From 08f11ce82ba98c5abd5d5319b7077822ae0bac73 Mon Sep 17 00:00:00 2001 From: jdl Date: Sun, 22 Dec 2024 19:17:58 +0100 Subject: [PATCH] WIP: client/server/relay working. --- node/globals.go | 8 +- node/packets.go | 75 +------ node/packets_test.go | 40 +--- node/peer-states.go | 405 -------------------------------------- node/peer-super-states.go | 276 ++++++++++++++++++++++++++ node/peer-super.go | 80 ++++++++ node/peer-supervisor.go | 63 ++---- node/peer.go | 15 +- node/router.go | 7 - 9 files changed, 393 insertions(+), 576 deletions(-) delete mode 100644 node/peer-states.go create mode 100644 node/peer-super-states.go create mode 100644 node/peer-super.go delete mode 100644 node/router.go diff --git a/node/globals.go b/node/globals.go index d646e71..b78c2c9 100644 --- a/node/globals.go +++ b/node/globals.go @@ -1,9 +1,15 @@ package node +import "net/netip" + const ( bufferSize = 1536 - if_mtu = 1200 + if_mtu = 1400 if_queue_len = 2048 controlCipherOverhead = 16 dataCipherOverhead = 16 ) + +var ( + zeroAddrPort = netip.AddrPort{} +) diff --git a/node/packets.go b/node/packets.go index ffda859..0126359 100644 --- a/node/packets.go +++ b/node/packets.go @@ -3,8 +3,6 @@ package node import ( "errors" "net/netip" - "time" - "unsafe" ) var ( @@ -31,10 +29,6 @@ type controlPacket struct { func (p *controlPacket) ParsePayload(buf []byte) (err error) { switch buf[0] { - case packetTypePing: - p.Payload, err = parsePingPacket(buf) - case packetTypePong: - p.Payload, err = parsePongPacket(buf) case packetTypeSyn: p.Payload, err = parseSynPacket(buf) case packetTypeSynAck: @@ -50,10 +44,9 @@ func (p *controlPacket) ParsePayload(buf []byte) (err error) { // ---------------------------------------------------------------------------- type synPacket struct { - TraceID uint64 // TraceID to match response w/ request. - SharedKey [32]byte // Our shared key. - ServerAddr netip.AddrPort // The address we're sending to. - RelayIP byte + TraceID uint64 // TraceID to match response w/ request. + SharedKey [32]byte // Our shared key. + RelayIP byte } func (p synPacket) Marshal(buf []byte) []byte { @@ -61,7 +54,6 @@ func (p synPacket) Marshal(buf []byte) []byte { Byte(packetTypeSyn). Uint64(p.TraceID). SharedKey(p.SharedKey). - AddrPort(p.ServerAddr). Byte(p.RelayIP). Build() } @@ -70,7 +62,6 @@ func parseSynPacket(buf []byte) (p synPacket, err error) { err = newBinReader(buf[1:]). Uint64(&p.TraceID). SharedKey(&p.SharedKey). - AddrPort(&p.ServerAddr). Byte(&p.RelayIP). Error() return @@ -119,63 +110,3 @@ func parseAckPacket(buf []byte) (p ackPacket, err error) { Error() return } - -// ---------------------------------------------------------------------------- - -// A pingPacket is sent from a node acting as a client, to a node acting -// as a server. It always contains the shared key the client is expecting -// to use for data encryption with the server. -type pingPacket struct { - SentAt int64 // UnixMilli. // Not used. Use traceID. -} - -func newPingPacket() (pp pingPacket) { - pp.SentAt = time.Now().UnixMilli() - return -} - -func (p pingPacket) Marshal(buf []byte) []byte { - return newBinWriter(buf). - Byte(packetTypePing). - Int64(p.SentAt). - Build() -} - -func parsePingPacket(buf []byte) (p pingPacket, err error) { - err = newBinReader(buf[1:]). - Int64(&p.SentAt). - Error() - return -} - -// ---------------------------------------------------------------------------- - -// A pongPacket is sent by a node in a server role in response to a pingPacket. -type pongPacket struct { - SentAt int64 // UnixMilli. - RecvdAt int64 // UnixMilli. -} - -func newPongPacket(sentAt int64) (pp pongPacket) { - pp.SentAt = sentAt - pp.RecvdAt = time.Now().UnixMilli() - return -} - -func (p pongPacket) Marshal(buf []byte) []byte { - buf = buf[:17] - buf[0] = packetTypePong - *(*uint64)(unsafe.Pointer(&buf[1])) = uint64(p.SentAt) - *(*uint64)(unsafe.Pointer(&buf[9])) = uint64(p.RecvdAt) - - return buf -} - -func parsePongPacket(buf []byte) (p pongPacket, err error) { - if len(buf) != 17 { - return p, errMalformedPacket - } - p.SentAt = *(*int64)(unsafe.Pointer(&buf[1])) - p.RecvdAt = *(*int64)(unsafe.Pointer(&buf[9])) - return -} diff --git a/node/packets_test.go b/node/packets_test.go index 6d96ccb..660d30e 100644 --- a/node/packets_test.go +++ b/node/packets_test.go @@ -2,16 +2,13 @@ package node import ( "crypto/rand" - "net/netip" "reflect" "testing" ) func TestPacketSyn(t *testing.T) { in := synPacket{ - TraceID: newTraceID(), - Direct: true, - ServerAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 34), + TraceID: newTraceID(), } rand.Read(in.SharedKey[:]) @@ -54,38 +51,3 @@ func TestPacketAck(t *testing.T) { t.Fatal("\n", in, "\n", out) } } - -func TestPacketPing(t *testing.T) { - sharedKey := make([]byte, 32) - rand.Read(sharedKey) - - buf := make([]byte, bufferSize) - - p := newPingPacket([32]byte(sharedKey)) - out := p.Marshal(buf) - - p2, err := parsePingPacket(out) - if err != nil { - t.Fatal(err) - } - - if !reflect.DeepEqual(p, p2) { - t.Fatal(p, p2) - } -} - -func TestPacketPong(t *testing.T) { - buf := make([]byte, bufferSize) - - p := newPongPacket(123566) - out := p.Marshal(buf) - - p2, err := parsePongPacket(out) - if err != nil { - t.Fatal(err) - } - - if !reflect.DeepEqual(p, p2) { - t.Fatal(p, p2) - } -} diff --git a/node/peer-states.go b/node/peer-states.go deleted file mode 100644 index 39990bd..0000000 --- a/node/peer-states.go +++ /dev/null @@ -1,405 +0,0 @@ -package node - -import ( - "fmt" - "log" - "math/rand" - "net/netip" - "sync/atomic" - "time" - "vppn/m" -) - -type peerState interface { - Name() string - OnSyn(netip.AddrPort, synPacket) peerState - OnSynAck(netip.AddrPort, synAckPacket) peerState - OnAck(netip.AddrPort, ackPacket) peerState - - OnPingTimer() peerState - OnTimeoutTimer() peerState - - // When the peer is updated, we reset. Handled by base state. - OnPeerUpdate(*m.Peer) peerState -} - -// ---------------------------------------------------------------------------- - -type stateBase struct { - // The purpose of this state machine is to manage this published data. - published *atomic.Pointer[peerRoutingData] - staged peerRoutingData // Local copy of shared data. See publish(). - - // The other remote peers. - peers *remotePeers - - // Immutable data. - localIP byte - localPub bool - remoteIP byte - privKey []byte - conn *connWriter - - // For sending to peer. - counter *uint64 - - // Mutable peer data. - peer *m.Peer - remotePub bool - - // Timers - pingTimer *time.Timer - timeoutTimer *time.Timer - - buf []byte - encBuf []byte -} - -func (sb *stateBase) Name() string { return "idle" } - -func (s *stateBase) OnPeerUpdate(peer *m.Peer) peerState { - // Both nil: no change. - if peer == nil && s.peer == nil { - return nil - } - - // No change. - if peer != nil && s.peer != nil && s.peer.Version == peer.Version { - return nil - } - - return s.selectStateFromPeer(peer) -} - -func (s *stateBase) selectStateFromPeer(peer *m.Peer) peerState { - s.peer = peer - s.staged = peerRoutingData{} - defer s.publish() - - if peer == nil { - return newStateNoPeer(s) - } - - s.staged.controlCipher = newControlCipher(s.privKey, peer.EncPubKey) - s.staged.dataCipher = newDataCipher() - - s.resetPingTimer() - s.resetTimeoutTimer() - - ip, isValid := netip.AddrFromSlice(peer.PublicIP) - if isValid { - s.remotePub = true - s.staged.remoteAddr = netip.AddrPortFrom(ip, peer.Port) - s.staged.relay = peer.Mediator - } - - if s.remotePub == s.localPub { - if s.localIP < s.remoteIP { - return newStateServer2(s) - } - return newStateDialLocal(s) - } - - if s.remotePub { - return newStateDialLocal(s) - } - return newStateServer2(s) -} - -func (s *stateBase) OnSyn(rAddr netip.AddrPort, p synPacket) peerState { return nil } -func (s *stateBase) OnSynAck(rAddr netip.AddrPort, p synAckPacket) peerState { return nil } -func (s *stateBase) OnAck(rAddr netip.AddrPort, p ackPacket) peerState { return nil } - -func (s *stateBase) OnPingTimer() peerState { return nil } -func (s *stateBase) OnTimeoutTimer() peerState { return nil } - -// Helpers. - -func (s *stateBase) resetPingTimer() { s.pingTimer.Reset(pingInterval) } -func (s *stateBase) resetTimeoutTimer() { s.timeoutTimer.Reset(timeoutInterval) } -func (s *stateBase) stopPingTimer() { s.pingTimer.Stop() } -func (s *stateBase) stopTimeoutTimer() { s.timeoutTimer.Stop() } - -func (s *stateBase) logf(msg string, args ...any) { - log.Printf(fmt.Sprintf("[%03d] ", s.remoteIP)+msg, args...) -} - -func (s *stateBase) publish() { - data := s.staged - s.published.Store(&data) -} - -func (s *stateBase) selectRelay() byte { - possible := make([]byte, 0, 8) - for i, peer := range s.peers { - if peer.CanRelay() { - possible = append(possible, byte(i)) - } - } - - if len(possible) == 0 { - return 0 - } - return possible[rand.Intn(len(possible))] -} - -func (s *stateBase) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) { - buf := pkt.Marshal(s.buf) - h := header{ - StreamID: controlStreamID, - Counter: atomic.AddUint64(s.counter, 1), - SourceIP: s.localIP, - DestIP: s.remoteIP, - } - - buf = s.staged.controlCipher.Encrypt(h, buf, s.encBuf) - if s.staged.relayIP != 0 { - s.peers[s.staged.relayIP].RelayFor(s.remoteIP, buf) - } else { - s.conn.WriteTo(buf, s.staged.remoteAddr) - } -} - -// ---------------------------------------------------------------------------- - -type stateNoPeer struct{ *stateBase } - -func newStateNoPeer(b *stateBase) *stateNoPeer { - s := &stateNoPeer{b} - s.pingTimer.Stop() - s.timeoutTimer.Stop() - s.publish() - return s -} - -// ---------------------------------------------------------------------------- - -type stateServer2 struct { - *stateBase - syn synPacket - publishedTraceID uint64 -} - -// TODO: Server should send SynAck packets on a loop. -func newStateServer2(b *stateBase) peerState { - s := &stateServer2{stateBase: b} - s.resetTimeoutTimer() - return s -} - -func (s *stateServer2) Name() string { return "server" } - -func (s *stateServer2) OnSyn(remoteAddr netip.AddrPort, p synPacket) peerState { - s.syn = p - s.sendControlPacket(newSynAckPacket(p.TraceID)) - return nil -} - -func (s *stateServer2) OnAck(remoteAddr netip.AddrPort, p ackPacket) peerState { - if p.TraceID != s.syn.TraceID { - return nil - } - - s.resetTimeoutTimer() - - if p.TraceID == s.publishedTraceID { - return nil - } - - // Pubish staged - s.staged.remoteAddr = remoteAddr - s.staged.dataCipher = newDataCipherFromKey(s.syn.SharedKey) - s.staged.relayIP = s.syn.RelayIP - s.staged.up = true - s.publish() - - s.publishedTraceID = p.TraceID - return nil -} - -func (s *stateServer) OnTimeoutTimer() peerState { - // TODO: We're down. - return nil -} - -// ---------------------------------------------------------------------------- - -type stateDialLocal struct { - *stateBase - syn synPacket -} - -func newStateDialLocal(b *stateBase) peerState { - // s := stateDialLocal{stateBase: b} - // TODO: check for peer local address. - return newStateDialDirect(b) -} - -func (s *stateDialLocal) Name() string { return "dial-local" } - -// ---------------------------------------------------------------------------- - -type stateDialDirect struct { - *stateBase - syn synPacket -} - -func newStateDialDirect(b *stateBase) peerState { - // If we don't have an address, dial via relay. - if b.staged.remoteAddr == zeroAddrPort { - return newStateNoPeer(b) - } - - s := &stateDialDirect{stateBase: b} - s.syn = synPacket{ - TraceID: newTraceID(), - SharedKey: s.staged.dataCipher.Key(), - ServerAddr: b.staged.remoteAddr, - } - - s.sendControlPacket(s.syn) - s.resetTimeoutTimer() - - return s -} - -func (s *stateDialDirect) Name() string { return "dial-direct" } - -func (s *stateDialDirect) OnSynAck(remoteAddr netip.AddrPort, p synAckPacket) peerState { - if p.TraceID != s.syn.TraceID { - // Hmm... - return nil - } - - s.sendControlPacket(ackPacket{TraceID: s.syn.TraceID}) - s.logf("GOT SYN-ACK! TODO!") - // client should continue to respond to synAck packets from server. - // return newStateClientConnected(s.stateBase, s.syn.TraceID) ... - return nil -} - -func (s *stateDialDirect) OnTimeoutTimer() peerState { - s.logf("Timeout when dialing") - return newStateDialLocal(s.stateBase) -} - -// ---------------------------------------------------------------------------- - -type stateClient struct { - sharedKey [32]byte - *stateBase -} - -func newStateClient(b *stateBase) peerState { - s := &stateClient{stateBase: b} - s.publish() - - s.staged.dataCipher = newDataCipher() - s.sharedKey = s.staged.dataCipher.Key() - - s.sendControlPacket(newPingPacket()) - s.resetPingTimer() - s.resetTimeoutTimer() - return s -} - -func (s *stateClient) Name() string { return "client" } - -// ---------------------------------------------------------------------------- - -type stateServer struct { - *stateBase -} - -func newStateServer(b *stateBase) peerState { - s := &stateServer{b} - s.publish() - s.stopPingTimer() - s.stopTimeoutTimer() - return s -} - -func (s *stateServer) Name() string { return "server" } - -// ---------------------------------------------------------------------------- - -type stateSelectRelay struct { - *stateBase -} - -func newStateSelectRelay(b *stateBase) peerState { - s := &stateSelectRelay{stateBase: b} - s.staged.dataCipher = nil - s.staged.up = false - s.publish() - - if relay := s.selectRelay(); relay != 0 { - s.staged.up = false - s.staged.relayIP = relay - return s.selectRole() - } - - s.resetPingTimer() - s.stopTimeoutTimer() - return s -} - -func (s *stateSelectRelay) selectRole() peerState { - if s.localIP < s.remoteIP { - return newStateServerRelayed(s.stateBase) - } - return newStateClientRelayed(s.stateBase) -} - -func (s *stateSelectRelay) Name() string { return "select-relay" } - -func (s *stateSelectRelay) OnPingTimer() peerState { - if relay := s.selectRelay(); relay != 0 { - s.logf("Got relay IP: %d", relay) - s.staged.relayIP = relay - return s.selectRole() - } - s.resetPingTimer() - return nil -} - -// ---------------------------------------------------------------------------- - -type stateClientRelayed struct { - sharedKey [32]byte - *stateBase -} - -func newStateClientRelayed(b *stateBase) peerState { - s := &stateClientRelayed{stateBase: b} - - s.staged.dataCipher = newDataCipher() - s.sharedKey = s.staged.dataCipher.Key() - s.publish() - - s.sendControlPacket(newPingPacket()) - s.resetPingTimer() - s.resetTimeoutTimer() - return s -} - -func (s *stateClientRelayed) Name() string { return "client-relayed" } - -// ---------------------------------------------------------------------------- - -type stateServerRelayed struct { - *stateBase -} - -func newStateServerRelayed(b *stateBase) peerState { - s := &stateServerRelayed{b} - s.stopPingTimer() - s.resetTimeoutTimer() - return s -} - -func (s *stateServerRelayed) Name() string { return "server-relayed" } - -func (s *stateServerRelayed) OnTimeoutTimer() peerState { - return newStateSelectRelay(s.stateBase) -} diff --git a/node/peer-super-states.go b/node/peer-super-states.go new file mode 100644 index 0000000..6e615ae --- /dev/null +++ b/node/peer-super-states.go @@ -0,0 +1,276 @@ +package node + +import ( + "math/rand" + "net/netip" + "time" + "vppn/m" +) + +// ---------------------------------------------------------------------------- + +func (s *peerSuper) noPeer() stateFunc { + return s.peerUpdate(<-s.peerUpdates) +} + +// ---------------------------------------------------------------------------- + +func (s *peerSuper) peerUpdate(peer *m.Peer) stateFunc { + return func() stateFunc { return s._peerUpdate(peer) } +} + +func (s *peerSuper) _peerUpdate(peer *m.Peer) stateFunc { + defer s.publish() + + s.peer = peer + s.staged = peerRoutingData{} + + if s.peer == nil { + return s.noPeer + } + + s.staged.controlCipher = newControlCipher(s.privKey, peer.EncPubKey) + s.staged.dataCipher = newDataCipher() + + if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid { + s.remotePub = true + s.staged.relay = peer.Mediator + s.staged.remoteAddr = netip.AddrPortFrom(ip, peer.Port) + } + + if s.remotePub == s.localPub { + if s.localIP < s.remoteIP { + return s.serverAccept + } + return s.clientInit + } + + if s.remotePub { + return s.clientInit + } + return s.serverAccept +} + +// ---------------------------------------------------------------------------- + +func (s *peerSuper) serverAccept() stateFunc { + s.logf("STATE: server-accept") + s.staged.up = false + s.staged.dataCipher = nil + s.staged.remoteAddr = zeroAddrPort + s.staged.relayIP = 0 + s.publish() + + var syn synPacket + + for { + select { + case peer := <-s.peerUpdates: + return s.peerUpdate(peer) + + case pkt := <-s.controlPackets: + switch p := pkt.Payload.(type) { + + case synPacket: + syn = p + s.staged.remoteAddr = pkt.RemoteAddr + s.staged.dataCipher = newDataCipherFromKey(syn.SharedKey) + s.staged.relayIP = syn.RelayIP + s.publish() + s.sendControlPacket(newSynAckPacket(p.TraceID)) + + case ackPacket: + if p.TraceID != syn.TraceID { + continue + } + + // Publish. + return s.serverConnected(syn.TraceID) + } + } + } +} + +// ---------------------------------------------------------------------------- + +func (s *peerSuper) serverConnected(traceID uint64) stateFunc { + s.logf("STATE: server-connected") + s.staged.up = true + s.publish() + return func() stateFunc { + return s._serverConnected(traceID) + } +} + +func (s *peerSuper) _serverConnected(traceID uint64) stateFunc { + + timeoutTimer := time.NewTimer(timeoutInterval) + defer timeoutTimer.Stop() + + for { + select { + case peer := <-s.peerUpdates: + return s.peerUpdate(peer) + + case pkt := <-s.controlPackets: + switch p := pkt.Payload.(type) { + + case ackPacket: + if p.TraceID != traceID { + return s.serverAccept + } + + s.sendControlPacket(ackPacket{TraceID: traceID}) + timeoutTimer.Reset(timeoutInterval) + } + + case <-timeoutTimer.C: + s.logf("server timeout") + return s.serverAccept + } + } +} + +// ---------------------------------------------------------------------------- + +func (s *peerSuper) clientInit() stateFunc { + s.logf("STATE: client-init") + if !s.remotePub { + // TODO: Check local discovery for IP. + // TODO: Attempt UDP hole punch. + // TODO: client-relayed + return s.clientSelectRelay + } + + return s.clientDial +} + +// ---------------------------------------------------------------------------- + +func (s *peerSuper) clientSelectRelay() stateFunc { + s.logf("STATE: client-select-relay") + + timer := time.NewTimer(0) + defer timer.Stop() + + for { + select { + case peer := <-s.peerUpdates: + return s.peerUpdate(peer) + + case <-timer.C: + ip := s.selectRelayIP() + if ip != 0 { + s.logf("Got relay: %d", ip) + s.staged.relayIP = ip + s.publish() + return s.clientDial + } + + s.logf("No relay available.") + timer.Reset(pingInterval) + } + } +} + +func (s *peerSuper) selectRelayIP() byte { + possible := make([]byte, 0, 8) + for i, peer := range s.peers { + if peer.CanRelay() { + possible = append(possible, byte(i)) + } + } + + if len(possible) == 0 { + return 0 + } + return possible[rand.Intn(len(possible))] +} + +// ---------------------------------------------------------------------------- + +func (s *peerSuper) clientDial() stateFunc { + s.logf("STATE: client-dial") + + var ( + syn = synPacket{ + TraceID: newTraceID(), + SharedKey: s.staged.dataCipher.Key(), + RelayIP: s.staged.relayIP, + } + + timeout = time.NewTimer(dialTimeout) + ) + + defer timeout.Stop() + + s.sendControlPacket(syn) + + for { + select { + + case peer := <-s.peerUpdates: + return s.peerUpdate(peer) + + case pkt := <-s.controlPackets: + switch p := pkt.Payload.(type) { + case synAckPacket: + if p.TraceID != syn.TraceID { + continue // Hmm... + } + s.sendControlPacket(ackPacket{TraceID: syn.TraceID}) + return s.clientConnected(syn.TraceID) + } + + case <-timeout.C: + return s.clientInit + } + } +} + +// ---------------------------------------------------------------------------- + +func (s *peerSuper) clientConnected(traceID uint64) stateFunc { + s.logf("STATE: client-connected") + s.staged.up = true + s.publish() + + return func() stateFunc { + return s._clientConnected(traceID) + } +} + +func (s *peerSuper) _clientConnected(traceID uint64) stateFunc { + + pingTimer := time.NewTimer(pingInterval) + timeoutTimer := time.NewTimer(timeoutInterval) + + defer pingTimer.Stop() + defer timeoutTimer.Stop() + + for { + select { + case peer := <-s.peerUpdates: + return s.peerUpdate(peer) + + case pkt := <-s.controlPackets: + switch p := pkt.Payload.(type) { + + case ackPacket: + if p.TraceID != traceID { + return s.clientInit + } + timeoutTimer.Reset(timeoutInterval) + } + + case <-pingTimer.C: + s.sendControlPacket(ackPacket{TraceID: traceID}) + pingTimer.Reset(pingInterval) + + case <-timeoutTimer.C: + s.logf("client timeout") + return s.clientInit + + } + } +} diff --git a/node/peer-super.go b/node/peer-super.go new file mode 100644 index 0000000..df1907f --- /dev/null +++ b/node/peer-super.go @@ -0,0 +1,80 @@ +package node + +import ( + "fmt" + "log" + "sync/atomic" + "vppn/m" +) + +type peerSuper struct { + // The purpose of this state machine is to manage this published data. + published *atomic.Pointer[peerRoutingData] + staged peerRoutingData // Local copy of shared data. See publish(). + + // The other remote peers. + peers *remotePeers + + // Immutable data. + localIP byte + localPub bool + remoteIP byte + privKey []byte + conn *connWriter + + // For sending to peer. + counter *uint64 + + // Mutable peer data. + peer *m.Peer + remotePub bool + + // Incoming events. + peerUpdates chan *m.Peer + controlPackets chan controlPacket + + // Buffers + buf []byte + encBuf []byte +} + +type stateFunc func() stateFunc + +func (s *peerSuper) Run() { + state := s.noPeer + for { + state = state() + } +} + +// ---------------------------------------------------------------------------- + +func (s *peerSuper) logf(msg string, args ...any) { + log.Printf(fmt.Sprintf("[%03d] ", s.remoteIP)+msg, args...) +} + +// ---------------------------------------------------------------------------- + +func (s *peerSuper) publish() { + data := s.staged + s.published.Store(&data) +} + +// ---------------------------------------------------------------------------- + +func (s *peerSuper) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) { + buf := pkt.Marshal(s.buf) + h := header{ + StreamID: controlStreamID, + Counter: atomic.AddUint64(s.counter, 1), + SourceIP: s.localIP, + DestIP: s.remoteIP, + } + + buf = s.staged.controlCipher.Encrypt(h, buf, s.encBuf) + if s.staged.relayIP != 0 { + s.peers[s.staged.relayIP].RelayTo(s.remoteIP, buf) + } else { + s.conn.WriteTo(buf, s.staged.remoteAddr) + } +} diff --git a/node/peer-supervisor.go b/node/peer-supervisor.go index 3f3e0a0..50401b8 100644 --- a/node/peer-supervisor.go +++ b/node/peer-supervisor.go @@ -15,55 +15,20 @@ const ( func (rp *remotePeer) supervise(conf m.PeerConfig) { defer panicHandler() - base := &stateBase{ - published: rp.published, - peers: rp.peers, - localIP: rp.localIP, - remoteIP: rp.remoteIP, - privKey: conf.EncPrivKey, - localPub: addrIsValid(conf.PublicIP), - conn: rp.conn, - counter: &rp.counter, - pingTimer: time.NewTimer(time.Second), - timeoutTimer: time.NewTimer(time.Second), - buf: make([]byte, bufferSize), - encBuf: make([]byte, bufferSize), + super := &peerSuper{ + published: rp.published, + peers: rp.peers, + localIP: rp.localIP, + localPub: addrIsValid(conf.PublicIP), + remoteIP: rp.remoteIP, + privKey: conf.EncPrivKey, + conn: rp.conn, + counter: &rp.counter, + peerUpdates: rp.peerUpdates, + controlPackets: rp.controlPackets, + buf: make([]byte, bufferSize), + encBuf: make([]byte, bufferSize), } - var ( - curState peerState = newStateNoPeer(base) - nextState peerState - ) - - for { - nextState = nil - - select { - case peer := <-rp.peerUpdates: - nextState = curState.OnPeerUpdate(peer) - - case pkt := <-rp.controlPackets: - switch p := pkt.Payload.(type) { - case synPacket: - nextState = curState.OnSyn(pkt.RemoteAddr, p) - case synAckPacket: - nextState = curState.OnSynAck(pkt.RemoteAddr, p) - case ackPacket: - nextState = curState.OnAck(pkt.RemoteAddr, p) - default: - // Unknown packet type. - } - - case <-base.pingTimer.C: - nextState = curState.OnPingTimer() - - case <-base.timeoutTimer.C: - nextState = curState.OnTimeoutTimer() - } - - if nextState != nil { - rp.logf("%s --> %s", curState.Name(), nextState.Name()) - curState = nextState - } - } + go super.Run() } diff --git a/node/peer.go b/node/peer.go index bae2c9c..1fc3226 100644 --- a/node/peer.go +++ b/node/peer.go @@ -41,6 +41,10 @@ type remotePeer struct { // Used for sending control and data packets. Atomic access only. counter uint64 + // Only accessed in HandlePeerUpdate. Used to determine if we should send + // the peer update to the peerSuper. + peerVersion int64 + // For communicating with the supervisor thread. peerUpdates chan *m.Peer controlPackets chan controlPacket @@ -75,7 +79,12 @@ func (rp *remotePeer) logf(msg string, args ...any) { } func (rp *remotePeer) HandlePeerUpdate(peer *m.Peer) { - rp.peerUpdates <- peer + if peer == nil { + rp.peerUpdates <- peer + } else if peer.Version != rp.peerVersion { + rp.peerVersion = peer.Version + rp.peerUpdates <- peer + } } // ---------------------------------------------------------------------------- @@ -209,7 +218,7 @@ func (rp *remotePeer) HandleInterfacePacket(data []byte) { enc := routingData.dataCipher.Encrypt(h, data, rp.encryptBuf) if routingData.relayIP != 0 { - rp.peers[routingData.relayIP].RelayFor(rp.remoteIP, enc) + rp.peers[routingData.relayIP].RelayTo(rp.remoteIP, enc) } else { rp.SendData(data) } @@ -224,7 +233,7 @@ func (rp *remotePeer) CanRelay() bool { // ---------------------------------------------------------------------------- -func (rp *remotePeer) RelayFor(destIP byte, data []byte) { +func (rp *remotePeer) RelayTo(destIP byte, data []byte) { rp.encryptAndSend(relayStreamID, destIP, data) } diff --git a/node/router.go b/node/router.go deleted file mode 100644 index 116b4d0..0000000 --- a/node/router.go +++ /dev/null @@ -1,7 +0,0 @@ -package node - -import ( - "net/netip" -) - -var zeroAddrPort = netip.AddrPort{}