mirror of
https://github.com/linka-cloud/grpc.git
synced 2024-12-25 02:10:46 +00:00
add metadata interceptors, auth client interceptors
Signed-off-by: Adphi <philippe.adrien.nousse@gmail.com>
This commit is contained in:
parent
c0e79d8834
commit
97ced73270
@ -45,7 +45,7 @@ func New(opts ...Option) (Client, error) {
|
||||
}
|
||||
if c.opts.addr == "" {
|
||||
c.addr = fmt.Sprintf("%s:///%s", c.opts.registry.String(), c.opts.name)
|
||||
} else if strings.HasPrefix(c.opts.addr, "tcp://"){
|
||||
} else if strings.HasPrefix(c.opts.addr, "tcp://") {
|
||||
c.addr = strings.Replace(c.opts.addr, "tcp://", "", 1)
|
||||
} else {
|
||||
c.addr = c.opts.addr
|
||||
|
@ -18,7 +18,7 @@ import (
|
||||
"go.linka.cloud/grpc/config"
|
||||
)
|
||||
|
||||
func newConfigFile(t *testing.T) (config.Config, string, func()){
|
||||
func newConfigFile(t *testing.T) (config.Config, string, func()) {
|
||||
path := filepath.Join(os.TempDir(), "config.yaml")
|
||||
if err := ioutil.WriteFile(path, []byte("ok"), os.ModePerm); err != nil {
|
||||
t.Fatal(err)
|
||||
@ -65,7 +65,7 @@ func TestWatch(t *testing.T) {
|
||||
}
|
||||
// when overwriting the file and waiting for the custom change notification handler to be triggered
|
||||
err := ioutil.WriteFile(cpath, []byte("foo: baz\n"), 0o640)
|
||||
b := <- updates
|
||||
b := <-updates
|
||||
// then the config value should have changed
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, []byte("foo: baz\n"), b)
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"net/mail"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
@ -31,15 +32,25 @@ var (
|
||||
_ = (*url.URL)(nil)
|
||||
_ = (*mail.Address)(nil)
|
||||
_ = anypb.Any{}
|
||||
_ = sort.Sort
|
||||
)
|
||||
|
||||
// Validate checks the field values on HelloRequest with the rules defined in
|
||||
// the proto definition for this message. If any rules are violated, an error
|
||||
// is returned. When asked to return all errors, validation continues after
|
||||
// first violation, and the result is a list of violation errors wrapped in
|
||||
// HelloRequestMultiError, or nil if none found. Otherwise, only the first
|
||||
// error is returned, if any.
|
||||
func (m *HelloRequest) Validate(all bool) error {
|
||||
// the proto definition for this message. If any rules are violated, the first
|
||||
// error encountered is returned, or nil if there are no violations.
|
||||
func (m *HelloRequest) Validate() error {
|
||||
return m.validate(false)
|
||||
}
|
||||
|
||||
// ValidateAll checks the field values on HelloRequest with the rules defined
|
||||
// in the proto definition for this message. If any rules are violated, the
|
||||
// result is a list of violation errors wrapped in HelloRequestMultiError, or
|
||||
// nil if none found.
|
||||
func (m *HelloRequest) ValidateAll() error {
|
||||
return m.validate(true)
|
||||
}
|
||||
|
||||
func (m *HelloRequest) validate(all bool) error {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
@ -64,8 +75,7 @@ func (m *HelloRequest) Validate(all bool) error {
|
||||
}
|
||||
|
||||
// HelloRequestMultiError is an error wrapping multiple validation errors
|
||||
// returned by HelloRequest.Validate(true) if the designated constraints
|
||||
// aren't met.
|
||||
// returned by HelloRequest.ValidateAll() if the designated constraints aren't met.
|
||||
type HelloRequestMultiError []error
|
||||
|
||||
// Error returns a concatenation of all the error messages it wraps.
|
||||
@ -135,12 +145,21 @@ var _ interface {
|
||||
} = HelloRequestValidationError{}
|
||||
|
||||
// Validate checks the field values on HelloReply with the rules defined in the
|
||||
// proto definition for this message. If any rules are violated, an error is
|
||||
// returned. When asked to return all errors, validation continues after first
|
||||
// violation, and the result is a list of violation errors wrapped in
|
||||
// HelloReplyMultiError, or nil if none found. Otherwise, only the first error
|
||||
// is returned, if any.
|
||||
func (m *HelloReply) Validate(all bool) error {
|
||||
// proto definition for this message. If any rules are violated, the first
|
||||
// error encountered is returned, or nil if there are no violations.
|
||||
func (m *HelloReply) Validate() error {
|
||||
return m.validate(false)
|
||||
}
|
||||
|
||||
// ValidateAll checks the field values on HelloReply with the rules defined in
|
||||
// the proto definition for this message. If any rules are violated, the
|
||||
// result is a list of violation errors wrapped in HelloReplyMultiError, or
|
||||
// nil if none found.
|
||||
func (m *HelloReply) ValidateAll() error {
|
||||
return m.validate(true)
|
||||
}
|
||||
|
||||
func (m *HelloReply) validate(all bool) error {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
@ -156,7 +175,7 @@ func (m *HelloReply) Validate(all bool) error {
|
||||
}
|
||||
|
||||
// HelloReplyMultiError is an error wrapping multiple validation errors
|
||||
// returned by HelloReply.Validate(true) if the designated constraints aren't met.
|
||||
// returned by HelloReply.ValidateAll() if the designated constraints aren't met.
|
||||
type HelloReplyMultiError []error
|
||||
|
||||
// Error returns a concatenation of all the error messages it wraps.
|
||||
@ -227,11 +246,20 @@ var _ interface {
|
||||
|
||||
// Validate checks the field values on HelloStreamRequest with the rules
|
||||
// defined in the proto definition for this message. If any rules are
|
||||
// violated, an error is returned. When asked to return all errors, validation
|
||||
// continues after first violation, and the result is a list of violation
|
||||
// errors wrapped in HelloStreamRequestMultiError, or nil if none found.
|
||||
// Otherwise, only the first error is returned, if any.
|
||||
func (m *HelloStreamRequest) Validate(all bool) error {
|
||||
// violated, the first error encountered is returned, or nil if there are no violations.
|
||||
func (m *HelloStreamRequest) Validate() error {
|
||||
return m.validate(false)
|
||||
}
|
||||
|
||||
// ValidateAll checks the field values on HelloStreamRequest with the rules
|
||||
// defined in the proto definition for this message. If any rules are
|
||||
// violated, the result is a list of violation errors wrapped in
|
||||
// HelloStreamRequestMultiError, or nil if none found.
|
||||
func (m *HelloStreamRequest) ValidateAll() error {
|
||||
return m.validate(true)
|
||||
}
|
||||
|
||||
func (m *HelloStreamRequest) validate(all bool) error {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
@ -267,7 +295,7 @@ func (m *HelloStreamRequest) Validate(all bool) error {
|
||||
}
|
||||
|
||||
// HelloStreamRequestMultiError is an error wrapping multiple validation errors
|
||||
// returned by HelloStreamRequest.Validate(true) if the designated constraints
|
||||
// returned by HelloStreamRequest.ValidateAll() if the designated constraints
|
||||
// aren't met.
|
||||
type HelloStreamRequestMultiError []error
|
||||
|
||||
|
@ -8,13 +8,15 @@ import (
|
||||
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
|
||||
|
||||
"go.linka.cloud/grpc/errors"
|
||||
"go.linka.cloud/grpc/interceptors"
|
||||
"go.linka.cloud/grpc/interceptors/metadata"
|
||||
)
|
||||
|
||||
func BasicAuth(user, password string) string {
|
||||
return "basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+password))
|
||||
}
|
||||
|
||||
type BasicValidator func(ctx context.Context, user, password string) (context.Context,error)
|
||||
type BasicValidator func(ctx context.Context, user, password string) (context.Context, error)
|
||||
|
||||
func makeBasicAuthFunc(v BasicValidator) grpc_auth.AuthFunc {
|
||||
return func(ctx context.Context) (context.Context, error) {
|
||||
@ -34,3 +36,7 @@ func makeBasicAuthFunc(v BasicValidator) grpc_auth.AuthFunc {
|
||||
return v(ctx, cs[:s], cs[s+1:])
|
||||
}
|
||||
}
|
||||
|
||||
func NewBasicAuthClientIntereptors(user, password string) interceptors.ClientInterceptors {
|
||||
return metadata.NewInterceptors("authorization", BasicAuth(user, password))
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"strings"
|
||||
|
||||
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
|
||||
@ -40,8 +41,8 @@ func NewServerInterceptors(opts ...Option) interceptors.ServerInterceptors {
|
||||
return &interceptor{o: o, authFn: ChainedAuthFuncs(o.authFns...)}
|
||||
}
|
||||
|
||||
type interceptor struct{
|
||||
o options
|
||||
type interceptor struct {
|
||||
o options
|
||||
authFn grpc_auth.AuthFunc
|
||||
}
|
||||
|
||||
@ -92,3 +93,7 @@ func (i *interceptor) isNotProtected(endpoint string) bool {
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func Equals(s1, s2 string) bool {
|
||||
return subtle.ConstantTimeCompare([]byte(s1), []byte(s2)) == 1
|
||||
}
|
||||
|
@ -102,11 +102,11 @@ func TestChainedAuthFuncs(t *testing.T) {
|
||||
code: codes.PermissionDenied,
|
||||
},
|
||||
{
|
||||
name: "internal error",
|
||||
auth: "bearer internal",
|
||||
name: "internal error",
|
||||
auth: "bearer internal",
|
||||
internalError: true,
|
||||
err: true,
|
||||
code: codes.PermissionDenied,
|
||||
err: true,
|
||||
code: codes.PermissionDenied,
|
||||
},
|
||||
{
|
||||
name: "multiple auth: first basic valid",
|
||||
|
@ -4,6 +4,9 @@ import (
|
||||
"context"
|
||||
|
||||
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
|
||||
|
||||
"go.linka.cloud/grpc/interceptors"
|
||||
"go.linka.cloud/grpc/interceptors/metadata"
|
||||
)
|
||||
|
||||
type TokenValidator func(ctx context.Context, token string) (context.Context, error)
|
||||
@ -17,3 +20,7 @@ func makeTokenAuthFunc(v TokenValidator) grpc_auth.AuthFunc {
|
||||
return v(ctx, a)
|
||||
}
|
||||
}
|
||||
|
||||
func NewBearerClientInterceptors(token string) interceptors.ClientInterceptors {
|
||||
return metadata.NewInterceptors("authorization", "Bearer "+token)
|
||||
}
|
||||
|
@ -1,19 +1,25 @@
|
||||
package service
|
||||
package metadata
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
"go.linka.cloud/grpc/interceptors"
|
||||
)
|
||||
|
||||
func NewInterceptors(pairs ...string) interceptors.Interceptors {
|
||||
return mdInterceptors{pairs: pairs}
|
||||
}
|
||||
|
||||
type mdInterceptors struct {
|
||||
k, v string
|
||||
pairs []string
|
||||
}
|
||||
|
||||
func (i mdInterceptors) UnaryServerInterceptor() grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
|
||||
if err := grpc.SetHeader(ctx, metadata.Pairs(i.k, i.v)); err != nil {
|
||||
if err := grpc.SetHeader(ctx, metadata.Pairs(i.pairs...)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return handler(ctx, req)
|
||||
@ -22,7 +28,7 @@ func (i mdInterceptors) UnaryServerInterceptor() grpc.UnaryServerInterceptor {
|
||||
|
||||
func (i mdInterceptors) StreamServerInterceptor() grpc.StreamServerInterceptor {
|
||||
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
if err := grpc.SetHeader(ss.Context(), metadata.Pairs(i.k, i.v)); err != nil {
|
||||
if err := grpc.SetHeader(ss.Context(), metadata.Pairs(i.pairs...)); err != nil {
|
||||
return err
|
||||
}
|
||||
return handler(srv, ss)
|
||||
@ -31,7 +37,7 @@ func (i mdInterceptors) StreamServerInterceptor() grpc.StreamServerInterceptor {
|
||||
|
||||
func (i mdInterceptors) UnaryClientInterceptor() grpc.UnaryClientInterceptor {
|
||||
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
|
||||
if err := grpc.SetHeader(ctx, metadata.Pairs(i.k, i.v)); err != nil {
|
||||
if err := grpc.SetHeader(ctx, metadata.Pairs(i.pairs...)); err != nil {
|
||||
return err
|
||||
}
|
||||
return invoker(ctx, method, req, reply, cc, opts...)
|
||||
@ -40,7 +46,7 @@ func (i mdInterceptors) UnaryClientInterceptor() grpc.UnaryClientInterceptor {
|
||||
|
||||
func (i mdInterceptors) StreamClientInterceptor() grpc.StreamClientInterceptor {
|
||||
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
|
||||
if err := grpc.SetHeader(ctx, metadata.Pairs(i.k, i.v)); err != nil {
|
||||
if err := grpc.SetHeader(ctx, metadata.Pairs(i.pairs...)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return streamer(ctx, desc, cc, method, opts...)
|
@ -30,6 +30,3 @@ func (i *recovery) UnaryClientInterceptor() grpc.UnaryClientInterceptor {
|
||||
func (i *recovery) StreamClientInterceptor() grpc.StreamClientInterceptor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
@ -3,7 +3,7 @@ package sentry
|
||||
import (
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/johnbellone/grpc-middleware-sentry"
|
||||
grpc_sentry "github.com/johnbellone/grpc-middleware-sentry"
|
||||
|
||||
"go.linka.cloud/grpc/interceptors"
|
||||
)
|
||||
@ -31,6 +31,3 @@ func (i *interceptor) UnaryClientInterceptor() grpc.UnaryClientInterceptor {
|
||||
func (i *interceptor) StreamClientInterceptor() grpc.StreamClientInterceptor {
|
||||
return grpc_sentry.StreamClientInterceptor(i.opts...)
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
@ -28,6 +28,7 @@ import (
|
||||
"google.golang.org/grpc/health/grpc_health_v1"
|
||||
greflect "google.golang.org/grpc/reflection"
|
||||
|
||||
"go.linka.cloud/grpc/interceptors/metadata"
|
||||
"go.linka.cloud/grpc/registry"
|
||||
"go.linka.cloud/grpc/registry/noop"
|
||||
)
|
||||
@ -74,32 +75,18 @@ func newService(opts ...Option) (*service, error) {
|
||||
f(s.opts)
|
||||
}
|
||||
if s.opts.name != "" {
|
||||
s.opts.unaryServerInterceptors = append([]grpc.UnaryServerInterceptor{mdInterceptors{
|
||||
k: "grpc-service-name", v: s.opts.name,
|
||||
}.UnaryServerInterceptor()}, s.opts.unaryServerInterceptors...)
|
||||
s.opts.unaryClientInterceptors = append([]grpc.UnaryClientInterceptor{mdInterceptors{
|
||||
k: "grpc-service-name", v: s.opts.name,
|
||||
}.UnaryClientInterceptor()}, s.opts.unaryClientInterceptors...)
|
||||
s.opts.streamServerInterceptors = append([]grpc.StreamServerInterceptor{mdInterceptors{
|
||||
k: "grpc-service-name", v: s.opts.name,
|
||||
}.StreamServerInterceptor()}, s.opts.streamServerInterceptors...)
|
||||
s.opts.streamClientInterceptors = append([]grpc.StreamClientInterceptor{mdInterceptors{
|
||||
k: "grpc-service-name", v: s.opts.name,
|
||||
}.StreamClientInterceptor()}, s.opts.streamClientInterceptors...)
|
||||
i := metadata.NewInterceptors("grpc-service-name", s.opts.name)
|
||||
s.opts.unaryServerInterceptors = append([]grpc.UnaryServerInterceptor{i.UnaryServerInterceptor()}, s.opts.unaryServerInterceptors...)
|
||||
s.opts.unaryClientInterceptors = append([]grpc.UnaryClientInterceptor{i.UnaryClientInterceptor()}, s.opts.unaryClientInterceptors...)
|
||||
s.opts.streamServerInterceptors = append([]grpc.StreamServerInterceptor{i.StreamServerInterceptor()}, s.opts.streamServerInterceptors...)
|
||||
s.opts.streamClientInterceptors = append([]grpc.StreamClientInterceptor{i.StreamClientInterceptor()}, s.opts.streamClientInterceptors...)
|
||||
}
|
||||
if s.opts.version != "" {
|
||||
s.opts.unaryServerInterceptors = append([]grpc.UnaryServerInterceptor{mdInterceptors{
|
||||
k: "grpc-service-version", v: s.opts.version,
|
||||
}.UnaryServerInterceptor()}, s.opts.unaryServerInterceptors...)
|
||||
s.opts.unaryClientInterceptors = append([]grpc.UnaryClientInterceptor{mdInterceptors{
|
||||
k: "grpc-service-version", v: s.opts.version,
|
||||
}.UnaryClientInterceptor()}, s.opts.unaryClientInterceptors...)
|
||||
s.opts.streamServerInterceptors = append([]grpc.StreamServerInterceptor{mdInterceptors{
|
||||
k: "grpc-service-version", v: s.opts.version,
|
||||
}.StreamServerInterceptor()}, s.opts.streamServerInterceptors...)
|
||||
s.opts.streamClientInterceptors = append([]grpc.StreamClientInterceptor{mdInterceptors{
|
||||
k: "grpc-service-version", v: s.opts.version,
|
||||
}.StreamClientInterceptor()}, s.opts.streamClientInterceptors...)
|
||||
i := metadata.NewInterceptors("grpc-service-version", s.opts.version)
|
||||
s.opts.unaryServerInterceptors = append([]grpc.UnaryServerInterceptor{i.UnaryServerInterceptor()}, s.opts.unaryServerInterceptors...)
|
||||
s.opts.unaryClientInterceptors = append([]grpc.UnaryClientInterceptor{i.UnaryClientInterceptor()}, s.opts.unaryClientInterceptors...)
|
||||
s.opts.streamServerInterceptors = append([]grpc.StreamServerInterceptor{i.StreamServerInterceptor()}, s.opts.streamServerInterceptors...)
|
||||
s.opts.streamClientInterceptors = append([]grpc.StreamClientInterceptor{i.StreamClientInterceptor()}, s.opts.streamClientInterceptors...)
|
||||
}
|
||||
if s.opts.mux == nil {
|
||||
s.opts.mux = http.NewServeMux()
|
||||
|
Loading…
Reference in New Issue
Block a user