mirror of
				https://github.com/linka-cloud/grpc.git
				synced 2025-10-31 01:22:29 +00:00 
			
		
		
		
	interceptors: add ban
health: set services serving on start and not available on close Signed-off-by: Adphi <philippe.adrien.nousse@gmail.com>
This commit is contained in:
		
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -1,3 +1,4 @@ | |||||||
| .idea | .idea | ||||||
| .bin | .bin | ||||||
| /tmp | /tmp | ||||||
|  | diff | ||||||
|   | |||||||
| @@ -23,7 +23,7 @@ Features: | |||||||
|     - [ ] context logger |     - [ ] context logger | ||||||
|     - [x] sentry |     - [x] sentry | ||||||
|     - [ ] rate-limiting |     - [ ] rate-limiting | ||||||
|     - [ ] ban |     - [x] ban | ||||||
|     - [ ] auth claim in context |     - [ ] auth claim in context | ||||||
|     - [x] recovery (server side only) |     - [x] recovery (server side only) | ||||||
|     - [x] tracing (open-tracing) |     - [x] tracing (open-tracing) | ||||||
|   | |||||||
| @@ -9,6 +9,7 @@ import ( | |||||||
| 	grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" | 	grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" | ||||||
| 	"google.golang.org/grpc" | 	"google.golang.org/grpc" | ||||||
| 	"google.golang.org/grpc/credentials" | 	"google.golang.org/grpc/credentials" | ||||||
|  | 	"google.golang.org/grpc/credentials/insecure" | ||||||
| 	"google.golang.org/grpc/resolver" | 	"google.golang.org/grpc/resolver" | ||||||
|  |  | ||||||
| 	"go.linka.cloud/grpc/registry/noop" | 	"go.linka.cloud/grpc/registry/noop" | ||||||
| @@ -34,8 +35,8 @@ func New(opts ...Option) (Client, error) { | |||||||
| 	if c.opts.tlsConfig != nil { | 	if c.opts.tlsConfig != nil { | ||||||
| 		c.opts.dialOptions = append(c.opts.dialOptions, grpc.WithTransportCredentials(credentials.NewTLS(c.opts.tlsConfig))) | 		c.opts.dialOptions = append(c.opts.dialOptions, grpc.WithTransportCredentials(credentials.NewTLS(c.opts.tlsConfig))) | ||||||
| 	} | 	} | ||||||
| 	if !c.opts.secure { | 	if !c.opts.secure && c.opts.tlsConfig == nil { | ||||||
| 		c.opts.dialOptions = append(c.opts.dialOptions, grpc.WithInsecure()) | 		c.opts.dialOptions = append(c.opts.dialOptions, grpc.WithTransportCredentials(insecure.NewCredentials())) | ||||||
| 	} | 	} | ||||||
| 	if len(c.opts.unaryInterceptors) > 0 { | 	if len(c.opts.unaryInterceptors) > 0 { | ||||||
| 		c.opts.dialOptions = append(c.opts.dialOptions, grpc.WithUnaryInterceptor(grpc_middleware.ChainUnaryClient(c.opts.unaryInterceptors...))) | 		c.opts.dialOptions = append(c.opts.dialOptions, grpc.WithUnaryInterceptor(grpc_middleware.ChainUnaryClient(c.opts.unaryInterceptors...))) | ||||||
|   | |||||||
| @@ -22,6 +22,7 @@ import ( | |||||||
|  |  | ||||||
| 	"go.linka.cloud/grpc/client" | 	"go.linka.cloud/grpc/client" | ||||||
| 	"go.linka.cloud/grpc/interceptors/auth" | 	"go.linka.cloud/grpc/interceptors/auth" | ||||||
|  | 	"go.linka.cloud/grpc/interceptors/ban" | ||||||
| 	"go.linka.cloud/grpc/interceptors/defaulter" | 	"go.linka.cloud/grpc/interceptors/defaulter" | ||||||
| 	"go.linka.cloud/grpc/interceptors/iface" | 	"go.linka.cloud/grpc/interceptors/iface" | ||||||
| 	metrics2 "go.linka.cloud/grpc/interceptors/metrics" | 	metrics2 "go.linka.cloud/grpc/interceptors/metrics" | ||||||
| @@ -137,6 +138,7 @@ func run(opts ...service.Option) { | |||||||
| 		service.WithMiddlewares(httpLogger), | 		service.WithMiddlewares(httpLogger), | ||||||
| 		service.WithInterceptors(metrics), | 		service.WithInterceptors(metrics), | ||||||
| 		service.WithServerInterceptors( | 		service.WithServerInterceptors( | ||||||
|  | 			ban.NewInterceptors(), | ||||||
| 			auth.NewServerInterceptors(auth.WithBasicValidators(func(ctx context.Context, user, password string) (context.Context, error) { | 			auth.NewServerInterceptors(auth.WithBasicValidators(func(ctx context.Context, user, password string) (context.Context, error) { | ||||||
| 				if !auth.Equals(user, "admin") || !auth.Equals(password, "admin") { | 				if !auth.Equals(user, "admin") || !auth.Equals(password, "admin") { | ||||||
| 					return ctx, fmt.Errorf("invalid user or password") | 					return ctx, fmt.Errorf("invalid user or password") | ||||||
| @@ -167,7 +169,7 @@ func run(opts ...service.Option) { | |||||||
| 		} | 		} | ||||||
| 	}() | 	}() | ||||||
| 	<-ready | 	<-ready | ||||||
| 	s, err := client.New( | 	copts := []client.Option{ | ||||||
| 		// client.WithName(name), | 		// client.WithName(name), | ||||||
| 		// client.WithVersion(version), | 		// client.WithVersion(version), | ||||||
| 		client.WithAddress("localhost:9991"), | 		client.WithAddress("localhost:9991"), | ||||||
| @@ -177,14 +179,28 @@ func run(opts ...service.Option) { | |||||||
| 			logger.From(ctx).WithFields("party", "client", "method", method).Info(req) | 			logger.From(ctx).WithFields("party", "client", "method", method).Info(req) | ||||||
| 			return invoker(ctx, method, req, reply, cc, opts...) | 			return invoker(ctx, method, req, reply, cc, opts...) | ||||||
| 		}), | 		}), | ||||||
| 		client.WithInterceptors(auth.NewBasicAuthClientIntereptors("admin", "admin")), | 	} | ||||||
| 	) | 	s, err := client.New(copts...) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Fatal(err) | 		log.Fatal(err) | ||||||
| 	} | 	} | ||||||
| 	g := NewGreeterClient(s) | 	g := NewGreeterClient(s) | ||||||
| 	h := grpc_health_v1.NewHealthClient(s) | 	h := grpc_health_v1.NewHealthClient(s) | ||||||
| 	hres, err := h.Check(ctx, &grpc_health_v1.HealthCheckRequest{}) | 	for i := 0; i < 4; i++ { | ||||||
|  | 		_, err := h.Check(ctx, &grpc_health_v1.HealthCheckRequest{}) | ||||||
|  | 		if err != nil { | ||||||
|  | 			log.Error(err) | ||||||
|  | 		} else { | ||||||
|  | 			log.Fatalf("expected error") | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	s, err = client.New(append(copts, client.WithInterceptors(auth.NewBasicAuthClientIntereptors("admin", "admin")))...) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	g = NewGreeterClient(s) | ||||||
|  | 	h = grpc_health_v1.NewHealthClient(s) | ||||||
|  | 	hres, err := h.Check(ctx, &grpc_health_v1.HealthCheckRequest{Service: "helloworld.Greeter"}) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Fatal(err) | 		log.Fatal(err) | ||||||
| 	} | 	} | ||||||
|   | |||||||
							
								
								
									
										1
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								go.mod
									
									
									
									
									
								
							| @@ -18,6 +18,7 @@ require ( | |||||||
| 	github.com/grpc-ecosystem/grpc-gateway/v2 v2.5.0 | 	github.com/grpc-ecosystem/grpc-gateway/v2 v2.5.0 | ||||||
| 	github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645 | 	github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645 | ||||||
| 	github.com/improbable-eng/grpc-web v0.14.1 | 	github.com/improbable-eng/grpc-web v0.14.1 | ||||||
|  | 	github.com/jaredfolkins/badactor v1.2.0 | ||||||
| 	github.com/johnbellone/grpc-middleware-sentry v0.2.0 | 	github.com/johnbellone/grpc-middleware-sentry v0.2.0 | ||||||
| 	github.com/justinas/alice v1.2.0 | 	github.com/justinas/alice v1.2.0 | ||||||
| 	github.com/lyft/protoc-gen-star v0.6.0 // indirect | 	github.com/lyft/protoc-gen-star v0.6.0 // indirect | ||||||
|   | |||||||
							
								
								
									
										2
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								go.sum
									
									
									
									
									
								
							| @@ -383,6 +383,8 @@ github.com/iris-contrib/go.uuid v2.0.0+incompatible/go.mod h1:iz2lgM/1UnEf1kP0L/ | |||||||
| github.com/iris-contrib/jade v1.1.3/go.mod h1:H/geBymxJhShH5kecoiOCSssPX7QWYH7UaeZTSWddIk= | github.com/iris-contrib/jade v1.1.3/go.mod h1:H/geBymxJhShH5kecoiOCSssPX7QWYH7UaeZTSWddIk= | ||||||
| github.com/iris-contrib/pongo2 v0.0.1/go.mod h1:Ssh+00+3GAZqSQb30AvBRNxBx7rf0GqwkjqxNd0u65g= | github.com/iris-contrib/pongo2 v0.0.1/go.mod h1:Ssh+00+3GAZqSQb30AvBRNxBx7rf0GqwkjqxNd0u65g= | ||||||
| github.com/iris-contrib/schema v0.0.1/go.mod h1:urYA3uvUNG1TIIjOSCzHr9/LmbQo8LrOcOqfqxa4hXw= | github.com/iris-contrib/schema v0.0.1/go.mod h1:urYA3uvUNG1TIIjOSCzHr9/LmbQo8LrOcOqfqxa4hXw= | ||||||
|  | github.com/jaredfolkins/badactor v1.2.0 h1:QTJBsVG9qhdIFmFx5eNet2Q9hX8T+qZ1rC9NJwyN+Hc= | ||||||
|  | github.com/jaredfolkins/badactor v1.2.0/go.mod h1:ZynkTrC/ICU1o8mmFy3JySRCErXVlx7trZiWEH6DDg8= | ||||||
| github.com/jhump/gopoet v0.0.0-20190322174617-17282ff210b3/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI= | github.com/jhump/gopoet v0.0.0-20190322174617-17282ff210b3/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI= | ||||||
| github.com/jhump/gopoet v0.1.0/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI= | github.com/jhump/gopoet v0.1.0/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI= | ||||||
| github.com/jhump/goprotoc v0.5.0/go.mod h1:VrbvcYrQOrTi3i0Vf+m+oqQWk9l72mjkJCYo7UvLHRQ= | github.com/jhump/goprotoc v0.5.0/go.mod h1:VrbvcYrQOrTi3i0Vf+m+oqQWk9l72mjkJCYo7UvLHRQ= | ||||||
|   | |||||||
							
								
								
									
										121
									
								
								interceptors/ban/ban.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										121
									
								
								interceptors/ban/ban.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,121 @@ | |||||||
|  | package ban | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  |  | ||||||
|  | 	"github.com/jaredfolkins/badactor" | ||||||
|  | 	"google.golang.org/grpc" | ||||||
|  | 	"google.golang.org/grpc/codes" | ||||||
|  | 	"google.golang.org/grpc/status" | ||||||
|  |  | ||||||
|  | 	"go.linka.cloud/grpc/interceptors" | ||||||
|  | 	"go.linka.cloud/grpc/logger" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type ban struct { | ||||||
|  | 	s     *badactor.Studio | ||||||
|  | 	rules map[codes.Code]Rule | ||||||
|  | 	actor func(ctx context.Context) (string, bool, error) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func NewInterceptors(opts ...Option) interceptors.ServerInterceptors { | ||||||
|  | 	o := defaultOptions | ||||||
|  | 	for _, opt := range opts { | ||||||
|  | 		opt(&o) | ||||||
|  | 	} | ||||||
|  | 	s := badactor.NewStudio(o.cap) | ||||||
|  | 	rules := make(map[codes.Code]Rule) | ||||||
|  | 	for _, r := range o.rules { | ||||||
|  | 		rules[r.Code] = r | ||||||
|  | 		s.AddRule(&badactor.Rule{ | ||||||
|  | 			Name:        r.Name, | ||||||
|  | 			Message:     r.Message, | ||||||
|  | 			StrikeLimit: r.StrikeLimit, | ||||||
|  | 			ExpireBase:  r.ExpireBase, | ||||||
|  | 			Sentence:    r.Sentence, | ||||||
|  | 			Action: &action{ | ||||||
|  | 				whenJailed:     r.WhenJailed, | ||||||
|  | 				whenTimeServed: r.WhenTimeServed, | ||||||
|  | 			}, | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | 	// we ignore the error because CreateDirectors never returns an error | ||||||
|  | 	_ = s.CreateDirectors(o.cap) | ||||||
|  | 	s.StartReaper(o.reaperInterval) | ||||||
|  | 	return &ban{s: s, rules: rules, actor: o.actorFunc} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (b *ban) UnaryServerInterceptor() grpc.UnaryServerInterceptor { | ||||||
|  | 	return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { | ||||||
|  | 		actor, ok, err := b.check(ctx) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 		ctx = set(ctx, b, actor) | ||||||
|  | 		if !ok { | ||||||
|  | 			return handler(ctx, req) | ||||||
|  | 		} | ||||||
|  | 		for _, v := range b.rules { | ||||||
|  | 			if b.s.IsJailedFor(actor, v.Name) { | ||||||
|  | 				return nil, status.Error(v.Code, v.Message) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		res, err := handler(ctx, req) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, b.handleErr(ctx, actor, err) | ||||||
|  | 		} | ||||||
|  | 		return res, nil | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (b *ban) StreamServerInterceptor() grpc.StreamServerInterceptor { | ||||||
|  | 	return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { | ||||||
|  | 		actor, ok, err := b.check(ss.Context()) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		ss = interceptors.NewContextServerStream(set(ss.Context(), b, actor), ss) | ||||||
|  | 		if !ok { | ||||||
|  | 			return handler(srv, ss) | ||||||
|  | 		} | ||||||
|  | 		if err := handler(srv, ss); err != nil { | ||||||
|  | 			return b.handleErr(ss.Context(), actor, err) | ||||||
|  | 		} | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (b *ban) check(ctx context.Context) (actor string, ok bool, err error) { | ||||||
|  | 	actor, ok, err = b.actor(ctx) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", false, err | ||||||
|  | 	} | ||||||
|  | 	if !ok { | ||||||
|  | 		return "", false, nil | ||||||
|  | 	} | ||||||
|  | 	for _, v := range b.rules { | ||||||
|  | 		if b.s.IsJailedFor(actor, v.Name) { | ||||||
|  | 			return actor, false, status.Error(v.Code, v.Message) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return actor, true, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (b *ban) handleErr(ctx context.Context, actor string, err error) error { | ||||||
|  | 	v, ok := ctx.Value(key{}).(*value) | ||||||
|  | 	if !ok || v.done { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	s, ok := status.FromError(err) | ||||||
|  | 	if !ok { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	r, ok := b.rules[s.Code()] | ||||||
|  | 	if !ok { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	if err := b.s.Infraction(actor, r.Name); err != nil { | ||||||
|  | 		logger.C(ctx).Warnf("%s: failed to add infraction: %v", r.Name, err) | ||||||
|  | 	} | ||||||
|  | 	return err | ||||||
|  | } | ||||||
							
								
								
									
										34
									
								
								interceptors/ban/context.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								interceptors/ban/context.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,34 @@ | |||||||
|  | package ban | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type key struct{} | ||||||
|  |  | ||||||
|  | type value struct { | ||||||
|  | 	ban   *ban | ||||||
|  | 	actor string | ||||||
|  | 	done  bool | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func set(ctx context.Context, b *ban, actor string) context.Context { | ||||||
|  | 	return context.WithValue(ctx, key{}, &value{ban: b, actor: actor}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func Infraction(ctx context.Context, rule string) error { | ||||||
|  | 	v, ok := ctx.Value(key{}).(*value) | ||||||
|  | 	if !ok { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 	v.done = true | ||||||
|  | 	return v.ban.s.Infraction(v.actor, rule) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func Actor(ctx context.Context) string { | ||||||
|  | 	v, ok := ctx.Value(key{}).(*value) | ||||||
|  | 	if !ok { | ||||||
|  | 		return "" | ||||||
|  | 	} | ||||||
|  | 	return v.actor | ||||||
|  | } | ||||||
							
								
								
									
										83
									
								
								interceptors/ban/options.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										83
									
								
								interceptors/ban/options.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,83 @@ | |||||||
|  | package ban | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	"google.golang.org/grpc/codes" | ||||||
|  | 	"google.golang.org/grpc/peer" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	Unauthorized    = "Unauthorized" | ||||||
|  | 	Unauthenticated = "Unauthenticated" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var ( | ||||||
|  | 	defaultOptions = options{ | ||||||
|  | 		cap:            1024, | ||||||
|  | 		reaperInterval: 10 * time.Minute, | ||||||
|  | 		rules:          defaultRules, | ||||||
|  | 		actorFunc:      DefaultActorFunc, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	defaultRules = []Rule{ | ||||||
|  | 		{ | ||||||
|  | 			Name:        Unauthorized, | ||||||
|  | 			Message:     "Too many unauthorized requests", | ||||||
|  | 			Code:        codes.PermissionDenied, | ||||||
|  | 			StrikeLimit: 3, | ||||||
|  | 			ExpireBase:  time.Second * 10, | ||||||
|  | 			Sentence:    time.Second * 10, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Name:        Unauthenticated, | ||||||
|  | 			Message:     "Too many unauthenticated requests", | ||||||
|  | 			Code:        codes.Unauthenticated, | ||||||
|  | 			StrikeLimit: 3, | ||||||
|  | 			ExpireBase:  time.Second * 10, | ||||||
|  | 			Sentence:    time.Second * 10, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func DefaultActorFunc(ctx context.Context) (string, bool, error) { | ||||||
|  | 	p, ok := peer.FromContext(ctx) | ||||||
|  | 	if !ok { | ||||||
|  | 		return "", false, nil | ||||||
|  | 	} | ||||||
|  | 	return p.Addr.String(), true, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Option func(*options) | ||||||
|  |  | ||||||
|  | func WithCapacity(cap int32) Option { | ||||||
|  | 	return func(o *options) { | ||||||
|  | 		o.cap = cap | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func WithRules(rules ...Rule) Option { | ||||||
|  | 	return func(o *options) { | ||||||
|  | 		o.rules = rules | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func WithReaperInterval(interval time.Duration) Option { | ||||||
|  | 	return func(o *options) { | ||||||
|  | 		o.reaperInterval = interval | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func WithActorFunc(f func(context.Context) (name string, found bool, err error)) Option { | ||||||
|  | 	return func(o *options) { | ||||||
|  | 		o.actorFunc = f | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type options struct { | ||||||
|  | 	cap            int32 | ||||||
|  | 	rules          []Rule | ||||||
|  | 	reaperInterval time.Duration | ||||||
|  | 	actorFunc      func(ctx context.Context) (name string, found bool, err error) | ||||||
|  | } | ||||||
							
								
								
									
										52
									
								
								interceptors/ban/rule.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								interceptors/ban/rule.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,52 @@ | |||||||
|  | package ban | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/jaredfolkins/badactor" | ||||||
|  | 	"google.golang.org/grpc/codes" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type Rule struct { | ||||||
|  | 	Name        string | ||||||
|  | 	Message     string | ||||||
|  | 	Code        codes.Code | ||||||
|  | 	StrikeLimit int | ||||||
|  | 	ExpireBase  time.Duration | ||||||
|  | 	Sentence    time.Duration | ||||||
|  | 	// WhenJailed is an optional function to call when an Actor isJailed | ||||||
|  | 	WhenJailed func(actor string, r *Rule) error | ||||||
|  | 	// WhenTimeServed is an optional function to call when an Actor is released because of timeServed | ||||||
|  | 	WhenTimeServed func(actor string, r *Rule) error | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type action struct { | ||||||
|  | 	whenJailed     func(actor string, r *Rule) error | ||||||
|  | 	whenTimeServed func(actor string, r *Rule) error | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (a2 *action) WhenJailed(a *badactor.Actor, r *badactor.Rule) error { | ||||||
|  | 	if a2.whenJailed != nil { | ||||||
|  | 		return a2.whenJailed(a.Name(), &Rule{ | ||||||
|  | 			Name:        r.Name, | ||||||
|  | 			Message:     r.Message, | ||||||
|  | 			StrikeLimit: r.StrikeLimit, | ||||||
|  | 			ExpireBase:  r.ExpireBase, | ||||||
|  | 			Sentence:    r.Sentence, | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (a2 *action) WhenTimeServed(a *badactor.Actor, r *badactor.Rule) error { | ||||||
|  | 	if a2.whenTimeServed != nil { | ||||||
|  | 		return a2.whenTimeServed(a.Name(), &Rule{ | ||||||
|  | 			Name:        r.Name, | ||||||
|  | 			Message:     r.Message, | ||||||
|  | 			StrikeLimit: r.StrikeLimit, | ||||||
|  | 			ExpireBase:  r.ExpireBase, | ||||||
|  | 			Sentence:    r.Sentence, | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
| @@ -1,6 +1,8 @@ | |||||||
| package interceptors | package interceptors | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
|  |  | ||||||
| 	"google.golang.org/grpc" | 	"google.golang.org/grpc" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -18,3 +20,16 @@ type Interceptors interface { | |||||||
| 	ServerInterceptors | 	ServerInterceptors | ||||||
| 	ClientInterceptors | 	ClientInterceptors | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func NewContextServerStream(ctx context.Context, ss grpc.ServerStream) grpc.ServerStream { | ||||||
|  | 	return &ContextWrapper{ServerStream: ss, ctx: ctx} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type ContextWrapper struct { | ||||||
|  | 	grpc.ServerStream | ||||||
|  | 	ctx context.Context | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *ContextWrapper) Context() context.Context { | ||||||
|  | 	return w.ctx | ||||||
|  | } | ||||||
|   | |||||||
| @@ -35,19 +35,6 @@ func (f *forward) StreamServerInterceptor() grpc.StreamServerInterceptor { | |||||||
| 		if md2, ok := metadata.FromOutgoingContext(ctx); ok { | 		if md2, ok := metadata.FromOutgoingContext(ctx); ok { | ||||||
| 			o = metadata.Join(o, md2.Copy()) | 			o = metadata.Join(o, md2.Copy()) | ||||||
| 		} | 		} | ||||||
| 		return handler(srv, NewContextServerStream(metadata.NewOutgoingContext(ctx, o), ss)) | 		return handler(srv, interceptors.NewContextServerStream(metadata.NewOutgoingContext(ctx, o), ss)) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func NewContextServerStream(ctx context.Context, ss grpc.ServerStream) grpc.ServerStream { |  | ||||||
| 	return &ContextWrapper{ServerStream: ss, ctx: ctx} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type ContextWrapper struct { |  | ||||||
| 	grpc.ServerStream |  | ||||||
| 	ctx context.Context |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (w *ContextWrapper) Context() context.Context { |  | ||||||
| 	return w.ctx |  | ||||||
| } |  | ||||||
|   | |||||||
| @@ -56,6 +56,8 @@ type service struct { | |||||||
| 	inproc   *inprocgrpc.Channel | 	inproc   *inprocgrpc.Channel | ||||||
| 	services map[string]*serviceInfo | 	services map[string]*serviceInfo | ||||||
|  |  | ||||||
|  | 	healthServer *health.Server | ||||||
|  |  | ||||||
| 	id     string | 	id     string | ||||||
| 	regSvc *registry.Service | 	regSvc *registry.Service | ||||||
| 	closed chan struct{} | 	closed chan struct{} | ||||||
| @@ -120,7 +122,8 @@ func newService(opts ...Option) (*service, error) { | |||||||
| 		greflect.Register(s.server) | 		greflect.Register(s.server) | ||||||
| 	} | 	} | ||||||
| 	if s.opts.health { | 	if s.opts.health { | ||||||
| 		s.registerService(&grpc_health_v1.Health_ServiceDesc, health.NewServer()) | 		s.healthServer = health.NewServer() | ||||||
|  | 		s.registerService(&grpc_health_v1.Health_ServiceDesc, s.healthServer) | ||||||
| 	} | 	} | ||||||
| 	if err := s.gateway(s.opts.gatewayOpts...); err != nil { | 	if err := s.gateway(s.opts.gatewayOpts...); err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -226,6 +229,18 @@ func (s *service) run() error { | |||||||
| 		} | 		} | ||||||
| 		errs <- nil | 		errs <- nil | ||||||
| 	}() | 	}() | ||||||
|  | 	if s.healthServer != nil { | ||||||
|  | 		for k := range s.services { | ||||||
|  | 			s.healthServer.SetServingStatus(k, grpc_health_v1.HealthCheckResponse_SERVING) | ||||||
|  | 		} | ||||||
|  | 		defer func() { | ||||||
|  | 			if s.healthServer != nil { | ||||||
|  | 				for k := range s.services { | ||||||
|  | 					s.healthServer.SetServingStatus(k, grpc_health_v1.HealthCheckResponse_NOT_SERVING) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		}() | ||||||
|  | 	} | ||||||
| 	for i := range s.opts.afterStart { | 	for i := range s.opts.afterStart { | ||||||
| 		if err := s.opts.afterStart[i](); err != nil { | 		if err := s.opts.afterStart[i](); err != nil { | ||||||
| 			s.mu.Unlock() | 			s.mu.Unlock() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user