diff --git a/peer/controlmessage.go b/peer/controlmessage.go index 09935ab..3a18bc8 100644 --- a/peer/controlmessage.go +++ b/peer/controlmessage.go @@ -49,8 +49,7 @@ func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error // ---------------------------------------------------------------------------- type peerUpdateMsg struct { - PeerIP byte - Peer *m.Peer + Peer *m.Peer } // ---------------------------------------------------------------------------- diff --git a/peer/hubpoller.go b/peer/hubpoller.go index 2b50495..238dfda 100644 --- a/peer/hubpoller.go +++ b/peer/hubpoller.go @@ -96,7 +96,7 @@ func (hp *hubPoller) applyNetworkState(state m.NetworkState) { for i, peer := range state.Peers { if i != int(hp.localIP) { if peer == nil || peer.Version != hp.versions[i] { - hp.handleControlMsg(byte(i), peerUpdateMsg{PeerIP: byte(i), Peer: state.Peers[i]}) + hp.handleControlMsg(byte(i), peerUpdateMsg{Peer: state.Peers[i]}) if peer != nil { hp.versions[i] = peer.Version } diff --git a/peer/mcwriter.go b/peer/mcwriter.go index 5559547..29cf2be 100644 --- a/peer/mcwriter.go +++ b/peer/mcwriter.go @@ -44,7 +44,7 @@ func runMCWriter(localIP byte, signingKey []byte) { log.Fatalf("Failed to bind to multicast address: %v", err) } - for range time.Tick(16 * time.Second) { + for range time.Tick(8 * time.Second) { _, err := conn.WriteToUDP(discoveryPacket, multicastAddr) if err != nil { log.Printf("[MCWriter] Failed to write multicast: %v", err) diff --git a/peer/peerstates.go b/peer/peerstates.go index a68afb1..b5abfb7 100644 --- a/peer/peerstates.go +++ b/peer/peerstates.go @@ -12,12 +12,7 @@ import ( ) type peerState interface { - OnPeerUpdate(*m.Peer) peerState - OnSyn(controlMsg[packetSyn]) peerState - OnAck(controlMsg[packetAck]) - OnProbe(controlMsg[packetProbe]) peerState - OnLocalDiscovery(controlMsg[packetLocalDiscovery]) - OnPingTimer() peerState + OnMsg(raw any) peerState } // ---------------------------------------------------------------------------- @@ -26,6 +21,7 @@ type pState struct { // Output. publish func(remotePeer) sendControlPacket func(remotePeer, marshaller) + pingTimer *time.Ticker // Immutable data. localIP byte @@ -147,9 +143,20 @@ func (s *pState) Send(peer remotePeer, pkt marshaller) { type stateDisconnected struct{ *pState } func enterStateDisconnected(s *pState) peerState { + s.pingTimer.Stop() return &stateDisconnected{pState: s} } +func (s *stateDisconnected) OnMsg(raw any) peerState { + switch msg := raw.(type) { + case peerUpdateMsg: + return s.OnPeerUpdate(msg.Peer) + default: + // TODO: Log. + return s + } +} + func (s *stateDisconnected) OnSyn(controlMsg[packetSyn]) peerState { return s } func (s *stateDisconnected) OnAck(controlMsg[packetAck]) {} func (s *stateDisconnected) OnProbe(controlMsg[packetProbe]) peerState { return s } @@ -166,9 +173,26 @@ type stateServer struct { func enterStateServer(s *pState) peerState { s.logf("==> Server") + s.pingTimer.Reset(pingInterval) return &stateServer{stateDisconnected: &stateDisconnected{pState: s}} } +func (s *stateServer) OnMsg(rawMsg any) peerState { + switch msg := rawMsg.(type) { + case peerUpdateMsg: + return s.OnPeerUpdate(msg.Peer) + case controlMsg[packetSyn]: + return s.OnSyn(msg) + case controlMsg[packetProbe]: + return s.OnProbe(msg) + case pingTimerMsg: + return s.OnPingTimer() + default: + // TODO: Log + return s + } +} + func (s *stateServer) OnSyn(msg controlMsg[packetSyn]) peerState { s.lastSeen = time.Now() p := msg.Packet @@ -236,6 +260,7 @@ type stateClientDirect struct { func enterStateClientDirect(s *pState) peerState { s.logf("==> ClientDirect") + s.pingTimer.Reset(pingInterval) return newStateClientDirect(s) } @@ -255,6 +280,24 @@ func newStateClientDirect(s *pState) *stateClientDirect { return state } +func (s *stateClientDirect) OnMsg(raw any) peerState { + switch msg := raw.(type) { + case peerUpdateMsg: + return s.OnPeerUpdate(msg.Peer) + case controlMsg[packetAck]: + s.OnAck(msg) + return s + case pingTimerMsg: + if next := s.onPingTimer(); next != nil { + return next + } + return s + default: + // TODO: Log + return s + } +} + func (s *stateClientDirect) OnAck(msg controlMsg[packetAck]) { if msg.Packet.TraceID != s.syn.TraceID { return @@ -271,13 +314,6 @@ func (s *stateClientDirect) OnAck(msg controlMsg[packetAck]) { s.pubAddrs.Store(msg.Packet.ToAddr) } -func (s *stateClientDirect) OnPingTimer() peerState { - if next := s.onPingTimer(); next != nil { - return next - } - return s -} - func (s *stateClientDirect) onPingTimer() peerState { if time.Since(s.lastSeen) > timeoutInterval { if s.staged.Up { @@ -297,21 +333,44 @@ func (s *stateClientDirect) onPingTimer() peerState { type stateClientRelayed struct { *stateClientDirect ack packetAck - probes map[uint64]netip.AddrPort - localDiscoveryAddr netip.AddrPort + probes map[uint64]netip.AddrPort // TODO: something better + localDiscoveryAddr netip.AddrPort // TODO: Remove } func enterStateClientRelayed(s *pState) peerState { s.logf("==> ClientRelayed") + s.pingTimer.Reset(pingInterval) return &stateClientRelayed{ stateClientDirect: newStateClientDirect(s), probes: map[uint64]netip.AddrPort{}, } } +func (s *stateClientRelayed) OnMsg(raw any) peerState { + switch msg := raw.(type) { + case peerUpdateMsg: + return s.OnPeerUpdate(msg.Peer) + case controlMsg[packetAck]: + s.OnAck(msg) + return s + case controlMsg[packetProbe]: + return s.OnProbe(msg) + case controlMsg[packetLocalDiscovery]: + s.OnLocalDiscovery(msg) + return s + case pingTimerMsg: + return s.OnPingTimer() + default: + // TODO: Log + return s + } +} + func (s *stateClientRelayed) OnAck(msg controlMsg[packetAck]) { s.ack = msg.Packet s.stateClientDirect.OnAck(msg) + + // TODO: Send probes now. } func (s *stateClientRelayed) OnProbe(msg controlMsg[packetProbe]) peerState { @@ -330,6 +389,7 @@ func (s *stateClientRelayed) OnLocalDiscovery(msg controlMsg[packetLocalDiscover // The source port will be the multicast port, so we'll have to // construct the correct address using the peer's listed port. s.localDiscoveryAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) + // TODO: s.sendProbeTo(s.localDiscoveryAddr) } func (s *stateClientRelayed) OnPingTimer() peerState { diff --git a/peer/peerstates_test.go b/peer/peerstates_test.go index daf5c14..cbe2474 100644 --- a/peer/peerstates_test.go +++ b/peer/peerstates_test.go @@ -34,10 +34,11 @@ func NewPeerStateTestHarness() *PeerStateTestHarness { sendControlPacket: func(rp remotePeer, pkt marshaller) { h.Sent = append(h.Sent, PeerStateControlMsg{rp, pkt}) }, - localIP: 2, - remoteIP: 3, - privKey: keys.PrivKey, - pubAddrs: newPubAddrStore(netip.AddrPort{}), + pingTimer: time.NewTicker(pingInterval), + localIP: 2, + remoteIP: 3, + privKey: keys.PrivKey, + pubAddrs: newPubAddrStore(netip.AddrPort{}), limiter: ratelimiter.New(ratelimiter.Config{ FillPeriod: 20 * time.Millisecond, MaxWaitCount: 1, @@ -49,27 +50,19 @@ func NewPeerStateTestHarness() *PeerStateTestHarness { } func (h *PeerStateTestHarness) PeerUpdate(p *m.Peer) { - if s := h.State.OnPeerUpdate(p); s != nil { - h.State = s - } + h.State = h.State.OnMsg(peerUpdateMsg{p}) } func (h *PeerStateTestHarness) OnSyn(msg controlMsg[packetSyn]) { - if s := h.State.OnSyn(msg); s != nil { - h.State = s - } + h.State = h.State.OnMsg(msg) } func (h *PeerStateTestHarness) OnProbe(msg controlMsg[packetProbe]) { - if s := h.State.OnProbe(msg); s != nil { - h.State = s - } + h.State = h.State.OnMsg(msg) } func (h *PeerStateTestHarness) OnPingTimer() { - if s := h.State.OnPingTimer(); s != nil { - h.State = s - } + h.State = h.State.OnMsg(pingTimerMsg{}) } func (h *PeerStateTestHarness) ConfigServer_Public(t *testing.T) *stateServer { @@ -202,7 +195,7 @@ func TestStateServer_directSyn(t *testing.T) { }, } - h.State.OnSyn(synMsg) + h.State = h.State.OnMsg(synMsg) assertEqual(t, len(h.Sent), 1) ack := assertType[packetAck](t, h.Sent[0].Packet) @@ -233,7 +226,7 @@ func TestStateServer_relayedSyn(t *testing.T) { synMsg.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 3, 300) synMsg.Packet.PossibleAddrs[1] = addrPort4(2, 2, 2, 3, 300) - h.State.OnSyn(synMsg) + h.State = h.State.OnMsg(synMsg) assertEqual(t, len(h.Sent), 3) @@ -261,7 +254,7 @@ func TestStateServer_onProbe(t *testing.T) { Packet: packetProbe{TraceID: newTraceID()}, } - h.State.OnProbe(probeMsg) + h.State = h.State.OnMsg(probeMsg) assertEqual(t, len(h.Sent), 1) @@ -285,7 +278,7 @@ func TestStateServer_OnPingTimer_timeout(t *testing.T) { }, } - h.State.OnSyn(synMsg) + h.State = h.State.OnMsg(synMsg) assertEqual(t, len(h.Sent), 1) assertEqual(t, h.Published.Up, true) @@ -314,7 +307,7 @@ func TestStateClientDirect_OnAck(t *testing.T) { ack := controlMsg[packetAck]{ Packet: packetAck{TraceID: syn.TraceID}, } - h.State.OnAck(ack) + h.State = h.State.OnMsg(ack) assertEqual(t, h.Published.Up, true) } @@ -331,7 +324,7 @@ func TestStateClientDirect_OnAck_incorrectTraceID(t *testing.T) { ack := controlMsg[packetAck]{ Packet: packetAck{TraceID: syn.TraceID + 1}, } - h.State.OnAck(ack) + h.State = h.State.OnMsg(ack) assertEqual(t, h.Published.Up, false) } @@ -366,7 +359,7 @@ func TestStateClientDirect_OnPingTimer_timeout(t *testing.T) { ack := controlMsg[packetAck]{ Packet: packetAck{TraceID: syn.TraceID}, } - h.State.OnAck(ack) + h.State = h.State.OnMsg(ack) assertEqual(t, h.Published.Up, true) state := assertType[*stateClientDirect](t, h.State) @@ -395,7 +388,7 @@ func TestStateClientRelayed_OnAck(t *testing.T) { ack := controlMsg[packetAck]{ Packet: packetAck{TraceID: syn.TraceID}, } - h.State.OnAck(ack) + h.State = h.State.OnMsg(ack) assertEqual(t, h.Published.Up, true) } @@ -429,11 +422,11 @@ func TestStateClientRelayed_OnPingTimer_withAddrs(t *testing.T) { ack.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 300) ack.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 300) - h.State.OnAck(ack) + h.State = h.State.OnMsg(ack) // Add a local discovery address. Note that the port will be configured port // and no the one provided here. - h.State.OnLocalDiscovery(controlMsg[packetLocalDiscovery]{ + h.State = h.State.OnMsg(controlMsg[packetLocalDiscovery]{ SrcIP: 3, SrcAddr: addrPort4(2, 2, 2, 3, 300), }) @@ -462,7 +455,7 @@ func TestStateClientRelayed_OnPingTimer_timeout(t *testing.T) { ack := controlMsg[packetAck]{ Packet: packetAck{TraceID: syn.TraceID}, } - h.State.OnAck(ack) + h.State = h.State.OnMsg(ack) assertEqual(t, h.Published.Up, true) state := assertType[*stateClientRelayed](t, h.State) @@ -499,7 +492,7 @@ func TestStateClientRelayed_OnProbe_upgradeDirect(t *testing.T) { ack.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 300) ack.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 300) - h.State.OnAck(ack) + h.State = h.State.OnMsg(ack) h.OnPingTimer() probe := assertType[packetProbe](t, h.Sent[2].Packet) diff --git a/peer/peersuper.go b/peer/peersuper.go index 7682d87..6fa724a 100644 --- a/peer/peersuper.go +++ b/peer/peersuper.go @@ -1,8 +1,6 @@ package peer import ( - "log" - "math/rand" "net/netip" "sync" "sync/atomic" @@ -44,6 +42,7 @@ func newSupervisor( state := &pState{ publish: s.publish, sendControlPacket: s.send, + pingTimer: time.NewTicker(timeoutInterval), localIP: routes.LocalIP, remoteIP: byte(i), privKey: privKey, @@ -55,7 +54,7 @@ func newSupervisor( MaxWaitCount: 1, }), } - s.peers[i] = newPeerSuper(state) + s.peers[i] = newPeerSuper(state, state.pingTimer) } return s @@ -105,7 +104,7 @@ func (s *supervisor) ensureRelay() { return } - // TODO: Random selection? + // TODO: Random selection? Something else? for _, peer := range s.staged.Peers { if peer.Up && peer.Direct && peer.Relay { s.staged.RelayIP = peer.IP @@ -117,14 +116,16 @@ func (s *supervisor) ensureRelay() { // ---------------------------------------------------------------------------- type peerSuper struct { - messages chan any - state peerState + messages chan any + state peerState + pingTimer *time.Ticker } -func newPeerSuper(state *pState) *peerSuper { +func newPeerSuper(state *pState, pingTimer *time.Ticker) *peerSuper { return &peerSuper{ - messages: make(chan any, 8), - state: state.OnPeerUpdate(nil), + messages: make(chan any, 8), + state: state.OnPeerUpdate(nil), + pingTimer: pingTimer, } } @@ -136,37 +137,12 @@ func (s *peerSuper) HandleControlMsg(msg any) { } func (s *peerSuper) Run() { - go func() { - // Randomize ping timers. - time.Sleep(time.Duration(rand.Intn(4000)) * time.Millisecond) - for range time.Tick(pingInterval) { - s.messages <- pingTimerMsg{} - } - }() - - for rawMsg := range s.messages { - switch msg := rawMsg.(type) { - - case peerUpdateMsg: - s.state = s.state.OnPeerUpdate(msg.Peer) - - case controlMsg[packetSyn]: - s.state = s.state.OnSyn(msg) - - case controlMsg[packetAck]: - s.state.OnAck(msg) - - case controlMsg[packetProbe]: - s.state = s.state.OnProbe(msg) - - case controlMsg[packetLocalDiscovery]: - s.state.OnLocalDiscovery(msg) - - case pingTimerMsg: - s.state = s.state.OnPingTimer() - - default: - log.Printf("WARNING: unknown message type: %+v", msg) + for { + select { + case <-s.pingTimer.C: + s.state = s.state.OnMsg(pingTimerMsg{}) + case raw := <-s.messages: + s.state = s.state.OnMsg(raw) } } }