mirror of
https://github.com/linka-cloud/grpc.git
synced 2024-11-22 10:56:26 +00:00
161 lines
5.5 KiB
Go
161 lines
5.5 KiB
Go
|
// Copyright 2017 Michal Witkowski. All Rights Reserved.
|
||
|
// See LICENSE for licensing terms.
|
||
|
|
||
|
package proxy
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"io"
|
||
|
|
||
|
"google.golang.org/grpc"
|
||
|
"google.golang.org/grpc/codes"
|
||
|
"google.golang.org/grpc/status"
|
||
|
"google.golang.org/protobuf/types/known/emptypb"
|
||
|
)
|
||
|
|
||
|
var (
|
||
|
clientStreamDescForProxying = &grpc.StreamDesc{
|
||
|
ServerStreams: true,
|
||
|
ClientStreams: true,
|
||
|
}
|
||
|
)
|
||
|
|
||
|
// RegisterService sets up a proxy handler for a particular gRPC service and method.
|
||
|
// The behaviour is the same as if you were registering a handler method, e.g. from a generated pb.go file.
|
||
|
func RegisterService(server *grpc.Server, director StreamDirector, serviceName string, methodNames ...string) {
|
||
|
streamer := &handler{director}
|
||
|
fakeDesc := &grpc.ServiceDesc{
|
||
|
ServiceName: serviceName,
|
||
|
HandlerType: (*interface{})(nil),
|
||
|
}
|
||
|
for _, m := range methodNames {
|
||
|
streamDesc := grpc.StreamDesc{
|
||
|
StreamName: m,
|
||
|
Handler: streamer.handler,
|
||
|
ServerStreams: true,
|
||
|
ClientStreams: true,
|
||
|
}
|
||
|
fakeDesc.Streams = append(fakeDesc.Streams, streamDesc)
|
||
|
}
|
||
|
server.RegisterService(fakeDesc, streamer)
|
||
|
}
|
||
|
|
||
|
// TransparentHandler returns a handler that attempts to proxy all requests that are not registered in the server.
|
||
|
// The indented use here is as a transparent proxy, where the server doesn't know about the services implemented by the
|
||
|
// backends. It should be used as a `grpc.UnknownServiceHandler`.
|
||
|
func TransparentHandler(director StreamDirector) grpc.StreamHandler {
|
||
|
streamer := &handler{director: director}
|
||
|
return streamer.handler
|
||
|
}
|
||
|
|
||
|
type handler struct {
|
||
|
director StreamDirector
|
||
|
}
|
||
|
|
||
|
// handler is where the real magic of proxying happens.
|
||
|
// It is invoked like any gRPC server stream and uses the emptypb.Empty type server
|
||
|
// to proxy calls between the input and output streams.
|
||
|
func (s *handler) handler(srv interface{}, serverStream grpc.ServerStream) error {
|
||
|
// little bit of gRPC internals never hurt anyone
|
||
|
fullMethodName, ok := grpc.MethodFromServerStream(serverStream)
|
||
|
if !ok {
|
||
|
return status.Errorf(codes.Internal, "lowLevelServerStream not exists in context")
|
||
|
}
|
||
|
// We require that the director's returned context inherits from the serverStream.Context().
|
||
|
outgoingCtx, backendConn, err := s.director(serverStream.Context(), fullMethodName)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
clientCtx, clientCancel := context.WithCancel(outgoingCtx)
|
||
|
defer clientCancel()
|
||
|
// TODO(mwitkow): Add a `forwarded` header to metadata, https://en.wikipedia.org/wiki/X-Forwarded-For.
|
||
|
clientStream, err := backendConn.NewStream(clientCtx, clientStreamDescForProxying, fullMethodName)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
// Explicitly *do not close* s2cErrChan and c2sErrChan, otherwise the select below will not terminate.
|
||
|
// Channels do not have to be closed, it is just a control flow mechanism, see
|
||
|
// https://groups.google.com/forum/#!msg/golang-nuts/pZwdYRGxCIk/qpbHxRRPJdUJ
|
||
|
s2cErrChan := s.forwardServerToClient(serverStream, clientStream)
|
||
|
c2sErrChan := s.forwardClientToServer(clientStream, serverStream)
|
||
|
// We don't know which side is going to stop sending first, so we need a select between the two.
|
||
|
for i := 0; i < 2; i++ {
|
||
|
select {
|
||
|
case s2cErr := <-s2cErrChan:
|
||
|
if s2cErr == io.EOF {
|
||
|
// this is the happy case where the sender has encountered io.EOF, and won't be sending anymore./
|
||
|
// the clientStream>serverStream may continue pumping though.
|
||
|
clientStream.CloseSend()
|
||
|
} else {
|
||
|
// however, we may have gotten a receive error (stream disconnected, a read error etc) in which case we need
|
||
|
// to cancel the clientStream to the backend, let all of its goroutines be freed up by the CancelFunc and
|
||
|
// exit with an error to the stack
|
||
|
clientCancel()
|
||
|
return status.Errorf(codes.Internal, "failed proxying s2c: %v", s2cErr)
|
||
|
}
|
||
|
case c2sErr := <-c2sErrChan:
|
||
|
// This happens when the clientStream has nothing else to offer (io.EOF), returned a gRPC error. In those two
|
||
|
// cases we may have received Trailers as part of the call. In case of other errors (stream closed) the trailers
|
||
|
// will be nil.
|
||
|
serverStream.SetTrailer(clientStream.Trailer())
|
||
|
// c2sErr will contain RPC error from client code. If not io.EOF return the RPC error as server stream error.
|
||
|
if c2sErr != io.EOF {
|
||
|
return c2sErr
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
}
|
||
|
return status.Errorf(codes.Internal, "gRPC proxying should never reach this stage.")
|
||
|
}
|
||
|
|
||
|
func (s *handler) forwardClientToServer(src grpc.ClientStream, dst grpc.ServerStream) chan error {
|
||
|
ret := make(chan error, 1)
|
||
|
go func() {
|
||
|
f := &emptypb.Empty{}
|
||
|
for i := 0; ; i++ {
|
||
|
if err := src.RecvMsg(f); err != nil {
|
||
|
ret <- err // this can be io.EOF which is happy case
|
||
|
break
|
||
|
}
|
||
|
if i == 0 {
|
||
|
// This is a bit of a hack, but client to server headers are only readable after first client msg is
|
||
|
// received but must be written to server stream before the first msg is flushed.
|
||
|
// This is the only place to do it nicely.
|
||
|
md, err := src.Header()
|
||
|
if err != nil {
|
||
|
ret <- err
|
||
|
break
|
||
|
}
|
||
|
if err := dst.SendHeader(md); err != nil {
|
||
|
ret <- err
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
if err := dst.SendMsg(f); err != nil {
|
||
|
ret <- err
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
}()
|
||
|
return ret
|
||
|
}
|
||
|
|
||
|
func (s *handler) forwardServerToClient(src grpc.ServerStream, dst grpc.ClientStream) chan error {
|
||
|
ret := make(chan error, 1)
|
||
|
go func() {
|
||
|
f := &emptypb.Empty{}
|
||
|
for i := 0; ; i++ {
|
||
|
if err := src.RecvMsg(f); err != nil {
|
||
|
ret <- err // this can be io.EOF which is happy case
|
||
|
break
|
||
|
}
|
||
|
if err := dst.SendMsg(f); err != nil {
|
||
|
ret <- err
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
}()
|
||
|
return ret
|
||
|
}
|