From d79902a83bc865288d618627fa4cc37ff945aa2e Mon Sep 17 00:00:00 2001 From: jdl Date: Wed, 22 Jan 2025 14:09:43 +0100 Subject: [PATCH 01/26] WIP --- README.md | 6 - node/README.md | 17 +++ node/conn.go | 47 ------ node/connwriter.go | 146 ++++++++++++++++++ node/connwriter_test.go | 291 ++++++++++++++++++++++++++++++++++++ node/crypto.go | 30 ++++ node/data-flow.dot | 14 ++ node/globalfuncs.go | 57 ------- node/globals.go | 37 +---- node/ifreader.go | 102 +++++++++++++ node/ifreader_test.go | 117 +++++++++++++++ node/ifwriter.go | 5 + node/localdiscovery_test.go | 2 +- node/main.go | 30 ++-- node/main_test.go | 37 +++++ node/mcwriter.go | 62 ++++++++ node/mcwriter_test.go | 102 +++++++++++++ node/messages.go | 5 +- node/packets.go | 3 +- node/packets_test.go | 40 ----- node/packetsender.go | 127 ++++++++++++++++ node/relaymanager.go | 1 + node/shared.go | 59 ++++++++ node/shared_test.go | 16 ++ node/supervisor.go | 28 ++-- 25 files changed, 1168 insertions(+), 213 deletions(-) create mode 100644 node/README.md create mode 100644 node/connwriter.go create mode 100644 node/connwriter_test.go create mode 100644 node/crypto.go create mode 100644 node/data-flow.dot create mode 100644 node/ifreader.go create mode 100644 node/ifreader_test.go create mode 100644 node/ifwriter.go create mode 100644 node/main_test.go create mode 100644 node/mcwriter.go create mode 100644 node/mcwriter_test.go create mode 100644 node/packetsender.go create mode 100644 node/shared.go create mode 100644 node/shared_test.go diff --git a/README.md b/README.md index c6cc0e1..4567196 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,5 @@ # vppn: Virtual Potentially Private Network -## TODO - -* Add `-force-init` argument to `node` main? - ## Hub Server Configuration ``` @@ -33,7 +29,6 @@ WorkingDirectory=/home/user/ ExecStart=/home/user/hub -listen :https -root-dir=/home/user Restart=always RestartSec=8 -TimeoutStopSec=24 [Install] WantedBy=default.target @@ -70,7 +65,6 @@ WorkingDirectory=/home/user/ ExecStart=/home/user/vppn -name vppn -hub-address https://my.hub -api-key 1234567890 Restart=always RestartSec=8 -TimeoutStopSec=24 [Install] WantedBy=default.target diff --git a/node/README.md b/node/README.md new file mode 100644 index 0000000..30b77c4 --- /dev/null +++ b/node/README.md @@ -0,0 +1,17 @@ +# VPPN Peer Code + +## Refactoring for Testability + +* [ ] connWriter + * [ ] Separate send/relay calls +* [x] mcWriter +* [x] ifWriter +* [ ] ifReader +* [ ] connReader +* [ ] mcReader +* [ ] hubPoller +* [ ] supervisor + +## Updates + +* [ ] Send timing info w/ syn/ack packets diff --git a/node/conn.go b/node/conn.go index 2a1e762..e000557 100644 --- a/node/conn.go +++ b/node/conn.go @@ -1,50 +1,3 @@ package node -import ( - "io" - "log" - "net" - "net/netip" - "sync" -) - // ---------------------------------------------------------------------------- - -type connWriter struct { - lock sync.Mutex - conn *net.UDPConn -} - -func newConnWriter(conn *net.UDPConn) *connWriter { - return &connWriter{conn: conn} -} - -func (w *connWriter) WriteTo(packet []byte, addr netip.AddrPort) { - // Even though a conn is safe for concurrent use, it turns out that a mutex - // in Go is more fair when there's contention. Without this lock, control - // packets may fail to be sent in a timely manner causing timeouts. - w.lock.Lock() - if _, err := w.conn.WriteToUDPAddrPort(packet, addr); err != nil { - log.Printf("Failed to write to UDP port: %v", err) - } - w.lock.Unlock() -} - -// ---------------------------------------------------------------------------- - -type ifWriter struct { - lock sync.Mutex - iface io.ReadWriteCloser -} - -func newIFWriter(iface io.ReadWriteCloser) *ifWriter { - return &ifWriter{iface: iface} -} - -func (w *ifWriter) Write(packet []byte) { - w.lock.Lock() - if _, err := w.iface.Write(packet); err != nil { - log.Fatalf("Failed to write to interface: %v", err) - } - w.lock.Unlock() -} diff --git a/node/connwriter.go b/node/connwriter.go new file mode 100644 index 0000000..597b886 --- /dev/null +++ b/node/connwriter.go @@ -0,0 +1,146 @@ +package node + +import ( + "log" + "net/netip" + "sync" + "sync/atomic" + "time" +) + +// ---------------------------------------------------------------------------- + +type peerRoute struct { + IP byte + Up bool // True if data can be sent on the route. + Relay bool // True if the peer is a relay. + Direct bool // True if this is a direct connection. + PubSignKey []byte + ControlCipher *controlCipher + DataCipher *dataCipher + RemoteAddr netip.AddrPort // Remote address if directly connected. +} + +// ---------------------------------------------------------------------------- + +type udpAddrPortWriter interface { + WriteToUDPAddrPort([]byte, netip.AddrPort) (int, error) +} + +type marshaller interface { + Marshal([]byte) []byte +} + +// ---------------------------------------------------------------------------- + +type connWriter struct { + localIP byte + conn udpAddrPortWriter + + // For sending control packets. + cBuf1 []byte + cBuf2 []byte + + // For sending data packets. + dBuf1 []byte + dBuf2 []byte + + counters [256]uint64 + + // Lock around for sending on UDP Conn. + wLock sync.Mutex +} + +func newConnWriter(conn udpAddrPortWriter, localIP byte) *connWriter { + w := &connWriter{ + localIP: localIP, + conn: conn, + cBuf1: make([]byte, bufferSize), + cBuf2: make([]byte, bufferSize), + dBuf1: make([]byte, bufferSize), + dBuf2: make([]byte, bufferSize), + } + for i := range w.counters { + w.counters[i] = uint64(time.Now().Unix()<<30 + 1) + } + return w +} + +// Not safe for concurrent use. Should only be called by supervisor. +func (w *connWriter) SendControlPacket(pkt marshaller, route *peerRoute) { + buf := pkt.Marshal(w.cBuf1) + h := header{ + StreamID: controlStreamID, + Counter: atomic.AddUint64(&w.counters[route.IP], 1), + SourceIP: w.localIP, + DestIP: route.IP, + } + buf = route.ControlCipher.Encrypt(h, buf, w.cBuf2) + w.writeTo(buf, route.RemoteAddr) +} + +func (w *connWriter) RelayControlPacket(pkt marshaller, route, relay *peerRoute) { + buf := pkt.Marshal(w.cBuf1) + h := header{ + StreamID: controlStreamID, + Counter: atomic.AddUint64(&w.counters[route.IP], 1), + SourceIP: w.localIP, + DestIP: route.IP, + } + buf = route.ControlCipher.Encrypt(h, buf, w.cBuf2) + w.relayPacket(buf, w.cBuf1, route, relay) +} + +// Not safe for concurrent use. Should only be called by ifReader. +func (w *connWriter) SendDataPacket(pkt []byte, route, relay *peerRoute) { + h := header{ + StreamID: dataStreamID, + Counter: atomic.AddUint64(&w.counters[route.IP], 1), + SourceIP: w.localIP, + DestIP: route.IP, + } + + enc := route.DataCipher.Encrypt(h, pkt, w.dBuf1) + + if route.Direct { + w.writeTo(enc, route.RemoteAddr) + return + } + + w.relayPacket(enc, w.dBuf2, route, relay) +} + +// TODO: RelayDataPacket + +// Safe for concurrent use. Should only be called by connReader. +// +// This function will send pkt to the peer directly. This is used when a peer +// is acting as a relay and is forwarding already encrypted data for another +// peer. +func (w *connWriter) SendEncryptedDataPacket(pkt []byte, route *peerRoute) { + w.writeTo(pkt, route.RemoteAddr) +} + +func (w *connWriter) relayPacket(data, buf []byte, route, relay *peerRoute) { + if relay == nil || !relay.Up { + return + } + + h := header{ + StreamID: dataStreamID, + Counter: atomic.AddUint64(&w.counters[relay.IP], 1), + SourceIP: w.localIP, + DestIP: route.IP, + } + + enc := relay.DataCipher.Encrypt(h, data, buf) + w.writeTo(enc, relay.RemoteAddr) +} + +func (w *connWriter) writeTo(packet []byte, addr netip.AddrPort) { + w.wLock.Lock() + if _, err := w.conn.WriteToUDPAddrPort(packet, addr); err != nil { + log.Printf("Failed to write to UDP port: %v", err) + } + w.wLock.Unlock() +} diff --git a/node/connwriter_test.go b/node/connwriter_test.go new file mode 100644 index 0000000..595d5b7 --- /dev/null +++ b/node/connwriter_test.go @@ -0,0 +1,291 @@ +package node + +import ( + "bytes" + "net/netip" + "testing" +) + +// ---------------------------------------------------------------------------- + +type testUDPPacket struct { + Addr netip.AddrPort + Data []byte +} + +type testUDPAddrPortWriter struct { + written []testUDPPacket +} + +func (w *testUDPAddrPortWriter) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { + w.written = append(w.written, testUDPPacket{ + Addr: addr, + Data: bytes.Clone(b), + }) + return len(b), nil +} + +func (w *testUDPAddrPortWriter) Written() []testUDPPacket { + out := w.written + w.written = []testUDPPacket{} + return out +} + +// ---------------------------------------------------------------------------- + +type testPacket string + +func (p testPacket) Marshal(b []byte) []byte { + b = b[:len(p)] + copy(b, []byte(p)) + return b +} + +// ---------------------------------------------------------------------------- + +func testConnWriter_getTestRoutes() (local, remote, relayLocal, relayRemote *peerRoute) { + localKeys := generateKeys() + remoteKeys := generateKeys() + + local = &peerRoute{ + IP: 2, + Up: true, + Relay: false, + PubSignKey: remoteKeys.PubSignKey, + ControlCipher: newControlCipher(localKeys.PrivKey, remoteKeys.PubKey), + DataCipher: newDataCipher(), + RemoteAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 2}), 100), + } + + remote = &peerRoute{ + IP: 1, + Up: true, + Relay: false, + PubSignKey: localKeys.PubSignKey, + ControlCipher: newControlCipher(remoteKeys.PrivKey, localKeys.PubKey), + DataCipher: local.DataCipher, + RemoteAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 100), + } + + rLocalKeys := generateKeys() + rRemoteKeys := generateKeys() + + relayLocal = &peerRoute{ + IP: 3, + Up: true, + Relay: true, + Direct: true, + PubSignKey: rRemoteKeys.PubSignKey, + ControlCipher: newControlCipher(rLocalKeys.PrivKey, rRemoteKeys.PubKey), + DataCipher: newDataCipher(), + RemoteAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 3}), 100), + } + + relayRemote = &peerRoute{ + IP: 1, + Up: true, + Relay: false, + Direct: true, + PubSignKey: rLocalKeys.PubSignKey, + ControlCipher: newControlCipher(rRemoteKeys.PrivKey, rLocalKeys.PubKey), + DataCipher: relayLocal.DataCipher, + RemoteAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 100), + } + + return +} + +// ---------------------------------------------------------------------------- + +// Testing if we can send a control packet directly to the remote route. +func TestConnWriter_SendControlPacket_direct(t *testing.T) { + route, rRoute, _, _ := testConnWriter_getTestRoutes() + route.Direct = true + + writer := &testUDPAddrPortWriter{} + w := newConnWriter(writer, rRoute.IP) + in := testPacket("hello world!") + + w.SendControlPacket(in, route) + out := writer.Written() + if len(out) != 1 { + t.Fatal(out) + } + + if out[0].Addr != route.RemoteAddr { + t.Fatal(out[0]) + } + + dec, ok := rRoute.ControlCipher.Decrypt(out[0].Data, make([]byte, 1024)) + if !ok { + t.Fatal(ok) + } + if string(dec) != string(in) { + t.Fatal(dec) + } +} + +// Testing if we can relay a packet via an intermediary. +func TestConnWriter_SendControlPacket_relay(t *testing.T) { + route, rRoute, relay, rRelay := testConnWriter_getTestRoutes() + + writer := &testUDPAddrPortWriter{} + w := newConnWriter(writer, rRoute.IP) + in := testPacket("hello world!") + + w.RelayControlPacket(in, route, relay) + + out := writer.Written() + if len(out) != 1 { + t.Fatal(out) + } + + if out[0].Addr != relay.RemoteAddr { + t.Fatal(out[0]) + } + + dec, ok := rRelay.DataCipher.Decrypt(out[0].Data, make([]byte, 1024)) + if !ok { + t.Fatal(ok) + } + + dec2, ok := rRoute.ControlCipher.Decrypt(dec, make([]byte, 1024)) + if !ok { + t.Fatal(ok) + } + + if string(dec2) != string(in) { + t.Fatal(dec2) + } +} + +// Testing that a nil relay doesn't cause an issue. +func TestConnWriter_SendControlPacket_relay_relayNil(t *testing.T) { + route, rRoute, _, _ := testConnWriter_getTestRoutes() + + writer := &testUDPAddrPortWriter{} + w := newConnWriter(writer, rRoute.IP) + in := testPacket("hello world!") + + w.RelayControlPacket(in, route, nil) + + out := writer.Written() + if len(out) != 0 { + t.Fatal(out) + } + +} + +// Testing that we don't send anything if the relay isn't up. +func TestConnWriter_SendControlPacket_relay_relayNotUp(t *testing.T) { + route, rRoute, relay, _ := testConnWriter_getTestRoutes() + relay.Up = false + + writer := &testUDPAddrPortWriter{} + w := newConnWriter(writer, rRoute.IP) + in := testPacket("hello world!") + + w.RelayControlPacket(in, route, relay) + + out := writer.Written() + if len(out) != 0 { + t.Fatal(out) + } +} + +// Testing that we can send a data packet directly to a remote route. +func TestConnWriter_SendDataPacket_direct(t *testing.T) { + route, rRoute, _, _ := testConnWriter_getTestRoutes() + route.Direct = true + + writer := &testUDPAddrPortWriter{} + w := newConnWriter(writer, rRoute.IP) + + in := []byte("hello world!") + w.SendDataPacket(in, route, nil) + + out := writer.Written() + if len(out) != 1 { + t.Fatal(out) + } + + if out[0].Addr != route.RemoteAddr { + t.Fatal(out[0]) + } + + dec, ok := rRoute.DataCipher.Decrypt(out[0].Data, make([]byte, 1024)) + if !ok { + t.Fatal(ok) + } + + if !bytes.Equal(dec, in) { + t.Fatal(dec) + } +} + +// Testing that we can relay a data packet via a relay. +func TestConnWriter_SendDataPacket_relay(t *testing.T) { + route, rRoute, relay, rRelay := testConnWriter_getTestRoutes() + + writer := &testUDPAddrPortWriter{} + w := newConnWriter(writer, rRoute.IP) + in := []byte("Hello world!") + + w.SendDataPacket(in, route, relay) + + out := writer.Written() + if len(out) != 1 { + t.Fatal(out) + } + + if out[0].Addr != relay.RemoteAddr { + t.Fatal(out[0]) + } + + dec, ok := rRelay.DataCipher.Decrypt(out[0].Data, make([]byte, 1024)) + if !ok { + t.Fatal(ok) + } + + dec2, ok := rRoute.DataCipher.Decrypt(dec, make([]byte, 1024)) + if !ok { + t.Fatal(ok) + } + + if !bytes.Equal(dec2, in) { + t.Fatal(dec2) + } +} + +// Testing that we don't attempt to relay if the relay is nil. +func TestConnWriter_SendDataPacket_relay_relayNil(t *testing.T) { + route, rRoute, _, _ := testConnWriter_getTestRoutes() + + writer := &testUDPAddrPortWriter{} + w := newConnWriter(writer, rRoute.IP) + in := []byte("Hello world!") + + w.SendDataPacket(in, route, nil) + + out := writer.Written() + if len(out) != 0 { + t.Fatal(out) + } +} + +// Testing that we don't attempt to relay if the relay isn't up. +func TestConnWriter_SendDataPacket_relay_relayNotUp(t *testing.T) { + route, rRoute, relay, _ := testConnWriter_getTestRoutes() + relay.Up = false + + writer := &testUDPAddrPortWriter{} + w := newConnWriter(writer, rRoute.IP) + in := []byte("Hello world!") + + w.SendDataPacket(in, route, relay) + + out := writer.Written() + if len(out) != 0 { + t.Fatal(out) + } +} diff --git a/node/crypto.go b/node/crypto.go new file mode 100644 index 0000000..c24aaad --- /dev/null +++ b/node/crypto.go @@ -0,0 +1,30 @@ +package node + +import ( + "crypto/rand" + "log" + + "golang.org/x/crypto/nacl/box" + "golang.org/x/crypto/nacl/sign" +) + +type cryptoKeys struct { + PubKey []byte + PrivKey []byte + PubSignKey []byte + PrivSignKey []byte +} + +func generateKeys() cryptoKeys { + pubKey, privKey, err := box.GenerateKey(rand.Reader) + if err != nil { + log.Fatalf("Failed to generate encryption keys: %v", err) + } + + pubSignKey, privSignKey, err := sign.GenerateKey(rand.Reader) + if err != nil { + log.Fatalf("Failed to generate signing keys: %v", err) + } + + return cryptoKeys{pubKey[:], privKey[:], pubSignKey[:], privSignKey[:]} +} diff --git a/node/data-flow.dot b/node/data-flow.dot new file mode 100644 index 0000000..45b6f05 --- /dev/null +++ b/node/data-flow.dot @@ -0,0 +1,14 @@ +digraph d { + ifReader -> connWriter; + connReader -> ifWriter; + connReader -> connWriter; + connReader -> supervisor; + mcReader -> supervisor; + supervisor -> connWriter; + supervisor -> mcWriter; + hubPoller -> supervisor; + + connWriter [shape="box"]; + mcWriter [shape="box"]; + ifWriter [shape="box"]; +} \ No newline at end of file diff --git a/node/globalfuncs.go b/node/globalfuncs.go index f32ec0b..2d13f57 100644 --- a/node/globalfuncs.go +++ b/node/globalfuncs.go @@ -1,65 +1,8 @@ package node -import ( - "sync/atomic" -) - func getRelayRoute() *peerRoute { if ip := relayIP.Load(); ip != nil { return routingTable[*ip].Load() } return nil } - -func _sendControlPacket(pkt interface{ Marshal([]byte) []byte }, route peerRoute, buf1, buf2 []byte) { - buf := pkt.Marshal(buf2) - h := header{ - StreamID: controlStreamID, - Counter: atomic.AddUint64(&sendCounters[route.IP], 1), - SourceIP: localIP, - DestIP: route.IP, - } - buf = route.ControlCipher.Encrypt(h, buf, buf1) - - if route.Direct { - _conn.WriteTo(buf, route.RemoteAddr) - return - } - - _relayPacket(route.IP, buf, buf2) -} - -func _sendDataPacket(route *peerRoute, pkt, buf1, buf2 []byte) { - h := header{ - StreamID: dataStreamID, - Counter: atomic.AddUint64(&sendCounters[route.IP], 1), - SourceIP: localIP, - DestIP: route.IP, - } - - enc := route.DataCipher.Encrypt(h, pkt, buf1) - - if route.Direct { - _conn.WriteTo(enc, route.RemoteAddr) - return - } - - _relayPacket(route.IP, enc, buf2) -} - -func _relayPacket(destIP byte, data, buf []byte) { - relayRoute := getRelayRoute() - if relayRoute == nil || !relayRoute.Up || !relayRoute.Relay { - return - } - - h := header{ - StreamID: dataStreamID, - Counter: atomic.AddUint64(&sendCounters[relayRoute.IP], 1), - SourceIP: localIP, - DestIP: destIP, - } - - enc := relayRoute.DataCipher.Encrypt(h, data, buf) - _conn.WriteTo(enc, relayRoute.RemoteAddr) -} diff --git a/node/globals.go b/node/globals.go index b72acc4..8538c4a 100644 --- a/node/globals.go +++ b/node/globals.go @@ -5,7 +5,6 @@ import ( "net/netip" "net/url" "sync/atomic" - "time" ) const ( @@ -17,21 +16,9 @@ const ( signOverhead = 64 ) -var ( - multicastIP = netip.AddrFrom4([4]byte{224, 0, 0, 157}) - multicastAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom(multicastIP, 4560)) -) - -type peerRoute struct { - IP byte - Up bool // True if data can be sent on the route. - Relay bool // True if the peer is a relay. - Direct bool // True if this is a direct connection. - PubSignKey []byte - ControlCipher *controlCipher - DataCipher *dataCipher - RemoteAddr netip.AddrPort // Remote address if directly connected. -} +var multicastAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom( + netip.AddrFrom4([4]byte{224, 0, 0, 157}), + 4560)) var ( hubURL *url.URL @@ -45,20 +32,7 @@ var ( privKey []byte privSignKey []byte - // Shared interface for writing. - _iface *ifWriter - - // Shared connection for writing. - _conn *connWriter - - // Counters for sending to each peer. - sendCounters [256]uint64 = func() (out [256]uint64) { - for i := range out { - out[i] = uint64(time.Now().Unix()<<30 + 1) - } - return - }() - + // TODO: Doesn't need to be global. // Duplicate checkers for incoming packets. dupChecks [256]*dupCheck = func() (out [256]*dupCheck) { for i := range out { @@ -67,9 +41,11 @@ var ( return }() + // TODO: Doesn't need to be global . // Messages for the supervisor. messages = make(chan any, 1024) + // TODO: Doesn't need to be global . // Global routing table. routingTable [256]*atomic.Pointer[peerRoute] = func() (out [256]*atomic.Pointer[peerRoute]) { for i := range out { @@ -82,5 +58,6 @@ var ( // Managed by the relayManager. relayIP = &atomic.Pointer[byte]{} + // TODO: Only used by supervisor: can make local there. publicAddrs = newPubAddrStore() ) diff --git a/node/ifreader.go b/node/ifreader.go new file mode 100644 index 0000000..a0e7a54 --- /dev/null +++ b/node/ifreader.go @@ -0,0 +1,102 @@ +package node + +import ( + "io" + "log" + "sync/atomic" +) + +type ifReader struct { + iface io.Reader + routes [256]*atomic.Pointer[peerRoute] + relay *atomic.Pointer[peerRoute] + sendDataPacket func(pkt []byte, route *peerRoute) + relayDataPacket func(pkt []byte, route, relay *peerRoute) +} + +func newIFReader( + iface io.Reader, + routes [256]*atomic.Pointer[peerRoute], + relay *atomic.Pointer[peerRoute], + sendDataPacket func(pkt []byte, route *peerRoute), + relayDackPacket func(pkt []byte, route, relay *peerRoute), +) *ifReader { + return &ifReader{ + iface: iface, + routes: routes, + relay: relay, + sendDataPacket: sendDataPacket, + } +} + +func (r *ifReader) Run() { + var ( + packet = make([]byte, bufferSize) + remoteIP byte + ok bool + ) + + for { + packet = r.readNextPacket(packet) + if remoteIP, ok = r.parsePacket(packet); ok { + r.sendPacket(packet, remoteIP) + } + } +} + +func (r *ifReader) sendPacket(pkt []byte, remoteIP byte) { + route := r.routes[remoteIP].Load() + if !route.Up { + log.Printf("Route not connected: %d", remoteIP) + return + } + + // Direct path => early return. + if route.Direct { + r.sendDataPacket(pkt, route) + return + } + + if relay := r.relay.Load(); relay != nil { + r.relayDataPacket(pkt, route, relay) + } +} + +// Get next packet, returning packet, and destination ip. +func (r *ifReader) readNextPacket(buf []byte) []byte { + n, err := r.iface.Read(buf[:cap(buf)]) + if err != nil { + log.Fatalf("Failed to read from interface: %v", err) + } + + return buf[:n] +} + +func (r *ifReader) parsePacket(buf []byte) (byte, bool) { + n := len(buf) + if n == 0 { + return 0, false + } + + version := buf[0] >> 4 + + switch version { + case 4: + if n < 20 { + log.Printf("Short IPv4 packet: %d", len(buf)) + return 0, false + } + return buf[19], true + + case 6: + if len(buf) < 40 { + log.Printf("Short IPv6 packet: %d", len(buf)) + return 0, false + } + return buf[39], true + + default: + log.Printf("Invalid IP packet version: %v", version) + return 0, false + } +} diff --git a/node/ifreader_test.go b/node/ifreader_test.go new file mode 100644 index 0000000..8f173f4 --- /dev/null +++ b/node/ifreader_test.go @@ -0,0 +1,117 @@ +package node + +import ( + "bytes" + "net" + "sync/atomic" + "testing" +) + +// Test that we parse IPv4 packets correctly. +func TestIFReader_parsePacket_ipv4(t *testing.T) { + r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil, nil) + + pkt := make([]byte, 1234) + pkt[0] = 4 << 4 + pkt[19] = 128 + + if ip, ok := r.parsePacket(pkt); !ok || ip != 128 { + t.Fatal(ip, ok) + } +} + +// Test that we parse IPv6 packets correctly. +func TestIFReader_parsePacket_ipv6(t *testing.T) { + r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil, nil) + + pkt := make([]byte, 1234) + pkt[0] = 6 << 4 + pkt[39] = 42 + + if ip, ok := r.parsePacket(pkt); !ok || ip != 42 { + t.Fatal(ip, ok) + } +} + +// Test that empty packets work as expected. +func TestIFReader_parsePacket_emptyPacket(t *testing.T) { + r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil, nil) + + pkt := make([]byte, 0) + if ip, ok := r.parsePacket(pkt); ok { + t.Fatal(ip, ok) + } +} + +// Test that invalid IP versions fail. +func TestIFReader_parsePacket_invalidIPVersion(t *testing.T) { + r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil, nil) + + for i := byte(1); i < 16; i++ { + if i == 4 || i == 6 { + continue + } + pkt := make([]byte, 1234) + pkt[0] = i << 4 + + if ip, ok := r.parsePacket(pkt); ok { + t.Fatal(i, ip, ok) + } + } +} + +// Test that short IPv4 packets fail. +func TestIFReader_parsePacket_shortIPv4(t *testing.T) { + r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil, nil) + + pkt := make([]byte, 19) + pkt[0] = 4 << 4 + + if ip, ok := r.parsePacket(pkt); ok { + t.Fatal(ip, ok) + } +} + +// Test that short IPv6 packets fail. +func TestIFReader_parsePacket_shortIPv6(t *testing.T) { + r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil, nil) + + pkt := make([]byte, 39) + pkt[0] = 6 << 4 + + if ip, ok := r.parsePacket(pkt); ok { + t.Fatal(ip, ok) + } +} + +// Test that we can read a packet. +func TestIFReader_readNextpacket(t *testing.T) { + in, out := net.Pipe() + r := newIFReader(out, [256]*atomic.Pointer[peerRoute]{}, nil, nil, nil) + defer in.Close() + defer out.Close() + + go in.Write([]byte("hello world!")) + + pkt := r.readNextPacket(make([]byte, bufferSize)) + if !bytes.Equal(pkt, []byte("hello world!")) { + t.Fatalf("%s", pkt) + } +} + +// Testing that we can send a packet directly. +func TestIFReader_sendPacket_direct(t *testing.T) { + // TODO +} + +// Testing that we don't send a packet if route isn't up. +func TestIFReader_sendPacket_directNotUp(t *testing.T) { + // TODO +} + +// Testing that we can send a packet via a relay. +func TestIFReader_sendPacket_relayed(t *testing.T) { + // TODO +} + +// Testing that we don't try to send on a nil relay IP. diff --git a/node/ifwriter.go b/node/ifwriter.go new file mode 100644 index 0000000..adb74e3 --- /dev/null +++ b/node/ifwriter.go @@ -0,0 +1,5 @@ +package node + +import "io" + +type ifWriter io.Writer diff --git a/node/localdiscovery_test.go b/node/localdiscovery_test.go index 7f4eaa3..b00b29d 100644 --- a/node/localdiscovery_test.go +++ b/node/localdiscovery_test.go @@ -20,7 +20,7 @@ func TestLocalDiscoveryPacketSigning(t *testing.T) { privSignKey = privSigKey[:] route := routingTable[localIP].Load() route.IP = byte(localIP) - route.PubSignKey = pubSignKey[0:32] + route.PubSignKey = pubSignKey[:] routingTable[localIP].Store(route) out := buildLocalDiscoveryPacket(buf1, buf2) diff --git a/node/main.go b/node/main.go index 4e59cf7..8e53cb4 100644 --- a/node/main.go +++ b/node/main.go @@ -143,10 +143,6 @@ func main() { conn.SetReadBuffer(1024 * 1024 * 8) conn.SetWriteBuffer(1024 * 1024 * 8) - // Intialize globals. - _iface = newIFWriter(iface) - _conn = newConnWriter(conn) - localIP = config.PeerIP ip, ok := netip.AddrFromSlice(config.PublicIP) @@ -169,17 +165,19 @@ func main() { } }() - go startPeerSuper() + sender := newPacketSender(conn) + + go startPeerSuper(routingTable, messages, sender) go newHubPoller().Run() - go readFromConn(conn) + go readFromConn(conn, iface, sender) - readFromIFace(iface) + readFromIFace(iface, sender) } // ---------------------------------------------------------------------------- -func readFromConn(conn *net.UDPConn) { +func readFromConn(conn *net.UDPConn, iface io.ReadWriteCloser, sender dataPacketSender) { defer panicHandler() @@ -213,7 +211,7 @@ func readFromConn(conn *net.UDPConn) { handleControlPacket(remoteAddr, h, data, decBuf) case dataStreamID: - handleDataPacket(h, data, decBuf) + handleDataPacket(h, data, decBuf, iface, sender) default: log.Printf("Unknown stream ID: %d", h.StreamID) @@ -263,7 +261,7 @@ func handleControlPacket(addr netip.AddrPort, h header, data, decBuf []byte) { } -func handleDataPacket(h header, data []byte, decBuf []byte) { +func handleDataPacket(h header, data []byte, decBuf []byte, iface ifWriter, sender dataPacketSender) { route := routingTable[h.SourceIP].Load() if !route.Up { log.Printf("Not connected (recv).") @@ -282,7 +280,9 @@ func handleDataPacket(h header, data []byte, decBuf []byte) { } if h.DestIP == localIP { - _iface.Write(dec) + if _, err := iface.Write(dec); err != nil { + log.Fatalf("Failed to write to interface: %v", err) + } return } @@ -292,16 +292,14 @@ func handleDataPacket(h header, data []byte, decBuf []byte) { return } - _conn.WriteTo(dec, destRoute.RemoteAddr) + sender.SendEncryptedDataPacket(dec, destRoute.RemoteAddr) } // ---------------------------------------------------------------------------- -func readFromIFace(iface io.ReadWriteCloser) { +func readFromIFace(iface io.ReadWriteCloser, sender dataPacketSender) { var ( packet = make([]byte, bufferSize) - buf1 = make([]byte, bufferSize) - buf2 = make([]byte, bufferSize) remoteIP byte err error ) @@ -318,6 +316,6 @@ func readFromIFace(iface io.ReadWriteCloser) { continue } - _sendDataPacket(route, packet, buf1, buf2) + sender.SendDataPacket(packet, *route) } } diff --git a/node/main_test.go b/node/main_test.go new file mode 100644 index 0000000..bf077a2 --- /dev/null +++ b/node/main_test.go @@ -0,0 +1,37 @@ +package node + +import ( + "crypto/rand" + "log" + + "golang.org/x/crypto/nacl/box" + "golang.org/x/crypto/nacl/sign" +) + +type testPeer struct { + IP byte + PubKey []byte + PrivKey []byte + PubSignKey []byte + PrivSignKey []byte +} + +func newTestPeer(ip byte) testPeer { + encPubKey, encPrivKey, err := box.GenerateKey(rand.Reader) + if err != nil { + log.Fatalf("Failed to generate encryption keys: %v", err) + } + + signPubKey, signPrivKey, err := sign.GenerateKey(rand.Reader) + if err != nil { + log.Fatalf("Failed to generate signing keys: %v", err) + } + + return testPeer{ + IP: ip, + PubKey: encPubKey[:], + PrivKey: encPrivKey[:], + PubSignKey: signPubKey[:], + PrivSignKey: signPrivKey[:], + } +} diff --git a/node/mcwriter.go b/node/mcwriter.go new file mode 100644 index 0000000..99e5b58 --- /dev/null +++ b/node/mcwriter.go @@ -0,0 +1,62 @@ +package node + +import ( + "log" + "net" + + "golang.org/x/crypto/nacl/sign" +) + +// ---------------------------------------------------------------------------- + +type udpWriter interface { + WriteToUDP([]byte, *net.UDPAddr) (int, error) +} + +// ---------------------------------------------------------------------------- + +func createLocalDiscoveryPacket(localIP byte, signingKey []byte) []byte { + h := header{ + SourceIP: localIP, + DestIP: 255, + } + buf := make([]byte, headerSize) + h.Marshal(buf) + out := make([]byte, headerSize+signOverhead) + return sign.Sign(out[:0], buf, (*[64]byte)(signingKey)) +} + +func headerFromLocalDiscoveryPacket(pkt []byte) (h header, ok bool) { + if len(pkt) != headerSize+signOverhead { + return + } + + h.Parse(pkt[signOverhead:]) + ok = true + return +} + +func verifyLocalDiscoveryPacket(pkt, buf []byte, pubSignKey []byte) bool { + _, ok := sign.Open(buf[:0], pkt, (*[32]byte)(pubSignKey)) + return ok +} + +// ---------------------------------------------------------------------------- + +type mcWriter struct { + conn udpWriter + discoveryPacket []byte +} + +func newMCWriter(conn udpWriter, localIP byte, signingKey []byte) *mcWriter { + return &mcWriter{ + conn: conn, + discoveryPacket: createLocalDiscoveryPacket(localIP, signingKey), + } +} + +func (w *mcWriter) SendLocalDiscovery() { + if _, err := w.conn.WriteToUDP(w.discoveryPacket, multicastAddr); err != nil { + log.Printf("Failed to write multicast UDP packet: %v", err) + } +} diff --git a/node/mcwriter_test.go b/node/mcwriter_test.go new file mode 100644 index 0000000..d182239 --- /dev/null +++ b/node/mcwriter_test.go @@ -0,0 +1,102 @@ +package node + +import ( + "bytes" + "net" + "testing" +) + +// ---------------------------------------------------------------------------- + +// Testing that we can create and verify a local discovery packet. +func TestVerifyLocalDiscoveryPacket_valid(t *testing.T) { + keys := generateKeys() + + created := createLocalDiscoveryPacket(55, keys.PrivSignKey) + + header, ok := headerFromLocalDiscoveryPacket(created) + if !ok { + t.Fatal(ok) + } + if header.SourceIP != 55 || header.DestIP != 255 { + t.Fatal(header) + } + + if !verifyLocalDiscoveryPacket(created, make([]byte, 1024), keys.PubSignKey) { + t.Fatal("Not valid") + } +} + +// Testing that we don't try to parse short packets. +func TestVerifyLocalDiscoveryPacket_tooShort(t *testing.T) { + keys := generateKeys() + + created := createLocalDiscoveryPacket(55, keys.PrivSignKey) + + _, ok := headerFromLocalDiscoveryPacket(created[:len(created)-1]) + if ok { + t.Fatal(ok) + } +} + +// Testing that modifying a packet makes it invalid. +func TestVerifyLocalDiscoveryPacket_invalid(t *testing.T) { + keys := generateKeys() + + created := createLocalDiscoveryPacket(55, keys.PrivSignKey) + buf := make([]byte, 1024) + for i := range created { + modified := bytes.Clone(created) + modified[i]++ + if verifyLocalDiscoveryPacket(modified, buf, keys.PubSignKey) { + t.Fatal("Verification should have failed.") + } + } +} + +// ---------------------------------------------------------------------------- + +type testUDPWriter struct { + written [][]byte +} + +func (w *testUDPWriter) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) { + w.written = append(w.written, bytes.Clone(b)) + return len(b), nil +} + +func (w *testUDPWriter) Written() [][]byte { + out := w.written + w.written = [][]byte{} + return out +} + +// ---------------------------------------------------------------------------- + +// Testing that the mcWriter sends local discovery packets as expected. +func TestMCWriter_SendLocalDiscovery(t *testing.T) { + keys := generateKeys() + writer := &testUDPWriter{} + + mcw := newMCWriter(writer, 42, keys.PrivSignKey) + mcw.SendLocalDiscovery() + + out := writer.Written() + if len(out) != 1 { + t.Fatal(out) + } + + pkt := out[0] + + header, ok := headerFromLocalDiscoveryPacket(pkt) + if !ok { + t.Fatal(ok) + } + if header.SourceIP != 42 || header.DestIP != 255 { + t.Fatal(header) + } + + if !verifyLocalDiscoveryPacket(pkt, make([]byte, 1024), keys.PubSignKey) { + t.Fatal("Verification should succeed.") + } +} diff --git a/node/messages.go b/node/messages.go index 76d86d4..64ca5fe 100644 --- a/node/messages.go +++ b/node/messages.go @@ -10,7 +10,8 @@ import ( type controlMsg[T any] struct { SrcIP byte SrcAddr netip.AddrPort - Packet T + // TODO: RecvdAt int64 // Unixmilli. + Packet T } func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error) { @@ -55,5 +56,3 @@ type peerUpdateMsg struct { // ---------------------------------------------------------------------------- type pingTimerMsg struct{} - -// ---------------------------------------------------------------------------- diff --git a/node/packets.go b/node/packets.go index 14d7377..f3aa523 100644 --- a/node/packets.go +++ b/node/packets.go @@ -21,7 +21,8 @@ const ( // ---------------------------------------------------------------------------- type synPacket struct { - TraceID uint64 // TraceID to match response w/ request. + TraceID uint64 // TraceID to match response w/ request. + // TODO: SentAt int64 // Unixmilli. SharedKey [32]byte // Our shared key. Direct bool PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. diff --git a/node/packets_test.go b/node/packets_test.go index 254bcc7..2b4023a 100644 --- a/node/packets_test.go +++ b/node/packets_test.go @@ -1,41 +1 @@ package node - -import ( - "crypto/rand" - "net/netip" - "reflect" - "testing" -) - -func TestPacketSyn(t *testing.T) { - in := synPacket{ - TraceID: newTraceID(), - FromAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{4, 5, 6, 7}), 22), - } - rand.Read(in.SharedKey[:]) - - out, err := parseSynPacket(in.Marshal(make([]byte, bufferSize))) - if err != nil { - t.Fatal(err) - } - - if !reflect.DeepEqual(in, out) { - t.Fatal("\n", in, "\n", out) - } -} - -func TestPacketSynAck(t *testing.T) { - in := ackPacket{ - TraceID: newTraceID(), - FromAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{4, 5, 6, 7}), 22), - } - - out, err := parseAckPacket(in.Marshal(make([]byte, bufferSize))) - if err != nil { - t.Fatal(err) - } - - if !reflect.DeepEqual(in, out) { - t.Fatal("\n", in, "\n", out) - } -} diff --git a/node/packetsender.go b/node/packetsender.go new file mode 100644 index 0000000..07e083a --- /dev/null +++ b/node/packetsender.go @@ -0,0 +1,127 @@ +package node + +import ( + "log" + "net" + "net/netip" + "sync" + "sync/atomic" + "time" +) + +type controlPacketSender interface { + SendControlPacket(pkt marshaller, route peerRoute) +} + +type dataPacketSender interface { + SendDataPacket(pkt []byte, route peerRoute) + SendEncryptedDataPacket(pkt []byte, addr netip.AddrPort) +} + +// ---------------------------------------------------------------------------- + +type packetSender struct { + conn *net.UDPConn + + // For sending control packets. + cLock sync.Mutex + cBuf1 []byte + cBuf2 []byte + + // For sending data packets. + dBuf1 []byte + dBuf2 []byte + + counters [256]uint64 + + // Lock around for sending on UDP Conn. + wLock sync.Mutex +} + +func newPacketSender(conn *net.UDPConn) *packetSender { + ps := &packetSender{ + conn: conn, + cBuf1: make([]byte, bufferSize), + cBuf2: make([]byte, bufferSize), + dBuf1: make([]byte, bufferSize), + dBuf2: make([]byte, bufferSize), + } + for i := range ps.counters { + ps.counters[i] = uint64(time.Now().Unix()<<30 + 1) + } + return ps +} + +// Safe for concurrent use. +func (sender *packetSender) SendControlPacket(pkt marshaller, route peerRoute) { + sender.cLock.Lock() + defer sender.cLock.Unlock() + + buf := pkt.Marshal(sender.cBuf1) + h := header{ + StreamID: controlStreamID, + Counter: atomic.AddUint64(&sender.counters[route.IP], 1), + SourceIP: localIP, + DestIP: route.IP, + } + buf = route.ControlCipher.Encrypt(h, buf, sender.cBuf2) + + if route.Direct { + sender.writeTo(buf, route.RemoteAddr) + return + } + + sender.relayPacket(route.IP, buf, sender.cBuf1) +} + +// Not safe for concurrent use. +func (sender *packetSender) SendDataPacket(pkt []byte, route peerRoute) { + h := header{ + StreamID: dataStreamID, + Counter: atomic.AddUint64(&sender.counters[route.IP], 1), + SourceIP: localIP, + DestIP: route.IP, + } + + enc := route.DataCipher.Encrypt(h, pkt, sender.dBuf1) + + if route.Direct { + sender.writeTo(enc, route.RemoteAddr) + return + } + + sender.relayPacket(route.IP, enc, sender.dBuf2) +} + +func (sender *packetSender) SendEncryptedDataPacket(pkt []byte, addr netip.AddrPort) { + sender.writeTo(pkt, addr) +} + +func (sender *packetSender) relayPacket(destIP byte, data, buf []byte) { + ip := relayIP.Load() + if ip == nil { + return + } + relayRoute := routingTable[*ip].Load() + if relayRoute == nil || !relayRoute.Up || !relayRoute.Relay { + return + } + + h := header{ + StreamID: dataStreamID, + Counter: atomic.AddUint64(&sender.counters[relayRoute.IP], 1), + SourceIP: localIP, + DestIP: destIP, + } + + enc := relayRoute.DataCipher.Encrypt(h, data, buf) + sender.writeTo(enc, relayRoute.RemoteAddr) +} + +func (sender *packetSender) writeTo(packet []byte, addr netip.AddrPort) { + sender.wLock.Lock() + if _, err := sender.conn.WriteToUDPAddrPort(packet, addr); err != nil { + log.Printf("Failed to write to UDP port: %v", err) + } + sender.wLock.Unlock() +} diff --git a/node/relaymanager.go b/node/relaymanager.go index 5c44ea8..a333ce1 100644 --- a/node/relaymanager.go +++ b/node/relaymanager.go @@ -6,6 +6,7 @@ import ( "time" ) +// TODO: Make part of main loop on ping timer func relayManager() { time.Sleep(2 * time.Second) updateRelayRoute() diff --git a/node/shared.go b/node/shared.go new file mode 100644 index 0000000..dbdb6ee --- /dev/null +++ b/node/shared.go @@ -0,0 +1,59 @@ +package node + +import ( + "net/netip" + "sync/atomic" +) + +type sharedState struct { + // Immutable: + HubAddress string + APIKey string + NetName string + LocalIP byte + LocalPub bool + LocalAddr netip.AddrPort + PrivKey []byte + PrivSignKey []byte + + // Mutable: + Routes [256]*atomic.Pointer[peerRoute] + RelayIP *atomic.Pointer[byte] + + // Messages for supervisor main loop. + Messages chan any +} + +func newSharedState( + netName, + hubAddress, + apiKey string, + conf localConfig, +) ( + ss sharedState, +) { + ss.HubAddress = hubAddress + + ss.APIKey = apiKey + ss.NetName = netName + ss.LocalIP = conf.PeerIP + + ip, ok := netip.AddrFromSlice(conf.PublicIP) + if ok { + ss.LocalPub = true + ss.LocalAddr = netip.AddrPortFrom(ip, conf.Port) + } + + ss.PrivKey = conf.PrivKey + ss.PrivSignKey = conf.PrivSignKey + + for i := range ss.Routes { + ss.Routes[i] = &atomic.Pointer[peerRoute]{} + ss.Routes[i].Store(&peerRoute{}) + } + + ss.RelayIP = &atomic.Pointer[byte]{} + + ss.Messages = make(chan any, 1024) + return +} diff --git a/node/shared_test.go b/node/shared_test.go new file mode 100644 index 0000000..4009e7d --- /dev/null +++ b/node/shared_test.go @@ -0,0 +1,16 @@ +package node + +import "vppn/m" + +// TODO: +var sharedStateForTesting = func() sharedState { + ss := newSharedState( + "testNet", + "http://localhost:39499", + "123", + localConfig{ + PeerConfig: m.PeerConfig{}, + }) + + return ss +} diff --git a/node/supervisor.go b/node/supervisor.go index 6b5e96a..726d47f 100644 --- a/node/supervisor.go +++ b/node/supervisor.go @@ -19,14 +19,17 @@ const ( // ---------------------------------------------------------------------------- -func startPeerSuper() { +func startPeerSuper( + routingTable [256]*atomic.Pointer[peerRoute], + messages chan any, + sender controlPacketSender, +) { peers := [256]peerState{} for i := range peers { data := &peerStateData{ + sender: sender, published: routingTable[i], remoteIP: byte(i), - buf1: make([]byte, bufferSize), - buf2: make([]byte, bufferSize), limiter: ratelimiter.New(ratelimiter.Config{ FillPeriod: 20 * time.Millisecond, MaxWaitCount: 1, @@ -34,10 +37,10 @@ func startPeerSuper() { } peers[i] = data.OnPeerUpdate(nil) } - go runPeerSuper(peers) + go runPeerSuper(peers, messages) } -func runPeerSuper(peers [256]peerState) { +func runPeerSuper(peers [256]peerState, messages chan any) { for raw := range messages { switch msg := raw.(type) { @@ -84,6 +87,8 @@ type peerState interface { // ---------------------------------------------------------------------------- type peerStateData struct { + sender controlPacketSender + // The purpose of this state machine is to manage this published data. published *atomic.Pointer[peerRoute] staged peerRoute // Local copy of shared data. See publish(). @@ -95,10 +100,6 @@ type peerStateData struct { peer *m.Peer remotePub bool - // Buffers for sending control packets. - buf1 []byte - buf2 []byte - // For logging. Set per-state. client bool @@ -129,7 +130,7 @@ func (s *peerStateData) _sendControlPacket(pkt interface{ Marshal([]byte) []byte s.logf("Not sending control packet: rate limited.") // Shouldn't happen. return } - _sendControlPacket(pkt, route, s.buf1, s.buf2) + s.sender.SendControlPacket(pkt, route) } // ---------------------------------------------------------------------------- @@ -175,8 +176,9 @@ func (s *peerStateData) OnPeerUpdate(peer *m.Peer) peerState { s.peer = peer s.staged = peerRoute{ - IP: s.remoteIP, - PubSignKey: peer.PubSignKey, + IP: s.remoteIP, + PubSignKey: peer.PubSignKey, + // TODO: privKey global. ControlCipher: newControlCipher(privKey, peer.PubKey), DataCipher: newDataCipher(), } @@ -192,6 +194,7 @@ func (s *peerStateData) OnPeerUpdate(peer *m.Peer) peerState { } if s.remotePub == localPub { + // TODO: localIP is global if localIP < s.remoteIP { return enterStateServer(s) } @@ -349,6 +352,7 @@ func (s *stateClient) OnAck(msg controlMsg[ackPacket]) { } // Store possible public address if we're not a public node. + // TODO: localPub is global, publicAddrs is global. if !localPub && s.remotePub { publicAddrs.Store(msg.Packet.ToAddr) } -- 2.39.5 From 1a6503bbda4946dc7687893fc643cd037ff6a6be Mon Sep 17 00:00:00 2001 From: jdl Date: Wed, 29 Jan 2025 11:45:09 +0100 Subject: [PATCH 02/26] wip --- node/README.md | 5 +- node/connwriter.go | 46 +++--- node/connwriter_test.go | 69 ++------ node/dupcheck.go | 4 +- node/header.go | 12 ++ node/ifreader.go | 2 +- peer/bitset.go | 21 +++ peer/bitset_test.go | 48 ++++++ peer/cipher-control.go | 26 +++ peer/cipher-control_test.go | 122 ++++++++++++++ peer/cipher-data.go | 61 +++++++ peer/cipher-data_test.go | 141 ++++++++++++++++ peer/cipher-discovery.go | 13 ++ peer/connreader.go | 141 ++++++++++++++++ peer/connreader_test.go | 318 ++++++++++++++++++++++++++++++++++++ peer/connwriter.go | 80 +++++++++ peer/connwriter_test.go | 240 +++++++++++++++++++++++++++ peer/controlmessage.go | 58 +++++++ peer/crypto.go | 113 +++++++++++++ peer/crypto_test.go | 213 ++++++++++++++++++++++++ peer/dupcheck.go | 76 +++++++++ peer/dupcheck_test.go | 57 +++++++ peer/errors.go | 10 ++ peer/globals.go | 19 +++ peer/header.go | 49 ++++++ peer/header_test.go | 21 +++ peer/ifreader.go | 100 ++++++++++++ peer/ifreader_test.go | 232 ++++++++++++++++++++++++++ peer/ifwriter.go | 5 + peer/interfaces.go | 28 ++++ peer/mcwriter.go | 62 +++++++ peer/mcwriter_test.go | 102 ++++++++++++ peer/packets-util.go | 190 +++++++++++++++++++++ peer/packets-util_test.go | 56 +++++++ peer/packets.go | 123 ++++++++++++++ peer/packets_test.go | 1 + peer/state.go | 29 ++++ 37 files changed, 2808 insertions(+), 85 deletions(-) create mode 100644 peer/bitset.go create mode 100644 peer/bitset_test.go create mode 100644 peer/cipher-control.go create mode 100644 peer/cipher-control_test.go create mode 100644 peer/cipher-data.go create mode 100644 peer/cipher-data_test.go create mode 100644 peer/cipher-discovery.go create mode 100644 peer/connreader.go create mode 100644 peer/connreader_test.go create mode 100644 peer/connwriter.go create mode 100644 peer/connwriter_test.go create mode 100644 peer/controlmessage.go create mode 100644 peer/crypto.go create mode 100644 peer/crypto_test.go create mode 100644 peer/dupcheck.go create mode 100644 peer/dupcheck_test.go create mode 100644 peer/errors.go create mode 100644 peer/globals.go create mode 100644 peer/header.go create mode 100644 peer/header_test.go create mode 100644 peer/ifreader.go create mode 100644 peer/ifreader_test.go create mode 100644 peer/ifwriter.go create mode 100644 peer/interfaces.go create mode 100644 peer/mcwriter.go create mode 100644 peer/mcwriter_test.go create mode 100644 peer/packets-util.go create mode 100644 peer/packets-util_test.go create mode 100644 peer/packets.go create mode 100644 peer/packets_test.go create mode 100644 peer/state.go diff --git a/node/README.md b/node/README.md index 30b77c4..58b4298 100644 --- a/node/README.md +++ b/node/README.md @@ -2,11 +2,10 @@ ## Refactoring for Testability -* [ ] connWriter - * [ ] Separate send/relay calls +* [x] connWriter * [x] mcWriter * [x] ifWriter -* [ ] ifReader +* [ ] ifReader (testing) * [ ] connReader * [ ] mcReader * [ ] hubPoller diff --git a/node/connwriter.go b/node/connwriter.go index 597b886..62caa75 100644 --- a/node/connwriter.go +++ b/node/connwriter.go @@ -68,18 +68,18 @@ func newConnWriter(conn udpAddrPortWriter, localIP byte) *connWriter { // Not safe for concurrent use. Should only be called by supervisor. func (w *connWriter) SendControlPacket(pkt marshaller, route *peerRoute) { - buf := pkt.Marshal(w.cBuf1) - h := header{ - StreamID: controlStreamID, - Counter: atomic.AddUint64(&w.counters[route.IP], 1), - SourceIP: w.localIP, - DestIP: route.IP, - } - buf = route.ControlCipher.Encrypt(h, buf, w.cBuf2) + buf := w.encryptControlPacket(pkt, route) w.writeTo(buf, route.RemoteAddr) } +// Relay control packet. Routes must not be nil. func (w *connWriter) RelayControlPacket(pkt marshaller, route, relay *peerRoute) { + buf := w.encryptControlPacket(pkt, route) + w.relayPacket(buf, w.cBuf1, route, relay) +} + +// Encrypted packet will occupy cBuf2. +func (w *connWriter) encryptControlPacket(pkt marshaller, route *peerRoute) []byte { buf := pkt.Marshal(w.cBuf1) h := header{ StreamID: controlStreamID, @@ -87,12 +87,11 @@ func (w *connWriter) RelayControlPacket(pkt marshaller, route, relay *peerRoute) SourceIP: w.localIP, DestIP: route.IP, } - buf = route.ControlCipher.Encrypt(h, buf, w.cBuf2) - w.relayPacket(buf, w.cBuf1, route, relay) + return route.ControlCipher.Encrypt(h, buf, w.cBuf2) } // Not safe for concurrent use. Should only be called by ifReader. -func (w *connWriter) SendDataPacket(pkt []byte, route, relay *peerRoute) { +func (w *connWriter) SendDataPacket(pkt []byte, route *peerRoute) { h := header{ StreamID: dataStreamID, Counter: atomic.AddUint64(&w.counters[route.IP], 1), @@ -101,16 +100,21 @@ func (w *connWriter) SendDataPacket(pkt []byte, route, relay *peerRoute) { } enc := route.DataCipher.Encrypt(h, pkt, w.dBuf1) - - if route.Direct { - w.writeTo(enc, route.RemoteAddr) - return - } - - w.relayPacket(enc, w.dBuf2, route, relay) + w.writeTo(enc, route.RemoteAddr) } -// TODO: RelayDataPacket +// Relay a data packet. Routes must not be nil. +func (w *connWriter) RelayDataPacket(pkt []byte, route, relay *peerRoute) { + h := header{ + StreamID: dataStreamID, + Counter: atomic.AddUint64(&w.counters[route.IP], 1), + SourceIP: w.localIP, + DestIP: route.IP, + } + + enc := route.DataCipher.Encrypt(h, pkt, w.dBuf1) + w.relayPacket(enc, w.dBuf2, route, relay) +} // Safe for concurrent use. Should only be called by connReader. // @@ -122,10 +126,6 @@ func (w *connWriter) SendEncryptedDataPacket(pkt []byte, route *peerRoute) { } func (w *connWriter) relayPacket(data, buf []byte, route, relay *peerRoute) { - if relay == nil || !relay.Up { - return - } - h := header{ StreamID: dataStreamID, Counter: atomic.AddUint64(&w.counters[relay.IP], 1), diff --git a/node/connwriter_test.go b/node/connwriter_test.go index 595d5b7..388fbbc 100644 --- a/node/connwriter_test.go +++ b/node/connwriter_test.go @@ -126,7 +126,7 @@ func TestConnWriter_SendControlPacket_direct(t *testing.T) { } // Testing if we can relay a packet via an intermediary. -func TestConnWriter_SendControlPacket_relay(t *testing.T) { +func TestConnWriter_RelayControlPacket_relay(t *testing.T) { route, rRoute, relay, rRelay := testConnWriter_getTestRoutes() writer := &testUDPAddrPortWriter{} @@ -159,40 +159,6 @@ func TestConnWriter_SendControlPacket_relay(t *testing.T) { } } -// Testing that a nil relay doesn't cause an issue. -func TestConnWriter_SendControlPacket_relay_relayNil(t *testing.T) { - route, rRoute, _, _ := testConnWriter_getTestRoutes() - - writer := &testUDPAddrPortWriter{} - w := newConnWriter(writer, rRoute.IP) - in := testPacket("hello world!") - - w.RelayControlPacket(in, route, nil) - - out := writer.Written() - if len(out) != 0 { - t.Fatal(out) - } - -} - -// Testing that we don't send anything if the relay isn't up. -func TestConnWriter_SendControlPacket_relay_relayNotUp(t *testing.T) { - route, rRoute, relay, _ := testConnWriter_getTestRoutes() - relay.Up = false - - writer := &testUDPAddrPortWriter{} - w := newConnWriter(writer, rRoute.IP) - in := testPacket("hello world!") - - w.RelayControlPacket(in, route, relay) - - out := writer.Written() - if len(out) != 0 { - t.Fatal(out) - } -} - // Testing that we can send a data packet directly to a remote route. func TestConnWriter_SendDataPacket_direct(t *testing.T) { route, rRoute, _, _ := testConnWriter_getTestRoutes() @@ -202,7 +168,7 @@ func TestConnWriter_SendDataPacket_direct(t *testing.T) { w := newConnWriter(writer, rRoute.IP) in := []byte("hello world!") - w.SendDataPacket(in, route, nil) + w.SendDataPacket(in, route) out := writer.Written() if len(out) != 1 { @@ -224,14 +190,14 @@ func TestConnWriter_SendDataPacket_direct(t *testing.T) { } // Testing that we can relay a data packet via a relay. -func TestConnWriter_SendDataPacket_relay(t *testing.T) { +func TestConnWriter_RelayDataPacket_relay(t *testing.T) { route, rRoute, relay, rRelay := testConnWriter_getTestRoutes() writer := &testUDPAddrPortWriter{} w := newConnWriter(writer, rRoute.IP) in := []byte("Hello world!") - w.SendDataPacket(in, route, relay) + w.RelayDataPacket(in, route, relay) out := writer.Written() if len(out) != 1 { @@ -257,35 +223,26 @@ func TestConnWriter_SendDataPacket_relay(t *testing.T) { } } -// Testing that we don't attempt to relay if the relay is nil. -func TestConnWriter_SendDataPacket_relay_relayNil(t *testing.T) { +// Testing that we can send an already encrypted packet. +func TestConnWriter_SendEncryptedDataPacket(t *testing.T) { route, rRoute, _, _ := testConnWriter_getTestRoutes() writer := &testUDPAddrPortWriter{} w := newConnWriter(writer, rRoute.IP) in := []byte("Hello world!") - w.SendDataPacket(in, route, nil) + w.SendEncryptedDataPacket(in, route) out := writer.Written() - if len(out) != 0 { + if len(out) != 1 { t.Fatal(out) } -} -// Testing that we don't attempt to relay if the relay isn't up. -func TestConnWriter_SendDataPacket_relay_relayNotUp(t *testing.T) { - route, rRoute, relay, _ := testConnWriter_getTestRoutes() - relay.Up = false + if out[0].Addr != route.RemoteAddr { + t.Fatal(out[0]) + } - writer := &testUDPAddrPortWriter{} - w := newConnWriter(writer, rRoute.IP) - in := []byte("Hello world!") - - w.SendDataPacket(in, route, relay) - - out := writer.Written() - if len(out) != 0 { - t.Fatal(out) + if !bytes.Equal(out[0].Data, in) { + t.Fatal(out[0]) } } diff --git a/node/dupcheck.go b/node/dupcheck.go index fac7a72..76792ae 100644 --- a/node/dupcheck.go +++ b/node/dupcheck.go @@ -38,14 +38,14 @@ func (dc *dupCheck) IsDup(counter uint64) bool { delta := counter - dc.tailCounter // Full clear. - if delta >= bitSetSize { + if delta >= bitSetSize-1 { dc.ClearAll() dc.Set(0) dc.tail = 1 dc.head = 2 dc.tailCounter = counter + 1 - dc.headCounter = dc.tailCounter - bitSetSize + dc.headCounter = dc.tailCounter - bitSetSize + 1 return false } diff --git a/node/header.go b/node/header.go index 9d0417a..915fe3e 100644 --- a/node/header.go +++ b/node/header.go @@ -20,6 +20,18 @@ type header struct { Counter uint64 // Init with time.Now().Unix << 30 to ensure monotonic. } +func parseHeader(b []byte) (h header, ok bool) { + if len(b) < headerSize { + return + } + h.Version = b[0] + h.StreamID = b[1] + h.SourceIP = b[2] + h.DestIP = b[3] + h.Counter = *(*uint64)(unsafe.Pointer(&b[4])) + return h, true +} + func (h *header) Parse(b []byte) { h.Version = b[0] h.StreamID = b[1] diff --git a/node/ifreader.go b/node/ifreader.go index a0e7a54..67d0999 100644 --- a/node/ifreader.go +++ b/node/ifreader.go @@ -57,7 +57,7 @@ func (r *ifReader) sendPacket(pkt []byte, remoteIP byte) { return } - if relay := r.relay.Load(); relay != nil { + if relay := r.relay.Load(); relay != nil && relay.Up { r.relayDataPacket(pkt, route, relay) } } diff --git a/peer/bitset.go b/peer/bitset.go new file mode 100644 index 0000000..8d03b50 --- /dev/null +++ b/peer/bitset.go @@ -0,0 +1,21 @@ +package peer + +const bitSetSize = 512 // Multiple of 64. + +type bitSet [bitSetSize / 64]uint64 + +func (bs *bitSet) Set(i int) { + bs[i/64] |= 1 << (i % 64) +} + +func (bs *bitSet) Clear(i int) { + bs[i/64] &= ^(1 << (i % 64)) +} + +func (bs *bitSet) ClearAll() { + clear(bs[:]) +} + +func (bs *bitSet) Get(i int) bool { + return bs[i/64]&(1<<(i%64)) != 0 +} diff --git a/peer/bitset_test.go b/peer/bitset_test.go new file mode 100644 index 0000000..01ae82b --- /dev/null +++ b/peer/bitset_test.go @@ -0,0 +1,48 @@ +package peer + +import ( + "math/rand" + "testing" +) + +func TestBitSet(t *testing.T) { + state := make([]bool, bitSetSize) + for i := range state { + state[i] = rand.Float32() > 0.5 + } + + bs := bitSet{} + + for i := range state { + if state[i] { + bs.Set(i) + } + } + + for i := range state { + if bs.Get(i) != state[i] { + t.Fatal(i, state[i], bs.Get(i)) + } + } + + for i := range state { + if rand.Float32() > 0.5 { + state[i] = false + bs.Clear(i) + } + } + + for i := range state { + if bs.Get(i) != state[i] { + t.Fatal(i, state[i], bs.Get(i)) + } + } + + bs.ClearAll() + + for i := range state { + if bs.Get(i) { + t.Fatal(i, bs.Get(i)) + } + } +} diff --git a/peer/cipher-control.go b/peer/cipher-control.go new file mode 100644 index 0000000..bfecaeb --- /dev/null +++ b/peer/cipher-control.go @@ -0,0 +1,26 @@ +package peer + +import "golang.org/x/crypto/nacl/box" + +type controlCipher struct { + sharedKey [32]byte +} + +func newControlCipher(privKey, pubKey []byte) *controlCipher { + shared := [32]byte{} + box.Precompute(&shared, (*[32]byte)(pubKey), (*[32]byte)(privKey)) + return &controlCipher{shared} +} + +func (cc *controlCipher) Encrypt(h header, data, out []byte) []byte { + const s = controlHeaderSize + out = out[:s+controlCipherOverhead+len(data)] + h.Marshal(out[:s]) + box.SealAfterPrecomputation(out[s:s], data, (*[24]byte)(out[:s]), &cc.sharedKey) + return out +} + +func (cc *controlCipher) Decrypt(encrypted, out []byte) (data []byte, ok bool) { + const s = controlHeaderSize + return box.OpenAfterPrecomputation(out[:0], encrypted[s:], (*[24]byte)(encrypted[:s]), &cc.sharedKey) +} diff --git a/peer/cipher-control_test.go b/peer/cipher-control_test.go new file mode 100644 index 0000000..916d2ea --- /dev/null +++ b/peer/cipher-control_test.go @@ -0,0 +1,122 @@ +package peer + +import ( + "bytes" + "crypto/rand" + "reflect" + "testing" + + "golang.org/x/crypto/nacl/box" +) + +func newControlCipherForTesting() (c1, c2 *controlCipher) { + pubKey1, privKey1, err := box.GenerateKey(rand.Reader) + if err != nil { + panic(err) + } + + pubKey2, privKey2, err := box.GenerateKey(rand.Reader) + if err != nil { + panic(err) + } + + return newControlCipher(privKey1[:], pubKey2[:]), + newControlCipher(privKey2[:], pubKey1[:]) +} + +func TestControlCipher(t *testing.T) { + c1, c2 := newControlCipherForTesting() + + maxSizePlaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead) + rand.Read(maxSizePlaintext) + + testCases := [][]byte{ + make([]byte, 0), + {1}, + {255}, + {1, 2, 3, 4, 5}, + []byte("Hello world"), + maxSizePlaintext, + } + + for _, plaintext := range testCases { + h1 := header{ + StreamID: controlStreamID, + Counter: 235153, + SourceIP: 4, + DestIP: 88, + } + + encrypted := make([]byte, bufferSize) + + encrypted = c1.Encrypt(h1, plaintext, encrypted) + + h2 := header{} + h2.Parse(encrypted) + if !reflect.DeepEqual(h1, h2) { + t.Fatal(h1, h2) + } + + decrypted, ok := c2.Decrypt(encrypted, make([]byte, bufferSize)) + if !ok { + t.Fatal(ok) + } + + if !bytes.Equal(decrypted, plaintext) { + t.Fatal("not equal") + } + } +} + +func TestControlCipher_ShortCiphertext(t *testing.T) { + c1, _ := newControlCipherForTesting() + shortText := make([]byte, controlHeaderSize+controlCipherOverhead-1) + rand.Read(shortText) + _, ok := c1.Decrypt(shortText, make([]byte, bufferSize)) + if ok { + t.Fatal(ok) + } +} + +func BenchmarkControlCipher_Encrypt(b *testing.B) { + c1, _ := newControlCipherForTesting() + h1 := header{ + Counter: 235153, + SourceIP: 4, + DestIP: 88, + } + + plaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead) + rand.Read(plaintext) + + encrypted := make([]byte, bufferSize) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + encrypted = c1.Encrypt(h1, plaintext, encrypted) + } +} + +func BenchmarkControlCipher_Decrypt(b *testing.B) { + c1, c2 := newControlCipherForTesting() + + h1 := header{ + Counter: 235153, + SourceIP: 4, + DestIP: 88, + } + + plaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead) + rand.Read(plaintext) + + encrypted := make([]byte, bufferSize) + + encrypted = c1.Encrypt(h1, plaintext, encrypted) + + decrypted := make([]byte, bufferSize) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + decrypted, _ = c2.Decrypt(encrypted, decrypted) + } +} diff --git a/peer/cipher-data.go b/peer/cipher-data.go new file mode 100644 index 0000000..9b229bb --- /dev/null +++ b/peer/cipher-data.go @@ -0,0 +1,61 @@ +package peer + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "log" +) + +type dataCipher struct { + key [32]byte + aead cipher.AEAD +} + +func newDataCipher() *dataCipher { + key := [32]byte{} + if _, err := rand.Read(key[:]); err != nil { + log.Fatalf("Failed to read random data: %v", err) + } + return newDataCipherFromKey(key) +} + +func newDataCipherFromKey(key [32]byte) *dataCipher { + block, err := aes.NewCipher(key[:]) + if err != nil { + log.Fatalf("Failed to create new cipher: %v", err) + } + + aead, err := cipher.NewGCM(block) + if err != nil { + log.Fatalf("Failed to create new GCM: %v", err) + } + + return &dataCipher{key: key, aead: aead} +} + +func (sc *dataCipher) Key() [32]byte { + return sc.key +} + +func (sc *dataCipher) Encrypt(h header, data, out []byte) []byte { + const s = dataHeaderSize + out = out[:s+dataCipherOverhead+len(data)] + h.Marshal(out[:s]) + sc.aead.Seal(out[s:s], out[:s], data, nil) + return out +} + +func (sc *dataCipher) Decrypt(encrypted, out []byte) (data []byte, ok bool) { + const s = dataHeaderSize + if len(encrypted) < s+dataCipherOverhead { + ok = false + return + } + + var err error + + data, err = sc.aead.Open(out[:0], encrypted[:s], encrypted[s:], nil) + ok = err == nil + return +} diff --git a/peer/cipher-data_test.go b/peer/cipher-data_test.go new file mode 100644 index 0000000..ac9a03a --- /dev/null +++ b/peer/cipher-data_test.go @@ -0,0 +1,141 @@ +package peer + +import ( + "bytes" + "crypto/rand" + mrand "math/rand/v2" + "reflect" + "testing" +) + +func TestDataCipher(t *testing.T) { + maxSizePlaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead) + rand.Read(maxSizePlaintext) + + testCases := [][]byte{ + make([]byte, 0), + {1}, + {255}, + {1, 2, 3, 4, 5}, + []byte("Hello world"), + maxSizePlaintext, + } + + for _, plaintext := range testCases { + h1 := header{ + StreamID: dataStreamID, + Counter: 235153, + SourceIP: 4, + DestIP: 88, + } + + encrypted := make([]byte, bufferSize) + + dc1 := newDataCipher() + encrypted = dc1.Encrypt(h1, plaintext, encrypted) + h2 := header{} + h2.Parse(encrypted) + + dc2 := newDataCipherFromKey(dc1.Key()) + + decrypted, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize)) + if !ok { + t.Fatal(ok) + } + + if !bytes.Equal(plaintext, decrypted) { + t.Fatal("not equal") + } + + if !reflect.DeepEqual(h1, h2) { + t.Fatalf("%v != %v", h1, h2) + } + } +} + +func TestDataCipher_ModifyCiphertext(t *testing.T) { + maxSizePlaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead) + rand.Read(maxSizePlaintext) + + testCases := [][]byte{ + make([]byte, 0), + {1}, + {255}, + {1, 2, 3, 4, 5}, + []byte("Hello world"), + maxSizePlaintext, + } + + for _, plaintext := range testCases { + h1 := header{ + Counter: 235153, + SourceIP: 4, + DestIP: 88, + } + + encrypted := make([]byte, bufferSize) + + dc1 := newDataCipher() + encrypted = dc1.Encrypt(h1, plaintext, encrypted) + encrypted[mrand.IntN(len(encrypted))]++ + + dc2 := newDataCipherFromKey(dc1.Key()) + + _, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize)) + if ok { + t.Fatal(ok) + } + } +} + +func TestDataCipher_ShortCiphertext(t *testing.T) { + dc1 := newDataCipher() + shortText := make([]byte, dataHeaderSize+dataCipherOverhead-1) + rand.Read(shortText) + _, ok := dc1.Decrypt(shortText, make([]byte, bufferSize)) + if ok { + t.Fatal(ok) + } +} + +func BenchmarkDataCipher_Encrypt(b *testing.B) { + h1 := header{ + Counter: 235153, + SourceIP: 4, + DestIP: 88, + } + + plaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead) + rand.Read(plaintext) + + encrypted := make([]byte, bufferSize) + + dc1 := newDataCipher() + b.ResetTimer() + for i := 0; i < b.N; i++ { + encrypted = dc1.Encrypt(h1, plaintext, encrypted) + } +} + +func BenchmarkDataCipher_Decrypt(b *testing.B) { + h1 := header{ + Counter: 235153, + SourceIP: 4, + DestIP: 88, + } + + plaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead) + rand.Read(plaintext) + + encrypted := make([]byte, bufferSize) + + dc1 := newDataCipher() + encrypted = dc1.Encrypt(h1, plaintext, encrypted) + + decrypted := make([]byte, bufferSize) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + decrypted, _ = dc1.Decrypt(encrypted, decrypted) + } +} diff --git a/peer/cipher-discovery.go b/peer/cipher-discovery.go new file mode 100644 index 0000000..0e66650 --- /dev/null +++ b/peer/cipher-discovery.go @@ -0,0 +1,13 @@ +package peer + +/* +func signData(privKey *[64]byte, h header, data, out []byte) []byte { + out = out[:headerSize] + h.Marshal(out) + return sign.Sign(out, data, privKey) +} + +func openData(pubKey *[32]byte, signed, out []byte) (data []byte, ok bool) { + return sign.Open(out[:0], signed[headerSize:], pubKey) +} +*/ diff --git a/peer/connreader.go b/peer/connreader.go new file mode 100644 index 0000000..757b37c --- /dev/null +++ b/peer/connreader.go @@ -0,0 +1,141 @@ +package peer + +import ( + "log" + "net/netip" + "sync/atomic" +) + +type connReader struct { + conn udpReader + iface ifWriter + sender encryptedPacketSender + super controlMsgHandler + localIP byte + routes [256]*atomic.Pointer[peerRoute] + + buf []byte + decBuf []byte + dupChecks [256]*dupCheck +} + +func newConnReader( + conn udpReader, + ifWriter ifWriter, + sender encryptedPacketSender, + super controlMsgHandler, + localIP byte, + routes [256]*atomic.Pointer[peerRoute], +) *connReader { + return &connReader{ + conn: conn, + iface: ifWriter, + sender: sender, + super: super, + localIP: localIP, + routes: routes, + buf: make([]byte, bufferSize), + decBuf: make([]byte, bufferSize), + dupChecks: func() (out [256]*dupCheck) { + for i := range out { + out[i] = newDupCheck(0) + } + return + }(), + } +} + +func (r *connReader) Run() { + for { + r.handleNextPacket() + } +} + +func (r *connReader) logf(s string, args ...any) { + log.Printf("[ConnReader] "+s, args...) +} + +func (r *connReader) handleNextPacket() { + buf := r.buf[:bufferSize] + n, remoteAddr, err := r.conn.ReadFromUDPAddrPort(buf) + if err != nil { + log.Fatalf("Failed to read from UDP port: %v", err) + } + + if n < headerSize { + return + } + + remoteAddr = netip.AddrPortFrom(remoteAddr.Addr().Unmap(), remoteAddr.Port()) + + buf = buf[:n] + h, ok := parseHeader(buf) + if !ok { + return + } + + route := r.routes[h.SourceIP].Load() + + switch h.StreamID { + case controlStreamID: + r.handleControlPacket(route, remoteAddr, h, buf) + + case dataStreamID: + r.handleDataPacket(route, h, buf) + + default: + r.logf("Unknown stream ID: %d", h.StreamID) + } +} + +func (r *connReader) handleControlPacket( + route *peerRoute, + addr netip.AddrPort, + h header, + enc []byte, +) { + if route.ControlCipher == nil { + return + } + + if h.DestIP != r.localIP { + r.logf("Incorrect destination IP on control packet: %d", h.DestIP) + return + } + + msg, err := decryptControlPacket(route, addr, h, enc, r.decBuf) + if err != nil { + r.logf("Failed to decrypt control packet: %v", err) + return + } + + r.super.HandleControlMsg(msg) +} + +func (r *connReader) handleDataPacket(route *peerRoute, h header, enc []byte) { + if !route.Up { + r.logf("Not connected (recv).") + return + } + + data, err := decryptDataPacket(route, h, enc, r.decBuf) + if err != nil { + r.logf("Failed to decrypt data packet: %v", err) + return + } + + if h.DestIP == r.localIP { + if _, err := r.iface.Write(data); err != nil { + log.Fatalf("Failed to write to interface: %v", err) + } + return + } + + destRoute := r.routes[h.DestIP].Load() + if !destRoute.Up { + r.logf("Not connected (relay): %d", destRoute.IP) + return + } + + r.sender.SendEncryptedDataPacket(data, destRoute) +} diff --git a/peer/connreader_test.go b/peer/connreader_test.go new file mode 100644 index 0000000..7ef4ad8 --- /dev/null +++ b/peer/connreader_test.go @@ -0,0 +1,318 @@ +package peer + +/* +type mockIfWriter struct { + Written [][]byte +} + +func (w *mockIfWriter) Write(b []byte) (int, error) { + w.Written = append(w.Written, bytes.Clone(b)) + return len(b), nil +} + +type mockEncryptedPacket struct { + Packet []byte + Route *peerRoute +} + +type mockEncryptedPacketSender struct { + Sent []mockEncryptedPacket +} + +func (m *mockEncryptedPacketSender) SendEncryptedDataPacket(pkt []byte, route *peerRoute) { + m.Sent = append(m.Sent, mockEncryptedPacket{ + Packet: bytes.Clone(pkt), + Route: route, + }) +} + +type mockControlMsgHandler struct { + Messages []any +} + +func (m *mockControlMsgHandler) HandleControlMsg(pkt any) { + m.Messages = append(m.Messages, pkt) +} + +type udpPipe struct { + packets chan []byte +} + +func newUDPPipe() *udpPipe { + return &udpPipe{make(chan []byte, 1024)} +} + +func (p *udpPipe) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { + p.packets <- bytes.Clone(b) + return len(b), nil +} + +func (p *udpPipe) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) { + packet := <-p.packets + copy(b, packet) + return len(packet), netip.AddrPort{}, nil +} + +type connReaderTestHarness struct { + Pipe *udpPipe + R *connReader + WRemote *connWriter + WRelayRemote *connWriter + Remote *peerRoute + RelayRemote *peerRoute + IFace *mockIfWriter + Sender *mockEncryptedPacketSender + Super *mockControlMsgHandler +} + +// Peer 2 is indirect, peer 3 is direct. +func newConnReadeTestHarness() (h connReaderTestHarness) { + pipe := newUDPPipe() + routes := [256]*atomic.Pointer[peerRoute]{} + for i := range routes { + routes[i] = &atomic.Pointer[peerRoute]{} + routes[i].Store(&peerRoute{}) + } + + local, remote, relayLocal, relayRemote := testConnWriter_getTestRoutes() + routes[2].Store(local) + routes[3].Store(relayLocal) + + h.Pipe = pipe + h.WRemote = newConnWriter(pipe, 2) + h.WRelayRemote = newConnWriter(pipe, 3) + + h.Remote = remote + h.RelayRemote = relayRemote + h.IFace = &mockIfWriter{} + h.Sender = &mockEncryptedPacketSender{} + h.Super = &mockControlMsgHandler{} + h.R = newConnReader( + pipe, + h.IFace, + h.Sender, + h.Super, + 1, + routes) + return h +} + +// Testing that we can receive a control packet. +func TestConnReader_handleControlPacket(t *testing.T) { + h := newConnReadeTestHarness() + + pkt := synPacket{TraceID: 1234} + + h.WRemote.SendControlPacket(pkt, h.Remote) + + h.R.handleNextPacket() + + if len(h.Super.Messages) != 1 { + t.Fatal(h.Super.Messages) + } + + msg := h.Super.Messages[0].(controlMsg[synPacket]) + if !reflect.DeepEqual(pkt, msg.Packet) { + t.Fatal(msg.Packet) + } +} + +// Testing that a short packet is ignored. +func TestConnReader_handleNextPacket_short(t *testing.T) { + h := newConnReadeTestHarness() + + h.Pipe.WriteToUDPAddrPort([]byte{1, 2, 3}, netip.AddrPort{}) + h.R.handleNextPacket() + + if len(h.Super.Messages) != 0 { + t.Fatal(h.Super.Messages) + } +} + +// Testing that a packet with an unexpected stream ID is ignored. +func TestConnReader_handleNextPacket_unknownStreamID(t *testing.T) { + h := newConnReadeTestHarness() + + pkt := synPacket{TraceID: 1234} + + encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) + var header header + header.Parse(encrypted) + header.StreamID = 100 + header.Marshal(encrypted) + + h.WRemote.writeTo(encrypted, netip.AddrPort{}) + h.R.handleNextPacket() + if len(h.Super.Messages) != 0 { + t.Fatal(h.Super.Messages) + } +} + +// Testing that control packet without matching control cipher is ignored. +func TestConnReader_handleControlPacket_noCipher(t *testing.T) { + h := newConnReadeTestHarness() + + pkt := synPacket{TraceID: 1234} + + encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) + var header header + header.Parse(encrypted) + header.SourceIP = 10 + header.Marshal(encrypted) + + h.WRemote.writeTo(encrypted, netip.AddrPort{}) + h.R.handleNextPacket() + if len(h.Super.Messages) != 0 { + t.Fatal(h.Super.Messages) + } +} + +// Testing that control packet with incrrect destination IP is ignored. +func TestConnReader_handleControlPacket_incorrectDest(t *testing.T) { + h := newConnReadeTestHarness() + + pkt := synPacket{TraceID: 1234} + + encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) + var header header + header.Parse(encrypted) + header.DestIP++ + header.Marshal(encrypted) + + h.WRemote.writeTo(encrypted, netip.AddrPort{}) + h.R.handleNextPacket() + if len(h.Super.Messages) != 0 { + t.Fatal(h.Super.Messages) + } +} + +// Testing that modified control packet is ignored. +func TestConnReader_handleControlPacket_modified(t *testing.T) { + h := newConnReadeTestHarness() + + pkt := synPacket{TraceID: 1234} + + encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) + encrypted[len(encrypted)-1]++ + + h.WRemote.writeTo(encrypted, netip.AddrPort{}) + h.R.handleNextPacket() + if len(h.Super.Messages) != 0 { + t.Fatal(h.Super.Messages) + } +} + +type emptyPacket struct{} + +func (p emptyPacket) Marshal(buf []byte) []byte { + return buf[:0] +} + +// Testing that an empty control packet is ignored. +func TestConnReader_handleControlPacket_empty(t *testing.T) { + h := newConnReadeTestHarness() + + pkt := emptyPacket{} + + encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) + h.WRemote.writeTo(encrypted, netip.AddrPort{}) + h.R.handleNextPacket() + if len(h.Super.Messages) != 0 { + t.Fatal(h.Super.Messages) + } +} + +// Testing that a duplicate control packet is ignored. +func TestConnReader_handleControlPacket_duplicate(t *testing.T) { + h := newConnReadeTestHarness() + + pkt := synPacket{TraceID: 1234} + + log.Printf("%d", h.WRemote.counters[1]) + h.WRemote.SendControlPacket(pkt, h.Remote) + log.Printf("%d", h.WRemote.counters[1]) + + // Rewind the counter. + h.WRemote.counters[1] = h.WRemote.counters[1] - 1 + log.Printf("%d", h.WRemote.counters[1]) + h.WRemote.SendControlPacket(pkt, h.Remote) + + h.R.handleNextPacket() + h.R.handleNextPacket() + + if len(h.Super.Messages) != 1 { + t.Fatal(h.Super.Messages) + } + + msg := h.Super.Messages[0].(controlMsg[synPacket]) + if !reflect.DeepEqual(pkt, msg.Packet) { + t.Fatal(msg.Packet) + } +} + +type invalidPacket struct { +} + +func (p invalidPacket) Marshal(b []byte) []byte { + out := b[:256] + clear(out) + return out +} + +// Testing that an invalid control packet is ignored (fails to parse). +func TestConnReader_handleControlPacket_cantParse(t *testing.T) { + h := newConnReadeTestHarness() + + pkt := invalidPacket{} + + encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) + h.WRemote.writeTo(encrypted, netip.AddrPort{}) + h.R.handleNextPacket() + if len(h.Super.Messages) != 0 { + t.Fatal(h.Super.Messages) + } +} + +// Testing that we can receive a data packet. +func TestConnReader_handleDataPacket(t *testing.T) { + h := newConnReadeTestHarness() + + pkt := make([]byte, 1024) + rand.Read(pkt) + + h.WRemote.SendDataPacket(pkt, h.Remote) + + h.R.handleNextPacket() + + if len(h.IFace.Written) != 1 { + t.Fatal(h.IFace.Written) + } + + if !bytes.Equal(pkt, h.IFace.Written[0]) { + t.Fatal(h.IFace.Written) + } +} + +// Testing that data packet is ignored if route isn't up. +func TestConnReader_handleDataPacket_routeDown(t *testing.T) { + h := newConnReadeTestHarness() + + pkt := make([]byte, 1024) + rand.Read(pkt) + + h.WRemote.SendDataPacket(pkt, h.Remote) + route := h.R.routes[2].Load() + route.Up = false + + h.R.handleNextPacket() + + if len(h.IFace.Written) != 0 { + t.Fatal(h.IFace.Written) + } +} +*/ +// Testing that a duplicate data packet is ignored. + +// Testing that we send a relayed data packet. + +// Testing that a relayed data packet is ignored if destination isn't up. diff --git a/peer/connwriter.go b/peer/connwriter.go new file mode 100644 index 0000000..928b2a0 --- /dev/null +++ b/peer/connwriter.go @@ -0,0 +1,80 @@ +package peer + +import ( + "log" + "net/netip" + "sync" +) + +// ---------------------------------------------------------------------------- + +type connWriter struct { + localIP byte + conn udpWriter + + // For sending control packets. + cBuf1 []byte + cBuf2 []byte + + // For sending data packets. + dBuf1 []byte + dBuf2 []byte + + // Lock around for sending on UDP Conn. + wLock sync.Mutex +} + +func newConnWriter(conn udpWriter, localIP byte) *connWriter { + w := &connWriter{ + localIP: localIP, + conn: conn, + cBuf1: make([]byte, bufferSize), + cBuf2: make([]byte, bufferSize), + dBuf1: make([]byte, bufferSize), + dBuf2: make([]byte, bufferSize), + } + return w +} + +// Not safe for concurrent use. Should only be called by supervisor. +func (w *connWriter) SendControlPacket(pkt marshaller, route *peerRoute) { + enc := encryptControlPacket(w.localIP, route, pkt, w.cBuf1, w.cBuf2) + w.writeTo(enc, route.RemoteAddr) +} + +// Relay control packet. Route must not be nil. +func (w *connWriter) RelayControlPacket(pkt marshaller, route, relay *peerRoute) { + enc := encryptControlPacket(w.localIP, route, pkt, w.cBuf1, w.cBuf2) + enc = encryptDataPacket(w.localIP, route.IP, relay, enc, w.cBuf1) + w.writeTo(enc, relay.RemoteAddr) +} + +// Not safe for concurrent use. Should only be called by ifReader. +func (w *connWriter) SendDataPacket(pkt []byte, route *peerRoute) { + enc := encryptDataPacket(w.localIP, route.IP, route, pkt, w.dBuf1) + w.writeTo(enc, route.RemoteAddr) +} + +// Relay a data packet. Route must not be nil. +func (w *connWriter) RelayDataPacket(pkt []byte, route, relay *peerRoute) { + enc := encryptDataPacket(w.localIP, route.IP, route, pkt, w.dBuf1) + enc = encryptDataPacket(w.localIP, route.IP, relay, enc, w.dBuf2) + w.writeTo(enc, relay.RemoteAddr) +} + +// Safe for concurrent use. Should only be called by connReader. +// +// This function will send pkt to the peer directly. This is used when a peer +// is acting as a relay and is forwarding already encrypted data for another +// peer. +func (w *connWriter) SendEncryptedDataPacket(pkt []byte, route *peerRoute) { + w.writeTo(pkt, route.RemoteAddr) +} + +func (w *connWriter) writeTo(packet []byte, addr netip.AddrPort) { + w.wLock.Lock() + if _, err := w.conn.WriteToUDPAddrPort(packet, addr); err != nil { + log.Printf("[ConnWriter] Failed to write to UDP port: %v", err) + } + w.wLock.Unlock() +} diff --git a/peer/connwriter_test.go b/peer/connwriter_test.go new file mode 100644 index 0000000..14f128e --- /dev/null +++ b/peer/connwriter_test.go @@ -0,0 +1,240 @@ +package peer + +import ( + "bytes" + "net/netip" + "testing" +) + +// ---------------------------------------------------------------------------- + +type testUDPPacket struct { + Addr netip.AddrPort + Data []byte +} + +type testUDPAddrPortWriter struct { + written []testUDPPacket +} + +func (w *testUDPAddrPortWriter) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { + w.written = append(w.written, testUDPPacket{ + Addr: addr, + Data: bytes.Clone(b), + }) + return len(b), nil +} + +func (w *testUDPAddrPortWriter) Written() []testUDPPacket { + out := w.written + w.written = []testUDPPacket{} + return out +} + +// ---------------------------------------------------------------------------- + +type testPacket string + +func (p testPacket) Marshal(b []byte) []byte { + b = b[:len(p)] + copy(b, []byte(p)) + return b +} + +// ---------------------------------------------------------------------------- + +func testConnWriter_getTestRoutes() (local, remote, relayLocal, relayRemote *peerRoute) { + localKeys := generateKeys() + remoteKeys := generateKeys() + + local = newPeerRoute(2) + local.Up = true + local.Relay = false + local.PubSignKey = remoteKeys.PubSignKey + local.ControlCipher = newControlCipher(localKeys.PrivKey, remoteKeys.PubKey) + local.DataCipher = newDataCipher() + local.RemoteAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 2}), 100) + + remote = newPeerRoute(1) + remote.Up = true + remote.Relay = false + remote.PubSignKey = localKeys.PubSignKey + remote.ControlCipher = newControlCipher(remoteKeys.PrivKey, localKeys.PubKey) + remote.DataCipher = local.DataCipher + remote.RemoteAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 100) + + rLocalKeys := generateKeys() + rRemoteKeys := generateKeys() + + relayLocal = newPeerRoute(3) + relayLocal.Up = true + relayLocal.Relay = true + relayLocal.Direct = true + relayLocal.PubSignKey = rRemoteKeys.PubSignKey + relayLocal.ControlCipher = newControlCipher(rLocalKeys.PrivKey, rRemoteKeys.PubKey) + relayLocal.DataCipher = newDataCipher() + relayLocal.RemoteAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 3}), 100) + + relayRemote = newPeerRoute(1) + relayRemote.Up = true + relayRemote.Relay = false + relayRemote.Direct = true + relayRemote.PubSignKey = rLocalKeys.PubSignKey + relayRemote.ControlCipher = newControlCipher(rRemoteKeys.PrivKey, rLocalKeys.PubKey) + relayRemote.DataCipher = relayLocal.DataCipher + relayRemote.RemoteAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 100) + + return +} + +// ---------------------------------------------------------------------------- + +// Testing if we can send a control packet directly to the remote route. +func TestConnWriter_SendControlPacket_direct(t *testing.T) { + route, rRoute, _, _ := testConnWriter_getTestRoutes() + route.Direct = true + + writer := &testUDPAddrPortWriter{} + w := newConnWriter(writer, rRoute.IP) + in := testPacket("hello world!") + + w.SendControlPacket(in, route) + out := writer.Written() + if len(out) != 1 { + t.Fatal(out) + } + + if out[0].Addr != route.RemoteAddr { + t.Fatal(out[0]) + } + + dec, ok := rRoute.ControlCipher.Decrypt(out[0].Data, make([]byte, 1024)) + if !ok { + t.Fatal(ok) + } + if string(dec) != string(in) { + t.Fatal(dec) + } +} + +// Testing if we can relay a packet via an intermediary. +func TestConnWriter_RelayControlPacket_relay(t *testing.T) { + route, rRoute, relay, rRelay := testConnWriter_getTestRoutes() + + writer := &testUDPAddrPortWriter{} + w := newConnWriter(writer, rRoute.IP) + in := testPacket("hello world!") + + w.RelayControlPacket(in, route, relay) + + out := writer.Written() + if len(out) != 1 { + t.Fatal(out) + } + + if out[0].Addr != relay.RemoteAddr { + t.Fatal(out[0]) + } + + dec, ok := rRelay.DataCipher.Decrypt(out[0].Data, make([]byte, 1024)) + if !ok { + t.Fatal(ok) + } + + dec2, ok := rRoute.ControlCipher.Decrypt(dec, make([]byte, 1024)) + if !ok { + t.Fatal(ok) + } + + if string(dec2) != string(in) { + t.Fatal(dec2) + } +} + +// Testing that we can send a data packet directly to a remote route. +func TestConnWriter_SendDataPacket_direct(t *testing.T) { + route, rRoute, _, _ := testConnWriter_getTestRoutes() + route.Direct = true + + writer := &testUDPAddrPortWriter{} + w := newConnWriter(writer, rRoute.IP) + + in := []byte("hello world!") + w.SendDataPacket(in, route) + + out := writer.Written() + if len(out) != 1 { + t.Fatal(out) + } + + if out[0].Addr != route.RemoteAddr { + t.Fatal(out[0]) + } + + dec, ok := rRoute.DataCipher.Decrypt(out[0].Data, make([]byte, 1024)) + if !ok { + t.Fatal(ok) + } + + if !bytes.Equal(dec, in) { + t.Fatal(dec) + } +} + +// Testing that we can relay a data packet via a relay. +func TestConnWriter_RelayDataPacket_relay(t *testing.T) { + route, rRoute, relay, rRelay := testConnWriter_getTestRoutes() + + writer := &testUDPAddrPortWriter{} + w := newConnWriter(writer, rRoute.IP) + in := []byte("Hello world!") + + w.RelayDataPacket(in, route, relay) + + out := writer.Written() + if len(out) != 1 { + t.Fatal(out) + } + + if out[0].Addr != relay.RemoteAddr { + t.Fatal(out[0]) + } + + dec, ok := rRelay.DataCipher.Decrypt(out[0].Data, make([]byte, 1024)) + if !ok { + t.Fatal(ok) + } + + dec2, ok := rRoute.DataCipher.Decrypt(dec, make([]byte, 1024)) + if !ok { + t.Fatal(ok) + } + + if !bytes.Equal(dec2, in) { + t.Fatal(dec2) + } +} + +// Testing that we can send an already encrypted packet. +func TestConnWriter_SendEncryptedDataPacket(t *testing.T) { + route, rRoute, _, _ := testConnWriter_getTestRoutes() + + writer := &testUDPAddrPortWriter{} + w := newConnWriter(writer, rRoute.IP) + in := []byte("Hello world!") + + w.SendEncryptedDataPacket(in, route) + + out := writer.Written() + if len(out) != 1 { + t.Fatal(out) + } + + if out[0].Addr != route.RemoteAddr { + t.Fatal(out[0]) + } + + if !bytes.Equal(out[0].Data, in) { + t.Fatal(out[0]) + } +} diff --git a/peer/controlmessage.go b/peer/controlmessage.go new file mode 100644 index 0000000..d8e9a17 --- /dev/null +++ b/peer/controlmessage.go @@ -0,0 +1,58 @@ +package peer + +import ( + "net/netip" + "vppn/m" +) + +// ---------------------------------------------------------------------------- + +type controlMsg[T any] struct { + SrcIP byte + SrcAddr netip.AddrPort + // TODO: RecvdAt int64 // Unixmilli. + Packet T +} + +func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error) { + switch buf[0] { + + case packetTypeSyn: + packet, err := parseSynPacket(buf) + return controlMsg[synPacket]{ + SrcIP: srcIP, + SrcAddr: srcAddr, + Packet: packet, + }, err + + case packetTypeAck: + packet, err := parseAckPacket(buf) + return controlMsg[ackPacket]{ + SrcIP: srcIP, + SrcAddr: srcAddr, + Packet: packet, + }, err + + case packetTypeProbe: + packet, err := parseProbePacket(buf) + return controlMsg[probePacket]{ + SrcIP: srcIP, + SrcAddr: srcAddr, + Packet: packet, + }, err + + default: + return nil, errUnknownPacketType + } +} + +// ---------------------------------------------------------------------------- + +type peerUpdateMsg struct { + PeerIP byte + Peer *m.Peer +} + +// ---------------------------------------------------------------------------- + +type pingTimerMsg struct{} diff --git a/peer/crypto.go b/peer/crypto.go new file mode 100644 index 0000000..f41c8bf --- /dev/null +++ b/peer/crypto.go @@ -0,0 +1,113 @@ +package peer + +import ( + "crypto/rand" + "log" + "net/netip" + "sync/atomic" + + "golang.org/x/crypto/nacl/box" + "golang.org/x/crypto/nacl/sign" +) + +type cryptoKeys struct { + PubKey []byte + PrivKey []byte + PubSignKey []byte + PrivSignKey []byte +} + +func generateKeys() cryptoKeys { + pubKey, privKey, err := box.GenerateKey(rand.Reader) + if err != nil { + log.Fatalf("Failed to generate encryption keys: %v", err) + } + + pubSignKey, privSignKey, err := sign.GenerateKey(rand.Reader) + if err != nil { + log.Fatalf("Failed to generate signing keys: %v", err) + } + + return cryptoKeys{pubKey[:], privKey[:], pubSignKey[:], privSignKey[:]} +} + +// ---------------------------------------------------------------------------- + +// Route must have a ControlCipher. +func encryptControlPacket( + localIP byte, + route *peerRoute, + pkt marshaller, + tmp []byte, + out []byte, +) []byte { + h := header{ + StreamID: controlStreamID, + Counter: atomic.AddUint64(route.Counter, 1), + SourceIP: localIP, + DestIP: route.IP, + } + tmp = pkt.Marshal(tmp) + return route.ControlCipher.Encrypt(h, tmp, out) +} + +// Returns a controlMsg[PacketType]. Route must have ControlCipher. +func decryptControlPacket( + route *peerRoute, + fromAddr netip.AddrPort, + h header, + encrypted []byte, + tmp []byte, +) (any, error) { + out, ok := route.ControlCipher.Decrypt(encrypted, tmp) + if !ok { + return nil, errDecryptionFailed + } + + if route.DupCheck.IsDup(h.Counter) { + return nil, errDuplicateSeqNum + } + + msg, err := parseControlMsg(h.SourceIP, fromAddr, out) + if err != nil { + return nil, err + } + + return msg, nil +} + +// ---------------------------------------------------------------------------- + +func encryptDataPacket( + localIP byte, + destIP byte, + route *peerRoute, + data []byte, + out []byte, +) []byte { + h := header{ + StreamID: dataStreamID, + Counter: atomic.AddUint64(route.Counter, 1), + SourceIP: localIP, + DestIP: destIP, + } + return route.DataCipher.Encrypt(h, data, out) +} + +func decryptDataPacket( + route *peerRoute, + h header, + encrypted []byte, + out []byte, +) ([]byte, error) { + dec, ok := route.DataCipher.Decrypt(encrypted, out) + if !ok { + return nil, errDecryptionFailed + } + + if route.DupCheck.IsDup(h.Counter) { + return nil, errDuplicateSeqNum + } + + return dec, nil +} diff --git a/peer/crypto_test.go b/peer/crypto_test.go new file mode 100644 index 0000000..29ee377 --- /dev/null +++ b/peer/crypto_test.go @@ -0,0 +1,213 @@ +package peer + +import ( + "bytes" + "crypto/rand" + "errors" + "net/netip" + "reflect" + "testing" +) + +func newRoutePairForTesting() (*peerRoute, *peerRoute) { + keys1 := generateKeys() + keys2 := generateKeys() + + r1 := newPeerRoute(1) + r1.PubSignKey = keys1.PubSignKey + r1.ControlCipher = newControlCipher(keys1.PrivKey, keys2.PubKey) + r1.DataCipher = newDataCipher() + + r2 := newPeerRoute(2) + r2.PubSignKey = keys2.PubSignKey + r2.ControlCipher = newControlCipher(keys2.PrivKey, keys1.PubKey) + r2.DataCipher = r1.DataCipher + + return r1, r2 +} + +func TestDecryptControlPacket(t *testing.T) { + var ( + r1, r2 = newRoutePairForTesting() + tmp = make([]byte, bufferSize) + out = make([]byte, bufferSize) + ) + + in := synPacket{ + TraceID: newTraceID(), + SharedKey: r1.DataCipher.Key(), + Direct: true, + } + + enc := encryptControlPacket(r1.IP, r2, in, tmp, out) + h, ok := parseHeader(enc) + if !ok { + t.Fatal(h, ok) + } + + iMsg, err := decryptControlPacket(r2, netip.AddrPort{}, h, enc, tmp) + if err != nil { + t.Fatal(err) + } + + msg, ok := iMsg.(controlMsg[synPacket]) + if !ok { + t.Fatal(ok) + } + + if !reflect.DeepEqual(msg.Packet, in) { + t.Fatal(msg) + } +} + +func TestDecryptControlPacket_decryptionFailed(t *testing.T) { + var ( + r1, r2 = newRoutePairForTesting() + tmp = make([]byte, bufferSize) + out = make([]byte, bufferSize) + ) + + in := synPacket{ + TraceID: newTraceID(), + SharedKey: r1.DataCipher.Key(), + Direct: true, + } + + enc := encryptControlPacket(r1.IP, r2, in, tmp, out) + h, ok := parseHeader(enc) + if !ok { + t.Fatal(h, ok) + } + + for i := range enc { + x := bytes.Clone(enc) + x[i]++ + _, err := decryptControlPacket(r2, netip.AddrPort{}, h, x, tmp) + if !errors.Is(err, errDecryptionFailed) { + t.Fatal(i, err) + } + } +} + +func TestDecryptControlPacket_duplicate(t *testing.T) { + var ( + r1, r2 = newRoutePairForTesting() + tmp = make([]byte, bufferSize) + out = make([]byte, bufferSize) + ) + + in := synPacket{ + TraceID: newTraceID(), + SharedKey: r1.DataCipher.Key(), + Direct: true, + } + + enc := encryptControlPacket(r1.IP, r2, in, tmp, out) + h, ok := parseHeader(enc) + if !ok { + t.Fatal(h, ok) + } + + if _, err := decryptControlPacket(r2, netip.AddrPort{}, h, enc, tmp); err != nil { + t.Fatal(err) + } + + _, err := decryptControlPacket(r2, netip.AddrPort{}, h, enc, tmp) + if !errors.Is(err, errDuplicateSeqNum) { + t.Fatal(err) + } +} + +func TestDecryptControlPacket_invalidPacket(t *testing.T) { + var ( + r1, r2 = newRoutePairForTesting() + tmp = make([]byte, bufferSize) + out = make([]byte, bufferSize) + ) + + in := testPacket("hello!") + + enc := encryptControlPacket(r1.IP, r2, in, tmp, out) + h, ok := parseHeader(enc) + if !ok { + t.Fatal(h, ok) + } + + _, err := decryptControlPacket(r2, netip.AddrPort{}, h, enc, tmp) + if !errors.Is(err, errUnknownPacketType) { + t.Fatal(err) + } +} + +func TestDecryptDataPacket(t *testing.T) { + var ( + r1, r2 = newRoutePairForTesting() + out = make([]byte, bufferSize) + data = make([]byte, 1024) + ) + + rand.Read(data) + + enc := encryptDataPacket(r1.IP, r2.IP, r2, data, out) + h, ok := parseHeader(enc) + if !ok { + t.Fatal(h, ok) + } + + out, err := decryptDataPacket(r1, h, bytes.Clone(enc), out) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(data, out) { + t.Fatal(data, out) + } +} + +func TestDecryptDataPacket_incorrectCipher(t *testing.T) { + var ( + r1, r2 = newRoutePairForTesting() + out = make([]byte, bufferSize) + data = make([]byte, 1024) + ) + + rand.Read(data) + + enc := encryptDataPacket(r1.IP, r2.IP, r2, data, bytes.Clone(out)) + h, ok := parseHeader(enc) + if !ok { + t.Fatal(h, ok) + } + + r1.DataCipher = newDataCipher() + _, err := decryptDataPacket(r1, h, enc, bytes.Clone(out)) + if !errors.Is(err, errDecryptionFailed) { + t.Fatal(err) + } +} + +func TestDecryptDataPacket_duplicate(t *testing.T) { + var ( + r1, r2 = newRoutePairForTesting() + out = make([]byte, bufferSize) + data = make([]byte, 1024) + ) + + rand.Read(data) + + enc := encryptDataPacket(r1.IP, r2.IP, r2, data, bytes.Clone(out)) + h, ok := parseHeader(enc) + if !ok { + t.Fatal(h, ok) + } + + _, err := decryptDataPacket(r1, h, enc, bytes.Clone(out)) + if err != nil { + t.Fatal(err) + } + + _, err = decryptDataPacket(r1, h, enc, bytes.Clone(out)) + if !errors.Is(err, errDuplicateSeqNum) { + t.Fatal(err) + } +} diff --git a/peer/dupcheck.go b/peer/dupcheck.go new file mode 100644 index 0000000..09b5b11 --- /dev/null +++ b/peer/dupcheck.go @@ -0,0 +1,76 @@ +package peer + +type dupCheck struct { + bitSet + head int + tail int + headCounter uint64 + tailCounter uint64 // Also next expected counter value. +} + +func newDupCheck(headCounter uint64) *dupCheck { + return &dupCheck{ + headCounter: headCounter, + tailCounter: headCounter + 1, + tail: 1, + } +} + +func (dc *dupCheck) IsDup(counter uint64) bool { + + // Before head => it's late, say it's a dup. + if counter < dc.headCounter { + return true + } + + // It's within the counter bounds. + if counter < dc.tailCounter { + index := (int(counter-dc.headCounter) + dc.head) % bitSetSize + if dc.Get(index) { + return true + } + + dc.Set(index) + return false + } + + // It's more than 1 beyond the tail. + delta := counter - dc.tailCounter + + // Full clear. + if delta >= bitSetSize-1 { + dc.ClearAll() + dc.Set(0) + + dc.tail = 1 + dc.head = 2 + dc.tailCounter = counter + 1 + dc.headCounter = dc.tailCounter - bitSetSize + 1 + + return false + } + + // Clear if necessary. + for i := 0; i < int(delta); i++ { + dc.put(false) + } + + dc.put(true) + return false +} + +func (dc *dupCheck) put(set bool) { + if set { + dc.Set(dc.tail) + } else { + dc.Clear(dc.tail) + } + + dc.tail = (dc.tail + 1) % bitSetSize + dc.tailCounter++ + + if dc.head == dc.tail { + dc.head = (dc.head + 1) % bitSetSize + dc.headCounter++ + } +} diff --git a/peer/dupcheck_test.go b/peer/dupcheck_test.go new file mode 100644 index 0000000..2b50d74 --- /dev/null +++ b/peer/dupcheck_test.go @@ -0,0 +1,57 @@ +package peer + +import ( + "testing" +) + +func TestDupCheck(t *testing.T) { + dc := newDupCheck(0) + + for i := range bitSetSize { + if dc.IsDup(uint64(i)) { + t.Fatal("!") + } + } + + type TestCase struct { + Counter uint64 + Dup bool + } + + testCases := []TestCase{ + {511, true}, + {0, true}, + {1, true}, + {2, true}, + {3, true}, + {63, true}, + {256, true}, + {510, true}, + {511, true}, + {512, false}, + {0, true}, + {512, true}, + {513, false}, + {517, false}, + {512, true}, + {513, true}, + {514, false}, + {515, false}, + {516, false}, + {517, true}, + {2512, false}, + {2512, true}, + {2001, true}, + {2002, false}, + {2002, true}, + {4000, false}, + {4000 - 511, true}, // Too old. + {4000 - 510, false}, // Just in the window. + } + + for i, tc := range testCases { + if ok := dc.IsDup(tc.Counter); ok != tc.Dup { + t.Fatal(i, ok, tc) + } + } +} diff --git a/peer/errors.go b/peer/errors.go new file mode 100644 index 0000000..b1e07e2 --- /dev/null +++ b/peer/errors.go @@ -0,0 +1,10 @@ +package peer + +import "errors" + +var ( + errDecryptionFailed = errors.New("decryption failed") + errDuplicateSeqNum = errors.New("duplicate sequence number") + errMalformedPacket = errors.New("malformed packet") + errUnknownPacketType = errors.New("unknown packet type") +) diff --git a/peer/globals.go b/peer/globals.go new file mode 100644 index 0000000..a4d8d65 --- /dev/null +++ b/peer/globals.go @@ -0,0 +1,19 @@ +package peer + +import ( + "net" + "net/netip" +) + +const ( + bufferSize = 1536 + if_mtu = 1200 + if_queue_len = 2048 + controlCipherOverhead = 16 + dataCipherOverhead = 16 + signOverhead = 64 +) + +var multicastAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom( + netip.AddrFrom4([4]byte{224, 0, 0, 157}), + 4560)) diff --git a/peer/header.go b/peer/header.go new file mode 100644 index 0000000..08698dd --- /dev/null +++ b/peer/header.go @@ -0,0 +1,49 @@ +package peer + +import "unsafe" + +// ---------------------------------------------------------------------------- + +const ( + headerSize = 12 + controlStreamID = 2 + controlHeaderSize = 24 + dataStreamID = 1 + dataHeaderSize = 12 +) + +type header struct { + Version byte + StreamID byte + SourceIP byte + DestIP byte + Counter uint64 // Init with time.Now().Unix << 30 to ensure monotonic. +} + +func parseHeader(b []byte) (h header, ok bool) { + if len(b) < headerSize { + return + } + h.Version = b[0] + h.StreamID = b[1] + h.SourceIP = b[2] + h.DestIP = b[3] + h.Counter = *(*uint64)(unsafe.Pointer(&b[4])) + return h, true +} + +func (h *header) Parse(b []byte) { + h.Version = b[0] + h.StreamID = b[1] + h.SourceIP = b[2] + h.DestIP = b[3] + h.Counter = *(*uint64)(unsafe.Pointer(&b[4])) +} + +func (h *header) Marshal(buf []byte) { + buf[0] = h.Version + buf[1] = h.StreamID + buf[2] = h.SourceIP + buf[3] = h.DestIP + *(*uint64)(unsafe.Pointer(&buf[4])) = h.Counter +} diff --git a/peer/header_test.go b/peer/header_test.go new file mode 100644 index 0000000..11e2f8f --- /dev/null +++ b/peer/header_test.go @@ -0,0 +1,21 @@ +package peer + +import "testing" + +func TestHeaderMarshalParse(t *testing.T) { + nIn := header{ + StreamID: 23, + Counter: 3212, + SourceIP: 34, + DestIP: 200, + } + + buf := make([]byte, headerSize) + nIn.Marshal(buf) + + nOut := header{} + nOut.Parse(buf) + if nIn != nOut { + t.Fatal(nIn, nOut) + } +} diff --git a/peer/ifreader.go b/peer/ifreader.go new file mode 100644 index 0000000..61627a2 --- /dev/null +++ b/peer/ifreader.go @@ -0,0 +1,100 @@ +package peer + +import ( + "io" + "log" + "sync/atomic" +) + +type ifReader struct { + iface io.Reader + routes [256]*atomic.Pointer[peerRoute] + relay *atomic.Pointer[peerRoute] + sender dataPacketSender +} + +func newIFReader( + iface io.Reader, + routes [256]*atomic.Pointer[peerRoute], + relay *atomic.Pointer[peerRoute], + sender dataPacketSender, +) *ifReader { + return &ifReader{ + iface: iface, + routes: routes, + relay: relay, + sender: sender, + } +} + +func (r *ifReader) Run() { + var ( + packet = make([]byte, bufferSize) + remoteIP byte + ok bool + ) + + for { + packet = r.readNextPacket(packet) + if remoteIP, ok = r.parsePacket(packet); ok { + r.sendPacket(packet, remoteIP) + } + } +} + +func (r *ifReader) sendPacket(pkt []byte, remoteIP byte) { + route := r.routes[remoteIP].Load() + if !route.Up { + log.Printf("Route not connected: %d", remoteIP) + return + } + + // Direct path => early return. + if route.Direct { + r.sender.SendDataPacket(pkt, route) + return + } + + if relay := r.relay.Load(); relay != nil && relay.Up { + r.sender.RelayDataPacket(pkt, route, relay) + } +} + +// Get next packet, returning packet, and destination ip. +func (r *ifReader) readNextPacket(buf []byte) []byte { + n, err := r.iface.Read(buf[:cap(buf)]) + if err != nil { + log.Fatalf("Failed to read from interface: %v", err) + } + + return buf[:n] +} + +func (r *ifReader) parsePacket(buf []byte) (byte, bool) { + n := len(buf) + if n == 0 { + return 0, false + } + + version := buf[0] >> 4 + + switch version { + case 4: + if n < 20 { + log.Printf("Short IPv4 packet: %d", len(buf)) + return 0, false + } + return buf[19], true + + case 6: + if len(buf) < 40 { + log.Printf("Short IPv6 packet: %d", len(buf)) + return 0, false + } + return buf[39], true + + default: + log.Printf("Invalid IP packet version: %v", version) + return 0, false + } +} diff --git a/peer/ifreader_test.go b/peer/ifreader_test.go new file mode 100644 index 0000000..c5efb30 --- /dev/null +++ b/peer/ifreader_test.go @@ -0,0 +1,232 @@ +package peer + +import ( + "bytes" + "net" + "reflect" + "sync/atomic" + "testing" +) + +// Test that we parse IPv4 packets correctly. +func TestIFReader_parsePacket_ipv4(t *testing.T) { + r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil) + + pkt := make([]byte, 1234) + pkt[0] = 4 << 4 + pkt[19] = 128 + + if ip, ok := r.parsePacket(pkt); !ok || ip != 128 { + t.Fatal(ip, ok) + } +} + +// Test that we parse IPv6 packets correctly. +func TestIFReader_parsePacket_ipv6(t *testing.T) { + r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil) + + pkt := make([]byte, 1234) + pkt[0] = 6 << 4 + pkt[39] = 42 + + if ip, ok := r.parsePacket(pkt); !ok || ip != 42 { + t.Fatal(ip, ok) + } +} + +// Test that empty packets work as expected. +func TestIFReader_parsePacket_emptyPacket(t *testing.T) { + r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil) + + pkt := make([]byte, 0) + if ip, ok := r.parsePacket(pkt); ok { + t.Fatal(ip, ok) + } +} + +// Test that invalid IP versions fail. +func TestIFReader_parsePacket_invalidIPVersion(t *testing.T) { + r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil) + + for i := byte(1); i < 16; i++ { + if i == 4 || i == 6 { + continue + } + pkt := make([]byte, 1234) + pkt[0] = i << 4 + + if ip, ok := r.parsePacket(pkt); ok { + t.Fatal(i, ip, ok) + } + } +} + +// Test that short IPv4 packets fail. +func TestIFReader_parsePacket_shortIPv4(t *testing.T) { + r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil) + + pkt := make([]byte, 19) + pkt[0] = 4 << 4 + + if ip, ok := r.parsePacket(pkt); ok { + t.Fatal(ip, ok) + } +} + +// Test that short IPv6 packets fail. +func TestIFReader_parsePacket_shortIPv6(t *testing.T) { + r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil) + + pkt := make([]byte, 39) + pkt[0] = 6 << 4 + + if ip, ok := r.parsePacket(pkt); ok { + t.Fatal(ip, ok) + } +} + +// Test that we can read a packet. +func TestIFReader_readNextpacket(t *testing.T) { + in, out := net.Pipe() + r := newIFReader(out, [256]*atomic.Pointer[peerRoute]{}, nil, nil) + defer in.Close() + defer out.Close() + + go in.Write([]byte("hello world!")) + + pkt := r.readNextPacket(make([]byte, bufferSize)) + if !bytes.Equal(pkt, []byte("hello world!")) { + t.Fatalf("%s", pkt) + } +} + +// ---------------------------------------------------------------------------- + +type sentPacket struct { + Relayed bool + Packet []byte + Route peerRoute + Relay peerRoute +} + +type sendPacketTestHarness struct { + Packets []sentPacket +} + +func (h *sendPacketTestHarness) SendDataPacket(pkt []byte, route *peerRoute) { + h.Packets = append(h.Packets, sentPacket{ + Packet: bytes.Clone(pkt), + Route: *route, + }) +} + +func (h *sendPacketTestHarness) RelayDataPacket(pkt []byte, route, relay *peerRoute) { + h.Packets = append(h.Packets, sentPacket{ + Relayed: true, + Packet: bytes.Clone(pkt), + Route: *route, + Relay: *relay, + }) +} + +func newIFReaderForSendPacketTesting() (*ifReader, *sendPacketTestHarness) { + h := &sendPacketTestHarness{} + + routes := [256]*atomic.Pointer[peerRoute]{} + for i := range routes { + routes[i] = &atomic.Pointer[peerRoute]{} + routes[i].Store(&peerRoute{}) + } + relay := &atomic.Pointer[peerRoute]{} + r := newIFReader(nil, routes, relay, h) + return r, h +} + +// Testing that we can send a packet directly. +func TestIFReader_sendPacket_direct(t *testing.T) { + r, h := newIFReaderForSendPacketTesting() + + route := r.routes[2].Load() + route.Up = true + route.Direct = true + + in := []byte("hello world") + + r.sendPacket(in, 2) + if len(h.Packets) != 1 { + t.Fatal(h.Packets) + } + + expected := sentPacket{ + Relayed: false, + Packet: in, + Route: *route, + } + + if !reflect.DeepEqual(h.Packets[0], expected) { + t.Fatal(h.Packets[0]) + } +} + +// Testing that we don't send a packet if route isn't up. +func TestIFReader_sendPacket_directNotUp(t *testing.T) { + r, h := newIFReaderForSendPacketTesting() + + route := r.routes[2].Load() + route.Direct = true + + in := []byte("hello world") + + r.sendPacket(in, 2) + if len(h.Packets) != 0 { + t.Fatal(h.Packets) + } +} + +// Testing that we can send a packet via a relay. +func TestIFReader_sendPacket_relayed(t *testing.T) { + r, h := newIFReaderForSendPacketTesting() + + route := r.routes[2].Load() + route.Up = true + route.Direct = false + + relay := r.routes[3].Load() + r.relay.Store(relay) + relay.Up = true + relay.Direct = true + + in := []byte("hello world") + + r.sendPacket(in, 2) + if len(h.Packets) != 1 { + t.Fatal(h.Packets) + } + + expected := sentPacket{ + Relayed: true, + Packet: in, + Route: *route, + Relay: *relay, + } + + if !reflect.DeepEqual(h.Packets[0], expected) { + t.Fatal(h.Packets[0]) + } +} + +// Testing that we don't try to send on a nil relay IP. +func TestIFReader_sendPacket_nilRealy(t *testing.T) { + r, h := newIFReaderForSendPacketTesting() + + route := r.routes[2].Load() + route.Up = true + route.Direct = false + + in := []byte("hello world") + + r.sendPacket(in, 2) + if len(h.Packets) != 0 { + t.Fatal(h.Packets) + } +} diff --git a/peer/ifwriter.go b/peer/ifwriter.go new file mode 100644 index 0000000..59e2e26 --- /dev/null +++ b/peer/ifwriter.go @@ -0,0 +1,5 @@ +package peer + +import "io" + +type ifWriter io.Writer diff --git a/peer/interfaces.go b/peer/interfaces.go new file mode 100644 index 0000000..84f9c99 --- /dev/null +++ b/peer/interfaces.go @@ -0,0 +1,28 @@ +package peer + +import "net/netip" + +type udpReader interface { + ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) +} + +type udpWriter interface { + WriteToUDPAddrPort([]byte, netip.AddrPort) (int, error) +} + +type marshaller interface { + Marshal([]byte) []byte +} + +type dataPacketSender interface { + SendDataPacket(pkt []byte, route *peerRoute) + RelayDataPacket(pkt []byte, route, relay *peerRoute) +} + +type encryptedPacketSender interface { + SendEncryptedDataPacket(pkt []byte, route *peerRoute) +} + +type controlMsgHandler interface { + HandleControlMsg(pkt any) +} diff --git a/peer/mcwriter.go b/peer/mcwriter.go new file mode 100644 index 0000000..db9a76b --- /dev/null +++ b/peer/mcwriter.go @@ -0,0 +1,62 @@ +package peer + +import ( + "log" + "net" + + "golang.org/x/crypto/nacl/sign" +) + +// ---------------------------------------------------------------------------- + +type mcUDPWriter interface { + WriteToUDP([]byte, *net.UDPAddr) (int, error) +} + +// ---------------------------------------------------------------------------- + +func createLocalDiscoveryPacket(localIP byte, signingKey []byte) []byte { + h := header{ + SourceIP: localIP, + DestIP: 255, + } + buf := make([]byte, headerSize) + h.Marshal(buf) + out := make([]byte, headerSize+signOverhead) + return sign.Sign(out[:0], buf, (*[64]byte)(signingKey)) +} + +func headerFromLocalDiscoveryPacket(pkt []byte) (h header, ok bool) { + if len(pkt) != headerSize+signOverhead { + return + } + + h.Parse(pkt[signOverhead:]) + ok = true + return +} + +func verifyLocalDiscoveryPacket(pkt, buf []byte, pubSignKey []byte) bool { + _, ok := sign.Open(buf[:0], pkt, (*[32]byte)(pubSignKey)) + return ok +} + +// ---------------------------------------------------------------------------- + +type mcWriter struct { + conn mcUDPWriter + discoveryPacket []byte +} + +func newMCWriter(conn mcUDPWriter, localIP byte, signingKey []byte) *mcWriter { + return &mcWriter{ + conn: conn, + discoveryPacket: createLocalDiscoveryPacket(localIP, signingKey), + } +} + +func (w *mcWriter) SendLocalDiscovery() { + if _, err := w.conn.WriteToUDP(w.discoveryPacket, multicastAddr); err != nil { + log.Printf("[MCWriter] Failed to write multicast UDP packet: %v", err) + } +} diff --git a/peer/mcwriter_test.go b/peer/mcwriter_test.go new file mode 100644 index 0000000..ffef05d --- /dev/null +++ b/peer/mcwriter_test.go @@ -0,0 +1,102 @@ +package peer + +import ( + "bytes" + "net" + "testing" +) + +// ---------------------------------------------------------------------------- + +// Testing that we can create and verify a local discovery packet. +func TestVerifyLocalDiscoveryPacket_valid(t *testing.T) { + keys := generateKeys() + + created := createLocalDiscoveryPacket(55, keys.PrivSignKey) + + header, ok := headerFromLocalDiscoveryPacket(created) + if !ok { + t.Fatal(ok) + } + if header.SourceIP != 55 || header.DestIP != 255 { + t.Fatal(header) + } + + if !verifyLocalDiscoveryPacket(created, make([]byte, 1024), keys.PubSignKey) { + t.Fatal("Not valid") + } +} + +// Testing that we don't try to parse short packets. +func TestVerifyLocalDiscoveryPacket_tooShort(t *testing.T) { + keys := generateKeys() + + created := createLocalDiscoveryPacket(55, keys.PrivSignKey) + + _, ok := headerFromLocalDiscoveryPacket(created[:len(created)-1]) + if ok { + t.Fatal(ok) + } +} + +// Testing that modifying a packet makes it invalid. +func TestVerifyLocalDiscoveryPacket_invalid(t *testing.T) { + keys := generateKeys() + + created := createLocalDiscoveryPacket(55, keys.PrivSignKey) + buf := make([]byte, 1024) + for i := range created { + modified := bytes.Clone(created) + modified[i]++ + if verifyLocalDiscoveryPacket(modified, buf, keys.PubSignKey) { + t.Fatal("Verification should have failed.") + } + } +} + +// ---------------------------------------------------------------------------- + +type testUDPWriter struct { + written [][]byte +} + +func (w *testUDPWriter) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) { + w.written = append(w.written, bytes.Clone(b)) + return len(b), nil +} + +func (w *testUDPWriter) Written() [][]byte { + out := w.written + w.written = [][]byte{} + return out +} + +// ---------------------------------------------------------------------------- + +// Testing that the mcWriter sends local discovery packets as expected. +func TestMCWriter_SendLocalDiscovery(t *testing.T) { + keys := generateKeys() + writer := &testUDPWriter{} + + mcw := newMCWriter(writer, 42, keys.PrivSignKey) + mcw.SendLocalDiscovery() + + out := writer.Written() + if len(out) != 1 { + t.Fatal(out) + } + + pkt := out[0] + + header, ok := headerFromLocalDiscoveryPacket(pkt) + if !ok { + t.Fatal(ok) + } + if header.SourceIP != 42 || header.DestIP != 255 { + t.Fatal(header) + } + + if !verifyLocalDiscoveryPacket(pkt, make([]byte, 1024), keys.PubSignKey) { + t.Fatal("Verification should succeed.") + } +} diff --git a/peer/packets-util.go b/peer/packets-util.go new file mode 100644 index 0000000..bda33b9 --- /dev/null +++ b/peer/packets-util.go @@ -0,0 +1,190 @@ +package peer + +import ( + "net/netip" + "sync/atomic" + "time" + "unsafe" +) + +var traceIDCounter uint64 = uint64(time.Now().Unix()<<30) + 1 + +func newTraceID() uint64 { + return atomic.AddUint64(&traceIDCounter, 1) +} + +// ---------------------------------------------------------------------------- + +type binWriter struct { + b []byte + i int +} + +func newBinWriter(buf []byte) *binWriter { + buf = buf[:cap(buf)] + return &binWriter{buf, 0} +} + +func (w *binWriter) Bool(b bool) *binWriter { + if b { + return w.Byte(1) + } + return w.Byte(0) +} + +func (w *binWriter) Byte(b byte) *binWriter { + w.b[w.i] = b + w.i++ + return w +} + +func (w *binWriter) SharedKey(key [32]byte) *binWriter { + copy(w.b[w.i:w.i+32], key[:]) + w.i += 32 + return w +} + +func (w *binWriter) Uint16(x uint16) *binWriter { + *(*uint16)(unsafe.Pointer(&w.b[w.i])) = x + w.i += 2 + return w +} + +func (w *binWriter) Uint64(x uint64) *binWriter { + *(*uint64)(unsafe.Pointer(&w.b[w.i])) = x + w.i += 8 + return w +} + +func (w *binWriter) Int64(x int64) *binWriter { + *(*int64)(unsafe.Pointer(&w.b[w.i])) = x + w.i += 8 + return w +} + +func (w *binWriter) AddrPort(addrPort netip.AddrPort) *binWriter { + w.Bool(addrPort.IsValid()) + addr := addrPort.Addr().As16() + copy(w.b[w.i:w.i+16], addr[:]) + w.i += 16 + return w.Uint16(addrPort.Port()) +} + +func (w *binWriter) AddrPortArray(l [8]netip.AddrPort) *binWriter { + for _, addrPort := range l { + w.AddrPort(addrPort) + } + return w +} + +func (w *binWriter) Build() []byte { + return w.b[:w.i] +} + +// ---------------------------------------------------------------------------- + +type binReader struct { + b []byte + i int + err error +} + +func newBinReader(buf []byte) *binReader { + return &binReader{b: buf} +} + +func (r *binReader) hasBytes(n int) bool { + if r.err != nil || (len(r.b)-r.i) < n { + r.err = errMalformedPacket + return false + } + return true +} + +func (r *binReader) Bool(b *bool) *binReader { + var bb byte + r.Byte(&bb) + *b = bb != 0 + return r +} + +func (r *binReader) Byte(b *byte) *binReader { + if !r.hasBytes(1) { + return r + } + *b = r.b[r.i] + r.i++ + return r +} + +func (r *binReader) SharedKey(x *[32]byte) *binReader { + if !r.hasBytes(32) { + return r + } + *x = ([32]byte)(r.b[r.i : r.i+32]) + r.i += 32 + return r +} + +func (r *binReader) Uint16(x *uint16) *binReader { + if !r.hasBytes(2) { + return r + } + *x = *(*uint16)(unsafe.Pointer(&r.b[r.i])) + r.i += 2 + return r +} + +func (r *binReader) Uint64(x *uint64) *binReader { + if !r.hasBytes(8) { + return r + } + *x = *(*uint64)(unsafe.Pointer(&r.b[r.i])) + r.i += 8 + return r +} + +func (r *binReader) Int64(x *int64) *binReader { + if !r.hasBytes(8) { + return r + } + *x = *(*int64)(unsafe.Pointer(&r.b[r.i])) + r.i += 8 + return r +} + +func (r *binReader) AddrPort(x *netip.AddrPort) *binReader { + if !r.hasBytes(19) { + return r + } + + var ( + valid bool + port uint16 + ) + + r.Bool(&valid) + addr := netip.AddrFrom16(([16]byte)(r.b[r.i : r.i+16])).Unmap() + r.i += 16 + + r.Uint16(&port) + + if valid { + *x = netip.AddrPortFrom(addr, port) + } else { + *x = netip.AddrPort{} + } + + return r +} + +func (r *binReader) AddrPortArray(x *[8]netip.AddrPort) *binReader { + for i := range x { + r.AddrPort(&x[i]) + } + return r +} + +func (r *binReader) Error() error { + return r.err +} diff --git a/peer/packets-util_test.go b/peer/packets-util_test.go new file mode 100644 index 0000000..5a518d7 --- /dev/null +++ b/peer/packets-util_test.go @@ -0,0 +1,56 @@ +package peer + +import ( + "net/netip" + "reflect" + "testing" +) + +func TestBinWriteRead(t *testing.T) { + buf := make([]byte, 1024) + + type Item struct { + Type byte + TraceID uint64 + Addrs [8]netip.AddrPort + DestAddr netip.AddrPort + } + + in := Item{ + 1, + 2, + [8]netip.AddrPort{}, + netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 22), + } + + in.Addrs[0] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{0, 1, 2, 3}), 20) + in.Addrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 5}), 22) + in.Addrs[3] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 3}), 23) + in.Addrs[4] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 4}), 24) + in.Addrs[5] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 5}), 25) + in.Addrs[6] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 6}), 26) + in.Addrs[7] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{7, 8, 9, 7}), 27) + + buf = newBinWriter(buf). + Byte(in.Type). + Uint64(in.TraceID). + AddrPort(in.DestAddr). + AddrPortArray(in.Addrs). + Build() + + out := Item{} + + err := newBinReader(buf). + Byte(&out.Type). + Uint64(&out.TraceID). + AddrPort(&out.DestAddr). + AddrPortArray(&out.Addrs). + Error() + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(in, out) { + t.Fatal(in, out) + } +} diff --git a/peer/packets.go b/peer/packets.go new file mode 100644 index 0000000..f7f1f85 --- /dev/null +++ b/peer/packets.go @@ -0,0 +1,123 @@ +package peer + +import ( + "net/netip" +) + +const ( + packetTypeSyn = iota + 1 + packetTypeSynAck + packetTypeAck + packetTypeProbe + packetTypeAddrDiscovery +) + +// ---------------------------------------------------------------------------- + +type synPacket struct { + TraceID uint64 // TraceID to match response w/ request. + // TODO: SentAt int64 // Unixmilli. + SharedKey [32]byte // Our shared key. + Direct bool + PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. +} + +func (p synPacket) Marshal(buf []byte) []byte { + return newBinWriter(buf). + Byte(packetTypeSyn). + Uint64(p.TraceID). + SharedKey(p.SharedKey). + Bool(p.Direct). + AddrPort(p.PossibleAddrs[0]). + AddrPort(p.PossibleAddrs[1]). + AddrPort(p.PossibleAddrs[2]). + AddrPort(p.PossibleAddrs[3]). + AddrPort(p.PossibleAddrs[4]). + AddrPort(p.PossibleAddrs[5]). + AddrPort(p.PossibleAddrs[6]). + AddrPort(p.PossibleAddrs[7]). + Build() +} + +func parseSynPacket(buf []byte) (p synPacket, err error) { + err = newBinReader(buf[1:]). + Uint64(&p.TraceID). + SharedKey(&p.SharedKey). + Bool(&p.Direct). + AddrPort(&p.PossibleAddrs[0]). + AddrPort(&p.PossibleAddrs[1]). + AddrPort(&p.PossibleAddrs[2]). + AddrPort(&p.PossibleAddrs[3]). + AddrPort(&p.PossibleAddrs[4]). + AddrPort(&p.PossibleAddrs[5]). + AddrPort(&p.PossibleAddrs[6]). + AddrPort(&p.PossibleAddrs[7]). + Error() + return +} + +// ---------------------------------------------------------------------------- + +type ackPacket struct { + TraceID uint64 + ToAddr netip.AddrPort + PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. +} + +func (p ackPacket) Marshal(buf []byte) []byte { + return newBinWriter(buf). + Byte(packetTypeAck). + Uint64(p.TraceID). + AddrPort(p.ToAddr). + AddrPort(p.PossibleAddrs[0]). + AddrPort(p.PossibleAddrs[1]). + AddrPort(p.PossibleAddrs[2]). + AddrPort(p.PossibleAddrs[3]). + AddrPort(p.PossibleAddrs[4]). + AddrPort(p.PossibleAddrs[5]). + AddrPort(p.PossibleAddrs[6]). + AddrPort(p.PossibleAddrs[7]). + Build() +} + +func parseAckPacket(buf []byte) (p ackPacket, err error) { + err = newBinReader(buf[1:]). + Uint64(&p.TraceID). + AddrPort(&p.ToAddr). + AddrPort(&p.PossibleAddrs[0]). + AddrPort(&p.PossibleAddrs[1]). + AddrPort(&p.PossibleAddrs[2]). + AddrPort(&p.PossibleAddrs[3]). + AddrPort(&p.PossibleAddrs[4]). + AddrPort(&p.PossibleAddrs[5]). + AddrPort(&p.PossibleAddrs[6]). + AddrPort(&p.PossibleAddrs[7]). + Error() + return +} + +// ---------------------------------------------------------------------------- + +// A probeReqPacket is sent from a client to a server to determine if direct +// UDP communication can be used. +type probePacket struct { + TraceID uint64 +} + +func (p probePacket) Marshal(buf []byte) []byte { + return newBinWriter(buf). + Byte(packetTypeProbe). + Uint64(p.TraceID). + Build() +} + +func parseProbePacket(buf []byte) (p probePacket, err error) { + err = newBinReader(buf[1:]). + Uint64(&p.TraceID). + Error() + return +} + +// ---------------------------------------------------------------------------- + +type localDiscoveryPacket struct{} diff --git a/peer/packets_test.go b/peer/packets_test.go new file mode 100644 index 0000000..333deff --- /dev/null +++ b/peer/packets_test.go @@ -0,0 +1 @@ +package peer diff --git a/peer/state.go b/peer/state.go new file mode 100644 index 0000000..2ef248b --- /dev/null +++ b/peer/state.go @@ -0,0 +1,29 @@ +package peer + +import ( + "net/netip" + "time" +) + +type peerRoute struct { + IP byte // VPN IP of peer (last byte). + Up bool // True if data can be sent on the route. + Relay bool // True if the peer is a relay. + Direct bool // True if this is a direct connection. + PubSignKey []byte + ControlCipher *controlCipher + DataCipher *dataCipher + RemoteAddr netip.AddrPort // Remote address if directly connected. + + Counter *uint64 // For sending to. Atomic access only. + DupCheck *dupCheck // For receiving from. Not safe for concurrent use. +} + +func newPeerRoute(ip byte) *peerRoute { + counter := uint64(time.Now().Unix()<<30 + 1) + return &peerRoute{ + IP: ip, + Counter: &counter, + DupCheck: newDupCheck(0), + } +} -- 2.39.5 From a730f95af223fa308d5bc4d46156bb6f280b832e Mon Sep 17 00:00:00 2001 From: jdl Date: Thu, 30 Jan 2025 20:01:20 +0100 Subject: [PATCH 03/26] WIP --- peer/connreader_test.go | 62 +++++++++++++++-------------------------- peer/crypto.go | 2 ++ peer/globals.go | 4 +++ 3 files changed, 29 insertions(+), 39 deletions(-) diff --git a/peer/connreader_test.go b/peer/connreader_test.go index 7ef4ad8..5d91547 100644 --- a/peer/connreader_test.go +++ b/peer/connreader_test.go @@ -1,6 +1,13 @@ package peer -/* +import ( + "bytes" + "net/netip" + "reflect" + "sync/atomic" + "testing" +) + type mockIfWriter struct { Written [][]byte } @@ -135,7 +142,7 @@ func TestConnReader_handleNextPacket_unknownStreamID(t *testing.T) { pkt := synPacket{TraceID: 1234} - encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) + encrypted := encryptControlPacket(1, h.Remote, pkt, newBuf(), newBuf()) var header header header.Parse(encrypted) header.StreamID = 100 @@ -154,7 +161,8 @@ func TestConnReader_handleControlPacket_noCipher(t *testing.T) { pkt := synPacket{TraceID: 1234} - encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) + //encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) + encrypted := encryptControlPacket(1, h.Remote, pkt, newBuf(), newBuf()) var header header header.Parse(encrypted) header.SourceIP = 10 @@ -173,7 +181,7 @@ func TestConnReader_handleControlPacket_incorrectDest(t *testing.T) { pkt := synPacket{TraceID: 1234} - encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) + encrypted := encryptControlPacket(2, h.Remote, pkt, newBuf(), newBuf()) var header header header.Parse(encrypted) header.DestIP++ @@ -192,7 +200,7 @@ func TestConnReader_handleControlPacket_modified(t *testing.T) { pkt := synPacket{TraceID: 1234} - encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) + encrypted := encryptControlPacket(2, h.Remote, pkt, newBuf(), newBuf()) encrypted[len(encrypted)-1]++ h.WRemote.writeTo(encrypted, netip.AddrPort{}) @@ -202,19 +210,21 @@ func TestConnReader_handleControlPacket_modified(t *testing.T) { } } -type emptyPacket struct{} +type unknownPacket struct{} -func (p emptyPacket) Marshal(buf []byte) []byte { - return buf[:0] +func (p unknownPacket) Marshal(buf []byte) []byte { + buf = buf[:1] + buf[0] = 100 + return buf } // Testing that an empty control packet is ignored. -func TestConnReader_handleControlPacket_empty(t *testing.T) { +func TestConnReader_handleControlPacket_unknownPacketType(t *testing.T) { h := newConnReadeTestHarness() - pkt := emptyPacket{} + pkt := unknownPacket{} - encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) + encrypted := encryptControlPacket(2, h.Remote, pkt, newBuf(), newBuf()) h.WRemote.writeTo(encrypted, netip.AddrPort{}) h.R.handleNextPacket() if len(h.Super.Messages) != 0 { @@ -228,13 +238,8 @@ func TestConnReader_handleControlPacket_duplicate(t *testing.T) { pkt := synPacket{TraceID: 1234} - log.Printf("%d", h.WRemote.counters[1]) h.WRemote.SendControlPacket(pkt, h.Remote) - log.Printf("%d", h.WRemote.counters[1]) - - // Rewind the counter. - h.WRemote.counters[1] = h.WRemote.counters[1] - 1 - log.Printf("%d", h.WRemote.counters[1]) + *h.Remote.Counter = *h.Remote.Counter - 1 h.WRemote.SendControlPacket(pkt, h.Remote) h.R.handleNextPacket() @@ -250,28 +255,7 @@ func TestConnReader_handleControlPacket_duplicate(t *testing.T) { } } -type invalidPacket struct { -} - -func (p invalidPacket) Marshal(b []byte) []byte { - out := b[:256] - clear(out) - return out -} - -// Testing that an invalid control packet is ignored (fails to parse). -func TestConnReader_handleControlPacket_cantParse(t *testing.T) { - h := newConnReadeTestHarness() - - pkt := invalidPacket{} - - encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) - h.WRemote.writeTo(encrypted, netip.AddrPort{}) - h.R.handleNextPacket() - if len(h.Super.Messages) != 0 { - t.Fatal(h.Super.Messages) - } -} +/* // Testing that we can receive a data packet. func TestConnReader_handleDataPacket(t *testing.T) { diff --git a/peer/crypto.go b/peer/crypto.go index f41c8bf..3bc970f 100644 --- a/peer/crypto.go +++ b/peer/crypto.go @@ -52,6 +52,8 @@ func encryptControlPacket( } // Returns a controlMsg[PacketType]. Route must have ControlCipher. +// +// This function also drops packets with duplicate sequence numbers. func decryptControlPacket( route *peerRoute, fromAddr netip.AddrPort, diff --git a/peer/globals.go b/peer/globals.go index a4d8d65..4733ac8 100644 --- a/peer/globals.go +++ b/peer/globals.go @@ -17,3 +17,7 @@ const ( var multicastAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom( netip.AddrFrom4([4]byte{224, 0, 0, 157}), 4560)) + +func newBuf() []byte { + return make([]byte, bufferSize) +} -- 2.39.5 From b2e63f6c033d681b3e0abef2af8fb1b146b1b7fd Mon Sep 17 00:00:00 2001 From: jdl Date: Thu, 30 Jan 2025 20:02:14 +0100 Subject: [PATCH 04/26] WIP --- peer/connreader_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/peer/connreader_test.go b/peer/connreader_test.go index 5d91547..fe8e6cb 100644 --- a/peer/connreader_test.go +++ b/peer/connreader_test.go @@ -236,7 +236,7 @@ func TestConnReader_handleControlPacket_unknownPacketType(t *testing.T) { func TestConnReader_handleControlPacket_duplicate(t *testing.T) { h := newConnReadeTestHarness() - pkt := synPacket{TraceID: 1234} + pkt := ackPacket{TraceID: 1234} h.WRemote.SendControlPacket(pkt, h.Remote) *h.Remote.Counter = *h.Remote.Counter - 1 @@ -249,7 +249,7 @@ func TestConnReader_handleControlPacket_duplicate(t *testing.T) { t.Fatal(h.Super.Messages) } - msg := h.Super.Messages[0].(controlMsg[synPacket]) + msg := h.Super.Messages[0].(controlMsg[ackPacket]) if !reflect.DeepEqual(pkt, msg.Packet) { t.Fatal(msg.Packet) } -- 2.39.5 From 8c618616dd4fd8084da862a309d33def1198a3a4 Mon Sep 17 00:00:00 2001 From: jdl Date: Fri, 31 Jan 2025 21:32:47 +0100 Subject: [PATCH 05/26] wip --- peer/connreader.go | 48 ++++++++++--------------- peer/connreader_test.go | 77 ++++++++++++++++++++++++++++++++++------- peer/connwriter.go | 36 +++++++++---------- peer/connwriter_test.go | 28 +++++++-------- peer/crypto.go | 31 +++++++++-------- peer/crypto_test.go | 41 ++++++---------------- peer/header.go | 7 ++-- peer/ifreader.go | 22 ++++++------ peer/ifreader_test.go | 40 ++++++++++----------- peer/interfaces.go | 15 +++++--- peer/mcreader.go | 57 ++++++++++++++++++++++++++++++ peer/mcwriter.go | 9 ----- peer/state.go | 16 ++++----- 13 files changed, 250 insertions(+), 177 deletions(-) create mode 100644 peer/mcreader.go diff --git a/peer/connreader.go b/peer/connreader.go index 757b37c..b127030 100644 --- a/peer/connreader.go +++ b/peer/connreader.go @@ -12,11 +12,10 @@ type connReader struct { sender encryptedPacketSender super controlMsgHandler localIP byte - routes [256]*atomic.Pointer[peerRoute] + peers [256]*atomic.Pointer[RemotePeer] - buf []byte - decBuf []byte - dupChecks [256]*dupCheck + buf []byte + decBuf []byte } func newConnReader( @@ -25,7 +24,7 @@ func newConnReader( sender encryptedPacketSender, super controlMsgHandler, localIP byte, - routes [256]*atomic.Pointer[peerRoute], + peers [256]*atomic.Pointer[RemotePeer], ) *connReader { return &connReader{ conn: conn, @@ -33,15 +32,9 @@ func newConnReader( sender: sender, super: super, localIP: localIP, - routes: routes, + peers: peers, buf: make([]byte, bufferSize), decBuf: make([]byte, bufferSize), - dupChecks: func() (out [256]*dupCheck) { - for i := range out { - out[i] = newDupCheck(0) - } - return - }(), } } @@ -69,19 +62,16 @@ func (r *connReader) handleNextPacket() { remoteAddr = netip.AddrPortFrom(remoteAddr.Addr().Unmap(), remoteAddr.Port()) buf = buf[:n] - h, ok := parseHeader(buf) - if !ok { - return - } + h := parseHeader(buf) - route := r.routes[h.SourceIP].Load() + peer := r.peers[h.SourceIP].Load() switch h.StreamID { case controlStreamID: - r.handleControlPacket(route, remoteAddr, h, buf) + r.handleControlPacket(peer, remoteAddr, h, buf) case dataStreamID: - r.handleDataPacket(route, h, buf) + r.handleDataPacket(peer, h, buf) default: r.logf("Unknown stream ID: %d", h.StreamID) @@ -89,12 +79,12 @@ func (r *connReader) handleNextPacket() { } func (r *connReader) handleControlPacket( - route *peerRoute, + peer *RemotePeer, addr netip.AddrPort, h header, enc []byte, ) { - if route.ControlCipher == nil { + if peer.ControlCipher == nil { return } @@ -103,7 +93,7 @@ func (r *connReader) handleControlPacket( return } - msg, err := decryptControlPacket(route, addr, h, enc, r.decBuf) + msg, err := decryptControlPacket(peer, addr, h, enc, r.decBuf) if err != nil { r.logf("Failed to decrypt control packet: %v", err) return @@ -112,13 +102,13 @@ func (r *connReader) handleControlPacket( r.super.HandleControlMsg(msg) } -func (r *connReader) handleDataPacket(route *peerRoute, h header, enc []byte) { - if !route.Up { +func (r *connReader) handleDataPacket(peer *RemotePeer, h header, enc []byte) { + if !peer.Up { r.logf("Not connected (recv).") return } - data, err := decryptDataPacket(route, h, enc, r.decBuf) + data, err := decryptDataPacket(peer, h, enc, r.decBuf) if err != nil { r.logf("Failed to decrypt data packet: %v", err) return @@ -131,11 +121,11 @@ func (r *connReader) handleDataPacket(route *peerRoute, h header, enc []byte) { return } - destRoute := r.routes[h.DestIP].Load() - if !destRoute.Up { - r.logf("Not connected (relay): %d", destRoute.IP) + destPeer := r.peers[h.DestIP].Load() + if !destPeer.Up { + r.logf("Not connected (relay): %d", destPeer.IP) return } - r.sender.SendEncryptedDataPacket(data, destRoute) + r.sender.SendEncryptedDataPacket(data, destPeer) } diff --git a/peer/connreader_test.go b/peer/connreader_test.go index fe8e6cb..39da83c 100644 --- a/peer/connreader_test.go +++ b/peer/connreader_test.go @@ -2,6 +2,7 @@ package peer import ( "bytes" + "crypto/rand" "net/netip" "reflect" "sync/atomic" @@ -19,14 +20,14 @@ func (w *mockIfWriter) Write(b []byte) (int, error) { type mockEncryptedPacket struct { Packet []byte - Route *peerRoute + Route *RemotePeer } type mockEncryptedPacketSender struct { Sent []mockEncryptedPacket } -func (m *mockEncryptedPacketSender) SendEncryptedDataPacket(pkt []byte, route *peerRoute) { +func (m *mockEncryptedPacketSender) SendEncryptedDataPacket(pkt []byte, route *RemotePeer) { m.Sent = append(m.Sent, mockEncryptedPacket{ Packet: bytes.Clone(pkt), Route: route, @@ -65,8 +66,8 @@ type connReaderTestHarness struct { R *connReader WRemote *connWriter WRelayRemote *connWriter - Remote *peerRoute - RelayRemote *peerRoute + Remote *RemotePeer + RelayRemote *RemotePeer IFace *mockIfWriter Sender *mockEncryptedPacketSender Super *mockControlMsgHandler @@ -75,10 +76,10 @@ type connReaderTestHarness struct { // Peer 2 is indirect, peer 3 is direct. func newConnReadeTestHarness() (h connReaderTestHarness) { pipe := newUDPPipe() - routes := [256]*atomic.Pointer[peerRoute]{} + routes := [256]*atomic.Pointer[RemotePeer]{} for i := range routes { - routes[i] = &atomic.Pointer[peerRoute]{} - routes[i].Store(&peerRoute{}) + routes[i] = &atomic.Pointer[RemotePeer]{} + routes[i].Store(&RemotePeer{}) } local, remote, relayLocal, relayRemote := testConnWriter_getTestRoutes() @@ -255,8 +256,6 @@ func TestConnReader_handleControlPacket_duplicate(t *testing.T) { } } -/* - // Testing that we can receive a data packet. func TestConnReader_handleDataPacket(t *testing.T) { h := newConnReadeTestHarness() @@ -285,7 +284,7 @@ func TestConnReader_handleDataPacket_routeDown(t *testing.T) { rand.Read(pkt) h.WRemote.SendDataPacket(pkt, h.Remote) - route := h.R.routes[2].Load() + route := h.R.peers[2].Load() route.Up = false h.R.handleNextPacket() @@ -294,9 +293,61 @@ func TestConnReader_handleDataPacket_routeDown(t *testing.T) { t.Fatal(h.IFace.Written) } } -*/ + // Testing that a duplicate data packet is ignored. +func TestConnReader_handleDataPacket_duplicate(t *testing.T) { + h := newConnReadeTestHarness() -// Testing that we send a relayed data packet. + pkt := make([]byte, 123) -// Testing that a relayed data packet is ignored if destination isn't up. + h.WRemote.SendDataPacket(pkt, h.Remote) + *h.Remote.Counter = *h.Remote.Counter - 1 + h.WRemote.SendDataPacket(pkt, h.Remote) + + h.R.handleNextPacket() + h.R.handleNextPacket() + + if len(h.IFace.Written) != 1 { + t.Fatal(h.IFace.Written) + } + + if !bytes.Equal(pkt, h.IFace.Written[0]) { + t.Fatal(h.IFace.Written) + } +} + +// Testing that we can relay a data packet. +func TestConnReader_handleDataPacket_relay(t *testing.T) { + h := newConnReadeTestHarness() + + pkt := make([]byte, 1024) + rand.Read(pkt) + + h.RelayRemote.IP = 3 + h.WRemote.RelayDataPacket(pkt, h.RelayRemote, h.Remote) + h.R.handleNextPacket() + + if len(h.Sender.Sent) != 1 { + t.Fatal(h.Sender.Sent) + } + +} + +// Testing that we drop a relayed packet if destination is down. +func TestConnReader_handleDataPacket_relayDown(t *testing.T) { + h := newConnReadeTestHarness() + + pkt := make([]byte, 1024) + rand.Read(pkt) + + h.RelayRemote.IP = 3 + relay := h.R.peers[3].Load() + relay.Up = false + + h.WRemote.RelayDataPacket(pkt, h.RelayRemote, h.Remote) + h.R.handleNextPacket() + + if len(h.Sender.Sent) != 0 { + t.Fatal(h.Sender.Sent) + } +} diff --git a/peer/connwriter.go b/peer/connwriter.go index 928b2a0..7daa567 100644 --- a/peer/connwriter.go +++ b/peer/connwriter.go @@ -37,29 +37,29 @@ func newConnWriter(conn udpWriter, localIP byte) *connWriter { } // Not safe for concurrent use. Should only be called by supervisor. -func (w *connWriter) SendControlPacket(pkt marshaller, route *peerRoute) { - enc := encryptControlPacket(w.localIP, route, pkt, w.cBuf1, w.cBuf2) - w.writeTo(enc, route.RemoteAddr) +func (w *connWriter) SendControlPacket(pkt marshaller, peer *RemotePeer) { + enc := encryptControlPacket(w.localIP, peer, pkt, w.cBuf1, w.cBuf2) + w.writeTo(enc, peer.DirectAddr) } -// Relay control packet. Route must not be nil. -func (w *connWriter) RelayControlPacket(pkt marshaller, route, relay *peerRoute) { - enc := encryptControlPacket(w.localIP, route, pkt, w.cBuf1, w.cBuf2) - enc = encryptDataPacket(w.localIP, route.IP, relay, enc, w.cBuf1) - w.writeTo(enc, relay.RemoteAddr) +// Relay control packet. Peer must not be nil. +func (w *connWriter) RelayControlPacket(pkt marshaller, peer, relay *RemotePeer) { + enc := encryptControlPacket(w.localIP, peer, pkt, w.cBuf1, w.cBuf2) + enc = encryptDataPacket(w.localIP, peer.IP, relay, enc, w.cBuf1) + w.writeTo(enc, relay.DirectAddr) } // Not safe for concurrent use. Should only be called by ifReader. -func (w *connWriter) SendDataPacket(pkt []byte, route *peerRoute) { - enc := encryptDataPacket(w.localIP, route.IP, route, pkt, w.dBuf1) - w.writeTo(enc, route.RemoteAddr) +func (w *connWriter) SendDataPacket(pkt []byte, peer *RemotePeer) { + enc := encryptDataPacket(w.localIP, peer.IP, peer, pkt, w.dBuf1) + w.writeTo(enc, peer.DirectAddr) } -// Relay a data packet. Route must not be nil. -func (w *connWriter) RelayDataPacket(pkt []byte, route, relay *peerRoute) { - enc := encryptDataPacket(w.localIP, route.IP, route, pkt, w.dBuf1) - enc = encryptDataPacket(w.localIP, route.IP, relay, enc, w.dBuf2) - w.writeTo(enc, relay.RemoteAddr) +// Relay a data packet. Peer must not be nil. +func (w *connWriter) RelayDataPacket(pkt []byte, peer, relay *RemotePeer) { + enc := encryptDataPacket(w.localIP, peer.IP, peer, pkt, w.dBuf1) + enc = encryptDataPacket(w.localIP, peer.IP, relay, enc, w.dBuf2) + w.writeTo(enc, relay.DirectAddr) } // Safe for concurrent use. Should only be called by connReader. @@ -67,8 +67,8 @@ func (w *connWriter) RelayDataPacket(pkt []byte, route, relay *peerRoute) { // This function will send pkt to the peer directly. This is used when a peer // is acting as a relay and is forwarding already encrypted data for another // peer. -func (w *connWriter) SendEncryptedDataPacket(pkt []byte, route *peerRoute) { - w.writeTo(pkt, route.RemoteAddr) +func (w *connWriter) SendEncryptedDataPacket(pkt []byte, peer *RemotePeer) { + w.writeTo(pkt, peer.DirectAddr) } func (w *connWriter) writeTo(packet []byte, addr netip.AddrPort) { diff --git a/peer/connwriter_test.go b/peer/connwriter_test.go index 14f128e..d8c0365 100644 --- a/peer/connwriter_test.go +++ b/peer/connwriter_test.go @@ -43,46 +43,46 @@ func (p testPacket) Marshal(b []byte) []byte { // ---------------------------------------------------------------------------- -func testConnWriter_getTestRoutes() (local, remote, relayLocal, relayRemote *peerRoute) { +func testConnWriter_getTestRoutes() (local, remote, relayLocal, relayRemote *RemotePeer) { localKeys := generateKeys() remoteKeys := generateKeys() - local = newPeerRoute(2) + local = NewRemotePeer(2) local.Up = true local.Relay = false local.PubSignKey = remoteKeys.PubSignKey local.ControlCipher = newControlCipher(localKeys.PrivKey, remoteKeys.PubKey) local.DataCipher = newDataCipher() - local.RemoteAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 2}), 100) + local.DirectAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 2}), 100) - remote = newPeerRoute(1) + remote = NewRemotePeer(1) remote.Up = true remote.Relay = false remote.PubSignKey = localKeys.PubSignKey remote.ControlCipher = newControlCipher(remoteKeys.PrivKey, localKeys.PubKey) remote.DataCipher = local.DataCipher - remote.RemoteAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 100) + remote.DirectAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 100) rLocalKeys := generateKeys() rRemoteKeys := generateKeys() - relayLocal = newPeerRoute(3) + relayLocal = NewRemotePeer(3) relayLocal.Up = true relayLocal.Relay = true relayLocal.Direct = true relayLocal.PubSignKey = rRemoteKeys.PubSignKey relayLocal.ControlCipher = newControlCipher(rLocalKeys.PrivKey, rRemoteKeys.PubKey) relayLocal.DataCipher = newDataCipher() - relayLocal.RemoteAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 3}), 100) + relayLocal.DirectAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 3}), 100) - relayRemote = newPeerRoute(1) + relayRemote = NewRemotePeer(1) relayRemote.Up = true relayRemote.Relay = false relayRemote.Direct = true relayRemote.PubSignKey = rLocalKeys.PubSignKey relayRemote.ControlCipher = newControlCipher(rRemoteKeys.PrivKey, rLocalKeys.PubKey) relayRemote.DataCipher = relayLocal.DataCipher - relayRemote.RemoteAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 100) + relayRemote.DirectAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 100) return } @@ -104,7 +104,7 @@ func TestConnWriter_SendControlPacket_direct(t *testing.T) { t.Fatal(out) } - if out[0].Addr != route.RemoteAddr { + if out[0].Addr != route.DirectAddr { t.Fatal(out[0]) } @@ -132,7 +132,7 @@ func TestConnWriter_RelayControlPacket_relay(t *testing.T) { t.Fatal(out) } - if out[0].Addr != relay.RemoteAddr { + if out[0].Addr != relay.DirectAddr { t.Fatal(out[0]) } @@ -167,7 +167,7 @@ func TestConnWriter_SendDataPacket_direct(t *testing.T) { t.Fatal(out) } - if out[0].Addr != route.RemoteAddr { + if out[0].Addr != route.DirectAddr { t.Fatal(out[0]) } @@ -196,7 +196,7 @@ func TestConnWriter_RelayDataPacket_relay(t *testing.T) { t.Fatal(out) } - if out[0].Addr != relay.RemoteAddr { + if out[0].Addr != relay.DirectAddr { t.Fatal(out[0]) } @@ -230,7 +230,7 @@ func TestConnWriter_SendEncryptedDataPacket(t *testing.T) { t.Fatal(out) } - if out[0].Addr != route.RemoteAddr { + if out[0].Addr != route.DirectAddr { t.Fatal(out[0]) } diff --git a/peer/crypto.go b/peer/crypto.go index 3bc970f..f9c61db 100644 --- a/peer/crypto.go +++ b/peer/crypto.go @@ -33,40 +33,40 @@ func generateKeys() cryptoKeys { // ---------------------------------------------------------------------------- -// Route must have a ControlCipher. +// Peer must have a ControlCipher. func encryptControlPacket( localIP byte, - route *peerRoute, + peer *RemotePeer, pkt marshaller, tmp []byte, out []byte, ) []byte { h := header{ StreamID: controlStreamID, - Counter: atomic.AddUint64(route.Counter, 1), + Counter: atomic.AddUint64(peer.Counter, 1), SourceIP: localIP, - DestIP: route.IP, + DestIP: peer.IP, } tmp = pkt.Marshal(tmp) - return route.ControlCipher.Encrypt(h, tmp, out) + return peer.ControlCipher.Encrypt(h, tmp, out) } -// Returns a controlMsg[PacketType]. Route must have ControlCipher. +// Returns a controlMsg[PacketType]. Peer must have a non-nil ControlCipher. // // This function also drops packets with duplicate sequence numbers. func decryptControlPacket( - route *peerRoute, + peer *RemotePeer, fromAddr netip.AddrPort, h header, encrypted []byte, tmp []byte, ) (any, error) { - out, ok := route.ControlCipher.Decrypt(encrypted, tmp) + out, ok := peer.ControlCipher.Decrypt(encrypted, tmp) if !ok { return nil, errDecryptionFailed } - if route.DupCheck.IsDup(h.Counter) { + if peer.DupCheck.IsDup(h.Counter) { return nil, errDuplicateSeqNum } @@ -83,31 +83,32 @@ func decryptControlPacket( func encryptDataPacket( localIP byte, destIP byte, - route *peerRoute, + peer *RemotePeer, data []byte, out []byte, ) []byte { h := header{ StreamID: dataStreamID, - Counter: atomic.AddUint64(route.Counter, 1), + Counter: atomic.AddUint64(peer.Counter, 1), SourceIP: localIP, DestIP: destIP, } - return route.DataCipher.Encrypt(h, data, out) + return peer.DataCipher.Encrypt(h, data, out) } +// Decrypts and de-dups incoming data packets. func decryptDataPacket( - route *peerRoute, + peer *RemotePeer, h header, encrypted []byte, out []byte, ) ([]byte, error) { - dec, ok := route.DataCipher.Decrypt(encrypted, out) + dec, ok := peer.DataCipher.Decrypt(encrypted, out) if !ok { return nil, errDecryptionFailed } - if route.DupCheck.IsDup(h.Counter) { + if peer.DupCheck.IsDup(h.Counter) { return nil, errDuplicateSeqNum } diff --git a/peer/crypto_test.go b/peer/crypto_test.go index 29ee377..c93b87f 100644 --- a/peer/crypto_test.go +++ b/peer/crypto_test.go @@ -9,16 +9,16 @@ import ( "testing" ) -func newRoutePairForTesting() (*peerRoute, *peerRoute) { +func newRoutePairForTesting() (*RemotePeer, *RemotePeer) { keys1 := generateKeys() keys2 := generateKeys() - r1 := newPeerRoute(1) + r1 := NewRemotePeer(1) r1.PubSignKey = keys1.PubSignKey r1.ControlCipher = newControlCipher(keys1.PrivKey, keys2.PubKey) r1.DataCipher = newDataCipher() - r2 := newPeerRoute(2) + r2 := NewRemotePeer(2) r2.PubSignKey = keys2.PubSignKey r2.ControlCipher = newControlCipher(keys2.PrivKey, keys1.PubKey) r2.DataCipher = r1.DataCipher @@ -40,10 +40,7 @@ func TestDecryptControlPacket(t *testing.T) { } enc := encryptControlPacket(r1.IP, r2, in, tmp, out) - h, ok := parseHeader(enc) - if !ok { - t.Fatal(h, ok) - } + h := parseHeader(enc) iMsg, err := decryptControlPacket(r2, netip.AddrPort{}, h, enc, tmp) if err != nil { @@ -74,10 +71,7 @@ func TestDecryptControlPacket_decryptionFailed(t *testing.T) { } enc := encryptControlPacket(r1.IP, r2, in, tmp, out) - h, ok := parseHeader(enc) - if !ok { - t.Fatal(h, ok) - } + h := parseHeader(enc) for i := range enc { x := bytes.Clone(enc) @@ -103,10 +97,7 @@ func TestDecryptControlPacket_duplicate(t *testing.T) { } enc := encryptControlPacket(r1.IP, r2, in, tmp, out) - h, ok := parseHeader(enc) - if !ok { - t.Fatal(h, ok) - } + h := parseHeader(enc) if _, err := decryptControlPacket(r2, netip.AddrPort{}, h, enc, tmp); err != nil { t.Fatal(err) @@ -128,10 +119,7 @@ func TestDecryptControlPacket_invalidPacket(t *testing.T) { in := testPacket("hello!") enc := encryptControlPacket(r1.IP, r2, in, tmp, out) - h, ok := parseHeader(enc) - if !ok { - t.Fatal(h, ok) - } + h := parseHeader(enc) _, err := decryptControlPacket(r2, netip.AddrPort{}, h, enc, tmp) if !errors.Is(err, errUnknownPacketType) { @@ -149,10 +137,7 @@ func TestDecryptDataPacket(t *testing.T) { rand.Read(data) enc := encryptDataPacket(r1.IP, r2.IP, r2, data, out) - h, ok := parseHeader(enc) - if !ok { - t.Fatal(h, ok) - } + h := parseHeader(enc) out, err := decryptDataPacket(r1, h, bytes.Clone(enc), out) if err != nil { @@ -174,10 +159,7 @@ func TestDecryptDataPacket_incorrectCipher(t *testing.T) { rand.Read(data) enc := encryptDataPacket(r1.IP, r2.IP, r2, data, bytes.Clone(out)) - h, ok := parseHeader(enc) - if !ok { - t.Fatal(h, ok) - } + h := parseHeader(enc) r1.DataCipher = newDataCipher() _, err := decryptDataPacket(r1, h, enc, bytes.Clone(out)) @@ -196,10 +178,7 @@ func TestDecryptDataPacket_duplicate(t *testing.T) { rand.Read(data) enc := encryptDataPacket(r1.IP, r2.IP, r2, data, bytes.Clone(out)) - h, ok := parseHeader(enc) - if !ok { - t.Fatal(h, ok) - } + h := parseHeader(enc) _, err := decryptDataPacket(r1, h, enc, bytes.Clone(out)) if err != nil { diff --git a/peer/header.go b/peer/header.go index 08698dd..fae3780 100644 --- a/peer/header.go +++ b/peer/header.go @@ -20,16 +20,13 @@ type header struct { Counter uint64 // Init with time.Now().Unix << 30 to ensure monotonic. } -func parseHeader(b []byte) (h header, ok bool) { - if len(b) < headerSize { - return - } +func parseHeader(b []byte) (h header) { h.Version = b[0] h.StreamID = b[1] h.SourceIP = b[2] h.DestIP = b[3] h.Counter = *(*uint64)(unsafe.Pointer(&b[4])) - return h, true + return h } func (h *header) Parse(b []byte) { diff --git a/peer/ifreader.go b/peer/ifreader.go index 61627a2..79ff441 100644 --- a/peer/ifreader.go +++ b/peer/ifreader.go @@ -8,20 +8,20 @@ import ( type ifReader struct { iface io.Reader - routes [256]*atomic.Pointer[peerRoute] - relay *atomic.Pointer[peerRoute] + peers [256]*atomic.Pointer[RemotePeer] + relay *atomic.Pointer[RemotePeer] sender dataPacketSender } func newIFReader( iface io.Reader, - routes [256]*atomic.Pointer[peerRoute], - relay *atomic.Pointer[peerRoute], + peers [256]*atomic.Pointer[RemotePeer], + relay *atomic.Pointer[RemotePeer], sender dataPacketSender, ) *ifReader { return &ifReader{ iface: iface, - routes: routes, + peers: peers, relay: relay, sender: sender, } @@ -43,20 +43,20 @@ func (r *ifReader) Run() { } func (r *ifReader) sendPacket(pkt []byte, remoteIP byte) { - route := r.routes[remoteIP].Load() - if !route.Up { - log.Printf("Route not connected: %d", remoteIP) + peer := r.peers[remoteIP].Load() + if !peer.Up { + log.Printf("Peer not connected: %d", remoteIP) return } // Direct path => early return. - if route.Direct { - r.sender.SendDataPacket(pkt, route) + if peer.Direct { + r.sender.SendDataPacket(pkt, peer) return } if relay := r.relay.Load(); relay != nil && relay.Up { - r.sender.RelayDataPacket(pkt, route, relay) + r.sender.RelayDataPacket(pkt, peer, relay) } } diff --git a/peer/ifreader_test.go b/peer/ifreader_test.go index c5efb30..e8c5683 100644 --- a/peer/ifreader_test.go +++ b/peer/ifreader_test.go @@ -10,7 +10,7 @@ import ( // Test that we parse IPv4 packets correctly. func TestIFReader_parsePacket_ipv4(t *testing.T) { - r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil) + r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) pkt := make([]byte, 1234) pkt[0] = 4 << 4 @@ -23,7 +23,7 @@ func TestIFReader_parsePacket_ipv4(t *testing.T) { // Test that we parse IPv6 packets correctly. func TestIFReader_parsePacket_ipv6(t *testing.T) { - r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil) + r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) pkt := make([]byte, 1234) pkt[0] = 6 << 4 @@ -36,7 +36,7 @@ func TestIFReader_parsePacket_ipv6(t *testing.T) { // Test that empty packets work as expected. func TestIFReader_parsePacket_emptyPacket(t *testing.T) { - r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil) + r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) pkt := make([]byte, 0) if ip, ok := r.parsePacket(pkt); ok { @@ -46,7 +46,7 @@ func TestIFReader_parsePacket_emptyPacket(t *testing.T) { // Test that invalid IP versions fail. func TestIFReader_parsePacket_invalidIPVersion(t *testing.T) { - r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil) + r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) for i := byte(1); i < 16; i++ { if i == 4 || i == 6 { @@ -63,7 +63,7 @@ func TestIFReader_parsePacket_invalidIPVersion(t *testing.T) { // Test that short IPv4 packets fail. func TestIFReader_parsePacket_shortIPv4(t *testing.T) { - r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil) + r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) pkt := make([]byte, 19) pkt[0] = 4 << 4 @@ -75,7 +75,7 @@ func TestIFReader_parsePacket_shortIPv4(t *testing.T) { // Test that short IPv6 packets fail. func TestIFReader_parsePacket_shortIPv6(t *testing.T) { - r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil) + r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) pkt := make([]byte, 39) pkt[0] = 6 << 4 @@ -88,7 +88,7 @@ func TestIFReader_parsePacket_shortIPv6(t *testing.T) { // Test that we can read a packet. func TestIFReader_readNextpacket(t *testing.T) { in, out := net.Pipe() - r := newIFReader(out, [256]*atomic.Pointer[peerRoute]{}, nil, nil) + r := newIFReader(out, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) defer in.Close() defer out.Close() @@ -105,22 +105,22 @@ func TestIFReader_readNextpacket(t *testing.T) { type sentPacket struct { Relayed bool Packet []byte - Route peerRoute - Relay peerRoute + Route RemotePeer + Relay RemotePeer } type sendPacketTestHarness struct { Packets []sentPacket } -func (h *sendPacketTestHarness) SendDataPacket(pkt []byte, route *peerRoute) { +func (h *sendPacketTestHarness) SendDataPacket(pkt []byte, route *RemotePeer) { h.Packets = append(h.Packets, sentPacket{ Packet: bytes.Clone(pkt), Route: *route, }) } -func (h *sendPacketTestHarness) RelayDataPacket(pkt []byte, route, relay *peerRoute) { +func (h *sendPacketTestHarness) RelayDataPacket(pkt []byte, route, relay *RemotePeer) { h.Packets = append(h.Packets, sentPacket{ Relayed: true, Packet: bytes.Clone(pkt), @@ -132,12 +132,12 @@ func (h *sendPacketTestHarness) RelayDataPacket(pkt []byte, route, relay *peerRo func newIFReaderForSendPacketTesting() (*ifReader, *sendPacketTestHarness) { h := &sendPacketTestHarness{} - routes := [256]*atomic.Pointer[peerRoute]{} + routes := [256]*atomic.Pointer[RemotePeer]{} for i := range routes { - routes[i] = &atomic.Pointer[peerRoute]{} - routes[i].Store(&peerRoute{}) + routes[i] = &atomic.Pointer[RemotePeer]{} + routes[i].Store(&RemotePeer{}) } - relay := &atomic.Pointer[peerRoute]{} + relay := &atomic.Pointer[RemotePeer]{} r := newIFReader(nil, routes, relay, h) return r, h } @@ -146,7 +146,7 @@ func newIFReaderForSendPacketTesting() (*ifReader, *sendPacketTestHarness) { func TestIFReader_sendPacket_direct(t *testing.T) { r, h := newIFReaderForSendPacketTesting() - route := r.routes[2].Load() + route := r.peers[2].Load() route.Up = true route.Direct = true @@ -172,7 +172,7 @@ func TestIFReader_sendPacket_direct(t *testing.T) { func TestIFReader_sendPacket_directNotUp(t *testing.T) { r, h := newIFReaderForSendPacketTesting() - route := r.routes[2].Load() + route := r.peers[2].Load() route.Direct = true in := []byte("hello world") @@ -187,11 +187,11 @@ func TestIFReader_sendPacket_directNotUp(t *testing.T) { func TestIFReader_sendPacket_relayed(t *testing.T) { r, h := newIFReaderForSendPacketTesting() - route := r.routes[2].Load() + route := r.peers[2].Load() route.Up = true route.Direct = false - relay := r.routes[3].Load() + relay := r.peers[3].Load() r.relay.Store(relay) relay.Up = true relay.Direct = true @@ -219,7 +219,7 @@ func TestIFReader_sendPacket_relayed(t *testing.T) { func TestIFReader_sendPacket_nilRealy(t *testing.T) { r, h := newIFReaderForSendPacketTesting() - route := r.routes[2].Load() + route := r.peers[2].Load() route.Up = true route.Direct = false diff --git a/peer/interfaces.go b/peer/interfaces.go index 84f9c99..0d826c3 100644 --- a/peer/interfaces.go +++ b/peer/interfaces.go @@ -1,6 +1,9 @@ package peer -import "net/netip" +import ( + "net" + "net/netip" +) type udpReader interface { ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) @@ -15,14 +18,18 @@ type marshaller interface { } type dataPacketSender interface { - SendDataPacket(pkt []byte, route *peerRoute) - RelayDataPacket(pkt []byte, route, relay *peerRoute) + SendDataPacket(pkt []byte, peer *RemotePeer) + RelayDataPacket(pkt []byte, peer, relay *RemotePeer) } type encryptedPacketSender interface { - SendEncryptedDataPacket(pkt []byte, route *peerRoute) + SendEncryptedDataPacket(pkt []byte, peer *RemotePeer) } type controlMsgHandler interface { HandleControlMsg(pkt any) } + +type mcUDPWriter interface { + WriteToUDP([]byte, *net.UDPAddr) (int, error) +} diff --git a/peer/mcreader.go b/peer/mcreader.go new file mode 100644 index 0000000..7d5c959 --- /dev/null +++ b/peer/mcreader.go @@ -0,0 +1,57 @@ +package peer + +import ( + "log" + "sync/atomic" +) + +type mcReader struct { + conn udpReader + super controlMsgHandler + peers [256]*atomic.Pointer[RemotePeer] + + incoming []byte + buf []byte +} + +func newMCReader( + conn udpReader, + super controlMsgHandler, + peers [256]*atomic.Pointer[RemotePeer], +) *mcReader { + return &mcReader{conn, super, peers, newBuf(), newBuf()} +} + +func (r *mcReader) Run() { + for { + r.handleNextPacket() + } +} + +func (r *mcReader) handleNextPacket() { + incoming := r.incoming[:bufferSize] + n, remoteAddr, err := r.conn.ReadFromUDPAddrPort(incoming) + if err != nil { + log.Fatalf("Failed to read from UDP multicast port: %v", err) + } + incoming = incoming[:n] + + h, ok := headerFromLocalDiscoveryPacket(incoming) + if !ok { + return + } + + peer := r.peers[h.SourceIP].Load() + if peer == nil || peer.PubSignKey == nil { + return + } + + if !verifyLocalDiscoveryPacket(incoming, r.buf, peer.PubSignKey) { + return + } + + r.super.HandleControlMsg(controlMsg[localDiscoveryPacket]{ + SrcIP: h.SourceIP, + SrcAddr: remoteAddr, + }) +} diff --git a/peer/mcwriter.go b/peer/mcwriter.go index db9a76b..a8b55e9 100644 --- a/peer/mcwriter.go +++ b/peer/mcwriter.go @@ -2,19 +2,10 @@ package peer import ( "log" - "net" "golang.org/x/crypto/nacl/sign" ) -// ---------------------------------------------------------------------------- - -type mcUDPWriter interface { - WriteToUDP([]byte, *net.UDPAddr) (int, error) -} - -// ---------------------------------------------------------------------------- - func createLocalDiscoveryPacket(localIP byte, signingKey []byte) []byte { h := header{ SourceIP: localIP, diff --git a/peer/state.go b/peer/state.go index 2ef248b..d6589fe 100644 --- a/peer/state.go +++ b/peer/state.go @@ -5,23 +5,23 @@ import ( "time" ) -type peerRoute struct { - IP byte // VPN IP of peer (last byte). - Up bool // True if data can be sent on the route. - Relay bool // True if the peer is a relay. - Direct bool // True if this is a direct connection. +type RemotePeer struct { + IP byte // VPN IP of peer (last byte). + Up bool // True if data can be sent on the peer. + Relay bool // True if the peer is a relay. + Direct bool // True if this is a direct connection. + DirectAddr netip.AddrPort // Remote address if directly connected. PubSignKey []byte ControlCipher *controlCipher DataCipher *dataCipher - RemoteAddr netip.AddrPort // Remote address if directly connected. Counter *uint64 // For sending to. Atomic access only. DupCheck *dupCheck // For receiving from. Not safe for concurrent use. } -func newPeerRoute(ip byte) *peerRoute { +func NewRemotePeer(ip byte) *RemotePeer { counter := uint64(time.Now().Unix()<<30 + 1) - return &peerRoute{ + return &RemotePeer{ IP: ip, Counter: &counter, DupCheck: newDupCheck(0), -- 2.39.5 From 6b3216f2d2c74d15ed9c4e086387343fa881b714 Mon Sep 17 00:00:00 2001 From: jdl Date: Mon, 10 Feb 2025 19:11:30 +0100 Subject: [PATCH 06/26] WIP --- peer/connreader2.go | 132 ++++++++++ peer/connreader_test.go | 20 +- peer/connwriter.go | 4 +- peer/connwriter2.go | 109 ++++++++ peer/connwriter2_test.go | 145 +++++++++++ peer/controlmessage.go | 18 +- peer/crypto.go | 10 +- peer/crypto_test.go | 8 +- peer/data-flow.dot | 14 ++ peer/files.go | 90 +++++++ peer/files_test.go | 57 +++++ peer/globals.go | 12 +- peer/hubpoller.go | 100 ++++++++ peer/ifreader2.go | 78 ++++++ peer/ifreader2_test.go | 83 +++++++ peer/ifreader_test.go | 4 +- peer/ifwriter.go | 5 - peer/interfaces.go | 24 +- peer/mcreader.go | 2 +- peer/mcreader_test.go | 138 +++++++++++ peer/mock-iface_test.go | 31 +++ peer/mock-network_test.go | 80 ++++++ peer/packets-util.go | 4 +- peer/packets-util_test.go | 24 +- peer/packets.go | 79 +++--- peer/packets_test.go | 65 +++++ peer/peer_test.go | 125 ++++++++++ peer/peerstates.go | 355 ++++++++++++++++++++++++++ peer/peerstates_test.go | 509 ++++++++++++++++++++++++++++++++++++++ peer/pubaddrs.go | 75 ++++++ peer/pubaddrs_test.go | 29 +++ peer/routingtable.go | 137 ++++++++++ peer/routingtable_test.go | 169 +++++++++++++ peer/state.go | 29 --- peer/supervisor.go | 103 ++++++++ peer/util_test.go | 26 ++ 36 files changed, 2763 insertions(+), 130 deletions(-) create mode 100644 peer/connreader2.go create mode 100644 peer/connwriter2.go create mode 100644 peer/connwriter2_test.go create mode 100644 peer/data-flow.dot create mode 100644 peer/files.go create mode 100644 peer/files_test.go create mode 100644 peer/hubpoller.go create mode 100644 peer/ifreader2.go create mode 100644 peer/ifreader2_test.go delete mode 100644 peer/ifwriter.go create mode 100644 peer/mcreader_test.go create mode 100644 peer/mock-iface_test.go create mode 100644 peer/mock-network_test.go create mode 100644 peer/peer_test.go create mode 100644 peer/peerstates.go create mode 100644 peer/peerstates_test.go create mode 100644 peer/pubaddrs.go create mode 100644 peer/pubaddrs_test.go create mode 100644 peer/routingtable.go create mode 100644 peer/routingtable_test.go delete mode 100644 peer/state.go create mode 100644 peer/supervisor.go create mode 100644 peer/util_test.go diff --git a/peer/connreader2.go b/peer/connreader2.go new file mode 100644 index 0000000..d9feab8 --- /dev/null +++ b/peer/connreader2.go @@ -0,0 +1,132 @@ +package peer + +import ( + "io" + "log" + "net/netip" + "sync/atomic" +) + +type ConnReader struct { + // Input + readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error) + + // Output + iface io.Writer + forwardData func(ip byte, pkt []byte) + handleControlMsg func(pkt any) + + localIP byte + rt *atomic.Pointer[RoutingTable] + + buf []byte + decBuf []byte +} + +func NewConnReader( + readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error), + iface io.Writer, + forwardData func(ip byte, pkt []byte), + handleControlMsg func(pkt any), + rt *atomic.Pointer[RoutingTable], +) *ConnReader { + return &ConnReader{ + readFromUDPAddrPort: readFromUDPAddrPort, + iface: iface, + forwardData: forwardData, + handleControlMsg: handleControlMsg, + localIP: rt.Load().LocalIP, + rt: rt, + buf: newBuf(), + decBuf: newBuf(), + } +} + +func (r *ConnReader) Run() { + for { + r.handleNextPacket() + } +} + +func (r *ConnReader) handleNextPacket() { + buf := r.buf[:bufferSize] + n, remoteAddr, err := r.readFromUDPAddrPort(buf) + if err != nil { + log.Fatalf("Failed to read from UDP port: %v", err) + } + + if n < headerSize { + return + } + + remoteAddr = netip.AddrPortFrom(remoteAddr.Addr().Unmap(), remoteAddr.Port()) + + buf = buf[:n] + h := parseHeader(buf) + + peer := r.rt.Load().Peers[h.SourceIP] + //peer := rt.Peers[h.SourceIP] + + switch h.StreamID { + case controlStreamID: + r.handleControlPacket(remoteAddr, peer, h, buf) + case dataStreamID: + r.handleDataPacket(peer, h, buf) + default: + r.logf("Unknown stream ID: %d", h.StreamID) + } +} + +func (r *ConnReader) handleControlPacket( + remoteAddr netip.AddrPort, + peer RemotePeer, + h header, + enc []byte, +) { + if peer.ControlCipher == nil { + return + } + + if h.DestIP != r.localIP { + r.logf("Incorrect destination IP on control packet: %d", h.DestIP) + return + } + + msg, err := peer.DecryptControlPacket(remoteAddr, h, enc, r.decBuf) + if err != nil { + r.logf("Failed to decrypt control packet: %v", err) + return + } + + r.handleControlMsg(msg) +} + +func (r *ConnReader) handleDataPacket( + peer RemotePeer, + h header, + enc []byte, +) { + if !peer.Up { + r.logf("Not connected (recv).") + return + } + + data, err := peer.DecryptDataPacket(h, enc, r.decBuf) + if err != nil { + r.logf("Failed to decrypt data packet: %v", err) + return + } + + if h.DestIP == r.localIP { + if _, err := r.iface.Write(data); err != nil { + log.Fatalf("Failed to write to interface: %v", err) + } + return + } + + r.forwardData(h.DestIP, data) +} + +func (r *ConnReader) logf(format string, args ...any) { + log.Printf("[ConnReader] "+format, args...) +} diff --git a/peer/connreader_test.go b/peer/connreader_test.go index 39da83c..714f6f3 100644 --- a/peer/connreader_test.go +++ b/peer/connreader_test.go @@ -109,7 +109,7 @@ func newConnReadeTestHarness() (h connReaderTestHarness) { func TestConnReader_handleControlPacket(t *testing.T) { h := newConnReadeTestHarness() - pkt := synPacket{TraceID: 1234} + pkt := PacketSyn{TraceID: 1234} h.WRemote.SendControlPacket(pkt, h.Remote) @@ -119,7 +119,7 @@ func TestConnReader_handleControlPacket(t *testing.T) { t.Fatal(h.Super.Messages) } - msg := h.Super.Messages[0].(controlMsg[synPacket]) + msg := h.Super.Messages[0].(controlMsg[PacketSyn]) if !reflect.DeepEqual(pkt, msg.Packet) { t.Fatal(msg.Packet) } @@ -141,7 +141,7 @@ func TestConnReader_handleNextPacket_short(t *testing.T) { func TestConnReader_handleNextPacket_unknownStreamID(t *testing.T) { h := newConnReadeTestHarness() - pkt := synPacket{TraceID: 1234} + pkt := PacketSyn{TraceID: 1234} encrypted := encryptControlPacket(1, h.Remote, pkt, newBuf(), newBuf()) var header header @@ -160,7 +160,7 @@ func TestConnReader_handleNextPacket_unknownStreamID(t *testing.T) { func TestConnReader_handleControlPacket_noCipher(t *testing.T) { h := newConnReadeTestHarness() - pkt := synPacket{TraceID: 1234} + pkt := PacketSyn{TraceID: 1234} //encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) encrypted := encryptControlPacket(1, h.Remote, pkt, newBuf(), newBuf()) @@ -180,7 +180,7 @@ func TestConnReader_handleControlPacket_noCipher(t *testing.T) { func TestConnReader_handleControlPacket_incorrectDest(t *testing.T) { h := newConnReadeTestHarness() - pkt := synPacket{TraceID: 1234} + pkt := PacketSyn{TraceID: 1234} encrypted := encryptControlPacket(2, h.Remote, pkt, newBuf(), newBuf()) var header header @@ -199,7 +199,7 @@ func TestConnReader_handleControlPacket_incorrectDest(t *testing.T) { func TestConnReader_handleControlPacket_modified(t *testing.T) { h := newConnReadeTestHarness() - pkt := synPacket{TraceID: 1234} + pkt := PacketSyn{TraceID: 1234} encrypted := encryptControlPacket(2, h.Remote, pkt, newBuf(), newBuf()) encrypted[len(encrypted)-1]++ @@ -237,10 +237,10 @@ func TestConnReader_handleControlPacket_unknownPacketType(t *testing.T) { func TestConnReader_handleControlPacket_duplicate(t *testing.T) { h := newConnReadeTestHarness() - pkt := ackPacket{TraceID: 1234} + pkt := PacketAck{TraceID: 1234} h.WRemote.SendControlPacket(pkt, h.Remote) - *h.Remote.Counter = *h.Remote.Counter - 1 + *h.Remote.counter = *h.Remote.counter - 1 h.WRemote.SendControlPacket(pkt, h.Remote) h.R.handleNextPacket() @@ -250,7 +250,7 @@ func TestConnReader_handleControlPacket_duplicate(t *testing.T) { t.Fatal(h.Super.Messages) } - msg := h.Super.Messages[0].(controlMsg[ackPacket]) + msg := h.Super.Messages[0].(controlMsg[PacketAck]) if !reflect.DeepEqual(pkt, msg.Packet) { t.Fatal(msg.Packet) } @@ -301,7 +301,7 @@ func TestConnReader_handleDataPacket_duplicate(t *testing.T) { pkt := make([]byte, 123) h.WRemote.SendDataPacket(pkt, h.Remote) - *h.Remote.Counter = *h.Remote.Counter - 1 + *h.Remote.counter = *h.Remote.counter - 1 h.WRemote.SendDataPacket(pkt, h.Remote) h.R.handleNextPacket() diff --git a/peer/connwriter.go b/peer/connwriter.go index 7daa567..8a09e35 100644 --- a/peer/connwriter.go +++ b/peer/connwriter.go @@ -37,13 +37,13 @@ func newConnWriter(conn udpWriter, localIP byte) *connWriter { } // Not safe for concurrent use. Should only be called by supervisor. -func (w *connWriter) SendControlPacket(pkt marshaller, peer *RemotePeer) { +func (w *connWriter) SendControlPacket(pkt Marshaller, peer *RemotePeer) { enc := encryptControlPacket(w.localIP, peer, pkt, w.cBuf1, w.cBuf2) w.writeTo(enc, peer.DirectAddr) } // Relay control packet. Peer must not be nil. -func (w *connWriter) RelayControlPacket(pkt marshaller, peer, relay *RemotePeer) { +func (w *connWriter) RelayControlPacket(pkt Marshaller, peer, relay *RemotePeer) { enc := encryptControlPacket(w.localIP, peer, pkt, w.cBuf1, w.cBuf2) enc = encryptDataPacket(w.localIP, peer.IP, relay, enc, w.cBuf1) w.writeTo(enc, relay.DirectAddr) diff --git a/peer/connwriter2.go b/peer/connwriter2.go new file mode 100644 index 0000000..e58250d --- /dev/null +++ b/peer/connwriter2.go @@ -0,0 +1,109 @@ +package peer + +import ( + "log" + "net/netip" + "sync" + "sync/atomic" +) + +type ConnWriter struct { + wLock sync.Mutex // Lock around for sending on UDP Conn. + + // Output. + writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error) + + // Shared state. + rt *atomic.Pointer[RoutingTable] + + // For sending control packets. + cBuf1 []byte + cBuf2 []byte + + // For sending data packets. + dBuf1 []byte + dBuf2 []byte +} + +func NewConnWriter( + writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error), + rt *atomic.Pointer[RoutingTable], +) *ConnWriter { + return &ConnWriter{ + writeToUDPAddrPort: writeToUDPAddrPort, + rt: rt, + cBuf1: newBuf(), + cBuf2: newBuf(), + dBuf1: newBuf(), + dBuf2: newBuf(), + } +} + +// Called by ConnReader to forward already encrypted bytes to another peer. +func (w *ConnWriter) Forward(ip byte, pkt []byte) { + peer := w.rt.Load().Peers[ip] + if !(peer.Up && peer.Direct) { + w.logf("Failed to forward to %d.", ip) + return + } + w.writeTo(pkt, peer.DirectAddr) +} + +// Called by IFReader to send data. Encryption will be applied, and packet will +// be relayed if appropriate. +func (w *ConnWriter) WriteData(ip byte, pkt []byte) { + rt := w.rt.Load() + peer := rt.Peers[ip] + if !peer.Up { + w.logf("Failed to send data to %d.", ip) + return + } + + enc := peer.EncryptDataPacket(ip, pkt, w.dBuf1) + + if peer.Direct { + w.writeTo(enc, peer.DirectAddr) + return + } + + relay, ok := rt.GetRelay() + if !ok { + w.logf("Failed to send data to %d. No relay.", ip) + return + } + + enc = relay.EncryptDataPacket(ip, enc, w.dBuf2) + w.writeTo(enc, relay.DirectAddr) +} + +// Called by Supervisor to send control packets. +func (w *ConnWriter) WriteControl(peer RemotePeer, pkt Marshaller) { + enc := peer.EncryptControlPacket(pkt, w.cBuf2, w.cBuf1) + + if peer.Direct { + w.writeTo(enc, peer.DirectAddr) + return + } + + rt := w.rt.Load() + relay, ok := rt.GetRelay() + if !ok { + w.logf("Failed to send control to %d. No relay.", peer.IP) + return + } + + enc = relay.EncryptDataPacket(peer.IP, enc, w.cBuf2) + w.writeTo(enc, relay.DirectAddr) +} + +func (w *ConnWriter) writeTo(pkt []byte, addr netip.AddrPort) { + w.wLock.Lock() + if _, err := w.writeToUDPAddrPort(pkt, addr); err != nil { + w.logf("Failed to write to UDP port: %v", err) + } + w.wLock.Unlock() +} + +func (w *ConnWriter) logf(s string, args ...any) { + log.Printf("[ConnWriter] "+s, args...) +} diff --git a/peer/connwriter2_test.go b/peer/connwriter2_test.go new file mode 100644 index 0000000..f0bb00f --- /dev/null +++ b/peer/connwriter2_test.go @@ -0,0 +1,145 @@ +package peer + +import ( + "testing" +) + +func TestConnWriter_WriteData_direct(t *testing.T) { + p1, p2, _ := NewPeersForTesting() + + in := RandPacket() + p1.ConnWriter.WriteData(2, in) + + packets := p2.Conn.Packets() + if len(packets) != 1 { + t.Fatal(packets) + } +} + +func TestConnWriter_WriteData_peerNotUp(t *testing.T) { + p1, p2, _ := NewPeersForTesting() + p1.RT.Load().Peers[2].Up = false + + in := RandPacket() + p1.ConnWriter.WriteData(2, in) + + packets := p2.Conn.Packets() + if len(packets) != 0 { + t.Fatal(packets) + } +} + +func TestConnWriter_WriteData_relay(t *testing.T) { + p1, _, p3 := NewPeersForTesting() + + p1.RT.Load().Peers[2].Direct = false + p1.RT.Load().RelayIP = 3 + + in := RandPacket() + p1.ConnWriter.WriteData(2, in) + + packets := p3.Conn.Packets() + if len(packets) != 1 { + t.Fatal(packets) + } +} + +func TestConnWriter_WriteData_relayNotAvailable(t *testing.T) { + p1, _, p3 := NewPeersForTesting() + + p1.RT.Load().Peers[2].Direct = false + p1.RT.Load().Peers[3].Up = false + p1.RT.Load().RelayIP = 3 + + in := RandPacket() + p1.ConnWriter.WriteData(2, in) + + packets := p3.Conn.Packets() + if len(packets) != 0 { + t.Fatal(packets) + } +} + +func TestConnWriter_WriteControl_direct(t *testing.T) { + p1, p2, _ := NewPeersForTesting() + + orig := PacketProbe{TraceID: newTraceID()} + + p1.ConnWriter.WriteControl(p1.RT.Load().Peers[2], orig) + + packets := p2.Conn.Packets() + if len(packets) != 1 { + t.Fatal(packets) + } +} + +func TestConnWriter_WriteControl_relay(t *testing.T) { + p1, _, p3 := NewPeersForTesting() + + p1.RT.Load().Peers[2].Direct = false + p1.RT.Load().RelayIP = 3 + + orig := PacketProbe{TraceID: newTraceID()} + + p1.ConnWriter.WriteControl(p1.RT.Load().Peers[2], orig) + + packets := p3.Conn.Packets() + if len(packets) != 1 { + t.Fatal(packets) + } +} + +func TestConnWriter_WriteControl_relayNotAvailable(t *testing.T) { + p1, _, p3 := NewPeersForTesting() + + p1.RT.Load().Peers[2].Direct = false + p1.RT.Load().Peers[3].Up = false + p1.RT.Load().RelayIP = 3 + + orig := PacketProbe{TraceID: newTraceID()} + + p1.ConnWriter.WriteControl(p1.RT.Load().Peers[2], orig) + + packets := p3.Conn.Packets() + if len(packets) != 0 { + t.Fatal(packets) + } +} + +func TestConnWriter__Forward(t *testing.T) { + p1, p2, _ := NewPeersForTesting() + + in := RandPacket() + p1.ConnWriter.Forward(2, in) + + packets := p2.Conn.Packets() + if len(packets) != 1 { + t.Fatal(packets) + } +} + +func TestConnWriter__Forward_notUp(t *testing.T) { + p1, p2, _ := NewPeersForTesting() + p1.RT.Load().Peers[2].Up = false + + in := RandPacket() + p1.ConnWriter.Forward(2, in) + + packets := p2.Conn.Packets() + if len(packets) != 0 { + t.Fatal(packets) + } +} + +func TestConnWriter__Forward_notDirect(t *testing.T) { + p1, p2, _ := NewPeersForTesting() + p1.RT.Load().Peers[2].Direct = false + + in := RandPacket() + p1.ConnWriter.Forward(2, in) + + packets := p2.Conn.Packets() + if len(packets) != 0 { + t.Fatal(packets) + } +} diff --git a/peer/controlmessage.go b/peer/controlmessage.go index d8e9a17..7180dd0 100644 --- a/peer/controlmessage.go +++ b/peer/controlmessage.go @@ -17,25 +17,25 @@ type controlMsg[T any] struct { func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error) { switch buf[0] { - case packetTypeSyn: - packet, err := parseSynPacket(buf) - return controlMsg[synPacket]{ + case PacketTypeSyn: + packet, err := ParsePacketSyn(buf) + return controlMsg[PacketSyn]{ SrcIP: srcIP, SrcAddr: srcAddr, Packet: packet, }, err - case packetTypeAck: - packet, err := parseAckPacket(buf) - return controlMsg[ackPacket]{ + case PacketTypeAck: + packet, err := ParsePacketAck(buf) + return controlMsg[PacketAck]{ SrcIP: srcIP, SrcAddr: srcAddr, Packet: packet, }, err - case packetTypeProbe: - packet, err := parseProbePacket(buf) - return controlMsg[probePacket]{ + case PacketTypeProbe: + packet, err := ParsePacketProbe(buf) + return controlMsg[PacketProbe]{ SrcIP: srcIP, SrcAddr: srcAddr, Packet: packet, diff --git a/peer/crypto.go b/peer/crypto.go index f9c61db..dcc042b 100644 --- a/peer/crypto.go +++ b/peer/crypto.go @@ -37,13 +37,13 @@ func generateKeys() cryptoKeys { func encryptControlPacket( localIP byte, peer *RemotePeer, - pkt marshaller, + pkt Marshaller, tmp []byte, out []byte, ) []byte { h := header{ StreamID: controlStreamID, - Counter: atomic.AddUint64(peer.Counter, 1), + Counter: atomic.AddUint64(peer.counter, 1), SourceIP: localIP, DestIP: peer.IP, } @@ -66,7 +66,7 @@ func decryptControlPacket( return nil, errDecryptionFailed } - if peer.DupCheck.IsDup(h.Counter) { + if peer.dupCheck.IsDup(h.Counter) { return nil, errDuplicateSeqNum } @@ -89,7 +89,7 @@ func encryptDataPacket( ) []byte { h := header{ StreamID: dataStreamID, - Counter: atomic.AddUint64(peer.Counter, 1), + Counter: atomic.AddUint64(peer.counter, 1), SourceIP: localIP, DestIP: destIP, } @@ -108,7 +108,7 @@ func decryptDataPacket( return nil, errDecryptionFailed } - if peer.DupCheck.IsDup(h.Counter) { + if peer.dupCheck.IsDup(h.Counter) { return nil, errDuplicateSeqNum } diff --git a/peer/crypto_test.go b/peer/crypto_test.go index c93b87f..824a43a 100644 --- a/peer/crypto_test.go +++ b/peer/crypto_test.go @@ -33,7 +33,7 @@ func TestDecryptControlPacket(t *testing.T) { out = make([]byte, bufferSize) ) - in := synPacket{ + in := PacketSyn{ TraceID: newTraceID(), SharedKey: r1.DataCipher.Key(), Direct: true, @@ -47,7 +47,7 @@ func TestDecryptControlPacket(t *testing.T) { t.Fatal(err) } - msg, ok := iMsg.(controlMsg[synPacket]) + msg, ok := iMsg.(controlMsg[PacketSyn]) if !ok { t.Fatal(ok) } @@ -64,7 +64,7 @@ func TestDecryptControlPacket_decryptionFailed(t *testing.T) { out = make([]byte, bufferSize) ) - in := synPacket{ + in := PacketSyn{ TraceID: newTraceID(), SharedKey: r1.DataCipher.Key(), Direct: true, @@ -90,7 +90,7 @@ func TestDecryptControlPacket_duplicate(t *testing.T) { out = make([]byte, bufferSize) ) - in := synPacket{ + in := PacketSyn{ TraceID: newTraceID(), SharedKey: r1.DataCipher.Key(), Direct: true, diff --git a/peer/data-flow.dot b/peer/data-flow.dot new file mode 100644 index 0000000..45b6f05 --- /dev/null +++ b/peer/data-flow.dot @@ -0,0 +1,14 @@ +digraph d { + ifReader -> connWriter; + connReader -> ifWriter; + connReader -> connWriter; + connReader -> supervisor; + mcReader -> supervisor; + supervisor -> connWriter; + supervisor -> mcWriter; + hubPoller -> supervisor; + + connWriter [shape="box"]; + mcWriter [shape="box"]; + ifWriter [shape="box"]; +} \ No newline at end of file diff --git a/peer/files.go b/peer/files.go new file mode 100644 index 0000000..b0eade5 --- /dev/null +++ b/peer/files.go @@ -0,0 +1,90 @@ +package peer + +import ( + "encoding/json" + "log" + "os" + "path/filepath" + "vppn/m" +) + +type localConfig struct { + m.PeerConfig + PubKey []byte + PrivKey []byte + PubSignKey []byte + PrivSignKey []byte +} + +func configDir(netName string) string { + d, err := os.UserHomeDir() + if err != nil { + log.Fatalf("Failed to get user home directory: %v", err) + } + return filepath.Join(d, ".vppn", netName) +} + +func peerConfigPath(netName string) string { + return filepath.Join(configDir(netName), "peer-config.json") +} + +func peerStatePath(netName string) string { + return filepath.Join(configDir(netName), "peer-state.json") +} + +func storeJson(x any, outPath string) error { + outDir := filepath.Dir(outPath) + _ = os.MkdirAll(outDir, 0700) + + tmpPath := outPath + ".tmp" + buf, err := json.Marshal(x) + if err != nil { + return err + } + + f, err := os.Create(tmpPath) + if err != nil { + return err + } + + if _, err := f.Write(buf); err != nil { + f.Close() + return err + } + + if err := f.Sync(); err != nil { + f.Close() + return err + } + + if err := f.Close(); err != nil { + return err + } + + return os.Rename(tmpPath, outPath) +} + +func storePeerConfig(netName string, pc localConfig) error { + return storeJson(pc, peerConfigPath(netName)) +} + +func storeNetworkState(netName string, ps m.NetworkState) error { + return storeJson(ps, peerStatePath(netName)) +} + +func loadJson(dataPath string, ptr any) error { + data, err := os.ReadFile(dataPath) + if err != nil { + return err + } + + return json.Unmarshal(data, ptr) +} + +func loadPeerConfig(netName string) (pc localConfig, err error) { + return pc, loadJson(peerConfigPath(netName), &pc) +} + +func loadNetworkState(netName string) (ps m.NetworkState, err error) { + return ps, loadJson(peerStatePath(netName), &ps) +} diff --git a/peer/files_test.go b/peer/files_test.go new file mode 100644 index 0000000..5e32ced --- /dev/null +++ b/peer/files_test.go @@ -0,0 +1,57 @@ +package peer + +import ( + "path/filepath" + "reflect" + "testing" +) + +func TestFilePaths(t *testing.T) { + confDir := configDir("netName") + if filepath.Base(confDir) != "netName" { + t.Fatal(confDir) + } + if filepath.Base(filepath.Dir(confDir)) != ".vppn" { + t.Fatal(confDir) + } + + path := peerConfigPath("netName") + if path != filepath.Join(confDir, "peer-config.json") { + t.Fatal(path) + } + + path = peerStatePath("netName") + if path != filepath.Join(confDir, "peer-state.json") { + t.Fatal(path) + } +} + +func TestStoreLoadJson(t *testing.T) { + type Object struct { + Name string + Age int + Price float64 + } + + tmpDir := t.TempDir() + outPath := filepath.Join(tmpDir, "object.json") + + obj := Object{ + Name: "Jason", + Age: 22, + Price: 123.534, + } + + if err := storeJson(obj, outPath); err != nil { + t.Fatal(err) + } + + obj2 := Object{} + if err := loadJson(outPath, &obj2); err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(obj, obj2) { + t.Fatal(obj, obj2) + } +} diff --git a/peer/globals.go b/peer/globals.go index 4733ac8..0d7ada3 100644 --- a/peer/globals.go +++ b/peer/globals.go @@ -3,15 +3,21 @@ package peer import ( "net" "net/netip" + "time" ) const ( - bufferSize = 1536 - if_mtu = 1200 - if_queue_len = 2048 + bufferSize = 1536 + + if_mtu = 1200 + if_queue_len = 2048 + controlCipherOverhead = 16 dataCipherOverhead = 16 signOverhead = 64 + + pingInterval = 8 * time.Second + timeoutInterval = 30 * time.Second ) var multicastAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom( diff --git a/peer/hubpoller.go b/peer/hubpoller.go new file mode 100644 index 0000000..f608bd5 --- /dev/null +++ b/peer/hubpoller.go @@ -0,0 +1,100 @@ +package peer + +import ( + "encoding/json" + "io" + "log" + "net/http" + "net/url" + "time" + "vppn/m" +) + +type hubPoller struct { + client *http.Client + req *http.Request + versions [256]int64 + localIP byte + netName string + super controlMsgHandler +} + +func newHubPoller(localIP byte, netName, hubURL, apiKey string, super controlMsgHandler) (*hubPoller, error) { + u, err := url.Parse(hubURL) + if err != nil { + return nil, err + } + u.Path = "/peer/fetch-state/" + + client := &http.Client{Timeout: 8 * time.Second} + + req := &http.Request{ + Method: http.MethodGet, + URL: u, + Header: http.Header{}, + } + req.SetBasicAuth("", apiKey) + + return &hubPoller{ + client: client, + req: req, + localIP: localIP, + netName: netName, + super: super, + }, nil +} + +func (hp *hubPoller) Run() { + state, err := loadNetworkState(hp.netName) + if err != nil { + log.Printf("Failed to load network state: %v", err) + log.Printf("Polling hub...") + hp.pollHub() + } else { + hp.applyNetworkState(state) + } + + for range time.Tick(64 * time.Second) { + hp.pollHub() + } +} + +func (hp *hubPoller) pollHub() { + var state m.NetworkState + + resp, err := hp.client.Do(hp.req) + if err != nil { + log.Printf("Failed to fetch peer state: %v", err) + return + } + body, err := io.ReadAll(resp.Body) + _ = resp.Body.Close() + if err != nil { + log.Printf("Failed to read body from hub: %v", err) + return + } + + if err := json.Unmarshal(body, &state); err != nil { + log.Printf("Failed to unmarshal response from hub: %v\n%s", err, body) + return + } + + hp.applyNetworkState(state) + + if err := storeNetworkState(hp.netName, state); err != nil { + log.Printf("Failed to store network state: %v", err) + } +} + +func (hp *hubPoller) applyNetworkState(state m.NetworkState) { + for i, peer := range state.Peers { + if i != int(hp.localIP) { + if peer == nil || peer.Version != hp.versions[i] { + hp.super.HandleControlMsg(peerUpdateMsg{PeerIP: byte(i), Peer: state.Peers[i]}) + if peer != nil { + hp.versions[i] = peer.Version + } + } + } + } +} diff --git a/peer/ifreader2.go b/peer/ifreader2.go new file mode 100644 index 0000000..c390e8f --- /dev/null +++ b/peer/ifreader2.go @@ -0,0 +1,78 @@ +package peer + +import ( + "io" + "log" +) + +type IFReader struct { + iface io.Reader + connWriter interface { + WriteData(ip byte, pkt []byte) + } +} + +func NewIFReader( + iface io.Reader, + connWriter interface { + WriteData(ip byte, pkt []byte) + }, +) *IFReader { + return &IFReader{iface, connWriter} +} + +func (r *IFReader) Run() { + packet := newBuf() + for { + r.handleNextPacket(packet) + } +} + +func (r *IFReader) handleNextPacket(packet []byte) { + packet = r.readNextPacket(packet) + if remoteIP, ok := r.parsePacket(packet); ok { + r.connWriter.WriteData(remoteIP, packet) + } +} + +func (r *IFReader) readNextPacket(buf []byte) []byte { + n, err := r.iface.Read(buf[:cap(buf)]) + if err != nil { + log.Fatalf("Failed to read from interface: %v", err) + } + + return buf[:n] +} + +func (r *IFReader) parsePacket(buf []byte) (byte, bool) { + n := len(buf) + if n == 0 { + return 0, false + } + + version := buf[0] >> 4 + + switch version { + case 4: + if n < 20 { + r.logf("Short IPv4 packet: %d", len(buf)) + return 0, false + } + return buf[19], true + + case 6: + if len(buf) < 40 { + r.logf("Short IPv6 packet: %d", len(buf)) + return 0, false + } + return buf[39], true + + default: + r.logf("Invalid IP packet version: %v", version) + return 0, false + } +} + +func (*IFReader) logf(s string, args ...any) { + log.Printf("[IFReader] "+s, args...) +} diff --git a/peer/ifreader2_test.go b/peer/ifreader2_test.go new file mode 100644 index 0000000..779cf49 --- /dev/null +++ b/peer/ifreader2_test.go @@ -0,0 +1,83 @@ +package peer + +import ( + "testing" +) + +func TestIFReader_IPv4(t *testing.T) { + p1, p2, _ := NewPeersForTesting() + + pkt := make([]byte, 1234) + pkt[0] = 4 << 4 + pkt[19] = 2 // IP. + + p1.IFace.UserWrite(pkt) + p1.IFReader.handleNextPacket(newBuf()) + + packets := p2.Conn.Packets() + if len(packets) != 1 { + t.Fatal(packets) + } +} + +func TestIFReader_IPv6(t *testing.T) { + p1, p2, _ := NewPeersForTesting() + + pkt := make([]byte, 1234) + pkt[0] = 6 << 4 + pkt[39] = 2 // IP. + + p1.IFace.UserWrite(pkt) + p1.IFReader.handleNextPacket(newBuf()) + + packets := p2.Conn.Packets() + if len(packets) != 1 { + t.Fatal(packets) + } +} + +func TestIFReader_parsePacket_emptyPacket(t *testing.T) { + r := NewIFReader(nil, nil) + pkt := make([]byte, 0) + if ip, ok := r.parsePacket(pkt); ok { + t.Fatal(ip, ok) + } +} + +func TestIFReader_parsePacket_invalidIPVersion(t *testing.T) { + r := NewIFReader(nil, nil) + + for i := byte(1); i < 16; i++ { + if i == 4 || i == 6 { + continue + } + pkt := make([]byte, 1234) + pkt[0] = i << 4 + + if ip, ok := r.parsePacket(pkt); ok { + t.Fatal(i, ip, ok) + } + } +} + +func TestIFReader_parsePacket_shortIPv4(t *testing.T) { + r := NewIFReader(nil, nil) + + pkt := make([]byte, 19) + pkt[0] = 4 << 4 + + if ip, ok := r.parsePacket(pkt); ok { + t.Fatal(ip, ok) + } +} + +func TestIFReader_parsePacket_shortIPv6(t *testing.T) { + r := NewIFReader(nil, nil) + + pkt := make([]byte, 39) + pkt[0] = 6 << 4 + + if ip, ok := r.parsePacket(pkt); ok { + t.Fatal(ip, ok) + } +} diff --git a/peer/ifreader_test.go b/peer/ifreader_test.go index e8c5683..620d2b1 100644 --- a/peer/ifreader_test.go +++ b/peer/ifreader_test.go @@ -2,7 +2,6 @@ package peer import ( "bytes" - "net" "reflect" "sync/atomic" "testing" @@ -34,6 +33,7 @@ func TestIFReader_parsePacket_ipv6(t *testing.T) { } } +/* // Test that empty packets work as expected. func TestIFReader_parsePacket_emptyPacket(t *testing.T) { r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) @@ -99,7 +99,7 @@ func TestIFReader_readNextpacket(t *testing.T) { t.Fatalf("%s", pkt) } } - +*/ // ---------------------------------------------------------------------------- type sentPacket struct { diff --git a/peer/ifwriter.go b/peer/ifwriter.go deleted file mode 100644 index 59e2e26..0000000 --- a/peer/ifwriter.go +++ /dev/null @@ -1,5 +0,0 @@ -package peer - -import "io" - -type ifWriter io.Writer diff --git a/peer/interfaces.go b/peer/interfaces.go index 0d826c3..8e99e8b 100644 --- a/peer/interfaces.go +++ b/peer/interfaces.go @@ -1,10 +1,19 @@ package peer import ( + "io" "net" "net/netip" ) +type UDPConn interface { + ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) + WriteToUDPAddrPort([]byte, netip.AddrPort) (int, error) + WriteToUDP([]byte, *net.UDPAddr) (int, error) +} + +type ifWriter io.Writer + type udpReader interface { ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) } @@ -13,7 +22,11 @@ type udpWriter interface { WriteToUDPAddrPort([]byte, netip.AddrPort) (int, error) } -type marshaller interface { +type mcUDPWriter interface { + WriteToUDP([]byte, *net.UDPAddr) (int, error) +} + +type Marshaller interface { Marshal([]byte) []byte } @@ -22,6 +35,11 @@ type dataPacketSender interface { RelayDataPacket(pkt []byte, peer, relay *RemotePeer) } +type controlPacketSender interface { + SendControlPacket(pkt Marshaller, peer *RemotePeer) + RelayControlPacket(pkt Marshaller, peer, relay *RemotePeer) +} + type encryptedPacketSender interface { SendEncryptedDataPacket(pkt []byte, peer *RemotePeer) } @@ -29,7 +47,3 @@ type encryptedPacketSender interface { type controlMsgHandler interface { HandleControlMsg(pkt any) } - -type mcUDPWriter interface { - WriteToUDP([]byte, *net.UDPAddr) (int, error) -} diff --git a/peer/mcreader.go b/peer/mcreader.go index 7d5c959..38921f1 100644 --- a/peer/mcreader.go +++ b/peer/mcreader.go @@ -50,7 +50,7 @@ func (r *mcReader) handleNextPacket() { return } - r.super.HandleControlMsg(controlMsg[localDiscoveryPacket]{ + r.super.HandleControlMsg(controlMsg[PacketLocalDiscovery]{ SrcIP: h.SourceIP, SrcAddr: remoteAddr, }) diff --git a/peer/mcreader_test.go b/peer/mcreader_test.go new file mode 100644 index 0000000..50bf821 --- /dev/null +++ b/peer/mcreader_test.go @@ -0,0 +1,138 @@ +package peer + +import ( + "bytes" + "net" + "net/netip" + "sync/atomic" + "testing" +) + +type mcMockConn struct { + packets chan []byte +} + +func newMCMockConn() *mcMockConn { + return &mcMockConn{make(chan []byte, 32)} +} + +func (c *mcMockConn) WriteToUDP(in []byte, addr *net.UDPAddr) (int, error) { + c.packets <- bytes.Clone(in) + return len(in), nil +} + +func (c *mcMockConn) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) { + buf := <-c.packets + b = b[:len(buf)] + copy(b, buf) + return len(b), netip.AddrPort{}, nil +} + +func TestMCReader(t *testing.T) { + keys := generateKeys() + super := &mockControlMsgHandler{} + conn := newMCMockConn() + + peers := [256]*atomic.Pointer[RemotePeer]{} + peer := &RemotePeer{ + IP: 1, + Up: true, + PubSignKey: keys.PubSignKey, + } + peers[1] = &atomic.Pointer[RemotePeer]{} + peers[1].Store(peer) + + w := newMCWriter(conn, 1, keys.PrivSignKey) + r := newMCReader(conn, super, peers) + + w.SendLocalDiscovery() + r.handleNextPacket() + + if len(super.Messages) != 1 { + t.Fatal(super.Messages) + } + msg, ok := super.Messages[0].(controlMsg[PacketLocalDiscovery]) + if !ok || msg.SrcIP != 1 { + t.Fatal(ok, msg) + } +} + +func TestMCReader_noHeader(t *testing.T) { + keys := generateKeys() + super := &mockControlMsgHandler{} + conn := newMCMockConn() + + peers := [256]*atomic.Pointer[RemotePeer]{} + peer := &RemotePeer{ + IP: 1, + Up: true, + PubSignKey: keys.PubSignKey, + } + peers[1] = &atomic.Pointer[RemotePeer]{} + peers[1].Store(peer) + + r := newMCReader(conn, super, peers) + conn.WriteToUDP([]byte("0123546789"), nil) + r.handleNextPacket() + + if len(super.Messages) != 0 { + t.Fatal(super.Messages) + } +} + +func TestMCReader_noPeer(t *testing.T) { + keys := generateKeys() + super := &mockControlMsgHandler{} + conn := newMCMockConn() + + peers := [256]*atomic.Pointer[RemotePeer]{} + peer := &RemotePeer{ + IP: 1, + Up: true, + PubSignKey: keys.PubSignKey, + } + peers[1] = &atomic.Pointer[RemotePeer]{} + peers[2] = &atomic.Pointer[RemotePeer]{} + peers[1].Store(peer) + + w := newMCWriter(conn, 2, keys.PrivSignKey) + r := newMCReader(conn, super, peers) + + w.SendLocalDiscovery() + r.handleNextPacket() + + if len(super.Messages) != 0 { + t.Fatal(super.Messages) + } +} + +func TestMCReader_badSignature(t *testing.T) { + keys := generateKeys() + super := &mockControlMsgHandler{} + conn := newMCMockConn() + + peers := [256]*atomic.Pointer[RemotePeer]{} + peer := &RemotePeer{ + IP: 1, + Up: true, + PubSignKey: keys.PubSignKey, + } + peers[1] = &atomic.Pointer[RemotePeer]{} + peers[1].Store(peer) + + w := newMCWriter(conn, 1, keys.PrivSignKey) + w.SendLocalDiscovery() + + // Break signing. + packet := <-conn.packets + packet[0]++ + conn.packets <- packet + + r := newMCReader(conn, super, peers) + + r.handleNextPacket() + + if len(super.Messages) != 0 { + t.Fatal(super.Messages) + } +} diff --git a/peer/mock-iface_test.go b/peer/mock-iface_test.go new file mode 100644 index 0000000..ffef5d9 --- /dev/null +++ b/peer/mock-iface_test.go @@ -0,0 +1,31 @@ +package peer + +import "bytes" + +type TestIFace struct { + out *bytes.Buffer // Toward the network. + in *bytes.Buffer // From the network +} + +func NewTestIFace() *TestIFace { + return &TestIFace{ + out: &bytes.Buffer{}, + in: &bytes.Buffer{}, + } +} + +func (iface *TestIFace) Write(b []byte) (int, error) { + return iface.in.Write(b) +} + +func (iface *TestIFace) Read(b []byte) (int, error) { + return iface.out.Read(b) +} + +func (iface *TestIFace) UserWrite(b []byte) (int, error) { + return iface.out.Write(b) +} + +func (iface *TestIFace) UserRead(b []byte) (int, error) { + return iface.in.Read(b) +} diff --git a/peer/mock-network_test.go b/peer/mock-network_test.go new file mode 100644 index 0000000..4b5240c --- /dev/null +++ b/peer/mock-network_test.go @@ -0,0 +1,80 @@ +package peer + +import ( + "bytes" + "net" + "net/netip" + "sync" +) + +type TestPacket struct { + Addr netip.AddrPort + Data []byte +} + +type TestNetwork struct { + lock sync.Mutex + packets map[netip.AddrPort]chan TestPacket +} + +func NewTestNetwork() *TestNetwork { + return &TestNetwork{packets: map[netip.AddrPort]chan TestPacket{}} +} + +func (n *TestNetwork) NewUDPConn(localAddr netip.AddrPort) *TestUDPConn { + n.lock.Lock() + defer n.lock.Unlock() + if _, ok := n.packets[localAddr]; !ok { + n.packets[localAddr] = make(chan TestPacket, 1024) + } + return &TestUDPConn{ + addr: localAddr, + n: n, + packets: n.packets[localAddr], + } +} + +func (n *TestNetwork) write(b []byte, from, to netip.AddrPort) { + n.lock.Lock() + defer n.lock.Unlock() + if _, ok := n.packets[to]; !ok { + n.packets[to] = make(chan TestPacket, 1024) + } + n.packets[to] <- TestPacket{ + Addr: from, + Data: bytes.Clone(b), + } +} + +type TestUDPConn struct { + addr netip.AddrPort + n *TestNetwork + packets chan TestPacket +} + +func (c *TestUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { + c.n.write(b, c.addr, addr) + return len(b), nil +} + +func (c *TestUDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) { + return c.WriteToUDPAddrPort(b, addr.AddrPort()) +} + +func (c *TestUDPConn) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) { + pkt := <-c.packets + b = b[:len(pkt.Data)] + copy(b, pkt.Data) + return len(b), pkt.Addr, nil +} + +func (c *TestUDPConn) Packets() (out []TestPacket) { + for { + select { + case pkt := <-c.packets: + out = append(out, pkt) + default: + return + } + } +} diff --git a/peer/packets-util.go b/peer/packets-util.go index bda33b9..c0264e5 100644 --- a/peer/packets-util.go +++ b/peer/packets-util.go @@ -70,7 +70,7 @@ func (w *binWriter) AddrPort(addrPort netip.AddrPort) *binWriter { return w.Uint16(addrPort.Port()) } -func (w *binWriter) AddrPortArray(l [8]netip.AddrPort) *binWriter { +func (w *binWriter) AddrPort8(l [8]netip.AddrPort) *binWriter { for _, addrPort := range l { w.AddrPort(addrPort) } @@ -178,7 +178,7 @@ func (r *binReader) AddrPort(x *netip.AddrPort) *binReader { return r } -func (r *binReader) AddrPortArray(x *[8]netip.AddrPort) *binReader { +func (r *binReader) AddrPort8(x *[8]netip.AddrPort) *binReader { for i := range x { r.AddrPort(&x[i]) } diff --git a/peer/packets-util_test.go b/peer/packets-util_test.go index 5a518d7..6e4a98c 100644 --- a/peer/packets-util_test.go +++ b/peer/packets-util_test.go @@ -6,6 +6,26 @@ import ( "testing" ) +func TestBinWriteRead_invalidAddrPort(t *testing.T) { + addr := netip.AddrPort{} + buf := make([]byte, 1024) + buf = newBinWriter(buf). + AddrPort(addr). + Build() + + var addr2 netip.AddrPort + err := newBinReader(buf). + AddrPort(&addr2). + Error() + if err != nil { + t.Fatal(err) + } + + if addr2.IsValid() { + t.Fatal(addr, addr2) + } +} + func TestBinWriteRead(t *testing.T) { buf := make([]byte, 1024) @@ -35,7 +55,7 @@ func TestBinWriteRead(t *testing.T) { Byte(in.Type). Uint64(in.TraceID). AddrPort(in.DestAddr). - AddrPortArray(in.Addrs). + AddrPort8(in.Addrs). Build() out := Item{} @@ -44,7 +64,7 @@ func TestBinWriteRead(t *testing.T) { Byte(&out.Type). Uint64(&out.TraceID). AddrPort(&out.DestAddr). - AddrPortArray(&out.Addrs). + AddrPort8(&out.Addrs). Error() if err != nil { t.Fatal(err) diff --git a/peer/packets.go b/peer/packets.go index f7f1f85..596483d 100644 --- a/peer/packets.go +++ b/peer/packets.go @@ -5,93 +5,70 @@ import ( ) const ( - packetTypeSyn = iota + 1 - packetTypeSynAck - packetTypeAck - packetTypeProbe - packetTypeAddrDiscovery + PacketTypeSyn = iota + 1 + PacketTypeSynAck + PacketTypeAck + PacketTypeProbe + PacketTypeAddrDiscovery ) // ---------------------------------------------------------------------------- -type synPacket struct { +type PacketSyn struct { TraceID uint64 // TraceID to match response w/ request. - // TODO: SentAt int64 // Unixmilli. + //SentAt int64 // Unixmilli. + //SharedKeyType byte // Currently only 1 is supported for AES. SharedKey [32]byte // Our shared key. Direct bool PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. } -func (p synPacket) Marshal(buf []byte) []byte { +func (p PacketSyn) Marshal(buf []byte) []byte { return newBinWriter(buf). - Byte(packetTypeSyn). + Byte(PacketTypeSyn). Uint64(p.TraceID). + //Int64(p.SentAt). + //Byte(p.SharedKeyType). SharedKey(p.SharedKey). Bool(p.Direct). - AddrPort(p.PossibleAddrs[0]). - AddrPort(p.PossibleAddrs[1]). - AddrPort(p.PossibleAddrs[2]). - AddrPort(p.PossibleAddrs[3]). - AddrPort(p.PossibleAddrs[4]). - AddrPort(p.PossibleAddrs[5]). - AddrPort(p.PossibleAddrs[6]). - AddrPort(p.PossibleAddrs[7]). + AddrPort8(p.PossibleAddrs). Build() } -func parseSynPacket(buf []byte) (p synPacket, err error) { +func ParsePacketSyn(buf []byte) (p PacketSyn, err error) { err = newBinReader(buf[1:]). Uint64(&p.TraceID). + //Int64(&p.SentAt). + //Byte(&p.SharedKeyType). SharedKey(&p.SharedKey). Bool(&p.Direct). - AddrPort(&p.PossibleAddrs[0]). - AddrPort(&p.PossibleAddrs[1]). - AddrPort(&p.PossibleAddrs[2]). - AddrPort(&p.PossibleAddrs[3]). - AddrPort(&p.PossibleAddrs[4]). - AddrPort(&p.PossibleAddrs[5]). - AddrPort(&p.PossibleAddrs[6]). - AddrPort(&p.PossibleAddrs[7]). + AddrPort8(&p.PossibleAddrs). Error() return } // ---------------------------------------------------------------------------- -type ackPacket struct { +type PacketAck struct { TraceID uint64 ToAddr netip.AddrPort PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. } -func (p ackPacket) Marshal(buf []byte) []byte { +func (p PacketAck) Marshal(buf []byte) []byte { return newBinWriter(buf). - Byte(packetTypeAck). + Byte(PacketTypeAck). Uint64(p.TraceID). AddrPort(p.ToAddr). - AddrPort(p.PossibleAddrs[0]). - AddrPort(p.PossibleAddrs[1]). - AddrPort(p.PossibleAddrs[2]). - AddrPort(p.PossibleAddrs[3]). - AddrPort(p.PossibleAddrs[4]). - AddrPort(p.PossibleAddrs[5]). - AddrPort(p.PossibleAddrs[6]). - AddrPort(p.PossibleAddrs[7]). + AddrPort8(p.PossibleAddrs). Build() } -func parseAckPacket(buf []byte) (p ackPacket, err error) { +func ParsePacketAck(buf []byte) (p PacketAck, err error) { err = newBinReader(buf[1:]). Uint64(&p.TraceID). AddrPort(&p.ToAddr). - AddrPort(&p.PossibleAddrs[0]). - AddrPort(&p.PossibleAddrs[1]). - AddrPort(&p.PossibleAddrs[2]). - AddrPort(&p.PossibleAddrs[3]). - AddrPort(&p.PossibleAddrs[4]). - AddrPort(&p.PossibleAddrs[5]). - AddrPort(&p.PossibleAddrs[6]). - AddrPort(&p.PossibleAddrs[7]). + AddrPort8(&p.PossibleAddrs). Error() return } @@ -100,18 +77,18 @@ func parseAckPacket(buf []byte) (p ackPacket, err error) { // A probeReqPacket is sent from a client to a server to determine if direct // UDP communication can be used. -type probePacket struct { +type PacketProbe struct { TraceID uint64 } -func (p probePacket) Marshal(buf []byte) []byte { +func (p PacketProbe) Marshal(buf []byte) []byte { return newBinWriter(buf). - Byte(packetTypeProbe). + Byte(PacketTypeProbe). Uint64(p.TraceID). Build() } -func parseProbePacket(buf []byte) (p probePacket, err error) { +func ParsePacketProbe(buf []byte) (p PacketProbe, err error) { err = newBinReader(buf[1:]). Uint64(&p.TraceID). Error() @@ -120,4 +97,4 @@ func parseProbePacket(buf []byte) (p probePacket, err error) { // ---------------------------------------------------------------------------- -type localDiscoveryPacket struct{} +type PacketLocalDiscovery struct{} diff --git a/peer/packets_test.go b/peer/packets_test.go index 333deff..3ddc1a0 100644 --- a/peer/packets_test.go +++ b/peer/packets_test.go @@ -1 +1,66 @@ package peer + +import ( + "crypto/rand" + "net/netip" + "reflect" + "testing" +) + +func TestSynPacket(t *testing.T) { + p := PacketSyn{ + TraceID: newTraceID(), + //SentAt: time.Now().UnixMilli(), + //SharedKeyType: 1, + Direct: true, + } + rand.Read(p.SharedKey[:]) + + p.PossibleAddrs[0] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 234) + p.PossibleAddrs[1] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 2, 3, 4}), 12399) + p.PossibleAddrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{3, 2, 3, 4}), 60000) + + buf := p.Marshal(newBuf()) + p2, err := ParsePacketSyn(buf) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(p, p2) { + t.Fatal(p2) + } +} + +func TestAckPacket(t *testing.T) { + p := PacketAck{ + TraceID: newTraceID(), + ToAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 234), + } + + p.PossibleAddrs[0] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 2, 3, 4}), 100) + p.PossibleAddrs[1] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 2, 3, 4}), 12399) + p.PossibleAddrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{3, 2, 3, 4}), 60000) + + buf := p.Marshal(newBuf()) + p2, err := ParsePacketAck(buf) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(p, p2) { + t.Fatal(p2) + } +} + +func TestProbePacket(t *testing.T) { + p := PacketProbe{ + TraceID: newTraceID(), + } + + buf := p.Marshal(newBuf()) + p2, err := ParsePacketProbe(buf) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(p, p2) { + t.Fatal(p2) + } +} diff --git a/peer/peer_test.go b/peer/peer_test.go new file mode 100644 index 0000000..414beaa --- /dev/null +++ b/peer/peer_test.go @@ -0,0 +1,125 @@ +package peer + +import ( + "bytes" + "crypto/rand" + mrand "math/rand" + "net/netip" + "sync/atomic" +) + +// A test peer. +type P struct { + cryptoKeys + RT *atomic.Pointer[RoutingTable] + Conn *TestUDPConn + IFace *TestIFace + ConnWriter *ConnWriter + ConnReader *ConnReader + IFReader *IFReader + Super *Supervisor +} + +func NewPeerForTesting(n *TestNetwork, ip byte, addr netip.AddrPort) P { + p := P{ + cryptoKeys: generateKeys(), + RT: &atomic.Pointer[RoutingTable]{}, + IFace: NewTestIFace(), + } + + rt := NewRoutingTable(ip, addr) + p.RT.Store(&rt) + p.Conn = n.NewUDPConn(addr) + p.ConnWriter = NewConnWriter(p.Conn.WriteToUDPAddrPort, p.RT) + p.IFReader = NewIFReader(p.IFace, p.ConnWriter) + + /* + p.ConnReader = NewConnReader( + p.Conn.ReadFromUDPAddrPort, + p.IFace, + p.ConnWriter.Forward, + p.Super.HandleControlMsg, + p.RT) + */ + return p +} + +func ConnectPeers(p1, p2 *P) { + rt1 := p1.RT.Load() + rt2 := p2.RT.Load() + + ip1 := rt1.LocalIP + ip2 := rt2.LocalIP + + rt1.Peers[ip2].Up = true + rt1.Peers[ip2].Direct = true + rt1.Peers[ip2].Relay = true + rt1.Peers[ip2].DirectAddr = rt2.LocalAddr + rt1.Peers[ip2].PubSignKey = p2.PubSignKey + rt1.Peers[ip2].ControlCipher = newControlCipher(p1.PrivKey, p2.PubKey) + rt1.Peers[ip2].DataCipher = newDataCipher() + + rt2.Peers[ip1].Up = true + rt2.Peers[ip1].Direct = true + rt2.Peers[ip1].Relay = true + rt2.Peers[ip1].DirectAddr = rt1.LocalAddr + rt2.Peers[ip1].PubSignKey = p1.PubSignKey + rt2.Peers[ip1].ControlCipher = newControlCipher(p2.PrivKey, p1.PubKey) + rt2.Peers[ip1].DataCipher = rt1.Peers[ip2].DataCipher +} + +func NewPeersForTesting() (p1, p2, p3 P) { + n := NewTestNetwork() + + p1 = NewPeerForTesting( + n, + 1, + netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 100)) + + p2 = NewPeerForTesting( + n, + 2, + netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 2}), 200)) + + p3 = NewPeerForTesting( + n, + 3, + netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 3}), 300)) + + ConnectPeers(&p1, &p2) + ConnectPeers(&p1, &p3) + ConnectPeers(&p2, &p3) + + return +} + +func RandPacket() []byte { + n := mrand.Intn(1200) + b := make([]byte, n) + rand.Read(b) + return b +} + +func ModifyPacket(in []byte) []byte { + x := make([]byte, 1) + + for { + rand.Read(x) + out := bytes.Clone(in) + idx := mrand.Intn(len(out)) + if out[idx] != x[0] { + out[idx] = x[0] + return out + } + } +} + +// ---------------------------------------------------------------------------- + +type UnknownControlPacket struct { + TraceID uint64 +} + +func (p UnknownControlPacket) Marshal(buf []byte) []byte { + return newBinWriter(buf).Byte(255).Uint64(p.TraceID).Build() +} diff --git a/peer/peerstates.go b/peer/peerstates.go new file mode 100644 index 0000000..b05826c --- /dev/null +++ b/peer/peerstates.go @@ -0,0 +1,355 @@ +package peer + +import ( + "fmt" + "log" + "net/netip" + "strings" + "time" + "vppn/m" + + "git.crumpington.com/lib/go/ratelimiter" +) + +type PeerState interface { + OnPeerUpdate(*m.Peer) PeerState + OnSyn(controlMsg[PacketSyn]) PeerState + OnAck(controlMsg[PacketAck]) + OnProbe(controlMsg[PacketProbe]) PeerState + OnLocalDiscovery(controlMsg[PacketLocalDiscovery]) + OnPingTimer() PeerState +} + +// ---------------------------------------------------------------------------- + +type State struct { + // Output. + publish func(RemotePeer) + sendControlPacket func(RemotePeer, Marshaller) + + // Immutable data. + localIP byte + remoteIP byte + privKey []byte + localAddr netip.AddrPort // If valid, then local peer is publicly accessible. + + pubAddrs *pubAddrStore + + // The purpose of this state machine is to manage the RemotePeer object, + // publishing it as necessary. + staged RemotePeer // Local copy of shared data. See publish(). + + // Mutable peer data. + peer *m.Peer + + // We rate limit per remote endpoint because if we don't we tend to lose + // packets. + limiter *ratelimiter.Limiter +} + +func (s *State) OnPeerUpdate(peer *m.Peer) PeerState { + defer func() { + // Don't defer directly otherwise s.staged will be evaluated immediately + // and won't reflect changes made in the function. + s.publish(s.staged) + }() + + if peer == nil { + return EnterStateDisconnected(s) + } + + s.peer = peer + s.staged.Relay = false + s.staged.Direct = false + s.staged.DirectAddr = netip.AddrPort{} + s.staged.PubSignKey = nil + s.staged.PubSignKey = peer.PubSignKey + s.staged.ControlCipher = newControlCipher(s.privKey, peer.PubKey) + s.staged.DataCipher = newDataCipher() + + if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid { + s.staged.Relay = peer.Relay + s.staged.Direct = true + s.staged.DirectAddr = netip.AddrPortFrom(ip, peer.Port) + + if s.localAddr.IsValid() && s.localIP < s.remoteIP { + return EnterStateServer(s) + } + + return EnterStateClientDirect(s) + } + + if s.localAddr.IsValid() { + s.staged.Direct = true + return EnterStateServer(s) + } + + if s.localIP < s.remoteIP { + return EnterStateServer(s) + } + + return EnterStateClientRelayed(s) +} + +func (s *State) logf(format string, args ...any) { + b := strings.Builder{} + name := "--" + if s.peer != nil { + name = s.peer.Name + } + b.WriteString(fmt.Sprintf("%30s: ", name)) + + if s.staged.Direct { + b.WriteString("DIRECT | ") + } else { + b.WriteString("RELAYED | ") + } + + if s.staged.Up { + b.WriteString("UP | ") + } else { + b.WriteString("DOWN | ") + } + + log.Printf(b.String()+format, args...) +} + +// ---------------------------------------------------------------------------- + +func (s *State) SendTo(pkt Marshaller, addr netip.AddrPort) { + if !addr.IsValid() { + return + } + route := s.staged + route.Direct = true + route.DirectAddr = addr + s.Send(route, pkt) +} + +func (s *State) Send(peer RemotePeer, pkt Marshaller) { + if err := s.limiter.Limit(); err != nil { + s.logf("Rate limited.") + return + } + s.sendControlPacket(peer, pkt) +} + +// ---------------------------------------------------------------------------- + +type StateDisconnected struct{ *State } + +func EnterStateDisconnected(s *State) PeerState { + s.logf("==> Disconnected") + s.peer = nil + s.staged.Up = false + s.staged.Relay = false + s.staged.Direct = false + s.staged.DirectAddr = netip.AddrPort{} + s.staged.PubSignKey = nil + s.staged.ControlCipher = nil + s.staged.DataCipher = nil + s.publish(s.staged) + return &StateDisconnected{State: s} +} + +func (s *StateDisconnected) OnSyn(controlMsg[PacketSyn]) PeerState { return nil } +func (s *StateDisconnected) OnAck(controlMsg[PacketAck]) {} +func (s *StateDisconnected) OnProbe(controlMsg[PacketProbe]) PeerState { return nil } +func (s *StateDisconnected) OnLocalDiscovery(controlMsg[PacketLocalDiscovery]) {} +func (s *StateDisconnected) OnPingTimer() PeerState { return nil } + +// ---------------------------------------------------------------------------- + +type StateServer struct { + *StateDisconnected + lastSeen time.Time + synTraceID uint64 +} + +func EnterStateServer(s *State) PeerState { + s.logf("==> Server") + return &StateServer{StateDisconnected: &StateDisconnected{State: s}} +} + +func (s *StateServer) OnSyn(msg controlMsg[PacketSyn]) PeerState { + s.lastSeen = time.Now() + p := msg.Packet + + // Before we can respond to this packet, we need to make sure the + // route is setup properly. + // + // The client will update the syn's TraceID whenever there's a change. + // The server will follow the client's request. + if p.TraceID != s.synTraceID || !s.staged.Up { + s.synTraceID = p.TraceID + s.staged.Up = true + s.staged.Direct = p.Direct + s.staged.DataCipher = newDataCipherFromKey(p.SharedKey) + s.staged.DirectAddr = msg.SrcAddr + s.publish(s.staged) + s.logf("Got SYN.") + } + + // Always respond. + ack := PacketAck{ + TraceID: p.TraceID, + ToAddr: s.staged.DirectAddr, + PossibleAddrs: s.pubAddrs.Get(), + } + s.Send(s.staged, ack) + + if p.Direct { + return nil + } + + for _, addr := range msg.Packet.PossibleAddrs { + if !addr.IsValid() { + break + } + s.SendTo(PacketProbe{TraceID: newTraceID()}, addr) + } + + return nil +} + +func (s *StateServer) OnProbe(msg controlMsg[PacketProbe]) PeerState { + if msg.SrcAddr.IsValid() { + s.SendTo(PacketProbe{TraceID: msg.Packet.TraceID}, msg.SrcAddr) + } + return nil +} + +func (s *StateServer) OnPingTimer() PeerState { + if time.Since(s.lastSeen) > timeoutInterval && s.staged.Up { + s.staged.Up = false + s.publish(s.staged) + s.logf("Timeout.") + } + return nil +} + +// ---------------------------------------------------------------------------- + +type StateClientDirect struct { + *StateDisconnected + lastSeen time.Time + syn PacketSyn +} + +func EnterStateClientDirect(s *State) PeerState { + s.logf("==> ClientDirect") + return NewStateClientDirect(s) +} + +func NewStateClientDirect(s *State) *StateClientDirect { + state := &StateClientDirect{ + StateDisconnected: &StateDisconnected{s}, + lastSeen: time.Now(), // Avoid immediate timeout. + } + + state.syn = PacketSyn{ + TraceID: newTraceID(), + SharedKey: s.staged.DataCipher.Key(), + Direct: s.staged.Direct, + PossibleAddrs: s.pubAddrs.Get(), + } + state.Send(s.staged, state.syn) + return state +} + +func (s *StateClientDirect) OnAck(msg controlMsg[PacketAck]) { + if msg.Packet.TraceID != s.syn.TraceID { + return + } + + s.lastSeen = time.Now() + + if !s.staged.Up { + s.staged.Up = true + s.publish(s.staged) + s.logf("Got ACK.") + } + + s.pubAddrs.Store(msg.Packet.ToAddr) +} + +func (s *StateClientDirect) OnPingTimer() PeerState { + if time.Since(s.lastSeen) > timeoutInterval { + if s.staged.Up { + s.staged.Up = false + s.publish(s.staged) + s.logf("Timeout.") + } + return s.OnPeerUpdate(s.peer) + } + + s.Send(s.staged, s.syn) + return nil +} + +// ---------------------------------------------------------------------------- + +type StateClientRelayed struct { + *StateClientDirect + ack PacketAck + probes map[uint64]netip.AddrPort + localDiscoveryAddr netip.AddrPort +} + +func EnterStateClientRelayed(s *State) PeerState { + s.logf("==> ClientRelayed") + return &StateClientRelayed{ + StateClientDirect: NewStateClientDirect(s), + probes: map[uint64]netip.AddrPort{}, + } +} + +func (s *StateClientRelayed) OnAck(msg controlMsg[PacketAck]) { + s.ack = msg.Packet + s.StateClientDirect.OnAck(msg) +} + +func (s *StateClientRelayed) OnProbe(msg controlMsg[PacketProbe]) PeerState { + addr, ok := s.probes[msg.Packet.TraceID] + if !ok { + return nil + } + + s.staged.DirectAddr = addr + s.staged.Direct = true + s.publish(s.staged) + return EnterStateClientDirect(s.StateClientDirect.State) +} + +func (s *StateClientRelayed) OnLocalDiscovery(msg controlMsg[PacketLocalDiscovery]) { + // The source port will be the multicast port, so we'll have to + // construct the correct address using the peer's listed port. + s.localDiscoveryAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) +} + +func (s *StateClientRelayed) OnPingTimer() PeerState { + if nextState := s.StateClientDirect.OnPingTimer(); nextState != nil { + return nextState + } + + clear(s.probes) + for _, addr := range s.ack.PossibleAddrs { + if !addr.IsValid() { + break + } + s.sendProbeTo(addr) + } + + if s.localDiscoveryAddr.IsValid() { + s.sendProbeTo(s.localDiscoveryAddr) + s.localDiscoveryAddr = netip.AddrPort{} + } + + return nil +} + +func (s *StateClientRelayed) sendProbeTo(addr netip.AddrPort) { + probe := PacketProbe{TraceID: newTraceID()} + s.probes[probe.TraceID] = addr + s.SendTo(probe, addr) +} diff --git a/peer/peerstates_test.go b/peer/peerstates_test.go new file mode 100644 index 0000000..16805d0 --- /dev/null +++ b/peer/peerstates_test.go @@ -0,0 +1,509 @@ +package peer + +import ( + "net/netip" + "testing" + "time" + "vppn/m" + + "git.crumpington.com/lib/go/ratelimiter" +) + +// ---------------------------------------------------------------------------- + +type PeerStateControlMsg struct { + Peer RemotePeer + Packet any +} + +type PeerStateTestHarness struct { + State PeerState + Published RemotePeer + Sent []PeerStateControlMsg +} + +func NewPeerStateTestHarness() *PeerStateTestHarness { + h := &PeerStateTestHarness{} + + keys := generateKeys() + + state := &State{ + publish: func(rp RemotePeer) { + h.Published = rp + }, + sendControlPacket: func(rp RemotePeer, pkt Marshaller) { + h.Sent = append(h.Sent, PeerStateControlMsg{rp, pkt}) + }, + localIP: 2, + remoteIP: 3, + privKey: keys.PrivKey, + pubAddrs: newPubAddrStore(netip.AddrPort{}), + limiter: ratelimiter.New(ratelimiter.Config{ + FillPeriod: 20 * time.Millisecond, + MaxWaitCount: 1, + }), + } + + h.State = EnterStateDisconnected(state) + return h +} + +func (h *PeerStateTestHarness) PeerUpdate(p *m.Peer) { + if s := h.State.OnPeerUpdate(p); s != nil { + h.State = s + } +} + +func (h *PeerStateTestHarness) OnSyn(msg controlMsg[PacketSyn]) { + if s := h.State.OnSyn(msg); s != nil { + h.State = s + } +} + +func (h *PeerStateTestHarness) OnProbe(msg controlMsg[PacketProbe]) { + if s := h.State.OnProbe(msg); s != nil { + h.State = s + } +} + +func (h *PeerStateTestHarness) OnPingTimer() { + if s := h.State.OnPingTimer(); s != nil { + h.State = s + } +} + +func (h *PeerStateTestHarness) ConfigServer_Public(t *testing.T) *StateServer { + keys := generateKeys() + + state := h.State.(*StateDisconnected) + state.localAddr = addrPort4(1, 1, 1, 2, 200) + + peer := &m.Peer{ + PeerIP: 3, + PublicIP: []byte{1, 1, 1, 3}, + Port: 456, + PubKey: keys.PubKey, + PubSignKey: keys.PubSignKey, + } + + h.PeerUpdate(peer) + assertEqual(t, h.Published.Up, false) + return assertType[*StateServer](t, h.State) +} + +func (h *PeerStateTestHarness) ConfigServer_Relayed(t *testing.T) *StateServer { + keys := generateKeys() + peer := &m.Peer{ + PeerIP: 3, + Port: 456, + PubKey: keys.PubKey, + PubSignKey: keys.PubSignKey, + } + + h.PeerUpdate(peer) + assertEqual(t, h.Published.Up, false) + return assertType[*StateServer](t, h.State) +} + +func (h *PeerStateTestHarness) ConfigClientDirect(t *testing.T) *StateClientDirect { + keys := generateKeys() + peer := &m.Peer{ + PeerIP: 3, + PublicIP: []byte{1, 2, 3, 4}, + Port: 456, + PubKey: keys.PubKey, + PubSignKey: keys.PubSignKey, + } + + h.PeerUpdate(peer) + assertEqual(t, h.Published.Up, false) + return assertType[*StateClientDirect](t, h.State) +} + +func (h *PeerStateTestHarness) ConfigClientRelayed(t *testing.T) *StateClientRelayed { + keys := generateKeys() + + state := h.State.(*StateDisconnected) + state.remoteIP = 1 + + peer := &m.Peer{ + PeerIP: 3, + Port: 456, + PubKey: keys.PubKey, + PubSignKey: keys.PubSignKey, + } + + h.PeerUpdate(peer) + assertEqual(t, h.Published.Up, false) + return assertType[*StateClientRelayed](t, h.State) +} + +// ---------------------------------------------------------------------------- + +func TestPeerState_OnPeerUpdate_nilPeer(t *testing.T) { + h := NewPeerStateTestHarness() + h.PeerUpdate(nil) + assertType[*StateDisconnected](t, h.State) +} + +func TestPeerState_OnPeerUpdate_publicLocalIsServer(t *testing.T) { + keys := generateKeys() + h := NewPeerStateTestHarness() + + state := h.State.(*StateDisconnected) + state.localAddr = addrPort4(1, 1, 1, 2, 200) + + peer := &m.Peer{ + PeerIP: 3, + Port: 456, + PubKey: keys.PubKey, + PubSignKey: keys.PubSignKey, + } + + h.PeerUpdate(peer) + assertEqual(t, h.Published.Up, false) + assertType[*StateServer](t, h.State) +} + +func TestPeerState_OnPeerUpdate_serverDirect(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigServer_Public(t) +} + +func TestPeerState_OnPeerUpdate_serverRelayed(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigServer_Relayed(t) +} + +func TestPeerState_OnPeerUpdate_clientDirect(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientDirect(t) +} + +func TestPeerState_OnPeerUpdate_clientRelayed(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientRelayed(t) +} + +func TestStateServer_directSyn(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigServer_Relayed(t) + + assertEqual(t, h.Published.Up, false) + + synMsg := controlMsg[PacketSyn]{ + SrcIP: 3, + SrcAddr: addrPort4(1, 1, 1, 3, 300), + Packet: PacketSyn{ + TraceID: newTraceID(), + //SentAt: time.Now().UnixMilli(), + //SharedKeyType: 1, + Direct: true, + }, + } + + h.State.OnSyn(synMsg) + + assertEqual(t, len(h.Sent), 1) + ack := assertType[PacketAck](t, h.Sent[0].Packet) + assertEqual(t, ack.TraceID, synMsg.Packet.TraceID) + assertEqual(t, h.Sent[0].Peer.IP, 3) + assertEqual(t, ack.PossibleAddrs[0].IsValid(), false) + assertEqual(t, h.Published.Up, true) +} + +func TestStateServer_relayedSyn(t *testing.T) { + h := NewPeerStateTestHarness() + state := h.ConfigServer_Relayed(t) + + state.pubAddrs.Store(addrPort4(4, 5, 6, 7, 1234)) + + assertEqual(t, h.Published.Up, false) + + synMsg := controlMsg[PacketSyn]{ + SrcIP: 3, + SrcAddr: addrPort4(1, 1, 1, 3, 300), + Packet: PacketSyn{ + TraceID: newTraceID(), + //SentAt: time.Now().UnixMilli(), + //SharedKeyType: 1, + Direct: false, + }, + } + synMsg.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 3, 300) + synMsg.Packet.PossibleAddrs[1] = addrPort4(2, 2, 2, 3, 300) + + h.State.OnSyn(synMsg) + + assertEqual(t, len(h.Sent), 3) + + ack := assertType[PacketAck](t, h.Sent[0].Packet) + assertEqual(t, ack.TraceID, synMsg.Packet.TraceID) + assertEqual(t, h.Sent[0].Peer.IP, 3) + assertEqual(t, ack.PossibleAddrs[0], addrPort4(4, 5, 6, 7, 1234)) + assertEqual(t, ack.PossibleAddrs[1].IsValid(), false) + assertEqual(t, h.Published.Up, true) + + assertType[PacketProbe](t, h.Sent[1].Packet) + assertType[PacketProbe](t, h.Sent[2].Packet) + assertEqual(t, h.Sent[1].Peer.DirectAddr, addrPort4(1, 1, 1, 3, 300)) + assertEqual(t, h.Sent[2].Peer.DirectAddr, addrPort4(2, 2, 2, 3, 300)) +} + +func TestStateServer_onProbe(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigServer_Relayed(t) + assertEqual(t, h.Published.Up, false) + + probeMsg := controlMsg[PacketProbe]{ + SrcIP: 3, + SrcAddr: addrPort4(1, 1, 1, 3, 300), + Packet: PacketProbe{TraceID: newTraceID()}, + } + + h.State.OnProbe(probeMsg) + + assertEqual(t, len(h.Sent), 1) + + probe := assertType[PacketProbe](t, h.Sent[0].Packet) + assertEqual(t, probe.TraceID, probeMsg.Packet.TraceID) + assertEqual(t, h.Sent[0].Peer.DirectAddr, addrPort4(1, 1, 1, 3, 300)) +} + +func TestStateServer_OnPingTimer_timeout(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigServer_Relayed(t) + + synMsg := controlMsg[PacketSyn]{ + SrcIP: 3, + SrcAddr: addrPort4(1, 1, 1, 3, 300), + Packet: PacketSyn{ + TraceID: newTraceID(), + //SentAt: time.Now().UnixMilli(), + //SharedKeyType: 1, + Direct: true, + }, + } + + h.State.OnSyn(synMsg) + assertEqual(t, len(h.Sent), 1) + assertEqual(t, h.Published.Up, true) + + // Ping shouldn't timeout. + h.OnPingTimer() + assertEqual(t, h.Published.Up, true) + + // Advance the time, then ping. + state := assertType[*StateServer](t, h.State) + state.lastSeen = time.Now().Add(-timeoutInterval - time.Second) + + h.OnPingTimer() + assertEqual(t, h.Published.Up, false) +} + +func TestStateClientDirect_OnAck(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientDirect(t) + + assertEqual(t, h.Published.Up, false) + + // On entering the state, a SYN should have been sent. + assertEqual(t, len(h.Sent), 1) + syn := assertType[PacketSyn](t, h.Sent[0].Packet) + + ack := controlMsg[PacketAck]{ + Packet: PacketAck{TraceID: syn.TraceID}, + } + h.State.OnAck(ack) + assertEqual(t, h.Published.Up, true) +} + +func TestStateClientDirect_OnAck_incorrectTraceID(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientDirect(t) + + assertEqual(t, h.Published.Up, false) + + // On entering the state, a SYN should have been sent. + assertEqual(t, len(h.Sent), 1) + syn := assertType[PacketSyn](t, h.Sent[0].Packet) + + ack := controlMsg[PacketAck]{ + Packet: PacketAck{TraceID: syn.TraceID + 1}, + } + h.State.OnAck(ack) + assertEqual(t, h.Published.Up, false) +} + +func TestStateClientDirect_OnPingTimer(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientDirect(t) + + // On entering the state, a SYN should have been sent. + assertEqual(t, len(h.Sent), 1) + assertType[PacketSyn](t, h.Sent[0].Packet) + + h.OnPingTimer() + + // On ping timer, another syn should be sent. Additionally, we should remain + // in the same state. + assertEqual(t, len(h.Sent), 2) + assertType[PacketSyn](t, h.Sent[1].Packet) + assertType[*StateClientDirect](t, h.State) + assertEqual(t, h.Published.Up, false) +} + +func TestStateClientDirect_OnPingTimer_timeout(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientDirect(t) + + assertEqual(t, h.Published.Up, false) + + // On entering the state, a SYN should have been sent. + assertEqual(t, len(h.Sent), 1) + syn := assertType[PacketSyn](t, h.Sent[0].Packet) + + ack := controlMsg[PacketAck]{ + Packet: PacketAck{TraceID: syn.TraceID}, + } + h.State.OnAck(ack) + assertEqual(t, h.Published.Up, true) + + state := assertType[*StateClientDirect](t, h.State) + state.lastSeen = time.Now().Add(-(timeoutInterval + time.Second)) + + h.OnPingTimer() + + // On ping timer, we should timeout, causing the client to reset. Another SYN + // will be sent when re-entering the state, but the connection should be down. + assertEqual(t, len(h.Sent), 2) + assertType[PacketSyn](t, h.Sent[1].Packet) + assertType[*StateClientDirect](t, h.State) + assertEqual(t, h.Published.Up, false) +} + +func TestStateClientRelayed_OnAck(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientRelayed(t) + + assertEqual(t, h.Published.Up, false) + + // On entering the state, a SYN should have been sent. + assertEqual(t, len(h.Sent), 1) + syn := assertType[PacketSyn](t, h.Sent[0].Packet) + + ack := controlMsg[PacketAck]{ + Packet: PacketAck{TraceID: syn.TraceID}, + } + h.State.OnAck(ack) + assertEqual(t, h.Published.Up, true) +} + +func TestStateClientRelayed_OnPingTimer_noAddrs(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientRelayed(t) + + assertEqual(t, h.Published.Up, false) + + // On entering the state, a SYN should have been sent. + assertEqual(t, len(h.Sent), 1) + + // If we haven't had an ack yet, we won't have addresses to probe. Therefore + // we'll have just one more syn packet sent. + h.OnPingTimer() + assertEqual(t, len(h.Sent), 2) +} + +func TestStateClientRelayed_OnPingTimer_withAddrs(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientRelayed(t) + + assertEqual(t, h.Published.Up, false) + + // On entering the state, a SYN should have been sent. + assertEqual(t, len(h.Sent), 1) + + syn := assertType[PacketSyn](t, h.Sent[0].Packet) + + ack := controlMsg[PacketAck]{Packet: PacketAck{TraceID: syn.TraceID}} + ack.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 300) + ack.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 300) + + h.State.OnAck(ack) + + // Add a local discovery address. Note that the port will be configured port + // and no the one provided here. + h.State.OnLocalDiscovery(controlMsg[PacketLocalDiscovery]{ + SrcIP: 3, + SrcAddr: addrPort4(2, 2, 2, 3, 300), + }) + + // We should see one SYN and three probe packets. + h.OnPingTimer() + assertEqual(t, len(h.Sent), 5) + assertType[PacketSyn](t, h.Sent[1].Packet) + assertType[PacketProbe](t, h.Sent[2].Packet) + assertType[PacketProbe](t, h.Sent[3].Packet) + assertType[PacketProbe](t, h.Sent[4].Packet) + + assertEqual(t, h.Sent[2].Peer.DirectAddr, addrPort4(1, 1, 1, 1, 300)) + assertEqual(t, h.Sent[3].Peer.DirectAddr, addrPort4(1, 1, 1, 2, 300)) + assertEqual(t, h.Sent[4].Peer.DirectAddr, addrPort4(2, 2, 2, 3, 456)) +} + +func TestStateClientRelayed_OnPingTimer_timeout(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientRelayed(t) + + // On entering the state, a SYN should have been sent. + assertEqual(t, len(h.Sent), 1) + syn := assertType[PacketSyn](t, h.Sent[0].Packet) + + ack := controlMsg[PacketAck]{ + Packet: PacketAck{TraceID: syn.TraceID}, + } + h.State.OnAck(ack) + assertEqual(t, h.Published.Up, true) + + state := assertType[*StateClientRelayed](t, h.State) + state.lastSeen = time.Now().Add(-(timeoutInterval + time.Second)) + + h.OnPingTimer() + + // On ping timer, we should timeout, causing the client to reset. Another SYN + // will be sent when re-entering the state, but the connection should be down. + assertEqual(t, len(h.Sent), 2) + assertType[PacketSyn](t, h.Sent[1].Packet) + assertType[*StateClientRelayed](t, h.State) + assertEqual(t, h.Published.Up, false) +} + +func TestStateClientRelayed_OnProbe_unknownAddr(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientRelayed(t) + + h.OnProbe(controlMsg[PacketProbe]{ + Packet: PacketProbe{TraceID: newTraceID()}, + }) + + assertType[*StateClientRelayed](t, h.State) +} + +func TestStateClientRelayed_OnProbe_upgradeDirect(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientRelayed(t) + + syn := assertType[PacketSyn](t, h.Sent[0].Packet) + + ack := controlMsg[PacketAck]{Packet: PacketAck{TraceID: syn.TraceID}} + ack.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 300) + ack.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 300) + + h.State.OnAck(ack) + h.OnPingTimer() + + probe := assertType[PacketProbe](t, h.Sent[2].Packet) + h.OnProbe(controlMsg[PacketProbe]{Packet: probe}) + + assertType[*StateClientDirect](t, h.State) +} diff --git a/peer/pubaddrs.go b/peer/pubaddrs.go new file mode 100644 index 0000000..13ab66f --- /dev/null +++ b/peer/pubaddrs.go @@ -0,0 +1,75 @@ +package peer + +import ( + "log" + "net/netip" + "runtime/debug" + "sort" + "time" +) + +type pubAddrStore struct { + localPub bool + localAddr netip.AddrPort + lastSeen map[netip.AddrPort]time.Time + addrList []netip.AddrPort +} + +func newPubAddrStore(localAddr netip.AddrPort) *pubAddrStore { + return &pubAddrStore{ + localPub: localAddr.IsValid(), + localAddr: localAddr, + lastSeen: map[netip.AddrPort]time.Time{}, + addrList: make([]netip.AddrPort, 0, 32), + } +} + +func (store *pubAddrStore) Store(add netip.AddrPort) { + if store.localPub { + log.Printf("OOPS: Local pub but storage attempt: %s", debug.Stack()) + return + } + + if !add.IsValid() { + return + } + + if _, exists := store.lastSeen[add]; !exists { + store.addrList = append(store.addrList, add) + } + store.lastSeen[add] = time.Now() + store.sort() +} + +func (store *pubAddrStore) Get() (addrs [8]netip.AddrPort) { + if store.localPub { + addrs[0] = store.localAddr + return + } + + copy(addrs[:], store.addrList) + return +} + +func (store *pubAddrStore) Clean() { + if store.localPub { + return + } + + for ip, lastSeen := range store.lastSeen { + if time.Since(lastSeen) > timeoutInterval { + delete(store.lastSeen, ip) + } + } + store.addrList = store.addrList[:0] + for ip := range store.lastSeen { + store.addrList = append(store.addrList, ip) + } + store.sort() +} + +func (store *pubAddrStore) sort() { + sort.Slice(store.addrList, func(i, j int) bool { + return store.lastSeen[store.addrList[j]].Before(store.lastSeen[store.addrList[i]]) + }) +} diff --git a/peer/pubaddrs_test.go b/peer/pubaddrs_test.go new file mode 100644 index 0000000..b79e854 --- /dev/null +++ b/peer/pubaddrs_test.go @@ -0,0 +1,29 @@ +package peer + +import ( + "net/netip" + "testing" + "time" +) + +func TestPubAddrStore(t *testing.T) { + s := newPubAddrStore(netip.AddrPort{}) + + l := []netip.AddrPort{ + netip.AddrPortFrom(netip.AddrFrom4([4]byte{0, 1, 2, 3}), 20), + netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 2, 3}), 21), + netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 1, 2, 3}), 22), + } + + for i := range l { + s.Store(l[i]) + time.Sleep(time.Millisecond) + } + + s.Clean() + + l2 := s.Get() + if l2[0] != l[2] || l2[1] != l[1] || l2[2] != l[0] { + t.Fatal(l, l2) + } +} diff --git a/peer/routingtable.go b/peer/routingtable.go new file mode 100644 index 0000000..0943ab2 --- /dev/null +++ b/peer/routingtable.go @@ -0,0 +1,137 @@ +package peer + +import ( + "net/netip" + "sync/atomic" + "time" +) + +// TODO: Remove +func NewRemotePeer(ip byte) *RemotePeer { + counter := uint64(time.Now().Unix()<<30 + 1) + return &RemotePeer{ + IP: ip, + counter: &counter, + dupCheck: newDupCheck(0), + } +} + +// ---------------------------------------------------------------------------- + +type RemotePeer struct { + localIP byte + IP byte // VPN IP of peer (last byte). + Up bool // True if data can be sent on the peer. + Relay bool // True if the peer is a relay. + Direct bool // True if this is a direct connection. + DirectAddr netip.AddrPort // Remote address if directly connected. + PubSignKey []byte + ControlCipher *controlCipher + DataCipher *dataCipher + + counter *uint64 // For sending to. Atomic access only. + dupCheck *dupCheck // For receiving from. Not safe for concurrent use. +} + +func (p RemotePeer) EncryptDataPacket(destIP byte, data, out []byte) []byte { + h := header{ + StreamID: dataStreamID, + Counter: atomic.AddUint64(p.counter, 1), + SourceIP: p.localIP, + DestIP: destIP, + } + return p.DataCipher.Encrypt(h, data, out) +} + +// Decrypts and de-dups incoming data packets. +func (p RemotePeer) DecryptDataPacket(h header, enc, out []byte) ([]byte, error) { + dec, ok := p.DataCipher.Decrypt(enc, out) + if !ok { + return nil, errDecryptionFailed + } + + if p.dupCheck.IsDup(h.Counter) { + return nil, errDuplicateSeqNum + } + + return dec, nil +} + +// Peer must have a ControlCipher. +func (p RemotePeer) EncryptControlPacket(pkt Marshaller, tmp, out []byte) []byte { + h := header{ + StreamID: controlStreamID, + Counter: atomic.AddUint64(p.counter, 1), + SourceIP: p.localIP, + DestIP: p.IP, + } + tmp = pkt.Marshal(tmp) + return p.ControlCipher.Encrypt(h, tmp, out) +} + +// Returns a controlMsg[PacketType]. Peer must have a non-nil ControlCipher. +// +// This function also drops packets with duplicate sequence numbers. +func (p RemotePeer) DecryptControlPacket(fromAddr netip.AddrPort, h header, enc, tmp []byte) (any, error) { + out, ok := p.ControlCipher.Decrypt(enc, tmp) + if !ok { + return nil, errDecryptionFailed + } + + if p.dupCheck.IsDup(h.Counter) { + return nil, errDuplicateSeqNum + } + + msg, err := parseControlMsg(h.SourceIP, fromAddr, out) + if err != nil { + return nil, err + } + + return msg, nil +} + +// ---------------------------------------------------------------------------- + +type RoutingTable struct { + // The LocalIP is the configured IP address of the local peer on the VPN. + // + // This value is constant. + LocalIP byte + + // The LocalAddr is the configured local public address of the peer on the + // internet. If LocalAddr.IsValid(), then the local peer has a public + // address. + // + // This value is constant. + LocalAddr netip.AddrPort + + // The remote peer configurations. These are updated by + Peers [256]RemotePeer + + // The current relay's VPN IP address, or zero if no relay is available. + RelayIP byte +} + +func NewRoutingTable(localIP byte, localAddr netip.AddrPort) RoutingTable { + rt := RoutingTable{ + LocalIP: localIP, + LocalAddr: localAddr, + } + + for i := range rt.Peers { + counter := uint64(time.Now().Unix()<<30 + 1) + rt.Peers[i] = RemotePeer{ + localIP: localIP, + IP: byte(i), + counter: &counter, + dupCheck: newDupCheck(0), + } + } + + return rt +} + +func (rt *RoutingTable) GetRelay() (RemotePeer, bool) { + relay := rt.Peers[rt.RelayIP] + return relay, relay.Up && relay.Direct +} diff --git a/peer/routingtable_test.go b/peer/routingtable_test.go new file mode 100644 index 0000000..b5497a4 --- /dev/null +++ b/peer/routingtable_test.go @@ -0,0 +1,169 @@ +package peer + +import ( + "bytes" + "reflect" + "testing" +) + +func TestRemotePeer_DecryptDataPacket(t *testing.T) { + p1, p2, _ := NewPeersForTesting() + orig := RandPacket() + + peer2 := p1.RT.Load().Peers[2] + peer1 := p2.RT.Load().Peers[1] + + enc := peer2.EncryptDataPacket(2, orig, newBuf()) + + h := parseHeader(enc) + if h.DestIP != 2 || h.SourceIP != 1 { + t.Fatal(h) + } + + dec, err := peer1.DecryptDataPacket(h, enc, newBuf()) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(orig, dec) { + t.Fatal(dec) + } +} + +func TestRemotePeer_DecryptDataPacket_packetAltered(t *testing.T) { + p1, p2, _ := NewPeersForTesting() + orig := RandPacket() + + peer2 := p1.RT.Load().Peers[2] + peer1 := p2.RT.Load().Peers[1] + + enc := peer2.EncryptDataPacket(2, orig, newBuf()) + + h := parseHeader(enc) + + for range 2048 { + _, err := peer1.DecryptDataPacket(h, ModifyPacket(enc), newBuf()) + if err == nil { + t.Fatal(enc) + } + } +} + +func TestRemotePeer_DecryptDataPacket_duplicateSequenceNumber(t *testing.T) { + p1, p2, _ := NewPeersForTesting() + orig := RandPacket() + + peer2 := p1.RT.Load().Peers[2] + peer1 := p2.RT.Load().Peers[1] + + enc := peer2.EncryptDataPacket(2, orig, newBuf()) + h := parseHeader(enc) + + if _, err := peer1.DecryptDataPacket(h, enc, newBuf()); err != nil { + t.Fatal(err) + } + + if _, err := peer1.DecryptDataPacket(h, enc, newBuf()); err == nil { + t.Fatal(err) + } +} + +func TestRemotePeer_DecryptControlPacket(t *testing.T) { + p1, p2, _ := NewPeersForTesting() + + peer2 := p1.RT.Load().Peers[2] + peer1 := p2.RT.Load().Peers[1] + + orig := PacketProbe{TraceID: newTraceID()} + + enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf()) + + h := parseHeader(enc) + if h.DestIP != 2 || h.SourceIP != 1 { + t.Fatal(h) + } + + ctrlMsg, err := peer1.DecryptControlPacket(p1.RT.Load().LocalAddr, h, enc, newBuf()) + if err != nil { + t.Fatal(err) + } + + dec, ok := ctrlMsg.(controlMsg[PacketProbe]) + if !ok { + t.Fatal(ctrlMsg) + } + + if dec.SrcIP != 1 || dec.SrcAddr != p1.RT.Load().LocalAddr { + t.Fatal(dec) + } + + if !reflect.DeepEqual(dec.Packet, orig) { + t.Fatal(dec) + } +} + +func TestRemotePeer_DecryptControlPacket_packetAltered(t *testing.T) { + p1, p2, _ := NewPeersForTesting() + + peer2 := p1.RT.Load().Peers[2] + peer1 := p2.RT.Load().Peers[1] + + orig := PacketProbe{TraceID: newTraceID()} + + enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf()) + + h := parseHeader(enc) + if h.DestIP != 2 || h.SourceIP != 1 { + t.Fatal(h) + } + + for range 2048 { + ctrlMsg, err := peer1.DecryptControlPacket(p1.RT.Load().LocalAddr, h, ModifyPacket(enc), newBuf()) + if err == nil { + t.Fatal(ctrlMsg) + } + } +} + +func TestRemotePeer_DecryptControlPacket_duplicateSequenceNumber(t *testing.T) { + p1, p2, _ := NewPeersForTesting() + + peer2 := p1.RT.Load().Peers[2] + peer1 := p2.RT.Load().Peers[1] + + orig := PacketProbe{TraceID: newTraceID()} + + enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf()) + + h := parseHeader(enc) + if h.DestIP != 2 || h.SourceIP != 1 { + t.Fatal(h) + } + + if _, err := peer1.DecryptControlPacket(p1.RT.Load().LocalAddr, h, enc, newBuf()); err != nil { + t.Fatal(err) + } + if _, err := peer1.DecryptControlPacket(p1.RT.Load().LocalAddr, h, enc, newBuf()); err == nil { + t.Fatal(err) + } +} + +func TestRemotePeer_DecryptControlPacket_unknownPacketType(t *testing.T) { + p1, p2, _ := NewPeersForTesting() + + peer2 := p1.RT.Load().Peers[2] + peer1 := p2.RT.Load().Peers[1] + + orig := UnknownControlPacket{TraceID: newTraceID()} + + enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf()) + + h := parseHeader(enc) + if h.DestIP != 2 || h.SourceIP != 1 { + t.Fatal(h) + } + + if _, err := peer1.DecryptControlPacket(p1.RT.Load().LocalAddr, h, enc, newBuf()); err == nil { + t.Fatal(err) + } +} diff --git a/peer/state.go b/peer/state.go deleted file mode 100644 index d6589fe..0000000 --- a/peer/state.go +++ /dev/null @@ -1,29 +0,0 @@ -package peer - -import ( - "net/netip" - "time" -) - -type RemotePeer struct { - IP byte // VPN IP of peer (last byte). - Up bool // True if data can be sent on the peer. - Relay bool // True if the peer is a relay. - Direct bool // True if this is a direct connection. - DirectAddr netip.AddrPort // Remote address if directly connected. - PubSignKey []byte - ControlCipher *controlCipher - DataCipher *dataCipher - - Counter *uint64 // For sending to. Atomic access only. - DupCheck *dupCheck // For receiving from. Not safe for concurrent use. -} - -func NewRemotePeer(ip byte) *RemotePeer { - counter := uint64(time.Now().Unix()<<30 + 1) - return &RemotePeer{ - IP: ip, - Counter: &counter, - DupCheck: newDupCheck(0), - } -} diff --git a/peer/supervisor.go b/peer/supervisor.go new file mode 100644 index 0000000..0f82a3f --- /dev/null +++ b/peer/supervisor.go @@ -0,0 +1,103 @@ +package peer + +import ( + "log" + "sync/atomic" + "time" + + "git.crumpington.com/lib/go/ratelimiter" +) + +// ---------------------------------------------------------------------------- + +type Supervisor struct { + messages chan any // Incoming control messages. + peers [256]PeerState + pubAddrs *pubAddrStore + rt *atomic.Pointer[RoutingTable] + staged RoutingTable +} + +func NewSupervisor( + sendControl func(RemotePeer, Marshaller), + privKey []byte, + rt *atomic.Pointer[RoutingTable], +) *Supervisor { + s := &Supervisor{ + messages: make(chan any, 1024), + pubAddrs: newPubAddrStore(rt.Load().LocalAddr), + rt: rt, + } + + routes := rt.Load() + + for i := range s.peers { + state := &State{ + publish: s.Publish, + sendControlPacket: sendControl, + localIP: routes.LocalIP, + remoteIP: byte(i), + privKey: privKey, + localAddr: routes.LocalAddr, + pubAddrs: s.pubAddrs, + staged: routes.Peers[i], + limiter: ratelimiter.New(ratelimiter.Config{ + FillPeriod: 20 * time.Millisecond, + MaxWaitCount: 1, + }), + } + s.peers[i] = state.OnPeerUpdate(nil) + } + + return s +} + +func (s *Supervisor) HandleControlMsg(msg any) { + select { + case s.messages <- msg: + default: + } +} + +func (s *Supervisor) Run() { + for raw := range s.messages { + switch msg := raw.(type) { + + case peerUpdateMsg: + s.peers[msg.PeerIP] = s.peers[msg.PeerIP].OnPeerUpdate(msg.Peer) + + case controlMsg[PacketSyn]: + if newState := s.peers[msg.SrcIP].OnSyn(msg); newState != nil { + s.peers[msg.SrcIP] = newState + } + + case controlMsg[PacketAck]: + s.peers[msg.SrcIP].OnAck(msg) + + case controlMsg[PacketProbe]: + if newState := s.peers[msg.SrcIP].OnProbe(msg); newState != nil { + s.peers[msg.SrcIP] = newState + } + + case controlMsg[PacketLocalDiscovery]: + s.peers[msg.SrcIP].OnLocalDiscovery(msg) + + case pingTimerMsg: + s.pubAddrs.Clean() + for i := range s.peers { + if newState := s.peers[i].OnPingTimer(); newState != nil { + s.peers[i] = newState + } + } + + default: + log.Printf("WARNING: unknown message type: %+v", msg) + } + } +} + +func (s *Supervisor) Publish(rp RemotePeer) { + s.staged.Peers[rp.IP] = rp + rt := s.staged // Copy. + s.rt.Store(&rt) +} diff --git a/peer/util_test.go b/peer/util_test.go new file mode 100644 index 0000000..56b9d6f --- /dev/null +++ b/peer/util_test.go @@ -0,0 +1,26 @@ +package peer + +import ( + "net/netip" + "testing" +) + +func addrPort4(a, b, c, d byte, port uint16) netip.AddrPort { + return netip.AddrPortFrom(netip.AddrFrom4([4]byte{a, b, c, d}), port) +} + +func assertType[T any](t *testing.T, obj any) T { + t.Helper() + x, ok := obj.(T) + if !ok { + t.Fatal("invalid type", obj) + } + return x +} + +func assertEqual[T comparable](t *testing.T, a, b T) { + t.Helper() + if a != b { + t.Fatal(a, " != ", b) + } +} -- 2.39.5 From affeb0b9ce9423a035907931f19038e8192100e9 Mon Sep 17 00:00:00 2001 From: jdl Date: Mon, 10 Feb 2025 19:21:36 +0100 Subject: [PATCH 07/26] WIP: Working --- peer/peerstates.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/peer/peerstates.go b/peer/peerstates.go index b05826c..f12ab08 100644 --- a/peer/peerstates.go +++ b/peer/peerstates.go @@ -59,10 +59,13 @@ func (s *State) OnPeerUpdate(peer *m.Peer) PeerState { } s.peer = peer + + s.staged.localIP = s.localIP + s.staged.IP = peer.PeerIP + s.staged.Up = false s.staged.Relay = false s.staged.Direct = false s.staged.DirectAddr = netip.AddrPort{} - s.staged.PubSignKey = nil s.staged.PubSignKey = peer.PubSignKey s.staged.ControlCipher = newControlCipher(s.privKey, peer.PubKey) s.staged.DataCipher = newDataCipher() @@ -93,7 +96,7 @@ func (s *State) OnPeerUpdate(peer *m.Peer) PeerState { func (s *State) logf(format string, args ...any) { b := strings.Builder{} - name := "--" + name := "" if s.peer != nil { name = s.peer.Name } -- 2.39.5 From 08dc79283e2914016943074990cbd7ed5ec2ea39 Mon Sep 17 00:00:00 2001 From: jdl Date: Wed, 19 Feb 2025 14:13:25 +0100 Subject: [PATCH 08/26] wip --- cmd/vppn/main.go | 4 +- node/main.go | 1 - peer/connreader.go | 8 +- peer/connreader2.go | 39 +++-- peer/connreader_test.go | 353 -------------------------------------- peer/connwriter.go | 80 --------- peer/connwriter2.go | 109 ------------ peer/connwriter2_test.go | 145 ---------------- peer/connwriter_test.go | 240 -------------------------- peer/controlmessage.go | 18 +- peer/crypto.go | 8 +- peer/crypto_test.go | 39 +++-- peer/hubpoller.go | 32 ++-- peer/ifreader.go | 100 ----------- peer/ifreader2.go | 45 +++-- peer/ifreader2_test.go | 6 +- peer/ifreader_test.go | 232 ------------------------- peer/interface.go | 177 +++++++++++++++++++ peer/interfaces.go | 10 +- peer/main.go | 23 +++ peer/mcreader.go | 6 +- peer/mcreader_test.go | 10 +- peer/packets.go | 43 ++--- peer/packets_test.go | 12 +- peer/peer.go | 161 +++++++++++++++++ peer/peer_test.go | 19 +- peer/peerstates.go | 176 +++++++++---------- peer/peerstates_test.go | 150 ++++++++-------- peer/peersuper.go | 172 +++++++++++++++++++ peer/routingtable.go | 29 ++-- peer/routingtable_test.go | 8 +- peer/supervisor.go | 103 ----------- 32 files changed, 873 insertions(+), 1685 deletions(-) delete mode 100644 peer/connreader_test.go delete mode 100644 peer/connwriter.go delete mode 100644 peer/connwriter2.go delete mode 100644 peer/connwriter2_test.go delete mode 100644 peer/connwriter_test.go delete mode 100644 peer/ifreader.go delete mode 100644 peer/ifreader_test.go create mode 100644 peer/interface.go create mode 100644 peer/main.go create mode 100644 peer/peer.go create mode 100644 peer/peersuper.go delete mode 100644 peer/supervisor.go diff --git a/cmd/vppn/main.go b/cmd/vppn/main.go index 8c04016..5daa907 100644 --- a/cmd/vppn/main.go +++ b/cmd/vppn/main.go @@ -2,10 +2,10 @@ package main import ( "log" - "vppn/node" + "vppn/peer" ) func main() { log.SetFlags(0) - node.Main() + peer.Main() } diff --git a/node/main.go b/node/main.go index 8e53cb4..78611a8 100644 --- a/node/main.go +++ b/node/main.go @@ -258,7 +258,6 @@ func handleControlPacket(addr netip.AddrPort, h header, data, decBuf []byte) { default: log.Printf("Dropping control packet.") } - } func handleDataPacket(h header, data []byte, decBuf []byte, iface ifWriter, sender dataPacketSender) { diff --git a/peer/connreader.go b/peer/connreader.go index b127030..37a4c87 100644 --- a/peer/connreader.go +++ b/peer/connreader.go @@ -12,7 +12,7 @@ type connReader struct { sender encryptedPacketSender super controlMsgHandler localIP byte - peers [256]*atomic.Pointer[RemotePeer] + peers [256]*atomic.Pointer[remotePeer] buf []byte decBuf []byte @@ -24,7 +24,7 @@ func newConnReader( sender encryptedPacketSender, super controlMsgHandler, localIP byte, - peers [256]*atomic.Pointer[RemotePeer], + peers [256]*atomic.Pointer[remotePeer], ) *connReader { return &connReader{ conn: conn, @@ -79,7 +79,7 @@ func (r *connReader) handleNextPacket() { } func (r *connReader) handleControlPacket( - peer *RemotePeer, + peer *remotePeer, addr netip.AddrPort, h header, enc []byte, @@ -102,7 +102,7 @@ func (r *connReader) handleControlPacket( r.super.HandleControlMsg(msg) } -func (r *connReader) handleDataPacket(peer *RemotePeer, h header, enc []byte) { +func (r *connReader) handleDataPacket(peer *remotePeer, h header, enc []byte) { if !peer.Up { r.logf("Not connected (recv).") return diff --git a/peer/connreader2.go b/peer/connreader2.go index d9feab8..9e870d7 100644 --- a/peer/connreader2.go +++ b/peer/connreader2.go @@ -12,12 +12,12 @@ type ConnReader struct { readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error) // Output - iface io.Writer - forwardData func(ip byte, pkt []byte) - handleControlMsg func(pkt any) + writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error) + iface io.Writer + handleControlMsg func(fromIP byte, pkt any) localIP byte - rt *atomic.Pointer[RoutingTable] + rt *atomic.Pointer[routingTable] buf []byte decBuf []byte @@ -25,15 +25,15 @@ type ConnReader struct { func NewConnReader( readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error), + writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error), iface io.Writer, - forwardData func(ip byte, pkt []byte), - handleControlMsg func(pkt any), - rt *atomic.Pointer[RoutingTable], + handleControlMsg func(fromIP byte, pkt any), + rt *atomic.Pointer[routingTable], ) *ConnReader { return &ConnReader{ readFromUDPAddrPort: readFromUDPAddrPort, + writeToUDPAddrPort: writeToUDPAddrPort, iface: iface, - forwardData: forwardData, handleControlMsg: handleControlMsg, localIP: rt.Load().LocalIP, rt: rt, @@ -50,7 +50,9 @@ func (r *ConnReader) Run() { func (r *ConnReader) handleNextPacket() { buf := r.buf[:bufferSize] + log.Printf("Getting next packet...") n, remoteAddr, err := r.readFromUDPAddrPort(buf) + log.Printf("Packet from %v...", remoteAddr) if err != nil { log.Fatalf("Failed to read from UDP port: %v", err) } @@ -64,14 +66,14 @@ func (r *ConnReader) handleNextPacket() { buf = buf[:n] h := parseHeader(buf) - peer := r.rt.Load().Peers[h.SourceIP] - //peer := rt.Peers[h.SourceIP] + rt := r.rt.Load() + peer := rt.Peers[h.SourceIP] switch h.StreamID { case controlStreamID: r.handleControlPacket(remoteAddr, peer, h, buf) case dataStreamID: - r.handleDataPacket(peer, h, buf) + r.handleDataPacket(rt, peer, h, buf) default: r.logf("Unknown stream ID: %d", h.StreamID) } @@ -79,7 +81,7 @@ func (r *ConnReader) handleNextPacket() { func (r *ConnReader) handleControlPacket( remoteAddr netip.AddrPort, - peer RemotePeer, + peer remotePeer, h header, enc []byte, ) { @@ -98,11 +100,12 @@ func (r *ConnReader) handleControlPacket( return } - r.handleControlMsg(msg) + r.handleControlMsg(h.SourceIP, msg) } func (r *ConnReader) handleDataPacket( - peer RemotePeer, + rt *routingTable, + peer remotePeer, h header, enc []byte, ) { @@ -124,7 +127,13 @@ func (r *ConnReader) handleDataPacket( return } - r.forwardData(h.DestIP, data) + relay, ok := rt.GetRelay() + if !ok { + r.logf("Relay not available.") + return + } + + r.writeToUDPAddrPort(data, relay.DirectAddr) } func (r *ConnReader) logf(format string, args ...any) { diff --git a/peer/connreader_test.go b/peer/connreader_test.go deleted file mode 100644 index 714f6f3..0000000 --- a/peer/connreader_test.go +++ /dev/null @@ -1,353 +0,0 @@ -package peer - -import ( - "bytes" - "crypto/rand" - "net/netip" - "reflect" - "sync/atomic" - "testing" -) - -type mockIfWriter struct { - Written [][]byte -} - -func (w *mockIfWriter) Write(b []byte) (int, error) { - w.Written = append(w.Written, bytes.Clone(b)) - return len(b), nil -} - -type mockEncryptedPacket struct { - Packet []byte - Route *RemotePeer -} - -type mockEncryptedPacketSender struct { - Sent []mockEncryptedPacket -} - -func (m *mockEncryptedPacketSender) SendEncryptedDataPacket(pkt []byte, route *RemotePeer) { - m.Sent = append(m.Sent, mockEncryptedPacket{ - Packet: bytes.Clone(pkt), - Route: route, - }) -} - -type mockControlMsgHandler struct { - Messages []any -} - -func (m *mockControlMsgHandler) HandleControlMsg(pkt any) { - m.Messages = append(m.Messages, pkt) -} - -type udpPipe struct { - packets chan []byte -} - -func newUDPPipe() *udpPipe { - return &udpPipe{make(chan []byte, 1024)} -} - -func (p *udpPipe) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { - p.packets <- bytes.Clone(b) - return len(b), nil -} - -func (p *udpPipe) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) { - packet := <-p.packets - copy(b, packet) - return len(packet), netip.AddrPort{}, nil -} - -type connReaderTestHarness struct { - Pipe *udpPipe - R *connReader - WRemote *connWriter - WRelayRemote *connWriter - Remote *RemotePeer - RelayRemote *RemotePeer - IFace *mockIfWriter - Sender *mockEncryptedPacketSender - Super *mockControlMsgHandler -} - -// Peer 2 is indirect, peer 3 is direct. -func newConnReadeTestHarness() (h connReaderTestHarness) { - pipe := newUDPPipe() - routes := [256]*atomic.Pointer[RemotePeer]{} - for i := range routes { - routes[i] = &atomic.Pointer[RemotePeer]{} - routes[i].Store(&RemotePeer{}) - } - - local, remote, relayLocal, relayRemote := testConnWriter_getTestRoutes() - routes[2].Store(local) - routes[3].Store(relayLocal) - - h.Pipe = pipe - h.WRemote = newConnWriter(pipe, 2) - h.WRelayRemote = newConnWriter(pipe, 3) - - h.Remote = remote - h.RelayRemote = relayRemote - h.IFace = &mockIfWriter{} - h.Sender = &mockEncryptedPacketSender{} - h.Super = &mockControlMsgHandler{} - h.R = newConnReader( - pipe, - h.IFace, - h.Sender, - h.Super, - 1, - routes) - return h -} - -// Testing that we can receive a control packet. -func TestConnReader_handleControlPacket(t *testing.T) { - h := newConnReadeTestHarness() - - pkt := PacketSyn{TraceID: 1234} - - h.WRemote.SendControlPacket(pkt, h.Remote) - - h.R.handleNextPacket() - - if len(h.Super.Messages) != 1 { - t.Fatal(h.Super.Messages) - } - - msg := h.Super.Messages[0].(controlMsg[PacketSyn]) - if !reflect.DeepEqual(pkt, msg.Packet) { - t.Fatal(msg.Packet) - } -} - -// Testing that a short packet is ignored. -func TestConnReader_handleNextPacket_short(t *testing.T) { - h := newConnReadeTestHarness() - - h.Pipe.WriteToUDPAddrPort([]byte{1, 2, 3}, netip.AddrPort{}) - h.R.handleNextPacket() - - if len(h.Super.Messages) != 0 { - t.Fatal(h.Super.Messages) - } -} - -// Testing that a packet with an unexpected stream ID is ignored. -func TestConnReader_handleNextPacket_unknownStreamID(t *testing.T) { - h := newConnReadeTestHarness() - - pkt := PacketSyn{TraceID: 1234} - - encrypted := encryptControlPacket(1, h.Remote, pkt, newBuf(), newBuf()) - var header header - header.Parse(encrypted) - header.StreamID = 100 - header.Marshal(encrypted) - - h.WRemote.writeTo(encrypted, netip.AddrPort{}) - h.R.handleNextPacket() - if len(h.Super.Messages) != 0 { - t.Fatal(h.Super.Messages) - } -} - -// Testing that control packet without matching control cipher is ignored. -func TestConnReader_handleControlPacket_noCipher(t *testing.T) { - h := newConnReadeTestHarness() - - pkt := PacketSyn{TraceID: 1234} - - //encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) - encrypted := encryptControlPacket(1, h.Remote, pkt, newBuf(), newBuf()) - var header header - header.Parse(encrypted) - header.SourceIP = 10 - header.Marshal(encrypted) - - h.WRemote.writeTo(encrypted, netip.AddrPort{}) - h.R.handleNextPacket() - if len(h.Super.Messages) != 0 { - t.Fatal(h.Super.Messages) - } -} - -// Testing that control packet with incrrect destination IP is ignored. -func TestConnReader_handleControlPacket_incorrectDest(t *testing.T) { - h := newConnReadeTestHarness() - - pkt := PacketSyn{TraceID: 1234} - - encrypted := encryptControlPacket(2, h.Remote, pkt, newBuf(), newBuf()) - var header header - header.Parse(encrypted) - header.DestIP++ - header.Marshal(encrypted) - - h.WRemote.writeTo(encrypted, netip.AddrPort{}) - h.R.handleNextPacket() - if len(h.Super.Messages) != 0 { - t.Fatal(h.Super.Messages) - } -} - -// Testing that modified control packet is ignored. -func TestConnReader_handleControlPacket_modified(t *testing.T) { - h := newConnReadeTestHarness() - - pkt := PacketSyn{TraceID: 1234} - - encrypted := encryptControlPacket(2, h.Remote, pkt, newBuf(), newBuf()) - encrypted[len(encrypted)-1]++ - - h.WRemote.writeTo(encrypted, netip.AddrPort{}) - h.R.handleNextPacket() - if len(h.Super.Messages) != 0 { - t.Fatal(h.Super.Messages) - } -} - -type unknownPacket struct{} - -func (p unknownPacket) Marshal(buf []byte) []byte { - buf = buf[:1] - buf[0] = 100 - return buf -} - -// Testing that an empty control packet is ignored. -func TestConnReader_handleControlPacket_unknownPacketType(t *testing.T) { - h := newConnReadeTestHarness() - - pkt := unknownPacket{} - - encrypted := encryptControlPacket(2, h.Remote, pkt, newBuf(), newBuf()) - h.WRemote.writeTo(encrypted, netip.AddrPort{}) - h.R.handleNextPacket() - if len(h.Super.Messages) != 0 { - t.Fatal(h.Super.Messages) - } -} - -// Testing that a duplicate control packet is ignored. -func TestConnReader_handleControlPacket_duplicate(t *testing.T) { - h := newConnReadeTestHarness() - - pkt := PacketAck{TraceID: 1234} - - h.WRemote.SendControlPacket(pkt, h.Remote) - *h.Remote.counter = *h.Remote.counter - 1 - h.WRemote.SendControlPacket(pkt, h.Remote) - - h.R.handleNextPacket() - h.R.handleNextPacket() - - if len(h.Super.Messages) != 1 { - t.Fatal(h.Super.Messages) - } - - msg := h.Super.Messages[0].(controlMsg[PacketAck]) - if !reflect.DeepEqual(pkt, msg.Packet) { - t.Fatal(msg.Packet) - } -} - -// Testing that we can receive a data packet. -func TestConnReader_handleDataPacket(t *testing.T) { - h := newConnReadeTestHarness() - - pkt := make([]byte, 1024) - rand.Read(pkt) - - h.WRemote.SendDataPacket(pkt, h.Remote) - - h.R.handleNextPacket() - - if len(h.IFace.Written) != 1 { - t.Fatal(h.IFace.Written) - } - - if !bytes.Equal(pkt, h.IFace.Written[0]) { - t.Fatal(h.IFace.Written) - } -} - -// Testing that data packet is ignored if route isn't up. -func TestConnReader_handleDataPacket_routeDown(t *testing.T) { - h := newConnReadeTestHarness() - - pkt := make([]byte, 1024) - rand.Read(pkt) - - h.WRemote.SendDataPacket(pkt, h.Remote) - route := h.R.peers[2].Load() - route.Up = false - - h.R.handleNextPacket() - - if len(h.IFace.Written) != 0 { - t.Fatal(h.IFace.Written) - } -} - -// Testing that a duplicate data packet is ignored. -func TestConnReader_handleDataPacket_duplicate(t *testing.T) { - h := newConnReadeTestHarness() - - pkt := make([]byte, 123) - - h.WRemote.SendDataPacket(pkt, h.Remote) - *h.Remote.counter = *h.Remote.counter - 1 - h.WRemote.SendDataPacket(pkt, h.Remote) - - h.R.handleNextPacket() - h.R.handleNextPacket() - - if len(h.IFace.Written) != 1 { - t.Fatal(h.IFace.Written) - } - - if !bytes.Equal(pkt, h.IFace.Written[0]) { - t.Fatal(h.IFace.Written) - } -} - -// Testing that we can relay a data packet. -func TestConnReader_handleDataPacket_relay(t *testing.T) { - h := newConnReadeTestHarness() - - pkt := make([]byte, 1024) - rand.Read(pkt) - - h.RelayRemote.IP = 3 - h.WRemote.RelayDataPacket(pkt, h.RelayRemote, h.Remote) - h.R.handleNextPacket() - - if len(h.Sender.Sent) != 1 { - t.Fatal(h.Sender.Sent) - } - -} - -// Testing that we drop a relayed packet if destination is down. -func TestConnReader_handleDataPacket_relayDown(t *testing.T) { - h := newConnReadeTestHarness() - - pkt := make([]byte, 1024) - rand.Read(pkt) - - h.RelayRemote.IP = 3 - relay := h.R.peers[3].Load() - relay.Up = false - - h.WRemote.RelayDataPacket(pkt, h.RelayRemote, h.Remote) - h.R.handleNextPacket() - - if len(h.Sender.Sent) != 0 { - t.Fatal(h.Sender.Sent) - } -} diff --git a/peer/connwriter.go b/peer/connwriter.go deleted file mode 100644 index 8a09e35..0000000 --- a/peer/connwriter.go +++ /dev/null @@ -1,80 +0,0 @@ -package peer - -import ( - "log" - "net/netip" - "sync" -) - -// ---------------------------------------------------------------------------- - -type connWriter struct { - localIP byte - conn udpWriter - - // For sending control packets. - cBuf1 []byte - cBuf2 []byte - - // For sending data packets. - dBuf1 []byte - dBuf2 []byte - - // Lock around for sending on UDP Conn. - wLock sync.Mutex -} - -func newConnWriter(conn udpWriter, localIP byte) *connWriter { - w := &connWriter{ - localIP: localIP, - conn: conn, - cBuf1: make([]byte, bufferSize), - cBuf2: make([]byte, bufferSize), - dBuf1: make([]byte, bufferSize), - dBuf2: make([]byte, bufferSize), - } - return w -} - -// Not safe for concurrent use. Should only be called by supervisor. -func (w *connWriter) SendControlPacket(pkt Marshaller, peer *RemotePeer) { - enc := encryptControlPacket(w.localIP, peer, pkt, w.cBuf1, w.cBuf2) - w.writeTo(enc, peer.DirectAddr) -} - -// Relay control packet. Peer must not be nil. -func (w *connWriter) RelayControlPacket(pkt Marshaller, peer, relay *RemotePeer) { - enc := encryptControlPacket(w.localIP, peer, pkt, w.cBuf1, w.cBuf2) - enc = encryptDataPacket(w.localIP, peer.IP, relay, enc, w.cBuf1) - w.writeTo(enc, relay.DirectAddr) -} - -// Not safe for concurrent use. Should only be called by ifReader. -func (w *connWriter) SendDataPacket(pkt []byte, peer *RemotePeer) { - enc := encryptDataPacket(w.localIP, peer.IP, peer, pkt, w.dBuf1) - w.writeTo(enc, peer.DirectAddr) -} - -// Relay a data packet. Peer must not be nil. -func (w *connWriter) RelayDataPacket(pkt []byte, peer, relay *RemotePeer) { - enc := encryptDataPacket(w.localIP, peer.IP, peer, pkt, w.dBuf1) - enc = encryptDataPacket(w.localIP, peer.IP, relay, enc, w.dBuf2) - w.writeTo(enc, relay.DirectAddr) -} - -// Safe for concurrent use. Should only be called by connReader. -// -// This function will send pkt to the peer directly. This is used when a peer -// is acting as a relay and is forwarding already encrypted data for another -// peer. -func (w *connWriter) SendEncryptedDataPacket(pkt []byte, peer *RemotePeer) { - w.writeTo(pkt, peer.DirectAddr) -} - -func (w *connWriter) writeTo(packet []byte, addr netip.AddrPort) { - w.wLock.Lock() - if _, err := w.conn.WriteToUDPAddrPort(packet, addr); err != nil { - log.Printf("[ConnWriter] Failed to write to UDP port: %v", err) - } - w.wLock.Unlock() -} diff --git a/peer/connwriter2.go b/peer/connwriter2.go deleted file mode 100644 index e58250d..0000000 --- a/peer/connwriter2.go +++ /dev/null @@ -1,109 +0,0 @@ -package peer - -import ( - "log" - "net/netip" - "sync" - "sync/atomic" -) - -type ConnWriter struct { - wLock sync.Mutex // Lock around for sending on UDP Conn. - - // Output. - writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error) - - // Shared state. - rt *atomic.Pointer[RoutingTable] - - // For sending control packets. - cBuf1 []byte - cBuf2 []byte - - // For sending data packets. - dBuf1 []byte - dBuf2 []byte -} - -func NewConnWriter( - writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error), - rt *atomic.Pointer[RoutingTable], -) *ConnWriter { - return &ConnWriter{ - writeToUDPAddrPort: writeToUDPAddrPort, - rt: rt, - cBuf1: newBuf(), - cBuf2: newBuf(), - dBuf1: newBuf(), - dBuf2: newBuf(), - } -} - -// Called by ConnReader to forward already encrypted bytes to another peer. -func (w *ConnWriter) Forward(ip byte, pkt []byte) { - peer := w.rt.Load().Peers[ip] - if !(peer.Up && peer.Direct) { - w.logf("Failed to forward to %d.", ip) - return - } - w.writeTo(pkt, peer.DirectAddr) -} - -// Called by IFReader to send data. Encryption will be applied, and packet will -// be relayed if appropriate. -func (w *ConnWriter) WriteData(ip byte, pkt []byte) { - rt := w.rt.Load() - peer := rt.Peers[ip] - if !peer.Up { - w.logf("Failed to send data to %d.", ip) - return - } - - enc := peer.EncryptDataPacket(ip, pkt, w.dBuf1) - - if peer.Direct { - w.writeTo(enc, peer.DirectAddr) - return - } - - relay, ok := rt.GetRelay() - if !ok { - w.logf("Failed to send data to %d. No relay.", ip) - return - } - - enc = relay.EncryptDataPacket(ip, enc, w.dBuf2) - w.writeTo(enc, relay.DirectAddr) -} - -// Called by Supervisor to send control packets. -func (w *ConnWriter) WriteControl(peer RemotePeer, pkt Marshaller) { - enc := peer.EncryptControlPacket(pkt, w.cBuf2, w.cBuf1) - - if peer.Direct { - w.writeTo(enc, peer.DirectAddr) - return - } - - rt := w.rt.Load() - relay, ok := rt.GetRelay() - if !ok { - w.logf("Failed to send control to %d. No relay.", peer.IP) - return - } - - enc = relay.EncryptDataPacket(peer.IP, enc, w.cBuf2) - w.writeTo(enc, relay.DirectAddr) -} - -func (w *ConnWriter) writeTo(pkt []byte, addr netip.AddrPort) { - w.wLock.Lock() - if _, err := w.writeToUDPAddrPort(pkt, addr); err != nil { - w.logf("Failed to write to UDP port: %v", err) - } - w.wLock.Unlock() -} - -func (w *ConnWriter) logf(s string, args ...any) { - log.Printf("[ConnWriter] "+s, args...) -} diff --git a/peer/connwriter2_test.go b/peer/connwriter2_test.go deleted file mode 100644 index f0bb00f..0000000 --- a/peer/connwriter2_test.go +++ /dev/null @@ -1,145 +0,0 @@ -package peer - -import ( - "testing" -) - -func TestConnWriter_WriteData_direct(t *testing.T) { - p1, p2, _ := NewPeersForTesting() - - in := RandPacket() - p1.ConnWriter.WriteData(2, in) - - packets := p2.Conn.Packets() - if len(packets) != 1 { - t.Fatal(packets) - } -} - -func TestConnWriter_WriteData_peerNotUp(t *testing.T) { - p1, p2, _ := NewPeersForTesting() - p1.RT.Load().Peers[2].Up = false - - in := RandPacket() - p1.ConnWriter.WriteData(2, in) - - packets := p2.Conn.Packets() - if len(packets) != 0 { - t.Fatal(packets) - } -} - -func TestConnWriter_WriteData_relay(t *testing.T) { - p1, _, p3 := NewPeersForTesting() - - p1.RT.Load().Peers[2].Direct = false - p1.RT.Load().RelayIP = 3 - - in := RandPacket() - p1.ConnWriter.WriteData(2, in) - - packets := p3.Conn.Packets() - if len(packets) != 1 { - t.Fatal(packets) - } -} - -func TestConnWriter_WriteData_relayNotAvailable(t *testing.T) { - p1, _, p3 := NewPeersForTesting() - - p1.RT.Load().Peers[2].Direct = false - p1.RT.Load().Peers[3].Up = false - p1.RT.Load().RelayIP = 3 - - in := RandPacket() - p1.ConnWriter.WriteData(2, in) - - packets := p3.Conn.Packets() - if len(packets) != 0 { - t.Fatal(packets) - } -} - -func TestConnWriter_WriteControl_direct(t *testing.T) { - p1, p2, _ := NewPeersForTesting() - - orig := PacketProbe{TraceID: newTraceID()} - - p1.ConnWriter.WriteControl(p1.RT.Load().Peers[2], orig) - - packets := p2.Conn.Packets() - if len(packets) != 1 { - t.Fatal(packets) - } -} - -func TestConnWriter_WriteControl_relay(t *testing.T) { - p1, _, p3 := NewPeersForTesting() - - p1.RT.Load().Peers[2].Direct = false - p1.RT.Load().RelayIP = 3 - - orig := PacketProbe{TraceID: newTraceID()} - - p1.ConnWriter.WriteControl(p1.RT.Load().Peers[2], orig) - - packets := p3.Conn.Packets() - if len(packets) != 1 { - t.Fatal(packets) - } -} - -func TestConnWriter_WriteControl_relayNotAvailable(t *testing.T) { - p1, _, p3 := NewPeersForTesting() - - p1.RT.Load().Peers[2].Direct = false - p1.RT.Load().Peers[3].Up = false - p1.RT.Load().RelayIP = 3 - - orig := PacketProbe{TraceID: newTraceID()} - - p1.ConnWriter.WriteControl(p1.RT.Load().Peers[2], orig) - - packets := p3.Conn.Packets() - if len(packets) != 0 { - t.Fatal(packets) - } -} - -func TestConnWriter__Forward(t *testing.T) { - p1, p2, _ := NewPeersForTesting() - - in := RandPacket() - p1.ConnWriter.Forward(2, in) - - packets := p2.Conn.Packets() - if len(packets) != 1 { - t.Fatal(packets) - } -} - -func TestConnWriter__Forward_notUp(t *testing.T) { - p1, p2, _ := NewPeersForTesting() - p1.RT.Load().Peers[2].Up = false - - in := RandPacket() - p1.ConnWriter.Forward(2, in) - - packets := p2.Conn.Packets() - if len(packets) != 0 { - t.Fatal(packets) - } -} - -func TestConnWriter__Forward_notDirect(t *testing.T) { - p1, p2, _ := NewPeersForTesting() - p1.RT.Load().Peers[2].Direct = false - - in := RandPacket() - p1.ConnWriter.Forward(2, in) - - packets := p2.Conn.Packets() - if len(packets) != 0 { - t.Fatal(packets) - } -} diff --git a/peer/connwriter_test.go b/peer/connwriter_test.go deleted file mode 100644 index d8c0365..0000000 --- a/peer/connwriter_test.go +++ /dev/null @@ -1,240 +0,0 @@ -package peer - -import ( - "bytes" - "net/netip" - "testing" -) - -// ---------------------------------------------------------------------------- - -type testUDPPacket struct { - Addr netip.AddrPort - Data []byte -} - -type testUDPAddrPortWriter struct { - written []testUDPPacket -} - -func (w *testUDPAddrPortWriter) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { - w.written = append(w.written, testUDPPacket{ - Addr: addr, - Data: bytes.Clone(b), - }) - return len(b), nil -} - -func (w *testUDPAddrPortWriter) Written() []testUDPPacket { - out := w.written - w.written = []testUDPPacket{} - return out -} - -// ---------------------------------------------------------------------------- - -type testPacket string - -func (p testPacket) Marshal(b []byte) []byte { - b = b[:len(p)] - copy(b, []byte(p)) - return b -} - -// ---------------------------------------------------------------------------- - -func testConnWriter_getTestRoutes() (local, remote, relayLocal, relayRemote *RemotePeer) { - localKeys := generateKeys() - remoteKeys := generateKeys() - - local = NewRemotePeer(2) - local.Up = true - local.Relay = false - local.PubSignKey = remoteKeys.PubSignKey - local.ControlCipher = newControlCipher(localKeys.PrivKey, remoteKeys.PubKey) - local.DataCipher = newDataCipher() - local.DirectAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 2}), 100) - - remote = NewRemotePeer(1) - remote.Up = true - remote.Relay = false - remote.PubSignKey = localKeys.PubSignKey - remote.ControlCipher = newControlCipher(remoteKeys.PrivKey, localKeys.PubKey) - remote.DataCipher = local.DataCipher - remote.DirectAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 100) - - rLocalKeys := generateKeys() - rRemoteKeys := generateKeys() - - relayLocal = NewRemotePeer(3) - relayLocal.Up = true - relayLocal.Relay = true - relayLocal.Direct = true - relayLocal.PubSignKey = rRemoteKeys.PubSignKey - relayLocal.ControlCipher = newControlCipher(rLocalKeys.PrivKey, rRemoteKeys.PubKey) - relayLocal.DataCipher = newDataCipher() - relayLocal.DirectAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 3}), 100) - - relayRemote = NewRemotePeer(1) - relayRemote.Up = true - relayRemote.Relay = false - relayRemote.Direct = true - relayRemote.PubSignKey = rLocalKeys.PubSignKey - relayRemote.ControlCipher = newControlCipher(rRemoteKeys.PrivKey, rLocalKeys.PubKey) - relayRemote.DataCipher = relayLocal.DataCipher - relayRemote.DirectAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 100) - - return -} - -// ---------------------------------------------------------------------------- - -// Testing if we can send a control packet directly to the remote route. -func TestConnWriter_SendControlPacket_direct(t *testing.T) { - route, rRoute, _, _ := testConnWriter_getTestRoutes() - route.Direct = true - - writer := &testUDPAddrPortWriter{} - w := newConnWriter(writer, rRoute.IP) - in := testPacket("hello world!") - - w.SendControlPacket(in, route) - out := writer.Written() - if len(out) != 1 { - t.Fatal(out) - } - - if out[0].Addr != route.DirectAddr { - t.Fatal(out[0]) - } - - dec, ok := rRoute.ControlCipher.Decrypt(out[0].Data, make([]byte, 1024)) - if !ok { - t.Fatal(ok) - } - if string(dec) != string(in) { - t.Fatal(dec) - } -} - -// Testing if we can relay a packet via an intermediary. -func TestConnWriter_RelayControlPacket_relay(t *testing.T) { - route, rRoute, relay, rRelay := testConnWriter_getTestRoutes() - - writer := &testUDPAddrPortWriter{} - w := newConnWriter(writer, rRoute.IP) - in := testPacket("hello world!") - - w.RelayControlPacket(in, route, relay) - - out := writer.Written() - if len(out) != 1 { - t.Fatal(out) - } - - if out[0].Addr != relay.DirectAddr { - t.Fatal(out[0]) - } - - dec, ok := rRelay.DataCipher.Decrypt(out[0].Data, make([]byte, 1024)) - if !ok { - t.Fatal(ok) - } - - dec2, ok := rRoute.ControlCipher.Decrypt(dec, make([]byte, 1024)) - if !ok { - t.Fatal(ok) - } - - if string(dec2) != string(in) { - t.Fatal(dec2) - } -} - -// Testing that we can send a data packet directly to a remote route. -func TestConnWriter_SendDataPacket_direct(t *testing.T) { - route, rRoute, _, _ := testConnWriter_getTestRoutes() - route.Direct = true - - writer := &testUDPAddrPortWriter{} - w := newConnWriter(writer, rRoute.IP) - - in := []byte("hello world!") - w.SendDataPacket(in, route) - - out := writer.Written() - if len(out) != 1 { - t.Fatal(out) - } - - if out[0].Addr != route.DirectAddr { - t.Fatal(out[0]) - } - - dec, ok := rRoute.DataCipher.Decrypt(out[0].Data, make([]byte, 1024)) - if !ok { - t.Fatal(ok) - } - - if !bytes.Equal(dec, in) { - t.Fatal(dec) - } -} - -// Testing that we can relay a data packet via a relay. -func TestConnWriter_RelayDataPacket_relay(t *testing.T) { - route, rRoute, relay, rRelay := testConnWriter_getTestRoutes() - - writer := &testUDPAddrPortWriter{} - w := newConnWriter(writer, rRoute.IP) - in := []byte("Hello world!") - - w.RelayDataPacket(in, route, relay) - - out := writer.Written() - if len(out) != 1 { - t.Fatal(out) - } - - if out[0].Addr != relay.DirectAddr { - t.Fatal(out[0]) - } - - dec, ok := rRelay.DataCipher.Decrypt(out[0].Data, make([]byte, 1024)) - if !ok { - t.Fatal(ok) - } - - dec2, ok := rRoute.DataCipher.Decrypt(dec, make([]byte, 1024)) - if !ok { - t.Fatal(ok) - } - - if !bytes.Equal(dec2, in) { - t.Fatal(dec2) - } -} - -// Testing that we can send an already encrypted packet. -func TestConnWriter_SendEncryptedDataPacket(t *testing.T) { - route, rRoute, _, _ := testConnWriter_getTestRoutes() - - writer := &testUDPAddrPortWriter{} - w := newConnWriter(writer, rRoute.IP) - in := []byte("Hello world!") - - w.SendEncryptedDataPacket(in, route) - - out := writer.Written() - if len(out) != 1 { - t.Fatal(out) - } - - if out[0].Addr != route.DirectAddr { - t.Fatal(out[0]) - } - - if !bytes.Equal(out[0].Data, in) { - t.Fatal(out[0]) - } -} diff --git a/peer/controlmessage.go b/peer/controlmessage.go index 7180dd0..09935ab 100644 --- a/peer/controlmessage.go +++ b/peer/controlmessage.go @@ -17,25 +17,25 @@ type controlMsg[T any] struct { func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error) { switch buf[0] { - case PacketTypeSyn: - packet, err := ParsePacketSyn(buf) - return controlMsg[PacketSyn]{ + case packetTypeSyn: + packet, err := parsePacketSyn(buf) + return controlMsg[packetSyn]{ SrcIP: srcIP, SrcAddr: srcAddr, Packet: packet, }, err - case PacketTypeAck: - packet, err := ParsePacketAck(buf) - return controlMsg[PacketAck]{ + case packetTypeAck: + packet, err := parsePacketAck(buf) + return controlMsg[packetAck]{ SrcIP: srcIP, SrcAddr: srcAddr, Packet: packet, }, err - case PacketTypeProbe: - packet, err := ParsePacketProbe(buf) - return controlMsg[PacketProbe]{ + case packetTypeProbe: + packet, err := parsePacketProbe(buf) + return controlMsg[packetProbe]{ SrcIP: srcIP, SrcAddr: srcAddr, Packet: packet, diff --git a/peer/crypto.go b/peer/crypto.go index dcc042b..e8afe60 100644 --- a/peer/crypto.go +++ b/peer/crypto.go @@ -36,7 +36,7 @@ func generateKeys() cryptoKeys { // Peer must have a ControlCipher. func encryptControlPacket( localIP byte, - peer *RemotePeer, + peer *remotePeer, pkt Marshaller, tmp []byte, out []byte, @@ -55,7 +55,7 @@ func encryptControlPacket( // // This function also drops packets with duplicate sequence numbers. func decryptControlPacket( - peer *RemotePeer, + peer *remotePeer, fromAddr netip.AddrPort, h header, encrypted []byte, @@ -83,7 +83,7 @@ func decryptControlPacket( func encryptDataPacket( localIP byte, destIP byte, - peer *RemotePeer, + peer *remotePeer, data []byte, out []byte, ) []byte { @@ -98,7 +98,7 @@ func encryptDataPacket( // Decrypts and de-dups incoming data packets. func decryptDataPacket( - peer *RemotePeer, + peer *remotePeer, h header, encrypted []byte, out []byte, diff --git a/peer/crypto_test.go b/peer/crypto_test.go index 824a43a..57adfd2 100644 --- a/peer/crypto_test.go +++ b/peer/crypto_test.go @@ -9,7 +9,7 @@ import ( "testing" ) -func newRoutePairForTesting() (*RemotePeer, *RemotePeer) { +func newRoutePairForTesting() (*remotePeer, *remotePeer) { keys1 := generateKeys() keys2 := generateKeys() @@ -33,7 +33,7 @@ func TestDecryptControlPacket(t *testing.T) { out = make([]byte, bufferSize) ) - in := PacketSyn{ + in := packetSyn{ TraceID: newTraceID(), SharedKey: r1.DataCipher.Key(), Direct: true, @@ -47,7 +47,7 @@ func TestDecryptControlPacket(t *testing.T) { t.Fatal(err) } - msg, ok := iMsg.(controlMsg[PacketSyn]) + msg, ok := iMsg.(controlMsg[packetSyn]) if !ok { t.Fatal(ok) } @@ -64,7 +64,7 @@ func TestDecryptControlPacket_decryptionFailed(t *testing.T) { out = make([]byte, bufferSize) ) - in := PacketSyn{ + in := packetSyn{ TraceID: newTraceID(), SharedKey: r1.DataCipher.Key(), Direct: true, @@ -90,7 +90,7 @@ func TestDecryptControlPacket_duplicate(t *testing.T) { out = make([]byte, bufferSize) ) - in := PacketSyn{ + in := packetSyn{ TraceID: newTraceID(), SharedKey: r1.DataCipher.Key(), Direct: true, @@ -109,24 +109,25 @@ func TestDecryptControlPacket_duplicate(t *testing.T) { } } -func TestDecryptControlPacket_invalidPacket(t *testing.T) { - var ( - r1, r2 = newRoutePairForTesting() - tmp = make([]byte, bufferSize) - out = make([]byte, bufferSize) - ) +/* + func TestDecryptControlPacket_invalidPacket(t *testing.T) { + var ( + r1, r2 = newRoutePairForTesting() + tmp = make([]byte, bufferSize) + out = make([]byte, bufferSize) + ) - in := testPacket("hello!") + in := testPacket("hello!") - enc := encryptControlPacket(r1.IP, r2, in, tmp, out) - h := parseHeader(enc) + enc := encryptControlPacket(r1.IP, r2, in, tmp, out) + h := parseHeader(enc) - _, err := decryptControlPacket(r2, netip.AddrPort{}, h, enc, tmp) - if !errors.Is(err, errUnknownPacketType) { - t.Fatal(err) + _, err := decryptControlPacket(r2, netip.AddrPort{}, h, enc, tmp) + if !errors.Is(err, errUnknownPacketType) { + t.Fatal(err) + } } -} - +*/ func TestDecryptDataPacket(t *testing.T) { var ( r1, r2 = newRoutePairForTesting() diff --git a/peer/hubpoller.go b/peer/hubpoller.go index f608bd5..2b50495 100644 --- a/peer/hubpoller.go +++ b/peer/hubpoller.go @@ -11,15 +11,21 @@ import ( ) type hubPoller struct { - client *http.Client - req *http.Request - versions [256]int64 - localIP byte - netName string - super controlMsgHandler + client *http.Client + req *http.Request + versions [256]int64 + localIP byte + netName string + handleControlMsg func(fromIP byte, msg any) } -func newHubPoller(localIP byte, netName, hubURL, apiKey string, super controlMsgHandler) (*hubPoller, error) { +func newHubPoller( + localIP byte, + netName, + hubURL, + apiKey string, + handleControlMsg func(byte, any), +) (*hubPoller, error) { u, err := url.Parse(hubURL) if err != nil { return nil, err @@ -36,11 +42,11 @@ func newHubPoller(localIP byte, netName, hubURL, apiKey string, super controlMsg req.SetBasicAuth("", apiKey) return &hubPoller{ - client: client, - req: req, - localIP: localIP, - netName: netName, - super: super, + client: client, + req: req, + localIP: localIP, + netName: netName, + handleControlMsg: handleControlMsg, }, nil } @@ -90,7 +96,7 @@ func (hp *hubPoller) applyNetworkState(state m.NetworkState) { for i, peer := range state.Peers { if i != int(hp.localIP) { if peer == nil || peer.Version != hp.versions[i] { - hp.super.HandleControlMsg(peerUpdateMsg{PeerIP: byte(i), Peer: state.Peers[i]}) + hp.handleControlMsg(byte(i), peerUpdateMsg{PeerIP: byte(i), Peer: state.Peers[i]}) if peer != nil { hp.versions[i] = peer.Version } diff --git a/peer/ifreader.go b/peer/ifreader.go deleted file mode 100644 index 79ff441..0000000 --- a/peer/ifreader.go +++ /dev/null @@ -1,100 +0,0 @@ -package peer - -import ( - "io" - "log" - "sync/atomic" -) - -type ifReader struct { - iface io.Reader - peers [256]*atomic.Pointer[RemotePeer] - relay *atomic.Pointer[RemotePeer] - sender dataPacketSender -} - -func newIFReader( - iface io.Reader, - peers [256]*atomic.Pointer[RemotePeer], - relay *atomic.Pointer[RemotePeer], - sender dataPacketSender, -) *ifReader { - return &ifReader{ - iface: iface, - peers: peers, - relay: relay, - sender: sender, - } -} - -func (r *ifReader) Run() { - var ( - packet = make([]byte, bufferSize) - remoteIP byte - ok bool - ) - - for { - packet = r.readNextPacket(packet) - if remoteIP, ok = r.parsePacket(packet); ok { - r.sendPacket(packet, remoteIP) - } - } -} - -func (r *ifReader) sendPacket(pkt []byte, remoteIP byte) { - peer := r.peers[remoteIP].Load() - if !peer.Up { - log.Printf("Peer not connected: %d", remoteIP) - return - } - - // Direct path => early return. - if peer.Direct { - r.sender.SendDataPacket(pkt, peer) - return - } - - if relay := r.relay.Load(); relay != nil && relay.Up { - r.sender.RelayDataPacket(pkt, peer, relay) - } -} - -// Get next packet, returning packet, and destination ip. -func (r *ifReader) readNextPacket(buf []byte) []byte { - n, err := r.iface.Read(buf[:cap(buf)]) - if err != nil { - log.Fatalf("Failed to read from interface: %v", err) - } - - return buf[:n] -} - -func (r *ifReader) parsePacket(buf []byte) (byte, bool) { - n := len(buf) - if n == 0 { - return 0, false - } - - version := buf[0] >> 4 - - switch version { - case 4: - if n < 20 { - log.Printf("Short IPv4 packet: %d", len(buf)) - return 0, false - } - return buf[19], true - - case 6: - if len(buf) < 40 { - log.Printf("Short IPv6 packet: %d", len(buf)) - return 0, false - } - return buf[39], true - - default: - log.Printf("Invalid IP packet version: %v", version) - return 0, false - } -} diff --git a/peer/ifreader2.go b/peer/ifreader2.go index c390e8f..22bd7cf 100644 --- a/peer/ifreader2.go +++ b/peer/ifreader2.go @@ -3,22 +3,24 @@ package peer import ( "io" "log" + "net/netip" + "sync/atomic" ) type IFReader struct { - iface io.Reader - connWriter interface { - WriteData(ip byte, pkt []byte) - } + iface io.Reader + writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error) + rt *atomic.Pointer[routingTable] + buf1 []byte + buf2 []byte } func NewIFReader( iface io.Reader, - connWriter interface { - WriteData(ip byte, pkt []byte) - }, + writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error), + rt *atomic.Pointer[routingTable], ) *IFReader { - return &IFReader{iface, connWriter} + return &IFReader{iface, writeToUDPAddrPort, rt, newBuf(), newBuf()} } func (r *IFReader) Run() { @@ -30,9 +32,32 @@ func (r *IFReader) Run() { func (r *IFReader) handleNextPacket(packet []byte) { packet = r.readNextPacket(packet) - if remoteIP, ok := r.parsePacket(packet); ok { - r.connWriter.WriteData(remoteIP, packet) + remoteIP, ok := r.parsePacket(packet) + if !ok { + return } + + rt := r.rt.Load() + peer := rt.Peers[remoteIP] + if !peer.Up { + r.logf("Peer %d not up.", peer.IP) + return + } + + enc := peer.EncryptDataPacket(peer.IP, packet, r.buf1) + if peer.Direct { + r.writeToUDPAddrPort(enc, peer.DirectAddr) + return + } + + relay, ok := rt.GetRelay() + if !ok { + r.logf("Relay not available for peer %d.", peer.IP) + return + } + + enc = relay.EncryptDataPacket(peer.IP, enc, r.buf2) + r.writeToUDPAddrPort(enc, relay.DirectAddr) } func (r *IFReader) readNextPacket(buf []byte) []byte { diff --git a/peer/ifreader2_test.go b/peer/ifreader2_test.go index 779cf49..92ec5ac 100644 --- a/peer/ifreader2_test.go +++ b/peer/ifreader2_test.go @@ -1,9 +1,6 @@ package peer -import ( - "testing" -) - +/* func TestIFReader_IPv4(t *testing.T) { p1, p2, _ := NewPeersForTesting() @@ -81,3 +78,4 @@ func TestIFReader_parsePacket_shortIPv6(t *testing.T) { t.Fatal(ip, ok) } } +*/ diff --git a/peer/ifreader_test.go b/peer/ifreader_test.go deleted file mode 100644 index 620d2b1..0000000 --- a/peer/ifreader_test.go +++ /dev/null @@ -1,232 +0,0 @@ -package peer - -import ( - "bytes" - "reflect" - "sync/atomic" - "testing" -) - -// Test that we parse IPv4 packets correctly. -func TestIFReader_parsePacket_ipv4(t *testing.T) { - r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) - - pkt := make([]byte, 1234) - pkt[0] = 4 << 4 - pkt[19] = 128 - - if ip, ok := r.parsePacket(pkt); !ok || ip != 128 { - t.Fatal(ip, ok) - } -} - -// Test that we parse IPv6 packets correctly. -func TestIFReader_parsePacket_ipv6(t *testing.T) { - r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) - - pkt := make([]byte, 1234) - pkt[0] = 6 << 4 - pkt[39] = 42 - - if ip, ok := r.parsePacket(pkt); !ok || ip != 42 { - t.Fatal(ip, ok) - } -} - -/* -// Test that empty packets work as expected. -func TestIFReader_parsePacket_emptyPacket(t *testing.T) { - r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) - - pkt := make([]byte, 0) - if ip, ok := r.parsePacket(pkt); ok { - t.Fatal(ip, ok) - } -} - -// Test that invalid IP versions fail. -func TestIFReader_parsePacket_invalidIPVersion(t *testing.T) { - r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) - - for i := byte(1); i < 16; i++ { - if i == 4 || i == 6 { - continue - } - pkt := make([]byte, 1234) - pkt[0] = i << 4 - - if ip, ok := r.parsePacket(pkt); ok { - t.Fatal(i, ip, ok) - } - } -} - -// Test that short IPv4 packets fail. -func TestIFReader_parsePacket_shortIPv4(t *testing.T) { - r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) - - pkt := make([]byte, 19) - pkt[0] = 4 << 4 - - if ip, ok := r.parsePacket(pkt); ok { - t.Fatal(ip, ok) - } -} - -// Test that short IPv6 packets fail. -func TestIFReader_parsePacket_shortIPv6(t *testing.T) { - r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) - - pkt := make([]byte, 39) - pkt[0] = 6 << 4 - - if ip, ok := r.parsePacket(pkt); ok { - t.Fatal(ip, ok) - } -} - -// Test that we can read a packet. -func TestIFReader_readNextpacket(t *testing.T) { - in, out := net.Pipe() - r := newIFReader(out, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) - defer in.Close() - defer out.Close() - - go in.Write([]byte("hello world!")) - - pkt := r.readNextPacket(make([]byte, bufferSize)) - if !bytes.Equal(pkt, []byte("hello world!")) { - t.Fatalf("%s", pkt) - } -} -*/ -// ---------------------------------------------------------------------------- - -type sentPacket struct { - Relayed bool - Packet []byte - Route RemotePeer - Relay RemotePeer -} - -type sendPacketTestHarness struct { - Packets []sentPacket -} - -func (h *sendPacketTestHarness) SendDataPacket(pkt []byte, route *RemotePeer) { - h.Packets = append(h.Packets, sentPacket{ - Packet: bytes.Clone(pkt), - Route: *route, - }) -} - -func (h *sendPacketTestHarness) RelayDataPacket(pkt []byte, route, relay *RemotePeer) { - h.Packets = append(h.Packets, sentPacket{ - Relayed: true, - Packet: bytes.Clone(pkt), - Route: *route, - Relay: *relay, - }) -} - -func newIFReaderForSendPacketTesting() (*ifReader, *sendPacketTestHarness) { - h := &sendPacketTestHarness{} - - routes := [256]*atomic.Pointer[RemotePeer]{} - for i := range routes { - routes[i] = &atomic.Pointer[RemotePeer]{} - routes[i].Store(&RemotePeer{}) - } - relay := &atomic.Pointer[RemotePeer]{} - r := newIFReader(nil, routes, relay, h) - return r, h -} - -// Testing that we can send a packet directly. -func TestIFReader_sendPacket_direct(t *testing.T) { - r, h := newIFReaderForSendPacketTesting() - - route := r.peers[2].Load() - route.Up = true - route.Direct = true - - in := []byte("hello world") - - r.sendPacket(in, 2) - if len(h.Packets) != 1 { - t.Fatal(h.Packets) - } - - expected := sentPacket{ - Relayed: false, - Packet: in, - Route: *route, - } - - if !reflect.DeepEqual(h.Packets[0], expected) { - t.Fatal(h.Packets[0]) - } -} - -// Testing that we don't send a packet if route isn't up. -func TestIFReader_sendPacket_directNotUp(t *testing.T) { - r, h := newIFReaderForSendPacketTesting() - - route := r.peers[2].Load() - route.Direct = true - - in := []byte("hello world") - - r.sendPacket(in, 2) - if len(h.Packets) != 0 { - t.Fatal(h.Packets) - } -} - -// Testing that we can send a packet via a relay. -func TestIFReader_sendPacket_relayed(t *testing.T) { - r, h := newIFReaderForSendPacketTesting() - - route := r.peers[2].Load() - route.Up = true - route.Direct = false - - relay := r.peers[3].Load() - r.relay.Store(relay) - relay.Up = true - relay.Direct = true - - in := []byte("hello world") - - r.sendPacket(in, 2) - if len(h.Packets) != 1 { - t.Fatal(h.Packets) - } - - expected := sentPacket{ - Relayed: true, - Packet: in, - Route: *route, - Relay: *relay, - } - - if !reflect.DeepEqual(h.Packets[0], expected) { - t.Fatal(h.Packets[0]) - } -} - -// Testing that we don't try to send on a nil relay IP. -func TestIFReader_sendPacket_nilRealy(t *testing.T) { - r, h := newIFReaderForSendPacketTesting() - - route := r.peers[2].Load() - route.Up = true - route.Direct = false - - in := []byte("hello world") - - r.sendPacket(in, 2) - if len(h.Packets) != 0 { - t.Fatal(h.Packets) - } -} diff --git a/peer/interface.go b/peer/interface.go new file mode 100644 index 0000000..7035c43 --- /dev/null +++ b/peer/interface.go @@ -0,0 +1,177 @@ +package peer + +import ( + "fmt" + "io" + "log" + "net" + "os" + "syscall" + + "golang.org/x/sys/unix" +) + +// Get next packet, returning packet, ip, and possible error. +func readNextPacket(iface io.ReadWriteCloser, buf []byte) ([]byte, byte, error) { + var ( + version byte + ip byte + ) + for { + n, err := iface.Read(buf[:cap(buf)]) + if err != nil { + return nil, ip, err + } + + buf = buf[:n] + version = buf[0] >> 4 + + switch version { + case 4: + if n < 20 { + log.Printf("Short IPv4 packet: %d", len(buf)) + continue + } + ip = buf[19] + + case 6: + if len(buf) < 40 { + log.Printf("Short IPv6 packet: %d", len(buf)) + continue + } + ip = buf[39] + + default: + log.Printf("Invalid IP packet version: %v", version) + continue + } + + return buf, ip, nil + } +} + +func openInterface(network []byte, localIP byte, name string) (io.ReadWriteCloser, error) { + if len(network) != 4 { + return nil, fmt.Errorf("expected network to be 4 bytes, got %d", len(network)) + } + ip := net.IPv4(network[0], network[1], network[2], localIP) + + ////////////////////////// + // Create TUN Interface // + ////////////////////////// + + tunFD, err := syscall.Open("/dev/net/tun", syscall.O_RDWR|unix.O_CLOEXEC, 0600) + if err != nil { + return nil, fmt.Errorf("failed to open TUN device: %w", err) + } + + // New interface request. + req, err := unix.NewIfreq(name) + if err != nil { + return nil, fmt.Errorf("failed to create new TUN interface request: %w", err) + } + + // Flags: + // + // IFF_NO_PI => don't add packet info data to packets sent to the interface. + // IFF_TUN => create a TUN device handling IP packets. + req.SetUint16(unix.IFF_NO_PI | unix.IFF_TUN) + + err = unix.IoctlIfreq(tunFD, unix.TUNSETIFF, req) + if err != nil { + return nil, fmt.Errorf("failed to set TUN device settings: %w", err) + } + + // Name may not be exactly the same? + name = req.Name() + + ///////////// + // Set MTU // + ///////////// + + // We need a socket file descriptor to set other options for some reason. + sockFD, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP) + if err != nil { + return nil, fmt.Errorf("failed to open socket: %w", err) + } + defer unix.Close(sockFD) + + req, err = unix.NewIfreq(name) + if err != nil { + return nil, fmt.Errorf("failed to create MTU interface request: %w", err) + } + + req.SetUint32(if_mtu) + if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFMTU, req); err != nil { + return nil, fmt.Errorf("failed to set interface MTU: %w", err) + } + + ////////////////////// + // Set Queue Length // + ////////////////////// + + req, err = unix.NewIfreq(name) + if err != nil { + return nil, fmt.Errorf("failed to create IP interface request: %w", err) + } + + req.SetUint16(if_queue_len) + if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFTXQLEN, req); err != nil { + return nil, fmt.Errorf("failed to set interface queue length: %w", err) + } + + ///////////////////// + // Set IP and Mask // + ///////////////////// + + req, err = unix.NewIfreq(name) + if err != nil { + return nil, fmt.Errorf("failed to create IP interface request: %w", err) + } + + if err := req.SetInet4Addr(ip.To4()); err != nil { + return nil, fmt.Errorf("failed to set interface request IP: %w", err) + } + + if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFADDR, req); err != nil { + return nil, fmt.Errorf("failed to set interface IP: %w", err) + } + + // SET MASK - must happen after setting address. + req, err = unix.NewIfreq(name) + if err != nil { + return nil, fmt.Errorf("failed to create mask interface request: %w", err) + } + + if err := req.SetInet4Addr(net.IPv4(255, 255, 255, 0).To4()); err != nil { + return nil, fmt.Errorf("failed to set interface request mask: %w", err) + } + + if err := unix.IoctlIfreq(sockFD, unix.SIOCSIFNETMASK, req); err != nil { + return nil, fmt.Errorf("failed to set interface mask: %w", err) + } + + //////////////////////// + // Bring Interface Up // + //////////////////////// + + req, err = unix.NewIfreq(name) + if err != nil { + return nil, fmt.Errorf("failed to create up interface request: %w", err) + } + + // Get current flags. + if err = unix.IoctlIfreq(sockFD, unix.SIOCGIFFLAGS, req); err != nil { + return nil, fmt.Errorf("failed to get interface flags: %w", err) + } + + flags := req.Uint16() | unix.IFF_UP | unix.IFF_RUNNING + + // Set UP flag / broadcast flags. + req.SetUint16(flags) + if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFFLAGS, req); err != nil { + return nil, fmt.Errorf("failed to set interface up: %w", err) + } + + return os.NewFile(uintptr(tunFD), "tun"), nil +} diff --git a/peer/interfaces.go b/peer/interfaces.go index 8e99e8b..d6b90d5 100644 --- a/peer/interfaces.go +++ b/peer/interfaces.go @@ -31,17 +31,17 @@ type Marshaller interface { } type dataPacketSender interface { - SendDataPacket(pkt []byte, peer *RemotePeer) - RelayDataPacket(pkt []byte, peer, relay *RemotePeer) + SendDataPacket(pkt []byte, peer *remotePeer) + RelayDataPacket(pkt []byte, peer, relay *remotePeer) } type controlPacketSender interface { - SendControlPacket(pkt Marshaller, peer *RemotePeer) - RelayControlPacket(pkt Marshaller, peer, relay *RemotePeer) + SendControlPacket(pkt Marshaller, peer *remotePeer) + RelayControlPacket(pkt Marshaller, peer, relay *remotePeer) } type encryptedPacketSender interface { - SendEncryptedDataPacket(pkt []byte, peer *RemotePeer) + SendEncryptedDataPacket(pkt []byte, peer *remotePeer) } type controlMsgHandler interface { diff --git a/peer/main.go b/peer/main.go new file mode 100644 index 0000000..c1ce110 --- /dev/null +++ b/peer/main.go @@ -0,0 +1,23 @@ +package peer + +import ( + "flag" + "os" +) + +func Main() { + conf := Config{} + + flag.StringVar(&conf.NetName, "name", "", "[REQUIRED] The network name.") + flag.StringVar(&conf.HubAddress, "hub-address", "", "[REQUIRED] The hub address.") + flag.StringVar(&conf.APIKey, "api-key", "", "[REQUIRED] The node's API key.") + flag.Parse() + + if conf.NetName == "" || conf.HubAddress == "" || conf.APIKey == "" { + flag.Usage() + os.Exit(1) + } + + peer := New(conf) + peer.Run() +} diff --git a/peer/mcreader.go b/peer/mcreader.go index 38921f1..3410655 100644 --- a/peer/mcreader.go +++ b/peer/mcreader.go @@ -8,7 +8,7 @@ import ( type mcReader struct { conn udpReader super controlMsgHandler - peers [256]*atomic.Pointer[RemotePeer] + peers [256]*atomic.Pointer[remotePeer] incoming []byte buf []byte @@ -17,7 +17,7 @@ type mcReader struct { func newMCReader( conn udpReader, super controlMsgHandler, - peers [256]*atomic.Pointer[RemotePeer], + peers [256]*atomic.Pointer[remotePeer], ) *mcReader { return &mcReader{conn, super, peers, newBuf(), newBuf()} } @@ -50,7 +50,7 @@ func (r *mcReader) handleNextPacket() { return } - r.super.HandleControlMsg(controlMsg[PacketLocalDiscovery]{ + r.super.HandleControlMsg(controlMsg[packetLocalDiscovery]{ SrcIP: h.SourceIP, SrcAddr: remoteAddr, }) diff --git a/peer/mcreader_test.go b/peer/mcreader_test.go index 50bf821..60feb44 100644 --- a/peer/mcreader_test.go +++ b/peer/mcreader_test.go @@ -1,13 +1,6 @@ package peer -import ( - "bytes" - "net" - "net/netip" - "sync/atomic" - "testing" -) - +/* type mcMockConn struct { packets chan []byte } @@ -136,3 +129,4 @@ func TestMCReader_badSignature(t *testing.T) { t.Fatal(super.Messages) } } +*/ diff --git a/peer/packets.go b/peer/packets.go index 596483d..b300dee 100644 --- a/peer/packets.go +++ b/peer/packets.go @@ -5,41 +5,34 @@ import ( ) const ( - PacketTypeSyn = iota + 1 - PacketTypeSynAck - PacketTypeAck - PacketTypeProbe - PacketTypeAddrDiscovery + packetTypeSyn = 1 + packetTypeAck = 3 + packetTypeProbe = 4 + packetTypeAddrDiscovery = 5 ) // ---------------------------------------------------------------------------- -type PacketSyn struct { - TraceID uint64 // TraceID to match response w/ request. - //SentAt int64 // Unixmilli. - //SharedKeyType byte // Currently only 1 is supported for AES. +type packetSyn struct { + TraceID uint64 // TraceID to match response w/ request. SharedKey [32]byte // Our shared key. Direct bool PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. } -func (p PacketSyn) Marshal(buf []byte) []byte { +func (p packetSyn) Marshal(buf []byte) []byte { return newBinWriter(buf). - Byte(PacketTypeSyn). + Byte(packetTypeSyn). Uint64(p.TraceID). - //Int64(p.SentAt). - //Byte(p.SharedKeyType). SharedKey(p.SharedKey). Bool(p.Direct). AddrPort8(p.PossibleAddrs). Build() } -func ParsePacketSyn(buf []byte) (p PacketSyn, err error) { +func parsePacketSyn(buf []byte) (p packetSyn, err error) { err = newBinReader(buf[1:]). Uint64(&p.TraceID). - //Int64(&p.SentAt). - //Byte(&p.SharedKeyType). SharedKey(&p.SharedKey). Bool(&p.Direct). AddrPort8(&p.PossibleAddrs). @@ -49,22 +42,22 @@ func ParsePacketSyn(buf []byte) (p PacketSyn, err error) { // ---------------------------------------------------------------------------- -type PacketAck struct { +type packetAck struct { TraceID uint64 ToAddr netip.AddrPort PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. } -func (p PacketAck) Marshal(buf []byte) []byte { +func (p packetAck) Marshal(buf []byte) []byte { return newBinWriter(buf). - Byte(PacketTypeAck). + Byte(packetTypeAck). Uint64(p.TraceID). AddrPort(p.ToAddr). AddrPort8(p.PossibleAddrs). Build() } -func ParsePacketAck(buf []byte) (p PacketAck, err error) { +func parsePacketAck(buf []byte) (p packetAck, err error) { err = newBinReader(buf[1:]). Uint64(&p.TraceID). AddrPort(&p.ToAddr). @@ -77,18 +70,18 @@ func ParsePacketAck(buf []byte) (p PacketAck, err error) { // A probeReqPacket is sent from a client to a server to determine if direct // UDP communication can be used. -type PacketProbe struct { +type packetProbe struct { TraceID uint64 } -func (p PacketProbe) Marshal(buf []byte) []byte { +func (p packetProbe) Marshal(buf []byte) []byte { return newBinWriter(buf). - Byte(PacketTypeProbe). + Byte(packetTypeProbe). Uint64(p.TraceID). Build() } -func ParsePacketProbe(buf []byte) (p PacketProbe, err error) { +func parsePacketProbe(buf []byte) (p packetProbe, err error) { err = newBinReader(buf[1:]). Uint64(&p.TraceID). Error() @@ -97,4 +90,4 @@ func ParsePacketProbe(buf []byte) (p PacketProbe, err error) { // ---------------------------------------------------------------------------- -type PacketLocalDiscovery struct{} +type packetLocalDiscovery struct{} diff --git a/peer/packets_test.go b/peer/packets_test.go index 3ddc1a0..c18b40a 100644 --- a/peer/packets_test.go +++ b/peer/packets_test.go @@ -8,7 +8,7 @@ import ( ) func TestSynPacket(t *testing.T) { - p := PacketSyn{ + p := packetSyn{ TraceID: newTraceID(), //SentAt: time.Now().UnixMilli(), //SharedKeyType: 1, @@ -21,7 +21,7 @@ func TestSynPacket(t *testing.T) { p.PossibleAddrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{3, 2, 3, 4}), 60000) buf := p.Marshal(newBuf()) - p2, err := ParsePacketSyn(buf) + p2, err := parsePacketSyn(buf) if err != nil { t.Fatal(err) } @@ -31,7 +31,7 @@ func TestSynPacket(t *testing.T) { } func TestAckPacket(t *testing.T) { - p := PacketAck{ + p := packetAck{ TraceID: newTraceID(), ToAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 234), } @@ -41,7 +41,7 @@ func TestAckPacket(t *testing.T) { p.PossibleAddrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{3, 2, 3, 4}), 60000) buf := p.Marshal(newBuf()) - p2, err := ParsePacketAck(buf) + p2, err := parsePacketAck(buf) if err != nil { t.Fatal(err) } @@ -51,12 +51,12 @@ func TestAckPacket(t *testing.T) { } func TestProbePacket(t *testing.T) { - p := PacketProbe{ + p := packetProbe{ TraceID: newTraceID(), } buf := p.Marshal(newBuf()) - p2, err := ParsePacketProbe(buf) + p2, err := parsePacketProbe(buf) if err != nil { t.Fatal(err) } diff --git a/peer/peer.go b/peer/peer.go new file mode 100644 index 0000000..6dc925b --- /dev/null +++ b/peer/peer.go @@ -0,0 +1,161 @@ +package peer + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "log" + "net" + "net/http" + "net/netip" + "net/url" + "sync" + "sync/atomic" + "vppn/m" +) + +type Peer struct { + ifReader *IFReader + connReader *ConnReader + iface io.Writer + hubPoller *hubPoller + super *Super +} + +type Config struct { + NetName string + HubAddress string + APIKey string +} + +func New(conf Config) *Peer { + config, err := loadPeerConfig(conf.NetName) + if err != nil { + log.Printf("Failed to load configuration: %v", err) + log.Printf("Initializing...") + initPeerWithHub(conf) + + config, err = loadPeerConfig(conf.NetName) + if err != nil { + log.Fatalf("Failed to load configuration: %v", err) + } + } + + iface, err := openInterface(config.Network, config.PeerIP, conf.NetName) + if err != nil { + log.Fatalf("Failed to open interface: %v", err) + } + + myAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", config.Port)) + if err != nil { + log.Fatalf("Failed to resolve UDP address: %v", err) + } + + log.Printf("Listening on %v...", myAddr) + conn, err := net.ListenUDP("udp", myAddr) + if err != nil { + log.Fatalf("Failed to open UDP port: %v", err) + } + + conn.SetReadBuffer(1024 * 1024 * 8) + conn.SetWriteBuffer(1024 * 1024 * 8) + + // Wrap write function - this is necessary to avoid starvation. + writeLock := sync.Mutex{} + writeToUDPAddrPort := func(b []byte, addr netip.AddrPort) (n int, err error) { + writeLock.Lock() + n, err = conn.WriteToUDPAddrPort(b, addr) + if err != nil { + log.Printf("Failed to write packet: %v", err) + } + writeLock.Unlock() + return n, err + } + + var localAddr netip.AddrPort + ip, ok := netip.AddrFromSlice(config.PublicIP) + if ok { + localAddr = netip.AddrPortFrom(ip, config.Port) + } + + rt := newRoutingTable(config.PeerIP, localAddr) + rtPtr := &atomic.Pointer[routingTable]{} + rtPtr.Store(&rt) + + ifReader := NewIFReader(iface, writeToUDPAddrPort, rtPtr) + super := NewSuper(writeToUDPAddrPort, rtPtr, config.PrivKey) + connReader := NewConnReader(conn.ReadFromUDPAddrPort, writeToUDPAddrPort, iface, super.HandleControlMsg, rtPtr) + hubPoller, err := newHubPoller(config.PeerIP, conf.NetName, conf.HubAddress, conf.APIKey, super.HandleControlMsg) + if err != nil { + log.Fatalf("Failed to create hub poller: %v", err) + } + + return &Peer{ + iface: iface, + ifReader: ifReader, + connReader: connReader, + hubPoller: hubPoller, + super: super, + } +} + +func (p *Peer) Run() { + go p.ifReader.Run() + go p.connReader.Run() + p.super.Start() + p.hubPoller.Run() +} + +func initPeerWithHub(conf Config) { + keys := generateKeys() + + initURL, err := url.Parse(conf.HubAddress) + if err != nil { + log.Fatalf("Failed to parse hub URL: %v", err) + } + initURL.Path = "/peer/init/" + + args := m.PeerInitArgs{ + EncPubKey: keys.PubKey, + PubSignKey: keys.PubSignKey, + } + + buf := &bytes.Buffer{} + if err := json.NewEncoder(buf).Encode(args); err != nil { + log.Fatalf("Failed to encode init args: %v", err) + } + + req, err := http.NewRequest(http.MethodPost, initURL.String(), buf) + if err != nil { + log.Fatalf("Failed to construct request: %v", err) + } + req.SetBasicAuth("", conf.APIKey) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + log.Fatalf("Failed to init with hub: %v", err) + } + defer resp.Body.Close() + + data, err := io.ReadAll(resp.Body) + if err != nil { + log.Fatalf("Failed to read response body: %v", err) + } + + peerConfig := localConfig{} + if err := json.Unmarshal(data, &peerConfig.PeerConfig); err != nil { + log.Fatalf("Failed to parse configuration: %v\n%s", err, data) + } + + peerConfig.PubKey = keys.PubKey + peerConfig.PrivKey = keys.PrivKey + peerConfig.PubSignKey = keys.PubSignKey + peerConfig.PrivSignKey = keys.PrivSignKey + + if err := storePeerConfig(conf.NetName, peerConfig); err != nil { + log.Fatalf("Failed to store configuration: %v", err) + } + + log.Print("Initialization successful.") +} diff --git a/peer/peer_test.go b/peer/peer_test.go index 414beaa..863ca8f 100644 --- a/peer/peer_test.go +++ b/peer/peer_test.go @@ -11,36 +11,25 @@ import ( // A test peer. type P struct { cryptoKeys - RT *atomic.Pointer[RoutingTable] + RT *atomic.Pointer[routingTable] Conn *TestUDPConn IFace *TestIFace - ConnWriter *ConnWriter ConnReader *ConnReader IFReader *IFReader - Super *Supervisor } func NewPeerForTesting(n *TestNetwork, ip byte, addr netip.AddrPort) P { p := P{ cryptoKeys: generateKeys(), - RT: &atomic.Pointer[RoutingTable]{}, + RT: &atomic.Pointer[routingTable]{}, IFace: NewTestIFace(), } - rt := NewRoutingTable(ip, addr) + rt := newRoutingTable(ip, addr) p.RT.Store(&rt) p.Conn = n.NewUDPConn(addr) - p.ConnWriter = NewConnWriter(p.Conn.WriteToUDPAddrPort, p.RT) - p.IFReader = NewIFReader(p.IFace, p.ConnWriter) + //p.ConnWriter = NewConnWriter(p.Conn.WriteToUDPAddrPort, p.RT) - /* - p.ConnReader = NewConnReader( - p.Conn.ReadFromUDPAddrPort, - p.IFace, - p.ConnWriter.Forward, - p.Super.HandleControlMsg, - p.RT) - */ return p } diff --git a/peer/peerstates.go b/peer/peerstates.go index f12ab08..5ded157 100644 --- a/peer/peerstates.go +++ b/peer/peerstates.go @@ -11,21 +11,21 @@ import ( "git.crumpington.com/lib/go/ratelimiter" ) -type PeerState interface { - OnPeerUpdate(*m.Peer) PeerState - OnSyn(controlMsg[PacketSyn]) PeerState - OnAck(controlMsg[PacketAck]) - OnProbe(controlMsg[PacketProbe]) PeerState - OnLocalDiscovery(controlMsg[PacketLocalDiscovery]) - OnPingTimer() PeerState +type peerState interface { + OnPeerUpdate(*m.Peer) peerState + OnSyn(controlMsg[packetSyn]) peerState + OnAck(controlMsg[packetAck]) + OnProbe(controlMsg[packetProbe]) peerState + OnLocalDiscovery(controlMsg[packetLocalDiscovery]) + OnPingTimer() peerState } // ---------------------------------------------------------------------------- -type State struct { +type pState struct { // Output. - publish func(RemotePeer) - sendControlPacket func(RemotePeer, Marshaller) + publish func(remotePeer) + sendControlPacket func(remotePeer, Marshaller) // Immutable data. localIP byte @@ -37,7 +37,7 @@ type State struct { // The purpose of this state machine is to manage the RemotePeer object, // publishing it as necessary. - staged RemotePeer // Local copy of shared data. See publish(). + staged remotePeer // Local copy of shared data. See publish(). // Mutable peer data. peer *m.Peer @@ -47,25 +47,28 @@ type State struct { limiter *ratelimiter.Limiter } -func (s *State) OnPeerUpdate(peer *m.Peer) PeerState { +func (s *pState) OnPeerUpdate(peer *m.Peer) peerState { defer func() { // Don't defer directly otherwise s.staged will be evaluated immediately // and won't reflect changes made in the function. s.publish(s.staged) }() - if peer == nil { - return EnterStateDisconnected(s) - } - s.peer = peer - s.staged.localIP = s.localIP - s.staged.IP = peer.PeerIP s.staged.Up = false s.staged.Relay = false s.staged.Direct = false s.staged.DirectAddr = netip.AddrPort{} + s.staged.PubSignKey = nil + s.staged.ControlCipher = nil + s.staged.DataCipher = nil + + if peer == nil { + return enterStateDisconnected(s) + } + + s.staged.IP = peer.PeerIP s.staged.PubSignKey = peer.PubSignKey s.staged.ControlCipher = newControlCipher(s.privKey, peer.PubKey) s.staged.DataCipher = newDataCipher() @@ -76,30 +79,32 @@ func (s *State) OnPeerUpdate(peer *m.Peer) PeerState { s.staged.DirectAddr = netip.AddrPortFrom(ip, peer.Port) if s.localAddr.IsValid() && s.localIP < s.remoteIP { - return EnterStateServer(s) + return enterStateServer(s) } - return EnterStateClientDirect(s) + return enterStateClientDirect(s) } if s.localAddr.IsValid() { s.staged.Direct = true - return EnterStateServer(s) + return enterStateServer(s) } if s.localIP < s.remoteIP { - return EnterStateServer(s) + return enterStateServer(s) } - return EnterStateClientRelayed(s) + return enterStateClientRelayed(s) } -func (s *State) logf(format string, args ...any) { +func (s *pState) logf(format string, args ...any) { b := strings.Builder{} name := "" if s.peer != nil { name = s.peer.Name } + b.WriteString(fmt.Sprintf("%03d", s.remoteIP)) + b.WriteString(fmt.Sprintf("%30s: ", name)) if s.staged.Direct { @@ -119,7 +124,7 @@ func (s *State) logf(format string, args ...any) { // ---------------------------------------------------------------------------- -func (s *State) SendTo(pkt Marshaller, addr netip.AddrPort) { +func (s *pState) SendTo(pkt Marshaller, addr netip.AddrPort) { if !addr.IsValid() { return } @@ -129,7 +134,7 @@ func (s *State) SendTo(pkt Marshaller, addr netip.AddrPort) { s.Send(route, pkt) } -func (s *State) Send(peer RemotePeer, pkt Marshaller) { +func (s *pState) Send(peer remotePeer, pkt Marshaller) { if err := s.limiter.Limit(); err != nil { s.logf("Rate limited.") return @@ -139,42 +144,32 @@ func (s *State) Send(peer RemotePeer, pkt Marshaller) { // ---------------------------------------------------------------------------- -type StateDisconnected struct{ *State } +type stateDisconnected struct{ *pState } -func EnterStateDisconnected(s *State) PeerState { - s.logf("==> Disconnected") - s.peer = nil - s.staged.Up = false - s.staged.Relay = false - s.staged.Direct = false - s.staged.DirectAddr = netip.AddrPort{} - s.staged.PubSignKey = nil - s.staged.ControlCipher = nil - s.staged.DataCipher = nil - s.publish(s.staged) - return &StateDisconnected{State: s} +func enterStateDisconnected(s *pState) peerState { + return &stateDisconnected{pState: s} } -func (s *StateDisconnected) OnSyn(controlMsg[PacketSyn]) PeerState { return nil } -func (s *StateDisconnected) OnAck(controlMsg[PacketAck]) {} -func (s *StateDisconnected) OnProbe(controlMsg[PacketProbe]) PeerState { return nil } -func (s *StateDisconnected) OnLocalDiscovery(controlMsg[PacketLocalDiscovery]) {} -func (s *StateDisconnected) OnPingTimer() PeerState { return nil } +func (s *stateDisconnected) OnSyn(controlMsg[packetSyn]) peerState { return s } +func (s *stateDisconnected) OnAck(controlMsg[packetAck]) {} +func (s *stateDisconnected) OnProbe(controlMsg[packetProbe]) peerState { return s } +func (s *stateDisconnected) OnLocalDiscovery(controlMsg[packetLocalDiscovery]) {} +func (s *stateDisconnected) OnPingTimer() peerState { return s } // ---------------------------------------------------------------------------- -type StateServer struct { - *StateDisconnected +type stateServer struct { + *stateDisconnected lastSeen time.Time synTraceID uint64 } -func EnterStateServer(s *State) PeerState { +func enterStateServer(s *pState) peerState { s.logf("==> Server") - return &StateServer{StateDisconnected: &StateDisconnected{State: s}} + return &stateServer{stateDisconnected: &stateDisconnected{pState: s}} } -func (s *StateServer) OnSyn(msg controlMsg[PacketSyn]) PeerState { +func (s *stateServer) OnSyn(msg controlMsg[packetSyn]) peerState { s.lastSeen = time.Now() p := msg.Packet @@ -194,7 +189,7 @@ func (s *StateServer) OnSyn(msg controlMsg[PacketSyn]) PeerState { } // Always respond. - ack := PacketAck{ + ack := packetAck{ TraceID: p.TraceID, ToAddr: s.staged.DirectAddr, PossibleAddrs: s.pubAddrs.Get(), @@ -202,55 +197,55 @@ func (s *StateServer) OnSyn(msg controlMsg[PacketSyn]) PeerState { s.Send(s.staged, ack) if p.Direct { - return nil + return s } for _, addr := range msg.Packet.PossibleAddrs { if !addr.IsValid() { break } - s.SendTo(PacketProbe{TraceID: newTraceID()}, addr) + s.SendTo(packetProbe{TraceID: newTraceID()}, addr) } - return nil + return s } -func (s *StateServer) OnProbe(msg controlMsg[PacketProbe]) PeerState { +func (s *stateServer) OnProbe(msg controlMsg[packetProbe]) peerState { if msg.SrcAddr.IsValid() { - s.SendTo(PacketProbe{TraceID: msg.Packet.TraceID}, msg.SrcAddr) + s.SendTo(packetProbe{TraceID: msg.Packet.TraceID}, msg.SrcAddr) } - return nil + return s } -func (s *StateServer) OnPingTimer() PeerState { +func (s *stateServer) OnPingTimer() peerState { if time.Since(s.lastSeen) > timeoutInterval && s.staged.Up { s.staged.Up = false s.publish(s.staged) s.logf("Timeout.") } - return nil + return s } // ---------------------------------------------------------------------------- -type StateClientDirect struct { - *StateDisconnected +type stateClientDirect struct { + *stateDisconnected lastSeen time.Time - syn PacketSyn + syn packetSyn } -func EnterStateClientDirect(s *State) PeerState { +func enterStateClientDirect(s *pState) peerState { s.logf("==> ClientDirect") - return NewStateClientDirect(s) + return newStateClientDirect(s) } -func NewStateClientDirect(s *State) *StateClientDirect { - state := &StateClientDirect{ - StateDisconnected: &StateDisconnected{s}, +func newStateClientDirect(s *pState) *stateClientDirect { + state := &stateClientDirect{ + stateDisconnected: &stateDisconnected{s}, lastSeen: time.Now(), // Avoid immediate timeout. } - state.syn = PacketSyn{ + state.syn = packetSyn{ TraceID: newTraceID(), SharedKey: s.staged.DataCipher.Key(), Direct: s.staged.Direct, @@ -260,7 +255,7 @@ func NewStateClientDirect(s *State) *StateClientDirect { return state } -func (s *StateClientDirect) OnAck(msg controlMsg[PacketAck]) { +func (s *stateClientDirect) OnAck(msg controlMsg[packetAck]) { if msg.Packet.TraceID != s.syn.TraceID { return } @@ -276,7 +271,14 @@ func (s *StateClientDirect) OnAck(msg controlMsg[PacketAck]) { s.pubAddrs.Store(msg.Packet.ToAddr) } -func (s *StateClientDirect) OnPingTimer() PeerState { +func (s *stateClientDirect) OnPingTimer() peerState { + if next := s.onPingTimer(); next != nil { + return next + } + return s +} + +func (s *stateClientDirect) onPingTimer() peerState { if time.Since(s.lastSeen) > timeoutInterval { if s.staged.Up { s.staged.Up = false @@ -292,47 +294,47 @@ func (s *StateClientDirect) OnPingTimer() PeerState { // ---------------------------------------------------------------------------- -type StateClientRelayed struct { - *StateClientDirect - ack PacketAck +type stateClientRelayed struct { + *stateClientDirect + ack packetAck probes map[uint64]netip.AddrPort localDiscoveryAddr netip.AddrPort } -func EnterStateClientRelayed(s *State) PeerState { +func enterStateClientRelayed(s *pState) peerState { s.logf("==> ClientRelayed") - return &StateClientRelayed{ - StateClientDirect: NewStateClientDirect(s), + return &stateClientRelayed{ + stateClientDirect: newStateClientDirect(s), probes: map[uint64]netip.AddrPort{}, } } -func (s *StateClientRelayed) OnAck(msg controlMsg[PacketAck]) { +func (s *stateClientRelayed) OnAck(msg controlMsg[packetAck]) { s.ack = msg.Packet - s.StateClientDirect.OnAck(msg) + s.stateClientDirect.OnAck(msg) } -func (s *StateClientRelayed) OnProbe(msg controlMsg[PacketProbe]) PeerState { +func (s *stateClientRelayed) OnProbe(msg controlMsg[packetProbe]) peerState { addr, ok := s.probes[msg.Packet.TraceID] if !ok { - return nil + return s } s.staged.DirectAddr = addr s.staged.Direct = true s.publish(s.staged) - return EnterStateClientDirect(s.StateClientDirect.State) + return enterStateClientDirect(s.stateClientDirect.pState) } -func (s *StateClientRelayed) OnLocalDiscovery(msg controlMsg[PacketLocalDiscovery]) { +func (s *stateClientRelayed) OnLocalDiscovery(msg controlMsg[packetLocalDiscovery]) { // The source port will be the multicast port, so we'll have to // construct the correct address using the peer's listed port. s.localDiscoveryAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) } -func (s *StateClientRelayed) OnPingTimer() PeerState { - if nextState := s.StateClientDirect.OnPingTimer(); nextState != nil { - return nextState +func (s *stateClientRelayed) OnPingTimer() peerState { + if next := s.stateClientDirect.onPingTimer(); next != nil { + return next } clear(s.probes) @@ -348,11 +350,11 @@ func (s *StateClientRelayed) OnPingTimer() PeerState { s.localDiscoveryAddr = netip.AddrPort{} } - return nil + return s } -func (s *StateClientRelayed) sendProbeTo(addr netip.AddrPort) { - probe := PacketProbe{TraceID: newTraceID()} +func (s *stateClientRelayed) sendProbeTo(addr netip.AddrPort) { + probe := packetProbe{TraceID: newTraceID()} s.probes[probe.TraceID] = addr s.SendTo(probe, addr) } diff --git a/peer/peerstates_test.go b/peer/peerstates_test.go index 16805d0..1ac531d 100644 --- a/peer/peerstates_test.go +++ b/peer/peerstates_test.go @@ -12,13 +12,13 @@ import ( // ---------------------------------------------------------------------------- type PeerStateControlMsg struct { - Peer RemotePeer + Peer remotePeer Packet any } type PeerStateTestHarness struct { - State PeerState - Published RemotePeer + State peerState + Published remotePeer Sent []PeerStateControlMsg } @@ -27,11 +27,11 @@ func NewPeerStateTestHarness() *PeerStateTestHarness { keys := generateKeys() - state := &State{ - publish: func(rp RemotePeer) { + state := &pState{ + publish: func(rp remotePeer) { h.Published = rp }, - sendControlPacket: func(rp RemotePeer, pkt Marshaller) { + sendControlPacket: func(rp remotePeer, pkt Marshaller) { h.Sent = append(h.Sent, PeerStateControlMsg{rp, pkt}) }, localIP: 2, @@ -44,7 +44,7 @@ func NewPeerStateTestHarness() *PeerStateTestHarness { }), } - h.State = EnterStateDisconnected(state) + h.State = enterStateDisconnected(state) return h } @@ -54,13 +54,13 @@ func (h *PeerStateTestHarness) PeerUpdate(p *m.Peer) { } } -func (h *PeerStateTestHarness) OnSyn(msg controlMsg[PacketSyn]) { +func (h *PeerStateTestHarness) OnSyn(msg controlMsg[packetSyn]) { if s := h.State.OnSyn(msg); s != nil { h.State = s } } -func (h *PeerStateTestHarness) OnProbe(msg controlMsg[PacketProbe]) { +func (h *PeerStateTestHarness) OnProbe(msg controlMsg[packetProbe]) { if s := h.State.OnProbe(msg); s != nil { h.State = s } @@ -72,10 +72,10 @@ func (h *PeerStateTestHarness) OnPingTimer() { } } -func (h *PeerStateTestHarness) ConfigServer_Public(t *testing.T) *StateServer { +func (h *PeerStateTestHarness) ConfigServer_Public(t *testing.T) *stateServer { keys := generateKeys() - state := h.State.(*StateDisconnected) + state := h.State.(*stateDisconnected) state.localAddr = addrPort4(1, 1, 1, 2, 200) peer := &m.Peer{ @@ -88,10 +88,10 @@ func (h *PeerStateTestHarness) ConfigServer_Public(t *testing.T) *StateServer { h.PeerUpdate(peer) assertEqual(t, h.Published.Up, false) - return assertType[*StateServer](t, h.State) + return assertType[*stateServer](t, h.State) } -func (h *PeerStateTestHarness) ConfigServer_Relayed(t *testing.T) *StateServer { +func (h *PeerStateTestHarness) ConfigServer_Relayed(t *testing.T) *stateServer { keys := generateKeys() peer := &m.Peer{ PeerIP: 3, @@ -102,10 +102,10 @@ func (h *PeerStateTestHarness) ConfigServer_Relayed(t *testing.T) *StateServer { h.PeerUpdate(peer) assertEqual(t, h.Published.Up, false) - return assertType[*StateServer](t, h.State) + return assertType[*stateServer](t, h.State) } -func (h *PeerStateTestHarness) ConfigClientDirect(t *testing.T) *StateClientDirect { +func (h *PeerStateTestHarness) ConfigClientDirect(t *testing.T) *stateClientDirect { keys := generateKeys() peer := &m.Peer{ PeerIP: 3, @@ -117,13 +117,13 @@ func (h *PeerStateTestHarness) ConfigClientDirect(t *testing.T) *StateClientDire h.PeerUpdate(peer) assertEqual(t, h.Published.Up, false) - return assertType[*StateClientDirect](t, h.State) + return assertType[*stateClientDirect](t, h.State) } -func (h *PeerStateTestHarness) ConfigClientRelayed(t *testing.T) *StateClientRelayed { +func (h *PeerStateTestHarness) ConfigClientRelayed(t *testing.T) *stateClientRelayed { keys := generateKeys() - state := h.State.(*StateDisconnected) + state := h.State.(*stateDisconnected) state.remoteIP = 1 peer := &m.Peer{ @@ -135,7 +135,7 @@ func (h *PeerStateTestHarness) ConfigClientRelayed(t *testing.T) *StateClientRel h.PeerUpdate(peer) assertEqual(t, h.Published.Up, false) - return assertType[*StateClientRelayed](t, h.State) + return assertType[*stateClientRelayed](t, h.State) } // ---------------------------------------------------------------------------- @@ -143,14 +143,14 @@ func (h *PeerStateTestHarness) ConfigClientRelayed(t *testing.T) *StateClientRel func TestPeerState_OnPeerUpdate_nilPeer(t *testing.T) { h := NewPeerStateTestHarness() h.PeerUpdate(nil) - assertType[*StateDisconnected](t, h.State) + assertType[*stateDisconnected](t, h.State) } func TestPeerState_OnPeerUpdate_publicLocalIsServer(t *testing.T) { keys := generateKeys() h := NewPeerStateTestHarness() - state := h.State.(*StateDisconnected) + state := h.State.(*stateDisconnected) state.localAddr = addrPort4(1, 1, 1, 2, 200) peer := &m.Peer{ @@ -162,7 +162,7 @@ func TestPeerState_OnPeerUpdate_publicLocalIsServer(t *testing.T) { h.PeerUpdate(peer) assertEqual(t, h.Published.Up, false) - assertType[*StateServer](t, h.State) + assertType[*stateServer](t, h.State) } func TestPeerState_OnPeerUpdate_serverDirect(t *testing.T) { @@ -191,10 +191,10 @@ func TestStateServer_directSyn(t *testing.T) { assertEqual(t, h.Published.Up, false) - synMsg := controlMsg[PacketSyn]{ + synMsg := controlMsg[packetSyn]{ SrcIP: 3, SrcAddr: addrPort4(1, 1, 1, 3, 300), - Packet: PacketSyn{ + Packet: packetSyn{ TraceID: newTraceID(), //SentAt: time.Now().UnixMilli(), //SharedKeyType: 1, @@ -205,7 +205,7 @@ func TestStateServer_directSyn(t *testing.T) { h.State.OnSyn(synMsg) assertEqual(t, len(h.Sent), 1) - ack := assertType[PacketAck](t, h.Sent[0].Packet) + ack := assertType[packetAck](t, h.Sent[0].Packet) assertEqual(t, ack.TraceID, synMsg.Packet.TraceID) assertEqual(t, h.Sent[0].Peer.IP, 3) assertEqual(t, ack.PossibleAddrs[0].IsValid(), false) @@ -220,10 +220,10 @@ func TestStateServer_relayedSyn(t *testing.T) { assertEqual(t, h.Published.Up, false) - synMsg := controlMsg[PacketSyn]{ + synMsg := controlMsg[packetSyn]{ SrcIP: 3, SrcAddr: addrPort4(1, 1, 1, 3, 300), - Packet: PacketSyn{ + Packet: packetSyn{ TraceID: newTraceID(), //SentAt: time.Now().UnixMilli(), //SharedKeyType: 1, @@ -237,15 +237,15 @@ func TestStateServer_relayedSyn(t *testing.T) { assertEqual(t, len(h.Sent), 3) - ack := assertType[PacketAck](t, h.Sent[0].Packet) + ack := assertType[packetAck](t, h.Sent[0].Packet) assertEqual(t, ack.TraceID, synMsg.Packet.TraceID) assertEqual(t, h.Sent[0].Peer.IP, 3) assertEqual(t, ack.PossibleAddrs[0], addrPort4(4, 5, 6, 7, 1234)) assertEqual(t, ack.PossibleAddrs[1].IsValid(), false) assertEqual(t, h.Published.Up, true) - assertType[PacketProbe](t, h.Sent[1].Packet) - assertType[PacketProbe](t, h.Sent[2].Packet) + assertType[packetProbe](t, h.Sent[1].Packet) + assertType[packetProbe](t, h.Sent[2].Packet) assertEqual(t, h.Sent[1].Peer.DirectAddr, addrPort4(1, 1, 1, 3, 300)) assertEqual(t, h.Sent[2].Peer.DirectAddr, addrPort4(2, 2, 2, 3, 300)) } @@ -255,17 +255,17 @@ func TestStateServer_onProbe(t *testing.T) { h.ConfigServer_Relayed(t) assertEqual(t, h.Published.Up, false) - probeMsg := controlMsg[PacketProbe]{ + probeMsg := controlMsg[packetProbe]{ SrcIP: 3, SrcAddr: addrPort4(1, 1, 1, 3, 300), - Packet: PacketProbe{TraceID: newTraceID()}, + Packet: packetProbe{TraceID: newTraceID()}, } h.State.OnProbe(probeMsg) assertEqual(t, len(h.Sent), 1) - probe := assertType[PacketProbe](t, h.Sent[0].Packet) + probe := assertType[packetProbe](t, h.Sent[0].Packet) assertEqual(t, probe.TraceID, probeMsg.Packet.TraceID) assertEqual(t, h.Sent[0].Peer.DirectAddr, addrPort4(1, 1, 1, 3, 300)) } @@ -274,10 +274,10 @@ func TestStateServer_OnPingTimer_timeout(t *testing.T) { h := NewPeerStateTestHarness() h.ConfigServer_Relayed(t) - synMsg := controlMsg[PacketSyn]{ + synMsg := controlMsg[packetSyn]{ SrcIP: 3, SrcAddr: addrPort4(1, 1, 1, 3, 300), - Packet: PacketSyn{ + Packet: packetSyn{ TraceID: newTraceID(), //SentAt: time.Now().UnixMilli(), //SharedKeyType: 1, @@ -294,7 +294,7 @@ func TestStateServer_OnPingTimer_timeout(t *testing.T) { assertEqual(t, h.Published.Up, true) // Advance the time, then ping. - state := assertType[*StateServer](t, h.State) + state := assertType[*stateServer](t, h.State) state.lastSeen = time.Now().Add(-timeoutInterval - time.Second) h.OnPingTimer() @@ -309,10 +309,10 @@ func TestStateClientDirect_OnAck(t *testing.T) { // On entering the state, a SYN should have been sent. assertEqual(t, len(h.Sent), 1) - syn := assertType[PacketSyn](t, h.Sent[0].Packet) + syn := assertType[packetSyn](t, h.Sent[0].Packet) - ack := controlMsg[PacketAck]{ - Packet: PacketAck{TraceID: syn.TraceID}, + ack := controlMsg[packetAck]{ + Packet: packetAck{TraceID: syn.TraceID}, } h.State.OnAck(ack) assertEqual(t, h.Published.Up, true) @@ -326,10 +326,10 @@ func TestStateClientDirect_OnAck_incorrectTraceID(t *testing.T) { // On entering the state, a SYN should have been sent. assertEqual(t, len(h.Sent), 1) - syn := assertType[PacketSyn](t, h.Sent[0].Packet) + syn := assertType[packetSyn](t, h.Sent[0].Packet) - ack := controlMsg[PacketAck]{ - Packet: PacketAck{TraceID: syn.TraceID + 1}, + ack := controlMsg[packetAck]{ + Packet: packetAck{TraceID: syn.TraceID + 1}, } h.State.OnAck(ack) assertEqual(t, h.Published.Up, false) @@ -341,15 +341,15 @@ func TestStateClientDirect_OnPingTimer(t *testing.T) { // On entering the state, a SYN should have been sent. assertEqual(t, len(h.Sent), 1) - assertType[PacketSyn](t, h.Sent[0].Packet) + assertType[packetSyn](t, h.Sent[0].Packet) h.OnPingTimer() // On ping timer, another syn should be sent. Additionally, we should remain // in the same state. assertEqual(t, len(h.Sent), 2) - assertType[PacketSyn](t, h.Sent[1].Packet) - assertType[*StateClientDirect](t, h.State) + assertType[packetSyn](t, h.Sent[1].Packet) + assertType[*stateClientDirect](t, h.State) assertEqual(t, h.Published.Up, false) } @@ -361,15 +361,15 @@ func TestStateClientDirect_OnPingTimer_timeout(t *testing.T) { // On entering the state, a SYN should have been sent. assertEqual(t, len(h.Sent), 1) - syn := assertType[PacketSyn](t, h.Sent[0].Packet) + syn := assertType[packetSyn](t, h.Sent[0].Packet) - ack := controlMsg[PacketAck]{ - Packet: PacketAck{TraceID: syn.TraceID}, + ack := controlMsg[packetAck]{ + Packet: packetAck{TraceID: syn.TraceID}, } h.State.OnAck(ack) assertEqual(t, h.Published.Up, true) - state := assertType[*StateClientDirect](t, h.State) + state := assertType[*stateClientDirect](t, h.State) state.lastSeen = time.Now().Add(-(timeoutInterval + time.Second)) h.OnPingTimer() @@ -377,8 +377,8 @@ func TestStateClientDirect_OnPingTimer_timeout(t *testing.T) { // On ping timer, we should timeout, causing the client to reset. Another SYN // will be sent when re-entering the state, but the connection should be down. assertEqual(t, len(h.Sent), 2) - assertType[PacketSyn](t, h.Sent[1].Packet) - assertType[*StateClientDirect](t, h.State) + assertType[packetSyn](t, h.Sent[1].Packet) + assertType[*stateClientDirect](t, h.State) assertEqual(t, h.Published.Up, false) } @@ -390,10 +390,10 @@ func TestStateClientRelayed_OnAck(t *testing.T) { // On entering the state, a SYN should have been sent. assertEqual(t, len(h.Sent), 1) - syn := assertType[PacketSyn](t, h.Sent[0].Packet) + syn := assertType[packetSyn](t, h.Sent[0].Packet) - ack := controlMsg[PacketAck]{ - Packet: PacketAck{TraceID: syn.TraceID}, + ack := controlMsg[packetAck]{ + Packet: packetAck{TraceID: syn.TraceID}, } h.State.OnAck(ack) assertEqual(t, h.Published.Up, true) @@ -423,9 +423,9 @@ func TestStateClientRelayed_OnPingTimer_withAddrs(t *testing.T) { // On entering the state, a SYN should have been sent. assertEqual(t, len(h.Sent), 1) - syn := assertType[PacketSyn](t, h.Sent[0].Packet) + syn := assertType[packetSyn](t, h.Sent[0].Packet) - ack := controlMsg[PacketAck]{Packet: PacketAck{TraceID: syn.TraceID}} + ack := controlMsg[packetAck]{Packet: packetAck{TraceID: syn.TraceID}} ack.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 300) ack.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 300) @@ -433,7 +433,7 @@ func TestStateClientRelayed_OnPingTimer_withAddrs(t *testing.T) { // Add a local discovery address. Note that the port will be configured port // and no the one provided here. - h.State.OnLocalDiscovery(controlMsg[PacketLocalDiscovery]{ + h.State.OnLocalDiscovery(controlMsg[packetLocalDiscovery]{ SrcIP: 3, SrcAddr: addrPort4(2, 2, 2, 3, 300), }) @@ -441,10 +441,10 @@ func TestStateClientRelayed_OnPingTimer_withAddrs(t *testing.T) { // We should see one SYN and three probe packets. h.OnPingTimer() assertEqual(t, len(h.Sent), 5) - assertType[PacketSyn](t, h.Sent[1].Packet) - assertType[PacketProbe](t, h.Sent[2].Packet) - assertType[PacketProbe](t, h.Sent[3].Packet) - assertType[PacketProbe](t, h.Sent[4].Packet) + assertType[packetSyn](t, h.Sent[1].Packet) + assertType[packetProbe](t, h.Sent[2].Packet) + assertType[packetProbe](t, h.Sent[3].Packet) + assertType[packetProbe](t, h.Sent[4].Packet) assertEqual(t, h.Sent[2].Peer.DirectAddr, addrPort4(1, 1, 1, 1, 300)) assertEqual(t, h.Sent[3].Peer.DirectAddr, addrPort4(1, 1, 1, 2, 300)) @@ -457,15 +457,15 @@ func TestStateClientRelayed_OnPingTimer_timeout(t *testing.T) { // On entering the state, a SYN should have been sent. assertEqual(t, len(h.Sent), 1) - syn := assertType[PacketSyn](t, h.Sent[0].Packet) + syn := assertType[packetSyn](t, h.Sent[0].Packet) - ack := controlMsg[PacketAck]{ - Packet: PacketAck{TraceID: syn.TraceID}, + ack := controlMsg[packetAck]{ + Packet: packetAck{TraceID: syn.TraceID}, } h.State.OnAck(ack) assertEqual(t, h.Published.Up, true) - state := assertType[*StateClientRelayed](t, h.State) + state := assertType[*stateClientRelayed](t, h.State) state.lastSeen = time.Now().Add(-(timeoutInterval + time.Second)) h.OnPingTimer() @@ -473,8 +473,8 @@ func TestStateClientRelayed_OnPingTimer_timeout(t *testing.T) { // On ping timer, we should timeout, causing the client to reset. Another SYN // will be sent when re-entering the state, but the connection should be down. assertEqual(t, len(h.Sent), 2) - assertType[PacketSyn](t, h.Sent[1].Packet) - assertType[*StateClientRelayed](t, h.State) + assertType[packetSyn](t, h.Sent[1].Packet) + assertType[*stateClientRelayed](t, h.State) assertEqual(t, h.Published.Up, false) } @@ -482,28 +482,28 @@ func TestStateClientRelayed_OnProbe_unknownAddr(t *testing.T) { h := NewPeerStateTestHarness() h.ConfigClientRelayed(t) - h.OnProbe(controlMsg[PacketProbe]{ - Packet: PacketProbe{TraceID: newTraceID()}, + h.OnProbe(controlMsg[packetProbe]{ + Packet: packetProbe{TraceID: newTraceID()}, }) - assertType[*StateClientRelayed](t, h.State) + assertType[*stateClientRelayed](t, h.State) } func TestStateClientRelayed_OnProbe_upgradeDirect(t *testing.T) { h := NewPeerStateTestHarness() h.ConfigClientRelayed(t) - syn := assertType[PacketSyn](t, h.Sent[0].Packet) + syn := assertType[packetSyn](t, h.Sent[0].Packet) - ack := controlMsg[PacketAck]{Packet: PacketAck{TraceID: syn.TraceID}} + ack := controlMsg[packetAck]{Packet: packetAck{TraceID: syn.TraceID}} ack.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 300) ack.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 300) h.State.OnAck(ack) h.OnPingTimer() - probe := assertType[PacketProbe](t, h.Sent[2].Packet) - h.OnProbe(controlMsg[PacketProbe]{Packet: probe}) + probe := assertType[packetProbe](t, h.Sent[2].Packet) + h.OnProbe(controlMsg[packetProbe]{Packet: probe}) - assertType[*StateClientDirect](t, h.State) + assertType[*stateClientDirect](t, h.State) } diff --git a/peer/peersuper.go b/peer/peersuper.go new file mode 100644 index 0000000..d402141 --- /dev/null +++ b/peer/peersuper.go @@ -0,0 +1,172 @@ +package peer + +import ( + "log" + "math/rand" + "net/netip" + "sync" + "sync/atomic" + "time" + + "git.crumpington.com/lib/go/ratelimiter" +) + +type Super struct { + writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error) + staged routingTable + shared *atomic.Pointer[routingTable] + peers [256]*PeerSuper + lock sync.Mutex + + buf1 []byte + buf2 []byte +} + +func NewSuper( + writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error), + rt *atomic.Pointer[routingTable], + privKey []byte, +) *Super { + + routes := rt.Load() + + s := &Super{ + writeToUDPAddrPort: writeToUDPAddrPort, + staged: *routes, + shared: rt, + buf1: newBuf(), + buf2: newBuf(), + } + + pubAddrs := newPubAddrStore(routes.LocalAddr) + + for i := range s.peers { + state := &pState{ + publish: s.publish, + sendControlPacket: s.send, + localIP: routes.LocalIP, + remoteIP: byte(i), + privKey: privKey, + localAddr: routes.LocalAddr, + pubAddrs: pubAddrs, + staged: routes.Peers[i], + limiter: ratelimiter.New(ratelimiter.Config{ + FillPeriod: 20 * time.Millisecond, + MaxWaitCount: 1, + }), + } + s.peers[i] = NewPeerSuper(state) + } + + return s +} + +func (s *Super) Start() { + for i := range s.peers { + go s.peers[i].Run() + } +} + +func (s *Super) HandleControlMsg(destIP byte, msg any) { + s.peers[destIP].HandleControlMsg(msg) +} + +func (s *Super) send(peer remotePeer, pkt Marshaller) { + s.lock.Lock() + defer s.lock.Unlock() + + enc := peer.EncryptControlPacket(pkt, s.buf1, s.buf2) + if peer.Direct { + s.writeToUDPAddrPort(enc, peer.DirectAddr) + return + } + + relay, ok := s.staged.GetRelay() + if !ok { + return + } + + enc = relay.EncryptDataPacket(peer.IP, enc, s.buf1) + s.writeToUDPAddrPort(enc, relay.DirectAddr) +} + +func (s *Super) publish(rp remotePeer) { + s.lock.Lock() + defer s.lock.Unlock() + + s.staged.Peers[rp.IP] = rp + s.ensureRelay() + copy := s.staged + s.shared.Store(©) +} + +func (s *Super) ensureRelay() { + if _, ok := s.staged.GetRelay(); ok { + return + } + + // TODO: Random selection? + for _, peer := range s.staged.Peers { + if peer.Up && peer.Direct && peer.Relay { + s.staged.RelayIP = peer.IP + return + } + } +} + +// ---------------------------------------------------------------------------- + +type PeerSuper struct { + messages chan any + state peerState +} + +func NewPeerSuper(state *pState) *PeerSuper { + return &PeerSuper{ + messages: make(chan any, 8), + state: state.OnPeerUpdate(nil), + } +} + +func (s *PeerSuper) HandleControlMsg(msg any) { + select { + case s.messages <- msg: + default: + } +} + +func (s *PeerSuper) Run() { + go func() { + // Randomize ping timers. + time.Sleep(time.Duration(rand.Intn(4000)) * time.Millisecond) + for range time.Tick(pingInterval) { + s.messages <- pingTimerMsg{} + } + }() + + for rawMsg := range s.messages { + switch msg := rawMsg.(type) { + + case peerUpdateMsg: + s.state = s.state.OnPeerUpdate(msg.Peer) + + case controlMsg[packetSyn]: + s.state = s.state.OnSyn(msg) + + case controlMsg[packetAck]: + s.state.OnAck(msg) + + case controlMsg[packetProbe]: + s.state = s.state.OnProbe(msg) + + case controlMsg[packetLocalDiscovery]: + s.state.OnLocalDiscovery(msg) + + case pingTimerMsg: + s.state = s.state.OnPingTimer() + + default: + log.Printf("WARNING: unknown message type: %+v", msg) + } + } +} diff --git a/peer/routingtable.go b/peer/routingtable.go index 0943ab2..7bbf542 100644 --- a/peer/routingtable.go +++ b/peer/routingtable.go @@ -7,9 +7,9 @@ import ( ) // TODO: Remove -func NewRemotePeer(ip byte) *RemotePeer { +func NewRemotePeer(ip byte) *remotePeer { counter := uint64(time.Now().Unix()<<30 + 1) - return &RemotePeer{ + return &remotePeer{ IP: ip, counter: &counter, dupCheck: newDupCheck(0), @@ -18,7 +18,7 @@ func NewRemotePeer(ip byte) *RemotePeer { // ---------------------------------------------------------------------------- -type RemotePeer struct { +type remotePeer struct { localIP byte IP byte // VPN IP of peer (last byte). Up bool // True if data can be sent on the peer. @@ -33,7 +33,7 @@ type RemotePeer struct { dupCheck *dupCheck // For receiving from. Not safe for concurrent use. } -func (p RemotePeer) EncryptDataPacket(destIP byte, data, out []byte) []byte { +func (p remotePeer) EncryptDataPacket(destIP byte, data, out []byte) []byte { h := header{ StreamID: dataStreamID, Counter: atomic.AddUint64(p.counter, 1), @@ -44,7 +44,7 @@ func (p RemotePeer) EncryptDataPacket(destIP byte, data, out []byte) []byte { } // Decrypts and de-dups incoming data packets. -func (p RemotePeer) DecryptDataPacket(h header, enc, out []byte) ([]byte, error) { +func (p remotePeer) DecryptDataPacket(h header, enc, out []byte) ([]byte, error) { dec, ok := p.DataCipher.Decrypt(enc, out) if !ok { return nil, errDecryptionFailed @@ -58,21 +58,22 @@ func (p RemotePeer) DecryptDataPacket(h header, enc, out []byte) ([]byte, error) } // Peer must have a ControlCipher. -func (p RemotePeer) EncryptControlPacket(pkt Marshaller, tmp, out []byte) []byte { +func (p remotePeer) EncryptControlPacket(pkt Marshaller, tmp, out []byte) []byte { + tmp = pkt.Marshal(tmp) h := header{ StreamID: controlStreamID, Counter: atomic.AddUint64(p.counter, 1), SourceIP: p.localIP, DestIP: p.IP, } - tmp = pkt.Marshal(tmp) + return p.ControlCipher.Encrypt(h, tmp, out) } // Returns a controlMsg[PacketType]. Peer must have a non-nil ControlCipher. // // This function also drops packets with duplicate sequence numbers. -func (p RemotePeer) DecryptControlPacket(fromAddr netip.AddrPort, h header, enc, tmp []byte) (any, error) { +func (p remotePeer) DecryptControlPacket(fromAddr netip.AddrPort, h header, enc, tmp []byte) (any, error) { out, ok := p.ControlCipher.Decrypt(enc, tmp) if !ok { return nil, errDecryptionFailed @@ -92,7 +93,7 @@ func (p RemotePeer) DecryptControlPacket(fromAddr netip.AddrPort, h header, enc, // ---------------------------------------------------------------------------- -type RoutingTable struct { +type routingTable struct { // The LocalIP is the configured IP address of the local peer on the VPN. // // This value is constant. @@ -106,21 +107,21 @@ type RoutingTable struct { LocalAddr netip.AddrPort // The remote peer configurations. These are updated by - Peers [256]RemotePeer + Peers [256]remotePeer // The current relay's VPN IP address, or zero if no relay is available. RelayIP byte } -func NewRoutingTable(localIP byte, localAddr netip.AddrPort) RoutingTable { - rt := RoutingTable{ +func newRoutingTable(localIP byte, localAddr netip.AddrPort) routingTable { + rt := routingTable{ LocalIP: localIP, LocalAddr: localAddr, } for i := range rt.Peers { counter := uint64(time.Now().Unix()<<30 + 1) - rt.Peers[i] = RemotePeer{ + rt.Peers[i] = remotePeer{ localIP: localIP, IP: byte(i), counter: &counter, @@ -131,7 +132,7 @@ func NewRoutingTable(localIP byte, localAddr netip.AddrPort) RoutingTable { return rt } -func (rt *RoutingTable) GetRelay() (RemotePeer, bool) { +func (rt *routingTable) GetRelay() (remotePeer, bool) { relay := rt.Peers[rt.RelayIP] return relay, relay.Up && relay.Direct } diff --git a/peer/routingtable_test.go b/peer/routingtable_test.go index b5497a4..919449b 100644 --- a/peer/routingtable_test.go +++ b/peer/routingtable_test.go @@ -74,7 +74,7 @@ func TestRemotePeer_DecryptControlPacket(t *testing.T) { peer2 := p1.RT.Load().Peers[2] peer1 := p2.RT.Load().Peers[1] - orig := PacketProbe{TraceID: newTraceID()} + orig := packetProbe{TraceID: newTraceID()} enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf()) @@ -88,7 +88,7 @@ func TestRemotePeer_DecryptControlPacket(t *testing.T) { t.Fatal(err) } - dec, ok := ctrlMsg.(controlMsg[PacketProbe]) + dec, ok := ctrlMsg.(controlMsg[packetProbe]) if !ok { t.Fatal(ctrlMsg) } @@ -108,7 +108,7 @@ func TestRemotePeer_DecryptControlPacket_packetAltered(t *testing.T) { peer2 := p1.RT.Load().Peers[2] peer1 := p2.RT.Load().Peers[1] - orig := PacketProbe{TraceID: newTraceID()} + orig := packetProbe{TraceID: newTraceID()} enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf()) @@ -131,7 +131,7 @@ func TestRemotePeer_DecryptControlPacket_duplicateSequenceNumber(t *testing.T) { peer2 := p1.RT.Load().Peers[2] peer1 := p2.RT.Load().Peers[1] - orig := PacketProbe{TraceID: newTraceID()} + orig := packetProbe{TraceID: newTraceID()} enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf()) diff --git a/peer/supervisor.go b/peer/supervisor.go deleted file mode 100644 index 0f82a3f..0000000 --- a/peer/supervisor.go +++ /dev/null @@ -1,103 +0,0 @@ -package peer - -import ( - "log" - "sync/atomic" - "time" - - "git.crumpington.com/lib/go/ratelimiter" -) - -// ---------------------------------------------------------------------------- - -type Supervisor struct { - messages chan any // Incoming control messages. - peers [256]PeerState - pubAddrs *pubAddrStore - rt *atomic.Pointer[RoutingTable] - staged RoutingTable -} - -func NewSupervisor( - sendControl func(RemotePeer, Marshaller), - privKey []byte, - rt *atomic.Pointer[RoutingTable], -) *Supervisor { - s := &Supervisor{ - messages: make(chan any, 1024), - pubAddrs: newPubAddrStore(rt.Load().LocalAddr), - rt: rt, - } - - routes := rt.Load() - - for i := range s.peers { - state := &State{ - publish: s.Publish, - sendControlPacket: sendControl, - localIP: routes.LocalIP, - remoteIP: byte(i), - privKey: privKey, - localAddr: routes.LocalAddr, - pubAddrs: s.pubAddrs, - staged: routes.Peers[i], - limiter: ratelimiter.New(ratelimiter.Config{ - FillPeriod: 20 * time.Millisecond, - MaxWaitCount: 1, - }), - } - s.peers[i] = state.OnPeerUpdate(nil) - } - - return s -} - -func (s *Supervisor) HandleControlMsg(msg any) { - select { - case s.messages <- msg: - default: - } -} - -func (s *Supervisor) Run() { - for raw := range s.messages { - switch msg := raw.(type) { - - case peerUpdateMsg: - s.peers[msg.PeerIP] = s.peers[msg.PeerIP].OnPeerUpdate(msg.Peer) - - case controlMsg[PacketSyn]: - if newState := s.peers[msg.SrcIP].OnSyn(msg); newState != nil { - s.peers[msg.SrcIP] = newState - } - - case controlMsg[PacketAck]: - s.peers[msg.SrcIP].OnAck(msg) - - case controlMsg[PacketProbe]: - if newState := s.peers[msg.SrcIP].OnProbe(msg); newState != nil { - s.peers[msg.SrcIP] = newState - } - - case controlMsg[PacketLocalDiscovery]: - s.peers[msg.SrcIP].OnLocalDiscovery(msg) - - case pingTimerMsg: - s.pubAddrs.Clean() - for i := range s.peers { - if newState := s.peers[i].OnPingTimer(); newState != nil { - s.peers[i] = newState - } - } - - default: - log.Printf("WARNING: unknown message type: %+v", msg) - } - } -} - -func (s *Supervisor) Publish(rp RemotePeer) { - s.staged.Peers[rp.IP] = rp - rt := s.staged // Copy. - s.rt.Store(&rt) -} -- 2.39.5 From e1a5f50e1a59e5aa9a65006b8ab945464172375e Mon Sep 17 00:00:00 2001 From: jdl Date: Wed, 19 Feb 2025 14:22:26 +0100 Subject: [PATCH 09/26] wip --- peer/connreader.go | 88 +++++++----- peer/connreader2.go | 141 ------------------- peer/crypto.go | 2 +- peer/crypto_test.go | 4 +- peer/globals.go | 4 + peer/{ifreader2.go => ifreader.go} | 18 +-- peer/{ifreader2_test.go => ifreader_test.go} | 0 peer/interface.go | 40 ------ peer/interfaces.go | 49 ------- peer/main.go | 4 +- peer/mcreader.go | 7 +- peer/mcwriter.go | 6 +- peer/mcwriter_test.go | 8 +- peer/peer.go | 24 ++-- peer/peer_test.go | 4 +- peer/peerstates.go | 6 +- peer/peerstates_test.go | 2 +- peer/peersuper.go | 32 ++--- peer/routingtable.go | 4 +- 19 files changed, 110 insertions(+), 333 deletions(-) delete mode 100644 peer/connreader2.go rename peer/{ifreader2.go => ifreader.go} (81%) rename peer/{ifreader2_test.go => ifreader_test.go} (100%) delete mode 100644 peer/interfaces.go diff --git a/peer/connreader.go b/peer/connreader.go index 37a4c87..a07275e 100644 --- a/peer/connreader.go +++ b/peer/connreader.go @@ -1,40 +1,44 @@ package peer import ( + "io" "log" "net/netip" "sync/atomic" ) type connReader struct { - conn udpReader - iface ifWriter - sender encryptedPacketSender - super controlMsgHandler + // Input + readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error) + + // Output + writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error) + iface io.Writer + handleControlMsg func(fromIP byte, pkt any) + localIP byte - peers [256]*atomic.Pointer[remotePeer] + rt *atomic.Pointer[routingTable] buf []byte decBuf []byte } func newConnReader( - conn udpReader, - ifWriter ifWriter, - sender encryptedPacketSender, - super controlMsgHandler, - localIP byte, - peers [256]*atomic.Pointer[remotePeer], + readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error), + writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error), + iface io.Writer, + handleControlMsg func(fromIP byte, pkt any), + rt *atomic.Pointer[routingTable], ) *connReader { return &connReader{ - conn: conn, - iface: ifWriter, - sender: sender, - super: super, - localIP: localIP, - peers: peers, - buf: make([]byte, bufferSize), - decBuf: make([]byte, bufferSize), + readFromUDPAddrPort: readFromUDPAddrPort, + writeToUDPAddrPort: writeToUDPAddrPort, + iface: iface, + handleControlMsg: handleControlMsg, + localIP: rt.Load().LocalIP, + rt: rt, + buf: newBuf(), + decBuf: newBuf(), } } @@ -44,13 +48,11 @@ func (r *connReader) Run() { } } -func (r *connReader) logf(s string, args ...any) { - log.Printf("[ConnReader] "+s, args...) -} - func (r *connReader) handleNextPacket() { buf := r.buf[:bufferSize] - n, remoteAddr, err := r.conn.ReadFromUDPAddrPort(buf) + log.Printf("Getting next packet...") + n, remoteAddr, err := r.readFromUDPAddrPort(buf) + log.Printf("Packet from %v...", remoteAddr) if err != nil { log.Fatalf("Failed to read from UDP port: %v", err) } @@ -64,23 +66,22 @@ func (r *connReader) handleNextPacket() { buf = buf[:n] h := parseHeader(buf) - peer := r.peers[h.SourceIP].Load() + rt := r.rt.Load() + peer := rt.Peers[h.SourceIP] switch h.StreamID { case controlStreamID: - r.handleControlPacket(peer, remoteAddr, h, buf) - + r.handleControlPacket(remoteAddr, peer, h, buf) case dataStreamID: - r.handleDataPacket(peer, h, buf) - + r.handleDataPacket(rt, peer, h, buf) default: r.logf("Unknown stream ID: %d", h.StreamID) } } func (r *connReader) handleControlPacket( - peer *remotePeer, - addr netip.AddrPort, + remoteAddr netip.AddrPort, + peer remotePeer, h header, enc []byte, ) { @@ -93,22 +94,27 @@ func (r *connReader) handleControlPacket( return } - msg, err := decryptControlPacket(peer, addr, h, enc, r.decBuf) + msg, err := peer.DecryptControlPacket(remoteAddr, h, enc, r.decBuf) if err != nil { r.logf("Failed to decrypt control packet: %v", err) return } - r.super.HandleControlMsg(msg) + r.handleControlMsg(h.SourceIP, msg) } -func (r *connReader) handleDataPacket(peer *remotePeer, h header, enc []byte) { +func (r *connReader) handleDataPacket( + rt *routingTable, + peer remotePeer, + h header, + enc []byte, +) { if !peer.Up { r.logf("Not connected (recv).") return } - data, err := decryptDataPacket(peer, h, enc, r.decBuf) + data, err := peer.DecryptDataPacket(h, enc, r.decBuf) if err != nil { r.logf("Failed to decrypt data packet: %v", err) return @@ -121,11 +127,15 @@ func (r *connReader) handleDataPacket(peer *remotePeer, h header, enc []byte) { return } - destPeer := r.peers[h.DestIP].Load() - if !destPeer.Up { - r.logf("Not connected (relay): %d", destPeer.IP) + relay, ok := rt.GetRelay() + if !ok { + r.logf("Relay not available.") return } - r.sender.SendEncryptedDataPacket(data, destPeer) + r.writeToUDPAddrPort(data, relay.DirectAddr) +} + +func (r *connReader) logf(format string, args ...any) { + log.Printf("[ConnReader] "+format, args...) } diff --git a/peer/connreader2.go b/peer/connreader2.go deleted file mode 100644 index 9e870d7..0000000 --- a/peer/connreader2.go +++ /dev/null @@ -1,141 +0,0 @@ -package peer - -import ( - "io" - "log" - "net/netip" - "sync/atomic" -) - -type ConnReader struct { - // Input - readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error) - - // Output - writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error) - iface io.Writer - handleControlMsg func(fromIP byte, pkt any) - - localIP byte - rt *atomic.Pointer[routingTable] - - buf []byte - decBuf []byte -} - -func NewConnReader( - readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error), - writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error), - iface io.Writer, - handleControlMsg func(fromIP byte, pkt any), - rt *atomic.Pointer[routingTable], -) *ConnReader { - return &ConnReader{ - readFromUDPAddrPort: readFromUDPAddrPort, - writeToUDPAddrPort: writeToUDPAddrPort, - iface: iface, - handleControlMsg: handleControlMsg, - localIP: rt.Load().LocalIP, - rt: rt, - buf: newBuf(), - decBuf: newBuf(), - } -} - -func (r *ConnReader) Run() { - for { - r.handleNextPacket() - } -} - -func (r *ConnReader) handleNextPacket() { - buf := r.buf[:bufferSize] - log.Printf("Getting next packet...") - n, remoteAddr, err := r.readFromUDPAddrPort(buf) - log.Printf("Packet from %v...", remoteAddr) - if err != nil { - log.Fatalf("Failed to read from UDP port: %v", err) - } - - if n < headerSize { - return - } - - remoteAddr = netip.AddrPortFrom(remoteAddr.Addr().Unmap(), remoteAddr.Port()) - - buf = buf[:n] - h := parseHeader(buf) - - rt := r.rt.Load() - peer := rt.Peers[h.SourceIP] - - switch h.StreamID { - case controlStreamID: - r.handleControlPacket(remoteAddr, peer, h, buf) - case dataStreamID: - r.handleDataPacket(rt, peer, h, buf) - default: - r.logf("Unknown stream ID: %d", h.StreamID) - } -} - -func (r *ConnReader) handleControlPacket( - remoteAddr netip.AddrPort, - peer remotePeer, - h header, - enc []byte, -) { - if peer.ControlCipher == nil { - return - } - - if h.DestIP != r.localIP { - r.logf("Incorrect destination IP on control packet: %d", h.DestIP) - return - } - - msg, err := peer.DecryptControlPacket(remoteAddr, h, enc, r.decBuf) - if err != nil { - r.logf("Failed to decrypt control packet: %v", err) - return - } - - r.handleControlMsg(h.SourceIP, msg) -} - -func (r *ConnReader) handleDataPacket( - rt *routingTable, - peer remotePeer, - h header, - enc []byte, -) { - if !peer.Up { - r.logf("Not connected (recv).") - return - } - - data, err := peer.DecryptDataPacket(h, enc, r.decBuf) - if err != nil { - r.logf("Failed to decrypt data packet: %v", err) - return - } - - if h.DestIP == r.localIP { - if _, err := r.iface.Write(data); err != nil { - log.Fatalf("Failed to write to interface: %v", err) - } - return - } - - relay, ok := rt.GetRelay() - if !ok { - r.logf("Relay not available.") - return - } - - r.writeToUDPAddrPort(data, relay.DirectAddr) -} - -func (r *ConnReader) logf(format string, args ...any) { - log.Printf("[ConnReader] "+format, args...) -} diff --git a/peer/crypto.go b/peer/crypto.go index e8afe60..160f7fd 100644 --- a/peer/crypto.go +++ b/peer/crypto.go @@ -37,7 +37,7 @@ func generateKeys() cryptoKeys { func encryptControlPacket( localIP byte, peer *remotePeer, - pkt Marshaller, + pkt marshaller, tmp []byte, out []byte, ) []byte { diff --git a/peer/crypto_test.go b/peer/crypto_test.go index 57adfd2..802653c 100644 --- a/peer/crypto_test.go +++ b/peer/crypto_test.go @@ -13,12 +13,12 @@ func newRoutePairForTesting() (*remotePeer, *remotePeer) { keys1 := generateKeys() keys2 := generateKeys() - r1 := NewRemotePeer(1) + r1 := newRemotePeer(1) r1.PubSignKey = keys1.PubSignKey r1.ControlCipher = newControlCipher(keys1.PrivKey, keys2.PubKey) r1.DataCipher = newDataCipher() - r2 := NewRemotePeer(2) + r2 := newRemotePeer(2) r2.PubSignKey = keys2.PubSignKey r2.ControlCipher = newControlCipher(keys2.PrivKey, keys1.PubKey) r2.DataCipher = r1.DataCipher diff --git a/peer/globals.go b/peer/globals.go index 0d7ada3..f967c8a 100644 --- a/peer/globals.go +++ b/peer/globals.go @@ -27,3 +27,7 @@ var multicastAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom( func newBuf() []byte { return make([]byte, bufferSize) } + +type marshaller interface { + Marshal([]byte) []byte +} diff --git a/peer/ifreader2.go b/peer/ifreader.go similarity index 81% rename from peer/ifreader2.go rename to peer/ifreader.go index 22bd7cf..2419758 100644 --- a/peer/ifreader2.go +++ b/peer/ifreader.go @@ -7,7 +7,7 @@ import ( "sync/atomic" ) -type IFReader struct { +type ifReader struct { iface io.Reader writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error) rt *atomic.Pointer[routingTable] @@ -15,22 +15,22 @@ type IFReader struct { buf2 []byte } -func NewIFReader( +func newIFReader( iface io.Reader, writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error), rt *atomic.Pointer[routingTable], -) *IFReader { - return &IFReader{iface, writeToUDPAddrPort, rt, newBuf(), newBuf()} +) *ifReader { + return &ifReader{iface, writeToUDPAddrPort, rt, newBuf(), newBuf()} } -func (r *IFReader) Run() { +func (r *ifReader) Run() { packet := newBuf() for { r.handleNextPacket(packet) } } -func (r *IFReader) handleNextPacket(packet []byte) { +func (r *ifReader) handleNextPacket(packet []byte) { packet = r.readNextPacket(packet) remoteIP, ok := r.parsePacket(packet) if !ok { @@ -60,7 +60,7 @@ func (r *IFReader) handleNextPacket(packet []byte) { r.writeToUDPAddrPort(enc, relay.DirectAddr) } -func (r *IFReader) readNextPacket(buf []byte) []byte { +func (r *ifReader) readNextPacket(buf []byte) []byte { n, err := r.iface.Read(buf[:cap(buf)]) if err != nil { log.Fatalf("Failed to read from interface: %v", err) @@ -69,7 +69,7 @@ func (r *IFReader) readNextPacket(buf []byte) []byte { return buf[:n] } -func (r *IFReader) parsePacket(buf []byte) (byte, bool) { +func (r *ifReader) parsePacket(buf []byte) (byte, bool) { n := len(buf) if n == 0 { return 0, false @@ -98,6 +98,6 @@ func (r *IFReader) parsePacket(buf []byte) (byte, bool) { } } -func (*IFReader) logf(s string, args ...any) { +func (*ifReader) logf(s string, args ...any) { log.Printf("[IFReader] "+s, args...) } diff --git a/peer/ifreader2_test.go b/peer/ifreader_test.go similarity index 100% rename from peer/ifreader2_test.go rename to peer/ifreader_test.go diff --git a/peer/interface.go b/peer/interface.go index 7035c43..0022392 100644 --- a/peer/interface.go +++ b/peer/interface.go @@ -3,7 +3,6 @@ package peer import ( "fmt" "io" - "log" "net" "os" "syscall" @@ -11,45 +10,6 @@ import ( "golang.org/x/sys/unix" ) -// Get next packet, returning packet, ip, and possible error. -func readNextPacket(iface io.ReadWriteCloser, buf []byte) ([]byte, byte, error) { - var ( - version byte - ip byte - ) - for { - n, err := iface.Read(buf[:cap(buf)]) - if err != nil { - return nil, ip, err - } - - buf = buf[:n] - version = buf[0] >> 4 - - switch version { - case 4: - if n < 20 { - log.Printf("Short IPv4 packet: %d", len(buf)) - continue - } - ip = buf[19] - - case 6: - if len(buf) < 40 { - log.Printf("Short IPv6 packet: %d", len(buf)) - continue - } - ip = buf[39] - - default: - log.Printf("Invalid IP packet version: %v", version) - continue - } - - return buf, ip, nil - } -} - func openInterface(network []byte, localIP byte, name string) (io.ReadWriteCloser, error) { if len(network) != 4 { return nil, fmt.Errorf("expected network to be 4 bytes, got %d", len(network)) diff --git a/peer/interfaces.go b/peer/interfaces.go deleted file mode 100644 index d6b90d5..0000000 --- a/peer/interfaces.go +++ /dev/null @@ -1,49 +0,0 @@ -package peer - -import ( - "io" - "net" - "net/netip" -) - -type UDPConn interface { - ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) - WriteToUDPAddrPort([]byte, netip.AddrPort) (int, error) - WriteToUDP([]byte, *net.UDPAddr) (int, error) -} - -type ifWriter io.Writer - -type udpReader interface { - ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) -} - -type udpWriter interface { - WriteToUDPAddrPort([]byte, netip.AddrPort) (int, error) -} - -type mcUDPWriter interface { - WriteToUDP([]byte, *net.UDPAddr) (int, error) -} - -type Marshaller interface { - Marshal([]byte) []byte -} - -type dataPacketSender interface { - SendDataPacket(pkt []byte, peer *remotePeer) - RelayDataPacket(pkt []byte, peer, relay *remotePeer) -} - -type controlPacketSender interface { - SendControlPacket(pkt Marshaller, peer *remotePeer) - RelayControlPacket(pkt Marshaller, peer, relay *remotePeer) -} - -type encryptedPacketSender interface { - SendEncryptedDataPacket(pkt []byte, peer *remotePeer) -} - -type controlMsgHandler interface { - HandleControlMsg(pkt any) -} diff --git a/peer/main.go b/peer/main.go index c1ce110..9ab9ab7 100644 --- a/peer/main.go +++ b/peer/main.go @@ -6,7 +6,7 @@ import ( ) func Main() { - conf := Config{} + conf := peerConfig{} flag.StringVar(&conf.NetName, "name", "", "[REQUIRED] The network name.") flag.StringVar(&conf.HubAddress, "hub-address", "", "[REQUIRED] The hub address.") @@ -18,6 +18,6 @@ func Main() { os.Exit(1) } - peer := New(conf) + peer := newPeerMain(conf) peer.Run() } diff --git a/peer/mcreader.go b/peer/mcreader.go index 3410655..a56576e 100644 --- a/peer/mcreader.go +++ b/peer/mcreader.go @@ -1,10 +1,6 @@ package peer -import ( - "log" - "sync/atomic" -) - +/* type mcReader struct { conn udpReader super controlMsgHandler @@ -55,3 +51,4 @@ func (r *mcReader) handleNextPacket() { SrcAddr: remoteAddr, }) } +*/ diff --git a/peer/mcwriter.go b/peer/mcwriter.go index a8b55e9..c26c2c8 100644 --- a/peer/mcwriter.go +++ b/peer/mcwriter.go @@ -1,8 +1,6 @@ package peer import ( - "log" - "golang.org/x/crypto/nacl/sign" ) @@ -34,7 +32,9 @@ func verifyLocalDiscoveryPacket(pkt, buf []byte, pubSignKey []byte) bool { // ---------------------------------------------------------------------------- +/* type mcWriter struct { + conn mcUDPWriter discoveryPacket []byte } @@ -50,4 +50,4 @@ func (w *mcWriter) SendLocalDiscovery() { if _, err := w.conn.WriteToUDP(w.discoveryPacket, multicastAddr); err != nil { log.Printf("[MCWriter] Failed to write multicast UDP packet: %v", err) } -} + }*/ diff --git a/peer/mcwriter_test.go b/peer/mcwriter_test.go index ffef05d..74411f4 100644 --- a/peer/mcwriter_test.go +++ b/peer/mcwriter_test.go @@ -1,11 +1,6 @@ package peer -import ( - "bytes" - "net" - "testing" -) - +/* // ---------------------------------------------------------------------------- // Testing that we can create and verify a local discovery packet. @@ -100,3 +95,4 @@ func TestMCWriter_SendLocalDiscovery(t *testing.T) { t.Fatal("Verification should succeed.") } } +*/ diff --git a/peer/peer.go b/peer/peer.go index 6dc925b..a0afc3b 100644 --- a/peer/peer.go +++ b/peer/peer.go @@ -15,21 +15,21 @@ import ( "vppn/m" ) -type Peer struct { - ifReader *IFReader - connReader *ConnReader +type peerMain struct { + ifReader *ifReader + connReader *connReader iface io.Writer hubPoller *hubPoller - super *Super + super *supervisor } -type Config struct { +type peerConfig struct { NetName string HubAddress string APIKey string } -func New(conf Config) *Peer { +func newPeerMain(conf peerConfig) *peerMain { config, err := loadPeerConfig(conf.NetName) if err != nil { log.Printf("Failed to load configuration: %v", err) @@ -83,15 +83,15 @@ func New(conf Config) *Peer { rtPtr := &atomic.Pointer[routingTable]{} rtPtr.Store(&rt) - ifReader := NewIFReader(iface, writeToUDPAddrPort, rtPtr) - super := NewSuper(writeToUDPAddrPort, rtPtr, config.PrivKey) - connReader := NewConnReader(conn.ReadFromUDPAddrPort, writeToUDPAddrPort, iface, super.HandleControlMsg, rtPtr) + ifReader := newIFReader(iface, writeToUDPAddrPort, rtPtr) + super := newSupervisor(writeToUDPAddrPort, rtPtr, config.PrivKey) + connReader := newConnReader(conn.ReadFromUDPAddrPort, writeToUDPAddrPort, iface, super.HandleControlMsg, rtPtr) hubPoller, err := newHubPoller(config.PeerIP, conf.NetName, conf.HubAddress, conf.APIKey, super.HandleControlMsg) if err != nil { log.Fatalf("Failed to create hub poller: %v", err) } - return &Peer{ + return &peerMain{ iface: iface, ifReader: ifReader, connReader: connReader, @@ -100,14 +100,14 @@ func New(conf Config) *Peer { } } -func (p *Peer) Run() { +func (p *peerMain) Run() { go p.ifReader.Run() go p.connReader.Run() p.super.Start() p.hubPoller.Run() } -func initPeerWithHub(conf Config) { +func initPeerWithHub(conf peerConfig) { keys := generateKeys() initURL, err := url.Parse(conf.HubAddress) diff --git a/peer/peer_test.go b/peer/peer_test.go index 863ca8f..2c25812 100644 --- a/peer/peer_test.go +++ b/peer/peer_test.go @@ -14,8 +14,8 @@ type P struct { RT *atomic.Pointer[routingTable] Conn *TestUDPConn IFace *TestIFace - ConnReader *ConnReader - IFReader *IFReader + ConnReader *connReader + IFReader *ifReader } func NewPeerForTesting(n *TestNetwork, ip byte, addr netip.AddrPort) P { diff --git a/peer/peerstates.go b/peer/peerstates.go index 5ded157..a68afb1 100644 --- a/peer/peerstates.go +++ b/peer/peerstates.go @@ -25,7 +25,7 @@ type peerState interface { type pState struct { // Output. publish func(remotePeer) - sendControlPacket func(remotePeer, Marshaller) + sendControlPacket func(remotePeer, marshaller) // Immutable data. localIP byte @@ -124,7 +124,7 @@ func (s *pState) logf(format string, args ...any) { // ---------------------------------------------------------------------------- -func (s *pState) SendTo(pkt Marshaller, addr netip.AddrPort) { +func (s *pState) SendTo(pkt marshaller, addr netip.AddrPort) { if !addr.IsValid() { return } @@ -134,7 +134,7 @@ func (s *pState) SendTo(pkt Marshaller, addr netip.AddrPort) { s.Send(route, pkt) } -func (s *pState) Send(peer remotePeer, pkt Marshaller) { +func (s *pState) Send(peer remotePeer, pkt marshaller) { if err := s.limiter.Limit(); err != nil { s.logf("Rate limited.") return diff --git a/peer/peerstates_test.go b/peer/peerstates_test.go index 1ac531d..daf5c14 100644 --- a/peer/peerstates_test.go +++ b/peer/peerstates_test.go @@ -31,7 +31,7 @@ func NewPeerStateTestHarness() *PeerStateTestHarness { publish: func(rp remotePeer) { h.Published = rp }, - sendControlPacket: func(rp remotePeer, pkt Marshaller) { + sendControlPacket: func(rp remotePeer, pkt marshaller) { h.Sent = append(h.Sent, PeerStateControlMsg{rp, pkt}) }, localIP: 2, diff --git a/peer/peersuper.go b/peer/peersuper.go index d402141..7682d87 100644 --- a/peer/peersuper.go +++ b/peer/peersuper.go @@ -11,26 +11,26 @@ import ( "git.crumpington.com/lib/go/ratelimiter" ) -type Super struct { +type supervisor struct { writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error) staged routingTable shared *atomic.Pointer[routingTable] - peers [256]*PeerSuper + peers [256]*peerSuper lock sync.Mutex buf1 []byte buf2 []byte } -func NewSuper( +func newSupervisor( writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error), rt *atomic.Pointer[routingTable], privKey []byte, -) *Super { +) *supervisor { routes := rt.Load() - s := &Super{ + s := &supervisor{ writeToUDPAddrPort: writeToUDPAddrPort, staged: *routes, shared: rt, @@ -55,23 +55,23 @@ func NewSuper( MaxWaitCount: 1, }), } - s.peers[i] = NewPeerSuper(state) + s.peers[i] = newPeerSuper(state) } return s } -func (s *Super) Start() { +func (s *supervisor) Start() { for i := range s.peers { go s.peers[i].Run() } } -func (s *Super) HandleControlMsg(destIP byte, msg any) { +func (s *supervisor) HandleControlMsg(destIP byte, msg any) { s.peers[destIP].HandleControlMsg(msg) } -func (s *Super) send(peer remotePeer, pkt Marshaller) { +func (s *supervisor) send(peer remotePeer, pkt marshaller) { s.lock.Lock() defer s.lock.Unlock() @@ -90,7 +90,7 @@ func (s *Super) send(peer remotePeer, pkt Marshaller) { s.writeToUDPAddrPort(enc, relay.DirectAddr) } -func (s *Super) publish(rp remotePeer) { +func (s *supervisor) publish(rp remotePeer) { s.lock.Lock() defer s.lock.Unlock() @@ -100,7 +100,7 @@ func (s *Super) publish(rp remotePeer) { s.shared.Store(©) } -func (s *Super) ensureRelay() { +func (s *supervisor) ensureRelay() { if _, ok := s.staged.GetRelay(); ok { return } @@ -116,26 +116,26 @@ func (s *Super) ensureRelay() { // ---------------------------------------------------------------------------- -type PeerSuper struct { +type peerSuper struct { messages chan any state peerState } -func NewPeerSuper(state *pState) *PeerSuper { - return &PeerSuper{ +func newPeerSuper(state *pState) *peerSuper { + return &peerSuper{ messages: make(chan any, 8), state: state.OnPeerUpdate(nil), } } -func (s *PeerSuper) HandleControlMsg(msg any) { +func (s *peerSuper) HandleControlMsg(msg any) { select { case s.messages <- msg: default: } } -func (s *PeerSuper) Run() { +func (s *peerSuper) Run() { go func() { // Randomize ping timers. time.Sleep(time.Duration(rand.Intn(4000)) * time.Millisecond) diff --git a/peer/routingtable.go b/peer/routingtable.go index 7bbf542..3f0aac3 100644 --- a/peer/routingtable.go +++ b/peer/routingtable.go @@ -7,7 +7,7 @@ import ( ) // TODO: Remove -func NewRemotePeer(ip byte) *remotePeer { +func newRemotePeer(ip byte) *remotePeer { counter := uint64(time.Now().Unix()<<30 + 1) return &remotePeer{ IP: ip, @@ -58,7 +58,7 @@ func (p remotePeer) DecryptDataPacket(h header, enc, out []byte) ([]byte, error) } // Peer must have a ControlCipher. -func (p remotePeer) EncryptControlPacket(pkt Marshaller, tmp, out []byte) []byte { +func (p remotePeer) EncryptControlPacket(pkt marshaller, tmp, out []byte) []byte { tmp = pkt.Marshal(tmp) h := header{ StreamID: controlStreamID, -- 2.39.5 From fb8f51ba679e1c16db386d9c1b9a68abd4901882 Mon Sep 17 00:00:00 2001 From: jdl Date: Wed, 19 Feb 2025 16:08:39 +0100 Subject: [PATCH 10/26] wip --- peer/crypto.go | 86 -------------------------------------- peer/crypto_test.go | 98 ++++++++++++++++++++++---------------------- peer/hubpoller.go | 2 + peer/peerstates.go | 1 + peer/routingtable.go | 3 ++ 5 files changed, 54 insertions(+), 136 deletions(-) diff --git a/peer/crypto.go b/peer/crypto.go index 160f7fd..a533e6d 100644 --- a/peer/crypto.go +++ b/peer/crypto.go @@ -3,8 +3,6 @@ package peer import ( "crypto/rand" "log" - "net/netip" - "sync/atomic" "golang.org/x/crypto/nacl/box" "golang.org/x/crypto/nacl/sign" @@ -30,87 +28,3 @@ func generateKeys() cryptoKeys { return cryptoKeys{pubKey[:], privKey[:], pubSignKey[:], privSignKey[:]} } - -// ---------------------------------------------------------------------------- - -// Peer must have a ControlCipher. -func encryptControlPacket( - localIP byte, - peer *remotePeer, - pkt marshaller, - tmp []byte, - out []byte, -) []byte { - h := header{ - StreamID: controlStreamID, - Counter: atomic.AddUint64(peer.counter, 1), - SourceIP: localIP, - DestIP: peer.IP, - } - tmp = pkt.Marshal(tmp) - return peer.ControlCipher.Encrypt(h, tmp, out) -} - -// Returns a controlMsg[PacketType]. Peer must have a non-nil ControlCipher. -// -// This function also drops packets with duplicate sequence numbers. -func decryptControlPacket( - peer *remotePeer, - fromAddr netip.AddrPort, - h header, - encrypted []byte, - tmp []byte, -) (any, error) { - out, ok := peer.ControlCipher.Decrypt(encrypted, tmp) - if !ok { - return nil, errDecryptionFailed - } - - if peer.dupCheck.IsDup(h.Counter) { - return nil, errDuplicateSeqNum - } - - msg, err := parseControlMsg(h.SourceIP, fromAddr, out) - if err != nil { - return nil, err - } - - return msg, nil -} - -// ---------------------------------------------------------------------------- - -func encryptDataPacket( - localIP byte, - destIP byte, - peer *remotePeer, - data []byte, - out []byte, -) []byte { - h := header{ - StreamID: dataStreamID, - Counter: atomic.AddUint64(peer.counter, 1), - SourceIP: localIP, - DestIP: destIP, - } - return peer.DataCipher.Encrypt(h, data, out) -} - -// Decrypts and de-dups incoming data packets. -func decryptDataPacket( - peer *remotePeer, - h header, - encrypted []byte, - out []byte, -) ([]byte, error) { - dec, ok := peer.DataCipher.Decrypt(encrypted, out) - if !ok { - return nil, errDecryptionFailed - } - - if peer.dupCheck.IsDup(h.Counter) { - return nil, errDuplicateSeqNum - } - - return dec, nil -} diff --git a/peer/crypto_test.go b/peer/crypto_test.go index 802653c..b3c00f3 100644 --- a/peer/crypto_test.go +++ b/peer/crypto_test.go @@ -1,9 +1,6 @@ package peer import ( - "bytes" - "crypto/rand" - "errors" "net/netip" "reflect" "testing" @@ -39,10 +36,10 @@ func TestDecryptControlPacket(t *testing.T) { Direct: true, } - enc := encryptControlPacket(r1.IP, r2, in, tmp, out) + enc := r1.EncryptControlPacket(in, tmp, out) h := parseHeader(enc) - iMsg, err := decryptControlPacket(r2, netip.AddrPort{}, h, enc, tmp) + iMsg, err := r2.DecryptControlPacket(netip.AddrPort{}, h, enc, tmp) if err != nil { t.Fatal(err) } @@ -57,59 +54,59 @@ func TestDecryptControlPacket(t *testing.T) { } } -func TestDecryptControlPacket_decryptionFailed(t *testing.T) { - var ( - r1, r2 = newRoutePairForTesting() - tmp = make([]byte, bufferSize) - out = make([]byte, bufferSize) - ) +/* + func TestDecryptControlPacket_decryptionFailed(t *testing.T) { + var ( + r1, r2 = newRoutePairForTesting() + tmp = make([]byte, bufferSize) + out = make([]byte, bufferSize) + ) - in := packetSyn{ - TraceID: newTraceID(), - SharedKey: r1.DataCipher.Key(), - Direct: true, - } + in := packetSyn{ + TraceID: newTraceID(), + SharedKey: r1.DataCipher.Key(), + Direct: true, + } - enc := encryptControlPacket(r1.IP, r2, in, tmp, out) - h := parseHeader(enc) + enc := encryptControlPacket(r1.IP, r2, in, tmp, out) + h := parseHeader(enc) - for i := range enc { - x := bytes.Clone(enc) - x[i]++ - _, err := decryptControlPacket(r2, netip.AddrPort{}, h, x, tmp) - if !errors.Is(err, errDecryptionFailed) { - t.Fatal(i, err) + for i := range enc { + x := bytes.Clone(enc) + x[i]++ + _, err := decryptControlPacket(r2, netip.AddrPort{}, h, x, tmp) + if !errors.Is(err, errDecryptionFailed) { + t.Fatal(i, err) + } } } -} -func TestDecryptControlPacket_duplicate(t *testing.T) { - var ( - r1, r2 = newRoutePairForTesting() - tmp = make([]byte, bufferSize) - out = make([]byte, bufferSize) - ) + func TestDecryptControlPacket_duplicate(t *testing.T) { + var ( + r1, r2 = newRoutePairForTesting() + tmp = make([]byte, bufferSize) + out = make([]byte, bufferSize) + ) - in := packetSyn{ - TraceID: newTraceID(), - SharedKey: r1.DataCipher.Key(), - Direct: true, + in := packetSyn{ + TraceID: newTraceID(), + SharedKey: r1.DataCipher.Key(), + Direct: true, + } + + enc := encryptControlPacket(r1.IP, r2, in, tmp, out) + h := parseHeader(enc) + + if _, err := decryptControlPacket(r2, netip.AddrPort{}, h, enc, tmp); err != nil { + t.Fatal(err) + } + + _, err := decryptControlPacket(r2, netip.AddrPort{}, h, enc, tmp) + if !errors.Is(err, errDuplicateSeqNum) { + t.Fatal(err) + } } - enc := encryptControlPacket(r1.IP, r2, in, tmp, out) - h := parseHeader(enc) - - if _, err := decryptControlPacket(r2, netip.AddrPort{}, h, enc, tmp); err != nil { - t.Fatal(err) - } - - _, err := decryptControlPacket(r2, netip.AddrPort{}, h, enc, tmp) - if !errors.Is(err, errDuplicateSeqNum) { - t.Fatal(err) - } -} - -/* func TestDecryptControlPacket_invalidPacket(t *testing.T) { var ( r1, r2 = newRoutePairForTesting() @@ -127,7 +124,7 @@ func TestDecryptControlPacket_duplicate(t *testing.T) { t.Fatal(err) } } -*/ + func TestDecryptDataPacket(t *testing.T) { var ( r1, r2 = newRoutePairForTesting() @@ -191,3 +188,4 @@ func TestDecryptDataPacket_duplicate(t *testing.T) { t.Fatal(err) } } +*/ diff --git a/peer/hubpoller.go b/peer/hubpoller.go index 2b50495..572cb74 100644 --- a/peer/hubpoller.go +++ b/peer/hubpoller.go @@ -51,7 +51,9 @@ func newHubPoller( } func (hp *hubPoller) Run() { + log.Printf("Running hub poller...") state, err := loadNetworkState(hp.netName) + log.Printf("Got state (%s) : %v", hp.netName, state) if err != nil { log.Printf("Failed to load network state: %v", err) log.Printf("Polling hub...") diff --git a/peer/peerstates.go b/peer/peerstates.go index a68afb1..6a13f9f 100644 --- a/peer/peerstates.go +++ b/peer/peerstates.go @@ -70,6 +70,7 @@ func (s *pState) OnPeerUpdate(peer *m.Peer) peerState { s.staged.IP = peer.PeerIP s.staged.PubSignKey = peer.PubSignKey + log.Printf("New cipher: %x, %x", s.privKey, peer.PubKey) s.staged.ControlCipher = newControlCipher(s.privKey, peer.PubKey) s.staged.DataCipher = newDataCipher() diff --git a/peer/routingtable.go b/peer/routingtable.go index 3f0aac3..8caa380 100644 --- a/peer/routingtable.go +++ b/peer/routingtable.go @@ -1,6 +1,7 @@ package peer import ( + "log" "net/netip" "sync/atomic" "time" @@ -67,6 +68,8 @@ func (p remotePeer) EncryptControlPacket(pkt marshaller, tmp, out []byte) []byte DestIP: p.IP, } + log.Printf("Encrypting with header: %#v", h) + return p.ControlCipher.Encrypt(h, tmp, out) } -- 2.39.5 From b797c5b321869c7de9a58068857a8ea2abb5ed2b Mon Sep 17 00:00:00 2001 From: jdl Date: Wed, 19 Feb 2025 16:34:03 +0100 Subject: [PATCH 11/26] Working --- peer/connreader.go | 2 - peer/hubpoller.go | 2 - peer/mcreader.go | 96 +++++++++++++++++++++++++------------------ peer/mcwriter.go | 30 +++++++------- peer/peer.go | 6 +++ peer/peerstates.go | 1 - peer/pubaddrs.go | 12 +++++- peer/pubaddrs_test.go | 2 +- peer/routingtable.go | 3 -- 9 files changed, 89 insertions(+), 65 deletions(-) diff --git a/peer/connreader.go b/peer/connreader.go index a07275e..b78e58f 100644 --- a/peer/connreader.go +++ b/peer/connreader.go @@ -50,9 +50,7 @@ func (r *connReader) Run() { func (r *connReader) handleNextPacket() { buf := r.buf[:bufferSize] - log.Printf("Getting next packet...") n, remoteAddr, err := r.readFromUDPAddrPort(buf) - log.Printf("Packet from %v...", remoteAddr) if err != nil { log.Fatalf("Failed to read from UDP port: %v", err) } diff --git a/peer/hubpoller.go b/peer/hubpoller.go index 572cb74..2b50495 100644 --- a/peer/hubpoller.go +++ b/peer/hubpoller.go @@ -51,9 +51,7 @@ func newHubPoller( } func (hp *hubPoller) Run() { - log.Printf("Running hub poller...") state, err := loadNetworkState(hp.netName) - log.Printf("Got state (%s) : %v", hp.netName, state) if err != nil { log.Printf("Failed to load network state: %v", err) log.Printf("Polling hub...") diff --git a/peer/mcreader.go b/peer/mcreader.go index a56576e..7c63f26 100644 --- a/peer/mcreader.go +++ b/peer/mcreader.go @@ -1,54 +1,70 @@ package peer -/* -type mcReader struct { - conn udpReader - super controlMsgHandler - peers [256]*atomic.Pointer[remotePeer] +import ( + "log" + "net" + "sync/atomic" + "time" +) - incoming []byte - buf []byte -} - -func newMCReader( - conn udpReader, - super controlMsgHandler, - peers [256]*atomic.Pointer[remotePeer], -) *mcReader { - return &mcReader{conn, super, peers, newBuf(), newBuf()} -} - -func (r *mcReader) Run() { +func runMCReader( + rt *atomic.Pointer[routingTable], + handleControlMsg func(destIP byte, msg any), +) { for { - r.handleNextPacket() + runMCReader2(rt, handleControlMsg) + time.Sleep(8 * time.Second) } } -func (r *mcReader) handleNextPacket() { - incoming := r.incoming[:bufferSize] - n, remoteAddr, err := r.conn.ReadFromUDPAddrPort(incoming) +func runMCReader2( + rt *atomic.Pointer[routingTable], + handleControlMsg func(destIP byte, msg any), +) { + var ( + raw = newBuf() + buf = newBuf() + logf = func(s string, args ...any) { + log.Printf("[MCReader] "+s, args...) + } + ) + + conn, err := net.ListenMulticastUDP("udp", nil, multicastAddr) if err != nil { - log.Fatalf("Failed to read from UDP multicast port: %v", err) - } - incoming = incoming[:n] - - h, ok := headerFromLocalDiscoveryPacket(incoming) - if !ok { + logf("Failed to bind to multicast address: %v", err) return } - peer := r.peers[h.SourceIP].Load() - if peer == nil || peer.PubSignKey == nil { - return - } + for { + conn.SetReadDeadline(time.Now().Add(32 * time.Second)) + n, remoteAddr, err := conn.ReadFromUDPAddrPort(raw[:bufferSize]) + if err != nil { + logf("Failed to read from UDP port): %v", err) + return + } - if !verifyLocalDiscoveryPacket(incoming, r.buf, peer.PubSignKey) { - return - } + raw = raw[:n] + h, ok := headerFromLocalDiscoveryPacket(raw) + if !ok { + logf("Failed to open discovery packet?") + continue + } - r.super.HandleControlMsg(controlMsg[packetLocalDiscovery]{ - SrcIP: h.SourceIP, - SrcAddr: remoteAddr, - }) + peer := rt.Load().Peers[h.SourceIP] + if peer.PubSignKey == nil { + logf("No signing key for peer %d.", h.SourceIP) + continue + } + + if !verifyLocalDiscoveryPacket(raw, buf, peer.PubSignKey) { + logf("Invalid signature from peer: %d", h.SourceIP) + continue + } + + msg := controlMsg[packetLocalDiscovery]{ + SrcIP: h.SourceIP, + SrcAddr: remoteAddr, + } + handleControlMsg(h.SourceIP, msg) + } } -*/ diff --git a/peer/mcwriter.go b/peer/mcwriter.go index c26c2c8..5559547 100644 --- a/peer/mcwriter.go +++ b/peer/mcwriter.go @@ -1,6 +1,10 @@ package peer import ( + "log" + "net" + "time" + "golang.org/x/crypto/nacl/sign" ) @@ -32,22 +36,18 @@ func verifyLocalDiscoveryPacket(pkt, buf []byte, pubSignKey []byte) bool { // ---------------------------------------------------------------------------- -/* -type mcWriter struct { +func runMCWriter(localIP byte, signingKey []byte) { + discoveryPacket := createLocalDiscoveryPacket(localIP, signingKey) - conn mcUDPWriter - discoveryPacket []byte -} + conn, err := net.ListenMulticastUDP("udp", nil, multicastAddr) + if err != nil { + log.Fatalf("Failed to bind to multicast address: %v", err) + } -func newMCWriter(conn mcUDPWriter, localIP byte, signingKey []byte) *mcWriter { - return &mcWriter{ - conn: conn, - discoveryPacket: createLocalDiscoveryPacket(localIP, signingKey), + for range time.Tick(16 * time.Second) { + _, err := conn.WriteToUDP(discoveryPacket, multicastAddr) + if err != nil { + log.Printf("[MCWriter] Failed to write multicast: %v", err) + } } } - -func (w *mcWriter) SendLocalDiscovery() { - if _, err := w.conn.WriteToUDP(w.discoveryPacket, multicastAddr); err != nil { - log.Printf("[MCWriter] Failed to write multicast UDP packet: %v", err) - } - }*/ diff --git a/peer/peer.go b/peer/peer.go index a0afc3b..45627b0 100644 --- a/peer/peer.go +++ b/peer/peer.go @@ -16,6 +16,8 @@ import ( ) type peerMain struct { + conf localConfig + rt *atomic.Pointer[routingTable] ifReader *ifReader connReader *connReader iface io.Writer @@ -92,6 +94,8 @@ func newPeerMain(conf peerConfig) *peerMain { } return &peerMain{ + conf: config, + rt: rtPtr, iface: iface, ifReader: ifReader, connReader: connReader, @@ -104,6 +108,8 @@ func (p *peerMain) Run() { go p.ifReader.Run() go p.connReader.Run() p.super.Start() + go runMCWriter(p.conf.PeerIP, p.conf.PrivSignKey) + go runMCReader(p.rt, p.super.HandleControlMsg) p.hubPoller.Run() } diff --git a/peer/peerstates.go b/peer/peerstates.go index 6a13f9f..a68afb1 100644 --- a/peer/peerstates.go +++ b/peer/peerstates.go @@ -70,7 +70,6 @@ func (s *pState) OnPeerUpdate(peer *m.Peer) peerState { s.staged.IP = peer.PeerIP s.staged.PubSignKey = peer.PubSignKey - log.Printf("New cipher: %x, %x", s.privKey, peer.PubKey) s.staged.ControlCipher = newControlCipher(s.privKey, peer.PubKey) s.staged.DataCipher = newDataCipher() diff --git a/peer/pubaddrs.go b/peer/pubaddrs.go index 13ab66f..027057a 100644 --- a/peer/pubaddrs.go +++ b/peer/pubaddrs.go @@ -5,10 +5,12 @@ import ( "net/netip" "runtime/debug" "sort" + "sync" "time" ) type pubAddrStore struct { + lock sync.Mutex localPub bool localAddr netip.AddrPort lastSeen map[netip.AddrPort]time.Time @@ -25,6 +27,9 @@ func newPubAddrStore(localAddr netip.AddrPort) *pubAddrStore { } func (store *pubAddrStore) Store(add netip.AddrPort) { + store.lock.Lock() + defer store.lock.Unlock() + if store.localPub { log.Printf("OOPS: Local pub but storage attempt: %s", debug.Stack()) return @@ -42,6 +47,11 @@ func (store *pubAddrStore) Store(add netip.AddrPort) { } func (store *pubAddrStore) Get() (addrs [8]netip.AddrPort) { + store.lock.Lock() + defer store.lock.Unlock() + + store.clean() + if store.localPub { addrs[0] = store.localAddr return @@ -51,7 +61,7 @@ func (store *pubAddrStore) Get() (addrs [8]netip.AddrPort) { return } -func (store *pubAddrStore) Clean() { +func (store *pubAddrStore) clean() { if store.localPub { return } diff --git a/peer/pubaddrs_test.go b/peer/pubaddrs_test.go index b79e854..fa47c22 100644 --- a/peer/pubaddrs_test.go +++ b/peer/pubaddrs_test.go @@ -20,7 +20,7 @@ func TestPubAddrStore(t *testing.T) { time.Sleep(time.Millisecond) } - s.Clean() + s.clean() l2 := s.Get() if l2[0] != l[2] || l2[1] != l[1] || l2[2] != l[0] { diff --git a/peer/routingtable.go b/peer/routingtable.go index 8caa380..3f0aac3 100644 --- a/peer/routingtable.go +++ b/peer/routingtable.go @@ -1,7 +1,6 @@ package peer import ( - "log" "net/netip" "sync/atomic" "time" @@ -68,8 +67,6 @@ func (p remotePeer) EncryptControlPacket(pkt marshaller, tmp, out []byte) []byte DestIP: p.IP, } - log.Printf("Encrypting with header: %#v", h) - return p.ControlCipher.Encrypt(h, tmp, out) } -- 2.39.5 From 589aa08866778cd1680acc72193fa56cd7eb7603 Mon Sep 17 00:00:00 2001 From: jdl Date: Wed, 19 Feb 2025 16:34:13 +0100 Subject: [PATCH 12/26] working --- node/README.md | 16 -- node/addrdiscovery.go | 71 ------ node/addrdiscovery_test.go | 29 --- node/bitset.go | 21 -- node/bitset_test.go | 48 ---- node/cipher-control.go | 26 --- node/cipher-control_test.go | 122 ----------- node/cipher-data.go | 60 ----- node/cipher-data_test.go | 141 ------------ node/cipher-discovery.go | 13 -- node/config.go | 11 - node/conn.go | 3 - node/connwriter.go | 146 ------------- node/connwriter_test.go | 248 --------------------- node/crypto.go | 30 --- node/data-flow.dot | 14 -- node/dupcheck.go | 76 ------- node/dupcheck_test.go | 54 ----- node/files.go | 82 ------- node/globalfuncs.go | 8 - node/globals.go | 63 ------ node/header.go | 49 ----- node/header_test.go | 21 -- node/hubpoller.go | 92 -------- node/ifreader.go | 102 --------- node/ifreader_test.go | 117 ---------- node/ifwriter.go | 5 - node/interface.go | 177 --------------- node/localdiscovery.go | 97 --------- node/localdiscovery_test.go | 35 --- node/main.go | 320 --------------------------- node/main_test.go | 37 ---- node/mcwriter.go | 62 ------ node/mcwriter_test.go | 102 --------- node/messages.go | 58 ----- node/packets-util.go | 190 ---------------- node/packets-util_test.go | 56 ----- node/packets.go | 130 ----------- node/packets_test.go | 1 - node/packetsender.go | 127 ----------- node/relaymanager.go | 41 ---- node/shared.go | 59 ----- node/shared_test.go | 16 -- node/supervisor.go | 421 ------------------------------------ 44 files changed, 3597 deletions(-) delete mode 100644 node/README.md delete mode 100644 node/addrdiscovery.go delete mode 100644 node/addrdiscovery_test.go delete mode 100644 node/bitset.go delete mode 100644 node/bitset_test.go delete mode 100644 node/cipher-control.go delete mode 100644 node/cipher-control_test.go delete mode 100644 node/cipher-data.go delete mode 100644 node/cipher-data_test.go delete mode 100644 node/cipher-discovery.go delete mode 100644 node/config.go delete mode 100644 node/conn.go delete mode 100644 node/connwriter.go delete mode 100644 node/connwriter_test.go delete mode 100644 node/crypto.go delete mode 100644 node/data-flow.dot delete mode 100644 node/dupcheck.go delete mode 100644 node/dupcheck_test.go delete mode 100644 node/files.go delete mode 100644 node/globalfuncs.go delete mode 100644 node/globals.go delete mode 100644 node/header.go delete mode 100644 node/header_test.go delete mode 100644 node/hubpoller.go delete mode 100644 node/ifreader.go delete mode 100644 node/ifreader_test.go delete mode 100644 node/ifwriter.go delete mode 100644 node/interface.go delete mode 100644 node/localdiscovery.go delete mode 100644 node/localdiscovery_test.go delete mode 100644 node/main.go delete mode 100644 node/main_test.go delete mode 100644 node/mcwriter.go delete mode 100644 node/mcwriter_test.go delete mode 100644 node/messages.go delete mode 100644 node/packets-util.go delete mode 100644 node/packets-util_test.go delete mode 100644 node/packets.go delete mode 100644 node/packets_test.go delete mode 100644 node/packetsender.go delete mode 100644 node/relaymanager.go delete mode 100644 node/shared.go delete mode 100644 node/shared_test.go delete mode 100644 node/supervisor.go diff --git a/node/README.md b/node/README.md deleted file mode 100644 index 58b4298..0000000 --- a/node/README.md +++ /dev/null @@ -1,16 +0,0 @@ -# VPPN Peer Code - -## Refactoring for Testability - -* [x] connWriter -* [x] mcWriter -* [x] ifWriter -* [ ] ifReader (testing) -* [ ] connReader -* [ ] mcReader -* [ ] hubPoller -* [ ] supervisor - -## Updates - -* [ ] Send timing info w/ syn/ack packets diff --git a/node/addrdiscovery.go b/node/addrdiscovery.go deleted file mode 100644 index 160c7a0..0000000 --- a/node/addrdiscovery.go +++ /dev/null @@ -1,71 +0,0 @@ -package node - -import ( - "log" - "net/netip" - "runtime/debug" - "sort" - "time" -) - -type pubAddrStore struct { - lastSeen map[netip.AddrPort]time.Time - addrList []netip.AddrPort -} - -func newPubAddrStore() *pubAddrStore { - return &pubAddrStore{ - lastSeen: map[netip.AddrPort]time.Time{}, - addrList: make([]netip.AddrPort, 0, 32), - } -} - -func (store *pubAddrStore) Store(add netip.AddrPort) { - if localPub { - log.Printf("OOPS: Local pub but storage attempt: %s", debug.Stack()) - return - } - - if !add.IsValid() { - return - } - - if _, exists := store.lastSeen[add]; !exists { - store.addrList = append(store.addrList, add) - } - store.lastSeen[add] = time.Now() - store.sort() -} - -func (store *pubAddrStore) Get() (addrs [8]netip.AddrPort) { - if localPub { - addrs[0] = localAddr - return - } - - copy(addrs[:], store.addrList) - return -} - -func (store *pubAddrStore) Clean() { - if localPub { - return - } - - for ip, lastSeen := range store.lastSeen { - if time.Since(lastSeen) > timeoutInterval { - delete(store.lastSeen, ip) - } - } - store.addrList = store.addrList[:0] - for ip := range store.lastSeen { - store.addrList = append(store.addrList, ip) - } - store.sort() -} - -func (store *pubAddrStore) sort() { - sort.Slice(store.addrList, func(i, j int) bool { - return store.lastSeen[store.addrList[j]].Before(store.lastSeen[store.addrList[i]]) - }) -} diff --git a/node/addrdiscovery_test.go b/node/addrdiscovery_test.go deleted file mode 100644 index 9851d6a..0000000 --- a/node/addrdiscovery_test.go +++ /dev/null @@ -1,29 +0,0 @@ -package node - -import ( - "net/netip" - "testing" - "time" -) - -func TestPubAddrStore(t *testing.T) { - s := newPubAddrStore() - - l := []netip.AddrPort{ - netip.AddrPortFrom(netip.AddrFrom4([4]byte{0, 1, 2, 3}), 20), - netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 2, 3}), 21), - netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 1, 2, 3}), 22), - } - - for i := range l { - s.Store(l[i]) - time.Sleep(time.Millisecond) - } - - s.Clean() - - l2 := s.Get() - if l2[0] != l[2] || l2[1] != l[1] || l2[2] != l[0] { - t.Fatal(l, l2) - } -} diff --git a/node/bitset.go b/node/bitset.go deleted file mode 100644 index a9024cb..0000000 --- a/node/bitset.go +++ /dev/null @@ -1,21 +0,0 @@ -package node - -const bitSetSize = 512 // Multiple of 64. - -type bitSet [bitSetSize / 64]uint64 - -func (bs *bitSet) Set(i int) { - bs[i/64] |= 1 << (i % 64) -} - -func (bs *bitSet) Clear(i int) { - bs[i/64] &= ^(1 << (i % 64)) -} - -func (bs *bitSet) ClearAll() { - clear(bs[:]) -} - -func (bs *bitSet) Get(i int) bool { - return bs[i/64]&(1<<(i%64)) != 0 -} diff --git a/node/bitset_test.go b/node/bitset_test.go deleted file mode 100644 index bd3307a..0000000 --- a/node/bitset_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package node - -import ( - "math/rand" - "testing" -) - -func TestBitSet(t *testing.T) { - state := make([]bool, bitSetSize) - for i := range state { - state[i] = rand.Float32() > 0.5 - } - - bs := bitSet{} - - for i := range state { - if state[i] { - bs.Set(i) - } - } - - for i := range state { - if bs.Get(i) != state[i] { - t.Fatal(i, state[i], bs.Get(i)) - } - } - - for i := range state { - if rand.Float32() > 0.5 { - state[i] = false - bs.Clear(i) - } - } - - for i := range state { - if bs.Get(i) != state[i] { - t.Fatal(i, state[i], bs.Get(i)) - } - } - - bs.ClearAll() - - for i := range state { - if bs.Get(i) { - t.Fatal(i, bs.Get(i)) - } - } -} diff --git a/node/cipher-control.go b/node/cipher-control.go deleted file mode 100644 index bd11470..0000000 --- a/node/cipher-control.go +++ /dev/null @@ -1,26 +0,0 @@ -package node - -import "golang.org/x/crypto/nacl/box" - -type controlCipher struct { - sharedKey [32]byte -} - -func newControlCipher(privKey, pubKey []byte) *controlCipher { - shared := [32]byte{} - box.Precompute(&shared, (*[32]byte)(pubKey), (*[32]byte)(privKey)) - return &controlCipher{shared} -} - -func (cc *controlCipher) Encrypt(h header, data, out []byte) []byte { - const s = controlHeaderSize - out = out[:s+controlCipherOverhead+len(data)] - h.Marshal(out[:s]) - box.SealAfterPrecomputation(out[s:s], data, (*[24]byte)(out[:s]), &cc.sharedKey) - return out -} - -func (cc *controlCipher) Decrypt(encrypted, out []byte) (data []byte, ok bool) { - const s = controlHeaderSize - return box.OpenAfterPrecomputation(out[:0], encrypted[s:], (*[24]byte)(encrypted[:s]), &cc.sharedKey) -} diff --git a/node/cipher-control_test.go b/node/cipher-control_test.go deleted file mode 100644 index ab28860..0000000 --- a/node/cipher-control_test.go +++ /dev/null @@ -1,122 +0,0 @@ -package node - -import ( - "bytes" - "crypto/rand" - "reflect" - "testing" - - "golang.org/x/crypto/nacl/box" -) - -func newControlCipherForTesting() (c1, c2 *controlCipher) { - pubKey1, privKey1, err := box.GenerateKey(rand.Reader) - if err != nil { - panic(err) - } - - pubKey2, privKey2, err := box.GenerateKey(rand.Reader) - if err != nil { - panic(err) - } - - return newControlCipher(privKey1[:], pubKey2[:]), - newControlCipher(privKey2[:], pubKey1[:]) -} - -func TestControlCipher(t *testing.T) { - c1, c2 := newControlCipherForTesting() - - maxSizePlaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead) - rand.Read(maxSizePlaintext) - - testCases := [][]byte{ - make([]byte, 0), - {1}, - {255}, - {1, 2, 3, 4, 5}, - []byte("Hello world"), - maxSizePlaintext, - } - - for _, plaintext := range testCases { - h1 := header{ - StreamID: controlStreamID, - Counter: 235153, - SourceIP: 4, - DestIP: 88, - } - - encrypted := make([]byte, bufferSize) - - encrypted = c1.Encrypt(h1, plaintext, encrypted) - - h2 := header{} - h2.Parse(encrypted) - if !reflect.DeepEqual(h1, h2) { - t.Fatal(h1, h2) - } - - decrypted, ok := c2.Decrypt(encrypted, make([]byte, bufferSize)) - if !ok { - t.Fatal(ok) - } - - if !bytes.Equal(decrypted, plaintext) { - t.Fatal("not equal") - } - } -} - -func TestControlCipher_ShortCiphertext(t *testing.T) { - c1, _ := newControlCipherForTesting() - shortText := make([]byte, controlHeaderSize+controlCipherOverhead-1) - rand.Read(shortText) - _, ok := c1.Decrypt(shortText, make([]byte, bufferSize)) - if ok { - t.Fatal(ok) - } -} - -func BenchmarkControlCipher_Encrypt(b *testing.B) { - c1, _ := newControlCipherForTesting() - h1 := header{ - Counter: 235153, - SourceIP: 4, - DestIP: 88, - } - - plaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead) - rand.Read(plaintext) - - encrypted := make([]byte, bufferSize) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - encrypted = c1.Encrypt(h1, plaintext, encrypted) - } -} - -func BenchmarkControlCipher_Decrypt(b *testing.B) { - c1, c2 := newControlCipherForTesting() - - h1 := header{ - Counter: 235153, - SourceIP: 4, - DestIP: 88, - } - - plaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead) - rand.Read(plaintext) - - encrypted := make([]byte, bufferSize) - - encrypted = c1.Encrypt(h1, plaintext, encrypted) - - decrypted := make([]byte, bufferSize) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - decrypted, _ = c2.Decrypt(encrypted, decrypted) - } -} diff --git a/node/cipher-data.go b/node/cipher-data.go deleted file mode 100644 index 9151870..0000000 --- a/node/cipher-data.go +++ /dev/null @@ -1,60 +0,0 @@ -package node - -import ( - "crypto/aes" - "crypto/cipher" - "crypto/rand" -) - -type dataCipher struct { - key [32]byte - aead cipher.AEAD -} - -func newDataCipher() *dataCipher { - key := [32]byte{} - if _, err := rand.Read(key[:]); err != nil { - panic(err) - } - return newDataCipherFromKey(key) -} - -func newDataCipherFromKey(key [32]byte) *dataCipher { - block, err := aes.NewCipher(key[:]) - if err != nil { - panic(err) - } - - aead, err := cipher.NewGCM(block) - if err != nil { - panic(err) - } - - return &dataCipher{key: key, aead: aead} -} - -func (sc *dataCipher) Key() [32]byte { - return sc.key -} - -func (sc *dataCipher) Encrypt(h header, data, out []byte) []byte { - const s = dataHeaderSize - out = out[:s+dataCipherOverhead+len(data)] - h.Marshal(out[:s]) - sc.aead.Seal(out[s:s], out[:s], data, nil) - return out -} - -func (sc *dataCipher) Decrypt(encrypted, out []byte) (data []byte, ok bool) { - const s = dataHeaderSize - if len(encrypted) < s+dataCipherOverhead { - ok = false - return - } - - var err error - - data, err = sc.aead.Open(out[:0], encrypted[:s], encrypted[s:], nil) - ok = err == nil - return -} diff --git a/node/cipher-data_test.go b/node/cipher-data_test.go deleted file mode 100644 index 493c198..0000000 --- a/node/cipher-data_test.go +++ /dev/null @@ -1,141 +0,0 @@ -package node - -import ( - "bytes" - "crypto/rand" - mrand "math/rand/v2" - "reflect" - "testing" -) - -func TestDataCipher(t *testing.T) { - maxSizePlaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead) - rand.Read(maxSizePlaintext) - - testCases := [][]byte{ - make([]byte, 0), - {1}, - {255}, - {1, 2, 3, 4, 5}, - []byte("Hello world"), - maxSizePlaintext, - } - - for _, plaintext := range testCases { - h1 := header{ - StreamID: dataStreamID, - Counter: 235153, - SourceIP: 4, - DestIP: 88, - } - - encrypted := make([]byte, bufferSize) - - dc1 := newDataCipher() - encrypted = dc1.Encrypt(h1, plaintext, encrypted) - h2 := header{} - h2.Parse(encrypted) - - dc2 := newDataCipherFromKey(dc1.Key()) - - decrypted, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize)) - if !ok { - t.Fatal(ok) - } - - if !bytes.Equal(plaintext, decrypted) { - t.Fatal("not equal") - } - - if !reflect.DeepEqual(h1, h2) { - t.Fatalf("%v != %v", h1, h2) - } - } -} - -func TestDataCipher_ModifyCiphertext(t *testing.T) { - maxSizePlaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead) - rand.Read(maxSizePlaintext) - - testCases := [][]byte{ - make([]byte, 0), - {1}, - {255}, - {1, 2, 3, 4, 5}, - []byte("Hello world"), - maxSizePlaintext, - } - - for _, plaintext := range testCases { - h1 := header{ - Counter: 235153, - SourceIP: 4, - DestIP: 88, - } - - encrypted := make([]byte, bufferSize) - - dc1 := newDataCipher() - encrypted = dc1.Encrypt(h1, plaintext, encrypted) - encrypted[mrand.IntN(len(encrypted))]++ - - dc2 := newDataCipherFromKey(dc1.Key()) - - _, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize)) - if ok { - t.Fatal(ok) - } - } -} - -func TestDataCipher_ShortCiphertext(t *testing.T) { - dc1 := newDataCipher() - shortText := make([]byte, dataHeaderSize+dataCipherOverhead-1) - rand.Read(shortText) - _, ok := dc1.Decrypt(shortText, make([]byte, bufferSize)) - if ok { - t.Fatal(ok) - } -} - -func BenchmarkDataCipher_Encrypt(b *testing.B) { - h1 := header{ - Counter: 235153, - SourceIP: 4, - DestIP: 88, - } - - plaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead) - rand.Read(plaintext) - - encrypted := make([]byte, bufferSize) - - dc1 := newDataCipher() - b.ResetTimer() - for i := 0; i < b.N; i++ { - encrypted = dc1.Encrypt(h1, plaintext, encrypted) - } -} - -func BenchmarkDataCipher_Decrypt(b *testing.B) { - h1 := header{ - Counter: 235153, - SourceIP: 4, - DestIP: 88, - } - - plaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead) - rand.Read(plaintext) - - encrypted := make([]byte, bufferSize) - - dc1 := newDataCipher() - encrypted = dc1.Encrypt(h1, plaintext, encrypted) - - decrypted := make([]byte, bufferSize) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - decrypted, _ = dc1.Decrypt(encrypted, decrypted) - } -} diff --git a/node/cipher-discovery.go b/node/cipher-discovery.go deleted file mode 100644 index 85e1381..0000000 --- a/node/cipher-discovery.go +++ /dev/null @@ -1,13 +0,0 @@ -package node - -/* -func signData(privKey *[64]byte, h header, data, out []byte) []byte { - out = out[:headerSize] - h.Marshal(out) - return sign.Sign(out, data, privKey) -} - -func openData(pubKey *[32]byte, signed, out []byte) (data []byte, ok bool) { - return sign.Open(out[:0], signed[headerSize:], pubKey) -} -*/ diff --git a/node/config.go b/node/config.go deleted file mode 100644 index 46da9eb..0000000 --- a/node/config.go +++ /dev/null @@ -1,11 +0,0 @@ -package node - -import "vppn/m" - -type localConfig struct { - m.PeerConfig - PubKey []byte - PrivKey []byte - PubSignKey []byte - PrivSignKey []byte -} diff --git a/node/conn.go b/node/conn.go deleted file mode 100644 index e000557..0000000 --- a/node/conn.go +++ /dev/null @@ -1,3 +0,0 @@ -package node - -// ---------------------------------------------------------------------------- diff --git a/node/connwriter.go b/node/connwriter.go deleted file mode 100644 index 62caa75..0000000 --- a/node/connwriter.go +++ /dev/null @@ -1,146 +0,0 @@ -package node - -import ( - "log" - "net/netip" - "sync" - "sync/atomic" - "time" -) - -// ---------------------------------------------------------------------------- - -type peerRoute struct { - IP byte - Up bool // True if data can be sent on the route. - Relay bool // True if the peer is a relay. - Direct bool // True if this is a direct connection. - PubSignKey []byte - ControlCipher *controlCipher - DataCipher *dataCipher - RemoteAddr netip.AddrPort // Remote address if directly connected. -} - -// ---------------------------------------------------------------------------- - -type udpAddrPortWriter interface { - WriteToUDPAddrPort([]byte, netip.AddrPort) (int, error) -} - -type marshaller interface { - Marshal([]byte) []byte -} - -// ---------------------------------------------------------------------------- - -type connWriter struct { - localIP byte - conn udpAddrPortWriter - - // For sending control packets. - cBuf1 []byte - cBuf2 []byte - - // For sending data packets. - dBuf1 []byte - dBuf2 []byte - - counters [256]uint64 - - // Lock around for sending on UDP Conn. - wLock sync.Mutex -} - -func newConnWriter(conn udpAddrPortWriter, localIP byte) *connWriter { - w := &connWriter{ - localIP: localIP, - conn: conn, - cBuf1: make([]byte, bufferSize), - cBuf2: make([]byte, bufferSize), - dBuf1: make([]byte, bufferSize), - dBuf2: make([]byte, bufferSize), - } - for i := range w.counters { - w.counters[i] = uint64(time.Now().Unix()<<30 + 1) - } - return w -} - -// Not safe for concurrent use. Should only be called by supervisor. -func (w *connWriter) SendControlPacket(pkt marshaller, route *peerRoute) { - buf := w.encryptControlPacket(pkt, route) - w.writeTo(buf, route.RemoteAddr) -} - -// Relay control packet. Routes must not be nil. -func (w *connWriter) RelayControlPacket(pkt marshaller, route, relay *peerRoute) { - buf := w.encryptControlPacket(pkt, route) - w.relayPacket(buf, w.cBuf1, route, relay) -} - -// Encrypted packet will occupy cBuf2. -func (w *connWriter) encryptControlPacket(pkt marshaller, route *peerRoute) []byte { - buf := pkt.Marshal(w.cBuf1) - h := header{ - StreamID: controlStreamID, - Counter: atomic.AddUint64(&w.counters[route.IP], 1), - SourceIP: w.localIP, - DestIP: route.IP, - } - return route.ControlCipher.Encrypt(h, buf, w.cBuf2) -} - -// Not safe for concurrent use. Should only be called by ifReader. -func (w *connWriter) SendDataPacket(pkt []byte, route *peerRoute) { - h := header{ - StreamID: dataStreamID, - Counter: atomic.AddUint64(&w.counters[route.IP], 1), - SourceIP: w.localIP, - DestIP: route.IP, - } - - enc := route.DataCipher.Encrypt(h, pkt, w.dBuf1) - w.writeTo(enc, route.RemoteAddr) -} - -// Relay a data packet. Routes must not be nil. -func (w *connWriter) RelayDataPacket(pkt []byte, route, relay *peerRoute) { - h := header{ - StreamID: dataStreamID, - Counter: atomic.AddUint64(&w.counters[route.IP], 1), - SourceIP: w.localIP, - DestIP: route.IP, - } - - enc := route.DataCipher.Encrypt(h, pkt, w.dBuf1) - w.relayPacket(enc, w.dBuf2, route, relay) -} - -// Safe for concurrent use. Should only be called by connReader. -// -// This function will send pkt to the peer directly. This is used when a peer -// is acting as a relay and is forwarding already encrypted data for another -// peer. -func (w *connWriter) SendEncryptedDataPacket(pkt []byte, route *peerRoute) { - w.writeTo(pkt, route.RemoteAddr) -} - -func (w *connWriter) relayPacket(data, buf []byte, route, relay *peerRoute) { - h := header{ - StreamID: dataStreamID, - Counter: atomic.AddUint64(&w.counters[relay.IP], 1), - SourceIP: w.localIP, - DestIP: route.IP, - } - - enc := relay.DataCipher.Encrypt(h, data, buf) - w.writeTo(enc, relay.RemoteAddr) -} - -func (w *connWriter) writeTo(packet []byte, addr netip.AddrPort) { - w.wLock.Lock() - if _, err := w.conn.WriteToUDPAddrPort(packet, addr); err != nil { - log.Printf("Failed to write to UDP port: %v", err) - } - w.wLock.Unlock() -} diff --git a/node/connwriter_test.go b/node/connwriter_test.go deleted file mode 100644 index 388fbbc..0000000 --- a/node/connwriter_test.go +++ /dev/null @@ -1,248 +0,0 @@ -package node - -import ( - "bytes" - "net/netip" - "testing" -) - -// ---------------------------------------------------------------------------- - -type testUDPPacket struct { - Addr netip.AddrPort - Data []byte -} - -type testUDPAddrPortWriter struct { - written []testUDPPacket -} - -func (w *testUDPAddrPortWriter) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { - w.written = append(w.written, testUDPPacket{ - Addr: addr, - Data: bytes.Clone(b), - }) - return len(b), nil -} - -func (w *testUDPAddrPortWriter) Written() []testUDPPacket { - out := w.written - w.written = []testUDPPacket{} - return out -} - -// ---------------------------------------------------------------------------- - -type testPacket string - -func (p testPacket) Marshal(b []byte) []byte { - b = b[:len(p)] - copy(b, []byte(p)) - return b -} - -// ---------------------------------------------------------------------------- - -func testConnWriter_getTestRoutes() (local, remote, relayLocal, relayRemote *peerRoute) { - localKeys := generateKeys() - remoteKeys := generateKeys() - - local = &peerRoute{ - IP: 2, - Up: true, - Relay: false, - PubSignKey: remoteKeys.PubSignKey, - ControlCipher: newControlCipher(localKeys.PrivKey, remoteKeys.PubKey), - DataCipher: newDataCipher(), - RemoteAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 2}), 100), - } - - remote = &peerRoute{ - IP: 1, - Up: true, - Relay: false, - PubSignKey: localKeys.PubSignKey, - ControlCipher: newControlCipher(remoteKeys.PrivKey, localKeys.PubKey), - DataCipher: local.DataCipher, - RemoteAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 100), - } - - rLocalKeys := generateKeys() - rRemoteKeys := generateKeys() - - relayLocal = &peerRoute{ - IP: 3, - Up: true, - Relay: true, - Direct: true, - PubSignKey: rRemoteKeys.PubSignKey, - ControlCipher: newControlCipher(rLocalKeys.PrivKey, rRemoteKeys.PubKey), - DataCipher: newDataCipher(), - RemoteAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 3}), 100), - } - - relayRemote = &peerRoute{ - IP: 1, - Up: true, - Relay: false, - Direct: true, - PubSignKey: rLocalKeys.PubSignKey, - ControlCipher: newControlCipher(rRemoteKeys.PrivKey, rLocalKeys.PubKey), - DataCipher: relayLocal.DataCipher, - RemoteAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 100), - } - - return -} - -// ---------------------------------------------------------------------------- - -// Testing if we can send a control packet directly to the remote route. -func TestConnWriter_SendControlPacket_direct(t *testing.T) { - route, rRoute, _, _ := testConnWriter_getTestRoutes() - route.Direct = true - - writer := &testUDPAddrPortWriter{} - w := newConnWriter(writer, rRoute.IP) - in := testPacket("hello world!") - - w.SendControlPacket(in, route) - out := writer.Written() - if len(out) != 1 { - t.Fatal(out) - } - - if out[0].Addr != route.RemoteAddr { - t.Fatal(out[0]) - } - - dec, ok := rRoute.ControlCipher.Decrypt(out[0].Data, make([]byte, 1024)) - if !ok { - t.Fatal(ok) - } - if string(dec) != string(in) { - t.Fatal(dec) - } -} - -// Testing if we can relay a packet via an intermediary. -func TestConnWriter_RelayControlPacket_relay(t *testing.T) { - route, rRoute, relay, rRelay := testConnWriter_getTestRoutes() - - writer := &testUDPAddrPortWriter{} - w := newConnWriter(writer, rRoute.IP) - in := testPacket("hello world!") - - w.RelayControlPacket(in, route, relay) - - out := writer.Written() - if len(out) != 1 { - t.Fatal(out) - } - - if out[0].Addr != relay.RemoteAddr { - t.Fatal(out[0]) - } - - dec, ok := rRelay.DataCipher.Decrypt(out[0].Data, make([]byte, 1024)) - if !ok { - t.Fatal(ok) - } - - dec2, ok := rRoute.ControlCipher.Decrypt(dec, make([]byte, 1024)) - if !ok { - t.Fatal(ok) - } - - if string(dec2) != string(in) { - t.Fatal(dec2) - } -} - -// Testing that we can send a data packet directly to a remote route. -func TestConnWriter_SendDataPacket_direct(t *testing.T) { - route, rRoute, _, _ := testConnWriter_getTestRoutes() - route.Direct = true - - writer := &testUDPAddrPortWriter{} - w := newConnWriter(writer, rRoute.IP) - - in := []byte("hello world!") - w.SendDataPacket(in, route) - - out := writer.Written() - if len(out) != 1 { - t.Fatal(out) - } - - if out[0].Addr != route.RemoteAddr { - t.Fatal(out[0]) - } - - dec, ok := rRoute.DataCipher.Decrypt(out[0].Data, make([]byte, 1024)) - if !ok { - t.Fatal(ok) - } - - if !bytes.Equal(dec, in) { - t.Fatal(dec) - } -} - -// Testing that we can relay a data packet via a relay. -func TestConnWriter_RelayDataPacket_relay(t *testing.T) { - route, rRoute, relay, rRelay := testConnWriter_getTestRoutes() - - writer := &testUDPAddrPortWriter{} - w := newConnWriter(writer, rRoute.IP) - in := []byte("Hello world!") - - w.RelayDataPacket(in, route, relay) - - out := writer.Written() - if len(out) != 1 { - t.Fatal(out) - } - - if out[0].Addr != relay.RemoteAddr { - t.Fatal(out[0]) - } - - dec, ok := rRelay.DataCipher.Decrypt(out[0].Data, make([]byte, 1024)) - if !ok { - t.Fatal(ok) - } - - dec2, ok := rRoute.DataCipher.Decrypt(dec, make([]byte, 1024)) - if !ok { - t.Fatal(ok) - } - - if !bytes.Equal(dec2, in) { - t.Fatal(dec2) - } -} - -// Testing that we can send an already encrypted packet. -func TestConnWriter_SendEncryptedDataPacket(t *testing.T) { - route, rRoute, _, _ := testConnWriter_getTestRoutes() - - writer := &testUDPAddrPortWriter{} - w := newConnWriter(writer, rRoute.IP) - in := []byte("Hello world!") - - w.SendEncryptedDataPacket(in, route) - - out := writer.Written() - if len(out) != 1 { - t.Fatal(out) - } - - if out[0].Addr != route.RemoteAddr { - t.Fatal(out[0]) - } - - if !bytes.Equal(out[0].Data, in) { - t.Fatal(out[0]) - } -} diff --git a/node/crypto.go b/node/crypto.go deleted file mode 100644 index c24aaad..0000000 --- a/node/crypto.go +++ /dev/null @@ -1,30 +0,0 @@ -package node - -import ( - "crypto/rand" - "log" - - "golang.org/x/crypto/nacl/box" - "golang.org/x/crypto/nacl/sign" -) - -type cryptoKeys struct { - PubKey []byte - PrivKey []byte - PubSignKey []byte - PrivSignKey []byte -} - -func generateKeys() cryptoKeys { - pubKey, privKey, err := box.GenerateKey(rand.Reader) - if err != nil { - log.Fatalf("Failed to generate encryption keys: %v", err) - } - - pubSignKey, privSignKey, err := sign.GenerateKey(rand.Reader) - if err != nil { - log.Fatalf("Failed to generate signing keys: %v", err) - } - - return cryptoKeys{pubKey[:], privKey[:], pubSignKey[:], privSignKey[:]} -} diff --git a/node/data-flow.dot b/node/data-flow.dot deleted file mode 100644 index 45b6f05..0000000 --- a/node/data-flow.dot +++ /dev/null @@ -1,14 +0,0 @@ -digraph d { - ifReader -> connWriter; - connReader -> ifWriter; - connReader -> connWriter; - connReader -> supervisor; - mcReader -> supervisor; - supervisor -> connWriter; - supervisor -> mcWriter; - hubPoller -> supervisor; - - connWriter [shape="box"]; - mcWriter [shape="box"]; - ifWriter [shape="box"]; -} \ No newline at end of file diff --git a/node/dupcheck.go b/node/dupcheck.go deleted file mode 100644 index 76792ae..0000000 --- a/node/dupcheck.go +++ /dev/null @@ -1,76 +0,0 @@ -package node - -type dupCheck struct { - bitSet - head int - tail int - headCounter uint64 - tailCounter uint64 // Also next expected counter value. -} - -func newDupCheck(headCounter uint64) *dupCheck { - return &dupCheck{ - headCounter: headCounter, - tailCounter: headCounter + 1, - tail: 1, - } -} - -func (dc *dupCheck) IsDup(counter uint64) bool { - - // Before head => it's late, say it's a dup. - if counter < dc.headCounter { - return true - } - - // It's within the counter bounds. - if counter < dc.tailCounter { - index := (int(counter-dc.headCounter) + dc.head) % bitSetSize - if dc.Get(index) { - return true - } - - dc.Set(index) - return false - } - - // It's more than 1 beyond the tail. - delta := counter - dc.tailCounter - - // Full clear. - if delta >= bitSetSize-1 { - dc.ClearAll() - dc.Set(0) - - dc.tail = 1 - dc.head = 2 - dc.tailCounter = counter + 1 - dc.headCounter = dc.tailCounter - bitSetSize + 1 - - return false - } - - // Clear if necessary. - for i := 0; i < int(delta); i++ { - dc.put(false) - } - - dc.put(true) - return false -} - -func (dc *dupCheck) put(set bool) { - if set { - dc.Set(dc.tail) - } else { - dc.Clear(dc.tail) - } - - dc.tail = (dc.tail + 1) % bitSetSize - dc.tailCounter++ - - if dc.head == dc.tail { - dc.head = (dc.head + 1) % bitSetSize - dc.headCounter++ - } -} diff --git a/node/dupcheck_test.go b/node/dupcheck_test.go deleted file mode 100644 index 2156b4e..0000000 --- a/node/dupcheck_test.go +++ /dev/null @@ -1,54 +0,0 @@ -package node - -import ( - "testing" -) - -func TestDupCheck(t *testing.T) { - dc := newDupCheck(0) - - for i := range bitSetSize { - if dc.IsDup(uint64(i)) { - t.Fatal("!") - } - } - - type TestCase struct { - Counter uint64 - Dup bool - } - - testCases := []TestCase{ - {0, true}, - {1, true}, - {2, true}, - {3, true}, - {63, true}, - {256, true}, - {510, true}, - {511, true}, - {512, false}, - {0, true}, - {512, true}, - {513, false}, - {517, false}, - {512, true}, - {513, true}, - {514, false}, - {515, false}, - {516, false}, - {517, true}, - {2512, false}, - {2000, true}, - {2001, false}, - {4000, false}, - {4000 - 512, true}, // Too old. - {4000 - 511, false}, // Just in the window. - } - - for i, tc := range testCases { - if ok := dc.IsDup(tc.Counter); ok != tc.Dup { - t.Fatal(i, ok, tc) - } - } -} diff --git a/node/files.go b/node/files.go deleted file mode 100644 index 18f539b..0000000 --- a/node/files.go +++ /dev/null @@ -1,82 +0,0 @@ -package node - -import ( - "encoding/json" - "log" - "os" - "path/filepath" - "vppn/m" -) - -func configDir(netName string) string { - d, err := os.UserHomeDir() - if err != nil { - log.Fatalf("Failed to get user home directory: %v", err) - } - return filepath.Join(d, ".vppn", netName) -} - -func peerConfigPath(netName string) string { - return filepath.Join(configDir(netName), "peer-config.json") -} - -func peerStatePath(netName string) string { - return filepath.Join(configDir(netName), "peer-state.json") -} - -func storeJson(x any, outPath string) error { - outDir := filepath.Dir(outPath) - _ = os.MkdirAll(outDir, 0700) - - tmpPath := outPath + ".tmp" - buf, err := json.Marshal(x) - if err != nil { - return err - } - - f, err := os.Create(tmpPath) - if err != nil { - return err - } - - if _, err := f.Write(buf); err != nil { - f.Close() - return err - } - - if err := f.Sync(); err != nil { - f.Close() - return err - } - - if err := f.Close(); err != nil { - return err - } - - return os.Rename(tmpPath, outPath) -} - -func storePeerConfig(netName string, pc localConfig) error { - return storeJson(pc, peerConfigPath(netName)) -} - -func storeNetworkState(netName string, ps m.NetworkState) error { - return storeJson(ps, peerStatePath(netName)) -} - -func loadJson(dataPath string, ptr any) error { - data, err := os.ReadFile(dataPath) - if err != nil { - return err - } - - return json.Unmarshal(data, ptr) -} - -func loadPeerConfig(netName string) (pc localConfig, err error) { - return pc, loadJson(peerConfigPath(netName), &pc) -} - -func loadNetworkState(netName string) (ps m.NetworkState, err error) { - return ps, loadJson(peerStatePath(netName), &ps) -} diff --git a/node/globalfuncs.go b/node/globalfuncs.go deleted file mode 100644 index 2d13f57..0000000 --- a/node/globalfuncs.go +++ /dev/null @@ -1,8 +0,0 @@ -package node - -func getRelayRoute() *peerRoute { - if ip := relayIP.Load(); ip != nil { - return routingTable[*ip].Load() - } - return nil -} diff --git a/node/globals.go b/node/globals.go deleted file mode 100644 index 8538c4a..0000000 --- a/node/globals.go +++ /dev/null @@ -1,63 +0,0 @@ -package node - -import ( - "net" - "net/netip" - "net/url" - "sync/atomic" -) - -const ( - bufferSize = 1536 - if_mtu = 1200 - if_queue_len = 2048 - controlCipherOverhead = 16 - dataCipherOverhead = 16 - signOverhead = 64 -) - -var multicastAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom( - netip.AddrFrom4([4]byte{224, 0, 0, 157}), - 4560)) - -var ( - hubURL *url.URL - apiKey string - - // Configuration for this peer. - netName string - localIP byte - localPub bool - localAddr netip.AddrPort - privKey []byte - privSignKey []byte - - // TODO: Doesn't need to be global. - // Duplicate checkers for incoming packets. - dupChecks [256]*dupCheck = func() (out [256]*dupCheck) { - for i := range out { - out[i] = newDupCheck(0) - } - return - }() - - // TODO: Doesn't need to be global . - // Messages for the supervisor. - messages = make(chan any, 1024) - - // TODO: Doesn't need to be global . - // Global routing table. - routingTable [256]*atomic.Pointer[peerRoute] = func() (out [256]*atomic.Pointer[peerRoute]) { - for i := range out { - out[i] = &atomic.Pointer[peerRoute]{} - out[i].Store(&peerRoute{}) - } - return - }() - - // Managed by the relayManager. - relayIP = &atomic.Pointer[byte]{} - - // TODO: Only used by supervisor: can make local there. - publicAddrs = newPubAddrStore() -) diff --git a/node/header.go b/node/header.go deleted file mode 100644 index 915fe3e..0000000 --- a/node/header.go +++ /dev/null @@ -1,49 +0,0 @@ -package node - -import "unsafe" - -// ---------------------------------------------------------------------------- - -const ( - headerSize = 12 - controlStreamID = 2 - controlHeaderSize = 24 - dataStreamID = 1 - dataHeaderSize = 12 -) - -type header struct { - Version byte - StreamID byte - SourceIP byte - DestIP byte - Counter uint64 // Init with time.Now().Unix << 30 to ensure monotonic. -} - -func parseHeader(b []byte) (h header, ok bool) { - if len(b) < headerSize { - return - } - h.Version = b[0] - h.StreamID = b[1] - h.SourceIP = b[2] - h.DestIP = b[3] - h.Counter = *(*uint64)(unsafe.Pointer(&b[4])) - return h, true -} - -func (h *header) Parse(b []byte) { - h.Version = b[0] - h.StreamID = b[1] - h.SourceIP = b[2] - h.DestIP = b[3] - h.Counter = *(*uint64)(unsafe.Pointer(&b[4])) -} - -func (h *header) Marshal(buf []byte) { - buf[0] = h.Version - buf[1] = h.StreamID - buf[2] = h.SourceIP - buf[3] = h.DestIP - *(*uint64)(unsafe.Pointer(&buf[4])) = h.Counter -} diff --git a/node/header_test.go b/node/header_test.go deleted file mode 100644 index 9dbb061..0000000 --- a/node/header_test.go +++ /dev/null @@ -1,21 +0,0 @@ -package node - -import "testing" - -func TestHeaderMarshalParse(t *testing.T) { - nIn := header{ - StreamID: 23, - Counter: 3212, - SourceIP: 34, - DestIP: 200, - } - - buf := make([]byte, headerSize) - nIn.Marshal(buf) - - nOut := header{} - nOut.Parse(buf) - if nIn != nOut { - t.Fatal(nIn, nOut) - } -} diff --git a/node/hubpoller.go b/node/hubpoller.go deleted file mode 100644 index a069c8b..0000000 --- a/node/hubpoller.go +++ /dev/null @@ -1,92 +0,0 @@ -package node - -import ( - "encoding/json" - "io" - "log" - "net/http" - "time" - "vppn/m" -) - -type hubPoller struct { - client *http.Client - req *http.Request - versions [256]int64 -} - -func newHubPoller() *hubPoller { - u := *hubURL - u.Path = "/peer/fetch-state/" - - client := &http.Client{Timeout: 8 * time.Second} - - req := &http.Request{ - Method: http.MethodGet, - URL: &u, - Header: http.Header{}, - } - req.SetBasicAuth("", apiKey) - - return &hubPoller{ - client: client, - req: req, - } -} - -func (hp *hubPoller) Run() { - defer panicHandler() - - state, err := loadNetworkState(netName) - if err != nil { - log.Printf("Failed to load network state: %v", err) - log.Printf("Polling hub...") - hp.pollHub() - } else { - hp.applyNetworkState(state) - } - - for range time.Tick(64 * time.Second) { - hp.pollHub() - } -} - -func (hp *hubPoller) pollHub() { - var state m.NetworkState - - resp, err := hp.client.Do(hp.req) - if err != nil { - log.Printf("Failed to fetch peer state: %v", err) - return - } - body, err := io.ReadAll(resp.Body) - _ = resp.Body.Close() - if err != nil { - log.Printf("Failed to read body from hub: %v", err) - return - } - - if err := json.Unmarshal(body, &state); err != nil { - log.Printf("Failed to unmarshal response from hub: %v\n%s", err, body) - return - } - - hp.applyNetworkState(state) - - if err := storeNetworkState(netName, state); err != nil { - log.Printf("Failed to store network state: %v", err) - } -} - -func (hp *hubPoller) applyNetworkState(state m.NetworkState) { - for i, peer := range state.Peers { - if i != int(localIP) { - if peer == nil || peer.Version != hp.versions[i] { - messages <- peerUpdateMsg{PeerIP: byte(i), Peer: state.Peers[i]} - if peer != nil { - hp.versions[i] = peer.Version - } - } - } - } -} diff --git a/node/ifreader.go b/node/ifreader.go deleted file mode 100644 index 67d0999..0000000 --- a/node/ifreader.go +++ /dev/null @@ -1,102 +0,0 @@ -package node - -import ( - "io" - "log" - "sync/atomic" -) - -type ifReader struct { - iface io.Reader - routes [256]*atomic.Pointer[peerRoute] - relay *atomic.Pointer[peerRoute] - sendDataPacket func(pkt []byte, route *peerRoute) - relayDataPacket func(pkt []byte, route, relay *peerRoute) -} - -func newIFReader( - iface io.Reader, - routes [256]*atomic.Pointer[peerRoute], - relay *atomic.Pointer[peerRoute], - sendDataPacket func(pkt []byte, route *peerRoute), - relayDackPacket func(pkt []byte, route, relay *peerRoute), -) *ifReader { - return &ifReader{ - iface: iface, - routes: routes, - relay: relay, - sendDataPacket: sendDataPacket, - } -} - -func (r *ifReader) Run() { - var ( - packet = make([]byte, bufferSize) - remoteIP byte - ok bool - ) - - for { - packet = r.readNextPacket(packet) - if remoteIP, ok = r.parsePacket(packet); ok { - r.sendPacket(packet, remoteIP) - } - } -} - -func (r *ifReader) sendPacket(pkt []byte, remoteIP byte) { - route := r.routes[remoteIP].Load() - if !route.Up { - log.Printf("Route not connected: %d", remoteIP) - return - } - - // Direct path => early return. - if route.Direct { - r.sendDataPacket(pkt, route) - return - } - - if relay := r.relay.Load(); relay != nil && relay.Up { - r.relayDataPacket(pkt, route, relay) - } -} - -// Get next packet, returning packet, and destination ip. -func (r *ifReader) readNextPacket(buf []byte) []byte { - n, err := r.iface.Read(buf[:cap(buf)]) - if err != nil { - log.Fatalf("Failed to read from interface: %v", err) - } - - return buf[:n] -} - -func (r *ifReader) parsePacket(buf []byte) (byte, bool) { - n := len(buf) - if n == 0 { - return 0, false - } - - version := buf[0] >> 4 - - switch version { - case 4: - if n < 20 { - log.Printf("Short IPv4 packet: %d", len(buf)) - return 0, false - } - return buf[19], true - - case 6: - if len(buf) < 40 { - log.Printf("Short IPv6 packet: %d", len(buf)) - return 0, false - } - return buf[39], true - - default: - log.Printf("Invalid IP packet version: %v", version) - return 0, false - } -} diff --git a/node/ifreader_test.go b/node/ifreader_test.go deleted file mode 100644 index 8f173f4..0000000 --- a/node/ifreader_test.go +++ /dev/null @@ -1,117 +0,0 @@ -package node - -import ( - "bytes" - "net" - "sync/atomic" - "testing" -) - -// Test that we parse IPv4 packets correctly. -func TestIFReader_parsePacket_ipv4(t *testing.T) { - r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil, nil) - - pkt := make([]byte, 1234) - pkt[0] = 4 << 4 - pkt[19] = 128 - - if ip, ok := r.parsePacket(pkt); !ok || ip != 128 { - t.Fatal(ip, ok) - } -} - -// Test that we parse IPv6 packets correctly. -func TestIFReader_parsePacket_ipv6(t *testing.T) { - r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil, nil) - - pkt := make([]byte, 1234) - pkt[0] = 6 << 4 - pkt[39] = 42 - - if ip, ok := r.parsePacket(pkt); !ok || ip != 42 { - t.Fatal(ip, ok) - } -} - -// Test that empty packets work as expected. -func TestIFReader_parsePacket_emptyPacket(t *testing.T) { - r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil, nil) - - pkt := make([]byte, 0) - if ip, ok := r.parsePacket(pkt); ok { - t.Fatal(ip, ok) - } -} - -// Test that invalid IP versions fail. -func TestIFReader_parsePacket_invalidIPVersion(t *testing.T) { - r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil, nil) - - for i := byte(1); i < 16; i++ { - if i == 4 || i == 6 { - continue - } - pkt := make([]byte, 1234) - pkt[0] = i << 4 - - if ip, ok := r.parsePacket(pkt); ok { - t.Fatal(i, ip, ok) - } - } -} - -// Test that short IPv4 packets fail. -func TestIFReader_parsePacket_shortIPv4(t *testing.T) { - r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil, nil) - - pkt := make([]byte, 19) - pkt[0] = 4 << 4 - - if ip, ok := r.parsePacket(pkt); ok { - t.Fatal(ip, ok) - } -} - -// Test that short IPv6 packets fail. -func TestIFReader_parsePacket_shortIPv6(t *testing.T) { - r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil, nil) - - pkt := make([]byte, 39) - pkt[0] = 6 << 4 - - if ip, ok := r.parsePacket(pkt); ok { - t.Fatal(ip, ok) - } -} - -// Test that we can read a packet. -func TestIFReader_readNextpacket(t *testing.T) { - in, out := net.Pipe() - r := newIFReader(out, [256]*atomic.Pointer[peerRoute]{}, nil, nil, nil) - defer in.Close() - defer out.Close() - - go in.Write([]byte("hello world!")) - - pkt := r.readNextPacket(make([]byte, bufferSize)) - if !bytes.Equal(pkt, []byte("hello world!")) { - t.Fatalf("%s", pkt) - } -} - -// Testing that we can send a packet directly. -func TestIFReader_sendPacket_direct(t *testing.T) { - // TODO -} - -// Testing that we don't send a packet if route isn't up. -func TestIFReader_sendPacket_directNotUp(t *testing.T) { - // TODO -} - -// Testing that we can send a packet via a relay. -func TestIFReader_sendPacket_relayed(t *testing.T) { - // TODO -} - -// Testing that we don't try to send on a nil relay IP. diff --git a/node/ifwriter.go b/node/ifwriter.go deleted file mode 100644 index adb74e3..0000000 --- a/node/ifwriter.go +++ /dev/null @@ -1,5 +0,0 @@ -package node - -import "io" - -type ifWriter io.Writer diff --git a/node/interface.go b/node/interface.go deleted file mode 100644 index 4b492b4..0000000 --- a/node/interface.go +++ /dev/null @@ -1,177 +0,0 @@ -package node - -import ( - "fmt" - "io" - "log" - "net" - "os" - "syscall" - - "golang.org/x/sys/unix" -) - -// Get next packet, returning packet, ip, and possible error. -func readNextPacket(iface io.ReadWriteCloser, buf []byte) ([]byte, byte, error) { - var ( - version byte - ip byte - ) - for { - n, err := iface.Read(buf[:cap(buf)]) - if err != nil { - return nil, ip, err - } - - buf = buf[:n] - version = buf[0] >> 4 - - switch version { - case 4: - if n < 20 { - log.Printf("Short IPv4 packet: %d", len(buf)) - continue - } - ip = buf[19] - - case 6: - if len(buf) < 40 { - log.Printf("Short IPv6 packet: %d", len(buf)) - continue - } - ip = buf[39] - - default: - log.Printf("Invalid IP packet version: %v", version) - continue - } - - return buf, ip, nil - } -} - -func openInterface(network []byte, localIP byte, name string) (io.ReadWriteCloser, error) { - if len(network) != 4 { - return nil, fmt.Errorf("expected network to be 4 bytes, got %d", len(network)) - } - ip := net.IPv4(network[0], network[1], network[2], localIP) - - ////////////////////////// - // Create TUN Interface // - ////////////////////////// - - tunFD, err := syscall.Open("/dev/net/tun", syscall.O_RDWR|unix.O_CLOEXEC, 0600) - if err != nil { - return nil, fmt.Errorf("failed to open TUN device: %w", err) - } - - // New interface request. - req, err := unix.NewIfreq(name) - if err != nil { - return nil, fmt.Errorf("failed to create new TUN interface request: %w", err) - } - - // Flags: - // - // IFF_NO_PI => don't add packet info data to packets sent to the interface. - // IFF_TUN => create a TUN device handling IP packets. - req.SetUint16(unix.IFF_NO_PI | unix.IFF_TUN) - - err = unix.IoctlIfreq(tunFD, unix.TUNSETIFF, req) - if err != nil { - return nil, fmt.Errorf("failed to set TUN device settings: %w", err) - } - - // Name may not be exactly the same? - name = req.Name() - - ///////////// - // Set MTU // - ///////////// - - // We need a socket file descriptor to set other options for some reason. - sockFD, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP) - if err != nil { - return nil, fmt.Errorf("failed to open socket: %w", err) - } - defer unix.Close(sockFD) - - req, err = unix.NewIfreq(name) - if err != nil { - return nil, fmt.Errorf("failed to create MTU interface request: %w", err) - } - - req.SetUint32(if_mtu) - if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFMTU, req); err != nil { - return nil, fmt.Errorf("failed to set interface MTU: %w", err) - } - - ////////////////////// - // Set Queue Length // - ////////////////////// - - req, err = unix.NewIfreq(name) - if err != nil { - return nil, fmt.Errorf("failed to create IP interface request: %w", err) - } - - req.SetUint16(if_queue_len) - if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFTXQLEN, req); err != nil { - return nil, fmt.Errorf("failed to set interface queue length: %w", err) - } - - ///////////////////// - // Set IP and Mask // - ///////////////////// - - req, err = unix.NewIfreq(name) - if err != nil { - return nil, fmt.Errorf("failed to create IP interface request: %w", err) - } - - if err := req.SetInet4Addr(ip.To4()); err != nil { - return nil, fmt.Errorf("failed to set interface request IP: %w", err) - } - - if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFADDR, req); err != nil { - return nil, fmt.Errorf("failed to set interface IP: %w", err) - } - - // SET MASK - must happen after setting address. - req, err = unix.NewIfreq(name) - if err != nil { - return nil, fmt.Errorf("failed to create mask interface request: %w", err) - } - - if err := req.SetInet4Addr(net.IPv4(255, 255, 255, 0).To4()); err != nil { - return nil, fmt.Errorf("failed to set interface request mask: %w", err) - } - - if err := unix.IoctlIfreq(sockFD, unix.SIOCSIFNETMASK, req); err != nil { - return nil, fmt.Errorf("failed to set interface mask: %w", err) - } - - //////////////////////// - // Bring Interface Up // - //////////////////////// - - req, err = unix.NewIfreq(name) - if err != nil { - return nil, fmt.Errorf("failed to create up interface request: %w", err) - } - - // Get current flags. - if err = unix.IoctlIfreq(sockFD, unix.SIOCGIFFLAGS, req); err != nil { - return nil, fmt.Errorf("failed to get interface flags: %w", err) - } - - flags := req.Uint16() | unix.IFF_UP | unix.IFF_RUNNING - - // Set UP flag / broadcast flags. - req.SetUint16(flags) - if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFFLAGS, req); err != nil { - return nil, fmt.Errorf("failed to set interface up: %w", err) - } - - return os.NewFile(uintptr(tunFD), "tun"), nil -} diff --git a/node/localdiscovery.go b/node/localdiscovery.go deleted file mode 100644 index 90f2e60..0000000 --- a/node/localdiscovery.go +++ /dev/null @@ -1,97 +0,0 @@ -package node - -import ( - "log" - "net" - "time" - - "golang.org/x/crypto/nacl/sign" -) - -func localDiscovery() { - conn, err := net.ListenMulticastUDP("udp", nil, multicastAddr) - if err != nil { - log.Printf("Failed to bind to multicast address: %v", err) - return - } - - go sendLocalDiscovery(conn) - go recvLocalDiscovery(conn) -} - -func sendLocalDiscovery(conn *net.UDPConn) { - var ( - buf1 = make([]byte, bufferSize) - buf2 = make([]byte, bufferSize) - ) - - for range time.Tick(16 * time.Second) { - signed := buildLocalDiscoveryPacket(buf1, buf2) - if _, err := conn.WriteToUDP(signed, multicastAddr); err != nil { - log.Printf("Failed to write multicast UDP packet: %v", err) - } - } -} - -func recvLocalDiscovery(conn *net.UDPConn) { - var ( - raw = make([]byte, bufferSize) - buf = make([]byte, bufferSize) - ) - - for { - n, remoteAddr, err := conn.ReadFromUDPAddrPort(raw[:bufferSize]) - if err != nil { - log.Fatalf("Failed to read from UDP port (multicast): %v", err) - } - - raw = raw[:n] - h, ok := openLocalDiscoveryPacket(raw, buf) - if !ok { - log.Printf("Failed to open discovery packet?") - continue - } - - msg := controlMsg[localDiscoveryPacket]{ - SrcIP: h.SourceIP, - SrcAddr: remoteAddr, - Packet: localDiscoveryPacket{}, - } - - select { - case messages <- msg: - default: - log.Printf("Dropping local discovery message.") - } - } -} - -func buildLocalDiscoveryPacket(buf1, buf2 []byte) []byte { - h := header{ - StreamID: controlStreamID, - Counter: 0, - SourceIP: localIP, - DestIP: 255, - } - out := buf1[:headerSize] - h.Marshal(out) - return sign.Sign(buf2[:0], out, (*[64]byte)(privSignKey)) -} - -func openLocalDiscoveryPacket(raw, buf []byte) (h header, ok bool) { - if len(raw) != headerSize+signOverhead { - ok = false - return - } - - h.Parse(raw[signOverhead:]) - route := routingTable[h.SourceIP].Load() - if route == nil || route.PubSignKey == nil { - log.Printf("Missing signing key: %d", h.SourceIP) - ok = false - return - } - - _, ok = sign.Open(buf[:0], raw, (*[32]byte)(route.PubSignKey)) - return -} diff --git a/node/localdiscovery_test.go b/node/localdiscovery_test.go deleted file mode 100644 index b00b29d..0000000 --- a/node/localdiscovery_test.go +++ /dev/null @@ -1,35 +0,0 @@ -package node - -import ( - "bytes" - "crypto/rand" - "testing" - - "golang.org/x/crypto/nacl/sign" -) - -func TestLocalDiscoveryPacketSigning(t *testing.T) { - localIP = 32 - - var ( - buf1 = make([]byte, bufferSize) - buf2 = make([]byte, bufferSize) - pubSignKey, privSigKey, _ = sign.GenerateKey(rand.Reader) - ) - - privSignKey = privSigKey[:] - route := routingTable[localIP].Load() - route.IP = byte(localIP) - route.PubSignKey = pubSignKey[:] - routingTable[localIP].Store(route) - - out := buildLocalDiscoveryPacket(buf1, buf2) - - h, ok := openLocalDiscoveryPacket(bytes.Clone(out), buf1) - if !ok { - t.Fatal(h, ok) - } - if h.StreamID != controlStreamID || h.SourceIP != localIP || h.DestIP != 255 { - t.Fatal(h) - } -} diff --git a/node/main.go b/node/main.go deleted file mode 100644 index 78611a8..0000000 --- a/node/main.go +++ /dev/null @@ -1,320 +0,0 @@ -package node - -import ( - "bytes" - "crypto/rand" - "encoding/json" - "flag" - "fmt" - "io" - "log" - "net" - "net/http" - "net/netip" - "net/url" - "os" - "runtime/debug" - "time" - "vppn/m" - - "golang.org/x/crypto/nacl/box" - "golang.org/x/crypto/nacl/sign" -) - -func panicHandler() { - if r := recover(); r != nil { - log.Fatalf("\n %v\n\nstacktrace from panic: %s\n", r, string(debug.Stack())) - } -} - -func Main() { - defer panicHandler() - - var hubAddress string - - flag.StringVar(&netName, "name", "", "[REQUIRED] The network name.") - flag.StringVar(&hubAddress, "hub-address", "", "[REQUIRED] The hub address.") - flag.StringVar(&apiKey, "api-key", "", "[REQUIRED] The node's API key.") - flag.Parse() - - if netName == "" || hubAddress == "" || apiKey == "" { - flag.Usage() - os.Exit(1) - } - - var err error - - hubURL, err = url.Parse(hubAddress) - if err != nil { - log.Fatalf("Failed to parse hub address: %v", err) - } - - main() -} - -func initPeerWithHub() { - encPubKey, encPrivKey, err := box.GenerateKey(rand.Reader) - if err != nil { - log.Fatalf("Failed to generate encryption keys: %v", err) - } - - signPubKey, signPrivKey, err := sign.GenerateKey(rand.Reader) - if err != nil { - log.Fatalf("Failed to generate signing keys: %v", err) - } - - initURL := *hubURL - initURL.Path = "/peer/init/" - - args := m.PeerInitArgs{ - EncPubKey: encPubKey[:], - PubSignKey: signPubKey[:], - } - - buf := &bytes.Buffer{} - if err := json.NewEncoder(buf).Encode(args); err != nil { - log.Fatalf("Failed to encode init args: %v", err) - } - - req, err := http.NewRequest(http.MethodPost, initURL.String(), buf) - if err != nil { - log.Fatalf("Failed to construct request: %v", err) - } - req.SetBasicAuth("", apiKey) - - resp, err := http.DefaultClient.Do(req) - if err != nil { - log.Fatalf("Failed to init with hub: %v", err) - } - defer resp.Body.Close() - - data, err := io.ReadAll(resp.Body) - if err != nil { - log.Fatalf("Failed to read response body: %v", err) - } - - peerConfig := localConfig{} - if err := json.Unmarshal(data, &peerConfig.PeerConfig); err != nil { - log.Fatalf("Failed to parse configuration: %v\n%s", err, data) - } - - peerConfig.PubKey = encPubKey[:] - peerConfig.PrivKey = encPrivKey[:] - peerConfig.PubSignKey = signPubKey[:] - peerConfig.PrivSignKey = signPrivKey[:] - - if err := storePeerConfig(netName, peerConfig); err != nil { - log.Fatalf("Failed to store configuration: %v", err) - } - - log.Print("Initialization successful.") -} - -// ---------------------------------------------------------------------------- - -func main() { - config, err := loadPeerConfig(netName) - if err != nil { - log.Printf("Failed to load configuration: %v", err) - log.Printf("Initializing...") - initPeerWithHub() - - config, err = loadPeerConfig(netName) - if err != nil { - log.Fatalf("Failed to load configuration: %v", err) - } - } - - iface, err := openInterface(config.Network, config.PeerIP, netName) - if err != nil { - log.Fatalf("Failed to open interface: %v", err) - } - - myAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", config.Port)) - if err != nil { - log.Fatalf("Failed to resolve UDP address: %v", err) - } - - conn, err := net.ListenUDP("udp", myAddr) - if err != nil { - log.Fatalf("Failed to open UDP port: %v", err) - } - - conn.SetReadBuffer(1024 * 1024 * 8) - conn.SetWriteBuffer(1024 * 1024 * 8) - - localIP = config.PeerIP - - ip, ok := netip.AddrFromSlice(config.PublicIP) - if ok { - localPub = true - localAddr = netip.AddrPortFrom(ip, config.Port) - } - - privKey = config.PrivKey - privSignKey = config.PrivSignKey - - if !localPub { - go relayManager() - go localDiscovery() - } - - go func() { - for range time.Tick(pingInterval) { - messages <- pingTimerMsg{} - } - }() - - sender := newPacketSender(conn) - - go startPeerSuper(routingTable, messages, sender) - - go newHubPoller().Run() - go readFromConn(conn, iface, sender) - - readFromIFace(iface, sender) -} - -// ---------------------------------------------------------------------------- - -func readFromConn(conn *net.UDPConn, iface io.ReadWriteCloser, sender dataPacketSender) { - - defer panicHandler() - - var ( - remoteAddr netip.AddrPort - n int - err error - buf = make([]byte, bufferSize) - decBuf = make([]byte, bufferSize) - data []byte - h header - ) - - for { - n, remoteAddr, err = conn.ReadFromUDPAddrPort(buf[:bufferSize]) - if err != nil { - log.Fatalf("Failed to read from UDP port: %v", err) - } - - remoteAddr = netip.AddrPortFrom(remoteAddr.Addr().Unmap(), remoteAddr.Port()) - - data = buf[:n] - - if n < headerSize { - continue // Packet it soo short. - } - - h.Parse(data) - switch h.StreamID { - case controlStreamID: - handleControlPacket(remoteAddr, h, data, decBuf) - - case dataStreamID: - handleDataPacket(h, data, decBuf, iface, sender) - - default: - log.Printf("Unknown stream ID: %d", h.StreamID) - } - } -} - -func handleControlPacket(addr netip.AddrPort, h header, data, decBuf []byte) { - route := routingTable[h.SourceIP].Load() - if route.ControlCipher == nil { - //log.Printf("Not connected (control).") - return - } - - if h.DestIP != localIP { - log.Printf("Incorrect destination IP on control packet: %#v", h) - return - } - - out, ok := route.ControlCipher.Decrypt(data, decBuf) - if !ok { - log.Printf("Failed to decrypt control packet.") - return - } - - if len(out) == 0 { - log.Printf("Empty control packet from: %d", h.SourceIP) - return - } - - if dupChecks[h.SourceIP].IsDup(h.Counter) { - log.Printf("[%03d] Duplicate control packet: %d", h.SourceIP, h.Counter) - return - } - - msg, err := parseControlMsg(h.SourceIP, addr, out) - if err != nil { - log.Printf("Failed to parse control packet: %v", err) - return - } - - select { - case messages <- msg: - default: - log.Printf("Dropping control packet.") - } -} - -func handleDataPacket(h header, data []byte, decBuf []byte, iface ifWriter, sender dataPacketSender) { - route := routingTable[h.SourceIP].Load() - if !route.Up { - log.Printf("Not connected (recv).") - return - } - - dec, ok := route.DataCipher.Decrypt(data, decBuf) - if !ok { - log.Printf("Failed to decrypt data packet.") - return - } - - if dupChecks[h.SourceIP].IsDup(h.Counter) { - log.Printf("[%03d] Duplicate data packet: %d", h.SourceIP, h.Counter) - return - } - - if h.DestIP == localIP { - if _, err := iface.Write(dec); err != nil { - log.Fatalf("Failed to write to interface: %v", err) - } - return - } - - destRoute := routingTable[h.DestIP].Load() - if !destRoute.Up { - log.Printf("Not connected (relay): %d", destRoute.IP) - return - } - - sender.SendEncryptedDataPacket(dec, destRoute.RemoteAddr) -} - -// ---------------------------------------------------------------------------- - -func readFromIFace(iface io.ReadWriteCloser, sender dataPacketSender) { - var ( - packet = make([]byte, bufferSize) - remoteIP byte - err error - ) - - for { - packet, remoteIP, err = readNextPacket(iface, packet) - if err != nil { - log.Fatalf("Failed to read from interface: %v", err) - } - - route := routingTable[remoteIP].Load() - if !route.Up { - log.Printf("Route not connected: %d", remoteIP) - continue - } - - sender.SendDataPacket(packet, *route) - } -} diff --git a/node/main_test.go b/node/main_test.go deleted file mode 100644 index bf077a2..0000000 --- a/node/main_test.go +++ /dev/null @@ -1,37 +0,0 @@ -package node - -import ( - "crypto/rand" - "log" - - "golang.org/x/crypto/nacl/box" - "golang.org/x/crypto/nacl/sign" -) - -type testPeer struct { - IP byte - PubKey []byte - PrivKey []byte - PubSignKey []byte - PrivSignKey []byte -} - -func newTestPeer(ip byte) testPeer { - encPubKey, encPrivKey, err := box.GenerateKey(rand.Reader) - if err != nil { - log.Fatalf("Failed to generate encryption keys: %v", err) - } - - signPubKey, signPrivKey, err := sign.GenerateKey(rand.Reader) - if err != nil { - log.Fatalf("Failed to generate signing keys: %v", err) - } - - return testPeer{ - IP: ip, - PubKey: encPubKey[:], - PrivKey: encPrivKey[:], - PubSignKey: signPubKey[:], - PrivSignKey: signPrivKey[:], - } -} diff --git a/node/mcwriter.go b/node/mcwriter.go deleted file mode 100644 index 99e5b58..0000000 --- a/node/mcwriter.go +++ /dev/null @@ -1,62 +0,0 @@ -package node - -import ( - "log" - "net" - - "golang.org/x/crypto/nacl/sign" -) - -// ---------------------------------------------------------------------------- - -type udpWriter interface { - WriteToUDP([]byte, *net.UDPAddr) (int, error) -} - -// ---------------------------------------------------------------------------- - -func createLocalDiscoveryPacket(localIP byte, signingKey []byte) []byte { - h := header{ - SourceIP: localIP, - DestIP: 255, - } - buf := make([]byte, headerSize) - h.Marshal(buf) - out := make([]byte, headerSize+signOverhead) - return sign.Sign(out[:0], buf, (*[64]byte)(signingKey)) -} - -func headerFromLocalDiscoveryPacket(pkt []byte) (h header, ok bool) { - if len(pkt) != headerSize+signOverhead { - return - } - - h.Parse(pkt[signOverhead:]) - ok = true - return -} - -func verifyLocalDiscoveryPacket(pkt, buf []byte, pubSignKey []byte) bool { - _, ok := sign.Open(buf[:0], pkt, (*[32]byte)(pubSignKey)) - return ok -} - -// ---------------------------------------------------------------------------- - -type mcWriter struct { - conn udpWriter - discoveryPacket []byte -} - -func newMCWriter(conn udpWriter, localIP byte, signingKey []byte) *mcWriter { - return &mcWriter{ - conn: conn, - discoveryPacket: createLocalDiscoveryPacket(localIP, signingKey), - } -} - -func (w *mcWriter) SendLocalDiscovery() { - if _, err := w.conn.WriteToUDP(w.discoveryPacket, multicastAddr); err != nil { - log.Printf("Failed to write multicast UDP packet: %v", err) - } -} diff --git a/node/mcwriter_test.go b/node/mcwriter_test.go deleted file mode 100644 index d182239..0000000 --- a/node/mcwriter_test.go +++ /dev/null @@ -1,102 +0,0 @@ -package node - -import ( - "bytes" - "net" - "testing" -) - -// ---------------------------------------------------------------------------- - -// Testing that we can create and verify a local discovery packet. -func TestVerifyLocalDiscoveryPacket_valid(t *testing.T) { - keys := generateKeys() - - created := createLocalDiscoveryPacket(55, keys.PrivSignKey) - - header, ok := headerFromLocalDiscoveryPacket(created) - if !ok { - t.Fatal(ok) - } - if header.SourceIP != 55 || header.DestIP != 255 { - t.Fatal(header) - } - - if !verifyLocalDiscoveryPacket(created, make([]byte, 1024), keys.PubSignKey) { - t.Fatal("Not valid") - } -} - -// Testing that we don't try to parse short packets. -func TestVerifyLocalDiscoveryPacket_tooShort(t *testing.T) { - keys := generateKeys() - - created := createLocalDiscoveryPacket(55, keys.PrivSignKey) - - _, ok := headerFromLocalDiscoveryPacket(created[:len(created)-1]) - if ok { - t.Fatal(ok) - } -} - -// Testing that modifying a packet makes it invalid. -func TestVerifyLocalDiscoveryPacket_invalid(t *testing.T) { - keys := generateKeys() - - created := createLocalDiscoveryPacket(55, keys.PrivSignKey) - buf := make([]byte, 1024) - for i := range created { - modified := bytes.Clone(created) - modified[i]++ - if verifyLocalDiscoveryPacket(modified, buf, keys.PubSignKey) { - t.Fatal("Verification should have failed.") - } - } -} - -// ---------------------------------------------------------------------------- - -type testUDPWriter struct { - written [][]byte -} - -func (w *testUDPWriter) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) { - w.written = append(w.written, bytes.Clone(b)) - return len(b), nil -} - -func (w *testUDPWriter) Written() [][]byte { - out := w.written - w.written = [][]byte{} - return out -} - -// ---------------------------------------------------------------------------- - -// Testing that the mcWriter sends local discovery packets as expected. -func TestMCWriter_SendLocalDiscovery(t *testing.T) { - keys := generateKeys() - writer := &testUDPWriter{} - - mcw := newMCWriter(writer, 42, keys.PrivSignKey) - mcw.SendLocalDiscovery() - - out := writer.Written() - if len(out) != 1 { - t.Fatal(out) - } - - pkt := out[0] - - header, ok := headerFromLocalDiscoveryPacket(pkt) - if !ok { - t.Fatal(ok) - } - if header.SourceIP != 42 || header.DestIP != 255 { - t.Fatal(header) - } - - if !verifyLocalDiscoveryPacket(pkt, make([]byte, 1024), keys.PubSignKey) { - t.Fatal("Verification should succeed.") - } -} diff --git a/node/messages.go b/node/messages.go deleted file mode 100644 index 64ca5fe..0000000 --- a/node/messages.go +++ /dev/null @@ -1,58 +0,0 @@ -package node - -import ( - "net/netip" - "vppn/m" -) - -// ---------------------------------------------------------------------------- - -type controlMsg[T any] struct { - SrcIP byte - SrcAddr netip.AddrPort - // TODO: RecvdAt int64 // Unixmilli. - Packet T -} - -func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error) { - switch buf[0] { - - case packetTypeSyn: - packet, err := parseSynPacket(buf) - return controlMsg[synPacket]{ - SrcIP: srcIP, - SrcAddr: srcAddr, - Packet: packet, - }, err - - case packetTypeAck: - packet, err := parseAckPacket(buf) - return controlMsg[ackPacket]{ - SrcIP: srcIP, - SrcAddr: srcAddr, - Packet: packet, - }, err - - case packetTypeProbe: - packet, err := parseProbePacket(buf) - return controlMsg[probePacket]{ - SrcIP: srcIP, - SrcAddr: srcAddr, - Packet: packet, - }, err - - default: - return nil, errUnknownPacketType - } -} - -// ---------------------------------------------------------------------------- - -type peerUpdateMsg struct { - PeerIP byte - Peer *m.Peer -} - -// ---------------------------------------------------------------------------- - -type pingTimerMsg struct{} diff --git a/node/packets-util.go b/node/packets-util.go deleted file mode 100644 index b3071ab..0000000 --- a/node/packets-util.go +++ /dev/null @@ -1,190 +0,0 @@ -package node - -import ( - "net/netip" - "sync/atomic" - "time" - "unsafe" -) - -var traceIDCounter uint64 = uint64(time.Now().Unix()<<30) + 1 - -func newTraceID() uint64 { - return atomic.AddUint64(&traceIDCounter, 1) -} - -// ---------------------------------------------------------------------------- - -type binWriter struct { - b []byte - i int -} - -func newBinWriter(buf []byte) *binWriter { - buf = buf[:cap(buf)] - return &binWriter{buf, 0} -} - -func (w *binWriter) Bool(b bool) *binWriter { - if b { - return w.Byte(1) - } - return w.Byte(0) -} - -func (w *binWriter) Byte(b byte) *binWriter { - w.b[w.i] = b - w.i++ - return w -} - -func (w *binWriter) SharedKey(key [32]byte) *binWriter { - copy(w.b[w.i:w.i+32], key[:]) - w.i += 32 - return w -} - -func (w *binWriter) Uint16(x uint16) *binWriter { - *(*uint16)(unsafe.Pointer(&w.b[w.i])) = x - w.i += 2 - return w -} - -func (w *binWriter) Uint64(x uint64) *binWriter { - *(*uint64)(unsafe.Pointer(&w.b[w.i])) = x - w.i += 8 - return w -} - -func (w *binWriter) Int64(x int64) *binWriter { - *(*int64)(unsafe.Pointer(&w.b[w.i])) = x - w.i += 8 - return w -} - -func (w *binWriter) AddrPort(addrPort netip.AddrPort) *binWriter { - w.Bool(addrPort.IsValid()) - addr := addrPort.Addr().As16() - copy(w.b[w.i:w.i+16], addr[:]) - w.i += 16 - return w.Uint16(addrPort.Port()) -} - -func (w *binWriter) AddrPortArray(l [8]netip.AddrPort) *binWriter { - for _, addrPort := range l { - w.AddrPort(addrPort) - } - return w -} - -func (w *binWriter) Build() []byte { - return w.b[:w.i] -} - -// ---------------------------------------------------------------------------- - -type binReader struct { - b []byte - i int - err error -} - -func newBinReader(buf []byte) *binReader { - return &binReader{b: buf} -} - -func (r *binReader) hasBytes(n int) bool { - if r.err != nil || (len(r.b)-r.i) < n { - r.err = errMalformedPacket - return false - } - return true -} - -func (r *binReader) Bool(b *bool) *binReader { - var bb byte - r.Byte(&bb) - *b = bb != 0 - return r -} - -func (r *binReader) Byte(b *byte) *binReader { - if !r.hasBytes(1) { - return r - } - *b = r.b[r.i] - r.i++ - return r -} - -func (r *binReader) SharedKey(x *[32]byte) *binReader { - if !r.hasBytes(32) { - return r - } - *x = ([32]byte)(r.b[r.i : r.i+32]) - r.i += 32 - return r -} - -func (r *binReader) Uint16(x *uint16) *binReader { - if !r.hasBytes(2) { - return r - } - *x = *(*uint16)(unsafe.Pointer(&r.b[r.i])) - r.i += 2 - return r -} - -func (r *binReader) Uint64(x *uint64) *binReader { - if !r.hasBytes(8) { - return r - } - *x = *(*uint64)(unsafe.Pointer(&r.b[r.i])) - r.i += 8 - return r -} - -func (r *binReader) Int64(x *int64) *binReader { - if !r.hasBytes(8) { - return r - } - *x = *(*int64)(unsafe.Pointer(&r.b[r.i])) - r.i += 8 - return r -} - -func (r *binReader) AddrPort(x *netip.AddrPort) *binReader { - if !r.hasBytes(19) { - return r - } - - var ( - valid bool - port uint16 - ) - - r.Bool(&valid) - addr := netip.AddrFrom16(([16]byte)(r.b[r.i : r.i+16])).Unmap() - r.i += 16 - - r.Uint16(&port) - - if valid { - *x = netip.AddrPortFrom(addr, port) - } else { - *x = netip.AddrPort{} - } - - return r -} - -func (r *binReader) AddrPortArray(x *[8]netip.AddrPort) *binReader { - for i := range x { - r.AddrPort(&x[i]) - } - return r -} - -func (r *binReader) Error() error { - return r.err -} diff --git a/node/packets-util_test.go b/node/packets-util_test.go deleted file mode 100644 index 96eab1a..0000000 --- a/node/packets-util_test.go +++ /dev/null @@ -1,56 +0,0 @@ -package node - -import ( - "net/netip" - "reflect" - "testing" -) - -func TestBinWriteRead(t *testing.T) { - buf := make([]byte, 1024) - - type Item struct { - Type byte - TraceID uint64 - Addrs [8]netip.AddrPort - DestAddr netip.AddrPort - } - - in := Item{ - 1, - 2, - [8]netip.AddrPort{}, - netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 22), - } - - in.Addrs[0] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{0, 1, 2, 3}), 20) - in.Addrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 5}), 22) - in.Addrs[3] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 3}), 23) - in.Addrs[4] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 4}), 24) - in.Addrs[5] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 5}), 25) - in.Addrs[6] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 6}), 26) - in.Addrs[7] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{7, 8, 9, 7}), 27) - - buf = newBinWriter(buf). - Byte(in.Type). - Uint64(in.TraceID). - AddrPort(in.DestAddr). - AddrPortArray(in.Addrs). - Build() - - out := Item{} - - err := newBinReader(buf). - Byte(&out.Type). - Uint64(&out.TraceID). - AddrPort(&out.DestAddr). - AddrPortArray(&out.Addrs). - Error() - if err != nil { - t.Fatal(err) - } - - if !reflect.DeepEqual(in, out) { - t.Fatal(in, out) - } -} diff --git a/node/packets.go b/node/packets.go deleted file mode 100644 index f3aa523..0000000 --- a/node/packets.go +++ /dev/null @@ -1,130 +0,0 @@ -package node - -import ( - "errors" - "net/netip" -) - -var ( - errMalformedPacket = errors.New("malformed packet") - errUnknownPacketType = errors.New("unknown packet type") -) - -const ( - packetTypeSyn = iota + 1 - packetTypeSynAck - packetTypeAck - packetTypeProbe - packetTypeAddrDiscovery -) - -// ---------------------------------------------------------------------------- - -type synPacket struct { - TraceID uint64 // TraceID to match response w/ request. - // TODO: SentAt int64 // Unixmilli. - SharedKey [32]byte // Our shared key. - Direct bool - PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. -} - -func (p synPacket) Marshal(buf []byte) []byte { - return newBinWriter(buf). - Byte(packetTypeSyn). - Uint64(p.TraceID). - SharedKey(p.SharedKey). - Bool(p.Direct). - AddrPort(p.PossibleAddrs[0]). - AddrPort(p.PossibleAddrs[1]). - AddrPort(p.PossibleAddrs[2]). - AddrPort(p.PossibleAddrs[3]). - AddrPort(p.PossibleAddrs[4]). - AddrPort(p.PossibleAddrs[5]). - AddrPort(p.PossibleAddrs[6]). - AddrPort(p.PossibleAddrs[7]). - Build() -} - -func parseSynPacket(buf []byte) (p synPacket, err error) { - err = newBinReader(buf[1:]). - Uint64(&p.TraceID). - SharedKey(&p.SharedKey). - Bool(&p.Direct). - AddrPort(&p.PossibleAddrs[0]). - AddrPort(&p.PossibleAddrs[1]). - AddrPort(&p.PossibleAddrs[2]). - AddrPort(&p.PossibleAddrs[3]). - AddrPort(&p.PossibleAddrs[4]). - AddrPort(&p.PossibleAddrs[5]). - AddrPort(&p.PossibleAddrs[6]). - AddrPort(&p.PossibleAddrs[7]). - Error() - return -} - -// ---------------------------------------------------------------------------- - -type ackPacket struct { - TraceID uint64 - ToAddr netip.AddrPort - PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. -} - -func (p ackPacket) Marshal(buf []byte) []byte { - return newBinWriter(buf). - Byte(packetTypeAck). - Uint64(p.TraceID). - AddrPort(p.ToAddr). - AddrPort(p.PossibleAddrs[0]). - AddrPort(p.PossibleAddrs[1]). - AddrPort(p.PossibleAddrs[2]). - AddrPort(p.PossibleAddrs[3]). - AddrPort(p.PossibleAddrs[4]). - AddrPort(p.PossibleAddrs[5]). - AddrPort(p.PossibleAddrs[6]). - AddrPort(p.PossibleAddrs[7]). - Build() - -} - -func parseAckPacket(buf []byte) (p ackPacket, err error) { - err = newBinReader(buf[1:]). - Uint64(&p.TraceID). - AddrPort(&p.ToAddr). - AddrPort(&p.PossibleAddrs[0]). - AddrPort(&p.PossibleAddrs[1]). - AddrPort(&p.PossibleAddrs[2]). - AddrPort(&p.PossibleAddrs[3]). - AddrPort(&p.PossibleAddrs[4]). - AddrPort(&p.PossibleAddrs[5]). - AddrPort(&p.PossibleAddrs[6]). - AddrPort(&p.PossibleAddrs[7]). - Error() - return -} - -// ---------------------------------------------------------------------------- - -// A probeReqPacket is sent from a client to a server to determine if direct -// UDP communication can be used. -type probePacket struct { - TraceID uint64 -} - -func (p probePacket) Marshal(buf []byte) []byte { - return newBinWriter(buf). - Byte(packetTypeProbe). - Uint64(p.TraceID). - Build() -} - -func parseProbePacket(buf []byte) (p probePacket, err error) { - err = newBinReader(buf[1:]). - Uint64(&p.TraceID). - Error() - return -} - -// ---------------------------------------------------------------------------- - -type localDiscoveryPacket struct{} diff --git a/node/packets_test.go b/node/packets_test.go deleted file mode 100644 index 2b4023a..0000000 --- a/node/packets_test.go +++ /dev/null @@ -1 +0,0 @@ -package node diff --git a/node/packetsender.go b/node/packetsender.go deleted file mode 100644 index 07e083a..0000000 --- a/node/packetsender.go +++ /dev/null @@ -1,127 +0,0 @@ -package node - -import ( - "log" - "net" - "net/netip" - "sync" - "sync/atomic" - "time" -) - -type controlPacketSender interface { - SendControlPacket(pkt marshaller, route peerRoute) -} - -type dataPacketSender interface { - SendDataPacket(pkt []byte, route peerRoute) - SendEncryptedDataPacket(pkt []byte, addr netip.AddrPort) -} - -// ---------------------------------------------------------------------------- - -type packetSender struct { - conn *net.UDPConn - - // For sending control packets. - cLock sync.Mutex - cBuf1 []byte - cBuf2 []byte - - // For sending data packets. - dBuf1 []byte - dBuf2 []byte - - counters [256]uint64 - - // Lock around for sending on UDP Conn. - wLock sync.Mutex -} - -func newPacketSender(conn *net.UDPConn) *packetSender { - ps := &packetSender{ - conn: conn, - cBuf1: make([]byte, bufferSize), - cBuf2: make([]byte, bufferSize), - dBuf1: make([]byte, bufferSize), - dBuf2: make([]byte, bufferSize), - } - for i := range ps.counters { - ps.counters[i] = uint64(time.Now().Unix()<<30 + 1) - } - return ps -} - -// Safe for concurrent use. -func (sender *packetSender) SendControlPacket(pkt marshaller, route peerRoute) { - sender.cLock.Lock() - defer sender.cLock.Unlock() - - buf := pkt.Marshal(sender.cBuf1) - h := header{ - StreamID: controlStreamID, - Counter: atomic.AddUint64(&sender.counters[route.IP], 1), - SourceIP: localIP, - DestIP: route.IP, - } - buf = route.ControlCipher.Encrypt(h, buf, sender.cBuf2) - - if route.Direct { - sender.writeTo(buf, route.RemoteAddr) - return - } - - sender.relayPacket(route.IP, buf, sender.cBuf1) -} - -// Not safe for concurrent use. -func (sender *packetSender) SendDataPacket(pkt []byte, route peerRoute) { - h := header{ - StreamID: dataStreamID, - Counter: atomic.AddUint64(&sender.counters[route.IP], 1), - SourceIP: localIP, - DestIP: route.IP, - } - - enc := route.DataCipher.Encrypt(h, pkt, sender.dBuf1) - - if route.Direct { - sender.writeTo(enc, route.RemoteAddr) - return - } - - sender.relayPacket(route.IP, enc, sender.dBuf2) -} - -func (sender *packetSender) SendEncryptedDataPacket(pkt []byte, addr netip.AddrPort) { - sender.writeTo(pkt, addr) -} - -func (sender *packetSender) relayPacket(destIP byte, data, buf []byte) { - ip := relayIP.Load() - if ip == nil { - return - } - relayRoute := routingTable[*ip].Load() - if relayRoute == nil || !relayRoute.Up || !relayRoute.Relay { - return - } - - h := header{ - StreamID: dataStreamID, - Counter: atomic.AddUint64(&sender.counters[relayRoute.IP], 1), - SourceIP: localIP, - DestIP: destIP, - } - - enc := relayRoute.DataCipher.Encrypt(h, data, buf) - sender.writeTo(enc, relayRoute.RemoteAddr) -} - -func (sender *packetSender) writeTo(packet []byte, addr netip.AddrPort) { - sender.wLock.Lock() - if _, err := sender.conn.WriteToUDPAddrPort(packet, addr); err != nil { - log.Printf("Failed to write to UDP port: %v", err) - } - sender.wLock.Unlock() -} diff --git a/node/relaymanager.go b/node/relaymanager.go deleted file mode 100644 index a333ce1..0000000 --- a/node/relaymanager.go +++ /dev/null @@ -1,41 +0,0 @@ -package node - -import ( - "log" - "math/rand" - "time" -) - -// TODO: Make part of main loop on ping timer -func relayManager() { - time.Sleep(2 * time.Second) - updateRelayRoute() - - for range time.Tick(8 * time.Second) { - relay := getRelayRoute() - if relay == nil || !relay.Up || !relay.Relay { - updateRelayRoute() - } - } -} - -func updateRelayRoute() { - possible := make([]*peerRoute, 0, 8) - for i := range routingTable { - route := routingTable[i].Load() - if !route.Up || !route.Relay { - continue - } - possible = append(possible, route) - } - - if len(possible) == 0 { - log.Printf("No relay available.") - relayIP.Store(nil) - return - } - - ip := possible[rand.Intn(len(possible))].IP - log.Printf("New relay IP: %d", ip) - relayIP.Store(&ip) -} diff --git a/node/shared.go b/node/shared.go deleted file mode 100644 index dbdb6ee..0000000 --- a/node/shared.go +++ /dev/null @@ -1,59 +0,0 @@ -package node - -import ( - "net/netip" - "sync/atomic" -) - -type sharedState struct { - // Immutable: - HubAddress string - APIKey string - NetName string - LocalIP byte - LocalPub bool - LocalAddr netip.AddrPort - PrivKey []byte - PrivSignKey []byte - - // Mutable: - Routes [256]*atomic.Pointer[peerRoute] - RelayIP *atomic.Pointer[byte] - - // Messages for supervisor main loop. - Messages chan any -} - -func newSharedState( - netName, - hubAddress, - apiKey string, - conf localConfig, -) ( - ss sharedState, -) { - ss.HubAddress = hubAddress - - ss.APIKey = apiKey - ss.NetName = netName - ss.LocalIP = conf.PeerIP - - ip, ok := netip.AddrFromSlice(conf.PublicIP) - if ok { - ss.LocalPub = true - ss.LocalAddr = netip.AddrPortFrom(ip, conf.Port) - } - - ss.PrivKey = conf.PrivKey - ss.PrivSignKey = conf.PrivSignKey - - for i := range ss.Routes { - ss.Routes[i] = &atomic.Pointer[peerRoute]{} - ss.Routes[i].Store(&peerRoute{}) - } - - ss.RelayIP = &atomic.Pointer[byte]{} - - ss.Messages = make(chan any, 1024) - return -} diff --git a/node/shared_test.go b/node/shared_test.go deleted file mode 100644 index 4009e7d..0000000 --- a/node/shared_test.go +++ /dev/null @@ -1,16 +0,0 @@ -package node - -import "vppn/m" - -// TODO: -var sharedStateForTesting = func() sharedState { - ss := newSharedState( - "testNet", - "http://localhost:39499", - "123", - localConfig{ - PeerConfig: m.PeerConfig{}, - }) - - return ss -} diff --git a/node/supervisor.go b/node/supervisor.go deleted file mode 100644 index 726d47f..0000000 --- a/node/supervisor.go +++ /dev/null @@ -1,421 +0,0 @@ -package node - -import ( - "fmt" - "log" - "net/netip" - "strings" - "sync/atomic" - "time" - "vppn/m" - - "git.crumpington.com/lib/go/ratelimiter" -) - -const ( - pingInterval = 8 * time.Second - timeoutInterval = 30 * time.Second -) - -// ---------------------------------------------------------------------------- - -func startPeerSuper( - routingTable [256]*atomic.Pointer[peerRoute], - messages chan any, - sender controlPacketSender, -) { - peers := [256]peerState{} - for i := range peers { - data := &peerStateData{ - sender: sender, - published: routingTable[i], - remoteIP: byte(i), - limiter: ratelimiter.New(ratelimiter.Config{ - FillPeriod: 20 * time.Millisecond, - MaxWaitCount: 1, - }), - } - peers[i] = data.OnPeerUpdate(nil) - } - go runPeerSuper(peers, messages) -} - -func runPeerSuper(peers [256]peerState, messages chan any) { - for raw := range messages { - switch msg := raw.(type) { - - case peerUpdateMsg: - peers[msg.PeerIP] = peers[msg.PeerIP].OnPeerUpdate(msg.Peer) - - case controlMsg[synPacket]: - peers[msg.SrcIP].OnSyn(msg) - - case controlMsg[ackPacket]: - peers[msg.SrcIP].OnAck(msg) - - case controlMsg[probePacket]: - peers[msg.SrcIP].OnProbe(msg) - - case controlMsg[localDiscoveryPacket]: - peers[msg.SrcIP].OnLocalDiscovery(msg) - - case pingTimerMsg: - publicAddrs.Clean() - for i := range peers { - if newState := peers[i].OnPingTimer(); newState != nil { - peers[i] = newState - } - } - - default: - log.Printf("WARNING: unknown message type: %+v", msg) - } - } -} - -// ---------------------------------------------------------------------------- - -type peerState interface { - OnPeerUpdate(*m.Peer) peerState - OnSyn(controlMsg[synPacket]) - OnAck(controlMsg[ackPacket]) - OnProbe(controlMsg[probePacket]) - OnLocalDiscovery(controlMsg[localDiscoveryPacket]) - OnPingTimer() peerState -} - -// ---------------------------------------------------------------------------- - -type peerStateData struct { - sender controlPacketSender - - // The purpose of this state machine is to manage this published data. - published *atomic.Pointer[peerRoute] - staged peerRoute // Local copy of shared data. See publish(). - - // Immutable data. - remoteIP byte // Remote VPN IP. - - // Mutable peer data. - peer *m.Peer - remotePub bool - - // For logging. Set per-state. - client bool - - // We rate limit per remote endpoint because if we don't we tend to lose - // packets. - limiter *ratelimiter.Limiter -} - -// ---------------------------------------------------------------------------- - -func (s *peerStateData) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) { - s._sendControlPacket(pkt, s.staged) -} - -func (s *peerStateData) sendControlPacketTo(pkt interface{ Marshal([]byte) []byte }, addr netip.AddrPort) { - if !addr.IsValid() { - s.logf("ERROR: Attepted to send packet to invalid address: %v", addr) - return - } - route := s.staged - route.Direct = true - route.RemoteAddr = addr - s._sendControlPacket(pkt, route) -} - -func (s *peerStateData) _sendControlPacket(pkt interface{ Marshal([]byte) []byte }, route peerRoute) { - if err := s.limiter.Limit(); err != nil { - s.logf("Not sending control packet: rate limited.") // Shouldn't happen. - return - } - s.sender.SendControlPacket(pkt, route) -} - -// ---------------------------------------------------------------------------- - -func (s *peerStateData) publish() { - data := s.staged - s.published.Store(&data) -} - -func (s *peerStateData) logf(format string, args ...any) { - b := strings.Builder{} - b.WriteString(fmt.Sprintf("%30s: ", s.peer.Name)) - - if s.client { - b.WriteString("CLIENT | ") - } else { - b.WriteString("SERVER | ") - } - - if s.staged.Direct { - b.WriteString("DIRECT | ") - } else { - b.WriteString("RELAYED | ") - } - - if s.staged.Up { - b.WriteString("UP | ") - } else { - b.WriteString("DOWN | ") - } - - log.Printf(b.String()+format, args...) -} - -// ---------------------------------------------------------------------------- - -func (s *peerStateData) OnPeerUpdate(peer *m.Peer) peerState { - defer s.publish() - - if peer == nil { - return enterStateDisconnected(s) - } - - s.peer = peer - s.staged = peerRoute{ - IP: s.remoteIP, - PubSignKey: peer.PubSignKey, - // TODO: privKey global. - ControlCipher: newControlCipher(privKey, peer.PubKey), - DataCipher: newDataCipher(), - } - s.remotePub = false - - if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid { - s.remotePub = true - s.staged.Relay = peer.Relay - s.staged.Direct = true - s.staged.RemoteAddr = netip.AddrPortFrom(ip, peer.Port) - } else if localPub { - s.staged.Direct = true - } - - if s.remotePub == localPub { - // TODO: localIP is global - if localIP < s.remoteIP { - return enterStateServer(s) - } - return enterStateClient(s) - } - - if s.remotePub { - return enterStateClient(s) - } - return enterStateServer(s) -} - -// ---------------------------------------------------------------------------- - -type stateDisconnected struct { - *peerStateData -} - -func enterStateDisconnected(s *peerStateData) peerState { - s.peer = nil - s.staged = peerRoute{} - s.publish() - return &stateDisconnected{s} -} - -func (s *stateDisconnected) OnSyn(controlMsg[synPacket]) {} -func (s *stateDisconnected) OnAck(controlMsg[ackPacket]) {} -func (s *stateDisconnected) OnProbe(controlMsg[probePacket]) {} -func (s *stateDisconnected) OnLocalDiscovery(controlMsg[localDiscoveryPacket]) {} - -func (s *stateDisconnected) OnPingTimer() peerState { - return nil -} - -// ---------------------------------------------------------------------------- - -type stateServer struct { - *stateDisconnected - lastSeen time.Time - synTraceID uint64 -} - -func enterStateServer(s *peerStateData) peerState { - s.client = false - return &stateServer{stateDisconnected: &stateDisconnected{s}} -} - -func (s *stateServer) OnSyn(msg controlMsg[synPacket]) { - s.lastSeen = time.Now() - p := msg.Packet - - // Before we can respond to this packet, we need to make sure the - // route is setup properly. - // - // The client will update the syn's TraceID whenever there's a change. - // The server will follow the client's request. - if p.TraceID != s.synTraceID || !s.staged.Up { - s.synTraceID = p.TraceID - s.staged.Up = true - s.staged.Direct = p.Direct - s.staged.DataCipher = newDataCipherFromKey(p.SharedKey) - s.staged.RemoteAddr = msg.SrcAddr - s.publish() - s.logf("Got syn.") - } - - // Always respond. - ack := ackPacket{ - TraceID: p.TraceID, - ToAddr: s.staged.RemoteAddr, - PossibleAddrs: publicAddrs.Get(), - } - s.sendControlPacket(ack) - - if s.staged.Direct { - return - } - - // Not direct => send probes. - for _, addr := range p.PossibleAddrs { - if !addr.IsValid() { - break - } - s.sendControlPacketTo(probePacket{TraceID: newTraceID()}, addr) - } -} - -func (s *stateServer) OnProbe(msg controlMsg[probePacket]) { - if !msg.SrcAddr.IsValid() { - s.logf("Invalid probe address.") - return - } - s.sendControlPacketTo(probePacket{TraceID: msg.Packet.TraceID}, msg.SrcAddr) -} - -func (s *stateServer) OnPingTimer() peerState { - if time.Since(s.lastSeen) > timeoutInterval && s.staged.Up { - s.staged.Up = false - s.publish() - s.logf("Connection timeout.") - } - return nil -} - -// ---------------------------------------------------------------------------- - -type stateClient struct { - *stateDisconnected - - lastSeen time.Time - syn synPacket - ack ackPacket - - probes map[uint64]netip.AddrPort - localDiscoveryAddr netip.AddrPort -} - -func enterStateClient(s *peerStateData) peerState { - s.client = true - ss := &stateClient{ - stateDisconnected: &stateDisconnected{s}, - probes: map[uint64]netip.AddrPort{}, - } - - ss.syn = synPacket{ - TraceID: newTraceID(), - SharedKey: s.staged.DataCipher.Key(), - Direct: s.staged.Direct, - PossibleAddrs: publicAddrs.Get(), - } - ss.sendControlPacket(ss.syn) - - return ss -} - -func (s *stateClient) sendProbeTo(addr netip.AddrPort) { - probe := probePacket{TraceID: newTraceID()} - s.probes[probe.TraceID] = addr - s.sendControlPacketTo(probe, addr) -} - -func (s *stateClient) OnAck(msg controlMsg[ackPacket]) { - if msg.Packet.TraceID != s.syn.TraceID { - s.logf("Ack has incorrect trace ID") - return - } - - s.ack = msg.Packet - s.lastSeen = time.Now() - - if !s.staged.Up { - s.staged.Up = true - s.logf("Got ack.") - s.publish() - } - - // Store possible public address if we're not a public node. - // TODO: localPub is global, publicAddrs is global. - if !localPub && s.remotePub { - publicAddrs.Store(msg.Packet.ToAddr) - } -} - -func (s *stateClient) OnProbe(msg controlMsg[probePacket]) { - if s.staged.Direct { - return - } - - addr, ok := s.probes[msg.Packet.TraceID] - if !ok { - return - } - - s.staged.RemoteAddr = addr - s.staged.Direct = true - s.publish() - - s.syn.TraceID = newTraceID() - s.syn.Direct = true - s.syn.PossibleAddrs = [8]netip.AddrPort{} - s.sendControlPacket(s.syn) - - s.logf("Established direct connection to %s.", s.staged.RemoteAddr.String()) -} - -func (s *stateClient) OnLocalDiscovery(msg controlMsg[localDiscoveryPacket]) { - if s.staged.Direct { - return - } - - // The source port will be the multicast port, so we'll have to - // construct the correct address using the peer's listed port. - s.localDiscoveryAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) -} - -func (s *stateClient) OnPingTimer() peerState { - if time.Since(s.lastSeen) > timeoutInterval { - if s.staged.Up { - s.logf("Connection timeout.") - } - return s.OnPeerUpdate(s.peer) - } - - s.sendControlPacket(s.syn) - - if s.staged.Direct { - return nil - } - - clear(s.probes) - for _, addr := range s.ack.PossibleAddrs { - if !addr.IsValid() { - break - } - s.sendProbeTo(addr) - } - - if s.localDiscoveryAddr.IsValid() { - s.sendProbeTo(s.localDiscoveryAddr) - s.localDiscoveryAddr = netip.AddrPort{} - } - - return nil -} -- 2.39.5 From 5cafd030acb81cd6e4ff81a5ae3c5c7ff9bb21be Mon Sep 17 00:00:00 2001 From: jdl Date: Fri, 21 Feb 2025 07:42:16 +0100 Subject: [PATCH 13/26] wip --- peer/packets.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/peer/packets.go b/peer/packets.go index b300dee..5be89b0 100644 --- a/peer/packets.go +++ b/peer/packets.go @@ -9,10 +9,18 @@ const ( packetTypeAck = 3 packetTypeProbe = 4 packetTypeAddrDiscovery = 5 + packetTypeInit = 6 ) // ---------------------------------------------------------------------------- +type packetInit struct { + TraceID uint64 + Version uint64 +} + +// ---------------------------------------------------------------------------- + type packetSyn struct { TraceID uint64 // TraceID to match response w/ request. SharedKey [32]byte // Our shared key. -- 2.39.5 From 3f7c42bb413fa7147597b5b7e53ad54e6300ac98 Mon Sep 17 00:00:00 2001 From: jdl Date: Sun, 23 Feb 2025 16:58:41 +0100 Subject: [PATCH 14/26] wip: working --- peer/controlmessage.go | 3 +- peer/hubpoller.go | 2 +- peer/mcwriter.go | 2 +- peer/peerstates.go | 90 ++++++++++++++++++++++++++++++++++------- peer/peerstates_test.go | 49 ++++++++++------------ peer/peersuper.go | 56 ++++++++----------------- 6 files changed, 115 insertions(+), 87 deletions(-) diff --git a/peer/controlmessage.go b/peer/controlmessage.go index 09935ab..3a18bc8 100644 --- a/peer/controlmessage.go +++ b/peer/controlmessage.go @@ -49,8 +49,7 @@ func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error // ---------------------------------------------------------------------------- type peerUpdateMsg struct { - PeerIP byte - Peer *m.Peer + Peer *m.Peer } // ---------------------------------------------------------------------------- diff --git a/peer/hubpoller.go b/peer/hubpoller.go index 2b50495..238dfda 100644 --- a/peer/hubpoller.go +++ b/peer/hubpoller.go @@ -96,7 +96,7 @@ func (hp *hubPoller) applyNetworkState(state m.NetworkState) { for i, peer := range state.Peers { if i != int(hp.localIP) { if peer == nil || peer.Version != hp.versions[i] { - hp.handleControlMsg(byte(i), peerUpdateMsg{PeerIP: byte(i), Peer: state.Peers[i]}) + hp.handleControlMsg(byte(i), peerUpdateMsg{Peer: state.Peers[i]}) if peer != nil { hp.versions[i] = peer.Version } diff --git a/peer/mcwriter.go b/peer/mcwriter.go index 5559547..29cf2be 100644 --- a/peer/mcwriter.go +++ b/peer/mcwriter.go @@ -44,7 +44,7 @@ func runMCWriter(localIP byte, signingKey []byte) { log.Fatalf("Failed to bind to multicast address: %v", err) } - for range time.Tick(16 * time.Second) { + for range time.Tick(8 * time.Second) { _, err := conn.WriteToUDP(discoveryPacket, multicastAddr) if err != nil { log.Printf("[MCWriter] Failed to write multicast: %v", err) diff --git a/peer/peerstates.go b/peer/peerstates.go index a68afb1..b5abfb7 100644 --- a/peer/peerstates.go +++ b/peer/peerstates.go @@ -12,12 +12,7 @@ import ( ) type peerState interface { - OnPeerUpdate(*m.Peer) peerState - OnSyn(controlMsg[packetSyn]) peerState - OnAck(controlMsg[packetAck]) - OnProbe(controlMsg[packetProbe]) peerState - OnLocalDiscovery(controlMsg[packetLocalDiscovery]) - OnPingTimer() peerState + OnMsg(raw any) peerState } // ---------------------------------------------------------------------------- @@ -26,6 +21,7 @@ type pState struct { // Output. publish func(remotePeer) sendControlPacket func(remotePeer, marshaller) + pingTimer *time.Ticker // Immutable data. localIP byte @@ -147,9 +143,20 @@ func (s *pState) Send(peer remotePeer, pkt marshaller) { type stateDisconnected struct{ *pState } func enterStateDisconnected(s *pState) peerState { + s.pingTimer.Stop() return &stateDisconnected{pState: s} } +func (s *stateDisconnected) OnMsg(raw any) peerState { + switch msg := raw.(type) { + case peerUpdateMsg: + return s.OnPeerUpdate(msg.Peer) + default: + // TODO: Log. + return s + } +} + func (s *stateDisconnected) OnSyn(controlMsg[packetSyn]) peerState { return s } func (s *stateDisconnected) OnAck(controlMsg[packetAck]) {} func (s *stateDisconnected) OnProbe(controlMsg[packetProbe]) peerState { return s } @@ -166,9 +173,26 @@ type stateServer struct { func enterStateServer(s *pState) peerState { s.logf("==> Server") + s.pingTimer.Reset(pingInterval) return &stateServer{stateDisconnected: &stateDisconnected{pState: s}} } +func (s *stateServer) OnMsg(rawMsg any) peerState { + switch msg := rawMsg.(type) { + case peerUpdateMsg: + return s.OnPeerUpdate(msg.Peer) + case controlMsg[packetSyn]: + return s.OnSyn(msg) + case controlMsg[packetProbe]: + return s.OnProbe(msg) + case pingTimerMsg: + return s.OnPingTimer() + default: + // TODO: Log + return s + } +} + func (s *stateServer) OnSyn(msg controlMsg[packetSyn]) peerState { s.lastSeen = time.Now() p := msg.Packet @@ -236,6 +260,7 @@ type stateClientDirect struct { func enterStateClientDirect(s *pState) peerState { s.logf("==> ClientDirect") + s.pingTimer.Reset(pingInterval) return newStateClientDirect(s) } @@ -255,6 +280,24 @@ func newStateClientDirect(s *pState) *stateClientDirect { return state } +func (s *stateClientDirect) OnMsg(raw any) peerState { + switch msg := raw.(type) { + case peerUpdateMsg: + return s.OnPeerUpdate(msg.Peer) + case controlMsg[packetAck]: + s.OnAck(msg) + return s + case pingTimerMsg: + if next := s.onPingTimer(); next != nil { + return next + } + return s + default: + // TODO: Log + return s + } +} + func (s *stateClientDirect) OnAck(msg controlMsg[packetAck]) { if msg.Packet.TraceID != s.syn.TraceID { return @@ -271,13 +314,6 @@ func (s *stateClientDirect) OnAck(msg controlMsg[packetAck]) { s.pubAddrs.Store(msg.Packet.ToAddr) } -func (s *stateClientDirect) OnPingTimer() peerState { - if next := s.onPingTimer(); next != nil { - return next - } - return s -} - func (s *stateClientDirect) onPingTimer() peerState { if time.Since(s.lastSeen) > timeoutInterval { if s.staged.Up { @@ -297,21 +333,44 @@ func (s *stateClientDirect) onPingTimer() peerState { type stateClientRelayed struct { *stateClientDirect ack packetAck - probes map[uint64]netip.AddrPort - localDiscoveryAddr netip.AddrPort + probes map[uint64]netip.AddrPort // TODO: something better + localDiscoveryAddr netip.AddrPort // TODO: Remove } func enterStateClientRelayed(s *pState) peerState { s.logf("==> ClientRelayed") + s.pingTimer.Reset(pingInterval) return &stateClientRelayed{ stateClientDirect: newStateClientDirect(s), probes: map[uint64]netip.AddrPort{}, } } +func (s *stateClientRelayed) OnMsg(raw any) peerState { + switch msg := raw.(type) { + case peerUpdateMsg: + return s.OnPeerUpdate(msg.Peer) + case controlMsg[packetAck]: + s.OnAck(msg) + return s + case controlMsg[packetProbe]: + return s.OnProbe(msg) + case controlMsg[packetLocalDiscovery]: + s.OnLocalDiscovery(msg) + return s + case pingTimerMsg: + return s.OnPingTimer() + default: + // TODO: Log + return s + } +} + func (s *stateClientRelayed) OnAck(msg controlMsg[packetAck]) { s.ack = msg.Packet s.stateClientDirect.OnAck(msg) + + // TODO: Send probes now. } func (s *stateClientRelayed) OnProbe(msg controlMsg[packetProbe]) peerState { @@ -330,6 +389,7 @@ func (s *stateClientRelayed) OnLocalDiscovery(msg controlMsg[packetLocalDiscover // The source port will be the multicast port, so we'll have to // construct the correct address using the peer's listed port. s.localDiscoveryAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) + // TODO: s.sendProbeTo(s.localDiscoveryAddr) } func (s *stateClientRelayed) OnPingTimer() peerState { diff --git a/peer/peerstates_test.go b/peer/peerstates_test.go index daf5c14..cbe2474 100644 --- a/peer/peerstates_test.go +++ b/peer/peerstates_test.go @@ -34,10 +34,11 @@ func NewPeerStateTestHarness() *PeerStateTestHarness { sendControlPacket: func(rp remotePeer, pkt marshaller) { h.Sent = append(h.Sent, PeerStateControlMsg{rp, pkt}) }, - localIP: 2, - remoteIP: 3, - privKey: keys.PrivKey, - pubAddrs: newPubAddrStore(netip.AddrPort{}), + pingTimer: time.NewTicker(pingInterval), + localIP: 2, + remoteIP: 3, + privKey: keys.PrivKey, + pubAddrs: newPubAddrStore(netip.AddrPort{}), limiter: ratelimiter.New(ratelimiter.Config{ FillPeriod: 20 * time.Millisecond, MaxWaitCount: 1, @@ -49,27 +50,19 @@ func NewPeerStateTestHarness() *PeerStateTestHarness { } func (h *PeerStateTestHarness) PeerUpdate(p *m.Peer) { - if s := h.State.OnPeerUpdate(p); s != nil { - h.State = s - } + h.State = h.State.OnMsg(peerUpdateMsg{p}) } func (h *PeerStateTestHarness) OnSyn(msg controlMsg[packetSyn]) { - if s := h.State.OnSyn(msg); s != nil { - h.State = s - } + h.State = h.State.OnMsg(msg) } func (h *PeerStateTestHarness) OnProbe(msg controlMsg[packetProbe]) { - if s := h.State.OnProbe(msg); s != nil { - h.State = s - } + h.State = h.State.OnMsg(msg) } func (h *PeerStateTestHarness) OnPingTimer() { - if s := h.State.OnPingTimer(); s != nil { - h.State = s - } + h.State = h.State.OnMsg(pingTimerMsg{}) } func (h *PeerStateTestHarness) ConfigServer_Public(t *testing.T) *stateServer { @@ -202,7 +195,7 @@ func TestStateServer_directSyn(t *testing.T) { }, } - h.State.OnSyn(synMsg) + h.State = h.State.OnMsg(synMsg) assertEqual(t, len(h.Sent), 1) ack := assertType[packetAck](t, h.Sent[0].Packet) @@ -233,7 +226,7 @@ func TestStateServer_relayedSyn(t *testing.T) { synMsg.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 3, 300) synMsg.Packet.PossibleAddrs[1] = addrPort4(2, 2, 2, 3, 300) - h.State.OnSyn(synMsg) + h.State = h.State.OnMsg(synMsg) assertEqual(t, len(h.Sent), 3) @@ -261,7 +254,7 @@ func TestStateServer_onProbe(t *testing.T) { Packet: packetProbe{TraceID: newTraceID()}, } - h.State.OnProbe(probeMsg) + h.State = h.State.OnMsg(probeMsg) assertEqual(t, len(h.Sent), 1) @@ -285,7 +278,7 @@ func TestStateServer_OnPingTimer_timeout(t *testing.T) { }, } - h.State.OnSyn(synMsg) + h.State = h.State.OnMsg(synMsg) assertEqual(t, len(h.Sent), 1) assertEqual(t, h.Published.Up, true) @@ -314,7 +307,7 @@ func TestStateClientDirect_OnAck(t *testing.T) { ack := controlMsg[packetAck]{ Packet: packetAck{TraceID: syn.TraceID}, } - h.State.OnAck(ack) + h.State = h.State.OnMsg(ack) assertEqual(t, h.Published.Up, true) } @@ -331,7 +324,7 @@ func TestStateClientDirect_OnAck_incorrectTraceID(t *testing.T) { ack := controlMsg[packetAck]{ Packet: packetAck{TraceID: syn.TraceID + 1}, } - h.State.OnAck(ack) + h.State = h.State.OnMsg(ack) assertEqual(t, h.Published.Up, false) } @@ -366,7 +359,7 @@ func TestStateClientDirect_OnPingTimer_timeout(t *testing.T) { ack := controlMsg[packetAck]{ Packet: packetAck{TraceID: syn.TraceID}, } - h.State.OnAck(ack) + h.State = h.State.OnMsg(ack) assertEqual(t, h.Published.Up, true) state := assertType[*stateClientDirect](t, h.State) @@ -395,7 +388,7 @@ func TestStateClientRelayed_OnAck(t *testing.T) { ack := controlMsg[packetAck]{ Packet: packetAck{TraceID: syn.TraceID}, } - h.State.OnAck(ack) + h.State = h.State.OnMsg(ack) assertEqual(t, h.Published.Up, true) } @@ -429,11 +422,11 @@ func TestStateClientRelayed_OnPingTimer_withAddrs(t *testing.T) { ack.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 300) ack.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 300) - h.State.OnAck(ack) + h.State = h.State.OnMsg(ack) // Add a local discovery address. Note that the port will be configured port // and no the one provided here. - h.State.OnLocalDiscovery(controlMsg[packetLocalDiscovery]{ + h.State = h.State.OnMsg(controlMsg[packetLocalDiscovery]{ SrcIP: 3, SrcAddr: addrPort4(2, 2, 2, 3, 300), }) @@ -462,7 +455,7 @@ func TestStateClientRelayed_OnPingTimer_timeout(t *testing.T) { ack := controlMsg[packetAck]{ Packet: packetAck{TraceID: syn.TraceID}, } - h.State.OnAck(ack) + h.State = h.State.OnMsg(ack) assertEqual(t, h.Published.Up, true) state := assertType[*stateClientRelayed](t, h.State) @@ -499,7 +492,7 @@ func TestStateClientRelayed_OnProbe_upgradeDirect(t *testing.T) { ack.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 300) ack.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 300) - h.State.OnAck(ack) + h.State = h.State.OnMsg(ack) h.OnPingTimer() probe := assertType[packetProbe](t, h.Sent[2].Packet) diff --git a/peer/peersuper.go b/peer/peersuper.go index 7682d87..6fa724a 100644 --- a/peer/peersuper.go +++ b/peer/peersuper.go @@ -1,8 +1,6 @@ package peer import ( - "log" - "math/rand" "net/netip" "sync" "sync/atomic" @@ -44,6 +42,7 @@ func newSupervisor( state := &pState{ publish: s.publish, sendControlPacket: s.send, + pingTimer: time.NewTicker(timeoutInterval), localIP: routes.LocalIP, remoteIP: byte(i), privKey: privKey, @@ -55,7 +54,7 @@ func newSupervisor( MaxWaitCount: 1, }), } - s.peers[i] = newPeerSuper(state) + s.peers[i] = newPeerSuper(state, state.pingTimer) } return s @@ -105,7 +104,7 @@ func (s *supervisor) ensureRelay() { return } - // TODO: Random selection? + // TODO: Random selection? Something else? for _, peer := range s.staged.Peers { if peer.Up && peer.Direct && peer.Relay { s.staged.RelayIP = peer.IP @@ -117,14 +116,16 @@ func (s *supervisor) ensureRelay() { // ---------------------------------------------------------------------------- type peerSuper struct { - messages chan any - state peerState + messages chan any + state peerState + pingTimer *time.Ticker } -func newPeerSuper(state *pState) *peerSuper { +func newPeerSuper(state *pState, pingTimer *time.Ticker) *peerSuper { return &peerSuper{ - messages: make(chan any, 8), - state: state.OnPeerUpdate(nil), + messages: make(chan any, 8), + state: state.OnPeerUpdate(nil), + pingTimer: pingTimer, } } @@ -136,37 +137,12 @@ func (s *peerSuper) HandleControlMsg(msg any) { } func (s *peerSuper) Run() { - go func() { - // Randomize ping timers. - time.Sleep(time.Duration(rand.Intn(4000)) * time.Millisecond) - for range time.Tick(pingInterval) { - s.messages <- pingTimerMsg{} - } - }() - - for rawMsg := range s.messages { - switch msg := rawMsg.(type) { - - case peerUpdateMsg: - s.state = s.state.OnPeerUpdate(msg.Peer) - - case controlMsg[packetSyn]: - s.state = s.state.OnSyn(msg) - - case controlMsg[packetAck]: - s.state.OnAck(msg) - - case controlMsg[packetProbe]: - s.state = s.state.OnProbe(msg) - - case controlMsg[packetLocalDiscovery]: - s.state.OnLocalDiscovery(msg) - - case pingTimerMsg: - s.state = s.state.OnPingTimer() - - default: - log.Printf("WARNING: unknown message type: %+v", msg) + for { + select { + case <-s.pingTimer.C: + s.state = s.state.OnMsg(pingTimerMsg{}) + case raw := <-s.messages: + s.state = s.state.OnMsg(raw) } } } -- 2.39.5 From 9fd6d90f9cdec592a2187a33ad06a06e6ff57ad5 Mon Sep 17 00:00:00 2001 From: jdl Date: Tue, 25 Feb 2025 02:43:29 +0100 Subject: [PATCH 15/26] wip: cleanup --- peer/connreader.go | 9 ++- peer/controlmessage.go | 8 ++ peer/globals.go | 2 + peer/logging.go | 13 ++++ peer/packets.go | 21 +++++- peer/peerstates.go | 94 +++++++++++++++++++++++- peer/peerstates_test.go | 5 ++ peer/peersuper.go | 2 +- peer/pubaddrs.go | 9 +-- peer/state-clientdirect.go | 85 +++++++++++++++++++++ peer/state-clientinit.go | 93 +++++++++++++++++++++++ peer/state-clientrelayed.go | 142 ++++++++++++++++++++++++++++++++++++ peer/state-disconnected.go | 33 +++++++++ peer/state-server.go | 127 ++++++++++++++++++++++++++++++++ peer/statedata.go | 28 +++++++ 15 files changed, 657 insertions(+), 14 deletions(-) create mode 100644 peer/logging.go create mode 100644 peer/state-clientdirect.go create mode 100644 peer/state-clientinit.go create mode 100644 peer/state-clientrelayed.go create mode 100644 peer/state-disconnected.go create mode 100644 peer/state-server.go create mode 100644 peer/statedata.go diff --git a/peer/connreader.go b/peer/connreader.go index b78e58f..0727ced 100644 --- a/peer/connreader.go +++ b/peer/connreader.go @@ -84,6 +84,7 @@ func (r *connReader) handleControlPacket( enc []byte, ) { if peer.ControlCipher == nil { + log.Printf("No control cipher for peer: %v", h) return } @@ -125,13 +126,13 @@ func (r *connReader) handleDataPacket( return } - relay, ok := rt.GetRelay() - if !ok { - r.logf("Relay not available.") + remote := rt.Peers[h.DestIP] + if !remote.Direct { + r.logf("Unable to relay data to %d.", h.DestIP) return } - r.writeToUDPAddrPort(data, relay.DirectAddr) + r.writeToUDPAddrPort(data, remote.DirectAddr) } func (r *connReader) logf(format string, args ...any) { diff --git a/peer/controlmessage.go b/peer/controlmessage.go index 3a18bc8..33d4e9c 100644 --- a/peer/controlmessage.go +++ b/peer/controlmessage.go @@ -41,6 +41,14 @@ func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error Packet: packet, }, err + case packetTypeInit: + packet, err := parsePacketInit(buf) + return controlMsg[packetInit]{ + SrcIP: srcIP, + SrcAddr: srcAddr, + Packet: packet, + }, err + default: return nil, errUnknownPacketType } diff --git a/peer/globals.go b/peer/globals.go index f967c8a..cd0e1f6 100644 --- a/peer/globals.go +++ b/peer/globals.go @@ -7,6 +7,8 @@ import ( ) const ( + version = 1 + bufferSize = 1536 if_mtu = 1200 diff --git a/peer/logging.go b/peer/logging.go new file mode 100644 index 0000000..4906b04 --- /dev/null +++ b/peer/logging.go @@ -0,0 +1,13 @@ +package peer + +import "log" + +func logPacket(p []byte, notes string) { + h := parseHeader(p) + log.Printf(`Sending: Data: %v | From: %d | To: %d | %s +`, + h.StreamID == dataStreamID, + h.SourceIP, + h.DestIP, + notes) +} diff --git a/peer/packets.go b/peer/packets.go index 5be89b0..b673a4c 100644 --- a/peer/packets.go +++ b/peer/packets.go @@ -6,19 +6,38 @@ import ( const ( packetTypeSyn = 1 + packetTypeInit = 2 packetTypeAck = 3 packetTypeProbe = 4 packetTypeAddrDiscovery = 5 - packetTypeInit = 6 ) // ---------------------------------------------------------------------------- type packetInit struct { TraceID uint64 + Direct bool Version uint64 } +func (p packetInit) Marshal(buf []byte) []byte { + return newBinWriter(buf). + Byte(packetTypeInit). + Uint64(p.TraceID). + Bool(p.Direct). + Uint64(p.Version). + Build() +} + +func parsePacketInit(buf []byte) (p packetInit, err error) { + err = newBinReader(buf[1:]). + Uint64(&p.TraceID). + Bool(&p.Direct). + Uint64(&p.Version). + Error() + return +} + // ---------------------------------------------------------------------------- type packetSyn struct { diff --git a/peer/peerstates.go b/peer/peerstates.go index b5abfb7..6c52f55 100644 --- a/peer/peerstates.go +++ b/peer/peerstates.go @@ -43,6 +43,7 @@ type pState struct { limiter *ratelimiter.Limiter } +/* func (s *pState) OnPeerUpdate(peer *m.Peer) peerState { defer func() { // Don't defer directly otherwise s.staged will be evaluated immediately @@ -78,7 +79,7 @@ func (s *pState) OnPeerUpdate(peer *m.Peer) peerState { return enterStateServer(s) } - return enterStateClientDirect(s) + return enterStateClientinit(s) } if s.localAddr.IsValid() { @@ -90,8 +91,9 @@ func (s *pState) OnPeerUpdate(peer *m.Peer) peerState { return enterStateServer(s) } - return enterStateClientRelayed(s) + return enterStateClientinit(s) } +*/ func (s *pState) logf(format string, args ...any) { b := strings.Builder{} @@ -140,6 +142,7 @@ func (s *pState) Send(peer remotePeer, pkt marshaller) { // ---------------------------------------------------------------------------- +/* type stateDisconnected struct{ *pState } func enterStateDisconnected(s *pState) peerState { @@ -181,6 +184,8 @@ func (s *stateServer) OnMsg(rawMsg any) peerState { switch msg := rawMsg.(type) { case peerUpdateMsg: return s.OnPeerUpdate(msg.Peer) + case controlMsg[packetInit]: + return s.OnInit(msg) case controlMsg[packetSyn]: return s.OnSyn(msg) case controlMsg[packetProbe]: @@ -193,6 +198,21 @@ func (s *stateServer) OnMsg(rawMsg any) peerState { } } +func (s *stateServer) OnInit(msg controlMsg[packetInit]) peerState { + s.logf("Responding to INIT.") + route := s.staged + route.Direct = msg.Packet.Direct + route.DirectAddr = msg.SrcAddr + + s.Send(route, packetInit{ + TraceID: msg.Packet.TraceID, + Direct: route.Direct, + Version: version, + }) + + return s +} + func (s *stateServer) OnSyn(msg controlMsg[packetSyn]) peerState { s.lastSeen = time.Now() p := msg.Packet @@ -252,6 +272,75 @@ func (s *stateServer) OnPingTimer() peerState { // ---------------------------------------------------------------------------- +type stateClientInit struct { + *stateDisconnected + startedAt time.Time + traceID uint64 +} + +func enterStateClientinit(s *pState) peerState { + s.logf("==> ClientInit") + s.pingTimer.Reset(pingInterval) + + state := &stateClientInit{ + stateDisconnected: &stateDisconnected{s}, + startedAt: time.Now(), + traceID: newTraceID(), + } + state.Send(s.staged, packetInit{ + TraceID: state.traceID, + Direct: s.staged.Direct, + Version: version, + }) + return state +} + +func (s *stateClientInit) OnMsg(raw any) peerState { + switch msg := raw.(type) { + case peerUpdateMsg: + return s.OnPeerUpdate(msg.Peer) + case controlMsg[packetInit]: + return s.onInit(msg) + case pingTimerMsg: + return s.onPing() + default: + return s + } +} + +func (s *stateClientInit) onInit(msg controlMsg[packetInit]) peerState { + if msg.Packet.TraceID != s.traceID { + s.logf("Invalid trace ID on INIT.") + return s + } + s.logf("Got INIT version %d.", msg.Packet.Version) + return s.nextState() +} + +func (s *stateClientInit) onPing() peerState { + if time.Since(s.startedAt) > timeoutInterval { + s.logf("Init timeout. Assuming version 1.") + return s.nextState() + } + + s.traceID = newTraceID() + s.Send(s.staged, packetInit{ + TraceID: s.traceID, + Direct: s.staged.Direct, + Version: version, + }) + return s +} + +func (s *stateClientInit) nextState() peerState { + if s.staged.Direct { + return enterStateClientDirect(s.pState) + } + return enterStateClientRelayed(s.pState) +} + +// ---------------------------------------------------------------------------- + type stateClientDirect struct { *stateDisconnected lastSeen time.Time @@ -418,3 +507,4 @@ func (s *stateClientRelayed) sendProbeTo(addr netip.AddrPort) { s.probes[probe.TraceID] = addr s.SendTo(probe, addr) } +*/ diff --git a/peer/peerstates_test.go b/peer/peerstates_test.go index cbe2474..15f7d18 100644 --- a/peer/peerstates_test.go +++ b/peer/peerstates_test.go @@ -53,6 +53,10 @@ func (h *PeerStateTestHarness) PeerUpdate(p *m.Peer) { h.State = h.State.OnMsg(peerUpdateMsg{p}) } +func (h *PeerStateTestHarness) OnInit(msg controlMsg[packetInit]) { + h.State = h.State.OnMsg(msg) +} + func (h *PeerStateTestHarness) OnSyn(msg controlMsg[packetSyn]) { h.State = h.State.OnMsg(msg) } @@ -110,6 +114,7 @@ func (h *PeerStateTestHarness) ConfigClientDirect(t *testing.T) *stateClientDire h.PeerUpdate(peer) assertEqual(t, h.Published.Up, false) + return assertType[*stateClientDirect](t, h.State) } diff --git a/peer/peersuper.go b/peer/peersuper.go index 6fa724a..ec8c741 100644 --- a/peer/peersuper.go +++ b/peer/peersuper.go @@ -124,7 +124,7 @@ type peerSuper struct { func newPeerSuper(state *pState, pingTimer *time.Ticker) *peerSuper { return &peerSuper{ messages: make(chan any, 8), - state: state.OnPeerUpdate(nil), + state: initPeerState(state, nil), pingTimer: pingTimer, } } diff --git a/peer/pubaddrs.go b/peer/pubaddrs.go index 027057a..c56b28e 100644 --- a/peer/pubaddrs.go +++ b/peer/pubaddrs.go @@ -1,9 +1,7 @@ package peer import ( - "log" "net/netip" - "runtime/debug" "sort" "sync" "time" @@ -27,14 +25,13 @@ func newPubAddrStore(localAddr netip.AddrPort) *pubAddrStore { } func (store *pubAddrStore) Store(add netip.AddrPort) { - store.lock.Lock() - defer store.lock.Unlock() - if store.localPub { - log.Printf("OOPS: Local pub but storage attempt: %s", debug.Stack()) return } + store.lock.Lock() + defer store.lock.Unlock() + if !add.IsValid() { return } diff --git a/peer/state-clientdirect.go b/peer/state-clientdirect.go new file mode 100644 index 0000000..c6c552d --- /dev/null +++ b/peer/state-clientdirect.go @@ -0,0 +1,85 @@ +package peer + +import ( + "net/netip" + "time" +) + +type stateClientDirect2 struct { + *peerData + lastSeen time.Time + syn packetSyn +} + +func enterStateClientDirect2(data *peerData, directAddr netip.AddrPort) peerState { + data.staged.Relay = data.peer.Relay + data.staged.Direct = true + data.staged.DirectAddr = directAddr + data.publish(data.staged) + + state := &stateClientDirect2{ + peerData: data, + lastSeen: time.Now(), + syn: packetSyn{ + TraceID: newTraceID(), + SharedKey: data.staged.DataCipher.Key(), + Direct: true, + }, + } + + state.Send(state.staged, state.syn) + + data.pingTimer.Reset(pingInterval) + + state.logf("==> ClientDirect") + return state +} + +func (s *stateClientDirect2) logf(str string, args ...any) { + s.peerData.logf("CLNT | "+str, args...) +} + +func (s *stateClientDirect2) OnMsg(raw any) peerState { + switch msg := raw.(type) { + case peerUpdateMsg: + return initPeerState(s.peerData, msg.Peer) + case controlMsg[packetAck]: + return s.onAck(msg) + case pingTimerMsg: + return s.onPingTimer() + case controlMsg[packetLocalDiscovery]: + return s + default: + s.logf("Ignoring message: %v", raw) + return s + } +} + +func (s *stateClientDirect2) onAck(msg controlMsg[packetAck]) peerState { + if msg.Packet.TraceID != s.syn.TraceID { + return s + } + + s.lastSeen = time.Now() + + if !s.staged.Up { + s.staged.Up = true + s.publish(s.staged) + s.logf("Got ACK.") + } + + s.pubAddrs.Store(msg.Packet.ToAddr) + return s +} + +func (s *stateClientDirect2) onPingTimer() peerState { + if time.Since(s.lastSeen) > timeoutInterval { + if s.staged.Up { + s.logf("Timeout.") + } + return initPeerState(s.peerData, s.peer) + } + + s.Send(s.staged, s.syn) + return s +} diff --git a/peer/state-clientinit.go b/peer/state-clientinit.go new file mode 100644 index 0000000..8a84100 --- /dev/null +++ b/peer/state-clientinit.go @@ -0,0 +1,93 @@ +package peer + +import ( + "net/netip" + "time" +) + +type stateClientInit2 struct { + *peerData + startedAt time.Time + traceID uint64 +} + +func enterStateClientInit2(data *peerData) peerState { + ip, ipValid := netip.AddrFromSlice(data.peer.PublicIP) + + data.staged.Up = false + data.staged.Relay = false + data.staged.Direct = ipValid + data.staged.DirectAddr = netip.AddrPortFrom(ip, data.peer.Port) + data.staged.PubSignKey = data.peer.PubSignKey + data.staged.ControlCipher = newControlCipher(data.privKey, data.peer.PubKey) + data.staged.DataCipher = newDataCipher() + + data.publish(data.staged) + + state := &stateClientInit2{ + peerData: data, + startedAt: time.Now(), + traceID: newTraceID(), + } + state.sendInit() + + data.pingTimer.Reset(pingInterval) + + state.logf("==> ClientInit") + return state +} + +func (s *stateClientInit2) logf(str string, args ...any) { + s.peerData.logf("INIT | "+str, args...) +} + +func (s *stateClientInit2) OnMsg(raw any) peerState { + switch msg := raw.(type) { + case peerUpdateMsg: + return initPeerState(s.peerData, msg.Peer) + case controlMsg[packetInit]: + return s.onInit(msg) + case pingTimerMsg: + return s.onPing() + default: + s.logf("Ignoring message: %v", raw) + return s + } +} + +func (s *stateClientInit2) onInit(msg controlMsg[packetInit]) peerState { + if msg.Packet.TraceID != s.traceID { + s.logf("Invalid trace ID on INIT.") + return s + } + s.logf("Got INIT version %d.", msg.Packet.Version) + return s.nextState() +} + +func (s *stateClientInit2) onPing() peerState { + if time.Since(s.startedAt) > timeoutInterval { + s.logf("Init timeout. Assuming version 1.") + return s.nextState() + } + + s.sendInit() + return s +} + +func (s *stateClientInit2) sendInit() { + s.traceID = newTraceID() + init := packetInit{ + TraceID: s.traceID, + Direct: s.staged.Direct, + Version: version, + } + s.Send(s.staged, init) +} + +func (s *stateClientInit2) nextState() peerState { + if s.staged.Direct { + return enterStateClientDirect2(s.peerData, s.staged.DirectAddr) + } + + return enterStateClientRelayed2(s.peerData) +} diff --git a/peer/state-clientrelayed.go b/peer/state-clientrelayed.go new file mode 100644 index 0000000..737f0a9 --- /dev/null +++ b/peer/state-clientrelayed.go @@ -0,0 +1,142 @@ +package peer + +import ( + "net/netip" + "time" +) + +type sentProbe struct { + SentAt time.Time + Addr netip.AddrPort +} + +type stateClientRelayed2 struct { + *peerData + lastSeen time.Time + syn packetSyn + probes map[uint64]sentProbe +} + +func enterStateClientRelayed2(data *peerData) peerState { + data.staged.Relay = false + data.staged.Direct = false + data.staged.DirectAddr = netip.AddrPort{} + data.publish(data.staged) + + state := &stateClientRelayed2{ + peerData: data, + lastSeen: time.Now(), + syn: packetSyn{ + TraceID: newTraceID(), + SharedKey: data.staged.DataCipher.Key(), + Direct: false, + PossibleAddrs: data.pubAddrs.Get(), + }, + probes: map[uint64]sentProbe{}, + } + + state.Send(state.staged, state.syn) + + data.pingTimer.Reset(pingInterval) + + state.logf("==> ClientRelayed") + return state +} + +func (s *stateClientRelayed2) logf(str string, args ...any) { + s.peerData.logf("CLNT | "+str, args...) +} + +func (s *stateClientRelayed2) OnMsg(raw any) peerState { + switch msg := raw.(type) { + case peerUpdateMsg: + return initPeerState(s.peerData, msg.Peer) + case controlMsg[packetAck]: + return s.onAck(msg) + case controlMsg[packetProbe]: + return s.onProbe(msg) + case controlMsg[packetLocalDiscovery]: + return s.onLocalDiscovery(msg) + case pingTimerMsg: + return s.onPingTimer() + default: + s.logf("Ignoring message: %v", raw) + return s + } +} + +func (s *stateClientRelayed2) onAck(msg controlMsg[packetAck]) peerState { + if msg.Packet.TraceID != s.syn.TraceID { + return s + } + + s.lastSeen = time.Now() + + if !s.staged.Up { + s.staged.Up = true + s.publish(s.staged) + s.logf("Got ACK.") + } + + s.pubAddrs.Store(msg.Packet.ToAddr) + + for _, addr := range msg.Packet.PossibleAddrs { + if !addr.IsValid() { + break + } + s.sendProbeTo(addr) + } + + s.cleanProbes() + + return s +} + +func (s *stateClientRelayed2) onPingTimer() peerState { + if time.Since(s.lastSeen) > timeoutInterval { + if s.staged.Up { + s.logf("Timeout.") + } + return initPeerState(s.peerData, s.peer) + } + + s.Send(s.staged, s.syn) + return s +} + +func (s *stateClientRelayed2) onProbe(msg controlMsg[packetProbe]) peerState { + s.cleanProbes() + + sent, ok := s.probes[msg.Packet.TraceID] + if !ok { + return s + } + + s.logf("Successful probe.") + return enterStateClientDirect2(s.peerData, sent.Addr) +} + +func (s *stateClientRelayed2) onLocalDiscovery(msg controlMsg[packetLocalDiscovery]) peerState { + // The source port will be the multicast port, so we'll have to + // construct the correct address using the peer's listed port. + addr := netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) + s.sendProbeTo(addr) + return s +} + +func (s *stateClientRelayed2) cleanProbes() { + for key, sent := range s.probes { + if time.Since(sent.SentAt) > pingInterval { + delete(s.probes, key) + } + } +} + +func (s *stateClientRelayed2) sendProbeTo(addr netip.AddrPort) { + probe := packetProbe{TraceID: newTraceID()} + s.probes[probe.TraceID] = sentProbe{ + SentAt: time.Now(), + Addr: addr, + } + s.SendTo(probe, addr) +} diff --git a/peer/state-disconnected.go b/peer/state-disconnected.go new file mode 100644 index 0000000..3fdbd23 --- /dev/null +++ b/peer/state-disconnected.go @@ -0,0 +1,33 @@ +package peer + +import "net/netip" + +type stateDisconnected2 struct { + *peerData +} + +func enterStateDisconnected2(data *peerData) peerState { + data.staged.Up = false + data.staged.Relay = false + data.staged.Direct = false + data.staged.DirectAddr = netip.AddrPort{} + data.staged.PubSignKey = nil + data.staged.ControlCipher = nil + data.staged.DataCipher = nil + + data.publish(data.staged) + + data.pingTimer.Stop() + + return &stateDisconnected2{data} +} + +func (s *stateDisconnected2) OnMsg(raw any) peerState { + switch msg := raw.(type) { + case peerUpdateMsg: + return initPeerState(s.peerData, msg.Peer) + default: + s.logf("Ignoring message: %v", raw) + return s + } +} diff --git a/peer/state-server.go b/peer/state-server.go new file mode 100644 index 0000000..f3d19da --- /dev/null +++ b/peer/state-server.go @@ -0,0 +1,127 @@ +package peer + +import ( + "net/netip" + "time" +) + +type stateServer2 struct { + *peerData + lastSeen time.Time + synTraceID uint64 // Last syn trace ID. +} + +func enterStateServer2(data *peerData) peerState { + data.staged.Up = false + data.staged.Relay = false + data.staged.Direct = false + data.staged.DirectAddr = netip.AddrPort{} + data.staged.PubSignKey = data.peer.PubSignKey + data.staged.ControlCipher = newControlCipher(data.privKey, data.peer.PubKey) + data.staged.DataCipher = nil + + data.publish(data.staged) + + data.pingTimer.Reset(pingInterval) + + state := &stateServer2{peerData: data} + state.logf("==> Server") + return state +} + +func (s *stateServer2) logf(str string, args ...any) { + s.peerData.logf("SRVR | "+str, args...) +} + +func (s *stateServer2) OnMsg(raw any) peerState { + switch msg := raw.(type) { + case peerUpdateMsg: + return initPeerState(s.peerData, msg.Peer) + case controlMsg[packetInit]: + return s.onInit(msg) + case controlMsg[packetSyn]: + return s.onSyn(msg) + case controlMsg[packetProbe]: + return s.onProbe(msg) + case controlMsg[packetLocalDiscovery]: + return s + case pingTimerMsg: + return s.onPingTimer() + default: + s.logf("Ignoring message: %v", raw) + return s + } +} + +func (s *stateServer2) onInit(msg controlMsg[packetInit]) peerState { + s.staged.Up = false + s.staged.Direct = msg.Packet.Direct + s.staged.DirectAddr = msg.SrcAddr + s.publish(s.staged) + + init := packetInit{ + TraceID: msg.Packet.TraceID, + Direct: s.staged.Direct, + Version: version, + } + + s.Send(s.staged, init) + + return s +} + +func (s *stateServer2) onSyn(msg controlMsg[packetSyn]) peerState { + s.lastSeen = time.Now() + p := msg.Packet + + // Before we can respond to this packet, we need to make sure the + // route is setup properly. + // + // The client will update the syn's TraceID whenever there's a change. + // The server will follow the client's request. + if p.TraceID != s.synTraceID || !s.staged.Up { + s.synTraceID = p.TraceID + s.staged.Up = true + s.staged.Direct = p.Direct + s.staged.DataCipher = newDataCipherFromKey(p.SharedKey) + s.staged.DirectAddr = msg.SrcAddr + s.publish(s.staged) + s.logf("Got SYN.") + } + + // Always respond. + s.Send(s.staged, packetAck{ + TraceID: p.TraceID, + ToAddr: s.staged.DirectAddr, + PossibleAddrs: s.pubAddrs.Get(), + }) + + if p.Direct { + return s + } + + for _, addr := range msg.Packet.PossibleAddrs { + if !addr.IsValid() { + break + } + s.SendTo(packetProbe{TraceID: newTraceID()}, addr) + } + + return s +} + +func (s *stateServer2) onProbe(msg controlMsg[packetProbe]) peerState { + if msg.SrcAddr.IsValid() { + s.SendTo(packetProbe{TraceID: msg.Packet.TraceID}, msg.SrcAddr) + } + return s +} + +func (s *stateServer2) onPingTimer() peerState { + if time.Since(s.lastSeen) > timeoutInterval && s.staged.Up { + s.staged.Up = false + s.publish(s.staged) + s.logf("Timeout.") + } + return s +} diff --git a/peer/statedata.go b/peer/statedata.go new file mode 100644 index 0000000..44330fa --- /dev/null +++ b/peer/statedata.go @@ -0,0 +1,28 @@ +package peer + +import ( + "net/netip" + "vppn/m" +) + +type peerData = pState + +func initPeerState(data *peerData, peer *m.Peer) peerState { + data.peer = peer + + if peer == nil { + return enterStateDisconnected2(data) + } + + if _, isValid := netip.AddrFromSlice(peer.PublicIP); isValid { + if data.localAddr.IsValid() && data.localIP < data.remoteIP { + return enterStateServer2(data) + } + return enterStateClientInit2(data) + } + + if data.localAddr.IsValid() || data.localIP < data.remoteIP { + return enterStateServer2(data) + } + return enterStateClientInit2(data) +} -- 2.39.5 From 7b9c8353dd519176bbdb482372077624912b40a5 Mon Sep 17 00:00:00 2001 From: jdl Date: Tue, 25 Feb 2025 08:56:49 +0100 Subject: [PATCH 16/26] wip: working needs cleanup --- peer/pubaddrs.go | 20 ++++++++----- peer/state-clientinit.go | 12 ++------ peer/state-clientrelayed.go | 59 +++++++++++++++++++++++++------------ peer/state-server.go | 2 ++ 4 files changed, 56 insertions(+), 37 deletions(-) diff --git a/peer/pubaddrs.go b/peer/pubaddrs.go index c56b28e..7945458 100644 --- a/peer/pubaddrs.go +++ b/peer/pubaddrs.go @@ -24,22 +24,26 @@ func newPubAddrStore(localAddr netip.AddrPort) *pubAddrStore { } } -func (store *pubAddrStore) Store(add netip.AddrPort) { +func (store *pubAddrStore) Store(addr netip.AddrPort) { if store.localPub { return } + if !addr.IsValid() { + return + } + + if addr.Addr().IsPrivate() { + return + } + store.lock.Lock() defer store.lock.Unlock() - if !add.IsValid() { - return + if _, exists := store.lastSeen[addr]; !exists { + store.addrList = append(store.addrList, addr) } - - if _, exists := store.lastSeen[add]; !exists { - store.addrList = append(store.addrList, add) - } - store.lastSeen[add] = time.Now() + store.lastSeen[addr] = time.Now() store.sort() } diff --git a/peer/state-clientinit.go b/peer/state-clientinit.go index 8a84100..674d63e 100644 --- a/peer/state-clientinit.go +++ b/peer/state-clientinit.go @@ -61,13 +61,13 @@ func (s *stateClientInit2) onInit(msg controlMsg[packetInit]) peerState { return s } s.logf("Got INIT version %d.", msg.Packet.Version) - return s.nextState() + return enterStateClient(s.peerData) } func (s *stateClientInit2) onPing() peerState { if time.Since(s.startedAt) > timeoutInterval { s.logf("Init timeout. Assuming version 1.") - return s.nextState() + return enterStateClient(s.peerData) } s.sendInit() @@ -83,11 +83,3 @@ func (s *stateClientInit2) sendInit() { } s.Send(s.staged, init) } - -func (s *stateClientInit2) nextState() peerState { - if s.staged.Direct { - return enterStateClientDirect2(s.peerData, s.staged.DirectAddr) - } - - return enterStateClientRelayed2(s.peerData) -} diff --git a/peer/state-clientrelayed.go b/peer/state-clientrelayed.go index 737f0a9..b51398d 100644 --- a/peer/state-clientrelayed.go +++ b/peer/state-clientrelayed.go @@ -17,10 +17,12 @@ type stateClientRelayed2 struct { probes map[uint64]sentProbe } -func enterStateClientRelayed2(data *peerData) peerState { - data.staged.Relay = false - data.staged.Direct = false - data.staged.DirectAddr = netip.AddrPort{} +func enterStateClient(data *peerData) peerState { + ip, ipValid := netip.AddrFromSlice(data.peer.PublicIP) + + data.staged.Relay = data.peer.Relay && ipValid + data.staged.Direct = ipValid + data.staged.DirectAddr = netip.AddrPortFrom(ip, data.peer.Port) data.publish(data.staged) state := &stateClientRelayed2{ @@ -29,7 +31,7 @@ func enterStateClientRelayed2(data *peerData) peerState { syn: packetSyn{ TraceID: newTraceID(), SharedKey: data.staged.DataCipher.Key(), - Direct: false, + Direct: data.staged.Direct, PossibleAddrs: data.pubAddrs.Get(), }, probes: map[uint64]sentProbe{}, @@ -39,7 +41,7 @@ func enterStateClientRelayed2(data *peerData) peerState { data.pingTimer.Reset(pingInterval) - state.logf("==> ClientRelayed") + state.logf("==> Client") return state } @@ -52,22 +54,22 @@ func (s *stateClientRelayed2) OnMsg(raw any) peerState { case peerUpdateMsg: return initPeerState(s.peerData, msg.Peer) case controlMsg[packetAck]: - return s.onAck(msg) + s.onAck(msg) case controlMsg[packetProbe]: return s.onProbe(msg) case controlMsg[packetLocalDiscovery]: - return s.onLocalDiscovery(msg) + s.onLocalDiscovery(msg) case pingTimerMsg: return s.onPingTimer() default: s.logf("Ignoring message: %v", raw) - return s } + return s } -func (s *stateClientRelayed2) onAck(msg controlMsg[packetAck]) peerState { +func (s *stateClientRelayed2) onAck(msg controlMsg[packetAck]) { if msg.Packet.TraceID != s.syn.TraceID { - return s + return } s.lastSeen = time.Now() @@ -78,7 +80,14 @@ func (s *stateClientRelayed2) onAck(msg controlMsg[packetAck]) peerState { s.logf("Got ACK.") } - s.pubAddrs.Store(msg.Packet.ToAddr) + if s.staged.Direct { + s.pubAddrs.Store(msg.Packet.ToAddr) + return + } + + // Relayed below. + + s.cleanProbes() for _, addr := range msg.Packet.PossibleAddrs { if !addr.IsValid() { @@ -86,10 +95,6 @@ func (s *stateClientRelayed2) onAck(msg controlMsg[packetAck]) peerState { } s.sendProbeTo(addr) } - - s.cleanProbes() - - return s } func (s *stateClientRelayed2) onPingTimer() peerState { @@ -105,6 +110,10 @@ func (s *stateClientRelayed2) onPingTimer() peerState { } func (s *stateClientRelayed2) onProbe(msg controlMsg[packetProbe]) peerState { + if s.staged.Direct { + return s + } + s.cleanProbes() sent, ok := s.probes[msg.Packet.TraceID] @@ -112,16 +121,27 @@ func (s *stateClientRelayed2) onProbe(msg controlMsg[packetProbe]) peerState { return s } + s.staged.Direct = true + s.staged.DirectAddr = sent.Addr + s.publish(s.staged) + + s.syn.TraceID = newTraceID() + s.syn.Direct = true + s.Send(s.staged, s.syn) + s.logf("Successful probe.") - return enterStateClientDirect2(s.peerData, sent.Addr) + return s } -func (s *stateClientRelayed2) onLocalDiscovery(msg controlMsg[packetLocalDiscovery]) peerState { +func (s *stateClientRelayed2) onLocalDiscovery(msg controlMsg[packetLocalDiscovery]) { + if s.staged.Direct { + return + } + // The source port will be the multicast port, so we'll have to // construct the correct address using the peer's listed port. addr := netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) s.sendProbeTo(addr) - return s } func (s *stateClientRelayed2) cleanProbes() { @@ -138,5 +158,6 @@ func (s *stateClientRelayed2) sendProbeTo(addr netip.AddrPort) { SentAt: time.Now(), Addr: addr, } + s.logf("Probing %v...", addr) s.SendTo(probe, addr) } diff --git a/peer/state-server.go b/peer/state-server.go index f3d19da..4543a60 100644 --- a/peer/state-server.go +++ b/peer/state-server.go @@ -104,6 +104,7 @@ func (s *stateServer2) onSyn(msg controlMsg[packetSyn]) peerState { if !addr.IsValid() { break } + s.logf("Probing %v...", addr) s.SendTo(packetProbe{TraceID: newTraceID()}, addr) } @@ -112,6 +113,7 @@ func (s *stateServer2) onSyn(msg controlMsg[packetSyn]) peerState { func (s *stateServer2) onProbe(msg controlMsg[packetProbe]) peerState { if msg.SrcAddr.IsValid() { + s.logf("Probe response %v...", msg.SrcAddr) s.SendTo(packetProbe{TraceID: msg.Packet.TraceID}, msg.SrcAddr) } return s -- 2.39.5 From 71410204128801adcad71a64974ac005d757e8ae Mon Sep 17 00:00:00 2001 From: jdl Date: Tue, 25 Feb 2025 18:14:39 +0100 Subject: [PATCH 17/26] wip --- peer/peerstates.go | 510 ------------------------------------ peer/peerstates_test.go | 2 +- peer/peersuper.go | 4 +- peer/state-clientdirect.go | 85 ------ peer/state-clientinit.go | 16 +- peer/state-clientrelayed.go | 163 ------------ peer/state-server.go | 19 +- peer/statedata.go | 91 ++++++- 8 files changed, 107 insertions(+), 783 deletions(-) delete mode 100644 peer/peerstates.go delete mode 100644 peer/state-clientdirect.go delete mode 100644 peer/state-clientrelayed.go diff --git a/peer/peerstates.go b/peer/peerstates.go deleted file mode 100644 index 6c52f55..0000000 --- a/peer/peerstates.go +++ /dev/null @@ -1,510 +0,0 @@ -package peer - -import ( - "fmt" - "log" - "net/netip" - "strings" - "time" - "vppn/m" - - "git.crumpington.com/lib/go/ratelimiter" -) - -type peerState interface { - OnMsg(raw any) peerState -} - -// ---------------------------------------------------------------------------- - -type pState struct { - // Output. - publish func(remotePeer) - sendControlPacket func(remotePeer, marshaller) - pingTimer *time.Ticker - - // Immutable data. - localIP byte - remoteIP byte - privKey []byte - localAddr netip.AddrPort // If valid, then local peer is publicly accessible. - - pubAddrs *pubAddrStore - - // The purpose of this state machine is to manage the RemotePeer object, - // publishing it as necessary. - staged remotePeer // Local copy of shared data. See publish(). - - // Mutable peer data. - peer *m.Peer - - // We rate limit per remote endpoint because if we don't we tend to lose - // packets. - limiter *ratelimiter.Limiter -} - -/* -func (s *pState) OnPeerUpdate(peer *m.Peer) peerState { - defer func() { - // Don't defer directly otherwise s.staged will be evaluated immediately - // and won't reflect changes made in the function. - s.publish(s.staged) - }() - - s.peer = peer - s.staged.localIP = s.localIP - s.staged.Up = false - s.staged.Relay = false - s.staged.Direct = false - s.staged.DirectAddr = netip.AddrPort{} - s.staged.PubSignKey = nil - s.staged.ControlCipher = nil - s.staged.DataCipher = nil - - if peer == nil { - return enterStateDisconnected(s) - } - - s.staged.IP = peer.PeerIP - s.staged.PubSignKey = peer.PubSignKey - s.staged.ControlCipher = newControlCipher(s.privKey, peer.PubKey) - s.staged.DataCipher = newDataCipher() - - if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid { - s.staged.Relay = peer.Relay - s.staged.Direct = true - s.staged.DirectAddr = netip.AddrPortFrom(ip, peer.Port) - - if s.localAddr.IsValid() && s.localIP < s.remoteIP { - return enterStateServer(s) - } - - return enterStateClientinit(s) - } - - if s.localAddr.IsValid() { - s.staged.Direct = true - return enterStateServer(s) - } - - if s.localIP < s.remoteIP { - return enterStateServer(s) - } - - return enterStateClientinit(s) -} -*/ - -func (s *pState) logf(format string, args ...any) { - b := strings.Builder{} - name := "" - if s.peer != nil { - name = s.peer.Name - } - b.WriteString(fmt.Sprintf("%03d", s.remoteIP)) - - b.WriteString(fmt.Sprintf("%30s: ", name)) - - if s.staged.Direct { - b.WriteString("DIRECT | ") - } else { - b.WriteString("RELAYED | ") - } - - if s.staged.Up { - b.WriteString("UP | ") - } else { - b.WriteString("DOWN | ") - } - - log.Printf(b.String()+format, args...) -} - -// ---------------------------------------------------------------------------- - -func (s *pState) SendTo(pkt marshaller, addr netip.AddrPort) { - if !addr.IsValid() { - return - } - route := s.staged - route.Direct = true - route.DirectAddr = addr - s.Send(route, pkt) -} - -func (s *pState) Send(peer remotePeer, pkt marshaller) { - if err := s.limiter.Limit(); err != nil { - s.logf("Rate limited.") - return - } - s.sendControlPacket(peer, pkt) -} - -// ---------------------------------------------------------------------------- - -/* -type stateDisconnected struct{ *pState } - -func enterStateDisconnected(s *pState) peerState { - s.pingTimer.Stop() - return &stateDisconnected{pState: s} -} - -func (s *stateDisconnected) OnMsg(raw any) peerState { - switch msg := raw.(type) { - case peerUpdateMsg: - return s.OnPeerUpdate(msg.Peer) - default: - // TODO: Log. - return s - } -} - -func (s *stateDisconnected) OnSyn(controlMsg[packetSyn]) peerState { return s } -func (s *stateDisconnected) OnAck(controlMsg[packetAck]) {} -func (s *stateDisconnected) OnProbe(controlMsg[packetProbe]) peerState { return s } -func (s *stateDisconnected) OnLocalDiscovery(controlMsg[packetLocalDiscovery]) {} -func (s *stateDisconnected) OnPingTimer() peerState { return s } - -// ---------------------------------------------------------------------------- - -type stateServer struct { - *stateDisconnected - lastSeen time.Time - synTraceID uint64 -} - -func enterStateServer(s *pState) peerState { - s.logf("==> Server") - s.pingTimer.Reset(pingInterval) - return &stateServer{stateDisconnected: &stateDisconnected{pState: s}} -} - -func (s *stateServer) OnMsg(rawMsg any) peerState { - switch msg := rawMsg.(type) { - case peerUpdateMsg: - return s.OnPeerUpdate(msg.Peer) - case controlMsg[packetInit]: - return s.OnInit(msg) - case controlMsg[packetSyn]: - return s.OnSyn(msg) - case controlMsg[packetProbe]: - return s.OnProbe(msg) - case pingTimerMsg: - return s.OnPingTimer() - default: - // TODO: Log - return s - } -} - -func (s *stateServer) OnInit(msg controlMsg[packetInit]) peerState { - s.logf("Responding to INIT.") - route := s.staged - route.Direct = msg.Packet.Direct - route.DirectAddr = msg.SrcAddr - - s.Send(route, packetInit{ - TraceID: msg.Packet.TraceID, - Direct: route.Direct, - Version: version, - }) - - return s -} - -func (s *stateServer) OnSyn(msg controlMsg[packetSyn]) peerState { - s.lastSeen = time.Now() - p := msg.Packet - - // Before we can respond to this packet, we need to make sure the - // route is setup properly. - // - // The client will update the syn's TraceID whenever there's a change. - // The server will follow the client's request. - if p.TraceID != s.synTraceID || !s.staged.Up { - s.synTraceID = p.TraceID - s.staged.Up = true - s.staged.Direct = p.Direct - s.staged.DataCipher = newDataCipherFromKey(p.SharedKey) - s.staged.DirectAddr = msg.SrcAddr - s.publish(s.staged) - s.logf("Got SYN.") - } - - // Always respond. - ack := packetAck{ - TraceID: p.TraceID, - ToAddr: s.staged.DirectAddr, - PossibleAddrs: s.pubAddrs.Get(), - } - s.Send(s.staged, ack) - - if p.Direct { - return s - } - - for _, addr := range msg.Packet.PossibleAddrs { - if !addr.IsValid() { - break - } - s.SendTo(packetProbe{TraceID: newTraceID()}, addr) - } - - return s -} - -func (s *stateServer) OnProbe(msg controlMsg[packetProbe]) peerState { - if msg.SrcAddr.IsValid() { - s.SendTo(packetProbe{TraceID: msg.Packet.TraceID}, msg.SrcAddr) - } - return s -} - -func (s *stateServer) OnPingTimer() peerState { - if time.Since(s.lastSeen) > timeoutInterval && s.staged.Up { - s.staged.Up = false - s.publish(s.staged) - s.logf("Timeout.") - } - return s -} - -// ---------------------------------------------------------------------------- - -type stateClientInit struct { - *stateDisconnected - startedAt time.Time - traceID uint64 -} - -func enterStateClientinit(s *pState) peerState { - s.logf("==> ClientInit") - s.pingTimer.Reset(pingInterval) - - state := &stateClientInit{ - stateDisconnected: &stateDisconnected{s}, - startedAt: time.Now(), - traceID: newTraceID(), - } - state.Send(s.staged, packetInit{ - TraceID: state.traceID, - Direct: s.staged.Direct, - Version: version, - }) - return state -} - -func (s *stateClientInit) OnMsg(raw any) peerState { - switch msg := raw.(type) { - case peerUpdateMsg: - return s.OnPeerUpdate(msg.Peer) - case controlMsg[packetInit]: - return s.onInit(msg) - case pingTimerMsg: - return s.onPing() - default: - return s - } -} - -func (s *stateClientInit) onInit(msg controlMsg[packetInit]) peerState { - if msg.Packet.TraceID != s.traceID { - s.logf("Invalid trace ID on INIT.") - return s - } - s.logf("Got INIT version %d.", msg.Packet.Version) - return s.nextState() -} - -func (s *stateClientInit) onPing() peerState { - if time.Since(s.startedAt) > timeoutInterval { - s.logf("Init timeout. Assuming version 1.") - return s.nextState() - } - - s.traceID = newTraceID() - s.Send(s.staged, packetInit{ - TraceID: s.traceID, - Direct: s.staged.Direct, - Version: version, - }) - return s -} - -func (s *stateClientInit) nextState() peerState { - if s.staged.Direct { - return enterStateClientDirect(s.pState) - } - return enterStateClientRelayed(s.pState) -} - -// ---------------------------------------------------------------------------- - -type stateClientDirect struct { - *stateDisconnected - lastSeen time.Time - syn packetSyn -} - -func enterStateClientDirect(s *pState) peerState { - s.logf("==> ClientDirect") - s.pingTimer.Reset(pingInterval) - return newStateClientDirect(s) -} - -func newStateClientDirect(s *pState) *stateClientDirect { - state := &stateClientDirect{ - stateDisconnected: &stateDisconnected{s}, - lastSeen: time.Now(), // Avoid immediate timeout. - } - - state.syn = packetSyn{ - TraceID: newTraceID(), - SharedKey: s.staged.DataCipher.Key(), - Direct: s.staged.Direct, - PossibleAddrs: s.pubAddrs.Get(), - } - state.Send(s.staged, state.syn) - return state -} - -func (s *stateClientDirect) OnMsg(raw any) peerState { - switch msg := raw.(type) { - case peerUpdateMsg: - return s.OnPeerUpdate(msg.Peer) - case controlMsg[packetAck]: - s.OnAck(msg) - return s - case pingTimerMsg: - if next := s.onPingTimer(); next != nil { - return next - } - return s - default: - // TODO: Log - return s - } -} - -func (s *stateClientDirect) OnAck(msg controlMsg[packetAck]) { - if msg.Packet.TraceID != s.syn.TraceID { - return - } - - s.lastSeen = time.Now() - - if !s.staged.Up { - s.staged.Up = true - s.publish(s.staged) - s.logf("Got ACK.") - } - - s.pubAddrs.Store(msg.Packet.ToAddr) -} - -func (s *stateClientDirect) onPingTimer() peerState { - if time.Since(s.lastSeen) > timeoutInterval { - if s.staged.Up { - s.staged.Up = false - s.publish(s.staged) - s.logf("Timeout.") - } - return s.OnPeerUpdate(s.peer) - } - - s.Send(s.staged, s.syn) - return nil -} - -// ---------------------------------------------------------------------------- - -type stateClientRelayed struct { - *stateClientDirect - ack packetAck - probes map[uint64]netip.AddrPort // TODO: something better - localDiscoveryAddr netip.AddrPort // TODO: Remove -} - -func enterStateClientRelayed(s *pState) peerState { - s.logf("==> ClientRelayed") - s.pingTimer.Reset(pingInterval) - return &stateClientRelayed{ - stateClientDirect: newStateClientDirect(s), - probes: map[uint64]netip.AddrPort{}, - } -} - -func (s *stateClientRelayed) OnMsg(raw any) peerState { - switch msg := raw.(type) { - case peerUpdateMsg: - return s.OnPeerUpdate(msg.Peer) - case controlMsg[packetAck]: - s.OnAck(msg) - return s - case controlMsg[packetProbe]: - return s.OnProbe(msg) - case controlMsg[packetLocalDiscovery]: - s.OnLocalDiscovery(msg) - return s - case pingTimerMsg: - return s.OnPingTimer() - default: - // TODO: Log - return s - } -} - -func (s *stateClientRelayed) OnAck(msg controlMsg[packetAck]) { - s.ack = msg.Packet - s.stateClientDirect.OnAck(msg) - - // TODO: Send probes now. -} - -func (s *stateClientRelayed) OnProbe(msg controlMsg[packetProbe]) peerState { - addr, ok := s.probes[msg.Packet.TraceID] - if !ok { - return s - } - - s.staged.DirectAddr = addr - s.staged.Direct = true - s.publish(s.staged) - return enterStateClientDirect(s.stateClientDirect.pState) -} - -func (s *stateClientRelayed) OnLocalDiscovery(msg controlMsg[packetLocalDiscovery]) { - // The source port will be the multicast port, so we'll have to - // construct the correct address using the peer's listed port. - s.localDiscoveryAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) - // TODO: s.sendProbeTo(s.localDiscoveryAddr) -} - -func (s *stateClientRelayed) OnPingTimer() peerState { - if next := s.stateClientDirect.onPingTimer(); next != nil { - return next - } - - clear(s.probes) - for _, addr := range s.ack.PossibleAddrs { - if !addr.IsValid() { - break - } - s.sendProbeTo(addr) - } - - if s.localDiscoveryAddr.IsValid() { - s.sendProbeTo(s.localDiscoveryAddr) - s.localDiscoveryAddr = netip.AddrPort{} - } - - return s -} - -func (s *stateClientRelayed) sendProbeTo(addr netip.AddrPort) { - probe := packetProbe{TraceID: newTraceID()} - s.probes[probe.TraceID] = addr - s.SendTo(probe, addr) -} -*/ diff --git a/peer/peerstates_test.go b/peer/peerstates_test.go index 15f7d18..26ebacd 100644 --- a/peer/peerstates_test.go +++ b/peer/peerstates_test.go @@ -27,7 +27,7 @@ func NewPeerStateTestHarness() *PeerStateTestHarness { keys := generateKeys() - state := &pState{ + state := &peerData{ publish: func(rp remotePeer) { h.Published = rp }, diff --git a/peer/peersuper.go b/peer/peersuper.go index ec8c741..2ce6d03 100644 --- a/peer/peersuper.go +++ b/peer/peersuper.go @@ -39,7 +39,7 @@ func newSupervisor( pubAddrs := newPubAddrStore(routes.LocalAddr) for i := range s.peers { - state := &pState{ + state := &peerData{ publish: s.publish, sendControlPacket: s.send, pingTimer: time.NewTicker(timeoutInterval), @@ -121,7 +121,7 @@ type peerSuper struct { pingTimer *time.Ticker } -func newPeerSuper(state *pState, pingTimer *time.Ticker) *peerSuper { +func newPeerSuper(state *peerData, pingTimer *time.Ticker) *peerSuper { return &peerSuper{ messages: make(chan any, 8), state: initPeerState(state, nil), diff --git a/peer/state-clientdirect.go b/peer/state-clientdirect.go deleted file mode 100644 index c6c552d..0000000 --- a/peer/state-clientdirect.go +++ /dev/null @@ -1,85 +0,0 @@ -package peer - -import ( - "net/netip" - "time" -) - -type stateClientDirect2 struct { - *peerData - lastSeen time.Time - syn packetSyn -} - -func enterStateClientDirect2(data *peerData, directAddr netip.AddrPort) peerState { - data.staged.Relay = data.peer.Relay - data.staged.Direct = true - data.staged.DirectAddr = directAddr - data.publish(data.staged) - - state := &stateClientDirect2{ - peerData: data, - lastSeen: time.Now(), - syn: packetSyn{ - TraceID: newTraceID(), - SharedKey: data.staged.DataCipher.Key(), - Direct: true, - }, - } - - state.Send(state.staged, state.syn) - - data.pingTimer.Reset(pingInterval) - - state.logf("==> ClientDirect") - return state -} - -func (s *stateClientDirect2) logf(str string, args ...any) { - s.peerData.logf("CLNT | "+str, args...) -} - -func (s *stateClientDirect2) OnMsg(raw any) peerState { - switch msg := raw.(type) { - case peerUpdateMsg: - return initPeerState(s.peerData, msg.Peer) - case controlMsg[packetAck]: - return s.onAck(msg) - case pingTimerMsg: - return s.onPingTimer() - case controlMsg[packetLocalDiscovery]: - return s - default: - s.logf("Ignoring message: %v", raw) - return s - } -} - -func (s *stateClientDirect2) onAck(msg controlMsg[packetAck]) peerState { - if msg.Packet.TraceID != s.syn.TraceID { - return s - } - - s.lastSeen = time.Now() - - if !s.staged.Up { - s.staged.Up = true - s.publish(s.staged) - s.logf("Got ACK.") - } - - s.pubAddrs.Store(msg.Packet.ToAddr) - return s -} - -func (s *stateClientDirect2) onPingTimer() peerState { - if time.Since(s.lastSeen) > timeoutInterval { - if s.staged.Up { - s.logf("Timeout.") - } - return initPeerState(s.peerData, s.peer) - } - - s.Send(s.staged, s.syn) - return s -} diff --git a/peer/state-clientinit.go b/peer/state-clientinit.go index 674d63e..f34854a 100644 --- a/peer/state-clientinit.go +++ b/peer/state-clientinit.go @@ -5,13 +5,13 @@ import ( "time" ) -type stateClientInit2 struct { +type stateClientInit struct { *peerData startedAt time.Time traceID uint64 } -func enterStateClientInit2(data *peerData) peerState { +func enterStateClientInit(data *peerData) peerState { ip, ipValid := netip.AddrFromSlice(data.peer.PublicIP) data.staged.Up = false @@ -24,7 +24,7 @@ func enterStateClientInit2(data *peerData) peerState { data.publish(data.staged) - state := &stateClientInit2{ + state := &stateClientInit{ peerData: data, startedAt: time.Now(), traceID: newTraceID(), @@ -37,11 +37,11 @@ func enterStateClientInit2(data *peerData) peerState { return state } -func (s *stateClientInit2) logf(str string, args ...any) { +func (s *stateClientInit) logf(str string, args ...any) { s.peerData.logf("INIT | "+str, args...) } -func (s *stateClientInit2) OnMsg(raw any) peerState { +func (s *stateClientInit) OnMsg(raw any) peerState { switch msg := raw.(type) { case peerUpdateMsg: return initPeerState(s.peerData, msg.Peer) @@ -55,7 +55,7 @@ func (s *stateClientInit2) OnMsg(raw any) peerState { } } -func (s *stateClientInit2) onInit(msg controlMsg[packetInit]) peerState { +func (s *stateClientInit) onInit(msg controlMsg[packetInit]) peerState { if msg.Packet.TraceID != s.traceID { s.logf("Invalid trace ID on INIT.") return s @@ -64,7 +64,7 @@ func (s *stateClientInit2) onInit(msg controlMsg[packetInit]) peerState { return enterStateClient(s.peerData) } -func (s *stateClientInit2) onPing() peerState { +func (s *stateClientInit) onPing() peerState { if time.Since(s.startedAt) > timeoutInterval { s.logf("Init timeout. Assuming version 1.") return enterStateClient(s.peerData) @@ -74,7 +74,7 @@ func (s *stateClientInit2) onPing() peerState { return s } -func (s *stateClientInit2) sendInit() { +func (s *stateClientInit) sendInit() { s.traceID = newTraceID() init := packetInit{ TraceID: s.traceID, diff --git a/peer/state-clientrelayed.go b/peer/state-clientrelayed.go deleted file mode 100644 index b51398d..0000000 --- a/peer/state-clientrelayed.go +++ /dev/null @@ -1,163 +0,0 @@ -package peer - -import ( - "net/netip" - "time" -) - -type sentProbe struct { - SentAt time.Time - Addr netip.AddrPort -} - -type stateClientRelayed2 struct { - *peerData - lastSeen time.Time - syn packetSyn - probes map[uint64]sentProbe -} - -func enterStateClient(data *peerData) peerState { - ip, ipValid := netip.AddrFromSlice(data.peer.PublicIP) - - data.staged.Relay = data.peer.Relay && ipValid - data.staged.Direct = ipValid - data.staged.DirectAddr = netip.AddrPortFrom(ip, data.peer.Port) - data.publish(data.staged) - - state := &stateClientRelayed2{ - peerData: data, - lastSeen: time.Now(), - syn: packetSyn{ - TraceID: newTraceID(), - SharedKey: data.staged.DataCipher.Key(), - Direct: data.staged.Direct, - PossibleAddrs: data.pubAddrs.Get(), - }, - probes: map[uint64]sentProbe{}, - } - - state.Send(state.staged, state.syn) - - data.pingTimer.Reset(pingInterval) - - state.logf("==> Client") - return state -} - -func (s *stateClientRelayed2) logf(str string, args ...any) { - s.peerData.logf("CLNT | "+str, args...) -} - -func (s *stateClientRelayed2) OnMsg(raw any) peerState { - switch msg := raw.(type) { - case peerUpdateMsg: - return initPeerState(s.peerData, msg.Peer) - case controlMsg[packetAck]: - s.onAck(msg) - case controlMsg[packetProbe]: - return s.onProbe(msg) - case controlMsg[packetLocalDiscovery]: - s.onLocalDiscovery(msg) - case pingTimerMsg: - return s.onPingTimer() - default: - s.logf("Ignoring message: %v", raw) - } - return s -} - -func (s *stateClientRelayed2) onAck(msg controlMsg[packetAck]) { - if msg.Packet.TraceID != s.syn.TraceID { - return - } - - s.lastSeen = time.Now() - - if !s.staged.Up { - s.staged.Up = true - s.publish(s.staged) - s.logf("Got ACK.") - } - - if s.staged.Direct { - s.pubAddrs.Store(msg.Packet.ToAddr) - return - } - - // Relayed below. - - s.cleanProbes() - - for _, addr := range msg.Packet.PossibleAddrs { - if !addr.IsValid() { - break - } - s.sendProbeTo(addr) - } -} - -func (s *stateClientRelayed2) onPingTimer() peerState { - if time.Since(s.lastSeen) > timeoutInterval { - if s.staged.Up { - s.logf("Timeout.") - } - return initPeerState(s.peerData, s.peer) - } - - s.Send(s.staged, s.syn) - return s -} - -func (s *stateClientRelayed2) onProbe(msg controlMsg[packetProbe]) peerState { - if s.staged.Direct { - return s - } - - s.cleanProbes() - - sent, ok := s.probes[msg.Packet.TraceID] - if !ok { - return s - } - - s.staged.Direct = true - s.staged.DirectAddr = sent.Addr - s.publish(s.staged) - - s.syn.TraceID = newTraceID() - s.syn.Direct = true - s.Send(s.staged, s.syn) - - s.logf("Successful probe.") - return s -} - -func (s *stateClientRelayed2) onLocalDiscovery(msg controlMsg[packetLocalDiscovery]) { - if s.staged.Direct { - return - } - - // The source port will be the multicast port, so we'll have to - // construct the correct address using the peer's listed port. - addr := netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) - s.sendProbeTo(addr) -} - -func (s *stateClientRelayed2) cleanProbes() { - for key, sent := range s.probes { - if time.Since(sent.SentAt) > pingInterval { - delete(s.probes, key) - } - } -} - -func (s *stateClientRelayed2) sendProbeTo(addr netip.AddrPort) { - probe := packetProbe{TraceID: newTraceID()} - s.probes[probe.TraceID] = sentProbe{ - SentAt: time.Now(), - Addr: addr, - } - s.logf("Probing %v...", addr) - s.SendTo(probe, addr) -} diff --git a/peer/state-server.go b/peer/state-server.go index 4543a60..723af5c 100644 --- a/peer/state-server.go +++ b/peer/state-server.go @@ -5,13 +5,13 @@ import ( "time" ) -type stateServer2 struct { +type stateServer struct { *peerData lastSeen time.Time synTraceID uint64 // Last syn trace ID. } -func enterStateServer2(data *peerData) peerState { +func enterStateServer(data *peerData) peerState { data.staged.Up = false data.staged.Relay = false data.staged.Direct = false @@ -24,16 +24,16 @@ func enterStateServer2(data *peerData) peerState { data.pingTimer.Reset(pingInterval) - state := &stateServer2{peerData: data} + state := &stateServer{peerData: data} state.logf("==> Server") return state } -func (s *stateServer2) logf(str string, args ...any) { +func (s *stateServer) logf(str string, args ...any) { s.peerData.logf("SRVR | "+str, args...) } -func (s *stateServer2) OnMsg(raw any) peerState { +func (s *stateServer) OnMsg(raw any) peerState { switch msg := raw.(type) { case peerUpdateMsg: return initPeerState(s.peerData, msg.Peer) @@ -53,7 +53,7 @@ func (s *stateServer2) OnMsg(raw any) peerState { } } -func (s *stateServer2) onInit(msg controlMsg[packetInit]) peerState { +func (s *stateServer) onInit(msg controlMsg[packetInit]) peerState { s.staged.Up = false s.staged.Direct = msg.Packet.Direct s.staged.DirectAddr = msg.SrcAddr @@ -70,7 +70,7 @@ func (s *stateServer2) onInit(msg controlMsg[packetInit]) peerState { return s } -func (s *stateServer2) onSyn(msg controlMsg[packetSyn]) peerState { +func (s *stateServer) onSyn(msg controlMsg[packetSyn]) peerState { s.lastSeen = time.Now() p := msg.Packet @@ -100,6 +100,7 @@ func (s *stateServer2) onSyn(msg controlMsg[packetSyn]) peerState { return s } + // Send probes if not a direct connection. for _, addr := range msg.Packet.PossibleAddrs { if !addr.IsValid() { break @@ -111,7 +112,7 @@ func (s *stateServer2) onSyn(msg controlMsg[packetSyn]) peerState { return s } -func (s *stateServer2) onProbe(msg controlMsg[packetProbe]) peerState { +func (s *stateServer) onProbe(msg controlMsg[packetProbe]) peerState { if msg.SrcAddr.IsValid() { s.logf("Probe response %v...", msg.SrcAddr) s.SendTo(packetProbe{TraceID: msg.Packet.TraceID}, msg.SrcAddr) @@ -119,7 +120,7 @@ func (s *stateServer2) onProbe(msg controlMsg[packetProbe]) peerState { return s } -func (s *stateServer2) onPingTimer() peerState { +func (s *stateServer) onPingTimer() peerState { if time.Since(s.lastSeen) > timeoutInterval && s.staged.Up { s.staged.Up = false s.publish(s.staged) diff --git a/peer/statedata.go b/peer/statedata.go index 44330fa..0ea0929 100644 --- a/peer/statedata.go +++ b/peer/statedata.go @@ -1,11 +1,92 @@ package peer import ( + "fmt" + "log" "net/netip" + "strings" + "time" "vppn/m" + + "git.crumpington.com/lib/go/ratelimiter" ) -type peerData = pState +type peerState interface { + OnMsg(raw any) peerState +} + +// ---------------------------------------------------------------------------- + +type peerData struct { + // Output. + publish func(remotePeer) + sendControlPacket func(remotePeer, marshaller) + pingTimer *time.Ticker + + // Immutable data. + localIP byte + remoteIP byte + privKey []byte + localAddr netip.AddrPort // If valid, then local peer is publicly accessible. + + pubAddrs *pubAddrStore + + // The purpose of this state machine is to manage the RemotePeer object, + // publishing it as necessary. + staged remotePeer // Local copy of shared data. See publish(). + + // Mutable peer data. + peer *m.Peer + + // We rate limit per remote endpoint because if we don't we tend to lose + // packets. + limiter *ratelimiter.Limiter +} + +func (s *peerData) logf(format string, args ...any) { + b := strings.Builder{} + name := "" + if s.peer != nil { + name = s.peer.Name + } + b.WriteString(fmt.Sprintf("%03d", s.remoteIP)) + + b.WriteString(fmt.Sprintf("%30s: ", name)) + + if s.staged.Direct { + b.WriteString("DIRECT | ") + } else { + b.WriteString("RELAYED | ") + } + + if s.staged.Up { + b.WriteString("UP | ") + } else { + b.WriteString("DOWN | ") + } + + log.Printf(b.String()+format, args...) +} + +// ---------------------------------------------------------------------------- + +func (s *peerData) SendTo(pkt marshaller, addr netip.AddrPort) { + if !addr.IsValid() { + return + } + route := s.staged + route.Direct = true + route.DirectAddr = addr + s.Send(route, pkt) +} + +func (s *peerData) Send(peer remotePeer, pkt marshaller) { + if err := s.limiter.Limit(); err != nil { + s.logf("Rate limited.") + return + } + s.sendControlPacket(peer, pkt) +} func initPeerState(data *peerData, peer *m.Peer) peerState { data.peer = peer @@ -16,13 +97,13 @@ func initPeerState(data *peerData, peer *m.Peer) peerState { if _, isValid := netip.AddrFromSlice(peer.PublicIP); isValid { if data.localAddr.IsValid() && data.localIP < data.remoteIP { - return enterStateServer2(data) + return enterStateServer(data) } - return enterStateClientInit2(data) + return enterStateClientInit(data) } if data.localAddr.IsValid() || data.localIP < data.remoteIP { - return enterStateServer2(data) + return enterStateServer(data) } - return enterStateClientInit2(data) + return enterStateClientInit(data) } -- 2.39.5 From 68cc5195b8b69eb0dfdb42ea673d09edc5867a1c Mon Sep 17 00:00:00 2001 From: jdl Date: Tue, 25 Feb 2025 19:23:44 +0100 Subject: [PATCH 18/26] wip --- peer/connreader.go | 2 +- peer/globals.go | 6 +- peer/hubpoller.go | 16 +++-- peer/logging.go | 13 ---- peer/mcreader.go | 6 +- peer/mcwriter.go | 4 +- peer/peer.go | 28 +++++--- peer/state-client.go | 162 +++++++++++++++++++++++++++++++++++++++++++ 8 files changed, 201 insertions(+), 36 deletions(-) delete mode 100644 peer/logging.go create mode 100644 peer/state-client.go diff --git a/peer/connreader.go b/peer/connreader.go index 0727ced..4c156f4 100644 --- a/peer/connreader.go +++ b/peer/connreader.go @@ -84,7 +84,7 @@ func (r *connReader) handleControlPacket( enc []byte, ) { if peer.ControlCipher == nil { - log.Printf("No control cipher for peer: %v", h) + r.logf("No control cipher for peer: %d", h.SourceIP) return } diff --git a/peer/globals.go b/peer/globals.go index cd0e1f6..6dd26eb 100644 --- a/peer/globals.go +++ b/peer/globals.go @@ -18,8 +18,10 @@ const ( dataCipherOverhead = 16 signOverhead = 64 - pingInterval = 8 * time.Second - timeoutInterval = 30 * time.Second + pingInterval = 8 * time.Second + timeoutInterval = 30 * time.Second + broadcastInterval = 16 * time.Second + broadcastErrorTimeoutInterval = 8 * time.Second ) var multicastAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom( diff --git a/peer/hubpoller.go b/peer/hubpoller.go index 238dfda..0082989 100644 --- a/peer/hubpoller.go +++ b/peer/hubpoller.go @@ -50,11 +50,15 @@ func newHubPoller( }, nil } +func (hp *hubPoller) logf(s string, args ...any) { + log.Printf("[HubPoller] "+s, args...) +} + func (hp *hubPoller) Run() { state, err := loadNetworkState(hp.netName) if err != nil { - log.Printf("Failed to load network state: %v", err) - log.Printf("Polling hub...") + hp.logf("Failed to load network state: %v", err) + hp.logf("Polling hub...") hp.pollHub() } else { hp.applyNetworkState(state) @@ -70,25 +74,25 @@ func (hp *hubPoller) pollHub() { resp, err := hp.client.Do(hp.req) if err != nil { - log.Printf("Failed to fetch peer state: %v", err) + hp.logf("Failed to fetch peer state: %v", err) return } body, err := io.ReadAll(resp.Body) _ = resp.Body.Close() if err != nil { - log.Printf("Failed to read body from hub: %v", err) + hp.logf("Failed to read body from hub: %v", err) return } if err := json.Unmarshal(body, &state); err != nil { - log.Printf("Failed to unmarshal response from hub: %v\n%s", err, body) + hp.logf("Failed to unmarshal response from hub: %v\n%s", err, body) return } hp.applyNetworkState(state) if err := storeNetworkState(hp.netName, state); err != nil { - log.Printf("Failed to store network state: %v", err) + hp.logf("Failed to store network state: %v", err) } } diff --git a/peer/logging.go b/peer/logging.go deleted file mode 100644 index 4906b04..0000000 --- a/peer/logging.go +++ /dev/null @@ -1,13 +0,0 @@ -package peer - -import "log" - -func logPacket(p []byte, notes string) { - h := parseHeader(p) - log.Printf(`Sending: Data: %v | From: %d | To: %d | %s -`, - h.StreamID == dataStreamID, - h.SourceIP, - h.DestIP, - notes) -} diff --git a/peer/mcreader.go b/peer/mcreader.go index 7c63f26..7b8af27 100644 --- a/peer/mcreader.go +++ b/peer/mcreader.go @@ -12,12 +12,12 @@ func runMCReader( handleControlMsg func(destIP byte, msg any), ) { for { - runMCReader2(rt, handleControlMsg) - time.Sleep(8 * time.Second) + runMCReaderInner(rt, handleControlMsg) + time.Sleep(broadcastErrorTimeoutInterval) } } -func runMCReader2( +func runMCReaderInner( rt *atomic.Pointer[routingTable], handleControlMsg func(destIP byte, msg any), ) { diff --git a/peer/mcwriter.go b/peer/mcwriter.go index 29cf2be..eb53af4 100644 --- a/peer/mcwriter.go +++ b/peer/mcwriter.go @@ -41,10 +41,10 @@ func runMCWriter(localIP byte, signingKey []byte) { conn, err := net.ListenMulticastUDP("udp", nil, multicastAddr) if err != nil { - log.Fatalf("Failed to bind to multicast address: %v", err) + log.Fatalf("[MCWriter] Failed to bind to multicast address: %v", err) } - for range time.Tick(8 * time.Second) { + for range time.Tick(broadcastInterval) { _, err := conn.WriteToUDP(discoveryPacket, multicastAddr) if err != nil { log.Printf("[MCWriter] Failed to write multicast: %v", err) diff --git a/peer/peer.go b/peer/peer.go index 45627b0..c210af4 100644 --- a/peer/peer.go +++ b/peer/peer.go @@ -32,10 +32,14 @@ type peerConfig struct { } func newPeerMain(conf peerConfig) *peerMain { + logf := func(s string, args ...any) { + log.Printf("[Main] "+s, args...) + } + config, err := loadPeerConfig(conf.NetName) if err != nil { - log.Printf("Failed to load configuration: %v", err) - log.Printf("Initializing...") + logf("Failed to load configuration: %v", err) + logf("Initializing...") initPeerWithHub(conf) config, err = loadPeerConfig(conf.NetName) @@ -54,7 +58,7 @@ func newPeerMain(conf peerConfig) *peerMain { log.Fatalf("Failed to resolve UDP address: %v", err) } - log.Printf("Listening on %v...", myAddr) + logf("Listening on %v...", myAddr) conn, err := net.ListenUDP("udp", myAddr) if err != nil { log.Fatalf("Failed to open UDP port: %v", err) @@ -69,15 +73,15 @@ func newPeerMain(conf peerConfig) *peerMain { writeLock.Lock() n, err = conn.WriteToUDPAddrPort(b, addr) if err != nil { - log.Printf("Failed to write packet: %v", err) + logf("Failed to write packet: %v", err) } writeLock.Unlock() return n, err } var localAddr netip.AddrPort - ip, ok := netip.AddrFromSlice(config.PublicIP) - if ok { + ip, localAddrValid := netip.AddrFromSlice(config.PublicIP) + if localAddrValid { localAddr = netip.AddrPortFrom(ip, config.Port) } @@ -105,12 +109,18 @@ func newPeerMain(conf peerConfig) *peerMain { } func (p *peerMain) Run() { + go p.ifReader.Run() go p.connReader.Run() p.super.Start() - go runMCWriter(p.conf.PeerIP, p.conf.PrivSignKey) - go runMCReader(p.rt, p.super.HandleControlMsg) - p.hubPoller.Run() + + if !p.rt.Load().LocalAddr.IsValid() { + go runMCWriter(p.conf.PeerIP, p.conf.PrivSignKey) + go runMCReader(p.rt, p.super.HandleControlMsg) + } + + go p.hubPoller.Run() + select {} } func initPeerWithHub(conf peerConfig) { diff --git a/peer/state-client.go b/peer/state-client.go new file mode 100644 index 0000000..49e4375 --- /dev/null +++ b/peer/state-client.go @@ -0,0 +1,162 @@ +package peer + +import ( + "net/netip" + "time" +) + +type sentProbe struct { + SentAt time.Time + Addr netip.AddrPort +} + +type stateClient struct { + *peerData + lastSeen time.Time + syn packetSyn + probes map[uint64]sentProbe +} + +func enterStateClient(data *peerData) peerState { + ip, ipValid := netip.AddrFromSlice(data.peer.PublicIP) + + data.staged.Relay = data.peer.Relay && ipValid + data.staged.Direct = ipValid + data.staged.DirectAddr = netip.AddrPortFrom(ip, data.peer.Port) + data.publish(data.staged) + + state := &stateClient{ + peerData: data, + lastSeen: time.Now(), + syn: packetSyn{ + TraceID: newTraceID(), + SharedKey: data.staged.DataCipher.Key(), + Direct: data.staged.Direct, + PossibleAddrs: data.pubAddrs.Get(), + }, + probes: map[uint64]sentProbe{}, + } + + state.Send(state.staged, state.syn) + + data.pingTimer.Reset(pingInterval) + + state.logf("==> Client") + return state +} + +func (s *stateClient) logf(str string, args ...any) { + s.peerData.logf("CLNT | "+str, args...) +} + +func (s *stateClient) OnMsg(raw any) peerState { + switch msg := raw.(type) { + case peerUpdateMsg: + return initPeerState(s.peerData, msg.Peer) + case controlMsg[packetAck]: + s.onAck(msg) + case controlMsg[packetProbe]: + return s.onProbe(msg) + case controlMsg[packetLocalDiscovery]: + s.onLocalDiscovery(msg) + case pingTimerMsg: + return s.onPingTimer() + default: + s.logf("Ignoring message: %v", raw) + } + return s +} + +func (s *stateClient) onAck(msg controlMsg[packetAck]) { + if msg.Packet.TraceID != s.syn.TraceID { + return + } + + s.lastSeen = time.Now() + + if !s.staged.Up { + s.staged.Up = true + s.publish(s.staged) + s.logf("Got ACK.") + } + + if s.staged.Direct { + s.pubAddrs.Store(msg.Packet.ToAddr) + return + } + + // Relayed below. + + s.cleanProbes() + + for _, addr := range msg.Packet.PossibleAddrs { + if !addr.IsValid() { + break + } + s.sendProbeTo(addr) + } +} + +func (s *stateClient) onPingTimer() peerState { + if time.Since(s.lastSeen) > timeoutInterval { + if s.staged.Up { + s.logf("Timeout.") + } + return initPeerState(s.peerData, s.peer) + } + + s.Send(s.staged, s.syn) + return s +} + +func (s *stateClient) onProbe(msg controlMsg[packetProbe]) peerState { + if s.staged.Direct { + return s + } + + s.cleanProbes() + + sent, ok := s.probes[msg.Packet.TraceID] + if !ok { + return s + } + + s.staged.Direct = true + s.staged.DirectAddr = sent.Addr + s.publish(s.staged) + + s.syn.TraceID = newTraceID() + s.syn.Direct = true + s.Send(s.staged, s.syn) + s.logf("Successful probe.") + return s +} + +func (s *stateClient) onLocalDiscovery(msg controlMsg[packetLocalDiscovery]) { + if s.staged.Direct { + return + } + + // The source port will be the multicast port, so we'll have to + // construct the correct address using the peer's listed port. + addr := netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) + s.sendProbeTo(addr) +} + +func (s *stateClient) cleanProbes() { + for key, sent := range s.probes { + if time.Since(sent.SentAt) > pingInterval { + delete(s.probes, key) + } + } +} + +func (s *stateClient) sendProbeTo(addr netip.AddrPort) { + probe := packetProbe{TraceID: newTraceID()} + s.probes[probe.TraceID] = sentProbe{ + SentAt: time.Now(), + Addr: addr, + } + s.logf("Probing %v...", addr) + s.SendTo(probe, addr) +} -- 2.39.5 From 8dbccaa3e6fc07410ed60726ac070b9994029037 Mon Sep 17 00:00:00 2001 From: jdl Date: Tue, 25 Feb 2025 22:24:51 +0100 Subject: [PATCH 19/26] wip: testing --- peer/peerstates_test.go | 135 +------------------------------ peer/state-clientinit_test.go | 83 +++++++++++++++++++ peer/state-disconnected.go | 8 +- peer/state-util_test.go | 146 ++++++++++++++++++++++++++++++++++ peer/statedata.go | 2 +- 5 files changed, 238 insertions(+), 136 deletions(-) create mode 100644 peer/state-clientinit_test.go create mode 100644 peer/state-util_test.go diff --git a/peer/peerstates_test.go b/peer/peerstates_test.go index 26ebacd..80f8210 100644 --- a/peer/peerstates_test.go +++ b/peer/peerstates_test.go @@ -1,143 +1,12 @@ package peer import ( - "net/netip" "testing" - "time" "vppn/m" - - "git.crumpington.com/lib/go/ratelimiter" ) // ---------------------------------------------------------------------------- -type PeerStateControlMsg struct { - Peer remotePeer - Packet any -} - -type PeerStateTestHarness struct { - State peerState - Published remotePeer - Sent []PeerStateControlMsg -} - -func NewPeerStateTestHarness() *PeerStateTestHarness { - h := &PeerStateTestHarness{} - - keys := generateKeys() - - state := &peerData{ - publish: func(rp remotePeer) { - h.Published = rp - }, - sendControlPacket: func(rp remotePeer, pkt marshaller) { - h.Sent = append(h.Sent, PeerStateControlMsg{rp, pkt}) - }, - pingTimer: time.NewTicker(pingInterval), - localIP: 2, - remoteIP: 3, - privKey: keys.PrivKey, - pubAddrs: newPubAddrStore(netip.AddrPort{}), - limiter: ratelimiter.New(ratelimiter.Config{ - FillPeriod: 20 * time.Millisecond, - MaxWaitCount: 1, - }), - } - - h.State = enterStateDisconnected(state) - return h -} - -func (h *PeerStateTestHarness) PeerUpdate(p *m.Peer) { - h.State = h.State.OnMsg(peerUpdateMsg{p}) -} - -func (h *PeerStateTestHarness) OnInit(msg controlMsg[packetInit]) { - h.State = h.State.OnMsg(msg) -} - -func (h *PeerStateTestHarness) OnSyn(msg controlMsg[packetSyn]) { - h.State = h.State.OnMsg(msg) -} - -func (h *PeerStateTestHarness) OnProbe(msg controlMsg[packetProbe]) { - h.State = h.State.OnMsg(msg) -} - -func (h *PeerStateTestHarness) OnPingTimer() { - h.State = h.State.OnMsg(pingTimerMsg{}) -} - -func (h *PeerStateTestHarness) ConfigServer_Public(t *testing.T) *stateServer { - keys := generateKeys() - - state := h.State.(*stateDisconnected) - state.localAddr = addrPort4(1, 1, 1, 2, 200) - - peer := &m.Peer{ - PeerIP: 3, - PublicIP: []byte{1, 1, 1, 3}, - Port: 456, - PubKey: keys.PubKey, - PubSignKey: keys.PubSignKey, - } - - h.PeerUpdate(peer) - assertEqual(t, h.Published.Up, false) - return assertType[*stateServer](t, h.State) -} - -func (h *PeerStateTestHarness) ConfigServer_Relayed(t *testing.T) *stateServer { - keys := generateKeys() - peer := &m.Peer{ - PeerIP: 3, - Port: 456, - PubKey: keys.PubKey, - PubSignKey: keys.PubSignKey, - } - - h.PeerUpdate(peer) - assertEqual(t, h.Published.Up, false) - return assertType[*stateServer](t, h.State) -} - -func (h *PeerStateTestHarness) ConfigClientDirect(t *testing.T) *stateClientDirect { - keys := generateKeys() - peer := &m.Peer{ - PeerIP: 3, - PublicIP: []byte{1, 2, 3, 4}, - Port: 456, - PubKey: keys.PubKey, - PubSignKey: keys.PubSignKey, - } - - h.PeerUpdate(peer) - assertEqual(t, h.Published.Up, false) - - return assertType[*stateClientDirect](t, h.State) -} - -func (h *PeerStateTestHarness) ConfigClientRelayed(t *testing.T) *stateClientRelayed { - keys := generateKeys() - - state := h.State.(*stateDisconnected) - state.remoteIP = 1 - - peer := &m.Peer{ - PeerIP: 3, - Port: 456, - PubKey: keys.PubKey, - PubSignKey: keys.PubSignKey, - } - - h.PeerUpdate(peer) - assertEqual(t, h.Published.Up, false) - return assertType[*stateClientRelayed](t, h.State) -} - -// ---------------------------------------------------------------------------- - func TestPeerState_OnPeerUpdate_nilPeer(t *testing.T) { h := NewPeerStateTestHarness() h.PeerUpdate(nil) @@ -163,6 +32,7 @@ func TestPeerState_OnPeerUpdate_publicLocalIsServer(t *testing.T) { assertType[*stateServer](t, h.State) } +/* func TestPeerState_OnPeerUpdate_serverDirect(t *testing.T) { h := NewPeerStateTestHarness() h.ConfigServer_Public(t) @@ -178,11 +48,13 @@ func TestPeerState_OnPeerUpdate_clientDirect(t *testing.T) { h.ConfigClientDirect(t) } +/* func TestPeerState_OnPeerUpdate_clientRelayed(t *testing.T) { h := NewPeerStateTestHarness() h.ConfigClientRelayed(t) } +/* func TestStateServer_directSyn(t *testing.T) { h := NewPeerStateTestHarness() h.ConfigServer_Relayed(t) @@ -505,3 +377,4 @@ func TestStateClientRelayed_OnProbe_upgradeDirect(t *testing.T) { assertType[*stateClientDirect](t, h.State) } +*/ diff --git a/peer/state-clientinit_test.go b/peer/state-clientinit_test.go new file mode 100644 index 0000000..87cdc8b --- /dev/null +++ b/peer/state-clientinit_test.go @@ -0,0 +1,83 @@ +package peer + +import ( + "testing" + "time" +) + +func TestPeerState_ClientInit_initWithIncorrectTraceID(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientInit(t) + + // Should have sent the first init packet. + assertEqual(t, len(h.Sent), 1) + init := assertType[packetInit](t, h.Sent[0].Packet) + + init.TraceID = newTraceID() + h.OnInit(controlMsg[packetInit]{Packet: init}) + + assertType[*stateClientInit](t, h.State) +} + +func TestPeerState_ClientInit_init(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientInit(t) + + // Should have sent the first init packet. + assertEqual(t, len(h.Sent), 1) + init := assertType[packetInit](t, h.Sent[0].Packet) + h.OnInit(controlMsg[packetInit]{Packet: init}) + + assertType[*stateClient](t, h.State) +} + +func TestPeerState_ClientInit_onPing(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientInit(t) + + // Should have sent the first init packet. + assertEqual(t, len(h.Sent), 1) + h.Sent = h.Sent[:0] + + for range 3 { + h.OnPingTimer() + } + + assertEqual(t, len(h.Sent), 3) + + for i := range h.Sent { + assertType[packetInit](t, h.Sent[i].Packet) + } +} + +func TestPeerState_ClientInit_onPingTimeout(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientInit(t) + + state := assertType[*stateClientInit](t, h.State) + state.startedAt = time.Now().Add(-2 * timeoutInterval) + + h.OnPingTimer() + + // Should have moved into the client state due to timeout. + assertType[*stateClient](t, h.State) +} + +func TestPeerState_ClientInit_onPeerUpdate(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientInit(t) + + h.PeerUpdate(nil) + + // Should have moved into the client state due to timeout. + assertType[*stateDisconnected](t, h.State) +} + +func TestPeerState_ClientInit_ignoreMessage(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientInit(t) + h.OnProbe(controlMsg[packetProbe]{}) + + // Shouldn't do anything. + assertType[*stateClientInit](t, h.State) +} diff --git a/peer/state-disconnected.go b/peer/state-disconnected.go index 3fdbd23..4c0b9c0 100644 --- a/peer/state-disconnected.go +++ b/peer/state-disconnected.go @@ -2,11 +2,11 @@ package peer import "net/netip" -type stateDisconnected2 struct { +type stateDisconnected struct { *peerData } -func enterStateDisconnected2(data *peerData) peerState { +func enterStateDisconnected(data *peerData) peerState { data.staged.Up = false data.staged.Relay = false data.staged.Direct = false @@ -19,10 +19,10 @@ func enterStateDisconnected2(data *peerData) peerState { data.pingTimer.Stop() - return &stateDisconnected2{data} + return &stateDisconnected{data} } -func (s *stateDisconnected2) OnMsg(raw any) peerState { +func (s *stateDisconnected) OnMsg(raw any) peerState { switch msg := raw.(type) { case peerUpdateMsg: return initPeerState(s.peerData, msg.Peer) diff --git a/peer/state-util_test.go b/peer/state-util_test.go new file mode 100644 index 0000000..f1dafb8 --- /dev/null +++ b/peer/state-util_test.go @@ -0,0 +1,146 @@ +package peer + +import ( + "net/netip" + "testing" + "time" + "vppn/m" + + "git.crumpington.com/lib/go/ratelimiter" +) + +type PeerStateControlMsg struct { + Peer remotePeer + Packet any +} + +type PeerStateTestHarness struct { + State peerState + Published remotePeer + Sent []PeerStateControlMsg +} + +func NewPeerStateTestHarness() *PeerStateTestHarness { + h := &PeerStateTestHarness{} + + keys := generateKeys() + + state := &peerData{ + publish: func(rp remotePeer) { + h.Published = rp + }, + sendControlPacket: func(rp remotePeer, pkt marshaller) { + h.Sent = append(h.Sent, PeerStateControlMsg{rp, pkt}) + }, + pingTimer: time.NewTicker(pingInterval), + localIP: 2, + remoteIP: 3, + privKey: keys.PrivKey, + pubAddrs: newPubAddrStore(netip.AddrPort{}), + limiter: ratelimiter.New(ratelimiter.Config{ + FillPeriod: 20 * time.Millisecond, + MaxWaitCount: 1, + }), + } + + h.State = enterStateDisconnected(state) + return h +} + +func (h *PeerStateTestHarness) PeerUpdate(p *m.Peer) { + h.State = h.State.OnMsg(peerUpdateMsg{p}) +} + +func (h *PeerStateTestHarness) OnInit(msg controlMsg[packetInit]) { + h.State = h.State.OnMsg(msg) +} + +func (h *PeerStateTestHarness) OnSyn(msg controlMsg[packetSyn]) { + h.State = h.State.OnMsg(msg) +} + +func (h *PeerStateTestHarness) OnProbe(msg controlMsg[packetProbe]) { + h.State = h.State.OnMsg(msg) +} + +func (h *PeerStateTestHarness) OnPingTimer() { + h.State = h.State.OnMsg(pingTimerMsg{}) +} + +func (h *PeerStateTestHarness) ConfigServer_Public(t *testing.T) *stateServer { + keys := generateKeys() + + state := h.State.(*stateDisconnected) + state.localAddr = addrPort4(1, 1, 1, 2, 200) + + peer := &m.Peer{ + PeerIP: 3, + PublicIP: []byte{1, 1, 1, 3}, + Port: 456, + PubKey: keys.PubKey, + PubSignKey: keys.PubSignKey, + } + + h.PeerUpdate(peer) + assertEqual(t, h.Published.Up, false) + return assertType[*stateServer](t, h.State) +} + +func (h *PeerStateTestHarness) ConfigServer_Relayed(t *testing.T) *stateServer { + keys := generateKeys() + peer := &m.Peer{ + PeerIP: 3, + Port: 456, + PubKey: keys.PubKey, + PubSignKey: keys.PubSignKey, + } + + h.PeerUpdate(peer) + assertEqual(t, h.Published.Up, false) + return assertType[*stateServer](t, h.State) +} + +func (h *PeerStateTestHarness) ConfigClientInit(t *testing.T) *stateClientInit { + keys := generateKeys() + peer := &m.Peer{ + PeerIP: 3, + PublicIP: []byte{1, 2, 3, 4}, + Port: 456, + PubKey: keys.PubKey, + PubSignKey: keys.PubSignKey, + } + + h.PeerUpdate(peer) + assertEqual(t, h.Published.Up, false) + return assertType[*stateClientInit](t, h.State) +} + +func (h *PeerStateTestHarness) ConfigClientDirect(t *testing.T) *stateClient { + h.ConfigClientInit(t) + init := assertType[packetInit](t, h.Sent[0].Packet) + h.OnInit(controlMsg[packetInit]{ + Packet: init, + }) + + return assertType[*stateClient](t, h.State) +} + +func (h *PeerStateTestHarness) ConfigClientRelayed(t *testing.T) *stateClient { + keys := generateKeys() + + state := h.State.(*stateDisconnected) + state.remoteIP = 1 + + peer := &m.Peer{ + PeerIP: 3, + Port: 456, + PubKey: keys.PubKey, + PubSignKey: keys.PubSignKey, + } + + // TODO: Fix me. + + h.PeerUpdate(peer) + assertEqual(t, h.Published.Up, false) + return assertType[*stateClient](t, h.State) +} diff --git a/peer/statedata.go b/peer/statedata.go index 0ea0929..5aee302 100644 --- a/peer/statedata.go +++ b/peer/statedata.go @@ -92,7 +92,7 @@ func initPeerState(data *peerData, peer *m.Peer) peerState { data.peer = peer if peer == nil { - return enterStateDisconnected2(data) + return enterStateDisconnected(data) } if _, isValid := netip.AddrFromSlice(peer.PublicIP); isValid { -- 2.39.5 From d78d704a451b21901e81e8d8f0e8a5be6a10341a Mon Sep 17 00:00:00 2001 From: jdl Date: Wed, 26 Feb 2025 06:46:39 +0100 Subject: [PATCH 20/26] wip: testing and cleanup --- peer/controlmessage.go | 3 +-- peer/peerstates_test.go | 9 --------- peer/state-server_test.go | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 33 insertions(+), 11 deletions(-) create mode 100644 peer/state-server_test.go diff --git a/peer/controlmessage.go b/peer/controlmessage.go index 33d4e9c..75a94d0 100644 --- a/peer/controlmessage.go +++ b/peer/controlmessage.go @@ -10,8 +10,7 @@ import ( type controlMsg[T any] struct { SrcIP byte SrcAddr netip.AddrPort - // TODO: RecvdAt int64 // Unixmilli. - Packet T + Packet T } func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error) { diff --git a/peer/peerstates_test.go b/peer/peerstates_test.go index 80f8210..32dc207 100644 --- a/peer/peerstates_test.go +++ b/peer/peerstates_test.go @@ -33,15 +33,6 @@ func TestPeerState_OnPeerUpdate_publicLocalIsServer(t *testing.T) { } /* -func TestPeerState_OnPeerUpdate_serverDirect(t *testing.T) { - h := NewPeerStateTestHarness() - h.ConfigServer_Public(t) -} - -func TestPeerState_OnPeerUpdate_serverRelayed(t *testing.T) { - h := NewPeerStateTestHarness() - h.ConfigServer_Relayed(t) -} func TestPeerState_OnPeerUpdate_clientDirect(t *testing.T) { h := NewPeerStateTestHarness() diff --git a/peer/state-server_test.go b/peer/state-server_test.go new file mode 100644 index 0000000..ad9f1cd --- /dev/null +++ b/peer/state-server_test.go @@ -0,0 +1,32 @@ +package peer + +import "testing" + +func TestStateServer_peerUpdate(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigServer_Public(t) + h.PeerUpdate(nil) + assertType[*stateDisconnected](t, h.State) +} + +func TestStateServer_onInit(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigServer_Public(t) + + msg := controlMsg[packetInit]{ + SrcIP: 3, + SrcAddr: addrPort4(1, 2, 3, 4, 1000), + Packet: packetInit{ + TraceID: newTraceID(), + Direct: true, + Version: 4, + }, + } + + h.OnInit(msg) + assertEqual(t, len(h.Sent), 1) + assertEqual(t, h.Sent[0].Peer.DirectAddr, msg.SrcAddr) + resp := assertType[packetInit](t, h.Sent[0].Packet) + assertEqual(t, msg.Packet.TraceID, resp.TraceID) + assertEqual(t, resp.Version, version) +} -- 2.39.5 From 17ffc01be264ca13b2da8152975d60e85e5f4134 Mon Sep 17 00:00:00 2001 From: jdl Date: Wed, 26 Feb 2025 07:56:09 +0100 Subject: [PATCH 21/26] wip: cleanup and testing --- peer/state-server_test.go | 50 +++++++++++++++++++++++++++++++++++++++ peer/util_test.go | 7 ++++++ 2 files changed, 57 insertions(+) diff --git a/peer/state-server_test.go b/peer/state-server_test.go index ad9f1cd..4d517a1 100644 --- a/peer/state-server_test.go +++ b/peer/state-server_test.go @@ -30,3 +30,53 @@ func TestStateServer_onInit(t *testing.T) { assertEqual(t, msg.Packet.TraceID, resp.TraceID) assertEqual(t, resp.Version, version) } + +func TestStateServer_onSynDirect(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigServer_Public(t) + + msg := controlMsg[packetSyn]{ + SrcIP: 3, + SrcAddr: addrPort4(1, 2, 3, 4, 1000), + Packet: packetSyn{ + TraceID: newTraceID(), + Direct: true, + }, + } + + msg.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 1000) + msg.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 2000) + + h.OnSyn(msg) + assertEqual(t, len(h.Sent), 1) + assertEqual(t, h.Sent[0].Peer.DirectAddr, msg.SrcAddr) + resp := assertType[packetAck](t, h.Sent[0].Packet) + assertEqual(t, msg.Packet.TraceID, resp.TraceID) +} + +func TestStateServer_onSynRelayed(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigServer_Relayed(t) + + msg := controlMsg[packetSyn]{ + SrcIP: 3, + SrcAddr: addrPort4(1, 2, 3, 4, 1000), + Packet: packetSyn{ + TraceID: newTraceID(), + }, + } + + msg.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 1000) + msg.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 2000) + + h.OnSyn(msg) + assertEqual(t, len(h.Sent), 3) + assertEqual(t, h.Sent[0].Peer.DirectAddr, msg.SrcAddr) + resp := assertType[packetAck](t, h.Sent[0].Packet) + assertEqual(t, msg.Packet.TraceID, resp.TraceID) + + for i, pkt := range h.Sent[1:] { + assertEqual(t, pkt.Peer.DirectAddr, msg.Packet.PossibleAddrs[i]) + assertType[packetProbe](t, pkt.Packet) + } +} diff --git a/peer/util_test.go b/peer/util_test.go index 56b9d6f..9703a97 100644 --- a/peer/util_test.go +++ b/peer/util_test.go @@ -9,6 +9,13 @@ func addrPort4(a, b, c, d byte, port uint16) netip.AddrPort { return netip.AddrPortFrom(netip.AddrFrom4([4]byte{a, b, c, d}), port) } +func assertNil(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } +} + func assertType[T any](t *testing.T, obj any) T { t.Helper() x, ok := obj.(T) -- 2.39.5 From 1d18d297ed67a8d7744f36de87d310f537d71476 Mon Sep 17 00:00:00 2001 From: jdl Date: Wed, 26 Feb 2025 12:45:43 +0100 Subject: [PATCH 22/26] wip: testing --- peer/state-server.go | 5 ++- peer/state-server_test.go | 84 ++++++++++++++++++++++++++++++++++++++- peer/state-util_test.go | 8 ++++ 3 files changed, 95 insertions(+), 2 deletions(-) diff --git a/peer/state-server.go b/peer/state-server.go index 723af5c..aba9c84 100644 --- a/peer/state-server.go +++ b/peer/state-server.go @@ -24,7 +24,10 @@ func enterStateServer(data *peerData) peerState { data.pingTimer.Reset(pingInterval) - state := &stateServer{peerData: data} + state := &stateServer{ + peerData: data, + lastSeen: time.Now(), + } state.logf("==> Server") return state } diff --git a/peer/state-server_test.go b/peer/state-server_test.go index 4d517a1..b367786 100644 --- a/peer/state-server_test.go +++ b/peer/state-server_test.go @@ -1,6 +1,9 @@ package peer -import "testing" +import ( + "testing" + "time" +) func TestStateServer_peerUpdate(t *testing.T) { h := NewPeerStateTestHarness() @@ -80,3 +83,82 @@ func TestStateServer_onSynRelayed(t *testing.T) { assertType[packetProbe](t, pkt.Packet) } } + +func TestStateServer_onProbe(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigServer_Relayed(t) + + msg := controlMsg[packetProbe]{ + SrcIP: 3, + Packet: packetProbe{ + TraceID: newTraceID(), + }, + } + h.Sent = h.Sent[:0] + + h.OnProbe(msg) + assertEqual(t, len(h.Sent), 0) +} + +func TestStateServer_onProbe_valid(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigServer_Relayed(t) + + msg := controlMsg[packetProbe]{ + SrcIP: 3, + SrcAddr: addrPort4(1, 2, 3, 4, 100), + Packet: packetProbe{ + TraceID: newTraceID(), + }, + } + h.Sent = h.Sent[:0] + + h.OnProbe(msg) + assertEqual(t, len(h.Sent), 1) + assertType[packetProbe](t, h.Sent[0].Packet) + assertEqual(t, h.Sent[0].Peer.DirectAddr, msg.SrcAddr) +} + +func TestStateServer_onPing(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigServer_Relayed(t) + h.Sent = h.Sent[:0] + h.OnPingTimer() + assertEqual(t, len(h.Sent), 0) + assertType[*stateServer](t, h.State) +} + +func TestStateServer_onPing_timeout(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigServer_Relayed(t) + + h.Sent = h.Sent[:0] + state := assertType[*stateServer](t, h.State) + state.staged.Up = true + state.lastSeen = time.Now().Add(-2 * timeoutInterval) + + h.OnPingTimer() + state = assertType[*stateServer](t, h.State) + assertEqual(t, len(h.Sent), 0) + assertEqual(t, state.staged.Up, false) +} + +func TestStateServer_onLocalDiscovery(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigServer_Relayed(t) + + msg := controlMsg[packetLocalDiscovery]{ + SrcIP: 3, + SrcAddr: addrPort4(1, 2, 3, 4, 100), + } + h.OnLocalDiscovery(msg) + assertType[*stateServer](t, h.State) +} + +func TestStateServer_onAck(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigServer_Relayed(t) + msg := controlMsg[packetAck]{} + h.OnAck(msg) + assertType[*stateServer](t, h.State) +} diff --git a/peer/state-util_test.go b/peer/state-util_test.go index f1dafb8..8bb2904 100644 --- a/peer/state-util_test.go +++ b/peer/state-util_test.go @@ -59,10 +59,18 @@ func (h *PeerStateTestHarness) OnSyn(msg controlMsg[packetSyn]) { h.State = h.State.OnMsg(msg) } +func (h *PeerStateTestHarness) OnAck(msg controlMsg[packetAck]) { + h.State = h.State.OnMsg(msg) +} + func (h *PeerStateTestHarness) OnProbe(msg controlMsg[packetProbe]) { h.State = h.State.OnMsg(msg) } +func (h *PeerStateTestHarness) OnLocalDiscovery(msg controlMsg[packetLocalDiscovery]) { + h.State = h.State.OnMsg(msg) +} + func (h *PeerStateTestHarness) OnPingTimer() { h.State = h.State.OnMsg(pingTimerMsg{}) } -- 2.39.5 From 10fcb244660438e902da4c108971e4e4f746fd00 Mon Sep 17 00:00:00 2001 From: jdl Date: Wed, 26 Feb 2025 12:51:31 +0100 Subject: [PATCH 23/26] wip --- peer/state-client_test.go | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 peer/state-client_test.go diff --git a/peer/state-client_test.go b/peer/state-client_test.go new file mode 100644 index 0000000..8feb4d2 --- /dev/null +++ b/peer/state-client_test.go @@ -0,0 +1,34 @@ +package peer + +import "testing" + +func TestStateClient_peerUpdate(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientDirect(t) + h.PeerUpdate(nil) + assertType[*stateDisconnected](t, h.State) +} + +func TestStateClient_initialPackets(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientDirect(t) + + assertEqual(t, len(h.Sent), 2) + assertType[packetInit](t, h.Sent[0].Packet) + assertType[packetSyn](t, h.Sent[1].Packet) +} + +func TestStateClient_onAck_incorrectTraceID(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientDirect(t) + h.Sent = h.Sent[:0] + + ack := controlMsg[packetAck]{ + Packet: packetAck{TraceID: newTraceID()}, + } + h.OnAck(ack) + + // Nothing should have happened. + assertType[*stateClient](t, h.State) + assertEqual(t, len(h.Sent), 0) +} -- 2.39.5 From ea3e997df86ca3550c8012f72c8f0cbe601e0226 Mon Sep 17 00:00:00 2001 From: jdl Date: Thu, 27 Feb 2025 21:00:48 +0100 Subject: [PATCH 24/26] wip - testing --- peer/state-client_test.go | 80 +++++++++++++++++++++++++++++++++++++- peer/state-clientinit.go | 12 +++++- peer/state-disconnected.go | 19 ++++++++- peer/state-server.go | 5 ++- peer/state-util_test.go | 25 ++++++------ peer/util_test.go | 2 +- 6 files changed, 124 insertions(+), 19 deletions(-) diff --git a/peer/state-client_test.go b/peer/state-client_test.go index 8feb4d2..88f4010 100644 --- a/peer/state-client_test.go +++ b/peer/state-client_test.go @@ -1,6 +1,9 @@ package peer -import "testing" +import ( + "testing" + "time" +) func TestStateClient_peerUpdate(t *testing.T) { h := NewPeerStateTestHarness() @@ -32,3 +35,78 @@ func TestStateClient_onAck_incorrectTraceID(t *testing.T) { assertType[*stateClient](t, h.State) assertEqual(t, len(h.Sent), 0) } + +func TestStateClient_onAck_direct_downToUp(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientDirect(t) + + assertEqual(t, len(h.Sent), 2) + syn := assertType[packetSyn](t, h.Sent[1].Packet) + h.Sent = h.Sent[:0] + + assertEqual(t, h.Published.Up, false) + + ack := controlMsg[packetAck]{ + Packet: packetAck{TraceID: syn.TraceID}, + } + + h.OnAck(ack) + + assertEqual(t, len(h.Sent), 0) +} + +func TestStateClient_onAck_relayed_sendsProbes(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientRelayed(t) + + assertEqual(t, len(h.Sent), 2) + syn := assertType[packetSyn](t, h.Sent[1].Packet) + h.Sent = h.Sent[:0] + + assertEqual(t, h.Published.Up, false) + + ack := controlMsg[packetAck]{ + Packet: packetAck{TraceID: syn.TraceID}, + } + ack.Packet.PossibleAddrs[0] = addrPort4(1, 2, 3, 4, 100) + ack.Packet.PossibleAddrs[1] = addrPort4(2, 3, 4, 5, 200) + + h.OnAck(ack) + + assertEqual(t, len(h.Sent), 2) + assertType[packetProbe](t, h.Sent[0].Packet) + assertEqual(t, h.Sent[0].Peer.DirectAddr, ack.Packet.PossibleAddrs[0]) + assertType[packetProbe](t, h.Sent[1].Packet) + assertEqual(t, h.Sent[1].Peer.DirectAddr, ack.Packet.PossibleAddrs[1]) +} + +func TestStateClient_onPing(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientRelayed(t) + h.Sent = h.Sent[:0] + h.OnPingTimer() + assertEqual(t, len(h.Sent), 1) + assertType[*stateClient](t, h.State) + assertType[packetSyn](t, h.Sent[0].Packet) +} + +func TestStateClient_onPing_timeout(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientRelayed(t) + h.Sent = h.Sent[:0] + state := assertType[*stateClient](t, h.State) + state.lastSeen = time.Now().Add(-2 * timeoutInterval) + state.staged.Up = true + h.OnPingTimer() + + newState := assertType[*stateClientInit](t, h.State) + assertEqual(t, newState.staged.Up, false) + assertEqual(t, len(h.Sent), 1) + assertType[packetInit](t, h.Sent[0].Packet) +} + +// probe direct + +// probe relayed - no match + +// probe relayed - match diff --git a/peer/state-clientinit.go b/peer/state-clientinit.go index f34854a..8d963bc 100644 --- a/peer/state-clientinit.go +++ b/peer/state-clientinit.go @@ -47,10 +47,20 @@ func (s *stateClientInit) OnMsg(raw any) peerState { return initPeerState(s.peerData, msg.Peer) case controlMsg[packetInit]: return s.onInit(msg) + case controlMsg[packetSyn]: + s.logf("Unexpected SYN") + return s + case controlMsg[packetAck]: + s.logf("Unexpected ACK") + return s + case controlMsg[packetProbe]: + return s + case controlMsg[packetLocalDiscovery]: + return s case pingTimerMsg: return s.onPing() default: - s.logf("Ignoring message: %v", raw) + s.logf("Ignoring message: %#v", raw) return s } } diff --git a/peer/state-disconnected.go b/peer/state-disconnected.go index 4c0b9c0..ea503dc 100644 --- a/peer/state-disconnected.go +++ b/peer/state-disconnected.go @@ -26,8 +26,25 @@ func (s *stateDisconnected) OnMsg(raw any) peerState { switch msg := raw.(type) { case peerUpdateMsg: return initPeerState(s.peerData, msg.Peer) + case controlMsg[packetInit]: + s.logf("Unexpected INIT") + return s + case controlMsg[packetSyn]: + s.logf("Unexpected SYN") + return s + case controlMsg[packetAck]: + s.logf("Unexpected ACK") + return s + case controlMsg[packetProbe]: + s.logf("Unexpected probe") + return s + case controlMsg[packetLocalDiscovery]: + return s + case pingTimerMsg: + s.logf("Unexpected ping") + return s default: - s.logf("Ignoring message: %v", raw) + s.logf("Ignoring message: %#v", raw) return s } } diff --git a/peer/state-server.go b/peer/state-server.go index aba9c84..c9c76db 100644 --- a/peer/state-server.go +++ b/peer/state-server.go @@ -44,6 +44,9 @@ func (s *stateServer) OnMsg(raw any) peerState { return s.onInit(msg) case controlMsg[packetSyn]: return s.onSyn(msg) + case controlMsg[packetAck]: + s.logf("Unexpected ACK") + return s case controlMsg[packetProbe]: return s.onProbe(msg) case controlMsg[packetLocalDiscovery]: @@ -51,7 +54,7 @@ func (s *stateServer) OnMsg(raw any) peerState { case pingTimerMsg: return s.onPingTimer() default: - s.logf("Ignoring message: %v", raw) + s.logf("Unexpected message: %#v", raw) return s } } diff --git a/peer/state-util_test.go b/peer/state-util_test.go index 8bb2904..465a8c3 100644 --- a/peer/state-util_test.go +++ b/peer/state-util_test.go @@ -15,6 +15,7 @@ type PeerStateControlMsg struct { } type PeerStateTestHarness struct { + data *peerData State peerState Published remotePeer Sent []PeerStateControlMsg @@ -42,6 +43,7 @@ func NewPeerStateTestHarness() *PeerStateTestHarness { MaxWaitCount: 1, }), } + h.data = state h.State = enterStateDisconnected(state) return h @@ -109,6 +111,8 @@ func (h *PeerStateTestHarness) ConfigServer_Relayed(t *testing.T) *stateServer { } func (h *PeerStateTestHarness) ConfigClientInit(t *testing.T) *stateClientInit { + // Remote IP should be less than local IP. + h.data.localIP = 4 keys := generateKeys() peer := &m.Peer{ PeerIP: 3, @@ -134,21 +138,14 @@ func (h *PeerStateTestHarness) ConfigClientDirect(t *testing.T) *stateClient { } func (h *PeerStateTestHarness) ConfigClientRelayed(t *testing.T) *stateClient { - keys := generateKeys() + h.ConfigClientInit(t) + state := assertType[*stateClientInit](t, h.State) + state.peer.PublicIP = nil // Force relay. - state := h.State.(*stateDisconnected) - state.remoteIP = 1 + init := assertType[packetInit](t, h.Sent[0].Packet) + h.OnInit(controlMsg[packetInit]{ + Packet: init, + }) - peer := &m.Peer{ - PeerIP: 3, - Port: 456, - PubKey: keys.PubKey, - PubSignKey: keys.PubSignKey, - } - - // TODO: Fix me. - - h.PeerUpdate(peer) - assertEqual(t, h.Published.Up, false) return assertType[*stateClient](t, h.State) } diff --git a/peer/util_test.go b/peer/util_test.go index 9703a97..128d29a 100644 --- a/peer/util_test.go +++ b/peer/util_test.go @@ -20,7 +20,7 @@ func assertType[T any](t *testing.T, obj any) T { t.Helper() x, ok := obj.(T) if !ok { - t.Fatal("invalid type", obj) + t.Fatalf("invalid type: %#v", obj) } return x } -- 2.39.5 From 1cd83dd098ca8782465ea182db04d7e6d704869a Mon Sep 17 00:00:00 2001 From: jdl Date: Fri, 28 Feb 2025 21:31:43 +0100 Subject: [PATCH 25/26] wip: testing, etc. --- peer/state-client_test.go | 87 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 84 insertions(+), 3 deletions(-) diff --git a/peer/state-client_test.go b/peer/state-client_test.go index 88f4010..25441e8 100644 --- a/peer/state-client_test.go +++ b/peer/state-client_test.go @@ -105,8 +105,89 @@ func TestStateClient_onPing_timeout(t *testing.T) { assertType[packetInit](t, h.Sent[0].Packet) } -// probe direct +func TestStateClient_onProbe_direct(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientDirect(t) -// probe relayed - no match + h.Sent = h.Sent[:0] + probe := controlMsg[packetProbe]{ + Packet: packetProbe{ + TraceID: newTraceID(), + }, + } -// probe relayed - match + h.OnProbe(probe) + assertType[*stateClient](t, h.State) + assertEqual(t, len(h.Sent), 0) +} + +func TestStateClient_onProbe_noMatch(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientRelayed(t) + + h.Sent = h.Sent[:0] + probe := controlMsg[packetProbe]{ + Packet: packetProbe{ + TraceID: newTraceID(), + }, + } + + h.OnProbe(probe) + assertType[*stateClient](t, h.State) + assertEqual(t, len(h.Sent), 0) +} + +func TestStateClient_onProbe_directUpgrade(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientRelayed(t) + + state := assertType[*stateClient](t, h.State) + traceID := newTraceID() + state.probes[traceID] = sentProbe{ + SentAt: time.Now(), + Addr: addrPort4(1, 2, 3, 4, 500), + } + + probe := controlMsg[packetProbe]{ + Packet: packetProbe{TraceID: traceID}, + } + + assertEqual(t, h.Published.Direct, false) + h.Sent = h.Sent[:0] + h.OnProbe(probe) + assertEqual(t, h.Published.Direct, true) + + assertEqual(t, len(h.Sent), 1) + assertType[packetSyn](t, h.Sent[0].Packet) +} + +func TestStateClient_onLocalDiscovery_direct(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientDirect(t) + + h.Sent = h.Sent[:0] + pkt := controlMsg[packetLocalDiscovery]{ + Packet: packetLocalDiscovery{}, + } + + h.OnLocalDiscovery(pkt) + assertType[*stateClient](t, h.State) + assertEqual(t, len(h.Sent), 0) +} + +func TestStateClient_onLocalDiscovery_relayed(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientRelayed(t) + + h.Sent = h.Sent[:0] + pkt := controlMsg[packetLocalDiscovery]{ + SrcAddr: addrPort4(1, 2, 3, 4, 500), + Packet: packetLocalDiscovery{}, + } + + h.OnLocalDiscovery(pkt) + assertType[*stateClient](t, h.State) + assertEqual(t, len(h.Sent), 1) + assertType[packetProbe](t, h.Sent[0].Packet) + assertEqual(t, h.Sent[0].Peer.DirectAddr, addrPort4(1, 2, 3, 4, 456)) +} -- 2.39.5 From 1d318b4ae7c7cf5d8f0abcd8c7558df2b00e26be Mon Sep 17 00:00:00 2001 From: jdl Date: Sat, 1 Mar 2025 17:29:43 +0100 Subject: [PATCH 26/26] cleanup --- peer/controlmessage.go | 16 ++++++++-------- peer/state-client.go | 2 +- peer/util_test.go | 7 ------- 3 files changed, 9 insertions(+), 16 deletions(-) diff --git a/peer/controlmessage.go b/peer/controlmessage.go index 75a94d0..f327291 100644 --- a/peer/controlmessage.go +++ b/peer/controlmessage.go @@ -16,6 +16,14 @@ type controlMsg[T any] struct { func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error) { switch buf[0] { + case packetTypeInit: + packet, err := parsePacketInit(buf) + return controlMsg[packetInit]{ + SrcIP: srcIP, + SrcAddr: srcAddr, + Packet: packet, + }, err + case packetTypeSyn: packet, err := parsePacketSyn(buf) return controlMsg[packetSyn]{ @@ -40,14 +48,6 @@ func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error Packet: packet, }, err - case packetTypeInit: - packet, err := parsePacketInit(buf) - return controlMsg[packetInit]{ - SrcIP: srcIP, - SrcAddr: srcAddr, - Packet: packet, - }, err - default: return nil, errUnknownPacketType } diff --git a/peer/state-client.go b/peer/state-client.go index 49e4375..7e9d7c9 100644 --- a/peer/state-client.go +++ b/peer/state-client.go @@ -128,7 +128,7 @@ func (s *stateClient) onProbe(msg controlMsg[packetProbe]) peerState { s.syn.TraceID = newTraceID() s.syn.Direct = true s.Send(s.staged, s.syn) - s.logf("Successful probe.") + s.logf("Successful probe to %v.", sent.Addr) return s } diff --git a/peer/util_test.go b/peer/util_test.go index 128d29a..af05365 100644 --- a/peer/util_test.go +++ b/peer/util_test.go @@ -9,13 +9,6 @@ func addrPort4(a, b, c, d byte, port uint16) netip.AddrPort { return netip.AddrPortFrom(netip.AddrFrom4([4]byte{a, b, c, d}), port) } -func assertNil(t *testing.T, err error) { - t.Helper() - if err != nil { - t.Fatal(err) - } -} - func assertType[T any](t *testing.T, obj any) T { t.Helper() x, ok := obj.(T) -- 2.39.5