From f8a0df0263204bba1ca5fb31078e894e3a6d7cf4 Mon Sep 17 00:00:00 2001 From: jdl Date: Mon, 23 Dec 2024 20:28:49 +0100 Subject: [PATCH] wip: working - moving on to single relay w/ address discovery --- node/globals.go | 12 ++- node/hubpoller.go | 1 - node/main.go | 14 ++- node/packets.go | 49 ++++++---- node/peer-supervisor.go | 203 ++++++++++++++++++---------------------- 5 files changed, 142 insertions(+), 137 deletions(-) diff --git a/node/globals.go b/node/globals.go index f782cb5..25eee33 100644 --- a/node/globals.go +++ b/node/globals.go @@ -23,8 +23,10 @@ type peerRoute struct { ControlCipher *controlCipher DataCipher *dataCipher RemoteAddr netip.AddrPort // Remote address if directly connected. - LocalAddr netip.AddrPort // Local address as seen by the remote. - RelayIP byte // Non-zero if we should relay. + // TODO: Remove this and use global localAddr and relayIP. + // Replace w/ a Direct boolean. + LocalAddr netip.AddrPort // Local address as seen by the remote. + RelayIP byte // Non-zero if we should relay. } var ( @@ -32,6 +34,7 @@ var ( netName string localIP byte localPub bool + localAddr netip.AddrPort privateKey []byte // Shared interface for writing. @@ -54,4 +57,9 @@ var ( // Global routing table. routingTable [256]*atomic.Pointer[peerRoute] + + // TODO: use relay for local address discovery. This should be new stream ID, + // managed by a single thread. + // localAddr *atomic.Pointer[netip.AddrPort] + // relayIP *atomic.Pointer[byte] ) diff --git a/node/hubpoller.go b/node/hubpoller.go index ef36431..ac6b110 100644 --- a/node/hubpoller.go +++ b/node/hubpoller.go @@ -58,7 +58,6 @@ func (hp *hubPoller) Run() { func (hp *hubPoller) pollHub() { var state m.NetworkState - log.Printf("Fetching peer state...") resp, err := hp.client.Do(hp.req) if err != nil { log.Printf("Failed to fetch peer state: %v", err) diff --git a/node/main.go b/node/main.go index 419f644..70857c3 100644 --- a/node/main.go +++ b/node/main.go @@ -105,7 +105,13 @@ func main(listenIP string, port uint16) { // Intialize globals. localIP = config.PeerIP - localPub = addrIsValid(config.PublicIP) + + ip, ok := netip.AddrFromSlice(config.PublicIP) + if ok { + localPub = true + localAddr = netip.AddrPortFrom(ip, config.Port) + } + privateKey = config.PrivKey _iface = newIFWriter(iface) @@ -178,6 +184,8 @@ func readFromConn(conn *net.UDPConn) { case controlStreamID: handleControlPacket(remoteAddr, h, data, decBuf) + // TODO: discoveryStreamID + case dataStreamID: handleDataPacket(h, data, decBuf) @@ -216,8 +224,8 @@ func handleControlPacket(addr netip.AddrPort, h header, data, decBuf []byte) { } pkt := controlPacket{ - SrcIP: h.SourceIP, - RemoteAddr: addr, + SrcIP: h.SourceIP, + SrcAddr: addr, } if err := pkt.ParsePayload(out); err != nil { diff --git a/node/packets.go b/node/packets.go index f6d92e1..f0ea736 100644 --- a/node/packets.go +++ b/node/packets.go @@ -2,6 +2,7 @@ package node import ( "errors" + "log" "net/netip" ) @@ -14,14 +15,15 @@ const ( packetTypeSyn = iota + 1 packetTypeSynAck packetTypeAck + packetTypeProbe ) // ---------------------------------------------------------------------------- type controlPacket struct { - SrcIP byte - RemoteAddr netip.AddrPort - Payload any + SrcIP byte + SrcAddr netip.AddrPort + Payload any } func (p *controlPacket) ParsePayload(buf []byte) (err error) { @@ -30,8 +32,9 @@ func (p *controlPacket) ParsePayload(buf []byte) (err error) { p.Payload, err = parseSynPacket(buf) case packetTypeSynAck: p.Payload, err = parseSynAckPacket(buf) - case packetTypeAck: - p.Payload, err = parseAckPacket(buf) + case packetTypeProbe: + log.Printf("Got probe...") + p.Payload, err = parseProbePacket(buf) default: return errUnknownPacketType } @@ -44,6 +47,7 @@ type synPacket struct { TraceID uint64 // TraceID to match response w/ request. SharedKey [32]byte // Our shared key. RelayIP byte + FromAddr netip.AddrPort // The client's sending address. } func (p synPacket) Marshal(buf []byte) []byte { @@ -52,6 +56,7 @@ func (p synPacket) Marshal(buf []byte) []byte { Uint64(p.TraceID). SharedKey(p.SharedKey). Byte(p.RelayIP). + AddrPort(p.FromAddr). Build() } @@ -60,6 +65,7 @@ func parseSynPacket(buf []byte) (p synPacket, err error) { Uint64(&p.TraceID). SharedKey(&p.SharedKey). Byte(&p.RelayIP). + AddrPort(&p.FromAddr). Error() return } @@ -68,47 +74,54 @@ func parseSynPacket(buf []byte) (p synPacket, err error) { type synAckPacket struct { TraceID uint64 - RecvAddr netip.AddrPort + FromAddr netip.AddrPort + ToAddr netip.AddrPort } func (p synAckPacket) Marshal(buf []byte) []byte { return newBinWriter(buf). Byte(packetTypeSynAck). Uint64(p.TraceID). - AddrPort(p.RecvAddr). + AddrPort(p.FromAddr). + AddrPort(p.ToAddr). Build() } func parseSynAckPacket(buf []byte) (p synAckPacket, err error) { err = newBinReader(buf[1:]). Uint64(&p.TraceID). - AddrPort(&p.RecvAddr). + AddrPort(&p.FromAddr). + AddrPort(&p.ToAddr). Error() return } // ---------------------------------------------------------------------------- -type ackPacket struct { +type addrDiscoveryPacket struct { TraceID uint64 - SendAddr netip.AddrPort // Address of the sender. - RecvAddr netip.AddrPort // Address of the recipient as seen by sender. + FromAddr netip.AddrPort + ToAddr netip.AddrPort } -func (p ackPacket) Marshal(buf []byte) []byte { +// ---------------------------------------------------------------------------- + +// A probeReqPacket is sent from a client to a server to determine if direct +// UDP communication can be used. +type probePacket struct { + TraceID uint64 +} + +func (p probePacket) Marshal(buf []byte) []byte { return newBinWriter(buf). - Byte(packetTypeAck). + Byte(packetTypeProbe). Uint64(p.TraceID). - AddrPort(p.SendAddr). - AddrPort(p.RecvAddr). Build() } -func parseAckPacket(buf []byte) (p ackPacket, err error) { +func parseProbePacket(buf []byte) (p probePacket, err error) { err = newBinReader(buf[1:]). Uint64(&p.TraceID). - AddrPort(&p.SendAddr). - AddrPort(&p.RecvAddr). Error() return } diff --git a/node/peer-supervisor.go b/node/peer-supervisor.go index e47d0ae..e4f056e 100644 --- a/node/peer-supervisor.go +++ b/node/peer-supervisor.go @@ -66,6 +66,36 @@ func (s *peerSupervisor) sendControlPacket(pkt interface{ Marshal([]byte) []byte _sendControlPacket(pkt, s.staged, s.buf1, s.buf2) } +func (s *peerSupervisor) sendControlPacketTo( + pkt interface{ Marshal([]byte) []byte }, + addr netip.AddrPort, +) { + if !addr.IsValid() { + s.logf("ERROR: Attepted to send packet to invalid address: %v", addr) + return + } + route := s.staged + route.RelayIP = 0 + route.RemoteAddr = addr + _sendControlPacket(pkt, route, s.buf1, s.buf2) +} + +// ---------------------------------------------------------------------------- + +func (s *peerSupervisor) getLocalAddr() netip.AddrPort { + if localPub { + return localAddr + } + + if s.staged.RelayIP != 0 { + if addr := routingTable[s.staged.RelayIP].Load().LocalAddr; addr.IsValid() { + return addr + } + } + + return s.staged.LocalAddr +} + // ---------------------------------------------------------------------------- func (s *peerSupervisor) logf(msg string, args ...any) { @@ -113,7 +143,7 @@ func (s *peerSupervisor) _peerUpdate(peer *m.Peer) stateFunc { if s.remotePub == localPub { if localIP < s.remoteIP { - return s.serverAccept + return s.server } return s.clientInit } @@ -121,18 +151,13 @@ func (s *peerSupervisor) _peerUpdate(peer *m.Peer) stateFunc { if s.remotePub { return s.clientInit } - return s.serverAccept + return s.server } // ---------------------------------------------------------------------------- -func (s *peerSupervisor) serverAccept() stateFunc { - s.logf("STATE: server-accept") - s.staged.Up = false - s.staged.DataCipher = nil - s.staged.RemoteAddr = zeroAddrPort - s.staged.RelayIP = 0 - s.publish() +func (s *peerSupervisor) server() stateFunc { + s.logf("STATE: server") var syn synPacket @@ -145,60 +170,37 @@ func (s *peerSupervisor) serverAccept() stateFunc { switch p := pkt.Payload.(type) { case synPacket: - syn = p - s.staged.RemoteAddr = pkt.RemoteAddr - s.staged.DataCipher = newDataCipherFromKey(syn.SharedKey) - s.staged.RelayIP = syn.RelayIP - s.publish() - s.sendControlPacket(synAckPacket{TraceID: syn.TraceID, RecvAddr: pkt.RemoteAddr}) - - case ackPacket: + // Before we can respond to this packet, we need to make sure the + // route is setup properly. if p.TraceID != syn.TraceID { - continue + syn = p + s.staged.Up = true + s.staged.RemoteAddr = pkt.SrcAddr + s.staged.DataCipher = newDataCipherFromKey(syn.SharedKey) + s.staged.RelayIP = syn.RelayIP + s.staged.LocalAddr = s.getLocalAddr() + s.publish() } - // Publish. - return s.serverConnected(syn.TraceID) - } - } - } -} + // We should always respond. + s.sendControlPacket(synAckPacket{ + TraceID: syn.TraceID, + FromAddr: s.staged.LocalAddr, + ToAddr: pkt.SrcAddr, + }) -// ---------------------------------------------------------------------------- - -func (s *peerSupervisor) serverConnected(traceID uint64) stateFunc { - s.logf("STATE: server-connected") - s.staged.Up = true - s.publish() - return func() stateFunc { - return s._serverConnected(traceID) - } -} - -func (s *peerSupervisor) _serverConnected(traceID uint64) stateFunc { - - timeoutTimer := time.NewTimer(timeoutInterval) - defer timeoutTimer.Stop() - - for { - select { - case peer := <-s.peerUpdates: - return s.peerUpdate(peer) - - case pkt := <-s.controlPackets: - switch p := pkt.Payload.(type) { - - case ackPacket: - if p.TraceID != traceID { - return s.serverAccept + // If we're relayed, attempt to probe the client. + if s.staged.RelayIP != 0 && syn.FromAddr.IsValid() { + probe := probePacket{TraceID: newTraceID()} + s.logf("SERVER sending probe %v: %v", probe, syn.FromAddr) + s.sendControlPacketTo(probe, syn.FromAddr) } - s.sendControlPacket(ackPacket{TraceID: traceID, RecvAddr: pkt.RemoteAddr}) - timeoutTimer.Reset(timeoutInterval) - } - case <-timeoutTimer.C: - s.logf("server timeout") - return s.serverAccept + case probePacket: + s.logf("SERVER got probe: %v", p) + s.logf("SERVER sending probe: %v", pkt.SrcAddr) + s.sendControlPacketTo(probePacket{TraceID: p.TraceID}, pkt.SrcAddr) + } } } } @@ -208,13 +210,10 @@ func (s *peerSupervisor) _serverConnected(traceID uint64) stateFunc { func (s *peerSupervisor) clientInit() stateFunc { s.logf("STATE: client-init") if !s.remotePub { - // TODO: Check local discovery for IP. - // TODO: Attempt UDP hole punch. - // TODO: client-relayed return s.clientSelectRelay } - return s.clientDial + return s.client } // ---------------------------------------------------------------------------- @@ -237,7 +236,7 @@ func (s *peerSupervisor) clientSelectRelay() stateFunc { s.staged.RelayIP = relay.IP s.staged.LocalAddr = relay.LocalAddr s.publish() - return s.clientDial + return s.client } s.logf("No relay available.") @@ -264,20 +263,26 @@ func (s *peerSupervisor) selectRelay() *peerRoute { // ---------------------------------------------------------------------------- -func (s *peerSupervisor) clientDial() stateFunc { - s.logf("STATE: client-dial") +func (s *peerSupervisor) client() stateFunc { + s.logf("STATE: client") var ( syn = synPacket{ TraceID: newTraceID(), SharedKey: s.staged.DataCipher.Key(), RelayIP: s.staged.RelayIP, + FromAddr: s.getLocalAddr(), } + ack synAckPacket - timeout = time.NewTimer(dialTimeout) + probe = probePacket{TraceID: newTraceID()} + + timeoutTimer = time.NewTimer(timeoutInterval) + pingTimer = time.NewTimer(pingInterval) ) - defer timeout.Stop() + defer timeoutTimer.Stop() + defer pingTimer.Stop() s.sendControlPacket(syn) @@ -289,64 +294,36 @@ func (s *peerSupervisor) clientDial() stateFunc { case pkt := <-s.controlPackets: switch p := pkt.Payload.(type) { + case synAckPacket: if p.TraceID != syn.TraceID { + s.logf("Bad traceID?") continue // Hmm... } - s.sendControlPacket(ackPacket{TraceID: syn.TraceID, RecvAddr: pkt.RemoteAddr}) - return s.clientConnected(p) - } - - case <-timeout.C: - return s.clientInit - } - } -} - -// ---------------------------------------------------------------------------- - -func (s *peerSupervisor) clientConnected(p synAckPacket) stateFunc { - s.logf("STATE: client-connected") - s.staged.Up = true - s.staged.LocalAddr = p.RecvAddr - s.publish() - - return func() stateFunc { - return s._clientConnected(p.TraceID) - } -} - -func (s *peerSupervisor) _clientConnected(traceID uint64) stateFunc { - - pingTimer := time.NewTimer(pingInterval) - timeoutTimer := time.NewTimer(timeoutInterval) - - defer pingTimer.Stop() - defer timeoutTimer.Stop() - - for { - select { - case peer := <-s.peerUpdates: - return s.peerUpdate(peer) - - case pkt := <-s.controlPackets: - switch p := pkt.Payload.(type) { - - case ackPacket: - if p.TraceID != traceID { - return s.clientInit - } + ack = p timeoutTimer.Reset(timeoutInterval) + + if !s.staged.Up { + s.staged.Up = true + s.staged.LocalAddr = p.ToAddr + s.publish() + } + + case probePacket: + s.logf("CLIENT got probe: %v", p) } case <-pingTimer.C: - s.sendControlPacket(ackPacket{TraceID: traceID}) + s.sendControlPacket(syn) pingTimer.Reset(pingInterval) - case <-timeoutTimer.C: - s.logf("client timeout") - return s.clientInit + if s.staged.RelayIP != 0 && ack.FromAddr.IsValid() { + s.logf("CLIENT sending probe %v: %v", probe, ack.FromAddr) + s.sendControlPacketTo(probe, ack.FromAddr) + } + case <-timeoutTimer.C: + return s.clientInit } } }