mirror of
https://github.com/linka-cloud/grpc.git
synced 2024-12-03 16:18:25 +00:00
certs: reload on both key and cert changes
Signed-off-by: Adphi <philippe.adrien.nousse@gmail.com>
This commit is contained in:
parent
efaa4bd14f
commit
198bd2bd59
@ -14,6 +14,7 @@ import (
|
||||
"math/big"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@ -81,30 +82,60 @@ func New(host ...string) (tls.Certificate, error) {
|
||||
return tls.X509KeyPair(certOut.Bytes(), keyOut.Bytes())
|
||||
}
|
||||
|
||||
func TLSConfig(ctx context.Context, cert, key string) (*tls.Config, error) {
|
||||
c, err := Load(ctx, cert, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &tls.Config{
|
||||
GetCertificate: c,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func Load(ctx context.Context, cert, key string) (func(info *tls.ClientHelloInfo) (*tls.Certificate, error), error) {
|
||||
f, err := file.NewConfig(cert)
|
||||
c, err := file.NewConfig(cert)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load cert: %v", err)
|
||||
}
|
||||
crt, err := load(f, key)
|
||||
k, err := file.NewConfig(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load key: %v", err)
|
||||
}
|
||||
crt, err := load(c, key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load cert: %v", err)
|
||||
}
|
||||
var mu sync.RWMutex
|
||||
ch := make(chan []byte)
|
||||
if err := f.Watch(ctx, ch); err != nil {
|
||||
kch := make(chan []byte)
|
||||
if err := k.Watch(ctx, kch); err != nil {
|
||||
return nil, fmt.Errorf("failed to watch key: %v", err)
|
||||
}
|
||||
cch := make(chan []byte)
|
||||
if err := c.Watch(ctx, cch); err != nil {
|
||||
return nil, fmt.Errorf("failed to watch cert: %v", err)
|
||||
}
|
||||
reload := func() {
|
||||
c, err := load(c, key)
|
||||
// ignore errors due to cert and key not matching as this is expected
|
||||
// when the cert is being reloaded and the key is not yet updated or vice versa
|
||||
if err != nil && !strings.Contains(err.Error(), "does not match") {
|
||||
logger.C(ctx).Errorf("failed to reload cert: %v", err)
|
||||
return
|
||||
}
|
||||
mu.Lock()
|
||||
crt = c
|
||||
mu.Unlock()
|
||||
}
|
||||
go func() {
|
||||
for range ch {
|
||||
c, err := load(f, key)
|
||||
if err != nil {
|
||||
logger.C(ctx).Errorf("failed to reload cert: %v", err)
|
||||
continue
|
||||
for {
|
||||
select {
|
||||
case <-kch:
|
||||
reload()
|
||||
case <-cch:
|
||||
reload()
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
mu.Lock()
|
||||
crt = c
|
||||
mu.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
|
@ -55,6 +55,7 @@ func TestLoad(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
require.Equal(t, want.Certificate, got.Certificate)
|
||||
require.Equal(t, want.PrivateKey, got.PrivateKey)
|
||||
require.Equal(t, want.Leaf, got.Leaf)
|
||||
}
|
||||
})
|
||||
@ -77,6 +78,9 @@ func write(t *testing.T, dir string, cert tls.Certificate) {
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: cert.Certificate[0],
|
||||
}))
|
||||
if err := crt.Sync(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
key, err := os.Create(filepath.Join(dir, "key.pem"))
|
||||
require.NoError(t, err)
|
||||
defer key.Close()
|
||||
@ -86,4 +90,7 @@ func write(t *testing.T, dir string, cert tls.Certificate) {
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: b,
|
||||
}))
|
||||
if err := key.Sync(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user