refactor-for-testability #3
| @@ -2,10 +2,10 @@ package main | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"log" | 	"log" | ||||||
| 	"vppn/node" | 	"vppn/peer" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func main() { | func main() { | ||||||
| 	log.SetFlags(0) | 	log.SetFlags(0) | ||||||
| 	node.Main() | 	peer.Main() | ||||||
| } | } | ||||||
|   | |||||||
| @@ -258,7 +258,6 @@ func handleControlPacket(addr netip.AddrPort, h header, data, decBuf []byte) { | |||||||
| 	default: | 	default: | ||||||
| 		log.Printf("Dropping control packet.") | 		log.Printf("Dropping control packet.") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func handleDataPacket(h header, data []byte, decBuf []byte, iface ifWriter, sender dataPacketSender) { | func handleDataPacket(h header, data []byte, decBuf []byte, iface ifWriter, sender dataPacketSender) { | ||||||
|   | |||||||
| @@ -12,7 +12,7 @@ type connReader struct { | |||||||
| 	sender  encryptedPacketSender | 	sender  encryptedPacketSender | ||||||
| 	super   controlMsgHandler | 	super   controlMsgHandler | ||||||
| 	localIP byte | 	localIP byte | ||||||
| 	peers   [256]*atomic.Pointer[RemotePeer] | 	peers   [256]*atomic.Pointer[remotePeer] | ||||||
|  |  | ||||||
| 	buf    []byte | 	buf    []byte | ||||||
| 	decBuf []byte | 	decBuf []byte | ||||||
| @@ -24,7 +24,7 @@ func newConnReader( | |||||||
| 	sender encryptedPacketSender, | 	sender encryptedPacketSender, | ||||||
| 	super controlMsgHandler, | 	super controlMsgHandler, | ||||||
| 	localIP byte, | 	localIP byte, | ||||||
| 	peers [256]*atomic.Pointer[RemotePeer], | 	peers [256]*atomic.Pointer[remotePeer], | ||||||
| ) *connReader { | ) *connReader { | ||||||
| 	return &connReader{ | 	return &connReader{ | ||||||
| 		conn:    conn, | 		conn:    conn, | ||||||
| @@ -79,7 +79,7 @@ func (r *connReader) handleNextPacket() { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (r *connReader) handleControlPacket( | func (r *connReader) handleControlPacket( | ||||||
| 	peer *RemotePeer, | 	peer *remotePeer, | ||||||
| 	addr netip.AddrPort, | 	addr netip.AddrPort, | ||||||
| 	h header, | 	h header, | ||||||
| 	enc []byte, | 	enc []byte, | ||||||
| @@ -102,7 +102,7 @@ func (r *connReader) handleControlPacket( | |||||||
| 	r.super.HandleControlMsg(msg) | 	r.super.HandleControlMsg(msg) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (r *connReader) handleDataPacket(peer *RemotePeer, h header, enc []byte) { | func (r *connReader) handleDataPacket(peer *remotePeer, h header, enc []byte) { | ||||||
| 	if !peer.Up { | 	if !peer.Up { | ||||||
| 		r.logf("Not connected (recv).") | 		r.logf("Not connected (recv).") | ||||||
| 		return | 		return | ||||||
|   | |||||||
| @@ -12,12 +12,12 @@ type ConnReader struct { | |||||||
| 	readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error) | 	readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error) | ||||||
|  |  | ||||||
| 	// Output | 	// Output | ||||||
|  | 	writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error) | ||||||
| 	iface              io.Writer | 	iface              io.Writer | ||||||
| 	forwardData      func(ip byte, pkt []byte) | 	handleControlMsg   func(fromIP byte, pkt any) | ||||||
| 	handleControlMsg func(pkt any) |  | ||||||
|  |  | ||||||
| 	localIP byte | 	localIP byte | ||||||
| 	rt      *atomic.Pointer[RoutingTable] | 	rt      *atomic.Pointer[routingTable] | ||||||
|  |  | ||||||
| 	buf    []byte | 	buf    []byte | ||||||
| 	decBuf []byte | 	decBuf []byte | ||||||
| @@ -25,15 +25,15 @@ type ConnReader struct { | |||||||
|  |  | ||||||
| func NewConnReader( | func NewConnReader( | ||||||
| 	readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error), | 	readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error), | ||||||
|  | 	writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error), | ||||||
| 	iface io.Writer, | 	iface io.Writer, | ||||||
| 	forwardData func(ip byte, pkt []byte), | 	handleControlMsg func(fromIP byte, pkt any), | ||||||
| 	handleControlMsg func(pkt any), | 	rt *atomic.Pointer[routingTable], | ||||||
| 	rt *atomic.Pointer[RoutingTable], |  | ||||||
| ) *ConnReader { | ) *ConnReader { | ||||||
| 	return &ConnReader{ | 	return &ConnReader{ | ||||||
| 		readFromUDPAddrPort: readFromUDPAddrPort, | 		readFromUDPAddrPort: readFromUDPAddrPort, | ||||||
|  | 		writeToUDPAddrPort:  writeToUDPAddrPort, | ||||||
| 		iface:               iface, | 		iface:               iface, | ||||||
| 		forwardData:         forwardData, |  | ||||||
| 		handleControlMsg:    handleControlMsg, | 		handleControlMsg:    handleControlMsg, | ||||||
| 		localIP:             rt.Load().LocalIP, | 		localIP:             rt.Load().LocalIP, | ||||||
| 		rt:                  rt, | 		rt:                  rt, | ||||||
| @@ -50,7 +50,9 @@ func (r *ConnReader) Run() { | |||||||
|  |  | ||||||
| func (r *ConnReader) handleNextPacket() { | func (r *ConnReader) handleNextPacket() { | ||||||
| 	buf := r.buf[:bufferSize] | 	buf := r.buf[:bufferSize] | ||||||
|  | 	log.Printf("Getting next packet...") | ||||||
| 	n, remoteAddr, err := r.readFromUDPAddrPort(buf) | 	n, remoteAddr, err := r.readFromUDPAddrPort(buf) | ||||||
|  | 	log.Printf("Packet from %v...", remoteAddr) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Fatalf("Failed to read from UDP port: %v", err) | 		log.Fatalf("Failed to read from UDP port: %v", err) | ||||||
| 	} | 	} | ||||||
| @@ -64,14 +66,14 @@ func (r *ConnReader) handleNextPacket() { | |||||||
| 	buf = buf[:n] | 	buf = buf[:n] | ||||||
| 	h := parseHeader(buf) | 	h := parseHeader(buf) | ||||||
|  |  | ||||||
| 	peer := r.rt.Load().Peers[h.SourceIP] | 	rt := r.rt.Load() | ||||||
| 	//peer := rt.Peers[h.SourceIP] | 	peer := rt.Peers[h.SourceIP] | ||||||
|  |  | ||||||
| 	switch h.StreamID { | 	switch h.StreamID { | ||||||
| 	case controlStreamID: | 	case controlStreamID: | ||||||
| 		r.handleControlPacket(remoteAddr, peer, h, buf) | 		r.handleControlPacket(remoteAddr, peer, h, buf) | ||||||
| 	case dataStreamID: | 	case dataStreamID: | ||||||
| 		r.handleDataPacket(peer, h, buf) | 		r.handleDataPacket(rt, peer, h, buf) | ||||||
| 	default: | 	default: | ||||||
| 		r.logf("Unknown stream ID: %d", h.StreamID) | 		r.logf("Unknown stream ID: %d", h.StreamID) | ||||||
| 	} | 	} | ||||||
| @@ -79,7 +81,7 @@ func (r *ConnReader) handleNextPacket() { | |||||||
|  |  | ||||||
| func (r *ConnReader) handleControlPacket( | func (r *ConnReader) handleControlPacket( | ||||||
| 	remoteAddr netip.AddrPort, | 	remoteAddr netip.AddrPort, | ||||||
| 	peer RemotePeer, | 	peer remotePeer, | ||||||
| 	h header, | 	h header, | ||||||
| 	enc []byte, | 	enc []byte, | ||||||
| ) { | ) { | ||||||
| @@ -98,11 +100,12 @@ func (r *ConnReader) handleControlPacket( | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	r.handleControlMsg(msg) | 	r.handleControlMsg(h.SourceIP, msg) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (r *ConnReader) handleDataPacket( | func (r *ConnReader) handleDataPacket( | ||||||
| 	peer RemotePeer, | 	rt *routingTable, | ||||||
|  | 	peer remotePeer, | ||||||
| 	h header, | 	h header, | ||||||
| 	enc []byte, | 	enc []byte, | ||||||
| ) { | ) { | ||||||
| @@ -124,7 +127,13 @@ func (r *ConnReader) handleDataPacket( | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	r.forwardData(h.DestIP, data) | 	relay, ok := rt.GetRelay() | ||||||
|  | 	if !ok { | ||||||
|  | 		r.logf("Relay not available.") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	r.writeToUDPAddrPort(data, relay.DirectAddr) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (r *ConnReader) logf(format string, args ...any) { | func (r *ConnReader) logf(format string, args ...any) { | ||||||
|   | |||||||
| @@ -1,353 +0,0 @@ | |||||||
| package peer |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"bytes" |  | ||||||
| 	"crypto/rand" |  | ||||||
| 	"net/netip" |  | ||||||
| 	"reflect" |  | ||||||
| 	"sync/atomic" |  | ||||||
| 	"testing" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| type mockIfWriter struct { |  | ||||||
| 	Written [][]byte |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (w *mockIfWriter) Write(b []byte) (int, error) { |  | ||||||
| 	w.Written = append(w.Written, bytes.Clone(b)) |  | ||||||
| 	return len(b), nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type mockEncryptedPacket struct { |  | ||||||
| 	Packet []byte |  | ||||||
| 	Route  *RemotePeer |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type mockEncryptedPacketSender struct { |  | ||||||
| 	Sent []mockEncryptedPacket |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (m *mockEncryptedPacketSender) SendEncryptedDataPacket(pkt []byte, route *RemotePeer) { |  | ||||||
| 	m.Sent = append(m.Sent, mockEncryptedPacket{ |  | ||||||
| 		Packet: bytes.Clone(pkt), |  | ||||||
| 		Route:  route, |  | ||||||
| 	}) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type mockControlMsgHandler struct { |  | ||||||
| 	Messages []any |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (m *mockControlMsgHandler) HandleControlMsg(pkt any) { |  | ||||||
| 	m.Messages = append(m.Messages, pkt) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type udpPipe struct { |  | ||||||
| 	packets chan []byte |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func newUDPPipe() *udpPipe { |  | ||||||
| 	return &udpPipe{make(chan []byte, 1024)} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (p *udpPipe) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { |  | ||||||
| 	p.packets <- bytes.Clone(b) |  | ||||||
| 	return len(b), nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (p *udpPipe) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) { |  | ||||||
| 	packet := <-p.packets |  | ||||||
| 	copy(b, packet) |  | ||||||
| 	return len(packet), netip.AddrPort{}, nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type connReaderTestHarness struct { |  | ||||||
| 	Pipe         *udpPipe |  | ||||||
| 	R            *connReader |  | ||||||
| 	WRemote      *connWriter |  | ||||||
| 	WRelayRemote *connWriter |  | ||||||
| 	Remote       *RemotePeer |  | ||||||
| 	RelayRemote  *RemotePeer |  | ||||||
| 	IFace        *mockIfWriter |  | ||||||
| 	Sender       *mockEncryptedPacketSender |  | ||||||
| 	Super        *mockControlMsgHandler |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Peer 2 is indirect, peer 3 is direct. |  | ||||||
| func newConnReadeTestHarness() (h connReaderTestHarness) { |  | ||||||
| 	pipe := newUDPPipe() |  | ||||||
| 	routes := [256]*atomic.Pointer[RemotePeer]{} |  | ||||||
| 	for i := range routes { |  | ||||||
| 		routes[i] = &atomic.Pointer[RemotePeer]{} |  | ||||||
| 		routes[i].Store(&RemotePeer{}) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	local, remote, relayLocal, relayRemote := testConnWriter_getTestRoutes() |  | ||||||
| 	routes[2].Store(local) |  | ||||||
| 	routes[3].Store(relayLocal) |  | ||||||
|  |  | ||||||
| 	h.Pipe = pipe |  | ||||||
| 	h.WRemote = newConnWriter(pipe, 2) |  | ||||||
| 	h.WRelayRemote = newConnWriter(pipe, 3) |  | ||||||
|  |  | ||||||
| 	h.Remote = remote |  | ||||||
| 	h.RelayRemote = relayRemote |  | ||||||
| 	h.IFace = &mockIfWriter{} |  | ||||||
| 	h.Sender = &mockEncryptedPacketSender{} |  | ||||||
| 	h.Super = &mockControlMsgHandler{} |  | ||||||
| 	h.R = newConnReader( |  | ||||||
| 		pipe, |  | ||||||
| 		h.IFace, |  | ||||||
| 		h.Sender, |  | ||||||
| 		h.Super, |  | ||||||
| 		1, |  | ||||||
| 		routes) |  | ||||||
| 	return h |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Testing that we can receive a control packet. |  | ||||||
| func TestConnReader_handleControlPacket(t *testing.T) { |  | ||||||
| 	h := newConnReadeTestHarness() |  | ||||||
|  |  | ||||||
| 	pkt := PacketSyn{TraceID: 1234} |  | ||||||
|  |  | ||||||
| 	h.WRemote.SendControlPacket(pkt, h.Remote) |  | ||||||
|  |  | ||||||
| 	h.R.handleNextPacket() |  | ||||||
|  |  | ||||||
| 	if len(h.Super.Messages) != 1 { |  | ||||||
| 		t.Fatal(h.Super.Messages) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	msg := h.Super.Messages[0].(controlMsg[PacketSyn]) |  | ||||||
| 	if !reflect.DeepEqual(pkt, msg.Packet) { |  | ||||||
| 		t.Fatal(msg.Packet) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Testing that a short packet is ignored. |  | ||||||
| func TestConnReader_handleNextPacket_short(t *testing.T) { |  | ||||||
| 	h := newConnReadeTestHarness() |  | ||||||
|  |  | ||||||
| 	h.Pipe.WriteToUDPAddrPort([]byte{1, 2, 3}, netip.AddrPort{}) |  | ||||||
| 	h.R.handleNextPacket() |  | ||||||
|  |  | ||||||
| 	if len(h.Super.Messages) != 0 { |  | ||||||
| 		t.Fatal(h.Super.Messages) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Testing that a packet with an unexpected stream ID is ignored. |  | ||||||
| func TestConnReader_handleNextPacket_unknownStreamID(t *testing.T) { |  | ||||||
| 	h := newConnReadeTestHarness() |  | ||||||
|  |  | ||||||
| 	pkt := PacketSyn{TraceID: 1234} |  | ||||||
|  |  | ||||||
| 	encrypted := encryptControlPacket(1, h.Remote, pkt, newBuf(), newBuf()) |  | ||||||
| 	var header header |  | ||||||
| 	header.Parse(encrypted) |  | ||||||
| 	header.StreamID = 100 |  | ||||||
| 	header.Marshal(encrypted) |  | ||||||
|  |  | ||||||
| 	h.WRemote.writeTo(encrypted, netip.AddrPort{}) |  | ||||||
| 	h.R.handleNextPacket() |  | ||||||
| 	if len(h.Super.Messages) != 0 { |  | ||||||
| 		t.Fatal(h.Super.Messages) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Testing that control packet without matching control cipher is ignored. |  | ||||||
| func TestConnReader_handleControlPacket_noCipher(t *testing.T) { |  | ||||||
| 	h := newConnReadeTestHarness() |  | ||||||
|  |  | ||||||
| 	pkt := PacketSyn{TraceID: 1234} |  | ||||||
|  |  | ||||||
| 	//encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) |  | ||||||
| 	encrypted := encryptControlPacket(1, h.Remote, pkt, newBuf(), newBuf()) |  | ||||||
| 	var header header |  | ||||||
| 	header.Parse(encrypted) |  | ||||||
| 	header.SourceIP = 10 |  | ||||||
| 	header.Marshal(encrypted) |  | ||||||
|  |  | ||||||
| 	h.WRemote.writeTo(encrypted, netip.AddrPort{}) |  | ||||||
| 	h.R.handleNextPacket() |  | ||||||
| 	if len(h.Super.Messages) != 0 { |  | ||||||
| 		t.Fatal(h.Super.Messages) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Testing that control packet with incrrect destination IP is ignored. |  | ||||||
| func TestConnReader_handleControlPacket_incorrectDest(t *testing.T) { |  | ||||||
| 	h := newConnReadeTestHarness() |  | ||||||
|  |  | ||||||
| 	pkt := PacketSyn{TraceID: 1234} |  | ||||||
|  |  | ||||||
| 	encrypted := encryptControlPacket(2, h.Remote, pkt, newBuf(), newBuf()) |  | ||||||
| 	var header header |  | ||||||
| 	header.Parse(encrypted) |  | ||||||
| 	header.DestIP++ |  | ||||||
| 	header.Marshal(encrypted) |  | ||||||
|  |  | ||||||
| 	h.WRemote.writeTo(encrypted, netip.AddrPort{}) |  | ||||||
| 	h.R.handleNextPacket() |  | ||||||
| 	if len(h.Super.Messages) != 0 { |  | ||||||
| 		t.Fatal(h.Super.Messages) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Testing that modified control packet is ignored. |  | ||||||
| func TestConnReader_handleControlPacket_modified(t *testing.T) { |  | ||||||
| 	h := newConnReadeTestHarness() |  | ||||||
|  |  | ||||||
| 	pkt := PacketSyn{TraceID: 1234} |  | ||||||
|  |  | ||||||
| 	encrypted := encryptControlPacket(2, h.Remote, pkt, newBuf(), newBuf()) |  | ||||||
| 	encrypted[len(encrypted)-1]++ |  | ||||||
|  |  | ||||||
| 	h.WRemote.writeTo(encrypted, netip.AddrPort{}) |  | ||||||
| 	h.R.handleNextPacket() |  | ||||||
| 	if len(h.Super.Messages) != 0 { |  | ||||||
| 		t.Fatal(h.Super.Messages) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type unknownPacket struct{} |  | ||||||
|  |  | ||||||
| func (p unknownPacket) Marshal(buf []byte) []byte { |  | ||||||
| 	buf = buf[:1] |  | ||||||
| 	buf[0] = 100 |  | ||||||
| 	return buf |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Testing that an empty control packet is ignored. |  | ||||||
| func TestConnReader_handleControlPacket_unknownPacketType(t *testing.T) { |  | ||||||
| 	h := newConnReadeTestHarness() |  | ||||||
|  |  | ||||||
| 	pkt := unknownPacket{} |  | ||||||
|  |  | ||||||
| 	encrypted := encryptControlPacket(2, h.Remote, pkt, newBuf(), newBuf()) |  | ||||||
| 	h.WRemote.writeTo(encrypted, netip.AddrPort{}) |  | ||||||
| 	h.R.handleNextPacket() |  | ||||||
| 	if len(h.Super.Messages) != 0 { |  | ||||||
| 		t.Fatal(h.Super.Messages) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Testing that a duplicate control packet is ignored. |  | ||||||
| func TestConnReader_handleControlPacket_duplicate(t *testing.T) { |  | ||||||
| 	h := newConnReadeTestHarness() |  | ||||||
|  |  | ||||||
| 	pkt := PacketAck{TraceID: 1234} |  | ||||||
|  |  | ||||||
| 	h.WRemote.SendControlPacket(pkt, h.Remote) |  | ||||||
| 	*h.Remote.counter = *h.Remote.counter - 1 |  | ||||||
| 	h.WRemote.SendControlPacket(pkt, h.Remote) |  | ||||||
|  |  | ||||||
| 	h.R.handleNextPacket() |  | ||||||
| 	h.R.handleNextPacket() |  | ||||||
|  |  | ||||||
| 	if len(h.Super.Messages) != 1 { |  | ||||||
| 		t.Fatal(h.Super.Messages) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	msg := h.Super.Messages[0].(controlMsg[PacketAck]) |  | ||||||
| 	if !reflect.DeepEqual(pkt, msg.Packet) { |  | ||||||
| 		t.Fatal(msg.Packet) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Testing that we can receive a data packet. |  | ||||||
| func TestConnReader_handleDataPacket(t *testing.T) { |  | ||||||
| 	h := newConnReadeTestHarness() |  | ||||||
|  |  | ||||||
| 	pkt := make([]byte, 1024) |  | ||||||
| 	rand.Read(pkt) |  | ||||||
|  |  | ||||||
| 	h.WRemote.SendDataPacket(pkt, h.Remote) |  | ||||||
|  |  | ||||||
| 	h.R.handleNextPacket() |  | ||||||
|  |  | ||||||
| 	if len(h.IFace.Written) != 1 { |  | ||||||
| 		t.Fatal(h.IFace.Written) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if !bytes.Equal(pkt, h.IFace.Written[0]) { |  | ||||||
| 		t.Fatal(h.IFace.Written) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Testing that data packet is ignored if route isn't up. |  | ||||||
| func TestConnReader_handleDataPacket_routeDown(t *testing.T) { |  | ||||||
| 	h := newConnReadeTestHarness() |  | ||||||
|  |  | ||||||
| 	pkt := make([]byte, 1024) |  | ||||||
| 	rand.Read(pkt) |  | ||||||
|  |  | ||||||
| 	h.WRemote.SendDataPacket(pkt, h.Remote) |  | ||||||
| 	route := h.R.peers[2].Load() |  | ||||||
| 	route.Up = false |  | ||||||
|  |  | ||||||
| 	h.R.handleNextPacket() |  | ||||||
|  |  | ||||||
| 	if len(h.IFace.Written) != 0 { |  | ||||||
| 		t.Fatal(h.IFace.Written) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Testing that a duplicate data packet is ignored. |  | ||||||
| func TestConnReader_handleDataPacket_duplicate(t *testing.T) { |  | ||||||
| 	h := newConnReadeTestHarness() |  | ||||||
|  |  | ||||||
| 	pkt := make([]byte, 123) |  | ||||||
|  |  | ||||||
| 	h.WRemote.SendDataPacket(pkt, h.Remote) |  | ||||||
| 	*h.Remote.counter = *h.Remote.counter - 1 |  | ||||||
| 	h.WRemote.SendDataPacket(pkt, h.Remote) |  | ||||||
|  |  | ||||||
| 	h.R.handleNextPacket() |  | ||||||
| 	h.R.handleNextPacket() |  | ||||||
|  |  | ||||||
| 	if len(h.IFace.Written) != 1 { |  | ||||||
| 		t.Fatal(h.IFace.Written) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if !bytes.Equal(pkt, h.IFace.Written[0]) { |  | ||||||
| 		t.Fatal(h.IFace.Written) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Testing that we can relay a data packet. |  | ||||||
| func TestConnReader_handleDataPacket_relay(t *testing.T) { |  | ||||||
| 	h := newConnReadeTestHarness() |  | ||||||
|  |  | ||||||
| 	pkt := make([]byte, 1024) |  | ||||||
| 	rand.Read(pkt) |  | ||||||
|  |  | ||||||
| 	h.RelayRemote.IP = 3 |  | ||||||
| 	h.WRemote.RelayDataPacket(pkt, h.RelayRemote, h.Remote) |  | ||||||
| 	h.R.handleNextPacket() |  | ||||||
|  |  | ||||||
| 	if len(h.Sender.Sent) != 1 { |  | ||||||
| 		t.Fatal(h.Sender.Sent) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Testing that we drop a relayed packet if destination is down. |  | ||||||
| func TestConnReader_handleDataPacket_relayDown(t *testing.T) { |  | ||||||
| 	h := newConnReadeTestHarness() |  | ||||||
|  |  | ||||||
| 	pkt := make([]byte, 1024) |  | ||||||
| 	rand.Read(pkt) |  | ||||||
|  |  | ||||||
| 	h.RelayRemote.IP = 3 |  | ||||||
| 	relay := h.R.peers[3].Load() |  | ||||||
| 	relay.Up = false |  | ||||||
|  |  | ||||||
| 	h.WRemote.RelayDataPacket(pkt, h.RelayRemote, h.Remote) |  | ||||||
| 	h.R.handleNextPacket() |  | ||||||
|  |  | ||||||
| 	if len(h.Sender.Sent) != 0 { |  | ||||||
| 		t.Fatal(h.Sender.Sent) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| @@ -1,80 +0,0 @@ | |||||||
| package peer |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"log" |  | ||||||
| 	"net/netip" |  | ||||||
| 	"sync" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- |  | ||||||
|  |  | ||||||
| type connWriter struct { |  | ||||||
| 	localIP byte |  | ||||||
| 	conn    udpWriter |  | ||||||
|  |  | ||||||
| 	// For sending control packets. |  | ||||||
| 	cBuf1 []byte |  | ||||||
| 	cBuf2 []byte |  | ||||||
|  |  | ||||||
| 	// For sending data packets. |  | ||||||
| 	dBuf1 []byte |  | ||||||
| 	dBuf2 []byte |  | ||||||
|  |  | ||||||
| 	// Lock around for sending on UDP Conn. |  | ||||||
| 	wLock sync.Mutex |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func newConnWriter(conn udpWriter, localIP byte) *connWriter { |  | ||||||
| 	w := &connWriter{ |  | ||||||
| 		localIP: localIP, |  | ||||||
| 		conn:    conn, |  | ||||||
| 		cBuf1:   make([]byte, bufferSize), |  | ||||||
| 		cBuf2:   make([]byte, bufferSize), |  | ||||||
| 		dBuf1:   make([]byte, bufferSize), |  | ||||||
| 		dBuf2:   make([]byte, bufferSize), |  | ||||||
| 	} |  | ||||||
| 	return w |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Not safe for concurrent use. Should only be called by supervisor. |  | ||||||
| func (w *connWriter) SendControlPacket(pkt Marshaller, peer *RemotePeer) { |  | ||||||
| 	enc := encryptControlPacket(w.localIP, peer, pkt, w.cBuf1, w.cBuf2) |  | ||||||
| 	w.writeTo(enc, peer.DirectAddr) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Relay control packet. Peer must not be nil. |  | ||||||
| func (w *connWriter) RelayControlPacket(pkt Marshaller, peer, relay *RemotePeer) { |  | ||||||
| 	enc := encryptControlPacket(w.localIP, peer, pkt, w.cBuf1, w.cBuf2) |  | ||||||
| 	enc = encryptDataPacket(w.localIP, peer.IP, relay, enc, w.cBuf1) |  | ||||||
| 	w.writeTo(enc, relay.DirectAddr) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Not safe for concurrent use. Should only be called by ifReader. |  | ||||||
| func (w *connWriter) SendDataPacket(pkt []byte, peer *RemotePeer) { |  | ||||||
| 	enc := encryptDataPacket(w.localIP, peer.IP, peer, pkt, w.dBuf1) |  | ||||||
| 	w.writeTo(enc, peer.DirectAddr) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Relay a data packet. Peer must not be nil. |  | ||||||
| func (w *connWriter) RelayDataPacket(pkt []byte, peer, relay *RemotePeer) { |  | ||||||
| 	enc := encryptDataPacket(w.localIP, peer.IP, peer, pkt, w.dBuf1) |  | ||||||
| 	enc = encryptDataPacket(w.localIP, peer.IP, relay, enc, w.dBuf2) |  | ||||||
| 	w.writeTo(enc, relay.DirectAddr) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Safe for concurrent use. Should only be called by connReader. |  | ||||||
| // |  | ||||||
| // This function will send pkt to the peer directly. This is used when a peer |  | ||||||
| // is acting as a relay and is forwarding already encrypted data for another |  | ||||||
| // peer. |  | ||||||
| func (w *connWriter) SendEncryptedDataPacket(pkt []byte, peer *RemotePeer) { |  | ||||||
| 	w.writeTo(pkt, peer.DirectAddr) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (w *connWriter) writeTo(packet []byte, addr netip.AddrPort) { |  | ||||||
| 	w.wLock.Lock() |  | ||||||
| 	if _, err := w.conn.WriteToUDPAddrPort(packet, addr); err != nil { |  | ||||||
| 		log.Printf("[ConnWriter] Failed to write to UDP port: %v", err) |  | ||||||
| 	} |  | ||||||
| 	w.wLock.Unlock() |  | ||||||
| } |  | ||||||
| @@ -1,109 +0,0 @@ | |||||||
| package peer |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"log" |  | ||||||
| 	"net/netip" |  | ||||||
| 	"sync" |  | ||||||
| 	"sync/atomic" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| type ConnWriter struct { |  | ||||||
| 	wLock sync.Mutex // Lock around for sending on UDP Conn. |  | ||||||
|  |  | ||||||
| 	// Output. |  | ||||||
| 	writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error) |  | ||||||
|  |  | ||||||
| 	// Shared state. |  | ||||||
| 	rt *atomic.Pointer[RoutingTable] |  | ||||||
|  |  | ||||||
| 	// For sending control packets. |  | ||||||
| 	cBuf1 []byte |  | ||||||
| 	cBuf2 []byte |  | ||||||
|  |  | ||||||
| 	// For sending data packets. |  | ||||||
| 	dBuf1 []byte |  | ||||||
| 	dBuf2 []byte |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func NewConnWriter( |  | ||||||
| 	writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error), |  | ||||||
| 	rt *atomic.Pointer[RoutingTable], |  | ||||||
| ) *ConnWriter { |  | ||||||
| 	return &ConnWriter{ |  | ||||||
| 		writeToUDPAddrPort: writeToUDPAddrPort, |  | ||||||
| 		rt:                 rt, |  | ||||||
| 		cBuf1:              newBuf(), |  | ||||||
| 		cBuf2:              newBuf(), |  | ||||||
| 		dBuf1:              newBuf(), |  | ||||||
| 		dBuf2:              newBuf(), |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Called by ConnReader to forward already encrypted bytes to another peer. |  | ||||||
| func (w *ConnWriter) Forward(ip byte, pkt []byte) { |  | ||||||
| 	peer := w.rt.Load().Peers[ip] |  | ||||||
| 	if !(peer.Up && peer.Direct) { |  | ||||||
| 		w.logf("Failed to forward to %d.", ip) |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 	w.writeTo(pkt, peer.DirectAddr) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Called by IFReader to send data. Encryption will be applied, and packet will |  | ||||||
| // be relayed if appropriate. |  | ||||||
| func (w *ConnWriter) WriteData(ip byte, pkt []byte) { |  | ||||||
| 	rt := w.rt.Load() |  | ||||||
| 	peer := rt.Peers[ip] |  | ||||||
| 	if !peer.Up { |  | ||||||
| 		w.logf("Failed to send data to %d.", ip) |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	enc := peer.EncryptDataPacket(ip, pkt, w.dBuf1) |  | ||||||
|  |  | ||||||
| 	if peer.Direct { |  | ||||||
| 		w.writeTo(enc, peer.DirectAddr) |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	relay, ok := rt.GetRelay() |  | ||||||
| 	if !ok { |  | ||||||
| 		w.logf("Failed to send data to %d. No relay.", ip) |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	enc = relay.EncryptDataPacket(ip, enc, w.dBuf2) |  | ||||||
| 	w.writeTo(enc, relay.DirectAddr) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Called by Supervisor to send control packets. |  | ||||||
| func (w *ConnWriter) WriteControl(peer RemotePeer, pkt Marshaller) { |  | ||||||
| 	enc := peer.EncryptControlPacket(pkt, w.cBuf2, w.cBuf1) |  | ||||||
|  |  | ||||||
| 	if peer.Direct { |  | ||||||
| 		w.writeTo(enc, peer.DirectAddr) |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	rt := w.rt.Load() |  | ||||||
| 	relay, ok := rt.GetRelay() |  | ||||||
| 	if !ok { |  | ||||||
| 		w.logf("Failed to send control to %d. No relay.", peer.IP) |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	enc = relay.EncryptDataPacket(peer.IP, enc, w.cBuf2) |  | ||||||
| 	w.writeTo(enc, relay.DirectAddr) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (w *ConnWriter) writeTo(pkt []byte, addr netip.AddrPort) { |  | ||||||
| 	w.wLock.Lock() |  | ||||||
| 	if _, err := w.writeToUDPAddrPort(pkt, addr); err != nil { |  | ||||||
| 		w.logf("Failed to write to UDP port: %v", err) |  | ||||||
| 	} |  | ||||||
| 	w.wLock.Unlock() |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (w *ConnWriter) logf(s string, args ...any) { |  | ||||||
| 	log.Printf("[ConnWriter] "+s, args...) |  | ||||||
| } |  | ||||||
| @@ -1,145 +0,0 @@ | |||||||
| package peer |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"testing" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| func TestConnWriter_WriteData_direct(t *testing.T) { |  | ||||||
| 	p1, p2, _ := NewPeersForTesting() |  | ||||||
|  |  | ||||||
| 	in := RandPacket() |  | ||||||
| 	p1.ConnWriter.WriteData(2, in) |  | ||||||
|  |  | ||||||
| 	packets := p2.Conn.Packets() |  | ||||||
| 	if len(packets) != 1 { |  | ||||||
| 		t.Fatal(packets) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestConnWriter_WriteData_peerNotUp(t *testing.T) { |  | ||||||
| 	p1, p2, _ := NewPeersForTesting() |  | ||||||
| 	p1.RT.Load().Peers[2].Up = false |  | ||||||
|  |  | ||||||
| 	in := RandPacket() |  | ||||||
| 	p1.ConnWriter.WriteData(2, in) |  | ||||||
|  |  | ||||||
| 	packets := p2.Conn.Packets() |  | ||||||
| 	if len(packets) != 0 { |  | ||||||
| 		t.Fatal(packets) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestConnWriter_WriteData_relay(t *testing.T) { |  | ||||||
| 	p1, _, p3 := NewPeersForTesting() |  | ||||||
|  |  | ||||||
| 	p1.RT.Load().Peers[2].Direct = false |  | ||||||
| 	p1.RT.Load().RelayIP = 3 |  | ||||||
|  |  | ||||||
| 	in := RandPacket() |  | ||||||
| 	p1.ConnWriter.WriteData(2, in) |  | ||||||
|  |  | ||||||
| 	packets := p3.Conn.Packets() |  | ||||||
| 	if len(packets) != 1 { |  | ||||||
| 		t.Fatal(packets) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestConnWriter_WriteData_relayNotAvailable(t *testing.T) { |  | ||||||
| 	p1, _, p3 := NewPeersForTesting() |  | ||||||
|  |  | ||||||
| 	p1.RT.Load().Peers[2].Direct = false |  | ||||||
| 	p1.RT.Load().Peers[3].Up = false |  | ||||||
| 	p1.RT.Load().RelayIP = 3 |  | ||||||
|  |  | ||||||
| 	in := RandPacket() |  | ||||||
| 	p1.ConnWriter.WriteData(2, in) |  | ||||||
|  |  | ||||||
| 	packets := p3.Conn.Packets() |  | ||||||
| 	if len(packets) != 0 { |  | ||||||
| 		t.Fatal(packets) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestConnWriter_WriteControl_direct(t *testing.T) { |  | ||||||
| 	p1, p2, _ := NewPeersForTesting() |  | ||||||
|  |  | ||||||
| 	orig := PacketProbe{TraceID: newTraceID()} |  | ||||||
|  |  | ||||||
| 	p1.ConnWriter.WriteControl(p1.RT.Load().Peers[2], orig) |  | ||||||
|  |  | ||||||
| 	packets := p2.Conn.Packets() |  | ||||||
| 	if len(packets) != 1 { |  | ||||||
| 		t.Fatal(packets) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestConnWriter_WriteControl_relay(t *testing.T) { |  | ||||||
| 	p1, _, p3 := NewPeersForTesting() |  | ||||||
|  |  | ||||||
| 	p1.RT.Load().Peers[2].Direct = false |  | ||||||
| 	p1.RT.Load().RelayIP = 3 |  | ||||||
|  |  | ||||||
| 	orig := PacketProbe{TraceID: newTraceID()} |  | ||||||
|  |  | ||||||
| 	p1.ConnWriter.WriteControl(p1.RT.Load().Peers[2], orig) |  | ||||||
|  |  | ||||||
| 	packets := p3.Conn.Packets() |  | ||||||
| 	if len(packets) != 1 { |  | ||||||
| 		t.Fatal(packets) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestConnWriter_WriteControl_relayNotAvailable(t *testing.T) { |  | ||||||
| 	p1, _, p3 := NewPeersForTesting() |  | ||||||
|  |  | ||||||
| 	p1.RT.Load().Peers[2].Direct = false |  | ||||||
| 	p1.RT.Load().Peers[3].Up = false |  | ||||||
| 	p1.RT.Load().RelayIP = 3 |  | ||||||
|  |  | ||||||
| 	orig := PacketProbe{TraceID: newTraceID()} |  | ||||||
|  |  | ||||||
| 	p1.ConnWriter.WriteControl(p1.RT.Load().Peers[2], orig) |  | ||||||
|  |  | ||||||
| 	packets := p3.Conn.Packets() |  | ||||||
| 	if len(packets) != 0 { |  | ||||||
| 		t.Fatal(packets) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestConnWriter__Forward(t *testing.T) { |  | ||||||
| 	p1, p2, _ := NewPeersForTesting() |  | ||||||
|  |  | ||||||
| 	in := RandPacket() |  | ||||||
| 	p1.ConnWriter.Forward(2, in) |  | ||||||
|  |  | ||||||
| 	packets := p2.Conn.Packets() |  | ||||||
| 	if len(packets) != 1 { |  | ||||||
| 		t.Fatal(packets) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestConnWriter__Forward_notUp(t *testing.T) { |  | ||||||
| 	p1, p2, _ := NewPeersForTesting() |  | ||||||
| 	p1.RT.Load().Peers[2].Up = false |  | ||||||
|  |  | ||||||
| 	in := RandPacket() |  | ||||||
| 	p1.ConnWriter.Forward(2, in) |  | ||||||
|  |  | ||||||
| 	packets := p2.Conn.Packets() |  | ||||||
| 	if len(packets) != 0 { |  | ||||||
| 		t.Fatal(packets) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestConnWriter__Forward_notDirect(t *testing.T) { |  | ||||||
| 	p1, p2, _ := NewPeersForTesting() |  | ||||||
| 	p1.RT.Load().Peers[2].Direct = false |  | ||||||
|  |  | ||||||
| 	in := RandPacket() |  | ||||||
| 	p1.ConnWriter.Forward(2, in) |  | ||||||
|  |  | ||||||
| 	packets := p2.Conn.Packets() |  | ||||||
| 	if len(packets) != 0 { |  | ||||||
| 		t.Fatal(packets) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| @@ -1,240 +0,0 @@ | |||||||
| package peer |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"bytes" |  | ||||||
| 	"net/netip" |  | ||||||
| 	"testing" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- |  | ||||||
|  |  | ||||||
| type testUDPPacket struct { |  | ||||||
| 	Addr netip.AddrPort |  | ||||||
| 	Data []byte |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type testUDPAddrPortWriter struct { |  | ||||||
| 	written []testUDPPacket |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (w *testUDPAddrPortWriter) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { |  | ||||||
| 	w.written = append(w.written, testUDPPacket{ |  | ||||||
| 		Addr: addr, |  | ||||||
| 		Data: bytes.Clone(b), |  | ||||||
| 	}) |  | ||||||
| 	return len(b), nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (w *testUDPAddrPortWriter) Written() []testUDPPacket { |  | ||||||
| 	out := w.written |  | ||||||
| 	w.written = []testUDPPacket{} |  | ||||||
| 	return out |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- |  | ||||||
|  |  | ||||||
| type testPacket string |  | ||||||
|  |  | ||||||
| func (p testPacket) Marshal(b []byte) []byte { |  | ||||||
| 	b = b[:len(p)] |  | ||||||
| 	copy(b, []byte(p)) |  | ||||||
| 	return b |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- |  | ||||||
|  |  | ||||||
| func testConnWriter_getTestRoutes() (local, remote, relayLocal, relayRemote *RemotePeer) { |  | ||||||
| 	localKeys := generateKeys() |  | ||||||
| 	remoteKeys := generateKeys() |  | ||||||
|  |  | ||||||
| 	local = NewRemotePeer(2) |  | ||||||
| 	local.Up = true |  | ||||||
| 	local.Relay = false |  | ||||||
| 	local.PubSignKey = remoteKeys.PubSignKey |  | ||||||
| 	local.ControlCipher = newControlCipher(localKeys.PrivKey, remoteKeys.PubKey) |  | ||||||
| 	local.DataCipher = newDataCipher() |  | ||||||
| 	local.DirectAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 2}), 100) |  | ||||||
|  |  | ||||||
| 	remote = NewRemotePeer(1) |  | ||||||
| 	remote.Up = true |  | ||||||
| 	remote.Relay = false |  | ||||||
| 	remote.PubSignKey = localKeys.PubSignKey |  | ||||||
| 	remote.ControlCipher = newControlCipher(remoteKeys.PrivKey, localKeys.PubKey) |  | ||||||
| 	remote.DataCipher = local.DataCipher |  | ||||||
| 	remote.DirectAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 100) |  | ||||||
|  |  | ||||||
| 	rLocalKeys := generateKeys() |  | ||||||
| 	rRemoteKeys := generateKeys() |  | ||||||
|  |  | ||||||
| 	relayLocal = NewRemotePeer(3) |  | ||||||
| 	relayLocal.Up = true |  | ||||||
| 	relayLocal.Relay = true |  | ||||||
| 	relayLocal.Direct = true |  | ||||||
| 	relayLocal.PubSignKey = rRemoteKeys.PubSignKey |  | ||||||
| 	relayLocal.ControlCipher = newControlCipher(rLocalKeys.PrivKey, rRemoteKeys.PubKey) |  | ||||||
| 	relayLocal.DataCipher = newDataCipher() |  | ||||||
| 	relayLocal.DirectAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 3}), 100) |  | ||||||
|  |  | ||||||
| 	relayRemote = NewRemotePeer(1) |  | ||||||
| 	relayRemote.Up = true |  | ||||||
| 	relayRemote.Relay = false |  | ||||||
| 	relayRemote.Direct = true |  | ||||||
| 	relayRemote.PubSignKey = rLocalKeys.PubSignKey |  | ||||||
| 	relayRemote.ControlCipher = newControlCipher(rRemoteKeys.PrivKey, rLocalKeys.PubKey) |  | ||||||
| 	relayRemote.DataCipher = relayLocal.DataCipher |  | ||||||
| 	relayRemote.DirectAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 100) |  | ||||||
|  |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- |  | ||||||
|  |  | ||||||
| // Testing if we can send a control packet directly to the remote route. |  | ||||||
| func TestConnWriter_SendControlPacket_direct(t *testing.T) { |  | ||||||
| 	route, rRoute, _, _ := testConnWriter_getTestRoutes() |  | ||||||
| 	route.Direct = true |  | ||||||
|  |  | ||||||
| 	writer := &testUDPAddrPortWriter{} |  | ||||||
| 	w := newConnWriter(writer, rRoute.IP) |  | ||||||
| 	in := testPacket("hello world!") |  | ||||||
|  |  | ||||||
| 	w.SendControlPacket(in, route) |  | ||||||
| 	out := writer.Written() |  | ||||||
| 	if len(out) != 1 { |  | ||||||
| 		t.Fatal(out) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if out[0].Addr != route.DirectAddr { |  | ||||||
| 		t.Fatal(out[0]) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	dec, ok := rRoute.ControlCipher.Decrypt(out[0].Data, make([]byte, 1024)) |  | ||||||
| 	if !ok { |  | ||||||
| 		t.Fatal(ok) |  | ||||||
| 	} |  | ||||||
| 	if string(dec) != string(in) { |  | ||||||
| 		t.Fatal(dec) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Testing if we can relay a packet via an intermediary. |  | ||||||
| func TestConnWriter_RelayControlPacket_relay(t *testing.T) { |  | ||||||
| 	route, rRoute, relay, rRelay := testConnWriter_getTestRoutes() |  | ||||||
|  |  | ||||||
| 	writer := &testUDPAddrPortWriter{} |  | ||||||
| 	w := newConnWriter(writer, rRoute.IP) |  | ||||||
| 	in := testPacket("hello world!") |  | ||||||
|  |  | ||||||
| 	w.RelayControlPacket(in, route, relay) |  | ||||||
|  |  | ||||||
| 	out := writer.Written() |  | ||||||
| 	if len(out) != 1 { |  | ||||||
| 		t.Fatal(out) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if out[0].Addr != relay.DirectAddr { |  | ||||||
| 		t.Fatal(out[0]) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	dec, ok := rRelay.DataCipher.Decrypt(out[0].Data, make([]byte, 1024)) |  | ||||||
| 	if !ok { |  | ||||||
| 		t.Fatal(ok) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	dec2, ok := rRoute.ControlCipher.Decrypt(dec, make([]byte, 1024)) |  | ||||||
| 	if !ok { |  | ||||||
| 		t.Fatal(ok) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if string(dec2) != string(in) { |  | ||||||
| 		t.Fatal(dec2) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Testing that we can send a data packet directly to a remote route. |  | ||||||
| func TestConnWriter_SendDataPacket_direct(t *testing.T) { |  | ||||||
| 	route, rRoute, _, _ := testConnWriter_getTestRoutes() |  | ||||||
| 	route.Direct = true |  | ||||||
|  |  | ||||||
| 	writer := &testUDPAddrPortWriter{} |  | ||||||
| 	w := newConnWriter(writer, rRoute.IP) |  | ||||||
|  |  | ||||||
| 	in := []byte("hello world!") |  | ||||||
| 	w.SendDataPacket(in, route) |  | ||||||
|  |  | ||||||
| 	out := writer.Written() |  | ||||||
| 	if len(out) != 1 { |  | ||||||
| 		t.Fatal(out) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if out[0].Addr != route.DirectAddr { |  | ||||||
| 		t.Fatal(out[0]) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	dec, ok := rRoute.DataCipher.Decrypt(out[0].Data, make([]byte, 1024)) |  | ||||||
| 	if !ok { |  | ||||||
| 		t.Fatal(ok) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if !bytes.Equal(dec, in) { |  | ||||||
| 		t.Fatal(dec) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Testing that we can relay a data packet via a relay. |  | ||||||
| func TestConnWriter_RelayDataPacket_relay(t *testing.T) { |  | ||||||
| 	route, rRoute, relay, rRelay := testConnWriter_getTestRoutes() |  | ||||||
|  |  | ||||||
| 	writer := &testUDPAddrPortWriter{} |  | ||||||
| 	w := newConnWriter(writer, rRoute.IP) |  | ||||||
| 	in := []byte("Hello world!") |  | ||||||
|  |  | ||||||
| 	w.RelayDataPacket(in, route, relay) |  | ||||||
|  |  | ||||||
| 	out := writer.Written() |  | ||||||
| 	if len(out) != 1 { |  | ||||||
| 		t.Fatal(out) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if out[0].Addr != relay.DirectAddr { |  | ||||||
| 		t.Fatal(out[0]) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	dec, ok := rRelay.DataCipher.Decrypt(out[0].Data, make([]byte, 1024)) |  | ||||||
| 	if !ok { |  | ||||||
| 		t.Fatal(ok) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	dec2, ok := rRoute.DataCipher.Decrypt(dec, make([]byte, 1024)) |  | ||||||
| 	if !ok { |  | ||||||
| 		t.Fatal(ok) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if !bytes.Equal(dec2, in) { |  | ||||||
| 		t.Fatal(dec2) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Testing that we can send an already encrypted packet. |  | ||||||
| func TestConnWriter_SendEncryptedDataPacket(t *testing.T) { |  | ||||||
| 	route, rRoute, _, _ := testConnWriter_getTestRoutes() |  | ||||||
|  |  | ||||||
| 	writer := &testUDPAddrPortWriter{} |  | ||||||
| 	w := newConnWriter(writer, rRoute.IP) |  | ||||||
| 	in := []byte("Hello world!") |  | ||||||
|  |  | ||||||
| 	w.SendEncryptedDataPacket(in, route) |  | ||||||
|  |  | ||||||
| 	out := writer.Written() |  | ||||||
| 	if len(out) != 1 { |  | ||||||
| 		t.Fatal(out) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if out[0].Addr != route.DirectAddr { |  | ||||||
| 		t.Fatal(out[0]) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if !bytes.Equal(out[0].Data, in) { |  | ||||||
| 		t.Fatal(out[0]) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| @@ -17,25 +17,25 @@ type controlMsg[T any] struct { | |||||||
| func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error) { | func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error) { | ||||||
| 	switch buf[0] { | 	switch buf[0] { | ||||||
|  |  | ||||||
| 	case PacketTypeSyn: | 	case packetTypeSyn: | ||||||
| 		packet, err := ParsePacketSyn(buf) | 		packet, err := parsePacketSyn(buf) | ||||||
| 		return controlMsg[PacketSyn]{ | 		return controlMsg[packetSyn]{ | ||||||
| 			SrcIP:   srcIP, | 			SrcIP:   srcIP, | ||||||
| 			SrcAddr: srcAddr, | 			SrcAddr: srcAddr, | ||||||
| 			Packet:  packet, | 			Packet:  packet, | ||||||
| 		}, err | 		}, err | ||||||
|  |  | ||||||
| 	case PacketTypeAck: | 	case packetTypeAck: | ||||||
| 		packet, err := ParsePacketAck(buf) | 		packet, err := parsePacketAck(buf) | ||||||
| 		return controlMsg[PacketAck]{ | 		return controlMsg[packetAck]{ | ||||||
| 			SrcIP:   srcIP, | 			SrcIP:   srcIP, | ||||||
| 			SrcAddr: srcAddr, | 			SrcAddr: srcAddr, | ||||||
| 			Packet:  packet, | 			Packet:  packet, | ||||||
| 		}, err | 		}, err | ||||||
|  |  | ||||||
| 	case PacketTypeProbe: | 	case packetTypeProbe: | ||||||
| 		packet, err := ParsePacketProbe(buf) | 		packet, err := parsePacketProbe(buf) | ||||||
| 		return controlMsg[PacketProbe]{ | 		return controlMsg[packetProbe]{ | ||||||
| 			SrcIP:   srcIP, | 			SrcIP:   srcIP, | ||||||
| 			SrcAddr: srcAddr, | 			SrcAddr: srcAddr, | ||||||
| 			Packet:  packet, | 			Packet:  packet, | ||||||
|   | |||||||
| @@ -36,7 +36,7 @@ func generateKeys() cryptoKeys { | |||||||
| // Peer must have a ControlCipher. | // Peer must have a ControlCipher. | ||||||
| func encryptControlPacket( | func encryptControlPacket( | ||||||
| 	localIP byte, | 	localIP byte, | ||||||
| 	peer *RemotePeer, | 	peer *remotePeer, | ||||||
| 	pkt Marshaller, | 	pkt Marshaller, | ||||||
| 	tmp []byte, | 	tmp []byte, | ||||||
| 	out []byte, | 	out []byte, | ||||||
| @@ -55,7 +55,7 @@ func encryptControlPacket( | |||||||
| // | // | ||||||
| // This function also drops packets with duplicate sequence numbers. | // This function also drops packets with duplicate sequence numbers. | ||||||
| func decryptControlPacket( | func decryptControlPacket( | ||||||
| 	peer *RemotePeer, | 	peer *remotePeer, | ||||||
| 	fromAddr netip.AddrPort, | 	fromAddr netip.AddrPort, | ||||||
| 	h header, | 	h header, | ||||||
| 	encrypted []byte, | 	encrypted []byte, | ||||||
| @@ -83,7 +83,7 @@ func decryptControlPacket( | |||||||
| func encryptDataPacket( | func encryptDataPacket( | ||||||
| 	localIP byte, | 	localIP byte, | ||||||
| 	destIP byte, | 	destIP byte, | ||||||
| 	peer *RemotePeer, | 	peer *remotePeer, | ||||||
| 	data []byte, | 	data []byte, | ||||||
| 	out []byte, | 	out []byte, | ||||||
| ) []byte { | ) []byte { | ||||||
| @@ -98,7 +98,7 @@ func encryptDataPacket( | |||||||
|  |  | ||||||
| // Decrypts and de-dups incoming data packets. | // Decrypts and de-dups incoming data packets. | ||||||
| func decryptDataPacket( | func decryptDataPacket( | ||||||
| 	peer *RemotePeer, | 	peer *remotePeer, | ||||||
| 	h header, | 	h header, | ||||||
| 	encrypted []byte, | 	encrypted []byte, | ||||||
| 	out []byte, | 	out []byte, | ||||||
|   | |||||||
| @@ -9,7 +9,7 @@ import ( | |||||||
| 	"testing" | 	"testing" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func newRoutePairForTesting() (*RemotePeer, *RemotePeer) { | func newRoutePairForTesting() (*remotePeer, *remotePeer) { | ||||||
| 	keys1 := generateKeys() | 	keys1 := generateKeys() | ||||||
| 	keys2 := generateKeys() | 	keys2 := generateKeys() | ||||||
|  |  | ||||||
| @@ -33,7 +33,7 @@ func TestDecryptControlPacket(t *testing.T) { | |||||||
| 		out    = make([]byte, bufferSize) | 		out    = make([]byte, bufferSize) | ||||||
| 	) | 	) | ||||||
|  |  | ||||||
| 	in := PacketSyn{ | 	in := packetSyn{ | ||||||
| 		TraceID:   newTraceID(), | 		TraceID:   newTraceID(), | ||||||
| 		SharedKey: r1.DataCipher.Key(), | 		SharedKey: r1.DataCipher.Key(), | ||||||
| 		Direct:    true, | 		Direct:    true, | ||||||
| @@ -47,7 +47,7 @@ func TestDecryptControlPacket(t *testing.T) { | |||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	msg, ok := iMsg.(controlMsg[PacketSyn]) | 	msg, ok := iMsg.(controlMsg[packetSyn]) | ||||||
| 	if !ok { | 	if !ok { | ||||||
| 		t.Fatal(ok) | 		t.Fatal(ok) | ||||||
| 	} | 	} | ||||||
| @@ -64,7 +64,7 @@ func TestDecryptControlPacket_decryptionFailed(t *testing.T) { | |||||||
| 		out    = make([]byte, bufferSize) | 		out    = make([]byte, bufferSize) | ||||||
| 	) | 	) | ||||||
|  |  | ||||||
| 	in := PacketSyn{ | 	in := packetSyn{ | ||||||
| 		TraceID:   newTraceID(), | 		TraceID:   newTraceID(), | ||||||
| 		SharedKey: r1.DataCipher.Key(), | 		SharedKey: r1.DataCipher.Key(), | ||||||
| 		Direct:    true, | 		Direct:    true, | ||||||
| @@ -90,7 +90,7 @@ func TestDecryptControlPacket_duplicate(t *testing.T) { | |||||||
| 		out    = make([]byte, bufferSize) | 		out    = make([]byte, bufferSize) | ||||||
| 	) | 	) | ||||||
|  |  | ||||||
| 	in := PacketSyn{ | 	in := packetSyn{ | ||||||
| 		TraceID:   newTraceID(), | 		TraceID:   newTraceID(), | ||||||
| 		SharedKey: r1.DataCipher.Key(), | 		SharedKey: r1.DataCipher.Key(), | ||||||
| 		Direct:    true, | 		Direct:    true, | ||||||
| @@ -109,6 +109,7 @@ func TestDecryptControlPacket_duplicate(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | /* | ||||||
| 	func TestDecryptControlPacket_invalidPacket(t *testing.T) { | 	func TestDecryptControlPacket_invalidPacket(t *testing.T) { | ||||||
| 		var ( | 		var ( | ||||||
| 			r1, r2 = newRoutePairForTesting() | 			r1, r2 = newRoutePairForTesting() | ||||||
| @@ -126,7 +127,7 @@ func TestDecryptControlPacket_invalidPacket(t *testing.T) { | |||||||
| 			t.Fatal(err) | 			t.Fatal(err) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  | */ | ||||||
| func TestDecryptDataPacket(t *testing.T) { | func TestDecryptDataPacket(t *testing.T) { | ||||||
| 	var ( | 	var ( | ||||||
| 		r1, r2 = newRoutePairForTesting() | 		r1, r2 = newRoutePairForTesting() | ||||||
|   | |||||||
| @@ -16,10 +16,16 @@ type hubPoller struct { | |||||||
| 	versions         [256]int64 | 	versions         [256]int64 | ||||||
| 	localIP          byte | 	localIP          byte | ||||||
| 	netName          string | 	netName          string | ||||||
| 	super    controlMsgHandler | 	handleControlMsg func(fromIP byte, msg any) | ||||||
| } | } | ||||||
|  |  | ||||||
| func newHubPoller(localIP byte, netName, hubURL, apiKey string, super controlMsgHandler) (*hubPoller, error) { | func newHubPoller( | ||||||
|  | 	localIP byte, | ||||||
|  | 	netName, | ||||||
|  | 	hubURL, | ||||||
|  | 	apiKey string, | ||||||
|  | 	handleControlMsg func(byte, any), | ||||||
|  | ) (*hubPoller, error) { | ||||||
| 	u, err := url.Parse(hubURL) | 	u, err := url.Parse(hubURL) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -40,7 +46,7 @@ func newHubPoller(localIP byte, netName, hubURL, apiKey string, super controlMsg | |||||||
| 		req:              req, | 		req:              req, | ||||||
| 		localIP:          localIP, | 		localIP:          localIP, | ||||||
| 		netName:          netName, | 		netName:          netName, | ||||||
| 		super:   super, | 		handleControlMsg: handleControlMsg, | ||||||
| 	}, nil | 	}, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -90,7 +96,7 @@ func (hp *hubPoller) applyNetworkState(state m.NetworkState) { | |||||||
| 	for i, peer := range state.Peers { | 	for i, peer := range state.Peers { | ||||||
| 		if i != int(hp.localIP) { | 		if i != int(hp.localIP) { | ||||||
| 			if peer == nil || peer.Version != hp.versions[i] { | 			if peer == nil || peer.Version != hp.versions[i] { | ||||||
| 				hp.super.HandleControlMsg(peerUpdateMsg{PeerIP: byte(i), Peer: state.Peers[i]}) | 				hp.handleControlMsg(byte(i), peerUpdateMsg{PeerIP: byte(i), Peer: state.Peers[i]}) | ||||||
| 				if peer != nil { | 				if peer != nil { | ||||||
| 					hp.versions[i] = peer.Version | 					hp.versions[i] = peer.Version | ||||||
| 				} | 				} | ||||||
|   | |||||||
							
								
								
									
										100
									
								
								peer/ifreader.go
									
									
									
									
									
								
							
							
						
						
									
										100
									
								
								peer/ifreader.go
									
									
									
									
									
								
							| @@ -1,100 +0,0 @@ | |||||||
| package peer |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"io" |  | ||||||
| 	"log" |  | ||||||
| 	"sync/atomic" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| type ifReader struct { |  | ||||||
| 	iface  io.Reader |  | ||||||
| 	peers  [256]*atomic.Pointer[RemotePeer] |  | ||||||
| 	relay  *atomic.Pointer[RemotePeer] |  | ||||||
| 	sender dataPacketSender |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func newIFReader( |  | ||||||
| 	iface io.Reader, |  | ||||||
| 	peers [256]*atomic.Pointer[RemotePeer], |  | ||||||
| 	relay *atomic.Pointer[RemotePeer], |  | ||||||
| 	sender dataPacketSender, |  | ||||||
| ) *ifReader { |  | ||||||
| 	return &ifReader{ |  | ||||||
| 		iface:  iface, |  | ||||||
| 		peers:  peers, |  | ||||||
| 		relay:  relay, |  | ||||||
| 		sender: sender, |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (r *ifReader) Run() { |  | ||||||
| 	var ( |  | ||||||
| 		packet   = make([]byte, bufferSize) |  | ||||||
| 		remoteIP byte |  | ||||||
| 		ok       bool |  | ||||||
| 	) |  | ||||||
|  |  | ||||||
| 	for { |  | ||||||
| 		packet = r.readNextPacket(packet) |  | ||||||
| 		if remoteIP, ok = r.parsePacket(packet); ok { |  | ||||||
| 			r.sendPacket(packet, remoteIP) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (r *ifReader) sendPacket(pkt []byte, remoteIP byte) { |  | ||||||
| 	peer := r.peers[remoteIP].Load() |  | ||||||
| 	if !peer.Up { |  | ||||||
| 		log.Printf("Peer not connected: %d", remoteIP) |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Direct path => early return. |  | ||||||
| 	if peer.Direct { |  | ||||||
| 		r.sender.SendDataPacket(pkt, peer) |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if relay := r.relay.Load(); relay != nil && relay.Up { |  | ||||||
| 		r.sender.RelayDataPacket(pkt, peer, relay) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Get next packet, returning packet, and destination ip. |  | ||||||
| func (r *ifReader) readNextPacket(buf []byte) []byte { |  | ||||||
| 	n, err := r.iface.Read(buf[:cap(buf)]) |  | ||||||
| 	if err != nil { |  | ||||||
| 		log.Fatalf("Failed to read from interface: %v", err) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return buf[:n] |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (r *ifReader) parsePacket(buf []byte) (byte, bool) { |  | ||||||
| 	n := len(buf) |  | ||||||
| 	if n == 0 { |  | ||||||
| 		return 0, false |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	version := buf[0] >> 4 |  | ||||||
|  |  | ||||||
| 	switch version { |  | ||||||
| 	case 4: |  | ||||||
| 		if n < 20 { |  | ||||||
| 			log.Printf("Short IPv4 packet: %d", len(buf)) |  | ||||||
| 			return 0, false |  | ||||||
| 		} |  | ||||||
| 		return buf[19], true |  | ||||||
|  |  | ||||||
| 	case 6: |  | ||||||
| 		if len(buf) < 40 { |  | ||||||
| 			log.Printf("Short IPv6 packet: %d", len(buf)) |  | ||||||
| 			return 0, false |  | ||||||
| 		} |  | ||||||
| 		return buf[39], true |  | ||||||
|  |  | ||||||
| 	default: |  | ||||||
| 		log.Printf("Invalid IP packet version: %v", version) |  | ||||||
| 		return 0, false |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| @@ -3,22 +3,24 @@ package peer | |||||||
| import ( | import ( | ||||||
| 	"io" | 	"io" | ||||||
| 	"log" | 	"log" | ||||||
|  | 	"net/netip" | ||||||
|  | 	"sync/atomic" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type IFReader struct { | type IFReader struct { | ||||||
| 	iface              io.Reader | 	iface              io.Reader | ||||||
| 	connWriter interface { | 	writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error) | ||||||
| 		WriteData(ip byte, pkt []byte) | 	rt                 *atomic.Pointer[routingTable] | ||||||
| 	} | 	buf1               []byte | ||||||
|  | 	buf2               []byte | ||||||
| } | } | ||||||
|  |  | ||||||
| func NewIFReader( | func NewIFReader( | ||||||
| 	iface io.Reader, | 	iface io.Reader, | ||||||
| 	connWriter interface { | 	writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error), | ||||||
| 		WriteData(ip byte, pkt []byte) | 	rt *atomic.Pointer[routingTable], | ||||||
| 	}, |  | ||||||
| ) *IFReader { | ) *IFReader { | ||||||
| 	return &IFReader{iface, connWriter} | 	return &IFReader{iface, writeToUDPAddrPort, rt, newBuf(), newBuf()} | ||||||
| } | } | ||||||
|  |  | ||||||
| func (r *IFReader) Run() { | func (r *IFReader) Run() { | ||||||
| @@ -30,9 +32,32 @@ func (r *IFReader) Run() { | |||||||
|  |  | ||||||
| func (r *IFReader) handleNextPacket(packet []byte) { | func (r *IFReader) handleNextPacket(packet []byte) { | ||||||
| 	packet = r.readNextPacket(packet) | 	packet = r.readNextPacket(packet) | ||||||
| 	if remoteIP, ok := r.parsePacket(packet); ok { | 	remoteIP, ok := r.parsePacket(packet) | ||||||
| 		r.connWriter.WriteData(remoteIP, packet) | 	if !ok { | ||||||
|  | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	rt := r.rt.Load() | ||||||
|  | 	peer := rt.Peers[remoteIP] | ||||||
|  | 	if !peer.Up { | ||||||
|  | 		r.logf("Peer %d not up.", peer.IP) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	enc := peer.EncryptDataPacket(peer.IP, packet, r.buf1) | ||||||
|  | 	if peer.Direct { | ||||||
|  | 		r.writeToUDPAddrPort(enc, peer.DirectAddr) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	relay, ok := rt.GetRelay() | ||||||
|  | 	if !ok { | ||||||
|  | 		r.logf("Relay not available for peer %d.", peer.IP) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	enc = relay.EncryptDataPacket(peer.IP, enc, r.buf2) | ||||||
|  | 	r.writeToUDPAddrPort(enc, relay.DirectAddr) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (r *IFReader) readNextPacket(buf []byte) []byte { | func (r *IFReader) readNextPacket(buf []byte) []byte { | ||||||
|   | |||||||
| @@ -1,9 +1,6 @@ | |||||||
| package peer | package peer | ||||||
|  |  | ||||||
| import ( | /* | ||||||
| 	"testing" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| func TestIFReader_IPv4(t *testing.T) { | func TestIFReader_IPv4(t *testing.T) { | ||||||
| 	p1, p2, _ := NewPeersForTesting() | 	p1, p2, _ := NewPeersForTesting() | ||||||
|  |  | ||||||
| @@ -81,3 +78,4 @@ func TestIFReader_parsePacket_shortIPv6(t *testing.T) { | |||||||
| 		t.Fatal(ip, ok) | 		t.Fatal(ip, ok) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | */ | ||||||
|   | |||||||
| @@ -1,232 +0,0 @@ | |||||||
| package peer |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"bytes" |  | ||||||
| 	"reflect" |  | ||||||
| 	"sync/atomic" |  | ||||||
| 	"testing" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| // Test that we parse IPv4 packets correctly. |  | ||||||
| func TestIFReader_parsePacket_ipv4(t *testing.T) { |  | ||||||
| 	r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) |  | ||||||
|  |  | ||||||
| 	pkt := make([]byte, 1234) |  | ||||||
| 	pkt[0] = 4 << 4 |  | ||||||
| 	pkt[19] = 128 |  | ||||||
|  |  | ||||||
| 	if ip, ok := r.parsePacket(pkt); !ok || ip != 128 { |  | ||||||
| 		t.Fatal(ip, ok) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Test that we parse IPv6 packets correctly. |  | ||||||
| func TestIFReader_parsePacket_ipv6(t *testing.T) { |  | ||||||
| 	r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) |  | ||||||
|  |  | ||||||
| 	pkt := make([]byte, 1234) |  | ||||||
| 	pkt[0] = 6 << 4 |  | ||||||
| 	pkt[39] = 42 |  | ||||||
|  |  | ||||||
| 	if ip, ok := r.parsePacket(pkt); !ok || ip != 42 { |  | ||||||
| 		t.Fatal(ip, ok) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| /* |  | ||||||
| // Test that empty packets work as expected. |  | ||||||
| func TestIFReader_parsePacket_emptyPacket(t *testing.T) { |  | ||||||
| 	r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) |  | ||||||
|  |  | ||||||
| 	pkt := make([]byte, 0) |  | ||||||
| 	if ip, ok := r.parsePacket(pkt); ok { |  | ||||||
| 		t.Fatal(ip, ok) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Test that invalid IP versions fail. |  | ||||||
| func TestIFReader_parsePacket_invalidIPVersion(t *testing.T) { |  | ||||||
| 	r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) |  | ||||||
|  |  | ||||||
| 	for i := byte(1); i < 16; i++ { |  | ||||||
| 		if i == 4 || i == 6 { |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 		pkt := make([]byte, 1234) |  | ||||||
| 		pkt[0] = i << 4 |  | ||||||
|  |  | ||||||
| 		if ip, ok := r.parsePacket(pkt); ok { |  | ||||||
| 			t.Fatal(i, ip, ok) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Test that short IPv4 packets fail. |  | ||||||
| func TestIFReader_parsePacket_shortIPv4(t *testing.T) { |  | ||||||
| 	r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) |  | ||||||
|  |  | ||||||
| 	pkt := make([]byte, 19) |  | ||||||
| 	pkt[0] = 4 << 4 |  | ||||||
|  |  | ||||||
| 	if ip, ok := r.parsePacket(pkt); ok { |  | ||||||
| 		t.Fatal(ip, ok) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Test that short IPv6 packets fail. |  | ||||||
| func TestIFReader_parsePacket_shortIPv6(t *testing.T) { |  | ||||||
| 	r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) |  | ||||||
|  |  | ||||||
| 	pkt := make([]byte, 39) |  | ||||||
| 	pkt[0] = 6 << 4 |  | ||||||
|  |  | ||||||
| 	if ip, ok := r.parsePacket(pkt); ok { |  | ||||||
| 		t.Fatal(ip, ok) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Test that we can read a packet. |  | ||||||
| func TestIFReader_readNextpacket(t *testing.T) { |  | ||||||
| 	in, out := net.Pipe() |  | ||||||
| 	r := newIFReader(out, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) |  | ||||||
| 	defer in.Close() |  | ||||||
| 	defer out.Close() |  | ||||||
|  |  | ||||||
| 	go in.Write([]byte("hello world!")) |  | ||||||
|  |  | ||||||
| 	pkt := r.readNextPacket(make([]byte, bufferSize)) |  | ||||||
| 	if !bytes.Equal(pkt, []byte("hello world!")) { |  | ||||||
| 		t.Fatalf("%s", pkt) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| */ |  | ||||||
| // ---------------------------------------------------------------------------- |  | ||||||
|  |  | ||||||
| type sentPacket struct { |  | ||||||
| 	Relayed bool |  | ||||||
| 	Packet  []byte |  | ||||||
| 	Route   RemotePeer |  | ||||||
| 	Relay   RemotePeer |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type sendPacketTestHarness struct { |  | ||||||
| 	Packets []sentPacket |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (h *sendPacketTestHarness) SendDataPacket(pkt []byte, route *RemotePeer) { |  | ||||||
| 	h.Packets = append(h.Packets, sentPacket{ |  | ||||||
| 		Packet: bytes.Clone(pkt), |  | ||||||
| 		Route:  *route, |  | ||||||
| 	}) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (h *sendPacketTestHarness) RelayDataPacket(pkt []byte, route, relay *RemotePeer) { |  | ||||||
| 	h.Packets = append(h.Packets, sentPacket{ |  | ||||||
| 		Relayed: true, |  | ||||||
| 		Packet:  bytes.Clone(pkt), |  | ||||||
| 		Route:   *route, |  | ||||||
| 		Relay:   *relay, |  | ||||||
| 	}) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func newIFReaderForSendPacketTesting() (*ifReader, *sendPacketTestHarness) { |  | ||||||
| 	h := &sendPacketTestHarness{} |  | ||||||
|  |  | ||||||
| 	routes := [256]*atomic.Pointer[RemotePeer]{} |  | ||||||
| 	for i := range routes { |  | ||||||
| 		routes[i] = &atomic.Pointer[RemotePeer]{} |  | ||||||
| 		routes[i].Store(&RemotePeer{}) |  | ||||||
| 	} |  | ||||||
| 	relay := &atomic.Pointer[RemotePeer]{} |  | ||||||
| 	r := newIFReader(nil, routes, relay, h) |  | ||||||
| 	return r, h |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Testing that we can send a packet directly. |  | ||||||
| func TestIFReader_sendPacket_direct(t *testing.T) { |  | ||||||
| 	r, h := newIFReaderForSendPacketTesting() |  | ||||||
|  |  | ||||||
| 	route := r.peers[2].Load() |  | ||||||
| 	route.Up = true |  | ||||||
| 	route.Direct = true |  | ||||||
|  |  | ||||||
| 	in := []byte("hello world") |  | ||||||
|  |  | ||||||
| 	r.sendPacket(in, 2) |  | ||||||
| 	if len(h.Packets) != 1 { |  | ||||||
| 		t.Fatal(h.Packets) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	expected := sentPacket{ |  | ||||||
| 		Relayed: false, |  | ||||||
| 		Packet:  in, |  | ||||||
| 		Route:   *route, |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if !reflect.DeepEqual(h.Packets[0], expected) { |  | ||||||
| 		t.Fatal(h.Packets[0]) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Testing that we don't send a packet if route isn't up. |  | ||||||
| func TestIFReader_sendPacket_directNotUp(t *testing.T) { |  | ||||||
| 	r, h := newIFReaderForSendPacketTesting() |  | ||||||
|  |  | ||||||
| 	route := r.peers[2].Load() |  | ||||||
| 	route.Direct = true |  | ||||||
|  |  | ||||||
| 	in := []byte("hello world") |  | ||||||
|  |  | ||||||
| 	r.sendPacket(in, 2) |  | ||||||
| 	if len(h.Packets) != 0 { |  | ||||||
| 		t.Fatal(h.Packets) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Testing that we can send a packet via a relay. |  | ||||||
| func TestIFReader_sendPacket_relayed(t *testing.T) { |  | ||||||
| 	r, h := newIFReaderForSendPacketTesting() |  | ||||||
|  |  | ||||||
| 	route := r.peers[2].Load() |  | ||||||
| 	route.Up = true |  | ||||||
| 	route.Direct = false |  | ||||||
|  |  | ||||||
| 	relay := r.peers[3].Load() |  | ||||||
| 	r.relay.Store(relay) |  | ||||||
| 	relay.Up = true |  | ||||||
| 	relay.Direct = true |  | ||||||
|  |  | ||||||
| 	in := []byte("hello world") |  | ||||||
|  |  | ||||||
| 	r.sendPacket(in, 2) |  | ||||||
| 	if len(h.Packets) != 1 { |  | ||||||
| 		t.Fatal(h.Packets) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	expected := sentPacket{ |  | ||||||
| 		Relayed: true, |  | ||||||
| 		Packet:  in, |  | ||||||
| 		Route:   *route, |  | ||||||
| 		Relay:   *relay, |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if !reflect.DeepEqual(h.Packets[0], expected) { |  | ||||||
| 		t.Fatal(h.Packets[0]) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Testing that we don't try to send on a nil relay IP. |  | ||||||
| func TestIFReader_sendPacket_nilRealy(t *testing.T) { |  | ||||||
| 	r, h := newIFReaderForSendPacketTesting() |  | ||||||
|  |  | ||||||
| 	route := r.peers[2].Load() |  | ||||||
| 	route.Up = true |  | ||||||
| 	route.Direct = false |  | ||||||
|  |  | ||||||
| 	in := []byte("hello world") |  | ||||||
|  |  | ||||||
| 	r.sendPacket(in, 2) |  | ||||||
| 	if len(h.Packets) != 0 { |  | ||||||
| 		t.Fatal(h.Packets) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
							
								
								
									
										177
									
								
								peer/interface.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										177
									
								
								peer/interface.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,177 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"io" | ||||||
|  | 	"log" | ||||||
|  | 	"net" | ||||||
|  | 	"os" | ||||||
|  | 	"syscall" | ||||||
|  |  | ||||||
|  | 	"golang.org/x/sys/unix" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // Get next packet, returning packet, ip, and possible error. | ||||||
|  | func readNextPacket(iface io.ReadWriteCloser, buf []byte) ([]byte, byte, error) { | ||||||
|  | 	var ( | ||||||
|  | 		version byte | ||||||
|  | 		ip      byte | ||||||
|  | 	) | ||||||
|  | 	for { | ||||||
|  | 		n, err := iface.Read(buf[:cap(buf)]) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, ip, err | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		buf = buf[:n] | ||||||
|  | 		version = buf[0] >> 4 | ||||||
|  |  | ||||||
|  | 		switch version { | ||||||
|  | 		case 4: | ||||||
|  | 			if n < 20 { | ||||||
|  | 				log.Printf("Short IPv4 packet: %d", len(buf)) | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			ip = buf[19] | ||||||
|  |  | ||||||
|  | 		case 6: | ||||||
|  | 			if len(buf) < 40 { | ||||||
|  | 				log.Printf("Short IPv6 packet: %d", len(buf)) | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			ip = buf[39] | ||||||
|  |  | ||||||
|  | 		default: | ||||||
|  | 			log.Printf("Invalid IP packet version: %v", version) | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		return buf, ip, nil | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func openInterface(network []byte, localIP byte, name string) (io.ReadWriteCloser, error) { | ||||||
|  | 	if len(network) != 4 { | ||||||
|  | 		return nil, fmt.Errorf("expected network to be 4 bytes, got %d", len(network)) | ||||||
|  | 	} | ||||||
|  | 	ip := net.IPv4(network[0], network[1], network[2], localIP) | ||||||
|  |  | ||||||
|  | 	////////////////////////// | ||||||
|  | 	// Create TUN Interface // | ||||||
|  | 	////////////////////////// | ||||||
|  |  | ||||||
|  | 	tunFD, err := syscall.Open("/dev/net/tun", syscall.O_RDWR|unix.O_CLOEXEC, 0600) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("failed to open TUN device: %w", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// New interface request. | ||||||
|  | 	req, err := unix.NewIfreq(name) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("failed to create new TUN interface request: %w", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Flags: | ||||||
|  | 	// | ||||||
|  | 	// IFF_NO_PI => don't add packet info data to packets sent to the interface. | ||||||
|  | 	// IFF_TUN   => create a TUN device handling IP packets. | ||||||
|  | 	req.SetUint16(unix.IFF_NO_PI | unix.IFF_TUN) | ||||||
|  |  | ||||||
|  | 	err = unix.IoctlIfreq(tunFD, unix.TUNSETIFF, req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("failed to set TUN device settings: %w", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Name may not be exactly the same? | ||||||
|  | 	name = req.Name() | ||||||
|  |  | ||||||
|  | 	///////////// | ||||||
|  | 	// Set MTU // | ||||||
|  | 	///////////// | ||||||
|  |  | ||||||
|  | 	// We need a socket file descriptor to set other options for some reason. | ||||||
|  | 	sockFD, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("failed to open socket: %w", err) | ||||||
|  | 	} | ||||||
|  | 	defer unix.Close(sockFD) | ||||||
|  |  | ||||||
|  | 	req, err = unix.NewIfreq(name) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("failed to create MTU interface request: %w", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	req.SetUint32(if_mtu) | ||||||
|  | 	if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFMTU, req); err != nil { | ||||||
|  | 		return nil, fmt.Errorf("failed to set interface MTU: %w", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	////////////////////// | ||||||
|  | 	// Set Queue Length // | ||||||
|  | 	////////////////////// | ||||||
|  |  | ||||||
|  | 	req, err = unix.NewIfreq(name) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("failed to create IP interface request: %w", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	req.SetUint16(if_queue_len) | ||||||
|  | 	if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFTXQLEN, req); err != nil { | ||||||
|  | 		return nil, fmt.Errorf("failed to set interface queue length: %w", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	///////////////////// | ||||||
|  | 	// Set IP and Mask // | ||||||
|  | 	///////////////////// | ||||||
|  |  | ||||||
|  | 	req, err = unix.NewIfreq(name) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("failed to create IP interface request: %w", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if err := req.SetInet4Addr(ip.To4()); err != nil { | ||||||
|  | 		return nil, fmt.Errorf("failed to set interface request IP: %w", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFADDR, req); err != nil { | ||||||
|  | 		return nil, fmt.Errorf("failed to set interface IP: %w", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// SET MASK - must happen after setting address. | ||||||
|  | 	req, err = unix.NewIfreq(name) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("failed to create mask interface request: %w", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if err := req.SetInet4Addr(net.IPv4(255, 255, 255, 0).To4()); err != nil { | ||||||
|  | 		return nil, fmt.Errorf("failed to set interface request mask: %w", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if err := unix.IoctlIfreq(sockFD, unix.SIOCSIFNETMASK, req); err != nil { | ||||||
|  | 		return nil, fmt.Errorf("failed to set interface mask: %w", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	//////////////////////// | ||||||
|  | 	// Bring Interface Up // | ||||||
|  | 	//////////////////////// | ||||||
|  |  | ||||||
|  | 	req, err = unix.NewIfreq(name) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("failed to create up interface request: %w", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Get current flags. | ||||||
|  | 	if err = unix.IoctlIfreq(sockFD, unix.SIOCGIFFLAGS, req); err != nil { | ||||||
|  | 		return nil, fmt.Errorf("failed to get interface flags: %w", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	flags := req.Uint16() | unix.IFF_UP | unix.IFF_RUNNING | ||||||
|  |  | ||||||
|  | 	// Set UP flag / broadcast flags. | ||||||
|  | 	req.SetUint16(flags) | ||||||
|  | 	if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFFLAGS, req); err != nil { | ||||||
|  | 		return nil, fmt.Errorf("failed to set interface up: %w", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return os.NewFile(uintptr(tunFD), "tun"), nil | ||||||
|  | } | ||||||
| @@ -31,17 +31,17 @@ type Marshaller interface { | |||||||
| } | } | ||||||
|  |  | ||||||
| type dataPacketSender interface { | type dataPacketSender interface { | ||||||
| 	SendDataPacket(pkt []byte, peer *RemotePeer) | 	SendDataPacket(pkt []byte, peer *remotePeer) | ||||||
| 	RelayDataPacket(pkt []byte, peer, relay *RemotePeer) | 	RelayDataPacket(pkt []byte, peer, relay *remotePeer) | ||||||
| } | } | ||||||
|  |  | ||||||
| type controlPacketSender interface { | type controlPacketSender interface { | ||||||
| 	SendControlPacket(pkt Marshaller, peer *RemotePeer) | 	SendControlPacket(pkt Marshaller, peer *remotePeer) | ||||||
| 	RelayControlPacket(pkt Marshaller, peer, relay *RemotePeer) | 	RelayControlPacket(pkt Marshaller, peer, relay *remotePeer) | ||||||
| } | } | ||||||
|  |  | ||||||
| type encryptedPacketSender interface { | type encryptedPacketSender interface { | ||||||
| 	SendEncryptedDataPacket(pkt []byte, peer *RemotePeer) | 	SendEncryptedDataPacket(pkt []byte, peer *remotePeer) | ||||||
| } | } | ||||||
|  |  | ||||||
| type controlMsgHandler interface { | type controlMsgHandler interface { | ||||||
|   | |||||||
							
								
								
									
										23
									
								
								peer/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								peer/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,23 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"flag" | ||||||
|  | 	"os" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func Main() { | ||||||
|  | 	conf := Config{} | ||||||
|  |  | ||||||
|  | 	flag.StringVar(&conf.NetName, "name", "", "[REQUIRED] The network name.") | ||||||
|  | 	flag.StringVar(&conf.HubAddress, "hub-address", "", "[REQUIRED] The hub address.") | ||||||
|  | 	flag.StringVar(&conf.APIKey, "api-key", "", "[REQUIRED] The node's API key.") | ||||||
|  | 	flag.Parse() | ||||||
|  |  | ||||||
|  | 	if conf.NetName == "" || conf.HubAddress == "" || conf.APIKey == "" { | ||||||
|  | 		flag.Usage() | ||||||
|  | 		os.Exit(1) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	peer := New(conf) | ||||||
|  | 	peer.Run() | ||||||
|  | } | ||||||
| @@ -8,7 +8,7 @@ import ( | |||||||
| type mcReader struct { | type mcReader struct { | ||||||
| 	conn  udpReader | 	conn  udpReader | ||||||
| 	super controlMsgHandler | 	super controlMsgHandler | ||||||
| 	peers [256]*atomic.Pointer[RemotePeer] | 	peers [256]*atomic.Pointer[remotePeer] | ||||||
|  |  | ||||||
| 	incoming []byte | 	incoming []byte | ||||||
| 	buf      []byte | 	buf      []byte | ||||||
| @@ -17,7 +17,7 @@ type mcReader struct { | |||||||
| func newMCReader( | func newMCReader( | ||||||
| 	conn udpReader, | 	conn udpReader, | ||||||
| 	super controlMsgHandler, | 	super controlMsgHandler, | ||||||
| 	peers [256]*atomic.Pointer[RemotePeer], | 	peers [256]*atomic.Pointer[remotePeer], | ||||||
| ) *mcReader { | ) *mcReader { | ||||||
| 	return &mcReader{conn, super, peers, newBuf(), newBuf()} | 	return &mcReader{conn, super, peers, newBuf(), newBuf()} | ||||||
| } | } | ||||||
| @@ -50,7 +50,7 @@ func (r *mcReader) handleNextPacket() { | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	r.super.HandleControlMsg(controlMsg[PacketLocalDiscovery]{ | 	r.super.HandleControlMsg(controlMsg[packetLocalDiscovery]{ | ||||||
| 		SrcIP:   h.SourceIP, | 		SrcIP:   h.SourceIP, | ||||||
| 		SrcAddr: remoteAddr, | 		SrcAddr: remoteAddr, | ||||||
| 	}) | 	}) | ||||||
|   | |||||||
| @@ -1,13 +1,6 @@ | |||||||
| package peer | package peer | ||||||
|  |  | ||||||
| import ( | /* | ||||||
| 	"bytes" |  | ||||||
| 	"net" |  | ||||||
| 	"net/netip" |  | ||||||
| 	"sync/atomic" |  | ||||||
| 	"testing" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| type mcMockConn struct { | type mcMockConn struct { | ||||||
| 	packets chan []byte | 	packets chan []byte | ||||||
| } | } | ||||||
| @@ -136,3 +129,4 @@ func TestMCReader_badSignature(t *testing.T) { | |||||||
| 		t.Fatal(super.Messages) | 		t.Fatal(super.Messages) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | */ | ||||||
|   | |||||||
| @@ -5,41 +5,34 @@ import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	PacketTypeSyn = iota + 1 | 	packetTypeSyn           = 1 | ||||||
| 	PacketTypeSynAck | 	packetTypeAck           = 3 | ||||||
| 	PacketTypeAck | 	packetTypeProbe         = 4 | ||||||
| 	PacketTypeProbe | 	packetTypeAddrDiscovery = 5 | ||||||
| 	PacketTypeAddrDiscovery |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
| type PacketSyn struct { | type packetSyn struct { | ||||||
| 	TraceID       uint64   // TraceID to match response w/ request. | 	TraceID       uint64   // TraceID to match response w/ request. | ||||||
| 	//SentAt        int64    // Unixmilli. |  | ||||||
| 	//SharedKeyType byte     // Currently only 1 is supported for AES. |  | ||||||
| 	SharedKey     [32]byte // Our shared key. | 	SharedKey     [32]byte // Our shared key. | ||||||
| 	Direct        bool | 	Direct        bool | ||||||
| 	PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. | 	PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. | ||||||
| } | } | ||||||
|  |  | ||||||
| func (p PacketSyn) Marshal(buf []byte) []byte { | func (p packetSyn) Marshal(buf []byte) []byte { | ||||||
| 	return newBinWriter(buf). | 	return newBinWriter(buf). | ||||||
| 		Byte(PacketTypeSyn). | 		Byte(packetTypeSyn). | ||||||
| 		Uint64(p.TraceID). | 		Uint64(p.TraceID). | ||||||
| 		//Int64(p.SentAt). |  | ||||||
| 		//Byte(p.SharedKeyType). |  | ||||||
| 		SharedKey(p.SharedKey). | 		SharedKey(p.SharedKey). | ||||||
| 		Bool(p.Direct). | 		Bool(p.Direct). | ||||||
| 		AddrPort8(p.PossibleAddrs). | 		AddrPort8(p.PossibleAddrs). | ||||||
| 		Build() | 		Build() | ||||||
| } | } | ||||||
|  |  | ||||||
| func ParsePacketSyn(buf []byte) (p PacketSyn, err error) { | func parsePacketSyn(buf []byte) (p packetSyn, err error) { | ||||||
| 	err = newBinReader(buf[1:]). | 	err = newBinReader(buf[1:]). | ||||||
| 		Uint64(&p.TraceID). | 		Uint64(&p.TraceID). | ||||||
| 		//Int64(&p.SentAt). |  | ||||||
| 		//Byte(&p.SharedKeyType). |  | ||||||
| 		SharedKey(&p.SharedKey). | 		SharedKey(&p.SharedKey). | ||||||
| 		Bool(&p.Direct). | 		Bool(&p.Direct). | ||||||
| 		AddrPort8(&p.PossibleAddrs). | 		AddrPort8(&p.PossibleAddrs). | ||||||
| @@ -49,22 +42,22 @@ func ParsePacketSyn(buf []byte) (p PacketSyn, err error) { | |||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
| type PacketAck struct { | type packetAck struct { | ||||||
| 	TraceID       uint64 | 	TraceID       uint64 | ||||||
| 	ToAddr        netip.AddrPort | 	ToAddr        netip.AddrPort | ||||||
| 	PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. | 	PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. | ||||||
| } | } | ||||||
|  |  | ||||||
| func (p PacketAck) Marshal(buf []byte) []byte { | func (p packetAck) Marshal(buf []byte) []byte { | ||||||
| 	return newBinWriter(buf). | 	return newBinWriter(buf). | ||||||
| 		Byte(PacketTypeAck). | 		Byte(packetTypeAck). | ||||||
| 		Uint64(p.TraceID). | 		Uint64(p.TraceID). | ||||||
| 		AddrPort(p.ToAddr). | 		AddrPort(p.ToAddr). | ||||||
| 		AddrPort8(p.PossibleAddrs). | 		AddrPort8(p.PossibleAddrs). | ||||||
| 		Build() | 		Build() | ||||||
| } | } | ||||||
|  |  | ||||||
| func ParsePacketAck(buf []byte) (p PacketAck, err error) { | func parsePacketAck(buf []byte) (p packetAck, err error) { | ||||||
| 	err = newBinReader(buf[1:]). | 	err = newBinReader(buf[1:]). | ||||||
| 		Uint64(&p.TraceID). | 		Uint64(&p.TraceID). | ||||||
| 		AddrPort(&p.ToAddr). | 		AddrPort(&p.ToAddr). | ||||||
| @@ -77,18 +70,18 @@ func ParsePacketAck(buf []byte) (p PacketAck, err error) { | |||||||
|  |  | ||||||
| // A probeReqPacket is sent from a client to a server to determine if direct | // A probeReqPacket is sent from a client to a server to determine if direct | ||||||
| // UDP communication can be used. | // UDP communication can be used. | ||||||
| type PacketProbe struct { | type packetProbe struct { | ||||||
| 	TraceID uint64 | 	TraceID uint64 | ||||||
| } | } | ||||||
|  |  | ||||||
| func (p PacketProbe) Marshal(buf []byte) []byte { | func (p packetProbe) Marshal(buf []byte) []byte { | ||||||
| 	return newBinWriter(buf). | 	return newBinWriter(buf). | ||||||
| 		Byte(PacketTypeProbe). | 		Byte(packetTypeProbe). | ||||||
| 		Uint64(p.TraceID). | 		Uint64(p.TraceID). | ||||||
| 		Build() | 		Build() | ||||||
| } | } | ||||||
|  |  | ||||||
| func ParsePacketProbe(buf []byte) (p PacketProbe, err error) { | func parsePacketProbe(buf []byte) (p packetProbe, err error) { | ||||||
| 	err = newBinReader(buf[1:]). | 	err = newBinReader(buf[1:]). | ||||||
| 		Uint64(&p.TraceID). | 		Uint64(&p.TraceID). | ||||||
| 		Error() | 		Error() | ||||||
| @@ -97,4 +90,4 @@ func ParsePacketProbe(buf []byte) (p PacketProbe, err error) { | |||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
| type PacketLocalDiscovery struct{} | type packetLocalDiscovery struct{} | ||||||
|   | |||||||
| @@ -8,7 +8,7 @@ import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| func TestSynPacket(t *testing.T) { | func TestSynPacket(t *testing.T) { | ||||||
| 	p := PacketSyn{ | 	p := packetSyn{ | ||||||
| 		TraceID: newTraceID(), | 		TraceID: newTraceID(), | ||||||
| 		//SentAt:        time.Now().UnixMilli(), | 		//SentAt:        time.Now().UnixMilli(), | ||||||
| 		//SharedKeyType: 1, | 		//SharedKeyType: 1, | ||||||
| @@ -21,7 +21,7 @@ func TestSynPacket(t *testing.T) { | |||||||
| 	p.PossibleAddrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{3, 2, 3, 4}), 60000) | 	p.PossibleAddrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{3, 2, 3, 4}), 60000) | ||||||
|  |  | ||||||
| 	buf := p.Marshal(newBuf()) | 	buf := p.Marshal(newBuf()) | ||||||
| 	p2, err := ParsePacketSyn(buf) | 	p2, err := parsePacketSyn(buf) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
| @@ -31,7 +31,7 @@ func TestSynPacket(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestAckPacket(t *testing.T) { | func TestAckPacket(t *testing.T) { | ||||||
| 	p := PacketAck{ | 	p := packetAck{ | ||||||
| 		TraceID: newTraceID(), | 		TraceID: newTraceID(), | ||||||
| 		ToAddr:  netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 234), | 		ToAddr:  netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 234), | ||||||
| 	} | 	} | ||||||
| @@ -41,7 +41,7 @@ func TestAckPacket(t *testing.T) { | |||||||
| 	p.PossibleAddrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{3, 2, 3, 4}), 60000) | 	p.PossibleAddrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{3, 2, 3, 4}), 60000) | ||||||
|  |  | ||||||
| 	buf := p.Marshal(newBuf()) | 	buf := p.Marshal(newBuf()) | ||||||
| 	p2, err := ParsePacketAck(buf) | 	p2, err := parsePacketAck(buf) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
| @@ -51,12 +51,12 @@ func TestAckPacket(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestProbePacket(t *testing.T) { | func TestProbePacket(t *testing.T) { | ||||||
| 	p := PacketProbe{ | 	p := packetProbe{ | ||||||
| 		TraceID: newTraceID(), | 		TraceID: newTraceID(), | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	buf := p.Marshal(newBuf()) | 	buf := p.Marshal(newBuf()) | ||||||
| 	p2, err := ParsePacketProbe(buf) | 	p2, err := parsePacketProbe(buf) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
|   | |||||||
							
								
								
									
										161
									
								
								peer/peer.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										161
									
								
								peer/peer.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,161 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"fmt" | ||||||
|  | 	"io" | ||||||
|  | 	"log" | ||||||
|  | 	"net" | ||||||
|  | 	"net/http" | ||||||
|  | 	"net/netip" | ||||||
|  | 	"net/url" | ||||||
|  | 	"sync" | ||||||
|  | 	"sync/atomic" | ||||||
|  | 	"vppn/m" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type Peer struct { | ||||||
|  | 	ifReader   *IFReader | ||||||
|  | 	connReader *ConnReader | ||||||
|  | 	iface      io.Writer | ||||||
|  | 	hubPoller  *hubPoller | ||||||
|  | 	super      *Super | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Config struct { | ||||||
|  | 	NetName    string | ||||||
|  | 	HubAddress string | ||||||
|  | 	APIKey     string | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func New(conf Config) *Peer { | ||||||
|  | 	config, err := loadPeerConfig(conf.NetName) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Printf("Failed to load configuration: %v", err) | ||||||
|  | 		log.Printf("Initializing...") | ||||||
|  | 		initPeerWithHub(conf) | ||||||
|  |  | ||||||
|  | 		config, err = loadPeerConfig(conf.NetName) | ||||||
|  | 		if err != nil { | ||||||
|  | 			log.Fatalf("Failed to load configuration: %v", err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	iface, err := openInterface(config.Network, config.PeerIP, conf.NetName) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Fatalf("Failed to open interface: %v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	myAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", config.Port)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Fatalf("Failed to resolve UDP address: %v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	log.Printf("Listening on %v...", myAddr) | ||||||
|  | 	conn, err := net.ListenUDP("udp", myAddr) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Fatalf("Failed to open UDP port: %v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	conn.SetReadBuffer(1024 * 1024 * 8) | ||||||
|  | 	conn.SetWriteBuffer(1024 * 1024 * 8) | ||||||
|  |  | ||||||
|  | 	// Wrap write function - this is necessary to avoid starvation. | ||||||
|  | 	writeLock := sync.Mutex{} | ||||||
|  | 	writeToUDPAddrPort := func(b []byte, addr netip.AddrPort) (n int, err error) { | ||||||
|  | 		writeLock.Lock() | ||||||
|  | 		n, err = conn.WriteToUDPAddrPort(b, addr) | ||||||
|  | 		if err != nil { | ||||||
|  | 			log.Printf("Failed to write packet: %v", err) | ||||||
|  | 		} | ||||||
|  | 		writeLock.Unlock() | ||||||
|  | 		return n, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	var localAddr netip.AddrPort | ||||||
|  | 	ip, ok := netip.AddrFromSlice(config.PublicIP) | ||||||
|  | 	if ok { | ||||||
|  | 		localAddr = netip.AddrPortFrom(ip, config.Port) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	rt := newRoutingTable(config.PeerIP, localAddr) | ||||||
|  | 	rtPtr := &atomic.Pointer[routingTable]{} | ||||||
|  | 	rtPtr.Store(&rt) | ||||||
|  |  | ||||||
|  | 	ifReader := NewIFReader(iface, writeToUDPAddrPort, rtPtr) | ||||||
|  | 	super := NewSuper(writeToUDPAddrPort, rtPtr, config.PrivKey) | ||||||
|  | 	connReader := NewConnReader(conn.ReadFromUDPAddrPort, writeToUDPAddrPort, iface, super.HandleControlMsg, rtPtr) | ||||||
|  | 	hubPoller, err := newHubPoller(config.PeerIP, conf.NetName, conf.HubAddress, conf.APIKey, super.HandleControlMsg) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Fatalf("Failed to create hub poller: %v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return &Peer{ | ||||||
|  | 		iface:      iface, | ||||||
|  | 		ifReader:   ifReader, | ||||||
|  | 		connReader: connReader, | ||||||
|  | 		hubPoller:  hubPoller, | ||||||
|  | 		super:      super, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (p *Peer) Run() { | ||||||
|  | 	go p.ifReader.Run() | ||||||
|  | 	go p.connReader.Run() | ||||||
|  | 	p.super.Start() | ||||||
|  | 	p.hubPoller.Run() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func initPeerWithHub(conf Config) { | ||||||
|  | 	keys := generateKeys() | ||||||
|  |  | ||||||
|  | 	initURL, err := url.Parse(conf.HubAddress) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Fatalf("Failed to parse hub URL: %v", err) | ||||||
|  | 	} | ||||||
|  | 	initURL.Path = "/peer/init/" | ||||||
|  |  | ||||||
|  | 	args := m.PeerInitArgs{ | ||||||
|  | 		EncPubKey:  keys.PubKey, | ||||||
|  | 		PubSignKey: keys.PubSignKey, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	buf := &bytes.Buffer{} | ||||||
|  | 	if err := json.NewEncoder(buf).Encode(args); err != nil { | ||||||
|  | 		log.Fatalf("Failed to encode init args: %v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	req, err := http.NewRequest(http.MethodPost, initURL.String(), buf) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Fatalf("Failed to construct request: %v", err) | ||||||
|  | 	} | ||||||
|  | 	req.SetBasicAuth("", conf.APIKey) | ||||||
|  |  | ||||||
|  | 	resp, err := http.DefaultClient.Do(req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Fatalf("Failed to init with hub: %v", err) | ||||||
|  | 	} | ||||||
|  | 	defer resp.Body.Close() | ||||||
|  |  | ||||||
|  | 	data, err := io.ReadAll(resp.Body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Fatalf("Failed to read response body: %v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	peerConfig := localConfig{} | ||||||
|  | 	if err := json.Unmarshal(data, &peerConfig.PeerConfig); err != nil { | ||||||
|  | 		log.Fatalf("Failed to parse configuration: %v\n%s", err, data) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	peerConfig.PubKey = keys.PubKey | ||||||
|  | 	peerConfig.PrivKey = keys.PrivKey | ||||||
|  | 	peerConfig.PubSignKey = keys.PubSignKey | ||||||
|  | 	peerConfig.PrivSignKey = keys.PrivSignKey | ||||||
|  |  | ||||||
|  | 	if err := storePeerConfig(conf.NetName, peerConfig); err != nil { | ||||||
|  | 		log.Fatalf("Failed to store configuration: %v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	log.Print("Initialization successful.") | ||||||
|  | } | ||||||
| @@ -11,36 +11,25 @@ import ( | |||||||
| // A test peer. | // A test peer. | ||||||
| type P struct { | type P struct { | ||||||
| 	cryptoKeys | 	cryptoKeys | ||||||
| 	RT         *atomic.Pointer[RoutingTable] | 	RT         *atomic.Pointer[routingTable] | ||||||
| 	Conn       *TestUDPConn | 	Conn       *TestUDPConn | ||||||
| 	IFace      *TestIFace | 	IFace      *TestIFace | ||||||
| 	ConnWriter *ConnWriter |  | ||||||
| 	ConnReader *ConnReader | 	ConnReader *ConnReader | ||||||
| 	IFReader   *IFReader | 	IFReader   *IFReader | ||||||
| 	Super      *Supervisor |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func NewPeerForTesting(n *TestNetwork, ip byte, addr netip.AddrPort) P { | func NewPeerForTesting(n *TestNetwork, ip byte, addr netip.AddrPort) P { | ||||||
| 	p := P{ | 	p := P{ | ||||||
| 		cryptoKeys: generateKeys(), | 		cryptoKeys: generateKeys(), | ||||||
| 		RT:         &atomic.Pointer[RoutingTable]{}, | 		RT:         &atomic.Pointer[routingTable]{}, | ||||||
| 		IFace:      NewTestIFace(), | 		IFace:      NewTestIFace(), | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	rt := NewRoutingTable(ip, addr) | 	rt := newRoutingTable(ip, addr) | ||||||
| 	p.RT.Store(&rt) | 	p.RT.Store(&rt) | ||||||
| 	p.Conn = n.NewUDPConn(addr) | 	p.Conn = n.NewUDPConn(addr) | ||||||
| 	p.ConnWriter = NewConnWriter(p.Conn.WriteToUDPAddrPort, p.RT) | 	//p.ConnWriter = NewConnWriter(p.Conn.WriteToUDPAddrPort, p.RT) | ||||||
| 	p.IFReader = NewIFReader(p.IFace, p.ConnWriter) |  | ||||||
|  |  | ||||||
| 	/* |  | ||||||
| 		   p.ConnReader = NewConnReader( |  | ||||||
| 				p.Conn.ReadFromUDPAddrPort, |  | ||||||
| 				p.IFace, |  | ||||||
| 				p.ConnWriter.Forward, |  | ||||||
| 				p.Super.HandleControlMsg, |  | ||||||
| 				p.RT) |  | ||||||
| 	*/ |  | ||||||
| 	return p | 	return p | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -11,21 +11,21 @@ import ( | |||||||
| 	"git.crumpington.com/lib/go/ratelimiter" | 	"git.crumpington.com/lib/go/ratelimiter" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type PeerState interface { | type peerState interface { | ||||||
| 	OnPeerUpdate(*m.Peer) PeerState | 	OnPeerUpdate(*m.Peer) peerState | ||||||
| 	OnSyn(controlMsg[PacketSyn]) PeerState | 	OnSyn(controlMsg[packetSyn]) peerState | ||||||
| 	OnAck(controlMsg[PacketAck]) | 	OnAck(controlMsg[packetAck]) | ||||||
| 	OnProbe(controlMsg[PacketProbe]) PeerState | 	OnProbe(controlMsg[packetProbe]) peerState | ||||||
| 	OnLocalDiscovery(controlMsg[PacketLocalDiscovery]) | 	OnLocalDiscovery(controlMsg[packetLocalDiscovery]) | ||||||
| 	OnPingTimer() PeerState | 	OnPingTimer() peerState | ||||||
| } | } | ||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
| type State struct { | type pState struct { | ||||||
| 	// Output. | 	// Output. | ||||||
| 	publish           func(RemotePeer) | 	publish           func(remotePeer) | ||||||
| 	sendControlPacket func(RemotePeer, Marshaller) | 	sendControlPacket func(remotePeer, Marshaller) | ||||||
|  |  | ||||||
| 	// Immutable data. | 	// Immutable data. | ||||||
| 	localIP   byte | 	localIP   byte | ||||||
| @@ -37,7 +37,7 @@ type State struct { | |||||||
|  |  | ||||||
| 	// The purpose of this state machine is to manage the RemotePeer object, | 	// The purpose of this state machine is to manage the RemotePeer object, | ||||||
| 	// publishing it as necessary. | 	// publishing it as necessary. | ||||||
| 	staged RemotePeer // Local copy of shared data. See publish(). | 	staged remotePeer // Local copy of shared data. See publish(). | ||||||
|  |  | ||||||
| 	// Mutable peer data. | 	// Mutable peer data. | ||||||
| 	peer *m.Peer | 	peer *m.Peer | ||||||
| @@ -47,25 +47,28 @@ type State struct { | |||||||
| 	limiter *ratelimiter.Limiter | 	limiter *ratelimiter.Limiter | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *State) OnPeerUpdate(peer *m.Peer) PeerState { | func (s *pState) OnPeerUpdate(peer *m.Peer) peerState { | ||||||
| 	defer func() { | 	defer func() { | ||||||
| 		// Don't defer directly otherwise s.staged will be evaluated immediately | 		// Don't defer directly otherwise s.staged will be evaluated immediately | ||||||
| 		// and won't reflect changes made in the function. | 		// and won't reflect changes made in the function. | ||||||
| 		s.publish(s.staged) | 		s.publish(s.staged) | ||||||
| 	}() | 	}() | ||||||
|  |  | ||||||
| 	if peer == nil { |  | ||||||
| 		return EnterStateDisconnected(s) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	s.peer = peer | 	s.peer = peer | ||||||
|  |  | ||||||
| 	s.staged.localIP = s.localIP | 	s.staged.localIP = s.localIP | ||||||
| 	s.staged.IP = peer.PeerIP |  | ||||||
| 	s.staged.Up = false | 	s.staged.Up = false | ||||||
| 	s.staged.Relay = false | 	s.staged.Relay = false | ||||||
| 	s.staged.Direct = false | 	s.staged.Direct = false | ||||||
| 	s.staged.DirectAddr = netip.AddrPort{} | 	s.staged.DirectAddr = netip.AddrPort{} | ||||||
|  | 	s.staged.PubSignKey = nil | ||||||
|  | 	s.staged.ControlCipher = nil | ||||||
|  | 	s.staged.DataCipher = nil | ||||||
|  |  | ||||||
|  | 	if peer == nil { | ||||||
|  | 		return enterStateDisconnected(s) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	s.staged.IP = peer.PeerIP | ||||||
| 	s.staged.PubSignKey = peer.PubSignKey | 	s.staged.PubSignKey = peer.PubSignKey | ||||||
| 	s.staged.ControlCipher = newControlCipher(s.privKey, peer.PubKey) | 	s.staged.ControlCipher = newControlCipher(s.privKey, peer.PubKey) | ||||||
| 	s.staged.DataCipher = newDataCipher() | 	s.staged.DataCipher = newDataCipher() | ||||||
| @@ -76,30 +79,32 @@ func (s *State) OnPeerUpdate(peer *m.Peer) PeerState { | |||||||
| 		s.staged.DirectAddr = netip.AddrPortFrom(ip, peer.Port) | 		s.staged.DirectAddr = netip.AddrPortFrom(ip, peer.Port) | ||||||
|  |  | ||||||
| 		if s.localAddr.IsValid() && s.localIP < s.remoteIP { | 		if s.localAddr.IsValid() && s.localIP < s.remoteIP { | ||||||
| 			return EnterStateServer(s) | 			return enterStateServer(s) | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		return EnterStateClientDirect(s) | 		return enterStateClientDirect(s) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if s.localAddr.IsValid() { | 	if s.localAddr.IsValid() { | ||||||
| 		s.staged.Direct = true | 		s.staged.Direct = true | ||||||
| 		return EnterStateServer(s) | 		return enterStateServer(s) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if s.localIP < s.remoteIP { | 	if s.localIP < s.remoteIP { | ||||||
| 		return EnterStateServer(s) | 		return enterStateServer(s) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return EnterStateClientRelayed(s) | 	return enterStateClientRelayed(s) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *State) logf(format string, args ...any) { | func (s *pState) logf(format string, args ...any) { | ||||||
| 	b := strings.Builder{} | 	b := strings.Builder{} | ||||||
| 	name := "" | 	name := "" | ||||||
| 	if s.peer != nil { | 	if s.peer != nil { | ||||||
| 		name = s.peer.Name | 		name = s.peer.Name | ||||||
| 	} | 	} | ||||||
|  | 	b.WriteString(fmt.Sprintf("%03d", s.remoteIP)) | ||||||
|  |  | ||||||
| 	b.WriteString(fmt.Sprintf("%30s: ", name)) | 	b.WriteString(fmt.Sprintf("%30s: ", name)) | ||||||
|  |  | ||||||
| 	if s.staged.Direct { | 	if s.staged.Direct { | ||||||
| @@ -119,7 +124,7 @@ func (s *State) logf(format string, args ...any) { | |||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
| func (s *State) SendTo(pkt Marshaller, addr netip.AddrPort) { | func (s *pState) SendTo(pkt Marshaller, addr netip.AddrPort) { | ||||||
| 	if !addr.IsValid() { | 	if !addr.IsValid() { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| @@ -129,7 +134,7 @@ func (s *State) SendTo(pkt Marshaller, addr netip.AddrPort) { | |||||||
| 	s.Send(route, pkt) | 	s.Send(route, pkt) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *State) Send(peer RemotePeer, pkt Marshaller) { | func (s *pState) Send(peer remotePeer, pkt Marshaller) { | ||||||
| 	if err := s.limiter.Limit(); err != nil { | 	if err := s.limiter.Limit(); err != nil { | ||||||
| 		s.logf("Rate limited.") | 		s.logf("Rate limited.") | ||||||
| 		return | 		return | ||||||
| @@ -139,42 +144,32 @@ func (s *State) Send(peer RemotePeer, pkt Marshaller) { | |||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
| type StateDisconnected struct{ *State } | type stateDisconnected struct{ *pState } | ||||||
|  |  | ||||||
| func EnterStateDisconnected(s *State) PeerState { | func enterStateDisconnected(s *pState) peerState { | ||||||
| 	s.logf("==> Disconnected") | 	return &stateDisconnected{pState: s} | ||||||
| 	s.peer = nil |  | ||||||
| 	s.staged.Up = false |  | ||||||
| 	s.staged.Relay = false |  | ||||||
| 	s.staged.Direct = false |  | ||||||
| 	s.staged.DirectAddr = netip.AddrPort{} |  | ||||||
| 	s.staged.PubSignKey = nil |  | ||||||
| 	s.staged.ControlCipher = nil |  | ||||||
| 	s.staged.DataCipher = nil |  | ||||||
| 	s.publish(s.staged) |  | ||||||
| 	return &StateDisconnected{State: s} |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *StateDisconnected) OnSyn(controlMsg[PacketSyn]) PeerState             { return nil } | func (s *stateDisconnected) OnSyn(controlMsg[packetSyn]) peerState             { return s } | ||||||
| func (s *StateDisconnected) OnAck(controlMsg[PacketAck])                       {} | func (s *stateDisconnected) OnAck(controlMsg[packetAck])                       {} | ||||||
| func (s *StateDisconnected) OnProbe(controlMsg[PacketProbe]) PeerState         { return nil } | func (s *stateDisconnected) OnProbe(controlMsg[packetProbe]) peerState         { return s } | ||||||
| func (s *StateDisconnected) OnLocalDiscovery(controlMsg[PacketLocalDiscovery]) {} | func (s *stateDisconnected) OnLocalDiscovery(controlMsg[packetLocalDiscovery]) {} | ||||||
| func (s *StateDisconnected) OnPingTimer() PeerState                            { return nil } | func (s *stateDisconnected) OnPingTimer() peerState                            { return s } | ||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
| type StateServer struct { | type stateServer struct { | ||||||
| 	*StateDisconnected | 	*stateDisconnected | ||||||
| 	lastSeen   time.Time | 	lastSeen   time.Time | ||||||
| 	synTraceID uint64 | 	synTraceID uint64 | ||||||
| } | } | ||||||
|  |  | ||||||
| func EnterStateServer(s *State) PeerState { | func enterStateServer(s *pState) peerState { | ||||||
| 	s.logf("==> Server") | 	s.logf("==> Server") | ||||||
| 	return &StateServer{StateDisconnected: &StateDisconnected{State: s}} | 	return &stateServer{stateDisconnected: &stateDisconnected{pState: s}} | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *StateServer) OnSyn(msg controlMsg[PacketSyn]) PeerState { | func (s *stateServer) OnSyn(msg controlMsg[packetSyn]) peerState { | ||||||
| 	s.lastSeen = time.Now() | 	s.lastSeen = time.Now() | ||||||
| 	p := msg.Packet | 	p := msg.Packet | ||||||
|  |  | ||||||
| @@ -194,7 +189,7 @@ func (s *StateServer) OnSyn(msg controlMsg[PacketSyn]) PeerState { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Always respond. | 	// Always respond. | ||||||
| 	ack := PacketAck{ | 	ack := packetAck{ | ||||||
| 		TraceID:       p.TraceID, | 		TraceID:       p.TraceID, | ||||||
| 		ToAddr:        s.staged.DirectAddr, | 		ToAddr:        s.staged.DirectAddr, | ||||||
| 		PossibleAddrs: s.pubAddrs.Get(), | 		PossibleAddrs: s.pubAddrs.Get(), | ||||||
| @@ -202,55 +197,55 @@ func (s *StateServer) OnSyn(msg controlMsg[PacketSyn]) PeerState { | |||||||
| 	s.Send(s.staged, ack) | 	s.Send(s.staged, ack) | ||||||
|  |  | ||||||
| 	if p.Direct { | 	if p.Direct { | ||||||
| 		return nil | 		return s | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	for _, addr := range msg.Packet.PossibleAddrs { | 	for _, addr := range msg.Packet.PossibleAddrs { | ||||||
| 		if !addr.IsValid() { | 		if !addr.IsValid() { | ||||||
| 			break | 			break | ||||||
| 		} | 		} | ||||||
| 		s.SendTo(PacketProbe{TraceID: newTraceID()}, addr) | 		s.SendTo(packetProbe{TraceID: newTraceID()}, addr) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return nil | 	return s | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *StateServer) OnProbe(msg controlMsg[PacketProbe]) PeerState { | func (s *stateServer) OnProbe(msg controlMsg[packetProbe]) peerState { | ||||||
| 	if msg.SrcAddr.IsValid() { | 	if msg.SrcAddr.IsValid() { | ||||||
| 		s.SendTo(PacketProbe{TraceID: msg.Packet.TraceID}, msg.SrcAddr) | 		s.SendTo(packetProbe{TraceID: msg.Packet.TraceID}, msg.SrcAddr) | ||||||
| 	} | 	} | ||||||
| 	return nil | 	return s | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *StateServer) OnPingTimer() PeerState { | func (s *stateServer) OnPingTimer() peerState { | ||||||
| 	if time.Since(s.lastSeen) > timeoutInterval && s.staged.Up { | 	if time.Since(s.lastSeen) > timeoutInterval && s.staged.Up { | ||||||
| 		s.staged.Up = false | 		s.staged.Up = false | ||||||
| 		s.publish(s.staged) | 		s.publish(s.staged) | ||||||
| 		s.logf("Timeout.") | 		s.logf("Timeout.") | ||||||
| 	} | 	} | ||||||
| 	return nil | 	return s | ||||||
| } | } | ||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
| type StateClientDirect struct { | type stateClientDirect struct { | ||||||
| 	*StateDisconnected | 	*stateDisconnected | ||||||
| 	lastSeen time.Time | 	lastSeen time.Time | ||||||
| 	syn      PacketSyn | 	syn      packetSyn | ||||||
| } | } | ||||||
|  |  | ||||||
| func EnterStateClientDirect(s *State) PeerState { | func enterStateClientDirect(s *pState) peerState { | ||||||
| 	s.logf("==> ClientDirect") | 	s.logf("==> ClientDirect") | ||||||
| 	return NewStateClientDirect(s) | 	return newStateClientDirect(s) | ||||||
| } | } | ||||||
|  |  | ||||||
| func NewStateClientDirect(s *State) *StateClientDirect { | func newStateClientDirect(s *pState) *stateClientDirect { | ||||||
| 	state := &StateClientDirect{ | 	state := &stateClientDirect{ | ||||||
| 		StateDisconnected: &StateDisconnected{s}, | 		stateDisconnected: &stateDisconnected{s}, | ||||||
| 		lastSeen:          time.Now(), // Avoid immediate timeout. | 		lastSeen:          time.Now(), // Avoid immediate timeout. | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	state.syn = PacketSyn{ | 	state.syn = packetSyn{ | ||||||
| 		TraceID:       newTraceID(), | 		TraceID:       newTraceID(), | ||||||
| 		SharedKey:     s.staged.DataCipher.Key(), | 		SharedKey:     s.staged.DataCipher.Key(), | ||||||
| 		Direct:        s.staged.Direct, | 		Direct:        s.staged.Direct, | ||||||
| @@ -260,7 +255,7 @@ func NewStateClientDirect(s *State) *StateClientDirect { | |||||||
| 	return state | 	return state | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *StateClientDirect) OnAck(msg controlMsg[PacketAck]) { | func (s *stateClientDirect) OnAck(msg controlMsg[packetAck]) { | ||||||
| 	if msg.Packet.TraceID != s.syn.TraceID { | 	if msg.Packet.TraceID != s.syn.TraceID { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| @@ -276,7 +271,14 @@ func (s *StateClientDirect) OnAck(msg controlMsg[PacketAck]) { | |||||||
| 	s.pubAddrs.Store(msg.Packet.ToAddr) | 	s.pubAddrs.Store(msg.Packet.ToAddr) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *StateClientDirect) OnPingTimer() PeerState { | func (s *stateClientDirect) OnPingTimer() peerState { | ||||||
|  | 	if next := s.onPingTimer(); next != nil { | ||||||
|  | 		return next | ||||||
|  | 	} | ||||||
|  | 	return s | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *stateClientDirect) onPingTimer() peerState { | ||||||
| 	if time.Since(s.lastSeen) > timeoutInterval { | 	if time.Since(s.lastSeen) > timeoutInterval { | ||||||
| 		if s.staged.Up { | 		if s.staged.Up { | ||||||
| 			s.staged.Up = false | 			s.staged.Up = false | ||||||
| @@ -292,47 +294,47 @@ func (s *StateClientDirect) OnPingTimer() PeerState { | |||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
| type StateClientRelayed struct { | type stateClientRelayed struct { | ||||||
| 	*StateClientDirect | 	*stateClientDirect | ||||||
| 	ack                PacketAck | 	ack                packetAck | ||||||
| 	probes             map[uint64]netip.AddrPort | 	probes             map[uint64]netip.AddrPort | ||||||
| 	localDiscoveryAddr netip.AddrPort | 	localDiscoveryAddr netip.AddrPort | ||||||
| } | } | ||||||
|  |  | ||||||
| func EnterStateClientRelayed(s *State) PeerState { | func enterStateClientRelayed(s *pState) peerState { | ||||||
| 	s.logf("==> ClientRelayed") | 	s.logf("==> ClientRelayed") | ||||||
| 	return &StateClientRelayed{ | 	return &stateClientRelayed{ | ||||||
| 		StateClientDirect: NewStateClientDirect(s), | 		stateClientDirect: newStateClientDirect(s), | ||||||
| 		probes:            map[uint64]netip.AddrPort{}, | 		probes:            map[uint64]netip.AddrPort{}, | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *StateClientRelayed) OnAck(msg controlMsg[PacketAck]) { | func (s *stateClientRelayed) OnAck(msg controlMsg[packetAck]) { | ||||||
| 	s.ack = msg.Packet | 	s.ack = msg.Packet | ||||||
| 	s.StateClientDirect.OnAck(msg) | 	s.stateClientDirect.OnAck(msg) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *StateClientRelayed) OnProbe(msg controlMsg[PacketProbe]) PeerState { | func (s *stateClientRelayed) OnProbe(msg controlMsg[packetProbe]) peerState { | ||||||
| 	addr, ok := s.probes[msg.Packet.TraceID] | 	addr, ok := s.probes[msg.Packet.TraceID] | ||||||
| 	if !ok { | 	if !ok { | ||||||
| 		return nil | 		return s | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	s.staged.DirectAddr = addr | 	s.staged.DirectAddr = addr | ||||||
| 	s.staged.Direct = true | 	s.staged.Direct = true | ||||||
| 	s.publish(s.staged) | 	s.publish(s.staged) | ||||||
| 	return EnterStateClientDirect(s.StateClientDirect.State) | 	return enterStateClientDirect(s.stateClientDirect.pState) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *StateClientRelayed) OnLocalDiscovery(msg controlMsg[PacketLocalDiscovery]) { | func (s *stateClientRelayed) OnLocalDiscovery(msg controlMsg[packetLocalDiscovery]) { | ||||||
| 	// The source port will be the multicast port, so we'll have to | 	// The source port will be the multicast port, so we'll have to | ||||||
| 	// construct the correct address using the peer's listed port. | 	// construct the correct address using the peer's listed port. | ||||||
| 	s.localDiscoveryAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) | 	s.localDiscoveryAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *StateClientRelayed) OnPingTimer() PeerState { | func (s *stateClientRelayed) OnPingTimer() peerState { | ||||||
| 	if nextState := s.StateClientDirect.OnPingTimer(); nextState != nil { | 	if next := s.stateClientDirect.onPingTimer(); next != nil { | ||||||
| 		return nextState | 		return next | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	clear(s.probes) | 	clear(s.probes) | ||||||
| @@ -348,11 +350,11 @@ func (s *StateClientRelayed) OnPingTimer() PeerState { | |||||||
| 		s.localDiscoveryAddr = netip.AddrPort{} | 		s.localDiscoveryAddr = netip.AddrPort{} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return nil | 	return s | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *StateClientRelayed) sendProbeTo(addr netip.AddrPort) { | func (s *stateClientRelayed) sendProbeTo(addr netip.AddrPort) { | ||||||
| 	probe := PacketProbe{TraceID: newTraceID()} | 	probe := packetProbe{TraceID: newTraceID()} | ||||||
| 	s.probes[probe.TraceID] = addr | 	s.probes[probe.TraceID] = addr | ||||||
| 	s.SendTo(probe, addr) | 	s.SendTo(probe, addr) | ||||||
| } | } | ||||||
|   | |||||||
| @@ -12,13 +12,13 @@ import ( | |||||||
| // ---------------------------------------------------------------------------- | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
| type PeerStateControlMsg struct { | type PeerStateControlMsg struct { | ||||||
| 	Peer   RemotePeer | 	Peer   remotePeer | ||||||
| 	Packet any | 	Packet any | ||||||
| } | } | ||||||
|  |  | ||||||
| type PeerStateTestHarness struct { | type PeerStateTestHarness struct { | ||||||
| 	State     PeerState | 	State     peerState | ||||||
| 	Published RemotePeer | 	Published remotePeer | ||||||
| 	Sent      []PeerStateControlMsg | 	Sent      []PeerStateControlMsg | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -27,11 +27,11 @@ func NewPeerStateTestHarness() *PeerStateTestHarness { | |||||||
|  |  | ||||||
| 	keys := generateKeys() | 	keys := generateKeys() | ||||||
|  |  | ||||||
| 	state := &State{ | 	state := &pState{ | ||||||
| 		publish: func(rp RemotePeer) { | 		publish: func(rp remotePeer) { | ||||||
| 			h.Published = rp | 			h.Published = rp | ||||||
| 		}, | 		}, | ||||||
| 		sendControlPacket: func(rp RemotePeer, pkt Marshaller) { | 		sendControlPacket: func(rp remotePeer, pkt Marshaller) { | ||||||
| 			h.Sent = append(h.Sent, PeerStateControlMsg{rp, pkt}) | 			h.Sent = append(h.Sent, PeerStateControlMsg{rp, pkt}) | ||||||
| 		}, | 		}, | ||||||
| 		localIP:  2, | 		localIP:  2, | ||||||
| @@ -44,7 +44,7 @@ func NewPeerStateTestHarness() *PeerStateTestHarness { | |||||||
| 		}), | 		}), | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	h.State = EnterStateDisconnected(state) | 	h.State = enterStateDisconnected(state) | ||||||
| 	return h | 	return h | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -54,13 +54,13 @@ func (h *PeerStateTestHarness) PeerUpdate(p *m.Peer) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func (h *PeerStateTestHarness) OnSyn(msg controlMsg[PacketSyn]) { | func (h *PeerStateTestHarness) OnSyn(msg controlMsg[packetSyn]) { | ||||||
| 	if s := h.State.OnSyn(msg); s != nil { | 	if s := h.State.OnSyn(msg); s != nil { | ||||||
| 		h.State = s | 		h.State = s | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func (h *PeerStateTestHarness) OnProbe(msg controlMsg[PacketProbe]) { | func (h *PeerStateTestHarness) OnProbe(msg controlMsg[packetProbe]) { | ||||||
| 	if s := h.State.OnProbe(msg); s != nil { | 	if s := h.State.OnProbe(msg); s != nil { | ||||||
| 		h.State = s | 		h.State = s | ||||||
| 	} | 	} | ||||||
| @@ -72,10 +72,10 @@ func (h *PeerStateTestHarness) OnPingTimer() { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func (h *PeerStateTestHarness) ConfigServer_Public(t *testing.T) *StateServer { | func (h *PeerStateTestHarness) ConfigServer_Public(t *testing.T) *stateServer { | ||||||
| 	keys := generateKeys() | 	keys := generateKeys() | ||||||
|  |  | ||||||
| 	state := h.State.(*StateDisconnected) | 	state := h.State.(*stateDisconnected) | ||||||
| 	state.localAddr = addrPort4(1, 1, 1, 2, 200) | 	state.localAddr = addrPort4(1, 1, 1, 2, 200) | ||||||
|  |  | ||||||
| 	peer := &m.Peer{ | 	peer := &m.Peer{ | ||||||
| @@ -88,10 +88,10 @@ func (h *PeerStateTestHarness) ConfigServer_Public(t *testing.T) *StateServer { | |||||||
|  |  | ||||||
| 	h.PeerUpdate(peer) | 	h.PeerUpdate(peer) | ||||||
| 	assertEqual(t, h.Published.Up, false) | 	assertEqual(t, h.Published.Up, false) | ||||||
| 	return assertType[*StateServer](t, h.State) | 	return assertType[*stateServer](t, h.State) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (h *PeerStateTestHarness) ConfigServer_Relayed(t *testing.T) *StateServer { | func (h *PeerStateTestHarness) ConfigServer_Relayed(t *testing.T) *stateServer { | ||||||
| 	keys := generateKeys() | 	keys := generateKeys() | ||||||
| 	peer := &m.Peer{ | 	peer := &m.Peer{ | ||||||
| 		PeerIP:     3, | 		PeerIP:     3, | ||||||
| @@ -102,10 +102,10 @@ func (h *PeerStateTestHarness) ConfigServer_Relayed(t *testing.T) *StateServer { | |||||||
|  |  | ||||||
| 	h.PeerUpdate(peer) | 	h.PeerUpdate(peer) | ||||||
| 	assertEqual(t, h.Published.Up, false) | 	assertEqual(t, h.Published.Up, false) | ||||||
| 	return assertType[*StateServer](t, h.State) | 	return assertType[*stateServer](t, h.State) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (h *PeerStateTestHarness) ConfigClientDirect(t *testing.T) *StateClientDirect { | func (h *PeerStateTestHarness) ConfigClientDirect(t *testing.T) *stateClientDirect { | ||||||
| 	keys := generateKeys() | 	keys := generateKeys() | ||||||
| 	peer := &m.Peer{ | 	peer := &m.Peer{ | ||||||
| 		PeerIP:     3, | 		PeerIP:     3, | ||||||
| @@ -117,13 +117,13 @@ func (h *PeerStateTestHarness) ConfigClientDirect(t *testing.T) *StateClientDire | |||||||
|  |  | ||||||
| 	h.PeerUpdate(peer) | 	h.PeerUpdate(peer) | ||||||
| 	assertEqual(t, h.Published.Up, false) | 	assertEqual(t, h.Published.Up, false) | ||||||
| 	return assertType[*StateClientDirect](t, h.State) | 	return assertType[*stateClientDirect](t, h.State) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (h *PeerStateTestHarness) ConfigClientRelayed(t *testing.T) *StateClientRelayed { | func (h *PeerStateTestHarness) ConfigClientRelayed(t *testing.T) *stateClientRelayed { | ||||||
| 	keys := generateKeys() | 	keys := generateKeys() | ||||||
|  |  | ||||||
| 	state := h.State.(*StateDisconnected) | 	state := h.State.(*stateDisconnected) | ||||||
| 	state.remoteIP = 1 | 	state.remoteIP = 1 | ||||||
|  |  | ||||||
| 	peer := &m.Peer{ | 	peer := &m.Peer{ | ||||||
| @@ -135,7 +135,7 @@ func (h *PeerStateTestHarness) ConfigClientRelayed(t *testing.T) *StateClientRel | |||||||
|  |  | ||||||
| 	h.PeerUpdate(peer) | 	h.PeerUpdate(peer) | ||||||
| 	assertEqual(t, h.Published.Up, false) | 	assertEqual(t, h.Published.Up, false) | ||||||
| 	return assertType[*StateClientRelayed](t, h.State) | 	return assertType[*stateClientRelayed](t, h.State) | ||||||
| } | } | ||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- | // ---------------------------------------------------------------------------- | ||||||
| @@ -143,14 +143,14 @@ func (h *PeerStateTestHarness) ConfigClientRelayed(t *testing.T) *StateClientRel | |||||||
| func TestPeerState_OnPeerUpdate_nilPeer(t *testing.T) { | func TestPeerState_OnPeerUpdate_nilPeer(t *testing.T) { | ||||||
| 	h := NewPeerStateTestHarness() | 	h := NewPeerStateTestHarness() | ||||||
| 	h.PeerUpdate(nil) | 	h.PeerUpdate(nil) | ||||||
| 	assertType[*StateDisconnected](t, h.State) | 	assertType[*stateDisconnected](t, h.State) | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestPeerState_OnPeerUpdate_publicLocalIsServer(t *testing.T) { | func TestPeerState_OnPeerUpdate_publicLocalIsServer(t *testing.T) { | ||||||
| 	keys := generateKeys() | 	keys := generateKeys() | ||||||
| 	h := NewPeerStateTestHarness() | 	h := NewPeerStateTestHarness() | ||||||
|  |  | ||||||
| 	state := h.State.(*StateDisconnected) | 	state := h.State.(*stateDisconnected) | ||||||
| 	state.localAddr = addrPort4(1, 1, 1, 2, 200) | 	state.localAddr = addrPort4(1, 1, 1, 2, 200) | ||||||
|  |  | ||||||
| 	peer := &m.Peer{ | 	peer := &m.Peer{ | ||||||
| @@ -162,7 +162,7 @@ func TestPeerState_OnPeerUpdate_publicLocalIsServer(t *testing.T) { | |||||||
|  |  | ||||||
| 	h.PeerUpdate(peer) | 	h.PeerUpdate(peer) | ||||||
| 	assertEqual(t, h.Published.Up, false) | 	assertEqual(t, h.Published.Up, false) | ||||||
| 	assertType[*StateServer](t, h.State) | 	assertType[*stateServer](t, h.State) | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestPeerState_OnPeerUpdate_serverDirect(t *testing.T) { | func TestPeerState_OnPeerUpdate_serverDirect(t *testing.T) { | ||||||
| @@ -191,10 +191,10 @@ func TestStateServer_directSyn(t *testing.T) { | |||||||
|  |  | ||||||
| 	assertEqual(t, h.Published.Up, false) | 	assertEqual(t, h.Published.Up, false) | ||||||
|  |  | ||||||
| 	synMsg := controlMsg[PacketSyn]{ | 	synMsg := controlMsg[packetSyn]{ | ||||||
| 		SrcIP:   3, | 		SrcIP:   3, | ||||||
| 		SrcAddr: addrPort4(1, 1, 1, 3, 300), | 		SrcAddr: addrPort4(1, 1, 1, 3, 300), | ||||||
| 		Packet: PacketSyn{ | 		Packet: packetSyn{ | ||||||
| 			TraceID: newTraceID(), | 			TraceID: newTraceID(), | ||||||
| 			//SentAt:        time.Now().UnixMilli(), | 			//SentAt:        time.Now().UnixMilli(), | ||||||
| 			//SharedKeyType: 1, | 			//SharedKeyType: 1, | ||||||
| @@ -205,7 +205,7 @@ func TestStateServer_directSyn(t *testing.T) { | |||||||
| 	h.State.OnSyn(synMsg) | 	h.State.OnSyn(synMsg) | ||||||
|  |  | ||||||
| 	assertEqual(t, len(h.Sent), 1) | 	assertEqual(t, len(h.Sent), 1) | ||||||
| 	ack := assertType[PacketAck](t, h.Sent[0].Packet) | 	ack := assertType[packetAck](t, h.Sent[0].Packet) | ||||||
| 	assertEqual(t, ack.TraceID, synMsg.Packet.TraceID) | 	assertEqual(t, ack.TraceID, synMsg.Packet.TraceID) | ||||||
| 	assertEqual(t, h.Sent[0].Peer.IP, 3) | 	assertEqual(t, h.Sent[0].Peer.IP, 3) | ||||||
| 	assertEqual(t, ack.PossibleAddrs[0].IsValid(), false) | 	assertEqual(t, ack.PossibleAddrs[0].IsValid(), false) | ||||||
| @@ -220,10 +220,10 @@ func TestStateServer_relayedSyn(t *testing.T) { | |||||||
|  |  | ||||||
| 	assertEqual(t, h.Published.Up, false) | 	assertEqual(t, h.Published.Up, false) | ||||||
|  |  | ||||||
| 	synMsg := controlMsg[PacketSyn]{ | 	synMsg := controlMsg[packetSyn]{ | ||||||
| 		SrcIP:   3, | 		SrcIP:   3, | ||||||
| 		SrcAddr: addrPort4(1, 1, 1, 3, 300), | 		SrcAddr: addrPort4(1, 1, 1, 3, 300), | ||||||
| 		Packet: PacketSyn{ | 		Packet: packetSyn{ | ||||||
| 			TraceID: newTraceID(), | 			TraceID: newTraceID(), | ||||||
| 			//SentAt:        time.Now().UnixMilli(), | 			//SentAt:        time.Now().UnixMilli(), | ||||||
| 			//SharedKeyType: 1, | 			//SharedKeyType: 1, | ||||||
| @@ -237,15 +237,15 @@ func TestStateServer_relayedSyn(t *testing.T) { | |||||||
|  |  | ||||||
| 	assertEqual(t, len(h.Sent), 3) | 	assertEqual(t, len(h.Sent), 3) | ||||||
|  |  | ||||||
| 	ack := assertType[PacketAck](t, h.Sent[0].Packet) | 	ack := assertType[packetAck](t, h.Sent[0].Packet) | ||||||
| 	assertEqual(t, ack.TraceID, synMsg.Packet.TraceID) | 	assertEqual(t, ack.TraceID, synMsg.Packet.TraceID) | ||||||
| 	assertEqual(t, h.Sent[0].Peer.IP, 3) | 	assertEqual(t, h.Sent[0].Peer.IP, 3) | ||||||
| 	assertEqual(t, ack.PossibleAddrs[0], addrPort4(4, 5, 6, 7, 1234)) | 	assertEqual(t, ack.PossibleAddrs[0], addrPort4(4, 5, 6, 7, 1234)) | ||||||
| 	assertEqual(t, ack.PossibleAddrs[1].IsValid(), false) | 	assertEqual(t, ack.PossibleAddrs[1].IsValid(), false) | ||||||
| 	assertEqual(t, h.Published.Up, true) | 	assertEqual(t, h.Published.Up, true) | ||||||
|  |  | ||||||
| 	assertType[PacketProbe](t, h.Sent[1].Packet) | 	assertType[packetProbe](t, h.Sent[1].Packet) | ||||||
| 	assertType[PacketProbe](t, h.Sent[2].Packet) | 	assertType[packetProbe](t, h.Sent[2].Packet) | ||||||
| 	assertEqual(t, h.Sent[1].Peer.DirectAddr, addrPort4(1, 1, 1, 3, 300)) | 	assertEqual(t, h.Sent[1].Peer.DirectAddr, addrPort4(1, 1, 1, 3, 300)) | ||||||
| 	assertEqual(t, h.Sent[2].Peer.DirectAddr, addrPort4(2, 2, 2, 3, 300)) | 	assertEqual(t, h.Sent[2].Peer.DirectAddr, addrPort4(2, 2, 2, 3, 300)) | ||||||
| } | } | ||||||
| @@ -255,17 +255,17 @@ func TestStateServer_onProbe(t *testing.T) { | |||||||
| 	h.ConfigServer_Relayed(t) | 	h.ConfigServer_Relayed(t) | ||||||
| 	assertEqual(t, h.Published.Up, false) | 	assertEqual(t, h.Published.Up, false) | ||||||
|  |  | ||||||
| 	probeMsg := controlMsg[PacketProbe]{ | 	probeMsg := controlMsg[packetProbe]{ | ||||||
| 		SrcIP:   3, | 		SrcIP:   3, | ||||||
| 		SrcAddr: addrPort4(1, 1, 1, 3, 300), | 		SrcAddr: addrPort4(1, 1, 1, 3, 300), | ||||||
| 		Packet:  PacketProbe{TraceID: newTraceID()}, | 		Packet:  packetProbe{TraceID: newTraceID()}, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	h.State.OnProbe(probeMsg) | 	h.State.OnProbe(probeMsg) | ||||||
|  |  | ||||||
| 	assertEqual(t, len(h.Sent), 1) | 	assertEqual(t, len(h.Sent), 1) | ||||||
|  |  | ||||||
| 	probe := assertType[PacketProbe](t, h.Sent[0].Packet) | 	probe := assertType[packetProbe](t, h.Sent[0].Packet) | ||||||
| 	assertEqual(t, probe.TraceID, probeMsg.Packet.TraceID) | 	assertEqual(t, probe.TraceID, probeMsg.Packet.TraceID) | ||||||
| 	assertEqual(t, h.Sent[0].Peer.DirectAddr, addrPort4(1, 1, 1, 3, 300)) | 	assertEqual(t, h.Sent[0].Peer.DirectAddr, addrPort4(1, 1, 1, 3, 300)) | ||||||
| } | } | ||||||
| @@ -274,10 +274,10 @@ func TestStateServer_OnPingTimer_timeout(t *testing.T) { | |||||||
| 	h := NewPeerStateTestHarness() | 	h := NewPeerStateTestHarness() | ||||||
| 	h.ConfigServer_Relayed(t) | 	h.ConfigServer_Relayed(t) | ||||||
|  |  | ||||||
| 	synMsg := controlMsg[PacketSyn]{ | 	synMsg := controlMsg[packetSyn]{ | ||||||
| 		SrcIP:   3, | 		SrcIP:   3, | ||||||
| 		SrcAddr: addrPort4(1, 1, 1, 3, 300), | 		SrcAddr: addrPort4(1, 1, 1, 3, 300), | ||||||
| 		Packet: PacketSyn{ | 		Packet: packetSyn{ | ||||||
| 			TraceID: newTraceID(), | 			TraceID: newTraceID(), | ||||||
| 			//SentAt:        time.Now().UnixMilli(), | 			//SentAt:        time.Now().UnixMilli(), | ||||||
| 			//SharedKeyType: 1, | 			//SharedKeyType: 1, | ||||||
| @@ -294,7 +294,7 @@ func TestStateServer_OnPingTimer_timeout(t *testing.T) { | |||||||
| 	assertEqual(t, h.Published.Up, true) | 	assertEqual(t, h.Published.Up, true) | ||||||
|  |  | ||||||
| 	// Advance the time, then ping. | 	// Advance the time, then ping. | ||||||
| 	state := assertType[*StateServer](t, h.State) | 	state := assertType[*stateServer](t, h.State) | ||||||
| 	state.lastSeen = time.Now().Add(-timeoutInterval - time.Second) | 	state.lastSeen = time.Now().Add(-timeoutInterval - time.Second) | ||||||
|  |  | ||||||
| 	h.OnPingTimer() | 	h.OnPingTimer() | ||||||
| @@ -309,10 +309,10 @@ func TestStateClientDirect_OnAck(t *testing.T) { | |||||||
|  |  | ||||||
| 	// On entering the state, a SYN should have been sent. | 	// On entering the state, a SYN should have been sent. | ||||||
| 	assertEqual(t, len(h.Sent), 1) | 	assertEqual(t, len(h.Sent), 1) | ||||||
| 	syn := assertType[PacketSyn](t, h.Sent[0].Packet) | 	syn := assertType[packetSyn](t, h.Sent[0].Packet) | ||||||
|  |  | ||||||
| 	ack := controlMsg[PacketAck]{ | 	ack := controlMsg[packetAck]{ | ||||||
| 		Packet: PacketAck{TraceID: syn.TraceID}, | 		Packet: packetAck{TraceID: syn.TraceID}, | ||||||
| 	} | 	} | ||||||
| 	h.State.OnAck(ack) | 	h.State.OnAck(ack) | ||||||
| 	assertEqual(t, h.Published.Up, true) | 	assertEqual(t, h.Published.Up, true) | ||||||
| @@ -326,10 +326,10 @@ func TestStateClientDirect_OnAck_incorrectTraceID(t *testing.T) { | |||||||
|  |  | ||||||
| 	// On entering the state, a SYN should have been sent. | 	// On entering the state, a SYN should have been sent. | ||||||
| 	assertEqual(t, len(h.Sent), 1) | 	assertEqual(t, len(h.Sent), 1) | ||||||
| 	syn := assertType[PacketSyn](t, h.Sent[0].Packet) | 	syn := assertType[packetSyn](t, h.Sent[0].Packet) | ||||||
|  |  | ||||||
| 	ack := controlMsg[PacketAck]{ | 	ack := controlMsg[packetAck]{ | ||||||
| 		Packet: PacketAck{TraceID: syn.TraceID + 1}, | 		Packet: packetAck{TraceID: syn.TraceID + 1}, | ||||||
| 	} | 	} | ||||||
| 	h.State.OnAck(ack) | 	h.State.OnAck(ack) | ||||||
| 	assertEqual(t, h.Published.Up, false) | 	assertEqual(t, h.Published.Up, false) | ||||||
| @@ -341,15 +341,15 @@ func TestStateClientDirect_OnPingTimer(t *testing.T) { | |||||||
|  |  | ||||||
| 	// On entering the state, a SYN should have been sent. | 	// On entering the state, a SYN should have been sent. | ||||||
| 	assertEqual(t, len(h.Sent), 1) | 	assertEqual(t, len(h.Sent), 1) | ||||||
| 	assertType[PacketSyn](t, h.Sent[0].Packet) | 	assertType[packetSyn](t, h.Sent[0].Packet) | ||||||
|  |  | ||||||
| 	h.OnPingTimer() | 	h.OnPingTimer() | ||||||
|  |  | ||||||
| 	// On ping timer, another syn should be sent. Additionally, we should remain | 	// On ping timer, another syn should be sent. Additionally, we should remain | ||||||
| 	// in the same state. | 	// in the same state. | ||||||
| 	assertEqual(t, len(h.Sent), 2) | 	assertEqual(t, len(h.Sent), 2) | ||||||
| 	assertType[PacketSyn](t, h.Sent[1].Packet) | 	assertType[packetSyn](t, h.Sent[1].Packet) | ||||||
| 	assertType[*StateClientDirect](t, h.State) | 	assertType[*stateClientDirect](t, h.State) | ||||||
| 	assertEqual(t, h.Published.Up, false) | 	assertEqual(t, h.Published.Up, false) | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -361,15 +361,15 @@ func TestStateClientDirect_OnPingTimer_timeout(t *testing.T) { | |||||||
|  |  | ||||||
| 	// On entering the state, a SYN should have been sent. | 	// On entering the state, a SYN should have been sent. | ||||||
| 	assertEqual(t, len(h.Sent), 1) | 	assertEqual(t, len(h.Sent), 1) | ||||||
| 	syn := assertType[PacketSyn](t, h.Sent[0].Packet) | 	syn := assertType[packetSyn](t, h.Sent[0].Packet) | ||||||
|  |  | ||||||
| 	ack := controlMsg[PacketAck]{ | 	ack := controlMsg[packetAck]{ | ||||||
| 		Packet: PacketAck{TraceID: syn.TraceID}, | 		Packet: packetAck{TraceID: syn.TraceID}, | ||||||
| 	} | 	} | ||||||
| 	h.State.OnAck(ack) | 	h.State.OnAck(ack) | ||||||
| 	assertEqual(t, h.Published.Up, true) | 	assertEqual(t, h.Published.Up, true) | ||||||
|  |  | ||||||
| 	state := assertType[*StateClientDirect](t, h.State) | 	state := assertType[*stateClientDirect](t, h.State) | ||||||
| 	state.lastSeen = time.Now().Add(-(timeoutInterval + time.Second)) | 	state.lastSeen = time.Now().Add(-(timeoutInterval + time.Second)) | ||||||
|  |  | ||||||
| 	h.OnPingTimer() | 	h.OnPingTimer() | ||||||
| @@ -377,8 +377,8 @@ func TestStateClientDirect_OnPingTimer_timeout(t *testing.T) { | |||||||
| 	// On ping timer, we should timeout, causing the client to reset. Another SYN | 	// On ping timer, we should timeout, causing the client to reset. Another SYN | ||||||
| 	// will be sent when re-entering the state, but the connection should be down. | 	// will be sent when re-entering the state, but the connection should be down. | ||||||
| 	assertEqual(t, len(h.Sent), 2) | 	assertEqual(t, len(h.Sent), 2) | ||||||
| 	assertType[PacketSyn](t, h.Sent[1].Packet) | 	assertType[packetSyn](t, h.Sent[1].Packet) | ||||||
| 	assertType[*StateClientDirect](t, h.State) | 	assertType[*stateClientDirect](t, h.State) | ||||||
| 	assertEqual(t, h.Published.Up, false) | 	assertEqual(t, h.Published.Up, false) | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -390,10 +390,10 @@ func TestStateClientRelayed_OnAck(t *testing.T) { | |||||||
|  |  | ||||||
| 	// On entering the state, a SYN should have been sent. | 	// On entering the state, a SYN should have been sent. | ||||||
| 	assertEqual(t, len(h.Sent), 1) | 	assertEqual(t, len(h.Sent), 1) | ||||||
| 	syn := assertType[PacketSyn](t, h.Sent[0].Packet) | 	syn := assertType[packetSyn](t, h.Sent[0].Packet) | ||||||
|  |  | ||||||
| 	ack := controlMsg[PacketAck]{ | 	ack := controlMsg[packetAck]{ | ||||||
| 		Packet: PacketAck{TraceID: syn.TraceID}, | 		Packet: packetAck{TraceID: syn.TraceID}, | ||||||
| 	} | 	} | ||||||
| 	h.State.OnAck(ack) | 	h.State.OnAck(ack) | ||||||
| 	assertEqual(t, h.Published.Up, true) | 	assertEqual(t, h.Published.Up, true) | ||||||
| @@ -423,9 +423,9 @@ func TestStateClientRelayed_OnPingTimer_withAddrs(t *testing.T) { | |||||||
| 	// On entering the state, a SYN should have been sent. | 	// On entering the state, a SYN should have been sent. | ||||||
| 	assertEqual(t, len(h.Sent), 1) | 	assertEqual(t, len(h.Sent), 1) | ||||||
|  |  | ||||||
| 	syn := assertType[PacketSyn](t, h.Sent[0].Packet) | 	syn := assertType[packetSyn](t, h.Sent[0].Packet) | ||||||
|  |  | ||||||
| 	ack := controlMsg[PacketAck]{Packet: PacketAck{TraceID: syn.TraceID}} | 	ack := controlMsg[packetAck]{Packet: packetAck{TraceID: syn.TraceID}} | ||||||
| 	ack.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 300) | 	ack.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 300) | ||||||
| 	ack.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 300) | 	ack.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 300) | ||||||
|  |  | ||||||
| @@ -433,7 +433,7 @@ func TestStateClientRelayed_OnPingTimer_withAddrs(t *testing.T) { | |||||||
|  |  | ||||||
| 	// Add a local discovery address. Note that the port will be configured port | 	// Add a local discovery address. Note that the port will be configured port | ||||||
| 	// and no the one provided here. | 	// and no the one provided here. | ||||||
| 	h.State.OnLocalDiscovery(controlMsg[PacketLocalDiscovery]{ | 	h.State.OnLocalDiscovery(controlMsg[packetLocalDiscovery]{ | ||||||
| 		SrcIP:   3, | 		SrcIP:   3, | ||||||
| 		SrcAddr: addrPort4(2, 2, 2, 3, 300), | 		SrcAddr: addrPort4(2, 2, 2, 3, 300), | ||||||
| 	}) | 	}) | ||||||
| @@ -441,10 +441,10 @@ func TestStateClientRelayed_OnPingTimer_withAddrs(t *testing.T) { | |||||||
| 	// We should see one SYN and three probe packets. | 	// We should see one SYN and three probe packets. | ||||||
| 	h.OnPingTimer() | 	h.OnPingTimer() | ||||||
| 	assertEqual(t, len(h.Sent), 5) | 	assertEqual(t, len(h.Sent), 5) | ||||||
| 	assertType[PacketSyn](t, h.Sent[1].Packet) | 	assertType[packetSyn](t, h.Sent[1].Packet) | ||||||
| 	assertType[PacketProbe](t, h.Sent[2].Packet) | 	assertType[packetProbe](t, h.Sent[2].Packet) | ||||||
| 	assertType[PacketProbe](t, h.Sent[3].Packet) | 	assertType[packetProbe](t, h.Sent[3].Packet) | ||||||
| 	assertType[PacketProbe](t, h.Sent[4].Packet) | 	assertType[packetProbe](t, h.Sent[4].Packet) | ||||||
|  |  | ||||||
| 	assertEqual(t, h.Sent[2].Peer.DirectAddr, addrPort4(1, 1, 1, 1, 300)) | 	assertEqual(t, h.Sent[2].Peer.DirectAddr, addrPort4(1, 1, 1, 1, 300)) | ||||||
| 	assertEqual(t, h.Sent[3].Peer.DirectAddr, addrPort4(1, 1, 1, 2, 300)) | 	assertEqual(t, h.Sent[3].Peer.DirectAddr, addrPort4(1, 1, 1, 2, 300)) | ||||||
| @@ -457,15 +457,15 @@ func TestStateClientRelayed_OnPingTimer_timeout(t *testing.T) { | |||||||
|  |  | ||||||
| 	// On entering the state, a SYN should have been sent. | 	// On entering the state, a SYN should have been sent. | ||||||
| 	assertEqual(t, len(h.Sent), 1) | 	assertEqual(t, len(h.Sent), 1) | ||||||
| 	syn := assertType[PacketSyn](t, h.Sent[0].Packet) | 	syn := assertType[packetSyn](t, h.Sent[0].Packet) | ||||||
|  |  | ||||||
| 	ack := controlMsg[PacketAck]{ | 	ack := controlMsg[packetAck]{ | ||||||
| 		Packet: PacketAck{TraceID: syn.TraceID}, | 		Packet: packetAck{TraceID: syn.TraceID}, | ||||||
| 	} | 	} | ||||||
| 	h.State.OnAck(ack) | 	h.State.OnAck(ack) | ||||||
| 	assertEqual(t, h.Published.Up, true) | 	assertEqual(t, h.Published.Up, true) | ||||||
|  |  | ||||||
| 	state := assertType[*StateClientRelayed](t, h.State) | 	state := assertType[*stateClientRelayed](t, h.State) | ||||||
| 	state.lastSeen = time.Now().Add(-(timeoutInterval + time.Second)) | 	state.lastSeen = time.Now().Add(-(timeoutInterval + time.Second)) | ||||||
|  |  | ||||||
| 	h.OnPingTimer() | 	h.OnPingTimer() | ||||||
| @@ -473,8 +473,8 @@ func TestStateClientRelayed_OnPingTimer_timeout(t *testing.T) { | |||||||
| 	// On ping timer, we should timeout, causing the client to reset. Another SYN | 	// On ping timer, we should timeout, causing the client to reset. Another SYN | ||||||
| 	// will be sent when re-entering the state, but the connection should be down. | 	// will be sent when re-entering the state, but the connection should be down. | ||||||
| 	assertEqual(t, len(h.Sent), 2) | 	assertEqual(t, len(h.Sent), 2) | ||||||
| 	assertType[PacketSyn](t, h.Sent[1].Packet) | 	assertType[packetSyn](t, h.Sent[1].Packet) | ||||||
| 	assertType[*StateClientRelayed](t, h.State) | 	assertType[*stateClientRelayed](t, h.State) | ||||||
| 	assertEqual(t, h.Published.Up, false) | 	assertEqual(t, h.Published.Up, false) | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -482,28 +482,28 @@ func TestStateClientRelayed_OnProbe_unknownAddr(t *testing.T) { | |||||||
| 	h := NewPeerStateTestHarness() | 	h := NewPeerStateTestHarness() | ||||||
| 	h.ConfigClientRelayed(t) | 	h.ConfigClientRelayed(t) | ||||||
|  |  | ||||||
| 	h.OnProbe(controlMsg[PacketProbe]{ | 	h.OnProbe(controlMsg[packetProbe]{ | ||||||
| 		Packet: PacketProbe{TraceID: newTraceID()}, | 		Packet: packetProbe{TraceID: newTraceID()}, | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| 	assertType[*StateClientRelayed](t, h.State) | 	assertType[*stateClientRelayed](t, h.State) | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestStateClientRelayed_OnProbe_upgradeDirect(t *testing.T) { | func TestStateClientRelayed_OnProbe_upgradeDirect(t *testing.T) { | ||||||
| 	h := NewPeerStateTestHarness() | 	h := NewPeerStateTestHarness() | ||||||
| 	h.ConfigClientRelayed(t) | 	h.ConfigClientRelayed(t) | ||||||
|  |  | ||||||
| 	syn := assertType[PacketSyn](t, h.Sent[0].Packet) | 	syn := assertType[packetSyn](t, h.Sent[0].Packet) | ||||||
|  |  | ||||||
| 	ack := controlMsg[PacketAck]{Packet: PacketAck{TraceID: syn.TraceID}} | 	ack := controlMsg[packetAck]{Packet: packetAck{TraceID: syn.TraceID}} | ||||||
| 	ack.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 300) | 	ack.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 300) | ||||||
| 	ack.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 300) | 	ack.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 300) | ||||||
|  |  | ||||||
| 	h.State.OnAck(ack) | 	h.State.OnAck(ack) | ||||||
| 	h.OnPingTimer() | 	h.OnPingTimer() | ||||||
|  |  | ||||||
| 	probe := assertType[PacketProbe](t, h.Sent[2].Packet) | 	probe := assertType[packetProbe](t, h.Sent[2].Packet) | ||||||
| 	h.OnProbe(controlMsg[PacketProbe]{Packet: probe}) | 	h.OnProbe(controlMsg[packetProbe]{Packet: probe}) | ||||||
|  |  | ||||||
| 	assertType[*StateClientDirect](t, h.State) | 	assertType[*stateClientDirect](t, h.State) | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										172
									
								
								peer/peersuper.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										172
									
								
								peer/peersuper.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,172 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"log" | ||||||
|  | 	"math/rand" | ||||||
|  | 	"net/netip" | ||||||
|  | 	"sync" | ||||||
|  | 	"sync/atomic" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	"git.crumpington.com/lib/go/ratelimiter" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type Super struct { | ||||||
|  | 	writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error) | ||||||
|  | 	staged             routingTable | ||||||
|  | 	shared             *atomic.Pointer[routingTable] | ||||||
|  | 	peers              [256]*PeerSuper | ||||||
|  | 	lock               sync.Mutex | ||||||
|  |  | ||||||
|  | 	buf1 []byte | ||||||
|  | 	buf2 []byte | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func NewSuper( | ||||||
|  | 	writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error), | ||||||
|  | 	rt *atomic.Pointer[routingTable], | ||||||
|  | 	privKey []byte, | ||||||
|  | ) *Super { | ||||||
|  |  | ||||||
|  | 	routes := rt.Load() | ||||||
|  |  | ||||||
|  | 	s := &Super{ | ||||||
|  | 		writeToUDPAddrPort: writeToUDPAddrPort, | ||||||
|  | 		staged:             *routes, | ||||||
|  | 		shared:             rt, | ||||||
|  | 		buf1:               newBuf(), | ||||||
|  | 		buf2:               newBuf(), | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	pubAddrs := newPubAddrStore(routes.LocalAddr) | ||||||
|  |  | ||||||
|  | 	for i := range s.peers { | ||||||
|  | 		state := &pState{ | ||||||
|  | 			publish:           s.publish, | ||||||
|  | 			sendControlPacket: s.send, | ||||||
|  | 			localIP:           routes.LocalIP, | ||||||
|  | 			remoteIP:          byte(i), | ||||||
|  | 			privKey:           privKey, | ||||||
|  | 			localAddr:         routes.LocalAddr, | ||||||
|  | 			pubAddrs:          pubAddrs, | ||||||
|  | 			staged:            routes.Peers[i], | ||||||
|  | 			limiter: ratelimiter.New(ratelimiter.Config{ | ||||||
|  | 				FillPeriod:   20 * time.Millisecond, | ||||||
|  | 				MaxWaitCount: 1, | ||||||
|  | 			}), | ||||||
|  | 		} | ||||||
|  | 		s.peers[i] = NewPeerSuper(state) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return s | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *Super) Start() { | ||||||
|  | 	for i := range s.peers { | ||||||
|  | 		go s.peers[i].Run() | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *Super) HandleControlMsg(destIP byte, msg any) { | ||||||
|  | 	s.peers[destIP].HandleControlMsg(msg) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *Super) send(peer remotePeer, pkt Marshaller) { | ||||||
|  | 	s.lock.Lock() | ||||||
|  | 	defer s.lock.Unlock() | ||||||
|  |  | ||||||
|  | 	enc := peer.EncryptControlPacket(pkt, s.buf1, s.buf2) | ||||||
|  | 	if peer.Direct { | ||||||
|  | 		s.writeToUDPAddrPort(enc, peer.DirectAddr) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	relay, ok := s.staged.GetRelay() | ||||||
|  | 	if !ok { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	enc = relay.EncryptDataPacket(peer.IP, enc, s.buf1) | ||||||
|  | 	s.writeToUDPAddrPort(enc, relay.DirectAddr) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *Super) publish(rp remotePeer) { | ||||||
|  | 	s.lock.Lock() | ||||||
|  | 	defer s.lock.Unlock() | ||||||
|  |  | ||||||
|  | 	s.staged.Peers[rp.IP] = rp | ||||||
|  | 	s.ensureRelay() | ||||||
|  | 	copy := s.staged | ||||||
|  | 	s.shared.Store(©) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *Super) ensureRelay() { | ||||||
|  | 	if _, ok := s.staged.GetRelay(); ok { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// TODO: Random selection? | ||||||
|  | 	for _, peer := range s.staged.Peers { | ||||||
|  | 		if peer.Up && peer.Direct && peer.Relay { | ||||||
|  | 			s.staged.RelayIP = peer.IP | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type PeerSuper struct { | ||||||
|  | 	messages chan any | ||||||
|  | 	state    peerState | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func NewPeerSuper(state *pState) *PeerSuper { | ||||||
|  | 	return &PeerSuper{ | ||||||
|  | 		messages: make(chan any, 8), | ||||||
|  | 		state:    state.OnPeerUpdate(nil), | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *PeerSuper) HandleControlMsg(msg any) { | ||||||
|  | 	select { | ||||||
|  | 	case s.messages <- msg: | ||||||
|  | 	default: | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *PeerSuper) Run() { | ||||||
|  | 	go func() { | ||||||
|  | 		// Randomize ping timers. | ||||||
|  | 		time.Sleep(time.Duration(rand.Intn(4000)) * time.Millisecond) | ||||||
|  | 		for range time.Tick(pingInterval) { | ||||||
|  | 			s.messages <- pingTimerMsg{} | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  |  | ||||||
|  | 	for rawMsg := range s.messages { | ||||||
|  | 		switch msg := rawMsg.(type) { | ||||||
|  |  | ||||||
|  | 		case peerUpdateMsg: | ||||||
|  | 			s.state = s.state.OnPeerUpdate(msg.Peer) | ||||||
|  |  | ||||||
|  | 		case controlMsg[packetSyn]: | ||||||
|  | 			s.state = s.state.OnSyn(msg) | ||||||
|  |  | ||||||
|  | 		case controlMsg[packetAck]: | ||||||
|  | 			s.state.OnAck(msg) | ||||||
|  |  | ||||||
|  | 		case controlMsg[packetProbe]: | ||||||
|  | 			s.state = s.state.OnProbe(msg) | ||||||
|  |  | ||||||
|  | 		case controlMsg[packetLocalDiscovery]: | ||||||
|  | 			s.state.OnLocalDiscovery(msg) | ||||||
|  |  | ||||||
|  | 		case pingTimerMsg: | ||||||
|  | 			s.state = s.state.OnPingTimer() | ||||||
|  |  | ||||||
|  | 		default: | ||||||
|  | 			log.Printf("WARNING: unknown message type: %+v", msg) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @@ -7,9 +7,9 @@ import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| // TODO: Remove | // TODO: Remove | ||||||
| func NewRemotePeer(ip byte) *RemotePeer { | func NewRemotePeer(ip byte) *remotePeer { | ||||||
| 	counter := uint64(time.Now().Unix()<<30 + 1) | 	counter := uint64(time.Now().Unix()<<30 + 1) | ||||||
| 	return &RemotePeer{ | 	return &remotePeer{ | ||||||
| 		IP:       ip, | 		IP:       ip, | ||||||
| 		counter:  &counter, | 		counter:  &counter, | ||||||
| 		dupCheck: newDupCheck(0), | 		dupCheck: newDupCheck(0), | ||||||
| @@ -18,7 +18,7 @@ func NewRemotePeer(ip byte) *RemotePeer { | |||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
| type RemotePeer struct { | type remotePeer struct { | ||||||
| 	localIP       byte | 	localIP       byte | ||||||
| 	IP            byte           // VPN IP of peer (last byte). | 	IP            byte           // VPN IP of peer (last byte). | ||||||
| 	Up            bool           // True if data can be sent on the peer. | 	Up            bool           // True if data can be sent on the peer. | ||||||
| @@ -33,7 +33,7 @@ type RemotePeer struct { | |||||||
| 	dupCheck *dupCheck // For receiving from. Not safe for concurrent use. | 	dupCheck *dupCheck // For receiving from. Not safe for concurrent use. | ||||||
| } | } | ||||||
|  |  | ||||||
| func (p RemotePeer) EncryptDataPacket(destIP byte, data, out []byte) []byte { | func (p remotePeer) EncryptDataPacket(destIP byte, data, out []byte) []byte { | ||||||
| 	h := header{ | 	h := header{ | ||||||
| 		StreamID: dataStreamID, | 		StreamID: dataStreamID, | ||||||
| 		Counter:  atomic.AddUint64(p.counter, 1), | 		Counter:  atomic.AddUint64(p.counter, 1), | ||||||
| @@ -44,7 +44,7 @@ func (p RemotePeer) EncryptDataPacket(destIP byte, data, out []byte) []byte { | |||||||
| } | } | ||||||
|  |  | ||||||
| // Decrypts and de-dups incoming data packets. | // Decrypts and de-dups incoming data packets. | ||||||
| func (p RemotePeer) DecryptDataPacket(h header, enc, out []byte) ([]byte, error) { | func (p remotePeer) DecryptDataPacket(h header, enc, out []byte) ([]byte, error) { | ||||||
| 	dec, ok := p.DataCipher.Decrypt(enc, out) | 	dec, ok := p.DataCipher.Decrypt(enc, out) | ||||||
| 	if !ok { | 	if !ok { | ||||||
| 		return nil, errDecryptionFailed | 		return nil, errDecryptionFailed | ||||||
| @@ -58,21 +58,22 @@ func (p RemotePeer) DecryptDataPacket(h header, enc, out []byte) ([]byte, error) | |||||||
| } | } | ||||||
|  |  | ||||||
| // Peer must have a ControlCipher. | // Peer must have a ControlCipher. | ||||||
| func (p RemotePeer) EncryptControlPacket(pkt Marshaller, tmp, out []byte) []byte { | func (p remotePeer) EncryptControlPacket(pkt Marshaller, tmp, out []byte) []byte { | ||||||
|  | 	tmp = pkt.Marshal(tmp) | ||||||
| 	h := header{ | 	h := header{ | ||||||
| 		StreamID: controlStreamID, | 		StreamID: controlStreamID, | ||||||
| 		Counter:  atomic.AddUint64(p.counter, 1), | 		Counter:  atomic.AddUint64(p.counter, 1), | ||||||
| 		SourceIP: p.localIP, | 		SourceIP: p.localIP, | ||||||
| 		DestIP:   p.IP, | 		DestIP:   p.IP, | ||||||
| 	} | 	} | ||||||
| 	tmp = pkt.Marshal(tmp) |  | ||||||
| 	return p.ControlCipher.Encrypt(h, tmp, out) | 	return p.ControlCipher.Encrypt(h, tmp, out) | ||||||
| } | } | ||||||
|  |  | ||||||
| // Returns a controlMsg[PacketType]. Peer must have a non-nil ControlCipher. | // Returns a controlMsg[PacketType]. Peer must have a non-nil ControlCipher. | ||||||
| // | // | ||||||
| // This function also drops packets with duplicate sequence numbers. | // This function also drops packets with duplicate sequence numbers. | ||||||
| func (p RemotePeer) DecryptControlPacket(fromAddr netip.AddrPort, h header, enc, tmp []byte) (any, error) { | func (p remotePeer) DecryptControlPacket(fromAddr netip.AddrPort, h header, enc, tmp []byte) (any, error) { | ||||||
| 	out, ok := p.ControlCipher.Decrypt(enc, tmp) | 	out, ok := p.ControlCipher.Decrypt(enc, tmp) | ||||||
| 	if !ok { | 	if !ok { | ||||||
| 		return nil, errDecryptionFailed | 		return nil, errDecryptionFailed | ||||||
| @@ -92,7 +93,7 @@ func (p RemotePeer) DecryptControlPacket(fromAddr netip.AddrPort, h header, enc, | |||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
| type RoutingTable struct { | type routingTable struct { | ||||||
| 	// The LocalIP is the configured IP address of the local peer on the VPN. | 	// The LocalIP is the configured IP address of the local peer on the VPN. | ||||||
| 	// | 	// | ||||||
| 	// This value is constant. | 	// This value is constant. | ||||||
| @@ -106,21 +107,21 @@ type RoutingTable struct { | |||||||
| 	LocalAddr netip.AddrPort | 	LocalAddr netip.AddrPort | ||||||
|  |  | ||||||
| 	// The remote peer configurations. These are updated by | 	// The remote peer configurations. These are updated by | ||||||
| 	Peers [256]RemotePeer | 	Peers [256]remotePeer | ||||||
|  |  | ||||||
| 	// The current relay's VPN IP address, or zero if no relay is available. | 	// The current relay's VPN IP address, or zero if no relay is available. | ||||||
| 	RelayIP byte | 	RelayIP byte | ||||||
| } | } | ||||||
|  |  | ||||||
| func NewRoutingTable(localIP byte, localAddr netip.AddrPort) RoutingTable { | func newRoutingTable(localIP byte, localAddr netip.AddrPort) routingTable { | ||||||
| 	rt := RoutingTable{ | 	rt := routingTable{ | ||||||
| 		LocalIP:   localIP, | 		LocalIP:   localIP, | ||||||
| 		LocalAddr: localAddr, | 		LocalAddr: localAddr, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	for i := range rt.Peers { | 	for i := range rt.Peers { | ||||||
| 		counter := uint64(time.Now().Unix()<<30 + 1) | 		counter := uint64(time.Now().Unix()<<30 + 1) | ||||||
| 		rt.Peers[i] = RemotePeer{ | 		rt.Peers[i] = remotePeer{ | ||||||
| 			localIP:  localIP, | 			localIP:  localIP, | ||||||
| 			IP:       byte(i), | 			IP:       byte(i), | ||||||
| 			counter:  &counter, | 			counter:  &counter, | ||||||
| @@ -131,7 +132,7 @@ func NewRoutingTable(localIP byte, localAddr netip.AddrPort) RoutingTable { | |||||||
| 	return rt | 	return rt | ||||||
| } | } | ||||||
|  |  | ||||||
| func (rt *RoutingTable) GetRelay() (RemotePeer, bool) { | func (rt *routingTable) GetRelay() (remotePeer, bool) { | ||||||
| 	relay := rt.Peers[rt.RelayIP] | 	relay := rt.Peers[rt.RelayIP] | ||||||
| 	return relay, relay.Up && relay.Direct | 	return relay, relay.Up && relay.Direct | ||||||
| } | } | ||||||
|   | |||||||
| @@ -74,7 +74,7 @@ func TestRemotePeer_DecryptControlPacket(t *testing.T) { | |||||||
| 	peer2 := p1.RT.Load().Peers[2] | 	peer2 := p1.RT.Load().Peers[2] | ||||||
| 	peer1 := p2.RT.Load().Peers[1] | 	peer1 := p2.RT.Load().Peers[1] | ||||||
|  |  | ||||||
| 	orig := PacketProbe{TraceID: newTraceID()} | 	orig := packetProbe{TraceID: newTraceID()} | ||||||
|  |  | ||||||
| 	enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf()) | 	enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf()) | ||||||
|  |  | ||||||
| @@ -88,7 +88,7 @@ func TestRemotePeer_DecryptControlPacket(t *testing.T) { | |||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	dec, ok := ctrlMsg.(controlMsg[PacketProbe]) | 	dec, ok := ctrlMsg.(controlMsg[packetProbe]) | ||||||
| 	if !ok { | 	if !ok { | ||||||
| 		t.Fatal(ctrlMsg) | 		t.Fatal(ctrlMsg) | ||||||
| 	} | 	} | ||||||
| @@ -108,7 +108,7 @@ func TestRemotePeer_DecryptControlPacket_packetAltered(t *testing.T) { | |||||||
| 	peer2 := p1.RT.Load().Peers[2] | 	peer2 := p1.RT.Load().Peers[2] | ||||||
| 	peer1 := p2.RT.Load().Peers[1] | 	peer1 := p2.RT.Load().Peers[1] | ||||||
|  |  | ||||||
| 	orig := PacketProbe{TraceID: newTraceID()} | 	orig := packetProbe{TraceID: newTraceID()} | ||||||
|  |  | ||||||
| 	enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf()) | 	enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf()) | ||||||
|  |  | ||||||
| @@ -131,7 +131,7 @@ func TestRemotePeer_DecryptControlPacket_duplicateSequenceNumber(t *testing.T) { | |||||||
| 	peer2 := p1.RT.Load().Peers[2] | 	peer2 := p1.RT.Load().Peers[2] | ||||||
| 	peer1 := p2.RT.Load().Peers[1] | 	peer1 := p2.RT.Load().Peers[1] | ||||||
|  |  | ||||||
| 	orig := PacketProbe{TraceID: newTraceID()} | 	orig := packetProbe{TraceID: newTraceID()} | ||||||
|  |  | ||||||
| 	enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf()) | 	enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf()) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,103 +0,0 @@ | |||||||
| package peer |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"log" |  | ||||||
| 	"sync/atomic" |  | ||||||
| 	"time" |  | ||||||
|  |  | ||||||
| 	"git.crumpington.com/lib/go/ratelimiter" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- |  | ||||||
|  |  | ||||||
| type Supervisor struct { |  | ||||||
| 	messages chan any // Incoming control messages. |  | ||||||
| 	peers    [256]PeerState |  | ||||||
| 	pubAddrs *pubAddrStore |  | ||||||
| 	rt       *atomic.Pointer[RoutingTable] |  | ||||||
| 	staged   RoutingTable |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func NewSupervisor( |  | ||||||
| 	sendControl func(RemotePeer, Marshaller), |  | ||||||
| 	privKey []byte, |  | ||||||
| 	rt *atomic.Pointer[RoutingTable], |  | ||||||
| ) *Supervisor { |  | ||||||
| 	s := &Supervisor{ |  | ||||||
| 		messages: make(chan any, 1024), |  | ||||||
| 		pubAddrs: newPubAddrStore(rt.Load().LocalAddr), |  | ||||||
| 		rt:       rt, |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	routes := rt.Load() |  | ||||||
|  |  | ||||||
| 	for i := range s.peers { |  | ||||||
| 		state := &State{ |  | ||||||
| 			publish:           s.Publish, |  | ||||||
| 			sendControlPacket: sendControl, |  | ||||||
| 			localIP:           routes.LocalIP, |  | ||||||
| 			remoteIP:          byte(i), |  | ||||||
| 			privKey:           privKey, |  | ||||||
| 			localAddr:         routes.LocalAddr, |  | ||||||
| 			pubAddrs:          s.pubAddrs, |  | ||||||
| 			staged:            routes.Peers[i], |  | ||||||
| 			limiter: ratelimiter.New(ratelimiter.Config{ |  | ||||||
| 				FillPeriod:   20 * time.Millisecond, |  | ||||||
| 				MaxWaitCount: 1, |  | ||||||
| 			}), |  | ||||||
| 		} |  | ||||||
| 		s.peers[i] = state.OnPeerUpdate(nil) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return s |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (s *Supervisor) HandleControlMsg(msg any) { |  | ||||||
| 	select { |  | ||||||
| 	case s.messages <- msg: |  | ||||||
| 	default: |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (s *Supervisor) Run() { |  | ||||||
| 	for raw := range s.messages { |  | ||||||
| 		switch msg := raw.(type) { |  | ||||||
|  |  | ||||||
| 		case peerUpdateMsg: |  | ||||||
| 			s.peers[msg.PeerIP] = s.peers[msg.PeerIP].OnPeerUpdate(msg.Peer) |  | ||||||
|  |  | ||||||
| 		case controlMsg[PacketSyn]: |  | ||||||
| 			if newState := s.peers[msg.SrcIP].OnSyn(msg); newState != nil { |  | ||||||
| 				s.peers[msg.SrcIP] = newState |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 		case controlMsg[PacketAck]: |  | ||||||
| 			s.peers[msg.SrcIP].OnAck(msg) |  | ||||||
|  |  | ||||||
| 		case controlMsg[PacketProbe]: |  | ||||||
| 			if newState := s.peers[msg.SrcIP].OnProbe(msg); newState != nil { |  | ||||||
| 				s.peers[msg.SrcIP] = newState |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 		case controlMsg[PacketLocalDiscovery]: |  | ||||||
| 			s.peers[msg.SrcIP].OnLocalDiscovery(msg) |  | ||||||
|  |  | ||||||
| 		case pingTimerMsg: |  | ||||||
| 			s.pubAddrs.Clean() |  | ||||||
| 			for i := range s.peers { |  | ||||||
| 				if newState := s.peers[i].OnPingTimer(); newState != nil { |  | ||||||
| 					s.peers[i] = newState |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 		default: |  | ||||||
| 			log.Printf("WARNING: unknown message type: %+v", msg) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (s *Supervisor) Publish(rp RemotePeer) { |  | ||||||
| 	s.staged.Peers[rp.IP] = rp |  | ||||||
| 	rt := s.staged // Copy. |  | ||||||
| 	s.rt.Store(&rt) |  | ||||||
| } |  | ||||||
		Reference in New Issue
	
	Block a user