From 640f4b998605abfd01db32509097b08f6bf8465b Mon Sep 17 00:00:00 2001 From: jdl Date: Tue, 24 Dec 2024 19:34:16 +0100 Subject: [PATCH] Cleanup. Direct, relayed, and hole-punching is working. --- fasttime/time.go | 20 ---- fasttime/time_test.go | 18 ---- hub/api/db/written.go | 4 +- node/addrdiscovery.go | 71 +++++++++++++ node/addrutil.go | 8 -- node/globalfuncs.go | 31 ++++-- node/globals.go | 53 +++++++--- node/header.go | 2 +- node/main.go | 69 +++++++------ node/packets-util.go | 12 +-- node/packets.go | 35 +++++-- node/packets_test.go | 21 +--- node/peer-supervisor.go | 222 ++++++++++++++++++++-------------------- node/relaymanager.go | 40 ++++++++ 14 files changed, 357 insertions(+), 249 deletions(-) delete mode 100644 fasttime/time.go delete mode 100644 fasttime/time_test.go create mode 100644 node/addrdiscovery.go delete mode 100644 node/addrutil.go create mode 100644 node/relaymanager.go diff --git a/fasttime/time.go b/fasttime/time.go deleted file mode 100644 index 5c569ac..0000000 --- a/fasttime/time.go +++ /dev/null @@ -1,20 +0,0 @@ -package fasttime - -import ( - "sync/atomic" - "time" -) - -var _timestamp int64 = time.Now().Unix() - -func init() { - go func() { - for range time.Tick(1100 * time.Millisecond) { - atomic.StoreInt64(&_timestamp, time.Now().Unix()) - } - }() -} - -func Now() int64 { - return atomic.LoadInt64(&_timestamp) -} diff --git a/fasttime/time_test.go b/fasttime/time_test.go deleted file mode 100644 index b0a85d0..0000000 --- a/fasttime/time_test.go +++ /dev/null @@ -1,18 +0,0 @@ -package fasttime - -import ( - "testing" - "time" -) - -func BenchmarkNow(b *testing.B) { - for i := 0; i < b.N; i++ { - Now() - } -} - -func BenchmarkTimeUnix(b *testing.B) { - for i := 0; i < b.N; i++ { - time.Now().Unix() - } -} diff --git a/hub/api/db/written.go b/hub/api/db/written.go index 65769c4..5b8bb15 100644 --- a/hub/api/db/written.go +++ b/hub/api/db/written.go @@ -1,12 +1,12 @@ package db -import "vppn/fasttime" +import "time" func Session_UpdateLastSeenAt( tx TX, id string, ) (err error) { - _, err = tx.Exec("UPDATE sessions SET LastSeenAt=? WHERE SessionID=?", fasttime.Now(), id) + _, err = tx.Exec("UPDATE sessions SET LastSeenAt=? WHERE SessionID=?", time.Now().Unix(), id) return err } diff --git a/node/addrdiscovery.go b/node/addrdiscovery.go new file mode 100644 index 0000000..b62e13f --- /dev/null +++ b/node/addrdiscovery.go @@ -0,0 +1,71 @@ +package node + +import ( + "log" + "net/netip" + "time" +) + +func addrDiscoveryServer() { + var ( + buf1 = make([]byte, bufferSize) + buf2 = make([]byte, bufferSize) + ) + + for { + pkt := <-discoveryPackets + + p, ok := pkt.Payload.(addrDiscoveryPacket) + if !ok { + continue + } + + route := routingTable[pkt.SrcIP].Load() + if route == nil || !route.RemoteAddr.IsValid() { + continue + } + + _sendControlPacket(addrDiscoveryPacket{ + TraceID: p.TraceID, + ToAddr: pkt.SrcAddr, + }, *route, buf1, buf2) + } +} + +func addrDiscoveryClient() { + var ( + checkInterval = 8 * time.Second + timer = time.NewTimer(4 * time.Second) + + buf1 = make([]byte, bufferSize) + buf2 = make([]byte, bufferSize) + + addrPacket addrDiscoveryPacket + lAddr netip.AddrPort + ) + + for { + select { + case pkt := <-discoveryPackets: + p, ok := pkt.Payload.(addrDiscoveryPacket) + if !ok || p.TraceID != addrPacket.TraceID || !p.ToAddr.IsValid() || p.ToAddr == lAddr { + continue + } + + log.Printf("Discovered local address: %v", p.ToAddr) + lAddr = p.ToAddr + localAddr.Store(&p.ToAddr) + + case <-timer.C: + timer.Reset(checkInterval) + + route := getRelayRoute() + if route == nil { + continue + } + + addrPacket.TraceID = newTraceID() + _sendControlPacket(addrPacket, *route, buf1, buf2) + } + } +} diff --git a/node/addrutil.go b/node/addrutil.go deleted file mode 100644 index 590c80c..0000000 --- a/node/addrutil.go +++ /dev/null @@ -1,8 +0,0 @@ -package node - -import "net/netip" - -func addrIsValid(in []byte) bool { - _, ok := netip.AddrFromSlice(in) - return ok -} diff --git a/node/globalfuncs.go b/node/globalfuncs.go index 406588e..98975da 100644 --- a/node/globalfuncs.go +++ b/node/globalfuncs.go @@ -1,10 +1,24 @@ package node import ( - "log" + "net/netip" "sync/atomic" ) +func getRelayRoute() *peerRoute { + if ip := relayIP.Load(); ip != nil { + return routingTable[*ip].Load() + } + return nil +} + +func getLocalAddr() netip.AddrPort { + if a := localAddr.Load(); a != nil { + return *a + } + return netip.AddrPort{} +} + func _sendControlPacket(pkt interface{ Marshal([]byte) []byte }, route peerRoute, buf1, buf2 []byte) { buf := pkt.Marshal(buf2) h := header{ @@ -15,12 +29,12 @@ func _sendControlPacket(pkt interface{ Marshal([]byte) []byte }, route peerRoute } buf = route.ControlCipher.Encrypt(h, buf, buf1) - if route.RelayIP == 0 { + if route.Direct { _conn.WriteTo(buf, route.RemoteAddr) return } - _relayPacket(route.RelayIP, route.IP, buf, buf2) + _relayPacket(route.IP, buf, buf2) } func _sendDataPacket(route *peerRoute, pkt, buf1, buf2 []byte) { @@ -33,18 +47,17 @@ func _sendDataPacket(route *peerRoute, pkt, buf1, buf2 []byte) { enc := route.DataCipher.Encrypt(h, pkt, buf1) - if route.RelayIP == 0 { + if route.Direct { _conn.WriteTo(enc, route.RemoteAddr) return } - _relayPacket(route.RelayIP, route.IP, enc, buf2) + _relayPacket(route.IP, enc, buf2) } -func _relayPacket(relayIP, destIP byte, data, buf []byte) { - relayRoute := routingTable[relayIP].Load() - if !relayRoute.Up || !relayRoute.Relay { - log.Print("Failed to send data packet: relay not available.") +func _relayPacket(destIP byte, data, buf []byte) { + relayRoute := getRelayRoute() + if relayRoute == nil || !relayRoute.Up || !relayRoute.Relay { return } diff --git a/node/globals.go b/node/globals.go index 25eee33..3b8edea 100644 --- a/node/globals.go +++ b/node/globals.go @@ -3,11 +3,10 @@ package node import ( "net/netip" "sync/atomic" + "time" "vppn/m" ) -var zeroAddrPort = netip.AddrPort{} - const ( bufferSize = 1536 if_mtu = 1200 @@ -20,13 +19,10 @@ type peerRoute struct { IP byte Up bool // True if data can be sent on the route. Relay bool // True if the peer is a relay. + Direct bool // True if this is a direct connection. ControlCipher *controlCipher DataCipher *dataCipher RemoteAddr netip.AddrPort // Remote address if directly connected. - // TODO: Remove this and use global localAddr and relayIP. - // Replace w/ a Direct boolean. - LocalAddr netip.AddrPort // Local address as seen by the remote. - RelayIP byte // Non-zero if we should relay. } var ( @@ -34,7 +30,6 @@ var ( netName string localIP byte localPub bool - localAddr netip.AddrPort privateKey []byte // Shared interface for writing. @@ -44,22 +39,48 @@ var ( _conn *connWriter // Counters for sending to each peer. - sendCounters [256]uint64 + sendCounters [256]uint64 = func() (out [256]uint64) { + for i := range out { + out[i] = uint64(time.Now().Unix()<<30 + 1) + } + return + }() // Duplicate checkers for incoming packets. - dupChecks [256]*dupCheck + dupChecks [256]*dupCheck = func() (out [256]*dupCheck) { + for i := range out { + out[i] = newDupCheck(0) + } + return + }() // Channels for incoming control packets. - controlPackets [256]chan controlPacket + controlPackets [256]chan controlPacket = func() (out [256]chan controlPacket) { + for i := range out { + out[i] = make(chan controlPacket, 256) + } + return + }() // Channels for incoming peer updates from the hub. - peerUpdates [256]chan *m.Peer + peerUpdates [256]chan *m.Peer = func() (out [256]chan *m.Peer) { + for i := range out { + out[i] = make(chan *m.Peer) + } + return + }() // Global routing table. - routingTable [256]*atomic.Pointer[peerRoute] + routingTable [256]*atomic.Pointer[peerRoute] = func() (out [256]*atomic.Pointer[peerRoute]) { + for i := range out { + out[i] = &atomic.Pointer[peerRoute]{} + out[i].Store(&peerRoute{}) + } + return + }() - // TODO: use relay for local address discovery. This should be new stream ID, - // managed by a single thread. - // localAddr *atomic.Pointer[netip.AddrPort] - // relayIP *atomic.Pointer[byte] + // Managed by the relayManager. + discoveryPackets chan controlPacket + localAddr *atomic.Pointer[netip.AddrPort] // May be nil. + relayIP *atomic.Pointer[byte] // May be nil. ) diff --git a/node/header.go b/node/header.go index fd28962..58ba852 100644 --- a/node/header.go +++ b/node/header.go @@ -14,7 +14,7 @@ const ( type header struct { StreamID byte - Counter uint64 // Init with fasttime.Now() << 30 to ensure monotonic. + Counter uint64 // Init with time.Now().Unix << 30 to ensure monotonic. SourceIP byte DestIP byte } diff --git a/node/main.go b/node/main.go index 70857c3..ee2e7a7 100644 --- a/node/main.go +++ b/node/main.go @@ -12,7 +12,6 @@ import ( "os" "runtime/debug" "sync/atomic" - "time" "vppn/m" ) @@ -104,36 +103,34 @@ func main(listenIP string, port uint16) { } // Intialize globals. + _iface = newIFWriter(iface) + _conn = newConnWriter(conn) + localIP = config.PeerIP + discoveryPackets = make(chan controlPacket, 256) + localAddr = &atomic.Pointer[netip.AddrPort]{} + relayIP = &atomic.Pointer[byte]{} ip, ok := netip.AddrFromSlice(config.PublicIP) if ok { localPub = true - localAddr = netip.AddrPortFrom(ip, config.Port) + addr := netip.AddrPortFrom(ip, config.Port) + localAddr.Store(&addr) } privateKey = config.PrivKey - _iface = newIFWriter(iface) - _conn = newConnWriter(conn) - - for i := range 256 { - sendCounters[i] = uint64(time.Now().Unix()<<30) + 1 - dupChecks[i] = newDupCheck(0) - controlPackets[i] = make(chan controlPacket, 256) - peerUpdates[i] = make(chan *m.Peer) - routingTable[i] = &atomic.Pointer[peerRoute]{} - route := peerRoute{IP: byte(i)} - routingTable[i].Store(&route) - } - // Start supervisors. for i := range 256 { go newPeerSupervisor(i).Run() } - // -------------------- - + if localPub { + go addrDiscoveryServer() + } else { + go addrDiscoveryClient() + go relayManager() + } go newHubPoller(config).Run() go readFromConn(conn) readFromIFace(iface) @@ -173,6 +170,8 @@ func readFromConn(conn *net.UDPConn) { log.Fatalf("Failed to read from UDP port: %v", err) } + remoteAddr = netip.AddrPortFrom(remoteAddr.Addr().Unmap(), remoteAddr.Port()) + data = buf[:n] if n < headerSize { @@ -184,8 +183,6 @@ func readFromConn(conn *net.UDPConn) { case controlStreamID: handleControlPacket(remoteAddr, h, data, decBuf) - // TODO: discoveryStreamID - case dataStreamID: handleDataPacket(h, data, decBuf) @@ -198,7 +195,7 @@ func readFromConn(conn *net.UDPConn) { func handleControlPacket(addr netip.AddrPort, h header, data, decBuf []byte) { route := routingTable[h.SourceIP].Load() if route.ControlCipher == nil { - log.Printf("Not connected (control).") + //log.Printf("Not connected (control).") return } @@ -209,17 +206,17 @@ func handleControlPacket(addr netip.AddrPort, h header, data, decBuf []byte) { out, ok := route.ControlCipher.Decrypt(data, decBuf) if !ok { - log.Printf("Failed to decrypt control packet.") + //log.Printf("Failed to decrypt control packet.") return } if len(out) == 0 { - log.Printf("Empty control packet from: %d", h.SourceIP) + //log.Printf("Empty control packet from: %d", h.SourceIP) return } if dupChecks[h.SourceIP].IsDup(h.Counter) { - log.Printf("[%03d] Duplicate control packet: %d", h.SourceIP, h.Counter) + //log.Printf("[%03d] Duplicate control packet: %d", h.SourceIP, h.Counter) return } @@ -233,17 +230,29 @@ func handleControlPacket(addr netip.AddrPort, h header, data, decBuf []byte) { return } - select { - case controlPackets[h.SourceIP] <- pkt: + switch pkt.Payload.(type) { + + case addrDiscoveryPacket: + select { + case discoveryPackets <- pkt: + default: + log.Printf("Dropping discovery packet.") + } + default: - log.Printf("Dropping control packet.") + select { + case controlPackets[h.SourceIP] <- pkt: + default: + log.Printf("Dropping control packet.") + } } + } func handleDataPacket(h header, data []byte, decBuf []byte) { route := routingTable[h.SourceIP].Load() if !route.Up { - log.Printf("Not connected (recv).") + //log.Printf("Not connected (recv).") return } @@ -254,7 +263,7 @@ func handleDataPacket(h header, data []byte, decBuf []byte) { } if dupChecks[h.SourceIP].IsDup(h.Counter) { - log.Printf("[%03d] Duplicate data packet: %d", h.SourceIP, h.Counter) + //log.Printf("[%03d] Duplicate data packet: %d", h.SourceIP, h.Counter) return } @@ -264,8 +273,8 @@ func handleDataPacket(h header, data []byte, decBuf []byte) { } destRoute := routingTable[h.DestIP].Load() - if !destRoute.Up || destRoute.RelayIP != 0 { - log.Printf("Not connected (relay)") + if !destRoute.Up { + log.Printf("Not connected (relay): %v", destRoute) return } diff --git a/node/packets-util.go b/node/packets-util.go index 8a6e13a..af10eb5 100644 --- a/node/packets-util.go +++ b/node/packets-util.go @@ -3,16 +3,14 @@ package node import ( "net/netip" "sync/atomic" + "time" "unsafe" - "vppn/fasttime" ) -var ( - traceIDCounter uint64 -) +var traceIDCounter uint64 = uint64(time.Now().Unix()<<30) + 1 func newTraceID() uint64 { - return uint64(fasttime.Now()<<30) + atomic.AddUint64(&traceIDCounter, 1) + return atomic.AddUint64(&traceIDCounter, 1) } // ---------------------------------------------------------------------------- @@ -151,9 +149,9 @@ func (r *binReader) AddrPort(x *netip.AddrPort) *binReader { if !r.hasBytes(18) { return r } - addr := netip.AddrFrom16(([16]byte)(r.b[r.i : r.i+16])) - addr = addr.Unmap() + addr := netip.AddrFrom16(([16]byte)(r.b[r.i : r.i+16])).Unmap() r.i += 16 + var port uint16 r.Uint16(&port) *x = netip.AddrPortFrom(addr, port) diff --git a/node/packets.go b/node/packets.go index f0ea736..267fed0 100644 --- a/node/packets.go +++ b/node/packets.go @@ -2,7 +2,6 @@ package node import ( "errors" - "log" "net/netip" ) @@ -16,6 +15,7 @@ const ( packetTypeSynAck packetTypeAck packetTypeProbe + packetTypeAddrDiscovery ) // ---------------------------------------------------------------------------- @@ -33,8 +33,9 @@ func (p *controlPacket) ParsePayload(buf []byte) (err error) { case packetTypeSynAck: p.Payload, err = parseSynAckPacket(buf) case packetTypeProbe: - log.Printf("Got probe...") p.Payload, err = parseProbePacket(buf) + case packetTypeAddrDiscovery: + p.Payload, err = parseAddrDiscoveryPacket(buf) default: return errUnknownPacketType } @@ -46,7 +47,7 @@ func (p *controlPacket) ParsePayload(buf []byte) (err error) { type synPacket struct { TraceID uint64 // TraceID to match response w/ request. SharedKey [32]byte // Our shared key. - RelayIP byte + Direct bool FromAddr netip.AddrPort // The client's sending address. } @@ -55,7 +56,7 @@ func (p synPacket) Marshal(buf []byte) []byte { Byte(packetTypeSyn). Uint64(p.TraceID). SharedKey(p.SharedKey). - Byte(p.RelayIP). + Bool(p.Direct). AddrPort(p.FromAddr). Build() } @@ -64,7 +65,7 @@ func parseSynPacket(buf []byte) (p synPacket, err error) { err = newBinReader(buf[1:]). Uint64(&p.TraceID). SharedKey(&p.SharedKey). - Byte(&p.RelayIP). + Bool(&p.Direct). AddrPort(&p.FromAddr). Error() return @@ -75,7 +76,6 @@ func parseSynPacket(buf []byte) (p synPacket, err error) { type synAckPacket struct { TraceID uint64 FromAddr netip.AddrPort - ToAddr netip.AddrPort } func (p synAckPacket) Marshal(buf []byte) []byte { @@ -83,7 +83,6 @@ func (p synAckPacket) Marshal(buf []byte) []byte { Byte(packetTypeSynAck). Uint64(p.TraceID). AddrPort(p.FromAddr). - AddrPort(p.ToAddr). Build() } @@ -91,7 +90,6 @@ func parseSynAckPacket(buf []byte) (p synAckPacket, err error) { err = newBinReader(buf[1:]). Uint64(&p.TraceID). AddrPort(&p.FromAddr). - AddrPort(&p.ToAddr). Error() return } @@ -99,9 +97,24 @@ func parseSynAckPacket(buf []byte) (p synAckPacket, err error) { // ---------------------------------------------------------------------------- type addrDiscoveryPacket struct { - TraceID uint64 - FromAddr netip.AddrPort - ToAddr netip.AddrPort + TraceID uint64 + ToAddr netip.AddrPort +} + +func (p addrDiscoveryPacket) Marshal(buf []byte) []byte { + return newBinWriter(buf). + Byte(packetTypeAddrDiscovery). + Uint64(p.TraceID). + AddrPort(p.ToAddr). + Build() +} + +func parseAddrDiscoveryPacket(buf []byte) (p addrDiscoveryPacket, err error) { + err = newBinReader(buf[1:]). + Uint64(&p.TraceID). + AddrPort(&p.ToAddr). + Error() + return } // ---------------------------------------------------------------------------- diff --git a/node/packets_test.go b/node/packets_test.go index bd83080..60295ec 100644 --- a/node/packets_test.go +++ b/node/packets_test.go @@ -9,7 +9,9 @@ import ( func TestPacketSyn(t *testing.T) { in := synPacket{ - TraceID: newTraceID(), + TraceID: newTraceID(), + RelayIP: 4, + FromAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{4, 5, 6, 7}), 22), } rand.Read(in.SharedKey[:]) @@ -26,7 +28,7 @@ func TestPacketSyn(t *testing.T) { func TestPacketSynAck(t *testing.T) { in := synAckPacket{ TraceID: newTraceID(), - RecvAddr: netip.AddrPort{}, + FromAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{4, 5, 6, 7}), 22), } out, err := parseSynAckPacket(in.Marshal(make([]byte, bufferSize))) @@ -38,18 +40,3 @@ func TestPacketSynAck(t *testing.T) { t.Fatal("\n", in, "\n", out) } } - -func TestPacketAck(t *testing.T) { - in := ackPacket{ - TraceID: newTraceID(), - } - - out, err := parseAckPacket(in.Marshal(make([]byte, bufferSize))) - if err != nil { - t.Fatal(err) - } - - if !reflect.DeepEqual(in, out) { - t.Fatal("\n", in, "\n", out) - } -} diff --git a/node/peer-supervisor.go b/node/peer-supervisor.go index e4f056e..76e329c 100644 --- a/node/peer-supervisor.go +++ b/node/peer-supervisor.go @@ -3,7 +3,6 @@ package node import ( "fmt" "log" - "math/rand" "net/netip" "sync/atomic" "time" @@ -11,10 +10,8 @@ import ( ) const ( - dialTimeout = 8 * time.Second - connectTimeout = 6 * time.Second - pingInterval = 6 * time.Second - timeoutInterval = 20 * time.Second + pingInterval = 8 * time.Second + timeoutInterval = 25 * time.Second ) // ---------------------------------------------------------------------------- @@ -64,6 +61,7 @@ func (s *peerSupervisor) Run() { func (s *peerSupervisor) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) { _sendControlPacket(pkt, s.staged, s.buf1, s.buf2) + time.Sleep(500 * time.Millisecond) // Rate limit packets. } func (s *peerSupervisor) sendControlPacketTo( @@ -75,25 +73,10 @@ func (s *peerSupervisor) sendControlPacketTo( return } route := s.staged - route.RelayIP = 0 + route.Direct = true route.RemoteAddr = addr _sendControlPacket(pkt, route, s.buf1, s.buf2) -} - -// ---------------------------------------------------------------------------- - -func (s *peerSupervisor) getLocalAddr() netip.AddrPort { - if localPub { - return localAddr - } - - if s.staged.RelayIP != 0 { - if addr := routingTable[s.staged.RelayIP].Load().LocalAddr; addr.IsValid() { - return addr - } - } - - return s.staged.LocalAddr + time.Sleep(500 * time.Millisecond) // Rate limit packets. } // ---------------------------------------------------------------------------- @@ -138,18 +121,21 @@ func (s *peerSupervisor) _peerUpdate(peer *m.Peer) stateFunc { if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid { s.remotePub = true s.staged.Relay = peer.Relay + s.staged.Direct = true s.staged.RemoteAddr = netip.AddrPortFrom(ip, peer.Port) + } else if localPub { + s.staged.Direct = true } if s.remotePub == localPub { if localIP < s.remoteIP { return s.server } - return s.clientInit + return s.client } if s.remotePub { - return s.clientInit + return s.client } return s.server } @@ -157,9 +143,14 @@ func (s *peerSupervisor) _peerUpdate(peer *m.Peer) stateFunc { // ---------------------------------------------------------------------------- func (s *peerSupervisor) server() stateFunc { - s.logf("STATE: server") + logf := func(format string, args ...any) { s.logf("SERVER "+format, args...) } - var syn synPacket + logf("DOWN") + + var ( + syn synPacket + timeoutTimer = time.NewTimer(timeoutInterval) + ) for { select { @@ -172,110 +163,80 @@ func (s *peerSupervisor) server() stateFunc { case synPacket: // Before we can respond to this packet, we need to make sure the // route is setup properly. - if p.TraceID != syn.TraceID { + // + // The client will update the syn's TraceID whenever there's a change. + // The server will follow the client's request. + if p.TraceID != syn.TraceID || !s.staged.Up { + if p.Direct { + logf("UP - Direct") + } else { + logf("UP - Relayed") + } + syn = p s.staged.Up = true - s.staged.RemoteAddr = pkt.SrcAddr + s.staged.Direct = syn.Direct s.staged.DataCipher = newDataCipherFromKey(syn.SharedKey) - s.staged.RelayIP = syn.RelayIP - s.staged.LocalAddr = s.getLocalAddr() + s.staged.RemoteAddr = pkt.SrcAddr + s.publish() } // We should always respond. - s.sendControlPacket(synAckPacket{ + ack := synAckPacket{ TraceID: syn.TraceID, - FromAddr: s.staged.LocalAddr, - ToAddr: pkt.SrcAddr, - }) + FromAddr: getLocalAddr(), + } + s.sendControlPacket(ack) - // If we're relayed, attempt to probe the client. - if s.staged.RelayIP != 0 && syn.FromAddr.IsValid() { - probe := probePacket{TraceID: newTraceID()} - s.logf("SERVER sending probe %v: %v", probe, syn.FromAddr) - s.sendControlPacketTo(probe, syn.FromAddr) + if s.staged.Direct { + continue } + if !syn.FromAddr.IsValid() { + continue + } + + probe := probePacket{TraceID: newTraceID()} + s.sendControlPacketTo(probe, syn.FromAddr) + case probePacket: - s.logf("SERVER got probe: %v", p) - s.logf("SERVER sending probe: %v", pkt.SrcAddr) - s.sendControlPacketTo(probePacket{TraceID: p.TraceID}, pkt.SrcAddr) - } - } - } -} - -// ---------------------------------------------------------------------------- - -func (s *peerSupervisor) clientInit() stateFunc { - s.logf("STATE: client-init") - if !s.remotePub { - return s.clientSelectRelay - } - - return s.client -} - -// ---------------------------------------------------------------------------- - -func (s *peerSupervisor) clientSelectRelay() stateFunc { - s.logf("STATE: client-select-relay") - - timer := time.NewTimer(0) - defer timer.Stop() - - for { - select { - case peer := <-s.peerUpdates: - return s.peerUpdate(peer) - - case <-timer.C: - relay := s.selectRelay() - if relay != nil { - s.logf("Got relay: %d", relay.IP) - s.staged.RelayIP = relay.IP - s.staged.LocalAddr = relay.LocalAddr - s.publish() - return s.client + if pkt.SrcAddr.IsValid() { + s.sendControlPacketTo(probePacket{TraceID: p.TraceID}, pkt.SrcAddr) + } else { + logf("Invalid probe address") + } } - s.logf("No relay available.") - timer.Reset(pingInterval) + case <-timeoutTimer.C: + logf("Connection timeout") + s.staged.Up = false + s.publish() } } } -func (s *peerSupervisor) selectRelay() *peerRoute { - possible := make([]*peerRoute, 0, 8) - for i := range routingTable { - route := routingTable[i].Load() - if !route.Up || !route.Relay { - continue - } - possible = append(possible, route) - } - - if len(possible) == 0 { - return nil - } - return possible[rand.Intn(len(possible))] -} - // ---------------------------------------------------------------------------- func (s *peerSupervisor) client() stateFunc { - s.logf("STATE: client") + logf := func(format string, args ...any) { s.logf("CLIENT "+format, args...) } + + logf("DOWN") var ( syn = synPacket{ TraceID: newTraceID(), SharedKey: s.staged.DataCipher.Key(), - RelayIP: s.staged.RelayIP, - FromAddr: s.getLocalAddr(), + Direct: s.staged.Direct, + FromAddr: getLocalAddr(), } + ack synAckPacket - probe = probePacket{TraceID: newTraceID()} + probe probePacket + probeAddr netip.AddrPort + + lAddr netip.AddrPort timeoutTimer = time.NewTimer(timeoutInterval) pingTimer = time.NewTimer(pingInterval) @@ -297,33 +258,74 @@ func (s *peerSupervisor) client() stateFunc { case synAckPacket: if p.TraceID != syn.TraceID { - s.logf("Bad traceID?") continue // Hmm... } + ack = p timeoutTimer.Reset(timeoutInterval) if !s.staged.Up { + if s.staged.Direct { + logf("UP - Direct") + } else { + logf("UP - Relayed") + } + s.staged.Up = true - s.staged.LocalAddr = p.ToAddr s.publish() } case probePacket: - s.logf("CLIENT got probe: %v", p) + if s.staged.Direct { + continue + } + + if p.TraceID != probe.TraceID { + continue + } + + // Upgrade connection. + + logf("UP - Direct") + s.staged.Direct = true + s.staged.RemoteAddr = probeAddr + s.publish() + + syn.TraceID = newTraceID() + syn.Direct = true + syn.FromAddr = getLocalAddr() + s.sendControlPacket(syn) } case <-pingTimer.C: - s.sendControlPacket(syn) - pingTimer.Reset(pingInterval) + // Send syn. - if s.staged.RelayIP != 0 && ack.FromAddr.IsValid() { - s.logf("CLIENT sending probe %v: %v", probe, ack.FromAddr) - s.sendControlPacketTo(probe, ack.FromAddr) + syn.FromAddr = getLocalAddr() + if syn.FromAddr != lAddr { + syn.TraceID = newTraceID() + lAddr = syn.FromAddr } + s.sendControlPacket(syn) + + pingTimer.Reset(pingInterval) + + if s.staged.Direct { + continue + } + + if !ack.FromAddr.IsValid() { + continue + } + + probe = probePacket{TraceID: newTraceID()} + probeAddr = ack.FromAddr + + s.sendControlPacketTo(probe, ack.FromAddr) + case <-timeoutTimer.C: - return s.clientInit + logf("Connection timeout") + return s.peerUpdate(s.peer) } } } diff --git a/node/relaymanager.go b/node/relaymanager.go new file mode 100644 index 0000000..5c44ea8 --- /dev/null +++ b/node/relaymanager.go @@ -0,0 +1,40 @@ +package node + +import ( + "log" + "math/rand" + "time" +) + +func relayManager() { + time.Sleep(2 * time.Second) + updateRelayRoute() + + for range time.Tick(8 * time.Second) { + relay := getRelayRoute() + if relay == nil || !relay.Up || !relay.Relay { + updateRelayRoute() + } + } +} + +func updateRelayRoute() { + possible := make([]*peerRoute, 0, 8) + for i := range routingTable { + route := routingTable[i].Load() + if !route.Up || !route.Relay { + continue + } + possible = append(possible, route) + } + + if len(possible) == 0 { + log.Printf("No relay available.") + relayIP.Store(nil) + return + } + + ip := possible[rand.Intn(len(possible))].IP + log.Printf("New relay IP: %d", ip) + relayIP.Store(&ip) +}