From 2bdd76e689a332479b0aeb9f4a26e567d3d713b4 Mon Sep 17 00:00:00 2001 From: jdl Date: Sun, 12 Jan 2025 20:31:36 +0100 Subject: [PATCH] Better address discovery. --- node/addrdiscovery.go | 94 ++++++++++++++-------------- node/addrdiscovery_test.go | 29 +++++++++ node/globalfuncs.go | 8 --- node/globals.go | 9 ++- node/main.go | 16 ++--- node/messages.go | 10 +-- node/packets-util.go | 33 +++++++++- node/packets-util_test.go | 18 +++++- node/packets.go | 75 ++++++++++++---------- node/supervisor.go | 125 +++++++++++++++++++++++-------------- 10 files changed, 253 insertions(+), 164 deletions(-) create mode 100644 node/addrdiscovery_test.go diff --git a/node/addrdiscovery.go b/node/addrdiscovery.go index 4875c1f..f3a3666 100644 --- a/node/addrdiscovery.go +++ b/node/addrdiscovery.go @@ -3,65 +3,65 @@ package node import ( "log" "net/netip" + "runtime/debug" + "sort" "time" ) -func addrDiscoveryServer() { - var ( - buf1 = make([]byte, bufferSize) - buf2 = make([]byte, bufferSize) - ) +type pubAddrStore struct { + lastSeen map[netip.AddrPort]time.Time + addrList []netip.AddrPort +} - for { - msg := <-discoveryMessages - p := msg.Packet - - route := routingTable[msg.SrcIP].Load() - if route == nil || !route.RemoteAddr.IsValid() { - continue - } - - _sendControlPacket(addrDiscoveryPacket{ - TraceID: p.TraceID, - ToAddr: msg.SrcAddr, - }, *route, buf1, buf2) +func newPubAddrStore() *pubAddrStore { + return &pubAddrStore{ + lastSeen: map[netip.AddrPort]time.Time{}, + addrList: make([]netip.AddrPort, 0, 32), } } -func addrDiscoveryClient() { - var ( - checkInterval = 8 * time.Second - timer = time.NewTimer(4 * time.Second) +func (store *pubAddrStore) Store(add netip.AddrPort) { + if localPub { + log.Printf("OOPS: Local pub but storage attempt: %s", debug.Stack()) + return + } - buf1 = make([]byte, bufferSize) - buf2 = make([]byte, bufferSize) + if _, exists := store.lastSeen[add]; !exists { + store.addrList = append(store.addrList, add) + } + store.lastSeen[add] = time.Now() + store.sort() +} - addrPacket addrDiscoveryPacket - lAddr netip.AddrPort - ) +func (store *pubAddrStore) Get() (addrs [8]netip.AddrPort) { + if localPub { + addrs[0] = localAddr + return + } - for { - select { - case msg := <-discoveryMessages: - p := msg.Packet - if p.TraceID != addrPacket.TraceID || !p.ToAddr.IsValid() || p.ToAddr == lAddr { - continue - } + copy(addrs[:], store.addrList) + return +} - log.Printf("Discovered local address: %v", p.ToAddr) - lAddr = p.ToAddr - localAddr.Store(&p.ToAddr) +func (store *pubAddrStore) Clean() { + if localPub { + return + } - case <-timer.C: - timer.Reset(checkInterval) - - route := getRelayRoute() - if route == nil { - continue - } - - addrPacket.TraceID = newTraceID() - _sendControlPacket(addrPacket, *route, buf1, buf2) + for ip, lastSeen := range store.lastSeen { + if time.Since(lastSeen) > timeoutInterval { + delete(store.lastSeen, ip) } } + store.addrList = store.addrList[:0] + for ip := range store.lastSeen { + store.addrList = append(store.addrList, ip) + } + store.sort() +} + +func (store *pubAddrStore) sort() { + sort.Slice(store.addrList, func(i, j int) bool { + return store.lastSeen[store.addrList[j]].Before(store.lastSeen[store.addrList[i]]) + }) } diff --git a/node/addrdiscovery_test.go b/node/addrdiscovery_test.go new file mode 100644 index 0000000..9851d6a --- /dev/null +++ b/node/addrdiscovery_test.go @@ -0,0 +1,29 @@ +package node + +import ( + "net/netip" + "testing" + "time" +) + +func TestPubAddrStore(t *testing.T) { + s := newPubAddrStore() + + l := []netip.AddrPort{ + netip.AddrPortFrom(netip.AddrFrom4([4]byte{0, 1, 2, 3}), 20), + netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 2, 3}), 21), + netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 1, 2, 3}), 22), + } + + for i := range l { + s.Store(l[i]) + time.Sleep(time.Millisecond) + } + + s.Clean() + + l2 := s.Get() + if l2[0] != l[2] || l2[1] != l[1] || l2[2] != l[0] { + t.Fatal(l, l2) + } +} diff --git a/node/globalfuncs.go b/node/globalfuncs.go index 98975da..f32ec0b 100644 --- a/node/globalfuncs.go +++ b/node/globalfuncs.go @@ -1,7 +1,6 @@ package node import ( - "net/netip" "sync/atomic" ) @@ -12,13 +11,6 @@ func getRelayRoute() *peerRoute { return nil } -func getLocalAddr() netip.AddrPort { - if a := localAddr.Load(); a != nil { - return *a - } - return netip.AddrPort{} -} - func _sendControlPacket(pkt interface{ Marshal([]byte) []byte }, route peerRoute, buf1, buf2 []byte) { buf := pkt.Marshal(buf2) h := header{ diff --git a/node/globals.go b/node/globals.go index beab12e..9b465f1 100644 --- a/node/globals.go +++ b/node/globals.go @@ -41,6 +41,7 @@ var ( netName string localIP byte localPub bool + localAddr netip.AddrPort privKey []byte privSignKey []byte @@ -78,10 +79,8 @@ var ( return }() - // Managed by the addrDiscovery* functions. - discoveryMessages = make(chan controlMsg[addrDiscoveryPacket], 256) - // Managed by the relayManager. - localAddr = &atomic.Pointer[netip.AddrPort]{} - relayIP = &atomic.Pointer[byte]{} + relayIP = &atomic.Pointer[byte]{} + + publicAddrs = newPubAddrStore() ) diff --git a/node/main.go b/node/main.go index f40db27..4e59cf7 100644 --- a/node/main.go +++ b/node/main.go @@ -152,17 +152,13 @@ func main() { ip, ok := netip.AddrFromSlice(config.PublicIP) if ok { localPub = true - addr := netip.AddrPortFrom(ip, config.Port) - localAddr.Store(&addr) + localAddr = netip.AddrPortFrom(ip, config.Port) } privKey = config.PrivKey privSignKey = config.PrivSignKey - if localPub { - go addrDiscoveryServer() - } else { - go addrDiscoveryClient() + if !localPub { go relayManager() go localDiscovery() } @@ -177,6 +173,7 @@ func main() { go newHubPoller().Run() go readFromConn(conn) + readFromIFace(iface) } @@ -232,7 +229,7 @@ func handleControlPacket(addr netip.AddrPort, h header, data, decBuf []byte) { } if h.DestIP != localIP { - log.Printf("Incorrect destination IP on control packet: %d != %d", h.DestIP, localIP) + log.Printf("Incorrect destination IP on control packet: %#v", h) return } @@ -258,11 +255,6 @@ func handleControlPacket(addr netip.AddrPort, h header, data, decBuf []byte) { return } - if dm, ok := msg.(controlMsg[addrDiscoveryPacket]); ok { - discoveryMessages <- dm - return - } - select { case messages <- msg: default: diff --git a/node/messages.go b/node/messages.go index 5bd0397..76d86d4 100644 --- a/node/messages.go +++ b/node/messages.go @@ -24,7 +24,7 @@ func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error Packet: packet, }, err - case packetTypeSynAck: + case packetTypeAck: packet, err := parseAckPacket(buf) return controlMsg[ackPacket]{ SrcIP: srcIP, @@ -40,14 +40,6 @@ func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error Packet: packet, }, err - case packetTypeAddrDiscovery: - packet, err := parseAddrDiscoveryPacket(buf) - return controlMsg[addrDiscoveryPacket]{ - SrcIP: srcIP, - SrcAddr: srcAddr, - Packet: packet, - }, err - default: return nil, errUnknownPacketType } diff --git a/node/packets-util.go b/node/packets-util.go index af10eb5..b3071ab 100644 --- a/node/packets-util.go +++ b/node/packets-util.go @@ -63,12 +63,20 @@ func (w *binWriter) Int64(x int64) *binWriter { } func (w *binWriter) AddrPort(addrPort netip.AddrPort) *binWriter { + w.Bool(addrPort.IsValid()) addr := addrPort.Addr().As16() copy(w.b[w.i:w.i+16], addr[:]) w.i += 16 return w.Uint16(addrPort.Port()) } +func (w *binWriter) AddrPortArray(l [8]netip.AddrPort) *binWriter { + for _, addrPort := range l { + w.AddrPort(addrPort) + } + return w +} + func (w *binWriter) Build() []byte { return w.b[:w.i] } @@ -146,15 +154,34 @@ func (r *binReader) Int64(x *int64) *binReader { } func (r *binReader) AddrPort(x *netip.AddrPort) *binReader { - if !r.hasBytes(18) { + if !r.hasBytes(19) { return r } + + var ( + valid bool + port uint16 + ) + + r.Bool(&valid) addr := netip.AddrFrom16(([16]byte)(r.b[r.i : r.i+16])).Unmap() r.i += 16 - var port uint16 r.Uint16(&port) - *x = netip.AddrPortFrom(addr, port) + + if valid { + *x = netip.AddrPortFrom(addr, port) + } else { + *x = netip.AddrPort{} + } + + return r +} + +func (r *binReader) AddrPortArray(x *[8]netip.AddrPort) *binReader { + for i := range x { + r.AddrPort(&x[i]) + } return r } diff --git a/node/packets-util_test.go b/node/packets-util_test.go index 06b0370..96eab1a 100644 --- a/node/packets-util_test.go +++ b/node/packets-util_test.go @@ -12,15 +12,30 @@ func TestBinWriteRead(t *testing.T) { type Item struct { Type byte TraceID uint64 + Addrs [8]netip.AddrPort DestAddr netip.AddrPort } - in := Item{1, 2, netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 22)} + in := Item{ + 1, + 2, + [8]netip.AddrPort{}, + netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 22), + } + + in.Addrs[0] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{0, 1, 2, 3}), 20) + in.Addrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 5}), 22) + in.Addrs[3] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 3}), 23) + in.Addrs[4] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 4}), 24) + in.Addrs[5] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 5}), 25) + in.Addrs[6] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 6}), 26) + in.Addrs[7] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{7, 8, 9, 7}), 27) buf = newBinWriter(buf). Byte(in.Type). Uint64(in.TraceID). AddrPort(in.DestAddr). + AddrPortArray(in.Addrs). Build() out := Item{} @@ -29,6 +44,7 @@ func TestBinWriteRead(t *testing.T) { Byte(&out.Type). Uint64(&out.TraceID). AddrPort(&out.DestAddr). + AddrPortArray(&out.Addrs). Error() if err != nil { t.Fatal(err) diff --git a/node/packets.go b/node/packets.go index 6b9463a..14d7377 100644 --- a/node/packets.go +++ b/node/packets.go @@ -21,10 +21,10 @@ const ( // ---------------------------------------------------------------------------- type synPacket struct { - TraceID uint64 // TraceID to match response w/ request. - SharedKey [32]byte // Our shared key. - Direct bool - FromAddr netip.AddrPort // The client's sending address. + TraceID uint64 // TraceID to match response w/ request. + SharedKey [32]byte // Our shared key. + Direct bool + PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. } func (p synPacket) Marshal(buf []byte) []byte { @@ -33,7 +33,14 @@ func (p synPacket) Marshal(buf []byte) []byte { Uint64(p.TraceID). SharedKey(p.SharedKey). Bool(p.Direct). - AddrPort(p.FromAddr). + AddrPort(p.PossibleAddrs[0]). + AddrPort(p.PossibleAddrs[1]). + AddrPort(p.PossibleAddrs[2]). + AddrPort(p.PossibleAddrs[3]). + AddrPort(p.PossibleAddrs[4]). + AddrPort(p.PossibleAddrs[5]). + AddrPort(p.PossibleAddrs[6]). + AddrPort(p.PossibleAddrs[7]). Build() } @@ -42,7 +49,14 @@ func parseSynPacket(buf []byte) (p synPacket, err error) { Uint64(&p.TraceID). SharedKey(&p.SharedKey). Bool(&p.Direct). - AddrPort(&p.FromAddr). + AddrPort(&p.PossibleAddrs[0]). + AddrPort(&p.PossibleAddrs[1]). + AddrPort(&p.PossibleAddrs[2]). + AddrPort(&p.PossibleAddrs[3]). + AddrPort(&p.PossibleAddrs[4]). + AddrPort(&p.PossibleAddrs[5]). + AddrPort(&p.PossibleAddrs[6]). + AddrPort(&p.PossibleAddrs[7]). Error() return } @@ -50,45 +64,40 @@ func parseSynPacket(buf []byte) (p synPacket, err error) { // ---------------------------------------------------------------------------- type ackPacket struct { - TraceID uint64 - FromAddr netip.AddrPort + TraceID uint64 + ToAddr netip.AddrPort + PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. } func (p ackPacket) Marshal(buf []byte) []byte { return newBinWriter(buf). - Byte(packetTypeSynAck). + Byte(packetTypeAck). Uint64(p.TraceID). - AddrPort(p.FromAddr). + AddrPort(p.ToAddr). + AddrPort(p.PossibleAddrs[0]). + AddrPort(p.PossibleAddrs[1]). + AddrPort(p.PossibleAddrs[2]). + AddrPort(p.PossibleAddrs[3]). + AddrPort(p.PossibleAddrs[4]). + AddrPort(p.PossibleAddrs[5]). + AddrPort(p.PossibleAddrs[6]). + AddrPort(p.PossibleAddrs[7]). Build() + } func parseAckPacket(buf []byte) (p ackPacket, err error) { - err = newBinReader(buf[1:]). - Uint64(&p.TraceID). - AddrPort(&p.FromAddr). - Error() - return -} - -// ---------------------------------------------------------------------------- - -type addrDiscoveryPacket struct { - TraceID uint64 - ToAddr netip.AddrPort -} - -func (p addrDiscoveryPacket) Marshal(buf []byte) []byte { - return newBinWriter(buf). - Byte(packetTypeAddrDiscovery). - Uint64(p.TraceID). - AddrPort(p.ToAddr). - Build() -} - -func parseAddrDiscoveryPacket(buf []byte) (p addrDiscoveryPacket, err error) { err = newBinReader(buf[1:]). Uint64(&p.TraceID). AddrPort(&p.ToAddr). + AddrPort(&p.PossibleAddrs[0]). + AddrPort(&p.PossibleAddrs[1]). + AddrPort(&p.PossibleAddrs[2]). + AddrPort(&p.PossibleAddrs[3]). + AddrPort(&p.PossibleAddrs[4]). + AddrPort(&p.PossibleAddrs[5]). + AddrPort(&p.PossibleAddrs[6]). + AddrPort(&p.PossibleAddrs[7]). Error() return } diff --git a/node/supervisor.go b/node/supervisor.go index e20f1a9..9d89ee4 100644 --- a/node/supervisor.go +++ b/node/supervisor.go @@ -14,7 +14,7 @@ import ( const ( pingInterval = 8 * time.Second - timeoutInterval = 25 * time.Second + timeoutInterval = 30 * time.Second ) // ---------------------------------------------------------------------------- @@ -28,7 +28,7 @@ func startPeerSuper() { buf1: make([]byte, bufferSize), buf2: make([]byte, bufferSize), limiter: ratelimiter.New(ratelimiter.Config{ - FillPeriod: 50 * time.Millisecond, + FillPeriod: 20 * time.Millisecond, MaxWaitCount: 1, }), } @@ -57,6 +57,7 @@ func runPeerSuper(peers [256]peerState) { peers[msg.SrcIP].OnLocalDiscovery(msg) case pingTimerMsg: + publicAddrs.Clean() for i := range peers { if newState := peers[i].OnPingTimer(); newState != nil { peers[i] = newState @@ -171,10 +172,13 @@ func (s *peerStateData) OnPeerUpdate(peer *m.Peer) peerState { } s.peer = peer - s.staged.IP = s.remoteIP - s.staged.PubSignKey = peer.PubSignKey - s.staged.ControlCipher = newControlCipher(privKey, peer.PubKey) - s.staged.DataCipher = newDataCipher() + s.staged = peerRoute{ + IP: s.remoteIP, + PubSignKey: peer.PubSignKey, + ControlCipher: newControlCipher(privKey, peer.PubKey), + DataCipher: newDataCipher(), + } + s.remotePub = false if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid { s.remotePub = true @@ -254,13 +258,21 @@ func (s *stateServer) OnSyn(msg controlMsg[synPacket]) { // Always respond. ack := ackPacket{ - TraceID: p.TraceID, - FromAddr: getLocalAddr(), + TraceID: p.TraceID, + ToAddr: s.staged.RemoteAddr, + PossibleAddrs: publicAddrs.Get(), } s.sendControlPacket(ack) - if !s.staged.Direct && p.FromAddr.IsValid() { - s.sendControlPacketTo(probePacket{TraceID: newTraceID()}, p.FromAddr) + if s.staged.Direct { + return + } + + // Not direct => send probes. + for _, addr := range p.PossibleAddrs { + if addr.IsValid() { + s.sendControlPacketTo(probePacket{TraceID: newTraceID()}, addr) + } } } @@ -290,26 +302,35 @@ type stateClient struct { syn synPacket ack ackPacket - probeTraceID uint64 - probeAddr netip.AddrPort - - localProbeTraceID uint64 - localProbeAddr netip.AddrPort + probes map[uint64]netip.AddrPort + localDiscoveryAddr chan netip.AddrPort } func enterStateClient(s *peerStateData) peerState { s.client = true - ss := &stateClient{stateDisconnected: &stateDisconnected{s}} - ss.syn = synPacket{ - TraceID: newTraceID(), - SharedKey: s.staged.DataCipher.Key(), - Direct: s.staged.Direct, - FromAddr: getLocalAddr(), + ss := &stateClient{ + stateDisconnected: &stateDisconnected{s}, + probes: map[uint64]netip.AddrPort{}, + localDiscoveryAddr: make(chan netip.AddrPort, 1), } - ss.sendSyn() + + ss.syn = synPacket{ + TraceID: newTraceID(), + SharedKey: s.staged.DataCipher.Key(), + Direct: s.staged.Direct, + PossibleAddrs: publicAddrs.Get(), + } + ss.sendControlPacket(ss.syn) + return ss } +func (s *stateClient) sendProbeTo(addr netip.AddrPort) { + probe := probePacket{TraceID: newTraceID()} + s.probes[probe.TraceID] = addr + s.sendControlPacketTo(probe, addr) +} + func (s *stateClient) OnAck(msg controlMsg[ackPacket]) { if msg.Packet.TraceID != s.syn.TraceID { s.logf("Ack has incorrect trace ID") @@ -324,6 +345,12 @@ func (s *stateClient) OnAck(msg controlMsg[ackPacket]) { s.logf("Got ack.") s.publish() } else { + // TODO: What???? + } + + // Store possible public address if we're not a public node. + if !localPub && s.remotePub { + publicAddrs.Store(msg.Packet.ToAddr) } } @@ -332,21 +359,18 @@ func (s *stateClient) OnProbe(msg controlMsg[probePacket]) { return } - switch msg.Packet.TraceID { - case s.probeTraceID: - s.staged.RemoteAddr = s.probeAddr - case s.localProbeTraceID: - s.staged.RemoteAddr = s.localProbeAddr - default: + addr, ok := s.probes[msg.Packet.TraceID] + if !ok { return } + s.staged.RemoteAddr = addr s.staged.Direct = true s.publish() s.syn.TraceID = newTraceID() s.syn.Direct = true - s.syn.FromAddr = getLocalAddr() + s.syn.PossibleAddrs = [8]netip.AddrPort{} s.sendControlPacket(s.syn) s.logf("Established direct connection to %s.", s.staged.RemoteAddr.String()) @@ -361,9 +385,14 @@ func (s *stateClient) OnLocalDiscovery(msg controlMsg[localDiscoveryPacket]) { // // The source port will be the multicast port, so we'll have to // construct the correct address using the peer's listed port. - s.localProbeTraceID = newTraceID() - s.localProbeAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) - s.sendControlPacketTo(probePacket{TraceID: s.localProbeTraceID}, s.localProbeAddr) + addr := netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) + + select { + case s.localDiscoveryAddr <- addr: + // OK. + default: + log.Printf("Local discovery packet dropped.") + } } func (s *stateClient) OnPingTimer() peerState { @@ -374,22 +403,26 @@ func (s *stateClient) OnPingTimer() peerState { return s.OnPeerUpdate(s.peer) } - s.sendSyn() + s.sendControlPacket(s.syn) - if !s.staged.Direct && s.ack.FromAddr.IsValid() { - s.probeTraceID = newTraceID() - s.probeAddr = s.ack.FromAddr - s.sendControlPacketTo(probePacket{TraceID: s.probeTraceID}, s.probeAddr) + if s.staged.Direct { + return nil + } + + clear(s.probes) + for _, ip := range publicAddrs.Get() { + if !ip.IsValid() { + break + } + s.sendProbeTo(ip) + } + + select { + case addr := <-s.localDiscoveryAddr: + s.sendProbeTo(addr) + default: + // Nothing to do. } return nil } - -func (s *stateClient) sendSyn() { - localAddr := getLocalAddr() - if localAddr != s.syn.FromAddr { - s.syn.TraceID = newTraceID() - s.syn.FromAddr = localAddr - } - s.sendControlPacket(s.syn) -}