add cors option

This commit is contained in:
Adphi 2021-09-30 16:56:51 +02:00
parent 89ebbee8dc
commit 0fd0a6ecc3
3 changed files with 42 additions and 12 deletions

View File

@ -13,8 +13,8 @@ var cmd = &cobra.Command{
const (
serverAddress = "server_address"
secure = "secure"
reflect = "reflect"
secure = "secure"
reflection = "reflection"
caCert = "ca_cert"
serverCert = "server_cert"
@ -31,9 +31,9 @@ func init() {
cmd.Flags().Bool(secure, true, "Generate self signed certificate if none provided [$SECURE]")
viper.BindPFlag(secure, cmd.Flags().Lookup(secure))
// reflect
cmd.Flags().Bool(reflect, false, "Enable gRPC reflection server [$REFLECT]")
viper.BindPFlag(reflect, cmd.Flags().Lookup(reflect))
// reflection
cmd.Flags().Bool(reflection, false, "Enable gRPC reflection server [$REFLECT]")
viper.BindPFlag(reflection, cmd.Flags().Lookup(reflection))
// ca_cert
cmd.Flags().String(caCert, "", "Path to Root CA certificate [$CA_CERT]")
@ -49,7 +49,7 @@ func init() {
func parseFlags(o *options) *options {
o.address = viper.GetString(serverAddress)
o.secure = viper.GetBool(secure)
o.reflection = viper.GetBool(reflect)
o.reflection = viper.GetBool(reflection)
o.caCert = viper.GetString(caCert)
o.cert = viper.GetString(serverCert)
o.key = viper.GetString(serverKey)

View File

@ -12,6 +12,7 @@ import (
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"github.com/improbable-eng/grpc-web/go/grpcweb"
"github.com/jinzhu/gorm"
"github.com/rs/cors"
"go.uber.org/multierr"
"google.golang.org/grpc"
@ -82,8 +83,8 @@ type Options interface {
ClientInterceptors() []grpc.UnaryClientInterceptor
StreamClientInterceptors() []grpc.StreamClientInterceptor
// TODO(adphi): CORS for http handler
Cors() cors.Options
GRPCWeb() bool
GRPCWebPrefix() string
GRPCWebOpts() []grpcweb.Option
@ -92,7 +93,7 @@ type Options interface {
GatewayPrefix() string
GatewayOpts() []runtime.ServeMuxOption
// TODO(adphi): metrics
// TODO(adphi): metrics + tracing
Default()
}
@ -114,6 +115,7 @@ func (o *options) Default() {
if o.transport == nil {
o.transport = &grpc.Server{}
}
}
type Option func(*options)
@ -258,6 +260,12 @@ func WithSubscriberInterceptor(w ...interface{}) Option {
}
}
func WithCors(opts cors.Options) Option {
return func(o *options) {
o.cors = opts
}
}
func WithGRPCWeb(b bool) Option {
return func(o *options) {
o.grpcWeb = b
@ -329,9 +337,9 @@ type options struct {
grpcWeb bool
grpcWebOpts []grpcweb.Option
grpcWebPrefix string
gateway RegisterGatewayFunc
gatewayOpts []runtime.ServeMuxOption
cors cors.Options
error error
gatewayPrefix string
@ -421,6 +429,10 @@ func (o *options) StreamClientInterceptors() []grpc.StreamClientInterceptor {
return o.streamClientInterceptors
}
func (o *options) Cors() cors.Options {
return o.cors
}
func (o *options) GRPCWeb() bool {
return o.grpcWeb
}

View File

@ -8,6 +8,7 @@ import (
"net/http"
"os"
"os/signal"
"reflect"
"strings"
"sync"
"syscall"
@ -17,12 +18,13 @@ import (
"github.com/google/uuid"
grpcmiddleware "github.com/grpc-ecosystem/go-grpc-middleware"
"github.com/jinzhu/gorm"
"github.com/rs/cors"
"github.com/sirupsen/logrus"
"github.com/soheilhy/cmux"
"github.com/spf13/cobra"
"go.uber.org/multierr"
"google.golang.org/grpc"
"google.golang.org/grpc/reflection"
greflect "google.golang.org/grpc/reflection"
"go.linka.cloud/grpc/registry"
"go.linka.cloud/grpc/registry/noop"
@ -116,7 +118,7 @@ func newService(opts ...Option) (*service, error) {
}
s.server = grpc.NewServer(append(gopts, s.opts.serverOpts...)...)
if s.opts.reflection {
reflection.Register(s.server)
greflect.Register(s.server)
}
if err := s.gateway(s.opts.gatewayOpts...); err != nil {
return nil, err
@ -180,8 +182,24 @@ func (s *service) run() error {
errs := make(chan error, 3)
if reflect.DeepEqual(s.opts.cors, cors.Options{}) {
s.opts.cors = cors.Options{
AllowedHeaders: []string{"*"},
AllowedMethods: []string{
http.MethodGet,
http.MethodPost,
http.MethodPut,
http.MethodPatch,
http.MethodDelete,
http.MethodOptions,
http.MethodHead,
},
AllowedOrigins: []string{"*"},
AllowCredentials: true,
}
}
hServer := &http.Server{
Handler: s.mux,
Handler: cors.New(s.opts.cors).Handler(s.mux),
}
if s.opts.Gateway() || s.opts.grpcWeb {
go func() {