add cors option

This commit is contained in:
Adphi 2021-09-30 16:56:51 +02:00
parent 89ebbee8dc
commit 0fd0a6ecc3
3 changed files with 42 additions and 12 deletions

View File

@ -14,7 +14,7 @@ const (
serverAddress = "server_address" serverAddress = "server_address"
secure = "secure" secure = "secure"
reflect = "reflect" reflection = "reflection"
caCert = "ca_cert" caCert = "ca_cert"
serverCert = "server_cert" serverCert = "server_cert"
@ -31,9 +31,9 @@ func init() {
cmd.Flags().Bool(secure, true, "Generate self signed certificate if none provided [$SECURE]") cmd.Flags().Bool(secure, true, "Generate self signed certificate if none provided [$SECURE]")
viper.BindPFlag(secure, cmd.Flags().Lookup(secure)) viper.BindPFlag(secure, cmd.Flags().Lookup(secure))
// reflect // reflection
cmd.Flags().Bool(reflect, false, "Enable gRPC reflection server [$REFLECT]") cmd.Flags().Bool(reflection, false, "Enable gRPC reflection server [$REFLECT]")
viper.BindPFlag(reflect, cmd.Flags().Lookup(reflect)) viper.BindPFlag(reflection, cmd.Flags().Lookup(reflection))
// ca_cert // ca_cert
cmd.Flags().String(caCert, "", "Path to Root CA certificate [$CA_CERT]") cmd.Flags().String(caCert, "", "Path to Root CA certificate [$CA_CERT]")
@ -49,7 +49,7 @@ func init() {
func parseFlags(o *options) *options { func parseFlags(o *options) *options {
o.address = viper.GetString(serverAddress) o.address = viper.GetString(serverAddress)
o.secure = viper.GetBool(secure) o.secure = viper.GetBool(secure)
o.reflection = viper.GetBool(reflect) o.reflection = viper.GetBool(reflection)
o.caCert = viper.GetString(caCert) o.caCert = viper.GetString(caCert)
o.cert = viper.GetString(serverCert) o.cert = viper.GetString(serverCert)
o.key = viper.GetString(serverKey) o.key = viper.GetString(serverKey)

View File

@ -12,6 +12,7 @@ import (
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"github.com/improbable-eng/grpc-web/go/grpcweb" "github.com/improbable-eng/grpc-web/go/grpcweb"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"github.com/rs/cors"
"go.uber.org/multierr" "go.uber.org/multierr"
"google.golang.org/grpc" "google.golang.org/grpc"
@ -82,8 +83,8 @@ type Options interface {
ClientInterceptors() []grpc.UnaryClientInterceptor ClientInterceptors() []grpc.UnaryClientInterceptor
StreamClientInterceptors() []grpc.StreamClientInterceptor StreamClientInterceptors() []grpc.StreamClientInterceptor
// TODO(adphi): CORS for http handler
Cors() cors.Options
GRPCWeb() bool GRPCWeb() bool
GRPCWebPrefix() string GRPCWebPrefix() string
GRPCWebOpts() []grpcweb.Option GRPCWebOpts() []grpcweb.Option
@ -92,7 +93,7 @@ type Options interface {
GatewayPrefix() string GatewayPrefix() string
GatewayOpts() []runtime.ServeMuxOption GatewayOpts() []runtime.ServeMuxOption
// TODO(adphi): metrics // TODO(adphi): metrics + tracing
Default() Default()
} }
@ -114,6 +115,7 @@ func (o *options) Default() {
if o.transport == nil { if o.transport == nil {
o.transport = &grpc.Server{} o.transport = &grpc.Server{}
} }
} }
type Option func(*options) type Option func(*options)
@ -258,6 +260,12 @@ func WithSubscriberInterceptor(w ...interface{}) Option {
} }
} }
func WithCors(opts cors.Options) Option {
return func(o *options) {
o.cors = opts
}
}
func WithGRPCWeb(b bool) Option { func WithGRPCWeb(b bool) Option {
return func(o *options) { return func(o *options) {
o.grpcWeb = b o.grpcWeb = b
@ -329,9 +337,9 @@ type options struct {
grpcWeb bool grpcWeb bool
grpcWebOpts []grpcweb.Option grpcWebOpts []grpcweb.Option
grpcWebPrefix string grpcWebPrefix string
gateway RegisterGatewayFunc gateway RegisterGatewayFunc
gatewayOpts []runtime.ServeMuxOption gatewayOpts []runtime.ServeMuxOption
cors cors.Options
error error error error
gatewayPrefix string gatewayPrefix string
@ -421,6 +429,10 @@ func (o *options) StreamClientInterceptors() []grpc.StreamClientInterceptor {
return o.streamClientInterceptors return o.streamClientInterceptors
} }
func (o *options) Cors() cors.Options {
return o.cors
}
func (o *options) GRPCWeb() bool { func (o *options) GRPCWeb() bool {
return o.grpcWeb return o.grpcWeb
} }

View File

@ -8,6 +8,7 @@ import (
"net/http" "net/http"
"os" "os"
"os/signal" "os/signal"
"reflect"
"strings" "strings"
"sync" "sync"
"syscall" "syscall"
@ -17,12 +18,13 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
grpcmiddleware "github.com/grpc-ecosystem/go-grpc-middleware" grpcmiddleware "github.com/grpc-ecosystem/go-grpc-middleware"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"github.com/rs/cors"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/soheilhy/cmux" "github.com/soheilhy/cmux"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"go.uber.org/multierr" "go.uber.org/multierr"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/reflection" greflect "google.golang.org/grpc/reflection"
"go.linka.cloud/grpc/registry" "go.linka.cloud/grpc/registry"
"go.linka.cloud/grpc/registry/noop" "go.linka.cloud/grpc/registry/noop"
@ -116,7 +118,7 @@ func newService(opts ...Option) (*service, error) {
} }
s.server = grpc.NewServer(append(gopts, s.opts.serverOpts...)...) s.server = grpc.NewServer(append(gopts, s.opts.serverOpts...)...)
if s.opts.reflection { if s.opts.reflection {
reflection.Register(s.server) greflect.Register(s.server)
} }
if err := s.gateway(s.opts.gatewayOpts...); err != nil { if err := s.gateway(s.opts.gatewayOpts...); err != nil {
return nil, err return nil, err
@ -180,8 +182,24 @@ func (s *service) run() error {
errs := make(chan error, 3) errs := make(chan error, 3)
if reflect.DeepEqual(s.opts.cors, cors.Options{}) {
s.opts.cors = cors.Options{
AllowedHeaders: []string{"*"},
AllowedMethods: []string{
http.MethodGet,
http.MethodPost,
http.MethodPut,
http.MethodPatch,
http.MethodDelete,
http.MethodOptions,
http.MethodHead,
},
AllowedOrigins: []string{"*"},
AllowCredentials: true,
}
}
hServer := &http.Server{ hServer := &http.Server{
Handler: s.mux, Handler: cors.New(s.opts.cors).Handler(s.mux),
} }
if s.opts.Gateway() || s.opts.grpcWeb { if s.opts.Gateway() || s.opts.grpcWeb {
go func() { go func() {