mirror of
https://github.com/linka-cloud/grpc.git
synced 2024-12-26 19:00:44 +00:00
certs: add Load function to watch for key and certificate changes
Signed-off-by: Adphi <philippe.adrien.nousse@gmail.com>
This commit is contained in:
parent
b52ae2c670
commit
97f48d30c0
@ -2,6 +2,7 @@ package certs
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
@ -9,9 +10,16 @@ import (
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.linka.cloud/grpc-toolkit/config"
|
||||
"go.linka.cloud/grpc-toolkit/config/file"
|
||||
"go.linka.cloud/grpc-toolkit/logger"
|
||||
)
|
||||
|
||||
func New(host ...string) (tls.Certificate, error) {
|
||||
@ -72,3 +80,53 @@ func New(host ...string) (tls.Certificate, error) {
|
||||
|
||||
return tls.X509KeyPair(certOut.Bytes(), keyOut.Bytes())
|
||||
}
|
||||
|
||||
func Load(ctx context.Context, cert, key string) (func(info *tls.ClientHelloInfo) (*tls.Certificate, error), error) {
|
||||
f, err := file.NewConfig(cert)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load cert: %v", err)
|
||||
}
|
||||
crt, err := load(f, key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load cert: %v", err)
|
||||
}
|
||||
var mu sync.RWMutex
|
||||
ch := make(chan []byte)
|
||||
if err := f.Watch(ctx, ch); err != nil {
|
||||
return nil, fmt.Errorf("failed to watch cert: %v", err)
|
||||
}
|
||||
go func() {
|
||||
for range ch {
|
||||
c, err := load(f, key)
|
||||
if err != nil {
|
||||
logger.C(ctx).Errorf("failed to reload cert: %v", err)
|
||||
continue
|
||||
}
|
||||
mu.Lock()
|
||||
crt = c
|
||||
mu.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
return func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
return crt, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
func load(cert config.Config, key string) (*tls.Certificate, error) {
|
||||
cb, err := cert.Read()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read cert: %v", err)
|
||||
}
|
||||
kb, err := os.ReadFile(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read key: %v", err)
|
||||
}
|
||||
c, err := tls.X509KeyPair(cb, kb)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &c, nil
|
||||
}
|
||||
|
89
certs/certs_test.go
Normal file
89
certs/certs_test.go
Normal file
@ -0,0 +1,89 @@
|
||||
package certs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLoad(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
t.Run("missing", func(t *testing.T) {
|
||||
fn, err := Load(ctx, "missing", "missing")
|
||||
require.Error(t, err)
|
||||
require.Nil(t, fn)
|
||||
})
|
||||
dir, err := os.MkdirTemp("", "certs")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(dir)
|
||||
var (
|
||||
want tls.Certificate
|
||||
fn func(*tls.ClientHelloInfo) (*tls.Certificate, error)
|
||||
)
|
||||
t.Run("load", func(t *testing.T) {
|
||||
want, err = New("acme.org")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, want.PrivateKey)
|
||||
require.NotEmpty(t, want.Certificate)
|
||||
write(t, dir, want)
|
||||
fn, err = Load(ctx, filepath.Join(dir, "cert.pem"), filepath.Join(dir, "key.pem"))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, fn)
|
||||
got, err := fn(nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
require.Equal(t, want.Certificate, got.Certificate)
|
||||
require.Equal(t, want.PrivateKey, got.PrivateKey)
|
||||
})
|
||||
t.Run("reload", func(t *testing.T) {
|
||||
for i := 0; i < 10; i++ {
|
||||
want, err = New("acme.org")
|
||||
require.NoError(t, err)
|
||||
write(t, dir, want)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
got, err := fn(nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
require.Equal(t, want.Certificate, got.Certificate)
|
||||
require.Equal(t, want.Leaf, got.Leaf)
|
||||
}
|
||||
})
|
||||
t.Run("removed", func(t *testing.T) {
|
||||
require.NoError(t, os.Remove(filepath.Join(dir, "cert.pem")))
|
||||
require.NoError(t, os.Remove(filepath.Join(dir, "key.pem")))
|
||||
got, err := fn(nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
require.Equal(t, want.Certificate, got.Certificate)
|
||||
require.Equal(t, want.PrivateKey, got.PrivateKey)
|
||||
})
|
||||
}
|
||||
|
||||
func write(t *testing.T, dir string, cert tls.Certificate) {
|
||||
crt, err := os.Create(filepath.Join(dir, "cert.pem"))
|
||||
require.NoError(t, err)
|
||||
defer crt.Close()
|
||||
require.NoError(t, pem.Encode(crt, &pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: cert.Certificate[0],
|
||||
}))
|
||||
key, err := os.Create(filepath.Join(dir, "key.pem"))
|
||||
require.NoError(t, err)
|
||||
defer key.Close()
|
||||
b, err := x509.MarshalECPrivateKey(cert.PrivateKey.(*ecdsa.PrivateKey))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, pem.Encode(key, &pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: b,
|
||||
}))
|
||||
}
|
Loading…
Reference in New Issue
Block a user