mirror of
https://github.com/linka-cloud/grpc.git
synced 2024-11-21 18:36:25 +00:00
add auth interceptors
This commit is contained in:
parent
55251b5020
commit
e578d62a29
36
interceptors/auth/basic.go
Normal file
36
interceptors/auth/basic.go
Normal file
@ -0,0 +1,36 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
|
||||
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
|
||||
|
||||
"go.linka.cloud/grpc/errors"
|
||||
)
|
||||
|
||||
func BasicAuth(user, password string) string {
|
||||
return "basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+password))
|
||||
}
|
||||
|
||||
type BasicValidator func(ctx context.Context, user, password string) (context.Context,error)
|
||||
|
||||
func makeBasicAuthFunc(v BasicValidator) grpc_auth.AuthFunc {
|
||||
return func(ctx context.Context) (context.Context, error) {
|
||||
a, err := grpc_auth.AuthFromMD(ctx, "basic")
|
||||
if err != nil {
|
||||
return ctx, err
|
||||
}
|
||||
c, err := base64.StdEncoding.DecodeString(a)
|
||||
if err != nil {
|
||||
return ctx, err
|
||||
}
|
||||
cs := string(c)
|
||||
s := strings.IndexByte(cs, ':')
|
||||
if s < 0 {
|
||||
return ctx, errors.Unauthenticatedf("malformed basic auth")
|
||||
}
|
||||
return v(ctx, cs[:s], cs[s+1:])
|
||||
}
|
||||
}
|
94
interceptors/auth/interceptors.go
Normal file
94
interceptors/auth/interceptors.go
Normal file
@ -0,0 +1,94 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"go.linka.cloud/grpc/interceptors"
|
||||
)
|
||||
|
||||
func ChainedAuthFuncs(fn ...grpc_auth.AuthFunc) grpc_auth.AuthFunc {
|
||||
return func(ctx context.Context) (context.Context, error) {
|
||||
code := codes.Unauthenticated
|
||||
for _, v := range fn {
|
||||
ctx2, err := v(ctx)
|
||||
if err == nil {
|
||||
return ctx2, nil
|
||||
}
|
||||
s, ok := status.FromError(err)
|
||||
if !ok {
|
||||
return ctx2, err
|
||||
}
|
||||
if s.Code() == codes.PermissionDenied {
|
||||
code = codes.PermissionDenied
|
||||
}
|
||||
}
|
||||
return ctx, status.Error(code, code.String())
|
||||
}
|
||||
}
|
||||
|
||||
func NewServerInterceptors(opts ...Option) interceptors.ServerInterceptors {
|
||||
o := options{}
|
||||
for _, v := range opts {
|
||||
v(&o)
|
||||
}
|
||||
return &interceptor{o: o, authFn: ChainedAuthFuncs(o.authFns...)}
|
||||
}
|
||||
|
||||
type interceptor struct{
|
||||
o options
|
||||
authFn grpc_auth.AuthFunc
|
||||
}
|
||||
|
||||
func (i *interceptor) UnaryServerInterceptor() grpc.UnaryServerInterceptor {
|
||||
a := grpc_auth.UnaryServerInterceptor(i.authFn)
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
|
||||
if i.isNotProtected(info.FullMethod) {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
return a(ctx, req, info, handler)
|
||||
}
|
||||
}
|
||||
|
||||
func (i *interceptor) StreamServerInterceptor() grpc.StreamServerInterceptor {
|
||||
a := grpc_auth.StreamServerInterceptor(i.authFn)
|
||||
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
if i.isNotProtected(info.FullMethod) {
|
||||
return handler(srv, ss)
|
||||
}
|
||||
return a(srv, ss, info, handler)
|
||||
}
|
||||
}
|
||||
|
||||
func (i *interceptor) isNotProtected(endpoint string) bool {
|
||||
// default to not ignored
|
||||
if len(i.o.ignoredMethods) == 0 && len(i.o.methods) == 0 {
|
||||
return false
|
||||
}
|
||||
// endpoint is like /helloworld.Greeter/SayHello
|
||||
parts := strings.Split(strings.TrimPrefix(endpoint, "/"), "/")
|
||||
// invalid endpoint format
|
||||
if len(parts) != 2 {
|
||||
return false
|
||||
}
|
||||
method := parts[1]
|
||||
for _, v := range i.o.ignoredMethods {
|
||||
if v == method {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if len(i.o.methods) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, v := range i.o.methods {
|
||||
if v == method {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
146
interceptors/auth/interceptors_test.go
Normal file
146
interceptors/auth/interceptors_test.go
Normal file
@ -0,0 +1,146 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
|
||||
assert2 "github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"go.linka.cloud/grpc/errors"
|
||||
)
|
||||
|
||||
func TestNotProtectededOnly(t *testing.T) {
|
||||
assert := assert2.New(t)
|
||||
i := &interceptor{o: options{ignoredMethods: []string{"ignored"}}}
|
||||
assert.False(i.isNotProtected("/test.Service/protected"))
|
||||
assert.True(i.isNotProtected("/test.Service/ignored"))
|
||||
}
|
||||
|
||||
func TestProtectedOnly(t *testing.T) {
|
||||
assert := assert2.New(t)
|
||||
i := &interceptor{o: options{methods: []string{"protected"}}}
|
||||
assert.False(i.isNotProtected("/test.Service/protected"))
|
||||
assert.True(i.isNotProtected("/test.Service/ignored"))
|
||||
}
|
||||
|
||||
func TestProtectedAndIgnored(t *testing.T) {
|
||||
assert := assert2.New(t)
|
||||
i := &interceptor{o: options{methods: []string{"protected"}, ignoredMethods: []string{"ignored"}}}
|
||||
assert.True(i.isNotProtected("/test.Service/ignored"))
|
||||
assert.False(i.isNotProtected("/test.Service/protected"))
|
||||
assert.True(i.isNotProtected("/test.Service/other"))
|
||||
}
|
||||
|
||||
func TestProtectedByDefault(t *testing.T) {
|
||||
i := &interceptor{}
|
||||
assert2.False(t, i.isNotProtected("nooop"))
|
||||
assert2.False(t, i.isNotProtected("/test.Service/method/cannotExists"))
|
||||
assert2.False(t, i.isNotProtected("/test.Service/validMethod"))
|
||||
}
|
||||
|
||||
var (
|
||||
adminAuth = func(ctx context.Context, user, password string) (context.Context, error) {
|
||||
if user == "admin" && password == "admin" {
|
||||
return ctx, nil
|
||||
}
|
||||
return ctx, errors.PermissionDeniedf("")
|
||||
}
|
||||
testAuth = func(ctx context.Context, user, password string) (context.Context, error) {
|
||||
if user == "test" && password == "test" {
|
||||
return ctx, nil
|
||||
}
|
||||
return ctx, errors.PermissionDeniedf("")
|
||||
}
|
||||
tokenAuth = func(ctx context.Context, token string) (context.Context, error) {
|
||||
if token == "token" {
|
||||
return ctx, nil
|
||||
}
|
||||
return ctx, errors.PermissionDeniedf("")
|
||||
}
|
||||
)
|
||||
|
||||
func TestChainedAuthFuncs(t *testing.T) {
|
||||
wantInternalError := false
|
||||
ctx := context.Background()
|
||||
auth := ChainedAuthFuncs([]grpc_auth.AuthFunc{
|
||||
makeBasicAuthFunc(adminAuth),
|
||||
makeBasicAuthFunc(testAuth),
|
||||
makeTokenAuthFunc(tokenAuth),
|
||||
makeTokenAuthFunc(func(ctx context.Context, token string) (context.Context, error) {
|
||||
if wantInternalError {
|
||||
return ctx, errors.Internalf("ooops")
|
||||
}
|
||||
return ctx, errors.Unauthenticatedf("")
|
||||
}),
|
||||
}...)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
auth string
|
||||
internalError bool
|
||||
err bool
|
||||
code codes.Code
|
||||
}{
|
||||
{
|
||||
name: "no auth",
|
||||
auth: "",
|
||||
err: true,
|
||||
code: codes.Unauthenticated,
|
||||
},
|
||||
{
|
||||
name: "valid token",
|
||||
auth: "bearer token",
|
||||
},
|
||||
{
|
||||
name: "empty bearer",
|
||||
auth: "bearer ",
|
||||
err: true,
|
||||
code: codes.PermissionDenied,
|
||||
},
|
||||
{
|
||||
name: "internal error",
|
||||
auth: "bearer internal",
|
||||
internalError: true,
|
||||
err: true,
|
||||
code: codes.PermissionDenied,
|
||||
},
|
||||
{
|
||||
name: "multiple auth: first basic valid",
|
||||
auth: BasicAuth("admin", "admin"),
|
||||
},
|
||||
{
|
||||
name: "multiple auth: second baisc valid",
|
||||
auth: BasicAuth("test", "test"),
|
||||
},
|
||||
{
|
||||
name: "invalid auth: bearer",
|
||||
auth: "bearer noop",
|
||||
err: true,
|
||||
code: codes.PermissionDenied,
|
||||
},
|
||||
{
|
||||
name: "invalid auth: basic",
|
||||
auth: BasicAuth("other", "other"),
|
||||
err: true,
|
||||
code: codes.PermissionDenied,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
wantInternalError = tt.internalError
|
||||
rctx, err := auth(metadata.NewIncomingContext(ctx, metadata.Pairs("authorization", tt.auth)))
|
||||
if tt.err {
|
||||
assert2.Error(t, err)
|
||||
s, ok := status.FromError(err)
|
||||
assert2.True(t, ok)
|
||||
assert2.Equal(t, tt.code, s.Code())
|
||||
}
|
||||
assert2.NotNil(t, rctx)
|
||||
})
|
||||
}
|
||||
}
|
56
interceptors/auth/options.go
Normal file
56
interceptors/auth/options.go
Normal file
@ -0,0 +1,56 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
|
||||
)
|
||||
|
||||
type Option func(o *options)
|
||||
|
||||
func WithMethods(methods ...string) Option {
|
||||
return func(o *options) {
|
||||
o.methods = append(o.methods, methods...)
|
||||
}
|
||||
}
|
||||
|
||||
func WithIgnoredMethods(methods ...string) Option {
|
||||
return func(o *options) {
|
||||
o.ignoredMethods = append(o.ignoredMethods, methods...)
|
||||
}
|
||||
}
|
||||
|
||||
func WithBasicValidators(validators ...BasicValidator) Option {
|
||||
var authFns []grpc_auth.AuthFunc
|
||||
for _, v := range validators {
|
||||
authFns = append(authFns, makeBasicAuthFunc(v))
|
||||
}
|
||||
return func(o *options) {
|
||||
o.authFns = append(o.authFns, authFns...)
|
||||
}
|
||||
}
|
||||
|
||||
func WithTokenValidators(validators ...TokenValidator) Option {
|
||||
var authFns []grpc_auth.AuthFunc
|
||||
for _, v := range validators {
|
||||
authFns = append(authFns, makeTokenAuthFunc(v))
|
||||
}
|
||||
return func(o *options) {
|
||||
o.authFns = append(o.authFns, authFns...)
|
||||
}
|
||||
}
|
||||
|
||||
func WithX509Validators(validators ...X509Validator) Option {
|
||||
var authFns []grpc_auth.AuthFunc
|
||||
for _, v := range validators {
|
||||
authFns = append(authFns, makeX509AuthFunc(v))
|
||||
}
|
||||
return func(o *options) {
|
||||
o.authFns = append(o.authFns, authFns...)
|
||||
}
|
||||
}
|
||||
|
||||
type options struct {
|
||||
methods []string
|
||||
ignoredMethods []string
|
||||
|
||||
authFns []grpc_auth.AuthFunc
|
||||
}
|
19
interceptors/auth/token.go
Normal file
19
interceptors/auth/token.go
Normal file
@ -0,0 +1,19 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
|
||||
)
|
||||
|
||||
type TokenValidator func(ctx context.Context, token string) (context.Context, error)
|
||||
|
||||
func makeTokenAuthFunc(v TokenValidator) grpc_auth.AuthFunc {
|
||||
return func(ctx context.Context) (context.Context, error) {
|
||||
a, err := grpc_auth.AuthFromMD(ctx, "bearer")
|
||||
if err != nil {
|
||||
return ctx, err
|
||||
}
|
||||
return v(ctx, a)
|
||||
}
|
||||
}
|
49
interceptors/auth/x509.go
Normal file
49
interceptors/auth/x509.go
Normal file
@ -0,0 +1,49 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/peer"
|
||||
|
||||
"go.linka.cloud/grpc/errors"
|
||||
)
|
||||
|
||||
type X509Validator func(ctx context.Context, sans []string) (context.Context, error)
|
||||
|
||||
// func _(ctx context.Context) {
|
||||
// p, ok := peer.FromContext(ctx)
|
||||
// if !ok {
|
||||
// return
|
||||
// }
|
||||
// i, ok := p.AuthInfo.(credentials.TLSInfo)
|
||||
// if !ok {
|
||||
// return
|
||||
// }
|
||||
// i.State.VerifiedChains
|
||||
// }
|
||||
|
||||
func makeX509AuthFunc(v X509Validator) grpc_auth.AuthFunc {
|
||||
return func(ctx context.Context) (context.Context, error) {
|
||||
p, ok := peer.FromContext(ctx)
|
||||
if !ok {
|
||||
return ctx, errors.Internalf("peer not found")
|
||||
}
|
||||
i, ok := p.AuthInfo.(credentials.TLSInfo)
|
||||
if !ok {
|
||||
return ctx, errors.Unauthenticatedf("no TLS credentials")
|
||||
}
|
||||
if !i.State.HandshakeComplete {
|
||||
return ctx, errors.Unauthenticatedf("handshake not complete")
|
||||
}
|
||||
var sans []string
|
||||
for _, v := range i.State.VerifiedChains {
|
||||
if len(v) == 0 {
|
||||
continue
|
||||
}
|
||||
sans = append(sans, v[0].PermittedDNSDomains...)
|
||||
}
|
||||
return v(ctx, sans)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user