refactor-for-testability #3
| @@ -24,22 +24,26 @@ func newPubAddrStore(localAddr netip.AddrPort) *pubAddrStore { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func (store *pubAddrStore) Store(add netip.AddrPort) { | func (store *pubAddrStore) Store(addr netip.AddrPort) { | ||||||
| 	if store.localPub { | 	if store.localPub { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	if !addr.IsValid() { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if addr.Addr().IsPrivate() { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	store.lock.Lock() | 	store.lock.Lock() | ||||||
| 	defer store.lock.Unlock() | 	defer store.lock.Unlock() | ||||||
|  |  | ||||||
| 	if !add.IsValid() { | 	if _, exists := store.lastSeen[addr]; !exists { | ||||||
| 		return | 		store.addrList = append(store.addrList, addr) | ||||||
| 	} | 	} | ||||||
|  | 	store.lastSeen[addr] = time.Now() | ||||||
| 	if _, exists := store.lastSeen[add]; !exists { |  | ||||||
| 		store.addrList = append(store.addrList, add) |  | ||||||
| 	} |  | ||||||
| 	store.lastSeen[add] = time.Now() |  | ||||||
| 	store.sort() | 	store.sort() | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -61,13 +61,13 @@ func (s *stateClientInit2) onInit(msg controlMsg[packetInit]) peerState { | |||||||
| 		return s | 		return s | ||||||
| 	} | 	} | ||||||
| 	s.logf("Got INIT version %d.", msg.Packet.Version) | 	s.logf("Got INIT version %d.", msg.Packet.Version) | ||||||
| 	return s.nextState() | 	return enterStateClient(s.peerData) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *stateClientInit2) onPing() peerState { | func (s *stateClientInit2) onPing() peerState { | ||||||
| 	if time.Since(s.startedAt) > timeoutInterval { | 	if time.Since(s.startedAt) > timeoutInterval { | ||||||
| 		s.logf("Init timeout. Assuming version 1.") | 		s.logf("Init timeout. Assuming version 1.") | ||||||
| 		return s.nextState() | 		return enterStateClient(s.peerData) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	s.sendInit() | 	s.sendInit() | ||||||
| @@ -83,11 +83,3 @@ func (s *stateClientInit2) sendInit() { | |||||||
| 	} | 	} | ||||||
| 	s.Send(s.staged, init) | 	s.Send(s.staged, init) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *stateClientInit2) nextState() peerState { |  | ||||||
| 	if s.staged.Direct { |  | ||||||
| 		return enterStateClientDirect2(s.peerData, s.staged.DirectAddr) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return enterStateClientRelayed2(s.peerData) |  | ||||||
| } |  | ||||||
|   | |||||||
| @@ -17,10 +17,12 @@ type stateClientRelayed2 struct { | |||||||
| 	probes   map[uint64]sentProbe | 	probes   map[uint64]sentProbe | ||||||
| } | } | ||||||
|  |  | ||||||
| func enterStateClientRelayed2(data *peerData) peerState { | func enterStateClient(data *peerData) peerState { | ||||||
| 	data.staged.Relay = false | 	ip, ipValid := netip.AddrFromSlice(data.peer.PublicIP) | ||||||
| 	data.staged.Direct = false |  | ||||||
| 	data.staged.DirectAddr = netip.AddrPort{} | 	data.staged.Relay = data.peer.Relay && ipValid | ||||||
|  | 	data.staged.Direct = ipValid | ||||||
|  | 	data.staged.DirectAddr = netip.AddrPortFrom(ip, data.peer.Port) | ||||||
| 	data.publish(data.staged) | 	data.publish(data.staged) | ||||||
|  |  | ||||||
| 	state := &stateClientRelayed2{ | 	state := &stateClientRelayed2{ | ||||||
| @@ -29,7 +31,7 @@ func enterStateClientRelayed2(data *peerData) peerState { | |||||||
| 		syn: packetSyn{ | 		syn: packetSyn{ | ||||||
| 			TraceID:       newTraceID(), | 			TraceID:       newTraceID(), | ||||||
| 			SharedKey:     data.staged.DataCipher.Key(), | 			SharedKey:     data.staged.DataCipher.Key(), | ||||||
| 			Direct:        false, | 			Direct:        data.staged.Direct, | ||||||
| 			PossibleAddrs: data.pubAddrs.Get(), | 			PossibleAddrs: data.pubAddrs.Get(), | ||||||
| 		}, | 		}, | ||||||
| 		probes: map[uint64]sentProbe{}, | 		probes: map[uint64]sentProbe{}, | ||||||
| @@ -39,7 +41,7 @@ func enterStateClientRelayed2(data *peerData) peerState { | |||||||
|  |  | ||||||
| 	data.pingTimer.Reset(pingInterval) | 	data.pingTimer.Reset(pingInterval) | ||||||
|  |  | ||||||
| 	state.logf("==> ClientRelayed") | 	state.logf("==> Client") | ||||||
| 	return state | 	return state | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -52,22 +54,22 @@ func (s *stateClientRelayed2) OnMsg(raw any) peerState { | |||||||
| 	case peerUpdateMsg: | 	case peerUpdateMsg: | ||||||
| 		return initPeerState(s.peerData, msg.Peer) | 		return initPeerState(s.peerData, msg.Peer) | ||||||
| 	case controlMsg[packetAck]: | 	case controlMsg[packetAck]: | ||||||
| 		return s.onAck(msg) | 		s.onAck(msg) | ||||||
| 	case controlMsg[packetProbe]: | 	case controlMsg[packetProbe]: | ||||||
| 		return s.onProbe(msg) | 		return s.onProbe(msg) | ||||||
| 	case controlMsg[packetLocalDiscovery]: | 	case controlMsg[packetLocalDiscovery]: | ||||||
| 		return s.onLocalDiscovery(msg) | 		s.onLocalDiscovery(msg) | ||||||
| 	case pingTimerMsg: | 	case pingTimerMsg: | ||||||
| 		return s.onPingTimer() | 		return s.onPingTimer() | ||||||
| 	default: | 	default: | ||||||
| 		s.logf("Ignoring message: %v", raw) | 		s.logf("Ignoring message: %v", raw) | ||||||
| 		return s |  | ||||||
| 	} | 	} | ||||||
|  | 	return s | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *stateClientRelayed2) onAck(msg controlMsg[packetAck]) peerState { | func (s *stateClientRelayed2) onAck(msg controlMsg[packetAck]) { | ||||||
| 	if msg.Packet.TraceID != s.syn.TraceID { | 	if msg.Packet.TraceID != s.syn.TraceID { | ||||||
| 		return s | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	s.lastSeen = time.Now() | 	s.lastSeen = time.Now() | ||||||
| @@ -78,7 +80,14 @@ func (s *stateClientRelayed2) onAck(msg controlMsg[packetAck]) peerState { | |||||||
| 		s.logf("Got ACK.") | 		s.logf("Got ACK.") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	if s.staged.Direct { | ||||||
| 		s.pubAddrs.Store(msg.Packet.ToAddr) | 		s.pubAddrs.Store(msg.Packet.ToAddr) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Relayed below. | ||||||
|  |  | ||||||
|  | 	s.cleanProbes() | ||||||
|  |  | ||||||
| 	for _, addr := range msg.Packet.PossibleAddrs { | 	for _, addr := range msg.Packet.PossibleAddrs { | ||||||
| 		if !addr.IsValid() { | 		if !addr.IsValid() { | ||||||
| @@ -86,10 +95,6 @@ func (s *stateClientRelayed2) onAck(msg controlMsg[packetAck]) peerState { | |||||||
| 		} | 		} | ||||||
| 		s.sendProbeTo(addr) | 		s.sendProbeTo(addr) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	s.cleanProbes() |  | ||||||
|  |  | ||||||
| 	return s |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *stateClientRelayed2) onPingTimer() peerState { | func (s *stateClientRelayed2) onPingTimer() peerState { | ||||||
| @@ -105,6 +110,10 @@ func (s *stateClientRelayed2) onPingTimer() peerState { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (s *stateClientRelayed2) onProbe(msg controlMsg[packetProbe]) peerState { | func (s *stateClientRelayed2) onProbe(msg controlMsg[packetProbe]) peerState { | ||||||
|  | 	if s.staged.Direct { | ||||||
|  | 		return s | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	s.cleanProbes() | 	s.cleanProbes() | ||||||
|  |  | ||||||
| 	sent, ok := s.probes[msg.Packet.TraceID] | 	sent, ok := s.probes[msg.Packet.TraceID] | ||||||
| @@ -112,16 +121,27 @@ func (s *stateClientRelayed2) onProbe(msg controlMsg[packetProbe]) peerState { | |||||||
| 		return s | 		return s | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	s.staged.Direct = true | ||||||
|  | 	s.staged.DirectAddr = sent.Addr | ||||||
|  | 	s.publish(s.staged) | ||||||
|  |  | ||||||
|  | 	s.syn.TraceID = newTraceID() | ||||||
|  | 	s.syn.Direct = true | ||||||
|  | 	s.Send(s.staged, s.syn) | ||||||
|  |  | ||||||
| 	s.logf("Successful probe.") | 	s.logf("Successful probe.") | ||||||
| 	return enterStateClientDirect2(s.peerData, sent.Addr) | 	return s | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateClientRelayed2) onLocalDiscovery(msg controlMsg[packetLocalDiscovery]) { | ||||||
|  | 	if s.staged.Direct { | ||||||
|  | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| func (s *stateClientRelayed2) onLocalDiscovery(msg controlMsg[packetLocalDiscovery]) peerState { |  | ||||||
| 	// The source port will be the multicast port, so we'll have to | 	// The source port will be the multicast port, so we'll have to | ||||||
| 	// construct the correct address using the peer's listed port. | 	// construct the correct address using the peer's listed port. | ||||||
| 	addr := netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) | 	addr := netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) | ||||||
| 	s.sendProbeTo(addr) | 	s.sendProbeTo(addr) | ||||||
| 	return s |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *stateClientRelayed2) cleanProbes() { | func (s *stateClientRelayed2) cleanProbes() { | ||||||
| @@ -138,5 +158,6 @@ func (s *stateClientRelayed2) sendProbeTo(addr netip.AddrPort) { | |||||||
| 		SentAt: time.Now(), | 		SentAt: time.Now(), | ||||||
| 		Addr:   addr, | 		Addr:   addr, | ||||||
| 	} | 	} | ||||||
|  | 	s.logf("Probing %v...", addr) | ||||||
| 	s.SendTo(probe, addr) | 	s.SendTo(probe, addr) | ||||||
| } | } | ||||||
|   | |||||||
| @@ -104,6 +104,7 @@ func (s *stateServer2) onSyn(msg controlMsg[packetSyn]) peerState { | |||||||
| 		if !addr.IsValid() { | 		if !addr.IsValid() { | ||||||
| 			break | 			break | ||||||
| 		} | 		} | ||||||
|  | 		s.logf("Probing %v...", addr) | ||||||
| 		s.SendTo(packetProbe{TraceID: newTraceID()}, addr) | 		s.SendTo(packetProbe{TraceID: newTraceID()}, addr) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -112,6 +113,7 @@ func (s *stateServer2) onSyn(msg controlMsg[packetSyn]) peerState { | |||||||
|  |  | ||||||
| func (s *stateServer2) onProbe(msg controlMsg[packetProbe]) peerState { | func (s *stateServer2) onProbe(msg controlMsg[packetProbe]) peerState { | ||||||
| 	if msg.SrcAddr.IsValid() { | 	if msg.SrcAddr.IsValid() { | ||||||
|  | 		s.logf("Probe response %v...", msg.SrcAddr) | ||||||
| 		s.SendTo(packetProbe{TraceID: msg.Packet.TraceID}, msg.SrcAddr) | 		s.SendTo(packetProbe{TraceID: msg.Packet.TraceID}, msg.SrcAddr) | ||||||
| 	} | 	} | ||||||
| 	return s | 	return s | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user