598 lines
14 KiB
Go
598 lines
14 KiB
Go
|
package utp
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"errors"
|
||
|
"io"
|
||
|
"log"
|
||
|
"math/rand"
|
||
|
"net"
|
||
|
"sync"
|
||
|
"time"
|
||
|
|
||
|
"github.com/anacrolix/missinggo"
|
||
|
"github.com/anacrolix/missinggo/inproc"
|
||
|
"github.com/anacrolix/missinggo/pproffd"
|
||
|
)
|
||
|
|
||
|
var (
|
||
|
_ net.Listener = &Socket{}
|
||
|
_ net.PacketConn = &Socket{}
|
||
|
)
|
||
|
|
||
|
// Uniquely identifies any uTP connection on top of the underlying packet
|
||
|
// stream.
|
||
|
type connKey struct {
|
||
|
remoteAddr resolvedAddrStr
|
||
|
connID uint16
|
||
|
}
|
||
|
|
||
|
// A Socket wraps a net.PacketConn, diverting uTP packets to its child uTP
|
||
|
// Conns.
|
||
|
type Socket struct {
|
||
|
pc net.PacketConn
|
||
|
conns map[connKey]*Conn
|
||
|
|
||
|
backlogNotEmpty missinggo.Event
|
||
|
backlog map[syn]struct{}
|
||
|
|
||
|
closed missinggo.Event
|
||
|
destroyed missinggo.Event
|
||
|
|
||
|
wgReadWrite sync.WaitGroup
|
||
|
|
||
|
unusedReads chan read
|
||
|
connDeadlines
|
||
|
// If a read error occurs on the underlying net.PacketConn, it is put
|
||
|
// here. This is because reading is done in its own goroutine to dispatch
|
||
|
// to uTP Conns.
|
||
|
ReadErr error
|
||
|
}
|
||
|
|
||
|
func listenPacket(network, addr string) (pc net.PacketConn, err error) {
|
||
|
if network == "inproc" {
|
||
|
return inproc.ListenPacket(network, addr)
|
||
|
}
|
||
|
return net.ListenPacket(network, addr)
|
||
|
}
|
||
|
|
||
|
// NewSocket creates a net.PacketConn with the given network and address, and
|
||
|
// returns a Socket dispatching on it.
|
||
|
func NewSocket(network, addr string) (s *Socket, err error) {
|
||
|
if network == "" {
|
||
|
network = "udp"
|
||
|
}
|
||
|
pc, err := listenPacket(network, addr)
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
return NewSocketFromPacketConn(pc)
|
||
|
}
|
||
|
|
||
|
// Create a Socket, using the provided net.PacketConn. If you want to retain
|
||
|
// use of the net.PacketConn after the Socket closes it, override the
|
||
|
// net.PacketConn's Close method, or use NetSocketFromPacketConnNoClose.
|
||
|
func NewSocketFromPacketConn(pc net.PacketConn) (s *Socket, err error) {
|
||
|
s = &Socket{
|
||
|
backlog: make(map[syn]struct{}, backlog),
|
||
|
pc: pc,
|
||
|
unusedReads: make(chan read, 100),
|
||
|
wgReadWrite: sync.WaitGroup{},
|
||
|
}
|
||
|
mu.Lock()
|
||
|
sockets[s] = struct{}{}
|
||
|
mu.Unlock()
|
||
|
go s.reader()
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// Create a Socket using the provided PacketConn, that doesn't close the
|
||
|
// PacketConn when the Socket is closed.
|
||
|
func NewSocketFromPacketConnNoClose(pc net.PacketConn) (s *Socket, err error) {
|
||
|
return NewSocketFromPacketConn(packetConnNopCloser{pc})
|
||
|
}
|
||
|
|
||
|
func (s *Socket) unusedRead(read read) {
|
||
|
unusedReads.Add(1)
|
||
|
select {
|
||
|
case s.unusedReads <- read:
|
||
|
default:
|
||
|
// Drop the packet.
|
||
|
unusedReadsDropped.Add(1)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (s *Socket) strNetAddr(str string) (a net.Addr) {
|
||
|
var err error
|
||
|
switch n := s.network(); n {
|
||
|
case "udp":
|
||
|
a, err = net.ResolveUDPAddr(n, str)
|
||
|
case "inproc":
|
||
|
a, err = inproc.ResolveAddr(n, str)
|
||
|
default:
|
||
|
panic(n)
|
||
|
}
|
||
|
if err != nil {
|
||
|
panic(err)
|
||
|
}
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (s *Socket) pushBacklog(syn syn) {
|
||
|
if _, ok := s.backlog[syn]; ok {
|
||
|
return
|
||
|
}
|
||
|
// Pop a pseudo-random syn to make room. TODO: Use missinggo/orderedmap,
|
||
|
// coz that's what is wanted here.
|
||
|
for k := range s.backlog {
|
||
|
if len(s.backlog) < backlog {
|
||
|
break
|
||
|
}
|
||
|
delete(s.backlog, k)
|
||
|
// A syn is sent on the remote's recv_id, so this is where we can send
|
||
|
// the reset.
|
||
|
s.reset(s.strNetAddr(k.addr), k.seq_nr, k.conn_id)
|
||
|
}
|
||
|
s.backlog[syn] = struct{}{}
|
||
|
s.backlogChanged()
|
||
|
}
|
||
|
|
||
|
func (s *Socket) reader() {
|
||
|
mu.Lock()
|
||
|
defer mu.Unlock()
|
||
|
defer s.destroy()
|
||
|
var b [maxRecvSize]byte
|
||
|
for {
|
||
|
s.wgReadWrite.Add(1)
|
||
|
mu.Unlock()
|
||
|
n, addr, err := s.pc.ReadFrom(b[:])
|
||
|
s.wgReadWrite.Done()
|
||
|
mu.Lock()
|
||
|
if s.destroyed.IsSet() {
|
||
|
return
|
||
|
}
|
||
|
if err != nil {
|
||
|
log.Printf("error reading Socket PacketConn: %s", err)
|
||
|
s.ReadErr = err
|
||
|
return
|
||
|
}
|
||
|
s.handleReceivedPacket(read{
|
||
|
append([]byte(nil), b[:n]...),
|
||
|
addr,
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func receivedUTPPacketSize(n int) {
|
||
|
if n > largestReceivedUTPPacket {
|
||
|
largestReceivedUTPPacket = n
|
||
|
largestReceivedUTPPacketExpvar.Set(int64(n))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (s *Socket) connForRead(h header, from net.Addr) (c *Conn, ok bool) {
|
||
|
c, ok = s.conns[connKey{
|
||
|
resolvedAddrStr(from.String()),
|
||
|
func() uint16 {
|
||
|
if h.Type == stSyn {
|
||
|
// SYNs have a ConnID one lower than the eventual recvID, and we index
|
||
|
// the connections with that, so use it for the lookup.
|
||
|
return h.ConnID + 1
|
||
|
} else {
|
||
|
return h.ConnID
|
||
|
}
|
||
|
}(),
|
||
|
}]
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (s *Socket) handlePacketReceivedForEstablishedConn(h header, from net.Addr, data []byte, c *Conn) {
|
||
|
if h.Type == stSyn {
|
||
|
if h.ConnID == c.send_id-2 {
|
||
|
// This is a SYN for connection that cannot exist locally. The
|
||
|
// connection the remote wants to establish here with the proposed
|
||
|
// recv_id, already has an existing connection that was dialled
|
||
|
// *out* from this socket, which is why the send_id is 1 higher,
|
||
|
// rather than 1 lower than the recv_id.
|
||
|
log.Print("resetting conflicting syn")
|
||
|
s.reset(from, h.SeqNr, h.ConnID)
|
||
|
return
|
||
|
} else if h.ConnID != c.send_id {
|
||
|
panic("bad assumption")
|
||
|
}
|
||
|
}
|
||
|
c.receivePacket(h, data)
|
||
|
}
|
||
|
|
||
|
func (s *Socket) handleReceivedPacket(p read) {
|
||
|
if len(p.data) < 20 {
|
||
|
s.unusedRead(p)
|
||
|
return
|
||
|
}
|
||
|
var h header
|
||
|
hEnd, err := h.Unmarshal(p.data)
|
||
|
if err != nil || h.Type > stMax || h.Version != 1 {
|
||
|
s.unusedRead(p)
|
||
|
return
|
||
|
}
|
||
|
if c, ok := s.connForRead(h, p.from); ok {
|
||
|
receivedUTPPacketSize(len(p.data))
|
||
|
s.handlePacketReceivedForEstablishedConn(h, p.from, p.data[hEnd:], c)
|
||
|
return
|
||
|
}
|
||
|
// Packet doesn't belong to an existing connection.
|
||
|
switch h.Type {
|
||
|
case stSyn:
|
||
|
s.pushBacklog(syn{
|
||
|
seq_nr: h.SeqNr,
|
||
|
conn_id: h.ConnID,
|
||
|
addr: p.from.String(),
|
||
|
})
|
||
|
return
|
||
|
case stReset:
|
||
|
// Could be a late arriving packet for a Conn we're already done with.
|
||
|
// If it was for an existing connection, we would have handled it
|
||
|
// earlier.
|
||
|
default:
|
||
|
unexpectedPacketsRead.Add(1)
|
||
|
// This is an unexpected packet. We'll send a reset, but also pass it
|
||
|
// on. I don't think you can reset on the received packets ConnID if
|
||
|
// it isn't a SYN, as the send_id will differ in this case.
|
||
|
s.reset(p.from, h.SeqNr, h.ConnID)
|
||
|
// Connection initiated by remote.
|
||
|
s.reset(p.from, h.SeqNr, h.ConnID-1)
|
||
|
// Connection initiated locally.
|
||
|
s.reset(p.from, h.SeqNr, h.ConnID+1)
|
||
|
}
|
||
|
s.unusedRead(p)
|
||
|
}
|
||
|
|
||
|
// Send a reset in response to a packet with the given header.
|
||
|
func (s *Socket) reset(addr net.Addr, ackNr, connId uint16) {
|
||
|
b := make([]byte, 0, maxHeaderSize)
|
||
|
h := header{
|
||
|
Type: stReset,
|
||
|
Version: 1,
|
||
|
ConnID: connId,
|
||
|
AckNr: ackNr,
|
||
|
}
|
||
|
b = b[:h.Marshal(b)]
|
||
|
go s.writeTo(b, addr)
|
||
|
}
|
||
|
|
||
|
// Return a recv_id that should be free. Handling the case where it isn't is
|
||
|
// deferred to a more appropriate function.
|
||
|
func (s *Socket) newConnID(remoteAddr resolvedAddrStr) (id uint16) {
|
||
|
// Rather than use math.Rand, which requires generating all the IDs up
|
||
|
// front and allocating a slice, we do it on the stack, generating the IDs
|
||
|
// only as required. To do this, we use the fact that the array is
|
||
|
// default-initialized. IDs that are 0, are actually their index in the
|
||
|
// array. IDs that are non-zero, are +1 from their intended ID.
|
||
|
var idsBack [0x10000]int
|
||
|
ids := idsBack[:]
|
||
|
for len(ids) != 0 {
|
||
|
// Pick the next ID from the untried ids.
|
||
|
i := rand.Intn(len(ids))
|
||
|
id = uint16(ids[i])
|
||
|
// If it's zero, then treat it as though the index i was the ID.
|
||
|
// Otherwise the value we get is the ID+1.
|
||
|
if id == 0 {
|
||
|
id = uint16(i)
|
||
|
} else {
|
||
|
id--
|
||
|
}
|
||
|
// Check there's no connection using this ID for its recv_id...
|
||
|
_, ok1 := s.conns[connKey{remoteAddr, id}]
|
||
|
// and if we're connecting to our own Socket, that there isn't a Conn
|
||
|
// already receiving on what will correspond to our send_id. Note that
|
||
|
// we just assume that we could be connecting to our own Socket. This
|
||
|
// will halve the available connection IDs to each distinct remote
|
||
|
// address. Presumably that's ~0x8000, down from ~0x10000.
|
||
|
_, ok2 := s.conns[connKey{remoteAddr, id + 1}]
|
||
|
_, ok4 := s.conns[connKey{remoteAddr, id - 1}]
|
||
|
if !ok1 && !ok2 && !ok4 {
|
||
|
return
|
||
|
}
|
||
|
// The set of possible IDs is shrinking. The highest one will be lost, so
|
||
|
// it's moved to the location of the one we just tried.
|
||
|
ids[i] = len(ids) // Conveniently already +1.
|
||
|
// And shrink.
|
||
|
ids = ids[:len(ids)-1]
|
||
|
}
|
||
|
return
|
||
|
}
|
||
|
|
||
|
var (
|
||
|
zeroipv4 = net.ParseIP("0.0.0.0")
|
||
|
zeroipv6 = net.ParseIP("::")
|
||
|
|
||
|
ipv4lo = mustResolveUDP("127.0.0.1")
|
||
|
ipv6lo = mustResolveUDP("::1")
|
||
|
)
|
||
|
|
||
|
func mustResolveUDP(addr string) net.IP {
|
||
|
u, err := net.ResolveIPAddr("ip", addr)
|
||
|
if err != nil {
|
||
|
panic(err)
|
||
|
}
|
||
|
return u.IP
|
||
|
}
|
||
|
|
||
|
func realRemoteAddr(addr net.Addr) net.Addr {
|
||
|
udpAddr, ok := addr.(*net.UDPAddr)
|
||
|
if ok {
|
||
|
if udpAddr.IP.Equal(zeroipv4) {
|
||
|
udpAddr.IP = ipv4lo
|
||
|
}
|
||
|
if udpAddr.IP.Equal(zeroipv6) {
|
||
|
udpAddr.IP = ipv6lo
|
||
|
}
|
||
|
}
|
||
|
return addr
|
||
|
}
|
||
|
|
||
|
func (s *Socket) newConn(addr net.Addr) (c *Conn) {
|
||
|
addr = realRemoteAddr(addr)
|
||
|
|
||
|
c = &Conn{
|
||
|
socket: s,
|
||
|
remoteSocketAddr: addr,
|
||
|
created: time.Now(),
|
||
|
}
|
||
|
c.sendPendingSendSendStateTimer = missinggo.StoppedFuncTimer(c.sendPendingSendStateTimerCallback)
|
||
|
c.packetReadTimeoutTimer = time.AfterFunc(packetReadTimeout, c.receivePacketTimeoutCallback)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (s *Socket) Dial(addr string) (net.Conn, error) {
|
||
|
return s.DialContext(context.Background(), "", addr)
|
||
|
}
|
||
|
|
||
|
func (s *Socket) resolveAddr(network, addr string) (net.Addr, error) {
|
||
|
n := s.network()
|
||
|
if network != "" {
|
||
|
n = network
|
||
|
}
|
||
|
if n == "inproc" {
|
||
|
return inproc.ResolveAddr(n, addr)
|
||
|
}
|
||
|
return net.ResolveUDPAddr(n, addr)
|
||
|
}
|
||
|
|
||
|
func (s *Socket) network() string {
|
||
|
return s.pc.LocalAddr().Network()
|
||
|
}
|
||
|
|
||
|
func (s *Socket) startOutboundConn(addr net.Addr) (c *Conn, err error) {
|
||
|
mu.Lock()
|
||
|
defer mu.Unlock()
|
||
|
c = s.newConn(addr)
|
||
|
c.recv_id = s.newConnID(resolvedAddrStr(c.RemoteAddr().String()))
|
||
|
c.send_id = c.recv_id + 1
|
||
|
if logLevel >= 1 {
|
||
|
log.Printf("dial registering addr: %s", c.RemoteAddr().String())
|
||
|
}
|
||
|
if !s.registerConn(c.recv_id, resolvedAddrStr(c.RemoteAddr().String()), c) {
|
||
|
err = errors.New("couldn't register new connection")
|
||
|
log.Println(c.recv_id, c.RemoteAddr().String())
|
||
|
for k, c := range s.conns {
|
||
|
log.Println(k, c, c.age())
|
||
|
}
|
||
|
log.Printf("that's %d connections", len(s.conns))
|
||
|
}
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
c.seq_nr = 1
|
||
|
c.writeSyn()
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (s *Socket) DialContext(ctx context.Context, network, addr string) (nc net.Conn, err error) {
|
||
|
netAddr, err := s.resolveAddr(network, addr)
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
c, err := s.startOutboundConn(netAddr)
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
connErr := make(chan error, 1)
|
||
|
go func() {
|
||
|
connErr <- c.recvSynAck()
|
||
|
}()
|
||
|
select {
|
||
|
case err = <-connErr:
|
||
|
case <-ctx.Done():
|
||
|
err = ctx.Err()
|
||
|
}
|
||
|
if err != nil {
|
||
|
mu.Lock()
|
||
|
c.destroy(errors.New("dial timeout"))
|
||
|
mu.Unlock()
|
||
|
return
|
||
|
}
|
||
|
mu.Lock()
|
||
|
c.updateCanWrite()
|
||
|
mu.Unlock()
|
||
|
nc = pproffd.WrapNetConn(c)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (me *Socket) writeTo(b []byte, addr net.Addr) (n int, err error) {
|
||
|
apdc := artificialPacketDropChance
|
||
|
if apdc != 0 {
|
||
|
if rand.Float64() < apdc {
|
||
|
n = len(b)
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
n, err = me.pc.WriteTo(b, addr)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// Returns true if the connection was newly registered, false otherwise.
|
||
|
func (s *Socket) registerConn(recvID uint16, remoteAddr resolvedAddrStr, c *Conn) bool {
|
||
|
if s.conns == nil {
|
||
|
s.conns = make(map[connKey]*Conn)
|
||
|
}
|
||
|
key := connKey{remoteAddr, recvID}
|
||
|
if _, ok := s.conns[key]; ok {
|
||
|
return false
|
||
|
}
|
||
|
c.connKey = key
|
||
|
s.conns[key] = c
|
||
|
return true
|
||
|
}
|
||
|
|
||
|
func (s *Socket) backlogChanged() {
|
||
|
if len(s.backlog) != 0 {
|
||
|
s.backlogNotEmpty.Set()
|
||
|
} else {
|
||
|
s.backlogNotEmpty.Clear()
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (s *Socket) nextSyn() (syn syn, err error) {
|
||
|
for {
|
||
|
missinggo.WaitEvents(&mu, &s.closed, &s.backlogNotEmpty, &s.destroyed)
|
||
|
if s.closed.IsSet() {
|
||
|
err = errClosed
|
||
|
return
|
||
|
}
|
||
|
if s.destroyed.IsSet() {
|
||
|
err = s.ReadErr
|
||
|
return
|
||
|
}
|
||
|
for k := range s.backlog {
|
||
|
syn = k
|
||
|
delete(s.backlog, k)
|
||
|
s.backlogChanged()
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// ACK a SYN, and return a new Conn for it. ok is false if the SYN is bad, and
|
||
|
// the Conn invalid.
|
||
|
func (s *Socket) ackSyn(syn syn) (c *Conn, ok bool) {
|
||
|
c = s.newConn(s.strNetAddr(syn.addr))
|
||
|
c.send_id = syn.conn_id
|
||
|
c.recv_id = c.send_id + 1
|
||
|
c.seq_nr = uint16(rand.Int())
|
||
|
c.lastAck = c.seq_nr - 1
|
||
|
c.ack_nr = syn.seq_nr
|
||
|
c.synAcked = true
|
||
|
c.updateCanWrite()
|
||
|
if !s.registerConn(c.recv_id, resolvedAddrStr(syn.addr), c) {
|
||
|
// SYN that triggered this accept duplicates existing connection.
|
||
|
// Ack again in case the SYN was a resend.
|
||
|
c = s.conns[connKey{resolvedAddrStr(syn.addr), c.recv_id}]
|
||
|
if c.send_id != syn.conn_id {
|
||
|
panic(":|")
|
||
|
}
|
||
|
c.sendState()
|
||
|
return
|
||
|
}
|
||
|
c.sendState()
|
||
|
ok = true
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// Accept and return a new uTP connection.
|
||
|
func (s *Socket) Accept() (net.Conn, error) {
|
||
|
mu.Lock()
|
||
|
defer mu.Unlock()
|
||
|
for {
|
||
|
syn, err := s.nextSyn()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
c, ok := s.ackSyn(syn)
|
||
|
if ok {
|
||
|
c.updateCanWrite()
|
||
|
return c, nil
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// The address we're listening on for new uTP connections.
|
||
|
func (s *Socket) Addr() net.Addr {
|
||
|
return s.pc.LocalAddr()
|
||
|
}
|
||
|
|
||
|
func (s *Socket) CloseNow() error {
|
||
|
mu.Lock()
|
||
|
defer mu.Unlock()
|
||
|
s.closed.Set()
|
||
|
for _, c := range s.conns {
|
||
|
c.closeNow()
|
||
|
}
|
||
|
s.destroy()
|
||
|
s.wgReadWrite.Wait()
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (s *Socket) Close() error {
|
||
|
mu.Lock()
|
||
|
defer mu.Unlock()
|
||
|
s.closed.Set()
|
||
|
s.lazyDestroy()
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (s *Socket) lazyDestroy() {
|
||
|
if len(s.conns) != 0 {
|
||
|
return
|
||
|
}
|
||
|
if !s.closed.IsSet() {
|
||
|
return
|
||
|
}
|
||
|
s.destroy()
|
||
|
}
|
||
|
|
||
|
func (s *Socket) destroy() {
|
||
|
delete(sockets, s)
|
||
|
s.destroyed.Set()
|
||
|
s.pc.Close()
|
||
|
for _, c := range s.conns {
|
||
|
c.destroy(errors.New("Socket destroyed"))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (s *Socket) LocalAddr() net.Addr {
|
||
|
return s.pc.LocalAddr()
|
||
|
}
|
||
|
|
||
|
func (s *Socket) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||
|
select {
|
||
|
case read, ok := <-s.unusedReads:
|
||
|
if !ok {
|
||
|
err = io.EOF
|
||
|
return
|
||
|
}
|
||
|
n = copy(p, read.data)
|
||
|
addr = read.from
|
||
|
return
|
||
|
case <-s.connDeadlines.read.passed.LockedChan(&mu):
|
||
|
err = errTimeout
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (s *Socket) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||
|
mu.Lock()
|
||
|
if s.connDeadlines.write.passed.IsSet() {
|
||
|
err = errTimeout
|
||
|
}
|
||
|
s.wgReadWrite.Add(1)
|
||
|
defer s.wgReadWrite.Done()
|
||
|
mu.Unlock()
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
return s.pc.WriteTo(b, addr)
|
||
|
}
|