WIP
This commit is contained in:
		
							
								
								
									
										132
									
								
								peer/connreader2.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										132
									
								
								peer/connreader2.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,132 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"net/netip" | ||||
| 	"sync/atomic" | ||||
| ) | ||||
|  | ||||
| type ConnReader struct { | ||||
| 	// Input | ||||
| 	readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error) | ||||
|  | ||||
| 	// Output | ||||
| 	iface            io.Writer | ||||
| 	forwardData      func(ip byte, pkt []byte) | ||||
| 	handleControlMsg func(pkt any) | ||||
|  | ||||
| 	localIP byte | ||||
| 	rt      *atomic.Pointer[RoutingTable] | ||||
|  | ||||
| 	buf    []byte | ||||
| 	decBuf []byte | ||||
| } | ||||
|  | ||||
| func NewConnReader( | ||||
| 	readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error), | ||||
| 	iface io.Writer, | ||||
| 	forwardData func(ip byte, pkt []byte), | ||||
| 	handleControlMsg func(pkt any), | ||||
| 	rt *atomic.Pointer[RoutingTable], | ||||
| ) *ConnReader { | ||||
| 	return &ConnReader{ | ||||
| 		readFromUDPAddrPort: readFromUDPAddrPort, | ||||
| 		iface:               iface, | ||||
| 		forwardData:         forwardData, | ||||
| 		handleControlMsg:    handleControlMsg, | ||||
| 		localIP:             rt.Load().LocalIP, | ||||
| 		rt:                  rt, | ||||
| 		buf:                 newBuf(), | ||||
| 		decBuf:              newBuf(), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (r *ConnReader) Run() { | ||||
| 	for { | ||||
| 		r.handleNextPacket() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (r *ConnReader) handleNextPacket() { | ||||
| 	buf := r.buf[:bufferSize] | ||||
| 	n, remoteAddr, err := r.readFromUDPAddrPort(buf) | ||||
| 	if err != nil { | ||||
| 		log.Fatalf("Failed to read from UDP port: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	if n < headerSize { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	remoteAddr = netip.AddrPortFrom(remoteAddr.Addr().Unmap(), remoteAddr.Port()) | ||||
|  | ||||
| 	buf = buf[:n] | ||||
| 	h := parseHeader(buf) | ||||
|  | ||||
| 	peer := r.rt.Load().Peers[h.SourceIP] | ||||
| 	//peer := rt.Peers[h.SourceIP] | ||||
|  | ||||
| 	switch h.StreamID { | ||||
| 	case controlStreamID: | ||||
| 		r.handleControlPacket(remoteAddr, peer, h, buf) | ||||
| 	case dataStreamID: | ||||
| 		r.handleDataPacket(peer, h, buf) | ||||
| 	default: | ||||
| 		r.logf("Unknown stream ID: %d", h.StreamID) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (r *ConnReader) handleControlPacket( | ||||
| 	remoteAddr netip.AddrPort, | ||||
| 	peer RemotePeer, | ||||
| 	h header, | ||||
| 	enc []byte, | ||||
| ) { | ||||
| 	if peer.ControlCipher == nil { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if h.DestIP != r.localIP { | ||||
| 		r.logf("Incorrect destination IP on control packet: %d", h.DestIP) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	msg, err := peer.DecryptControlPacket(remoteAddr, h, enc, r.decBuf) | ||||
| 	if err != nil { | ||||
| 		r.logf("Failed to decrypt control packet: %v", err) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	r.handleControlMsg(msg) | ||||
| } | ||||
|  | ||||
| func (r *ConnReader) handleDataPacket( | ||||
| 	peer RemotePeer, | ||||
| 	h header, | ||||
| 	enc []byte, | ||||
| ) { | ||||
| 	if !peer.Up { | ||||
| 		r.logf("Not connected (recv).") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	data, err := peer.DecryptDataPacket(h, enc, r.decBuf) | ||||
| 	if err != nil { | ||||
| 		r.logf("Failed to decrypt data packet: %v", err) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if h.DestIP == r.localIP { | ||||
| 		if _, err := r.iface.Write(data); err != nil { | ||||
| 			log.Fatalf("Failed to write to interface: %v", err) | ||||
| 		} | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	r.forwardData(h.DestIP, data) | ||||
| } | ||||
|  | ||||
| func (r *ConnReader) logf(format string, args ...any) { | ||||
| 	log.Printf("[ConnReader] "+format, args...) | ||||
| } | ||||
| @@ -109,7 +109,7 @@ func newConnReadeTestHarness() (h connReaderTestHarness) { | ||||
| func TestConnReader_handleControlPacket(t *testing.T) { | ||||
| 	h := newConnReadeTestHarness() | ||||
|  | ||||
| 	pkt := synPacket{TraceID: 1234} | ||||
| 	pkt := PacketSyn{TraceID: 1234} | ||||
|  | ||||
| 	h.WRemote.SendControlPacket(pkt, h.Remote) | ||||
|  | ||||
| @@ -119,7 +119,7 @@ func TestConnReader_handleControlPacket(t *testing.T) { | ||||
| 		t.Fatal(h.Super.Messages) | ||||
| 	} | ||||
|  | ||||
| 	msg := h.Super.Messages[0].(controlMsg[synPacket]) | ||||
| 	msg := h.Super.Messages[0].(controlMsg[PacketSyn]) | ||||
| 	if !reflect.DeepEqual(pkt, msg.Packet) { | ||||
| 		t.Fatal(msg.Packet) | ||||
| 	} | ||||
| @@ -141,7 +141,7 @@ func TestConnReader_handleNextPacket_short(t *testing.T) { | ||||
| func TestConnReader_handleNextPacket_unknownStreamID(t *testing.T) { | ||||
| 	h := newConnReadeTestHarness() | ||||
|  | ||||
| 	pkt := synPacket{TraceID: 1234} | ||||
| 	pkt := PacketSyn{TraceID: 1234} | ||||
|  | ||||
| 	encrypted := encryptControlPacket(1, h.Remote, pkt, newBuf(), newBuf()) | ||||
| 	var header header | ||||
| @@ -160,7 +160,7 @@ func TestConnReader_handleNextPacket_unknownStreamID(t *testing.T) { | ||||
| func TestConnReader_handleControlPacket_noCipher(t *testing.T) { | ||||
| 	h := newConnReadeTestHarness() | ||||
|  | ||||
| 	pkt := synPacket{TraceID: 1234} | ||||
| 	pkt := PacketSyn{TraceID: 1234} | ||||
|  | ||||
| 	//encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) | ||||
| 	encrypted := encryptControlPacket(1, h.Remote, pkt, newBuf(), newBuf()) | ||||
| @@ -180,7 +180,7 @@ func TestConnReader_handleControlPacket_noCipher(t *testing.T) { | ||||
| func TestConnReader_handleControlPacket_incorrectDest(t *testing.T) { | ||||
| 	h := newConnReadeTestHarness() | ||||
|  | ||||
| 	pkt := synPacket{TraceID: 1234} | ||||
| 	pkt := PacketSyn{TraceID: 1234} | ||||
|  | ||||
| 	encrypted := encryptControlPacket(2, h.Remote, pkt, newBuf(), newBuf()) | ||||
| 	var header header | ||||
| @@ -199,7 +199,7 @@ func TestConnReader_handleControlPacket_incorrectDest(t *testing.T) { | ||||
| func TestConnReader_handleControlPacket_modified(t *testing.T) { | ||||
| 	h := newConnReadeTestHarness() | ||||
|  | ||||
| 	pkt := synPacket{TraceID: 1234} | ||||
| 	pkt := PacketSyn{TraceID: 1234} | ||||
|  | ||||
| 	encrypted := encryptControlPacket(2, h.Remote, pkt, newBuf(), newBuf()) | ||||
| 	encrypted[len(encrypted)-1]++ | ||||
| @@ -237,10 +237,10 @@ func TestConnReader_handleControlPacket_unknownPacketType(t *testing.T) { | ||||
| func TestConnReader_handleControlPacket_duplicate(t *testing.T) { | ||||
| 	h := newConnReadeTestHarness() | ||||
|  | ||||
| 	pkt := ackPacket{TraceID: 1234} | ||||
| 	pkt := PacketAck{TraceID: 1234} | ||||
|  | ||||
| 	h.WRemote.SendControlPacket(pkt, h.Remote) | ||||
| 	*h.Remote.Counter = *h.Remote.Counter - 1 | ||||
| 	*h.Remote.counter = *h.Remote.counter - 1 | ||||
| 	h.WRemote.SendControlPacket(pkt, h.Remote) | ||||
|  | ||||
| 	h.R.handleNextPacket() | ||||
| @@ -250,7 +250,7 @@ func TestConnReader_handleControlPacket_duplicate(t *testing.T) { | ||||
| 		t.Fatal(h.Super.Messages) | ||||
| 	} | ||||
|  | ||||
| 	msg := h.Super.Messages[0].(controlMsg[ackPacket]) | ||||
| 	msg := h.Super.Messages[0].(controlMsg[PacketAck]) | ||||
| 	if !reflect.DeepEqual(pkt, msg.Packet) { | ||||
| 		t.Fatal(msg.Packet) | ||||
| 	} | ||||
| @@ -301,7 +301,7 @@ func TestConnReader_handleDataPacket_duplicate(t *testing.T) { | ||||
| 	pkt := make([]byte, 123) | ||||
|  | ||||
| 	h.WRemote.SendDataPacket(pkt, h.Remote) | ||||
| 	*h.Remote.Counter = *h.Remote.Counter - 1 | ||||
| 	*h.Remote.counter = *h.Remote.counter - 1 | ||||
| 	h.WRemote.SendDataPacket(pkt, h.Remote) | ||||
|  | ||||
| 	h.R.handleNextPacket() | ||||
|   | ||||
| @@ -37,13 +37,13 @@ func newConnWriter(conn udpWriter, localIP byte) *connWriter { | ||||
| } | ||||
|  | ||||
| // Not safe for concurrent use. Should only be called by supervisor. | ||||
| func (w *connWriter) SendControlPacket(pkt marshaller, peer *RemotePeer) { | ||||
| func (w *connWriter) SendControlPacket(pkt Marshaller, peer *RemotePeer) { | ||||
| 	enc := encryptControlPacket(w.localIP, peer, pkt, w.cBuf1, w.cBuf2) | ||||
| 	w.writeTo(enc, peer.DirectAddr) | ||||
| } | ||||
|  | ||||
| // Relay control packet. Peer must not be nil. | ||||
| func (w *connWriter) RelayControlPacket(pkt marshaller, peer, relay *RemotePeer) { | ||||
| func (w *connWriter) RelayControlPacket(pkt Marshaller, peer, relay *RemotePeer) { | ||||
| 	enc := encryptControlPacket(w.localIP, peer, pkt, w.cBuf1, w.cBuf2) | ||||
| 	enc = encryptDataPacket(w.localIP, peer.IP, relay, enc, w.cBuf1) | ||||
| 	w.writeTo(enc, relay.DirectAddr) | ||||
|   | ||||
							
								
								
									
										109
									
								
								peer/connwriter2.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										109
									
								
								peer/connwriter2.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,109 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"log" | ||||
| 	"net/netip" | ||||
| 	"sync" | ||||
| 	"sync/atomic" | ||||
| ) | ||||
|  | ||||
| type ConnWriter struct { | ||||
| 	wLock sync.Mutex // Lock around for sending on UDP Conn. | ||||
|  | ||||
| 	// Output. | ||||
| 	writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error) | ||||
|  | ||||
| 	// Shared state. | ||||
| 	rt *atomic.Pointer[RoutingTable] | ||||
|  | ||||
| 	// For sending control packets. | ||||
| 	cBuf1 []byte | ||||
| 	cBuf2 []byte | ||||
|  | ||||
| 	// For sending data packets. | ||||
| 	dBuf1 []byte | ||||
| 	dBuf2 []byte | ||||
| } | ||||
|  | ||||
| func NewConnWriter( | ||||
| 	writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error), | ||||
| 	rt *atomic.Pointer[RoutingTable], | ||||
| ) *ConnWriter { | ||||
| 	return &ConnWriter{ | ||||
| 		writeToUDPAddrPort: writeToUDPAddrPort, | ||||
| 		rt:                 rt, | ||||
| 		cBuf1:              newBuf(), | ||||
| 		cBuf2:              newBuf(), | ||||
| 		dBuf1:              newBuf(), | ||||
| 		dBuf2:              newBuf(), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Called by ConnReader to forward already encrypted bytes to another peer. | ||||
| func (w *ConnWriter) Forward(ip byte, pkt []byte) { | ||||
| 	peer := w.rt.Load().Peers[ip] | ||||
| 	if !(peer.Up && peer.Direct) { | ||||
| 		w.logf("Failed to forward to %d.", ip) | ||||
| 		return | ||||
| 	} | ||||
| 	w.writeTo(pkt, peer.DirectAddr) | ||||
| } | ||||
|  | ||||
| // Called by IFReader to send data. Encryption will be applied, and packet will | ||||
| // be relayed if appropriate. | ||||
| func (w *ConnWriter) WriteData(ip byte, pkt []byte) { | ||||
| 	rt := w.rt.Load() | ||||
| 	peer := rt.Peers[ip] | ||||
| 	if !peer.Up { | ||||
| 		w.logf("Failed to send data to %d.", ip) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	enc := peer.EncryptDataPacket(ip, pkt, w.dBuf1) | ||||
|  | ||||
| 	if peer.Direct { | ||||
| 		w.writeTo(enc, peer.DirectAddr) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	relay, ok := rt.GetRelay() | ||||
| 	if !ok { | ||||
| 		w.logf("Failed to send data to %d. No relay.", ip) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	enc = relay.EncryptDataPacket(ip, enc, w.dBuf2) | ||||
| 	w.writeTo(enc, relay.DirectAddr) | ||||
| } | ||||
|  | ||||
| // Called by Supervisor to send control packets. | ||||
| func (w *ConnWriter) WriteControl(peer RemotePeer, pkt Marshaller) { | ||||
| 	enc := peer.EncryptControlPacket(pkt, w.cBuf2, w.cBuf1) | ||||
|  | ||||
| 	if peer.Direct { | ||||
| 		w.writeTo(enc, peer.DirectAddr) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	rt := w.rt.Load() | ||||
| 	relay, ok := rt.GetRelay() | ||||
| 	if !ok { | ||||
| 		w.logf("Failed to send control to %d. No relay.", peer.IP) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	enc = relay.EncryptDataPacket(peer.IP, enc, w.cBuf2) | ||||
| 	w.writeTo(enc, relay.DirectAddr) | ||||
| } | ||||
|  | ||||
| func (w *ConnWriter) writeTo(pkt []byte, addr netip.AddrPort) { | ||||
| 	w.wLock.Lock() | ||||
| 	if _, err := w.writeToUDPAddrPort(pkt, addr); err != nil { | ||||
| 		w.logf("Failed to write to UDP port: %v", err) | ||||
| 	} | ||||
| 	w.wLock.Unlock() | ||||
| } | ||||
|  | ||||
| func (w *ConnWriter) logf(s string, args ...any) { | ||||
| 	log.Printf("[ConnWriter] "+s, args...) | ||||
| } | ||||
							
								
								
									
										145
									
								
								peer/connwriter2_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										145
									
								
								peer/connwriter2_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,145 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| func TestConnWriter_WriteData_direct(t *testing.T) { | ||||
| 	p1, p2, _ := NewPeersForTesting() | ||||
|  | ||||
| 	in := RandPacket() | ||||
| 	p1.ConnWriter.WriteData(2, in) | ||||
|  | ||||
| 	packets := p2.Conn.Packets() | ||||
| 	if len(packets) != 1 { | ||||
| 		t.Fatal(packets) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestConnWriter_WriteData_peerNotUp(t *testing.T) { | ||||
| 	p1, p2, _ := NewPeersForTesting() | ||||
| 	p1.RT.Load().Peers[2].Up = false | ||||
|  | ||||
| 	in := RandPacket() | ||||
| 	p1.ConnWriter.WriteData(2, in) | ||||
|  | ||||
| 	packets := p2.Conn.Packets() | ||||
| 	if len(packets) != 0 { | ||||
| 		t.Fatal(packets) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestConnWriter_WriteData_relay(t *testing.T) { | ||||
| 	p1, _, p3 := NewPeersForTesting() | ||||
|  | ||||
| 	p1.RT.Load().Peers[2].Direct = false | ||||
| 	p1.RT.Load().RelayIP = 3 | ||||
|  | ||||
| 	in := RandPacket() | ||||
| 	p1.ConnWriter.WriteData(2, in) | ||||
|  | ||||
| 	packets := p3.Conn.Packets() | ||||
| 	if len(packets) != 1 { | ||||
| 		t.Fatal(packets) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestConnWriter_WriteData_relayNotAvailable(t *testing.T) { | ||||
| 	p1, _, p3 := NewPeersForTesting() | ||||
|  | ||||
| 	p1.RT.Load().Peers[2].Direct = false | ||||
| 	p1.RT.Load().Peers[3].Up = false | ||||
| 	p1.RT.Load().RelayIP = 3 | ||||
|  | ||||
| 	in := RandPacket() | ||||
| 	p1.ConnWriter.WriteData(2, in) | ||||
|  | ||||
| 	packets := p3.Conn.Packets() | ||||
| 	if len(packets) != 0 { | ||||
| 		t.Fatal(packets) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestConnWriter_WriteControl_direct(t *testing.T) { | ||||
| 	p1, p2, _ := NewPeersForTesting() | ||||
|  | ||||
| 	orig := PacketProbe{TraceID: newTraceID()} | ||||
|  | ||||
| 	p1.ConnWriter.WriteControl(p1.RT.Load().Peers[2], orig) | ||||
|  | ||||
| 	packets := p2.Conn.Packets() | ||||
| 	if len(packets) != 1 { | ||||
| 		t.Fatal(packets) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestConnWriter_WriteControl_relay(t *testing.T) { | ||||
| 	p1, _, p3 := NewPeersForTesting() | ||||
|  | ||||
| 	p1.RT.Load().Peers[2].Direct = false | ||||
| 	p1.RT.Load().RelayIP = 3 | ||||
|  | ||||
| 	orig := PacketProbe{TraceID: newTraceID()} | ||||
|  | ||||
| 	p1.ConnWriter.WriteControl(p1.RT.Load().Peers[2], orig) | ||||
|  | ||||
| 	packets := p3.Conn.Packets() | ||||
| 	if len(packets) != 1 { | ||||
| 		t.Fatal(packets) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestConnWriter_WriteControl_relayNotAvailable(t *testing.T) { | ||||
| 	p1, _, p3 := NewPeersForTesting() | ||||
|  | ||||
| 	p1.RT.Load().Peers[2].Direct = false | ||||
| 	p1.RT.Load().Peers[3].Up = false | ||||
| 	p1.RT.Load().RelayIP = 3 | ||||
|  | ||||
| 	orig := PacketProbe{TraceID: newTraceID()} | ||||
|  | ||||
| 	p1.ConnWriter.WriteControl(p1.RT.Load().Peers[2], orig) | ||||
|  | ||||
| 	packets := p3.Conn.Packets() | ||||
| 	if len(packets) != 0 { | ||||
| 		t.Fatal(packets) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestConnWriter__Forward(t *testing.T) { | ||||
| 	p1, p2, _ := NewPeersForTesting() | ||||
|  | ||||
| 	in := RandPacket() | ||||
| 	p1.ConnWriter.Forward(2, in) | ||||
|  | ||||
| 	packets := p2.Conn.Packets() | ||||
| 	if len(packets) != 1 { | ||||
| 		t.Fatal(packets) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestConnWriter__Forward_notUp(t *testing.T) { | ||||
| 	p1, p2, _ := NewPeersForTesting() | ||||
| 	p1.RT.Load().Peers[2].Up = false | ||||
|  | ||||
| 	in := RandPacket() | ||||
| 	p1.ConnWriter.Forward(2, in) | ||||
|  | ||||
| 	packets := p2.Conn.Packets() | ||||
| 	if len(packets) != 0 { | ||||
| 		t.Fatal(packets) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestConnWriter__Forward_notDirect(t *testing.T) { | ||||
| 	p1, p2, _ := NewPeersForTesting() | ||||
| 	p1.RT.Load().Peers[2].Direct = false | ||||
|  | ||||
| 	in := RandPacket() | ||||
| 	p1.ConnWriter.Forward(2, in) | ||||
|  | ||||
| 	packets := p2.Conn.Packets() | ||||
| 	if len(packets) != 0 { | ||||
| 		t.Fatal(packets) | ||||
| 	} | ||||
| } | ||||
| @@ -17,25 +17,25 @@ type controlMsg[T any] struct { | ||||
| func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error) { | ||||
| 	switch buf[0] { | ||||
|  | ||||
| 	case packetTypeSyn: | ||||
| 		packet, err := parseSynPacket(buf) | ||||
| 		return controlMsg[synPacket]{ | ||||
| 	case PacketTypeSyn: | ||||
| 		packet, err := ParsePacketSyn(buf) | ||||
| 		return controlMsg[PacketSyn]{ | ||||
| 			SrcIP:   srcIP, | ||||
| 			SrcAddr: srcAddr, | ||||
| 			Packet:  packet, | ||||
| 		}, err | ||||
|  | ||||
| 	case packetTypeAck: | ||||
| 		packet, err := parseAckPacket(buf) | ||||
| 		return controlMsg[ackPacket]{ | ||||
| 	case PacketTypeAck: | ||||
| 		packet, err := ParsePacketAck(buf) | ||||
| 		return controlMsg[PacketAck]{ | ||||
| 			SrcIP:   srcIP, | ||||
| 			SrcAddr: srcAddr, | ||||
| 			Packet:  packet, | ||||
| 		}, err | ||||
|  | ||||
| 	case packetTypeProbe: | ||||
| 		packet, err := parseProbePacket(buf) | ||||
| 		return controlMsg[probePacket]{ | ||||
| 	case PacketTypeProbe: | ||||
| 		packet, err := ParsePacketProbe(buf) | ||||
| 		return controlMsg[PacketProbe]{ | ||||
| 			SrcIP:   srcIP, | ||||
| 			SrcAddr: srcAddr, | ||||
| 			Packet:  packet, | ||||
|   | ||||
| @@ -37,13 +37,13 @@ func generateKeys() cryptoKeys { | ||||
| func encryptControlPacket( | ||||
| 	localIP byte, | ||||
| 	peer *RemotePeer, | ||||
| 	pkt marshaller, | ||||
| 	pkt Marshaller, | ||||
| 	tmp []byte, | ||||
| 	out []byte, | ||||
| ) []byte { | ||||
| 	h := header{ | ||||
| 		StreamID: controlStreamID, | ||||
| 		Counter:  atomic.AddUint64(peer.Counter, 1), | ||||
| 		Counter:  atomic.AddUint64(peer.counter, 1), | ||||
| 		SourceIP: localIP, | ||||
| 		DestIP:   peer.IP, | ||||
| 	} | ||||
| @@ -66,7 +66,7 @@ func decryptControlPacket( | ||||
| 		return nil, errDecryptionFailed | ||||
| 	} | ||||
|  | ||||
| 	if peer.DupCheck.IsDup(h.Counter) { | ||||
| 	if peer.dupCheck.IsDup(h.Counter) { | ||||
| 		return nil, errDuplicateSeqNum | ||||
| 	} | ||||
|  | ||||
| @@ -89,7 +89,7 @@ func encryptDataPacket( | ||||
| ) []byte { | ||||
| 	h := header{ | ||||
| 		StreamID: dataStreamID, | ||||
| 		Counter:  atomic.AddUint64(peer.Counter, 1), | ||||
| 		Counter:  atomic.AddUint64(peer.counter, 1), | ||||
| 		SourceIP: localIP, | ||||
| 		DestIP:   destIP, | ||||
| 	} | ||||
| @@ -108,7 +108,7 @@ func decryptDataPacket( | ||||
| 		return nil, errDecryptionFailed | ||||
| 	} | ||||
|  | ||||
| 	if peer.DupCheck.IsDup(h.Counter) { | ||||
| 	if peer.dupCheck.IsDup(h.Counter) { | ||||
| 		return nil, errDuplicateSeqNum | ||||
| 	} | ||||
|  | ||||
|   | ||||
| @@ -33,7 +33,7 @@ func TestDecryptControlPacket(t *testing.T) { | ||||
| 		out    = make([]byte, bufferSize) | ||||
| 	) | ||||
|  | ||||
| 	in := synPacket{ | ||||
| 	in := PacketSyn{ | ||||
| 		TraceID:   newTraceID(), | ||||
| 		SharedKey: r1.DataCipher.Key(), | ||||
| 		Direct:    true, | ||||
| @@ -47,7 +47,7 @@ func TestDecryptControlPacket(t *testing.T) { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	msg, ok := iMsg.(controlMsg[synPacket]) | ||||
| 	msg, ok := iMsg.(controlMsg[PacketSyn]) | ||||
| 	if !ok { | ||||
| 		t.Fatal(ok) | ||||
| 	} | ||||
| @@ -64,7 +64,7 @@ func TestDecryptControlPacket_decryptionFailed(t *testing.T) { | ||||
| 		out    = make([]byte, bufferSize) | ||||
| 	) | ||||
|  | ||||
| 	in := synPacket{ | ||||
| 	in := PacketSyn{ | ||||
| 		TraceID:   newTraceID(), | ||||
| 		SharedKey: r1.DataCipher.Key(), | ||||
| 		Direct:    true, | ||||
| @@ -90,7 +90,7 @@ func TestDecryptControlPacket_duplicate(t *testing.T) { | ||||
| 		out    = make([]byte, bufferSize) | ||||
| 	) | ||||
|  | ||||
| 	in := synPacket{ | ||||
| 	in := PacketSyn{ | ||||
| 		TraceID:   newTraceID(), | ||||
| 		SharedKey: r1.DataCipher.Key(), | ||||
| 		Direct:    true, | ||||
|   | ||||
							
								
								
									
										14
									
								
								peer/data-flow.dot
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								peer/data-flow.dot
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,14 @@ | ||||
| digraph d { | ||||
|     ifReader   -> connWriter; | ||||
|     connReader -> ifWriter; | ||||
|     connReader -> connWriter; | ||||
|     connReader -> supervisor; | ||||
|     mcReader   -> supervisor; | ||||
|     supervisor -> connWriter; | ||||
|     supervisor -> mcWriter; | ||||
|     hubPoller  -> supervisor; | ||||
|  | ||||
|     connWriter [shape="box"]; | ||||
|     mcWriter [shape="box"]; | ||||
|     ifWriter [shape="box"]; | ||||
| } | ||||
							
								
								
									
										90
									
								
								peer/files.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										90
									
								
								peer/files.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,90 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"log" | ||||
| 	"os" | ||||
| 	"path/filepath" | ||||
| 	"vppn/m" | ||||
| ) | ||||
|  | ||||
| type localConfig struct { | ||||
| 	m.PeerConfig | ||||
| 	PubKey      []byte | ||||
| 	PrivKey     []byte | ||||
| 	PubSignKey  []byte | ||||
| 	PrivSignKey []byte | ||||
| } | ||||
|  | ||||
| func configDir(netName string) string { | ||||
| 	d, err := os.UserHomeDir() | ||||
| 	if err != nil { | ||||
| 		log.Fatalf("Failed to get user home directory: %v", err) | ||||
| 	} | ||||
| 	return filepath.Join(d, ".vppn", netName) | ||||
| } | ||||
|  | ||||
| func peerConfigPath(netName string) string { | ||||
| 	return filepath.Join(configDir(netName), "peer-config.json") | ||||
| } | ||||
|  | ||||
| func peerStatePath(netName string) string { | ||||
| 	return filepath.Join(configDir(netName), "peer-state.json") | ||||
| } | ||||
|  | ||||
| func storeJson(x any, outPath string) error { | ||||
| 	outDir := filepath.Dir(outPath) | ||||
| 	_ = os.MkdirAll(outDir, 0700) | ||||
|  | ||||
| 	tmpPath := outPath + ".tmp" | ||||
| 	buf, err := json.Marshal(x) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	f, err := os.Create(tmpPath) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if _, err := f.Write(buf); err != nil { | ||||
| 		f.Close() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if err := f.Sync(); err != nil { | ||||
| 		f.Close() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if err := f.Close(); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return os.Rename(tmpPath, outPath) | ||||
| } | ||||
|  | ||||
| func storePeerConfig(netName string, pc localConfig) error { | ||||
| 	return storeJson(pc, peerConfigPath(netName)) | ||||
| } | ||||
|  | ||||
| func storeNetworkState(netName string, ps m.NetworkState) error { | ||||
| 	return storeJson(ps, peerStatePath(netName)) | ||||
| } | ||||
|  | ||||
| func loadJson(dataPath string, ptr any) error { | ||||
| 	data, err := os.ReadFile(dataPath) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return json.Unmarshal(data, ptr) | ||||
| } | ||||
|  | ||||
| func loadPeerConfig(netName string) (pc localConfig, err error) { | ||||
| 	return pc, loadJson(peerConfigPath(netName), &pc) | ||||
| } | ||||
|  | ||||
| func loadNetworkState(netName string) (ps m.NetworkState, err error) { | ||||
| 	return ps, loadJson(peerStatePath(netName), &ps) | ||||
| } | ||||
							
								
								
									
										57
									
								
								peer/files_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										57
									
								
								peer/files_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,57 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"path/filepath" | ||||
| 	"reflect" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| func TestFilePaths(t *testing.T) { | ||||
| 	confDir := configDir("netName") | ||||
| 	if filepath.Base(confDir) != "netName" { | ||||
| 		t.Fatal(confDir) | ||||
| 	} | ||||
| 	if filepath.Base(filepath.Dir(confDir)) != ".vppn" { | ||||
| 		t.Fatal(confDir) | ||||
| 	} | ||||
|  | ||||
| 	path := peerConfigPath("netName") | ||||
| 	if path != filepath.Join(confDir, "peer-config.json") { | ||||
| 		t.Fatal(path) | ||||
| 	} | ||||
|  | ||||
| 	path = peerStatePath("netName") | ||||
| 	if path != filepath.Join(confDir, "peer-state.json") { | ||||
| 		t.Fatal(path) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestStoreLoadJson(t *testing.T) { | ||||
| 	type Object struct { | ||||
| 		Name  string | ||||
| 		Age   int | ||||
| 		Price float64 | ||||
| 	} | ||||
|  | ||||
| 	tmpDir := t.TempDir() | ||||
| 	outPath := filepath.Join(tmpDir, "object.json") | ||||
|  | ||||
| 	obj := Object{ | ||||
| 		Name:  "Jason", | ||||
| 		Age:   22, | ||||
| 		Price: 123.534, | ||||
| 	} | ||||
|  | ||||
| 	if err := storeJson(obj, outPath); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	obj2 := Object{} | ||||
| 	if err := loadJson(outPath, &obj2); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	if !reflect.DeepEqual(obj, obj2) { | ||||
| 		t.Fatal(obj, obj2) | ||||
| 	} | ||||
| } | ||||
| @@ -3,15 +3,21 @@ package peer | ||||
| import ( | ||||
| 	"net" | ||||
| 	"net/netip" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	bufferSize            = 1536 | ||||
| 	if_mtu                = 1200 | ||||
| 	if_queue_len          = 2048 | ||||
| 	bufferSize = 1536 | ||||
|  | ||||
| 	if_mtu       = 1200 | ||||
| 	if_queue_len = 2048 | ||||
|  | ||||
| 	controlCipherOverhead = 16 | ||||
| 	dataCipherOverhead    = 16 | ||||
| 	signOverhead          = 64 | ||||
|  | ||||
| 	pingInterval    = 8 * time.Second | ||||
| 	timeoutInterval = 30 * time.Second | ||||
| ) | ||||
|  | ||||
| var multicastAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom( | ||||
|   | ||||
							
								
								
									
										100
									
								
								peer/hubpoller.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										100
									
								
								peer/hubpoller.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,100 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"time" | ||||
| 	"vppn/m" | ||||
| ) | ||||
|  | ||||
| type hubPoller struct { | ||||
| 	client   *http.Client | ||||
| 	req      *http.Request | ||||
| 	versions [256]int64 | ||||
| 	localIP  byte | ||||
| 	netName  string | ||||
| 	super    controlMsgHandler | ||||
| } | ||||
|  | ||||
| func newHubPoller(localIP byte, netName, hubURL, apiKey string, super controlMsgHandler) (*hubPoller, error) { | ||||
| 	u, err := url.Parse(hubURL) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	u.Path = "/peer/fetch-state/" | ||||
|  | ||||
| 	client := &http.Client{Timeout: 8 * time.Second} | ||||
|  | ||||
| 	req := &http.Request{ | ||||
| 		Method: http.MethodGet, | ||||
| 		URL:    u, | ||||
| 		Header: http.Header{}, | ||||
| 	} | ||||
| 	req.SetBasicAuth("", apiKey) | ||||
|  | ||||
| 	return &hubPoller{ | ||||
| 		client:  client, | ||||
| 		req:     req, | ||||
| 		localIP: localIP, | ||||
| 		netName: netName, | ||||
| 		super:   super, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| func (hp *hubPoller) Run() { | ||||
| 	state, err := loadNetworkState(hp.netName) | ||||
| 	if err != nil { | ||||
| 		log.Printf("Failed to load network state: %v", err) | ||||
| 		log.Printf("Polling hub...") | ||||
| 		hp.pollHub() | ||||
| 	} else { | ||||
| 		hp.applyNetworkState(state) | ||||
| 	} | ||||
|  | ||||
| 	for range time.Tick(64 * time.Second) { | ||||
| 		hp.pollHub() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (hp *hubPoller) pollHub() { | ||||
| 	var state m.NetworkState | ||||
|  | ||||
| 	resp, err := hp.client.Do(hp.req) | ||||
| 	if err != nil { | ||||
| 		log.Printf("Failed to fetch peer state: %v", err) | ||||
| 		return | ||||
| 	} | ||||
| 	body, err := io.ReadAll(resp.Body) | ||||
| 	_ = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		log.Printf("Failed to read body from hub: %v", err) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if err := json.Unmarshal(body, &state); err != nil { | ||||
| 		log.Printf("Failed to unmarshal response from hub: %v\n%s", err, body) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	hp.applyNetworkState(state) | ||||
|  | ||||
| 	if err := storeNetworkState(hp.netName, state); err != nil { | ||||
| 		log.Printf("Failed to store network state: %v", err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (hp *hubPoller) applyNetworkState(state m.NetworkState) { | ||||
| 	for i, peer := range state.Peers { | ||||
| 		if i != int(hp.localIP) { | ||||
| 			if peer == nil || peer.Version != hp.versions[i] { | ||||
| 				hp.super.HandleControlMsg(peerUpdateMsg{PeerIP: byte(i), Peer: state.Peers[i]}) | ||||
| 				if peer != nil { | ||||
| 					hp.versions[i] = peer.Version | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										78
									
								
								peer/ifreader2.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										78
									
								
								peer/ifreader2.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,78 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"io" | ||||
| 	"log" | ||||
| ) | ||||
|  | ||||
| type IFReader struct { | ||||
| 	iface      io.Reader | ||||
| 	connWriter interface { | ||||
| 		WriteData(ip byte, pkt []byte) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func NewIFReader( | ||||
| 	iface io.Reader, | ||||
| 	connWriter interface { | ||||
| 		WriteData(ip byte, pkt []byte) | ||||
| 	}, | ||||
| ) *IFReader { | ||||
| 	return &IFReader{iface, connWriter} | ||||
| } | ||||
|  | ||||
| func (r *IFReader) Run() { | ||||
| 	packet := newBuf() | ||||
| 	for { | ||||
| 		r.handleNextPacket(packet) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (r *IFReader) handleNextPacket(packet []byte) { | ||||
| 	packet = r.readNextPacket(packet) | ||||
| 	if remoteIP, ok := r.parsePacket(packet); ok { | ||||
| 		r.connWriter.WriteData(remoteIP, packet) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (r *IFReader) readNextPacket(buf []byte) []byte { | ||||
| 	n, err := r.iface.Read(buf[:cap(buf)]) | ||||
| 	if err != nil { | ||||
| 		log.Fatalf("Failed to read from interface: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	return buf[:n] | ||||
| } | ||||
|  | ||||
| func (r *IFReader) parsePacket(buf []byte) (byte, bool) { | ||||
| 	n := len(buf) | ||||
| 	if n == 0 { | ||||
| 		return 0, false | ||||
| 	} | ||||
|  | ||||
| 	version := buf[0] >> 4 | ||||
|  | ||||
| 	switch version { | ||||
| 	case 4: | ||||
| 		if n < 20 { | ||||
| 			r.logf("Short IPv4 packet: %d", len(buf)) | ||||
| 			return 0, false | ||||
| 		} | ||||
| 		return buf[19], true | ||||
|  | ||||
| 	case 6: | ||||
| 		if len(buf) < 40 { | ||||
| 			r.logf("Short IPv6 packet: %d", len(buf)) | ||||
| 			return 0, false | ||||
| 		} | ||||
| 		return buf[39], true | ||||
|  | ||||
| 	default: | ||||
| 		r.logf("Invalid IP packet version: %v", version) | ||||
| 		return 0, false | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (*IFReader) logf(s string, args ...any) { | ||||
| 	log.Printf("[IFReader] "+s, args...) | ||||
| } | ||||
							
								
								
									
										83
									
								
								peer/ifreader2_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										83
									
								
								peer/ifreader2_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,83 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| func TestIFReader_IPv4(t *testing.T) { | ||||
| 	p1, p2, _ := NewPeersForTesting() | ||||
|  | ||||
| 	pkt := make([]byte, 1234) | ||||
| 	pkt[0] = 4 << 4 | ||||
| 	pkt[19] = 2 // IP. | ||||
|  | ||||
| 	p1.IFace.UserWrite(pkt) | ||||
| 	p1.IFReader.handleNextPacket(newBuf()) | ||||
|  | ||||
| 	packets := p2.Conn.Packets() | ||||
| 	if len(packets) != 1 { | ||||
| 		t.Fatal(packets) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestIFReader_IPv6(t *testing.T) { | ||||
| 	p1, p2, _ := NewPeersForTesting() | ||||
|  | ||||
| 	pkt := make([]byte, 1234) | ||||
| 	pkt[0] = 6 << 4 | ||||
| 	pkt[39] = 2 // IP. | ||||
|  | ||||
| 	p1.IFace.UserWrite(pkt) | ||||
| 	p1.IFReader.handleNextPacket(newBuf()) | ||||
|  | ||||
| 	packets := p2.Conn.Packets() | ||||
| 	if len(packets) != 1 { | ||||
| 		t.Fatal(packets) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestIFReader_parsePacket_emptyPacket(t *testing.T) { | ||||
| 	r := NewIFReader(nil, nil) | ||||
| 	pkt := make([]byte, 0) | ||||
| 	if ip, ok := r.parsePacket(pkt); ok { | ||||
| 		t.Fatal(ip, ok) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestIFReader_parsePacket_invalidIPVersion(t *testing.T) { | ||||
| 	r := NewIFReader(nil, nil) | ||||
|  | ||||
| 	for i := byte(1); i < 16; i++ { | ||||
| 		if i == 4 || i == 6 { | ||||
| 			continue | ||||
| 		} | ||||
| 		pkt := make([]byte, 1234) | ||||
| 		pkt[0] = i << 4 | ||||
|  | ||||
| 		if ip, ok := r.parsePacket(pkt); ok { | ||||
| 			t.Fatal(i, ip, ok) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestIFReader_parsePacket_shortIPv4(t *testing.T) { | ||||
| 	r := NewIFReader(nil, nil) | ||||
|  | ||||
| 	pkt := make([]byte, 19) | ||||
| 	pkt[0] = 4 << 4 | ||||
|  | ||||
| 	if ip, ok := r.parsePacket(pkt); ok { | ||||
| 		t.Fatal(ip, ok) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestIFReader_parsePacket_shortIPv6(t *testing.T) { | ||||
| 	r := NewIFReader(nil, nil) | ||||
|  | ||||
| 	pkt := make([]byte, 39) | ||||
| 	pkt[0] = 6 << 4 | ||||
|  | ||||
| 	if ip, ok := r.parsePacket(pkt); ok { | ||||
| 		t.Fatal(ip, ok) | ||||
| 	} | ||||
| } | ||||
| @@ -2,7 +2,6 @@ package peer | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"net" | ||||
| 	"reflect" | ||||
| 	"sync/atomic" | ||||
| 	"testing" | ||||
| @@ -34,6 +33,7 @@ func TestIFReader_parsePacket_ipv6(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| /* | ||||
| // Test that empty packets work as expected. | ||||
| func TestIFReader_parsePacket_emptyPacket(t *testing.T) { | ||||
| 	r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) | ||||
| @@ -99,7 +99,7 @@ func TestIFReader_readNextpacket(t *testing.T) { | ||||
| 		t.Fatalf("%s", pkt) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| */ | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type sentPacket struct { | ||||
|   | ||||
| @@ -1,5 +0,0 @@ | ||||
| package peer | ||||
|  | ||||
| import "io" | ||||
|  | ||||
| type ifWriter io.Writer | ||||
| @@ -1,10 +1,19 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"io" | ||||
| 	"net" | ||||
| 	"net/netip" | ||||
| ) | ||||
|  | ||||
| type UDPConn interface { | ||||
| 	ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) | ||||
| 	WriteToUDPAddrPort([]byte, netip.AddrPort) (int, error) | ||||
| 	WriteToUDP([]byte, *net.UDPAddr) (int, error) | ||||
| } | ||||
|  | ||||
| type ifWriter io.Writer | ||||
|  | ||||
| type udpReader interface { | ||||
| 	ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) | ||||
| } | ||||
| @@ -13,7 +22,11 @@ type udpWriter interface { | ||||
| 	WriteToUDPAddrPort([]byte, netip.AddrPort) (int, error) | ||||
| } | ||||
|  | ||||
| type marshaller interface { | ||||
| type mcUDPWriter interface { | ||||
| 	WriteToUDP([]byte, *net.UDPAddr) (int, error) | ||||
| } | ||||
|  | ||||
| type Marshaller interface { | ||||
| 	Marshal([]byte) []byte | ||||
| } | ||||
|  | ||||
| @@ -22,6 +35,11 @@ type dataPacketSender interface { | ||||
| 	RelayDataPacket(pkt []byte, peer, relay *RemotePeer) | ||||
| } | ||||
|  | ||||
| type controlPacketSender interface { | ||||
| 	SendControlPacket(pkt Marshaller, peer *RemotePeer) | ||||
| 	RelayControlPacket(pkt Marshaller, peer, relay *RemotePeer) | ||||
| } | ||||
|  | ||||
| type encryptedPacketSender interface { | ||||
| 	SendEncryptedDataPacket(pkt []byte, peer *RemotePeer) | ||||
| } | ||||
| @@ -29,7 +47,3 @@ type encryptedPacketSender interface { | ||||
| type controlMsgHandler interface { | ||||
| 	HandleControlMsg(pkt any) | ||||
| } | ||||
|  | ||||
| type mcUDPWriter interface { | ||||
| 	WriteToUDP([]byte, *net.UDPAddr) (int, error) | ||||
| } | ||||
|   | ||||
| @@ -50,7 +50,7 @@ func (r *mcReader) handleNextPacket() { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	r.super.HandleControlMsg(controlMsg[localDiscoveryPacket]{ | ||||
| 	r.super.HandleControlMsg(controlMsg[PacketLocalDiscovery]{ | ||||
| 		SrcIP:   h.SourceIP, | ||||
| 		SrcAddr: remoteAddr, | ||||
| 	}) | ||||
|   | ||||
							
								
								
									
										138
									
								
								peer/mcreader_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										138
									
								
								peer/mcreader_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,138 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"net" | ||||
| 	"net/netip" | ||||
| 	"sync/atomic" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| type mcMockConn struct { | ||||
| 	packets chan []byte | ||||
| } | ||||
|  | ||||
| func newMCMockConn() *mcMockConn { | ||||
| 	return &mcMockConn{make(chan []byte, 32)} | ||||
| } | ||||
|  | ||||
| func (c *mcMockConn) WriteToUDP(in []byte, addr *net.UDPAddr) (int, error) { | ||||
| 	c.packets <- bytes.Clone(in) | ||||
| 	return len(in), nil | ||||
| } | ||||
|  | ||||
| func (c *mcMockConn) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) { | ||||
| 	buf := <-c.packets | ||||
| 	b = b[:len(buf)] | ||||
| 	copy(b, buf) | ||||
| 	return len(b), netip.AddrPort{}, nil | ||||
| } | ||||
|  | ||||
| func TestMCReader(t *testing.T) { | ||||
| 	keys := generateKeys() | ||||
| 	super := &mockControlMsgHandler{} | ||||
| 	conn := newMCMockConn() | ||||
|  | ||||
| 	peers := [256]*atomic.Pointer[RemotePeer]{} | ||||
| 	peer := &RemotePeer{ | ||||
| 		IP:         1, | ||||
| 		Up:         true, | ||||
| 		PubSignKey: keys.PubSignKey, | ||||
| 	} | ||||
| 	peers[1] = &atomic.Pointer[RemotePeer]{} | ||||
| 	peers[1].Store(peer) | ||||
|  | ||||
| 	w := newMCWriter(conn, 1, keys.PrivSignKey) | ||||
| 	r := newMCReader(conn, super, peers) | ||||
|  | ||||
| 	w.SendLocalDiscovery() | ||||
| 	r.handleNextPacket() | ||||
|  | ||||
| 	if len(super.Messages) != 1 { | ||||
| 		t.Fatal(super.Messages) | ||||
| 	} | ||||
| 	msg, ok := super.Messages[0].(controlMsg[PacketLocalDiscovery]) | ||||
| 	if !ok || msg.SrcIP != 1 { | ||||
| 		t.Fatal(ok, msg) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestMCReader_noHeader(t *testing.T) { | ||||
| 	keys := generateKeys() | ||||
| 	super := &mockControlMsgHandler{} | ||||
| 	conn := newMCMockConn() | ||||
|  | ||||
| 	peers := [256]*atomic.Pointer[RemotePeer]{} | ||||
| 	peer := &RemotePeer{ | ||||
| 		IP:         1, | ||||
| 		Up:         true, | ||||
| 		PubSignKey: keys.PubSignKey, | ||||
| 	} | ||||
| 	peers[1] = &atomic.Pointer[RemotePeer]{} | ||||
| 	peers[1].Store(peer) | ||||
|  | ||||
| 	r := newMCReader(conn, super, peers) | ||||
| 	conn.WriteToUDP([]byte("0123546789"), nil) | ||||
| 	r.handleNextPacket() | ||||
|  | ||||
| 	if len(super.Messages) != 0 { | ||||
| 		t.Fatal(super.Messages) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestMCReader_noPeer(t *testing.T) { | ||||
| 	keys := generateKeys() | ||||
| 	super := &mockControlMsgHandler{} | ||||
| 	conn := newMCMockConn() | ||||
|  | ||||
| 	peers := [256]*atomic.Pointer[RemotePeer]{} | ||||
| 	peer := &RemotePeer{ | ||||
| 		IP:         1, | ||||
| 		Up:         true, | ||||
| 		PubSignKey: keys.PubSignKey, | ||||
| 	} | ||||
| 	peers[1] = &atomic.Pointer[RemotePeer]{} | ||||
| 	peers[2] = &atomic.Pointer[RemotePeer]{} | ||||
| 	peers[1].Store(peer) | ||||
|  | ||||
| 	w := newMCWriter(conn, 2, keys.PrivSignKey) | ||||
| 	r := newMCReader(conn, super, peers) | ||||
|  | ||||
| 	w.SendLocalDiscovery() | ||||
| 	r.handleNextPacket() | ||||
|  | ||||
| 	if len(super.Messages) != 0 { | ||||
| 		t.Fatal(super.Messages) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestMCReader_badSignature(t *testing.T) { | ||||
| 	keys := generateKeys() | ||||
| 	super := &mockControlMsgHandler{} | ||||
| 	conn := newMCMockConn() | ||||
|  | ||||
| 	peers := [256]*atomic.Pointer[RemotePeer]{} | ||||
| 	peer := &RemotePeer{ | ||||
| 		IP:         1, | ||||
| 		Up:         true, | ||||
| 		PubSignKey: keys.PubSignKey, | ||||
| 	} | ||||
| 	peers[1] = &atomic.Pointer[RemotePeer]{} | ||||
| 	peers[1].Store(peer) | ||||
|  | ||||
| 	w := newMCWriter(conn, 1, keys.PrivSignKey) | ||||
| 	w.SendLocalDiscovery() | ||||
|  | ||||
| 	// Break signing. | ||||
| 	packet := <-conn.packets | ||||
| 	packet[0]++ | ||||
| 	conn.packets <- packet | ||||
|  | ||||
| 	r := newMCReader(conn, super, peers) | ||||
|  | ||||
| 	r.handleNextPacket() | ||||
|  | ||||
| 	if len(super.Messages) != 0 { | ||||
| 		t.Fatal(super.Messages) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										31
									
								
								peer/mock-iface_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								peer/mock-iface_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,31 @@ | ||||
| package peer | ||||
|  | ||||
| import "bytes" | ||||
|  | ||||
| type TestIFace struct { | ||||
| 	out *bytes.Buffer // Toward the network. | ||||
| 	in  *bytes.Buffer // From the network | ||||
| } | ||||
|  | ||||
| func NewTestIFace() *TestIFace { | ||||
| 	return &TestIFace{ | ||||
| 		out: &bytes.Buffer{}, | ||||
| 		in:  &bytes.Buffer{}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (iface *TestIFace) Write(b []byte) (int, error) { | ||||
| 	return iface.in.Write(b) | ||||
| } | ||||
|  | ||||
| func (iface *TestIFace) Read(b []byte) (int, error) { | ||||
| 	return iface.out.Read(b) | ||||
| } | ||||
|  | ||||
| func (iface *TestIFace) UserWrite(b []byte) (int, error) { | ||||
| 	return iface.out.Write(b) | ||||
| } | ||||
|  | ||||
| func (iface *TestIFace) UserRead(b []byte) (int, error) { | ||||
| 	return iface.in.Read(b) | ||||
| } | ||||
							
								
								
									
										80
									
								
								peer/mock-network_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										80
									
								
								peer/mock-network_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,80 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"net" | ||||
| 	"net/netip" | ||||
| 	"sync" | ||||
| ) | ||||
|  | ||||
| type TestPacket struct { | ||||
| 	Addr netip.AddrPort | ||||
| 	Data []byte | ||||
| } | ||||
|  | ||||
| type TestNetwork struct { | ||||
| 	lock    sync.Mutex | ||||
| 	packets map[netip.AddrPort]chan TestPacket | ||||
| } | ||||
|  | ||||
| func NewTestNetwork() *TestNetwork { | ||||
| 	return &TestNetwork{packets: map[netip.AddrPort]chan TestPacket{}} | ||||
| } | ||||
|  | ||||
| func (n *TestNetwork) NewUDPConn(localAddr netip.AddrPort) *TestUDPConn { | ||||
| 	n.lock.Lock() | ||||
| 	defer n.lock.Unlock() | ||||
| 	if _, ok := n.packets[localAddr]; !ok { | ||||
| 		n.packets[localAddr] = make(chan TestPacket, 1024) | ||||
| 	} | ||||
| 	return &TestUDPConn{ | ||||
| 		addr:    localAddr, | ||||
| 		n:       n, | ||||
| 		packets: n.packets[localAddr], | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (n *TestNetwork) write(b []byte, from, to netip.AddrPort) { | ||||
| 	n.lock.Lock() | ||||
| 	defer n.lock.Unlock() | ||||
| 	if _, ok := n.packets[to]; !ok { | ||||
| 		n.packets[to] = make(chan TestPacket, 1024) | ||||
| 	} | ||||
| 	n.packets[to] <- TestPacket{ | ||||
| 		Addr: from, | ||||
| 		Data: bytes.Clone(b), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type TestUDPConn struct { | ||||
| 	addr    netip.AddrPort | ||||
| 	n       *TestNetwork | ||||
| 	packets chan TestPacket | ||||
| } | ||||
|  | ||||
| func (c *TestUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { | ||||
| 	c.n.write(b, c.addr, addr) | ||||
| 	return len(b), nil | ||||
| } | ||||
|  | ||||
| func (c *TestUDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) { | ||||
| 	return c.WriteToUDPAddrPort(b, addr.AddrPort()) | ||||
| } | ||||
|  | ||||
| func (c *TestUDPConn) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) { | ||||
| 	pkt := <-c.packets | ||||
| 	b = b[:len(pkt.Data)] | ||||
| 	copy(b, pkt.Data) | ||||
| 	return len(b), pkt.Addr, nil | ||||
| } | ||||
|  | ||||
| func (c *TestUDPConn) Packets() (out []TestPacket) { | ||||
| 	for { | ||||
| 		select { | ||||
| 		case pkt := <-c.packets: | ||||
| 			out = append(out, pkt) | ||||
| 		default: | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| @@ -70,7 +70,7 @@ func (w *binWriter) AddrPort(addrPort netip.AddrPort) *binWriter { | ||||
| 	return w.Uint16(addrPort.Port()) | ||||
| } | ||||
|  | ||||
| func (w *binWriter) AddrPortArray(l [8]netip.AddrPort) *binWriter { | ||||
| func (w *binWriter) AddrPort8(l [8]netip.AddrPort) *binWriter { | ||||
| 	for _, addrPort := range l { | ||||
| 		w.AddrPort(addrPort) | ||||
| 	} | ||||
| @@ -178,7 +178,7 @@ func (r *binReader) AddrPort(x *netip.AddrPort) *binReader { | ||||
| 	return r | ||||
| } | ||||
|  | ||||
| func (r *binReader) AddrPortArray(x *[8]netip.AddrPort) *binReader { | ||||
| func (r *binReader) AddrPort8(x *[8]netip.AddrPort) *binReader { | ||||
| 	for i := range x { | ||||
| 		r.AddrPort(&x[i]) | ||||
| 	} | ||||
|   | ||||
| @@ -6,6 +6,26 @@ import ( | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| func TestBinWriteRead_invalidAddrPort(t *testing.T) { | ||||
| 	addr := netip.AddrPort{} | ||||
| 	buf := make([]byte, 1024) | ||||
| 	buf = newBinWriter(buf). | ||||
| 		AddrPort(addr). | ||||
| 		Build() | ||||
|  | ||||
| 	var addr2 netip.AddrPort | ||||
| 	err := newBinReader(buf). | ||||
| 		AddrPort(&addr2). | ||||
| 		Error() | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	if addr2.IsValid() { | ||||
| 		t.Fatal(addr, addr2) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestBinWriteRead(t *testing.T) { | ||||
| 	buf := make([]byte, 1024) | ||||
|  | ||||
| @@ -35,7 +55,7 @@ func TestBinWriteRead(t *testing.T) { | ||||
| 		Byte(in.Type). | ||||
| 		Uint64(in.TraceID). | ||||
| 		AddrPort(in.DestAddr). | ||||
| 		AddrPortArray(in.Addrs). | ||||
| 		AddrPort8(in.Addrs). | ||||
| 		Build() | ||||
|  | ||||
| 	out := Item{} | ||||
| @@ -44,7 +64,7 @@ func TestBinWriteRead(t *testing.T) { | ||||
| 		Byte(&out.Type). | ||||
| 		Uint64(&out.TraceID). | ||||
| 		AddrPort(&out.DestAddr). | ||||
| 		AddrPortArray(&out.Addrs). | ||||
| 		AddrPort8(&out.Addrs). | ||||
| 		Error() | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
|   | ||||
| @@ -5,93 +5,70 @@ import ( | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	packetTypeSyn = iota + 1 | ||||
| 	packetTypeSynAck | ||||
| 	packetTypeAck | ||||
| 	packetTypeProbe | ||||
| 	packetTypeAddrDiscovery | ||||
| 	PacketTypeSyn = iota + 1 | ||||
| 	PacketTypeSynAck | ||||
| 	PacketTypeAck | ||||
| 	PacketTypeProbe | ||||
| 	PacketTypeAddrDiscovery | ||||
| ) | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type synPacket struct { | ||||
| type PacketSyn struct { | ||||
| 	TraceID uint64 // TraceID to match response w/ request. | ||||
| 	// TODO: SentAt int64 // Unixmilli. | ||||
| 	//SentAt        int64    // Unixmilli. | ||||
| 	//SharedKeyType byte     // Currently only 1 is supported for AES. | ||||
| 	SharedKey     [32]byte // Our shared key. | ||||
| 	Direct        bool | ||||
| 	PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. | ||||
| } | ||||
|  | ||||
| func (p synPacket) Marshal(buf []byte) []byte { | ||||
| func (p PacketSyn) Marshal(buf []byte) []byte { | ||||
| 	return newBinWriter(buf). | ||||
| 		Byte(packetTypeSyn). | ||||
| 		Byte(PacketTypeSyn). | ||||
| 		Uint64(p.TraceID). | ||||
| 		//Int64(p.SentAt). | ||||
| 		//Byte(p.SharedKeyType). | ||||
| 		SharedKey(p.SharedKey). | ||||
| 		Bool(p.Direct). | ||||
| 		AddrPort(p.PossibleAddrs[0]). | ||||
| 		AddrPort(p.PossibleAddrs[1]). | ||||
| 		AddrPort(p.PossibleAddrs[2]). | ||||
| 		AddrPort(p.PossibleAddrs[3]). | ||||
| 		AddrPort(p.PossibleAddrs[4]). | ||||
| 		AddrPort(p.PossibleAddrs[5]). | ||||
| 		AddrPort(p.PossibleAddrs[6]). | ||||
| 		AddrPort(p.PossibleAddrs[7]). | ||||
| 		AddrPort8(p.PossibleAddrs). | ||||
| 		Build() | ||||
| } | ||||
|  | ||||
| func parseSynPacket(buf []byte) (p synPacket, err error) { | ||||
| func ParsePacketSyn(buf []byte) (p PacketSyn, err error) { | ||||
| 	err = newBinReader(buf[1:]). | ||||
| 		Uint64(&p.TraceID). | ||||
| 		//Int64(&p.SentAt). | ||||
| 		//Byte(&p.SharedKeyType). | ||||
| 		SharedKey(&p.SharedKey). | ||||
| 		Bool(&p.Direct). | ||||
| 		AddrPort(&p.PossibleAddrs[0]). | ||||
| 		AddrPort(&p.PossibleAddrs[1]). | ||||
| 		AddrPort(&p.PossibleAddrs[2]). | ||||
| 		AddrPort(&p.PossibleAddrs[3]). | ||||
| 		AddrPort(&p.PossibleAddrs[4]). | ||||
| 		AddrPort(&p.PossibleAddrs[5]). | ||||
| 		AddrPort(&p.PossibleAddrs[6]). | ||||
| 		AddrPort(&p.PossibleAddrs[7]). | ||||
| 		AddrPort8(&p.PossibleAddrs). | ||||
| 		Error() | ||||
| 	return | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type ackPacket struct { | ||||
| type PacketAck struct { | ||||
| 	TraceID       uint64 | ||||
| 	ToAddr        netip.AddrPort | ||||
| 	PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. | ||||
| } | ||||
|  | ||||
| func (p ackPacket) Marshal(buf []byte) []byte { | ||||
| func (p PacketAck) Marshal(buf []byte) []byte { | ||||
| 	return newBinWriter(buf). | ||||
| 		Byte(packetTypeAck). | ||||
| 		Byte(PacketTypeAck). | ||||
| 		Uint64(p.TraceID). | ||||
| 		AddrPort(p.ToAddr). | ||||
| 		AddrPort(p.PossibleAddrs[0]). | ||||
| 		AddrPort(p.PossibleAddrs[1]). | ||||
| 		AddrPort(p.PossibleAddrs[2]). | ||||
| 		AddrPort(p.PossibleAddrs[3]). | ||||
| 		AddrPort(p.PossibleAddrs[4]). | ||||
| 		AddrPort(p.PossibleAddrs[5]). | ||||
| 		AddrPort(p.PossibleAddrs[6]). | ||||
| 		AddrPort(p.PossibleAddrs[7]). | ||||
| 		AddrPort8(p.PossibleAddrs). | ||||
| 		Build() | ||||
| } | ||||
|  | ||||
| func parseAckPacket(buf []byte) (p ackPacket, err error) { | ||||
| func ParsePacketAck(buf []byte) (p PacketAck, err error) { | ||||
| 	err = newBinReader(buf[1:]). | ||||
| 		Uint64(&p.TraceID). | ||||
| 		AddrPort(&p.ToAddr). | ||||
| 		AddrPort(&p.PossibleAddrs[0]). | ||||
| 		AddrPort(&p.PossibleAddrs[1]). | ||||
| 		AddrPort(&p.PossibleAddrs[2]). | ||||
| 		AddrPort(&p.PossibleAddrs[3]). | ||||
| 		AddrPort(&p.PossibleAddrs[4]). | ||||
| 		AddrPort(&p.PossibleAddrs[5]). | ||||
| 		AddrPort(&p.PossibleAddrs[6]). | ||||
| 		AddrPort(&p.PossibleAddrs[7]). | ||||
| 		AddrPort8(&p.PossibleAddrs). | ||||
| 		Error() | ||||
| 	return | ||||
| } | ||||
| @@ -100,18 +77,18 @@ func parseAckPacket(buf []byte) (p ackPacket, err error) { | ||||
|  | ||||
| // A probeReqPacket is sent from a client to a server to determine if direct | ||||
| // UDP communication can be used. | ||||
| type probePacket struct { | ||||
| type PacketProbe struct { | ||||
| 	TraceID uint64 | ||||
| } | ||||
|  | ||||
| func (p probePacket) Marshal(buf []byte) []byte { | ||||
| func (p PacketProbe) Marshal(buf []byte) []byte { | ||||
| 	return newBinWriter(buf). | ||||
| 		Byte(packetTypeProbe). | ||||
| 		Byte(PacketTypeProbe). | ||||
| 		Uint64(p.TraceID). | ||||
| 		Build() | ||||
| } | ||||
|  | ||||
| func parseProbePacket(buf []byte) (p probePacket, err error) { | ||||
| func ParsePacketProbe(buf []byte) (p PacketProbe, err error) { | ||||
| 	err = newBinReader(buf[1:]). | ||||
| 		Uint64(&p.TraceID). | ||||
| 		Error() | ||||
| @@ -120,4 +97,4 @@ func parseProbePacket(buf []byte) (p probePacket, err error) { | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type localDiscoveryPacket struct{} | ||||
| type PacketLocalDiscovery struct{} | ||||
|   | ||||
| @@ -1 +1,66 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"crypto/rand" | ||||
| 	"net/netip" | ||||
| 	"reflect" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| func TestSynPacket(t *testing.T) { | ||||
| 	p := PacketSyn{ | ||||
| 		TraceID: newTraceID(), | ||||
| 		//SentAt:        time.Now().UnixMilli(), | ||||
| 		//SharedKeyType: 1, | ||||
| 		Direct: true, | ||||
| 	} | ||||
| 	rand.Read(p.SharedKey[:]) | ||||
|  | ||||
| 	p.PossibleAddrs[0] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 234) | ||||
| 	p.PossibleAddrs[1] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 2, 3, 4}), 12399) | ||||
| 	p.PossibleAddrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{3, 2, 3, 4}), 60000) | ||||
|  | ||||
| 	buf := p.Marshal(newBuf()) | ||||
| 	p2, err := ParsePacketSyn(buf) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	if !reflect.DeepEqual(p, p2) { | ||||
| 		t.Fatal(p2) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestAckPacket(t *testing.T) { | ||||
| 	p := PacketAck{ | ||||
| 		TraceID: newTraceID(), | ||||
| 		ToAddr:  netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 234), | ||||
| 	} | ||||
|  | ||||
| 	p.PossibleAddrs[0] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 2, 3, 4}), 100) | ||||
| 	p.PossibleAddrs[1] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 2, 3, 4}), 12399) | ||||
| 	p.PossibleAddrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{3, 2, 3, 4}), 60000) | ||||
|  | ||||
| 	buf := p.Marshal(newBuf()) | ||||
| 	p2, err := ParsePacketAck(buf) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	if !reflect.DeepEqual(p, p2) { | ||||
| 		t.Fatal(p2) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestProbePacket(t *testing.T) { | ||||
| 	p := PacketProbe{ | ||||
| 		TraceID: newTraceID(), | ||||
| 	} | ||||
|  | ||||
| 	buf := p.Marshal(newBuf()) | ||||
| 	p2, err := ParsePacketProbe(buf) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	if !reflect.DeepEqual(p, p2) { | ||||
| 		t.Fatal(p2) | ||||
| 	} | ||||
| } | ||||
|   | ||||
							
								
								
									
										125
									
								
								peer/peer_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										125
									
								
								peer/peer_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,125 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"crypto/rand" | ||||
| 	mrand "math/rand" | ||||
| 	"net/netip" | ||||
| 	"sync/atomic" | ||||
| ) | ||||
|  | ||||
| // A test peer. | ||||
| type P struct { | ||||
| 	cryptoKeys | ||||
| 	RT         *atomic.Pointer[RoutingTable] | ||||
| 	Conn       *TestUDPConn | ||||
| 	IFace      *TestIFace | ||||
| 	ConnWriter *ConnWriter | ||||
| 	ConnReader *ConnReader | ||||
| 	IFReader   *IFReader | ||||
| 	Super      *Supervisor | ||||
| } | ||||
|  | ||||
| func NewPeerForTesting(n *TestNetwork, ip byte, addr netip.AddrPort) P { | ||||
| 	p := P{ | ||||
| 		cryptoKeys: generateKeys(), | ||||
| 		RT:         &atomic.Pointer[RoutingTable]{}, | ||||
| 		IFace:      NewTestIFace(), | ||||
| 	} | ||||
|  | ||||
| 	rt := NewRoutingTable(ip, addr) | ||||
| 	p.RT.Store(&rt) | ||||
| 	p.Conn = n.NewUDPConn(addr) | ||||
| 	p.ConnWriter = NewConnWriter(p.Conn.WriteToUDPAddrPort, p.RT) | ||||
| 	p.IFReader = NewIFReader(p.IFace, p.ConnWriter) | ||||
|  | ||||
| 	/* | ||||
| 		   p.ConnReader = NewConnReader( | ||||
| 				p.Conn.ReadFromUDPAddrPort, | ||||
| 				p.IFace, | ||||
| 				p.ConnWriter.Forward, | ||||
| 				p.Super.HandleControlMsg, | ||||
| 				p.RT) | ||||
| 	*/ | ||||
| 	return p | ||||
| } | ||||
|  | ||||
| func ConnectPeers(p1, p2 *P) { | ||||
| 	rt1 := p1.RT.Load() | ||||
| 	rt2 := p2.RT.Load() | ||||
|  | ||||
| 	ip1 := rt1.LocalIP | ||||
| 	ip2 := rt2.LocalIP | ||||
|  | ||||
| 	rt1.Peers[ip2].Up = true | ||||
| 	rt1.Peers[ip2].Direct = true | ||||
| 	rt1.Peers[ip2].Relay = true | ||||
| 	rt1.Peers[ip2].DirectAddr = rt2.LocalAddr | ||||
| 	rt1.Peers[ip2].PubSignKey = p2.PubSignKey | ||||
| 	rt1.Peers[ip2].ControlCipher = newControlCipher(p1.PrivKey, p2.PubKey) | ||||
| 	rt1.Peers[ip2].DataCipher = newDataCipher() | ||||
|  | ||||
| 	rt2.Peers[ip1].Up = true | ||||
| 	rt2.Peers[ip1].Direct = true | ||||
| 	rt2.Peers[ip1].Relay = true | ||||
| 	rt2.Peers[ip1].DirectAddr = rt1.LocalAddr | ||||
| 	rt2.Peers[ip1].PubSignKey = p1.PubSignKey | ||||
| 	rt2.Peers[ip1].ControlCipher = newControlCipher(p2.PrivKey, p1.PubKey) | ||||
| 	rt2.Peers[ip1].DataCipher = rt1.Peers[ip2].DataCipher | ||||
| } | ||||
|  | ||||
| func NewPeersForTesting() (p1, p2, p3 P) { | ||||
| 	n := NewTestNetwork() | ||||
|  | ||||
| 	p1 = NewPeerForTesting( | ||||
| 		n, | ||||
| 		1, | ||||
| 		netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 100)) | ||||
|  | ||||
| 	p2 = NewPeerForTesting( | ||||
| 		n, | ||||
| 		2, | ||||
| 		netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 2}), 200)) | ||||
|  | ||||
| 	p3 = NewPeerForTesting( | ||||
| 		n, | ||||
| 		3, | ||||
| 		netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 3}), 300)) | ||||
|  | ||||
| 	ConnectPeers(&p1, &p2) | ||||
| 	ConnectPeers(&p1, &p3) | ||||
| 	ConnectPeers(&p2, &p3) | ||||
|  | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func RandPacket() []byte { | ||||
| 	n := mrand.Intn(1200) | ||||
| 	b := make([]byte, n) | ||||
| 	rand.Read(b) | ||||
| 	return b | ||||
| } | ||||
|  | ||||
| func ModifyPacket(in []byte) []byte { | ||||
| 	x := make([]byte, 1) | ||||
|  | ||||
| 	for { | ||||
| 		rand.Read(x) | ||||
| 		out := bytes.Clone(in) | ||||
| 		idx := mrand.Intn(len(out)) | ||||
| 		if out[idx] != x[0] { | ||||
| 			out[idx] = x[0] | ||||
| 			return out | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type UnknownControlPacket struct { | ||||
| 	TraceID uint64 | ||||
| } | ||||
|  | ||||
| func (p UnknownControlPacket) Marshal(buf []byte) []byte { | ||||
| 	return newBinWriter(buf).Byte(255).Uint64(p.TraceID).Build() | ||||
| } | ||||
							
								
								
									
										355
									
								
								peer/peerstates.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										355
									
								
								peer/peerstates.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,355 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"log" | ||||
| 	"net/netip" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 	"vppn/m" | ||||
|  | ||||
| 	"git.crumpington.com/lib/go/ratelimiter" | ||||
| ) | ||||
|  | ||||
| type PeerState interface { | ||||
| 	OnPeerUpdate(*m.Peer) PeerState | ||||
| 	OnSyn(controlMsg[PacketSyn]) PeerState | ||||
| 	OnAck(controlMsg[PacketAck]) | ||||
| 	OnProbe(controlMsg[PacketProbe]) PeerState | ||||
| 	OnLocalDiscovery(controlMsg[PacketLocalDiscovery]) | ||||
| 	OnPingTimer() PeerState | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type State struct { | ||||
| 	// Output. | ||||
| 	publish           func(RemotePeer) | ||||
| 	sendControlPacket func(RemotePeer, Marshaller) | ||||
|  | ||||
| 	// Immutable data. | ||||
| 	localIP   byte | ||||
| 	remoteIP  byte | ||||
| 	privKey   []byte | ||||
| 	localAddr netip.AddrPort // If valid, then local peer is publicly accessible. | ||||
|  | ||||
| 	pubAddrs *pubAddrStore | ||||
|  | ||||
| 	// The purpose of this state machine is to manage the RemotePeer object, | ||||
| 	// publishing it as necessary. | ||||
| 	staged RemotePeer // Local copy of shared data. See publish(). | ||||
|  | ||||
| 	// Mutable peer data. | ||||
| 	peer *m.Peer | ||||
|  | ||||
| 	// We rate limit per remote endpoint because if we don't we tend to lose | ||||
| 	// packets. | ||||
| 	limiter *ratelimiter.Limiter | ||||
| } | ||||
|  | ||||
| func (s *State) OnPeerUpdate(peer *m.Peer) PeerState { | ||||
| 	defer func() { | ||||
| 		// Don't defer directly otherwise s.staged will be evaluated immediately | ||||
| 		// and won't reflect changes made in the function. | ||||
| 		s.publish(s.staged) | ||||
| 	}() | ||||
|  | ||||
| 	if peer == nil { | ||||
| 		return EnterStateDisconnected(s) | ||||
| 	} | ||||
|  | ||||
| 	s.peer = peer | ||||
| 	s.staged.Relay = false | ||||
| 	s.staged.Direct = false | ||||
| 	s.staged.DirectAddr = netip.AddrPort{} | ||||
| 	s.staged.PubSignKey = nil | ||||
| 	s.staged.PubSignKey = peer.PubSignKey | ||||
| 	s.staged.ControlCipher = newControlCipher(s.privKey, peer.PubKey) | ||||
| 	s.staged.DataCipher = newDataCipher() | ||||
|  | ||||
| 	if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid { | ||||
| 		s.staged.Relay = peer.Relay | ||||
| 		s.staged.Direct = true | ||||
| 		s.staged.DirectAddr = netip.AddrPortFrom(ip, peer.Port) | ||||
|  | ||||
| 		if s.localAddr.IsValid() && s.localIP < s.remoteIP { | ||||
| 			return EnterStateServer(s) | ||||
| 		} | ||||
|  | ||||
| 		return EnterStateClientDirect(s) | ||||
| 	} | ||||
|  | ||||
| 	if s.localAddr.IsValid() { | ||||
| 		s.staged.Direct = true | ||||
| 		return EnterStateServer(s) | ||||
| 	} | ||||
|  | ||||
| 	if s.localIP < s.remoteIP { | ||||
| 		return EnterStateServer(s) | ||||
| 	} | ||||
|  | ||||
| 	return EnterStateClientRelayed(s) | ||||
| } | ||||
|  | ||||
| func (s *State) logf(format string, args ...any) { | ||||
| 	b := strings.Builder{} | ||||
| 	name := "--" | ||||
| 	if s.peer != nil { | ||||
| 		name = s.peer.Name | ||||
| 	} | ||||
| 	b.WriteString(fmt.Sprintf("%30s: ", name)) | ||||
|  | ||||
| 	if s.staged.Direct { | ||||
| 		b.WriteString("DIRECT  | ") | ||||
| 	} else { | ||||
| 		b.WriteString("RELAYED | ") | ||||
| 	} | ||||
|  | ||||
| 	if s.staged.Up { | ||||
| 		b.WriteString("UP   | ") | ||||
| 	} else { | ||||
| 		b.WriteString("DOWN | ") | ||||
| 	} | ||||
|  | ||||
| 	log.Printf(b.String()+format, args...) | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *State) SendTo(pkt Marshaller, addr netip.AddrPort) { | ||||
| 	if !addr.IsValid() { | ||||
| 		return | ||||
| 	} | ||||
| 	route := s.staged | ||||
| 	route.Direct = true | ||||
| 	route.DirectAddr = addr | ||||
| 	s.Send(route, pkt) | ||||
| } | ||||
|  | ||||
| func (s *State) Send(peer RemotePeer, pkt Marshaller) { | ||||
| 	if err := s.limiter.Limit(); err != nil { | ||||
| 		s.logf("Rate limited.") | ||||
| 		return | ||||
| 	} | ||||
| 	s.sendControlPacket(peer, pkt) | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type StateDisconnected struct{ *State } | ||||
|  | ||||
| func EnterStateDisconnected(s *State) PeerState { | ||||
| 	s.logf("==> Disconnected") | ||||
| 	s.peer = nil | ||||
| 	s.staged.Up = false | ||||
| 	s.staged.Relay = false | ||||
| 	s.staged.Direct = false | ||||
| 	s.staged.DirectAddr = netip.AddrPort{} | ||||
| 	s.staged.PubSignKey = nil | ||||
| 	s.staged.ControlCipher = nil | ||||
| 	s.staged.DataCipher = nil | ||||
| 	s.publish(s.staged) | ||||
| 	return &StateDisconnected{State: s} | ||||
| } | ||||
|  | ||||
| func (s *StateDisconnected) OnSyn(controlMsg[PacketSyn]) PeerState             { return nil } | ||||
| func (s *StateDisconnected) OnAck(controlMsg[PacketAck])                       {} | ||||
| func (s *StateDisconnected) OnProbe(controlMsg[PacketProbe]) PeerState         { return nil } | ||||
| func (s *StateDisconnected) OnLocalDiscovery(controlMsg[PacketLocalDiscovery]) {} | ||||
| func (s *StateDisconnected) OnPingTimer() PeerState                            { return nil } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type StateServer struct { | ||||
| 	*StateDisconnected | ||||
| 	lastSeen   time.Time | ||||
| 	synTraceID uint64 | ||||
| } | ||||
|  | ||||
| func EnterStateServer(s *State) PeerState { | ||||
| 	s.logf("==> Server") | ||||
| 	return &StateServer{StateDisconnected: &StateDisconnected{State: s}} | ||||
| } | ||||
|  | ||||
| func (s *StateServer) OnSyn(msg controlMsg[PacketSyn]) PeerState { | ||||
| 	s.lastSeen = time.Now() | ||||
| 	p := msg.Packet | ||||
|  | ||||
| 	// Before we can respond to this packet, we need to make sure the | ||||
| 	// route is setup properly. | ||||
| 	// | ||||
| 	// The client will update the syn's TraceID whenever there's a change. | ||||
| 	// The server will follow the client's request. | ||||
| 	if p.TraceID != s.synTraceID || !s.staged.Up { | ||||
| 		s.synTraceID = p.TraceID | ||||
| 		s.staged.Up = true | ||||
| 		s.staged.Direct = p.Direct | ||||
| 		s.staged.DataCipher = newDataCipherFromKey(p.SharedKey) | ||||
| 		s.staged.DirectAddr = msg.SrcAddr | ||||
| 		s.publish(s.staged) | ||||
| 		s.logf("Got SYN.") | ||||
| 	} | ||||
|  | ||||
| 	// Always respond. | ||||
| 	ack := PacketAck{ | ||||
| 		TraceID:       p.TraceID, | ||||
| 		ToAddr:        s.staged.DirectAddr, | ||||
| 		PossibleAddrs: s.pubAddrs.Get(), | ||||
| 	} | ||||
| 	s.Send(s.staged, ack) | ||||
|  | ||||
| 	if p.Direct { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	for _, addr := range msg.Packet.PossibleAddrs { | ||||
| 		if !addr.IsValid() { | ||||
| 			break | ||||
| 		} | ||||
| 		s.SendTo(PacketProbe{TraceID: newTraceID()}, addr) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (s *StateServer) OnProbe(msg controlMsg[PacketProbe]) PeerState { | ||||
| 	if msg.SrcAddr.IsValid() { | ||||
| 		s.SendTo(PacketProbe{TraceID: msg.Packet.TraceID}, msg.SrcAddr) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (s *StateServer) OnPingTimer() PeerState { | ||||
| 	if time.Since(s.lastSeen) > timeoutInterval && s.staged.Up { | ||||
| 		s.staged.Up = false | ||||
| 		s.publish(s.staged) | ||||
| 		s.logf("Timeout.") | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type StateClientDirect struct { | ||||
| 	*StateDisconnected | ||||
| 	lastSeen time.Time | ||||
| 	syn      PacketSyn | ||||
| } | ||||
|  | ||||
| func EnterStateClientDirect(s *State) PeerState { | ||||
| 	s.logf("==> ClientDirect") | ||||
| 	return NewStateClientDirect(s) | ||||
| } | ||||
|  | ||||
| func NewStateClientDirect(s *State) *StateClientDirect { | ||||
| 	state := &StateClientDirect{ | ||||
| 		StateDisconnected: &StateDisconnected{s}, | ||||
| 		lastSeen:          time.Now(), // Avoid immediate timeout. | ||||
| 	} | ||||
|  | ||||
| 	state.syn = PacketSyn{ | ||||
| 		TraceID:       newTraceID(), | ||||
| 		SharedKey:     s.staged.DataCipher.Key(), | ||||
| 		Direct:        s.staged.Direct, | ||||
| 		PossibleAddrs: s.pubAddrs.Get(), | ||||
| 	} | ||||
| 	state.Send(s.staged, state.syn) | ||||
| 	return state | ||||
| } | ||||
|  | ||||
| func (s *StateClientDirect) OnAck(msg controlMsg[PacketAck]) { | ||||
| 	if msg.Packet.TraceID != s.syn.TraceID { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	s.lastSeen = time.Now() | ||||
|  | ||||
| 	if !s.staged.Up { | ||||
| 		s.staged.Up = true | ||||
| 		s.publish(s.staged) | ||||
| 		s.logf("Got ACK.") | ||||
| 	} | ||||
|  | ||||
| 	s.pubAddrs.Store(msg.Packet.ToAddr) | ||||
| } | ||||
|  | ||||
| func (s *StateClientDirect) OnPingTimer() PeerState { | ||||
| 	if time.Since(s.lastSeen) > timeoutInterval { | ||||
| 		if s.staged.Up { | ||||
| 			s.staged.Up = false | ||||
| 			s.publish(s.staged) | ||||
| 			s.logf("Timeout.") | ||||
| 		} | ||||
| 		return s.OnPeerUpdate(s.peer) | ||||
| 	} | ||||
|  | ||||
| 	s.Send(s.staged, s.syn) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type StateClientRelayed struct { | ||||
| 	*StateClientDirect | ||||
| 	ack                PacketAck | ||||
| 	probes             map[uint64]netip.AddrPort | ||||
| 	localDiscoveryAddr netip.AddrPort | ||||
| } | ||||
|  | ||||
| func EnterStateClientRelayed(s *State) PeerState { | ||||
| 	s.logf("==> ClientRelayed") | ||||
| 	return &StateClientRelayed{ | ||||
| 		StateClientDirect: NewStateClientDirect(s), | ||||
| 		probes:            map[uint64]netip.AddrPort{}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s *StateClientRelayed) OnAck(msg controlMsg[PacketAck]) { | ||||
| 	s.ack = msg.Packet | ||||
| 	s.StateClientDirect.OnAck(msg) | ||||
| } | ||||
|  | ||||
| func (s *StateClientRelayed) OnProbe(msg controlMsg[PacketProbe]) PeerState { | ||||
| 	addr, ok := s.probes[msg.Packet.TraceID] | ||||
| 	if !ok { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	s.staged.DirectAddr = addr | ||||
| 	s.staged.Direct = true | ||||
| 	s.publish(s.staged) | ||||
| 	return EnterStateClientDirect(s.StateClientDirect.State) | ||||
| } | ||||
|  | ||||
| func (s *StateClientRelayed) OnLocalDiscovery(msg controlMsg[PacketLocalDiscovery]) { | ||||
| 	// The source port will be the multicast port, so we'll have to | ||||
| 	// construct the correct address using the peer's listed port. | ||||
| 	s.localDiscoveryAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) | ||||
| } | ||||
|  | ||||
| func (s *StateClientRelayed) OnPingTimer() PeerState { | ||||
| 	if nextState := s.StateClientDirect.OnPingTimer(); nextState != nil { | ||||
| 		return nextState | ||||
| 	} | ||||
|  | ||||
| 	clear(s.probes) | ||||
| 	for _, addr := range s.ack.PossibleAddrs { | ||||
| 		if !addr.IsValid() { | ||||
| 			break | ||||
| 		} | ||||
| 		s.sendProbeTo(addr) | ||||
| 	} | ||||
|  | ||||
| 	if s.localDiscoveryAddr.IsValid() { | ||||
| 		s.sendProbeTo(s.localDiscoveryAddr) | ||||
| 		s.localDiscoveryAddr = netip.AddrPort{} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (s *StateClientRelayed) sendProbeTo(addr netip.AddrPort) { | ||||
| 	probe := PacketProbe{TraceID: newTraceID()} | ||||
| 	s.probes[probe.TraceID] = addr | ||||
| 	s.SendTo(probe, addr) | ||||
| } | ||||
							
								
								
									
										509
									
								
								peer/peerstates_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										509
									
								
								peer/peerstates_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,509 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"net/netip" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 	"vppn/m" | ||||
|  | ||||
| 	"git.crumpington.com/lib/go/ratelimiter" | ||||
| ) | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type PeerStateControlMsg struct { | ||||
| 	Peer   RemotePeer | ||||
| 	Packet any | ||||
| } | ||||
|  | ||||
| type PeerStateTestHarness struct { | ||||
| 	State     PeerState | ||||
| 	Published RemotePeer | ||||
| 	Sent      []PeerStateControlMsg | ||||
| } | ||||
|  | ||||
| func NewPeerStateTestHarness() *PeerStateTestHarness { | ||||
| 	h := &PeerStateTestHarness{} | ||||
|  | ||||
| 	keys := generateKeys() | ||||
|  | ||||
| 	state := &State{ | ||||
| 		publish: func(rp RemotePeer) { | ||||
| 			h.Published = rp | ||||
| 		}, | ||||
| 		sendControlPacket: func(rp RemotePeer, pkt Marshaller) { | ||||
| 			h.Sent = append(h.Sent, PeerStateControlMsg{rp, pkt}) | ||||
| 		}, | ||||
| 		localIP:  2, | ||||
| 		remoteIP: 3, | ||||
| 		privKey:  keys.PrivKey, | ||||
| 		pubAddrs: newPubAddrStore(netip.AddrPort{}), | ||||
| 		limiter: ratelimiter.New(ratelimiter.Config{ | ||||
| 			FillPeriod:   20 * time.Millisecond, | ||||
| 			MaxWaitCount: 1, | ||||
| 		}), | ||||
| 	} | ||||
|  | ||||
| 	h.State = EnterStateDisconnected(state) | ||||
| 	return h | ||||
| } | ||||
|  | ||||
| func (h *PeerStateTestHarness) PeerUpdate(p *m.Peer) { | ||||
| 	if s := h.State.OnPeerUpdate(p); s != nil { | ||||
| 		h.State = s | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (h *PeerStateTestHarness) OnSyn(msg controlMsg[PacketSyn]) { | ||||
| 	if s := h.State.OnSyn(msg); s != nil { | ||||
| 		h.State = s | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (h *PeerStateTestHarness) OnProbe(msg controlMsg[PacketProbe]) { | ||||
| 	if s := h.State.OnProbe(msg); s != nil { | ||||
| 		h.State = s | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (h *PeerStateTestHarness) OnPingTimer() { | ||||
| 	if s := h.State.OnPingTimer(); s != nil { | ||||
| 		h.State = s | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (h *PeerStateTestHarness) ConfigServer_Public(t *testing.T) *StateServer { | ||||
| 	keys := generateKeys() | ||||
|  | ||||
| 	state := h.State.(*StateDisconnected) | ||||
| 	state.localAddr = addrPort4(1, 1, 1, 2, 200) | ||||
|  | ||||
| 	peer := &m.Peer{ | ||||
| 		PeerIP:     3, | ||||
| 		PublicIP:   []byte{1, 1, 1, 3}, | ||||
| 		Port:       456, | ||||
| 		PubKey:     keys.PubKey, | ||||
| 		PubSignKey: keys.PubSignKey, | ||||
| 	} | ||||
|  | ||||
| 	h.PeerUpdate(peer) | ||||
| 	assertEqual(t, h.Published.Up, false) | ||||
| 	return assertType[*StateServer](t, h.State) | ||||
| } | ||||
|  | ||||
| func (h *PeerStateTestHarness) ConfigServer_Relayed(t *testing.T) *StateServer { | ||||
| 	keys := generateKeys() | ||||
| 	peer := &m.Peer{ | ||||
| 		PeerIP:     3, | ||||
| 		Port:       456, | ||||
| 		PubKey:     keys.PubKey, | ||||
| 		PubSignKey: keys.PubSignKey, | ||||
| 	} | ||||
|  | ||||
| 	h.PeerUpdate(peer) | ||||
| 	assertEqual(t, h.Published.Up, false) | ||||
| 	return assertType[*StateServer](t, h.State) | ||||
| } | ||||
|  | ||||
| func (h *PeerStateTestHarness) ConfigClientDirect(t *testing.T) *StateClientDirect { | ||||
| 	keys := generateKeys() | ||||
| 	peer := &m.Peer{ | ||||
| 		PeerIP:     3, | ||||
| 		PublicIP:   []byte{1, 2, 3, 4}, | ||||
| 		Port:       456, | ||||
| 		PubKey:     keys.PubKey, | ||||
| 		PubSignKey: keys.PubSignKey, | ||||
| 	} | ||||
|  | ||||
| 	h.PeerUpdate(peer) | ||||
| 	assertEqual(t, h.Published.Up, false) | ||||
| 	return assertType[*StateClientDirect](t, h.State) | ||||
| } | ||||
|  | ||||
| func (h *PeerStateTestHarness) ConfigClientRelayed(t *testing.T) *StateClientRelayed { | ||||
| 	keys := generateKeys() | ||||
|  | ||||
| 	state := h.State.(*StateDisconnected) | ||||
| 	state.remoteIP = 1 | ||||
|  | ||||
| 	peer := &m.Peer{ | ||||
| 		PeerIP:     3, | ||||
| 		Port:       456, | ||||
| 		PubKey:     keys.PubKey, | ||||
| 		PubSignKey: keys.PubSignKey, | ||||
| 	} | ||||
|  | ||||
| 	h.PeerUpdate(peer) | ||||
| 	assertEqual(t, h.Published.Up, false) | ||||
| 	return assertType[*StateClientRelayed](t, h.State) | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func TestPeerState_OnPeerUpdate_nilPeer(t *testing.T) { | ||||
| 	h := NewPeerStateTestHarness() | ||||
| 	h.PeerUpdate(nil) | ||||
| 	assertType[*StateDisconnected](t, h.State) | ||||
| } | ||||
|  | ||||
| func TestPeerState_OnPeerUpdate_publicLocalIsServer(t *testing.T) { | ||||
| 	keys := generateKeys() | ||||
| 	h := NewPeerStateTestHarness() | ||||
|  | ||||
| 	state := h.State.(*StateDisconnected) | ||||
| 	state.localAddr = addrPort4(1, 1, 1, 2, 200) | ||||
|  | ||||
| 	peer := &m.Peer{ | ||||
| 		PeerIP:     3, | ||||
| 		Port:       456, | ||||
| 		PubKey:     keys.PubKey, | ||||
| 		PubSignKey: keys.PubSignKey, | ||||
| 	} | ||||
|  | ||||
| 	h.PeerUpdate(peer) | ||||
| 	assertEqual(t, h.Published.Up, false) | ||||
| 	assertType[*StateServer](t, h.State) | ||||
| } | ||||
|  | ||||
| func TestPeerState_OnPeerUpdate_serverDirect(t *testing.T) { | ||||
| 	h := NewPeerStateTestHarness() | ||||
| 	h.ConfigServer_Public(t) | ||||
| } | ||||
|  | ||||
| func TestPeerState_OnPeerUpdate_serverRelayed(t *testing.T) { | ||||
| 	h := NewPeerStateTestHarness() | ||||
| 	h.ConfigServer_Relayed(t) | ||||
| } | ||||
|  | ||||
| func TestPeerState_OnPeerUpdate_clientDirect(t *testing.T) { | ||||
| 	h := NewPeerStateTestHarness() | ||||
| 	h.ConfigClientDirect(t) | ||||
| } | ||||
|  | ||||
| func TestPeerState_OnPeerUpdate_clientRelayed(t *testing.T) { | ||||
| 	h := NewPeerStateTestHarness() | ||||
| 	h.ConfigClientRelayed(t) | ||||
| } | ||||
|  | ||||
| func TestStateServer_directSyn(t *testing.T) { | ||||
| 	h := NewPeerStateTestHarness() | ||||
| 	h.ConfigServer_Relayed(t) | ||||
|  | ||||
| 	assertEqual(t, h.Published.Up, false) | ||||
|  | ||||
| 	synMsg := controlMsg[PacketSyn]{ | ||||
| 		SrcIP:   3, | ||||
| 		SrcAddr: addrPort4(1, 1, 1, 3, 300), | ||||
| 		Packet: PacketSyn{ | ||||
| 			TraceID: newTraceID(), | ||||
| 			//SentAt:        time.Now().UnixMilli(), | ||||
| 			//SharedKeyType: 1, | ||||
| 			Direct: true, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	h.State.OnSyn(synMsg) | ||||
|  | ||||
| 	assertEqual(t, len(h.Sent), 1) | ||||
| 	ack := assertType[PacketAck](t, h.Sent[0].Packet) | ||||
| 	assertEqual(t, ack.TraceID, synMsg.Packet.TraceID) | ||||
| 	assertEqual(t, h.Sent[0].Peer.IP, 3) | ||||
| 	assertEqual(t, ack.PossibleAddrs[0].IsValid(), false) | ||||
| 	assertEqual(t, h.Published.Up, true) | ||||
| } | ||||
|  | ||||
| func TestStateServer_relayedSyn(t *testing.T) { | ||||
| 	h := NewPeerStateTestHarness() | ||||
| 	state := h.ConfigServer_Relayed(t) | ||||
|  | ||||
| 	state.pubAddrs.Store(addrPort4(4, 5, 6, 7, 1234)) | ||||
|  | ||||
| 	assertEqual(t, h.Published.Up, false) | ||||
|  | ||||
| 	synMsg := controlMsg[PacketSyn]{ | ||||
| 		SrcIP:   3, | ||||
| 		SrcAddr: addrPort4(1, 1, 1, 3, 300), | ||||
| 		Packet: PacketSyn{ | ||||
| 			TraceID: newTraceID(), | ||||
| 			//SentAt:        time.Now().UnixMilli(), | ||||
| 			//SharedKeyType: 1, | ||||
| 			Direct: false, | ||||
| 		}, | ||||
| 	} | ||||
| 	synMsg.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 3, 300) | ||||
| 	synMsg.Packet.PossibleAddrs[1] = addrPort4(2, 2, 2, 3, 300) | ||||
|  | ||||
| 	h.State.OnSyn(synMsg) | ||||
|  | ||||
| 	assertEqual(t, len(h.Sent), 3) | ||||
|  | ||||
| 	ack := assertType[PacketAck](t, h.Sent[0].Packet) | ||||
| 	assertEqual(t, ack.TraceID, synMsg.Packet.TraceID) | ||||
| 	assertEqual(t, h.Sent[0].Peer.IP, 3) | ||||
| 	assertEqual(t, ack.PossibleAddrs[0], addrPort4(4, 5, 6, 7, 1234)) | ||||
| 	assertEqual(t, ack.PossibleAddrs[1].IsValid(), false) | ||||
| 	assertEqual(t, h.Published.Up, true) | ||||
|  | ||||
| 	assertType[PacketProbe](t, h.Sent[1].Packet) | ||||
| 	assertType[PacketProbe](t, h.Sent[2].Packet) | ||||
| 	assertEqual(t, h.Sent[1].Peer.DirectAddr, addrPort4(1, 1, 1, 3, 300)) | ||||
| 	assertEqual(t, h.Sent[2].Peer.DirectAddr, addrPort4(2, 2, 2, 3, 300)) | ||||
| } | ||||
|  | ||||
| func TestStateServer_onProbe(t *testing.T) { | ||||
| 	h := NewPeerStateTestHarness() | ||||
| 	h.ConfigServer_Relayed(t) | ||||
| 	assertEqual(t, h.Published.Up, false) | ||||
|  | ||||
| 	probeMsg := controlMsg[PacketProbe]{ | ||||
| 		SrcIP:   3, | ||||
| 		SrcAddr: addrPort4(1, 1, 1, 3, 300), | ||||
| 		Packet:  PacketProbe{TraceID: newTraceID()}, | ||||
| 	} | ||||
|  | ||||
| 	h.State.OnProbe(probeMsg) | ||||
|  | ||||
| 	assertEqual(t, len(h.Sent), 1) | ||||
|  | ||||
| 	probe := assertType[PacketProbe](t, h.Sent[0].Packet) | ||||
| 	assertEqual(t, probe.TraceID, probeMsg.Packet.TraceID) | ||||
| 	assertEqual(t, h.Sent[0].Peer.DirectAddr, addrPort4(1, 1, 1, 3, 300)) | ||||
| } | ||||
|  | ||||
| func TestStateServer_OnPingTimer_timeout(t *testing.T) { | ||||
| 	h := NewPeerStateTestHarness() | ||||
| 	h.ConfigServer_Relayed(t) | ||||
|  | ||||
| 	synMsg := controlMsg[PacketSyn]{ | ||||
| 		SrcIP:   3, | ||||
| 		SrcAddr: addrPort4(1, 1, 1, 3, 300), | ||||
| 		Packet: PacketSyn{ | ||||
| 			TraceID: newTraceID(), | ||||
| 			//SentAt:        time.Now().UnixMilli(), | ||||
| 			//SharedKeyType: 1, | ||||
| 			Direct: true, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	h.State.OnSyn(synMsg) | ||||
| 	assertEqual(t, len(h.Sent), 1) | ||||
| 	assertEqual(t, h.Published.Up, true) | ||||
|  | ||||
| 	// Ping shouldn't timeout. | ||||
| 	h.OnPingTimer() | ||||
| 	assertEqual(t, h.Published.Up, true) | ||||
|  | ||||
| 	// Advance the time, then ping. | ||||
| 	state := assertType[*StateServer](t, h.State) | ||||
| 	state.lastSeen = time.Now().Add(-timeoutInterval - time.Second) | ||||
|  | ||||
| 	h.OnPingTimer() | ||||
| 	assertEqual(t, h.Published.Up, false) | ||||
| } | ||||
|  | ||||
| func TestStateClientDirect_OnAck(t *testing.T) { | ||||
| 	h := NewPeerStateTestHarness() | ||||
| 	h.ConfigClientDirect(t) | ||||
|  | ||||
| 	assertEqual(t, h.Published.Up, false) | ||||
|  | ||||
| 	// On entering the state, a SYN should have been sent. | ||||
| 	assertEqual(t, len(h.Sent), 1) | ||||
| 	syn := assertType[PacketSyn](t, h.Sent[0].Packet) | ||||
|  | ||||
| 	ack := controlMsg[PacketAck]{ | ||||
| 		Packet: PacketAck{TraceID: syn.TraceID}, | ||||
| 	} | ||||
| 	h.State.OnAck(ack) | ||||
| 	assertEqual(t, h.Published.Up, true) | ||||
| } | ||||
|  | ||||
| func TestStateClientDirect_OnAck_incorrectTraceID(t *testing.T) { | ||||
| 	h := NewPeerStateTestHarness() | ||||
| 	h.ConfigClientDirect(t) | ||||
|  | ||||
| 	assertEqual(t, h.Published.Up, false) | ||||
|  | ||||
| 	// On entering the state, a SYN should have been sent. | ||||
| 	assertEqual(t, len(h.Sent), 1) | ||||
| 	syn := assertType[PacketSyn](t, h.Sent[0].Packet) | ||||
|  | ||||
| 	ack := controlMsg[PacketAck]{ | ||||
| 		Packet: PacketAck{TraceID: syn.TraceID + 1}, | ||||
| 	} | ||||
| 	h.State.OnAck(ack) | ||||
| 	assertEqual(t, h.Published.Up, false) | ||||
| } | ||||
|  | ||||
| func TestStateClientDirect_OnPingTimer(t *testing.T) { | ||||
| 	h := NewPeerStateTestHarness() | ||||
| 	h.ConfigClientDirect(t) | ||||
|  | ||||
| 	// On entering the state, a SYN should have been sent. | ||||
| 	assertEqual(t, len(h.Sent), 1) | ||||
| 	assertType[PacketSyn](t, h.Sent[0].Packet) | ||||
|  | ||||
| 	h.OnPingTimer() | ||||
|  | ||||
| 	// On ping timer, another syn should be sent. Additionally, we should remain | ||||
| 	// in the same state. | ||||
| 	assertEqual(t, len(h.Sent), 2) | ||||
| 	assertType[PacketSyn](t, h.Sent[1].Packet) | ||||
| 	assertType[*StateClientDirect](t, h.State) | ||||
| 	assertEqual(t, h.Published.Up, false) | ||||
| } | ||||
|  | ||||
| func TestStateClientDirect_OnPingTimer_timeout(t *testing.T) { | ||||
| 	h := NewPeerStateTestHarness() | ||||
| 	h.ConfigClientDirect(t) | ||||
|  | ||||
| 	assertEqual(t, h.Published.Up, false) | ||||
|  | ||||
| 	// On entering the state, a SYN should have been sent. | ||||
| 	assertEqual(t, len(h.Sent), 1) | ||||
| 	syn := assertType[PacketSyn](t, h.Sent[0].Packet) | ||||
|  | ||||
| 	ack := controlMsg[PacketAck]{ | ||||
| 		Packet: PacketAck{TraceID: syn.TraceID}, | ||||
| 	} | ||||
| 	h.State.OnAck(ack) | ||||
| 	assertEqual(t, h.Published.Up, true) | ||||
|  | ||||
| 	state := assertType[*StateClientDirect](t, h.State) | ||||
| 	state.lastSeen = time.Now().Add(-(timeoutInterval + time.Second)) | ||||
|  | ||||
| 	h.OnPingTimer() | ||||
|  | ||||
| 	// On ping timer, we should timeout, causing the client to reset. Another SYN | ||||
| 	// will be sent when re-entering the state, but the connection should be down. | ||||
| 	assertEqual(t, len(h.Sent), 2) | ||||
| 	assertType[PacketSyn](t, h.Sent[1].Packet) | ||||
| 	assertType[*StateClientDirect](t, h.State) | ||||
| 	assertEqual(t, h.Published.Up, false) | ||||
| } | ||||
|  | ||||
| func TestStateClientRelayed_OnAck(t *testing.T) { | ||||
| 	h := NewPeerStateTestHarness() | ||||
| 	h.ConfigClientRelayed(t) | ||||
|  | ||||
| 	assertEqual(t, h.Published.Up, false) | ||||
|  | ||||
| 	// On entering the state, a SYN should have been sent. | ||||
| 	assertEqual(t, len(h.Sent), 1) | ||||
| 	syn := assertType[PacketSyn](t, h.Sent[0].Packet) | ||||
|  | ||||
| 	ack := controlMsg[PacketAck]{ | ||||
| 		Packet: PacketAck{TraceID: syn.TraceID}, | ||||
| 	} | ||||
| 	h.State.OnAck(ack) | ||||
| 	assertEqual(t, h.Published.Up, true) | ||||
| } | ||||
|  | ||||
| func TestStateClientRelayed_OnPingTimer_noAddrs(t *testing.T) { | ||||
| 	h := NewPeerStateTestHarness() | ||||
| 	h.ConfigClientRelayed(t) | ||||
|  | ||||
| 	assertEqual(t, h.Published.Up, false) | ||||
|  | ||||
| 	// On entering the state, a SYN should have been sent. | ||||
| 	assertEqual(t, len(h.Sent), 1) | ||||
|  | ||||
| 	// If we haven't had an ack yet, we won't have addresses to probe. Therefore | ||||
| 	// we'll have just one more syn packet sent. | ||||
| 	h.OnPingTimer() | ||||
| 	assertEqual(t, len(h.Sent), 2) | ||||
| } | ||||
|  | ||||
| func TestStateClientRelayed_OnPingTimer_withAddrs(t *testing.T) { | ||||
| 	h := NewPeerStateTestHarness() | ||||
| 	h.ConfigClientRelayed(t) | ||||
|  | ||||
| 	assertEqual(t, h.Published.Up, false) | ||||
|  | ||||
| 	// On entering the state, a SYN should have been sent. | ||||
| 	assertEqual(t, len(h.Sent), 1) | ||||
|  | ||||
| 	syn := assertType[PacketSyn](t, h.Sent[0].Packet) | ||||
|  | ||||
| 	ack := controlMsg[PacketAck]{Packet: PacketAck{TraceID: syn.TraceID}} | ||||
| 	ack.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 300) | ||||
| 	ack.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 300) | ||||
|  | ||||
| 	h.State.OnAck(ack) | ||||
|  | ||||
| 	// Add a local discovery address. Note that the port will be configured port | ||||
| 	// and no the one provided here. | ||||
| 	h.State.OnLocalDiscovery(controlMsg[PacketLocalDiscovery]{ | ||||
| 		SrcIP:   3, | ||||
| 		SrcAddr: addrPort4(2, 2, 2, 3, 300), | ||||
| 	}) | ||||
|  | ||||
| 	// We should see one SYN and three probe packets. | ||||
| 	h.OnPingTimer() | ||||
| 	assertEqual(t, len(h.Sent), 5) | ||||
| 	assertType[PacketSyn](t, h.Sent[1].Packet) | ||||
| 	assertType[PacketProbe](t, h.Sent[2].Packet) | ||||
| 	assertType[PacketProbe](t, h.Sent[3].Packet) | ||||
| 	assertType[PacketProbe](t, h.Sent[4].Packet) | ||||
|  | ||||
| 	assertEqual(t, h.Sent[2].Peer.DirectAddr, addrPort4(1, 1, 1, 1, 300)) | ||||
| 	assertEqual(t, h.Sent[3].Peer.DirectAddr, addrPort4(1, 1, 1, 2, 300)) | ||||
| 	assertEqual(t, h.Sent[4].Peer.DirectAddr, addrPort4(2, 2, 2, 3, 456)) | ||||
| } | ||||
|  | ||||
| func TestStateClientRelayed_OnPingTimer_timeout(t *testing.T) { | ||||
| 	h := NewPeerStateTestHarness() | ||||
| 	h.ConfigClientRelayed(t) | ||||
|  | ||||
| 	// On entering the state, a SYN should have been sent. | ||||
| 	assertEqual(t, len(h.Sent), 1) | ||||
| 	syn := assertType[PacketSyn](t, h.Sent[0].Packet) | ||||
|  | ||||
| 	ack := controlMsg[PacketAck]{ | ||||
| 		Packet: PacketAck{TraceID: syn.TraceID}, | ||||
| 	} | ||||
| 	h.State.OnAck(ack) | ||||
| 	assertEqual(t, h.Published.Up, true) | ||||
|  | ||||
| 	state := assertType[*StateClientRelayed](t, h.State) | ||||
| 	state.lastSeen = time.Now().Add(-(timeoutInterval + time.Second)) | ||||
|  | ||||
| 	h.OnPingTimer() | ||||
|  | ||||
| 	// On ping timer, we should timeout, causing the client to reset. Another SYN | ||||
| 	// will be sent when re-entering the state, but the connection should be down. | ||||
| 	assertEqual(t, len(h.Sent), 2) | ||||
| 	assertType[PacketSyn](t, h.Sent[1].Packet) | ||||
| 	assertType[*StateClientRelayed](t, h.State) | ||||
| 	assertEqual(t, h.Published.Up, false) | ||||
| } | ||||
|  | ||||
| func TestStateClientRelayed_OnProbe_unknownAddr(t *testing.T) { | ||||
| 	h := NewPeerStateTestHarness() | ||||
| 	h.ConfigClientRelayed(t) | ||||
|  | ||||
| 	h.OnProbe(controlMsg[PacketProbe]{ | ||||
| 		Packet: PacketProbe{TraceID: newTraceID()}, | ||||
| 	}) | ||||
|  | ||||
| 	assertType[*StateClientRelayed](t, h.State) | ||||
| } | ||||
|  | ||||
| func TestStateClientRelayed_OnProbe_upgradeDirect(t *testing.T) { | ||||
| 	h := NewPeerStateTestHarness() | ||||
| 	h.ConfigClientRelayed(t) | ||||
|  | ||||
| 	syn := assertType[PacketSyn](t, h.Sent[0].Packet) | ||||
|  | ||||
| 	ack := controlMsg[PacketAck]{Packet: PacketAck{TraceID: syn.TraceID}} | ||||
| 	ack.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 300) | ||||
| 	ack.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 300) | ||||
|  | ||||
| 	h.State.OnAck(ack) | ||||
| 	h.OnPingTimer() | ||||
|  | ||||
| 	probe := assertType[PacketProbe](t, h.Sent[2].Packet) | ||||
| 	h.OnProbe(controlMsg[PacketProbe]{Packet: probe}) | ||||
|  | ||||
| 	assertType[*StateClientDirect](t, h.State) | ||||
| } | ||||
							
								
								
									
										75
									
								
								peer/pubaddrs.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										75
									
								
								peer/pubaddrs.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,75 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"log" | ||||
| 	"net/netip" | ||||
| 	"runtime/debug" | ||||
| 	"sort" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| type pubAddrStore struct { | ||||
| 	localPub  bool | ||||
| 	localAddr netip.AddrPort | ||||
| 	lastSeen  map[netip.AddrPort]time.Time | ||||
| 	addrList  []netip.AddrPort | ||||
| } | ||||
|  | ||||
| func newPubAddrStore(localAddr netip.AddrPort) *pubAddrStore { | ||||
| 	return &pubAddrStore{ | ||||
| 		localPub:  localAddr.IsValid(), | ||||
| 		localAddr: localAddr, | ||||
| 		lastSeen:  map[netip.AddrPort]time.Time{}, | ||||
| 		addrList:  make([]netip.AddrPort, 0, 32), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (store *pubAddrStore) Store(add netip.AddrPort) { | ||||
| 	if store.localPub { | ||||
| 		log.Printf("OOPS: Local pub but storage attempt: %s", debug.Stack()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if !add.IsValid() { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if _, exists := store.lastSeen[add]; !exists { | ||||
| 		store.addrList = append(store.addrList, add) | ||||
| 	} | ||||
| 	store.lastSeen[add] = time.Now() | ||||
| 	store.sort() | ||||
| } | ||||
|  | ||||
| func (store *pubAddrStore) Get() (addrs [8]netip.AddrPort) { | ||||
| 	if store.localPub { | ||||
| 		addrs[0] = store.localAddr | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	copy(addrs[:], store.addrList) | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func (store *pubAddrStore) Clean() { | ||||
| 	if store.localPub { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	for ip, lastSeen := range store.lastSeen { | ||||
| 		if time.Since(lastSeen) > timeoutInterval { | ||||
| 			delete(store.lastSeen, ip) | ||||
| 		} | ||||
| 	} | ||||
| 	store.addrList = store.addrList[:0] | ||||
| 	for ip := range store.lastSeen { | ||||
| 		store.addrList = append(store.addrList, ip) | ||||
| 	} | ||||
| 	store.sort() | ||||
| } | ||||
|  | ||||
| func (store *pubAddrStore) sort() { | ||||
| 	sort.Slice(store.addrList, func(i, j int) bool { | ||||
| 		return store.lastSeen[store.addrList[j]].Before(store.lastSeen[store.addrList[i]]) | ||||
| 	}) | ||||
| } | ||||
							
								
								
									
										29
									
								
								peer/pubaddrs_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								peer/pubaddrs_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,29 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"net/netip" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| func TestPubAddrStore(t *testing.T) { | ||||
| 	s := newPubAddrStore(netip.AddrPort{}) | ||||
|  | ||||
| 	l := []netip.AddrPort{ | ||||
| 		netip.AddrPortFrom(netip.AddrFrom4([4]byte{0, 1, 2, 3}), 20), | ||||
| 		netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 2, 3}), 21), | ||||
| 		netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 1, 2, 3}), 22), | ||||
| 	} | ||||
|  | ||||
| 	for i := range l { | ||||
| 		s.Store(l[i]) | ||||
| 		time.Sleep(time.Millisecond) | ||||
| 	} | ||||
|  | ||||
| 	s.Clean() | ||||
|  | ||||
| 	l2 := s.Get() | ||||
| 	if l2[0] != l[2] || l2[1] != l[1] || l2[2] != l[0] { | ||||
| 		t.Fatal(l, l2) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										137
									
								
								peer/routingtable.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										137
									
								
								peer/routingtable.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,137 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"net/netip" | ||||
| 	"sync/atomic" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| // TODO: Remove | ||||
| func NewRemotePeer(ip byte) *RemotePeer { | ||||
| 	counter := uint64(time.Now().Unix()<<30 + 1) | ||||
| 	return &RemotePeer{ | ||||
| 		IP:       ip, | ||||
| 		counter:  &counter, | ||||
| 		dupCheck: newDupCheck(0), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type RemotePeer struct { | ||||
| 	localIP       byte | ||||
| 	IP            byte           // VPN IP of peer (last byte). | ||||
| 	Up            bool           // True if data can be sent on the peer. | ||||
| 	Relay         bool           // True if the peer is a relay. | ||||
| 	Direct        bool           // True if this is a direct connection. | ||||
| 	DirectAddr    netip.AddrPort // Remote address if directly connected. | ||||
| 	PubSignKey    []byte | ||||
| 	ControlCipher *controlCipher | ||||
| 	DataCipher    *dataCipher | ||||
|  | ||||
| 	counter  *uint64   // For sending to. Atomic access only. | ||||
| 	dupCheck *dupCheck // For receiving from. Not safe for concurrent use. | ||||
| } | ||||
|  | ||||
| func (p RemotePeer) EncryptDataPacket(destIP byte, data, out []byte) []byte { | ||||
| 	h := header{ | ||||
| 		StreamID: dataStreamID, | ||||
| 		Counter:  atomic.AddUint64(p.counter, 1), | ||||
| 		SourceIP: p.localIP, | ||||
| 		DestIP:   destIP, | ||||
| 	} | ||||
| 	return p.DataCipher.Encrypt(h, data, out) | ||||
| } | ||||
|  | ||||
| // Decrypts and de-dups incoming data packets. | ||||
| func (p RemotePeer) DecryptDataPacket(h header, enc, out []byte) ([]byte, error) { | ||||
| 	dec, ok := p.DataCipher.Decrypt(enc, out) | ||||
| 	if !ok { | ||||
| 		return nil, errDecryptionFailed | ||||
| 	} | ||||
|  | ||||
| 	if p.dupCheck.IsDup(h.Counter) { | ||||
| 		return nil, errDuplicateSeqNum | ||||
| 	} | ||||
|  | ||||
| 	return dec, nil | ||||
| } | ||||
|  | ||||
| // Peer must have a ControlCipher. | ||||
| func (p RemotePeer) EncryptControlPacket(pkt Marshaller, tmp, out []byte) []byte { | ||||
| 	h := header{ | ||||
| 		StreamID: controlStreamID, | ||||
| 		Counter:  atomic.AddUint64(p.counter, 1), | ||||
| 		SourceIP: p.localIP, | ||||
| 		DestIP:   p.IP, | ||||
| 	} | ||||
| 	tmp = pkt.Marshal(tmp) | ||||
| 	return p.ControlCipher.Encrypt(h, tmp, out) | ||||
| } | ||||
|  | ||||
| // Returns a controlMsg[PacketType]. Peer must have a non-nil ControlCipher. | ||||
| // | ||||
| // This function also drops packets with duplicate sequence numbers. | ||||
| func (p RemotePeer) DecryptControlPacket(fromAddr netip.AddrPort, h header, enc, tmp []byte) (any, error) { | ||||
| 	out, ok := p.ControlCipher.Decrypt(enc, tmp) | ||||
| 	if !ok { | ||||
| 		return nil, errDecryptionFailed | ||||
| 	} | ||||
|  | ||||
| 	if p.dupCheck.IsDup(h.Counter) { | ||||
| 		return nil, errDuplicateSeqNum | ||||
| 	} | ||||
|  | ||||
| 	msg, err := parseControlMsg(h.SourceIP, fromAddr, out) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return msg, nil | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type RoutingTable struct { | ||||
| 	// The LocalIP is the configured IP address of the local peer on the VPN. | ||||
| 	// | ||||
| 	// This value is constant. | ||||
| 	LocalIP byte | ||||
|  | ||||
| 	// The LocalAddr is the configured local public address of the peer on the | ||||
| 	// internet. If LocalAddr.IsValid(), then the local peer has a public | ||||
| 	// address. | ||||
| 	// | ||||
| 	// This value is constant. | ||||
| 	LocalAddr netip.AddrPort | ||||
|  | ||||
| 	// The remote peer configurations. These are updated by | ||||
| 	Peers [256]RemotePeer | ||||
|  | ||||
| 	// The current relay's VPN IP address, or zero if no relay is available. | ||||
| 	RelayIP byte | ||||
| } | ||||
|  | ||||
| func NewRoutingTable(localIP byte, localAddr netip.AddrPort) RoutingTable { | ||||
| 	rt := RoutingTable{ | ||||
| 		LocalIP:   localIP, | ||||
| 		LocalAddr: localAddr, | ||||
| 	} | ||||
|  | ||||
| 	for i := range rt.Peers { | ||||
| 		counter := uint64(time.Now().Unix()<<30 + 1) | ||||
| 		rt.Peers[i] = RemotePeer{ | ||||
| 			localIP:  localIP, | ||||
| 			IP:       byte(i), | ||||
| 			counter:  &counter, | ||||
| 			dupCheck: newDupCheck(0), | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return rt | ||||
| } | ||||
|  | ||||
| func (rt *RoutingTable) GetRelay() (RemotePeer, bool) { | ||||
| 	relay := rt.Peers[rt.RelayIP] | ||||
| 	return relay, relay.Up && relay.Direct | ||||
| } | ||||
							
								
								
									
										169
									
								
								peer/routingtable_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										169
									
								
								peer/routingtable_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,169 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"reflect" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| func TestRemotePeer_DecryptDataPacket(t *testing.T) { | ||||
| 	p1, p2, _ := NewPeersForTesting() | ||||
| 	orig := RandPacket() | ||||
|  | ||||
| 	peer2 := p1.RT.Load().Peers[2] | ||||
| 	peer1 := p2.RT.Load().Peers[1] | ||||
|  | ||||
| 	enc := peer2.EncryptDataPacket(2, orig, newBuf()) | ||||
|  | ||||
| 	h := parseHeader(enc) | ||||
| 	if h.DestIP != 2 || h.SourceIP != 1 { | ||||
| 		t.Fatal(h) | ||||
| 	} | ||||
|  | ||||
| 	dec, err := peer1.DecryptDataPacket(h, enc, newBuf()) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	if !bytes.Equal(orig, dec) { | ||||
| 		t.Fatal(dec) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestRemotePeer_DecryptDataPacket_packetAltered(t *testing.T) { | ||||
| 	p1, p2, _ := NewPeersForTesting() | ||||
| 	orig := RandPacket() | ||||
|  | ||||
| 	peer2 := p1.RT.Load().Peers[2] | ||||
| 	peer1 := p2.RT.Load().Peers[1] | ||||
|  | ||||
| 	enc := peer2.EncryptDataPacket(2, orig, newBuf()) | ||||
|  | ||||
| 	h := parseHeader(enc) | ||||
|  | ||||
| 	for range 2048 { | ||||
| 		_, err := peer1.DecryptDataPacket(h, ModifyPacket(enc), newBuf()) | ||||
| 		if err == nil { | ||||
| 			t.Fatal(enc) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestRemotePeer_DecryptDataPacket_duplicateSequenceNumber(t *testing.T) { | ||||
| 	p1, p2, _ := NewPeersForTesting() | ||||
| 	orig := RandPacket() | ||||
|  | ||||
| 	peer2 := p1.RT.Load().Peers[2] | ||||
| 	peer1 := p2.RT.Load().Peers[1] | ||||
|  | ||||
| 	enc := peer2.EncryptDataPacket(2, orig, newBuf()) | ||||
| 	h := parseHeader(enc) | ||||
|  | ||||
| 	if _, err := peer1.DecryptDataPacket(h, enc, newBuf()); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	if _, err := peer1.DecryptDataPacket(h, enc, newBuf()); err == nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestRemotePeer_DecryptControlPacket(t *testing.T) { | ||||
| 	p1, p2, _ := NewPeersForTesting() | ||||
|  | ||||
| 	peer2 := p1.RT.Load().Peers[2] | ||||
| 	peer1 := p2.RT.Load().Peers[1] | ||||
|  | ||||
| 	orig := PacketProbe{TraceID: newTraceID()} | ||||
|  | ||||
| 	enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf()) | ||||
|  | ||||
| 	h := parseHeader(enc) | ||||
| 	if h.DestIP != 2 || h.SourceIP != 1 { | ||||
| 		t.Fatal(h) | ||||
| 	} | ||||
|  | ||||
| 	ctrlMsg, err := peer1.DecryptControlPacket(p1.RT.Load().LocalAddr, h, enc, newBuf()) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	dec, ok := ctrlMsg.(controlMsg[PacketProbe]) | ||||
| 	if !ok { | ||||
| 		t.Fatal(ctrlMsg) | ||||
| 	} | ||||
|  | ||||
| 	if dec.SrcIP != 1 || dec.SrcAddr != p1.RT.Load().LocalAddr { | ||||
| 		t.Fatal(dec) | ||||
| 	} | ||||
|  | ||||
| 	if !reflect.DeepEqual(dec.Packet, orig) { | ||||
| 		t.Fatal(dec) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestRemotePeer_DecryptControlPacket_packetAltered(t *testing.T) { | ||||
| 	p1, p2, _ := NewPeersForTesting() | ||||
|  | ||||
| 	peer2 := p1.RT.Load().Peers[2] | ||||
| 	peer1 := p2.RT.Load().Peers[1] | ||||
|  | ||||
| 	orig := PacketProbe{TraceID: newTraceID()} | ||||
|  | ||||
| 	enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf()) | ||||
|  | ||||
| 	h := parseHeader(enc) | ||||
| 	if h.DestIP != 2 || h.SourceIP != 1 { | ||||
| 		t.Fatal(h) | ||||
| 	} | ||||
|  | ||||
| 	for range 2048 { | ||||
| 		ctrlMsg, err := peer1.DecryptControlPacket(p1.RT.Load().LocalAddr, h, ModifyPacket(enc), newBuf()) | ||||
| 		if err == nil { | ||||
| 			t.Fatal(ctrlMsg) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestRemotePeer_DecryptControlPacket_duplicateSequenceNumber(t *testing.T) { | ||||
| 	p1, p2, _ := NewPeersForTesting() | ||||
|  | ||||
| 	peer2 := p1.RT.Load().Peers[2] | ||||
| 	peer1 := p2.RT.Load().Peers[1] | ||||
|  | ||||
| 	orig := PacketProbe{TraceID: newTraceID()} | ||||
|  | ||||
| 	enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf()) | ||||
|  | ||||
| 	h := parseHeader(enc) | ||||
| 	if h.DestIP != 2 || h.SourceIP != 1 { | ||||
| 		t.Fatal(h) | ||||
| 	} | ||||
|  | ||||
| 	if _, err := peer1.DecryptControlPacket(p1.RT.Load().LocalAddr, h, enc, newBuf()); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	if _, err := peer1.DecryptControlPacket(p1.RT.Load().LocalAddr, h, enc, newBuf()); err == nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestRemotePeer_DecryptControlPacket_unknownPacketType(t *testing.T) { | ||||
| 	p1, p2, _ := NewPeersForTesting() | ||||
|  | ||||
| 	peer2 := p1.RT.Load().Peers[2] | ||||
| 	peer1 := p2.RT.Load().Peers[1] | ||||
|  | ||||
| 	orig := UnknownControlPacket{TraceID: newTraceID()} | ||||
|  | ||||
| 	enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf()) | ||||
|  | ||||
| 	h := parseHeader(enc) | ||||
| 	if h.DestIP != 2 || h.SourceIP != 1 { | ||||
| 		t.Fatal(h) | ||||
| 	} | ||||
|  | ||||
| 	if _, err := peer1.DecryptControlPacket(p1.RT.Load().LocalAddr, h, enc, newBuf()); err == nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| } | ||||
| @@ -1,29 +0,0 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"net/netip" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| type RemotePeer struct { | ||||
| 	IP            byte           // VPN IP of peer (last byte). | ||||
| 	Up            bool           // True if data can be sent on the peer. | ||||
| 	Relay         bool           // True if the peer is a relay. | ||||
| 	Direct        bool           // True if this is a direct connection. | ||||
| 	DirectAddr    netip.AddrPort // Remote address if directly connected. | ||||
| 	PubSignKey    []byte | ||||
| 	ControlCipher *controlCipher | ||||
| 	DataCipher    *dataCipher | ||||
|  | ||||
| 	Counter  *uint64   // For sending to. Atomic access only. | ||||
| 	DupCheck *dupCheck // For receiving from. Not safe for concurrent use. | ||||
| } | ||||
|  | ||||
| func NewRemotePeer(ip byte) *RemotePeer { | ||||
| 	counter := uint64(time.Now().Unix()<<30 + 1) | ||||
| 	return &RemotePeer{ | ||||
| 		IP:       ip, | ||||
| 		Counter:  &counter, | ||||
| 		DupCheck: newDupCheck(0), | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										103
									
								
								peer/supervisor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										103
									
								
								peer/supervisor.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,103 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"log" | ||||
| 	"sync/atomic" | ||||
| 	"time" | ||||
|  | ||||
| 	"git.crumpington.com/lib/go/ratelimiter" | ||||
| ) | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type Supervisor struct { | ||||
| 	messages chan any // Incoming control messages. | ||||
| 	peers    [256]PeerState | ||||
| 	pubAddrs *pubAddrStore | ||||
| 	rt       *atomic.Pointer[RoutingTable] | ||||
| 	staged   RoutingTable | ||||
| } | ||||
|  | ||||
| func NewSupervisor( | ||||
| 	sendControl func(RemotePeer, Marshaller), | ||||
| 	privKey []byte, | ||||
| 	rt *atomic.Pointer[RoutingTable], | ||||
| ) *Supervisor { | ||||
| 	s := &Supervisor{ | ||||
| 		messages: make(chan any, 1024), | ||||
| 		pubAddrs: newPubAddrStore(rt.Load().LocalAddr), | ||||
| 		rt:       rt, | ||||
| 	} | ||||
|  | ||||
| 	routes := rt.Load() | ||||
|  | ||||
| 	for i := range s.peers { | ||||
| 		state := &State{ | ||||
| 			publish:           s.Publish, | ||||
| 			sendControlPacket: sendControl, | ||||
| 			localIP:           routes.LocalIP, | ||||
| 			remoteIP:          byte(i), | ||||
| 			privKey:           privKey, | ||||
| 			localAddr:         routes.LocalAddr, | ||||
| 			pubAddrs:          s.pubAddrs, | ||||
| 			staged:            routes.Peers[i], | ||||
| 			limiter: ratelimiter.New(ratelimiter.Config{ | ||||
| 				FillPeriod:   20 * time.Millisecond, | ||||
| 				MaxWaitCount: 1, | ||||
| 			}), | ||||
| 		} | ||||
| 		s.peers[i] = state.OnPeerUpdate(nil) | ||||
| 	} | ||||
|  | ||||
| 	return s | ||||
| } | ||||
|  | ||||
| func (s *Supervisor) HandleControlMsg(msg any) { | ||||
| 	select { | ||||
| 	case s.messages <- msg: | ||||
| 	default: | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s *Supervisor) Run() { | ||||
| 	for raw := range s.messages { | ||||
| 		switch msg := raw.(type) { | ||||
|  | ||||
| 		case peerUpdateMsg: | ||||
| 			s.peers[msg.PeerIP] = s.peers[msg.PeerIP].OnPeerUpdate(msg.Peer) | ||||
|  | ||||
| 		case controlMsg[PacketSyn]: | ||||
| 			if newState := s.peers[msg.SrcIP].OnSyn(msg); newState != nil { | ||||
| 				s.peers[msg.SrcIP] = newState | ||||
| 			} | ||||
|  | ||||
| 		case controlMsg[PacketAck]: | ||||
| 			s.peers[msg.SrcIP].OnAck(msg) | ||||
|  | ||||
| 		case controlMsg[PacketProbe]: | ||||
| 			if newState := s.peers[msg.SrcIP].OnProbe(msg); newState != nil { | ||||
| 				s.peers[msg.SrcIP] = newState | ||||
| 			} | ||||
|  | ||||
| 		case controlMsg[PacketLocalDiscovery]: | ||||
| 			s.peers[msg.SrcIP].OnLocalDiscovery(msg) | ||||
|  | ||||
| 		case pingTimerMsg: | ||||
| 			s.pubAddrs.Clean() | ||||
| 			for i := range s.peers { | ||||
| 				if newState := s.peers[i].OnPingTimer(); newState != nil { | ||||
| 					s.peers[i] = newState | ||||
| 				} | ||||
| 			} | ||||
|  | ||||
| 		default: | ||||
| 			log.Printf("WARNING: unknown message type: %+v", msg) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s *Supervisor) Publish(rp RemotePeer) { | ||||
| 	s.staged.Peers[rp.IP] = rp | ||||
| 	rt := s.staged // Copy. | ||||
| 	s.rt.Store(&rt) | ||||
| } | ||||
							
								
								
									
										26
									
								
								peer/util_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								peer/util_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,26 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"net/netip" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| func addrPort4(a, b, c, d byte, port uint16) netip.AddrPort { | ||||
| 	return netip.AddrPortFrom(netip.AddrFrom4([4]byte{a, b, c, d}), port) | ||||
| } | ||||
|  | ||||
| func assertType[T any](t *testing.T, obj any) T { | ||||
| 	t.Helper() | ||||
| 	x, ok := obj.(T) | ||||
| 	if !ok { | ||||
| 		t.Fatal("invalid type", obj) | ||||
| 	} | ||||
| 	return x | ||||
| } | ||||
|  | ||||
| func assertEqual[T comparable](t *testing.T, a, b T) { | ||||
| 	t.Helper() | ||||
| 	if a != b { | ||||
| 		t.Fatal(a, " != ", b) | ||||
| 	} | ||||
| } | ||||
		Reference in New Issue
	
	Block a user