From 4ca6bedf5a098590191f966878901b373827b819 Mon Sep 17 00:00:00 2001 From: Adphi Date: Fri, 7 Aug 2020 10:59:27 +0200 Subject: [PATCH] add cert generation, tls config, reflection --- certs/certs.go | 74 +++++++++++++++++++++++++++++++++++++++++++ example/example.go | 2 ++ go.mod | 1 + go.sum | 2 ++ service/options.go | 78 +++++++++++++++++++++++++++++++++------------- service/service.go | 30 ++++++++++++------ 6 files changed, 156 insertions(+), 31 deletions(-) create mode 100644 certs/certs.go diff --git a/certs/certs.go b/certs/certs.go new file mode 100644 index 0000000..3879bb1 --- /dev/null +++ b/certs/certs.go @@ -0,0 +1,74 @@ +package certs + +import ( + "bytes" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "net" + "time" +) + +func New(host ...string) (tls.Certificate, error) { + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return tls.Certificate{}, err + } + + notBefore := time.Now() + notAfter := notBefore.Add(time.Hour * 24 * 365) + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return tls.Certificate{}, err + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"Acme Co"}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + for _, h := range host { + if ip := net.ParseIP(h); ip != nil { + template.IPAddresses = append(template.IPAddresses, ip) + } else { + template.DNSNames = append(template.DNSNames, h) + } + } + + template.IsCA = true + template.KeyUsage |= x509.KeyUsageCertSign + + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return tls.Certificate{}, err + } + + // create public key + certOut := bytes.NewBuffer(nil) + pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + + // create private key + keyOut := bytes.NewBuffer(nil) + b, err := x509.MarshalECPrivateKey(priv) + if err != nil { + return tls.Certificate{}, err + } + pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: b}) + + return tls.X509KeyPair(certOut.Bytes(), keyOut.Bytes()) +} \ No newline at end of file diff --git a/example/example.go b/example/example.go index bf0fcfd..0b7d8ea 100644 --- a/example/example.go +++ b/example/example.go @@ -35,6 +35,8 @@ func main() { svc, err = service.New( service.WithContext(ctx), service.WithName("Greeting"), + service.WithReflection(true), + service.WithSecure(true), service.WithAfterStart(func() error { fmt.Println("Server listening on", svc.Options().Address()) return nil diff --git a/go.mod b/go.mod index fb8ead7..5c365d0 100644 --- a/go.mod +++ b/go.mod @@ -10,5 +10,6 @@ require ( github.com/spf13/cobra v0.0.5 github.com/spf13/pflag v1.0.3 github.com/spf13/viper v1.6.2 + go.uber.org/multierr v1.1.0 google.golang.org/grpc v1.26.0 ) diff --git a/go.sum b/go.sum index 295b524..e5fc8eb 100644 --- a/go.sum +++ b/go.sum @@ -141,7 +141,9 @@ github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljT github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= +go.uber.org/atomic v1.4.0 h1:cxzIVoETapQEqDhQu3QfnvXAV4AlzcvUCxkVUFw3+EU= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/multierr v1.1.0 h1:HoEmRHQPVSqub6w2z2d2EOVs2fjyFRGyofhKuyDq0QI= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= diff --git a/service/options.go b/service/options.go index ac35179..79d7a75 100644 --- a/service/options.go +++ b/service/options.go @@ -6,10 +6,12 @@ import ( "crypto/x509" "fmt" "io/ioutil" - "reflect" "github.com/jinzhu/gorm" + "go.uber.org/multierr" "google.golang.org/grpc" + + "gitlab.bertha.cloud/partitio/lab/grpc/certs" ) /* @@ -63,10 +65,12 @@ type Options interface { Context() context.Context Name() string Address() string + Reflection() bool + Secure() bool CACert() string Cert() string Key() string - TLSConfig() tls.Config + TLSConfig() *tls.Config DB() *gorm.DB BeforeStart() []func() error AfterStart() []func() error @@ -120,6 +124,18 @@ func WithAddress(addr string) Option { } } +func WithReflection(r bool) Option { + return func(o *options) { + o.reflection = r + } +} + +func WithSecure(s bool) Option { + return func(o *options) { + o.secure = s + } +} + func WithGRPCServerOpts(opts ...grpc.ServerOption) Option { return func(o *options) { o.serverOpts = append(o.serverOpts, opts...) @@ -148,15 +164,13 @@ func WithDB(dialect string, args ...interface{}) Option { db, err := gorm.Open(dialect, args...) return func(o *options) { o.db = db - o.error = err + o.error = multierr.Append(o.error, err) } } func WithTLSConfig(conf *tls.Config) Option { return func(o *options) { - if conf != nil { - o.tlsConfig = *conf - } + o.tlsConfig = conf } } @@ -217,14 +231,16 @@ func WithSubscriberInterceptor(w ...interface{}) Option { } type options struct { - ctx context.Context - name string - address string - caCert string - cert string - key string - tlsConfig tls.Config - db *gorm.DB + ctx context.Context + name string + address string + secure bool + reflection bool + caCert string + cert string + key string + tlsConfig *tls.Config + db *gorm.DB beforeStart []func() error afterStart []func() error @@ -254,6 +270,14 @@ func (o *options) Address() string { return o.address } +func (o *options) Reflection() bool { + return o.reflection +} + +func (o *options) Secure() bool { + return o.secure +} + func (o *options) CACert() string { return o.caCert } @@ -266,7 +290,7 @@ func (o *options) Key() string { return o.key } -func (o *options) TLSConfig() tls.Config { +func (o *options) TLSConfig() *tls.Config { return o.tlsConfig } @@ -311,7 +335,21 @@ func (o *options) StreamClientInterceptors() []grpc.StreamClientInterceptor { } func (o *options) parseTLSConfig() error { - if o.hasTLSConfig() { + if (o.tlsConfig != nil) { + return nil + } + if !o.hasTLSConfig() { + if !o.secure { + return nil + } + cert, err := certs.New(o.address, "localhost", "127.0.0.1", o.name) + if err != nil { + return err + } + o.tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{cert}, + InsecureSkipVerify: true, + } return nil } caCert, err := ioutil.ReadFile(o.caCert) @@ -327,15 +365,13 @@ func (o *options) parseTLSConfig() error { if err != nil { return err } - o.tlsConfig = tls.Config{ - ClientAuth: tls.RequireAndVerifyClientCert, - ClientCAs: caCertPool, - RootCAs: caCertPool, + o.tlsConfig = &tls.Config{ Certificates: []tls.Certificate{cert}, + RootCAs: caCertPool, } return nil } func (o *options) hasTLSConfig() bool { - return reflect.DeepEqual(o.tlsConfig, tls.Config{}) + return o.caCert != "" && o.cert != "" && o.key != "" && o.tlsConfig == nil } diff --git a/service/service.go b/service/service.go index cb66ec8..3f8fec4 100644 --- a/service/service.go +++ b/service/service.go @@ -11,6 +11,8 @@ import ( "github.com/spf13/cobra" "go.uber.org/multierr" "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/reflection" ) type Service interface { @@ -60,6 +62,11 @@ func newService(opts ...Option) (*service, error) { } } }() + var err error + s.list, err = net.Listen("tcp", s.opts.address) + if err != nil { + return nil, err + } if err := s.opts.parseTLSConfig(); err != nil { return nil, err } @@ -70,10 +77,19 @@ func newService(opts ...Option) (*service, error) { } return s.run() } - gopts := []grpc.ServerOption{grpc.UnaryInterceptor(grpcmiddleware.ChainUnaryServer(s.opts.serverInterceptors...))} - // TODO : check tls config and tls auth - // grpc.Creds(credentials.NewTLS(&s.opts.tlsConfig)) + gopts := []grpc.ServerOption{ + grpc.Creds(credentials.NewTLS(s.opts.tlsConfig)), + grpc.UnaryInterceptor( + grpcmiddleware.ChainUnaryServer(s.opts.serverInterceptors...), + ), + } + if s.opts.tlsConfig != nil { + gopts = append(gopts) + } s.server = grpc.NewServer(append(gopts, s.opts.serverOpts...)...) + if s.opts.reflection { + reflection.Register(s.server) + } return s, nil } @@ -100,13 +116,7 @@ func (s *service) run() error { return err } } - var err error s.running = true - s.list, err = net.Listen("tcp", s.opts.address) - if err != nil { - s.mu.Unlock() - return err - } s.opts.address = s.list.Addr().String() errs := make(chan error) go func() { @@ -130,7 +140,7 @@ func (s *service) Start() error { func (s *service) Stop() error { s.mu.Lock() defer s.mu.Unlock() - if ! s.running { + if !s.running { return nil } for i := range s.opts.beforeStop {