mirror of
https://github.com/linka-cloud/grpc.git
synced 2024-12-04 16:48:24 +00:00
certs: reload on both key and cert changes
Signed-off-by: Adphi <philippe.adrien.nousse@gmail.com>
This commit is contained in:
parent
efaa4bd14f
commit
198bd2bd59
@ -14,6 +14,7 @@ import (
|
|||||||
"math/big"
|
"math/big"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -81,30 +82,60 @@ func New(host ...string) (tls.Certificate, error) {
|
|||||||
return tls.X509KeyPair(certOut.Bytes(), keyOut.Bytes())
|
return tls.X509KeyPair(certOut.Bytes(), keyOut.Bytes())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TLSConfig(ctx context.Context, cert, key string) (*tls.Config, error) {
|
||||||
|
c, err := Load(ctx, cert, key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &tls.Config{
|
||||||
|
GetCertificate: c,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func Load(ctx context.Context, cert, key string) (func(info *tls.ClientHelloInfo) (*tls.Certificate, error), error) {
|
func Load(ctx context.Context, cert, key string) (func(info *tls.ClientHelloInfo) (*tls.Certificate, error), error) {
|
||||||
f, err := file.NewConfig(cert)
|
c, err := file.NewConfig(cert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to load cert: %v", err)
|
return nil, fmt.Errorf("failed to load cert: %v", err)
|
||||||
}
|
}
|
||||||
crt, err := load(f, key)
|
k, err := file.NewConfig(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load key: %v", err)
|
||||||
|
}
|
||||||
|
crt, err := load(c, key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to load cert: %v", err)
|
return nil, fmt.Errorf("failed to load cert: %v", err)
|
||||||
}
|
}
|
||||||
var mu sync.RWMutex
|
var mu sync.RWMutex
|
||||||
ch := make(chan []byte)
|
kch := make(chan []byte)
|
||||||
if err := f.Watch(ctx, ch); err != nil {
|
if err := k.Watch(ctx, kch); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to watch key: %v", err)
|
||||||
|
}
|
||||||
|
cch := make(chan []byte)
|
||||||
|
if err := c.Watch(ctx, cch); err != nil {
|
||||||
return nil, fmt.Errorf("failed to watch cert: %v", err)
|
return nil, fmt.Errorf("failed to watch cert: %v", err)
|
||||||
}
|
}
|
||||||
|
reload := func() {
|
||||||
|
c, err := load(c, key)
|
||||||
|
// ignore errors due to cert and key not matching as this is expected
|
||||||
|
// when the cert is being reloaded and the key is not yet updated or vice versa
|
||||||
|
if err != nil && !strings.Contains(err.Error(), "does not match") {
|
||||||
|
logger.C(ctx).Errorf("failed to reload cert: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
mu.Lock()
|
||||||
|
crt = c
|
||||||
|
mu.Unlock()
|
||||||
|
}
|
||||||
go func() {
|
go func() {
|
||||||
for range ch {
|
for {
|
||||||
c, err := load(f, key)
|
select {
|
||||||
if err != nil {
|
case <-kch:
|
||||||
logger.C(ctx).Errorf("failed to reload cert: %v", err)
|
reload()
|
||||||
continue
|
case <-cch:
|
||||||
|
reload()
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
}
|
}
|
||||||
mu.Lock()
|
|
||||||
crt = c
|
|
||||||
mu.Unlock()
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -55,6 +55,7 @@ func TestLoad(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, got)
|
require.NotNil(t, got)
|
||||||
require.Equal(t, want.Certificate, got.Certificate)
|
require.Equal(t, want.Certificate, got.Certificate)
|
||||||
|
require.Equal(t, want.PrivateKey, got.PrivateKey)
|
||||||
require.Equal(t, want.Leaf, got.Leaf)
|
require.Equal(t, want.Leaf, got.Leaf)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@ -77,6 +78,9 @@ func write(t *testing.T, dir string, cert tls.Certificate) {
|
|||||||
Type: "CERTIFICATE",
|
Type: "CERTIFICATE",
|
||||||
Bytes: cert.Certificate[0],
|
Bytes: cert.Certificate[0],
|
||||||
}))
|
}))
|
||||||
|
if err := crt.Sync(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
key, err := os.Create(filepath.Join(dir, "key.pem"))
|
key, err := os.Create(filepath.Join(dir, "key.pem"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer key.Close()
|
defer key.Close()
|
||||||
@ -86,4 +90,7 @@ func write(t *testing.T, dir string, cert tls.Certificate) {
|
|||||||
Type: "RSA PRIVATE KEY",
|
Type: "RSA PRIVATE KEY",
|
||||||
Bytes: b,
|
Bytes: b,
|
||||||
}))
|
}))
|
||||||
|
if err := key.Sync(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user