diff --git a/service/command.go b/service/command.go index 4228c0e..28e2054 100644 --- a/service/command.go +++ b/service/command.go @@ -13,8 +13,8 @@ var cmd = &cobra.Command{ const ( serverAddress = "server_address" - secure = "secure" - reflect = "reflect" + secure = "secure" + reflection = "reflection" caCert = "ca_cert" serverCert = "server_cert" @@ -31,9 +31,9 @@ func init() { cmd.Flags().Bool(secure, true, "Generate self signed certificate if none provided [$SECURE]") viper.BindPFlag(secure, cmd.Flags().Lookup(secure)) - // reflect - cmd.Flags().Bool(reflect, false, "Enable gRPC reflection server [$REFLECT]") - viper.BindPFlag(reflect, cmd.Flags().Lookup(reflect)) + // reflection + cmd.Flags().Bool(reflection, false, "Enable gRPC reflection server [$REFLECT]") + viper.BindPFlag(reflection, cmd.Flags().Lookup(reflection)) // ca_cert cmd.Flags().String(caCert, "", "Path to Root CA certificate [$CA_CERT]") @@ -49,7 +49,7 @@ func init() { func parseFlags(o *options) *options { o.address = viper.GetString(serverAddress) o.secure = viper.GetBool(secure) - o.reflection = viper.GetBool(reflect) + o.reflection = viper.GetBool(reflection) o.caCert = viper.GetString(caCert) o.cert = viper.GetString(serverCert) o.key = viper.GetString(serverKey) diff --git a/service/options.go b/service/options.go index 5a233a1..6494a9b 100644 --- a/service/options.go +++ b/service/options.go @@ -12,6 +12,7 @@ import ( "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/improbable-eng/grpc-web/go/grpcweb" "github.com/jinzhu/gorm" + "github.com/rs/cors" "go.uber.org/multierr" "google.golang.org/grpc" @@ -82,8 +83,8 @@ type Options interface { ClientInterceptors() []grpc.UnaryClientInterceptor StreamClientInterceptors() []grpc.StreamClientInterceptor - // TODO(adphi): CORS for http handler + Cors() cors.Options GRPCWeb() bool GRPCWebPrefix() string GRPCWebOpts() []grpcweb.Option @@ -92,7 +93,7 @@ type Options interface { GatewayPrefix() string GatewayOpts() []runtime.ServeMuxOption - // TODO(adphi): metrics + // TODO(adphi): metrics + tracing Default() } @@ -114,6 +115,7 @@ func (o *options) Default() { if o.transport == nil { o.transport = &grpc.Server{} } + } type Option func(*options) @@ -258,6 +260,12 @@ func WithSubscriberInterceptor(w ...interface{}) Option { } } +func WithCors(opts cors.Options) Option { + return func(o *options) { + o.cors = opts + } +} + func WithGRPCWeb(b bool) Option { return func(o *options) { o.grpcWeb = b @@ -329,9 +337,9 @@ type options struct { grpcWeb bool grpcWebOpts []grpcweb.Option grpcWebPrefix string - gateway RegisterGatewayFunc gatewayOpts []runtime.ServeMuxOption + cors cors.Options error error gatewayPrefix string @@ -421,6 +429,10 @@ func (o *options) StreamClientInterceptors() []grpc.StreamClientInterceptor { return o.streamClientInterceptors } +func (o *options) Cors() cors.Options { + return o.cors +} + func (o *options) GRPCWeb() bool { return o.grpcWeb } diff --git a/service/service.go b/service/service.go index dee22cc..cf58a65 100644 --- a/service/service.go +++ b/service/service.go @@ -8,6 +8,7 @@ import ( "net/http" "os" "os/signal" + "reflect" "strings" "sync" "syscall" @@ -17,12 +18,13 @@ import ( "github.com/google/uuid" grpcmiddleware "github.com/grpc-ecosystem/go-grpc-middleware" "github.com/jinzhu/gorm" + "github.com/rs/cors" "github.com/sirupsen/logrus" "github.com/soheilhy/cmux" "github.com/spf13/cobra" "go.uber.org/multierr" "google.golang.org/grpc" - "google.golang.org/grpc/reflection" + greflect "google.golang.org/grpc/reflection" "go.linka.cloud/grpc/registry" "go.linka.cloud/grpc/registry/noop" @@ -116,7 +118,7 @@ func newService(opts ...Option) (*service, error) { } s.server = grpc.NewServer(append(gopts, s.opts.serverOpts...)...) if s.opts.reflection { - reflection.Register(s.server) + greflect.Register(s.server) } if err := s.gateway(s.opts.gatewayOpts...); err != nil { return nil, err @@ -180,8 +182,24 @@ func (s *service) run() error { errs := make(chan error, 3) + if reflect.DeepEqual(s.opts.cors, cors.Options{}) { + s.opts.cors = cors.Options{ + AllowedHeaders: []string{"*"}, + AllowedMethods: []string{ + http.MethodGet, + http.MethodPost, + http.MethodPut, + http.MethodPatch, + http.MethodDelete, + http.MethodOptions, + http.MethodHead, + }, + AllowedOrigins: []string{"*"}, + AllowCredentials: true, + } + } hServer := &http.Server{ - Handler: s.mux, + Handler: cors.New(s.opts.cors).Handler(s.mux), } if s.opts.Gateway() || s.opts.grpcWeb { go func() {