refactor-for-testability #3
| @@ -1,40 +1,44 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"net/netip" | ||||
| 	"sync/atomic" | ||||
| ) | ||||
|  | ||||
| type connReader struct { | ||||
| 	conn    udpReader | ||||
| 	iface   ifWriter | ||||
| 	sender  encryptedPacketSender | ||||
| 	super   controlMsgHandler | ||||
| 	// Input | ||||
| 	readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error) | ||||
|  | ||||
| 	// Output | ||||
| 	writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error) | ||||
| 	iface              io.Writer | ||||
| 	handleControlMsg   func(fromIP byte, pkt any) | ||||
|  | ||||
| 	localIP byte | ||||
| 	peers   [256]*atomic.Pointer[remotePeer] | ||||
| 	rt      *atomic.Pointer[routingTable] | ||||
|  | ||||
| 	buf    []byte | ||||
| 	decBuf []byte | ||||
| } | ||||
|  | ||||
| func newConnReader( | ||||
| 	conn udpReader, | ||||
| 	ifWriter ifWriter, | ||||
| 	sender encryptedPacketSender, | ||||
| 	super controlMsgHandler, | ||||
| 	localIP byte, | ||||
| 	peers [256]*atomic.Pointer[remotePeer], | ||||
| 	readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error), | ||||
| 	writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error), | ||||
| 	iface io.Writer, | ||||
| 	handleControlMsg func(fromIP byte, pkt any), | ||||
| 	rt *atomic.Pointer[routingTable], | ||||
| ) *connReader { | ||||
| 	return &connReader{ | ||||
| 		conn:    conn, | ||||
| 		iface:   ifWriter, | ||||
| 		sender:  sender, | ||||
| 		super:   super, | ||||
| 		localIP: localIP, | ||||
| 		peers:   peers, | ||||
| 		buf:     make([]byte, bufferSize), | ||||
| 		decBuf:  make([]byte, bufferSize), | ||||
| 		readFromUDPAddrPort: readFromUDPAddrPort, | ||||
| 		writeToUDPAddrPort:  writeToUDPAddrPort, | ||||
| 		iface:               iface, | ||||
| 		handleControlMsg:    handleControlMsg, | ||||
| 		localIP:             rt.Load().LocalIP, | ||||
| 		rt:                  rt, | ||||
| 		buf:                 newBuf(), | ||||
| 		decBuf:              newBuf(), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -44,13 +48,11 @@ func (r *connReader) Run() { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (r *connReader) logf(s string, args ...any) { | ||||
| 	log.Printf("[ConnReader] "+s, args...) | ||||
| } | ||||
|  | ||||
| func (r *connReader) handleNextPacket() { | ||||
| 	buf := r.buf[:bufferSize] | ||||
| 	n, remoteAddr, err := r.conn.ReadFromUDPAddrPort(buf) | ||||
| 	log.Printf("Getting next packet...") | ||||
| 	n, remoteAddr, err := r.readFromUDPAddrPort(buf) | ||||
| 	log.Printf("Packet from %v...", remoteAddr) | ||||
| 	if err != nil { | ||||
| 		log.Fatalf("Failed to read from UDP port: %v", err) | ||||
| 	} | ||||
| @@ -64,23 +66,22 @@ func (r *connReader) handleNextPacket() { | ||||
| 	buf = buf[:n] | ||||
| 	h := parseHeader(buf) | ||||
|  | ||||
| 	peer := r.peers[h.SourceIP].Load() | ||||
| 	rt := r.rt.Load() | ||||
| 	peer := rt.Peers[h.SourceIP] | ||||
|  | ||||
| 	switch h.StreamID { | ||||
| 	case controlStreamID: | ||||
| 		r.handleControlPacket(peer, remoteAddr, h, buf) | ||||
|  | ||||
| 		r.handleControlPacket(remoteAddr, peer, h, buf) | ||||
| 	case dataStreamID: | ||||
| 		r.handleDataPacket(peer, h, buf) | ||||
|  | ||||
| 		r.handleDataPacket(rt, peer, h, buf) | ||||
| 	default: | ||||
| 		r.logf("Unknown stream ID: %d", h.StreamID) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (r *connReader) handleControlPacket( | ||||
| 	peer *remotePeer, | ||||
| 	addr netip.AddrPort, | ||||
| 	remoteAddr netip.AddrPort, | ||||
| 	peer remotePeer, | ||||
| 	h header, | ||||
| 	enc []byte, | ||||
| ) { | ||||
| @@ -93,22 +94,27 @@ func (r *connReader) handleControlPacket( | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	msg, err := decryptControlPacket(peer, addr, h, enc, r.decBuf) | ||||
| 	msg, err := peer.DecryptControlPacket(remoteAddr, h, enc, r.decBuf) | ||||
| 	if err != nil { | ||||
| 		r.logf("Failed to decrypt control packet: %v", err) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	r.super.HandleControlMsg(msg) | ||||
| 	r.handleControlMsg(h.SourceIP, msg) | ||||
| } | ||||
|  | ||||
| func (r *connReader) handleDataPacket(peer *remotePeer, h header, enc []byte) { | ||||
| func (r *connReader) handleDataPacket( | ||||
| 	rt *routingTable, | ||||
| 	peer remotePeer, | ||||
| 	h header, | ||||
| 	enc []byte, | ||||
| ) { | ||||
| 	if !peer.Up { | ||||
| 		r.logf("Not connected (recv).") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	data, err := decryptDataPacket(peer, h, enc, r.decBuf) | ||||
| 	data, err := peer.DecryptDataPacket(h, enc, r.decBuf) | ||||
| 	if err != nil { | ||||
| 		r.logf("Failed to decrypt data packet: %v", err) | ||||
| 		return | ||||
| @@ -121,11 +127,15 @@ func (r *connReader) handleDataPacket(peer *remotePeer, h header, enc []byte) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	destPeer := r.peers[h.DestIP].Load() | ||||
| 	if !destPeer.Up { | ||||
| 		r.logf("Not connected (relay): %d", destPeer.IP) | ||||
| 	relay, ok := rt.GetRelay() | ||||
| 	if !ok { | ||||
| 		r.logf("Relay not available.") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	r.sender.SendEncryptedDataPacket(data, destPeer) | ||||
| 	r.writeToUDPAddrPort(data, relay.DirectAddr) | ||||
| } | ||||
|  | ||||
| func (r *connReader) logf(format string, args ...any) { | ||||
| 	log.Printf("[ConnReader] "+format, args...) | ||||
| } | ||||
|   | ||||
| @@ -1,141 +0,0 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"net/netip" | ||||
| 	"sync/atomic" | ||||
| ) | ||||
|  | ||||
| type ConnReader struct { | ||||
| 	// Input | ||||
| 	readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error) | ||||
|  | ||||
| 	// Output | ||||
| 	writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error) | ||||
| 	iface              io.Writer | ||||
| 	handleControlMsg   func(fromIP byte, pkt any) | ||||
|  | ||||
| 	localIP byte | ||||
| 	rt      *atomic.Pointer[routingTable] | ||||
|  | ||||
| 	buf    []byte | ||||
| 	decBuf []byte | ||||
| } | ||||
|  | ||||
| func NewConnReader( | ||||
| 	readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error), | ||||
| 	writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error), | ||||
| 	iface io.Writer, | ||||
| 	handleControlMsg func(fromIP byte, pkt any), | ||||
| 	rt *atomic.Pointer[routingTable], | ||||
| ) *ConnReader { | ||||
| 	return &ConnReader{ | ||||
| 		readFromUDPAddrPort: readFromUDPAddrPort, | ||||
| 		writeToUDPAddrPort:  writeToUDPAddrPort, | ||||
| 		iface:               iface, | ||||
| 		handleControlMsg:    handleControlMsg, | ||||
| 		localIP:             rt.Load().LocalIP, | ||||
| 		rt:                  rt, | ||||
| 		buf:                 newBuf(), | ||||
| 		decBuf:              newBuf(), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (r *ConnReader) Run() { | ||||
| 	for { | ||||
| 		r.handleNextPacket() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (r *ConnReader) handleNextPacket() { | ||||
| 	buf := r.buf[:bufferSize] | ||||
| 	log.Printf("Getting next packet...") | ||||
| 	n, remoteAddr, err := r.readFromUDPAddrPort(buf) | ||||
| 	log.Printf("Packet from %v...", remoteAddr) | ||||
| 	if err != nil { | ||||
| 		log.Fatalf("Failed to read from UDP port: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	if n < headerSize { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	remoteAddr = netip.AddrPortFrom(remoteAddr.Addr().Unmap(), remoteAddr.Port()) | ||||
|  | ||||
| 	buf = buf[:n] | ||||
| 	h := parseHeader(buf) | ||||
|  | ||||
| 	rt := r.rt.Load() | ||||
| 	peer := rt.Peers[h.SourceIP] | ||||
|  | ||||
| 	switch h.StreamID { | ||||
| 	case controlStreamID: | ||||
| 		r.handleControlPacket(remoteAddr, peer, h, buf) | ||||
| 	case dataStreamID: | ||||
| 		r.handleDataPacket(rt, peer, h, buf) | ||||
| 	default: | ||||
| 		r.logf("Unknown stream ID: %d", h.StreamID) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (r *ConnReader) handleControlPacket( | ||||
| 	remoteAddr netip.AddrPort, | ||||
| 	peer remotePeer, | ||||
| 	h header, | ||||
| 	enc []byte, | ||||
| ) { | ||||
| 	if peer.ControlCipher == nil { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if h.DestIP != r.localIP { | ||||
| 		r.logf("Incorrect destination IP on control packet: %d", h.DestIP) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	msg, err := peer.DecryptControlPacket(remoteAddr, h, enc, r.decBuf) | ||||
| 	if err != nil { | ||||
| 		r.logf("Failed to decrypt control packet: %v", err) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	r.handleControlMsg(h.SourceIP, msg) | ||||
| } | ||||
|  | ||||
| func (r *ConnReader) handleDataPacket( | ||||
| 	rt *routingTable, | ||||
| 	peer remotePeer, | ||||
| 	h header, | ||||
| 	enc []byte, | ||||
| ) { | ||||
| 	if !peer.Up { | ||||
| 		r.logf("Not connected (recv).") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	data, err := peer.DecryptDataPacket(h, enc, r.decBuf) | ||||
| 	if err != nil { | ||||
| 		r.logf("Failed to decrypt data packet: %v", err) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if h.DestIP == r.localIP { | ||||
| 		if _, err := r.iface.Write(data); err != nil { | ||||
| 			log.Fatalf("Failed to write to interface: %v", err) | ||||
| 		} | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	relay, ok := rt.GetRelay() | ||||
| 	if !ok { | ||||
| 		r.logf("Relay not available.") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	r.writeToUDPAddrPort(data, relay.DirectAddr) | ||||
| } | ||||
|  | ||||
| func (r *ConnReader) logf(format string, args ...any) { | ||||
| 	log.Printf("[ConnReader] "+format, args...) | ||||
| } | ||||
| @@ -37,7 +37,7 @@ func generateKeys() cryptoKeys { | ||||
| func encryptControlPacket( | ||||
| 	localIP byte, | ||||
| 	peer *remotePeer, | ||||
| 	pkt Marshaller, | ||||
| 	pkt marshaller, | ||||
| 	tmp []byte, | ||||
| 	out []byte, | ||||
| ) []byte { | ||||
|   | ||||
| @@ -13,12 +13,12 @@ func newRoutePairForTesting() (*remotePeer, *remotePeer) { | ||||
| 	keys1 := generateKeys() | ||||
| 	keys2 := generateKeys() | ||||
|  | ||||
| 	r1 := NewRemotePeer(1) | ||||
| 	r1 := newRemotePeer(1) | ||||
| 	r1.PubSignKey = keys1.PubSignKey | ||||
| 	r1.ControlCipher = newControlCipher(keys1.PrivKey, keys2.PubKey) | ||||
| 	r1.DataCipher = newDataCipher() | ||||
|  | ||||
| 	r2 := NewRemotePeer(2) | ||||
| 	r2 := newRemotePeer(2) | ||||
| 	r2.PubSignKey = keys2.PubSignKey | ||||
| 	r2.ControlCipher = newControlCipher(keys2.PrivKey, keys1.PubKey) | ||||
| 	r2.DataCipher = r1.DataCipher | ||||
|   | ||||
| @@ -27,3 +27,7 @@ var multicastAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom( | ||||
| func newBuf() []byte { | ||||
| 	return make([]byte, bufferSize) | ||||
| } | ||||
|  | ||||
| type marshaller interface { | ||||
| 	Marshal([]byte) []byte | ||||
| } | ||||
|   | ||||
| @@ -7,7 +7,7 @@ import ( | ||||
| 	"sync/atomic" | ||||
| ) | ||||
| 
 | ||||
| type IFReader struct { | ||||
| type ifReader struct { | ||||
| 	iface              io.Reader | ||||
| 	writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error) | ||||
| 	rt                 *atomic.Pointer[routingTable] | ||||
| @@ -15,22 +15,22 @@ type IFReader struct { | ||||
| 	buf2               []byte | ||||
| } | ||||
| 
 | ||||
| func NewIFReader( | ||||
| func newIFReader( | ||||
| 	iface io.Reader, | ||||
| 	writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error), | ||||
| 	rt *atomic.Pointer[routingTable], | ||||
| ) *IFReader { | ||||
| 	return &IFReader{iface, writeToUDPAddrPort, rt, newBuf(), newBuf()} | ||||
| ) *ifReader { | ||||
| 	return &ifReader{iface, writeToUDPAddrPort, rt, newBuf(), newBuf()} | ||||
| } | ||||
| 
 | ||||
| func (r *IFReader) Run() { | ||||
| func (r *ifReader) Run() { | ||||
| 	packet := newBuf() | ||||
| 	for { | ||||
| 		r.handleNextPacket(packet) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (r *IFReader) handleNextPacket(packet []byte) { | ||||
| func (r *ifReader) handleNextPacket(packet []byte) { | ||||
| 	packet = r.readNextPacket(packet) | ||||
| 	remoteIP, ok := r.parsePacket(packet) | ||||
| 	if !ok { | ||||
| @@ -60,7 +60,7 @@ func (r *IFReader) handleNextPacket(packet []byte) { | ||||
| 	r.writeToUDPAddrPort(enc, relay.DirectAddr) | ||||
| } | ||||
| 
 | ||||
| func (r *IFReader) readNextPacket(buf []byte) []byte { | ||||
| func (r *ifReader) readNextPacket(buf []byte) []byte { | ||||
| 	n, err := r.iface.Read(buf[:cap(buf)]) | ||||
| 	if err != nil { | ||||
| 		log.Fatalf("Failed to read from interface: %v", err) | ||||
| @@ -69,7 +69,7 @@ func (r *IFReader) readNextPacket(buf []byte) []byte { | ||||
| 	return buf[:n] | ||||
| } | ||||
| 
 | ||||
| func (r *IFReader) parsePacket(buf []byte) (byte, bool) { | ||||
| func (r *ifReader) parsePacket(buf []byte) (byte, bool) { | ||||
| 	n := len(buf) | ||||
| 	if n == 0 { | ||||
| 		return 0, false | ||||
| @@ -98,6 +98,6 @@ func (r *IFReader) parsePacket(buf []byte) (byte, bool) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (*IFReader) logf(s string, args ...any) { | ||||
| func (*ifReader) logf(s string, args ...any) { | ||||
| 	log.Printf("[IFReader] "+s, args...) | ||||
| } | ||||
| @@ -3,7 +3,6 @@ package peer | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"net" | ||||
| 	"os" | ||||
| 	"syscall" | ||||
| @@ -11,45 +10,6 @@ import ( | ||||
| 	"golang.org/x/sys/unix" | ||||
| ) | ||||
|  | ||||
| // Get next packet, returning packet, ip, and possible error. | ||||
| func readNextPacket(iface io.ReadWriteCloser, buf []byte) ([]byte, byte, error) { | ||||
| 	var ( | ||||
| 		version byte | ||||
| 		ip      byte | ||||
| 	) | ||||
| 	for { | ||||
| 		n, err := iface.Read(buf[:cap(buf)]) | ||||
| 		if err != nil { | ||||
| 			return nil, ip, err | ||||
| 		} | ||||
|  | ||||
| 		buf = buf[:n] | ||||
| 		version = buf[0] >> 4 | ||||
|  | ||||
| 		switch version { | ||||
| 		case 4: | ||||
| 			if n < 20 { | ||||
| 				log.Printf("Short IPv4 packet: %d", len(buf)) | ||||
| 				continue | ||||
| 			} | ||||
| 			ip = buf[19] | ||||
|  | ||||
| 		case 6: | ||||
| 			if len(buf) < 40 { | ||||
| 				log.Printf("Short IPv6 packet: %d", len(buf)) | ||||
| 				continue | ||||
| 			} | ||||
| 			ip = buf[39] | ||||
|  | ||||
| 		default: | ||||
| 			log.Printf("Invalid IP packet version: %v", version) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		return buf, ip, nil | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func openInterface(network []byte, localIP byte, name string) (io.ReadWriteCloser, error) { | ||||
| 	if len(network) != 4 { | ||||
| 		return nil, fmt.Errorf("expected network to be 4 bytes, got %d", len(network)) | ||||
|   | ||||
| @@ -1,49 +0,0 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"io" | ||||
| 	"net" | ||||
| 	"net/netip" | ||||
| ) | ||||
|  | ||||
| type UDPConn interface { | ||||
| 	ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) | ||||
| 	WriteToUDPAddrPort([]byte, netip.AddrPort) (int, error) | ||||
| 	WriteToUDP([]byte, *net.UDPAddr) (int, error) | ||||
| } | ||||
|  | ||||
| type ifWriter io.Writer | ||||
|  | ||||
| type udpReader interface { | ||||
| 	ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) | ||||
| } | ||||
|  | ||||
| type udpWriter interface { | ||||
| 	WriteToUDPAddrPort([]byte, netip.AddrPort) (int, error) | ||||
| } | ||||
|  | ||||
| type mcUDPWriter interface { | ||||
| 	WriteToUDP([]byte, *net.UDPAddr) (int, error) | ||||
| } | ||||
|  | ||||
| type Marshaller interface { | ||||
| 	Marshal([]byte) []byte | ||||
| } | ||||
|  | ||||
| type dataPacketSender interface { | ||||
| 	SendDataPacket(pkt []byte, peer *remotePeer) | ||||
| 	RelayDataPacket(pkt []byte, peer, relay *remotePeer) | ||||
| } | ||||
|  | ||||
| type controlPacketSender interface { | ||||
| 	SendControlPacket(pkt Marshaller, peer *remotePeer) | ||||
| 	RelayControlPacket(pkt Marshaller, peer, relay *remotePeer) | ||||
| } | ||||
|  | ||||
| type encryptedPacketSender interface { | ||||
| 	SendEncryptedDataPacket(pkt []byte, peer *remotePeer) | ||||
| } | ||||
|  | ||||
| type controlMsgHandler interface { | ||||
| 	HandleControlMsg(pkt any) | ||||
| } | ||||
| @@ -6,7 +6,7 @@ import ( | ||||
| ) | ||||
|  | ||||
| func Main() { | ||||
| 	conf := Config{} | ||||
| 	conf := peerConfig{} | ||||
|  | ||||
| 	flag.StringVar(&conf.NetName, "name", "", "[REQUIRED] The network name.") | ||||
| 	flag.StringVar(&conf.HubAddress, "hub-address", "", "[REQUIRED] The hub address.") | ||||
| @@ -18,6 +18,6 @@ func Main() { | ||||
| 		os.Exit(1) | ||||
| 	} | ||||
|  | ||||
| 	peer := New(conf) | ||||
| 	peer := newPeerMain(conf) | ||||
| 	peer.Run() | ||||
| } | ||||
|   | ||||
| @@ -1,10 +1,6 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"log" | ||||
| 	"sync/atomic" | ||||
| ) | ||||
|  | ||||
| /* | ||||
| type mcReader struct { | ||||
| 	conn  udpReader | ||||
| 	super controlMsgHandler | ||||
| @@ -55,3 +51,4 @@ func (r *mcReader) handleNextPacket() { | ||||
| 		SrcAddr: remoteAddr, | ||||
| 	}) | ||||
| } | ||||
| */ | ||||
|   | ||||
| @@ -1,8 +1,6 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"log" | ||||
|  | ||||
| 	"golang.org/x/crypto/nacl/sign" | ||||
| ) | ||||
|  | ||||
| @@ -34,7 +32,9 @@ func verifyLocalDiscoveryPacket(pkt, buf []byte, pubSignKey []byte) bool { | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| /* | ||||
| type mcWriter struct { | ||||
|  | ||||
| 	conn            mcUDPWriter | ||||
| 	discoveryPacket []byte | ||||
| } | ||||
| @@ -50,4 +50,4 @@ func (w *mcWriter) SendLocalDiscovery() { | ||||
| 	if _, err := w.conn.WriteToUDP(w.discoveryPacket, multicastAddr); err != nil { | ||||
| 		log.Printf("[MCWriter] Failed to write multicast UDP packet: %v", err) | ||||
| 	} | ||||
| } | ||||
|   }*/ | ||||
|   | ||||
| @@ -1,11 +1,6 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"net" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| /* | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| // Testing that we can create and verify a local discovery packet. | ||||
| @@ -100,3 +95,4 @@ func TestMCWriter_SendLocalDiscovery(t *testing.T) { | ||||
| 		t.Fatal("Verification should succeed.") | ||||
| 	} | ||||
| } | ||||
| */ | ||||
|   | ||||
							
								
								
									
										24
									
								
								peer/peer.go
									
									
									
									
									
								
							
							
						
						
									
										24
									
								
								peer/peer.go
									
									
									
									
									
								
							| @@ -15,21 +15,21 @@ import ( | ||||
| 	"vppn/m" | ||||
| ) | ||||
|  | ||||
| type Peer struct { | ||||
| 	ifReader   *IFReader | ||||
| 	connReader *ConnReader | ||||
| type peerMain struct { | ||||
| 	ifReader   *ifReader | ||||
| 	connReader *connReader | ||||
| 	iface      io.Writer | ||||
| 	hubPoller  *hubPoller | ||||
| 	super      *Super | ||||
| 	super      *supervisor | ||||
| } | ||||
|  | ||||
| type Config struct { | ||||
| type peerConfig struct { | ||||
| 	NetName    string | ||||
| 	HubAddress string | ||||
| 	APIKey     string | ||||
| } | ||||
|  | ||||
| func New(conf Config) *Peer { | ||||
| func newPeerMain(conf peerConfig) *peerMain { | ||||
| 	config, err := loadPeerConfig(conf.NetName) | ||||
| 	if err != nil { | ||||
| 		log.Printf("Failed to load configuration: %v", err) | ||||
| @@ -83,15 +83,15 @@ func New(conf Config) *Peer { | ||||
| 	rtPtr := &atomic.Pointer[routingTable]{} | ||||
| 	rtPtr.Store(&rt) | ||||
|  | ||||
| 	ifReader := NewIFReader(iface, writeToUDPAddrPort, rtPtr) | ||||
| 	super := NewSuper(writeToUDPAddrPort, rtPtr, config.PrivKey) | ||||
| 	connReader := NewConnReader(conn.ReadFromUDPAddrPort, writeToUDPAddrPort, iface, super.HandleControlMsg, rtPtr) | ||||
| 	ifReader := newIFReader(iface, writeToUDPAddrPort, rtPtr) | ||||
| 	super := newSupervisor(writeToUDPAddrPort, rtPtr, config.PrivKey) | ||||
| 	connReader := newConnReader(conn.ReadFromUDPAddrPort, writeToUDPAddrPort, iface, super.HandleControlMsg, rtPtr) | ||||
| 	hubPoller, err := newHubPoller(config.PeerIP, conf.NetName, conf.HubAddress, conf.APIKey, super.HandleControlMsg) | ||||
| 	if err != nil { | ||||
| 		log.Fatalf("Failed to create hub poller: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	return &Peer{ | ||||
| 	return &peerMain{ | ||||
| 		iface:      iface, | ||||
| 		ifReader:   ifReader, | ||||
| 		connReader: connReader, | ||||
| @@ -100,14 +100,14 @@ func New(conf Config) *Peer { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (p *Peer) Run() { | ||||
| func (p *peerMain) Run() { | ||||
| 	go p.ifReader.Run() | ||||
| 	go p.connReader.Run() | ||||
| 	p.super.Start() | ||||
| 	p.hubPoller.Run() | ||||
| } | ||||
|  | ||||
| func initPeerWithHub(conf Config) { | ||||
| func initPeerWithHub(conf peerConfig) { | ||||
| 	keys := generateKeys() | ||||
|  | ||||
| 	initURL, err := url.Parse(conf.HubAddress) | ||||
|   | ||||
| @@ -14,8 +14,8 @@ type P struct { | ||||
| 	RT         *atomic.Pointer[routingTable] | ||||
| 	Conn       *TestUDPConn | ||||
| 	IFace      *TestIFace | ||||
| 	ConnReader *ConnReader | ||||
| 	IFReader   *IFReader | ||||
| 	ConnReader *connReader | ||||
| 	IFReader   *ifReader | ||||
| } | ||||
|  | ||||
| func NewPeerForTesting(n *TestNetwork, ip byte, addr netip.AddrPort) P { | ||||
|   | ||||
| @@ -25,7 +25,7 @@ type peerState interface { | ||||
| type pState struct { | ||||
| 	// Output. | ||||
| 	publish           func(remotePeer) | ||||
| 	sendControlPacket func(remotePeer, Marshaller) | ||||
| 	sendControlPacket func(remotePeer, marshaller) | ||||
|  | ||||
| 	// Immutable data. | ||||
| 	localIP   byte | ||||
| @@ -124,7 +124,7 @@ func (s *pState) logf(format string, args ...any) { | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *pState) SendTo(pkt Marshaller, addr netip.AddrPort) { | ||||
| func (s *pState) SendTo(pkt marshaller, addr netip.AddrPort) { | ||||
| 	if !addr.IsValid() { | ||||
| 		return | ||||
| 	} | ||||
| @@ -134,7 +134,7 @@ func (s *pState) SendTo(pkt Marshaller, addr netip.AddrPort) { | ||||
| 	s.Send(route, pkt) | ||||
| } | ||||
|  | ||||
| func (s *pState) Send(peer remotePeer, pkt Marshaller) { | ||||
| func (s *pState) Send(peer remotePeer, pkt marshaller) { | ||||
| 	if err := s.limiter.Limit(); err != nil { | ||||
| 		s.logf("Rate limited.") | ||||
| 		return | ||||
|   | ||||
| @@ -31,7 +31,7 @@ func NewPeerStateTestHarness() *PeerStateTestHarness { | ||||
| 		publish: func(rp remotePeer) { | ||||
| 			h.Published = rp | ||||
| 		}, | ||||
| 		sendControlPacket: func(rp remotePeer, pkt Marshaller) { | ||||
| 		sendControlPacket: func(rp remotePeer, pkt marshaller) { | ||||
| 			h.Sent = append(h.Sent, PeerStateControlMsg{rp, pkt}) | ||||
| 		}, | ||||
| 		localIP:  2, | ||||
|   | ||||
| @@ -11,26 +11,26 @@ import ( | ||||
| 	"git.crumpington.com/lib/go/ratelimiter" | ||||
| ) | ||||
|  | ||||
| type Super struct { | ||||
| type supervisor struct { | ||||
| 	writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error) | ||||
| 	staged             routingTable | ||||
| 	shared             *atomic.Pointer[routingTable] | ||||
| 	peers              [256]*PeerSuper | ||||
| 	peers              [256]*peerSuper | ||||
| 	lock               sync.Mutex | ||||
|  | ||||
| 	buf1 []byte | ||||
| 	buf2 []byte | ||||
| } | ||||
|  | ||||
| func NewSuper( | ||||
| func newSupervisor( | ||||
| 	writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error), | ||||
| 	rt *atomic.Pointer[routingTable], | ||||
| 	privKey []byte, | ||||
| ) *Super { | ||||
| ) *supervisor { | ||||
|  | ||||
| 	routes := rt.Load() | ||||
|  | ||||
| 	s := &Super{ | ||||
| 	s := &supervisor{ | ||||
| 		writeToUDPAddrPort: writeToUDPAddrPort, | ||||
| 		staged:             *routes, | ||||
| 		shared:             rt, | ||||
| @@ -55,23 +55,23 @@ func NewSuper( | ||||
| 				MaxWaitCount: 1, | ||||
| 			}), | ||||
| 		} | ||||
| 		s.peers[i] = NewPeerSuper(state) | ||||
| 		s.peers[i] = newPeerSuper(state) | ||||
| 	} | ||||
|  | ||||
| 	return s | ||||
| } | ||||
|  | ||||
| func (s *Super) Start() { | ||||
| func (s *supervisor) Start() { | ||||
| 	for i := range s.peers { | ||||
| 		go s.peers[i].Run() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s *Super) HandleControlMsg(destIP byte, msg any) { | ||||
| func (s *supervisor) HandleControlMsg(destIP byte, msg any) { | ||||
| 	s.peers[destIP].HandleControlMsg(msg) | ||||
| } | ||||
|  | ||||
| func (s *Super) send(peer remotePeer, pkt Marshaller) { | ||||
| func (s *supervisor) send(peer remotePeer, pkt marshaller) { | ||||
| 	s.lock.Lock() | ||||
| 	defer s.lock.Unlock() | ||||
|  | ||||
| @@ -90,7 +90,7 @@ func (s *Super) send(peer remotePeer, pkt Marshaller) { | ||||
| 	s.writeToUDPAddrPort(enc, relay.DirectAddr) | ||||
| } | ||||
|  | ||||
| func (s *Super) publish(rp remotePeer) { | ||||
| func (s *supervisor) publish(rp remotePeer) { | ||||
| 	s.lock.Lock() | ||||
| 	defer s.lock.Unlock() | ||||
|  | ||||
| @@ -100,7 +100,7 @@ func (s *Super) publish(rp remotePeer) { | ||||
| 	s.shared.Store(©) | ||||
| } | ||||
|  | ||||
| func (s *Super) ensureRelay() { | ||||
| func (s *supervisor) ensureRelay() { | ||||
| 	if _, ok := s.staged.GetRelay(); ok { | ||||
| 		return | ||||
| 	} | ||||
| @@ -116,26 +116,26 @@ func (s *Super) ensureRelay() { | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type PeerSuper struct { | ||||
| type peerSuper struct { | ||||
| 	messages chan any | ||||
| 	state    peerState | ||||
| } | ||||
|  | ||||
| func NewPeerSuper(state *pState) *PeerSuper { | ||||
| 	return &PeerSuper{ | ||||
| func newPeerSuper(state *pState) *peerSuper { | ||||
| 	return &peerSuper{ | ||||
| 		messages: make(chan any, 8), | ||||
| 		state:    state.OnPeerUpdate(nil), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s *PeerSuper) HandleControlMsg(msg any) { | ||||
| func (s *peerSuper) HandleControlMsg(msg any) { | ||||
| 	select { | ||||
| 	case s.messages <- msg: | ||||
| 	default: | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s *PeerSuper) Run() { | ||||
| func (s *peerSuper) Run() { | ||||
| 	go func() { | ||||
| 		// Randomize ping timers. | ||||
| 		time.Sleep(time.Duration(rand.Intn(4000)) * time.Millisecond) | ||||
|   | ||||
| @@ -7,7 +7,7 @@ import ( | ||||
| ) | ||||
|  | ||||
| // TODO: Remove | ||||
| func NewRemotePeer(ip byte) *remotePeer { | ||||
| func newRemotePeer(ip byte) *remotePeer { | ||||
| 	counter := uint64(time.Now().Unix()<<30 + 1) | ||||
| 	return &remotePeer{ | ||||
| 		IP:       ip, | ||||
| @@ -58,7 +58,7 @@ func (p remotePeer) DecryptDataPacket(h header, enc, out []byte) ([]byte, error) | ||||
| } | ||||
|  | ||||
| // Peer must have a ControlCipher. | ||||
| func (p remotePeer) EncryptControlPacket(pkt Marshaller, tmp, out []byte) []byte { | ||||
| func (p remotePeer) EncryptControlPacket(pkt marshaller, tmp, out []byte) []byte { | ||||
| 	tmp = pkt.Marshal(tmp) | ||||
| 	h := header{ | ||||
| 		StreamID: controlStreamID, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user