From c7096975b167f7749c31689e283c866834cbd61a Mon Sep 17 00:00:00 2001 From: Adphi Date: Tue, 27 Sep 2022 17:06:18 +0200 Subject: [PATCH] interceptors: add ban health: set services serving on start and not available on close Signed-off-by: Adphi --- .gitignore | 1 + README.md | 2 +- client/client.go | 5 +- example/example.go | 24 +++++- go.mod | 1 + go.sum | 2 + interceptors/ban/ban.go | 121 +++++++++++++++++++++++++++++++ interceptors/ban/context.go | 34 +++++++++ interceptors/ban/options.go | 83 +++++++++++++++++++++ interceptors/ban/rule.go | 52 +++++++++++++ interceptors/interceptors.go | 15 ++++ interceptors/metadata/forward.go | 15 +--- service/service.go | 17 ++++- 13 files changed, 350 insertions(+), 22 deletions(-) create mode 100644 interceptors/ban/ban.go create mode 100644 interceptors/ban/context.go create mode 100644 interceptors/ban/options.go create mode 100644 interceptors/ban/rule.go diff --git a/.gitignore b/.gitignore index dd67c09..33b84ac 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ .idea .bin /tmp +diff diff --git a/README.md b/README.md index 49375c3..5fd50db 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ Features: - [ ] context logger - [x] sentry - [ ] rate-limiting - - [ ] ban + - [x] ban - [ ] auth claim in context - [x] recovery (server side only) - [x] tracing (open-tracing) diff --git a/client/client.go b/client/client.go index 654b33f..6e4702c 100644 --- a/client/client.go +++ b/client/client.go @@ -9,6 +9,7 @@ import ( grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" "google.golang.org/grpc" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/resolver" "go.linka.cloud/grpc/registry/noop" @@ -34,8 +35,8 @@ func New(opts ...Option) (Client, error) { if c.opts.tlsConfig != nil { c.opts.dialOptions = append(c.opts.dialOptions, grpc.WithTransportCredentials(credentials.NewTLS(c.opts.tlsConfig))) } - if !c.opts.secure { - c.opts.dialOptions = append(c.opts.dialOptions, grpc.WithInsecure()) + if !c.opts.secure && c.opts.tlsConfig == nil { + c.opts.dialOptions = append(c.opts.dialOptions, grpc.WithTransportCredentials(insecure.NewCredentials())) } if len(c.opts.unaryInterceptors) > 0 { c.opts.dialOptions = append(c.opts.dialOptions, grpc.WithUnaryInterceptor(grpc_middleware.ChainUnaryClient(c.opts.unaryInterceptors...))) diff --git a/example/example.go b/example/example.go index 18caf86..e657c12 100644 --- a/example/example.go +++ b/example/example.go @@ -22,6 +22,7 @@ import ( "go.linka.cloud/grpc/client" "go.linka.cloud/grpc/interceptors/auth" + "go.linka.cloud/grpc/interceptors/ban" "go.linka.cloud/grpc/interceptors/defaulter" "go.linka.cloud/grpc/interceptors/iface" metrics2 "go.linka.cloud/grpc/interceptors/metrics" @@ -137,6 +138,7 @@ func run(opts ...service.Option) { service.WithMiddlewares(httpLogger), service.WithInterceptors(metrics), service.WithServerInterceptors( + ban.NewInterceptors(), auth.NewServerInterceptors(auth.WithBasicValidators(func(ctx context.Context, user, password string) (context.Context, error) { if !auth.Equals(user, "admin") || !auth.Equals(password, "admin") { return ctx, fmt.Errorf("invalid user or password") @@ -167,7 +169,7 @@ func run(opts ...service.Option) { } }() <-ready - s, err := client.New( + copts := []client.Option{ // client.WithName(name), // client.WithVersion(version), client.WithAddress("localhost:9991"), @@ -177,14 +179,28 @@ func run(opts ...service.Option) { logger.From(ctx).WithFields("party", "client", "method", method).Info(req) return invoker(ctx, method, req, reply, cc, opts...) }), - client.WithInterceptors(auth.NewBasicAuthClientIntereptors("admin", "admin")), - ) + } + s, err := client.New(copts...) if err != nil { log.Fatal(err) } g := NewGreeterClient(s) h := grpc_health_v1.NewHealthClient(s) - hres, err := h.Check(ctx, &grpc_health_v1.HealthCheckRequest{}) + for i := 0; i < 4; i++ { + _, err := h.Check(ctx, &grpc_health_v1.HealthCheckRequest{}) + if err != nil { + log.Error(err) + } else { + log.Fatalf("expected error") + } + } + s, err = client.New(append(copts, client.WithInterceptors(auth.NewBasicAuthClientIntereptors("admin", "admin")))...) + if err != nil { + log.Fatal(err) + } + g = NewGreeterClient(s) + h = grpc_health_v1.NewHealthClient(s) + hres, err := h.Check(ctx, &grpc_health_v1.HealthCheckRequest{Service: "helloworld.Greeter"}) if err != nil { log.Fatal(err) } diff --git a/go.mod b/go.mod index 433de09..9d797ef 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/grpc-ecosystem/grpc-gateway/v2 v2.5.0 github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645 github.com/improbable-eng/grpc-web v0.14.1 + github.com/jaredfolkins/badactor v1.2.0 github.com/johnbellone/grpc-middleware-sentry v0.2.0 github.com/justinas/alice v1.2.0 github.com/lyft/protoc-gen-star v0.6.0 // indirect diff --git a/go.sum b/go.sum index 0b908e5..eeeca5c 100644 --- a/go.sum +++ b/go.sum @@ -383,6 +383,8 @@ github.com/iris-contrib/go.uuid v2.0.0+incompatible/go.mod h1:iz2lgM/1UnEf1kP0L/ github.com/iris-contrib/jade v1.1.3/go.mod h1:H/geBymxJhShH5kecoiOCSssPX7QWYH7UaeZTSWddIk= github.com/iris-contrib/pongo2 v0.0.1/go.mod h1:Ssh+00+3GAZqSQb30AvBRNxBx7rf0GqwkjqxNd0u65g= github.com/iris-contrib/schema v0.0.1/go.mod h1:urYA3uvUNG1TIIjOSCzHr9/LmbQo8LrOcOqfqxa4hXw= +github.com/jaredfolkins/badactor v1.2.0 h1:QTJBsVG9qhdIFmFx5eNet2Q9hX8T+qZ1rC9NJwyN+Hc= +github.com/jaredfolkins/badactor v1.2.0/go.mod h1:ZynkTrC/ICU1o8mmFy3JySRCErXVlx7trZiWEH6DDg8= github.com/jhump/gopoet v0.0.0-20190322174617-17282ff210b3/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI= github.com/jhump/gopoet v0.1.0/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI= github.com/jhump/goprotoc v0.5.0/go.mod h1:VrbvcYrQOrTi3i0Vf+m+oqQWk9l72mjkJCYo7UvLHRQ= diff --git a/interceptors/ban/ban.go b/interceptors/ban/ban.go new file mode 100644 index 0000000..3fa11c1 --- /dev/null +++ b/interceptors/ban/ban.go @@ -0,0 +1,121 @@ +package ban + +import ( + "context" + + "github.com/jaredfolkins/badactor" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "go.linka.cloud/grpc/interceptors" + "go.linka.cloud/grpc/logger" +) + +type ban struct { + s *badactor.Studio + rules map[codes.Code]Rule + actor func(ctx context.Context) (string, bool, error) +} + +func NewInterceptors(opts ...Option) interceptors.ServerInterceptors { + o := defaultOptions + for _, opt := range opts { + opt(&o) + } + s := badactor.NewStudio(o.cap) + rules := make(map[codes.Code]Rule) + for _, r := range o.rules { + rules[r.Code] = r + s.AddRule(&badactor.Rule{ + Name: r.Name, + Message: r.Message, + StrikeLimit: r.StrikeLimit, + ExpireBase: r.ExpireBase, + Sentence: r.Sentence, + Action: &action{ + whenJailed: r.WhenJailed, + whenTimeServed: r.WhenTimeServed, + }, + }) + } + // we ignore the error because CreateDirectors never returns an error + _ = s.CreateDirectors(o.cap) + s.StartReaper(o.reaperInterval) + return &ban{s: s, rules: rules, actor: o.actorFunc} +} + +func (b *ban) UnaryServerInterceptor() grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { + actor, ok, err := b.check(ctx) + if err != nil { + return nil, err + } + ctx = set(ctx, b, actor) + if !ok { + return handler(ctx, req) + } + for _, v := range b.rules { + if b.s.IsJailedFor(actor, v.Name) { + return nil, status.Error(v.Code, v.Message) + } + } + res, err := handler(ctx, req) + if err != nil { + return nil, b.handleErr(ctx, actor, err) + } + return res, nil + } +} + +func (b *ban) StreamServerInterceptor() grpc.StreamServerInterceptor { + return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + actor, ok, err := b.check(ss.Context()) + if err != nil { + return err + } + ss = interceptors.NewContextServerStream(set(ss.Context(), b, actor), ss) + if !ok { + return handler(srv, ss) + } + if err := handler(srv, ss); err != nil { + return b.handleErr(ss.Context(), actor, err) + } + return nil + } +} + +func (b *ban) check(ctx context.Context) (actor string, ok bool, err error) { + actor, ok, err = b.actor(ctx) + if err != nil { + return "", false, err + } + if !ok { + return "", false, nil + } + for _, v := range b.rules { + if b.s.IsJailedFor(actor, v.Name) { + return actor, false, status.Error(v.Code, v.Message) + } + } + return actor, true, nil +} + +func (b *ban) handleErr(ctx context.Context, actor string, err error) error { + v, ok := ctx.Value(key{}).(*value) + if !ok || v.done { + return err + } + s, ok := status.FromError(err) + if !ok { + return err + } + r, ok := b.rules[s.Code()] + if !ok { + return err + } + if err := b.s.Infraction(actor, r.Name); err != nil { + logger.C(ctx).Warnf("%s: failed to add infraction: %v", r.Name, err) + } + return err +} diff --git a/interceptors/ban/context.go b/interceptors/ban/context.go new file mode 100644 index 0000000..879f203 --- /dev/null +++ b/interceptors/ban/context.go @@ -0,0 +1,34 @@ +package ban + +import ( + "context" +) + +type key struct{} + +type value struct { + ban *ban + actor string + done bool +} + +func set(ctx context.Context, b *ban, actor string) context.Context { + return context.WithValue(ctx, key{}, &value{ban: b, actor: actor}) +} + +func Infraction(ctx context.Context, rule string) error { + v, ok := ctx.Value(key{}).(*value) + if !ok { + return nil + } + v.done = true + return v.ban.s.Infraction(v.actor, rule) +} + +func Actor(ctx context.Context) string { + v, ok := ctx.Value(key{}).(*value) + if !ok { + return "" + } + return v.actor +} diff --git a/interceptors/ban/options.go b/interceptors/ban/options.go new file mode 100644 index 0000000..31b3400 --- /dev/null +++ b/interceptors/ban/options.go @@ -0,0 +1,83 @@ +package ban + +import ( + "context" + "time" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/peer" +) + +const ( + Unauthorized = "Unauthorized" + Unauthenticated = "Unauthenticated" +) + +var ( + defaultOptions = options{ + cap: 1024, + reaperInterval: 10 * time.Minute, + rules: defaultRules, + actorFunc: DefaultActorFunc, + } + + defaultRules = []Rule{ + { + Name: Unauthorized, + Message: "Too many unauthorized requests", + Code: codes.PermissionDenied, + StrikeLimit: 3, + ExpireBase: time.Second * 10, + Sentence: time.Second * 10, + }, + { + Name: Unauthenticated, + Message: "Too many unauthenticated requests", + Code: codes.Unauthenticated, + StrikeLimit: 3, + ExpireBase: time.Second * 10, + Sentence: time.Second * 10, + }, + } +) + +func DefaultActorFunc(ctx context.Context) (string, bool, error) { + p, ok := peer.FromContext(ctx) + if !ok { + return "", false, nil + } + return p.Addr.String(), true, nil +} + +type Option func(*options) + +func WithCapacity(cap int32) Option { + return func(o *options) { + o.cap = cap + } +} + +func WithRules(rules ...Rule) Option { + return func(o *options) { + o.rules = rules + } +} + +func WithReaperInterval(interval time.Duration) Option { + return func(o *options) { + o.reaperInterval = interval + } +} + +func WithActorFunc(f func(context.Context) (name string, found bool, err error)) Option { + return func(o *options) { + o.actorFunc = f + } +} + +type options struct { + cap int32 + rules []Rule + reaperInterval time.Duration + actorFunc func(ctx context.Context) (name string, found bool, err error) +} diff --git a/interceptors/ban/rule.go b/interceptors/ban/rule.go new file mode 100644 index 0000000..9a44479 --- /dev/null +++ b/interceptors/ban/rule.go @@ -0,0 +1,52 @@ +package ban + +import ( + "time" + + "github.com/jaredfolkins/badactor" + "google.golang.org/grpc/codes" +) + +type Rule struct { + Name string + Message string + Code codes.Code + StrikeLimit int + ExpireBase time.Duration + Sentence time.Duration + // WhenJailed is an optional function to call when an Actor isJailed + WhenJailed func(actor string, r *Rule) error + // WhenTimeServed is an optional function to call when an Actor is released because of timeServed + WhenTimeServed func(actor string, r *Rule) error +} + +type action struct { + whenJailed func(actor string, r *Rule) error + whenTimeServed func(actor string, r *Rule) error +} + +func (a2 *action) WhenJailed(a *badactor.Actor, r *badactor.Rule) error { + if a2.whenJailed != nil { + return a2.whenJailed(a.Name(), &Rule{ + Name: r.Name, + Message: r.Message, + StrikeLimit: r.StrikeLimit, + ExpireBase: r.ExpireBase, + Sentence: r.Sentence, + }) + } + return nil +} + +func (a2 *action) WhenTimeServed(a *badactor.Actor, r *badactor.Rule) error { + if a2.whenTimeServed != nil { + return a2.whenTimeServed(a.Name(), &Rule{ + Name: r.Name, + Message: r.Message, + StrikeLimit: r.StrikeLimit, + ExpireBase: r.ExpireBase, + Sentence: r.Sentence, + }) + } + return nil +} diff --git a/interceptors/interceptors.go b/interceptors/interceptors.go index 72ae9cb..dcd8cba 100644 --- a/interceptors/interceptors.go +++ b/interceptors/interceptors.go @@ -1,6 +1,8 @@ package interceptors import ( + "context" + "google.golang.org/grpc" ) @@ -18,3 +20,16 @@ type Interceptors interface { ServerInterceptors ClientInterceptors } + +func NewContextServerStream(ctx context.Context, ss grpc.ServerStream) grpc.ServerStream { + return &ContextWrapper{ServerStream: ss, ctx: ctx} +} + +type ContextWrapper struct { + grpc.ServerStream + ctx context.Context +} + +func (w *ContextWrapper) Context() context.Context { + return w.ctx +} diff --git a/interceptors/metadata/forward.go b/interceptors/metadata/forward.go index dd01c2b..813926c 100644 --- a/interceptors/metadata/forward.go +++ b/interceptors/metadata/forward.go @@ -35,19 +35,6 @@ func (f *forward) StreamServerInterceptor() grpc.StreamServerInterceptor { if md2, ok := metadata.FromOutgoingContext(ctx); ok { o = metadata.Join(o, md2.Copy()) } - return handler(srv, NewContextServerStream(metadata.NewOutgoingContext(ctx, o), ss)) + return handler(srv, interceptors.NewContextServerStream(metadata.NewOutgoingContext(ctx, o), ss)) } } - -func NewContextServerStream(ctx context.Context, ss grpc.ServerStream) grpc.ServerStream { - return &ContextWrapper{ServerStream: ss, ctx: ctx} -} - -type ContextWrapper struct { - grpc.ServerStream - ctx context.Context -} - -func (w *ContextWrapper) Context() context.Context { - return w.ctx -} diff --git a/service/service.go b/service/service.go index fa57003..0991a5f 100644 --- a/service/service.go +++ b/service/service.go @@ -56,6 +56,8 @@ type service struct { inproc *inprocgrpc.Channel services map[string]*serviceInfo + healthServer *health.Server + id string regSvc *registry.Service closed chan struct{} @@ -120,7 +122,8 @@ func newService(opts ...Option) (*service, error) { greflect.Register(s.server) } if s.opts.health { - s.registerService(&grpc_health_v1.Health_ServiceDesc, health.NewServer()) + s.healthServer = health.NewServer() + s.registerService(&grpc_health_v1.Health_ServiceDesc, s.healthServer) } if err := s.gateway(s.opts.gatewayOpts...); err != nil { return nil, err @@ -226,6 +229,18 @@ func (s *service) run() error { } errs <- nil }() + if s.healthServer != nil { + for k := range s.services { + s.healthServer.SetServingStatus(k, grpc_health_v1.HealthCheckResponse_SERVING) + } + defer func() { + if s.healthServer != nil { + for k := range s.services { + s.healthServer.SetServingStatus(k, grpc_health_v1.HealthCheckResponse_NOT_SERVING) + } + } + }() + } for i := range s.opts.afterStart { if err := s.opts.afterStart[i](); err != nil { s.mu.Unlock()