grpc/example/example.go
Adphi 9a645d3f44
add h2c/http2 support with WithoutCmux option
Signed-off-by: Adphi <philippe.adrien.nousse@gmail.com>
2023-07-10 13:50:29 +02:00

383 lines
11 KiB
Go

package main
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"time"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/spf13/cobra"
"golang.org/x/net/http2"
"google.golang.org/grpc"
"google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/metadata"
greflectsvc "google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/descriptorpb"
"go.linka.cloud/grpc-toolkit/client"
"go.linka.cloud/grpc-toolkit/interceptors/auth"
"go.linka.cloud/grpc-toolkit/interceptors/ban"
"go.linka.cloud/grpc-toolkit/interceptors/defaulter"
"go.linka.cloud/grpc-toolkit/interceptors/iface"
metrics2 "go.linka.cloud/grpc-toolkit/interceptors/metrics"
validation2 "go.linka.cloud/grpc-toolkit/interceptors/validation"
"go.linka.cloud/grpc-toolkit/logger"
"go.linka.cloud/grpc-toolkit/service"
)
var (
_ iface.UnaryInterceptor = (*GreeterHandler)(nil)
_ iface.StreamInterceptor = (*GreeterHandler)(nil)
)
type GreeterHandler struct {
UnimplementedGreeterServer
}
func hello(name string) string {
return fmt.Sprintf("Hello %s !", name)
}
func (g *GreeterHandler) SayHello(ctx context.Context, req *HelloRequest) (*HelloReply, error) {
return &HelloReply{Message: hello(req.Name)}, nil
}
func (g *GreeterHandler) SayHelloStream(req *HelloStreamRequest, s Greeter_SayHelloStreamServer) error {
for i := int64(0); i < req.Count; i++ {
if err := s.Send(&HelloReply{Message: fmt.Sprintf("Hello %s (%d)!", req.Name, i+1)}); err != nil {
return err
}
// time.Sleep(time.Second)
}
return nil
}
func (g *GreeterHandler) UnaryServerInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
logger.C(ctx).Infof("called service interface unary interceptor")
return handler(ctx, req)
}
}
func (g *GreeterHandler) StreamServerInterceptor() grpc.StreamServerInterceptor {
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
logger.C(ss.Context()).Infof("called service interface stream interceptor")
return handler(srv, ss)
}
}
func httpLogger(next http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
start := time.Now()
log := logger.From(request.Context()).WithFields(
"method", request.Method,
"host", request.Host,
"path", request.URL.Path,
"remoteAddress", request.RemoteAddr,
)
next.ServeHTTP(writer, request)
log.WithField("duration", time.Since(start)).Info()
})
}
func main() {
f, opts := service.NewFlagSet()
cmd := &cobra.Command{
Use: "example",
Run: func(cmd *cobra.Command, args []string) {
run(opts)
},
}
cmd.Flags().AddFlagSet(f)
cmd.Execute()
}
func run(opts ...service.Option) {
name := "greeter"
version := "v0.0.1"
secure := true
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
log := logger.New().WithFields("service", name)
ctx = logger.Set(ctx, log)
done := make(chan struct{})
ready := make(chan struct{})
var svc service.Service
var err error
metrics := metrics2.DefaultInterceptors()
validation := validation2.NewInterceptors(true)
defaulter := defaulter.NewInterceptors()
address := "0.0.0.0:9991"
opts = append(opts, service.WithContext(ctx),
service.WithName(name),
service.WithVersion(version),
service.WithAddress(address),
// service.WithRegistry(mdns.NewRegistry()),
service.WithReflection(true),
service.WithSecure(secure),
service.WithAfterStart(func() error {
log.Info("Server listening on", svc.Options().Address())
close(ready)
return nil
}),
service.WithAfterStop(func() error {
log.Info("Stopping server")
close(done)
return nil
}),
service.WithoutCmux(),
service.WithGateway(RegisterGreeterHandler),
service.WithGatewayPrefix("/rest"),
service.WithGRPCWeb(true),
service.WithGRPCWebPrefix("/grpc"),
service.WithMiddlewares(httpLogger),
service.WithInterceptors(metrics),
service.WithServerInterceptors(
ban.NewInterceptors(ban.WithDefaultJailDuration(time.Second), ban.WithDefaultCallback(func(action ban.Action, actor string, rule *ban.Rule) error {
log.WithFields("action", action, "actor", actor, "rule", rule.Name).Info("ban callback")
return nil
})),
auth.NewServerInterceptors(auth.WithBasicValidators(func(ctx context.Context, user, password string) (context.Context, error) {
if !auth.Equals(user, "admin") || !auth.Equals(password, "admin") {
return ctx, fmt.Errorf("invalid user or password")
}
log.Infof("request authenticated")
return ctx, nil
})),
),
service.WithInterceptors(defaulter, validation),
// enable server interface interceptor
service.WithServerInterceptors(iface.New()),
)
svc, err = service.New(opts...)
if err != nil {
panic(err)
}
RegisterGreeterServer(svc, &GreeterHandler{})
metrics.EnableHandlingTimeHistogram()
metrics.Register(svc)
go func() {
if err := svc.Start(); err != nil {
panic(err)
}
}()
go func() {
http.Handle("/metrics", promhttp.Handler())
if err := http.ListenAndServe(":9992", nil); err != nil {
panic(err)
}
}()
<-ready
copts := []client.Option{
// client.WithName(name),
// client.WithVersion(version),
client.WithAddress("localhost:9991"),
// client.WithRegistry(mdns.NewRegistry()),
client.WithSecure(secure),
client.WithUnaryInterceptors(func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
logger.From(ctx).WithFields("party", "client", "method", method).Info(req)
return invoker(ctx, method, req, reply, cc, opts...)
}),
}
s, err := client.New(copts...)
if err != nil {
log.Fatal(err)
}
g := NewGreeterClient(s)
h := grpc_health_v1.NewHealthClient(s)
for i := 0; i < 5; i++ {
_, err := h.Check(ctx, &grpc_health_v1.HealthCheckRequest{})
if err != nil {
log.Error(err)
} else {
log.Fatalf("expected error")
}
}
log.Infof("waiting for unban")
time.Sleep(time.Second)
s, err = client.New(append(copts, client.WithInterceptors(auth.NewBasicAuthClientIntereptors("admin", "admin")))...)
if err != nil {
log.Fatal(err)
}
g = NewGreeterClient(s)
h = grpc_health_v1.NewHealthClient(s)
hres, err := h.Check(ctx, &grpc_health_v1.HealthCheckRequest{Service: "helloworld.Greeter"})
if err != nil {
log.Fatal(err)
}
log.Infof("status: %v", hres.Status)
md := metadata.MD{}
res, err := g.SayHello(ctx, &HelloRequest{Name: "test"}, grpc.Header(&md))
if err != nil {
log.Fatal(err)
}
logMetadata(ctx, md)
log.Infof("received message: %s", res.Message)
md = metadata.MD{}
res, err = g.SayHello(ctx, &HelloRequest{}, grpc.Header(&md))
if err == nil {
log.Fatal("expected validation error")
}
logMetadata(ctx, md)
stream, err := g.SayHelloStream(ctx, &HelloStreamRequest{Name: "test"}, grpc.Header(&md))
if err != nil {
log.Fatal(err)
}
if md, err := stream.Header(); err == nil {
logMetadata(ctx, md)
}
for {
m, err := stream.Recv()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
log.Fatal(err)
}
log.Infof("received stream message: %s", m.Message)
}
scheme := "http://"
var (
tlsConfig *tls.Config
dial func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error)
)
if secure {
scheme = "https://"
tlsConfig = &tls.Config{
InsecureSkipVerify: true,
}
} else {
dial = func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
var d net.Dialer
return d.DialContext(ctx, network, addr)
}
}
httpc := &http.Client{
Transport: &http2.Transport{
AllowHTTP: true,
TLSClientConfig: tlsConfig,
DialTLSContext: dial,
},
}
req := `{"name":"test"}`
do := func(url, contentType string) {
req, err := http.NewRequest(http.MethodPost, url, strings.NewReader(req))
if err != nil {
log.Fatal(err)
}
req.Header.Set("content-type", contentType)
req.Header.Set("authorization", auth.BasicAuth("admin", "admin"))
resp, err := httpc.Do(req)
if err != nil {
log.Fatal(err)
}
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body)
if err != nil {
log.Fatal(err)
}
log.Info(string(b))
}
do(scheme+address+"/rest/api/v1/greeter/hello", "application/json")
do(scheme+address+"/grpc/helloworld.Greeter/SayHello", "application/grpc-web+json")
if err := readSvcs(ctx, s); err != nil {
log.Fatal(err)
}
cancel()
<-done
}
func logMetadata(ctx context.Context, md metadata.MD) {
log := logger.From(ctx)
for k, v := range md {
log.Infof("%s: %v", k, v)
}
}
func readSvcs(ctx context.Context, c client.Client) (err error) {
log := logger.From(ctx)
rc := greflectsvc.NewServerReflectionClient(c)
rstream, err := rc.ServerReflectionInfo(ctx)
if err != nil {
return err
}
defer func() {
if err2 := rstream.CloseSend(); err2 != nil && err == nil {
err = err2
}
}()
if err = rstream.Send(&greflectsvc.ServerReflectionRequest{MessageRequest: &greflectsvc.ServerReflectionRequest_ListServices{}}); err != nil {
return err
}
var rres *greflectsvc.ServerReflectionResponse
rres, err = rstream.Recv()
if err != nil {
return err
}
rlist, ok := rres.MessageResponse.(*greflectsvc.ServerReflectionResponse_ListServicesResponse)
if !ok {
return fmt.Errorf("unexpected reflection response type: %T", rres.MessageResponse)
}
for _, v := range rlist.ListServicesResponse.Service {
if v.Name == "grpc.reflection.v1alpha.ServerReflection" {
continue
}
parts := strings.Split(v.Name, ".")
if len(parts) < 2 {
return fmt.Errorf("malformed service name: %s", v.Name)
}
pkg := strings.Join(parts[:len(parts)-1], ".")
svc := parts[len(parts)-1]
if err = rstream.Send(&greflectsvc.ServerReflectionRequest{MessageRequest: &greflectsvc.ServerReflectionRequest_FileContainingSymbol{
FileContainingSymbol: v.Name,
}}); err != nil {
return err
}
rres, err = rstream.Recv()
if err != nil {
log.Fatal(err)
}
rfile, ok := rres.MessageResponse.(*greflectsvc.ServerReflectionResponse_FileDescriptorResponse)
if !ok {
return fmt.Errorf("unexpected reflection response type: %T", rres.MessageResponse)
}
fdps := make(map[string]*descriptorpb.DescriptorProto)
var sdp *descriptorpb.ServiceDescriptorProto
for _, v := range rfile.FileDescriptorResponse.FileDescriptorProto {
fdp := &descriptorpb.FileDescriptorProto{}
if err = proto.Unmarshal(v, fdp); err != nil {
return err
}
for _, s := range fdp.GetService() {
if fdp.GetPackage() == pkg && s.GetName() == svc {
if sdp != nil {
log.Warnf("service already found: %s.%s", fdp.GetPackage(), s.GetName())
continue
}
sdp = s
}
}
for _, m := range fdp.GetMessageType() {
fdps[fdp.GetPackage()+"."+m.GetName()] = m
}
}
if sdp == nil {
return fmt.Errorf("%s: service not found", v.Name)
}
for _, m := range sdp.GetMethod() {
log.Infof("%s: %s", v.Name, m.GetName())
}
}
return nil
}