From 97ced73270555200b17b89fd5b063bcb19e6764d Mon Sep 17 00:00:00 2001 From: Adphi Date: Fri, 11 Mar 2022 12:33:18 +0100 Subject: [PATCH] add metadata interceptors, auth client interceptors Signed-off-by: Adphi --- client/client.go | 2 +- config/file/config_test.go | 4 +- example/example.pb.validate.go | 70 +++++++++++++------ interceptors/auth/basic.go | 8 ++- interceptors/auth/interceptors.go | 9 ++- interceptors/auth/interceptors_test.go | 8 +-- interceptors/auth/token.go | 7 ++ .../metadata/metadata.go | 18 +++-- interceptors/recovery/interceptors.go | 3 - interceptors/sentry/interceptors.go | 5 +- service/service.go | 35 +++------- 11 files changed, 101 insertions(+), 68 deletions(-) rename service/internal_interceptors.go => interceptors/metadata/metadata.go (72%) diff --git a/client/client.go b/client/client.go index 866e8c6..8051295 100644 --- a/client/client.go +++ b/client/client.go @@ -45,7 +45,7 @@ func New(opts ...Option) (Client, error) { } if c.opts.addr == "" { c.addr = fmt.Sprintf("%s:///%s", c.opts.registry.String(), c.opts.name) - } else if strings.HasPrefix(c.opts.addr, "tcp://"){ + } else if strings.HasPrefix(c.opts.addr, "tcp://") { c.addr = strings.Replace(c.opts.addr, "tcp://", "", 1) } else { c.addr = c.opts.addr diff --git a/config/file/config_test.go b/config/file/config_test.go index 0836633..fa331e2 100644 --- a/config/file/config_test.go +++ b/config/file/config_test.go @@ -18,7 +18,7 @@ import ( "go.linka.cloud/grpc/config" ) -func newConfigFile(t *testing.T) (config.Config, string, func()){ +func newConfigFile(t *testing.T) (config.Config, string, func()) { path := filepath.Join(os.TempDir(), "config.yaml") if err := ioutil.WriteFile(path, []byte("ok"), os.ModePerm); err != nil { t.Fatal(err) @@ -65,7 +65,7 @@ func TestWatch(t *testing.T) { } // when overwriting the file and waiting for the custom change notification handler to be triggered err := ioutil.WriteFile(cpath, []byte("foo: baz\n"), 0o640) - b := <- updates + b := <-updates // then the config value should have changed require.Nil(t, err) assert.Equal(t, []byte("foo: baz\n"), b) diff --git a/example/example.pb.validate.go b/example/example.pb.validate.go index 1f6a5ef..0fba72c 100644 --- a/example/example.pb.validate.go +++ b/example/example.pb.validate.go @@ -11,6 +11,7 @@ import ( "net/mail" "net/url" "regexp" + "sort" "strings" "time" "unicode/utf8" @@ -31,15 +32,25 @@ var ( _ = (*url.URL)(nil) _ = (*mail.Address)(nil) _ = anypb.Any{} + _ = sort.Sort ) // Validate checks the field values on HelloRequest with the rules defined in -// the proto definition for this message. If any rules are violated, an error -// is returned. When asked to return all errors, validation continues after -// first violation, and the result is a list of violation errors wrapped in -// HelloRequestMultiError, or nil if none found. Otherwise, only the first -// error is returned, if any. -func (m *HelloRequest) Validate(all bool) error { +// the proto definition for this message. If any rules are violated, the first +// error encountered is returned, or nil if there are no violations. +func (m *HelloRequest) Validate() error { + return m.validate(false) +} + +// ValidateAll checks the field values on HelloRequest with the rules defined +// in the proto definition for this message. If any rules are violated, the +// result is a list of violation errors wrapped in HelloRequestMultiError, or +// nil if none found. +func (m *HelloRequest) ValidateAll() error { + return m.validate(true) +} + +func (m *HelloRequest) validate(all bool) error { if m == nil { return nil } @@ -64,8 +75,7 @@ func (m *HelloRequest) Validate(all bool) error { } // HelloRequestMultiError is an error wrapping multiple validation errors -// returned by HelloRequest.Validate(true) if the designated constraints -// aren't met. +// returned by HelloRequest.ValidateAll() if the designated constraints aren't met. type HelloRequestMultiError []error // Error returns a concatenation of all the error messages it wraps. @@ -135,12 +145,21 @@ var _ interface { } = HelloRequestValidationError{} // Validate checks the field values on HelloReply with the rules defined in the -// proto definition for this message. If any rules are violated, an error is -// returned. When asked to return all errors, validation continues after first -// violation, and the result is a list of violation errors wrapped in -// HelloReplyMultiError, or nil if none found. Otherwise, only the first error -// is returned, if any. -func (m *HelloReply) Validate(all bool) error { +// proto definition for this message. If any rules are violated, the first +// error encountered is returned, or nil if there are no violations. +func (m *HelloReply) Validate() error { + return m.validate(false) +} + +// ValidateAll checks the field values on HelloReply with the rules defined in +// the proto definition for this message. If any rules are violated, the +// result is a list of violation errors wrapped in HelloReplyMultiError, or +// nil if none found. +func (m *HelloReply) ValidateAll() error { + return m.validate(true) +} + +func (m *HelloReply) validate(all bool) error { if m == nil { return nil } @@ -156,7 +175,7 @@ func (m *HelloReply) Validate(all bool) error { } // HelloReplyMultiError is an error wrapping multiple validation errors -// returned by HelloReply.Validate(true) if the designated constraints aren't met. +// returned by HelloReply.ValidateAll() if the designated constraints aren't met. type HelloReplyMultiError []error // Error returns a concatenation of all the error messages it wraps. @@ -227,11 +246,20 @@ var _ interface { // Validate checks the field values on HelloStreamRequest with the rules // defined in the proto definition for this message. If any rules are -// violated, an error is returned. When asked to return all errors, validation -// continues after first violation, and the result is a list of violation -// errors wrapped in HelloStreamRequestMultiError, or nil if none found. -// Otherwise, only the first error is returned, if any. -func (m *HelloStreamRequest) Validate(all bool) error { +// violated, the first error encountered is returned, or nil if there are no violations. +func (m *HelloStreamRequest) Validate() error { + return m.validate(false) +} + +// ValidateAll checks the field values on HelloStreamRequest with the rules +// defined in the proto definition for this message. If any rules are +// violated, the result is a list of violation errors wrapped in +// HelloStreamRequestMultiError, or nil if none found. +func (m *HelloStreamRequest) ValidateAll() error { + return m.validate(true) +} + +func (m *HelloStreamRequest) validate(all bool) error { if m == nil { return nil } @@ -267,7 +295,7 @@ func (m *HelloStreamRequest) Validate(all bool) error { } // HelloStreamRequestMultiError is an error wrapping multiple validation errors -// returned by HelloStreamRequest.Validate(true) if the designated constraints +// returned by HelloStreamRequest.ValidateAll() if the designated constraints // aren't met. type HelloStreamRequestMultiError []error diff --git a/interceptors/auth/basic.go b/interceptors/auth/basic.go index 03fb2be..c61d6f6 100644 --- a/interceptors/auth/basic.go +++ b/interceptors/auth/basic.go @@ -8,13 +8,15 @@ import ( grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth" "go.linka.cloud/grpc/errors" + "go.linka.cloud/grpc/interceptors" + "go.linka.cloud/grpc/interceptors/metadata" ) func BasicAuth(user, password string) string { return "basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+password)) } -type BasicValidator func(ctx context.Context, user, password string) (context.Context,error) +type BasicValidator func(ctx context.Context, user, password string) (context.Context, error) func makeBasicAuthFunc(v BasicValidator) grpc_auth.AuthFunc { return func(ctx context.Context) (context.Context, error) { @@ -34,3 +36,7 @@ func makeBasicAuthFunc(v BasicValidator) grpc_auth.AuthFunc { return v(ctx, cs[:s], cs[s+1:]) } } + +func NewBasicAuthClientIntereptors(user, password string) interceptors.ClientInterceptors { + return metadata.NewInterceptors("authorization", BasicAuth(user, password)) +} diff --git a/interceptors/auth/interceptors.go b/interceptors/auth/interceptors.go index 93f8582..f376fd9 100644 --- a/interceptors/auth/interceptors.go +++ b/interceptors/auth/interceptors.go @@ -2,6 +2,7 @@ package auth import ( "context" + "crypto/subtle" "strings" grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth" @@ -40,8 +41,8 @@ func NewServerInterceptors(opts ...Option) interceptors.ServerInterceptors { return &interceptor{o: o, authFn: ChainedAuthFuncs(o.authFns...)} } -type interceptor struct{ - o options +type interceptor struct { + o options authFn grpc_auth.AuthFunc } @@ -92,3 +93,7 @@ func (i *interceptor) isNotProtected(endpoint string) bool { } return true } + +func Equals(s1, s2 string) bool { + return subtle.ConstantTimeCompare([]byte(s1), []byte(s2)) == 1 +} diff --git a/interceptors/auth/interceptors_test.go b/interceptors/auth/interceptors_test.go index f101e4a..2ae06ed 100644 --- a/interceptors/auth/interceptors_test.go +++ b/interceptors/auth/interceptors_test.go @@ -102,11 +102,11 @@ func TestChainedAuthFuncs(t *testing.T) { code: codes.PermissionDenied, }, { - name: "internal error", - auth: "bearer internal", + name: "internal error", + auth: "bearer internal", internalError: true, - err: true, - code: codes.PermissionDenied, + err: true, + code: codes.PermissionDenied, }, { name: "multiple auth: first basic valid", diff --git a/interceptors/auth/token.go b/interceptors/auth/token.go index 6424c64..6c8caf9 100644 --- a/interceptors/auth/token.go +++ b/interceptors/auth/token.go @@ -4,6 +4,9 @@ import ( "context" grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth" + + "go.linka.cloud/grpc/interceptors" + "go.linka.cloud/grpc/interceptors/metadata" ) type TokenValidator func(ctx context.Context, token string) (context.Context, error) @@ -17,3 +20,7 @@ func makeTokenAuthFunc(v TokenValidator) grpc_auth.AuthFunc { return v(ctx, a) } } + +func NewBearerClientInterceptors(token string) interceptors.ClientInterceptors { + return metadata.NewInterceptors("authorization", "Bearer "+token) +} diff --git a/service/internal_interceptors.go b/interceptors/metadata/metadata.go similarity index 72% rename from service/internal_interceptors.go rename to interceptors/metadata/metadata.go index 42701e1..983f642 100644 --- a/service/internal_interceptors.go +++ b/interceptors/metadata/metadata.go @@ -1,19 +1,25 @@ -package service +package metadata import ( "context" "google.golang.org/grpc" "google.golang.org/grpc/metadata" + + "go.linka.cloud/grpc/interceptors" ) +func NewInterceptors(pairs ...string) interceptors.Interceptors { + return mdInterceptors{pairs: pairs} +} + type mdInterceptors struct { - k, v string + pairs []string } func (i mdInterceptors) UnaryServerInterceptor() grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { - if err := grpc.SetHeader(ctx, metadata.Pairs(i.k, i.v)); err != nil { + if err := grpc.SetHeader(ctx, metadata.Pairs(i.pairs...)); err != nil { return nil, err } return handler(ctx, req) @@ -22,7 +28,7 @@ func (i mdInterceptors) UnaryServerInterceptor() grpc.UnaryServerInterceptor { func (i mdInterceptors) StreamServerInterceptor() grpc.StreamServerInterceptor { return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { - if err := grpc.SetHeader(ss.Context(), metadata.Pairs(i.k, i.v)); err != nil { + if err := grpc.SetHeader(ss.Context(), metadata.Pairs(i.pairs...)); err != nil { return err } return handler(srv, ss) @@ -31,7 +37,7 @@ func (i mdInterceptors) StreamServerInterceptor() grpc.StreamServerInterceptor { func (i mdInterceptors) UnaryClientInterceptor() grpc.UnaryClientInterceptor { return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { - if err := grpc.SetHeader(ctx, metadata.Pairs(i.k, i.v)); err != nil { + if err := grpc.SetHeader(ctx, metadata.Pairs(i.pairs...)); err != nil { return err } return invoker(ctx, method, req, reply, cc, opts...) @@ -40,7 +46,7 @@ func (i mdInterceptors) UnaryClientInterceptor() grpc.UnaryClientInterceptor { func (i mdInterceptors) StreamClientInterceptor() grpc.StreamClientInterceptor { return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { - if err := grpc.SetHeader(ctx, metadata.Pairs(i.k, i.v)); err != nil { + if err := grpc.SetHeader(ctx, metadata.Pairs(i.pairs...)); err != nil { return nil, err } return streamer(ctx, desc, cc, method, opts...) diff --git a/interceptors/recovery/interceptors.go b/interceptors/recovery/interceptors.go index b8bba13..5e4f38b 100644 --- a/interceptors/recovery/interceptors.go +++ b/interceptors/recovery/interceptors.go @@ -30,6 +30,3 @@ func (i *recovery) UnaryClientInterceptor() grpc.UnaryClientInterceptor { func (i *recovery) StreamClientInterceptor() grpc.StreamClientInterceptor { panic("not implemented") } - - - diff --git a/interceptors/sentry/interceptors.go b/interceptors/sentry/interceptors.go index 47a25aa..cb7191b 100644 --- a/interceptors/sentry/interceptors.go +++ b/interceptors/sentry/interceptors.go @@ -3,7 +3,7 @@ package sentry import ( "google.golang.org/grpc" - "github.com/johnbellone/grpc-middleware-sentry" + grpc_sentry "github.com/johnbellone/grpc-middleware-sentry" "go.linka.cloud/grpc/interceptors" ) @@ -31,6 +31,3 @@ func (i *interceptor) UnaryClientInterceptor() grpc.UnaryClientInterceptor { func (i *interceptor) StreamClientInterceptor() grpc.StreamClientInterceptor { return grpc_sentry.StreamClientInterceptor(i.opts...) } - - - diff --git a/service/service.go b/service/service.go index b7e2936..f50ad33 100644 --- a/service/service.go +++ b/service/service.go @@ -28,6 +28,7 @@ import ( "google.golang.org/grpc/health/grpc_health_v1" greflect "google.golang.org/grpc/reflection" + "go.linka.cloud/grpc/interceptors/metadata" "go.linka.cloud/grpc/registry" "go.linka.cloud/grpc/registry/noop" ) @@ -74,32 +75,18 @@ func newService(opts ...Option) (*service, error) { f(s.opts) } if s.opts.name != "" { - s.opts.unaryServerInterceptors = append([]grpc.UnaryServerInterceptor{mdInterceptors{ - k: "grpc-service-name", v: s.opts.name, - }.UnaryServerInterceptor()}, s.opts.unaryServerInterceptors...) - s.opts.unaryClientInterceptors = append([]grpc.UnaryClientInterceptor{mdInterceptors{ - k: "grpc-service-name", v: s.opts.name, - }.UnaryClientInterceptor()}, s.opts.unaryClientInterceptors...) - s.opts.streamServerInterceptors = append([]grpc.StreamServerInterceptor{mdInterceptors{ - k: "grpc-service-name", v: s.opts.name, - }.StreamServerInterceptor()}, s.opts.streamServerInterceptors...) - s.opts.streamClientInterceptors = append([]grpc.StreamClientInterceptor{mdInterceptors{ - k: "grpc-service-name", v: s.opts.name, - }.StreamClientInterceptor()}, s.opts.streamClientInterceptors...) + i := metadata.NewInterceptors("grpc-service-name", s.opts.name) + s.opts.unaryServerInterceptors = append([]grpc.UnaryServerInterceptor{i.UnaryServerInterceptor()}, s.opts.unaryServerInterceptors...) + s.opts.unaryClientInterceptors = append([]grpc.UnaryClientInterceptor{i.UnaryClientInterceptor()}, s.opts.unaryClientInterceptors...) + s.opts.streamServerInterceptors = append([]grpc.StreamServerInterceptor{i.StreamServerInterceptor()}, s.opts.streamServerInterceptors...) + s.opts.streamClientInterceptors = append([]grpc.StreamClientInterceptor{i.StreamClientInterceptor()}, s.opts.streamClientInterceptors...) } if s.opts.version != "" { - s.opts.unaryServerInterceptors = append([]grpc.UnaryServerInterceptor{mdInterceptors{ - k: "grpc-service-version", v: s.opts.version, - }.UnaryServerInterceptor()}, s.opts.unaryServerInterceptors...) - s.opts.unaryClientInterceptors = append([]grpc.UnaryClientInterceptor{mdInterceptors{ - k: "grpc-service-version", v: s.opts.version, - }.UnaryClientInterceptor()}, s.opts.unaryClientInterceptors...) - s.opts.streamServerInterceptors = append([]grpc.StreamServerInterceptor{mdInterceptors{ - k: "grpc-service-version", v: s.opts.version, - }.StreamServerInterceptor()}, s.opts.streamServerInterceptors...) - s.opts.streamClientInterceptors = append([]grpc.StreamClientInterceptor{mdInterceptors{ - k: "grpc-service-version", v: s.opts.version, - }.StreamClientInterceptor()}, s.opts.streamClientInterceptors...) + i := metadata.NewInterceptors("grpc-service-version", s.opts.version) + s.opts.unaryServerInterceptors = append([]grpc.UnaryServerInterceptor{i.UnaryServerInterceptor()}, s.opts.unaryServerInterceptors...) + s.opts.unaryClientInterceptors = append([]grpc.UnaryClientInterceptor{i.UnaryClientInterceptor()}, s.opts.unaryClientInterceptors...) + s.opts.streamServerInterceptors = append([]grpc.StreamServerInterceptor{i.StreamServerInterceptor()}, s.opts.streamServerInterceptors...) + s.opts.streamClientInterceptors = append([]grpc.StreamClientInterceptor{i.StreamClientInterceptor()}, s.opts.streamClientInterceptors...) } if s.opts.mux == nil { s.opts.mux = http.NewServeMux()