diff --git a/interceptors/auth/basic.go b/interceptors/auth/basic.go new file mode 100644 index 0000000..03fb2be --- /dev/null +++ b/interceptors/auth/basic.go @@ -0,0 +1,36 @@ +package auth + +import ( + "context" + "encoding/base64" + "strings" + + grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth" + + "go.linka.cloud/grpc/errors" +) + +func BasicAuth(user, password string) string { + return "basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+password)) +} + +type BasicValidator func(ctx context.Context, user, password string) (context.Context,error) + +func makeBasicAuthFunc(v BasicValidator) grpc_auth.AuthFunc { + return func(ctx context.Context) (context.Context, error) { + a, err := grpc_auth.AuthFromMD(ctx, "basic") + if err != nil { + return ctx, err + } + c, err := base64.StdEncoding.DecodeString(a) + if err != nil { + return ctx, err + } + cs := string(c) + s := strings.IndexByte(cs, ':') + if s < 0 { + return ctx, errors.Unauthenticatedf("malformed basic auth") + } + return v(ctx, cs[:s], cs[s+1:]) + } +} diff --git a/interceptors/auth/interceptors.go b/interceptors/auth/interceptors.go new file mode 100644 index 0000000..93f8582 --- /dev/null +++ b/interceptors/auth/interceptors.go @@ -0,0 +1,94 @@ +package auth + +import ( + "context" + "strings" + + grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "go.linka.cloud/grpc/interceptors" +) + +func ChainedAuthFuncs(fn ...grpc_auth.AuthFunc) grpc_auth.AuthFunc { + return func(ctx context.Context) (context.Context, error) { + code := codes.Unauthenticated + for _, v := range fn { + ctx2, err := v(ctx) + if err == nil { + return ctx2, nil + } + s, ok := status.FromError(err) + if !ok { + return ctx2, err + } + if s.Code() == codes.PermissionDenied { + code = codes.PermissionDenied + } + } + return ctx, status.Error(code, code.String()) + } +} + +func NewServerInterceptors(opts ...Option) interceptors.ServerInterceptors { + o := options{} + for _, v := range opts { + v(&o) + } + return &interceptor{o: o, authFn: ChainedAuthFuncs(o.authFns...)} +} + +type interceptor struct{ + o options + authFn grpc_auth.AuthFunc +} + +func (i *interceptor) UnaryServerInterceptor() grpc.UnaryServerInterceptor { + a := grpc_auth.UnaryServerInterceptor(i.authFn) + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { + if i.isNotProtected(info.FullMethod) { + return handler(ctx, req) + } + return a(ctx, req, info, handler) + } +} + +func (i *interceptor) StreamServerInterceptor() grpc.StreamServerInterceptor { + a := grpc_auth.StreamServerInterceptor(i.authFn) + return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + if i.isNotProtected(info.FullMethod) { + return handler(srv, ss) + } + return a(srv, ss, info, handler) + } +} + +func (i *interceptor) isNotProtected(endpoint string) bool { + // default to not ignored + if len(i.o.ignoredMethods) == 0 && len(i.o.methods) == 0 { + return false + } + // endpoint is like /helloworld.Greeter/SayHello + parts := strings.Split(strings.TrimPrefix(endpoint, "/"), "/") + // invalid endpoint format + if len(parts) != 2 { + return false + } + method := parts[1] + for _, v := range i.o.ignoredMethods { + if v == method { + return true + } + } + if len(i.o.methods) == 0 { + return false + } + for _, v := range i.o.methods { + if v == method { + return false + } + } + return true +} diff --git a/interceptors/auth/interceptors_test.go b/interceptors/auth/interceptors_test.go new file mode 100644 index 0000000..f101e4a --- /dev/null +++ b/interceptors/auth/interceptors_test.go @@ -0,0 +1,146 @@ +package auth + +import ( + "context" + "testing" + + grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth" + assert2 "github.com/stretchr/testify/assert" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + + "go.linka.cloud/grpc/errors" +) + +func TestNotProtectededOnly(t *testing.T) { + assert := assert2.New(t) + i := &interceptor{o: options{ignoredMethods: []string{"ignored"}}} + assert.False(i.isNotProtected("/test.Service/protected")) + assert.True(i.isNotProtected("/test.Service/ignored")) +} + +func TestProtectedOnly(t *testing.T) { + assert := assert2.New(t) + i := &interceptor{o: options{methods: []string{"protected"}}} + assert.False(i.isNotProtected("/test.Service/protected")) + assert.True(i.isNotProtected("/test.Service/ignored")) +} + +func TestProtectedAndIgnored(t *testing.T) { + assert := assert2.New(t) + i := &interceptor{o: options{methods: []string{"protected"}, ignoredMethods: []string{"ignored"}}} + assert.True(i.isNotProtected("/test.Service/ignored")) + assert.False(i.isNotProtected("/test.Service/protected")) + assert.True(i.isNotProtected("/test.Service/other")) +} + +func TestProtectedByDefault(t *testing.T) { + i := &interceptor{} + assert2.False(t, i.isNotProtected("nooop")) + assert2.False(t, i.isNotProtected("/test.Service/method/cannotExists")) + assert2.False(t, i.isNotProtected("/test.Service/validMethod")) +} + +var ( + adminAuth = func(ctx context.Context, user, password string) (context.Context, error) { + if user == "admin" && password == "admin" { + return ctx, nil + } + return ctx, errors.PermissionDeniedf("") + } + testAuth = func(ctx context.Context, user, password string) (context.Context, error) { + if user == "test" && password == "test" { + return ctx, nil + } + return ctx, errors.PermissionDeniedf("") + } + tokenAuth = func(ctx context.Context, token string) (context.Context, error) { + if token == "token" { + return ctx, nil + } + return ctx, errors.PermissionDeniedf("") + } +) + +func TestChainedAuthFuncs(t *testing.T) { + wantInternalError := false + ctx := context.Background() + auth := ChainedAuthFuncs([]grpc_auth.AuthFunc{ + makeBasicAuthFunc(adminAuth), + makeBasicAuthFunc(testAuth), + makeTokenAuthFunc(tokenAuth), + makeTokenAuthFunc(func(ctx context.Context, token string) (context.Context, error) { + if wantInternalError { + return ctx, errors.Internalf("ooops") + } + return ctx, errors.Unauthenticatedf("") + }), + }...) + + tests := []struct { + name string + auth string + internalError bool + err bool + code codes.Code + }{ + { + name: "no auth", + auth: "", + err: true, + code: codes.Unauthenticated, + }, + { + name: "valid token", + auth: "bearer token", + }, + { + name: "empty bearer", + auth: "bearer ", + err: true, + code: codes.PermissionDenied, + }, + { + name: "internal error", + auth: "bearer internal", + internalError: true, + err: true, + code: codes.PermissionDenied, + }, + { + name: "multiple auth: first basic valid", + auth: BasicAuth("admin", "admin"), + }, + { + name: "multiple auth: second baisc valid", + auth: BasicAuth("test", "test"), + }, + { + name: "invalid auth: bearer", + auth: "bearer noop", + err: true, + code: codes.PermissionDenied, + }, + { + name: "invalid auth: basic", + auth: BasicAuth("other", "other"), + err: true, + code: codes.PermissionDenied, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + wantInternalError = tt.internalError + rctx, err := auth(metadata.NewIncomingContext(ctx, metadata.Pairs("authorization", tt.auth))) + if tt.err { + assert2.Error(t, err) + s, ok := status.FromError(err) + assert2.True(t, ok) + assert2.Equal(t, tt.code, s.Code()) + } + assert2.NotNil(t, rctx) + }) + } +} diff --git a/interceptors/auth/options.go b/interceptors/auth/options.go new file mode 100644 index 0000000..dd59b57 --- /dev/null +++ b/interceptors/auth/options.go @@ -0,0 +1,56 @@ +package auth + +import ( + grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth" +) + +type Option func(o *options) + +func WithMethods(methods ...string) Option { + return func(o *options) { + o.methods = append(o.methods, methods...) + } +} + +func WithIgnoredMethods(methods ...string) Option { + return func(o *options) { + o.ignoredMethods = append(o.ignoredMethods, methods...) + } +} + +func WithBasicValidators(validators ...BasicValidator) Option { + var authFns []grpc_auth.AuthFunc + for _, v := range validators { + authFns = append(authFns, makeBasicAuthFunc(v)) + } + return func(o *options) { + o.authFns = append(o.authFns, authFns...) + } +} + +func WithTokenValidators(validators ...TokenValidator) Option { + var authFns []grpc_auth.AuthFunc + for _, v := range validators { + authFns = append(authFns, makeTokenAuthFunc(v)) + } + return func(o *options) { + o.authFns = append(o.authFns, authFns...) + } +} + +func WithX509Validators(validators ...X509Validator) Option { + var authFns []grpc_auth.AuthFunc + for _, v := range validators { + authFns = append(authFns, makeX509AuthFunc(v)) + } + return func(o *options) { + o.authFns = append(o.authFns, authFns...) + } +} + +type options struct { + methods []string + ignoredMethods []string + + authFns []grpc_auth.AuthFunc +} diff --git a/interceptors/auth/token.go b/interceptors/auth/token.go new file mode 100644 index 0000000..6424c64 --- /dev/null +++ b/interceptors/auth/token.go @@ -0,0 +1,19 @@ +package auth + +import ( + "context" + + grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth" +) + +type TokenValidator func(ctx context.Context, token string) (context.Context, error) + +func makeTokenAuthFunc(v TokenValidator) grpc_auth.AuthFunc { + return func(ctx context.Context) (context.Context, error) { + a, err := grpc_auth.AuthFromMD(ctx, "bearer") + if err != nil { + return ctx, err + } + return v(ctx, a) + } +} diff --git a/interceptors/auth/x509.go b/interceptors/auth/x509.go new file mode 100644 index 0000000..2aae089 --- /dev/null +++ b/interceptors/auth/x509.go @@ -0,0 +1,49 @@ +package auth + +import ( + "context" + + grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/peer" + + "go.linka.cloud/grpc/errors" +) + +type X509Validator func(ctx context.Context, sans []string) (context.Context, error) + +// func _(ctx context.Context) { +// p, ok := peer.FromContext(ctx) +// if !ok { +// return +// } +// i, ok := p.AuthInfo.(credentials.TLSInfo) +// if !ok { +// return +// } +// i.State.VerifiedChains +// } + +func makeX509AuthFunc(v X509Validator) grpc_auth.AuthFunc { + return func(ctx context.Context) (context.Context, error) { + p, ok := peer.FromContext(ctx) + if !ok { + return ctx, errors.Internalf("peer not found") + } + i, ok := p.AuthInfo.(credentials.TLSInfo) + if !ok { + return ctx, errors.Unauthenticatedf("no TLS credentials") + } + if !i.State.HandshakeComplete { + return ctx, errors.Unauthenticatedf("handshake not complete") + } + var sans []string + for _, v := range i.State.VerifiedChains { + if len(v) == 0 { + continue + } + sans = append(sans, v[0].PermittedDNSDomains...) + } + return v(ctx, sans) + } +}