add metadata interceptors, auth client interceptors

Signed-off-by: Adphi <philippe.adrien.nousse@gmail.com>
This commit is contained in:
Adphi 2022-03-11 12:33:18 +01:00
parent c0e79d8834
commit 97ced73270
Signed by: adphi
GPG Key ID: 46BE4062DB2397FF
11 changed files with 101 additions and 68 deletions

View File

@ -45,7 +45,7 @@ func New(opts ...Option) (Client, error) {
} }
if c.opts.addr == "" { if c.opts.addr == "" {
c.addr = fmt.Sprintf("%s:///%s", c.opts.registry.String(), c.opts.name) c.addr = fmt.Sprintf("%s:///%s", c.opts.registry.String(), c.opts.name)
} else if strings.HasPrefix(c.opts.addr, "tcp://"){ } else if strings.HasPrefix(c.opts.addr, "tcp://") {
c.addr = strings.Replace(c.opts.addr, "tcp://", "", 1) c.addr = strings.Replace(c.opts.addr, "tcp://", "", 1)
} else { } else {
c.addr = c.opts.addr c.addr = c.opts.addr

View File

@ -18,7 +18,7 @@ import (
"go.linka.cloud/grpc/config" "go.linka.cloud/grpc/config"
) )
func newConfigFile(t *testing.T) (config.Config, string, func()){ func newConfigFile(t *testing.T) (config.Config, string, func()) {
path := filepath.Join(os.TempDir(), "config.yaml") path := filepath.Join(os.TempDir(), "config.yaml")
if err := ioutil.WriteFile(path, []byte("ok"), os.ModePerm); err != nil { if err := ioutil.WriteFile(path, []byte("ok"), os.ModePerm); err != nil {
t.Fatal(err) t.Fatal(err)
@ -65,7 +65,7 @@ func TestWatch(t *testing.T) {
} }
// when overwriting the file and waiting for the custom change notification handler to be triggered // when overwriting the file and waiting for the custom change notification handler to be triggered
err := ioutil.WriteFile(cpath, []byte("foo: baz\n"), 0o640) err := ioutil.WriteFile(cpath, []byte("foo: baz\n"), 0o640)
b := <- updates b := <-updates
// then the config value should have changed // then the config value should have changed
require.Nil(t, err) require.Nil(t, err)
assert.Equal(t, []byte("foo: baz\n"), b) assert.Equal(t, []byte("foo: baz\n"), b)

View File

@ -11,6 +11,7 @@ import (
"net/mail" "net/mail"
"net/url" "net/url"
"regexp" "regexp"
"sort"
"strings" "strings"
"time" "time"
"unicode/utf8" "unicode/utf8"
@ -31,15 +32,25 @@ var (
_ = (*url.URL)(nil) _ = (*url.URL)(nil)
_ = (*mail.Address)(nil) _ = (*mail.Address)(nil)
_ = anypb.Any{} _ = anypb.Any{}
_ = sort.Sort
) )
// Validate checks the field values on HelloRequest with the rules defined in // Validate checks the field values on HelloRequest with the rules defined in
// the proto definition for this message. If any rules are violated, an error // the proto definition for this message. If any rules are violated, the first
// is returned. When asked to return all errors, validation continues after // error encountered is returned, or nil if there are no violations.
// first violation, and the result is a list of violation errors wrapped in func (m *HelloRequest) Validate() error {
// HelloRequestMultiError, or nil if none found. Otherwise, only the first return m.validate(false)
// error is returned, if any. }
func (m *HelloRequest) Validate(all bool) error {
// ValidateAll checks the field values on HelloRequest with the rules defined
// in the proto definition for this message. If any rules are violated, the
// result is a list of violation errors wrapped in HelloRequestMultiError, or
// nil if none found.
func (m *HelloRequest) ValidateAll() error {
return m.validate(true)
}
func (m *HelloRequest) validate(all bool) error {
if m == nil { if m == nil {
return nil return nil
} }
@ -64,8 +75,7 @@ func (m *HelloRequest) Validate(all bool) error {
} }
// HelloRequestMultiError is an error wrapping multiple validation errors // HelloRequestMultiError is an error wrapping multiple validation errors
// returned by HelloRequest.Validate(true) if the designated constraints // returned by HelloRequest.ValidateAll() if the designated constraints aren't met.
// aren't met.
type HelloRequestMultiError []error type HelloRequestMultiError []error
// Error returns a concatenation of all the error messages it wraps. // Error returns a concatenation of all the error messages it wraps.
@ -135,12 +145,21 @@ var _ interface {
} = HelloRequestValidationError{} } = HelloRequestValidationError{}
// Validate checks the field values on HelloReply with the rules defined in the // Validate checks the field values on HelloReply with the rules defined in the
// proto definition for this message. If any rules are violated, an error is // proto definition for this message. If any rules are violated, the first
// returned. When asked to return all errors, validation continues after first // error encountered is returned, or nil if there are no violations.
// violation, and the result is a list of violation errors wrapped in func (m *HelloReply) Validate() error {
// HelloReplyMultiError, or nil if none found. Otherwise, only the first error return m.validate(false)
// is returned, if any. }
func (m *HelloReply) Validate(all bool) error {
// ValidateAll checks the field values on HelloReply with the rules defined in
// the proto definition for this message. If any rules are violated, the
// result is a list of violation errors wrapped in HelloReplyMultiError, or
// nil if none found.
func (m *HelloReply) ValidateAll() error {
return m.validate(true)
}
func (m *HelloReply) validate(all bool) error {
if m == nil { if m == nil {
return nil return nil
} }
@ -156,7 +175,7 @@ func (m *HelloReply) Validate(all bool) error {
} }
// HelloReplyMultiError is an error wrapping multiple validation errors // HelloReplyMultiError is an error wrapping multiple validation errors
// returned by HelloReply.Validate(true) if the designated constraints aren't met. // returned by HelloReply.ValidateAll() if the designated constraints aren't met.
type HelloReplyMultiError []error type HelloReplyMultiError []error
// Error returns a concatenation of all the error messages it wraps. // Error returns a concatenation of all the error messages it wraps.
@ -227,11 +246,20 @@ var _ interface {
// Validate checks the field values on HelloStreamRequest with the rules // Validate checks the field values on HelloStreamRequest with the rules
// defined in the proto definition for this message. If any rules are // defined in the proto definition for this message. If any rules are
// violated, an error is returned. When asked to return all errors, validation // violated, the first error encountered is returned, or nil if there are no violations.
// continues after first violation, and the result is a list of violation func (m *HelloStreamRequest) Validate() error {
// errors wrapped in HelloStreamRequestMultiError, or nil if none found. return m.validate(false)
// Otherwise, only the first error is returned, if any. }
func (m *HelloStreamRequest) Validate(all bool) error {
// ValidateAll checks the field values on HelloStreamRequest with the rules
// defined in the proto definition for this message. If any rules are
// violated, the result is a list of violation errors wrapped in
// HelloStreamRequestMultiError, or nil if none found.
func (m *HelloStreamRequest) ValidateAll() error {
return m.validate(true)
}
func (m *HelloStreamRequest) validate(all bool) error {
if m == nil { if m == nil {
return nil return nil
} }
@ -267,7 +295,7 @@ func (m *HelloStreamRequest) Validate(all bool) error {
} }
// HelloStreamRequestMultiError is an error wrapping multiple validation errors // HelloStreamRequestMultiError is an error wrapping multiple validation errors
// returned by HelloStreamRequest.Validate(true) if the designated constraints // returned by HelloStreamRequest.ValidateAll() if the designated constraints
// aren't met. // aren't met.
type HelloStreamRequestMultiError []error type HelloStreamRequestMultiError []error

View File

@ -8,13 +8,15 @@ import (
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth" grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
"go.linka.cloud/grpc/errors" "go.linka.cloud/grpc/errors"
"go.linka.cloud/grpc/interceptors"
"go.linka.cloud/grpc/interceptors/metadata"
) )
func BasicAuth(user, password string) string { func BasicAuth(user, password string) string {
return "basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+password)) return "basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+password))
} }
type BasicValidator func(ctx context.Context, user, password string) (context.Context,error) type BasicValidator func(ctx context.Context, user, password string) (context.Context, error)
func makeBasicAuthFunc(v BasicValidator) grpc_auth.AuthFunc { func makeBasicAuthFunc(v BasicValidator) grpc_auth.AuthFunc {
return func(ctx context.Context) (context.Context, error) { return func(ctx context.Context) (context.Context, error) {
@ -34,3 +36,7 @@ func makeBasicAuthFunc(v BasicValidator) grpc_auth.AuthFunc {
return v(ctx, cs[:s], cs[s+1:]) return v(ctx, cs[:s], cs[s+1:])
} }
} }
func NewBasicAuthClientIntereptors(user, password string) interceptors.ClientInterceptors {
return metadata.NewInterceptors("authorization", BasicAuth(user, password))
}

View File

@ -2,6 +2,7 @@ package auth
import ( import (
"context" "context"
"crypto/subtle"
"strings" "strings"
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth" grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
@ -40,8 +41,8 @@ func NewServerInterceptors(opts ...Option) interceptors.ServerInterceptors {
return &interceptor{o: o, authFn: ChainedAuthFuncs(o.authFns...)} return &interceptor{o: o, authFn: ChainedAuthFuncs(o.authFns...)}
} }
type interceptor struct{ type interceptor struct {
o options o options
authFn grpc_auth.AuthFunc authFn grpc_auth.AuthFunc
} }
@ -92,3 +93,7 @@ func (i *interceptor) isNotProtected(endpoint string) bool {
} }
return true return true
} }
func Equals(s1, s2 string) bool {
return subtle.ConstantTimeCompare([]byte(s1), []byte(s2)) == 1
}

View File

@ -102,11 +102,11 @@ func TestChainedAuthFuncs(t *testing.T) {
code: codes.PermissionDenied, code: codes.PermissionDenied,
}, },
{ {
name: "internal error", name: "internal error",
auth: "bearer internal", auth: "bearer internal",
internalError: true, internalError: true,
err: true, err: true,
code: codes.PermissionDenied, code: codes.PermissionDenied,
}, },
{ {
name: "multiple auth: first basic valid", name: "multiple auth: first basic valid",

View File

@ -4,6 +4,9 @@ import (
"context" "context"
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth" grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
"go.linka.cloud/grpc/interceptors"
"go.linka.cloud/grpc/interceptors/metadata"
) )
type TokenValidator func(ctx context.Context, token string) (context.Context, error) type TokenValidator func(ctx context.Context, token string) (context.Context, error)
@ -17,3 +20,7 @@ func makeTokenAuthFunc(v TokenValidator) grpc_auth.AuthFunc {
return v(ctx, a) return v(ctx, a)
} }
} }
func NewBearerClientInterceptors(token string) interceptors.ClientInterceptors {
return metadata.NewInterceptors("authorization", "Bearer "+token)
}

