add metadata interceptors, auth client interceptors

Signed-off-by: Adphi <philippe.adrien.nousse@gmail.com>
This commit is contained in:
Adphi 2022-03-11 12:33:18 +01:00
parent c0e79d8834
commit 97ced73270
Signed by: adphi
GPG Key ID: 46BE4062DB2397FF
11 changed files with 101 additions and 68 deletions

View File

@ -45,7 +45,7 @@ func New(opts ...Option) (Client, error) {
}
if c.opts.addr == "" {
c.addr = fmt.Sprintf("%s:///%s", c.opts.registry.String(), c.opts.name)
} else if strings.HasPrefix(c.opts.addr, "tcp://"){
} else if strings.HasPrefix(c.opts.addr, "tcp://") {
c.addr = strings.Replace(c.opts.addr, "tcp://", "", 1)
} else {
c.addr = c.opts.addr

View File

@ -18,7 +18,7 @@ import (
"go.linka.cloud/grpc/config"
)
func newConfigFile(t *testing.T) (config.Config, string, func()){
func newConfigFile(t *testing.T) (config.Config, string, func()) {
path := filepath.Join(os.TempDir(), "config.yaml")
if err := ioutil.WriteFile(path, []byte("ok"), os.ModePerm); err != nil {
t.Fatal(err)
@ -65,7 +65,7 @@ func TestWatch(t *testing.T) {
}
// when overwriting the file and waiting for the custom change notification handler to be triggered
err := ioutil.WriteFile(cpath, []byte("foo: baz\n"), 0o640)
b := <- updates
b := <-updates
// then the config value should have changed
require.Nil(t, err)
assert.Equal(t, []byte("foo: baz\n"), b)

View File

@ -11,6 +11,7 @@ import (
"net/mail"
"net/url"
"regexp"
"sort"
"strings"
"time"
"unicode/utf8"
@ -31,15 +32,25 @@ var (
_ = (*url.URL)(nil)
_ = (*mail.Address)(nil)
_ = anypb.Any{}
_ = sort.Sort
)
// Validate checks the field values on HelloRequest with the rules defined in
// the proto definition for this message. If any rules are violated, an error
// is returned. When asked to return all errors, validation continues after
// first violation, and the result is a list of violation errors wrapped in
// HelloRequestMultiError, or nil if none found. Otherwise, only the first
// error is returned, if any.
func (m *HelloRequest) Validate(all bool) error {
// the proto definition for this message. If any rules are violated, the first
// error encountered is returned, or nil if there are no violations.
func (m *HelloRequest) Validate() error {
return m.validate(false)
}
// ValidateAll checks the field values on HelloRequest with the rules defined
// in the proto definition for this message. If any rules are violated, the
// result is a list of violation errors wrapped in HelloRequestMultiError, or
// nil if none found.
func (m *HelloRequest) ValidateAll() error {
return m.validate(true)
}
func (m *HelloRequest) validate(all bool) error {
if m == nil {
return nil
}
@ -64,8 +75,7 @@ func (m *HelloRequest) Validate(all bool) error {
}
// HelloRequestMultiError is an error wrapping multiple validation errors
// returned by HelloRequest.Validate(true) if the designated constraints
// aren't met.
// returned by HelloRequest.ValidateAll() if the designated constraints aren't met.
type HelloRequestMultiError []error
// Error returns a concatenation of all the error messages it wraps.
@ -135,12 +145,21 @@ var _ interface {
} = HelloRequestValidationError{}
// Validate checks the field values on HelloReply with the rules defined in the
// proto definition for this message. If any rules are violated, an error is
// returned. When asked to return all errors, validation continues after first
// violation, and the result is a list of violation errors wrapped in
// HelloReplyMultiError, or nil if none found. Otherwise, only the first error
// is returned, if any.
func (m *HelloReply) Validate(all bool) error {
// proto definition for this message. If any rules are violated, the first
// error encountered is returned, or nil if there are no violations.
func (m *HelloReply) Validate() error {
return m.validate(false)
}
// ValidateAll checks the field values on HelloReply with the rules defined in
// the proto definition for this message. If any rules are violated, the
// result is a list of violation errors wrapped in HelloReplyMultiError, or
// nil if none found.
func (m *HelloReply) ValidateAll() error {
return m.validate(true)
}
func (m *HelloReply) validate(all bool) error {
if m == nil {
return nil
}
@ -156,7 +175,7 @@ func (m *HelloReply) Validate(all bool) error {
}
// HelloReplyMultiError is an error wrapping multiple validation errors
// returned by HelloReply.Validate(true) if the designated constraints aren't met.
// returned by HelloReply.ValidateAll() if the designated constraints aren't met.
type HelloReplyMultiError []error
// Error returns a concatenation of all the error messages it wraps.
@ -227,11 +246,20 @@ var _ interface {
// Validate checks the field values on HelloStreamRequest with the rules
// defined in the proto definition for this message. If any rules are
// violated, an error is returned. When asked to return all errors, validation
// continues after first violation, and the result is a list of violation
// errors wrapped in HelloStreamRequestMultiError, or nil if none found.
// Otherwise, only the first error is returned, if any.
func (m *HelloStreamRequest) Validate(all bool) error {
// violated, the first error encountered is returned, or nil if there are no violations.
func (m *HelloStreamRequest) Validate() error {
return m.validate(false)
}
// ValidateAll checks the field values on HelloStreamRequest with the rules
// defined in the proto definition for this message. If any rules are
// violated, the result is a list of violation errors wrapped in
// HelloStreamRequestMultiError, or nil if none found.
func (m *HelloStreamRequest) ValidateAll() error {
return m.validate(true)
}
func (m *HelloStreamRequest) validate(all bool) error {
if m == nil {
return nil
}
@ -267,7 +295,7 @@ func (m *HelloStreamRequest) Validate(all bool) error {
}
// HelloStreamRequestMultiError is an error wrapping multiple validation errors
// returned by HelloStreamRequest.Validate(true) if the designated constraints
// returned by HelloStreamRequest.ValidateAll() if the designated constraints
// aren't met.
type HelloStreamRequestMultiError []error

View File

@ -8,13 +8,15 @@ import (
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
"go.linka.cloud/grpc/errors"
"go.linka.cloud/grpc/interceptors"
"go.linka.cloud/grpc/interceptors/metadata"
)
func BasicAuth(user, password string) string {
return "basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+password))
}
type BasicValidator func(ctx context.Context, user, password string) (context.Context,error)
type BasicValidator func(ctx context.Context, user, password string) (context.Context, error)
func makeBasicAuthFunc(v BasicValidator) grpc_auth.AuthFunc {
return func(ctx context.Context) (context.Context, error) {
@ -34,3 +36,7 @@ func makeBasicAuthFunc(v BasicValidator) grpc_auth.AuthFunc {
return v(ctx, cs[:s], cs[s+1:])
}
}
func NewBasicAuthClientIntereptors(user, password string) interceptors.ClientInterceptors {
return metadata.NewInterceptors("authorization", BasicAuth(user, password))
}

View File

@ -2,6 +2,7 @@ package auth
import (
"context"
"crypto/subtle"
"strings"
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
@ -40,8 +41,8 @@ func NewServerInterceptors(opts ...Option) interceptors.ServerInterceptors {
return &interceptor{o: o, authFn: ChainedAuthFuncs(o.authFns...)}
}
type interceptor struct{
o options
type interceptor struct {
o options
authFn grpc_auth.AuthFunc
}
@ -92,3 +93,7 @@ func (i *interceptor) isNotProtected(endpoint string) bool {
}
return true
}
func Equals(s1, s2 string) bool {
return subtle.ConstantTimeCompare([]byte(s1), []byte(s2)) == 1
}

View File

@ -102,11 +102,11 @@ func TestChainedAuthFuncs(t *testing.T) {
code: codes.PermissionDenied,
},
{
name: "internal error",
auth: "bearer internal",
name: "internal error",
auth: "bearer internal",
internalError: true,
err: true,
code: codes.PermissionDenied,
err: true,
code: codes.PermissionDenied,
},
{
name: "multiple auth: first basic valid",

View File

@ -4,6 +4,9 @@ import (
"context"
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
"go.linka.cloud/grpc/interceptors"
"go.linka.cloud/grpc/interceptors/metadata"
)
type TokenValidator func(ctx context.Context, token string) (context.Context, error)
@ -17,3 +20,7 @@ func makeTokenAuthFunc(v TokenValidator) grpc_auth.AuthFunc {
return v(ctx, a)
}
}
func NewBearerClientInterceptors(token string) interceptors.ClientInterceptors {
return metadata.NewInterceptors("authorization", "Bearer "+token)
}

View File

@ -1,19 +1,25 @@
package service
package metadata
import (
"context"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"go.linka.cloud/grpc/interceptors"
)
func NewInterceptors(pairs ...string) interceptors.Interceptors {
return mdInterceptors{pairs: pairs}
}
type mdInterceptors struct {
k, v string
pairs []string
}
func (i mdInterceptors) UnaryServerInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
if err := grpc.SetHeader(ctx, metadata.Pairs(i.k, i.v)); err != nil {
if err := grpc.SetHeader(ctx, metadata.Pairs(i.pairs...)); err != nil {
return nil, err
}
return handler(ctx, req)
@ -22,7 +28,7 @@ func (i mdInterceptors) UnaryServerInterceptor() grpc.UnaryServerInterceptor {
func (i mdInterceptors) StreamServerInterceptor() grpc.StreamServerInterceptor {
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
if err := grpc.SetHeader(ss.Context(), metadata.Pairs(i.k, i.v)); err != nil {
if err := grpc.SetHeader(ss.Context(), metadata.Pairs(i.pairs...)); err != nil {
return err
}
return handler(srv, ss)
@ -31,7 +37,7 @@ func (i mdInterceptors) StreamServerInterceptor() grpc.StreamServerInterceptor {
func (i mdInterceptors) UnaryClientInterceptor() grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
if err := grpc.SetHeader(ctx, metadata.Pairs(i.k, i.v)); err != nil {
if err := grpc.SetHeader(ctx, metadata.Pairs(i.pairs...)); err != nil {
return err
}
return invoker(ctx, method, req, reply, cc, opts...)
@ -40,7 +46,7 @@ func (i mdInterceptors) UnaryClientInterceptor() grpc.UnaryClientInterceptor {
func (i mdInterceptors) StreamClientInterceptor() grpc.StreamClientInterceptor {
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
if err := grpc.SetHeader(ctx, metadata.Pairs(i.k, i.v)); err != nil {
if err := grpc.SetHeader(ctx, metadata.Pairs(i.pairs...)); err != nil {
return nil, err
}
return streamer(ctx, desc, cc, method, opts...)

View File

@ -30,6 +30,3 @@ func (i *recovery) UnaryClientInterceptor() grpc.UnaryClientInterceptor {
func (i *recovery) StreamClientInterceptor() grpc.StreamClientInterceptor {
panic("not implemented")
}

View File

@ -3,7 +3,7 @@ package sentry
import (
"google.golang.org/grpc"
"github.com/johnbellone/grpc-middleware-sentry"
grpc_sentry "github.com/johnbellone/grpc-middleware-sentry"
"go.linka.cloud/grpc/interceptors"
)
@ -31,6 +31,3 @@ func (i *interceptor) UnaryClientInterceptor() grpc.UnaryClientInterceptor {
func (i *interceptor) StreamClientInterceptor() grpc.StreamClientInterceptor {
return grpc_sentry.StreamClientInterceptor(i.opts...)
}

View File

@ -28,6 +28,7 @@ import (
"google.golang.org/grpc/health/grpc_health_v1"
greflect "google.golang.org/grpc/reflection"
"go.linka.cloud/grpc/interceptors/metadata"
"go.linka.cloud/grpc/registry"
"go.linka.cloud/grpc/registry/noop"
)
@ -74,32 +75,18 @@ func newService(opts ...Option) (*service, error) {
f(s.opts)
}
if s.opts.name != "" {
s.opts.unaryServerInterceptors = append([]grpc.UnaryServerInterceptor{mdInterceptors{
k: "grpc-service-name", v: s.opts.name,
}.UnaryServerInterceptor()}, s.opts.unaryServerInterceptors...)
s.opts.unaryClientInterceptors = append([]grpc.UnaryClientInterceptor{mdInterceptors{
k: "grpc-service-name", v: s.opts.name,
}.UnaryClientInterceptor()}, s.opts.unaryClientInterceptors...)
s.opts.streamServerInterceptors = append([]grpc.StreamServerInterceptor{mdInterceptors{
k: "grpc-service-name", v: s.opts.name,
}.StreamServerInterceptor()}, s.opts.streamServerInterceptors...)
s.opts.streamClientInterceptors = append([]grpc.StreamClientInterceptor{mdInterceptors{
k: "grpc-service-name", v: s.opts.name,
}.StreamClientInterceptor()}, s.opts.streamClientInterceptors...)
i := metadata.NewInterceptors("grpc-service-name", s.opts.name)
s.opts.unaryServerInterceptors = append([]grpc.UnaryServerInterceptor{i.UnaryServerInterceptor()}, s.opts.unaryServerInterceptors...)
s.opts.unaryClientInterceptors = append([]grpc.UnaryClientInterceptor{i.UnaryClientInterceptor()}, s.opts.unaryClientInterceptors...)
s.opts.streamServerInterceptors = append([]grpc.StreamServerInterceptor{i.StreamServerInterceptor()}, s.opts.streamServerInterceptors...)
s.opts.streamClientInterceptors = append([]grpc.StreamClientInterceptor{i.StreamClientInterceptor()}, s.opts.streamClientInterceptors...)
}
if s.opts.version != "" {
s.opts.unaryServerInterceptors = append([]grpc.UnaryServerInterceptor{mdInterceptors{
k: "grpc-service-version", v: s.opts.version,
}.UnaryServerInterceptor()}, s.opts.unaryServerInterceptors...)
s.opts.unaryClientInterceptors = append([]grpc.UnaryClientInterceptor{mdInterceptors{
k: "grpc-service-version", v: s.opts.version,
}.UnaryClientInterceptor()}, s.opts.unaryClientInterceptors...)
s.opts.streamServerInterceptors = append([]grpc.StreamServerInterceptor{mdInterceptors{
k: "grpc-service-version", v: s.opts.version,
}.StreamServerInterceptor()}, s.opts.streamServerInterceptors...)
s.opts.streamClientInterceptors = append([]grpc.StreamClientInterceptor{mdInterceptors{
k: "grpc-service-version", v: s.opts.version,
}.StreamClientInterceptor()}, s.opts.streamClientInterceptors...)
i := metadata.NewInterceptors("grpc-service-version", s.opts.version)
s.opts.unaryServerInterceptors = append([]grpc.UnaryServerInterceptor{i.UnaryServerInterceptor()}, s.opts.unaryServerInterceptors...)
s.opts.unaryClientInterceptors = append([]grpc.UnaryClientInterceptor{i.UnaryClientInterceptor()}, s.opts.unaryClientInterceptors...)
s.opts.streamServerInterceptors = append([]grpc.StreamServerInterceptor{i.StreamServerInterceptor()}, s.opts.streamServerInterceptors...)
s.opts.streamClientInterceptors = append([]grpc.StreamClientInterceptor{i.StreamClientInterceptor()}, s.opts.streamClientInterceptors...)
}
if s.opts.mux == nil {
s.opts.mux = http.NewServeMux()