grpc/interceptors/validation/validation.go

182 lines
5.2 KiB
Go
Raw Permalink Normal View History

package validation
import (
"context"
"google.golang.org/genproto/googleapis/rpc/errdetails"
"google.golang.org/grpc"
"go.linka.cloud/grpc-toolkit/errors"
"go.linka.cloud/grpc-toolkit/interceptors"
)
type validatorAll interface {
Validate() error
ValidateAll() error
}
// The validate interface starting with protoc-gen-validate v0.6.0.
// See https://github.com/envoyproxy/protoc-gen-validate/pull/455.
type validator interface {
Validate(all bool) error
}
// The validate interface prior to protoc-gen-validate v0.6.0.
type validatorLegacy interface {
Validate() error
}
type validatorMultiError interface {
AllErrors() []error
}
type validatorError interface {
Field() string
Reason() string
Key() bool
Cause() error
ErrorName() string
}
func validatorErrorToGrpc(e validatorError, prefix string) []*errdetails.BadRequest_FieldViolation {
// check nested errors for validation error, e.g. "embedded message failed validation"
switch v := e.Cause().(type) {
case validatorError:
return validatorErrorToGrpc(v, e.Field()+".")
case validatorMultiError:
var details []*errdetails.BadRequest_FieldViolation
for _, vv := range v.AllErrors() {
if ee, ok := vv.(validatorError); ok {
details = append(details, validatorErrorToGrpc(ee, e.Field()+".")...)
}
}
return details
default:
return []*errdetails.BadRequest_FieldViolation{{
Field: prefix + e.Field(),
Description: e.Reason(),
}}
}
}
func errToStatus(err error) error {
if err == nil {
return nil
}
switch v := err.(type) {
case validatorError:
return errors.InvalidArgumentd(err, &errdetails.BadRequest{FieldViolations: validatorErrorToGrpc(v, "")})
case validatorMultiError:
details := &errdetails.BadRequest{}
for _, v := range v.AllErrors() {
if d, ok := v.(validatorError); ok {
details.FieldViolations = append(details.FieldViolations, validatorErrorToGrpc(d, "")...)
}
}
return errors.InvalidArgumentd(err, details)
default:
return errors.InvalidArgument(err)
}
}
func (i interceptor) validate(req interface{}) error {
switch v := req.(type) {
case validatorAll:
if i.all {
return errToStatus(v.ValidateAll())
}
return errToStatus(v.Validate())
case validatorLegacy:
return errToStatus(v.Validate())
case validator:
return errToStatus(v.Validate(i.all))
}
return nil
}
type interceptor struct {
all bool
}
func NewInterceptors(validateAll bool) interceptors.Interceptors {
return &interceptor{all: validateAll}
}
// UnaryServerInterceptor returns a new unary server interceptor that validates incoming messages.
//
// Invalid messages will be rejected with `InvalidArgument` before reaching any userspace handlers.
func (i interceptor) UnaryServerInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
if err := i.validate(req); err != nil {
return nil, err
}
return handler(ctx, req)
}
}
// UnaryClientInterceptor returns a new unary client interceptor that validates outgoing messages.
//
// Invalid messages will be rejected with `InvalidArgument` before sending the request to server.
func (i interceptor) UnaryClientInterceptor() grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
if err := i.validate(req); err != nil {
return err
}
return invoker(ctx, method, req, reply, cc, opts...)
}
}
// StreamServerInterceptor returns a new streaming server interceptor that validates incoming messages.
//
// The stage at which invalid messages will be rejected with `InvalidArgument` varies based on the
// type of the RPC. For `ServerStream` (1:m) requests, it will happen before reaching any userspace
// handlers. For `ClientStream` (n:1) or `BidiStream` (n:m) RPCs, the messages will be rejected on
// calls to `stream.Recv()`.
func (i interceptor) StreamServerInterceptor() grpc.StreamServerInterceptor {
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
wrapper := &recvWrapper{ServerStream: stream, i: i}
return handler(srv, wrapper)
}
}
func (i interceptor) StreamClientInterceptor() grpc.StreamClientInterceptor {
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
desc.Handler = (&sendWrapper{handler: desc.Handler, i: i}).Handler()
return streamer(ctx, desc, cc, method)
}
}
type recvWrapper struct {
i interceptor
grpc.ServerStream
}
func (s *recvWrapper) RecvMsg(m interface{}) error {
if err := s.ServerStream.RecvMsg(m); err != nil {
return err
}
if err := s.i.validate(m); err != nil {
return err
}
return nil
}
type sendWrapper struct {
i interceptor
grpc.ServerStream
handler grpc.StreamHandler
}
func (s *sendWrapper) Handler() grpc.StreamHandler {
return func(srv interface{}, stream grpc.ServerStream) error {
return s.handler(srv, s)
}
}
func (s *sendWrapper) SendMsg(m interface{}) error {
if err := s.i.validate(m); err != nil {
return err
}
return s.ServerStream.SendMsg(m)
}