ban: more defaults options, simpler callback

Signed-off-by: Adphi <philippe.adrien.nousse@gmail.com>
This commit is contained in:
Adphi 2022-12-07 14:09:36 +01:00
parent 01b37a0d91
commit ef9a12d89e
Signed by: adphi
GPG Key ID: 46BE4062DB2397FF
4 changed files with 87 additions and 51 deletions

View File

@ -138,7 +138,10 @@ func run(opts ...service.Option) {
service.WithMiddlewares(httpLogger), service.WithMiddlewares(httpLogger),
service.WithInterceptors(metrics), service.WithInterceptors(metrics),
service.WithServerInterceptors( service.WithServerInterceptors(
ban.NewInterceptors(), ban.NewInterceptors(ban.WithDefaultJailDuration(time.Second), ban.WithDefaultCallback(func(action ban.Action, actor string, rule *ban.Rule) error {
log.WithFields("action", action, "actor", actor, "rule", rule.Name).Info("ban callback")
return nil
})),
auth.NewServerInterceptors(auth.WithBasicValidators(func(ctx context.Context, user, password string) (context.Context, error) { auth.NewServerInterceptors(auth.WithBasicValidators(func(ctx context.Context, user, password string) (context.Context, error) {
if !auth.Equals(user, "admin") || !auth.Equals(password, "admin") { if !auth.Equals(user, "admin") || !auth.Equals(password, "admin") {
return ctx, fmt.Errorf("invalid user or password") return ctx, fmt.Errorf("invalid user or password")
@ -187,7 +190,7 @@ func run(opts ...service.Option) {
} }
g := NewGreeterClient(s) g := NewGreeterClient(s)
h := grpc_health_v1.NewHealthClient(s) h := grpc_health_v1.NewHealthClient(s)
for i := 0; i < 4; i++ { for i := 0; i < 5; i++ {
_, err := h.Check(ctx, &grpc_health_v1.HealthCheckRequest{}) _, err := h.Check(ctx, &grpc_health_v1.HealthCheckRequest{})
if err != nil { if err != nil {
log.Error(err) log.Error(err)
@ -195,6 +198,8 @@ func run(opts ...service.Option) {
log.Fatalf("expected error") log.Fatalf("expected error")
} }
} }
log.Infof("waiting for unban")
time.Sleep(time.Second)
s, err = client.New(append(copts, client.WithInterceptors(auth.NewBasicAuthClientIntereptors("admin", "admin")))...) s, err = client.New(append(copts, client.WithInterceptors(auth.NewBasicAuthClientIntereptors("admin", "admin")))...)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)

View File

@ -27,16 +27,21 @@ func NewInterceptors(opts ...Option) interceptors.ServerInterceptors {
rules := make(map[codes.Code]Rule) rules := make(map[codes.Code]Rule)
for _, r := range o.rules { for _, r := range o.rules {
rules[r.Code] = r rules[r.Code] = r
callback := r.Callback
if callback == nil {
callback = o.defaultCallback
}
expire := r.JailDuration
if expire == 0 {
expire = o.defaultJailDuration
}
s.AddRule(&badactor.Rule{ s.AddRule(&badactor.Rule{
Name: r.Name, Name: r.Name,
Message: r.Message, Message: r.Message,
StrikeLimit: r.StrikeLimit, StrikeLimit: r.StrikeLimit,
ExpireBase: r.ExpireBase, ExpireBase: expire,
Sentence: r.Sentence, Sentence: expire,
Action: &action{ Action: &action{fn: callback},
whenJailed: r.WhenJailed,
whenTimeServed: r.WhenTimeServed,
},
}) })
} }
// we ignore the error because CreateDirectors never returns an error // we ignore the error because CreateDirectors never returns an error

View File

@ -20,6 +20,8 @@ var (
reaperInterval: 10 * time.Minute, reaperInterval: 10 * time.Minute,
rules: defaultRules, rules: defaultRules,
actorFunc: DefaultActorFunc, actorFunc: DefaultActorFunc,
defaultCallback: nil,
defaultJailDuration: 10 * time.Second,
} }
defaultRules = []Rule{ defaultRules = []Rule{
@ -28,16 +30,12 @@ var (
Message: "Too many unauthorized requests", Message: "Too many unauthorized requests",
Code: codes.PermissionDenied, Code: codes.PermissionDenied,
StrikeLimit: 3, StrikeLimit: 3,
ExpireBase: time.Second * 10,
Sentence: time.Second * 10,
}, },
{ {
Name: Unauthenticated, Name: Unauthenticated,
Message: "Too many unauthenticated requests", Message: "Too many unauthenticated requests",
Code: codes.Unauthenticated, Code: codes.Unauthenticated,
StrikeLimit: 3, StrikeLimit: 3,
ExpireBase: time.Second * 10,
Sentence: time.Second * 10,
}, },
} }
) )
@ -79,9 +77,23 @@ func WithActorFunc(f func(context.Context) (name string, found bool, err error))
} }
} }
func WithDefaultCallback(f ActionCallback) Option {
return func(o *options) {
o.defaultCallback = f
}
}
func WithDefaultJailDuration(expire time.Duration) Option {
return func(o *options) {
o.defaultJailDuration = expire
}
}
type options struct { type options struct {
cap int32 cap int32
rules []Rule rules []Rule
reaperInterval time.Duration reaperInterval time.Duration
actorFunc func(ctx context.Context) (name string, found bool, err error) actorFunc func(ctx context.Context) (name string, found bool, err error)
defaultCallback ActionCallback
defaultJailDuration time.Duration
} }

View File

@ -7,46 +7,60 @@ import (
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
) )
type ActionCallback func(action Action, actor string, rule *Rule) error
type Action int
const (
Jailed Action = iota
Released
)
func (a Action) String() string {
switch a {
case Jailed:
return "Jailed"
case Released:
return "Released"
default:
return "Unknown"
}
}
type Rule struct { type Rule struct {
Name string Name string
Message string Message string
Code codes.Code Code codes.Code
StrikeLimit int StrikeLimit int
ExpireBase time.Duration JailDuration time.Duration
Sentence time.Duration // Callback is an optional function to call when an Actor isJailed or released because of timeServed
// WhenJailed is an optional function to call when an Actor isJailed Callback ActionCallback
WhenJailed func(actor string, r *Rule) error
// WhenTimeServed is an optional function to call when an Actor is released because of timeServed
WhenTimeServed func(actor string, r *Rule) error
} }
type action struct { type action struct {
whenJailed func(actor string, r *Rule) error fn ActionCallback
whenTimeServed func(actor string, r *Rule) error
} }
func (a2 *action) WhenJailed(a *badactor.Actor, r *badactor.Rule) error { func (a2 *action) WhenJailed(a *badactor.Actor, r *badactor.Rule) error {
if a2.whenJailed != nil { if a2.fn == nil {
return a2.whenJailed(a.Name(), &Rule{ return nil
}
return a2.fn(Jailed, a.Name(), &Rule{
Name: r.Name, Name: r.Name,
Message: r.Message, Message: r.Message,
StrikeLimit: r.StrikeLimit, StrikeLimit: r.StrikeLimit,
ExpireBase: r.ExpireBase, JailDuration: r.ExpireBase,
Sentence: r.Sentence,
}) })
}
return nil
} }
func (a2 *action) WhenTimeServed(a *badactor.Actor, r *badactor.Rule) error { func (a2 *action) WhenTimeServed(a *badactor.Actor, r *badactor.Rule) error {
if a2.whenTimeServed != nil { if a2.fn == nil {
return a2.whenTimeServed(a.Name(), &Rule{ return nil
}
return a2.fn(Released, a.Name(), &Rule{
Name: r.Name, Name: r.Name,
Message: r.Message, Message: r.Message,
StrikeLimit: r.StrikeLimit, StrikeLimit: r.StrikeLimit,
ExpireBase: r.ExpireBase, JailDuration: r.ExpireBase,
Sentence: r.Sentence,
}) })
}
return nil
} }