certs: add Load function to watch for key and certificate changes

Signed-off-by: Adphi <philippe.adrien.nousse@gmail.com>
This commit is contained in:
2023-08-15 18:16:29 +01:00
parent b52ae2c670
commit 97f48d30c0
2 changed files with 147 additions and 0 deletions

View File

@ -2,6 +2,7 @@ package certs
import (
"bytes"
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
@ -9,9 +10,16 @@ import (
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net"
"os"
"sync"
"time"
"go.linka.cloud/grpc-toolkit/config"
"go.linka.cloud/grpc-toolkit/config/file"
"go.linka.cloud/grpc-toolkit/logger"
)
func New(host ...string) (tls.Certificate, error) {
@ -72,3 +80,53 @@ func New(host ...string) (tls.Certificate, error) {
return tls.X509KeyPair(certOut.Bytes(), keyOut.Bytes())
}
func Load(ctx context.Context, cert, key string) (func(info *tls.ClientHelloInfo) (*tls.Certificate, error), error) {
f, err := file.NewConfig(cert)
if err != nil {
return nil, fmt.Errorf("failed to load cert: %v", err)
}
crt, err := load(f, key)
if err != nil {
return nil, fmt.Errorf("failed to load cert: %v", err)
}
var mu sync.RWMutex
ch := make(chan []byte)
if err := f.Watch(ctx, ch); err != nil {
return nil, fmt.Errorf("failed to watch cert: %v", err)
}
go func() {
for range ch {
c, err := load(f, key)
if err != nil {
logger.C(ctx).Errorf("failed to reload cert: %v", err)
continue
}
mu.Lock()
crt = c
mu.Unlock()
}
}()
return func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
mu.RLock()
defer mu.RUnlock()
return crt, nil
}, nil
}
func load(cert config.Config, key string) (*tls.Certificate, error) {
cb, err := cert.Read()
if err != nil {
return nil, fmt.Errorf("failed to read cert: %v", err)
}
kb, err := os.ReadFile(key)
if err != nil {
return nil, fmt.Errorf("failed to read key: %v", err)
}
c, err := tls.X509KeyPair(cb, kb)
if err != nil {
return nil, err
}
return &c, nil
}