add auth interceptors

This commit is contained in:
Adphi 2021-12-13 12:08:10 +01:00
parent 55251b5020
commit e578d62a29
6 changed files with 400 additions and 0 deletions

View File

@ -0,0 +1,36 @@
package auth
import (
"context"
"encoding/base64"
"strings"
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
"go.linka.cloud/grpc/errors"
)
func BasicAuth(user, password string) string {
return "basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+password))
}
type BasicValidator func(ctx context.Context, user, password string) (context.Context,error)
func makeBasicAuthFunc(v BasicValidator) grpc_auth.AuthFunc {
return func(ctx context.Context) (context.Context, error) {
a, err := grpc_auth.AuthFromMD(ctx, "basic")
if err != nil {
return ctx, err
}
c, err := base64.StdEncoding.DecodeString(a)
if err != nil {
return ctx, err
}
cs := string(c)
s := strings.IndexByte(cs, ':')
if s < 0 {
return ctx, errors.Unauthenticatedf("malformed basic auth")
}
return v(ctx, cs[:s], cs[s+1:])
}
}

View File

@ -0,0 +1,94 @@
package auth
import (
"context"
"strings"
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"go.linka.cloud/grpc/interceptors"
)
func ChainedAuthFuncs(fn ...grpc_auth.AuthFunc) grpc_auth.AuthFunc {
return func(ctx context.Context) (context.Context, error) {
code := codes.Unauthenticated
for _, v := range fn {
ctx2, err := v(ctx)
if err == nil {
return ctx2, nil
}
s, ok := status.FromError(err)
if !ok {
return ctx2, err
}
if s.Code() == codes.PermissionDenied {
code = codes.PermissionDenied
}
}
return ctx, status.Error(code, code.String())
}
}
func NewServerInterceptors(opts ...Option) interceptors.ServerInterceptors {
o := options{}
for _, v := range opts {
v(&o)
}
return &interceptor{o: o, authFn: ChainedAuthFuncs(o.authFns...)}
}
type interceptor struct{
o options
authFn grpc_auth.AuthFunc
}
func (i *interceptor) UnaryServerInterceptor() grpc.UnaryServerInterceptor {
a := grpc_auth.UnaryServerInterceptor(i.authFn)
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
if i.isNotProtected(info.FullMethod) {
return handler(ctx, req)
}
return a(ctx, req, info, handler)
}
}
func (i *interceptor) StreamServerInterceptor() grpc.StreamServerInterceptor {
a := grpc_auth.StreamServerInterceptor(i.authFn)
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
if i.isNotProtected(info.FullMethod) {
return handler(srv, ss)
}
return a(srv, ss, info, handler)
}
}
func (i *interceptor) isNotProtected(endpoint string) bool {
// default to not ignored
if len(i.o.ignoredMethods) == 0 && len(i.o.methods) == 0 {
return false
}
// endpoint is like /helloworld.Greeter/SayHello
parts := strings.Split(strings.TrimPrefix(endpoint, "/"), "/")
// invalid endpoint format
if len(parts) != 2 {
return false
}
method := parts[1]
for _, v := range i.o.ignoredMethods {
if v == method {
return true
}
}
if len(i.o.methods) == 0 {
return false
}
for _, v := range i.o.methods {
if v == method {
return false
}
}
return true
}

View File

