refactor-for-testability #3
| @@ -49,7 +49,6 @@ func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error | |||||||
| // ---------------------------------------------------------------------------- | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
| type peerUpdateMsg struct { | type peerUpdateMsg struct { | ||||||
| 	PeerIP byte |  | ||||||
| 	Peer *m.Peer | 	Peer *m.Peer | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -96,7 +96,7 @@ func (hp *hubPoller) applyNetworkState(state m.NetworkState) { | |||||||
| 	for i, peer := range state.Peers { | 	for i, peer := range state.Peers { | ||||||
| 		if i != int(hp.localIP) { | 		if i != int(hp.localIP) { | ||||||
| 			if peer == nil || peer.Version != hp.versions[i] { | 			if peer == nil || peer.Version != hp.versions[i] { | ||||||
| 				hp.handleControlMsg(byte(i), peerUpdateMsg{PeerIP: byte(i), Peer: state.Peers[i]}) | 				hp.handleControlMsg(byte(i), peerUpdateMsg{Peer: state.Peers[i]}) | ||||||
| 				if peer != nil { | 				if peer != nil { | ||||||
| 					hp.versions[i] = peer.Version | 					hp.versions[i] = peer.Version | ||||||
| 				} | 				} | ||||||
|   | |||||||
| @@ -44,7 +44,7 @@ func runMCWriter(localIP byte, signingKey []byte) { | |||||||
| 		log.Fatalf("Failed to bind to multicast address: %v", err) | 		log.Fatalf("Failed to bind to multicast address: %v", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	for range time.Tick(16 * time.Second) { | 	for range time.Tick(8 * time.Second) { | ||||||
| 		_, err := conn.WriteToUDP(discoveryPacket, multicastAddr) | 		_, err := conn.WriteToUDP(discoveryPacket, multicastAddr) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Printf("[MCWriter] Failed to write multicast: %v", err) | 			log.Printf("[MCWriter] Failed to write multicast: %v", err) | ||||||
|   | |||||||
| @@ -12,12 +12,7 @@ import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| type peerState interface { | type peerState interface { | ||||||
| 	OnPeerUpdate(*m.Peer) peerState | 	OnMsg(raw any) peerState | ||||||
| 	OnSyn(controlMsg[packetSyn]) peerState |  | ||||||
| 	OnAck(controlMsg[packetAck]) |  | ||||||
| 	OnProbe(controlMsg[packetProbe]) peerState |  | ||||||
| 	OnLocalDiscovery(controlMsg[packetLocalDiscovery]) |  | ||||||
| 	OnPingTimer() peerState |  | ||||||
| } | } | ||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- | // ---------------------------------------------------------------------------- | ||||||
| @@ -26,6 +21,7 @@ type pState struct { | |||||||
| 	// Output. | 	// Output. | ||||||
| 	publish           func(remotePeer) | 	publish           func(remotePeer) | ||||||
| 	sendControlPacket func(remotePeer, marshaller) | 	sendControlPacket func(remotePeer, marshaller) | ||||||
|  | 	pingTimer         *time.Ticker | ||||||
|  |  | ||||||
| 	// Immutable data. | 	// Immutable data. | ||||||
| 	localIP   byte | 	localIP   byte | ||||||
| @@ -147,9 +143,20 @@ func (s *pState) Send(peer remotePeer, pkt marshaller) { | |||||||
| type stateDisconnected struct{ *pState } | type stateDisconnected struct{ *pState } | ||||||
|  |  | ||||||
| func enterStateDisconnected(s *pState) peerState { | func enterStateDisconnected(s *pState) peerState { | ||||||
|  | 	s.pingTimer.Stop() | ||||||
| 	return &stateDisconnected{pState: s} | 	return &stateDisconnected{pState: s} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (s *stateDisconnected) OnMsg(raw any) peerState { | ||||||
|  | 	switch msg := raw.(type) { | ||||||
|  | 	case peerUpdateMsg: | ||||||
|  | 		return s.OnPeerUpdate(msg.Peer) | ||||||
|  | 	default: | ||||||
|  | 		// TODO: Log. | ||||||
|  | 		return s | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| func (s *stateDisconnected) OnSyn(controlMsg[packetSyn]) peerState             { return s } | func (s *stateDisconnected) OnSyn(controlMsg[packetSyn]) peerState             { return s } | ||||||
| func (s *stateDisconnected) OnAck(controlMsg[packetAck])                       {} | func (s *stateDisconnected) OnAck(controlMsg[packetAck])                       {} | ||||||
| func (s *stateDisconnected) OnProbe(controlMsg[packetProbe]) peerState         { return s } | func (s *stateDisconnected) OnProbe(controlMsg[packetProbe]) peerState         { return s } | ||||||
| @@ -166,9 +173,26 @@ type stateServer struct { | |||||||
|  |  | ||||||
| func enterStateServer(s *pState) peerState { | func enterStateServer(s *pState) peerState { | ||||||
| 	s.logf("==> Server") | 	s.logf("==> Server") | ||||||
|  | 	s.pingTimer.Reset(pingInterval) | ||||||
| 	return &stateServer{stateDisconnected: &stateDisconnected{pState: s}} | 	return &stateServer{stateDisconnected: &stateDisconnected{pState: s}} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (s *stateServer) OnMsg(rawMsg any) peerState { | ||||||
|  | 	switch msg := rawMsg.(type) { | ||||||
|  | 	case peerUpdateMsg: | ||||||
|  | 		return s.OnPeerUpdate(msg.Peer) | ||||||
|  | 	case controlMsg[packetSyn]: | ||||||
|  | 		return s.OnSyn(msg) | ||||||
|  | 	case controlMsg[packetProbe]: | ||||||
|  | 		return s.OnProbe(msg) | ||||||
|  | 	case pingTimerMsg: | ||||||
|  | 		return s.OnPingTimer() | ||||||
|  | 	default: | ||||||
|  | 		// TODO: Log | ||||||
|  | 		return s | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| func (s *stateServer) OnSyn(msg controlMsg[packetSyn]) peerState { | func (s *stateServer) OnSyn(msg controlMsg[packetSyn]) peerState { | ||||||
| 	s.lastSeen = time.Now() | 	s.lastSeen = time.Now() | ||||||
| 	p := msg.Packet | 	p := msg.Packet | ||||||
| @@ -236,6 +260,7 @@ type stateClientDirect struct { | |||||||
|  |  | ||||||
| func enterStateClientDirect(s *pState) peerState { | func enterStateClientDirect(s *pState) peerState { | ||||||
| 	s.logf("==> ClientDirect") | 	s.logf("==> ClientDirect") | ||||||
|  | 	s.pingTimer.Reset(pingInterval) | ||||||
| 	return newStateClientDirect(s) | 	return newStateClientDirect(s) | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -255,6 +280,24 @@ func newStateClientDirect(s *pState) *stateClientDirect { | |||||||
| 	return state | 	return state | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (s *stateClientDirect) OnMsg(raw any) peerState { | ||||||
|  | 	switch msg := raw.(type) { | ||||||
|  | 	case peerUpdateMsg: | ||||||
|  | 		return s.OnPeerUpdate(msg.Peer) | ||||||
|  | 	case controlMsg[packetAck]: | ||||||
|  | 		s.OnAck(msg) | ||||||
|  | 		return s | ||||||
|  | 	case pingTimerMsg: | ||||||
|  | 		if next := s.onPingTimer(); next != nil { | ||||||
|  | 			return next | ||||||
|  | 		} | ||||||
|  | 		return s | ||||||
|  | 	default: | ||||||
|  | 		// TODO: Log | ||||||
|  | 		return s | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| func (s *stateClientDirect) OnAck(msg controlMsg[packetAck]) { | func (s *stateClientDirect) OnAck(msg controlMsg[packetAck]) { | ||||||
| 	if msg.Packet.TraceID != s.syn.TraceID { | 	if msg.Packet.TraceID != s.syn.TraceID { | ||||||
| 		return | 		return | ||||||
| @@ -271,13 +314,6 @@ func (s *stateClientDirect) OnAck(msg controlMsg[packetAck]) { | |||||||
| 	s.pubAddrs.Store(msg.Packet.ToAddr) | 	s.pubAddrs.Store(msg.Packet.ToAddr) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *stateClientDirect) OnPingTimer() peerState { |  | ||||||
| 	if next := s.onPingTimer(); next != nil { |  | ||||||
| 		return next |  | ||||||
| 	} |  | ||||||
| 	return s |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (s *stateClientDirect) onPingTimer() peerState { | func (s *stateClientDirect) onPingTimer() peerState { | ||||||
| 	if time.Since(s.lastSeen) > timeoutInterval { | 	if time.Since(s.lastSeen) > timeoutInterval { | ||||||
| 		if s.staged.Up { | 		if s.staged.Up { | ||||||
| @@ -297,21 +333,44 @@ func (s *stateClientDirect) onPingTimer() peerState { | |||||||
| type stateClientRelayed struct { | type stateClientRelayed struct { | ||||||
| 	*stateClientDirect | 	*stateClientDirect | ||||||
| 	ack                packetAck | 	ack                packetAck | ||||||
| 	probes             map[uint64]netip.AddrPort | 	probes             map[uint64]netip.AddrPort // TODO: something better | ||||||
| 	localDiscoveryAddr netip.AddrPort | 	localDiscoveryAddr netip.AddrPort            // TODO: Remove | ||||||
| } | } | ||||||
|  |  | ||||||
| func enterStateClientRelayed(s *pState) peerState { | func enterStateClientRelayed(s *pState) peerState { | ||||||
| 	s.logf("==> ClientRelayed") | 	s.logf("==> ClientRelayed") | ||||||
|  | 	s.pingTimer.Reset(pingInterval) | ||||||
| 	return &stateClientRelayed{ | 	return &stateClientRelayed{ | ||||||
| 		stateClientDirect: newStateClientDirect(s), | 		stateClientDirect: newStateClientDirect(s), | ||||||
| 		probes:            map[uint64]netip.AddrPort{}, | 		probes:            map[uint64]netip.AddrPort{}, | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (s *stateClientRelayed) OnMsg(raw any) peerState { | ||||||
|  | 	switch msg := raw.(type) { | ||||||
|  | 	case peerUpdateMsg: | ||||||
|  | 		return s.OnPeerUpdate(msg.Peer) | ||||||
|  | 	case controlMsg[packetAck]: | ||||||
|  | 		s.OnAck(msg) | ||||||
|  | 		return s | ||||||
|  | 	case controlMsg[packetProbe]: | ||||||
|  | 		return s.OnProbe(msg) | ||||||
|  | 	case controlMsg[packetLocalDiscovery]: | ||||||
|  | 		s.OnLocalDiscovery(msg) | ||||||
|  | 		return s | ||||||
|  | 	case pingTimerMsg: | ||||||
|  | 		return s.OnPingTimer() | ||||||
|  | 	default: | ||||||
|  | 		// TODO: Log | ||||||
|  | 		return s | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| func (s *stateClientRelayed) OnAck(msg controlMsg[packetAck]) { | func (s *stateClientRelayed) OnAck(msg controlMsg[packetAck]) { | ||||||
| 	s.ack = msg.Packet | 	s.ack = msg.Packet | ||||||
| 	s.stateClientDirect.OnAck(msg) | 	s.stateClientDirect.OnAck(msg) | ||||||
|  |  | ||||||
|  | 	// TODO: Send probes now. | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *stateClientRelayed) OnProbe(msg controlMsg[packetProbe]) peerState { | func (s *stateClientRelayed) OnProbe(msg controlMsg[packetProbe]) peerState { | ||||||
| @@ -330,6 +389,7 @@ func (s *stateClientRelayed) OnLocalDiscovery(msg controlMsg[packetLocalDiscover | |||||||
| 	// The source port will be the multicast port, so we'll have to | 	// The source port will be the multicast port, so we'll have to | ||||||
| 	// construct the correct address using the peer's listed port. | 	// construct the correct address using the peer's listed port. | ||||||
| 	s.localDiscoveryAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) | 	s.localDiscoveryAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) | ||||||
|  | 	// TODO:  s.sendProbeTo(s.localDiscoveryAddr) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *stateClientRelayed) OnPingTimer() peerState { | func (s *stateClientRelayed) OnPingTimer() peerState { | ||||||
|   | |||||||
| @@ -34,6 +34,7 @@ func NewPeerStateTestHarness() *PeerStateTestHarness { | |||||||
| 		sendControlPacket: func(rp remotePeer, pkt marshaller) { | 		sendControlPacket: func(rp remotePeer, pkt marshaller) { | ||||||
| 			h.Sent = append(h.Sent, PeerStateControlMsg{rp, pkt}) | 			h.Sent = append(h.Sent, PeerStateControlMsg{rp, pkt}) | ||||||
| 		}, | 		}, | ||||||
|  | 		pingTimer: time.NewTicker(pingInterval), | ||||||
| 		localIP:   2, | 		localIP:   2, | ||||||
| 		remoteIP:  3, | 		remoteIP:  3, | ||||||
| 		privKey:   keys.PrivKey, | 		privKey:   keys.PrivKey, | ||||||
| @@ -49,27 +50,19 @@ func NewPeerStateTestHarness() *PeerStateTestHarness { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (h *PeerStateTestHarness) PeerUpdate(p *m.Peer) { | func (h *PeerStateTestHarness) PeerUpdate(p *m.Peer) { | ||||||
| 	if s := h.State.OnPeerUpdate(p); s != nil { | 	h.State = h.State.OnMsg(peerUpdateMsg{p}) | ||||||
| 		h.State = s |  | ||||||
| 	} |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func (h *PeerStateTestHarness) OnSyn(msg controlMsg[packetSyn]) { | func (h *PeerStateTestHarness) OnSyn(msg controlMsg[packetSyn]) { | ||||||
| 	if s := h.State.OnSyn(msg); s != nil { | 	h.State = h.State.OnMsg(msg) | ||||||
| 		h.State = s |  | ||||||
| 	} |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func (h *PeerStateTestHarness) OnProbe(msg controlMsg[packetProbe]) { | func (h *PeerStateTestHarness) OnProbe(msg controlMsg[packetProbe]) { | ||||||
| 	if s := h.State.OnProbe(msg); s != nil { | 	h.State = h.State.OnMsg(msg) | ||||||
| 		h.State = s |  | ||||||
| 	} |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func (h *PeerStateTestHarness) OnPingTimer() { | func (h *PeerStateTestHarness) OnPingTimer() { | ||||||
| 	if s := h.State.OnPingTimer(); s != nil { | 	h.State = h.State.OnMsg(pingTimerMsg{}) | ||||||
| 		h.State = s |  | ||||||
| 	} |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func (h *PeerStateTestHarness) ConfigServer_Public(t *testing.T) *stateServer { | func (h *PeerStateTestHarness) ConfigServer_Public(t *testing.T) *stateServer { | ||||||
| @@ -202,7 +195,7 @@ func TestStateServer_directSyn(t *testing.T) { | |||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	h.State.OnSyn(synMsg) | 	h.State = h.State.OnMsg(synMsg) | ||||||
|  |  | ||||||
| 	assertEqual(t, len(h.Sent), 1) | 	assertEqual(t, len(h.Sent), 1) | ||||||
| 	ack := assertType[packetAck](t, h.Sent[0].Packet) | 	ack := assertType[packetAck](t, h.Sent[0].Packet) | ||||||
| @@ -233,7 +226,7 @@ func TestStateServer_relayedSyn(t *testing.T) { | |||||||
| 	synMsg.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 3, 300) | 	synMsg.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 3, 300) | ||||||
| 	synMsg.Packet.PossibleAddrs[1] = addrPort4(2, 2, 2, 3, 300) | 	synMsg.Packet.PossibleAddrs[1] = addrPort4(2, 2, 2, 3, 300) | ||||||
|  |  | ||||||
| 	h.State.OnSyn(synMsg) | 	h.State = h.State.OnMsg(synMsg) | ||||||
|  |  | ||||||
| 	assertEqual(t, len(h.Sent), 3) | 	assertEqual(t, len(h.Sent), 3) | ||||||
|  |  | ||||||
| @@ -261,7 +254,7 @@ func TestStateServer_onProbe(t *testing.T) { | |||||||
| 		Packet:  packetProbe{TraceID: newTraceID()}, | 		Packet:  packetProbe{TraceID: newTraceID()}, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	h.State.OnProbe(probeMsg) | 	h.State = h.State.OnMsg(probeMsg) | ||||||
|  |  | ||||||
| 	assertEqual(t, len(h.Sent), 1) | 	assertEqual(t, len(h.Sent), 1) | ||||||
|  |  | ||||||
| @@ -285,7 +278,7 @@ func TestStateServer_OnPingTimer_timeout(t *testing.T) { | |||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	h.State.OnSyn(synMsg) | 	h.State = h.State.OnMsg(synMsg) | ||||||
| 	assertEqual(t, len(h.Sent), 1) | 	assertEqual(t, len(h.Sent), 1) | ||||||
| 	assertEqual(t, h.Published.Up, true) | 	assertEqual(t, h.Published.Up, true) | ||||||
|  |  | ||||||
| @@ -314,7 +307,7 @@ func TestStateClientDirect_OnAck(t *testing.T) { | |||||||
| 	ack := controlMsg[packetAck]{ | 	ack := controlMsg[packetAck]{ | ||||||
| 		Packet: packetAck{TraceID: syn.TraceID}, | 		Packet: packetAck{TraceID: syn.TraceID}, | ||||||
| 	} | 	} | ||||||
| 	h.State.OnAck(ack) | 	h.State = h.State.OnMsg(ack) | ||||||
| 	assertEqual(t, h.Published.Up, true) | 	assertEqual(t, h.Published.Up, true) | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -331,7 +324,7 @@ func TestStateClientDirect_OnAck_incorrectTraceID(t *testing.T) { | |||||||
| 	ack := controlMsg[packetAck]{ | 	ack := controlMsg[packetAck]{ | ||||||
| 		Packet: packetAck{TraceID: syn.TraceID + 1}, | 		Packet: packetAck{TraceID: syn.TraceID + 1}, | ||||||
| 	} | 	} | ||||||
| 	h.State.OnAck(ack) | 	h.State = h.State.OnMsg(ack) | ||||||
| 	assertEqual(t, h.Published.Up, false) | 	assertEqual(t, h.Published.Up, false) | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -366,7 +359,7 @@ func TestStateClientDirect_OnPingTimer_timeout(t *testing.T) { | |||||||
| 	ack := controlMsg[packetAck]{ | 	ack := controlMsg[packetAck]{ | ||||||
| 		Packet: packetAck{TraceID: syn.TraceID}, | 		Packet: packetAck{TraceID: syn.TraceID}, | ||||||
| 	} | 	} | ||||||
| 	h.State.OnAck(ack) | 	h.State = h.State.OnMsg(ack) | ||||||
| 	assertEqual(t, h.Published.Up, true) | 	assertEqual(t, h.Published.Up, true) | ||||||
|  |  | ||||||
| 	state := assertType[*stateClientDirect](t, h.State) | 	state := assertType[*stateClientDirect](t, h.State) | ||||||
| @@ -395,7 +388,7 @@ func TestStateClientRelayed_OnAck(t *testing.T) { | |||||||
| 	ack := controlMsg[packetAck]{ | 	ack := controlMsg[packetAck]{ | ||||||
| 		Packet: packetAck{TraceID: syn.TraceID}, | 		Packet: packetAck{TraceID: syn.TraceID}, | ||||||
| 	} | 	} | ||||||
| 	h.State.OnAck(ack) | 	h.State = h.State.OnMsg(ack) | ||||||
| 	assertEqual(t, h.Published.Up, true) | 	assertEqual(t, h.Published.Up, true) | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -429,11 +422,11 @@ func TestStateClientRelayed_OnPingTimer_withAddrs(t *testing.T) { | |||||||
| 	ack.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 300) | 	ack.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 300) | ||||||
| 	ack.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 300) | 	ack.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 300) | ||||||
|  |  | ||||||
| 	h.State.OnAck(ack) | 	h.State = h.State.OnMsg(ack) | ||||||
|  |  | ||||||
| 	// Add a local discovery address. Note that the port will be configured port | 	// Add a local discovery address. Note that the port will be configured port | ||||||
| 	// and no the one provided here. | 	// and no the one provided here. | ||||||
| 	h.State.OnLocalDiscovery(controlMsg[packetLocalDiscovery]{ | 	h.State = h.State.OnMsg(controlMsg[packetLocalDiscovery]{ | ||||||
| 		SrcIP:   3, | 		SrcIP:   3, | ||||||
| 		SrcAddr: addrPort4(2, 2, 2, 3, 300), | 		SrcAddr: addrPort4(2, 2, 2, 3, 300), | ||||||
| 	}) | 	}) | ||||||
| @@ -462,7 +455,7 @@ func TestStateClientRelayed_OnPingTimer_timeout(t *testing.T) { | |||||||
| 	ack := controlMsg[packetAck]{ | 	ack := controlMsg[packetAck]{ | ||||||
| 		Packet: packetAck{TraceID: syn.TraceID}, | 		Packet: packetAck{TraceID: syn.TraceID}, | ||||||
| 	} | 	} | ||||||
| 	h.State.OnAck(ack) | 	h.State = h.State.OnMsg(ack) | ||||||
| 	assertEqual(t, h.Published.Up, true) | 	assertEqual(t, h.Published.Up, true) | ||||||
|  |  | ||||||
| 	state := assertType[*stateClientRelayed](t, h.State) | 	state := assertType[*stateClientRelayed](t, h.State) | ||||||
| @@ -499,7 +492,7 @@ func TestStateClientRelayed_OnProbe_upgradeDirect(t *testing.T) { | |||||||
| 	ack.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 300) | 	ack.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 300) | ||||||
| 	ack.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 300) | 	ack.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 300) | ||||||
|  |  | ||||||
| 	h.State.OnAck(ack) | 	h.State = h.State.OnMsg(ack) | ||||||
| 	h.OnPingTimer() | 	h.OnPingTimer() | ||||||
|  |  | ||||||
| 	probe := assertType[packetProbe](t, h.Sent[2].Packet) | 	probe := assertType[packetProbe](t, h.Sent[2].Packet) | ||||||
|   | |||||||
| @@ -1,8 +1,6 @@ | |||||||
| package peer | package peer | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"log" |  | ||||||
| 	"math/rand" |  | ||||||
| 	"net/netip" | 	"net/netip" | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"sync/atomic" | 	"sync/atomic" | ||||||
| @@ -44,6 +42,7 @@ func newSupervisor( | |||||||
| 		state := &pState{ | 		state := &pState{ | ||||||
| 			publish:           s.publish, | 			publish:           s.publish, | ||||||
| 			sendControlPacket: s.send, | 			sendControlPacket: s.send, | ||||||
|  | 			pingTimer:         time.NewTicker(timeoutInterval), | ||||||
| 			localIP:           routes.LocalIP, | 			localIP:           routes.LocalIP, | ||||||
| 			remoteIP:          byte(i), | 			remoteIP:          byte(i), | ||||||
| 			privKey:           privKey, | 			privKey:           privKey, | ||||||
| @@ -55,7 +54,7 @@ func newSupervisor( | |||||||
| 				MaxWaitCount: 1, | 				MaxWaitCount: 1, | ||||||
| 			}), | 			}), | ||||||
| 		} | 		} | ||||||
| 		s.peers[i] = newPeerSuper(state) | 		s.peers[i] = newPeerSuper(state, state.pingTimer) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return s | 	return s | ||||||
| @@ -105,7 +104,7 @@ func (s *supervisor) ensureRelay() { | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// TODO: Random selection? | 	// TODO: Random selection? Something else? | ||||||
| 	for _, peer := range s.staged.Peers { | 	for _, peer := range s.staged.Peers { | ||||||
| 		if peer.Up && peer.Direct && peer.Relay { | 		if peer.Up && peer.Direct && peer.Relay { | ||||||
| 			s.staged.RelayIP = peer.IP | 			s.staged.RelayIP = peer.IP | ||||||
| @@ -119,12 +118,14 @@ func (s *supervisor) ensureRelay() { | |||||||
| type peerSuper struct { | type peerSuper struct { | ||||||
| 	messages  chan any | 	messages  chan any | ||||||
| 	state     peerState | 	state     peerState | ||||||
|  | 	pingTimer *time.Ticker | ||||||
| } | } | ||||||
|  |  | ||||||
| func newPeerSuper(state *pState) *peerSuper { | func newPeerSuper(state *pState, pingTimer *time.Ticker) *peerSuper { | ||||||
| 	return &peerSuper{ | 	return &peerSuper{ | ||||||
| 		messages:  make(chan any, 8), | 		messages:  make(chan any, 8), | ||||||
| 		state:     state.OnPeerUpdate(nil), | 		state:     state.OnPeerUpdate(nil), | ||||||
|  | 		pingTimer: pingTimer, | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -136,37 +137,12 @@ func (s *peerSuper) HandleControlMsg(msg any) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (s *peerSuper) Run() { | func (s *peerSuper) Run() { | ||||||
| 	go func() { | 	for { | ||||||
| 		// Randomize ping timers. | 		select { | ||||||
| 		time.Sleep(time.Duration(rand.Intn(4000)) * time.Millisecond) | 		case <-s.pingTimer.C: | ||||||
| 		for range time.Tick(pingInterval) { | 			s.state = s.state.OnMsg(pingTimerMsg{}) | ||||||
| 			s.messages <- pingTimerMsg{} | 		case raw := <-s.messages: | ||||||
| 		} | 			s.state = s.state.OnMsg(raw) | ||||||
| 	}() |  | ||||||
|  |  | ||||||
| 	for rawMsg := range s.messages { |  | ||||||
| 		switch msg := rawMsg.(type) { |  | ||||||
|  |  | ||||||
| 		case peerUpdateMsg: |  | ||||||
| 			s.state = s.state.OnPeerUpdate(msg.Peer) |  | ||||||
|  |  | ||||||
| 		case controlMsg[packetSyn]: |  | ||||||
| 			s.state = s.state.OnSyn(msg) |  | ||||||
|  |  | ||||||
| 		case controlMsg[packetAck]: |  | ||||||
| 			s.state.OnAck(msg) |  | ||||||
|  |  | ||||||
| 		case controlMsg[packetProbe]: |  | ||||||
| 			s.state = s.state.OnProbe(msg) |  | ||||||
|  |  | ||||||
| 		case controlMsg[packetLocalDiscovery]: |  | ||||||
| 			s.state.OnLocalDiscovery(msg) |  | ||||||
|  |  | ||||||
| 		case pingTimerMsg: |  | ||||||
| 			s.state = s.state.OnPingTimer() |  | ||||||
|  |  | ||||||
| 		default: |  | ||||||
| 			log.Printf("WARNING: unknown message type: %+v", msg) |  | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user