mirror of
https://github.com/linka-cloud/grpc.git
synced 2025-11-24 05:03:16 +00:00
105 lines
2.5 KiB
Go
105 lines
2.5 KiB
Go
package peercreds
|
||
|
||
import (
|
||
"context"
|
||
"crypto/tls"
|
||
"errors"
|
||
"net"
|
||
|
||
"github.com/soheilhy/cmux"
|
||
"github.com/tailscale/peercred"
|
||
"google.golang.org/grpc/credentials"
|
||
)
|
||
|
||
var ErrUnsupportedConnType = peercred.ErrUnsupportedConnType
|
||
|
||
var _ credentials.TransportCredentials = (*peerCreds)(nil)
|
||
|
||
// Creds are the peer credentials.
|
||
type Creds struct {
|
||
pid int
|
||
uid string
|
||
}
|
||
|
||
func (c *Creds) PID() (pid int, ok bool) {
|
||
return c.pid, c.pid != 0
|
||
}
|
||
|
||
// UserID returns the userid (or Windows SID) that owns the other side
|
||
// of the connection, if known. (ok is false if not known)
|
||
// The returned string is suitable to passing to os/user.LookupId.
|
||
func (c *Creds) UserID() (uid string, ok bool) {
|
||
return c.uid, c.uid != ""
|
||
}
|
||
|
||
var common = credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity}
|
||
|
||
func New() credentials.TransportCredentials {
|
||
return &peerCreds{info: credentials.ProtocolInfo{
|
||
SecurityProtocol: "peercred",
|
||
ProtocolVersion: "0.1",
|
||
}}
|
||
}
|
||
|
||
type peerCreds struct {
|
||
info credentials.ProtocolInfo
|
||
}
|
||
|
||
// AuthInfo we’ll attach to the gRPC peer
|
||
type AuthInfo struct {
|
||
credentials.CommonAuthInfo
|
||
Creds Creds
|
||
}
|
||
|
||
func (AuthInfo) AuthType() string { return "peercred" }
|
||
|
||
func (t *peerCreds) ClientHandshake(ctx context.Context, authority string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||
return t.handshakeConn(conn)
|
||
}
|
||
|
||
func (t *peerCreds) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||
return t.handshakeConn(conn)
|
||
}
|
||
|
||
func (t *peerCreds) Info() credentials.ProtocolInfo {
|
||
return t.info
|
||
}
|
||
|
||
func (t *peerCreds) Clone() credentials.TransportCredentials {
|
||
return &peerCreds{info: t.info}
|
||
}
|
||
|
||
func (t *peerCreds) OverrideServerName(name string) error {
|
||
t.info.ServerName = name
|
||
return nil
|
||
}
|
||
|
||
func (t *peerCreds) handshakeConn(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||
if conn.RemoteAddr().Network() != "unix" && conn.RemoteAddr().Network() != "pipe" {
|
||
return nil, nil, errors.New("peercred only works with unix domain sockets or Windows named pipes")
|
||
}
|
||
inner := conn
|
||
unwrap:
|
||
for {
|
||
switch c := inner.(type) {
|
||
case *cmux.MuxConn:
|
||
inner = c.Conn
|
||
case *tls.Conn:
|
||
inner = c.NetConn()
|
||
default:
|
||
break unwrap
|
||
}
|
||
}
|
||
creds, err := Get(inner)
|
||
if err != nil {
|
||
if errors.Is(err, peercred.ErrNotImplemented) {
|
||
return nil, nil, errors.New("peercred not implemented on this OS")
|
||
}
|
||
return nil, nil, err
|
||
}
|
||
var c Creds
|
||
c.uid, _ = creds.UserID()
|
||
c.pid, _ = creds.PID()
|
||
return conn, AuthInfo{Creds: c, CommonAuthInfo: common}, nil
|
||
}
|