add cert generation, tls config, reflection

This commit is contained in:
Adphi 2020-08-07 10:59:27 +02:00
parent 8e5e48f39b
commit 4ca6bedf5a
6 changed files with 156 additions and 31 deletions

74
certs/certs.go Normal file
View File

@ -0,0 +1,74 @@
package certs
import (
"bytes"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"net"
"time"
)
func New(host ...string) (tls.Certificate, error) {
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return tls.Certificate{}, err
}
notBefore := time.Now()
notAfter := notBefore.Add(time.Hour * 24 * 365)
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return tls.Certificate{}, err
}
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"Acme Co"},
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
for _, h := range host {
if ip := net.ParseIP(h); ip != nil {
template.IPAddresses = append(template.IPAddresses, ip)
} else {
template.DNSNames = append(template.DNSNames, h)
}
}
template.IsCA = true
template.KeyUsage |= x509.KeyUsageCertSign
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil {
return tls.Certificate{}, err
}
// create public key
certOut := bytes.NewBuffer(nil)
pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
// create private key
keyOut := bytes.NewBuffer(nil)
b, err := x509.MarshalECPrivateKey(priv)
if err != nil {
return tls.Certificate{}, err
}
pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: b})
return tls.X509KeyPair(certOut.Bytes(), keyOut.Bytes())
}

View File

