refactor-for-testability #3
| @@ -84,7 +84,7 @@ func (r *connReader) handleControlPacket( | |||||||
| 	enc []byte, | 	enc []byte, | ||||||
| ) { | ) { | ||||||
| 	if peer.ControlCipher == nil { | 	if peer.ControlCipher == nil { | ||||||
| 		log.Printf("No control cipher for peer: %v", h) | 		r.logf("No control cipher for peer: %d", h.SourceIP) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|   | |||||||
| @@ -20,6 +20,8 @@ const ( | |||||||
|  |  | ||||||
| 	pingInterval                  = 8 * time.Second | 	pingInterval                  = 8 * time.Second | ||||||
| 	timeoutInterval               = 30 * time.Second | 	timeoutInterval               = 30 * time.Second | ||||||
|  | 	broadcastInterval             = 16 * time.Second | ||||||
|  | 	broadcastErrorTimeoutInterval = 8 * time.Second | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var multicastAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom( | var multicastAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom( | ||||||
|   | |||||||
| @@ -50,11 +50,15 @@ func newHubPoller( | |||||||
| 	}, nil | 	}, nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (hp *hubPoller) logf(s string, args ...any) { | ||||||
|  | 	log.Printf("[HubPoller] "+s, args...) | ||||||
|  | } | ||||||
|  |  | ||||||
| func (hp *hubPoller) Run() { | func (hp *hubPoller) Run() { | ||||||
| 	state, err := loadNetworkState(hp.netName) | 	state, err := loadNetworkState(hp.netName) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Printf("Failed to load network state: %v", err) | 		hp.logf("Failed to load network state: %v", err) | ||||||
| 		log.Printf("Polling hub...") | 		hp.logf("Polling hub...") | ||||||
| 		hp.pollHub() | 		hp.pollHub() | ||||||
| 	} else { | 	} else { | ||||||
| 		hp.applyNetworkState(state) | 		hp.applyNetworkState(state) | ||||||
| @@ -70,25 +74,25 @@ func (hp *hubPoller) pollHub() { | |||||||
|  |  | ||||||
| 	resp, err := hp.client.Do(hp.req) | 	resp, err := hp.client.Do(hp.req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Printf("Failed to fetch peer state: %v", err) | 		hp.logf("Failed to fetch peer state: %v", err) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	body, err := io.ReadAll(resp.Body) | 	body, err := io.ReadAll(resp.Body) | ||||||
| 	_ = resp.Body.Close() | 	_ = resp.Body.Close() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Printf("Failed to read body from hub: %v", err) | 		hp.logf("Failed to read body from hub: %v", err) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if err := json.Unmarshal(body, &state); err != nil { | 	if err := json.Unmarshal(body, &state); err != nil { | ||||||
| 		log.Printf("Failed to unmarshal response from hub: %v\n%s", err, body) | 		hp.logf("Failed to unmarshal response from hub: %v\n%s", err, body) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	hp.applyNetworkState(state) | 	hp.applyNetworkState(state) | ||||||
|  |  | ||||||
| 	if err := storeNetworkState(hp.netName, state); err != nil { | 	if err := storeNetworkState(hp.netName, state); err != nil { | ||||||
| 		log.Printf("Failed to store network state: %v", err) | 		hp.logf("Failed to store network state: %v", err) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,13 +0,0 @@ | |||||||
| package peer |  | ||||||
|  |  | ||||||
| import "log" |  | ||||||
|  |  | ||||||
| func logPacket(p []byte, notes string) { |  | ||||||
| 	h := parseHeader(p) |  | ||||||
| 	log.Printf(`Sending: Data: %v | From: %d | To:   %d | %s |  | ||||||
| `, |  | ||||||
| 		h.StreamID == dataStreamID, |  | ||||||
| 		h.SourceIP, |  | ||||||
| 		h.DestIP, |  | ||||||
| 		notes) |  | ||||||
| } |  | ||||||
| @@ -12,12 +12,12 @@ func runMCReader( | |||||||
| 	handleControlMsg func(destIP byte, msg any), | 	handleControlMsg func(destIP byte, msg any), | ||||||
| ) { | ) { | ||||||
| 	for { | 	for { | ||||||
| 		runMCReader2(rt, handleControlMsg) | 		runMCReaderInner(rt, handleControlMsg) | ||||||
| 		time.Sleep(8 * time.Second) | 		time.Sleep(broadcastErrorTimeoutInterval) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func runMCReader2( | func runMCReaderInner( | ||||||
| 	rt *atomic.Pointer[routingTable], | 	rt *atomic.Pointer[routingTable], | ||||||
| 	handleControlMsg func(destIP byte, msg any), | 	handleControlMsg func(destIP byte, msg any), | ||||||
| ) { | ) { | ||||||
|   | |||||||
| @@ -41,10 +41,10 @@ func runMCWriter(localIP byte, signingKey []byte) { | |||||||
|  |  | ||||||
| 	conn, err := net.ListenMulticastUDP("udp", nil, multicastAddr) | 	conn, err := net.ListenMulticastUDP("udp", nil, multicastAddr) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Fatalf("Failed to bind to multicast address: %v", err) | 		log.Fatalf("[MCWriter] Failed to bind to multicast address: %v", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	for range time.Tick(8 * time.Second) { | 	for range time.Tick(broadcastInterval) { | ||||||
| 		_, err := conn.WriteToUDP(discoveryPacket, multicastAddr) | 		_, err := conn.WriteToUDP(discoveryPacket, multicastAddr) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Printf("[MCWriter] Failed to write multicast: %v", err) | 			log.Printf("[MCWriter] Failed to write multicast: %v", err) | ||||||
|   | |||||||
							
								
								
									
										24
									
								
								peer/peer.go
									
									
									
									
									
								
							
							
						
						
									
										24
									
								
								peer/peer.go
									
									
									
									
									
								
							| @@ -32,10 +32,14 @@ type peerConfig struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| func newPeerMain(conf peerConfig) *peerMain { | func newPeerMain(conf peerConfig) *peerMain { | ||||||
|  | 	logf := func(s string, args ...any) { | ||||||
|  | 		log.Printf("[Main] "+s, args...) | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	config, err := loadPeerConfig(conf.NetName) | 	config, err := loadPeerConfig(conf.NetName) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Printf("Failed to load configuration: %v", err) | 		logf("Failed to load configuration: %v", err) | ||||||
| 		log.Printf("Initializing...") | 		logf("Initializing...") | ||||||
| 		initPeerWithHub(conf) | 		initPeerWithHub(conf) | ||||||
|  |  | ||||||
| 		config, err = loadPeerConfig(conf.NetName) | 		config, err = loadPeerConfig(conf.NetName) | ||||||
| @@ -54,7 +58,7 @@ func newPeerMain(conf peerConfig) *peerMain { | |||||||
| 		log.Fatalf("Failed to resolve UDP address: %v", err) | 		log.Fatalf("Failed to resolve UDP address: %v", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Printf("Listening on %v...", myAddr) | 	logf("Listening on %v...", myAddr) | ||||||
| 	conn, err := net.ListenUDP("udp", myAddr) | 	conn, err := net.ListenUDP("udp", myAddr) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Fatalf("Failed to open UDP port: %v", err) | 		log.Fatalf("Failed to open UDP port: %v", err) | ||||||
| @@ -69,15 +73,15 @@ func newPeerMain(conf peerConfig) *peerMain { | |||||||
| 		writeLock.Lock() | 		writeLock.Lock() | ||||||
| 		n, err = conn.WriteToUDPAddrPort(b, addr) | 		n, err = conn.WriteToUDPAddrPort(b, addr) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Printf("Failed to write packet: %v", err) | 			logf("Failed to write packet: %v", err) | ||||||
| 		} | 		} | ||||||
| 		writeLock.Unlock() | 		writeLock.Unlock() | ||||||
| 		return n, err | 		return n, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	var localAddr netip.AddrPort | 	var localAddr netip.AddrPort | ||||||
| 	ip, ok := netip.AddrFromSlice(config.PublicIP) | 	ip, localAddrValid := netip.AddrFromSlice(config.PublicIP) | ||||||
| 	if ok { | 	if localAddrValid { | ||||||
| 		localAddr = netip.AddrPortFrom(ip, config.Port) | 		localAddr = netip.AddrPortFrom(ip, config.Port) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -105,12 +109,18 @@ func newPeerMain(conf peerConfig) *peerMain { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (p *peerMain) Run() { | func (p *peerMain) Run() { | ||||||
|  |  | ||||||
| 	go p.ifReader.Run() | 	go p.ifReader.Run() | ||||||
| 	go p.connReader.Run() | 	go p.connReader.Run() | ||||||
| 	p.super.Start() | 	p.super.Start() | ||||||
|  |  | ||||||
|  | 	if !p.rt.Load().LocalAddr.IsValid() { | ||||||
| 		go runMCWriter(p.conf.PeerIP, p.conf.PrivSignKey) | 		go runMCWriter(p.conf.PeerIP, p.conf.PrivSignKey) | ||||||
| 		go runMCReader(p.rt, p.super.HandleControlMsg) | 		go runMCReader(p.rt, p.super.HandleControlMsg) | ||||||
| 	p.hubPoller.Run() | 	} | ||||||
|  |  | ||||||
|  | 	go p.hubPoller.Run() | ||||||
|  | 	select {} | ||||||
| } | } | ||||||
|  |  | ||||||
| func initPeerWithHub(conf peerConfig) { | func initPeerWithHub(conf peerConfig) { | ||||||
|   | |||||||
							
								
								
									
										162
									
								
								peer/state-client.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										162
									
								
								peer/state-client.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,162 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"net/netip" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type sentProbe struct { | ||||||
|  | 	SentAt time.Time | ||||||
|  | 	Addr   netip.AddrPort | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type stateClient struct { | ||||||
|  | 	*peerData | ||||||
|  | 	lastSeen time.Time | ||||||
|  | 	syn      packetSyn | ||||||
|  | 	probes   map[uint64]sentProbe | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func enterStateClient(data *peerData) peerState { | ||||||
|  | 	ip, ipValid := netip.AddrFromSlice(data.peer.PublicIP) | ||||||
|  |  | ||||||
|  | 	data.staged.Relay = data.peer.Relay && ipValid | ||||||
|  | 	data.staged.Direct = ipValid | ||||||
|  | 	data.staged.DirectAddr = netip.AddrPortFrom(ip, data.peer.Port) | ||||||
|  | 	data.publish(data.staged) | ||||||
|  |  | ||||||
|  | 	state := &stateClient{ | ||||||
|  | 		peerData: data, | ||||||
|  | 		lastSeen: time.Now(), | ||||||
|  | 		syn: packetSyn{ | ||||||
|  | 			TraceID:       newTraceID(), | ||||||
|  | 			SharedKey:     data.staged.DataCipher.Key(), | ||||||
|  | 			Direct:        data.staged.Direct, | ||||||
|  | 			PossibleAddrs: data.pubAddrs.Get(), | ||||||
|  | 		}, | ||||||
|  | 		probes: map[uint64]sentProbe{}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	state.Send(state.staged, state.syn) | ||||||
|  |  | ||||||
|  | 	data.pingTimer.Reset(pingInterval) | ||||||
|  |  | ||||||
|  | 	state.logf("==> Client") | ||||||
|  | 	return state | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateClient) logf(str string, args ...any) { | ||||||
|  | 	s.peerData.logf("CLNT | "+str, args...) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateClient) OnMsg(raw any) peerState { | ||||||
|  | 	switch msg := raw.(type) { | ||||||
|  | 	case peerUpdateMsg: | ||||||
|  | 		return initPeerState(s.peerData, msg.Peer) | ||||||
|  | 	case controlMsg[packetAck]: | ||||||
|  | 		s.onAck(msg) | ||||||
|  | 	case controlMsg[packetProbe]: | ||||||
|  | 		return s.onProbe(msg) | ||||||
|  | 	case controlMsg[packetLocalDiscovery]: | ||||||
|  | 		s.onLocalDiscovery(msg) | ||||||
|  | 	case pingTimerMsg: | ||||||
|  | 		return s.onPingTimer() | ||||||
|  | 	default: | ||||||
|  | 		s.logf("Ignoring message: %v", raw) | ||||||
|  | 	} | ||||||
|  | 	return s | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateClient) onAck(msg controlMsg[packetAck]) { | ||||||
|  | 	if msg.Packet.TraceID != s.syn.TraceID { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	s.lastSeen = time.Now() | ||||||
|  |  | ||||||
|  | 	if !s.staged.Up { | ||||||
|  | 		s.staged.Up = true | ||||||
|  | 		s.publish(s.staged) | ||||||
|  | 		s.logf("Got ACK.") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if s.staged.Direct { | ||||||
|  | 		s.pubAddrs.Store(msg.Packet.ToAddr) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Relayed below. | ||||||
|  |  | ||||||
|  | 	s.cleanProbes() | ||||||
|  |  | ||||||
|  | 	for _, addr := range msg.Packet.PossibleAddrs { | ||||||
|  | 		if !addr.IsValid() { | ||||||
|  | 			break | ||||||
|  | 		} | ||||||
|  | 		s.sendProbeTo(addr) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateClient) onPingTimer() peerState { | ||||||
|  | 	if time.Since(s.lastSeen) > timeoutInterval { | ||||||
|  | 		if s.staged.Up { | ||||||
|  | 			s.logf("Timeout.") | ||||||
|  | 		} | ||||||
|  | 		return initPeerState(s.peerData, s.peer) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	s.Send(s.staged, s.syn) | ||||||
|  | 	return s | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateClient) onProbe(msg controlMsg[packetProbe]) peerState { | ||||||
|  | 	if s.staged.Direct { | ||||||
|  | 		return s | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	s.cleanProbes() | ||||||
|  |  | ||||||
|  | 	sent, ok := s.probes[msg.Packet.TraceID] | ||||||
|  | 	if !ok { | ||||||
|  | 		return s | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	s.staged.Direct = true | ||||||
|  | 	s.staged.DirectAddr = sent.Addr | ||||||
|  | 	s.publish(s.staged) | ||||||
|  |  | ||||||
|  | 	s.syn.TraceID = newTraceID() | ||||||
|  | 	s.syn.Direct = true | ||||||
|  | 	s.Send(s.staged, s.syn) | ||||||
|  | 	s.logf("Successful probe.") | ||||||
|  | 	return s | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateClient) onLocalDiscovery(msg controlMsg[packetLocalDiscovery]) { | ||||||
|  | 	if s.staged.Direct { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// The source port will be the multicast port, so we'll have to | ||||||
|  | 	// construct the correct address using the peer's listed port. | ||||||
|  | 	addr := netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) | ||||||
|  | 	s.sendProbeTo(addr) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateClient) cleanProbes() { | ||||||
|  | 	for key, sent := range s.probes { | ||||||
|  | 		if time.Since(sent.SentAt) > pingInterval { | ||||||
|  | 			delete(s.probes, key) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateClient) sendProbeTo(addr netip.AddrPort) { | ||||||
|  | 	probe := packetProbe{TraceID: newTraceID()} | ||||||
|  | 	s.probes[probe.TraceID] = sentProbe{ | ||||||
|  | 		SentAt: time.Now(), | ||||||
|  | 		Addr:   addr, | ||||||
|  | 	} | ||||||
|  | 	s.logf("Probing %v...", addr) | ||||||
|  | 	s.SendTo(probe, addr) | ||||||
|  | } | ||||||
		Reference in New Issue
	
	Block a user