grpc/certs/certs_test.go
Adphi 97f48d30c0
certs: add Load function to watch for key and certificate changes
Signed-off-by: Adphi <philippe.adrien.nousse@gmail.com>
2023-08-15 18:16:29 +01:00

90 lines
2.3 KiB
Go

package certs
import (
"context"
"crypto/ecdsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestLoad(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
t.Run("missing", func(t *testing.T) {
fn, err := Load(ctx, "missing", "missing")
require.Error(t, err)
require.Nil(t, fn)
})
dir, err := os.MkdirTemp("", "certs")
require.NoError(t, err)
defer os.RemoveAll(dir)
var (
want tls.Certificate
fn func(*tls.ClientHelloInfo) (*tls.Certificate, error)
)
t.Run("load", func(t *testing.T) {
want, err = New("acme.org")
require.NoError(t, err)
require.NotNil(t, want.PrivateKey)
require.NotEmpty(t, want.Certificate)
write(t, dir, want)
fn, err = Load(ctx, filepath.Join(dir, "cert.pem"), filepath.Join(dir, "key.pem"))
require.NoError(t, err)
require.NotNil(t, fn)
got, err := fn(nil)
require.NoError(t, err)
require.NotNil(t, got)
require.Equal(t, want.Certificate, got.Certificate)
require.Equal(t, want.PrivateKey, got.PrivateKey)
})
t.Run("reload", func(t *testing.T) {
for i := 0; i < 10; i++ {
want, err = New("acme.org")
require.NoError(t, err)
write(t, dir, want)
time.Sleep(100 * time.Millisecond)
got, err := fn(nil)
require.NoError(t, err)
require.NotNil(t, got)
require.Equal(t, want.Certificate, got.Certificate)
require.Equal(t, want.Leaf, got.Leaf)
}
})
t.Run("removed", func(t *testing.T) {
require.NoError(t, os.Remove(filepath.Join(dir, "cert.pem")))
require.NoError(t, os.Remove(filepath.Join(dir, "key.pem")))
got, err := fn(nil)
require.NoError(t, err)
require.NotNil(t, got)
require.Equal(t, want.Certificate, got.Certificate)
require.Equal(t, want.PrivateKey, got.PrivateKey)
})
}
func write(t *testing.T, dir string, cert tls.Certificate) {
crt, err := os.Create(filepath.Join(dir, "cert.pem"))
require.NoError(t, err)
defer crt.Close()
require.NoError(t, pem.Encode(crt, &pem.Block{
Type: "CERTIFICATE",
Bytes: cert.Certificate[0],
}))
key, err := os.Create(filepath.Join(dir, "key.pem"))
require.NoError(t, err)
defer key.Close()
b, err := x509.MarshalECPrivateKey(cert.PrivateKey.(*ecdsa.PrivateKey))
require.NoError(t, err)
require.NoError(t, pem.Encode(key, &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: b,
}))
}