@ -35,6 +35,8 @@ func main() {
svc, err = service.New( svc, err = service.New(
service.WithContext(ctx), service.WithContext(ctx),
service.WithName("Greeting"), service.WithName("Greeting"),
service.WithReflection(true),
service.WithSecure(true),
service.WithAfterStart(func() error { service.WithAfterStart(func() error {
fmt.Println("Server listening on", svc.Options().Address()) fmt.Println("Server listening on", svc.Options().Address())
return nil return nil

1
go.mod
View File

@ -10,5 +10,6 @@ require (
github.com/spf13/cobra v0.0.5 github.com/spf13/cobra v0.0.5
github.com/spf13/pflag v1.0.3 github.com/spf13/pflag v1.0.3
github.com/spf13/viper v1.6.2 github.com/spf13/viper v1.6.2
go.uber.org/multierr v1.1.0
google.golang.org/grpc v1.26.0 google.golang.org/grpc v1.26.0
) )

2
go.sum
View File

@ -141,7 +141,9 @@ github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljT
github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU=
github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q=
go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU=
go.uber.org/atomic v1.4.0 h1:cxzIVoETapQEqDhQu3QfnvXAV4AlzcvUCxkVUFw3+EU=
go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
go.uber.org/multierr v1.1.0 h1:HoEmRHQPVSqub6w2z2d2EOVs2fjyFRGyofhKuyDq0QI=
go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0=
go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=

View File

@ -6,10 +6,12 @@ import (
"crypto/x509" "crypto/x509"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"reflect"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"go.uber.org/multierr"
"google.golang.org/grpc" "google.golang.org/grpc"
"gitlab.bertha.cloud/partitio/lab/grpc/certs"
) )
/* /*
@ -63,10 +65,12 @@ type Options interface {
Context() context.Context Context() context.Context
Name() string Name() string
Address() string Address() string
Reflection() bool
Secure() bool
CACert() string CACert() string
Cert() string Cert() string
Key() string Key() string
TLSConfig() tls.Config TLSConfig() *tls.Config
DB() *gorm.DB DB() *gorm.DB
BeforeStart() []func() error BeforeStart() []func() error
AfterStart() []func() error AfterStart() []func() error
@ -120,6 +124,18 @@ func WithAddress(addr string) Option {
} }
} }
func WithReflection(r bool) Option {
return func(o *options) {
o.reflection = r
}
}
func WithSecure(s bool) Option {
return func(o *options) {
o.secure = s
}
}
func WithGRPCServerOpts(opts ...grpc.ServerOption) Option { func WithGRPCServerOpts(opts ...grpc.ServerOption) Option {
return func(o *options) { return func(o *options) {
o.serverOpts = append(o.serverOpts, opts...) o.serverOpts = append(o.serverOpts, opts...)
@ -148,15 +164,13 @@ func WithDB(dialect string, args ...interface{}) Option {
db, err := gorm.Open(dialect, args...) db, err := gorm.Open(dialect, args...)
return func(o *options) { return func(o *options) {
o.db = db o.db = db
o.error = err o.error = multierr.Append(o.error, err)
} }
} }
func WithTLSConfig(conf *tls.Config) Option { func WithTLSConfig(conf *tls.Config) Option {
return func(o *options) { return func(o *options) {
if conf != nil { o.tlsConfig = conf
o.tlsConfig = *conf
}
} }
} }
@ -220,10 +234,12 @@ type options struct {
ctx context.Context ctx context.Context
name string name string
address string address string
secure bool
reflection bool
caCert string caCert string
cert string cert string
key string key string
tlsConfig tls.Config tlsConfig *tls.Config
db *gorm.DB db *gorm.DB
beforeStart []func() error beforeStart []func() error
@ -254,6 +270,14 @@ func (o *options) Address() string {
return o.address return o.address
} }
func (o *options) Reflection() bool {
return o.reflection
}
func (o *options) Secure() bool {
return o.secure
}
func (o *options) CACert() string { func (o *options) CACert() string {
return o.caCert return o.caCert
} }
@ -266,7 +290,7 @@ func (o *options) Key() string {
return o.key return o.key
} }
func (o *options) TLSConfig() tls.Config { func (o *options) TLSConfig() *tls.Config {
return o.tlsConfig return o.tlsConfig
} }
@ -311,7 +335,21 @@ func (o *options) StreamClientInterceptors() []grpc.StreamClientInterceptor {
} }
func (o *options) parseTLSConfig() error { func (o *options) parseTLSConfig() error {
if o.hasTLSConfig() { if (o.tlsConfig != nil) {
return nil
}
if !o.hasTLSConfig() {
if !o.secure {
return nil
}
cert, err := certs.New(o.address, "localhost", "127.0.0.1", o.name)
if err != nil {
return err
}
o.tlsConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
InsecureSkipVerify: true,
}
return nil return nil
} }
caCert, err := ioutil.ReadFile(o.caCert) caCert, err := ioutil.ReadFile(o.caCert)
@ -327,15 +365,13 @@ func (o *options) parseTLSConfig() error {
if err != nil { if err != nil {
return err return err
} }
o.tlsConfig = tls.Config{ o.tlsConfig = &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: caCertPool,
RootCAs: caCertPool,
Certificates: []tls.Certificate{cert}, Certificates: []tls.Certificate{cert},
RootCAs: caCertPool,
} }
return nil return nil
} }
func (o *options) hasTLSConfig() bool { func (o *options) hasTLSConfig() bool {
return reflect.DeepEqual(o.tlsConfig, tls.Config{}) return o.caCert != "" && o.cert != "" && o.key != "" && o.tlsConfig == nil
} }

View File

@ -11,6 +11,8 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"go.uber.org/multierr" "go.uber.org/multierr"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/reflection"
) )
type Service interface { type Service interface {
@ -60,6 +62,11 @@ func newService(opts ...Option) (*service, error) {
} }
} }
}() }()
var err error
s.list, err = net.Listen("tcp", s.opts.address)
if err != nil {
return nil, err
}
if err := s.opts.parseTLSConfig(); err != nil { if err := s.opts.parseTLSConfig(); err != nil {
return nil, err return nil, err
} }
@ -70,10 +77,19 @@ func newService(opts ...Option) (*service, error) {
} }
return s.run() return s.run()
} }
gopts := []grpc.ServerOption{grpc.UnaryInterceptor(grpcmiddleware.ChainUnaryServer(s.opts.serverInterceptors...))} gopts := []grpc.ServerOption{
// TODO : check tls config and tls auth grpc.Creds(credentials.NewTLS(s.opts.tlsConfig)),
// grpc.Creds(credentials.NewTLS(&s.opts.tlsConfig)) grpc.UnaryInterceptor(
grpcmiddleware.ChainUnaryServer(s.opts.serverInterceptors...),
),
}
if s.opts.tlsConfig != nil {
gopts = append(gopts)
}
s.server = grpc.NewServer(append(gopts, s.opts.serverOpts...)...) s.server = grpc.NewServer(append(gopts, s.opts.serverOpts...)...)
if s.opts.reflection {
reflection.Register(s.server)
}
return s, nil return s, nil
} }
@ -100,13 +116,7 @@ func (s *service) run() error {
return err return err
} }
} }
var err error
s.running = true s.running = true
s.list, err = net.Listen("tcp", s.opts.address)
if err != nil {
s.mu.Unlock()
return err
}
s.opts.address = s.list.Addr().String() s.opts.address = s.list.Addr().String()
errs := make(chan error) errs := make(chan error)
go func() { go func() {