diff --git a/creds/peercreds/peercreds.go b/creds/peercreds/peercreds.go new file mode 100644 index 0000000..277cb91 --- /dev/null +++ b/creds/peercreds/peercreds.go @@ -0,0 +1,68 @@ +package peercreds + +import ( + "context" + "errors" + "net" + + "github.com/tailscale/peercred" + "google.golang.org/grpc/credentials" +) + +var _ credentials.TransportCredentials = (*peerCreds)(nil) + +var common = credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity} + +func New() credentials.TransportCredentials { + return &peerCreds{info: credentials.ProtocolInfo{ + SecurityProtocol: "peercred", + ProtocolVersion: "0.1", + }} +} + +type peerCreds struct { + info credentials.ProtocolInfo +} + +// AuthInfo we’ll attach to the gRPC peer +type AuthInfo struct { + credentials.CommonAuthInfo + Creds *peercred.Creds +} + +func (AuthInfo) AuthType() string { return "peercred" } + +func (t *peerCreds) ClientHandshake(_ context.Context, _ string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return t.handshakeConn(conn) +} + +func (t *peerCreds) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return t.handshakeConn(conn) +} + +func (t *peerCreds) Info() credentials.ProtocolInfo { + return t.info +} + +func (t *peerCreds) Clone() credentials.TransportCredentials { + return &peerCreds{info: t.info} +} + +func (t *peerCreds) OverrideServerName(name string) error { + t.info.ServerName = name + return nil +} + +func (t *peerCreds) handshakeConn(conn net.Conn) (net.Conn, credentials.AuthInfo, error) { + if conn.RemoteAddr().Network() != "unix" { + return nil, nil, errors.New("peercred only works with unix domain sockets") + } + creds, err := peercred.Get(conn) + if err != nil { + if errors.Is(err, peercred.ErrNotImplemented) { + return nil, nil, errors.New("peercred not implemented on this OS") + } + return nil, nil, err + } + return conn, AuthInfo{Creds: creds, CommonAuthInfo: common}, nil +} diff --git a/go.mod b/go.mod index 694060d..2429415 100644 --- a/go.mod +++ b/go.mod @@ -31,6 +31,7 @@ require ( github.com/spf13/cobra v1.3.0 github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.10.0 + github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5 github.com/traefik/grpc-web v0.16.0 github.com/uptrace/opentelemetry-go-extra/otellogrus v0.3.2 diff --git a/go.sum b/go.sum index 1808060..5af313e 100644 --- a/go.sum +++ b/go.sum @@ -506,6 +506,8 @@ github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXl github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= +github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc h1:24heQPtnFR+yfntqhI3oAu9i27nEojcQ4NuBQOo5ZFA= +github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc/go.mod h1:f93CXfllFsO9ZQVq+Zocb1Gp4G5Fz0b0rXHLOzt/Djc= github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5 h1:LnC5Kc/wtumK+WB441p7ynQJzVuNRJiqddSIE3IlSEQ= github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/traefik/grpc-web v0.16.0 h1:eeUWZaFg6ZU0I9dWOYE2D5qkNzRBmXzzuRlxdltascY= diff --git a/service/service.go b/service/service.go index 62b47d5..ed44b86 100644 --- a/service/service.go +++ b/service/service.go @@ -29,6 +29,7 @@ import ( "google.golang.org/grpc/health/grpc_health_v1" greflect "google.golang.org/grpc/reflection" + "go.linka.cloud/grpc-toolkit/creds/peercreds" "go.linka.cloud/grpc-toolkit/interceptors/chain" "go.linka.cloud/grpc-toolkit/internal/injectlogger" "go.linka.cloud/grpc-toolkit/logger" @@ -138,6 +139,9 @@ func newService(opts ...Option) (*service, error) { grpc.StreamInterceptor(si), grpc.UnaryInterceptor(ui), } + if _, ok := s.opts.lis.(*net.UnixListener); ok || strings.HasPrefix(s.opts.address, "unix://") { + gopts = append(gopts, grpc.Creds(peercreds.New())) + } s.server = grpc.NewServer(append(gopts, s.opts.serverOpts...)...) if s.opts.reflection { greflect.Register(s.server)