Compare commits

..

No commits in common. "e91cbfe957b3f7ea96e8351355435f5630d047a6" and "f4589a1031e1fdeeec385f4efbeb3c8a829f861f" have entirely different histories.

42 changed files with 2428 additions and 1235 deletions

Binary file not shown.

Binary file not shown.

View File

@ -1 +0,0 @@
package peer

View File

@ -12,7 +12,7 @@ func newControlCipher(privKey, pubKey []byte) *controlCipher {
return &controlCipher{shared} return &controlCipher{shared}
} }
func (cc *controlCipher) Encrypt(h Header, data, out []byte) []byte { func (cc *controlCipher) Encrypt(h header, data, out []byte) []byte {
const s = controlHeaderSize const s = controlHeaderSize
out = out[:s+controlCipherOverhead+len(data)] out = out[:s+controlCipherOverhead+len(data)]
h.Marshal(out[:s]) h.Marshal(out[:s])

View File

@ -40,7 +40,7 @@ func TestControlCipher(t *testing.T) {
} }
for _, plaintext := range testCases { for _, plaintext := range testCases {
h1 := Header{ h1 := header{
StreamID: controlStreamID, StreamID: controlStreamID,
Counter: 235153, Counter: 235153,
SourceIP: 4, SourceIP: 4,
@ -51,7 +51,7 @@ func TestControlCipher(t *testing.T) {
encrypted = c1.Encrypt(h1, plaintext, encrypted) encrypted = c1.Encrypt(h1, plaintext, encrypted)
h2 := Header{} h2 := header{}
h2.Parse(encrypted) h2.Parse(encrypted)
if !reflect.DeepEqual(h1, h2) { if !reflect.DeepEqual(h1, h2) {
t.Fatal(h1, h2) t.Fatal(h1, h2)
@ -80,7 +80,7 @@ func TestControlCipher_ShortCiphertext(t *testing.T) {
func BenchmarkControlCipher_Encrypt(b *testing.B) { func BenchmarkControlCipher_Encrypt(b *testing.B) {
c1, _ := newControlCipherForTesting() c1, _ := newControlCipherForTesting()
h1 := Header{ h1 := header{
Counter: 235153, Counter: 235153,
SourceIP: 4, SourceIP: 4,
DestIP: 88, DestIP: 88,
@ -100,7 +100,7 @@ func BenchmarkControlCipher_Encrypt(b *testing.B) {
func BenchmarkControlCipher_Decrypt(b *testing.B) { func BenchmarkControlCipher_Decrypt(b *testing.B) {
c1, c2 := newControlCipherForTesting() c1, c2 := newControlCipherForTesting()
h1 := Header{ h1 := header{
Counter: 235153, Counter: 235153,
SourceIP: 4, SourceIP: 4,
DestIP: 88, DestIP: 88,

View File

@ -38,7 +38,7 @@ func (sc *dataCipher) Key() [32]byte {
return sc.key return sc.key
} }
func (sc *dataCipher) Encrypt(h Header, data, out []byte) []byte { func (sc *dataCipher) Encrypt(h header, data, out []byte) []byte {
const s = dataHeaderSize const s = dataHeaderSize
out = out[:s+dataCipherOverhead+len(data)] out = out[:s+dataCipherOverhead+len(data)]
h.Marshal(out[:s]) h.Marshal(out[:s])

View File

@ -22,7 +22,7 @@ func TestDataCipher(t *testing.T) {
} }
for _, plaintext := range testCases { for _, plaintext := range testCases {
h1 := Header{ h1 := header{
StreamID: dataStreamID, StreamID: dataStreamID,
Counter: 235153, Counter: 235153,
SourceIP: 4, SourceIP: 4,
@ -33,7 +33,7 @@ func TestDataCipher(t *testing.T) {
dc1 := newDataCipher() dc1 := newDataCipher()
encrypted = dc1.Encrypt(h1, plaintext, encrypted) encrypted = dc1.Encrypt(h1, plaintext, encrypted)
h2 := Header{} h2 := header{}
h2.Parse(encrypted) h2.Parse(encrypted)
dc2 := newDataCipherFromKey(dc1.Key()) dc2 := newDataCipherFromKey(dc1.Key())
@ -67,7 +67,7 @@ func TestDataCipher_ModifyCiphertext(t *testing.T) {
} }
for _, plaintext := range testCases { for _, plaintext := range testCases {
h1 := Header{ h1 := header{
Counter: 235153, Counter: 235153,
SourceIP: 4, SourceIP: 4,
DestIP: 88, DestIP: 88,
@ -99,7 +99,7 @@ func TestDataCipher_ShortCiphertext(t *testing.T) {
} }
func BenchmarkDataCipher_Encrypt(b *testing.B) { func BenchmarkDataCipher_Encrypt(b *testing.B) {
h1 := Header{ h1 := header{
Counter: 235153, Counter: 235153,
SourceIP: 4, SourceIP: 4,
DestIP: 88, DestIP: 88,
@ -118,7 +118,7 @@ func BenchmarkDataCipher_Encrypt(b *testing.B) {
} }
func BenchmarkDataCipher_Decrypt(b *testing.B) { func BenchmarkDataCipher_Decrypt(b *testing.B) {
h1 := Header{ h1 := header{
Counter: 235153, Counter: 235153,
SourceIP: 4, SourceIP: 4,
DestIP: 88, DestIP: 88,

13
peer/cipher-discovery.go Normal file
View File

@ -0,0 +1,13 @@
package peer
/*
func signData(privKey *[64]byte, h header, data, out []byte) []byte {
out = out[:headerSize]
h.Marshal(out)
return sign.Sign(out, data, privKey)
}
func openData(pubKey *[32]byte, signed, out []byte) (data []byte, ok bool) {
return sign.Open(out[:0], signed[headerSize:], pubKey)
}
*/

View File

@ -1 +0,0 @@
package peer

191
peer/crypto_test.go Normal file
View File

@ -0,0 +1,191 @@
package peer
import (
"net/netip"
"reflect"
"testing"
)
func newRoutePairForTesting() (*remotePeer, *remotePeer) {
keys1 := generateKeys()
keys2 := generateKeys()
r1 := newRemotePeer(1)
r1.PubSignKey = keys1.PubSignKey
r1.ControlCipher = newControlCipher(keys1.PrivKey, keys2.PubKey)
r1.DataCipher = newDataCipher()
r2 := newRemotePeer(2)
r2.PubSignKey = keys2.PubSignKey
r2.ControlCipher = newControlCipher(keys2.PrivKey, keys1.PubKey)
r2.DataCipher = r1.DataCipher
return r1, r2
}
func TestDecryptControlPacket(t *testing.T) {
var (
r1, r2 = newRoutePairForTesting()
tmp = make([]byte, bufferSize)
out = make([]byte, bufferSize)
)
in := packetSyn{
TraceID: newTraceID(),
SharedKey: r1.DataCipher.Key(),
Direct: true,
}
enc := r1.EncryptControlPacket(in, tmp, out)
h := parseHeader(enc)
iMsg, err := r2.DecryptControlPacket(netip.AddrPort{}, h, enc, tmp)
if err != nil {
t.Fatal(err)
}
msg, ok := iMsg.(controlMsg[packetSyn])
if !ok {
t.Fatal(ok)
}
if !reflect.DeepEqual(msg.Packet, in) {
t.Fatal(msg)
}
}
/*
func TestDecryptControlPacket_decryptionFailed(t *testing.T) {
var (
r1, r2 = newRoutePairForTesting()
tmp = make([]byte, bufferSize)
out = make([]byte, bufferSize)
)
in := packetSyn{
TraceID: newTraceID(),
SharedKey: r1.DataCipher.Key(),
Direct: true,
}
enc := encryptControlPacket(r1.IP, r2, in, tmp, out)
h := parseHeader(enc)
for i := range enc {
x := bytes.Clone(enc)
x[i]++
_, err := decryptControlPacket(r2, netip.AddrPort{}, h, x, tmp)
if !errors.Is(err, errDecryptionFailed) {
t.Fatal(i, err)
}
}
}
func TestDecryptControlPacket_duplicate(t *testing.T) {
var (
r1, r2 = newRoutePairForTesting()
tmp = make([]byte, bufferSize)
out = make([]byte, bufferSize)
)
in := packetSyn{
TraceID: newTraceID(),
SharedKey: r1.DataCipher.Key(),
Direct: true,
}
enc := encryptControlPacket(r1.IP, r2, in, tmp, out)
h := parseHeader(enc)
if _, err := decryptControlPacket(r2, netip.AddrPort{}, h, enc, tmp); err != nil {
t.Fatal(err)
}
_, err := decryptControlPacket(r2, netip.AddrPort{}, h, enc, tmp)
if !errors.Is(err, errDuplicateSeqNum) {
t.Fatal(err)
}
}
func TestDecryptControlPacket_invalidPacket(t *testing.T) {
var (
r1, r2 = newRoutePairForTesting()
tmp = make([]byte, bufferSize)
out = make([]byte, bufferSize)
)
in := testPacket("hello!")
enc := encryptControlPacket(r1.IP, r2, in, tmp, out)
h := parseHeader(enc)
_, err := decryptControlPacket(r2, netip.AddrPort{}, h, enc, tmp)
if !errors.Is(err, errUnknownPacketType) {
t.Fatal(err)
}
}
func TestDecryptDataPacket(t *testing.T) {
var (
r1, r2 = newRoutePairForTesting()
out = make([]byte, bufferSize)
data = make([]byte, 1024)
)
rand.Read(data)
enc := encryptDataPacket(r1.IP, r2.IP, r2, data, out)
h := parseHeader(enc)
out, err := decryptDataPacket(r1, h, bytes.Clone(enc), out)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(data, out) {
t.Fatal(data, out)
}
}
func TestDecryptDataPacket_incorrectCipher(t *testing.T) {
var (
r1, r2 = newRoutePairForTesting()
out = make([]byte, bufferSize)
data = make([]byte, 1024)
)
rand.Read(data)
enc := encryptDataPacket(r1.IP, r2.IP, r2, data, bytes.Clone(out))
h := parseHeader(enc)
r1.DataCipher = newDataCipher()
_, err := decryptDataPacket(r1, h, enc, bytes.Clone(out))
if !errors.Is(err, errDecryptionFailed) {
t.Fatal(err)
}
}
func TestDecryptDataPacket_duplicate(t *testing.T) {
var (
r1, r2 = newRoutePairForTesting()
out = make([]byte, bufferSize)
data = make([]byte, 1024)
)
rand.Read(data)
enc := encryptDataPacket(r1.IP, r2.IP, r2, data, bytes.Clone(out))
h := parseHeader(enc)
_, err := decryptDataPacket(r1, h, enc, bytes.Clone(out))
if err != nil {
t.Fatal(err)
}
_, err = decryptDataPacket(r1, h, enc, bytes.Clone(out))
if !errors.Is(err, errDuplicateSeqNum) {
t.Fatal(err)
}
}
*/

View File

@ -1,11 +1,12 @@
digraph d { digraph d {
ifReader -> remote; ifReader -> connWriter;
connReader -> ifWriter;
connReader -> remote; connReader -> connWriter;
mcReader -> remote; connReader -> supervisor;
remote -> connWriter; mcReader -> supervisor;
remote -> ifWriter; supervisor -> connWriter;
hubPoller -> remote; supervisor -> mcWriter;
hubPoller -> supervisor;
connWriter [shape="box"]; connWriter [shape="box"];
mcWriter [shape="box"]; mcWriter [shape="box"];

View File

@ -51,7 +51,7 @@ func (dc *dupCheck) IsDup(counter uint64) bool {
} }
// Clear if necessary. // Clear if necessary.
for range delta { for i := 0; i < int(delta); i++ {
dc.put(false) dc.put(false)
} }

View File

@ -3,6 +3,8 @@ package peer
import "errors" import "errors"
var ( var (
errDecryptionFailed = errors.New("decryption failed")
errDuplicateSeqNum = errors.New("duplicate sequence number")
errMalformedPacket = errors.New("malformed packet") errMalformedPacket = errors.New("malformed packet")
errUnknownPacketType = errors.New("unknown packet type") errUnknownPacketType = errors.New("unknown packet type")
) )

View File

@ -1,18 +1,15 @@
package peer package peer
import ( import (
"io"
"net" "net"
"net/netip" "net/netip"
"sync"
"sync/atomic"
"time" "time"
) )
const ( const (
version = 1 version = 1
bufferSize = 8192 // Enough for data packets and encryption buffers. bufferSize = 1536
if_mtu = 1200 if_mtu = 1200
if_queue_len = 2048 if_queue_len = 2048
@ -31,148 +28,10 @@ var multicastAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom(
netip.AddrFrom4([4]byte{224, 0, 0, 157}), netip.AddrFrom4([4]byte{224, 0, 0, 157}),
4560)) 4560))
<<<<<<< HEAD func newBuf() []byte {
// ---------------------------------------------------------------------------- return make([]byte, bufferSize)
type Globals struct {
LocalConfig // Embed, immutable.
// The number of startups
StartupCount uint16
// Local public address (if available). Immutable.
LocalAddr netip.AddrPort
// True if local public address is valid. Immutable.
LocalAddrValid bool
// All remote peers by VPN IP.
RemotePeers [256]*atomic.Pointer[Remote]
// Discovered public addresses.
PubAddrs *pubAddrStore
// Attempts to ensure that we have a relay available.
RelayHandler *relayHandler
// Send UDP - Global function to write UDP packets.
SendUDP func(b []byte, addr netip.AddrPort) (n int, err error)
// Global TUN interface.
IFace io.ReadWriteCloser
// For trace ID.
NewTraceID func() uint64
} }
func NewGlobals(
localConfig LocalConfig,
startupCount startupCount,
localAddr netip.AddrPort,
conn *net.UDPConn,
iface io.ReadWriteCloser,
) (g Globals) {
g.LocalConfig = localConfig
g.StartupCount = startupCount.Count
g.LocalAddr = localAddr
g.LocalAddrValid = localAddr.IsValid()
g.PubAddrs = newPubAddrStore(localAddr)
g.RelayHandler = newRelayHandler()
// Use a lock here avoids starvation, at least on my Linux machine.
sendLock := sync.Mutex{}
g.SendUDP = func(b []byte, addr netip.AddrPort) (int, error) {
sendLock.Lock()
n, err := conn.WriteToUDPAddrPort(b, addr)
sendLock.Unlock()
return n, err
}
g.IFace = iface
traceID := (uint64(g.StartupCount) << 48) + 1
g.NewTraceID = func() uint64 {
return atomic.AddUint64(&traceID, 1)
}
for i := range g.RemotePeers {
g.RemotePeers[i] = &atomic.Pointer[Remote]{}
}
for i := range g.RemotePeers {
g.RemotePeers[i].Store(newRemote(g, byte(i)))
}
return g
=======
type marshaller interface { type marshaller interface {
Marshal([]byte) []byte Marshal([]byte) []byte
>>>>>>> 69f2536 (WIP)
}
// ----------------------------------------------------------------------------
type Globals struct {
LocalConfig // Embed, immutable.
// Local public address (if available). Immutable.
LocalAddr netip.AddrPort
// True if local public address is valid. Immutable.
LocalAddrValid bool
// All remote peers by VPN IP.
RemotePeers [256]*atomic.Pointer[Remote]
// Discovered public addresses.
PubAddrs *pubAddrStore
// Attempts to ensure that we have a relay available.
RelayHandler *relayHandler
// Send UDP - Global function to write UDP packets.
SendUDP func(b []byte, addr netip.AddrPort) (n int, err error)
// Global TUN interface.
IFace io.ReadWriteCloser
}
func NewGlobals(
localConfig LocalConfig,
localAddr netip.AddrPort,
conn *net.UDPConn,
iface io.ReadWriteCloser,
) (g Globals) {
g.LocalConfig = localConfig
g.LocalAddr = localAddr
g.LocalAddrValid = localAddr.IsValid()
g.PubAddrs = newPubAddrStore(localAddr)
g.RelayHandler = newRelayHandler()
// Use a lock here avoids starvation, at least on my Linux machine.
sendLock := sync.Mutex{}
g.SendUDP = func(b []byte, addr netip.AddrPort) (int, error) {
sendLock.Lock()
n, err := conn.WriteToUDPAddrPort(b, addr)
sendLock.Unlock()
return n, err
}
g.IFace = iface
for i := range g.RemotePeers {
g.RemotePeers[i] = &atomic.Pointer[Remote]{}
}
for i := range g.RemotePeers {
g.RemotePeers[i].Store(newRemote(g, byte(i)))
}
return g
} }

View File

@ -6,14 +6,13 @@ import "unsafe"
const ( const (
headerSize = 12 headerSize = 12
controlStreamID = 2
controlHeaderSize = 24 controlHeaderSize = 24
dataStreamID = 1
dataHeaderSize = 12 dataHeaderSize = 12
dataStreamID = 1
controlStreamID = 2
) )
type Header struct { type header struct {
Version byte Version byte
StreamID byte StreamID byte
SourceIP byte SourceIP byte
@ -21,7 +20,7 @@ type Header struct {
Counter uint64 // Init with time.Now().Unix << 30 to ensure monotonic. Counter uint64 // Init with time.Now().Unix << 30 to ensure monotonic.
} }
func parseHeader(b []byte) (h Header) { func parseHeader(b []byte) (h header) {
h.Version = b[0] h.Version = b[0]
h.StreamID = b[1] h.StreamID = b[1]
h.SourceIP = b[2] h.SourceIP = b[2]
@ -30,7 +29,7 @@ func parseHeader(b []byte) (h Header) {
return h return h
} }
func (h *Header) Parse(b []byte) { func (h *header) Parse(b []byte) {
h.Version = b[0] h.Version = b[0]
h.StreamID = b[1] h.StreamID = b[1]
h.SourceIP = b[2] h.SourceIP = b[2]
@ -38,7 +37,7 @@ func (h *Header) Parse(b []byte) {
h.Counter = *(*uint64)(unsafe.Pointer(&b[4])) h.Counter = *(*uint64)(unsafe.Pointer(&b[4]))
} }
func (h *Header) Marshal(buf []byte) { func (h *header) Marshal(buf []byte) {
buf[0] = h.Version buf[0] = h.Version
buf[1] = h.StreamID buf[1] = h.StreamID
buf[2] = h.SourceIP buf[2] = h.SourceIP

View File

@ -3,7 +3,7 @@ package peer
import "testing" import "testing"
func TestHeaderMarshalParse(t *testing.T) { func TestHeaderMarshalParse(t *testing.T) {
nIn := Header{ nIn := header{
StreamID: 23, StreamID: 23,
Counter: 3212, Counter: 3212,
SourceIP: 34, SourceIP: 34,
@ -13,7 +13,7 @@ func TestHeaderMarshalParse(t *testing.T) {
buf := make([]byte, headerSize) buf := make([]byte, headerSize)
nIn.Marshal(buf) nIn.Marshal(buf)
nOut := Header{} nOut := header{}
nOut.Parse(buf) nOut.Parse(buf)
if nIn != nOut { if nIn != nOut {
t.Fatal(nIn, nOut) t.Fatal(nIn, nOut)

View File

@ -10,20 +10,22 @@ import (
"vppn/m" "vppn/m"
) )
type HubPoller struct { type hubPoller struct {
Globals client *http.Client
client *http.Client req *http.Request
req *http.Request versions [256]int64
versions [256]int64 localIP byte
netName string netName string
handleControlMsg func(fromIP byte, msg any)
} }
func NewHubPoller( func newHubPoller(
g Globals, localIP byte,
netName, netName,
hubURL, hubURL,
apiKey string, apiKey string,
) (*HubPoller, error) { handleControlMsg func(byte, any),
) (*hubPoller, error) {
u, err := url.Parse(hubURL) u, err := url.Parse(hubURL)
if err != nil { if err != nil {
return nil, err return nil, err
@ -39,19 +41,20 @@ func NewHubPoller(
} }
req.SetBasicAuth("", apiKey) req.SetBasicAuth("", apiKey)
return &HubPoller{ return &hubPoller{
Globals: g, client: client,
client: client, req: req,
req: req, localIP: localIP,
netName: netName, netName: netName,
handleControlMsg: handleControlMsg,
}, nil }, nil
} }
func (hp *HubPoller) logf(s string, args ...any) { func (hp *hubPoller) logf(s string, args ...any) {
log.Printf("[HubPoller] "+s, args...) log.Printf("[HubPoller] "+s, args...)
} }
func (hp *HubPoller) Run() { func (hp *hubPoller) Run() {
state, err := loadNetworkState(hp.netName) state, err := loadNetworkState(hp.netName)
if err != nil { if err != nil {
hp.logf("Failed to load network state: %v", err) hp.logf("Failed to load network state: %v", err)
@ -66,7 +69,7 @@ func (hp *HubPoller) Run() {
} }
} }
func (hp *HubPoller) pollHub() { func (hp *hubPoller) pollHub() {
var state m.NetworkState var state m.NetworkState
resp, err := hp.client.Do(hp.req) resp, err := hp.client.Do(hp.req)
@ -86,26 +89,22 @@ func (hp *HubPoller) pollHub() {
return return
} }
hp.applyNetworkState(state)
if err := storeNetworkState(hp.netName, state); err != nil { if err := storeNetworkState(hp.netName, state); err != nil {
hp.logf("Failed to store network state: %v", err) hp.logf("Failed to store network state: %v", err)
} }
hp.applyNetworkState(state)
} }
func (hp *HubPoller) applyNetworkState(state m.NetworkState) { func (hp *hubPoller) applyNetworkState(state m.NetworkState) {
for i, peer := range state.Peers { for i, peer := range state.Peers {
if i == int(hp.LocalPeerIP) { if i != int(hp.localIP) {
continue if peer == nil || peer.Version != hp.versions[i] {
} hp.handleControlMsg(byte(i), peerUpdateMsg{Peer: state.Peers[i]})
if peer != nil {
if peer != nil && peer.Version == hp.versions[i] { hp.versions[i] = peer.Version
continue }
} }
hp.RemotePeers[i].Load().HandlePeerUpdate(peerUpdateMsg{Peer: state.Peers[i]})
if peer != nil {
hp.versions[i] = peer.Version
} }
} }
} }

View File

@ -1,35 +1,67 @@
package peer package peer
import ( import (
"io"
"log" "log"
"net/netip"
"sync/atomic"
) )
type IFReader struct { type ifReader struct {
Globals iface io.Reader
writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error)
rt *atomic.Pointer[routingTable]
buf1 []byte
buf2 []byte
} }
func NewIFReader(g Globals) *IFReader { func newIFReader(
return &IFReader{Globals: g} iface io.Reader,
writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error),
rt *atomic.Pointer[routingTable],
) *ifReader {
return &ifReader{iface, writeToUDPAddrPort, rt, newBuf(), newBuf()}
} }
func (r *IFReader) Run() { func (r *ifReader) Run() {
packet := make([]byte, bufferSize) packet := newBuf()
for { for {
r.handleNextPacket(packet) r.handleNextPacket(packet)
} }
} }
func (r *IFReader) handleNextPacket(packet []byte) { func (r *ifReader) handleNextPacket(packet []byte) {
packet = r.readNextPacket(packet) packet = r.readNextPacket(packet)
remoteIP, ok := r.parsePacket(packet) remoteIP, ok := r.parsePacket(packet)
if !ok { if !ok {
return return
} }
r.RemotePeers[remoteIP].Load().SendDataTo(packet)
rt := r.rt.Load()
peer := rt.Peers[remoteIP]
if !peer.Up {
r.logf("Peer %d not up.", peer.IP)
return
}
enc := peer.EncryptDataPacket(peer.IP, packet, r.buf1)
if peer.Direct {
r.writeToUDPAddrPort(enc, peer.DirectAddr)
return
}
relay, ok := rt.GetRelay()
if !ok {
r.logf("Relay not available for peer %d.", peer.IP)
return
}
enc = relay.EncryptDataPacket(peer.IP, enc, r.buf2)
r.writeToUDPAddrPort(enc, relay.DirectAddr)
} }
func (r *IFReader) readNextPacket(buf []byte) []byte { func (r *ifReader) readNextPacket(buf []byte) []byte {
n, err := r.IFace.Read(buf[:cap(buf)]) n, err := r.iface.Read(buf[:cap(buf)])
if err != nil { if err != nil {
log.Fatalf("Failed to read from interface: %v", err) log.Fatalf("Failed to read from interface: %v", err)
} }
@ -37,9 +69,7 @@ func (r *IFReader) readNextPacket(buf []byte) []byte {
return buf[:n] return buf[:n]
} }
// parsePacket returns the VPN ip for the packet, and a boolean indicating func (r *ifReader) parsePacket(buf []byte) (byte, bool) {
// success.
func (r *IFReader) parsePacket(buf []byte) (byte, bool) {
n := len(buf) n := len(buf)
if n == 0 { if n == 0 {
return 0, false return 0, false
@ -68,6 +98,6 @@ func (r *IFReader) parsePacket(buf []byte) (byte, bool) {
} }
} }
func (*IFReader) logf(s string, args ...any) { func (*ifReader) logf(s string, args ...any) {
log.Printf("[IFReader] "+s, args...) log.Printf("[IFReader] "+s, args...)
} }

View File

@ -1,5 +0,0 @@
package peer
func newBuf() []byte {
return make([]byte, bufferSize)
}

View File

@ -9,7 +9,7 @@ import (
) )
func createLocalDiscoveryPacket(localIP byte, signingKey []byte) []byte { func createLocalDiscoveryPacket(localIP byte, signingKey []byte) []byte {
h := Header{ h := header{
SourceIP: localIP, SourceIP: localIP,
DestIP: 255, DestIP: 255,
} }
@ -19,7 +19,7 @@ func createLocalDiscoveryPacket(localIP byte, signingKey []byte) []byte {
return sign.Sign(out[:0], buf, (*[64]byte)(signingKey)) return sign.Sign(out[:0], buf, (*[64]byte)(signingKey))
} }
func headerFromLocalDiscoveryPacket(pkt []byte) (h Header, ok bool) { func headerFromLocalDiscoveryPacket(pkt []byte) (h header, ok bool) {
if len(pkt) != headerSize+signOverhead { if len(pkt) != headerSize+signOverhead {
return return
} }
@ -36,7 +36,7 @@ func verifyLocalDiscoveryPacket(pkt, buf []byte, pubSignKey []byte) bool {
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
func RunMCWriter(localIP byte, signingKey []byte) { func runMCWriter(localIP byte, signingKey []byte) {
discoveryPacket := createLocalDiscoveryPacket(localIP, signingKey) discoveryPacket := createLocalDiscoveryPacket(localIP, signingKey)
conn, err := net.ListenMulticastUDP("udp", nil, multicastAddr) conn, err := net.ListenMulticastUDP("udp", nil, multicastAddr)

View File

@ -2,9 +2,17 @@ package peer
import ( import (
"net/netip" "net/netip"
"sync/atomic"
"time"
"unsafe" "unsafe"
) )
var traceIDCounter uint64 = uint64(time.Now().Unix()<<30) + 1
func newTraceID() uint64 {
return atomic.AddUint64(&traceIDCounter, 1)
}
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
type binWriter struct { type binWriter struct {

View File

@ -9,8 +9,10 @@ import (
func TestSynPacket(t *testing.T) { func TestSynPacket(t *testing.T) {
p := packetSyn{ p := packetSyn{
TraceID: 2342342345, TraceID: newTraceID(),
Direct: true, //SentAt: time.Now().UnixMilli(),
//SharedKeyType: 1,
Direct: true,
} }
rand.Read(p.SharedKey[:]) rand.Read(p.SharedKey[:])
@ -30,7 +32,7 @@ func TestSynPacket(t *testing.T) {
func TestAckPacket(t *testing.T) { func TestAckPacket(t *testing.T) {
p := packetAck{ p := packetAck{
TraceID: 123213, TraceID: newTraceID(),
ToAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 234), ToAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 234),
} }
@ -50,7 +52,7 @@ func TestAckPacket(t *testing.T) {
func TestProbePacket(t *testing.T) { func TestProbePacket(t *testing.T) {
p := packetProbe{ p := packetProbe{
TraceID: 12345, TraceID: newTraceID(),
} }
buf := p.Marshal(newBuf()) buf := p.Marshal(newBuf())

114
peer/peer_test.go Normal file
View File

@ -0,0 +1,114 @@
package peer
import (
"bytes"
"crypto/rand"
mrand "math/rand"
"net/netip"
"sync/atomic"
)
// A test peer.
type P struct {
cryptoKeys
RT *atomic.Pointer[routingTable]
Conn *TestUDPConn
IFace *TestIFace
ConnReader *connReader
IFReader *ifReader
}
func NewPeerForTesting(n *TestNetwork, ip byte, addr netip.AddrPort) P {
p := P{
cryptoKeys: generateKeys(),
RT: &atomic.Pointer[routingTable]{},
IFace: NewTestIFace(),
}
rt := newRoutingTable(ip, addr)
p.RT.Store(&rt)
p.Conn = n.NewUDPConn(addr)
//p.ConnWriter = NewConnWriter(p.Conn.WriteToUDPAddrPort, p.RT)
return p
}
func ConnectPeers(p1, p2 *P) {
rt1 := p1.RT.Load()
rt2 := p2.RT.Load()
ip1 := rt1.LocalIP
ip2 := rt2.LocalIP
rt1.Peers[ip2].Up = true
rt1.Peers[ip2].Direct = true
rt1.Peers[ip2].Relay = true
rt1.Peers[ip2].DirectAddr = rt2.LocalAddr
rt1.Peers[ip2].PubSignKey = p2.PubSignKey
rt1.Peers[ip2].ControlCipher = newControlCipher(p1.PrivKey, p2.PubKey)
rt1.Peers[ip2].DataCipher = newDataCipher()
rt2.Peers[ip1].Up = true
rt2.Peers[ip1].Direct = true
rt2.Peers[ip1].Relay = true
rt2.Peers[ip1].DirectAddr = rt1.LocalAddr
rt2.Peers[ip1].PubSignKey = p1.PubSignKey
rt2.Peers[ip1].ControlCipher = newControlCipher(p2.PrivKey, p1.PubKey)
rt2.Peers[ip1].DataCipher = rt1.Peers[ip2].DataCipher
}
func NewPeersForTesting() (p1, p2, p3 P) {
n := NewTestNetwork()
p1 = NewPeerForTesting(
n,
1,
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 100))
p2 = NewPeerForTesting(
n,
2,
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 2}), 200))
p3 = NewPeerForTesting(
n,
3,
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 3}), 300))
ConnectPeers(&p1, &p2)
ConnectPeers(&p1, &p3)
ConnectPeers(&p2, &p3)
return
}
func RandPacket() []byte {
n := mrand.Intn(1200)
b := make([]byte, n)
rand.Read(b)
return b
}
func ModifyPacket(in []byte) []byte {
x := make([]byte, 1)
for {
rand.Read(x)
out := bytes.Clone(in)
idx := mrand.Intn(len(out))
if out[idx] != x[0] {
out[idx] = x[0]
return out
}
}
}
// ----------------------------------------------------------------------------
type UnknownControlPacket struct {
TraceID uint64
}
func (p UnknownControlPacket) Marshal(buf []byte) []byte {
return newBinWriter(buf).Byte(255).Uint64(p.TraceID).Build()
}

View File

@ -1,28 +0,0 @@
digraph d {
disconnected -> peerUpdating;
peerUpdating -> disconnected;
peerUpdating -> server;
peerUpdating -> clientInit;
server -> peerUpdating;
clientInit -> peerUpdating;
clientInit -> clientInit;
clientInit -> client;
client -> clientInit;
client -> peerUpdating;
clientInitializing -> clientSyncing;
clientSyncing -> clientInitializing;
clientSyncing -> clientUpIndirect;
clientSyncing -> clientUpDirect;
clientUpIndirect -> clientUpDirect;
clientUpIndirect -> clientInitializing;
clientUpDirect -> clientInitializing;
serverInitializing -> serverSyncing;
serverSyncing -> serverInitializing;
serverSyncing -> serverUpIndirect;
serverSyncing -> serverUpDirect;
serverUpIndirect -> serverUpDirect;
serverUpIndirect -> serverInitializing;
serverUpDirect -> serverInitializing;
}

371
peer/peerstates_test.go Normal file
View File

@ -0,0 +1,371 @@
package peer
import (
"testing"
"vppn/m"
)
// ----------------------------------------------------------------------------
func TestPeerState_OnPeerUpdate_nilPeer(t *testing.T) {
h := NewPeerStateTestHarness()
h.PeerUpdate(nil)
assertType[*stateDisconnected](t, h.State)
}
func TestPeerState_OnPeerUpdate_publicLocalIsServer(t *testing.T) {
keys := generateKeys()
h := NewPeerStateTestHarness()
state := h.State.(*stateDisconnected)
state.localAddr = addrPort4(1, 1, 1, 2, 200)
peer := &m.Peer{
PeerIP: 3,
Port: 456,
PubKey: keys.PubKey,
PubSignKey: keys.PubSignKey,
}
h.PeerUpdate(peer)
assertEqual(t, h.Published.Up, false)
assertType[*stateServer](t, h.State)
}
/*
func TestPeerState_OnPeerUpdate_clientDirect(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigClientDirect(t)
}
/*
func TestPeerState_OnPeerUpdate_clientRelayed(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigClientRelayed(t)
}
/*
func TestStateServer_directSyn(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigServer_Relayed(t)
assertEqual(t, h.Published.Up, false)
synMsg := controlMsg[packetSyn]{
SrcIP: 3,
SrcAddr: addrPort4(1, 1, 1, 3, 300),
Packet: packetSyn{
TraceID: newTraceID(),
//SentAt: time.Now().UnixMilli(),
//SharedKeyType: 1,
Direct: true,
},
}
h.State = h.State.OnMsg(synMsg)
assertEqual(t, len(h.Sent), 1)
ack := assertType[packetAck](t, h.Sent[0].Packet)
assertEqual(t, ack.TraceID, synMsg.Packet.TraceID)
assertEqual(t, h.Sent[0].Peer.IP, 3)
assertEqual(t, ack.PossibleAddrs[0].IsValid(), false)
assertEqual(t, h.Published.Up, true)
}
func TestStateServer_relayedSyn(t *testing.T) {
h := NewPeerStateTestHarness()
state := h.ConfigServer_Relayed(t)
state.pubAddrs.Store(addrPort4(4, 5, 6, 7, 1234))
assertEqual(t, h.Published.Up, false)
synMsg := controlMsg[packetSyn]{
SrcIP: 3,
SrcAddr: addrPort4(1, 1, 1, 3, 300),
Packet: packetSyn{
TraceID: newTraceID(),
//SentAt: time.Now().UnixMilli(),
//SharedKeyType: 1,
Direct: false,
},
}
synMsg.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 3, 300)
synMsg.Packet.PossibleAddrs[1] = addrPort4(2, 2, 2, 3, 300)
h.State = h.State.OnMsg(synMsg)
assertEqual(t, len(h.Sent), 3)
ack := assertType[packetAck](t, h.Sent[0].Packet)
assertEqual(t, ack.TraceID, synMsg.Packet.TraceID)
assertEqual(t, h.Sent[0].Peer.IP, 3)
assertEqual(t, ack.PossibleAddrs[0], addrPort4(4, 5, 6, 7, 1234))
assertEqual(t, ack.PossibleAddrs[1].IsValid(), false)
assertEqual(t, h.Published.Up, true)
assertType[packetProbe](t, h.Sent[1].Packet)
assertType[packetProbe](t, h.Sent[2].Packet)
assertEqual(t, h.Sent[1].Peer.DirectAddr, addrPort4(1, 1, 1, 3, 300))
assertEqual(t, h.Sent[2].Peer.DirectAddr, addrPort4(2, 2, 2, 3, 300))
}
func TestStateServer_onProbe(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigServer_Relayed(t)
assertEqual(t, h.Published.Up, false)
probeMsg := controlMsg[packetProbe]{
SrcIP: 3,
SrcAddr: addrPort4(1, 1, 1, 3, 300),
Packet: packetProbe{TraceID: newTraceID()},
}
h.State = h.State.OnMsg(probeMsg)
assertEqual(t, len(h.Sent), 1)
probe := assertType[packetProbe](t, h.Sent[0].Packet)
assertEqual(t, probe.TraceID, probeMsg.Packet.TraceID)
assertEqual(t, h.Sent[0].Peer.DirectAddr, addrPort4(1, 1, 1, 3, 300))
}
func TestStateServer_OnPingTimer_timeout(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigServer_Relayed(t)
synMsg := controlMsg[packetSyn]{
SrcIP: 3,
SrcAddr: addrPort4(1, 1, 1, 3, 300),
Packet: packetSyn{
TraceID: newTraceID(),
//SentAt: time.Now().UnixMilli(),
//SharedKeyType: 1,
Direct: true,
},
}
h.State = h.State.OnMsg(synMsg)
assertEqual(t, len(h.Sent), 1)
assertEqual(t, h.Published.Up, true)
// Ping shouldn't timeout.
h.OnPingTimer()
assertEqual(t, h.Published.Up, true)
// Advance the time, then ping.
state := assertType[*stateServer](t, h.State)
state.lastSeen = time.Now().Add(-timeoutInterval - time.Second)
h.OnPingTimer()
assertEqual(t, h.Published.Up, false)
}
func TestStateClientDirect_OnAck(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigClientDirect(t)
assertEqual(t, h.Published.Up, false)
// On entering the state, a SYN should have been sent.
assertEqual(t, len(h.Sent), 1)
syn := assertType[packetSyn](t, h.Sent[0].Packet)
ack := controlMsg[packetAck]{
Packet: packetAck{TraceID: syn.TraceID},
}
h.State = h.State.OnMsg(ack)
assertEqual(t, h.Published.Up, true)
}
func TestStateClientDirect_OnAck_incorrectTraceID(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigClientDirect(t)
assertEqual(t, h.Published.Up, false)
// On entering the state, a SYN should have been sent.
assertEqual(t, len(h.Sent), 1)
syn := assertType[packetSyn](t, h.Sent[0].Packet)
ack := controlMsg[packetAck]{
Packet: packetAck{TraceID: syn.TraceID + 1},
}
h.State = h.State.OnMsg(ack)
assertEqual(t, h.Published.Up, false)
}
func TestStateClientDirect_OnPingTimer(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigClientDirect(t)
// On entering the state, a SYN should have been sent.
assertEqual(t, len(h.Sent), 1)
assertType[packetSyn](t, h.Sent[0].Packet)
h.OnPingTimer()
// On ping timer, another syn should be sent. Additionally, we should remain
// in the same state.
assertEqual(t, len(h.Sent), 2)
assertType[packetSyn](t, h.Sent[1].Packet)
assertType[*stateClientDirect](t, h.State)
assertEqual(t, h.Published.Up, false)
}
func TestStateClientDirect_OnPingTimer_timeout(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigClientDirect(t)
assertEqual(t, h.Published.Up, false)
// On entering the state, a SYN should have been sent.
assertEqual(t, len(h.Sent), 1)
syn := assertType[packetSyn](t, h.Sent[0].Packet)
ack := controlMsg[packetAck]{
Packet: packetAck{TraceID: syn.TraceID},
}
h.State = h.State.OnMsg(ack)
assertEqual(t, h.Published.Up, true)
state := assertType[*stateClientDirect](t, h.State)
state.lastSeen = time.Now().Add(-(timeoutInterval + time.Second))
h.OnPingTimer()
// On ping timer, we should timeout, causing the client to reset. Another SYN
// will be sent when re-entering the state, but the connection should be down.
assertEqual(t, len(h.Sent), 2)
assertType[packetSyn](t, h.Sent[1].Packet)
assertType[*stateClientDirect](t, h.State)
assertEqual(t, h.Published.Up, false)
}
func TestStateClientRelayed_OnAck(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigClientRelayed(t)
assertEqual(t, h.Published.Up, false)
// On entering the state, a SYN should have been sent.
assertEqual(t, len(h.Sent), 1)
syn := assertType[packetSyn](t, h.Sent[0].Packet)
ack := controlMsg[packetAck]{
Packet: packetAck{TraceID: syn.TraceID},
}
h.State = h.State.OnMsg(ack)
assertEqual(t, h.Published.Up, true)
}
func TestStateClientRelayed_OnPingTimer_noAddrs(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigClientRelayed(t)
assertEqual(t, h.Published.Up, false)
// On entering the state, a SYN should have been sent.
assertEqual(t, len(h.Sent), 1)
// If we haven't had an ack yet, we won't have addresses to probe. Therefore
// we'll have just one more syn packet sent.
h.OnPingTimer()
assertEqual(t, len(h.Sent), 2)
}
func TestStateClientRelayed_OnPingTimer_withAddrs(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigClientRelayed(t)
assertEqual(t, h.Published.Up, false)
// On entering the state, a SYN should have been sent.
assertEqual(t, len(h.Sent), 1)
syn := assertType[packetSyn](t, h.Sent[0].Packet)
ack := controlMsg[packetAck]{Packet: packetAck{TraceID: syn.TraceID}}
ack.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 300)
ack.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 300)
h.State = h.State.OnMsg(ack)
// Add a local discovery address. Note that the port will be configured port
// and no the one provided here.
h.State = h.State.OnMsg(controlMsg[packetLocalDiscovery]{
SrcIP: 3,
SrcAddr: addrPort4(2, 2, 2, 3, 300),
})
// We should see one SYN and three probe packets.
h.OnPingTimer()
assertEqual(t, len(h.Sent), 5)
assertType[packetSyn](t, h.Sent[1].Packet)
assertType[packetProbe](t, h.Sent[2].Packet)
assertType[packetProbe](t, h.Sent[3].Packet)
assertType[packetProbe](t, h.Sent[4].Packet)
assertEqual(t, h.Sent[2].Peer.DirectAddr, addrPort4(1, 1, 1, 1, 300))
assertEqual(t, h.Sent[3].Peer.DirectAddr, addrPort4(1, 1, 1, 2, 300))
assertEqual(t, h.Sent[4].Peer.DirectAddr, addrPort4(2, 2, 2, 3, 456))
}
func TestStateClientRelayed_OnPingTimer_timeout(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigClientRelayed(t)
// On entering the state, a SYN should have been sent.
assertEqual(t, len(h.Sent), 1)
syn := assertType[packetSyn](t, h.Sent[0].Packet)
ack := controlMsg[packetAck]{
Packet: packetAck{TraceID: syn.TraceID},
}
h.State = h.State.OnMsg(ack)
assertEqual(t, h.Published.Up, true)
state := assertType[*stateClientRelayed](t, h.State)
state.lastSeen = time.Now().Add(-(timeoutInterval + time.Second))
h.OnPingTimer()
// On ping timer, we should timeout, causing the client to reset. Another SYN
// will be sent when re-entering the state, but the connection should be down.
assertEqual(t, len(h.Sent), 2)
assertType[packetSyn](t, h.Sent[1].Packet)
assertType[*stateClientRelayed](t, h.State)
assertEqual(t, h.Published.Up, false)
}
func TestStateClientRelayed_OnProbe_unknownAddr(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigClientRelayed(t)
h.OnProbe(controlMsg[packetProbe]{
Packet: packetProbe{TraceID: newTraceID()},
})
assertType[*stateClientRelayed](t, h.State)
}
func TestStateClientRelayed_OnProbe_upgradeDirect(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigClientRelayed(t)
syn := assertType[packetSyn](t, h.Sent[0].Packet)
ack := controlMsg[packetAck]{Packet: packetAck{TraceID: syn.TraceID}}
ack.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 300)
ack.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 300)
h.State = h.State.OnMsg(ack)
h.OnPingTimer()
probe := assertType[packetProbe](t, h.Sent[2].Packet)
h.OnProbe(controlMsg[packetProbe]{Packet: probe})
assertType[*stateClientDirect](t, h.State)
}
*/

View File

@ -1 +1,148 @@
package peer package peer
import (
"net/netip"
"sync"
"sync/atomic"
"time"
"git.crumpington.com/lib/go/ratelimiter"
)
type supervisor struct {
writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error)
staged routingTable
shared *atomic.Pointer[routingTable]
peers [256]*peerSuper
lock sync.Mutex
buf1 []byte
buf2 []byte
}
func newSupervisor(
writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error),
rt *atomic.Pointer[routingTable],
privKey []byte,
) *supervisor {
routes := rt.Load()
s := &supervisor{
writeToUDPAddrPort: writeToUDPAddrPort,
staged: *routes,
shared: rt,
buf1: newBuf(),
buf2: newBuf(),
}
pubAddrs := newPubAddrStore(routes.LocalAddr)
for i := range s.peers {
state := &peerData{
publish: s.publish,
sendControlPacket: s.send,
pingTimer: time.NewTicker(timeoutInterval),
localIP: routes.LocalIP,
remoteIP: byte(i),
privKey: privKey,
localAddr: routes.LocalAddr,
pubAddrs: pubAddrs,
staged: routes.Peers[i],
limiter: ratelimiter.New(ratelimiter.Config{
FillPeriod: 20 * time.Millisecond,
MaxWaitCount: 1,
}),
}
s.peers[i] = newPeerSuper(state, state.pingTimer)
}
return s
}
func (s *supervisor) Start() {
for i := range s.peers {
go s.peers[i].Run()
}
}
func (s *supervisor) HandleControlMsg(destIP byte, msg any) {
s.peers[destIP].HandleControlMsg(msg)
}
func (s *supervisor) send(peer remotePeer, pkt marshaller) {
s.lock.Lock()
defer s.lock.Unlock()
enc := peer.EncryptControlPacket(pkt, s.buf1, s.buf2)
if peer.Direct {
s.writeToUDPAddrPort(enc, peer.DirectAddr)
return
}
relay, ok := s.staged.GetRelay()
if !ok {
return
}
enc = relay.EncryptDataPacket(peer.IP, enc, s.buf1)
s.writeToUDPAddrPort(enc, relay.DirectAddr)
}
func (s *supervisor) publish(rp remotePeer) {
s.lock.Lock()
defer s.lock.Unlock()
s.staged.Peers[rp.IP] = rp
s.ensureRelay()
copy := s.staged
s.shared.Store(&copy)
}
func (s *supervisor) ensureRelay() {
if _, ok := s.staged.GetRelay(); ok {
return
}
// TODO: Random selection? Something else?
for _, peer := range s.staged.Peers {
if peer.Up && peer.Direct && peer.Relay {
s.staged.RelayIP = peer.IP
return
}
}
}
// ----------------------------------------------------------------------------
type peerSuper struct {
messages chan any
state peerState
pingTimer *time.Ticker
}
func newPeerSuper(state *peerData, pingTimer *time.Ticker) *peerSuper {
return &peerSuper{
messages: make(chan any, 8),
state: initPeerState(state, nil),
pingTimer: pingTimer,
}
}
func (s *peerSuper) HandleControlMsg(msg any) {
select {
case s.messages <- msg:
default:
}
}
func (s *peerSuper) Run() {
for {
select {
case <-s.pingTimer.C:
s.state = s.state.OnMsg(pingTimerMsg{})
case raw := <-s.messages:
s.state = s.state.OnMsg(raw)
}
}
}

View File

@ -1 +0,0 @@
package peer

View File

@ -1,54 +0,0 @@
package peer
import (
"log"
"sync"
"sync/atomic"
)
type relayHandler struct {
lock sync.Mutex
relays map[byte]*Remote
relay atomic.Pointer[Remote]
}
func newRelayHandler() *relayHandler {
return &relayHandler{
relays: make(map[byte]*Remote, 256),
}
}
func (h *relayHandler) Add(r *Remote) {
h.lock.Lock()
defer h.lock.Unlock()
h.relays[r.RemotePeerIP] = r
if h.relay.Load() == nil {
log.Printf("Setting Relay: %v", r.conf().Peer.Name)
h.relay.Store(r)
}
}
func (h *relayHandler) Remove(r *Remote) {
h.lock.Lock()
defer h.lock.Unlock()
log.Printf("Removing relay %d...", r.RemotePeerIP)
delete(h.relays, r.RemotePeerIP)
if h.relay.Load() == r {
// Remove current relay.
h.relay.Store(nil)
// Find new relay.
for _, r := range h.relays {
h.relay.Store(r)
break
}
}
}
func (h *relayHandler) Load() *Remote {
return h.relay.Load()
}

View File

@ -1,429 +0,0 @@
package peer
import (
"fmt"
"log"
"net/netip"
"strings"
"sync/atomic"
<<<<<<< HEAD
"vppn/m"
=======
"time"
"vppn/m"
"git.crumpington.com/lib/go/ratelimiter"
>>>>>>> 69f2536 (WIP)
)
// ----------------------------------------------------------------------------
// The remoteConfig is the shared, immutable configuration for a remote
// peer. It's read and written atomically. See remote.config.
// ----------------------------------------------------------------------------
type remoteConfig struct {
Up bool // True if peer is up and we can send data.
Server bool // True if role is server.
Direct bool // True if this is a direct connection.
DirectAddr netip.AddrPort // Remote address if directly connected.
ControlCipher *controlCipher
DataCipher *dataCipher
Peer *m.Peer
}
// CanRelay returns true if the remote configuration is able to relay packets.
// to other hosts.
func (rc remoteConfig) CanRelay() bool {
return rc.Up && rc.Direct && rc.Peer.Relay
}
// A Remote represents a remote peer and contains functions for handling
// incoming control, data, and multicast packets, peer udpates, as well as
// sending, forwarding, and relaying packets.
type Remote struct {
Globals
RemotePeerIP byte // Immutable.
<<<<<<< HEAD
=======
limiter *ratelimiter.Limiter
>>>>>>> 69f2536 (WIP)
dupCheck *dupCheck
sendCounter uint64 // init to startupCount << 48. Atomic access only.
// config should be accessed via conf() and updateConf(...) methods.
config atomic.Pointer[remoteConfig]
messages chan any
}
func newRemote(g Globals, remotePeerIP byte) *Remote {
r := &Remote{
Globals: g,
RemotePeerIP: remotePeerIP,
<<<<<<< HEAD
dupCheck: newDupCheck(0),
sendCounter: (uint64(g.StartupCount) << 48) + 1,
messages: make(chan any, 8),
=======
limiter: ratelimiter.New(ratelimiter.Config{
FillPeriod: 20 * time.Millisecond,
MaxWaitCount: 1,
}),
dupCheck: newDupCheck(0),
sendCounter: uint64(time.Now().Unix()<<30) + 1,
messages: make(chan any, 8),
>>>>>>> 69f2536 (WIP)
}
r.config.Store(&remoteConfig{})
return r
}
// ----------------------------------------------------------------------------
func (r *Remote) conf() remoteConfig {
return *(r.config.Load())
}
func (r *Remote) updateConf(conf remoteConfig) {
old := r.config.Load()
r.config.Store(&conf)
if !old.CanRelay() && conf.CanRelay() {
r.RelayHandler.Add(r)
}
if old.CanRelay() && !conf.CanRelay() {
r.RelayHandler.Remove(r)
}
}
// ----------------------------------------------------------------------------
func (r *Remote) sendUDP(b []byte, addr netip.AddrPort) {
<<<<<<< HEAD
if _, err := r.SendUDP(b, addr); err != nil {
r.logf("Failed to send UDP packet: %v", err)
=======
if err := r.limiter.Limit(); err != nil {
r.logf("Rate limiter")
return
}
if _, err := r.SendUDP(b, addr); err != nil {
r.logf("Failed to send URP packet: %v", err)
>>>>>>> 69f2536 (WIP)
}
}
// ----------------------------------------------------------------------------
<<<<<<< HEAD
func (r *Remote) encryptData(conf remoteConfig, destIP byte, packet []byte) []byte {
=======
func (r *Remote) encryptData(conf remoteConfig, packet []byte) []byte {
>>>>>>> 69f2536 (WIP)
h := Header{
StreamID: dataStreamID,
Counter: atomic.AddUint64(&r.sendCounter, 1),
SourceIP: r.Globals.LocalPeerIP,
<<<<<<< HEAD
DestIP: destIP,
=======
DestIP: r.RemotePeerIP,
>>>>>>> 69f2536 (WIP)
}
return conf.DataCipher.Encrypt(h, packet, packet[len(packet):cap(packet)])
}
func (r *Remote) encryptControl(conf remoteConfig, packet []byte) []byte {
h := Header{
StreamID: controlStreamID,
Counter: atomic.AddUint64(&r.sendCounter, 1),
SourceIP: r.LocalPeerIP,
DestIP: r.RemotePeerIP,
}
return conf.ControlCipher.Encrypt(h, packet, packet[len(packet):cap(packet)])
}
// ----------------------------------------------------------------------------
// SendDataTo sends a data packet to the remote, called by the IFReader.
func (r *Remote) SendDataTo(data []byte) {
conf := r.conf()
if !conf.Up {
r.logf("Cannot send: link down")
return
}
<<<<<<< HEAD
// Direct:
if conf.Direct {
r.sendUDP(r.encryptData(conf, conf.Peer.PeerIP, data), conf.DirectAddr)
return
}
// Relayed:
=======
if conf.Direct {
r.sendDataDirect(conf, data)
} else {
r.sendDataRelayed(conf, data)
}
}
// sendDataRelayed sends data to the remote via the relay.
func (r *Remote) sendDataRelayed(conf remoteConfig, data []byte) {
>>>>>>> 69f2536 (WIP)
relay := r.RelayHandler.Load()
if relay == nil {
r.logf("Connot send: no relay")
return
}
<<<<<<< HEAD
relay.relayData(conf.Peer.PeerIP, r.encryptData(conf, conf.Peer.PeerIP, data))
}
func (r *Remote) relayData(toIP byte, enc []byte) {
=======
relay.relayData(r.encryptData(conf, data))
}
// sendDataDirect sends data to the remote directly.
func (r *Remote) sendDataDirect(conf remoteConfig, data []byte) {
r.logf("Sending data direct...")
r.sendUDP(r.encryptData(conf, data), conf.DirectAddr)
}
func (r *Remote) relayData(enc []byte) {
>>>>>>> 69f2536 (WIP)
conf := r.conf()
if !conf.Up || !conf.Direct {
r.logf("Cannot relay: not up or not a direct connection")
return
}
<<<<<<< HEAD
r.sendUDP(r.encryptData(conf, toIP, enc), conf.DirectAddr)
}
func (r *Remote) sendControl(conf remoteConfig, data []byte) {
// Direct:
if conf.Direct {
enc := r.encryptControl(conf, data)
r.sendUDP(enc, conf.DirectAddr)
return
}
// Relayed:
=======
r.sendDataDirect(conf, enc)
}
func (r *Remote) sendControl(conf remoteConfig, data []byte) {
if conf.Direct {
r.sendControlDirect(conf, data)
} else {
r.sendControlRelayed(conf, data)
}
}
func (r *Remote) sendControlToAddr(buf []byte, addr netip.AddrPort) {
enc := r.encryptControl(r.conf(), buf)
r.sendUDP(enc, addr)
}
func (r *Remote) sendControlDirect(conf remoteConfig, data []byte) {
r.logf("Sending control direct...")
enc := r.encryptControl(conf, data)
r.sendUDP(enc, conf.DirectAddr)
}
func (r *Remote) sendControlRelayed(conf remoteConfig, data []byte) {
r.logf("Sending control relayed...")
>>>>>>> 69f2536 (WIP)
relay := r.RelayHandler.Load()
if relay == nil {
r.logf("Connot send: no relay")
return
}
<<<<<<< HEAD
relay.relayData(conf.Peer.PeerIP, r.encryptControl(conf, data))
}
func (r *Remote) sendControlToAddr(buf []byte, addr netip.AddrPort) {
enc := r.encryptControl(r.conf(), buf)
r.sendUDP(enc, addr)
=======
relay.relayData(r.encryptControl(conf, data))
>>>>>>> 69f2536 (WIP)
}
func (r *Remote) forwardPacket(data []byte) {
conf := r.conf()
if !conf.Up || !conf.Direct {
r.logf("Cannot forward to %d: not a direct connection", conf.Peer.PeerIP)
return
}
r.sendUDP(data, conf.DirectAddr)
}
// ----------------------------------------------------------------------------
// HandlePacket is called by the ConnReader to handle an incoming packet.
func (r *Remote) HandlePacket(h Header, srcAddr netip.AddrPort, data []byte) {
switch h.StreamID {
case controlStreamID:
r.handleControlPacket(h, srcAddr, data)
case dataStreamID:
r.handleDataPacket(h, data)
default:
r.logf("Unknown stream ID: %d", h.StreamID)
}
}
// Handle a control packet. Decrypt, verify, etc.
func (r *Remote) handleControlPacket(h Header, srcAddr netip.AddrPort, data []byte) {
conf := r.conf()
if conf.ControlCipher == nil {
r.logf("No control cipher")
return
}
dec, ok := conf.ControlCipher.Decrypt(data, data[len(data):cap(data)])
if !ok {
r.logf("Failed to decrypt control packet")
return
}
if r.dupCheck.IsDup(h.Counter) {
r.logf("Dropping control packet as duplicate: %d", h.Counter)
return
}
msg, err := parseControlMsg(h.SourceIP, srcAddr, dec)
if err != nil {
r.logf("Failed to parse control packet: %v", err)
return
}
select {
case r.messages <- msg:
default:
r.logf("Dropping control message")
}
}
func (r *Remote) handleDataPacket(h Header, data []byte) {
conf := r.conf()
if conf.DataCipher == nil {
return
}
dec, ok := conf.DataCipher.Decrypt(data, data[len(data):cap(data)])
if !ok {
r.logf("Failed to decrypt data packet")
return
}
if r.dupCheck.IsDup(h.Counter) {
r.logf("Dropping data packet as duplicate: %d", h.Counter)
return
}
// For local.
if h.DestIP == r.LocalPeerIP {
if _, err := r.IFace.Write(dec); err != nil {
<<<<<<< HEAD
// This could be a malformed packet from a peer, so we don't crash if it
// happens.
r.logf("Failed to write to interface: %v", err)
=======
log.Fatalf("Failed to write to interface: %v", err)
>>>>>>> 69f2536 (WIP)
}
return
}
// Forward.
dest := r.RemotePeers[h.DestIP].Load()
dest.forwardPacket(dec)
}
// ----------------------------------------------------------------------------
// HandleLocalDiscoveryPacket is called by the MCReader.
func (r *Remote) HandleLocalDiscoveryPacket(h Header, srcAddr netip.AddrPort, data []byte) {
conf := r.conf()
if conf.Peer.PubSignKey == nil {
r.logf("No signing key for discovery packet.")
return
}
if !verifyLocalDiscoveryPacket(data, data[len(data):cap(data)], conf.Peer.PubSignKey) {
r.logf("Invalid signature on discovery packet.")
return
}
msg := controlMsg[packetLocalDiscovery]{
SrcIP: h.SourceIP,
SrcAddr: srcAddr,
}
<<<<<<< HEAD
=======
r.logf("Got local discovery packet from %v.", srcAddr)
>>>>>>> 69f2536 (WIP)
select {
case r.messages <- msg:
default:
r.logf("Dropping discovery message.")
}
}
// ----------------------------------------------------------------------------
// HandlePeerUpdate is called by the HubPoller when it gets a new version of
// the associated peer configuration.
func (r *Remote) HandlePeerUpdate(msg peerUpdateMsg) {
r.messages <- msg
}
// ----------------------------------------------------------------------------
func (s *Remote) logf(format string, args ...any) {
conf := s.conf()
b := strings.Builder{}
name := ""
if conf.Peer != nil {
name = conf.Peer.Name
}
b.WriteString(fmt.Sprintf("%03d", s.RemotePeerIP))
b.WriteString(fmt.Sprintf("%30s: ", name))
if conf.Server {
b.WriteString("SERVER | ")
} else {
b.WriteString("CLIENT | ")
}
if conf.Direct {
b.WriteString("DIRECT | ")
} else {
b.WriteString("RELAYED | ")
}
if conf.Up {
b.WriteString("UP | ")
} else {
b.WriteString("DOWN | ")
}
log.Printf(b.String()+format, args...)
}

View File

@ -1,491 +0,0 @@
package peer
import (
"bytes"
"net/netip"
"time"
"vppn/m"
)
type stateFunc func(msg any) stateFunc
<<<<<<< HEAD
type sentProbe struct {
SentAt time.Time
Addr netip.AddrPort
}
=======
>>>>>>> 69f2536 (WIP)
type remoteFSM struct {
*Remote
pingTimer *time.Ticker
lastSeen time.Time
traceID uint64
probes map[uint64]sentProbe
sharedKey [32]byte
buf []byte
}
func newRemoteFSM(r *Remote) *remoteFSM {
fsm := &remoteFSM{
Remote: r,
pingTimer: time.NewTicker(timeoutInterval),
probes: map[uint64]sentProbe{},
buf: make([]byte, bufferSize),
}
fsm.pingTimer.Stop()
return fsm
}
func (r *remoteFSM) Run() {
go func() {
for range r.pingTimer.C {
r.messages <- pingTimerMsg{}
}
}()
state := r.enterDisconnected()
for msg := range r.messages {
state = state(msg)
}
}
// ----------------------------------------------------------------------------
func (r *remoteFSM) enterDisconnected() stateFunc {
r.updateConf(remoteConfig{})
return r.stateDisconnected
}
func (r *remoteFSM) stateDisconnected(iMsg any) stateFunc {
switch msg := iMsg.(type) {
case peerUpdateMsg:
return r.enterPeerUpdating(msg.Peer)
case controlMsg[packetInit]:
r.logf("Unexpected INIT")
case controlMsg[packetSyn]:
r.logf("Unexpected SYN")
case controlMsg[packetAck]:
r.logf("Unexpected ACK")
case controlMsg[packetProbe]:
r.logf("Unexpected probe")
case controlMsg[packetLocalDiscovery]:
// Ignore
case pingTimerMsg:
r.logf("Unexpected ping")
default:
r.logf("Ignoring message: %#v", iMsg)
}
return r.stateDisconnected
}
// ----------------------------------------------------------------------------
func (r *remoteFSM) enterPeerUpdating(peer *m.Peer) stateFunc {
if peer == nil {
return r.enterDisconnected()
}
conf := remoteConfig{
Peer: peer,
ControlCipher: newControlCipher(r.PrivKey, peer.PubKey),
}
r.updateConf(conf)
if _, isValid := netip.AddrFromSlice(peer.PublicIP); isValid {
if r.LocalAddrValid && r.LocalPeerIP < peer.PeerIP {
return r.enterServer()
}
return r.enterClientInit()
}
if r.LocalAddrValid || r.LocalPeerIP < peer.PeerIP {
return r.enterServer()
}
return r.enterClientInit()
}
// ----------------------------------------------------------------------------
func (r *remoteFSM) enterServer() stateFunc {
conf := r.conf()
conf.Server = true
r.updateConf(conf)
r.logf("==> Server")
r.pingTimer.Reset(pingInterval)
r.lastSeen = time.Now()
clear(r.sharedKey[:])
return r.stateServer
}
func (r *remoteFSM) stateServer(iMsg any) stateFunc {
switch msg := iMsg.(type) {
case peerUpdateMsg:
return r.enterPeerUpdating(msg.Peer)
case controlMsg[packetInit]:
r.stateServer_onInit(msg)
case controlMsg[packetSyn]:
r.stateServer_onSyn(msg)
case controlMsg[packetAck]:
r.logf("Unexpected ACK")
case controlMsg[packetProbe]:
r.stateServer_onProbe(msg)
case controlMsg[packetLocalDiscovery]:
// Ignore
case pingTimerMsg:
r.stateServer_onPingTimer()
default:
r.logf("Unexpected message: %#v", iMsg)
}
return r.stateServer
}
func (r *remoteFSM) stateServer_onInit(msg controlMsg[packetInit]) {
conf := r.conf()
conf.Up = false
conf.Direct = msg.Packet.Direct
conf.DirectAddr = msg.SrcAddr
r.updateConf(conf)
init := packetInit{
TraceID: msg.Packet.TraceID,
Direct: conf.Direct,
Version: version,
}
r.sendControl(conf, init.Marshal(r.buf))
}
func (r *remoteFSM) stateServer_onSyn(msg controlMsg[packetSyn]) {
<<<<<<< HEAD
=======
r.logf("Got SYN: %v", msg.Packet)
>>>>>>> 69f2536 (WIP)
r.lastSeen = time.Now()
p := msg.Packet
// Before we can respond to this packet, we need to make sure the
// route is setup properly.
conf := r.conf()
<<<<<<< HEAD
logSyn := !conf.Up || conf.Direct != p.Direct
=======
if !conf.Up || conf.Direct != p.Direct {
r.logf("Got SYN.")
}
>>>>>>> 69f2536 (WIP)
conf.Up = true
conf.Direct = p.Direct
conf.DirectAddr = msg.SrcAddr
// Update data cipher if the key has changed.
if !bytes.Equal(r.sharedKey[:], p.SharedKey[:]) {
conf.DataCipher = newDataCipherFromKey(p.SharedKey)
copy(r.sharedKey[:], p.SharedKey[:])
}
r.updateConf(conf)
<<<<<<< HEAD
if logSyn {
r.logf("Got SYN.")
}
=======
>>>>>>> 69f2536 (WIP)
r.sendControl(conf, packetAck{
TraceID: p.TraceID,
ToAddr: conf.DirectAddr,
PossibleAddrs: r.PubAddrs.Get(),
}.Marshal(r.buf))
if p.Direct {
return
}
// Send probes if not a direct connection.
for _, addr := range msg.Packet.PossibleAddrs {
if !addr.IsValid() {
break
}
r.logf("Probing %v...", addr)
<<<<<<< HEAD
r.sendControlToAddr(packetProbe{TraceID: r.NewTraceID()}.Marshal(r.buf), addr)
=======
r.sendControlToAddr(packetProbe{TraceID: newTraceID()}.Marshal(r.buf), addr)
>>>>>>> 69f2536 (WIP)
}
}
func (r *remoteFSM) stateServer_onProbe(msg controlMsg[packetProbe]) {
if !msg.SrcAddr.IsValid() {
return
}
data := packetProbe{TraceID: msg.Packet.TraceID}.Marshal(r.buf)
r.sendControlToAddr(data, msg.SrcAddr)
}
func (r *remoteFSM) stateServer_onPingTimer() {
conf := r.conf()
if time.Since(r.lastSeen) > timeoutInterval && conf.Up {
conf.Up = false
r.updateConf(conf)
r.logf("Timeout.")
}
}
// ----------------------------------------------------------------------------
func (r *remoteFSM) enterClientInit() stateFunc {
conf := r.conf()
ip, ipValid := netip.AddrFromSlice(conf.Peer.PublicIP)
conf.Up = false
conf.Server = false
conf.Direct = ipValid
conf.DirectAddr = netip.AddrPortFrom(ip, conf.Peer.Port)
conf.DataCipher = newDataCipher()
r.updateConf(conf)
r.logf("==> ClientInit")
r.lastSeen = time.Now()
r.pingTimer.Reset(pingInterval)
r.stateClientInit_sendInit()
return r.stateClientInit
}
func (r *remoteFSM) stateClientInit(iMsg any) stateFunc {
switch msg := iMsg.(type) {
case peerUpdateMsg:
return r.enterPeerUpdating(msg.Peer)
case controlMsg[packetInit]:
return r.stateClientInit_onInit(msg)
case controlMsg[packetSyn]:
r.logf("Unexpected SYN")
case controlMsg[packetAck]:
r.logf("Unexpected ACK")
case controlMsg[packetProbe]:
// Ignore
case controlMsg[packetLocalDiscovery]:
// Ignore
case pingTimerMsg:
return r.stateClientInit_onPing()
default:
r.logf("Unexpected message: %#v", iMsg)
}
return r.stateClientInit
}
func (r *remoteFSM) stateClientInit_sendInit() {
conf := r.conf()
<<<<<<< HEAD
r.traceID = r.NewTraceID()
=======
r.traceID = newTraceID()
>>>>>>> 69f2536 (WIP)
init := packetInit{
TraceID: r.traceID,
Direct: conf.Direct,
Version: version,
}
r.sendControl(conf, init.Marshal(r.buf))
}
func (r *remoteFSM) stateClientInit_onInit(msg controlMsg[packetInit]) stateFunc {
if msg.Packet.TraceID != r.traceID {
r.logf("Invalid trace ID on INIT.")
return r.stateClientInit
}
r.logf("Got INIT version %d.", msg.Packet.Version)
return r.enterClient()
}
func (r *remoteFSM) stateClientInit_onPing() stateFunc {
if time.Since(r.lastSeen) < timeoutInterval {
r.stateClientInit_sendInit()
return r.stateClientInit
}
// Direct connect failed. Try indirect.
conf := r.conf()
if conf.Direct {
conf.Direct = false
r.updateConf(conf)
r.lastSeen = time.Now()
r.stateClientInit_sendInit()
r.logf("Direct connection failed. Attempting indirect connection.")
return r.stateClientInit
}
// Indirect failed. Re-enter init state.
r.logf("Timeout.")
return r.enterClientInit()
}
// ----------------------------------------------------------------------------
func (r *remoteFSM) enterClient() stateFunc {
conf := r.conf()
r.probes = make(map[uint64]sentProbe, 8)
<<<<<<< HEAD
r.traceID = r.NewTraceID()
=======
r.traceID = newTraceID()
>>>>>>> 69f2536 (WIP)
r.stateClient_sendSyn(conf)
r.pingTimer.Reset(pingInterval)
r.logf("==> Client")
return r.stateClient
}
func (r *remoteFSM) stateClient(iMsg any) stateFunc {
switch msg := iMsg.(type) {
case peerUpdateMsg:
return r.enterPeerUpdating(msg.Peer)
case controlMsg[packetAck]:
r.stateClient_onAck(msg)
case controlMsg[packetProbe]:
r.stateClient_onProbe(msg)
case controlMsg[packetLocalDiscovery]:
r.stateClient_onLocalDiscovery(msg)
case pingTimerMsg:
return r.stateClient_onPingTimer()
default:
r.logf("Ignoring message: %v", iMsg)
}
return r.stateClient
}
func (r *remoteFSM) stateClient_onAck(msg controlMsg[packetAck]) {
if msg.Packet.TraceID != r.traceID {
return
}
r.lastSeen = time.Now()
conf := r.conf()
if !conf.Up {
conf.Up = true
r.updateConf(conf)
r.logf("Got ACK.")
}
if conf.Direct {
r.PubAddrs.Store(msg.Packet.ToAddr)
return
}
// Relayed.
r.stateClient_cleanProbes()
for _, addr := range msg.Packet.PossibleAddrs {
if !addr.IsValid() {
break
}
r.stateClient_sendProbeTo(addr)
}
}
func (r *remoteFSM) stateClient_cleanProbes() {
for key, sent := range r.probes {
if time.Since(sent.SentAt) > pingInterval {
delete(r.probes, key)
}
}
}
func (r *remoteFSM) stateClient_sendProbeTo(addr netip.AddrPort) {
<<<<<<< HEAD
probe := packetProbe{TraceID: r.NewTraceID()}
=======
probe := packetProbe{TraceID: newTraceID()}
>>>>>>> 69f2536 (WIP)
r.probes[probe.TraceID] = sentProbe{
SentAt: time.Now(),
Addr: addr,
}
r.logf("Probing %v...", addr)
r.sendControlToAddr(probe.Marshal(r.buf), addr)
}
func (r *remoteFSM) stateClient_onProbe(msg controlMsg[packetProbe]) {
conf := r.conf()
if conf.Direct {
return
}
r.stateClient_cleanProbes()
sent, ok := r.probes[msg.Packet.TraceID]
if !ok {
return
}
conf.Direct = true
conf.DirectAddr = sent.Addr
r.updateConf(conf)
<<<<<<< HEAD
r.traceID = r.NewTraceID()
=======
r.traceID = newTraceID()
>>>>>>> 69f2536 (WIP)
r.stateClient_sendSyn(conf)
r.logf("Successful probe to %v.", sent.Addr)
}
func (r *remoteFSM) stateClient_onLocalDiscovery(msg controlMsg[packetLocalDiscovery]) {
conf := r.conf()
if conf.Direct {
return
}
// The source port will be the multicast port, so we'll have to
// construct the correct address using the peer's listed port.
addr := netip.AddrPortFrom(msg.SrcAddr.Addr(), conf.Peer.Port)
r.stateClient_sendProbeTo(addr)
}
func (r *remoteFSM) stateClient_onPingTimer() stateFunc {
conf := r.conf()
if time.Since(r.lastSeen) > timeoutInterval {
if conf.Up {
r.logf("Timeout.")
}
return r.enterClientInit()
}
<<<<<<< HEAD
=======
r.traceID = newTraceID()
>>>>>>> 69f2536 (WIP)
r.stateClient_sendSyn(conf)
return r.stateClient
}
func (r *remoteFSM) stateClient_sendSyn(conf remoteConfig) {
syn := packetSyn{
TraceID: r.traceID,
SharedKey: conf.DataCipher.Key(),
Direct: conf.Direct,
PossibleAddrs: r.PubAddrs.Get(),
}
r.sendControl(conf, syn.Marshal(r.buf))
}

View File

@ -1 +0,0 @@
package peer

138
peer/routingtable.go Normal file
View File

@ -0,0 +1,138 @@
package peer
import (
"net/netip"
"sync/atomic"
"time"
)
// TODO: Remove
func newRemotePeer(ip byte) *remotePeer {
counter := uint64(time.Now().Unix()<<30 + 1)
return &remotePeer{
IP: ip,
counter: &counter,
dupCheck: newDupCheck(0),
}
}
// ----------------------------------------------------------------------------
type remotePeer struct {
localIP byte
IP byte // VPN IP of peer (last byte).
Up bool // True if data can be sent on the peer.
Relay bool // True if the peer is a relay.
Direct bool // True if this is a direct connection.
DirectAddr netip.AddrPort // Remote address if directly connected.
PubSignKey []byte
ControlCipher *controlCipher
DataCipher *dataCipher
counter *uint64 // For sending to. Atomic access only.
dupCheck *dupCheck // For receiving from. Not safe for concurrent use.
}
func (p remotePeer) EncryptDataPacket(destIP byte, data, out []byte) []byte {
h := header{
StreamID: dataStreamID,
Counter: atomic.AddUint64(p.counter, 1),
SourceIP: p.localIP,
DestIP: destIP,
}
return p.DataCipher.Encrypt(h, data, out)
}
// Decrypts and de-dups incoming data packets.
func (p remotePeer) DecryptDataPacket(h header, enc, out []byte) ([]byte, error) {
dec, ok := p.DataCipher.Decrypt(enc, out)
if !ok {
return nil, errDecryptionFailed
}
if p.dupCheck.IsDup(h.Counter) {
return nil, errDuplicateSeqNum
}
return dec, nil
}
// Peer must have a ControlCipher.
func (p remotePeer) EncryptControlPacket(pkt marshaller, tmp, out []byte) []byte {
tmp = pkt.Marshal(tmp)
h := header{
StreamID: controlStreamID,
Counter: atomic.AddUint64(p.counter, 1),
SourceIP: p.localIP,
DestIP: p.IP,
}
return p.ControlCipher.Encrypt(h, tmp, out)
}
// Returns a controlMsg[PacketType]. Peer must have a non-nil ControlCipher.
//
// This function also drops packets with duplicate sequence numbers.
func (p remotePeer) DecryptControlPacket(fromAddr netip.AddrPort, h header, enc, tmp []byte) (any, error) {
out, ok := p.ControlCipher.Decrypt(enc, tmp)
if !ok {
return nil, errDecryptionFailed
}
if p.dupCheck.IsDup(h.Counter) {
return nil, errDuplicateSeqNum
}
msg, err := parseControlMsg(h.SourceIP, fromAddr, out)
if err != nil {
return nil, err
}
return msg, nil
}
// ----------------------------------------------------------------------------
type routingTable struct {
// The LocalIP is the configured IP address of the local peer on the VPN.
//
// This value is constant.
LocalIP byte
// The LocalAddr is the configured local public address of the peer on the
// internet. If LocalAddr.IsValid(), then the local peer has a public
// address.
//
// This value is constant.
LocalAddr netip.AddrPort
// The remote peer configurations. These are updated by
Peers [256]remotePeer
// The current relay's VPN IP address, or zero if no relay is available.
RelayIP byte
}
func newRoutingTable(localIP byte, localAddr netip.AddrPort) routingTable {
rt := routingTable{
LocalIP: localIP,
LocalAddr: localAddr,
}
for i := range rt.Peers {
counter := uint64(time.Now().Unix()<<30 + 1)
rt.Peers[i] = remotePeer{
localIP: localIP,
IP: byte(i),
counter: &counter,
dupCheck: newDupCheck(0),
}
}
return rt
}
func (rt *routingTable) GetRelay() (remotePeer, bool) {
relay := rt.Peers[rt.RelayIP]
return relay, relay.Up && relay.Direct
}

169
peer/routingtable_test.go Normal file
View File

@ -0,0 +1,169 @@
package peer
import (
"bytes"
"reflect"
"testing"
)
func TestRemotePeer_DecryptDataPacket(t *testing.T) {
p1, p2, _ := NewPeersForTesting()
orig := RandPacket()
peer2 := p1.RT.Load().Peers[2]
peer1 := p2.RT.Load().Peers[1]
enc := peer2.EncryptDataPacket(2, orig, newBuf())
h := parseHeader(enc)
if h.DestIP != 2 || h.SourceIP != 1 {
t.Fatal(h)
}
dec, err := peer1.DecryptDataPacket(h, enc, newBuf())
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(orig, dec) {
t.Fatal(dec)
}
}
func TestRemotePeer_DecryptDataPacket_packetAltered(t *testing.T) {
p1, p2, _ := NewPeersForTesting()
orig := RandPacket()
peer2 := p1.RT.Load().Peers[2]
peer1 := p2.RT.Load().Peers[1]
enc := peer2.EncryptDataPacket(2, orig, newBuf())
h := parseHeader(enc)
for range 2048 {
_, err := peer1.DecryptDataPacket(h, ModifyPacket(enc), newBuf())
if err == nil {
t.Fatal(enc)
}
}
}
func TestRemotePeer_DecryptDataPacket_duplicateSequenceNumber(t *testing.T) {
p1, p2, _ := NewPeersForTesting()
orig := RandPacket()
peer2 := p1.RT.Load().Peers[2]
peer1 := p2.RT.Load().Peers[1]
enc := peer2.EncryptDataPacket(2, orig, newBuf())
h := parseHeader(enc)
if _, err := peer1.DecryptDataPacket(h, enc, newBuf()); err != nil {
t.Fatal(err)
}
if _, err := peer1.DecryptDataPacket(h, enc, newBuf()); err == nil {
t.Fatal(err)
}
}
func TestRemotePeer_DecryptControlPacket(t *testing.T) {
p1, p2, _ := NewPeersForTesting()
peer2 := p1.RT.Load().Peers[2]
peer1 := p2.RT.Load().Peers[1]
orig := packetProbe{TraceID: newTraceID()}
enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf())
h := parseHeader(enc)
if h.DestIP != 2 || h.SourceIP != 1 {
t.Fatal(h)
}
ctrlMsg, err := peer1.DecryptControlPacket(p1.RT.Load().LocalAddr, h, enc, newBuf())
if err != nil {
t.Fatal(err)
}
dec, ok := ctrlMsg.(controlMsg[packetProbe])
if !ok {
t.Fatal(ctrlMsg)
}
if dec.SrcIP != 1 || dec.SrcAddr != p1.RT.Load().LocalAddr {
t.Fatal(dec)
}
if !reflect.DeepEqual(dec.Packet, orig) {
t.Fatal(dec)
}
}
func TestRemotePeer_DecryptControlPacket_packetAltered(t *testing.T) {
p1, p2, _ := NewPeersForTesting()
peer2 := p1.RT.Load().Peers[2]
peer1 := p2.RT.Load().Peers[1]
orig := packetProbe{TraceID: newTraceID()}
enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf())
h := parseHeader(enc)
if h.DestIP != 2 || h.SourceIP != 1 {
t.Fatal(h)
}
for range 2048 {
ctrlMsg, err := peer1.DecryptControlPacket(p1.RT.Load().LocalAddr, h, ModifyPacket(enc), newBuf())
if err == nil {
t.Fatal(ctrlMsg)
}
}
}
func TestRemotePeer_DecryptControlPacket_duplicateSequenceNumber(t *testing.T) {
p1, p2, _ := NewPeersForTesting()
peer2 := p1.RT.Load().Peers[2]
peer1 := p2.RT.Load().Peers[1]
orig := packetProbe{TraceID: newTraceID()}
enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf())
h := parseHeader(enc)
if h.DestIP != 2 || h.SourceIP != 1 {
t.Fatal(h)
}
if _, err := peer1.DecryptControlPacket(p1.RT.Load().LocalAddr, h, enc, newBuf()); err != nil {
t.Fatal(err)
}
if _, err := peer1.DecryptControlPacket(p1.RT.Load().LocalAddr, h, enc, newBuf()); err == nil {
t.Fatal(err)
}
}
func TestRemotePeer_DecryptControlPacket_unknownPacketType(t *testing.T) {
p1, p2, _ := NewPeersForTesting()
peer2 := p1.RT.Load().Peers[2]
peer1 := p2.RT.Load().Peers[1]
orig := UnknownControlPacket{TraceID: newTraceID()}
enc := peer2.EncryptControlPacket(orig, newBuf(), newBuf())
h := parseHeader(enc)
if h.DestIP != 2 || h.SourceIP != 1 {
t.Fatal(h)
}
if _, err := peer1.DecryptControlPacket(p1.RT.Load().LocalAddr, h, enc, newBuf()); err == nil {
t.Fatal(err)
}
}

162
peer/state-client.go Normal file
View File

@ -0,0 +1,162 @@
package peer
import (
"net/netip"
"time"
)
type sentProbe struct {
SentAt time.Time
Addr netip.AddrPort
}
type stateClient struct {
*peerData
lastSeen time.Time
syn packetSyn
probes map[uint64]sentProbe
}
func enterStateClient(data *peerData) peerState {
ip, ipValid := netip.AddrFromSlice(data.peer.PublicIP)
data.staged.Relay = data.peer.Relay && ipValid
data.staged.Direct = ipValid
data.staged.DirectAddr = netip.AddrPortFrom(ip, data.peer.Port)
data.publish(data.staged)
state := &stateClient{
peerData: data,
lastSeen: time.Now(),
syn: packetSyn{
TraceID: newTraceID(),
SharedKey: data.staged.DataCipher.Key(),
Direct: data.staged.Direct,
PossibleAddrs: data.pubAddrs.Get(),
},
probes: map[uint64]sentProbe{},
}
state.Send(state.staged, state.syn)
data.pingTimer.Reset(pingInterval)
state.logf("==> Client")
return state
}
func (s *stateClient) logf(str string, args ...any) {
s.peerData.logf("CLNT | "+str, args...)
}
func (s *stateClient) OnMsg(raw any) peerState {
switch msg := raw.(type) {
case peerUpdateMsg:
return initPeerState(s.peerData, msg.Peer)
case controlMsg[packetAck]:
s.onAck(msg)
case controlMsg[packetProbe]:
return s.onProbe(msg)
case controlMsg[packetLocalDiscovery]:
s.onLocalDiscovery(msg)
case pingTimerMsg:
return s.onPingTimer()
default:
s.logf("Ignoring message: %v", raw)
}
return s
}
func (s *stateClient) onAck(msg controlMsg[packetAck]) {
if msg.Packet.TraceID != s.syn.TraceID {
return
}
s.lastSeen = time.Now()
if !s.staged.Up {
s.staged.Up = true
s.publish(s.staged)
s.logf("Got ACK.")
}
if s.staged.Direct {
s.pubAddrs.Store(msg.Packet.ToAddr)
return
}
// Relayed below.
s.cleanProbes()
for _, addr := range msg.Packet.PossibleAddrs {
if !addr.IsValid() {
break
}
s.sendProbeTo(addr)
}
}
func (s *stateClient) onPingTimer() peerState {
if time.Since(s.lastSeen) > timeoutInterval {
if s.staged.Up {
s.logf("Timeout.")
}
return initPeerState(s.peerData, s.peer)
}
s.Send(s.staged, s.syn)
return s
}
func (s *stateClient) onProbe(msg controlMsg[packetProbe]) peerState {
if s.staged.Direct {
return s
}
s.cleanProbes()
sent, ok := s.probes[msg.Packet.TraceID]
if !ok {
return s
}
s.staged.Direct = true
s.staged.DirectAddr = sent.Addr
s.publish(s.staged)
s.syn.TraceID = newTraceID()
s.syn.Direct = true
s.Send(s.staged, s.syn)
s.logf("Successful probe to %v.", sent.Addr)
return s
}
func (s *stateClient) onLocalDiscovery(msg controlMsg[packetLocalDiscovery]) {
if s.staged.Direct {
return
}
// The source port will be the multicast port, so we'll have to
// construct the correct address using the peer's listed port.
addr := netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port)
s.sendProbeTo(addr)
}
func (s *stateClient) cleanProbes() {
for key, sent := range s.probes {
if time.Since(sent.SentAt) > pingInterval {
delete(s.probes, key)
}
}
}
func (s *stateClient) sendProbeTo(addr netip.AddrPort) {
probe := packetProbe{TraceID: newTraceID()}
s.probes[probe.TraceID] = sentProbe{
SentAt: time.Now(),
Addr: addr,
}
s.logf("Probing %v...", addr)
s.SendTo(probe, addr)
}

193
peer/state-client_test.go Normal file
View File

@ -0,0 +1,193 @@
package peer
import (
"testing"
"time"
)
func TestStateClient_peerUpdate(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigClientDirect(t)
h.PeerUpdate(nil)
assertType[*stateDisconnected](t, h.State)
}
func TestStateClient_initialPackets(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigClientDirect(t)
assertEqual(t, len(h.Sent), 2)
assertType[packetInit](t, h.Sent[0].Packet)
assertType[packetSyn](t, h.Sent[1].Packet)
}
func TestStateClient_onAck_incorrectTraceID(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigClientDirect(t)
h.Sent = h.Sent[:0]
ack := controlMsg[packetAck]{
Packet: packetAck{TraceID: newTraceID()},
}
h.OnAck(ack)
// Nothing should have happened.
assertType[*stateClient](t, h.State)
assertEqual(t, len(h.Sent), 0)
}
func TestStateClient_onAck_direct_downToUp(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigClientDirect(t)
assertEqual(t, len(h.Sent), 2)
syn := assertType[packetSyn](t, h.Sent[1].Packet)
h.Sent = h.Sent[:0]
assertEqual(t, h.Published.Up, false)
ack := controlMsg[packetAck]{
Packet: packetAck{TraceID: syn.TraceID},
}
h.OnAck(ack)
assertEqual(t, len(h.Sent), 0)
}
func TestStateClient_onAck_relayed_sendsProbes(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigClientRelayed(t)
assertEqual(t, len(h.Sent), 2)
syn := assertType[packetSyn](t, h.Sent[1].Packet)
h.Sent = h.Sent[:0]
assertEqual(t, h.Published.Up, false)
ack := controlMsg[packetAck]{
Packet: packetAck{TraceID: syn.TraceID},
}
ack.Packet.PossibleAddrs[0] = addrPort4(1, 2, 3, 4, 100)
ack.Packet.PossibleAddrs[1] = addrPort4(2, 3, 4, 5, 200)
h.OnAck(ack)
assertEqual(t, len(h.Sent), 2)
assertType[packetProbe](t, h.Sent[0].Packet)
assertEqual(t, h.Sent[0].Peer.DirectAddr, ack.Packet.PossibleAddrs[0])
assertType[packetProbe](t, h.Sent[1].Packet)
assertEqual(t, h.Sent[1].Peer.DirectAddr, ack.Packet.PossibleAddrs[1])
}
func TestStateClient_onPing(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigClientRelayed(t)
h.Sent = h.Sent[:0]
h.OnPingTimer()
assertEqual(t, len(h.Sent), 1)
assertType[*stateClient](t, h.State)
assertType[packetSyn](t, h.Sent[0].Packet)
}
func TestStateClient_onPing_timeout(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigClientRelayed(t)
h.Sent = h.Sent[:0]
state := assertType[*stateClient](t, h.State)
state.lastSeen = time.Now().Add(-2 * timeoutInterval)
state.staged.Up = true
h.OnPingTimer()
newState := assertType[*stateClientInit](t, h.State)
assertEqual(t, newState.staged.Up, false)
assertEqual(t, len(h.Sent), 1)
assertType[packetInit](t, h.Sent[0].Packet)
}
func TestStateClient_onProbe_direct(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigClientDirect(t)
h.Sent = h.Sent[:0]
probe := controlMsg[packetProbe]{
Packet: packetProbe{
TraceID: newTraceID(),
},
}
h.OnProbe(probe)
assertType[*stateClient](t, h.State)
assertEqual(t, len(h.Sent), 0)
}
func TestStateClient_onProbe_noMatch(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigClientRelayed(t)
h.Sent = h.Sent[:0]
probe := controlMsg[packetProbe]{
Packet: packetProbe{
TraceID: newTraceID(),
},
}
h.OnProbe(probe)
assertType[*stateClient](t, h.State)
assertEqual(t, len(h.Sent), 0)
}
func TestStateClient_onProbe_directUpgrade(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigClientRelayed(t)
state := assertType[*stateClient](t, h.State)
traceID := newTraceID()
state.probes[traceID] = sentProbe{
SentAt: time.Now(),
Addr: addrPort4(1, 2, 3, 4, 500),
}
probe := controlMsg[packetProbe]{
Packet: packetProbe{TraceID: traceID},
}
assertEqual(t, h.Published.Direct, false)
h.Sent = h.Sent[:0]
h.OnProbe(probe)
assertEqual(t, h.Published.Direct, true)
assertEqual(t, len(h.Sent), 1)
assertType[packetSyn](t, h.Sent[0].Packet)
}
func TestStateClient_onLocalDiscovery_direct(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigClientDirect(t)
h.Sent = h.Sent[:0]
pkt := controlMsg[packetLocalDiscovery]{
Packet: packetLocalDiscovery{},
}
h.OnLocalDiscovery(pkt)
assertType[*stateClient](t, h.State)
assertEqual(t, len(h.Sent), 0)
}
func TestStateClient_onLocalDiscovery_relayed(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigClientRelayed(t)
h.Sent = h.Sent[:0]
pkt := controlMsg[packetLocalDiscovery]{
SrcAddr: addrPort4(1, 2, 3, 4, 500),
Packet: packetLocalDiscovery{},
}
h.OnLocalDiscovery(pkt)
assertType[*stateClient](t, h.State)
assertEqual(t, len(h.Sent), 1)
assertType[packetProbe](t, h.Sent[0].Packet)
assertEqual(t, h.Sent[0].Peer.DirectAddr, addrPort4(1, 2, 3, 4, 456))
}

104
peer/state-clientinit.go Normal file
View File

@ -0,0 +1,104 @@
package peer
import (
"net/netip"
"time"
)
type stateClientInit struct {
*peerData
startedAt time.Time
traceID uint64
}
func enterStateClientInit(data *peerData) peerState {
ip, ipValid := netip.AddrFromSlice(data.peer.PublicIP)
data.staged.Up = false
data.staged.Relay = false
data.staged.Direct = ipValid
data.staged.DirectAddr = netip.AddrPortFrom(ip, data.peer.Port)
data.staged.PubSignKey = data.peer.PubSignKey
data.staged.ControlCipher = newControlCipher(data.privKey, data.peer.PubKey)
data.staged.DataCipher = newDataCipher()
data.publish(data.staged)
state := &stateClientInit{
peerData: data,
startedAt: time.Now(),
traceID: newTraceID(),
}
state.sendInit()
data.pingTimer.Reset(pingInterval)
state.logf("==> ClientInit")
return state
}
func (s *stateClientInit) logf(str string, args ...any) {
s.peerData.logf("INIT | "+str, args...)
}
func (s *stateClientInit) OnMsg(raw any) peerState {
switch msg := raw.(type) {
case peerUpdateMsg:
return initPeerState(s.peerData, msg.Peer)
case controlMsg[packetInit]:
return s.onInit(msg)
case controlMsg[packetSyn]:
s.logf("Unexpected SYN")
return s
case controlMsg[packetAck]:
s.logf("Unexpected ACK")
return s
case controlMsg[packetProbe]:
return s
case controlMsg[packetLocalDiscovery]:
return s
case pingTimerMsg:
return s.onPing()
default:
s.logf("Ignoring message: %#v", raw)
return s
}
}
func (s *stateClientInit) onInit(msg controlMsg[packetInit]) peerState {
if msg.Packet.TraceID != s.traceID {
s.logf("Invalid trace ID on INIT.")
return s
}
s.logf("Got INIT version %d.", msg.Packet.Version)
return enterStateClient(s.peerData)
}
func (s *stateClientInit) onPing() peerState {
if time.Since(s.startedAt) < timeoutInterval {
s.sendInit()
return s
}
if s.staged.Direct {
s.staged.Direct = false
s.publish(s.staged)
s.startedAt = time.Now()
s.sendInit()
s.logf("Direct connection failed. Attempting indirect connection.")
return s
}
s.logf("Timeout.")
return initPeerState(s.peerData, s.peer)
}
func (s *stateClientInit) sendInit() {
s.traceID = newTraceID()
init := packetInit{
TraceID: s.traceID,
Direct: s.staged.Direct,
Version: version,
}
s.Send(s.staged, init)
}

View File

@ -0,0 +1,92 @@
package peer
import (
"testing"
"time"
)
func TestPeerState_ClientInit_initWithIncorrectTraceID(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigClientInit(t)
// Should have sent the first init packet.
assertEqual(t, len(h.Sent), 1)
init := assertType[packetInit](t, h.Sent[0].Packet)
init.TraceID = newTraceID()
h.OnInit(controlMsg[packetInit]{Packet: init})
assertType[*stateClientInit](t, h.State)
}
func TestPeerState_ClientInit_init(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigClientInit(t)
// Should have sent the first init packet.
assertEqual(t, len(h.Sent), 1)
init := assertType[packetInit](t, h.Sent[0].Packet)
h.OnInit(controlMsg[packetInit]{Packet: init})
assertType[*stateClient](t, h.State)
}
func TestPeerState_ClientInit_onPing(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigClientInit(t)
// Should have sent the first init packet.
assertEqual(t, len(h.Sent), 1)
h.Sent = h.Sent[:0]
for range 3 {
h.OnPingTimer()
}
assertEqual(t, len(h.Sent), 3)
for i := range h.Sent {
assertType[packetInit](t, h.Sent[i].Packet)
}
}
func TestPeerState_ClientInit_onPingTimeout(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigClientInit(t)
state := assertType[*stateClientInit](t, h.State)
state.startedAt = time.Now().Add(-2 * timeoutInterval)
assertEqual(t, state.staged.Direct, true)
h.OnPingTimer()
// Should now try indirect connection.
state = assertType[*stateClientInit](t, h.State)
assertEqual(t, state.staged.Direct, false)
// Should re-initialize the peer after another timeout, so should be direct
// again.
state.startedAt = time.Now().Add(-2 * timeoutInterval)
h.OnPingTimer()
assertEqual(t, state.staged.Direct, true)
}
func TestPeerState_ClientInit_onPeerUpdate(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigClientInit(t)
h.PeerUpdate(nil)
// Should have moved into the client state due to timeout.
assertType[*stateDisconnected](t, h.State)
}
func TestPeerState_ClientInit_ignoreMessage(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigClientInit(t)
h.OnProbe(controlMsg[packetProbe]{})
// Shouldn't do anything.
assertType[*stateClientInit](t, h.State)
}

View File

@ -0,0 +1,50 @@
package peer
import "net/netip"
type stateDisconnected struct {
*peerData
}
func enterStateDisconnected(data *peerData) peerState {
data.staged.Up = false
data.staged.Relay = false
data.staged.Direct = false
data.staged.DirectAddr = netip.AddrPort{}
data.staged.PubSignKey = nil
data.staged.ControlCipher = nil
data.staged.DataCipher = nil
data.publish(data.staged)
data.pingTimer.Stop()
return &stateDisconnected{data}
}
func (s *stateDisconnected) OnMsg(raw any) peerState {
switch msg := raw.(type) {
case peerUpdateMsg:
return initPeerState(s.peerData, msg.Peer)
case controlMsg[packetInit]:
s.logf("Unexpected INIT")
return s
case controlMsg[packetSyn]:
s.logf("Unexpected SYN")
return s
case controlMsg[packetAck]:
s.logf("Unexpected ACK")
return s
case controlMsg[packetProbe]:
s.logf("Unexpected probe")
return s
case controlMsg[packetLocalDiscovery]:
return s
case pingTimerMsg:
s.logf("Unexpected ping")
return s
default:
s.logf("Ignoring message: %#v", raw)
return s
}
}

136
peer/state-server.go Normal file
View File

@ -0,0 +1,136 @@
package peer
import (
"net/netip"
"time"
)
type stateServer struct {
*peerData
lastSeen time.Time
synTraceID uint64 // Last syn trace ID.
}
func enterStateServer(data *peerData) peerState {
data.staged.Up = false
data.staged.Relay = false
data.staged.Direct = false
data.staged.DirectAddr = netip.AddrPort{}
data.staged.PubSignKey = data.peer.PubSignKey
data.staged.ControlCipher = newControlCipher(data.privKey, data.peer.PubKey)
data.staged.DataCipher = nil
data.publish(data.staged)
data.pingTimer.Reset(pingInterval)
state := &stateServer{
peerData: data,
lastSeen: time.Now(),
}
state.logf("==> Server")
return state
}
func (s *stateServer) logf(str string, args ...any) {
s.peerData.logf("SRVR | "+str, args...)
}
func (s *stateServer) OnMsg(raw any) peerState {
switch msg := raw.(type) {
case peerUpdateMsg:
return initPeerState(s.peerData, msg.Peer)
case controlMsg[packetInit]:
return s.onInit(msg)
case controlMsg[packetSyn]:
return s.onSyn(msg)
case controlMsg[packetAck]:
s.logf("Unexpected ACK")
return s
case controlMsg[packetProbe]:
return s.onProbe(msg)
case controlMsg[packetLocalDiscovery]:
return s
case pingTimerMsg:
return s.onPingTimer()
default:
s.logf("Unexpected message: %#v", raw)
return s
}
}
func (s *stateServer) onInit(msg controlMsg[packetInit]) peerState {
s.staged.Up = false
s.staged.Direct = msg.Packet.Direct
s.staged.DirectAddr = msg.SrcAddr
s.publish(s.staged)
init := packetInit{
TraceID: msg.Packet.TraceID,
Direct: s.staged.Direct,
Version: version,
}
s.Send(s.staged, init)
return s
}
func (s *stateServer) onSyn(msg controlMsg[packetSyn]) peerState {
s.lastSeen = time.Now()
p := msg.Packet
// Before we can respond to this packet, we need to make sure the
// route is setup properly.
//
// The client will update the syn's TraceID whenever there's a change.
// The server will follow the client's request.
if p.TraceID != s.synTraceID || !s.staged.Up {
s.synTraceID = p.TraceID
s.staged.Up = true
s.staged.Direct = p.Direct
s.staged.DataCipher = newDataCipherFromKey(p.SharedKey)
s.staged.DirectAddr = msg.SrcAddr
s.publish(s.staged)
s.logf("Got SYN.")
}
// Always respond.
s.Send(s.staged, packetAck{
TraceID: p.TraceID,
ToAddr: s.staged.DirectAddr,
PossibleAddrs: s.pubAddrs.Get(),
})
if p.Direct {
return s
}
// Send probes if not a direct connection.
for _, addr := range msg.Packet.PossibleAddrs {
if !addr.IsValid() {
break
}
s.logf("Probing %v...", addr)
s.SendTo(packetProbe{TraceID: newTraceID()}, addr)
}
return s
}
func (s *stateServer) onProbe(msg controlMsg[packetProbe]) peerState {
if msg.SrcAddr.IsValid() {
s.logf("Probe response %v...", msg.SrcAddr)
s.SendTo(packetProbe{TraceID: msg.Packet.TraceID}, msg.SrcAddr)
}
return s
}
func (s *stateServer) onPingTimer() peerState {
if time.Since(s.lastSeen) > timeoutInterval && s.staged.Up {
s.staged.Up = false
s.publish(s.staged)
s.logf("Timeout.")
}
return s
}

164
peer/state-server_test.go Normal file
View File

@ -0,0 +1,164 @@
package peer
import (
"testing"
"time"
)
func TestStateServer_peerUpdate(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigServer_Public(t)
h.PeerUpdate(nil)
assertType[*stateDisconnected](t, h.State)
}
func TestStateServer_onInit(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigServer_Public(t)
msg := controlMsg[packetInit]{
SrcIP: 3,
SrcAddr: addrPort4(1, 2, 3, 4, 1000),
Packet: packetInit{
TraceID: newTraceID(),
Direct: true,
Version: 4,
},
}
h.OnInit(msg)
assertEqual(t, len(h.Sent), 1)
assertEqual(t, h.Sent[0].Peer.DirectAddr, msg.SrcAddr)
resp := assertType[packetInit](t, h.Sent[0].Packet)
assertEqual(t, msg.Packet.TraceID, resp.TraceID)
assertEqual(t, resp.Version, version)
}
func TestStateServer_onSynDirect(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigServer_Public(t)
msg := controlMsg[packetSyn]{
SrcIP: 3,
SrcAddr: addrPort4(1, 2, 3, 4, 1000),
Packet: packetSyn{
TraceID: newTraceID(),
Direct: true,
},
}
msg.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 1000)
msg.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 2000)
h.OnSyn(msg)
assertEqual(t, len(h.Sent), 1)
assertEqual(t, h.Sent[0].Peer.DirectAddr, msg.SrcAddr)
resp := assertType[packetAck](t, h.Sent[0].Packet)
assertEqual(t, msg.Packet.TraceID, resp.TraceID)
}
func TestStateServer_onSynRelayed(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigServer_Relayed(t)
msg := controlMsg[packetSyn]{
SrcIP: 3,
SrcAddr: addrPort4(1, 2, 3, 4, 1000),
Packet: packetSyn{
TraceID: newTraceID(),
},
}
msg.Packet.PossibleAddrs[0] = addrPort4(1, 1, 1, 1, 1000)
msg.Packet.PossibleAddrs[1] = addrPort4(1, 1, 1, 2, 2000)
h.OnSyn(msg)
assertEqual(t, len(h.Sent), 3)
assertEqual(t, h.Sent[0].Peer.DirectAddr, msg.SrcAddr)
resp := assertType[packetAck](t, h.Sent[0].Packet)
assertEqual(t, msg.Packet.TraceID, resp.TraceID)
for i, pkt := range h.Sent[1:] {
assertEqual(t, pkt.Peer.DirectAddr, msg.Packet.PossibleAddrs[i])
assertType[packetProbe](t, pkt.Packet)
}
}
func TestStateServer_onProbe(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigServer_Relayed(t)
msg := controlMsg[packetProbe]{
SrcIP: 3,
Packet: packetProbe{
TraceID: newTraceID(),
},
}
h.Sent = h.Sent[:0]
h.OnProbe(msg)
assertEqual(t, len(h.Sent), 0)
}
func TestStateServer_onProbe_valid(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigServer_Relayed(t)
msg := controlMsg[packetProbe]{
SrcIP: 3,
SrcAddr: addrPort4(1, 2, 3, 4, 100),
Packet: packetProbe{
TraceID: newTraceID(),
},
}
h.Sent = h.Sent[:0]
h.OnProbe(msg)
assertEqual(t, len(h.Sent), 1)
assertType[packetProbe](t, h.Sent[0].Packet)
assertEqual(t, h.Sent[0].Peer.DirectAddr, msg.SrcAddr)
}
func TestStateServer_onPing(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigServer_Relayed(t)
h.Sent = h.Sent[:0]
h.OnPingTimer()
assertEqual(t, len(h.Sent), 0)
assertType[*stateServer](t, h.State)
}
func TestStateServer_onPing_timeout(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigServer_Relayed(t)
h.Sent = h.Sent[:0]
state := assertType[*stateServer](t, h.State)
state.staged.Up = true
state.lastSeen = time.Now().Add(-2 * timeoutInterval)
h.OnPingTimer()
state = assertType[*stateServer](t, h.State)
assertEqual(t, len(h.Sent), 0)
assertEqual(t, state.staged.Up, false)
}
func TestStateServer_onLocalDiscovery(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigServer_Relayed(t)
msg := controlMsg[packetLocalDiscovery]{
SrcIP: 3,
SrcAddr: addrPort4(1, 2, 3, 4, 100),
}
h.OnLocalDiscovery(msg)
assertType[*stateServer](t, h.State)
}
func TestStateServer_onAck(t *testing.T) {
h := NewPeerStateTestHarness()
h.ConfigServer_Relayed(t)
msg := controlMsg[packetAck]{}
h.OnAck(msg)
assertType[*stateServer](t, h.State)
}

151
peer/state-util_test.go Normal file
View File

@ -0,0 +1,151 @@
package peer
import (
"net/netip"
"testing"
"time"
"vppn/m"
"git.crumpington.com/lib/go/ratelimiter"
)
type PeerStateControlMsg struct {
Peer remotePeer
Packet any
}
type PeerStateTestHarness struct {
data *peerData
State peerState
Published remotePeer
Sent []PeerStateControlMsg
}
func NewPeerStateTestHarness() *PeerStateTestHarness {
h := &PeerStateTestHarness{}
keys := generateKeys()
state := &peerData{
publish: func(rp remotePeer) {
h.Published = rp
},
sendControlPacket: func(rp remotePeer, pkt marshaller) {
h.Sent = append(h.Sent, PeerStateControlMsg{rp, pkt})
},
pingTimer: time.NewTicker(pingInterval),
localIP: 2,
remoteIP: 3,
privKey: keys.PrivKey,
pubAddrs: newPubAddrStore(netip.AddrPort{}),
limiter: ratelimiter.New(ratelimiter.Config{
FillPeriod: 20 * time.Millisecond,
MaxWaitCount: 1,
}),
}
h.data = state
h.State = enterStateDisconnected(state)
return h
}
func (h *PeerStateTestHarness) PeerUpdate(p *m.Peer) {
h.State = h.State.OnMsg(peerUpdateMsg{p})
}
func (h *PeerStateTestHarness) OnInit(msg controlMsg[packetInit]) {
h.State = h.State.OnMsg(msg)
}
func (h *PeerStateTestHarness) OnSyn(msg controlMsg[packetSyn]) {
h.State = h.State.OnMsg(msg)
}
func (h *PeerStateTestHarness) OnAck(msg controlMsg[packetAck]) {
h.State = h.State.OnMsg(msg)
}
func (h *PeerStateTestHarness) OnProbe(msg controlMsg[packetProbe]) {
h.State = h.State.OnMsg(msg)
}
func (h *PeerStateTestHarness) OnLocalDiscovery(msg controlMsg[packetLocalDiscovery]) {
h.State = h.State.OnMsg(msg)
}
func (h *PeerStateTestHarness) OnPingTimer() {
h.State = h.State.OnMsg(pingTimerMsg{})
}
func (h *PeerStateTestHarness) ConfigServer_Public(t *testing.T) *stateServer {
keys := generateKeys()
state := h.State.(*stateDisconnected)
state.localAddr = addrPort4(1, 1, 1, 2, 200)
peer := &m.Peer{
PeerIP: 3,
PublicIP: []byte{1, 1, 1, 3},
Port: 456,
PubKey: keys.PubKey,
PubSignKey: keys.PubSignKey,
}
h.PeerUpdate(peer)
assertEqual(t, h.Published.Up, false)
return assertType[*stateServer](t, h.State)
}
func (h *PeerStateTestHarness) ConfigServer_Relayed(t *testing.T) *stateServer {
keys := generateKeys()
peer := &m.Peer{
PeerIP: 3,
Port: 456,
PubKey: keys.PubKey,
PubSignKey: keys.PubSignKey,
}
h.PeerUpdate(peer)
assertEqual(t, h.Published.Up, false)
return assertType[*stateServer](t, h.State)
}
func (h *PeerStateTestHarness) ConfigClientInit(t *testing.T) *stateClientInit {
// Remote IP should be less than local IP.
h.data.localIP = 4
keys := generateKeys()
peer := &m.Peer{
PeerIP: 3,
PublicIP: []byte{1, 2, 3, 4},
Port: 456,
PubKey: keys.PubKey,
PubSignKey: keys.PubSignKey,
}
h.PeerUpdate(peer)
assertEqual(t, h.Published.Up, false)
return assertType[*stateClientInit](t, h.State)
}
func (h *PeerStateTestHarness) ConfigClientDirect(t *testing.T) *stateClient {
h.ConfigClientInit(t)
init := assertType[packetInit](t, h.Sent[0].Packet)
h.OnInit(controlMsg[packetInit]{
Packet: init,
})
return assertType[*stateClient](t, h.State)
}
func (h *PeerStateTestHarness) ConfigClientRelayed(t *testing.T) *stateClient {
h.ConfigClientInit(t)
state := assertType[*stateClientInit](t, h.State)
state.peer.PublicIP = nil // Force relay.
init := assertType[packetInit](t, h.Sent[0].Packet)
h.OnInit(controlMsg[packetInit]{
Packet: init,
})
return assertType[*stateClient](t, h.State)
}

109
peer/statedata.go Normal file
View File

@ -0,0 +1,109 @@
package peer
import (
"fmt"
"log"
"net/netip"
"strings"
"time"
"vppn/m"
"git.crumpington.com/lib/go/ratelimiter"
)
type peerState interface {
OnMsg(raw any) peerState
}
// ----------------------------------------------------------------------------
type peerData struct {
// Output.
publish func(remotePeer)
sendControlPacket func(remotePeer, marshaller)
pingTimer *time.Ticker
// Immutable data.
localIP byte
remoteIP byte
privKey []byte
localAddr netip.AddrPort // If valid, then local peer is publicly accessible.
pubAddrs *pubAddrStore
// The purpose of this state machine is to manage the RemotePeer object,
// publishing it as necessary.
staged remotePeer // Local copy of shared data. See publish().
// Mutable peer data.
peer *m.Peer
// We rate limit per remote endpoint because if we don't we tend to lose
// packets.
limiter *ratelimiter.Limiter
}
func (s *peerData) logf(format string, args ...any) {
b := strings.Builder{}
name := ""
if s.peer != nil {
name = s.peer.Name
}
b.WriteString(fmt.Sprintf("%03d", s.remoteIP))
b.WriteString(fmt.Sprintf("%30s: ", name))
if s.staged.Direct {
b.WriteString("DIRECT | ")
} else {
b.WriteString("RELAYED | ")
}
if s.staged.Up {
b.WriteString("UP | ")
} else {
b.WriteString("DOWN | ")
}
log.Printf(b.String()+format, args...)
}
// ----------------------------------------------------------------------------
func (s *peerData) SendTo(pkt marshaller, addr netip.AddrPort) {
if !addr.IsValid() {
return
}
route := s.staged
route.Direct = true
route.DirectAddr = addr
s.Send(route, pkt)
}
func (s *peerData) Send(peer remotePeer, pkt marshaller) {
if err := s.limiter.Limit(); err != nil {
s.logf("Rate limited.")
return
}
s.sendControlPacket(peer, pkt)
}
func initPeerState(data *peerData, peer *m.Peer) peerState {
data.peer = peer
if peer == nil {
return enterStateDisconnected(data)
}
if _, isValid := netip.AddrFromSlice(peer.PublicIP); isValid {
if data.localAddr.IsValid() && data.localIP < data.remoteIP {
return enterStateServer(data)
}
return enterStateClientInit(data)
}
if data.localAddr.IsValid() || data.localIP < data.remoteIP {
return enterStateServer(data)
}
return enterStateClientInit(data)
}