certs: reload on both key and cert changes

Signed-off-by: Adphi <philippe.adrien.nousse@gmail.com>
This commit is contained in:
Adphi 2023-11-24 20:05:56 +01:00
parent efaa4bd14f
commit 198bd2bd59
Signed by: adphi
GPG Key ID: 46BE4062DB2397FF
2 changed files with 50 additions and 12 deletions

View File

@ -14,6 +14,7 @@ import (
"math/big" "math/big"
"net" "net"
"os" "os"
"strings"
"sync" "sync"
"time" "time"
@ -81,30 +82,60 @@ func New(host ...string) (tls.Certificate, error) {
return tls.X509KeyPair(certOut.Bytes(), keyOut.Bytes()) return tls.X509KeyPair(certOut.Bytes(), keyOut.Bytes())
} }
func TLSConfig(ctx context.Context, cert, key string) (*tls.Config, error) {
c, err := Load(ctx, cert, key)
if err != nil {
return nil, err
}
return &tls.Config{
GetCertificate: c,
}, nil
}
func Load(ctx context.Context, cert, key string) (func(info *tls.ClientHelloInfo) (*tls.Certificate, error), error) { func Load(ctx context.Context, cert, key string) (func(info *tls.ClientHelloInfo) (*tls.Certificate, error), error) {
f, err := file.NewConfig(cert) c, err := file.NewConfig(cert)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to load cert: %v", err) return nil, fmt.Errorf("failed to load cert: %v", err)
} }
crt, err := load(f, key) k, err := file.NewConfig(key)
if err != nil {
return nil, fmt.Errorf("failed to load key: %v", err)
}
crt, err := load(c, key)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to load cert: %v", err) return nil, fmt.Errorf("failed to load cert: %v", err)
} }
var mu sync.RWMutex var mu sync.RWMutex
ch := make(chan []byte) kch := make(chan []byte)
if err := f.Watch(ctx, ch); err != nil { if err := k.Watch(ctx, kch); err != nil {
return nil, fmt.Errorf("failed to watch key: %v", err)
}
cch := make(chan []byte)
if err := c.Watch(ctx, cch); err != nil {
return nil, fmt.Errorf("failed to watch cert: %v", err) return nil, fmt.Errorf("failed to watch cert: %v", err)
} }
reload := func() {
c, err := load(c, key)
// ignore errors due to cert and key not matching as this is expected
// when the cert is being reloaded and the key is not yet updated or vice versa
if err != nil && !strings.Contains(err.Error(), "does not match") {
logger.C(ctx).Errorf("failed to reload cert: %v", err)
return
}
mu.Lock()
crt = c
mu.Unlock()
}
go func() { go func() {
for range ch { for {
c, err := load(f, key) select {
if err != nil { case <-kch:
logger.C(ctx).Errorf("failed to reload cert: %v", err) reload()
continue case <-cch:
reload()
case <-ctx.Done():
return
} }
mu.Lock()
crt = c
mu.Unlock()
} }
}() }()

View File

@ -55,6 +55,7 @@ func TestLoad(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, got) require.NotNil(t, got)
require.Equal(t, want.Certificate, got.Certificate) require.Equal(t, want.Certificate, got.Certificate)
require.Equal(t, want.PrivateKey, got.PrivateKey)
require.Equal(t, want.Leaf, got.Leaf) require.Equal(t, want.Leaf, got.Leaf)
} }
}) })
@ -77,6 +78,9 @@ func write(t *testing.T, dir string, cert tls.Certificate) {
Type: "CERTIFICATE", Type: "CERTIFICATE",
Bytes: cert.Certificate[0], Bytes: cert.Certificate[0],
})) }))
if err := crt.Sync(); err != nil {
t.Fatal(err)
}
key, err := os.Create(filepath.Join(dir, "key.pem")) key, err := os.Create(filepath.Join(dir, "key.pem"))
require.NoError(t, err) require.NoError(t, err)
defer key.Close() defer key.Close()
@ -86,4 +90,7 @@ func write(t *testing.T, dir string, cert tls.Certificate) {
Type: "RSA PRIVATE KEY", Type: "RSA PRIVATE KEY",
Bytes: b, Bytes: b,
})) }))
if err := key.Sync(); err != nil {
t.Fatal(err)
}
} }