From ef9a12d89e872ffcafd1909e89caa7f954e3b8b8 Mon Sep 17 00:00:00 2001 From: Adphi Date: Wed, 7 Dec 2022 14:09:36 +0100 Subject: [PATCH] ban: more defaults options, simpler callback Signed-off-by: Adphi --- example/example.go | 9 ++++- interceptors/ban/ban.go | 17 ++++++--- interceptors/ban/options.go | 38 ++++++++++++------- interceptors/ban/rule.go | 74 ++++++++++++++++++++++--------------- 4 files changed, 87 insertions(+), 51 deletions(-) diff --git a/example/example.go b/example/example.go index 885f45b..ed010b8 100644 --- a/example/example.go +++ b/example/example.go @@ -138,7 +138,10 @@ func run(opts ...service.Option) { service.WithMiddlewares(httpLogger), service.WithInterceptors(metrics), service.WithServerInterceptors( - ban.NewInterceptors(), + ban.NewInterceptors(ban.WithDefaultJailDuration(time.Second), ban.WithDefaultCallback(func(action ban.Action, actor string, rule *ban.Rule) error { + log.WithFields("action", action, "actor", actor, "rule", rule.Name).Info("ban callback") + return nil + })), auth.NewServerInterceptors(auth.WithBasicValidators(func(ctx context.Context, user, password string) (context.Context, error) { if !auth.Equals(user, "admin") || !auth.Equals(password, "admin") { return ctx, fmt.Errorf("invalid user or password") @@ -187,7 +190,7 @@ func run(opts ...service.Option) { } g := NewGreeterClient(s) h := grpc_health_v1.NewHealthClient(s) - for i := 0; i < 4; i++ { + for i := 0; i < 5; i++ { _, err := h.Check(ctx, &grpc_health_v1.HealthCheckRequest{}) if err != nil { log.Error(err) @@ -195,6 +198,8 @@ func run(opts ...service.Option) { log.Fatalf("expected error") } } + log.Infof("waiting for unban") + time.Sleep(time.Second) s, err = client.New(append(copts, client.WithInterceptors(auth.NewBasicAuthClientIntereptors("admin", "admin")))...) if err != nil { log.Fatal(err) diff --git a/interceptors/ban/ban.go b/interceptors/ban/ban.go index 3fa11c1..b5404e4 100644 --- a/interceptors/ban/ban.go +++ b/interceptors/ban/ban.go @@ -27,16 +27,21 @@ func NewInterceptors(opts ...Option) interceptors.ServerInterceptors { rules := make(map[codes.Code]Rule) for _, r := range o.rules { rules[r.Code] = r + callback := r.Callback + if callback == nil { + callback = o.defaultCallback + } + expire := r.JailDuration + if expire == 0 { + expire = o.defaultJailDuration + } s.AddRule(&badactor.Rule{ Name: r.Name, Message: r.Message, StrikeLimit: r.StrikeLimit, - ExpireBase: r.ExpireBase, - Sentence: r.Sentence, - Action: &action{ - whenJailed: r.WhenJailed, - whenTimeServed: r.WhenTimeServed, - }, + ExpireBase: expire, + Sentence: expire, + Action: &action{fn: callback}, }) } // we ignore the error because CreateDirectors never returns an error diff --git a/interceptors/ban/options.go b/interceptors/ban/options.go index fa6c403..0b12f99 100644 --- a/interceptors/ban/options.go +++ b/interceptors/ban/options.go @@ -16,10 +16,12 @@ const ( var ( defaultOptions = options{ - cap: 1024, - reaperInterval: 10 * time.Minute, - rules: defaultRules, - actorFunc: DefaultActorFunc, + cap: 1024, + reaperInterval: 10 * time.Minute, + rules: defaultRules, + actorFunc: DefaultActorFunc, + defaultCallback: nil, + defaultJailDuration: 10 * time.Second, } defaultRules = []Rule{ @@ -28,16 +30,12 @@ var ( Message: "Too many unauthorized requests", Code: codes.PermissionDenied, StrikeLimit: 3, - ExpireBase: time.Second * 10, - Sentence: time.Second * 10, }, { Name: Unauthenticated, Message: "Too many unauthenticated requests", Code: codes.Unauthenticated, StrikeLimit: 3, - ExpireBase: time.Second * 10, - Sentence: time.Second * 10, }, } ) @@ -79,9 +77,23 @@ func WithActorFunc(f func(context.Context) (name string, found bool, err error)) } } -type options struct { - cap int32 - rules []Rule - reaperInterval time.Duration - actorFunc func(ctx context.Context) (name string, found bool, err error) +func WithDefaultCallback(f ActionCallback) Option { + return func(o *options) { + o.defaultCallback = f + } +} + +func WithDefaultJailDuration(expire time.Duration) Option { + return func(o *options) { + o.defaultJailDuration = expire + } +} + +type options struct { + cap int32 + rules []Rule + reaperInterval time.Duration + actorFunc func(ctx context.Context) (name string, found bool, err error) + defaultCallback ActionCallback + defaultJailDuration time.Duration } diff --git a/interceptors/ban/rule.go b/interceptors/ban/rule.go index 9a44479..172ed84 100644 --- a/interceptors/ban/rule.go +++ b/interceptors/ban/rule.go @@ -7,46 +7,60 @@ import ( "google.golang.org/grpc/codes" ) +type ActionCallback func(action Action, actor string, rule *Rule) error + +type Action int + +const ( + Jailed Action = iota + Released +) + +func (a Action) String() string { + switch a { + case Jailed: + return "Jailed" + case Released: + return "Released" + default: + return "Unknown" + } +} + type Rule struct { - Name string - Message string - Code codes.Code - StrikeLimit int - ExpireBase time.Duration - Sentence time.Duration - // WhenJailed is an optional function to call when an Actor isJailed - WhenJailed func(actor string, r *Rule) error - // WhenTimeServed is an optional function to call when an Actor is released because of timeServed - WhenTimeServed func(actor string, r *Rule) error + Name string + Message string + Code codes.Code + StrikeLimit int + JailDuration time.Duration + // Callback is an optional function to call when an Actor isJailed or released because of timeServed + Callback ActionCallback } type action struct { - whenJailed func(actor string, r *Rule) error - whenTimeServed func(actor string, r *Rule) error + fn ActionCallback } func (a2 *action) WhenJailed(a *badactor.Actor, r *badactor.Rule) error { - if a2.whenJailed != nil { - return a2.whenJailed(a.Name(), &Rule{ - Name: r.Name, - Message: r.Message, - StrikeLimit: r.StrikeLimit, - ExpireBase: r.ExpireBase, - Sentence: r.Sentence, - }) + if a2.fn == nil { + return nil } - return nil + return a2.fn(Jailed, a.Name(), &Rule{ + Name: r.Name, + Message: r.Message, + StrikeLimit: r.StrikeLimit, + JailDuration: r.ExpireBase, + }) } func (a2 *action) WhenTimeServed(a *badactor.Actor, r *badactor.Rule) error { - if a2.whenTimeServed != nil { - return a2.whenTimeServed(a.Name(), &Rule{ - Name: r.Name, - Message: r.Message, - StrikeLimit: r.StrikeLimit, - ExpireBase: r.ExpireBase, - Sentence: r.Sentence, - }) + if a2.fn == nil { + return nil } - return nil + return a2.fn(Released, a.Name(), &Rule{ + Name: r.Name, + Message: r.Message, + StrikeLimit: r.StrikeLimit, + JailDuration: r.ExpireBase, + }) }