diff --git a/node/peer-states.go b/node/peer-states.go index c3c0904..7a1de54 100644 --- a/node/peer-states.go +++ b/node/peer-states.go @@ -3,6 +3,7 @@ package node import ( "fmt" "log" + "math/rand" "net/netip" "sync/atomic" "time" @@ -22,7 +23,7 @@ type peerState interface { type stateBase struct { // The purpose of this state machine is to manage this published data. - published *atomic.Pointer[peerData] + published *atomic.Pointer[peerRoutingData] // The other remote peers. peers *remotePeers @@ -38,9 +39,9 @@ type stateBase struct { counter *uint64 // Mutable peer data. - peer *m.Peer - remotePub bool - data peerData // Local copy of shared data. See publish(). + peer *m.Peer + remotePub bool + routingData peerRoutingData // Local copy of shared data. See publish(). // Timers pingTimer *time.Timer @@ -63,16 +64,24 @@ func (s *stateBase) OnPeerUpdate(peer *m.Peer) peerState { return nil } - s.peer = peer + return s.selectStateFromPeer(peer) +} - s.data = peerData{} - s.data.controlCipher = newControlCipher(s.privKey, peer.EncPubKey) +func (s *stateBase) selectStateFromPeer(peer *m.Peer) peerState { + s.peer = peer + s.routingData = peerRoutingData{} + + if peer == nil { + return newStateNoPeer(s) + } + + s.routingData.controlCipher = newControlCipher(s.privKey, peer.EncPubKey) ip, isValid := netip.AddrFromSlice(peer.PublicIP) if isValid { s.remotePub = true - s.data.remoteAddr = netip.AddrPortFrom(ip, peer.Port) - s.data.relay = peer.Mediator + s.routingData.remoteAddr = netip.AddrPortFrom(ip, peer.Port) + s.routingData.relay = peer.Mediator if s.localPub && s.localIP < s.remoteIP { return newStateServer(s) @@ -84,9 +93,7 @@ func (s *stateBase) OnPeerUpdate(peer *m.Peer) peerState { return newStateServer(s) } - // TODO: return newStateMediated(a/b) - - return nil + return newStateSelectRelay(s) } func (s *stateBase) OnPing(rAddr netip.AddrPort, p pingPacket) peerState { return nil } @@ -106,10 +113,24 @@ func (s *stateBase) logf(msg string, args ...any) { } func (s *stateBase) publish() { - data := s.data + data := s.routingData s.published.Store(&data) } +func (s *stateBase) selectRelay() byte { + possible := make([]byte, 0, 8) + for i, peer := range s.peers { + if peer.CanRelay() { + possible = append(possible, byte(i)) + } + } + + if len(possible) == 0 { + return 0 + } + return possible[rand.Intn(len(possible))] +} + func (s *stateBase) sendPing(sharedKey [32]byte) { s.sendControlPacket(newPingPacket(sharedKey)) } @@ -127,13 +148,22 @@ func (s *stateBase) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) { DestIP: s.remoteIP, } - buf = s.data.controlCipher.Encrypt(h, buf, s.encBuf) - if s.data.relayIP == 0 { - s.conn.WriteTo(buf, s.data.remoteAddr) - return + buf = s.routingData.controlCipher.Encrypt(h, buf, s.encBuf) + if s.routingData.relayIP != 0 { + s.peers[s.routingData.relayIP].RelayFor(s.remoteIP, buf) + } else { + s.conn.WriteTo(buf, s.routingData.remoteAddr) } +} - // TODO: Relay! +// ---------------------------------------------------------------------------- + +type stateNoPeer struct{ *stateBase } + +func newStateNoPeer(b *stateBase) *stateNoPeer { + s := &stateNoPeer{b} + s.publish() + return s } // ---------------------------------------------------------------------------- @@ -147,8 +177,8 @@ func newStateClient(b *stateBase) peerState { s := &stateClient{stateBase: b} s.publish() - s.data.dataCipher = newDataCipher() - s.sharedKey = s.data.dataCipher.Key() + s.routingData.dataCipher = newDataCipher() + s.sharedKey = s.routingData.dataCipher.Key() s.sendPing(s.sharedKey) s.resetPingTimer() @@ -159,8 +189,8 @@ func newStateClient(b *stateBase) peerState { func (s *stateClient) Name() string { return "client" } func (s *stateClient) OnPong(addr netip.AddrPort, p pongPacket) peerState { - if !s.data.up { - s.data.up = true + if !s.routingData.up { + s.routingData.up = true s.publish() } s.resetTimeoutTimer() @@ -174,7 +204,7 @@ func (s *stateClient) OnPingTimer() peerState { } func (s *stateClient) OnTimeoutTimer() peerState { - s.data.up = false + s.routingData.up = false s.publish() return nil } @@ -196,19 +226,134 @@ func newStateServer(b *stateBase) peerState { func (s *stateServer) Name() string { return "server" } func (s *stateServer) OnPing(addr netip.AddrPort, p pingPacket) peerState { - if addr != s.data.remoteAddr { + if addr != s.routingData.remoteAddr { s.logf("Got new peer address: %v", addr) - s.data.remoteAddr = addr - s.data.up = true + s.routingData.remoteAddr = addr + s.routingData.up = true s.publish() } - if s.data.dataCipher == nil || p.SharedKey != s.data.dataCipher.Key() { + if s.routingData.dataCipher == nil || p.SharedKey != s.routingData.dataCipher.Key() { s.logf("Got new shared key.") - s.data.dataCipher = newDataCipherFromKey(p.SharedKey) + s.routingData.dataCipher = newDataCipherFromKey(p.SharedKey) + s.routingData.up = true s.publish() } s.sendPong(p) return nil } + +// ---------------------------------------------------------------------------- + +type stateSelectRelay struct { + *stateBase +} + +func newStateSelectRelay(b *stateBase) peerState { + s := &stateSelectRelay{stateBase: b} + s.routingData.dataCipher = nil + s.routingData.up = false + s.publish() + + if relay := s.selectRelay(); relay != 0 { + s.routingData.up = false + s.routingData.relayIP = relay + return s.selectRole() + } + + s.resetPingTimer() + s.stopTimeoutTimer() + return s +} + +func (s *stateSelectRelay) selectRole() peerState { + if s.localIP < s.remoteIP { + return newStateServerRelayed(s.stateBase) + } + return newStateClientRelayed(s.stateBase) +} + +func (s *stateSelectRelay) Name() string { return "select-relay" } + +func (s *stateSelectRelay) OnPingTimer() peerState { + if relay := s.selectRelay(); relay != 0 { + s.routingData.relayIP = relay + return s.selectRole() + } + s.resetPingTimer() + return nil +} + +// ---------------------------------------------------------------------------- + +type stateClientRelayed struct { + sharedKey [32]byte + *stateBase +} + +func newStateClientRelayed(b *stateBase) peerState { + s := &stateClientRelayed{stateBase: b} + + s.routingData.dataCipher = newDataCipher() + s.sharedKey = s.routingData.dataCipher.Key() + s.publish() + + s.sendPing(s.sharedKey) + s.resetPingTimer() + s.resetTimeoutTimer() + return s +} + +func (s *stateClientRelayed) Name() string { return "client-relayed" } + +func (s *stateClientRelayed) OnPong(addr netip.AddrPort, p pongPacket) peerState { + if !s.routingData.up { + s.routingData.up = true + s.publish() + } + s.resetTimeoutTimer() + return nil +} + +func (s *stateClientRelayed) OnPingTimer() peerState { + s.sendPing(s.sharedKey) + s.resetPingTimer() + return nil +} + +func (s *stateClientRelayed) OnTimeoutTimer() peerState { + return newStateSelectRelay(s.stateBase) +} + +// ---------------------------------------------------------------------------- + +type stateServerRelayed struct { + *stateBase +} + +func newStateServerRelayed(b *stateBase) peerState { + s := &stateServerRelayed{b} + s.stopPingTimer() + s.resetTimeoutTimer() + return s +} + +func (s *stateServerRelayed) Name() string { return "server-relayed" } + +func (s *stateServerRelayed) OnPing(addr netip.AddrPort, p pingPacket) peerState { + if s.routingData.dataCipher == nil || p.SharedKey != s.routingData.dataCipher.Key() { + s.logf("Got new shared key.") + s.routingData.up = true + s.routingData.dataCipher = newDataCipherFromKey(p.SharedKey) + s.publish() + } + + s.sendPong(p) + s.resetTimeoutTimer() + return nil +} + +func (s *stateServerRelayed) OnTimeoutTimer() peerState { + return newStateSelectRelay(s.stateBase) +} diff --git a/node/peer-supervisor.go b/node/peer-supervisor.go index 2c46ad2..ac2508e 100644 --- a/node/peer-supervisor.go +++ b/node/peer-supervisor.go @@ -11,12 +11,7 @@ const ( timeoutInterval = 20 * time.Second ) -func (rp *remotePeer) supervise( - conf m.PeerConfig, - remoteIP byte, - conn *connWriter, - peers *remotePeers, -) { +func (rp *remotePeer) supervise(conf m.PeerConfig) { defer panicHandler() base := &stateBase{ diff --git a/node/peer.go b/node/peer.go index 3cc308e..bae2c9c 100644 --- a/node/peer.go +++ b/node/peer.go @@ -11,7 +11,7 @@ import ( type remotePeers [256]*remotePeer -type peerData struct { +type peerRoutingData struct { up bool relay bool controlCipher *controlCipher @@ -29,7 +29,7 @@ type remotePeer struct { // Shared state. peers *remotePeers - published *atomic.Pointer[peerData] + published *atomic.Pointer[peerRoutingData] // Only used in HandlePacket / Not synchronized. dupCheck *dupCheck @@ -53,7 +53,7 @@ func newRemotePeer(conf m.PeerConfig, remoteIP byte, iface *ifWriter, conn *conn iface: iface, conn: conn, peers: peers, - published: &atomic.Pointer[peerData]{}, + published: &atomic.Pointer[peerRoutingData]{}, dupCheck: newDupCheck(0), decryptBuf: make([]byte, bufferSize), encryptBuf: make([]byte, bufferSize), @@ -62,11 +62,11 @@ func newRemotePeer(conf m.PeerConfig, remoteIP byte, iface *ifWriter, conn *conn controlPackets: make(chan controlPacket, 512), } - pd := peerData{} + pd := peerRoutingData{} rp.published.Store(&pd) //go newPeerSuper(rp).Run() - go rp.supervise(conf, remoteIP, conn, peers) + go rp.supervise(conf) return rp } @@ -102,9 +102,8 @@ func (rp *remotePeer) HandlePacket(addr netip.AddrPort, h header, data []byte) { // ---------------------------------------------------------------------------- func (rp *remotePeer) handleControlPacket(addr netip.AddrPort, h header, data []byte) { - shared := rp.published.Load() - if shared.controlCipher == nil { - log.Printf("Shared: %+v", *shared) + routingData := rp.published.Load() + if routingData.controlCipher == nil { rp.logf("Not connected (control).") return } @@ -114,7 +113,7 @@ func (rp *remotePeer) handleControlPacket(addr netip.AddrPort, h header, data [] return } - out, ok := shared.controlCipher.Decrypt(data, rp.decryptBuf) + out, ok := routingData.controlCipher.Decrypt(data, rp.decryptBuf) if !ok { rp.logf("Failed to decrypt control packet.") return @@ -150,13 +149,13 @@ func (rp *remotePeer) handleControlPacket(addr netip.AddrPort, h header, data [] // ---------------------------------------------------------------------------- func (rp *remotePeer) handleDataPacket(data []byte) { - shared := rp.published.Load() - if shared.dataCipher == nil { + routingData := rp.published.Load() + if routingData.dataCipher == nil { rp.logf("Not connected (recv).") return } - dec, ok := shared.dataCipher.Decrypt(data, rp.decryptBuf) + dec, ok := routingData.dataCipher.Decrypt(data, rp.decryptBuf) if !ok { rp.logf("Failed to decrypt data packet.") return @@ -168,19 +167,19 @@ func (rp *remotePeer) handleDataPacket(data []byte) { // ---------------------------------------------------------------------------- func (rp *remotePeer) handleRelayPacket(h header, data []byte) { - shared := rp.published.Load() - if shared.dataCipher == nil { + routingData := rp.published.Load() + if routingData.dataCipher == nil { rp.logf("Not connected (recv).") return } - dec, ok := shared.dataCipher.Decrypt(data, rp.decryptBuf) + dec, ok := routingData.dataCipher.Decrypt(data, rp.decryptBuf) if !ok { rp.logf("Failed to decrypt data packet.") return } - rp.peers[h.DestIP].sendDirect(dec) + rp.peers[h.DestIP].SendAsIs(dec) } // ---------------------------------------------------------------------------- @@ -189,13 +188,13 @@ func (rp *remotePeer) handleRelayPacket(h header, data []byte) { // // This function is called by a single thread. func (rp *remotePeer) SendData(data []byte) { - rp.sendData(dataStreamID, rp.remoteIP, data) + rp.encryptAndSend(dataStreamID, rp.remoteIP, data) } func (rp *remotePeer) HandleInterfacePacket(data []byte) { - shared := rp.published.Load() + routingData := rp.published.Load() - if shared.dataCipher == nil { + if routingData.dataCipher == nil { rp.logf("Not connected (handle interface).") return } @@ -207,10 +206,10 @@ func (rp *remotePeer) HandleInterfacePacket(data []byte) { DestIP: rp.remoteIP, } - enc := shared.dataCipher.Encrypt(h, data, rp.encryptBuf) + enc := routingData.dataCipher.Encrypt(h, data, rp.encryptBuf) - if shared.relayIP != 0 { - rp.peers[shared.relayIP].RelayData(shared.relayIP, enc) + if routingData.relayIP != 0 { + rp.peers[routingData.relayIP].RelayFor(rp.remoteIP, enc) } else { rp.SendData(data) } @@ -218,16 +217,23 @@ func (rp *remotePeer) HandleInterfacePacket(data []byte) { // ---------------------------------------------------------------------------- -func (rp *remotePeer) RelayData(destIP byte, data []byte) { - rp.sendData(relayStreamID, destIP, data) +func (rp *remotePeer) CanRelay() bool { + data := rp.published.Load() + return data.relay && data.up } // ---------------------------------------------------------------------------- -func (rp *remotePeer) sendData(streamID byte, destIP byte, data []byte) { - shared := rp.published.Load() - if shared.dataCipher == nil || shared.remoteAddr == zeroAddrPort { - rp.logf("Not connected (send).") +func (rp *remotePeer) RelayFor(destIP byte, data []byte) { + rp.encryptAndSend(relayStreamID, destIP, data) +} + +// ---------------------------------------------------------------------------- + +func (rp *remotePeer) encryptAndSend(streamID byte, destIP byte, data []byte) { + routingData := rp.published.Load() + if routingData.dataCipher == nil || routingData.remoteAddr == zeroAddrPort { + rp.logf("Not connected (encrypt and send).") return } @@ -238,15 +244,19 @@ func (rp *remotePeer) sendData(streamID byte, destIP byte, data []byte) { DestIP: destIP, } - enc := shared.dataCipher.Encrypt(h, data, rp.encryptBuf) - rp.conn.WriteTo(enc, shared.remoteAddr) + enc := routingData.dataCipher.Encrypt(h, data, rp.encryptBuf) + rp.conn.WriteTo(enc, routingData.remoteAddr) } -func (rp *remotePeer) sendDirect(data []byte) { - shared := rp.published.Load() - if shared.remoteAddr == zeroAddrPort { - rp.logf("Not connected (send).") +// ---------------------------------------------------------------------------- + +// SendAsIs is used when forwarding already-encrypted data from one peer to +// another. +func (rp *remotePeer) SendAsIs(data []byte) { + routingData := rp.published.Load() + if routingData.remoteAddr == zeroAddrPort { + rp.logf("Not connected (send direct).") return } - rp.conn.WriteTo(data, shared.remoteAddr) + rp.conn.WriteTo(data, routingData.remoteAddr) }