4 Commits
v0.4.4 ... main

19 changed files with 495 additions and 117 deletions

1
.gitignore vendored
View File

@@ -2,3 +2,4 @@
.bin
/tmp
diff
.DS_Store

View File

@@ -3,6 +3,7 @@ package client
import (
"context"
"fmt"
"net/url"
"strings"
"google.golang.org/grpc"
@@ -19,7 +20,7 @@ type Client interface {
}
func New(opts ...Option) (Client, error) {
c := &client{opts: &options{}}
c := &client{opts: &options{dialOptions: []grpc.DialOption{grpc.WithContextDialer(dial)}}}
for _, o := range opts {
o(c.opts)
}
@@ -78,3 +79,32 @@ func (c *client) Invoke(ctx context.Context, method string, args interface{}, re
func (c *client) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
return c.cc.NewStream(ctx, desc, method, opts...)
}
func parseDialTarget(target string) (string, string) {
net := "tcp"
m1 := strings.Index(target, ":")
m2 := strings.Index(target, ":/")
// handle unix:addr which will fail with url.Parse
if m1 >= 0 && m2 < 0 {
if n := target[0:m1]; n == "unix" {
return n, target[m1+1:]
}
}
if strings.HasPrefix(target, `\\.\pipe\`) {
net = "pipe"
return net, target
}
if m2 >= 0 {
t, err := url.Parse(target)
if err != nil {
return net, target
}
scheme := t.Scheme
addr := t.Host
if scheme == "unix" {
addr += t.Path
}
return scheme, addr
}
return net, target
}

26
client/client_test.go Normal file
View File

@@ -0,0 +1,26 @@
package client
import (
"testing"
)
func TestParseDialTarget(t *testing.T) {
tests := []struct {
input string
expectedNet string
expectedAddr string
}{
{"tcp://localhost:50051", "tcp", "localhost:50051"},
{"localhost:50051", "tcp", "localhost:50051"},
{"unix:///tmp/socket", "unix", "/tmp/socket"},
{"unix://C:/path/to/socket", "unix", "C:/path/to/socket"},
{"unix:path/to/socket", "unix", "path/to/socket"},
{`\\.\pipe\example`, "pipe", `\\.\pipe\example`},
}
for _, test := range tests {
net, addr := parseDialTarget(test.input)
if net != test.expectedNet || addr != test.expectedAddr {
t.Errorf("parseDialTarget(%q) = (%q, %q); want (%q, %q)", test.input, net, addr, test.expectedNet, test.expectedAddr)
}
}
}

12
client/dialer.go Normal file
View File

@@ -0,0 +1,12 @@
//go:build !windows
package client
import (
"context"
"net"
)
func dial(ctx context.Context, addr string) (net.Conn, error) {
network, address := parseDialTarget(addr)
return (&net.Dialer{}).DialContext(ctx, network, address)
}

18
client/dialer_windows.go Normal file
View File

@@ -0,0 +1,18 @@
//go:build windows
package client
import (
"context"
"net"
"golang.zx2c4.com/wireguard/ipc/namedpipe"
)
func dial(ctx context.Context, addr string) (net.Conn, error) {
network, address := parseDialTarget(addr)
if network == "pipe" {
return namedpipe.DialContext(ctx, address)
}
return (&net.Dialer{}).DialContext(ctx, network, address)
}

View File

@@ -2,15 +2,36 @@ package peercreds
import (
"context"
"crypto/tls"
"errors"
"net"
"github.com/soheilhy/cmux"
"github.com/tailscale/peercred"
"google.golang.org/grpc/credentials"
)
var ErrUnsupportedConnType = peercred.ErrUnsupportedConnType
var _ credentials.TransportCredentials = (*peerCreds)(nil)
// Creds are the peer credentials.
type Creds struct {
pid int
uid string
}
func (c *Creds) PID() (pid int, ok bool) {
return c.pid, c.pid != 0
}
// UserID returns the userid (or Windows SID) that owns the other side
// of the connection, if known. (ok is false if not known)
// The returned string is suitable to passing to os/user.LookupId.
func (c *Creds) UserID() (uid string, ok bool) {
return c.uid, c.uid != ""
}
var common = credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity}
func New() credentials.TransportCredentials {
@@ -27,12 +48,12 @@ type peerCreds struct {
// AuthInfo well attach to the gRPC peer
type AuthInfo struct {
credentials.CommonAuthInfo
Creds *peercred.Creds
Creds Creds
}
func (AuthInfo) AuthType() string { return "peercred" }
func (t *peerCreds) ClientHandshake(_ context.Context, _ string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
func (t *peerCreds) ClientHandshake(ctx context.Context, authority string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return t.handshakeConn(conn)
}
@@ -54,15 +75,30 @@ func (t *peerCreds) OverrideServerName(name string) error {
}
func (t *peerCreds) handshakeConn(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
if conn.RemoteAddr().Network() != "unix" {
return nil, nil, errors.New("peercred only works with unix domain sockets")
if conn.RemoteAddr().Network() != "unix" && conn.RemoteAddr().Network() != "pipe" {
return nil, nil, errors.New("peercred only works with unix domain sockets or Windows named pipes")
}
creds, err := peercred.Get(conn)
inner := conn
unwrap:
for {
switch c := inner.(type) {
case *cmux.MuxConn:
inner = c.Conn
case *tls.Conn:
inner = c.NetConn()
default:
break unwrap
}
}
creds, err := Get(inner)
if err != nil {
if errors.Is(err, peercred.ErrNotImplemented) {
return nil, nil, errors.New("peercred not implemented on this OS")
}
return nil, nil, err
}
return conn, AuthInfo{Creds: creds, CommonAuthInfo: common}, nil
var c Creds
c.uid, _ = creds.UserID()
c.pid, _ = creds.PID()
return conn, AuthInfo{Creds: c, CommonAuthInfo: common}, nil
}

View File

@@ -0,0 +1,20 @@
//go:build !windows
package peercreds
import (
"net"
"github.com/tailscale/peercred"
)
func Get(conn net.Conn) (*Creds, error) {
creds, err := peercred.Get(conn)
if err != nil {
return nil, err
}
var c Creds
c.uid, _ = creds.UserID()
c.pid, _ = creds.PID()
return &c, nil
}

View File

@@ -0,0 +1,146 @@
//go:build windows
package peercreds
import (
"fmt"
"net"
"reflect"
"unsafe"
"golang.org/x/sys/windows"
)
// Get returns peer creds for the client connected to this server-side pipe
// connection. The conn must be a net.Conn returned from go-winio's ListenPipe.
func Get(conn net.Conn) (*Creds, error) {
if conn == nil {
return nil, ErrUnsupportedConnType
}
h, err := pipeHandle(conn)
if err != nil {
return nil, err
}
// Get client PID for this pipe instance.
var pid uint32
if err := windows.GetNamedPipeClientProcessId(h, &pid); err != nil {
return nil, fmt.Errorf("GetNamedPipeClientProcessId: %w", err)
}
if pid == 0 {
return nil, fmt.Errorf("GetNamedPipeClientProcessId returned pid=0")
}
// Open the client process with query rights.
const processQueryLimitedInfo = windows.PROCESS_QUERY_LIMITED_INFORMATION
ph, err := windows.OpenProcess(processQueryLimitedInfo, false, pid)
if err != nil {
return nil, fmt.Errorf("OpenProcess(%d): %w", pid, err)
}
defer windows.CloseHandle(ph)
// Open the process token.
var token windows.Token
if err := windows.OpenProcessToken(ph, windows.TOKEN_QUERY, &token); err != nil {
return nil, fmt.Errorf("OpenProcessToken: %w", err)
}
defer token.Close()
// Get the token's user SID.
tu, err := token.GetTokenUser()
if err != nil {
return nil, fmt.Errorf("GetTokenUser: %w", err)
}
return &Creds{
uid: tu.User.Sid.String(),
pid: int(pid),
}, nil
}
type handle interface {
Handle() windows.Handle
}
func pipeHandle(conn net.Conn) (windows.Handle, error) {
if c, ok := conn.(handle); ok {
return c.Handle(), nil
}
return winioPipeHandle(conn)
}
// winioPipeHandle digs the underlying syscall HANDLE out of a go-winio
// pipe connection using reflect + unsafe. This depends on the current
// internal layout of github.com/Microsoft/go-winio:
//
// type win32Pipe struct {
// *win32File
// path string
// }
//
// type win32MessageBytePipe struct {
// win32Pipe
// writeClosed bool
// readEOF bool
// }
//
// type win32File struct {
// handle syscall.Handle
// ...
// }
//
// See pipe.go + file.go in go-winio. :contentReference[oaicite:1]{index=1}
func winioPipeHandle(conn net.Conn) (windows.Handle, error) {
v := reflect.ValueOf(conn)
if !v.IsValid() {
return 0, ErrUnsupportedConnType
}
// Peel off interface & pointer layers: net.Conn is an interface and the
// concrete type is *win32Pipe or *win32MessageBytePipe.
for v.Kind() == reflect.Interface || v.Kind() == reflect.Ptr {
if v.IsNil() {
return 0, ErrUnsupportedConnType
}
v = v.Elem()
}
if v.Kind() != reflect.Struct {
return 0, ErrUnsupportedConnType
}
var wfField reflect.Value
// Case 1: *win32Pipe { *win32File; path string }
if f := v.FieldByName("win32File"); f.IsValid() && f.Kind() == reflect.Ptr {
wfField = f
} else if v.NumField() > 0 {
// Case 2: *win32MessageBytePipe { win32Pipe; ... }
embedded := v.Field(0)
if embedded.IsValid() && embedded.Kind() == reflect.Struct {
if f2 := embedded.FieldByName("win32File"); f2.IsValid() && f2.Kind() == reflect.Ptr {
wfField = f2
}
}
}
if !wfField.IsValid() || wfField.IsNil() {
return 0, ErrUnsupportedConnType
}
// wfField is a *win32File. Its first field is "handle syscall.Handle".
// We only need the first field, so we define a 1-field header type with
// compatible layout and reinterpret the pointer.
type win32FileHeader struct {
Handle windows.Handle // same underlying type as syscall.Handle
}
ptr := unsafe.Pointer(wfField.Pointer())
h := (*win32FileHeader)(ptr).Handle
if h == 0 {
return 0, fmt.Errorf("winio pipe handle is 0")
}
return h, nil
}

View File

@@ -7,6 +7,8 @@ import (
"io"
"net"
"net/http"
"os"
"runtime"
"strings"
"time"
@@ -51,7 +53,12 @@ func run(ctx context.Context, opts ...service.Option) {
)
defer p.Shutdown(ctx)
address := "0.0.0.0:9991"
// address := "0.0.0.0:9991"
address := "unix:///tmp/example.sock"
if runtime.GOOS == "windows" {
address = `\\.\pipe\example`
}
defer os.Remove("/tmp/example.sock")
var svc service.Service
opts = append(opts,
@@ -96,7 +103,7 @@ func run(ctx context.Context, opts ...service.Option) {
copts := []client.Option{
// client.WithName(name),
// client.WithVersion(version),
client.WithAddress("localhost:9991"),
client.WithAddress(address),
// client.WithRegistry(mdns.NewRegistry()),
client.WithSecure(secure),
client.WithInterceptors(tracing.NewClientInterceptors()),

View File

@@ -63,6 +63,7 @@ require (
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/soheilhy/cmux v0.1.5 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc // indirect
github.com/tmc/grpc-websocket-proxy v0.0.0-20220101234140-673ab2c3ae75 // indirect
github.com/traefik/grpc-web v0.16.0 // indirect
github.com/uptrace/opentelemetry-go-extra/otellogrus v0.3.2 // indirect
@@ -86,6 +87,7 @@ require (
golang.org/x/sync v0.14.0 // indirect
golang.org/x/sys v0.33.0 // indirect
golang.org/x/text v0.25.0 // indirect
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250519155744-55703ea1f237 // indirect
nhooyr.io/websocket v1.8.17 // indirect
)

View File

@@ -326,6 +326,8 @@ github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc h1:24heQPtnFR+yfntqhI3oAu9i27nEojcQ4NuBQOo5ZFA=
github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc/go.mod h1:f93CXfllFsO9ZQVq+Zocb1Gp4G5Fz0b0rXHLOzt/Djc=
github.com/tmc/grpc-websocket-proxy v0.0.0-20220101234140-673ab2c3ae75 h1:6fotK7otjonDflCTK0BCfls4SPy3NcCVb5dqqmbRknE=
github.com/tmc/grpc-websocket-proxy v0.0.0-20220101234140-673ab2c3ae75/go.mod h1:KO6IkyS8Y3j8OdNO85qEYBsRPuteD+YciPomcXdrMnk=
github.com/traefik/grpc-web v0.16.0 h1:eeUWZaFg6ZU0I9dWOYE2D5qkNzRBmXzzuRlxdltascY=
@@ -617,6 +619,8 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A=
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=
google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M=
google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg=

View File

@@ -7,6 +7,7 @@ import (
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"google.golang.org/grpc"
"google.golang.org/grpc/peer"
"go.linka.cloud/grpc-tookit/example/pb"
"go.linka.cloud/grpc-toolkit/interceptors/iface"
@@ -30,6 +31,11 @@ func (g *GreeterHandler) SayHello(ctx context.Context, req *pb.HelloRequest) (*p
span := trace.SpanFromContext(ctx)
span.SetAttributes(attribute.String("name", req.Name))
logger.C(ctx).Infof("replying to %s", req.Name)
if p, ok := peer.FromContext(ctx); ok {
logger.C(ctx).Infof("peer auth info: %+v", p.AuthInfo)
} else {
logger.C(ctx).Infof("no peer info")
}
return &pb.HelloReply{Message: hello(req.Name)}, nil
}

View File

@@ -27,15 +27,12 @@ func newService(ctx context.Context, opts ...service.Option) (service.Service, e
log := logger.C(ctx)
metrics := metrics2.NewInterceptors(metrics2.WithExemplarFromContext(metrics2.DefaultExemplarFromCtx))
address := "0.0.0.0:9991"
var svc service.Service
opts = append(opts,
service.WithContext(ctx),
service.WithAddress(address),
// service.WithRegistry(mdns.NewRegistry()),
service.WithReflection(true),
service.WithoutCmux(),
// service.WithoutCmux(),
service.WithGateway(pb.RegisterGreeterHandler),
service.WithGatewayPrefix("/rest"),
service.WithGRPCWeb(true),

5
go.mod
View File

@@ -1,6 +1,6 @@
module go.linka.cloud/grpc-toolkit
go 1.23.0
go 1.23.1
toolchain go1.24.3
@@ -56,6 +56,8 @@ require (
go.uber.org/multierr v1.7.0
golang.org/x/net v0.40.0
golang.org/x/sync v0.14.0
golang.org/x/sys v0.33.0
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
google.golang.org/genproto/googleapis/api v0.0.0-20250519155744-55703ea1f237
google.golang.org/genproto/googleapis/rpc v0.0.0-20250519155744-55703ea1f237
google.golang.org/grpc v1.72.1
@@ -96,7 +98,6 @@ require (
go.opentelemetry.io/proto/otlp v1.6.0 // indirect
go.uber.org/atomic v1.7.0 // indirect
golang.org/x/mod v0.21.0 // indirect
golang.org/x/sys v0.33.0 // indirect
golang.org/x/text v0.25.0 // indirect
golang.org/x/tools v0.26.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect

2
go.sum
View File

@@ -904,6 +904,8 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A=
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=
google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M=
google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg=

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"io"
"os"
"path/filepath"
"runtime"
"strings"
@@ -46,7 +47,7 @@ func New() Logger {
return &logger{fl: logrus.New()}
}
func FromLogrus(fl logrus.Ext1FieldLogger) Logger {
func FromLogrus(fl LogrusLogger) Logger {
return &logger{fl: fl}
}
@@ -56,9 +57,10 @@ type Logger interface {
WithContext(ctx context.Context) Logger
WithReportCaller(b bool, depth ...uint) Logger
WithOffset(n int) Logger
WithField(key string, value interface{}) Logger
WithFields(kv ...interface{}) Logger
WithField(key string, value any) Logger
WithFields(kv ...any) Logger
WithError(err error) Logger
SetLevel(level Level) Logger
@@ -66,35 +68,35 @@ type Logger interface {
SetOutput(w io.Writer) Logger
Tracef(format string, args ...interface{})
Debugf(format string, args ...interface{})
Infof(format string, args ...interface{})
Printf(format string, args ...interface{})
Warnf(format string, args ...interface{})
Warningf(format string, args ...interface{})
Errorf(format string, args ...interface{})
Fatalf(format string, args ...interface{})
Panicf(format string, args ...interface{})
Tracef(format string, args ...any)
Debugf(format string, args ...any)
Infof(format string, args ...any)
Printf(format string, args ...any)
Warnf(format string, args ...any)
Warningf(format string, args ...any)
Errorf(format string, args ...any)
Fatalf(format string, args ...any)
Panicf(format string, args ...any)
Trace(args ...interface{})
Debug(args ...interface{})
Info(args ...interface{})
Print(args ...interface{})
Warn(args ...interface{})
Warning(args ...interface{})
Error(args ...interface{})
Fatal(args ...interface{})
Panic(args ...interface{})
Trace(args ...any)
Debug(args ...any)
Info(args ...any)
Print(args ...any)
Warn(args ...any)
Warning(args ...any)
Error(args ...any)
Fatal(args ...any)
Panic(args ...any)
Traceln(args ...interface{})
Debugln(args ...interface{})
Infoln(args ...interface{})
Println(args ...interface{})
Warnln(args ...interface{})
Warningln(args ...interface{})
Errorln(args ...interface{})
Fatalln(args ...interface{})
Panicln(args ...interface{})
Traceln(args ...any)
Debugln(args ...any)
Infoln(args ...any)
Println(args ...any)
Warnln(args ...any)
Warningln(args ...any)
Errorln(args ...any)
Fatalln(args ...any)
Panicln(args ...any)
Logr() logr.Logger
FieldLogger() logrus.FieldLogger
@@ -104,116 +106,120 @@ type Logger interface {
}
type logger struct {
fl logrus.Ext1FieldLogger
fl LogrusLogger
reportCaller *int
offset int
}
func (l *logger) Tracef(format string, args ...interface{}) {
l.withCaller().Tracef(format, args...)
func (l *logger) Tracef(format string, args ...any) {
l.logf(TraceLevel, format, args...)
}
func (l *logger) Debugf(format string, args ...interface{}) {
l.withCaller().Debugf(format, args...)
func (l *logger) Debugf(format string, args ...any) {
l.logf(DebugLevel, format, args...)
}
func (l *logger) Infof(format string, args ...interface{}) {
l.withCaller().Infof(format, args...)
func (l *logger) Infof(format string, args ...any) {
l.logf(InfoLevel, format, args...)
}
func (l *logger) Printf(format string, args ...interface{}) {
l.withCaller().Printf(format, args...)
func (l *logger) Printf(format string, args ...any) {
l.logf(InfoLevel, format, args...)
}
func (l *logger) Warnf(format string, args ...interface{}) {
l.withCaller().Warnf(format, args...)
func (l *logger) Warnf(format string, args ...any) {
l.logf(WarnLevel, format, args...)
}
func (l *logger) Warningf(format string, args ...interface{}) {
l.withCaller().Warningf(format, args...)
func (l *logger) Warningf(format string, args ...any) {
l.logf(WarnLevel, format, args...)
}
func (l *logger) Errorf(format string, args ...interface{}) {
l.withCaller().Errorf(format, args...)
func (l *logger) Errorf(format string, args ...any) {
l.logf(ErrorLevel, format, args...)
}
func (l *logger) Fatalf(format string, args ...interface{}) {
l.withCaller().Fatalf(format, args...)
func (l *logger) Fatalf(format string, args ...any) {
l.logf(FatalLevel, format, args...)
os.Exit(1)
}
func (l *logger) Panicf(format string, args ...interface{}) {
l.withCaller().Panicf(format, args...)
func (l *logger) Panicf(format string, args ...any) {
l.logf(PanicLevel, format, args...)
}
func (l *logger) Trace(args ...interface{}) {
l.withCaller().Trace(args...)
func (l *logger) Trace(args ...any) {
l.log(TraceLevel, args...)
}
func (l *logger) Debug(args ...interface{}) {
l.withCaller().Debug(args...)
func (l *logger) Debug(args ...any) {
l.log(DebugLevel, args...)
}
func (l *logger) Info(args ...interface{}) {
l.withCaller().Info(args...)
func (l *logger) Info(args ...any) {
l.log(InfoLevel, args...)
}
func (l *logger) Print(args ...interface{}) {
l.withCaller().Print(args...)
func (l *logger) Print(args ...any) {
l.log(InfoLevel, args...)
}
func (l *logger) Warn(args ...interface{}) {
l.withCaller().Warn(args...)
func (l *logger) Warn(args ...any) {
l.log(WarnLevel, args...)
}
func (l *logger) Warning(args ...interface{}) {
l.withCaller().Warning(args...)
func (l *logger) Warning(args ...any) {
l.log(WarnLevel, args...)
}
func (l *logger) Error(args ...interface{}) {
l.withCaller().Error(args...)
func (l *logger) Error(args ...any) {
l.log(ErrorLevel, args...)
}
func (l *logger) Fatal(args ...interface{}) {
l.withCaller().Fatal(args...)
func (l *logger) Fatal(args ...any) {
l.log(FatalLevel, args...)
os.Exit(1)
}
func (l *logger) Panic(args ...interface{}) {
l.withCaller().Panic(args...)
func (l *logger) Panic(args ...any) {
l.log(PanicLevel, args...)
}
func (l *logger) Traceln(args ...interface{}) {
l.withCaller().Traceln(args...)
func (l *logger) Traceln(args ...any) {
l.logln(TraceLevel, args...)
}
func (l *logger) Debugln(args ...interface{}) {
l.withCaller().Debugln(args...)
func (l *logger) Debugln(args ...any) {
l.logln(DebugLevel, args...)
}
func (l *logger) Infoln(args ...interface{}) {
l.withCaller().Infoln(args...)
func (l *logger) Infoln(args ...any) {
l.logln(InfoLevel, args...)
}
func (l *logger) Println(args ...interface{}) {
l.withCaller().Println(args...)
func (l *logger) Println(args ...any) {
l.logln(InfoLevel, args...)
}
func (l *logger) Warnln(args ...interface{}) {
l.withCaller().Warnln(args...)
func (l *logger) Warnln(args ...any) {
l.logln(WarnLevel, args...)
}
func (l *logger) Warningln(args ...interface{}) {
l.withCaller().Warningln(args...)
func (l *logger) Warningln(args ...any) {
l.logln(WarnLevel, args...)
}
func (l *logger) Errorln(args ...interface{}) {
l.withCaller().Errorln(args...)
func (l *logger) Errorln(args ...any) {
l.logln(ErrorLevel, args...)
}
func (l *logger) Fatalln(args ...interface{}) {
l.withCaller().Fatalln(args...)
func (l *logger) Fatalln(args ...any) {
l.logln(FatalLevel, args...)
os.Exit(1)
}
func (l *logger) Panicln(args ...interface{}) {
l.withCaller().Panicln(args...)
func (l *logger) Panicln(args ...any) {
l.logln(PanicLevel, args...)
}
func (l *logger) WriterLevel(level Level) *io.PipeWriter {
@@ -228,32 +234,32 @@ func (l *logger) SetLevel(level Level) Logger {
func (l *logger) WithContext(ctx context.Context) Logger {
switch t := l.fl.(type) {
case *logrus.Logger:
return &logger{fl: t.WithContext(ctx), reportCaller: l.reportCaller}
return &logger{fl: t.WithContext(ctx), reportCaller: l.reportCaller, offset: l.offset}
case *logrus.Entry:
return &logger{fl: t.WithContext(ctx), reportCaller: l.reportCaller}
return &logger{fl: t.WithContext(ctx), reportCaller: l.reportCaller, offset: l.offset}
}
panic(fmt.Sprintf("unexpected logger type %T", l.fl))
}
func (l *logger) WithField(key string, value interface{}) Logger {
return &logger{fl: l.fl.WithField(key, value), reportCaller: l.reportCaller}
func (l *logger) WithField(key string, value any) Logger {
return &logger{fl: l.fl.WithField(key, value), reportCaller: l.reportCaller, offset: l.offset}
}
func (l *logger) WithFields(kv ...interface{}) Logger {
log := &logger{fl: l.fl}
func (l *logger) WithFields(kv ...any) Logger {
log := &logger{fl: l.fl, reportCaller: l.reportCaller, offset: l.offset}
for i := 0; i < len(kv); i += 2 {
log = &logger{fl: log.fl.WithField(fmt.Sprintf("%v", kv[i]), kv[i+1]), reportCaller: l.reportCaller}
log = &logger{fl: log.fl.WithField(fmt.Sprintf("%v", kv[i]), kv[i+1]), reportCaller: l.reportCaller, offset: l.offset}
}
return log
}
func (l *logger) WithError(err error) Logger {
return &logger{fl: l.fl.WithError(err), reportCaller: l.reportCaller}
return &logger{fl: l.fl.WithError(err), reportCaller: l.reportCaller, offset: l.offset}
}
func (l *logger) WithReportCaller(b bool, depth ...uint) Logger {
if !b {
return &logger{fl: l.fl}
return &logger{fl: l.fl, reportCaller: nil, offset: l.offset}
}
var d int
if len(depth) > 0 {
@@ -261,11 +267,15 @@ func (l *logger) WithReportCaller(b bool, depth ...uint) Logger {
} else {
d = 0
}
return &logger{fl: l.fl, reportCaller: &d}
return &logger{fl: l.fl, reportCaller: &d, offset: l.offset}
}
func (l *logger) WithOffset(n int) Logger {
return &logger{fl: l.fl, reportCaller: l.reportCaller, offset: n}
}
func (l *logger) Logr() logr.Logger {
return logrusr.New(l.fl, logrusr.WithFormatter(func(i interface{}) interface{} {
return logrusr.New(l.fl, logrusr.WithFormatter(func(i any) any {
return fmt.Sprintf("%v", i)
}))
}
@@ -297,7 +307,7 @@ func (l *logger) Clone() Logger {
n.Out = t.Out
n.Formatter = t.Formatter
n.Hooks = t.Hooks
return &logger{fl: n, reportCaller: l.reportCaller}
return &logger{fl: n, reportCaller: l.reportCaller, offset: l.offset}
case *logrus.Entry:
t = t.Dup()
n.Level = t.Logger.Level
@@ -305,22 +315,41 @@ func (l *logger) Clone() Logger {
n.Formatter = t.Logger.Formatter
n.Hooks = t.Logger.Hooks
t.Logger = n
return &logger{fl: t, reportCaller: l.reportCaller}
return &logger{fl: t, reportCaller: l.reportCaller, offset: l.offset}
}
panic(fmt.Sprintf("unexpected logger type %T", l.fl))
}
func (l *logger) withCaller() logrus.Ext1FieldLogger {
func (l *logger) logf(level logrus.Level, format string, args ...any) {
l.withCaller().Logf(l.level(level), format, args...)
}
func (l *logger) log(level logrus.Level, args ...any) {
l.withCaller().Log(l.level(level), args...)
}
func (l *logger) logln(level logrus.Level, args ...any) {
l.withCaller().Logln(l.level(level), args...)
}
func (l *logger) withCaller() LogrusLogger {
if l.reportCaller == nil {
return l.fl
}
pcs := make([]uintptr, 1)
runtime.Callers(3+*l.reportCaller, pcs)
runtime.Callers(4+*l.reportCaller, pcs)
f, _ := runtime.CallersFrames(pcs).Next()
pkg := getPackageName(f.Function)
return l.fl.WithField("caller", fmt.Sprintf("%s/%s:%d", pkg, filepath.Base(f.File), f.Line)).WithField("func", f.Func.Name())
}
func (l *logger) level(lvl Level) logrus.Level {
if lvl > 3 {
return lvl + logrus.Level(l.offset)
}
return lvl
}
// getPackageName reduces a fully qualified function name to the package name
// There really ought to be a better way...
func getPackageName(f string) string {
@@ -336,3 +365,12 @@ func getPackageName(f string) string {
return f
}
type LogrusLogger interface {
logrus.FieldLogger
logrus.Ext1FieldLogger
Log(level logrus.Level, args ...any)
Logf(level logrus.Level, format string, args ...any)
Logln(level logrus.Level, args ...any)
}

11
service/listen_other.go Normal file
View File

@@ -0,0 +1,11 @@
//go:build !windows
package service
import (
"net"
)
func listen(network, address string) (net.Listener, error) {
return net.Listen(network, address)
}

17
service/listen_windows.go Normal file
View File

@@ -0,0 +1,17 @@
package service
import (
"net"
"golang.zx2c4.com/wireguard/ipc/namedpipe"
)
// listen uses wireguard's namedpipe package to listen on named pipes on Windows until
// https://github.com/golang/go/issues/49650 is resolved.
// For other networks, it falls back to the standard net.Listen.
func listen(network, address string) (net.Listener, error) {
if network == "pipe" {
return namedpipe.Listen(address)
}
return net.Listen(network, address)
}

View File

@@ -139,7 +139,8 @@ func newService(opts ...Option) (*service, error) {
grpc.StreamInterceptor(si),
grpc.UnaryInterceptor(ui),
}
if _, ok := s.opts.lis.(*net.UnixListener); ok || strings.HasPrefix(s.opts.address, "unix://") {
if (s.opts.lis != nil && (s.opts.lis.Addr().Network() == "unix" || s.opts.lis.Addr().Network() == "pipe")) ||
strings.HasPrefix(s.opts.address, "unix://") || strings.HasPrefix(s.opts.address, `\\.\pipe\`) {
gopts = append(gopts, grpc.Creds(peercreds.New()))
}
s.server = grpc.NewServer(append(gopts, s.opts.serverOpts...)...)
@@ -179,9 +180,12 @@ func (s *service) start() (*errgroup.Group, error) {
network = "unix"
s.opts.address = strings.TrimPrefix(s.opts.address, "unix://")
}
if strings.HasPrefix(s.opts.address, `\\.\pipe\`) {
network = "pipe"
}
if s.opts.lis == nil {
lis, err := net.Listen(network, s.opts.address)
lis, err := listen(network, s.opts.address)
if err != nil {
return nil, err
}