mirror of
				https://github.com/linka-cloud/grpc.git
				synced 2025-10-30 17:12:28 +00:00 
			
		
		
		
	ban: more defaults options, simpler callback
Signed-off-by: Adphi <philippe.adrien.nousse@gmail.com>
This commit is contained in:
		| @@ -138,7 +138,10 @@ func run(opts ...service.Option) { | |||||||
| 		service.WithMiddlewares(httpLogger), | 		service.WithMiddlewares(httpLogger), | ||||||
| 		service.WithInterceptors(metrics), | 		service.WithInterceptors(metrics), | ||||||
| 		service.WithServerInterceptors( | 		service.WithServerInterceptors( | ||||||
| 			ban.NewInterceptors(), | 			ban.NewInterceptors(ban.WithDefaultJailDuration(time.Second), ban.WithDefaultCallback(func(action ban.Action, actor string, rule *ban.Rule) error { | ||||||
|  | 				log.WithFields("action", action, "actor", actor, "rule", rule.Name).Info("ban callback") | ||||||
|  | 				return nil | ||||||
|  | 			})), | ||||||
| 			auth.NewServerInterceptors(auth.WithBasicValidators(func(ctx context.Context, user, password string) (context.Context, error) { | 			auth.NewServerInterceptors(auth.WithBasicValidators(func(ctx context.Context, user, password string) (context.Context, error) { | ||||||
| 				if !auth.Equals(user, "admin") || !auth.Equals(password, "admin") { | 				if !auth.Equals(user, "admin") || !auth.Equals(password, "admin") { | ||||||
| 					return ctx, fmt.Errorf("invalid user or password") | 					return ctx, fmt.Errorf("invalid user or password") | ||||||
| @@ -187,7 +190,7 @@ func run(opts ...service.Option) { | |||||||
| 	} | 	} | ||||||
| 	g := NewGreeterClient(s) | 	g := NewGreeterClient(s) | ||||||
| 	h := grpc_health_v1.NewHealthClient(s) | 	h := grpc_health_v1.NewHealthClient(s) | ||||||
| 	for i := 0; i < 4; i++ { | 	for i := 0; i < 5; i++ { | ||||||
| 		_, err := h.Check(ctx, &grpc_health_v1.HealthCheckRequest{}) | 		_, err := h.Check(ctx, &grpc_health_v1.HealthCheckRequest{}) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Error(err) | 			log.Error(err) | ||||||
| @@ -195,6 +198,8 @@ func run(opts ...service.Option) { | |||||||
| 			log.Fatalf("expected error") | 			log.Fatalf("expected error") | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  | 	log.Infof("waiting for unban") | ||||||
|  | 	time.Sleep(time.Second) | ||||||
| 	s, err = client.New(append(copts, client.WithInterceptors(auth.NewBasicAuthClientIntereptors("admin", "admin")))...) | 	s, err = client.New(append(copts, client.WithInterceptors(auth.NewBasicAuthClientIntereptors("admin", "admin")))...) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Fatal(err) | 		log.Fatal(err) | ||||||
|   | |||||||
| @@ -27,16 +27,21 @@ func NewInterceptors(opts ...Option) interceptors.ServerInterceptors { | |||||||
| 	rules := make(map[codes.Code]Rule) | 	rules := make(map[codes.Code]Rule) | ||||||
| 	for _, r := range o.rules { | 	for _, r := range o.rules { | ||||||
| 		rules[r.Code] = r | 		rules[r.Code] = r | ||||||
|  | 		callback := r.Callback | ||||||
|  | 		if callback == nil { | ||||||
|  | 			callback = o.defaultCallback | ||||||
|  | 		} | ||||||
|  | 		expire := r.JailDuration | ||||||
|  | 		if expire == 0 { | ||||||
|  | 			expire = o.defaultJailDuration | ||||||
|  | 		} | ||||||
| 		s.AddRule(&badactor.Rule{ | 		s.AddRule(&badactor.Rule{ | ||||||
| 			Name:        r.Name, | 			Name:        r.Name, | ||||||
| 			Message:     r.Message, | 			Message:     r.Message, | ||||||
| 			StrikeLimit: r.StrikeLimit, | 			StrikeLimit: r.StrikeLimit, | ||||||
| 			ExpireBase:  r.ExpireBase, | 			ExpireBase:  expire, | ||||||
| 			Sentence:    r.Sentence, | 			Sentence:    expire, | ||||||
| 			Action: &action{ | 			Action:      &action{fn: callback}, | ||||||
| 				whenJailed:     r.WhenJailed, |  | ||||||
| 				whenTimeServed: r.WhenTimeServed, |  | ||||||
| 			}, |  | ||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
| 	// we ignore the error because CreateDirectors never returns an error | 	// we ignore the error because CreateDirectors never returns an error | ||||||
|   | |||||||
| @@ -16,10 +16,12 @@ const ( | |||||||
|  |  | ||||||
| var ( | var ( | ||||||
| 	defaultOptions = options{ | 	defaultOptions = options{ | ||||||
| 		cap:            1024, | 		cap:                 1024, | ||||||
| 		reaperInterval: 10 * time.Minute, | 		reaperInterval:      10 * time.Minute, | ||||||
| 		rules:          defaultRules, | 		rules:               defaultRules, | ||||||
| 		actorFunc:      DefaultActorFunc, | 		actorFunc:           DefaultActorFunc, | ||||||
|  | 		defaultCallback:     nil, | ||||||
|  | 		defaultJailDuration: 10 * time.Second, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	defaultRules = []Rule{ | 	defaultRules = []Rule{ | ||||||
| @@ -28,16 +30,12 @@ var ( | |||||||
| 			Message:     "Too many unauthorized requests", | 			Message:     "Too many unauthorized requests", | ||||||
| 			Code:        codes.PermissionDenied, | 			Code:        codes.PermissionDenied, | ||||||
| 			StrikeLimit: 3, | 			StrikeLimit: 3, | ||||||
| 			ExpireBase:  time.Second * 10, |  | ||||||
| 			Sentence:    time.Second * 10, |  | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			Name:        Unauthenticated, | 			Name:        Unauthenticated, | ||||||
| 			Message:     "Too many unauthenticated requests", | 			Message:     "Too many unauthenticated requests", | ||||||
| 			Code:        codes.Unauthenticated, | 			Code:        codes.Unauthenticated, | ||||||
| 			StrikeLimit: 3, | 			StrikeLimit: 3, | ||||||
| 			ExpireBase:  time.Second * 10, |  | ||||||
| 			Sentence:    time.Second * 10, |  | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
| ) | ) | ||||||
| @@ -79,9 +77,23 @@ func WithActorFunc(f func(context.Context) (name string, found bool, err error)) | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| type options struct { | func WithDefaultCallback(f ActionCallback) Option { | ||||||
| 	cap            int32 | 	return func(o *options) { | ||||||
| 	rules          []Rule | 		o.defaultCallback = f | ||||||
| 	reaperInterval time.Duration | 	} | ||||||
| 	actorFunc      func(ctx context.Context) (name string, found bool, err error) | } | ||||||
|  |  | ||||||
|  | func WithDefaultJailDuration(expire time.Duration) Option { | ||||||
|  | 	return func(o *options) { | ||||||
|  | 		o.defaultJailDuration = expire | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type options struct { | ||||||
|  | 	cap                 int32 | ||||||
|  | 	rules               []Rule | ||||||
|  | 	reaperInterval      time.Duration | ||||||
|  | 	actorFunc           func(ctx context.Context) (name string, found bool, err error) | ||||||
|  | 	defaultCallback     ActionCallback | ||||||
|  | 	defaultJailDuration time.Duration | ||||||
| } | } | ||||||
|   | |||||||
| @@ -7,46 +7,60 @@ import ( | |||||||
| 	"google.golang.org/grpc/codes" | 	"google.golang.org/grpc/codes" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | type ActionCallback func(action Action, actor string, rule *Rule) error | ||||||
|  |  | ||||||
|  | type Action int | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	Jailed Action = iota | ||||||
|  | 	Released | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func (a Action) String() string { | ||||||
|  | 	switch a { | ||||||
|  | 	case Jailed: | ||||||
|  | 		return "Jailed" | ||||||
|  | 	case Released: | ||||||
|  | 		return "Released" | ||||||
|  | 	default: | ||||||
|  | 		return "Unknown" | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| type Rule struct { | type Rule struct { | ||||||
| 	Name        string | 	Name         string | ||||||
| 	Message     string | 	Message      string | ||||||
| 	Code        codes.Code | 	Code         codes.Code | ||||||
| 	StrikeLimit int | 	StrikeLimit  int | ||||||
| 	ExpireBase  time.Duration | 	JailDuration time.Duration | ||||||
| 	Sentence    time.Duration | 	// Callback is an optional function to call when an Actor isJailed or released because of timeServed | ||||||
| 	// WhenJailed is an optional function to call when an Actor isJailed | 	Callback ActionCallback | ||||||
| 	WhenJailed func(actor string, r *Rule) error |  | ||||||
| 	// WhenTimeServed is an optional function to call when an Actor is released because of timeServed |  | ||||||
| 	WhenTimeServed func(actor string, r *Rule) error |  | ||||||
| } | } | ||||||
|  |  | ||||||
| type action struct { | type action struct { | ||||||
| 	whenJailed     func(actor string, r *Rule) error | 	fn ActionCallback | ||||||
| 	whenTimeServed func(actor string, r *Rule) error |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func (a2 *action) WhenJailed(a *badactor.Actor, r *badactor.Rule) error { | func (a2 *action) WhenJailed(a *badactor.Actor, r *badactor.Rule) error { | ||||||
| 	if a2.whenJailed != nil { | 	if a2.fn == nil { | ||||||
| 		return a2.whenJailed(a.Name(), &Rule{ | 		return nil | ||||||
| 			Name:        r.Name, |  | ||||||
| 			Message:     r.Message, |  | ||||||
| 			StrikeLimit: r.StrikeLimit, |  | ||||||
| 			ExpireBase:  r.ExpireBase, |  | ||||||
| 			Sentence:    r.Sentence, |  | ||||||
| 		}) |  | ||||||
| 	} | 	} | ||||||
| 	return nil | 	return a2.fn(Jailed, a.Name(), &Rule{ | ||||||
|  | 		Name:         r.Name, | ||||||
|  | 		Message:      r.Message, | ||||||
|  | 		StrikeLimit:  r.StrikeLimit, | ||||||
|  | 		JailDuration: r.ExpireBase, | ||||||
|  | 	}) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (a2 *action) WhenTimeServed(a *badactor.Actor, r *badactor.Rule) error { | func (a2 *action) WhenTimeServed(a *badactor.Actor, r *badactor.Rule) error { | ||||||
| 	if a2.whenTimeServed != nil { | 	if a2.fn == nil { | ||||||
| 		return a2.whenTimeServed(a.Name(), &Rule{ | 		return nil | ||||||
| 			Name:        r.Name, |  | ||||||
| 			Message:     r.Message, |  | ||||||
| 			StrikeLimit: r.StrikeLimit, |  | ||||||
| 			ExpireBase:  r.ExpireBase, |  | ||||||
| 			Sentence:    r.Sentence, |  | ||||||
| 		}) |  | ||||||
| 	} | 	} | ||||||
| 	return nil | 	return a2.fn(Released, a.Name(), &Rule{ | ||||||
|  | 		Name:         r.Name, | ||||||
|  | 		Message:      r.Message, | ||||||
|  | 		StrikeLimit:  r.StrikeLimit, | ||||||
|  | 		JailDuration: r.ExpireBase, | ||||||
|  | 	}) | ||||||
| } | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user