@ -0,0 +1,146 @@
package auth
import (
"context"
"testing"
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
assert2 "github.com/stretchr/testify/assert"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"go.linka.cloud/grpc/errors"
)
func TestNotProtectededOnly(t *testing.T) {
assert := assert2.New(t)
i := &interceptor{o: options{ignoredMethods: []string{"ignored"}}}
assert.False(i.isNotProtected("/test.Service/protected"))
assert.True(i.isNotProtected("/test.Service/ignored"))
}
func TestProtectedOnly(t *testing.T) {
assert := assert2.New(t)
i := &interceptor{o: options{methods: []string{"protected"}}}
assert.False(i.isNotProtected("/test.Service/protected"))
assert.True(i.isNotProtected("/test.Service/ignored"))
}
func TestProtectedAndIgnored(t *testing.T) {
assert := assert2.New(t)
i := &interceptor{o: options{methods: []string{"protected"}, ignoredMethods: []string{"ignored"}}}
assert.True(i.isNotProtected("/test.Service/ignored"))
assert.False(i.isNotProtected("/test.Service/protected"))
assert.True(i.isNotProtected("/test.Service/other"))
}
func TestProtectedByDefault(t *testing.T) {
i := &interceptor{}
assert2.False(t, i.isNotProtected("nooop"))
assert2.False(t, i.isNotProtected("/test.Service/method/cannotExists"))
assert2.False(t, i.isNotProtected("/test.Service/validMethod"))
}
var (
adminAuth = func(ctx context.Context, user, password string) (context.Context, error) {
if user == "admin" && password == "admin" {
return ctx, nil
}
return ctx, errors.PermissionDeniedf("")
}
testAuth = func(ctx context.Context, user, password string) (context.Context, error) {
if user == "test" && password == "test" {
return ctx, nil
}
return ctx, errors.PermissionDeniedf("")
}
tokenAuth = func(ctx context.Context, token string) (context.Context, error) {
if token == "token" {
return ctx, nil
}
return ctx, errors.PermissionDeniedf("")
}
)
func TestChainedAuthFuncs(t *testing.T) {
wantInternalError := false
ctx := context.Background()
auth := ChainedAuthFuncs([]grpc_auth.AuthFunc{
makeBasicAuthFunc(adminAuth),
makeBasicAuthFunc(testAuth),
makeTokenAuthFunc(tokenAuth),
makeTokenAuthFunc(func(ctx context.Context, token string) (context.Context, error) {
if wantInternalError {
return ctx, errors.Internalf("ooops")
}
return ctx, errors.Unauthenticatedf("")
}),
}...)
tests := []struct {
name string
auth string
internalError bool
err bool
code codes.Code
}{
{
name: "no auth",
auth: "",
err: true,
code: codes.Unauthenticated,
},
{
name: "valid token",
auth: "bearer token",
},
{
name: "empty bearer",
auth: "bearer ",
err: true,
code: codes.PermissionDenied,
},
{
name: "internal error",
auth: "bearer internal",
internalError: true,
err: true,
code: codes.PermissionDenied,
},
{
name: "multiple auth: first basic valid",
auth: BasicAuth("admin", "admin"),
},
{
name: "multiple auth: second baisc valid",
auth: BasicAuth("test", "test"),
},
{
name: "invalid auth: bearer",
auth: "bearer noop",
err: true,
code: codes.PermissionDenied,
},
{
name: "invalid auth: basic",
auth: BasicAuth("other", "other"),
err: true,
code: codes.PermissionDenied,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
wantInternalError = tt.internalError
rctx, err := auth(metadata.NewIncomingContext(ctx, metadata.Pairs("authorization", tt.auth)))
if tt.err {
assert2.Error(t, err)
s, ok := status.FromError(err)
assert2.True(t, ok)
assert2.Equal(t, tt.code, s.Code())
}
assert2.NotNil(t, rctx)
})
}
}

View File

@ -0,0 +1,56 @@
package auth
import (
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
)
type Option func(o *options)
func WithMethods(methods ...string) Option {
return func(o *options) {
o.methods = append(o.methods, methods...)
}
}
func WithIgnoredMethods(methods ...string) Option {
return func(o *options) {
o.ignoredMethods = append(o.ignoredMethods, methods...)
}
}
func WithBasicValidators(validators ...BasicValidator) Option {
var authFns []grpc_auth.AuthFunc
for _, v := range validators {
authFns = append(authFns, makeBasicAuthFunc(v))
}
return func(o *options) {
o.authFns = append(o.authFns, authFns...)
}
}
func WithTokenValidators(validators ...TokenValidator) Option {
var authFns []grpc_auth.AuthFunc
for _, v := range validators {
authFns = append(authFns, makeTokenAuthFunc(v))
}
return func(o *options) {
o.authFns = append(o.authFns, authFns...)
}
}
func WithX509Validators(validators ...X509Validator) Option {
var authFns []grpc_auth.AuthFunc
for _, v := range validators {
authFns = append(authFns, makeX509AuthFunc(v))
}
return func(o *options) {
o.authFns = append(o.authFns, authFns...)
}
}
type options struct {
methods []string
ignoredMethods []string
authFns []grpc_auth.AuthFunc
}

View File

@ -0,0 +1,19 @@
package auth
import (
"context"
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
)
type TokenValidator func(ctx context.Context, token string) (context.Context, error)
func makeTokenAuthFunc(v TokenValidator) grpc_auth.AuthFunc {
return func(ctx context.Context) (context.Context, error) {
a, err := grpc_auth.AuthFromMD(ctx, "bearer")
if err != nil {
return ctx, err
}
return v(ctx, a)
}
}

49
interceptors/auth/x509.go Normal file
View File

@ -0,0 +1,49 @@
package auth
import (
"context"
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/peer"
"go.linka.cloud/grpc/errors"
)
type X509Validator func(ctx context.Context, sans []string) (context.Context, error)
// func _(ctx context.Context) {
// p, ok := peer.FromContext(ctx)
// if !ok {
// return
// }
// i, ok := p.AuthInfo.(credentials.TLSInfo)
// if !ok {
// return
// }
// i.State.VerifiedChains
// }
func makeX509AuthFunc(v X509Validator) grpc_auth.AuthFunc {
return func(ctx context.Context) (context.Context, error) {
p, ok := peer.FromContext(ctx)
if !ok {
return ctx, errors.Internalf("peer not found")
}
i, ok := p.AuthInfo.(credentials.TLSInfo)
if !ok {
return ctx, errors.Unauthenticatedf("no TLS credentials")
}
if !i.State.HandshakeComplete {
return ctx, errors.Unauthenticatedf("handshake not complete")
}
var sans []string
for _, v := range i.State.VerifiedChains {
if len(v) == 0 {
continue
}
sans = append(sans, v[0].PermittedDNSDomains...)
}
return v(ctx, sans)
}
}