refactor-for-testability #3
| @@ -24,22 +24,26 @@ func newPubAddrStore(localAddr netip.AddrPort) *pubAddrStore { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (store *pubAddrStore) Store(add netip.AddrPort) { | ||||
| func (store *pubAddrStore) Store(addr netip.AddrPort) { | ||||
| 	if store.localPub { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if !addr.IsValid() { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if addr.Addr().IsPrivate() { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	store.lock.Lock() | ||||
| 	defer store.lock.Unlock() | ||||
|  | ||||
| 	if !add.IsValid() { | ||||
| 		return | ||||
| 	if _, exists := store.lastSeen[addr]; !exists { | ||||
| 		store.addrList = append(store.addrList, addr) | ||||
| 	} | ||||
|  | ||||
| 	if _, exists := store.lastSeen[add]; !exists { | ||||
| 		store.addrList = append(store.addrList, add) | ||||
| 	} | ||||
| 	store.lastSeen[add] = time.Now() | ||||
| 	store.lastSeen[addr] = time.Now() | ||||
| 	store.sort() | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -61,13 +61,13 @@ func (s *stateClientInit2) onInit(msg controlMsg[packetInit]) peerState { | ||||
| 		return s | ||||
| 	} | ||||
| 	s.logf("Got INIT version %d.", msg.Packet.Version) | ||||
| 	return s.nextState() | ||||
| 	return enterStateClient(s.peerData) | ||||
| } | ||||
|  | ||||
| func (s *stateClientInit2) onPing() peerState { | ||||
| 	if time.Since(s.startedAt) > timeoutInterval { | ||||
| 		s.logf("Init timeout. Assuming version 1.") | ||||
| 		return s.nextState() | ||||
| 		return enterStateClient(s.peerData) | ||||
| 	} | ||||
|  | ||||
| 	s.sendInit() | ||||
| @@ -83,11 +83,3 @@ func (s *stateClientInit2) sendInit() { | ||||
| 	} | ||||
| 	s.Send(s.staged, init) | ||||
| } | ||||
|  | ||||
| func (s *stateClientInit2) nextState() peerState { | ||||
| 	if s.staged.Direct { | ||||
| 		return enterStateClientDirect2(s.peerData, s.staged.DirectAddr) | ||||
| 	} | ||||
|  | ||||
| 	return enterStateClientRelayed2(s.peerData) | ||||
| } | ||||
|   | ||||
| @@ -17,10 +17,12 @@ type stateClientRelayed2 struct { | ||||
| 	probes   map[uint64]sentProbe | ||||
| } | ||||
|  | ||||
| func enterStateClientRelayed2(data *peerData) peerState { | ||||
| 	data.staged.Relay = false | ||||
| 	data.staged.Direct = false | ||||
| 	data.staged.DirectAddr = netip.AddrPort{} | ||||
| func enterStateClient(data *peerData) peerState { | ||||
| 	ip, ipValid := netip.AddrFromSlice(data.peer.PublicIP) | ||||
|  | ||||
| 	data.staged.Relay = data.peer.Relay && ipValid | ||||
| 	data.staged.Direct = ipValid | ||||
| 	data.staged.DirectAddr = netip.AddrPortFrom(ip, data.peer.Port) | ||||
| 	data.publish(data.staged) | ||||
|  | ||||
| 	state := &stateClientRelayed2{ | ||||
| @@ -29,7 +31,7 @@ func enterStateClientRelayed2(data *peerData) peerState { | ||||
| 		syn: packetSyn{ | ||||
| 			TraceID:       newTraceID(), | ||||
| 			SharedKey:     data.staged.DataCipher.Key(), | ||||
| 			Direct:        false, | ||||
| 			Direct:        data.staged.Direct, | ||||
| 			PossibleAddrs: data.pubAddrs.Get(), | ||||
| 		}, | ||||
| 		probes: map[uint64]sentProbe{}, | ||||
| @@ -39,7 +41,7 @@ func enterStateClientRelayed2(data *peerData) peerState { | ||||
|  | ||||
| 	data.pingTimer.Reset(pingInterval) | ||||
|  | ||||
| 	state.logf("==> ClientRelayed") | ||||
| 	state.logf("==> Client") | ||||
| 	return state | ||||
| } | ||||
|  | ||||
| @@ -52,22 +54,22 @@ func (s *stateClientRelayed2) OnMsg(raw any) peerState { | ||||
| 	case peerUpdateMsg: | ||||
| 		return initPeerState(s.peerData, msg.Peer) | ||||
| 	case controlMsg[packetAck]: | ||||
| 		return s.onAck(msg) | ||||
| 		s.onAck(msg) | ||||
| 	case controlMsg[packetProbe]: | ||||
| 		return s.onProbe(msg) | ||||
| 	case controlMsg[packetLocalDiscovery]: | ||||
| 		return s.onLocalDiscovery(msg) | ||||
| 		s.onLocalDiscovery(msg) | ||||
| 	case pingTimerMsg: | ||||
| 		return s.onPingTimer() | ||||
| 	default: | ||||
| 		s.logf("Ignoring message: %v", raw) | ||||
| 		return s | ||||
| 	} | ||||
| 	return s | ||||
| } | ||||
|  | ||||
| func (s *stateClientRelayed2) onAck(msg controlMsg[packetAck]) peerState { | ||||
| func (s *stateClientRelayed2) onAck(msg controlMsg[packetAck]) { | ||||
| 	if msg.Packet.TraceID != s.syn.TraceID { | ||||
| 		return s | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	s.lastSeen = time.Now() | ||||
| @@ -78,7 +80,14 @@ func (s *stateClientRelayed2) onAck(msg controlMsg[packetAck]) peerState { | ||||
| 		s.logf("Got ACK.") | ||||
| 	} | ||||
|  | ||||
| 	if s.staged.Direct { | ||||
| 		s.pubAddrs.Store(msg.Packet.ToAddr) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// Relayed below. | ||||
|  | ||||
| 	s.cleanProbes() | ||||
|  | ||||
| 	for _, addr := range msg.Packet.PossibleAddrs { | ||||
| 		if !addr.IsValid() { | ||||
| @@ -86,10 +95,6 @@ func (s *stateClientRelayed2) onAck(msg controlMsg[packetAck]) peerState { | ||||
| 		} | ||||
| 		s.sendProbeTo(addr) | ||||
| 	} | ||||
|  | ||||
| 	s.cleanProbes() | ||||
|  | ||||
| 	return s | ||||
| } | ||||
|  | ||||
| func (s *stateClientRelayed2) onPingTimer() peerState { | ||||
| @@ -105,6 +110,10 @@ func (s *stateClientRelayed2) onPingTimer() peerState { | ||||
| } | ||||
|  | ||||
| func (s *stateClientRelayed2) onProbe(msg controlMsg[packetProbe]) peerState { | ||||
| 	if s.staged.Direct { | ||||
| 		return s | ||||
| 	} | ||||
|  | ||||
| 	s.cleanProbes() | ||||
|  | ||||
| 	sent, ok := s.probes[msg.Packet.TraceID] | ||||
| @@ -112,16 +121,27 @@ func (s *stateClientRelayed2) onProbe(msg controlMsg[packetProbe]) peerState { | ||||
| 		return s | ||||
| 	} | ||||
|  | ||||
| 	s.staged.Direct = true | ||||
| 	s.staged.DirectAddr = sent.Addr | ||||
| 	s.publish(s.staged) | ||||
|  | ||||
| 	s.syn.TraceID = newTraceID() | ||||
| 	s.syn.Direct = true | ||||
| 	s.Send(s.staged, s.syn) | ||||
|  | ||||
| 	s.logf("Successful probe.") | ||||
| 	return enterStateClientDirect2(s.peerData, sent.Addr) | ||||
| 	return s | ||||
| } | ||||
|  | ||||
| func (s *stateClientRelayed2) onLocalDiscovery(msg controlMsg[packetLocalDiscovery]) peerState { | ||||
| func (s *stateClientRelayed2) onLocalDiscovery(msg controlMsg[packetLocalDiscovery]) { | ||||
| 	if s.staged.Direct { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// The source port will be the multicast port, so we'll have to | ||||
| 	// construct the correct address using the peer's listed port. | ||||
| 	addr := netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) | ||||
| 	s.sendProbeTo(addr) | ||||
| 	return s | ||||
| } | ||||
|  | ||||
| func (s *stateClientRelayed2) cleanProbes() { | ||||
| @@ -138,5 +158,6 @@ func (s *stateClientRelayed2) sendProbeTo(addr netip.AddrPort) { | ||||
| 		SentAt: time.Now(), | ||||
| 		Addr:   addr, | ||||
| 	} | ||||
| 	s.logf("Probing %v...", addr) | ||||
| 	s.SendTo(probe, addr) | ||||
| } | ||||
|   | ||||
| @@ -104,6 +104,7 @@ func (s *stateServer2) onSyn(msg controlMsg[packetSyn]) peerState { | ||||
| 		if !addr.IsValid() { | ||||
| 			break | ||||
| 		} | ||||
| 		s.logf("Probing %v...", addr) | ||||
| 		s.SendTo(packetProbe{TraceID: newTraceID()}, addr) | ||||
| 	} | ||||
|  | ||||
| @@ -112,6 +113,7 @@ func (s *stateServer2) onSyn(msg controlMsg[packetSyn]) peerState { | ||||
|  | ||||
| func (s *stateServer2) onProbe(msg controlMsg[packetProbe]) peerState { | ||||
| 	if msg.SrcAddr.IsValid() { | ||||
| 		s.logf("Probe response %v...", msg.SrcAddr) | ||||
| 		s.SendTo(packetProbe{TraceID: msg.Packet.TraceID}, msg.SrcAddr) | ||||
| 	} | ||||
| 	return s | ||||
|   | ||||
		Reference in New Issue
	
	Block a user