diff --git a/node/main.go b/node/main.go index 19252ff..9273823 100644 --- a/node/main.go +++ b/node/main.go @@ -108,11 +108,11 @@ func main(netName, listenIP string, port uint16) { peers := remotePeers{} for i := range peers { - peers[i] = newRemotePeer(conf, byte(i), ifWriter, connWriter) + peers[i] = newRemotePeer(conf, byte(i), ifWriter, connWriter, &peers) } go newHubPoller(netName, conf, peers).Run() - go readFromConn(conn, peers) + go readFromConn(conf.PeerIP, conn, peers) readFromIFace(iface, peers) } @@ -131,7 +131,7 @@ func determinePort(confPort, portFromCommandLine uint16) uint16 { // ---------------------------------------------------------------------------- -func readFromConn(conn *net.UDPConn, peers remotePeers) { +func readFromConn(localIP byte, conn *net.UDPConn, peers remotePeers) { defer panicHandler() @@ -157,7 +157,12 @@ func readFromConn(conn *net.UDPConn, peers remotePeers) { } h.Parse(data) - peers[h.SourceIP].HandlePacket(remoteAddr, h, data) + + if h.DestIP == localIP { + peers[h.SourceIP].HandlePacket(remoteAddr, h, data) + } else { + peers[h.DestIP].ForwardPacket(data) + } } } diff --git a/node/peer-supervisor.go b/node/peer-supervisor.go index e4dd881..cc615ab 100644 --- a/node/peer-supervisor.go +++ b/node/peer-supervisor.go @@ -1,6 +1,8 @@ package node import ( + "log" + "math/rand" "net/netip" "sync/atomic" "time" @@ -47,12 +49,15 @@ func (rp *peerSuper) Run() { func (rp *peerSuper) stateInit() stateFunc { //rp.logf("STATE: Init") + x := peerData{} rp.shared.Store(&x) + rp.peerData.relay = false rp.peerData.controlCipher = nil rp.peerData.dataCipher = nil rp.peerData.remoteAddr = zeroAddrPort + rp.peerData.relayIP = 0 if rp.peer == nil { return rp.stateDisconnected @@ -62,6 +67,8 @@ func (rp *peerSuper) stateInit() stateFunc { addr, rp.remotePublic = netip.AddrFromSlice(rp.peer.PublicIP) if rp.remotePublic { rp.peerData.remoteAddr = netip.AddrPortFrom(addr, rp.peer.Port) + } else { + rp.peerData.relay = false } rp.peerData.controlCipher = newControlCipher(rp.privKey, rp.peer.EncPubKey) @@ -89,8 +96,7 @@ func (rp *peerSuper) stateSelectRole() stateFunc { rp.logf("STATE: SelectRole") if !rp.localPublic && !rp.remotePublic { - // TODO! - return rp.stateDisconnected + return rp.stateSelectMediator } if !rp.localPublic { @@ -99,12 +105,55 @@ func (rp *peerSuper) stateSelectRole() stateFunc { return rp.stateClient } - if rp.localIP < rp.peer.PeerIP { + if rp.localIP < rp.remoteIP { return rp.stateClient } return rp.stateServer } +// ---------------------------------------------------------------------------- + +func (rp *peerSuper) stateSelectMediator() stateFunc { + rp.logf("STATE: SelectMediator") + + for { + log.Printf("Selecting mediator...") + if ip := rp.selectMediator(); ip != 0 { + rp.logf("Got mediator: %d", ip) + rp.peerData.relayIP = ip + + if rp.localIP < rp.remoteIP { + return rp.stateClient + } + return rp.stateServer + } + + select { + case <-time.After(pingInterval): + continue + case rp.peer = <-rp.peerUpdates: + return rp.stateInit + } + } + +} + +func (rp *peerSuper) selectMediator() byte { + possible := make([]byte, 0, 8) + for _, peer := range rp.peers { + if peer.canRelay() { + rp.logf("relay: %v", peer.shared.Load()) + possible = append(possible, peer.remoteIP) + } + } + if len(possible) == 0 { + return 0 + } + return possible[rand.Intn(len(possible))] +} + +// ---------------------------------------------------------------------------- + // The remote is a server. func (rp *peerSuper) stateServer() stateFunc { rp.logf("STATE: Server") @@ -112,10 +161,12 @@ func (rp *peerSuper) stateServer() stateFunc { rp.updateShared() var ( - pingTimer = time.NewTimer(pingInterval) - ping = pingPacket{SharedKey: ([32]byte)(rp.peerData.dataCipher.Key())} + pingTimer = time.NewTimer(pingInterval) + timeoutTimer = time.NewTimer(timeoutInterval) + ping = pingPacket{SharedKey: ([32]byte)(rp.peerData.dataCipher.Key())} ) defer pingTimer.Stop() + defer timeoutTimer.Stop() ping.SentAt = time.Now().UnixMilli() rp.sendControlPacket(ping) @@ -127,8 +178,18 @@ func (rp *peerSuper) stateServer() stateFunc { rp.sendControlPacket(ping) pingTimer.Reset(pingInterval) - case <-rp.controlPackets: - // Ignore + case cPkt := <-rp.controlPackets: + if _, ok := cPkt.Payload.(pongPacket); ok { + timeoutTimer.Reset(timeoutInterval) + } + + case <-timeoutTimer.C: + if rp.peerData.relayIP != 0 { + rp.logf("Timeout (server, relay)") + return rp.stateSelectMediator + } else { + rp.logf("Timeout (server)") + } case rp.peer = <-rp.peerUpdates: return rp.stateInit @@ -143,8 +204,12 @@ func (rp *peerSuper) stateClient() stateFunc { rp.logf("STATE: Client") rp.updateShared() - // TODO: Could use timeout to set dataCipher to nil. - var currentKey = [32]byte{} + var ( + currentKey = [32]byte{} + timeoutTimer = time.NewTimer(timeoutInterval) + ) + + defer timeoutTimer.Stop() for { select { @@ -163,12 +228,22 @@ func (rp *peerSuper) stateClient() stateFunc { if ping.SharedKey != currentKey { rp.logf("Connected with new shared key") currentKey = ping.SharedKey + rp.peerData.up = true rp.peerData.dataCipher = newDataCipherFromKey(currentKey) rp.updateShared() } + timeoutTimer.Reset(timeoutInterval) rp.sendControlPacket(newPongPacket(ping.SentAt)) + case <-timeoutTimer.C: + if rp.peerData.relayIP != 0 { + rp.logf("Timeout (server, relay)") + return rp.stateSelectMediator + } else { + rp.logf("Timeout (server)") + } + case rp.peer = <-rp.peerUpdates: return rp.stateInit } @@ -193,5 +268,10 @@ func (rp *peerSuper) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) DestIP: rp.remoteIP, } buf = rp.peerData.controlCipher.Encrypt(h, buf, rp.encBuf) - rp.conn.WriteTo(buf, rp.peerData.remoteAddr) + if rp.peerData.relayIP == 0 { + rp.conn.WriteTo(buf, rp.peerData.remoteAddr) + return + } + + rp.peers[rp.peerData.relayIP].RelayControlData(buf) } diff --git a/node/peer.go b/node/peer.go index a344472..d999339 100644 --- a/node/peer.go +++ b/node/peer.go @@ -12,6 +12,8 @@ import ( type remotePeers [256]*remotePeer type peerData struct { + up bool + relay bool controlCipher *controlCipher dataCipher *dataCipher remoteAddr netip.AddrPort @@ -28,6 +30,7 @@ type remotePeer struct { conn *connWriter // Shared state. + peers *remotePeers shared *atomic.Pointer[peerData] // Only used in HandlePeerUpdate. @@ -48,7 +51,7 @@ type remotePeer struct { controlPackets chan controlPacket } -func newRemotePeer(conf m.PeerConfig, remoteIP byte, iface *ifWriter, conn *connWriter) *remotePeer { +func newRemotePeer(conf m.PeerConfig, remoteIP byte, iface *ifWriter, conn *connWriter, peers *remotePeers) *remotePeer { rp := &remotePeer{ localIP: conf.PeerIP, remoteIP: remoteIP, @@ -56,6 +59,7 @@ func newRemotePeer(conf m.PeerConfig, remoteIP byte, iface *ifWriter, conn *conn localPublic: addrIsValid(conf.PublicIP), iface: iface, conn: conn, + peers: peers, shared: &atomic.Pointer[peerData]{}, dupCheck: newDupCheck(0), decryptBuf: make([]byte, bufferSize), @@ -97,10 +101,6 @@ func (rp *remotePeer) HandlePacket(addr netip.AddrPort, h header, data []byte) { case dataStreamID: rp.handleDataPacket(data) - case forwardStreamID: - fallthrough - // TODO - //rp.handleForwardPacket(h, data) default: rp.logf("Unknown stream ID: %d", h.StreamID) } @@ -115,6 +115,11 @@ func (rp *remotePeer) handleControlPacket(addr netip.AddrPort, h header, data [] return } + if h.DestIP != rp.localIP { + rp.logf("Incorrect destination IP on control packet.") + return + } + out, ok := shared.controlCipher.Decrypt(data, rp.decryptBuf) if !ok { rp.logf("Failed to decrypt control packet.") @@ -131,13 +136,6 @@ func (rp *remotePeer) handleControlPacket(addr netip.AddrPort, h header, data [] return } - if h.DestIP != rp.localIP { - // TODO: Forward control packet. - // TODO: Probably this should be dropped. - // Control packets should be forwarded as data for efficiency. - return - } - pkt := controlPacket{ SrcIP: h.SourceIP, RemoteAddr: addr, @@ -167,6 +165,8 @@ func (rp *remotePeer) handleControlPacket(addr netip.AddrPort, h header, data [] } } +// ---------------------------------------------------------------------------- + func (rp *remotePeer) handleDataPacket(data []byte) { shared := rp.shared.Load() if shared.dataCipher == nil { @@ -189,6 +189,29 @@ func (rp *remotePeer) handleDataPacket(data []byte) { // // This function is called by a single thread. func (rp *remotePeer) SendData(data []byte) { + rp.sendData(dataStreamID, data) +} + +// ---------------------------------------------------------------------------- + +func (rp *remotePeer) RelayControlData(data []byte) { + rp.sendData(forwardStreamID, data) +} + +// ---------------------------------------------------------------------------- + +func (rp *remotePeer) ForwardPacket(data []byte) { + shared := rp.shared.Load() + if shared.remoteAddr == zeroAddrPort { + rp.logf("Not connected (forward).") + return + } + rp.conn.WriteTo(data, shared.remoteAddr) +} + +// ---------------------------------------------------------------------------- + +func (rp *remotePeer) sendData(streamID byte, data []byte) { shared := rp.shared.Load() if shared.dataCipher == nil || shared.remoteAddr == zeroAddrPort { rp.logf("Not connected (send).") @@ -196,7 +219,7 @@ func (rp *remotePeer) SendData(data []byte) { } h := header{ - StreamID: dataStreamID, + StreamID: streamID, Counter: atomic.AddUint64(&rp.counter, 1), SourceIP: rp.localIP, DestIP: rp.remoteIP, @@ -205,3 +228,10 @@ func (rp *remotePeer) SendData(data []byte) { enc := shared.dataCipher.Encrypt(h, data, rp.encryptBuf) rp.conn.WriteTo(enc, shared.remoteAddr) } + +// ---------------------------------------------------------------------------- + +func (rp *remotePeer) canRelay() bool { + shared := rp.shared.Load() + return shared.relay && shared.up +}