mirror of
https://github.com/linka-cloud/grpc.git
synced 2024-11-25 12:26:26 +00:00
add auth interceptors
This commit is contained in:
parent
55251b5020
commit
e578d62a29
36
interceptors/auth/basic.go
Normal file
36
interceptors/auth/basic.go
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
|
||||||
|
|
||||||
|
"go.linka.cloud/grpc/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BasicAuth(user, password string) string {
|
||||||
|
return "basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+password))
|
||||||
|
}
|
||||||
|
|
||||||
|
type BasicValidator func(ctx context.Context, user, password string) (context.Context,error)
|
||||||
|
|
||||||
|
func makeBasicAuthFunc(v BasicValidator) grpc_auth.AuthFunc {
|
||||||
|
return func(ctx context.Context) (context.Context, error) {
|
||||||
|
a, err := grpc_auth.AuthFromMD(ctx, "basic")
|
||||||
|
if err != nil {
|
||||||
|
return ctx, err
|
||||||
|
}
|
||||||
|
c, err := base64.StdEncoding.DecodeString(a)
|
||||||
|
if err != nil {
|
||||||
|
return ctx, err
|
||||||
|
}
|
||||||
|
cs := string(c)
|
||||||
|
s := strings.IndexByte(cs, ':')
|
||||||
|
if s < 0 {
|
||||||
|
return ctx, errors.Unauthenticatedf("malformed basic auth")
|
||||||
|
}
|
||||||
|
return v(ctx, cs[:s], cs[s+1:])
|
||||||
|
}
|
||||||
|
}
|
94
interceptors/auth/interceptors.go
Normal file
94
interceptors/auth/interceptors.go
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"go.linka.cloud/grpc/interceptors"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ChainedAuthFuncs(fn ...grpc_auth.AuthFunc) grpc_auth.AuthFunc {
|
||||||
|
return func(ctx context.Context) (context.Context, error) {
|
||||||
|
code := codes.Unauthenticated
|
||||||
|
for _, v := range fn {
|
||||||
|
ctx2, err := v(ctx)
|
||||||
|
if err == nil {
|
||||||
|
return ctx2, nil
|
||||||
|
}
|
||||||
|
s, ok := status.FromError(err)
|
||||||
|
if !ok {
|
||||||
|
return ctx2, err
|
||||||
|
}
|
||||||
|
if s.Code() == codes.PermissionDenied {
|
||||||
|
code = codes.PermissionDenied
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ctx, status.Error(code, code.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewServerInterceptors(opts ...Option) interceptors.ServerInterceptors {
|
||||||
|
o := options{}
|
||||||
|
for _, v := range opts {
|
||||||
|
v(&o)
|
||||||
|
}
|
||||||
|
return &interceptor{o: o, authFn: ChainedAuthFuncs(o.authFns...)}
|
||||||
|
}
|
||||||
|
|
||||||
|
type interceptor struct{
|
||||||
|
o options
|
||||||
|
authFn grpc_auth.AuthFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *interceptor) UnaryServerInterceptor() grpc.UnaryServerInterceptor {
|
||||||
|
a := grpc_auth.UnaryServerInterceptor(i.authFn)
|
||||||
|
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
|
||||||
|
if i.isNotProtected(info.FullMethod) {
|
||||||
|
return handler(ctx, req)
|
||||||
|
}
|
||||||
|
return a(ctx, req, info, handler)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *interceptor) StreamServerInterceptor() grpc.StreamServerInterceptor {
|
||||||
|
a := grpc_auth.StreamServerInterceptor(i.authFn)
|
||||||
|
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||||
|
if i.isNotProtected(info.FullMethod) {
|
||||||
|
return handler(srv, ss)
|
||||||
|
}
|
||||||
|
return a(srv, ss, info, handler)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *interceptor) isNotProtected(endpoint string) bool {
|
||||||
|
// default to not ignored
|
||||||
|
if len(i.o.ignoredMethods) == 0 && len(i.o.methods) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// endpoint is like /helloworld.Greeter/SayHello
|
||||||
|
parts := strings.Split(strings.TrimPrefix(endpoint, "/"), "/")
|
||||||
|
// invalid endpoint format
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
method := parts[1]
|
||||||
|
for _, v := range i.o.ignoredMethods {
|
||||||
|
if v == method {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(i.o.methods) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, v := range i.o.methods {
|
||||||
|
if v == method {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
146
interceptors/auth/interceptors_test.go
Normal file
146
interceptors/auth/interceptors_test.go
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
|
||||||
|
assert2 "github.com/stretchr/testify/assert"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/metadata"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"go.linka.cloud/grpc/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNotProtectededOnly(t *testing.T) {
|
||||||
|
assert := assert2.New(t)
|
||||||
|
i := &interceptor{o: options{ignoredMethods: []string{"ignored"}}}
|
||||||
|
assert.False(i.isNotProtected("/test.Service/protected"))
|
||||||
|
assert.True(i.isNotProtected("/test.Service/ignored"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProtectedOnly(t *testing.T) {
|
||||||
|
assert := assert2.New(t)
|
||||||
|
i := &interceptor{o: options{methods: []string{"protected"}}}
|
||||||
|
assert.False(i.isNotProtected("/test.Service/protected"))
|
||||||
|
assert.True(i.isNotProtected("/test.Service/ignored"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProtectedAndIgnored(t *testing.T) {
|
||||||
|
assert := assert2.New(t)
|
||||||
|
i := &interceptor{o: options{methods: []string{"protected"}, ignoredMethods: []string{"ignored"}}}
|
||||||
|
assert.True(i.isNotProtected("/test.Service/ignored"))
|
||||||
|
assert.False(i.isNotProtected("/test.Service/protected"))
|
||||||
|
assert.True(i.isNotProtected("/test.Service/other"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProtectedByDefault(t *testing.T) {
|
||||||
|
i := &interceptor{}
|
||||||
|
assert2.False(t, i.isNotProtected("nooop"))
|
||||||
|
assert2.False(t, i.isNotProtected("/test.Service/method/cannotExists"))
|
||||||
|
assert2.False(t, i.isNotProtected("/test.Service/validMethod"))
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
adminAuth = func(ctx context.Context, user, password string) (context.Context, error) {
|
||||||
|
if user == "admin" && password == "admin" {
|
||||||
|
return ctx, nil
|
||||||
|
}
|
||||||
|
return ctx, errors.PermissionDeniedf("")
|
||||||
|
}
|
||||||
|
testAuth = func(ctx context.Context, user, password string) (context.Context, error) {
|
||||||
|
if user == "test" && password == "test" {
|
||||||
|
return ctx, nil
|
||||||
|
}
|
||||||
|
return ctx, errors.PermissionDeniedf("")
|
||||||
|
}
|
||||||
|
tokenAuth = func(ctx context.Context, token string) (context.Context, error) {
|
||||||
|
if token == "token" {
|
||||||
|
return ctx, nil
|
||||||
|
}
|
||||||
|
return ctx, errors.PermissionDeniedf("")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestChainedAuthFuncs(t *testing.T) {
|
||||||
|
wantInternalError := false
|
||||||
|
ctx := context.Background()
|
||||||
|
auth := ChainedAuthFuncs([]grpc_auth.AuthFunc{
|
||||||
|
makeBasicAuthFunc(adminAuth),
|
||||||
|
makeBasicAuthFunc(testAuth),
|
||||||
|
makeTokenAuthFunc(tokenAuth),
|
||||||
|
makeTokenAuthFunc(func(ctx context.Context, token string) (context.Context, error) {
|
||||||
|
if wantInternalError {
|
||||||
|
return ctx, errors.Internalf("ooops")
|
||||||
|
}
|
||||||
|
return ctx, errors.Unauthenticatedf("")
|
||||||
|
}),
|
||||||
|
}...)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
auth string
|
||||||
|
internalError bool
|
||||||
|
err bool
|
||||||
|
code codes.Code
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no auth",
|
||||||
|
auth: "",
|
||||||
|
err: true,
|
||||||
|
code: codes.Unauthenticated,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid token",
|
||||||
|
auth: "bearer token",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty bearer",
|
||||||
|
auth: "bearer ",
|
||||||
|
err: true,
|
||||||
|
code: codes.PermissionDenied,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "internal error",
|
||||||
|
auth: "bearer internal",
|
||||||
|
internalError: true,
|
||||||
|
err: true,
|
||||||
|
code: codes.PermissionDenied,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple auth: first basic valid",
|
||||||
|
auth: BasicAuth("admin", "admin"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple auth: second baisc valid",
|
||||||
|
auth: BasicAuth("test", "test"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid auth: bearer",
|
||||||
|
auth: "bearer noop",
|
||||||
|
err: true,
|
||||||
|
code: codes.PermissionDenied,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid auth: basic",
|
||||||
|
auth: BasicAuth("other", "other"),
|
||||||
|
err: true,
|
||||||
|
code: codes.PermissionDenied,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
wantInternalError = tt.internalError
|
||||||
|
rctx, err := auth(metadata.NewIncomingContext(ctx, metadata.Pairs("authorization", tt.auth)))
|
||||||
|
if tt.err {
|
||||||
|
assert2.Error(t, err)
|
||||||
|
s, ok := status.FromError(err)
|
||||||
|
assert2.True(t, ok)
|
||||||
|
assert2.Equal(t, tt.code, s.Code())
|
||||||
|
}
|
||||||
|
assert2.NotNil(t, rctx)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
56
interceptors/auth/options.go
Normal file
56
interceptors/auth/options.go
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Option func(o *options)
|
||||||
|
|
||||||
|
func WithMethods(methods ...string) Option {
|
||||||
|
return func(o *options) {
|
||||||
|
o.methods = append(o.methods, methods...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithIgnoredMethods(methods ...string) Option {
|
||||||
|
return func(o *options) {
|
||||||
|
o.ignoredMethods = append(o.ignoredMethods, methods...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithBasicValidators(validators ...BasicValidator) Option {
|
||||||
|
var authFns []grpc_auth.AuthFunc
|
||||||
|
for _, v := range validators {
|
||||||
|
authFns = append(authFns, makeBasicAuthFunc(v))
|
||||||
|
}
|
||||||
|
return func(o *options) {
|
||||||
|
o.authFns = append(o.authFns, authFns...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithTokenValidators(validators ...TokenValidator) Option {
|
||||||
|
var authFns []grpc_auth.AuthFunc
|
||||||
|
for _, v := range validators {
|
||||||
|
authFns = append(authFns, makeTokenAuthFunc(v))
|
||||||
|
}
|
||||||
|
return func(o *options) {
|
||||||
|
o.authFns = append(o.authFns, authFns...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithX509Validators(validators ...X509Validator) Option {
|
||||||
|
var authFns []grpc_auth.AuthFunc
|
||||||
|
for _, v := range validators {
|
||||||
|
authFns = append(authFns, makeX509AuthFunc(v))
|
||||||
|
}
|
||||||
|
return func(o *options) {
|
||||||
|
o.authFns = append(o.authFns, authFns...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type options struct {
|
||||||
|
methods []string
|
||||||
|
ignoredMethods []string
|
||||||
|
|
||||||
|
authFns []grpc_auth.AuthFunc
|
||||||
|
}
|
19
interceptors/auth/token.go
Normal file
19
interceptors/auth/token.go
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TokenValidator func(ctx context.Context, token string) (context.Context, error)
|
||||||
|
|
||||||
|
func makeTokenAuthFunc(v TokenValidator) grpc_auth.AuthFunc {
|
||||||
|
return func(ctx context.Context) (context.Context, error) {
|
||||||
|
a, err := grpc_auth.AuthFromMD(ctx, "bearer")
|
||||||
|
if err != nil {
|
||||||
|
return ctx, err
|
||||||
|
}
|
||||||
|
return v(ctx, a)
|
||||||
|
}
|
||||||
|
}
|
49
interceptors/auth/x509.go
Normal file
49
interceptors/auth/x509.go
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
|
||||||
|
"google.golang.org/grpc/credentials"
|
||||||
|
"google.golang.org/grpc/peer"
|
||||||
|
|
||||||
|
"go.linka.cloud/grpc/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type X509Validator func(ctx context.Context, sans []string) (context.Context, error)
|
||||||
|
|
||||||
|
// func _(ctx context.Context) {
|
||||||
|
// p, ok := peer.FromContext(ctx)
|
||||||
|
// if !ok {
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
// i, ok := p.AuthInfo.(credentials.TLSInfo)
|
||||||
|
// if !ok {
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
// i.State.VerifiedChains
|
||||||
|
// }
|
||||||
|
|
||||||
|
func makeX509AuthFunc(v X509Validator) grpc_auth.AuthFunc {
|
||||||
|
return func(ctx context.Context) (context.Context, error) {
|
||||||
|
p, ok := peer.FromContext(ctx)
|
||||||
|
if !ok {
|
||||||
|
return ctx, errors.Internalf("peer not found")
|
||||||
|
}
|
||||||
|
i, ok := p.AuthInfo.(credentials.TLSInfo)
|
||||||
|
if !ok {
|
||||||
|
return ctx, errors.Unauthenticatedf("no TLS credentials")
|
||||||
|
}
|
||||||
|
if !i.State.HandshakeComplete {
|
||||||
|
return ctx, errors.Unauthenticatedf("handshake not complete")
|
||||||
|
}
|
||||||
|
var sans []string
|
||||||
|
for _, v := range i.State.VerifiedChains {
|
||||||
|
if len(v) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
sans = append(sans, v[0].PermittedDNSDomains...)
|
||||||
|
}
|
||||||
|
return v(ctx, sans)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user