From 2a1c809731efa47a5f56c866d185468ec45909f4 Mon Sep 17 00:00:00 2001 From: jdl Date: Wed, 18 Dec 2024 13:29:06 +0100 Subject: [PATCH] FOrwarding is working. Not well at the moment. --- node/conn.go | 86 +++++++++++++++++++++++++++++++------------ node/header.go | 10 ++--- node/main.go | 16 ++++---- node/router.go | 29 +++++++++++++-- node/tmp_peerstate.go | 21 +++++++++-- 5 files changed, 119 insertions(+), 43 deletions(-) diff --git a/node/conn.go b/node/conn.go index 2e2bce8..0400526 100644 --- a/node/conn.go +++ b/node/conn.go @@ -14,6 +14,7 @@ type connWriter struct { lock sync.Mutex localIP byte buf []byte + buf2 []byte counters [256]uint64 routing *routingTable } @@ -23,6 +24,7 @@ func newConnWriter(conn *net.UDPConn, localIP byte, routing *routingTable) *conn UDPConn: conn, localIP: localIP, buf: make([]byte, bufferSize), + buf2: make([]byte, bufferSize), routing: routing, } @@ -33,37 +35,76 @@ func newConnWriter(conn *net.UDPConn, localIP byte, routing *routingTable) *conn return w } -func (w *connWriter) WriteTo(remoteIP, stream byte, data []byte) error { - // TODO: Handle mediator. - peer := w.routing.Get(remoteIP) - if peer == nil || peer.Addr == nil { +func (w *connWriter) WriteTo(remoteIP, stream byte, data []byte) { + dstPeer := w.routing.Get(remoteIP) + if dstPeer == nil { log.Printf("No peer: %d", remoteIP) - return nil + return } - if stream == streamData && !peer.Up { + if stream == streamData && !dstPeer.Up { log.Printf("Peer down: %d", remoteIP) + return } - return w.WriteToPeer(peer, stream, data) + + var viaPeer *peer + if dstPeer.Mediated { + viaPeer = w.routing.mediator.Load() + if viaPeer == nil || viaPeer.Addr == nil { + log.Printf("Mediator not connected") + return + } + } else if dstPeer.Addr == nil { + log.Printf("Peer doesn't have address: %d", remoteIP) + return + } + + w.WriteToPeer(dstPeer, viaPeer, stream, data) } -func (w *connWriter) WriteToPeer(peer *peer, stream byte, data []byte) error { +func (w *connWriter) WriteToPeer(dstPeer, viaPeer *peer, stream byte, data []byte) { w.lock.Lock() - remoteIP := peer.IP + addr := dstPeer.Addr + h := header{ - Counter: atomic.AddUint64(&w.counters[remoteIP], 1), + Counter: atomic.AddUint64(&w.counters[dstPeer.IP], 1), SourceIP: w.localIP, - ViaIP: 0, - DestIP: remoteIP, + DestIP: dstPeer.IP, Stream: stream, } - buf := encryptPacket(&h, peer.SharedKey, data, w.buf) + buf := encryptPacket(&h, dstPeer.SharedKey, data, w.buf) - _, err := w.WriteToUDPAddrPort(buf, *peer.Addr) + if viaPeer != nil { + h := header{ + Counter: atomic.AddUint64(&w.counters[viaPeer.IP], 1), + SourceIP: w.localIP, + DestIP: dstPeer.IP, + Forward: 1, + Stream: stream, + } + + buf = encryptPacket(&h, viaPeer.SharedKey, buf, w.buf2) + addr = viaPeer.Addr + } + + if _, err := w.WriteToUDPAddrPort(buf, *addr); err != nil { + log.Fatalf("Failed to write to UDP port: %v", err) + } w.lock.Unlock() - return err +} + +func (w *connWriter) Forward(dstIP byte, packet []byte) { + dstPeer := w.routing.Get(dstIP) + if dstPeer == nil || dstPeer.Addr == nil { + log.Printf("No peer: %d", dstIP) + return + } + + if _, err := w.WriteToUDPAddrPort(packet, *dstPeer.Addr); err != nil { + log.Fatalf("Failed to write to UDP port: %v", err) + } } // ---------------------------------------------------------------------------- @@ -89,38 +130,37 @@ func newConnReader(conn *net.UDPConn, localIP byte, routing *routingTable) *conn return r } -func (r *connReader) Read(buf []byte) (remoteAddr netip.AddrPort, h header, data []byte, err error) { - var n int +func (r *connReader) Read(buf []byte) (remoteAddr netip.AddrPort, h header, data []byte) { + var ( + n int + err error + ) for { n, remoteAddr, err = r.ReadFromUDPAddrPort(buf[:bufferSize]) if err != nil { - return + log.Fatalf("Failed to read from UDP port: %v", err) } data = buf[:n] if n < headerSize { - log.Printf("Dropping short packet: %d", n) continue // Packet it soo short. } h.Parse(data) if len(data) != headerSize+int(h.DataSize) { - log.Printf("Malformed packet: %d != %d", len(data), headerSize+int(h.DataSize)) - continue + continue // Invalid header. } peer := r.routing.Get(h.SourceIP) if peer == nil { - log.Printf("No peer: %d...", h.SourceIP) continue } out, ok := decryptPacket(peer.SharedKey, data, r.buf) if !ok { - log.Printf("Decrypt failed...") continue } diff --git a/node/header.go b/node/header.go index 05ce29b..81d6ed0 100644 --- a/node/header.go +++ b/node/header.go @@ -11,8 +11,8 @@ const ( type header struct { Counter uint64 // Init with fasttime.Now() << 30 to ensure monotonic. SourceIP byte - ViaIP byte DestIP byte + Forward byte Stream byte // See stream* constants. DataSize uint16 // Data size following associated data. } @@ -20,8 +20,8 @@ type header struct { func (hdr *header) Parse(nb []byte) { hdr.Counter = *(*uint64)(unsafe.Pointer(&nb[0])) hdr.SourceIP = nb[8] - hdr.ViaIP = nb[9] - hdr.DestIP = nb[10] + hdr.DestIP = nb[9] + hdr.Forward = nb[10] hdr.Stream = nb[11] hdr.DataSize = *(*uint16)(unsafe.Pointer(&nb[12])) } @@ -29,8 +29,8 @@ func (hdr *header) Parse(nb []byte) { func (hdr header) Marshal(buf []byte) { *(*uint64)(unsafe.Pointer(&buf[0])) = hdr.Counter buf[8] = hdr.SourceIP - buf[9] = hdr.ViaIP - buf[10] = hdr.DestIP + buf[9] = hdr.DestIP + buf[10] = hdr.Forward buf[11] = hdr.Stream *(*uint16)(unsafe.Pointer(&buf[12])) = hdr.DataSize } diff --git a/node/main.go b/node/main.go index 092751a..cac6df8 100644 --- a/node/main.go +++ b/node/main.go @@ -138,9 +138,11 @@ func nodeConnReader(r *connReader, w *connWriter, iface io.ReadWriteCloser, rout ) for { - remoteAddr, h, data, err = r.Read(buf) - if err != nil { - log.Fatalf("Failed to read from UDP connection: %v", err) + remoteAddr, h, data = r.Read(buf) + + if h.Forward != 0 { + w.Forward(h.DestIP, data) + continue } switch h.Stream { @@ -178,13 +180,9 @@ func nodeIFaceReader(w *connWriter, iface io.ReadWriteCloser, router *router) { } if remoteIP == w.localIP { - //log.Printf("Incoming packet for self: %x", packet) - //iface.Write(packet) - continue + continue // Don't write to self. } - if err := w.WriteTo(remoteIP, streamData, packet); err != nil { - log.Fatalf("Failed to write to network: %v", err) - } + w.WriteTo(remoteIP, streamData, packet) } } diff --git a/node/router.go b/node/router.go index bd3e5a9..67c0756 100644 --- a/node/router.go +++ b/node/router.go @@ -14,6 +14,8 @@ import ( type peer struct { Up bool // No data will be sent to peers that are down. + Mediator bool + Mediated bool IP byte Addr *netip.AddrPort // If we have direct connection, otherwise use mediator. SharedKey []byte @@ -49,8 +51,8 @@ func (r *routingTable) Set(ip byte, p *peer) { // ---------------------------------------------------------------------------- type router struct { - netName string *routingTable + netName string peerSupers [256]*peerSupervisor } @@ -68,7 +70,7 @@ func newRouter(netName string, conf m.PeerConfig, routingData *routingTable, w * r.routingTable) } - // TODO: Handle Mediator + go r.selectMediator() go r.pollHub(conf) return r @@ -111,7 +113,6 @@ func (r *router) pollHub(conf m.PeerConfig) { } req.SetBasicAuth("", conf.APIKey) - // TODO: Before we start polling, load state from the file system. state, err := loadNetworkState(r.netName) if err != nil { log.Printf("Failed to load network state: %v", err) @@ -161,3 +162,25 @@ func (r *router) applyNetworkState(conf m.PeerConfig, state m.NetworkState) { } } } + +// ---------------------------------------------------------------------------- + +func (r *router) selectMediator() { + for range time.Tick(8 * time.Second) { + current := r.mediator.Load() + if current != nil && current.Up { + continue + } + + for i := range r.table { + peer := r.table[i].Load() + if peer != nil && peer.Up && peer.Mediator { + log.Printf("Got mediator: %v", *peer) + r.mediator.Store(peer) + return + } + } + + r.mediator.Store(nil) + } +} diff --git a/node/tmp_peerstate.go b/node/tmp_peerstate.go index 5e365ed..14c9315 100644 --- a/node/tmp_peerstate.go +++ b/node/tmp_peerstate.go @@ -37,6 +37,7 @@ type peerSupervisor struct { version int64 // Ony accessed in HandlePeerUpdate. peer *m.Peer remoteAddrPort *netip.AddrPort + mediated bool sharedKey []byte // Used by our state functions. @@ -94,7 +95,6 @@ func (s *peerSupervisor) HandlePeerUpdate(p *m.Peer) { if p.Version == s.version { return } - s.logf("New peer version: %d", p.Version) s.version = p.Version } else { s.version = 0 @@ -154,6 +154,8 @@ func (s *peerSupervisor) stateSelectRole() stateFunc { s.updateRoutingTable(false) if s.remoteAddrPort != nil { + s.mediated = false + // If both remote and local are public, one side acts as client, and one // side as server. if s.localPublic && s.localIP < s.peer.PeerIP { @@ -164,6 +166,7 @@ func (s *peerSupervisor) stateSelectRole() stateFunc { // We're public, remote is not => can only wait for connection if s.localPublic { + s.mediated = false return s.stateAccept } @@ -277,8 +280,18 @@ func (s *peerSupervisor) stateConnected() stateFunc { func (s *peerSupervisor) stateMediated() stateFunc { s.logf("STATE: Mediated") - // TODO - select {} + s.mediated = true + s.updateRoutingTable(true) + + for { + select { + case <-s.packets: + // Drop. + case s.peer = <-s.peerUpdates: + s.logf("New peer: %v", s.peer) + return s.stateInit + } + } } // ---------------------------------------------------------------------------- @@ -290,6 +303,8 @@ func (s *peerSupervisor) clearRoutingTable() { func (s *peerSupervisor) updateRoutingTable(up bool) { s.table.Set(s.remoteIP, &peer{ Up: up, + Mediator: s.peer.Mediator, + Mediated: s.mediated, IP: s.remoteIP, Addr: s.remoteAddrPort, SharedKey: s.sharedKey,