remove client pool and add tls client auth support

Signed-off-by: Adphi <philippe.adrien.nousse@gmail.com>
This commit is contained in:
Adphi 2024-10-17 18:09:58 +02:00
parent 3a3d77169c
commit abe69f1c80
Signed by: adphi
GPG Key ID: 46BE4062DB2397FF
6 changed files with 179 additions and 282 deletions

View File

@ -2,7 +2,6 @@ package client
import (
"context"
"crypto/tls"
"fmt"
"strings"
@ -28,9 +27,8 @@ func New(opts ...Option) (Client, error) {
c.opts.registry = noop.New()
}
resolver.Register(c.opts.registry.ResolverBuilder())
c.pool = newPool(DefaultPoolSize, DefaultPoolTTL, DefaultPoolMaxIdle, DefaultPoolMaxStreams)
if c.opts.tlsConfig == nil && c.opts.Secure() {
c.opts.tlsConfig = &tls.Config{InsecureSkipVerify: true}
if err := c.opts.parseTLSConfig(); err != nil {
return nil, err
}
if c.opts.tlsConfig != nil {
c.opts.dialOptions = append(c.opts.dialOptions, grpc.WithTransportCredentials(credentials.NewTLS(c.opts.tlsConfig)))
@ -59,27 +57,24 @@ func New(opts ...Option) (Client, error) {
if c.opts.version != "" && c.opts.addr == "" {
c.addr = c.addr + ":" + strings.TrimSpace(c.opts.version)
}
cc, err := grpc.Dial(c.addr, c.opts.dialOptions...)
if err != nil {
return nil, err
}
c.cc = cc
return c, nil
}
type client struct {
addr string
pool *pool
opts *options
cc *grpc.ClientConn
}
func (c *client) Invoke(ctx context.Context, method string, args interface{}, reply interface{}, opts ...grpc.CallOption) error {
pc, err := c.pool.getConn(c.addr, c.opts.dialOptions...)
if err != nil {
return err
}
return pc.Invoke(ctx, method, args, reply, opts...)
return c.cc.Invoke(ctx, method, args, reply, opts...)
}
func (c *client) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
pc, err := c.pool.getConn(c.addr, c.opts.dialOptions...)
if err != nil {
return nil, err
}
return pc.NewStream(ctx, desc, method, opts...)
return c.cc.NewStream(ctx, desc, method, opts...)
}

View File

@ -14,27 +14,29 @@ func NewFlagSet() (*pflag.FlagSet, Option) {
const (
addr = "address"
secure = "secure"
// caCert = "ca-cert"
// clientCert = "client-cert"
// clientKey = "client-key"
caCert = "ca-cert"
clientCert = "client-cert"
clientKey = "client-key"
)
var (
optAddress string
optSecure bool
// optCACert string
// optCert string
// optKey string
optCACert string
optCert string
optKey string
)
flags := pflag.NewFlagSet("gRPC", pflag.ContinueOnError)
flags.StringVar(&optAddress, addr, env.GetDefault(u(addr), "0.0.0.0:0"), "Bind address for the server. 127.0.0.1:9090"+flagEnv(addr))
flags.BoolVar(&optSecure, secure, env.GetBoolDefault(u(secure), true), "Generate self signed certificate if none provided"+flagEnv(secure))
// flags.StringVar(&optCACert, caCert, "", "Path to Root CA certificate"+flagEnv(optCACert))
// flags.StringVar(&optCert, clientCert, "", "Path to Server certificate"+flagEnv(clientCert))
// flags.StringVar(&optKey, clientKey, "", "Path to Server key"+flagEnv(clientKey))
flags.StringVar(&optCACert, caCert, "", "Path to Root CA certificate"+flagEnv(optCACert))
flags.StringVar(&optCert, clientCert, "", "Path to Server certificate"+flagEnv(clientCert))
flags.StringVar(&optKey, clientKey, "", "Path to Server key"+flagEnv(clientKey))
return flags, func(o *options) {
o.addr = optAddress
o.secure = optSecure
o.caCert = optCACert
o.cert = optCert
o.key = optKey
}
}

View File

@ -2,6 +2,9 @@ package client
import (
"crypto/tls"
"crypto/x509"
"fmt"
"os"
"google.golang.org/grpc"
@ -15,6 +18,9 @@ type Options interface {
Address() string
Secure() bool
Registry() registry.Registry
CA() string
Cert() string
Key() string
TLSConfig() *tls.Config
DialOptions() []grpc.DialOption
UnaryInterceptors() []grpc.UnaryClientInterceptor
@ -47,6 +53,24 @@ func WithAddress(address string) Option {
}
}
func WithCA(ca string) Option {
return func(o *options) {
o.caCert = ca
}
}
func WithCert(cert string) Option {
return func(o *options) {
o.cert = cert
}
}
func WithKey(key string) Option {
return func(o *options) {
o.key = key
}
}
func WithTLSConfig(conf *tls.Config) Option {
return func(o *options) {
o.tlsConfig = conf
@ -91,6 +115,10 @@ type options struct {
name string
version string
addr string
caCert string
cert string
key string
tlsConfig *tls.Config
secure bool
dialOptions []grpc.DialOption
@ -115,6 +143,18 @@ func (o *options) Registry() registry.Registry {
return o.registry
}
func (o *options) CA() string {
return o.caCert
}
func (o *options) Cert() string {
return o.cert
}
func (o *options) Key() string {
return o.key
}
func (o *options) TLSConfig() *tls.Config {
return o.tlsConfig
}
@ -134,3 +174,38 @@ func (o *options) UnaryInterceptors() []grpc.UnaryClientInterceptor {
func (o *options) StreamInterceptors() []grpc.StreamClientInterceptor {
return o.streamInterceptors
}
func (o *options) hasTLSConfig() bool {
return o.caCert != "" && o.cert != "" && o.key != "" && o.tlsConfig == nil
}
func (o *options) parseTLSConfig() error {
if o.tlsConfig != nil {
return nil
}
if !o.hasTLSConfig() {
if !o.secure {
return nil
}
o.tlsConfig = &tls.Config{InsecureSkipVerify: true}
return nil
}
caCert, err := os.ReadFile(o.caCert)
if err != nil {
return err
}
caCertPool := x509.NewCertPool()
ok := caCertPool.AppendCertsFromPEM(caCert)
if !ok {
return fmt.Errorf("failed to load CA Cert from %s", o.caCert)
}
cert, err := tls.LoadX509KeyPair(o.cert, o.key)
if err != nil {
return err
}
o.tlsConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
RootCAs: caCertPool,
}
return nil
}

View File

@ -1,244 +0,0 @@
/*
Taken from the https://github.com/micro/go-micro/client/grpc
*/
package client
import (
"sync"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
)
var (
// DefaultPoolSize sets the connection pool size
DefaultPoolSize = 100
// DefaultPoolTTL sets the connection pool ttl
DefaultPoolTTL = time.Minute
// DefaultPoolMaxStreams maximum streams on a connectioin
// (20)
DefaultPoolMaxStreams = 20
// DefaultPoolMaxIdle maximum idle conns of a pool
// (50)
DefaultPoolMaxIdle = 50
// DefaultMaxRecvMsgSize maximum message that client can receive
// (4 MB).
DefaultMaxRecvMsgSize = 1024 * 1024 * 4
// DefaultMaxSendMsgSize maximum message that client can send
// (4 MB).
DefaultMaxSendMsgSize = 1024 * 1024 * 4
)
type pool struct {
size int
ttl int64
// max streams on a *poolConn
maxStreams int
// max idle conns
maxIdle int
sync.Mutex
conns map[string]*streamsPool
}
type streamsPool struct {
// head of list
head *poolConn
// busy conns list
busy *poolConn
// the siza of list
count int
// idle conn
idle int
}
type poolConn struct {
// grpc conn
*grpc.ClientConn
err error
addr string
// pool and streams pool
pool *pool
sp *streamsPool
streams int
created int64
// list
pre *poolConn
next *poolConn
in bool
}
func newPool(size int, ttl time.Duration, idle int, ms int) *pool {
if ms <= 0 {
ms = 1
}
if idle < 0 {
idle = 0
}
return &pool{
size: size,
ttl: int64(ttl.Seconds()),
maxStreams: ms,
maxIdle: idle,
conns: make(map[string]*streamsPool),
}
}
func (p *pool) getConn(addr string, opts ...grpc.DialOption) (*poolConn, error) {
now := time.Now().Unix()
p.Lock()
sp, ok := p.conns[addr]
if !ok {
sp = &streamsPool{head: &poolConn{}, busy: &poolConn{}, count: 0, idle: 0}
p.conns[addr] = sp
}
// while we have conns check streams and then return one
// otherwise we'll create a new conn
conn := sp.head.next
for conn != nil {
// check conn state
// https://github.com/grpc/grpc/blob/master/doc/connectivity-semantics-and-api.md
switch conn.GetState() {
case connectivity.Connecting:
conn = conn.next
continue
case connectivity.Shutdown:
next := conn.next
if conn.streams == 0 {
removeConn(conn)
sp.idle--
}
conn = next
continue
case connectivity.TransientFailure:
next := conn.next
if conn.streams == 0 {
removeConn(conn)
conn.ClientConn.Close()
sp.idle--
}
conn = next
continue
case connectivity.Ready:
case connectivity.Idle:
}
// a old conn
if now-conn.created > p.ttl {
next := conn.next
if conn.streams == 0 {
removeConn(conn)
conn.ClientConn.Close()
sp.idle--
}
conn = next
continue
}
// a busy conn
if conn.streams >= p.maxStreams {
next := conn.next
removeConn(conn)
addConnAfter(conn, sp.busy)
conn = next
continue
}
// a idle conn
if conn.streams == 0 {
sp.idle--
}
// a good conn
conn.streams++
p.Unlock()
return conn, nil
}
p.Unlock()
// create new conn
cc, err := grpc.Dial(addr, opts...)
if err != nil {
return nil, err
}
conn = &poolConn{cc, nil, addr, p, sp, 1, time.Now().Unix(), nil, nil, false}
// add conn to streams pool
p.Lock()
if sp.count < p.size {
addConnAfter(conn, sp.head)
}
p.Unlock()
return conn, nil
}
func (p *pool) release(addr string, conn *poolConn, err error) {
p.Lock()
p, sp, created := conn.pool, conn.sp, conn.created
// try to add conn
if !conn.in && sp.count < p.size {
addConnAfter(conn, sp.head)
}
if !conn.in {
p.Unlock()
conn.ClientConn.Close()
return
}
// a busy conn
if conn.streams >= p.maxStreams {
removeConn(conn)
addConnAfter(conn, sp.head)
}
conn.streams--
// if streams == 0, we can do something
if conn.streams == 0 {
// 1. it has errored
// 2. too many idle conn or
// 3. conn is too old
now := time.Now().Unix()
if err != nil || sp.idle >= p.maxIdle || now-created > p.ttl {
removeConn(conn)
p.Unlock()
conn.ClientConn.Close()
return
}
sp.idle++
}
p.Unlock()
return
}
func (conn *poolConn) Close() {
conn.pool.release(conn.addr, conn, conn.err)
}
func removeConn(conn *poolConn) {
if conn.pre != nil {
conn.pre.next = conn.next
}
if conn.next != nil {
conn.next.pre = conn.pre
}
conn.pre = nil
conn.next = nil
conn.in = false
conn.sp.count--
return
}
func addConnAfter(conn *poolConn, after *poolConn) {
conn.next = after.next
conn.pre = after
if after.next != nil {
after.next.pre = conn
}
after.next = conn
conn.in = true
conn.sp.count++
return
}

View File

@ -4,7 +4,7 @@ import (
"fmt"
"strings"
env "github.com/caitlinelfring/go-env-default"
"github.com/caitlinelfring/go-env-default"
"github.com/spf13/pflag"
)
@ -17,6 +17,10 @@ const (
caCert = "ca-cert"
serverCert = "server-cert"
serverKey = "server-key"
clientCACert = "client-ca-cert"
clientCert = "client-cert"
clientKey = "client-key"
)
var u = strings.ToUpper
@ -37,6 +41,9 @@ func NewFlagSet() (*pflag.FlagSet, Option) {
flags.StringVar(&optCACert, caCert, "", "Path to Root CA certificate"+flagEnv(caCert))
flags.StringVar(&optCert, serverCert, "", "Path to Server certificate"+flagEnv(serverCert))
flags.StringVar(&optKey, serverKey, "", "Path to Server key"+flagEnv(serverKey))
flags.StringVar(&optCACert, clientCACert, "", "Path to Root CA certificate"+flagEnv(clientCACert))
flags.StringVar(&optCert, clientCert, "", "Path to Client certificate"+flagEnv(clientCert))
flags.StringVar(&optKey, clientKey, "", "Path to Client key"+flagEnv(clientKey))
return flags, func(o *options) {
o.address = optAddress
o.secure = !optInsecure
@ -44,6 +51,9 @@ func NewFlagSet() (*pflag.FlagSet, Option) {
o.caCert = optCACert
o.cert = optCert
o.key = optKey
o.clientCACert = optCACert
o.clientCert = optCert
o.clientKey = optKey
}
}

View File

@ -38,6 +38,9 @@ type Options interface {
CACert() string
Cert() string
Key() string
ClientCACert() string
ClientCert() string
ClientKey() string
TLSConfig() *tls.Config
Secure() bool
@ -180,6 +183,24 @@ func WithKey(path string) Option {
}
}
func WithClientCACert(path string) Option {
return func(o *options) {
o.clientCACert = path
}
}
func WithClientCert(path string) Option {
return func(o *options) {
o.clientCert = path
}
}
func WithClientKey(path string) Option {
return func(o *options) {
o.clientKey = path
}
}
func WithTLSConfig(conf *tls.Config) Option {
return func(o *options) {
o.tlsConfig = conf
@ -364,6 +385,9 @@ type options struct {
caCert string
cert string
key string
clientCACert string
clientCert string
clientKey string
tlsConfig *tls.Config
transport transport.Transport
@ -442,6 +466,18 @@ func (o *options) Key() string {
return o.key
}
func (o *options) ClientCACert() string {
return o.clientCACert
}
func (o *options) ClientCert() string {
return o.clientCert
}
func (o *options) ClientKey() string {
return o.clientKey
}
func (o *options) TLSConfig() *tls.Config {
return o.tlsConfig
}
@ -577,9 +613,32 @@ func (o *options) parseTLSConfig() error {
Certificates: []tls.Certificate{cert},
RootCAs: caCertPool,
}
if !o.hasClientTLSConfig() {
return nil
}
clientCACert, err := os.ReadFile(o.clientCACert)
if err != nil {
return err
}
clientCACertPool := x509.NewCertPool()
ok = clientCACertPool.AppendCertsFromPEM(clientCACert)
if !ok {
return fmt.Errorf("failed to load Client CA Cert from %s", o.clientCACert)
}
clientCert, err := tls.LoadX509KeyPair(o.clientCert, o.clientKey)
if err != nil {
return err
}
o.tlsConfig.ClientCAs = clientCACertPool
o.tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
o.tlsConfig.Certificates = append(o.tlsConfig.Certificates, clientCert)
return nil
}
func (o *options) hasTLSConfig() bool {
return o.caCert != "" && o.cert != "" && o.key != "" && o.tlsConfig == nil
return o.caCert != "" && o.cert != "" && o.tlsConfig == nil
}
func (o *options) hasClientTLSConfig() bool {
return o.clientCACert != "" && o.clientCert != "" && o.clientKey != ""
}