certs: reload on both key and cert changes

Signed-off-by: Adphi <philippe.adrien.nousse@gmail.com>
This commit is contained in:
Adphi 2023-11-24 20:05:56 +01:00
parent efaa4bd14f
commit 198bd2bd59
Signed by: adphi
GPG Key ID: 46BE4062DB2397FF
2 changed files with 50 additions and 12 deletions

View File

@ -14,6 +14,7 @@ import (
"math/big"
"net"
"os"
"strings"
"sync"
"time"
@ -81,30 +82,60 @@ func New(host ...string) (tls.Certificate, error) {
return tls.X509KeyPair(certOut.Bytes(), keyOut.Bytes())
}
func TLSConfig(ctx context.Context, cert, key string) (*tls.Config, error) {
c, err := Load(ctx, cert, key)
if err != nil {
return nil, err
}
return &tls.Config{
GetCertificate: c,
}, nil
}
func Load(ctx context.Context, cert, key string) (func(info *tls.ClientHelloInfo) (*tls.Certificate, error), error) {
f, err := file.NewConfig(cert)
c, err := file.NewConfig(cert)
if err != nil {
return nil, fmt.Errorf("failed to load cert: %v", err)
}
crt, err := load(f, key)
k, err := file.NewConfig(key)
if err != nil {
return nil, fmt.Errorf("failed to load key: %v", err)
}
crt, err := load(c, key)
if err != nil {
return nil, fmt.Errorf("failed to load cert: %v", err)
}
var mu sync.RWMutex
ch := make(chan []byte)
if err := f.Watch(ctx, ch); err != nil {
kch := make(chan []byte)
if err := k.Watch(ctx, kch); err != nil {
return nil, fmt.Errorf("failed to watch key: %v", err)
}
cch := make(chan []byte)
if err := c.Watch(ctx, cch); err != nil {
return nil, fmt.Errorf("failed to watch cert: %v", err)
}
reload := func() {
c, err := load(c, key)
// ignore errors due to cert and key not matching as this is expected
// when the cert is being reloaded and the key is not yet updated or vice versa
if err != nil && !strings.Contains(err.Error(), "does not match") {
logger.C(ctx).Errorf("failed to reload cert: %v", err)
return
}
mu.Lock()
crt = c
mu.Unlock()
}
go func() {
for range ch {
c, err := load(f, key)
if err != nil {
logger.C(ctx).Errorf("failed to reload cert: %v", err)
continue
for {
select {
case <-kch:
reload()
case <-cch:
reload()
case <-ctx.Done():
return
}
mu.Lock()
crt = c
mu.Unlock()
}
}()

View File

@ -55,6 +55,7 @@ func TestLoad(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, got)
require.Equal(t, want.Certificate, got.Certificate)
require.Equal(t, want.PrivateKey, got.PrivateKey)
require.Equal(t, want.Leaf, got.Leaf)
}
})
@ -77,6 +78,9 @@ func write(t *testing.T, dir string, cert tls.Certificate) {
Type: "CERTIFICATE",
Bytes: cert.Certificate[0],
}))
if err := crt.Sync(); err != nil {
t.Fatal(err)
}
key, err := os.Create(filepath.Join(dir, "key.pem"))
require.NoError(t, err)
defer key.Close()
@ -86,4 +90,7 @@ func write(t *testing.T, dir string, cert tls.Certificate) {
Type: "RSA PRIVATE KEY",
Bytes: b,
}))
if err := key.Sync(); err != nil {
t.Fatal(err)
}
}