refactor-for-testability #3
| @@ -50,9 +50,7 @@ func (r *connReader) Run() { | |||||||
|  |  | ||||||
| func (r *connReader) handleNextPacket() { | func (r *connReader) handleNextPacket() { | ||||||
| 	buf := r.buf[:bufferSize] | 	buf := r.buf[:bufferSize] | ||||||
| 	log.Printf("Getting next packet...") |  | ||||||
| 	n, remoteAddr, err := r.readFromUDPAddrPort(buf) | 	n, remoteAddr, err := r.readFromUDPAddrPort(buf) | ||||||
| 	log.Printf("Packet from %v...", remoteAddr) |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Fatalf("Failed to read from UDP port: %v", err) | 		log.Fatalf("Failed to read from UDP port: %v", err) | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -51,9 +51,7 @@ func newHubPoller( | |||||||
| } | } | ||||||
|  |  | ||||||
| func (hp *hubPoller) Run() { | func (hp *hubPoller) Run() { | ||||||
| 	log.Printf("Running hub poller...") |  | ||||||
| 	state, err := loadNetworkState(hp.netName) | 	state, err := loadNetworkState(hp.netName) | ||||||
| 	log.Printf("Got state (%s) : %v", hp.netName, state) |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Printf("Failed to load network state: %v", err) | 		log.Printf("Failed to load network state: %v", err) | ||||||
| 		log.Printf("Polling hub...") | 		log.Printf("Polling hub...") | ||||||
|   | |||||||
| @@ -1,54 +1,70 @@ | |||||||
| package peer | package peer | ||||||
|  |  | ||||||
| /* | import ( | ||||||
| type mcReader struct { | 	"log" | ||||||
| 	conn  udpReader | 	"net" | ||||||
| 	super controlMsgHandler | 	"sync/atomic" | ||||||
| 	peers [256]*atomic.Pointer[remotePeer] | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
| 	incoming []byte | func runMCReader( | ||||||
| 	buf      []byte | 	rt *atomic.Pointer[routingTable], | ||||||
| } | 	handleControlMsg func(destIP byte, msg any), | ||||||
|  | ) { | ||||||
| func newMCReader( |  | ||||||
| 	conn udpReader, |  | ||||||
| 	super controlMsgHandler, |  | ||||||
| 	peers [256]*atomic.Pointer[remotePeer], |  | ||||||
| ) *mcReader { |  | ||||||
| 	return &mcReader{conn, super, peers, newBuf(), newBuf()} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (r *mcReader) Run() { |  | ||||||
| 	for { | 	for { | ||||||
| 		r.handleNextPacket() | 		runMCReader2(rt, handleControlMsg) | ||||||
|  | 		time.Sleep(8 * time.Second) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func (r *mcReader) handleNextPacket() { | func runMCReader2( | ||||||
| 	incoming := r.incoming[:bufferSize] | 	rt *atomic.Pointer[routingTable], | ||||||
| 	n, remoteAddr, err := r.conn.ReadFromUDPAddrPort(incoming) | 	handleControlMsg func(destIP byte, msg any), | ||||||
|  | ) { | ||||||
|  | 	var ( | ||||||
|  | 		raw  = newBuf() | ||||||
|  | 		buf  = newBuf() | ||||||
|  | 		logf = func(s string, args ...any) { | ||||||
|  | 			log.Printf("[MCReader] "+s, args...) | ||||||
|  | 		} | ||||||
|  | 	) | ||||||
|  |  | ||||||
|  | 	conn, err := net.ListenMulticastUDP("udp", nil, multicastAddr) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Fatalf("Failed to read from UDP multicast port: %v", err) | 		logf("Failed to bind to multicast address: %v", err) | ||||||
|  | 		return | ||||||
| 	} | 	} | ||||||
| 	incoming = incoming[:n] |  | ||||||
|  |  | ||||||
| 	h, ok := headerFromLocalDiscoveryPacket(incoming) | 	for { | ||||||
|  | 		conn.SetReadDeadline(time.Now().Add(32 * time.Second)) | ||||||
|  | 		n, remoteAddr, err := conn.ReadFromUDPAddrPort(raw[:bufferSize]) | ||||||
|  | 		if err != nil { | ||||||
|  | 			logf("Failed to read from UDP port): %v", err) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		raw = raw[:n] | ||||||
|  | 		h, ok := headerFromLocalDiscoveryPacket(raw) | ||||||
| 		if !ok { | 		if !ok { | ||||||
| 		return | 			logf("Failed to open discovery packet?") | ||||||
|  | 			continue | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 	peer := r.peers[h.SourceIP].Load() | 		peer := rt.Load().Peers[h.SourceIP] | ||||||
| 	if peer == nil || peer.PubSignKey == nil { | 		if peer.PubSignKey == nil { | ||||||
| 		return | 			logf("No signing key for peer %d.", h.SourceIP) | ||||||
|  | 			continue | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 	if !verifyLocalDiscoveryPacket(incoming, r.buf, peer.PubSignKey) { | 		if !verifyLocalDiscoveryPacket(raw, buf, peer.PubSignKey) { | ||||||
| 		return | 			logf("Invalid signature from peer: %d", h.SourceIP) | ||||||
|  | 			continue | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 	r.super.HandleControlMsg(controlMsg[packetLocalDiscovery]{ | 		msg := controlMsg[packetLocalDiscovery]{ | ||||||
| 			SrcIP:   h.SourceIP, | 			SrcIP:   h.SourceIP, | ||||||
| 			SrcAddr: remoteAddr, | 			SrcAddr: remoteAddr, | ||||||
| 	}) |  | ||||||
| 		} | 		} | ||||||
| */ | 		handleControlMsg(h.SourceIP, msg) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|   | |||||||
| @@ -1,6 +1,10 @@ | |||||||
| package peer | package peer | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"log" | ||||||
|  | 	"net" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
| 	"golang.org/x/crypto/nacl/sign" | 	"golang.org/x/crypto/nacl/sign" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -32,22 +36,18 @@ func verifyLocalDiscoveryPacket(pkt, buf []byte, pubSignKey []byte) bool { | |||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
| /* | func runMCWriter(localIP byte, signingKey []byte) { | ||||||
| type mcWriter struct { | 	discoveryPacket := createLocalDiscoveryPacket(localIP, signingKey) | ||||||
|  |  | ||||||
| 	conn            mcUDPWriter | 	conn, err := net.ListenMulticastUDP("udp", nil, multicastAddr) | ||||||
| 	discoveryPacket []byte | 	if err != nil { | ||||||
|  | 		log.Fatalf("Failed to bind to multicast address: %v", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| func newMCWriter(conn mcUDPWriter, localIP byte, signingKey []byte) *mcWriter { | 	for range time.Tick(16 * time.Second) { | ||||||
| 	return &mcWriter{ | 		_, err := conn.WriteToUDP(discoveryPacket, multicastAddr) | ||||||
| 		conn:            conn, | 		if err != nil { | ||||||
| 		discoveryPacket: createLocalDiscoveryPacket(localIP, signingKey), | 			log.Printf("[MCWriter] Failed to write multicast: %v", err) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| func (w *mcWriter) SendLocalDiscovery() { |  | ||||||
| 	if _, err := w.conn.WriteToUDP(w.discoveryPacket, multicastAddr); err != nil { |  | ||||||
| 		log.Printf("[MCWriter] Failed to write multicast UDP packet: %v", err) |  | ||||||
| } | } | ||||||
|   }*/ |  | ||||||
|   | |||||||
| @@ -16,6 +16,8 @@ import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| type peerMain struct { | type peerMain struct { | ||||||
|  | 	conf       localConfig | ||||||
|  | 	rt         *atomic.Pointer[routingTable] | ||||||
| 	ifReader   *ifReader | 	ifReader   *ifReader | ||||||
| 	connReader *connReader | 	connReader *connReader | ||||||
| 	iface      io.Writer | 	iface      io.Writer | ||||||
| @@ -92,6 +94,8 @@ func newPeerMain(conf peerConfig) *peerMain { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return &peerMain{ | 	return &peerMain{ | ||||||
|  | 		conf:       config, | ||||||
|  | 		rt:         rtPtr, | ||||||
| 		iface:      iface, | 		iface:      iface, | ||||||
| 		ifReader:   ifReader, | 		ifReader:   ifReader, | ||||||
| 		connReader: connReader, | 		connReader: connReader, | ||||||
| @@ -104,6 +108,8 @@ func (p *peerMain) Run() { | |||||||
| 	go p.ifReader.Run() | 	go p.ifReader.Run() | ||||||
| 	go p.connReader.Run() | 	go p.connReader.Run() | ||||||
| 	p.super.Start() | 	p.super.Start() | ||||||
|  | 	go runMCWriter(p.conf.PeerIP, p.conf.PrivSignKey) | ||||||
|  | 	go runMCReader(p.rt, p.super.HandleControlMsg) | ||||||
| 	p.hubPoller.Run() | 	p.hubPoller.Run() | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -70,7 +70,6 @@ func (s *pState) OnPeerUpdate(peer *m.Peer) peerState { | |||||||
|  |  | ||||||
| 	s.staged.IP = peer.PeerIP | 	s.staged.IP = peer.PeerIP | ||||||
| 	s.staged.PubSignKey = peer.PubSignKey | 	s.staged.PubSignKey = peer.PubSignKey | ||||||
| 	log.Printf("New cipher: %x, %x", s.privKey, peer.PubKey) |  | ||||||
| 	s.staged.ControlCipher = newControlCipher(s.privKey, peer.PubKey) | 	s.staged.ControlCipher = newControlCipher(s.privKey, peer.PubKey) | ||||||
| 	s.staged.DataCipher = newDataCipher() | 	s.staged.DataCipher = newDataCipher() | ||||||
|  |  | ||||||
|   | |||||||
| @@ -5,10 +5,12 @@ import ( | |||||||
| 	"net/netip" | 	"net/netip" | ||||||
| 	"runtime/debug" | 	"runtime/debug" | ||||||
| 	"sort" | 	"sort" | ||||||
|  | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type pubAddrStore struct { | type pubAddrStore struct { | ||||||
|  | 	lock      sync.Mutex | ||||||
| 	localPub  bool | 	localPub  bool | ||||||
| 	localAddr netip.AddrPort | 	localAddr netip.AddrPort | ||||||
| 	lastSeen  map[netip.AddrPort]time.Time | 	lastSeen  map[netip.AddrPort]time.Time | ||||||
| @@ -25,6 +27,9 @@ func newPubAddrStore(localAddr netip.AddrPort) *pubAddrStore { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (store *pubAddrStore) Store(add netip.AddrPort) { | func (store *pubAddrStore) Store(add netip.AddrPort) { | ||||||
|  | 	store.lock.Lock() | ||||||
|  | 	defer store.lock.Unlock() | ||||||
|  |  | ||||||
| 	if store.localPub { | 	if store.localPub { | ||||||
| 		log.Printf("OOPS: Local pub but storage attempt: %s", debug.Stack()) | 		log.Printf("OOPS: Local pub but storage attempt: %s", debug.Stack()) | ||||||
| 		return | 		return | ||||||
| @@ -42,6 +47,11 @@ func (store *pubAddrStore) Store(add netip.AddrPort) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (store *pubAddrStore) Get() (addrs [8]netip.AddrPort) { | func (store *pubAddrStore) Get() (addrs [8]netip.AddrPort) { | ||||||
|  | 	store.lock.Lock() | ||||||
|  | 	defer store.lock.Unlock() | ||||||
|  |  | ||||||
|  | 	store.clean() | ||||||
|  |  | ||||||
| 	if store.localPub { | 	if store.localPub { | ||||||
| 		addrs[0] = store.localAddr | 		addrs[0] = store.localAddr | ||||||
| 		return | 		return | ||||||
| @@ -51,7 +61,7 @@ func (store *pubAddrStore) Get() (addrs [8]netip.AddrPort) { | |||||||
| 	return | 	return | ||||||
| } | } | ||||||
|  |  | ||||||
| func (store *pubAddrStore) Clean() { | func (store *pubAddrStore) clean() { | ||||||
| 	if store.localPub { | 	if store.localPub { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -20,7 +20,7 @@ func TestPubAddrStore(t *testing.T) { | |||||||
| 		time.Sleep(time.Millisecond) | 		time.Sleep(time.Millisecond) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	s.Clean() | 	s.clean() | ||||||
|  |  | ||||||
| 	l2 := s.Get() | 	l2 := s.Get() | ||||||
| 	if l2[0] != l[2] || l2[1] != l[1] || l2[2] != l[0] { | 	if l2[0] != l[2] || l2[1] != l[1] || l2[2] != l[0] { | ||||||
|   | |||||||
| @@ -1,7 +1,6 @@ | |||||||
| package peer | package peer | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"log" |  | ||||||
| 	"net/netip" | 	"net/netip" | ||||||
| 	"sync/atomic" | 	"sync/atomic" | ||||||
| 	"time" | 	"time" | ||||||
| @@ -68,8 +67,6 @@ func (p remotePeer) EncryptControlPacket(pkt marshaller, tmp, out []byte) []byte | |||||||
| 		DestIP:   p.IP, | 		DestIP:   p.IP, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Printf("Encrypting with header: %#v", h) |  | ||||||
|  |  | ||||||
| 	return p.ControlCipher.Encrypt(h, tmp, out) | 	return p.ControlCipher.Encrypt(h, tmp, out) | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user