diff --git a/certs/certs.go b/certs/certs.go index cf4c1c8..33a7ad5 100644 --- a/certs/certs.go +++ b/certs/certs.go @@ -2,6 +2,7 @@ package certs import ( "bytes" + "context" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" @@ -9,9 +10,16 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/pem" + "fmt" "math/big" "net" + "os" + "sync" "time" + + "go.linka.cloud/grpc-toolkit/config" + "go.linka.cloud/grpc-toolkit/config/file" + "go.linka.cloud/grpc-toolkit/logger" ) func New(host ...string) (tls.Certificate, error) { @@ -72,3 +80,53 @@ func New(host ...string) (tls.Certificate, error) { return tls.X509KeyPair(certOut.Bytes(), keyOut.Bytes()) } + +func Load(ctx context.Context, cert, key string) (func(info *tls.ClientHelloInfo) (*tls.Certificate, error), error) { + f, err := file.NewConfig(cert) + if err != nil { + return nil, fmt.Errorf("failed to load cert: %v", err) + } + crt, err := load(f, key) + if err != nil { + return nil, fmt.Errorf("failed to load cert: %v", err) + } + var mu sync.RWMutex + ch := make(chan []byte) + if err := f.Watch(ctx, ch); err != nil { + return nil, fmt.Errorf("failed to watch cert: %v", err) + } + go func() { + for range ch { + c, err := load(f, key) + if err != nil { + logger.C(ctx).Errorf("failed to reload cert: %v", err) + continue + } + mu.Lock() + crt = c + mu.Unlock() + } + }() + + return func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + mu.RLock() + defer mu.RUnlock() + return crt, nil + }, nil +} + +func load(cert config.Config, key string) (*tls.Certificate, error) { + cb, err := cert.Read() + if err != nil { + return nil, fmt.Errorf("failed to read cert: %v", err) + } + kb, err := os.ReadFile(key) + if err != nil { + return nil, fmt.Errorf("failed to read key: %v", err) + } + c, err := tls.X509KeyPair(cb, kb) + if err != nil { + return nil, err + } + return &c, nil +} diff --git a/certs/certs_test.go b/certs/certs_test.go new file mode 100644 index 0000000..85d7822 --- /dev/null +++ b/certs/certs_test.go @@ -0,0 +1,89 @@ +package certs + +import ( + "context" + "crypto/ecdsa" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestLoad(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + t.Run("missing", func(t *testing.T) { + fn, err := Load(ctx, "missing", "missing") + require.Error(t, err) + require.Nil(t, fn) + }) + dir, err := os.MkdirTemp("", "certs") + require.NoError(t, err) + defer os.RemoveAll(dir) + var ( + want tls.Certificate + fn func(*tls.ClientHelloInfo) (*tls.Certificate, error) + ) + t.Run("load", func(t *testing.T) { + want, err = New("acme.org") + require.NoError(t, err) + require.NotNil(t, want.PrivateKey) + require.NotEmpty(t, want.Certificate) + write(t, dir, want) + fn, err = Load(ctx, filepath.Join(dir, "cert.pem"), filepath.Join(dir, "key.pem")) + require.NoError(t, err) + require.NotNil(t, fn) + got, err := fn(nil) + require.NoError(t, err) + require.NotNil(t, got) + require.Equal(t, want.Certificate, got.Certificate) + require.Equal(t, want.PrivateKey, got.PrivateKey) + }) + t.Run("reload", func(t *testing.T) { + for i := 0; i < 10; i++ { + want, err = New("acme.org") + require.NoError(t, err) + write(t, dir, want) + time.Sleep(100 * time.Millisecond) + got, err := fn(nil) + require.NoError(t, err) + require.NotNil(t, got) + require.Equal(t, want.Certificate, got.Certificate) + require.Equal(t, want.Leaf, got.Leaf) + } + }) + t.Run("removed", func(t *testing.T) { + require.NoError(t, os.Remove(filepath.Join(dir, "cert.pem"))) + require.NoError(t, os.Remove(filepath.Join(dir, "key.pem"))) + got, err := fn(nil) + require.NoError(t, err) + require.NotNil(t, got) + require.Equal(t, want.Certificate, got.Certificate) + require.Equal(t, want.PrivateKey, got.PrivateKey) + }) +} + +func write(t *testing.T, dir string, cert tls.Certificate) { + crt, err := os.Create(filepath.Join(dir, "cert.pem")) + require.NoError(t, err) + defer crt.Close() + require.NoError(t, pem.Encode(crt, &pem.Block{ + Type: "CERTIFICATE", + Bytes: cert.Certificate[0], + })) + key, err := os.Create(filepath.Join(dir, "key.pem")) + require.NoError(t, err) + defer key.Close() + b, err := x509.MarshalECPrivateKey(cert.PrivateKey.(*ecdsa.PrivateKey)) + require.NoError(t, err) + require.NoError(t, pem.Encode(key, &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: b, + })) +}