From 1d3cc1f959c48eb42a982435a77b81a8277ac8df Mon Sep 17 00:00:00 2001 From: "J. David Lee" Date: Sat, 1 Mar 2025 20:02:27 +0000 Subject: [PATCH] refactor-for-testability (#3) Co-authored-by: jdl Co-authored-by: jdl Reviewed-on: https://git.crumpington.com/app/vppn/pulls/3 --- README.md | 6 - cmd/vppn/main.go | 4 +- node/addrdiscovery.go | 71 --- node/config.go | 11 - node/conn.go | 50 --- node/globalfuncs.go | 65 --- node/globals.go | 86 ---- node/hubpoller.go | 92 ---- node/localdiscovery.go | 97 ---- node/localdiscovery_test.go | 35 -- node/main.go | 323 -------------- node/packets.go | 129 ------ node/packets_test.go | 41 -- node/relaymanager.go | 40 -- node/supervisor.go | 417 ------------------ {node => peer}/bitset.go | 2 +- {node => peer}/bitset_test.go | 2 +- {node => peer}/cipher-control.go | 2 +- {node => peer}/cipher-control_test.go | 2 +- {node => peer}/cipher-data.go | 9 +- {node => peer}/cipher-data_test.go | 2 +- {node => peer}/cipher-discovery.go | 2 +- peer/connreader.go | 140 ++++++ node/messages.go => peer/controlmessage.go | 27 +- peer/crypto.go | 30 ++ peer/crypto_test.go | 191 ++++++++ peer/data-flow.dot | 14 + {node => peer}/dupcheck.go | 6 +- {node => peer}/dupcheck_test.go | 13 +- peer/errors.go | 10 + {node => peer}/files.go | 10 +- peer/files_test.go | 57 +++ peer/globals.go | 37 ++ {node => peer}/header.go | 11 +- {node => peer}/header_test.go | 2 +- peer/hubpoller.go | 110 +++++ peer/ifreader.go | 103 +++++ peer/ifreader_test.go | 81 ++++ {node => peer}/interface.go | 42 +- peer/main.go | 23 + peer/mcreader.go | 70 +++ peer/mcreader_test.go | 132 ++++++ peer/mcwriter.go | 53 +++ peer/mcwriter_test.go | 98 ++++ peer/mock-iface_test.go | 31 ++ peer/mock-network_test.go | 80 ++++ {node => peer}/packets-util.go | 6 +- {node => peer}/packets-util_test.go | 26 +- peer/packets.go | 120 +++++ peer/packets_test.go | 66 +++ peer/peer.go | 177 ++++++++ peer/peer_test.go | 114 +++++ peer/peerstates_test.go | 371 ++++++++++++++++ peer/peersuper.go | 148 +++++++ peer/pubaddrs.go | 86 ++++ .../pubaddrs_test.go | 6 +- peer/routingtable.go | 138 ++++++ peer/routingtable_test.go | 169 +++++++ peer/state-client.go | 162 +++++++ peer/state-client_test.go | 193 ++++++++ peer/state-clientinit.go | 95 ++++ peer/state-clientinit_test.go | 83 ++++ peer/state-disconnected.go | 50 +++ peer/state-server.go | 136 ++++++ peer/state-server_test.go | 164 +++++++ peer/state-util_test.go | 151 +++++++ peer/statedata.go | 109 +++++ peer/util_test.go | 26 ++ 68 files changed, 3908 insertions(+), 1547 deletions(-) delete mode 100644 node/addrdiscovery.go delete mode 100644 node/config.go delete mode 100644 node/conn.go delete mode 100644 node/globalfuncs.go delete mode 100644 node/globals.go delete mode 100644 node/hubpoller.go delete mode 100644 node/localdiscovery.go delete mode 100644 node/localdiscovery_test.go delete mode 100644 node/main.go delete mode 100644 node/packets.go delete mode 100644 node/packets_test.go delete mode 100644 node/relaymanager.go delete mode 100644 node/supervisor.go rename {node => peer}/bitset.go (96%) rename {node => peer}/bitset_test.go (97%) rename {node => peer}/cipher-control.go (98%) rename {node => peer}/cipher-control_test.go (99%) rename {node => peer}/cipher-data.go (85%) rename {node => peer}/cipher-data_test.go (99%) rename {node => peer}/cipher-discovery.go (95%) create mode 100644 peer/connreader.go rename node/messages.go => peer/controlmessage.go (68%) create mode 100644 peer/crypto.go create mode 100644 peer/crypto_test.go create mode 100644 peer/data-flow.dot rename {node => peer}/dupcheck.go (92%) rename {node => peer}/dupcheck_test.go (79%) create mode 100644 peer/errors.go rename {node => peer}/files.go (92%) create mode 100644 peer/files_test.go create mode 100644 peer/globals.go rename {node => peer}/header.go (79%) rename {node => peer}/header_test.go (95%) create mode 100644 peer/hubpoller.go create mode 100644 peer/ifreader.go create mode 100644 peer/ifreader_test.go rename {node => peer}/interface.go (84%) create mode 100644 peer/main.go create mode 100644 peer/mcreader.go create mode 100644 peer/mcreader_test.go create mode 100644 peer/mcwriter.go create mode 100644 peer/mcwriter_test.go create mode 100644 peer/mock-iface_test.go create mode 100644 peer/mock-network_test.go rename {node => peer}/packets-util.go (95%) rename {node => peer}/packets-util_test.go (75%) create mode 100644 peer/packets.go create mode 100644 peer/packets_test.go create mode 100644 peer/peer.go create mode 100644 peer/peer_test.go create mode 100644 peer/peerstates_test.go create mode 100644 peer/peersuper.go create mode 100644 peer/pubaddrs.go rename node/addrdiscovery_test.go => peer/pubaddrs_test.go (87%) create mode 100644 peer/routingtable.go create mode 100644 peer/routingtable_test.go create mode 100644 peer/state-client.go create mode 100644 peer/state-client_test.go create mode 100644 peer/state-clientinit.go create mode 100644 peer/state-clientinit_test.go create mode 100644 peer/state-disconnected.go create mode 100644 peer/state-server.go create mode 100644 peer/state-server_test.go create mode 100644 peer/state-util_test.go create mode 100644 peer/statedata.go create mode 100644 peer/util_test.go diff --git a/README.md b/README.md index c6cc0e1..4567196 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,5 @@ # vppn: Virtual Potentially Private Network -## TODO - -* Add `-force-init` argument to `node` main? - ## Hub Server Configuration ``` @@ -33,7 +29,6 @@ WorkingDirectory=/home/user/ ExecStart=/home/user/hub -listen :https -root-dir=/home/user Restart=always RestartSec=8 -TimeoutStopSec=24 [Install] WantedBy=default.target @@ -70,7 +65,6 @@ WorkingDirectory=/home/user/ ExecStart=/home/user/vppn -name vppn -hub-address https://my.hub -api-key 1234567890 Restart=always RestartSec=8 -TimeoutStopSec=24 [Install] WantedBy=default.target diff --git a/cmd/vppn/main.go b/cmd/vppn/main.go index 8c04016..5daa907 100644 --- a/cmd/vppn/main.go +++ b/cmd/vppn/main.go @@ -2,10 +2,10 @@ package main import ( "log" - "vppn/node" + "vppn/peer" ) func main() { log.SetFlags(0) - node.Main() + peer.Main() } diff --git a/node/addrdiscovery.go b/node/addrdiscovery.go deleted file mode 100644 index 160c7a0..0000000 --- a/node/addrdiscovery.go +++ /dev/null @@ -1,71 +0,0 @@ -package node - -import ( - "log" - "net/netip" - "runtime/debug" - "sort" - "time" -) - -type pubAddrStore struct { - lastSeen map[netip.AddrPort]time.Time - addrList []netip.AddrPort -} - -func newPubAddrStore() *pubAddrStore { - return &pubAddrStore{ - lastSeen: map[netip.AddrPort]time.Time{}, - addrList: make([]netip.AddrPort, 0, 32), - } -} - -func (store *pubAddrStore) Store(add netip.AddrPort) { - if localPub { - log.Printf("OOPS: Local pub but storage attempt: %s", debug.Stack()) - return - } - - if !add.IsValid() { - return - } - - if _, exists := store.lastSeen[add]; !exists { - store.addrList = append(store.addrList, add) - } - store.lastSeen[add] = time.Now() - store.sort() -} - -func (store *pubAddrStore) Get() (addrs [8]netip.AddrPort) { - if localPub { - addrs[0] = localAddr - return - } - - copy(addrs[:], store.addrList) - return -} - -func (store *pubAddrStore) Clean() { - if localPub { - return - } - - for ip, lastSeen := range store.lastSeen { - if time.Since(lastSeen) > timeoutInterval { - delete(store.lastSeen, ip) - } - } - store.addrList = store.addrList[:0] - for ip := range store.lastSeen { - store.addrList = append(store.addrList, ip) - } - store.sort() -} - -func (store *pubAddrStore) sort() { - sort.Slice(store.addrList, func(i, j int) bool { - return store.lastSeen[store.addrList[j]].Before(store.lastSeen[store.addrList[i]]) - }) -} diff --git a/node/config.go b/node/config.go deleted file mode 100644 index 46da9eb..0000000 --- a/node/config.go +++ /dev/null @@ -1,11 +0,0 @@ -package node - -import "vppn/m" - -type localConfig struct { - m.PeerConfig - PubKey []byte - PrivKey []byte - PubSignKey []byte - PrivSignKey []byte -} diff --git a/node/conn.go b/node/conn.go deleted file mode 100644 index 2a1e762..0000000 --- a/node/conn.go +++ /dev/null @@ -1,50 +0,0 @@ -package node - -import ( - "io" - "log" - "net" - "net/netip" - "sync" -) - -// ---------------------------------------------------------------------------- - -type connWriter struct { - lock sync.Mutex - conn *net.UDPConn -} - -func newConnWriter(conn *net.UDPConn) *connWriter { - return &connWriter{conn: conn} -} - -func (w *connWriter) WriteTo(packet []byte, addr netip.AddrPort) { - // Even though a conn is safe for concurrent use, it turns out that a mutex - // in Go is more fair when there's contention. Without this lock, control - // packets may fail to be sent in a timely manner causing timeouts. - w.lock.Lock() - if _, err := w.conn.WriteToUDPAddrPort(packet, addr); err != nil { - log.Printf("Failed to write to UDP port: %v", err) - } - w.lock.Unlock() -} - -// ---------------------------------------------------------------------------- - -type ifWriter struct { - lock sync.Mutex - iface io.ReadWriteCloser -} - -func newIFWriter(iface io.ReadWriteCloser) *ifWriter { - return &ifWriter{iface: iface} -} - -func (w *ifWriter) Write(packet []byte) { - w.lock.Lock() - if _, err := w.iface.Write(packet); err != nil { - log.Fatalf("Failed to write to interface: %v", err) - } - w.lock.Unlock() -} diff --git a/node/globalfuncs.go b/node/globalfuncs.go deleted file mode 100644 index f32ec0b..0000000 --- a/node/globalfuncs.go +++ /dev/null @@ -1,65 +0,0 @@ -package node - -import ( - "sync/atomic" -) - -func getRelayRoute() *peerRoute { - if ip := relayIP.Load(); ip != nil { - return routingTable[*ip].Load() - } - return nil -} - -func _sendControlPacket(pkt interface{ Marshal([]byte) []byte }, route peerRoute, buf1, buf2 []byte) { - buf := pkt.Marshal(buf2) - h := header{ - StreamID: controlStreamID, - Counter: atomic.AddUint64(&sendCounters[route.IP], 1), - SourceIP: localIP, - DestIP: route.IP, - } - buf = route.ControlCipher.Encrypt(h, buf, buf1) - - if route.Direct { - _conn.WriteTo(buf, route.RemoteAddr) - return - } - - _relayPacket(route.IP, buf, buf2) -} - -func _sendDataPacket(route *peerRoute, pkt, buf1, buf2 []byte) { - h := header{ - StreamID: dataStreamID, - Counter: atomic.AddUint64(&sendCounters[route.IP], 1), - SourceIP: localIP, - DestIP: route.IP, - } - - enc := route.DataCipher.Encrypt(h, pkt, buf1) - - if route.Direct { - _conn.WriteTo(enc, route.RemoteAddr) - return - } - - _relayPacket(route.IP, enc, buf2) -} - -func _relayPacket(destIP byte, data, buf []byte) { - relayRoute := getRelayRoute() - if relayRoute == nil || !relayRoute.Up || !relayRoute.Relay { - return - } - - h := header{ - StreamID: dataStreamID, - Counter: atomic.AddUint64(&sendCounters[relayRoute.IP], 1), - SourceIP: localIP, - DestIP: destIP, - } - - enc := relayRoute.DataCipher.Encrypt(h, data, buf) - _conn.WriteTo(enc, relayRoute.RemoteAddr) -} diff --git a/node/globals.go b/node/globals.go deleted file mode 100644 index b72acc4..0000000 --- a/node/globals.go +++ /dev/null @@ -1,86 +0,0 @@ -package node - -import ( - "net" - "net/netip" - "net/url" - "sync/atomic" - "time" -) - -const ( - bufferSize = 1536 - if_mtu = 1200 - if_queue_len = 2048 - controlCipherOverhead = 16 - dataCipherOverhead = 16 - signOverhead = 64 -) - -var ( - multicastIP = netip.AddrFrom4([4]byte{224, 0, 0, 157}) - multicastAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom(multicastIP, 4560)) -) - -type peerRoute struct { - IP byte - Up bool // True if data can be sent on the route. - Relay bool // True if the peer is a relay. - Direct bool // True if this is a direct connection. - PubSignKey []byte - ControlCipher *controlCipher - DataCipher *dataCipher - RemoteAddr netip.AddrPort // Remote address if directly connected. -} - -var ( - hubURL *url.URL - apiKey string - - // Configuration for this peer. - netName string - localIP byte - localPub bool - localAddr netip.AddrPort - privKey []byte - privSignKey []byte - - // Shared interface for writing. - _iface *ifWriter - - // Shared connection for writing. - _conn *connWriter - - // Counters for sending to each peer. - sendCounters [256]uint64 = func() (out [256]uint64) { - for i := range out { - out[i] = uint64(time.Now().Unix()<<30 + 1) - } - return - }() - - // Duplicate checkers for incoming packets. - dupChecks [256]*dupCheck = func() (out [256]*dupCheck) { - for i := range out { - out[i] = newDupCheck(0) - } - return - }() - - // Messages for the supervisor. - messages = make(chan any, 1024) - - // Global routing table. - routingTable [256]*atomic.Pointer[peerRoute] = func() (out [256]*atomic.Pointer[peerRoute]) { - for i := range out { - out[i] = &atomic.Pointer[peerRoute]{} - out[i].Store(&peerRoute{}) - } - return - }() - - // Managed by the relayManager. - relayIP = &atomic.Pointer[byte]{} - - publicAddrs = newPubAddrStore() -) diff --git a/node/hubpoller.go b/node/hubpoller.go deleted file mode 100644 index a069c8b..0000000 --- a/node/hubpoller.go +++ /dev/null @@ -1,92 +0,0 @@ -package node - -import ( - "encoding/json" - "io" - "log" - "net/http" - "time" - "vppn/m" -) - -type hubPoller struct { - client *http.Client - req *http.Request - versions [256]int64 -} - -func newHubPoller() *hubPoller { - u := *hubURL - u.Path = "/peer/fetch-state/" - - client := &http.Client{Timeout: 8 * time.Second} - - req := &http.Request{ - Method: http.MethodGet, - URL: &u, - Header: http.Header{}, - } - req.SetBasicAuth("", apiKey) - - return &hubPoller{ - client: client, - req: req, - } -} - -func (hp *hubPoller) Run() { - defer panicHandler() - - state, err := loadNetworkState(netName) - if err != nil { - log.Printf("Failed to load network state: %v", err) - log.Printf("Polling hub...") - hp.pollHub() - } else { - hp.applyNetworkState(state) - } - - for range time.Tick(64 * time.Second) { - hp.pollHub() - } -} - -func (hp *hubPoller) pollHub() { - var state m.NetworkState - - resp, err := hp.client.Do(hp.req) - if err != nil { - log.Printf("Failed to fetch peer state: %v", err) - return - } - body, err := io.ReadAll(resp.Body) - _ = resp.Body.Close() - if err != nil { - log.Printf("Failed to read body from hub: %v", err) - return - } - - if err := json.Unmarshal(body, &state); err != nil { - log.Printf("Failed to unmarshal response from hub: %v\n%s", err, body) - return - } - - hp.applyNetworkState(state) - - if err := storeNetworkState(netName, state); err != nil { - log.Printf("Failed to store network state: %v", err) - } -} - -func (hp *hubPoller) applyNetworkState(state m.NetworkState) { - for i, peer := range state.Peers { - if i != int(localIP) { - if peer == nil || peer.Version != hp.versions[i] { - messages <- peerUpdateMsg{PeerIP: byte(i), Peer: state.Peers[i]} - if peer != nil { - hp.versions[i] = peer.Version - } - } - } - } -} diff --git a/node/localdiscovery.go b/node/localdiscovery.go deleted file mode 100644 index 90f2e60..0000000 --- a/node/localdiscovery.go +++ /dev/null @@ -1,97 +0,0 @@ -package node - -import ( - "log" - "net" - "time" - - "golang.org/x/crypto/nacl/sign" -) - -func localDiscovery() { - conn, err := net.ListenMulticastUDP("udp", nil, multicastAddr) - if err != nil { - log.Printf("Failed to bind to multicast address: %v", err) - return - } - - go sendLocalDiscovery(conn) - go recvLocalDiscovery(conn) -} - -func sendLocalDiscovery(conn *net.UDPConn) { - var ( - buf1 = make([]byte, bufferSize) - buf2 = make([]byte, bufferSize) - ) - - for range time.Tick(16 * time.Second) { - signed := buildLocalDiscoveryPacket(buf1, buf2) - if _, err := conn.WriteToUDP(signed, multicastAddr); err != nil { - log.Printf("Failed to write multicast UDP packet: %v", err) - } - } -} - -func recvLocalDiscovery(conn *net.UDPConn) { - var ( - raw = make([]byte, bufferSize) - buf = make([]byte, bufferSize) - ) - - for { - n, remoteAddr, err := conn.ReadFromUDPAddrPort(raw[:bufferSize]) - if err != nil { - log.Fatalf("Failed to read from UDP port (multicast): %v", err) - } - - raw = raw[:n] - h, ok := openLocalDiscoveryPacket(raw, buf) - if !ok { - log.Printf("Failed to open discovery packet?") - continue - } - - msg := controlMsg[localDiscoveryPacket]{ - SrcIP: h.SourceIP, - SrcAddr: remoteAddr, - Packet: localDiscoveryPacket{}, - } - - select { - case messages <- msg: - default: - log.Printf("Dropping local discovery message.") - } - } -} - -func buildLocalDiscoveryPacket(buf1, buf2 []byte) []byte { - h := header{ - StreamID: controlStreamID, - Counter: 0, - SourceIP: localIP, - DestIP: 255, - } - out := buf1[:headerSize] - h.Marshal(out) - return sign.Sign(buf2[:0], out, (*[64]byte)(privSignKey)) -} - -func openLocalDiscoveryPacket(raw, buf []byte) (h header, ok bool) { - if len(raw) != headerSize+signOverhead { - ok = false - return - } - - h.Parse(raw[signOverhead:]) - route := routingTable[h.SourceIP].Load() - if route == nil || route.PubSignKey == nil { - log.Printf("Missing signing key: %d", h.SourceIP) - ok = false - return - } - - _, ok = sign.Open(buf[:0], raw, (*[32]byte)(route.PubSignKey)) - return -} diff --git a/node/localdiscovery_test.go b/node/localdiscovery_test.go deleted file mode 100644 index 7f4eaa3..0000000 --- a/node/localdiscovery_test.go +++ /dev/null @@ -1,35 +0,0 @@ -package node - -import ( - "bytes" - "crypto/rand" - "testing" - - "golang.org/x/crypto/nacl/sign" -) - -func TestLocalDiscoveryPacketSigning(t *testing.T) { - localIP = 32 - - var ( - buf1 = make([]byte, bufferSize) - buf2 = make([]byte, bufferSize) - pubSignKey, privSigKey, _ = sign.GenerateKey(rand.Reader) - ) - - privSignKey = privSigKey[:] - route := routingTable[localIP].Load() - route.IP = byte(localIP) - route.PubSignKey = pubSignKey[0:32] - routingTable[localIP].Store(route) - - out := buildLocalDiscoveryPacket(buf1, buf2) - - h, ok := openLocalDiscoveryPacket(bytes.Clone(out), buf1) - if !ok { - t.Fatal(h, ok) - } - if h.StreamID != controlStreamID || h.SourceIP != localIP || h.DestIP != 255 { - t.Fatal(h) - } -} diff --git a/node/main.go b/node/main.go deleted file mode 100644 index 4e59cf7..0000000 --- a/node/main.go +++ /dev/null @@ -1,323 +0,0 @@ -package node - -import ( - "bytes" - "crypto/rand" - "encoding/json" - "flag" - "fmt" - "io" - "log" - "net" - "net/http" - "net/netip" - "net/url" - "os" - "runtime/debug" - "time" - "vppn/m" - - "golang.org/x/crypto/nacl/box" - "golang.org/x/crypto/nacl/sign" -) - -func panicHandler() { - if r := recover(); r != nil { - log.Fatalf("\n %v\n\nstacktrace from panic: %s\n", r, string(debug.Stack())) - } -} - -func Main() { - defer panicHandler() - - var hubAddress string - - flag.StringVar(&netName, "name", "", "[REQUIRED] The network name.") - flag.StringVar(&hubAddress, "hub-address", "", "[REQUIRED] The hub address.") - flag.StringVar(&apiKey, "api-key", "", "[REQUIRED] The node's API key.") - flag.Parse() - - if netName == "" || hubAddress == "" || apiKey == "" { - flag.Usage() - os.Exit(1) - } - - var err error - - hubURL, err = url.Parse(hubAddress) - if err != nil { - log.Fatalf("Failed to parse hub address: %v", err) - } - - main() -} - -func initPeerWithHub() { - encPubKey, encPrivKey, err := box.GenerateKey(rand.Reader) - if err != nil { - log.Fatalf("Failed to generate encryption keys: %v", err) - } - - signPubKey, signPrivKey, err := sign.GenerateKey(rand.Reader) - if err != nil { - log.Fatalf("Failed to generate signing keys: %v", err) - } - - initURL := *hubURL - initURL.Path = "/peer/init/" - - args := m.PeerInitArgs{ - EncPubKey: encPubKey[:], - PubSignKey: signPubKey[:], - } - - buf := &bytes.Buffer{} - if err := json.NewEncoder(buf).Encode(args); err != nil { - log.Fatalf("Failed to encode init args: %v", err) - } - - req, err := http.NewRequest(http.MethodPost, initURL.String(), buf) - if err != nil { - log.Fatalf("Failed to construct request: %v", err) - } - req.SetBasicAuth("", apiKey) - - resp, err := http.DefaultClient.Do(req) - if err != nil { - log.Fatalf("Failed to init with hub: %v", err) - } - defer resp.Body.Close() - - data, err := io.ReadAll(resp.Body) - if err != nil { - log.Fatalf("Failed to read response body: %v", err) - } - - peerConfig := localConfig{} - if err := json.Unmarshal(data, &peerConfig.PeerConfig); err != nil { - log.Fatalf("Failed to parse configuration: %v\n%s", err, data) - } - - peerConfig.PubKey = encPubKey[:] - peerConfig.PrivKey = encPrivKey[:] - peerConfig.PubSignKey = signPubKey[:] - peerConfig.PrivSignKey = signPrivKey[:] - - if err := storePeerConfig(netName, peerConfig); err != nil { - log.Fatalf("Failed to store configuration: %v", err) - } - - log.Print("Initialization successful.") -} - -// ---------------------------------------------------------------------------- - -func main() { - config, err := loadPeerConfig(netName) - if err != nil { - log.Printf("Failed to load configuration: %v", err) - log.Printf("Initializing...") - initPeerWithHub() - - config, err = loadPeerConfig(netName) - if err != nil { - log.Fatalf("Failed to load configuration: %v", err) - } - } - - iface, err := openInterface(config.Network, config.PeerIP, netName) - if err != nil { - log.Fatalf("Failed to open interface: %v", err) - } - - myAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", config.Port)) - if err != nil { - log.Fatalf("Failed to resolve UDP address: %v", err) - } - - conn, err := net.ListenUDP("udp", myAddr) - if err != nil { - log.Fatalf("Failed to open UDP port: %v", err) - } - - conn.SetReadBuffer(1024 * 1024 * 8) - conn.SetWriteBuffer(1024 * 1024 * 8) - - // Intialize globals. - _iface = newIFWriter(iface) - _conn = newConnWriter(conn) - - localIP = config.PeerIP - - ip, ok := netip.AddrFromSlice(config.PublicIP) - if ok { - localPub = true - localAddr = netip.AddrPortFrom(ip, config.Port) - } - - privKey = config.PrivKey - privSignKey = config.PrivSignKey - - if !localPub { - go relayManager() - go localDiscovery() - } - - go func() { - for range time.Tick(pingInterval) { - messages <- pingTimerMsg{} - } - }() - - go startPeerSuper() - - go newHubPoller().Run() - go readFromConn(conn) - - readFromIFace(iface) -} - -// ---------------------------------------------------------------------------- - -func readFromConn(conn *net.UDPConn) { - - defer panicHandler() - - var ( - remoteAddr netip.AddrPort - n int - err error - buf = make([]byte, bufferSize) - decBuf = make([]byte, bufferSize) - data []byte - h header - ) - - for { - n, remoteAddr, err = conn.ReadFromUDPAddrPort(buf[:bufferSize]) - if err != nil { - log.Fatalf("Failed to read from UDP port: %v", err) - } - - remoteAddr = netip.AddrPortFrom(remoteAddr.Addr().Unmap(), remoteAddr.Port()) - - data = buf[:n] - - if n < headerSize { - continue // Packet it soo short. - } - - h.Parse(data) - switch h.StreamID { - case controlStreamID: - handleControlPacket(remoteAddr, h, data, decBuf) - - case dataStreamID: - handleDataPacket(h, data, decBuf) - - default: - log.Printf("Unknown stream ID: %d", h.StreamID) - } - } -} - -func handleControlPacket(addr netip.AddrPort, h header, data, decBuf []byte) { - route := routingTable[h.SourceIP].Load() - if route.ControlCipher == nil { - //log.Printf("Not connected (control).") - return - } - - if h.DestIP != localIP { - log.Printf("Incorrect destination IP on control packet: %#v", h) - return - } - - out, ok := route.ControlCipher.Decrypt(data, decBuf) - if !ok { - log.Printf("Failed to decrypt control packet.") - return - } - - if len(out) == 0 { - log.Printf("Empty control packet from: %d", h.SourceIP) - return - } - - if dupChecks[h.SourceIP].IsDup(h.Counter) { - log.Printf("[%03d] Duplicate control packet: %d", h.SourceIP, h.Counter) - return - } - - msg, err := parseControlMsg(h.SourceIP, addr, out) - if err != nil { - log.Printf("Failed to parse control packet: %v", err) - return - } - - select { - case messages <- msg: - default: - log.Printf("Dropping control packet.") - } - -} - -func handleDataPacket(h header, data []byte, decBuf []byte) { - route := routingTable[h.SourceIP].Load() - if !route.Up { - log.Printf("Not connected (recv).") - return - } - - dec, ok := route.DataCipher.Decrypt(data, decBuf) - if !ok { - log.Printf("Failed to decrypt data packet.") - return - } - - if dupChecks[h.SourceIP].IsDup(h.Counter) { - log.Printf("[%03d] Duplicate data packet: %d", h.SourceIP, h.Counter) - return - } - - if h.DestIP == localIP { - _iface.Write(dec) - return - } - - destRoute := routingTable[h.DestIP].Load() - if !destRoute.Up { - log.Printf("Not connected (relay): %d", destRoute.IP) - return - } - - _conn.WriteTo(dec, destRoute.RemoteAddr) -} - -// ---------------------------------------------------------------------------- - -func readFromIFace(iface io.ReadWriteCloser) { - var ( - packet = make([]byte, bufferSize) - buf1 = make([]byte, bufferSize) - buf2 = make([]byte, bufferSize) - remoteIP byte - err error - ) - - for { - packet, remoteIP, err = readNextPacket(iface, packet) - if err != nil { - log.Fatalf("Failed to read from interface: %v", err) - } - - route := routingTable[remoteIP].Load() - if !route.Up { - log.Printf("Route not connected: %d", remoteIP) - continue - } - - _sendDataPacket(route, packet, buf1, buf2) - } -} diff --git a/node/packets.go b/node/packets.go deleted file mode 100644 index 14d7377..0000000 --- a/node/packets.go +++ /dev/null @@ -1,129 +0,0 @@ -package node - -import ( - "errors" - "net/netip" -) - -var ( - errMalformedPacket = errors.New("malformed packet") - errUnknownPacketType = errors.New("unknown packet type") -) - -const ( - packetTypeSyn = iota + 1 - packetTypeSynAck - packetTypeAck - packetTypeProbe - packetTypeAddrDiscovery -) - -// ---------------------------------------------------------------------------- - -type synPacket struct { - TraceID uint64 // TraceID to match response w/ request. - SharedKey [32]byte // Our shared key. - Direct bool - PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. -} - -func (p synPacket) Marshal(buf []byte) []byte { - return newBinWriter(buf). - Byte(packetTypeSyn). - Uint64(p.TraceID). - SharedKey(p.SharedKey). - Bool(p.Direct). - AddrPort(p.PossibleAddrs[0]). - AddrPort(p.PossibleAddrs[1]). - AddrPort(p.PossibleAddrs[2]). - AddrPort(p.PossibleAddrs[3]). - AddrPort(p.PossibleAddrs[4]). - AddrPort(p.PossibleAddrs[5]). - AddrPort(p.PossibleAddrs[6]). - AddrPort(p.PossibleAddrs[7]). - Build() -} - -func parseSynPacket(buf []byte) (p synPacket, err error) { - err = newBinReader(buf[1:]). - Uint64(&p.TraceID). - SharedKey(&p.SharedKey). - Bool(&p.Direct). - AddrPort(&p.PossibleAddrs[0]). - AddrPort(&p.PossibleAddrs[1]). - AddrPort(&p.PossibleAddrs[2]). - AddrPort(&p.PossibleAddrs[3]). - AddrPort(&p.PossibleAddrs[4]). - AddrPort(&p.PossibleAddrs[5]). - AddrPort(&p.PossibleAddrs[6]). - AddrPort(&p.PossibleAddrs[7]). - Error() - return -} - -// ---------------------------------------------------------------------------- - -type ackPacket struct { - TraceID uint64 - ToAddr netip.AddrPort - PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. -} - -func (p ackPacket) Marshal(buf []byte) []byte { - return newBinWriter(buf). - Byte(packetTypeAck). - Uint64(p.TraceID). - AddrPort(p.ToAddr). - AddrPort(p.PossibleAddrs[0]). - AddrPort(p.PossibleAddrs[1]). - AddrPort(p.PossibleAddrs[2]). - AddrPort(p.PossibleAddrs[3]). - AddrPort(p.PossibleAddrs[4]). - AddrPort(p.PossibleAddrs[5]). - AddrPort(p.PossibleAddrs[6]). - AddrPort(p.PossibleAddrs[7]). - Build() - -} - -func parseAckPacket(buf []byte) (p ackPacket, err error) { - err = newBinReader(buf[1:]). - Uint64(&p.TraceID). - AddrPort(&p.ToAddr). - AddrPort(&p.PossibleAddrs[0]). - AddrPort(&p.PossibleAddrs[1]). - AddrPort(&p.PossibleAddrs[2]). - AddrPort(&p.PossibleAddrs[3]). - AddrPort(&p.PossibleAddrs[4]). - AddrPort(&p.PossibleAddrs[5]). - AddrPort(&p.PossibleAddrs[6]). - AddrPort(&p.PossibleAddrs[7]). - Error() - return -} - -// ---------------------------------------------------------------------------- - -// A probeReqPacket is sent from a client to a server to determine if direct -// UDP communication can be used. -type probePacket struct { - TraceID uint64 -} - -func (p probePacket) Marshal(buf []byte) []byte { - return newBinWriter(buf). - Byte(packetTypeProbe). - Uint64(p.TraceID). - Build() -} - -func parseProbePacket(buf []byte) (p probePacket, err error) { - err = newBinReader(buf[1:]). - Uint64(&p.TraceID). - Error() - return -} - -// ---------------------------------------------------------------------------- - -type localDiscoveryPacket struct{} diff --git a/node/packets_test.go b/node/packets_test.go deleted file mode 100644 index 254bcc7..0000000 --- a/node/packets_test.go +++ /dev/null @@ -1,41 +0,0 @@ -package node - -import ( - "crypto/rand" - "net/netip" - "reflect" - "testing" -) - -func TestPacketSyn(t *testing.T) { - in := synPacket{ - TraceID: newTraceID(), - FromAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{4, 5, 6, 7}), 22), - } - rand.Read(in.SharedKey[:]) - - out, err := parseSynPacket(in.Marshal(make([]byte, bufferSize))) - if err != nil { - t.Fatal(err) - } - - if !reflect.DeepEqual(in, out) { - t.Fatal("\n", in, "\n", out) - } -} - -func TestPacketSynAck(t *testing.T) { - in := ackPacket{ - TraceID: newTraceID(), - FromAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{4, 5, 6, 7}), 22), - } - - out, err := parseAckPacket(in.Marshal(make([]byte, bufferSize))) - if err != nil { - t.Fatal(err) - } - - if !reflect.DeepEqual(in, out) { - t.Fatal("\n", in, "\n", out) - } -} diff --git a/node/relaymanager.go b/node/relaymanager.go deleted file mode 100644 index 5c44ea8..0000000 --- a/node/relaymanager.go +++ /dev/null @@ -1,40 +0,0 @@ -package node - -import ( - "log" - "math/rand" - "time" -) - -func relayManager() { - time.Sleep(2 * time.Second) - updateRelayRoute() - - for range time.Tick(8 * time.Second) { - relay := getRelayRoute() - if relay == nil || !relay.Up || !relay.Relay { - updateRelayRoute() - } - } -} - -func updateRelayRoute() { - possible := make([]*peerRoute, 0, 8) - for i := range routingTable { - route := routingTable[i].Load() - if !route.Up || !route.Relay { - continue - } - possible = append(possible, route) - } - - if len(possible) == 0 { - log.Printf("No relay available.") - relayIP.Store(nil) - return - } - - ip := possible[rand.Intn(len(possible))].IP - log.Printf("New relay IP: %d", ip) - relayIP.Store(&ip) -} diff --git a/node/supervisor.go b/node/supervisor.go deleted file mode 100644 index 6b5e96a..0000000 --- a/node/supervisor.go +++ /dev/null @@ -1,417 +0,0 @@ -package node - -import ( - "fmt" - "log" - "net/netip" - "strings" - "sync/atomic" - "time" - "vppn/m" - - "git.crumpington.com/lib/go/ratelimiter" -) - -const ( - pingInterval = 8 * time.Second - timeoutInterval = 30 * time.Second -) - -// ---------------------------------------------------------------------------- - -func startPeerSuper() { - peers := [256]peerState{} - for i := range peers { - data := &peerStateData{ - published: routingTable[i], - remoteIP: byte(i), - buf1: make([]byte, bufferSize), - buf2: make([]byte, bufferSize), - limiter: ratelimiter.New(ratelimiter.Config{ - FillPeriod: 20 * time.Millisecond, - MaxWaitCount: 1, - }), - } - peers[i] = data.OnPeerUpdate(nil) - } - go runPeerSuper(peers) -} - -func runPeerSuper(peers [256]peerState) { - for raw := range messages { - switch msg := raw.(type) { - - case peerUpdateMsg: - peers[msg.PeerIP] = peers[msg.PeerIP].OnPeerUpdate(msg.Peer) - - case controlMsg[synPacket]: - peers[msg.SrcIP].OnSyn(msg) - - case controlMsg[ackPacket]: - peers[msg.SrcIP].OnAck(msg) - - case controlMsg[probePacket]: - peers[msg.SrcIP].OnProbe(msg) - - case controlMsg[localDiscoveryPacket]: - peers[msg.SrcIP].OnLocalDiscovery(msg) - - case pingTimerMsg: - publicAddrs.Clean() - for i := range peers { - if newState := peers[i].OnPingTimer(); newState != nil { - peers[i] = newState - } - } - - default: - log.Printf("WARNING: unknown message type: %+v", msg) - } - } -} - -// ---------------------------------------------------------------------------- - -type peerState interface { - OnPeerUpdate(*m.Peer) peerState - OnSyn(controlMsg[synPacket]) - OnAck(controlMsg[ackPacket]) - OnProbe(controlMsg[probePacket]) - OnLocalDiscovery(controlMsg[localDiscoveryPacket]) - OnPingTimer() peerState -} - -// ---------------------------------------------------------------------------- - -type peerStateData struct { - // The purpose of this state machine is to manage this published data. - published *atomic.Pointer[peerRoute] - staged peerRoute // Local copy of shared data. See publish(). - - // Immutable data. - remoteIP byte // Remote VPN IP. - - // Mutable peer data. - peer *m.Peer - remotePub bool - - // Buffers for sending control packets. - buf1 []byte - buf2 []byte - - // For logging. Set per-state. - client bool - - // We rate limit per remote endpoint because if we don't we tend to lose - // packets. - limiter *ratelimiter.Limiter -} - -// ---------------------------------------------------------------------------- - -func (s *peerStateData) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) { - s._sendControlPacket(pkt, s.staged) -} - -func (s *peerStateData) sendControlPacketTo(pkt interface{ Marshal([]byte) []byte }, addr netip.AddrPort) { - if !addr.IsValid() { - s.logf("ERROR: Attepted to send packet to invalid address: %v", addr) - return - } - route := s.staged - route.Direct = true - route.RemoteAddr = addr - s._sendControlPacket(pkt, route) -} - -func (s *peerStateData) _sendControlPacket(pkt interface{ Marshal([]byte) []byte }, route peerRoute) { - if err := s.limiter.Limit(); err != nil { - s.logf("Not sending control packet: rate limited.") // Shouldn't happen. - return - } - _sendControlPacket(pkt, route, s.buf1, s.buf2) -} - -// ---------------------------------------------------------------------------- - -func (s *peerStateData) publish() { - data := s.staged - s.published.Store(&data) -} - -func (s *peerStateData) logf(format string, args ...any) { - b := strings.Builder{} - b.WriteString(fmt.Sprintf("%30s: ", s.peer.Name)) - - if s.client { - b.WriteString("CLIENT | ") - } else { - b.WriteString("SERVER | ") - } - - if s.staged.Direct { - b.WriteString("DIRECT | ") - } else { - b.WriteString("RELAYED | ") - } - - if s.staged.Up { - b.WriteString("UP | ") - } else { - b.WriteString("DOWN | ") - } - - log.Printf(b.String()+format, args...) -} - -// ---------------------------------------------------------------------------- - -func (s *peerStateData) OnPeerUpdate(peer *m.Peer) peerState { - defer s.publish() - - if peer == nil { - return enterStateDisconnected(s) - } - - s.peer = peer - s.staged = peerRoute{ - IP: s.remoteIP, - PubSignKey: peer.PubSignKey, - ControlCipher: newControlCipher(privKey, peer.PubKey), - DataCipher: newDataCipher(), - } - s.remotePub = false - - if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid { - s.remotePub = true - s.staged.Relay = peer.Relay - s.staged.Direct = true - s.staged.RemoteAddr = netip.AddrPortFrom(ip, peer.Port) - } else if localPub { - s.staged.Direct = true - } - - if s.remotePub == localPub { - if localIP < s.remoteIP { - return enterStateServer(s) - } - return enterStateClient(s) - } - - if s.remotePub { - return enterStateClient(s) - } - return enterStateServer(s) -} - -// ---------------------------------------------------------------------------- - -type stateDisconnected struct { - *peerStateData -} - -func enterStateDisconnected(s *peerStateData) peerState { - s.peer = nil - s.staged = peerRoute{} - s.publish() - return &stateDisconnected{s} -} - -func (s *stateDisconnected) OnSyn(controlMsg[synPacket]) {} -func (s *stateDisconnected) OnAck(controlMsg[ackPacket]) {} -func (s *stateDisconnected) OnProbe(controlMsg[probePacket]) {} -func (s *stateDisconnected) OnLocalDiscovery(controlMsg[localDiscoveryPacket]) {} - -func (s *stateDisconnected) OnPingTimer() peerState { - return nil -} - -// ---------------------------------------------------------------------------- - -type stateServer struct { - *stateDisconnected - lastSeen time.Time - synTraceID uint64 -} - -func enterStateServer(s *peerStateData) peerState { - s.client = false - return &stateServer{stateDisconnected: &stateDisconnected{s}} -} - -func (s *stateServer) OnSyn(msg controlMsg[synPacket]) { - s.lastSeen = time.Now() - p := msg.Packet - - // Before we can respond to this packet, we need to make sure the - // route is setup properly. - // - // The client will update the syn's TraceID whenever there's a change. - // The server will follow the client's request. - if p.TraceID != s.synTraceID || !s.staged.Up { - s.synTraceID = p.TraceID - s.staged.Up = true - s.staged.Direct = p.Direct - s.staged.DataCipher = newDataCipherFromKey(p.SharedKey) - s.staged.RemoteAddr = msg.SrcAddr - s.publish() - s.logf("Got syn.") - } - - // Always respond. - ack := ackPacket{ - TraceID: p.TraceID, - ToAddr: s.staged.RemoteAddr, - PossibleAddrs: publicAddrs.Get(), - } - s.sendControlPacket(ack) - - if s.staged.Direct { - return - } - - // Not direct => send probes. - for _, addr := range p.PossibleAddrs { - if !addr.IsValid() { - break - } - s.sendControlPacketTo(probePacket{TraceID: newTraceID()}, addr) - } -} - -func (s *stateServer) OnProbe(msg controlMsg[probePacket]) { - if !msg.SrcAddr.IsValid() { - s.logf("Invalid probe address.") - return - } - s.sendControlPacketTo(probePacket{TraceID: msg.Packet.TraceID}, msg.SrcAddr) -} - -func (s *stateServer) OnPingTimer() peerState { - if time.Since(s.lastSeen) > timeoutInterval && s.staged.Up { - s.staged.Up = false - s.publish() - s.logf("Connection timeout.") - } - return nil -} - -// ---------------------------------------------------------------------------- - -type stateClient struct { - *stateDisconnected - - lastSeen time.Time - syn synPacket - ack ackPacket - - probes map[uint64]netip.AddrPort - localDiscoveryAddr netip.AddrPort -} - -func enterStateClient(s *peerStateData) peerState { - s.client = true - ss := &stateClient{ - stateDisconnected: &stateDisconnected{s}, - probes: map[uint64]netip.AddrPort{}, - } - - ss.syn = synPacket{ - TraceID: newTraceID(), - SharedKey: s.staged.DataCipher.Key(), - Direct: s.staged.Direct, - PossibleAddrs: publicAddrs.Get(), - } - ss.sendControlPacket(ss.syn) - - return ss -} - -func (s *stateClient) sendProbeTo(addr netip.AddrPort) { - probe := probePacket{TraceID: newTraceID()} - s.probes[probe.TraceID] = addr - s.sendControlPacketTo(probe, addr) -} - -func (s *stateClient) OnAck(msg controlMsg[ackPacket]) { - if msg.Packet.TraceID != s.syn.TraceID { - s.logf("Ack has incorrect trace ID") - return - } - - s.ack = msg.Packet - s.lastSeen = time.Now() - - if !s.staged.Up { - s.staged.Up = true - s.logf("Got ack.") - s.publish() - } - - // Store possible public address if we're not a public node. - if !localPub && s.remotePub { - publicAddrs.Store(msg.Packet.ToAddr) - } -} - -func (s *stateClient) OnProbe(msg controlMsg[probePacket]) { - if s.staged.Direct { - return - } - - addr, ok := s.probes[msg.Packet.TraceID] - if !ok { - return - } - - s.staged.RemoteAddr = addr - s.staged.Direct = true - s.publish() - - s.syn.TraceID = newTraceID() - s.syn.Direct = true - s.syn.PossibleAddrs = [8]netip.AddrPort{} - s.sendControlPacket(s.syn) - - s.logf("Established direct connection to %s.", s.staged.RemoteAddr.String()) -} - -func (s *stateClient) OnLocalDiscovery(msg controlMsg[localDiscoveryPacket]) { - if s.staged.Direct { - return - } - - // The source port will be the multicast port, so we'll have to - // construct the correct address using the peer's listed port. - s.localDiscoveryAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) -} - -func (s *stateClient) OnPingTimer() peerState { - if time.Since(s.lastSeen) > timeoutInterval { - if s.staged.Up { - s.logf("Connection timeout.") - } - return s.OnPeerUpdate(s.peer) - } - - s.sendControlPacket(s.syn) - - if s.staged.Direct { - return nil - } - - clear(s.probes) - for _, addr := range s.ack.PossibleAddrs { - if !addr.IsValid() { - break - } - s.sendProbeTo(addr) - } - - if s.localDiscoveryAddr.IsValid() { - s.sendProbeTo(s.localDiscoveryAddr) - s.localDiscoveryAddr = netip.AddrPort{} - } - - return nil -} diff --git a/node/bitset.go b/peer/bitset.go similarity index 96% rename from node/bitset.go rename to peer/bitset.go index a9024cb..8d03b50 100644 --- a/node/bitset.go +++ b/peer/bitset.go @@ -1,4 +1,4 @@ -package node +package peer const bitSetSize = 512 // Multiple of 64. diff --git a/node/bitset_test.go b/peer/bitset_test.go similarity index 97% rename from node/bitset_test.go rename to peer/bitset_test.go index bd3307a..01ae82b 100644 --- a/node/bitset_test.go +++ b/peer/bitset_test.go @@ -1,4 +1,4 @@ -package node +package peer import ( "math/rand" diff --git a/node/cipher-control.go b/peer/cipher-control.go similarity index 98% rename from node/cipher-control.go rename to peer/cipher-control.go index bd11470..bfecaeb 100644 --- a/node/cipher-control.go +++ b/peer/cipher-control.go @@ -1,4 +1,4 @@ -package node +package peer import "golang.org/x/crypto/nacl/box" diff --git a/node/cipher-control_test.go b/peer/cipher-control_test.go similarity index 99% rename from node/cipher-control_test.go rename to peer/cipher-control_test.go index ab28860..916d2ea 100644 --- a/node/cipher-control_test.go +++ b/peer/cipher-control_test.go @@ -1,4 +1,4 @@ -package node +package peer import ( "bytes" diff --git a/node/cipher-data.go b/peer/cipher-data.go similarity index 85% rename from node/cipher-data.go rename to peer/cipher-data.go index 9151870..9b229bb 100644 --- a/node/cipher-data.go +++ b/peer/cipher-data.go @@ -1,9 +1,10 @@ -package node +package peer import ( "crypto/aes" "crypto/cipher" "crypto/rand" + "log" ) type dataCipher struct { @@ -14,7 +15,7 @@ type dataCipher struct { func newDataCipher() *dataCipher { key := [32]byte{} if _, err := rand.Read(key[:]); err != nil { - panic(err) + log.Fatalf("Failed to read random data: %v", err) } return newDataCipherFromKey(key) } @@ -22,12 +23,12 @@ func newDataCipher() *dataCipher { func newDataCipherFromKey(key [32]byte) *dataCipher { block, err := aes.NewCipher(key[:]) if err != nil { - panic(err) + log.Fatalf("Failed to create new cipher: %v", err) } aead, err := cipher.NewGCM(block) if err != nil { - panic(err) + log.Fatalf("Failed to create new GCM: %v", err) } return &dataCipher{key: key, aead: aead} diff --git a/node/cipher-data_test.go b/peer/cipher-data_test.go similarity index 99% rename from node/cipher-data_test.go rename to peer/cipher-data_test.go index 493c198..ac9a03a 100644 --- a/node/cipher-data_test.go +++ b/peer/cipher-data_test.go @@ -1,4 +1,4 @@ -package node +package peer import ( "bytes" diff --git a/node/cipher-discovery.go b/peer/cipher-discovery.go similarity index 95% rename from node/cipher-discovery.go rename to peer/cipher-discovery.go index 85e1381..0e66650 100644 --- a/node/cipher-discovery.go +++ b/peer/cipher-discovery.go @@ -1,4 +1,4 @@ -package node +package peer /* func signData(privKey *[64]byte, h header, data, out []byte) []byte { diff --git a/peer/connreader.go b/peer/connreader.go new file mode 100644 index 0000000..4c156f4 --- /dev/null +++ b/peer/connreader.go @@ -0,0 +1,140 @@ +package peer + +import ( + "io" + "log" + "net/netip" + "sync/atomic" +) + +type connReader struct { + // Input + readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error) + + // Output + writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error) + iface io.Writer + handleControlMsg func(fromIP byte, pkt any) + + localIP byte + rt *atomic.Pointer[routingTable] + + buf []byte + decBuf []byte +} + +func newConnReader( + readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error), + writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error), + iface io.Writer, + handleControlMsg func(fromIP byte, pkt any), + rt *atomic.Pointer[routingTable], +) *connReader { + return &connReader{ + readFromUDPAddrPort: readFromUDPAddrPort, + writeToUDPAddrPort: writeToUDPAddrPort, + iface: iface, + handleControlMsg: handleControlMsg, + localIP: rt.Load().LocalIP, + rt: rt, + buf: newBuf(), + decBuf: newBuf(), + } +} + +func (r *connReader) Run() { + for { + r.handleNextPacket() + } +} + +func (r *connReader) handleNextPacket() { + buf := r.buf[:bufferSize] + n, remoteAddr, err := r.readFromUDPAddrPort(buf) + if err != nil { + log.Fatalf("Failed to read from UDP port: %v", err) + } + + if n < headerSize { + return + } + + remoteAddr = netip.AddrPortFrom(remoteAddr.Addr().Unmap(), remoteAddr.Port()) + + buf = buf[:n] + h := parseHeader(buf) + + rt := r.rt.Load() + peer := rt.Peers[h.SourceIP] + + switch h.StreamID { + case controlStreamID: + r.handleControlPacket(remoteAddr, peer, h, buf) + case dataStreamID: + r.handleDataPacket(rt, peer, h, buf) + default: + r.logf("Unknown stream ID: %d", h.StreamID) + } +} + +func (r *connReader) handleControlPacket( + remoteAddr netip.AddrPort, + peer remotePeer, + h header, + enc []byte, +) { + if peer.ControlCipher == nil { + r.logf("No control cipher for peer: %d", h.SourceIP) + return + } + + if h.DestIP != r.localIP { + r.logf("Incorrect destination IP on control packet: %d", h.DestIP) + return + } + + msg, err := peer.DecryptControlPacket(remoteAddr, h, enc, r.decBuf) + if err != nil { + r.logf("Failed to decrypt control packet: %v", err) + return + } + + r.handleControlMsg(h.SourceIP, msg) +} + +func (r *connReader) handleDataPacket( + rt *routingTable, + peer remotePeer, + h header, + enc []byte, +) { + if !peer.Up { + r.logf("Not connected (recv).") + return + } + + data, err := peer.DecryptDataPacket(h, enc, r.decBuf) + if err != nil { + r.logf("Failed to decrypt data packet: %v", err) + return + } + + if h.DestIP == r.localIP { + if _, err := r.iface.Write(data); err != nil { + log.Fatalf("Failed to write to interface: %v", err) + } + return + } + + remote := rt.Peers[h.DestIP] + if !remote.Direct { + r.logf("Unable to relay data to %d.", h.DestIP) + return + } + + r.writeToUDPAddrPort(data, remote.DirectAddr) +} + +func (r *connReader) logf(format string, args ...any) { + log.Printf("[ConnReader] "+format, args...) +} diff --git a/node/messages.go b/peer/controlmessage.go similarity index 68% rename from node/messages.go rename to peer/controlmessage.go index 76d86d4..f327291 100644 --- a/node/messages.go +++ b/peer/controlmessage.go @@ -1,4 +1,4 @@ -package node +package peer import ( "net/netip" @@ -16,25 +16,33 @@ type controlMsg[T any] struct { func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error) { switch buf[0] { + case packetTypeInit: + packet, err := parsePacketInit(buf) + return controlMsg[packetInit]{ + SrcIP: srcIP, + SrcAddr: srcAddr, + Packet: packet, + }, err + case packetTypeSyn: - packet, err := parseSynPacket(buf) - return controlMsg[synPacket]{ + packet, err := parsePacketSyn(buf) + return controlMsg[packetSyn]{ SrcIP: srcIP, SrcAddr: srcAddr, Packet: packet, }, err case packetTypeAck: - packet, err := parseAckPacket(buf) - return controlMsg[ackPacket]{ + packet, err := parsePacketAck(buf) + return controlMsg[packetAck]{ SrcIP: srcIP, SrcAddr: srcAddr, Packet: packet, }, err case packetTypeProbe: - packet, err := parseProbePacket(buf) - return controlMsg[probePacket]{ + packet, err := parsePacketProbe(buf) + return controlMsg[packetProbe]{ SrcIP: srcIP, SrcAddr: srcAddr, Packet: packet, @@ -48,12 +56,9 @@ func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error // ---------------------------------------------------------------------------- type peerUpdateMsg struct { - PeerIP byte - Peer *m.Peer + Peer *m.Peer } // ---------------------------------------------------------------------------- type pingTimerMsg struct{} - -// ---------------------------------------------------------------------------- diff --git a/peer/crypto.go b/peer/crypto.go new file mode 100644 index 0000000..a533e6d --- /dev/null +++ b/peer/crypto.go @@ -0,0 +1,30 @@ +package peer + +import ( + "crypto/rand" + "log" + + "golang.org/x/crypto/nacl/box" + "golang.org/x/crypto/nacl/sign" +) + +type cryptoKeys struct { + PubKey []byte + PrivKey []byte + PubSignKey []byte + PrivSignKey []byte +} + +func generateKeys() cryptoKeys { + pubKey, privKey, err := box.GenerateKey(rand.Reader) + if err != nil { + log.Fatalf("Failed to generate encryption keys: %v", err) + } + + pubSignKey, privSignKey, err := sign.GenerateKey(rand.Reader) + if err != nil { + log.Fatalf("Failed to generate signing keys: %v", err) + } + + return cryptoKeys{pubKey[:], privKey[:], pubSignKey[:], privSignKey[:]} +} diff --git a/peer/crypto_test.go b/peer/crypto_test.go new file mode 100644 index 0000000..b3c00f3 --- /dev/null +++ b/peer/crypto_test.go @@ -0,0 +1,191 @@ +package peer + +import ( + "net/netip" + "reflect" + "testing" +) + +func newRoutePairForTesting() (*remotePeer, *remotePeer) { + keys1 := generateKeys() + keys2 := generateKeys() + + r1 := newRemotePeer(1) + r1.PubSignKey = keys1.PubSignKey + r1.ControlCipher = newControlCipher(keys1.PrivKey, keys2.PubKey) + r1.DataCipher = newDataCipher() + + r2 := newRemotePeer(2) + r2.PubSignKey = keys2.PubSignKey + r2.ControlCipher = newControlCipher(keys2.PrivKey, keys1.PubKey) + r2.DataCipher = r1.DataCipher + + return r1, r2 +} + +func TestDecryptControlPacket(t *testing.T) { + var ( + r1, r2 = newRoutePairForTesting() + tmp = make([]byte, bufferSize) + out = make([]byte, bufferSize) + ) + + in := packetSyn{ + TraceID: newTraceID(), + SharedKey: r1.DataCipher.Key(), + Direct: true, + } + + enc := r1.EncryptControlPacket(in, tmp, out) + h := parseHeader(enc) + + iMsg, err := r2.DecryptControlPacket(netip.AddrPort{}, h, enc, tmp) + if err != nil { + t.Fatal(err) + } + + msg, ok := iMsg.(controlMsg[packetSyn]) + if !ok { + t.Fatal(ok) + } + + if !reflect.DeepEqual(msg.Packet, in) { + t.Fatal(msg) + } +} + +/* + func TestDecryptControlPacket_decryptionFailed(t *testing.T) { + var ( + r1, r2 = newRoutePairForTesting() + tmp = make([]byte, bufferSize) + out = make([]byte, bufferSize) + ) + + in := packetSyn{ + TraceID: newTraceID(), + SharedKey: r1.DataCipher.Key(), + Direct: true, + } + + enc := encryptControlPacket(r1.IP, r2, in, tmp, out) + h := parseHeader(enc) + + for i := range enc { + x := bytes.Clone(enc) + x[i]++ + _, err := decryptControlPacket(r2, netip.AddrPort{}, h, x, tmp) + if !errors.Is(err, errDecryptionFailed) { + t.Fatal(i, err) + } + } + } + + func TestDecryptControlPacket_duplicate(t *testing.T) { + var ( + r1, r2 = newRoutePairForTesting() + tmp = make([]byte, bufferSize) + out = make([]byte, bufferSize) + ) + + in := packetSyn{ + TraceID: newTraceID(), + SharedKey: r1.DataCipher.Key(), + Direct: true, + } + + enc := encryptControlPacket(r1.IP, r2, in, tmp, out) + h := parseHeader(enc) + + if _, err := decryptControlPacket(r2, netip.AddrPort{}, h, enc, tmp); err != nil { + t.Fatal(err) + } + + _, err := decryptControlPacket(r2, netip.AddrPort{}, h, enc, tmp) + if !errors.Is(err, errDuplicateSeqNum) { + t.Fatal(err) + } + } + + func TestDecryptControlPacket_invalidPacket(t *testing.T) { + var ( + r1, r2 = newRoutePairForTesting() + tmp = make([]byte, bufferSize) + out = make([]byte, bufferSize) + ) + + in := testPacket("hello!") + + enc := encryptControlPacket(r1.IP, r2, in, tmp, out) + h := parseHeader(enc) + + _, err := decryptControlPacket(r2, netip.AddrPort{}, h, enc, tmp) + if !errors.Is(err, errUnknownPacketType) { + t.Fatal(err) + } + } + +func TestDecryptDataPacket(t *testing.T) { + var ( + r1, r2 = newRoutePairForTesting() + out = make([]byte, bufferSize) + data = make([]byte, 1024) + ) + + rand.Read(data) + + enc := encryptDataPacket(r1.IP, r2.IP, r2, data, out) + h := parseHeader(enc) + + out, err := decryptDataPacket(r1, h, bytes.Clone(enc), out) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(data, out) { + t.Fatal(data, out) + } +} + +func TestDecryptDataPacket_incorrectCipher(t *testing.T) { + var ( + r1, r2 = newRoutePairForTesting() + out = make([]byte, bufferSize) + data = make([]byte, 1024) + ) + + rand.Read(data) + + enc := encryptDataPacket(r1.IP, r2.IP, r2, data, bytes.Clone(out)) + h := parseHeader(enc) + + r1.DataCipher = newDataCipher() + _, err := decryptDataPacket(r1, h, enc, bytes.Clone(out)) + if !errors.Is(err, errDecryptionFailed) { + t.Fatal(err) + } +} + +func TestDecryptDataPacket_duplicate(t *testing.T) { + var ( + r1, r2 = newRoutePairForTesting() + out = make([]byte, bufferSize) + data = make([]byte, 1024) + ) + + rand.Read(data) + + enc := encryptDataPacket(r1.IP, r2.IP, r2, data, bytes.Clone(out)) + h := parseHeader(enc) + + _, err := decryptDataPacket(r1, h, enc, bytes.Clone(out)) + if err != nil { + t.Fatal(err) + } + + _, err = decryptDataPacket(r1, h, enc, bytes.Clone(out)) + if !errors.Is(err, errDuplicateSeqNum) { + t.Fatal(err) + } +} +*/ diff --git a/peer/data-flow.dot b/peer/data-flow.dot new file mode 100644 index 0000000..45b6f05 --- /dev/null +++ b/peer/data-flow.dot @@ -0,0 +1,14 @@ +digraph d { + ifReader -> connWriter; + connReader -> ifWriter; + connReader -> connWriter; + connReader -> supervisor; + mcReader -> supervisor; + supervisor -> connWriter; + supervisor -> mcWriter; + hubPoller -> supervisor; + + connWriter [shape="box"]; + mcWriter [shape="box"]; + ifWriter [shape="box"]; +} \ No newline at end of file diff --git a/node/dupcheck.go b/peer/dupcheck.go similarity index 92% rename from node/dupcheck.go rename to peer/dupcheck.go index fac7a72..09b5b11 100644 --- a/node/dupcheck.go +++ b/peer/dupcheck.go @@ -1,4 +1,4 @@ -package node +package peer type dupCheck struct { bitSet @@ -38,14 +38,14 @@ func (dc *dupCheck) IsDup(counter uint64) bool { delta := counter - dc.tailCounter // Full clear. - if delta >= bitSetSize { + if delta >= bitSetSize-1 { dc.ClearAll() dc.Set(0) dc.tail = 1 dc.head = 2 dc.tailCounter = counter + 1 - dc.headCounter = dc.tailCounter - bitSetSize + dc.headCounter = dc.tailCounter - bitSetSize + 1 return false } diff --git a/node/dupcheck_test.go b/peer/dupcheck_test.go similarity index 79% rename from node/dupcheck_test.go rename to peer/dupcheck_test.go index 2156b4e..2b50d74 100644 --- a/node/dupcheck_test.go +++ b/peer/dupcheck_test.go @@ -1,4 +1,4 @@ -package node +package peer import ( "testing" @@ -19,6 +19,7 @@ func TestDupCheck(t *testing.T) { } testCases := []TestCase{ + {511, true}, {0, true}, {1, true}, {2, true}, @@ -39,11 +40,13 @@ func TestDupCheck(t *testing.T) { {516, false}, {517, true}, {2512, false}, - {2000, true}, - {2001, false}, + {2512, true}, + {2001, true}, + {2002, false}, + {2002, true}, {4000, false}, - {4000 - 512, true}, // Too old. - {4000 - 511, false}, // Just in the window. + {4000 - 511, true}, // Too old. + {4000 - 510, false}, // Just in the window. } for i, tc := range testCases { diff --git a/peer/errors.go b/peer/errors.go new file mode 100644 index 0000000..b1e07e2 --- /dev/null +++ b/peer/errors.go @@ -0,0 +1,10 @@ +package peer + +import "errors" + +var ( + errDecryptionFailed = errors.New("decryption failed") + errDuplicateSeqNum = errors.New("duplicate sequence number") + errMalformedPacket = errors.New("malformed packet") + errUnknownPacketType = errors.New("unknown packet type") +) diff --git a/node/files.go b/peer/files.go similarity index 92% rename from node/files.go rename to peer/files.go index 18f539b..b0eade5 100644 --- a/node/files.go +++ b/peer/files.go @@ -1,4 +1,4 @@ -package node +package peer import ( "encoding/json" @@ -8,6 +8,14 @@ import ( "vppn/m" ) +type localConfig struct { + m.PeerConfig + PubKey []byte + PrivKey []byte + PubSignKey []byte + PrivSignKey []byte +} + func configDir(netName string) string { d, err := os.UserHomeDir() if err != nil { diff --git a/peer/files_test.go b/peer/files_test.go new file mode 100644 index 0000000..5e32ced --- /dev/null +++ b/peer/files_test.go @@ -0,0 +1,57 @@ +package peer + +import ( + "path/filepath" + "reflect" + "testing" +) + +func TestFilePaths(t *testing.T) { + confDir := configDir("netName") + if filepath.Base(confDir) != "netName" { + t.Fatal(confDir) + } + if filepath.Base(filepath.Dir(confDir)) != ".vppn" { + t.Fatal(confDir) + } + + path := peerConfigPath("netName") + if path != filepath.Join(confDir, "peer-config.json") { + t.Fatal(path) + } + + path = peerStatePath("netName") + if path != filepath.Join(confDir, "peer-state.json") { + t.Fatal(path) + } +} + +func TestStoreLoadJson(t *testing.T) { + type Object struct { + Name string + Age int + Price float64 + } + + tmpDir := t.TempDir() + outPath := filepath.Join(tmpDir, "object.json") + + obj := Object{ + Name: "Jason", + Age: 22, + Price: 123.534, + } + + if err := storeJson(obj, outPath); err != nil { + t.Fatal(err) + } + + obj2 := Object{} + if err := loadJson(outPath, &obj2); err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(obj, obj2) { + t.Fatal(obj, obj2) + } +} diff --git a/peer/globals.go b/peer/globals.go new file mode 100644 index 0000000..6dd26eb --- /dev/null +++ b/peer/globals.go @@ -0,0 +1,37 @@ +package peer + +import ( + "net" + "net/netip" + "time" +) + +const ( + version = 1 + + bufferSize = 1536 + + if_mtu = 1200 + if_queue_len = 2048 + + controlCipherOverhead = 16 + dataCipherOverhead = 16 + signOverhead = 64 + + pingInterval = 8 * time.Second + timeoutInterval = 30 * time.Second + broadcastInterval = 16 * time.Second + broadcastErrorTimeoutInterval = 8 * time.Second +) + +var multicastAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom( + netip.AddrFrom4([4]byte{224, 0, 0, 157}), + 4560)) + +func newBuf() []byte { + return make([]byte, bufferSize) +} + +type marshaller interface { + Marshal([]byte) []byte +} diff --git a/node/header.go b/peer/header.go similarity index 79% rename from node/header.go rename to peer/header.go index 9d0417a..fae3780 100644 --- a/node/header.go +++ b/peer/header.go @@ -1,4 +1,4 @@ -package node +package peer import "unsafe" @@ -20,6 +20,15 @@ type header struct { Counter uint64 // Init with time.Now().Unix << 30 to ensure monotonic. } +func parseHeader(b []byte) (h header) { + h.Version = b[0] + h.StreamID = b[1] + h.SourceIP = b[2] + h.DestIP = b[3] + h.Counter = *(*uint64)(unsafe.Pointer(&b[4])) + return h +} + func (h *header) Parse(b []byte) { h.Version = b[0] h.StreamID = b[1] diff --git a/node/header_test.go b/peer/header_test.go similarity index 95% rename from node/header_test.go rename to peer/header_test.go index 9dbb061..11e2f8f 100644 --- a/node/header_test.go +++ b/peer/header_test.go @@ -1,4 +1,4 @@ -package node +package peer import "testing" diff --git a/peer/hubpoller.go b/peer/hubpoller.go new file mode 100644 index 0000000..0082989 --- /dev/null +++ b/peer/hubpoller.go @@ -0,0 +1,110 @@ +package peer + +import ( + "encoding/json" + "io" + "log" + "net/http" + "net/url" + "time" + "vppn/m" +) + +type hubPoller struct { + client *http.Client + req *http.Request + versions [256]int64 + localIP byte + netName string + handleControlMsg func(fromIP byte, msg any) +} + +func newHubPoller( + localIP byte, + netName, + hubURL, + apiKey string, + handleControlMsg func(byte, any), +) (*hubPoller, error) { + u, err := url.Parse(hubURL) + if err != nil { + return nil, err + } + u.Path = "/peer/fetch-state/" + + client := &http.Client{Timeout: 8 * time.Second} + + req := &http.Request{ + Method: http.MethodGet, + URL: u, + Header: http.Header{}, + } + req.SetBasicAuth("", apiKey) + + return &hubPoller{ + client: client, + req: req, + localIP: localIP, + netName: netName, + handleControlMsg: handleControlMsg, + }, nil +} + +func (hp *hubPoller) logf(s string, args ...any) { + log.Printf("[HubPoller] "+s, args...) +} + +func (hp *hubPoller) Run() { + state, err := loadNetworkState(hp.netName) + if err != nil { + hp.logf("Failed to load network state: %v", err) + hp.logf("Polling hub...") + hp.pollHub() + } else { + hp.applyNetworkState(state) + } + + for range time.Tick(64 * time.Second) { + hp.pollHub() + } +} + +func (hp *hubPoller) pollHub() { + var state m.NetworkState + + resp, err := hp.client.Do(hp.req) + if err != nil { + hp.logf("Failed to fetch peer state: %v", err) + return + } + body, err := io.ReadAll(resp.Body) + _ = resp.Body.Close() + if err != nil { + hp.logf("Failed to read body from hub: %v", err) + return + } + + if err := json.Unmarshal(body, &state); err != nil { + hp.logf("Failed to unmarshal response from hub: %v\n%s", err, body) + return + } + + hp.applyNetworkState(state) + + if err := storeNetworkState(hp.netName, state); err != nil { + hp.logf("Failed to store network state: %v", err) + } +} + +func (hp *hubPoller) applyNetworkState(state m.NetworkState) { + for i, peer := range state.Peers { + if i != int(hp.localIP) { + if peer == nil || peer.Version != hp.versions[i] { + hp.handleControlMsg(byte(i), peerUpdateMsg{Peer: state.Peers[i]}) + if peer != nil { + hp.versions[i] = peer.Version + } + } + } + } +} diff --git a/peer/ifreader.go b/peer/ifreader.go new file mode 100644 index 0000000..2419758 --- /dev/null +++ b/peer/ifreader.go @@ -0,0 +1,103 @@ +package peer + +import ( + "io" + "log" + "net/netip" + "sync/atomic" +) + +type ifReader struct { + iface io.Reader + writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error) + rt *atomic.Pointer[routingTable] + buf1 []byte + buf2 []byte +} + +func newIFReader( + iface io.Reader, + writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error), + rt *atomic.Pointer[routingTable], +) *ifReader { + return &ifReader{iface, writeToUDPAddrPort, rt, newBuf(), newBuf()} +} + +func (r *ifReader) Run() { + packet := newBuf() + for { + r.handleNextPacket(packet) + } +} + +func (r *ifReader) handleNextPacket(packet []byte) { + packet = r.readNextPacket(packet) + remoteIP, ok := r.parsePacket(packet) + if !ok { + return + } + + rt := r.rt.Load() + peer := rt.Peers[remoteIP] + if !peer.Up { + r.logf("Peer %d not up.", peer.IP) + return + } + + enc := peer.EncryptDataPacket(peer.IP, packet, r.buf1) + if peer.Direct { + r.writeToUDPAddrPort(enc, peer.DirectAddr) + return + } + + relay, ok := rt.GetRelay() + if !ok { + r.logf("Relay not available for peer %d.", peer.IP) + return + } + + enc = relay.EncryptDataPacket(peer.IP, enc, r.buf2) + r.writeToUDPAddrPort(enc, relay.DirectAddr) +} + +func (r *ifReader) readNextPacket(buf []byte) []byte { + n, err := r.iface.Read(buf[:cap(buf)]) + if err != nil { + log.Fatalf("Failed to read from interface: %v", err) + } + + return buf[:n] +} + +func (r *ifReader) parsePacket(buf []byte) (byte, bool) { + n := len(buf) + if n == 0 { + return 0, false + } + + version := buf[0] >> 4 + + switch version { + case 4: + if n < 20 { + r.logf("Short IPv4 packet: %d", len(buf)) + return 0, false + } + return buf[19], true + + case 6: + if len(buf) < 40 { + r.logf("Short IPv6 packet: %d", len(buf)) + return 0, false + } + return buf[39], true + + default: + r.logf("Invalid IP packet version: %v", version) + return 0, false + } +} + +func (*ifReader) logf(s string, args ...any) { + log.Printf("[IFReader] "+s, args...) +} diff --git a/peer/ifreader_test.go b/peer/ifreader_test.go new file mode 100644 index 0000000..92ec5ac --- /dev/null +++ b/peer/ifreader_test.go @@ -0,0 +1,81 @@ +package peer + +/* +func TestIFReader_IPv4(t *testing.T) { + p1, p2, _ := NewPeersForTesting() + + pkt := make([]byte, 1234) + pkt[0] = 4 << 4 + pkt[19] = 2 // IP. + + p1.IFace.UserWrite(pkt) + p1.IFReader.handleNextPacket(newBuf()) + + packets := p2.Conn.Packets() + if len(packets) != 1 { + t.Fatal(packets) + } +} + +func TestIFReader_IPv6(t *testing.T) { + p1, p2, _ := NewPeersForTesting() + + pkt := make([]byte, 1234) + pkt[0] = 6 << 4 + pkt[39] = 2 // IP. + + p1.IFace.UserWrite(pkt) + p1.IFReader.handleNextPacket(newBuf()) + + packets := p2.Conn.Packets() + if len(packets) != 1 { + t.Fatal(packets) + } +} + +func TestIFReader_parsePacket_emptyPacket(t *testing.T) { + r := NewIFReader(nil, nil) + pkt := make([]byte, 0) + if ip, ok := r.parsePacket(pkt); ok { + t.Fatal(ip, ok) + } +} + +func TestIFReader_parsePacket_invalidIPVersion(t *testing.T) { + r := NewIFReader(nil, nil) + + for i := byte(1); i < 16; i++ { + if i == 4 || i == 6 { + continue + } + pkt := make([]byte, 1234) + pkt[0] = i << 4 + + if ip, ok := r.parsePacket(pkt); ok { + t.Fatal(i, ip, ok) + } + } +} + +func TestIFReader_parsePacket_shortIPv4(t *testing.T) { + r := NewIFReader(nil, nil) + + pkt := make([]byte, 19) + pkt[0] = 4 << 4 + + if ip, ok := r.parsePacket(pkt); ok { + t.Fatal(ip, ok) + } +} + +func TestIFReader_parsePacket_shortIPv6(t *testing.T) { + r := NewIFReader(nil, nil) + + pkt := make([]byte, 39) + pkt[0] = 6 << 4 + + if ip, ok := r.parsePacket(pkt); ok { + t.Fatal(ip, ok) + } +} +*/ diff --git a/node/interface.go b/peer/interface.go similarity index 84% rename from node/interface.go rename to peer/interface.go index 4b492b4..0022392 100644 --- a/node/interface.go +++ b/peer/interface.go @@ -1,9 +1,8 @@ -package node +package peer import ( "fmt" "io" - "log" "net" "os" "syscall" @@ -11,45 +10,6 @@ import ( "golang.org/x/sys/unix" ) -// Get next packet, returning packet, ip, and possible error. -func readNextPacket(iface io.ReadWriteCloser, buf []byte) ([]byte, byte, error) { - var ( - version byte - ip byte - ) - for { - n, err := iface.Read(buf[:cap(buf)]) - if err != nil { - return nil, ip, err - } - - buf = buf[:n] - version = buf[0] >> 4 - - switch version { - case 4: - if n < 20 { - log.Printf("Short IPv4 packet: %d", len(buf)) - continue - } - ip = buf[19] - - case 6: - if len(buf) < 40 { - log.Printf("Short IPv6 packet: %d", len(buf)) - continue - } - ip = buf[39] - - default: - log.Printf("Invalid IP packet version: %v", version) - continue - } - - return buf, ip, nil - } -} - func openInterface(network []byte, localIP byte, name string) (io.ReadWriteCloser, error) { if len(network) != 4 { return nil, fmt.Errorf("expected network to be 4 bytes, got %d", len(network)) diff --git a/peer/main.go b/peer/main.go new file mode 100644 index 0000000..9ab9ab7 --- /dev/null +++ b/peer/main.go @@ -0,0 +1,23 @@ +package peer + +import ( + "flag" + "os" +) + +func Main() { + conf := peerConfig{} + + flag.StringVar(&conf.NetName, "name", "", "[REQUIRED] The network name.") + flag.StringVar(&conf.HubAddress, "hub-address", "", "[REQUIRED] The hub address.") + flag.StringVar(&conf.APIKey, "api-key", "", "[REQUIRED] The node's API key.") + flag.Parse() + + if conf.NetName == "" || conf.HubAddress == "" || conf.APIKey == "" { + flag.Usage() + os.Exit(1) + } + + peer := newPeerMain(conf) + peer.Run() +} diff --git a/peer/mcreader.go b/peer/mcreader.go new file mode 100644 index 0000000..7b8af27 --- /dev/null +++ b/peer/mcreader.go @@ -0,0 +1,70 @@ +package peer + +import ( + "log" + "net" + "sync/atomic" + "time" +) + +func runMCReader( + rt *atomic.Pointer[routingTable], + handleControlMsg func(destIP byte, msg any), +) { + for { + runMCReaderInner(rt, handleControlMsg) + time.Sleep(broadcastErrorTimeoutInterval) + } +} + +func runMCReaderInner( + rt *atomic.Pointer[routingTable], + handleControlMsg func(destIP byte, msg any), +) { + var ( + raw = newBuf() + buf = newBuf() + logf = func(s string, args ...any) { + log.Printf("[MCReader] "+s, args...) + } + ) + + conn, err := net.ListenMulticastUDP("udp", nil, multicastAddr) + if err != nil { + logf("Failed to bind to multicast address: %v", err) + return + } + + for { + conn.SetReadDeadline(time.Now().Add(32 * time.Second)) + n, remoteAddr, err := conn.ReadFromUDPAddrPort(raw[:bufferSize]) + if err != nil { + logf("Failed to read from UDP port): %v", err) + return + } + + raw = raw[:n] + h, ok := headerFromLocalDiscoveryPacket(raw) + if !ok { + logf("Failed to open discovery packet?") + continue + } + + peer := rt.Load().Peers[h.SourceIP] + if peer.PubSignKey == nil { + logf("No signing key for peer %d.", h.SourceIP) + continue + } + + if !verifyLocalDiscoveryPacket(raw, buf, peer.PubSignKey) { + logf("Invalid signature from peer: %d", h.SourceIP) + continue + } + + msg := controlMsg[packetLocalDiscovery]{ + SrcIP: h.SourceIP, + SrcAddr: remoteAddr, + } + handleControlMsg(h.SourceIP, msg) + } +} diff --git a/peer/mcreader_test.go b/peer/mcreader_test.go new file mode 100644 index 0000000..60feb44 --- /dev/null +++ b/peer/mcreader_test.go @@ -0,0 +1,132 @@ +package peer + +/* +type mcMockConn struct { + packets chan []byte +} + +func newMCMockConn() *mcMockConn { + return &mcMockConn{make(chan []byte, 32)} +} + +func (c *mcMockConn) WriteToUDP(in []byte, addr *net.UDPAddr) (int, error) { + c.packets <- bytes.Clone(in) + return len(in), nil +} + +func (c *mcMockConn) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) { + buf := <-c.packets + b = b[:len(buf)] + copy(b, buf) + return len(b), netip.AddrPort{}, nil +} + +func TestMCReader(t *testing.T) { + keys := generateKeys() + super := &mockControlMsgHandler{} + conn := newMCMockConn() + + peers := [256]*atomic.Pointer[RemotePeer]{} + peer := &RemotePeer{ + IP: 1, + Up: true, + PubSignKey: keys.PubSignKey, + } + peers[1] = &atomic.Pointer[RemotePeer]{} + peers[1].Store(peer) + + w := newMCWriter(conn, 1, keys.PrivSignKey) + r := newMCReader(conn, super, peers) + + w.SendLocalDiscovery() + r.handleNextPacket() + + if len(super.Messages) != 1 { + t.Fatal(super.Messages) + } + msg, ok := super.Messages[0].(controlMsg[PacketLocalDiscovery]) + if !ok || msg.SrcIP != 1 { + t.Fatal(ok, msg) + } +} + +func TestMCReader_noHeader(t *testing.T) { + keys := generateKeys() + super := &mockControlMsgHandler{} + conn := newMCMockConn() + + peers := [256]*atomic.Pointer[RemotePeer]{} + peer := &RemotePeer{ + IP: 1, + Up: true, + PubSignKey: keys.PubSignKey, + } + peers[1] = &atomic.Pointer[RemotePeer]{} + peers[1].Store(peer) + + r := newMCReader(conn, super, peers) + conn.WriteToUDP([]byte("0123546789"), nil) + r.handleNextPacket() + + if len(super.Messages) != 0 { + t.Fatal(super.Messages) + } +} + +func TestMCReader_noPeer(t *testing.T) { + keys := generateKeys() + super := &mockControlMsgHandler{} + conn := newMCMockConn() + + peers := [256]*atomic.Pointer[RemotePeer]{} + peer := &RemotePeer{ + IP: 1, + Up: true, + PubSignKey: keys.PubSignKey, + } + peers[1] = &atomic.Pointer[RemotePeer]{} + peers[2] = &atomic.Pointer[RemotePeer]{} + peers[1].Store(peer) + + w := newMCWriter(conn, 2, keys.PrivSignKey) + r := newMCReader(conn, super, peers) + + w.SendLocalDiscovery() + r.handleNextPacket() + + if len(super.Messages) != 0 { + t.Fatal(super.Messages) + } +} + +func TestMCReader_badSignature(t *testing.T) { + keys := generateKeys() + super := &mockControlMsgHandler{} + conn := newMCMockConn() + + peers := [256]*atomic.Pointer[RemotePeer]{} + peer := &RemotePeer{ + IP: 1, + Up: true, + PubSignKey: keys.PubSignKey, + } + peers[1] = &atomic.Pointer[RemotePeer]{} + peers[1].Store(peer) + + w := newMCWriter(conn, 1, keys.PrivSignKey) + w.SendLocalDiscovery() + + // Break signing. + packet := <-conn.packets + packet[0]++ + conn.packets <- packet + + r := newMCReader(conn, super, peers) + + r.handleNextPacket() + + if len(super.Messages) != 0 { + t.Fatal(super.Messages) + } +} +*/ diff --git a/peer/mcwriter.go b/peer/mcwriter.go new file mode 100644 index 0000000..eb53af4 --- /dev/null +++ b/peer/mcwriter.go @@ -0,0 +1,53 @@ +package peer + +import ( + "log" + "net" + "time" + + "golang.org/x/crypto/nacl/sign" +) + +func createLocalDiscoveryPacket(localIP byte, signingKey []byte) []byte { + h := header{ + SourceIP: localIP, + DestIP: 255, + } + buf := make([]byte, headerSize) + h.Marshal(buf) + out := make([]byte, headerSize+signOverhead) + return sign.Sign(out[:0], buf, (*[64]byte)(signingKey)) +} + +func headerFromLocalDiscoveryPacket(pkt []byte) (h header, ok bool) { + if len(pkt) != headerSize+signOverhead { + return + } + + h.Parse(pkt[signOverhead:]) + ok = true + return +} + +func verifyLocalDiscoveryPacket(pkt, buf []byte, pubSignKey []byte) bool { + _, ok := sign.Open(buf[:0], pkt, (*[32]byte)(pubSignKey)) + return ok +} + +// ---------------------------------------------------------------------------- + +func runMCWriter(localIP byte, signingKey []byte) { + discoveryPacket := createLocalDiscoveryPacket(localIP, signingKey) + + conn, err := net.ListenMulticastUDP("udp", nil, multicastAddr) + if err != nil { + log.Fatalf("[MCWriter] Failed to bind to multicast address: %v", err) + } + + for range time.Tick(broadcastInterval) { + _, err := conn.WriteToUDP(discoveryPacket, multicastAddr) + if err != nil { + log.Printf("[MCWriter] Failed to write multicast: %v", err) + } + } +} diff --git a/peer/mcwriter_test.go b/peer/mcwriter_test.go new file mode 100644 index 0000000..74411f4 --- /dev/null +++ b/peer/mcwriter_test.go @@ -0,0 +1,98 @@ +package peer + +/* +// ---------------------------------------------------------------------------- + +// Testing that we can create and verify a local discovery packet. +func TestVerifyLocalDiscoveryPacket_valid(t *testing.T) { + keys := generateKeys() + + created := createLocalDiscoveryPacket(55, keys.PrivSignKey) + + header, ok := headerFromLocalDiscoveryPacket(created) + if !ok { + t.Fatal(ok) + } + if header.SourceIP != 55 || header.DestIP != 255 { + t.Fatal(header) + } + + if !verifyLocalDiscoveryPacket(created, make([]byte, 1024), keys.PubSignKey) { + t.Fatal("Not valid") + } +} + +// Testing that we don't try to parse short packets. +func TestVerifyLocalDiscoveryPacket_tooShort(t *testing.T) { + keys := generateKeys() + + created := createLocalDiscoveryPacket(55, keys.PrivSignKey) + + _, ok := headerFromLocalDiscoveryPacket(created[:len(created)-1]) + if ok { + t.Fatal(ok) + } +} + +// Testing that modifying a packet makes it invalid. +func TestVerifyLocalDiscoveryPacket_invalid(t *testing.T) { + keys := generateKeys() + + created := createLocalDiscoveryPacket(55, keys.PrivSignKey) + buf := make([]byte, 1024) + for i := range created { + modified := bytes.Clone(created) + modified[i]++ + if verifyLocalDiscoveryPacket(modified, buf, keys.PubSignKey) { + t.Fatal("Verification should have failed.") + } + } +} + +// ---------------------------------------------------------------------------- + +type testUDPWriter struct { + written [][]byte +} + +func (w *testUDPWriter) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) { + w.written = append(w.written, bytes.Clone(b)) + return len(b), nil +} + +func (w *testUDPWriter) Written() [][]byte { + out := w.written + w.written = [][]byte{} + return out +} + +// ---------------------------------------------------------------------------- + +// Testing that the mcWriter sends local discovery packets as expected. +func TestMCWriter_SendLocalDiscovery(t *testing.T) { + keys := generateKeys() + writer := &testUDPWriter{} + + mcw := newMCWriter(writer, 42, keys.PrivSignKey) + mcw.SendLocalDiscovery() + + out := writer.Written() + if len(out) != 1 { + t.Fatal(out) + } + + pkt := out[0] + + header, ok := headerFromLocalDiscoveryPacket(pkt) + if !ok { + t.Fatal(ok) + } + if header.SourceIP != 42 || header.DestIP != 255 { + t.Fatal(header) + } + + if !verifyLocalDiscoveryPacket(pkt, make([]byte, 1024), keys.PubSignKey) { + t.Fatal("Verification should succeed.") + } +} +*/ diff --git a/peer/mock-iface_test.go b/peer/mock-iface_test.go new file mode 100644 index 0000000..ffef5d9 --- /dev/null +++ b/peer/mock-iface_test.go @@ -0,0 +1,31 @@ +package peer + +import "bytes" + +type TestIFace struct { + out *bytes.Buffer // Toward the network. + in *bytes.Buffer // From the network +} + +func NewTestIFace() *TestIFace { + return &TestIFace{ + out: &bytes.Buffer{}, + in: &bytes.Buffer{}, + } +} + +func (iface *TestIFace) Write(b []byte) (int, error) { + return iface.in.Write(b) +} + +func (iface *TestIFace) Read(b []byte) (int, error) { + return iface.out.Read(b) +} + +func (iface *TestIFace) UserWrite(b []byte) (int, error) { + return iface.out.Write(b) +} + +func (iface *TestIFace) UserRead(b []byte) (int, error) { + return iface.in.Read(b) +} diff --git a/peer/mock-network_test.go b/peer/mock-network_test.go new file mode 100644 index 0000000..4b5240c --- /dev/null +++ b/peer/mock-network_test.go @@ -0,0 +1,80 @@ +package peer + +import ( + "bytes" + "net" + "net/netip" + "sync" +) + +type TestPacket struct { + Addr netip.AddrPort + Data []byte +} + +type TestNetwork struct { + lock sync.Mutex + packets map[netip.AddrPort]chan TestPacket +} + +func NewTestNetwork() *TestNetwork { + return &TestNetwork{packets: map[netip.AddrPort]chan TestPacket{}} +} + +func (n *TestNetwork) NewUDPConn(localAddr netip.AddrPort) *TestUDPConn { + n.lock.Lock() + defer n.lock.Unlock() + if _, ok := n.packets[localAddr]; !ok { + n.packets[localAddr] = make(chan TestPacket, 1024) + } + return &TestUDPConn{ + addr: localAddr, + n: n, + packets: n.packets[localAddr], + } +} + +func (n *TestNetwork) write(b []byte, from, to netip.AddrPort) { + n.lock.Lock() + defer n.lock.Unlock() + if _, ok := n.packets[to]; !ok { + n.packets[to] = make(chan TestPacket, 1024) + } + n.packets[to] <- TestPacket{ + Addr: from, + Data: bytes.Clone(b), + } +} + +type TestUDPConn struct { + addr netip.AddrPort + n *TestNetwork + packets chan TestPacket +} + +func (c *TestUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { + c.n.write(b, c.addr, addr) + return len(b), nil +} + +func (c *TestUDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) { + return c.WriteToUDPAddrPort(b, addr.AddrPort()) +} + +func (c *TestUDPConn) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) { + pkt := <-c.packets + b = b[:len(pkt.Data)] + copy(b, pkt.Data) + return len(b), pkt.Addr, nil +} + +func (c *TestUDPConn) Packets() (out []TestPacket) { + for { + select { + case pkt := <-c.packets: + out = append(out, pkt) + default: + return + } + } +} diff --git a/node/packets-util.go b/peer/packets-util.go similarity index 95% rename from node/packets-util.go rename to peer/packets-util.go index b3071ab..c0264e5 100644 --- a/node/packets-util.go +++ b/peer/packets-util.go @@ -1,4 +1,4 @@ -package node +package peer import ( "net/netip" @@ -70,7 +70,7 @@ func (w *binWriter) AddrPort(addrPort netip.AddrPort) *binWriter { return w.Uint16(addrPort.Port()) } -func (w *binWriter) AddrPortArray(l [8]netip.AddrPort) *binWriter { +func (w *binWriter) AddrPort8(l [8]netip.AddrPort) *binWriter { for _, addrPort := range l { w.AddrPort(addrPort) } @@ -178,7 +178,7 @@ func (r *binReader) AddrPort(x *netip.AddrPort) *binReader { return r } -func (r *binReader) AddrPortArray(x *[8]netip.AddrPort) *binReader { +func (r *binReader) AddrPort8(x *[8]netip.AddrPort) *binReader { for i := range x { r.AddrPort(&x[i]) } diff --git a/node/packets-util_test.go b/peer/packets-util_test.go similarity index 75% rename from node/packets-util_test.go rename to peer/packets-util_test.go index 96eab1a..6e4a98c 100644 --- a/node/packets-util_test.go +++ b/peer/packets-util_test.go @@ -1,4 +1,4 @@ -package node +package peer import ( "net/netip" @@ -6,6 +6,26 @@ import ( "testing" ) +func TestBinWriteRead_invalidAddrPort(t *testing.T) { + addr := netip.AddrPort{} + buf := make([]byte, 1024) + buf = newBinWriter(buf). + AddrPort(addr). + Build() + + var addr2 netip.AddrPort + err := newBinReader(buf). + AddrPort(&addr2). + Error() + if err != nil { + t.Fatal(err) + } + + if addr2.IsValid() { + t.Fatal(addr, addr2) + } +} + func TestBinWriteRead(t *testing.T) { buf := make([]byte, 1024) @@ -35,7 +55,7 @@ func TestBinWriteRead(t *testing.T) { Byte(in.Type). Uint64(in.TraceID). AddrPort(in.DestAddr). - AddrPortArray(in.Addrs). + AddrPort8(in.Addrs). Build() out := Item{} @@ -44,7 +64,7 @@ func TestBinWriteRead(t *testing.T) { Byte(&out.Type). Uint64(&out.TraceID). AddrPort(&out.DestAddr). - AddrPortArray(&out.Addrs). + AddrPort8(&out.Addrs). Error() if err != nil { t.Fatal(err) diff --git a/peer/packets.go b/peer/packets.go new file mode 100644 index 0000000..b673a4c --- /dev/null +++ b/peer/packets.go @@ -0,0 +1,120 @@ +package peer + +import ( + "net/netip" +) + +const ( + packetTypeSyn = 1 + packetTypeInit = 2 + packetTypeAck = 3 + packetTypeProbe = 4 + packetTypeAddrDiscovery = 5 +) + +// ---------------------------------------------------------------------------- + +type packetInit struct { + TraceID uint64 + Direct bool + Version uint64 +} + +func (p packetInit) Marshal(buf []byte) []byte { + return newBinWriter(buf). + Byte(packetTypeInit). + Uint64(p.TraceID). + Bool(p.Direct). + Uint64(p.Version). + Build() +} + +func parsePacketInit(buf []byte) (p packetInit, err error) { + err = newBinReader(buf[1:]). + Uint64(&p.TraceID). + Bool(&p.Direct). + Uint64(&p.Version). + Error() + return +} + +// ---------------------------------------------------------------------------- + +type packetSyn struct { + TraceID uint64 // TraceID to match response w/ request. + SharedKey [32]byte // Our shared key. + Direct bool + PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. +} + +func (p packetSyn) Marshal(buf []byte) []byte { + return newBinWriter(buf). + Byte(packetTypeSyn). + Uint64(p.TraceID). + SharedKey(p.SharedKey). + Bool(p.Direct). + AddrPort8(p.PossibleAddrs). + Build() +} + +func parsePacketSyn(buf []byte) (p packetSyn, err error) { + err = newBinReader(buf[1:]). + Uint64(&p.TraceID). + SharedKey(&p.SharedKey). + Bool(&p.Direct). + AddrPort8(&p.PossibleAddrs). + Error() + return +} + +// ---------------------------------------------------------------------------- + +type packetAck struct { + TraceID uint64 + ToAddr netip.AddrPort + PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. +} + +func (p packetAck) Marshal(buf []byte) []byte { + return newBinWriter(buf). + Byte(packetTypeAck). + Uint64(p.TraceID). + AddrPort(p.ToAddr). + AddrPort8(p.PossibleAddrs). + Build() +} + +func parsePacketAck(buf []byte) (p packetAck, err error) { + err = newBinReader(buf[1:]). + Uint64(&p.TraceID). + AddrPort(&p.ToAddr). + AddrPort8(&p.PossibleAddrs). + Error() + return +} + +// ---------------------------------------------------------------------------- + +// A probeReqPacket is sent from a client to a server to determine if direct +// UDP communication can be used. +type packetProbe struct { + TraceID uint64 +} + +func (p packetProbe) Marshal(buf []byte) []byte { + return newBinWriter(buf). + Byte(packetTypeProbe). + Uint64(p.TraceID). + Build() +} + +func parsePacketProbe(buf []byte) (p packetProbe, err error) { + err = newBinReader(buf[1:]). + Uint64(&p.TraceID). + Error() + return +} + +// ---------------------------------------------------------------------------- + +type packetLocalDiscovery struct{} diff --git a/peer/packets_test.go b/peer/packets_test.go new file mode 100644 index 0000000..c18b40a --- /dev/null +++ b/peer/packets_test.go @@ -0,0 +1,66 @@ +package peer + +import ( + "crypto/rand" + "net/netip" + "reflect" + "testing" +) + +func TestSynPacket(t *testing.T) { + p := packetSyn{ + TraceID: newTraceID(), + //SentAt: time.Now().UnixMilli(), + //SharedKeyType: 1, + Direct: true, + } + rand.Read(p.SharedKey[:]) + + p.PossibleAddrs[0] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 234) + p.PossibleAddrs[1] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 2, 3, 4}), 12399) + p.PossibleAddrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{3, 2, 3, 4}), 60000) + + buf := p.Marshal(newBuf()) + p2, err := parsePacketSyn(buf) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(p, p2) { + t.Fatal(p2) + } +} + +func TestAckPacket(t *testing.T) { + p := packetAck{ + TraceID: newTraceID(), + ToAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 234), + } + + p.PossibleAddrs[0] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 2, 3, 4}), 100) + p.PossibleAddrs[1] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 2, 3, 4}), 12399) + p.PossibleAddrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{3, 2, 3, 4}), 60000) + + buf := p.Marshal(newBuf()) + p2, err := parsePacketAck(buf) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(p, p2) { + t.Fatal(p2) + } +} + +func TestProbePacket(t *testing.T) { + p := packetProbe{ + TraceID: newTraceID(), + } + + buf := p.Marshal(newBuf()) + p2, err := parsePacketProbe(buf) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(p, p2) { + t.Fatal(p2) + } +} diff --git a/peer/peer.go b/peer/peer.go new file mode 100644 index 0000000..c210af4 --- /dev/null +++ b/peer/peer.go @@ -0,0 +1,177 @@ +package peer + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "log" + "net" + "net/http" + "net/netip" + "net/url" + "sync" + "sync/atomic" + "vppn/m" +) + +type peerMain struct { + conf localConfig + rt *atomic.Pointer[routingTable] + ifReader *ifReader + connReader *connReader + iface io.Writer + hubPoller *hubPoller + super *supervisor +} + +type peerConfig struct { + NetName string + HubAddress string + APIKey string +} + +func newPeerMain(conf peerConfig) *peerMain { + logf := func(s string, args ...any) { + log.Printf("[Main] "+s, args...) + } + + config, err := loadPeerConfig(conf.NetName) + if err != nil { + logf("Failed to load configuration: %v", err) + logf("Initializing...") + initPeerWithHub(conf) + + config, err = loadPeerConfig(conf.NetName) + if err != nil { + log.Fatalf("Failed to load configuration: %v", err) + } + } + + iface, err := openInterface(config.Network, config.PeerIP, conf.NetName) + if err != nil { + log.Fatalf("Failed to open interface: %v", err) + } + + myAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", config.Port)) + if err != nil { + log.Fatalf("Failed to resolve UDP address: %v", err) + } + + logf("Listening on %v...", myAddr) + conn, err := net.ListenUDP("udp", myAddr) + if err != nil { + log.Fatalf("Failed to open UDP port: %v", err) + } + + conn.SetReadBuffer(1024 * 1024 * 8) + conn.SetWriteBuffer(1024 * 1024 * 8) + + // Wrap write function - this is necessary to avoid starvation. + writeLock := sync.Mutex{} + writeToUDPAddrPort := func(b []byte, addr netip.AddrPort) (n int, err error) { + writeLock.Lock() + n, err = conn.WriteToUDPAddrPort(b, addr) + if err != nil { + logf("Failed to write packet: %v", err) + } + writeLock.Unlock() + return n, err + } + + var localAddr netip.AddrPort + ip, localAddrValid := netip.AddrFromSlice(config.PublicIP) + if localAddrValid { + localAddr = netip.AddrPortFrom(ip, config.Port) + } + + rt := newRoutingTable(config.PeerIP, localAddr) + rtPtr := &atomic.Pointer[routingTable]{} + rtPtr.Store(&rt) + + ifReader := newIFReader(iface, writeToUDPAddrPort, rtPtr) + super := newSupervisor(writeToUDPAddrPort, rtPtr, config.PrivKey) + connReader := newConnReader(conn.ReadFromUDPAddrPort, writeToUDPAddrPort, iface, super.HandleControlMsg, rtPtr) + hubPoller, err := newHubPoller(config.PeerIP, conf.NetName, conf.HubAddress, conf.APIKey, super.HandleControlMsg) + if err != nil { + log.Fatalf("Failed to create hub poller: %v", err) + } + + return &peerMain{ + conf: config, + rt: rtPtr, + iface: iface, + ifReader: ifReader, + connReader: connReader, + hubPoller: hubPoller, + super: super, + } +} + +func (p *peerMain) Run() { + + go p.ifReader.Run() + go p.connReader.Run() + p.super.Start() + + if !p.rt.Load().LocalAddr.IsValid() { + go runMCWriter(p.conf.PeerIP, p.conf.PrivSignKey) + go runMCReader(p.rt, p.super.HandleControlMsg) + } + + go p.hubPoller.Run() + select {} +} + +func initPeerWithHub(conf peerConfig) { + keys := generateKeys() + + initURL, err := url.Parse(conf.HubAddress) + if err != nil { + log.Fatalf("Failed to parse hub URL: %v", err) + } + initURL.Path = "/peer/init/" + + args := m.PeerInitArgs{ + EncPubKey: keys.PubKey, + PubSignKey: keys.PubSignKey, + } + + buf := &bytes.Buffer{} + if err := json.NewEncoder(buf).Encode(args); err != nil { + log.Fatalf("Failed to encode init args: %v", err) + } + + req, err := http.NewRequest(http.MethodPost, initURL.String(), buf) + if err != nil { + log.Fatalf("Failed to construct request: %v", err) + } + req.SetBasicAuth("", conf.APIKey) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + log.Fatalf("Failed to init with hub: %v", err) + } + defer resp.Body.Close() + + data, err := io.ReadAll(resp.Body) + if err != nil { + log.Fatalf("Failed to read response body: %v", err) + } + + peerConfig := localConfig{} + if err := json.Unmarshal(data, &peerConfig.PeerConfig); err != nil { + log.Fatalf("Failed to parse configuration: %v\n%s", err, data) + } + + peerConfig.PubKey = keys.PubKey + peerConfig.PrivKey = keys.PrivKey + peerConfig.PubSignKey = keys.PubSignKey + peerConfig.PrivSignKey = keys.PrivSignKey + + if err := storePeerConfig(conf.NetName, peerConfig); err != nil { + log.Fatalf("Failed to store configuration: %v", err) + } + + log.Print("Initialization successful.") +} diff --git a/peer/peer_test.go b/peer/peer_test.go new file mode 100644 index 0000000..2c25812 --- /dev/null +++ b/peer/peer_test.go @@ -0,0 +1,114 @@ +package peer + +import ( + "bytes" + "crypto/rand" + mrand "math/rand" + "net/netip" + "sync/atomic" +) + +// A test peer. +type P struct { + cryptoKeys + RT *atomic.Pointer[routingTable] + Conn *TestUDPConn + IFace *TestIFace + ConnReader *connReader + IFReader *ifReader +} + +func NewPeerForTesting(n *TestNetwork, ip byte, addr netip.AddrPort) P { + p := P{ + cryptoKeys: generateKeys(), + RT: &atomic.Pointer[routingTable]{}, + IFace: NewTestIFace(), + } + + rt := newRoutingTable(ip, addr) + p.RT.Store(&rt) + p.Conn = n.NewUDPConn(addr) + //p.ConnWriter = NewConnWriter(p.Conn.WriteToUDPAddrPort, p.RT) + + return p +} + +func ConnectPeers(p1, p2 *P) { + rt1 := p1.RT.Load() + rt2 := p2.RT.Load() + + ip1 := rt1.LocalIP + ip2 := rt2.LocalIP + + rt1.Peers[ip2].Up = true + rt1.Peers[ip2].Direct = true + rt1.Peers[ip2].Relay = true + rt1.Peers[ip2].DirectAddr = rt2.LocalAddr + rt1.Peers[ip2].PubSignKey = p2.PubSignKey + rt1.Peers[ip2].ControlCipher = newControlCipher(p1.PrivKey, p2.PubKey) + rt1.Peers[ip2].DataCipher = newDataCipher() + + rt2.Peers[ip1].Up = true + rt2.Peers[ip1].Direct = true + rt2.Peers[ip1].Relay = true + rt2.Peers[ip1].DirectAddr = rt1.LocalAddr + rt2.Peers[ip1].PubSignKey = p1.PubSignKey + rt2.Peers[ip1].ControlCipher = newControlCipher(p2.PrivKey, p1.PubKey) + rt2.Peers[ip1].DataCipher = rt1.Peers[ip2].DataCipher +} + +func NewPeersForTesting() (p1, p2, p3 P) { + n := NewTestNetwork() + + p1 = NewPeerForTesting( + n, + 1, + netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 100)) + + p2 = NewPeerForTesting( + n, + 2, + netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 2}), 200)) + + p3 = NewPeerForTesting( + n, + 3, + netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 3}), 300)) + + ConnectPeers(&p1, &p2) + ConnectPeers(&p1, &p3) + ConnectPeers(&p2, &p3) + + return +} + +func RandPacket() []byte { + n := mrand.Intn(1200) + b := make([]byte, n) + rand.Read(b) + return b +} + +func ModifyPacket(in []byte) []byte { + x := make([]byte, 1) + + for { + rand.Read(x) + out := bytes.Clone(in) + idx := mrand.Intn(len(out)) + if out[idx] != x[0] { + out[idx] = x[0] + return out + } + } +} + +// ---------------------------------------------------------------------------- + +type UnknownControlPacket struct { + TraceID uint64 +} + +func (p UnknownControlPacket) Marshal(buf []byte) []byte { + return newBinWriter(buf).Byte(255).Uint64(p.TraceID).Build() +} diff --git a/peer/peerstates_test.go b/peer/peerstates_test.go new file mode 100644 index 0000000..32dc207 --- /dev/null +++ b/peer/peerstates_test.go @@ -0,0 +1,371 @@ +package peer + +import ( + "testing" + "vppn/m" +) + +// ---------------------------------------------------------------------------- + +func TestPeerState_OnPeerUpdate_nilPeer(t *testing.T) { + h := NewPeerStateTestHarness() + h.PeerUpdate(nil) + assertType[*stateDisconnected](t, h.State) +} + +func TestPeerState_OnPeerUpdate_publicLocalIsServer(t *testing.T) { + keys := generateKeys() + h := NewPeerStateTestHarness() + + state := h.State.(*stateDisconnected) + state.localAddr = addrPort4(1, 1, 1, 2, 200) + + peer := &m.Peer{ + PeerIP: 3, + Port: 456, + PubKey: keys.PubKey, + PubSignKey: keys.PubSignKey, + } + + h.PeerUpdate(peer) + assertEqual(t, h.Published.Up, false) + assertType[*stateServer](t, h.State) +} + +/* + +func TestPeerState_OnPeerUpdate_clientDirect(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientDirect(t) +} + +/* +func TestPeerState_OnPeerUpdate_clientRelayed(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientRelayed(t) +} + +/* +func TestStateServer_directSyn(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigServer_Relayed(t) + + assertEqual(t, h.Published.Up, false) + + synMsg := controlMsg[packetSyn]{ + SrcIP: 3, + SrcAddr: addrPort4(1, 1, 1, 3, 300), + Packet: packetSyn{ + TraceID: newTraceID(), + //SentAt: time.Now().UnixMilli(), + //SharedKeyType: 1, + Direct: true, + }, + } + + h.State = h.State.OnMsg(synMsg) + + assertEqual(t, len(h.Sent), 1) + ack := assertType[packetAck](t, h.Sent[0].Packet) + assertEqual(t, ack.TraceID, synMsg.Packet.TraceID) + assertEqual(t, h.Sent[0].Peer.IP, 3) + assertEqual(t, ack.PossibleAddrs[0].IsValid(), false) + assertEqual(t, h.Published.Up, true) +} + +func TestStateServer_relayedSyn(t *testing.T) { + h := NewPeerStateTestHarness() + state := h.ConfigServer_Relayed(t) + + state.pubAddrs.Store(addrPort4(4, 5, 6, 7, 1234)) + + assertEqual(t, h.Published.Up, false) + + synMsg := controlMsg[packetSyn]{ + SrcIP: 3, + SrcAddr: addrPort4(1, 1, 1, 3, 300), + Packet: packetSyn{ + TraceID: newTraceID(), + //SentAt: time.Now().UnixMilli(), + //SharedKeyType: 1, + Direct: false, + }, + } + synMsg.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 3, 300) + synMsg.Packet.PossibleAddrs[1] = addrPort4(2, 2, 2, 3, 300) + + h.State = h.State.OnMsg(synMsg) + + assertEqual(t, len(h.Sent), 3) + + ack := assertType[packetAck](t, h.Sent[0].Packet) + assertEqual(t, ack.TraceID, synMsg.Packet.TraceID) + assertEqual(t, h.Sent[0].Peer.IP, 3) + assertEqual(t, ack.PossibleAddrs[0], addrPort4(4, 5, 6, 7, 1234)) + assertEqual(t, ack.PossibleAddrs[1].IsValid(), false) + assertEqual(t, h.Published.Up, true) + + assertType[packetProbe](t, h.Sent[1].Packet) + assertType[packetProbe](t, h.Sent[2].Packet) + assertEqual(t, h.Sent[1].Peer.DirectAddr, addrPort4(1, 1, 1, 3, 300)) + assertEqual(t, h.Sent[2].Peer.DirectAddr, addrPort4(2, 2, 2, 3, 300)) +} + +func TestStateServer_onProbe(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigServer_Relayed(t) + assertEqual(t, h.Published.Up, false) + + probeMsg := controlMsg[packetProbe]{ + SrcIP: 3, + SrcAddr: addrPort4(1, 1, 1, 3, 300), + Packet: packetProbe{TraceID: newTraceID()}, + } + + h.State = h.State.OnMsg(probeMsg) + + assertEqual(t, len(h.Sent), 1) + + probe := assertType[packetProbe](t, h.Sent[0].Packet) + assertEqual(t, probe.TraceID, probeMsg.Packet.TraceID) + assertEqual(t, h.Sent[0].Peer.DirectAddr, addrPort4(1, 1, 1, 3, 300)) +} + +func TestStateServer_OnPingTimer_timeout(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigServer_Relayed(t) + + synMsg := controlMsg[packetSyn]{ + SrcIP: 3, + SrcAddr: addrPort4(1, 1, 1, 3, 300), + Packet: packetSyn{ + TraceID: newTraceID(), + //SentAt: time.Now().UnixMilli(), + //SharedKeyType: 1, + Direct: true, + }, + } + + h.State = h.State.OnMsg(synMsg) + assertEqual(t, len(h.Sent), 1) + assertEqual(t, h.Published.Up, true) + + // Ping shouldn't timeout. + h.OnPingTimer() + assertEqual(t, h.Published.Up, true) + + // Advance the time, then ping. + state := assertType[*stateServer](t, h.State) + state.lastSeen = time.Now().Add(-timeoutInterval - time.Second) + + h.OnPingTimer() + assertEqual(t, h.Published.Up, false) +} + +func TestStateClientDirect_OnAck(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientDirect(t) + + assertEqual(t, h.Published.Up, false) + + // On entering the state, a SYN should have been sent. + assertEqual(t, len(h.Sent), 1) + syn := assertType[packetSyn](t, h.Sent[0].Packet) + + ack := controlMsg[packetAck]{ + Packet: packetAck{TraceID: syn.TraceID}, + } + h.State = h.State.OnMsg(ack) + assertEqual(t, h.Published.Up, true) +} + +func TestStateClientDirect_OnAck_incorrectTraceID(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientDirect(t) + + assertEqual(t, h.Published.Up, false) + + // On entering the state, a SYN should have been sent. + assertEqual(t, len(h.Sent), 1) + syn := assertType[packetSyn](t, h.Sent[0].Packet) + + ack := controlMsg[packetAck]{ + Packet: packetAck{TraceID: syn.TraceID + 1}, + } + h.State = h.State.OnMsg(ack) + assertEqual(t, h.Published.Up, false) +} + +func TestStateClientDirect_OnPingTimer(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientDirect(t) + + // On entering the state, a SYN should have been sent. + assertEqual(t, len(h.Sent), 1) + assertType[packetSyn](t, h.Sent[0].Packet) + + h.OnPingTimer() + + // On ping timer, another syn should be sent. Additionally, we should remain + // in the same state. + assertEqual(t, len(h.Sent), 2) + assertType[packetSyn](t, h.Sent[1].Packet) + assertType[*stateClientDirect](t, h.State) + assertEqual(t, h.Published.Up, false) +} + +func TestStateClientDirect_OnPingTimer_timeout(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientDirect(t) + + assertEqual(t, h.Published.Up, false) + + // On entering the state, a SYN should have been sent. + assertEqual(t, len(h.Sent), 1) + syn := assertType[packetSyn](t, h.Sent[0].Packet) + + ack := controlMsg[packetAck]{ + Packet: packetAck{TraceID: syn.TraceID}, + } + h.State = h.State.OnMsg(ack) + assertEqual(t, h.Published.Up, true) + + state := assertType[*stateClientDirect](t, h.State) + state.lastSeen = time.Now().Add(-(timeoutInterval + time.Second)) + + h.OnPingTimer() + + // On ping timer, we should timeout, causing the client to reset. Another SYN + // will be sent when re-entering the state, but the connection should be down. + assertEqual(t, len(h.Sent), 2) + assertType[packetSyn](t, h.Sent[1].Packet) + assertType[*stateClientDirect](t, h.State) + assertEqual(t, h.Published.Up, false) +} + +func TestStateClientRelayed_OnAck(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientRelayed(t) + + assertEqual(t, h.Published.Up, false) + + // On entering the state, a SYN should have been sent. + assertEqual(t, len(h.Sent), 1) + syn := assertType[packetSyn](t, h.Sent[0].Packet) + + ack := controlMsg[packetAck]{ + Packet: packetAck{TraceID: syn.TraceID}, + } + h.State = h.State.OnMsg(ack) + assertEqual(t, h.Published.Up, true) +} + +func TestStateClientRelayed_OnPingTimer_noAddrs(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientRelayed(t) + + assertEqual(t, h.Published.Up, false) + + // On entering the state, a SYN should have been sent. + assertEqual(t, len(h.Sent), 1) + + // If we haven't had an ack yet, we won't have addresses to probe. Therefore + // we'll have just one more syn packet sent. + h.OnPingTimer() + assertEqual(t, len(h.Sent), 2) +} + +func TestStateClientRelayed_OnPingTimer_withAddrs(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientRelayed(t) + + assertEqual(t, h.Published.Up, false) + + // On entering the state, a SYN should have been sent. + assertEqual(t, len(h.Sent), 1) + + syn := assertType[packetSyn](t, h.Sent[0].Packet) + + ack := controlMsg[packetAck]{Packet: packetAck{TraceID: syn.TraceID}} + ack.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 300) + ack.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 300) + + h.State = h.State.OnMsg(ack) + + // Add a local discovery address. Note that the port will be configured port + // and no the one provided here. + h.State = h.State.OnMsg(controlMsg[packetLocalDiscovery]{ + SrcIP: 3, + SrcAddr: addrPort4(2, 2, 2, 3, 300), + }) + + // We should see one SYN and three probe packets. + h.OnPingTimer() + assertEqual(t, len(h.Sent), 5) + assertType[packetSyn](t, h.Sent[1].Packet) + assertType[packetProbe](t, h.Sent[2].Packet) + assertType[packetProbe](t, h.Sent[3].Packet) + assertType[packetProbe](t, h.Sent[4].Packet) + + assertEqual(t, h.Sent[2].Peer.DirectAddr, addrPort4(1, 1, 1, 1, 300)) + assertEqual(t, h.Sent[3].Peer.DirectAddr, addrPort4(1, 1, 1, 2, 300)) + assertEqual(t, h.Sent[4].Peer.DirectAddr, addrPort4(2, 2, 2, 3, 456)) +} + +func TestStateClientRelayed_OnPingTimer_timeout(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientRelayed(t) + + // On entering the state, a SYN should have been sent. + assertEqual(t, len(h.Sent), 1) + syn := assertType[packetSyn](t, h.Sent[0].Packet) + + ack := controlMsg[packetAck]{ + Packet: packetAck{TraceID: syn.TraceID}, + } + h.State = h.State.OnMsg(ack) + assertEqual(t, h.Published.Up, true) + + state := assertType[*stateClientRelayed](t, h.State) + state.lastSeen = time.Now().Add(-(timeoutInterval + time.Second)) + + h.OnPingTimer() + + // On ping timer, we should timeout, causing the client to reset. Another SYN + // will be sent when re-entering the state, but the connection should be down. + assertEqual(t, len(h.Sent), 2) + assertType[packetSyn](t, h.Sent[1].Packet) + assertType[*stateClientRelayed](t, h.State) + assertEqual(t, h.Published.Up, false) +} + +func TestStateClientRelayed_OnProbe_unknownAddr(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientRelayed(t) + + h.OnProbe(controlMsg[packetProbe]{ + Packet: packetProbe{TraceID: newTraceID()}, + }) + + assertType[*stateClientRelayed](t, h.State) +} + +func TestStateClientRelayed_OnProbe_upgradeDirect(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientRelayed(t) + + syn := assertType[packetSyn](t, h.Sent[0].Packet) + + ack := controlMsg[packetAck]{Packet: packetAck{TraceID: syn.TraceID}} + ack.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 300) + ack.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 300) + + h.State = h.State.OnMsg(ack) + h.OnPingTimer() + + probe := assertType[packetProbe](t, h.Sent[2].Packet) + h.OnProbe(controlMsg[packetProbe]{Packet: probe}) + + assertType[*stateClientDirect](t, h.State) +} +*/ diff --git a/peer/peersuper.go b/peer/peersuper.go new file mode 100644 index 0000000..2ce6d03 --- /dev/null +++ b/peer/peersuper.go @@ -0,0 +1,148 @@ +package peer + +import ( + "net/netip" + "sync" + "sync/atomic" + "time" + + "git.crumpington.com/lib/go/ratelimiter" +) + +type supervisor struct { + writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error) + staged routingTable + shared *atomic.Pointer[routingTable] + peers [256]*peerSuper + lock sync.Mutex + + buf1 []byte + buf2 []byte +} + +func newSupervisor( + writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error), + rt *atomic.Pointer[routingTable], + privKey []byte, +) *supervisor { + + routes := rt.Load() + + s := &supervisor{ + writeToUDPAddrPort: writeToUDPAddrPort, + staged: *routes, + shared: rt, + buf1: newBuf(), + buf2: newBuf(), + } + + pubAddrs := newPubAddrStore(routes.LocalAddr) + + for i := range s.peers { + state := &peerData{ + publish: s.publish, + sendControlPacket: s.send, + pingTimer: time.NewTicker(timeoutInterval), + localIP: routes.LocalIP, + remoteIP: byte(i), + privKey: privKey, + localAddr: routes.LocalAddr, + pubAddrs: pubAddrs, + staged: routes.Peers[i], + limiter: ratelimiter.New(ratelimiter.Config{ + FillPeriod: 20 * time.Millisecond, + MaxWaitCount: 1, + }), + } + s.peers[i] = newPeerSuper(state, state.pingTimer) + } + + return s +} + +func (s *supervisor) Start() { + for i := range s.peers { + go s.peers[i].Run() + } +} + +func (s *supervisor) HandleControlMsg(destIP byte, msg any) { + s.peers[destIP].HandleControlMsg(msg) +} + +func (s *supervisor) send(peer remotePeer, pkt marshaller) { + s.lock.Lock() + defer s.lock.Unlock() + + enc := peer.EncryptControlPacket(pkt, s.buf1, s.buf2) + if peer.Direct { + s.writeToUDPAddrPort(enc, peer.DirectAddr) + return + } + + relay, ok := s.staged.GetRelay() + if !ok { + return + } + + enc = relay.EncryptDataPacket(peer.IP, enc, s.buf1) + s.writeToUDPAddrPort(enc, relay.DirectAddr) +} + +func (s *supervisor) publish(rp remotePeer) { + s.lock.Lock() + defer s.lock.Unlock() + + s.staged.Peers[rp.IP] = rp + s.ensureRelay() + copy := s.staged + s.shared.Store(©) +} + +func (s *supervisor) ensureRelay() { + if _, ok := s.staged.GetRelay(); ok { + return + } + + // TODO: Random selection? Something else? + for _, peer := range s.staged.Peers { + if peer.Up && peer.Direct && peer.Relay { + s.staged.RelayIP = peer.IP + return + } + } +} + +// ---------------------------------------------------------------------------- + +type peerSuper struct { + messages chan any + state peerState + pingTimer *time.Ticker +} + +func newPeerSuper(state *peerData, pingTimer *time.Ticker) *peerSuper { + return &peerSuper{ + messages: make(chan any, 8), + state: initPeerState(state, nil), + pingTimer: pingTimer, + } +} + +func (s *peerSuper) HandleControlMsg(msg any) { + select { + case s.messages <- msg: + default: + } +} + +func (s *peerSuper) Run() { + for { + select { + case <-s.pingTimer.C: + s.state = s.state.OnMsg(pingTimerMsg{}) + case raw := <-s.messages: + s.state = s.state.OnMsg(raw) + } + } +} diff --git a/peer/pubaddrs.go b/peer/pubaddrs.go new file mode 100644 index 0000000..7945458 --- /dev/null +++ b/peer/pubaddrs.go @@ -0,0 +1,86 @@ +package peer + +import ( + "net/netip" + "sort" + "sync" + "time" +) + +type pubAddrStore struct { + lock sync.Mutex + localPub bool + localAddr netip.AddrPort + lastSeen map[netip.AddrPort]time.Time + addrList []netip.AddrPort +} + +func newPubAddrStore(localAddr netip.AddrPort) *pubAddrStore { + return &pubAddrStore{ + localPub: localAddr.IsValid(), + localAddr: localAddr, + lastSeen: map[netip.AddrPort]time.Time{}, + addrList: make([]netip.AddrPort, 0, 32), + } +} + +func (store *pubAddrStore) Store(addr netip.AddrPort) { + if store.localPub { + return + } + + if !addr.IsValid() { + return + } + + if addr.Addr().IsPrivate() { + return + } + + store.lock.Lock() + defer store.lock.Unlock() + + if _, exists := store.lastSeen[addr]; !exists { + store.addrList = append(store.addrList, addr) + } + store.lastSeen[addr] = time.Now() + store.sort() +} + +func (store *pubAddrStore) Get() (addrs [8]netip.AddrPort) { + store.lock.Lock() + defer store.lock.Unlock() + + store.clean() + + if store.localPub { + addrs[0] = store.localAddr + return + } + + copy(addrs[:], store.addrList) + return +} + +func (store *pubAddrStore) clean() { + if store.localPub { + return + } + + for ip, lastSeen := range store.lastSeen { + if time.Since(lastSeen) > timeoutInterval { + delete(store.lastSeen, ip) + } + } + store.addrList = store.addrList[:0] + for ip := range store.lastSeen { + store.addrList = append(store.addrList, ip) + } + store.sort() +} + +func (store *pubAddrStore) sort() { + sort.Slice(store.addrList, func(i, j int) bool { + return store.lastSeen[store.addrList[j]].Before(store.lastSeen[store.addrList[i]]) + }) +} diff --git a/node/addrdiscovery_test.go b/peer/pubaddrs_test.go similarity index 87% rename from node/addrdiscovery_test.go rename to peer/pubaddrs_test.go index 9851d6a..fa47c22 100644 --- a/node/addrdiscovery_test.go +++ b/peer/pubaddrs_test.go @@ -1,4 +1,4 @@ -package node +package peer import ( "net/netip" @@ -7,7 +7,7 @@ import ( ) func TestPubAddrStore(t *testing.T) { - s := newPubAddrStore() + s := newPubAddrStore(netip.AddrPort{}) l := []netip.AddrPort{ netip.AddrPortFrom(netip.AddrFrom4([4]byte{0, 1, 2, 3}), 20), @@ -20,7 +20,7 @@ func TestPubAddrStore(t *testing.T) { time.Sleep(time.Millisecond) } - s.Clean() + s.clean() l2 := s.Get() if l2[0] != l[2] || l2[1] != l[1] || l2[2] != l[0] { diff --git a/peer/routingtable.go b/peer/routingtable.go new file mode 100644 index 0000000..3f0aac3 --- /dev/null +++ b/peer/routingtable.go @@ -0,0 +1,138 @@ +package peer + +import ( + "net/netip" + "sync/atomic" + "time" +) + +// TODO: Remove +func newRemotePeer(ip byte) *remotePeer { + counter := uint64(time.Now().Unix()<<30 + 1) + return &remotePeer{ + IP: ip, + counter: &counter, + dupCheck: newDupCheck(0), + } +} + +// ---------------------------------------------------------------------------- + +type remotePeer struct { + localIP byte + IP byte // VPN IP of peer (last byte). + Up bool // True if data can be sent on the peer. + Relay bool // True if the peer is a relay. + Direct bool // True if this is a direct connection. + DirectAddr netip.AddrPort // Remote address if directly connected. + PubSignKey []byte + ControlCipher *controlCipher + DataCipher *dataCipher + + counter *uint64 // For sending to. Atomic access only. + dupCheck *dupCheck // For receiving from. Not safe for concurrent use. +} + +func (p remotePeer) EncryptDataPacket(destIP byte, data, out []byte) []byte { + h := header{ + StreamID: dataStreamID, + Counter: atomic.AddUint64(p.counter, 1), + SourceIP: p.localIP, + DestIP: destIP, + } + return p.DataCipher.Encrypt(h, data, out) +} + +// Decrypts and de-dups incoming data packets. +func (p remotePeer) DecryptDataPacket(h header, enc, out []byte) ([]byte, error) { + dec, ok := p.DataCipher.Decrypt(enc, out) + if !ok { + return nil, errDecryptionFailed + } + + if p.dupCheck.IsDup(h.Counter) { + return nil, errDuplicateSeqNum + } + + return dec, nil +} + +// Peer must have a ControlCipher. +func (p remotePeer) EncryptControlPacket(pkt marshaller, tmp, out []byte) []byte { + tmp = pkt.Marshal(tmp) + h := header{ + StreamID: controlStreamID, + Counter: atomic.AddUint64(p.counter, 1), + SourceIP: p.localIP, + DestIP: p.IP, + } + + return p.ControlCipher.Encrypt(h, tmp, out) +} + +// Returns a controlMsg[PacketType]. Peer must have a non-nil ControlCipher. +// +// This function also drops packets with duplicate sequence numbers. +func (p remotePeer) DecryptControlPacket(fromAddr netip.AddrPort, h header, enc, tmp []byte) (any, error) { + out, ok := p.ControlCipher.Decrypt(enc, tmp) + if !ok { + return nil, errDecryptionFailed + } + + if p.dupCheck.IsDup(h.Counter) { + return nil, errDuplicateSeqNum + } + + msg, err := parseControlMsg(h.SourceIP, fromAddr, out) + if err != nil { + return nil, err + } + + return msg, nil +} + +// ---------------------------------------------------------------------------- + +type routingTable struct { + // The LocalIP is the configured IP address of the local peer on the VPN. + // + // This value is constant. + LocalIP byte + + // The LocalAddr is the configured local public address of the peer on the + // internet. If LocalAddr.IsValid(), then the local peer has a public + // address. + // + // This value is constant. + LocalAddr netip.AddrPort + + // The remote peer configurations. These are updated by + Peers [256]remotePeer + + // The current relay's VPN IP address, or zero if no relay is available. + RelayIP byte +} + +func newRoutingTable(localIP byte, localAddr netip.AddrPort) routingTable { + rt := routingTable{ + LocalIP: localIP, + LocalAddr: localAddr, + } + + for i := range rt.Peers { + counter := uint64(time.Now().Unix()<<30 + 1) + rt.Peers[i] = remotePeer{ + localIP: localIP, + IP: byte(i), + counter: &counter, + dupCheck: newDupCheck(0), + } + } + + return rt +} + +func (rt *routingTable) GetRelay() (remotePeer, bool) { + relay := rt.Peers[rt.RelayIP] + return relay, relay.Up && relay.Direct +} diff --git a/peer/routingtable_test.go b/peer/routingtable_test.go new file mode 100644 index 0000000..919449b --- /dev/null +++ b/peer/routingtable_test.go @@ -0,0 +1,169 @@ +package peer + +import ( + "bytes" + "reflect" + "testing" +) + +func TestRemotePeer_DecryptDataPacket(t *testing.T) { + p1, p2, _ := NewPeersForTesting() + orig := RandPacket() + + peer2 := p1.RT.Load().Peers[2] + peer1 := p2.RT.Load().Peers[1] + + enc := peer2.EncryptDataPacket(2, orig, newBuf()) + + h := parseHeader(enc) + if h.DestIP != 2 || h.SourceIP != 1 { + t.Fatal(h) + } + + dec, err := peer1.DecryptDataPacket(h, enc, newBuf()) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(orig, dec) { + t.Fatal(dec) + } +} + +func TestRemotePeer_DecryptDataPacket_packetAltered(t *testing.T) { + p1, p2, _ := NewPeersForTesting() + orig := RandPacket() + + peer2 := p1.RT.Load().Peers[2] + peer1 := p2.RT.Load().Peers[1] + + enc := peer2.EncryptDataPacket(2, orig, newBuf()) + + h := parseHeader(enc) + + for range 2048 { + _, err := peer1.DecryptDataPacket(h, ModifyPacket(enc), newBuf()) + if err == nil { + t.Fatal(enc) + } + } +} + +func TestRemotePeer_DecryptDataPacket_duplicateSequenceNumber(t *testing.T) { + p1, p2, _ := NewPeersForTesting() + orig := RandPacket() + + peer2 := p1.RT.Load().Peers[2] + peer1 := p2.RT.Load().Peers[1] + + enc := peer2.EncryptDataPacket(2, orig, newBuf()) + h := parseHeader(enc) + + if _, err := peer1.DecryptDataPacket(h, enc, newBuf()); err != nil { + t.Fatal(err) + } + + if _, err := peer1.DecryptDataPacket(h, enc, newBuf()); err == nil { + t.Fatal(err) + } +} + +func TestRemotePeer_DecryptControlPacket(t *testing.T) { + p1, p2, _ := NewPeersForTesting() + + peer2 := p1.RT.Load().Peers[2] + peer1 := p2.RT.Load().Peers[1] + + orig := packetProbe{TraceID: newTraceID()} + + enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf()) + + h := parseHeader(enc) + if h.DestIP != 2 || h.SourceIP != 1 { + t.Fatal(h) + } + + ctrlMsg, err := peer1.DecryptControlPacket(p1.RT.Load().LocalAddr, h, enc, newBuf()) + if err != nil { + t.Fatal(err) + } + + dec, ok := ctrlMsg.(controlMsg[packetProbe]) + if !ok { + t.Fatal(ctrlMsg) + } + + if dec.SrcIP != 1 || dec.SrcAddr != p1.RT.Load().LocalAddr { + t.Fatal(dec) + } + + if !reflect.DeepEqual(dec.Packet, orig) { + t.Fatal(dec) + } +} + +func TestRemotePeer_DecryptControlPacket_packetAltered(t *testing.T) { + p1, p2, _ := NewPeersForTesting() + + peer2 := p1.RT.Load().Peers[2] + peer1 := p2.RT.Load().Peers[1] + + orig := packetProbe{TraceID: newTraceID()} + + enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf()) + + h := parseHeader(enc) + if h.DestIP != 2 || h.SourceIP != 1 { + t.Fatal(h) + } + + for range 2048 { + ctrlMsg, err := peer1.DecryptControlPacket(p1.RT.Load().LocalAddr, h, ModifyPacket(enc), newBuf()) + if err == nil { + t.Fatal(ctrlMsg) + } + } +} + +func TestRemotePeer_DecryptControlPacket_duplicateSequenceNumber(t *testing.T) { + p1, p2, _ := NewPeersForTesting() + + peer2 := p1.RT.Load().Peers[2] + peer1 := p2.RT.Load().Peers[1] + + orig := packetProbe{TraceID: newTraceID()} + + enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf()) + + h := parseHeader(enc) + if h.DestIP != 2 || h.SourceIP != 1 { + t.Fatal(h) + } + + if _, err := peer1.DecryptControlPacket(p1.RT.Load().LocalAddr, h, enc, newBuf()); err != nil { + t.Fatal(err) + } + if _, err := peer1.DecryptControlPacket(p1.RT.Load().LocalAddr, h, enc, newBuf()); err == nil { + t.Fatal(err) + } +} + +func TestRemotePeer_DecryptControlPacket_unknownPacketType(t *testing.T) { + p1, p2, _ := NewPeersForTesting() + + peer2 := p1.RT.Load().Peers[2] + peer1 := p2.RT.Load().Peers[1] + + orig := UnknownControlPacket{TraceID: newTraceID()} + + enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf()) + + h := parseHeader(enc) + if h.DestIP != 2 || h.SourceIP != 1 { + t.Fatal(h) + } + + if _, err := peer1.DecryptControlPacket(p1.RT.Load().LocalAddr, h, enc, newBuf()); err == nil { + t.Fatal(err) + } +} diff --git a/peer/state-client.go b/peer/state-client.go new file mode 100644 index 0000000..7e9d7c9 --- /dev/null +++ b/peer/state-client.go @@ -0,0 +1,162 @@ +package peer + +import ( + "net/netip" + "time" +) + +type sentProbe struct { + SentAt time.Time + Addr netip.AddrPort +} + +type stateClient struct { + *peerData + lastSeen time.Time + syn packetSyn + probes map[uint64]sentProbe +} + +func enterStateClient(data *peerData) peerState { + ip, ipValid := netip.AddrFromSlice(data.peer.PublicIP) + + data.staged.Relay = data.peer.Relay && ipValid + data.staged.Direct = ipValid + data.staged.DirectAddr = netip.AddrPortFrom(ip, data.peer.Port) + data.publish(data.staged) + + state := &stateClient{ + peerData: data, + lastSeen: time.Now(), + syn: packetSyn{ + TraceID: newTraceID(), + SharedKey: data.staged.DataCipher.Key(), + Direct: data.staged.Direct, + PossibleAddrs: data.pubAddrs.Get(), + }, + probes: map[uint64]sentProbe{}, + } + + state.Send(state.staged, state.syn) + + data.pingTimer.Reset(pingInterval) + + state.logf("==> Client") + return state +} + +func (s *stateClient) logf(str string, args ...any) { + s.peerData.logf("CLNT | "+str, args...) +} + +func (s *stateClient) OnMsg(raw any) peerState { + switch msg := raw.(type) { + case peerUpdateMsg: + return initPeerState(s.peerData, msg.Peer) + case controlMsg[packetAck]: + s.onAck(msg) + case controlMsg[packetProbe]: + return s.onProbe(msg) + case controlMsg[packetLocalDiscovery]: + s.onLocalDiscovery(msg) + case pingTimerMsg: + return s.onPingTimer() + default: + s.logf("Ignoring message: %v", raw) + } + return s +} + +func (s *stateClient) onAck(msg controlMsg[packetAck]) { + if msg.Packet.TraceID != s.syn.TraceID { + return + } + + s.lastSeen = time.Now() + + if !s.staged.Up { + s.staged.Up = true + s.publish(s.staged) + s.logf("Got ACK.") + } + + if s.staged.Direct { + s.pubAddrs.Store(msg.Packet.ToAddr) + return + } + + // Relayed below. + + s.cleanProbes() + + for _, addr := range msg.Packet.PossibleAddrs { + if !addr.IsValid() { + break + } + s.sendProbeTo(addr) + } +} + +func (s *stateClient) onPingTimer() peerState { + if time.Since(s.lastSeen) > timeoutInterval { + if s.staged.Up { + s.logf("Timeout.") + } + return initPeerState(s.peerData, s.peer) + } + + s.Send(s.staged, s.syn) + return s +} + +func (s *stateClient) onProbe(msg controlMsg[packetProbe]) peerState { + if s.staged.Direct { + return s + } + + s.cleanProbes() + + sent, ok := s.probes[msg.Packet.TraceID] + if !ok { + return s + } + + s.staged.Direct = true + s.staged.DirectAddr = sent.Addr + s.publish(s.staged) + + s.syn.TraceID = newTraceID() + s.syn.Direct = true + s.Send(s.staged, s.syn) + s.logf("Successful probe to %v.", sent.Addr) + return s +} + +func (s *stateClient) onLocalDiscovery(msg controlMsg[packetLocalDiscovery]) { + if s.staged.Direct { + return + } + + // The source port will be the multicast port, so we'll have to + // construct the correct address using the peer's listed port. + addr := netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) + s.sendProbeTo(addr) +} + +func (s *stateClient) cleanProbes() { + for key, sent := range s.probes { + if time.Since(sent.SentAt) > pingInterval { + delete(s.probes, key) + } + } +} + +func (s *stateClient) sendProbeTo(addr netip.AddrPort) { + probe := packetProbe{TraceID: newTraceID()} + s.probes[probe.TraceID] = sentProbe{ + SentAt: time.Now(), + Addr: addr, + } + s.logf("Probing %v...", addr) + s.SendTo(probe, addr) +} diff --git a/peer/state-client_test.go b/peer/state-client_test.go new file mode 100644 index 0000000..25441e8 --- /dev/null +++ b/peer/state-client_test.go @@ -0,0 +1,193 @@ +package peer + +import ( + "testing" + "time" +) + +func TestStateClient_peerUpdate(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientDirect(t) + h.PeerUpdate(nil) + assertType[*stateDisconnected](t, h.State) +} + +func TestStateClient_initialPackets(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientDirect(t) + + assertEqual(t, len(h.Sent), 2) + assertType[packetInit](t, h.Sent[0].Packet) + assertType[packetSyn](t, h.Sent[1].Packet) +} + +func TestStateClient_onAck_incorrectTraceID(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientDirect(t) + h.Sent = h.Sent[:0] + + ack := controlMsg[packetAck]{ + Packet: packetAck{TraceID: newTraceID()}, + } + h.OnAck(ack) + + // Nothing should have happened. + assertType[*stateClient](t, h.State) + assertEqual(t, len(h.Sent), 0) +} + +func TestStateClient_onAck_direct_downToUp(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientDirect(t) + + assertEqual(t, len(h.Sent), 2) + syn := assertType[packetSyn](t, h.Sent[1].Packet) + h.Sent = h.Sent[:0] + + assertEqual(t, h.Published.Up, false) + + ack := controlMsg[packetAck]{ + Packet: packetAck{TraceID: syn.TraceID}, + } + + h.OnAck(ack) + + assertEqual(t, len(h.Sent), 0) +} + +func TestStateClient_onAck_relayed_sendsProbes(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientRelayed(t) + + assertEqual(t, len(h.Sent), 2) + syn := assertType[packetSyn](t, h.Sent[1].Packet) + h.Sent = h.Sent[:0] + + assertEqual(t, h.Published.Up, false) + + ack := controlMsg[packetAck]{ + Packet: packetAck{TraceID: syn.TraceID}, + } + ack.Packet.PossibleAddrs[0] = addrPort4(1, 2, 3, 4, 100) + ack.Packet.PossibleAddrs[1] = addrPort4(2, 3, 4, 5, 200) + + h.OnAck(ack) + + assertEqual(t, len(h.Sent), 2) + assertType[packetProbe](t, h.Sent[0].Packet) + assertEqual(t, h.Sent[0].Peer.DirectAddr, ack.Packet.PossibleAddrs[0]) + assertType[packetProbe](t, h.Sent[1].Packet) + assertEqual(t, h.Sent[1].Peer.DirectAddr, ack.Packet.PossibleAddrs[1]) +} + +func TestStateClient_onPing(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientRelayed(t) + h.Sent = h.Sent[:0] + h.OnPingTimer() + assertEqual(t, len(h.Sent), 1) + assertType[*stateClient](t, h.State) + assertType[packetSyn](t, h.Sent[0].Packet) +} + +func TestStateClient_onPing_timeout(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientRelayed(t) + h.Sent = h.Sent[:0] + state := assertType[*stateClient](t, h.State) + state.lastSeen = time.Now().Add(-2 * timeoutInterval) + state.staged.Up = true + h.OnPingTimer() + + newState := assertType[*stateClientInit](t, h.State) + assertEqual(t, newState.staged.Up, false) + assertEqual(t, len(h.Sent), 1) + assertType[packetInit](t, h.Sent[0].Packet) +} + +func TestStateClient_onProbe_direct(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientDirect(t) + + h.Sent = h.Sent[:0] + probe := controlMsg[packetProbe]{ + Packet: packetProbe{ + TraceID: newTraceID(), + }, + } + + h.OnProbe(probe) + assertType[*stateClient](t, h.State) + assertEqual(t, len(h.Sent), 0) +} + +func TestStateClient_onProbe_noMatch(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientRelayed(t) + + h.Sent = h.Sent[:0] + probe := controlMsg[packetProbe]{ + Packet: packetProbe{ + TraceID: newTraceID(), + }, + } + + h.OnProbe(probe) + assertType[*stateClient](t, h.State) + assertEqual(t, len(h.Sent), 0) +} + +func TestStateClient_onProbe_directUpgrade(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientRelayed(t) + + state := assertType[*stateClient](t, h.State) + traceID := newTraceID() + state.probes[traceID] = sentProbe{ + SentAt: time.Now(), + Addr: addrPort4(1, 2, 3, 4, 500), + } + + probe := controlMsg[packetProbe]{ + Packet: packetProbe{TraceID: traceID}, + } + + assertEqual(t, h.Published.Direct, false) + h.Sent = h.Sent[:0] + h.OnProbe(probe) + assertEqual(t, h.Published.Direct, true) + + assertEqual(t, len(h.Sent), 1) + assertType[packetSyn](t, h.Sent[0].Packet) +} + +func TestStateClient_onLocalDiscovery_direct(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientDirect(t) + + h.Sent = h.Sent[:0] + pkt := controlMsg[packetLocalDiscovery]{ + Packet: packetLocalDiscovery{}, + } + + h.OnLocalDiscovery(pkt) + assertType[*stateClient](t, h.State) + assertEqual(t, len(h.Sent), 0) +} + +func TestStateClient_onLocalDiscovery_relayed(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientRelayed(t) + + h.Sent = h.Sent[:0] + pkt := controlMsg[packetLocalDiscovery]{ + SrcAddr: addrPort4(1, 2, 3, 4, 500), + Packet: packetLocalDiscovery{}, + } + + h.OnLocalDiscovery(pkt) + assertType[*stateClient](t, h.State) + assertEqual(t, len(h.Sent), 1) + assertType[packetProbe](t, h.Sent[0].Packet) + assertEqual(t, h.Sent[0].Peer.DirectAddr, addrPort4(1, 2, 3, 4, 456)) +} diff --git a/peer/state-clientinit.go b/peer/state-clientinit.go new file mode 100644 index 0000000..8d963bc --- /dev/null +++ b/peer/state-clientinit.go @@ -0,0 +1,95 @@ +package peer + +import ( + "net/netip" + "time" +) + +type stateClientInit struct { + *peerData + startedAt time.Time + traceID uint64 +} + +func enterStateClientInit(data *peerData) peerState { + ip, ipValid := netip.AddrFromSlice(data.peer.PublicIP) + + data.staged.Up = false + data.staged.Relay = false + data.staged.Direct = ipValid + data.staged.DirectAddr = netip.AddrPortFrom(ip, data.peer.Port) + data.staged.PubSignKey = data.peer.PubSignKey + data.staged.ControlCipher = newControlCipher(data.privKey, data.peer.PubKey) + data.staged.DataCipher = newDataCipher() + + data.publish(data.staged) + + state := &stateClientInit{ + peerData: data, + startedAt: time.Now(), + traceID: newTraceID(), + } + state.sendInit() + + data.pingTimer.Reset(pingInterval) + + state.logf("==> ClientInit") + return state +} + +func (s *stateClientInit) logf(str string, args ...any) { + s.peerData.logf("INIT | "+str, args...) +} + +func (s *stateClientInit) OnMsg(raw any) peerState { + switch msg := raw.(type) { + case peerUpdateMsg: + return initPeerState(s.peerData, msg.Peer) + case controlMsg[packetInit]: + return s.onInit(msg) + case controlMsg[packetSyn]: + s.logf("Unexpected SYN") + return s + case controlMsg[packetAck]: + s.logf("Unexpected ACK") + return s + case controlMsg[packetProbe]: + return s + case controlMsg[packetLocalDiscovery]: + return s + case pingTimerMsg: + return s.onPing() + default: + s.logf("Ignoring message: %#v", raw) + return s + } +} + +func (s *stateClientInit) onInit(msg controlMsg[packetInit]) peerState { + if msg.Packet.TraceID != s.traceID { + s.logf("Invalid trace ID on INIT.") + return s + } + s.logf("Got INIT version %d.", msg.Packet.Version) + return enterStateClient(s.peerData) +} + +func (s *stateClientInit) onPing() peerState { + if time.Since(s.startedAt) > timeoutInterval { + s.logf("Init timeout. Assuming version 1.") + return enterStateClient(s.peerData) + } + + s.sendInit() + return s +} + +func (s *stateClientInit) sendInit() { + s.traceID = newTraceID() + init := packetInit{ + TraceID: s.traceID, + Direct: s.staged.Direct, + Version: version, + } + s.Send(s.staged, init) +} diff --git a/peer/state-clientinit_test.go b/peer/state-clientinit_test.go new file mode 100644 index 0000000..87cdc8b --- /dev/null +++ b/peer/state-clientinit_test.go @@ -0,0 +1,83 @@ +package peer + +import ( + "testing" + "time" +) + +func TestPeerState_ClientInit_initWithIncorrectTraceID(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientInit(t) + + // Should have sent the first init packet. + assertEqual(t, len(h.Sent), 1) + init := assertType[packetInit](t, h.Sent[0].Packet) + + init.TraceID = newTraceID() + h.OnInit(controlMsg[packetInit]{Packet: init}) + + assertType[*stateClientInit](t, h.State) +} + +func TestPeerState_ClientInit_init(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientInit(t) + + // Should have sent the first init packet. + assertEqual(t, len(h.Sent), 1) + init := assertType[packetInit](t, h.Sent[0].Packet) + h.OnInit(controlMsg[packetInit]{Packet: init}) + + assertType[*stateClient](t, h.State) +} + +func TestPeerState_ClientInit_onPing(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientInit(t) + + // Should have sent the first init packet. + assertEqual(t, len(h.Sent), 1) + h.Sent = h.Sent[:0] + + for range 3 { + h.OnPingTimer() + } + + assertEqual(t, len(h.Sent), 3) + + for i := range h.Sent { + assertType[packetInit](t, h.Sent[i].Packet) + } +} + +func TestPeerState_ClientInit_onPingTimeout(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientInit(t) + + state := assertType[*stateClientInit](t, h.State) + state.startedAt = time.Now().Add(-2 * timeoutInterval) + + h.OnPingTimer() + + // Should have moved into the client state due to timeout. + assertType[*stateClient](t, h.State) +} + +func TestPeerState_ClientInit_onPeerUpdate(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientInit(t) + + h.PeerUpdate(nil) + + // Should have moved into the client state due to timeout. + assertType[*stateDisconnected](t, h.State) +} + +func TestPeerState_ClientInit_ignoreMessage(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigClientInit(t) + h.OnProbe(controlMsg[packetProbe]{}) + + // Shouldn't do anything. + assertType[*stateClientInit](t, h.State) +} diff --git a/peer/state-disconnected.go b/peer/state-disconnected.go new file mode 100644 index 0000000..ea503dc --- /dev/null +++ b/peer/state-disconnected.go @@ -0,0 +1,50 @@ +package peer + +import "net/netip" + +type stateDisconnected struct { + *peerData +} + +func enterStateDisconnected(data *peerData) peerState { + data.staged.Up = false + data.staged.Relay = false + data.staged.Direct = false + data.staged.DirectAddr = netip.AddrPort{} + data.staged.PubSignKey = nil + data.staged.ControlCipher = nil + data.staged.DataCipher = nil + + data.publish(data.staged) + + data.pingTimer.Stop() + + return &stateDisconnected{data} +} + +func (s *stateDisconnected) OnMsg(raw any) peerState { + switch msg := raw.(type) { + case peerUpdateMsg: + return initPeerState(s.peerData, msg.Peer) + case controlMsg[packetInit]: + s.logf("Unexpected INIT") + return s + case controlMsg[packetSyn]: + s.logf("Unexpected SYN") + return s + case controlMsg[packetAck]: + s.logf("Unexpected ACK") + return s + case controlMsg[packetProbe]: + s.logf("Unexpected probe") + return s + case controlMsg[packetLocalDiscovery]: + return s + case pingTimerMsg: + s.logf("Unexpected ping") + return s + default: + s.logf("Ignoring message: %#v", raw) + return s + } +} diff --git a/peer/state-server.go b/peer/state-server.go new file mode 100644 index 0000000..c9c76db --- /dev/null +++ b/peer/state-server.go @@ -0,0 +1,136 @@ +package peer + +import ( + "net/netip" + "time" +) + +type stateServer struct { + *peerData + lastSeen time.Time + synTraceID uint64 // Last syn trace ID. +} + +func enterStateServer(data *peerData) peerState { + data.staged.Up = false + data.staged.Relay = false + data.staged.Direct = false + data.staged.DirectAddr = netip.AddrPort{} + data.staged.PubSignKey = data.peer.PubSignKey + data.staged.ControlCipher = newControlCipher(data.privKey, data.peer.PubKey) + data.staged.DataCipher = nil + + data.publish(data.staged) + + data.pingTimer.Reset(pingInterval) + + state := &stateServer{ + peerData: data, + lastSeen: time.Now(), + } + state.logf("==> Server") + return state +} + +func (s *stateServer) logf(str string, args ...any) { + s.peerData.logf("SRVR | "+str, args...) +} + +func (s *stateServer) OnMsg(raw any) peerState { + switch msg := raw.(type) { + case peerUpdateMsg: + return initPeerState(s.peerData, msg.Peer) + case controlMsg[packetInit]: + return s.onInit(msg) + case controlMsg[packetSyn]: + return s.onSyn(msg) + case controlMsg[packetAck]: + s.logf("Unexpected ACK") + return s + case controlMsg[packetProbe]: + return s.onProbe(msg) + case controlMsg[packetLocalDiscovery]: + return s + case pingTimerMsg: + return s.onPingTimer() + default: + s.logf("Unexpected message: %#v", raw) + return s + } +} + +func (s *stateServer) onInit(msg controlMsg[packetInit]) peerState { + s.staged.Up = false + s.staged.Direct = msg.Packet.Direct + s.staged.DirectAddr = msg.SrcAddr + s.publish(s.staged) + + init := packetInit{ + TraceID: msg.Packet.TraceID, + Direct: s.staged.Direct, + Version: version, + } + + s.Send(s.staged, init) + + return s +} + +func (s *stateServer) onSyn(msg controlMsg[packetSyn]) peerState { + s.lastSeen = time.Now() + p := msg.Packet + + // Before we can respond to this packet, we need to make sure the + // route is setup properly. + // + // The client will update the syn's TraceID whenever there's a change. + // The server will follow the client's request. + if p.TraceID != s.synTraceID || !s.staged.Up { + s.synTraceID = p.TraceID + s.staged.Up = true + s.staged.Direct = p.Direct + s.staged.DataCipher = newDataCipherFromKey(p.SharedKey) + s.staged.DirectAddr = msg.SrcAddr + s.publish(s.staged) + s.logf("Got SYN.") + } + + // Always respond. + s.Send(s.staged, packetAck{ + TraceID: p.TraceID, + ToAddr: s.staged.DirectAddr, + PossibleAddrs: s.pubAddrs.Get(), + }) + + if p.Direct { + return s + } + + // Send probes if not a direct connection. + for _, addr := range msg.Packet.PossibleAddrs { + if !addr.IsValid() { + break + } + s.logf("Probing %v...", addr) + s.SendTo(packetProbe{TraceID: newTraceID()}, addr) + } + + return s +} + +func (s *stateServer) onProbe(msg controlMsg[packetProbe]) peerState { + if msg.SrcAddr.IsValid() { + s.logf("Probe response %v...", msg.SrcAddr) + s.SendTo(packetProbe{TraceID: msg.Packet.TraceID}, msg.SrcAddr) + } + return s +} + +func (s *stateServer) onPingTimer() peerState { + if time.Since(s.lastSeen) > timeoutInterval && s.staged.Up { + s.staged.Up = false + s.publish(s.staged) + s.logf("Timeout.") + } + return s +} diff --git a/peer/state-server_test.go b/peer/state-server_test.go new file mode 100644 index 0000000..b367786 --- /dev/null +++ b/peer/state-server_test.go @@ -0,0 +1,164 @@ +package peer + +import ( + "testing" + "time" +) + +func TestStateServer_peerUpdate(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigServer_Public(t) + h.PeerUpdate(nil) + assertType[*stateDisconnected](t, h.State) +} + +func TestStateServer_onInit(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigServer_Public(t) + + msg := controlMsg[packetInit]{ + SrcIP: 3, + SrcAddr: addrPort4(1, 2, 3, 4, 1000), + Packet: packetInit{ + TraceID: newTraceID(), + Direct: true, + Version: 4, + }, + } + + h.OnInit(msg) + assertEqual(t, len(h.Sent), 1) + assertEqual(t, h.Sent[0].Peer.DirectAddr, msg.SrcAddr) + resp := assertType[packetInit](t, h.Sent[0].Packet) + assertEqual(t, msg.Packet.TraceID, resp.TraceID) + assertEqual(t, resp.Version, version) +} + +func TestStateServer_onSynDirect(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigServer_Public(t) + + msg := controlMsg[packetSyn]{ + SrcIP: 3, + SrcAddr: addrPort4(1, 2, 3, 4, 1000), + Packet: packetSyn{ + TraceID: newTraceID(), + Direct: true, + }, + } + + msg.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 1000) + msg.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 2000) + + h.OnSyn(msg) + assertEqual(t, len(h.Sent), 1) + assertEqual(t, h.Sent[0].Peer.DirectAddr, msg.SrcAddr) + resp := assertType[packetAck](t, h.Sent[0].Packet) + assertEqual(t, msg.Packet.TraceID, resp.TraceID) +} + +func TestStateServer_onSynRelayed(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigServer_Relayed(t) + + msg := controlMsg[packetSyn]{ + SrcIP: 3, + SrcAddr: addrPort4(1, 2, 3, 4, 1000), + Packet: packetSyn{ + TraceID: newTraceID(), + }, + } + + msg.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 1000) + msg.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 2000) + + h.OnSyn(msg) + assertEqual(t, len(h.Sent), 3) + assertEqual(t, h.Sent[0].Peer.DirectAddr, msg.SrcAddr) + resp := assertType[packetAck](t, h.Sent[0].Packet) + assertEqual(t, msg.Packet.TraceID, resp.TraceID) + + for i, pkt := range h.Sent[1:] { + assertEqual(t, pkt.Peer.DirectAddr, msg.Packet.PossibleAddrs[i]) + assertType[packetProbe](t, pkt.Packet) + } +} + +func TestStateServer_onProbe(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigServer_Relayed(t) + + msg := controlMsg[packetProbe]{ + SrcIP: 3, + Packet: packetProbe{ + TraceID: newTraceID(), + }, + } + h.Sent = h.Sent[:0] + + h.OnProbe(msg) + assertEqual(t, len(h.Sent), 0) +} + +func TestStateServer_onProbe_valid(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigServer_Relayed(t) + + msg := controlMsg[packetProbe]{ + SrcIP: 3, + SrcAddr: addrPort4(1, 2, 3, 4, 100), + Packet: packetProbe{ + TraceID: newTraceID(), + }, + } + h.Sent = h.Sent[:0] + + h.OnProbe(msg) + assertEqual(t, len(h.Sent), 1) + assertType[packetProbe](t, h.Sent[0].Packet) + assertEqual(t, h.Sent[0].Peer.DirectAddr, msg.SrcAddr) +} + +func TestStateServer_onPing(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigServer_Relayed(t) + h.Sent = h.Sent[:0] + h.OnPingTimer() + assertEqual(t, len(h.Sent), 0) + assertType[*stateServer](t, h.State) +} + +func TestStateServer_onPing_timeout(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigServer_Relayed(t) + + h.Sent = h.Sent[:0] + state := assertType[*stateServer](t, h.State) + state.staged.Up = true + state.lastSeen = time.Now().Add(-2 * timeoutInterval) + + h.OnPingTimer() + state = assertType[*stateServer](t, h.State) + assertEqual(t, len(h.Sent), 0) + assertEqual(t, state.staged.Up, false) +} + +func TestStateServer_onLocalDiscovery(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigServer_Relayed(t) + + msg := controlMsg[packetLocalDiscovery]{ + SrcIP: 3, + SrcAddr: addrPort4(1, 2, 3, 4, 100), + } + h.OnLocalDiscovery(msg) + assertType[*stateServer](t, h.State) +} + +func TestStateServer_onAck(t *testing.T) { + h := NewPeerStateTestHarness() + h.ConfigServer_Relayed(t) + msg := controlMsg[packetAck]{} + h.OnAck(msg) + assertType[*stateServer](t, h.State) +} diff --git a/peer/state-util_test.go b/peer/state-util_test.go new file mode 100644 index 0000000..465a8c3 --- /dev/null +++ b/peer/state-util_test.go @@ -0,0 +1,151 @@ +package peer + +import ( + "net/netip" + "testing" + "time" + "vppn/m" + + "git.crumpington.com/lib/go/ratelimiter" +) + +type PeerStateControlMsg struct { + Peer remotePeer + Packet any +} + +type PeerStateTestHarness struct { + data *peerData + State peerState + Published remotePeer + Sent []PeerStateControlMsg +} + +func NewPeerStateTestHarness() *PeerStateTestHarness { + h := &PeerStateTestHarness{} + + keys := generateKeys() + + state := &peerData{ + publish: func(rp remotePeer) { + h.Published = rp + }, + sendControlPacket: func(rp remotePeer, pkt marshaller) { + h.Sent = append(h.Sent, PeerStateControlMsg{rp, pkt}) + }, + pingTimer: time.NewTicker(pingInterval), + localIP: 2, + remoteIP: 3, + privKey: keys.PrivKey, + pubAddrs: newPubAddrStore(netip.AddrPort{}), + limiter: ratelimiter.New(ratelimiter.Config{ + FillPeriod: 20 * time.Millisecond, + MaxWaitCount: 1, + }), + } + h.data = state + + h.State = enterStateDisconnected(state) + return h +} + +func (h *PeerStateTestHarness) PeerUpdate(p *m.Peer) { + h.State = h.State.OnMsg(peerUpdateMsg{p}) +} + +func (h *PeerStateTestHarness) OnInit(msg controlMsg[packetInit]) { + h.State = h.State.OnMsg(msg) +} + +func (h *PeerStateTestHarness) OnSyn(msg controlMsg[packetSyn]) { + h.State = h.State.OnMsg(msg) +} + +func (h *PeerStateTestHarness) OnAck(msg controlMsg[packetAck]) { + h.State = h.State.OnMsg(msg) +} + +func (h *PeerStateTestHarness) OnProbe(msg controlMsg[packetProbe]) { + h.State = h.State.OnMsg(msg) +} + +func (h *PeerStateTestHarness) OnLocalDiscovery(msg controlMsg[packetLocalDiscovery]) { + h.State = h.State.OnMsg(msg) +} + +func (h *PeerStateTestHarness) OnPingTimer() { + h.State = h.State.OnMsg(pingTimerMsg{}) +} + +func (h *PeerStateTestHarness) ConfigServer_Public(t *testing.T) *stateServer { + keys := generateKeys() + + state := h.State.(*stateDisconnected) + state.localAddr = addrPort4(1, 1, 1, 2, 200) + + peer := &m.Peer{ + PeerIP: 3, + PublicIP: []byte{1, 1, 1, 3}, + Port: 456, + PubKey: keys.PubKey, + PubSignKey: keys.PubSignKey, + } + + h.PeerUpdate(peer) + assertEqual(t, h.Published.Up, false) + return assertType[*stateServer](t, h.State) +} + +func (h *PeerStateTestHarness) ConfigServer_Relayed(t *testing.T) *stateServer { + keys := generateKeys() + peer := &m.Peer{ + PeerIP: 3, + Port: 456, + PubKey: keys.PubKey, + PubSignKey: keys.PubSignKey, + } + + h.PeerUpdate(peer) + assertEqual(t, h.Published.Up, false) + return assertType[*stateServer](t, h.State) +} + +func (h *PeerStateTestHarness) ConfigClientInit(t *testing.T) *stateClientInit { + // Remote IP should be less than local IP. + h.data.localIP = 4 + keys := generateKeys() + peer := &m.Peer{ + PeerIP: 3, + PublicIP: []byte{1, 2, 3, 4}, + Port: 456, + PubKey: keys.PubKey, + PubSignKey: keys.PubSignKey, + } + + h.PeerUpdate(peer) + assertEqual(t, h.Published.Up, false) + return assertType[*stateClientInit](t, h.State) +} + +func (h *PeerStateTestHarness) ConfigClientDirect(t *testing.T) *stateClient { + h.ConfigClientInit(t) + init := assertType[packetInit](t, h.Sent[0].Packet) + h.OnInit(controlMsg[packetInit]{ + Packet: init, + }) + + return assertType[*stateClient](t, h.State) +} + +func (h *PeerStateTestHarness) ConfigClientRelayed(t *testing.T) *stateClient { + h.ConfigClientInit(t) + state := assertType[*stateClientInit](t, h.State) + state.peer.PublicIP = nil // Force relay. + + init := assertType[packetInit](t, h.Sent[0].Packet) + h.OnInit(controlMsg[packetInit]{ + Packet: init, + }) + + return assertType[*stateClient](t, h.State) +} diff --git a/peer/statedata.go b/peer/statedata.go new file mode 100644 index 0000000..5aee302 --- /dev/null +++ b/peer/statedata.go @@ -0,0 +1,109 @@ +package peer + +import ( + "fmt" + "log" + "net/netip" + "strings" + "time" + "vppn/m" + + "git.crumpington.com/lib/go/ratelimiter" +) + +type peerState interface { + OnMsg(raw any) peerState +} + +// ---------------------------------------------------------------------------- + +type peerData struct { + // Output. + publish func(remotePeer) + sendControlPacket func(remotePeer, marshaller) + pingTimer *time.Ticker + + // Immutable data. + localIP byte + remoteIP byte + privKey []byte + localAddr netip.AddrPort // If valid, then local peer is publicly accessible. + + pubAddrs *pubAddrStore + + // The purpose of this state machine is to manage the RemotePeer object, + // publishing it as necessary. + staged remotePeer // Local copy of shared data. See publish(). + + // Mutable peer data. + peer *m.Peer + + // We rate limit per remote endpoint because if we don't we tend to lose + // packets. + limiter *ratelimiter.Limiter +} + +func (s *peerData) logf(format string, args ...any) { + b := strings.Builder{} + name := "" + if s.peer != nil { + name = s.peer.Name + } + b.WriteString(fmt.Sprintf("%03d", s.remoteIP)) + + b.WriteString(fmt.Sprintf("%30s: ", name)) + + if s.staged.Direct { + b.WriteString("DIRECT | ") + } else { + b.WriteString("RELAYED | ") + } + + if s.staged.Up { + b.WriteString("UP | ") + } else { + b.WriteString("DOWN | ") + } + + log.Printf(b.String()+format, args...) +} + +// ---------------------------------------------------------------------------- + +func (s *peerData) SendTo(pkt marshaller, addr netip.AddrPort) { + if !addr.IsValid() { + return + } + route := s.staged + route.Direct = true + route.DirectAddr = addr + s.Send(route, pkt) +} + +func (s *peerData) Send(peer remotePeer, pkt marshaller) { + if err := s.limiter.Limit(); err != nil { + s.logf("Rate limited.") + return + } + s.sendControlPacket(peer, pkt) +} + +func initPeerState(data *peerData, peer *m.Peer) peerState { + data.peer = peer + + if peer == nil { + return enterStateDisconnected(data) + } + + if _, isValid := netip.AddrFromSlice(peer.PublicIP); isValid { + if data.localAddr.IsValid() && data.localIP < data.remoteIP { + return enterStateServer(data) + } + return enterStateClientInit(data) + } + + if data.localAddr.IsValid() || data.localIP < data.remoteIP { + return enterStateServer(data) + } + return enterStateClientInit(data) +} diff --git a/peer/util_test.go b/peer/util_test.go new file mode 100644 index 0000000..af05365 --- /dev/null +++ b/peer/util_test.go @@ -0,0 +1,26 @@ +package peer + +import ( + "net/netip" + "testing" +) + +func addrPort4(a, b, c, d byte, port uint16) netip.AddrPort { + return netip.AddrPortFrom(netip.AddrFrom4([4]byte{a, b, c, d}), port) +} + +func assertType[T any](t *testing.T, obj any) T { + t.Helper() + x, ok := obj.(T) + if !ok { + t.Fatalf("invalid type: %#v", obj) + } + return x +} + +func assertEqual[T comparable](t *testing.T, a, b T) { + t.Helper() + if a != b { + t.Fatal(a, " != ", b) + } +}