diff --git a/client/client.go b/client/client.go index 2d3e53d..7613df7 100644 --- a/client/client.go +++ b/client/client.go @@ -3,6 +3,7 @@ package client import ( "context" "fmt" + "net/url" "strings" "google.golang.org/grpc" @@ -19,7 +20,7 @@ type Client interface { } func New(opts ...Option) (Client, error) { - c := &client{opts: &options{}} + c := &client{opts: &options{dialOptions: []grpc.DialOption{grpc.WithContextDialer(dial)}}} for _, o := range opts { o(c.opts) } @@ -78,3 +79,32 @@ func (c *client) Invoke(ctx context.Context, method string, args interface{}, re func (c *client) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { return c.cc.NewStream(ctx, desc, method, opts...) } + +func parseDialTarget(target string) (string, string) { + net := "tcp" + m1 := strings.Index(target, ":") + m2 := strings.Index(target, ":/") + // handle unix:addr which will fail with url.Parse + if m1 >= 0 && m2 < 0 { + if n := target[0:m1]; n == "unix" { + return n, target[m1+1:] + } + } + if strings.HasPrefix(target, `\\.\pipe\`) { + net = "pipe" + return net, target + } + if m2 >= 0 { + t, err := url.Parse(target) + if err != nil { + return net, target + } + scheme := t.Scheme + addr := t.Host + if scheme == "unix" { + addr += t.Path + } + return scheme, addr + } + return net, target +} diff --git a/client/client_test.go b/client/client_test.go new file mode 100644 index 0000000..dfde524 --- /dev/null +++ b/client/client_test.go @@ -0,0 +1,26 @@ +package client + +import ( + "testing" +) + +func TestParseDialTarget(t *testing.T) { + tests := []struct { + input string + expectedNet string + expectedAddr string + }{ + {"tcp://localhost:50051", "tcp", "localhost:50051"}, + {"localhost:50051", "tcp", "localhost:50051"}, + {"unix:///tmp/socket", "unix", "/tmp/socket"}, + {"unix://C:/path/to/socket", "unix", "C:/path/to/socket"}, + {"unix:path/to/socket", "unix", "path/to/socket"}, + {`\\.\pipe\example`, "pipe", `\\.\pipe\example`}, + } + for _, test := range tests { + net, addr := parseDialTarget(test.input) + if net != test.expectedNet || addr != test.expectedAddr { + t.Errorf("parseDialTarget(%q) = (%q, %q); want (%q, %q)", test.input, net, addr, test.expectedNet, test.expectedAddr) + } + } +} diff --git a/client/dialer.go b/client/dialer.go new file mode 100644 index 0000000..d3aa49d --- /dev/null +++ b/client/dialer.go @@ -0,0 +1,12 @@ +//go:build !windows +package client + +import ( + "context" + "net" +) + +func dial(ctx context.Context, addr string) (net.Conn, error) { + network, address := parseDialTarget(addr) + return (&net.Dialer{}).DialContext(ctx, network, address) +} diff --git a/client/dialer_windows.go b/client/dialer_windows.go new file mode 100644 index 0000000..909b209 --- /dev/null +++ b/client/dialer_windows.go @@ -0,0 +1,18 @@ +//go:build windows + +package client + +import ( + "context" + "net" + + "github.com/Microsoft/go-winio" +) + +func dial(ctx context.Context, addr string) (net.Conn, error) { + network, address := parseDialTarget(addr) + if network == "pipe" { + return winio.DialPipeContext(ctx, address) + } + return (&net.Dialer{}).DialContext(ctx, network, address) +} diff --git a/creds/peercreds/peercreds.go b/creds/peercreds/peercreds.go index 277cb91..94494e3 100644 --- a/creds/peercreds/peercreds.go +++ b/creds/peercreds/peercreds.go @@ -2,15 +2,36 @@ package peercreds import ( "context" + "crypto/tls" "errors" "net" + "github.com/soheilhy/cmux" "github.com/tailscale/peercred" "google.golang.org/grpc/credentials" ) +var ErrUnsupportedConnType = peercred.ErrUnsupportedConnType + var _ credentials.TransportCredentials = (*peerCreds)(nil) +// Creds are the peer credentials. +type Creds struct { + pid int + uid string +} + +func (c *Creds) PID() (pid int, ok bool) { + return c.pid, c.pid != 0 +} + +// UserID returns the userid (or Windows SID) that owns the other side +// of the connection, if known. (ok is false if not known) +// The returned string is suitable to passing to os/user.LookupId. +func (c *Creds) UserID() (uid string, ok bool) { + return c.uid, c.uid != "" +} + var common = credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity} func New() credentials.TransportCredentials { @@ -27,12 +48,12 @@ type peerCreds struct { // AuthInfo we’ll attach to the gRPC peer type AuthInfo struct { credentials.CommonAuthInfo - Creds *peercred.Creds + Creds Creds } func (AuthInfo) AuthType() string { return "peercred" } -func (t *peerCreds) ClientHandshake(_ context.Context, _ string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) { +func (t *peerCreds) ClientHandshake(ctx context.Context, authority string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) { return t.handshakeConn(conn) } @@ -54,15 +75,30 @@ func (t *peerCreds) OverrideServerName(name string) error { } func (t *peerCreds) handshakeConn(conn net.Conn) (net.Conn, credentials.AuthInfo, error) { - if conn.RemoteAddr().Network() != "unix" { - return nil, nil, errors.New("peercred only works with unix domain sockets") + if conn.RemoteAddr().Network() != "unix" && conn.RemoteAddr().Network() != "pipe" { + return nil, nil, errors.New("peercred only works with unix domain sockets or Windows named pipes") } - creds, err := peercred.Get(conn) + inner := conn +unwrap: + for { + switch c := inner.(type) { + case *cmux.MuxConn: + inner = c.Conn + case *tls.Conn: + inner = c.NetConn() + default: + break unwrap + } + } + creds, err := Get(inner) if err != nil { if errors.Is(err, peercred.ErrNotImplemented) { return nil, nil, errors.New("peercred not implemented on this OS") } return nil, nil, err } - return conn, AuthInfo{Creds: creds, CommonAuthInfo: common}, nil + var c Creds + c.uid, _ = creds.UserID() + c.pid, _ = creds.PID() + return conn, AuthInfo{Creds: c, CommonAuthInfo: common}, nil } diff --git a/creds/peercreds/peercreds_unix.go b/creds/peercreds/peercreds_unix.go new file mode 100644 index 0000000..c10ba42 --- /dev/null +++ b/creds/peercreds/peercreds_unix.go @@ -0,0 +1,20 @@ +//go:build !windows + +package peercreds + +import ( + "net" + + "github.com/tailscale/peercred" +) + +func Get(conn net.Conn) (*Creds, error) { + creds, err := peercred.Get(conn) + if err != nil { + return nil, err + } + var c Creds + c.uid, _ = creds.UserID() + c.pid, _ = creds.PID() + return &c, nil +} diff --git a/creds/peercreds/peercreds_windows.go b/creds/peercreds/peercreds_windows.go new file mode 100644 index 0000000..3f2a2a4 --- /dev/null +++ b/creds/peercreds/peercreds_windows.go @@ -0,0 +1,134 @@ +//go:build windows + +package peercreds + +import ( + "fmt" + "net" + "reflect" + "unsafe" + + "golang.org/x/sys/windows" +) + +// Get returns peer creds for the client connected to this server-side pipe +// connection. The conn must be a net.Conn returned from go-winio's ListenPipe. +func Get(conn net.Conn) (*Creds, error) { + if conn == nil { + return nil, ErrUnsupportedConnType + } + + h, err := winioPipeHandle(conn) + if err != nil { + return nil, err + } + + // Get client PID for this pipe instance. + var pid uint32 + if err := windows.GetNamedPipeClientProcessId(h, &pid); err != nil { + return nil, fmt.Errorf("GetNamedPipeClientProcessId: %w", err) + } + if pid == 0 { + return nil, fmt.Errorf("GetNamedPipeClientProcessId returned pid=0") + } + + // Open the client process with query rights. + const processQueryLimitedInfo = windows.PROCESS_QUERY_LIMITED_INFORMATION + ph, err := windows.OpenProcess(processQueryLimitedInfo, false, pid) + if err != nil { + return nil, fmt.Errorf("OpenProcess(%d): %w", pid, err) + } + defer windows.CloseHandle(ph) + + // Open the process token. + var token windows.Token + if err := windows.OpenProcessToken(ph, windows.TOKEN_QUERY, &token); err != nil { + return nil, fmt.Errorf("OpenProcessToken: %w", err) + } + defer token.Close() + + // Get the token's user SID. + tu, err := token.GetTokenUser() + if err != nil { + return nil, fmt.Errorf("GetTokenUser: %w", err) + } + + return &Creds{ + uid: tu.User.Sid.String(), + pid: int(pid), + }, nil +} + +// winioPipeHandle digs the underlying syscall HANDLE out of a go-winio +// pipe connection using reflect + unsafe. This depends on the current +// internal layout of github.com/Microsoft/go-winio: +// +// type win32Pipe struct { +// *win32File +// path string +// } +// +// type win32MessageBytePipe struct { +// win32Pipe +// writeClosed bool +// readEOF bool +// } +// +// type win32File struct { +// handle syscall.Handle +// ... +// } +// +// See pipe.go + file.go in go-winio. :contentReference[oaicite:1]{index=1} +func winioPipeHandle(conn net.Conn) (windows.Handle, error) { + v := reflect.ValueOf(conn) + if !v.IsValid() { + return 0, ErrUnsupportedConnType + } + + // Peel off interface & pointer layers: net.Conn is an interface and the + // concrete type is *win32Pipe or *win32MessageBytePipe. + for v.Kind() == reflect.Interface || v.Kind() == reflect.Ptr { + if v.IsNil() { + return 0, ErrUnsupportedConnType + } + v = v.Elem() + } + + if v.Kind() != reflect.Struct { + return 0, ErrUnsupportedConnType + } + + var wfField reflect.Value + + // Case 1: *win32Pipe { *win32File; path string } + if f := v.FieldByName("win32File"); f.IsValid() && f.Kind() == reflect.Ptr { + wfField = f + } else if v.NumField() > 0 { + // Case 2: *win32MessageBytePipe { win32Pipe; ... } + embedded := v.Field(0) + if embedded.IsValid() && embedded.Kind() == reflect.Struct { + if f2 := embedded.FieldByName("win32File"); f2.IsValid() && f2.Kind() == reflect.Ptr { + wfField = f2 + } + } + } + + if !wfField.IsValid() || wfField.IsNil() { + return 0, ErrUnsupportedConnType + } + + // wfField is a *win32File. Its first field is "handle syscall.Handle". + // We only need the first field, so we define a 1-field header type with + // compatible layout and reinterpret the pointer. + type win32FileHeader struct { + Handle windows.Handle // same underlying type as syscall.Handle + } + + ptr := unsafe.Pointer(wfField.Pointer()) + h := (*win32FileHeader)(ptr).Handle + if h == 0 { + return 0, fmt.Errorf("winio pipe handle is 0") + } + return h, nil +} diff --git a/example/example.go b/example/example.go index 3c5f306..0a56b0a 100644 --- a/example/example.go +++ b/example/example.go @@ -7,6 +7,8 @@ import ( "io" "net" "net/http" + "os" + "runtime" "strings" "time" @@ -51,7 +53,12 @@ func run(ctx context.Context, opts ...service.Option) { ) defer p.Shutdown(ctx) - address := "0.0.0.0:9991" + // address := "0.0.0.0:9991" + address := "unix:///tmp/example.sock" + if runtime.GOOS == "windows" { + address = `\\.\pipe\example` + } + defer os.Remove("/tmp/example.sock") var svc service.Service opts = append(opts, @@ -96,7 +103,7 @@ func run(ctx context.Context, opts ...service.Option) { copts := []client.Option{ // client.WithName(name), // client.WithVersion(version), - client.WithAddress("localhost:9991"), + client.WithAddress(address), // client.WithRegistry(mdns.NewRegistry()), client.WithSecure(secure), client.WithInterceptors(tracing.NewClientInterceptors()), diff --git a/example/go.mod b/example/go.mod index bdf31b2..dce9d3b 100644 --- a/example/go.mod +++ b/example/go.mod @@ -32,6 +32,7 @@ require ( require ( cloud.google.com/go/compute/metadata v0.6.0 // indirect + github.com/Microsoft/go-winio v0.6.2 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/bombsimon/logrusr/v4 v4.1.0 // indirect github.com/bufbuild/protocompile v0.14.1 // indirect @@ -63,6 +64,7 @@ require ( github.com/sirupsen/logrus v1.9.3 // indirect github.com/soheilhy/cmux v0.1.5 // indirect github.com/spf13/pflag v1.0.5 // indirect + github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc // indirect github.com/tmc/grpc-websocket-proxy v0.0.0-20220101234140-673ab2c3ae75 // indirect github.com/traefik/grpc-web v0.16.0 // indirect github.com/uptrace/opentelemetry-go-extra/otellogrus v0.3.2 // indirect diff --git a/example/go.sum b/example/go.sum index 7f94834..41de522 100644 --- a/example/go.sum +++ b/example/go.sum @@ -35,6 +35,8 @@ cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9 dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= +github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= @@ -326,6 +328,8 @@ github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc h1:24heQPtnFR+yfntqhI3oAu9i27nEojcQ4NuBQOo5ZFA= +github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc/go.mod h1:f93CXfllFsO9ZQVq+Zocb1Gp4G5Fz0b0rXHLOzt/Djc= github.com/tmc/grpc-websocket-proxy v0.0.0-20220101234140-673ab2c3ae75 h1:6fotK7otjonDflCTK0BCfls4SPy3NcCVb5dqqmbRknE= github.com/tmc/grpc-websocket-proxy v0.0.0-20220101234140-673ab2c3ae75/go.mod h1:KO6IkyS8Y3j8OdNO85qEYBsRPuteD+YciPomcXdrMnk= github.com/traefik/grpc-web v0.16.0 h1:eeUWZaFg6ZU0I9dWOYE2D5qkNzRBmXzzuRlxdltascY= diff --git a/example/server.go b/example/server.go index d3201d6..c197232 100644 --- a/example/server.go +++ b/example/server.go @@ -7,6 +7,7 @@ import ( "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" "google.golang.org/grpc" + "google.golang.org/grpc/peer" "go.linka.cloud/grpc-tookit/example/pb" "go.linka.cloud/grpc-toolkit/interceptors/iface" @@ -30,6 +31,11 @@ func (g *GreeterHandler) SayHello(ctx context.Context, req *pb.HelloRequest) (*p span := trace.SpanFromContext(ctx) span.SetAttributes(attribute.String("name", req.Name)) logger.C(ctx).Infof("replying to %s", req.Name) + if p, ok := peer.FromContext(ctx); ok { + logger.C(ctx).Infof("peer auth info: %+v", p.AuthInfo) + } else { + logger.C(ctx).Infof("no peer info") + } return &pb.HelloReply{Message: hello(req.Name)}, nil } diff --git a/example/service.go b/example/service.go index a0fe7ce..1cb8b7c 100644 --- a/example/service.go +++ b/example/service.go @@ -27,15 +27,12 @@ func newService(ctx context.Context, opts ...service.Option) (service.Service, e log := logger.C(ctx) metrics := metrics2.NewInterceptors(metrics2.WithExemplarFromContext(metrics2.DefaultExemplarFromCtx)) - address := "0.0.0.0:9991" - var svc service.Service opts = append(opts, service.WithContext(ctx), - service.WithAddress(address), // service.WithRegistry(mdns.NewRegistry()), service.WithReflection(true), - service.WithoutCmux(), + // service.WithoutCmux(), service.WithGateway(pb.RegisterGreeterHandler), service.WithGatewayPrefix("/rest"), service.WithGRPCWeb(true), diff --git a/go.mod b/go.mod index 2429415..3e50632 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.23.0 toolchain go1.24.3 require ( + github.com/Microsoft/go-winio v0.6.2 github.com/alta/protopatch v0.5.3 github.com/bombsimon/logrusr/v4 v4.0.0 github.com/caitlinelfring/go-env-default v1.1.0 @@ -56,6 +57,7 @@ require ( go.uber.org/multierr v1.7.0 golang.org/x/net v0.40.0 golang.org/x/sync v0.14.0 + golang.org/x/sys v0.33.0 google.golang.org/genproto/googleapis/api v0.0.0-20250519155744-55703ea1f237 google.golang.org/genproto/googleapis/rpc v0.0.0-20250519155744-55703ea1f237 google.golang.org/grpc v1.72.1 @@ -96,7 +98,6 @@ require ( go.opentelemetry.io/proto/otlp v1.6.0 // indirect go.uber.org/atomic v1.7.0 // indirect golang.org/x/mod v0.21.0 // indirect - golang.org/x/sys v0.33.0 // indirect golang.org/x/text v0.25.0 // indirect golang.org/x/tools v0.26.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 5af313e..744b51e 100644 --- a/go.sum +++ b/go.sum @@ -53,6 +53,8 @@ dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7 github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= +github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= diff --git a/service/listen.go b/service/listen.go new file mode 100644 index 0000000..a77d45f --- /dev/null +++ b/service/listen.go @@ -0,0 +1,11 @@ +//go:build !windows + +package service + +import ( + "net" +) + +func listen(network, address string) (net.Listener, error) { + return net.Listen(network, address) +} diff --git a/service/listen_windows.go b/service/listen_windows.go new file mode 100644 index 0000000..8bff049 --- /dev/null +++ b/service/listen_windows.go @@ -0,0 +1,14 @@ +package service + +import ( + "net" + + "github.com/Microsoft/go-winio" +) + +func listen(network, address string) (net.Listener, error) { + if network == "pipe" { + return winio.ListenPipe(address, nil) + } + return net.Listen(network, address) +} diff --git a/service/service.go b/service/service.go index ed44b86..45bc9ba 100644 --- a/service/service.go +++ b/service/service.go @@ -139,7 +139,7 @@ func newService(opts ...Option) (*service, error) { grpc.StreamInterceptor(si), grpc.UnaryInterceptor(ui), } - if _, ok := s.opts.lis.(*net.UnixListener); ok || strings.HasPrefix(s.opts.address, "unix://") { + if _, ok := s.opts.lis.(*net.UnixListener); ok || strings.HasPrefix(s.opts.address, "unix://") || strings.HasPrefix(s.opts.address, `\\.\pipe\`) { gopts = append(gopts, grpc.Creds(peercreds.New())) } s.server = grpc.NewServer(append(gopts, s.opts.serverOpts...)...) @@ -179,9 +179,12 @@ func (s *service) start() (*errgroup.Group, error) { network = "unix" s.opts.address = strings.TrimPrefix(s.opts.address, "unix://") } + if strings.HasPrefix(s.opts.address, `\\.\pipe\`) { + network = "pipe" + } if s.opts.lis == nil { - lis, err := net.Listen(network, s.opts.address) + lis, err := listen(network, s.opts.address) if err != nil { return nil, err }