auth interceptors: preserve error message

Signed-off-by: Adphi <philippe.adrien.nousse@gmail.com>
This commit is contained in:
Adphi 2022-11-03 17:13:28 +01:00
parent 1d3d5315a4
commit dcd2f18f65
Signed by: adphi
GPG Key ID: 46BE4062DB2397FF
2 changed files with 13 additions and 9 deletions

View File

@ -8,13 +8,14 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/anypb"
"go.linka.cloud/grpc/interceptors" "go.linka.cloud/grpc/interceptors"
) )
func ChainedAuthFuncs(fn ...grpc_auth.AuthFunc) grpc_auth.AuthFunc { func ChainedAuthFuncs(fn ...grpc_auth.AuthFunc) grpc_auth.AuthFunc {
return func(ctx context.Context) (context.Context, error) { return func(ctx context.Context) (context.Context, error) {
code := codes.Unauthenticated spb := status.New(codes.Unauthenticated, codes.Unauthenticated.String()).Proto()
for _, v := range fn { for _, v := range fn {
ctx2, err := v(ctx) ctx2, err := v(ctx)
if err == nil { if err == nil {
@ -24,11 +25,14 @@ func ChainedAuthFuncs(fn ...grpc_auth.AuthFunc) grpc_auth.AuthFunc {
if !ok { if !ok {
return ctx2, err return ctx2, err
} }
if s.Code() == codes.PermissionDenied { if spb.Code != s.Proto().Code {
code = codes.PermissionDenied spb.Code = s.Proto().Code
} }
d, _ := anypb.New(s.Proto())
spb.Details = append(spb.Details, d)
spb.Message += ", " + s.Proto().Message
} }
return ctx, status.Error(code, code.String()) return ctx, status.FromProto(spb).Err()
} }
} }

View File

@ -13,7 +13,7 @@ import (
"go.linka.cloud/grpc/errors" "go.linka.cloud/grpc/errors"
) )
func TestNotProtectededOnly(t *testing.T) { func TestNotProtectedOnly(t *testing.T) {
assert := assert2.New(t) assert := assert2.New(t)
i := &interceptor{o: options{ignoredMethods: []string{"/test.Service/ignored"}}} i := &interceptor{o: options{ignoredMethods: []string{"/test.Service/ignored"}}}
assert.False(i.isNotProtected("/test.Service/protected")) assert.False(i.isNotProtected("/test.Service/protected"))
@ -99,14 +99,14 @@ func TestChainedAuthFuncs(t *testing.T) {
name: "empty bearer", name: "empty bearer",
auth: "bearer ", auth: "bearer ",
err: true, err: true,
code: codes.PermissionDenied, code: codes.Unauthenticated,
}, },
{ {
name: "internal error", name: "internal error",
auth: "bearer internal", auth: "bearer internal",
internalError: true, internalError: true,
err: true, err: true,
code: codes.PermissionDenied, code: codes.Internal,
}, },
{ {
name: "multiple auth: first basic valid", name: "multiple auth: first basic valid",
@ -120,13 +120,13 @@ func TestChainedAuthFuncs(t *testing.T) {
name: "invalid auth: bearer", name: "invalid auth: bearer",
auth: "bearer noop", auth: "bearer noop",
err: true, err: true,
code: codes.PermissionDenied, code: codes.Unauthenticated,
}, },
{ {
name: "invalid auth: basic", name: "invalid auth: basic",
auth: BasicAuth("other", "other"), auth: BasicAuth("other", "other"),
err: true, err: true,
code: codes.PermissionDenied, code: codes.Unauthenticated,
}, },
} }