add cert generation, tls config, reflection

This commit is contained in:
Adphi 2020-08-07 10:59:27 +02:00
parent 8e5e48f39b
commit 4ca6bedf5a
6 changed files with 156 additions and 31 deletions

74
certs/certs.go Normal file
View File

@ -0,0 +1,74 @@
package certs
import (
"bytes"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"net"
"time"
)
func New(host ...string) (tls.Certificate, error) {
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return tls.Certificate{}, err
}
notBefore := time.Now()
notAfter := notBefore.Add(time.Hour * 24 * 365)
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return tls.Certificate{}, err
}
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"Acme Co"},
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
for _, h := range host {
if ip := net.ParseIP(h); ip != nil {
template.IPAddresses = append(template.IPAddresses, ip)
} else {
template.DNSNames = append(template.DNSNames, h)
}
}
template.IsCA = true
template.KeyUsage |= x509.KeyUsageCertSign
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil {
return tls.Certificate{}, err
}
// create public key
certOut := bytes.NewBuffer(nil)
pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
// create private key
keyOut := bytes.NewBuffer(nil)
b, err := x509.MarshalECPrivateKey(priv)
if err != nil {
return tls.Certificate{}, err
}
pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: b})
return tls.X509KeyPair(certOut.Bytes(), keyOut.Bytes())
}

View File

