grpc/service/options.go

559 lines
11 KiB
Go

package service
import (
"context"
"crypto/tls"
"crypto/x509"
"embed"
"fmt"
"net"
"os"
"strings"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"github.com/improbable-eng/grpc-web/go/grpcweb"
"github.com/rs/cors"
"google.golang.org/grpc"
"go.linka.cloud/grpc/certs"
"go.linka.cloud/grpc/interceptors"
"go.linka.cloud/grpc/registry"
"go.linka.cloud/grpc/transport"
"go.linka.cloud/grpc/utils/addr"
)
var _ Options = (*options)(nil)
type RegisterGatewayFunc func(ctx context.Context, mux *runtime.ServeMux, cc grpc.ClientConnInterface) error
type Options interface {
Context() context.Context
Name() string
Version() string
Address() string
Reflection() bool
Health() bool
CACert() string
Cert() string
Key() string
TLSConfig() *tls.Config
Secure() bool
Registry() registry.Registry
BeforeStart() []func() error
AfterStart() []func() error
BeforeStop() []func() error
AfterStop() []func() error
ServerOpts() []grpc.ServerOption
ServerInterceptors() []grpc.UnaryServerInterceptor
StreamServerInterceptors() []grpc.StreamServerInterceptor
ClientInterceptors() []grpc.UnaryClientInterceptor
StreamClientInterceptors() []grpc.StreamClientInterceptor
Cors() cors.Options
Mux() ServeMux
GRPCWeb() bool
GRPCWebPrefix() string
GRPCWebOpts() []grpcweb.Option
Gateway() bool
GatewayPrefix() string
GatewayOpts() []runtime.ServeMuxOption
// TODO(adphi): metrics + tracing
Default()
}
func NewOptions() *options {
return &options{
ctx: context.Background(),
address: ":0",
health: true,
}
}
func (o *options) Default() {
if o.ctx == nil {
o.ctx = context.Background()
}
if o.address == "" {
o.address = "0.0.0.0:0"
}
if o.transport == nil {
o.transport = &grpc.Server{}
}
}
type Option func(*options)
func WithName(name string) Option {
return func(o *options) {
o.name = name
}
}
func WithVersion(version string) Option {
return func(o *options) {
o.version = version
}
}
func WithRegistry(registry registry.Registry) Option {
return func(o *options) {
o.registry = registry
}
}
// WithContext specifies a context for the service.
// Can be used to signal shutdown of the service.
// Can be used for extra option values.
func WithContext(ctx context.Context) Option {
return func(o *options) {
o.ctx = ctx
}
}
// WithAddress sets the address of the server
func WithAddress(addr string) Option {
return func(o *options) {
o.address = addr
}
}
// WithListener specifies a listener for the service.
// It can be used to specify a custom listener.
// This will override the WithAddress and WithTLSConfig options
func WithListener(lis net.Listener) Option {
return func(o *options) {
o.lis = lis
}
}
func WithReflection(r bool) Option {
return func(o *options) {
o.reflection = r
}
}
func WithHealth(h bool) Option {
return func(o *options) {
o.health = h
}
}
func WithSecure(s bool) Option {
return func(o *options) {
o.secure = s
}
}
func WithGRPCServerOpts(opts ...grpc.ServerOption) Option {
return func(o *options) {
o.serverOpts = append(o.serverOpts, opts...)
}
}
func WithCACert(path string) Option {
return func(o *options) {
o.caCert = path
}
}
func WithCert(path string) Option {
return func(o *options) {
o.cert = path
}
}
func WithKey(path string) Option {
return func(o *options) {
o.key = path
}
}
func WithTLSConfig(conf *tls.Config) Option {
return func(o *options) {
o.tlsConfig = conf
}
}
func WithBeforeStart(fn ...func() error) Option {
return func(o *options) {
o.beforeStart = append(o.beforeStart, fn...)
}
}
func WithBeforeStop(fn ...func() error) Option {
return func(o *options) {
o.beforeStop = append(o.beforeStop, fn...)
}
}
func WithAfterStart(fn ...func() error) Option {
return func(o *options) {
o.afterStart = append(o.afterStart, fn...)
}
}
func WithAfterStop(fn ...func() error) Option {
return func(o *options) {
o.afterStop = append(o.afterStop, fn...)
}
}
func WithInterceptors(i ...interceptors.Interceptors) Option {
return func(o *options) {
for _, v := range i {
o.unaryServerInterceptors = append(o.unaryServerInterceptors, v.UnaryServerInterceptor())
o.streamServerInterceptors = append(o.streamServerInterceptors, v.StreamServerInterceptor())
o.unaryClientInterceptors = append(o.unaryClientInterceptors, v.UnaryClientInterceptor())
o.streamClientInterceptors = append(o.streamClientInterceptors, v.StreamClientInterceptor())
}
}
}
func WithServerInterceptors(i ...interceptors.ServerInterceptors) Option {
return func(o *options) {
for _, v := range i {
o.unaryServerInterceptors = append(o.unaryServerInterceptors, v.UnaryServerInterceptor())
o.streamServerInterceptors = append(o.streamServerInterceptors, v.StreamServerInterceptor())
}
}
}
func WithClientInterceptors(i ...interceptors.ClientInterceptors) Option {
return func(o *options) {
for _, v := range i {
o.unaryClientInterceptors = append(o.unaryClientInterceptors, v.UnaryClientInterceptor())
o.streamClientInterceptors = append(o.streamClientInterceptors, v.StreamClientInterceptor())
}
}
}
func WithUnaryClientInterceptor(i ...grpc.UnaryClientInterceptor) Option {
return func(o *options) {
o.unaryClientInterceptors = append(o.unaryClientInterceptors, i...)
}
}
// WithUnaryServerInterceptor adds unary Wrapper interceptors to the options passed into the server
func WithUnaryServerInterceptor(i ...grpc.UnaryServerInterceptor) Option {
return func(o *options) {
o.unaryServerInterceptors = append(o.unaryServerInterceptors, i...)
}
}
func WithStreamServerInterceptor(i ...grpc.StreamServerInterceptor) Option {
return func(o *options) {
o.streamServerInterceptors = append(o.streamServerInterceptors, i...)
}
}
func WithStreamClientInterceptor(i ...grpc.StreamClientInterceptor) Option {
return func(o *options) {
o.streamClientInterceptors = append(o.streamClientInterceptors, i...)
}
}
// WithSubscriberInterceptor adds subscriber interceptors to the options passed into the server
func WithSubscriberInterceptor(w ...interface{}) Option {
return func(o *options) {
}
}
func WithCors(opts cors.Options) Option {
return func(o *options) {
o.cors = opts
}
}
func WithMux(mux ServeMux) Option {
return func(o *options) {
o.mux = mux
}
}
func WithMiddlewares(m ...Middleware) Option {
return func(o *options) {
o.middlewares = m
}
}
func WithGRPCWeb(b bool) Option {
return func(o *options) {
o.grpcWeb = b
}
}
func WithGRPCWebPrefix(prefix string) Option {
return func(o *options) {
o.grpcWebPrefix = strings.TrimSuffix(prefix, "/")
}
}
func WithGRPCWebOpts(opts ...grpcweb.Option) Option {
return func(o *options) {
o.grpcWebOpts = opts
}
}
func WithGateway(fn RegisterGatewayFunc) Option {
return func(o *options) {
o.gateway = fn
}
}
func WithGatewayPrefix(prefix string) Option {
return func(o *options) {
o.gatewayPrefix = strings.TrimSuffix(prefix, "/")
}
}
func WithGatewayOpts(opts ...runtime.ServeMuxOption) Option {
return func(o *options) {
o.gatewayOpts = opts
}
}
// WithReactUI add static single page app serving to the http server
// subpath is the path in the read-only file embed.FS to use as root to serve
// static content
func WithReactUI(fs embed.FS, subpath string) Option {
return func(o *options) {
o.reactUI = fs
o.reactUISubPath = subpath
o.hasReactUI = true
}
}
type options struct {
ctx context.Context
name string
version string
address string
lis net.Listener
reflection bool
health bool
secure bool
caCert string
cert string
key string
tlsConfig *tls.Config
transport transport.Transport
registry registry.Registry
beforeStart []func() error
afterStart []func() error
beforeStop []func() error
afterStop []func() error
serverOpts []grpc.ServerOption
unaryServerInterceptors []grpc.UnaryServerInterceptor
streamServerInterceptors []grpc.StreamServerInterceptor
unaryClientInterceptors []grpc.UnaryClientInterceptor
streamClientInterceptors []grpc.StreamClientInterceptor
mux ServeMux
middlewares []Middleware
grpcWeb bool
grpcWebOpts []grpcweb.Option
grpcWebPrefix string
gateway RegisterGatewayFunc
gatewayOpts []runtime.ServeMuxOption
cors cors.Options
reactUI embed.FS
reactUISubPath string
hasReactUI bool
error error
gatewayPrefix string
}
func (o *options) Name() string {
return o.name
}
func (o *options) Version() string {
return o.version
}
func (o *options) Context() context.Context {
return o.ctx
}
func (o *options) Address() string {
return o.address
}
func (o *options) Registry() registry.Registry {
return o.registry
}
func (o *options) Reflection() bool {
return o.reflection
}
func (o *options) Health() bool {
return o.health
}
func (o *options) CACert() string {
return o.caCert
}
func (o *options) Cert() string {
return o.cert
}
func (o *options) Key() string {
return o.key
}
func (o *options) TLSConfig() *tls.Config {
return o.tlsConfig
}
func (o *options) Secure() bool {
return o.secure
}
func (o *options) BeforeStart() []func() error {
return o.beforeStart
}
func (o *options) AfterStart() []func() error {
return o.afterStart
}
func (o *options) BeforeStop() []func() error {
return o.beforeStop
}
func (o *options) AfterStop() []func() error {
return o.afterStop
}
func (o *options) ServerOpts() []grpc.ServerOption {
return o.serverOpts
}
func (o *options) ServerInterceptors() []grpc.UnaryServerInterceptor {
return o.unaryServerInterceptors
}
func (o *options) StreamServerInterceptors() []grpc.StreamServerInterceptor {
return o.streamServerInterceptors
}
func (o *options) ClientInterceptors() []grpc.UnaryClientInterceptor {
return o.unaryClientInterceptors
}
func (o *options) StreamClientInterceptors() []grpc.StreamClientInterceptor {
return o.streamClientInterceptors
}
func (o *options) Cors() cors.Options {
return o.cors
}
func (o *options) Mux() ServeMux {
return o.mux
}
func (o *options) GRPCWeb() bool {
return o.grpcWeb
}
func (o *options) GRPCWebPrefix() string {
return o.grpcWebPrefix
}
func (o *options) GRPCWebOpts() []grpcweb.Option {
return o.grpcWebOpts
}
func (o *options) Gateway() bool {
return o.gateway != nil
}
func (o *options) GatewayPrefix() string {
return o.gatewayPrefix
}
func (o *options) GatewayOpts() []runtime.ServeMuxOption {
return o.gatewayOpts
}
func (o *options) parseTLSConfig() error {
if o.tlsConfig != nil {
return nil
}
if !o.hasTLSConfig() {
if !o.secure {
return nil
}
var hosts []string
if host, _, err := net.SplitHostPort(o.address); err == nil {
if len(host) == 0 {
hosts = addr.IPs()
} else {
hosts = []string{host}
}
}
for i, h := range hosts {
a, err := addr.Extract(h)
if err != nil {
return err
}
hosts[i] = a
}
// generate a certificate
cert, err := certs.New(hosts...)
if err != nil {
return err
}
o.tlsConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
ClientAuth: tls.NoClientCert,
}
return nil
}
caCert, err := os.ReadFile(o.caCert)
if err != nil {
return err
}
caCertPool := x509.NewCertPool()
ok := caCertPool.AppendCertsFromPEM(caCert)
if !ok {
return fmt.Errorf("failed to load CA Cert from %s", o.caCert)
}
cert, err := tls.LoadX509KeyPair(o.cert, o.key)
if err != nil {
return err
}
o.tlsConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
RootCAs: caCertPool,
}
return nil
}
func (o *options) hasTLSConfig() bool {
return o.caCert != "" && o.cert != "" && o.key != "" && o.tlsConfig == nil
}