From 9fd6d90f9cdec592a2187a33ad06a06e6ff57ad5 Mon Sep 17 00:00:00 2001 From: jdl Date: Tue, 25 Feb 2025 02:43:29 +0100 Subject: [PATCH] wip: cleanup --- peer/connreader.go | 9 ++- peer/controlmessage.go | 8 ++ peer/globals.go | 2 + peer/logging.go | 13 ++++ peer/packets.go | 21 +++++- peer/peerstates.go | 94 +++++++++++++++++++++++- peer/peerstates_test.go | 5 ++ peer/peersuper.go | 2 +- peer/pubaddrs.go | 9 +-- peer/state-clientdirect.go | 85 +++++++++++++++++++++ peer/state-clientinit.go | 93 +++++++++++++++++++++++ peer/state-clientrelayed.go | 142 ++++++++++++++++++++++++++++++++++++ peer/state-disconnected.go | 33 +++++++++ peer/state-server.go | 127 ++++++++++++++++++++++++++++++++ peer/statedata.go | 28 +++++++ 15 files changed, 657 insertions(+), 14 deletions(-) create mode 100644 peer/logging.go create mode 100644 peer/state-clientdirect.go create mode 100644 peer/state-clientinit.go create mode 100644 peer/state-clientrelayed.go create mode 100644 peer/state-disconnected.go create mode 100644 peer/state-server.go create mode 100644 peer/statedata.go diff --git a/peer/connreader.go b/peer/connreader.go index b78e58f..0727ced 100644 --- a/peer/connreader.go +++ b/peer/connreader.go @@ -84,6 +84,7 @@ func (r *connReader) handleControlPacket( enc []byte, ) { if peer.ControlCipher == nil { + log.Printf("No control cipher for peer: %v", h) return } @@ -125,13 +126,13 @@ func (r *connReader) handleDataPacket( return } - relay, ok := rt.GetRelay() - if !ok { - r.logf("Relay not available.") + remote := rt.Peers[h.DestIP] + if !remote.Direct { + r.logf("Unable to relay data to %d.", h.DestIP) return } - r.writeToUDPAddrPort(data, relay.DirectAddr) + r.writeToUDPAddrPort(data, remote.DirectAddr) } func (r *connReader) logf(format string, args ...any) { diff --git a/peer/controlmessage.go b/peer/controlmessage.go index 3a18bc8..33d4e9c 100644 --- a/peer/controlmessage.go +++ b/peer/controlmessage.go @@ -41,6 +41,14 @@ func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error Packet: packet, }, err + case packetTypeInit: + packet, err := parsePacketInit(buf) + return controlMsg[packetInit]{ + SrcIP: srcIP, + SrcAddr: srcAddr, + Packet: packet, + }, err + default: return nil, errUnknownPacketType } diff --git a/peer/globals.go b/peer/globals.go index f967c8a..cd0e1f6 100644 --- a/peer/globals.go +++ b/peer/globals.go @@ -7,6 +7,8 @@ import ( ) const ( + version = 1 + bufferSize = 1536 if_mtu = 1200 diff --git a/peer/logging.go b/peer/logging.go new file mode 100644 index 0000000..4906b04 --- /dev/null +++ b/peer/logging.go @@ -0,0 +1,13 @@ +package peer + +import "log" + +func logPacket(p []byte, notes string) { + h := parseHeader(p) + log.Printf(`Sending: Data: %v | From: %d | To: %d | %s +`, + h.StreamID == dataStreamID, + h.SourceIP, + h.DestIP, + notes) +} diff --git a/peer/packets.go b/peer/packets.go index 5be89b0..b673a4c 100644 --- a/peer/packets.go +++ b/peer/packets.go @@ -6,19 +6,38 @@ import ( const ( packetTypeSyn = 1 + packetTypeInit = 2 packetTypeAck = 3 packetTypeProbe = 4 packetTypeAddrDiscovery = 5 - packetTypeInit = 6 ) // ---------------------------------------------------------------------------- type packetInit struct { TraceID uint64 + Direct bool Version uint64 } +func (p packetInit) Marshal(buf []byte) []byte { + return newBinWriter(buf). + Byte(packetTypeInit). + Uint64(p.TraceID). + Bool(p.Direct). + Uint64(p.Version). + Build() +} + +func parsePacketInit(buf []byte) (p packetInit, err error) { + err = newBinReader(buf[1:]). + Uint64(&p.TraceID). + Bool(&p.Direct). + Uint64(&p.Version). + Error() + return +} + // ---------------------------------------------------------------------------- type packetSyn struct { diff --git a/peer/peerstates.go b/peer/peerstates.go index b5abfb7..6c52f55 100644 --- a/peer/peerstates.go +++ b/peer/peerstates.go @@ -43,6 +43,7 @@ type pState struct { limiter *ratelimiter.Limiter } +/* func (s *pState) OnPeerUpdate(peer *m.Peer) peerState { defer func() { // Don't defer directly otherwise s.staged will be evaluated immediately @@ -78,7 +79,7 @@ func (s *pState) OnPeerUpdate(peer *m.Peer) peerState { return enterStateServer(s) } - return enterStateClientDirect(s) + return enterStateClientinit(s) } if s.localAddr.IsValid() { @@ -90,8 +91,9 @@ func (s *pState) OnPeerUpdate(peer *m.Peer) peerState { return enterStateServer(s) } - return enterStateClientRelayed(s) + return enterStateClientinit(s) } +*/ func (s *pState) logf(format string, args ...any) { b := strings.Builder{} @@ -140,6 +142,7 @@ func (s *pState) Send(peer remotePeer, pkt marshaller) { // ---------------------------------------------------------------------------- +/* type stateDisconnected struct{ *pState } func enterStateDisconnected(s *pState) peerState { @@ -181,6 +184,8 @@ func (s *stateServer) OnMsg(rawMsg any) peerState { switch msg := rawMsg.(type) { case peerUpdateMsg: return s.OnPeerUpdate(msg.Peer) + case controlMsg[packetInit]: + return s.OnInit(msg) case controlMsg[packetSyn]: return s.OnSyn(msg) case controlMsg[packetProbe]: @@ -193,6 +198,21 @@ func (s *stateServer) OnMsg(rawMsg any) peerState { } } +func (s *stateServer) OnInit(msg controlMsg[packetInit]) peerState { + s.logf("Responding to INIT.") + route := s.staged + route.Direct = msg.Packet.Direct + route.DirectAddr = msg.SrcAddr + + s.Send(route, packetInit{ + TraceID: msg.Packet.TraceID, + Direct: route.Direct, + Version: version, + }) + + return s +} + func (s *stateServer) OnSyn(msg controlMsg[packetSyn]) peerState { s.lastSeen = time.Now() p := msg.Packet @@ -252,6 +272,75 @@ func (s *stateServer) OnPingTimer() peerState { // ---------------------------------------------------------------------------- +type stateClientInit struct { + *stateDisconnected + startedAt time.Time + traceID uint64 +} + +func enterStateClientinit(s *pState) peerState { + s.logf("==> ClientInit") + s.pingTimer.Reset(pingInterval) + + state := &stateClientInit{ + stateDisconnected: &stateDisconnected{s}, + startedAt: time.Now(), + traceID: newTraceID(), + } + state.Send(s.staged, packetInit{ + TraceID: state.traceID, + Direct: s.staged.Direct, + Version: version, + }) + return state +} + +func (s *stateClientInit) OnMsg(raw any) peerState { + switch msg := raw.(type) { + case peerUpdateMsg: + return s.OnPeerUpdate(msg.Peer) + case controlMsg[packetInit]: + return s.onInit(msg) + case pingTimerMsg: + return s.onPing() + default: + return s + } +} + +func (s *stateClientInit) onInit(msg controlMsg[packetInit]) peerState { + if msg.Packet.TraceID != s.traceID { + s.logf("Invalid trace ID on INIT.") + return s + } + s.logf("Got INIT version %d.", msg.Packet.Version) + return s.nextState() +} + +func (s *stateClientInit) onPing() peerState { + if time.Since(s.startedAt) > timeoutInterval { + s.logf("Init timeout. Assuming version 1.") + return s.nextState() + } + + s.traceID = newTraceID() + s.Send(s.staged, packetInit{ + TraceID: s.traceID, + Direct: s.staged.Direct, + Version: version, + }) + return s +} + +func (s *stateClientInit) nextState() peerState { + if s.staged.Direct { + return enterStateClientDirect(s.pState) + } + return enterStateClientRelayed(s.pState) +} + +// ---------------------------------------------------------------------------- + type stateClientDirect struct { *stateDisconnected lastSeen time.Time @@ -418,3 +507,4 @@ func (s *stateClientRelayed) sendProbeTo(addr netip.AddrPort) { s.probes[probe.TraceID] = addr s.SendTo(probe, addr) } +*/ diff --git a/peer/peerstates_test.go b/peer/peerstates_test.go index cbe2474..15f7d18 100644 --- a/peer/peerstates_test.go +++ b/peer/peerstates_test.go @@ -53,6 +53,10 @@ func (h *PeerStateTestHarness) PeerUpdate(p *m.Peer) { h.State = h.State.OnMsg(peerUpdateMsg{p}) } +func (h *PeerStateTestHarness) OnInit(msg controlMsg[packetInit]) { + h.State = h.State.OnMsg(msg) +} + func (h *PeerStateTestHarness) OnSyn(msg controlMsg[packetSyn]) { h.State = h.State.OnMsg(msg) } @@ -110,6 +114,7 @@ func (h *PeerStateTestHarness) ConfigClientDirect(t *testing.T) *stateClientDire h.PeerUpdate(peer) assertEqual(t, h.Published.Up, false) + return assertType[*stateClientDirect](t, h.State) } diff --git a/peer/peersuper.go b/peer/peersuper.go index 6fa724a..ec8c741 100644 --- a/peer/peersuper.go +++ b/peer/peersuper.go @@ -124,7 +124,7 @@ type peerSuper struct { func newPeerSuper(state *pState, pingTimer *time.Ticker) *peerSuper { return &peerSuper{ messages: make(chan any, 8), - state: state.OnPeerUpdate(nil), + state: initPeerState(state, nil), pingTimer: pingTimer, } } diff --git a/peer/pubaddrs.go b/peer/pubaddrs.go index 027057a..c56b28e 100644 --- a/peer/pubaddrs.go +++ b/peer/pubaddrs.go @@ -1,9 +1,7 @@ package peer import ( - "log" "net/netip" - "runtime/debug" "sort" "sync" "time" @@ -27,14 +25,13 @@ func newPubAddrStore(localAddr netip.AddrPort) *pubAddrStore { } func (store *pubAddrStore) Store(add netip.AddrPort) { - store.lock.Lock() - defer store.lock.Unlock() - if store.localPub { - log.Printf("OOPS: Local pub but storage attempt: %s", debug.Stack()) return } + store.lock.Lock() + defer store.lock.Unlock() + if !add.IsValid() { return } diff --git a/peer/state-clientdirect.go b/peer/state-clientdirect.go new file mode 100644 index 0000000..c6c552d --- /dev/null +++ b/peer/state-clientdirect.go @@ -0,0 +1,85 @@ +package peer + +import ( + "net/netip" + "time" +) + +type stateClientDirect2 struct { + *peerData + lastSeen time.Time + syn packetSyn +} + +func enterStateClientDirect2(data *peerData, directAddr netip.AddrPort) peerState { + data.staged.Relay = data.peer.Relay + data.staged.Direct = true + data.staged.DirectAddr = directAddr + data.publish(data.staged) + + state := &stateClientDirect2{ + peerData: data, + lastSeen: time.Now(), + syn: packetSyn{ + TraceID: newTraceID(), + SharedKey: data.staged.DataCipher.Key(), + Direct: true, + }, + } + + state.Send(state.staged, state.syn) + + data.pingTimer.Reset(pingInterval) + + state.logf("==> ClientDirect") + return state +} + +func (s *stateClientDirect2) logf(str string, args ...any) { + s.peerData.logf("CLNT | "+str, args...) +} + +func (s *stateClientDirect2) OnMsg(raw any) peerState { + switch msg := raw.(type) { + case peerUpdateMsg: + return initPeerState(s.peerData, msg.Peer) + case controlMsg[packetAck]: + return s.onAck(msg) + case pingTimerMsg: + return s.onPingTimer() + case controlMsg[packetLocalDiscovery]: + return s + default: + s.logf("Ignoring message: %v", raw) + return s + } +} + +func (s *stateClientDirect2) onAck(msg controlMsg[packetAck]) peerState { + if msg.Packet.TraceID != s.syn.TraceID { + return s + } + + s.lastSeen = time.Now() + + if !s.staged.Up { + s.staged.Up = true + s.publish(s.staged) + s.logf("Got ACK.") + } + + s.pubAddrs.Store(msg.Packet.ToAddr) + return s +} + +func (s *stateClientDirect2) onPingTimer() peerState { + if time.Since(s.lastSeen) > timeoutInterval { + if s.staged.Up { + s.logf("Timeout.") + } + return initPeerState(s.peerData, s.peer) + } + + s.Send(s.staged, s.syn) + return s +} diff --git a/peer/state-clientinit.go b/peer/state-clientinit.go new file mode 100644 index 0000000..8a84100 --- /dev/null +++ b/peer/state-clientinit.go @@ -0,0 +1,93 @@ +package peer + +import ( + "net/netip" + "time" +) + +type stateClientInit2 struct { + *peerData + startedAt time.Time + traceID uint64 +} + +func enterStateClientInit2(data *peerData) peerState { + ip, ipValid := netip.AddrFromSlice(data.peer.PublicIP) + + data.staged.Up = false + data.staged.Relay = false + data.staged.Direct = ipValid + data.staged.DirectAddr = netip.AddrPortFrom(ip, data.peer.Port) + data.staged.PubSignKey = data.peer.PubSignKey + data.staged.ControlCipher = newControlCipher(data.privKey, data.peer.PubKey) + data.staged.DataCipher = newDataCipher() + + data.publish(data.staged) + + state := &stateClientInit2{ + peerData: data, + startedAt: time.Now(), + traceID: newTraceID(), + } + state.sendInit() + + data.pingTimer.Reset(pingInterval) + + state.logf("==> ClientInit") + return state +} + +func (s *stateClientInit2) logf(str string, args ...any) { + s.peerData.logf("INIT | "+str, args...) +} + +func (s *stateClientInit2) OnMsg(raw any) peerState { + switch msg := raw.(type) { + case peerUpdateMsg: + return initPeerState(s.peerData, msg.Peer) + case controlMsg[packetInit]: + return s.onInit(msg) + case pingTimerMsg: + return s.onPing() + default: + s.logf("Ignoring message: %v", raw) + return s + } +} + +func (s *stateClientInit2) onInit(msg controlMsg[packetInit]) peerState { + if msg.Packet.TraceID != s.traceID { + s.logf("Invalid trace ID on INIT.") + return s + } + s.logf("Got INIT version %d.", msg.Packet.Version) + return s.nextState() +} + +func (s *stateClientInit2) onPing() peerState { + if time.Since(s.startedAt) > timeoutInterval { + s.logf("Init timeout. Assuming version 1.") + return s.nextState() + } + + s.sendInit() + return s +} + +func (s *stateClientInit2) sendInit() { + s.traceID = newTraceID() + init := packetInit{ + TraceID: s.traceID, + Direct: s.staged.Direct, + Version: version, + } + s.Send(s.staged, init) +} + +func (s *stateClientInit2) nextState() peerState { + if s.staged.Direct { + return enterStateClientDirect2(s.peerData, s.staged.DirectAddr) + } + + return enterStateClientRelayed2(s.peerData) +} diff --git a/peer/state-clientrelayed.go b/peer/state-clientrelayed.go new file mode 100644 index 0000000..737f0a9 --- /dev/null +++ b/peer/state-clientrelayed.go @@ -0,0 +1,142 @@ +package peer + +import ( + "net/netip" + "time" +) + +type sentProbe struct { + SentAt time.Time + Addr netip.AddrPort +} + +type stateClientRelayed2 struct { + *peerData + lastSeen time.Time + syn packetSyn + probes map[uint64]sentProbe +} + +func enterStateClientRelayed2(data *peerData) peerState { + data.staged.Relay = false + data.staged.Direct = false + data.staged.DirectAddr = netip.AddrPort{} + data.publish(data.staged) + + state := &stateClientRelayed2{ + peerData: data, + lastSeen: time.Now(), + syn: packetSyn{ + TraceID: newTraceID(), + SharedKey: data.staged.DataCipher.Key(), + Direct: false, + PossibleAddrs: data.pubAddrs.Get(), + }, + probes: map[uint64]sentProbe{}, + } + + state.Send(state.staged, state.syn) + + data.pingTimer.Reset(pingInterval) + + state.logf("==> ClientRelayed") + return state +} + +func (s *stateClientRelayed2) logf(str string, args ...any) { + s.peerData.logf("CLNT | "+str, args...) +} + +func (s *stateClientRelayed2) OnMsg(raw any) peerState { + switch msg := raw.(type) { + case peerUpdateMsg: + return initPeerState(s.peerData, msg.Peer) + case controlMsg[packetAck]: + return s.onAck(msg) + case controlMsg[packetProbe]: + return s.onProbe(msg) + case controlMsg[packetLocalDiscovery]: + return s.onLocalDiscovery(msg) + case pingTimerMsg: + return s.onPingTimer() + default: + s.logf("Ignoring message: %v", raw) + return s + } +} + +func (s *stateClientRelayed2) onAck(msg controlMsg[packetAck]) peerState { + if msg.Packet.TraceID != s.syn.TraceID { + return s + } + + s.lastSeen = time.Now() + + if !s.staged.Up { + s.staged.Up = true + s.publish(s.staged) + s.logf("Got ACK.") + } + + s.pubAddrs.Store(msg.Packet.ToAddr) + + for _, addr := range msg.Packet.PossibleAddrs { + if !addr.IsValid() { + break + } + s.sendProbeTo(addr) + } + + s.cleanProbes() + + return s +} + +func (s *stateClientRelayed2) onPingTimer() peerState { + if time.Since(s.lastSeen) > timeoutInterval { + if s.staged.Up { + s.logf("Timeout.") + } + return initPeerState(s.peerData, s.peer) + } + + s.Send(s.staged, s.syn) + return s +} + +func (s *stateClientRelayed2) onProbe(msg controlMsg[packetProbe]) peerState { + s.cleanProbes() + + sent, ok := s.probes[msg.Packet.TraceID] + if !ok { + return s + } + + s.logf("Successful probe.") + return enterStateClientDirect2(s.peerData, sent.Addr) +} + +func (s *stateClientRelayed2) onLocalDiscovery(msg controlMsg[packetLocalDiscovery]) peerState { + // The source port will be the multicast port, so we'll have to + // construct the correct address using the peer's listed port. + addr := netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) + s.sendProbeTo(addr) + return s +} + +func (s *stateClientRelayed2) cleanProbes() { + for key, sent := range s.probes { + if time.Since(sent.SentAt) > pingInterval { + delete(s.probes, key) + } + } +} + +func (s *stateClientRelayed2) sendProbeTo(addr netip.AddrPort) { + probe := packetProbe{TraceID: newTraceID()} + s.probes[probe.TraceID] = sentProbe{ + SentAt: time.Now(), + Addr: addr, + } + s.SendTo(probe, addr) +} diff --git a/peer/state-disconnected.go b/peer/state-disconnected.go new file mode 100644 index 0000000..3fdbd23 --- /dev/null +++ b/peer/state-disconnected.go @@ -0,0 +1,33 @@ +package peer + +import "net/netip" + +type stateDisconnected2 struct { + *peerData +} + +func enterStateDisconnected2(data *peerData) peerState { + data.staged.Up = false + data.staged.Relay = false + data.staged.Direct = false + data.staged.DirectAddr = netip.AddrPort{} + data.staged.PubSignKey = nil + data.staged.ControlCipher = nil + data.staged.DataCipher = nil + + data.publish(data.staged) + + data.pingTimer.Stop() + + return &stateDisconnected2{data} +} + +func (s *stateDisconnected2) OnMsg(raw any) peerState { + switch msg := raw.(type) { + case peerUpdateMsg: + return initPeerState(s.peerData, msg.Peer) + default: + s.logf("Ignoring message: %v", raw) + return s + } +} diff --git a/peer/state-server.go b/peer/state-server.go new file mode 100644 index 0000000..f3d19da --- /dev/null +++ b/peer/state-server.go @@ -0,0 +1,127 @@ +package peer + +import ( + "net/netip" + "time" +) + +type stateServer2 struct { + *peerData + lastSeen time.Time + synTraceID uint64 // Last syn trace ID. +} + +func enterStateServer2(data *peerData) peerState { + data.staged.Up = false + data.staged.Relay = false + data.staged.Direct = false + data.staged.DirectAddr = netip.AddrPort{} + data.staged.PubSignKey = data.peer.PubSignKey + data.staged.ControlCipher = newControlCipher(data.privKey, data.peer.PubKey) + data.staged.DataCipher = nil + + data.publish(data.staged) + + data.pingTimer.Reset(pingInterval) + + state := &stateServer2{peerData: data} + state.logf("==> Server") + return state +} + +func (s *stateServer2) logf(str string, args ...any) { + s.peerData.logf("SRVR | "+str, args...) +} + +func (s *stateServer2) OnMsg(raw any) peerState { + switch msg := raw.(type) { + case peerUpdateMsg: + return initPeerState(s.peerData, msg.Peer) + case controlMsg[packetInit]: + return s.onInit(msg) + case controlMsg[packetSyn]: + return s.onSyn(msg) + case controlMsg[packetProbe]: + return s.onProbe(msg) + case controlMsg[packetLocalDiscovery]: + return s + case pingTimerMsg: + return s.onPingTimer() + default: + s.logf("Ignoring message: %v", raw) + return s + } +} + +func (s *stateServer2) onInit(msg controlMsg[packetInit]) peerState { + s.staged.Up = false + s.staged.Direct = msg.Packet.Direct + s.staged.DirectAddr = msg.SrcAddr + s.publish(s.staged) + + init := packetInit{ + TraceID: msg.Packet.TraceID, + Direct: s.staged.Direct, + Version: version, + } + + s.Send(s.staged, init) + + return s +} + +func (s *stateServer2) onSyn(msg controlMsg[packetSyn]) peerState { + s.lastSeen = time.Now() + p := msg.Packet + + // Before we can respond to this packet, we need to make sure the + // route is setup properly. + // + // The client will update the syn's TraceID whenever there's a change. + // The server will follow the client's request. + if p.TraceID != s.synTraceID || !s.staged.Up { + s.synTraceID = p.TraceID + s.staged.Up = true + s.staged.Direct = p.Direct + s.staged.DataCipher = newDataCipherFromKey(p.SharedKey) + s.staged.DirectAddr = msg.SrcAddr + s.publish(s.staged) + s.logf("Got SYN.") + } + + // Always respond. + s.Send(s.staged, packetAck{ + TraceID: p.TraceID, + ToAddr: s.staged.DirectAddr, + PossibleAddrs: s.pubAddrs.Get(), + }) + + if p.Direct { + return s + } + + for _, addr := range msg.Packet.PossibleAddrs { + if !addr.IsValid() { + break + } + s.SendTo(packetProbe{TraceID: newTraceID()}, addr) + } + + return s +} + +func (s *stateServer2) onProbe(msg controlMsg[packetProbe]) peerState { + if msg.SrcAddr.IsValid() { + s.SendTo(packetProbe{TraceID: msg.Packet.TraceID}, msg.SrcAddr) + } + return s +} + +func (s *stateServer2) onPingTimer() peerState { + if time.Since(s.lastSeen) > timeoutInterval && s.staged.Up { + s.staged.Up = false + s.publish(s.staged) + s.logf("Timeout.") + } + return s +} diff --git a/peer/statedata.go b/peer/statedata.go new file mode 100644 index 0000000..44330fa --- /dev/null +++ b/peer/statedata.go @@ -0,0 +1,28 @@ +package peer + +import ( + "net/netip" + "vppn/m" +) + +type peerData = pState + +func initPeerState(data *peerData, peer *m.Peer) peerState { + data.peer = peer + + if peer == nil { + return enterStateDisconnected2(data) + } + + if _, isValid := netip.AddrFromSlice(peer.PublicIP); isValid { + if data.localAddr.IsValid() && data.localIP < data.remoteIP { + return enterStateServer2(data) + } + return enterStateClientInit2(data) + } + + if data.localAddr.IsValid() || data.localIP < data.remoteIP { + return enterStateServer2(data) + } + return enterStateClientInit2(data) +}