diff --git a/client/client.go b/client/client.go index 8087c3f..d2aecd5 100644 --- a/client/client.go +++ b/client/client.go @@ -2,7 +2,6 @@ package client import ( "context" - "crypto/tls" "fmt" "strings" @@ -28,9 +27,8 @@ func New(opts ...Option) (Client, error) { c.opts.registry = noop.New() } resolver.Register(c.opts.registry.ResolverBuilder()) - c.pool = newPool(DefaultPoolSize, DefaultPoolTTL, DefaultPoolMaxIdle, DefaultPoolMaxStreams) - if c.opts.tlsConfig == nil && c.opts.Secure() { - c.opts.tlsConfig = &tls.Config{InsecureSkipVerify: true} + if err := c.opts.parseTLSConfig(); err != nil { + return nil, err } if c.opts.tlsConfig != nil { c.opts.dialOptions = append(c.opts.dialOptions, grpc.WithTransportCredentials(credentials.NewTLS(c.opts.tlsConfig))) @@ -59,27 +57,24 @@ func New(opts ...Option) (Client, error) { if c.opts.version != "" && c.opts.addr == "" { c.addr = c.addr + ":" + strings.TrimSpace(c.opts.version) } + cc, err := grpc.Dial(c.addr, c.opts.dialOptions...) + if err != nil { + return nil, err + } + c.cc = cc return c, nil } type client struct { addr string - pool *pool opts *options + cc *grpc.ClientConn } func (c *client) Invoke(ctx context.Context, method string, args interface{}, reply interface{}, opts ...grpc.CallOption) error { - pc, err := c.pool.getConn(c.addr, c.opts.dialOptions...) - if err != nil { - return err - } - return pc.Invoke(ctx, method, args, reply, opts...) + return c.cc.Invoke(ctx, method, args, reply, opts...) } func (c *client) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { - pc, err := c.pool.getConn(c.addr, c.opts.dialOptions...) - if err != nil { - return nil, err - } - return pc.NewStream(ctx, desc, method, opts...) + return c.cc.NewStream(ctx, desc, method, opts...) } diff --git a/client/flags.go b/client/flags.go index 5a9a64e..c1b174a 100644 --- a/client/flags.go +++ b/client/flags.go @@ -12,29 +12,31 @@ var u = strings.ToUpper func NewFlagSet() (*pflag.FlagSet, Option) { const ( - addr = "address" - secure = "secure" - // caCert = "ca-cert" - // clientCert = "client-cert" - // clientKey = "client-key" + addr = "address" + secure = "secure" + caCert = "ca-cert" + clientCert = "client-cert" + clientKey = "client-key" ) var ( optAddress string optSecure bool - // optCACert string - // optCert string - // optKey string + optCACert string + optCert string + optKey string ) flags := pflag.NewFlagSet("gRPC", pflag.ContinueOnError) flags.StringVar(&optAddress, addr, env.GetDefault(u(addr), "0.0.0.0:0"), "Bind address for the server. 127.0.0.1:9090"+flagEnv(addr)) flags.BoolVar(&optSecure, secure, env.GetBoolDefault(u(secure), true), "Generate self signed certificate if none provided"+flagEnv(secure)) - // flags.StringVar(&optCACert, caCert, "", "Path to Root CA certificate"+flagEnv(optCACert)) - // flags.StringVar(&optCert, clientCert, "", "Path to Server certificate"+flagEnv(clientCert)) - // flags.StringVar(&optKey, clientKey, "", "Path to Server key"+flagEnv(clientKey)) + flags.StringVar(&optCACert, caCert, "", "Path to Root CA certificate"+flagEnv(optCACert)) + flags.StringVar(&optCert, clientCert, "", "Path to Server certificate"+flagEnv(clientCert)) + flags.StringVar(&optKey, clientKey, "", "Path to Server key"+flagEnv(clientKey)) return flags, func(o *options) { o.addr = optAddress o.secure = optSecure - + o.caCert = optCACert + o.cert = optCert + o.key = optKey } } diff --git a/client/options.go b/client/options.go index 5b04e5b..f889cf0 100644 --- a/client/options.go +++ b/client/options.go @@ -2,6 +2,9 @@ package client import ( "crypto/tls" + "crypto/x509" + "fmt" + "os" "google.golang.org/grpc" @@ -15,6 +18,9 @@ type Options interface { Address() string Secure() bool Registry() registry.Registry + CA() string + Cert() string + Key() string TLSConfig() *tls.Config DialOptions() []grpc.DialOption UnaryInterceptors() []grpc.UnaryClientInterceptor @@ -47,6 +53,24 @@ func WithAddress(address string) Option { } } +func WithCA(ca string) Option { + return func(o *options) { + o.caCert = ca + } +} + +func WithCert(cert string) Option { + return func(o *options) { + o.cert = cert + } +} + +func WithKey(key string) Option { + return func(o *options) { + o.key = key + } +} + func WithTLSConfig(conf *tls.Config) Option { return func(o *options) { o.tlsConfig = conf @@ -87,10 +111,14 @@ func WithStreamInterceptors(i ...grpc.StreamClientInterceptor) Option { } type options struct { - registry registry.Registry - name string - version string - addr string + registry registry.Registry + name string + version string + addr string + + caCert string + cert string + key string tlsConfig *tls.Config secure bool dialOptions []grpc.DialOption @@ -115,6 +143,18 @@ func (o *options) Registry() registry.Registry { return o.registry } +func (o *options) CA() string { + return o.caCert +} + +func (o *options) Cert() string { + return o.cert +} + +func (o *options) Key() string { + return o.key +} + func (o *options) TLSConfig() *tls.Config { return o.tlsConfig } @@ -134,3 +174,38 @@ func (o *options) UnaryInterceptors() []grpc.UnaryClientInterceptor { func (o *options) StreamInterceptors() []grpc.StreamClientInterceptor { return o.streamInterceptors } + +func (o *options) hasTLSConfig() bool { + return o.caCert != "" && o.cert != "" && o.key != "" && o.tlsConfig == nil +} + +func (o *options) parseTLSConfig() error { + if o.tlsConfig != nil { + return nil + } + if !o.hasTLSConfig() { + if !o.secure { + return nil + } + o.tlsConfig = &tls.Config{InsecureSkipVerify: true} + return nil + } + caCert, err := os.ReadFile(o.caCert) + if err != nil { + return err + } + caCertPool := x509.NewCertPool() + ok := caCertPool.AppendCertsFromPEM(caCert) + if !ok { + return fmt.Errorf("failed to load CA Cert from %s", o.caCert) + } + cert, err := tls.LoadX509KeyPair(o.cert, o.key) + if err != nil { + return err + } + o.tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{cert}, + RootCAs: caCertPool, + } + return nil +} diff --git a/client/pool.go b/client/pool.go deleted file mode 100644 index 5c05543..0000000 --- a/client/pool.go +++ /dev/null @@ -1,244 +0,0 @@ -/* -Taken from the https://github.com/micro/go-micro/client/grpc -*/ - -package client - -import ( - "sync" - "time" - - "google.golang.org/grpc" - "google.golang.org/grpc/connectivity" -) - -var ( - // DefaultPoolSize sets the connection pool size - DefaultPoolSize = 100 - // DefaultPoolTTL sets the connection pool ttl - DefaultPoolTTL = time.Minute - // DefaultPoolMaxStreams maximum streams on a connectioin - // (20) - DefaultPoolMaxStreams = 20 - - // DefaultPoolMaxIdle maximum idle conns of a pool - // (50) - DefaultPoolMaxIdle = 50 - - // DefaultMaxRecvMsgSize maximum message that client can receive - // (4 MB). - DefaultMaxRecvMsgSize = 1024 * 1024 * 4 - - // DefaultMaxSendMsgSize maximum message that client can send - // (4 MB). - DefaultMaxSendMsgSize = 1024 * 1024 * 4 -) - -type pool struct { - size int - ttl int64 - - // max streams on a *poolConn - maxStreams int - // max idle conns - maxIdle int - - sync.Mutex - conns map[string]*streamsPool -} - -type streamsPool struct { - // head of list - head *poolConn - // busy conns list - busy *poolConn - // the siza of list - count int - // idle conn - idle int -} - -type poolConn struct { - // grpc conn - *grpc.ClientConn - err error - addr string - - // pool and streams pool - pool *pool - sp *streamsPool - streams int - created int64 - - // list - pre *poolConn - next *poolConn - in bool -} - -func newPool(size int, ttl time.Duration, idle int, ms int) *pool { - if ms <= 0 { - ms = 1 - } - if idle < 0 { - idle = 0 - } - return &pool{ - size: size, - ttl: int64(ttl.Seconds()), - maxStreams: ms, - maxIdle: idle, - conns: make(map[string]*streamsPool), - } -} - -func (p *pool) getConn(addr string, opts ...grpc.DialOption) (*poolConn, error) { - now := time.Now().Unix() - p.Lock() - sp, ok := p.conns[addr] - if !ok { - sp = &streamsPool{head: &poolConn{}, busy: &poolConn{}, count: 0, idle: 0} - p.conns[addr] = sp - } - // while we have conns check streams and then return one - // otherwise we'll create a new conn - conn := sp.head.next - for conn != nil { - // check conn state - // https://github.com/grpc/grpc/blob/master/doc/connectivity-semantics-and-api.md - switch conn.GetState() { - case connectivity.Connecting: - conn = conn.next - continue - case connectivity.Shutdown: - next := conn.next - if conn.streams == 0 { - removeConn(conn) - sp.idle-- - } - conn = next - continue - case connectivity.TransientFailure: - next := conn.next - if conn.streams == 0 { - removeConn(conn) - conn.ClientConn.Close() - sp.idle-- - } - conn = next - continue - case connectivity.Ready: - case connectivity.Idle: - } - // a old conn - if now-conn.created > p.ttl { - next := conn.next - if conn.streams == 0 { - removeConn(conn) - conn.ClientConn.Close() - sp.idle-- - } - conn = next - continue - } - // a busy conn - if conn.streams >= p.maxStreams { - next := conn.next - removeConn(conn) - addConnAfter(conn, sp.busy) - conn = next - continue - } - // a idle conn - if conn.streams == 0 { - sp.idle-- - } - // a good conn - conn.streams++ - p.Unlock() - return conn, nil - } - p.Unlock() - - // create new conn - cc, err := grpc.Dial(addr, opts...) - if err != nil { - return nil, err - } - conn = &poolConn{cc, nil, addr, p, sp, 1, time.Now().Unix(), nil, nil, false} - - // add conn to streams pool - p.Lock() - if sp.count < p.size { - addConnAfter(conn, sp.head) - } - p.Unlock() - - return conn, nil -} - -func (p *pool) release(addr string, conn *poolConn, err error) { - p.Lock() - p, sp, created := conn.pool, conn.sp, conn.created - // try to add conn - if !conn.in && sp.count < p.size { - addConnAfter(conn, sp.head) - } - if !conn.in { - p.Unlock() - conn.ClientConn.Close() - return - } - // a busy conn - if conn.streams >= p.maxStreams { - removeConn(conn) - addConnAfter(conn, sp.head) - } - conn.streams-- - // if streams == 0, we can do something - if conn.streams == 0 { - // 1. it has errored - // 2. too many idle conn or - // 3. conn is too old - now := time.Now().Unix() - if err != nil || sp.idle >= p.maxIdle || now-created > p.ttl { - removeConn(conn) - p.Unlock() - conn.ClientConn.Close() - return - } - sp.idle++ - } - p.Unlock() - return -} - -func (conn *poolConn) Close() { - conn.pool.release(conn.addr, conn, conn.err) -} - -func removeConn(conn *poolConn) { - if conn.pre != nil { - conn.pre.next = conn.next - } - if conn.next != nil { - conn.next.pre = conn.pre - } - conn.pre = nil - conn.next = nil - conn.in = false - conn.sp.count-- - return -} - -func addConnAfter(conn *poolConn, after *poolConn) { - conn.next = after.next - conn.pre = after - if after.next != nil { - after.next.pre = conn - } - after.next = conn - conn.in = true - conn.sp.count++ - return -} diff --git a/service/flags.go b/service/flags.go index 62e8f58..b24889b 100644 --- a/service/flags.go +++ b/service/flags.go @@ -4,7 +4,7 @@ import ( "fmt" "strings" - env "github.com/caitlinelfring/go-env-default" + "github.com/caitlinelfring/go-env-default" "github.com/spf13/pflag" ) @@ -17,6 +17,10 @@ const ( caCert = "ca-cert" serverCert = "server-cert" serverKey = "server-key" + + clientCACert = "client-ca-cert" + clientCert = "client-cert" + clientKey = "client-key" ) var u = strings.ToUpper @@ -37,6 +41,9 @@ func NewFlagSet() (*pflag.FlagSet, Option) { flags.StringVar(&optCACert, caCert, "", "Path to Root CA certificate"+flagEnv(caCert)) flags.StringVar(&optCert, serverCert, "", "Path to Server certificate"+flagEnv(serverCert)) flags.StringVar(&optKey, serverKey, "", "Path to Server key"+flagEnv(serverKey)) + flags.StringVar(&optCACert, clientCACert, "", "Path to Root CA certificate"+flagEnv(clientCACert)) + flags.StringVar(&optCert, clientCert, "", "Path to Client certificate"+flagEnv(clientCert)) + flags.StringVar(&optKey, clientKey, "", "Path to Client key"+flagEnv(clientKey)) return flags, func(o *options) { o.address = optAddress o.secure = !optInsecure @@ -44,6 +51,9 @@ func NewFlagSet() (*pflag.FlagSet, Option) { o.caCert = optCACert o.cert = optCert o.key = optKey + o.clientCACert = optCACert + o.clientCert = optCert + o.clientKey = optKey } } diff --git a/service/options.go b/service/options.go index 532e9ad..4df2cd0 100644 --- a/service/options.go +++ b/service/options.go @@ -38,6 +38,9 @@ type Options interface { CACert() string Cert() string Key() string + ClientCACert() string + ClientCert() string + ClientKey() string TLSConfig() *tls.Config Secure() bool @@ -180,6 +183,24 @@ func WithKey(path string) Option { } } +func WithClientCACert(path string) Option { + return func(o *options) { + o.clientCACert = path + } +} + +func WithClientCert(path string) Option { + return func(o *options) { + o.clientCert = path + } +} + +func WithClientKey(path string) Option { + return func(o *options) { + o.clientKey = path + } +} + func WithTLSConfig(conf *tls.Config) Option { return func(o *options) { o.tlsConfig = conf @@ -360,11 +381,14 @@ type options struct { reflection bool health bool - secure bool - caCert string - cert string - key string - tlsConfig *tls.Config + secure bool + caCert string + cert string + key string + clientCACert string + clientCert string + clientKey string + tlsConfig *tls.Config transport transport.Transport registry registry.Registry @@ -442,6 +466,18 @@ func (o *options) Key() string { return o.key } +func (o *options) ClientCACert() string { + return o.clientCACert +} + +func (o *options) ClientCert() string { + return o.clientCert +} + +func (o *options) ClientKey() string { + return o.clientKey +} + func (o *options) TLSConfig() *tls.Config { return o.tlsConfig } @@ -577,9 +613,32 @@ func (o *options) parseTLSConfig() error { Certificates: []tls.Certificate{cert}, RootCAs: caCertPool, } + if !o.hasClientTLSConfig() { + return nil + } + clientCACert, err := os.ReadFile(o.clientCACert) + if err != nil { + return err + } + clientCACertPool := x509.NewCertPool() + ok = clientCACertPool.AppendCertsFromPEM(clientCACert) + if !ok { + return fmt.Errorf("failed to load Client CA Cert from %s", o.clientCACert) + } + clientCert, err := tls.LoadX509KeyPair(o.clientCert, o.clientKey) + if err != nil { + return err + } + o.tlsConfig.ClientCAs = clientCACertPool + o.tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + o.tlsConfig.Certificates = append(o.tlsConfig.Certificates, clientCert) return nil } func (o *options) hasTLSConfig() bool { - return o.caCert != "" && o.cert != "" && o.key != "" && o.tlsConfig == nil + return o.caCert != "" && o.cert != "" && o.tlsConfig == nil +} + +func (o *options) hasClientTLSConfig() bool { + return o.clientCACert != "" && o.clientCert != "" && o.clientKey != "" }