View File

@ -1,19 +1,25 @@
package service package metadata
import ( import (
"context" "context"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"go.linka.cloud/grpc/interceptors"
) )
func NewInterceptors(pairs ...string) interceptors.Interceptors {
return mdInterceptors{pairs: pairs}
}
type mdInterceptors struct { type mdInterceptors struct {
k, v string pairs []string
} }
func (i mdInterceptors) UnaryServerInterceptor() grpc.UnaryServerInterceptor { func (i mdInterceptors) UnaryServerInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
if err := grpc.SetHeader(ctx, metadata.Pairs(i.k, i.v)); err != nil { if err := grpc.SetHeader(ctx, metadata.Pairs(i.pairs...)); err != nil {
return nil, err return nil, err
} }
return handler(ctx, req) return handler(ctx, req)
@ -22,7 +28,7 @@ func (i mdInterceptors) UnaryServerInterceptor() grpc.UnaryServerInterceptor {
func (i mdInterceptors) StreamServerInterceptor() grpc.StreamServerInterceptor { func (i mdInterceptors) StreamServerInterceptor() grpc.StreamServerInterceptor {
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
if err := grpc.SetHeader(ss.Context(), metadata.Pairs(i.k, i.v)); err != nil { if err := grpc.SetHeader(ss.Context(), metadata.Pairs(i.pairs...)); err != nil {
return err return err
} }
return handler(srv, ss) return handler(srv, ss)
@ -31,7 +37,7 @@ func (i mdInterceptors) StreamServerInterceptor() grpc.StreamServerInterceptor {
func (i mdInterceptors) UnaryClientInterceptor() grpc.UnaryClientInterceptor { func (i mdInterceptors) UnaryClientInterceptor() grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
if err := grpc.SetHeader(ctx, metadata.Pairs(i.k, i.v)); err != nil { if err := grpc.SetHeader(ctx, metadata.Pairs(i.pairs...)); err != nil {
return err return err
} }
return invoker(ctx, method, req, reply, cc, opts...) return invoker(ctx, method, req, reply, cc, opts...)
@ -40,7 +46,7 @@ func (i mdInterceptors) UnaryClientInterceptor() grpc.UnaryClientInterceptor {
func (i mdInterceptors) StreamClientInterceptor() grpc.StreamClientInterceptor { func (i mdInterceptors) StreamClientInterceptor() grpc.StreamClientInterceptor {
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
if err := grpc.SetHeader(ctx, metadata.Pairs(i.k, i.v)); err != nil { if err := grpc.SetHeader(ctx, metadata.Pairs(i.pairs...)); err != nil {
return nil, err return nil, err
} }
return streamer(ctx, desc, cc, method, opts...) return streamer(ctx, desc, cc, method, opts...)

View File

@ -30,6 +30,3 @@ func (i *recovery) UnaryClientInterceptor() grpc.UnaryClientInterceptor {
func (i *recovery) StreamClientInterceptor() grpc.StreamClientInterceptor { func (i *recovery) StreamClientInterceptor() grpc.StreamClientInterceptor {
panic("not implemented") panic("not implemented")
} }

View File

@ -3,7 +3,7 @@ package sentry
import ( import (
"google.golang.org/grpc" "google.golang.org/grpc"
"github.com/johnbellone/grpc-middleware-sentry" grpc_sentry "github.com/johnbellone/grpc-middleware-sentry"
"go.linka.cloud/grpc/interceptors" "go.linka.cloud/grpc/interceptors"
) )
@ -31,6 +31,3 @@ func (i *interceptor) UnaryClientInterceptor() grpc.UnaryClientInterceptor {
func (i *interceptor) StreamClientInterceptor() grpc.StreamClientInterceptor { func (i *interceptor) StreamClientInterceptor() grpc.StreamClientInterceptor {
return grpc_sentry.StreamClientInterceptor(i.opts...) return grpc_sentry.StreamClientInterceptor(i.opts...)
} }

View File

@ -28,6 +28,7 @@ import (
"google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/health/grpc_health_v1"
greflect "google.golang.org/grpc/reflection" greflect "google.golang.org/grpc/reflection"
"go.linka.cloud/grpc/interceptors/metadata"
"go.linka.cloud/grpc/registry" "go.linka.cloud/grpc/registry"
"go.linka.cloud/grpc/registry/noop" "go.linka.cloud/grpc/registry/noop"
) )
@ -74,32 +75,18 @@ func newService(opts ...Option) (*service, error) {
f(s.opts) f(s.opts)
} }
if s.opts.name != "" { if s.opts.name != "" {
s.opts.unaryServerInterceptors = append([]grpc.UnaryServerInterceptor{mdInterceptors{ i := metadata.NewInterceptors("grpc-service-name", s.opts.name)
k: "grpc-service-name", v: s.opts.name, s.opts.unaryServerInterceptors = append([]grpc.UnaryServerInterceptor{i.UnaryServerInterceptor()}, s.opts.unaryServerInterceptors...)
}.UnaryServerInterceptor()}, s.opts.unaryServerInterceptors...) s.opts.unaryClientInterceptors = append([]grpc.UnaryClientInterceptor{i.UnaryClientInterceptor()}, s.opts.unaryClientInterceptors...)
s.opts.unaryClientInterceptors = append([]grpc.UnaryClientInterceptor{mdInterceptors{ s.opts.streamServerInterceptors = append([]grpc.StreamServerInterceptor{i.StreamServerInterceptor()}, s.opts.streamServerInterceptors...)
k: "grpc-service-name", v: s.opts.name, s.opts.streamClientInterceptors = append([]grpc.StreamClientInterceptor{i.StreamClientInterceptor()}, s.opts.streamClientInterceptors...)
}.UnaryClientInterceptor()}, s.opts.unaryClientInterceptors...)
s.opts.streamServerInterceptors = append([]grpc.StreamServerInterceptor{mdInterceptors{
k: "grpc-service-name", v: s.opts.name,
}.StreamServerInterceptor()}, s.opts.streamServerInterceptors...)
s.opts.streamClientInterceptors = append([]grpc.StreamClientInterceptor{mdInterceptors{
k: "grpc-service-name", v: s.opts.name,
}.StreamClientInterceptor()}, s.opts.streamClientInterceptors...)
} }
if s.opts.version != "" { if s.opts.version != "" {
s.opts.unaryServerInterceptors = append([]grpc.UnaryServerInterceptor{mdInterceptors{ i := metadata.NewInterceptors("grpc-service-version", s.opts.version)
k: "grpc-service-version", v: s.opts.version, s.opts.unaryServerInterceptors = append([]grpc.UnaryServerInterceptor{i.UnaryServerInterceptor()}, s.opts.unaryServerInterceptors...)
}.UnaryServerInterceptor()}, s.opts.unaryServerInterceptors...) s.opts.unaryClientInterceptors = append([]grpc.UnaryClientInterceptor{i.UnaryClientInterceptor()}, s.opts.unaryClientInterceptors...)
s.opts.unaryClientInterceptors = append([]grpc.UnaryClientInterceptor{mdInterceptors{ s.opts.streamServerInterceptors = append([]grpc.StreamServerInterceptor{i.StreamServerInterceptor()}, s.opts.streamServerInterceptors...)
k: "grpc-service-version", v: s.opts.version, s.opts.streamClientInterceptors = append([]grpc.StreamClientInterceptor{i.StreamClientInterceptor()}, s.opts.streamClientInterceptors...)
}.UnaryClientInterceptor()}, s.opts.unaryClientInterceptors...)
s.opts.streamServerInterceptors = append([]grpc.StreamServerInterceptor{mdInterceptors{
k: "grpc-service-version", v: s.opts.version,
}.StreamServerInterceptor()}, s.opts.streamServerInterceptors...)
s.opts.streamClientInterceptors = append([]grpc.StreamClientInterceptor{mdInterceptors{
k: "grpc-service-version", v: s.opts.version,
}.StreamClientInterceptor()}, s.opts.streamClientInterceptors...)
} }
if s.opts.mux == nil { if s.opts.mux == nil {
s.opts.mux = http.NewServeMux() s.opts.mux = http.NewServeMux()