mirror of
https://github.com/linka-cloud/grpc.git
synced 2024-11-21 18:36:25 +00:00
remove client pool and add tls client auth support
Signed-off-by: Adphi <philippe.adrien.nousse@gmail.com>
This commit is contained in:
parent
3a3d77169c
commit
abe69f1c80
@ -2,7 +2,6 @@ package client
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@ -28,9 +27,8 @@ func New(opts ...Option) (Client, error) {
|
|||||||
c.opts.registry = noop.New()
|
c.opts.registry = noop.New()
|
||||||
}
|
}
|
||||||
resolver.Register(c.opts.registry.ResolverBuilder())
|
resolver.Register(c.opts.registry.ResolverBuilder())
|
||||||
c.pool = newPool(DefaultPoolSize, DefaultPoolTTL, DefaultPoolMaxIdle, DefaultPoolMaxStreams)
|
if err := c.opts.parseTLSConfig(); err != nil {
|
||||||
if c.opts.tlsConfig == nil && c.opts.Secure() {
|
return nil, err
|
||||||
c.opts.tlsConfig = &tls.Config{InsecureSkipVerify: true}
|
|
||||||
}
|
}
|
||||||
if c.opts.tlsConfig != nil {
|
if c.opts.tlsConfig != nil {
|
||||||
c.opts.dialOptions = append(c.opts.dialOptions, grpc.WithTransportCredentials(credentials.NewTLS(c.opts.tlsConfig)))
|
c.opts.dialOptions = append(c.opts.dialOptions, grpc.WithTransportCredentials(credentials.NewTLS(c.opts.tlsConfig)))
|
||||||
@ -59,27 +57,24 @@ func New(opts ...Option) (Client, error) {
|
|||||||
if c.opts.version != "" && c.opts.addr == "" {
|
if c.opts.version != "" && c.opts.addr == "" {
|
||||||
c.addr = c.addr + ":" + strings.TrimSpace(c.opts.version)
|
c.addr = c.addr + ":" + strings.TrimSpace(c.opts.version)
|
||||||
}
|
}
|
||||||
|
cc, err := grpc.Dial(c.addr, c.opts.dialOptions...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
c.cc = cc
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type client struct {
|
type client struct {
|
||||||
addr string
|
addr string
|
||||||
pool *pool
|
|
||||||
opts *options
|
opts *options
|
||||||
|
cc *grpc.ClientConn
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) Invoke(ctx context.Context, method string, args interface{}, reply interface{}, opts ...grpc.CallOption) error {
|
func (c *client) Invoke(ctx context.Context, method string, args interface{}, reply interface{}, opts ...grpc.CallOption) error {
|
||||||
pc, err := c.pool.getConn(c.addr, c.opts.dialOptions...)
|
return c.cc.Invoke(ctx, method, args, reply, opts...)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return pc.Invoke(ctx, method, args, reply, opts...)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
|
func (c *client) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
|
||||||
pc, err := c.pool.getConn(c.addr, c.opts.dialOptions...)
|
return c.cc.NewStream(ctx, desc, method, opts...)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return pc.NewStream(ctx, desc, method, opts...)
|
|
||||||
}
|
}
|
||||||
|
@ -14,27 +14,29 @@ func NewFlagSet() (*pflag.FlagSet, Option) {
|
|||||||
const (
|
const (
|
||||||
addr = "address"
|
addr = "address"
|
||||||
secure = "secure"
|
secure = "secure"
|
||||||
// caCert = "ca-cert"
|
caCert = "ca-cert"
|
||||||
// clientCert = "client-cert"
|
clientCert = "client-cert"
|
||||||
// clientKey = "client-key"
|
clientKey = "client-key"
|
||||||
)
|
)
|
||||||
var (
|
var (
|
||||||
optAddress string
|
optAddress string
|
||||||
optSecure bool
|
optSecure bool
|
||||||
// optCACert string
|
optCACert string
|
||||||
// optCert string
|
optCert string
|
||||||
// optKey string
|
optKey string
|
||||||
)
|
)
|
||||||
flags := pflag.NewFlagSet("gRPC", pflag.ContinueOnError)
|
flags := pflag.NewFlagSet("gRPC", pflag.ContinueOnError)
|
||||||
flags.StringVar(&optAddress, addr, env.GetDefault(u(addr), "0.0.0.0:0"), "Bind address for the server. 127.0.0.1:9090"+flagEnv(addr))
|
flags.StringVar(&optAddress, addr, env.GetDefault(u(addr), "0.0.0.0:0"), "Bind address for the server. 127.0.0.1:9090"+flagEnv(addr))
|
||||||
flags.BoolVar(&optSecure, secure, env.GetBoolDefault(u(secure), true), "Generate self signed certificate if none provided"+flagEnv(secure))
|
flags.BoolVar(&optSecure, secure, env.GetBoolDefault(u(secure), true), "Generate self signed certificate if none provided"+flagEnv(secure))
|
||||||
// flags.StringVar(&optCACert, caCert, "", "Path to Root CA certificate"+flagEnv(optCACert))
|
flags.StringVar(&optCACert, caCert, "", "Path to Root CA certificate"+flagEnv(optCACert))
|
||||||
// flags.StringVar(&optCert, clientCert, "", "Path to Server certificate"+flagEnv(clientCert))
|
flags.StringVar(&optCert, clientCert, "", "Path to Server certificate"+flagEnv(clientCert))
|
||||||
// flags.StringVar(&optKey, clientKey, "", "Path to Server key"+flagEnv(clientKey))
|
flags.StringVar(&optKey, clientKey, "", "Path to Server key"+flagEnv(clientKey))
|
||||||
return flags, func(o *options) {
|
return flags, func(o *options) {
|
||||||
o.addr = optAddress
|
o.addr = optAddress
|
||||||
o.secure = optSecure
|
o.secure = optSecure
|
||||||
|
o.caCert = optCACert
|
||||||
|
o.cert = optCert
|
||||||
|
o.key = optKey
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,6 +2,9 @@ package client
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
@ -15,6 +18,9 @@ type Options interface {
|
|||||||
Address() string
|
Address() string
|
||||||
Secure() bool
|
Secure() bool
|
||||||
Registry() registry.Registry
|
Registry() registry.Registry
|
||||||
|
CA() string
|
||||||
|
Cert() string
|
||||||
|
Key() string
|
||||||
TLSConfig() *tls.Config
|
TLSConfig() *tls.Config
|
||||||
DialOptions() []grpc.DialOption
|
DialOptions() []grpc.DialOption
|
||||||
UnaryInterceptors() []grpc.UnaryClientInterceptor
|
UnaryInterceptors() []grpc.UnaryClientInterceptor
|
||||||
@ -47,6 +53,24 @@ func WithAddress(address string) Option {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithCA(ca string) Option {
|
||||||
|
return func(o *options) {
|
||||||
|
o.caCert = ca
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithCert(cert string) Option {
|
||||||
|
return func(o *options) {
|
||||||
|
o.cert = cert
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithKey(key string) Option {
|
||||||
|
return func(o *options) {
|
||||||
|
o.key = key
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func WithTLSConfig(conf *tls.Config) Option {
|
func WithTLSConfig(conf *tls.Config) Option {
|
||||||
return func(o *options) {
|
return func(o *options) {
|
||||||
o.tlsConfig = conf
|
o.tlsConfig = conf
|
||||||
@ -91,6 +115,10 @@ type options struct {
|
|||||||
name string
|
name string
|
||||||
version string
|
version string
|
||||||
addr string
|
addr string
|
||||||
|
|
||||||
|
caCert string
|
||||||
|
cert string
|
||||||
|
key string
|
||||||
tlsConfig *tls.Config
|
tlsConfig *tls.Config
|
||||||
secure bool
|
secure bool
|
||||||
dialOptions []grpc.DialOption
|
dialOptions []grpc.DialOption
|
||||||
@ -115,6 +143,18 @@ func (o *options) Registry() registry.Registry {
|
|||||||
return o.registry
|
return o.registry
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o *options) CA() string {
|
||||||
|
return o.caCert
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *options) Cert() string {
|
||||||
|
return o.cert
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *options) Key() string {
|
||||||
|
return o.key
|
||||||
|
}
|
||||||
|
|
||||||
func (o *options) TLSConfig() *tls.Config {
|
func (o *options) TLSConfig() *tls.Config {
|
||||||
return o.tlsConfig
|
return o.tlsConfig
|
||||||
}
|
}
|
||||||
@ -134,3 +174,38 @@ func (o *options) UnaryInterceptors() []grpc.UnaryClientInterceptor {
|
|||||||
func (o *options) StreamInterceptors() []grpc.StreamClientInterceptor {
|
func (o *options) StreamInterceptors() []grpc.StreamClientInterceptor {
|
||||||
return o.streamInterceptors
|
return o.streamInterceptors
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o *options) hasTLSConfig() bool {
|
||||||
|
return o.caCert != "" && o.cert != "" && o.key != "" && o.tlsConfig == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *options) parseTLSConfig() error {
|
||||||
|
if o.tlsConfig != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if !o.hasTLSConfig() {
|
||||||
|
if !o.secure {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
o.tlsConfig = &tls.Config{InsecureSkipVerify: true}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
caCert, err := os.ReadFile(o.caCert)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
caCertPool := x509.NewCertPool()
|
||||||
|
ok := caCertPool.AppendCertsFromPEM(caCert)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("failed to load CA Cert from %s", o.caCert)
|
||||||
|
}
|
||||||
|
cert, err := tls.LoadX509KeyPair(o.cert, o.key)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
o.tlsConfig = &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
RootCAs: caCertPool,
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
244
client/pool.go
244
client/pool.go
@ -1,244 +0,0 @@
|
|||||||
/*
|
|
||||||
Taken from the https://github.com/micro/go-micro/client/grpc
|
|
||||||
*/
|
|
||||||
|
|
||||||
package client
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"google.golang.org/grpc"
|
|
||||||
"google.golang.org/grpc/connectivity"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
// DefaultPoolSize sets the connection pool size
|
|
||||||
DefaultPoolSize = 100
|
|
||||||
// DefaultPoolTTL sets the connection pool ttl
|
|
||||||
DefaultPoolTTL = time.Minute
|
|
||||||
// DefaultPoolMaxStreams maximum streams on a connectioin
|
|
||||||
// (20)
|
|
||||||
DefaultPoolMaxStreams = 20
|
|
||||||
|
|
||||||
// DefaultPoolMaxIdle maximum idle conns of a pool
|
|
||||||
// (50)
|
|
||||||
DefaultPoolMaxIdle = 50
|
|
||||||
|
|
||||||
// DefaultMaxRecvMsgSize maximum message that client can receive
|
|
||||||
// (4 MB).
|
|
||||||
DefaultMaxRecvMsgSize = 1024 * 1024 * 4
|
|
||||||
|
|
||||||
// DefaultMaxSendMsgSize maximum message that client can send
|
|
||||||
// (4 MB).
|
|
||||||
DefaultMaxSendMsgSize = 1024 * 1024 * 4
|
|
||||||
)
|
|
||||||
|
|
||||||
type pool struct {
|
|
||||||
size int
|
|
||||||
ttl int64
|
|
||||||
|
|
||||||
// max streams on a *poolConn
|
|
||||||
maxStreams int
|
|
||||||
// max idle conns
|
|
||||||
maxIdle int
|
|
||||||
|
|
||||||
sync.Mutex
|
|
||||||
conns map[string]*streamsPool
|
|
||||||
}
|
|
||||||
|
|
||||||
type streamsPool struct {
|
|
||||||
// head of list
|
|
||||||
head *poolConn
|
|
||||||
// busy conns list
|
|
||||||
busy *poolConn
|
|
||||||
// the siza of list
|
|
||||||
count int
|
|
||||||
// idle conn
|
|
||||||
idle int
|
|
||||||
}
|
|
||||||
|
|
||||||
type poolConn struct {
|
|
||||||
// grpc conn
|
|
||||||
*grpc.ClientConn
|
|
||||||
err error
|
|
||||||
addr string
|
|
||||||
|
|
||||||
// pool and streams pool
|
|
||||||
pool *pool
|
|
||||||
sp *streamsPool
|
|
||||||
streams int
|
|
||||||
created int64
|
|
||||||
|
|
||||||
// list
|
|
||||||
pre *poolConn
|
|
||||||
next *poolConn
|
|
||||||
in bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func newPool(size int, ttl time.Duration, idle int, ms int) *pool {
|
|
||||||
if ms <= 0 {
|
|
||||||
ms = 1
|
|
||||||
}
|
|
||||||
if idle < 0 {
|
|
||||||
idle = 0
|
|
||||||
}
|
|
||||||
return &pool{
|
|
||||||
size: size,
|
|
||||||
ttl: int64(ttl.Seconds()),
|
|
||||||
maxStreams: ms,
|
|
||||||
maxIdle: idle,
|
|
||||||
conns: make(map[string]*streamsPool),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *pool) getConn(addr string, opts ...grpc.DialOption) (*poolConn, error) {
|
|
||||||
now := time.Now().Unix()
|
|
||||||
p.Lock()
|
|
||||||
sp, ok := p.conns[addr]
|
|
||||||
if !ok {
|
|
||||||
sp = &streamsPool{head: &poolConn{}, busy: &poolConn{}, count: 0, idle: 0}
|
|
||||||
p.conns[addr] = sp
|
|
||||||
}
|
|
||||||
// while we have conns check streams and then return one
|
|
||||||
// otherwise we'll create a new conn
|
|
||||||
conn := sp.head.next
|
|
||||||
for conn != nil {
|
|
||||||
// check conn state
|
|
||||||
// https://github.com/grpc/grpc/blob/master/doc/connectivity-semantics-and-api.md
|
|
||||||
switch conn.GetState() {
|
|
||||||
case connectivity.Connecting:
|
|
||||||
conn = conn.next
|
|
||||||
continue
|
|
||||||
case connectivity.Shutdown:
|
|
||||||
next := conn.next
|
|
||||||
if conn.streams == 0 {
|
|
||||||
removeConn(conn)
|
|
||||||
sp.idle--
|
|
||||||
}
|
|
||||||
conn = next
|
|
||||||
continue
|
|
||||||
case connectivity.TransientFailure:
|
|
||||||
next := conn.next
|
|
||||||
if conn.streams == 0 {
|
|
||||||
removeConn(conn)
|
|
||||||
conn.ClientConn.Close()
|
|
||||||
sp.idle--
|
|
||||||
}
|
|
||||||
conn = next
|
|
||||||
continue
|
|
||||||
case connectivity.Ready:
|
|
||||||
case connectivity.Idle:
|
|
||||||
}
|
|
||||||
// a old conn
|
|
||||||
if now-conn.created > p.ttl {
|
|
||||||
next := conn.next
|
|
||||||
if conn.streams == 0 {
|
|
||||||
removeConn(conn)
|
|
||||||
conn.ClientConn.Close()
|
|
||||||
sp.idle--
|
|
||||||
}
|
|
||||||
conn = next
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// a busy conn
|
|
||||||
if conn.streams >= p.maxStreams {
|
|
||||||
next := conn.next
|
|
||||||
removeConn(conn)
|
|
||||||
addConnAfter(conn, sp.busy)
|
|
||||||
conn = next
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// a idle conn
|
|
||||||
if conn.streams == 0 {
|
|
||||||
sp.idle--
|
|
||||||
}
|
|
||||||
// a good conn
|
|
||||||
conn.streams++
|
|
||||||
p.Unlock()
|
|
||||||
return conn, nil
|
|
||||||
}
|
|
||||||
p.Unlock()
|
|
||||||
|
|
||||||
// create new conn
|
|
||||||
cc, err := grpc.Dial(addr, opts...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
conn = &poolConn{cc, nil, addr, p, sp, 1, time.Now().Unix(), nil, nil, false}
|
|
||||||
|
|
||||||
// add conn to streams pool
|
|
||||||
p.Lock()
|
|
||||||
if sp.count < p.size {
|
|
||||||
addConnAfter(conn, sp.head)
|
|
||||||
}
|
|
||||||
p.Unlock()
|
|
||||||
|
|
||||||
return conn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *pool) release(addr string, conn *poolConn, err error) {
|
|
||||||
p.Lock()
|
|
||||||
p, sp, created := conn.pool, conn.sp, conn.created
|
|
||||||
// try to add conn
|
|
||||||
if !conn.in && sp.count < p.size {
|
|
||||||
addConnAfter(conn, sp.head)
|
|
||||||
}
|
|
||||||
if !conn.in {
|
|
||||||
p.Unlock()
|
|
||||||
conn.ClientConn.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// a busy conn
|
|
||||||
if conn.streams >= p.maxStreams {
|
|
||||||
removeConn(conn)
|
|
||||||
addConnAfter(conn, sp.head)
|
|
||||||
}
|
|
||||||
conn.streams--
|
|
||||||
// if streams == 0, we can do something
|
|
||||||
if conn.streams == 0 {
|
|
||||||
// 1. it has errored
|
|
||||||
// 2. too many idle conn or
|
|
||||||
// 3. conn is too old
|
|
||||||
now := time.Now().Unix()
|
|
||||||
if err != nil || sp.idle >= p.maxIdle || now-created > p.ttl {
|
|
||||||
removeConn(conn)
|
|
||||||
p.Unlock()
|
|
||||||
conn.ClientConn.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
sp.idle++
|
|
||||||
}
|
|
||||||
p.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (conn *poolConn) Close() {
|
|
||||||
conn.pool.release(conn.addr, conn, conn.err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func removeConn(conn *poolConn) {
|
|
||||||
if conn.pre != nil {
|
|
||||||
conn.pre.next = conn.next
|
|
||||||
}
|
|
||||||
if conn.next != nil {
|
|
||||||
conn.next.pre = conn.pre
|
|
||||||
}
|
|
||||||
conn.pre = nil
|
|
||||||
conn.next = nil
|
|
||||||
conn.in = false
|
|
||||||
conn.sp.count--
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func addConnAfter(conn *poolConn, after *poolConn) {
|
|
||||||
conn.next = after.next
|
|
||||||
conn.pre = after
|
|
||||||
if after.next != nil {
|
|
||||||
after.next.pre = conn
|
|
||||||
}
|
|
||||||
after.next = conn
|
|
||||||
conn.in = true
|
|
||||||
conn.sp.count++
|
|
||||||
return
|
|
||||||
}
|
|
@ -4,7 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
env "github.com/caitlinelfring/go-env-default"
|
"github.com/caitlinelfring/go-env-default"
|
||||||
"github.com/spf13/pflag"
|
"github.com/spf13/pflag"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -17,6 +17,10 @@ const (
|
|||||||
caCert = "ca-cert"
|
caCert = "ca-cert"
|
||||||
serverCert = "server-cert"
|
serverCert = "server-cert"
|
||||||
serverKey = "server-key"
|
serverKey = "server-key"
|
||||||
|
|
||||||
|
clientCACert = "client-ca-cert"
|
||||||
|
clientCert = "client-cert"
|
||||||
|
clientKey = "client-key"
|
||||||
)
|
)
|
||||||
|
|
||||||
var u = strings.ToUpper
|
var u = strings.ToUpper
|
||||||
@ -37,6 +41,9 @@ func NewFlagSet() (*pflag.FlagSet, Option) {
|
|||||||
flags.StringVar(&optCACert, caCert, "", "Path to Root CA certificate"+flagEnv(caCert))
|
flags.StringVar(&optCACert, caCert, "", "Path to Root CA certificate"+flagEnv(caCert))
|
||||||
flags.StringVar(&optCert, serverCert, "", "Path to Server certificate"+flagEnv(serverCert))
|
flags.StringVar(&optCert, serverCert, "", "Path to Server certificate"+flagEnv(serverCert))
|
||||||
flags.StringVar(&optKey, serverKey, "", "Path to Server key"+flagEnv(serverKey))
|
flags.StringVar(&optKey, serverKey, "", "Path to Server key"+flagEnv(serverKey))
|
||||||
|
flags.StringVar(&optCACert, clientCACert, "", "Path to Root CA certificate"+flagEnv(clientCACert))
|
||||||
|
flags.StringVar(&optCert, clientCert, "", "Path to Client certificate"+flagEnv(clientCert))
|
||||||
|
flags.StringVar(&optKey, clientKey, "", "Path to Client key"+flagEnv(clientKey))
|
||||||
return flags, func(o *options) {
|
return flags, func(o *options) {
|
||||||
o.address = optAddress
|
o.address = optAddress
|
||||||
o.secure = !optInsecure
|
o.secure = !optInsecure
|
||||||
@ -44,6 +51,9 @@ func NewFlagSet() (*pflag.FlagSet, Option) {
|
|||||||
o.caCert = optCACert
|
o.caCert = optCACert
|
||||||
o.cert = optCert
|
o.cert = optCert
|
||||||
o.key = optKey
|
o.key = optKey
|
||||||
|
o.clientCACert = optCACert
|
||||||
|
o.clientCert = optCert
|
||||||
|
o.clientKey = optKey
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -38,6 +38,9 @@ type Options interface {
|
|||||||
CACert() string
|
CACert() string
|
||||||
Cert() string
|
Cert() string
|
||||||
Key() string
|
Key() string
|
||||||
|
ClientCACert() string
|
||||||
|
ClientCert() string
|
||||||
|
ClientKey() string
|
||||||
TLSConfig() *tls.Config
|
TLSConfig() *tls.Config
|
||||||
Secure() bool
|
Secure() bool
|
||||||
|
|
||||||
@ -180,6 +183,24 @@ func WithKey(path string) Option {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithClientCACert(path string) Option {
|
||||||
|
return func(o *options) {
|
||||||
|
o.clientCACert = path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithClientCert(path string) Option {
|
||||||
|
return func(o *options) {
|
||||||
|
o.clientCert = path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithClientKey(path string) Option {
|
||||||
|
return func(o *options) {
|
||||||
|
o.clientKey = path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func WithTLSConfig(conf *tls.Config) Option {
|
func WithTLSConfig(conf *tls.Config) Option {
|
||||||
return func(o *options) {
|
return func(o *options) {
|
||||||
o.tlsConfig = conf
|
o.tlsConfig = conf
|
||||||
@ -364,6 +385,9 @@ type options struct {
|
|||||||
caCert string
|
caCert string
|
||||||
cert string
|
cert string
|
||||||
key string
|
key string
|
||||||
|
clientCACert string
|
||||||
|
clientCert string
|
||||||
|
clientKey string
|
||||||
tlsConfig *tls.Config
|
tlsConfig *tls.Config
|
||||||
|
|
||||||
transport transport.Transport
|
transport transport.Transport
|
||||||
@ -442,6 +466,18 @@ func (o *options) Key() string {
|
|||||||
return o.key
|
return o.key
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o *options) ClientCACert() string {
|
||||||
|
return o.clientCACert
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *options) ClientCert() string {
|
||||||
|
return o.clientCert
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *options) ClientKey() string {
|
||||||
|
return o.clientKey
|
||||||
|
}
|
||||||
|
|
||||||
func (o *options) TLSConfig() *tls.Config {
|
func (o *options) TLSConfig() *tls.Config {
|
||||||
return o.tlsConfig
|
return o.tlsConfig
|
||||||
}
|
}
|
||||||
@ -577,9 +613,32 @@ func (o *options) parseTLSConfig() error {
|
|||||||
Certificates: []tls.Certificate{cert},
|
Certificates: []tls.Certificate{cert},
|
||||||
RootCAs: caCertPool,
|
RootCAs: caCertPool,
|
||||||
}
|
}
|
||||||
|
if !o.hasClientTLSConfig() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
clientCACert, err := os.ReadFile(o.clientCACert)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
clientCACertPool := x509.NewCertPool()
|
||||||
|
ok = clientCACertPool.AppendCertsFromPEM(clientCACert)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("failed to load Client CA Cert from %s", o.clientCACert)
|
||||||
|
}
|
||||||
|
clientCert, err := tls.LoadX509KeyPair(o.clientCert, o.clientKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
o.tlsConfig.ClientCAs = clientCACertPool
|
||||||
|
o.tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||||
|
o.tlsConfig.Certificates = append(o.tlsConfig.Certificates, clientCert)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *options) hasTLSConfig() bool {
|
func (o *options) hasTLSConfig() bool {
|
||||||
return o.caCert != "" && o.cert != "" && o.key != "" && o.tlsConfig == nil
|
return o.caCert != "" && o.cert != "" && o.tlsConfig == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *options) hasClientTLSConfig() bool {
|
||||||
|
return o.clientCACert != "" && o.clientCert != "" && o.clientKey != ""
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user