diff --git a/errors/errors.go b/errors/errors.go index 6351592..32c3be1 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -1,6 +1,10 @@ package errors import ( + "context" + "errors" + "strings" + status2 "google.golang.org/genproto/googleapis/rpc/status" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -8,6 +12,10 @@ import ( "google.golang.org/protobuf/types/known/anypb" ) +func Canceled(err error) error { + return status.Error(codes.Canceled, err.Error()) +} + func InvalidArgument(err error) error { return status.Error(codes.InvalidArgument, err.Error()) } @@ -63,6 +71,12 @@ func makeDetails(m ...proto.Message) []*anypb.Any { return out } +func Canceledf(msg string, args ...interface{}) error { + return status.Errorf(codes.Canceled, msg, args...) +} +func Canceledd(err error, details ...proto.Message) error { + return statusErr(codes.Canceled, err, details...) +} func InvalidArgumentf(msg string, args ...interface{}) error { return status.Errorf(codes.InvalidArgument, msg, args...) } @@ -152,7 +166,7 @@ func IsCanceled(err error) bool { if err == nil { return false } - return status.Convert(err).Code() == codes.Canceled + return status.Convert(err).Code() == codes.Canceled || IsCanceled(err) } func IsUnknown(err error) bool { if err == nil { @@ -170,7 +184,7 @@ func IsDeadlineExceeded(err error) bool { if err == nil { return false } - return status.Convert(err).Code() == codes.DeadlineExceeded + return status.Convert(err).Code() == codes.DeadlineExceeded || IsContextDeadlineExceeded(err) } func IsNotFound(err error) bool { if err == nil { @@ -244,3 +258,30 @@ func IsUnauthenticated(err error) bool { } return status.Convert(err).Code() == codes.Unauthenticated } + +func IsContextCanceled(err error) bool { + err = Unwrap(err) + if err == nil { + return false + } + return strings.Contains(err.Error(), context.Canceled.Error()) +} + +func IsContextDeadlineExceeded(err error) bool { + err = Unwrap(err) + if err == nil { + return false + } + return strings.Contains(err.Error(), context.DeadlineExceeded.Error()) +} + +func Unwrap(err error) error { + s, ok := status.FromError(err) + if s == nil { + return nil + } + if ok { + return errors.New(s.Message()) + } + return err +} diff --git a/example/example.go b/example/example.go index ce744ab..a86d30b 100644 --- a/example/example.go +++ b/example/example.go @@ -20,6 +20,7 @@ import ( "google.golang.org/protobuf/types/descriptorpb" "go.linka.cloud/grpc/client" + "go.linka.cloud/grpc/interceptors/auth" "go.linka.cloud/grpc/interceptors/defaulter" metrics2 "go.linka.cloud/grpc/interceptors/metrics" validation2 "go.linka.cloud/grpc/interceptors/validation" @@ -113,7 +114,15 @@ func run(opts ...service.Option) { service.WithGRPCWeb(true), service.WithGRPCWebPrefix("/grpc"), service.WithMiddlewares(httpLogger), - service.WithInterceptors(metrics, defaulter, validation), + service.WithInterceptors(metrics), + service.WithServerInterceptors(auth.NewServerInterceptors(auth.WithBasicValidators(func(ctx context.Context, user, password string) (context.Context, error) { + if !auth.Equals(user, "admin") || !auth.Equals(password, "admin") { + return ctx, fmt.Errorf("invalid user or password") + } + log.Infof("request authenticated") + return ctx, nil + }))), + service.WithInterceptors(defaulter, validation), ) svc, err = service.New(opts...) if err != nil { @@ -143,6 +152,7 @@ func run(opts ...service.Option) { logger.From(ctx).WithFields("party", "client", "method", method).Info(req) return invoker(ctx, method, req, reply, cc, opts...) }), + client.WithInterceptors(auth.NewBasicAuthClientIntereptors("admin", "admin")), ) if err != nil { log.Fatal(err) @@ -193,7 +203,13 @@ func run(opts ...service.Option) { req := `{"name":"test"}` do := func(url, contentType string) { - resp, err := httpc.Post(url, contentType, strings.NewReader(req)) + req, err := http.NewRequest(http.MethodPost, url, strings.NewReader(req)) + if err != nil { + log.Fatal(err) + } + req.Header.Set("content-type", contentType) + req.Header.Set("authorization", auth.BasicAuth("admin", "admin")) + resp, err := httpc.Do(req) if err != nil { log.Fatal(err) } diff --git a/interceptors/metadata/metadata.go b/interceptors/metadata/metadata.go index 6f7e8f7..297df47 100644 --- a/interceptors/metadata/metadata.go +++ b/interceptors/metadata/metadata.go @@ -37,14 +37,14 @@ func (i mdInterceptors) StreamServerInterceptor() grpc.StreamServerInterceptor { func (i mdInterceptors) UnaryClientInterceptor() grpc.UnaryClientInterceptor { return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { - ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs(i.pairs...)) + ctx = metadata.AppendToOutgoingContext(ctx, i.pairs...) return invoker(ctx, method, req, reply, cc, opts...) } } func (i mdInterceptors) StreamClientInterceptor() grpc.StreamClientInterceptor { return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { - ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs(i.pairs...)) + ctx = metadata.AppendToOutgoingContext(ctx, i.pairs...) return streamer(ctx, desc, cc, method, opts...) } }