wip
This commit is contained in:
		| @@ -2,10 +2,10 @@ package main | ||||
|  | ||||
| import ( | ||||
| 	"log" | ||||
| 	"vppn/node" | ||||
| 	"vppn/peer" | ||||
| ) | ||||
|  | ||||
| func main() { | ||||
| 	log.SetFlags(0) | ||||
| 	node.Main() | ||||
| 	peer.Main() | ||||
| } | ||||
|   | ||||
| @@ -258,7 +258,6 @@ func handleControlPacket(addr netip.AddrPort, h header, data, decBuf []byte) { | ||||
| 	default: | ||||
| 		log.Printf("Dropping control packet.") | ||||
| 	} | ||||
|  | ||||
| } | ||||
|  | ||||
| func handleDataPacket(h header, data []byte, decBuf []byte, iface ifWriter, sender dataPacketSender) { | ||||
|   | ||||
| @@ -12,7 +12,7 @@ type connReader struct { | ||||
| 	sender  encryptedPacketSender | ||||
| 	super   controlMsgHandler | ||||
| 	localIP byte | ||||
| 	peers   [256]*atomic.Pointer[RemotePeer] | ||||
| 	peers   [256]*atomic.Pointer[remotePeer] | ||||
|  | ||||
| 	buf    []byte | ||||
| 	decBuf []byte | ||||
| @@ -24,7 +24,7 @@ func newConnReader( | ||||
| 	sender encryptedPacketSender, | ||||
| 	super controlMsgHandler, | ||||
| 	localIP byte, | ||||
| 	peers [256]*atomic.Pointer[RemotePeer], | ||||
| 	peers [256]*atomic.Pointer[remotePeer], | ||||
| ) *connReader { | ||||
| 	return &connReader{ | ||||
| 		conn:    conn, | ||||
| @@ -79,7 +79,7 @@ func (r *connReader) handleNextPacket() { | ||||
| } | ||||
|  | ||||
| func (r *connReader) handleControlPacket( | ||||
| 	peer *RemotePeer, | ||||
| 	peer *remotePeer, | ||||
| 	addr netip.AddrPort, | ||||
| 	h header, | ||||
| 	enc []byte, | ||||
| @@ -102,7 +102,7 @@ func (r *connReader) handleControlPacket( | ||||
| 	r.super.HandleControlMsg(msg) | ||||
| } | ||||
|  | ||||
| func (r *connReader) handleDataPacket(peer *RemotePeer, h header, enc []byte) { | ||||
| func (r *connReader) handleDataPacket(peer *remotePeer, h header, enc []byte) { | ||||
| 	if !peer.Up { | ||||
| 		r.logf("Not connected (recv).") | ||||
| 		return | ||||
|   | ||||
| @@ -12,12 +12,12 @@ type ConnReader struct { | ||||
| 	readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error) | ||||
|  | ||||
| 	// Output | ||||
| 	writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error) | ||||
| 	iface              io.Writer | ||||
| 	forwardData      func(ip byte, pkt []byte) | ||||
| 	handleControlMsg func(pkt any) | ||||
| 	handleControlMsg   func(fromIP byte, pkt any) | ||||
|  | ||||
| 	localIP byte | ||||
| 	rt      *atomic.Pointer[RoutingTable] | ||||
| 	rt      *atomic.Pointer[routingTable] | ||||
|  | ||||
| 	buf    []byte | ||||
| 	decBuf []byte | ||||
| @@ -25,15 +25,15 @@ type ConnReader struct { | ||||
|  | ||||
| func NewConnReader( | ||||
| 	readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error), | ||||
| 	writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error), | ||||
| 	iface io.Writer, | ||||
| 	forwardData func(ip byte, pkt []byte), | ||||
| 	handleControlMsg func(pkt any), | ||||
| 	rt *atomic.Pointer[RoutingTable], | ||||
| 	handleControlMsg func(fromIP byte, pkt any), | ||||
| 	rt *atomic.Pointer[routingTable], | ||||
| ) *ConnReader { | ||||
| 	return &ConnReader{ | ||||
| 		readFromUDPAddrPort: readFromUDPAddrPort, | ||||
| 		writeToUDPAddrPort:  writeToUDPAddrPort, | ||||
| 		iface:               iface, | ||||
| 		forwardData:         forwardData, | ||||
| 		handleControlMsg:    handleControlMsg, | ||||
| 		localIP:             rt.Load().LocalIP, | ||||
| 		rt:                  rt, | ||||
| @@ -50,7 +50,9 @@ func (r *ConnReader) Run() { | ||||
|  | ||||
| func (r *ConnReader) handleNextPacket() { | ||||
| 	buf := r.buf[:bufferSize] | ||||
| 	log.Printf("Getting next packet...") | ||||
| 	n, remoteAddr, err := r.readFromUDPAddrPort(buf) | ||||
| 	log.Printf("Packet from %v...", remoteAddr) | ||||
| 	if err != nil { | ||||
| 		log.Fatalf("Failed to read from UDP port: %v", err) | ||||
| 	} | ||||
| @@ -64,14 +66,14 @@ func (r *ConnReader) handleNextPacket() { | ||||
| 	buf = buf[:n] | ||||
| 	h := parseHeader(buf) | ||||
|  | ||||
| 	peer := r.rt.Load().Peers[h.SourceIP] | ||||
| 	//peer := rt.Peers[h.SourceIP] | ||||
| 	rt := r.rt.Load() | ||||
| 	peer := rt.Peers[h.SourceIP] | ||||
|  | ||||
| 	switch h.StreamID { | ||||
| 	case controlStreamID: | ||||
| 		r.handleControlPacket(remoteAddr, peer, h, buf) | ||||
| 	case dataStreamID: | ||||
| 		r.handleDataPacket(peer, h, buf) | ||||
| 		r.handleDataPacket(rt, peer, h, buf) | ||||
| 	default: | ||||
| 		r.logf("Unknown stream ID: %d", h.StreamID) | ||||
| 	} | ||||
| @@ -79,7 +81,7 @@ func (r *ConnReader) handleNextPacket() { | ||||
|  | ||||
| func (r *ConnReader) handleControlPacket( | ||||
| 	remoteAddr netip.AddrPort, | ||||
| 	peer RemotePeer, | ||||
| 	peer remotePeer, | ||||
| 	h header, | ||||
| 	enc []byte, | ||||
| ) { | ||||
| @@ -98,11 +100,12 @@ func (r *ConnReader) handleControlPacket( | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	r.handleControlMsg(msg) | ||||
| 	r.handleControlMsg(h.SourceIP, msg) | ||||
| } | ||||
|  | ||||
| func (r *ConnReader) handleDataPacket( | ||||
| 	peer RemotePeer, | ||||
| 	rt *routingTable, | ||||
| 	peer remotePeer, | ||||
| 	h header, | ||||
| 	enc []byte, | ||||
| ) { | ||||
| @@ -124,7 +127,13 @@ func (r *ConnReader) handleDataPacket( | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	r.forwardData(h.DestIP, data) | ||||
| 	relay, ok := rt.GetRelay() | ||||
| 	if !ok { | ||||
| 		r.logf("Relay not available.") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	r.writeToUDPAddrPort(data, relay.DirectAddr) | ||||
| } | ||||
|  | ||||
| func (r *ConnReader) logf(format string, args ...any) { | ||||
|   | ||||
| @@ -1,353 +0,0 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"crypto/rand" | ||||
| 	"net/netip" | ||||
| 	"reflect" | ||||
| 	"sync/atomic" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| type mockIfWriter struct { | ||||
| 	Written [][]byte | ||||
| } | ||||
|  | ||||
| func (w *mockIfWriter) Write(b []byte) (int, error) { | ||||
| 	w.Written = append(w.Written, bytes.Clone(b)) | ||||
| 	return len(b), nil | ||||
| } | ||||
|  | ||||
| type mockEncryptedPacket struct { | ||||
| 	Packet []byte | ||||
| 	Route  *RemotePeer | ||||
| } | ||||
|  | ||||
| type mockEncryptedPacketSender struct { | ||||
| 	Sent []mockEncryptedPacket | ||||
| } | ||||
|  | ||||
| func (m *mockEncryptedPacketSender) SendEncryptedDataPacket(pkt []byte, route *RemotePeer) { | ||||
| 	m.Sent = append(m.Sent, mockEncryptedPacket{ | ||||
| 		Packet: bytes.Clone(pkt), | ||||
| 		Route:  route, | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| type mockControlMsgHandler struct { | ||||
| 	Messages []any | ||||
| } | ||||
|  | ||||
| func (m *mockControlMsgHandler) HandleControlMsg(pkt any) { | ||||
| 	m.Messages = append(m.Messages, pkt) | ||||
| } | ||||
|  | ||||
| type udpPipe struct { | ||||
| 	packets chan []byte | ||||
| } | ||||
|  | ||||
| func newUDPPipe() *udpPipe { | ||||
| 	return &udpPipe{make(chan []byte, 1024)} | ||||
| } | ||||
|  | ||||
| func (p *udpPipe) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { | ||||
| 	p.packets <- bytes.Clone(b) | ||||
| 	return len(b), nil | ||||
| } | ||||
|  | ||||
| func (p *udpPipe) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) { | ||||
| 	packet := <-p.packets | ||||
| 	copy(b, packet) | ||||
| 	return len(packet), netip.AddrPort{}, nil | ||||
| } | ||||
|  | ||||
| type connReaderTestHarness struct { | ||||
| 	Pipe         *udpPipe | ||||
| 	R            *connReader | ||||
| 	WRemote      *connWriter | ||||
| 	WRelayRemote *connWriter | ||||
| 	Remote       *RemotePeer | ||||
| 	RelayRemote  *RemotePeer | ||||
| 	IFace        *mockIfWriter | ||||
| 	Sender       *mockEncryptedPacketSender | ||||
| 	Super        *mockControlMsgHandler | ||||
| } | ||||
|  | ||||
| // Peer 2 is indirect, peer 3 is direct. | ||||
| func newConnReadeTestHarness() (h connReaderTestHarness) { | ||||
| 	pipe := newUDPPipe() | ||||
| 	routes := [256]*atomic.Pointer[RemotePeer]{} | ||||
| 	for i := range routes { | ||||
| 		routes[i] = &atomic.Pointer[RemotePeer]{} | ||||
| 		routes[i].Store(&RemotePeer{}) | ||||
| 	} | ||||
|  | ||||
| 	local, remote, relayLocal, relayRemote := testConnWriter_getTestRoutes() | ||||
| 	routes[2].Store(local) | ||||
| 	routes[3].Store(relayLocal) | ||||
|  | ||||
| 	h.Pipe = pipe | ||||
| 	h.WRemote = newConnWriter(pipe, 2) | ||||
| 	h.WRelayRemote = newConnWriter(pipe, 3) | ||||
|  | ||||
| 	h.Remote = remote | ||||
| 	h.RelayRemote = relayRemote | ||||
| 	h.IFace = &mockIfWriter{} | ||||
| 	h.Sender = &mockEncryptedPacketSender{} | ||||
| 	h.Super = &mockControlMsgHandler{} | ||||
| 	h.R = newConnReader( | ||||
| 		pipe, | ||||
| 		h.IFace, | ||||
| 		h.Sender, | ||||
| 		h.Super, | ||||
| 		1, | ||||
| 		routes) | ||||
| 	return h | ||||
| } | ||||
|  | ||||
| // Testing that we can receive a control packet. | ||||
| func TestConnReader_handleControlPacket(t *testing.T) { | ||||
| 	h := newConnReadeTestHarness() | ||||
|  | ||||
| 	pkt := PacketSyn{TraceID: 1234} | ||||
|  | ||||
| 	h.WRemote.SendControlPacket(pkt, h.Remote) | ||||
|  | ||||
| 	h.R.handleNextPacket() | ||||
|  | ||||
| 	if len(h.Super.Messages) != 1 { | ||||
| 		t.Fatal(h.Super.Messages) | ||||
| 	} | ||||
|  | ||||
| 	msg := h.Super.Messages[0].(controlMsg[PacketSyn]) | ||||
| 	if !reflect.DeepEqual(pkt, msg.Packet) { | ||||
| 		t.Fatal(msg.Packet) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Testing that a short packet is ignored. | ||||
| func TestConnReader_handleNextPacket_short(t *testing.T) { | ||||
| 	h := newConnReadeTestHarness() | ||||
|  | ||||
| 	h.Pipe.WriteToUDPAddrPort([]byte{1, 2, 3}, netip.AddrPort{}) | ||||
| 	h.R.handleNextPacket() | ||||
|  | ||||
| 	if len(h.Super.Messages) != 0 { | ||||
| 		t.Fatal(h.Super.Messages) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Testing that a packet with an unexpected stream ID is ignored. | ||||
| func TestConnReader_handleNextPacket_unknownStreamID(t *testing.T) { | ||||
| 	h := newConnReadeTestHarness() | ||||
|  | ||||
| 	pkt := PacketSyn{TraceID: 1234} | ||||
|  | ||||
| 	encrypted := encryptControlPacket(1, h.Remote, pkt, newBuf(), newBuf()) | ||||
| 	var header header | ||||
| 	header.Parse(encrypted) | ||||
| 	header.StreamID = 100 | ||||
| 	header.Marshal(encrypted) | ||||
|  | ||||
| 	h.WRemote.writeTo(encrypted, netip.AddrPort{}) | ||||
| 	h.R.handleNextPacket() | ||||
| 	if len(h.Super.Messages) != 0 { | ||||
| 		t.Fatal(h.Super.Messages) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Testing that control packet without matching control cipher is ignored. | ||||
| func TestConnReader_handleControlPacket_noCipher(t *testing.T) { | ||||
| 	h := newConnReadeTestHarness() | ||||
|  | ||||
| 	pkt := PacketSyn{TraceID: 1234} | ||||
|  | ||||
| 	//encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) | ||||
| 	encrypted := encryptControlPacket(1, h.Remote, pkt, newBuf(), newBuf()) | ||||
| 	var header header | ||||
| 	header.Parse(encrypted) | ||||
| 	header.SourceIP = 10 | ||||
| 	header.Marshal(encrypted) | ||||
|  | ||||
| 	h.WRemote.writeTo(encrypted, netip.AddrPort{}) | ||||
| 	h.R.handleNextPacket() | ||||
| 	if len(h.Super.Messages) != 0 { | ||||
| 		t.Fatal(h.Super.Messages) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Testing that control packet with incrrect destination IP is ignored. | ||||
| func TestConnReader_handleControlPacket_incorrectDest(t *testing.T) { | ||||
| 	h := newConnReadeTestHarness() | ||||
|  | ||||
| 	pkt := PacketSyn{TraceID: 1234} | ||||
|  | ||||
| 	encrypted := encryptControlPacket(2, h.Remote, pkt, newBuf(), newBuf()) | ||||
| 	var header header | ||||
| 	header.Parse(encrypted) | ||||
| 	header.DestIP++ | ||||
| 	header.Marshal(encrypted) | ||||
|  | ||||
| 	h.WRemote.writeTo(encrypted, netip.AddrPort{}) | ||||
| 	h.R.handleNextPacket() | ||||
| 	if len(h.Super.Messages) != 0 { | ||||
| 		t.Fatal(h.Super.Messages) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Testing that modified control packet is ignored. | ||||
| func TestConnReader_handleControlPacket_modified(t *testing.T) { | ||||
| 	h := newConnReadeTestHarness() | ||||
|  | ||||
| 	pkt := PacketSyn{TraceID: 1234} | ||||
|  | ||||
| 	encrypted := encryptControlPacket(2, h.Remote, pkt, newBuf(), newBuf()) | ||||
| 	encrypted[len(encrypted)-1]++ | ||||
|  | ||||
| 	h.WRemote.writeTo(encrypted, netip.AddrPort{}) | ||||
| 	h.R.handleNextPacket() | ||||
| 	if len(h.Super.Messages) != 0 { | ||||
| 		t.Fatal(h.Super.Messages) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type unknownPacket struct{} | ||||
|  | ||||
| func (p unknownPacket) Marshal(buf []byte) []byte { | ||||
| 	buf = buf[:1] | ||||
| 	buf[0] = 100 | ||||
| 	return buf | ||||
| } | ||||
|  | ||||
| // Testing that an empty control packet is ignored. | ||||
| func TestConnReader_handleControlPacket_unknownPacketType(t *testing.T) { | ||||
| 	h := newConnReadeTestHarness() | ||||
|  | ||||
| 	pkt := unknownPacket{} | ||||
|  | ||||
| 	encrypted := encryptControlPacket(2, h.Remote, pkt, newBuf(), newBuf()) | ||||
| 	h.WRemote.writeTo(encrypted, netip.AddrPort{}) | ||||
| 	h.R.handleNextPacket() | ||||
| 	if len(h.Super.Messages) != 0 { | ||||
| 		t.Fatal(h.Super.Messages) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Testing that a duplicate control packet is ignored. | ||||
| func TestConnReader_handleControlPacket_duplicate(t *testing.T) { | ||||
| 	h := newConnReadeTestHarness() | ||||
|  | ||||
| 	pkt := PacketAck{TraceID: 1234} | ||||
|  | ||||
| 	h.WRemote.SendControlPacket(pkt, h.Remote) | ||||
| 	*h.Remote.counter = *h.Remote.counter - 1 | ||||
| 	h.WRemote.SendControlPacket(pkt, h.Remote) | ||||
|  | ||||
| 	h.R.handleNextPacket() | ||||
| 	h.R.handleNextPacket() | ||||
|  | ||||
| 	if len(h.Super.Messages) != 1 { | ||||
| 		t.Fatal(h.Super.Messages) | ||||
| 	} | ||||
|  | ||||
| 	msg := h.Super.Messages[0].(controlMsg[PacketAck]) | ||||
| 	if !reflect.DeepEqual(pkt, msg.Packet) { | ||||
| 		t.Fatal(msg.Packet) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Testing that we can receive a data packet. | ||||
| func TestConnReader_handleDataPacket(t *testing.T) { | ||||
| 	h := newConnReadeTestHarness() | ||||
|  | ||||
| 	pkt := make([]byte, 1024) | ||||
| 	rand.Read(pkt) | ||||
|  | ||||
| 	h.WRemote.SendDataPacket(pkt, h.Remote) | ||||
|  | ||||
| 	h.R.handleNextPacket() | ||||
|  | ||||
| 	if len(h.IFace.Written) != 1 { | ||||
| 		t.Fatal(h.IFace.Written) | ||||
| 	} | ||||
|  | ||||
| 	if !bytes.Equal(pkt, h.IFace.Written[0]) { | ||||
| 		t.Fatal(h.IFace.Written) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Testing that data packet is ignored if route isn't up. | ||||
| func TestConnReader_handleDataPacket_routeDown(t *testing.T) { | ||||
| 	h := newConnReadeTestHarness() | ||||
|  | ||||
| 	pkt := make([]byte, 1024) | ||||
| 	rand.Read(pkt) | ||||
|  | ||||
| 	h.WRemote.SendDataPacket(pkt, h.Remote) | ||||
| 	route := h.R.peers[2].Load() | ||||
| 	route.Up = false | ||||
|  | ||||
| 	h.R.handleNextPacket() | ||||
|  | ||||
| 	if len(h.IFace.Written) != 0 { | ||||
| 		t.Fatal(h.IFace.Written) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Testing that a duplicate data packet is ignored. | ||||
| func TestConnReader_handleDataPacket_duplicate(t *testing.T) { | ||||
| 	h := newConnReadeTestHarness() | ||||
|  | ||||
| 	pkt := make([]byte, 123) | ||||
|  | ||||
| 	h.WRemote.SendDataPacket(pkt, h.Remote) | ||||
| 	*h.Remote.counter = *h.Remote.counter - 1 | ||||
| 	h.WRemote.SendDataPacket(pkt, h.Remote) | ||||
|  | ||||
| 	h.R.handleNextPacket() | ||||
| 	h.R.handleNextPacket() | ||||
|  | ||||
| 	if len(h.IFace.Written) != 1 { | ||||
| 		t.Fatal(h.IFace.Written) | ||||
| 	} | ||||
|  | ||||
| 	if !bytes.Equal(pkt, h.IFace.Written[0]) { | ||||
| 		t.Fatal(h.IFace.Written) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Testing that we can relay a data packet. | ||||
| func TestConnReader_handleDataPacket_relay(t *testing.T) { | ||||
| 	h := newConnReadeTestHarness() | ||||
|  | ||||
| 	pkt := make([]byte, 1024) | ||||
| 	rand.Read(pkt) | ||||
|  | ||||
| 	h.RelayRemote.IP = 3 | ||||
| 	h.WRemote.RelayDataPacket(pkt, h.RelayRemote, h.Remote) | ||||
| 	h.R.handleNextPacket() | ||||
|  | ||||
| 	if len(h.Sender.Sent) != 1 { | ||||
| 		t.Fatal(h.Sender.Sent) | ||||
| 	} | ||||
|  | ||||
| } | ||||
|  | ||||
| // Testing that we drop a relayed packet if destination is down. | ||||
| func TestConnReader_handleDataPacket_relayDown(t *testing.T) { | ||||
| 	h := newConnReadeTestHarness() | ||||
|  | ||||
| 	pkt := make([]byte, 1024) | ||||
| 	rand.Read(pkt) | ||||
|  | ||||
| 	h.RelayRemote.IP = 3 | ||||
| 	relay := h.R.peers[3].Load() | ||||
| 	relay.Up = false | ||||
|  | ||||
| 	h.WRemote.RelayDataPacket(pkt, h.RelayRemote, h.Remote) | ||||
| 	h.R.handleNextPacket() | ||||
|  | ||||
| 	if len(h.Sender.Sent) != 0 { | ||||
| 		t.Fatal(h.Sender.Sent) | ||||
| 	} | ||||
| } | ||||
| @@ -1,80 +0,0 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"log" | ||||
| 	"net/netip" | ||||
| 	"sync" | ||||
| ) | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type connWriter struct { | ||||
| 	localIP byte | ||||
| 	conn    udpWriter | ||||
|  | ||||
| 	// For sending control packets. | ||||
| 	cBuf1 []byte | ||||
| 	cBuf2 []byte | ||||
|  | ||||
| 	// For sending data packets. | ||||
| 	dBuf1 []byte | ||||
| 	dBuf2 []byte | ||||
|  | ||||
| 	// Lock around for sending on UDP Conn. | ||||
| 	wLock sync.Mutex | ||||
| } | ||||
|  | ||||
| func newConnWriter(conn udpWriter, localIP byte) *connWriter { | ||||
| 	w := &connWriter{ | ||||
| 		localIP: localIP, | ||||
| 		conn:    conn, | ||||
| 		cBuf1:   make([]byte, bufferSize), | ||||
| 		cBuf2:   make([]byte, bufferSize), | ||||
| 		dBuf1:   make([]byte, bufferSize), | ||||
| 		dBuf2:   make([]byte, bufferSize), | ||||
| 	} | ||||
| 	return w | ||||
| } | ||||
|  | ||||
| // Not safe for concurrent use. Should only be called by supervisor. | ||||
| func (w *connWriter) SendControlPacket(pkt Marshaller, peer *RemotePeer) { | ||||
| 	enc := encryptControlPacket(w.localIP, peer, pkt, w.cBuf1, w.cBuf2) | ||||
| 	w.writeTo(enc, peer.DirectAddr) | ||||
| } | ||||
|  | ||||
| // Relay control packet. Peer must not be nil. | ||||
| func (w *connWriter) RelayControlPacket(pkt Marshaller, peer, relay *RemotePeer) { | ||||
| 	enc := encryptControlPacket(w.localIP, peer, pkt, w.cBuf1, w.cBuf2) | ||||
| 	enc = encryptDataPacket(w.localIP, peer.IP, relay, enc, w.cBuf1) | ||||
| 	w.writeTo(enc, relay.DirectAddr) | ||||
| } | ||||
|  | ||||
| // Not safe for concurrent use. Should only be called by ifReader. | ||||
| func (w *connWriter) SendDataPacket(pkt []byte, peer *RemotePeer) { | ||||
| 	enc := encryptDataPacket(w.localIP, peer.IP, peer, pkt, w.dBuf1) | ||||
| 	w.writeTo(enc, peer.DirectAddr) | ||||
| } | ||||
|  | ||||
| // Relay a data packet. Peer must not be nil. | ||||
| func (w *connWriter) RelayDataPacket(pkt []byte, peer, relay *RemotePeer) { | ||||
| 	enc := encryptDataPacket(w.localIP, peer.IP, peer, pkt, w.dBuf1) | ||||
| 	enc = encryptDataPacket(w.localIP, peer.IP, relay, enc, w.dBuf2) | ||||
| 	w.writeTo(enc, relay.DirectAddr) | ||||
| } | ||||
|  | ||||
| // Safe for concurrent use. Should only be called by connReader. | ||||
| // | ||||
| // This function will send pkt to the peer directly. This is used when a peer | ||||
| // is acting as a relay and is forwarding already encrypted data for another | ||||
| // peer. | ||||
| func (w *connWriter) SendEncryptedDataPacket(pkt []byte, peer *RemotePeer) { | ||||
| 	w.writeTo(pkt, peer.DirectAddr) | ||||
| } | ||||
|  | ||||
| func (w *connWriter) writeTo(packet []byte, addr netip.AddrPort) { | ||||
| 	w.wLock.Lock() | ||||
| 	if _, err := w.conn.WriteToUDPAddrPort(packet, addr); err != nil { | ||||
| 		log.Printf("[ConnWriter] Failed to write to UDP port: %v", err) | ||||
| 	} | ||||
| 	w.wLock.Unlock() | ||||
| } | ||||
| @@ -1,109 +0,0 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"log" | ||||
| 	"net/netip" | ||||
| 	"sync" | ||||
| 	"sync/atomic" | ||||
| ) | ||||
|  | ||||
| type ConnWriter struct { | ||||
| 	wLock sync.Mutex // Lock around for sending on UDP Conn. | ||||
|  | ||||
| 	// Output. | ||||
| 	writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error) | ||||
|  | ||||
| 	// Shared state. | ||||
| 	rt *atomic.Pointer[RoutingTable] | ||||
|  | ||||
| 	// For sending control packets. | ||||
| 	cBuf1 []byte | ||||
| 	cBuf2 []byte | ||||
|  | ||||
| 	// For sending data packets. | ||||
| 	dBuf1 []byte | ||||
| 	dBuf2 []byte | ||||
| } | ||||
|  | ||||
| func NewConnWriter( | ||||
| 	writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error), | ||||
| 	rt *atomic.Pointer[RoutingTable], | ||||
| ) *ConnWriter { | ||||
| 	return &ConnWriter{ | ||||
| 		writeToUDPAddrPort: writeToUDPAddrPort, | ||||
| 		rt:                 rt, | ||||
| 		cBuf1:              newBuf(), | ||||
| 		cBuf2:              newBuf(), | ||||
| 		dBuf1:              newBuf(), | ||||
| 		dBuf2:              newBuf(), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Called by ConnReader to forward already encrypted bytes to another peer. | ||||
| func (w *ConnWriter) Forward(ip byte, pkt []byte) { | ||||
| 	peer := w.rt.Load().Peers[ip] | ||||
| 	if !(peer.Up && peer.Direct) { | ||||
| 		w.logf("Failed to forward to %d.", ip) | ||||
| 		return | ||||
| 	} | ||||
| 	w.writeTo(pkt, peer.DirectAddr) | ||||
| } | ||||
|  | ||||
| // Called by IFReader to send data. Encryption will be applied, and packet will | ||||
| // be relayed if appropriate. | ||||
| func (w *ConnWriter) WriteData(ip byte, pkt []byte) { | ||||
| 	rt := w.rt.Load() | ||||
| 	peer := rt.Peers[ip] | ||||
| 	if !peer.Up { | ||||
| 		w.logf("Failed to send data to %d.", ip) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	enc := peer.EncryptDataPacket(ip, pkt, w.dBuf1) | ||||
|  | ||||
| 	if peer.Direct { | ||||
| 		w.writeTo(enc, peer.DirectAddr) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	relay, ok := rt.GetRelay() | ||||
| 	if !ok { | ||||
| 		w.logf("Failed to send data to %d. No relay.", ip) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	enc = relay.EncryptDataPacket(ip, enc, w.dBuf2) | ||||
| 	w.writeTo(enc, relay.DirectAddr) | ||||
| } | ||||
|  | ||||
| // Called by Supervisor to send control packets. | ||||
| func (w *ConnWriter) WriteControl(peer RemotePeer, pkt Marshaller) { | ||||
| 	enc := peer.EncryptControlPacket(pkt, w.cBuf2, w.cBuf1) | ||||
|  | ||||
| 	if peer.Direct { | ||||
| 		w.writeTo(enc, peer.DirectAddr) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	rt := w.rt.Load() | ||||
| 	relay, ok := rt.GetRelay() | ||||
| 	if !ok { | ||||
| 		w.logf("Failed to send control to %d. No relay.", peer.IP) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	enc = relay.EncryptDataPacket(peer.IP, enc, w.cBuf2) | ||||
| 	w.writeTo(enc, relay.DirectAddr) | ||||
| } | ||||
|  | ||||
| func (w *ConnWriter) writeTo(pkt []byte, addr netip.AddrPort) { | ||||
| 	w.wLock.Lock() | ||||
| 	if _, err := w.writeToUDPAddrPort(pkt, addr); err != nil { | ||||
| 		w.logf("Failed to write to UDP port: %v", err) | ||||
| 	} | ||||
| 	w.wLock.Unlock() | ||||
| } | ||||
|  | ||||
| func (w *ConnWriter) logf(s string, args ...any) { | ||||
| 	log.Printf("[ConnWriter] "+s, args...) | ||||
| } | ||||
| @@ -1,145 +0,0 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| func TestConnWriter_WriteData_direct(t *testing.T) { | ||||
| 	p1, p2, _ := NewPeersForTesting() | ||||
|  | ||||
| 	in := RandPacket() | ||||
| 	p1.ConnWriter.WriteData(2, in) | ||||
|  | ||||
| 	packets := p2.Conn.Packets() | ||||
| 	if len(packets) != 1 { | ||||
| 		t.Fatal(packets) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestConnWriter_WriteData_peerNotUp(t *testing.T) { | ||||
| 	p1, p2, _ := NewPeersForTesting() | ||||
| 	p1.RT.Load().Peers[2].Up = false | ||||
|  | ||||
| 	in := RandPacket() | ||||
| 	p1.ConnWriter.WriteData(2, in) | ||||
|  | ||||
| 	packets := p2.Conn.Packets() | ||||
| 	if len(packets) != 0 { | ||||
| 		t.Fatal(packets) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestConnWriter_WriteData_relay(t *testing.T) { | ||||
| 	p1, _, p3 := NewPeersForTesting() | ||||
|  | ||||
| 	p1.RT.Load().Peers[2].Direct = false | ||||
| 	p1.RT.Load().RelayIP = 3 | ||||
|  | ||||
| 	in := RandPacket() | ||||
| 	p1.ConnWriter.WriteData(2, in) | ||||
|  | ||||
| 	packets := p3.Conn.Packets() | ||||
| 	if len(packets) != 1 { | ||||
| 		t.Fatal(packets) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestConnWriter_WriteData_relayNotAvailable(t *testing.T) { | ||||
| 	p1, _, p3 := NewPeersForTesting() | ||||
|  | ||||
| 	p1.RT.Load().Peers[2].Direct = false | ||||
| 	p1.RT.Load().Peers[3].Up = false | ||||
| 	p1.RT.Load().RelayIP = 3 | ||||
|  | ||||
| 	in := RandPacket() | ||||
| 	p1.ConnWriter.WriteData(2, in) | ||||
|  | ||||
| 	packets := p3.Conn.Packets() | ||||
| 	if len(packets) != 0 { | ||||
| 		t.Fatal(packets) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestConnWriter_WriteControl_direct(t *testing.T) { | ||||
| 	p1, p2, _ := NewPeersForTesting() | ||||
|  | ||||
| 	orig := PacketProbe{TraceID: newTraceID()} | ||||
|  | ||||
| 	p1.ConnWriter.WriteControl(p1.RT.Load().Peers[2], orig) | ||||
|  | ||||
| 	packets := p2.Conn.Packets() | ||||
| 	if len(packets) != 1 { | ||||
| 		t.Fatal(packets) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestConnWriter_WriteControl_relay(t *testing.T) { | ||||
| 	p1, _, p3 := NewPeersForTesting() | ||||
|  | ||||
| 	p1.RT.Load().Peers[2].Direct = false | ||||
| 	p1.RT.Load().RelayIP = 3 | ||||
|  | ||||
| 	orig := PacketProbe{TraceID: newTraceID()} | ||||
|  | ||||
| 	p1.ConnWriter.WriteControl(p1.RT.Load().Peers[2], orig) | ||||
|  | ||||
| 	packets := p3.Conn.Packets() | ||||
| 	if len(packets) != 1 { | ||||
| 		t.Fatal(packets) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestConnWriter_WriteControl_relayNotAvailable(t *testing.T) { | ||||
| 	p1, _, p3 := NewPeersForTesting() | ||||
|  | ||||
| 	p1.RT.Load().Peers[2].Direct = false | ||||
| 	p1.RT.Load().Peers[3].Up = false | ||||
| 	p1.RT.Load().RelayIP = 3 | ||||
|  | ||||
| 	orig := PacketProbe{TraceID: newTraceID()} | ||||
|  | ||||
| 	p1.ConnWriter.WriteControl(p1.RT.Load().Peers[2], orig) | ||||
|  | ||||
| 	packets := p3.Conn.Packets() | ||||
| 	if len(packets) != 0 { | ||||
| 		t.Fatal(packets) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestConnWriter__Forward(t *testing.T) { | ||||
| 	p1, p2, _ := NewPeersForTesting() | ||||
|  | ||||
| 	in := RandPacket() | ||||
| 	p1.ConnWriter.Forward(2, in) | ||||
|  | ||||
| 	packets := p2.Conn.Packets() | ||||
| 	if len(packets) != 1 { | ||||
| 		t.Fatal(packets) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestConnWriter__Forward_notUp(t *testing.T) { | ||||
| 	p1, p2, _ := NewPeersForTesting() | ||||
| 	p1.RT.Load().Peers[2].Up = false | ||||
|  | ||||
| 	in := RandPacket() | ||||
| 	p1.ConnWriter.Forward(2, in) | ||||
|  | ||||
| 	packets := p2.Conn.Packets() | ||||
| 	if len(packets) != 0 { | ||||
| 		t.Fatal(packets) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestConnWriter__Forward_notDirect(t *testing.T) { | ||||
| 	p1, p2, _ := NewPeersForTesting() | ||||
| 	p1.RT.Load().Peers[2].Direct = false | ||||
|  | ||||
| 	in := RandPacket() | ||||
| 	p1.ConnWriter.Forward(2, in) | ||||
|  | ||||
| 	packets := p2.Conn.Packets() | ||||
| 	if len(packets) != 0 { | ||||
| 		t.Fatal(packets) | ||||
| 	} | ||||
| } | ||||
| @@ -1,240 +0,0 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"net/netip" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type testUDPPacket struct { | ||||
| 	Addr netip.AddrPort | ||||
| 	Data []byte | ||||
| } | ||||
|  | ||||
| type testUDPAddrPortWriter struct { | ||||
| 	written []testUDPPacket | ||||
| } | ||||
|  | ||||
| func (w *testUDPAddrPortWriter) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { | ||||
| 	w.written = append(w.written, testUDPPacket{ | ||||
| 		Addr: addr, | ||||
| 		Data: bytes.Clone(b), | ||||
| 	}) | ||||
| 	return len(b), nil | ||||
| } | ||||
|  | ||||
| func (w *testUDPAddrPortWriter) Written() []testUDPPacket { | ||||
| 	out := w.written | ||||
| 	w.written = []testUDPPacket{} | ||||
| 	return out | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type testPacket string | ||||
|  | ||||
| func (p testPacket) Marshal(b []byte) []byte { | ||||
| 	b = b[:len(p)] | ||||
| 	copy(b, []byte(p)) | ||||
| 	return b | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func testConnWriter_getTestRoutes() (local, remote, relayLocal, relayRemote *RemotePeer) { | ||||
| 	localKeys := generateKeys() | ||||
| 	remoteKeys := generateKeys() | ||||
|  | ||||
| 	local = NewRemotePeer(2) | ||||
| 	local.Up = true | ||||
| 	local.Relay = false | ||||
| 	local.PubSignKey = remoteKeys.PubSignKey | ||||
| 	local.ControlCipher = newControlCipher(localKeys.PrivKey, remoteKeys.PubKey) | ||||
| 	local.DataCipher = newDataCipher() | ||||
| 	local.DirectAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 2}), 100) | ||||
|  | ||||
| 	remote = NewRemotePeer(1) | ||||
| 	remote.Up = true | ||||
| 	remote.Relay = false | ||||
| 	remote.PubSignKey = localKeys.PubSignKey | ||||
| 	remote.ControlCipher = newControlCipher(remoteKeys.PrivKey, localKeys.PubKey) | ||||
| 	remote.DataCipher = local.DataCipher | ||||
| 	remote.DirectAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 100) | ||||
|  | ||||
| 	rLocalKeys := generateKeys() | ||||
| 	rRemoteKeys := generateKeys() | ||||
|  | ||||
| 	relayLocal = NewRemotePeer(3) | ||||
| 	relayLocal.Up = true | ||||
| 	relayLocal.Relay = true | ||||
| 	relayLocal.Direct = true | ||||
| 	relayLocal.PubSignKey = rRemoteKeys.PubSignKey | ||||
| 	relayLocal.ControlCipher = newControlCipher(rLocalKeys.PrivKey, rRemoteKeys.PubKey) | ||||
| 	relayLocal.DataCipher = newDataCipher() | ||||
| 	relayLocal.DirectAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 3}), 100) | ||||
|  | ||||
| 	relayRemote = NewRemotePeer(1) | ||||
| 	relayRemote.Up = true | ||||
| 	relayRemote.Relay = false | ||||
| 	relayRemote.Direct = true | ||||
| 	relayRemote.PubSignKey = rLocalKeys.PubSignKey | ||||
| 	relayRemote.ControlCipher = newControlCipher(rRemoteKeys.PrivKey, rLocalKeys.PubKey) | ||||
| 	relayRemote.DataCipher = relayLocal.DataCipher | ||||
| 	relayRemote.DirectAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 100) | ||||
|  | ||||
| 	return | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| // Testing if we can send a control packet directly to the remote route. | ||||
| func TestConnWriter_SendControlPacket_direct(t *testing.T) { | ||||
| 	route, rRoute, _, _ := testConnWriter_getTestRoutes() | ||||
| 	route.Direct = true | ||||
|  | ||||
| 	writer := &testUDPAddrPortWriter{} | ||||
| 	w := newConnWriter(writer, rRoute.IP) | ||||
| 	in := testPacket("hello world!") | ||||
|  | ||||
| 	w.SendControlPacket(in, route) | ||||
| 	out := writer.Written() | ||||
| 	if len(out) != 1 { | ||||
| 		t.Fatal(out) | ||||
| 	} | ||||
|  | ||||
| 	if out[0].Addr != route.DirectAddr { | ||||
| 		t.Fatal(out[0]) | ||||
| 	} | ||||
|  | ||||
| 	dec, ok := rRoute.ControlCipher.Decrypt(out[0].Data, make([]byte, 1024)) | ||||
| 	if !ok { | ||||
| 		t.Fatal(ok) | ||||
| 	} | ||||
| 	if string(dec) != string(in) { | ||||
| 		t.Fatal(dec) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Testing if we can relay a packet via an intermediary. | ||||
| func TestConnWriter_RelayControlPacket_relay(t *testing.T) { | ||||
| 	route, rRoute, relay, rRelay := testConnWriter_getTestRoutes() | ||||
|  | ||||
| 	writer := &testUDPAddrPortWriter{} | ||||
| 	w := newConnWriter(writer, rRoute.IP) | ||||
| 	in := testPacket("hello world!") | ||||
|  | ||||
| 	w.RelayControlPacket(in, route, relay) | ||||
|  | ||||
| 	out := writer.Written() | ||||
| 	if len(out) != 1 { | ||||
| 		t.Fatal(out) | ||||
| 	} | ||||
|  | ||||
| 	if out[0].Addr != relay.DirectAddr { | ||||
| 		t.Fatal(out[0]) | ||||
| 	} | ||||
|  | ||||
| 	dec, ok := rRelay.DataCipher.Decrypt(out[0].Data, make([]byte, 1024)) | ||||
| 	if !ok { | ||||
| 		t.Fatal(ok) | ||||
| 	} | ||||
|  | ||||
| 	dec2, ok := rRoute.ControlCipher.Decrypt(dec, make([]byte, 1024)) | ||||
| 	if !ok { | ||||
| 		t.Fatal(ok) | ||||
| 	} | ||||
|  | ||||
| 	if string(dec2) != string(in) { | ||||
| 		t.Fatal(dec2) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Testing that we can send a data packet directly to a remote route. | ||||
| func TestConnWriter_SendDataPacket_direct(t *testing.T) { | ||||
| 	route, rRoute, _, _ := testConnWriter_getTestRoutes() | ||||
| 	route.Direct = true | ||||
|  | ||||
| 	writer := &testUDPAddrPortWriter{} | ||||
| 	w := newConnWriter(writer, rRoute.IP) | ||||
|  | ||||
| 	in := []byte("hello world!") | ||||
| 	w.SendDataPacket(in, route) | ||||
|  | ||||
| 	out := writer.Written() | ||||
| 	if len(out) != 1 { | ||||
| 		t.Fatal(out) | ||||
| 	} | ||||
|  | ||||
| 	if out[0].Addr != route.DirectAddr { | ||||
| 		t.Fatal(out[0]) | ||||
| 	} | ||||
|  | ||||
| 	dec, ok := rRoute.DataCipher.Decrypt(out[0].Data, make([]byte, 1024)) | ||||
| 	if !ok { | ||||
| 		t.Fatal(ok) | ||||
| 	} | ||||
|  | ||||
| 	if !bytes.Equal(dec, in) { | ||||
| 		t.Fatal(dec) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Testing that we can relay a data packet via a relay. | ||||
| func TestConnWriter_RelayDataPacket_relay(t *testing.T) { | ||||
| 	route, rRoute, relay, rRelay := testConnWriter_getTestRoutes() | ||||
|  | ||||
| 	writer := &testUDPAddrPortWriter{} | ||||
| 	w := newConnWriter(writer, rRoute.IP) | ||||
| 	in := []byte("Hello world!") | ||||
|  | ||||
| 	w.RelayDataPacket(in, route, relay) | ||||
|  | ||||
| 	out := writer.Written() | ||||
| 	if len(out) != 1 { | ||||
| 		t.Fatal(out) | ||||
| 	} | ||||
|  | ||||
| 	if out[0].Addr != relay.DirectAddr { | ||||
| 		t.Fatal(out[0]) | ||||
| 	} | ||||
|  | ||||
| 	dec, ok := rRelay.DataCipher.Decrypt(out[0].Data, make([]byte, 1024)) | ||||
| 	if !ok { | ||||
| 		t.Fatal(ok) | ||||
| 	} | ||||
|  | ||||
| 	dec2, ok := rRoute.DataCipher.Decrypt(dec, make([]byte, 1024)) | ||||
| 	if !ok { | ||||
| 		t.Fatal(ok) | ||||
| 	} | ||||
|  | ||||
| 	if !bytes.Equal(dec2, in) { | ||||
| 		t.Fatal(dec2) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Testing that we can send an already encrypted packet. | ||||
| func TestConnWriter_SendEncryptedDataPacket(t *testing.T) { | ||||
| 	route, rRoute, _, _ := testConnWriter_getTestRoutes() | ||||
|  | ||||
| 	writer := &testUDPAddrPortWriter{} | ||||
| 	w := newConnWriter(writer, rRoute.IP) | ||||
| 	in := []byte("Hello world!") | ||||
|  | ||||
| 	w.SendEncryptedDataPacket(in, route) | ||||
|  | ||||
| 	out := writer.Written() | ||||
| 	if len(out) != 1 { | ||||
| 		t.Fatal(out) | ||||
| 	} | ||||
|  | ||||
| 	if out[0].Addr != route.DirectAddr { | ||||
| 		t.Fatal(out[0]) | ||||
| 	} | ||||
|  | ||||
| 	if !bytes.Equal(out[0].Data, in) { | ||||
| 		t.Fatal(out[0]) | ||||
| 	} | ||||
| } | ||||
| @@ -17,25 +17,25 @@ type controlMsg[T any] struct { | ||||
| func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error) { | ||||
| 	switch buf[0] { | ||||
|  | ||||
| 	case PacketTypeSyn: | ||||
| 		packet, err := ParsePacketSyn(buf) | ||||
| 		return controlMsg[PacketSyn]{ | ||||
| 	case packetTypeSyn: | ||||
| 		packet, err := parsePacketSyn(buf) | ||||
| 		return controlMsg[packetSyn]{ | ||||
| 			SrcIP:   srcIP, | ||||
| 			SrcAddr: srcAddr, | ||||
| 			Packet:  packet, | ||||
| 		}, err | ||||
|  | ||||
| 	case PacketTypeAck: | ||||
| 		packet, err := ParsePacketAck(buf) | ||||
| 		return controlMsg[PacketAck]{ | ||||
| 	case packetTypeAck: | ||||
| 		packet, err := parsePacketAck(buf) | ||||
| 		return controlMsg[packetAck]{ | ||||
| 			SrcIP:   srcIP, | ||||
| 			SrcAddr: srcAddr, | ||||
| 			Packet:  packet, | ||||
| 		}, err | ||||
|  | ||||
| 	case PacketTypeProbe: | ||||
| 		packet, err := ParsePacketProbe(buf) | ||||
| 		return controlMsg[PacketProbe]{ | ||||
| 	case packetTypeProbe: | ||||
| 		packet, err := parsePacketProbe(buf) | ||||
| 		return controlMsg[packetProbe]{ | ||||
| 			SrcIP:   srcIP, | ||||
| 			SrcAddr: srcAddr, | ||||
| 			Packet:  packet, | ||||
|   | ||||
| @@ -36,7 +36,7 @@ func generateKeys() cryptoKeys { | ||||
| // Peer must have a ControlCipher. | ||||
| func encryptControlPacket( | ||||
| 	localIP byte, | ||||
| 	peer *RemotePeer, | ||||
| 	peer *remotePeer, | ||||
| 	pkt Marshaller, | ||||
| 	tmp []byte, | ||||
| 	out []byte, | ||||
| @@ -55,7 +55,7 @@ func encryptControlPacket( | ||||
| // | ||||
| // This function also drops packets with duplicate sequence numbers. | ||||
| func decryptControlPacket( | ||||
| 	peer *RemotePeer, | ||||
| 	peer *remotePeer, | ||||
| 	fromAddr netip.AddrPort, | ||||
| 	h header, | ||||
| 	encrypted []byte, | ||||
| @@ -83,7 +83,7 @@ func decryptControlPacket( | ||||
| func encryptDataPacket( | ||||
| 	localIP byte, | ||||
| 	destIP byte, | ||||
| 	peer *RemotePeer, | ||||
| 	peer *remotePeer, | ||||
| 	data []byte, | ||||
| 	out []byte, | ||||
| ) []byte { | ||||
| @@ -98,7 +98,7 @@ func encryptDataPacket( | ||||
|  | ||||
| // Decrypts and de-dups incoming data packets. | ||||
| func decryptDataPacket( | ||||
| 	peer *RemotePeer, | ||||
| 	peer *remotePeer, | ||||
| 	h header, | ||||
| 	encrypted []byte, | ||||
| 	out []byte, | ||||
|   | ||||
| @@ -9,7 +9,7 @@ import ( | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| func newRoutePairForTesting() (*RemotePeer, *RemotePeer) { | ||||
| func newRoutePairForTesting() (*remotePeer, *remotePeer) { | ||||
| 	keys1 := generateKeys() | ||||
| 	keys2 := generateKeys() | ||||
|  | ||||
| @@ -33,7 +33,7 @@ func TestDecryptControlPacket(t *testing.T) { | ||||
| 		out    = make([]byte, bufferSize) | ||||
| 	) | ||||
|  | ||||
| 	in := PacketSyn{ | ||||
| 	in := packetSyn{ | ||||
| 		TraceID:   newTraceID(), | ||||
| 		SharedKey: r1.DataCipher.Key(), | ||||
| 		Direct:    true, | ||||
| @@ -47,7 +47,7 @@ func TestDecryptControlPacket(t *testing.T) { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	msg, ok := iMsg.(controlMsg[PacketSyn]) | ||||
| 	msg, ok := iMsg.(controlMsg[packetSyn]) | ||||
| 	if !ok { | ||||
| 		t.Fatal(ok) | ||||
| 	} | ||||
| @@ -64,7 +64,7 @@ func TestDecryptControlPacket_decryptionFailed(t *testing.T) { | ||||
| 		out    = make([]byte, bufferSize) | ||||
| 	) | ||||
|  | ||||
| 	in := PacketSyn{ | ||||
| 	in := packetSyn{ | ||||
| 		TraceID:   newTraceID(), | ||||
| 		SharedKey: r1.DataCipher.Key(), | ||||
| 		Direct:    true, | ||||
| @@ -90,7 +90,7 @@ func TestDecryptControlPacket_duplicate(t *testing.T) { | ||||
| 		out    = make([]byte, bufferSize) | ||||
| 	) | ||||
|  | ||||
| 	in := PacketSyn{ | ||||
| 	in := packetSyn{ | ||||
| 		TraceID:   newTraceID(), | ||||
| 		SharedKey: r1.DataCipher.Key(), | ||||
| 		Direct:    true, | ||||
| @@ -109,7 +109,8 @@ func TestDecryptControlPacket_duplicate(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestDecryptControlPacket_invalidPacket(t *testing.T) { | ||||
| /* | ||||
| 	func TestDecryptControlPacket_invalidPacket(t *testing.T) { | ||||
| 		var ( | ||||
| 			r1, r2 = newRoutePairForTesting() | ||||
| 			tmp    = make([]byte, bufferSize) | ||||
| @@ -125,8 +126,8 @@ func TestDecryptControlPacket_invalidPacket(t *testing.T) { | ||||
| 		if !errors.Is(err, errUnknownPacketType) { | ||||
| 			t.Fatal(err) | ||||
| 		} | ||||
| } | ||||
|  | ||||
| 	} | ||||
| */ | ||||
| func TestDecryptDataPacket(t *testing.T) { | ||||
| 	var ( | ||||
| 		r1, r2 = newRoutePairForTesting() | ||||
|   | ||||
| @@ -16,10 +16,16 @@ type hubPoller struct { | ||||
| 	versions         [256]int64 | ||||
| 	localIP          byte | ||||
| 	netName          string | ||||
| 	super    controlMsgHandler | ||||
| 	handleControlMsg func(fromIP byte, msg any) | ||||
| } | ||||
|  | ||||
| func newHubPoller(localIP byte, netName, hubURL, apiKey string, super controlMsgHandler) (*hubPoller, error) { | ||||
| func newHubPoller( | ||||
| 	localIP byte, | ||||
| 	netName, | ||||
| 	hubURL, | ||||
| 	apiKey string, | ||||
| 	handleControlMsg func(byte, any), | ||||
| ) (*hubPoller, error) { | ||||
| 	u, err := url.Parse(hubURL) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| @@ -40,7 +46,7 @@ func newHubPoller(localIP byte, netName, hubURL, apiKey string, super controlMsg | ||||
| 		req:              req, | ||||
| 		localIP:          localIP, | ||||
| 		netName:          netName, | ||||
| 		super:   super, | ||||
| 		handleControlMsg: handleControlMsg, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| @@ -90,7 +96,7 @@ func (hp *hubPoller) applyNetworkState(state m.NetworkState) { | ||||
| 	for i, peer := range state.Peers { | ||||
| 		if i != int(hp.localIP) { | ||||
| 			if peer == nil || peer.Version != hp.versions[i] { | ||||
| 				hp.super.HandleControlMsg(peerUpdateMsg{PeerIP: byte(i), Peer: state.Peers[i]}) | ||||
| 				hp.handleControlMsg(byte(i), peerUpdateMsg{PeerIP: byte(i), Peer: state.Peers[i]}) | ||||
| 				if peer != nil { | ||||
| 					hp.versions[i] = peer.Version | ||||
| 				} | ||||
|   | ||||
							
								
								
									
										100
									
								
								peer/ifreader.go
									
									
									
									
									
								
							
							
						
						
									
										100
									
								
								peer/ifreader.go
									
									
									
									
									
								
							| @@ -1,100 +0,0 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"sync/atomic" | ||||
| ) | ||||
|  | ||||
| type ifReader struct { | ||||
| 	iface  io.Reader | ||||
| 	peers  [256]*atomic.Pointer[RemotePeer] | ||||
| 	relay  *atomic.Pointer[RemotePeer] | ||||
| 	sender dataPacketSender | ||||
| } | ||||
|  | ||||
| func newIFReader( | ||||
| 	iface io.Reader, | ||||
| 	peers [256]*atomic.Pointer[RemotePeer], | ||||
| 	relay *atomic.Pointer[RemotePeer], | ||||
| 	sender dataPacketSender, | ||||
| ) *ifReader { | ||||
| 	return &ifReader{ | ||||
| 		iface:  iface, | ||||
| 		peers:  peers, | ||||
| 		relay:  relay, | ||||
| 		sender: sender, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (r *ifReader) Run() { | ||||
| 	var ( | ||||
| 		packet   = make([]byte, bufferSize) | ||||
| 		remoteIP byte | ||||
| 		ok       bool | ||||
| 	) | ||||
|  | ||||
| 	for { | ||||
| 		packet = r.readNextPacket(packet) | ||||
| 		if remoteIP, ok = r.parsePacket(packet); ok { | ||||
| 			r.sendPacket(packet, remoteIP) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (r *ifReader) sendPacket(pkt []byte, remoteIP byte) { | ||||
| 	peer := r.peers[remoteIP].Load() | ||||
| 	if !peer.Up { | ||||
| 		log.Printf("Peer not connected: %d", remoteIP) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// Direct path => early return. | ||||
| 	if peer.Direct { | ||||
| 		r.sender.SendDataPacket(pkt, peer) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if relay := r.relay.Load(); relay != nil && relay.Up { | ||||
| 		r.sender.RelayDataPacket(pkt, peer, relay) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Get next packet, returning packet, and destination ip. | ||||
| func (r *ifReader) readNextPacket(buf []byte) []byte { | ||||
| 	n, err := r.iface.Read(buf[:cap(buf)]) | ||||
| 	if err != nil { | ||||
| 		log.Fatalf("Failed to read from interface: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	return buf[:n] | ||||
| } | ||||
|  | ||||
| func (r *ifReader) parsePacket(buf []byte) (byte, bool) { | ||||
| 	n := len(buf) | ||||
| 	if n == 0 { | ||||
| 		return 0, false | ||||
| 	} | ||||
|  | ||||
| 	version := buf[0] >> 4 | ||||
|  | ||||
| 	switch version { | ||||
| 	case 4: | ||||
| 		if n < 20 { | ||||
| 			log.Printf("Short IPv4 packet: %d", len(buf)) | ||||
| 			return 0, false | ||||
| 		} | ||||
| 		return buf[19], true | ||||
|  | ||||
| 	case 6: | ||||
| 		if len(buf) < 40 { | ||||
| 			log.Printf("Short IPv6 packet: %d", len(buf)) | ||||
| 			return 0, false | ||||
| 		} | ||||
| 		return buf[39], true | ||||
|  | ||||
| 	default: | ||||
| 		log.Printf("Invalid IP packet version: %v", version) | ||||
| 		return 0, false | ||||
| 	} | ||||
| } | ||||
| @@ -3,22 +3,24 @@ package peer | ||||
| import ( | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"net/netip" | ||||
| 	"sync/atomic" | ||||
| ) | ||||
|  | ||||
| type IFReader struct { | ||||
| 	iface              io.Reader | ||||
| 	connWriter interface { | ||||
| 		WriteData(ip byte, pkt []byte) | ||||
| 	} | ||||
| 	writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error) | ||||
| 	rt                 *atomic.Pointer[routingTable] | ||||
| 	buf1               []byte | ||||
| 	buf2               []byte | ||||
| } | ||||
|  | ||||
| func NewIFReader( | ||||
| 	iface io.Reader, | ||||
| 	connWriter interface { | ||||
| 		WriteData(ip byte, pkt []byte) | ||||
| 	}, | ||||
| 	writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error), | ||||
| 	rt *atomic.Pointer[routingTable], | ||||
| ) *IFReader { | ||||
| 	return &IFReader{iface, connWriter} | ||||
| 	return &IFReader{iface, writeToUDPAddrPort, rt, newBuf(), newBuf()} | ||||
| } | ||||
|  | ||||
| func (r *IFReader) Run() { | ||||
| @@ -30,9 +32,32 @@ func (r *IFReader) Run() { | ||||
|  | ||||
| func (r *IFReader) handleNextPacket(packet []byte) { | ||||
| 	packet = r.readNextPacket(packet) | ||||
| 	if remoteIP, ok := r.parsePacket(packet); ok { | ||||
| 		r.connWriter.WriteData(remoteIP, packet) | ||||
| 	remoteIP, ok := r.parsePacket(packet) | ||||
| 	if !ok { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	rt := r.rt.Load() | ||||
| 	peer := rt.Peers[remoteIP] | ||||
| 	if !peer.Up { | ||||
| 		r.logf("Peer %d not up.", peer.IP) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	enc := peer.EncryptDataPacket(peer.IP, packet, r.buf1) | ||||
| 	if peer.Direct { | ||||
| 		r.writeToUDPAddrPort(enc, peer.DirectAddr) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	relay, ok := rt.GetRelay() | ||||
| 	if !ok { | ||||
| 		r.logf("Relay not available for peer %d.", peer.IP) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	enc = relay.EncryptDataPacket(peer.IP, enc, r.buf2) | ||||
| 	r.writeToUDPAddrPort(enc, relay.DirectAddr) | ||||
| } | ||||
|  | ||||
| func (r *IFReader) readNextPacket(buf []byte) []byte { | ||||
|   | ||||
| @@ -1,9 +1,6 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| /* | ||||
| func TestIFReader_IPv4(t *testing.T) { | ||||
| 	p1, p2, _ := NewPeersForTesting() | ||||
|  | ||||
| @@ -81,3 +78,4 @@ func TestIFReader_parsePacket_shortIPv6(t *testing.T) { | ||||
| 		t.Fatal(ip, ok) | ||||
| 	} | ||||
| } | ||||
| */ | ||||
|   | ||||
| @@ -1,232 +0,0 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"reflect" | ||||
| 	"sync/atomic" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| // Test that we parse IPv4 packets correctly. | ||||
| func TestIFReader_parsePacket_ipv4(t *testing.T) { | ||||
| 	r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) | ||||
|  | ||||
| 	pkt := make([]byte, 1234) | ||||
| 	pkt[0] = 4 << 4 | ||||
| 	pkt[19] = 128 | ||||
|  | ||||
| 	if ip, ok := r.parsePacket(pkt); !ok || ip != 128 { | ||||
| 		t.Fatal(ip, ok) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Test that we parse IPv6 packets correctly. | ||||
| func TestIFReader_parsePacket_ipv6(t *testing.T) { | ||||
| 	r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) | ||||
|  | ||||
| 	pkt := make([]byte, 1234) | ||||
| 	pkt[0] = 6 << 4 | ||||
| 	pkt[39] = 42 | ||||
|  | ||||
| 	if ip, ok := r.parsePacket(pkt); !ok || ip != 42 { | ||||
| 		t.Fatal(ip, ok) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| /* | ||||
| // Test that empty packets work as expected. | ||||
| func TestIFReader_parsePacket_emptyPacket(t *testing.T) { | ||||
| 	r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) | ||||
|  | ||||
| 	pkt := make([]byte, 0) | ||||
| 	if ip, ok := r.parsePacket(pkt); ok { | ||||
| 		t.Fatal(ip, ok) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Test that invalid IP versions fail. | ||||
| func TestIFReader_parsePacket_invalidIPVersion(t *testing.T) { | ||||
| 	r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) | ||||
|  | ||||
| 	for i := byte(1); i < 16; i++ { | ||||
| 		if i == 4 || i == 6 { | ||||
| 			continue | ||||
| 		} | ||||
| 		pkt := make([]byte, 1234) | ||||
| 		pkt[0] = i << 4 | ||||
|  | ||||
| 		if ip, ok := r.parsePacket(pkt); ok { | ||||
| 			t.Fatal(i, ip, ok) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Test that short IPv4 packets fail. | ||||
| func TestIFReader_parsePacket_shortIPv4(t *testing.T) { | ||||
| 	r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) | ||||
|  | ||||
| 	pkt := make([]byte, 19) | ||||
| 	pkt[0] = 4 << 4 | ||||
|  | ||||
| 	if ip, ok := r.parsePacket(pkt); ok { | ||||
| 		t.Fatal(ip, ok) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Test that short IPv6 packets fail. | ||||
| func TestIFReader_parsePacket_shortIPv6(t *testing.T) { | ||||
| 	r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) | ||||
|  | ||||
| 	pkt := make([]byte, 39) | ||||
| 	pkt[0] = 6 << 4 | ||||
|  | ||||
| 	if ip, ok := r.parsePacket(pkt); ok { | ||||
| 		t.Fatal(ip, ok) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Test that we can read a packet. | ||||
| func TestIFReader_readNextpacket(t *testing.T) { | ||||
| 	in, out := net.Pipe() | ||||
| 	r := newIFReader(out, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) | ||||
| 	defer in.Close() | ||||
| 	defer out.Close() | ||||
|  | ||||
| 	go in.Write([]byte("hello world!")) | ||||
|  | ||||
| 	pkt := r.readNextPacket(make([]byte, bufferSize)) | ||||
| 	if !bytes.Equal(pkt, []byte("hello world!")) { | ||||
| 		t.Fatalf("%s", pkt) | ||||
| 	} | ||||
| } | ||||
| */ | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type sentPacket struct { | ||||
| 	Relayed bool | ||||
| 	Packet  []byte | ||||
| 	Route   RemotePeer | ||||
| 	Relay   RemotePeer | ||||
| } | ||||
|  | ||||
| type sendPacketTestHarness struct { | ||||
| 	Packets []sentPacket | ||||
| } | ||||
|  | ||||
| func (h *sendPacketTestHarness) SendDataPacket(pkt []byte, route *RemotePeer) { | ||||
| 	h.Packets = append(h.Packets, sentPacket{ | ||||
| 		Packet: bytes.Clone(pkt), | ||||
| 		Route:  *route, | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func (h *sendPacketTestHarness) RelayDataPacket(pkt []byte, route, relay *RemotePeer) { | ||||
| 	h.Packets = append(h.Packets, sentPacket{ | ||||
| 		Relayed: true, | ||||
| 		Packet:  bytes.Clone(pkt), | ||||
| 		Route:   *route, | ||||
| 		Relay:   *relay, | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func newIFReaderForSendPacketTesting() (*ifReader, *sendPacketTestHarness) { | ||||
| 	h := &sendPacketTestHarness{} | ||||
|  | ||||
| 	routes := [256]*atomic.Pointer[RemotePeer]{} | ||||
| 	for i := range routes { | ||||
| 		routes[i] = &atomic.Pointer[RemotePeer]{} | ||||
| 		routes[i].Store(&RemotePeer{}) | ||||
| 	} | ||||
| 	relay := &atomic.Pointer[RemotePeer]{} | ||||
| 	r := newIFReader(nil, routes, relay, h) | ||||
| 	return r, h | ||||
| } | ||||
|  | ||||
| // Testing that we can send a packet directly. | ||||
| func TestIFReader_sendPacket_direct(t *testing.T) { | ||||
| 	r, h := newIFReaderForSendPacketTesting() | ||||
|  | ||||
| 	route := r.peers[2].Load() | ||||
| 	route.Up = true | ||||
| 	route.Direct = true | ||||
|  | ||||
| 	in := []byte("hello world") | ||||
|  | ||||
| 	r.sendPacket(in, 2) | ||||
| 	if len(h.Packets) != 1 { | ||||
| 		t.Fatal(h.Packets) | ||||
| 	} | ||||
|  | ||||
| 	expected := sentPacket{ | ||||
| 		Relayed: false, | ||||
| 		Packet:  in, | ||||
| 		Route:   *route, | ||||
| 	} | ||||
|  | ||||
| 	if !reflect.DeepEqual(h.Packets[0], expected) { | ||||
| 		t.Fatal(h.Packets[0]) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Testing that we don't send a packet if route isn't up. | ||||
| func TestIFReader_sendPacket_directNotUp(t *testing.T) { | ||||
| 	r, h := newIFReaderForSendPacketTesting() | ||||
|  | ||||
| 	route := r.peers[2].Load() | ||||
| 	route.Direct = true | ||||
|  | ||||
| 	in := []byte("hello world") | ||||
|  | ||||
| 	r.sendPacket(in, 2) | ||||
| 	if len(h.Packets) != 0 { | ||||
| 		t.Fatal(h.Packets) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Testing that we can send a packet via a relay. | ||||
| func TestIFReader_sendPacket_relayed(t *testing.T) { | ||||
| 	r, h := newIFReaderForSendPacketTesting() | ||||
|  | ||||
| 	route := r.peers[2].Load() | ||||
| 	route.Up = true | ||||
| 	route.Direct = false | ||||
|  | ||||
| 	relay := r.peers[3].Load() | ||||
| 	r.relay.Store(relay) | ||||
| 	relay.Up = true | ||||
| 	relay.Direct = true | ||||
|  | ||||
| 	in := []byte("hello world") | ||||
|  | ||||
| 	r.sendPacket(in, 2) | ||||
| 	if len(h.Packets) != 1 { | ||||
| 		t.Fatal(h.Packets) | ||||
| 	} | ||||
|  | ||||
| 	expected := sentPacket{ | ||||
| 		Relayed: true, | ||||
| 		Packet:  in, | ||||
| 		Route:   *route, | ||||
| 		Relay:   *relay, | ||||
| 	} | ||||
|  | ||||
| 	if !reflect.DeepEqual(h.Packets[0], expected) { | ||||
| 		t.Fatal(h.Packets[0]) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Testing that we don't try to send on a nil relay IP. | ||||
| func TestIFReader_sendPacket_nilRealy(t *testing.T) { | ||||
| 	r, h := newIFReaderForSendPacketTesting() | ||||
|  | ||||
| 	route := r.peers[2].Load() | ||||
| 	route.Up = true | ||||
| 	route.Direct = false | ||||
|  | ||||
| 	in := []byte("hello world") | ||||
|  | ||||
| 	r.sendPacket(in, 2) | ||||
| 	if len(h.Packets) != 0 { | ||||
| 		t.Fatal(h.Packets) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										177
									
								
								peer/interface.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										177
									
								
								peer/interface.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,177 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"net" | ||||
| 	"os" | ||||
| 	"syscall" | ||||
|  | ||||
| 	"golang.org/x/sys/unix" | ||||
| ) | ||||
|  | ||||
| // Get next packet, returning packet, ip, and possible error. | ||||
| func readNextPacket(iface io.ReadWriteCloser, buf []byte) ([]byte, byte, error) { | ||||
| 	var ( | ||||
| 		version byte | ||||
| 		ip      byte | ||||
| 	) | ||||
| 	for { | ||||
| 		n, err := iface.Read(buf[:cap(buf)]) | ||||
| 		if err != nil { | ||||
| 			return nil, ip, err | ||||
| 		} | ||||
|  | ||||
| 		buf = buf[:n] | ||||
| 		version = buf[0] >> 4 | ||||
|  | ||||
| 		switch version { | ||||
| 		case 4: | ||||
| 			if n < 20 { | ||||
| 				log.Printf("Short IPv4 packet: %d", len(buf)) | ||||
| 				continue | ||||
| 			} | ||||
| 			ip = buf[19] | ||||
|  | ||||
| 		case 6: | ||||
| 			if len(buf) < 40 { | ||||
| 				log.Printf("Short IPv6 packet: %d", len(buf)) | ||||
| 				continue | ||||
| 			} | ||||
| 			ip = buf[39] | ||||
|  | ||||
| 		default: | ||||
| 			log.Printf("Invalid IP packet version: %v", version) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		return buf, ip, nil | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func openInterface(network []byte, localIP byte, name string) (io.ReadWriteCloser, error) { | ||||
| 	if len(network) != 4 { | ||||
| 		return nil, fmt.Errorf("expected network to be 4 bytes, got %d", len(network)) | ||||
| 	} | ||||
| 	ip := net.IPv4(network[0], network[1], network[2], localIP) | ||||
|  | ||||
| 	////////////////////////// | ||||
| 	// Create TUN Interface // | ||||
| 	////////////////////////// | ||||
|  | ||||
| 	tunFD, err := syscall.Open("/dev/net/tun", syscall.O_RDWR|unix.O_CLOEXEC, 0600) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to open TUN device: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	// New interface request. | ||||
| 	req, err := unix.NewIfreq(name) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to create new TUN interface request: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	// Flags: | ||||
| 	// | ||||
| 	// IFF_NO_PI => don't add packet info data to packets sent to the interface. | ||||
| 	// IFF_TUN   => create a TUN device handling IP packets. | ||||
| 	req.SetUint16(unix.IFF_NO_PI | unix.IFF_TUN) | ||||
|  | ||||
| 	err = unix.IoctlIfreq(tunFD, unix.TUNSETIFF, req) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to set TUN device settings: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	// Name may not be exactly the same? | ||||
| 	name = req.Name() | ||||
|  | ||||
| 	///////////// | ||||
| 	// Set MTU // | ||||
| 	///////////// | ||||
|  | ||||
| 	// We need a socket file descriptor to set other options for some reason. | ||||
| 	sockFD, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to open socket: %w", err) | ||||
| 	} | ||||
| 	defer unix.Close(sockFD) | ||||
|  | ||||
| 	req, err = unix.NewIfreq(name) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to create MTU interface request: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	req.SetUint32(if_mtu) | ||||
| 	if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFMTU, req); err != nil { | ||||
| 		return nil, fmt.Errorf("failed to set interface MTU: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	////////////////////// | ||||
| 	// Set Queue Length // | ||||
| 	////////////////////// | ||||
|  | ||||
| 	req, err = unix.NewIfreq(name) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to create IP interface request: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	req.SetUint16(if_queue_len) | ||||
| 	if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFTXQLEN, req); err != nil { | ||||
| 		return nil, fmt.Errorf("failed to set interface queue length: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	///////////////////// | ||||
| 	// Set IP and Mask // | ||||
| 	///////////////////// | ||||
|  | ||||
| 	req, err = unix.NewIfreq(name) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to create IP interface request: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	if err := req.SetInet4Addr(ip.To4()); err != nil { | ||||
| 		return nil, fmt.Errorf("failed to set interface request IP: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFADDR, req); err != nil { | ||||
| 		return nil, fmt.Errorf("failed to set interface IP: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	// SET MASK - must happen after setting address. | ||||
| 	req, err = unix.NewIfreq(name) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to create mask interface request: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	if err := req.SetInet4Addr(net.IPv4(255, 255, 255, 0).To4()); err != nil { | ||||
| 		return nil, fmt.Errorf("failed to set interface request mask: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	if err := unix.IoctlIfreq(sockFD, unix.SIOCSIFNETMASK, req); err != nil { | ||||
| 		return nil, fmt.Errorf("failed to set interface mask: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	//////////////////////// | ||||
| 	// Bring Interface Up // | ||||
| 	//////////////////////// | ||||
|  | ||||
| 	req, err = unix.NewIfreq(name) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to create up interface request: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	// Get current flags. | ||||
| 	if err = unix.IoctlIfreq(sockFD, unix.SIOCGIFFLAGS, req); err != nil { | ||||
| 		return nil, fmt.Errorf("failed to get interface flags: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	flags := req.Uint16() | unix.IFF_UP | unix.IFF_RUNNING | ||||
|  | ||||
| 	// Set UP flag / broadcast flags. | ||||
| 	req.SetUint16(flags) | ||||
| 	if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFFLAGS, req); err != nil { | ||||
| 		return nil, fmt.Errorf("failed to set interface up: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	return os.NewFile(uintptr(tunFD), "tun"), nil | ||||
| } | ||||
| @@ -31,17 +31,17 @@ type Marshaller interface { | ||||
| } | ||||
|  | ||||
| type dataPacketSender interface { | ||||
| 	SendDataPacket(pkt []byte, peer *RemotePeer) | ||||
| 	RelayDataPacket(pkt []byte, peer, relay *RemotePeer) | ||||
| 	SendDataPacket(pkt []byte, peer *remotePeer) | ||||
| 	RelayDataPacket(pkt []byte, peer, relay *remotePeer) | ||||
| } | ||||
|  | ||||
| type controlPacketSender interface { | ||||
| 	SendControlPacket(pkt Marshaller, peer *RemotePeer) | ||||
| 	RelayControlPacket(pkt Marshaller, peer, relay *RemotePeer) | ||||
| 	SendControlPacket(pkt Marshaller, peer *remotePeer) | ||||
| 	RelayControlPacket(pkt Marshaller, peer, relay *remotePeer) | ||||
| } | ||||
|  | ||||
| type encryptedPacketSender interface { | ||||
| 	SendEncryptedDataPacket(pkt []byte, peer *RemotePeer) | ||||
| 	SendEncryptedDataPacket(pkt []byte, peer *remotePeer) | ||||
| } | ||||
|  | ||||
| type controlMsgHandler interface { | ||||
|   | ||||
							
								
								
									
										23
									
								
								peer/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								peer/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,23 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"flag" | ||||
| 	"os" | ||||
| ) | ||||
|  | ||||
| func Main() { | ||||
| 	conf := Config{} | ||||
|  | ||||
| 	flag.StringVar(&conf.NetName, "name", "", "[REQUIRED] The network name.") | ||||
| 	flag.StringVar(&conf.HubAddress, "hub-address", "", "[REQUIRED] The hub address.") | ||||
| 	flag.StringVar(&conf.APIKey, "api-key", "", "[REQUIRED] The node's API key.") | ||||
| 	flag.Parse() | ||||
|  | ||||
| 	if conf.NetName == "" || conf.HubAddress == "" || conf.APIKey == "" { | ||||
| 		flag.Usage() | ||||
| 		os.Exit(1) | ||||
| 	} | ||||
|  | ||||
| 	peer := New(conf) | ||||
| 	peer.Run() | ||||
| } | ||||
| @@ -8,7 +8,7 @@ import ( | ||||
| type mcReader struct { | ||||
| 	conn  udpReader | ||||
| 	super controlMsgHandler | ||||
| 	peers [256]*atomic.Pointer[RemotePeer] | ||||
| 	peers [256]*atomic.Pointer[remotePeer] | ||||
|  | ||||
| 	incoming []byte | ||||
| 	buf      []byte | ||||
| @@ -17,7 +17,7 @@ type mcReader struct { | ||||
| func newMCReader( | ||||
| 	conn udpReader, | ||||
| 	super controlMsgHandler, | ||||
| 	peers [256]*atomic.Pointer[RemotePeer], | ||||
| 	peers [256]*atomic.Pointer[remotePeer], | ||||
| ) *mcReader { | ||||
| 	return &mcReader{conn, super, peers, newBuf(), newBuf()} | ||||
| } | ||||
| @@ -50,7 +50,7 @@ func (r *mcReader) handleNextPacket() { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	r.super.HandleControlMsg(controlMsg[PacketLocalDiscovery]{ | ||||
| 	r.super.HandleControlMsg(controlMsg[packetLocalDiscovery]{ | ||||
| 		SrcIP:   h.SourceIP, | ||||
| 		SrcAddr: remoteAddr, | ||||
| 	}) | ||||
|   | ||||
| @@ -1,13 +1,6 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"net" | ||||
| 	"net/netip" | ||||
| 	"sync/atomic" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| /* | ||||
| type mcMockConn struct { | ||||
| 	packets chan []byte | ||||
| } | ||||
| @@ -136,3 +129,4 @@ func TestMCReader_badSignature(t *testing.T) { | ||||
| 		t.Fatal(super.Messages) | ||||
| 	} | ||||
| } | ||||
| */ | ||||
|   | ||||
| @@ -5,41 +5,34 @@ import ( | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	PacketTypeSyn = iota + 1 | ||||
| 	PacketTypeSynAck | ||||
| 	PacketTypeAck | ||||
| 	PacketTypeProbe | ||||
| 	PacketTypeAddrDiscovery | ||||
| 	packetTypeSyn           = 1 | ||||
| 	packetTypeAck           = 3 | ||||
| 	packetTypeProbe         = 4 | ||||
| 	packetTypeAddrDiscovery = 5 | ||||
| ) | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type PacketSyn struct { | ||||
| type packetSyn struct { | ||||
| 	TraceID       uint64   // TraceID to match response w/ request. | ||||
| 	//SentAt        int64    // Unixmilli. | ||||
| 	//SharedKeyType byte     // Currently only 1 is supported for AES. | ||||
| 	SharedKey     [32]byte // Our shared key. | ||||
| 	Direct        bool | ||||
| 	PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. | ||||
| } | ||||
|  | ||||
| func (p PacketSyn) Marshal(buf []byte) []byte { | ||||
| func (p packetSyn) Marshal(buf []byte) []byte { | ||||
| 	return newBinWriter(buf). | ||||
| 		Byte(PacketTypeSyn). | ||||
| 		Byte(packetTypeSyn). | ||||
| 		Uint64(p.TraceID). | ||||
| 		//Int64(p.SentAt). | ||||
| 		//Byte(p.SharedKeyType). | ||||
| 		SharedKey(p.SharedKey). | ||||
| 		Bool(p.Direct). | ||||
| 		AddrPort8(p.PossibleAddrs). | ||||
| 		Build() | ||||
| } | ||||
|  | ||||
| func ParsePacketSyn(buf []byte) (p PacketSyn, err error) { | ||||
| func parsePacketSyn(buf []byte) (p packetSyn, err error) { | ||||
| 	err = newBinReader(buf[1:]). | ||||
| 		Uint64(&p.TraceID). | ||||
| 		//Int64(&p.SentAt). | ||||
| 		//Byte(&p.SharedKeyType). | ||||
| 		SharedKey(&p.SharedKey). | ||||
| 		Bool(&p.Direct). | ||||
| 		AddrPort8(&p.PossibleAddrs). | ||||
| @@ -49,22 +42,22 @@ func ParsePacketSyn(buf []byte) (p PacketSyn, err error) { | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type PacketAck struct { | ||||
| type packetAck struct { | ||||
| 	TraceID       uint64 | ||||
| 	ToAddr        netip.AddrPort | ||||
| 	PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. | ||||
| } | ||||
|  | ||||
| func (p PacketAck) Marshal(buf []byte) []byte { | ||||
| func (p packetAck) Marshal(buf []byte) []byte { | ||||
| 	return newBinWriter(buf). | ||||
| 		Byte(PacketTypeAck). | ||||
| 		Byte(packetTypeAck). | ||||
| 		Uint64(p.TraceID). | ||||
| 		AddrPort(p.ToAddr). | ||||
| 		AddrPort8(p.PossibleAddrs). | ||||
| 		Build() | ||||
| } | ||||
|  | ||||
| func ParsePacketAck(buf []byte) (p PacketAck, err error) { | ||||
| func parsePacketAck(buf []byte) (p packetAck, err error) { | ||||
| 	err = newBinReader(buf[1:]). | ||||
| 		Uint64(&p.TraceID). | ||||
| 		AddrPort(&p.ToAddr). | ||||
| @@ -77,18 +70,18 @@ func ParsePacketAck(buf []byte) (p PacketAck, err error) { | ||||
|  | ||||
| // A probeReqPacket is sent from a client to a server to determine if direct | ||||
| // UDP communication can be used. | ||||
| type PacketProbe struct { | ||||
| type packetProbe struct { | ||||
| 	TraceID uint64 | ||||
| } | ||||
|  | ||||
| func (p PacketProbe) Marshal(buf []byte) []byte { | ||||
| func (p packetProbe) Marshal(buf []byte) []byte { | ||||
| 	return newBinWriter(buf). | ||||
| 		Byte(PacketTypeProbe). | ||||
| 		Byte(packetTypeProbe). | ||||
| 		Uint64(p.TraceID). | ||||
| 		Build() | ||||
| } | ||||
|  | ||||
| func ParsePacketProbe(buf []byte) (p PacketProbe, err error) { | ||||
| func parsePacketProbe(buf []byte) (p packetProbe, err error) { | ||||
| 	err = newBinReader(buf[1:]). | ||||
| 		Uint64(&p.TraceID). | ||||
| 		Error() | ||||
| @@ -97,4 +90,4 @@ func ParsePacketProbe(buf []byte) (p PacketProbe, err error) { | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type PacketLocalDiscovery struct{} | ||||
| type packetLocalDiscovery struct{} | ||||
|   | ||||
| @@ -8,7 +8,7 @@ import ( | ||||
| ) | ||||
|  | ||||
| func TestSynPacket(t *testing.T) { | ||||
| 	p := PacketSyn{ | ||||
| 	p := packetSyn{ | ||||
| 		TraceID: newTraceID(), | ||||
| 		//SentAt:        time.Now().UnixMilli(), | ||||
| 		//SharedKeyType: 1, | ||||
| @@ -21,7 +21,7 @@ func TestSynPacket(t *testing.T) { | ||||
| 	p.PossibleAddrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{3, 2, 3, 4}), 60000) | ||||
|  | ||||
| 	buf := p.Marshal(newBuf()) | ||||
| 	p2, err := ParsePacketSyn(buf) | ||||
| 	p2, err := parsePacketSyn(buf) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| @@ -31,7 +31,7 @@ func TestSynPacket(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestAckPacket(t *testing.T) { | ||||
| 	p := PacketAck{ | ||||
| 	p := packetAck{ | ||||
| 		TraceID: newTraceID(), | ||||
| 		ToAddr:  netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 234), | ||||
| 	} | ||||
| @@ -41,7 +41,7 @@ func TestAckPacket(t *testing.T) { | ||||
| 	p.PossibleAddrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{3, 2, 3, 4}), 60000) | ||||
|  | ||||
| 	buf := p.Marshal(newBuf()) | ||||
| 	p2, err := ParsePacketAck(buf) | ||||
| 	p2, err := parsePacketAck(buf) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| @@ -51,12 +51,12 @@ func TestAckPacket(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestProbePacket(t *testing.T) { | ||||
| 	p := PacketProbe{ | ||||
| 	p := packetProbe{ | ||||
| 		TraceID: newTraceID(), | ||||
| 	} | ||||
|  | ||||
| 	buf := p.Marshal(newBuf()) | ||||
| 	p2, err := ParsePacketProbe(buf) | ||||
| 	p2, err := parsePacketProbe(buf) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|   | ||||
							
								
								
									
										161
									
								
								peer/peer.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										161
									
								
								peer/peer.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,161 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"net/netip" | ||||
| 	"net/url" | ||||
| 	"sync" | ||||
| 	"sync/atomic" | ||||
| 	"vppn/m" | ||||
| ) | ||||
|  | ||||
| type Peer struct { | ||||
| 	ifReader   *IFReader | ||||
| 	connReader *ConnReader | ||||
| 	iface      io.Writer | ||||
| 	hubPoller  *hubPoller | ||||
| 	super      *Super | ||||
| } | ||||
|  | ||||
| type Config struct { | ||||
| 	NetName    string | ||||
| 	HubAddress string | ||||
| 	APIKey     string | ||||
| } | ||||
|  | ||||
| func New(conf Config) *Peer { | ||||
| 	config, err := loadPeerConfig(conf.NetName) | ||||
| 	if err != nil { | ||||
| 		log.Printf("Failed to load configuration: %v", err) | ||||
| 		log.Printf("Initializing...") | ||||
| 		initPeerWithHub(conf) | ||||
|  | ||||
| 		config, err = loadPeerConfig(conf.NetName) | ||||
| 		if err != nil { | ||||
| 			log.Fatalf("Failed to load configuration: %v", err) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	iface, err := openInterface(config.Network, config.PeerIP, conf.NetName) | ||||
| 	if err != nil { | ||||
| 		log.Fatalf("Failed to open interface: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	myAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", config.Port)) | ||||
| 	if err != nil { | ||||
| 		log.Fatalf("Failed to resolve UDP address: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	log.Printf("Listening on %v...", myAddr) | ||||
| 	conn, err := net.ListenUDP("udp", myAddr) | ||||
| 	if err != nil { | ||||
| 		log.Fatalf("Failed to open UDP port: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	conn.SetReadBuffer(1024 * 1024 * 8) | ||||
| 	conn.SetWriteBuffer(1024 * 1024 * 8) | ||||
|  | ||||
| 	// Wrap write function - this is necessary to avoid starvation. | ||||
| 	writeLock := sync.Mutex{} | ||||
| 	writeToUDPAddrPort := func(b []byte, addr netip.AddrPort) (n int, err error) { | ||||
| 		writeLock.Lock() | ||||
| 		n, err = conn.WriteToUDPAddrPort(b, addr) | ||||
| 		if err != nil { | ||||
| 			log.Printf("Failed to write packet: %v", err) | ||||
| 		} | ||||
| 		writeLock.Unlock() | ||||
| 		return n, err | ||||
| 	} | ||||
|  | ||||
| 	var localAddr netip.AddrPort | ||||
| 	ip, ok := netip.AddrFromSlice(config.PublicIP) | ||||
| 	if ok { | ||||
| 		localAddr = netip.AddrPortFrom(ip, config.Port) | ||||
| 	} | ||||
|  | ||||
| 	rt := newRoutingTable(config.PeerIP, localAddr) | ||||
| 	rtPtr := &atomic.Pointer[routingTable]{} | ||||
| 	rtPtr.Store(&rt) | ||||
|  | ||||
| 	ifReader := NewIFReader(iface, writeToUDPAddrPort, rtPtr) | ||||
| 	super := NewSuper(writeToUDPAddrPort, rtPtr, config.PrivKey) | ||||
| 	connReader := NewConnReader(conn.ReadFromUDPAddrPort, writeToUDPAddrPort, iface, super.HandleControlMsg, rtPtr) | ||||
| 	hubPoller, err := newHubPoller(config.PeerIP, conf.NetName, conf.HubAddress, conf.APIKey, super.HandleControlMsg) | ||||
| 	if err != nil { | ||||
| 		log.Fatalf("Failed to create hub poller: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	return &Peer{ | ||||
| 		iface:      iface, | ||||
| 		ifReader:   ifReader, | ||||
| 		connReader: connReader, | ||||
| 		hubPoller:  hubPoller, | ||||
| 		super:      super, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (p *Peer) Run() { | ||||
| 	go p.ifReader.Run() | ||||
| 	go p.connReader.Run() | ||||
| 	p.super.Start() | ||||
| 	p.hubPoller.Run() | ||||
| } | ||||
|  | ||||
| func initPeerWithHub(conf Config) { | ||||
| 	keys := generateKeys() | ||||
|  | ||||
| 	initURL, err := url.Parse(conf.HubAddress) | ||||
| 	if err != nil { | ||||
| 		log.Fatalf("Failed to parse hub URL: %v", err) | ||||
| 	} | ||||
| 	initURL.Path = "/peer/init/" | ||||
|  | ||||
| 	args := m.PeerInitArgs{ | ||||
| 		EncPubKey:  keys.PubKey, | ||||
| 		PubSignKey: keys.PubSignKey, | ||||
| 	} | ||||
|  | ||||
| 	buf := &bytes.Buffer{} | ||||
| 	if err := json.NewEncoder(buf).Encode(args); err != nil { | ||||
| 		log.Fatalf("Failed to encode init args: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	req, err := http.NewRequest(http.MethodPost, initURL.String(), buf) | ||||
| 	if err != nil { | ||||
| 		log.Fatalf("Failed to construct request: %v", err) | ||||
| 	} | ||||
| 	req.SetBasicAuth("", conf.APIKey) | ||||
|  | ||||
| 	resp, err := http.DefaultClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		log.Fatalf("Failed to init with hub: %v", err) | ||||
| 	} | ||||
| 	defer resp.Body.Close() | ||||
|  | ||||
| 	data, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		log.Fatalf("Failed to read response body: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	peerConfig := localConfig{} | ||||
| 	if err := json.Unmarshal(data, &peerConfig.PeerConfig); err != nil { | ||||
| 		log.Fatalf("Failed to parse configuration: %v\n%s", err, data) | ||||
| 	} | ||||
|  | ||||
| 	peerConfig.PubKey = keys.PubKey | ||||
| 	peerConfig.PrivKey = keys.PrivKey | ||||
| 	peerConfig.PubSignKey = keys.PubSignKey | ||||
| 	peerConfig.PrivSignKey = keys.PrivSignKey | ||||
|  | ||||
| 	if err := storePeerConfig(conf.NetName, peerConfig); err != nil { | ||||
| 		log.Fatalf("Failed to store configuration: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	log.Print("Initialization successful.") | ||||
| } | ||||
| @@ -11,36 +11,25 @@ import ( | ||||
| // A test peer. | ||||
| type P struct { | ||||
| 	cryptoKeys | ||||
| 	RT         *atomic.Pointer[RoutingTable] | ||||
| 	RT         *atomic.Pointer[routingTable] | ||||
| 	Conn       *TestUDPConn | ||||
| 	IFace      *TestIFace | ||||
| 	ConnWriter *ConnWriter | ||||
| 	ConnReader *ConnReader | ||||
| 	IFReader   *IFReader | ||||
| 	Super      *Supervisor | ||||
| } | ||||
|  | ||||
| func NewPeerForTesting(n *TestNetwork, ip byte, addr netip.AddrPort) P { | ||||
| 	p := P{ | ||||
| 		cryptoKeys: generateKeys(), | ||||
| 		RT:         &atomic.Pointer[RoutingTable]{}, | ||||
| 		RT:         &atomic.Pointer[routingTable]{}, | ||||
| 		IFace:      NewTestIFace(), | ||||
| 	} | ||||
|  | ||||
| 	rt := NewRoutingTable(ip, addr) | ||||
| 	rt := newRoutingTable(ip, addr) | ||||
| 	p.RT.Store(&rt) | ||||
| 	p.Conn = n.NewUDPConn(addr) | ||||
| 	p.ConnWriter = NewConnWriter(p.Conn.WriteToUDPAddrPort, p.RT) | ||||
| 	p.IFReader = NewIFReader(p.IFace, p.ConnWriter) | ||||
| 	//p.ConnWriter = NewConnWriter(p.Conn.WriteToUDPAddrPort, p.RT) | ||||
|  | ||||
| 	/* | ||||
| 		   p.ConnReader = NewConnReader( | ||||
| 				p.Conn.ReadFromUDPAddrPort, | ||||
| 				p.IFace, | ||||
| 				p.ConnWriter.Forward, | ||||
| 				p.Super.HandleControlMsg, | ||||
| 				p.RT) | ||||
| 	*/ | ||||
| 	return p | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -11,21 +11,21 @@ import ( | ||||
| 	"git.crumpington.com/lib/go/ratelimiter" | ||||
| ) | ||||
|  | ||||
| type PeerState interface { | ||||
| 	OnPeerUpdate(*m.Peer) PeerState | ||||
| 	OnSyn(controlMsg[PacketSyn]) PeerState | ||||
| 	OnAck(controlMsg[PacketAck]) | ||||
| 	OnProbe(controlMsg[PacketProbe]) PeerState | ||||
| 	OnLocalDiscovery(controlMsg[PacketLocalDiscovery]) | ||||
| 	OnPingTimer() PeerState | ||||
| type peerState interface { | ||||
| 	OnPeerUpdate(*m.Peer) peerState | ||||
| 	OnSyn(controlMsg[packetSyn]) peerState | ||||
| 	OnAck(controlMsg[packetAck]) | ||||
| 	OnProbe(controlMsg[packetProbe]) peerState | ||||
| 	OnLocalDiscovery(controlMsg[packetLocalDiscovery]) | ||||
| 	OnPingTimer() peerState | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type State struct { | ||||
| type pState struct { | ||||
| 	// Output. | ||||
| 	publish           func(RemotePeer) | ||||
| 	sendControlPacket func(RemotePeer, Marshaller) | ||||
| 	publish           func(remotePeer) | ||||
| 	sendControlPacket func(remotePeer, Marshaller) | ||||
|  | ||||
| 	// Immutable data. | ||||
| 	localIP   byte | ||||
| @@ -37,7 +37,7 @@ type State struct { | ||||
|  | ||||
| 	// The purpose of this state machine is to manage the RemotePeer object, | ||||
| 	// publishing it as necessary. | ||||
| 	staged RemotePeer // Local copy of shared data. See publish(). | ||||
| 	staged remotePeer // Local copy of shared data. See publish(). | ||||
|  | ||||
| 	// Mutable peer data. | ||||
| 	peer *m.Peer | ||||
| @@ -47,25 +47,28 @@ type State struct { | ||||
| 	limiter *ratelimiter.Limiter | ||||
| } | ||||
|  | ||||
| func (s *State) OnPeerUpdate(peer *m.Peer) PeerState { | ||||
| func (s *pState) OnPeerUpdate(peer *m.Peer) peerState { | ||||
| 	defer func() { | ||||
| 		// Don't defer directly otherwise s.staged will be evaluated immediately | ||||
| 		// and won't reflect changes made in the function. | ||||
| 		s.publish(s.staged) | ||||
| 	}() | ||||
|  | ||||
| 	if peer == nil { | ||||
| 		return EnterStateDisconnected(s) | ||||
| 	} | ||||
|  | ||||
| 	s.peer = peer | ||||
|  | ||||
| 	s.staged.localIP = s.localIP | ||||
| 	s.staged.IP = peer.PeerIP | ||||
| 	s.staged.Up = false | ||||
| 	s.staged.Relay = false | ||||
| 	s.staged.Direct = false | ||||
| 	s.staged.DirectAddr = netip.AddrPort{} | ||||
| 	s.staged.PubSignKey = nil | ||||
| 	s.staged.ControlCipher = nil | ||||
| 	s.staged.DataCipher = nil | ||||
|  | ||||
| 	if peer == nil { | ||||
| 		return enterStateDisconnected(s) | ||||
| 	} | ||||
|  | ||||
| 	s.staged.IP = peer.PeerIP | ||||
| 	s.staged.PubSignKey = peer.PubSignKey | ||||
| 	s.staged.ControlCipher = newControlCipher(s.privKey, peer.PubKey) | ||||
| 	s.staged.DataCipher = newDataCipher() | ||||
| @@ -76,30 +79,32 @@ func (s *State) OnPeerUpdate(peer *m.Peer) PeerState { | ||||
| 		s.staged.DirectAddr = netip.AddrPortFrom(ip, peer.Port) | ||||
|  | ||||
| 		if s.localAddr.IsValid() && s.localIP < s.remoteIP { | ||||
| 			return EnterStateServer(s) | ||||
| 			return enterStateServer(s) | ||||
| 		} | ||||
|  | ||||
| 		return EnterStateClientDirect(s) | ||||
| 		return enterStateClientDirect(s) | ||||
| 	} | ||||
|  | ||||
| 	if s.localAddr.IsValid() { | ||||
| 		s.staged.Direct = true | ||||
| 		return EnterStateServer(s) | ||||
| 		return enterStateServer(s) | ||||
| 	} | ||||
|  | ||||
| 	if s.localIP < s.remoteIP { | ||||
| 		return EnterStateServer(s) | ||||
| 		return enterStateServer(s) | ||||
| 	} | ||||
|  | ||||
| 	return EnterStateClientRelayed(s) | ||||
| 	return enterStateClientRelayed(s) | ||||
| } | ||||
|  | ||||
| func (s *State) logf(format string, args ...any) { | ||||
| func (s *pState) logf(format string, args ...any) { | ||||
| 	b := strings.Builder{} | ||||
| 	name := "" | ||||
| 	if s.peer != nil { | ||||
| 		name = s.peer.Name | ||||
| 	} | ||||
| 	b.WriteString(fmt.Sprintf("%03d", s.remoteIP)) | ||||
|  | ||||
| 	b.WriteString(fmt.Sprintf("%30s: ", name)) | ||||
|  | ||||
| 	if s.staged.Direct { | ||||
| @@ -119,7 +124,7 @@ func (s *State) logf(format string, args ...any) { | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *State) SendTo(pkt Marshaller, addr netip.AddrPort) { | ||||
| func (s *pState) SendTo(pkt Marshaller, addr netip.AddrPort) { | ||||
| 	if !addr.IsValid() { | ||||
| 		return | ||||
| 	} | ||||
| @@ -129,7 +134,7 @@ func (s *State) SendTo(pkt Marshaller, addr netip.AddrPort) { | ||||
| 	s.Send(route, pkt) | ||||
| } | ||||
|  | ||||
| func (s *State) Send(peer RemotePeer, pkt Marshaller) { | ||||
| func (s *pState) Send(peer remotePeer, pkt Marshaller) { | ||||
| 	if err := s.limiter.Limit(); err != nil { | ||||
| 		s.logf("Rate limited.") | ||||
| 		return | ||||
| @@ -139,42 +144,32 @@ func (s *State) Send(peer RemotePeer, pkt Marshaller) { | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type StateDisconnected struct{ *State } | ||||
| type stateDisconnected struct{ *pState } | ||||
|  | ||||
| func EnterStateDisconnected(s *State) PeerState { | ||||
| 	s.logf("==> Disconnected") | ||||
| 	s.peer = nil | ||||
| 	s.staged.Up = false | ||||
| 	s.staged.Relay = false | ||||
| 	s.staged.Direct = false | ||||
| 	s.staged.DirectAddr = netip.AddrPort{} | ||||
| 	s.staged.PubSignKey = nil | ||||
| 	s.staged.ControlCipher = nil | ||||
| 	s.staged.DataCipher = nil | ||||
| 	s.publish(s.staged) | ||||
| 	return &StateDisconnected{State: s} | ||||
| func enterStateDisconnected(s *pState) peerState { | ||||
| 	return &stateDisconnected{pState: s} | ||||
| } | ||||
|  | ||||
| func (s *StateDisconnected) OnSyn(controlMsg[PacketSyn]) PeerState             { return nil } | ||||
| func (s *StateDisconnected) OnAck(controlMsg[PacketAck])                       {} | ||||
| func (s *StateDisconnected) OnProbe(controlMsg[PacketProbe]) PeerState         { return nil } | ||||
| func (s *StateDisconnected) OnLocalDiscovery(controlMsg[PacketLocalDiscovery]) {} | ||||
| func (s *StateDisconnected) OnPingTimer() PeerState                            { return nil } | ||||
| func (s *stateDisconnected) OnSyn(controlMsg[packetSyn]) peerState             { return s } | ||||
| func (s *stateDisconnected) OnAck(controlMsg[packetAck])                       {} | ||||
| func (s *stateDisconnected) OnProbe(controlMsg[packetProbe]) peerState         { return s } | ||||
| func (s *stateDisconnected) OnLocalDiscovery(controlMsg[packetLocalDiscovery]) {} | ||||
| func (s *stateDisconnected) OnPingTimer() peerState                            { return s } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type StateServer struct { | ||||
| 	*StateDisconnected | ||||
| type stateServer struct { | ||||
| 	*stateDisconnected | ||||
| 	lastSeen   time.Time | ||||
| 	synTraceID uint64 | ||||
| } | ||||
|  | ||||
| func EnterStateServer(s *State) PeerState { | ||||
| func enterStateServer(s *pState) peerState { | ||||
| 	s.logf("==> Server") | ||||
| 	return &StateServer{StateDisconnected: &StateDisconnected{State: s}} | ||||
| 	return &stateServer{stateDisconnected: &stateDisconnected{pState: s}} | ||||
| } | ||||
|  | ||||
| func (s *StateServer) OnSyn(msg controlMsg[PacketSyn]) PeerState { | ||||
| func (s *stateServer) OnSyn(msg controlMsg[packetSyn]) peerState { | ||||
| 	s.lastSeen = time.Now() | ||||
| 	p := msg.Packet | ||||
|  | ||||
| @@ -194,7 +189,7 @@ func (s *StateServer) OnSyn(msg controlMsg[PacketSyn]) PeerState { | ||||
| 	} | ||||
|  | ||||
| 	// Always respond. | ||||
| 	ack := PacketAck{ | ||||
| 	ack := packetAck{ | ||||
| 		TraceID:       p.TraceID, | ||||
| 		ToAddr:        s.staged.DirectAddr, | ||||
| 		PossibleAddrs: s.pubAddrs.Get(), | ||||
| @@ -202,55 +197,55 @@ func (s *StateServer) OnSyn(msg controlMsg[PacketSyn]) PeerState { | ||||
| 	s.Send(s.staged, ack) | ||||
|  | ||||
| 	if p.Direct { | ||||
| 		return nil | ||||
| 		return s | ||||
| 	} | ||||
|  | ||||
| 	for _, addr := range msg.Packet.PossibleAddrs { | ||||
| 		if !addr.IsValid() { | ||||
| 			break | ||||
| 		} | ||||
| 		s.SendTo(PacketProbe{TraceID: newTraceID()}, addr) | ||||
| 		s.SendTo(packetProbe{TraceID: newTraceID()}, addr) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| 	return s | ||||
| } | ||||
|  | ||||
| func (s *StateServer) OnProbe(msg controlMsg[PacketProbe]) PeerState { | ||||
| func (s *stateServer) OnProbe(msg controlMsg[packetProbe]) peerState { | ||||
| 	if msg.SrcAddr.IsValid() { | ||||
| 		s.SendTo(PacketProbe{TraceID: msg.Packet.TraceID}, msg.SrcAddr) | ||||
| 		s.SendTo(packetProbe{TraceID: msg.Packet.TraceID}, msg.SrcAddr) | ||||
| 	} | ||||
| 	return nil | ||||
| 	return s | ||||
| } | ||||
|  | ||||
| func (s *StateServer) OnPingTimer() PeerState { | ||||
| func (s *stateServer) OnPingTimer() peerState { | ||||
| 	if time.Since(s.lastSeen) > timeoutInterval && s.staged.Up { | ||||
| 		s.staged.Up = false | ||||
| 		s.publish(s.staged) | ||||
| 		s.logf("Timeout.") | ||||
| 	} | ||||
| 	return nil | ||||
| 	return s | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type StateClientDirect struct { | ||||
| 	*StateDisconnected | ||||
| type stateClientDirect struct { | ||||
| 	*stateDisconnected | ||||
| 	lastSeen time.Time | ||||
| 	syn      PacketSyn | ||||
| 	syn      packetSyn | ||||
| } | ||||
|  | ||||
| func EnterStateClientDirect(s *State) PeerState { | ||||
| func enterStateClientDirect(s *pState) peerState { | ||||
| 	s.logf("==> ClientDirect") | ||||
| 	return NewStateClientDirect(s) | ||||
| 	return newStateClientDirect(s) | ||||
| } | ||||
|  | ||||
| func NewStateClientDirect(s *State) *StateClientDirect { | ||||
| 	state := &StateClientDirect{ | ||||
| 		StateDisconnected: &StateDisconnected{s}, | ||||
| func newStateClientDirect(s *pState) *stateClientDirect { | ||||
| 	state := &stateClientDirect{ | ||||
| 		stateDisconnected: &stateDisconnected{s}, | ||||
| 		lastSeen:          time.Now(), // Avoid immediate timeout. | ||||
| 	} | ||||
|  | ||||
| 	state.syn = PacketSyn{ | ||||
| 	state.syn = packetSyn{ | ||||
| 		TraceID:       newTraceID(), | ||||
| 		SharedKey:     s.staged.DataCipher.Key(), | ||||
| 		Direct:        s.staged.Direct, | ||||
| @@ -260,7 +255,7 @@ func NewStateClientDirect(s *State) *StateClientDirect { | ||||
| 	return state | ||||
| } | ||||
|  | ||||
| func (s *StateClientDirect) OnAck(msg controlMsg[PacketAck]) { | ||||
| func (s *stateClientDirect) OnAck(msg controlMsg[packetAck]) { | ||||
| 	if msg.Packet.TraceID != s.syn.TraceID { | ||||
| 		return | ||||
| 	} | ||||
| @@ -276,7 +271,14 @@ func (s *StateClientDirect) OnAck(msg controlMsg[PacketAck]) { | ||||
| 	s.pubAddrs.Store(msg.Packet.ToAddr) | ||||
| } | ||||
|  | ||||
| func (s *StateClientDirect) OnPingTimer() PeerState { | ||||
| func (s *stateClientDirect) OnPingTimer() peerState { | ||||
| 	if next := s.onPingTimer(); next != nil { | ||||
| 		return next | ||||
| 	} | ||||
| 	return s | ||||
| } | ||||
|  | ||||
| func (s *stateClientDirect) onPingTimer() peerState { | ||||
| 	if time.Since(s.lastSeen) > timeoutInterval { | ||||
| 		if s.staged.Up { | ||||
| 			s.staged.Up = false | ||||
| @@ -292,47 +294,47 @@ func (s *StateClientDirect) OnPingTimer() PeerState { | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type StateClientRelayed struct { | ||||
| 	*StateClientDirect | ||||
| 	ack                PacketAck | ||||
| type stateClientRelayed struct { | ||||
| 	*stateClientDirect | ||||
| 	ack                packetAck | ||||
| 	probes             map[uint64]netip.AddrPort | ||||
| 	localDiscoveryAddr netip.AddrPort | ||||
| } | ||||
|  | ||||
| func EnterStateClientRelayed(s *State) PeerState { | ||||
| func enterStateClientRelayed(s *pState) peerState { | ||||
| 	s.logf("==> ClientRelayed") | ||||
| 	return &StateClientRelayed{ | ||||
| 		StateClientDirect: NewStateClientDirect(s), | ||||
| 	return &stateClientRelayed{ | ||||
| 		stateClientDirect: newStateClientDirect(s), | ||||
| 		probes:            map[uint64]netip.AddrPort{}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s *StateClientRelayed) OnAck(msg controlMsg[PacketAck]) { | ||||
| func (s *stateClientRelayed) OnAck(msg controlMsg[packetAck]) { | ||||
| 	s.ack = msg.Packet | ||||
| 	s.StateClientDirect.OnAck(msg) | ||||
| 	s.stateClientDirect.OnAck(msg) | ||||
| } | ||||
|  | ||||
| func (s *StateClientRelayed) OnProbe(msg controlMsg[PacketProbe]) PeerState { | ||||
| func (s *stateClientRelayed) OnProbe(msg controlMsg[packetProbe]) peerState { | ||||
| 	addr, ok := s.probes[msg.Packet.TraceID] | ||||
| 	if !ok { | ||||
| 		return nil | ||||
| 		return s | ||||
| 	} | ||||
|  | ||||
| 	s.staged.DirectAddr = addr | ||||
| 	s.staged.Direct = true | ||||
| 	s.publish(s.staged) | ||||
| 	return EnterStateClientDirect(s.StateClientDirect.State) | ||||
| 	return enterStateClientDirect(s.stateClientDirect.pState) | ||||
| } | ||||
|  | ||||
| func (s *StateClientRelayed) OnLocalDiscovery(msg controlMsg[PacketLocalDiscovery]) { | ||||
| func (s *stateClientRelayed) OnLocalDiscovery(msg controlMsg[packetLocalDiscovery]) { | ||||
| 	// The source port will be the multicast port, so we'll have to | ||||
| 	// construct the correct address using the peer's listed port. | ||||
| 	s.localDiscoveryAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) | ||||
| } | ||||
|  | ||||
| func (s *StateClientRelayed) OnPingTimer() PeerState { | ||||
| 	if nextState := s.StateClientDirect.OnPingTimer(); nextState != nil { | ||||
| 		return nextState | ||||
| func (s *stateClientRelayed) OnPingTimer() peerState { | ||||
| 	if next := s.stateClientDirect.onPingTimer(); next != nil { | ||||
| 		return next | ||||
| 	} | ||||
|  | ||||
| 	clear(s.probes) | ||||
| @@ -348,11 +350,11 @@ func (s *StateClientRelayed) OnPingTimer() PeerState { | ||||
| 		s.localDiscoveryAddr = netip.AddrPort{} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| 	return s | ||||
| } | ||||
|  | ||||
| func (s *StateClientRelayed) sendProbeTo(addr netip.AddrPort) { | ||||
| 	probe := PacketProbe{TraceID: newTraceID()} | ||||
| func (s *stateClientRelayed) sendProbeTo(addr netip.AddrPort) { | ||||
| 	probe := packetProbe{TraceID: newTraceID()} | ||||
| 	s.probes[probe.TraceID] = addr | ||||
| 	s.SendTo(probe, addr) | ||||
| } | ||||
|   | ||||
| @@ -12,13 +12,13 @@ import ( | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type PeerStateControlMsg struct { | ||||
| 	Peer   RemotePeer | ||||
| 	Peer   remotePeer | ||||
| 	Packet any | ||||
| } | ||||
|  | ||||
| type PeerStateTestHarness struct { | ||||
| 	State     PeerState | ||||
| 	Published RemotePeer | ||||
| 	State     peerState | ||||
| 	Published remotePeer | ||||
| 	Sent      []PeerStateControlMsg | ||||
| } | ||||
|  | ||||
| @@ -27,11 +27,11 @@ func NewPeerStateTestHarness() *PeerStateTestHarness { | ||||
|  | ||||
| 	keys := generateKeys() | ||||
|  | ||||
| 	state := &State{ | ||||
| 		publish: func(rp RemotePeer) { | ||||
| 	state := &pState{ | ||||
| 		publish: func(rp remotePeer) { | ||||
| 			h.Published = rp | ||||
| 		}, | ||||
| 		sendControlPacket: func(rp RemotePeer, pkt Marshaller) { | ||||
| 		sendControlPacket: func(rp remotePeer, pkt Marshaller) { | ||||
| 			h.Sent = append(h.Sent, PeerStateControlMsg{rp, pkt}) | ||||
| 		}, | ||||
| 		localIP:  2, | ||||
| @@ -44,7 +44,7 @@ func NewPeerStateTestHarness() *PeerStateTestHarness { | ||||
| 		}), | ||||
| 	} | ||||
|  | ||||
| 	h.State = EnterStateDisconnected(state) | ||||
| 	h.State = enterStateDisconnected(state) | ||||
| 	return h | ||||
| } | ||||
|  | ||||
| @@ -54,13 +54,13 @@ func (h *PeerStateTestHarness) PeerUpdate(p *m.Peer) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (h *PeerStateTestHarness) OnSyn(msg controlMsg[PacketSyn]) { | ||||
| func (h *PeerStateTestHarness) OnSyn(msg controlMsg[packetSyn]) { | ||||
| 	if s := h.State.OnSyn(msg); s != nil { | ||||
| 		h.State = s | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (h *PeerStateTestHarness) OnProbe(msg controlMsg[PacketProbe]) { | ||||
| func (h *PeerStateTestHarness) OnProbe(msg controlMsg[packetProbe]) { | ||||
| 	if s := h.State.OnProbe(msg); s != nil { | ||||
| 		h.State = s | ||||
| 	} | ||||
| @@ -72,10 +72,10 @@ func (h *PeerStateTestHarness) OnPingTimer() { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (h *PeerStateTestHarness) ConfigServer_Public(t *testing.T) *StateServer { | ||||
| func (h *PeerStateTestHarness) ConfigServer_Public(t *testing.T) *stateServer { | ||||
| 	keys := generateKeys() | ||||
|  | ||||
| 	state := h.State.(*StateDisconnected) | ||||
| 	state := h.State.(*stateDisconnected) | ||||
| 	state.localAddr = addrPort4(1, 1, 1, 2, 200) | ||||
|  | ||||
| 	peer := &m.Peer{ | ||||
| @@ -88,10 +88,10 @@ func (h *PeerStateTestHarness) ConfigServer_Public(t *testing.T) *StateServer { | ||||
|  | ||||
| 	h.PeerUpdate(peer) | ||||
| 	assertEqual(t, h.Published.Up, false) | ||||
| 	return assertType[*StateServer](t, h.State) | ||||
| 	return assertType[*stateServer](t, h.State) | ||||
| } | ||||
|  | ||||
| func (h *PeerStateTestHarness) ConfigServer_Relayed(t *testing.T) *StateServer { | ||||
| func (h *PeerStateTestHarness) ConfigServer_Relayed(t *testing.T) *stateServer { | ||||
| 	keys := generateKeys() | ||||
| 	peer := &m.Peer{ | ||||
| 		PeerIP:     3, | ||||
| @@ -102,10 +102,10 @@ func (h *PeerStateTestHarness) ConfigServer_Relayed(t *testing.T) *StateServer { | ||||
|  | ||||
| 	h.PeerUpdate(peer) | ||||
| 	assertEqual(t, h.Published.Up, false) | ||||
| 	return assertType[*StateServer](t, h.State) | ||||
| 	return assertType[*stateServer](t, h.State) | ||||
| } | ||||
|  | ||||
| func (h *PeerStateTestHarness) ConfigClientDirect(t *testing.T) *StateClientDirect { | ||||
| func (h *PeerStateTestHarness) ConfigClientDirect(t *testing.T) *stateClientDirect { | ||||
| 	keys := generateKeys() | ||||
| 	peer := &m.Peer{ | ||||
| 		PeerIP:     3, | ||||
| @@ -117,13 +117,13 @@ func (h *PeerStateTestHarness) ConfigClientDirect(t *testing.T) *StateClientDire | ||||
|  | ||||
| 	h.PeerUpdate(peer) | ||||
| 	assertEqual(t, h.Published.Up, false) | ||||
| 	return assertType[*StateClientDirect](t, h.State) | ||||
| 	return assertType[*stateClientDirect](t, h.State) | ||||
| } | ||||
|  | ||||
| func (h *PeerStateTestHarness) ConfigClientRelayed(t *testing.T) *StateClientRelayed { | ||||
| func (h *PeerStateTestHarness) ConfigClientRelayed(t *testing.T) *stateClientRelayed { | ||||
| 	keys := generateKeys() | ||||
|  | ||||
| 	state := h.State.(*StateDisconnected) | ||||
| 	state := h.State.(*stateDisconnected) | ||||
| 	state.remoteIP = 1 | ||||
|  | ||||
| 	peer := &m.Peer{ | ||||
| @@ -135,7 +135,7 @@ func (h *PeerStateTestHarness) ConfigClientRelayed(t *testing.T) *StateClientRel | ||||
|  | ||||
| 	h.PeerUpdate(peer) | ||||
| 	assertEqual(t, h.Published.Up, false) | ||||
| 	return assertType[*StateClientRelayed](t, h.State) | ||||
| 	return assertType[*stateClientRelayed](t, h.State) | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
| @@ -143,14 +143,14 @@ func (h *PeerStateTestHarness) ConfigClientRelayed(t *testing.T) *StateClientRel | ||||
| func TestPeerState_OnPeerUpdate_nilPeer(t *testing.T) { | ||||
| 	h := NewPeerStateTestHarness() | ||||
| 	h.PeerUpdate(nil) | ||||
| 	assertType[*StateDisconnected](t, h.State) | ||||
| 	assertType[*stateDisconnected](t, h.State) | ||||
| } | ||||
|  | ||||
| func TestPeerState_OnPeerUpdate_publicLocalIsServer(t *testing.T) { | ||||
| 	keys := generateKeys() | ||||
| 	h := NewPeerStateTestHarness() | ||||
|  | ||||
| 	state := h.State.(*StateDisconnected) | ||||
| 	state := h.State.(*stateDisconnected) | ||||
| 	state.localAddr = addrPort4(1, 1, 1, 2, 200) | ||||
|  | ||||
| 	peer := &m.Peer{ | ||||
| @@ -162,7 +162,7 @@ func TestPeerState_OnPeerUpdate_publicLocalIsServer(t *testing.T) { | ||||
|  | ||||
| 	h.PeerUpdate(peer) | ||||
| 	assertEqual(t, h.Published.Up, false) | ||||
| 	assertType[*StateServer](t, h.State) | ||||
| 	assertType[*stateServer](t, h.State) | ||||
| } | ||||
|  | ||||
| func TestPeerState_OnPeerUpdate_serverDirect(t *testing.T) { | ||||
| @@ -191,10 +191,10 @@ func TestStateServer_directSyn(t *testing.T) { | ||||
|  | ||||
| 	assertEqual(t, h.Published.Up, false) | ||||
|  | ||||
| 	synMsg := controlMsg[PacketSyn]{ | ||||
| 	synMsg := controlMsg[packetSyn]{ | ||||
| 		SrcIP:   3, | ||||
| 		SrcAddr: addrPort4(1, 1, 1, 3, 300), | ||||
| 		Packet: PacketSyn{ | ||||
| 		Packet: packetSyn{ | ||||
| 			TraceID: newTraceID(), | ||||
| 			//SentAt:        time.Now().UnixMilli(), | ||||
| 			//SharedKeyType: 1, | ||||
| @@ -205,7 +205,7 @@ func TestStateServer_directSyn(t *testing.T) { | ||||
| 	h.State.OnSyn(synMsg) | ||||
|  | ||||
| 	assertEqual(t, len(h.Sent), 1) | ||||
| 	ack := assertType[PacketAck](t, h.Sent[0].Packet) | ||||
| 	ack := assertType[packetAck](t, h.Sent[0].Packet) | ||||
| 	assertEqual(t, ack.TraceID, synMsg.Packet.TraceID) | ||||
| 	assertEqual(t, h.Sent[0].Peer.IP, 3) | ||||
| 	assertEqual(t, ack.PossibleAddrs[0].IsValid(), false) | ||||
| @@ -220,10 +220,10 @@ func TestStateServer_relayedSyn(t *testing.T) { | ||||
|  | ||||
| 	assertEqual(t, h.Published.Up, false) | ||||
|  | ||||
| 	synMsg := controlMsg[PacketSyn]{ | ||||
| 	synMsg := controlMsg[packetSyn]{ | ||||
| 		SrcIP:   3, | ||||
| 		SrcAddr: addrPort4(1, 1, 1, 3, 300), | ||||
| 		Packet: PacketSyn{ | ||||
| 		Packet: packetSyn{ | ||||
| 			TraceID: newTraceID(), | ||||
| 			//SentAt:        time.Now().UnixMilli(), | ||||
| 			//SharedKeyType: 1, | ||||
| @@ -237,15 +237,15 @@ func TestStateServer_relayedSyn(t *testing.T) { | ||||
|  | ||||
| 	assertEqual(t, len(h.Sent), 3) | ||||
|  | ||||
| 	ack := assertType[PacketAck](t, h.Sent[0].Packet) | ||||
| 	ack := assertType[packetAck](t, h.Sent[0].Packet) | ||||
| 	assertEqual(t, ack.TraceID, synMsg.Packet.TraceID) | ||||
| 	assertEqual(t, h.Sent[0].Peer.IP, 3) | ||||
| 	assertEqual(t, ack.PossibleAddrs[0], addrPort4(4, 5, 6, 7, 1234)) | ||||
| 	assertEqual(t, ack.PossibleAddrs[1].IsValid(), false) | ||||
| 	assertEqual(t, h.Published.Up, true) | ||||
|  | ||||
| 	assertType[PacketProbe](t, h.Sent[1].Packet) | ||||
| 	assertType[PacketProbe](t, h.Sent[2].Packet) | ||||
| 	assertType[packetProbe](t, h.Sent[1].Packet) | ||||
| 	assertType[packetProbe](t, h.Sent[2].Packet) | ||||
| 	assertEqual(t, h.Sent[1].Peer.DirectAddr, addrPort4(1, 1, 1, 3, 300)) | ||||
| 	assertEqual(t, h.Sent[2].Peer.DirectAddr, addrPort4(2, 2, 2, 3, 300)) | ||||
| } | ||||
| @@ -255,17 +255,17 @@ func TestStateServer_onProbe(t *testing.T) { | ||||
| 	h.ConfigServer_Relayed(t) | ||||
| 	assertEqual(t, h.Published.Up, false) | ||||
|  | ||||
| 	probeMsg := controlMsg[PacketProbe]{ | ||||
| 	probeMsg := controlMsg[packetProbe]{ | ||||
| 		SrcIP:   3, | ||||
| 		SrcAddr: addrPort4(1, 1, 1, 3, 300), | ||||
| 		Packet:  PacketProbe{TraceID: newTraceID()}, | ||||
| 		Packet:  packetProbe{TraceID: newTraceID()}, | ||||
| 	} | ||||
|  | ||||
| 	h.State.OnProbe(probeMsg) | ||||
|  | ||||
| 	assertEqual(t, len(h.Sent), 1) | ||||
|  | ||||
| 	probe := assertType[PacketProbe](t, h.Sent[0].Packet) | ||||
| 	probe := assertType[packetProbe](t, h.Sent[0].Packet) | ||||
| 	assertEqual(t, probe.TraceID, probeMsg.Packet.TraceID) | ||||
| 	assertEqual(t, h.Sent[0].Peer.DirectAddr, addrPort4(1, 1, 1, 3, 300)) | ||||
| } | ||||
| @@ -274,10 +274,10 @@ func TestStateServer_OnPingTimer_timeout(t *testing.T) { | ||||
| 	h := NewPeerStateTestHarness() | ||||
| 	h.ConfigServer_Relayed(t) | ||||
|  | ||||
| 	synMsg := controlMsg[PacketSyn]{ | ||||
| 	synMsg := controlMsg[packetSyn]{ | ||||
| 		SrcIP:   3, | ||||
| 		SrcAddr: addrPort4(1, 1, 1, 3, 300), | ||||
| 		Packet: PacketSyn{ | ||||
| 		Packet: packetSyn{ | ||||
| 			TraceID: newTraceID(), | ||||
| 			//SentAt:        time.Now().UnixMilli(), | ||||
| 			//SharedKeyType: 1, | ||||
| @@ -294,7 +294,7 @@ func TestStateServer_OnPingTimer_timeout(t *testing.T) { | ||||
| 	assertEqual(t, h.Published.Up, true) | ||||
|  | ||||
| 	// Advance the time, then ping. | ||||
| 	state := assertType[*StateServer](t, h.State) | ||||
| 	state := assertType[*stateServer](t, h.State) | ||||
| 	state.lastSeen = time.Now().Add(-timeoutInterval - time.Second) | ||||
|  | ||||
| 	h.OnPingTimer() | ||||
| @@ -309,10 +309,10 @@ func TestStateClientDirect_OnAck(t *testing.T) { | ||||
|  | ||||
| 	// On entering the state, a SYN should have been sent. | ||||
| 	assertEqual(t, len(h.Sent), 1) | ||||
| 	syn := assertType[PacketSyn](t, h.Sent[0].Packet) | ||||
| 	syn := assertType[packetSyn](t, h.Sent[0].Packet) | ||||
|  | ||||
| 	ack := controlMsg[PacketAck]{ | ||||
| 		Packet: PacketAck{TraceID: syn.TraceID}, | ||||
| 	ack := controlMsg[packetAck]{ | ||||
| 		Packet: packetAck{TraceID: syn.TraceID}, | ||||
| 	} | ||||
| 	h.State.OnAck(ack) | ||||
| 	assertEqual(t, h.Published.Up, true) | ||||
| @@ -326,10 +326,10 @@ func TestStateClientDirect_OnAck_incorrectTraceID(t *testing.T) { | ||||
|  | ||||
| 	// On entering the state, a SYN should have been sent. | ||||
| 	assertEqual(t, len(h.Sent), 1) | ||||
| 	syn := assertType[PacketSyn](t, h.Sent[0].Packet) | ||||
| 	syn := assertType[packetSyn](t, h.Sent[0].Packet) | ||||
|  | ||||
| 	ack := controlMsg[PacketAck]{ | ||||
| 		Packet: PacketAck{TraceID: syn.TraceID + 1}, | ||||
| 	ack := controlMsg[packetAck]{ | ||||
| 		Packet: packetAck{TraceID: syn.TraceID + 1}, | ||||
| 	} | ||||
| 	h.State.OnAck(ack) | ||||
| 	assertEqual(t, h.Published.Up, false) | ||||
| @@ -341,15 +341,15 @@ func TestStateClientDirect_OnPingTimer(t *testing.T) { | ||||
|  | ||||
| 	// On entering the state, a SYN should have been sent. | ||||
| 	assertEqual(t, len(h.Sent), 1) | ||||
| 	assertType[PacketSyn](t, h.Sent[0].Packet) | ||||
| 	assertType[packetSyn](t, h.Sent[0].Packet) | ||||
|  | ||||
| 	h.OnPingTimer() | ||||
|  | ||||
| 	// On ping timer, another syn should be sent. Additionally, we should remain | ||||
| 	// in the same state. | ||||
| 	assertEqual(t, len(h.Sent), 2) | ||||
| 	assertType[PacketSyn](t, h.Sent[1].Packet) | ||||
| 	assertType[*StateClientDirect](t, h.State) | ||||
| 	assertType[packetSyn](t, h.Sent[1].Packet) | ||||
| 	assertType[*stateClientDirect](t, h.State) | ||||
| 	assertEqual(t, h.Published.Up, false) | ||||
| } | ||||
|  | ||||
| @@ -361,15 +361,15 @@ func TestStateClientDirect_OnPingTimer_timeout(t *testing.T) { | ||||
|  | ||||
| 	// On entering the state, a SYN should have been sent. | ||||
| 	assertEqual(t, len(h.Sent), 1) | ||||
| 	syn := assertType[PacketSyn](t, h.Sent[0].Packet) | ||||
| 	syn := assertType[packetSyn](t, h.Sent[0].Packet) | ||||
|  | ||||
| 	ack := controlMsg[PacketAck]{ | ||||
| 		Packet: PacketAck{TraceID: syn.TraceID}, | ||||
| 	ack := controlMsg[packetAck]{ | ||||
| 		Packet: packetAck{TraceID: syn.TraceID}, | ||||
| 	} | ||||
| 	h.State.OnAck(ack) | ||||
| 	assertEqual(t, h.Published.Up, true) | ||||
|  | ||||
| 	state := assertType[*StateClientDirect](t, h.State) | ||||
| 	state := assertType[*stateClientDirect](t, h.State) | ||||
| 	state.lastSeen = time.Now().Add(-(timeoutInterval + time.Second)) | ||||
|  | ||||
| 	h.OnPingTimer() | ||||
| @@ -377,8 +377,8 @@ func TestStateClientDirect_OnPingTimer_timeout(t *testing.T) { | ||||
| 	// On ping timer, we should timeout, causing the client to reset. Another SYN | ||||
| 	// will be sent when re-entering the state, but the connection should be down. | ||||
| 	assertEqual(t, len(h.Sent), 2) | ||||
| 	assertType[PacketSyn](t, h.Sent[1].Packet) | ||||
| 	assertType[*StateClientDirect](t, h.State) | ||||
| 	assertType[packetSyn](t, h.Sent[1].Packet) | ||||
| 	assertType[*stateClientDirect](t, h.State) | ||||
| 	assertEqual(t, h.Published.Up, false) | ||||
| } | ||||
|  | ||||
| @@ -390,10 +390,10 @@ func TestStateClientRelayed_OnAck(t *testing.T) { | ||||
|  | ||||
| 	// On entering the state, a SYN should have been sent. | ||||
| 	assertEqual(t, len(h.Sent), 1) | ||||
| 	syn := assertType[PacketSyn](t, h.Sent[0].Packet) | ||||
| 	syn := assertType[packetSyn](t, h.Sent[0].Packet) | ||||
|  | ||||
| 	ack := controlMsg[PacketAck]{ | ||||
| 		Packet: PacketAck{TraceID: syn.TraceID}, | ||||
| 	ack := controlMsg[packetAck]{ | ||||
| 		Packet: packetAck{TraceID: syn.TraceID}, | ||||
| 	} | ||||
| 	h.State.OnAck(ack) | ||||
| 	assertEqual(t, h.Published.Up, true) | ||||
| @@ -423,9 +423,9 @@ func TestStateClientRelayed_OnPingTimer_withAddrs(t *testing.T) { | ||||
| 	// On entering the state, a SYN should have been sent. | ||||
| 	assertEqual(t, len(h.Sent), 1) | ||||
|  | ||||
| 	syn := assertType[PacketSyn](t, h.Sent[0].Packet) | ||||
| 	syn := assertType[packetSyn](t, h.Sent[0].Packet) | ||||
|  | ||||
| 	ack := controlMsg[PacketAck]{Packet: PacketAck{TraceID: syn.TraceID}} | ||||
| 	ack := controlMsg[packetAck]{Packet: packetAck{TraceID: syn.TraceID}} | ||||
| 	ack.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 300) | ||||
| 	ack.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 300) | ||||
|  | ||||
| @@ -433,7 +433,7 @@ func TestStateClientRelayed_OnPingTimer_withAddrs(t *testing.T) { | ||||
|  | ||||
| 	// Add a local discovery address. Note that the port will be configured port | ||||
| 	// and no the one provided here. | ||||
| 	h.State.OnLocalDiscovery(controlMsg[PacketLocalDiscovery]{ | ||||
| 	h.State.OnLocalDiscovery(controlMsg[packetLocalDiscovery]{ | ||||
| 		SrcIP:   3, | ||||
| 		SrcAddr: addrPort4(2, 2, 2, 3, 300), | ||||
| 	}) | ||||
| @@ -441,10 +441,10 @@ func TestStateClientRelayed_OnPingTimer_withAddrs(t *testing.T) { | ||||
| 	// We should see one SYN and three probe packets. | ||||
| 	h.OnPingTimer() | ||||
| 	assertEqual(t, len(h.Sent), 5) | ||||
| 	assertType[PacketSyn](t, h.Sent[1].Packet) | ||||
| 	assertType[PacketProbe](t, h.Sent[2].Packet) | ||||
| 	assertType[PacketProbe](t, h.Sent[3].Packet) | ||||
| 	assertType[PacketProbe](t, h.Sent[4].Packet) | ||||
| 	assertType[packetSyn](t, h.Sent[1].Packet) | ||||
| 	assertType[packetProbe](t, h.Sent[2].Packet) | ||||
| 	assertType[packetProbe](t, h.Sent[3].Packet) | ||||
| 	assertType[packetProbe](t, h.Sent[4].Packet) | ||||
|  | ||||
| 	assertEqual(t, h.Sent[2].Peer.DirectAddr, addrPort4(1, 1, 1, 1, 300)) | ||||
| 	assertEqual(t, h.Sent[3].Peer.DirectAddr, addrPort4(1, 1, 1, 2, 300)) | ||||
| @@ -457,15 +457,15 @@ func TestStateClientRelayed_OnPingTimer_timeout(t *testing.T) { | ||||
|  | ||||
| 	// On entering the state, a SYN should have been sent. | ||||
| 	assertEqual(t, len(h.Sent), 1) | ||||
| 	syn := assertType[PacketSyn](t, h.Sent[0].Packet) | ||||
| 	syn := assertType[packetSyn](t, h.Sent[0].Packet) | ||||
|  | ||||
| 	ack := controlMsg[PacketAck]{ | ||||
| 		Packet: PacketAck{TraceID: syn.TraceID}, | ||||
| 	ack := controlMsg[packetAck]{ | ||||
| 		Packet: packetAck{TraceID: syn.TraceID}, | ||||
| 	} | ||||
| 	h.State.OnAck(ack) | ||||
| 	assertEqual(t, h.Published.Up, true) | ||||
|  | ||||
| 	state := assertType[*StateClientRelayed](t, h.State) | ||||
| 	state := assertType[*stateClientRelayed](t, h.State) | ||||
| 	state.lastSeen = time.Now().Add(-(timeoutInterval + time.Second)) | ||||
|  | ||||
| 	h.OnPingTimer() | ||||
| @@ -473,8 +473,8 @@ func TestStateClientRelayed_OnPingTimer_timeout(t *testing.T) { | ||||
| 	// On ping timer, we should timeout, causing the client to reset. Another SYN | ||||
| 	// will be sent when re-entering the state, but the connection should be down. | ||||
| 	assertEqual(t, len(h.Sent), 2) | ||||
| 	assertType[PacketSyn](t, h.Sent[1].Packet) | ||||
| 	assertType[*StateClientRelayed](t, h.State) | ||||
| 	assertType[packetSyn](t, h.Sent[1].Packet) | ||||
| 	assertType[*stateClientRelayed](t, h.State) | ||||
| 	assertEqual(t, h.Published.Up, false) | ||||
| } | ||||
|  | ||||
| @@ -482,28 +482,28 @@ func TestStateClientRelayed_OnProbe_unknownAddr(t *testing.T) { | ||||
| 	h := NewPeerStateTestHarness() | ||||
| 	h.ConfigClientRelayed(t) | ||||
|  | ||||
| 	h.OnProbe(controlMsg[PacketProbe]{ | ||||
| 		Packet: PacketProbe{TraceID: newTraceID()}, | ||||
| 	h.OnProbe(controlMsg[packetProbe]{ | ||||
| 		Packet: packetProbe{TraceID: newTraceID()}, | ||||
| 	}) | ||||
|  | ||||
| 	assertType[*StateClientRelayed](t, h.State) | ||||
| 	assertType[*stateClientRelayed](t, h.State) | ||||
| } | ||||
|  | ||||
| func TestStateClientRelayed_OnProbe_upgradeDirect(t *testing.T) { | ||||
| 	h := NewPeerStateTestHarness() | ||||
| 	h.ConfigClientRelayed(t) | ||||
|  | ||||
| 	syn := assertType[PacketSyn](t, h.Sent[0].Packet) | ||||
| 	syn := assertType[packetSyn](t, h.Sent[0].Packet) | ||||
|  | ||||
| 	ack := controlMsg[PacketAck]{Packet: PacketAck{TraceID: syn.TraceID}} | ||||
| 	ack := controlMsg[packetAck]{Packet: packetAck{TraceID: syn.TraceID}} | ||||
| 	ack.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 300) | ||||
| 	ack.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 300) | ||||
|  | ||||
| 	h.State.OnAck(ack) | ||||
| 	h.OnPingTimer() | ||||
|  | ||||
| 	probe := assertType[PacketProbe](t, h.Sent[2].Packet) | ||||
| 	h.OnProbe(controlMsg[PacketProbe]{Packet: probe}) | ||||
| 	probe := assertType[packetProbe](t, h.Sent[2].Packet) | ||||
| 	h.OnProbe(controlMsg[packetProbe]{Packet: probe}) | ||||
|  | ||||
| 	assertType[*StateClientDirect](t, h.State) | ||||
| 	assertType[*stateClientDirect](t, h.State) | ||||
| } | ||||
|   | ||||
							
								
								
									
										172
									
								
								peer/peersuper.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										172
									
								
								peer/peersuper.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,172 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"log" | ||||
| 	"math/rand" | ||||
| 	"net/netip" | ||||
| 	"sync" | ||||
| 	"sync/atomic" | ||||
| 	"time" | ||||
|  | ||||
| 	"git.crumpington.com/lib/go/ratelimiter" | ||||
| ) | ||||
|  | ||||
| type Super struct { | ||||
| 	writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error) | ||||
| 	staged             routingTable | ||||
| 	shared             *atomic.Pointer[routingTable] | ||||
| 	peers              [256]*PeerSuper | ||||
| 	lock               sync.Mutex | ||||
|  | ||||
| 	buf1 []byte | ||||
| 	buf2 []byte | ||||
| } | ||||
|  | ||||
| func NewSuper( | ||||
| 	writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error), | ||||
| 	rt *atomic.Pointer[routingTable], | ||||
| 	privKey []byte, | ||||
| ) *Super { | ||||
|  | ||||
| 	routes := rt.Load() | ||||
|  | ||||
| 	s := &Super{ | ||||
| 		writeToUDPAddrPort: writeToUDPAddrPort, | ||||
| 		staged:             *routes, | ||||
| 		shared:             rt, | ||||
| 		buf1:               newBuf(), | ||||
| 		buf2:               newBuf(), | ||||
| 	} | ||||
|  | ||||
| 	pubAddrs := newPubAddrStore(routes.LocalAddr) | ||||
|  | ||||
| 	for i := range s.peers { | ||||
| 		state := &pState{ | ||||
| 			publish:           s.publish, | ||||
| 			sendControlPacket: s.send, | ||||
| 			localIP:           routes.LocalIP, | ||||
| 			remoteIP:          byte(i), | ||||
| 			privKey:           privKey, | ||||
| 			localAddr:         routes.LocalAddr, | ||||
| 			pubAddrs:          pubAddrs, | ||||
| 			staged:            routes.Peers[i], | ||||
| 			limiter: ratelimiter.New(ratelimiter.Config{ | ||||
| 				FillPeriod:   20 * time.Millisecond, | ||||
| 				MaxWaitCount: 1, | ||||
| 			}), | ||||
| 		} | ||||
| 		s.peers[i] = NewPeerSuper(state) | ||||
| 	} | ||||
|  | ||||
| 	return s | ||||
| } | ||||
|  | ||||
| func (s *Super) Start() { | ||||
| 	for i := range s.peers { | ||||
| 		go s.peers[i].Run() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s *Super) HandleControlMsg(destIP byte, msg any) { | ||||
| 	s.peers[destIP].HandleControlMsg(msg) | ||||
| } | ||||
|  | ||||
| func (s *Super) send(peer remotePeer, pkt Marshaller) { | ||||
| 	s.lock.Lock() | ||||
| 	defer s.lock.Unlock() | ||||
|  | ||||
| 	enc := peer.EncryptControlPacket(pkt, s.buf1, s.buf2) | ||||
| 	if peer.Direct { | ||||
| 		s.writeToUDPAddrPort(enc, peer.DirectAddr) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	relay, ok := s.staged.GetRelay() | ||||
| 	if !ok { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	enc = relay.EncryptDataPacket(peer.IP, enc, s.buf1) | ||||
| 	s.writeToUDPAddrPort(enc, relay.DirectAddr) | ||||
| } | ||||
|  | ||||
| func (s *Super) publish(rp remotePeer) { | ||||
| 	s.lock.Lock() | ||||
| 	defer s.lock.Unlock() | ||||
|  | ||||
| 	s.staged.Peers[rp.IP] = rp | ||||
| 	s.ensureRelay() | ||||
| 	copy := s.staged | ||||
| 	s.shared.Store(©) | ||||
| } | ||||
|  | ||||
| func (s *Super) ensureRelay() { | ||||
| 	if _, ok := s.staged.GetRelay(); ok { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// TODO: Random selection? | ||||
| 	for _, peer := range s.staged.Peers { | ||||
| 		if peer.Up && peer.Direct && peer.Relay { | ||||
| 			s.staged.RelayIP = peer.IP | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type PeerSuper struct { | ||||
| 	messages chan any | ||||
| 	state    peerState | ||||
| } | ||||
|  | ||||
| func NewPeerSuper(state *pState) *PeerSuper { | ||||
| 	return &PeerSuper{ | ||||
| 		messages: make(chan any, 8), | ||||
| 		state:    state.OnPeerUpdate(nil), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s *PeerSuper) HandleControlMsg(msg any) { | ||||
| 	select { | ||||
| 	case s.messages <- msg: | ||||
| 	default: | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s *PeerSuper) Run() { | ||||
| 	go func() { | ||||
| 		// Randomize ping timers. | ||||
| 		time.Sleep(time.Duration(rand.Intn(4000)) * time.Millisecond) | ||||
| 		for range time.Tick(pingInterval) { | ||||
| 			s.messages <- pingTimerMsg{} | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	for rawMsg := range s.messages { | ||||
| 		switch msg := rawMsg.(type) { | ||||
|  | ||||
| 		case peerUpdateMsg: | ||||
| 			s.state = s.state.OnPeerUpdate(msg.Peer) | ||||
|  | ||||
| 		case controlMsg[packetSyn]: | ||||
| 			s.state = s.state.OnSyn(msg) | ||||
|  | ||||
| 		case controlMsg[packetAck]: | ||||
| 			s.state.OnAck(msg) | ||||
|  | ||||
| 		case controlMsg[packetProbe]: | ||||
| 			s.state = s.state.OnProbe(msg) | ||||
|  | ||||
| 		case controlMsg[packetLocalDiscovery]: | ||||
| 			s.state.OnLocalDiscovery(msg) | ||||
|  | ||||
| 		case pingTimerMsg: | ||||
| 			s.state = s.state.OnPingTimer() | ||||
|  | ||||
| 		default: | ||||
| 			log.Printf("WARNING: unknown message type: %+v", msg) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| @@ -7,9 +7,9 @@ import ( | ||||
| ) | ||||
|  | ||||
| // TODO: Remove | ||||
| func NewRemotePeer(ip byte) *RemotePeer { | ||||
| func NewRemotePeer(ip byte) *remotePeer { | ||||
| 	counter := uint64(time.Now().Unix()<<30 + 1) | ||||
| 	return &RemotePeer{ | ||||
| 	return &remotePeer{ | ||||
| 		IP:       ip, | ||||
| 		counter:  &counter, | ||||
| 		dupCheck: newDupCheck(0), | ||||
| @@ -18,7 +18,7 @@ func NewRemotePeer(ip byte) *RemotePeer { | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type RemotePeer struct { | ||||
| type remotePeer struct { | ||||
| 	localIP       byte | ||||
| 	IP            byte           // VPN IP of peer (last byte). | ||||
| 	Up            bool           // True if data can be sent on the peer. | ||||
| @@ -33,7 +33,7 @@ type RemotePeer struct { | ||||
| 	dupCheck *dupCheck // For receiving from. Not safe for concurrent use. | ||||
| } | ||||
|  | ||||
| func (p RemotePeer) EncryptDataPacket(destIP byte, data, out []byte) []byte { | ||||
| func (p remotePeer) EncryptDataPacket(destIP byte, data, out []byte) []byte { | ||||
| 	h := header{ | ||||
| 		StreamID: dataStreamID, | ||||
| 		Counter:  atomic.AddUint64(p.counter, 1), | ||||
| @@ -44,7 +44,7 @@ func (p RemotePeer) EncryptDataPacket(destIP byte, data, out []byte) []byte { | ||||
| } | ||||
|  | ||||
| // Decrypts and de-dups incoming data packets. | ||||
| func (p RemotePeer) DecryptDataPacket(h header, enc, out []byte) ([]byte, error) { | ||||
| func (p remotePeer) DecryptDataPacket(h header, enc, out []byte) ([]byte, error) { | ||||
| 	dec, ok := p.DataCipher.Decrypt(enc, out) | ||||
| 	if !ok { | ||||
| 		return nil, errDecryptionFailed | ||||
| @@ -58,21 +58,22 @@ func (p RemotePeer) DecryptDataPacket(h header, enc, out []byte) ([]byte, error) | ||||
| } | ||||
|  | ||||
| // Peer must have a ControlCipher. | ||||
| func (p RemotePeer) EncryptControlPacket(pkt Marshaller, tmp, out []byte) []byte { | ||||
| func (p remotePeer) EncryptControlPacket(pkt Marshaller, tmp, out []byte) []byte { | ||||
| 	tmp = pkt.Marshal(tmp) | ||||
| 	h := header{ | ||||
| 		StreamID: controlStreamID, | ||||
| 		Counter:  atomic.AddUint64(p.counter, 1), | ||||
| 		SourceIP: p.localIP, | ||||
| 		DestIP:   p.IP, | ||||
| 	} | ||||
| 	tmp = pkt.Marshal(tmp) | ||||
|  | ||||
| 	return p.ControlCipher.Encrypt(h, tmp, out) | ||||
| } | ||||
|  | ||||
| // Returns a controlMsg[PacketType]. Peer must have a non-nil ControlCipher. | ||||
| // | ||||
| // This function also drops packets with duplicate sequence numbers. | ||||
| func (p RemotePeer) DecryptControlPacket(fromAddr netip.AddrPort, h header, enc, tmp []byte) (any, error) { | ||||
| func (p remotePeer) DecryptControlPacket(fromAddr netip.AddrPort, h header, enc, tmp []byte) (any, error) { | ||||
| 	out, ok := p.ControlCipher.Decrypt(enc, tmp) | ||||
| 	if !ok { | ||||
| 		return nil, errDecryptionFailed | ||||
| @@ -92,7 +93,7 @@ func (p RemotePeer) DecryptControlPacket(fromAddr netip.AddrPort, h header, enc, | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type RoutingTable struct { | ||||
| type routingTable struct { | ||||
| 	// The LocalIP is the configured IP address of the local peer on the VPN. | ||||
| 	// | ||||
| 	// This value is constant. | ||||
| @@ -106,21 +107,21 @@ type RoutingTable struct { | ||||
| 	LocalAddr netip.AddrPort | ||||
|  | ||||
| 	// The remote peer configurations. These are updated by | ||||
| 	Peers [256]RemotePeer | ||||
| 	Peers [256]remotePeer | ||||
|  | ||||
| 	// The current relay's VPN IP address, or zero if no relay is available. | ||||
| 	RelayIP byte | ||||
| } | ||||
|  | ||||
| func NewRoutingTable(localIP byte, localAddr netip.AddrPort) RoutingTable { | ||||
| 	rt := RoutingTable{ | ||||
| func newRoutingTable(localIP byte, localAddr netip.AddrPort) routingTable { | ||||
| 	rt := routingTable{ | ||||
| 		LocalIP:   localIP, | ||||
| 		LocalAddr: localAddr, | ||||
| 	} | ||||
|  | ||||
| 	for i := range rt.Peers { | ||||
| 		counter := uint64(time.Now().Unix()<<30 + 1) | ||||
| 		rt.Peers[i] = RemotePeer{ | ||||
| 		rt.Peers[i] = remotePeer{ | ||||
| 			localIP:  localIP, | ||||
| 			IP:       byte(i), | ||||
| 			counter:  &counter, | ||||
| @@ -131,7 +132,7 @@ func NewRoutingTable(localIP byte, localAddr netip.AddrPort) RoutingTable { | ||||
| 	return rt | ||||
| } | ||||
|  | ||||
| func (rt *RoutingTable) GetRelay() (RemotePeer, bool) { | ||||
| func (rt *routingTable) GetRelay() (remotePeer, bool) { | ||||
| 	relay := rt.Peers[rt.RelayIP] | ||||
| 	return relay, relay.Up && relay.Direct | ||||
| } | ||||
|   | ||||
| @@ -74,7 +74,7 @@ func TestRemotePeer_DecryptControlPacket(t *testing.T) { | ||||
| 	peer2 := p1.RT.Load().Peers[2] | ||||
| 	peer1 := p2.RT.Load().Peers[1] | ||||
|  | ||||
| 	orig := PacketProbe{TraceID: newTraceID()} | ||||
| 	orig := packetProbe{TraceID: newTraceID()} | ||||
|  | ||||
| 	enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf()) | ||||
|  | ||||
| @@ -88,7 +88,7 @@ func TestRemotePeer_DecryptControlPacket(t *testing.T) { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	dec, ok := ctrlMsg.(controlMsg[PacketProbe]) | ||||
| 	dec, ok := ctrlMsg.(controlMsg[packetProbe]) | ||||
| 	if !ok { | ||||
| 		t.Fatal(ctrlMsg) | ||||
| 	} | ||||
| @@ -108,7 +108,7 @@ func TestRemotePeer_DecryptControlPacket_packetAltered(t *testing.T) { | ||||
| 	peer2 := p1.RT.Load().Peers[2] | ||||
| 	peer1 := p2.RT.Load().Peers[1] | ||||
|  | ||||
| 	orig := PacketProbe{TraceID: newTraceID()} | ||||
| 	orig := packetProbe{TraceID: newTraceID()} | ||||
|  | ||||
| 	enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf()) | ||||
|  | ||||
| @@ -131,7 +131,7 @@ func TestRemotePeer_DecryptControlPacket_duplicateSequenceNumber(t *testing.T) { | ||||
| 	peer2 := p1.RT.Load().Peers[2] | ||||
| 	peer1 := p2.RT.Load().Peers[1] | ||||
|  | ||||
| 	orig := PacketProbe{TraceID: newTraceID()} | ||||
| 	orig := packetProbe{TraceID: newTraceID()} | ||||
|  | ||||
| 	enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf()) | ||||
|  | ||||
|   | ||||
| @@ -1,103 +0,0 @@ | ||||
| package peer | ||||
|  | ||||
| import ( | ||||
| 	"log" | ||||
| 	"sync/atomic" | ||||
| 	"time" | ||||
|  | ||||
| 	"git.crumpington.com/lib/go/ratelimiter" | ||||
| ) | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type Supervisor struct { | ||||
| 	messages chan any // Incoming control messages. | ||||
| 	peers    [256]PeerState | ||||
| 	pubAddrs *pubAddrStore | ||||
| 	rt       *atomic.Pointer[RoutingTable] | ||||
| 	staged   RoutingTable | ||||
| } | ||||
|  | ||||
| func NewSupervisor( | ||||
| 	sendControl func(RemotePeer, Marshaller), | ||||
| 	privKey []byte, | ||||
| 	rt *atomic.Pointer[RoutingTable], | ||||
| ) *Supervisor { | ||||
| 	s := &Supervisor{ | ||||
| 		messages: make(chan any, 1024), | ||||
| 		pubAddrs: newPubAddrStore(rt.Load().LocalAddr), | ||||
| 		rt:       rt, | ||||
| 	} | ||||
|  | ||||
| 	routes := rt.Load() | ||||
|  | ||||
| 	for i := range s.peers { | ||||
| 		state := &State{ | ||||
| 			publish:           s.Publish, | ||||
| 			sendControlPacket: sendControl, | ||||
| 			localIP:           routes.LocalIP, | ||||
| 			remoteIP:          byte(i), | ||||
| 			privKey:           privKey, | ||||
| 			localAddr:         routes.LocalAddr, | ||||
| 			pubAddrs:          s.pubAddrs, | ||||
| 			staged:            routes.Peers[i], | ||||
| 			limiter: ratelimiter.New(ratelimiter.Config{ | ||||
| 				FillPeriod:   20 * time.Millisecond, | ||||
| 				MaxWaitCount: 1, | ||||
| 			}), | ||||
| 		} | ||||
| 		s.peers[i] = state.OnPeerUpdate(nil) | ||||
| 	} | ||||
|  | ||||
| 	return s | ||||
| } | ||||
|  | ||||
| func (s *Supervisor) HandleControlMsg(msg any) { | ||||
| 	select { | ||||
| 	case s.messages <- msg: | ||||
| 	default: | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s *Supervisor) Run() { | ||||
| 	for raw := range s.messages { | ||||
| 		switch msg := raw.(type) { | ||||
|  | ||||
| 		case peerUpdateMsg: | ||||
| 			s.peers[msg.PeerIP] = s.peers[msg.PeerIP].OnPeerUpdate(msg.Peer) | ||||
|  | ||||
| 		case controlMsg[PacketSyn]: | ||||
| 			if newState := s.peers[msg.SrcIP].OnSyn(msg); newState != nil { | ||||
| 				s.peers[msg.SrcIP] = newState | ||||
| 			} | ||||
|  | ||||
| 		case controlMsg[PacketAck]: | ||||
| 			s.peers[msg.SrcIP].OnAck(msg) | ||||
|  | ||||
| 		case controlMsg[PacketProbe]: | ||||
| 			if newState := s.peers[msg.SrcIP].OnProbe(msg); newState != nil { | ||||
| 				s.peers[msg.SrcIP] = newState | ||||
| 			} | ||||
|  | ||||
| 		case controlMsg[PacketLocalDiscovery]: | ||||
| 			s.peers[msg.SrcIP].OnLocalDiscovery(msg) | ||||
|  | ||||
| 		case pingTimerMsg: | ||||
| 			s.pubAddrs.Clean() | ||||
| 			for i := range s.peers { | ||||
| 				if newState := s.peers[i].OnPingTimer(); newState != nil { | ||||
| 					s.peers[i] = newState | ||||
| 				} | ||||
| 			} | ||||
|  | ||||
| 		default: | ||||
| 			log.Printf("WARNING: unknown message type: %+v", msg) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s *Supervisor) Publish(rp RemotePeer) { | ||||
| 	s.staged.Peers[rp.IP] = rp | ||||
| 	rt := s.staged // Copy. | ||||
| 	s.rt.Store(&rt) | ||||
| } | ||||
		Reference in New Issue
	
	Block a user