From 6b3216f2d2c74d15ed9c4e086387343fa881b714 Mon Sep 17 00:00:00 2001 From: jdl Date: Mon, 10 Feb 2025 19:11:30 +0100 Subject: [PATCH] WIP --- peer/connreader2.go | 132 ++++++++++ peer/connreader_test.go | 20 +- peer/connwriter.go | 4 +- peer/connwriter2.go | 109 ++++++++ peer/connwriter2_test.go | 145 +++++++++++ peer/controlmessage.go | 18 +- peer/crypto.go | 10 +- peer/crypto_test.go | 8 +- peer/data-flow.dot | 14 ++ peer/files.go | 90 +++++++ peer/files_test.go | 57 +++++ peer/globals.go | 12 +- peer/hubpoller.go | 100 ++++++++ peer/ifreader2.go | 78 ++++++ peer/ifreader2_test.go | 83 +++++++ peer/ifreader_test.go | 4 +- peer/ifwriter.go | 5 - peer/interfaces.go | 24 +- peer/mcreader.go | 2 +- peer/mcreader_test.go | 138 +++++++++++ peer/mock-iface_test.go | 31 +++ peer/mock-network_test.go | 80 ++++++ peer/packets-util.go | 4 +- peer/packets-util_test.go | 24 +- peer/packets.go | 79 +++--- peer/packets_test.go | 65 +++++ peer/peer_test.go | 125 ++++++++++ peer/peerstates.go | 355 ++++++++++++++++++++++++++ peer/peerstates_test.go | 509 ++++++++++++++++++++++++++++++++++++++ peer/pubaddrs.go | 75 ++++++ peer/pubaddrs_test.go | 29 +++ peer/routingtable.go | 137 ++++++++++ peer/routingtable_test.go | 169 +++++++++++++ peer/state.go | 29 --- peer/supervisor.go | 103 ++++++++ peer/util_test.go | 26 ++ 36 files changed, 2763 insertions(+), 130 deletions(-) create mode 100644 peer/connreader2.go create mode 100644 peer/connwriter2.go create mode 100644 peer/connwriter2_test.go create mode 100644 peer/data-flow.dot create mode 100644 peer/files.go create mode 100644 peer/files_test.go create mode 100644 peer/hubpoller.go create mode 100644 peer/ifreader2.go create mode 100644 peer/ifreader2_test.go delete mode 100644 peer/ifwriter.go create mode 100644 peer/mcreader_test.go create mode 100644 peer/mock-iface_test.go create mode 100644 peer/mock-network_test.go create mode 100644 peer/peer_test.go create mode 100644 peer/peerstates.go create mode 100644 peer/peerstates_test.go create mode 100644 peer/pubaddrs.go create mode 100644 peer/pubaddrs_test.go create mode 100644 peer/routingtable.go create mode 100644 peer/routingtable_test.go delete mode 100644 peer/state.go create mode 100644 peer/supervisor.go create mode 100644 peer/util_test.go diff --git a/peer/connreader2.go b/peer/connreader2.go new file mode 100644 index 0000000..d9feab8 --- /dev/null +++ b/peer/connreader2.go @@ -0,0 +1,132 @@ +package peer + +import ( + "io" + "log" + "net/netip" + "sync/atomic" +) + +type ConnReader struct { + // Input + readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error) + + // Output + iface io.Writer + forwardData func(ip byte, pkt []byte) + handleControlMsg func(pkt any) + + localIP byte + rt *atomic.Pointer[RoutingTable] + + buf []byte + decBuf []byte +} + +func NewConnReader( + readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error), + iface io.Writer, + forwardData func(ip byte, pkt []byte), + handleControlMsg func(pkt any), + rt *atomic.Pointer[RoutingTable], +) *ConnReader { + return &ConnReader{ + readFromUDPAddrPort: readFromUDPAddrPort, + iface: iface, + forwardData: forwardData, + handleControlMsg: handleControlMsg, + localIP: rt.Load().LocalIP, + rt: rt, + buf: newBuf(), + decBuf: newBuf(), + } +} + +func (r *ConnReader) Run() { + for { + r.handleNextPacket() + } +} + +func (r *ConnReader) handleNextPacket() { + buf := r.buf[:bufferSize] + n, remoteAddr, err := r.readFromUDPAddrPort(buf) + if err != nil { + log.Fatalf("Failed to read from UDP port: %v", err) + } + + if n < headerSize { + return + } + + remoteAddr = netip.AddrPortFrom(remoteAddr.Addr().Unmap(), remoteAddr.Port()) + + buf = buf[:n] + h := parseHeader(buf) + + peer := r.rt.Load().Peers[h.SourceIP] + //peer := rt.Peers[h.SourceIP] + + switch h.StreamID { + case controlStreamID: + r.handleControlPacket(remoteAddr, peer, h, buf) + case dataStreamID: + r.handleDataPacket(peer, h, buf) + default: + r.logf("Unknown stream ID: %d", h.StreamID) + } +} + +func (r *ConnReader) handleControlPacket( + remoteAddr netip.AddrPort, + peer RemotePeer, + h header, + enc []byte, +) { + if peer.ControlCipher == nil { + return + } + + if h.DestIP != r.localIP { + r.logf("Incorrect destination IP on control packet: %d", h.DestIP) + return + } + + msg, err := peer.DecryptControlPacket(remoteAddr, h, enc, r.decBuf) + if err != nil { + r.logf("Failed to decrypt control packet: %v", err) + return + } + + r.handleControlMsg(msg) +} + +func (r *ConnReader) handleDataPacket( + peer RemotePeer, + h header, + enc []byte, +) { + if !peer.Up { + r.logf("Not connected (recv).") + return + } + + data, err := peer.DecryptDataPacket(h, enc, r.decBuf) + if err != nil { + r.logf("Failed to decrypt data packet: %v", err) + return + } + + if h.DestIP == r.localIP { + if _, err := r.iface.Write(data); err != nil { + log.Fatalf("Failed to write to interface: %v", err) + } + return + } + + r.forwardData(h.DestIP, data) +} + +func (r *ConnReader) logf(format string, args ...any) { + log.Printf("[ConnReader] "+format, args...) +} diff --git a/peer/connreader_test.go b/peer/connreader_test.go index 39da83c..714f6f3 100644 --- a/peer/connreader_test.go +++ b/peer/connreader_test.go @@ -109,7 +109,7 @@ func newConnReadeTestHarness() (h connReaderTestHarness) { func TestConnReader_handleControlPacket(t *testing.T) { h := newConnReadeTestHarness() - pkt := synPacket{TraceID: 1234} + pkt := PacketSyn{TraceID: 1234} h.WRemote.SendControlPacket(pkt, h.Remote) @@ -119,7 +119,7 @@ func TestConnReader_handleControlPacket(t *testing.T) { t.Fatal(h.Super.Messages) } - msg := h.Super.Messages[0].(controlMsg[synPacket]) + msg := h.Super.Messages[0].(controlMsg[PacketSyn]) if !reflect.DeepEqual(pkt, msg.Packet) { t.Fatal(msg.Packet) } @@ -141,7 +141,7 @@ func TestConnReader_handleNextPacket_short(t *testing.T) { func TestConnReader_handleNextPacket_unknownStreamID(t *testing.T) { h := newConnReadeTestHarness() - pkt := synPacket{TraceID: 1234} + pkt := PacketSyn{TraceID: 1234} encrypted := encryptControlPacket(1, h.Remote, pkt, newBuf(), newBuf()) var header header @@ -160,7 +160,7 @@ func TestConnReader_handleNextPacket_unknownStreamID(t *testing.T) { func TestConnReader_handleControlPacket_noCipher(t *testing.T) { h := newConnReadeTestHarness() - pkt := synPacket{TraceID: 1234} + pkt := PacketSyn{TraceID: 1234} //encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) encrypted := encryptControlPacket(1, h.Remote, pkt, newBuf(), newBuf()) @@ -180,7 +180,7 @@ func TestConnReader_handleControlPacket_noCipher(t *testing.T) { func TestConnReader_handleControlPacket_incorrectDest(t *testing.T) { h := newConnReadeTestHarness() - pkt := synPacket{TraceID: 1234} + pkt := PacketSyn{TraceID: 1234} encrypted := encryptControlPacket(2, h.Remote, pkt, newBuf(), newBuf()) var header header @@ -199,7 +199,7 @@ func TestConnReader_handleControlPacket_incorrectDest(t *testing.T) { func TestConnReader_handleControlPacket_modified(t *testing.T) { h := newConnReadeTestHarness() - pkt := synPacket{TraceID: 1234} + pkt := PacketSyn{TraceID: 1234} encrypted := encryptControlPacket(2, h.Remote, pkt, newBuf(), newBuf()) encrypted[len(encrypted)-1]++ @@ -237,10 +237,10 @@ func TestConnReader_handleControlPacket_unknownPacketType(t *testing.T) { func TestConnReader_handleControlPacket_duplicate(t *testing.T) { h := newConnReadeTestHarness() - pkt := ackPacket{TraceID: 1234} + pkt := PacketAck{TraceID: 1234} h.WRemote.SendControlPacket(pkt, h.Remote) - *h.Remote.Counter = *h.Remote.Counter - 1 + *h.Remote.counter = *h.Remote.counter - 1 h.WRemote.SendControlPacket(pkt, h.Remote) h.R.handleNextPacket() @@ -250,7 +250,7 @@ func TestConnReader_handleControlPacket_duplicate(t *testing.T) { t.Fatal(h.Super.Messages) } - msg := h.Super.Messages[0].(controlMsg[ackPacket]) + msg := h.Super.Messages[0].(controlMsg[PacketAck]) if !reflect.DeepEqual(pkt, msg.Packet) { t.Fatal(msg.Packet) } @@ -301,7 +301,7 @@ func TestConnReader_handleDataPacket_duplicate(t *testing.T) { pkt := make([]byte, 123) h.WRemote.SendDataPacket(pkt, h.Remote) - *h.Remote.Counter = *h.Remote.Counter - 1 + *h.Remote.counter = *h.Remote.counter - 1 h.WRemote.SendDataPacket(pkt, h.Remote) h.R.handleNextPacket() diff --git a/peer/connwriter.go b/peer/connwriter.go index 7daa567..8a09e35 100644 --- a/peer/connwriter.go +++ b/peer/connwriter.go @@ -37,13 +37,13 @@ func newConnWriter(conn udpWriter, localIP byte) *connWriter { } // Not safe for concurrent use. Should only be called by supervisor. -func (w *connWriter) SendControlPacket(pkt marshaller, peer *RemotePeer) { +func (w *connWriter) SendControlPacket(pkt Marshaller, peer *RemotePeer) { enc := encryptControlPacket(w.localIP, peer, pkt, w.cBuf1, w.cBuf2) w.writeTo(enc, peer.DirectAddr) } // Relay control packet. Peer must not be nil. -func (w *connWriter) RelayControlPacket(pkt marshaller, peer, relay *RemotePeer) { +func (w *connWriter) RelayControlPacket(pkt Marshaller, peer, relay *RemotePeer) { enc := encryptControlPacket(w.localIP, peer, pkt, w.cBuf1, w.cBuf2) enc = encryptDataPacket(w.localIP, peer.IP, relay, enc, w.cBuf1) w.writeTo(enc, relay.DirectAddr) diff --git a/peer/connwriter2.go b/peer/connwriter2.go new file mode 100644 index 0000000..e58250d --- /dev/null +++ b/peer/connwriter2.go @@ -0,0 +1,109 @@ +package peer + +import ( + "log" + "net/netip" + "sync" + "sync/atomic" +) + +type ConnWriter struct { + wLock sync.Mutex // Lock around for sending on UDP Conn. + + // Output. + writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error) + + // Shared state. + rt *atomic.Pointer[RoutingTable] + + // For sending control packets. + cBuf1 []byte + cBuf2 []byte + + // For sending data packets. + dBuf1 []byte + dBuf2 []byte +} + +func NewConnWriter( + writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error), + rt *atomic.Pointer[RoutingTable], +) *ConnWriter { + return &ConnWriter{ + writeToUDPAddrPort: writeToUDPAddrPort, + rt: rt, + cBuf1: newBuf(), + cBuf2: newBuf(), + dBuf1: newBuf(), + dBuf2: newBuf(), + } +} + +// Called by ConnReader to forward already encrypted bytes to another peer. +func (w *ConnWriter) Forward(ip byte, pkt []byte) { + peer := w.rt.Load().Peers[ip] + if !(peer.Up && peer.Direct) { + w.logf("Failed to forward to %d.", ip) + return + } + w.writeTo(pkt, peer.DirectAddr) +} + +// Called by IFReader to send data. Encryption will be applied, and packet will +// be relayed if appropriate. +func (w *ConnWriter) WriteData(ip byte, pkt []byte) { + rt := w.rt.Load() + peer := rt.Peers[ip] + if !peer.Up { + w.logf("Failed to send data to %d.", ip) + return + } + + enc := peer.EncryptDataPacket(ip, pkt, w.dBuf1) + + if peer.Direct { + w.writeTo(enc, peer.DirectAddr) + return + } + + relay, ok := rt.GetRelay() + if !ok { + w.logf("Failed to send data to %d. No relay.", ip) + return + } + + enc = relay.EncryptDataPacket(ip, enc, w.dBuf2) + w.writeTo(enc, relay.DirectAddr) +} + +// Called by Supervisor to send control packets. +func (w *ConnWriter) WriteControl(peer RemotePeer, pkt Marshaller) { + enc := peer.EncryptControlPacket(pkt, w.cBuf2, w.cBuf1) + + if peer.Direct { + w.writeTo(enc, peer.DirectAddr) + return + } + + rt := w.rt.Load() + relay, ok := rt.GetRelay() + if !ok { + w.logf("Failed to send control to %d. No relay.", peer.IP) + return + } + + enc = relay.EncryptDataPacket(peer.IP, enc, w.cBuf2) + w.writeTo(enc, relay.DirectAddr) +} + +func (w *ConnWriter) writeTo(pkt []byte, addr netip.AddrPort) { + w.wLock.Lock() + if _, err := w.writeToUDPAddrPort(pkt, addr); err != nil { + w.logf("Failed to write to UDP port: %v", err) + } + w.wLock.Unlock() +} + +func (w *ConnWriter) logf(s string, args ...any) { + log.Printf("[ConnWriter] "+s, args...) +} diff --git a/peer/connwriter2_test.go b/peer/connwriter2_test.go new file mode 100644 index 0000000..f0bb00f --- /dev/null +++ b/peer/connwriter2_test.go @@ -0,0 +1,145 @@ +package peer + +import ( + "testing" +) + +func TestConnWriter_WriteData_direct(t *testing.T) { + p1, p2, _ := NewPeersForTesting() + + in := RandPacket() + p1.ConnWriter.WriteData(2, in) + + packets := p2.Conn.Packets() + if len(packets) != 1 { + t.Fatal(packets) + } +} + +func TestConnWriter_WriteData_peerNotUp(t *testing.T) { + p1, p2, _ := NewPeersForTesting() + p1.RT.Load().Peers[2].Up = false + + in := RandPacket() + p1.ConnWriter.WriteData(2, in) + + packets := p2.Conn.Packets() + if len(packets) != 0 { + t.Fatal(packets) + } +} + +func TestConnWriter_WriteData_relay(t *testing.T) { + p1, _, p3 := NewPeersForTesting() + + p1.RT.Load().Peers[2].Direct = false + p1.RT.Load().RelayIP = 3 + + in := RandPacket() + p1.ConnWriter.WriteData(2, in) + + packets := p3.Conn.Packets() + if len(packets) != 1 { + t.Fatal(packets) + } +} + +func TestConnWriter_WriteData_relayNotAvailable(t *testing.T) { + p1, _, p3 := NewPeersForTesting() + + p1.RT.Load().Peers[2].Direct = false + p1.RT.Load().Peers[3].Up = false + p1.RT.Load().RelayIP = 3 + + in := RandPacket() + p1.ConnWriter.WriteData(2, in) + + packets := p3.Conn.Packets() + if len(packets) != 0 { + t.Fatal(packets) + } +} + +func TestConnWriter_WriteControl_direct(t *testing.T) { + p1, p2, _ := NewPeersForTesting() + + orig := PacketProbe{TraceID: newTraceID()} + + p1.ConnWriter.WriteControl(p1.RT.Load().Peers[2], orig) + + packets := p2.Conn.Packets() + if len(packets) != 1 { + t.Fatal(packets) + } +} + +func TestConnWriter_WriteControl_relay(t *testing.T) { + p1, _, p3 := NewPeersForTesting() + + p1.RT.Load().Peers[2].Direct = false + p1.RT.Load().RelayIP = 3 + + orig := PacketProbe{TraceID: newTraceID()} + + p1.ConnWriter.WriteControl(p1.RT.Load().Peers[2], orig) + + packets := p3.Conn.Packets() + if len(packets) != 1 { + t.Fatal(packets) + } +} + +func TestConnWriter_WriteControl_relayNotAvailable(t *testing.T) { + p1, _, p3 := NewPeersForTesting() + + p1.RT.Load().Peers[2].Direct = false + p1.RT.Load().Peers[3].Up = false + p1.RT.Load().RelayIP = 3 + + orig := PacketProbe{TraceID: newTraceID()} + + p1.ConnWriter.WriteControl(p1.RT.Load().Peers[2], orig) + + packets := p3.Conn.Packets() + if len(packets) != 0 { + t.Fatal(packets) + } +} + +func TestConnWriter__Forward(t *testing.T) { + p1, p2, _ := NewPeersForTesting() + + in := RandPacket() + p1.ConnWriter.Forward(2, in) + + packets := p2.Conn.Packets() + if len(packets) != 1 { + t.Fatal(packets) + } +} + +func TestConnWriter__Forward_notUp(t *testing.T) { + p1, p2, _ := NewPeersForTesting() + p1.RT.Load().Peers[2].Up = false + + in := RandPacket() + p1.ConnWriter.Forward(2, in) + + packets := p2.Conn.Packets() + if len(packets) != 0 { + t.Fatal(packets) + } +} + +func TestConnWriter__Forward_notDirect(t *testing.T) { + p1, p2, _ := NewPeersForTesting() + p1.RT.Load().Peers[2].Direct = false + + in := RandPacket() + p1.ConnWriter.Forward(2, in) + + packets := p2.Conn.Packets() + if len(packets) != 0 { + t.Fatal(packets) + } +} diff --git a/peer/controlmessage.go b/peer/controlmessage.go index d8e9a17..7180dd0 100644 --- a/peer/controlmessage.go +++ b/peer/controlmessage.go @@ -17,25 +17,25 @@ type controlMsg[T any] struct { func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error) { switch buf[0] { - case packetTypeSyn: - packet, err := parseSynPacket(buf) - return controlMsg[synPacket]{ + case PacketTypeSyn: + packet, err := ParsePacketSyn(buf) + return controlMsg[PacketSyn]{ SrcIP: srcIP, SrcAddr: srcAddr, Packet: packet, }, err - case packetTypeAck: - packet, err := parseAckPacket(buf) - return controlMsg[ackPacket]{ + case PacketTypeAck: + packet, err := ParsePacketAck(buf) + return controlMsg[PacketAck]{ SrcIP: srcIP, SrcAddr: srcAddr, Packet: packet, }, err - case packetTypeProbe: - packet, err := parseProbePacket(buf) - return controlMsg[probePacket]{ + case PacketTypeProbe: + packet, err := ParsePacketProbe(buf) + return controlMsg[PacketProbe]{ SrcIP: srcIP, SrcAddr: srcAddr, Packet: packet, diff --git a/peer/crypto.go b/peer/crypto.go index f9c61db..dcc042b 100644 --- a/peer/crypto.go +++ b/peer/crypto.go @@ -37,13 +37,13 @@ func generateKeys() cryptoKeys { func encryptControlPacket( localIP byte, peer *RemotePeer, - pkt marshaller, + pkt Marshaller, tmp []byte, out []byte, ) []byte { h := header{ StreamID: controlStreamID, - Counter: atomic.AddUint64(peer.Counter, 1), + Counter: atomic.AddUint64(peer.counter, 1), SourceIP: localIP, DestIP: peer.IP, } @@ -66,7 +66,7 @@ func decryptControlPacket( return nil, errDecryptionFailed } - if peer.DupCheck.IsDup(h.Counter) { + if peer.dupCheck.IsDup(h.Counter) { return nil, errDuplicateSeqNum } @@ -89,7 +89,7 @@ func encryptDataPacket( ) []byte { h := header{ StreamID: dataStreamID, - Counter: atomic.AddUint64(peer.Counter, 1), + Counter: atomic.AddUint64(peer.counter, 1), SourceIP: localIP, DestIP: destIP, } @@ -108,7 +108,7 @@ func decryptDataPacket( return nil, errDecryptionFailed } - if peer.DupCheck.IsDup(h.Counter) { + if peer.dupCheck.IsDup(h.Counter) { return nil, errDuplicateSeqNum } diff --git a/peer/crypto_test.go b/peer/crypto_test.go index c93b87f..824a43a 100644 --- a/peer/crypto_test.go +++ b/peer/crypto_test.go @@ -33,7 +33,7 @@ func TestDecryptControlPacket(t *testing.T) { out = make([]byte, bufferSize) ) - in := synPacket{ + in := PacketSyn{ TraceID: newTraceID(), SharedKey: r1.DataCipher.Key(), Direct: true, @@ -47,7 +47,7 @@ func TestDecryptControlPacket(t *testing.T) { t.Fatal(err) } - msg, ok := iMsg.(controlMsg[synPacket]) + msg, ok := iMsg.(controlMsg[PacketSyn]) if !ok { t.Fatal(ok) } @@ -64,7 +64,7 @@ func TestDecryptControlPacket_decryptionFailed(t *testing.T) { out = make([]byte, bufferSize) ) - in := synPacket{ + in := PacketSyn{ TraceID: newTraceID(), SharedKey: r1.DataCipher.Key(), Direct: true, @@ -90,7 +90,7 @@ func TestDecryptControlPacket_duplicate(t *testing.T) { out = make([]byte, bufferSize) ) - in := synPacket{ + in := PacketSyn{ TraceID: newTraceID(), SharedKey: r1.DataCipher.Key(), Direct: true, diff --git a/peer/data-flow.dot b/peer/data-flow.dot new file mode 100644 index 0000000..45b6f05 --- /dev/null +++ b/peer/data-flow.dot @@ -0,0 +1,14 @@ +digraph d { + ifReader -> connWriter; + connReader -> ifWriter; + connReader -> connWriter; + connReader -> supervisor; + mcReader -> supervisor; + supervisor -> connWriter; + supervisor -> mcWriter; + hubPoller -> supervisor; + + connWriter [shape="box"]; + mcWriter [shape="box"]; + ifWriter [shape="box"]; +} \ No newline at end of file diff --git a/peer/files.go b/peer/files.go new file mode 100644 index 0000000..b0eade5 --- /dev/null +++ b/peer/files.go @@ -0,0 +1,90 @@ +package peer + +import ( + "encoding/json" + "log" + "os" + "path/filepath" + "vppn/m" +) + +type localConfig struct { + m.PeerConfig + PubKey []byte + PrivKey []byte + PubSignKey []byte + PrivSignKey []byte +} + +func configDir(netName string) string { + d, err := os.UserHomeDir() + if err != nil { + log.Fatalf("Failed to get user home directory: %v", err) + } + return filepath.Join(d, ".vppn", netName) +} + +func peerConfigPath(netName string) string { + return filepath.Join(configDir(netName), "peer-config.json") +} + +func peerStatePath(netName string) string { + return filepath.Join(configDir(netName), "peer-state.json") +} + +func storeJson(x any, outPath string) error { + outDir := filepath.Dir(outPath) + _ = os.MkdirAll(outDir, 0700) + + tmpPath := outPath + ".tmp" + buf, err := json.Marshal(x) + if err != nil { + return err + } + + f, err := os.Create(tmpPath) + if err != nil { + return err + } + + if _, err := f.Write(buf); err != nil { + f.Close() + return err + } + + if err := f.Sync(); err != nil { + f.Close() + return err + } + + if err := f.Close(); err != nil { + return err + } + + return os.Rename(tmpPath, outPath) +} + +func storePeerConfig(netName string, pc localConfig) error { + return storeJson(pc, peerConfigPath(netName)) +} + +func storeNetworkState(netName string, ps m.NetworkState) error { + return storeJson(ps, peerStatePath(netName)) +} + +func loadJson(dataPath string, ptr any) error { + data, err := os.ReadFile(dataPath) + if err != nil { + return err + } + + return json.Unmarshal(data, ptr) +} + +func loadPeerConfig(netName string) (pc localConfig, err error) { + return pc, loadJson(peerConfigPath(netName), &pc) +} + +func loadNetworkState(netName string) (ps m.NetworkState, err error) { + return ps, loadJson(peerStatePath(netName), &ps) +} diff --git a/peer/files_test.go b/peer/files_test.go new file mode 100644 index 0000000..5e32ced --- /dev/null +++ b/peer/files_test.go @@ -0,0 +1,57 @@ +package peer + +import ( + "path/filepath" + "reflect" + "testing" +) + +func TestFilePaths(t *testing.T) { + confDir := configDir("netName") + if filepath.Base(confDir) != "netName" { + t.Fatal(confDir) + } + if filepath.Base(filepath.Dir(confDir)) != ".vppn" { + t.Fatal(confDir) + } + + path := peerConfigPath("netName") + if path != filepath.Join(confDir, "peer-config.json") { + t.Fatal(path) + } + + path = peerStatePath("netName") + if path != filepath.Join(confDir, "peer-state.json") { + t.Fatal(path) + } +} + +func TestStoreLoadJson(t *testing.T) { + type Object struct { + Name string + Age int + Price float64 + } + + tmpDir := t.TempDir() + outPath := filepath.Join(tmpDir, "object.json") + + obj := Object{ + Name: "Jason", + Age: 22, + Price: 123.534, + } + + if err := storeJson(obj, outPath); err != nil { + t.Fatal(err) + } + + obj2 := Object{} + if err := loadJson(outPath, &obj2); err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(obj, obj2) { + t.Fatal(obj, obj2) + } +} diff --git a/peer/globals.go b/peer/globals.go index 4733ac8..0d7ada3 100644 --- a/peer/globals.go +++ b/peer/globals.go @@ -3,15 +3,21 @@ package peer import ( "net" "net/netip" + "time" ) const ( - bufferSize = 1536 - if_mtu = 1200 - if_queue_len = 2048 + bufferSize = 1536 + + if_mtu = 1200 + if_queue_len = 2048 + controlCipherOverhead = 16 dataCipherOverhead = 16 signOverhead = 64 + + pingInterval = 8 * time.Second + timeoutInterval = 30 * time.Second ) var multicastAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom( diff --git a/peer/hubpoller.go b/peer/hubpoller.go new file mode 100644 index 0000000..f608bd5 --- /dev/null +++ b/peer/hubpoller.go @@ -0,0 +1,100 @@ +package peer + +import ( + "encoding/json" + "io" + "log" + "net/http" + "net/url" + "time" + "vppn/m" +) + +type hubPoller struct { + client *http.Client + req *http.Request + versions [256]int64 + localIP byte + netName string + super controlMsgHandler +} + +func newHubPoller(localIP byte, netName, hubURL, apiKey string, super controlMsgHandler) (*hubPoller, error) { + u, err := url.Parse(hubURL) + if err != nil { + return nil, err + } + u.Path = "/peer/fetch-state/" + + client := &http.Client{Timeout: 8 * time.Second} + + req := &http.Request{ + Method: http.MethodGet, + URL: u, + Header: http.Header{}, + } + req.SetBasicAuth("", apiKey) + + return &hubPoller{ + client: client, + req: req, + localIP: localIP, + netName: netName, + super: super, + }, nil +} + +func (hp *hubPoller) Run() { + state, err := loadNetworkState(hp.netName) + if err != nil { + log.Printf("Failed to load network state: %v", err) + log.Printf("Polling hub...") + hp.pollHub() + } else { + hp.applyNetworkState(state) + } + + for range time.Tick(64 * time.Second) { + hp.pollHub() + } +} + +func (hp *hubPoller) pollHub() { + var state m.NetworkState + + resp, err := hp.client.Do(hp.req) + if err != nil { + log.Printf("Failed to fetch peer state: %v", err) + return + } + body, err := io.ReadAll(resp.Body) + _ = resp.Body.Close() + if err != nil { + log.Printf("Failed to read body from hub: %v", err) + return + } + + if err := json.Unmarshal(body, &state); err != nil { + log.Printf("Failed to unmarshal response from hub: %v\n%s", err, body) + return + } + + hp.applyNetworkState(state) + + if err := storeNetworkState(hp.netName, state); err != nil { + log.Printf("Failed to store network state: %v", err) + } +} + +func (hp *hubPoller) applyNetworkState(state m.NetworkState) { + for i, peer := range state.Peers { + if i != int(hp.localIP) { + if peer == nil || peer.Version != hp.versions[i] { + hp.super.HandleControlMsg(peerUpdateMsg{PeerIP: byte(i), Peer: state.Peers[i]}) + if peer != nil { + hp.versions[i] = peer.Version + } + } + } + } +} diff --git a/peer/ifreader2.go b/peer/ifreader2.go new file mode 100644 index 0000000..c390e8f --- /dev/null +++ b/peer/ifreader2.go @@ -0,0 +1,78 @@ +package peer + +import ( + "io" + "log" +) + +type IFReader struct { + iface io.Reader + connWriter interface { + WriteData(ip byte, pkt []byte) + } +} + +func NewIFReader( + iface io.Reader, + connWriter interface { + WriteData(ip byte, pkt []byte) + }, +) *IFReader { + return &IFReader{iface, connWriter} +} + +func (r *IFReader) Run() { + packet := newBuf() + for { + r.handleNextPacket(packet) + } +} + +func (r *IFReader) handleNextPacket(packet []byte) { + packet = r.readNextPacket(packet) + if remoteIP, ok := r.parsePacket(packet); ok { + r.connWriter.WriteData(remoteIP, packet) + } +} + +func (r *IFReader) readNextPacket(buf []byte) []byte { + n, err := r.iface.Read(buf[:cap(buf)]) + if err != nil { + log.Fatalf("Failed to read from interface: %v", err) + } + + return buf[:n] +} + +func (r *IFReader) parsePacket(buf []byte) (byte, bool) { + n := len(buf) + if n == 0 { + return 0, false + } + + version := buf[0] >> 4 + + switch version { + case 4: + if n < 20 { + r.logf("Short IPv4 packet: %d", len(buf)) + return 0, false + } + return buf[19], true + + case 6: + if len(buf) < 40 { + r.logf("Short IPv6 packet: %d", len(buf)) + return 0, false + } + return buf[39], true + + default: + r.logf("Invalid IP packet version: %v", version) + return 0, false + } +} + +func (*IFReader) logf(s string, args ...any) { + log.Printf("[IFReader] "+s, args...) +} diff --git a/peer/ifreader2_test.go b/peer/ifreader2_test.go new file mode 100644 index 0000000..779cf49 --- /dev/null +++ b/peer/ifreader2_test.go @@ -0,0 +1,83 @@ +package peer + +import ( + "testing" +) + +func TestIFReader_IPv4(t *testing.T) { + p1, p2, _ := NewPeersForTesting() + + pkt := make([]byte, 1234) + pkt[0] = 4 << 4 + pkt[19] = 2 // IP. + + p1.IFace.UserWrite(pkt) + p1.IFReader.handleNextPacket(newBuf()) + + packets := p2.Conn.Packets() + if len(packets) != 1 { + t.Fatal(packets) + } +} + +func TestIFReader_IPv6(t *testing.T) { + p1, p2, _ := NewPeersForTesting() + + pkt := make([]byte, 1234) + pkt[0] = 6 << 4 + pkt[39] = 2 // IP. + + p1.IFace.UserWrite(pkt) + p1.IFReader.handleNextPacket(newBuf()) + + packets := p2.Conn.Packets() + if len(packets) != 1 { + t.Fatal(packets) + } +} + +func TestIFReader_parsePacket_emptyPacket(t *testing.T) { + r := NewIFReader(nil, nil) + pkt := make([]byte, 0) + if ip, ok := r.parsePacket(pkt); ok { + t.Fatal(ip, ok) + } +} + +func TestIFReader_parsePacket_invalidIPVersion(t *testing.T) { + r := NewIFReader(nil, nil) + + for i := byte(1); i < 16; i++ { + if i == 4 || i == 6 { + continue + } + pkt := make([]byte, 1234) + pkt[0] = i << 4 + + if ip, ok := r.parsePacket(pkt); ok { + t.Fatal(i, ip, ok) + } + } +} + +func TestIFReader_parsePacket_shortIPv4(t *testing.T) { + r := NewIFReader(nil, nil) + + pkt := make([]byte, 19) + pkt[0] = 4 << 4 + + if ip, ok := r.parsePacket(pkt); ok { + t.Fatal(ip, ok) + } +} + +func TestIFReader_parsePacket_shortIPv6(t *testing.T) { + r := NewIFReader(nil, nil) + + pkt := make([]byte, 39) + pkt[0] = 6 << 4 + + if ip, ok := r.parsePacket(pkt); ok { + t.Fatal(ip, ok) + } +} diff --git a/peer/ifreader_test.go b/peer/ifreader_test.go index e8c5683..620d2b1 100644 --- a/peer/ifreader_test.go +++ b/peer/ifreader_test.go @@ -2,7 +2,6 @@ package peer import ( "bytes" - "net" "reflect" "sync/atomic" "testing" @@ -34,6 +33,7 @@ func TestIFReader_parsePacket_ipv6(t *testing.T) { } } +/* // Test that empty packets work as expected. func TestIFReader_parsePacket_emptyPacket(t *testing.T) { r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) @@ -99,7 +99,7 @@ func TestIFReader_readNextpacket(t *testing.T) { t.Fatalf("%s", pkt) } } - +*/ // ---------------------------------------------------------------------------- type sentPacket struct { diff --git a/peer/ifwriter.go b/peer/ifwriter.go deleted file mode 100644 index 59e2e26..0000000 --- a/peer/ifwriter.go +++ /dev/null @@ -1,5 +0,0 @@ -package peer - -import "io" - -type ifWriter io.Writer diff --git a/peer/interfaces.go b/peer/interfaces.go index 0d826c3..8e99e8b 100644 --- a/peer/interfaces.go +++ b/peer/interfaces.go @@ -1,10 +1,19 @@ package peer import ( + "io" "net" "net/netip" ) +type UDPConn interface { + ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) + WriteToUDPAddrPort([]byte, netip.AddrPort) (int, error) + WriteToUDP([]byte, *net.UDPAddr) (int, error) +} + +type ifWriter io.Writer + type udpReader interface { ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) } @@ -13,7 +22,11 @@ type udpWriter interface { WriteToUDPAddrPort([]byte, netip.AddrPort) (int, error) } -type marshaller interface { +type mcUDPWriter interface { + WriteToUDP([]byte, *net.UDPAddr) (int, error) +} + +type Marshaller interface { Marshal([]byte) []byte } @@ -22,6 +35,11 @@ type dataPacketSender interface { RelayDataPacket(pkt []byte, peer, relay *RemotePeer) } +type controlPacketSender interface { + SendControlPacket(pkt Marshaller, peer *RemotePeer) + RelayControlPacket(pkt Marshaller, peer, relay *RemotePeer) +} + type encryptedPacketSender interface { SendEncryptedDataPacket(pkt []byte, peer *RemotePeer) } @@ -29,7 +47,3 @@ type encryptedPacketSender interface { type controlMsgHandler interface { HandleControlMsg(pkt any) } - -type mcUDPWriter interface { - WriteToUDP([]byte, *net.UDPAddr) (int, error) -} diff --git a/peer/mcreader.go b/peer/mcreader.go index 7d5c959..38921f1 100644 --- a/peer/mcreader.go +++ b/peer/mcreader.go @@ -50,7 +50,7 @@ func (r *mcReader) handleNextPacket() { return } - r.super.HandleControlMsg(controlMsg[localDiscoveryPacket]{ + r.super.HandleControlMsg(controlMsg[PacketLocalDiscovery]{ SrcIP: h.SourceIP, SrcAddr: remoteAddr, }) diff --git a/peer/mcreader_test.go b/peer/mcreader_test.go new file mode 100644 index 0000000..50bf821 --- /dev/null +++ b/peer/mcreader_test.go @@ -0,0 +1,138 @@ +package peer + +import ( + "bytes" + "net" + "net/netip" + "sync/atomic" + "testing" +) + +type mcMockConn struct { + packets chan []byte +} + +func newMCMockConn() *mcMockConn { + return &mcMockConn{make(chan []byte, 32)} +} + +func (c *mcMockConn) WriteToUDP(in []byte, addr *net.UDPAddr) (int, error) { + c.packets <- bytes.Clone(in) + return len(in), nil +} + +func (c *mcMockConn) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) { + buf := <-c.packets + b = b[:len(buf)] + copy(b, buf) + return len(b), netip.AddrPort{}, nil +} + +func TestMCReader(t *testing.T) { + keys := generateKeys() + super := &mockControlMsgHandler{} + conn := newMCMockConn() + + peers := [256]*atomic.Pointer[RemotePeer]{} + peer := &RemotePeer{ + IP: 1, + Up: true, + PubSignKey: keys.PubSignKey, + } + peers[1] = &atomic.Pointer[RemotePeer]{} + peers[1].Store(peer) + + w := newMCWriter(conn, 1, keys.PrivSignKey) + r := newMCReader(conn, super, peers) + + w.SendLocalDiscovery() + r.handleNextPacket() + + if len(super.Messages) != 1 { + t.Fatal(super.Messages) + } + msg, ok := super.Messages[0].(controlMsg[PacketLocalDiscovery]) + if !ok || msg.SrcIP != 1 { + t.Fatal(ok, msg) + } +} + +func TestMCReader_noHeader(t *testing.T) { + keys := generateKeys() + super := &mockControlMsgHandler{} + conn := newMCMockConn() + + peers := [256]*atomic.Pointer[RemotePeer]{} + peer := &RemotePeer{ + IP: 1, + Up: true, + PubSignKey: keys.PubSignKey, + } + peers[1] = &atomic.Pointer[RemotePeer]{} + peers[1].Store(peer) + + r := newMCReader(conn, super, peers) + conn.WriteToUDP([]byte("0123546789"), nil) + r.handleNextPacket() + + if len(super.Messages) != 0 { + t.Fatal(super.Messages) + } +} + +func TestMCReader_noPeer(t *testing.T) { + keys := generateKeys() + super := &mockControlMsgHandler{} + conn := newMCMockConn() + + peers := [256]*atomic.Pointer[RemotePeer]{} + peer := &RemotePeer{ + IP: 1, + Up: true, + PubSignKey: keys.PubSignKey, + } + peers[1] = &atomic.Pointer[RemotePeer]{} + peers[2] = &atomic.Pointer[RemotePeer]{} + peers[1].Store(peer) + + w := newMCWriter(conn, 2, keys.PrivSignKey) + r := newMCReader(conn, super, peers) + + w.SendLocalDiscovery() + r.handleNextPacket() + + if len(super.Messages) != 0 { + t.Fatal(super.Messages) + } +} + +func TestMCReader_badSignature(t *testing.T) { + keys := generateKeys() + super := &mockControlMsgHandler{} + conn := newMCMockConn() + + peers := [256]*atomic.Pointer[RemotePeer]{} + peer := &RemotePeer{ + IP: 1, + Up: true, + PubSignKey: keys.PubSignKey, + } + peers[1] = &atomic.Pointer[RemotePeer]{} + peers[1].Store(peer) + + w := newMCWriter(conn, 1, keys.PrivSignKey) + w.SendLocalDiscovery() + + // Break signing. + packet := <-conn.packets + packet[0]++ + conn.packets <- packet + + r := newMCReader(conn, super, peers) + + r.handleNextPacket() + + if len(super.Messages) != 0 { + t.Fatal(super.Messages) + } +} diff --git a/peer/mock-iface_test.go b/peer/mock-iface_test.go new file mode 100644 index 0000000..ffef5d9 --- /dev/null +++ b/peer/mock-iface_test.go @@ -0,0 +1,31 @@ +package peer + +import "bytes" + +type TestIFace struct { + out *bytes.Buffer // Toward the network. + in *bytes.Buffer // From the network +} + +func NewTestIFace() *TestIFace { + return &TestIFace{ + out: &bytes.Buffer{}, + in: &bytes.Buffer{}, + } +} + +func (iface *TestIFace) Write(b []byte) (int, error) { + return iface.in.Write(b) +} + +func (iface *TestIFace) Read(b []byte) (int, error) { + return iface.out.Read(b) +} + +func (iface *TestIFace) UserWrite(b []byte) (int, error) { + return iface.out.Write(b) +} + +func (iface *TestIFace) UserRead(b []byte) (int, error) { + return iface.in.Read(b) +} diff --git a/peer/mock-network_test.go b/peer/mock-network_test.go new file mode 100644 index 0000000..4b5240c --- /dev/null +++ b/peer/mock-network_test.go @@ -0,0 +1,80 @@ +package peer + +import ( + "bytes" + "net" + "net/netip" + "sync" +) + +type TestPacket struct { + Addr netip.AddrPort + Data []byte +} + +type TestNetwork struct { + lock sync.Mutex + packets map[netip.AddrPort]chan TestPacket +} + +func NewTestNetwork() *TestNetwork { + return &TestNetwork{packets: map[netip.AddrPort]chan TestPacket{}} +} + +func (n *TestNetwork) NewUDPConn(localAddr netip.AddrPort) *TestUDPConn { + n.lock.Lock() + defer n.lock.Unlock() + if _, ok := n.packets[localAddr]; !ok { + n.packets[localAddr] = make(chan TestPacket, 1024) + } + return &TestUDPConn{ + addr: localAddr, + n: n, + packets: n.packets[localAddr], + } +} + +func (n *TestNetwork) write(b []byte, from, to netip.AddrPort) { + n.lock.Lock() + defer n.lock.Unlock() + if _, ok := n.packets[to]; !ok { + n.packets[to] = make(chan TestPacket, 1024) + } + n.packets[to] <- TestPacket{ + Addr: from, + Data: bytes.Clone(b), + } +} + +type TestUDPConn struct { + addr netip.AddrPort + n *TestNetwork + packets chan TestPacket +} + +func (c *TestUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { + c.n.write(b, c.addr, addr) + return len(b), nil +} + +func (c *TestUDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) { + return c.WriteToUDPAddrPort(b, addr.AddrPort()) +} + +func (c *TestUDPConn) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) { + pkt := <-c.packets + b = b[:len(pkt.Data)] + copy(b, pkt.Data) + return len(b), pkt.Addr, nil +} + +func (c *TestUDPConn) Packets() (out []TestPacket) { + for { + select { + case pkt := <-c.packets: + out = append(out, pkt) + default: + return + } + } +} diff --git a/peer/packets-util.go b/peer/packets-util.go index bda33b9..c0264e5 100644 --- a/peer/packets-util.go +++ b/peer/packets-util.go @@ -70,7 +70,7 @@ func (w *binWriter) AddrPort(addrPort netip.AddrPort) *binWriter { return w.Uint16(addrPort.Port()) } -func (w *binWriter) AddrPortArray(l [8]netip.AddrPort) *binWriter { +func (w *binWriter) AddrPort8(l [8]netip.AddrPort) *binWriter { for _, addrPort := range l { w.AddrPort(addrPort) } @@ -178,7 +178,7 @@ func (r *binReader) AddrPort(x *netip.AddrPort) *binReader { return r } -func (r *binReader) AddrPortArray(x *[8]netip.AddrPort) *binReader { +func (r *binReader) AddrPort8(x *[8]netip.AddrPort) *binReader { for i := range x { r.AddrPort(&x[i]) } diff --git a/peer/packets-util_test.go b/peer/packets-util_test.go index 5a518d7..6e4a98c 100644 --- a/peer/packets-util_test.go +++ b/peer/packets-util_test.go @@ -6,6 +6,26 @@ import ( "testing" ) +func TestBinWriteRead_invalidAddrPort(t *testing.T) { + addr := netip.AddrPort{} + buf := make([]byte, 1024) + buf = newBinWriter(buf). + AddrPort(addr). + Build() + + var addr2 netip.AddrPort + err := newBinReader(buf). + AddrPort(&addr2). + Error() + if err != nil { + t.Fatal(err) + } + + if addr2.IsValid() { + t.Fatal(addr, addr2) + } +} + func TestBinWriteRead(t *testing.T) { buf := make([]byte, 1024) @@ -35,7 +55,7 @@ func TestBinWriteRead(t *testing.T) { Byte(in.Type). Uint64(in.TraceID). AddrPort(in.DestAddr). - AddrPortArray(in.Addrs). + AddrPort8(in.Addrs). Build() out := Item{} @@ -44,7 +64,7 @@ func TestBinWriteRead(t *testing.T) { Byte(&out.Type). Uint64(&out.TraceID). AddrPort(&out.DestAddr). - AddrPortArray(&out.Addrs). + AddrPort8(&out.Addrs). Error() if err != nil { t.Fatal(err) diff --git a/peer/packets.go b/peer/packets.go index f7f1f85..596483d 100644 --- a/peer/packets.go +++ b/peer/packets.go @@ -5,93 +5,70 @@ import ( ) const ( - packetTypeSyn = iota + 1 - packetTypeSynAck - packetTypeAck - packetTypeProbe - packetTypeAddrDiscovery + PacketTypeSyn = iota + 1 + PacketTypeSynAck + PacketTypeAck + PacketTypeProbe + PacketTypeAddrDiscovery ) // ---------------------------------------------------------------------------- -type synPacket struct { +type PacketSyn struct { TraceID uint64 // TraceID to match response w/ request. - // TODO: SentAt int64 // Unixmilli. + //SentAt int64 // Unixmilli. + //SharedKeyType byte // Currently only 1 is supported for AES. SharedKey [32]byte // Our shared key. Direct bool PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. } -func (p synPacket) Marshal(buf []byte) []byte { +func (p PacketSyn) Marshal(buf []byte) []byte { return newBinWriter(buf). - Byte(packetTypeSyn). + Byte(PacketTypeSyn). Uint64(p.TraceID). + //Int64(p.SentAt). + //Byte(p.SharedKeyType). SharedKey(p.SharedKey). Bool(p.Direct). - AddrPort(p.PossibleAddrs[0]). - AddrPort(p.PossibleAddrs[1]). - AddrPort(p.PossibleAddrs[2]). - AddrPort(p.PossibleAddrs[3]). - AddrPort(p.PossibleAddrs[4]). - AddrPort(p.PossibleAddrs[5]). - AddrPort(p.PossibleAddrs[6]). - AddrPort(p.PossibleAddrs[7]). + AddrPort8(p.PossibleAddrs). Build() } -func parseSynPacket(buf []byte) (p synPacket, err error) { +func ParsePacketSyn(buf []byte) (p PacketSyn, err error) { err = newBinReader(buf[1:]). Uint64(&p.TraceID). + //Int64(&p.SentAt). + //Byte(&p.SharedKeyType). SharedKey(&p.SharedKey). Bool(&p.Direct). - AddrPort(&p.PossibleAddrs[0]). - AddrPort(&p.PossibleAddrs[1]). - AddrPort(&p.PossibleAddrs[2]). - AddrPort(&p.PossibleAddrs[3]). - AddrPort(&p.PossibleAddrs[4]). - AddrPort(&p.PossibleAddrs[5]). - AddrPort(&p.PossibleAddrs[6]). - AddrPort(&p.PossibleAddrs[7]). + AddrPort8(&p.PossibleAddrs). Error() return } // ---------------------------------------------------------------------------- -type ackPacket struct { +type PacketAck struct { TraceID uint64 ToAddr netip.AddrPort PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. } -func (p ackPacket) Marshal(buf []byte) []byte { +func (p PacketAck) Marshal(buf []byte) []byte { return newBinWriter(buf). - Byte(packetTypeAck). + Byte(PacketTypeAck). Uint64(p.TraceID). AddrPort(p.ToAddr). - AddrPort(p.PossibleAddrs[0]). - AddrPort(p.PossibleAddrs[1]). - AddrPort(p.PossibleAddrs[2]). - AddrPort(p.PossibleAddrs[3]). - AddrPort(p.PossibleAddrs[4]). - AddrPort(p.PossibleAddrs[5]). - AddrPort(p.PossibleAddrs[6]). - AddrPort(p.PossibleAddrs[7]). + AddrPort8(p.PossibleAddrs). Build() } -func parseAckPacket(buf []byte) (p ackPacket, err error) { +func ParsePacketAck(buf []byte) (p PacketAck, err error) { err = newBinReader(buf[1:]). Uint64(&p.TraceID). AddrPort(&p.ToAddr). - AddrPort(&p.PossibleAddrs[0]). - AddrPort(&p.PossibleAddrs[1]). - AddrPort(&p.PossibleAddrs[2]). - AddrPort(&p.PossibleAddrs[3]). - AddrPort(&p.PossibleAddrs[4]). - AddrPort(&p.PossibleAddrs[5]). - AddrPort(&p.PossibleAddrs[6]). - AddrPort(&p.PossibleAddrs[7]). + AddrPort8(&p.PossibleAddrs). Error() return } @@ -100,18 +77,18 @@ func parseAckPacket(buf []byte) (p ackPacket, err error) { // A probeReqPacket is sent from a client to a server to determine if direct // UDP communication can be used. -type probePacket struct { +type PacketProbe struct { TraceID uint64 } -func (p probePacket) Marshal(buf []byte) []byte { +func (p PacketProbe) Marshal(buf []byte) []byte { return newBinWriter(buf). - Byte(packetTypeProbe). + Byte(PacketTypeProbe). Uint64(p.TraceID). Build() } -func parseProbePacket(buf []byte) (p probePacket, err error) { +func ParsePacketProbe(buf []byte) (p PacketProbe, err error) { err = newBinReader(buf[1:]). Uint64(&p.TraceID). Error() @@ -120,4 +97,4 @@ func parseProbePacket(buf []byte) (p probePacket, err error) { // ---------------------------------------------------------------------------- -type localDiscoveryPacket struct{} +type PacketLocalDiscovery struct{} diff --git a/peer/packets_test.go b/peer/packets_test.go index 333deff..3ddc1a0 100644 --- a/peer/packets_test.go +++ b/peer/packets_test.go @@ -1 +1,66 @@ package peer + +import ( + "crypto/rand" + "net/netip" + "reflect" + "testing" +) + +func TestSynPacket(t *testing.T) { + p := PacketSyn{ + TraceID: newTraceID(), + //SentAt: time.Now().UnixMilli(), + //SharedKeyType: 1, + Direct: true, + } + rand.Read(p.SharedKey[:]) + + p.PossibleAddrs[0] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 234) + p.PossibleAddrs[1] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 2, 3, 4}), 12399) + p.PossibleAddrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{3, 2, 3, 4}), 60000) + + buf := p.Marshal(newBuf()) + p2, err := ParsePacketSyn(buf) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(p, p2) { + t.Fatal(p2) + } +} + +func TestAckPacket(t *testing.T) { + p := PacketAck{ + TraceID: newTraceID(), + ToAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 234), + } + + p.PossibleAddrs[0] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 2, 3, 4}), 100) + p.PossibleAddrs[1] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 2, 3, 4}), 12399) + p.PossibleAddrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{3, 2, 3, 4}), 60000) + + buf := p.Marshal(newBuf()) + p2, err := ParsePacketAck(buf) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(p, p2) { + t.Fatal(p2) + } +} + +func TestProbePacket(t *testing.T) { + p := PacketProbe{ + TraceID: newTraceID(), + } + + buf := p.Marshal(newBuf()) + p2, err := ParsePacketProbe(buf) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(p, p2) { + t.Fatal(p2) + } +} diff --git a/peer/peer_test.go b/peer/peer_test.go new file mode 100644 index 0000000..414beaa --- /dev/null +++ b/peer/peer_test.go @@ -0,0 +1,125 @@ +package peer + +import ( + "bytes" + "crypto/rand" + mrand "math/rand" + "net/netip" + "sync/atomic" +) + +// A test peer. +type P struct { + cryptoKeys + RT *atomic.Pointer[RoutingTable] + Conn *TestUDPConn + IFace *TestIFace + ConnWriter *ConnWriter + ConnReader *ConnReader + IFReader *IFReader + Super *Supervisor +} + +func NewPeerForTesting(n *TestNetwork, ip byte, addr netip.AddrPort) P { + p := P{ + cryptoKeys: generateKeys(), + RT: &atomic.Pointer[RoutingTable]{}, + IFace: NewTestIFace(), + } + + rt := NewRoutingTable(ip, addr) + p.RT.Store(&rt) + p.Conn = n.NewUDPConn(addr) + p.ConnWriter = NewConnWriter(p.Conn.WriteToUDPAddrPort, p.RT) + p.IFReader = NewIFReader(p.IFace, p.ConnWriter) + + /* + p.ConnReader = NewConnReader( + p.Conn.ReadFromUDPAddrPort, + p.IFace, + p.ConnWriter.Forward, + p.Super.HandleControlMsg, + p.RT) + */ + return p +} + +func ConnectPeers(p1, p2 *P) { + rt1 := p1.RT.Load() + rt2 := p2.RT.Load() + + ip1 := rt1.LocalIP + ip2 := rt2.LocalIP + + rt1.Peers[ip2].Up = true + rt1.Peers[ip2].Direct = true + rt1.Peers[ip2].Relay = true + rt1.Peers[ip2].DirectAddr = rt2.LocalAddr + rt1.Peers[ip2].PubSignKey = p2.PubSignKey + rt1.Peers[ip2].ControlCipher = newControlCipher(p1.PrivKey, p2.PubKey) + rt1.Peers[ip2].DataCipher = newDataCipher() + + rt2.Peers[ip1].Up = true + rt2.Peers[ip1].Direct = true + rt2.Peers[ip1].Relay = true + rt2.Peers[ip1].DirectAddr = rt1.LocalAddr + rt2.Peers[ip1].PubSignKey = p1.PubSignKey + rt2.Peers[ip1].ControlCipher = newControlCipher(p2.PrivKey, p1.PubKey) + rt2.Peers[ip1].DataCipher = rt1.Peers[ip2].DataCipher +} + +func NewPeersForTesting() (p1, p2, p3 P) { + n := NewTestNetwork() + + p1 = NewPeerForTesting( + n, + 1, + netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 100)) + + p2 = NewPeerForTesting( + n, + 2, + netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 2}), 200)) + + p3 = NewPeerForTesting( + n, + 3, + netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 3}), 300)) + + ConnectPeers(&p1, &p2) + ConnectPeers(&p1, &p3) + ConnectPeers(&p2, &p3) + + return +} + +func RandPacket() []byte { + n := mrand.Intn(1200) + b := make([]byte, n) + rand.Read(b) + return b +} + +func ModifyPacket(in []byte) []byte { + x := make([]byte, 1) + + for { + rand.Read(x) + out := bytes.Clone(in) + idx := mrand.Intn(len(out)) + if out[idx] != x[0] { + out[idx] = x[0] + return out + } + } +} + +// ---------------------------------------------------------------------------- + +type UnknownControlPacket struct { + TraceID uint64 +} + +func (p UnknownControlPacket) Marshal(buf []byte) []byte { + return newBinWriter(buf).Byte(255).Uint64(p.TraceID).Build() +} diff --git a/peer/peerstates.go b/peer/peerstates.go new file mode 100644 index 0000000..b05826c --- /dev/null +++ b/peer/peerstates.go @@ -0,0 +1,355 @@ +package peer + +import ( + "fmt" + "log" + "net/netip" + "strings" + "time" + "vppn/m" + + "git.crumpington.com/lib/go/ratelimiter" +) + +type PeerState interface { + OnPeerUpdate(*m.Peer) PeerState + OnSyn(controlMsg[PacketSyn]) PeerState + OnAck(controlMsg[PacketAck]) + OnProbe(controlMsg[PacketProbe]) PeerState + OnLocalDiscovery(controlMsg[PacketLocalDiscovery]) + OnPingTimer() PeerState +} + +// ---------------------------------------------------------------------------- + +type State struct { + // Output. + publish func(RemotePeer) + sendControlPacket func(RemotePeer, Marshaller) + + // Immutable data. + localIP byte + remoteIP byte + privKey []byte + localAddr netip.AddrPort // If valid, then local peer is publicly accessible. + + pubAddrs *pubAddrStore + + // The purpose of this state machine is to manage the RemotePeer object, + // publishing it as necessary. + staged RemotePeer // Local copy of shared data. See publish(). + + // Mutable peer data. + peer *m.Peer + + // We rate limit per remote endpoint because if we don't we tend to lose + // packets. + limiter *ratelimiter.Limiter +} + +func (s *State) OnPeerUpdate(peer *m.Peer) PeerState { + defer func() { + // Don't defer directly otherwise s.staged will be evaluated immediately + // and won't reflect changes made in the function. + s.publish(s.staged) + }() + + if peer == nil { + return EnterStateDisconnected(s) + } + + s.peer = peer + s.staged.Relay = false + s.staged.Direct = false + s.staged.DirectAddr = netip.AddrPort{} + s.staged.PubSignKey = nil + s.staged.PubSignKey = peer.PubSignKey + s.staged.ControlCipher = newControlCipher(s.privKey, peer.PubKey) + s.staged.DataCipher = newDataCipher() + + if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid { + s.staged.Relay = peer.Relay + s.staged.Direct = true + s.staged.DirectAddr = netip.AddrPortFrom(ip, peer.Port) + + if s.localAddr.IsValid() && s.localIP < s.remoteIP { + return EnterStateServer(s) + } + + return EnterStateClientDirect(s) + } + + if s.localAddr.IsValid() { + s.staged.Direct = true + return EnterStateServer(s) + } + + if s.localIP < s.remoteIP { + return EnterStateServer(s) + } + + return EnterStateClientRelayed(s) +} + +func (s *State) logf(format string, args ...any) { + b := strings.Builder{} + name := "--" + if s.peer != nil { + name = s.peer.Name + } + b.WriteString(fmt.Sprintf("%30s: ", name)) + + if s.staged.Direct { + b.WriteString("DIRECT | ") + } else { + b.WriteString("RELAYED | ") + } + + if s.staged.Up { + b.WriteString("UP | ") + } else { + b.WriteString("DOWN | ") + } + + log.Printf(b.String()+format, args...) +} + +// ---------------------------------------------------------------------------- + +func (s *State) SendTo(pkt Marshaller, addr netip.AddrPort) { + if !addr.IsValid() { + return + } + route := s.staged + route.Direct = true + route.DirectAddr = addr + s.Send(route, pkt) +} + +func (s *State) Send(peer RemotePeer, pkt Marshaller) { + if err := s.limiter.Limit(); err != nil { + s.logf("Rate limited.") + return + } + s.sendControlPacket(peer, pkt) +} + +// ---------------------------------------------------------------------------- + +type StateDisconnected struct{ *State } + +func EnterStateDisconnected(s *State) PeerState { + s.logf("==> Disconnected") + s.peer = nil + s.staged.Up = false + s.staged.Relay = false + s.staged.Direct = false + s.staged.DirectAddr = netip.AddrPort{} + s.staged.PubSignKey = nil + s.staged.ControlCipher = nil + s.staged.DataCipher = nil + s.publish(s.staged) + return &StateDisconnected{State: s} +} + +func (s *StateDisconnected) OnSyn(controlMsg[PacketSyn]) PeerState { return nil } +func (s *StateDisconnected) OnAck(controlMsg[PacketAck]) {} +func (s *StateDisconnected) OnProbe(controlMsg[PacketProbe]) PeerState { return nil } +func (s *StateDisconnected) OnLocalDiscovery(controlMsg[PacketLocalDiscovery]) {} +func (s *StateDisconnected) OnPingTimer() PeerState { return nil } + +// ---------------------------------------------------------------------------- + +type StateServer struct { + *StateDisconnected + lastSeen time.Time + synTraceID uint64 +} + +func EnterStateServer(s *State) PeerState { + s.logf("==> Server") + return &StateServer{StateDisconnected: &StateDisconnected{State: s}} +} + +func (s *StateServer) OnSyn(msg controlMsg[PacketSyn]) PeerState { + s.lastSeen = time.Now() + p := msg.Packet + + // Before we can respond to this packet, we need to make sure the + // route is setup properly. + // + // The client will update the syn's TraceID whenever there's a change. + // The server will follow the client's request. + if p.TraceID != s.synTraceID || !s.staged.Up { + s.synTraceID = p.TraceID + s.staged.Up = true + s.staged.Direct = p.Direct + s.staged.DataCipher = newDataCipherFromKey(p.SharedKey) + s.staged.DirectAddr = msg.SrcAddr + s.publish(s.staged) + s.logf("Got SYN.") + } + + // Always respond. + ack := PacketAck{ + TraceID: p.TraceID, + ToAddr: s.staged.DirectAddr, + PossibleAddrs: s.pubAddrs.Get(), + } + s.Send(s.staged, ack) + + if p.Direct { + return nil + } + + for _, addr := range msg.Packet.PossibleAddrs { + if !addr.IsValid() { + break + } + s.SendTo(PacketProbe{TraceID: newTraceID()}, addr) + } + + return nil +} + +func (s *StateServer) OnProbe(msg controlMsg[PacketProbe]) PeerState { + if msg.SrcAddr.IsValid() { + s.SendTo(PacketProbe{TraceID: msg.Packet.TraceID}, msg.SrcAddr) + } + return nil +} + +func (s *StateServer) OnPingTimer() PeerState { + if time.Since(s.lastSeen) > timeoutInterval && s.staged.Up { + s.staged.Up = false + s.publish(s.staged) + s.logf("Timeout.") + } + return nil +} + +// ---------------------------------------------------------------------------- + +type StateClientDirect struct { + *StateDisconnected + lastSeen time.Time + syn PacketSyn +} + +func EnterStateClientDirect(s *State) PeerState { + s.logf("==> ClientDirect") + return NewStateClientDirect(s) +} + +func NewStateClientDirect(s *State) *StateClientDirect { + state := &StateClientDirect{ + StateDisconnected: &StateDisconnected{s}, + lastSeen: time.Now(), // Avoid immediate timeout. + } + + state.syn = PacketSyn{ + TraceID: newTraceID(), + SharedKey: s.staged.DataCipher.Key(), + Direct: s.staged.Direct, + PossibleAddrs: s.pubAddrs.Get(), + } + state.Send(s.staged, state.syn) + return state +} + +func (s *StateClientDirect) OnAck(msg controlMsg[PacketAck]) { + if msg.Packet.TraceID != s.syn.TraceID { + return + } + + s.lastSeen = time.Now() + + if !s.staged.Up { + s.staged.Up = true + s.publish(s.staged) + s.logf("Got ACK.") + } + + s.pubAddrs.Store(msg.Packet.ToAddr) +} + +func (s *StateClientDirect) OnPingTimer() PeerState { + if time.Since(s.lastSeen) > timeoutInterval { + if s.staged.Up { + s.staged.Up = false + s.publish(s.staged) + s.logf("Timeout.") + } + return s.OnPeerUpdate(s.peer) + } + + s.Send(s.staged, s.syn) + return nil +} + +// ---------------------------------------------------------------------------- + +type StateClientRelayed struct { + *StateClientDirect + ack PacketAck + probes map[uint64]netip.AddrPort + localDiscoveryAddr netip.AddrPort +} + +func EnterStateClientRelayed(s *State) PeerState { + s.logf("==> ClientRelayed") + return &StateClientRelayed{ + StateClientDirect: NewStateClientDirect(s), + probes: map[uint64]netip.AddrPort{}, + } +} + +func (s *StateClientRelayed) OnAck(msg controlMsg[PacketAck]) { + s.ack = msg.Packet + s.StateClientDirect.OnAck(msg) +} + +func (s *StateClientRelayed) OnProbe(msg controlMsg[PacketProbe]) PeerState { + addr, ok := s.probes[msg.Packet.TraceID] + if !ok { + return nil + } + + s.staged.DirectAddr = addr + s.staged.Direct = true + s.publish(s.staged) + return EnterStateClientDirect(s.StateClientDirect.State) +} + +func (s *StateClientRelayed) OnLocalDiscovery(msg controlMsg[PacketLocalDiscovery]) { + // The source port will be the multicast port, so we'll have to + // construct the correct address using the peer's listed port. + s.localDiscoveryAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) +} + +func (s *StateClientRelayed) OnPingTimer() PeerState { + if nextState := s.StateClientDirect.OnPingTimer(); nextState != nil { + return nextState + } + + clear(s.probes) + for _, addr := range s.ack.PossibleAddrs { + if !addr.IsValid() { + break + } + s.sendProbeTo(addr) + } + + if s.localDiscoveryAddr.IsValid() { + s.sendProbeTo(s.localDiscoveryAddr) + s.localDiscoveryAddr = netip.AddrPort{} + } + + return nil +} + +func (s *StateClientRelayed) sendProbeTo(addr netip.AddrPort) { + probe := PacketProbe{TraceID: newTraceID()} + s.probes[probe.TraceID] = addr + s.SendTo(probe, addr) +} diff --git a/peer/peerstates_test.go b/peer/peerstates_test.go new file mode 100644 index 0000000..16805d0 --- /dev/null +++ b/peer/peerstates_test.go @@ -0,0 +1,509 @@ +package peer + +import ( + "net/netip" + "testing" + "time" + "vppn/m" + + "git.crumpington.com/lib/go/ratelimiter" +) + +// ---------------------------------------------------------------------------- + +type PeerStateControlMsg struct { + Peer RemotePeer + Packet any +} + +type PeerStateTestHarness struct { + State PeerState + Published RemotePeer + Sent []PeerStateControlMsg +} + +func NewPeerStateTestHarness() *PeerStateTestHarness { + h := &PeerStateTestHarness{} + + keys := generateKeys() + + state := &State{ + publish: func(rp RemotePeer) { + h.Published = rp + }, + sendControlPacket: func(rp RemotePeer, pkt Marshaller) { + h.Sent = append(h.Sent, PeerStateControlMsg{rp, pkt}) + }, + localIP: 2, + remoteIP: 3, + privKey: keys.PrivKey, + pubAddrs: newPubAddrStore(netip.AddrPort{}), + limiter: ratelimiter.New(ratelimiter.Config{ + FillPeriod: 20 * time.Millisecond, + MaxWaitCount: 1, + }), + } + + h.State = EnterStateDisconnected(state) + return h +} + +func (h *PeerStateTestHarness) PeerUpdate(p *m.Peer) { + if s := h.State.OnPeerUpdate(p); s != nil { + h.State = s + } +} + +func (h *PeerStateTestHarness) OnSyn(msg controlMsg[PacketSyn]) { + if s := h.State.OnSyn(msg); s != nil { + h.State = s + } +} + +func (h *PeerStateTestHarness) OnProbe(msg controlMsg[PacketProbe]) { + if s := h.State.OnProbe(msg); s != nil { + h.State = s + } +} + +func (h *PeerStateTestHarness) OnPingTimer() { + if s := h.State.OnPingTimer(); s != nil { + h.State = s + } +} + +func (h *PeerStateTestHarness) ConfigServer_Public(t *testing.T) *StateServer { + keys := generateKeys() + + state := h.State.(*StateDisconnected) + state.localAddr = addrPort4(1, 1, 1, 2, 200) + + peer := &m.Peer{ + PeerIP: 3, + PublicIP: []byte{1, 1, 1, 3}, + Port: 456, + PubKey: keys.PubKey, + PubSignKey: keys.PubSignKey, + } + + h.PeerUpdate(peer) + assertEqual(t, h.Published.Up, false) + return assertType[*StateServer](t, h.State) +} + +func (h *PeerStateTestHarness) ConfigServer_Relayed(t *testing.T) *StateServer { + keys := generateKeys() + peer := &m.Peer{ + PeerIP: 3, + Port: 456, + PubKey: keys.PubKey, + PubSignKey: keys.PubSignKey, + } + + h.PeerUpdate(peer) + assertEqual(t, h.Published.Up, false) + return assertType[*StateServer](t, h.State) +} + +func (h *PeerStateTestHarness) ConfigClientDirect(t *testing.T) *StateClientDirect { + keys := generateKeys() + peer := &m.Peer{ + PeerIP: 3, + PublicIP: []byte{1, 2, 3, 4}, + Port: 456, + PubKey: keys.PubKey, + PubSignKey: keys.PubSignKey, + } + + h.PeerUpdate(peer) + assertEqual(t, h.Published.Up, false) + return assertType[*StateClientDirect](t, h.State) +} + +func (h *PeerStateTestHarness) ConfigClientRelayed(t *testing.T) *StateClientRelayed { + keys := generateKeys() + + state := h.State.(*StateDisconnected) + state.remoteIP = 1 + + peer := &m.Peer{ + PeerIP: 3, + Port: 456, + PubKey: keys.PubKey, + PubSignKey: keys.PubSignKey, + } + + h.PeerUpdate(peer) + assertEqual(t, h.Published.Up, false) + return assertType[*StateClientRelayed](t, h.State) +} + +// ---------------------------------------------------------------------------- + +func TestPeerState_OnPeerUpdate_nilPeer(t *testing.T) { + h := NewPeerStateTestHarness() + h.PeerUpdate(nil) + assertType[*StateDisconnected](t, h.State) +} + +func TestPeerState_OnPeerUpdate_publicLocalIsServer(t *testing.T) { + keys := generateKeys() + h := NewPeerStateTestHarness() + + state := h.State.(*StateDisconnected) + state.localAddr = addrPort4(1, 1, 1, 2, 200) + + peer := &m.Peer{ + PeerIP: 3, + Port: 456, + PubKey: keys.PubKey, + PubSignKey: keys.PubSignKey, + } + + h.PeerUpdate(peer) + assertEqual(t, h.Published.Up, false) + assertType[*StateServer](t, h.State) +} + +func TestPeerState_OnPeerUpdate_serverDirect(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigServer_Public(t) +} + +func TestPeerState_OnPeerUpdate_serverRelayed(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigServer_Relayed(t) +} + +func TestPeerState_OnPeerUpdate_clientDirect(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientDirect(t) +} + +func TestPeerState_OnPeerUpdate_clientRelayed(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientRelayed(t) +} + +func TestStateServer_directSyn(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigServer_Relayed(t) + + assertEqual(t, h.Published.Up, false) + + synMsg := controlMsg[PacketSyn]{ + SrcIP: 3, + SrcAddr: addrPort4(1, 1, 1, 3, 300), + Packet: PacketSyn{ + TraceID: newTraceID(), + //SentAt: time.Now().UnixMilli(), + //SharedKeyType: 1, + Direct: true, + }, + } + + h.State.OnSyn(synMsg) + + assertEqual(t, len(h.Sent), 1) + ack := assertType[PacketAck](t, h.Sent[0].Packet) + assertEqual(t, ack.TraceID, synMsg.Packet.TraceID) + assertEqual(t, h.Sent[0].Peer.IP, 3) + assertEqual(t, ack.PossibleAddrs[0].IsValid(), false) + assertEqual(t, h.Published.Up, true) +} + +func TestStateServer_relayedSyn(t *testing.T) { + h := NewPeerStateTestHarness() + state := h.ConfigServer_Relayed(t) + + state.pubAddrs.Store(addrPort4(4, 5, 6, 7, 1234)) + + assertEqual(t, h.Published.Up, false) + + synMsg := controlMsg[PacketSyn]{ + SrcIP: 3, + SrcAddr: addrPort4(1, 1, 1, 3, 300), + Packet: PacketSyn{ + TraceID: newTraceID(), + //SentAt: time.Now().UnixMilli(), + //SharedKeyType: 1, + Direct: false, + }, + } + synMsg.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 3, 300) + synMsg.Packet.PossibleAddrs[1] = addrPort4(2, 2, 2, 3, 300) + + h.State.OnSyn(synMsg) + + assertEqual(t, len(h.Sent), 3) + + ack := assertType[PacketAck](t, h.Sent[0].Packet) + assertEqual(t, ack.TraceID, synMsg.Packet.TraceID) + assertEqual(t, h.Sent[0].Peer.IP, 3) + assertEqual(t, ack.PossibleAddrs[0], addrPort4(4, 5, 6, 7, 1234)) + assertEqual(t, ack.PossibleAddrs[1].IsValid(), false) + assertEqual(t, h.Published.Up, true) + + assertType[PacketProbe](t, h.Sent[1].Packet) + assertType[PacketProbe](t, h.Sent[2].Packet) + assertEqual(t, h.Sent[1].Peer.DirectAddr, addrPort4(1, 1, 1, 3, 300)) + assertEqual(t, h.Sent[2].Peer.DirectAddr, addrPort4(2, 2, 2, 3, 300)) +} + +func TestStateServer_onProbe(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigServer_Relayed(t) + assertEqual(t, h.Published.Up, false) + + probeMsg := controlMsg[PacketProbe]{ + SrcIP: 3, + SrcAddr: addrPort4(1, 1, 1, 3, 300), + Packet: PacketProbe{TraceID: newTraceID()}, + } + + h.State.OnProbe(probeMsg) + + assertEqual(t, len(h.Sent), 1) + + probe := assertType[PacketProbe](t, h.Sent[0].Packet) + assertEqual(t, probe.TraceID, probeMsg.Packet.TraceID) + assertEqual(t, h.Sent[0].Peer.DirectAddr, addrPort4(1, 1, 1, 3, 300)) +} + +func TestStateServer_OnPingTimer_timeout(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigServer_Relayed(t) + + synMsg := controlMsg[PacketSyn]{ + SrcIP: 3, + SrcAddr: addrPort4(1, 1, 1, 3, 300), + Packet: PacketSyn{ + TraceID: newTraceID(), + //SentAt: time.Now().UnixMilli(), + //SharedKeyType: 1, + Direct: true, + }, + } + + h.State.OnSyn(synMsg) + assertEqual(t, len(h.Sent), 1) + assertEqual(t, h.Published.Up, true) + + // Ping shouldn't timeout. + h.OnPingTimer() + assertEqual(t, h.Published.Up, true) + + // Advance the time, then ping. + state := assertType[*StateServer](t, h.State) + state.lastSeen = time.Now().Add(-timeoutInterval - time.Second) + + h.OnPingTimer() + assertEqual(t, h.Published.Up, false) +} + +func TestStateClientDirect_OnAck(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientDirect(t) + + assertEqual(t, h.Published.Up, false) + + // On entering the state, a SYN should have been sent. + assertEqual(t, len(h.Sent), 1) + syn := assertType[PacketSyn](t, h.Sent[0].Packet) + + ack := controlMsg[PacketAck]{ + Packet: PacketAck{TraceID: syn.TraceID}, + } + h.State.OnAck(ack) + assertEqual(t, h.Published.Up, true) +} + +func TestStateClientDirect_OnAck_incorrectTraceID(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientDirect(t) + + assertEqual(t, h.Published.Up, false) + + // On entering the state, a SYN should have been sent. + assertEqual(t, len(h.Sent), 1) + syn := assertType[PacketSyn](t, h.Sent[0].Packet) + + ack := controlMsg[PacketAck]{ + Packet: PacketAck{TraceID: syn.TraceID + 1}, + } + h.State.OnAck(ack) + assertEqual(t, h.Published.Up, false) +} + +func TestStateClientDirect_OnPingTimer(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientDirect(t) + + // On entering the state, a SYN should have been sent. + assertEqual(t, len(h.Sent), 1) + assertType[PacketSyn](t, h.Sent[0].Packet) + + h.OnPingTimer() + + // On ping timer, another syn should be sent. Additionally, we should remain + // in the same state. + assertEqual(t, len(h.Sent), 2) + assertType[PacketSyn](t, h.Sent[1].Packet) + assertType[*StateClientDirect](t, h.State) + assertEqual(t, h.Published.Up, false) +} + +func TestStateClientDirect_OnPingTimer_timeout(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientDirect(t) + + assertEqual(t, h.Published.Up, false) + + // On entering the state, a SYN should have been sent. + assertEqual(t, len(h.Sent), 1) + syn := assertType[PacketSyn](t, h.Sent[0].Packet) + + ack := controlMsg[PacketAck]{ + Packet: PacketAck{TraceID: syn.TraceID}, + } + h.State.OnAck(ack) + assertEqual(t, h.Published.Up, true) + + state := assertType[*StateClientDirect](t, h.State) + state.lastSeen = time.Now().Add(-(timeoutInterval + time.Second)) + + h.OnPingTimer() + + // On ping timer, we should timeout, causing the client to reset. Another SYN + // will be sent when re-entering the state, but the connection should be down. + assertEqual(t, len(h.Sent), 2) + assertType[PacketSyn](t, h.Sent[1].Packet) + assertType[*StateClientDirect](t, h.State) + assertEqual(t, h.Published.Up, false) +} + +func TestStateClientRelayed_OnAck(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientRelayed(t) + + assertEqual(t, h.Published.Up, false) + + // On entering the state, a SYN should have been sent. + assertEqual(t, len(h.Sent), 1) + syn := assertType[PacketSyn](t, h.Sent[0].Packet) + + ack := controlMsg[PacketAck]{ + Packet: PacketAck{TraceID: syn.TraceID}, + } + h.State.OnAck(ack) + assertEqual(t, h.Published.Up, true) +} + +func TestStateClientRelayed_OnPingTimer_noAddrs(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientRelayed(t) + + assertEqual(t, h.Published.Up, false) + + // On entering the state, a SYN should have been sent. + assertEqual(t, len(h.Sent), 1) + + // If we haven't had an ack yet, we won't have addresses to probe. Therefore + // we'll have just one more syn packet sent. + h.OnPingTimer() + assertEqual(t, len(h.Sent), 2) +} + +func TestStateClientRelayed_OnPingTimer_withAddrs(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientRelayed(t) + + assertEqual(t, h.Published.Up, false) + + // On entering the state, a SYN should have been sent. + assertEqual(t, len(h.Sent), 1) + + syn := assertType[PacketSyn](t, h.Sent[0].Packet) + + ack := controlMsg[PacketAck]{Packet: PacketAck{TraceID: syn.TraceID}} + ack.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 300) + ack.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 300) + + h.State.OnAck(ack) + + // Add a local discovery address. Note that the port will be configured port + // and no the one provided here. + h.State.OnLocalDiscovery(controlMsg[PacketLocalDiscovery]{ + SrcIP: 3, + SrcAddr: addrPort4(2, 2, 2, 3, 300), + }) + + // We should see one SYN and three probe packets. + h.OnPingTimer() + assertEqual(t, len(h.Sent), 5) + assertType[PacketSyn](t, h.Sent[1].Packet) + assertType[PacketProbe](t, h.Sent[2].Packet) + assertType[PacketProbe](t, h.Sent[3].Packet) + assertType[PacketProbe](t, h.Sent[4].Packet) + + assertEqual(t, h.Sent[2].Peer.DirectAddr, addrPort4(1, 1, 1, 1, 300)) + assertEqual(t, h.Sent[3].Peer.DirectAddr, addrPort4(1, 1, 1, 2, 300)) + assertEqual(t, h.Sent[4].Peer.DirectAddr, addrPort4(2, 2, 2, 3, 456)) +} + +func TestStateClientRelayed_OnPingTimer_timeout(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientRelayed(t) + + // On entering the state, a SYN should have been sent. + assertEqual(t, len(h.Sent), 1) + syn := assertType[PacketSyn](t, h.Sent[0].Packet) + + ack := controlMsg[PacketAck]{ + Packet: PacketAck{TraceID: syn.TraceID}, + } + h.State.OnAck(ack) + assertEqual(t, h.Published.Up, true) + + state := assertType[*StateClientRelayed](t, h.State) + state.lastSeen = time.Now().Add(-(timeoutInterval + time.Second)) + + h.OnPingTimer() + + // On ping timer, we should timeout, causing the client to reset. Another SYN + // will be sent when re-entering the state, but the connection should be down. + assertEqual(t, len(h.Sent), 2) + assertType[PacketSyn](t, h.Sent[1].Packet) + assertType[*StateClientRelayed](t, h.State) + assertEqual(t, h.Published.Up, false) +} + +func TestStateClientRelayed_OnProbe_unknownAddr(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientRelayed(t) + + h.OnProbe(controlMsg[PacketProbe]{ + Packet: PacketProbe{TraceID: newTraceID()}, + }) + + assertType[*StateClientRelayed](t, h.State) +} + +func TestStateClientRelayed_OnProbe_upgradeDirect(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientRelayed(t) + + syn := assertType[PacketSyn](t, h.Sent[0].Packet) + + ack := controlMsg[PacketAck]{Packet: PacketAck{TraceID: syn.TraceID}} + ack.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 300) + ack.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 300) + + h.State.OnAck(ack) + h.OnPingTimer() + + probe := assertType[PacketProbe](t, h.Sent[2].Packet) + h.OnProbe(controlMsg[PacketProbe]{Packet: probe}) + + assertType[*StateClientDirect](t, h.State) +} diff --git a/peer/pubaddrs.go b/peer/pubaddrs.go new file mode 100644 index 0000000..13ab66f --- /dev/null +++ b/peer/pubaddrs.go @@ -0,0 +1,75 @@ +package peer + +import ( + "log" + "net/netip" + "runtime/debug" + "sort" + "time" +) + +type pubAddrStore struct { + localPub bool + localAddr netip.AddrPort + lastSeen map[netip.AddrPort]time.Time + addrList []netip.AddrPort +} + +func newPubAddrStore(localAddr netip.AddrPort) *pubAddrStore { + return &pubAddrStore{ + localPub: localAddr.IsValid(), + localAddr: localAddr, + lastSeen: map[netip.AddrPort]time.Time{}, + addrList: make([]netip.AddrPort, 0, 32), + } +} + +func (store *pubAddrStore) Store(add netip.AddrPort) { + if store.localPub { + log.Printf("OOPS: Local pub but storage attempt: %s", debug.Stack()) + return + } + + if !add.IsValid() { + return + } + + if _, exists := store.lastSeen[add]; !exists { + store.addrList = append(store.addrList, add) + } + store.lastSeen[add] = time.Now() + store.sort() +} + +func (store *pubAddrStore) Get() (addrs [8]netip.AddrPort) { + if store.localPub { + addrs[0] = store.localAddr + return + } + + copy(addrs[:], store.addrList) + return +} + +func (store *pubAddrStore) Clean() { + if store.localPub { + return + } + + for ip, lastSeen := range store.lastSeen { + if time.Since(lastSeen) > timeoutInterval { + delete(store.lastSeen, ip) + } + } + store.addrList = store.addrList[:0] + for ip := range store.lastSeen { + store.addrList = append(store.addrList, ip) + } + store.sort() +} + +func (store *pubAddrStore) sort() { + sort.Slice(store.addrList, func(i, j int) bool { + return store.lastSeen[store.addrList[j]].Before(store.lastSeen[store.addrList[i]]) + }) +} diff --git a/peer/pubaddrs_test.go b/peer/pubaddrs_test.go new file mode 100644 index 0000000..b79e854 --- /dev/null +++ b/peer/pubaddrs_test.go @@ -0,0 +1,29 @@ +package peer + +import ( + "net/netip" + "testing" + "time" +) + +func TestPubAddrStore(t *testing.T) { + s := newPubAddrStore(netip.AddrPort{}) + + l := []netip.AddrPort{ + netip.AddrPortFrom(netip.AddrFrom4([4]byte{0, 1, 2, 3}), 20), + netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 2, 3}), 21), + netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 1, 2, 3}), 22), + } + + for i := range l { + s.Store(l[i]) + time.Sleep(time.Millisecond) + } + + s.Clean() + + l2 := s.Get() + if l2[0] != l[2] || l2[1] != l[1] || l2[2] != l[0] { + t.Fatal(l, l2) + } +} diff --git a/peer/routingtable.go b/peer/routingtable.go new file mode 100644 index 0000000..0943ab2 --- /dev/null +++ b/peer/routingtable.go @@ -0,0 +1,137 @@ +package peer + +import ( + "net/netip" + "sync/atomic" + "time" +) + +// TODO: Remove +func NewRemotePeer(ip byte) *RemotePeer { + counter := uint64(time.Now().Unix()<<30 + 1) + return &RemotePeer{ + IP: ip, + counter: &counter, + dupCheck: newDupCheck(0), + } +} + +// ---------------------------------------------------------------------------- + +type RemotePeer struct { + localIP byte + IP byte // VPN IP of peer (last byte). + Up bool // True if data can be sent on the peer. + Relay bool // True if the peer is a relay. + Direct bool // True if this is a direct connection. + DirectAddr netip.AddrPort // Remote address if directly connected. + PubSignKey []byte + ControlCipher *controlCipher + DataCipher *dataCipher + + counter *uint64 // For sending to. Atomic access only. + dupCheck *dupCheck // For receiving from. Not safe for concurrent use. +} + +func (p RemotePeer) EncryptDataPacket(destIP byte, data, out []byte) []byte { + h := header{ + StreamID: dataStreamID, + Counter: atomic.AddUint64(p.counter, 1), + SourceIP: p.localIP, + DestIP: destIP, + } + return p.DataCipher.Encrypt(h, data, out) +} + +// Decrypts and de-dups incoming data packets. +func (p RemotePeer) DecryptDataPacket(h header, enc, out []byte) ([]byte, error) { + dec, ok := p.DataCipher.Decrypt(enc, out) + if !ok { + return nil, errDecryptionFailed + } + + if p.dupCheck.IsDup(h.Counter) { + return nil, errDuplicateSeqNum + } + + return dec, nil +} + +// Peer must have a ControlCipher. +func (p RemotePeer) EncryptControlPacket(pkt Marshaller, tmp, out []byte) []byte { + h := header{ + StreamID: controlStreamID, + Counter: atomic.AddUint64(p.counter, 1), + SourceIP: p.localIP, + DestIP: p.IP, + } + tmp = pkt.Marshal(tmp) + return p.ControlCipher.Encrypt(h, tmp, out) +} + +// Returns a controlMsg[PacketType]. Peer must have a non-nil ControlCipher. +// +// This function also drops packets with duplicate sequence numbers. +func (p RemotePeer) DecryptControlPacket(fromAddr netip.AddrPort, h header, enc, tmp []byte) (any, error) { + out, ok := p.ControlCipher.Decrypt(enc, tmp) + if !ok { + return nil, errDecryptionFailed + } + + if p.dupCheck.IsDup(h.Counter) { + return nil, errDuplicateSeqNum + } + + msg, err := parseControlMsg(h.SourceIP, fromAddr, out) + if err != nil { + return nil, err + } + + return msg, nil +} + +// ---------------------------------------------------------------------------- + +type RoutingTable struct { + // The LocalIP is the configured IP address of the local peer on the VPN. + // + // This value is constant. + LocalIP byte + + // The LocalAddr is the configured local public address of the peer on the + // internet. If LocalAddr.IsValid(), then the local peer has a public + // address. + // + // This value is constant. + LocalAddr netip.AddrPort + + // The remote peer configurations. These are updated by + Peers [256]RemotePeer + + // The current relay's VPN IP address, or zero if no relay is available. + RelayIP byte +} + +func NewRoutingTable(localIP byte, localAddr netip.AddrPort) RoutingTable { + rt := RoutingTable{ + LocalIP: localIP, + LocalAddr: localAddr, + } + + for i := range rt.Peers { + counter := uint64(time.Now().Unix()<<30 + 1) + rt.Peers[i] = RemotePeer{ + localIP: localIP, + IP: byte(i), + counter: &counter, + dupCheck: newDupCheck(0), + } + } + + return rt +} + +func (rt *RoutingTable) GetRelay() (RemotePeer, bool) { + relay := rt.Peers[rt.RelayIP] + return relay, relay.Up && relay.Direct +} diff --git a/peer/routingtable_test.go b/peer/routingtable_test.go new file mode 100644 index 0000000..b5497a4 --- /dev/null +++ b/peer/routingtable_test.go @@ -0,0 +1,169 @@ +package peer + +import ( + "bytes" + "reflect" + "testing" +) + +func TestRemotePeer_DecryptDataPacket(t *testing.T) { + p1, p2, _ := NewPeersForTesting() + orig := RandPacket() + + peer2 := p1.RT.Load().Peers[2] + peer1 := p2.RT.Load().Peers[1] + + enc := peer2.EncryptDataPacket(2, orig, newBuf()) + + h := parseHeader(enc) + if h.DestIP != 2 || h.SourceIP != 1 { + t.Fatal(h) + } + + dec, err := peer1.DecryptDataPacket(h, enc, newBuf()) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(orig, dec) { + t.Fatal(dec) + } +} + +func TestRemotePeer_DecryptDataPacket_packetAltered(t *testing.T) { + p1, p2, _ := NewPeersForTesting() + orig := RandPacket() + + peer2 := p1.RT.Load().Peers[2] + peer1 := p2.RT.Load().Peers[1] + + enc := peer2.EncryptDataPacket(2, orig, newBuf()) + + h := parseHeader(enc) + + for range 2048 { + _, err := peer1.DecryptDataPacket(h, ModifyPacket(enc), newBuf()) + if err == nil { + t.Fatal(enc) + } + } +} + +func TestRemotePeer_DecryptDataPacket_duplicateSequenceNumber(t *testing.T) { + p1, p2, _ := NewPeersForTesting() + orig := RandPacket() + + peer2 := p1.RT.Load().Peers[2] + peer1 := p2.RT.Load().Peers[1] + + enc := peer2.EncryptDataPacket(2, orig, newBuf()) + h := parseHeader(enc) + + if _, err := peer1.DecryptDataPacket(h, enc, newBuf()); err != nil { + t.Fatal(err) + } + + if _, err := peer1.DecryptDataPacket(h, enc, newBuf()); err == nil { + t.Fatal(err) + } +} + +func TestRemotePeer_DecryptControlPacket(t *testing.T) { + p1, p2, _ := NewPeersForTesting() + + peer2 := p1.RT.Load().Peers[2] + peer1 := p2.RT.Load().Peers[1] + + orig := PacketProbe{TraceID: newTraceID()} + + enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf()) + + h := parseHeader(enc) + if h.DestIP != 2 || h.SourceIP != 1 { + t.Fatal(h) + } + + ctrlMsg, err := peer1.DecryptControlPacket(p1.RT.Load().LocalAddr, h, enc, newBuf()) + if err != nil { + t.Fatal(err) + } + + dec, ok := ctrlMsg.(controlMsg[PacketProbe]) + if !ok { + t.Fatal(ctrlMsg) + } + + if dec.SrcIP != 1 || dec.SrcAddr != p1.RT.Load().LocalAddr { + t.Fatal(dec) + } + + if !reflect.DeepEqual(dec.Packet, orig) { + t.Fatal(dec) + } +} + +func TestRemotePeer_DecryptControlPacket_packetAltered(t *testing.T) { + p1, p2, _ := NewPeersForTesting() + + peer2 := p1.RT.Load().Peers[2] + peer1 := p2.RT.Load().Peers[1] + + orig := PacketProbe{TraceID: newTraceID()} + + enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf()) + + h := parseHeader(enc) + if h.DestIP != 2 || h.SourceIP != 1 { + t.Fatal(h) + } + + for range 2048 { + ctrlMsg, err := peer1.DecryptControlPacket(p1.RT.Load().LocalAddr, h, ModifyPacket(enc), newBuf()) + if err == nil { + t.Fatal(ctrlMsg) + } + } +} + +func TestRemotePeer_DecryptControlPacket_duplicateSequenceNumber(t *testing.T) { + p1, p2, _ := NewPeersForTesting() + + peer2 := p1.RT.Load().Peers[2] + peer1 := p2.RT.Load().Peers[1] + + orig := PacketProbe{TraceID: newTraceID()} + + enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf()) + + h := parseHeader(enc) + if h.DestIP != 2 || h.SourceIP != 1 { + t.Fatal(h) + } + + if _, err := peer1.DecryptControlPacket(p1.RT.Load().LocalAddr, h, enc, newBuf()); err != nil { + t.Fatal(err) + } + if _, err := peer1.DecryptControlPacket(p1.RT.Load().LocalAddr, h, enc, newBuf()); err == nil { + t.Fatal(err) + } +} + +func TestRemotePeer_DecryptControlPacket_unknownPacketType(t *testing.T) { + p1, p2, _ := NewPeersForTesting() + + peer2 := p1.RT.Load().Peers[2] + peer1 := p2.RT.Load().Peers[1] + + orig := UnknownControlPacket{TraceID: newTraceID()} + + enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf()) + + h := parseHeader(enc) + if h.DestIP != 2 || h.SourceIP != 1 { + t.Fatal(h) + } + + if _, err := peer1.DecryptControlPacket(p1.RT.Load().LocalAddr, h, enc, newBuf()); err == nil { + t.Fatal(err) + } +} diff --git a/peer/state.go b/peer/state.go deleted file mode 100644 index d6589fe..0000000 --- a/peer/state.go +++ /dev/null @@ -1,29 +0,0 @@ -package peer - -import ( - "net/netip" - "time" -) - -type RemotePeer struct { - IP byte // VPN IP of peer (last byte). - Up bool // True if data can be sent on the peer. - Relay bool // True if the peer is a relay. - Direct bool // True if this is a direct connection. - DirectAddr netip.AddrPort // Remote address if directly connected. - PubSignKey []byte - ControlCipher *controlCipher - DataCipher *dataCipher - - Counter *uint64 // For sending to. Atomic access only. - DupCheck *dupCheck // For receiving from. Not safe for concurrent use. -} - -func NewRemotePeer(ip byte) *RemotePeer { - counter := uint64(time.Now().Unix()<<30 + 1) - return &RemotePeer{ - IP: ip, - Counter: &counter, - DupCheck: newDupCheck(0), - } -} diff --git a/peer/supervisor.go b/peer/supervisor.go new file mode 100644 index 0000000..0f82a3f --- /dev/null +++ b/peer/supervisor.go @@ -0,0 +1,103 @@ +package peer + +import ( + "log" + "sync/atomic" + "time" + + "git.crumpington.com/lib/go/ratelimiter" +) + +// ---------------------------------------------------------------------------- + +type Supervisor struct { + messages chan any // Incoming control messages. + peers [256]PeerState + pubAddrs *pubAddrStore + rt *atomic.Pointer[RoutingTable] + staged RoutingTable +} + +func NewSupervisor( + sendControl func(RemotePeer, Marshaller), + privKey []byte, + rt *atomic.Pointer[RoutingTable], +) *Supervisor { + s := &Supervisor{ + messages: make(chan any, 1024), + pubAddrs: newPubAddrStore(rt.Load().LocalAddr), + rt: rt, + } + + routes := rt.Load() + + for i := range s.peers { + state := &State{ + publish: s.Publish, + sendControlPacket: sendControl, + localIP: routes.LocalIP, + remoteIP: byte(i), + privKey: privKey, + localAddr: routes.LocalAddr, + pubAddrs: s.pubAddrs, + staged: routes.Peers[i], + limiter: ratelimiter.New(ratelimiter.Config{ + FillPeriod: 20 * time.Millisecond, + MaxWaitCount: 1, + }), + } + s.peers[i] = state.OnPeerUpdate(nil) + } + + return s +} + +func (s *Supervisor) HandleControlMsg(msg any) { + select { + case s.messages <- msg: + default: + } +} + +func (s *Supervisor) Run() { + for raw := range s.messages { + switch msg := raw.(type) { + + case peerUpdateMsg: + s.peers[msg.PeerIP] = s.peers[msg.PeerIP].OnPeerUpdate(msg.Peer) + + case controlMsg[PacketSyn]: + if newState := s.peers[msg.SrcIP].OnSyn(msg); newState != nil { + s.peers[msg.SrcIP] = newState + } + + case controlMsg[PacketAck]: + s.peers[msg.SrcIP].OnAck(msg) + + case controlMsg[PacketProbe]: + if newState := s.peers[msg.SrcIP].OnProbe(msg); newState != nil { + s.peers[msg.SrcIP] = newState + } + + case controlMsg[PacketLocalDiscovery]: + s.peers[msg.SrcIP].OnLocalDiscovery(msg) + + case pingTimerMsg: + s.pubAddrs.Clean() + for i := range s.peers { + if newState := s.peers[i].OnPingTimer(); newState != nil { + s.peers[i] = newState + } + } + + default: + log.Printf("WARNING: unknown message type: %+v", msg) + } + } +} + +func (s *Supervisor) Publish(rp RemotePeer) { + s.staged.Peers[rp.IP] = rp + rt := s.staged // Copy. + s.rt.Store(&rt) +} diff --git a/peer/util_test.go b/peer/util_test.go new file mode 100644 index 0000000..56b9d6f --- /dev/null +++ b/peer/util_test.go @@ -0,0 +1,26 @@ +package peer + +import ( + "net/netip" + "testing" +) + +func addrPort4(a, b, c, d byte, port uint16) netip.AddrPort { + return netip.AddrPortFrom(netip.AddrFrom4([4]byte{a, b, c, d}), port) +} + +func assertType[T any](t *testing.T, obj any) T { + t.Helper() + x, ok := obj.(T) + if !ok { + t.Fatal("invalid type", obj) + } + return x +} + +func assertEqual[T comparable](t *testing.T, a, b T) { + t.Helper() + if a != b { + t.Fatal(a, " != ", b) + } +}