diff --git a/go.mod b/go.mod index 1bf6ca9..3d807ac 100644 --- a/go.mod +++ b/go.mod @@ -22,6 +22,7 @@ require ( github.com/justinas/alice v1.2.0 github.com/miekg/dns v1.1.41 github.com/opentracing/opentracing-go v1.1.0 + github.com/pires/go-proxyproto v0.7.0 github.com/planetscale/vtprotobuf v0.4.0 github.com/prometheus/client_golang v1.14.0 github.com/rs/cors v1.7.0 diff --git a/go.sum b/go.sum index 9be2b12..957ad96 100644 --- a/go.sum +++ b/go.sum @@ -546,6 +546,8 @@ github.com/pierrec/lz4 v1.0.2-0.20190131084431-473cd7ce01a1/go.mod h1:3/3N9NVKO0 github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4= github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= +github.com/pires/go-proxyproto v0.7.0 h1:IukmRewDQFWC7kfnb66CSomk2q/seBuilHBYFwyq0Hs= +github.com/pires/go-proxyproto v0.7.0/go.mod h1:Vz/1JPY/OACxWGQNIRY2BeyDmpoaWmEP40O9LbuiFR4= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= diff --git a/service/options.go b/service/options.go index 2da6175..229042f 100644 --- a/service/options.go +++ b/service/options.go @@ -67,6 +67,9 @@ type Options interface { // TODO(adphi): metrics + tracing + WithoutCmux() bool + ProxyProtocol() bool + Default() } @@ -341,6 +344,13 @@ func WithoutCmux() Option { } } +func WithProxyProtocol(addrs ...string) Option { + return func(o *options) { + o.proxyProtocol = true + o.proxyProtocolAddrs = addrs + } +} + type options struct { ctx context.Context name string @@ -386,9 +396,11 @@ type options struct { reactUISubPath string hasReactUI bool - error error - gatewayPrefix string - withoutCmux bool + error error + gatewayPrefix string + withoutCmux bool + proxyProtocol bool + proxyProtocolAddrs []string } func (o *options) Name() string { @@ -511,6 +523,10 @@ func (o *options) WithoutCmux() bool { return o.withoutCmux } +func (o *options) ProxyProtocol() bool { + return o.proxyProtocol +} + func (o *options) parseTLSConfig() error { if o.tlsConfig != nil { return nil diff --git a/service/service.go b/service/service.go index c5d55f7..d4e4067 100644 --- a/service/service.go +++ b/service/service.go @@ -18,6 +18,7 @@ import ( "github.com/google/uuid" grpcmiddleware "github.com/grpc-ecosystem/go-grpc-middleware" "github.com/justinas/alice" + "github.com/pires/go-proxyproto" "github.com/rs/cors" "github.com/soheilhy/cmux" "go.uber.org/multierr" @@ -172,6 +173,34 @@ func (s *service) start() (*errgroup.Group, error) { s.opts.address = s.opts.lis.Addr().String() } + if s.opts.proxyProtocol { + p := func(upstream net.Addr) (proxyproto.Policy, error) { + u, _, err := net.SplitHostPort(upstream.String()) + if err != nil { + return proxyproto.REJECT, err + } + ip := net.ParseIP(u) + if ip == nil { + return proxyproto.REJECT, fmt.Errorf("proxyproto: invalid IP address") + } + if ip.IsPrivate() || ip.IsLoopback() { + return proxyproto.USE, nil + } + return proxyproto.REJECT, nil + } + if len(s.opts.proxyProtocolAddrs) > 0 { + var err error + p, err = proxyproto.StrictWhiteListPolicy(s.opts.proxyProtocolAddrs) + if err != nil { + return nil, err + } + } + s.opts.lis = &proxyproto.Listener{ + Listener: s.opts.lis, + Policy: p, + } + } + for i := range s.opts.beforeStart { if err := s.opts.beforeStart[i](); err != nil { return nil, err