@ -35,6 +35,8 @@ func main() {
svc, err = service.New( svc, err = service.New(
service.WithContext(ctx), service.WithContext(ctx),
service.WithName("Greeting"), service.WithName("Greeting"),
service.WithReflection(true),
service.WithSecure(true),
service.WithAfterStart(func() error { service.WithAfterStart(func() error {
fmt.Println("Server listening on", svc.Options().Address()) fmt.Println("Server listening on", svc.Options().Address())
return nil return nil

1
go.mod
View File

@ -10,5 +10,6 @@ require (
github.com/spf13/cobra v0.0.5 github.com/spf13/cobra v0.0.5
github.com/spf13/pflag v1.0.3 github.com/spf13/pflag v1.0.3
github.com/spf13/viper v1.6.2 github.com/spf13/viper v1.6.2
go.uber.org/multierr v1.1.0
google.golang.org/grpc v1.26.0 google.golang.org/grpc v1.26.0
) )

2
go.sum
View File

@ -141,7 +141,9 @@ github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljT
github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU=
github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q=
go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU=
go.uber.org/atomic v1.4.0 h1:cxzIVoETapQEqDhQu3QfnvXAV4AlzcvUCxkVUFw3+EU=
go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
go.uber.org/multierr v1.1.0 h1:HoEmRHQPVSqub6w2z2d2EOVs2fjyFRGyofhKuyDq0QI=
go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0=
go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=

View File

@ -6,10 +6,12 @@ import (
"crypto/x509" "crypto/x509"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"reflect"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"go.uber.org/multierr"
"google.golang.org/grpc" "google.golang.org/grpc"
"gitlab.bertha.cloud/partitio/lab/grpc/certs"
) )
/* /*
@ -63,10 +65,12 @@ type Options interface {
Context() context.Context Context() context.Context
Name() string Name() string
Address() string Address() string
Reflection() bool
Secure() bool
CACert() string CACert() string
Cert() string Cert() string
Key() string Key() string
TLSConfig() tls.Config TLSConfig() *tls.Config
DB() *gorm.DB DB() *gorm.DB
BeforeStart() []func() error BeforeStart() []func() error
AfterStart() []func() error AfterStart() []func() error
@ -120,6 +124,18 @@ func WithAddress(addr string) Option {
} }
} }
func WithReflection(r bool) Option {
return func(o *options) {
o.reflection = r
}
}
func WithSecure(s bool) Option {
return func(o *options) {
o.secure = s
}
}
func WithGRPCServerOpts(opts ...grpc.ServerOption) Option { func WithGRPCServerOpts(opts ...grpc.ServerOption) Option {
return func(o *options) { return func(o *options) {
o.serverOpts = append(o.serverOpts, opts...) o.serverOpts = append(o.serverOpts, opts...)
@ -148,15 +164,13 @@ func WithDB(dialect string, args ...interface{}) Option {
db, err := gorm.Open(dialect, args...) db, err := gorm.Open(dialect, args...)
return func(o *options) { return func(o *options) {
o.db = db o.db = db
o.error = err o.error = multierr.Append(o.error, err)
} }
} }
func WithTLSConfig(conf *tls.Config) Option { func WithTLSConfig(conf *tls.Config) Option {
return func(o *options) { return func(o *options) {
if conf != nil { o.tlsConfig = conf
o.tlsConfig = *conf
}
} }
} }
@ -217,14 +231,16 @@ func WithSubscriberInterceptor(w ...interface{}) Option {
} }
type options struct { type options struct {
ctx context.Context ctx context.Context
name string name string
address string address string
caCert string secure bool
cert string reflection bool
key string caCert string
tlsConfig tls.Config cert string
db *gorm.DB key string
tlsConfig *tls.Config
db *gorm.DB
beforeStart []func() error beforeStart []func() error
afterStart []func() error afterStart []func() error
@ -254,6 +270,14 @@ func (o *options) Address() string {
return o.address return o.address
} }
func (o *options) Reflection() bool {
return o.reflection
}
func (o *options) Secure() bool {
return o.secure
}
func (o *options) CACert() string { func (o *options) CACert() string {
return o.caCert return o.caCert
} }
@ -266,7 +290,7 @@ func (o *options) Key() string {
return o.key return o.key
} }
func (o *options) TLSConfig() tls.Config { func (o *options) TLSConfig() *tls.Config {
return o.tlsConfig return o.tlsConfig
} }
@ -311,7 +335,21 @@ func (o *options) StreamClientInterceptors() []grpc.StreamClientInterceptor {
} }
func (o *options) parseTLSConfig() error { func (o *options) parseTLSConfig() error {
if o.hasTLSConfig() { if (o.tlsConfig != nil) {
return nil
}
if !o.hasTLSConfig() {
if !o.secure {
return nil
}
cert, err := certs.New(o.address, "localhost", "127.0.0.1", o.name)
if err != nil {
return err
}
o.tlsConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
InsecureSkipVerify: true,
}
return nil return nil
} }
caCert, err := ioutil.ReadFile(o.caCert) caCert, err := ioutil.ReadFile(o.caCert)
@ -327,15 +365,13 @@ func (o *options) parseTLSConfig() error {
if err != nil { if err != nil {
return err return err
} }
o.tlsConfig = tls.Config{ o.tlsConfig = &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: caCertPool,
RootCAs: caCertPool,
Certificates: []tls.Certificate{cert}, Certificates: []tls.Certificate{cert},
RootCAs: caCertPool,
} }
return nil return nil
} }
func (o *options) hasTLSConfig() bool { func (o *options) hasTLSConfig() bool {
return reflect.DeepEqual(o.tlsConfig, tls.Config{}) return o.caCert != "" && o.cert != "" && o.key != "" && o.tlsConfig == nil
} }

View File

@ -11,6 +11,8 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"go.uber.org/multierr" "go.uber.org/multierr"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/reflection"
) )
type Service interface { type Service interface {
@ -60,6 +62,11 @@ func newService(opts ...Option) (*service, error) {
} }
} }
}() }()
var err error
s.list, err = net.Listen("tcp", s.opts.address)
if err != nil {
return nil, err
}
if err := s.opts.parseTLSConfig(); err != nil { if err := s.opts.parseTLSConfig(); err != nil {
return nil, err return nil, err
} }
@ -70,10 +77,19 @@ func newService(opts ...Option) (*service, error) {
} }
return s.run() return s.run()
} }
gopts := []grpc.ServerOption{grpc.UnaryInterceptor(grpcmiddleware.ChainUnaryServer(s.opts.serverInterceptors...))} gopts := []grpc.ServerOption{
// TODO : check tls config and tls auth grpc.Creds(credentials.NewTLS(s.opts.tlsConfig)),
// grpc.Creds(credentials.NewTLS(&s.opts.tlsConfig)) grpc.UnaryInterceptor(
grpcmiddleware.ChainUnaryServer(s.opts.serverInterceptors...),
),
}
if s.opts.tlsConfig != nil {
gopts = append(gopts)
}
s.server = grpc.NewServer(append(gopts, s.opts.serverOpts...)...) s.server = grpc.NewServer(append(gopts, s.opts.serverOpts...)...)
if s.opts.reflection {
reflection.Register(s.server)
}
return s, nil return s, nil
} }
@ -100,13 +116,7 @@ func (s *service) run() error {
return err return err
} }
} }
var err error
s.running = true s.running = true
s.list, err = net.Listen("tcp", s.opts.address)
if err != nil {
s.mu.Unlock()
return err
}
s.opts.address = s.list.Addr().String() s.opts.address = s.list.Addr().String()
errs := make(chan error) errs := make(chan error)
go func() { go func() {
@ -130,7 +140,7 @@ func (s *service) Start() error {
func (s *service) Stop() error { func (s *service) Stop() error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if ! s.running { if !s.running {
return nil return nil
} }
for i := range s.opts.beforeStop { for i := range s.opts.beforeStop {