This commit is contained in:
jdl 2025-02-19 14:13:25 +01:00
parent affeb0b9ce
commit 08dc79283e
32 changed files with 873 additions and 1685 deletions

View File

@ -2,10 +2,10 @@ package main
import ( import (
"log" "log"
"vppn/node" "vppn/peer"
) )
func main() { func main() {
log.SetFlags(0) log.SetFlags(0)
node.Main() peer.Main()
} }

View File

@ -258,7 +258,6 @@ func handleControlPacket(addr netip.AddrPort, h header, data, decBuf []byte) {
default: default:
log.Printf("Dropping control packet.") log.Printf("Dropping control packet.")
} }
} }
func handleDataPacket(h header, data []byte, decBuf []byte, iface ifWriter, sender dataPacketSender) { func handleDataPacket(h header, data []byte, decBuf []byte, iface ifWriter, sender dataPacketSender) {

View File

@ -12,7 +12,7 @@ type connReader struct {
sender encryptedPacketSender sender encryptedPacketSender
super controlMsgHandler super controlMsgHandler
localIP byte localIP byte
peers [256]*atomic.Pointer[RemotePeer] peers [256]*atomic.Pointer[remotePeer]
buf []byte buf []byte
decBuf []byte decBuf []byte
@ -24,7 +24,7 @@ func newConnReader(
sender encryptedPacketSender, sender encryptedPacketSender,
super controlMsgHandler, super controlMsgHandler,
localIP byte, localIP byte,
peers [256]*atomic.Pointer[RemotePeer], peers [256]*atomic.Pointer[remotePeer],
) *connReader { ) *connReader {
return &connReader{ return &connReader{
conn: conn, conn: conn,
@ -79,7 +79,7 @@ func (r *connReader) handleNextPacket() {
} }
func (r *connReader) handleControlPacket( func (r *connReader) handleControlPacket(
peer *RemotePeer, peer *remotePeer,
addr netip.AddrPort, addr netip.AddrPort,
h header, h header,
enc []byte, enc []byte,
@ -102,7 +102,7 @@ func (r *connReader) handleControlPacket(
r.super.HandleControlMsg(msg) r.super.HandleControlMsg(msg)
} }
func (r *connReader) handleDataPacket(peer *RemotePeer, h header, enc []byte) { func (r *connReader) handleDataPacket(peer *remotePeer, h header, enc []byte) {
if !peer.Up { if !peer.Up {
r.logf("Not connected (recv).") r.logf("Not connected (recv).")
return return

View File

@ -12,12 +12,12 @@ type ConnReader struct {
readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error) readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error)
// Output // Output
writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error)
iface io.Writer iface io.Writer
forwardData func(ip byte, pkt []byte) handleControlMsg func(fromIP byte, pkt any)
handleControlMsg func(pkt any)
localIP byte localIP byte
rt *atomic.Pointer[RoutingTable] rt *atomic.Pointer[routingTable]
buf []byte buf []byte
decBuf []byte decBuf []byte
@ -25,15 +25,15 @@ type ConnReader struct {
func NewConnReader( func NewConnReader(
readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error), readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error),
writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error),
iface io.Writer, iface io.Writer,
forwardData func(ip byte, pkt []byte), handleControlMsg func(fromIP byte, pkt any),
handleControlMsg func(pkt any), rt *atomic.Pointer[routingTable],
rt *atomic.Pointer[RoutingTable],
) *ConnReader { ) *ConnReader {
return &ConnReader{ return &ConnReader{
readFromUDPAddrPort: readFromUDPAddrPort, readFromUDPAddrPort: readFromUDPAddrPort,
writeToUDPAddrPort: writeToUDPAddrPort,
iface: iface, iface: iface,
forwardData: forwardData,
handleControlMsg: handleControlMsg, handleControlMsg: handleControlMsg,
localIP: rt.Load().LocalIP, localIP: rt.Load().LocalIP,
rt: rt, rt: rt,
@ -50,7 +50,9 @@ func (r *ConnReader) Run() {
func (r *ConnReader) handleNextPacket() { func (r *ConnReader) handleNextPacket() {
buf := r.buf[:bufferSize] buf := r.buf[:bufferSize]
log.Printf("Getting next packet...")
n, remoteAddr, err := r.readFromUDPAddrPort(buf) n, remoteAddr, err := r.readFromUDPAddrPort(buf)
log.Printf("Packet from %v...", remoteAddr)
if err != nil { if err != nil {
log.Fatalf("Failed to read from UDP port: %v", err) log.Fatalf("Failed to read from UDP port: %v", err)
} }
@ -64,14 +66,14 @@ func (r *ConnReader) handleNextPacket() {
buf = buf[:n] buf = buf[:n]
h := parseHeader(buf) h := parseHeader(buf)
peer := r.rt.Load().Peers[h.SourceIP] rt := r.rt.Load()
//peer := rt.Peers[h.SourceIP] peer := rt.Peers[h.SourceIP]
switch h.StreamID { switch h.StreamID {
case controlStreamID: case controlStreamID:
r.handleControlPacket(remoteAddr, peer, h, buf) r.handleControlPacket(remoteAddr, peer, h, buf)
case dataStreamID: case dataStreamID:
r.handleDataPacket(peer, h, buf) r.handleDataPacket(rt, peer, h, buf)
default: default:
r.logf("Unknown stream ID: %d", h.StreamID) r.logf("Unknown stream ID: %d", h.StreamID)
} }
@ -79,7 +81,7 @@ func (r *ConnReader) handleNextPacket() {
func (r *ConnReader) handleControlPacket( func (r *ConnReader) handleControlPacket(
remoteAddr netip.AddrPort, remoteAddr netip.AddrPort,
peer RemotePeer, peer remotePeer,
h header, h header,
enc []byte, enc []byte,
) { ) {
@ -98,11 +100,12 @@ func (r *ConnReader) handleControlPacket(
return return
} }
r.handleControlMsg(msg) r.handleControlMsg(h.SourceIP, msg)
} }
func (r *ConnReader) handleDataPacket( func (r *ConnReader) handleDataPacket(
peer RemotePeer, rt *routingTable,
peer remotePeer,
h header, h header,
enc []byte, enc []byte,
) { ) {
@ -124,7 +127,13 @@ func (r *ConnReader) handleDataPacket(
return return
} }
r.forwardData(h.DestIP, data) relay, ok := rt.GetRelay()
if !ok {
r.logf("Relay not available.")
return
}
r.writeToUDPAddrPort(data, relay.DirectAddr)
} }
func (r *ConnReader) logf(format string, args ...any) { func (r *ConnReader) logf(format string, args ...any) {

View File

@ -1,353 +0,0 @@
package peer
import (
"bytes"
"crypto/rand"
"net/netip"
"reflect"
"sync/atomic"
"testing"
)
type mockIfWriter struct {
Written [][]byte
}
func (w *mockIfWriter) Write(b []byte) (int, error) {
w.Written = append(w.Written, bytes.Clone(b))
return len(b), nil
}
type mockEncryptedPacket struct {
Packet []byte
Route *RemotePeer
}
type mockEncryptedPacketSender struct {
Sent []mockEncryptedPacket
}
func (m *mockEncryptedPacketSender) SendEncryptedDataPacket(pkt []byte, route *RemotePeer) {
m.Sent = append(m.Sent, mockEncryptedPacket{
Packet: bytes.Clone(pkt),
Route: route,
})
}
type mockControlMsgHandler struct {
Messages []any
}
func (m *mockControlMsgHandler) HandleControlMsg(pkt any) {
m.Messages = append(m.Messages, pkt)
}
type udpPipe struct {
packets chan []byte
}
func newUDPPipe() *udpPipe {
return &udpPipe{make(chan []byte, 1024)}
}
func (p *udpPipe) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
p.packets <- bytes.Clone(b)
return len(b), nil
}
func (p *udpPipe) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) {
packet := <-p.packets
copy(b, packet)
return len(packet), netip.AddrPort{}, nil
}
type connReaderTestHarness struct {
Pipe *udpPipe
R *connReader
WRemote *connWriter
WRelayRemote *connWriter
Remote *RemotePeer
RelayRemote *RemotePeer
IFace *mockIfWriter
Sender *mockEncryptedPacketSender
Super *mockControlMsgHandler
}
// Peer 2 is indirect, peer 3 is direct.
func newConnReadeTestHarness() (h connReaderTestHarness) {
pipe := newUDPPipe()
routes := [256]*atomic.Pointer[RemotePeer]{}
for i := range routes {
routes[i] = &atomic.Pointer[RemotePeer]{}
routes[i].Store(&RemotePeer{})
}
local, remote, relayLocal, relayRemote := testConnWriter_getTestRoutes()
routes[2].Store(local)
routes[3].Store(relayLocal)
h.Pipe = pipe
h.WRemote = newConnWriter(pipe, 2)
h.WRelayRemote = newConnWriter(pipe, 3)
h.Remote = remote
h.RelayRemote = relayRemote
h.IFace = &mockIfWriter{}
h.Sender = &mockEncryptedPacketSender{}
h.Super = &mockControlMsgHandler{}
h.R = newConnReader(
pipe,
h.IFace,
h.Sender,
h.Super,
1,
routes)
return h
}
// Testing that we can receive a control packet.
func TestConnReader_handleControlPacket(t *testing.T) {
h := newConnReadeTestHarness()
pkt := PacketSyn{TraceID: 1234}
h.WRemote.SendControlPacket(pkt, h.Remote)
h.R.handleNextPacket()
if len(h.Super.Messages) != 1 {
t.Fatal(h.Super.Messages)
}
msg := h.Super.Messages[0].(controlMsg[PacketSyn])
if !reflect.DeepEqual(pkt, msg.Packet) {
t.Fatal(msg.Packet)
}
}
// Testing that a short packet is ignored.
func TestConnReader_handleNextPacket_short(t *testing.T) {
h := newConnReadeTestHarness()
h.Pipe.WriteToUDPAddrPort([]byte{1, 2, 3}, netip.AddrPort{})
h.R.handleNextPacket()
if len(h.Super.Messages) != 0 {
t.Fatal(h.Super.Messages)
}
}
// Testing that a packet with an unexpected stream ID is ignored.
func TestConnReader_handleNextPacket_unknownStreamID(t *testing.T) {
h := newConnReadeTestHarness()
pkt := PacketSyn{TraceID: 1234}
encrypted := encryptControlPacket(1, h.Remote, pkt, newBuf(), newBuf())
var header header
header.Parse(encrypted)
header.StreamID = 100
header.Marshal(encrypted)
h.WRemote.writeTo(encrypted, netip.AddrPort{})
h.R.handleNextPacket()
if len(h.Super.Messages) != 0 {
t.Fatal(h.Super.Messages)
}
}
// Testing that control packet without matching control cipher is ignored.
func TestConnReader_handleControlPacket_noCipher(t *testing.T) {
h := newConnReadeTestHarness()
pkt := PacketSyn{TraceID: 1234}
//encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote)
encrypted := encryptControlPacket(1, h.Remote, pkt, newBuf(), newBuf())
var header header
header.Parse(encrypted)
header.SourceIP = 10
header.Marshal(encrypted)
h.WRemote.writeTo(encrypted, netip.AddrPort{})
h.R.handleNextPacket()
if len(h.Super.Messages) != 0 {
t.Fatal(h.Super.Messages)
}
}
// Testing that control packet with incrrect destination IP is ignored.
func TestConnReader_handleControlPacket_incorrectDest(t *testing.T) {
h := newConnReadeTestHarness()
pkt := PacketSyn{TraceID: 1234}
encrypted := encryptControlPacket(2, h.Remote, pkt, newBuf(), newBuf())
var header header
header.Parse(encrypted)
header.DestIP++
header.Marshal(encrypted)
h.WRemote.writeTo(encrypted, netip.AddrPort{})
h.R.handleNextPacket()
if len(h.Super.Messages) != 0 {
t.Fatal(h.Super.Messages)
}
}
// Testing that modified control packet is ignored.
func TestConnReader_handleControlPacket_modified(t *testing.T) {
h := newConnReadeTestHarness()
pkt := PacketSyn{TraceID: 1234}
encrypted := encryptControlPacket(2, h.Remote, pkt, newBuf(), newBuf())
encrypted[len(encrypted)-1]++
h.WRemote.writeTo(encrypted, netip.AddrPort{})
h.R.handleNextPacket()
if len(h.Super.Messages) != 0 {
t.Fatal(h.Super.Messages)
}
}
type unknownPacket struct{}
func (p unknownPacket) Marshal(buf []byte) []byte {
buf = buf[:1]
buf[0] = 100
return buf
}
// Testing that an empty control packet is ignored.
func TestConnReader_handleControlPacket_unknownPacketType(t *testing.T) {
h := newConnReadeTestHarness()
pkt := unknownPacket{}
encrypted := encryptControlPacket(2, h.Remote, pkt, newBuf(), newBuf())
h.WRemote.writeTo(encrypted, netip.AddrPort{})
h.R.handleNextPacket()
if len(h.Super.Messages) != 0 {
t.Fatal(h.Super.Messages)
}
}
// Testing that a duplicate control packet is ignored.
func TestConnReader_handleControlPacket_duplicate(t *testing.T) {
h := newConnReadeTestHarness()
pkt := PacketAck{TraceID: 1234}
h.WRemote.SendControlPacket(pkt, h.Remote)
*h.Remote.counter = *h.Remote.counter - 1
h.WRemote.SendControlPacket(pkt, h.Remote)
h.R.handleNextPacket()
h.R.handleNextPacket()
if len(h.Super.Messages) != 1 {
t.Fatal(h.Super.Messages)
}
msg := h.Super.Messages[0].(controlMsg[PacketAck])
if !reflect.DeepEqual(pkt, msg.Packet) {
t.Fatal(msg.Packet)
}
}
// Testing that we can receive a data packet.
func TestConnReader_handleDataPacket(t *testing.T) {
h := newConnReadeTestHarness()
pkt := make([]byte, 1024)
rand.Read(pkt)
h.WRemote.SendDataPacket(pkt, h.Remote)
h.R.handleNextPacket()
if len(h.IFace.Written) != 1 {
t.Fatal(h.IFace.Written)
}
if !bytes.Equal(pkt, h.IFace.Written[0]) {
t.Fatal(h.IFace.Written)
}
}
// Testing that data packet is ignored if route isn't up.
func TestConnReader_handleDataPacket_routeDown(t *testing.T) {
h := newConnReadeTestHarness()
pkt := make([]byte, 1024)
rand.Read(pkt)
h.WRemote.SendDataPacket(pkt, h.Remote)
route := h.R.peers[2].Load()
route.Up = false
h.R.handleNextPacket()
if len(h.IFace.Written) != 0 {
t.Fatal(h.IFace.Written)
}
}
// Testing that a duplicate data packet is ignored.
func TestConnReader_handleDataPacket_duplicate(t *testing.T) {
h := newConnReadeTestHarness()
pkt := make([]byte, 123)
h.WRemote.SendDataPacket(pkt, h.Remote)
*h.Remote.counter = *h.Remote.counter - 1
h.WRemote.SendDataPacket(pkt, h.Remote)
h.R.handleNextPacket()
h.R.handleNextPacket()
if len(h.IFace.Written) != 1 {
t.Fatal(h.IFace.Written)
}
if !bytes.Equal(pkt, h.IFace.Written[0]) {
t.Fatal(h.IFace.Written)
}
}
// Testing that we can relay a data packet.
func TestConnReader_handleDataPacket_relay(t *testing.T) {
h := newConnReadeTestHarness()
pkt := make([]byte, 1024)
rand.Read(pkt)
h.RelayRemote.IP = 3
h.WRemote.RelayDataPacket(pkt, h.RelayRemote, h.Remote)
h.R.handleNextPacket()
if len(h.Sender.Sent) != 1 {
t.Fatal(h.Sender.Sent)
}
}
// Testing that we drop a relayed packet if destination is down.
func TestConnReader_handleDataPacket_relayDown(t *testing.T) {
h := newConnReadeTestHarness()
pkt := make([]byte, 1024)
rand.Read(pkt)
h.RelayRemote.IP = 3
relay := h.R.peers[3].Load()
relay.Up = false
h.WRemote.RelayDataPacket(pkt, h.RelayRemote, h.Remote)
h.R.handleNextPacket()
if len(h.Sender.Sent) != 0 {
t.Fatal(h.Sender.Sent)
}
}

View File

@ -1,80 +0,0 @@
package peer
import (
"log"
"net/netip"
"sync"
)
// ----------------------------------------------------------------------------
type connWriter struct {
localIP byte
conn udpWriter
// For sending control packets.
cBuf1 []byte
cBuf2 []byte
// For sending data packets.
dBuf1 []byte
dBuf2 []byte
// Lock around for sending on UDP Conn.
wLock sync.Mutex
}
func newConnWriter(conn udpWriter, localIP byte) *connWriter {
w := &connWriter{
localIP: localIP,
conn: conn,
cBuf1: make([]byte, bufferSize),
cBuf2: make([]byte, bufferSize),
dBuf1: make([]byte, bufferSize),
dBuf2: make([]byte, bufferSize),
}
return w
}
// Not safe for concurrent use. Should only be called by supervisor.
func (w *connWriter) SendControlPacket(pkt Marshaller, peer *RemotePeer) {
enc := encryptControlPacket(w.localIP, peer, pkt, w.cBuf1, w.cBuf2)
w.writeTo(enc, peer.DirectAddr)
}
// Relay control packet. Peer must not be nil.
func (w *connWriter) RelayControlPacket(pkt Marshaller, peer, relay *RemotePeer) {
enc := encryptControlPacket(w.localIP, peer, pkt, w.cBuf1, w.cBuf2)
enc = encryptDataPacket(w.localIP, peer.IP, relay, enc, w.cBuf1)
w.writeTo(enc, relay.DirectAddr)
}
// Not safe for concurrent use. Should only be called by ifReader.
func (w *connWriter) SendDataPacket(pkt []byte, peer *RemotePeer) {
enc := encryptDataPacket(w.localIP, peer.IP, peer, pkt, w.dBuf1)
w.writeTo(enc, peer.DirectAddr)
}
// Relay a data packet. Peer must not be nil.
func (w *connWriter) RelayDataPacket(pkt []byte, peer, relay *RemotePeer) {
enc := encryptDataPacket(w.localIP, peer.IP, peer, pkt, w.dBuf1)
enc = encryptDataPacket(w.localIP, peer.IP, relay, enc, w.dBuf2)
w.writeTo(enc, relay.DirectAddr)
}
// Safe for concurrent use. Should only be called by connReader.
//
// This function will send pkt to the peer directly. This is used when a peer
// is acting as a relay and is forwarding already encrypted data for another
// peer.
func (w *connWriter) SendEncryptedDataPacket(pkt []byte, peer *RemotePeer) {
w.writeTo(pkt, peer.DirectAddr)
}
func (w *connWriter) writeTo(packet []byte, addr netip.AddrPort) {
w.wLock.Lock()
if _, err := w.conn.WriteToUDPAddrPort(packet, addr); err != nil {
log.Printf("[ConnWriter] Failed to write to UDP port: %v", err)
}
w.wLock.Unlock()
}

View File

@ -1,109 +0,0 @@
package peer
import (
"log"
"net/netip"
"sync"
"sync/atomic"
)
type ConnWriter struct {
wLock sync.Mutex // Lock around for sending on UDP Conn.
// Output.
writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error)
// Shared state.
rt *atomic.Pointer[RoutingTable]
// For sending control packets.
cBuf1 []byte
cBuf2 []byte
// For sending data packets.
dBuf1 []byte
dBuf2 []byte
}
func NewConnWriter(
writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error),
rt *atomic.Pointer[RoutingTable],
) *ConnWriter {
return &ConnWriter{
writeToUDPAddrPort: writeToUDPAddrPort,
rt: rt,
cBuf1: newBuf(),
cBuf2: newBuf(),
dBuf1: newBuf(),
dBuf2: newBuf(),
}
}
// Called by ConnReader to forward already encrypted bytes to another peer.
func (w *ConnWriter) Forward(ip byte, pkt []byte) {
peer := w.rt.Load().Peers[ip]
if !(peer.Up && peer.Direct) {
w.logf("Failed to forward to %d.", ip)
return
}
w.writeTo(pkt, peer.DirectAddr)
}
// Called by IFReader to send data. Encryption will be applied, and packet will
// be relayed if appropriate.
func (w *ConnWriter) WriteData(ip byte, pkt []byte) {
rt := w.rt.Load()
peer := rt.Peers[ip]
if !peer.Up {
w.logf("Failed to send data to %d.", ip)
return
}
enc := peer.EncryptDataPacket(ip, pkt, w.dBuf1)
if peer.Direct {
w.writeTo(enc, peer.DirectAddr)
return
}
relay, ok := rt.GetRelay()
if !ok {
w.logf("Failed to send data to %d. No relay.", ip)
return
}
enc = relay.EncryptDataPacket(ip, enc, w.dBuf2)
w.writeTo(enc, relay.DirectAddr)
}
// Called by Supervisor to send control packets.
func (w *ConnWriter) WriteControl(peer RemotePeer, pkt Marshaller) {
enc := peer.EncryptControlPacket(pkt, w.cBuf2, w.cBuf1)
if peer.Direct {
w.writeTo(enc, peer.DirectAddr)
return
}
rt := w.rt.Load()
relay, ok := rt.GetRelay()
if !ok {
w.logf("Failed to send control to %d. No relay.", peer.IP)
return
}
enc = relay.EncryptDataPacket(peer.IP, enc, w.cBuf2)
w.writeTo(enc, relay.DirectAddr)
}
func (w *ConnWriter) writeTo(pkt []byte, addr netip.AddrPort) {
w.wLock.Lock()
if _, err := w.writeToUDPAddrPort(pkt, addr); err != nil {
w.logf("Failed to write to UDP port: %v", err)
}
w.wLock.Unlock()
}
func (w *ConnWriter) logf(s string, args ...any) {
log.Printf("[ConnWriter] "+s, args...)
}

View File

@ -1,145 +0,0 @@
package peer
import (
"testing"
)
func TestConnWriter_WriteData_direct(t *testing.T) {
p1, p2, _ := NewPeersForTesting()
in := RandPacket()
p1.ConnWriter.WriteData(2, in)
packets := p2.Conn.Packets()
if len(packets) != 1 {
t.Fatal(packets)
}
}
func TestConnWriter_WriteData_peerNotUp(t *testing.T) {
p1, p2, _ := NewPeersForTesting()
p1.RT.Load().Peers[2].Up = false
in := RandPacket()
p1.ConnWriter.WriteData(2, in)
packets := p2.Conn.Packets()
if len(packets) != 0 {
t.Fatal(packets)
}
}
func TestConnWriter_WriteData_relay(t *testing.T) {
p1, _, p3 := NewPeersForTesting()
p1.RT.Load().Peers[2].Direct = false
p1.RT.Load().RelayIP = 3
in := RandPacket()
p1.ConnWriter.WriteData(2, in)
packets := p3.Conn.Packets()
if len(packets) != 1 {
t.Fatal(packets)
}
}
func TestConnWriter_WriteData_relayNotAvailable(t *testing.T) {
p1, _, p3 := NewPeersForTesting()
p1.RT.Load().Peers[2].Direct = false
p1.RT.Load().Peers[3].Up = false
p1.RT.Load().RelayIP = 3
in := RandPacket()
p1.ConnWriter.WriteData(2, in)
packets := p3.Conn.Packets()
if len(packets) != 0 {
t.Fatal(packets)
}
}
func TestConnWriter_WriteControl_direct(t *testing.T) {
p1, p2, _ := NewPeersForTesting()
orig := PacketProbe{TraceID: newTraceID()}
p1.ConnWriter.WriteControl(p1.RT.Load().Peers[2], orig)
packets := p2.Conn.Packets()
if len(packets) != 1 {
t.Fatal(packets)
}
}
func TestConnWriter_WriteControl_relay(t *testing.T) {
p1, _, p3 := NewPeersForTesting()
p1.RT.Load().Peers[2].Direct = false
p1.RT.Load().RelayIP = 3
orig := PacketProbe{TraceID: newTraceID()}
p1.ConnWriter.WriteControl(p1.RT.Load().Peers[2], orig)
packets := p3.Conn.Packets()
if len(packets) != 1 {
t.Fatal(packets)
}
}
func TestConnWriter_WriteControl_relayNotAvailable(t *testing.T) {
p1, _, p3 := NewPeersForTesting()
p1.RT.Load().Peers[2].Direct = false
p1.RT.Load().Peers[3].Up = false
p1.RT.Load().RelayIP = 3
orig := PacketProbe{TraceID: newTraceID()}
p1.ConnWriter.WriteControl(p1.RT.Load().Peers[2], orig)
packets := p3.Conn.Packets()
if len(packets) != 0 {
t.Fatal(packets)
}
}
func TestConnWriter__Forward(t *testing.T) {
p1, p2, _ := NewPeersForTesting()
in := RandPacket()
p1.ConnWriter.Forward(2, in)
packets := p2.Conn.Packets()
if len(packets) != 1 {
t.Fatal(packets)
}
}
func TestConnWriter__Forward_notUp(t *testing.T) {
p1, p2, _ := NewPeersForTesting()
p1.RT.Load().Peers[2].Up = false
in := RandPacket()
p1.ConnWriter.Forward(2, in)
packets := p2.Conn.Packets()
if len(packets) != 0 {
t.Fatal(packets)
}
}
func TestConnWriter__Forward_notDirect(t *testing.T) {
p1, p2, _ := NewPeersForTesting()
p1.RT.Load().Peers[2].Direct = false
in := RandPacket()
p1.ConnWriter.Forward(2, in)
packets := p2.Conn.Packets()
if len(packets) != 0 {
t.Fatal(packets)
}
}

View File

@ -1,240 +0,0 @@
package peer
import (
"bytes"
"net/netip"
"testing"
)
// ----------------------------------------------------------------------------
type testUDPPacket struct {
Addr netip.AddrPort
Data []byte
}
type testUDPAddrPortWriter struct {
written []testUDPPacket
}
func (w *testUDPAddrPortWriter) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
w.written = append(w.written, testUDPPacket{
Addr: addr,
Data: bytes.Clone(b),
})
return len(b), nil
}
func (w *testUDPAddrPortWriter) Written() []testUDPPacket {
out := w.written
w.written = []testUDPPacket{}
return out
}
// ----------------------------------------------------------------------------
type testPacket string
func (p testPacket) Marshal(b []byte) []byte {
b = b[:len(p)]
copy(b, []byte(p))
return b
}
// ----------------------------------------------------------------------------
func testConnWriter_getTestRoutes() (local, remote, relayLocal, relayRemote *RemotePeer) {
localKeys := generateKeys()
remoteKeys := generateKeys()
local = NewRemotePeer(2)
local.Up = true
local.Relay = false
local.PubSignKey = remoteKeys.PubSignKey
local.ControlCipher = newControlCipher(localKeys.PrivKey, remoteKeys.PubKey)
local.DataCipher = newDataCipher()
local.DirectAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 2}), 100)
remote = NewRemotePeer(1)
remote.Up = true
remote.Relay = false
remote.PubSignKey = localKeys.PubSignKey
remote.ControlCipher = newControlCipher(remoteKeys.PrivKey, localKeys.PubKey)
remote.DataCipher = local.DataCipher
remote.DirectAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 100)
rLocalKeys := generateKeys()
rRemoteKeys := generateKeys()
relayLocal = NewRemotePeer(3)
relayLocal.Up = true
relayLocal.Relay = true
relayLocal.Direct = true
relayLocal.PubSignKey = rRemoteKeys.PubSignKey
relayLocal.ControlCipher = newControlCipher(rLocalKeys.PrivKey, rRemoteKeys.PubKey)
relayLocal.DataCipher = newDataCipher()
relayLocal.DirectAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 3}), 100)
relayRemote = NewRemotePeer(1)
relayRemote.Up = true
relayRemote.Relay = false
relayRemote.Direct = true
relayRemote.PubSignKey = rLocalKeys.PubSignKey
relayRemote.ControlCipher = newControlCipher(rRemoteKeys.PrivKey, rLocalKeys.PubKey)
relayRemote.DataCipher = relayLocal.DataCipher
relayRemote.DirectAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 100)
return
}
// ----------------------------------------------------------------------------
// Testing if we can send a control packet directly to the remote route.
func TestConnWriter_SendControlPacket_direct(t *testing.T) {
route, rRoute, _, _ := testConnWriter_getTestRoutes()
route.Direct = true
writer := &testUDPAddrPortWriter{}
w := newConnWriter(writer, rRoute.IP)
in := testPacket("hello world!")
w.SendControlPacket(in, route)
out := writer.Written()
if len(out) != 1 {
t.Fatal(out)
}
if out[0].Addr != route.DirectAddr {
t.Fatal(out[0])
}
dec, ok := rRoute.ControlCipher.Decrypt(out[0].Data, make([]byte, 1024))
if !ok {
t.Fatal(ok)
}
if string(dec) != string(in) {
t.Fatal(dec)
}
}
// Testing if we can relay a packet via an intermediary.
func TestConnWriter_RelayControlPacket_relay(t *testing.T) {
route, rRoute, relay, rRelay := testConnWriter_getTestRoutes()
writer := &testUDPAddrPortWriter{}
w := newConnWriter(writer, rRoute.IP)
in := testPacket("hello world!")
w.RelayControlPacket(in, route, relay)
out := writer.Written()
if len(out) != 1 {
t.Fatal(out)
}
if out[0].Addr != relay.DirectAddr {
t.Fatal(out[0])
}
dec, ok := rRelay.DataCipher.Decrypt(out[0].Data, make([]byte, 1024))
if !ok {
t.Fatal(ok)
}
dec2, ok := rRoute.ControlCipher.Decrypt(dec, make([]byte, 1024))
if !ok {
t.Fatal(ok)
}
if string(dec2) != string(in) {
t.Fatal(dec2)
}
}
// Testing that we can send a data packet directly to a remote route.
func TestConnWriter_SendDataPacket_direct(t *testing.T) {
route, rRoute, _, _ := testConnWriter_getTestRoutes()
route.Direct = true
writer := &testUDPAddrPortWriter{}
w := newConnWriter(writer, rRoute.IP)
in := []byte("hello world!")
w.SendDataPacket(in, route)
out := writer.Written()
if len(out) != 1 {
t.Fatal(out)
}
if out[0].Addr != route.DirectAddr {
t.Fatal(out[0])
}
dec, ok := rRoute.DataCipher.Decrypt(out[0].Data, make([]byte, 1024))
if !ok {
t.Fatal(ok)
}
if !bytes.Equal(dec, in) {
t.Fatal(dec)
}
}
// Testing that we can relay a data packet via a relay.
func TestConnWriter_RelayDataPacket_relay(t *testing.T) {
route, rRoute, relay, rRelay := testConnWriter_getTestRoutes()
writer := &testUDPAddrPortWriter{}
w := newConnWriter(writer, rRoute.IP)
in := []byte("Hello world!")
w.RelayDataPacket(in, route, relay)
out := writer.Written()
if len(out) != 1 {
t.Fatal(out)
}
if out[0].Addr != relay.DirectAddr {
t.Fatal(out[0])
}
dec, ok := rRelay.DataCipher.Decrypt(out[0].Data, make([]byte, 1024))
if !ok {
t.Fatal(ok)
}
dec2, ok := rRoute.DataCipher.Decrypt(dec, make([]byte, 1024))
if !ok {
t.Fatal(ok)
}
if !bytes.Equal(dec2, in) {
t.Fatal(dec2)
}
}
// Testing that we can send an already encrypted packet.
func TestConnWriter_SendEncryptedDataPacket(t *testing.T) {
route, rRoute, _, _ := testConnWriter_getTestRoutes()
writer := &testUDPAddrPortWriter{}
w := newConnWriter(writer, rRoute.IP)
in := []byte("Hello world!")
w.SendEncryptedDataPacket(in, route)
out := writer.Written()
if len(out) != 1 {
t.Fatal(out)
}
if out[0].Addr != route.DirectAddr {
t.Fatal(out[0])
}
if !bytes.Equal(out[0].Data, in) {
t.Fatal(out[0])
}
}

View File

@ -17,25 +17,25 @@ type controlMsg[T any] struct {
func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error) { func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error) {
switch buf[0] { switch buf[0] {
case PacketTypeSyn: case packetTypeSyn:
packet, err := ParsePacketSyn(buf) packet, err := parsePacketSyn(buf)
return controlMsg[PacketSyn]{ return controlMsg[packetSyn]{
SrcIP: srcIP, SrcIP: srcIP,
SrcAddr: srcAddr, SrcAddr: srcAddr,
Packet: packet, Packet: packet,
}, err }, err
case PacketTypeAck: case packetTypeAck:
packet, err := ParsePacketAck(buf) packet, err := parsePacketAck(buf)
return controlMsg[PacketAck]{ return controlMsg[packetAck]{
SrcIP: srcIP, SrcIP: srcIP,
SrcAddr: srcAddr, SrcAddr: srcAddr,
Packet: packet, Packet: packet,
}, err }, err
case PacketTypeProbe: case packetTypeProbe:
packet, err := ParsePacketProbe(buf) packet, err := parsePacketProbe(buf)
return controlMsg[PacketProbe]{ return controlMsg[packetProbe]{
SrcIP: srcIP, SrcIP: srcIP,
SrcAddr: srcAddr, SrcAddr: srcAddr,
Packet: packet, Packet: packet,

View File

@ -36,7 +36,7 @@ func generateKeys() cryptoKeys {
// Peer must have a ControlCipher. // Peer must have a ControlCipher.
func encryptControlPacket( func encryptControlPacket(
localIP byte, localIP byte,
peer *RemotePeer, peer *remotePeer,
pkt Marshaller, pkt Marshaller,
tmp []byte, tmp []byte,
out []byte, out []byte,
@ -55,7 +55,7 @@ func encryptControlPacket(
// //
// This function also drops packets with duplicate sequence numbers. // This function also drops packets with duplicate sequence numbers.
func decryptControlPacket( func decryptControlPacket(
peer *RemotePeer, peer *remotePeer,
fromAddr netip.AddrPort, fromAddr netip.AddrPort,
h header, h header,
encrypted []byte, encrypted []byte,
@ -83,7 +83,7 @@ func decryptControlPacket(
func encryptDataPacket( func encryptDataPacket(
localIP byte, localIP byte,
destIP byte, destIP byte,
peer *RemotePeer, peer *remotePeer,
data []byte, data []byte,
out []byte, out []byte,
) []byte { ) []byte {
@ -98,7 +98,7 @@ func encryptDataPacket(
// Decrypts and de-dups incoming data packets. // Decrypts and de-dups incoming data packets.
func decryptDataPacket( func decryptDataPacket(
peer *RemotePeer, peer *remotePeer,
h header, h header,
encrypted []byte, encrypted []byte,
out []byte, out []byte,

View File

@ -9,7 +9,7 @@ import (
"testing" "testing"
) )
func newRoutePairForTesting() (*RemotePeer, *RemotePeer) { func newRoutePairForTesting() (*remotePeer, *remotePeer) {
keys1 := generateKeys() keys1 := generateKeys()
keys2 := generateKeys() keys2 := generateKeys()
@ -33,7 +33,7 @@ func TestDecryptControlPacket(t *testing.T) {
out = make([]byte, bufferSize) out = make([]byte, bufferSize)
) )
in := PacketSyn{ in := packetSyn{
TraceID: newTraceID(), TraceID: newTraceID(),
SharedKey: r1.DataCipher.Key(), SharedKey: r1.DataCipher.Key(),
Direct: true, Direct: true,
@ -47,7 +47,7 @@ func TestDecryptControlPacket(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
msg, ok := iMsg.(controlMsg[PacketSyn]) msg, ok := iMsg.(controlMsg[packetSyn])
if !ok { if !ok {
t.Fatal(ok) t.Fatal(ok)
} }
@ -64,7 +64,7 @@ func TestDecryptControlPacket_decryptionFailed(t *testing.T) {
out = make([]byte, bufferSize) out = make([]byte, bufferSize)
) )
in := PacketSyn{ in := packetSyn{
TraceID: newTraceID(), TraceID: newTraceID(),
SharedKey: r1.DataCipher.Key(), SharedKey: r1.DataCipher.Key(),
Direct: true, Direct: true,
@ -90,7 +90,7 @@ func TestDecryptControlPacket_duplicate(t *testing.T) {
out = make([]byte, bufferSize) out = make([]byte, bufferSize)
) )
in := PacketSyn{ in := packetSyn{
TraceID: newTraceID(), TraceID: newTraceID(),
SharedKey: r1.DataCipher.Key(), SharedKey: r1.DataCipher.Key(),
Direct: true, Direct: true,
@ -109,7 +109,8 @@ func TestDecryptControlPacket_duplicate(t *testing.T) {
} }
} }
func TestDecryptControlPacket_invalidPacket(t *testing.T) { /*
func TestDecryptControlPacket_invalidPacket(t *testing.T) {
var ( var (
r1, r2 = newRoutePairForTesting() r1, r2 = newRoutePairForTesting()
tmp = make([]byte, bufferSize) tmp = make([]byte, bufferSize)
@ -125,8 +126,8 @@ func TestDecryptControlPacket_invalidPacket(t *testing.T) {
if !errors.Is(err, errUnknownPacketType) { if !errors.Is(err, errUnknownPacketType) {
t.Fatal(err) t.Fatal(err)
} }
} }
*/
func TestDecryptDataPacket(t *testing.T) { func TestDecryptDataPacket(t *testing.T) {
var ( var (
r1, r2 = newRoutePairForTesting() r1, r2 = newRoutePairForTesting()

View File

@ -16,10 +16,16 @@ type hubPoller struct {
versions [256]int64 versions [256]int64
localIP byte localIP byte
netName string netName string
super controlMsgHandler handleControlMsg func(fromIP byte, msg any)
} }
func newHubPoller(localIP byte, netName, hubURL, apiKey string, super controlMsgHandler) (*hubPoller, error) { func newHubPoller(
localIP byte,
netName,
hubURL,
apiKey string,
handleControlMsg func(byte, any),
) (*hubPoller, error) {
u, err := url.Parse(hubURL) u, err := url.Parse(hubURL)
if err != nil { if err != nil {
return nil, err return nil, err
@ -40,7 +46,7 @@ func newHubPoller(localIP byte, netName, hubURL, apiKey string, super controlMsg
req: req, req: req,
localIP: localIP, localIP: localIP,
netName: netName, netName: netName,
super: super, handleControlMsg: handleControlMsg,
}, nil }, nil
} }
@ -90,7 +96,7 @@ func (hp *hubPoller) applyNetworkState(state m.NetworkState) {
for i, peer := range state.Peers { for i, peer := range state.Peers {
if i != int(hp.localIP) { if i != int(hp.localIP) {
if peer == nil || peer.Version != hp.versions[i] { if peer == nil || peer.Version != hp.versions[i] {
hp.super.HandleControlMsg(peerUpdateMsg{PeerIP: byte(i), Peer: state.Peers[i]}) hp.handleControlMsg(byte(i), peerUpdateMsg{PeerIP: byte(i), Peer: state.Peers[i]})
if peer != nil { if peer != nil {
hp.versions[i] = peer.Version hp.versions[i] = peer.Version
} }

View File

@ -1,100 +0,0 @@
package peer
import (
"io"
"log"
"sync/atomic"
)
type ifReader struct {
iface io.Reader
peers [256]*atomic.Pointer[RemotePeer]
relay *atomic.Pointer[RemotePeer]
sender dataPacketSender
}
func newIFReader(
iface io.Reader,
peers [256]*atomic.Pointer[RemotePeer],
relay *atomic.Pointer[RemotePeer],
sender dataPacketSender,
) *ifReader {
return &ifReader{
iface: iface,
peers: peers,
relay: relay,
sender: sender,
}
}
func (r *ifReader) Run() {
var (
packet = make([]byte, bufferSize)
remoteIP byte
ok bool
)
for {
packet = r.readNextPacket(packet)
if remoteIP, ok = r.parsePacket(packet); ok {
r.sendPacket(packet, remoteIP)
}
}
}
func (r *ifReader) sendPacket(pkt []byte, remoteIP byte) {
peer := r.peers[remoteIP].Load()
if !peer.Up {
log.Printf("Peer not connected: %d", remoteIP)
return
}
// Direct path => early return.
if peer.Direct {
r.sender.SendDataPacket(pkt, peer)
return
}
if relay := r.relay.Load(); relay != nil && relay.Up {
r.sender.RelayDataPacket(pkt, peer, relay)
}
}
// Get next packet, returning packet, and destination ip.
func (r *ifReader) readNextPacket(buf []byte) []byte {
n, err := r.iface.Read(buf[:cap(buf)])
if err != nil {
log.Fatalf("Failed to read from interface: %v", err)
}
return buf[:n]
}
func (r *ifReader) parsePacket(buf []byte) (byte, bool) {
n := len(buf)
if n == 0 {
return 0, false
}
version := buf[0] >> 4
switch version {
case 4:
if n < 20 {
log.Printf("Short IPv4 packet: %d", len(buf))
return 0, false
}
return buf[19], true
case 6:
if len(buf) < 40 {
log.Printf("Short IPv6 packet: %d", len(buf))
return 0, false
}
return buf[39], true
default:
log.Printf("Invalid IP packet version: %v", version)
return 0, false
}
}

View File

@ -3,22 +3,24 @@ package peer
import ( import (
"io" "io"
"log" "log"
"net/netip"
"sync/atomic"
) )
type IFReader struct { type IFReader struct {
iface io.Reader iface io.Reader
connWriter interface { writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error)
WriteData(ip byte, pkt []byte) rt *atomic.Pointer[routingTable]
} buf1 []byte
buf2 []byte
} }
func NewIFReader( func NewIFReader(
iface io.Reader, iface io.Reader,
connWriter interface { writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error),
WriteData(ip byte, pkt []byte) rt *atomic.Pointer[routingTable],
},
) *IFReader { ) *IFReader {
return &IFReader{iface, connWriter} return &IFReader{iface, writeToUDPAddrPort, rt, newBuf(), newBuf()}
} }
func (r *IFReader) Run() { func (r *IFReader) Run() {
@ -30,9 +32,32 @@ func (r *IFReader) Run() {
func (r *IFReader) handleNextPacket(packet []byte) { func (r *IFReader) handleNextPacket(packet []byte) {
packet = r.readNextPacket(packet) packet = r.readNextPacket(packet)
if remoteIP, ok := r.parsePacket(packet); ok { remoteIP, ok := r.parsePacket(packet)
r.connWriter.WriteData(remoteIP, packet) if !ok {
return
} }
rt := r.rt.Load()
peer := rt.Peers[remoteIP]
if !peer.Up {
r.logf("Peer %d not up.", peer.IP)
return
}
enc := peer.EncryptDataPacket(peer.IP, packet, r.buf1)
if peer.Direct {
r.writeToUDPAddrPort(enc, peer.DirectAddr)
return
}
relay, ok := rt.GetRelay()
if !ok {
r.logf("Relay not available for peer %d.", peer.IP)
return
}
enc = relay.EncryptDataPacket(peer.IP, enc, r.buf2)
r.writeToUDPAddrPort(enc, relay.DirectAddr)
} }
func (r *IFReader) readNextPacket(buf []byte) []byte { func (r *IFReader) readNextPacket(buf []byte) []byte {

View File

@ -1,9 +1,6 @@
package peer package peer
import ( /*
"testing"
)
func TestIFReader_IPv4(t *testing.T) { func TestIFReader_IPv4(t *testing.T) {
p1, p2, _ := NewPeersForTesting() p1, p2, _ := NewPeersForTesting()
@ -81,3 +78,4 @@ func TestIFReader_parsePacket_shortIPv6(t *testing.T) {
t.Fatal(ip, ok) t.Fatal(ip, ok)
} }
} }
*/

View File

@ -1,232 +0,0 @@
package peer
import (
"bytes"
"reflect"
"sync/atomic"
"testing"
)
// Test that we parse IPv4 packets correctly.
func TestIFReader_parsePacket_ipv4(t *testing.T) {
r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil)
pkt := make([]byte, 1234)
pkt[0] = 4 << 4
pkt[19] = 128
if ip, ok := r.parsePacket(pkt); !ok || ip != 128 {
t.Fatal(ip, ok)
}
}
// Test that we parse IPv6 packets correctly.
func TestIFReader_parsePacket_ipv6(t *testing.T) {
r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil)
pkt := make([]byte, 1234)
pkt[0] = 6 << 4
pkt[39] = 42
if ip, ok := r.parsePacket(pkt); !ok || ip != 42 {
t.Fatal(ip, ok)
}
}
/*
// Test that empty packets work as expected.
func TestIFReader_parsePacket_emptyPacket(t *testing.T) {
r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil)
pkt := make([]byte, 0)
if ip, ok := r.parsePacket(pkt); ok {
t.Fatal(ip, ok)
}
}
// Test that invalid IP versions fail.
func TestIFReader_parsePacket_invalidIPVersion(t *testing.T) {
r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil)
for i := byte(1); i < 16; i++ {
if i == 4 || i == 6 {
continue
}
pkt := make([]byte, 1234)
pkt[0] = i << 4
if ip, ok := r.parsePacket(pkt); ok {
t.Fatal(i, ip, ok)
}
}
}
// Test that short IPv4 packets fail.
func TestIFReader_parsePacket_shortIPv4(t *testing.T) {
r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil)
pkt := make([]byte, 19)
pkt[0] = 4 << 4
if ip, ok := r.parsePacket(pkt); ok {
t.Fatal(ip, ok)
}
}
// Test that short IPv6 packets fail.
func TestIFReader_parsePacket_shortIPv6(t *testing.T) {
r := newIFReader(nil, [256]*atomic.Pointer[RemotePeer]{}, nil, nil)
pkt := make([]byte, 39)
pkt[0] = 6 << 4
if ip, ok := r.parsePacket(pkt); ok {
t.Fatal(ip, ok)
}
}
// Test that we can read a packet.
func TestIFReader_readNextpacket(t *testing.T) {
in, out := net.Pipe()
r := newIFReader(out, [256]*atomic.Pointer[RemotePeer]{}, nil, nil)
defer in.Close()
defer out.Close()
go in.Write([]byte("hello world!"))
pkt := r.readNextPacket(make([]byte, bufferSize))
if !bytes.Equal(pkt, []byte("hello world!")) {
t.Fatalf("%s", pkt)
}
}
*/
// ----------------------------------------------------------------------------
type sentPacket struct {
Relayed bool
Packet []byte
Route RemotePeer
Relay RemotePeer
}
type sendPacketTestHarness struct {
Packets []sentPacket
}
func (h *sendPacketTestHarness) SendDataPacket(pkt []byte, route *RemotePeer) {
h.Packets = append(h.Packets, sentPacket{
Packet: bytes.Clone(pkt),
Route: *route,
})
}
func (h *sendPacketTestHarness) RelayDataPacket(pkt []byte, route, relay *RemotePeer) {
h.Packets = append(h.Packets, sentPacket{
Relayed: true,
Packet: bytes.Clone(pkt),
Route: *route,
Relay: *relay,
})
}
func newIFReaderForSendPacketTesting() (*ifReader, *sendPacketTestHarness) {
h := &sendPacketTestHarness{}
routes := [256]*atomic.Pointer[RemotePeer]{}
for i := range routes {
routes[i] = &atomic.Pointer[RemotePeer]{}
routes[i].Store(&RemotePeer{})
}
relay := &atomic.Pointer[RemotePeer]{}
r := newIFReader(nil, routes, relay, h)
return r, h
}
// Testing that we can send a packet directly.
func TestIFReader_sendPacket_direct(t *testing.T) {
r, h := newIFReaderForSendPacketTesting()
route := r.peers[2].Load()
route.Up = true
route.Direct = true
in := []byte("hello world")
r.sendPacket(in, 2)
if len(h.Packets) != 1 {
t.Fatal(h.Packets)
}
expected := sentPacket{
Relayed: false,
Packet: in,
Route: *route,
}
if !reflect.DeepEqual(h.Packets[0], expected) {
t.Fatal(h.Packets[0])
}
}
// Testing that we don't send a packet if route isn't up.
func TestIFReader_sendPacket_directNotUp(t *testing.T) {
r, h := newIFReaderForSendPacketTesting()
route := r.peers[2].Load()
route.Direct = true
in := []byte("hello world")
r.sendPacket(in, 2)
if len(h.Packets) != 0 {
t.Fatal(h.Packets)
}
}
// Testing that we can send a packet via a relay.
func TestIFReader_sendPacket_relayed(t *testing.T) {
r, h := newIFReaderForSendPacketTesting()
route := r.peers[2].Load()
route.Up = true
route.Direct = false
relay := r.peers[3].Load()
r.relay.Store(relay)
relay.Up = true
relay.Direct = true
in := []byte("hello world")
r.sendPacket(in, 2)
if len(h.Packets) != 1 {
t.Fatal(h.Packets)
}
expected := sentPacket{
Relayed: true,
Packet: in,
Route: *route,
Relay: *relay,
}
if !reflect.DeepEqual(h.Packets[0], expected) {
t.Fatal(h.Packets[0])
}
}
// Testing that we don't try to send on a nil relay IP.
func TestIFReader_sendPacket_nilRealy(t *testing.T) {
r, h := newIFReaderForSendPacketTesting()
route := r.peers[2].Load()
route.Up = true
route.Direct = false
in := []byte("hello world")
r.sendPacket(in, 2)
if len(h.Packets) != 0 {
t.Fatal(h.Packets)
}
}

177
peer/interface.go Normal file
View File

@ -0,0 +1,177 @@
package peer
import (
"fmt"
"io"
"log"
"net"
"os"
"syscall"
"golang.org/x/sys/unix"
)
// Get next packet, returning packet, ip, and possible error.
func readNextPacket(iface io.ReadWriteCloser, buf []byte) ([]byte, byte, error) {
var (
version byte
ip byte
)
for {
n, err := iface.Read(buf[:cap(buf)])
if err != nil {
return nil, ip, err
}
buf = buf[:n]
version = buf[0] >> 4
switch version {
case 4:
if n < 20 {
log.Printf("Short IPv4 packet: %d", len(buf))
continue
}
ip = buf[19]
case 6:
if len(buf) < 40 {
log.Printf("Short IPv6 packet: %d", len(buf))
continue
}
ip = buf[39]
default:
log.Printf("Invalid IP packet version: %v", version)
continue
}
return buf, ip, nil
}
}
func openInterface(network []byte, localIP byte, name string) (io.ReadWriteCloser, error) {
if len(network) != 4 {
return nil, fmt.Errorf("expected network to be 4 bytes, got %d", len(network))
}
ip := net.IPv4(network[0], network[1], network[2], localIP)
//////////////////////////
// Create TUN Interface //
//////////////////////////
tunFD, err := syscall.Open("/dev/net/tun", syscall.O_RDWR|unix.O_CLOEXEC, 0600)
if err != nil {
return nil, fmt.Errorf("failed to open TUN device: %w", err)
}
// New interface request.
req, err := unix.NewIfreq(name)
if err != nil {
return nil, fmt.Errorf("failed to create new TUN interface request: %w", err)
}
// Flags:
//
// IFF_NO_PI => don't add packet info data to packets sent to the interface.
// IFF_TUN => create a TUN device handling IP packets.
req.SetUint16(unix.IFF_NO_PI | unix.IFF_TUN)
err = unix.IoctlIfreq(tunFD, unix.TUNSETIFF, req)
if err != nil {
return nil, fmt.Errorf("failed to set TUN device settings: %w", err)
}
// Name may not be exactly the same?
name = req.Name()
/////////////
// Set MTU //
/////////////
// We need a socket file descriptor to set other options for some reason.
sockFD, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return nil, fmt.Errorf("failed to open socket: %w", err)
}
defer unix.Close(sockFD)
req, err = unix.NewIfreq(name)
if err != nil {
return nil, fmt.Errorf("failed to create MTU interface request: %w", err)
}
req.SetUint32(if_mtu)
if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFMTU, req); err != nil {
return nil, fmt.Errorf("failed to set interface MTU: %w", err)
}
//////////////////////
// Set Queue Length //
//////////////////////
req, err = unix.NewIfreq(name)
if err != nil {
return nil, fmt.Errorf("failed to create IP interface request: %w", err)
}
req.SetUint16(if_queue_len)
if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFTXQLEN, req); err != nil {
return nil, fmt.Errorf("failed to set interface queue length: %w", err)
}
/////////////////////
// Set IP and Mask //
/////////////////////
req, err = unix.NewIfreq(name)
if err != nil {
return nil, fmt.Errorf("failed to create IP interface request: %w", err)
}
if err := req.SetInet4Addr(ip.To4()); err != nil {
return nil, fmt.Errorf("failed to set interface request IP: %w", err)
}
if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFADDR, req); err != nil {
return nil, fmt.Errorf("failed to set interface IP: %w", err)
}
// SET MASK - must happen after setting address.
req, err = unix.NewIfreq(name)
if err != nil {
return nil, fmt.Errorf("failed to create mask interface request: %w", err)
}
if err := req.SetInet4Addr(net.IPv4(255, 255, 255, 0).To4()); err != nil {
return nil, fmt.Errorf("failed to set interface request mask: %w", err)
}
if err := unix.IoctlIfreq(sockFD, unix.SIOCSIFNETMASK, req); err != nil {
return nil, fmt.Errorf("failed to set interface mask: %w", err)
}
////////////////////////
// Bring Interface Up //
////////////////////////
req, err = unix.NewIfreq(name)
if err != nil {
return nil, fmt.Errorf("failed to create up interface request: %w", err)
}
// Get current flags.
if err = unix.IoctlIfreq(sockFD, unix.SIOCGIFFLAGS, req); err != nil {
return nil, fmt.Errorf("failed to get interface flags: %w", err)
}
flags := req.Uint16() | unix.IFF_UP | unix.IFF_RUNNING
// Set UP flag / broadcast flags.
req.SetUint16(flags)
if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFFLAGS, req); err != nil {
return nil, fmt.Errorf("failed to set interface up: %w", err)
}
return os.NewFile(uintptr(tunFD), "tun"), nil
}

View File

@ -31,17 +31,17 @@ type Marshaller interface {
} }
type dataPacketSender interface { type dataPacketSender interface {
SendDataPacket(pkt []byte, peer *RemotePeer) SendDataPacket(pkt []byte, peer *remotePeer)
RelayDataPacket(pkt []byte, peer, relay *RemotePeer) RelayDataPacket(pkt []byte, peer, relay *remotePeer)
} }
type controlPacketSender interface { type controlPacketSender interface {
SendControlPacket(pkt Marshaller, peer *RemotePeer) SendControlPacket(pkt Marshaller, peer *remotePeer)
RelayControlPacket(pkt Marshaller, peer, relay *RemotePeer) RelayControlPacket(pkt Marshaller, peer, relay *remotePeer)
} }
type encryptedPacketSender interface { type encryptedPacketSender interface {
SendEncryptedDataPacket(pkt []byte, peer *RemotePeer) SendEncryptedDataPacket(pkt []byte, peer *remotePeer)
} }
type controlMsgHandler interface { type controlMsgHandler interface {

23
peer/main.go Normal file
View File

@ -0,0 +1,23 @@
package peer
import (
"flag"
"os"
)
func Main() {
conf := Config{}
flag.StringVar(&conf.NetName, "name", "", "[REQUIRED] The network name.")
flag.StringVar(&conf.HubAddress, "hub-address", "", "[REQUIRED] The hub address.")
flag.StringVar(&conf.APIKey, "api-key", "", "[REQUIRED] The node's API key.")
flag.Parse()
if conf.NetName == "" || conf.HubAddress == "" || conf.APIKey == "" {
flag.Usage()
os.Exit(1)
}
peer := New(conf)
peer.Run()
}

View File

@ -8,7 +8,7 @@ import (
type mcReader struct { type mcReader struct {
conn udpReader conn udpReader
super controlMsgHandler super controlMsgHandler
peers [256]*atomic.Pointer[RemotePeer] peers [256]*atomic.Pointer[remotePeer]
incoming []byte incoming []byte
buf []byte buf []byte
@ -17,7 +17,7 @@ type mcReader struct {
func newMCReader( func newMCReader(
conn udpReader, conn udpReader,
super controlMsgHandler, super controlMsgHandler,
peers [256]*atomic.Pointer[RemotePeer], peers [256]*atomic.Pointer[remotePeer],
) *mcReader { ) *mcReader {
return &mcReader{conn, super, peers, newBuf(), newBuf()} return &mcReader{conn, super, peers, newBuf(), newBuf()}
} }
@ -50,7 +50,7 @@ func (r *mcReader) handleNextPacket() {
return return
} }
r.super.HandleControlMsg(controlMsg[PacketLocalDiscovery]{ r.super.HandleControlMsg(controlMsg[packetLocalDiscovery]{
SrcIP: h.SourceIP, SrcIP: h.SourceIP,
SrcAddr: remoteAddr, SrcAddr: remoteAddr,
}) })

View File

@ -1,13 +1,6 @@
package peer package peer
import ( /*
"bytes"
"net"
"net/netip"
"sync/atomic"
"testing"
)
type mcMockConn struct { type mcMockConn struct {
packets chan []byte packets chan []byte
} }
@ -136,3 +129,4 @@ func TestMCReader_badSignature(t *testing.T) {
t.Fatal(super.Messages) t.Fatal(super.Messages)
} }
} }
*/

View File

@ -5,41 +5,34 @@ import (
) )
const ( const (
PacketTypeSyn = iota + 1 packetTypeSyn = 1
PacketTypeSynAck packetTypeAck = 3
PacketTypeAck packetTypeProbe = 4
PacketTypeProbe packetTypeAddrDiscovery = 5
PacketTypeAddrDiscovery
) )
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
type PacketSyn struct { type packetSyn struct {
TraceID uint64 // TraceID to match response w/ request. TraceID uint64 // TraceID to match response w/ request.
//SentAt int64 // Unixmilli.
//SharedKeyType byte // Currently only 1 is supported for AES.
SharedKey [32]byte // Our shared key. SharedKey [32]byte // Our shared key.
Direct bool Direct bool
PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender.
} }
func (p PacketSyn) Marshal(buf []byte) []byte { func (p packetSyn) Marshal(buf []byte) []byte {
return newBinWriter(buf). return newBinWriter(buf).
Byte(PacketTypeSyn). Byte(packetTypeSyn).
Uint64(p.TraceID). Uint64(p.TraceID).
//Int64(p.SentAt).
//Byte(p.SharedKeyType).
SharedKey(p.SharedKey). SharedKey(p.SharedKey).
Bool(p.Direct). Bool(p.Direct).
AddrPort8(p.PossibleAddrs). AddrPort8(p.PossibleAddrs).
Build() Build()
} }
func ParsePacketSyn(buf []byte) (p PacketSyn, err error) { func parsePacketSyn(buf []byte) (p packetSyn, err error) {
err = newBinReader(buf[1:]). err = newBinReader(buf[1:]).
Uint64(&p.TraceID). Uint64(&p.TraceID).
//Int64(&p.SentAt).
//Byte(&p.SharedKeyType).
SharedKey(&p.SharedKey). SharedKey(&p.SharedKey).
Bool(&p.Direct). Bool(&p.Direct).
AddrPort8(&p.PossibleAddrs). AddrPort8(&p.PossibleAddrs).
@ -49,22 +42,22 @@ func ParsePacketSyn(buf []byte) (p PacketSyn, err error) {
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
type PacketAck struct { type packetAck struct {
TraceID uint64 TraceID uint64
ToAddr netip.AddrPort ToAddr netip.AddrPort
PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender.
} }
func (p PacketAck) Marshal(buf []byte) []byte { func (p packetAck) Marshal(buf []byte) []byte {
return newBinWriter(buf). return newBinWriter(buf).
Byte(PacketTypeAck). Byte(packetTypeAck).
Uint64(p.TraceID). Uint64(p.TraceID).
AddrPort(p.ToAddr). AddrPort(p.ToAddr).
AddrPort8(p.PossibleAddrs). AddrPort8(p.PossibleAddrs).
Build() Build()
} }
func ParsePacketAck(buf []byte) (p PacketAck, err error) { func parsePacketAck(buf []byte) (p packetAck, err error) {
err = newBinReader(buf[1:]). err = newBinReader(buf[1:]).
Uint64(&p.TraceID). Uint64(&p.TraceID).
AddrPort(&p.ToAddr). AddrPort(&p.ToAddr).
@ -77,18 +70,18 @@ func ParsePacketAck(buf []byte) (p PacketAck, err error) {
// A probeReqPacket is sent from a client to a server to determine if direct // A probeReqPacket is sent from a client to a server to determine if direct
// UDP communication can be used. // UDP communication can be used.
type PacketProbe struct { type packetProbe struct {
TraceID uint64 TraceID uint64
} }
func (p PacketProbe) Marshal(buf []byte) []byte { func (p packetProbe) Marshal(buf []byte) []byte {
return newBinWriter(buf). return newBinWriter(buf).
Byte(PacketTypeProbe). Byte(packetTypeProbe).
Uint64(p.TraceID). Uint64(p.TraceID).
Build() Build()
} }
func ParsePacketProbe(buf []byte) (p PacketProbe, err error) { func parsePacketProbe(buf []byte) (p packetProbe, err error) {
err = newBinReader(buf[1:]). err = newBinReader(buf[1:]).
Uint64(&p.TraceID). Uint64(&p.TraceID).
Error() Error()
@ -97,4 +90,4 @@ func ParsePacketProbe(buf []byte) (p PacketProbe, err error) {
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
type PacketLocalDiscovery struct{} type packetLocalDiscovery struct{}

View File

@ -8,7 +8,7 @@ import (
) )
func TestSynPacket(t *testing.T) { func TestSynPacket(t *testing.T) {
p := PacketSyn{ p := packetSyn{
TraceID: newTraceID(), TraceID: newTraceID(),
//SentAt: time.Now().UnixMilli(), //SentAt: time.Now().UnixMilli(),
//SharedKeyType: 1, //SharedKeyType: 1,
@ -21,7 +21,7 @@ func TestSynPacket(t *testing.T) {
p.PossibleAddrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{3, 2, 3, 4}), 60000) p.PossibleAddrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{3, 2, 3, 4}), 60000)
buf := p.Marshal(newBuf()) buf := p.Marshal(newBuf())
p2, err := ParsePacketSyn(buf) p2, err := parsePacketSyn(buf)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -31,7 +31,7 @@ func TestSynPacket(t *testing.T) {
} }
func TestAckPacket(t *testing.T) { func TestAckPacket(t *testing.T) {
p := PacketAck{ p := packetAck{
TraceID: newTraceID(), TraceID: newTraceID(),
ToAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 234), ToAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 234),
} }
@ -41,7 +41,7 @@ func TestAckPacket(t *testing.T) {
p.PossibleAddrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{3, 2, 3, 4}), 60000) p.PossibleAddrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{3, 2, 3, 4}), 60000)
buf := p.Marshal(newBuf()) buf := p.Marshal(newBuf())
p2, err := ParsePacketAck(buf) p2, err := parsePacketAck(buf)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -51,12 +51,12 @@ func TestAckPacket(t *testing.T) {
} }
func TestProbePacket(t *testing.T) { func TestProbePacket(t *testing.T) {
p := PacketProbe{ p := packetProbe{
TraceID: newTraceID(), TraceID: newTraceID(),
} }
buf := p.Marshal(newBuf()) buf := p.Marshal(newBuf())
p2, err := ParsePacketProbe(buf) p2, err := parsePacketProbe(buf)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

161
peer/peer.go Normal file
View File

@ -0,0 +1,161 @@
package peer
import (
"bytes"
"encoding/json"
"fmt"
"io"
"log"
"net"
"net/http"
"net/netip"
"net/url"
"sync"
"sync/atomic"
"vppn/m"
)
type Peer struct {
ifReader *IFReader
connReader *ConnReader
iface io.Writer
hubPoller *hubPoller
super *Super
}
type Config struct {
NetName string
HubAddress string
APIKey string
}
func New(conf Config) *Peer {
config, err := loadPeerConfig(conf.NetName)
if err != nil {
log.Printf("Failed to load configuration: %v", err)
log.Printf("Initializing...")
initPeerWithHub(conf)
config, err = loadPeerConfig(conf.NetName)
if err != nil {
log.Fatalf("Failed to load configuration: %v", err)
}
}
iface, err := openInterface(config.Network, config.PeerIP, conf.NetName)
if err != nil {
log.Fatalf("Failed to open interface: %v", err)
}
myAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", config.Port))
if err != nil {
log.Fatalf("Failed to resolve UDP address: %v", err)
}
log.Printf("Listening on %v...", myAddr)
conn, err := net.ListenUDP("udp", myAddr)
if err != nil {
log.Fatalf("Failed to open UDP port: %v", err)
}
conn.SetReadBuffer(1024 * 1024 * 8)
conn.SetWriteBuffer(1024 * 1024 * 8)
// Wrap write function - this is necessary to avoid starvation.
writeLock := sync.Mutex{}
writeToUDPAddrPort := func(b []byte, addr netip.AddrPort) (n int, err error) {
writeLock.Lock()
n, err = conn.WriteToUDPAddrPort(b, addr)
if err != nil {
log.Printf("Failed to write packet: %v", err)
}
writeLock.Unlock()
return n, err
}
var localAddr netip.AddrPort
ip, ok := netip.AddrFromSlice(config.PublicIP)
if ok {
localAddr = netip.AddrPortFrom(ip, config.Port)
}
rt := newRoutingTable(config.PeerIP, localAddr)
rtPtr := &atomic.Pointer[routingTable]{}
rtPtr.Store(&rt)
ifReader := NewIFReader(iface, writeToUDPAddrPort, rtPtr)
super := NewSuper(writeToUDPAddrPort, rtPtr, config.PrivKey)
connReader := NewConnReader(conn.ReadFromUDPAddrPort, writeToUDPAddrPort, iface, super.HandleControlMsg, rtPtr)
hubPoller, err := newHubPoller(config.PeerIP, conf.NetName, conf.HubAddress, conf.APIKey, super.HandleControlMsg)
if err != nil {
log.Fatalf("Failed to create hub poller: %v", err)
}
return &Peer{
iface: iface,
ifReader: ifReader,
connReader: connReader,
hubPoller: hubPoller,
super: super,
}
}
func (p *Peer) Run() {
go p.ifReader.Run()
go p.connReader.Run()
p.super.Start()
p.hubPoller.Run()
}
func initPeerWithHub(conf Config) {
keys := generateKeys()
initURL, err := url.Parse(conf.HubAddress)
if err != nil {
log.Fatalf("Failed to parse hub URL: %v", err)
}
initURL.Path = "/peer/init/"
args := m.PeerInitArgs{
EncPubKey: keys.PubKey,
PubSignKey: keys.PubSignKey,
}
buf := &bytes.Buffer{}
if err := json.NewEncoder(buf).Encode(args); err != nil {
log.Fatalf("Failed to encode init args: %v", err)
}
req, err := http.NewRequest(http.MethodPost, initURL.String(), buf)
if err != nil {
log.Fatalf("Failed to construct request: %v", err)
}
req.SetBasicAuth("", conf.APIKey)
resp, err := http.DefaultClient.Do(req)
if err != nil {
log.Fatalf("Failed to init with hub: %v", err)
}
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
if err != nil {
log.Fatalf("Failed to read response body: %v", err)
}
peerConfig := localConfig{}
if err := json.Unmarshal(data, &peerConfig.PeerConfig); err != nil {
log.Fatalf("Failed to parse configuration: %v\n%s", err, data)
}
peerConfig.PubKey = keys.PubKey
peerConfig.PrivKey = keys.PrivKey
peerConfig.PubSignKey = keys.PubSignKey
peerConfig.PrivSignKey = keys.PrivSignKey
if err := storePeerConfig(conf.NetName, peerConfig); err != nil {
log.Fatalf("Failed to store configuration: %v", err)
}
log.Print("Initialization successful.")
}

View File

@ -11,36 +11,25 @@ import (
// A test peer. // A test peer.
type P struct { type P struct {
cryptoKeys cryptoKeys
RT *atomic.Pointer[RoutingTable] RT *atomic.Pointer[routingTable]
Conn *TestUDPConn Conn *TestUDPConn
IFace *TestIFace IFace *TestIFace
ConnWriter *ConnWriter
ConnReader *ConnReader ConnReader *ConnReader
IFReader *IFReader IFReader *IFReader
Super *Supervisor
} }
func NewPeerForTesting(n *TestNetwork, ip byte, addr netip.AddrPort) P { func NewPeerForTesting(n *TestNetwork, ip byte, addr netip.AddrPort) P {
p := P{ p := P{
cryptoKeys: generateKeys(), cryptoKeys: generateKeys(),
RT: &atomic.Pointer[RoutingTable]{}, RT: &atomic.Pointer[routingTable]{},
IFace: NewTestIFace(), IFace: NewTestIFace(),
} }
rt := NewRoutingTable(ip, addr) rt := newRoutingTable(ip, addr)
p.RT.Store(&rt) p.RT.Store(&rt)
p.Conn = n.NewUDPConn(addr) p.Conn = n.NewUDPConn(addr)
p.ConnWriter = NewConnWriter(p.Conn.WriteToUDPAddrPort, p.RT) //p.ConnWriter = NewConnWriter(p.Conn.WriteToUDPAddrPort, p.RT)
p.IFReader = NewIFReader(p.IFace, p.ConnWriter)
/*
p.ConnReader = NewConnReader(
p.Conn.ReadFromUDPAddrPort,
p.IFace,
p.ConnWriter.Forward,
p.Super.HandleControlMsg,
p.RT)
*/
return p return p
} }

View File

@ -11,21 +11,21 @@ import (
"git.crumpington.com/lib/go/ratelimiter" "git.crumpington.com/lib/go/ratelimiter"
) )
type PeerState interface { type peerState interface {
OnPeerUpdate(*m.Peer) PeerState OnPeerUpdate(*m.Peer) peerState
OnSyn(controlMsg[PacketSyn]) PeerState OnSyn(controlMsg[packetSyn]) peerState
OnAck(controlMsg[PacketAck]) OnAck(controlMsg[packetAck])
OnProbe(controlMsg[PacketProbe]) PeerState OnProbe(controlMsg[packetProbe]) peerState
OnLocalDiscovery(controlMsg[PacketLocalDiscovery]) OnLocalDiscovery(controlMsg[packetLocalDiscovery])
OnPingTimer() PeerState OnPingTimer() peerState
} }
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
type State struct { type pState struct {
// Output. // Output.
publish func(RemotePeer) publish func(remotePeer)
sendControlPacket func(RemotePeer, Marshaller) sendControlPacket func(remotePeer, Marshaller)
// Immutable data. // Immutable data.
localIP byte localIP byte
@ -37,7 +37,7 @@ type State struct {
// The purpose of this state machine is to manage the RemotePeer object, // The purpose of this state machine is to manage the RemotePeer object,
// publishing it as necessary. // publishing it as necessary.
staged RemotePeer // Local copy of shared data. See publish(). staged remotePeer // Local copy of shared data. See publish().
// Mutable peer data. // Mutable peer data.
peer *m.Peer peer *m.Peer
@ -47,25 +47,28 @@ type State struct {
limiter *ratelimiter.Limiter limiter *ratelimiter.Limiter
} }
func (s *State) OnPeerUpdate(peer *m.Peer) PeerState { func (s *pState) OnPeerUpdate(peer *m.Peer) peerState {
defer func() { defer func() {
// Don't defer directly otherwise s.staged will be evaluated immediately // Don't defer directly otherwise s.staged will be evaluated immediately
// and won't reflect changes made in the function. // and won't reflect changes made in the function.
s.publish(s.staged) s.publish(s.staged)
}() }()
if peer == nil {
return EnterStateDisconnected(s)
}
s.peer = peer s.peer = peer
s.staged.localIP = s.localIP s.staged.localIP = s.localIP
s.staged.IP = peer.PeerIP
s.staged.Up = false s.staged.Up = false
s.staged.Relay = false s.staged.Relay = false
s.staged.Direct = false s.staged.Direct = false
s.staged.DirectAddr = netip.AddrPort{} s.staged.DirectAddr = netip.AddrPort{}
s.staged.PubSignKey = nil
s.staged.ControlCipher = nil
s.staged.DataCipher = nil
if peer == nil {
return enterStateDisconnected(s)
}
s.staged.IP = peer.PeerIP
s.staged.PubSignKey = peer.PubSignKey s.staged.PubSignKey = peer.PubSignKey
s.staged.ControlCipher = newControlCipher(s.privKey, peer.PubKey) s.staged.ControlCipher = newControlCipher(s.privKey, peer.PubKey)
s.staged.DataCipher = newDataCipher() s.staged.DataCipher = newDataCipher()
@ -76,30 +79,32 @@ func (s *State) OnPeerUpdate(peer *m.Peer) PeerState {
s.staged.DirectAddr = netip.AddrPortFrom(ip, peer.Port) s.staged.DirectAddr = netip.AddrPortFrom(ip, peer.Port)
if s.localAddr.IsValid() && s.localIP < s.remoteIP { if s.localAddr.IsValid() && s.localIP < s.remoteIP {
return EnterStateServer(s) return enterStateServer(s)
} }
return EnterStateClientDirect(s) return enterStateClientDirect(s)
} }
if s.localAddr.IsValid() { if s.localAddr.IsValid() {
s.staged.Direct = true s.staged.Direct = true
return EnterStateServer(s) return enterStateServer(s)
} }
if s.localIP < s.remoteIP { if s.localIP < s.remoteIP {
return EnterStateServer(s) return enterStateServer(s)
} }
return EnterStateClientRelayed(s) return enterStateClientRelayed(s)
} }
func (s *State) logf(format string, args ...any) { func (s *pState) logf(format string, args ...any) {
b := strings.Builder{} b := strings.Builder{}
name := "" name := ""
if s.peer != nil { if s.peer != nil {
name = s.peer.Name name = s.peer.Name
} }
b.WriteString(fmt.Sprintf("%03d", s.remoteIP))
b.WriteString(fmt.Sprintf("%30s: ", name)) b.WriteString(fmt.Sprintf("%30s: ", name))
if s.staged.Direct { if s.staged.Direct {
@ -119,7 +124,7 @@ func (s *State) logf(format string, args ...any) {
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
func (s *State) SendTo(pkt Marshaller, addr netip.AddrPort) { func (s *pState) SendTo(pkt Marshaller, addr netip.AddrPort) {
if !addr.IsValid() { if !addr.IsValid() {
return return
} }
@ -129,7 +134,7 @@ func (s *State) SendTo(pkt Marshaller, addr netip.AddrPort) {
s.Send(route, pkt) s.Send(route, pkt)
} }
func (s *State) Send(peer RemotePeer, pkt Marshaller) { func (s *pState) Send(peer remotePeer, pkt Marshaller) {
if err := s.limiter.Limit(); err != nil { if err := s.limiter.Limit(); err != nil {
s.logf("Rate limited.") s.logf("Rate limited.")
return return
@ -139,42 +144,32 @@ func (s *State) Send(peer RemotePeer, pkt Marshaller) {
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
type StateDisconnected struct{ *State } type stateDisconnected struct{ *pState }
func EnterStateDisconnected(s *State) PeerState { func enterStateDisconnected(s *pState) peerState {
s.logf("==> Disconnected") return &stateDisconnected{pState: s}
s.peer = nil
s.staged.Up = false
s.staged.Relay = false
s.staged.Direct = false
s.staged.DirectAddr = netip.AddrPort{}
s.staged.PubSignKey = nil
s.staged.ControlCipher = nil
s.staged.DataCipher = nil
s.publish(s.staged)
return &StateDisconnected{State: s}
} }
func (s *StateDisconnected) OnSyn(controlMsg[PacketSyn]) PeerState { return nil } func (s *stateDisconnected) OnSyn(controlMsg[packetSyn]) peerState { return s }
func (s *StateDisconnected) OnAck(controlMsg[PacketAck]) {} func (s *stateDisconnected) OnAck(controlMsg[packetAck]) {}
func (s *StateDisconnected) OnProbe(controlMsg[PacketProbe]) PeerState { return nil } func (s *stateDisconnected) OnProbe(controlMsg[packetProbe]) peerState { return s }
func (s *StateDisconnected) OnLocalDiscovery(controlMsg[PacketLocalDiscovery]) {} func (s *stateDisconnected) OnLocalDiscovery(controlMsg[packetLocalDiscovery]) {}
func (s *StateDisconnected) OnPingTimer() PeerState { return nil } func (s *stateDisconnected) OnPingTimer() peerState { return s }
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
type StateServer struct { type stateServer struct {
*StateDisconnected *stateDisconnected
lastSeen time.Time lastSeen time.Time
synTraceID uint64 synTraceID uint64
} }
func EnterStateServer(s *State) PeerState { func enterStateServer(s *pState) peerState {
s.logf("==> Server") s.logf("==> Server")
return &StateServer{StateDisconnected: &StateDisconnected{State: s}} return &stateServer{stateDisconnected: &stateDisconnected{pState: s}}
} }
func (s *StateServer) OnSyn(msg controlMsg[PacketSyn]) PeerState { func (s *stateServer) OnSyn(msg controlMsg[packetSyn]) peerState {
s.lastSeen = time.Now() s.lastSeen = time.Now()
p := msg.Packet p := msg.Packet
@ -194,7 +189,7 @@ func (s *StateServer) OnSyn(msg controlMsg[PacketSyn]) PeerState {
} }
// Always respond. // Always respond.
ack := PacketAck{ ack := packetAck{
TraceID: p.TraceID, TraceID: p.TraceID,
ToAddr: s.staged.DirectAddr, ToAddr: s.staged.DirectAddr,
PossibleAddrs: s.pubAddrs.Get(), PossibleAddrs: s.pubAddrs.Get(),
@ -202,55 +197,55 @@ func (s *StateServer) OnSyn(msg controlMsg[PacketSyn]) PeerState {
s.Send(s.staged, ack) s.Send(s.staged, ack)
if p.Direct { if p.Direct {
return nil return s
} }
for _, addr := range msg.Packet.PossibleAddrs { for _, addr := range msg.Packet.PossibleAddrs {
if !addr.IsValid() { if !addr.IsValid() {
break break
} }
s.SendTo(PacketProbe{TraceID: newTraceID()}, addr) s.SendTo(packetProbe{TraceID: newTraceID()}, addr)
} }
return nil return s
} }
func (s *StateServer) OnProbe(msg controlMsg[PacketProbe]) PeerState { func (s *stateServer) OnProbe(msg controlMsg[packetProbe]) peerState {
if msg.SrcAddr.IsValid() { if msg.SrcAddr.IsValid() {
s.SendTo(PacketProbe{TraceID: msg.Packet.TraceID}, msg.SrcAddr) s.SendTo(packetProbe{TraceID: msg.Packet.TraceID}, msg.SrcAddr)
} }
return nil return s
} }
func (s *StateServer) OnPingTimer() PeerState { func (s *stateServer) OnPingTimer() peerState {
if time.Since(s.lastSeen) > timeoutInterval && s.staged.Up { if time.Since(s.lastSeen) > timeoutInterval && s.staged.Up {
s.staged.Up = false s.staged.Up = false
s.publish(s.staged) s.publish(s.staged)
s.logf("Timeout.") s.logf("Timeout.")
} }
return nil return s
} }
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
type StateClientDirect struct { type stateClientDirect struct {
*StateDisconnected *stateDisconnected
lastSeen time.Time lastSeen time.Time
syn PacketSyn syn packetSyn
} }
func EnterStateClientDirect(s *State) PeerState { func enterStateClientDirect(s *pState) peerState {
s.logf("==> ClientDirect") s.logf("==> ClientDirect")
return NewStateClientDirect(s) return newStateClientDirect(s)
} }
func NewStateClientDirect(s *State) *StateClientDirect { func newStateClientDirect(s *pState) *stateClientDirect {
state := &StateClientDirect{ state := &stateClientDirect{
StateDisconnected: &StateDisconnected{s}, stateDisconnected: &stateDisconnected{s},
lastSeen: time.Now(), // Avoid immediate timeout. lastSeen: time.Now(), // Avoid immediate timeout.
} }
state.syn = PacketSyn{ state.syn = packetSyn{
TraceID: newTraceID(), TraceID: newTraceID(),
SharedKey: s.staged.DataCipher.Key(), SharedKey: s.staged.DataCipher.Key(),
Direct: s.staged.Direct, Direct: s.staged.Direct,
@ -260,7 +255,7 @@ func NewStateClientDirect(s *State) *StateClientDirect {
return state return state
} }
func (s *StateClientDirect) OnAck(msg controlMsg[PacketAck]) { func (s *stateClientDirect) OnAck(msg controlMsg[packetAck]) {
if msg.Packet.TraceID != s.syn.TraceID { if msg.Packet.TraceID != s.syn.TraceID {
return return
} }
@ -276,7 +271,14 @@ func (s *StateClientDirect) OnAck(msg controlMsg[PacketAck]) {
s.pubAddrs.Store(msg.Packet.ToAddr) s.pubAddrs.Store(msg.Packet.ToAddr)
} }
func (s *StateClientDirect) OnPingTimer() PeerState { func (s *stateClientDirect) OnPingTimer() peerState {
if next := s.onPingTimer(); next != nil {
return next
}
return s
}
func (s *stateClientDirect) onPingTimer() peerState {
if time.Since(s.lastSeen) > timeoutInterval { if time.Since(s.lastSeen) > timeoutInterval {
if s.staged.Up { if s.staged.Up {
s.staged.Up = false s.staged.Up = false
@ -292,47 +294,47 @@ func (s *StateClientDirect) OnPingTimer() PeerState {
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
type StateClientRelayed struct { type stateClientRelayed struct {
*StateClientDirect *stateClientDirect
ack PacketAck ack packetAck
probes map[uint64]netip.AddrPort probes map[uint64]netip.AddrPort
localDiscoveryAddr netip.AddrPort localDiscoveryAddr netip.AddrPort
} }
func EnterStateClientRelayed(s *State) PeerState { func enterStateClientRelayed(s *pState) peerState {
s.logf("==> ClientRelayed") s.logf("==> ClientRelayed")
return &StateClientRelayed{ return &stateClientRelayed{
StateClientDirect: NewStateClientDirect(s), stateClientDirect: newStateClientDirect(s),
probes: map[uint64]netip.AddrPort{}, probes: map[uint64]netip.AddrPort{},
} }
} }
func (s *StateClientRelayed) OnAck(msg controlMsg[PacketAck]) { func (s *stateClientRelayed) OnAck(msg controlMsg[packetAck]) {
s.ack = msg.Packet s.ack = msg.Packet
s.StateClientDirect.OnAck(msg) s.stateClientDirect.OnAck(msg)
} }
func (s *StateClientRelayed) OnProbe(msg controlMsg[PacketProbe]) PeerState { func (s *stateClientRelayed) OnProbe(msg controlMsg[packetProbe]) peerState {
addr, ok := s.probes[msg.Packet.TraceID] addr, ok := s.probes[msg.Packet.TraceID]
if !ok { if !ok {
return nil return s
} }
s.staged.DirectAddr = addr s.staged.DirectAddr = addr
s.staged.Direct = true s.staged.Direct = true
s.publish(s.staged) s.publish(s.staged)
return EnterStateClientDirect(s.StateClientDirect.State) return enterStateClientDirect(s.stateClientDirect.pState)
} }
func (s *StateClientRelayed) OnLocalDiscovery(msg controlMsg[PacketLocalDiscovery]) { func (s *stateClientRelayed) OnLocalDiscovery(msg controlMsg[packetLocalDiscovery]) {
// The source port will be the multicast port, so we'll have to // The source port will be the multicast port, so we'll have to
// construct the correct address using the peer's listed port. // construct the correct address using the peer's listed port.
s.localDiscoveryAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port) s.localDiscoveryAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port)
} }
func (s *StateClientRelayed) OnPingTimer() PeerState { func (s *stateClientRelayed) OnPingTimer() peerState {
if nextState := s.StateClientDirect.OnPingTimer(); nextState != nil { if next := s.stateClientDirect.onPingTimer(); next != nil {
return nextState return next
} }
clear(s.probes) clear(s.probes)
@ -348,11 +350,11 @@ func (s *StateClientRelayed) OnPingTimer() PeerState {
s.localDiscoveryAddr = netip.AddrPort{} s.localDiscoveryAddr = netip.AddrPort{}
} }
return nil return s
} }
func (s *StateClientRelayed) sendProbeTo(addr netip.AddrPort) { func (s *stateClientRelayed) sendProbeTo(addr netip.AddrPort) {
probe := PacketProbe{TraceID: newTraceID()} probe := packetProbe{TraceID: newTraceID()}
s.probes[probe.TraceID] = addr s.probes[probe.TraceID] = addr
s.SendTo(probe, addr) s.SendTo(probe, addr)
} }

View File

@ -12,13 +12,13 @@ import (
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
type PeerStateControlMsg struct { type PeerStateControlMsg struct {
Peer RemotePeer Peer remotePeer
Packet any Packet any
} }
type PeerStateTestHarness struct { type PeerStateTestHarness struct {
State PeerState State peerState
Published RemotePeer Published remotePeer
Sent []PeerStateControlMsg Sent []PeerStateControlMsg
} }
@ -27,11 +27,11 @@ func NewPeerStateTestHarness() *PeerStateTestHarness {
keys := generateKeys() keys := generateKeys()
state := &State{ state := &pState{
publish: func(rp RemotePeer) { publish: func(rp remotePeer) {
h.Published = rp h.Published = rp
}, },
sendControlPacket: func(rp RemotePeer, pkt Marshaller) { sendControlPacket: func(rp remotePeer, pkt Marshaller) {
h.Sent = append(h.Sent, PeerStateControlMsg{rp, pkt}) h.Sent = append(h.Sent, PeerStateControlMsg{rp, pkt})
}, },
localIP: 2, localIP: 2,
@ -44,7 +44,7 @@ func NewPeerStateTestHarness() *PeerStateTestHarness {
}), }),
} }
h.State = EnterStateDisconnected(state) h.State = enterStateDisconnected(state)
return h return h
} }
@ -54,13 +54,13 @@ func (h *PeerStateTestHarness) PeerUpdate(p *m.Peer) {
} }
} }
func (h *PeerStateTestHarness) OnSyn(msg controlMsg[PacketSyn]) { func (h *PeerStateTestHarness) OnSyn(msg controlMsg[packetSyn]) {
if s := h.State.OnSyn(msg); s != nil { if s := h.State.OnSyn(msg); s != nil {
h.State = s h.State = s
} }
} }
func (h *PeerStateTestHarness) OnProbe(msg controlMsg[PacketProbe]) { func (h *PeerStateTestHarness) OnProbe(msg controlMsg[packetProbe]) {
if s := h.State.OnProbe(msg); s != nil { if s := h.State.OnProbe(msg); s != nil {
h.State = s h.State = s
} }
@ -72,10 +72,10 @@ func (h *PeerStateTestHarness) OnPingTimer() {
} }
} }
func (h *PeerStateTestHarness) ConfigServer_Public(t *testing.T) *StateServer { func (h *PeerStateTestHarness) ConfigServer_Public(t *testing.T) *stateServer {
keys := generateKeys() keys := generateKeys()
state := h.State.(*StateDisconnected) state := h.State.(*stateDisconnected)
state.localAddr = addrPort4(1, 1, 1, 2, 200) state.localAddr = addrPort4(1, 1, 1, 2, 200)
peer := &m.Peer{ peer := &m.Peer{
@ -88,10 +88,10 @@ func (h *PeerStateTestHarness) ConfigServer_Public(t *testing.T) *StateServer {
h.PeerUpdate(peer) h.PeerUpdate(peer)
assertEqual(t, h.Published.Up, false) assertEqual(t, h.Published.Up, false)
return assertType[*StateServer](t, h.State) return assertType[*stateServer](t, h.State)
} }
func (h *PeerStateTestHarness) ConfigServer_Relayed(t *testing.T) *StateServer { func (h *PeerStateTestHarness) ConfigServer_Relayed(t *testing.T) *stateServer {
keys := generateKeys() keys := generateKeys()
peer := &m.Peer{ peer := &m.Peer{
PeerIP: 3, PeerIP: 3,
@ -102,10 +102,10 @@ func (h *PeerStateTestHarness) ConfigServer_Relayed(t *testing.T) *StateServer {
h.PeerUpdate(peer) h.PeerUpdate(peer)
assertEqual(t, h.Published.Up, false) assertEqual(t, h.Published.Up, false)
return assertType[*StateServer](t, h.State) return assertType[*stateServer](t, h.State)
} }
func (h *PeerStateTestHarness) ConfigClientDirect(t *testing.T) *StateClientDirect { func (h *PeerStateTestHarness) ConfigClientDirect(t *testing.T) *stateClientDirect {
keys := generateKeys() keys := generateKeys()
peer := &m.Peer{ peer := &m.Peer{
PeerIP: 3, PeerIP: 3,
@ -117,13 +117,13 @@ func (h *PeerStateTestHarness) ConfigClientDirect(t *testing.T) *StateClientDire
h.PeerUpdate(peer) h.PeerUpdate(peer)
assertEqual(t, h.Published.Up, false) assertEqual(t, h.Published.Up, false)
return assertType[*StateClientDirect](t, h.State) return assertType[*stateClientDirect](t, h.State)
} }
func (h *PeerStateTestHarness) ConfigClientRelayed(t *testing.T) *StateClientRelayed { func (h *PeerStateTestHarness) ConfigClientRelayed(t *testing.T) *stateClientRelayed {
keys := generateKeys() keys := generateKeys()
state := h.State.(*StateDisconnected) state := h.State.(*stateDisconnected)
state.remoteIP = 1 state.remoteIP = 1
peer := &m.Peer{ peer := &m.Peer{
@ -135,7 +135,7 @@ func (h *PeerStateTestHarness) ConfigClientRelayed(t *testing.T) *StateClientRel
h.PeerUpdate(peer) h.PeerUpdate(peer)
assertEqual(t, h.Published.Up, false) assertEqual(t, h.Published.Up, false)
return assertType[*StateClientRelayed](t, h.State) return assertType[*stateClientRelayed](t, h.State)
} }
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
@ -143,14 +143,14 @@ func (h *PeerStateTestHarness) ConfigClientRelayed(t *testing.T) *StateClientRel
func TestPeerState_OnPeerUpdate_nilPeer(t *testing.T) { func TestPeerState_OnPeerUpdate_nilPeer(t *testing.T) {
h := NewPeerStateTestHarness() h := NewPeerStateTestHarness()
h.PeerUpdate(nil) h.PeerUpdate(nil)
assertType[*StateDisconnected](t, h.State) assertType[*stateDisconnected](t, h.State)
} }
func TestPeerState_OnPeerUpdate_publicLocalIsServer(t *testing.T) { func TestPeerState_OnPeerUpdate_publicLocalIsServer(t *testing.T) {
keys := generateKeys() keys := generateKeys()
h := NewPeerStateTestHarness() h := NewPeerStateTestHarness()
state := h.State.(*StateDisconnected) state := h.State.(*stateDisconnected)
state.localAddr = addrPort4(1, 1, 1, 2, 200) state.localAddr = addrPort4(1, 1, 1, 2, 200)
peer := &m.Peer{ peer := &m.Peer{
@ -162,7 +162,7 @@ func TestPeerState_OnPeerUpdate_publicLocalIsServer(t *testing.T) {
h.PeerUpdate(peer) h.PeerUpdate(peer)
assertEqual(t, h.Published.Up, false) assertEqual(t, h.Published.Up, false)
assertType[*StateServer](t, h.State) assertType[*stateServer](t, h.State)
} }
func TestPeerState_OnPeerUpdate_serverDirect(t *testing.T) { func TestPeerState_OnPeerUpdate_serverDirect(t *testing.T) {
@ -191,10 +191,10 @@ func TestStateServer_directSyn(t *testing.T) {
assertEqual(t, h.Published.Up, false) assertEqual(t, h.Published.Up, false)
synMsg := controlMsg[PacketSyn]{ synMsg := controlMsg[packetSyn]{
SrcIP: 3, SrcIP: 3,
SrcAddr: addrPort4(1, 1, 1, 3, 300), SrcAddr: addrPort4(1, 1, 1, 3, 300),
Packet: PacketSyn{ Packet: packetSyn{
TraceID: newTraceID(), TraceID: newTraceID(),
//SentAt: time.Now().UnixMilli(), //SentAt: time.Now().UnixMilli(),
//SharedKeyType: 1, //SharedKeyType: 1,
@ -205,7 +205,7 @@ func TestStateServer_directSyn(t *testing.T) {
h.State.OnSyn(synMsg) h.State.OnSyn(synMsg)
assertEqual(t, len(h.Sent), 1) assertEqual(t, len(h.Sent), 1)
ack := assertType[PacketAck](t, h.Sent[0].Packet) ack := assertType[packetAck](t, h.Sent[0].Packet)
assertEqual(t, ack.TraceID, synMsg.Packet.TraceID) assertEqual(t, ack.TraceID, synMsg.Packet.TraceID)
assertEqual(t, h.Sent[0].Peer.IP, 3) assertEqual(t, h.Sent[0].Peer.IP, 3)
assertEqual(t, ack.PossibleAddrs[0].IsValid(), false) assertEqual(t, ack.PossibleAddrs[0].IsValid(), false)
@ -220,10 +220,10 @@ func TestStateServer_relayedSyn(t *testing.T) {
assertEqual(t, h.Published.Up, false) assertEqual(t, h.Published.Up, false)
synMsg := controlMsg[PacketSyn]{ synMsg := controlMsg[packetSyn]{
SrcIP: 3, SrcIP: 3,
SrcAddr: addrPort4(1, 1, 1, 3, 300), SrcAddr: addrPort4(1, 1, 1, 3, 300),
Packet: PacketSyn{ Packet: packetSyn{
TraceID: newTraceID(), TraceID: newTraceID(),
//SentAt: time.Now().UnixMilli(), //SentAt: time.Now().UnixMilli(),
//SharedKeyType: 1, //SharedKeyType: 1,
@ -237,15 +237,15 @@ func TestStateServer_relayedSyn(t *testing.T) {
assertEqual(t, len(h.Sent), 3) assertEqual(t, len(h.Sent), 3)
ack := assertType[PacketAck](t, h.Sent[0].Packet) ack := assertType[packetAck](t, h.Sent[0].Packet)
assertEqual(t, ack.TraceID, synMsg.Packet.TraceID) assertEqual(t, ack.TraceID, synMsg.Packet.TraceID)
assertEqual(t, h.Sent[0].Peer.IP, 3) assertEqual(t, h.Sent[0].Peer.IP, 3)
assertEqual(t, ack.PossibleAddrs[0], addrPort4(4, 5, 6, 7, 1234)) assertEqual(t, ack.PossibleAddrs[0], addrPort4(4, 5, 6, 7, 1234))
assertEqual(t, ack.PossibleAddrs[1].IsValid(), false) assertEqual(t, ack.PossibleAddrs[1].IsValid(), false)
assertEqual(t, h.Published.Up, true) assertEqual(t, h.Published.Up, true)
assertType[PacketProbe](t, h.Sent[1].Packet) assertType[packetProbe](t, h.Sent[1].Packet)
assertType[PacketProbe](t, h.Sent[2].Packet) assertType[packetProbe](t, h.Sent[2].Packet)
assertEqual(t, h.Sent[1].Peer.DirectAddr, addrPort4(1, 1, 1, 3, 300)) assertEqual(t, h.Sent[1].Peer.DirectAddr, addrPort4(1, 1, 1, 3, 300))
assertEqual(t, h.Sent[2].Peer.DirectAddr, addrPort4(2, 2, 2, 3, 300)) assertEqual(t, h.Sent[2].Peer.DirectAddr, addrPort4(2, 2, 2, 3, 300))
} }
@ -255,17 +255,17 @@ func TestStateServer_onProbe(t *testing.T) {
h.ConfigServer_Relayed(t) h.ConfigServer_Relayed(t)
assertEqual(t, h.Published.Up, false) assertEqual(t, h.Published.Up, false)
probeMsg := controlMsg[PacketProbe]{ probeMsg := controlMsg[packetProbe]{
SrcIP: 3, SrcIP: 3,
SrcAddr: addrPort4(1, 1, 1, 3, 300), SrcAddr: addrPort4(1, 1, 1, 3, 300),
Packet: PacketProbe{TraceID: newTraceID()}, Packet: packetProbe{TraceID: newTraceID()},
} }
h.State.OnProbe(probeMsg) h.State.OnProbe(probeMsg)
assertEqual(t, len(h.Sent), 1) assertEqual(t, len(h.Sent), 1)
probe := assertType[PacketProbe](t, h.Sent[0].Packet) probe := assertType[packetProbe](t, h.Sent[0].Packet)
assertEqual(t, probe.TraceID, probeMsg.Packet.TraceID) assertEqual(t, probe.TraceID, probeMsg.Packet.TraceID)
assertEqual(t, h.Sent[0].Peer.DirectAddr, addrPort4(1, 1, 1, 3, 300)) assertEqual(t, h.Sent[0].Peer.DirectAddr, addrPort4(1, 1, 1, 3, 300))
} }
@ -274,10 +274,10 @@ func TestStateServer_OnPingTimer_timeout(t *testing.T) {
h := NewPeerStateTestHarness() h := NewPeerStateTestHarness()
h.ConfigServer_Relayed(t) h.ConfigServer_Relayed(t)
synMsg := controlMsg[PacketSyn]{ synMsg := controlMsg[packetSyn]{
SrcIP: 3, SrcIP: 3,
SrcAddr: addrPort4(1, 1, 1, 3, 300), SrcAddr: addrPort4(1, 1, 1, 3, 300),
Packet: PacketSyn{ Packet: packetSyn{
TraceID: newTraceID(), TraceID: newTraceID(),
//SentAt: time.Now().UnixMilli(), //SentAt: time.Now().UnixMilli(),
//SharedKeyType: 1, //SharedKeyType: 1,
@ -294,7 +294,7 @@ func TestStateServer_OnPingTimer_timeout(t *testing.T) {
assertEqual(t, h.Published.Up, true) assertEqual(t, h.Published.Up, true)
// Advance the time, then ping. // Advance the time, then ping.
state := assertType[*StateServer](t, h.State) state := assertType[*stateServer](t, h.State)
state.lastSeen = time.Now().Add(-timeoutInterval - time.Second) state.lastSeen = time.Now().Add(-timeoutInterval - time.Second)
h.OnPingTimer() h.OnPingTimer()
@ -309,10 +309,10 @@ func TestStateClientDirect_OnAck(t *testing.T) {
// On entering the state, a SYN should have been sent. // On entering the state, a SYN should have been sent.
assertEqual(t, len(h.Sent), 1) assertEqual(t, len(h.Sent), 1)
syn := assertType[PacketSyn](t, h.Sent[0].Packet) syn := assertType[packetSyn](t, h.Sent[0].Packet)
ack := controlMsg[PacketAck]{ ack := controlMsg[packetAck]{
Packet: PacketAck{TraceID: syn.TraceID}, Packet: packetAck{TraceID: syn.TraceID},
} }
h.State.OnAck(ack) h.State.OnAck(ack)
assertEqual(t, h.Published.Up, true) assertEqual(t, h.Published.Up, true)
@ -326,10 +326,10 @@ func TestStateClientDirect_OnAck_incorrectTraceID(t *testing.T) {
// On entering the state, a SYN should have been sent. // On entering the state, a SYN should have been sent.
assertEqual(t, len(h.Sent), 1) assertEqual(t, len(h.Sent), 1)
syn := assertType[PacketSyn](t, h.Sent[0].Packet) syn := assertType[packetSyn](t, h.Sent[0].Packet)
ack := controlMsg[PacketAck]{ ack := controlMsg[packetAck]{
Packet: PacketAck{TraceID: syn.TraceID + 1}, Packet: packetAck{TraceID: syn.TraceID + 1},
} }
h.State.OnAck(ack) h.State.OnAck(ack)
assertEqual(t, h.Published.Up, false) assertEqual(t, h.Published.Up, false)
@ -341,15 +341,15 @@ func TestStateClientDirect_OnPingTimer(t *testing.T) {
// On entering the state, a SYN should have been sent. // On entering the state, a SYN should have been sent.
assertEqual(t, len(h.Sent), 1) assertEqual(t, len(h.Sent), 1)
assertType[PacketSyn](t, h.Sent[0].Packet) assertType[packetSyn](t, h.Sent[0].Packet)
h.OnPingTimer() h.OnPingTimer()
// On ping timer, another syn should be sent. Additionally, we should remain // On ping timer, another syn should be sent. Additionally, we should remain
// in the same state. // in the same state.
assertEqual(t, len(h.Sent), 2) assertEqual(t, len(h.Sent), 2)
assertType[PacketSyn](t, h.Sent[1].Packet) assertType[packetSyn](t, h.Sent[1].Packet)
assertType[*StateClientDirect](t, h.State) assertType[*stateClientDirect](t, h.State)
assertEqual(t, h.Published.Up, false) assertEqual(t, h.Published.Up, false)
} }
@ -361,15 +361,15 @@ func TestStateClientDirect_OnPingTimer_timeout(t *testing.T) {
// On entering the state, a SYN should have been sent. // On entering the state, a SYN should have been sent.
assertEqual(t, len(h.Sent), 1) assertEqual(t, len(h.Sent), 1)
syn := assertType[PacketSyn](t, h.Sent[0].Packet) syn := assertType[packetSyn](t, h.Sent[0].Packet)
ack := controlMsg[PacketAck]{ ack := controlMsg[packetAck]{
Packet: PacketAck{TraceID: syn.TraceID}, Packet: packetAck{TraceID: syn.TraceID},
} }
h.State.OnAck(ack) h.State.OnAck(ack)
assertEqual(t, h.Published.Up, true) assertEqual(t, h.Published.Up, true)
state := assertType[*StateClientDirect](t, h.State) state := assertType[*stateClientDirect](t, h.State)
state.lastSeen = time.Now().Add(-(timeoutInterval + time.Second)) state.lastSeen = time.Now().Add(-(timeoutInterval + time.Second))
h.OnPingTimer() h.OnPingTimer()
@ -377,8 +377,8 @@ func TestStateClientDirect_OnPingTimer_timeout(t *testing.T) {
// On ping timer, we should timeout, causing the client to reset. Another SYN // On ping timer, we should timeout, causing the client to reset. Another SYN
// will be sent when re-entering the state, but the connection should be down. // will be sent when re-entering the state, but the connection should be down.
assertEqual(t, len(h.Sent), 2) assertEqual(t, len(h.Sent), 2)
assertType[PacketSyn](t, h.Sent[1].Packet) assertType[packetSyn](t, h.Sent[1].Packet)
assertType[*StateClientDirect](t, h.State) assertType[*stateClientDirect](t, h.State)
assertEqual(t, h.Published.Up, false) assertEqual(t, h.Published.Up, false)
} }
@ -390,10 +390,10 @@ func TestStateClientRelayed_OnAck(t *testing.T) {
// On entering the state, a SYN should have been sent. // On entering the state, a SYN should have been sent.
assertEqual(t, len(h.Sent), 1) assertEqual(t, len(h.Sent), 1)
syn := assertType[PacketSyn](t, h.Sent[0].Packet) syn := assertType[packetSyn](t, h.Sent[0].Packet)
ack := controlMsg[PacketAck]{ ack := controlMsg[packetAck]{
Packet: PacketAck{TraceID: syn.TraceID}, Packet: packetAck{TraceID: syn.TraceID},
} }
h.State.OnAck(ack) h.State.OnAck(ack)
assertEqual(t, h.Published.Up, true) assertEqual(t, h.Published.Up, true)
@ -423,9 +423,9 @@ func TestStateClientRelayed_OnPingTimer_withAddrs(t *testing.T) {
// On entering the state, a SYN should have been sent. // On entering the state, a SYN should have been sent.
assertEqual(t, len(h.Sent), 1) assertEqual(t, len(h.Sent), 1)
syn := assertType[PacketSyn](t, h.Sent[0].Packet) syn := assertType[packetSyn](t, h.Sent[0].Packet)
ack := controlMsg[PacketAck]{Packet: PacketAck{TraceID: syn.TraceID}} ack := controlMsg[packetAck]{Packet: packetAck{TraceID: syn.TraceID}}
ack.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 300) ack.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 300)
ack.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 300) ack.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 300)
@ -433,7 +433,7 @@ func TestStateClientRelayed_OnPingTimer_withAddrs(t *testing.T) {
// Add a local discovery address. Note that the port will be configured port // Add a local discovery address. Note that the port will be configured port
// and no the one provided here. // and no the one provided here.
h.State.OnLocalDiscovery(controlMsg[PacketLocalDiscovery]{ h.State.OnLocalDiscovery(controlMsg[packetLocalDiscovery]{
SrcIP: 3, SrcIP: 3,
SrcAddr: addrPort4(2, 2, 2, 3, 300), SrcAddr: addrPort4(2, 2, 2, 3, 300),
}) })
@ -441,10 +441,10 @@ func TestStateClientRelayed_OnPingTimer_withAddrs(t *testing.T) {
// We should see one SYN and three probe packets. // We should see one SYN and three probe packets.
h.OnPingTimer() h.OnPingTimer()
assertEqual(t, len(h.Sent), 5) assertEqual(t, len(h.Sent), 5)
assertType[PacketSyn](t, h.Sent[1].Packet) assertType[packetSyn](t, h.Sent[1].Packet)
assertType[PacketProbe](t, h.Sent[2].Packet) assertType[packetProbe](t, h.Sent[2].Packet)
assertType[PacketProbe](t, h.Sent[3].Packet) assertType[packetProbe](t, h.Sent[3].Packet)
assertType[PacketProbe](t, h.Sent[4].Packet) assertType[packetProbe](t, h.Sent[4].Packet)
assertEqual(t, h.Sent[2].Peer.DirectAddr, addrPort4(1, 1, 1, 1, 300)) assertEqual(t, h.Sent[2].Peer.DirectAddr, addrPort4(1, 1, 1, 1, 300))
assertEqual(t, h.Sent[3].Peer.DirectAddr, addrPort4(1, 1, 1, 2, 300)) assertEqual(t, h.Sent[3].Peer.DirectAddr, addrPort4(1, 1, 1, 2, 300))
@ -457,15 +457,15 @@ func TestStateClientRelayed_OnPingTimer_timeout(t *testing.T) {
// On entering the state, a SYN should have been sent. // On entering the state, a SYN should have been sent.
assertEqual(t, len(h.Sent), 1) assertEqual(t, len(h.Sent), 1)
syn := assertType[PacketSyn](t, h.Sent[0].Packet) syn := assertType[packetSyn](t, h.Sent[0].Packet)
ack := controlMsg[PacketAck]{ ack := controlMsg[packetAck]{
Packet: PacketAck{TraceID: syn.TraceID}, Packet: packetAck{TraceID: syn.TraceID},
} }
h.State.OnAck(ack) h.State.OnAck(ack)
assertEqual(t, h.Published.Up, true) assertEqual(t, h.Published.Up, true)
state := assertType[*StateClientRelayed](t, h.State) state := assertType[*stateClientRelayed](t, h.State)
state.lastSeen = time.Now().Add(-(timeoutInterval + time.Second)) state.lastSeen = time.Now().Add(-(timeoutInterval + time.Second))
h.OnPingTimer() h.OnPingTimer()
@ -473,8 +473,8 @@ func TestStateClientRelayed_OnPingTimer_timeout(t *testing.T) {
// On ping timer, we should timeout, causing the client to reset. Another SYN // On ping timer, we should timeout, causing the client to reset. Another SYN
// will be sent when re-entering the state, but the connection should be down. // will be sent when re-entering the state, but the connection should be down.
assertEqual(t, len(h.Sent), 2) assertEqual(t, len(h.Sent), 2)
assertType[PacketSyn](t, h.Sent[1].Packet) assertType[packetSyn](t, h.Sent[1].Packet)
assertType[*StateClientRelayed](t, h.State) assertType[*stateClientRelayed](t, h.State)
assertEqual(t, h.Published.Up, false) assertEqual(t, h.Published.Up, false)
} }
@ -482,28 +482,28 @@ func TestStateClientRelayed_OnProbe_unknownAddr(t *testing.T) {
h := NewPeerStateTestHarness() h := NewPeerStateTestHarness()
h.ConfigClientRelayed(t) h.ConfigClientRelayed(t)
h.OnProbe(controlMsg[PacketProbe]{ h.OnProbe(controlMsg[packetProbe]{
Packet: PacketProbe{TraceID: newTraceID()}, Packet: packetProbe{TraceID: newTraceID()},
}) })
assertType[*StateClientRelayed](t, h.State) assertType[*stateClientRelayed](t, h.State)
} }
func TestStateClientRelayed_OnProbe_upgradeDirect(t *testing.T) { func TestStateClientRelayed_OnProbe_upgradeDirect(t *testing.T) {
h := NewPeerStateTestHarness() h := NewPeerStateTestHarness()
h.ConfigClientRelayed(t) h.ConfigClientRelayed(t)
syn := assertType[PacketSyn](t, h.Sent[0].Packet) syn := assertType[packetSyn](t, h.Sent[0].Packet)
ack := controlMsg[PacketAck]{Packet: PacketAck{TraceID: syn.TraceID}} ack := controlMsg[packetAck]{Packet: packetAck{TraceID: syn.TraceID}}
ack.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 300) ack.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 300)
ack.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 300) ack.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 300)
h.State.OnAck(ack) h.State.OnAck(ack)
h.OnPingTimer() h.OnPingTimer()
probe := assertType[PacketProbe](t, h.Sent[2].Packet) probe := assertType[packetProbe](t, h.Sent[2].Packet)
h.OnProbe(controlMsg[PacketProbe]{Packet: probe}) h.OnProbe(controlMsg[packetProbe]{Packet: probe})
assertType[*StateClientDirect](t, h.State) assertType[*stateClientDirect](t, h.State)
} }

172
peer/peersuper.go Normal file
View File

@ -0,0 +1,172 @@
package peer
import (
"log"
"math/rand"
"net/netip"
"sync"
"sync/atomic"
"time"
"git.crumpington.com/lib/go/ratelimiter"
)
type Super struct {
writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error)
staged routingTable
shared *atomic.Pointer[routingTable]
peers [256]*PeerSuper
lock sync.Mutex
buf1 []byte
buf2 []byte
}
func NewSuper(
writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error),
rt *atomic.Pointer[routingTable],
privKey []byte,
) *Super {
routes := rt.Load()
s := &Super{
writeToUDPAddrPort: writeToUDPAddrPort,
staged: *routes,
shared: rt,
buf1: newBuf(),
buf2: newBuf(),
}
pubAddrs := newPubAddrStore(routes.LocalAddr)
for i := range s.peers {
state := &pState{
publish: s.publish,
sendControlPacket: s.send,
localIP: routes.LocalIP,
remoteIP: byte(i),
privKey: privKey,
localAddr: routes.LocalAddr,
pubAddrs: pubAddrs,
staged: routes.Peers[i],
limiter: ratelimiter.New(ratelimiter.Config{
FillPeriod: 20 * time.Millisecond,
MaxWaitCount: 1,
}),
}
s.peers[i] = NewPeerSuper(state)
}
return s
}
func (s *Super) Start() {
for i := range s.peers {
go s.peers[i].Run()
}
}
func (s *Super) HandleControlMsg(destIP byte, msg any) {
s.peers[destIP].HandleControlMsg(msg)
}
func (s *Super) send(peer remotePeer, pkt Marshaller) {
s.lock.Lock()
defer s.lock.Unlock()
enc := peer.EncryptControlPacket(pkt, s.buf1, s.buf2)
if peer.Direct {
s.writeToUDPAddrPort(enc, peer.DirectAddr)
return
}
relay, ok := s.staged.GetRelay()
if !ok {
return
}
enc = relay.EncryptDataPacket(peer.IP, enc, s.buf1)
s.writeToUDPAddrPort(enc, relay.DirectAddr)
}
func (s *Super) publish(rp remotePeer) {
s.lock.Lock()
defer s.lock.Unlock()
s.staged.Peers[rp.IP] = rp
s.ensureRelay()
copy := s.staged
s.shared.Store(&copy)
}
func (s *Super) ensureRelay() {
if _, ok := s.staged.GetRelay(); ok {
return
}
// TODO: Random selection?
for _, peer := range s.staged.Peers {
if peer.Up && peer.Direct && peer.Relay {
s.staged.RelayIP = peer.IP
return
}
}
}
// ----------------------------------------------------------------------------
type PeerSuper struct {
messages chan any
state peerState
}
func NewPeerSuper(state *pState) *PeerSuper {
return &PeerSuper{
messages: make(chan any, 8),
state: state.OnPeerUpdate(nil),
}
}
func (s *PeerSuper) HandleControlMsg(msg any) {
select {
case s.messages <- msg:
default:
}
}
func (s *PeerSuper) Run() {
go func() {
// Randomize ping timers.
time.Sleep(time.Duration(rand.Intn(4000)) * time.Millisecond)
for range time.Tick(pingInterval) {
s.messages <- pingTimerMsg{}
}
}()
for rawMsg := range s.messages {
switch msg := rawMsg.(type) {
case peerUpdateMsg:
s.state = s.state.OnPeerUpdate(msg.Peer)
case controlMsg[packetSyn]:
s.state = s.state.OnSyn(msg)
case controlMsg[packetAck]:
s.state.OnAck(msg)
case controlMsg[packetProbe]:
s.state = s.state.OnProbe(msg)
case controlMsg[packetLocalDiscovery]:
s.state.OnLocalDiscovery(msg)
case pingTimerMsg:
s.state = s.state.OnPingTimer()
default:
log.Printf("WARNING: unknown message type: %+v", msg)
}
}
}

View File

@ -7,9 +7,9 @@ import (
) )
// TODO: Remove // TODO: Remove
func NewRemotePeer(ip byte) *RemotePeer { func NewRemotePeer(ip byte) *remotePeer {
counter := uint64(time.Now().Unix()<<30 + 1) counter := uint64(time.Now().Unix()<<30 + 1)
return &RemotePeer{ return &remotePeer{
IP: ip, IP: ip,
counter: &counter, counter: &counter,
dupCheck: newDupCheck(0), dupCheck: newDupCheck(0),
@ -18,7 +18,7 @@ func NewRemotePeer(ip byte) *RemotePeer {
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
type RemotePeer struct { type remotePeer struct {
localIP byte localIP byte
IP byte // VPN IP of peer (last byte). IP byte // VPN IP of peer (last byte).
Up bool // True if data can be sent on the peer. Up bool // True if data can be sent on the peer.
@ -33,7 +33,7 @@ type RemotePeer struct {
dupCheck *dupCheck // For receiving from. Not safe for concurrent use. dupCheck *dupCheck // For receiving from. Not safe for concurrent use.
} }
func (p RemotePeer) EncryptDataPacket(destIP byte, data, out []byte) []byte { func (p remotePeer) EncryptDataPacket(destIP byte, data, out []byte) []byte {
h := header{ h := header{
StreamID: dataStreamID, StreamID: dataStreamID,
Counter: atomic.AddUint64(p.counter, 1), Counter: atomic.AddUint64(p.counter, 1),
@ -44,7 +44,7 @@ func (p RemotePeer) EncryptDataPacket(destIP byte, data, out []byte) []byte {
} }
// Decrypts and de-dups incoming data packets. // Decrypts and de-dups incoming data packets.
func (p RemotePeer) DecryptDataPacket(h header, enc, out []byte) ([]byte, error) { func (p remotePeer) DecryptDataPacket(h header, enc, out []byte) ([]byte, error) {
dec, ok := p.DataCipher.Decrypt(enc, out) dec, ok := p.DataCipher.Decrypt(enc, out)
if !ok { if !ok {
return nil, errDecryptionFailed return nil, errDecryptionFailed
@ -58,21 +58,22 @@ func (p RemotePeer) DecryptDataPacket(h header, enc, out []byte) ([]byte, error)
} }
// Peer must have a ControlCipher. // Peer must have a ControlCipher.
func (p RemotePeer) EncryptControlPacket(pkt Marshaller, tmp, out []byte) []byte { func (p remotePeer) EncryptControlPacket(pkt Marshaller, tmp, out []byte) []byte {
tmp = pkt.Marshal(tmp)
h := header{ h := header{
StreamID: controlStreamID, StreamID: controlStreamID,
Counter: atomic.AddUint64(p.counter, 1), Counter: atomic.AddUint64(p.counter, 1),
SourceIP: p.localIP, SourceIP: p.localIP,
DestIP: p.IP, DestIP: p.IP,
} }
tmp = pkt.Marshal(tmp)
return p.ControlCipher.Encrypt(h, tmp, out) return p.ControlCipher.Encrypt(h, tmp, out)
} }
// Returns a controlMsg[PacketType]. Peer must have a non-nil ControlCipher. // Returns a controlMsg[PacketType]. Peer must have a non-nil ControlCipher.
// //
// This function also drops packets with duplicate sequence numbers. // This function also drops packets with duplicate sequence numbers.
func (p RemotePeer) DecryptControlPacket(fromAddr netip.AddrPort, h header, enc, tmp []byte) (any, error) { func (p remotePeer) DecryptControlPacket(fromAddr netip.AddrPort, h header, enc, tmp []byte) (any, error) {
out, ok := p.ControlCipher.Decrypt(enc, tmp) out, ok := p.ControlCipher.Decrypt(enc, tmp)
if !ok { if !ok {
return nil, errDecryptionFailed return nil, errDecryptionFailed
@ -92,7 +93,7 @@ func (p RemotePeer) DecryptControlPacket(fromAddr netip.AddrPort, h header, enc,
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
type RoutingTable struct { type routingTable struct {
// The LocalIP is the configured IP address of the local peer on the VPN. // The LocalIP is the configured IP address of the local peer on the VPN.
// //
// This value is constant. // This value is constant.
@ -106,21 +107,21 @@ type RoutingTable struct {
LocalAddr netip.AddrPort LocalAddr netip.AddrPort
// The remote peer configurations. These are updated by // The remote peer configurations. These are updated by
Peers [256]RemotePeer Peers [256]remotePeer
// The current relay's VPN IP address, or zero if no relay is available. // The current relay's VPN IP address, or zero if no relay is available.
RelayIP byte RelayIP byte
} }
func NewRoutingTable(localIP byte, localAddr netip.AddrPort) RoutingTable { func newRoutingTable(localIP byte, localAddr netip.AddrPort) routingTable {
rt := RoutingTable{ rt := routingTable{
LocalIP: localIP, LocalIP: localIP,
LocalAddr: localAddr, LocalAddr: localAddr,
} }
for i := range rt.Peers { for i := range rt.Peers {
counter := uint64(time.Now().Unix()<<30 + 1) counter := uint64(time.Now().Unix()<<30 + 1)
rt.Peers[i] = RemotePeer{ rt.Peers[i] = remotePeer{
localIP: localIP, localIP: localIP,
IP: byte(i), IP: byte(i),
counter: &counter, counter: &counter,
@ -131,7 +132,7 @@ func NewRoutingTable(localIP byte, localAddr netip.AddrPort) RoutingTable {
return rt return rt
} }
func (rt *RoutingTable) GetRelay() (RemotePeer, bool) { func (rt *routingTable) GetRelay() (remotePeer, bool) {
relay := rt.Peers[rt.RelayIP] relay := rt.Peers[rt.RelayIP]
return relay, relay.Up && relay.Direct return relay, relay.Up && relay.Direct
} }

View File

@ -74,7 +74,7 @@ func TestRemotePeer_DecryptControlPacket(t *testing.T) {
peer2 := p1.RT.Load().Peers[2] peer2 := p1.RT.Load().Peers[2]
peer1 := p2.RT.Load().Peers[1] peer1 := p2.RT.Load().Peers[1]
orig := PacketProbe{TraceID: newTraceID()} orig := packetProbe{TraceID: newTraceID()}
enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf()) enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf())
@ -88,7 +88,7 @@ func TestRemotePeer_DecryptControlPacket(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
dec, ok := ctrlMsg.(controlMsg[PacketProbe]) dec, ok := ctrlMsg.(controlMsg[packetProbe])
if !ok { if !ok {
t.Fatal(ctrlMsg) t.Fatal(ctrlMsg)
} }
@ -108,7 +108,7 @@ func TestRemotePeer_DecryptControlPacket_packetAltered(t *testing.T) {
peer2 := p1.RT.Load().Peers[2] peer2 := p1.RT.Load().Peers[2]
peer1 := p2.RT.Load().Peers[1] peer1 := p2.RT.Load().Peers[1]
orig := PacketProbe{TraceID: newTraceID()} orig := packetProbe{TraceID: newTraceID()}
enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf()) enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf())
@ -131,7 +131,7 @@ func TestRemotePeer_DecryptControlPacket_duplicateSequenceNumber(t *testing.T) {
peer2 := p1.RT.Load().Peers[2] peer2 := p1.RT.Load().Peers[2]
peer1 := p2.RT.Load().Peers[1] peer1 := p2.RT.Load().Peers[1]
orig := PacketProbe{TraceID: newTraceID()} orig := packetProbe{TraceID: newTraceID()}
enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf()) enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf())

View File

@ -1,103 +0,0 @@
package peer
import (
"log"
"sync/atomic"
"time"
"git.crumpington.com/lib/go/ratelimiter"
)
// ----------------------------------------------------------------------------
type Supervisor struct {
messages chan any // Incoming control messages.
peers [256]PeerState
pubAddrs *pubAddrStore
rt *atomic.Pointer[RoutingTable]
staged RoutingTable
}
func NewSupervisor(
sendControl func(RemotePeer, Marshaller),
privKey []byte,
rt *atomic.Pointer[RoutingTable],
) *Supervisor {
s := &Supervisor{
messages: make(chan any, 1024),
pubAddrs: newPubAddrStore(rt.Load().LocalAddr),
rt: rt,
}
routes := rt.Load()
for i := range s.peers {
state := &State{
publish: s.Publish,
sendControlPacket: sendControl,
localIP: routes.LocalIP,
remoteIP: byte(i),
privKey: privKey,
localAddr: routes.LocalAddr,
pubAddrs: s.pubAddrs,
staged: routes.Peers[i],
limiter: ratelimiter.New(ratelimiter.Config{
FillPeriod: 20 * time.Millisecond,
MaxWaitCount: 1,
}),
}
s.peers[i] = state.OnPeerUpdate(nil)
}
return s
}
func (s *Supervisor) HandleControlMsg(msg any) {
select {
case s.messages <- msg:
default:
}
}
func (s *Supervisor) Run() {
for raw := range s.messages {
switch msg := raw.(type) {
case peerUpdateMsg:
s.peers[msg.PeerIP] = s.peers[msg.PeerIP].OnPeerUpdate(msg.Peer)
case controlMsg[PacketSyn]:
if newState := s.peers[msg.SrcIP].OnSyn(msg); newState != nil {
s.peers[msg.SrcIP] = newState
}
case controlMsg[PacketAck]:
s.peers[msg.SrcIP].OnAck(msg)
case controlMsg[PacketProbe]:
if newState := s.peers[msg.SrcIP].OnProbe(msg); newState != nil {
s.peers[msg.SrcIP] = newState
}
case controlMsg[PacketLocalDiscovery]:
s.peers[msg.SrcIP].OnLocalDiscovery(msg)
case pingTimerMsg:
s.pubAddrs.Clean()
for i := range s.peers {
if newState := s.peers[i].OnPingTimer(); newState != nil {
s.peers[i] = newState
}
}
default:
log.Printf("WARNING: unknown message type: %+v", msg)
}
}
}
func (s *Supervisor) Publish(rp RemotePeer) {
s.staged.Peers[rp.IP] = rp
rt := s.staged // Copy.
s.rt.Store(&rt)
}