diff --git a/cmd/vppn/main.go b/cmd/vppn/main.go index 8c04016..5daa907 100644 --- a/cmd/vppn/main.go +++ b/cmd/vppn/main.go @@ -2,10 +2,10 @@ package main import ( "log" - "vppn/node" + "vppn/peer" ) func main() { log.SetFlags(0) - node.Main() + peer.Main() } diff --git a/node/main.go b/node/main.go index 8e53cb4..78611a8 100644 --- a/node/main.go +++ b/node/main.go @@ -258,7 +258,6 @@ func handleControlPacket(addr netip.AddrPort, h header, data, decBuf []byte) { default: log.Printf("Dropping control packet.") } - } func handleDataPacket(h header, data []byte, decBuf []byte, iface ifWriter, sender dataPacketSender) { diff --git a/peer/connreader.go b/peer/connreader.go index b127030..37a4c87 100644 --- a/peer/connreader.go +++ b/peer/connreader.go @@ -12,7 +12,7 @@ type connReader struct { sender encryptedPacketSender super controlMsgHandler localIP byte - peers [256]*atomic.Pointer[RemotePeer] + peers [256]*atomic.Pointer[remotePeer] buf []byte decBuf []byte @@ -24,7 +24,7 @@ func newConnReader( sender encryptedPacketSender, super controlMsgHandler, localIP byte, - peers [256]*atomic.Pointer[RemotePeer], + peers [256]*atomic.Pointer[remotePeer], ) *connReader { return &connReader{ conn: conn, @@ -79,7 +79,7 @@ func (r *connReader) handleNextPacket() { } func (r *connReader) handleControlPacket( - peer *RemotePeer, + peer *remotePeer, addr netip.AddrPort, h header, enc []byte, @@ -102,7 +102,7 @@ func (r *connReader) handleControlPacket( r.super.HandleControlMsg(msg) } -func (r *connReader) handleDataPacket(peer *RemotePeer, h header, enc []byte) { +func (r *connReader) handleDataPacket(peer *remotePeer, h header, enc []byte) { if !peer.Up { r.logf("Not connected (recv).") return diff --git a/peer/connreader2.go b/peer/connreader2.go index d9feab8..9e870d7 100644 --- a/peer/connreader2.go +++ b/peer/connreader2.go @@ -12,12 +12,12 @@ type ConnReader struct { readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error) // Output - iface io.Writer - forwardData func(ip byte, pkt []byte) - handleControlMsg func(pkt any) + writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error) + iface io.Writer + handleControlMsg func(fromIP byte, pkt any) localIP byte - rt *atomic.Pointer[RoutingTable] + rt *atomic.Pointer[routingTable] buf []byte decBuf []byte @@ -25,15 +25,15 @@ type ConnReader struct { func NewConnReader( readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error), + writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error), iface io.Writer, - forwardData func(ip byte, pkt []byte), - handleControlMsg func(pkt any), - rt *atomic.Pointer[RoutingTable], + handleControlMsg func(fromIP byte, pkt any), + rt *atomic.Pointer[routingTable], ) *ConnReader { return &ConnReader{ readFromUDPAddrPort: readFromUDPAddrPort, + writeToUDPAddrPort: writeToUDPAddrPort, iface: iface, - forwardData: forwardData, handleControlMsg: handleControlMsg, localIP: rt.Load().LocalIP, rt: rt, @@ -50,7 +50,9 @@ func (r *ConnReader) Run() { func (r *ConnReader) handleNextPacket() { buf := r.buf[:bufferSize] + log.Printf("Getting next packet...") n, remoteAddr, err := r.readFromUDPAddrPort(buf) + log.Printf("Packet from %v...", remoteAddr) if err != nil { log.Fatalf("Failed to read from UDP port: %v", err) } @@ -64,14 +66,14 @@ func (r *ConnReader) handleNextPacket() { buf = buf[:n] h := parseHeader(buf) - peer := r.rt.Load().Peers[h.SourceIP] - //peer := rt.Peers[h.SourceIP] + rt := r.rt.Load() + peer := rt.Peers[h.SourceIP] switch h.StreamID { case controlStreamID: r.handleControlPacket(remoteAddr, peer, h, buf) case dataStreamID: - r.handleDataPacket(peer, h, buf) + r.handleDataPacket(rt, peer, h, buf) default: r.logf("Unknown stream ID: %d", h.StreamID) } @@ -79,7 +81,7 @@ func (r *ConnReader) handleNextPacket() { func (r *ConnReader) handleControlPacket( remoteAddr netip.AddrPort, - peer RemotePeer, + peer remotePeer, h header, enc []byte, ) { @@ -98,11 +100,12 @@ func (r *ConnReader) handleControlPacket( return } - r.handleControlMsg(msg) + r.handleControlMsg(h.SourceIP, msg) } func (r *ConnReader) handleDataPacket( - peer RemotePeer, + rt *routingTable, + peer remotePeer, h header, enc []byte, ) { @@ -124,7 +127,13 @@ func (r *ConnReader) handleDataPacket( return } - r.forwardData(h.DestIP, data) + relay, ok := rt.GetRelay() + if !ok { + r.logf("Relay not available.") + return + } + + r.writeToUDPAddrPort(data, relay.DirectAddr) } func (r *ConnReader) logf(format string, args ...any) { diff --git a/peer/connreader_test.go b/peer/connreader_test.go deleted file mode 100644 index 714f6f3..0000000 --- a/peer/connreader_test.go +++ /dev/null @@ -1,353 +0,0 @@ -package peer - -import ( - "bytes" - "crypto/rand" - "net/netip" - "reflect" - "sync/atomic" - "testing" -) - -type mockIfWriter struct { - Written [][]byte -} - -func (w *mockIfWriter) Write(b []byte) (int, error) { - w.Written = append(w.Written, bytes.Clone(b)) - return len(b), nil -} - -type mockEncryptedPacket struct { - Packet []byte - Route *RemotePeer -} - -type mockEncryptedPacketSender struct { - Sent []mockEncryptedPacket -} - -func (m *mockEncryptedPacketSender) SendEncryptedDataPacket(pkt []byte, route *RemotePeer) { - m.Sent = append(m.Sent, mockEncryptedPacket{ - Packet: bytes.Clone(pkt), - Route: route, - }) -} - -type mockControlMsgHandler struct { - Messages []any -} - -func (m *mockControlMsgHandler) HandleControlMsg(pkt any) { - m.Messages = append(m.Messages, pkt) -} - -type udpPipe struct { - packets chan []byte -} - -func newUDPPipe() *udpPipe { - return &udpPipe{make(chan []byte, 1024)} -} - -func (p *udpPipe) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { - p.packets <- bytes.Clone(b) - return len(b), nil -} - -func (p *udpPipe) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) { - packet := <-p.packets - copy(b, packet) - return len(packet), netip.AddrPort{}, nil -} - -type connReaderTestHarness struct { - Pipe *udpPipe - R *connReader - WRemote *connWriter - WRelayRemote *connWriter - Remote *RemotePeer - RelayRemote *RemotePeer - IFace *mockIfWriter - Sender *mockEncryptedPacketSender - Super *mockControlMsgHandler -} - -// Peer 2 is indirect, peer 3 is direct. -func newConnReadeTestHarness() (h connReaderTestHarness) { - pipe := newUDPPipe() - routes := [256]*atomic.Pointer[RemotePeer]{} - for i := range routes { - routes[i] = &atomic.Pointer[RemotePeer]{} - routes[i].Store(&RemotePeer{}) - } - - local, remote, relayLocal, relayRemote := testConnWriter_getTestRoutes() - routes[2].Store(local) - routes[3].Store(relayLocal) - - h.Pipe = pipe - h.WRemote = newConnWriter(pipe, 2) - h.WRelayRemote = newConnWriter(pipe, 3) - - h.Remote = remote - h.RelayRemote = relayRemote - h.IFace = &mockIfWriter{} - h.Sender = &mockEncryptedPacketSender{} - h.Super = &mockControlMsgHandler{} - h.R = newConnReader( - pipe, - h.IFace, - h.Sender, - h.Super, - 1, - routes) - return h -} - -// Testing that we can receive a control packet. -func TestConnReader_handleControlPacket(t *testing.T) { - h := newConnReadeTestHarness() - - pkt := PacketSyn{TraceID: 1234} - - h.WRemote.SendControlPacket(pkt, h.Remote) - - h.R.handleNextPacket() - - if len(h.Super.Messages) != 1 { - t.Fatal(h.Super.Messages) - } - - msg := h.Super.Messages[0].(controlMsg[PacketSyn]) - if !reflect.DeepEqual(pkt, msg.Packet) { - t.Fatal(msg.Packet) - } -} - -// Testing that a short packet is ignored. -func TestConnReader_handleNextPacket_short(t *testing.T) { - h := newConnReadeTestHarness() - - h.Pipe.WriteToUDPAddrPort([]byte{1, 2, 3}, netip.AddrPort{}) - h.R.handleNextPacket() - - if len(h.Super.Messages) != 0 { - t.Fatal(h.Super.Messages) - } -} - -// Testing that a packet with an unexpected stream ID is ignored. -func TestConnReader_handleNextPacket_unknownStreamID(t *testing.T) { - h := newConnReadeTestHarness() - - pkt := PacketSyn{TraceID: 1234} - - encrypted := encryptControlPacket(1, h.Remote, pkt, newBuf(), newBuf()) - var header header - header.Parse(encrypted) - header.StreamID = 100 - header.Marshal(encrypted) - - h.WRemote.writeTo(encrypted, netip.AddrPort{}) - h.R.handleNextPacket() - if len(h.Super.Messages) != 0 { - t.Fatal(h.Super.Messages) - } -} - -// Testing that control packet without matching control cipher is ignored. -func TestConnReader_handleControlPacket_noCipher(t *testing.T) { - h := newConnReadeTestHarness() - - pkt := PacketSyn{TraceID: 1234} - - //encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) - encrypted := encryptControlPacket(1, h.Remote, pkt, newBuf(), newBuf()) - var header header - header.Parse(encrypted) - header.SourceIP = 10 - header.Marshal(encrypted) - - h.WRemote.writeTo(encrypted, netip.AddrPort{}) - h.R.handleNextPacket() - if len(h.Super.Messages) != 0 { - t.Fatal(h.Super.Messages) - } -} - -// Testing that control packet with incrrect destination IP is ignored. -func TestConnReader_handleControlPacket_incorrectDest(t *testing.T) { - h := newConnReadeTestHarness() - - pkt := PacketSyn{TraceID: 1234} - - encrypted := encryptControlPacket(2, h.Remote, pkt, newBuf(), newBuf()) - var header header - header.Parse(encrypted) - header.DestIP++ - header.Marshal(encrypted) - - h.WRemote.writeTo(encrypted, netip.AddrPort{}) - h.R.handleNextPacket() - if len(h.Super.Messages) != 0 { - t.Fatal(h.Super.Messages) - } -} - -// Testing that modified control packet is ignored. -func TestConnReader_handleControlPacket_modified(t *testing.T) { - h := newConnReadeTestHarness() - - pkt := PacketSyn{TraceID: 1234} - - encrypted := encryptControlPacket(2, h.Remote, pkt, newBuf(), newBuf()) - encrypted[len(encrypted)-1]++ - - h.WRemote.writeTo(encrypted, netip.AddrPort{}) - h.R.handleNextPacket() - if len(h.Super.Messages) != 0 { - t.Fatal(h.Super.Messages) - } -} - -type unknownPacket struct{} - -func (p unknownPacket) Marshal(buf []byte) []byte { - buf = buf[:1] - buf[0] = 100 - return buf -} - -// Testing that an empty control packet is ignored. -func TestConnReader_handleControlPacket_unknownPacketType(t *testing.T) { - h := newConnReadeTestHarness() - - pkt := unknownPacket{} - - encrypted := encryptControlPacket(2, h.Remote, pkt, newBuf(), newBuf()) - h.WRemote.writeTo(encrypted, netip.AddrPort{}) - h.R.handleNextPacket() - if len(h.Super.Messages) != 0 { - t.Fatal(h.Super.Messages) - } -} - -// Testing that a duplicate control packet is ignored. -func TestConnReader_handleControlPacket_duplicate(t *testing.T) { - h := newConnReadeTestHarness() - - pkt := PacketAck{TraceID: 1234} - - h.WRemote.SendControlPacket(pkt, h.Remote) - *h.Remote.counter = *h.Remote.counter - 1 - h.WRemote.SendControlPacket(pkt, h.Remote) - - h.R.handleNextPacket() - h.R.handleNextPacket() - - if len(h.Super.Messages) != 1 { - t.Fatal(h.Super.Messages) - } - - msg := h.Super.Messages[0].(controlMsg[PacketAck]) - if !reflect.DeepEqual(pkt, msg.Packet) { - t.Fatal(msg.Packet) - } -} - -// Testing that we can receive a data packet. -func TestConnReader_handleDataPacket(t *testing.T) { - h := newConnReadeTestHarness() - - pkt := make([]byte, 1024) - rand.Read(pkt) - - h.WRemote.SendDataPacket(pkt, h.Remote) - - h.R.handleNextPacket() - - if len(h.IFace.Written) != 1 { - t.Fatal(h.IFace.Written) - } - - if !bytes.Equal(pkt, h.IFace.Written[0]) { - t.Fatal(h.IFace.Written) - } -} - -// Testing that data packet is ignored if route isn't up. -func TestConnReader_handleDataPacket_routeDown(t *testing.T) { - h := newConnReadeTestHarness() - - pkt := make([]byte, 1024) - rand.Read(pkt) - - h.WRemote.SendDataPacket(pkt, h.Remote) - route := h.R.peers[2].Load() - route.Up = false - - h.R.handleNextPacket() - - if len(h.IFace.Written) != 0 { - t.Fatal(h.IFace.Written) - } -} - -// Testing that a duplicate data packet is ignored. -func TestConnReader_handleDataPacket_duplicate(t *testing.T) { - h := newConnReadeTestHarness() - - pkt := make([]byte, 123) - - h.WRemote.SendDataPacket(pkt, h.Remote) - *h.Remote.counter = *h.Remote.counter - 1 - h.WRemote.SendDataPacket(pkt, h.Remote) - - h.R.handleNextPacket() - h.R.handleNextPacket() - - if len(h.IFace.Written) != 1 { - t.Fatal(h.IFace.Written) - } - - if !bytes.Equal(pkt, h.IFace.Written[0]) { - t.Fatal(h.IFace.Written) - } -} - -// Testing that we can relay a data packet. -func TestConnReader_handleDataPacket_relay(t *testing.T) { - h := newConnReadeTestHarness() - - pkt := make([]byte, 1024) - rand.Read(pkt) - - h.RelayRemote.IP = 3 - h.WRemote.RelayDataPacket(pkt, h.RelayRemote, h.Remote) - h.R.handleNextPacket() - - if len(h.Sender.Sent) != 1 { - t.Fatal(h.Sender.Sent) - } - -} - -// Testing that we drop a relayed packet if destination is down. -func TestConnReader_handleDataPacket_relayDown(t *testing.T) { - h := newConnReadeTestHarness() - - pkt := make([]byte, 1024) - rand.Read(pkt) - - h.RelayRemote.IP = 3 - relay := h.R.peers[3].Load() - relay.Up = false - - h.WRemote.RelayDataPacket(pkt, h.RelayRemote, h.Remote) - h.R.handleNextPacket() - - if len(h.Sender.Sent) != 0 { - t.Fatal(h.Sender.Sent) - } -} diff --git a/peer/connwriter.go b/peer/connwriter.go deleted file mode 100644 index 8a09e35..0000000 --- a/peer/connwriter.go +++ /dev/null @@ -1,80 +0,0 @@ -package peer - -import ( - "log" - "net/netip" - "sync" -) - -// ---------------------------------------------------------------------------- - -type connWriter struct { - localIP byte - conn udpWriter - - // For sending control packets. - cBuf1 []byte - cBuf2 []byte - - // For sending data packets. - dBuf1 []byte - dBuf2 []byte - - // Lock around for sending on UDP Conn. - wLock sync.Mutex -} - -func newConnWriter(conn udpWriter, localIP byte) *connWriter { - w := &connWriter{ - localIP: localIP, - conn: conn, - cBuf1: make([]byte, bufferSize), - cBuf2: make([]byte, bufferSize), - dBuf1: make([]byte, bufferSize), - dBuf2: make([]byte, bufferSize), - } - return w -} - -// Not safe for concurrent use. Should only be called by supervisor. -func (w *connWriter) SendControlPacket(pkt Marshaller, peer *RemotePeer) { - enc := encryptControlPacket(w.localIP, peer, pkt, w.cBuf1, w.cBuf2) - w.writeTo(enc, peer.DirectAddr) -} - -// Relay control packet. Peer must not be nil. -func (w *connWriter) RelayControlPacket(pkt Marshaller, peer, relay *RemotePeer) { - enc := encryptControlPacket(w.localIP, peer, pkt, w.cBuf1, w.cBuf2) - enc = encryptDataPacket(w.localIP, peer.IP, relay, enc, w.cBuf1) - w.writeTo(enc, relay.DirectAddr) -} - -// Not safe for concurrent use. Should only be called by ifReader. -func (w *connWriter) SendDataPacket(pkt []byte, peer *RemotePeer) { - enc := encryptDataPacket(w.localIP, peer.IP, peer, pkt, w.dBuf1) - w.writeTo(enc, peer.DirectAddr) -} - -// Relay a data packet. Peer must not be nil. -func (w *connWriter) RelayDataPacket(pkt []byte, peer, relay *RemotePeer) { - enc := encryptDataPacket(w.localIP, peer.IP, peer, pkt, w.dBuf1) - enc = encryptDataPacket(w.localIP, peer.IP, relay, enc, w.dBuf2) - w.writeTo(enc, relay.DirectAddr) -} - -// Safe for concurrent use. Should only be called by connReader. -// -// This function will send pkt to the peer directly. This is used when a peer -// is acting as a relay and is forwarding already encrypted data for another -// peer. -func (w *connWriter) SendEncryptedDataPacket(pkt []byte, peer *RemotePeer) { - w.writeTo(pkt, peer.DirectAddr) -} - -func (w *connWriter) writeTo(packet []byte, addr netip.AddrPort) { - w.wLock.Lock() - if _, err := w.conn.WriteToUDPAddrPort(packet, addr); err != nil { - log.Printf("[ConnWriter] Failed to write to UDP port: %v", err) - } - w.wLock.Unlock() -} diff --git a/peer/connwriter2.go b/peer/connwriter2.go deleted file mode 100644 index e58250d..0000000 --- a/peer/connwriter2.go +++ /dev/null @@ -1,109 +0,0 @@ -package peer - -import ( - "log" - "net/netip" - "sync" - "sync/atomic" -) - -type ConnWriter struct { - wLock sync.Mutex // Lock around for sending on UDP Conn. - - // Output. - writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error) - - // Shared state. - rt *atomic.Pointer[RoutingTable] - - // For sending control packets. - cBuf1 []byte - cBuf2 []byte - - // For sending data packets. - dBuf1 []byte - dBuf2 []byte -} - -func NewConnWriter( - writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error), - rt *atomic.Pointer[RoutingTable], -) *ConnWriter { - return &ConnWriter{ - writeToUDPAddrPort: writeToUDPAddrPort, - rt: rt, - cBuf1: newBuf(), - cBuf2: newBuf(), - dBuf1: newBuf(), - dBuf2: newBuf(), - } -} - -// Called by ConnReader to forward already encrypted bytes to another peer. -func (w *ConnWriter) Forward(ip byte, pkt []byte) { - peer := w.rt.Load().Peers[ip] - if !(peer.Up && peer.Direct) { - w.logf("Failed to forward to %d.", ip) - return - } - w.writeTo(pkt, peer.DirectAddr) -} - -// Called by IFReader to send data. Encryption will be applied, and packet will -// be relayed if appropriate. -func (w *ConnWriter) WriteData(ip byte, pkt []byte) { - rt := w.rt.Load() - peer := rt.Peers[ip] - if !peer.Up { - w.logf("Failed to send data to %d.", ip) - return - } - - enc := peer.EncryptDataPacket(ip, pkt, w.dBuf1) - - if peer.Direct { - w.writeTo(enc, peer.DirectAddr) - return - } - - relay, ok := rt.GetRelay() - if !ok { - w.logf("Failed to send data to %d. No relay.", ip) - return - } - - enc = relay.EncryptDataPacket(ip, enc, w.dBuf2) - w.writeTo(enc, relay.DirectAddr) -} - -// Called by Supervisor to send control packets. -func (w *ConnWriter) WriteControl(peer RemotePeer, pkt Marshaller) { - enc := peer.EncryptControlPacket(pkt, w.cBuf2, w.cBuf1) - - if peer.Direct { - w.writeTo(enc, peer.DirectAddr) - return - } - - rt := w.rt.Load() - relay, ok := rt.GetRelay() - if !ok { - w.logf("Failed to send control to %d. No relay.", peer.IP) - return - } - - enc = relay.EncryptDataPacket(peer.IP, enc, w.cBuf2) - w.writeTo(enc, relay.DirectAddr) -} - -func (w *ConnWriter) writeTo(pkt []byte, addr netip.AddrPort) { - w.wLock.Lock() - if _, err := w.writeToUDPAddrPort(pkt, addr); err != nil { - w.logf("Failed to write to UDP port: %v", err) - } - w.wLock.Unlock() -} - -func (w *ConnWriter) logf(s string, args ...any) { - log.Printf("[ConnWriter] "+s, args...) -} diff --git a/peer/connwriter2_test.go b/peer/connwriter2_test.go deleted file mode 100644 index f0bb00f..0000000 --- a/peer/connwriter2_test.go +++ /dev/null @@ -1,145 +0,0 @@ -package peer - -import ( - "testing" -) - -func TestConnWriter_WriteData_direct(t *testing.T) { - p1, p2, _ := NewPeersForTesting() - - in := RandPacket() - p1.ConnWriter.WriteData(2, in) - - packets := p2.Conn.Packets() - if len(packets) != 1 { - t.Fatal(packets) - } -} - -func TestConnWriter_WriteData_peerNotUp(t *testing.T) { - p1, p2, _ := NewPeersForTesting() - p1.RT.Load().Peers[2].Up = false - - in := RandPacket() - p1.ConnWriter.WriteData(2, in) - - packets := p2.Conn.Packets() - if len(packets) != 0 { - t.Fatal(packets) - } -} - -func TestConnWriter_WriteData_relay(t *testing.T) { - p1, _, p3 := NewPeersForTesting() - - p1.RT.Load().Peers[2].Direct = false - p1.RT.Load().RelayIP = 3 - - in := RandPacket() - p1.ConnWriter.WriteData(2, in) - - packets := p3.Conn.Packets() - if len(packets) != 1 { - t.Fatal(packets) - } -} - -func TestConnWriter_WriteData_relayNotAvailable(t *testing.T) { - p1, _, p3 := NewPeersForTesting() - - p1.RT.Load().Peers[2].Direct = false - p1.RT.Load().Peers[3].Up = false - p1.RT.Load().RelayIP = 3 - - in := RandPacket() - p1.ConnWriter.WriteData(2, in) - - packets := p3.Conn.Packets() - if len(packets) != 0 { - t.Fatal(packets) - } -} - -func TestConnWriter_WriteControl_direct(t *testing.T) { - p1, p2, _ := NewPeersForTesting() - - orig := PacketProbe{TraceID: newTraceID()} - - p1.ConnWriter.WriteControl(p1.RT.Load().Peers[2], orig) - - packets := p2.Conn.Packets() - if len(packets) != 1 { - t.Fatal(packets) - } -} - -func TestConnWriter_WriteControl_relay(t *testing.T) { - p1, _, p3 := NewPeersForTesting() - - p1.RT.Load().Peers[2].Direct = false - p1.RT.Load().RelayIP = 3 - - orig := PacketProbe{TraceID: newTraceID()} - - p1.ConnWriter.WriteControl(p1.RT.Load().Peers[2], orig) - - packets := p3.Conn.Packets() - if len(packets) != 1 { - t.Fatal(packets) - } -} - -func TestConnWriter_WriteControl_relayNotAvailable(t *testing.T) { - p1, _, p3 := NewPeersForTesting() - - p1.RT.Load().Peers[2].Direct = false - p1.RT.Load().Peers[3].Up = false - p1.RT.Load().RelayIP = 3 - - orig := PacketProbe{TraceID: newTraceID()} - - p1.ConnWriter.WriteControl(p1.RT.Load().Peers[2], orig) - - packets := p3.Conn.Packets() - if len(packets) != 0 { - t.Fatal(packets) - } -} - -func TestConnWriter__Forward(t *testing.T) { - p1, p2, _ := NewPeersForTesting() - - in := RandPacket() - p1.ConnWriter.Forward(2, in) - - packets := p2.Conn.Packets() - if len(packets) != 1 { - t.Fatal(packets) - } -} - -func TestConnWriter__Forward_notUp(t *testing.T) { - p1, p2, _ := NewPeersForTesting() - p1.RT.Load().Peers[2].Up = false - - in := RandPacket() - p1.ConnWriter.Forward(2, in) - - packets := p2.Conn.Packets() - if len(packets) != 0 { - t.Fatal(packets) - } -} - -func TestConnWriter__Forward_notDirect(t *testing.T) { - p1, p2, _ := NewPeersForTesting() - p1.RT.Load().Peers[2].Direct = false - - in := RandPacket() - p1.ConnWriter.Forward(2, in) - - packets := p2.Conn.Packets() - if len(packets) != 0 { - t.Fatal(packets) - } -} diff --git a/peer/connwriter_test.go b/peer/connwriter_test.go deleted file mode 100644 index d8c0365..0000000 --- a/peer/connwriter_test.go +++ /dev/null @@ -1,240 +0,0 @@ -package peer - -import ( - "bytes" - "net/netip" - "testing" -) - -// ---------------------------------------------------------------------------- - -type testUDPPacket struct { - Addr netip.AddrPort - Data []byte -} - -type testUDPAddrPortWriter struct { - written []testUDPPacket -} - -func (w *testUDPAddrPortWriter) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { - w.written = append(w.written, testUDPPacket{ - Addr: addr, - Data: bytes.Clone(b), - }) - return len(b), nil -} - -func (w *testUDPAddrPortWriter) Written() []testUDPPacket { - out := w.written - w.written = []testUDPPacket{} - return out -} - -// ---------------------------------------------------------------------------- - -type testPacket string - -func (p testPacket) Marshal(b []byte) []byte { - b = b[:len(p)] - copy(b, []byte(p)) - return b -} - -// ---------------------------------------------------------------------------- - -func testConnWriter_getTestRoutes() (local, remote, relayLocal, relayRemote *RemotePeer) { - localKeys := generateKeys() - remoteKeys := generateKeys() - - local = NewRemotePeer(2) - local.Up = true - local.Relay = false - local.PubSignKey = remoteKeys.PubSignKey - local.ControlCipher = newControlCipher(localKeys.PrivKey, remoteKeys.PubKey) - local.DataCipher = newDataCipher() - local.DirectAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 2}), 100) - - remote = NewRemotePeer(1) - remote.Up = true - remote.Relay = false - remote.PubSignKey = localKeys.PubSignKey - remote.ControlCipher = newControlCipher(remoteKeys.PrivKey, localKeys.PubKey) - remote.DataCipher = local.DataCipher - remote.DirectAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 100) - - rLocalKeys := generateKeys() - rRemoteKeys := generateKeys() - - relayLocal = NewRemotePeer(3) - relayLocal.Up = true - relayLocal.Relay = true - relayLocal.Direct = true - relayLocal.PubSignKey = rRemoteKeys.PubSignKey - relayLocal.ControlCipher = newControlCipher(rLocalKeys.PrivKey, rRemoteKeys.PubKey) - relayLocal.DataCipher = newDataCipher() - relayLocal.DirectAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 3}), 100) - - relayRemote = NewRemotePeer(1) - relayRemote.Up = true - relayRemote.Relay = false - relayRemote.Direct = true - relayRemote.PubSignKey = rLocalKeys.PubSignKey - relayRemote.ControlCipher = newControlCipher(rRemoteKeys.PrivKey, rLocalKeys.PubKey) - relayRemote.DataCipher = relayLocal.DataCipher - relayRemote.DirectAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 100) - - return -} - -// ---------------------------------------------------------------------------- - -// Testing if we can send a control packet directly to the remote route. -func TestConnWriter_SendControlPacket_direct(t *testing.T) { - route, rRoute, _, _ := testConnWriter_getTestRoutes() - route.Direct = true - - writer := &testUDPAddrPortWriter{} - w := newConnWriter(writer, rRoute.IP) - in := testPacket("hello world!") - - w.SendControlPacket(in, route) - out := writer.Written() - if len(out) != 1 { - t.Fatal(out) - } - - if out[0].Addr != route.DirectAddr { - t.Fatal(out[0]) - } - - dec, ok := rRoute.ControlCipher.Decrypt(out[0].Data, make([]byte, 1024)) - if !ok { - t.Fatal(ok) - } - if string(dec) != string(in) { - t.Fatal(dec) - } -} - -// Testing if we can relay a packet via an intermediary. -func TestConnWriter_RelayControlPacket_relay(t *testing.T) { - route, rRoute, relay, rRelay := testConnWriter_getTestRoutes() - - writer := &testUDPAddrPortWriter{} - w := newConnWriter(writer, rRoute.IP) - in := testPacket("hello world!") - - w.RelayControlPacket(in, route, relay) - - out := writer.Written() - if len(out) != 1 { - t.Fatal(out) - } - - if out[0].Addr != relay.DirectAddr { - t.Fatal(out[0]) - } - - dec, ok := rRelay.DataCipher.Decrypt(out[0].Data, make([]byte, 1024)) - if !ok { - t.Fatal(ok) - } - - dec2, ok := rRoute.ControlCipher.Decrypt(dec, make([]byte, 1024)) - if !ok { - t.Fatal(ok) - } - - if string(dec2) != string(in) { - t.Fatal(dec2) - } -} - -// Testing that we can send a data packet directly to a remote route. -func TestConnWriter_SendDataPacket_direct(t *testing.T) { - route, rRoute, _, _ := testConnWriter_getTestRoutes() - route.Direct = true - - writer := &testUDPAddrPortWriter{} - w := newConnWriter(writer, rRoute.IP) - - in := []byte("hello world!") - w.SendDataPacket(in, route) - - out := writer.Written() - if len(out) != 1 { - t.Fatal(out) - } - - if out[0].Addr != route.DirectAddr { - t.Fatal(out[0]) - } - - dec, ok := rRoute.DataCipher.Decrypt(out[0].Data, make([]byte, 1024)) - if !ok { - t.Fatal(ok) - } - - if !bytes.Equal(dec, in) { - t.Fatal(dec) - } -} - -// Testing that we can relay a data packet via a relay. -func TestConnWriter_RelayDataPacket_relay(t *testing.T) { - route, rRoute, relay, rRelay := testConnWriter_getTestRoutes() - - writer := &testUDPAddrPortWriter{} - w := newConnWriter(writer, rRoute.IP) - in := []byte("Hello world!") - - w.RelayDataPacket(in, route, relay) - - out := writer.Written() - if len(out) != 1 { - t.Fatal(out) - } - - if out[0].Addr != relay.DirectAddr { - t.Fatal(out[0]) - } - - dec, ok := rRelay.DataCipher.Decrypt(out[0].Data, make([]byte, 1024)) - if !ok { - t.Fatal(ok) - } - - dec2, ok := rRoute.DataCipher.Decrypt(dec, make([]byte, 1024)) - if !ok { - t.Fatal(ok) - } - - if !bytes.Equal(dec2, in) { - t.Fatal(dec2) - } -} - -// Testing that we can send an already encrypted packet. -func TestConnWriter_SendEncryptedDataPacket(t *testing.T) { - route, rRoute, _, _ := testConnWriter_getTestRoutes() - - writer := &testUDPAddrPortWriter{} - w := newConnWriter(writer, rRoute.IP) - in := []byte("Hello world!") - - w.SendEncryptedDataPacket(in, route) - - out := writer.Written() - if len(out) != 1 { - t.Fatal(out) - } - - if out[0].Addr != route.DirectAddr { - t.Fatal(out[0]) - } - - if !bytes.Equal(out[0].Data, in) { - t.Fatal(out[0]) - } -} diff --git a/peer/controlmessage.go b/peer/controlmessage.go index 7180dd0..09935ab 100644 --- a/peer/controlmessage.go +++ b/peer/controlmessage.go @@ -17,25 +17,25 @@ type controlMsg[T any] struct { func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error) { switch buf[0] { - case PacketTypeSyn: - packet, err := ParsePacketSyn(buf) - return controlMsg[PacketSyn]{ + case packetTypeSyn: + packet, err := parsePacketSyn(buf) + return controlMsg[packetSyn]{ SrcIP: srcIP, SrcAddr: srcAddr, Packet: packet, }, err - case PacketTypeAck: - packet, err := ParsePacketAck(buf) - return controlMsg[PacketAck]{ + case packetTypeAck: + packet, err := parsePacketAck(buf) + return controlMsg[packetAck]{ SrcIP: srcIP, SrcAddr: srcAddr, Packet: packet, }, err - case PacketTypeProbe: - packet, err := ParsePacketProbe(buf) - return controlMsg[PacketProbe]{ + case packetTypeProbe: + packet, err := parsePacketProbe(buf) + return controlMsg[packetProbe]{ SrcIP: srcIP, SrcAddr: srcAddr, Packet: packet, diff --git a/peer/crypto.go b/peer/crypto.go index dcc042b..e8afe60 100644 --- a/peer/crypto.go +++ b/peer/crypto.go @@ -36,7 +36,7 @@ func generateKeys() cryptoKeys { // Peer must have a ControlCipher. func encryptControlPacket( localIP byte, - peer *RemotePeer, + peer *remotePeer, pkt Marshaller, tmp []byte, out []byte, @@ -55,7 +55,7 @@ func encryptControlPacket( // // This function also drops packets with duplicate sequence numbers. func decryptControlPacket( - peer *RemotePeer, + peer *remotePeer, fromAddr netip.AddrPort, h header, encrypted []byte, @@ -83,7 +83,7 @@ func decryptControlPacket( func encryptDataPacket( localIP byte, destIP byte, - peer *RemotePeer, + peer *remotePeer, data []byte, out []byte, ) []byte { @@ -98,7 +98,7 @@ func encryptDataPacket( // Decrypts and de-dups incoming data packets. func decryptDataPacket( - peer *RemotePeer, + peer *remotePeer, h header, encrypted []byte, out []byte, diff --git a/peer/crypto_test.go b/peer/crypto_test.go index 824a43a..57adfd2 100644 --- a/peer/crypto_test.go +++ b/peer/crypto_test.go @@ -9,7 +9,7 @@ import ( "testing" ) -func newRoutePairForTesting() (*RemotePeer, *RemotePeer) { +func newRoutePairForTesting() (*remotePeer, *remotePeer) { keys1 := generateKeys() keys2 := generateKeys() @@ -33,7 +33,7 @@ func TestDecryptControlPacket(t *testing.T) { out = make([]byte, bufferSize) ) - in := PacketSyn{ + in := packetSyn{ TraceID: newTraceID(), SharedKey: r1.DataCipher.Key(), Direct: true, @@ -47,7 +47,7 @@ func TestDecryptControlPacket(t *testing.T) { t.Fatal(err) } - msg, ok := iMsg.(controlMsg[PacketSyn]) + msg, ok := iMsg.(controlMsg[packetSyn]) if !ok { t.Fatal(ok) } @@ -64,7 +64,7 @@ func TestDecryptControlPacket_decryptionFailed(t *testing.T) { out = make([]byte, bufferSize) ) - in := PacketSyn{ + in := packetSyn{ TraceID: newTraceID(), SharedKey: r1.DataCipher.Key(), Direct: true, @@ -90,7 +90,7 @@ func TestDecryptControlPacket_duplicate(t *testing.T) { out = make([]byte, bufferSize) ) - in := PacketSyn{ + in := packetSyn{ TraceID: newTraceID(), SharedKey: r1.DataCipher.Key(), Direct: true, @@ -109,24 +109,25 @@ func TestDecryptControlPacket_duplicate(t *testing.T) { } } -func TestDecryptControlPacket_invalidPacket(t *testing.T) { - var ( - r1, r2 = newRoutePairForTesting() - tmp = make([]byte, bufferSize) - out = make([]byte, bufferSize) - ) +/* + func TestDecryptControlPacket_invalidPacket(t *testing.T) { + var ( + r1, r2 = newRoutePairForTesting() + tmp = make([]byte, bufferSize) + out = make([]byte, bufferSize) + ) - in := testPacket("hello!") + in := testPacket("hello!") - enc := encryptControlPacket(r1.IP, r2, in, tmp, out) - h := parseHeader(enc) + enc := encryptControlPacket(r1.IP, r2, in, tmp, out) + h := parseHeader(enc) - _, err := decryptControlPacket(r2, netip.AddrPort{}, h, enc, tmp) - if !errors.Is(err, errUnknownPacketType) { - t.Fatal(err) + _, err := decryptControlPacket(r2, netip.AddrPort{}, h, enc, tmp) + if !errors.Is(err, errUnknownPacketType) { + t.Fatal(err) + } } -} - +*/ func TestDecryptDataPacket(t *testing.T) { var ( r1, r2 = newRoutePairForTesting() diff --git a/peer/hubpoller.go b/peer/hubpoller.go index f608bd5..2b50495 100644 --- a/peer/hubpoller.go +++ b/peer/hubpoller.go @@ -11,15 +11,21 @@ import ( ) type hubPoller struct { - client *http.Client - req *http.Request - versions [256]int64 - localIP byte - netName string - super controlMsgHandler + client *http.Client + req *http.Request + versions [256]int64 + localIP byte + netName string + handleControlMsg func(fromIP byte, msg any) } -func newHubPoller(localIP byte, netName, hubURL, apiKey string, super controlMsgHandler) (*hubPoller, error) { +func newHubPoller( + localIP byte, + netName, + hubURL, + apiKey string, + handleControlMsg func(byte, any), +) (*hubPoller, error) { u, err := url.Parse(hubURL) if err != nil { return nil, err @@ -36,11 +42,11 @@ func newHubPoller(localIP byte, netName, hubURL, apiKey string, super controlMsg req.SetBasicAuth("", apiKey) return &hubPoller{ - client: client, - req: req, - localIP: localIP, - netName: netName, - super: super, + client: client, + req: req, + localIP: localIP, + netName: netName, + handleControlMsg: handleControlMsg, }, nil } @@ -90,7 +96,7 @@ func (hp *hubPoller) applyNetworkState(state m.NetworkState) { for i, peer := range state.Peers { if i != int(hp.localIP) { if peer == nil || peer.Version != hp.versions[i] { - hp.super.HandleControlMsg(peerUpdateMsg{PeerIP: byte(i), Peer: state.Peers[i]}) + hp.handleControlMsg(byte(i), peerUpdateMsg{PeerIP: byte(i), Peer: state.Peers[i]}) if peer != nil { hp.versions[i] = peer.Version } diff --git a/peer/ifreader.go b/peer/ifreader.go deleted file mode 100644 index 79ff441..0000000 --- a/peer/ifreader.go +++ /dev/null @@ -1,100 +0,0 @@ -package peer - -import ( - "io" - "log" - "sync/atomic" -) - -type ifReader struct { - iface io.Reader - peers [256]*atomic.Pointer[RemotePeer] - relay *atomic.Pointer[RemotePeer] - sender dataPacketSender -} - -func newIFReader( - iface io.Reader, - peers [256]*atomic.Pointer[RemotePeer], - relay *atomic.Pointer[RemotePeer], - sender dataPacketSender, -) *ifReader { - return &ifReader{ - iface: iface, - peers: peers, - relay: relay, - sender: sender, - } -} - -func (r *ifReader) Run() { - var ( - packet = make([]byte, bufferSize) - remoteIP byte - ok bool - ) - - for { - packet = r.readNextPacket(packet) - if remoteIP, ok = r.parsePacket(packet); ok { - r.sendPacket(packet, remoteIP) - } - } -} - -func (r *ifReader) sendPacket(pkt []byte, remoteIP byte) { - peer := r.peers[remoteIP].Load() - if !peer.Up { - log.Printf("Peer not connected: %d", remoteIP) - return - } - - // Direct path => early return. - if peer.Direct { - r.sender.SendDataPacket(pkt, peer) - return - } - - if relay := r.relay.Load(); relay != nil && relay.Up { - r.sender.RelayDataPacket(pkt, peer, relay) - } -} - -// Get next packet, returning packet, and destination ip. -func (r *ifReader) readNextPacket(buf []byte) []byte { - n, err := r.iface.Read(buf[:cap(buf)]) - if err != nil { - log.Fatalf("Failed to read from interface: %v", err) - } - - return buf[:n] -} - -func (r *ifReader) parsePacket(buf []byte) (byte, bool) { - n := len(buf) - if n == 0 { - return 0, false - } - - version := buf[0] >> 4 - - switch version { - case 4: - if n < 20 { - log.Printf("Short IPv4 packet: %d", len(buf)) - return 0, false - } - return buf[19], true - - case 6: - if len(buf) < 40 { - log.Printf("Short IPv6 packet: %d", len(buf)) - return 0, false - } - return buf[39], true - - default: - log.Printf("Invalid IP packet version: %v", version) - return 0, false - } -} diff --git a/peer/ifreader2.go b/peer/ifreader2.go index c390e8f..22bd7cf 100644 --- a/peer/ifreader2.go +++ b/peer/ifreader2.go @@ -3,22 +3,24 @@ package peer import ( "io" "log" + "net/netip" + "sync/atomic" ) type IFReader struct { - iface io.Reader - connWriter interface { - WriteData(ip byte, pkt []byte) - } + iface io.Reader + writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error) + rt *atomic.Pointer[routingTable] + buf1 []byte + buf2 []byte } func NewIFReader( iface io.Reader, - connWriter interface { - WriteData(ip byte, pkt []byte) - }, + writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error), + rt *atomic.Pointer[routingTable], ) *IFReader { - return &IFReader{iface, connWriter} + return &IFReader{iface, writeToUDPAddrPort, rt, newBuf(), newBuf()} } func (r *IFReader) Run() { @@ -30,9 +32,32 @@ func (r *IFReader) Run() { func (r *IFReader) handleNextPacket(packet []byte) { packet = r.readNextPacket(packet) - if remoteIP, ok := r.parsePacket(packet); ok { - r.connWriter.WriteData(remoteIP, packet) + remoteIP, ok := r.parsePacket(packet) + if !ok { + return } + + rt := r.rt.Load() + peer := rt.Peers[remoteIP] + if !peer.Up { + r.logf("Peer %d not up.", peer.IP) + return + } + + enc := peer.EncryptDataPacket(peer.IP, packet, r.buf1) + if peer.Direct { + r.writeToUDPAddrPort(enc, peer.DirectAddr) + return + } + + relay, ok := rt.GetRelay() + if !ok { + r.logf("Relay not available for peer %d.", peer.IP) + return + } + + enc = relay.EncryptDataPacket(peer.IP, enc, r.buf2) + r.writeToUDPAddrPort(enc, relay.DirectAddr) } func (r *IFReader) readNextPacket(buf []byte) []byte { diff --git a/peer/ifreader2_test.go b/peer/ifreader2_test.go index 779cf49..92ec5ac 100644 --- a/peer/ifreader2_test.go +++ b/peer/ifreader2_test.go @@ -1,9 +1,6 @@ package peer -import ( - "testing" -) - +/* func TestIFReader_IPv4(t *testing.T) { p1, p2, _ := NewPeersForTesting() @@ -81,3 +78,4 @@ func TestIFReader_parsePacket_shortIPv6(t *testing.T) { t.Fatal(ip, ok) } } +*/ diff --git a/peer/ifreader_test.go b/peer/ifreader_test.go deleted file mode 100644 index 620d2b1..0000000 --- a/peer/ifreader_test.go +++ /dev/null @@ -1,232 +0,0 @@ -package peer - -import ( - "bytes" - "reflect" - "sync/atomic" - "testing" -) - -// Test that we parse IPv4 packets correctly. -func TestIFReader_parsePacket_ipv4(t *testing.T) { - r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) - - pkt := make([]byte, 1234) - pkt[0] = 4 << 4 - pkt[19] = 128 - - if ip, ok := r.parsePacket(pkt); !ok || ip != 128 { - t.Fatal(ip, ok) - } -} - -// Test that we parse IPv6 packets correctly. -func TestIFReader_parsePacket_ipv6(t *testing.T) { - r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) - - pkt := make([]byte, 1234) - pkt[0] = 6 << 4 - pkt[39] = 42 - - if ip, ok := r.parsePacket(pkt); !ok || ip != 42 { - t.Fatal(ip, ok) - } -} - -/* -// Test that empty packets work as expected. -func TestIFReader_parsePacket_emptyPacket(t *testing.T) { - r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) - - pkt := make([]byte, 0) - if ip, ok := r.parsePacket(pkt); ok { - t.Fatal(ip, ok) - } -} - -// Test that invalid IP versions fail. -func TestIFReader_parsePacket_invalidIPVersion(t *testing.T) { - r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) - - for i := byte(1); i < 16; i++ { - if i == 4 || i == 6 { - continue - } - pkt := make([]byte, 1234) - pkt[0] = i << 4 - - if ip, ok := r.parsePacket(pkt); ok { - t.Fatal(i, ip, ok) - } - } -} - -// Test that short IPv4 packets fail. -func TestIFReader_parsePacket_shortIPv4(t *testing.T) { - r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) - - pkt := make([]byte, 19) - pkt[0] = 4 << 4 - - if ip, ok := r.parsePacket(pkt); ok { - t.Fatal(ip, ok) - } -} - -// Test that short IPv6 packets fail. -func TestIFReader_parsePacket_shortIPv6(t *testing.T) { - r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) - - pkt := make([]byte, 39) - pkt[0] = 6 << 4 - - if ip, ok := r.parsePacket(pkt); ok { - t.Fatal(ip, ok) - } -} - -// Test that we can read a packet. -func TestIFReader_readNextpacket(t *testing.T) { - in, out := net.Pipe() - r := newIFReader(out, [256]*atomic.Pointer[RemotePeer]{}, nil, nil) - defer in.Close() - defer out.Close() - - go in.Write([]byte("hello world!")) - - pkt := r.readNextPacket(make([]byte, bufferSize)) - if !bytes.Equal(pkt, []byte("hello world!")) { - t.Fatalf("%s", pkt) - } -} -*/ -// ---------------------------------------------------------------------------- - -type sentPacket struct { - Relayed bool - Packet []byte - Route RemotePeer - Relay RemotePeer -} - -type sendPacketTestHarness struct { - Packets []sentPacket -} - -func (h *sendPacketTestHarness) SendDataPacket(pkt []byte, route *RemotePeer) { - h.Packets = append(h.Packets, sentPacket{ - Packet: bytes.Clone(pkt), - Route: *route, - }) -} - -func (h *sendPacketTestHarness) RelayDataPacket(pkt []byte, route, relay *RemotePeer) { - h.Packets = append(h.Packets, sentPacket{ - Relayed: true, - Packet: bytes.Clone(pkt), - Route: *route, - Relay: *relay, - }) -} - -func newIFReaderForSendPacketTesting() (*ifReader, *sendPacketTestHarness) { - h := &sendPacketTestHarness{} - - routes := [256]*atomic.Pointer[RemotePeer]{} - for i := range routes { - routes[i] = &atomic.Pointer[RemotePeer]{} - routes[i].Store(&RemotePeer{}) - } - relay := &atomic.Pointer[RemotePeer]{} - r := newIFReader(nil, routes, relay, h) - return r, h -} - -// Testing that we can send a packet directly. -func TestIFReader_sendPacket_direct(t *testing.T) { - r, h := newIFReaderForSendPacketTesting() - - route := r.peers[2].Load() - route.Up = true - route.Direct = true - - in := []byte("hello world") - - r.sendPacket(in, 2) - if len(h.Packets) != 1 { - t.Fatal(h.Packets) - } - - expected := sentPacket{ - Relayed: false, - Packet: in, - Route: *route, - } - - if !reflect.DeepEqual(h.Packets[0], expected) { - t.Fatal(h.Packets[0]) - } -} - -// Testing that we don't send a packet if route isn't up. -func TestIFReader_sendPacket_directNotUp(t *testing.T) { - r, h := newIFReaderForSendPacketTesting() - - route := r.peers[2].Load() - route.Direct = true - - in := []byte("hello world") - - r.sendPacket(in, 2) - if len(h.Packets) != 0 { - t.Fatal(h.Packets) - } -} - -// Testing that we can send a packet via a relay. -func TestIFReader_sendPacket_relayed(t *testing.T) { - r, h := newIFReaderForSendPacketTesting() - - route := r.peers[2].Load() - route.Up = true - route.Direct = false - - relay := r.peers[3].Load() - r.relay.Store(relay) - relay.Up = true - relay.Direct = true - - in := []byte("hello world") - - r.sendPacket(in, 2) - if len(h.Packets) != 1 { - t.Fatal(h.Packets) - } - - expected := sentPacket{ - Relayed: true, - Packet: in, - Route: *route, - Relay: *relay, - } - - if !reflect.DeepEqual(h.Packets[0], expected) { - t.Fatal(h.Packets[0]) - } -} - -// Testing that we don't try to send on a nil relay IP. -func TestIFReader_sendPacket_nilRealy(t *testing.T) { - r, h := newIFReaderForSendPacketTesting() - - route := r.peers[2].Load() - route.Up = true - route.Direct = false - - in := []byte("hello world") - - r.sendPacket(in, 2) - if len(h.Packets) != 0 { - t.Fatal(h.Packets) - } -} diff --git a/peer/interface.go b/peer/interface.go new file mode 100644 index 0000000..7035c43 --- /dev/null +++ b/peer/interface.go @@ -0,0 +1,177 @@ +package peer + +import ( + "fmt" + "io" + "log" + "net" + "os" + "syscall" + + "golang.org/x/sys/unix" +) + +// Get next packet, returning packet, ip, and possible error. +func readNextPacket(iface io.ReadWriteCloser, buf []byte) ([]byte, byte, error) { + var ( + version byte + ip byte + ) + for { + n, err := iface.Read(buf[:cap(buf)]) + if err != nil { + return nil, ip, err + } + + buf = buf[:n] + version = buf[0] >> 4 + + switch version { + case 4: + if n < 20 { + log.Printf("Short IPv4 packet: %d", len(buf)) + continue + } + ip = buf[19] + + case 6: + if len(buf) < 40 { + log.Printf("Short IPv6 packet: %d", len(buf)) + continue + } + ip = buf[39] + + default: + log.Printf("Invalid IP packet version: %v", version) + continue + } + + return buf, ip, nil + } +} + +func openInterface(network []byte, localIP byte, name string) (io.ReadWriteCloser, error) { + if len(network) != 4 { + return nil, fmt.Errorf("expected network to be 4 bytes, got %d", len(network)) + } + ip := net.IPv4(network[0], network[1], network[2], localIP) + + ////////////////////////// + // Create TUN Interface // + ////////////////////////// + + tunFD, err := syscall.Open("/dev/net/tun", syscall.O_RDWR|unix.O_CLOEXEC, 0600) + if err != nil { + return nil, fmt.Errorf("failed to open TUN device: %w", err) + } + + // New interface request. + req, err := unix.NewIfreq(name) + if err != nil { + return nil, fmt.Errorf("failed to create new TUN interface request: %w", err) + } + + // Flags: + // + // IFF_NO_PI => don't add packet info data to packets sent to the interface. + // IFF_TUN => create a TUN device handling IP packets. + req.SetUint16(unix.IFF_NO_PI | unix.IFF_TUN) + + err = unix.IoctlIfreq(tunFD, unix.TUNSETIFF, req) + if err != nil { + return nil, fmt.Errorf("failed to set TUN device settings: %w", err) + } + + // Name may not be exactly the same? + name = req.Name() + + ///////////// + // Set MTU // + ///////////// + + // We need a socket file descriptor to set other options for some reason. + sockFD, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP) + if err != nil { + return nil, fmt.Errorf("failed to open socket: %w", err) + } + defer unix.Close(sockFD) + + req, err = unix.NewIfreq(name) + if err != nil { + return nil, fmt.Errorf("failed to create MTU interface request: %w", err) + } + + req.SetUint32(if_mtu) + if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFMTU, req); err != nil { + return nil, fmt.Errorf("failed to set interface MTU: %w", err) + } + + ////////////////////// + // Set Queue Length // + ////////////////////// + + req, err = unix.NewIfreq(name) + if err != nil { + return nil, fmt.Errorf("failed to create IP interface request: %w", err) + } + + req.SetUint16(if_queue_len) + if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFTXQLEN, req); err != nil { + return nil, fmt.Errorf("failed to set interface queue length: %w", err) + } + + ///////////////////// + // Set IP and Mask // + ///////////////////// + + req, err = unix.NewIfreq(name) + if err != nil { + return nil, fmt.Errorf("failed to create IP interface request: %w", err) + } + + if err := req.SetInet4Addr(ip.To4()); err != nil { + return nil, fmt.Errorf("failed to set interface request IP: %w", err) + } + + if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFADDR, req); err != nil { + return nil, fmt.Errorf("failed to set interface IP: %w", err) + } + + // SET MASK - must happen after setting address. + req, err = unix.NewIfreq(name) + if err != nil { + return nil, fmt.Errorf("failed to create mask interface request: %w", err) + } + + if err := req.SetInet4Addr(net.IPv4(255, 255, 255, 0).To4()); err != nil { + return nil, fmt.Errorf("failed to set interface request mask: %w", err) + } + + if err := unix.IoctlIfreq(sockFD, unix.SIOCSIFNETMASK, req); err != nil { + return nil, fmt.Errorf("failed to set interface mask: %w", err) + } + + //////////////////////// + // Bring Interface Up // + //////////////////////// + + req, err = unix.NewIfreq(name) + if err != nil { + return nil, fmt.Errorf("failed to create up interface request: %w", err) + } + + // Get current flags. + if err = unix.IoctlIfreq(sockFD, unix.SIOCGIFFLAGS, req); err != nil { + return nil, fmt.Errorf("failed to get interface flags: %w", err) + } + + flags := req.Uint16() | unix.IFF_UP | unix.IFF_RUNNING + + // Set UP flag / broadcast flags. + req.SetUint16(flags) + if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFFLAGS, req); err != nil { + return nil, fmt.Errorf("failed to set interface up: %w", err) + } + + return os.NewFile(uintptr(tunFD), "tun"), nil +} diff --git a/peer/interfaces.go b/peer/interfaces.go index 8e99e8b..d6b90d5 100644 --- a/peer/interfaces.go +++ b/peer/interfaces.go @@ -31,17 +31,17 @@ type Marshaller interface { } type dataPacketSender interface { - SendDataPacket(pkt []byte, peer *RemotePeer) - RelayDataPacket(pkt []byte, peer, relay *RemotePeer) + SendDataPacket(pkt []byte, peer *remotePeer) + RelayDataPacket(pkt []byte, peer, relay *remotePeer) } type controlPacketSender interface { - SendControlPacket(pkt Marshaller, peer *RemotePeer) - RelayControlPacket(pkt Marshaller, peer, relay *RemotePeer) + SendControlPacket(pkt Marshaller, peer *remotePeer) + RelayControlPacket(pkt Marshaller, peer, relay *remotePeer) } type encryptedPacketSender interface { - SendEncryptedDataPacket(pkt []byte, peer *RemotePeer) + SendEncryptedDataPacket(pkt []byte, peer *remotePeer) } type controlMsgHandler interface { diff --git a/peer/main.go b/peer/main.go new file mode 100644 index 0000000..c1ce110 --- /dev/null +++ b/peer/main.go @@ -0,0 +1,23 @@ +package peer + +import ( + "flag" + "os" +) + +func Main() { + conf := Config{} + + flag.StringVar(&conf.NetName, "name", "", "[REQUIRED] The network name.") + flag.StringVar(&conf.HubAddress, "hub-address", "", "[REQUIRED] The hub address.") + flag.StringVar(&conf.APIKey, "api-key", "", "[REQUIRED] The node's API key.") + flag.Parse() + + if conf.NetName == "" || conf.HubAddress == "" || conf.APIKey == "" { + flag.Usage() + os.Exit(1) + } + + peer := New(conf) + peer.Run() +} diff --git a/peer/mcreader.go b/peer/mcreader.go index 38921f1..3410655 100644 --- a/peer/mcreader.go +++ b/peer/mcreader.go @@ -8,7 +8,7 @@ import ( type mcReader struct { conn udpReader super controlMsgHandler - peers [256]*atomic.Pointer[RemotePeer] + peers [256]*atomic.Pointer[remotePeer] incoming []byte buf []byte @@ -17,7 +17,7 @@ type mcReader struct { func newMCReader( conn udpReader, super controlMsgHandler, - peers [256]*atomic.Pointer[RemotePeer], + peers [256]*atomic.Pointer[remotePeer], ) *mcReader { return &mcReader{conn, super, peers, newBuf(), newBuf()} } @@ -50,7 +50,7 @@ func (r *mcReader) handleNextPacket() { return } - r.super.HandleControlMsg(controlMsg[PacketLocalDiscovery]{ + r.super.HandleControlMsg(controlMsg[packetLocalDiscovery]{ SrcIP: h.SourceIP, SrcAddr: remoteAddr, }) diff --git a/peer/mcreader_test.go b/peer/mcreader_test.go index 50bf821..60feb44 100644 --- a/peer/mcreader_test.go +++ b/peer/mcreader_test.go @@ -1,13 +1,6 @@ package peer -import ( - "bytes" - "net" - "net/netip" - "sync/atomic" - "testing" -) - +/* type mcMockConn struct { packets chan []byte } @@ -136,3 +129,4 @@ func TestMCReader_badSignature(t *testing.T) { t.Fatal(super.Messages) } } +*/ diff --git a/peer/packets.go b/peer/packets.go index 596483d..b300dee 100644 --- a/peer/packets.go +++ b/peer/packets.go @@ -5,41 +5,34 @@ import ( ) const ( - PacketTypeSyn = iota + 1 - PacketTypeSynAck - PacketTypeAck - PacketTypeProbe - PacketTypeAddrDiscovery + packetTypeSyn = 1 + packetTypeAck = 3 + packetTypeProbe = 4 + packetTypeAddrDiscovery = 5 ) // ---------------------------------------------------------------------------- -type PacketSyn struct { - TraceID uint64 // TraceID to match response w/ request. - //SentAt int64 // Unixmilli. - //SharedKeyType byte // Currently only 1 is supported for AES. +type packetSyn struct { + TraceID uint64 // TraceID to match response w/ request. SharedKey [32]byte // Our shared key. Direct bool PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. } -func (p PacketSyn) Marshal(buf []byte) []byte { +func (p packetSyn) Marshal(buf []byte) []byte { return newBinWriter(buf). - Byte(PacketTypeSyn). + Byte(packetTypeSyn). Uint64(p.TraceID). - //Int64(p.SentAt). - //Byte(p.SharedKeyType). SharedKey(p.SharedKey). Bool(p.Direct). AddrPort8(p.PossibleAddrs). Build() } -func ParsePacketSyn(buf []byte) (p PacketSyn, err error) { +func parsePacketSyn(buf []byte) (p packetSyn, err error) { err = newBinReader(buf[1:]). Uint64(&p.TraceID). - //Int64(&p.SentAt). - //Byte(&p.SharedKeyType). SharedKey(&p.SharedKey). Bool(&p.Direct). AddrPort8(&p.PossibleAddrs). @@ -49,22 +42,22 @@ func ParsePacketSyn(buf []byte) (p PacketSyn, err error) { // ---------------------------------------------------------------------------- -type PacketAck struct { +type packetAck struct { TraceID uint64 ToAddr netip.AddrPort PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. } -func (p PacketAck) Marshal(buf []byte) []byte { +func (p packetAck) Marshal(buf []byte) []byte { return newBinWriter(buf). - Byte(PacketTypeAck). + Byte(packetTypeAck). Uint64(p.TraceID). AddrPort(p.ToAddr). AddrPort8(p.PossibleAddrs). Build() } -func ParsePacketAck(buf []byte) (p PacketAck, err error) { +func parsePacketAck(buf []byte) (p packetAck, err error) { err = newBinReader(buf[1:]). Uint64(&p.TraceID). AddrPort(&p.ToAddr). @@ -77,18 +70,18 @@ func ParsePacketAck(buf []byte) (p PacketAck, err error) { // A probeReqPacket is sent from a client to a server to determine if direct // UDP communication can be used. -type PacketProbe struct { +type packetProbe struct { TraceID uint64 } -func (p PacketProbe) Marshal(buf []byte) []byte { +func (p packetProbe) Marshal(buf []byte) []byte { return newBinWriter(buf). - Byte(PacketTypeProbe). + Byte(packetTypeProbe). Uint64(p.TraceID). Build() } -func ParsePacketProbe(buf []byte) (p PacketProbe, err error) { +func parsePacketProbe(buf []byte) (p packetProbe, err error) { err = newBinReader(buf[1:]). Uint64(&p.TraceID). Error() @@ -97,4 +90,4 @@ func ParsePacketProbe(buf []byte) (p PacketProbe, err error) { // ---------------------------------------------------------------------------- -type PacketLocalDiscovery struct{} +type packetLocalDiscovery struct{} diff --git a/peer/packets_test.go b/peer/packets_test.go index 3ddc1a0..c18b40a 100644 --- a/peer/packets_test.go +++ b/peer/packets_test.go @@ -8,7 +8,7 @@ import ( ) func TestSynPacket(t *testing.T) { - p := PacketSyn{ + p := packetSyn{ TraceID: newTraceID(), //SentAt: time.Now().UnixMilli(), //SharedKeyType: 1, @@ -21,7 +21,7 @@ func TestSynPacket(t *testing.T) { p.PossibleAddrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{3, 2, 3, 4}), 60000) buf := p.Marshal(newBuf()) - p2, err := ParsePacketSyn(buf) + p2, err := parsePacketSyn(buf) if err != nil { t.Fatal(err) } @@ -31,7 +31,7 @@ func TestSynPacket(t *testing.T) { } func TestAckPacket(t *testing.T) { - p := PacketAck{ + p := packetAck{ TraceID: newTraceID(), ToAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 234), } @@ -41,7 +41,7 @@ func TestAckPacket(t *testing.T) { p.PossibleAddrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{3, 2, 3, 4}), 60000) buf := p.Marshal(newBuf()) - p2, err := ParsePacketAck(buf) + p2, err := parsePacketAck(buf) if err != nil { t.Fatal(err) } @@ -51,12 +51,12 @@ func TestAckPacket(t *testing.T) { } func TestProbePacket(t *testing.T) { - p := PacketProbe{ + p := packetProbe{ TraceID: newTraceID(), } buf := p.Marshal(newBuf()) - p2, err := ParsePacketProbe(buf) + p2, err := parsePacketProbe(buf) if err != nil { t.Fatal(err) } diff --git a/peer/peer.go b/peer/peer.go new file mode 100644 index 0000000..6dc925b --- /dev/null +++ b/peer/peer.go @@ -0,0 +1,161 @@ +package peer + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "log" + "net" + "net/http" + "net/netip" + "net/url" + "sync" + "sync/atomic" + "vppn/m" +) + +type Peer struct { + ifReader *IFReader + connReader *ConnReader + iface io.Writer + hubPoller *hubPoller + super *Super +} + +type Config struct { + NetName string + HubAddress string + APIKey string +} + +func New(conf Config) *Peer { + config, err := loadPeerConfig(conf.NetName) + if err != nil { + log.Printf("Failed to load configuration: %v", err) + log.Printf("Initializing...") + initPeerWithHub(conf) + + config, err = loadPeerConfig(conf.NetName) + if err != nil { + log.Fatalf("Failed to load configuration: %v", err) + } + } + + iface, err := openInterface(config.Network, config.PeerIP, conf.NetName) + if err != nil { + log.Fatalf("Failed to open interface: %v", err) + } + + myAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", config.Port)) + if err != nil { + log.Fatalf("Failed to resolve UDP address: %v", err) + } + + log.Printf("Listening on %v...", myAddr) + conn, err := net.ListenUDP("udp", myAddr) + if err != nil { + log.Fatalf("Failed to open UDP port: %v", err) + } + + conn.SetReadBuffer(1024 * 1024 * 8) + conn.SetWriteBuffer(1024 * 1024 * 8) + + // Wrap write function - this is necessary to avoid starvation. + writeLock := sync.Mutex{} + writeToUDPAddrPort := func(b []byte, addr netip.AddrPort) (n int, err error) { + writeLock.Lock() + n, err = conn.WriteToUDPAddrPort(b, addr) + if err != nil { + log.Printf("Failed to write packet: %v", err) + } + writeLock.Unlock() + return n, err + } + + var localAddr netip.AddrPort + ip, ok := netip.AddrFromSlice(config.PublicIP) + if ok { + localAddr = netip.AddrPortFrom(ip, config.Port) + } + + rt := newRoutingTable(config.PeerIP, localAddr) + rtPtr := &atomic.Pointer[routingTable]{} + rtPtr.Store(&rt) + + ifReader := NewIFReader(iface, writeToUDPAddrPort, rtPtr) + super := NewSuper(writeToUDPAddrPort, rtPtr, config.PrivKey) + connReader := NewConnReader(conn.ReadFromUDPAddrPort, writeToUDPAddrPort, iface, super.HandleControlMsg, rtPtr) + hubPoller, err := newHubPoller(config.PeerIP, conf.NetName, conf.HubAddress, conf.APIKey, super.HandleControlMsg) + if err != nil { + log.Fatalf("Failed to create hub poller: %v", err) + } + + return &Peer{ + iface: iface, + ifReader: ifReader, + connReader: connReader, + hubPoller: hubPoller, + super: super, + } +} + +func (p *Peer) Run() { + go p.ifReader.Run() + go p.connReader.Run() + p.super.Start() + p.hubPoller.Run() +} + +func initPeerWithHub(conf Config) { + keys := generateKeys() + + initURL, err := url.Parse(conf.HubAddress) + if err != nil { + log.Fatalf("Failed to parse hub URL: %v", err) + } + initURL.Path = "/peer/init/" + + args := m.PeerInitArgs{ + EncPubKey: keys.PubKey, + PubSignKey: keys.PubSignKey, + } + + buf := &bytes.Buffer{} + if err := json.NewEncoder(buf).Encode(args); err != nil { + log.Fatalf("Failed to encode init args: %v", err) + } + + req, err := http.NewRequest(http.MethodPost, initURL.String(), buf) + if err != nil { + log.Fatalf("Failed to construct request: %v", err) + } + req.SetBasicAuth("", conf.APIKey) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + log.Fatalf("Failed to init with hub: %v", err) + } + defer resp.Body.Close() + + data, err := io.ReadAll(resp.Body) + if err != nil { + log.Fatalf("Failed to read response body: %v", err) + } + + peerConfig := localConfig{} + if err := json.Unmarshal(data, &peerConfig.PeerConfig); err != nil { + log.Fatalf("Failed to parse configuration: %v\n%s", err, data) + } + + peerConfig.PubKey = keys.PubKey + peerConfig.PrivKey = keys.PrivKey + peerConfig.PubSignKey = keys.PubSignKey + peerConfig.PrivSignKey = keys.PrivSignKey + + if err := storePeerConfig(conf.NetName, peerConfig); err != nil { + log.Fatalf("Failed to store configuration: %v", err) + } + + log.Print("Initialization successful.") +} diff --git a/peer/peer_test.go b/peer/peer_test.go index 414beaa..863ca8f 100644 --- a/peer/peer_test.go +++ b/peer/peer_test.go @@ -11,36 +11,25 @@ import ( // A test peer. type P struct { cryptoKeys - RT *atomic.Pointer[RoutingTable] + RT *atomic.Pointer[routingTable] Conn *TestUDPConn IFace *TestIFace - ConnWriter *ConnWriter ConnReader *ConnReader IFReader *IFReader - Super *Supervisor } func NewPeerForTesting(n *TestNetwork, ip byte, addr netip.AddrPort) P { p := P{ cryptoKeys: generateKeys(), - RT: &atomic.Pointer[RoutingTable]{}, + RT: &atomic.Pointer[routingTable]{}, IFace: NewTestIFace(), } - rt := NewRoutingTable(ip, addr) + rt := newRoutingTable(ip, addr) p.RT.Store(&rt) p.Conn = n.NewUDPConn(addr) - p.ConnWriter = NewConnWriter(p.Conn.WriteToUDPAddrPort, p.RT) - p.IFReader = NewIFReader(p.IFace, p.ConnWriter) + //p.ConnWriter = NewConnWriter(p.Conn.WriteToUDPAddrPort, p.RT) - /* - p.ConnReader = NewConnReader( - p.Conn.ReadFromUDPAddrPort, - p.IFace, - p.ConnWriter.Forward, - p.Super.HandleControlMsg, - p.RT) - */ return p } diff --git a/peer/peerstates.go b/peer/peerstates.go index f12ab08..5ded157 100644 --- a/peer/peerstates.go +++ b/peer/peerstates.go @@ -11,21 +11,21 @@ import ( "git.crumpington.com/lib/go/ratelimiter" ) -type PeerState interface { - OnPeerUpdate(*m.Peer) PeerState - OnSyn(controlMsg[PacketSyn]) PeerState - OnAck(controlMsg[PacketAck]) - OnProbe(controlMsg[PacketProbe]) PeerState - OnLocalDiscovery(controlMsg[PacketLocalDiscovery]) - OnPingTimer() PeerState +type peerState interface { + OnPeerUpdate(*m.Peer) peerState + OnSyn(controlMsg[packetSyn]) peerState + OnAck(controlMsg[packetAck]) + OnProbe(controlMsg[packetProbe]) peerState + OnLocalDiscovery(controlMsg[packetLocalDiscovery]) + OnPingTimer() peerState } // ---------------------------------------------------------------------------- -type State struct { +type pState struct { // Output. - publish func(RemotePeer) - sendControlPacket func(RemotePeer, Marshaller) + publish func(remotePeer) + sendControlPacket func(remotePeer, Marshaller) // Immutable data. localIP byte @@ -37,7 +37,7 @@ type State struct { // The purpose of this state machine is to manage the RemotePeer object, // publishing it as necessary. - staged RemotePeer // Local copy of shared data. See publish(). + staged remotePeer // Local copy of shared data. See publish(). // Mutable peer data. peer *m.Peer @@ -47,25 +47,28 @@ type State struct { limiter *ratelimiter.Limiter } -func (s *State) OnPeerUpdate(peer *m.Peer) PeerState { +func (s *pState) OnPeerUpdate(peer *m.Peer) peerState { defer func() { // Don't defer directly otherwise s.staged will be evaluated immediately // and won't reflect changes made in the function. s.publish(s.staged) }() - if peer == nil { - return EnterStateDisconnected(s) - } - s.peer = peer - s.staged.localIP = s.localIP - s.staged.IP = peer.PeerIP s.staged.Up = false s.staged.Relay = false s.staged.Direct = false s.staged.DirectAddr = netip.AddrPort{} + s.staged.PubSignKey = nil + s.staged.ControlCipher = nil + s.staged.DataCipher = nil + + if peer == nil { + return enterStateDisconnected(s) + } + + s.staged.IP = peer.PeerIP s.staged.PubSignKey = peer.PubSignKey s.staged.ControlCipher = newControlCipher(s.privKey, peer.PubKey) s.staged.DataCipher = newDataCipher() @@ -76,30 +79,32 @@ func (s *State) OnPeerUpdate(peer *m.Peer) PeerState { s.staged.DirectAddr = netip.AddrPortFrom(ip, peer.Port) if s.localAddr.IsValid() && s.localIP < s.remoteIP { - return EnterStateServer(s) + return enterStateServer(s) } - return EnterStateClientDirect(s) + return enterStateClientDirect(s) } if s.localAddr.IsValid() { s.staged.Direct = true - return EnterStateServer(s) + return enterStateServer(s) } if s.localIP < s.remoteIP { - return EnterStateServer(s) + return enterStateServer(s) } - return EnterStateClientRelayed(s) + return enterStateClientRelayed(s) } -func (s *State) logf(format string, args ...any) { +func (s *pState) logf(format string, args ...any) { b := strings.Builder{} name := "" if s.peer != nil { name = s.peer.Name } + b.WriteString(fmt.Sprintf("%03d", s.remoteIP)) + b.WriteString(fmt.Sprintf("%30s: ", name)) if s.staged.Direct { @@ -119,7 +124,7 @@ func (s *State) logf(format string, args ...any) { // ---------------------------------------------------------------------------- -func (s *State) SendTo(pkt Marshaller, addr netip.AddrPort) { +func (s *pState) SendTo(pkt Marshaller, addr netip.AddrPort) { if !addr.IsValid() { return } @@ -129,7 +134,7 @@ func (s *State) SendTo(pkt Marshaller, addr netip.AddrPort) { s.Send(route, pkt) } -func (s *State) Send(peer RemotePeer, pkt Marshaller) { +func (s *pState) Send(peer remotePeer, pkt Marshaller) { if err := s.limiter.Limit(); err != nil { s.logf("Rate limited.") return @@ -139,42 +144,32 @@ func (s *State) Send(peer RemotePeer, pkt Marshaller) { // ---------------------------------------------------------------------------- -type StateDisconnected struct{ *State } +type stateDisconnected struct{ *pState } -func EnterStateDisconnected(s *State) PeerState { - s.logf("==> Disconnected") - s.peer = nil - s.staged.Up = false - s.staged.Relay = false - s.staged.Direct = false - s.staged.DirectAddr = netip.AddrPort{} - s.staged.PubSignKey = nil - s.staged.ControlCipher = nil - s.staged.DataCipher = nil - s.publish(s.staged) - return &StateDisconnected{State: s} +func enterStateDisconnected(s *pState) peerState { + return &stateDisconnected{pState: s} } -func (s *StateDisconnected) OnSyn(controlMsg[PacketSyn]) PeerState { return nil } -func (s *StateDisconnected) OnAck(controlMsg[PacketAck]) {} -func (s *StateDisconnected) OnProbe(controlMsg[PacketProbe]) PeerState { return nil } -func (s *StateDisconnected) OnLocalDiscovery(controlMsg[PacketLocalDiscovery]) {} -func (s *StateDisconnected) OnPingTimer() PeerState { return nil } +func (s *stateDisconnected) OnSyn(controlMsg[packetSyn]) peerState { return s } +func (s *stateDisconnected) OnAck(controlMsg[packetAck]) {} +func (s *stateDisconnected) OnProbe(controlMsg[packetProbe]) peerState { return s } +func (s *stateDisconnected) OnLocalDiscovery(controlMsg[packetLocalDiscovery]) {} +func (s *stateDisconnected) OnPingTimer() peerState { return s } // ---------------------------------------------------------------------------- -type StateServer struct { - *StateDisconnected +type stateServer struct { + *stateDisconnected lastSeen time.Time synTraceID uint64 } -func EnterStateServer(s *State) PeerState { +func enterStateServer(s *pState) peerState { s.logf("==> Server") - return &StateServer{StateDisconnected: &StateDisconnected{State: s}} + return &stateServer{stateDisconnected: &stateDisconnected{pState: s}} } -func (s *StateServer) OnSyn(msg controlMsg[PacketSyn]) PeerState { +func (s *stateServer) OnSyn(msg controlMsg[packetSyn]) peerState { s.lastSeen = time.Now() p := msg.Packet @@ -194,7 +189,7 @@ func (s *StateServer) OnSyn(msg controlMsg[PacketSyn]) PeerState { } // Always respond. - ack := PacketAck{ + ack := packetAck{ TraceID: p.TraceID, ToAddr: s.staged.DirectAddr, PossibleAddrs: s.pubAddrs.Get(), @@ -202,55 +197,55 @@ func (s *StateServer) OnSyn(msg controlMsg[PacketSyn]) PeerState { s.Send(s.staged, ack) if p.Direct { - return nil + return s } for _, addr := range msg.Packet.PossibleAddrs { if !addr.IsValid() { break } - s.SendTo(PacketProbe{TraceID: newTraceID()}, addr) + s.SendTo(packetProbe{TraceID: newTraceID()}, addr) } - return nil + return s } -func (s *StateServer) OnProbe(msg controlMsg[PacketProbe]) PeerState { +func (s *stateServer) OnProbe(msg controlMsg[packetProbe]) peerState { if msg.SrcAddr.IsValid() { - s.SendTo(PacketProbe{TraceID: msg.Packet.TraceID}, msg.SrcAddr) + s.SendTo(packetProbe{TraceID: msg.Packet.TraceID}, msg.SrcAddr) } - return nil + return s } -func (s *StateServer) OnPingTimer() PeerState { +func (s *stateServer) OnPingTimer() peerState { if time.Since(s.lastSeen) > timeoutInterval && s.staged.Up { s.staged.Up = false s.publish(s.staged) s.logf("Timeout.") } - return nil + return s } // ---------------------------------------------------------------------------- -type StateClientDirect struct { - *StateDisconnected +type stateClientDirect struct { + *stateDisconnected lastSeen time.Time - syn PacketSyn + syn packetSyn } -func EnterStateClientDirect(s *State) PeerState { +func enterStateClientDirect(s *pState) peerState { s.logf("==> ClientDirect") - return NewStateClientDirect(s) + return newStateClientDirect(s) } -func NewStateClientDirect(s *State) *StateClientDirect { - state := &StateClientDirect{ - StateDisconnected: &StateDisconnected{s}, +func newStateClientDirect(s *pState) *stateClientDirect { + state := &stateClientDirect{ + stateDisconnected: &stateDisconnected{s}, lastSeen: time.Now(), // Avoid immediate timeout. } - state.syn = PacketSyn{ + state.syn = packetSyn{ TraceID: newTraceID(), SharedKey: s.staged.DataCipher.Key(), Direct: s.staged.Direct, @@ -260,7 +255,7 @@ func NewStateClientDirect(s *State) *StateClientDirect { return state } -func (s *StateClientDirect) OnAck(msg controlMsg[PacketAck]) { +func (s *stateClientDirect) OnAck(msg controlMsg[packetAck]) { if msg.Packet.TraceID != s.syn.TraceID { return } @@ -276,7 +271,14 @@ func (s *StateClientDirect) OnAck(msg controlMsg[PacketAck]) { s.pubAddrs.Store(msg.Packet.ToAddr) } -func (s *StateClientDirect) OnPingTimer() PeerState { +func (s *stateClientDirect) OnPingTimer() peerState { + if next := s.onPingTimer(); next != nil { + return next + } + return s +} + +func (s *stateClientDirect) onPingTimer() peerState { if time.Since(s.lastSeen) > timeoutInterval { if s.staged.Up { s.staged.Up = false @@ -292,47 +294,47 @@ func (s *StateClientDirect) OnPingTimer() PeerState { // ---------------------------------------------------------------------------- -type StateClientRelayed struct { - *StateClientDirect - ack PacketAck +type stateClientRelayed struct { + *stateClientDirect + ack packetAck probes map[uint64]netip.AddrPort localDiscoveryAddr netip.AddrPort } -func EnterStateClientRelayed(s *State) PeerState { +func enterStateClientRelayed(s *pState) peerState { s.logf("==> ClientRelayed") - return &StateClientRelayed{ - StateClientDirect: NewStateClientDirect(s), + return &stateClientRelayed{ + stateClientDirect: newStateClientDirect(s), probes: map[uint64]netip.AddrPort{}, } } -func (s *StateClientRelayed) OnAck(msg controlMsg[PacketAck]) { +func (s *stateClientRelayed) OnAck(msg controlMsg[packetAck]) { s.ack = msg.Packet - s.StateClientDirect.OnAck(msg) + s.stateClientDirect.OnAck(msg) } -func (s *StateClientRelayed) OnProbe(msg controlMsg[PacketProbe]) PeerState { +func (s *stateClientRelayed) OnProbe(msg controlMsg[packetProbe]) peerState { addr, ok := s.probes[msg.Packet.TraceID] if !ok { - return nil + return s } s.staged.DirectAddr = addr s.staged.Direct = true s.publish(s.staged) - return EnterStateClientDirect(s.StateClientDirect.State) + return enterStateClientDirect(s.stateClientDirect.pState) } -func (s *StateClientRelayed) OnLocalDiscovery(msg controlMsg[PacketLocalDiscovery]) { +func (s *stateClientRelayed) OnLocalDiscovery(msg controlMsg[packetLocalDiscovery]) { // The source port will be the multicast port, so we'll have to // construct the correct address using the peer's listed port. s.localDiscoveryAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) } -func (s *StateClientRelayed) OnPingTimer() PeerState { - if nextState := s.StateClientDirect.OnPingTimer(); nextState != nil { - return nextState +func (s *stateClientRelayed) OnPingTimer() peerState { + if next := s.stateClientDirect.onPingTimer(); next != nil { + return next } clear(s.probes) @@ -348,11 +350,11 @@ func (s *StateClientRelayed) OnPingTimer() PeerState { s.localDiscoveryAddr = netip.AddrPort{} } - return nil + return s } -func (s *StateClientRelayed) sendProbeTo(addr netip.AddrPort) { - probe := PacketProbe{TraceID: newTraceID()} +func (s *stateClientRelayed) sendProbeTo(addr netip.AddrPort) { + probe := packetProbe{TraceID: newTraceID()} s.probes[probe.TraceID] = addr s.SendTo(probe, addr) } diff --git a/peer/peerstates_test.go b/peer/peerstates_test.go index 16805d0..1ac531d 100644 --- a/peer/peerstates_test.go +++ b/peer/peerstates_test.go @@ -12,13 +12,13 @@ import ( // ---------------------------------------------------------------------------- type PeerStateControlMsg struct { - Peer RemotePeer + Peer remotePeer Packet any } type PeerStateTestHarness struct { - State PeerState - Published RemotePeer + State peerState + Published remotePeer Sent []PeerStateControlMsg } @@ -27,11 +27,11 @@ func NewPeerStateTestHarness() *PeerStateTestHarness { keys := generateKeys() - state := &State{ - publish: func(rp RemotePeer) { + state := &pState{ + publish: func(rp remotePeer) { h.Published = rp }, - sendControlPacket: func(rp RemotePeer, pkt Marshaller) { + sendControlPacket: func(rp remotePeer, pkt Marshaller) { h.Sent = append(h.Sent, PeerStateControlMsg{rp, pkt}) }, localIP: 2, @@ -44,7 +44,7 @@ func NewPeerStateTestHarness() *PeerStateTestHarness { }), } - h.State = EnterStateDisconnected(state) + h.State = enterStateDisconnected(state) return h } @@ -54,13 +54,13 @@ func (h *PeerStateTestHarness) PeerUpdate(p *m.Peer) { } } -func (h *PeerStateTestHarness) OnSyn(msg controlMsg[PacketSyn]) { +func (h *PeerStateTestHarness) OnSyn(msg controlMsg[packetSyn]) { if s := h.State.OnSyn(msg); s != nil { h.State = s } } -func (h *PeerStateTestHarness) OnProbe(msg controlMsg[PacketProbe]) { +func (h *PeerStateTestHarness) OnProbe(msg controlMsg[packetProbe]) { if s := h.State.OnProbe(msg); s != nil { h.State = s } @@ -72,10 +72,10 @@ func (h *PeerStateTestHarness) OnPingTimer() { } } -func (h *PeerStateTestHarness) ConfigServer_Public(t *testing.T) *StateServer { +func (h *PeerStateTestHarness) ConfigServer_Public(t *testing.T) *stateServer { keys := generateKeys() - state := h.State.(*StateDisconnected) + state := h.State.(*stateDisconnected) state.localAddr = addrPort4(1, 1, 1, 2, 200) peer := &m.Peer{ @@ -88,10 +88,10 @@ func (h *PeerStateTestHarness) ConfigServer_Public(t *testing.T) *StateServer { h.PeerUpdate(peer) assertEqual(t, h.Published.Up, false) - return assertType[*StateServer](t, h.State) + return assertType[*stateServer](t, h.State) } -func (h *PeerStateTestHarness) ConfigServer_Relayed(t *testing.T) *StateServer { +func (h *PeerStateTestHarness) ConfigServer_Relayed(t *testing.T) *stateServer { keys := generateKeys() peer := &m.Peer{ PeerIP: 3, @@ -102,10 +102,10 @@ func (h *PeerStateTestHarness) ConfigServer_Relayed(t *testing.T) *StateServer { h.PeerUpdate(peer) assertEqual(t, h.Published.Up, false) - return assertType[*StateServer](t, h.State) + return assertType[*stateServer](t, h.State) } -func (h *PeerStateTestHarness) ConfigClientDirect(t *testing.T) *StateClientDirect { +func (h *PeerStateTestHarness) ConfigClientDirect(t *testing.T) *stateClientDirect { keys := generateKeys() peer := &m.Peer{ PeerIP: 3, @@ -117,13 +117,13 @@ func (h *PeerStateTestHarness) ConfigClientDirect(t *testing.T) *StateClientDire h.PeerUpdate(peer) assertEqual(t, h.Published.Up, false) - return assertType[*StateClientDirect](t, h.State) + return assertType[*stateClientDirect](t, h.State) } -func (h *PeerStateTestHarness) ConfigClientRelayed(t *testing.T) *StateClientRelayed { +func (h *PeerStateTestHarness) ConfigClientRelayed(t *testing.T) *stateClientRelayed { keys := generateKeys() - state := h.State.(*StateDisconnected) + state := h.State.(*stateDisconnected) state.remoteIP = 1 peer := &m.Peer{ @@ -135,7 +135,7 @@ func (h *PeerStateTestHarness) ConfigClientRelayed(t *testing.T) *StateClientRel h.PeerUpdate(peer) assertEqual(t, h.Published.Up, false) - return assertType[*StateClientRelayed](t, h.State) + return assertType[*stateClientRelayed](t, h.State) } // ---------------------------------------------------------------------------- @@ -143,14 +143,14 @@ func (h *PeerStateTestHarness) ConfigClientRelayed(t *testing.T) *StateClientRel func TestPeerState_OnPeerUpdate_nilPeer(t *testing.T) { h := NewPeerStateTestHarness() h.PeerUpdate(nil) - assertType[*StateDisconnected](t, h.State) + assertType[*stateDisconnected](t, h.State) } func TestPeerState_OnPeerUpdate_publicLocalIsServer(t *testing.T) { keys := generateKeys() h := NewPeerStateTestHarness() - state := h.State.(*StateDisconnected) + state := h.State.(*stateDisconnected) state.localAddr = addrPort4(1, 1, 1, 2, 200) peer := &m.Peer{ @@ -162,7 +162,7 @@ func TestPeerState_OnPeerUpdate_publicLocalIsServer(t *testing.T) { h.PeerUpdate(peer) assertEqual(t, h.Published.Up, false) - assertType[*StateServer](t, h.State) + assertType[*stateServer](t, h.State) } func TestPeerState_OnPeerUpdate_serverDirect(t *testing.T) { @@ -191,10 +191,10 @@ func TestStateServer_directSyn(t *testing.T) { assertEqual(t, h.Published.Up, false) - synMsg := controlMsg[PacketSyn]{ + synMsg := controlMsg[packetSyn]{ SrcIP: 3, SrcAddr: addrPort4(1, 1, 1, 3, 300), - Packet: PacketSyn{ + Packet: packetSyn{ TraceID: newTraceID(), //SentAt: time.Now().UnixMilli(), //SharedKeyType: 1, @@ -205,7 +205,7 @@ func TestStateServer_directSyn(t *testing.T) { h.State.OnSyn(synMsg) assertEqual(t, len(h.Sent), 1) - ack := assertType[PacketAck](t, h.Sent[0].Packet) + ack := assertType[packetAck](t, h.Sent[0].Packet) assertEqual(t, ack.TraceID, synMsg.Packet.TraceID) assertEqual(t, h.Sent[0].Peer.IP, 3) assertEqual(t, ack.PossibleAddrs[0].IsValid(), false) @@ -220,10 +220,10 @@ func TestStateServer_relayedSyn(t *testing.T) { assertEqual(t, h.Published.Up, false) - synMsg := controlMsg[PacketSyn]{ + synMsg := controlMsg[packetSyn]{ SrcIP: 3, SrcAddr: addrPort4(1, 1, 1, 3, 300), - Packet: PacketSyn{ + Packet: packetSyn{ TraceID: newTraceID(), //SentAt: time.Now().UnixMilli(), //SharedKeyType: 1, @@ -237,15 +237,15 @@ func TestStateServer_relayedSyn(t *testing.T) { assertEqual(t, len(h.Sent), 3) - ack := assertType[PacketAck](t, h.Sent[0].Packet) + ack := assertType[packetAck](t, h.Sent[0].Packet) assertEqual(t, ack.TraceID, synMsg.Packet.TraceID) assertEqual(t, h.Sent[0].Peer.IP, 3) assertEqual(t, ack.PossibleAddrs[0], addrPort4(4, 5, 6, 7, 1234)) assertEqual(t, ack.PossibleAddrs[1].IsValid(), false) assertEqual(t, h.Published.Up, true) - assertType[PacketProbe](t, h.Sent[1].Packet) - assertType[PacketProbe](t, h.Sent[2].Packet) + assertType[packetProbe](t, h.Sent[1].Packet) + assertType[packetProbe](t, h.Sent[2].Packet) assertEqual(t, h.Sent[1].Peer.DirectAddr, addrPort4(1, 1, 1, 3, 300)) assertEqual(t, h.Sent[2].Peer.DirectAddr, addrPort4(2, 2, 2, 3, 300)) } @@ -255,17 +255,17 @@ func TestStateServer_onProbe(t *testing.T) { h.ConfigServer_Relayed(t) assertEqual(t, h.Published.Up, false) - probeMsg := controlMsg[PacketProbe]{ + probeMsg := controlMsg[packetProbe]{ SrcIP: 3, SrcAddr: addrPort4(1, 1, 1, 3, 300), - Packet: PacketProbe{TraceID: newTraceID()}, + Packet: packetProbe{TraceID: newTraceID()}, } h.State.OnProbe(probeMsg) assertEqual(t, len(h.Sent), 1) - probe := assertType[PacketProbe](t, h.Sent[0].Packet) + probe := assertType[packetProbe](t, h.Sent[0].Packet) assertEqual(t, probe.TraceID, probeMsg.Packet.TraceID) assertEqual(t, h.Sent[0].Peer.DirectAddr, addrPort4(1, 1, 1, 3, 300)) } @@ -274,10 +274,10 @@ func TestStateServer_OnPingTimer_timeout(t *testing.T) { h := NewPeerStateTestHarness() h.ConfigServer_Relayed(t) - synMsg := controlMsg[PacketSyn]{ + synMsg := controlMsg[packetSyn]{ SrcIP: 3, SrcAddr: addrPort4(1, 1, 1, 3, 300), - Packet: PacketSyn{ + Packet: packetSyn{ TraceID: newTraceID(), //SentAt: time.Now().UnixMilli(), //SharedKeyType: 1, @@ -294,7 +294,7 @@ func TestStateServer_OnPingTimer_timeout(t *testing.T) { assertEqual(t, h.Published.Up, true) // Advance the time, then ping. - state := assertType[*StateServer](t, h.State) + state := assertType[*stateServer](t, h.State) state.lastSeen = time.Now().Add(-timeoutInterval - time.Second) h.OnPingTimer() @@ -309,10 +309,10 @@ func TestStateClientDirect_OnAck(t *testing.T) { // On entering the state, a SYN should have been sent. assertEqual(t, len(h.Sent), 1) - syn := assertType[PacketSyn](t, h.Sent[0].Packet) + syn := assertType[packetSyn](t, h.Sent[0].Packet) - ack := controlMsg[PacketAck]{ - Packet: PacketAck{TraceID: syn.TraceID}, + ack := controlMsg[packetAck]{ + Packet: packetAck{TraceID: syn.TraceID}, } h.State.OnAck(ack) assertEqual(t, h.Published.Up, true) @@ -326,10 +326,10 @@ func TestStateClientDirect_OnAck_incorrectTraceID(t *testing.T) { // On entering the state, a SYN should have been sent. assertEqual(t, len(h.Sent), 1) - syn := assertType[PacketSyn](t, h.Sent[0].Packet) + syn := assertType[packetSyn](t, h.Sent[0].Packet) - ack := controlMsg[PacketAck]{ - Packet: PacketAck{TraceID: syn.TraceID + 1}, + ack := controlMsg[packetAck]{ + Packet: packetAck{TraceID: syn.TraceID + 1}, } h.State.OnAck(ack) assertEqual(t, h.Published.Up, false) @@ -341,15 +341,15 @@ func TestStateClientDirect_OnPingTimer(t *testing.T) { // On entering the state, a SYN should have been sent. assertEqual(t, len(h.Sent), 1) - assertType[PacketSyn](t, h.Sent[0].Packet) + assertType[packetSyn](t, h.Sent[0].Packet) h.OnPingTimer() // On ping timer, another syn should be sent. Additionally, we should remain // in the same state. assertEqual(t, len(h.Sent), 2) - assertType[PacketSyn](t, h.Sent[1].Packet) - assertType[*StateClientDirect](t, h.State) + assertType[packetSyn](t, h.Sent[1].Packet) + assertType[*stateClientDirect](t, h.State) assertEqual(t, h.Published.Up, false) } @@ -361,15 +361,15 @@ func TestStateClientDirect_OnPingTimer_timeout(t *testing.T) { // On entering the state, a SYN should have been sent. assertEqual(t, len(h.Sent), 1) - syn := assertType[PacketSyn](t, h.Sent[0].Packet) + syn := assertType[packetSyn](t, h.Sent[0].Packet) - ack := controlMsg[PacketAck]{ - Packet: PacketAck{TraceID: syn.TraceID}, + ack := controlMsg[packetAck]{ + Packet: packetAck{TraceID: syn.TraceID}, } h.State.OnAck(ack) assertEqual(t, h.Published.Up, true) - state := assertType[*StateClientDirect](t, h.State) + state := assertType[*stateClientDirect](t, h.State) state.lastSeen = time.Now().Add(-(timeoutInterval + time.Second)) h.OnPingTimer() @@ -377,8 +377,8 @@ func TestStateClientDirect_OnPingTimer_timeout(t *testing.T) { // On ping timer, we should timeout, causing the client to reset. Another SYN // will be sent when re-entering the state, but the connection should be down. assertEqual(t, len(h.Sent), 2) - assertType[PacketSyn](t, h.Sent[1].Packet) - assertType[*StateClientDirect](t, h.State) + assertType[packetSyn](t, h.Sent[1].Packet) + assertType[*stateClientDirect](t, h.State) assertEqual(t, h.Published.Up, false) } @@ -390,10 +390,10 @@ func TestStateClientRelayed_OnAck(t *testing.T) { // On entering the state, a SYN should have been sent. assertEqual(t, len(h.Sent), 1) - syn := assertType[PacketSyn](t, h.Sent[0].Packet) + syn := assertType[packetSyn](t, h.Sent[0].Packet) - ack := controlMsg[PacketAck]{ - Packet: PacketAck{TraceID: syn.TraceID}, + ack := controlMsg[packetAck]{ + Packet: packetAck{TraceID: syn.TraceID}, } h.State.OnAck(ack) assertEqual(t, h.Published.Up, true) @@ -423,9 +423,9 @@ func TestStateClientRelayed_OnPingTimer_withAddrs(t *testing.T) { // On entering the state, a SYN should have been sent. assertEqual(t, len(h.Sent), 1) - syn := assertType[PacketSyn](t, h.Sent[0].Packet) + syn := assertType[packetSyn](t, h.Sent[0].Packet) - ack := controlMsg[PacketAck]{Packet: PacketAck{TraceID: syn.TraceID}} + ack := controlMsg[packetAck]{Packet: packetAck{TraceID: syn.TraceID}} ack.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 300) ack.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 300) @@ -433,7 +433,7 @@ func TestStateClientRelayed_OnPingTimer_withAddrs(t *testing.T) { // Add a local discovery address. Note that the port will be configured port // and no the one provided here. - h.State.OnLocalDiscovery(controlMsg[PacketLocalDiscovery]{ + h.State.OnLocalDiscovery(controlMsg[packetLocalDiscovery]{ SrcIP: 3, SrcAddr: addrPort4(2, 2, 2, 3, 300), }) @@ -441,10 +441,10 @@ func TestStateClientRelayed_OnPingTimer_withAddrs(t *testing.T) { // We should see one SYN and three probe packets. h.OnPingTimer() assertEqual(t, len(h.Sent), 5) - assertType[PacketSyn](t, h.Sent[1].Packet) - assertType[PacketProbe](t, h.Sent[2].Packet) - assertType[PacketProbe](t, h.Sent[3].Packet) - assertType[PacketProbe](t, h.Sent[4].Packet) + assertType[packetSyn](t, h.Sent[1].Packet) + assertType[packetProbe](t, h.Sent[2].Packet) + assertType[packetProbe](t, h.Sent[3].Packet) + assertType[packetProbe](t, h.Sent[4].Packet) assertEqual(t, h.Sent[2].Peer.DirectAddr, addrPort4(1, 1, 1, 1, 300)) assertEqual(t, h.Sent[3].Peer.DirectAddr, addrPort4(1, 1, 1, 2, 300)) @@ -457,15 +457,15 @@ func TestStateClientRelayed_OnPingTimer_timeout(t *testing.T) { // On entering the state, a SYN should have been sent. assertEqual(t, len(h.Sent), 1) - syn := assertType[PacketSyn](t, h.Sent[0].Packet) + syn := assertType[packetSyn](t, h.Sent[0].Packet) - ack := controlMsg[PacketAck]{ - Packet: PacketAck{TraceID: syn.TraceID}, + ack := controlMsg[packetAck]{ + Packet: packetAck{TraceID: syn.TraceID}, } h.State.OnAck(ack) assertEqual(t, h.Published.Up, true) - state := assertType[*StateClientRelayed](t, h.State) + state := assertType[*stateClientRelayed](t, h.State) state.lastSeen = time.Now().Add(-(timeoutInterval + time.Second)) h.OnPingTimer() @@ -473,8 +473,8 @@ func TestStateClientRelayed_OnPingTimer_timeout(t *testing.T) { // On ping timer, we should timeout, causing the client to reset. Another SYN // will be sent when re-entering the state, but the connection should be down. assertEqual(t, len(h.Sent), 2) - assertType[PacketSyn](t, h.Sent[1].Packet) - assertType[*StateClientRelayed](t, h.State) + assertType[packetSyn](t, h.Sent[1].Packet) + assertType[*stateClientRelayed](t, h.State) assertEqual(t, h.Published.Up, false) } @@ -482,28 +482,28 @@ func TestStateClientRelayed_OnProbe_unknownAddr(t *testing.T) { h := NewPeerStateTestHarness() h.ConfigClientRelayed(t) - h.OnProbe(controlMsg[PacketProbe]{ - Packet: PacketProbe{TraceID: newTraceID()}, + h.OnProbe(controlMsg[packetProbe]{ + Packet: packetProbe{TraceID: newTraceID()}, }) - assertType[*StateClientRelayed](t, h.State) + assertType[*stateClientRelayed](t, h.State) } func TestStateClientRelayed_OnProbe_upgradeDirect(t *testing.T) { h := NewPeerStateTestHarness() h.ConfigClientRelayed(t) - syn := assertType[PacketSyn](t, h.Sent[0].Packet) + syn := assertType[packetSyn](t, h.Sent[0].Packet) - ack := controlMsg[PacketAck]{Packet: PacketAck{TraceID: syn.TraceID}} + ack := controlMsg[packetAck]{Packet: packetAck{TraceID: syn.TraceID}} ack.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 300) ack.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 300) h.State.OnAck(ack) h.OnPingTimer() - probe := assertType[PacketProbe](t, h.Sent[2].Packet) - h.OnProbe(controlMsg[PacketProbe]{Packet: probe}) + probe := assertType[packetProbe](t, h.Sent[2].Packet) + h.OnProbe(controlMsg[packetProbe]{Packet: probe}) - assertType[*StateClientDirect](t, h.State) + assertType[*stateClientDirect](t, h.State) } diff --git a/peer/peersuper.go b/peer/peersuper.go new file mode 100644 index 0000000..d402141 --- /dev/null +++ b/peer/peersuper.go @@ -0,0 +1,172 @@ +package peer + +import ( + "log" + "math/rand" + "net/netip" + "sync" + "sync/atomic" + "time" + + "git.crumpington.com/lib/go/ratelimiter" +) + +type Super struct { + writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error) + staged routingTable + shared *atomic.Pointer[routingTable] + peers [256]*PeerSuper + lock sync.Mutex + + buf1 []byte + buf2 []byte +} + +func NewSuper( + writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error), + rt *atomic.Pointer[routingTable], + privKey []byte, +) *Super { + + routes := rt.Load() + + s := &Super{ + writeToUDPAddrPort: writeToUDPAddrPort, + staged: *routes, + shared: rt, + buf1: newBuf(), + buf2: newBuf(), + } + + pubAddrs := newPubAddrStore(routes.LocalAddr) + + for i := range s.peers { + state := &pState{ + publish: s.publish, + sendControlPacket: s.send, + localIP: routes.LocalIP, + remoteIP: byte(i), + privKey: privKey, + localAddr: routes.LocalAddr, + pubAddrs: pubAddrs, + staged: routes.Peers[i], + limiter: ratelimiter.New(ratelimiter.Config{ + FillPeriod: 20 * time.Millisecond, + MaxWaitCount: 1, + }), + } + s.peers[i] = NewPeerSuper(state) + } + + return s +} + +func (s *Super) Start() { + for i := range s.peers { + go s.peers[i].Run() + } +} + +func (s *Super) HandleControlMsg(destIP byte, msg any) { + s.peers[destIP].HandleControlMsg(msg) +} + +func (s *Super) send(peer remotePeer, pkt Marshaller) { + s.lock.Lock() + defer s.lock.Unlock() + + enc := peer.EncryptControlPacket(pkt, s.buf1, s.buf2) + if peer.Direct { + s.writeToUDPAddrPort(enc, peer.DirectAddr) + return + } + + relay, ok := s.staged.GetRelay() + if !ok { + return + } + + enc = relay.EncryptDataPacket(peer.IP, enc, s.buf1) + s.writeToUDPAddrPort(enc, relay.DirectAddr) +} + +func (s *Super) publish(rp remotePeer) { + s.lock.Lock() + defer s.lock.Unlock() + + s.staged.Peers[rp.IP] = rp + s.ensureRelay() + copy := s.staged + s.shared.Store(©) +} + +func (s *Super) ensureRelay() { + if _, ok := s.staged.GetRelay(); ok { + return + } + + // TODO: Random selection? + for _, peer := range s.staged.Peers { + if peer.Up && peer.Direct && peer.Relay { + s.staged.RelayIP = peer.IP + return + } + } +} + +// ---------------------------------------------------------------------------- + +type PeerSuper struct { + messages chan any + state peerState +} + +func NewPeerSuper(state *pState) *PeerSuper { + return &PeerSuper{ + messages: make(chan any, 8), + state: state.OnPeerUpdate(nil), + } +} + +func (s *PeerSuper) HandleControlMsg(msg any) { + select { + case s.messages <- msg: + default: + } +} + +func (s *PeerSuper) Run() { + go func() { + // Randomize ping timers. + time.Sleep(time.Duration(rand.Intn(4000)) * time.Millisecond) + for range time.Tick(pingInterval) { + s.messages <- pingTimerMsg{} + } + }() + + for rawMsg := range s.messages { + switch msg := rawMsg.(type) { + + case peerUpdateMsg: + s.state = s.state.OnPeerUpdate(msg.Peer) + + case controlMsg[packetSyn]: + s.state = s.state.OnSyn(msg) + + case controlMsg[packetAck]: + s.state.OnAck(msg) + + case controlMsg[packetProbe]: + s.state = s.state.OnProbe(msg) + + case controlMsg[packetLocalDiscovery]: + s.state.OnLocalDiscovery(msg) + + case pingTimerMsg: + s.state = s.state.OnPingTimer() + + default: + log.Printf("WARNING: unknown message type: %+v", msg) + } + } +} diff --git a/peer/routingtable.go b/peer/routingtable.go index 0943ab2..7bbf542 100644 --- a/peer/routingtable.go +++ b/peer/routingtable.go @@ -7,9 +7,9 @@ import ( ) // TODO: Remove -func NewRemotePeer(ip byte) *RemotePeer { +func NewRemotePeer(ip byte) *remotePeer { counter := uint64(time.Now().Unix()<<30 + 1) - return &RemotePeer{ + return &remotePeer{ IP: ip, counter: &counter, dupCheck: newDupCheck(0), @@ -18,7 +18,7 @@ func NewRemotePeer(ip byte) *RemotePeer { // ---------------------------------------------------------------------------- -type RemotePeer struct { +type remotePeer struct { localIP byte IP byte // VPN IP of peer (last byte). Up bool // True if data can be sent on the peer. @@ -33,7 +33,7 @@ type RemotePeer struct { dupCheck *dupCheck // For receiving from. Not safe for concurrent use. } -func (p RemotePeer) EncryptDataPacket(destIP byte, data, out []byte) []byte { +func (p remotePeer) EncryptDataPacket(destIP byte, data, out []byte) []byte { h := header{ StreamID: dataStreamID, Counter: atomic.AddUint64(p.counter, 1), @@ -44,7 +44,7 @@ func (p RemotePeer) EncryptDataPacket(destIP byte, data, out []byte) []byte { } // Decrypts and de-dups incoming data packets. -func (p RemotePeer) DecryptDataPacket(h header, enc, out []byte) ([]byte, error) { +func (p remotePeer) DecryptDataPacket(h header, enc, out []byte) ([]byte, error) { dec, ok := p.DataCipher.Decrypt(enc, out) if !ok { return nil, errDecryptionFailed @@ -58,21 +58,22 @@ func (p RemotePeer) DecryptDataPacket(h header, enc, out []byte) ([]byte, error) } // Peer must have a ControlCipher. -func (p RemotePeer) EncryptControlPacket(pkt Marshaller, tmp, out []byte) []byte { +func (p remotePeer) EncryptControlPacket(pkt Marshaller, tmp, out []byte) []byte { + tmp = pkt.Marshal(tmp) h := header{ StreamID: controlStreamID, Counter: atomic.AddUint64(p.counter, 1), SourceIP: p.localIP, DestIP: p.IP, } - tmp = pkt.Marshal(tmp) + return p.ControlCipher.Encrypt(h, tmp, out) } // Returns a controlMsg[PacketType]. Peer must have a non-nil ControlCipher. // // This function also drops packets with duplicate sequence numbers. -func (p RemotePeer) DecryptControlPacket(fromAddr netip.AddrPort, h header, enc, tmp []byte) (any, error) { +func (p remotePeer) DecryptControlPacket(fromAddr netip.AddrPort, h header, enc, tmp []byte) (any, error) { out, ok := p.ControlCipher.Decrypt(enc, tmp) if !ok { return nil, errDecryptionFailed @@ -92,7 +93,7 @@ func (p RemotePeer) DecryptControlPacket(fromAddr netip.AddrPort, h header, enc, // ---------------------------------------------------------------------------- -type RoutingTable struct { +type routingTable struct { // The LocalIP is the configured IP address of the local peer on the VPN. // // This value is constant. @@ -106,21 +107,21 @@ type RoutingTable struct { LocalAddr netip.AddrPort // The remote peer configurations. These are updated by - Peers [256]RemotePeer + Peers [256]remotePeer // The current relay's VPN IP address, or zero if no relay is available. RelayIP byte } -func NewRoutingTable(localIP byte, localAddr netip.AddrPort) RoutingTable { - rt := RoutingTable{ +func newRoutingTable(localIP byte, localAddr netip.AddrPort) routingTable { + rt := routingTable{ LocalIP: localIP, LocalAddr: localAddr, } for i := range rt.Peers { counter := uint64(time.Now().Unix()<<30 + 1) - rt.Peers[i] = RemotePeer{ + rt.Peers[i] = remotePeer{ localIP: localIP, IP: byte(i), counter: &counter, @@ -131,7 +132,7 @@ func NewRoutingTable(localIP byte, localAddr netip.AddrPort) RoutingTable { return rt } -func (rt *RoutingTable) GetRelay() (RemotePeer, bool) { +func (rt *routingTable) GetRelay() (remotePeer, bool) { relay := rt.Peers[rt.RelayIP] return relay, relay.Up && relay.Direct } diff --git a/peer/routingtable_test.go b/peer/routingtable_test.go index b5497a4..919449b 100644 --- a/peer/routingtable_test.go +++ b/peer/routingtable_test.go @@ -74,7 +74,7 @@ func TestRemotePeer_DecryptControlPacket(t *testing.T) { peer2 := p1.RT.Load().Peers[2] peer1 := p2.RT.Load().Peers[1] - orig := PacketProbe{TraceID: newTraceID()} + orig := packetProbe{TraceID: newTraceID()} enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf()) @@ -88,7 +88,7 @@ func TestRemotePeer_DecryptControlPacket(t *testing.T) { t.Fatal(err) } - dec, ok := ctrlMsg.(controlMsg[PacketProbe]) + dec, ok := ctrlMsg.(controlMsg[packetProbe]) if !ok { t.Fatal(ctrlMsg) } @@ -108,7 +108,7 @@ func TestRemotePeer_DecryptControlPacket_packetAltered(t *testing.T) { peer2 := p1.RT.Load().Peers[2] peer1 := p2.RT.Load().Peers[1] - orig := PacketProbe{TraceID: newTraceID()} + orig := packetProbe{TraceID: newTraceID()} enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf()) @@ -131,7 +131,7 @@ func TestRemotePeer_DecryptControlPacket_duplicateSequenceNumber(t *testing.T) { peer2 := p1.RT.Load().Peers[2] peer1 := p2.RT.Load().Peers[1] - orig := PacketProbe{TraceID: newTraceID()} + orig := packetProbe{TraceID: newTraceID()} enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf()) diff --git a/peer/supervisor.go b/peer/supervisor.go deleted file mode 100644 index 0f82a3f..0000000 --- a/peer/supervisor.go +++ /dev/null @@ -1,103 +0,0 @@ -package peer - -import ( - "log" - "sync/atomic" - "time" - - "git.crumpington.com/lib/go/ratelimiter" -) - -// ---------------------------------------------------------------------------- - -type Supervisor struct { - messages chan any // Incoming control messages. - peers [256]PeerState - pubAddrs *pubAddrStore - rt *atomic.Pointer[RoutingTable] - staged RoutingTable -} - -func NewSupervisor( - sendControl func(RemotePeer, Marshaller), - privKey []byte, - rt *atomic.Pointer[RoutingTable], -) *Supervisor { - s := &Supervisor{ - messages: make(chan any, 1024), - pubAddrs: newPubAddrStore(rt.Load().LocalAddr), - rt: rt, - } - - routes := rt.Load() - - for i := range s.peers { - state := &State{ - publish: s.Publish, - sendControlPacket: sendControl, - localIP: routes.LocalIP, - remoteIP: byte(i), - privKey: privKey, - localAddr: routes.LocalAddr, - pubAddrs: s.pubAddrs, - staged: routes.Peers[i], - limiter: ratelimiter.New(ratelimiter.Config{ - FillPeriod: 20 * time.Millisecond, - MaxWaitCount: 1, - }), - } - s.peers[i] = state.OnPeerUpdate(nil) - } - - return s -} - -func (s *Supervisor) HandleControlMsg(msg any) { - select { - case s.messages <- msg: - default: - } -} - -func (s *Supervisor) Run() { - for raw := range s.messages { - switch msg := raw.(type) { - - case peerUpdateMsg: - s.peers[msg.PeerIP] = s.peers[msg.PeerIP].OnPeerUpdate(msg.Peer) - - case controlMsg[PacketSyn]: - if newState := s.peers[msg.SrcIP].OnSyn(msg); newState != nil { - s.peers[msg.SrcIP] = newState - } - - case controlMsg[PacketAck]: - s.peers[msg.SrcIP].OnAck(msg) - - case controlMsg[PacketProbe]: - if newState := s.peers[msg.SrcIP].OnProbe(msg); newState != nil { - s.peers[msg.SrcIP] = newState - } - - case controlMsg[PacketLocalDiscovery]: - s.peers[msg.SrcIP].OnLocalDiscovery(msg) - - case pingTimerMsg: - s.pubAddrs.Clean() - for i := range s.peers { - if newState := s.peers[i].OnPingTimer(); newState != nil { - s.peers[i] = newState - } - } - - default: - log.Printf("WARNING: unknown message type: %+v", msg) - } - } -} - -func (s *Supervisor) Publish(rp RemotePeer) { - s.staged.Peers[rp.IP] = rp - rt := s.staged // Copy. - s.rt.Store(&rt) -}