diff --git a/certs/certs.go b/certs/certs.go index 33a7ad5..8eeb9f1 100644 --- a/certs/certs.go +++ b/certs/certs.go @@ -14,6 +14,7 @@ import ( "math/big" "net" "os" + "strings" "sync" "time" @@ -81,30 +82,60 @@ func New(host ...string) (tls.Certificate, error) { return tls.X509KeyPair(certOut.Bytes(), keyOut.Bytes()) } +func TLSConfig(ctx context.Context, cert, key string) (*tls.Config, error) { + c, err := Load(ctx, cert, key) + if err != nil { + return nil, err + } + return &tls.Config{ + GetCertificate: c, + }, nil +} + func Load(ctx context.Context, cert, key string) (func(info *tls.ClientHelloInfo) (*tls.Certificate, error), error) { - f, err := file.NewConfig(cert) + c, err := file.NewConfig(cert) if err != nil { return nil, fmt.Errorf("failed to load cert: %v", err) } - crt, err := load(f, key) + k, err := file.NewConfig(key) + if err != nil { + return nil, fmt.Errorf("failed to load key: %v", err) + } + crt, err := load(c, key) if err != nil { return nil, fmt.Errorf("failed to load cert: %v", err) } var mu sync.RWMutex - ch := make(chan []byte) - if err := f.Watch(ctx, ch); err != nil { + kch := make(chan []byte) + if err := k.Watch(ctx, kch); err != nil { + return nil, fmt.Errorf("failed to watch key: %v", err) + } + cch := make(chan []byte) + if err := c.Watch(ctx, cch); err != nil { return nil, fmt.Errorf("failed to watch cert: %v", err) } + reload := func() { + c, err := load(c, key) + // ignore errors due to cert and key not matching as this is expected + // when the cert is being reloaded and the key is not yet updated or vice versa + if err != nil && !strings.Contains(err.Error(), "does not match") { + logger.C(ctx).Errorf("failed to reload cert: %v", err) + return + } + mu.Lock() + crt = c + mu.Unlock() + } go func() { - for range ch { - c, err := load(f, key) - if err != nil { - logger.C(ctx).Errorf("failed to reload cert: %v", err) - continue + for { + select { + case <-kch: + reload() + case <-cch: + reload() + case <-ctx.Done(): + return } - mu.Lock() - crt = c - mu.Unlock() } }() diff --git a/certs/certs_test.go b/certs/certs_test.go index 85d7822..ebc115c 100644 --- a/certs/certs_test.go +++ b/certs/certs_test.go @@ -55,6 +55,7 @@ func TestLoad(t *testing.T) { require.NoError(t, err) require.NotNil(t, got) require.Equal(t, want.Certificate, got.Certificate) + require.Equal(t, want.PrivateKey, got.PrivateKey) require.Equal(t, want.Leaf, got.Leaf) } }) @@ -77,6 +78,9 @@ func write(t *testing.T, dir string, cert tls.Certificate) { Type: "CERTIFICATE", Bytes: cert.Certificate[0], })) + if err := crt.Sync(); err != nil { + t.Fatal(err) + } key, err := os.Create(filepath.Join(dir, "key.pem")) require.NoError(t, err) defer key.Close() @@ -86,4 +90,7 @@ func write(t *testing.T, dir string, cert tls.Certificate) { Type: "RSA PRIVATE KEY", Bytes: b, })) + if err := key.Sync(); err != nil { + t.Fatal(err) + } }