wip: working
This commit is contained in:
		| @@ -49,8 +49,7 @@ func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type peerUpdateMsg struct { | ||||
| 	PeerIP byte | ||||
| 	Peer   *m.Peer | ||||
| 	Peer *m.Peer | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|   | ||||
| @@ -96,7 +96,7 @@ func (hp *hubPoller) applyNetworkState(state m.NetworkState) { | ||||
| 	for i, peer := range state.Peers { | ||||
| 		if i != int(hp.localIP) { | ||||
| 			if peer == nil || peer.Version != hp.versions[i] { | ||||
| 				hp.handleControlMsg(byte(i), peerUpdateMsg{PeerIP: byte(i), Peer: state.Peers[i]}) | ||||
| 				hp.handleControlMsg(byte(i), peerUpdateMsg{Peer: state.Peers[i]}) | ||||
| 				if peer != nil { | ||||
| 					hp.versions[i] = peer.Version | ||||
| 				} | ||||
|   | ||||
| @@ -44,7 +44,7 @@ func runMCWriter(localIP byte, signingKey []byte) { | ||||
| 		log.Fatalf("Failed to bind to multicast address: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	for range time.Tick(16 * time.Second) { | ||||
| 	for range time.Tick(8 * time.Second) { | ||||
| 		_, err := conn.WriteToUDP(discoveryPacket, multicastAddr) | ||||
| 		if err != nil { | ||||
| 			log.Printf("[MCWriter] Failed to write multicast: %v", err) | ||||
|   | ||||
| @@ -12,12 +12,7 @@ import ( | ||||
| ) | ||||
|  | ||||
| type peerState interface { | ||||
| 	OnPeerUpdate(*m.Peer) peerState | ||||
| 	OnSyn(controlMsg[packetSyn]) peerState | ||||
| 	OnAck(controlMsg[packetAck]) | ||||
| 	OnProbe(controlMsg[packetProbe]) peerState | ||||
| 	OnLocalDiscovery(controlMsg[packetLocalDiscovery]) | ||||
| 	OnPingTimer() peerState | ||||
| 	OnMsg(raw any) peerState | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
| @@ -26,6 +21,7 @@ type pState struct { | ||||
| 	// Output. | ||||
| 	publish           func(remotePeer) | ||||
| 	sendControlPacket func(remotePeer, marshaller) | ||||
| 	pingTimer         *time.Ticker | ||||
|  | ||||
| 	// Immutable data. | ||||
| 	localIP   byte | ||||
| @@ -147,9 +143,20 @@ func (s *pState) Send(peer remotePeer, pkt marshaller) { | ||||
| type stateDisconnected struct{ *pState } | ||||
|  | ||||
| func enterStateDisconnected(s *pState) peerState { | ||||
| 	s.pingTimer.Stop() | ||||
| 	return &stateDisconnected{pState: s} | ||||
| } | ||||
|  | ||||
| func (s *stateDisconnected) OnMsg(raw any) peerState { | ||||
| 	switch msg := raw.(type) { | ||||
| 	case peerUpdateMsg: | ||||
| 		return s.OnPeerUpdate(msg.Peer) | ||||
| 	default: | ||||
| 		// TODO: Log. | ||||
| 		return s | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s *stateDisconnected) OnSyn(controlMsg[packetSyn]) peerState             { return s } | ||||
| func (s *stateDisconnected) OnAck(controlMsg[packetAck])                       {} | ||||
| func (s *stateDisconnected) OnProbe(controlMsg[packetProbe]) peerState         { return s } | ||||
| @@ -166,9 +173,26 @@ type stateServer struct { | ||||
|  | ||||
| func enterStateServer(s *pState) peerState { | ||||
| 	s.logf("==> Server") | ||||
| 	s.pingTimer.Reset(pingInterval) | ||||
| 	return &stateServer{stateDisconnected: &stateDisconnected{pState: s}} | ||||
| } | ||||
|  | ||||
| func (s *stateServer) OnMsg(rawMsg any) peerState { | ||||
| 	switch msg := rawMsg.(type) { | ||||
| 	case peerUpdateMsg: | ||||
| 		return s.OnPeerUpdate(msg.Peer) | ||||
| 	case controlMsg[packetSyn]: | ||||
| 		return s.OnSyn(msg) | ||||
| 	case controlMsg[packetProbe]: | ||||
| 		return s.OnProbe(msg) | ||||
| 	case pingTimerMsg: | ||||
| 		return s.OnPingTimer() | ||||
| 	default: | ||||
| 		// TODO: Log | ||||
| 		return s | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s *stateServer) OnSyn(msg controlMsg[packetSyn]) peerState { | ||||
| 	s.lastSeen = time.Now() | ||||
| 	p := msg.Packet | ||||
| @@ -236,6 +260,7 @@ type stateClientDirect struct { | ||||
|  | ||||
| func enterStateClientDirect(s *pState) peerState { | ||||
| 	s.logf("==> ClientDirect") | ||||
| 	s.pingTimer.Reset(pingInterval) | ||||
| 	return newStateClientDirect(s) | ||||
| } | ||||
|  | ||||
| @@ -255,6 +280,24 @@ func newStateClientDirect(s *pState) *stateClientDirect { | ||||
| 	return state | ||||
| } | ||||
|  | ||||
| func (s *stateClientDirect) OnMsg(raw any) peerState { | ||||
| 	switch msg := raw.(type) { | ||||
| 	case peerUpdateMsg: | ||||
| 		return s.OnPeerUpdate(msg.Peer) | ||||
| 	case controlMsg[packetAck]: | ||||
| 		s.OnAck(msg) | ||||
| 		return s | ||||
| 	case pingTimerMsg: | ||||
| 		if next := s.onPingTimer(); next != nil { | ||||
| 			return next | ||||
| 		} | ||||
| 		return s | ||||
| 	default: | ||||
| 		// TODO: Log | ||||
| 		return s | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s *stateClientDirect) OnAck(msg controlMsg[packetAck]) { | ||||
| 	if msg.Packet.TraceID != s.syn.TraceID { | ||||
| 		return | ||||
| @@ -271,13 +314,6 @@ func (s *stateClientDirect) OnAck(msg controlMsg[packetAck]) { | ||||
| 	s.pubAddrs.Store(msg.Packet.ToAddr) | ||||
| } | ||||
|  | ||||
| func (s *stateClientDirect) OnPingTimer() peerState { | ||||
| 	if next := s.onPingTimer(); next != nil { | ||||
| 		return next | ||||
| 	} | ||||
| 	return s | ||||
| } | ||||
|  | ||||
| func (s *stateClientDirect) onPingTimer() peerState { | ||||
| 	if time.Since(s.lastSeen) > timeoutInterval { | ||||
| 		if s.staged.Up { | ||||
| @@ -297,21 +333,44 @@ func (s *stateClientDirect) onPingTimer() peerState { | ||||
| type stateClientRelayed struct { | ||||
| 	*stateClientDirect | ||||
| 	ack                packetAck | ||||
| 	probes             map[uint64]netip.AddrPort | ||||
| 	localDiscoveryAddr netip.AddrPort | ||||
| 	probes             map[uint64]netip.AddrPort // TODO: something better | ||||
| 	localDiscoveryAddr netip.AddrPort            // TODO: Remove | ||||
| } | ||||
|  | ||||
| func enterStateClientRelayed(s *pState) peerState { | ||||
| 	s.logf("==> ClientRelayed") | ||||
| 	s.pingTimer.Reset(pingInterval) | ||||
| 	return &stateClientRelayed{ | ||||
| 		stateClientDirect: newStateClientDirect(s), | ||||
| 		probes:            map[uint64]netip.AddrPort{}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s *stateClientRelayed) OnMsg(raw any) peerState { | ||||
| 	switch msg := raw.(type) { | ||||
| 	case peerUpdateMsg: | ||||
| 		return s.OnPeerUpdate(msg.Peer) | ||||
| 	case controlMsg[packetAck]: | ||||
| 		s.OnAck(msg) | ||||
| 		return s | ||||
| 	case controlMsg[packetProbe]: | ||||
| 		return s.OnProbe(msg) | ||||
| 	case controlMsg[packetLocalDiscovery]: | ||||
| 		s.OnLocalDiscovery(msg) | ||||
| 		return s | ||||
| 	case pingTimerMsg: | ||||
| 		return s.OnPingTimer() | ||||
| 	default: | ||||
| 		// TODO: Log | ||||
| 		return s | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s *stateClientRelayed) OnAck(msg controlMsg[packetAck]) { | ||||
| 	s.ack = msg.Packet | ||||
| 	s.stateClientDirect.OnAck(msg) | ||||
|  | ||||
| 	// TODO: Send probes now. | ||||
| } | ||||
|  | ||||
| func (s *stateClientRelayed) OnProbe(msg controlMsg[packetProbe]) peerState { | ||||
| @@ -330,6 +389,7 @@ func (s *stateClientRelayed) OnLocalDiscovery(msg controlMsg[packetLocalDiscover | ||||
| 	// The source port will be the multicast port, so we'll have to | ||||
| 	// construct the correct address using the peer's listed port. | ||||
| 	s.localDiscoveryAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) | ||||
| 	// TODO:  s.sendProbeTo(s.localDiscoveryAddr) | ||||
| } | ||||
|  | ||||
| func (s *stateClientRelayed) OnPingTimer() peerState { | ||||
|   | ||||
| @@ -34,10 +34,11 @@ func NewPeerStateTestHarness() *PeerStateTestHarness { | ||||
| 		sendControlPacket: func(rp remotePeer, pkt marshaller) { | ||||
| 			h.Sent = append(h.Sent, PeerStateControlMsg{rp, pkt}) | ||||
| 		}, | ||||
| 		localIP:  2, | ||||
| 		remoteIP: 3, | ||||
| 		privKey:  keys.PrivKey, | ||||
| 		pubAddrs: newPubAddrStore(netip.AddrPort{}), | ||||
| 		pingTimer: time.NewTicker(pingInterval), | ||||
| 		localIP:   2, | ||||
| 		remoteIP:  3, | ||||
| 		privKey:   keys.PrivKey, | ||||
| 		pubAddrs:  newPubAddrStore(netip.AddrPort{}), | ||||
| 		limiter: ratelimiter.New(ratelimiter.Config{ | ||||
| 			FillPeriod:   20 * time.Millisecond, | ||||
| 			MaxWaitCount: 1, | ||||
| @@ -49,27 +50,19 @@ func NewPeerStateTestHarness() *PeerStateTestHarness { | ||||
| } | ||||
|  | ||||
| func (h *PeerStateTestHarness) PeerUpdate(p *m.Peer) { | ||||
| 	if s := h.State.OnPeerUpdate(p); s != nil { | ||||
| 		h.State = s | ||||
| 	} | ||||
| 	h.State = h.State.OnMsg(peerUpdateMsg{p}) | ||||
| } | ||||
|  | ||||
| func (h *PeerStateTestHarness) OnSyn(msg controlMsg[packetSyn]) { | ||||
| 	if s := h.State.OnSyn(msg); s != nil { | ||||
| 		h.State = s | ||||
| 	} | ||||
| 	h.State = h.State.OnMsg(msg) | ||||
| } | ||||
|  | ||||
| func (h *PeerStateTestHarness) OnProbe(msg controlMsg[packetProbe]) { | ||||
| 	if s := h.State.OnProbe(msg); s != nil { | ||||
| 		h.State = s | ||||
| 	} | ||||
| 	h.State = h.State.OnMsg(msg) | ||||
| } | ||||
|  | ||||
| func (h *PeerStateTestHarness) OnPingTimer() { | ||||
| 	if s := h.State.OnPingTimer(); s != nil { | ||||
| 		h.State = s | ||||
| 	} | ||||
| 	h.State = h.State.OnMsg(pingTimerMsg{}) | ||||
| } | ||||
|  | ||||
| func (h *PeerStateTestHarness) ConfigServer_Public(t *testing.T) *stateServer { | ||||
| @@ -202,7 +195,7 @@ func TestStateServer_directSyn(t *testing.T) { | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	h.State.OnSyn(synMsg) | ||||
| 	h.State = h.State.OnMsg(synMsg) | ||||
|  | ||||
| 	assertEqual(t, len(h.Sent), 1) | ||||
| 	ack := assertType[packetAck](t, h.Sent[0].Packet) | ||||
| @@ -233,7 +226,7 @@ func TestStateServer_relayedSyn(t *testing.T) { | ||||
| 	synMsg.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 3, 300) | ||||
| 	synMsg.Packet.PossibleAddrs[1] = addrPort4(2, 2, 2, 3, 300) | ||||
|  | ||||
| 	h.State.OnSyn(synMsg) | ||||
| 	h.State = h.State.OnMsg(synMsg) | ||||
|  | ||||
| 	assertEqual(t, len(h.Sent), 3) | ||||
|  | ||||
| @@ -261,7 +254,7 @@ func TestStateServer_onProbe(t *testing.T) { | ||||
| 		Packet:  packetProbe{TraceID: newTraceID()}, | ||||
| 	} | ||||
|  | ||||
| 	h.State.OnProbe(probeMsg) | ||||
| 	h.State = h.State.OnMsg(probeMsg) | ||||
|  | ||||
| 	assertEqual(t, len(h.Sent), 1) | ||||
|  | ||||
| @@ -285,7 +278,7 @@ func TestStateServer_OnPingTimer_timeout(t *testing.T) { | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	h.State.OnSyn(synMsg) | ||||
| 	h.State = h.State.OnMsg(synMsg) | ||||
| 	assertEqual(t, len(h.Sent), 1) | ||||
| 	assertEqual(t, h.Published.Up, true) | ||||
|  | ||||
| @@ -314,7 +307,7 @@ func TestStateClientDirect_OnAck(t *testing.T) { | ||||
| 	ack := controlMsg[packetAck]{ | ||||
| 		Packet: packetAck{TraceID: syn.TraceID}, | ||||
| 	} | ||||
| 	h.State.OnAck(ack) | ||||
| 	h.State = h.State.OnMsg(ack) | ||||
| 	assertEqual(t, h.Published.Up, true) | ||||
| } | ||||
|  | ||||
| @@ -331,7 +324,7 @@ func TestStateClientDirect_OnAck_incorrectTraceID(t *testing.T) { | ||||
| 	ack := controlMsg[packetAck]{ | ||||
| 		Packet: packetAck{TraceID: syn.TraceID + 1}, | ||||
| 	} | ||||
| 	h.State.OnAck(ack) | ||||
| 	h.State = h.State.OnMsg(ack) | ||||
| 	assertEqual(t, h.Published.Up, false) | ||||
| } | ||||
|  | ||||
| @@ -366,7 +359,7 @@ func TestStateClientDirect_OnPingTimer_timeout(t *testing.T) { | ||||
| 	ack := controlMsg[packetAck]{ | ||||
| 		Packet: packetAck{TraceID: syn.TraceID}, | ||||
| 	} | ||||
| 	h.State.OnAck(ack) | ||||
| 	h.State = h.State.OnMsg(ack) | ||||
| 	assertEqual(t, h.Published.Up, true) | ||||
|  | ||||
| 	state := assertType[*stateClientDirect](t, h.State) | ||||
| @@ -395,7 +388,7 @@ func TestStateClientRelayed_OnAck(t *testing.T) { | ||||
| 	ack := controlMsg[packetAck]{ | ||||
| 		Packet: packetAck{TraceID: syn.TraceID}, | ||||
| 	} | ||||
| 	h.State.OnAck(ack) | ||||
| 	h.State = h.State.OnMsg(ack) | ||||
| 	assertEqual(t, h.Published.Up, true) | ||||
| } | ||||
|  | ||||
| @@ -429,11 +422,11 @@ func TestStateClientRelayed_OnPingTimer_withAddrs(t *testing.T) { | ||||
| 	ack.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 300) | ||||
| 	ack.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 300) | ||||
|  | ||||
| 	h.State.OnAck(ack) | ||||
| 	h.State = h.State.OnMsg(ack) | ||||
|  | ||||
| 	// Add a local discovery address. Note that the port will be configured port | ||||
| 	// and no the one provided here. | ||||
| 	h.State.OnLocalDiscovery(controlMsg[packetLocalDiscovery]{ | ||||
| 	h.State = h.State.OnMsg(controlMsg[packetLocalDiscovery]{ | ||||
| 		SrcIP:   3, | ||||
| 		SrcAddr: addrPort4(2, 2, 2, 3, 300), | ||||
| 	}) | ||||
| @@ -462,7 +455,7 @@ func TestStateClientRelayed_OnPingTimer_timeout(t *testing.T) { | ||||
| 	ack := controlMsg[packetAck]{ | ||||
| 		Packet: packetAck{TraceID: syn.TraceID}, | ||||
| 	} | ||||
| 	h.State.OnAck(ack) | ||||
| 	h.State = h.State.OnMsg(ack) | ||||
| 	assertEqual(t, h.Published.Up, true) | ||||
|  | ||||
| 	state := assertType[*stateClientRelayed](t, h.State) | ||||
| @@ -499,7 +492,7 @@ func TestStateClientRelayed_OnProbe_upgradeDirect(t *testing.T) { | ||||
| 	ack.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 300) | ||||
| 	ack.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 300) | ||||
|  | ||||
| 	h.State.OnAck(ack) | ||||
| 	h.State = h.State.OnMsg(ack) | ||||
| 	h.OnPingTimer() | ||||
|  | ||||
| 	probe := assertType[packetProbe](t, h.Sent[2].Packet) | ||||
|   | ||||
| @@ -1,8 +1,6 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"log" | ||||
| 	"math/rand" | ||||
| 	"net/netip" | ||||
| 	"sync" | ||||
| 	"sync/atomic" | ||||
| @@ -44,6 +42,7 @@ func newSupervisor( | ||||
| 		state := &pState{ | ||||
| 			publish:           s.publish, | ||||
| 			sendControlPacket: s.send, | ||||
| 			pingTimer:         time.NewTicker(timeoutInterval), | ||||
| 			localIP:           routes.LocalIP, | ||||
| 			remoteIP:          byte(i), | ||||
| 			privKey:           privKey, | ||||
| @@ -55,7 +54,7 @@ func newSupervisor( | ||||
| 				MaxWaitCount: 1, | ||||
| 			}), | ||||
| 		} | ||||
| 		s.peers[i] = newPeerSuper(state) | ||||
| 		s.peers[i] = newPeerSuper(state, state.pingTimer) | ||||
| 	} | ||||
|  | ||||
| 	return s | ||||
| @@ -105,7 +104,7 @@ func (s *supervisor) ensureRelay() { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// TODO: Random selection? | ||||
| 	// TODO: Random selection? Something else? | ||||
| 	for _, peer := range s.staged.Peers { | ||||
| 		if peer.Up && peer.Direct && peer.Relay { | ||||
| 			s.staged.RelayIP = peer.IP | ||||
| @@ -117,14 +116,16 @@ func (s *supervisor) ensureRelay() { | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type peerSuper struct { | ||||
| 	messages chan any | ||||
| 	state    peerState | ||||
| 	messages  chan any | ||||
| 	state     peerState | ||||
| 	pingTimer *time.Ticker | ||||
| } | ||||
|  | ||||
| func newPeerSuper(state *pState) *peerSuper { | ||||
| func newPeerSuper(state *pState, pingTimer *time.Ticker) *peerSuper { | ||||
| 	return &peerSuper{ | ||||
| 		messages: make(chan any, 8), | ||||
| 		state:    state.OnPeerUpdate(nil), | ||||
| 		messages:  make(chan any, 8), | ||||
| 		state:     state.OnPeerUpdate(nil), | ||||
| 		pingTimer: pingTimer, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -136,37 +137,12 @@ func (s *peerSuper) HandleControlMsg(msg any) { | ||||
| } | ||||
|  | ||||
| func (s *peerSuper) Run() { | ||||
| 	go func() { | ||||
| 		// Randomize ping timers. | ||||
| 		time.Sleep(time.Duration(rand.Intn(4000)) * time.Millisecond) | ||||
| 		for range time.Tick(pingInterval) { | ||||
| 			s.messages <- pingTimerMsg{} | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	for rawMsg := range s.messages { | ||||
| 		switch msg := rawMsg.(type) { | ||||
|  | ||||
| 		case peerUpdateMsg: | ||||
| 			s.state = s.state.OnPeerUpdate(msg.Peer) | ||||
|  | ||||
| 		case controlMsg[packetSyn]: | ||||
| 			s.state = s.state.OnSyn(msg) | ||||
|  | ||||
| 		case controlMsg[packetAck]: | ||||
| 			s.state.OnAck(msg) | ||||
|  | ||||
| 		case controlMsg[packetProbe]: | ||||
| 			s.state = s.state.OnProbe(msg) | ||||
|  | ||||
| 		case controlMsg[packetLocalDiscovery]: | ||||
| 			s.state.OnLocalDiscovery(msg) | ||||
|  | ||||
| 		case pingTimerMsg: | ||||
| 			s.state = s.state.OnPingTimer() | ||||
|  | ||||
| 		default: | ||||
| 			log.Printf("WARNING: unknown message type: %+v", msg) | ||||
| 	for { | ||||
| 		select { | ||||
| 		case <-s.pingTimer.C: | ||||
| 			s.state = s.state.OnMsg(pingTimerMsg{}) | ||||
| 		case raw := <-s.messages: | ||||
| 			s.state = s.state.OnMsg(raw) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user