From c7bc1ecc458da56ce33aaf1a3d4b382d8dfd7901 Mon Sep 17 00:00:00 2001
From: jdl
Date: Thu, 19 Dec 2024 06:45:17 +0100
Subject: [PATCH 01/18] WIP: separate data and control ciphers.
---
node/conn.go | 6 +-
node/crypto.go | 4 +-
node/crypto_test.go | 19 +++---
node/datacipher.go | 97 ++++++++++++++++++++++++++++
node/datacipher_test.go | 138 ++++++++++++++++++++++++++++++++++++++++
node/dupcheck.go | 4 --
node/dupcheck_test.go | 3 -
node/routingpacket.go | 4 +-
8 files changed, 249 insertions(+), 26 deletions(-)
create mode 100644 node/datacipher.go
create mode 100644 node/datacipher_test.go
diff --git a/node/conn.go b/node/conn.go
index 9224d57..0823e31 100644
--- a/node/conn.go
+++ b/node/conn.go
@@ -74,7 +74,7 @@ func (w *connWriter) WriteToPeer(dstPeer, viaPeer *peer, stream byte, data []byt
Stream: stream,
}
- buf := encryptPacket(&h, dstPeer.SharedKey, data, w.buf)
+ buf := encryptPacketAsym(&h, dstPeer.SharedKey, data, w.buf)
if viaPeer != nil {
h := header{
@@ -85,7 +85,7 @@ func (w *connWriter) WriteToPeer(dstPeer, viaPeer *peer, stream byte, data []byt
Stream: stream,
}
- buf = encryptPacket(&h, viaPeer.SharedKey, buf, w.buf2)
+ buf = encryptPacketAsym(&h, viaPeer.SharedKey, buf, w.buf2)
addr = viaPeer.Addr
}
@@ -155,7 +155,7 @@ func (r *connReader) Read(buf []byte) (remoteAddr netip.AddrPort, h header, data
continue
}
- out, ok := decryptPacket(peer.SharedKey, data, r.buf)
+ out, ok := decryptPacketAsym(peer.SharedKey, data, r.buf)
if !ok {
continue
}
diff --git a/node/crypto.go b/node/crypto.go
index cc5904f..0f7710f 100644
--- a/node/crypto.go
+++ b/node/crypto.go
@@ -8,14 +8,14 @@ import (
)
// Encrypting the packet will also set the header's DataSize field.
-func encryptPacket(h *header, sharedKey, data, out []byte) []byte {
+func encryptPacketAsym(h *header, sharedKey, data, out []byte) []byte {
out = out[:headerSize]
h.Marshal(out)
b := box.SealAfterPrecomputation(out[headerSize:headerSize], data, (*[24]byte)(out[:headerSize]), (*[32]byte)(sharedKey))
return out[:len(b)+headerSize]
}
-func decryptPacket(sharedKey, packetAndHeader, out []byte) (decrypted []byte, ok bool) {
+func decryptPacketAsym(sharedKey, packetAndHeader, out []byte) (decrypted []byte, ok bool) {
return box.OpenAfterPrecomputation(
out[:0],
packetAndHeader[headerSize:],
diff --git a/node/crypto_test.go b/node/crypto_test.go
index 0a651b0..76f408f 100644
--- a/node/crypto_test.go
+++ b/node/crypto_test.go
@@ -3,14 +3,13 @@ package node
import (
"bytes"
"crypto/rand"
- "log"
"reflect"
"testing"
"golang.org/x/crypto/nacl/box"
)
-func TestEncryptDecryptPacket(t *testing.T) {
+func TestEncryptDecryptAsym(t *testing.T) {
pubKey1, privKey1, err := box.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
@@ -21,8 +20,6 @@ func TestEncryptDecryptPacket(t *testing.T) {
t.Fatal(err)
}
- log.Printf("\n%#v\n%#v\n%#v\n%#v\n", pubKey1, privKey1, pubKey2, privKey2)
-
sharedEncKey := [32]byte{}
box.Precompute(&sharedEncKey, pubKey2, privKey1)
@@ -41,11 +38,11 @@ func TestEncryptDecryptPacket(t *testing.T) {
}
encrypted := make([]byte, bufferSize)
- encrypted = encryptPacket(&h, sharedEncKey[:], original, encrypted)
+ encrypted = encryptPacketAsym(&h, sharedEncKey[:], original, encrypted)
decrypted := make([]byte, bufferSize)
var ok bool
- decrypted, ok = decryptPacket(sharedDecKey[:], encrypted, decrypted)
+ decrypted, ok = decryptPacketAsym(sharedDecKey[:], encrypted, decrypted)
if !ok {
t.Fatal(ok)
}
@@ -62,7 +59,7 @@ func TestEncryptDecryptPacket(t *testing.T) {
}
}
-func BenchmarkEncryptPacket(b *testing.B) {
+func BenchmarkEncryptAsym(b *testing.B) {
_, privKey1, err := box.GenerateKey(rand.Reader)
if err != nil {
b.Fatal(err)
@@ -93,11 +90,11 @@ func BenchmarkEncryptPacket(b *testing.B) {
}
for i := 0; i < b.N; i++ {
- encrypted = encryptPacket(&h, sharedEncKey[:], original, encrypted)
+ encrypted = encryptPacketAsym(&h, sharedEncKey[:], original, encrypted)
}
}
-func BenchmarkDecryptPacket(b *testing.B) {
+func BenchmarkDecryptAsym(b *testing.B) {
pubKey1, privKey1, err := box.GenerateKey(rand.Reader)
if err != nil {
b.Fatal(err)
@@ -128,11 +125,11 @@ func BenchmarkDecryptPacket(b *testing.B) {
Stream: 1,
}
- encrypted := encryptPacket(&h, sharedEncKey[:], original, make([]byte, bufferSize))
+ encrypted := encryptPacketAsym(&h, sharedEncKey[:], original, make([]byte, bufferSize))
decrypted := make([]byte, bufferSize)
var ok bool
for i := 0; i < b.N; i++ {
- decrypted, ok = decryptPacket(sharedDecKey[:], encrypted, decrypted)
+ decrypted, ok = decryptPacketAsym(sharedDecKey[:], encrypted, decrypted)
if !ok {
panic(ok)
}
diff --git a/node/datacipher.go b/node/datacipher.go
new file mode 100644
index 0000000..b631d14
--- /dev/null
+++ b/node/datacipher.go
@@ -0,0 +1,97 @@
+package node
+
+import (
+ "crypto/aes"
+ "crypto/cipher"
+ "crypto/rand"
+ "unsafe"
+)
+
+// ----------------------------------------------------------------------------
+
+const (
+ dataStreamID = 1
+ dataHeaderSize = 12
+ dataCipherOverhead = 16 + 1
+)
+
+type dataHeader struct {
+ Counter uint64 // Init with fasttime.Now() << 30 to ensure monotonic.
+ SourceIP byte
+ DestIP byte
+}
+
+func (h *dataHeader) Parse(b []byte) {
+ h.Counter = *(*uint64)(unsafe.Pointer(&b[0]))
+ h.SourceIP = b[8]
+ h.DestIP = b[9]
+}
+
+func (h *dataHeader) Marshal(buf []byte) {
+ *(*uint64)(unsafe.Pointer(&buf[0])) = h.Counter
+ buf[8] = h.SourceIP
+ buf[9] = h.DestIP
+ buf[10] = 0
+ buf[11] = 0
+}
+
+// ----------------------------------------------------------------------------
+
+type dataCipher struct {
+ key []byte
+ aead cipher.AEAD
+}
+
+func newDataCipher() *dataCipher {
+ key := make([]byte, 32)
+ if _, err := rand.Read(key); err != nil {
+ panic(err)
+ }
+ return newDataCipherFromKey(key)
+}
+
+// key must be 32 bytes.
+func newDataCipherFromKey(key []byte) *dataCipher {
+ block, err := aes.NewCipher(key)
+ if err != nil {
+ panic(err)
+ }
+
+ aead, err := cipher.NewGCM(block)
+ if err != nil {
+ panic(err)
+ }
+
+ return &dataCipher{key: key, aead: aead}
+}
+
+func (sc *dataCipher) Key() []byte {
+ return sc.key
+}
+
+func (sc *dataCipher) Encrypt(h *dataHeader, data, out []byte) []byte {
+ out = out[:dataHeaderSize+dataCipherOverhead+len(data)]
+ out[0] = dataStreamID
+
+ h.Marshal(out[1:])
+
+ const s = dataHeaderSize
+ sc.aead.Seal(out[1+s:1+s], out[1:1+s], data, nil)
+ return out
+}
+
+func (sc *dataCipher) Decrypt(encrypted, out []byte) (data []byte, h dataHeader, ok bool) {
+ const s = dataHeaderSize
+ if len(encrypted) < s+dataCipherOverhead {
+ ok = false
+ return
+ }
+
+ h.Parse(encrypted[1 : 1+s])
+
+ var err error
+
+ data, err = sc.aead.Open(out[:0], encrypted[1:1+s], encrypted[1+s:], nil)
+ ok = err == nil
+ return
+}
diff --git a/node/datacipher_test.go b/node/datacipher_test.go
new file mode 100644
index 0000000..8a0f012
--- /dev/null
+++ b/node/datacipher_test.go
@@ -0,0 +1,138 @@
+package node
+
+import (
+ "bytes"
+ "crypto/rand"
+ mrand "math/rand/v2"
+ "reflect"
+ "testing"
+)
+
+func TestDataCipher(t *testing.T) {
+ maxSizePlaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead)
+ rand.Read(maxSizePlaintext)
+
+ testCases := [][]byte{
+ make([]byte, 0),
+ {1},
+ {255},
+ {1, 2, 3, 4, 5},
+ []byte("Hello world"),
+ maxSizePlaintext,
+ }
+
+ for _, plaintext := range testCases {
+ h1 := dataHeader{
+ Counter: 235153,
+ SourceIP: 4,
+ DestIP: 88,
+ }
+
+ encrypted := make([]byte, bufferSize)
+
+ dc1 := newDataCipher()
+ encrypted = dc1.Encrypt(&h1, plaintext, encrypted)
+
+ dc2 := newDataCipherFromKey(dc1.Key())
+
+ decrypted, h2, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize))
+ if !ok {
+ t.Fatal(ok)
+ }
+
+ if !bytes.Equal(plaintext, decrypted) {
+ t.Fatal("not equal")
+ }
+
+ if !reflect.DeepEqual(h1, h2) {
+ t.Fatalf("%v != %v", h1, h2)
+ }
+ }
+}
+
+func TestDataCipher_ModifyCiphertext(t *testing.T) {
+ maxSizePlaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead)
+ rand.Read(maxSizePlaintext)
+
+ testCases := [][]byte{
+ make([]byte, 0),
+ {1},
+ {255},
+ {1, 2, 3, 4, 5},
+ []byte("Hello world"),
+ maxSizePlaintext,
+ }
+
+ for _, plaintext := range testCases {
+ h1 := dataHeader{
+ Counter: 235153,
+ SourceIP: 4,
+ DestIP: 88,
+ }
+
+ encrypted := make([]byte, bufferSize)
+
+ dc1 := newDataCipher()
+ encrypted = dc1.Encrypt(&h1, plaintext, encrypted)
+ encrypted[mrand.IntN(len(encrypted))]++
+
+ dc2 := newDataCipherFromKey(dc1.Key())
+
+ _, h2, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize))
+ if ok {
+ t.Fatal(ok, h2)
+ }
+ }
+}
+
+func TestDataCipher_ShortCiphertext(t *testing.T) {
+ dc1 := newDataCipher()
+ shortText := make([]byte, dataHeaderSize+dataCipherOverhead-1)
+ rand.Read(shortText)
+ _, _, ok := dc1.Decrypt(shortText, make([]byte, bufferSize))
+ if ok {
+ t.Fatal(ok)
+ }
+}
+
+func BenchmarkDataCipher_Encrypt(b *testing.B) {
+ h1 := dataHeader{
+ Counter: 235153,
+ SourceIP: 4,
+ DestIP: 88,
+ }
+
+ plaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead)
+ rand.Read(plaintext)
+
+ encrypted := make([]byte, bufferSize)
+
+ dc1 := newDataCipher()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ encrypted = dc1.Encrypt(&h1, plaintext, encrypted)
+ }
+}
+
+func BenchmarkDataCipher_Decrypt(b *testing.B) {
+ h1 := dataHeader{
+ Counter: 235153,
+ SourceIP: 4,
+ DestIP: 88,
+ }
+
+ plaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead)
+ rand.Read(plaintext)
+
+ encrypted := make([]byte, bufferSize)
+
+ dc1 := newDataCipher()
+ encrypted = dc1.Encrypt(&h1, plaintext, encrypted)
+
+ decrypted := make([]byte, bufferSize)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ decrypted, _, _ = dc1.Decrypt(encrypted, decrypted)
+ }
+}
diff --git a/node/dupcheck.go b/node/dupcheck.go
index e960bd4..fac7a72 100644
--- a/node/dupcheck.go
+++ b/node/dupcheck.go
@@ -1,7 +1,5 @@
package node
-import "log"
-
type dupCheck struct {
bitSet
head int
@@ -22,7 +20,6 @@ func (dc *dupCheck) IsDup(counter uint64) bool {
// Before head => it's late, say it's a dup.
if counter < dc.headCounter {
- log.Printf("Late: %d", counter)
return true
}
@@ -30,7 +27,6 @@ func (dc *dupCheck) IsDup(counter uint64) bool {
if counter < dc.tailCounter {
index := (int(counter-dc.headCounter) + dc.head) % bitSetSize
if dc.Get(index) {
- log.Printf("Dup: %d, %d", counter, dc.tailCounter)
return true
}
diff --git a/node/dupcheck_test.go b/node/dupcheck_test.go
index 9a939b5..2156b4e 100644
--- a/node/dupcheck_test.go
+++ b/node/dupcheck_test.go
@@ -1,7 +1,6 @@
package node
import (
- "log"
"testing"
)
@@ -49,8 +48,6 @@ func TestDupCheck(t *testing.T) {
for i, tc := range testCases {
if ok := dc.IsDup(tc.Counter); ok != tc.Dup {
- log.Printf("%b", dc.bitSet)
- log.Printf("%+v", *dc)
t.Fatal(i, ok, tc)
}
}
diff --git a/node/routingpacket.go b/node/routingpacket.go
index 4e35055..64f0374 100644
--- a/node/routingpacket.go
+++ b/node/routingpacket.go
@@ -8,10 +8,8 @@ import (
var errMalformedPacket = errors.New("malformed packet")
const (
- packetTypeInvalid = iota
-
// Used to maintain connection.
- packetTypePing
+ packetTypePing = iota + 1
packetTypePong
)
--
2.39.5
From 0ae0f31eae0e270413cd8256eb12cae515e2db8f Mon Sep 17 00:00:00 2001
From: jdl
Date: Thu, 19 Dec 2024 20:53:52 +0100
Subject: [PATCH 02/18] wip
---
node/cipher-data.go | 61 ++++++++++
...datacipher_test.go => cipher-data_test.go} | 28 +++--
node/cipher-routing.go | 26 ++++
node/cipher-routing_test.go | 114 ++++++++++++++++++
node/cipher.go | 6 +
node/conn.go | 39 +++++-
node/datacipher.go | 97 ---------------
node/header.go | 35 ++++++
node/packets.go | 89 ++++++++++++++
node/packets_test.go | 42 +++++++
node/peersupervisor.go | 14 +--
node/router.go | 18 ++-
node/routingpacket.go | 9 --
13 files changed, 443 insertions(+), 135 deletions(-)
create mode 100644 node/cipher-data.go
rename node/{datacipher_test.go => cipher-data_test.go} (78%)
create mode 100644 node/cipher-routing.go
create mode 100644 node/cipher-routing_test.go
create mode 100644 node/cipher.go
delete mode 100644 node/datacipher.go
create mode 100644 node/packets.go
create mode 100644 node/packets_test.go
diff --git a/node/cipher-data.go b/node/cipher-data.go
new file mode 100644
index 0000000..c0fc273
--- /dev/null
+++ b/node/cipher-data.go
@@ -0,0 +1,61 @@
+package node
+
+import (
+ "crypto/aes"
+ "crypto/cipher"
+ "crypto/rand"
+)
+
+type dataCipher struct {
+ key []byte
+ aead cipher.AEAD
+}
+
+func newDataCipher() *dataCipher {
+ key := make([]byte, 32)
+ if _, err := rand.Read(key); err != nil {
+ panic(err)
+ }
+ return newDataCipherFromKey(key)
+}
+
+// key must be 32 bytes.
+func newDataCipherFromKey(key []byte) *dataCipher {
+ block, err := aes.NewCipher(key)
+ if err != nil {
+ panic(err)
+ }
+
+ aead, err := cipher.NewGCM(block)
+ if err != nil {
+ panic(err)
+ }
+
+ return &dataCipher{key: key, aead: aead}
+}
+
+func (sc *dataCipher) Key() []byte {
+ return sc.key
+}
+
+func (sc *dataCipher) Encrypt(h xHeader, data, out []byte) []byte {
+ const s = dataHeaderSize
+ out = out[:s+dataCipherOverhead+len(data)]
+ h.Marshal(dataStreamID, out[:s])
+ sc.aead.Seal(out[s:s], out[:s], data, nil)
+ return out
+}
+
+func (sc *dataCipher) Decrypt(encrypted, out []byte) (data []byte, ok bool) {
+ const s = dataHeaderSize
+ if len(encrypted) < s+dataCipherOverhead {
+ ok = false
+ return
+ }
+
+ var err error
+
+ data, err = sc.aead.Open(out[:0], encrypted[:s], encrypted[s:], nil)
+ ok = err == nil
+ return
+}
diff --git a/node/datacipher_test.go b/node/cipher-data_test.go
similarity index 78%
rename from node/datacipher_test.go
rename to node/cipher-data_test.go
index 8a0f012..d1523d8 100644
--- a/node/datacipher_test.go
+++ b/node/cipher-data_test.go
@@ -22,7 +22,7 @@ func TestDataCipher(t *testing.T) {
}
for _, plaintext := range testCases {
- h1 := dataHeader{
+ h1 := xHeader{
Counter: 235153,
SourceIP: 4,
DestIP: 88,
@@ -31,11 +31,13 @@ func TestDataCipher(t *testing.T) {
encrypted := make([]byte, bufferSize)
dc1 := newDataCipher()
- encrypted = dc1.Encrypt(&h1, plaintext, encrypted)
+ encrypted = dc1.Encrypt(h1, plaintext, encrypted)
+ h2 := xHeader{}
+ h2.Parse(encrypted)
dc2 := newDataCipherFromKey(dc1.Key())
- decrypted, h2, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize))
+ decrypted, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize))
if !ok {
t.Fatal(ok)
}
@@ -64,7 +66,7 @@ func TestDataCipher_ModifyCiphertext(t *testing.T) {
}
for _, plaintext := range testCases {
- h1 := dataHeader{
+ h1 := xHeader{
Counter: 235153,
SourceIP: 4,
DestIP: 88,
@@ -73,14 +75,14 @@ func TestDataCipher_ModifyCiphertext(t *testing.T) {
encrypted := make([]byte, bufferSize)
dc1 := newDataCipher()
- encrypted = dc1.Encrypt(&h1, plaintext, encrypted)
+ encrypted = dc1.Encrypt(h1, plaintext, encrypted)
encrypted[mrand.IntN(len(encrypted))]++
dc2 := newDataCipherFromKey(dc1.Key())
- _, h2, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize))
+ _, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize))
if ok {
- t.Fatal(ok, h2)
+ t.Fatal(ok)
}
}
}
@@ -89,14 +91,14 @@ func TestDataCipher_ShortCiphertext(t *testing.T) {
dc1 := newDataCipher()
shortText := make([]byte, dataHeaderSize+dataCipherOverhead-1)
rand.Read(shortText)
- _, _, ok := dc1.Decrypt(shortText, make([]byte, bufferSize))
+ _, ok := dc1.Decrypt(shortText, make([]byte, bufferSize))
if ok {
t.Fatal(ok)
}
}
func BenchmarkDataCipher_Encrypt(b *testing.B) {
- h1 := dataHeader{
+ h1 := xHeader{
Counter: 235153,
SourceIP: 4,
DestIP: 88,
@@ -110,12 +112,12 @@ func BenchmarkDataCipher_Encrypt(b *testing.B) {
dc1 := newDataCipher()
b.ResetTimer()
for i := 0; i < b.N; i++ {
- encrypted = dc1.Encrypt(&h1, plaintext, encrypted)
+ encrypted = dc1.Encrypt(h1, plaintext, encrypted)
}
}
func BenchmarkDataCipher_Decrypt(b *testing.B) {
- h1 := dataHeader{
+ h1 := xHeader{
Counter: 235153,
SourceIP: 4,
DestIP: 88,
@@ -127,12 +129,12 @@ func BenchmarkDataCipher_Decrypt(b *testing.B) {
encrypted := make([]byte, bufferSize)
dc1 := newDataCipher()
- encrypted = dc1.Encrypt(&h1, plaintext, encrypted)
+ encrypted = dc1.Encrypt(h1, plaintext, encrypted)
decrypted := make([]byte, bufferSize)
b.ResetTimer()
for i := 0; i < b.N; i++ {
- decrypted, _, _ = dc1.Decrypt(encrypted, decrypted)
+ decrypted, _ = dc1.Decrypt(encrypted, decrypted)
}
}
diff --git a/node/cipher-routing.go b/node/cipher-routing.go
new file mode 100644
index 0000000..795ac7a
--- /dev/null
+++ b/node/cipher-routing.go
@@ -0,0 +1,26 @@
+package node
+
+import "golang.org/x/crypto/nacl/box"
+
+type routingCipher struct {
+ sharedKey [32]byte
+}
+
+func newRoutingCipher(privKey, pubKey []byte) routingCipher {
+ shared := [32]byte{}
+ box.Precompute(&shared, (*[32]byte)(pubKey), (*[32]byte)(privKey))
+ return routingCipher{shared}
+}
+
+func (rc routingCipher) Encrypt(h xHeader, data, out []byte) []byte {
+ const s = routingHeaderSize
+ out = out[:s+routingCipherOverhead+len(data)]
+ h.Marshal(routingStreamID, out[:s])
+ box.SealAfterPrecomputation(out[s:s], data, (*[24]byte)(out[:s]), &rc.sharedKey)
+ return out
+}
+
+func (rc routingCipher) Decrypt(encrypted, out []byte) (data []byte, ok bool) {
+ const s = routingHeaderSize
+ return box.OpenAfterPrecomputation(out[:0], encrypted[s:], (*[24]byte)(encrypted[:s]), &rc.sharedKey)
+}
diff --git a/node/cipher-routing_test.go b/node/cipher-routing_test.go
new file mode 100644
index 0000000..09824f7
--- /dev/null
+++ b/node/cipher-routing_test.go
@@ -0,0 +1,114 @@
+package node
+
+import (
+ "bytes"
+ "crypto/rand"
+ "testing"
+
+ "golang.org/x/crypto/nacl/box"
+)
+
+func newRoutingCipherForTesting() (c1, c2 routingCipher) {
+ pubKey1, privKey1, err := box.GenerateKey(rand.Reader)
+ if err != nil {
+ panic(err)
+ }
+
+ pubKey2, privKey2, err := box.GenerateKey(rand.Reader)
+ if err != nil {
+ panic(err)
+ }
+
+ return newRoutingCipher(privKey1[:], pubKey2[:]),
+ newRoutingCipher(privKey2[:], pubKey1[:])
+}
+
+func TestRoutingCipher(t *testing.T) {
+ c1, c2 := newRoutingCipherForTesting()
+
+ maxSizePlaintext := make([]byte, bufferSize-routingHeaderSize-routingCipherOverhead)
+ rand.Read(maxSizePlaintext)
+
+ testCases := [][]byte{
+ make([]byte, 0),
+ {1},
+ {255},
+ {1, 2, 3, 4, 5},
+ []byte("Hello world"),
+ maxSizePlaintext,
+ }
+
+ for _, plaintext := range testCases {
+ h1 := xHeader{
+ Counter: 235153,
+ SourceIP: 4,
+ DestIP: 88,
+ }
+
+ encrypted := make([]byte, bufferSize)
+
+ encrypted = c1.Encrypt(h1, plaintext, encrypted)
+
+ decrypted, ok := c2.Decrypt(encrypted, make([]byte, bufferSize))
+ if !ok {
+ t.Fatal(ok)
+ }
+
+ if !bytes.Equal(decrypted, plaintext) {
+ t.Fatal("not equal")
+ }
+ }
+}
+
+func TestRoutingCipher_ShortCiphertext(t *testing.T) {
+ c1, _ := newRoutingCipherForTesting()
+ shortText := make([]byte, routingHeaderSize+routingCipherOverhead-1)
+ rand.Read(shortText)
+ _, ok := c1.Decrypt(shortText, make([]byte, bufferSize))
+ if ok {
+ t.Fatal(ok)
+ }
+}
+
+func BenchmarkRoutingCipher_Encrypt(b *testing.B) {
+ c1, _ := newRoutingCipherForTesting()
+ h1 := xHeader{
+ Counter: 235153,
+ SourceIP: 4,
+ DestIP: 88,
+ }
+
+ plaintext := make([]byte, bufferSize-routingHeaderSize-routingCipherOverhead)
+ rand.Read(plaintext)
+
+ encrypted := make([]byte, bufferSize)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ encrypted = c1.Encrypt(h1, plaintext, encrypted)
+ }
+}
+
+func BenchmarkRoutingCipher_Decrypt(b *testing.B) {
+ c1, c2 := newRoutingCipherForTesting()
+
+ h1 := xHeader{
+ Counter: 235153,
+ SourceIP: 4,
+ DestIP: 88,
+ }
+
+ plaintext := make([]byte, bufferSize-routingHeaderSize-routingCipherOverhead)
+ rand.Read(plaintext)
+
+ encrypted := make([]byte, bufferSize)
+
+ encrypted = c1.Encrypt(h1, plaintext, encrypted)
+
+ decrypted := make([]byte, bufferSize)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ decrypted, _ = c2.Decrypt(encrypted, decrypted)
+ }
+}
diff --git a/node/cipher.go b/node/cipher.go
new file mode 100644
index 0000000..cb7accd
--- /dev/null
+++ b/node/cipher.go
@@ -0,0 +1,6 @@
+package node
+
+type packetCipher interface {
+ Encrypt(h xHeader, data, out []byte) []byte
+ Decrypt(encrypted, out []byte) (data []byte, ok bool)
+}
diff --git a/node/conn.go b/node/conn.go
index 0823e31..8a57641 100644
--- a/node/conn.go
+++ b/node/conn.go
@@ -35,6 +35,33 @@ func newConnWriter(conn *net.UDPConn, localIP byte, routing *routingTable) *conn
return w
}
+/*
+ func (w *connWriter) SendRouting(remoteIP byte, data []byte) {
+ dstPeer := w.routing.Get(remoteIP)
+ if dstPeer == nil {
+ log.Printf("No peer: %d", remoteIP)
+ return
+ }
+
+ var viaPeer *peer
+
+ if dstPeer.Addr == zeroAddrPort {
+ viaPeer = w.routing.Mediator()
+ if viaPeer == nil {
+ log.Printf("No mediator: %d", remoteIP)
+ return
+ }
+ }
+
+ w.sendRouting(dstPeer, viaPeer, data)
+ }
+*/
+
+func (w *connWriter) SendData(remoteIP byte, data []byte) {
+ // TODO
+}
+
+// TODO: deprecated
func (w *connWriter) WriteTo(remoteIP, stream byte, data []byte) {
dstPeer := w.routing.Get(remoteIP)
if dstPeer == nil {
@@ -50,11 +77,11 @@ func (w *connWriter) WriteTo(remoteIP, stream byte, data []byte) {
var viaPeer *peer
if dstPeer.Mediated {
viaPeer = w.routing.mediator.Load()
- if viaPeer == nil || viaPeer.Addr == nil {
+ if viaPeer == nil || viaPeer.Addr == zeroAddrPort {
log.Printf("Mediator not connected")
return
}
- } else if dstPeer.Addr == nil {
+ } else if dstPeer.Addr == zeroAddrPort {
log.Printf("Peer doesn't have address: %d", remoteIP)
return
}
@@ -62,6 +89,7 @@ func (w *connWriter) WriteTo(remoteIP, stream byte, data []byte) {
w.WriteToPeer(dstPeer, viaPeer, stream, data)
}
+// TODO: deprecated
func (w *connWriter) WriteToPeer(dstPeer, viaPeer *peer, stream byte, data []byte) {
w.lock.Lock()
@@ -89,20 +117,21 @@ func (w *connWriter) WriteToPeer(dstPeer, viaPeer *peer, stream byte, data []byt
addr = viaPeer.Addr
}
- if _, err := w.WriteToUDPAddrPort(buf, *addr); err != nil {
+ if _, err := w.WriteToUDPAddrPort(buf, addr); err != nil {
log.Fatalf("Failed to write to UDP port: %v", err)
}
w.lock.Unlock()
}
+// TODO: deprecated
func (w *connWriter) Forward(dstIP byte, packet []byte) {
dstPeer := w.routing.Get(dstIP)
- if dstPeer == nil || dstPeer.Addr == nil {
+ if dstPeer == nil || dstPeer.Addr == zeroAddrPort {
log.Printf("No peer: %d", dstIP)
return
}
- if _, err := w.WriteToUDPAddrPort(packet, *dstPeer.Addr); err != nil {
+ if _, err := w.WriteToUDPAddrPort(packet, dstPeer.Addr); err != nil {
log.Fatalf("Failed to write to UDP port: %v", err)
}
}
diff --git a/node/datacipher.go b/node/datacipher.go
deleted file mode 100644
index b631d14..0000000
--- a/node/datacipher.go
+++ /dev/null
@@ -1,97 +0,0 @@
-package node
-
-import (
- "crypto/aes"
- "crypto/cipher"
- "crypto/rand"
- "unsafe"
-)
-
-// ----------------------------------------------------------------------------
-
-const (
- dataStreamID = 1
- dataHeaderSize = 12
- dataCipherOverhead = 16 + 1
-)
-
-type dataHeader struct {
- Counter uint64 // Init with fasttime.Now() << 30 to ensure monotonic.
- SourceIP byte
- DestIP byte
-}
-
-func (h *dataHeader) Parse(b []byte) {
- h.Counter = *(*uint64)(unsafe.Pointer(&b[0]))
- h.SourceIP = b[8]
- h.DestIP = b[9]
-}
-
-func (h *dataHeader) Marshal(buf []byte) {
- *(*uint64)(unsafe.Pointer(&buf[0])) = h.Counter
- buf[8] = h.SourceIP
- buf[9] = h.DestIP
- buf[10] = 0
- buf[11] = 0
-}
-
-// ----------------------------------------------------------------------------
-
-type dataCipher struct {
- key []byte
- aead cipher.AEAD
-}
-
-func newDataCipher() *dataCipher {
- key := make([]byte, 32)
- if _, err := rand.Read(key); err != nil {
- panic(err)
- }
- return newDataCipherFromKey(key)
-}
-
-// key must be 32 bytes.
-func newDataCipherFromKey(key []byte) *dataCipher {
- block, err := aes.NewCipher(key)
- if err != nil {
- panic(err)
- }
-
- aead, err := cipher.NewGCM(block)
- if err != nil {
- panic(err)
- }
-
- return &dataCipher{key: key, aead: aead}
-}
-
-func (sc *dataCipher) Key() []byte {
- return sc.key
-}
-
-func (sc *dataCipher) Encrypt(h *dataHeader, data, out []byte) []byte {
- out = out[:dataHeaderSize+dataCipherOverhead+len(data)]
- out[0] = dataStreamID
-
- h.Marshal(out[1:])
-
- const s = dataHeaderSize
- sc.aead.Seal(out[1+s:1+s], out[1:1+s], data, nil)
- return out
-}
-
-func (sc *dataCipher) Decrypt(encrypted, out []byte) (data []byte, h dataHeader, ok bool) {
- const s = dataHeaderSize
- if len(encrypted) < s+dataCipherOverhead {
- ok = false
- return
- }
-
- h.Parse(encrypted[1 : 1+s])
-
- var err error
-
- data, err = sc.aead.Open(out[:0], encrypted[1:1+s], encrypted[1+s:], nil)
- ok = err == nil
- return
-}
diff --git a/node/header.go b/node/header.go
index ed3671a..a409576 100644
--- a/node/header.go
+++ b/node/header.go
@@ -2,6 +2,41 @@ package node
import "unsafe"
+// ----------------------------------------------------------------------------
+
+const (
+ routingStreamID = 2
+ routingHeaderSize = 24
+ routingCipherOverhead = 16
+
+ dataStreamID = 1
+ dataHeaderSize = 12
+ dataCipherOverhead = 16
+)
+
+// TODO: Rename
+type xHeader struct {
+ Counter uint64 // Init with fasttime.Now() << 30 to ensure monotonic.
+ SourceIP byte
+ DestIP byte
+}
+
+func (h *xHeader) Parse(b []byte) {
+ h.Counter = *(*uint64)(unsafe.Pointer(&b[1]))
+ h.SourceIP = b[9]
+ h.DestIP = b[10]
+}
+
+func (h *xHeader) Marshal(streamID byte, buf []byte) {
+ buf[0] = streamID
+ *(*uint64)(unsafe.Pointer(&buf[1])) = h.Counter
+ buf[9] = h.SourceIP
+ buf[10] = h.DestIP
+ buf[11] = 0
+}
+
+// ----------------------------------------------------------------------------
+// TODO: Remove this code.
const (
headerSize = 24
streamData = 1
diff --git a/node/packets.go b/node/packets.go
new file mode 100644
index 0000000..75f4e6e
--- /dev/null
+++ b/node/packets.go
@@ -0,0 +1,89 @@
+package node
+
+import (
+ "errors"
+ "net/netip"
+ "time"
+ "unsafe"
+)
+
+var errMalformedPacket = errors.New("malformed packet")
+
+const (
+ packetTypePing = iota + 1
+ packetTypePong
+)
+
+// ----------------------------------------------------------------------------
+
+type packetWrapper struct {
+ SrcIP byte
+ RemoteAddr netip.AddrPort
+ Packet any
+}
+
+// ----------------------------------------------------------------------------
+
+// A pingPacket is sent from a node acting as a client, to a node acting
+// as a server. It always contains the shared key the client is expecting
+// to use for data encryption with the server.
+type pingPacket struct {
+ SentAt int64 // UnixMilli.
+ SharedKey [32]byte
+}
+
+func newPingPacket(sharedKey []byte) (pp pingPacket) {
+ pp.SentAt = time.Now().UnixMilli()
+ copy(pp.SharedKey[:], sharedKey)
+ return
+}
+
+func (p pingPacket) Marshal(buf []byte) []byte {
+ buf = buf[:41]
+ buf[0] = packetTypePing
+ *(*uint64)(unsafe.Pointer(&buf[1])) = uint64(p.SentAt)
+ copy(buf[9:41], p.SharedKey[:])
+ return buf
+}
+
+func (p *pingPacket) Parse(buf []byte) error {
+ if len(buf) != 41 {
+ return errMalformedPacket
+ }
+ p.SentAt = *(*int64)(unsafe.Pointer(&buf[1]))
+ copy(p.SharedKey[:], buf[9:41])
+ return nil
+}
+
+// ----------------------------------------------------------------------------
+
+// A pongPacket is sent by a node in a server role in response to a pingPacket.
+type pongPacket struct {
+ SentAt int64 // UnixMilli.
+ RecvdAt int64 // UnixMilli.
+}
+
+func newPongPacket(sentAt int64) (pp pongPacket) {
+ pp.SentAt = sentAt
+ pp.RecvdAt = time.Now().UnixMilli()
+ return
+}
+
+func (p pongPacket) Marshal(buf []byte) []byte {
+ buf = buf[:17]
+ buf[0] = packetTypePong
+ *(*uint64)(unsafe.Pointer(&buf[1])) = uint64(p.SentAt)
+ *(*uint64)(unsafe.Pointer(&buf[9])) = uint64(p.RecvdAt)
+
+ return buf
+}
+
+func (p *pongPacket) Parse(buf []byte) error {
+ if len(buf) != 17 {
+ return errMalformedPacket
+ }
+ p.SentAt = *(*int64)(unsafe.Pointer(&buf[1]))
+ p.RecvdAt = *(*int64)(unsafe.Pointer(&buf[9]))
+
+ return nil
+}
diff --git a/node/packets_test.go b/node/packets_test.go
new file mode 100644
index 0000000..bd89215
--- /dev/null
+++ b/node/packets_test.go
@@ -0,0 +1,42 @@
+package node
+
+import (
+ "crypto/rand"
+ "reflect"
+ "testing"
+)
+
+func TestPacketPing(t *testing.T) {
+ sharedKey := make([]byte, 32)
+ rand.Read(sharedKey)
+
+ buf := make([]byte, bufferSize)
+
+ p := newPingPacket(sharedKey)
+ out := p.Marshal(buf)
+
+ p2 := pingPacket{}
+ if err := p2.Parse(out); err != nil {
+ t.Fatal(err)
+ }
+
+ if !reflect.DeepEqual(p, p2) {
+ t.Fatal(p, p2)
+ }
+}
+
+func TestPacketPong(t *testing.T) {
+ buf := make([]byte, bufferSize)
+
+ p := newPongPacket(123566)
+ out := p.Marshal(buf)
+
+ p2 := pongPacket{}
+ if err := p2.Parse(out); err != nil {
+ t.Fatal(err)
+ }
+
+ if !reflect.DeepEqual(p, p2) {
+ t.Fatal(p, p2)
+ }
+}
diff --git a/node/peersupervisor.go b/node/peersupervisor.go
index 14c9315..bdcf03f 100644
--- a/node/peersupervisor.go
+++ b/node/peersupervisor.go
@@ -36,7 +36,7 @@ type peerSupervisor struct {
// Peer-related items.
version int64 // Ony accessed in HandlePeerUpdate.
peer *m.Peer
- remoteAddrPort *netip.AddrPort
+ remoteAddrPort netip.AddrPort
mediated bool
sharedKey []byte
@@ -123,9 +123,9 @@ func (s *peerSupervisor) stateInit() stateFunc {
addr, ok := netip.AddrFromSlice(s.peer.PublicIP)
if ok {
addrPort := netip.AddrPortFrom(addr, s.peer.Port)
- s.remoteAddrPort = &addrPort
+ s.remoteAddrPort = addrPort
} else {
- s.remoteAddrPort = nil
+ s.remoteAddrPort = zeroAddrPort
}
s.sharedKey = computeSharedKey(s.peer.EncPubKey, s.privKey)
@@ -153,7 +153,7 @@ func (s *peerSupervisor) stateSelectRole() stateFunc {
s.logf("STATE: SelectRole")
s.updateRoutingTable(false)
- if s.remoteAddrPort != nil {
+ if s.remoteAddrPort != zeroAddrPort {
s.mediated = false
// If both remote and local are public, one side acts as client, and one
@@ -186,7 +186,7 @@ func (s *peerSupervisor) stateAccept() stateFunc {
switch pkt.Type {
case packetTypePing:
- s.remoteAddrPort = &pkt.Addr
+ s.remoteAddrPort = pkt.Addr
s.updateRoutingTable(true)
s.sendPong(pkt.TraceID)
return s.stateConnected
@@ -256,8 +256,8 @@ func (s *peerSupervisor) stateConnected() stateFunc {
// Server should always follow remote port.
if s.localPublic {
- if pkt.Addr != *s.remoteAddrPort {
- s.remoteAddrPort = &pkt.Addr
+ if pkt.Addr != s.remoteAddrPort {
+ s.remoteAddrPort = pkt.Addr
s.updateRoutingTable(true)
}
}
diff --git a/node/router.go b/node/router.go
index 67c0756..c99f763 100644
--- a/node/router.go
+++ b/node/router.go
@@ -12,12 +12,18 @@ import (
"vppn/m"
)
+var zeroAddrPort = netip.AddrPort{}
+
type peer struct {
- Up bool // No data will be sent to peers that are down.
- Mediator bool
+ IP byte // The VPN IP.
+ Up bool // No data will be sent to peers that are down.
+ Addr netip.AddrPort // If we have direct connection, otherwise use mediator.
+ Mediator bool // True if the peer will mediate.
+ RoutingCipher routingCipher
+ DataCipher dataCipher
+
+ // TODO: Deprecated below.
Mediated bool
- IP byte
- Addr *netip.AddrPort // If we have direct connection, otherwise use mediator.
SharedKey []byte
}
@@ -48,6 +54,10 @@ func (r *routingTable) Set(ip byte, p *peer) {
r.table[ip].Store(p)
}
+func (r *routingTable) Mediator() *peer {
+ return r.mediator.Load()
+}
+
// ----------------------------------------------------------------------------
type router struct {
diff --git a/node/routingpacket.go b/node/routingpacket.go
index 64f0374..1b5aed1 100644
--- a/node/routingpacket.go
+++ b/node/routingpacket.go
@@ -1,18 +1,9 @@
package node
import (
- "errors"
"unsafe"
)
-var errMalformedPacket = errors.New("malformed packet")
-
-const (
- // Used to maintain connection.
- packetTypePing = iota + 1
- packetTypePong
-)
-
type routingPacket struct {
Type byte // One of the packetType* constants.
TraceID uint64 // For matching requests and responses.
--
2.39.5
From 8ab61584690d71e647f0e2e4d76f71a669477657 Mon Sep 17 00:00:00 2001
From: jdl
Date: Fri, 20 Dec 2024 15:55:46 +0100
Subject: [PATCH 03/18] WIP: Working
---
README.md | 1 +
node/addrutil.go | 8 +
node/cipher-control.go | 26 +++
...routing_test.go => cipher-control_test.go} | 38 ++--
node/cipher-data.go | 15 +-
node/cipher-data_test.go | 1 +
node/cipher-routing.go | 26 ---
node/conn.go | 43 ++++
node/header.go | 17 +-
node/header_test.go | 7 +-
node/main.go | 59 +++--
node/packets.go | 17 +-
node/packets_test.go | 8 +-
node/peer-pollhub.go | 97 +++++++++
node/peer-supervisor.go | 197 +++++++++++++++++
node/peer.go | 205 ++++++++++++++++++
node/peersupervisor.go | 12 +-
node/router.go | 2 +-
18 files changed, 664 insertions(+), 115 deletions(-)
create mode 100644 node/addrutil.go
create mode 100644 node/cipher-control.go
rename node/{cipher-routing_test.go => cipher-control_test.go} (62%)
delete mode 100644 node/cipher-routing.go
create mode 100644 node/peer-pollhub.go
create mode 100644 node/peer-supervisor.go
diff --git a/README.md b/README.md
index 3aa4d04..b9d291e 100644
--- a/README.md
+++ b/README.md
@@ -2,6 +2,7 @@
## Roadmap
+* Rename Mediator -> Relay
* Node: use symmetric encryption after handshake
* AEAD-AES uses a 12 byte nonce. We need to shrink the header:
* Remove Forward and replace it with a HeaderFlags bitfield.
diff --git a/node/addrutil.go b/node/addrutil.go
new file mode 100644
index 0000000..590c80c
--- /dev/null
+++ b/node/addrutil.go
@@ -0,0 +1,8 @@
+package node
+
+import "net/netip"
+
+func addrIsValid(in []byte) bool {
+ _, ok := netip.AddrFromSlice(in)
+ return ok
+}
diff --git a/node/cipher-control.go b/node/cipher-control.go
new file mode 100644
index 0000000..e9b56d5
--- /dev/null
+++ b/node/cipher-control.go
@@ -0,0 +1,26 @@
+package node
+
+import "golang.org/x/crypto/nacl/box"
+
+type controlCipher struct {
+ sharedKey [32]byte
+}
+
+func newControlCipher(privKey, pubKey []byte) *controlCipher {
+ shared := [32]byte{}
+ box.Precompute(&shared, (*[32]byte)(pubKey), (*[32]byte)(privKey))
+ return &controlCipher{shared}
+}
+
+func (cc *controlCipher) Encrypt(h xHeader, data, out []byte) []byte {
+ const s = controlHeaderSize
+ out = out[:s+controlCipherOverhead+len(data)]
+ h.Marshal(out[:s])
+ box.SealAfterPrecomputation(out[s:s], data, (*[24]byte)(out[:s]), &cc.sharedKey)
+ return out
+}
+
+func (cc *controlCipher) Decrypt(encrypted, out []byte) (data []byte, ok bool) {
+ const s = controlHeaderSize
+ return box.OpenAfterPrecomputation(out[:0], encrypted[s:], (*[24]byte)(encrypted[:s]), &cc.sharedKey)
+}
diff --git a/node/cipher-routing_test.go b/node/cipher-control_test.go
similarity index 62%
rename from node/cipher-routing_test.go
rename to node/cipher-control_test.go
index 09824f7..c571aa2 100644
--- a/node/cipher-routing_test.go
+++ b/node/cipher-control_test.go
@@ -3,12 +3,13 @@ package node
import (
"bytes"
"crypto/rand"
+ "reflect"
"testing"
"golang.org/x/crypto/nacl/box"
)
-func newRoutingCipherForTesting() (c1, c2 routingCipher) {
+func newControlCipherForTesting() (c1, c2 *controlCipher) {
pubKey1, privKey1, err := box.GenerateKey(rand.Reader)
if err != nil {
panic(err)
@@ -19,14 +20,14 @@ func newRoutingCipherForTesting() (c1, c2 routingCipher) {
panic(err)
}
- return newRoutingCipher(privKey1[:], pubKey2[:]),
- newRoutingCipher(privKey2[:], pubKey1[:])
+ return newControlCipher(privKey1[:], pubKey2[:]),
+ newControlCipher(privKey2[:], pubKey1[:])
}
-func TestRoutingCipher(t *testing.T) {
- c1, c2 := newRoutingCipherForTesting()
+func TestControlCipher(t *testing.T) {
+ c1, c2 := newControlCipherForTesting()
- maxSizePlaintext := make([]byte, bufferSize-routingHeaderSize-routingCipherOverhead)
+ maxSizePlaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead)
rand.Read(maxSizePlaintext)
testCases := [][]byte{
@@ -40,6 +41,7 @@ func TestRoutingCipher(t *testing.T) {
for _, plaintext := range testCases {
h1 := xHeader{
+ StreamID: controlStreamID,
Counter: 235153,
SourceIP: 4,
DestIP: 88,
@@ -49,6 +51,12 @@ func TestRoutingCipher(t *testing.T) {
encrypted = c1.Encrypt(h1, plaintext, encrypted)
+ h2 := xHeader{}
+ h2.Parse(encrypted)
+ if !reflect.DeepEqual(h1, h2) {
+ t.Fatal(h1, h2)
+ }
+
decrypted, ok := c2.Decrypt(encrypted, make([]byte, bufferSize))
if !ok {
t.Fatal(ok)
@@ -60,9 +68,9 @@ func TestRoutingCipher(t *testing.T) {
}
}
-func TestRoutingCipher_ShortCiphertext(t *testing.T) {
- c1, _ := newRoutingCipherForTesting()
- shortText := make([]byte, routingHeaderSize+routingCipherOverhead-1)
+func TestControlCipher_ShortCiphertext(t *testing.T) {
+ c1, _ := newControlCipherForTesting()
+ shortText := make([]byte, controlHeaderSize+controlCipherOverhead-1)
rand.Read(shortText)
_, ok := c1.Decrypt(shortText, make([]byte, bufferSize))
if ok {
@@ -70,15 +78,15 @@ func TestRoutingCipher_ShortCiphertext(t *testing.T) {
}
}
-func BenchmarkRoutingCipher_Encrypt(b *testing.B) {
- c1, _ := newRoutingCipherForTesting()
+func BenchmarkControlCipher_Encrypt(b *testing.B) {
+ c1, _ := newControlCipherForTesting()
h1 := xHeader{
Counter: 235153,
SourceIP: 4,
DestIP: 88,
}
- plaintext := make([]byte, bufferSize-routingHeaderSize-routingCipherOverhead)
+ plaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead)
rand.Read(plaintext)
encrypted := make([]byte, bufferSize)
@@ -89,8 +97,8 @@ func BenchmarkRoutingCipher_Encrypt(b *testing.B) {
}
}
-func BenchmarkRoutingCipher_Decrypt(b *testing.B) {
- c1, c2 := newRoutingCipherForTesting()
+func BenchmarkControlCipher_Decrypt(b *testing.B) {
+ c1, c2 := newControlCipherForTesting()
h1 := xHeader{
Counter: 235153,
@@ -98,7 +106,7 @@ func BenchmarkRoutingCipher_Decrypt(b *testing.B) {
DestIP: 88,
}
- plaintext := make([]byte, bufferSize-routingHeaderSize-routingCipherOverhead)
+ plaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead)
rand.Read(plaintext)
encrypted := make([]byte, bufferSize)
diff --git a/node/cipher-data.go b/node/cipher-data.go
index c0fc273..26d3121 100644
--- a/node/cipher-data.go
+++ b/node/cipher-data.go
@@ -6,22 +6,23 @@ import (
"crypto/rand"
)
+// TODO: Use [32]byte for simplicity everywhere.
type dataCipher struct {
- key []byte
+ key [32]byte
aead cipher.AEAD
}
func newDataCipher() *dataCipher {
- key := make([]byte, 32)
- if _, err := rand.Read(key); err != nil {
+ key := [32]byte{}
+ if _, err := rand.Read(key[:]); err != nil {
panic(err)
}
return newDataCipherFromKey(key)
}
// key must be 32 bytes.
-func newDataCipherFromKey(key []byte) *dataCipher {
- block, err := aes.NewCipher(key)
+func newDataCipherFromKey(key [32]byte) *dataCipher {
+ block, err := aes.NewCipher(key[:])
if err != nil {
panic(err)
}
@@ -34,14 +35,14 @@ func newDataCipherFromKey(key []byte) *dataCipher {
return &dataCipher{key: key, aead: aead}
}
-func (sc *dataCipher) Key() []byte {
+func (sc *dataCipher) Key() [32]byte {
return sc.key
}
func (sc *dataCipher) Encrypt(h xHeader, data, out []byte) []byte {
const s = dataHeaderSize
out = out[:s+dataCipherOverhead+len(data)]
- h.Marshal(dataStreamID, out[:s])
+ h.Marshal(out[:s])
sc.aead.Seal(out[s:s], out[:s], data, nil)
return out
}
diff --git a/node/cipher-data_test.go b/node/cipher-data_test.go
index d1523d8..c3892bb 100644
--- a/node/cipher-data_test.go
+++ b/node/cipher-data_test.go
@@ -23,6 +23,7 @@ func TestDataCipher(t *testing.T) {
for _, plaintext := range testCases {
h1 := xHeader{
+ StreamID: dataStreamID,
Counter: 235153,
SourceIP: 4,
DestIP: 88,
diff --git a/node/cipher-routing.go b/node/cipher-routing.go
deleted file mode 100644
index 795ac7a..0000000
--- a/node/cipher-routing.go
+++ /dev/null
@@ -1,26 +0,0 @@
-package node
-
-import "golang.org/x/crypto/nacl/box"
-
-type routingCipher struct {
- sharedKey [32]byte
-}
-
-func newRoutingCipher(privKey, pubKey []byte) routingCipher {
- shared := [32]byte{}
- box.Precompute(&shared, (*[32]byte)(pubKey), (*[32]byte)(privKey))
- return routingCipher{shared}
-}
-
-func (rc routingCipher) Encrypt(h xHeader, data, out []byte) []byte {
- const s = routingHeaderSize
- out = out[:s+routingCipherOverhead+len(data)]
- h.Marshal(routingStreamID, out[:s])
- box.SealAfterPrecomputation(out[s:s], data, (*[24]byte)(out[:s]), &rc.sharedKey)
- return out
-}
-
-func (rc routingCipher) Decrypt(encrypted, out []byte) (data []byte, ok bool) {
- const s = routingHeaderSize
- return box.OpenAfterPrecomputation(out[:0], encrypted[s:], (*[24]byte)(encrypted[:s]), &rc.sharedKey)
-}
diff --git a/node/conn.go b/node/conn.go
index 8a57641..7f7e4e3 100644
--- a/node/conn.go
+++ b/node/conn.go
@@ -1,6 +1,7 @@
package node
import (
+ "io"
"log"
"net"
"net/netip"
@@ -9,6 +10,48 @@ import (
"vppn/fasttime"
)
+// ----------------------------------------------------------------------------
+
+type connWriter2 struct {
+ lock sync.Mutex
+ conn *net.UDPConn
+}
+
+func newConnWriter2(conn *net.UDPConn) *connWriter2 {
+ return &connWriter2{conn: conn}
+}
+
+func (w *connWriter2) WriteTo(packet []byte, addr netip.AddrPort) {
+ w.lock.Lock()
+ if _, err := w.conn.WriteToUDPAddrPort(packet, addr); err != nil {
+ log.Fatalf("Failed to write to UDP port: %v", err)
+ }
+ w.lock.Unlock()
+}
+
+// ----------------------------------------------------------------------------
+
+type ifWriter struct {
+ lock sync.Mutex
+ iface io.ReadWriteCloser
+}
+
+func newIFWriter(iface io.ReadWriteCloser) *ifWriter {
+ return &ifWriter{iface: iface}
+}
+
+func (w *ifWriter) Write(packet []byte) {
+ w.lock.Lock()
+ if _, err := w.iface.Write(packet); err != nil {
+ log.Fatalf("Failed to write to interface: %v", err)
+ }
+ w.lock.Unlock()
+}
+
+// ----------------------------------------------------------------------------
+
+// TODO: Delete below??
+
type connWriter struct {
*net.UDPConn
lock sync.Mutex
diff --git a/node/header.go b/node/header.go
index a409576..d2eb142 100644
--- a/node/header.go
+++ b/node/header.go
@@ -5,30 +5,33 @@ import "unsafe"
// ----------------------------------------------------------------------------
const (
- routingStreamID = 2
- routingHeaderSize = 24
- routingCipherOverhead = 16
+ controlStreamID = 2
+ controlHeaderSize = 24
+ controlCipherOverhead = 16
dataStreamID = 1
dataHeaderSize = 12
dataCipherOverhead = 16
+
+ forwardStreamID = 3
)
-// TODO: Rename
type xHeader struct {
+ StreamID byte
Counter uint64 // Init with fasttime.Now() << 30 to ensure monotonic.
SourceIP byte
DestIP byte
}
func (h *xHeader) Parse(b []byte) {
+ h.StreamID = b[0]
h.Counter = *(*uint64)(unsafe.Pointer(&b[1]))
h.SourceIP = b[9]
h.DestIP = b[10]
}
-func (h *xHeader) Marshal(streamID byte, buf []byte) {
- buf[0] = streamID
+func (h *xHeader) Marshal(buf []byte) {
+ buf[0] = h.StreamID
*(*uint64)(unsafe.Pointer(&buf[1])) = h.Counter
buf[9] = h.SourceIP
buf[10] = h.DestIP
@@ -40,7 +43,7 @@ func (h *xHeader) Marshal(streamID byte, buf []byte) {
const (
headerSize = 24
streamData = 1
- streamRouting = 2
+ streamControl = 2
)
type header struct {
diff --git a/node/header_test.go b/node/header_test.go
index 7a87354..0205d87 100644
--- a/node/header_test.go
+++ b/node/header_test.go
@@ -3,18 +3,17 @@ package node
import "testing"
func TestHeaderMarshalParse(t *testing.T) {
- nIn := header{
+ nIn := xHeader{
+ StreamID: 23,
Counter: 3212,
SourceIP: 34,
DestIP: 200,
- Forward: 1,
- Stream: 44,
}
buf := make([]byte, headerSize)
nIn.Marshal(buf)
- nOut := header{}
+ nOut := xHeader{}
nOut.Parse(buf)
if nIn != nOut {
t.Fatal(nIn, nOut)
diff --git a/node/main.go b/node/main.go
index cac6df8..f5c9bc7 100644
--- a/node/main.go
+++ b/node/main.go
@@ -102,15 +102,19 @@ func main(netName, listenIP string, port uint16) {
log.Fatalf("Failed to open UDP port: %v", err)
}
- routing := newRoutingTable()
+ connWriter := newConnWriter2(conn)
+ ifWriter := newIFWriter(iface)
- w := newConnWriter(conn, conf.PeerIP, routing)
- r := newConnReader(conn, conf.PeerIP, routing)
+ peers := remotePeers{}
- router := newRouter(netName, conf, routing, w)
+ for i := range peers {
+ peers[i] = newRemotePeer(conf, byte(i), ifWriter, connWriter)
+ }
+
+ go newHubPoller(netName, conf, peers).Run()
+ go readFromConn(conn, peers)
+ readFromIFace(iface, peers)
- go nodeConnReader(r, w, iface, router)
- nodeIFaceReader(w, iface, router)
}
// ----------------------------------------------------------------------------
@@ -127,43 +131,39 @@ func determinePort(confPort, portFromCommandLine uint16) uint16 {
// ----------------------------------------------------------------------------
-func nodeConnReader(r *connReader, w *connWriter, iface io.ReadWriteCloser, router *router) {
+func readFromConn(conn *net.UDPConn, peers remotePeers) {
+
defer panicHandler()
+
var (
remoteAddr netip.AddrPort
- h header
+ n int
+ err error
buf = make([]byte, bufferSize)
data []byte
- err error
+ h xHeader
)
for {
- remoteAddr, h, data = r.Read(buf)
-
- if h.Forward != 0 {
- w.Forward(h.DestIP, data)
- continue
+ n, remoteAddr, err = conn.ReadFromUDPAddrPort(buf[:bufferSize])
+ if err != nil {
+ log.Fatalf("Failed to read from UDP port: %v", err)
}
- switch h.Stream {
+ data = buf[:n]
- case streamData:
- if _, err = iface.Write(data); err != nil {
- log.Printf("Malformed data from peer %d: %v", h.SourceIP, err)
- }
-
- case streamRouting:
- router.HandlePacket(h.SourceIP, remoteAddr, data)
-
- default:
- log.Printf("Dropping unknown stream: %d", h.Stream)
+ if n < headerSize {
+ continue // Packet it soo short.
}
+
+ h.Parse(data)
+ peers[h.SourceIP].HandlePacket(remoteAddr, h, data)
}
}
// ----------------------------------------------------------------------------
-func nodeIFaceReader(w *connWriter, iface io.ReadWriteCloser, router *router) {
+func readFromIFace(iface io.ReadWriteCloser, peers remotePeers) {
var (
buf = make([]byte, bufferSize)
@@ -173,16 +173,11 @@ func nodeIFaceReader(w *connWriter, iface io.ReadWriteCloser, router *router) {
)
for {
-
packet, remoteIP, err = readNextPacket(iface, buf)
if err != nil {
log.Fatalf("Failed to read from interface: %v", err)
}
- if remoteIP == w.localIP {
- continue // Don't write to self.
- }
-
- w.WriteTo(remoteIP, streamData, packet)
+ peers[remoteIP].SendData(packet)
}
}
diff --git a/node/packets.go b/node/packets.go
index 75f4e6e..d197f58 100644
--- a/node/packets.go
+++ b/node/packets.go
@@ -16,10 +16,10 @@ const (
// ----------------------------------------------------------------------------
-type packetWrapper struct {
+type controlPacket struct {
SrcIP byte
RemoteAddr netip.AddrPort
- Packet any
+ Payload any
}
// ----------------------------------------------------------------------------
@@ -46,13 +46,13 @@ func (p pingPacket) Marshal(buf []byte) []byte {
return buf
}
-func (p *pingPacket) Parse(buf []byte) error {
+func parsePingPacket(buf []byte) (p pingPacket, err error) {
if len(buf) != 41 {
- return errMalformedPacket
+ return p, errMalformedPacket
}
p.SentAt = *(*int64)(unsafe.Pointer(&buf[1]))
copy(p.SharedKey[:], buf[9:41])
- return nil
+ return
}
// ----------------------------------------------------------------------------
@@ -78,12 +78,11 @@ func (p pongPacket) Marshal(buf []byte) []byte {
return buf
}
-func (p *pongPacket) Parse(buf []byte) error {
+func parsePongPacket(buf []byte) (p pongPacket, err error) {
if len(buf) != 17 {
- return errMalformedPacket
+ return p, errMalformedPacket
}
p.SentAt = *(*int64)(unsafe.Pointer(&buf[1]))
p.RecvdAt = *(*int64)(unsafe.Pointer(&buf[9]))
-
- return nil
+ return
}
diff --git a/node/packets_test.go b/node/packets_test.go
index bd89215..b385c2b 100644
--- a/node/packets_test.go
+++ b/node/packets_test.go
@@ -15,8 +15,8 @@ func TestPacketPing(t *testing.T) {
p := newPingPacket(sharedKey)
out := p.Marshal(buf)
- p2 := pingPacket{}
- if err := p2.Parse(out); err != nil {
+ p2, err := parsePingPacket(out)
+ if err != nil {
t.Fatal(err)
}
@@ -31,8 +31,8 @@ func TestPacketPong(t *testing.T) {
p := newPongPacket(123566)
out := p.Marshal(buf)
- p2 := pongPacket{}
- if err := p2.Parse(out); err != nil {
+ p2, err := parsePongPacket(out)
+ if err != nil {
t.Fatal(err)
}
diff --git a/node/peer-pollhub.go b/node/peer-pollhub.go
new file mode 100644
index 0000000..aa1c91b
--- /dev/null
+++ b/node/peer-pollhub.go
@@ -0,0 +1,97 @@
+package node
+
+import (
+ "encoding/json"
+ "io"
+ "log"
+ "net/http"
+ "net/url"
+ "time"
+ "vppn/m"
+)
+
+type hubPoller struct {
+ netName string
+ localIP byte
+ client *http.Client
+ req *http.Request
+ peers remotePeers
+}
+
+func newHubPoller(netName string, conf m.PeerConfig, peers remotePeers) *hubPoller {
+ u, err := url.Parse(conf.HubAddress)
+ if err != nil {
+ log.Fatalf("Failed to parse hub address %s: %v", conf.HubAddress, err)
+ }
+ u.Path = "/peer/fetch-state/"
+
+ client := &http.Client{Timeout: 8 * time.Second}
+
+ req := &http.Request{
+ Method: http.MethodGet,
+ URL: u,
+ Header: http.Header{},
+ }
+ req.SetBasicAuth("", conf.APIKey)
+
+ return &hubPoller{
+ netName: netName,
+ localIP: conf.PeerIP,
+ client: client,
+ req: req,
+ peers: peers,
+ }
+}
+
+func (hp *hubPoller) Run() {
+ defer panicHandler()
+
+ state, err := loadNetworkState(hp.netName)
+ if err != nil {
+ log.Printf("Failed to load network state: %v", err)
+ log.Printf("Polling hub...")
+ hp.pollHub()
+ } else {
+ hp.applyNetworkState(state)
+ }
+
+ for range time.Tick(64 * time.Second) {
+ hp.pollHub()
+ }
+}
+
+func (hp *hubPoller) pollHub() {
+ var state m.NetworkState
+
+ log.Printf("Fetching peer state...")
+ resp, err := hp.client.Do(hp.req)
+ if err != nil {
+ log.Printf("Failed to fetch peer state: %v", err)
+ return
+ }
+ body, err := io.ReadAll(resp.Body)
+ _ = resp.Body.Close()
+ if err != nil {
+ log.Printf("Failed to read body from hub: %v", err)
+ return
+ }
+
+ if err := json.Unmarshal(body, &state); err != nil {
+ log.Printf("Failed to unmarshal response from hub: %v", err)
+ return
+ }
+
+ hp.applyNetworkState(state)
+
+ if err := storeNetworkState(hp.netName, state); err != nil {
+ log.Printf("Failed to store network state: %v", err)
+ }
+}
+
+func (hp *hubPoller) applyNetworkState(state m.NetworkState) {
+ for i := range state.Peers {
+ if i != int(hp.localIP) {
+ hp.peers[i].HandlePeerUpdate(state.Peers[i])
+ }
+ }
+}
diff --git a/node/peer-supervisor.go b/node/peer-supervisor.go
new file mode 100644
index 0000000..cfcb43b
--- /dev/null
+++ b/node/peer-supervisor.go
@@ -0,0 +1,197 @@
+package node
+
+import (
+ "net/netip"
+ "sync/atomic"
+ "time"
+ "vppn/m"
+)
+
+const (
+ connectTimeout = 6 * time.Second
+ pingInterval = 6 * time.Second
+ timeoutInterval = 20 * time.Second
+)
+
+type stateFunc func() stateFunc
+
+type peerSuper struct {
+ *remotePeer
+
+ peer *m.Peer
+ remotePublic bool
+ peerData peerData
+
+ pktBuf []byte
+ encBuf []byte
+}
+
+func newPeerSuper(rp *remotePeer) *peerSuper {
+ return &peerSuper{
+ remotePeer: rp,
+ peer: nil,
+ pktBuf: make([]byte, bufferSize),
+ encBuf: make([]byte, bufferSize),
+ }
+}
+
+func (rp *peerSuper) Run() {
+ defer panicHandler()
+ state := rp.stateInit
+ for {
+ state = state()
+ }
+}
+
+// ----------------------------------------------------------------------------
+
+func (rp *peerSuper) stateInit() stateFunc {
+ //rp.logf("STATE: Init")
+ x := peerData{}
+ rp.shared.Store(&x)
+
+ rp.peerData.controlCipher = nil
+ rp.peerData.dataCipher = nil
+ rp.peerData.remoteAddr = zeroAddrPort
+
+ if rp.peer == nil {
+ return rp.stateDisconnected
+ }
+
+ var addr netip.Addr
+ addr, rp.remotePublic = netip.AddrFromSlice(rp.peer.PublicIP)
+ if rp.remotePublic {
+ rp.peerData.remoteAddr = netip.AddrPortFrom(addr, rp.peer.Port)
+ }
+
+ rp.peerData.controlCipher = newControlCipher(rp.privKey, rp.peer.EncPubKey)
+
+ return rp.stateSelectRole()
+}
+
+// ----------------------------------------------------------------------------
+
+func (rp *peerSuper) stateDisconnected() stateFunc {
+ //rp.logf("STATE: Disconnected")
+ for {
+ select {
+ case <-rp.controlPackets:
+ // Drop
+ case rp.peer = <-rp.peerUpdates:
+ return rp.stateInit
+ }
+ }
+}
+
+// ----------------------------------------------------------------------------
+
+func (rp *peerSuper) stateSelectRole() stateFunc {
+ rp.logf("STATE: SelectRole")
+
+ if !rp.localPublic && !rp.remotePublic {
+ // TODO!
+ return rp.stateDisconnected
+ }
+
+ if !rp.localPublic {
+ return rp.stateServer
+ } else if !rp.remotePublic {
+ return rp.stateClient
+ }
+
+ if rp.localIP < rp.peer.PeerIP {
+ return rp.stateClient
+ }
+ return rp.stateServer
+}
+
+// The remote is a server.
+func (rp *peerSuper) stateServer() stateFunc {
+ rp.logf("STATE: Server")
+ rp.peerData.dataCipher = newDataCipher()
+ rp.updateShared()
+
+ var (
+ pingTimer = time.NewTimer(pingInterval)
+ ping = pingPacket{SharedKey: ([32]byte)(rp.peerData.dataCipher.Key())}
+ )
+ defer pingTimer.Stop()
+
+ ping.SentAt = time.Now().UnixMilli()
+ rp.sendControlPacket(ping)
+
+ for {
+ select {
+ case <-pingTimer.C:
+ ping.SentAt = time.Now().UnixMilli()
+ rp.sendControlPacket(ping)
+ pingTimer.Reset(pingInterval)
+
+ case <-rp.controlPackets:
+ // Ignore
+
+ case rp.peer = <-rp.peerUpdates:
+ return rp.stateInit
+ }
+ }
+}
+
+// ----------------------------------------------------------------------------
+
+// The remote is a client.
+func (rp *peerSuper) stateClient() stateFunc {
+ rp.logf("STATE: Client")
+ rp.updateShared()
+
+ // TODO: Could use timeout to set dataCipher to nil.
+ var currentKey = [32]byte{}
+
+ for {
+ select {
+ case cPkt := <-rp.controlPackets:
+ if cPkt.RemoteAddr != rp.peerData.remoteAddr {
+ rp.peerData.remoteAddr = cPkt.RemoteAddr
+ rp.logf("Got new remote address: %v", cPkt.RemoteAddr)
+ rp.updateShared()
+ }
+
+ ping, ok := cPkt.Payload.(pingPacket)
+ if !ok {
+ continue
+ }
+
+ if ping.SharedKey != currentKey {
+ rp.logf("Connected with new shared key")
+ currentKey = ping.SharedKey
+ rp.peerData.dataCipher = newDataCipherFromKey(currentKey)
+ rp.updateShared()
+ }
+
+ rp.sendControlPacket(newPongPacket(ping.SentAt))
+
+ case rp.peer = <-rp.peerUpdates:
+ return rp.stateInit
+ }
+ }
+}
+
+// ----------------------------------------------------------------------------
+
+func (rp *peerSuper) updateShared() {
+ data := rp.peerData
+ rp.shared.Store(&data)
+}
+
+// ----------------------------------------------------------------------------
+
+func (rp *peerSuper) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) {
+ buf := pkt.Marshal(rp.pktBuf)
+ h := xHeader{
+ StreamID: controlStreamID,
+ Counter: atomic.AddUint64(&rp.counter, 1),
+ SourceIP: rp.localIP,
+ DestIP: rp.remoteIP,
+ }
+ buf = rp.peerData.controlCipher.Encrypt(h, buf, rp.encBuf)
+ rp.conn.WriteTo(buf, rp.peerData.remoteAddr)
+}
diff --git a/node/peer.go b/node/peer.go
index 2b4023a..19cddfd 100644
--- a/node/peer.go
+++ b/node/peer.go
@@ -1 +1,206 @@
package node
+
+import (
+ "fmt"
+ "log"
+ "net/netip"
+ "sync/atomic"
+ "time"
+ "vppn/m"
+)
+
+type remotePeers [256]*remotePeer
+
+type peerData struct {
+ controlCipher *controlCipher
+ dataCipher *dataCipher
+ remoteAddr netip.AddrPort
+}
+
+type remotePeer struct {
+ // Immutable data.
+ localIP byte
+ remoteIP byte
+ privKey []byte
+ localPublic bool // True if local node is public.
+ iface *ifWriter
+ conn *connWriter2
+
+ // Shared state.
+ shared *atomic.Pointer[peerData]
+
+ // Only used in HandlePeerUpdate.
+ peerVersion int64
+
+ // Only used in HandlePacket / Not synchronized.
+ dupCheck *dupCheck
+ decryptBuf []byte
+
+ // Only used in SendData / Not synchronized.
+ encryptBuf []byte
+
+ // Used for sending control and data packets. Atomic access only.
+ counter uint64
+
+ // For communicating with the supervisor thread.
+ peerUpdates chan *m.Peer
+ controlPackets chan controlPacket
+}
+
+func newRemotePeer(conf m.PeerConfig, remoteIP byte, iface *ifWriter, conn *connWriter2) *remotePeer {
+ rp := &remotePeer{
+ localIP: conf.PeerIP,
+ remoteIP: remoteIP,
+ privKey: conf.EncPrivKey,
+ localPublic: addrIsValid(conf.PublicIP),
+ iface: iface,
+ conn: conn,
+ shared: &atomic.Pointer[peerData]{},
+ dupCheck: newDupCheck(0),
+ decryptBuf: make([]byte, bufferSize),
+ encryptBuf: make([]byte, bufferSize),
+ counter: uint64(time.Now().Unix()) << 30,
+ peerUpdates: make(chan *m.Peer),
+ controlPackets: make(chan controlPacket, 512),
+ }
+
+ pd := peerData{}
+ rp.shared.Store(&pd)
+
+ go newPeerSuper(rp).Run()
+
+ return rp
+}
+
+func (rp *remotePeer) logf(msg string, args ...any) {
+ log.Printf(fmt.Sprintf("[%03d] ", rp.remoteIP)+msg, args...)
+}
+
+func (rp *remotePeer) HandlePeerUpdate(peer *m.Peer) {
+ if peer != nil && peer.Version != rp.peerVersion {
+ rp.peerUpdates <- peer
+ rp.peerVersion = peer.Version
+ }
+}
+
+// ----------------------------------------------------------------------------
+
+// HandlePacket accepts a raw data packet coming in from the network.
+//
+// This function is called by a single thread.
+func (rp *remotePeer) HandlePacket(addr netip.AddrPort, h xHeader, data []byte) {
+ switch h.StreamID {
+ case controlStreamID:
+ rp.handleControlPacket(addr, h, data)
+
+ case dataStreamID:
+ rp.handleDataPacket(data)
+
+ case forwardStreamID:
+ fallthrough
+ // TODO
+ //rp.handleForwardPacket(h, data)
+ default:
+ rp.logf("Unknown stream ID: %d", h.StreamID)
+ }
+}
+
+// ----------------------------------------------------------------------------
+
+func (rp *remotePeer) handleControlPacket(addr netip.AddrPort, h xHeader, data []byte) {
+ shared := rp.shared.Load()
+ if shared.controlCipher == nil {
+ rp.logf("Not connected (control).")
+ return
+ }
+
+ out, ok := shared.controlCipher.Decrypt(data, rp.decryptBuf)
+ if !ok {
+ rp.logf("Failed to decrypt control packet.")
+ return
+ }
+
+ if len(out) == 0 {
+ rp.logf("Empty control packet from: %d", h.SourceIP)
+ return
+ }
+
+ if rp.dupCheck.IsDup(h.Counter) {
+ rp.logf("Duplicate control packet: %d", h.Counter)
+ return
+ }
+
+ if h.DestIP != rp.localIP {
+ // TODO: Forward control packet.
+ // TODO: Probably this should be dropped.
+ // Control packets should be forwarded as data for efficiency.
+ return
+ }
+
+ pkt := controlPacket{
+ SrcIP: h.SourceIP,
+ RemoteAddr: addr,
+ }
+
+ var err error
+
+ switch out[0] {
+ case packetTypePing:
+ pkt.Payload, err = parsePingPacket(out)
+ case packetTypePong:
+ pkt.Payload, err = parsePongPacket(out)
+ default:
+ rp.logf("Unknown control packet type: %d", out[0])
+ return
+ }
+
+ if err != nil {
+ rp.logf("Failed to parse control packet: %v", err)
+ return
+ }
+
+ select {
+ case rp.controlPackets <- pkt:
+ default:
+ rp.logf("Dropping control packet.")
+ }
+}
+
+func (rp *remotePeer) handleDataPacket(data []byte) {
+ shared := rp.shared.Load()
+ if shared.dataCipher == nil {
+ rp.logf("Not connected (recv).")
+ return
+ }
+
+ dec, ok := shared.dataCipher.Decrypt(data, rp.decryptBuf)
+ if !ok {
+ rp.logf("Failed to decrypt data packet.")
+ return
+ }
+
+ rp.iface.Write(dec)
+}
+
+// ----------------------------------------------------------------------------
+
+// SendData sends data coming from the interface going to the network.
+//
+// This function is called by a single thread.
+func (rp *remotePeer) SendData(data []byte) {
+ shared := rp.shared.Load()
+ if shared.dataCipher == nil || shared.remoteAddr == zeroAddrPort {
+ rp.logf("Not connected (send).")
+ return
+ }
+
+ h := xHeader{
+ StreamID: dataStreamID,
+ Counter: atomic.AddUint64(&rp.counter, 1),
+ SourceIP: rp.localIP,
+ DestIP: rp.remoteIP,
+ }
+
+ enc := shared.dataCipher.Encrypt(h, data, rp.encryptBuf)
+ rp.conn.WriteTo(enc, shared.remoteAddr)
+}
diff --git a/node/peersupervisor.go b/node/peersupervisor.go
index bdcf03f..90763b4 100644
--- a/node/peersupervisor.go
+++ b/node/peersupervisor.go
@@ -8,12 +8,6 @@ import (
"vppn/m"
)
-const (
- connectTimeout = 6 * time.Second
- pingInterval = 6 * time.Second
- timeoutInterval = 20 * time.Second
-)
-
type routingPacketWrapper struct {
routingPacket
Addr netip.AddrPort // Source.
@@ -113,8 +107,6 @@ func (s *peerSupervisor) HandlePacket(w routingPacketWrapper) {
// ----------------------------------------------------------------------------
-type stateFunc func() stateFunc
-
func (s *peerSupervisor) stateInit() stateFunc {
if s.peer == nil {
return s.stateDisconnected
@@ -316,12 +308,12 @@ func (s *peerSupervisor) updateRoutingTable(up bool) {
func (s *peerSupervisor) sendPing() uint64 {
traceID := newTraceID()
pkt := newRoutingPacket(packetTypePing, traceID)
- s.w.WriteTo(s.peer.PeerIP, streamRouting, pkt.Marshal(s.buf))
+ s.w.WriteTo(s.peer.PeerIP, streamControl, pkt.Marshal(s.buf))
s.pingTimer.Reset(pingInterval)
return traceID
}
func (s *peerSupervisor) sendPong(traceID uint64) {
pkt := newRoutingPacket(packetTypePong, traceID)
- s.w.WriteTo(s.peer.PeerIP, streamRouting, pkt.Marshal(s.buf))
+ s.w.WriteTo(s.peer.PeerIP, streamControl, pkt.Marshal(s.buf))
}
diff --git a/node/router.go b/node/router.go
index c99f763..0e74d14 100644
--- a/node/router.go
+++ b/node/router.go
@@ -19,7 +19,7 @@ type peer struct {
Up bool // No data will be sent to peers that are down.
Addr netip.AddrPort // If we have direct connection, otherwise use mediator.
Mediator bool // True if the peer will mediate.
- RoutingCipher routingCipher
+ RoutingCipher controlCipher
DataCipher dataCipher
// TODO: Deprecated below.
--
2.39.5
From f87c2e59b4203bb1104eca53956a70c4e7e2c7d8 Mon Sep 17 00:00:00 2001
From: jdl
Date: Fri, 20 Dec 2024 16:13:59 +0100
Subject: [PATCH 04/18] wip: cleaning
---
node/cipher.go | 6 -
node/conn.go | 205 +-------------------------
node/crypto.go | 50 -------
node/crypto_test.go | 137 ------------------
node/header.go | 34 +----
node/interface.go | 2 +-
node/main.go | 2 +-
node/node.go | 1 -
node/peer.go | 4 +-
node/peerstate.go | 1 -
node/peersupervisor.go | 319 -----------------------------------------
node/router.go | 189 ------------------------
node/routingpacket.go | 33 -----
node/tmp-server.go | 185 ------------------------
14 files changed, 10 insertions(+), 1158 deletions(-)
delete mode 100644 node/cipher.go
delete mode 100644 node/crypto.go
delete mode 100644 node/crypto_test.go
delete mode 100644 node/node.go
delete mode 100644 node/peerstate.go
delete mode 100644 node/peersupervisor.go
delete mode 100644 node/routingpacket.go
delete mode 100644 node/tmp-server.go
diff --git a/node/cipher.go b/node/cipher.go
deleted file mode 100644
index cb7accd..0000000
--- a/node/cipher.go
+++ /dev/null
@@ -1,6 +0,0 @@
-package node
-
-type packetCipher interface {
- Encrypt(h xHeader, data, out []byte) []byte
- Decrypt(encrypted, out []byte) (data []byte, ok bool)
-}
diff --git a/node/conn.go b/node/conn.go
index 7f7e4e3..344d8d5 100644
--- a/node/conn.go
+++ b/node/conn.go
@@ -6,22 +6,20 @@ import (
"net"
"net/netip"
"sync"
- "sync/atomic"
- "vppn/fasttime"
)
// ----------------------------------------------------------------------------
-type connWriter2 struct {
+type connWriter struct {
lock sync.Mutex
conn *net.UDPConn
}
-func newConnWriter2(conn *net.UDPConn) *connWriter2 {
- return &connWriter2{conn: conn}
+func newConnWriter(conn *net.UDPConn) *connWriter {
+ return &connWriter{conn: conn}
}
-func (w *connWriter2) WriteTo(packet []byte, addr netip.AddrPort) {
+func (w *connWriter) WriteTo(packet []byte, addr netip.AddrPort) {
w.lock.Lock()
if _, err := w.conn.WriteToUDPAddrPort(packet, addr); err != nil {
log.Fatalf("Failed to write to UDP port: %v", err)
@@ -47,198 +45,3 @@ func (w *ifWriter) Write(packet []byte) {
}
w.lock.Unlock()
}
-
-// ----------------------------------------------------------------------------
-
-// TODO: Delete below??
-
-type connWriter struct {
- *net.UDPConn
- lock sync.Mutex
- localIP byte
- buf []byte
- buf2 []byte
- counters [256]uint64
- routing *routingTable
-}
-
-func newConnWriter(conn *net.UDPConn, localIP byte, routing *routingTable) *connWriter {
- w := &connWriter{
- UDPConn: conn,
- localIP: localIP,
- buf: make([]byte, bufferSize),
- buf2: make([]byte, bufferSize),
- routing: routing,
- }
-
- for i := range w.counters {
- w.counters[i] = uint64(fasttime.Now() << 30)
- }
-
- return w
-}
-
-/*
- func (w *connWriter) SendRouting(remoteIP byte, data []byte) {
- dstPeer := w.routing.Get(remoteIP)
- if dstPeer == nil {
- log.Printf("No peer: %d", remoteIP)
- return
- }
-
- var viaPeer *peer
-
- if dstPeer.Addr == zeroAddrPort {
- viaPeer = w.routing.Mediator()
- if viaPeer == nil {
- log.Printf("No mediator: %d", remoteIP)
- return
- }
- }
-
- w.sendRouting(dstPeer, viaPeer, data)
- }
-*/
-
-func (w *connWriter) SendData(remoteIP byte, data []byte) {
- // TODO
-}
-
-// TODO: deprecated
-func (w *connWriter) WriteTo(remoteIP, stream byte, data []byte) {
- dstPeer := w.routing.Get(remoteIP)
- if dstPeer == nil {
- log.Printf("No peer: %d", remoteIP)
- return
- }
-
- if stream == streamData && !dstPeer.Up {
- log.Printf("Peer down: %d", remoteIP)
- return
- }
-
- var viaPeer *peer
- if dstPeer.Mediated {
- viaPeer = w.routing.mediator.Load()
- if viaPeer == nil || viaPeer.Addr == zeroAddrPort {
- log.Printf("Mediator not connected")
- return
- }
- } else if dstPeer.Addr == zeroAddrPort {
- log.Printf("Peer doesn't have address: %d", remoteIP)
- return
- }
-
- w.WriteToPeer(dstPeer, viaPeer, stream, data)
-}
-
-// TODO: deprecated
-func (w *connWriter) WriteToPeer(dstPeer, viaPeer *peer, stream byte, data []byte) {
- w.lock.Lock()
-
- addr := dstPeer.Addr
-
- h := header{
- Counter: atomic.AddUint64(&w.counters[dstPeer.IP], 1),
- SourceIP: w.localIP,
- DestIP: dstPeer.IP,
- Stream: stream,
- }
-
- buf := encryptPacketAsym(&h, dstPeer.SharedKey, data, w.buf)
-
- if viaPeer != nil {
- h := header{
- Counter: atomic.AddUint64(&w.counters[viaPeer.IP], 1),
- SourceIP: w.localIP,
- DestIP: dstPeer.IP,
- Forward: 1,
- Stream: stream,
- }
-
- buf = encryptPacketAsym(&h, viaPeer.SharedKey, buf, w.buf2)
- addr = viaPeer.Addr
- }
-
- if _, err := w.WriteToUDPAddrPort(buf, addr); err != nil {
- log.Fatalf("Failed to write to UDP port: %v", err)
- }
- w.lock.Unlock()
-}
-
-// TODO: deprecated
-func (w *connWriter) Forward(dstIP byte, packet []byte) {
- dstPeer := w.routing.Get(dstIP)
- if dstPeer == nil || dstPeer.Addr == zeroAddrPort {
- log.Printf("No peer: %d", dstIP)
- return
- }
-
- if _, err := w.WriteToUDPAddrPort(packet, dstPeer.Addr); err != nil {
- log.Fatalf("Failed to write to UDP port: %v", err)
- }
-}
-
-// ----------------------------------------------------------------------------
-
-type connReader struct {
- *net.UDPConn
- localIP byte
- dupChecks [256]*dupCheck
- routing *routingTable
- buf []byte
-}
-
-func newConnReader(conn *net.UDPConn, localIP byte, routing *routingTable) *connReader {
- r := &connReader{
- UDPConn: conn,
- localIP: localIP,
- routing: routing,
- buf: make([]byte, bufferSize),
- }
- for i := range r.dupChecks {
- r.dupChecks[i] = newDupCheck(0)
- }
- return r
-}
-
-func (r *connReader) Read(buf []byte) (remoteAddr netip.AddrPort, h header, data []byte) {
- var (
- n int
- err error
- )
-
- for {
- n, remoteAddr, err = r.ReadFromUDPAddrPort(buf[:bufferSize])
- if err != nil {
- log.Fatalf("Failed to read from UDP port: %v", err)
- }
-
- data = buf[:n]
-
- if n < headerSize {
- continue // Packet it soo short.
- }
-
- h.Parse(data)
-
- peer := r.routing.Get(h.SourceIP)
- if peer == nil {
- continue
- }
-
- out, ok := decryptPacketAsym(peer.SharedKey, data, r.buf)
- if !ok {
- continue
- }
-
- out, data = data, out
-
- if r.dupChecks[h.SourceIP].IsDup(h.Counter) {
- log.Printf("Duplicate: %d", h.Counter)
- continue
- }
-
- return
- }
-}
diff --git a/node/crypto.go b/node/crypto.go
deleted file mode 100644
index 0f7710f..0000000
--- a/node/crypto.go
+++ /dev/null
@@ -1,50 +0,0 @@
-package node
-
-import (
- "sync"
- "vppn/fasttime"
-
- "golang.org/x/crypto/nacl/box"
-)
-
-// Encrypting the packet will also set the header's DataSize field.
-func encryptPacketAsym(h *header, sharedKey, data, out []byte) []byte {
- out = out[:headerSize]
- h.Marshal(out)
- b := box.SealAfterPrecomputation(out[headerSize:headerSize], data, (*[24]byte)(out[:headerSize]), (*[32]byte)(sharedKey))
- return out[:len(b)+headerSize]
-}
-
-func decryptPacketAsym(sharedKey, packetAndHeader, out []byte) (decrypted []byte, ok bool) {
- return box.OpenAfterPrecomputation(
- out[:0],
- packetAndHeader[headerSize:],
- (*[24]byte)(packetAndHeader[:headerSize]),
- (*[32]byte)(sharedKey))
-}
-
-func computeSharedKey(peerPubKey, privKey []byte) []byte {
- shared := [32]byte{}
- box.Precompute(&shared, (*[32]byte)(peerPubKey), (*[32]byte)(privKey))
- return shared[:]
-}
-
-var (
- traceIDLock sync.Mutex
- traceIDTime uint64
- traceIDCounter uint64
-)
-
-func newTraceID() (id uint64) {
- traceIDLock.Lock()
- defer traceIDLock.Unlock()
-
- now := uint64(fasttime.Now())
- if traceIDTime < now {
- traceIDTime = now
- traceIDCounter = 0
- }
- traceIDCounter++
-
- return traceIDTime<<30 + traceIDCounter
-}
diff --git a/node/crypto_test.go b/node/crypto_test.go
deleted file mode 100644
index 76f408f..0000000
--- a/node/crypto_test.go
+++ /dev/null
@@ -1,137 +0,0 @@
-package node
-
-import (
- "bytes"
- "crypto/rand"
- "reflect"
- "testing"
-
- "golang.org/x/crypto/nacl/box"
-)
-
-func TestEncryptDecryptAsym(t *testing.T) {
- pubKey1, privKey1, err := box.GenerateKey(rand.Reader)
- if err != nil {
- t.Fatal(err)
- }
-
- pubKey2, privKey2, err := box.GenerateKey(rand.Reader)
- if err != nil {
- t.Fatal(err)
- }
-
- sharedEncKey := [32]byte{}
- box.Precompute(&sharedEncKey, pubKey2, privKey1)
-
- sharedDecKey := [32]byte{}
- box.Precompute(&sharedDecKey, pubKey1, privKey2)
-
- original := make([]byte, if_mtu-64)
- rand.Read(original)
-
- h := header{
- Counter: 2893749238,
- SourceIP: 5,
- DestIP: 12,
- Forward: 1,
- Stream: 1,
- }
-
- encrypted := make([]byte, bufferSize)
- encrypted = encryptPacketAsym(&h, sharedEncKey[:], original, encrypted)
-
- decrypted := make([]byte, bufferSize)
- var ok bool
- decrypted, ok = decryptPacketAsym(sharedDecKey[:], encrypted, decrypted)
- if !ok {
- t.Fatal(ok)
- }
-
- var h2 header
- h2.Parse(encrypted)
-
- if !reflect.DeepEqual(h, h2) {
- t.Fatal(h, h2)
- }
-
- if !bytes.Equal(original, decrypted) {
- t.Fatal("mismatch")
- }
-}
-
-func BenchmarkEncryptAsym(b *testing.B) {
- _, privKey1, err := box.GenerateKey(rand.Reader)
- if err != nil {
- b.Fatal(err)
- }
-
- pubKey2, _, err := box.GenerateKey(rand.Reader)
- if err != nil {
- b.Fatal(err)
- }
-
- sharedEncKey := [32]byte{}
- box.Precompute(&sharedEncKey, pubKey2, privKey1)
-
- original := make([]byte, if_mtu)
- rand.Read(original)
-
- nonce := make([]byte, headerSize)
- rand.Read(nonce)
-
- encrypted := make([]byte, bufferSize)
-
- h := header{
- Counter: 2893749238,
- SourceIP: 5,
- DestIP: 12,
- Forward: 1,
- Stream: 1,
- }
-
- for i := 0; i < b.N; i++ {
- encrypted = encryptPacketAsym(&h, sharedEncKey[:], original, encrypted)
- }
-}
-
-func BenchmarkDecryptAsym(b *testing.B) {
- pubKey1, privKey1, err := box.GenerateKey(rand.Reader)
- if err != nil {
- b.Fatal(err)
- }
-
- pubKey2, privKey2, err := box.GenerateKey(rand.Reader)
- if err != nil {
- b.Fatal(err)
- }
-
- sharedEncKey := [32]byte{}
- box.Precompute(&sharedEncKey, pubKey2, privKey1)
-
- sharedDecKey := [32]byte{}
- box.Precompute(&sharedDecKey, pubKey1, privKey2)
-
- original := make([]byte, if_mtu)
- rand.Read(original)
-
- nonce := make([]byte, headerSize)
- rand.Read(nonce)
-
- h := header{
- Counter: 2893749238,
- SourceIP: 5,
- DestIP: 12,
- Forward: 1,
- Stream: 1,
- }
-
- encrypted := encryptPacketAsym(&h, sharedEncKey[:], original, make([]byte, bufferSize))
- decrypted := make([]byte, bufferSize)
- var ok bool
- for i := 0; i < b.N; i++ {
- decrypted, ok = decryptPacketAsym(sharedDecKey[:], encrypted, decrypted)
- if !ok {
- panic(ok)
- }
- }
-}
diff --git a/node/header.go b/node/header.go
index d2eb142..f2e300f 100644
--- a/node/header.go
+++ b/node/header.go
@@ -5,6 +5,8 @@ import "unsafe"
// ----------------------------------------------------------------------------
const (
+ headerSize = 12
+
controlStreamID = 2
controlHeaderSize = 24
controlCipherOverhead = 16
@@ -37,35 +39,3 @@ func (h *xHeader) Marshal(buf []byte) {
buf[10] = h.DestIP
buf[11] = 0
}
-
-// ----------------------------------------------------------------------------
-// TODO: Remove this code.
-const (
- headerSize = 24
- streamData = 1
- streamControl = 2
-)
-
-type header struct {
- Counter uint64 // Init with fasttime.Now() << 30 to ensure monotonic.
- SourceIP byte
- DestIP byte
- Forward byte
- Stream byte // See stream* constants.
-}
-
-func (hdr *header) Parse(nb []byte) {
- hdr.Counter = *(*uint64)(unsafe.Pointer(&nb[0]))
- hdr.SourceIP = nb[8]
- hdr.DestIP = nb[9]
- hdr.Forward = nb[10]
- hdr.Stream = nb[11]
-}
-
-func (hdr header) Marshal(buf []byte) {
- *(*uint64)(unsafe.Pointer(&buf[0])) = hdr.Counter
- buf[8] = hdr.SourceIP
- buf[9] = hdr.DestIP
- buf[10] = hdr.Forward
- buf[11] = hdr.Stream
-}
diff --git a/node/interface.go b/node/interface.go
index c5edf3e..e066b2b 100644
--- a/node/interface.go
+++ b/node/interface.go
@@ -51,7 +51,7 @@ func readNextPacket(iface io.ReadWriteCloser, buf []byte) ([]byte, byte, error)
}
const (
- if_mtu = 1200
+ if_mtu = 1350
if_queue_len = 2048
)
diff --git a/node/main.go b/node/main.go
index f5c9bc7..00d7f9c 100644
--- a/node/main.go
+++ b/node/main.go
@@ -102,7 +102,7 @@ func main(netName, listenIP string, port uint16) {
log.Fatalf("Failed to open UDP port: %v", err)
}
- connWriter := newConnWriter2(conn)
+ connWriter := newConnWriter(conn)
ifWriter := newIFWriter(iface)
peers := remotePeers{}
diff --git a/node/node.go b/node/node.go
deleted file mode 100644
index 2b4023a..0000000
--- a/node/node.go
+++ /dev/null
@@ -1 +0,0 @@
-package node
diff --git a/node/peer.go b/node/peer.go
index 19cddfd..ab7ca77 100644
--- a/node/peer.go
+++ b/node/peer.go
@@ -24,7 +24,7 @@ type remotePeer struct {
privKey []byte
localPublic bool // True if local node is public.
iface *ifWriter
- conn *connWriter2
+ conn *connWriter
// Shared state.
shared *atomic.Pointer[peerData]
@@ -47,7 +47,7 @@ type remotePeer struct {
controlPackets chan controlPacket
}
-func newRemotePeer(conf m.PeerConfig, remoteIP byte, iface *ifWriter, conn *connWriter2) *remotePeer {
+func newRemotePeer(conf m.PeerConfig, remoteIP byte, iface *ifWriter, conn *connWriter) *remotePeer {
rp := &remotePeer{
localIP: conf.PeerIP,
remoteIP: remoteIP,
diff --git a/node/peerstate.go b/node/peerstate.go
deleted file mode 100644
index 2b4023a..0000000
--- a/node/peerstate.go
+++ /dev/null
@@ -1 +0,0 @@
-package node
diff --git a/node/peersupervisor.go b/node/peersupervisor.go
deleted file mode 100644
index 90763b4..0000000
--- a/node/peersupervisor.go
+++ /dev/null
@@ -1,319 +0,0 @@
-package node
-
-import (
- "fmt"
- "log"
- "net/netip"
- "time"
- "vppn/m"
-)
-
-type routingPacketWrapper struct {
- routingPacket
- Addr netip.AddrPort // Source.
-}
-
-type peerSupervisor struct {
- // Constants:
- localIP byte
- localPublic bool
- remoteIP byte
- privKey []byte
-
- // Shared data:
- w *connWriter
- table *routingTable
-
- packets chan routingPacketWrapper
- peerUpdates chan *m.Peer
-
- // Peer-related items.
- version int64 // Ony accessed in HandlePeerUpdate.
- peer *m.Peer
- remoteAddrPort netip.AddrPort
- mediated bool
- sharedKey []byte
-
- // Used by our state functions.
- pingTimer *time.Timer
- timeoutTimer *time.Timer
- buf []byte
-}
-
-// ----------------------------------------------------------------------------
-
-func newPeerSupervisor(
- conf m.PeerConfig,
- remoteIP byte,
- w *connWriter,
- table *routingTable,
-) *peerSupervisor {
- s := &peerSupervisor{
- localIP: conf.PeerIP,
- remoteIP: remoteIP,
- privKey: conf.EncPrivKey,
- w: w,
- table: table,
- packets: make(chan routingPacketWrapper, 256),
- peerUpdates: make(chan *m.Peer, 1),
- pingTimer: time.NewTimer(pingInterval),
- timeoutTimer: time.NewTimer(timeoutInterval),
- buf: make([]byte, bufferSize),
- }
-
- _, s.localPublic = netip.AddrFromSlice(conf.PublicIP)
-
- go s.mainLoop()
- return s
-}
-
-func (s *peerSupervisor) logf(msg string, args ...any) {
- msg = fmt.Sprintf("[%03d] ", s.remoteIP) + msg
- log.Printf(msg, args...)
-}
-
-// ----------------------------------------------------------------------------
-
-func (s *peerSupervisor) mainLoop() {
- defer panicHandler()
- state := s.stateInit
- for {
- state = state()
- }
-}
-
-// ----------------------------------------------------------------------------
-
-func (s *peerSupervisor) HandlePeerUpdate(p *m.Peer) {
- if p != nil {
- if p.Version == s.version {
- return
- }
- s.version = p.Version
- } else {
- s.version = 0
- }
-
- s.peerUpdates <- p
-}
-
-func (s *peerSupervisor) HandlePacket(w routingPacketWrapper) {
- select {
- case s.packets <- w:
- default:
- // Drop
- }
-}
-
-// ----------------------------------------------------------------------------
-
-func (s *peerSupervisor) stateInit() stateFunc {
- if s.peer == nil {
- return s.stateDisconnected
- }
-
- addr, ok := netip.AddrFromSlice(s.peer.PublicIP)
- if ok {
- addrPort := netip.AddrPortFrom(addr, s.peer.Port)
- s.remoteAddrPort = addrPort
- } else {
- s.remoteAddrPort = zeroAddrPort
- }
- s.sharedKey = computeSharedKey(s.peer.EncPubKey, s.privKey)
-
- return s.stateSelectRole()
-}
-
-// ----------------------------------------------------------------------------
-
-func (s *peerSupervisor) stateDisconnected() stateFunc {
- s.clearRoutingTable()
-
- for {
- select {
- case <-s.packets:
- // Drop
- case s.peer = <-s.peerUpdates:
- return s.stateInit
- }
- }
-}
-
-// ----------------------------------------------------------------------------
-
-func (s *peerSupervisor) stateSelectRole() stateFunc {
- s.logf("STATE: SelectRole")
- s.updateRoutingTable(false)
-
- if s.remoteAddrPort != zeroAddrPort {
- s.mediated = false
-
- // If both remote and local are public, one side acts as client, and one
- // side as server.
- if s.localPublic && s.localIP < s.peer.PeerIP {
- return s.stateAccept
- }
- return s.stateDial
- }
-
- // We're public, remote is not => can only wait for connection
- if s.localPublic {
- s.mediated = false
- return s.stateAccept
- }
-
- // Both non-public => need to use mediator.
- return s.stateMediated
-}
-
-// ----------------------------------------------------------------------------
-
-func (s *peerSupervisor) stateAccept() stateFunc {
- s.logf("STATE: Accept")
-
- for {
-
- select {
- case pkt := <-s.packets:
- switch pkt.Type {
-
- case packetTypePing:
- s.remoteAddrPort = pkt.Addr
- s.updateRoutingTable(true)
- s.sendPong(pkt.TraceID)
- return s.stateConnected
-
- default:
- // Still waiting for ping...
- }
-
- case s.peer = <-s.peerUpdates:
- return s.stateInit
- }
- }
-}
-
-// ----------------------------------------------------------------------------
-
-func (s *peerSupervisor) stateDial() stateFunc {
- s.logf("STATE: Dial")
- s.updateRoutingTable(false)
-
- s.sendPing()
-
- for {
- select {
- case pkt := <-s.packets:
-
- switch pkt.Type {
-
- case packetTypePong:
- s.updateRoutingTable(true)
- return s.stateConnected
-
- default:
- // Ignore
- }
-
- case <-s.pingTimer.C:
- s.sendPing()
-
- case s.peer = <-s.peerUpdates:
- return s.stateInit
- }
- }
-}
-
-// ----------------------------------------------------------------------------
-
-func (s *peerSupervisor) stateConnected() stateFunc {
- s.logf("STATE: Connected")
-
- s.timeoutTimer.Reset(timeoutInterval)
-
- for {
- select {
-
- case <-s.pingTimer.C:
- s.sendPing()
-
- case <-s.timeoutTimer.C:
- s.logf("Timeout")
- return s.stateInit
-
- case pkt := <-s.packets:
- switch pkt.Type {
- case packetTypePing:
- s.sendPong(pkt.TraceID)
-
- // Server should always follow remote port.
- if s.localPublic {
- if pkt.Addr != s.remoteAddrPort {
- s.remoteAddrPort = pkt.Addr
- s.updateRoutingTable(true)
- }
- }
-
- case packetTypePong:
- s.timeoutTimer.Reset(timeoutInterval)
-
- default:
- // Drop packet.
- }
-
- case s.peer = <-s.peerUpdates:
- s.logf("New peer: %v", s.peer)
- return s.stateInit
- }
- }
-}
-
-// ----------------------------------------------------------------------------
-
-func (s *peerSupervisor) stateMediated() stateFunc {
- s.logf("STATE: Mediated")
- s.mediated = true
- s.updateRoutingTable(true)
-
- for {
- select {
- case <-s.packets:
- // Drop.
- case s.peer = <-s.peerUpdates:
- s.logf("New peer: %v", s.peer)
- return s.stateInit
- }
- }
-}
-
-// ----------------------------------------------------------------------------
-
-func (s *peerSupervisor) clearRoutingTable() {
- s.table.Set(s.remoteIP, nil)
-}
-
-func (s *peerSupervisor) updateRoutingTable(up bool) {
- s.table.Set(s.remoteIP, &peer{
- Up: up,
- Mediator: s.peer.Mediator,
- Mediated: s.mediated,
- IP: s.remoteIP,
- Addr: s.remoteAddrPort,
- SharedKey: s.sharedKey,
- })
-}
-
-// ----------------------------------------------------------------------------
-
-func (s *peerSupervisor) sendPing() uint64 {
- traceID := newTraceID()
- pkt := newRoutingPacket(packetTypePing, traceID)
- s.w.WriteTo(s.peer.PeerIP, streamControl, pkt.Marshal(s.buf))
- s.pingTimer.Reset(pingInterval)
- return traceID
-}
-
-func (s *peerSupervisor) sendPong(traceID uint64) {
- pkt := newRoutingPacket(packetTypePong, traceID)
- s.w.WriteTo(s.peer.PeerIP, streamControl, pkt.Marshal(s.buf))
-}
diff --git a/node/router.go b/node/router.go
index 0e74d14..116b4d0 100644
--- a/node/router.go
+++ b/node/router.go
@@ -1,196 +1,7 @@
package node
import (
- "encoding/json"
- "io"
- "log"
- "net/http"
"net/netip"
- "net/url"
- "sync/atomic"
- "time"
- "vppn/m"
)
var zeroAddrPort = netip.AddrPort{}
-
-type peer struct {
- IP byte // The VPN IP.
- Up bool // No data will be sent to peers that are down.
- Addr netip.AddrPort // If we have direct connection, otherwise use mediator.
- Mediator bool // True if the peer will mediate.
- RoutingCipher controlCipher
- DataCipher dataCipher
-
- // TODO: Deprecated below.
- Mediated bool
- SharedKey []byte
-}
-
-// ----------------------------------------------------------------------------
-
-type routingTable struct {
- table [256]*atomic.Pointer[peer]
- mediator *atomic.Pointer[peer]
-}
-
-func newRoutingTable() *routingTable {
- r := routingTable{
- mediator: &atomic.Pointer[peer]{},
- }
-
- for i := range r.table {
- r.table[i] = &atomic.Pointer[peer]{}
- }
-
- return &r
-}
-
-func (r *routingTable) Get(ip byte) *peer {
- return r.table[ip].Load()
-}
-
-func (r *routingTable) Set(ip byte, p *peer) {
- r.table[ip].Store(p)
-}
-
-func (r *routingTable) Mediator() *peer {
- return r.mediator.Load()
-}
-
-// ----------------------------------------------------------------------------
-
-type router struct {
- *routingTable
- netName string
- peerSupers [256]*peerSupervisor
-}
-
-func newRouter(netName string, conf m.PeerConfig, routingData *routingTable, w *connWriter) *router {
- r := &router{
- netName: netName,
- routingTable: routingData,
- }
-
- for i := range r.peerSupers {
- r.peerSupers[i] = newPeerSupervisor(
- conf,
- byte(i),
- w,
- r.routingTable)
- }
-
- go r.selectMediator()
- go r.pollHub(conf)
-
- return r
-}
-
-// ----------------------------------------------------------------------------
-
-func (r *router) HandlePacket(sourceIP byte, remoteAddr netip.AddrPort, data []byte) {
- p := routingPacket{}
- if err := p.Parse(data); err != nil {
- log.Printf("Dropping malformed routing packet: %v", err)
- return
- }
-
- w := routingPacketWrapper{
- routingPacket: p,
- Addr: remoteAddr,
- }
-
- r.peerSupers[sourceIP].HandlePacket(w)
-}
-
-// ----------------------------------------------------------------------------
-
-func (r *router) pollHub(conf m.PeerConfig) {
- defer panicHandler()
-
- u, err := url.Parse(conf.HubAddress)
- if err != nil {
- log.Fatalf("Failed to parse hub address %s: %v", conf.HubAddress, err)
- }
- u.Path = "/peer/fetch-state/"
-
- client := &http.Client{Timeout: 8 * time.Second}
-
- req := &http.Request{
- Method: http.MethodGet,
- URL: u,
- Header: http.Header{},
- }
- req.SetBasicAuth("", conf.APIKey)
-
- state, err := loadNetworkState(r.netName)
- if err != nil {
- log.Printf("Failed to load network state: %v", err)
- log.Printf("Polling hub...")
- r._pollHub(conf, client, req)
- } else {
- r.applyNetworkState(conf, state)
- }
-
- for range time.Tick(64 * time.Second) {
- r._pollHub(conf, client, req)
- }
-}
-
-func (r *router) _pollHub(conf m.PeerConfig, client *http.Client, req *http.Request) {
- var state m.NetworkState
-
- log.Printf("Fetching peer state from %s...", conf.HubAddress)
- resp, err := client.Do(req)
- if err != nil {
- log.Printf("Failed to fetch peer state: %v", err)
- return
- }
- body, err := io.ReadAll(resp.Body)
- _ = resp.Body.Close()
- if err != nil {
- log.Printf("Failed to read body from hub: %v", err)
- return
- }
-
- if err := json.Unmarshal(body, &state); err != nil {
- log.Printf("Failed to unmarshal response from hub: %v", err)
- return
- }
-
- r.applyNetworkState(conf, state)
-
- if err := storeNetworkState(r.netName, state); err != nil {
- log.Printf("Failed to store network state: %v", err)
- }
-}
-
-func (r *router) applyNetworkState(conf m.PeerConfig, state m.NetworkState) {
- for i := range state.Peers {
- if i != int(conf.PeerIP) {
- r.peerSupers[i].HandlePeerUpdate(state.Peers[i])
- }
- }
-}
-
-// ----------------------------------------------------------------------------
-
-func (r *router) selectMediator() {
- for range time.Tick(8 * time.Second) {
- current := r.mediator.Load()
- if current != nil && current.Up {
- continue
- }
-
- for i := range r.table {
- peer := r.table[i].Load()
- if peer != nil && peer.Up && peer.Mediator {
- log.Printf("Got mediator: %v", *peer)
- r.mediator.Store(peer)
- return
- }
- }
-
- r.mediator.Store(nil)
- }
-}
diff --git a/node/routingpacket.go b/node/routingpacket.go
deleted file mode 100644
index 1b5aed1..0000000
--- a/node/routingpacket.go
+++ /dev/null
@@ -1,33 +0,0 @@
-package node
-
-import (
- "unsafe"
-)
-
-type routingPacket struct {
- Type byte // One of the packetType* constants.
- TraceID uint64 // For matching requests and responses.
-}
-
-func newRoutingPacket(reqType byte, traceID uint64) routingPacket {
- return routingPacket{
- Type: reqType,
- TraceID: traceID,
- }
-}
-
-func (p routingPacket) Marshal(buf []byte) []byte {
- buf = buf[:32] // Reserve 32 bytes just in case we need to add anything.
- buf[0] = p.Type
- *(*uint64)(unsafe.Pointer(&buf[1])) = uint64(p.TraceID)
- return buf
-}
-
-func (p *routingPacket) Parse(buf []byte) error {
- if len(buf) != 32 {
- return errMalformedPacket
- }
- p.Type = buf[0]
- p.TraceID = *(*uint64)(unsafe.Pointer(&buf[1]))
- return nil
-}
diff --git a/node/tmp-server.go b/node/tmp-server.go
deleted file mode 100644
index 179a8a4..0000000
--- a/node/tmp-server.go
+++ /dev/null
@@ -1,185 +0,0 @@
-package node
-
-/*
-var (
- network = []byte{10, 1, 1, 0}
- serverIP = byte(1)
- clientIP = byte(2)
- port = uint16(5151)
- netName = "testnet"
- pubKey1 = []byte{0x43, 0xde, 0xd4, 0xb2, 0x1d, 0x71, 0x58, 0x9a, 0x96, 0x3a, 0x23, 0xfc, 0x2, 0xe, 0xfa, 0x42, 0x3, 0x94, 0xbc, 0xf8, 0x25, 0xf, 0x54, 0xcc, 0x98, 0x42, 0x8b, 0xe5, 0x27, 0x86, 0x49, 0x33}
- privKey1 = []byte{0xae, 0x4d, 0xc5, 0xaa, 0xc9, 0xbc, 0x65, 0x41, 0x55, 0xb, 0x61, 0x52, 0xc4, 0x6c, 0xce, 0x2f, 0x1b, 0xf5, 0xb3, 0xbf, 0xb5, 0x54, 0x61, 0x7c, 0x26, 0x2e, 0xba, 0x5a, 0x19, 0xe2, 0x9c, 0xe0}
- pubKey2 = []byte{0x8c, 0xfe, 0x12, 0xd9, 0x2d, 0x37, 0x5, 0x43, 0xab, 0x70, 0x59, 0x20, 0x3d, 0x82, 0x93, 0x9b, 0xb3, 0xaa, 0x35, 0x23, 0xc1, 0xb4, 0x4, 0x1f, 0x92, 0x97, 0x6f, 0xfd, 0x55, 0x17, 0x5a, 0x4b}
- privKey2 = []byte{0xd9, 0xe1, 0xc6, 0x64, 0x3e, 0x29, 0x29, 0x78, 0x81, 0x53, 0xc2, 0x31, 0xd9, 0x34, 0x5b, 0x41, 0xf5, 0x80, 0xb0, 0x27, 0x9f, 0x65, 0x85, 0xd4, 0x78, 0xd5, 0x9, 0x2, 0xca, 0x56, 0x42, 0x80}
-)
-
-func must(err error) {
- if err != nil {
- panic(err)
- }
-}
-
-type TmpNode struct {
- network []byte
- localIP byte
- router *router
- port uint16
- netName string
- iface io.ReadWriteCloser
- pubKey []byte
- privKey []byte
- w *connWriter
- r *connReader
-}
-
-// ----------------------------------------------------------------------------
-
-func NewTmpNodeServer() *TmpNode {
- n := &TmpNode{
- localIP: serverIP,
- network: network,
- router: &router{table: newPeerRepo()},
- port: port,
- netName: netName,
- pubKey: pubKey1,
- privKey: privKey1,
- }
-
- var err error
- n.iface, err = openInterface(n.network, n.localIP, n.netName)
- must(err)
-
- myAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", n.port))
- must(err)
-
- conn, err := net.ListenUDP("udp", myAddr)
- must(err)
-
- n.w = newConnWriter(conn, n.localIP, n.router)
- n.r = newConnReader(conn, n.localIP, n.router)
-
- n.router.table.Set(clientIP, &peer{
- IP: clientIP,
- SharedKey: computeSharedKey(pubKey2, n.privKey),
- })
-
- return n
-}
-
-// ----------------------------------------------------------------------------
-
-func NewTmpNodeClient(srvAddrStr string) *TmpNode {
- n := &TmpNode{
- localIP: clientIP,
- network: network,
- router: &router{table: newPeerRepo()},
- port: port,
- netName: netName,
- pubKey: pubKey2,
- privKey: privKey2,
- }
-
- var err error
- n.iface, err = openInterface(n.network, n.localIP, n.netName)
- must(err)
-
- myAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", n.port))
- must(err)
-
- conn, err := net.ListenUDP("udp", myAddr)
- must(err)
-
- n.w = newConnWriter(conn, n.localIP, n.router)
- n.r = newConnReader(conn, n.localIP, n.router)
-
- serverAddr, err := netip.ParseAddrPort(fmt.Sprintf("%s:%d", srvAddrStr, port))
- must(err)
-
- n.router.table.Set(serverIP, &peer{
- IP: serverIP,
- Addr: &serverAddr,
- SharedKey: computeSharedKey(pubKey1, n.privKey),
- })
-
- return n
-}
-
-// ----------------------------------------------------------------------------
-
-func (n *TmpNode) RunServer() {
- defer func() {
- if r := recover(); r != nil {
- fmt.Printf("%v", r)
- debug.PrintStack()
- }
- }()
-
- // Get remoteAddr from a packet.
- buf := make([]byte, bufferSize)
- remoteAddr, h, _, err := n.r.Read(buf)
- must(err)
- log.Printf("Got remote addr: %d -> %v", h.SourceIP, remoteAddr)
- must(err)
-
- n.router.table.Set(h.SourceIP, &peer{
- IP: h.SourceIP,
- Addr: &remoteAddr,
- SharedKey: computeSharedKey(pubKey2, n.privKey),
- })
-
- go n.readFromIFace()
- n.readFromConn()
-}
-
-// ----------------------------------------------------------------------------
-
-func (n *TmpNode) RunClient() {
- defer func() {
- if r := recover(); r != nil {
- fmt.Printf("%v\n", r)
- debug.PrintStack()
- }
- }()
-
- log.Printf("Sending to server...")
- must(n.w.WriteTo(serverIP, 1, []byte{1, 2, 3, 4, 5, 6, 7, 8}))
-
- go n.readFromIFace()
- n.readFromConn()
-}
-
-func (n *TmpNode) readFromIFace() {
- var (
- buf = make([]byte, bufferSize)
- packet []byte
- remoteIP byte
- err error
- )
-
- for {
- packet, remoteIP, err = readNextPacket(n.iface, buf)
- must(err)
- must(n.w.WriteTo(remoteIP, 1, packet))
- }
-}
-
-func (node *TmpNode) readFromConn() {
- var (
- buf = make([]byte, bufferSize)
- packet []byte
- err error
- )
-
- for {
- _, _, packet, err = node.r.Read(buf)
- must(err)
- // We assume that we're only receiving packets from one source.
-
- _, err = node.iface.Write(packet)
- if err != nil {
- log.Printf("Got error: %v", err)
- }
- //must(err)
- }
-}
-*/
--
2.39.5
From 1be5c791867895b2a5d7b3327520423368bbf6a1 Mon Sep 17 00:00:00 2001
From: jdl
Date: Fri, 20 Dec 2024 16:26:20 +0100
Subject: [PATCH 05/18] cleanup
---
node/cipher-control.go | 2 +-
node/cipher-control_test.go | 8 ++++----
node/cipher-data.go | 2 +-
node/cipher-data_test.go | 10 +++++-----
node/globals.go | 8 +++++++-
node/header.go | 23 +++++++++--------------
node/header_test.go | 4 ++--
node/interface.go | 5 -----
node/main.go | 2 +-
node/peer-supervisor.go | 2 +-
node/peer.go | 7 ++++---
11 files changed, 35 insertions(+), 38 deletions(-)
diff --git a/node/cipher-control.go b/node/cipher-control.go
index e9b56d5..bd11470 100644
--- a/node/cipher-control.go
+++ b/node/cipher-control.go
@@ -12,7 +12,7 @@ func newControlCipher(privKey, pubKey []byte) *controlCipher {
return &controlCipher{shared}
}
-func (cc *controlCipher) Encrypt(h xHeader, data, out []byte) []byte {
+func (cc *controlCipher) Encrypt(h header, data, out []byte) []byte {
const s = controlHeaderSize
out = out[:s+controlCipherOverhead+len(data)]
h.Marshal(out[:s])
diff --git a/node/cipher-control_test.go b/node/cipher-control_test.go
index c571aa2..ab28860 100644
--- a/node/cipher-control_test.go
+++ b/node/cipher-control_test.go
@@ -40,7 +40,7 @@ func TestControlCipher(t *testing.T) {
}
for _, plaintext := range testCases {
- h1 := xHeader{
+ h1 := header{
StreamID: controlStreamID,
Counter: 235153,
SourceIP: 4,
@@ -51,7 +51,7 @@ func TestControlCipher(t *testing.T) {
encrypted = c1.Encrypt(h1, plaintext, encrypted)
- h2 := xHeader{}
+ h2 := header{}
h2.Parse(encrypted)
if !reflect.DeepEqual(h1, h2) {
t.Fatal(h1, h2)
@@ -80,7 +80,7 @@ func TestControlCipher_ShortCiphertext(t *testing.T) {
func BenchmarkControlCipher_Encrypt(b *testing.B) {
c1, _ := newControlCipherForTesting()
- h1 := xHeader{
+ h1 := header{
Counter: 235153,
SourceIP: 4,
DestIP: 88,
@@ -100,7 +100,7 @@ func BenchmarkControlCipher_Encrypt(b *testing.B) {
func BenchmarkControlCipher_Decrypt(b *testing.B) {
c1, c2 := newControlCipherForTesting()
- h1 := xHeader{
+ h1 := header{
Counter: 235153,
SourceIP: 4,
DestIP: 88,
diff --git a/node/cipher-data.go b/node/cipher-data.go
index 26d3121..7cdc0d5 100644
--- a/node/cipher-data.go
+++ b/node/cipher-data.go
@@ -39,7 +39,7 @@ func (sc *dataCipher) Key() [32]byte {
return sc.key
}
-func (sc *dataCipher) Encrypt(h xHeader, data, out []byte) []byte {
+func (sc *dataCipher) Encrypt(h header, data, out []byte) []byte {
const s = dataHeaderSize
out = out[:s+dataCipherOverhead+len(data)]
h.Marshal(out[:s])
diff --git a/node/cipher-data_test.go b/node/cipher-data_test.go
index c3892bb..493c198 100644
--- a/node/cipher-data_test.go
+++ b/node/cipher-data_test.go
@@ -22,7 +22,7 @@ func TestDataCipher(t *testing.T) {
}
for _, plaintext := range testCases {
- h1 := xHeader{
+ h1 := header{
StreamID: dataStreamID,
Counter: 235153,
SourceIP: 4,
@@ -33,7 +33,7 @@ func TestDataCipher(t *testing.T) {
dc1 := newDataCipher()
encrypted = dc1.Encrypt(h1, plaintext, encrypted)
- h2 := xHeader{}
+ h2 := header{}
h2.Parse(encrypted)
dc2 := newDataCipherFromKey(dc1.Key())
@@ -67,7 +67,7 @@ func TestDataCipher_ModifyCiphertext(t *testing.T) {
}
for _, plaintext := range testCases {
- h1 := xHeader{
+ h1 := header{
Counter: 235153,
SourceIP: 4,
DestIP: 88,
@@ -99,7 +99,7 @@ func TestDataCipher_ShortCiphertext(t *testing.T) {
}
func BenchmarkDataCipher_Encrypt(b *testing.B) {
- h1 := xHeader{
+ h1 := header{
Counter: 235153,
SourceIP: 4,
DestIP: 88,
@@ -118,7 +118,7 @@ func BenchmarkDataCipher_Encrypt(b *testing.B) {
}
func BenchmarkDataCipher_Decrypt(b *testing.B) {
- h1 := xHeader{
+ h1 := header{
Counter: 235153,
SourceIP: 4,
DestIP: 88,
diff --git a/node/globals.go b/node/globals.go
index 172e6ef..d646e71 100644
--- a/node/globals.go
+++ b/node/globals.go
@@ -1,3 +1,9 @@
package node
-const bufferSize = if_mtu + 128
+const (
+ bufferSize = 1536
+ if_mtu = 1200
+ if_queue_len = 2048
+ controlCipherOverhead = 16
+ dataCipherOverhead = 16
+)
diff --git a/node/header.go b/node/header.go
index f2e300f..97e5872 100644
--- a/node/header.go
+++ b/node/header.go
@@ -5,34 +5,29 @@ import "unsafe"
// ----------------------------------------------------------------------------
const (
- headerSize = 12
-
- controlStreamID = 2
- controlHeaderSize = 24
- controlCipherOverhead = 16
-
- dataStreamID = 1
- dataHeaderSize = 12
- dataCipherOverhead = 16
-
- forwardStreamID = 3
+ headerSize = 12
+ controlStreamID = 2
+ controlHeaderSize = 24
+ dataStreamID = 1
+ dataHeaderSize = 12
+ forwardStreamID = 3
)
-type xHeader struct {
+type header struct {
StreamID byte
Counter uint64 // Init with fasttime.Now() << 30 to ensure monotonic.
SourceIP byte
DestIP byte
}
-func (h *xHeader) Parse(b []byte) {
+func (h *header) Parse(b []byte) {
h.StreamID = b[0]
h.Counter = *(*uint64)(unsafe.Pointer(&b[1]))
h.SourceIP = b[9]
h.DestIP = b[10]
}
-func (h *xHeader) Marshal(buf []byte) {
+func (h *header) Marshal(buf []byte) {
buf[0] = h.StreamID
*(*uint64)(unsafe.Pointer(&buf[1])) = h.Counter
buf[9] = h.SourceIP
diff --git a/node/header_test.go b/node/header_test.go
index 0205d87..9dbb061 100644
--- a/node/header_test.go
+++ b/node/header_test.go
@@ -3,7 +3,7 @@ package node
import "testing"
func TestHeaderMarshalParse(t *testing.T) {
- nIn := xHeader{
+ nIn := header{
StreamID: 23,
Counter: 3212,
SourceIP: 34,
@@ -13,7 +13,7 @@ func TestHeaderMarshalParse(t *testing.T) {
buf := make([]byte, headerSize)
nIn.Marshal(buf)
- nOut := xHeader{}
+ nOut := header{}
nOut.Parse(buf)
if nIn != nOut {
t.Fatal(nIn, nOut)
diff --git a/node/interface.go b/node/interface.go
index e066b2b..4b492b4 100644
--- a/node/interface.go
+++ b/node/interface.go
@@ -50,11 +50,6 @@ func readNextPacket(iface io.ReadWriteCloser, buf []byte) ([]byte, byte, error)
}
}
-const (
- if_mtu = 1350
- if_queue_len = 2048
-)
-
func openInterface(network []byte, localIP byte, name string) (io.ReadWriteCloser, error) {
if len(network) != 4 {
return nil, fmt.Errorf("expected network to be 4 bytes, got %d", len(network))
diff --git a/node/main.go b/node/main.go
index 00d7f9c..19252ff 100644
--- a/node/main.go
+++ b/node/main.go
@@ -141,7 +141,7 @@ func readFromConn(conn *net.UDPConn, peers remotePeers) {
err error
buf = make([]byte, bufferSize)
data []byte
- h xHeader
+ h header
)
for {
diff --git a/node/peer-supervisor.go b/node/peer-supervisor.go
index cfcb43b..e4dd881 100644
--- a/node/peer-supervisor.go
+++ b/node/peer-supervisor.go
@@ -186,7 +186,7 @@ func (rp *peerSuper) updateShared() {
func (rp *peerSuper) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) {
buf := pkt.Marshal(rp.pktBuf)
- h := xHeader{
+ h := header{
StreamID: controlStreamID,
Counter: atomic.AddUint64(&rp.counter, 1),
SourceIP: rp.localIP,
diff --git a/node/peer.go b/node/peer.go
index ab7ca77..a344472 100644
--- a/node/peer.go
+++ b/node/peer.go
@@ -15,6 +15,7 @@ type peerData struct {
controlCipher *controlCipher
dataCipher *dataCipher
remoteAddr netip.AddrPort
+ relayIP byte // Non-zero if we should relay.
}
type remotePeer struct {
@@ -88,7 +89,7 @@ func (rp *remotePeer) HandlePeerUpdate(peer *m.Peer) {
// HandlePacket accepts a raw data packet coming in from the network.
//
// This function is called by a single thread.
-func (rp *remotePeer) HandlePacket(addr netip.AddrPort, h xHeader, data []byte) {
+func (rp *remotePeer) HandlePacket(addr netip.AddrPort, h header, data []byte) {
switch h.StreamID {
case controlStreamID:
rp.handleControlPacket(addr, h, data)
@@ -107,7 +108,7 @@ func (rp *remotePeer) HandlePacket(addr netip.AddrPort, h xHeader, data []byte)
// ----------------------------------------------------------------------------
-func (rp *remotePeer) handleControlPacket(addr netip.AddrPort, h xHeader, data []byte) {
+func (rp *remotePeer) handleControlPacket(addr netip.AddrPort, h header, data []byte) {
shared := rp.shared.Load()
if shared.controlCipher == nil {
rp.logf("Not connected (control).")
@@ -194,7 +195,7 @@ func (rp *remotePeer) SendData(data []byte) {
return
}
- h := xHeader{
+ h := header{
StreamID: dataStreamID,
Counter: atomic.AddUint64(&rp.counter, 1),
SourceIP: rp.localIP,
--
2.39.5
From 5b34b3311b92cf4ee6a899e2a13e61210547c619 Mon Sep 17 00:00:00 2001
From: jdl
Date: Fri, 20 Dec 2024 17:11:20 +0100
Subject: [PATCH 06/18] wip: trying to get relaying to work.
---
node/main.go | 13 ++++--
node/peer-supervisor.go | 100 ++++++++++++++++++++++++++++++++++++----
node/peer.go | 56 ++++++++++++++++------
3 files changed, 142 insertions(+), 27 deletions(-)
diff --git a/node/main.go b/node/main.go
index 19252ff..9273823 100644
--- a/node/main.go
+++ b/node/main.go
@@ -108,11 +108,11 @@ func main(netName, listenIP string, port uint16) {
peers := remotePeers{}
for i := range peers {
- peers[i] = newRemotePeer(conf, byte(i), ifWriter, connWriter)
+ peers[i] = newRemotePeer(conf, byte(i), ifWriter, connWriter, &peers)
}
go newHubPoller(netName, conf, peers).Run()
- go readFromConn(conn, peers)
+ go readFromConn(conf.PeerIP, conn, peers)
readFromIFace(iface, peers)
}
@@ -131,7 +131,7 @@ func determinePort(confPort, portFromCommandLine uint16) uint16 {
// ----------------------------------------------------------------------------
-func readFromConn(conn *net.UDPConn, peers remotePeers) {
+func readFromConn(localIP byte, conn *net.UDPConn, peers remotePeers) {
defer panicHandler()
@@ -157,7 +157,12 @@ func readFromConn(conn *net.UDPConn, peers remotePeers) {
}
h.Parse(data)
- peers[h.SourceIP].HandlePacket(remoteAddr, h, data)
+
+ if h.DestIP == localIP {
+ peers[h.SourceIP].HandlePacket(remoteAddr, h, data)
+ } else {
+ peers[h.DestIP].ForwardPacket(data)
+ }
}
}
diff --git a/node/peer-supervisor.go b/node/peer-supervisor.go
index e4dd881..cc615ab 100644
--- a/node/peer-supervisor.go
+++ b/node/peer-supervisor.go
@@ -1,6 +1,8 @@
package node
import (
+ "log"
+ "math/rand"
"net/netip"
"sync/atomic"
"time"
@@ -47,12 +49,15 @@ func (rp *peerSuper) Run() {
func (rp *peerSuper) stateInit() stateFunc {
//rp.logf("STATE: Init")
+
x := peerData{}
rp.shared.Store(&x)
+ rp.peerData.relay = false
rp.peerData.controlCipher = nil
rp.peerData.dataCipher = nil
rp.peerData.remoteAddr = zeroAddrPort
+ rp.peerData.relayIP = 0
if rp.peer == nil {
return rp.stateDisconnected
@@ -62,6 +67,8 @@ func (rp *peerSuper) stateInit() stateFunc {
addr, rp.remotePublic = netip.AddrFromSlice(rp.peer.PublicIP)
if rp.remotePublic {
rp.peerData.remoteAddr = netip.AddrPortFrom(addr, rp.peer.Port)
+ } else {
+ rp.peerData.relay = false
}
rp.peerData.controlCipher = newControlCipher(rp.privKey, rp.peer.EncPubKey)
@@ -89,8 +96,7 @@ func (rp *peerSuper) stateSelectRole() stateFunc {
rp.logf("STATE: SelectRole")
if !rp.localPublic && !rp.remotePublic {
- // TODO!
- return rp.stateDisconnected
+ return rp.stateSelectMediator
}
if !rp.localPublic {
@@ -99,12 +105,55 @@ func (rp *peerSuper) stateSelectRole() stateFunc {
return rp.stateClient
}
- if rp.localIP < rp.peer.PeerIP {
+ if rp.localIP < rp.remoteIP {
return rp.stateClient
}
return rp.stateServer
}
+// ----------------------------------------------------------------------------
+
+func (rp *peerSuper) stateSelectMediator() stateFunc {
+ rp.logf("STATE: SelectMediator")
+
+ for {
+ log.Printf("Selecting mediator...")
+ if ip := rp.selectMediator(); ip != 0 {
+ rp.logf("Got mediator: %d", ip)
+ rp.peerData.relayIP = ip
+
+ if rp.localIP < rp.remoteIP {
+ return rp.stateClient
+ }
+ return rp.stateServer
+ }
+
+ select {
+ case <-time.After(pingInterval):
+ continue
+ case rp.peer = <-rp.peerUpdates:
+ return rp.stateInit
+ }
+ }
+
+}
+
+func (rp *peerSuper) selectMediator() byte {
+ possible := make([]byte, 0, 8)
+ for _, peer := range rp.peers {
+ if peer.canRelay() {
+ rp.logf("relay: %v", peer.shared.Load())
+ possible = append(possible, peer.remoteIP)
+ }
+ }
+ if len(possible) == 0 {
+ return 0
+ }
+ return possible[rand.Intn(len(possible))]
+}
+
+// ----------------------------------------------------------------------------
+
// The remote is a server.
func (rp *peerSuper) stateServer() stateFunc {
rp.logf("STATE: Server")
@@ -112,10 +161,12 @@ func (rp *peerSuper) stateServer() stateFunc {
rp.updateShared()
var (
- pingTimer = time.NewTimer(pingInterval)
- ping = pingPacket{SharedKey: ([32]byte)(rp.peerData.dataCipher.Key())}
+ pingTimer = time.NewTimer(pingInterval)
+ timeoutTimer = time.NewTimer(timeoutInterval)
+ ping = pingPacket{SharedKey: ([32]byte)(rp.peerData.dataCipher.Key())}
)
defer pingTimer.Stop()
+ defer timeoutTimer.Stop()
ping.SentAt = time.Now().UnixMilli()
rp.sendControlPacket(ping)
@@ -127,8 +178,18 @@ func (rp *peerSuper) stateServer() stateFunc {
rp.sendControlPacket(ping)
pingTimer.Reset(pingInterval)
- case <-rp.controlPackets:
- // Ignore
+ case cPkt := <-rp.controlPackets:
+ if _, ok := cPkt.Payload.(pongPacket); ok {
+ timeoutTimer.Reset(timeoutInterval)
+ }
+
+ case <-timeoutTimer.C:
+ if rp.peerData.relayIP != 0 {
+ rp.logf("Timeout (server, relay)")
+ return rp.stateSelectMediator
+ } else {
+ rp.logf("Timeout (server)")
+ }
case rp.peer = <-rp.peerUpdates:
return rp.stateInit
@@ -143,8 +204,12 @@ func (rp *peerSuper) stateClient() stateFunc {
rp.logf("STATE: Client")
rp.updateShared()
- // TODO: Could use timeout to set dataCipher to nil.
- var currentKey = [32]byte{}
+ var (
+ currentKey = [32]byte{}
+ timeoutTimer = time.NewTimer(timeoutInterval)
+ )
+
+ defer timeoutTimer.Stop()
for {
select {
@@ -163,12 +228,22 @@ func (rp *peerSuper) stateClient() stateFunc {
if ping.SharedKey != currentKey {
rp.logf("Connected with new shared key")
currentKey = ping.SharedKey
+ rp.peerData.up = true
rp.peerData.dataCipher = newDataCipherFromKey(currentKey)
rp.updateShared()
}
+ timeoutTimer.Reset(timeoutInterval)
rp.sendControlPacket(newPongPacket(ping.SentAt))
+ case <-timeoutTimer.C:
+ if rp.peerData.relayIP != 0 {
+ rp.logf("Timeout (server, relay)")
+ return rp.stateSelectMediator
+ } else {
+ rp.logf("Timeout (server)")
+ }
+
case rp.peer = <-rp.peerUpdates:
return rp.stateInit
}
@@ -193,5 +268,10 @@ func (rp *peerSuper) sendControlPacket(pkt interface{ Marshal([]byte) []byte })
DestIP: rp.remoteIP,
}
buf = rp.peerData.controlCipher.Encrypt(h, buf, rp.encBuf)
- rp.conn.WriteTo(buf, rp.peerData.remoteAddr)
+ if rp.peerData.relayIP == 0 {
+ rp.conn.WriteTo(buf, rp.peerData.remoteAddr)
+ return
+ }
+
+ rp.peers[rp.peerData.relayIP].RelayControlData(buf)
}
diff --git a/node/peer.go b/node/peer.go
index a344472..d999339 100644
--- a/node/peer.go
+++ b/node/peer.go
@@ -12,6 +12,8 @@ import (
type remotePeers [256]*remotePeer
type peerData struct {
+ up bool
+ relay bool
controlCipher *controlCipher
dataCipher *dataCipher
remoteAddr netip.AddrPort
@@ -28,6 +30,7 @@ type remotePeer struct {
conn *connWriter
// Shared state.
+ peers *remotePeers
shared *atomic.Pointer[peerData]
// Only used in HandlePeerUpdate.
@@ -48,7 +51,7 @@ type remotePeer struct {
controlPackets chan controlPacket
}
-func newRemotePeer(conf m.PeerConfig, remoteIP byte, iface *ifWriter, conn *connWriter) *remotePeer {
+func newRemotePeer(conf m.PeerConfig, remoteIP byte, iface *ifWriter, conn *connWriter, peers *remotePeers) *remotePeer {
rp := &remotePeer{
localIP: conf.PeerIP,
remoteIP: remoteIP,
@@ -56,6 +59,7 @@ func newRemotePeer(conf m.PeerConfig, remoteIP byte, iface *ifWriter, conn *conn
localPublic: addrIsValid(conf.PublicIP),
iface: iface,
conn: conn,
+ peers: peers,
shared: &atomic.Pointer[peerData]{},
dupCheck: newDupCheck(0),
decryptBuf: make([]byte, bufferSize),
@@ -97,10 +101,6 @@ func (rp *remotePeer) HandlePacket(addr netip.AddrPort, h header, data []byte) {
case dataStreamID:
rp.handleDataPacket(data)
- case forwardStreamID:
- fallthrough
- // TODO
- //rp.handleForwardPacket(h, data)
default:
rp.logf("Unknown stream ID: %d", h.StreamID)
}
@@ -115,6 +115,11 @@ func (rp *remotePeer) handleControlPacket(addr netip.AddrPort, h header, data []
return
}
+ if h.DestIP != rp.localIP {
+ rp.logf("Incorrect destination IP on control packet.")
+ return
+ }
+
out, ok := shared.controlCipher.Decrypt(data, rp.decryptBuf)
if !ok {
rp.logf("Failed to decrypt control packet.")
@@ -131,13 +136,6 @@ func (rp *remotePeer) handleControlPacket(addr netip.AddrPort, h header, data []
return
}
- if h.DestIP != rp.localIP {
- // TODO: Forward control packet.
- // TODO: Probably this should be dropped.
- // Control packets should be forwarded as data for efficiency.
- return
- }
-
pkt := controlPacket{
SrcIP: h.SourceIP,
RemoteAddr: addr,
@@ -167,6 +165,8 @@ func (rp *remotePeer) handleControlPacket(addr netip.AddrPort, h header, data []
}
}
+// ----------------------------------------------------------------------------
+
func (rp *remotePeer) handleDataPacket(data []byte) {
shared := rp.shared.Load()
if shared.dataCipher == nil {
@@ -189,6 +189,29 @@ func (rp *remotePeer) handleDataPacket(data []byte) {
//
// This function is called by a single thread.
func (rp *remotePeer) SendData(data []byte) {
+ rp.sendData(dataStreamID, data)
+}
+
+// ----------------------------------------------------------------------------
+
+func (rp *remotePeer) RelayControlData(data []byte) {
+ rp.sendData(forwardStreamID, data)
+}
+
+// ----------------------------------------------------------------------------
+
+func (rp *remotePeer) ForwardPacket(data []byte) {
+ shared := rp.shared.Load()
+ if shared.remoteAddr == zeroAddrPort {
+ rp.logf("Not connected (forward).")
+ return
+ }
+ rp.conn.WriteTo(data, shared.remoteAddr)
+}
+
+// ----------------------------------------------------------------------------
+
+func (rp *remotePeer) sendData(streamID byte, data []byte) {
shared := rp.shared.Load()
if shared.dataCipher == nil || shared.remoteAddr == zeroAddrPort {
rp.logf("Not connected (send).")
@@ -196,7 +219,7 @@ func (rp *remotePeer) SendData(data []byte) {
}
h := header{
- StreamID: dataStreamID,
+ StreamID: streamID,
Counter: atomic.AddUint64(&rp.counter, 1),
SourceIP: rp.localIP,
DestIP: rp.remoteIP,
@@ -205,3 +228,10 @@ func (rp *remotePeer) SendData(data []byte) {
enc := shared.dataCipher.Encrypt(h, data, rp.encryptBuf)
rp.conn.WriteTo(enc, shared.remoteAddr)
}
+
+// ----------------------------------------------------------------------------
+
+func (rp *remotePeer) canRelay() bool {
+ shared := rp.shared.Load()
+ return shared.relay && shared.up
+}
--
2.39.5
From c7d3fe1ed8961714916d2433efe50c50ebf944cd Mon Sep 17 00:00:00 2001
From: jdl
Date: Fri, 20 Dec 2024 21:06:16 +0100
Subject: [PATCH 07/18] wip: client/server working.
---
node/conn.go | 2 +
node/header.go | 2 +-
node/main.go | 10 +-
node/packets.go | 21 ++-
node/packets_test.go | 2 +-
node/peer-states.go | 214 +++++++++++++++++++++++++++++
node/peer-supervisor.go | 291 ++++++----------------------------------
node/peer.go | 129 ++++++++++--------
8 files changed, 354 insertions(+), 317 deletions(-)
create mode 100644 node/peer-states.go
diff --git a/node/conn.go b/node/conn.go
index 344d8d5..7671f36 100644
--- a/node/conn.go
+++ b/node/conn.go
@@ -5,6 +5,7 @@ import (
"log"
"net"
"net/netip"
+ "runtime/debug"
"sync"
)
@@ -22,6 +23,7 @@ func newConnWriter(conn *net.UDPConn) *connWriter {
func (w *connWriter) WriteTo(packet []byte, addr netip.AddrPort) {
w.lock.Lock()
if _, err := w.conn.WriteToUDPAddrPort(packet, addr); err != nil {
+ debug.PrintStack()
log.Fatalf("Failed to write to UDP port: %v", err)
}
w.lock.Unlock()
diff --git a/node/header.go b/node/header.go
index 97e5872..1a022a2 100644
--- a/node/header.go
+++ b/node/header.go
@@ -10,7 +10,7 @@ const (
controlHeaderSize = 24
dataStreamID = 1
dataHeaderSize = 12
- forwardStreamID = 3
+ relayStreamID = 3
)
type header struct {
diff --git a/node/main.go b/node/main.go
index 9273823..35a00e6 100644
--- a/node/main.go
+++ b/node/main.go
@@ -114,7 +114,6 @@ func main(netName, listenIP string, port uint16) {
go newHubPoller(netName, conf, peers).Run()
go readFromConn(conf.PeerIP, conn, peers)
readFromIFace(iface, peers)
-
}
// ----------------------------------------------------------------------------
@@ -157,12 +156,7 @@ func readFromConn(localIP byte, conn *net.UDPConn, peers remotePeers) {
}
h.Parse(data)
-
- if h.DestIP == localIP {
- peers[h.SourceIP].HandlePacket(remoteAddr, h, data)
- } else {
- peers[h.DestIP].ForwardPacket(data)
- }
+ peers[h.SourceIP].HandlePacket(remoteAddr, h, data)
}
}
@@ -183,6 +177,6 @@ func readFromIFace(iface io.ReadWriteCloser, peers remotePeers) {
log.Fatalf("Failed to read from interface: %v", err)
}
- peers[remoteIP].SendData(packet)
+ peers[remoteIP].HandleInterfacePacket(packet)
}
}
diff --git a/node/packets.go b/node/packets.go
index d197f58..57c7341 100644
--- a/node/packets.go
+++ b/node/packets.go
@@ -7,7 +7,10 @@ import (
"unsafe"
)
-var errMalformedPacket = errors.New("malformed packet")
+var (
+ errMalformedPacket = errors.New("malformed packet")
+ errUnknownPacketType = errors.New("unknown packet type")
+)
const (
packetTypePing = iota + 1
@@ -22,6 +25,18 @@ type controlPacket struct {
Payload any
}
+func (p *controlPacket) ParsePayload(buf []byte) (err error) {
+ switch buf[0] {
+ case packetTypePing:
+ p.Payload, err = parsePingPacket(buf)
+ case packetTypePong:
+ p.Payload, err = parsePongPacket(buf)
+ default:
+ return errUnknownPacketType
+ }
+ return err
+}
+
// ----------------------------------------------------------------------------
// A pingPacket is sent from a node acting as a client, to a node acting
@@ -32,9 +47,9 @@ type pingPacket struct {
SharedKey [32]byte
}
-func newPingPacket(sharedKey []byte) (pp pingPacket) {
+func newPingPacket(sharedKey [32]byte) (pp pingPacket) {
pp.SentAt = time.Now().UnixMilli()
- copy(pp.SharedKey[:], sharedKey)
+ copy(pp.SharedKey[:], sharedKey[:])
return
}
diff --git a/node/packets_test.go b/node/packets_test.go
index b385c2b..da242d4 100644
--- a/node/packets_test.go
+++ b/node/packets_test.go
@@ -12,7 +12,7 @@ func TestPacketPing(t *testing.T) {
buf := make([]byte, bufferSize)
- p := newPingPacket(sharedKey)
+ p := newPingPacket([32]byte(sharedKey))
out := p.Marshal(buf)
p2, err := parsePingPacket(out)
diff --git a/node/peer-states.go b/node/peer-states.go
new file mode 100644
index 0000000..c3c0904
--- /dev/null
+++ b/node/peer-states.go
@@ -0,0 +1,214 @@
+package node
+
+import (
+ "fmt"
+ "log"
+ "net/netip"
+ "sync/atomic"
+ "time"
+ "vppn/m"
+)
+
+type peerState interface {
+ Name() string
+ OnPeerUpdate(*m.Peer) peerState
+ OnPing(netip.AddrPort, pingPacket) peerState
+ OnPong(netip.AddrPort, pongPacket) peerState
+ OnPingTimer() peerState
+ OnTimeoutTimer() peerState
+}
+
+// ----------------------------------------------------------------------------
+
+type stateBase struct {
+ // The purpose of this state machine is to manage this published data.
+ published *atomic.Pointer[peerData]
+
+ // The other remote peers.
+ peers *remotePeers
+
+ // Immutable data.
+ localIP byte
+ localPub bool
+ remoteIP byte
+ privKey []byte
+ conn *connWriter
+
+ // For sending to peer.
+ counter *uint64
+
+ // Mutable peer data.
+ peer *m.Peer
+ remotePub bool
+ data peerData // Local copy of shared data. See publish().
+
+ // Timers
+ pingTimer *time.Timer
+ timeoutTimer *time.Timer
+
+ buf []byte
+ encBuf []byte
+}
+
+func (sb *stateBase) Name() string { return "idle" }
+
+func (s *stateBase) OnPeerUpdate(peer *m.Peer) peerState {
+ // Both nil: no change.
+ if peer == nil && s.peer == nil {
+ return nil
+ }
+
+ // No change.
+ if peer != nil && s.peer != nil && s.peer.Version == peer.Version {
+ return nil
+ }
+
+ s.peer = peer
+
+ s.data = peerData{}
+ s.data.controlCipher = newControlCipher(s.privKey, peer.EncPubKey)
+
+ ip, isValid := netip.AddrFromSlice(peer.PublicIP)
+ if isValid {
+ s.remotePub = true
+ s.data.remoteAddr = netip.AddrPortFrom(ip, peer.Port)
+ s.data.relay = peer.Mediator
+
+ if s.localPub && s.localIP < s.remoteIP {
+ return newStateServer(s)
+ }
+ return newStateClient(s)
+ }
+
+ if s.localPub {
+ return newStateServer(s)
+ }
+
+ // TODO: return newStateMediated(a/b)
+
+ return nil
+}
+
+func (s *stateBase) OnPing(rAddr netip.AddrPort, p pingPacket) peerState { return nil }
+func (s *stateBase) OnPong(rAddr netip.AddrPort, p pongPacket) peerState { return nil }
+func (s *stateBase) OnPingTimer() peerState { return nil }
+func (s *stateBase) OnTimeoutTimer() peerState { return nil }
+
+// Helpers.
+
+func (s *stateBase) resetPingTimer() { s.pingTimer.Reset(pingInterval) }
+func (s *stateBase) resetTimeoutTimer() { s.timeoutTimer.Reset(timeoutInterval) }
+func (s *stateBase) stopPingTimer() { s.pingTimer.Stop() }
+func (s *stateBase) stopTimeoutTimer() { s.timeoutTimer.Stop() }
+
+func (s *stateBase) logf(msg string, args ...any) {
+ log.Printf(fmt.Sprintf("[%03d] ", s.remoteIP)+msg, args...)
+}
+
+func (s *stateBase) publish() {
+ data := s.data
+ s.published.Store(&data)
+}
+
+func (s *stateBase) sendPing(sharedKey [32]byte) {
+ s.sendControlPacket(newPingPacket(sharedKey))
+}
+
+func (s *stateBase) sendPong(ping pingPacket) {
+ s.sendControlPacket(newPongPacket(ping.SentAt))
+}
+
+func (s *stateBase) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) {
+ buf := pkt.Marshal(s.buf)
+ h := header{
+ StreamID: controlStreamID,
+ Counter: atomic.AddUint64(s.counter, 1),
+ SourceIP: s.localIP,
+ DestIP: s.remoteIP,
+ }
+
+ buf = s.data.controlCipher.Encrypt(h, buf, s.encBuf)
+ if s.data.relayIP == 0 {
+ s.conn.WriteTo(buf, s.data.remoteAddr)
+ return
+ }
+
+ // TODO: Relay!
+}
+
+// ----------------------------------------------------------------------------
+
+type stateClient struct {
+ sharedKey [32]byte
+ *stateBase
+}
+
+func newStateClient(b *stateBase) peerState {
+ s := &stateClient{stateBase: b}
+ s.publish()
+
+ s.data.dataCipher = newDataCipher()
+ s.sharedKey = s.data.dataCipher.Key()
+
+ s.sendPing(s.sharedKey)
+ s.resetPingTimer()
+ s.resetTimeoutTimer()
+ return s
+}
+
+func (s *stateClient) Name() string { return "client" }
+
+func (s *stateClient) OnPong(addr netip.AddrPort, p pongPacket) peerState {
+ if !s.data.up {
+ s.data.up = true
+ s.publish()
+ }
+ s.resetTimeoutTimer()
+ return nil
+}
+
+func (s *stateClient) OnPingTimer() peerState {
+ s.sendPing(s.sharedKey)
+ s.resetPingTimer()
+ return nil
+}
+
+func (s *stateClient) OnTimeoutTimer() peerState {
+ s.data.up = false
+ s.publish()
+ return nil
+}
+
+// ----------------------------------------------------------------------------
+
+type stateServer struct {
+ *stateBase
+}
+
+func newStateServer(b *stateBase) peerState {
+ s := &stateServer{b}
+ s.publish()
+ s.stopPingTimer()
+ s.stopTimeoutTimer()
+ return s
+}
+
+func (s *stateServer) Name() string { return "server" }
+
+func (s *stateServer) OnPing(addr netip.AddrPort, p pingPacket) peerState {
+ if addr != s.data.remoteAddr {
+ s.logf("Got new peer address: %v", addr)
+ s.data.remoteAddr = addr
+ s.data.up = true
+ s.publish()
+ }
+
+ if s.data.dataCipher == nil || p.SharedKey != s.data.dataCipher.Key() {
+ s.logf("Got new shared key.")
+ s.data.dataCipher = newDataCipherFromKey(p.SharedKey)
+ s.publish()
+ }
+
+ s.sendPong(p)
+ return nil
+}
diff --git a/node/peer-supervisor.go b/node/peer-supervisor.go
index cc615ab..2c46ad2 100644
--- a/node/peer-supervisor.go
+++ b/node/peer-supervisor.go
@@ -1,10 +1,6 @@
package node
import (
- "log"
- "math/rand"
- "net/netip"
- "sync/atomic"
"time"
"vppn/m"
)
@@ -15,263 +11,64 @@ const (
timeoutInterval = 20 * time.Second
)
-type stateFunc func() stateFunc
-
-type peerSuper struct {
- *remotePeer
-
- peer *m.Peer
- remotePublic bool
- peerData peerData
-
- pktBuf []byte
- encBuf []byte
-}
-
-func newPeerSuper(rp *remotePeer) *peerSuper {
- return &peerSuper{
- remotePeer: rp,
- peer: nil,
- pktBuf: make([]byte, bufferSize),
- encBuf: make([]byte, bufferSize),
- }
-}
-
-func (rp *peerSuper) Run() {
+func (rp *remotePeer) supervise(
+ conf m.PeerConfig,
+ remoteIP byte,
+ conn *connWriter,
+ peers *remotePeers,
+) {
defer panicHandler()
- state := rp.stateInit
- for {
- state = state()
- }
-}
-// ----------------------------------------------------------------------------
-
-func (rp *peerSuper) stateInit() stateFunc {
- //rp.logf("STATE: Init")
-
- x := peerData{}
- rp.shared.Store(&x)
-
- rp.peerData.relay = false
- rp.peerData.controlCipher = nil
- rp.peerData.dataCipher = nil
- rp.peerData.remoteAddr = zeroAddrPort
- rp.peerData.relayIP = 0
-
- if rp.peer == nil {
- return rp.stateDisconnected
+ base := &stateBase{
+ published: rp.published,
+ peers: rp.peers,
+ localIP: rp.localIP,
+ remoteIP: rp.remoteIP,
+ privKey: conf.EncPrivKey,
+ localPub: addrIsValid(conf.PublicIP),
+ conn: rp.conn,
+ counter: &rp.counter,
+ pingTimer: time.NewTimer(time.Second),
+ timeoutTimer: time.NewTimer(time.Second),
+ buf: make([]byte, bufferSize),
+ encBuf: make([]byte, bufferSize),
}
- var addr netip.Addr
- addr, rp.remotePublic = netip.AddrFromSlice(rp.peer.PublicIP)
- if rp.remotePublic {
- rp.peerData.remoteAddr = netip.AddrPortFrom(addr, rp.peer.Port)
- } else {
- rp.peerData.relay = false
- }
-
- rp.peerData.controlCipher = newControlCipher(rp.privKey, rp.peer.EncPubKey)
-
- return rp.stateSelectRole()
-}
-
-// ----------------------------------------------------------------------------
-
-func (rp *peerSuper) stateDisconnected() stateFunc {
- //rp.logf("STATE: Disconnected")
- for {
- select {
- case <-rp.controlPackets:
- // Drop
- case rp.peer = <-rp.peerUpdates:
- return rp.stateInit
- }
- }
-}
-
-// ----------------------------------------------------------------------------
-
-func (rp *peerSuper) stateSelectRole() stateFunc {
- rp.logf("STATE: SelectRole")
-
- if !rp.localPublic && !rp.remotePublic {
- return rp.stateSelectMediator
- }
-
- if !rp.localPublic {
- return rp.stateServer
- } else if !rp.remotePublic {
- return rp.stateClient
- }
-
- if rp.localIP < rp.remoteIP {
- return rp.stateClient
- }
- return rp.stateServer
-}
-
-// ----------------------------------------------------------------------------
-
-func (rp *peerSuper) stateSelectMediator() stateFunc {
- rp.logf("STATE: SelectMediator")
-
- for {
- log.Printf("Selecting mediator...")
- if ip := rp.selectMediator(); ip != 0 {
- rp.logf("Got mediator: %d", ip)
- rp.peerData.relayIP = ip
-
- if rp.localIP < rp.remoteIP {
- return rp.stateClient
- }
- return rp.stateServer
- }
-
- select {
- case <-time.After(pingInterval):
- continue
- case rp.peer = <-rp.peerUpdates:
- return rp.stateInit
- }
- }
-
-}
-
-func (rp *peerSuper) selectMediator() byte {
- possible := make([]byte, 0, 8)
- for _, peer := range rp.peers {
- if peer.canRelay() {
- rp.logf("relay: %v", peer.shared.Load())
- possible = append(possible, peer.remoteIP)
- }
- }
- if len(possible) == 0 {
- return 0
- }
- return possible[rand.Intn(len(possible))]
-}
-
-// ----------------------------------------------------------------------------
-
-// The remote is a server.
-func (rp *peerSuper) stateServer() stateFunc {
- rp.logf("STATE: Server")
- rp.peerData.dataCipher = newDataCipher()
- rp.updateShared()
+ base.pingTimer.Stop()
+ base.timeoutTimer.Stop()
var (
- pingTimer = time.NewTimer(pingInterval)
- timeoutTimer = time.NewTimer(timeoutInterval)
- ping = pingPacket{SharedKey: ([32]byte)(rp.peerData.dataCipher.Key())}
+ curState peerState = base
+ nextState peerState
)
- defer pingTimer.Stop()
- defer timeoutTimer.Stop()
-
- ping.SentAt = time.Now().UnixMilli()
- rp.sendControlPacket(ping)
for {
+ nextState = nil
+
select {
- case <-pingTimer.C:
- ping.SentAt = time.Now().UnixMilli()
- rp.sendControlPacket(ping)
- pingTimer.Reset(pingInterval)
+ case peer := <-rp.peerUpdates:
+ nextState = curState.OnPeerUpdate(peer)
- case cPkt := <-rp.controlPackets:
- if _, ok := cPkt.Payload.(pongPacket); ok {
- timeoutTimer.Reset(timeoutInterval)
+ case pkt := <-rp.controlPackets:
+ switch p := pkt.Payload.(type) {
+ case pingPacket:
+ nextState = curState.OnPing(pkt.RemoteAddr, p)
+ case pongPacket:
+ nextState = curState.OnPong(pkt.RemoteAddr, p)
+ default:
+ // Unknown packet type.
}
- case <-timeoutTimer.C:
- if rp.peerData.relayIP != 0 {
- rp.logf("Timeout (server, relay)")
- return rp.stateSelectMediator
- } else {
- rp.logf("Timeout (server)")
- }
+ case <-base.pingTimer.C:
+ nextState = curState.OnPingTimer()
- case rp.peer = <-rp.peerUpdates:
- return rp.stateInit
+ case <-base.timeoutTimer.C:
+ nextState = curState.OnTimeoutTimer()
+ }
+
+ if nextState != nil {
+ rp.logf("%s --> %s", curState.Name(), nextState.Name())
+ curState = nextState
}
}
}
-
-// ----------------------------------------------------------------------------
-
-// The remote is a client.
-func (rp *peerSuper) stateClient() stateFunc {
- rp.logf("STATE: Client")
- rp.updateShared()
-
- var (
- currentKey = [32]byte{}
- timeoutTimer = time.NewTimer(timeoutInterval)
- )
-
- defer timeoutTimer.Stop()
-
- for {
- select {
- case cPkt := <-rp.controlPackets:
- if cPkt.RemoteAddr != rp.peerData.remoteAddr {
- rp.peerData.remoteAddr = cPkt.RemoteAddr
- rp.logf("Got new remote address: %v", cPkt.RemoteAddr)
- rp.updateShared()
- }
-
- ping, ok := cPkt.Payload.(pingPacket)
- if !ok {
- continue
- }
-
- if ping.SharedKey != currentKey {
- rp.logf("Connected with new shared key")
- currentKey = ping.SharedKey
- rp.peerData.up = true
- rp.peerData.dataCipher = newDataCipherFromKey(currentKey)
- rp.updateShared()
- }
-
- timeoutTimer.Reset(timeoutInterval)
- rp.sendControlPacket(newPongPacket(ping.SentAt))
-
- case <-timeoutTimer.C:
- if rp.peerData.relayIP != 0 {
- rp.logf("Timeout (server, relay)")
- return rp.stateSelectMediator
- } else {
- rp.logf("Timeout (server)")
- }
-
- case rp.peer = <-rp.peerUpdates:
- return rp.stateInit
- }
- }
-}
-
-// ----------------------------------------------------------------------------
-
-func (rp *peerSuper) updateShared() {
- data := rp.peerData
- rp.shared.Store(&data)
-}
-
-// ----------------------------------------------------------------------------
-
-func (rp *peerSuper) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) {
- buf := pkt.Marshal(rp.pktBuf)
- h := header{
- StreamID: controlStreamID,
- Counter: atomic.AddUint64(&rp.counter, 1),
- SourceIP: rp.localIP,
- DestIP: rp.remoteIP,
- }
- buf = rp.peerData.controlCipher.Encrypt(h, buf, rp.encBuf)
- if rp.peerData.relayIP == 0 {
- rp.conn.WriteTo(buf, rp.peerData.remoteAddr)
- return
- }
-
- rp.peers[rp.peerData.relayIP].RelayControlData(buf)
-}
diff --git a/node/peer.go b/node/peer.go
index d999339..3cc308e 100644
--- a/node/peer.go
+++ b/node/peer.go
@@ -22,19 +22,14 @@ type peerData struct {
type remotePeer struct {
// Immutable data.
- localIP byte
- remoteIP byte
- privKey []byte
- localPublic bool // True if local node is public.
- iface *ifWriter
- conn *connWriter
+ localIP byte
+ remoteIP byte
+ iface *ifWriter
+ conn *connWriter
// Shared state.
- peers *remotePeers
- shared *atomic.Pointer[peerData]
-
- // Only used in HandlePeerUpdate.
- peerVersion int64
+ peers *remotePeers
+ published *atomic.Pointer[peerData]
// Only used in HandlePacket / Not synchronized.
dupCheck *dupCheck
@@ -55,12 +50,10 @@ func newRemotePeer(conf m.PeerConfig, remoteIP byte, iface *ifWriter, conn *conn
rp := &remotePeer{
localIP: conf.PeerIP,
remoteIP: remoteIP,
- privKey: conf.EncPrivKey,
- localPublic: addrIsValid(conf.PublicIP),
iface: iface,
conn: conn,
peers: peers,
- shared: &atomic.Pointer[peerData]{},
+ published: &atomic.Pointer[peerData]{},
dupCheck: newDupCheck(0),
decryptBuf: make([]byte, bufferSize),
encryptBuf: make([]byte, bufferSize),
@@ -70,10 +63,10 @@ func newRemotePeer(conf m.PeerConfig, remoteIP byte, iface *ifWriter, conn *conn
}
pd := peerData{}
- rp.shared.Store(&pd)
-
- go newPeerSuper(rp).Run()
+ rp.published.Store(&pd)
+ //go newPeerSuper(rp).Run()
+ go rp.supervise(conf, remoteIP, conn, peers)
return rp
}
@@ -82,10 +75,7 @@ func (rp *remotePeer) logf(msg string, args ...any) {
}
func (rp *remotePeer) HandlePeerUpdate(peer *m.Peer) {
- if peer != nil && peer.Version != rp.peerVersion {
- rp.peerUpdates <- peer
- rp.peerVersion = peer.Version
- }
+ rp.peerUpdates <- peer
}
// ----------------------------------------------------------------------------
@@ -101,6 +91,9 @@ func (rp *remotePeer) HandlePacket(addr netip.AddrPort, h header, data []byte) {
case dataStreamID:
rp.handleDataPacket(data)
+ case relayStreamID:
+ rp.handleRelayPacket(h, data)
+
default:
rp.logf("Unknown stream ID: %d", h.StreamID)
}
@@ -109,8 +102,9 @@ func (rp *remotePeer) HandlePacket(addr netip.AddrPort, h header, data []byte) {
// ----------------------------------------------------------------------------
func (rp *remotePeer) handleControlPacket(addr netip.AddrPort, h header, data []byte) {
- shared := rp.shared.Load()
+ shared := rp.published.Load()
if shared.controlCipher == nil {
+ log.Printf("Shared: %+v", *shared)
rp.logf("Not connected (control).")
return
}
@@ -141,19 +135,7 @@ func (rp *remotePeer) handleControlPacket(addr netip.AddrPort, h header, data []
RemoteAddr: addr,
}
- var err error
-
- switch out[0] {
- case packetTypePing:
- pkt.Payload, err = parsePingPacket(out)
- case packetTypePong:
- pkt.Payload, err = parsePongPacket(out)
- default:
- rp.logf("Unknown control packet type: %d", out[0])
- return
- }
-
- if err != nil {
+ if err := pkt.ParsePayload(out); err != nil {
rp.logf("Failed to parse control packet: %v", err)
return
}
@@ -168,7 +150,7 @@ func (rp *remotePeer) handleControlPacket(addr netip.AddrPort, h header, data []
// ----------------------------------------------------------------------------
func (rp *remotePeer) handleDataPacket(data []byte) {
- shared := rp.shared.Load()
+ shared := rp.published.Load()
if shared.dataCipher == nil {
rp.logf("Not connected (recv).")
return
@@ -185,34 +167,65 @@ func (rp *remotePeer) handleDataPacket(data []byte) {
// ----------------------------------------------------------------------------
+func (rp *remotePeer) handleRelayPacket(h header, data []byte) {
+ shared := rp.published.Load()
+ if shared.dataCipher == nil {
+ rp.logf("Not connected (recv).")
+ return
+ }
+
+ dec, ok := shared.dataCipher.Decrypt(data, rp.decryptBuf)
+ if !ok {
+ rp.logf("Failed to decrypt data packet.")
+ return
+ }
+
+ rp.peers[h.DestIP].sendDirect(dec)
+}
+
+// ----------------------------------------------------------------------------
+
// SendData sends data coming from the interface going to the network.
//
// This function is called by a single thread.
func (rp *remotePeer) SendData(data []byte) {
- rp.sendData(dataStreamID, data)
+ rp.sendData(dataStreamID, rp.remoteIP, data)
}
-// ----------------------------------------------------------------------------
+func (rp *remotePeer) HandleInterfacePacket(data []byte) {
+ shared := rp.published.Load()
-func (rp *remotePeer) RelayControlData(data []byte) {
- rp.sendData(forwardStreamID, data)
-}
-
-// ----------------------------------------------------------------------------
-
-func (rp *remotePeer) ForwardPacket(data []byte) {
- shared := rp.shared.Load()
- if shared.remoteAddr == zeroAddrPort {
- rp.logf("Not connected (forward).")
+ if shared.dataCipher == nil {
+ rp.logf("Not connected (handle interface).")
return
}
- rp.conn.WriteTo(data, shared.remoteAddr)
+
+ h := header{
+ StreamID: dataStreamID,
+ Counter: atomic.AddUint64(&rp.counter, 1),
+ SourceIP: rp.localIP,
+ DestIP: rp.remoteIP,
+ }
+
+ enc := shared.dataCipher.Encrypt(h, data, rp.encryptBuf)
+
+ if shared.relayIP != 0 {
+ rp.peers[shared.relayIP].RelayData(shared.relayIP, enc)
+ } else {
+ rp.SendData(data)
+ }
}
// ----------------------------------------------------------------------------
-func (rp *remotePeer) sendData(streamID byte, data []byte) {
- shared := rp.shared.Load()
+func (rp *remotePeer) RelayData(destIP byte, data []byte) {
+ rp.sendData(relayStreamID, destIP, data)
+}
+
+// ----------------------------------------------------------------------------
+
+func (rp *remotePeer) sendData(streamID byte, destIP byte, data []byte) {
+ shared := rp.published.Load()
if shared.dataCipher == nil || shared.remoteAddr == zeroAddrPort {
rp.logf("Not connected (send).")
return
@@ -222,16 +235,18 @@ func (rp *remotePeer) sendData(streamID byte, data []byte) {
StreamID: streamID,
Counter: atomic.AddUint64(&rp.counter, 1),
SourceIP: rp.localIP,
- DestIP: rp.remoteIP,
+ DestIP: destIP,
}
enc := shared.dataCipher.Encrypt(h, data, rp.encryptBuf)
rp.conn.WriteTo(enc, shared.remoteAddr)
}
-// ----------------------------------------------------------------------------
-
-func (rp *remotePeer) canRelay() bool {
- shared := rp.shared.Load()
- return shared.relay && shared.up
+func (rp *remotePeer) sendDirect(data []byte) {
+ shared := rp.published.Load()
+ if shared.remoteAddr == zeroAddrPort {
+ rp.logf("Not connected (send).")
+ return
+ }
+ rp.conn.WriteTo(data, shared.remoteAddr)
}
--
2.39.5
From eb18dd1fa011905f4bca93c896de7b929525961b Mon Sep 17 00:00:00 2001
From: jdl
Date: Sat, 21 Dec 2024 06:55:28 +0100
Subject: [PATCH 08/18] wip: relaying working
---
node/peer-states.go | 201 ++++++++++++++++++++++++++++++++++------
node/peer-supervisor.go | 7 +-
node/peer.go | 80 +++++++++-------
3 files changed, 219 insertions(+), 69 deletions(-)
diff --git a/node/peer-states.go b/node/peer-states.go
index c3c0904..7a1de54 100644
--- a/node/peer-states.go
+++ b/node/peer-states.go
@@ -3,6 +3,7 @@ package node
import (
"fmt"
"log"
+ "math/rand"
"net/netip"
"sync/atomic"
"time"
@@ -22,7 +23,7 @@ type peerState interface {
type stateBase struct {
// The purpose of this state machine is to manage this published data.
- published *atomic.Pointer[peerData]
+ published *atomic.Pointer[peerRoutingData]
// The other remote peers.
peers *remotePeers
@@ -38,9 +39,9 @@ type stateBase struct {
counter *uint64
// Mutable peer data.
- peer *m.Peer
- remotePub bool
- data peerData // Local copy of shared data. See publish().
+ peer *m.Peer
+ remotePub bool
+ routingData peerRoutingData // Local copy of shared data. See publish().
// Timers
pingTimer *time.Timer
@@ -63,16 +64,24 @@ func (s *stateBase) OnPeerUpdate(peer *m.Peer) peerState {
return nil
}
- s.peer = peer
+ return s.selectStateFromPeer(peer)
+}
- s.data = peerData{}
- s.data.controlCipher = newControlCipher(s.privKey, peer.EncPubKey)
+func (s *stateBase) selectStateFromPeer(peer *m.Peer) peerState {
+ s.peer = peer
+ s.routingData = peerRoutingData{}
+
+ if peer == nil {
+ return newStateNoPeer(s)
+ }
+
+ s.routingData.controlCipher = newControlCipher(s.privKey, peer.EncPubKey)
ip, isValid := netip.AddrFromSlice(peer.PublicIP)
if isValid {
s.remotePub = true
- s.data.remoteAddr = netip.AddrPortFrom(ip, peer.Port)
- s.data.relay = peer.Mediator
+ s.routingData.remoteAddr = netip.AddrPortFrom(ip, peer.Port)
+ s.routingData.relay = peer.Mediator
if s.localPub && s.localIP < s.remoteIP {
return newStateServer(s)
@@ -84,9 +93,7 @@ func (s *stateBase) OnPeerUpdate(peer *m.Peer) peerState {
return newStateServer(s)
}
- // TODO: return newStateMediated(a/b)
-
- return nil
+ return newStateSelectRelay(s)
}
func (s *stateBase) OnPing(rAddr netip.AddrPort, p pingPacket) peerState { return nil }
@@ -106,10 +113,24 @@ func (s *stateBase) logf(msg string, args ...any) {
}
func (s *stateBase) publish() {
- data := s.data
+ data := s.routingData
s.published.Store(&data)
}
+func (s *stateBase) selectRelay() byte {
+ possible := make([]byte, 0, 8)
+ for i, peer := range s.peers {
+ if peer.CanRelay() {
+ possible = append(possible, byte(i))
+ }
+ }
+
+ if len(possible) == 0 {
+ return 0
+ }
+ return possible[rand.Intn(len(possible))]
+}
+
func (s *stateBase) sendPing(sharedKey [32]byte) {
s.sendControlPacket(newPingPacket(sharedKey))
}
@@ -127,13 +148,22 @@ func (s *stateBase) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) {
DestIP: s.remoteIP,
}
- buf = s.data.controlCipher.Encrypt(h, buf, s.encBuf)
- if s.data.relayIP == 0 {
- s.conn.WriteTo(buf, s.data.remoteAddr)
- return
+ buf = s.routingData.controlCipher.Encrypt(h, buf, s.encBuf)
+ if s.routingData.relayIP != 0 {
+ s.peers[s.routingData.relayIP].RelayFor(s.remoteIP, buf)
+ } else {
+ s.conn.WriteTo(buf, s.routingData.remoteAddr)
}
+}
- // TODO: Relay!
+// ----------------------------------------------------------------------------
+
+type stateNoPeer struct{ *stateBase }
+
+func newStateNoPeer(b *stateBase) *stateNoPeer {
+ s := &stateNoPeer{b}
+ s.publish()
+ return s
}
// ----------------------------------------------------------------------------
@@ -147,8 +177,8 @@ func newStateClient(b *stateBase) peerState {
s := &stateClient{stateBase: b}
s.publish()
- s.data.dataCipher = newDataCipher()
- s.sharedKey = s.data.dataCipher.Key()
+ s.routingData.dataCipher = newDataCipher()
+ s.sharedKey = s.routingData.dataCipher.Key()
s.sendPing(s.sharedKey)
s.resetPingTimer()
@@ -159,8 +189,8 @@ func newStateClient(b *stateBase) peerState {
func (s *stateClient) Name() string { return "client" }
func (s *stateClient) OnPong(addr netip.AddrPort, p pongPacket) peerState {
- if !s.data.up {
- s.data.up = true
+ if !s.routingData.up {
+ s.routingData.up = true
s.publish()
}
s.resetTimeoutTimer()
@@ -174,7 +204,7 @@ func (s *stateClient) OnPingTimer() peerState {
}
func (s *stateClient) OnTimeoutTimer() peerState {
- s.data.up = false
+ s.routingData.up = false
s.publish()
return nil
}
@@ -196,19 +226,134 @@ func newStateServer(b *stateBase) peerState {
func (s *stateServer) Name() string { return "server" }
func (s *stateServer) OnPing(addr netip.AddrPort, p pingPacket) peerState {
- if addr != s.data.remoteAddr {
+ if addr != s.routingData.remoteAddr {
s.logf("Got new peer address: %v", addr)
- s.data.remoteAddr = addr
- s.data.up = true
+ s.routingData.remoteAddr = addr
+ s.routingData.up = true
s.publish()
}
- if s.data.dataCipher == nil || p.SharedKey != s.data.dataCipher.Key() {
+ if s.routingData.dataCipher == nil || p.SharedKey != s.routingData.dataCipher.Key() {
s.logf("Got new shared key.")
- s.data.dataCipher = newDataCipherFromKey(p.SharedKey)
+ s.routingData.dataCipher = newDataCipherFromKey(p.SharedKey)
+ s.routingData.up = true
s.publish()
}
s.sendPong(p)
return nil
}
+
+// ----------------------------------------------------------------------------
+
+type stateSelectRelay struct {
+ *stateBase
+}
+
+func newStateSelectRelay(b *stateBase) peerState {
+ s := &stateSelectRelay{stateBase: b}
+ s.routingData.dataCipher = nil
+ s.routingData.up = false
+ s.publish()
+
+ if relay := s.selectRelay(); relay != 0 {
+ s.routingData.up = false
+ s.routingData.relayIP = relay
+ return s.selectRole()
+ }
+
+ s.resetPingTimer()
+ s.stopTimeoutTimer()
+ return s
+}
+
+func (s *stateSelectRelay) selectRole() peerState {
+ if s.localIP < s.remoteIP {
+ return newStateServerRelayed(s.stateBase)
+ }
+ return newStateClientRelayed(s.stateBase)
+}
+
+func (s *stateSelectRelay) Name() string { return "select-relay" }
+
+func (s *stateSelectRelay) OnPingTimer() peerState {
+ if relay := s.selectRelay(); relay != 0 {
+ s.routingData.relayIP = relay
+ return s.selectRole()
+ }
+ s.resetPingTimer()
+ return nil
+}
+
+// ----------------------------------------------------------------------------
+
+type stateClientRelayed struct {
+ sharedKey [32]byte
+ *stateBase
+}
+
+func newStateClientRelayed(b *stateBase) peerState {
+ s := &stateClientRelayed{stateBase: b}
+
+ s.routingData.dataCipher = newDataCipher()
+ s.sharedKey = s.routingData.dataCipher.Key()
+ s.publish()
+
+ s.sendPing(s.sharedKey)
+ s.resetPingTimer()
+ s.resetTimeoutTimer()
+ return s
+}
+
+func (s *stateClientRelayed) Name() string { return "client-relayed" }
+
+func (s *stateClientRelayed) OnPong(addr netip.AddrPort, p pongPacket) peerState {
+ if !s.routingData.up {
+ s.routingData.up = true
+ s.publish()
+ }
+ s.resetTimeoutTimer()
+ return nil
+}
+
+func (s *stateClientRelayed) OnPingTimer() peerState {
+ s.sendPing(s.sharedKey)
+ s.resetPingTimer()
+ return nil
+}
+
+func (s *stateClientRelayed) OnTimeoutTimer() peerState {
+ return newStateSelectRelay(s.stateBase)
+}
+
+// ----------------------------------------------------------------------------
+
+type stateServerRelayed struct {
+ *stateBase
+}
+
+func newStateServerRelayed(b *stateBase) peerState {
+ s := &stateServerRelayed{b}
+ s.stopPingTimer()
+ s.resetTimeoutTimer()
+ return s
+}
+
+func (s *stateServerRelayed) Name() string { return "server-relayed" }
+
+func (s *stateServerRelayed) OnPing(addr netip.AddrPort, p pingPacket) peerState {
+ if s.routingData.dataCipher == nil || p.SharedKey != s.routingData.dataCipher.Key() {
+ s.logf("Got new shared key.")
+ s.routingData.up = true
+ s.routingData.dataCipher = newDataCipherFromKey(p.SharedKey)
+ s.publish()
+ }
+
+ s.sendPong(p)
+ s.resetTimeoutTimer()
+ return nil
+}
+
+func (s *stateServerRelayed) OnTimeoutTimer() peerState {
+ return newStateSelectRelay(s.stateBase)
+}
diff --git a/node/peer-supervisor.go b/node/peer-supervisor.go
index 2c46ad2..ac2508e 100644
--- a/node/peer-supervisor.go
+++ b/node/peer-supervisor.go
@@ -11,12 +11,7 @@ const (
timeoutInterval = 20 * time.Second
)
-func (rp *remotePeer) supervise(
- conf m.PeerConfig,
- remoteIP byte,
- conn *connWriter,
- peers *remotePeers,
-) {
+func (rp *remotePeer) supervise(conf m.PeerConfig) {
defer panicHandler()
base := &stateBase{
diff --git a/node/peer.go b/node/peer.go
index 3cc308e..bae2c9c 100644
--- a/node/peer.go
+++ b/node/peer.go
@@ -11,7 +11,7 @@ import (
type remotePeers [256]*remotePeer
-type peerData struct {
+type peerRoutingData struct {
up bool
relay bool
controlCipher *controlCipher
@@ -29,7 +29,7 @@ type remotePeer struct {
// Shared state.
peers *remotePeers
- published *atomic.Pointer[peerData]
+ published *atomic.Pointer[peerRoutingData]
// Only used in HandlePacket / Not synchronized.
dupCheck *dupCheck
@@ -53,7 +53,7 @@ func newRemotePeer(conf m.PeerConfig, remoteIP byte, iface *ifWriter, conn *conn
iface: iface,
conn: conn,
peers: peers,
- published: &atomic.Pointer[peerData]{},
+ published: &atomic.Pointer[peerRoutingData]{},
dupCheck: newDupCheck(0),
decryptBuf: make([]byte, bufferSize),
encryptBuf: make([]byte, bufferSize),
@@ -62,11 +62,11 @@ func newRemotePeer(conf m.PeerConfig, remoteIP byte, iface *ifWriter, conn *conn
controlPackets: make(chan controlPacket, 512),
}
- pd := peerData{}
+ pd := peerRoutingData{}
rp.published.Store(&pd)
//go newPeerSuper(rp).Run()
- go rp.supervise(conf, remoteIP, conn, peers)
+ go rp.supervise(conf)
return rp
}
@@ -102,9 +102,8 @@ func (rp *remotePeer) HandlePacket(addr netip.AddrPort, h header, data []byte) {
// ----------------------------------------------------------------------------
func (rp *remotePeer) handleControlPacket(addr netip.AddrPort, h header, data []byte) {
- shared := rp.published.Load()
- if shared.controlCipher == nil {
- log.Printf("Shared: %+v", *shared)
+ routingData := rp.published.Load()
+ if routingData.controlCipher == nil {
rp.logf("Not connected (control).")
return
}
@@ -114,7 +113,7 @@ func (rp *remotePeer) handleControlPacket(addr netip.AddrPort, h header, data []
return
}
- out, ok := shared.controlCipher.Decrypt(data, rp.decryptBuf)
+ out, ok := routingData.controlCipher.Decrypt(data, rp.decryptBuf)
if !ok {
rp.logf("Failed to decrypt control packet.")
return
@@ -150,13 +149,13 @@ func (rp *remotePeer) handleControlPacket(addr netip.AddrPort, h header, data []
// ----------------------------------------------------------------------------
func (rp *remotePeer) handleDataPacket(data []byte) {
- shared := rp.published.Load()
- if shared.dataCipher == nil {
+ routingData := rp.published.Load()
+ if routingData.dataCipher == nil {
rp.logf("Not connected (recv).")
return
}
- dec, ok := shared.dataCipher.Decrypt(data, rp.decryptBuf)
+ dec, ok := routingData.dataCipher.Decrypt(data, rp.decryptBuf)
if !ok {
rp.logf("Failed to decrypt data packet.")
return
@@ -168,19 +167,19 @@ func (rp *remotePeer) handleDataPacket(data []byte) {
// ----------------------------------------------------------------------------
func (rp *remotePeer) handleRelayPacket(h header, data []byte) {
- shared := rp.published.Load()
- if shared.dataCipher == nil {
+ routingData := rp.published.Load()
+ if routingData.dataCipher == nil {
rp.logf("Not connected (recv).")
return
}
- dec, ok := shared.dataCipher.Decrypt(data, rp.decryptBuf)
+ dec, ok := routingData.dataCipher.Decrypt(data, rp.decryptBuf)
if !ok {
rp.logf("Failed to decrypt data packet.")
return
}
- rp.peers[h.DestIP].sendDirect(dec)
+ rp.peers[h.DestIP].SendAsIs(dec)
}
// ----------------------------------------------------------------------------
@@ -189,13 +188,13 @@ func (rp *remotePeer) handleRelayPacket(h header, data []byte) {
//
// This function is called by a single thread.
func (rp *remotePeer) SendData(data []byte) {
- rp.sendData(dataStreamID, rp.remoteIP, data)
+ rp.encryptAndSend(dataStreamID, rp.remoteIP, data)
}
func (rp *remotePeer) HandleInterfacePacket(data []byte) {
- shared := rp.published.Load()
+ routingData := rp.published.Load()
- if shared.dataCipher == nil {
+ if routingData.dataCipher == nil {
rp.logf("Not connected (handle interface).")
return
}
@@ -207,10 +206,10 @@ func (rp *remotePeer) HandleInterfacePacket(data []byte) {
DestIP: rp.remoteIP,
}
- enc := shared.dataCipher.Encrypt(h, data, rp.encryptBuf)
+ enc := routingData.dataCipher.Encrypt(h, data, rp.encryptBuf)
- if shared.relayIP != 0 {
- rp.peers[shared.relayIP].RelayData(shared.relayIP, enc)
+ if routingData.relayIP != 0 {
+ rp.peers[routingData.relayIP].RelayFor(rp.remoteIP, enc)
} else {
rp.SendData(data)
}
@@ -218,16 +217,23 @@ func (rp *remotePeer) HandleInterfacePacket(data []byte) {
// ----------------------------------------------------------------------------
-func (rp *remotePeer) RelayData(destIP byte, data []byte) {
- rp.sendData(relayStreamID, destIP, data)
+func (rp *remotePeer) CanRelay() bool {
+ data := rp.published.Load()
+ return data.relay && data.up
}
// ----------------------------------------------------------------------------
-func (rp *remotePeer) sendData(streamID byte, destIP byte, data []byte) {
- shared := rp.published.Load()
- if shared.dataCipher == nil || shared.remoteAddr == zeroAddrPort {
- rp.logf("Not connected (send).")
+func (rp *remotePeer) RelayFor(destIP byte, data []byte) {
+ rp.encryptAndSend(relayStreamID, destIP, data)
+}
+
+// ----------------------------------------------------------------------------
+
+func (rp *remotePeer) encryptAndSend(streamID byte, destIP byte, data []byte) {
+ routingData := rp.published.Load()
+ if routingData.dataCipher == nil || routingData.remoteAddr == zeroAddrPort {
+ rp.logf("Not connected (encrypt and send).")
return
}
@@ -238,15 +244,19 @@ func (rp *remotePeer) sendData(streamID byte, destIP byte, data []byte) {
DestIP: destIP,
}
- enc := shared.dataCipher.Encrypt(h, data, rp.encryptBuf)
- rp.conn.WriteTo(enc, shared.remoteAddr)
+ enc := routingData.dataCipher.Encrypt(h, data, rp.encryptBuf)
+ rp.conn.WriteTo(enc, routingData.remoteAddr)
}
-func (rp *remotePeer) sendDirect(data []byte) {
- shared := rp.published.Load()
- if shared.remoteAddr == zeroAddrPort {
- rp.logf("Not connected (send).")
+// ----------------------------------------------------------------------------
+
+// SendAsIs is used when forwarding already-encrypted data from one peer to
+// another.
+func (rp *remotePeer) SendAsIs(data []byte) {
+ routingData := rp.published.Load()
+ if routingData.remoteAddr == zeroAddrPort {
+ rp.logf("Not connected (send direct).")
return
}
- rp.conn.WriteTo(data, shared.remoteAddr)
+ rp.conn.WriteTo(data, routingData.remoteAddr)
}
--
2.39.5
From 1d68e4f79e4c11a356a96da98097489590d0e00a Mon Sep 17 00:00:00 2001
From: jdl
Date: Sat, 21 Dec 2024 20:28:04 +0100
Subject: [PATCH 09/18] wip: working, modifying logic to allow local discovery
and hole punching in the future.
---
README.md | 5 +-
node/main.go | 4 +-
node/packets-util.go | 165 ++++++++++++++++++++++++++++++++++++++
node/packets-util_test.go | 40 +++++++++
node/packets.go | 100 ++++++++++++++++++++---
node/packets_test.go | 49 +++++++++++
node/peer-states.go | 97 +++++++++++++---------
node/peer-supervisor.go | 6 ++
8 files changed, 407 insertions(+), 59 deletions(-)
create mode 100644 node/packets-util.go
create mode 100644 node/packets-util_test.go
diff --git a/README.md b/README.md
index b9d291e..87e3072 100644
--- a/README.md
+++ b/README.md
@@ -2,11 +2,8 @@
## Roadmap
+* Use probe and relayed-probe packets vs ping/pong.
* Rename Mediator -> Relay
-* Node: use symmetric encryption after handshake
-* AEAD-AES uses a 12 byte nonce. We need to shrink the header:
- * Remove Forward and replace it with a HeaderFlags bitfield.
- * Forward, Asym/Sym, ...
* Use default port 456
* Remove signing key from hub
* Peer: UDP hole-punching
diff --git a/node/main.go b/node/main.go
index 35a00e6..c291e73 100644
--- a/node/main.go
+++ b/node/main.go
@@ -112,7 +112,7 @@ func main(netName, listenIP string, port uint16) {
}
go newHubPoller(netName, conf, peers).Run()
- go readFromConn(conf.PeerIP, conn, peers)
+ go readFromConn(conn, peers)
readFromIFace(iface, peers)
}
@@ -130,7 +130,7 @@ func determinePort(confPort, portFromCommandLine uint16) uint16 {
// ----------------------------------------------------------------------------
-func readFromConn(localIP byte, conn *net.UDPConn, peers remotePeers) {
+func readFromConn(conn *net.UDPConn, peers remotePeers) {
defer panicHandler()
diff --git a/node/packets-util.go b/node/packets-util.go
new file mode 100644
index 0000000..8a6e13a
--- /dev/null
+++ b/node/packets-util.go
@@ -0,0 +1,165 @@
+package node
+
+import (
+ "net/netip"
+ "sync/atomic"
+ "unsafe"
+ "vppn/fasttime"
+)
+
+var (
+ traceIDCounter uint64
+)
+
+func newTraceID() uint64 {
+ return uint64(fasttime.Now()<<30) + atomic.AddUint64(&traceIDCounter, 1)
+}
+
+// ----------------------------------------------------------------------------
+
+type binWriter struct {
+ b []byte
+ i int
+}
+
+func newBinWriter(buf []byte) *binWriter {
+ buf = buf[:cap(buf)]
+ return &binWriter{buf, 0}
+}
+
+func (w *binWriter) Bool(b bool) *binWriter {
+ if b {
+ return w.Byte(1)
+ }
+ return w.Byte(0)
+}
+
+func (w *binWriter) Byte(b byte) *binWriter {
+ w.b[w.i] = b
+ w.i++
+ return w
+}
+
+func (w *binWriter) SharedKey(key [32]byte) *binWriter {
+ copy(w.b[w.i:w.i+32], key[:])
+ w.i += 32
+ return w
+}
+
+func (w *binWriter) Uint16(x uint16) *binWriter {
+ *(*uint16)(unsafe.Pointer(&w.b[w.i])) = x
+ w.i += 2
+ return w
+}
+
+func (w *binWriter) Uint64(x uint64) *binWriter {
+ *(*uint64)(unsafe.Pointer(&w.b[w.i])) = x
+ w.i += 8
+ return w
+}
+
+func (w *binWriter) Int64(x int64) *binWriter {
+ *(*int64)(unsafe.Pointer(&w.b[w.i])) = x
+ w.i += 8
+ return w
+}
+
+func (w *binWriter) AddrPort(addrPort netip.AddrPort) *binWriter {
+ addr := addrPort.Addr().As16()
+ copy(w.b[w.i:w.i+16], addr[:])
+ w.i += 16
+ return w.Uint16(addrPort.Port())
+}
+
+func (w *binWriter) Build() []byte {
+ return w.b[:w.i]
+}
+
+// ----------------------------------------------------------------------------
+
+type binReader struct {
+ b []byte
+ i int
+ err error
+}
+
+func newBinReader(buf []byte) *binReader {
+ return &binReader{b: buf}
+}
+
+func (r *binReader) hasBytes(n int) bool {
+ if r.err != nil || (len(r.b)-r.i) < n {
+ r.err = errMalformedPacket
+ return false
+ }
+ return true
+}
+
+func (r *binReader) Bool(b *bool) *binReader {
+ var bb byte
+ r.Byte(&bb)
+ *b = bb != 0
+ return r
+}
+
+func (r *binReader) Byte(b *byte) *binReader {
+ if !r.hasBytes(1) {
+ return r
+ }
+ *b = r.b[r.i]
+ r.i++
+ return r
+}
+
+func (r *binReader) SharedKey(x *[32]byte) *binReader {
+ if !r.hasBytes(32) {
+ return r
+ }
+ *x = ([32]byte)(r.b[r.i : r.i+32])
+ r.i += 32
+ return r
+}
+
+func (r *binReader) Uint16(x *uint16) *binReader {
+ if !r.hasBytes(2) {
+ return r
+ }
+ *x = *(*uint16)(unsafe.Pointer(&r.b[r.i]))
+ r.i += 2
+ return r
+}
+
+func (r *binReader) Uint64(x *uint64) *binReader {
+ if !r.hasBytes(8) {
+ return r
+ }
+ *x = *(*uint64)(unsafe.Pointer(&r.b[r.i]))
+ r.i += 8
+ return r
+}
+
+func (r *binReader) Int64(x *int64) *binReader {
+ if !r.hasBytes(8) {
+ return r
+ }
+ *x = *(*int64)(unsafe.Pointer(&r.b[r.i]))
+ r.i += 8
+ return r
+}
+
+func (r *binReader) AddrPort(x *netip.AddrPort) *binReader {
+ if !r.hasBytes(18) {
+ return r
+ }
+ addr := netip.AddrFrom16(([16]byte)(r.b[r.i : r.i+16]))
+ addr = addr.Unmap()
+ r.i += 16
+ var port uint16
+ r.Uint16(&port)
+ *x = netip.AddrPortFrom(addr, port)
+ return r
+}
+
+func (r *binReader) Error() error {
+ return r.err
+}
diff --git a/node/packets-util_test.go b/node/packets-util_test.go
new file mode 100644
index 0000000..06b0370
--- /dev/null
+++ b/node/packets-util_test.go
@@ -0,0 +1,40 @@
+package node
+
+import (
+ "net/netip"
+ "reflect"
+ "testing"
+)
+
+func TestBinWriteRead(t *testing.T) {
+ buf := make([]byte, 1024)
+
+ type Item struct {
+ Type byte
+ TraceID uint64
+ DestAddr netip.AddrPort
+ }
+
+ in := Item{1, 2, netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 22)}
+
+ buf = newBinWriter(buf).
+ Byte(in.Type).
+ Uint64(in.TraceID).
+ AddrPort(in.DestAddr).
+ Build()
+
+ out := Item{}
+
+ err := newBinReader(buf).
+ Byte(&out.Type).
+ Uint64(&out.TraceID).
+ AddrPort(&out.DestAddr).
+ Error()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if !reflect.DeepEqual(in, out) {
+ t.Fatal(in, out)
+ }
+}
diff --git a/node/packets.go b/node/packets.go
index 57c7341..bbc1262 100644
--- a/node/packets.go
+++ b/node/packets.go
@@ -13,8 +13,12 @@ var (
)
const (
- packetTypePing = iota + 1
+ packetTypeSyn = iota + 1
+ packetTypeSynAck
+ packetTypeAck
+ packetTypePing
packetTypePong
+ packetTypeRelayed
)
// ----------------------------------------------------------------------------
@@ -31,6 +35,8 @@ func (p *controlPacket) ParsePayload(buf []byte) (err error) {
p.Payload, err = parsePingPacket(buf)
case packetTypePong:
p.Payload, err = parsePongPacket(buf)
+ case packetTypeSyn:
+ p.Payload, err = parseSynPacket(buf)
default:
return errUnknownPacketType
}
@@ -39,34 +45,102 @@ func (p *controlPacket) ParsePayload(buf []byte) (err error) {
// ----------------------------------------------------------------------------
+type synPacket struct {
+ TraceID uint64 // TraceID to match response w/ request.
+ SharedKey [32]byte // Our shared key.
+ ServerAddr netip.AddrPort // The address we're sending to.
+ Direct bool // True if this is request isn't relayed.
+}
+
+func (p synPacket) Marshal(buf []byte) []byte {
+ return newBinWriter(buf).
+ Byte(packetTypeSyn).
+ Uint64(p.TraceID).
+ SharedKey(p.SharedKey).
+ AddrPort(p.ServerAddr).
+ Bool(p.Direct).
+ Build()
+}
+
+func parseSynPacket(buf []byte) (p synPacket, err error) {
+ err = newBinReader(buf[1:]).
+ Uint64(&p.TraceID).
+ SharedKey(&p.SharedKey).
+ AddrPort(&p.ServerAddr).
+ Bool(&p.Direct).
+ Error()
+ return
+}
+
+// ----------------------------------------------------------------------------
+
+type synAckPacket struct {
+ TraceID uint64
+}
+
+func (p synAckPacket) Marshal(buf []byte) []byte {
+ return newBinWriter(buf).
+ Byte(packetTypeSynAck).
+ Uint64(p.TraceID).
+ Build()
+}
+
+func parseSynAckPacket(buf []byte) (p synAckPacket, err error) {
+ err = newBinReader(buf[1:]).
+ Uint64(&p.TraceID).
+ Error()
+ return
+}
+
+// ----------------------------------------------------------------------------
+
+type ackPacket struct {
+ TraceID uint64
+}
+
+func (p ackPacket) Marshal(buf []byte) []byte {
+ return newBinWriter(buf).
+ Byte(packetTypeSynAck).
+ Uint64(p.TraceID).
+ Build()
+}
+
+func parseAckPacket(buf []byte) (p ackPacket, err error) {
+ err = newBinReader(buf[1:]).
+ Uint64(&p.TraceID).
+ Error()
+ return
+}
+
+// ----------------------------------------------------------------------------
+
// A pingPacket is sent from a node acting as a client, to a node acting
// as a server. It always contains the shared key the client is expecting
// to use for data encryption with the server.
type pingPacket struct {
- SentAt int64 // UnixMilli.
+ SentAt int64 // UnixMilli. // Not used. Use traceID.
SharedKey [32]byte
}
func newPingPacket(sharedKey [32]byte) (pp pingPacket) {
pp.SentAt = time.Now().UnixMilli()
- copy(pp.SharedKey[:], sharedKey[:])
+ pp.SharedKey = sharedKey
return
}
func (p pingPacket) Marshal(buf []byte) []byte {
- buf = buf[:41]
- buf[0] = packetTypePing
- *(*uint64)(unsafe.Pointer(&buf[1])) = uint64(p.SentAt)
- copy(buf[9:41], p.SharedKey[:])
- return buf
+ return newBinWriter(buf).
+ Byte(packetTypePing).
+ Int64(p.SentAt).
+ SharedKey(p.SharedKey).
+ Build()
}
func parsePingPacket(buf []byte) (p pingPacket, err error) {
- if len(buf) != 41 {
- return p, errMalformedPacket
- }
- p.SentAt = *(*int64)(unsafe.Pointer(&buf[1]))
- copy(p.SharedKey[:], buf[9:41])
+ err = newBinReader(buf[1:]).
+ Int64(&p.SentAt).
+ SharedKey(&p.SharedKey).
+ Error()
return
}
diff --git a/node/packets_test.go b/node/packets_test.go
index da242d4..6d96ccb 100644
--- a/node/packets_test.go
+++ b/node/packets_test.go
@@ -2,10 +2,59 @@ package node
import (
"crypto/rand"
+ "net/netip"
"reflect"
"testing"
)
+func TestPacketSyn(t *testing.T) {
+ in := synPacket{
+ TraceID: newTraceID(),
+ Direct: true,
+ ServerAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 34),
+ }
+ rand.Read(in.SharedKey[:])
+
+ out, err := parseSynPacket(in.Marshal(make([]byte, bufferSize)))
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if !reflect.DeepEqual(in, out) {
+ t.Fatal("\n", in, "\n", out)
+ }
+}
+
+func TestPacketSynAck(t *testing.T) {
+ in := synAckPacket{
+ TraceID: newTraceID(),
+ }
+
+ out, err := parseSynAckPacket(in.Marshal(make([]byte, bufferSize)))
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if !reflect.DeepEqual(in, out) {
+ t.Fatal("\n", in, "\n", out)
+ }
+}
+
+func TestPacketAck(t *testing.T) {
+ in := ackPacket{
+ TraceID: newTraceID(),
+ }
+
+ out, err := parseAckPacket(in.Marshal(make([]byte, bufferSize)))
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if !reflect.DeepEqual(in, out) {
+ t.Fatal("\n", in, "\n", out)
+ }
+}
+
func TestPacketPing(t *testing.T) {
sharedKey := make([]byte, 32)
rand.Read(sharedKey)
diff --git a/node/peer-states.go b/node/peer-states.go
index 7a1de54..35ebc0b 100644
--- a/node/peer-states.go
+++ b/node/peer-states.go
@@ -12,7 +12,14 @@ import (
type peerState interface {
Name() string
+ OnSyn(netip.AddrPort, synPacket) peerState
+ OnSynAck(netip.AddrPort, synAckPacket) peerState
+ OnAck(netip.AddrPort, ackPacket) peerState
+
+ // When the peer is updated, we reset. Handled by base state.
OnPeerUpdate(*m.Peer) peerState
+
+ // To determe up / dataCipher. Handled by base state.
OnPing(netip.AddrPort, pingPacket) peerState
OnPong(netip.AddrPort, pongPacket) peerState
OnPingTimer() peerState
@@ -24,6 +31,7 @@ type peerState interface {
type stateBase struct {
// The purpose of this state machine is to manage this published data.
published *atomic.Pointer[peerRoutingData]
+ staged peerRoutingData // Local copy of shared data. See publish().
// The other remote peers.
peers *remotePeers
@@ -39,9 +47,8 @@ type stateBase struct {
counter *uint64
// Mutable peer data.
- peer *m.Peer
- remotePub bool
- routingData peerRoutingData // Local copy of shared data. See publish().
+ peer *m.Peer
+ remotePub bool
// Timers
pingTimer *time.Timer
@@ -69,19 +76,19 @@ func (s *stateBase) OnPeerUpdate(peer *m.Peer) peerState {
func (s *stateBase) selectStateFromPeer(peer *m.Peer) peerState {
s.peer = peer
- s.routingData = peerRoutingData{}
+ s.staged = peerRoutingData{}
+ defer s.publish()
if peer == nil {
return newStateNoPeer(s)
}
-
- s.routingData.controlCipher = newControlCipher(s.privKey, peer.EncPubKey)
+ s.staged.controlCipher = newControlCipher(s.privKey, peer.EncPubKey)
ip, isValid := netip.AddrFromSlice(peer.PublicIP)
if isValid {
s.remotePub = true
- s.routingData.remoteAddr = netip.AddrPortFrom(ip, peer.Port)
- s.routingData.relay = peer.Mediator
+ s.staged.remoteAddr = netip.AddrPortFrom(ip, peer.Port)
+ s.staged.relay = peer.Mediator
if s.localPub && s.localIP < s.remoteIP {
return newStateServer(s)
@@ -96,10 +103,16 @@ func (s *stateBase) selectStateFromPeer(peer *m.Peer) peerState {
return newStateSelectRelay(s)
}
-func (s *stateBase) OnPing(rAddr netip.AddrPort, p pingPacket) peerState { return nil }
-func (s *stateBase) OnPong(rAddr netip.AddrPort, p pongPacket) peerState { return nil }
-func (s *stateBase) OnPingTimer() peerState { return nil }
-func (s *stateBase) OnTimeoutTimer() peerState { return nil }
+func (s *stateBase) OnSyn(rAddr netip.AddrPort, p synPacket) peerState { return nil }
+func (s *stateBase) OnSynAck(rAddr netip.AddrPort, p synAckPacket) peerState { return nil }
+func (s *stateBase) OnAck(rAddr netip.AddrPort, p ackPacket) peerState { return nil }
+func (s *stateBase) OnPing(rAddr netip.AddrPort, p pingPacket) peerState { return nil }
+func (s *stateBase) OnPong(rAddr netip.AddrPort, p pongPacket) peerState { return nil }
+func (s *stateBase) OnPingTimer() peerState { return nil }
+
+func (s *stateBase) OnTimeoutTimer() peerState {
+ return s.selectStateFromPeer(s.peer)
+}
// Helpers.
@@ -113,7 +126,7 @@ func (s *stateBase) logf(msg string, args ...any) {
}
func (s *stateBase) publish() {
- data := s.routingData
+ data := s.staged
s.published.Store(&data)
}
@@ -148,11 +161,11 @@ func (s *stateBase) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) {
DestIP: s.remoteIP,
}
- buf = s.routingData.controlCipher.Encrypt(h, buf, s.encBuf)
- if s.routingData.relayIP != 0 {
- s.peers[s.routingData.relayIP].RelayFor(s.remoteIP, buf)
+ buf = s.staged.controlCipher.Encrypt(h, buf, s.encBuf)
+ if s.staged.relayIP != 0 {
+ s.peers[s.staged.relayIP].RelayFor(s.remoteIP, buf)
} else {
- s.conn.WriteTo(buf, s.routingData.remoteAddr)
+ s.conn.WriteTo(buf, s.staged.remoteAddr)
}
}
@@ -162,6 +175,8 @@ type stateNoPeer struct{ *stateBase }
func newStateNoPeer(b *stateBase) *stateNoPeer {
s := &stateNoPeer{b}
+ s.pingTimer.Stop()
+ s.timeoutTimer.Stop()
s.publish()
return s
}
@@ -177,8 +192,8 @@ func newStateClient(b *stateBase) peerState {
s := &stateClient{stateBase: b}
s.publish()
- s.routingData.dataCipher = newDataCipher()
- s.sharedKey = s.routingData.dataCipher.Key()
+ s.staged.dataCipher = newDataCipher()
+ s.sharedKey = s.staged.dataCipher.Key()
s.sendPing(s.sharedKey)
s.resetPingTimer()
@@ -189,8 +204,8 @@ func newStateClient(b *stateBase) peerState {
func (s *stateClient) Name() string { return "client" }
func (s *stateClient) OnPong(addr netip.AddrPort, p pongPacket) peerState {
- if !s.routingData.up {
- s.routingData.up = true
+ if !s.staged.up {
+ s.staged.up = true
s.publish()
}
s.resetTimeoutTimer()
@@ -204,7 +219,7 @@ func (s *stateClient) OnPingTimer() peerState {
}
func (s *stateClient) OnTimeoutTimer() peerState {
- s.routingData.up = false
+ s.staged.up = false
s.publish()
return nil
}
@@ -226,17 +241,17 @@ func newStateServer(b *stateBase) peerState {
func (s *stateServer) Name() string { return "server" }
func (s *stateServer) OnPing(addr netip.AddrPort, p pingPacket) peerState {
- if addr != s.routingData.remoteAddr {
+ if addr != s.staged.remoteAddr {
s.logf("Got new peer address: %v", addr)
- s.routingData.remoteAddr = addr
- s.routingData.up = true
+ s.staged.remoteAddr = addr
+ s.staged.up = true
s.publish()
}
- if s.routingData.dataCipher == nil || p.SharedKey != s.routingData.dataCipher.Key() {
+ if s.staged.dataCipher == nil || p.SharedKey != s.staged.dataCipher.Key() {
s.logf("Got new shared key.")
- s.routingData.dataCipher = newDataCipherFromKey(p.SharedKey)
- s.routingData.up = true
+ s.staged.dataCipher = newDataCipherFromKey(p.SharedKey)
+ s.staged.up = true
s.publish()
}
@@ -252,13 +267,13 @@ type stateSelectRelay struct {
func newStateSelectRelay(b *stateBase) peerState {
s := &stateSelectRelay{stateBase: b}
- s.routingData.dataCipher = nil
- s.routingData.up = false
+ s.staged.dataCipher = nil
+ s.staged.up = false
s.publish()
if relay := s.selectRelay(); relay != 0 {
- s.routingData.up = false
- s.routingData.relayIP = relay
+ s.staged.up = false
+ s.staged.relayIP = relay
return s.selectRole()
}
@@ -278,7 +293,8 @@ func (s *stateSelectRelay) Name() string { return "select-relay" }
func (s *stateSelectRelay) OnPingTimer() peerState {
if relay := s.selectRelay(); relay != 0 {
- s.routingData.relayIP = relay
+ s.logf("Got relay IP: %d", relay)
+ s.staged.relayIP = relay
return s.selectRole()
}
s.resetPingTimer()
@@ -295,8 +311,8 @@ type stateClientRelayed struct {
func newStateClientRelayed(b *stateBase) peerState {
s := &stateClientRelayed{stateBase: b}
- s.routingData.dataCipher = newDataCipher()
- s.sharedKey = s.routingData.dataCipher.Key()
+ s.staged.dataCipher = newDataCipher()
+ s.sharedKey = s.staged.dataCipher.Key()
s.publish()
s.sendPing(s.sharedKey)
@@ -308,10 +324,11 @@ func newStateClientRelayed(b *stateBase) peerState {
func (s *stateClientRelayed) Name() string { return "client-relayed" }
func (s *stateClientRelayed) OnPong(addr netip.AddrPort, p pongPacket) peerState {
- if !s.routingData.up {
- s.routingData.up = true
+ if !s.staged.up {
+ s.staged.up = true
s.publish()
}
+
s.resetTimeoutTimer()
return nil
}
@@ -342,10 +359,10 @@ func newStateServerRelayed(b *stateBase) peerState {
func (s *stateServerRelayed) Name() string { return "server-relayed" }
func (s *stateServerRelayed) OnPing(addr netip.AddrPort, p pingPacket) peerState {
- if s.routingData.dataCipher == nil || p.SharedKey != s.routingData.dataCipher.Key() {
+ if s.staged.dataCipher == nil || p.SharedKey != s.staged.dataCipher.Key() {
s.logf("Got new shared key.")
- s.routingData.up = true
- s.routingData.dataCipher = newDataCipherFromKey(p.SharedKey)
+ s.staged.up = true
+ s.staged.dataCipher = newDataCipherFromKey(p.SharedKey)
s.publish()
}
diff --git a/node/peer-supervisor.go b/node/peer-supervisor.go
index ac2508e..08691aa 100644
--- a/node/peer-supervisor.go
+++ b/node/peer-supervisor.go
@@ -46,6 +46,12 @@ func (rp *remotePeer) supervise(conf m.PeerConfig) {
case pkt := <-rp.controlPackets:
switch p := pkt.Payload.(type) {
+ case synPacket:
+ nextState = curState.OnSyn(pkt.RemoteAddr, p)
+ case synAckPacket:
+ nextState = curState.OnSynAck(pkt.RemoteAddr, p)
+ case ackPacket:
+ nextState = curState.OnAck(pkt.RemoteAddr, p)
case pingPacket:
nextState = curState.OnPing(pkt.RemoteAddr, p)
case pongPacket:
--
2.39.5
From a263f65c5d036d72f9e732a98046f25bd1d4a1b7 Mon Sep 17 00:00:00 2001
From: jdl
Date: Sun, 22 Dec 2024 13:58:09 +0100
Subject: [PATCH 10/18] wip
---
node/packets.go | 24 +++--
node/peer-states.go | 233 ++++++++++++++++++++++------------------
node/peer-supervisor.go | 10 +-
3 files changed, 147 insertions(+), 120 deletions(-)
diff --git a/node/packets.go b/node/packets.go
index bbc1262..ffda859 100644
--- a/node/packets.go
+++ b/node/packets.go
@@ -37,6 +37,10 @@ func (p *controlPacket) ParsePayload(buf []byte) (err error) {
p.Payload, err = parsePongPacket(buf)
case packetTypeSyn:
p.Payload, err = parseSynPacket(buf)
+ case packetTypeSynAck:
+ p.Payload, err = parseSynAckPacket(buf)
+ case packetTypeAck:
+ p.Payload, err = parseAckPacket(buf)
default:
return errUnknownPacketType
}
@@ -49,7 +53,7 @@ type synPacket struct {
TraceID uint64 // TraceID to match response w/ request.
SharedKey [32]byte // Our shared key.
ServerAddr netip.AddrPort // The address we're sending to.
- Direct bool // True if this is request isn't relayed.
+ RelayIP byte
}
func (p synPacket) Marshal(buf []byte) []byte {
@@ -58,7 +62,7 @@ func (p synPacket) Marshal(buf []byte) []byte {
Uint64(p.TraceID).
SharedKey(p.SharedKey).
AddrPort(p.ServerAddr).
- Bool(p.Direct).
+ Byte(p.RelayIP).
Build()
}
@@ -67,7 +71,7 @@ func parseSynPacket(buf []byte) (p synPacket, err error) {
Uint64(&p.TraceID).
SharedKey(&p.SharedKey).
AddrPort(&p.ServerAddr).
- Bool(&p.Direct).
+ Byte(&p.RelayIP).
Error()
return
}
@@ -78,6 +82,10 @@ type synAckPacket struct {
TraceID uint64
}
+func newSynAckPacket(traceID uint64) synAckPacket {
+ return synAckPacket{traceID}
+}
+
func (p synAckPacket) Marshal(buf []byte) []byte {
return newBinWriter(buf).
Byte(packetTypeSynAck).
@@ -100,7 +108,7 @@ type ackPacket struct {
func (p ackPacket) Marshal(buf []byte) []byte {
return newBinWriter(buf).
- Byte(packetTypeSynAck).
+ Byte(packetTypeAck).
Uint64(p.TraceID).
Build()
}
@@ -118,13 +126,11 @@ func parseAckPacket(buf []byte) (p ackPacket, err error) {
// as a server. It always contains the shared key the client is expecting
// to use for data encryption with the server.
type pingPacket struct {
- SentAt int64 // UnixMilli. // Not used. Use traceID.
- SharedKey [32]byte
+ SentAt int64 // UnixMilli. // Not used. Use traceID.
}
-func newPingPacket(sharedKey [32]byte) (pp pingPacket) {
+func newPingPacket() (pp pingPacket) {
pp.SentAt = time.Now().UnixMilli()
- pp.SharedKey = sharedKey
return
}
@@ -132,14 +138,12 @@ func (p pingPacket) Marshal(buf []byte) []byte {
return newBinWriter(buf).
Byte(packetTypePing).
Int64(p.SentAt).
- SharedKey(p.SharedKey).
Build()
}
func parsePingPacket(buf []byte) (p pingPacket, err error) {
err = newBinReader(buf[1:]).
Int64(&p.SentAt).
- SharedKey(&p.SharedKey).
Error()
return
}
diff --git a/node/peer-states.go b/node/peer-states.go
index 35ebc0b..39990bd 100644
--- a/node/peer-states.go
+++ b/node/peer-states.go
@@ -16,14 +16,11 @@ type peerState interface {
OnSynAck(netip.AddrPort, synAckPacket) peerState
OnAck(netip.AddrPort, ackPacket) peerState
- // When the peer is updated, we reset. Handled by base state.
- OnPeerUpdate(*m.Peer) peerState
-
- // To determe up / dataCipher. Handled by base state.
- OnPing(netip.AddrPort, pingPacket) peerState
- OnPong(netip.AddrPort, pongPacket) peerState
OnPingTimer() peerState
OnTimeoutTimer() peerState
+
+ // When the peer is updated, we reset. Handled by base state.
+ OnPeerUpdate(*m.Peer) peerState
}
// ----------------------------------------------------------------------------
@@ -82,37 +79,39 @@ func (s *stateBase) selectStateFromPeer(peer *m.Peer) peerState {
if peer == nil {
return newStateNoPeer(s)
}
+
s.staged.controlCipher = newControlCipher(s.privKey, peer.EncPubKey)
+ s.staged.dataCipher = newDataCipher()
+
+ s.resetPingTimer()
+ s.resetTimeoutTimer()
ip, isValid := netip.AddrFromSlice(peer.PublicIP)
if isValid {
s.remotePub = true
s.staged.remoteAddr = netip.AddrPortFrom(ip, peer.Port)
s.staged.relay = peer.Mediator
+ }
- if s.localPub && s.localIP < s.remoteIP {
- return newStateServer(s)
+ if s.remotePub == s.localPub {
+ if s.localIP < s.remoteIP {
+ return newStateServer2(s)
}
- return newStateClient(s)
+ return newStateDialLocal(s)
}
- if s.localPub {
- return newStateServer(s)
+ if s.remotePub {
+ return newStateDialLocal(s)
}
-
- return newStateSelectRelay(s)
+ return newStateServer2(s)
}
func (s *stateBase) OnSyn(rAddr netip.AddrPort, p synPacket) peerState { return nil }
func (s *stateBase) OnSynAck(rAddr netip.AddrPort, p synAckPacket) peerState { return nil }
func (s *stateBase) OnAck(rAddr netip.AddrPort, p ackPacket) peerState { return nil }
-func (s *stateBase) OnPing(rAddr netip.AddrPort, p pingPacket) peerState { return nil }
-func (s *stateBase) OnPong(rAddr netip.AddrPort, p pongPacket) peerState { return nil }
-func (s *stateBase) OnPingTimer() peerState { return nil }
-func (s *stateBase) OnTimeoutTimer() peerState {
- return s.selectStateFromPeer(s.peer)
-}
+func (s *stateBase) OnPingTimer() peerState { return nil }
+func (s *stateBase) OnTimeoutTimer() peerState { return nil }
// Helpers.
@@ -144,14 +143,6 @@ func (s *stateBase) selectRelay() byte {
return possible[rand.Intn(len(possible))]
}
-func (s *stateBase) sendPing(sharedKey [32]byte) {
- s.sendControlPacket(newPingPacket(sharedKey))
-}
-
-func (s *stateBase) sendPong(ping pingPacket) {
- s.sendControlPacket(newPongPacket(ping.SentAt))
-}
-
func (s *stateBase) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) {
buf := pkt.Marshal(s.buf)
h := header{
@@ -183,6 +174,117 @@ func newStateNoPeer(b *stateBase) *stateNoPeer {
// ----------------------------------------------------------------------------
+type stateServer2 struct {
+ *stateBase
+ syn synPacket
+ publishedTraceID uint64
+}
+
+// TODO: Server should send SynAck packets on a loop.
+func newStateServer2(b *stateBase) peerState {
+ s := &stateServer2{stateBase: b}
+ s.resetTimeoutTimer()
+ return s
+}
+
+func (s *stateServer2) Name() string { return "server" }
+
+func (s *stateServer2) OnSyn(remoteAddr netip.AddrPort, p synPacket) peerState {
+ s.syn = p
+ s.sendControlPacket(newSynAckPacket(p.TraceID))
+ return nil
+}
+
+func (s *stateServer2) OnAck(remoteAddr netip.AddrPort, p ackPacket) peerState {
+ if p.TraceID != s.syn.TraceID {
+ return nil
+ }
+
+ s.resetTimeoutTimer()
+
+ if p.TraceID == s.publishedTraceID {
+ return nil
+ }
+
+ // Pubish staged
+ s.staged.remoteAddr = remoteAddr
+ s.staged.dataCipher = newDataCipherFromKey(s.syn.SharedKey)
+ s.staged.relayIP = s.syn.RelayIP
+ s.staged.up = true
+ s.publish()
+
+ s.publishedTraceID = p.TraceID
+ return nil
+}
+
+func (s *stateServer) OnTimeoutTimer() peerState {
+ // TODO: We're down.
+ return nil
+}
+
+// ----------------------------------------------------------------------------
+
+type stateDialLocal struct {
+ *stateBase
+ syn synPacket
+}
+
+func newStateDialLocal(b *stateBase) peerState {
+ // s := stateDialLocal{stateBase: b}
+ // TODO: check for peer local address.
+ return newStateDialDirect(b)
+}
+
+func (s *stateDialLocal) Name() string { return "dial-local" }
+
+// ----------------------------------------------------------------------------
+
+type stateDialDirect struct {
+ *stateBase
+ syn synPacket
+}
+
+func newStateDialDirect(b *stateBase) peerState {
+ // If we don't have an address, dial via relay.
+ if b.staged.remoteAddr == zeroAddrPort {
+ return newStateNoPeer(b)
+ }
+
+ s := &stateDialDirect{stateBase: b}
+ s.syn = synPacket{
+ TraceID: newTraceID(),
+ SharedKey: s.staged.dataCipher.Key(),
+ ServerAddr: b.staged.remoteAddr,
+ }
+
+ s.sendControlPacket(s.syn)
+ s.resetTimeoutTimer()
+
+ return s
+}
+
+func (s *stateDialDirect) Name() string { return "dial-direct" }
+
+func (s *stateDialDirect) OnSynAck(remoteAddr netip.AddrPort, p synAckPacket) peerState {
+ if p.TraceID != s.syn.TraceID {
+ // Hmm...
+ return nil
+ }
+
+ s.sendControlPacket(ackPacket{TraceID: s.syn.TraceID})
+ s.logf("GOT SYN-ACK! TODO!")
+ // client should continue to respond to synAck packets from server.
+ // return newStateClientConnected(s.stateBase, s.syn.TraceID) ...
+ return nil
+}
+
+func (s *stateDialDirect) OnTimeoutTimer() peerState {
+ s.logf("Timeout when dialing")
+ return newStateDialLocal(s.stateBase)
+}
+
+// ----------------------------------------------------------------------------
+
type stateClient struct {
sharedKey [32]byte
*stateBase
@@ -195,7 +297,7 @@ func newStateClient(b *stateBase) peerState {
s.staged.dataCipher = newDataCipher()
s.sharedKey = s.staged.dataCipher.Key()
- s.sendPing(s.sharedKey)
+ s.sendControlPacket(newPingPacket())
s.resetPingTimer()
s.resetTimeoutTimer()
return s
@@ -203,27 +305,6 @@ func newStateClient(b *stateBase) peerState {
func (s *stateClient) Name() string { return "client" }
-func (s *stateClient) OnPong(addr netip.AddrPort, p pongPacket) peerState {
- if !s.staged.up {
- s.staged.up = true
- s.publish()
- }
- s.resetTimeoutTimer()
- return nil
-}
-
-func (s *stateClient) OnPingTimer() peerState {
- s.sendPing(s.sharedKey)
- s.resetPingTimer()
- return nil
-}
-
-func (s *stateClient) OnTimeoutTimer() peerState {
- s.staged.up = false
- s.publish()
- return nil
-}
-
// ----------------------------------------------------------------------------
type stateServer struct {
@@ -240,25 +321,6 @@ func newStateServer(b *stateBase) peerState {
func (s *stateServer) Name() string { return "server" }
-func (s *stateServer) OnPing(addr netip.AddrPort, p pingPacket) peerState {
- if addr != s.staged.remoteAddr {
- s.logf("Got new peer address: %v", addr)
- s.staged.remoteAddr = addr
- s.staged.up = true
- s.publish()
- }
-
- if s.staged.dataCipher == nil || p.SharedKey != s.staged.dataCipher.Key() {
- s.logf("Got new shared key.")
- s.staged.dataCipher = newDataCipherFromKey(p.SharedKey)
- s.staged.up = true
- s.publish()
- }
-
- s.sendPong(p)
- return nil
-}
-
// ----------------------------------------------------------------------------
type stateSelectRelay struct {
@@ -315,7 +377,7 @@ func newStateClientRelayed(b *stateBase) peerState {
s.sharedKey = s.staged.dataCipher.Key()
s.publish()
- s.sendPing(s.sharedKey)
+ s.sendControlPacket(newPingPacket())
s.resetPingTimer()
s.resetTimeoutTimer()
return s
@@ -323,26 +385,6 @@ func newStateClientRelayed(b *stateBase) peerState {
func (s *stateClientRelayed) Name() string { return "client-relayed" }
-func (s *stateClientRelayed) OnPong(addr netip.AddrPort, p pongPacket) peerState {
- if !s.staged.up {
- s.staged.up = true
- s.publish()
- }
-
- s.resetTimeoutTimer()
- return nil
-}
-
-func (s *stateClientRelayed) OnPingTimer() peerState {
- s.sendPing(s.sharedKey)
- s.resetPingTimer()
- return nil
-}
-
-func (s *stateClientRelayed) OnTimeoutTimer() peerState {
- return newStateSelectRelay(s.stateBase)
-}
-
// ----------------------------------------------------------------------------
type stateServerRelayed struct {
@@ -358,19 +400,6 @@ func newStateServerRelayed(b *stateBase) peerState {
func (s *stateServerRelayed) Name() string { return "server-relayed" }
-func (s *stateServerRelayed) OnPing(addr netip.AddrPort, p pingPacket) peerState {
- if s.staged.dataCipher == nil || p.SharedKey != s.staged.dataCipher.Key() {
- s.logf("Got new shared key.")
- s.staged.up = true
- s.staged.dataCipher = newDataCipherFromKey(p.SharedKey)
- s.publish()
- }
-
- s.sendPong(p)
- s.resetTimeoutTimer()
- return nil
-}
-
func (s *stateServerRelayed) OnTimeoutTimer() peerState {
return newStateSelectRelay(s.stateBase)
}
diff --git a/node/peer-supervisor.go b/node/peer-supervisor.go
index 08691aa..3f3e0a0 100644
--- a/node/peer-supervisor.go
+++ b/node/peer-supervisor.go
@@ -6,6 +6,7 @@ import (
)
const (
+ dialTimeout = 8 * time.Second
connectTimeout = 6 * time.Second
pingInterval = 6 * time.Second
timeoutInterval = 20 * time.Second
@@ -29,11 +30,8 @@ func (rp *remotePeer) supervise(conf m.PeerConfig) {
encBuf: make([]byte, bufferSize),
}
- base.pingTimer.Stop()
- base.timeoutTimer.Stop()
-
var (
- curState peerState = base
+ curState peerState = newStateNoPeer(base)
nextState peerState
)
@@ -52,10 +50,6 @@ func (rp *remotePeer) supervise(conf m.PeerConfig) {
nextState = curState.OnSynAck(pkt.RemoteAddr, p)
case ackPacket:
nextState = curState.OnAck(pkt.RemoteAddr, p)
- case pingPacket:
- nextState = curState.OnPing(pkt.RemoteAddr, p)
- case pongPacket:
- nextState = curState.OnPong(pkt.RemoteAddr, p)
default:
// Unknown packet type.
}
--
2.39.5
From 08f11ce82ba98c5abd5d5319b7077822ae0bac73 Mon Sep 17 00:00:00 2001
From: jdl
Date: Sun, 22 Dec 2024 19:17:58 +0100
Subject: [PATCH 11/18] WIP: client/server/relay working.
---
node/globals.go | 8 +-
node/packets.go | 75 +------
node/packets_test.go | 40 +---
node/peer-states.go | 405 --------------------------------------
node/peer-super-states.go | 276 ++++++++++++++++++++++++++
node/peer-super.go | 80 ++++++++
node/peer-supervisor.go | 63 ++----
node/peer.go | 15 +-
node/router.go | 7 -
9 files changed, 393 insertions(+), 576 deletions(-)
delete mode 100644 node/peer-states.go
create mode 100644 node/peer-super-states.go
create mode 100644 node/peer-super.go
delete mode 100644 node/router.go
diff --git a/node/globals.go b/node/globals.go
index d646e71..b78c2c9 100644
--- a/node/globals.go
+++ b/node/globals.go
@@ -1,9 +1,15 @@
package node
+import "net/netip"
+
const (
bufferSize = 1536
- if_mtu = 1200
+ if_mtu = 1400
if_queue_len = 2048
controlCipherOverhead = 16
dataCipherOverhead = 16
)
+
+var (
+ zeroAddrPort = netip.AddrPort{}
+)
diff --git a/node/packets.go b/node/packets.go
index ffda859..0126359 100644
--- a/node/packets.go
+++ b/node/packets.go
@@ -3,8 +3,6 @@ package node
import (
"errors"
"net/netip"
- "time"
- "unsafe"
)
var (
@@ -31,10 +29,6 @@ type controlPacket struct {
func (p *controlPacket) ParsePayload(buf []byte) (err error) {
switch buf[0] {
- case packetTypePing:
- p.Payload, err = parsePingPacket(buf)
- case packetTypePong:
- p.Payload, err = parsePongPacket(buf)
case packetTypeSyn:
p.Payload, err = parseSynPacket(buf)
case packetTypeSynAck:
@@ -50,10 +44,9 @@ func (p *controlPacket) ParsePayload(buf []byte) (err error) {
// ----------------------------------------------------------------------------
type synPacket struct {
- TraceID uint64 // TraceID to match response w/ request.
- SharedKey [32]byte // Our shared key.
- ServerAddr netip.AddrPort // The address we're sending to.
- RelayIP byte
+ TraceID uint64 // TraceID to match response w/ request.
+ SharedKey [32]byte // Our shared key.
+ RelayIP byte
}
func (p synPacket) Marshal(buf []byte) []byte {
@@ -61,7 +54,6 @@ func (p synPacket) Marshal(buf []byte) []byte {
Byte(packetTypeSyn).
Uint64(p.TraceID).
SharedKey(p.SharedKey).
- AddrPort(p.ServerAddr).
Byte(p.RelayIP).
Build()
}
@@ -70,7 +62,6 @@ func parseSynPacket(buf []byte) (p synPacket, err error) {
err = newBinReader(buf[1:]).
Uint64(&p.TraceID).
SharedKey(&p.SharedKey).
- AddrPort(&p.ServerAddr).
Byte(&p.RelayIP).
Error()
return
@@ -119,63 +110,3 @@ func parseAckPacket(buf []byte) (p ackPacket, err error) {
Error()
return
}
-
-// ----------------------------------------------------------------------------
-
-// A pingPacket is sent from a node acting as a client, to a node acting
-// as a server. It always contains the shared key the client is expecting
-// to use for data encryption with the server.
-type pingPacket struct {
- SentAt int64 // UnixMilli. // Not used. Use traceID.
-}
-
-func newPingPacket() (pp pingPacket) {
- pp.SentAt = time.Now().UnixMilli()
- return
-}
-
-func (p pingPacket) Marshal(buf []byte) []byte {
- return newBinWriter(buf).
- Byte(packetTypePing).
- Int64(p.SentAt).
- Build()
-}
-
-func parsePingPacket(buf []byte) (p pingPacket, err error) {
- err = newBinReader(buf[1:]).
- Int64(&p.SentAt).
- Error()
- return
-}
-
-// ----------------------------------------------------------------------------
-
-// A pongPacket is sent by a node in a server role in response to a pingPacket.
-type pongPacket struct {
- SentAt int64 // UnixMilli.
- RecvdAt int64 // UnixMilli.
-}
-
-func newPongPacket(sentAt int64) (pp pongPacket) {
- pp.SentAt = sentAt
- pp.RecvdAt = time.Now().UnixMilli()
- return
-}
-
-func (p pongPacket) Marshal(buf []byte) []byte {
- buf = buf[:17]
- buf[0] = packetTypePong
- *(*uint64)(unsafe.Pointer(&buf[1])) = uint64(p.SentAt)
- *(*uint64)(unsafe.Pointer(&buf[9])) = uint64(p.RecvdAt)
-
- return buf
-}
-
-func parsePongPacket(buf []byte) (p pongPacket, err error) {
- if len(buf) != 17 {
- return p, errMalformedPacket
- }
- p.SentAt = *(*int64)(unsafe.Pointer(&buf[1]))
- p.RecvdAt = *(*int64)(unsafe.Pointer(&buf[9]))
- return
-}
diff --git a/node/packets_test.go b/node/packets_test.go
index 6d96ccb..660d30e 100644
--- a/node/packets_test.go
+++ b/node/packets_test.go
@@ -2,16 +2,13 @@ package node
import (
"crypto/rand"
- "net/netip"
"reflect"
"testing"
)
func TestPacketSyn(t *testing.T) {
in := synPacket{
- TraceID: newTraceID(),
- Direct: true,
- ServerAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 34),
+ TraceID: newTraceID(),
}
rand.Read(in.SharedKey[:])
@@ -54,38 +51,3 @@ func TestPacketAck(t *testing.T) {
t.Fatal("\n", in, "\n", out)
}
}
-
-func TestPacketPing(t *testing.T) {
- sharedKey := make([]byte, 32)
- rand.Read(sharedKey)
-
- buf := make([]byte, bufferSize)
-
- p := newPingPacket([32]byte(sharedKey))
- out := p.Marshal(buf)
-
- p2, err := parsePingPacket(out)
- if err != nil {
- t.Fatal(err)
- }
-
- if !reflect.DeepEqual(p, p2) {
- t.Fatal(p, p2)
- }
-}
-
-func TestPacketPong(t *testing.T) {
- buf := make([]byte, bufferSize)
-
- p := newPongPacket(123566)
- out := p.Marshal(buf)
-
- p2, err := parsePongPacket(out)
- if err != nil {
- t.Fatal(err)
- }
-
- if !reflect.DeepEqual(p, p2) {
- t.Fatal(p, p2)
- }
-}
diff --git a/node/peer-states.go b/node/peer-states.go
deleted file mode 100644
index 39990bd..0000000
--- a/node/peer-states.go
+++ /dev/null
@@ -1,405 +0,0 @@
-package node
-
-import (
- "fmt"
- "log"
- "math/rand"
- "net/netip"
- "sync/atomic"
- "time"
- "vppn/m"
-)
-
-type peerState interface {
- Name() string
- OnSyn(netip.AddrPort, synPacket) peerState
- OnSynAck(netip.AddrPort, synAckPacket) peerState
- OnAck(netip.AddrPort, ackPacket) peerState
-
- OnPingTimer() peerState
- OnTimeoutTimer() peerState
-
- // When the peer is updated, we reset. Handled by base state.
- OnPeerUpdate(*m.Peer) peerState
-}
-
-// ----------------------------------------------------------------------------
-
-type stateBase struct {
- // The purpose of this state machine is to manage this published data.
- published *atomic.Pointer[peerRoutingData]
- staged peerRoutingData // Local copy of shared data. See publish().
-
- // The other remote peers.
- peers *remotePeers
-
- // Immutable data.
- localIP byte
- localPub bool
- remoteIP byte
- privKey []byte
- conn *connWriter
-
- // For sending to peer.
- counter *uint64
-
- // Mutable peer data.
- peer *m.Peer
- remotePub bool
-
- // Timers
- pingTimer *time.Timer
- timeoutTimer *time.Timer
-
- buf []byte
- encBuf []byte
-}
-
-func (sb *stateBase) Name() string { return "idle" }
-
-func (s *stateBase) OnPeerUpdate(peer *m.Peer) peerState {
- // Both nil: no change.
- if peer == nil && s.peer == nil {
- return nil
- }
-
- // No change.
- if peer != nil && s.peer != nil && s.peer.Version == peer.Version {
- return nil
- }
-
- return s.selectStateFromPeer(peer)
-}
-
-func (s *stateBase) selectStateFromPeer(peer *m.Peer) peerState {
- s.peer = peer
- s.staged = peerRoutingData{}
- defer s.publish()
-
- if peer == nil {
- return newStateNoPeer(s)
- }
-
- s.staged.controlCipher = newControlCipher(s.privKey, peer.EncPubKey)
- s.staged.dataCipher = newDataCipher()
-
- s.resetPingTimer()
- s.resetTimeoutTimer()
-
- ip, isValid := netip.AddrFromSlice(peer.PublicIP)
- if isValid {
- s.remotePub = true
- s.staged.remoteAddr = netip.AddrPortFrom(ip, peer.Port)
- s.staged.relay = peer.Mediator
- }
-
- if s.remotePub == s.localPub {
- if s.localIP < s.remoteIP {
- return newStateServer2(s)
- }
- return newStateDialLocal(s)
- }
-
- if s.remotePub {
- return newStateDialLocal(s)
- }
- return newStateServer2(s)
-}
-
-func (s *stateBase) OnSyn(rAddr netip.AddrPort, p synPacket) peerState { return nil }
-func (s *stateBase) OnSynAck(rAddr netip.AddrPort, p synAckPacket) peerState { return nil }
-func (s *stateBase) OnAck(rAddr netip.AddrPort, p ackPacket) peerState { return nil }
-
-func (s *stateBase) OnPingTimer() peerState { return nil }
-func (s *stateBase) OnTimeoutTimer() peerState { return nil }
-
-// Helpers.
-
-func (s *stateBase) resetPingTimer() { s.pingTimer.Reset(pingInterval) }
-func (s *stateBase) resetTimeoutTimer() { s.timeoutTimer.Reset(timeoutInterval) }
-func (s *stateBase) stopPingTimer() { s.pingTimer.Stop() }
-func (s *stateBase) stopTimeoutTimer() { s.timeoutTimer.Stop() }
-
-func (s *stateBase) logf(msg string, args ...any) {
- log.Printf(fmt.Sprintf("[%03d] ", s.remoteIP)+msg, args...)
-}
-
-func (s *stateBase) publish() {
- data := s.staged
- s.published.Store(&data)
-}
-
-func (s *stateBase) selectRelay() byte {
- possible := make([]byte, 0, 8)
- for i, peer := range s.peers {
- if peer.CanRelay() {
- possible = append(possible, byte(i))
- }
- }
-
- if len(possible) == 0 {
- return 0
- }
- return possible[rand.Intn(len(possible))]
-}
-
-func (s *stateBase) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) {
- buf := pkt.Marshal(s.buf)
- h := header{
- StreamID: controlStreamID,
- Counter: atomic.AddUint64(s.counter, 1),
- SourceIP: s.localIP,
- DestIP: s.remoteIP,
- }
-
- buf = s.staged.controlCipher.Encrypt(h, buf, s.encBuf)
- if s.staged.relayIP != 0 {
- s.peers[s.staged.relayIP].RelayFor(s.remoteIP, buf)
- } else {
- s.conn.WriteTo(buf, s.staged.remoteAddr)
- }
-}
-
-// ----------------------------------------------------------------------------
-
-type stateNoPeer struct{ *stateBase }
-
-func newStateNoPeer(b *stateBase) *stateNoPeer {
- s := &stateNoPeer{b}
- s.pingTimer.Stop()
- s.timeoutTimer.Stop()
- s.publish()
- return s
-}
-
-// ----------------------------------------------------------------------------
-
-type stateServer2 struct {
- *stateBase
- syn synPacket
- publishedTraceID uint64
-}
-
-// TODO: Server should send SynAck packets on a loop.
-func newStateServer2(b *stateBase) peerState {
- s := &stateServer2{stateBase: b}
- s.resetTimeoutTimer()
- return s
-}
-
-func (s *stateServer2) Name() string { return "server" }
-
-func (s *stateServer2) OnSyn(remoteAddr netip.AddrPort, p synPacket) peerState {
- s.syn = p
- s.sendControlPacket(newSynAckPacket(p.TraceID))
- return nil
-}
-
-func (s *stateServer2) OnAck(remoteAddr netip.AddrPort, p ackPacket) peerState {
- if p.TraceID != s.syn.TraceID {
- return nil
- }
-
- s.resetTimeoutTimer()
-
- if p.TraceID == s.publishedTraceID {
- return nil
- }
-
- // Pubish staged
- s.staged.remoteAddr = remoteAddr
- s.staged.dataCipher = newDataCipherFromKey(s.syn.SharedKey)
- s.staged.relayIP = s.syn.RelayIP
- s.staged.up = true
- s.publish()
-
- s.publishedTraceID = p.TraceID
- return nil
-}
-
-func (s *stateServer) OnTimeoutTimer() peerState {
- // TODO: We're down.
- return nil
-}
-
-// ----------------------------------------------------------------------------
-
-type stateDialLocal struct {
- *stateBase
- syn synPacket
-}
-
-func newStateDialLocal(b *stateBase) peerState {
- // s := stateDialLocal{stateBase: b}
- // TODO: check for peer local address.
- return newStateDialDirect(b)
-}
-
-func (s *stateDialLocal) Name() string { return "dial-local" }
-
-// ----------------------------------------------------------------------------
-
-type stateDialDirect struct {
- *stateBase
- syn synPacket
-}
-
-func newStateDialDirect(b *stateBase) peerState {
- // If we don't have an address, dial via relay.
- if b.staged.remoteAddr == zeroAddrPort {
- return newStateNoPeer(b)
- }
-
- s := &stateDialDirect{stateBase: b}
- s.syn = synPacket{
- TraceID: newTraceID(),
- SharedKey: s.staged.dataCipher.Key(),
- ServerAddr: b.staged.remoteAddr,
- }
-
- s.sendControlPacket(s.syn)
- s.resetTimeoutTimer()
-
- return s
-}
-
-func (s *stateDialDirect) Name() string { return "dial-direct" }
-
-func (s *stateDialDirect) OnSynAck(remoteAddr netip.AddrPort, p synAckPacket) peerState {
- if p.TraceID != s.syn.TraceID {
- // Hmm...
- return nil
- }
-
- s.sendControlPacket(ackPacket{TraceID: s.syn.TraceID})
- s.logf("GOT SYN-ACK! TODO!")
- // client should continue to respond to synAck packets from server.
- // return newStateClientConnected(s.stateBase, s.syn.TraceID) ...
- return nil
-}
-
-func (s *stateDialDirect) OnTimeoutTimer() peerState {
- s.logf("Timeout when dialing")
- return newStateDialLocal(s.stateBase)
-}
-
-// ----------------------------------------------------------------------------
-
-type stateClient struct {
- sharedKey [32]byte
- *stateBase
-}
-
-func newStateClient(b *stateBase) peerState {
- s := &stateClient{stateBase: b}
- s.publish()
-
- s.staged.dataCipher = newDataCipher()
- s.sharedKey = s.staged.dataCipher.Key()
-
- s.sendControlPacket(newPingPacket())
- s.resetPingTimer()
- s.resetTimeoutTimer()
- return s
-}
-
-func (s *stateClient) Name() string { return "client" }
-
-// ----------------------------------------------------------------------------
-
-type stateServer struct {
- *stateBase
-}
-
-func newStateServer(b *stateBase) peerState {
- s := &stateServer{b}
- s.publish()
- s.stopPingTimer()
- s.stopTimeoutTimer()
- return s
-}
-
-func (s *stateServer) Name() string { return "server" }
-
-// ----------------------------------------------------------------------------
-
-type stateSelectRelay struct {
- *stateBase
-}
-
-func newStateSelectRelay(b *stateBase) peerState {
- s := &stateSelectRelay{stateBase: b}
- s.staged.dataCipher = nil
- s.staged.up = false
- s.publish()
-
- if relay := s.selectRelay(); relay != 0 {
- s.staged.up = false
- s.staged.relayIP = relay
- return s.selectRole()
- }
-
- s.resetPingTimer()
- s.stopTimeoutTimer()
- return s
-}
-
-func (s *stateSelectRelay) selectRole() peerState {
- if s.localIP < s.remoteIP {
- return newStateServerRelayed(s.stateBase)
- }
- return newStateClientRelayed(s.stateBase)
-}
-
-func (s *stateSelectRelay) Name() string { return "select-relay" }
-
-func (s *stateSelectRelay) OnPingTimer() peerState {
- if relay := s.selectRelay(); relay != 0 {
- s.logf("Got relay IP: %d", relay)
- s.staged.relayIP = relay
- return s.selectRole()
- }
- s.resetPingTimer()
- return nil
-}
-
-// ----------------------------------------------------------------------------
-
-type stateClientRelayed struct {
- sharedKey [32]byte
- *stateBase
-}
-
-func newStateClientRelayed(b *stateBase) peerState {
- s := &stateClientRelayed{stateBase: b}
-
- s.staged.dataCipher = newDataCipher()
- s.sharedKey = s.staged.dataCipher.Key()
- s.publish()
-
- s.sendControlPacket(newPingPacket())
- s.resetPingTimer()
- s.resetTimeoutTimer()
- return s
-}
-
-func (s *stateClientRelayed) Name() string { return "client-relayed" }
-
-// ----------------------------------------------------------------------------
-
-type stateServerRelayed struct {
- *stateBase
-}
-
-func newStateServerRelayed(b *stateBase) peerState {
- s := &stateServerRelayed{b}
- s.stopPingTimer()
- s.resetTimeoutTimer()
- return s
-}
-
-func (s *stateServerRelayed) Name() string { return "server-relayed" }
-
-func (s *stateServerRelayed) OnTimeoutTimer() peerState {
- return newStateSelectRelay(s.stateBase)
-}
diff --git a/node/peer-super-states.go b/node/peer-super-states.go
new file mode 100644
index 0000000..6e615ae
--- /dev/null
+++ b/node/peer-super-states.go
@@ -0,0 +1,276 @@
+package node
+
+import (
+ "math/rand"
+ "net/netip"
+ "time"
+ "vppn/m"
+)
+
+// ----------------------------------------------------------------------------
+
+func (s *peerSuper) noPeer() stateFunc {
+ return s.peerUpdate(<-s.peerUpdates)
+}
+
+// ----------------------------------------------------------------------------
+
+func (s *peerSuper) peerUpdate(peer *m.Peer) stateFunc {
+ return func() stateFunc { return s._peerUpdate(peer) }
+}
+
+func (s *peerSuper) _peerUpdate(peer *m.Peer) stateFunc {
+ defer s.publish()
+
+ s.peer = peer
+ s.staged = peerRoutingData{}
+
+ if s.peer == nil {
+ return s.noPeer
+ }
+
+ s.staged.controlCipher = newControlCipher(s.privKey, peer.EncPubKey)
+ s.staged.dataCipher = newDataCipher()
+
+ if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid {
+ s.remotePub = true
+ s.staged.relay = peer.Mediator
+ s.staged.remoteAddr = netip.AddrPortFrom(ip, peer.Port)
+ }
+
+ if s.remotePub == s.localPub {
+ if s.localIP < s.remoteIP {
+ return s.serverAccept
+ }
+ return s.clientInit
+ }
+
+ if s.remotePub {
+ return s.clientInit
+ }
+ return s.serverAccept
+}
+
+// ----------------------------------------------------------------------------
+
+func (s *peerSuper) serverAccept() stateFunc {
+ s.logf("STATE: server-accept")
+ s.staged.up = false
+ s.staged.dataCipher = nil
+ s.staged.remoteAddr = zeroAddrPort
+ s.staged.relayIP = 0
+ s.publish()
+
+ var syn synPacket
+
+ for {
+ select {
+ case peer := <-s.peerUpdates:
+ return s.peerUpdate(peer)
+
+ case pkt := <-s.controlPackets:
+ switch p := pkt.Payload.(type) {
+
+ case synPacket:
+ syn = p
+ s.staged.remoteAddr = pkt.RemoteAddr
+ s.staged.dataCipher = newDataCipherFromKey(syn.SharedKey)
+ s.staged.relayIP = syn.RelayIP
+ s.publish()
+ s.sendControlPacket(newSynAckPacket(p.TraceID))
+
+ case ackPacket:
+ if p.TraceID != syn.TraceID {
+ continue
+ }
+
+ // Publish.
+ return s.serverConnected(syn.TraceID)
+ }
+ }
+ }
+}
+
+// ----------------------------------------------------------------------------
+
+func (s *peerSuper) serverConnected(traceID uint64) stateFunc {
+ s.logf("STATE: server-connected")
+ s.staged.up = true
+ s.publish()
+ return func() stateFunc {
+ return s._serverConnected(traceID)
+ }
+}
+
+func (s *peerSuper) _serverConnected(traceID uint64) stateFunc {
+
+ timeoutTimer := time.NewTimer(timeoutInterval)
+ defer timeoutTimer.Stop()
+
+ for {
+ select {
+ case peer := <-s.peerUpdates:
+ return s.peerUpdate(peer)
+
+ case pkt := <-s.controlPackets:
+ switch p := pkt.Payload.(type) {
+
+ case ackPacket:
+ if p.TraceID != traceID {
+ return s.serverAccept
+ }
+
+ s.sendControlPacket(ackPacket{TraceID: traceID})
+ timeoutTimer.Reset(timeoutInterval)
+ }
+
+ case <-timeoutTimer.C:
+ s.logf("server timeout")
+ return s.serverAccept
+ }
+ }
+}
+
+// ----------------------------------------------------------------------------
+
+func (s *peerSuper) clientInit() stateFunc {
+ s.logf("STATE: client-init")
+ if !s.remotePub {
+ // TODO: Check local discovery for IP.
+ // TODO: Attempt UDP hole punch.
+ // TODO: client-relayed
+ return s.clientSelectRelay
+ }
+
+ return s.clientDial
+}
+
+// ----------------------------------------------------------------------------
+
+func (s *peerSuper) clientSelectRelay() stateFunc {
+ s.logf("STATE: client-select-relay")
+
+ timer := time.NewTimer(0)
+ defer timer.Stop()
+
+ for {
+ select {
+ case peer := <-s.peerUpdates:
+ return s.peerUpdate(peer)
+
+ case <-timer.C:
+ ip := s.selectRelayIP()
+ if ip != 0 {
+ s.logf("Got relay: %d", ip)
+ s.staged.relayIP = ip
+ s.publish()
+ return s.clientDial
+ }
+
+ s.logf("No relay available.")
+ timer.Reset(pingInterval)
+ }
+ }
+}
+
+func (s *peerSuper) selectRelayIP() byte {
+ possible := make([]byte, 0, 8)
+ for i, peer := range s.peers {
+ if peer.CanRelay() {
+ possible = append(possible, byte(i))
+ }
+ }
+
+ if len(possible) == 0 {
+ return 0
+ }
+ return possible[rand.Intn(len(possible))]
+}
+
+// ----------------------------------------------------------------------------
+
+func (s *peerSuper) clientDial() stateFunc {
+ s.logf("STATE: client-dial")
+
+ var (
+ syn = synPacket{
+ TraceID: newTraceID(),
+ SharedKey: s.staged.dataCipher.Key(),
+ RelayIP: s.staged.relayIP,
+ }
+
+ timeout = time.NewTimer(dialTimeout)
+ )
+
+ defer timeout.Stop()
+
+ s.sendControlPacket(syn)
+
+ for {
+ select {
+
+ case peer := <-s.peerUpdates:
+ return s.peerUpdate(peer)
+
+ case pkt := <-s.controlPackets:
+ switch p := pkt.Payload.(type) {
+ case synAckPacket:
+ if p.TraceID != syn.TraceID {
+ continue // Hmm...
+ }
+ s.sendControlPacket(ackPacket{TraceID: syn.TraceID})
+ return s.clientConnected(syn.TraceID)
+ }
+
+ case <-timeout.C:
+ return s.clientInit
+ }
+ }
+}
+
+// ----------------------------------------------------------------------------
+
+func (s *peerSuper) clientConnected(traceID uint64) stateFunc {
+ s.logf("STATE: client-connected")
+ s.staged.up = true
+ s.publish()
+
+ return func() stateFunc {
+ return s._clientConnected(traceID)
+ }
+}
+
+func (s *peerSuper) _clientConnected(traceID uint64) stateFunc {
+
+ pingTimer := time.NewTimer(pingInterval)
+ timeoutTimer := time.NewTimer(timeoutInterval)
+
+ defer pingTimer.Stop()
+ defer timeoutTimer.Stop()
+
+ for {
+ select {
+ case peer := <-s.peerUpdates:
+ return s.peerUpdate(peer)
+
+ case pkt := <-s.controlPackets:
+ switch p := pkt.Payload.(type) {
+
+ case ackPacket:
+ if p.TraceID != traceID {
+ return s.clientInit
+ }
+ timeoutTimer.Reset(timeoutInterval)
+ }
+
+ case <-pingTimer.C:
+ s.sendControlPacket(ackPacket{TraceID: traceID})
+ pingTimer.Reset(pingInterval)
+
+ case <-timeoutTimer.C:
+ s.logf("client timeout")
+ return s.clientInit
+
+ }
+ }
+}
diff --git a/node/peer-super.go b/node/peer-super.go
new file mode 100644
index 0000000..df1907f
--- /dev/null
+++ b/node/peer-super.go
@@ -0,0 +1,80 @@
+package node
+
+import (
+ "fmt"
+ "log"
+ "sync/atomic"
+ "vppn/m"
+)
+
+type peerSuper struct {
+ // The purpose of this state machine is to manage this published data.
+ published *atomic.Pointer[peerRoutingData]
+ staged peerRoutingData // Local copy of shared data. See publish().
+
+ // The other remote peers.
+ peers *remotePeers
+
+ // Immutable data.
+ localIP byte
+ localPub bool
+ remoteIP byte
+ privKey []byte
+ conn *connWriter
+
+ // For sending to peer.
+ counter *uint64
+
+ // Mutable peer data.
+ peer *m.Peer
+ remotePub bool
+
+ // Incoming events.
+ peerUpdates chan *m.Peer
+ controlPackets chan controlPacket
+
+ // Buffers
+ buf []byte
+ encBuf []byte
+}
+
+type stateFunc func() stateFunc
+
+func (s *peerSuper) Run() {
+ state := s.noPeer
+ for {
+ state = state()
+ }
+}
+
+// ----------------------------------------------------------------------------
+
+func (s *peerSuper) logf(msg string, args ...any) {
+ log.Printf(fmt.Sprintf("[%03d] ", s.remoteIP)+msg, args...)
+}
+
+// ----------------------------------------------------------------------------
+
+func (s *peerSuper) publish() {
+ data := s.staged
+ s.published.Store(&data)
+}
+
+// ----------------------------------------------------------------------------
+
+func (s *peerSuper) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) {
+ buf := pkt.Marshal(s.buf)
+ h := header{
+ StreamID: controlStreamID,
+ Counter: atomic.AddUint64(s.counter, 1),
+ SourceIP: s.localIP,
+ DestIP: s.remoteIP,
+ }
+
+ buf = s.staged.controlCipher.Encrypt(h, buf, s.encBuf)
+ if s.staged.relayIP != 0 {
+ s.peers[s.staged.relayIP].RelayTo(s.remoteIP, buf)
+ } else {
+ s.conn.WriteTo(buf, s.staged.remoteAddr)
+ }
+}
diff --git a/node/peer-supervisor.go b/node/peer-supervisor.go
index 3f3e0a0..50401b8 100644
--- a/node/peer-supervisor.go
+++ b/node/peer-supervisor.go
@@ -15,55 +15,20 @@ const (
func (rp *remotePeer) supervise(conf m.PeerConfig) {
defer panicHandler()
- base := &stateBase{
- published: rp.published,
- peers: rp.peers,
- localIP: rp.localIP,
- remoteIP: rp.remoteIP,
- privKey: conf.EncPrivKey,
- localPub: addrIsValid(conf.PublicIP),
- conn: rp.conn,
- counter: &rp.counter,
- pingTimer: time.NewTimer(time.Second),
- timeoutTimer: time.NewTimer(time.Second),
- buf: make([]byte, bufferSize),
- encBuf: make([]byte, bufferSize),
+ super := &peerSuper{
+ published: rp.published,
+ peers: rp.peers,
+ localIP: rp.localIP,
+ localPub: addrIsValid(conf.PublicIP),
+ remoteIP: rp.remoteIP,
+ privKey: conf.EncPrivKey,
+ conn: rp.conn,
+ counter: &rp.counter,
+ peerUpdates: rp.peerUpdates,
+ controlPackets: rp.controlPackets,
+ buf: make([]byte, bufferSize),
+ encBuf: make([]byte, bufferSize),
}
- var (
- curState peerState = newStateNoPeer(base)
- nextState peerState
- )
-
- for {
- nextState = nil
-
- select {
- case peer := <-rp.peerUpdates:
- nextState = curState.OnPeerUpdate(peer)
-
- case pkt := <-rp.controlPackets:
- switch p := pkt.Payload.(type) {
- case synPacket:
- nextState = curState.OnSyn(pkt.RemoteAddr, p)
- case synAckPacket:
- nextState = curState.OnSynAck(pkt.RemoteAddr, p)
- case ackPacket:
- nextState = curState.OnAck(pkt.RemoteAddr, p)
- default:
- // Unknown packet type.
- }
-
- case <-base.pingTimer.C:
- nextState = curState.OnPingTimer()
-
- case <-base.timeoutTimer.C:
- nextState = curState.OnTimeoutTimer()
- }
-
- if nextState != nil {
- rp.logf("%s --> %s", curState.Name(), nextState.Name())
- curState = nextState
- }
- }
+ go super.Run()
}
diff --git a/node/peer.go b/node/peer.go
index bae2c9c..1fc3226 100644
--- a/node/peer.go
+++ b/node/peer.go
@@ -41,6 +41,10 @@ type remotePeer struct {
// Used for sending control and data packets. Atomic access only.
counter uint64
+ // Only accessed in HandlePeerUpdate. Used to determine if we should send
+ // the peer update to the peerSuper.
+ peerVersion int64
+
// For communicating with the supervisor thread.
peerUpdates chan *m.Peer
controlPackets chan controlPacket
@@ -75,7 +79,12 @@ func (rp *remotePeer) logf(msg string, args ...any) {
}
func (rp *remotePeer) HandlePeerUpdate(peer *m.Peer) {
- rp.peerUpdates <- peer
+ if peer == nil {
+ rp.peerUpdates <- peer
+ } else if peer.Version != rp.peerVersion {
+ rp.peerVersion = peer.Version
+ rp.peerUpdates <- peer
+ }
}
// ----------------------------------------------------------------------------
@@ -209,7 +218,7 @@ func (rp *remotePeer) HandleInterfacePacket(data []byte) {
enc := routingData.dataCipher.Encrypt(h, data, rp.encryptBuf)
if routingData.relayIP != 0 {
- rp.peers[routingData.relayIP].RelayFor(rp.remoteIP, enc)
+ rp.peers[routingData.relayIP].RelayTo(rp.remoteIP, enc)
} else {
rp.SendData(data)
}
@@ -224,7 +233,7 @@ func (rp *remotePeer) CanRelay() bool {
// ----------------------------------------------------------------------------
-func (rp *remotePeer) RelayFor(destIP byte, data []byte) {
+func (rp *remotePeer) RelayTo(destIP byte, data []byte) {
rp.encryptAndSend(relayStreamID, destIP, data)
}
diff --git a/node/router.go b/node/router.go
deleted file mode 100644
index 116b4d0..0000000
--- a/node/router.go
+++ /dev/null
@@ -1,7 +0,0 @@
-package node
-
-import (
- "net/netip"
-)
-
-var zeroAddrPort = netip.AddrPort{}
--
2.39.5
From 51d7b5f08627f348418c050d9005db865fc22079 Mon Sep 17 00:00:00 2001
From: jdl
Date: Mon, 23 Dec 2024 06:05:50 +0100
Subject: [PATCH 12/18] wip working - modifying
---
node/packets.go | 75 ++++++++++++++++++++++++++++++++-------
node/peer-super-states.go | 18 ++++++----
node/peer-super.go | 19 ++++++++--
node/peer-supervisor.go | 2 +-
node/peer.go | 31 ++++++++--------
5 files changed, 109 insertions(+), 36 deletions(-)
diff --git a/node/packets.go b/node/packets.go
index 0126359..04db2a9 100644
--- a/node/packets.go
+++ b/node/packets.go
@@ -14,9 +14,8 @@ const (
packetTypeSyn = iota + 1
packetTypeSynAck
packetTypeAck
- packetTypePing
- packetTypePong
- packetTypeRelayed
+ packetTypeAddrReq
+ packetTypeAddrResp
)
// ----------------------------------------------------------------------------
@@ -35,6 +34,10 @@ func (p *controlPacket) ParsePayload(buf []byte) (err error) {
p.Payload, err = parseSynAckPacket(buf)
case packetTypeAck:
p.Payload, err = parseAckPacket(buf)
+ case packetTypeAddrReq:
+ p.Payload, err = parseAddrReqPacket(buf)
+ case packetTypeAddrResp:
+ p.Payload, err = parseAddrRespPacket(buf)
default:
return errUnknownPacketType
}
@@ -70,21 +73,66 @@ func parseSynPacket(buf []byte) (p synPacket, err error) {
// ----------------------------------------------------------------------------
type synAckPacket struct {
- TraceID uint64
-}
-
-func newSynAckPacket(traceID uint64) synAckPacket {
- return synAckPacket{traceID}
+ TraceID uint64
+ RecvAddr netip.AddrPort
}
func (p synAckPacket) Marshal(buf []byte) []byte {
return newBinWriter(buf).
Byte(packetTypeSynAck).
Uint64(p.TraceID).
+ AddrPort(p.RecvAddr).
Build()
}
func parseSynAckPacket(buf []byte) (p synAckPacket, err error) {
+ err = newBinReader(buf[1:]).
+ Uint64(&p.TraceID).
+ AddrPort(&p.RecvAddr).
+ Error()
+ return
+}
+
+// ----------------------------------------------------------------------------
+
+type ackPacket struct {
+ TraceID uint64
+ SendAddr netip.AddrPort // Address of the sender.
+ RecvAddr netip.AddrPort // Address of the recipient as seen by sender.
+}
+
+func (p ackPacket) Marshal(buf []byte) []byte {
+ return newBinWriter(buf).
+ Byte(packetTypeAck).
+ Uint64(p.TraceID).
+ AddrPort(p.SendAddr).
+ AddrPort(p.RecvAddr).
+ Build()
+}
+
+func parseAckPacket(buf []byte) (p ackPacket, err error) {
+ err = newBinReader(buf[1:]).
+ Uint64(&p.TraceID).
+ AddrPort(&p.SendAddr).
+ AddrPort(&p.RecvAddr).
+ Error()
+ return
+}
+
+// ----------------------------------------------------------------------------
+
+type addrReqPacket struct {
+ TraceID uint64
+}
+
+func (p addrReqPacket) Marshal(buf []byte) []byte {
+ return newBinWriter(buf).
+ Byte(packetTypeAddrReq).
+ Uint64(p.TraceID).
+ Build()
+}
+
+func parseAddrReqPacket(buf []byte) (p addrReqPacket, err error) {
err = newBinReader(buf[1:]).
Uint64(&p.TraceID).
Error()
@@ -93,20 +141,23 @@ func parseSynAckPacket(buf []byte) (p synAckPacket, err error) {
// ----------------------------------------------------------------------------
-type ackPacket struct {
+type addrRespPacket struct {
TraceID uint64
+ Addr netip.AddrPort
}
-func (p ackPacket) Marshal(buf []byte) []byte {
+func (p addrRespPacket) Marshal(buf []byte) []byte {
return newBinWriter(buf).
- Byte(packetTypeAck).
+ Byte(packetTypeAddrResp).
Uint64(p.TraceID).
+ AddrPort(p.Addr).
Build()
}
-func parseAckPacket(buf []byte) (p ackPacket, err error) {
+func parseAddrRespPacket(buf []byte) (p addrRespPacket, err error) {
err = newBinReader(buf[1:]).
Uint64(&p.TraceID).
+ AddrPort(&p.Addr).
Error()
return
}
diff --git a/node/peer-super-states.go b/node/peer-super-states.go
index 6e615ae..2d888df 100644
--- a/node/peer-super-states.go
+++ b/node/peer-super-states.go
@@ -23,7 +23,7 @@ func (s *peerSuper) _peerUpdate(peer *m.Peer) stateFunc {
defer s.publish()
s.peer = peer
- s.staged = peerRoutingData{}
+ s.staged = peerRouteInfo{}
if s.peer == nil {
return s.noPeer
@@ -77,7 +77,10 @@ func (s *peerSuper) serverAccept() stateFunc {
s.staged.dataCipher = newDataCipherFromKey(syn.SharedKey)
s.staged.relayIP = syn.RelayIP
s.publish()
- s.sendControlPacket(newSynAckPacket(p.TraceID))
+ s.sendControlPacket(synAckPacket{
+ TraceID: syn.TraceID,
+ RecvAddr: pkt.RemoteAddr,
+ })
case ackPacket:
if p.TraceID != syn.TraceID {
@@ -120,7 +123,7 @@ func (s *peerSuper) _serverConnected(traceID uint64) stateFunc {
return s.serverAccept
}
- s.sendControlPacket(ackPacket{TraceID: traceID})
+ s.sendControlPacket(ackPacket{TraceID: traceID, RecvAddr: pkt.RemoteAddr})
timeoutTimer.Reset(timeoutInterval)
}
@@ -218,8 +221,8 @@ func (s *peerSuper) clientDial() stateFunc {
if p.TraceID != syn.TraceID {
continue // Hmm...
}
- s.sendControlPacket(ackPacket{TraceID: syn.TraceID})
- return s.clientConnected(syn.TraceID)
+ s.sendControlPacket(ackPacket{TraceID: syn.TraceID, RecvAddr: pkt.RemoteAddr})
+ return s.clientConnected(p)
}
case <-timeout.C:
@@ -230,13 +233,14 @@ func (s *peerSuper) clientDial() stateFunc {
// ----------------------------------------------------------------------------
-func (s *peerSuper) clientConnected(traceID uint64) stateFunc {
+func (s *peerSuper) clientConnected(p synAckPacket) stateFunc {
s.logf("STATE: client-connected")
s.staged.up = true
+ s.staged.localAddr = p.RecvAddr
s.publish()
return func() stateFunc {
- return s._clientConnected(traceID)
+ return s._clientConnected(p.TraceID)
}
}
diff --git a/node/peer-super.go b/node/peer-super.go
index df1907f..f5e2436 100644
--- a/node/peer-super.go
+++ b/node/peer-super.go
@@ -9,8 +9,8 @@ import (
type peerSuper struct {
// The purpose of this state machine is to manage this published data.
- published *atomic.Pointer[peerRoutingData]
- staged peerRoutingData // Local copy of shared data. See publish().
+ published *atomic.Pointer[peerRouteInfo]
+ staged peerRouteInfo // Local copy of shared data. See publish().
// The other remote peers.
peers *remotePeers
@@ -78,3 +78,18 @@ func (s *peerSuper) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) {
s.conn.WriteTo(buf, s.staged.remoteAddr)
}
}
+
+// ----------------------------------------------------------------------------
+
+func (s *peerSuper) sendControlPacketDirect(pkt interface{ Marshal([]byte) []byte }) {
+ buf := pkt.Marshal(s.buf)
+ h := header{
+ StreamID: controlStreamID,
+ Counter: atomic.AddUint64(s.counter, 1),
+ SourceIP: s.localIP,
+ DestIP: s.remoteIP,
+ }
+
+ buf = s.staged.controlCipher.Encrypt(h, buf, s.encBuf)
+ s.conn.WriteTo(buf, s.staged.remoteAddr)
+}
diff --git a/node/peer-supervisor.go b/node/peer-supervisor.go
index 50401b8..dc7d2c6 100644
--- a/node/peer-supervisor.go
+++ b/node/peer-supervisor.go
@@ -16,7 +16,7 @@ func (rp *remotePeer) supervise(conf m.PeerConfig) {
defer panicHandler()
super := &peerSuper{
- published: rp.published,
+ published: rp.route,
peers: rp.peers,
localIP: rp.localIP,
localPub: addrIsValid(conf.PublicIP),
diff --git a/node/peer.go b/node/peer.go
index 1fc3226..b829b39 100644
--- a/node/peer.go
+++ b/node/peer.go
@@ -11,13 +11,16 @@ import (
type remotePeers [256]*remotePeer
-type peerRoutingData struct {
+// ----------------------------------------------------------------------------
+
+type peerRouteInfo struct {
up bool
relay bool
controlCipher *controlCipher
dataCipher *dataCipher
remoteAddr netip.AddrPort
- relayIP byte // Non-zero if we should relay.
+ localAddr netip.AddrPort // Local address as seen by the remote.
+ relayIP byte // Non-zero if we should relay.
}
type remotePeer struct {
@@ -28,8 +31,8 @@ type remotePeer struct {
conn *connWriter
// Shared state.
- peers *remotePeers
- published *atomic.Pointer[peerRoutingData]
+ peers *remotePeers
+ route *atomic.Pointer[peerRouteInfo]
// Only used in HandlePacket / Not synchronized.
dupCheck *dupCheck
@@ -57,7 +60,7 @@ func newRemotePeer(conf m.PeerConfig, remoteIP byte, iface *ifWriter, conn *conn
iface: iface,
conn: conn,
peers: peers,
- published: &atomic.Pointer[peerRoutingData]{},
+ route: &atomic.Pointer[peerRouteInfo]{},
dupCheck: newDupCheck(0),
decryptBuf: make([]byte, bufferSize),
encryptBuf: make([]byte, bufferSize),
@@ -66,8 +69,8 @@ func newRemotePeer(conf m.PeerConfig, remoteIP byte, iface *ifWriter, conn *conn
controlPackets: make(chan controlPacket, 512),
}
- pd := peerRoutingData{}
- rp.published.Store(&pd)
+ pd := peerRouteInfo{}
+ rp.route.Store(&pd)
//go newPeerSuper(rp).Run()
go rp.supervise(conf)
@@ -111,7 +114,7 @@ func (rp *remotePeer) HandlePacket(addr netip.AddrPort, h header, data []byte) {
// ----------------------------------------------------------------------------
func (rp *remotePeer) handleControlPacket(addr netip.AddrPort, h header, data []byte) {
- routingData := rp.published.Load()
+ routingData := rp.route.Load()
if routingData.controlCipher == nil {
rp.logf("Not connected (control).")
return
@@ -158,7 +161,7 @@ func (rp *remotePeer) handleControlPacket(addr netip.AddrPort, h header, data []
// ----------------------------------------------------------------------------
func (rp *remotePeer) handleDataPacket(data []byte) {
- routingData := rp.published.Load()
+ routingData := rp.route.Load()
if routingData.dataCipher == nil {
rp.logf("Not connected (recv).")
return
@@ -176,7 +179,7 @@ func (rp *remotePeer) handleDataPacket(data []byte) {
// ----------------------------------------------------------------------------
func (rp *remotePeer) handleRelayPacket(h header, data []byte) {
- routingData := rp.published.Load()
+ routingData := rp.route.Load()
if routingData.dataCipher == nil {
rp.logf("Not connected (recv).")
return
@@ -201,7 +204,7 @@ func (rp *remotePeer) SendData(data []byte) {
}
func (rp *remotePeer) HandleInterfacePacket(data []byte) {
- routingData := rp.published.Load()
+ routingData := rp.route.Load()
if routingData.dataCipher == nil {
rp.logf("Not connected (handle interface).")
@@ -227,7 +230,7 @@ func (rp *remotePeer) HandleInterfacePacket(data []byte) {
// ----------------------------------------------------------------------------
func (rp *remotePeer) CanRelay() bool {
- data := rp.published.Load()
+ data := rp.route.Load()
return data.relay && data.up
}
@@ -240,7 +243,7 @@ func (rp *remotePeer) RelayTo(destIP byte, data []byte) {
// ----------------------------------------------------------------------------
func (rp *remotePeer) encryptAndSend(streamID byte, destIP byte, data []byte) {
- routingData := rp.published.Load()
+ routingData := rp.route.Load()
if routingData.dataCipher == nil || routingData.remoteAddr == zeroAddrPort {
rp.logf("Not connected (encrypt and send).")
return
@@ -262,7 +265,7 @@ func (rp *remotePeer) encryptAndSend(streamID byte, destIP byte, data []byte) {
// SendAsIs is used when forwarding already-encrypted data from one peer to
// another.
func (rp *remotePeer) SendAsIs(data []byte) {
- routingData := rp.published.Load()
+ routingData := rp.route.Load()
if routingData.remoteAddr == zeroAddrPort {
rp.logf("Not connected (send direct).")
return
--
2.39.5
From 869bbfb3d623e4206b968b23bbe10bfb50002b2f Mon Sep 17 00:00:00 2001
From: jdl
Date: Mon, 23 Dec 2024 08:08:23 +0100
Subject: [PATCH 13/18] wip
---
node/globalfuncs.go | 79 +++++++++
node/globals.go | 48 +++++-
node/header.go | 1 -
node/main.go | 158 ++++++++++++++---
node/peer-pollhub.go | 30 ++--
node/peer-super-states.go | 280 ------------------------------
node/peer-super.go | 95 -----------
node/peer-supervisor.go | 350 ++++++++++++++++++++++++++++++++++++--
node/peer.go | 274 -----------------------------
9 files changed, 606 insertions(+), 709 deletions(-)
create mode 100644 node/globalfuncs.go
delete mode 100644 node/peer-super-states.go
delete mode 100644 node/peer-super.go
delete mode 100644 node/peer.go
diff --git a/node/globalfuncs.go b/node/globalfuncs.go
new file mode 100644
index 0000000..9ddf90c
--- /dev/null
+++ b/node/globalfuncs.go
@@ -0,0 +1,79 @@
+package node
+
+import (
+ "log"
+ "sync/atomic"
+)
+
+func _sendControlPacket(
+ pkt interface{ Marshal([]byte) []byte },
+ route peerRoute,
+ buf1 []byte,
+ buf2 []byte,
+) {
+ buf := pkt.Marshal(buf1)
+ h1 := header{
+ StreamID: controlStreamID,
+ Counter: atomic.AddUint64(&sendCounters[route.IP], 1),
+ SourceIP: localIP,
+ DestIP: route.IP,
+ }
+ buf = route.ControlCipher.Encrypt(h1, buf, buf2)
+
+ if route.RelayIP == 0 {
+ _conn.WriteTo(buf, route.RemoteAddr)
+ return
+ }
+
+ relayRoute := routingTable[route.RelayIP].Load()
+ if !relayRoute.Up || !relayRoute.Relay {
+ log.Print("Failed to send control packet: relay not available.")
+ return
+ }
+
+ h2 := header{
+ StreamID: dataStreamID,
+ Counter: atomic.AddUint64(&sendCounters[relayRoute.IP], 1),
+ SourceIP: localIP,
+ DestIP: route.IP,
+ }
+ buf = relayRoute.DataCipher.Encrypt(h2, buf, buf1)
+ _conn.WriteTo(buf, relayRoute.RemoteAddr)
+}
+
+func _sendDataPacket(
+ pkt []byte,
+ route *peerRoute,
+ buf1 []byte,
+ buf2 []byte,
+) {
+ h := header{
+ StreamID: dataStreamID,
+ Counter: atomic.AddUint64(&sendCounters[route.IP], 1),
+ SourceIP: localIP,
+ DestIP: route.IP,
+ }
+
+ enc := route.DataCipher.Encrypt(h, pkt, buf1)
+
+ if route.RelayIP == 0 {
+ _conn.WriteTo(enc, route.RemoteAddr)
+ return
+ }
+
+ relayRoute := routingTable[route.RelayIP].Load()
+ if !relayRoute.Up || !relayRoute.Relay {
+ log.Print("Failed to send data packet: relay not available.")
+ return
+ }
+
+ h2 := header{
+ StreamID: dataStreamID,
+ Counter: atomic.AddUint64(&sendCounters[relayRoute.IP], 1),
+ SourceIP: localIP,
+ DestIP: route.IP,
+ }
+
+ enc = relayRoute.DataCipher.Encrypt(h2, enc, buf2)
+ _conn.WriteTo(enc, relayRoute.RemoteAddr)
+}
diff --git a/node/globals.go b/node/globals.go
index b78c2c9..db1e792 100644
--- a/node/globals.go
+++ b/node/globals.go
@@ -1,15 +1,57 @@
package node
-import "net/netip"
+import (
+ "net/netip"
+ "sync/atomic"
+ "vppn/m"
+)
+
+var zeroAddrPort = netip.AddrPort{}
const (
bufferSize = 1536
- if_mtu = 1400
+ if_mtu = 1200
if_queue_len = 2048
controlCipherOverhead = 16
dataCipherOverhead = 16
)
+type peerRoute struct {
+ IP byte
+ Up bool // True if data can be sent on the route.
+ Relay bool // True if the peer is a relay.
+ ControlCipher *controlCipher
+ DataCipher *dataCipher
+ RemoteAddr netip.AddrPort // Remote address if directly connected.
+ LocalAddr netip.AddrPort // Local address as seen by the remote.
+ RelayIP byte // Non-zero if we should relay.
+}
+
+// Configuration for this peer.
var (
- zeroAddrPort = netip.AddrPort{}
+ netName string
+ localIP byte
+ localPub bool
+ privateKey []byte
)
+
+// Shared interface for writing.
+var _iface *ifWriter
+
+// Shared connection for writing.
+var _conn *connWriter
+
+// Counters for sending to each peer.
+var sendCounters [256]uint64
+
+// Duplicate checkers for incoming packets.
+var dupChecks [256]*dupCheck
+
+// Channels for incoming control packets.
+var controlPackets [256]chan controlPacket
+
+// Channels for incoming peer updates from the hub.
+var peerUpdates [256]chan *m.Peer
+
+// Global routing table.
+var routingTable [256]*atomic.Pointer[peerRoute]
diff --git a/node/header.go b/node/header.go
index 1a022a2..fd28962 100644
--- a/node/header.go
+++ b/node/header.go
@@ -10,7 +10,6 @@ const (
controlHeaderSize = 24
dataStreamID = 1
dataHeaderSize = 12
- relayStreamID = 3
)
type header struct {
diff --git a/node/main.go b/node/main.go
index c291e73..e2e5c42 100644
--- a/node/main.go
+++ b/node/main.go
@@ -11,6 +11,8 @@ import (
"net/netip"
"os"
"runtime/debug"
+ "sync/atomic"
+ "time"
"vppn/m"
)
@@ -24,7 +26,6 @@ func Main() {
defer panicHandler()
var (
- netName string
initURL string
listenIP string
port int
@@ -42,14 +43,14 @@ func Main() {
}
if initURL != "" {
- mainInit(netName, initURL)
+ mainInit(initURL)
return
}
- main(netName, listenIP, uint16(port))
+ main(listenIP, uint16(port))
}
-func mainInit(netName, initURL string) {
+func mainInit(initURL string) {
if _, err := loadPeerConfig(netName); err == nil {
log.Fatalf("Network is already initialized.")
}
@@ -79,15 +80,15 @@ func mainInit(netName, initURL string) {
// ----------------------------------------------------------------------------
-func main(netName, listenIP string, port uint16) {
- conf, err := loadPeerConfig(netName)
+func main(listenIP string, port uint16) {
+ config, err := loadPeerConfig(netName)
if err != nil {
log.Fatalf("Failed to load configuration: %v", err)
}
- port = determinePort(conf.Port, port)
+ port = determinePort(config.Port, port)
- iface, err := openInterface(conf.Network, conf.PeerIP, netName)
+ iface, err := openInterface(config.Network, config.PeerIP, netName)
if err != nil {
log.Fatalf("Failed to open interface: %v", err)
}
@@ -102,18 +103,34 @@ func main(netName, listenIP string, port uint16) {
log.Fatalf("Failed to open UDP port: %v", err)
}
- connWriter := newConnWriter(conn)
- ifWriter := newIFWriter(iface)
+ // Intialize globals.
+ localIP = config.PeerIP
+ localPub = addrIsValid(config.PublicIP)
+ privateKey = config.EncPrivKey
- peers := remotePeers{}
+ _iface = newIFWriter(iface)
+ _conn = newConnWriter(conn)
- for i := range peers {
- peers[i] = newRemotePeer(conf, byte(i), ifWriter, connWriter, &peers)
+ for i := range 256 {
+ sendCounters[i] = uint64(time.Now().Unix()<<30) + 1
+ dupChecks[i] = newDupCheck(0)
+ controlPackets[i] = make(chan controlPacket, 256)
+ peerUpdates[i] = make(chan *m.Peer)
+ routingTable[i] = &atomic.Pointer[peerRoute]{}
+ route := peerRoute{IP: byte(i)}
+ routingTable[i].Store(&route)
}
- go newHubPoller(netName, conf, peers).Run()
- go readFromConn(conn, peers)
- readFromIFace(iface, peers)
+ // Start supervisors.
+ for i := range 256 {
+ go newPeerSupervisor(i).Run()
+ }
+
+ // --------------------
+
+ go newHubPoller(config).Run()
+ go readFromConn(conn)
+ readFromIFace(iface)
}
// ----------------------------------------------------------------------------
@@ -130,7 +147,7 @@ func determinePort(confPort, portFromCommandLine uint16) uint16 {
// ----------------------------------------------------------------------------
-func readFromConn(conn *net.UDPConn, peers remotePeers) {
+func readFromConn(conn *net.UDPConn) {
defer panicHandler()
@@ -139,6 +156,7 @@ func readFromConn(conn *net.UDPConn, peers remotePeers) {
n int
err error
buf = make([]byte, bufferSize)
+ decBuf = make([]byte, bufferSize)
data []byte
h header
)
@@ -156,27 +174,119 @@ func readFromConn(conn *net.UDPConn, peers remotePeers) {
}
h.Parse(data)
- peers[h.SourceIP].HandlePacket(remoteAddr, h, data)
+ switch h.StreamID {
+ case controlStreamID:
+ handleControlPacket(remoteAddr, h, data, decBuf)
+
+ case dataStreamID:
+ handleDataPacket(h, data, decBuf)
+
+ default:
+ log.Printf("Unknown stream ID: %d", h.StreamID)
+ }
}
}
+func handleControlPacket(addr netip.AddrPort, h header, data, decBuf []byte) {
+ route := routingTable[h.SourceIP].Load()
+ if route.ControlCipher == nil {
+ log.Printf("Not connected (control).")
+ return
+ }
+
+ if h.DestIP != localIP {
+ log.Printf("Incorrect destination IP on control packet: %d != %d", h.DestIP, localIP)
+ return
+ }
+
+ out, ok := route.ControlCipher.Decrypt(data, decBuf)
+ if !ok {
+ log.Printf("Failed to decrypt control packet.")
+ return
+ }
+
+ if len(out) == 0 {
+ log.Printf("Empty control packet from: %d", h.SourceIP)
+ return
+ }
+
+ if dupChecks[h.SourceIP].IsDup(h.Counter) {
+ log.Printf("[%03d] Duplicate control packet: %d", h.SourceIP, h.Counter)
+ return
+ }
+
+ pkt := controlPacket{
+ SrcIP: h.SourceIP,
+ RemoteAddr: addr,
+ }
+
+ if err := pkt.ParsePayload(out); err != nil {
+ log.Printf("Failed to parse control packet: %v", err)
+ return
+ }
+
+ select {
+ case controlPackets[h.SourceIP] <- pkt:
+ default:
+ log.Printf("Dropping control packet.")
+ }
+}
+
+func handleDataPacket(h header, data []byte, decBuf []byte) {
+ route := routingTable[h.SourceIP].Load()
+ if !route.Up {
+ log.Printf("Not connected (recv).")
+ return
+ }
+
+ dec, ok := route.DataCipher.Decrypt(data, decBuf)
+ if !ok {
+ log.Printf("Failed to decrypt data packet.")
+ return
+ }
+
+ if dupChecks[h.SourceIP].IsDup(h.Counter) {
+ log.Printf("[%03d] Duplicate data packet: %d", h.SourceIP, h.Counter)
+ return
+ }
+
+ if h.DestIP == localIP {
+ _iface.Write(dec)
+ return
+ }
+
+ destRoute := routingTable[h.DestIP].Load()
+ if !destRoute.Up || destRoute.RelayIP != 0 {
+ log.Printf("Not connected (relay)")
+ return
+ }
+
+ _conn.WriteTo(dec, destRoute.RemoteAddr)
+}
+
// ----------------------------------------------------------------------------
-func readFromIFace(iface io.ReadWriteCloser, peers remotePeers) {
-
+func readFromIFace(iface io.ReadWriteCloser) {
var (
- buf = make([]byte, bufferSize)
- packet []byte
+ packet = make([]byte, bufferSize)
+ buf1 = make([]byte, bufferSize)
+ buf2 = make([]byte, bufferSize)
remoteIP byte
err error
)
for {
- packet, remoteIP, err = readNextPacket(iface, buf)
+ packet, remoteIP, err = readNextPacket(iface, packet)
if err != nil {
log.Fatalf("Failed to read from interface: %v", err)
}
- peers[remoteIP].HandleInterfacePacket(packet)
+ route := routingTable[remoteIP].Load()
+ if !route.Up {
+ log.Printf("Route not connected: %d", remoteIP)
+ continue
+ }
+
+ _sendDataPacket(packet, route, buf1, buf2)
}
}
diff --git a/node/peer-pollhub.go b/node/peer-pollhub.go
index aa1c91b..ef36431 100644
--- a/node/peer-pollhub.go
+++ b/node/peer-pollhub.go
@@ -11,14 +11,12 @@ import (
)
type hubPoller struct {
- netName string
- localIP byte
- client *http.Client
- req *http.Request
- peers remotePeers
+ client *http.Client
+ req *http.Request
+ versions [256]int64
}
-func newHubPoller(netName string, conf m.PeerConfig, peers remotePeers) *hubPoller {
+func newHubPoller(conf m.PeerConfig) *hubPoller {
u, err := url.Parse(conf.HubAddress)
if err != nil {
log.Fatalf("Failed to parse hub address %s: %v", conf.HubAddress, err)
@@ -35,18 +33,15 @@ func newHubPoller(netName string, conf m.PeerConfig, peers remotePeers) *hubPoll
req.SetBasicAuth("", conf.APIKey)
return &hubPoller{
- netName: netName,
- localIP: conf.PeerIP,
- client: client,
- req: req,
- peers: peers,
+ client: client,
+ req: req,
}
}
func (hp *hubPoller) Run() {
defer panicHandler()
- state, err := loadNetworkState(hp.netName)
+ state, err := loadNetworkState(netName)
if err != nil {
log.Printf("Failed to load network state: %v", err)
log.Printf("Polling hub...")
@@ -83,15 +78,18 @@ func (hp *hubPoller) pollHub() {
hp.applyNetworkState(state)
- if err := storeNetworkState(hp.netName, state); err != nil {
+ if err := storeNetworkState(netName, state); err != nil {
log.Printf("Failed to store network state: %v", err)
}
}
func (hp *hubPoller) applyNetworkState(state m.NetworkState) {
- for i := range state.Peers {
- if i != int(hp.localIP) {
- hp.peers[i].HandlePeerUpdate(state.Peers[i])
+ for i, peer := range state.Peers {
+ if i != int(localIP) {
+ if peer != nil && peer.Version != hp.versions[i] {
+ peerUpdates[i] <- state.Peers[i]
+ hp.versions[i] = peer.Version
+ }
}
}
}
diff --git a/node/peer-super-states.go b/node/peer-super-states.go
deleted file mode 100644
index 2d888df..0000000
--- a/node/peer-super-states.go
+++ /dev/null
@@ -1,280 +0,0 @@
-package node
-
-import (
- "math/rand"
- "net/netip"
- "time"
- "vppn/m"
-)
-
-// ----------------------------------------------------------------------------
-
-func (s *peerSuper) noPeer() stateFunc {
- return s.peerUpdate(<-s.peerUpdates)
-}
-
-// ----------------------------------------------------------------------------
-
-func (s *peerSuper) peerUpdate(peer *m.Peer) stateFunc {
- return func() stateFunc { return s._peerUpdate(peer) }
-}
-
-func (s *peerSuper) _peerUpdate(peer *m.Peer) stateFunc {
- defer s.publish()
-
- s.peer = peer
- s.staged = peerRouteInfo{}
-
- if s.peer == nil {
- return s.noPeer
- }
-
- s.staged.controlCipher = newControlCipher(s.privKey, peer.EncPubKey)
- s.staged.dataCipher = newDataCipher()
-
- if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid {
- s.remotePub = true
- s.staged.relay = peer.Mediator
- s.staged.remoteAddr = netip.AddrPortFrom(ip, peer.Port)
- }
-
- if s.remotePub == s.localPub {
- if s.localIP < s.remoteIP {
- return s.serverAccept
- }
- return s.clientInit
- }
-
- if s.remotePub {
- return s.clientInit
- }
- return s.serverAccept
-}
-
-// ----------------------------------------------------------------------------
-
-func (s *peerSuper) serverAccept() stateFunc {
- s.logf("STATE: server-accept")
- s.staged.up = false
- s.staged.dataCipher = nil
- s.staged.remoteAddr = zeroAddrPort
- s.staged.relayIP = 0
- s.publish()
-
- var syn synPacket
-
- for {
- select {
- case peer := <-s.peerUpdates:
- return s.peerUpdate(peer)
-
- case pkt := <-s.controlPackets:
- switch p := pkt.Payload.(type) {
-
- case synPacket:
- syn = p
- s.staged.remoteAddr = pkt.RemoteAddr
- s.staged.dataCipher = newDataCipherFromKey(syn.SharedKey)
- s.staged.relayIP = syn.RelayIP
- s.publish()
- s.sendControlPacket(synAckPacket{
- TraceID: syn.TraceID,
- RecvAddr: pkt.RemoteAddr,
- })
-
- case ackPacket:
- if p.TraceID != syn.TraceID {
- continue
- }
-
- // Publish.
- return s.serverConnected(syn.TraceID)
- }
- }
- }
-}
-
-// ----------------------------------------------------------------------------
-
-func (s *peerSuper) serverConnected(traceID uint64) stateFunc {
- s.logf("STATE: server-connected")
- s.staged.up = true
- s.publish()
- return func() stateFunc {
- return s._serverConnected(traceID)
- }
-}
-
-func (s *peerSuper) _serverConnected(traceID uint64) stateFunc {
-
- timeoutTimer := time.NewTimer(timeoutInterval)
- defer timeoutTimer.Stop()
-
- for {
- select {
- case peer := <-s.peerUpdates:
- return s.peerUpdate(peer)
-
- case pkt := <-s.controlPackets:
- switch p := pkt.Payload.(type) {
-
- case ackPacket:
- if p.TraceID != traceID {
- return s.serverAccept
- }
-
- s.sendControlPacket(ackPacket{TraceID: traceID, RecvAddr: pkt.RemoteAddr})
- timeoutTimer.Reset(timeoutInterval)
- }
-
- case <-timeoutTimer.C:
- s.logf("server timeout")
- return s.serverAccept
- }
- }
-}
-
-// ----------------------------------------------------------------------------
-
-func (s *peerSuper) clientInit() stateFunc {
- s.logf("STATE: client-init")
- if !s.remotePub {
- // TODO: Check local discovery for IP.
- // TODO: Attempt UDP hole punch.
- // TODO: client-relayed
- return s.clientSelectRelay
- }
-
- return s.clientDial
-}
-
-// ----------------------------------------------------------------------------
-
-func (s *peerSuper) clientSelectRelay() stateFunc {
- s.logf("STATE: client-select-relay")
-
- timer := time.NewTimer(0)
- defer timer.Stop()
-
- for {
- select {
- case peer := <-s.peerUpdates:
- return s.peerUpdate(peer)
-
- case <-timer.C:
- ip := s.selectRelayIP()
- if ip != 0 {
- s.logf("Got relay: %d", ip)
- s.staged.relayIP = ip
- s.publish()
- return s.clientDial
- }
-
- s.logf("No relay available.")
- timer.Reset(pingInterval)
- }
- }
-}
-
-func (s *peerSuper) selectRelayIP() byte {
- possible := make([]byte, 0, 8)
- for i, peer := range s.peers {
- if peer.CanRelay() {
- possible = append(possible, byte(i))
- }
- }
-
- if len(possible) == 0 {
- return 0
- }
- return possible[rand.Intn(len(possible))]
-}
-
-// ----------------------------------------------------------------------------
-
-func (s *peerSuper) clientDial() stateFunc {
- s.logf("STATE: client-dial")
-
- var (
- syn = synPacket{
- TraceID: newTraceID(),
- SharedKey: s.staged.dataCipher.Key(),
- RelayIP: s.staged.relayIP,
- }
-
- timeout = time.NewTimer(dialTimeout)
- )
-
- defer timeout.Stop()
-
- s.sendControlPacket(syn)
-
- for {
- select {
-
- case peer := <-s.peerUpdates:
- return s.peerUpdate(peer)
-
- case pkt := <-s.controlPackets:
- switch p := pkt.Payload.(type) {
- case synAckPacket:
- if p.TraceID != syn.TraceID {
- continue // Hmm...
- }
- s.sendControlPacket(ackPacket{TraceID: syn.TraceID, RecvAddr: pkt.RemoteAddr})
- return s.clientConnected(p)
- }
-
- case <-timeout.C:
- return s.clientInit
- }
- }
-}
-
-// ----------------------------------------------------------------------------
-
-func (s *peerSuper) clientConnected(p synAckPacket) stateFunc {
- s.logf("STATE: client-connected")
- s.staged.up = true
- s.staged.localAddr = p.RecvAddr
- s.publish()
-
- return func() stateFunc {
- return s._clientConnected(p.TraceID)
- }
-}
-
-func (s *peerSuper) _clientConnected(traceID uint64) stateFunc {
-
- pingTimer := time.NewTimer(pingInterval)
- timeoutTimer := time.NewTimer(timeoutInterval)
-
- defer pingTimer.Stop()
- defer timeoutTimer.Stop()
-
- for {
- select {
- case peer := <-s.peerUpdates:
- return s.peerUpdate(peer)
-
- case pkt := <-s.controlPackets:
- switch p := pkt.Payload.(type) {
-
- case ackPacket:
- if p.TraceID != traceID {
- return s.clientInit
- }
- timeoutTimer.Reset(timeoutInterval)
- }
-
- case <-pingTimer.C:
- s.sendControlPacket(ackPacket{TraceID: traceID})
- pingTimer.Reset(pingInterval)
-
- case <-timeoutTimer.C:
- s.logf("client timeout")
- return s.clientInit
-
- }
- }
-}
diff --git a/node/peer-super.go b/node/peer-super.go
deleted file mode 100644
index f5e2436..0000000
--- a/node/peer-super.go
+++ /dev/null
@@ -1,95 +0,0 @@
-package node
-
-import (
- "fmt"
- "log"
- "sync/atomic"
- "vppn/m"
-)
-
-type peerSuper struct {
- // The purpose of this state machine is to manage this published data.
- published *atomic.Pointer[peerRouteInfo]
- staged peerRouteInfo // Local copy of shared data. See publish().
-
- // The other remote peers.
- peers *remotePeers
-
- // Immutable data.
- localIP byte
- localPub bool
- remoteIP byte
- privKey []byte
- conn *connWriter
-
- // For sending to peer.
- counter *uint64
-
- // Mutable peer data.
- peer *m.Peer
- remotePub bool
-
- // Incoming events.
- peerUpdates chan *m.Peer
- controlPackets chan controlPacket
-
- // Buffers
- buf []byte
- encBuf []byte
-}
-
-type stateFunc func() stateFunc
-
-func (s *peerSuper) Run() {
- state := s.noPeer
- for {
- state = state()
- }
-}
-
-// ----------------------------------------------------------------------------
-
-func (s *peerSuper) logf(msg string, args ...any) {
- log.Printf(fmt.Sprintf("[%03d] ", s.remoteIP)+msg, args...)
-}
-
-// ----------------------------------------------------------------------------
-
-func (s *peerSuper) publish() {
- data := s.staged
- s.published.Store(&data)
-}
-
-// ----------------------------------------------------------------------------
-
-func (s *peerSuper) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) {
- buf := pkt.Marshal(s.buf)
- h := header{
- StreamID: controlStreamID,
- Counter: atomic.AddUint64(s.counter, 1),
- SourceIP: s.localIP,
- DestIP: s.remoteIP,
- }
-
- buf = s.staged.controlCipher.Encrypt(h, buf, s.encBuf)
- if s.staged.relayIP != 0 {
- s.peers[s.staged.relayIP].RelayTo(s.remoteIP, buf)
- } else {
- s.conn.WriteTo(buf, s.staged.remoteAddr)
- }
-}
-
-// ----------------------------------------------------------------------------
-
-func (s *peerSuper) sendControlPacketDirect(pkt interface{ Marshal([]byte) []byte }) {
- buf := pkt.Marshal(s.buf)
- h := header{
- StreamID: controlStreamID,
- Counter: atomic.AddUint64(s.counter, 1),
- SourceIP: s.localIP,
- DestIP: s.remoteIP,
- }
-
- buf = s.staged.controlCipher.Encrypt(h, buf, s.encBuf)
- s.conn.WriteTo(buf, s.staged.remoteAddr)
-}
diff --git a/node/peer-supervisor.go b/node/peer-supervisor.go
index dc7d2c6..6741f48 100644
--- a/node/peer-supervisor.go
+++ b/node/peer-supervisor.go
@@ -1,6 +1,11 @@
package node
import (
+ "fmt"
+ "log"
+ "math/rand"
+ "net/netip"
+ "sync/atomic"
"time"
"vppn/m"
)
@@ -12,23 +17,336 @@ const (
timeoutInterval = 20 * time.Second
)
-func (rp *remotePeer) supervise(conf m.PeerConfig) {
- defer panicHandler()
+// ----------------------------------------------------------------------------
- super := &peerSuper{
- published: rp.route,
- peers: rp.peers,
- localIP: rp.localIP,
- localPub: addrIsValid(conf.PublicIP),
- remoteIP: rp.remoteIP,
- privKey: conf.EncPrivKey,
- conn: rp.conn,
- counter: &rp.counter,
- peerUpdates: rp.peerUpdates,
- controlPackets: rp.controlPackets,
- buf: make([]byte, bufferSize),
- encBuf: make([]byte, bufferSize),
+type peerSupervisor struct {
+ // The purpose of this state machine is to manage this published data.
+ published *atomic.Pointer[peerRoute]
+ staged peerRoute // Local copy of shared data. See publish().
+
+ // Immutable data.
+ remoteIP byte // Remote VPN IP.
+
+ // Mutable peer data.
+ peer *m.Peer
+ remotePub bool
+
+ // Incoming events.
+ peerUpdates chan *m.Peer
+ controlPackets chan controlPacket
+
+ // Buffers for sending control packets.
+ buf1 []byte
+ buf2 []byte
+}
+
+func newPeerSupervisor(i int) *peerSupervisor {
+ return &peerSupervisor{
+ published: routingTable[i],
+ remoteIP: byte(i),
+ peerUpdates: peerUpdates[i],
+ controlPackets: controlPackets[i],
+ buf1: make([]byte, bufferSize),
+ buf2: make([]byte, bufferSize),
+ }
+}
+
+type stateFunc func() stateFunc
+
+func (s *peerSupervisor) Run() {
+ state := s.noPeer
+ for {
+ state = state()
+ }
+}
+
+// ----------------------------------------------------------------------------
+
+func (s *peerSupervisor) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) {
+ _sendControlPacket(pkt, s.staged, s.buf1, s.buf2)
+}
+
+// ----------------------------------------------------------------------------
+
+func (s *peerSupervisor) logf(msg string, args ...any) {
+ log.Printf(fmt.Sprintf("[%03d] ", s.remoteIP)+msg, args...)
+}
+
+// ----------------------------------------------------------------------------
+
+func (s *peerSupervisor) publish() {
+ data := s.staged
+ s.published.Store(&data)
+}
+
+// ----------------------------------------------------------------------------
+
+func (s *peerSupervisor) noPeer() stateFunc {
+ return s.peerUpdate(<-s.peerUpdates)
+}
+
+// ----------------------------------------------------------------------------
+
+func (s *peerSupervisor) peerUpdate(peer *m.Peer) stateFunc {
+ return func() stateFunc { return s._peerUpdate(peer) }
+}
+
+func (s *peerSupervisor) _peerUpdate(peer *m.Peer) stateFunc {
+ defer s.publish()
+
+ s.peer = peer
+ s.staged = peerRoute{}
+
+ if s.peer == nil {
+ return s.noPeer
}
- go super.Run()
+ s.staged.IP = s.remoteIP
+ s.staged.ControlCipher = newControlCipher(privateKey, peer.EncPubKey)
+ s.staged.DataCipher = newDataCipher()
+
+ if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid {
+ s.remotePub = true
+ s.staged.Relay = peer.Mediator
+ s.staged.RemoteAddr = netip.AddrPortFrom(ip, peer.Port)
+ }
+
+ if s.remotePub == localPub {
+ if localIP < s.remoteIP {
+ return s.serverAccept
+ }
+ return s.clientInit
+ }
+
+ if s.remotePub {
+ return s.clientInit
+ }
+ return s.serverAccept
+}
+
+// ----------------------------------------------------------------------------
+
+func (s *peerSupervisor) serverAccept() stateFunc {
+ s.logf("STATE: server-accept")
+ s.staged.Up = false
+ s.staged.DataCipher = nil
+ s.staged.RemoteAddr = zeroAddrPort
+ s.staged.RelayIP = 0
+ s.publish()
+
+ var syn synPacket
+
+ for {
+ select {
+ case peer := <-s.peerUpdates:
+ return s.peerUpdate(peer)
+
+ case pkt := <-s.controlPackets:
+ switch p := pkt.Payload.(type) {
+
+ case synPacket:
+ syn = p
+ s.staged.RemoteAddr = pkt.RemoteAddr
+ s.staged.DataCipher = newDataCipherFromKey(syn.SharedKey)
+ s.staged.RelayIP = syn.RelayIP
+ s.publish()
+ s.sendControlPacket(synAckPacket{TraceID: syn.TraceID, RecvAddr: pkt.RemoteAddr})
+
+ case ackPacket:
+ if p.TraceID != syn.TraceID {
+ continue
+ }
+
+ // Publish.
+ return s.serverConnected(syn.TraceID)
+ }
+ }
+ }
+}
+
+// ----------------------------------------------------------------------------
+
+func (s *peerSupervisor) serverConnected(traceID uint64) stateFunc {
+ s.logf("STATE: server-connected")
+ s.staged.Up = true
+ s.publish()
+ return func() stateFunc {
+ return s._serverConnected(traceID)
+ }
+}
+
+func (s *peerSupervisor) _serverConnected(traceID uint64) stateFunc {
+
+ timeoutTimer := time.NewTimer(timeoutInterval)
+ defer timeoutTimer.Stop()
+
+ for {
+ select {
+ case peer := <-s.peerUpdates:
+ return s.peerUpdate(peer)
+
+ case pkt := <-s.controlPackets:
+ switch p := pkt.Payload.(type) {
+
+ case ackPacket:
+ if p.TraceID != traceID {
+ return s.serverAccept
+ }
+ s.sendControlPacket(ackPacket{TraceID: traceID, RecvAddr: pkt.RemoteAddr})
+ timeoutTimer.Reset(timeoutInterval)
+ }
+
+ case <-timeoutTimer.C:
+ s.logf("server timeout")
+ return s.serverAccept
+ }
+ }
+}
+
+// ----------------------------------------------------------------------------
+
+func (s *peerSupervisor) clientInit() stateFunc {
+ s.logf("STATE: client-init")
+ if !s.remotePub {
+ // TODO: Check local discovery for IP.
+ // TODO: Attempt UDP hole punch.
+ // TODO: client-relayed
+ return s.clientSelectRelay
+ }
+
+ return s.clientDial
+}
+
+// ----------------------------------------------------------------------------
+
+func (s *peerSupervisor) clientSelectRelay() stateFunc {
+ s.logf("STATE: client-select-relay")
+
+ timer := time.NewTimer(0)
+ defer timer.Stop()
+
+ for {
+ select {
+ case peer := <-s.peerUpdates:
+ return s.peerUpdate(peer)
+
+ case <-timer.C:
+ relay := s.selectRelay()
+ if relay != nil {
+ s.logf("Got relay: %d", relay.IP)
+ s.staged.RelayIP = relay.IP
+ s.staged.LocalAddr = relay.LocalAddr
+ s.publish()
+ return s.clientDial
+ }
+
+ s.logf("No relay available.")
+ timer.Reset(pingInterval)
+ }
+ }
+}
+
+func (s *peerSupervisor) selectRelay() *peerRoute {
+ possible := make([]*peerRoute, 0, 8)
+ for i := range routingTable {
+ route := routingTable[i].Load()
+ if !route.Up || !route.Relay {
+ continue
+ }
+ possible = append(possible, route)
+ }
+
+ if len(possible) == 0 {
+ return nil
+ }
+ return possible[rand.Intn(len(possible))]
+}
+
+// ----------------------------------------------------------------------------
+
+func (s *peerSupervisor) clientDial() stateFunc {
+ s.logf("STATE: client-dial")
+
+ var (
+ syn = synPacket{
+ TraceID: newTraceID(),
+ SharedKey: s.staged.DataCipher.Key(),
+ RelayIP: s.staged.RelayIP,
+ }
+
+ timeout = time.NewTimer(dialTimeout)
+ )
+
+ defer timeout.Stop()
+
+ s.sendControlPacket(syn)
+
+ for {
+ select {
+
+ case peer := <-s.peerUpdates:
+ return s.peerUpdate(peer)
+
+ case pkt := <-s.controlPackets:
+ switch p := pkt.Payload.(type) {
+ case synAckPacket:
+ if p.TraceID != syn.TraceID {
+ continue // Hmm...
+ }
+ s.sendControlPacket(ackPacket{TraceID: syn.TraceID, RecvAddr: pkt.RemoteAddr})
+ return s.clientConnected(p)
+ }
+
+ case <-timeout.C:
+ return s.clientInit
+ }
+ }
+}
+
+// ----------------------------------------------------------------------------
+
+func (s *peerSupervisor) clientConnected(p synAckPacket) stateFunc {
+ s.logf("STATE: client-connected")
+ s.staged.Up = true
+ s.staged.LocalAddr = p.RecvAddr
+ s.publish()
+
+ return func() stateFunc {
+ return s._clientConnected(p.TraceID)
+ }
+}
+
+func (s *peerSupervisor) _clientConnected(traceID uint64) stateFunc {
+
+ pingTimer := time.NewTimer(pingInterval)
+ timeoutTimer := time.NewTimer(timeoutInterval)
+
+ defer pingTimer.Stop()
+ defer timeoutTimer.Stop()
+
+ for {
+ select {
+ case peer := <-s.peerUpdates:
+ return s.peerUpdate(peer)
+
+ case pkt := <-s.controlPackets:
+ switch p := pkt.Payload.(type) {
+
+ case ackPacket:
+ if p.TraceID != traceID {
+ return s.clientInit
+ }
+ timeoutTimer.Reset(timeoutInterval)
+ }
+
+ case <-pingTimer.C:
+ s.sendControlPacket(ackPacket{TraceID: traceID})
+ pingTimer.Reset(pingInterval)
+
+ case <-timeoutTimer.C:
+ s.logf("client timeout")
+ return s.clientInit
+
+ }
+ }
}
diff --git a/node/peer.go b/node/peer.go
deleted file mode 100644
index b829b39..0000000
--- a/node/peer.go
+++ /dev/null
@@ -1,274 +0,0 @@
-package node
-
-import (
- "fmt"
- "log"
- "net/netip"
- "sync/atomic"
- "time"
- "vppn/m"
-)
-
-type remotePeers [256]*remotePeer
-
-// ----------------------------------------------------------------------------
-
-type peerRouteInfo struct {
- up bool
- relay bool
- controlCipher *controlCipher
- dataCipher *dataCipher
- remoteAddr netip.AddrPort
- localAddr netip.AddrPort // Local address as seen by the remote.
- relayIP byte // Non-zero if we should relay.
-}
-
-type remotePeer struct {
- // Immutable data.
- localIP byte
- remoteIP byte
- iface *ifWriter
- conn *connWriter
-
- // Shared state.
- peers *remotePeers
- route *atomic.Pointer[peerRouteInfo]
-
- // Only used in HandlePacket / Not synchronized.
- dupCheck *dupCheck
- decryptBuf []byte
-
- // Only used in SendData / Not synchronized.
- encryptBuf []byte
-
- // Used for sending control and data packets. Atomic access only.
- counter uint64
-
- // Only accessed in HandlePeerUpdate. Used to determine if we should send
- // the peer update to the peerSuper.
- peerVersion int64
-
- // For communicating with the supervisor thread.
- peerUpdates chan *m.Peer
- controlPackets chan controlPacket
-}
-
-func newRemotePeer(conf m.PeerConfig, remoteIP byte, iface *ifWriter, conn *connWriter, peers *remotePeers) *remotePeer {
- rp := &remotePeer{
- localIP: conf.PeerIP,
- remoteIP: remoteIP,
- iface: iface,
- conn: conn,
- peers: peers,
- route: &atomic.Pointer[peerRouteInfo]{},
- dupCheck: newDupCheck(0),
- decryptBuf: make([]byte, bufferSize),
- encryptBuf: make([]byte, bufferSize),
- counter: uint64(time.Now().Unix()) << 30,
- peerUpdates: make(chan *m.Peer),
- controlPackets: make(chan controlPacket, 512),
- }
-
- pd := peerRouteInfo{}
- rp.route.Store(&pd)
-
- //go newPeerSuper(rp).Run()
- go rp.supervise(conf)
- return rp
-}
-
-func (rp *remotePeer) logf(msg string, args ...any) {
- log.Printf(fmt.Sprintf("[%03d] ", rp.remoteIP)+msg, args...)
-}
-
-func (rp *remotePeer) HandlePeerUpdate(peer *m.Peer) {
- if peer == nil {
- rp.peerUpdates <- peer
- } else if peer.Version != rp.peerVersion {
- rp.peerVersion = peer.Version
- rp.peerUpdates <- peer
- }
-}
-
-// ----------------------------------------------------------------------------
-
-// HandlePacket accepts a raw data packet coming in from the network.
-//
-// This function is called by a single thread.
-func (rp *remotePeer) HandlePacket(addr netip.AddrPort, h header, data []byte) {
- switch h.StreamID {
- case controlStreamID:
- rp.handleControlPacket(addr, h, data)
-
- case dataStreamID:
- rp.handleDataPacket(data)
-
- case relayStreamID:
- rp.handleRelayPacket(h, data)
-
- default:
- rp.logf("Unknown stream ID: %d", h.StreamID)
- }
-}
-
-// ----------------------------------------------------------------------------
-
-func (rp *remotePeer) handleControlPacket(addr netip.AddrPort, h header, data []byte) {
- routingData := rp.route.Load()
- if routingData.controlCipher == nil {
- rp.logf("Not connected (control).")
- return
- }
-
- if h.DestIP != rp.localIP {
- rp.logf("Incorrect destination IP on control packet.")
- return
- }
-
- out, ok := routingData.controlCipher.Decrypt(data, rp.decryptBuf)
- if !ok {
- rp.logf("Failed to decrypt control packet.")
- return
- }
-
- if len(out) == 0 {
- rp.logf("Empty control packet from: %d", h.SourceIP)
- return
- }
-
- if rp.dupCheck.IsDup(h.Counter) {
- rp.logf("Duplicate control packet: %d", h.Counter)
- return
- }
-
- pkt := controlPacket{
- SrcIP: h.SourceIP,
- RemoteAddr: addr,
- }
-
- if err := pkt.ParsePayload(out); err != nil {
- rp.logf("Failed to parse control packet: %v", err)
- return
- }
-
- select {
- case rp.controlPackets <- pkt:
- default:
- rp.logf("Dropping control packet.")
- }
-}
-
-// ----------------------------------------------------------------------------
-
-func (rp *remotePeer) handleDataPacket(data []byte) {
- routingData := rp.route.Load()
- if routingData.dataCipher == nil {
- rp.logf("Not connected (recv).")
- return
- }
-
- dec, ok := routingData.dataCipher.Decrypt(data, rp.decryptBuf)
- if !ok {
- rp.logf("Failed to decrypt data packet.")
- return
- }
-
- rp.iface.Write(dec)
-}
-
-// ----------------------------------------------------------------------------
-
-func (rp *remotePeer) handleRelayPacket(h header, data []byte) {
- routingData := rp.route.Load()
- if routingData.dataCipher == nil {
- rp.logf("Not connected (recv).")
- return
- }
-
- dec, ok := routingData.dataCipher.Decrypt(data, rp.decryptBuf)
- if !ok {
- rp.logf("Failed to decrypt data packet.")
- return
- }
-
- rp.peers[h.DestIP].SendAsIs(dec)
-}
-
-// ----------------------------------------------------------------------------
-
-// SendData sends data coming from the interface going to the network.
-//
-// This function is called by a single thread.
-func (rp *remotePeer) SendData(data []byte) {
- rp.encryptAndSend(dataStreamID, rp.remoteIP, data)
-}
-
-func (rp *remotePeer) HandleInterfacePacket(data []byte) {
- routingData := rp.route.Load()
-
- if routingData.dataCipher == nil {
- rp.logf("Not connected (handle interface).")
- return
- }
-
- h := header{
- StreamID: dataStreamID,
- Counter: atomic.AddUint64(&rp.counter, 1),
- SourceIP: rp.localIP,
- DestIP: rp.remoteIP,
- }
-
- enc := routingData.dataCipher.Encrypt(h, data, rp.encryptBuf)
-
- if routingData.relayIP != 0 {
- rp.peers[routingData.relayIP].RelayTo(rp.remoteIP, enc)
- } else {
- rp.SendData(data)
- }
-}
-
-// ----------------------------------------------------------------------------
-
-func (rp *remotePeer) CanRelay() bool {
- data := rp.route.Load()
- return data.relay && data.up
-}
-
-// ----------------------------------------------------------------------------
-
-func (rp *remotePeer) RelayTo(destIP byte, data []byte) {
- rp.encryptAndSend(relayStreamID, destIP, data)
-}
-
-// ----------------------------------------------------------------------------
-
-func (rp *remotePeer) encryptAndSend(streamID byte, destIP byte, data []byte) {
- routingData := rp.route.Load()
- if routingData.dataCipher == nil || routingData.remoteAddr == zeroAddrPort {
- rp.logf("Not connected (encrypt and send).")
- return
- }
-
- h := header{
- StreamID: streamID,
- Counter: atomic.AddUint64(&rp.counter, 1),
- SourceIP: rp.localIP,
- DestIP: destIP,
- }
-
- enc := routingData.dataCipher.Encrypt(h, data, rp.encryptBuf)
- rp.conn.WriteTo(enc, routingData.remoteAddr)
-}
-
-// ----------------------------------------------------------------------------
-
-// SendAsIs is used when forwarding already-encrypted data from one peer to
-// another.
-func (rp *remotePeer) SendAsIs(data []byte) {
- routingData := rp.route.Load()
- if routingData.remoteAddr == zeroAddrPort {
- rp.logf("Not connected (send direct).")
- return
- }
- rp.conn.WriteTo(data, routingData.remoteAddr)
-}
--
2.39.5
From a6e022e57075682a9bd44df9835c722bf9dbd999 Mon Sep 17 00:00:00 2001
From: jdl
Date: Mon, 23 Dec 2024 08:15:02 +0100
Subject: [PATCH 14/18] Cleanup: working
---
node/globalfuncs.go | 47 ++++++++------------------
node/globals.go | 44 ++++++++++++------------
node/{peer-pollhub.go => hubpoller.go} | 0
node/main.go | 2 +-
4 files changed, 37 insertions(+), 56 deletions(-)
rename node/{peer-pollhub.go => hubpoller.go} (100%)
diff --git a/node/globalfuncs.go b/node/globalfuncs.go
index 9ddf90c..406588e 100644
--- a/node/globalfuncs.go
+++ b/node/globalfuncs.go
@@ -5,48 +5,25 @@ import (
"sync/atomic"
)
-func _sendControlPacket(
- pkt interface{ Marshal([]byte) []byte },
- route peerRoute,
- buf1 []byte,
- buf2 []byte,
-) {
- buf := pkt.Marshal(buf1)
- h1 := header{
+func _sendControlPacket(pkt interface{ Marshal([]byte) []byte }, route peerRoute, buf1, buf2 []byte) {
+ buf := pkt.Marshal(buf2)
+ h := header{
StreamID: controlStreamID,
Counter: atomic.AddUint64(&sendCounters[route.IP], 1),
SourceIP: localIP,
DestIP: route.IP,
}
- buf = route.ControlCipher.Encrypt(h1, buf, buf2)
+ buf = route.ControlCipher.Encrypt(h, buf, buf1)
if route.RelayIP == 0 {
_conn.WriteTo(buf, route.RemoteAddr)
return
}
- relayRoute := routingTable[route.RelayIP].Load()
- if !relayRoute.Up || !relayRoute.Relay {
- log.Print("Failed to send control packet: relay not available.")
- return
- }
-
- h2 := header{
- StreamID: dataStreamID,
- Counter: atomic.AddUint64(&sendCounters[relayRoute.IP], 1),
- SourceIP: localIP,
- DestIP: route.IP,
- }
- buf = relayRoute.DataCipher.Encrypt(h2, buf, buf1)
- _conn.WriteTo(buf, relayRoute.RemoteAddr)
+ _relayPacket(route.RelayIP, route.IP, buf, buf2)
}
-func _sendDataPacket(
- pkt []byte,
- route *peerRoute,
- buf1 []byte,
- buf2 []byte,
-) {
+func _sendDataPacket(route *peerRoute, pkt, buf1, buf2 []byte) {
h := header{
StreamID: dataStreamID,
Counter: atomic.AddUint64(&sendCounters[route.IP], 1),
@@ -61,19 +38,23 @@ func _sendDataPacket(
return
}
- relayRoute := routingTable[route.RelayIP].Load()
+ _relayPacket(route.RelayIP, route.IP, enc, buf2)
+}
+
+func _relayPacket(relayIP, destIP byte, data, buf []byte) {
+ relayRoute := routingTable[relayIP].Load()
if !relayRoute.Up || !relayRoute.Relay {
log.Print("Failed to send data packet: relay not available.")
return
}
- h2 := header{
+ h := header{
StreamID: dataStreamID,
Counter: atomic.AddUint64(&sendCounters[relayRoute.IP], 1),
SourceIP: localIP,
- DestIP: route.IP,
+ DestIP: destIP,
}
- enc = relayRoute.DataCipher.Encrypt(h2, enc, buf2)
+ enc := relayRoute.DataCipher.Encrypt(h, data, buf)
_conn.WriteTo(enc, relayRoute.RemoteAddr)
}
diff --git a/node/globals.go b/node/globals.go
index db1e792..f782cb5 100644
--- a/node/globals.go
+++ b/node/globals.go
@@ -27,31 +27,31 @@ type peerRoute struct {
RelayIP byte // Non-zero if we should relay.
}
-// Configuration for this peer.
var (
+ // Configuration for this peer.
netName string
localIP byte
localPub bool
privateKey []byte
+
+ // Shared interface for writing.
+ _iface *ifWriter
+
+ // Shared connection for writing.
+ _conn *connWriter
+
+ // Counters for sending to each peer.
+ sendCounters [256]uint64
+
+ // Duplicate checkers for incoming packets.
+ dupChecks [256]*dupCheck
+
+ // Channels for incoming control packets.
+ controlPackets [256]chan controlPacket
+
+ // Channels for incoming peer updates from the hub.
+ peerUpdates [256]chan *m.Peer
+
+ // Global routing table.
+ routingTable [256]*atomic.Pointer[peerRoute]
)
-
-// Shared interface for writing.
-var _iface *ifWriter
-
-// Shared connection for writing.
-var _conn *connWriter
-
-// Counters for sending to each peer.
-var sendCounters [256]uint64
-
-// Duplicate checkers for incoming packets.
-var dupChecks [256]*dupCheck
-
-// Channels for incoming control packets.
-var controlPackets [256]chan controlPacket
-
-// Channels for incoming peer updates from the hub.
-var peerUpdates [256]chan *m.Peer
-
-// Global routing table.
-var routingTable [256]*atomic.Pointer[peerRoute]
diff --git a/node/peer-pollhub.go b/node/hubpoller.go
similarity index 100%
rename from node/peer-pollhub.go
rename to node/hubpoller.go
diff --git a/node/main.go b/node/main.go
index e2e5c42..d9d865a 100644
--- a/node/main.go
+++ b/node/main.go
@@ -287,6 +287,6 @@ func readFromIFace(iface io.ReadWriteCloser) {
continue
}
- _sendDataPacket(packet, route, buf1, buf2)
+ _sendDataPacket(route, packet, buf1, buf2)
}
}
--
2.39.5
From 6a6e30feb90e6dd57a5c2ff53fb3c86b0a125ab9 Mon Sep 17 00:00:00 2001
From: jdl
Date: Mon, 23 Dec 2024 09:34:11 +0100
Subject: [PATCH 15/18] Cleanup, hub updates
---
hub/api/api.go | 45 ++++++++++-------------
hub/api/db/generated.go | 33 +++++++++--------
hub/api/db/sanitize-validate.go | 2 +-
hub/api/db/tables.defs | 5 ++-
hub/api/migrations/2024-11-30-init.sql | 5 ++-
hub/handlers.go | 47 +++++++++++++++++-------
hub/routes.go | 1 +
hub/templates/admin-peer-create.html | 6 ++--
hub/templates/admin-peer-delete.html | 4 +--
hub/templates/admin-peer-edit.html | 4 +--
hub/templates/admin-peer-list.html | 7 ++--
hub/templates/admin-peer-view.html | 2 +-
m/models.go | 35 +++++++++---------
node/main.go | 2 +-
node/packets.go | 49 --------------------------
node/packets_test.go | 4 ++-
node/peer-supervisor.go | 4 +--
17 files changed, 108 insertions(+), 147 deletions(-)
diff --git a/hub/api/api.go b/hub/api/api.go
index 053c574..975149d 100644
--- a/hub/api/api.go
+++ b/hub/api/api.go
@@ -15,7 +15,6 @@ import (
"git.crumpington.com/lib/go/sqliteutil"
"golang.org/x/crypto/bcrypt"
"golang.org/x/crypto/nacl/box"
- "golang.org/x/crypto/nacl/sign"
)
//go:embed migrations
@@ -146,7 +145,7 @@ type PeerCreateArgs struct {
Name string
PublicIP []byte
Port uint16
- Mediator bool
+ Relay bool
}
// Create the intention to add a peer. The returned code is used to complete
@@ -184,11 +183,6 @@ func (a *API) Peer_Create(creationCode string) (*m.PeerConfig, error) {
return nil, err
}
- signPubKey, signPrivKey, err := sign.GenerateKey(rand.Reader)
- if err != nil {
- return nil, err
- }
-
// Get peer IP.
peerIP := byte(0)
@@ -208,15 +202,14 @@ func (a *API) Peer_Create(creationCode string) (*m.PeerConfig, error) {
}
peer := &Peer{
- PeerIP: peerIP,
- Version: idgen.NextID(0),
- APIKey: idgen.NewToken(),
- Name: args.Name,
- PublicIP: args.PublicIP,
- Port: args.Port,
- Mediator: args.Mediator,
- EncPubKey: encPubKey[:],
- SignPubKey: signPubKey[:],
+ PeerIP: peerIP,
+ Version: idgen.NextID(0),
+ APIKey: idgen.NewToken(),
+ Name: args.Name,
+ PublicIP: args.PublicIP,
+ Port: args.Port,
+ Relay: args.Relay,
+ PubKey: encPubKey[:],
}
if err := db.Peer_Insert(a.db, peer); err != nil {
@@ -226,17 +219,15 @@ func (a *API) Peer_Create(creationCode string) (*m.PeerConfig, error) {
conf := a.Config_Get()
return &m.PeerConfig{
- PeerIP: peer.PeerIP,
- HubAddress: conf.HubAddress,
- APIKey: peer.APIKey,
- Network: conf.VPNNetwork,
- PublicIP: peer.PublicIP,
- Port: peer.Port,
- Mediator: peer.Mediator,
- EncPubKey: encPubKey[:],
- EncPrivKey: encPrivKey[:],
- SignPubKey: signPubKey[:],
- SignPrivKey: signPrivKey[:],
+ PeerIP: peer.PeerIP,
+ HubAddress: conf.HubAddress,
+ APIKey: peer.APIKey,
+ Network: conf.VPNNetwork,
+ PublicIP: peer.PublicIP,
+ Port: peer.Port,
+ Relay: peer.Relay,
+ PubKey: encPubKey[:],
+ PrivKey: encPrivKey[:],
}, nil
}
diff --git a/hub/api/db/generated.go b/hub/api/db/generated.go
index a23498d..1957b6f 100644
--- a/hub/api/db/generated.go
+++ b/hub/api/db/generated.go
@@ -307,18 +307,17 @@ func Session_List(
// ----------------------------------------------------------------------------
type Peer struct {
- PeerIP byte
- Version int64
- APIKey string
- Name string
- PublicIP []byte
- Port uint16
- Mediator bool
- EncPubKey []byte
- SignPubKey []byte
+ PeerIP byte
+ Version int64
+ APIKey string
+ Name string
+ PublicIP []byte
+ Port uint16
+ Relay bool
+ PubKey []byte
}
-const Peer_SelectQuery = "SELECT PeerIP,Version,APIKey,Name,PublicIP,Port,Mediator,EncPubKey,SignPubKey FROM peers"
+const Peer_SelectQuery = "SELECT PeerIP,Version,APIKey,Name,PublicIP,Port,Relay,PubKey FROM peers"
func Peer_Insert(
tx TX,
@@ -329,7 +328,7 @@ func Peer_Insert(
return err
}
- _, err = tx.Exec("INSERT INTO peers(PeerIP,Version,APIKey,Name,PublicIP,Port,Mediator,EncPubKey,SignPubKey) VALUES(?,?,?,?,?,?,?,?,?)", row.PeerIP, row.Version, row.APIKey, row.Name, row.PublicIP, row.Port, row.Mediator, row.EncPubKey, row.SignPubKey)
+ _, err = tx.Exec("INSERT INTO peers(PeerIP,Version,APIKey,Name,PublicIP,Port,Relay,PubKey) VALUES(?,?,?,?,?,?,?,?)", row.PeerIP, row.Version, row.APIKey, row.Name, row.PublicIP, row.Port, row.Relay, row.PubKey)
return err
}
@@ -342,7 +341,7 @@ func Peer_Update(
return err
}
- result, err := tx.Exec("UPDATE peers SET Version=?,Name=?,PublicIP=?,Port=?,Mediator=? WHERE PeerIP=?", row.Version, row.Name, row.PublicIP, row.Port, row.Mediator, row.PeerIP)
+ result, err := tx.Exec("UPDATE peers SET Version=?,Name=?,PublicIP=?,Port=?,Relay=? WHERE PeerIP=?", row.Version, row.Name, row.PublicIP, row.Port, row.Relay, row.PeerIP)
if err != nil {
return err
}
@@ -370,7 +369,7 @@ func Peer_UpdateFull(
return err
}
- result, err := tx.Exec("UPDATE peers SET Version=?,APIKey=?,Name=?,PublicIP=?,Port=?,Mediator=?,EncPubKey=?,SignPubKey=? WHERE PeerIP=?", row.Version, row.APIKey, row.Name, row.PublicIP, row.Port, row.Mediator, row.EncPubKey, row.SignPubKey, row.PeerIP)
+ result, err := tx.Exec("UPDATE peers SET Version=?,APIKey=?,Name=?,PublicIP=?,Port=?,Relay=?,PubKey=? WHERE PeerIP=?", row.Version, row.APIKey, row.Name, row.PublicIP, row.Port, row.Relay, row.PubKey, row.PeerIP)
if err != nil {
return err
}
@@ -420,8 +419,8 @@ func Peer_Get(
err error,
) {
row = &Peer{}
- r := tx.QueryRow("SELECT PeerIP,Version,APIKey,Name,PublicIP,Port,Mediator,EncPubKey,SignPubKey FROM peers WHERE PeerIP=?", PeerIP)
- err = r.Scan(&row.PeerIP, &row.Version, &row.APIKey, &row.Name, &row.PublicIP, &row.Port, &row.Mediator, &row.EncPubKey, &row.SignPubKey)
+ r := tx.QueryRow("SELECT PeerIP,Version,APIKey,Name,PublicIP,Port,Relay,PubKey FROM peers WHERE PeerIP=?", PeerIP)
+ err = r.Scan(&row.PeerIP, &row.Version, &row.APIKey, &row.Name, &row.PublicIP, &row.Port, &row.Relay, &row.PubKey)
return
}
@@ -435,7 +434,7 @@ func Peer_GetWhere(
) {
row = &Peer{}
r := tx.QueryRow(query, args...)
- err = r.Scan(&row.PeerIP, &row.Version, &row.APIKey, &row.Name, &row.PublicIP, &row.Port, &row.Mediator, &row.EncPubKey, &row.SignPubKey)
+ err = r.Scan(&row.PeerIP, &row.Version, &row.APIKey, &row.Name, &row.PublicIP, &row.Port, &row.Relay, &row.PubKey)
return
}
@@ -455,7 +454,7 @@ func Peer_Iterate(
defer rows.Close()
for rows.Next() {
row := &Peer{}
- err := rows.Scan(&row.PeerIP, &row.Version, &row.APIKey, &row.Name, &row.PublicIP, &row.Port, &row.Mediator, &row.EncPubKey, &row.SignPubKey)
+ err := rows.Scan(&row.PeerIP, &row.Version, &row.APIKey, &row.Name, &row.PublicIP, &row.Port, &row.Relay, &row.PubKey)
if !yield(row, err) {
return
}
diff --git a/hub/api/db/sanitize-validate.go b/hub/api/db/sanitize-validate.go
index b4ed8ff..e06ad94 100644
--- a/hub/api/db/sanitize-validate.go
+++ b/hub/api/db/sanitize-validate.go
@@ -51,7 +51,7 @@ func Peer_Sanitize(p *Peer) {
}
}
if p.Port == 0 {
- p.Port = 515
+ p.Port = 456
}
}
diff --git a/hub/api/db/tables.defs b/hub/api/db/tables.defs
index c9e35e2..6df286f 100644
--- a/hub/api/db/tables.defs
+++ b/hub/api/db/tables.defs
@@ -20,7 +20,6 @@ TABLE peers OF Peer (
Name string,
PublicIP []byte,
Port uint16,
- Mediator bool,
- EncPubKey []byte NoUpdate,
- SignPubKey []byte NoUpdate
+ Relay bool,
+ PubKey []byte NoUpdate
);
diff --git a/hub/api/migrations/2024-11-30-init.sql b/hub/api/migrations/2024-11-30-init.sql
index eb5da37..ee37ddc 100644
--- a/hub/api/migrations/2024-11-30-init.sql
+++ b/hub/api/migrations/2024-11-30-init.sql
@@ -22,7 +22,6 @@ CREATE TABLE peers (
Name TEXT NOT NULL UNIQUE, -- For humans.
PublicIP BLOB NOT NULL,
Port INTEGER NOT NULL,
- Mediator INTEGER NOT NULL DEFAULT 0, -- Boolean if peer will forward packets. Must also have public address.
- EncPubKey BLOB NOT NULL,
- SignPubKey BLOB NOT NULL
+ Relay INTEGER NOT NULL DEFAULT 0, -- Boolean if peer will forward packets. Must also have public address.
+ PubKey BLOB NOT NULL
) WITHOUT ROWID;
diff --git a/hub/handlers.go b/hub/handlers.go
index f24aaaa..aabf3c7 100644
--- a/hub/handlers.go
+++ b/hub/handlers.go
@@ -4,6 +4,8 @@ import (
"errors"
"log"
"net/http"
+ "net/netip"
+ "strings"
"vppn/hub/api"
"vppn/m"
@@ -155,6 +157,29 @@ func (a *App) _adminPeerList(s *api.Session, w http.ResponseWriter, r *http.Requ
})
}
+func (a *App) _adminHosts(s *api.Session, w http.ResponseWriter, r *http.Request) error {
+ conf := a.api.Config_Get()
+
+ peers, err := a.api.Peer_List()
+ if err != nil {
+ return err
+ }
+
+ b := strings.Builder{}
+
+ for _, peer := range peers {
+ ip := conf.VPNNetwork
+ ip[3] = peer.PeerIP
+ b.WriteString(netip.AddrFrom4([4]byte(ip)).String())
+ b.WriteString(" ")
+ b.WriteString(peer.Name)
+ b.WriteString("\n")
+ }
+
+ w.Write([]byte(b.String()))
+ return nil
+}
+
func (a *App) _adminPeerCreate(s *api.Session, w http.ResponseWriter, r *http.Request) error {
return a.render("/admin-peer-create.html", w, struct{ Session *api.Session }{s})
}
@@ -167,7 +192,7 @@ func (a *App) _adminPeerCreateSubmit(s *api.Session, w http.ResponseWriter, r *h
Scan("Name", &args.Name).
Scan("PublicIP", &ipStr).
Scan("Port", &args.Port).
- Scan("Mediator", &args.Mediator).
+ Scan("Relay", &args.Relay).
Error()
if err != nil {
return err
@@ -249,7 +274,7 @@ func (a *App) _adminPeerEditSubmit(s *api.Session, w http.ResponseWriter, r *htt
Scan("Name", &peer.Name).
Scan("PublicIP", &ipStr).
Scan("Port", &peer.Port).
- Scan("Mediator", &peer.Mediator).
+ Scan("Relay", &peer.Relay).
Error()
if err != nil {
return err
@@ -311,19 +336,16 @@ func (a *App) _peerCreate(w http.ResponseWriter, r *http.Request) error {
func (a *App) _peerFetchState(w http.ResponseWriter, r *http.Request) error {
_, apiKey, ok := r.BasicAuth()
if !ok {
- log.Printf("1")
return api.ErrNotAuthorized
}
peer, err := a.api.Peer_GetByAPIKey(apiKey)
if err != nil {
- log.Printf("2")
return err
}
peers, err := a.api.Peer_List()
if err != nil {
- log.Printf("3")
return err
}
@@ -339,14 +361,13 @@ func (a *App) _peerFetchState(w http.ResponseWriter, r *http.Request) error {
for _, p := range peers {
state.Peers[p.PeerIP] = &m.Peer{
- PeerIP: p.PeerIP,
- Version: p.Version,
- Name: p.Name,
- PublicIP: p.PublicIP,
- Port: p.Port,
- Mediator: p.Mediator,
- EncPubKey: p.EncPubKey,
- SignPubKey: p.SignPubKey,
+ PeerIP: p.PeerIP,
+ Version: p.Version,
+ Name: p.Name,
+ PublicIP: p.PublicIP,
+ Port: p.Port,
+ Relay: p.Relay,
+ PubKey: p.PubKey,
}
}
diff --git a/hub/routes.go b/hub/routes.go
index 0fa47f2..a29736f 100644
--- a/hub/routes.go
+++ b/hub/routes.go
@@ -17,6 +17,7 @@ func (a *App) registerRoutes() {
a.handleSignedIn("GET /admin/password/edit/", a._adminPasswordEdit)
a.handleSignedIn("POST /admin/password/edit/", a._adminPasswordSubmit)
a.handleSignedIn("GET /admin/peer/list/", a._adminPeerList)
+ a.handleSignedIn("GET /admin/peer/hosts/", a._adminHosts)
a.handleSignedIn("GET /admin/peer/create/", a._adminPeerCreate)
a.handleSignedIn("POST /admin/peer/create/", a._adminPeerCreateSubmit)
a.handleSignedIn("GET /admin/peer/intent-created/", a._adminPeerIntentCreated)
diff --git a/hub/templates/admin-peer-create.html b/hub/templates/admin-peer-create.html
index f2f0c39..8225fc8 100644
--- a/hub/templates/admin-peer-create.html
+++ b/hub/templates/admin-peer-create.html
@@ -13,12 +13,12 @@
-
+
diff --git a/hub/templates/admin-peer-delete.html b/hub/templates/admin-peer-delete.html
index a330eb8..9290f68 100644
--- a/hub/templates/admin-peer-delete.html
+++ b/hub/templates/admin-peer-delete.html
@@ -22,8 +22,8 @@
diff --git a/hub/templates/admin-peer-edit.html b/hub/templates/admin-peer-edit.html
index c6081b1..da40de8 100644
--- a/hub/templates/admin-peer-edit.html
+++ b/hub/templates/admin-peer-edit.html
@@ -22,8 +22,8 @@
diff --git a/hub/templates/admin-peer-list.html b/hub/templates/admin-peer-list.html
index 4acadc7..cb7c72c 100644
--- a/hub/templates/admin-peer-list.html
+++ b/hub/templates/admin-peer-list.html
@@ -2,7 +2,8 @@
Peers
- Add Peer
+ Add Peer /
+ Hosts
{{if .Peers -}}
@@ -13,7 +14,7 @@
Name |
Public IP |
Port |
- Mediator |
+ Relay |
@@ -27,7 +28,7 @@
{{.Name}} |
{{ipToString .PublicIP}} |
{{.Port}} |
- {{if .Mediator}}T{{else}}F{{end}} |
+ {{if .Relay}}T{{else}}F{{end}} |
{{- end}}
diff --git a/hub/templates/admin-peer-view.html b/hub/templates/admin-peer-view.html
index 89ff754..e8d6f6e 100644
--- a/hub/templates/admin-peer-view.html
+++ b/hub/templates/admin-peer-view.html
@@ -12,7 +12,7 @@
Name | {{.Name}} |
Public IP | {{ipToString .PublicIP}} |
Port | {{.Port}} |
- Mediator | {{if .Mediator}}T{{else}}F{{end}} |
+ Relay | {{if .Relay}}T{{else}}F{{end}} |
API Key | {{.APIKey}} |
{{- end}}
diff --git a/m/models.go b/m/models.go
index 29c39f9..345bf5d 100644
--- a/m/models.go
+++ b/m/models.go
@@ -2,28 +2,25 @@
package m
type PeerConfig struct {
- PeerIP byte
- HubAddress string
- Network []byte
- APIKey string
- PublicIP []byte
- Port uint16
- Mediator bool
- EncPubKey []byte
- EncPrivKey []byte
- SignPubKey []byte
- SignPrivKey []byte
+ PeerIP byte
+ HubAddress string
+ Network []byte
+ APIKey string
+ PublicIP []byte
+ Port uint16
+ Relay bool
+ PubKey []byte
+ PrivKey []byte
}
type Peer struct {
- PeerIP byte
- Version int64
- Name string
- PublicIP []byte
- Port uint16
- Mediator bool
- EncPubKey []byte
- SignPubKey []byte
+ PeerIP byte
+ Version int64
+ Name string
+ PublicIP []byte
+ Port uint16
+ Relay bool
+ PubKey []byte
}
type NetworkState struct {
diff --git a/node/main.go b/node/main.go
index d9d865a..419f644 100644
--- a/node/main.go
+++ b/node/main.go
@@ -106,7 +106,7 @@ func main(listenIP string, port uint16) {
// Intialize globals.
localIP = config.PeerIP
localPub = addrIsValid(config.PublicIP)
- privateKey = config.EncPrivKey
+ privateKey = config.PrivKey
_iface = newIFWriter(iface)
_conn = newConnWriter(conn)
diff --git a/node/packets.go b/node/packets.go
index 04db2a9..f6d92e1 100644
--- a/node/packets.go
+++ b/node/packets.go
@@ -14,8 +14,6 @@ const (
packetTypeSyn = iota + 1
packetTypeSynAck
packetTypeAck
- packetTypeAddrReq
- packetTypeAddrResp
)
// ----------------------------------------------------------------------------
@@ -34,10 +32,6 @@ func (p *controlPacket) ParsePayload(buf []byte) (err error) {
p.Payload, err = parseSynAckPacket(buf)
case packetTypeAck:
p.Payload, err = parseAckPacket(buf)
- case packetTypeAddrReq:
- p.Payload, err = parseAddrReqPacket(buf)
- case packetTypeAddrResp:
- p.Payload, err = parseAddrRespPacket(buf)
default:
return errUnknownPacketType
}
@@ -118,46 +112,3 @@ func parseAckPacket(buf []byte) (p ackPacket, err error) {
Error()
return
}
-
-// ----------------------------------------------------------------------------
-
-type addrReqPacket struct {
- TraceID uint64
-}
-
-func (p addrReqPacket) Marshal(buf []byte) []byte {
- return newBinWriter(buf).
- Byte(packetTypeAddrReq).
- Uint64(p.TraceID).
- Build()
-}
-
-func parseAddrReqPacket(buf []byte) (p addrReqPacket, err error) {
- err = newBinReader(buf[1:]).
- Uint64(&p.TraceID).
- Error()
- return
-}
-
-// ----------------------------------------------------------------------------
-
-type addrRespPacket struct {
- TraceID uint64
- Addr netip.AddrPort
-}
-
-func (p addrRespPacket) Marshal(buf []byte) []byte {
- return newBinWriter(buf).
- Byte(packetTypeAddrResp).
- Uint64(p.TraceID).
- AddrPort(p.Addr).
- Build()
-}
-
-func parseAddrRespPacket(buf []byte) (p addrRespPacket, err error) {
- err = newBinReader(buf[1:]).
- Uint64(&p.TraceID).
- AddrPort(&p.Addr).
- Error()
- return
-}
diff --git a/node/packets_test.go b/node/packets_test.go
index 660d30e..bd83080 100644
--- a/node/packets_test.go
+++ b/node/packets_test.go
@@ -2,6 +2,7 @@ package node
import (
"crypto/rand"
+ "net/netip"
"reflect"
"testing"
)
@@ -24,7 +25,8 @@ func TestPacketSyn(t *testing.T) {
func TestPacketSynAck(t *testing.T) {
in := synAckPacket{
- TraceID: newTraceID(),
+ TraceID: newTraceID(),
+ RecvAddr: netip.AddrPort{},
}
out, err := parseSynAckPacket(in.Marshal(make([]byte, bufferSize)))
diff --git a/node/peer-supervisor.go b/node/peer-supervisor.go
index 6741f48..e47d0ae 100644
--- a/node/peer-supervisor.go
+++ b/node/peer-supervisor.go
@@ -102,12 +102,12 @@ func (s *peerSupervisor) _peerUpdate(peer *m.Peer) stateFunc {
}
s.staged.IP = s.remoteIP
- s.staged.ControlCipher = newControlCipher(privateKey, peer.EncPubKey)
+ s.staged.ControlCipher = newControlCipher(privateKey, peer.PubKey)
s.staged.DataCipher = newDataCipher()
if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid {
s.remotePub = true
- s.staged.Relay = peer.Mediator
+ s.staged.Relay = peer.Relay
s.staged.RemoteAddr = netip.AddrPortFrom(ip, peer.Port)
}
--
2.39.5
From f8a0df0263204bba1ca5fb31078e894e3a6d7cf4 Mon Sep 17 00:00:00 2001
From: jdl
Date: Mon, 23 Dec 2024 20:28:49 +0100
Subject: [PATCH 16/18] wip: working - moving on to single relay w/ address
discovery
---
node/globals.go | 12 ++-
node/hubpoller.go | 1 -
node/main.go | 14 ++-
node/packets.go | 49 ++++++----
node/peer-supervisor.go | 203 ++++++++++++++++++----------------------
5 files changed, 142 insertions(+), 137 deletions(-)
diff --git a/node/globals.go b/node/globals.go
index f782cb5..25eee33 100644
--- a/node/globals.go
+++ b/node/globals.go
@@ -23,8 +23,10 @@ type peerRoute struct {
ControlCipher *controlCipher
DataCipher *dataCipher
RemoteAddr netip.AddrPort // Remote address if directly connected.
- LocalAddr netip.AddrPort // Local address as seen by the remote.
- RelayIP byte // Non-zero if we should relay.
+ // TODO: Remove this and use global localAddr and relayIP.
+ // Replace w/ a Direct boolean.
+ LocalAddr netip.AddrPort // Local address as seen by the remote.
+ RelayIP byte // Non-zero if we should relay.
}
var (
@@ -32,6 +34,7 @@ var (
netName string
localIP byte
localPub bool
+ localAddr netip.AddrPort
privateKey []byte
// Shared interface for writing.
@@ -54,4 +57,9 @@ var (
// Global routing table.
routingTable [256]*atomic.Pointer[peerRoute]
+
+ // TODO: use relay for local address discovery. This should be new stream ID,
+ // managed by a single thread.
+ // localAddr *atomic.Pointer[netip.AddrPort]
+ // relayIP *atomic.Pointer[byte]
)
diff --git a/node/hubpoller.go b/node/hubpoller.go
index ef36431..ac6b110 100644
--- a/node/hubpoller.go
+++ b/node/hubpoller.go
@@ -58,7 +58,6 @@ func (hp *hubPoller) Run() {
func (hp *hubPoller) pollHub() {
var state m.NetworkState
- log.Printf("Fetching peer state...")
resp, err := hp.client.Do(hp.req)
if err != nil {
log.Printf("Failed to fetch peer state: %v", err)
diff --git a/node/main.go b/node/main.go
index 419f644..70857c3 100644
--- a/node/main.go
+++ b/node/main.go
@@ -105,7 +105,13 @@ func main(listenIP string, port uint16) {
// Intialize globals.
localIP = config.PeerIP
- localPub = addrIsValid(config.PublicIP)
+
+ ip, ok := netip.AddrFromSlice(config.PublicIP)
+ if ok {
+ localPub = true
+ localAddr = netip.AddrPortFrom(ip, config.Port)
+ }
+
privateKey = config.PrivKey
_iface = newIFWriter(iface)
@@ -178,6 +184,8 @@ func readFromConn(conn *net.UDPConn) {
case controlStreamID:
handleControlPacket(remoteAddr, h, data, decBuf)
+ // TODO: discoveryStreamID
+
case dataStreamID:
handleDataPacket(h, data, decBuf)
@@ -216,8 +224,8 @@ func handleControlPacket(addr netip.AddrPort, h header, data, decBuf []byte) {
}
pkt := controlPacket{
- SrcIP: h.SourceIP,
- RemoteAddr: addr,
+ SrcIP: h.SourceIP,
+ SrcAddr: addr,
}
if err := pkt.ParsePayload(out); err != nil {
diff --git a/node/packets.go b/node/packets.go
index f6d92e1..f0ea736 100644
--- a/node/packets.go
+++ b/node/packets.go
@@ -2,6 +2,7 @@ package node
import (
"errors"
+ "log"
"net/netip"
)
@@ -14,14 +15,15 @@ const (
packetTypeSyn = iota + 1
packetTypeSynAck
packetTypeAck
+ packetTypeProbe
)
// ----------------------------------------------------------------------------
type controlPacket struct {
- SrcIP byte
- RemoteAddr netip.AddrPort
- Payload any
+ SrcIP byte
+ SrcAddr netip.AddrPort
+ Payload any
}
func (p *controlPacket) ParsePayload(buf []byte) (err error) {
@@ -30,8 +32,9 @@ func (p *controlPacket) ParsePayload(buf []byte) (err error) {
p.Payload, err = parseSynPacket(buf)
case packetTypeSynAck:
p.Payload, err = parseSynAckPacket(buf)
- case packetTypeAck:
- p.Payload, err = parseAckPacket(buf)
+ case packetTypeProbe:
+ log.Printf("Got probe...")
+ p.Payload, err = parseProbePacket(buf)
default:
return errUnknownPacketType
}
@@ -44,6 +47,7 @@ type synPacket struct {
TraceID uint64 // TraceID to match response w/ request.
SharedKey [32]byte // Our shared key.
RelayIP byte
+ FromAddr netip.AddrPort // The client's sending address.
}
func (p synPacket) Marshal(buf []byte) []byte {
@@ -52,6 +56,7 @@ func (p synPacket) Marshal(buf []byte) []byte {
Uint64(p.TraceID).
SharedKey(p.SharedKey).
Byte(p.RelayIP).
+ AddrPort(p.FromAddr).
Build()
}
@@ -60,6 +65,7 @@ func parseSynPacket(buf []byte) (p synPacket, err error) {
Uint64(&p.TraceID).
SharedKey(&p.SharedKey).
Byte(&p.RelayIP).
+ AddrPort(&p.FromAddr).
Error()
return
}
@@ -68,47 +74,54 @@ func parseSynPacket(buf []byte) (p synPacket, err error) {
type synAckPacket struct {
TraceID uint64
- RecvAddr netip.AddrPort
+ FromAddr netip.AddrPort
+ ToAddr netip.AddrPort
}
func (p synAckPacket) Marshal(buf []byte) []byte {
return newBinWriter(buf).
Byte(packetTypeSynAck).
Uint64(p.TraceID).
- AddrPort(p.RecvAddr).
+ AddrPort(p.FromAddr).
+ AddrPort(p.ToAddr).
Build()
}
func parseSynAckPacket(buf []byte) (p synAckPacket, err error) {
err = newBinReader(buf[1:]).
Uint64(&p.TraceID).
- AddrPort(&p.RecvAddr).
+ AddrPort(&p.FromAddr).
+ AddrPort(&p.ToAddr).
Error()
return
}
// ----------------------------------------------------------------------------
-type ackPacket struct {
+type addrDiscoveryPacket struct {
TraceID uint64
- SendAddr netip.AddrPort // Address of the sender.
- RecvAddr netip.AddrPort // Address of the recipient as seen by sender.
+ FromAddr netip.AddrPort
+ ToAddr netip.AddrPort
}
-func (p ackPacket) Marshal(buf []byte) []byte {
+// ----------------------------------------------------------------------------
+
+// A probeReqPacket is sent from a client to a server to determine if direct
+// UDP communication can be used.
+type probePacket struct {
+ TraceID uint64
+}
+
+func (p probePacket) Marshal(buf []byte) []byte {
return newBinWriter(buf).
- Byte(packetTypeAck).
+ Byte(packetTypeProbe).
Uint64(p.TraceID).
- AddrPort(p.SendAddr).
- AddrPort(p.RecvAddr).
Build()
}
-func parseAckPacket(buf []byte) (p ackPacket, err error) {
+func parseProbePacket(buf []byte) (p probePacket, err error) {
err = newBinReader(buf[1:]).
Uint64(&p.TraceID).
- AddrPort(&p.SendAddr).
- AddrPort(&p.RecvAddr).
Error()
return
}
diff --git a/node/peer-supervisor.go b/node/peer-supervisor.go
index e47d0ae..e4f056e 100644
--- a/node/peer-supervisor.go
+++ b/node/peer-supervisor.go
@@ -66,6 +66,36 @@ func (s *peerSupervisor) sendControlPacket(pkt interface{ Marshal([]byte) []byte
_sendControlPacket(pkt, s.staged, s.buf1, s.buf2)
}
+func (s *peerSupervisor) sendControlPacketTo(
+ pkt interface{ Marshal([]byte) []byte },
+ addr netip.AddrPort,
+) {
+ if !addr.IsValid() {
+ s.logf("ERROR: Attepted to send packet to invalid address: %v", addr)
+ return
+ }
+ route := s.staged
+ route.RelayIP = 0
+ route.RemoteAddr = addr
+ _sendControlPacket(pkt, route, s.buf1, s.buf2)
+}
+
+// ----------------------------------------------------------------------------
+
+func (s *peerSupervisor) getLocalAddr() netip.AddrPort {
+ if localPub {
+ return localAddr
+ }
+
+ if s.staged.RelayIP != 0 {
+ if addr := routingTable[s.staged.RelayIP].Load().LocalAddr; addr.IsValid() {
+ return addr
+ }
+ }
+
+ return s.staged.LocalAddr
+}
+
// ----------------------------------------------------------------------------
func (s *peerSupervisor) logf(msg string, args ...any) {
@@ -113,7 +143,7 @@ func (s *peerSupervisor) _peerUpdate(peer *m.Peer) stateFunc {
if s.remotePub == localPub {
if localIP < s.remoteIP {
- return s.serverAccept
+ return s.server
}
return s.clientInit
}
@@ -121,18 +151,13 @@ func (s *peerSupervisor) _peerUpdate(peer *m.Peer) stateFunc {
if s.remotePub {
return s.clientInit
}
- return s.serverAccept
+ return s.server
}
// ----------------------------------------------------------------------------
-func (s *peerSupervisor) serverAccept() stateFunc {
- s.logf("STATE: server-accept")
- s.staged.Up = false
- s.staged.DataCipher = nil
- s.staged.RemoteAddr = zeroAddrPort
- s.staged.RelayIP = 0
- s.publish()
+func (s *peerSupervisor) server() stateFunc {
+ s.logf("STATE: server")
var syn synPacket
@@ -145,60 +170,37 @@ func (s *peerSupervisor) serverAccept() stateFunc {
switch p := pkt.Payload.(type) {
case synPacket:
- syn = p
- s.staged.RemoteAddr = pkt.RemoteAddr
- s.staged.DataCipher = newDataCipherFromKey(syn.SharedKey)
- s.staged.RelayIP = syn.RelayIP
- s.publish()
- s.sendControlPacket(synAckPacket{TraceID: syn.TraceID, RecvAddr: pkt.RemoteAddr})
-
- case ackPacket:
+ // Before we can respond to this packet, we need to make sure the
+ // route is setup properly.
if p.TraceID != syn.TraceID {
- continue
+ syn = p
+ s.staged.Up = true
+ s.staged.RemoteAddr = pkt.SrcAddr
+ s.staged.DataCipher = newDataCipherFromKey(syn.SharedKey)
+ s.staged.RelayIP = syn.RelayIP
+ s.staged.LocalAddr = s.getLocalAddr()
+ s.publish()
}
- // Publish.
- return s.serverConnected(syn.TraceID)
- }
- }
- }
-}
+ // We should always respond.
+ s.sendControlPacket(synAckPacket{
+ TraceID: syn.TraceID,
+ FromAddr: s.staged.LocalAddr,
+ ToAddr: pkt.SrcAddr,
+ })
-// ----------------------------------------------------------------------------
-
-func (s *peerSupervisor) serverConnected(traceID uint64) stateFunc {
- s.logf("STATE: server-connected")
- s.staged.Up = true
- s.publish()
- return func() stateFunc {
- return s._serverConnected(traceID)
- }
-}
-
-func (s *peerSupervisor) _serverConnected(traceID uint64) stateFunc {
-
- timeoutTimer := time.NewTimer(timeoutInterval)
- defer timeoutTimer.Stop()
-
- for {
- select {
- case peer := <-s.peerUpdates:
- return s.peerUpdate(peer)
-
- case pkt := <-s.controlPackets:
- switch p := pkt.Payload.(type) {
-
- case ackPacket:
- if p.TraceID != traceID {
- return s.serverAccept
+ // If we're relayed, attempt to probe the client.
+ if s.staged.RelayIP != 0 && syn.FromAddr.IsValid() {
+ probe := probePacket{TraceID: newTraceID()}
+ s.logf("SERVER sending probe %v: %v", probe, syn.FromAddr)
+ s.sendControlPacketTo(probe, syn.FromAddr)
}
- s.sendControlPacket(ackPacket{TraceID: traceID, RecvAddr: pkt.RemoteAddr})
- timeoutTimer.Reset(timeoutInterval)
- }
- case <-timeoutTimer.C:
- s.logf("server timeout")
- return s.serverAccept
+ case probePacket:
+ s.logf("SERVER got probe: %v", p)
+ s.logf("SERVER sending probe: %v", pkt.SrcAddr)
+ s.sendControlPacketTo(probePacket{TraceID: p.TraceID}, pkt.SrcAddr)
+ }
}
}
}
@@ -208,13 +210,10 @@ func (s *peerSupervisor) _serverConnected(traceID uint64) stateFunc {
func (s *peerSupervisor) clientInit() stateFunc {
s.logf("STATE: client-init")
if !s.remotePub {
- // TODO: Check local discovery for IP.
- // TODO: Attempt UDP hole punch.
- // TODO: client-relayed
return s.clientSelectRelay
}
- return s.clientDial
+ return s.client
}
// ----------------------------------------------------------------------------
@@ -237,7 +236,7 @@ func (s *peerSupervisor) clientSelectRelay() stateFunc {
s.staged.RelayIP = relay.IP
s.staged.LocalAddr = relay.LocalAddr
s.publish()
- return s.clientDial
+ return s.client
}
s.logf("No relay available.")
@@ -264,20 +263,26 @@ func (s *peerSupervisor) selectRelay() *peerRoute {
// ----------------------------------------------------------------------------
-func (s *peerSupervisor) clientDial() stateFunc {
- s.logf("STATE: client-dial")
+func (s *peerSupervisor) client() stateFunc {
+ s.logf("STATE: client")
var (
syn = synPacket{
TraceID: newTraceID(),
SharedKey: s.staged.DataCipher.Key(),
RelayIP: s.staged.RelayIP,
+ FromAddr: s.getLocalAddr(),
}
+ ack synAckPacket
- timeout = time.NewTimer(dialTimeout)
+ probe = probePacket{TraceID: newTraceID()}
+
+ timeoutTimer = time.NewTimer(timeoutInterval)
+ pingTimer = time.NewTimer(pingInterval)
)
- defer timeout.Stop()
+ defer timeoutTimer.Stop()
+ defer pingTimer.Stop()
s.sendControlPacket(syn)
@@ -289,64 +294,36 @@ func (s *peerSupervisor) clientDial() stateFunc {
case pkt := <-s.controlPackets:
switch p := pkt.Payload.(type) {
+
case synAckPacket:
if p.TraceID != syn.TraceID {
+ s.logf("Bad traceID?")
continue // Hmm...
}
- s.sendControlPacket(ackPacket{TraceID: syn.TraceID, RecvAddr: pkt.RemoteAddr})
- return s.clientConnected(p)
- }
-
- case <-timeout.C:
- return s.clientInit
- }
- }
-}
-
-// ----------------------------------------------------------------------------
-
-func (s *peerSupervisor) clientConnected(p synAckPacket) stateFunc {
- s.logf("STATE: client-connected")
- s.staged.Up = true
- s.staged.LocalAddr = p.RecvAddr
- s.publish()
-
- return func() stateFunc {
- return s._clientConnected(p.TraceID)
- }
-}
-
-func (s *peerSupervisor) _clientConnected(traceID uint64) stateFunc {
-
- pingTimer := time.NewTimer(pingInterval)
- timeoutTimer := time.NewTimer(timeoutInterval)
-
- defer pingTimer.Stop()
- defer timeoutTimer.Stop()
-
- for {
- select {
- case peer := <-s.peerUpdates:
- return s.peerUpdate(peer)
-
- case pkt := <-s.controlPackets:
- switch p := pkt.Payload.(type) {
-
- case ackPacket:
- if p.TraceID != traceID {
- return s.clientInit
- }
+ ack = p
timeoutTimer.Reset(timeoutInterval)
+
+ if !s.staged.Up {
+ s.staged.Up = true
+ s.staged.LocalAddr = p.ToAddr
+ s.publish()
+ }
+
+ case probePacket:
+ s.logf("CLIENT got probe: %v", p)
}
case <-pingTimer.C:
- s.sendControlPacket(ackPacket{TraceID: traceID})
+ s.sendControlPacket(syn)
pingTimer.Reset(pingInterval)
- case <-timeoutTimer.C:
- s.logf("client timeout")
- return s.clientInit
+ if s.staged.RelayIP != 0 && ack.FromAddr.IsValid() {
+ s.logf("CLIENT sending probe %v: %v", probe, ack.FromAddr)
+ s.sendControlPacketTo(probe, ack.FromAddr)
+ }
+ case <-timeoutTimer.C:
+ return s.clientInit
}
}
}
--
2.39.5
From 640f4b998605abfd01db32509097b08f6bf8465b Mon Sep 17 00:00:00 2001
From: jdl
Date: Tue, 24 Dec 2024 19:34:16 +0100
Subject: [PATCH 17/18] Cleanup. Direct, relayed, and hole-punching is working.
---
fasttime/time.go | 20 ----
fasttime/time_test.go | 18 ----
hub/api/db/written.go | 4 +-
node/addrdiscovery.go | 71 +++++++++++++
node/addrutil.go | 8 --
node/globalfuncs.go | 31 ++++--
node/globals.go | 53 +++++++---
node/header.go | 2 +-
node/main.go | 69 +++++++------
node/packets-util.go | 12 +--
node/packets.go | 35 +++++--
node/packets_test.go | 21 +---
node/peer-supervisor.go | 222 ++++++++++++++++++++--------------------
node/relaymanager.go | 40 ++++++++
14 files changed, 357 insertions(+), 249 deletions(-)
delete mode 100644 fasttime/time.go
delete mode 100644 fasttime/time_test.go
create mode 100644 node/addrdiscovery.go
delete mode 100644 node/addrutil.go
create mode 100644 node/relaymanager.go
diff --git a/fasttime/time.go b/fasttime/time.go
deleted file mode 100644
index 5c569ac..0000000
--- a/fasttime/time.go
+++ /dev/null
@@ -1,20 +0,0 @@
-package fasttime
-
-import (
- "sync/atomic"
- "time"
-)
-
-var _timestamp int64 = time.Now().Unix()
-
-func init() {
- go func() {
- for range time.Tick(1100 * time.Millisecond) {
- atomic.StoreInt64(&_timestamp, time.Now().Unix())
- }
- }()
-}
-
-func Now() int64 {
- return atomic.LoadInt64(&_timestamp)
-}
diff --git a/fasttime/time_test.go b/fasttime/time_test.go
deleted file mode 100644
index b0a85d0..0000000
--- a/fasttime/time_test.go
+++ /dev/null
@@ -1,18 +0,0 @@
-package fasttime
-
-import (
- "testing"
- "time"
-)
-
-func BenchmarkNow(b *testing.B) {
- for i := 0; i < b.N; i++ {
- Now()
- }
-}
-
-func BenchmarkTimeUnix(b *testing.B) {
- for i := 0; i < b.N; i++ {
- time.Now().Unix()
- }
-}
diff --git a/hub/api/db/written.go b/hub/api/db/written.go
index 65769c4..5b8bb15 100644
--- a/hub/api/db/written.go
+++ b/hub/api/db/written.go
@@ -1,12 +1,12 @@
package db
-import "vppn/fasttime"
+import "time"
func Session_UpdateLastSeenAt(
tx TX,
id string,
) (err error) {
- _, err = tx.Exec("UPDATE sessions SET LastSeenAt=? WHERE SessionID=?", fasttime.Now(), id)
+ _, err = tx.Exec("UPDATE sessions SET LastSeenAt=? WHERE SessionID=?", time.Now().Unix(), id)
return err
}
diff --git a/node/addrdiscovery.go b/node/addrdiscovery.go
new file mode 100644
index 0000000..b62e13f
--- /dev/null
+++ b/node/addrdiscovery.go
@@ -0,0 +1,71 @@
+package node
+
+import (
+ "log"
+ "net/netip"
+ "time"
+)
+
+func addrDiscoveryServer() {
+ var (
+ buf1 = make([]byte, bufferSize)
+ buf2 = make([]byte, bufferSize)
+ )
+
+ for {
+ pkt := <-discoveryPackets
+
+ p, ok := pkt.Payload.(addrDiscoveryPacket)
+ if !ok {
+ continue
+ }
+
+ route := routingTable[pkt.SrcIP].Load()
+ if route == nil || !route.RemoteAddr.IsValid() {
+ continue
+ }
+
+ _sendControlPacket(addrDiscoveryPacket{
+ TraceID: p.TraceID,
+ ToAddr: pkt.SrcAddr,
+ }, *route, buf1, buf2)
+ }
+}
+
+func addrDiscoveryClient() {
+ var (
+ checkInterval = 8 * time.Second
+ timer = time.NewTimer(4 * time.Second)
+
+ buf1 = make([]byte, bufferSize)
+ buf2 = make([]byte, bufferSize)
+
+ addrPacket addrDiscoveryPacket
+ lAddr netip.AddrPort
+ )
+
+ for {
+ select {
+ case pkt := <-discoveryPackets:
+ p, ok := pkt.Payload.(addrDiscoveryPacket)
+ if !ok || p.TraceID != addrPacket.TraceID || !p.ToAddr.IsValid() || p.ToAddr == lAddr {
+ continue
+ }
+
+ log.Printf("Discovered local address: %v", p.ToAddr)
+ lAddr = p.ToAddr
+ localAddr.Store(&p.ToAddr)
+
+ case <-timer.C:
+ timer.Reset(checkInterval)
+
+ route := getRelayRoute()
+ if route == nil {
+ continue
+ }
+
+ addrPacket.TraceID = newTraceID()
+ _sendControlPacket(addrPacket, *route, buf1, buf2)
+ }
+ }
+}
diff --git a/node/addrutil.go b/node/addrutil.go
deleted file mode 100644
index 590c80c..0000000
--- a/node/addrutil.go
+++ /dev/null
@@ -1,8 +0,0 @@
-package node
-
-import "net/netip"
-
-func addrIsValid(in []byte) bool {
- _, ok := netip.AddrFromSlice(in)
- return ok
-}
diff --git a/node/globalfuncs.go b/node/globalfuncs.go
index 406588e..98975da 100644
--- a/node/globalfuncs.go
+++ b/node/globalfuncs.go
@@ -1,10 +1,24 @@
package node
import (
- "log"
+ "net/netip"
"sync/atomic"
)
+func getRelayRoute() *peerRoute {
+ if ip := relayIP.Load(); ip != nil {
+ return routingTable[*ip].Load()
+ }
+ return nil
+}
+
+func getLocalAddr() netip.AddrPort {
+ if a := localAddr.Load(); a != nil {
+ return *a
+ }
+ return netip.AddrPort{}
+}
+
func _sendControlPacket(pkt interface{ Marshal([]byte) []byte }, route peerRoute, buf1, buf2 []byte) {
buf := pkt.Marshal(buf2)
h := header{
@@ -15,12 +29,12 @@ func _sendControlPacket(pkt interface{ Marshal([]byte) []byte }, route peerRoute
}
buf = route.ControlCipher.Encrypt(h, buf, buf1)
- if route.RelayIP == 0 {
+ if route.Direct {
_conn.WriteTo(buf, route.RemoteAddr)
return
}
- _relayPacket(route.RelayIP, route.IP, buf, buf2)
+ _relayPacket(route.IP, buf, buf2)
}
func _sendDataPacket(route *peerRoute, pkt, buf1, buf2 []byte) {
@@ -33,18 +47,17 @@ func _sendDataPacket(route *peerRoute, pkt, buf1, buf2 []byte) {
enc := route.DataCipher.Encrypt(h, pkt, buf1)
- if route.RelayIP == 0 {
+ if route.Direct {
_conn.WriteTo(enc, route.RemoteAddr)
return
}
- _relayPacket(route.RelayIP, route.IP, enc, buf2)
+ _relayPacket(route.IP, enc, buf2)
}
-func _relayPacket(relayIP, destIP byte, data, buf []byte) {
- relayRoute := routingTable[relayIP].Load()
- if !relayRoute.Up || !relayRoute.Relay {
- log.Print("Failed to send data packet: relay not available.")
+func _relayPacket(destIP byte, data, buf []byte) {
+ relayRoute := getRelayRoute()
+ if relayRoute == nil || !relayRoute.Up || !relayRoute.Relay {
return
}
diff --git a/node/globals.go b/node/globals.go
index 25eee33..3b8edea 100644
--- a/node/globals.go
+++ b/node/globals.go
@@ -3,11 +3,10 @@ package node
import (
"net/netip"
"sync/atomic"
+ "time"
"vppn/m"
)
-var zeroAddrPort = netip.AddrPort{}
-
const (
bufferSize = 1536
if_mtu = 1200
@@ -20,13 +19,10 @@ type peerRoute struct {
IP byte
Up bool // True if data can be sent on the route.
Relay bool // True if the peer is a relay.
+ Direct bool // True if this is a direct connection.
ControlCipher *controlCipher
DataCipher *dataCipher
RemoteAddr netip.AddrPort // Remote address if directly connected.
- // TODO: Remove this and use global localAddr and relayIP.
- // Replace w/ a Direct boolean.
- LocalAddr netip.AddrPort // Local address as seen by the remote.
- RelayIP byte // Non-zero if we should relay.
}
var (
@@ -34,7 +30,6 @@ var (
netName string
localIP byte
localPub bool
- localAddr netip.AddrPort
privateKey []byte
// Shared interface for writing.
@@ -44,22 +39,48 @@ var (
_conn *connWriter
// Counters for sending to each peer.
- sendCounters [256]uint64
+ sendCounters [256]uint64 = func() (out [256]uint64) {
+ for i := range out {
+ out[i] = uint64(time.Now().Unix()<<30 + 1)
+ }
+ return
+ }()
// Duplicate checkers for incoming packets.
- dupChecks [256]*dupCheck
+ dupChecks [256]*dupCheck = func() (out [256]*dupCheck) {
+ for i := range out {
+ out[i] = newDupCheck(0)
+ }
+ return
+ }()
// Channels for incoming control packets.
- controlPackets [256]chan controlPacket
+ controlPackets [256]chan controlPacket = func() (out [256]chan controlPacket) {
+ for i := range out {
+ out[i] = make(chan controlPacket, 256)
+ }
+ return
+ }()
// Channels for incoming peer updates from the hub.
- peerUpdates [256]chan *m.Peer
+ peerUpdates [256]chan *m.Peer = func() (out [256]chan *m.Peer) {
+ for i := range out {
+ out[i] = make(chan *m.Peer)
+ }
+ return
+ }()
// Global routing table.
- routingTable [256]*atomic.Pointer[peerRoute]
+ routingTable [256]*atomic.Pointer[peerRoute] = func() (out [256]*atomic.Pointer[peerRoute]) {
+ for i := range out {
+ out[i] = &atomic.Pointer[peerRoute]{}
+ out[i].Store(&peerRoute{})
+ }
+ return
+ }()
- // TODO: use relay for local address discovery. This should be new stream ID,
- // managed by a single thread.
- // localAddr *atomic.Pointer[netip.AddrPort]
- // relayIP *atomic.Pointer[byte]
+ // Managed by the relayManager.
+ discoveryPackets chan controlPacket
+ localAddr *atomic.Pointer[netip.AddrPort] // May be nil.
+ relayIP *atomic.Pointer[byte] // May be nil.
)
diff --git a/node/header.go b/node/header.go
index fd28962..58ba852 100644
--- a/node/header.go
+++ b/node/header.go
@@ -14,7 +14,7 @@ const (
type header struct {
StreamID byte
- Counter uint64 // Init with fasttime.Now() << 30 to ensure monotonic.
+ Counter uint64 // Init with time.Now().Unix << 30 to ensure monotonic.
SourceIP byte
DestIP byte
}
diff --git a/node/main.go b/node/main.go
index 70857c3..ee2e7a7 100644
--- a/node/main.go
+++ b/node/main.go
@@ -12,7 +12,6 @@ import (
"os"
"runtime/debug"
"sync/atomic"
- "time"
"vppn/m"
)
@@ -104,36 +103,34 @@ func main(listenIP string, port uint16) {
}
// Intialize globals.
+ _iface = newIFWriter(iface)
+ _conn = newConnWriter(conn)
+
localIP = config.PeerIP
+ discoveryPackets = make(chan controlPacket, 256)
+ localAddr = &atomic.Pointer[netip.AddrPort]{}
+ relayIP = &atomic.Pointer[byte]{}
ip, ok := netip.AddrFromSlice(config.PublicIP)
if ok {
localPub = true
- localAddr = netip.AddrPortFrom(ip, config.Port)
+ addr := netip.AddrPortFrom(ip, config.Port)
+ localAddr.Store(&addr)
}
privateKey = config.PrivKey
- _iface = newIFWriter(iface)
- _conn = newConnWriter(conn)
-
- for i := range 256 {
- sendCounters[i] = uint64(time.Now().Unix()<<30) + 1
- dupChecks[i] = newDupCheck(0)
- controlPackets[i] = make(chan controlPacket, 256)
- peerUpdates[i] = make(chan *m.Peer)
- routingTable[i] = &atomic.Pointer[peerRoute]{}
- route := peerRoute{IP: byte(i)}
- routingTable[i].Store(&route)
- }
-
// Start supervisors.
for i := range 256 {
go newPeerSupervisor(i).Run()
}
- // --------------------
-
+ if localPub {
+ go addrDiscoveryServer()
+ } else {
+ go addrDiscoveryClient()
+ go relayManager()
+ }
go newHubPoller(config).Run()
go readFromConn(conn)
readFromIFace(iface)
@@ -173,6 +170,8 @@ func readFromConn(conn *net.UDPConn) {
log.Fatalf("Failed to read from UDP port: %v", err)
}
+ remoteAddr = netip.AddrPortFrom(remoteAddr.Addr().Unmap(), remoteAddr.Port())
+
data = buf[:n]
if n < headerSize {
@@ -184,8 +183,6 @@ func readFromConn(conn *net.UDPConn) {
case controlStreamID:
handleControlPacket(remoteAddr, h, data, decBuf)
- // TODO: discoveryStreamID
-
case dataStreamID:
handleDataPacket(h, data, decBuf)
@@ -198,7 +195,7 @@ func readFromConn(conn *net.UDPConn) {
func handleControlPacket(addr netip.AddrPort, h header, data, decBuf []byte) {
route := routingTable[h.SourceIP].Load()
if route.ControlCipher == nil {
- log.Printf("Not connected (control).")
+ //log.Printf("Not connected (control).")
return
}
@@ -209,17 +206,17 @@ func handleControlPacket(addr netip.AddrPort, h header, data, decBuf []byte) {
out, ok := route.ControlCipher.Decrypt(data, decBuf)
if !ok {
- log.Printf("Failed to decrypt control packet.")
+ //log.Printf("Failed to decrypt control packet.")
return
}
if len(out) == 0 {
- log.Printf("Empty control packet from: %d", h.SourceIP)
+ //log.Printf("Empty control packet from: %d", h.SourceIP)
return
}
if dupChecks[h.SourceIP].IsDup(h.Counter) {
- log.Printf("[%03d] Duplicate control packet: %d", h.SourceIP, h.Counter)
+ //log.Printf("[%03d] Duplicate control packet: %d", h.SourceIP, h.Counter)
return
}
@@ -233,17 +230,29 @@ func handleControlPacket(addr netip.AddrPort, h header, data, decBuf []byte) {
return
}
- select {
- case controlPackets[h.SourceIP] <- pkt:
+ switch pkt.Payload.(type) {
+
+ case addrDiscoveryPacket:
+ select {
+ case discoveryPackets <- pkt:
+ default:
+ log.Printf("Dropping discovery packet.")
+ }
+
default:
- log.Printf("Dropping control packet.")
+ select {
+ case controlPackets[h.SourceIP] <- pkt:
+ default:
+ log.Printf("Dropping control packet.")
+ }
}
+
}
func handleDataPacket(h header, data []byte, decBuf []byte) {
route := routingTable[h.SourceIP].Load()
if !route.Up {
- log.Printf("Not connected (recv).")
+ //log.Printf("Not connected (recv).")
return
}
@@ -254,7 +263,7 @@ func handleDataPacket(h header, data []byte, decBuf []byte) {
}
if dupChecks[h.SourceIP].IsDup(h.Counter) {
- log.Printf("[%03d] Duplicate data packet: %d", h.SourceIP, h.Counter)
+ //log.Printf("[%03d] Duplicate data packet: %d", h.SourceIP, h.Counter)
return
}
@@ -264,8 +273,8 @@ func handleDataPacket(h header, data []byte, decBuf []byte) {
}
destRoute := routingTable[h.DestIP].Load()
- if !destRoute.Up || destRoute.RelayIP != 0 {
- log.Printf("Not connected (relay)")
+ if !destRoute.Up {
+ log.Printf("Not connected (relay): %v", destRoute)
return
}
diff --git a/node/packets-util.go b/node/packets-util.go
index 8a6e13a..af10eb5 100644
--- a/node/packets-util.go
+++ b/node/packets-util.go
@@ -3,16 +3,14 @@ package node
import (
"net/netip"
"sync/atomic"
+ "time"
"unsafe"
- "vppn/fasttime"
)
-var (
- traceIDCounter uint64
-)
+var traceIDCounter uint64 = uint64(time.Now().Unix()<<30) + 1
func newTraceID() uint64 {
- return uint64(fasttime.Now()<<30) + atomic.AddUint64(&traceIDCounter, 1)
+ return atomic.AddUint64(&traceIDCounter, 1)
}
// ----------------------------------------------------------------------------
@@ -151,9 +149,9 @@ func (r *binReader) AddrPort(x *netip.AddrPort) *binReader {
if !r.hasBytes(18) {
return r
}
- addr := netip.AddrFrom16(([16]byte)(r.b[r.i : r.i+16]))
- addr = addr.Unmap()
+ addr := netip.AddrFrom16(([16]byte)(r.b[r.i : r.i+16])).Unmap()
r.i += 16
+
var port uint16
r.Uint16(&port)
*x = netip.AddrPortFrom(addr, port)
diff --git a/node/packets.go b/node/packets.go
index f0ea736..267fed0 100644
--- a/node/packets.go
+++ b/node/packets.go
@@ -2,7 +2,6 @@ package node
import (
"errors"
- "log"
"net/netip"
)
@@ -16,6 +15,7 @@ const (
packetTypeSynAck
packetTypeAck
packetTypeProbe
+ packetTypeAddrDiscovery
)
// ----------------------------------------------------------------------------
@@ -33,8 +33,9 @@ func (p *controlPacket) ParsePayload(buf []byte) (err error) {
case packetTypeSynAck:
p.Payload, err = parseSynAckPacket(buf)
case packetTypeProbe:
- log.Printf("Got probe...")
p.Payload, err = parseProbePacket(buf)
+ case packetTypeAddrDiscovery:
+ p.Payload, err = parseAddrDiscoveryPacket(buf)
default:
return errUnknownPacketType
}
@@ -46,7 +47,7 @@ func (p *controlPacket) ParsePayload(buf []byte) (err error) {
type synPacket struct {
TraceID uint64 // TraceID to match response w/ request.
SharedKey [32]byte // Our shared key.
- RelayIP byte
+ Direct bool
FromAddr netip.AddrPort // The client's sending address.
}
@@ -55,7 +56,7 @@ func (p synPacket) Marshal(buf []byte) []byte {
Byte(packetTypeSyn).
Uint64(p.TraceID).
SharedKey(p.SharedKey).
- Byte(p.RelayIP).
+ Bool(p.Direct).
AddrPort(p.FromAddr).
Build()
}
@@ -64,7 +65,7 @@ func parseSynPacket(buf []byte) (p synPacket, err error) {
err = newBinReader(buf[1:]).
Uint64(&p.TraceID).
SharedKey(&p.SharedKey).
- Byte(&p.RelayIP).
+ Bool(&p.Direct).
AddrPort(&p.FromAddr).
Error()
return
@@ -75,7 +76,6 @@ func parseSynPacket(buf []byte) (p synPacket, err error) {
type synAckPacket struct {
TraceID uint64
FromAddr netip.AddrPort
- ToAddr netip.AddrPort
}
func (p synAckPacket) Marshal(buf []byte) []byte {
@@ -83,7 +83,6 @@ func (p synAckPacket) Marshal(buf []byte) []byte {
Byte(packetTypeSynAck).
Uint64(p.TraceID).
AddrPort(p.FromAddr).
- AddrPort(p.ToAddr).
Build()
}
@@ -91,7 +90,6 @@ func parseSynAckPacket(buf []byte) (p synAckPacket, err error) {
err = newBinReader(buf[1:]).
Uint64(&p.TraceID).
AddrPort(&p.FromAddr).
- AddrPort(&p.ToAddr).
Error()
return
}
@@ -99,9 +97,24 @@ func parseSynAckPacket(buf []byte) (p synAckPacket, err error) {
// ----------------------------------------------------------------------------
type addrDiscoveryPacket struct {
- TraceID uint64
- FromAddr netip.AddrPort
- ToAddr netip.AddrPort
+ TraceID uint64
+ ToAddr netip.AddrPort
+}
+
+func (p addrDiscoveryPacket) Marshal(buf []byte) []byte {
+ return newBinWriter(buf).
+ Byte(packetTypeAddrDiscovery).
+ Uint64(p.TraceID).
+ AddrPort(p.ToAddr).
+ Build()
+}
+
+func parseAddrDiscoveryPacket(buf []byte) (p addrDiscoveryPacket, err error) {
+ err = newBinReader(buf[1:]).
+ Uint64(&p.TraceID).
+ AddrPort(&p.ToAddr).
+ Error()
+ return
}
// ----------------------------------------------------------------------------
diff --git a/node/packets_test.go b/node/packets_test.go
index bd83080..60295ec 100644
--- a/node/packets_test.go
+++ b/node/packets_test.go
@@ -9,7 +9,9 @@ import (
func TestPacketSyn(t *testing.T) {
in := synPacket{
- TraceID: newTraceID(),
+ TraceID: newTraceID(),
+ RelayIP: 4,
+ FromAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{4, 5, 6, 7}), 22),
}
rand.Read(in.SharedKey[:])
@@ -26,7 +28,7 @@ func TestPacketSyn(t *testing.T) {
func TestPacketSynAck(t *testing.T) {
in := synAckPacket{
TraceID: newTraceID(),
- RecvAddr: netip.AddrPort{},
+ FromAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{4, 5, 6, 7}), 22),
}
out, err := parseSynAckPacket(in.Marshal(make([]byte, bufferSize)))
@@ -38,18 +40,3 @@ func TestPacketSynAck(t *testing.T) {
t.Fatal("\n", in, "\n", out)
}
}
-
-func TestPacketAck(t *testing.T) {
- in := ackPacket{
- TraceID: newTraceID(),
- }
-
- out, err := parseAckPacket(in.Marshal(make([]byte, bufferSize)))
- if err != nil {
- t.Fatal(err)
- }
-
- if !reflect.DeepEqual(in, out) {
- t.Fatal("\n", in, "\n", out)
- }
-}
diff --git a/node/peer-supervisor.go b/node/peer-supervisor.go
index e4f056e..76e329c 100644
--- a/node/peer-supervisor.go
+++ b/node/peer-supervisor.go
@@ -3,7 +3,6 @@ package node
import (
"fmt"
"log"
- "math/rand"
"net/netip"
"sync/atomic"
"time"
@@ -11,10 +10,8 @@ import (
)
const (
- dialTimeout = 8 * time.Second
- connectTimeout = 6 * time.Second
- pingInterval = 6 * time.Second
- timeoutInterval = 20 * time.Second
+ pingInterval = 8 * time.Second
+ timeoutInterval = 25 * time.Second
)
// ----------------------------------------------------------------------------
@@ -64,6 +61,7 @@ func (s *peerSupervisor) Run() {
func (s *peerSupervisor) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) {
_sendControlPacket(pkt, s.staged, s.buf1, s.buf2)
+ time.Sleep(500 * time.Millisecond) // Rate limit packets.
}
func (s *peerSupervisor) sendControlPacketTo(
@@ -75,25 +73,10 @@ func (s *peerSupervisor) sendControlPacketTo(
return
}
route := s.staged
- route.RelayIP = 0
+ route.Direct = true
route.RemoteAddr = addr
_sendControlPacket(pkt, route, s.buf1, s.buf2)
-}
-
-// ----------------------------------------------------------------------------
-
-func (s *peerSupervisor) getLocalAddr() netip.AddrPort {
- if localPub {
- return localAddr
- }
-
- if s.staged.RelayIP != 0 {
- if addr := routingTable[s.staged.RelayIP].Load().LocalAddr; addr.IsValid() {
- return addr
- }
- }
-
- return s.staged.LocalAddr
+ time.Sleep(500 * time.Millisecond) // Rate limit packets.
}
// ----------------------------------------------------------------------------
@@ -138,18 +121,21 @@ func (s *peerSupervisor) _peerUpdate(peer *m.Peer) stateFunc {
if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid {
s.remotePub = true
s.staged.Relay = peer.Relay
+ s.staged.Direct = true
s.staged.RemoteAddr = netip.AddrPortFrom(ip, peer.Port)
+ } else if localPub {
+ s.staged.Direct = true
}
if s.remotePub == localPub {
if localIP < s.remoteIP {
return s.server
}
- return s.clientInit
+ return s.client
}
if s.remotePub {
- return s.clientInit
+ return s.client
}
return s.server
}
@@ -157,9 +143,14 @@ func (s *peerSupervisor) _peerUpdate(peer *m.Peer) stateFunc {
// ----------------------------------------------------------------------------
func (s *peerSupervisor) server() stateFunc {
- s.logf("STATE: server")
+ logf := func(format string, args ...any) { s.logf("SERVER "+format, args...) }
- var syn synPacket
+ logf("DOWN")
+
+ var (
+ syn synPacket
+ timeoutTimer = time.NewTimer(timeoutInterval)
+ )
for {
select {
@@ -172,110 +163,80 @@ func (s *peerSupervisor) server() stateFunc {
case synPacket:
// Before we can respond to this packet, we need to make sure the
// route is setup properly.
- if p.TraceID != syn.TraceID {
+ //
+ // The client will update the syn's TraceID whenever there's a change.
+ // The server will follow the client's request.
+ if p.TraceID != syn.TraceID || !s.staged.Up {
+ if p.Direct {
+ logf("UP - Direct")
+ } else {
+ logf("UP - Relayed")
+ }
+
syn = p
s.staged.Up = true
- s.staged.RemoteAddr = pkt.SrcAddr
+ s.staged.Direct = syn.Direct
s.staged.DataCipher = newDataCipherFromKey(syn.SharedKey)
- s.staged.RelayIP = syn.RelayIP
- s.staged.LocalAddr = s.getLocalAddr()
+ s.staged.RemoteAddr = pkt.SrcAddr
+
s.publish()
}
// We should always respond.
- s.sendControlPacket(synAckPacket{
+ ack := synAckPacket{
TraceID: syn.TraceID,
- FromAddr: s.staged.LocalAddr,
- ToAddr: pkt.SrcAddr,
- })
+ FromAddr: getLocalAddr(),
+ }
+ s.sendControlPacket(ack)
- // If we're relayed, attempt to probe the client.
- if s.staged.RelayIP != 0 && syn.FromAddr.IsValid() {
- probe := probePacket{TraceID: newTraceID()}
- s.logf("SERVER sending probe %v: %v", probe, syn.FromAddr)
- s.sendControlPacketTo(probe, syn.FromAddr)
+ if s.staged.Direct {
+ continue
}
+ if !syn.FromAddr.IsValid() {
+ continue
+ }
+
+ probe := probePacket{TraceID: newTraceID()}
+ s.sendControlPacketTo(probe, syn.FromAddr)
+
case probePacket:
- s.logf("SERVER got probe: %v", p)
- s.logf("SERVER sending probe: %v", pkt.SrcAddr)
- s.sendControlPacketTo(probePacket{TraceID: p.TraceID}, pkt.SrcAddr)
- }
- }
- }
-}
-
-// ----------------------------------------------------------------------------
-
-func (s *peerSupervisor) clientInit() stateFunc {
- s.logf("STATE: client-init")
- if !s.remotePub {
- return s.clientSelectRelay
- }
-
- return s.client
-}
-
-// ----------------------------------------------------------------------------
-
-func (s *peerSupervisor) clientSelectRelay() stateFunc {
- s.logf("STATE: client-select-relay")
-
- timer := time.NewTimer(0)
- defer timer.Stop()
-
- for {
- select {
- case peer := <-s.peerUpdates:
- return s.peerUpdate(peer)
-
- case <-timer.C:
- relay := s.selectRelay()
- if relay != nil {
- s.logf("Got relay: %d", relay.IP)
- s.staged.RelayIP = relay.IP
- s.staged.LocalAddr = relay.LocalAddr
- s.publish()
- return s.client
+ if pkt.SrcAddr.IsValid() {
+ s.sendControlPacketTo(probePacket{TraceID: p.TraceID}, pkt.SrcAddr)
+ } else {
+ logf("Invalid probe address")
+ }
}
- s.logf("No relay available.")
- timer.Reset(pingInterval)
+ case <-timeoutTimer.C:
+ logf("Connection timeout")
+ s.staged.Up = false
+ s.publish()
}
}
}
-func (s *peerSupervisor) selectRelay() *peerRoute {
- possible := make([]*peerRoute, 0, 8)
- for i := range routingTable {
- route := routingTable[i].Load()
- if !route.Up || !route.Relay {
- continue
- }
- possible = append(possible, route)
- }
-
- if len(possible) == 0 {
- return nil
- }
- return possible[rand.Intn(len(possible))]
-}
-
// ----------------------------------------------------------------------------
func (s *peerSupervisor) client() stateFunc {
- s.logf("STATE: client")
+ logf := func(format string, args ...any) { s.logf("CLIENT "+format, args...) }
+
+ logf("DOWN")
var (
syn = synPacket{
TraceID: newTraceID(),
SharedKey: s.staged.DataCipher.Key(),
- RelayIP: s.staged.RelayIP,
- FromAddr: s.getLocalAddr(),
+ Direct: s.staged.Direct,
+ FromAddr: getLocalAddr(),
}
+
ack synAckPacket
- probe = probePacket{TraceID: newTraceID()}
+ probe probePacket
+ probeAddr netip.AddrPort
+
+ lAddr netip.AddrPort
timeoutTimer = time.NewTimer(timeoutInterval)
pingTimer = time.NewTimer(pingInterval)
@@ -297,33 +258,74 @@ func (s *peerSupervisor) client() stateFunc {
case synAckPacket:
if p.TraceID != syn.TraceID {
- s.logf("Bad traceID?")
continue // Hmm...
}
+
ack = p
timeoutTimer.Reset(timeoutInterval)
if !s.staged.Up {
+ if s.staged.Direct {
+ logf("UP - Direct")
+ } else {
+ logf("UP - Relayed")
+ }
+
s.staged.Up = true
- s.staged.LocalAddr = p.ToAddr
s.publish()
}
case probePacket:
- s.logf("CLIENT got probe: %v", p)
+ if s.staged.Direct {
+ continue
+ }
+
+ if p.TraceID != probe.TraceID {
+ continue
+ }
+
+ // Upgrade connection.
+
+ logf("UP - Direct")
+ s.staged.Direct = true
+ s.staged.RemoteAddr = probeAddr
+ s.publish()
+
+ syn.TraceID = newTraceID()
+ syn.Direct = true
+ syn.FromAddr = getLocalAddr()
+ s.sendControlPacket(syn)
}
case <-pingTimer.C:
- s.sendControlPacket(syn)
- pingTimer.Reset(pingInterval)
+ // Send syn.
- if s.staged.RelayIP != 0 && ack.FromAddr.IsValid() {
- s.logf("CLIENT sending probe %v: %v", probe, ack.FromAddr)
- s.sendControlPacketTo(probe, ack.FromAddr)
+ syn.FromAddr = getLocalAddr()
+ if syn.FromAddr != lAddr {
+ syn.TraceID = newTraceID()
+ lAddr = syn.FromAddr
}
+ s.sendControlPacket(syn)
+
+ pingTimer.Reset(pingInterval)
+
+ if s.staged.Direct {
+ continue
+ }
+
+ if !ack.FromAddr.IsValid() {
+ continue
+ }
+
+ probe = probePacket{TraceID: newTraceID()}
+ probeAddr = ack.FromAddr
+
+ s.sendControlPacketTo(probe, ack.FromAddr)
+
case <-timeoutTimer.C:
- return s.clientInit
+ logf("Connection timeout")
+ return s.peerUpdate(s.peer)
}
}
}
diff --git a/node/relaymanager.go b/node/relaymanager.go
new file mode 100644
index 0000000..5c44ea8
--- /dev/null
+++ b/node/relaymanager.go
@@ -0,0 +1,40 @@
+package node
+
+import (
+ "log"
+ "math/rand"
+ "time"
+)
+
+func relayManager() {
+ time.Sleep(2 * time.Second)
+ updateRelayRoute()
+
+ for range time.Tick(8 * time.Second) {
+ relay := getRelayRoute()
+ if relay == nil || !relay.Up || !relay.Relay {
+ updateRelayRoute()
+ }
+ }
+}
+
+func updateRelayRoute() {
+ possible := make([]*peerRoute, 0, 8)
+ for i := range routingTable {
+ route := routingTable[i].Load()
+ if !route.Up || !route.Relay {
+ continue
+ }
+ possible = append(possible, route)
+ }
+
+ if len(possible) == 0 {
+ log.Printf("No relay available.")
+ relayIP.Store(nil)
+ return
+ }
+
+ ip := possible[rand.Intn(len(possible))].IP
+ log.Printf("New relay IP: %d", ip)
+ relayIP.Store(&ip)
+}
--
2.39.5
From ef4ef33579afb46cc31339e5ecee3688af0016fb Mon Sep 17 00:00:00 2001
From: jdl
Date: Tue, 24 Dec 2024 19:35:50 +0100
Subject: [PATCH 18/18] WIP
---
LICENSE | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/LICENSE b/LICENSE
index 078df32..042a386 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,6 +1,6 @@
MIT License
-Copyright (c) 2024 app
+Copyright (c) 2024 John David Lee (johndavidlee@crumpington.com)
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
--
2.39.5