mirror of
https://github.com/linka-cloud/grpc.git
synced 2024-11-21 10:26:26 +00:00
interceptors: add ban
health: set services serving on start and not available on close Signed-off-by: Adphi <philippe.adrien.nousse@gmail.com>
This commit is contained in:
parent
9bf4e691ce
commit
c7096975b1
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,3 +1,4 @@
|
|||||||
.idea
|
.idea
|
||||||
.bin
|
.bin
|
||||||
/tmp
|
/tmp
|
||||||
|
diff
|
||||||
|
@ -23,7 +23,7 @@ Features:
|
|||||||
- [ ] context logger
|
- [ ] context logger
|
||||||
- [x] sentry
|
- [x] sentry
|
||||||
- [ ] rate-limiting
|
- [ ] rate-limiting
|
||||||
- [ ] ban
|
- [x] ban
|
||||||
- [ ] auth claim in context
|
- [ ] auth claim in context
|
||||||
- [x] recovery (server side only)
|
- [x] recovery (server side only)
|
||||||
- [x] tracing (open-tracing)
|
- [x] tracing (open-tracing)
|
||||||
|
@ -9,6 +9,7 @@ import (
|
|||||||
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
|
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
"google.golang.org/grpc/resolver"
|
"google.golang.org/grpc/resolver"
|
||||||
|
|
||||||
"go.linka.cloud/grpc/registry/noop"
|
"go.linka.cloud/grpc/registry/noop"
|
||||||
@ -34,8 +35,8 @@ func New(opts ...Option) (Client, error) {
|
|||||||
if c.opts.tlsConfig != nil {
|
if c.opts.tlsConfig != nil {
|
||||||
c.opts.dialOptions = append(c.opts.dialOptions, grpc.WithTransportCredentials(credentials.NewTLS(c.opts.tlsConfig)))
|
c.opts.dialOptions = append(c.opts.dialOptions, grpc.WithTransportCredentials(credentials.NewTLS(c.opts.tlsConfig)))
|
||||||
}
|
}
|
||||||
if !c.opts.secure {
|
if !c.opts.secure && c.opts.tlsConfig == nil {
|
||||||
c.opts.dialOptions = append(c.opts.dialOptions, grpc.WithInsecure())
|
c.opts.dialOptions = append(c.opts.dialOptions, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
}
|
}
|
||||||
if len(c.opts.unaryInterceptors) > 0 {
|
if len(c.opts.unaryInterceptors) > 0 {
|
||||||
c.opts.dialOptions = append(c.opts.dialOptions, grpc.WithUnaryInterceptor(grpc_middleware.ChainUnaryClient(c.opts.unaryInterceptors...)))
|
c.opts.dialOptions = append(c.opts.dialOptions, grpc.WithUnaryInterceptor(grpc_middleware.ChainUnaryClient(c.opts.unaryInterceptors...)))
|
||||||
|
@ -22,6 +22,7 @@ import (
|
|||||||
|
|
||||||
"go.linka.cloud/grpc/client"
|
"go.linka.cloud/grpc/client"
|
||||||
"go.linka.cloud/grpc/interceptors/auth"
|
"go.linka.cloud/grpc/interceptors/auth"
|
||||||
|
"go.linka.cloud/grpc/interceptors/ban"
|
||||||
"go.linka.cloud/grpc/interceptors/defaulter"
|
"go.linka.cloud/grpc/interceptors/defaulter"
|
||||||
"go.linka.cloud/grpc/interceptors/iface"
|
"go.linka.cloud/grpc/interceptors/iface"
|
||||||
metrics2 "go.linka.cloud/grpc/interceptors/metrics"
|
metrics2 "go.linka.cloud/grpc/interceptors/metrics"
|
||||||
@ -137,6 +138,7 @@ func run(opts ...service.Option) {
|
|||||||
service.WithMiddlewares(httpLogger),
|
service.WithMiddlewares(httpLogger),
|
||||||
service.WithInterceptors(metrics),
|
service.WithInterceptors(metrics),
|
||||||
service.WithServerInterceptors(
|
service.WithServerInterceptors(
|
||||||
|
ban.NewInterceptors(),
|
||||||
auth.NewServerInterceptors(auth.WithBasicValidators(func(ctx context.Context, user, password string) (context.Context, error) {
|
auth.NewServerInterceptors(auth.WithBasicValidators(func(ctx context.Context, user, password string) (context.Context, error) {
|
||||||
if !auth.Equals(user, "admin") || !auth.Equals(password, "admin") {
|
if !auth.Equals(user, "admin") || !auth.Equals(password, "admin") {
|
||||||
return ctx, fmt.Errorf("invalid user or password")
|
return ctx, fmt.Errorf("invalid user or password")
|
||||||
@ -167,7 +169,7 @@ func run(opts ...service.Option) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
<-ready
|
<-ready
|
||||||
s, err := client.New(
|
copts := []client.Option{
|
||||||
// client.WithName(name),
|
// client.WithName(name),
|
||||||
// client.WithVersion(version),
|
// client.WithVersion(version),
|
||||||
client.WithAddress("localhost:9991"),
|
client.WithAddress("localhost:9991"),
|
||||||
@ -177,14 +179,28 @@ func run(opts ...service.Option) {
|
|||||||
logger.From(ctx).WithFields("party", "client", "method", method).Info(req)
|
logger.From(ctx).WithFields("party", "client", "method", method).Info(req)
|
||||||
return invoker(ctx, method, req, reply, cc, opts...)
|
return invoker(ctx, method, req, reply, cc, opts...)
|
||||||
}),
|
}),
|
||||||
client.WithInterceptors(auth.NewBasicAuthClientIntereptors("admin", "admin")),
|
}
|
||||||
)
|
s, err := client.New(copts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
g := NewGreeterClient(s)
|
g := NewGreeterClient(s)
|
||||||
h := grpc_health_v1.NewHealthClient(s)
|
h := grpc_health_v1.NewHealthClient(s)
|
||||||
hres, err := h.Check(ctx, &grpc_health_v1.HealthCheckRequest{})
|
for i := 0; i < 4; i++ {
|
||||||
|
_, err := h.Check(ctx, &grpc_health_v1.HealthCheckRequest{})
|
||||||
|
if err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
} else {
|
||||||
|
log.Fatalf("expected error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s, err = client.New(append(copts, client.WithInterceptors(auth.NewBasicAuthClientIntereptors("admin", "admin")))...)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
g = NewGreeterClient(s)
|
||||||
|
h = grpc_health_v1.NewHealthClient(s)
|
||||||
|
hres, err := h.Check(ctx, &grpc_health_v1.HealthCheckRequest{Service: "helloworld.Greeter"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
1
go.mod
1
go.mod
@ -18,6 +18,7 @@ require (
|
|||||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.5.0
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.5.0
|
||||||
github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645
|
github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645
|
||||||
github.com/improbable-eng/grpc-web v0.14.1
|
github.com/improbable-eng/grpc-web v0.14.1
|
||||||
|
github.com/jaredfolkins/badactor v1.2.0
|
||||||
github.com/johnbellone/grpc-middleware-sentry v0.2.0
|
github.com/johnbellone/grpc-middleware-sentry v0.2.0
|
||||||
github.com/justinas/alice v1.2.0
|
github.com/justinas/alice v1.2.0
|
||||||
github.com/lyft/protoc-gen-star v0.6.0 // indirect
|
github.com/lyft/protoc-gen-star v0.6.0 // indirect
|
||||||
|
2
go.sum
2
go.sum
@ -383,6 +383,8 @@ github.com/iris-contrib/go.uuid v2.0.0+incompatible/go.mod h1:iz2lgM/1UnEf1kP0L/
|
|||||||
github.com/iris-contrib/jade v1.1.3/go.mod h1:H/geBymxJhShH5kecoiOCSssPX7QWYH7UaeZTSWddIk=
|
github.com/iris-contrib/jade v1.1.3/go.mod h1:H/geBymxJhShH5kecoiOCSssPX7QWYH7UaeZTSWddIk=
|
||||||
github.com/iris-contrib/pongo2 v0.0.1/go.mod h1:Ssh+00+3GAZqSQb30AvBRNxBx7rf0GqwkjqxNd0u65g=
|
github.com/iris-contrib/pongo2 v0.0.1/go.mod h1:Ssh+00+3GAZqSQb30AvBRNxBx7rf0GqwkjqxNd0u65g=
|
||||||
github.com/iris-contrib/schema v0.0.1/go.mod h1:urYA3uvUNG1TIIjOSCzHr9/LmbQo8LrOcOqfqxa4hXw=
|
github.com/iris-contrib/schema v0.0.1/go.mod h1:urYA3uvUNG1TIIjOSCzHr9/LmbQo8LrOcOqfqxa4hXw=
|
||||||
|
github.com/jaredfolkins/badactor v1.2.0 h1:QTJBsVG9qhdIFmFx5eNet2Q9hX8T+qZ1rC9NJwyN+Hc=
|
||||||
|
github.com/jaredfolkins/badactor v1.2.0/go.mod h1:ZynkTrC/ICU1o8mmFy3JySRCErXVlx7trZiWEH6DDg8=
|
||||||
github.com/jhump/gopoet v0.0.0-20190322174617-17282ff210b3/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI=
|
github.com/jhump/gopoet v0.0.0-20190322174617-17282ff210b3/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI=
|
||||||
github.com/jhump/gopoet v0.1.0/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI=
|
github.com/jhump/gopoet v0.1.0/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI=
|
||||||
github.com/jhump/goprotoc v0.5.0/go.mod h1:VrbvcYrQOrTi3i0Vf+m+oqQWk9l72mjkJCYo7UvLHRQ=
|
github.com/jhump/goprotoc v0.5.0/go.mod h1:VrbvcYrQOrTi3i0Vf+m+oqQWk9l72mjkJCYo7UvLHRQ=
|
||||||
|
121
interceptors/ban/ban.go
Normal file
121
interceptors/ban/ban.go
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
package ban
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/jaredfolkins/badactor"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"go.linka.cloud/grpc/interceptors"
|
||||||
|
"go.linka.cloud/grpc/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ban struct {
|
||||||
|
s *badactor.Studio
|
||||||
|
rules map[codes.Code]Rule
|
||||||
|
actor func(ctx context.Context) (string, bool, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewInterceptors(opts ...Option) interceptors.ServerInterceptors {
|
||||||
|
o := defaultOptions
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(&o)
|
||||||
|
}
|
||||||
|
s := badactor.NewStudio(o.cap)
|
||||||
|
rules := make(map[codes.Code]Rule)
|
||||||
|
for _, r := range o.rules {
|
||||||
|
rules[r.Code] = r
|
||||||
|
s.AddRule(&badactor.Rule{
|
||||||
|
Name: r.Name,
|
||||||
|
Message: r.Message,
|
||||||
|
StrikeLimit: r.StrikeLimit,
|
||||||
|
ExpireBase: r.ExpireBase,
|
||||||
|
Sentence: r.Sentence,
|
||||||
|
Action: &action{
|
||||||
|
whenJailed: r.WhenJailed,
|
||||||
|
whenTimeServed: r.WhenTimeServed,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
// we ignore the error because CreateDirectors never returns an error
|
||||||
|
_ = s.CreateDirectors(o.cap)
|
||||||
|
s.StartReaper(o.reaperInterval)
|
||||||
|
return &ban{s: s, rules: rules, actor: o.actorFunc}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *ban) UnaryServerInterceptor() grpc.UnaryServerInterceptor {
|
||||||
|
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
|
||||||
|
actor, ok, err := b.check(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ctx = set(ctx, b, actor)
|
||||||
|
if !ok {
|
||||||
|
return handler(ctx, req)
|
||||||
|
}
|
||||||
|
for _, v := range b.rules {
|
||||||
|
if b.s.IsJailedFor(actor, v.Name) {
|
||||||
|
return nil, status.Error(v.Code, v.Message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
res, err := handler(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, b.handleErr(ctx, actor, err)
|
||||||
|
}
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *ban) StreamServerInterceptor() grpc.StreamServerInterceptor {
|
||||||
|
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||||
|
actor, ok, err := b.check(ss.Context())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ss = interceptors.NewContextServerStream(set(ss.Context(), b, actor), ss)
|
||||||
|
if !ok {
|
||||||
|
return handler(srv, ss)
|
||||||
|
}
|
||||||
|
if err := handler(srv, ss); err != nil {
|
||||||
|
return b.handleErr(ss.Context(), actor, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *ban) check(ctx context.Context) (actor string, ok bool, err error) {
|
||||||
|
actor, ok, err = b.actor(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return "", false, err
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return "", false, nil
|
||||||
|
}
|
||||||
|
for _, v := range b.rules {
|
||||||
|
if b.s.IsJailedFor(actor, v.Name) {
|
||||||
|
return actor, false, status.Error(v.Code, v.Message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return actor, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *ban) handleErr(ctx context.Context, actor string, err error) error {
|
||||||
|
v, ok := ctx.Value(key{}).(*value)
|
||||||
|
if !ok || v.done {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s, ok := status.FromError(err)
|
||||||
|
if !ok {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
r, ok := b.rules[s.Code()]
|
||||||
|
if !ok {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := b.s.Infraction(actor, r.Name); err != nil {
|
||||||
|
logger.C(ctx).Warnf("%s: failed to add infraction: %v", r.Name, err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
34
interceptors/ban/context.go
Normal file
34
interceptors/ban/context.go
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
package ban
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
type key struct{}
|
||||||
|
|
||||||
|
type value struct {
|
||||||
|
ban *ban
|
||||||
|
actor string
|
||||||
|
done bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func set(ctx context.Context, b *ban, actor string) context.Context {
|
||||||
|
return context.WithValue(ctx, key{}, &value{ban: b, actor: actor})
|
||||||
|
}
|
||||||
|
|
||||||
|
func Infraction(ctx context.Context, rule string) error {
|
||||||
|
v, ok := ctx.Value(key{}).(*value)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
v.done = true
|
||||||
|
return v.ban.s.Infraction(v.actor, rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Actor(ctx context.Context) string {
|
||||||
|
v, ok := ctx.Value(key{}).(*value)
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return v.actor
|
||||||
|
}
|
83
interceptors/ban/options.go
Normal file
83
interceptors/ban/options.go
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
package ban
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/peer"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
Unauthorized = "Unauthorized"
|
||||||
|
Unauthenticated = "Unauthenticated"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
defaultOptions = options{
|
||||||
|
cap: 1024,
|
||||||
|
reaperInterval: 10 * time.Minute,
|
||||||
|
rules: defaultRules,
|
||||||
|
actorFunc: DefaultActorFunc,
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultRules = []Rule{
|
||||||
|
{
|
||||||
|
Name: Unauthorized,
|
||||||
|
Message: "Too many unauthorized requests",
|
||||||
|
Code: codes.PermissionDenied,
|
||||||
|
StrikeLimit: 3,
|
||||||
|
ExpireBase: time.Second * 10,
|
||||||
|
Sentence: time.Second * 10,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: Unauthenticated,
|
||||||
|
Message: "Too many unauthenticated requests",
|
||||||
|
Code: codes.Unauthenticated,
|
||||||
|
StrikeLimit: 3,
|
||||||
|
ExpireBase: time.Second * 10,
|
||||||
|
Sentence: time.Second * 10,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func DefaultActorFunc(ctx context.Context) (string, bool, error) {
|
||||||
|
p, ok := peer.FromContext(ctx)
|
||||||
|
if !ok {
|
||||||
|
return "", false, nil
|
||||||
|
}
|
||||||
|
return p.Addr.String(), true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type Option func(*options)
|
||||||
|
|
||||||
|
func WithCapacity(cap int32) Option {
|
||||||
|
return func(o *options) {
|
||||||
|
o.cap = cap
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRules(rules ...Rule) Option {
|
||||||
|
return func(o *options) {
|
||||||
|
o.rules = rules
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithReaperInterval(interval time.Duration) Option {
|
||||||
|
return func(o *options) {
|
||||||
|
o.reaperInterval = interval
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithActorFunc(f func(context.Context) (name string, found bool, err error)) Option {
|
||||||
|
return func(o *options) {
|
||||||
|
o.actorFunc = f
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type options struct {
|
||||||
|
cap int32
|
||||||
|
rules []Rule
|
||||||
|
reaperInterval time.Duration
|
||||||
|
actorFunc func(ctx context.Context) (name string, found bool, err error)
|
||||||
|
}
|
52
interceptors/ban/rule.go
Normal file
52
interceptors/ban/rule.go
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
package ban
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jaredfolkins/badactor"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Rule struct {
|
||||||
|
Name string
|
||||||
|
Message string
|
||||||
|
Code codes.Code
|
||||||
|
StrikeLimit int
|
||||||
|
ExpireBase time.Duration
|
||||||
|
Sentence time.Duration
|
||||||
|
// WhenJailed is an optional function to call when an Actor isJailed
|
||||||
|
WhenJailed func(actor string, r *Rule) error
|
||||||
|
// WhenTimeServed is an optional function to call when an Actor is released because of timeServed
|
||||||
|
WhenTimeServed func(actor string, r *Rule) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type action struct {
|
||||||
|
whenJailed func(actor string, r *Rule) error
|
||||||
|
whenTimeServed func(actor string, r *Rule) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a2 *action) WhenJailed(a *badactor.Actor, r *badactor.Rule) error {
|
||||||
|
if a2.whenJailed != nil {
|
||||||
|
return a2.whenJailed(a.Name(), &Rule{
|
||||||
|
Name: r.Name,
|
||||||
|
Message: r.Message,
|
||||||
|
StrikeLimit: r.StrikeLimit,
|
||||||
|
ExpireBase: r.ExpireBase,
|
||||||
|
Sentence: r.Sentence,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a2 *action) WhenTimeServed(a *badactor.Actor, r *badactor.Rule) error {
|
||||||
|
if a2.whenTimeServed != nil {
|
||||||
|
return a2.whenTimeServed(a.Name(), &Rule{
|
||||||
|
Name: r.Name,
|
||||||
|
Message: r.Message,
|
||||||
|
StrikeLimit: r.StrikeLimit,
|
||||||
|
ExpireBase: r.ExpireBase,
|
||||||
|
Sentence: r.Sentence,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
@ -1,6 +1,8 @@
|
|||||||
package interceptors
|
package interceptors
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -18,3 +20,16 @@ type Interceptors interface {
|
|||||||
ServerInterceptors
|
ServerInterceptors
|
||||||
ClientInterceptors
|
ClientInterceptors
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewContextServerStream(ctx context.Context, ss grpc.ServerStream) grpc.ServerStream {
|
||||||
|
return &ContextWrapper{ServerStream: ss, ctx: ctx}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ContextWrapper struct {
|
||||||
|
grpc.ServerStream
|
||||||
|
ctx context.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ContextWrapper) Context() context.Context {
|
||||||
|
return w.ctx
|
||||||
|
}
|
||||||
|
@ -35,19 +35,6 @@ func (f *forward) StreamServerInterceptor() grpc.StreamServerInterceptor {
|
|||||||
if md2, ok := metadata.FromOutgoingContext(ctx); ok {
|
if md2, ok := metadata.FromOutgoingContext(ctx); ok {
|
||||||
o = metadata.Join(o, md2.Copy())
|
o = metadata.Join(o, md2.Copy())
|
||||||
}
|
}
|
||||||
return handler(srv, NewContextServerStream(metadata.NewOutgoingContext(ctx, o), ss))
|
return handler(srv, interceptors.NewContextServerStream(metadata.NewOutgoingContext(ctx, o), ss))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewContextServerStream(ctx context.Context, ss grpc.ServerStream) grpc.ServerStream {
|
|
||||||
return &ContextWrapper{ServerStream: ss, ctx: ctx}
|
|
||||||
}
|
|
||||||
|
|
||||||
type ContextWrapper struct {
|
|
||||||
grpc.ServerStream
|
|
||||||
ctx context.Context
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *ContextWrapper) Context() context.Context {
|
|
||||||
return w.ctx
|
|
||||||
}
|
|
||||||
|
@ -56,6 +56,8 @@ type service struct {
|
|||||||
inproc *inprocgrpc.Channel
|
inproc *inprocgrpc.Channel
|
||||||
services map[string]*serviceInfo
|
services map[string]*serviceInfo
|
||||||
|
|
||||||
|
healthServer *health.Server
|
||||||
|
|
||||||
id string
|
id string
|
||||||
regSvc *registry.Service
|
regSvc *registry.Service
|
||||||
closed chan struct{}
|
closed chan struct{}
|
||||||
@ -120,7 +122,8 @@ func newService(opts ...Option) (*service, error) {
|
|||||||
greflect.Register(s.server)
|
greflect.Register(s.server)
|
||||||
}
|
}
|
||||||
if s.opts.health {
|
if s.opts.health {
|
||||||
s.registerService(&grpc_health_v1.Health_ServiceDesc, health.NewServer())
|
s.healthServer = health.NewServer()
|
||||||
|
s.registerService(&grpc_health_v1.Health_ServiceDesc, s.healthServer)
|
||||||
}
|
}
|
||||||
if err := s.gateway(s.opts.gatewayOpts...); err != nil {
|
if err := s.gateway(s.opts.gatewayOpts...); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -226,6 +229,18 @@ func (s *service) run() error {
|
|||||||
}
|
}
|
||||||
errs <- nil
|
errs <- nil
|
||||||
}()
|
}()
|
||||||
|
if s.healthServer != nil {
|
||||||
|
for k := range s.services {
|
||||||
|
s.healthServer.SetServingStatus(k, grpc_health_v1.HealthCheckResponse_SERVING)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if s.healthServer != nil {
|
||||||
|
for k := range s.services {
|
||||||
|
s.healthServer.SetServingStatus(k, grpc_health_v1.HealthCheckResponse_NOT_SERVING)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
for i := range s.opts.afterStart {
|
for i := range s.opts.afterStart {
|
||||||
if err := s.opts.afterStart[i](); err != nil {
|
if err := s.opts.afterStart[i](); err != nil {
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
Loading…
Reference in New Issue
Block a user