mirror of
https://github.com/linka-cloud/grpc.git
synced 2024-11-21 18:36:25 +00:00
127 lines
3.0 KiB
Go
127 lines
3.0 KiB
Go
package ban
|
|
|
|
import (
|
|
"context"
|
|
|
|
"github.com/jaredfolkins/badactor"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
|
|
"go.linka.cloud/grpc-toolkit/interceptors"
|
|
"go.linka.cloud/grpc-toolkit/logger"
|
|
)
|
|
|
|
type ban struct {
|
|
s *badactor.Studio
|
|
rules map[codes.Code]Rule
|
|
actor func(ctx context.Context) (string, bool, error)
|
|
}
|
|
|
|
func NewInterceptors(opts ...Option) interceptors.ServerInterceptors {
|
|
o := defaultOptions
|
|
for _, opt := range opts {
|
|
opt(&o)
|
|
}
|
|
s := badactor.NewStudio(o.cap)
|
|
rules := make(map[codes.Code]Rule)
|
|
for _, r := range o.rules {
|
|
rules[r.Code] = r
|
|
callback := r.Callback
|
|
if callback == nil {
|
|
callback = o.defaultCallback
|
|
}
|
|
expire := r.JailDuration
|
|
if expire == 0 {
|
|
expire = o.defaultJailDuration
|
|
}
|
|
s.AddRule(&badactor.Rule{
|
|
Name: r.Name,
|
|
Message: r.Message,
|
|
StrikeLimit: r.StrikeLimit,
|
|
ExpireBase: expire,
|
|
Sentence: expire,
|
|
Action: &action{fn: callback},
|
|
})
|
|
}
|
|
// we ignore the error because CreateDirectors never returns an error
|
|
_ = s.CreateDirectors(o.cap)
|
|
s.StartReaper(o.reaperInterval)
|
|
return &ban{s: s, rules: rules, actor: o.actorFunc}
|
|
}
|
|
|
|
func (b *ban) UnaryServerInterceptor() grpc.UnaryServerInterceptor {
|
|
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
|
|
actor, ok, err := b.check(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ctx = set(ctx, b, actor)
|
|
if !ok {
|
|
return handler(ctx, req)
|
|
}
|
|
for _, v := range b.rules {
|
|
if b.s.IsJailedFor(actor, v.Name) {
|
|
return nil, status.Error(v.Code, v.Message)
|
|
}
|
|
}
|
|
res, err := handler(ctx, req)
|
|
if err != nil {
|
|
return nil, b.handleErr(ctx, actor, err)
|
|
}
|
|
return res, nil
|
|
}
|
|
}
|
|
|
|
func (b *ban) StreamServerInterceptor() grpc.StreamServerInterceptor {
|
|
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
|
actor, ok, err := b.check(ss.Context())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
ss = interceptors.NewContextServerStream(set(ss.Context(), b, actor), ss)
|
|
if !ok {
|
|
return handler(srv, ss)
|
|
}
|
|
if err := handler(srv, ss); err != nil {
|
|
return b.handleErr(ss.Context(), actor, err)
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (b *ban) check(ctx context.Context) (actor string, ok bool, err error) {
|
|
actor, ok, err = b.actor(ctx)
|
|
if err != nil {
|
|
return "", false, err
|
|
}
|
|
if !ok {
|
|
return "", false, nil
|
|
}
|
|
for _, v := range b.rules {
|
|
if b.s.IsJailedFor(actor, v.Name) {
|
|
return actor, false, status.Error(v.Code, v.Message)
|
|
}
|
|
}
|
|
return actor, true, nil
|
|
}
|
|
|
|
func (b *ban) handleErr(ctx context.Context, actor string, err error) error {
|
|
v, ok := ctx.Value(key{}).(*value)
|
|
if !ok || v.done {
|
|
return err
|
|
}
|
|
s, ok := status.FromError(err)
|
|
if !ok {
|
|
return err
|
|
}
|
|
r, ok := b.rules[s.Code()]
|
|
if !ok {
|
|
return err
|
|
}
|
|
if err := b.s.Infraction(actor, r.Name); err != nil {
|
|
logger.C(ctx).Warnf("%s: failed to add infraction: %v", r.Name, err)
|
|
}
|
|
return err
|
|
}
|