WIP
This commit is contained in:
		
							
								
								
									
										132
									
								
								peer/connreader2.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										132
									
								
								peer/connreader2.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,132 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"io" | ||||||
|  | 	"log" | ||||||
|  | 	"net/netip" | ||||||
|  | 	"sync/atomic" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type ConnReader struct { | ||||||
|  | 	// Input | ||||||
|  | 	readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error) | ||||||
|  |  | ||||||
|  | 	// Output | ||||||
|  | 	iface            io.Writer | ||||||
|  | 	forwardData      func(ip byte, pkt []byte) | ||||||
|  | 	handleControlMsg func(pkt any) | ||||||
|  |  | ||||||
|  | 	localIP byte | ||||||
|  | 	rt      *atomic.Pointer[RoutingTable] | ||||||
|  |  | ||||||
|  | 	buf    []byte | ||||||
|  | 	decBuf []byte | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func NewConnReader( | ||||||
|  | 	readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error), | ||||||
|  | 	iface io.Writer, | ||||||
|  | 	forwardData func(ip byte, pkt []byte), | ||||||
|  | 	handleControlMsg func(pkt any), | ||||||
|  | 	rt *atomic.Pointer[RoutingTable], | ||||||
|  | ) *ConnReader { | ||||||
|  | 	return &ConnReader{ | ||||||
|  | 		readFromUDPAddrPort: readFromUDPAddrPort, | ||||||
|  | 		iface:               iface, | ||||||
|  | 		forwardData:         forwardData, | ||||||
|  | 		handleControlMsg:    handleControlMsg, | ||||||
|  | 		localIP:             rt.Load().LocalIP, | ||||||
|  | 		rt:                  rt, | ||||||
|  | 		buf:                 newBuf(), | ||||||
|  | 		decBuf:              newBuf(), | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *ConnReader) Run() { | ||||||
|  | 	for { | ||||||
|  | 		r.handleNextPacket() | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *ConnReader) handleNextPacket() { | ||||||
|  | 	buf := r.buf[:bufferSize] | ||||||
|  | 	n, remoteAddr, err := r.readFromUDPAddrPort(buf) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Fatalf("Failed to read from UDP port: %v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if n < headerSize { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	remoteAddr = netip.AddrPortFrom(remoteAddr.Addr().Unmap(), remoteAddr.Port()) | ||||||
|  |  | ||||||
|  | 	buf = buf[:n] | ||||||
|  | 	h := parseHeader(buf) | ||||||
|  |  | ||||||
|  | 	peer := r.rt.Load().Peers[h.SourceIP] | ||||||
|  | 	//peer := rt.Peers[h.SourceIP] | ||||||
|  |  | ||||||
|  | 	switch h.StreamID { | ||||||
|  | 	case controlStreamID: | ||||||
|  | 		r.handleControlPacket(remoteAddr, peer, h, buf) | ||||||
|  | 	case dataStreamID: | ||||||
|  | 		r.handleDataPacket(peer, h, buf) | ||||||
|  | 	default: | ||||||
|  | 		r.logf("Unknown stream ID: %d", h.StreamID) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *ConnReader) handleControlPacket( | ||||||
|  | 	remoteAddr netip.AddrPort, | ||||||
|  | 	peer RemotePeer, | ||||||
|  | 	h header, | ||||||
|  | 	enc []byte, | ||||||
|  | ) { | ||||||
|  | 	if peer.ControlCipher == nil { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if h.DestIP != r.localIP { | ||||||
|  | 		r.logf("Incorrect destination IP on control packet: %d", h.DestIP) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	msg, err := peer.DecryptControlPacket(remoteAddr, h, enc, r.decBuf) | ||||||
|  | 	if err != nil { | ||||||
|  | 		r.logf("Failed to decrypt control packet: %v", err) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	r.handleControlMsg(msg) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *ConnReader) handleDataPacket( | ||||||
|  | 	peer RemotePeer, | ||||||
|  | 	h header, | ||||||
|  | 	enc []byte, | ||||||
|  | ) { | ||||||
|  | 	if !peer.Up { | ||||||
|  | 		r.logf("Not connected (recv).") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	data, err := peer.DecryptDataPacket(h, enc, r.decBuf) | ||||||
|  | 	if err != nil { | ||||||
|  | 		r.logf("Failed to decrypt data packet: %v", err) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if h.DestIP == r.localIP { | ||||||
|  | 		if _, err := r.iface.Write(data); err != nil { | ||||||
|  | 			log.Fatalf("Failed to write to interface: %v", err) | ||||||
|  | 		} | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	r.forwardData(h.DestIP, data) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *ConnReader) logf(format string, args ...any) { | ||||||
|  | 	log.Printf("[ConnReader] "+format, args...) | ||||||
|  | } | ||||||
| @@ -109,7 +109,7 @@ func newConnReadeTestHarness() (h connReaderTestHarness) { | |||||||
| func TestConnReader_handleControlPacket(t *testing.T) { | func TestConnReader_handleControlPacket(t *testing.T) { | ||||||
| 	h := newConnReadeTestHarness() | 	h := newConnReadeTestHarness() | ||||||
|  |  | ||||||
| 	pkt := synPacket{TraceID: 1234} | 	pkt := PacketSyn{TraceID: 1234} | ||||||
|  |  | ||||||
| 	h.WRemote.SendControlPacket(pkt, h.Remote) | 	h.WRemote.SendControlPacket(pkt, h.Remote) | ||||||
|  |  | ||||||
| @@ -119,7 +119,7 @@ func TestConnReader_handleControlPacket(t *testing.T) { | |||||||
| 		t.Fatal(h.Super.Messages) | 		t.Fatal(h.Super.Messages) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	msg := h.Super.Messages[0].(controlMsg[synPacket]) | 	msg := h.Super.Messages[0].(controlMsg[PacketSyn]) | ||||||
| 	if !reflect.DeepEqual(pkt, msg.Packet) { | 	if !reflect.DeepEqual(pkt, msg.Packet) { | ||||||
| 		t.Fatal(msg.Packet) | 		t.Fatal(msg.Packet) | ||||||
| 	} | 	} | ||||||
| @@ -141,7 +141,7 @@ func TestConnReader_handleNextPacket_short(t *testing.T) { | |||||||
| func TestConnReader_handleNextPacket_unknownStreamID(t *testing.T) { | func TestConnReader_handleNextPacket_unknownStreamID(t *testing.T) { | ||||||
| 	h := newConnReadeTestHarness() | 	h := newConnReadeTestHarness() | ||||||
|  |  | ||||||
| 	pkt := synPacket{TraceID: 1234} | 	pkt := PacketSyn{TraceID: 1234} | ||||||
|  |  | ||||||
| 	encrypted := encryptControlPacket(1, h.Remote, pkt, newBuf(), newBuf()) | 	encrypted := encryptControlPacket(1, h.Remote, pkt, newBuf(), newBuf()) | ||||||
| 	var header header | 	var header header | ||||||
| @@ -160,7 +160,7 @@ func TestConnReader_handleNextPacket_unknownStreamID(t *testing.T) { | |||||||
| func TestConnReader_handleControlPacket_noCipher(t *testing.T) { | func TestConnReader_handleControlPacket_noCipher(t *testing.T) { | ||||||
| 	h := newConnReadeTestHarness() | 	h := newConnReadeTestHarness() | ||||||
|  |  | ||||||
| 	pkt := synPacket{TraceID: 1234} | 	pkt := PacketSyn{TraceID: 1234} | ||||||
|  |  | ||||||
| 	//encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) | 	//encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) | ||||||
| 	encrypted := encryptControlPacket(1, h.Remote, pkt, newBuf(), newBuf()) | 	encrypted := encryptControlPacket(1, h.Remote, pkt, newBuf(), newBuf()) | ||||||
| @@ -180,7 +180,7 @@ func TestConnReader_handleControlPacket_noCipher(t *testing.T) { | |||||||
| func TestConnReader_handleControlPacket_incorrectDest(t *testing.T) { | func TestConnReader_handleControlPacket_incorrectDest(t *testing.T) { | ||||||
| 	h := newConnReadeTestHarness() | 	h := newConnReadeTestHarness() | ||||||
|  |  | ||||||
| 	pkt := synPacket{TraceID: 1234} | 	pkt := PacketSyn{TraceID: 1234} | ||||||
|  |  | ||||||
| 	encrypted := encryptControlPacket(2, h.Remote, pkt, newBuf(), newBuf()) | 	encrypted := encryptControlPacket(2, h.Remote, pkt, newBuf(), newBuf()) | ||||||
| 	var header header | 	var header header | ||||||
| @@ -199,7 +199,7 @@ func TestConnReader_handleControlPacket_incorrectDest(t *testing.T) { | |||||||
| func TestConnReader_handleControlPacket_modified(t *testing.T) { | func TestConnReader_handleControlPacket_modified(t *testing.T) { | ||||||
| 	h := newConnReadeTestHarness() | 	h := newConnReadeTestHarness() | ||||||
|  |  | ||||||
| 	pkt := synPacket{TraceID: 1234} | 	pkt := PacketSyn{TraceID: 1234} | ||||||
|  |  | ||||||
| 	encrypted := encryptControlPacket(2, h.Remote, pkt, newBuf(), newBuf()) | 	encrypted := encryptControlPacket(2, h.Remote, pkt, newBuf(), newBuf()) | ||||||
| 	encrypted[len(encrypted)-1]++ | 	encrypted[len(encrypted)-1]++ | ||||||
| @@ -237,10 +237,10 @@ func TestConnReader_handleControlPacket_unknownPacketType(t *testing.T) { | |||||||
| func TestConnReader_handleControlPacket_duplicate(t *testing.T) { | func TestConnReader_handleControlPacket_duplicate(t *testing.T) { | ||||||
| 	h := newConnReadeTestHarness() | 	h := newConnReadeTestHarness() | ||||||
|  |  | ||||||
| 	pkt := ackPacket{TraceID: 1234} | 	pkt := PacketAck{TraceID: 1234} | ||||||
|  |  | ||||||
| 	h.WRemote.SendControlPacket(pkt, h.Remote) | 	h.WRemote.SendControlPacket(pkt, h.Remote) | ||||||
| 	*h.Remote.Counter = *h.Remote.Counter - 1 | 	*h.Remote.counter = *h.Remote.counter - 1 | ||||||
| 	h.WRemote.SendControlPacket(pkt, h.Remote) | 	h.WRemote.SendControlPacket(pkt, h.Remote) | ||||||
|  |  | ||||||
| 	h.R.handleNextPacket() | 	h.R.handleNextPacket() | ||||||
| @@ -250,7 +250,7 @@ func TestConnReader_handleControlPacket_duplicate(t *testing.T) { | |||||||
| 		t.Fatal(h.Super.Messages) | 		t.Fatal(h.Super.Messages) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	msg := h.Super.Messages[0].(controlMsg[ackPacket]) | 	msg := h.Super.Messages[0].(controlMsg[PacketAck]) | ||||||
| 	if !reflect.DeepEqual(pkt, msg.Packet) { | 	if !reflect.DeepEqual(pkt, msg.Packet) { | ||||||
| 		t.Fatal(msg.Packet) | 		t.Fatal(msg.Packet) | ||||||
| 	} | 	} | ||||||
| @@ -301,7 +301,7 @@ func TestConnReader_handleDataPacket_duplicate(t *testing.T) { | |||||||
| 	pkt := make([]byte, 123) | 	pkt := make([]byte, 123) | ||||||
|  |  | ||||||
| 	h.WRemote.SendDataPacket(pkt, h.Remote) | 	h.WRemote.SendDataPacket(pkt, h.Remote) | ||||||
| 	*h.Remote.Counter = *h.Remote.Counter - 1 | 	*h.Remote.counter = *h.Remote.counter - 1 | ||||||
| 	h.WRemote.SendDataPacket(pkt, h.Remote) | 	h.WRemote.SendDataPacket(pkt, h.Remote) | ||||||
|  |  | ||||||
| 	h.R.handleNextPacket() | 	h.R.handleNextPacket() | ||||||
|   | |||||||
| @@ -37,13 +37,13 @@ func newConnWriter(conn udpWriter, localIP byte) *connWriter { | |||||||
| } | } | ||||||
|  |  | ||||||
| // Not safe for concurrent use. Should only be called by supervisor. | // Not safe for concurrent use. Should only be called by supervisor. | ||||||
| func (w *connWriter) SendControlPacket(pkt marshaller, peer *RemotePeer) { | func (w *connWriter) SendControlPacket(pkt Marshaller, peer *RemotePeer) { | ||||||
| 	enc := encryptControlPacket(w.localIP, peer, pkt, w.cBuf1, w.cBuf2) | 	enc := encryptControlPacket(w.localIP, peer, pkt, w.cBuf1, w.cBuf2) | ||||||
| 	w.writeTo(enc, peer.DirectAddr) | 	w.writeTo(enc, peer.DirectAddr) | ||||||
| } | } | ||||||
|  |  | ||||||
| // Relay control packet. Peer must not be nil. | // Relay control packet. Peer must not be nil. | ||||||
| func (w *connWriter) RelayControlPacket(pkt marshaller, peer, relay *RemotePeer) { | func (w *connWriter) RelayControlPacket(pkt Marshaller, peer, relay *RemotePeer) { | ||||||
| 	enc := encryptControlPacket(w.localIP, peer, pkt, w.cBuf1, w.cBuf2) | 	enc := encryptControlPacket(w.localIP, peer, pkt, w.cBuf1, w.cBuf2) | ||||||
| 	enc = encryptDataPacket(w.localIP, peer.IP, relay, enc, w.cBuf1) | 	enc = encryptDataPacket(w.localIP, peer.IP, relay, enc, w.cBuf1) | ||||||
| 	w.writeTo(enc, relay.DirectAddr) | 	w.writeTo(enc, relay.DirectAddr) | ||||||
|   | |||||||
							
								
								
									
										109
									
								
								peer/connwriter2.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										109
									
								
								peer/connwriter2.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,109 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"log" | ||||||
|  | 	"net/netip" | ||||||
|  | 	"sync" | ||||||
|  | 	"sync/atomic" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type ConnWriter struct { | ||||||
|  | 	wLock sync.Mutex // Lock around for sending on UDP Conn. | ||||||
|  |  | ||||||
|  | 	// Output. | ||||||
|  | 	writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error) | ||||||
|  |  | ||||||
|  | 	// Shared state. | ||||||
|  | 	rt *atomic.Pointer[RoutingTable] | ||||||
|  |  | ||||||
|  | 	// For sending control packets. | ||||||
|  | 	cBuf1 []byte | ||||||
|  | 	cBuf2 []byte | ||||||
|  |  | ||||||
|  | 	// For sending data packets. | ||||||
|  | 	dBuf1 []byte | ||||||
|  | 	dBuf2 []byte | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func NewConnWriter( | ||||||
|  | 	writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error), | ||||||
|  | 	rt *atomic.Pointer[RoutingTable], | ||||||
|  | ) *ConnWriter { | ||||||
|  | 	return &ConnWriter{ | ||||||
|  | 		writeToUDPAddrPort: writeToUDPAddrPort, | ||||||
|  | 		rt:                 rt, | ||||||
|  | 		cBuf1:              newBuf(), | ||||||
|  | 		cBuf2:              newBuf(), | ||||||
|  | 		dBuf1:              newBuf(), | ||||||
|  | 		dBuf2:              newBuf(), | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Called by ConnReader to forward already encrypted bytes to another peer. | ||||||
|  | func (w *ConnWriter) Forward(ip byte, pkt []byte) { | ||||||
|  | 	peer := w.rt.Load().Peers[ip] | ||||||
|  | 	if !(peer.Up && peer.Direct) { | ||||||
|  | 		w.logf("Failed to forward to %d.", ip) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	w.writeTo(pkt, peer.DirectAddr) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Called by IFReader to send data. Encryption will be applied, and packet will | ||||||
|  | // be relayed if appropriate. | ||||||
|  | func (w *ConnWriter) WriteData(ip byte, pkt []byte) { | ||||||
|  | 	rt := w.rt.Load() | ||||||
|  | 	peer := rt.Peers[ip] | ||||||
|  | 	if !peer.Up { | ||||||
|  | 		w.logf("Failed to send data to %d.", ip) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	enc := peer.EncryptDataPacket(ip, pkt, w.dBuf1) | ||||||
|  |  | ||||||
|  | 	if peer.Direct { | ||||||
|  | 		w.writeTo(enc, peer.DirectAddr) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	relay, ok := rt.GetRelay() | ||||||
|  | 	if !ok { | ||||||
|  | 		w.logf("Failed to send data to %d. No relay.", ip) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	enc = relay.EncryptDataPacket(ip, enc, w.dBuf2) | ||||||
|  | 	w.writeTo(enc, relay.DirectAddr) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Called by Supervisor to send control packets. | ||||||
|  | func (w *ConnWriter) WriteControl(peer RemotePeer, pkt Marshaller) { | ||||||
|  | 	enc := peer.EncryptControlPacket(pkt, w.cBuf2, w.cBuf1) | ||||||
|  |  | ||||||
|  | 	if peer.Direct { | ||||||
|  | 		w.writeTo(enc, peer.DirectAddr) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	rt := w.rt.Load() | ||||||
|  | 	relay, ok := rt.GetRelay() | ||||||
|  | 	if !ok { | ||||||
|  | 		w.logf("Failed to send control to %d. No relay.", peer.IP) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	enc = relay.EncryptDataPacket(peer.IP, enc, w.cBuf2) | ||||||
|  | 	w.writeTo(enc, relay.DirectAddr) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *ConnWriter) writeTo(pkt []byte, addr netip.AddrPort) { | ||||||
|  | 	w.wLock.Lock() | ||||||
|  | 	if _, err := w.writeToUDPAddrPort(pkt, addr); err != nil { | ||||||
|  | 		w.logf("Failed to write to UDP port: %v", err) | ||||||
|  | 	} | ||||||
|  | 	w.wLock.Unlock() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *ConnWriter) logf(s string, args ...any) { | ||||||
|  | 	log.Printf("[ConnWriter] "+s, args...) | ||||||
|  | } | ||||||
							
								
								
									
										145
									
								
								peer/connwriter2_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										145
									
								
								peer/connwriter2_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,145 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestConnWriter_WriteData_direct(t *testing.T) { | ||||||
|  | 	p1, p2, _ := NewPeersForTesting() | ||||||
|  |  | ||||||
|  | 	in := RandPacket() | ||||||
|  | 	p1.ConnWriter.WriteData(2, in) | ||||||
|  |  | ||||||
|  | 	packets := p2.Conn.Packets() | ||||||
|  | 	if len(packets) != 1 { | ||||||
|  | 		t.Fatal(packets) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestConnWriter_WriteData_peerNotUp(t *testing.T) { | ||||||
|  | 	p1, p2, _ := NewPeersForTesting() | ||||||
|  | 	p1.RT.Load().Peers[2].Up = false | ||||||
|  |  | ||||||
|  | 	in := RandPacket() | ||||||
|  | 	p1.ConnWriter.WriteData(2, in) | ||||||
|  |  | ||||||
|  | 	packets := p2.Conn.Packets() | ||||||
|  | 	if len(packets) != 0 { | ||||||
|  | 		t.Fatal(packets) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestConnWriter_WriteData_relay(t *testing.T) { | ||||||
|  | 	p1, _, p3 := NewPeersForTesting() | ||||||
|  |  | ||||||
|  | 	p1.RT.Load().Peers[2].Direct = false | ||||||
|  | 	p1.RT.Load().RelayIP = 3 | ||||||
|  |  | ||||||
|  | 	in := RandPacket() | ||||||
|  | 	p1.ConnWriter.WriteData(2, in) | ||||||
|  |  | ||||||
|  | 	packets := p3.Conn.Packets() | ||||||
|  | 	if len(packets) != 1 { | ||||||
|  | 		t.Fatal(packets) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestConnWriter_WriteData_relayNotAvailable(t *testing.T) { | ||||||
|  | 	p1, _, p3 := NewPeersForTesting() | ||||||
|  |  | ||||||
|  | 	p1.RT.Load().Peers[2].Direct = false | ||||||
|  | 	p1.RT.Load().Peers[3].Up = false | ||||||
|  | 	p1.RT.Load().RelayIP = 3 | ||||||
|  |  | ||||||
|  | 	in := RandPacket() | ||||||
|  | 	p1.ConnWriter.WriteData(2, in) | ||||||
|  |  | ||||||
|  | 	packets := p3.Conn.Packets() | ||||||
|  | 	if len(packets) != 0 { | ||||||
|  | 		t.Fatal(packets) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestConnWriter_WriteControl_direct(t *testing.T) { | ||||||
|  | 	p1, p2, _ := NewPeersForTesting() | ||||||
|  |  | ||||||
|  | 	orig := PacketProbe{TraceID: newTraceID()} | ||||||
|  |  | ||||||
|  | 	p1.ConnWriter.WriteControl(p1.RT.Load().Peers[2], orig) | ||||||
|  |  | ||||||
|  | 	packets := p2.Conn.Packets() | ||||||
|  | 	if len(packets) != 1 { | ||||||
|  | 		t.Fatal(packets) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestConnWriter_WriteControl_relay(t *testing.T) { | ||||||
|  | 	p1, _, p3 := NewPeersForTesting() | ||||||
|  |  | ||||||
|  | 	p1.RT.Load().Peers[2].Direct = false | ||||||
|  | 	p1.RT.Load().RelayIP = 3 | ||||||
|  |  | ||||||
|  | 	orig := PacketProbe{TraceID: newTraceID()} | ||||||
|  |  | ||||||
|  | 	p1.ConnWriter.WriteControl(p1.RT.Load().Peers[2], orig) | ||||||
|  |  | ||||||
|  | 	packets := p3.Conn.Packets() | ||||||
|  | 	if len(packets) != 1 { | ||||||
|  | 		t.Fatal(packets) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestConnWriter_WriteControl_relayNotAvailable(t *testing.T) { | ||||||
|  | 	p1, _, p3 := NewPeersForTesting() | ||||||
|  |  | ||||||
|  | 	p1.RT.Load().Peers[2].Direct = false | ||||||
|  | 	p1.RT.Load().Peers[3].Up = false | ||||||
|  | 	p1.RT.Load().RelayIP = 3 | ||||||
|  |  | ||||||
|  | 	orig := PacketProbe{TraceID: newTraceID()} | ||||||
|  |  | ||||||
|  | 	p1.ConnWriter.WriteControl(p1.RT.Load().Peers[2], orig) | ||||||
|  |  | ||||||
|  | 	packets := p3.Conn.Packets() | ||||||
|  | 	if len(packets) != 0 { | ||||||
|  | 		t.Fatal(packets) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestConnWriter__Forward(t *testing.T) { | ||||||
|  | 	p1, p2, _ := NewPeersForTesting() | ||||||
|  |  | ||||||
|  | 	in := RandPacket() | ||||||
|  | 	p1.ConnWriter.Forward(2, in) | ||||||
|  |  | ||||||
|  | 	packets := p2.Conn.Packets() | ||||||
|  | 	if len(packets) != 1 { | ||||||
|  | 		t.Fatal(packets) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestConnWriter__Forward_notUp(t *testing.T) { | ||||||
|  | 	p1, p2, _ := NewPeersForTesting() | ||||||
|  | 	p1.RT.Load().Peers[2].Up = false | ||||||
|  |  | ||||||
|  | 	in := RandPacket() | ||||||
|  | 	p1.ConnWriter.Forward(2, in) | ||||||
|  |  | ||||||
|  | 	packets := p2.Conn.Packets() | ||||||
|  | 	if len(packets) != 0 { | ||||||
|  | 		t.Fatal(packets) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestConnWriter__Forward_notDirect(t *testing.T) { | ||||||
|  | 	p1, p2, _ := NewPeersForTesting() | ||||||
|  | 	p1.RT.Load().Peers[2].Direct = false | ||||||
|  |  | ||||||
|  | 	in := RandPacket() | ||||||
|  | 	p1.ConnWriter.Forward(2, in) | ||||||
|  |  | ||||||
|  | 	packets := p2.Conn.Packets() | ||||||
|  | 	if len(packets) != 0 { | ||||||
|  | 		t.Fatal(packets) | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @@ -17,25 +17,25 @@ type controlMsg[T any] struct { | |||||||
| func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error) { | func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error) { | ||||||
| 	switch buf[0] { | 	switch buf[0] { | ||||||
|  |  | ||||||
| 	case packetTypeSyn: | 	case PacketTypeSyn: | ||||||
| 		packet, err := parseSynPacket(buf) | 		packet, err := ParsePacketSyn(buf) | ||||||
| 		return controlMsg[synPacket]{ | 		return controlMsg[PacketSyn]{ | ||||||
| 			SrcIP:   srcIP, | 			SrcIP:   srcIP, | ||||||
| 			SrcAddr: srcAddr, | 			SrcAddr: srcAddr, | ||||||
| 			Packet:  packet, | 			Packet:  packet, | ||||||
| 		}, err | 		}, err | ||||||
|  |  | ||||||
| 	case packetTypeAck: | 	case PacketTypeAck: | ||||||
| 		packet, err := parseAckPacket(buf) | 		packet, err := ParsePacketAck(buf) | ||||||
| 		return controlMsg[ackPacket]{ | 		return controlMsg[PacketAck]{ | ||||||
| 			SrcIP:   srcIP, | 			SrcIP:   srcIP, | ||||||
| 			SrcAddr: srcAddr, | 			SrcAddr: srcAddr, | ||||||
| 			Packet:  packet, | 			Packet:  packet, | ||||||
| 		}, err | 		}, err | ||||||
|  |  | ||||||
| 	case packetTypeProbe: | 	case PacketTypeProbe: | ||||||
| 		packet, err := parseProbePacket(buf) | 		packet, err := ParsePacketProbe(buf) | ||||||
| 		return controlMsg[probePacket]{ | 		return controlMsg[PacketProbe]{ | ||||||
| 			SrcIP:   srcIP, | 			SrcIP:   srcIP, | ||||||
| 			SrcAddr: srcAddr, | 			SrcAddr: srcAddr, | ||||||
| 			Packet:  packet, | 			Packet:  packet, | ||||||
|   | |||||||
| @@ -37,13 +37,13 @@ func generateKeys() cryptoKeys { | |||||||
| func encryptControlPacket( | func encryptControlPacket( | ||||||
| 	localIP byte, | 	localIP byte, | ||||||
| 	peer *RemotePeer, | 	peer *RemotePeer, | ||||||
| 	pkt marshaller, | 	pkt Marshaller, | ||||||
| 	tmp []byte, | 	tmp []byte, | ||||||
| 	out []byte, | 	out []byte, | ||||||
| ) []byte { | ) []byte { | ||||||
| 	h := header{ | 	h := header{ | ||||||
| 		StreamID: controlStreamID, | 		StreamID: controlStreamID, | ||||||
| 		Counter:  atomic.AddUint64(peer.Counter, 1), | 		Counter:  atomic.AddUint64(peer.counter, 1), | ||||||
| 		SourceIP: localIP, | 		SourceIP: localIP, | ||||||
| 		DestIP:   peer.IP, | 		DestIP:   peer.IP, | ||||||
| 	} | 	} | ||||||
| @@ -66,7 +66,7 @@ func decryptControlPacket( | |||||||
| 		return nil, errDecryptionFailed | 		return nil, errDecryptionFailed | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if peer.DupCheck.IsDup(h.Counter) { | 	if peer.dupCheck.IsDup(h.Counter) { | ||||||
| 		return nil, errDuplicateSeqNum | 		return nil, errDuplicateSeqNum | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -89,7 +89,7 @@ func encryptDataPacket( | |||||||
| ) []byte { | ) []byte { | ||||||
| 	h := header{ | 	h := header{ | ||||||
| 		StreamID: dataStreamID, | 		StreamID: dataStreamID, | ||||||
| 		Counter:  atomic.AddUint64(peer.Counter, 1), | 		Counter:  atomic.AddUint64(peer.counter, 1), | ||||||
| 		SourceIP: localIP, | 		SourceIP: localIP, | ||||||
| 		DestIP:   destIP, | 		DestIP:   destIP, | ||||||
| 	} | 	} | ||||||
| @@ -108,7 +108,7 @@ func decryptDataPacket( | |||||||
| 		return nil, errDecryptionFailed | 		return nil, errDecryptionFailed | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if peer.DupCheck.IsDup(h.Counter) { | 	if peer.dupCheck.IsDup(h.Counter) { | ||||||
| 		return nil, errDuplicateSeqNum | 		return nil, errDuplicateSeqNum | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|   | |||||||
| @@ -33,7 +33,7 @@ func TestDecryptControlPacket(t *testing.T) { | |||||||
| 		out    = make([]byte, bufferSize) | 		out    = make([]byte, bufferSize) | ||||||
| 	) | 	) | ||||||
|  |  | ||||||
| 	in := synPacket{ | 	in := PacketSyn{ | ||||||
| 		TraceID:   newTraceID(), | 		TraceID:   newTraceID(), | ||||||
| 		SharedKey: r1.DataCipher.Key(), | 		SharedKey: r1.DataCipher.Key(), | ||||||
| 		Direct:    true, | 		Direct:    true, | ||||||
| @@ -47,7 +47,7 @@ func TestDecryptControlPacket(t *testing.T) { | |||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	msg, ok := iMsg.(controlMsg[synPacket]) | 	msg, ok := iMsg.(controlMsg[PacketSyn]) | ||||||
| 	if !ok { | 	if !ok { | ||||||
| 		t.Fatal(ok) | 		t.Fatal(ok) | ||||||
| 	} | 	} | ||||||
| @@ -64,7 +64,7 @@ func TestDecryptControlPacket_decryptionFailed(t *testing.T) { | |||||||
| 		out    = make([]byte, bufferSize) | 		out    = make([]byte, bufferSize) | ||||||
| 	) | 	) | ||||||
|  |  | ||||||
| 	in := synPacket{ | 	in := PacketSyn{ | ||||||
| 		TraceID:   newTraceID(), | 		TraceID:   newTraceID(), | ||||||
| 		SharedKey: r1.DataCipher.Key(), | 		SharedKey: r1.DataCipher.Key(), | ||||||
| 		Direct:    true, | 		Direct:    true, | ||||||
| @@ -90,7 +90,7 @@ func TestDecryptControlPacket_duplicate(t *testing.T) { | |||||||
| 		out    = make([]byte, bufferSize) | 		out    = make([]byte, bufferSize) | ||||||
| 	) | 	) | ||||||
|  |  | ||||||
| 	in := synPacket{ | 	in := PacketSyn{ | ||||||
| 		TraceID:   newTraceID(), | 		TraceID:   newTraceID(), | ||||||
| 		SharedKey: r1.DataCipher.Key(), | 		SharedKey: r1.DataCipher.Key(), | ||||||
| 		Direct:    true, | 		Direct:    true, | ||||||
|   | |||||||
							
								
								
									
										14
									
								
								peer/data-flow.dot
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								peer/data-flow.dot
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,14 @@ | |||||||
|  | digraph d { | ||||||
|  |     ifReader   -> connWriter; | ||||||
|  |     connReader -> ifWriter; | ||||||
|  |     connReader -> connWriter; | ||||||
|  |     connReader -> supervisor; | ||||||
|  |     mcReader   -> supervisor; | ||||||
|  |     supervisor -> connWriter; | ||||||
|  |     supervisor -> mcWriter; | ||||||
|  |     hubPoller  -> supervisor; | ||||||
|  |  | ||||||
|  |     connWriter [shape="box"]; | ||||||
|  |     mcWriter [shape="box"]; | ||||||
|  |     ifWriter [shape="box"]; | ||||||
|  | } | ||||||
							
								
								
									
										90
									
								
								peer/files.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										90
									
								
								peer/files.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,90 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"log" | ||||||
|  | 	"os" | ||||||
|  | 	"path/filepath" | ||||||
|  | 	"vppn/m" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type localConfig struct { | ||||||
|  | 	m.PeerConfig | ||||||
|  | 	PubKey      []byte | ||||||
|  | 	PrivKey     []byte | ||||||
|  | 	PubSignKey  []byte | ||||||
|  | 	PrivSignKey []byte | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func configDir(netName string) string { | ||||||
|  | 	d, err := os.UserHomeDir() | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Fatalf("Failed to get user home directory: %v", err) | ||||||
|  | 	} | ||||||
|  | 	return filepath.Join(d, ".vppn", netName) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func peerConfigPath(netName string) string { | ||||||
|  | 	return filepath.Join(configDir(netName), "peer-config.json") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func peerStatePath(netName string) string { | ||||||
|  | 	return filepath.Join(configDir(netName), "peer-state.json") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func storeJson(x any, outPath string) error { | ||||||
|  | 	outDir := filepath.Dir(outPath) | ||||||
|  | 	_ = os.MkdirAll(outDir, 0700) | ||||||
|  |  | ||||||
|  | 	tmpPath := outPath + ".tmp" | ||||||
|  | 	buf, err := json.Marshal(x) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	f, err := os.Create(tmpPath) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if _, err := f.Write(buf); err != nil { | ||||||
|  | 		f.Close() | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if err := f.Sync(); err != nil { | ||||||
|  | 		f.Close() | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if err := f.Close(); err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return os.Rename(tmpPath, outPath) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func storePeerConfig(netName string, pc localConfig) error { | ||||||
|  | 	return storeJson(pc, peerConfigPath(netName)) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func storeNetworkState(netName string, ps m.NetworkState) error { | ||||||
|  | 	return storeJson(ps, peerStatePath(netName)) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func loadJson(dataPath string, ptr any) error { | ||||||
|  | 	data, err := os.ReadFile(dataPath) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return json.Unmarshal(data, ptr) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func loadPeerConfig(netName string) (pc localConfig, err error) { | ||||||
|  | 	return pc, loadJson(peerConfigPath(netName), &pc) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func loadNetworkState(netName string) (ps m.NetworkState, err error) { | ||||||
|  | 	return ps, loadJson(peerStatePath(netName), &ps) | ||||||
|  | } | ||||||
							
								
								
									
										57
									
								
								peer/files_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										57
									
								
								peer/files_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,57 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"path/filepath" | ||||||
|  | 	"reflect" | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestFilePaths(t *testing.T) { | ||||||
|  | 	confDir := configDir("netName") | ||||||
|  | 	if filepath.Base(confDir) != "netName" { | ||||||
|  | 		t.Fatal(confDir) | ||||||
|  | 	} | ||||||
|  | 	if filepath.Base(filepath.Dir(confDir)) != ".vppn" { | ||||||
|  | 		t.Fatal(confDir) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	path := peerConfigPath("netName") | ||||||
|  | 	if path != filepath.Join(confDir, "peer-config.json") { | ||||||
|  | 		t.Fatal(path) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	path = peerStatePath("netName") | ||||||
|  | 	if path != filepath.Join(confDir, "peer-state.json") { | ||||||
|  | 		t.Fatal(path) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestStoreLoadJson(t *testing.T) { | ||||||
|  | 	type Object struct { | ||||||
|  | 		Name  string | ||||||
|  | 		Age   int | ||||||
|  | 		Price float64 | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	tmpDir := t.TempDir() | ||||||
|  | 	outPath := filepath.Join(tmpDir, "object.json") | ||||||
|  |  | ||||||
|  | 	obj := Object{ | ||||||
|  | 		Name:  "Jason", | ||||||
|  | 		Age:   22, | ||||||
|  | 		Price: 123.534, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if err := storeJson(obj, outPath); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	obj2 := Object{} | ||||||
|  | 	if err := loadJson(outPath, &obj2); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if !reflect.DeepEqual(obj, obj2) { | ||||||
|  | 		t.Fatal(obj, obj2) | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @@ -3,15 +3,21 @@ package peer | |||||||
| import ( | import ( | ||||||
| 	"net" | 	"net" | ||||||
| 	"net/netip" | 	"net/netip" | ||||||
|  | 	"time" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	bufferSize = 1536 | 	bufferSize = 1536 | ||||||
|  |  | ||||||
| 	if_mtu       = 1200 | 	if_mtu       = 1200 | ||||||
| 	if_queue_len = 2048 | 	if_queue_len = 2048 | ||||||
|  |  | ||||||
| 	controlCipherOverhead = 16 | 	controlCipherOverhead = 16 | ||||||
| 	dataCipherOverhead    = 16 | 	dataCipherOverhead    = 16 | ||||||
| 	signOverhead          = 64 | 	signOverhead          = 64 | ||||||
|  |  | ||||||
|  | 	pingInterval    = 8 * time.Second | ||||||
|  | 	timeoutInterval = 30 * time.Second | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var multicastAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom( | var multicastAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom( | ||||||
|   | |||||||
							
								
								
									
										100
									
								
								peer/hubpoller.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										100
									
								
								peer/hubpoller.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,100 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"io" | ||||||
|  | 	"log" | ||||||
|  | 	"net/http" | ||||||
|  | 	"net/url" | ||||||
|  | 	"time" | ||||||
|  | 	"vppn/m" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type hubPoller struct { | ||||||
|  | 	client   *http.Client | ||||||
|  | 	req      *http.Request | ||||||
|  | 	versions [256]int64 | ||||||
|  | 	localIP  byte | ||||||
|  | 	netName  string | ||||||
|  | 	super    controlMsgHandler | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func newHubPoller(localIP byte, netName, hubURL, apiKey string, super controlMsgHandler) (*hubPoller, error) { | ||||||
|  | 	u, err := url.Parse(hubURL) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	u.Path = "/peer/fetch-state/" | ||||||
|  |  | ||||||
|  | 	client := &http.Client{Timeout: 8 * time.Second} | ||||||
|  |  | ||||||
|  | 	req := &http.Request{ | ||||||
|  | 		Method: http.MethodGet, | ||||||
|  | 		URL:    u, | ||||||
|  | 		Header: http.Header{}, | ||||||
|  | 	} | ||||||
|  | 	req.SetBasicAuth("", apiKey) | ||||||
|  |  | ||||||
|  | 	return &hubPoller{ | ||||||
|  | 		client:  client, | ||||||
|  | 		req:     req, | ||||||
|  | 		localIP: localIP, | ||||||
|  | 		netName: netName, | ||||||
|  | 		super:   super, | ||||||
|  | 	}, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (hp *hubPoller) Run() { | ||||||
|  | 	state, err := loadNetworkState(hp.netName) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Printf("Failed to load network state: %v", err) | ||||||
|  | 		log.Printf("Polling hub...") | ||||||
|  | 		hp.pollHub() | ||||||
|  | 	} else { | ||||||
|  | 		hp.applyNetworkState(state) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for range time.Tick(64 * time.Second) { | ||||||
|  | 		hp.pollHub() | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (hp *hubPoller) pollHub() { | ||||||
|  | 	var state m.NetworkState | ||||||
|  |  | ||||||
|  | 	resp, err := hp.client.Do(hp.req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Printf("Failed to fetch peer state: %v", err) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	body, err := io.ReadAll(resp.Body) | ||||||
|  | 	_ = resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Printf("Failed to read body from hub: %v", err) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if err := json.Unmarshal(body, &state); err != nil { | ||||||
|  | 		log.Printf("Failed to unmarshal response from hub: %v\n%s", err, body) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	hp.applyNetworkState(state) | ||||||
|  |  | ||||||
|  | 	if err := storeNetworkState(hp.netName, state); err != nil { | ||||||
|  | 		log.Printf("Failed to store network state: %v", err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (hp *hubPoller) applyNetworkState(state m.NetworkState) { | ||||||
|  | 	for i, peer := range state.Peers { | ||||||
|  | 		if i != int(hp.localIP) { | ||||||
|  | 			if peer == nil || peer.Version != hp.versions[i] { | ||||||
|  | 				hp.super.HandleControlMsg(peerUpdateMsg{PeerIP: byte(i), Peer: state.Peers[i]}) | ||||||
|  | 				if peer != nil { | ||||||
|  | 					hp.versions[i] = peer.Version | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										78
									
								
								peer/ifreader2.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										78
									
								
								peer/ifreader2.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,78 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"io" | ||||||
|  | 	"log" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type IFReader struct { | ||||||
|  | 	iface      io.Reader | ||||||
|  | 	connWriter interface { | ||||||
|  | 		WriteData(ip byte, pkt []byte) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func NewIFReader( | ||||||
|  | 	iface io.Reader, | ||||||
|  | 	connWriter interface { | ||||||
|  | 		WriteData(ip byte, pkt []byte) | ||||||
|  | 	}, | ||||||
|  | ) *IFReader { | ||||||
|  | 	return &IFReader{iface, connWriter} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *IFReader) Run() { | ||||||
|  | 	packet := newBuf() | ||||||
|  | 	for { | ||||||
|  | 		r.handleNextPacket(packet) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *IFReader) handleNextPacket(packet []byte) { | ||||||
|  | 	packet = r.readNextPacket(packet) | ||||||
|  | 	if remoteIP, ok := r.parsePacket(packet); ok { | ||||||
|  | 		r.connWriter.WriteData(remoteIP, packet) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *IFReader) readNextPacket(buf []byte) []byte { | ||||||
|  | 	n, err := r.iface.Read(buf[:cap(buf)]) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Fatalf("Failed to read from interface: %v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return buf[:n] | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *IFReader) parsePacket(buf []byte) (byte, bool) { | ||||||
|  | 	n := len(buf) | ||||||
|  | 	if n == 0 { | ||||||
|  | 		return 0, false | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	version := buf[0] >> 4 | ||||||
|  |  | ||||||
|  | 	switch version { | ||||||
|  | 	case 4: | ||||||
|  | 		if n < 20 { | ||||||
|  | 			r.logf("Short IPv4 packet: %d", len(buf)) | ||||||
|  | 			return 0, false | ||||||
|  | 		} | ||||||
|  | 		return buf[19], true | ||||||
|  |  | ||||||
|  | 	case 6: | ||||||
|  | 		if len(buf) < 40 { | ||||||
|  | 			r.logf("Short IPv6 packet: %d", len(buf)) | ||||||
|  | 			return 0, false | ||||||
|  | 		} | ||||||
|  | 		return buf[39], true | ||||||
|  |  | ||||||
|  | 	default: | ||||||
|  | 		r.logf("Invalid IP packet version: %v", version) | ||||||
|  | 		return 0, false | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (*IFReader) logf(s string, args ...any) { | ||||||
|  | 	log.Printf("[IFReader] "+s, args...) | ||||||
|  | } | ||||||
							
								
								
									
										83
									
								
								peer/ifreader2_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										83
									
								
								peer/ifreader2_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,83 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestIFReader_IPv4(t *testing.T) { | ||||||
|  | 	p1, p2, _ := NewPeersForTesting() | ||||||
|  |  | ||||||
|  | 	pkt := make([]byte, 1234) | ||||||
|  | 	pkt[0] = 4 << 4 | ||||||
|  | 	pkt[19] = 2 // IP. | ||||||
|  |  | ||||||
|  | 	p1.IFace.UserWrite(pkt) | ||||||
|  | 	p1.IFReader.handleNextPacket(newBuf()) | ||||||
|  |  | ||||||
|  | 	packets := p2.Conn.Packets() | ||||||
|  | 	if len(packets) != 1 { | ||||||
|  | 		t.Fatal(packets) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestIFReader_IPv6(t *testing.T) { | ||||||
|  | 	p1, p2, _ := NewPeersForTesting() | ||||||
|  |  | ||||||
|  | 	pkt := make([]byte, 1234) | ||||||
|  | 	pkt[0] = 6 << 4 | ||||||
|  | 	pkt[39] = 2 // IP. | ||||||
|  |  | ||||||
|  | 	p1.IFace.UserWrite(pkt) | ||||||
|  | 	p1.IFReader.handleNextPacket(newBuf()) | ||||||
|  |  | ||||||
|  | 	packets := p2.Conn.Packets() | ||||||
|  | 	if len(packets) != 1 { | ||||||
|  | 		t.Fatal(packets) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestIFReader_parsePacket_emptyPacket(t *testing.T) { | ||||||
|  | 	r := NewIFReader(nil, nil) | ||||||
|  | 	pkt := make([]byte, 0) | ||||||
|  | 	if ip, ok := r.parsePacket(pkt); ok { | ||||||
|  | 		t.Fatal(ip, ok) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestIFReader_parsePacket_invalidIPVersion(t *testing.T) { | ||||||
|  | 	r := NewIFReader(nil, nil) | ||||||
|  |  | ||||||
|  | 	for i := byte(1); i < 16; i++ { | ||||||
|  | 		if i == 4 || i == 6 { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		pkt := make([]byte, 1234) | ||||||
|  | 		pkt[0] = i << 4 | ||||||
|  |  | ||||||
|  | 		if ip, ok := r.parsePacket(pkt); ok { | ||||||
|  | 			t.Fatal(i, ip, ok) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestIFReader_parsePacket_shortIPv4(t *testing.T) { | ||||||
|  | 	r := NewIFReader(nil, nil) | ||||||
|  |  | ||||||
|  | 	pkt := make([]byte, 19) | ||||||
|  | 	pkt[0] = 4 << 4 | ||||||
|  |  | ||||||
|  | 	if ip, ok := r.parsePacket(pkt); ok { | ||||||
|  | 		t.Fatal(ip, ok) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestIFReader_parsePacket_shortIPv6(t *testing.T) { | ||||||
|  | 	r := NewIFReader(nil, nil) | ||||||
|  |  | ||||||
|  | 	pkt := make([]byte, 39) | ||||||
|  | 	pkt[0] = 6 << 4 | ||||||
|  |  | ||||||
|  | 	if ip, ok := r.parsePacket(pkt); ok { | ||||||
|  | 		t.Fatal(ip, ok) | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @@ -2,7 +2,6 @@ package peer | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
| 	"net" |  | ||||||
| 	"reflect" | 	"reflect" | ||||||
| 	"sync/atomic" | 	"sync/atomic" | ||||||
| 	"testing" | 	"testing" | ||||||
| @@ -34,6 +33,7 @@ func TestIFReader_parsePacket_ipv6(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | /* | ||||||
| // Test that empty packets work as expected. | // Test that empty packets work as expected. | ||||||
| func TestIFReader_parsePacket_emptyPacket(t *testing.T) { | func TestIFReader_parsePacket_emptyPacket(t *testing.T) { | ||||||
| 	r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) | 	r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) | ||||||
| @@ -99,7 +99,7 @@ func TestIFReader_readNextpacket(t *testing.T) { | |||||||
| 		t.Fatalf("%s", pkt) | 		t.Fatalf("%s", pkt) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | */ | ||||||
| // ---------------------------------------------------------------------------- | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
| type sentPacket struct { | type sentPacket struct { | ||||||
|   | |||||||
| @@ -1,5 +0,0 @@ | |||||||
| package peer |  | ||||||
|  |  | ||||||
| import "io" |  | ||||||
|  |  | ||||||
| type ifWriter io.Writer |  | ||||||
| @@ -1,10 +1,19 @@ | |||||||
| package peer | package peer | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"io" | ||||||
| 	"net" | 	"net" | ||||||
| 	"net/netip" | 	"net/netip" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | type UDPConn interface { | ||||||
|  | 	ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) | ||||||
|  | 	WriteToUDPAddrPort([]byte, netip.AddrPort) (int, error) | ||||||
|  | 	WriteToUDP([]byte, *net.UDPAddr) (int, error) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type ifWriter io.Writer | ||||||
|  |  | ||||||
| type udpReader interface { | type udpReader interface { | ||||||
| 	ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) | 	ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) | ||||||
| } | } | ||||||
| @@ -13,7 +22,11 @@ type udpWriter interface { | |||||||
| 	WriteToUDPAddrPort([]byte, netip.AddrPort) (int, error) | 	WriteToUDPAddrPort([]byte, netip.AddrPort) (int, error) | ||||||
| } | } | ||||||
|  |  | ||||||
| type marshaller interface { | type mcUDPWriter interface { | ||||||
|  | 	WriteToUDP([]byte, *net.UDPAddr) (int, error) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Marshaller interface { | ||||||
| 	Marshal([]byte) []byte | 	Marshal([]byte) []byte | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -22,6 +35,11 @@ type dataPacketSender interface { | |||||||
| 	RelayDataPacket(pkt []byte, peer, relay *RemotePeer) | 	RelayDataPacket(pkt []byte, peer, relay *RemotePeer) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | type controlPacketSender interface { | ||||||
|  | 	SendControlPacket(pkt Marshaller, peer *RemotePeer) | ||||||
|  | 	RelayControlPacket(pkt Marshaller, peer, relay *RemotePeer) | ||||||
|  | } | ||||||
|  |  | ||||||
| type encryptedPacketSender interface { | type encryptedPacketSender interface { | ||||||
| 	SendEncryptedDataPacket(pkt []byte, peer *RemotePeer) | 	SendEncryptedDataPacket(pkt []byte, peer *RemotePeer) | ||||||
| } | } | ||||||
| @@ -29,7 +47,3 @@ type encryptedPacketSender interface { | |||||||
| type controlMsgHandler interface { | type controlMsgHandler interface { | ||||||
| 	HandleControlMsg(pkt any) | 	HandleControlMsg(pkt any) | ||||||
| } | } | ||||||
|  |  | ||||||
| type mcUDPWriter interface { |  | ||||||
| 	WriteToUDP([]byte, *net.UDPAddr) (int, error) |  | ||||||
| } |  | ||||||
|   | |||||||
| @@ -50,7 +50,7 @@ func (r *mcReader) handleNextPacket() { | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	r.super.HandleControlMsg(controlMsg[localDiscoveryPacket]{ | 	r.super.HandleControlMsg(controlMsg[PacketLocalDiscovery]{ | ||||||
| 		SrcIP:   h.SourceIP, | 		SrcIP:   h.SourceIP, | ||||||
| 		SrcAddr: remoteAddr, | 		SrcAddr: remoteAddr, | ||||||
| 	}) | 	}) | ||||||
|   | |||||||
							
								
								
									
										138
									
								
								peer/mcreader_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										138
									
								
								peer/mcreader_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,138 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"net" | ||||||
|  | 	"net/netip" | ||||||
|  | 	"sync/atomic" | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type mcMockConn struct { | ||||||
|  | 	packets chan []byte | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func newMCMockConn() *mcMockConn { | ||||||
|  | 	return &mcMockConn{make(chan []byte, 32)} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *mcMockConn) WriteToUDP(in []byte, addr *net.UDPAddr) (int, error) { | ||||||
|  | 	c.packets <- bytes.Clone(in) | ||||||
|  | 	return len(in), nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *mcMockConn) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) { | ||||||
|  | 	buf := <-c.packets | ||||||
|  | 	b = b[:len(buf)] | ||||||
|  | 	copy(b, buf) | ||||||
|  | 	return len(b), netip.AddrPort{}, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestMCReader(t *testing.T) { | ||||||
|  | 	keys := generateKeys() | ||||||
|  | 	super := &mockControlMsgHandler{} | ||||||
|  | 	conn := newMCMockConn() | ||||||
|  |  | ||||||
|  | 	peers := [256]*atomic.Pointer[RemotePeer]{} | ||||||
|  | 	peer := &RemotePeer{ | ||||||
|  | 		IP:         1, | ||||||
|  | 		Up:         true, | ||||||
|  | 		PubSignKey: keys.PubSignKey, | ||||||
|  | 	} | ||||||
|  | 	peers[1] = &atomic.Pointer[RemotePeer]{} | ||||||
|  | 	peers[1].Store(peer) | ||||||
|  |  | ||||||
|  | 	w := newMCWriter(conn, 1, keys.PrivSignKey) | ||||||
|  | 	r := newMCReader(conn, super, peers) | ||||||
|  |  | ||||||
|  | 	w.SendLocalDiscovery() | ||||||
|  | 	r.handleNextPacket() | ||||||
|  |  | ||||||
|  | 	if len(super.Messages) != 1 { | ||||||
|  | 		t.Fatal(super.Messages) | ||||||
|  | 	} | ||||||
|  | 	msg, ok := super.Messages[0].(controlMsg[PacketLocalDiscovery]) | ||||||
|  | 	if !ok || msg.SrcIP != 1 { | ||||||
|  | 		t.Fatal(ok, msg) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestMCReader_noHeader(t *testing.T) { | ||||||
|  | 	keys := generateKeys() | ||||||
|  | 	super := &mockControlMsgHandler{} | ||||||
|  | 	conn := newMCMockConn() | ||||||
|  |  | ||||||
|  | 	peers := [256]*atomic.Pointer[RemotePeer]{} | ||||||
|  | 	peer := &RemotePeer{ | ||||||
|  | 		IP:         1, | ||||||
|  | 		Up:         true, | ||||||
|  | 		PubSignKey: keys.PubSignKey, | ||||||
|  | 	} | ||||||
|  | 	peers[1] = &atomic.Pointer[RemotePeer]{} | ||||||
|  | 	peers[1].Store(peer) | ||||||
|  |  | ||||||
|  | 	r := newMCReader(conn, super, peers) | ||||||
|  | 	conn.WriteToUDP([]byte("0123546789"), nil) | ||||||
|  | 	r.handleNextPacket() | ||||||
|  |  | ||||||
|  | 	if len(super.Messages) != 0 { | ||||||
|  | 		t.Fatal(super.Messages) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestMCReader_noPeer(t *testing.T) { | ||||||
|  | 	keys := generateKeys() | ||||||
|  | 	super := &mockControlMsgHandler{} | ||||||
|  | 	conn := newMCMockConn() | ||||||
|  |  | ||||||
|  | 	peers := [256]*atomic.Pointer[RemotePeer]{} | ||||||
|  | 	peer := &RemotePeer{ | ||||||
|  | 		IP:         1, | ||||||
|  | 		Up:         true, | ||||||
|  | 		PubSignKey: keys.PubSignKey, | ||||||
|  | 	} | ||||||
|  | 	peers[1] = &atomic.Pointer[RemotePeer]{} | ||||||
|  | 	peers[2] = &atomic.Pointer[RemotePeer]{} | ||||||
|  | 	peers[1].Store(peer) | ||||||
|  |  | ||||||
|  | 	w := newMCWriter(conn, 2, keys.PrivSignKey) | ||||||
|  | 	r := newMCReader(conn, super, peers) | ||||||
|  |  | ||||||
|  | 	w.SendLocalDiscovery() | ||||||
|  | 	r.handleNextPacket() | ||||||
|  |  | ||||||
|  | 	if len(super.Messages) != 0 { | ||||||
|  | 		t.Fatal(super.Messages) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestMCReader_badSignature(t *testing.T) { | ||||||
|  | 	keys := generateKeys() | ||||||
|  | 	super := &mockControlMsgHandler{} | ||||||
|  | 	conn := newMCMockConn() | ||||||
|  |  | ||||||
|  | 	peers := [256]*atomic.Pointer[RemotePeer]{} | ||||||
|  | 	peer := &RemotePeer{ | ||||||
|  | 		IP:         1, | ||||||
|  | 		Up:         true, | ||||||
|  | 		PubSignKey: keys.PubSignKey, | ||||||
|  | 	} | ||||||
|  | 	peers[1] = &atomic.Pointer[RemotePeer]{} | ||||||
|  | 	peers[1].Store(peer) | ||||||
|  |  | ||||||
|  | 	w := newMCWriter(conn, 1, keys.PrivSignKey) | ||||||
|  | 	w.SendLocalDiscovery() | ||||||
|  |  | ||||||
|  | 	// Break signing. | ||||||
|  | 	packet := <-conn.packets | ||||||
|  | 	packet[0]++ | ||||||
|  | 	conn.packets <- packet | ||||||
|  |  | ||||||
|  | 	r := newMCReader(conn, super, peers) | ||||||
|  |  | ||||||
|  | 	r.handleNextPacket() | ||||||
|  |  | ||||||
|  | 	if len(super.Messages) != 0 { | ||||||
|  | 		t.Fatal(super.Messages) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										31
									
								
								peer/mock-iface_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								peer/mock-iface_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,31 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import "bytes" | ||||||
|  |  | ||||||
|  | type TestIFace struct { | ||||||
|  | 	out *bytes.Buffer // Toward the network. | ||||||
|  | 	in  *bytes.Buffer // From the network | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func NewTestIFace() *TestIFace { | ||||||
|  | 	return &TestIFace{ | ||||||
|  | 		out: &bytes.Buffer{}, | ||||||
|  | 		in:  &bytes.Buffer{}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (iface *TestIFace) Write(b []byte) (int, error) { | ||||||
|  | 	return iface.in.Write(b) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (iface *TestIFace) Read(b []byte) (int, error) { | ||||||
|  | 	return iface.out.Read(b) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (iface *TestIFace) UserWrite(b []byte) (int, error) { | ||||||
|  | 	return iface.out.Write(b) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (iface *TestIFace) UserRead(b []byte) (int, error) { | ||||||
|  | 	return iface.in.Read(b) | ||||||
|  | } | ||||||
							
								
								
									
										80
									
								
								peer/mock-network_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										80
									
								
								peer/mock-network_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,80 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"net" | ||||||
|  | 	"net/netip" | ||||||
|  | 	"sync" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type TestPacket struct { | ||||||
|  | 	Addr netip.AddrPort | ||||||
|  | 	Data []byte | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type TestNetwork struct { | ||||||
|  | 	lock    sync.Mutex | ||||||
|  | 	packets map[netip.AddrPort]chan TestPacket | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func NewTestNetwork() *TestNetwork { | ||||||
|  | 	return &TestNetwork{packets: map[netip.AddrPort]chan TestPacket{}} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (n *TestNetwork) NewUDPConn(localAddr netip.AddrPort) *TestUDPConn { | ||||||
|  | 	n.lock.Lock() | ||||||
|  | 	defer n.lock.Unlock() | ||||||
|  | 	if _, ok := n.packets[localAddr]; !ok { | ||||||
|  | 		n.packets[localAddr] = make(chan TestPacket, 1024) | ||||||
|  | 	} | ||||||
|  | 	return &TestUDPConn{ | ||||||
|  | 		addr:    localAddr, | ||||||
|  | 		n:       n, | ||||||
|  | 		packets: n.packets[localAddr], | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (n *TestNetwork) write(b []byte, from, to netip.AddrPort) { | ||||||
|  | 	n.lock.Lock() | ||||||
|  | 	defer n.lock.Unlock() | ||||||
|  | 	if _, ok := n.packets[to]; !ok { | ||||||
|  | 		n.packets[to] = make(chan TestPacket, 1024) | ||||||
|  | 	} | ||||||
|  | 	n.packets[to] <- TestPacket{ | ||||||
|  | 		Addr: from, | ||||||
|  | 		Data: bytes.Clone(b), | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type TestUDPConn struct { | ||||||
|  | 	addr    netip.AddrPort | ||||||
|  | 	n       *TestNetwork | ||||||
|  | 	packets chan TestPacket | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *TestUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { | ||||||
|  | 	c.n.write(b, c.addr, addr) | ||||||
|  | 	return len(b), nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *TestUDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) { | ||||||
|  | 	return c.WriteToUDPAddrPort(b, addr.AddrPort()) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *TestUDPConn) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) { | ||||||
|  | 	pkt := <-c.packets | ||||||
|  | 	b = b[:len(pkt.Data)] | ||||||
|  | 	copy(b, pkt.Data) | ||||||
|  | 	return len(b), pkt.Addr, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *TestUDPConn) Packets() (out []TestPacket) { | ||||||
|  | 	for { | ||||||
|  | 		select { | ||||||
|  | 		case pkt := <-c.packets: | ||||||
|  | 			out = append(out, pkt) | ||||||
|  | 		default: | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @@ -70,7 +70,7 @@ func (w *binWriter) AddrPort(addrPort netip.AddrPort) *binWriter { | |||||||
| 	return w.Uint16(addrPort.Port()) | 	return w.Uint16(addrPort.Port()) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (w *binWriter) AddrPortArray(l [8]netip.AddrPort) *binWriter { | func (w *binWriter) AddrPort8(l [8]netip.AddrPort) *binWriter { | ||||||
| 	for _, addrPort := range l { | 	for _, addrPort := range l { | ||||||
| 		w.AddrPort(addrPort) | 		w.AddrPort(addrPort) | ||||||
| 	} | 	} | ||||||
| @@ -178,7 +178,7 @@ func (r *binReader) AddrPort(x *netip.AddrPort) *binReader { | |||||||
| 	return r | 	return r | ||||||
| } | } | ||||||
|  |  | ||||||
| func (r *binReader) AddrPortArray(x *[8]netip.AddrPort) *binReader { | func (r *binReader) AddrPort8(x *[8]netip.AddrPort) *binReader { | ||||||
| 	for i := range x { | 	for i := range x { | ||||||
| 		r.AddrPort(&x[i]) | 		r.AddrPort(&x[i]) | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -6,6 +6,26 @@ import ( | |||||||
| 	"testing" | 	"testing" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | func TestBinWriteRead_invalidAddrPort(t *testing.T) { | ||||||
|  | 	addr := netip.AddrPort{} | ||||||
|  | 	buf := make([]byte, 1024) | ||||||
|  | 	buf = newBinWriter(buf). | ||||||
|  | 		AddrPort(addr). | ||||||
|  | 		Build() | ||||||
|  |  | ||||||
|  | 	var addr2 netip.AddrPort | ||||||
|  | 	err := newBinReader(buf). | ||||||
|  | 		AddrPort(&addr2). | ||||||
|  | 		Error() | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if addr2.IsValid() { | ||||||
|  | 		t.Fatal(addr, addr2) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| func TestBinWriteRead(t *testing.T) { | func TestBinWriteRead(t *testing.T) { | ||||||
| 	buf := make([]byte, 1024) | 	buf := make([]byte, 1024) | ||||||
|  |  | ||||||
| @@ -35,7 +55,7 @@ func TestBinWriteRead(t *testing.T) { | |||||||
| 		Byte(in.Type). | 		Byte(in.Type). | ||||||
| 		Uint64(in.TraceID). | 		Uint64(in.TraceID). | ||||||
| 		AddrPort(in.DestAddr). | 		AddrPort(in.DestAddr). | ||||||
| 		AddrPortArray(in.Addrs). | 		AddrPort8(in.Addrs). | ||||||
| 		Build() | 		Build() | ||||||
|  |  | ||||||
| 	out := Item{} | 	out := Item{} | ||||||
| @@ -44,7 +64,7 @@ func TestBinWriteRead(t *testing.T) { | |||||||
| 		Byte(&out.Type). | 		Byte(&out.Type). | ||||||
| 		Uint64(&out.TraceID). | 		Uint64(&out.TraceID). | ||||||
| 		AddrPort(&out.DestAddr). | 		AddrPort(&out.DestAddr). | ||||||
| 		AddrPortArray(&out.Addrs). | 		AddrPort8(&out.Addrs). | ||||||
| 		Error() | 		Error() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
|   | |||||||
| @@ -5,93 +5,70 @@ import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	packetTypeSyn = iota + 1 | 	PacketTypeSyn = iota + 1 | ||||||
| 	packetTypeSynAck | 	PacketTypeSynAck | ||||||
| 	packetTypeAck | 	PacketTypeAck | ||||||
| 	packetTypeProbe | 	PacketTypeProbe | ||||||
| 	packetTypeAddrDiscovery | 	PacketTypeAddrDiscovery | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
| type synPacket struct { | type PacketSyn struct { | ||||||
| 	TraceID uint64 // TraceID to match response w/ request. | 	TraceID uint64 // TraceID to match response w/ request. | ||||||
| 	// TODO: SentAt int64 // Unixmilli. | 	//SentAt        int64    // Unixmilli. | ||||||
|  | 	//SharedKeyType byte     // Currently only 1 is supported for AES. | ||||||
| 	SharedKey     [32]byte // Our shared key. | 	SharedKey     [32]byte // Our shared key. | ||||||
| 	Direct        bool | 	Direct        bool | ||||||
| 	PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. | 	PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. | ||||||
| } | } | ||||||
|  |  | ||||||
| func (p synPacket) Marshal(buf []byte) []byte { | func (p PacketSyn) Marshal(buf []byte) []byte { | ||||||
| 	return newBinWriter(buf). | 	return newBinWriter(buf). | ||||||
| 		Byte(packetTypeSyn). | 		Byte(PacketTypeSyn). | ||||||
| 		Uint64(p.TraceID). | 		Uint64(p.TraceID). | ||||||
|  | 		//Int64(p.SentAt). | ||||||
|  | 		//Byte(p.SharedKeyType). | ||||||
| 		SharedKey(p.SharedKey). | 		SharedKey(p.SharedKey). | ||||||
| 		Bool(p.Direct). | 		Bool(p.Direct). | ||||||
| 		AddrPort(p.PossibleAddrs[0]). | 		AddrPort8(p.PossibleAddrs). | ||||||
| 		AddrPort(p.PossibleAddrs[1]). |  | ||||||
| 		AddrPort(p.PossibleAddrs[2]). |  | ||||||
| 		AddrPort(p.PossibleAddrs[3]). |  | ||||||
| 		AddrPort(p.PossibleAddrs[4]). |  | ||||||
| 		AddrPort(p.PossibleAddrs[5]). |  | ||||||
| 		AddrPort(p.PossibleAddrs[6]). |  | ||||||
| 		AddrPort(p.PossibleAddrs[7]). |  | ||||||
| 		Build() | 		Build() | ||||||
| } | } | ||||||
|  |  | ||||||
| func parseSynPacket(buf []byte) (p synPacket, err error) { | func ParsePacketSyn(buf []byte) (p PacketSyn, err error) { | ||||||
| 	err = newBinReader(buf[1:]). | 	err = newBinReader(buf[1:]). | ||||||
| 		Uint64(&p.TraceID). | 		Uint64(&p.TraceID). | ||||||
|  | 		//Int64(&p.SentAt). | ||||||
|  | 		//Byte(&p.SharedKeyType). | ||||||
| 		SharedKey(&p.SharedKey). | 		SharedKey(&p.SharedKey). | ||||||
| 		Bool(&p.Direct). | 		Bool(&p.Direct). | ||||||
| 		AddrPort(&p.PossibleAddrs[0]). | 		AddrPort8(&p.PossibleAddrs). | ||||||
| 		AddrPort(&p.PossibleAddrs[1]). |  | ||||||
| 		AddrPort(&p.PossibleAddrs[2]). |  | ||||||
| 		AddrPort(&p.PossibleAddrs[3]). |  | ||||||
| 		AddrPort(&p.PossibleAddrs[4]). |  | ||||||
| 		AddrPort(&p.PossibleAddrs[5]). |  | ||||||
| 		AddrPort(&p.PossibleAddrs[6]). |  | ||||||
| 		AddrPort(&p.PossibleAddrs[7]). |  | ||||||
| 		Error() | 		Error() | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
| type ackPacket struct { | type PacketAck struct { | ||||||
| 	TraceID       uint64 | 	TraceID       uint64 | ||||||
| 	ToAddr        netip.AddrPort | 	ToAddr        netip.AddrPort | ||||||
| 	PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. | 	PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. | ||||||
| } | } | ||||||
|  |  | ||||||
| func (p ackPacket) Marshal(buf []byte) []byte { | func (p PacketAck) Marshal(buf []byte) []byte { | ||||||
| 	return newBinWriter(buf). | 	return newBinWriter(buf). | ||||||
| 		Byte(packetTypeAck). | 		Byte(PacketTypeAck). | ||||||
| 		Uint64(p.TraceID). | 		Uint64(p.TraceID). | ||||||
| 		AddrPort(p.ToAddr). | 		AddrPort(p.ToAddr). | ||||||
| 		AddrPort(p.PossibleAddrs[0]). | 		AddrPort8(p.PossibleAddrs). | ||||||
| 		AddrPort(p.PossibleAddrs[1]). |  | ||||||
| 		AddrPort(p.PossibleAddrs[2]). |  | ||||||
| 		AddrPort(p.PossibleAddrs[3]). |  | ||||||
| 		AddrPort(p.PossibleAddrs[4]). |  | ||||||
| 		AddrPort(p.PossibleAddrs[5]). |  | ||||||
| 		AddrPort(p.PossibleAddrs[6]). |  | ||||||
| 		AddrPort(p.PossibleAddrs[7]). |  | ||||||
| 		Build() | 		Build() | ||||||
| } | } | ||||||
|  |  | ||||||
| func parseAckPacket(buf []byte) (p ackPacket, err error) { | func ParsePacketAck(buf []byte) (p PacketAck, err error) { | ||||||
| 	err = newBinReader(buf[1:]). | 	err = newBinReader(buf[1:]). | ||||||
| 		Uint64(&p.TraceID). | 		Uint64(&p.TraceID). | ||||||
| 		AddrPort(&p.ToAddr). | 		AddrPort(&p.ToAddr). | ||||||
| 		AddrPort(&p.PossibleAddrs[0]). | 		AddrPort8(&p.PossibleAddrs). | ||||||
| 		AddrPort(&p.PossibleAddrs[1]). |  | ||||||
| 		AddrPort(&p.PossibleAddrs[2]). |  | ||||||
| 		AddrPort(&p.PossibleAddrs[3]). |  | ||||||
| 		AddrPort(&p.PossibleAddrs[4]). |  | ||||||
| 		AddrPort(&p.PossibleAddrs[5]). |  | ||||||
| 		AddrPort(&p.PossibleAddrs[6]). |  | ||||||
| 		AddrPort(&p.PossibleAddrs[7]). |  | ||||||
| 		Error() | 		Error() | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| @@ -100,18 +77,18 @@ func parseAckPacket(buf []byte) (p ackPacket, err error) { | |||||||
|  |  | ||||||
| // A probeReqPacket is sent from a client to a server to determine if direct | // A probeReqPacket is sent from a client to a server to determine if direct | ||||||
| // UDP communication can be used. | // UDP communication can be used. | ||||||
| type probePacket struct { | type PacketProbe struct { | ||||||
| 	TraceID uint64 | 	TraceID uint64 | ||||||
| } | } | ||||||
|  |  | ||||||
| func (p probePacket) Marshal(buf []byte) []byte { | func (p PacketProbe) Marshal(buf []byte) []byte { | ||||||
| 	return newBinWriter(buf). | 	return newBinWriter(buf). | ||||||
| 		Byte(packetTypeProbe). | 		Byte(PacketTypeProbe). | ||||||
| 		Uint64(p.TraceID). | 		Uint64(p.TraceID). | ||||||
| 		Build() | 		Build() | ||||||
| } | } | ||||||
|  |  | ||||||
| func parseProbePacket(buf []byte) (p probePacket, err error) { | func ParsePacketProbe(buf []byte) (p PacketProbe, err error) { | ||||||
| 	err = newBinReader(buf[1:]). | 	err = newBinReader(buf[1:]). | ||||||
| 		Uint64(&p.TraceID). | 		Uint64(&p.TraceID). | ||||||
| 		Error() | 		Error() | ||||||
| @@ -120,4 +97,4 @@ func parseProbePacket(buf []byte) (p probePacket, err error) { | |||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
| type localDiscoveryPacket struct{} | type PacketLocalDiscovery struct{} | ||||||
|   | |||||||
| @@ -1 +1,66 @@ | |||||||
| package peer | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"crypto/rand" | ||||||
|  | 	"net/netip" | ||||||
|  | 	"reflect" | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestSynPacket(t *testing.T) { | ||||||
|  | 	p := PacketSyn{ | ||||||
|  | 		TraceID: newTraceID(), | ||||||
|  | 		//SentAt:        time.Now().UnixMilli(), | ||||||
|  | 		//SharedKeyType: 1, | ||||||
|  | 		Direct: true, | ||||||
|  | 	} | ||||||
|  | 	rand.Read(p.SharedKey[:]) | ||||||
|  |  | ||||||
|  | 	p.PossibleAddrs[0] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 234) | ||||||
|  | 	p.PossibleAddrs[1] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 2, 3, 4}), 12399) | ||||||
|  | 	p.PossibleAddrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{3, 2, 3, 4}), 60000) | ||||||
|  |  | ||||||
|  | 	buf := p.Marshal(newBuf()) | ||||||
|  | 	p2, err := ParsePacketSyn(buf) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	if !reflect.DeepEqual(p, p2) { | ||||||
|  | 		t.Fatal(p2) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestAckPacket(t *testing.T) { | ||||||
|  | 	p := PacketAck{ | ||||||
|  | 		TraceID: newTraceID(), | ||||||
|  | 		ToAddr:  netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 234), | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	p.PossibleAddrs[0] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 2, 3, 4}), 100) | ||||||
|  | 	p.PossibleAddrs[1] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 2, 3, 4}), 12399) | ||||||
|  | 	p.PossibleAddrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{3, 2, 3, 4}), 60000) | ||||||
|  |  | ||||||
|  | 	buf := p.Marshal(newBuf()) | ||||||
|  | 	p2, err := ParsePacketAck(buf) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	if !reflect.DeepEqual(p, p2) { | ||||||
|  | 		t.Fatal(p2) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestProbePacket(t *testing.T) { | ||||||
|  | 	p := PacketProbe{ | ||||||
|  | 		TraceID: newTraceID(), | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	buf := p.Marshal(newBuf()) | ||||||
|  | 	p2, err := ParsePacketProbe(buf) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	if !reflect.DeepEqual(p, p2) { | ||||||
|  | 		t.Fatal(p2) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|   | |||||||
							
								
								
									
										125
									
								
								peer/peer_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										125
									
								
								peer/peer_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,125 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"crypto/rand" | ||||||
|  | 	mrand "math/rand" | ||||||
|  | 	"net/netip" | ||||||
|  | 	"sync/atomic" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // A test peer. | ||||||
|  | type P struct { | ||||||
|  | 	cryptoKeys | ||||||
|  | 	RT         *atomic.Pointer[RoutingTable] | ||||||
|  | 	Conn       *TestUDPConn | ||||||
|  | 	IFace      *TestIFace | ||||||
|  | 	ConnWriter *ConnWriter | ||||||
|  | 	ConnReader *ConnReader | ||||||
|  | 	IFReader   *IFReader | ||||||
|  | 	Super      *Supervisor | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func NewPeerForTesting(n *TestNetwork, ip byte, addr netip.AddrPort) P { | ||||||
|  | 	p := P{ | ||||||
|  | 		cryptoKeys: generateKeys(), | ||||||
|  | 		RT:         &atomic.Pointer[RoutingTable]{}, | ||||||
|  | 		IFace:      NewTestIFace(), | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	rt := NewRoutingTable(ip, addr) | ||||||
|  | 	p.RT.Store(&rt) | ||||||
|  | 	p.Conn = n.NewUDPConn(addr) | ||||||
|  | 	p.ConnWriter = NewConnWriter(p.Conn.WriteToUDPAddrPort, p.RT) | ||||||
|  | 	p.IFReader = NewIFReader(p.IFace, p.ConnWriter) | ||||||
|  |  | ||||||
|  | 	/* | ||||||
|  | 		   p.ConnReader = NewConnReader( | ||||||
|  | 				p.Conn.ReadFromUDPAddrPort, | ||||||
|  | 				p.IFace, | ||||||
|  | 				p.ConnWriter.Forward, | ||||||
|  | 				p.Super.HandleControlMsg, | ||||||
|  | 				p.RT) | ||||||
|  | 	*/ | ||||||
|  | 	return p | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func ConnectPeers(p1, p2 *P) { | ||||||
|  | 	rt1 := p1.RT.Load() | ||||||
|  | 	rt2 := p2.RT.Load() | ||||||
|  |  | ||||||
|  | 	ip1 := rt1.LocalIP | ||||||
|  | 	ip2 := rt2.LocalIP | ||||||
|  |  | ||||||
|  | 	rt1.Peers[ip2].Up = true | ||||||
|  | 	rt1.Peers[ip2].Direct = true | ||||||
|  | 	rt1.Peers[ip2].Relay = true | ||||||
|  | 	rt1.Peers[ip2].DirectAddr = rt2.LocalAddr | ||||||
|  | 	rt1.Peers[ip2].PubSignKey = p2.PubSignKey | ||||||
|  | 	rt1.Peers[ip2].ControlCipher = newControlCipher(p1.PrivKey, p2.PubKey) | ||||||
|  | 	rt1.Peers[ip2].DataCipher = newDataCipher() | ||||||
|  |  | ||||||
|  | 	rt2.Peers[ip1].Up = true | ||||||
|  | 	rt2.Peers[ip1].Direct = true | ||||||
|  | 	rt2.Peers[ip1].Relay = true | ||||||
|  | 	rt2.Peers[ip1].DirectAddr = rt1.LocalAddr | ||||||
|  | 	rt2.Peers[ip1].PubSignKey = p1.PubSignKey | ||||||
|  | 	rt2.Peers[ip1].ControlCipher = newControlCipher(p2.PrivKey, p1.PubKey) | ||||||
|  | 	rt2.Peers[ip1].DataCipher = rt1.Peers[ip2].DataCipher | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func NewPeersForTesting() (p1, p2, p3 P) { | ||||||
|  | 	n := NewTestNetwork() | ||||||
|  |  | ||||||
|  | 	p1 = NewPeerForTesting( | ||||||
|  | 		n, | ||||||
|  | 		1, | ||||||
|  | 		netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 100)) | ||||||
|  |  | ||||||
|  | 	p2 = NewPeerForTesting( | ||||||
|  | 		n, | ||||||
|  | 		2, | ||||||
|  | 		netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 2}), 200)) | ||||||
|  |  | ||||||
|  | 	p3 = NewPeerForTesting( | ||||||
|  | 		n, | ||||||
|  | 		3, | ||||||
|  | 		netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 3}), 300)) | ||||||
|  |  | ||||||
|  | 	ConnectPeers(&p1, &p2) | ||||||
|  | 	ConnectPeers(&p1, &p3) | ||||||
|  | 	ConnectPeers(&p2, &p3) | ||||||
|  |  | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func RandPacket() []byte { | ||||||
|  | 	n := mrand.Intn(1200) | ||||||
|  | 	b := make([]byte, n) | ||||||
|  | 	rand.Read(b) | ||||||
|  | 	return b | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func ModifyPacket(in []byte) []byte { | ||||||
|  | 	x := make([]byte, 1) | ||||||
|  |  | ||||||
|  | 	for { | ||||||
|  | 		rand.Read(x) | ||||||
|  | 		out := bytes.Clone(in) | ||||||
|  | 		idx := mrand.Intn(len(out)) | ||||||
|  | 		if out[idx] != x[0] { | ||||||
|  | 			out[idx] = x[0] | ||||||
|  | 			return out | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type UnknownControlPacket struct { | ||||||
|  | 	TraceID uint64 | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (p UnknownControlPacket) Marshal(buf []byte) []byte { | ||||||
|  | 	return newBinWriter(buf).Byte(255).Uint64(p.TraceID).Build() | ||||||
|  | } | ||||||
							
								
								
									
										355
									
								
								peer/peerstates.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										355
									
								
								peer/peerstates.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,355 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"log" | ||||||
|  | 	"net/netip" | ||||||
|  | 	"strings" | ||||||
|  | 	"time" | ||||||
|  | 	"vppn/m" | ||||||
|  |  | ||||||
|  | 	"git.crumpington.com/lib/go/ratelimiter" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type PeerState interface { | ||||||
|  | 	OnPeerUpdate(*m.Peer) PeerState | ||||||
|  | 	OnSyn(controlMsg[PacketSyn]) PeerState | ||||||
|  | 	OnAck(controlMsg[PacketAck]) | ||||||
|  | 	OnProbe(controlMsg[PacketProbe]) PeerState | ||||||
|  | 	OnLocalDiscovery(controlMsg[PacketLocalDiscovery]) | ||||||
|  | 	OnPingTimer() PeerState | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type State struct { | ||||||
|  | 	// Output. | ||||||
|  | 	publish           func(RemotePeer) | ||||||
|  | 	sendControlPacket func(RemotePeer, Marshaller) | ||||||
|  |  | ||||||
|  | 	// Immutable data. | ||||||
|  | 	localIP   byte | ||||||
|  | 	remoteIP  byte | ||||||
|  | 	privKey   []byte | ||||||
|  | 	localAddr netip.AddrPort // If valid, then local peer is publicly accessible. | ||||||
|  |  | ||||||
|  | 	pubAddrs *pubAddrStore | ||||||
|  |  | ||||||
|  | 	// The purpose of this state machine is to manage the RemotePeer object, | ||||||
|  | 	// publishing it as necessary. | ||||||
|  | 	staged RemotePeer // Local copy of shared data. See publish(). | ||||||
|  |  | ||||||
|  | 	// Mutable peer data. | ||||||
|  | 	peer *m.Peer | ||||||
|  |  | ||||||
|  | 	// We rate limit per remote endpoint because if we don't we tend to lose | ||||||
|  | 	// packets. | ||||||
|  | 	limiter *ratelimiter.Limiter | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *State) OnPeerUpdate(peer *m.Peer) PeerState { | ||||||
|  | 	defer func() { | ||||||
|  | 		// Don't defer directly otherwise s.staged will be evaluated immediately | ||||||
|  | 		// and won't reflect changes made in the function. | ||||||
|  | 		s.publish(s.staged) | ||||||
|  | 	}() | ||||||
|  |  | ||||||
|  | 	if peer == nil { | ||||||
|  | 		return EnterStateDisconnected(s) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	s.peer = peer | ||||||
|  | 	s.staged.Relay = false | ||||||
|  | 	s.staged.Direct = false | ||||||
|  | 	s.staged.DirectAddr = netip.AddrPort{} | ||||||
|  | 	s.staged.PubSignKey = nil | ||||||
|  | 	s.staged.PubSignKey = peer.PubSignKey | ||||||
|  | 	s.staged.ControlCipher = newControlCipher(s.privKey, peer.PubKey) | ||||||
|  | 	s.staged.DataCipher = newDataCipher() | ||||||
|  |  | ||||||
|  | 	if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid { | ||||||
|  | 		s.staged.Relay = peer.Relay | ||||||
|  | 		s.staged.Direct = true | ||||||
|  | 		s.staged.DirectAddr = netip.AddrPortFrom(ip, peer.Port) | ||||||
|  |  | ||||||
|  | 		if s.localAddr.IsValid() && s.localIP < s.remoteIP { | ||||||
|  | 			return EnterStateServer(s) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		return EnterStateClientDirect(s) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if s.localAddr.IsValid() { | ||||||
|  | 		s.staged.Direct = true | ||||||
|  | 		return EnterStateServer(s) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if s.localIP < s.remoteIP { | ||||||
|  | 		return EnterStateServer(s) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return EnterStateClientRelayed(s) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *State) logf(format string, args ...any) { | ||||||
|  | 	b := strings.Builder{} | ||||||
|  | 	name := "--" | ||||||
|  | 	if s.peer != nil { | ||||||
|  | 		name = s.peer.Name | ||||||
|  | 	} | ||||||
|  | 	b.WriteString(fmt.Sprintf("%30s: ", name)) | ||||||
|  |  | ||||||
|  | 	if s.staged.Direct { | ||||||
|  | 		b.WriteString("DIRECT  | ") | ||||||
|  | 	} else { | ||||||
|  | 		b.WriteString("RELAYED | ") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if s.staged.Up { | ||||||
|  | 		b.WriteString("UP   | ") | ||||||
|  | 	} else { | ||||||
|  | 		b.WriteString("DOWN | ") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	log.Printf(b.String()+format, args...) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | func (s *State) SendTo(pkt Marshaller, addr netip.AddrPort) { | ||||||
|  | 	if !addr.IsValid() { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	route := s.staged | ||||||
|  | 	route.Direct = true | ||||||
|  | 	route.DirectAddr = addr | ||||||
|  | 	s.Send(route, pkt) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *State) Send(peer RemotePeer, pkt Marshaller) { | ||||||
|  | 	if err := s.limiter.Limit(); err != nil { | ||||||
|  | 		s.logf("Rate limited.") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	s.sendControlPacket(peer, pkt) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type StateDisconnected struct{ *State } | ||||||
|  |  | ||||||
|  | func EnterStateDisconnected(s *State) PeerState { | ||||||
|  | 	s.logf("==> Disconnected") | ||||||
|  | 	s.peer = nil | ||||||
|  | 	s.staged.Up = false | ||||||
|  | 	s.staged.Relay = false | ||||||
|  | 	s.staged.Direct = false | ||||||
|  | 	s.staged.DirectAddr = netip.AddrPort{} | ||||||
|  | 	s.staged.PubSignKey = nil | ||||||
|  | 	s.staged.ControlCipher = nil | ||||||
|  | 	s.staged.DataCipher = nil | ||||||
|  | 	s.publish(s.staged) | ||||||
|  | 	return &StateDisconnected{State: s} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *StateDisconnected) OnSyn(controlMsg[PacketSyn]) PeerState             { return nil } | ||||||
|  | func (s *StateDisconnected) OnAck(controlMsg[PacketAck])                       {} | ||||||
|  | func (s *StateDisconnected) OnProbe(controlMsg[PacketProbe]) PeerState         { return nil } | ||||||
|  | func (s *StateDisconnected) OnLocalDiscovery(controlMsg[PacketLocalDiscovery]) {} | ||||||
|  | func (s *StateDisconnected) OnPingTimer() PeerState                            { return nil } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type StateServer struct { | ||||||
|  | 	*StateDisconnected | ||||||
|  | 	lastSeen   time.Time | ||||||
|  | 	synTraceID uint64 | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func EnterStateServer(s *State) PeerState { | ||||||
|  | 	s.logf("==> Server") | ||||||
|  | 	return &StateServer{StateDisconnected: &StateDisconnected{State: s}} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *StateServer) OnSyn(msg controlMsg[PacketSyn]) PeerState { | ||||||
|  | 	s.lastSeen = time.Now() | ||||||
|  | 	p := msg.Packet | ||||||
|  |  | ||||||
|  | 	// Before we can respond to this packet, we need to make sure the | ||||||
|  | 	// route is setup properly. | ||||||
|  | 	// | ||||||
|  | 	// The client will update the syn's TraceID whenever there's a change. | ||||||
|  | 	// The server will follow the client's request. | ||||||
|  | 	if p.TraceID != s.synTraceID || !s.staged.Up { | ||||||
|  | 		s.synTraceID = p.TraceID | ||||||
|  | 		s.staged.Up = true | ||||||
|  | 		s.staged.Direct = p.Direct | ||||||
|  | 		s.staged.DataCipher = newDataCipherFromKey(p.SharedKey) | ||||||
|  | 		s.staged.DirectAddr = msg.SrcAddr | ||||||
|  | 		s.publish(s.staged) | ||||||
|  | 		s.logf("Got SYN.") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Always respond. | ||||||
|  | 	ack := PacketAck{ | ||||||
|  | 		TraceID:       p.TraceID, | ||||||
|  | 		ToAddr:        s.staged.DirectAddr, | ||||||
|  | 		PossibleAddrs: s.pubAddrs.Get(), | ||||||
|  | 	} | ||||||
|  | 	s.Send(s.staged, ack) | ||||||
|  |  | ||||||
|  | 	if p.Direct { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, addr := range msg.Packet.PossibleAddrs { | ||||||
|  | 		if !addr.IsValid() { | ||||||
|  | 			break | ||||||
|  | 		} | ||||||
|  | 		s.SendTo(PacketProbe{TraceID: newTraceID()}, addr) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *StateServer) OnProbe(msg controlMsg[PacketProbe]) PeerState { | ||||||
|  | 	if msg.SrcAddr.IsValid() { | ||||||
|  | 		s.SendTo(PacketProbe{TraceID: msg.Packet.TraceID}, msg.SrcAddr) | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *StateServer) OnPingTimer() PeerState { | ||||||
|  | 	if time.Since(s.lastSeen) > timeoutInterval && s.staged.Up { | ||||||
|  | 		s.staged.Up = false | ||||||
|  | 		s.publish(s.staged) | ||||||
|  | 		s.logf("Timeout.") | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type StateClientDirect struct { | ||||||
|  | 	*StateDisconnected | ||||||
|  | 	lastSeen time.Time | ||||||
|  | 	syn      PacketSyn | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func EnterStateClientDirect(s *State) PeerState { | ||||||
|  | 	s.logf("==> ClientDirect") | ||||||
|  | 	return NewStateClientDirect(s) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func NewStateClientDirect(s *State) *StateClientDirect { | ||||||
|  | 	state := &StateClientDirect{ | ||||||
|  | 		StateDisconnected: &StateDisconnected{s}, | ||||||
|  | 		lastSeen:          time.Now(), // Avoid immediate timeout. | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	state.syn = PacketSyn{ | ||||||
|  | 		TraceID:       newTraceID(), | ||||||
|  | 		SharedKey:     s.staged.DataCipher.Key(), | ||||||
|  | 		Direct:        s.staged.Direct, | ||||||
|  | 		PossibleAddrs: s.pubAddrs.Get(), | ||||||
|  | 	} | ||||||
|  | 	state.Send(s.staged, state.syn) | ||||||
|  | 	return state | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *StateClientDirect) OnAck(msg controlMsg[PacketAck]) { | ||||||
|  | 	if msg.Packet.TraceID != s.syn.TraceID { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	s.lastSeen = time.Now() | ||||||
|  |  | ||||||
|  | 	if !s.staged.Up { | ||||||
|  | 		s.staged.Up = true | ||||||
|  | 		s.publish(s.staged) | ||||||
|  | 		s.logf("Got ACK.") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	s.pubAddrs.Store(msg.Packet.ToAddr) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *StateClientDirect) OnPingTimer() PeerState { | ||||||
|  | 	if time.Since(s.lastSeen) > timeoutInterval { | ||||||
|  | 		if s.staged.Up { | ||||||
|  | 			s.staged.Up = false | ||||||
|  | 			s.publish(s.staged) | ||||||
|  | 			s.logf("Timeout.") | ||||||
|  | 		} | ||||||
|  | 		return s.OnPeerUpdate(s.peer) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	s.Send(s.staged, s.syn) | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type StateClientRelayed struct { | ||||||
|  | 	*StateClientDirect | ||||||
|  | 	ack                PacketAck | ||||||
|  | 	probes             map[uint64]netip.AddrPort | ||||||
|  | 	localDiscoveryAddr netip.AddrPort | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func EnterStateClientRelayed(s *State) PeerState { | ||||||
|  | 	s.logf("==> ClientRelayed") | ||||||
|  | 	return &StateClientRelayed{ | ||||||
|  | 		StateClientDirect: NewStateClientDirect(s), | ||||||
|  | 		probes:            map[uint64]netip.AddrPort{}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *StateClientRelayed) OnAck(msg controlMsg[PacketAck]) { | ||||||
|  | 	s.ack = msg.Packet | ||||||
|  | 	s.StateClientDirect.OnAck(msg) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *StateClientRelayed) OnProbe(msg controlMsg[PacketProbe]) PeerState { | ||||||
|  | 	addr, ok := s.probes[msg.Packet.TraceID] | ||||||
|  | 	if !ok { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	s.staged.DirectAddr = addr | ||||||
|  | 	s.staged.Direct = true | ||||||
|  | 	s.publish(s.staged) | ||||||
|  | 	return EnterStateClientDirect(s.StateClientDirect.State) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *StateClientRelayed) OnLocalDiscovery(msg controlMsg[PacketLocalDiscovery]) { | ||||||
|  | 	// The source port will be the multicast port, so we'll have to | ||||||
|  | 	// construct the correct address using the peer's listed port. | ||||||
|  | 	s.localDiscoveryAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *StateClientRelayed) OnPingTimer() PeerState { | ||||||
|  | 	if nextState := s.StateClientDirect.OnPingTimer(); nextState != nil { | ||||||
|  | 		return nextState | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	clear(s.probes) | ||||||
|  | 	for _, addr := range s.ack.PossibleAddrs { | ||||||
|  | 		if !addr.IsValid() { | ||||||
|  | 			break | ||||||
|  | 		} | ||||||
|  | 		s.sendProbeTo(addr) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if s.localDiscoveryAddr.IsValid() { | ||||||
|  | 		s.sendProbeTo(s.localDiscoveryAddr) | ||||||
|  | 		s.localDiscoveryAddr = netip.AddrPort{} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *StateClientRelayed) sendProbeTo(addr netip.AddrPort) { | ||||||
|  | 	probe := PacketProbe{TraceID: newTraceID()} | ||||||
|  | 	s.probes[probe.TraceID] = addr | ||||||
|  | 	s.SendTo(probe, addr) | ||||||
|  | } | ||||||
							
								
								
									
										509
									
								
								peer/peerstates_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										509
									
								
								peer/peerstates_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,509 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"net/netip" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  | 	"vppn/m" | ||||||
|  |  | ||||||
|  | 	"git.crumpington.com/lib/go/ratelimiter" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type PeerStateControlMsg struct { | ||||||
|  | 	Peer   RemotePeer | ||||||
|  | 	Packet any | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type PeerStateTestHarness struct { | ||||||
|  | 	State     PeerState | ||||||
|  | 	Published RemotePeer | ||||||
|  | 	Sent      []PeerStateControlMsg | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func NewPeerStateTestHarness() *PeerStateTestHarness { | ||||||
|  | 	h := &PeerStateTestHarness{} | ||||||
|  |  | ||||||
|  | 	keys := generateKeys() | ||||||
|  |  | ||||||
|  | 	state := &State{ | ||||||
|  | 		publish: func(rp RemotePeer) { | ||||||
|  | 			h.Published = rp | ||||||
|  | 		}, | ||||||
|  | 		sendControlPacket: func(rp RemotePeer, pkt Marshaller) { | ||||||
|  | 			h.Sent = append(h.Sent, PeerStateControlMsg{rp, pkt}) | ||||||
|  | 		}, | ||||||
|  | 		localIP:  2, | ||||||
|  | 		remoteIP: 3, | ||||||
|  | 		privKey:  keys.PrivKey, | ||||||
|  | 		pubAddrs: newPubAddrStore(netip.AddrPort{}), | ||||||
|  | 		limiter: ratelimiter.New(ratelimiter.Config{ | ||||||
|  | 			FillPeriod:   20 * time.Millisecond, | ||||||
|  | 			MaxWaitCount: 1, | ||||||
|  | 		}), | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	h.State = EnterStateDisconnected(state) | ||||||
|  | 	return h | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (h *PeerStateTestHarness) PeerUpdate(p *m.Peer) { | ||||||
|  | 	if s := h.State.OnPeerUpdate(p); s != nil { | ||||||
|  | 		h.State = s | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (h *PeerStateTestHarness) OnSyn(msg controlMsg[PacketSyn]) { | ||||||
|  | 	if s := h.State.OnSyn(msg); s != nil { | ||||||
|  | 		h.State = s | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (h *PeerStateTestHarness) OnProbe(msg controlMsg[PacketProbe]) { | ||||||
|  | 	if s := h.State.OnProbe(msg); s != nil { | ||||||
|  | 		h.State = s | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (h *PeerStateTestHarness) OnPingTimer() { | ||||||
|  | 	if s := h.State.OnPingTimer(); s != nil { | ||||||
|  | 		h.State = s | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (h *PeerStateTestHarness) ConfigServer_Public(t *testing.T) *StateServer { | ||||||
|  | 	keys := generateKeys() | ||||||
|  |  | ||||||
|  | 	state := h.State.(*StateDisconnected) | ||||||
|  | 	state.localAddr = addrPort4(1, 1, 1, 2, 200) | ||||||
|  |  | ||||||
|  | 	peer := &m.Peer{ | ||||||
|  | 		PeerIP:     3, | ||||||
|  | 		PublicIP:   []byte{1, 1, 1, 3}, | ||||||
|  | 		Port:       456, | ||||||
|  | 		PubKey:     keys.PubKey, | ||||||
|  | 		PubSignKey: keys.PubSignKey, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	h.PeerUpdate(peer) | ||||||
|  | 	assertEqual(t, h.Published.Up, false) | ||||||
|  | 	return assertType[*StateServer](t, h.State) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (h *PeerStateTestHarness) ConfigServer_Relayed(t *testing.T) *StateServer { | ||||||
|  | 	keys := generateKeys() | ||||||
|  | 	peer := &m.Peer{ | ||||||
|  | 		PeerIP:     3, | ||||||
|  | 		Port:       456, | ||||||
|  | 		PubKey:     keys.PubKey, | ||||||
|  | 		PubSignKey: keys.PubSignKey, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	h.PeerUpdate(peer) | ||||||
|  | 	assertEqual(t, h.Published.Up, false) | ||||||
|  | 	return assertType[*StateServer](t, h.State) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (h *PeerStateTestHarness) ConfigClientDirect(t *testing.T) *StateClientDirect { | ||||||
|  | 	keys := generateKeys() | ||||||
|  | 	peer := &m.Peer{ | ||||||
|  | 		PeerIP:     3, | ||||||
|  | 		PublicIP:   []byte{1, 2, 3, 4}, | ||||||
|  | 		Port:       456, | ||||||
|  | 		PubKey:     keys.PubKey, | ||||||
|  | 		PubSignKey: keys.PubSignKey, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	h.PeerUpdate(peer) | ||||||
|  | 	assertEqual(t, h.Published.Up, false) | ||||||
|  | 	return assertType[*StateClientDirect](t, h.State) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (h *PeerStateTestHarness) ConfigClientRelayed(t *testing.T) *StateClientRelayed { | ||||||
|  | 	keys := generateKeys() | ||||||
|  |  | ||||||
|  | 	state := h.State.(*StateDisconnected) | ||||||
|  | 	state.remoteIP = 1 | ||||||
|  |  | ||||||
|  | 	peer := &m.Peer{ | ||||||
|  | 		PeerIP:     3, | ||||||
|  | 		Port:       456, | ||||||
|  | 		PubKey:     keys.PubKey, | ||||||
|  | 		PubSignKey: keys.PubSignKey, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	h.PeerUpdate(peer) | ||||||
|  | 	assertEqual(t, h.Published.Up, false) | ||||||
|  | 	return assertType[*StateClientRelayed](t, h.State) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | func TestPeerState_OnPeerUpdate_nilPeer(t *testing.T) { | ||||||
|  | 	h := NewPeerStateTestHarness() | ||||||
|  | 	h.PeerUpdate(nil) | ||||||
|  | 	assertType[*StateDisconnected](t, h.State) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestPeerState_OnPeerUpdate_publicLocalIsServer(t *testing.T) { | ||||||
|  | 	keys := generateKeys() | ||||||
|  | 	h := NewPeerStateTestHarness() | ||||||
|  |  | ||||||
|  | 	state := h.State.(*StateDisconnected) | ||||||
|  | 	state.localAddr = addrPort4(1, 1, 1, 2, 200) | ||||||
|  |  | ||||||
|  | 	peer := &m.Peer{ | ||||||
|  | 		PeerIP:     3, | ||||||
|  | 		Port:       456, | ||||||
|  | 		PubKey:     keys.PubKey, | ||||||
|  | 		PubSignKey: keys.PubSignKey, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	h.PeerUpdate(peer) | ||||||
|  | 	assertEqual(t, h.Published.Up, false) | ||||||
|  | 	assertType[*StateServer](t, h.State) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestPeerState_OnPeerUpdate_serverDirect(t *testing.T) { | ||||||
|  | 	h := NewPeerStateTestHarness() | ||||||
|  | 	h.ConfigServer_Public(t) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestPeerState_OnPeerUpdate_serverRelayed(t *testing.T) { | ||||||
|  | 	h := NewPeerStateTestHarness() | ||||||
|  | 	h.ConfigServer_Relayed(t) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestPeerState_OnPeerUpdate_clientDirect(t *testing.T) { | ||||||
|  | 	h := NewPeerStateTestHarness() | ||||||
|  | 	h.ConfigClientDirect(t) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestPeerState_OnPeerUpdate_clientRelayed(t *testing.T) { | ||||||
|  | 	h := NewPeerStateTestHarness() | ||||||
|  | 	h.ConfigClientRelayed(t) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestStateServer_directSyn(t *testing.T) { | ||||||
|  | 	h := NewPeerStateTestHarness() | ||||||
|  | 	h.ConfigServer_Relayed(t) | ||||||
|  |  | ||||||
|  | 	assertEqual(t, h.Published.Up, false) | ||||||
|  |  | ||||||
|  | 	synMsg := controlMsg[PacketSyn]{ | ||||||
|  | 		SrcIP:   3, | ||||||
|  | 		SrcAddr: addrPort4(1, 1, 1, 3, 300), | ||||||
|  | 		Packet: PacketSyn{ | ||||||
|  | 			TraceID: newTraceID(), | ||||||
|  | 			//SentAt:        time.Now().UnixMilli(), | ||||||
|  | 			//SharedKeyType: 1, | ||||||
|  | 			Direct: true, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	h.State.OnSyn(synMsg) | ||||||
|  |  | ||||||
|  | 	assertEqual(t, len(h.Sent), 1) | ||||||
|  | 	ack := assertType[PacketAck](t, h.Sent[0].Packet) | ||||||
|  | 	assertEqual(t, ack.TraceID, synMsg.Packet.TraceID) | ||||||
|  | 	assertEqual(t, h.Sent[0].Peer.IP, 3) | ||||||
|  | 	assertEqual(t, ack.PossibleAddrs[0].IsValid(), false) | ||||||
|  | 	assertEqual(t, h.Published.Up, true) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestStateServer_relayedSyn(t *testing.T) { | ||||||
|  | 	h := NewPeerStateTestHarness() | ||||||
|  | 	state := h.ConfigServer_Relayed(t) | ||||||
|  |  | ||||||
|  | 	state.pubAddrs.Store(addrPort4(4, 5, 6, 7, 1234)) | ||||||
|  |  | ||||||
|  | 	assertEqual(t, h.Published.Up, false) | ||||||
|  |  | ||||||
|  | 	synMsg := controlMsg[PacketSyn]{ | ||||||
|  | 		SrcIP:   3, | ||||||
|  | 		SrcAddr: addrPort4(1, 1, 1, 3, 300), | ||||||
|  | 		Packet: PacketSyn{ | ||||||
|  | 			TraceID: newTraceID(), | ||||||
|  | 			//SentAt:        time.Now().UnixMilli(), | ||||||
|  | 			//SharedKeyType: 1, | ||||||
|  | 			Direct: false, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	synMsg.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 3, 300) | ||||||
|  | 	synMsg.Packet.PossibleAddrs[1] = addrPort4(2, 2, 2, 3, 300) | ||||||
|  |  | ||||||
|  | 	h.State.OnSyn(synMsg) | ||||||
|  |  | ||||||
|  | 	assertEqual(t, len(h.Sent), 3) | ||||||
|  |  | ||||||
|  | 	ack := assertType[PacketAck](t, h.Sent[0].Packet) | ||||||
|  | 	assertEqual(t, ack.TraceID, synMsg.Packet.TraceID) | ||||||
|  | 	assertEqual(t, h.Sent[0].Peer.IP, 3) | ||||||
|  | 	assertEqual(t, ack.PossibleAddrs[0], addrPort4(4, 5, 6, 7, 1234)) | ||||||
|  | 	assertEqual(t, ack.PossibleAddrs[1].IsValid(), false) | ||||||
|  | 	assertEqual(t, h.Published.Up, true) | ||||||
|  |  | ||||||
|  | 	assertType[PacketProbe](t, h.Sent[1].Packet) | ||||||
|  | 	assertType[PacketProbe](t, h.Sent[2].Packet) | ||||||
|  | 	assertEqual(t, h.Sent[1].Peer.DirectAddr, addrPort4(1, 1, 1, 3, 300)) | ||||||
|  | 	assertEqual(t, h.Sent[2].Peer.DirectAddr, addrPort4(2, 2, 2, 3, 300)) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestStateServer_onProbe(t *testing.T) { | ||||||
|  | 	h := NewPeerStateTestHarness() | ||||||
|  | 	h.ConfigServer_Relayed(t) | ||||||
|  | 	assertEqual(t, h.Published.Up, false) | ||||||
|  |  | ||||||
|  | 	probeMsg := controlMsg[PacketProbe]{ | ||||||
|  | 		SrcIP:   3, | ||||||
|  | 		SrcAddr: addrPort4(1, 1, 1, 3, 300), | ||||||
|  | 		Packet:  PacketProbe{TraceID: newTraceID()}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	h.State.OnProbe(probeMsg) | ||||||
|  |  | ||||||
|  | 	assertEqual(t, len(h.Sent), 1) | ||||||
|  |  | ||||||
|  | 	probe := assertType[PacketProbe](t, h.Sent[0].Packet) | ||||||
|  | 	assertEqual(t, probe.TraceID, probeMsg.Packet.TraceID) | ||||||
|  | 	assertEqual(t, h.Sent[0].Peer.DirectAddr, addrPort4(1, 1, 1, 3, 300)) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestStateServer_OnPingTimer_timeout(t *testing.T) { | ||||||
|  | 	h := NewPeerStateTestHarness() | ||||||
|  | 	h.ConfigServer_Relayed(t) | ||||||
|  |  | ||||||
|  | 	synMsg := controlMsg[PacketSyn]{ | ||||||
|  | 		SrcIP:   3, | ||||||
|  | 		SrcAddr: addrPort4(1, 1, 1, 3, 300), | ||||||
|  | 		Packet: PacketSyn{ | ||||||
|  | 			TraceID: newTraceID(), | ||||||
|  | 			//SentAt:        time.Now().UnixMilli(), | ||||||
|  | 			//SharedKeyType: 1, | ||||||
|  | 			Direct: true, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	h.State.OnSyn(synMsg) | ||||||
|  | 	assertEqual(t, len(h.Sent), 1) | ||||||
|  | 	assertEqual(t, h.Published.Up, true) | ||||||
|  |  | ||||||
|  | 	// Ping shouldn't timeout. | ||||||
|  | 	h.OnPingTimer() | ||||||
|  | 	assertEqual(t, h.Published.Up, true) | ||||||
|  |  | ||||||
|  | 	// Advance the time, then ping. | ||||||
|  | 	state := assertType[*StateServer](t, h.State) | ||||||
|  | 	state.lastSeen = time.Now().Add(-timeoutInterval - time.Second) | ||||||
|  |  | ||||||
|  | 	h.OnPingTimer() | ||||||
|  | 	assertEqual(t, h.Published.Up, false) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestStateClientDirect_OnAck(t *testing.T) { | ||||||
|  | 	h := NewPeerStateTestHarness() | ||||||
|  | 	h.ConfigClientDirect(t) | ||||||
|  |  | ||||||
|  | 	assertEqual(t, h.Published.Up, false) | ||||||
|  |  | ||||||
|  | 	// On entering the state, a SYN should have been sent. | ||||||
|  | 	assertEqual(t, len(h.Sent), 1) | ||||||
|  | 	syn := assertType[PacketSyn](t, h.Sent[0].Packet) | ||||||
|  |  | ||||||
|  | 	ack := controlMsg[PacketAck]{ | ||||||
|  | 		Packet: PacketAck{TraceID: syn.TraceID}, | ||||||
|  | 	} | ||||||
|  | 	h.State.OnAck(ack) | ||||||
|  | 	assertEqual(t, h.Published.Up, true) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestStateClientDirect_OnAck_incorrectTraceID(t *testing.T) { | ||||||
|  | 	h := NewPeerStateTestHarness() | ||||||
|  | 	h.ConfigClientDirect(t) | ||||||
|  |  | ||||||
|  | 	assertEqual(t, h.Published.Up, false) | ||||||
|  |  | ||||||
|  | 	// On entering the state, a SYN should have been sent. | ||||||
|  | 	assertEqual(t, len(h.Sent), 1) | ||||||
|  | 	syn := assertType[PacketSyn](t, h.Sent[0].Packet) | ||||||
|  |  | ||||||
|  | 	ack := controlMsg[PacketAck]{ | ||||||
|  | 		Packet: PacketAck{TraceID: syn.TraceID + 1}, | ||||||
|  | 	} | ||||||
|  | 	h.State.OnAck(ack) | ||||||
|  | 	assertEqual(t, h.Published.Up, false) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestStateClientDirect_OnPingTimer(t *testing.T) { | ||||||
|  | 	h := NewPeerStateTestHarness() | ||||||
|  | 	h.ConfigClientDirect(t) | ||||||
|  |  | ||||||
|  | 	// On entering the state, a SYN should have been sent. | ||||||
|  | 	assertEqual(t, len(h.Sent), 1) | ||||||
|  | 	assertType[PacketSyn](t, h.Sent[0].Packet) | ||||||
|  |  | ||||||
|  | 	h.OnPingTimer() | ||||||
|  |  | ||||||
|  | 	// On ping timer, another syn should be sent. Additionally, we should remain | ||||||
|  | 	// in the same state. | ||||||
|  | 	assertEqual(t, len(h.Sent), 2) | ||||||
|  | 	assertType[PacketSyn](t, h.Sent[1].Packet) | ||||||
|  | 	assertType[*StateClientDirect](t, h.State) | ||||||
|  | 	assertEqual(t, h.Published.Up, false) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestStateClientDirect_OnPingTimer_timeout(t *testing.T) { | ||||||
|  | 	h := NewPeerStateTestHarness() | ||||||
|  | 	h.ConfigClientDirect(t) | ||||||
|  |  | ||||||
|  | 	assertEqual(t, h.Published.Up, false) | ||||||
|  |  | ||||||
|  | 	// On entering the state, a SYN should have been sent. | ||||||
|  | 	assertEqual(t, len(h.Sent), 1) | ||||||
|  | 	syn := assertType[PacketSyn](t, h.Sent[0].Packet) | ||||||
|  |  | ||||||
|  | 	ack := controlMsg[PacketAck]{ | ||||||
|  | 		Packet: PacketAck{TraceID: syn.TraceID}, | ||||||
|  | 	} | ||||||
|  | 	h.State.OnAck(ack) | ||||||
|  | 	assertEqual(t, h.Published.Up, true) | ||||||
|  |  | ||||||
|  | 	state := assertType[*StateClientDirect](t, h.State) | ||||||
|  | 	state.lastSeen = time.Now().Add(-(timeoutInterval + time.Second)) | ||||||
|  |  | ||||||
|  | 	h.OnPingTimer() | ||||||
|  |  | ||||||
|  | 	// On ping timer, we should timeout, causing the client to reset. Another SYN | ||||||
|  | 	// will be sent when re-entering the state, but the connection should be down. | ||||||
|  | 	assertEqual(t, len(h.Sent), 2) | ||||||
|  | 	assertType[PacketSyn](t, h.Sent[1].Packet) | ||||||
|  | 	assertType[*StateClientDirect](t, h.State) | ||||||
|  | 	assertEqual(t, h.Published.Up, false) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestStateClientRelayed_OnAck(t *testing.T) { | ||||||
|  | 	h := NewPeerStateTestHarness() | ||||||
|  | 	h.ConfigClientRelayed(t) | ||||||
|  |  | ||||||
|  | 	assertEqual(t, h.Published.Up, false) | ||||||
|  |  | ||||||
|  | 	// On entering the state, a SYN should have been sent. | ||||||
|  | 	assertEqual(t, len(h.Sent), 1) | ||||||
|  | 	syn := assertType[PacketSyn](t, h.Sent[0].Packet) | ||||||
|  |  | ||||||
|  | 	ack := controlMsg[PacketAck]{ | ||||||
|  | 		Packet: PacketAck{TraceID: syn.TraceID}, | ||||||
|  | 	} | ||||||
|  | 	h.State.OnAck(ack) | ||||||
|  | 	assertEqual(t, h.Published.Up, true) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestStateClientRelayed_OnPingTimer_noAddrs(t *testing.T) { | ||||||
|  | 	h := NewPeerStateTestHarness() | ||||||
|  | 	h.ConfigClientRelayed(t) | ||||||
|  |  | ||||||
|  | 	assertEqual(t, h.Published.Up, false) | ||||||
|  |  | ||||||
|  | 	// On entering the state, a SYN should have been sent. | ||||||
|  | 	assertEqual(t, len(h.Sent), 1) | ||||||
|  |  | ||||||
|  | 	// If we haven't had an ack yet, we won't have addresses to probe. Therefore | ||||||
|  | 	// we'll have just one more syn packet sent. | ||||||
|  | 	h.OnPingTimer() | ||||||
|  | 	assertEqual(t, len(h.Sent), 2) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestStateClientRelayed_OnPingTimer_withAddrs(t *testing.T) { | ||||||
|  | 	h := NewPeerStateTestHarness() | ||||||
|  | 	h.ConfigClientRelayed(t) | ||||||
|  |  | ||||||
|  | 	assertEqual(t, h.Published.Up, false) | ||||||
|  |  | ||||||
|  | 	// On entering the state, a SYN should have been sent. | ||||||
|  | 	assertEqual(t, len(h.Sent), 1) | ||||||
|  |  | ||||||
|  | 	syn := assertType[PacketSyn](t, h.Sent[0].Packet) | ||||||
|  |  | ||||||
|  | 	ack := controlMsg[PacketAck]{Packet: PacketAck{TraceID: syn.TraceID}} | ||||||
|  | 	ack.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 300) | ||||||
|  | 	ack.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 300) | ||||||
|  |  | ||||||
|  | 	h.State.OnAck(ack) | ||||||
|  |  | ||||||
|  | 	// Add a local discovery address. Note that the port will be configured port | ||||||
|  | 	// and no the one provided here. | ||||||
|  | 	h.State.OnLocalDiscovery(controlMsg[PacketLocalDiscovery]{ | ||||||
|  | 		SrcIP:   3, | ||||||
|  | 		SrcAddr: addrPort4(2, 2, 2, 3, 300), | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	// We should see one SYN and three probe packets. | ||||||
|  | 	h.OnPingTimer() | ||||||
|  | 	assertEqual(t, len(h.Sent), 5) | ||||||
|  | 	assertType[PacketSyn](t, h.Sent[1].Packet) | ||||||
|  | 	assertType[PacketProbe](t, h.Sent[2].Packet) | ||||||
|  | 	assertType[PacketProbe](t, h.Sent[3].Packet) | ||||||
|  | 	assertType[PacketProbe](t, h.Sent[4].Packet) | ||||||
|  |  | ||||||
|  | 	assertEqual(t, h.Sent[2].Peer.DirectAddr, addrPort4(1, 1, 1, 1, 300)) | ||||||
|  | 	assertEqual(t, h.Sent[3].Peer.DirectAddr, addrPort4(1, 1, 1, 2, 300)) | ||||||
|  | 	assertEqual(t, h.Sent[4].Peer.DirectAddr, addrPort4(2, 2, 2, 3, 456)) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestStateClientRelayed_OnPingTimer_timeout(t *testing.T) { | ||||||
|  | 	h := NewPeerStateTestHarness() | ||||||
|  | 	h.ConfigClientRelayed(t) | ||||||
|  |  | ||||||
|  | 	// On entering the state, a SYN should have been sent. | ||||||
|  | 	assertEqual(t, len(h.Sent), 1) | ||||||
|  | 	syn := assertType[PacketSyn](t, h.Sent[0].Packet) | ||||||
|  |  | ||||||
|  | 	ack := controlMsg[PacketAck]{ | ||||||
|  | 		Packet: PacketAck{TraceID: syn.TraceID}, | ||||||
|  | 	} | ||||||
|  | 	h.State.OnAck(ack) | ||||||
|  | 	assertEqual(t, h.Published.Up, true) | ||||||
|  |  | ||||||
|  | 	state := assertType[*StateClientRelayed](t, h.State) | ||||||
|  | 	state.lastSeen = time.Now().Add(-(timeoutInterval + time.Second)) | ||||||
|  |  | ||||||
|  | 	h.OnPingTimer() | ||||||
|  |  | ||||||
|  | 	// On ping timer, we should timeout, causing the client to reset. Another SYN | ||||||
|  | 	// will be sent when re-entering the state, but the connection should be down. | ||||||
|  | 	assertEqual(t, len(h.Sent), 2) | ||||||
|  | 	assertType[PacketSyn](t, h.Sent[1].Packet) | ||||||
|  | 	assertType[*StateClientRelayed](t, h.State) | ||||||
|  | 	assertEqual(t, h.Published.Up, false) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestStateClientRelayed_OnProbe_unknownAddr(t *testing.T) { | ||||||
|  | 	h := NewPeerStateTestHarness() | ||||||
|  | 	h.ConfigClientRelayed(t) | ||||||
|  |  | ||||||
|  | 	h.OnProbe(controlMsg[PacketProbe]{ | ||||||
|  | 		Packet: PacketProbe{TraceID: newTraceID()}, | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	assertType[*StateClientRelayed](t, h.State) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestStateClientRelayed_OnProbe_upgradeDirect(t *testing.T) { | ||||||
|  | 	h := NewPeerStateTestHarness() | ||||||
|  | 	h.ConfigClientRelayed(t) | ||||||
|  |  | ||||||
|  | 	syn := assertType[PacketSyn](t, h.Sent[0].Packet) | ||||||
|  |  | ||||||
|  | 	ack := controlMsg[PacketAck]{Packet: PacketAck{TraceID: syn.TraceID}} | ||||||
|  | 	ack.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 300) | ||||||
|  | 	ack.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 300) | ||||||
|  |  | ||||||
|  | 	h.State.OnAck(ack) | ||||||
|  | 	h.OnPingTimer() | ||||||
|  |  | ||||||
|  | 	probe := assertType[PacketProbe](t, h.Sent[2].Packet) | ||||||
|  | 	h.OnProbe(controlMsg[PacketProbe]{Packet: probe}) | ||||||
|  |  | ||||||
|  | 	assertType[*StateClientDirect](t, h.State) | ||||||
|  | } | ||||||
							
								
								
									
										75
									
								
								peer/pubaddrs.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										75
									
								
								peer/pubaddrs.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,75 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"log" | ||||||
|  | 	"net/netip" | ||||||
|  | 	"runtime/debug" | ||||||
|  | 	"sort" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type pubAddrStore struct { | ||||||
|  | 	localPub  bool | ||||||
|  | 	localAddr netip.AddrPort | ||||||
|  | 	lastSeen  map[netip.AddrPort]time.Time | ||||||
|  | 	addrList  []netip.AddrPort | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func newPubAddrStore(localAddr netip.AddrPort) *pubAddrStore { | ||||||
|  | 	return &pubAddrStore{ | ||||||
|  | 		localPub:  localAddr.IsValid(), | ||||||
|  | 		localAddr: localAddr, | ||||||
|  | 		lastSeen:  map[netip.AddrPort]time.Time{}, | ||||||
|  | 		addrList:  make([]netip.AddrPort, 0, 32), | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (store *pubAddrStore) Store(add netip.AddrPort) { | ||||||
|  | 	if store.localPub { | ||||||
|  | 		log.Printf("OOPS: Local pub but storage attempt: %s", debug.Stack()) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if !add.IsValid() { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if _, exists := store.lastSeen[add]; !exists { | ||||||
|  | 		store.addrList = append(store.addrList, add) | ||||||
|  | 	} | ||||||
|  | 	store.lastSeen[add] = time.Now() | ||||||
|  | 	store.sort() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (store *pubAddrStore) Get() (addrs [8]netip.AddrPort) { | ||||||
|  | 	if store.localPub { | ||||||
|  | 		addrs[0] = store.localAddr | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	copy(addrs[:], store.addrList) | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (store *pubAddrStore) Clean() { | ||||||
|  | 	if store.localPub { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for ip, lastSeen := range store.lastSeen { | ||||||
|  | 		if time.Since(lastSeen) > timeoutInterval { | ||||||
|  | 			delete(store.lastSeen, ip) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	store.addrList = store.addrList[:0] | ||||||
|  | 	for ip := range store.lastSeen { | ||||||
|  | 		store.addrList = append(store.addrList, ip) | ||||||
|  | 	} | ||||||
|  | 	store.sort() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (store *pubAddrStore) sort() { | ||||||
|  | 	sort.Slice(store.addrList, func(i, j int) bool { | ||||||
|  | 		return store.lastSeen[store.addrList[j]].Before(store.lastSeen[store.addrList[i]]) | ||||||
|  | 	}) | ||||||
|  | } | ||||||
							
								
								
									
										29
									
								
								peer/pubaddrs_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								peer/pubaddrs_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,29 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"net/netip" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestPubAddrStore(t *testing.T) { | ||||||
|  | 	s := newPubAddrStore(netip.AddrPort{}) | ||||||
|  |  | ||||||
|  | 	l := []netip.AddrPort{ | ||||||
|  | 		netip.AddrPortFrom(netip.AddrFrom4([4]byte{0, 1, 2, 3}), 20), | ||||||
|  | 		netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 2, 3}), 21), | ||||||
|  | 		netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 1, 2, 3}), 22), | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for i := range l { | ||||||
|  | 		s.Store(l[i]) | ||||||
|  | 		time.Sleep(time.Millisecond) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	s.Clean() | ||||||
|  |  | ||||||
|  | 	l2 := s.Get() | ||||||
|  | 	if l2[0] != l[2] || l2[1] != l[1] || l2[2] != l[0] { | ||||||
|  | 		t.Fatal(l, l2) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										137
									
								
								peer/routingtable.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										137
									
								
								peer/routingtable.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,137 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"net/netip" | ||||||
|  | 	"sync/atomic" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // TODO: Remove | ||||||
|  | func NewRemotePeer(ip byte) *RemotePeer { | ||||||
|  | 	counter := uint64(time.Now().Unix()<<30 + 1) | ||||||
|  | 	return &RemotePeer{ | ||||||
|  | 		IP:       ip, | ||||||
|  | 		counter:  &counter, | ||||||
|  | 		dupCheck: newDupCheck(0), | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type RemotePeer struct { | ||||||
|  | 	localIP       byte | ||||||
|  | 	IP            byte           // VPN IP of peer (last byte). | ||||||
|  | 	Up            bool           // True if data can be sent on the peer. | ||||||
|  | 	Relay         bool           // True if the peer is a relay. | ||||||
|  | 	Direct        bool           // True if this is a direct connection. | ||||||
|  | 	DirectAddr    netip.AddrPort // Remote address if directly connected. | ||||||
|  | 	PubSignKey    []byte | ||||||
|  | 	ControlCipher *controlCipher | ||||||
|  | 	DataCipher    *dataCipher | ||||||
|  |  | ||||||
|  | 	counter  *uint64   // For sending to. Atomic access only. | ||||||
|  | 	dupCheck *dupCheck // For receiving from. Not safe for concurrent use. | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (p RemotePeer) EncryptDataPacket(destIP byte, data, out []byte) []byte { | ||||||
|  | 	h := header{ | ||||||
|  | 		StreamID: dataStreamID, | ||||||
|  | 		Counter:  atomic.AddUint64(p.counter, 1), | ||||||
|  | 		SourceIP: p.localIP, | ||||||
|  | 		DestIP:   destIP, | ||||||
|  | 	} | ||||||
|  | 	return p.DataCipher.Encrypt(h, data, out) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Decrypts and de-dups incoming data packets. | ||||||
|  | func (p RemotePeer) DecryptDataPacket(h header, enc, out []byte) ([]byte, error) { | ||||||
|  | 	dec, ok := p.DataCipher.Decrypt(enc, out) | ||||||
|  | 	if !ok { | ||||||
|  | 		return nil, errDecryptionFailed | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if p.dupCheck.IsDup(h.Counter) { | ||||||
|  | 		return nil, errDuplicateSeqNum | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return dec, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Peer must have a ControlCipher. | ||||||
|  | func (p RemotePeer) EncryptControlPacket(pkt Marshaller, tmp, out []byte) []byte { | ||||||
|  | 	h := header{ | ||||||
|  | 		StreamID: controlStreamID, | ||||||
|  | 		Counter:  atomic.AddUint64(p.counter, 1), | ||||||
|  | 		SourceIP: p.localIP, | ||||||
|  | 		DestIP:   p.IP, | ||||||
|  | 	} | ||||||
|  | 	tmp = pkt.Marshal(tmp) | ||||||
|  | 	return p.ControlCipher.Encrypt(h, tmp, out) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Returns a controlMsg[PacketType]. Peer must have a non-nil ControlCipher. | ||||||
|  | // | ||||||
|  | // This function also drops packets with duplicate sequence numbers. | ||||||
|  | func (p RemotePeer) DecryptControlPacket(fromAddr netip.AddrPort, h header, enc, tmp []byte) (any, error) { | ||||||
|  | 	out, ok := p.ControlCipher.Decrypt(enc, tmp) | ||||||
|  | 	if !ok { | ||||||
|  | 		return nil, errDecryptionFailed | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if p.dupCheck.IsDup(h.Counter) { | ||||||
|  | 		return nil, errDuplicateSeqNum | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	msg, err := parseControlMsg(h.SourceIP, fromAddr, out) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return msg, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type RoutingTable struct { | ||||||
|  | 	// The LocalIP is the configured IP address of the local peer on the VPN. | ||||||
|  | 	// | ||||||
|  | 	// This value is constant. | ||||||
|  | 	LocalIP byte | ||||||
|  |  | ||||||
|  | 	// The LocalAddr is the configured local public address of the peer on the | ||||||
|  | 	// internet. If LocalAddr.IsValid(), then the local peer has a public | ||||||
|  | 	// address. | ||||||
|  | 	// | ||||||
|  | 	// This value is constant. | ||||||
|  | 	LocalAddr netip.AddrPort | ||||||
|  |  | ||||||
|  | 	// The remote peer configurations. These are updated by | ||||||
|  | 	Peers [256]RemotePeer | ||||||
|  |  | ||||||
|  | 	// The current relay's VPN IP address, or zero if no relay is available. | ||||||
|  | 	RelayIP byte | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func NewRoutingTable(localIP byte, localAddr netip.AddrPort) RoutingTable { | ||||||
|  | 	rt := RoutingTable{ | ||||||
|  | 		LocalIP:   localIP, | ||||||
|  | 		LocalAddr: localAddr, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for i := range rt.Peers { | ||||||
|  | 		counter := uint64(time.Now().Unix()<<30 + 1) | ||||||
|  | 		rt.Peers[i] = RemotePeer{ | ||||||
|  | 			localIP:  localIP, | ||||||
|  | 			IP:       byte(i), | ||||||
|  | 			counter:  &counter, | ||||||
|  | 			dupCheck: newDupCheck(0), | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return rt | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (rt *RoutingTable) GetRelay() (RemotePeer, bool) { | ||||||
|  | 	relay := rt.Peers[rt.RelayIP] | ||||||
|  | 	return relay, relay.Up && relay.Direct | ||||||
|  | } | ||||||
							
								
								
									
										169
									
								
								peer/routingtable_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										169
									
								
								peer/routingtable_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,169 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"reflect" | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestRemotePeer_DecryptDataPacket(t *testing.T) { | ||||||
|  | 	p1, p2, _ := NewPeersForTesting() | ||||||
|  | 	orig := RandPacket() | ||||||
|  |  | ||||||
|  | 	peer2 := p1.RT.Load().Peers[2] | ||||||
|  | 	peer1 := p2.RT.Load().Peers[1] | ||||||
|  |  | ||||||
|  | 	enc := peer2.EncryptDataPacket(2, orig, newBuf()) | ||||||
|  |  | ||||||
|  | 	h := parseHeader(enc) | ||||||
|  | 	if h.DestIP != 2 || h.SourceIP != 1 { | ||||||
|  | 		t.Fatal(h) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	dec, err := peer1.DecryptDataPacket(h, enc, newBuf()) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if !bytes.Equal(orig, dec) { | ||||||
|  | 		t.Fatal(dec) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestRemotePeer_DecryptDataPacket_packetAltered(t *testing.T) { | ||||||
|  | 	p1, p2, _ := NewPeersForTesting() | ||||||
|  | 	orig := RandPacket() | ||||||
|  |  | ||||||
|  | 	peer2 := p1.RT.Load().Peers[2] | ||||||
|  | 	peer1 := p2.RT.Load().Peers[1] | ||||||
|  |  | ||||||
|  | 	enc := peer2.EncryptDataPacket(2, orig, newBuf()) | ||||||
|  |  | ||||||
|  | 	h := parseHeader(enc) | ||||||
|  |  | ||||||
|  | 	for range 2048 { | ||||||
|  | 		_, err := peer1.DecryptDataPacket(h, ModifyPacket(enc), newBuf()) | ||||||
|  | 		if err == nil { | ||||||
|  | 			t.Fatal(enc) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestRemotePeer_DecryptDataPacket_duplicateSequenceNumber(t *testing.T) { | ||||||
|  | 	p1, p2, _ := NewPeersForTesting() | ||||||
|  | 	orig := RandPacket() | ||||||
|  |  | ||||||
|  | 	peer2 := p1.RT.Load().Peers[2] | ||||||
|  | 	peer1 := p2.RT.Load().Peers[1] | ||||||
|  |  | ||||||
|  | 	enc := peer2.EncryptDataPacket(2, orig, newBuf()) | ||||||
|  | 	h := parseHeader(enc) | ||||||
|  |  | ||||||
|  | 	if _, err := peer1.DecryptDataPacket(h, enc, newBuf()); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if _, err := peer1.DecryptDataPacket(h, enc, newBuf()); err == nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestRemotePeer_DecryptControlPacket(t *testing.T) { | ||||||
|  | 	p1, p2, _ := NewPeersForTesting() | ||||||
|  |  | ||||||
|  | 	peer2 := p1.RT.Load().Peers[2] | ||||||
|  | 	peer1 := p2.RT.Load().Peers[1] | ||||||
|  |  | ||||||
|  | 	orig := PacketProbe{TraceID: newTraceID()} | ||||||
|  |  | ||||||
|  | 	enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf()) | ||||||
|  |  | ||||||
|  | 	h := parseHeader(enc) | ||||||
|  | 	if h.DestIP != 2 || h.SourceIP != 1 { | ||||||
|  | 		t.Fatal(h) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	ctrlMsg, err := peer1.DecryptControlPacket(p1.RT.Load().LocalAddr, h, enc, newBuf()) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	dec, ok := ctrlMsg.(controlMsg[PacketProbe]) | ||||||
|  | 	if !ok { | ||||||
|  | 		t.Fatal(ctrlMsg) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if dec.SrcIP != 1 || dec.SrcAddr != p1.RT.Load().LocalAddr { | ||||||
|  | 		t.Fatal(dec) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if !reflect.DeepEqual(dec.Packet, orig) { | ||||||
|  | 		t.Fatal(dec) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestRemotePeer_DecryptControlPacket_packetAltered(t *testing.T) { | ||||||
|  | 	p1, p2, _ := NewPeersForTesting() | ||||||
|  |  | ||||||
|  | 	peer2 := p1.RT.Load().Peers[2] | ||||||
|  | 	peer1 := p2.RT.Load().Peers[1] | ||||||
|  |  | ||||||
|  | 	orig := PacketProbe{TraceID: newTraceID()} | ||||||
|  |  | ||||||
|  | 	enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf()) | ||||||
|  |  | ||||||
|  | 	h := parseHeader(enc) | ||||||
|  | 	if h.DestIP != 2 || h.SourceIP != 1 { | ||||||
|  | 		t.Fatal(h) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for range 2048 { | ||||||
|  | 		ctrlMsg, err := peer1.DecryptControlPacket(p1.RT.Load().LocalAddr, h, ModifyPacket(enc), newBuf()) | ||||||
|  | 		if err == nil { | ||||||
|  | 			t.Fatal(ctrlMsg) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestRemotePeer_DecryptControlPacket_duplicateSequenceNumber(t *testing.T) { | ||||||
|  | 	p1, p2, _ := NewPeersForTesting() | ||||||
|  |  | ||||||
|  | 	peer2 := p1.RT.Load().Peers[2] | ||||||
|  | 	peer1 := p2.RT.Load().Peers[1] | ||||||
|  |  | ||||||
|  | 	orig := PacketProbe{TraceID: newTraceID()} | ||||||
|  |  | ||||||
|  | 	enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf()) | ||||||
|  |  | ||||||
|  | 	h := parseHeader(enc) | ||||||
|  | 	if h.DestIP != 2 || h.SourceIP != 1 { | ||||||
|  | 		t.Fatal(h) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if _, err := peer1.DecryptControlPacket(p1.RT.Load().LocalAddr, h, enc, newBuf()); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	if _, err := peer1.DecryptControlPacket(p1.RT.Load().LocalAddr, h, enc, newBuf()); err == nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestRemotePeer_DecryptControlPacket_unknownPacketType(t *testing.T) { | ||||||
|  | 	p1, p2, _ := NewPeersForTesting() | ||||||
|  |  | ||||||
|  | 	peer2 := p1.RT.Load().Peers[2] | ||||||
|  | 	peer1 := p2.RT.Load().Peers[1] | ||||||
|  |  | ||||||
|  | 	orig := UnknownControlPacket{TraceID: newTraceID()} | ||||||
|  |  | ||||||
|  | 	enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf()) | ||||||
|  |  | ||||||
|  | 	h := parseHeader(enc) | ||||||
|  | 	if h.DestIP != 2 || h.SourceIP != 1 { | ||||||
|  | 		t.Fatal(h) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if _, err := peer1.DecryptControlPacket(p1.RT.Load().LocalAddr, h, enc, newBuf()); err == nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @@ -1,29 +0,0 @@ | |||||||
| package peer |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"net/netip" |  | ||||||
| 	"time" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| type RemotePeer struct { |  | ||||||
| 	IP            byte           // VPN IP of peer (last byte). |  | ||||||
| 	Up            bool           // True if data can be sent on the peer. |  | ||||||
| 	Relay         bool           // True if the peer is a relay. |  | ||||||
| 	Direct        bool           // True if this is a direct connection. |  | ||||||
| 	DirectAddr    netip.AddrPort // Remote address if directly connected. |  | ||||||
| 	PubSignKey    []byte |  | ||||||
| 	ControlCipher *controlCipher |  | ||||||
| 	DataCipher    *dataCipher |  | ||||||
|  |  | ||||||
| 	Counter  *uint64   // For sending to. Atomic access only. |  | ||||||
| 	DupCheck *dupCheck // For receiving from. Not safe for concurrent use. |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func NewRemotePeer(ip byte) *RemotePeer { |  | ||||||
| 	counter := uint64(time.Now().Unix()<<30 + 1) |  | ||||||
| 	return &RemotePeer{ |  | ||||||
| 		IP:       ip, |  | ||||||
| 		Counter:  &counter, |  | ||||||
| 		DupCheck: newDupCheck(0), |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
							
								
								
									
										103
									
								
								peer/supervisor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										103
									
								
								peer/supervisor.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,103 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"log" | ||||||
|  | 	"sync/atomic" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	"git.crumpington.com/lib/go/ratelimiter" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type Supervisor struct { | ||||||
|  | 	messages chan any // Incoming control messages. | ||||||
|  | 	peers    [256]PeerState | ||||||
|  | 	pubAddrs *pubAddrStore | ||||||
|  | 	rt       *atomic.Pointer[RoutingTable] | ||||||
|  | 	staged   RoutingTable | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func NewSupervisor( | ||||||
|  | 	sendControl func(RemotePeer, Marshaller), | ||||||
|  | 	privKey []byte, | ||||||
|  | 	rt *atomic.Pointer[RoutingTable], | ||||||
|  | ) *Supervisor { | ||||||
|  | 	s := &Supervisor{ | ||||||
|  | 		messages: make(chan any, 1024), | ||||||
|  | 		pubAddrs: newPubAddrStore(rt.Load().LocalAddr), | ||||||
|  | 		rt:       rt, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	routes := rt.Load() | ||||||
|  |  | ||||||
|  | 	for i := range s.peers { | ||||||
|  | 		state := &State{ | ||||||
|  | 			publish:           s.Publish, | ||||||
|  | 			sendControlPacket: sendControl, | ||||||
|  | 			localIP:           routes.LocalIP, | ||||||
|  | 			remoteIP:          byte(i), | ||||||
|  | 			privKey:           privKey, | ||||||
|  | 			localAddr:         routes.LocalAddr, | ||||||
|  | 			pubAddrs:          s.pubAddrs, | ||||||
|  | 			staged:            routes.Peers[i], | ||||||
|  | 			limiter: ratelimiter.New(ratelimiter.Config{ | ||||||
|  | 				FillPeriod:   20 * time.Millisecond, | ||||||
|  | 				MaxWaitCount: 1, | ||||||
|  | 			}), | ||||||
|  | 		} | ||||||
|  | 		s.peers[i] = state.OnPeerUpdate(nil) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return s | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *Supervisor) HandleControlMsg(msg any) { | ||||||
|  | 	select { | ||||||
|  | 	case s.messages <- msg: | ||||||
|  | 	default: | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *Supervisor) Run() { | ||||||
|  | 	for raw := range s.messages { | ||||||
|  | 		switch msg := raw.(type) { | ||||||
|  |  | ||||||
|  | 		case peerUpdateMsg: | ||||||
|  | 			s.peers[msg.PeerIP] = s.peers[msg.PeerIP].OnPeerUpdate(msg.Peer) | ||||||
|  |  | ||||||
|  | 		case controlMsg[PacketSyn]: | ||||||
|  | 			if newState := s.peers[msg.SrcIP].OnSyn(msg); newState != nil { | ||||||
|  | 				s.peers[msg.SrcIP] = newState | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 		case controlMsg[PacketAck]: | ||||||
|  | 			s.peers[msg.SrcIP].OnAck(msg) | ||||||
|  |  | ||||||
|  | 		case controlMsg[PacketProbe]: | ||||||
|  | 			if newState := s.peers[msg.SrcIP].OnProbe(msg); newState != nil { | ||||||
|  | 				s.peers[msg.SrcIP] = newState | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 		case controlMsg[PacketLocalDiscovery]: | ||||||
|  | 			s.peers[msg.SrcIP].OnLocalDiscovery(msg) | ||||||
|  |  | ||||||
|  | 		case pingTimerMsg: | ||||||
|  | 			s.pubAddrs.Clean() | ||||||
|  | 			for i := range s.peers { | ||||||
|  | 				if newState := s.peers[i].OnPingTimer(); newState != nil { | ||||||
|  | 					s.peers[i] = newState | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 		default: | ||||||
|  | 			log.Printf("WARNING: unknown message type: %+v", msg) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *Supervisor) Publish(rp RemotePeer) { | ||||||
|  | 	s.staged.Peers[rp.IP] = rp | ||||||
|  | 	rt := s.staged // Copy. | ||||||
|  | 	s.rt.Store(&rt) | ||||||
|  | } | ||||||
							
								
								
									
										26
									
								
								peer/util_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								peer/util_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,26 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"net/netip" | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func addrPort4(a, b, c, d byte, port uint16) netip.AddrPort { | ||||||
|  | 	return netip.AddrPortFrom(netip.AddrFrom4([4]byte{a, b, c, d}), port) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func assertType[T any](t *testing.T, obj any) T { | ||||||
|  | 	t.Helper() | ||||||
|  | 	x, ok := obj.(T) | ||||||
|  | 	if !ok { | ||||||
|  | 		t.Fatal("invalid type", obj) | ||||||
|  | 	} | ||||||
|  | 	return x | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func assertEqual[T comparable](t *testing.T, a, b T) { | ||||||
|  | 	t.Helper() | ||||||
|  | 	if a != b { | ||||||
|  | 		t.Fatal(a, " != ", b) | ||||||
|  | 	} | ||||||
|  | } | ||||||
		Reference in New Issue
	
	Block a user