mirror of
https://github.com/linka-cloud/grpc.git
synced 2024-11-24 20:06:25 +00:00
90 lines
2.3 KiB
Go
90 lines
2.3 KiB
Go
|
package certs
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"crypto/ecdsa"
|
||
|
"crypto/tls"
|
||
|
"crypto/x509"
|
||
|
"encoding/pem"
|
||
|
"os"
|
||
|
"path/filepath"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
"github.com/stretchr/testify/require"
|
||
|
)
|
||
|
|
||
|
func TestLoad(t *testing.T) {
|
||
|
ctx, cancel := context.WithCancel(context.Background())
|
||
|
defer cancel()
|
||
|
|
||
|
t.Run("missing", func(t *testing.T) {
|
||
|
fn, err := Load(ctx, "missing", "missing")
|
||
|
require.Error(t, err)
|
||
|
require.Nil(t, fn)
|
||
|
})
|
||
|
dir, err := os.MkdirTemp("", "certs")
|
||
|
require.NoError(t, err)
|
||
|
defer os.RemoveAll(dir)
|
||
|
var (
|
||
|
want tls.Certificate
|
||
|
fn func(*tls.ClientHelloInfo) (*tls.Certificate, error)
|
||
|
)
|
||
|
t.Run("load", func(t *testing.T) {
|
||
|
want, err = New("acme.org")
|
||
|
require.NoError(t, err)
|
||
|
require.NotNil(t, want.PrivateKey)
|
||
|
require.NotEmpty(t, want.Certificate)
|
||
|
write(t, dir, want)
|
||
|
fn, err = Load(ctx, filepath.Join(dir, "cert.pem"), filepath.Join(dir, "key.pem"))
|
||
|
require.NoError(t, err)
|
||
|
require.NotNil(t, fn)
|
||
|
got, err := fn(nil)
|
||
|
require.NoError(t, err)
|
||
|
require.NotNil(t, got)
|
||
|
require.Equal(t, want.Certificate, got.Certificate)
|
||
|
require.Equal(t, want.PrivateKey, got.PrivateKey)
|
||
|
})
|
||
|
t.Run("reload", func(t *testing.T) {
|
||
|
for i := 0; i < 10; i++ {
|
||
|
want, err = New("acme.org")
|
||
|
require.NoError(t, err)
|
||
|
write(t, dir, want)
|
||
|
time.Sleep(100 * time.Millisecond)
|
||
|
got, err := fn(nil)
|
||
|
require.NoError(t, err)
|
||
|
require.NotNil(t, got)
|
||
|
require.Equal(t, want.Certificate, got.Certificate)
|
||
|
require.Equal(t, want.Leaf, got.Leaf)
|
||
|
}
|
||
|
})
|
||
|
t.Run("removed", func(t *testing.T) {
|
||
|
require.NoError(t, os.Remove(filepath.Join(dir, "cert.pem")))
|
||
|
require.NoError(t, os.Remove(filepath.Join(dir, "key.pem")))
|
||
|
got, err := fn(nil)
|
||
|
require.NoError(t, err)
|
||
|
require.NotNil(t, got)
|
||
|
require.Equal(t, want.Certificate, got.Certificate)
|
||
|
require.Equal(t, want.PrivateKey, got.PrivateKey)
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func write(t *testing.T, dir string, cert tls.Certificate) {
|
||
|
crt, err := os.Create(filepath.Join(dir, "cert.pem"))
|
||
|
require.NoError(t, err)
|
||
|
defer crt.Close()
|
||
|
require.NoError(t, pem.Encode(crt, &pem.Block{
|
||
|
Type: "CERTIFICATE",
|
||
|
Bytes: cert.Certificate[0],
|
||
|
}))
|
||
|
key, err := os.Create(filepath.Join(dir, "key.pem"))
|
||
|
require.NoError(t, err)
|
||
|
defer key.Close()
|
||
|
b, err := x509.MarshalECPrivateKey(cert.PrivateKey.(*ecdsa.PrivateKey))
|
||
|
require.NoError(t, err)
|
||
|
require.NoError(t, pem.Encode(key, &pem.Block{
|
||
|
Type: "RSA PRIVATE KEY",
|
||
|
Bytes: b,
|
||
|
}))
|
||
|
}
|