Major update - symmetric encryption, UDP hole punching, code cleanup.

Reviewed-on: #1
This commit was merged in pull request #1.
This commit is contained in:
2024-12-24 18:37:43 +00:00
parent ee4f5e012c
commit 3bd73cfd34
48 changed files with 1739 additions and 1291 deletions

71
node/addrdiscovery.go Normal file
View File

@@ -0,0 +1,71 @@
package node
import (
"log"
"net/netip"
"time"
)
func addrDiscoveryServer() {
var (
buf1 = make([]byte, bufferSize)
buf2 = make([]byte, bufferSize)
)
for {
pkt := <-discoveryPackets
p, ok := pkt.Payload.(addrDiscoveryPacket)
if !ok {
continue
}
route := routingTable[pkt.SrcIP].Load()
if route == nil || !route.RemoteAddr.IsValid() {
continue
}
_sendControlPacket(addrDiscoveryPacket{
TraceID: p.TraceID,
ToAddr: pkt.SrcAddr,
}, *route, buf1, buf2)
}
}
func addrDiscoveryClient() {
var (
checkInterval = 8 * time.Second
timer = time.NewTimer(4 * time.Second)
buf1 = make([]byte, bufferSize)
buf2 = make([]byte, bufferSize)
addrPacket addrDiscoveryPacket
lAddr netip.AddrPort
)
for {
select {
case pkt := <-discoveryPackets:
p, ok := pkt.Payload.(addrDiscoveryPacket)
if !ok || p.TraceID != addrPacket.TraceID || !p.ToAddr.IsValid() || p.ToAddr == lAddr {
continue
}
log.Printf("Discovered local address: %v", p.ToAddr)
lAddr = p.ToAddr
localAddr.Store(&p.ToAddr)
case <-timer.C:
timer.Reset(checkInterval)
route := getRelayRoute()
if route == nil {
continue
}
addrPacket.TraceID = newTraceID()
_sendControlPacket(addrPacket, *route, buf1, buf2)
}
}
}

26
node/cipher-control.go Normal file
View File

@@ -0,0 +1,26 @@
package node
import "golang.org/x/crypto/nacl/box"
type controlCipher struct {
sharedKey [32]byte
}
func newControlCipher(privKey, pubKey []byte) *controlCipher {
shared := [32]byte{}
box.Precompute(&shared, (*[32]byte)(pubKey), (*[32]byte)(privKey))
return &controlCipher{shared}
}
func (cc *controlCipher) Encrypt(h header, data, out []byte) []byte {
const s = controlHeaderSize
out = out[:s+controlCipherOverhead+len(data)]
h.Marshal(out[:s])
box.SealAfterPrecomputation(out[s:s], data, (*[24]byte)(out[:s]), &cc.sharedKey)
return out
}
func (cc *controlCipher) Decrypt(encrypted, out []byte) (data []byte, ok bool) {
const s = controlHeaderSize
return box.OpenAfterPrecomputation(out[:0], encrypted[s:], (*[24]byte)(encrypted[:s]), &cc.sharedKey)
}

122
node/cipher-control_test.go Normal file
View File

@@ -0,0 +1,122 @@
package node
import (
"bytes"
"crypto/rand"
"reflect"
"testing"
"golang.org/x/crypto/nacl/box"
)
func newControlCipherForTesting() (c1, c2 *controlCipher) {
pubKey1, privKey1, err := box.GenerateKey(rand.Reader)
if err != nil {
panic(err)
}
pubKey2, privKey2, err := box.GenerateKey(rand.Reader)
if err != nil {
panic(err)
}
return newControlCipher(privKey1[:], pubKey2[:]),
newControlCipher(privKey2[:], pubKey1[:])
}
func TestControlCipher(t *testing.T) {
c1, c2 := newControlCipherForTesting()
maxSizePlaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead)
rand.Read(maxSizePlaintext)
testCases := [][]byte{
make([]byte, 0),
{1},
{255},
{1, 2, 3, 4, 5},
[]byte("Hello world"),
maxSizePlaintext,
}
for _, plaintext := range testCases {
h1 := header{
StreamID: controlStreamID,
Counter: 235153,
SourceIP: 4,
DestIP: 88,
}
encrypted := make([]byte, bufferSize)
encrypted = c1.Encrypt(h1, plaintext, encrypted)
h2 := header{}
h2.Parse(encrypted)
if !reflect.DeepEqual(h1, h2) {
t.Fatal(h1, h2)
}
decrypted, ok := c2.Decrypt(encrypted, make([]byte, bufferSize))
if !ok {
t.Fatal(ok)
}
if !bytes.Equal(decrypted, plaintext) {
t.Fatal("not equal")
}
}
}
func TestControlCipher_ShortCiphertext(t *testing.T) {
c1, _ := newControlCipherForTesting()
shortText := make([]byte, controlHeaderSize+controlCipherOverhead-1)
rand.Read(shortText)
_, ok := c1.Decrypt(shortText, make([]byte, bufferSize))
if ok {
t.Fatal(ok)
}
}
func BenchmarkControlCipher_Encrypt(b *testing.B) {
c1, _ := newControlCipherForTesting()
h1 := header{
Counter: 235153,
SourceIP: 4,
DestIP: 88,
}
plaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead)
rand.Read(plaintext)
encrypted := make([]byte, bufferSize)
b.ResetTimer()
for i := 0; i < b.N; i++ {
encrypted = c1.Encrypt(h1, plaintext, encrypted)
}
}
func BenchmarkControlCipher_Decrypt(b *testing.B) {
c1, c2 := newControlCipherForTesting()
h1 := header{
Counter: 235153,
SourceIP: 4,
DestIP: 88,
}
plaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead)
rand.Read(plaintext)
encrypted := make([]byte, bufferSize)
encrypted = c1.Encrypt(h1, plaintext, encrypted)
decrypted := make([]byte, bufferSize)
b.ResetTimer()
for i := 0; i < b.N; i++ {
decrypted, _ = c2.Decrypt(encrypted, decrypted)
}
}

62
node/cipher-data.go Normal file
View File

@@ -0,0 +1,62 @@
package node
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
)
// TODO: Use [32]byte for simplicity everywhere.
type dataCipher struct {
key [32]byte
aead cipher.AEAD
}
func newDataCipher() *dataCipher {
key := [32]byte{}
if _, err := rand.Read(key[:]); err != nil {
panic(err)
}
return newDataCipherFromKey(key)
}
// key must be 32 bytes.
func newDataCipherFromKey(key [32]byte) *dataCipher {
block, err := aes.NewCipher(key[:])
if err != nil {
panic(err)
}
aead, err := cipher.NewGCM(block)
if err != nil {
panic(err)
}
return &dataCipher{key: key, aead: aead}
}
func (sc *dataCipher) Key() [32]byte {
return sc.key
}
func (sc *dataCipher) Encrypt(h header, data, out []byte) []byte {
const s = dataHeaderSize
out = out[:s+dataCipherOverhead+len(data)]
h.Marshal(out[:s])
sc.aead.Seal(out[s:s], out[:s], data, nil)
return out
}
func (sc *dataCipher) Decrypt(encrypted, out []byte) (data []byte, ok bool) {
const s = dataHeaderSize
if len(encrypted) < s+dataCipherOverhead {
ok = false
return
}
var err error
data, err = sc.aead.Open(out[:0], encrypted[:s], encrypted[s:], nil)
ok = err == nil
return
}

141
node/cipher-data_test.go Normal file
View File

@@ -0,0 +1,141 @@
package node
import (
"bytes"
"crypto/rand"
mrand "math/rand/v2"
"reflect"
"testing"
)
func TestDataCipher(t *testing.T) {
maxSizePlaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead)
rand.Read(maxSizePlaintext)
testCases := [][]byte{
make([]byte, 0),
{1},
{255},
{1, 2, 3, 4, 5},
[]byte("Hello world"),
maxSizePlaintext,
}
for _, plaintext := range testCases {
h1 := header{
StreamID: dataStreamID,
Counter: 235153,
SourceIP: 4,
DestIP: 88,
}
encrypted := make([]byte, bufferSize)
dc1 := newDataCipher()
encrypted = dc1.Encrypt(h1, plaintext, encrypted)
h2 := header{}
h2.Parse(encrypted)
dc2 := newDataCipherFromKey(dc1.Key())
decrypted, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize))
if !ok {
t.Fatal(ok)
}
if !bytes.Equal(plaintext, decrypted) {
t.Fatal("not equal")
}
if !reflect.DeepEqual(h1, h2) {
t.Fatalf("%v != %v", h1, h2)
}
}
}
func TestDataCipher_ModifyCiphertext(t *testing.T) {
maxSizePlaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead)
rand.Read(maxSizePlaintext)
testCases := [][]byte{
make([]byte, 0),
{1},
{255},
{1, 2, 3, 4, 5},
[]byte("Hello world"),
maxSizePlaintext,
}
for _, plaintext := range testCases {
h1 := header{
Counter: 235153,
SourceIP: 4,
DestIP: 88,
}
encrypted := make([]byte, bufferSize)
dc1 := newDataCipher()
encrypted = dc1.Encrypt(h1, plaintext, encrypted)
encrypted[mrand.IntN(len(encrypted))]++
dc2 := newDataCipherFromKey(dc1.Key())
_, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize))
if ok {
t.Fatal(ok)
}
}
}
func TestDataCipher_ShortCiphertext(t *testing.T) {
dc1 := newDataCipher()
shortText := make([]byte, dataHeaderSize+dataCipherOverhead-1)
rand.Read(shortText)
_, ok := dc1.Decrypt(shortText, make([]byte, bufferSize))
if ok {
t.Fatal(ok)
}
}
func BenchmarkDataCipher_Encrypt(b *testing.B) {
h1 := header{
Counter: 235153,
SourceIP: 4,
DestIP: 88,
}
plaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead)
rand.Read(plaintext)
encrypted := make([]byte, bufferSize)
dc1 := newDataCipher()
b.ResetTimer()
for i := 0; i < b.N; i++ {
encrypted = dc1.Encrypt(h1, plaintext, encrypted)
}
}
func BenchmarkDataCipher_Decrypt(b *testing.B) {
h1 := header{
Counter: 235153,
SourceIP: 4,
DestIP: 88,
}
plaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead)
rand.Read(plaintext)
encrypted := make([]byte, bufferSize)
dc1 := newDataCipher()
encrypted = dc1.Encrypt(h1, plaintext, encrypted)
decrypted := make([]byte, bufferSize)
b.ResetTimer()
for i := 0; i < b.N; i++ {
decrypted, _ = dc1.Decrypt(encrypted, decrypted)
}
}

View File

@@ -1,172 +1,49 @@
package node
import (
"io"
"log"
"net"
"net/netip"
"runtime/debug"
"sync"
"sync/atomic"
"vppn/fasttime"
)
// ----------------------------------------------------------------------------
type connWriter struct {
*net.UDPConn
lock sync.Mutex
localIP byte
buf []byte
buf2 []byte
counters [256]uint64
routing *routingTable
lock sync.Mutex
conn *net.UDPConn
}
func newConnWriter(conn *net.UDPConn, localIP byte, routing *routingTable) *connWriter {
w := &connWriter{
UDPConn: conn,
localIP: localIP,
buf: make([]byte, bufferSize),
buf2: make([]byte, bufferSize),
routing: routing,
}
for i := range w.counters {
w.counters[i] = uint64(fasttime.Now() << 30)
}
return w
func newConnWriter(conn *net.UDPConn) *connWriter {
return &connWriter{conn: conn}
}
func (w *connWriter) WriteTo(remoteIP, stream byte, data []byte) {
dstPeer := w.routing.Get(remoteIP)
if dstPeer == nil {
log.Printf("No peer: %d", remoteIP)
return
}
if stream == streamData && !dstPeer.Up {
log.Printf("Peer down: %d", remoteIP)
return
}
var viaPeer *peer
if dstPeer.Mediated {
viaPeer = w.routing.mediator.Load()
if viaPeer == nil || viaPeer.Addr == nil {
log.Printf("Mediator not connected")
return
}
} else if dstPeer.Addr == nil {
log.Printf("Peer doesn't have address: %d", remoteIP)
return
}
w.WriteToPeer(dstPeer, viaPeer, stream, data)
}
func (w *connWriter) WriteToPeer(dstPeer, viaPeer *peer, stream byte, data []byte) {
func (w *connWriter) WriteTo(packet []byte, addr netip.AddrPort) {
w.lock.Lock()
addr := dstPeer.Addr
h := header{
Counter: atomic.AddUint64(&w.counters[dstPeer.IP], 1),
SourceIP: w.localIP,
DestIP: dstPeer.IP,
Stream: stream,
}
buf := encryptPacket(&h, dstPeer.SharedKey, data, w.buf)
if viaPeer != nil {
h := header{
Counter: atomic.AddUint64(&w.counters[viaPeer.IP], 1),
SourceIP: w.localIP,
DestIP: dstPeer.IP,
Forward: 1,
Stream: stream,
}
buf = encryptPacket(&h, viaPeer.SharedKey, buf, w.buf2)
addr = viaPeer.Addr
}
if _, err := w.WriteToUDPAddrPort(buf, *addr); err != nil {
if _, err := w.conn.WriteToUDPAddrPort(packet, addr); err != nil {
debug.PrintStack()
log.Fatalf("Failed to write to UDP port: %v", err)
}
w.lock.Unlock()
}
func (w *connWriter) Forward(dstIP byte, packet []byte) {
dstPeer := w.routing.Get(dstIP)
if dstPeer == nil || dstPeer.Addr == nil {
log.Printf("No peer: %d", dstIP)
return
}
if _, err := w.WriteToUDPAddrPort(packet, *dstPeer.Addr); err != nil {
log.Fatalf("Failed to write to UDP port: %v", err)
}
}
// ----------------------------------------------------------------------------
type connReader struct {
*net.UDPConn
localIP byte
dupChecks [256]*dupCheck
routing *routingTable
buf []byte
type ifWriter struct {
lock sync.Mutex
iface io.ReadWriteCloser
}
func newConnReader(conn *net.UDPConn, localIP byte, routing *routingTable) *connReader {
r := &connReader{
UDPConn: conn,
localIP: localIP,
routing: routing,
buf: make([]byte, bufferSize),
}
for i := range r.dupChecks {
r.dupChecks[i] = newDupCheck(0)
}
return r
func newIFWriter(iface io.ReadWriteCloser) *ifWriter {
return &ifWriter{iface: iface}
}
func (r *connReader) Read(buf []byte) (remoteAddr netip.AddrPort, h header, data []byte) {
var (
n int
err error
)
for {
n, remoteAddr, err = r.ReadFromUDPAddrPort(buf[:bufferSize])
if err != nil {
log.Fatalf("Failed to read from UDP port: %v", err)
}
data = buf[:n]
if n < headerSize {
continue // Packet it soo short.
}
h.Parse(data)
peer := r.routing.Get(h.SourceIP)
if peer == nil {
continue
}
out, ok := decryptPacket(peer.SharedKey, data, r.buf)
if !ok {
continue
}
out, data = data, out
if r.dupChecks[h.SourceIP].IsDup(h.Counter) {
log.Printf("Duplicate: %d", h.Counter)
continue
}
return
func (w *ifWriter) Write(packet []byte) {
w.lock.Lock()
if _, err := w.iface.Write(packet); err != nil {
log.Fatalf("Failed to write to interface: %v", err)
}
w.lock.Unlock()
}

View File

@@ -1,50 +0,0 @@
package node
import (
"sync"
"vppn/fasttime"
"golang.org/x/crypto/nacl/box"
)
// Encrypting the packet will also set the header's DataSize field.
func encryptPacket(h *header, sharedKey, data, out []byte) []byte {
out = out[:headerSize]
h.Marshal(out)
b := box.SealAfterPrecomputation(out[headerSize:headerSize], data, (*[24]byte)(out[:headerSize]), (*[32]byte)(sharedKey))
return out[:len(b)+headerSize]
}
func decryptPacket(sharedKey, packetAndHeader, out []byte) (decrypted []byte, ok bool) {
return box.OpenAfterPrecomputation(
out[:0],
packetAndHeader[headerSize:],
(*[24]byte)(packetAndHeader[:headerSize]),
(*[32]byte)(sharedKey))
}
func computeSharedKey(peerPubKey, privKey []byte) []byte {
shared := [32]byte{}
box.Precompute(&shared, (*[32]byte)(peerPubKey), (*[32]byte)(privKey))
return shared[:]
}
var (
traceIDLock sync.Mutex
traceIDTime uint64
traceIDCounter uint64
)
func newTraceID() (id uint64) {
traceIDLock.Lock()
defer traceIDLock.Unlock()
now := uint64(fasttime.Now())
if traceIDTime < now {
traceIDTime = now
traceIDCounter = 0
}
traceIDCounter++
return traceIDTime<<30 + traceIDCounter
}

View File

@@ -1,140 +0,0 @@
package node
import (
"bytes"
"crypto/rand"
"log"
"reflect"
"testing"
"golang.org/x/crypto/nacl/box"
)
func TestEncryptDecryptPacket(t *testing.T) {
pubKey1, privKey1, err := box.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
pubKey2, privKey2, err := box.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
log.Printf("\n%#v\n%#v\n%#v\n%#v\n", pubKey1, privKey1, pubKey2, privKey2)
sharedEncKey := [32]byte{}
box.Precompute(&sharedEncKey, pubKey2, privKey1)
sharedDecKey := [32]byte{}
box.Precompute(&sharedDecKey, pubKey1, privKey2)
original := make([]byte, if_mtu-64)
rand.Read(original)
h := header{
Counter: 2893749238,
SourceIP: 5,
DestIP: 12,
Forward: 1,
Stream: 1,
}
encrypted := make([]byte, bufferSize)
encrypted = encryptPacket(&h, sharedEncKey[:], original, encrypted)
decrypted := make([]byte, bufferSize)
var ok bool
decrypted, ok = decryptPacket(sharedDecKey[:], encrypted, decrypted)
if !ok {
t.Fatal(ok)
}
var h2 header
h2.Parse(encrypted)
if !reflect.DeepEqual(h, h2) {
t.Fatal(h, h2)
}
if !bytes.Equal(original, decrypted) {
t.Fatal("mismatch")
}
}
func BenchmarkEncryptPacket(b *testing.B) {
_, privKey1, err := box.GenerateKey(rand.Reader)
if err != nil {
b.Fatal(err)
}
pubKey2, _, err := box.GenerateKey(rand.Reader)
if err != nil {
b.Fatal(err)
}
sharedEncKey := [32]byte{}
box.Precompute(&sharedEncKey, pubKey2, privKey1)
original := make([]byte, if_mtu)
rand.Read(original)
nonce := make([]byte, headerSize)
rand.Read(nonce)
encrypted := make([]byte, bufferSize)
h := header{
Counter: 2893749238,
SourceIP: 5,
DestIP: 12,
Forward: 1,
Stream: 1,
}
for i := 0; i < b.N; i++ {
encrypted = encryptPacket(&h, sharedEncKey[:], original, encrypted)
}
}
func BenchmarkDecryptPacket(b *testing.B) {
pubKey1, privKey1, err := box.GenerateKey(rand.Reader)
if err != nil {
b.Fatal(err)
}
pubKey2, privKey2, err := box.GenerateKey(rand.Reader)
if err != nil {
b.Fatal(err)
}
sharedEncKey := [32]byte{}
box.Precompute(&sharedEncKey, pubKey2, privKey1)
sharedDecKey := [32]byte{}
box.Precompute(&sharedDecKey, pubKey1, privKey2)
original := make([]byte, if_mtu)
rand.Read(original)
nonce := make([]byte, headerSize)
rand.Read(nonce)
h := header{
Counter: 2893749238,
SourceIP: 5,
DestIP: 12,
Forward: 1,
Stream: 1,
}
encrypted := encryptPacket(&h, sharedEncKey[:], original, make([]byte, bufferSize))
decrypted := make([]byte, bufferSize)
var ok bool
for i := 0; i < b.N; i++ {
decrypted, ok = decryptPacket(sharedDecKey[:], encrypted, decrypted)
if !ok {
panic(ok)
}
}
}

View File

@@ -1,7 +1,5 @@
package node
import "log"
type dupCheck struct {
bitSet
head int
@@ -22,7 +20,6 @@ func (dc *dupCheck) IsDup(counter uint64) bool {
// Before head => it's late, say it's a dup.
if counter < dc.headCounter {
log.Printf("Late: %d", counter)
return true
}
@@ -30,7 +27,6 @@ func (dc *dupCheck) IsDup(counter uint64) bool {
if counter < dc.tailCounter {
index := (int(counter-dc.headCounter) + dc.head) % bitSetSize
if dc.Get(index) {
log.Printf("Dup: %d, %d", counter, dc.tailCounter)
return true
}

View File

@@ -1,7 +1,6 @@
package node
import (
"log"
"testing"
)
@@ -49,8 +48,6 @@ func TestDupCheck(t *testing.T) {
for i, tc := range testCases {
if ok := dc.IsDup(tc.Counter); ok != tc.Dup {
log.Printf("%b", dc.bitSet)
log.Printf("%+v", *dc)
t.Fatal(i, ok, tc)
}
}

73
node/globalfuncs.go Normal file
View File

@@ -0,0 +1,73 @@
package node
import (
"net/netip"
"sync/atomic"
)
func getRelayRoute() *peerRoute {
if ip := relayIP.Load(); ip != nil {
return routingTable[*ip].Load()
}
return nil
}
func getLocalAddr() netip.AddrPort {
if a := localAddr.Load(); a != nil {
return *a
}
return netip.AddrPort{}
}
func _sendControlPacket(pkt interface{ Marshal([]byte) []byte }, route peerRoute, buf1, buf2 []byte) {
buf := pkt.Marshal(buf2)
h := header{
StreamID: controlStreamID,
Counter: atomic.AddUint64(&sendCounters[route.IP], 1),
SourceIP: localIP,
DestIP: route.IP,
}
buf = route.ControlCipher.Encrypt(h, buf, buf1)
if route.Direct {
_conn.WriteTo(buf, route.RemoteAddr)
return
}
_relayPacket(route.IP, buf, buf2)
}
func _sendDataPacket(route *peerRoute, pkt, buf1, buf2 []byte) {
h := header{
StreamID: dataStreamID,
Counter: atomic.AddUint64(&sendCounters[route.IP], 1),
SourceIP: localIP,
DestIP: route.IP,
}
enc := route.DataCipher.Encrypt(h, pkt, buf1)
if route.Direct {
_conn.WriteTo(enc, route.RemoteAddr)
return
}
_relayPacket(route.IP, enc, buf2)
}
func _relayPacket(destIP byte, data, buf []byte) {
relayRoute := getRelayRoute()
if relayRoute == nil || !relayRoute.Up || !relayRoute.Relay {
return
}
h := header{
StreamID: dataStreamID,
Counter: atomic.AddUint64(&sendCounters[relayRoute.IP], 1),
SourceIP: localIP,
DestIP: destIP,
}
enc := relayRoute.DataCipher.Encrypt(h, data, buf)
_conn.WriteTo(enc, relayRoute.RemoteAddr)
}

View File

@@ -1,3 +1,86 @@
package node
const bufferSize = if_mtu + 128
import (
"net/netip"
"sync/atomic"
"time"
"vppn/m"
)
const (
bufferSize = 1536
if_mtu = 1200
if_queue_len = 2048
controlCipherOverhead = 16
dataCipherOverhead = 16
)
type peerRoute struct {
IP byte
Up bool // True if data can be sent on the route.
Relay bool // True if the peer is a relay.
Direct bool // True if this is a direct connection.
ControlCipher *controlCipher
DataCipher *dataCipher
RemoteAddr netip.AddrPort // Remote address if directly connected.
}
var (
// Configuration for this peer.
netName string
localIP byte
localPub bool
privateKey []byte
// Shared interface for writing.
_iface *ifWriter
// Shared connection for writing.
_conn *connWriter
// Counters for sending to each peer.
sendCounters [256]uint64 = func() (out [256]uint64) {
for i := range out {
out[i] = uint64(time.Now().Unix()<<30 + 1)
}
return
}()
// Duplicate checkers for incoming packets.
dupChecks [256]*dupCheck = func() (out [256]*dupCheck) {
for i := range out {
out[i] = newDupCheck(0)
}
return
}()
// Channels for incoming control packets.
controlPackets [256]chan controlPacket = func() (out [256]chan controlPacket) {
for i := range out {
out[i] = make(chan controlPacket, 256)
}
return
}()
// Channels for incoming peer updates from the hub.
peerUpdates [256]chan *m.Peer = func() (out [256]chan *m.Peer) {
for i := range out {
out[i] = make(chan *m.Peer)
}
return
}()
// Global routing table.
routingTable [256]*atomic.Pointer[peerRoute] = func() (out [256]*atomic.Pointer[peerRoute]) {
for i := range out {
out[i] = &atomic.Pointer[peerRoute]{}
out[i].Store(&peerRoute{})
}
return
}()
// Managed by the relayManager.
discoveryPackets chan controlPacket
localAddr *atomic.Pointer[netip.AddrPort] // May be nil.
relayIP *atomic.Pointer[byte] // May be nil.
)

View File

@@ -2,32 +2,34 @@ package node
import "unsafe"
// ----------------------------------------------------------------------------
const (
headerSize = 24
streamData = 1
streamRouting = 2
headerSize = 12
controlStreamID = 2
controlHeaderSize = 24
dataStreamID = 1
dataHeaderSize = 12
)
type header struct {
Counter uint64 // Init with fasttime.Now() << 30 to ensure monotonic.
StreamID byte
Counter uint64 // Init with time.Now().Unix << 30 to ensure monotonic.
SourceIP byte
DestIP byte
Forward byte
Stream byte // See stream* constants.
}
func (hdr *header) Parse(nb []byte) {
hdr.Counter = *(*uint64)(unsafe.Pointer(&nb[0]))
hdr.SourceIP = nb[8]
hdr.DestIP = nb[9]
hdr.Forward = nb[10]
hdr.Stream = nb[11]
func (h *header) Parse(b []byte) {
h.StreamID = b[0]
h.Counter = *(*uint64)(unsafe.Pointer(&b[1]))
h.SourceIP = b[9]
h.DestIP = b[10]
}
func (hdr header) Marshal(buf []byte) {
*(*uint64)(unsafe.Pointer(&buf[0])) = hdr.Counter
buf[8] = hdr.SourceIP
buf[9] = hdr.DestIP
buf[10] = hdr.Forward
buf[11] = hdr.Stream
func (h *header) Marshal(buf []byte) {
buf[0] = h.StreamID
*(*uint64)(unsafe.Pointer(&buf[1])) = h.Counter
buf[9] = h.SourceIP
buf[10] = h.DestIP
buf[11] = 0
}

View File

@@ -4,11 +4,10 @@ import "testing"
func TestHeaderMarshalParse(t *testing.T) {
nIn := header{
StreamID: 23,
Counter: 3212,
SourceIP: 34,
DestIP: 200,
Forward: 1,
Stream: 44,
}
buf := make([]byte, headerSize)

94
node/hubpoller.go Normal file
View File

@@ -0,0 +1,94 @@
package node
import (
"encoding/json"
"io"
"log"
"net/http"
"net/url"
"time"
"vppn/m"
)
type hubPoller struct {
client *http.Client
req *http.Request
versions [256]int64
}
func newHubPoller(conf m.PeerConfig) *hubPoller {
u, err := url.Parse(conf.HubAddress)
if err != nil {
log.Fatalf("Failed to parse hub address %s: %v", conf.HubAddress, err)
}
u.Path = "/peer/fetch-state/"
client := &http.Client{Timeout: 8 * time.Second}
req := &http.Request{
Method: http.MethodGet,
URL: u,
Header: http.Header{},
}
req.SetBasicAuth("", conf.APIKey)
return &hubPoller{
client: client,
req: req,
}
}
func (hp *hubPoller) Run() {
defer panicHandler()
state, err := loadNetworkState(netName)
if err != nil {
log.Printf("Failed to load network state: %v", err)
log.Printf("Polling hub...")
hp.pollHub()
} else {
hp.applyNetworkState(state)
}
for range time.Tick(64 * time.Second) {
hp.pollHub()
}
}
func (hp *hubPoller) pollHub() {
var state m.NetworkState
resp, err := hp.client.Do(hp.req)
if err != nil {
log.Printf("Failed to fetch peer state: %v", err)
return
}
body, err := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if err != nil {
log.Printf("Failed to read body from hub: %v", err)
return
}
if err := json.Unmarshal(body, &state); err != nil {
log.Printf("Failed to unmarshal response from hub: %v", err)
return
}
hp.applyNetworkState(state)
if err := storeNetworkState(netName, state); err != nil {
log.Printf("Failed to store network state: %v", err)
}
}
func (hp *hubPoller) applyNetworkState(state m.NetworkState) {
for i, peer := range state.Peers {
if i != int(localIP) {
if peer != nil && peer.Version != hp.versions[i] {
peerUpdates[i] <- state.Peers[i]
hp.versions[i] = peer.Version
}
}
}
}

View File

@@ -50,11 +50,6 @@ func readNextPacket(iface io.ReadWriteCloser, buf []byte) ([]byte, byte, error)
}
}
const (
if_mtu = 1200
if_queue_len = 2048
)
func openInterface(network []byte, localIP byte, name string) (io.ReadWriteCloser, error) {
if len(network) != 4 {
return nil, fmt.Errorf("expected network to be 4 bytes, got %d", len(network))

View File

@@ -11,6 +11,7 @@ import (
"net/netip"
"os"
"runtime/debug"
"sync/atomic"
"vppn/m"
)
@@ -24,7 +25,6 @@ func Main() {
defer panicHandler()
var (
netName string
initURL string
listenIP string
port int
@@ -42,14 +42,14 @@ func Main() {
}
if initURL != "" {
mainInit(netName, initURL)
mainInit(initURL)
return
}
main(netName, listenIP, uint16(port))
main(listenIP, uint16(port))
}
func mainInit(netName, initURL string) {
func mainInit(initURL string) {
if _, err := loadPeerConfig(netName); err == nil {
log.Fatalf("Network is already initialized.")
}
@@ -79,15 +79,15 @@ func mainInit(netName, initURL string) {
// ----------------------------------------------------------------------------
func main(netName, listenIP string, port uint16) {
conf, err := loadPeerConfig(netName)
func main(listenIP string, port uint16) {
config, err := loadPeerConfig(netName)
if err != nil {
log.Fatalf("Failed to load configuration: %v", err)
}
port = determinePort(conf.Port, port)
port = determinePort(config.Port, port)
iface, err := openInterface(conf.Network, conf.PeerIP, netName)
iface, err := openInterface(config.Network, config.PeerIP, netName)
if err != nil {
log.Fatalf("Failed to open interface: %v", err)
}
@@ -102,15 +102,38 @@ func main(netName, listenIP string, port uint16) {
log.Fatalf("Failed to open UDP port: %v", err)
}
routing := newRoutingTable()
// Intialize globals.
_iface = newIFWriter(iface)
_conn = newConnWriter(conn)
w := newConnWriter(conn, conf.PeerIP, routing)
r := newConnReader(conn, conf.PeerIP, routing)
localIP = config.PeerIP
discoveryPackets = make(chan controlPacket, 256)
localAddr = &atomic.Pointer[netip.AddrPort]{}
relayIP = &atomic.Pointer[byte]{}
router := newRouter(netName, conf, routing, w)
ip, ok := netip.AddrFromSlice(config.PublicIP)
if ok {
localPub = true
addr := netip.AddrPortFrom(ip, config.Port)
localAddr.Store(&addr)
}
go nodeConnReader(r, w, iface, router)
nodeIFaceReader(w, iface, router)
privateKey = config.PrivKey
// Start supervisors.
for i := range 256 {
go newPeerSupervisor(i).Run()
}
if localPub {
go addrDiscoveryServer()
} else {
go addrDiscoveryClient()
go relayManager()
}
go newHubPoller(config).Run()
go readFromConn(conn)
readFromIFace(iface)
}
// ----------------------------------------------------------------------------
@@ -127,62 +150,160 @@ func determinePort(confPort, portFromCommandLine uint16) uint16 {
// ----------------------------------------------------------------------------
func nodeConnReader(r *connReader, w *connWriter, iface io.ReadWriteCloser, router *router) {
func readFromConn(conn *net.UDPConn) {
defer panicHandler()
var (
remoteAddr netip.AddrPort
h header
buf = make([]byte, bufferSize)
data []byte
n int
err error
buf = make([]byte, bufferSize)
decBuf = make([]byte, bufferSize)
data []byte
h header
)
for {
remoteAddr, h, data = r.Read(buf)
if h.Forward != 0 {
w.Forward(h.DestIP, data)
continue
n, remoteAddr, err = conn.ReadFromUDPAddrPort(buf[:bufferSize])
if err != nil {
log.Fatalf("Failed to read from UDP port: %v", err)
}
switch h.Stream {
remoteAddr = netip.AddrPortFrom(remoteAddr.Addr().Unmap(), remoteAddr.Port())
case streamData:
if _, err = iface.Write(data); err != nil {
log.Printf("Malformed data from peer %d: %v", h.SourceIP, err)
}
data = buf[:n]
case streamRouting:
router.HandlePacket(h.SourceIP, remoteAddr, data)
if n < headerSize {
continue // Packet it soo short.
}
h.Parse(data)
switch h.StreamID {
case controlStreamID:
handleControlPacket(remoteAddr, h, data, decBuf)
case dataStreamID:
handleDataPacket(h, data, decBuf)
default:
log.Printf("Dropping unknown stream: %d", h.Stream)
log.Printf("Unknown stream ID: %d", h.StreamID)
}
}
}
func handleControlPacket(addr netip.AddrPort, h header, data, decBuf []byte) {
route := routingTable[h.SourceIP].Load()
if route.ControlCipher == nil {
//log.Printf("Not connected (control).")
return
}
if h.DestIP != localIP {
log.Printf("Incorrect destination IP on control packet: %d != %d", h.DestIP, localIP)
return
}
out, ok := route.ControlCipher.Decrypt(data, decBuf)
if !ok {
//log.Printf("Failed to decrypt control packet.")
return
}
if len(out) == 0 {
//log.Printf("Empty control packet from: %d", h.SourceIP)
return
}
if dupChecks[h.SourceIP].IsDup(h.Counter) {
//log.Printf("[%03d] Duplicate control packet: %d", h.SourceIP, h.Counter)
return
}
pkt := controlPacket{
SrcIP: h.SourceIP,
SrcAddr: addr,
}
if err := pkt.ParsePayload(out); err != nil {
log.Printf("Failed to parse control packet: %v", err)
return
}
switch pkt.Payload.(type) {
case addrDiscoveryPacket:
select {
case discoveryPackets <- pkt:
default:
log.Printf("Dropping discovery packet.")
}
default:
select {
case controlPackets[h.SourceIP] <- pkt:
default:
log.Printf("Dropping control packet.")
}
}
}
func handleDataPacket(h header, data []byte, decBuf []byte) {
route := routingTable[h.SourceIP].Load()
if !route.Up {
//log.Printf("Not connected (recv).")
return
}
dec, ok := route.DataCipher.Decrypt(data, decBuf)
if !ok {
log.Printf("Failed to decrypt data packet.")
return
}
if dupChecks[h.SourceIP].IsDup(h.Counter) {
//log.Printf("[%03d] Duplicate data packet: %d", h.SourceIP, h.Counter)
return
}
if h.DestIP == localIP {
_iface.Write(dec)
return
}
destRoute := routingTable[h.DestIP].Load()
if !destRoute.Up {
log.Printf("Not connected (relay): %v", destRoute)
return
}
_conn.WriteTo(dec, destRoute.RemoteAddr)
}
// ----------------------------------------------------------------------------
func nodeIFaceReader(w *connWriter, iface io.ReadWriteCloser, router *router) {
func readFromIFace(iface io.ReadWriteCloser) {
var (
buf = make([]byte, bufferSize)
packet []byte
packet = make([]byte, bufferSize)
buf1 = make([]byte, bufferSize)
buf2 = make([]byte, bufferSize)
remoteIP byte
err error
)
for {
packet, remoteIP, err = readNextPacket(iface, buf)
packet, remoteIP, err = readNextPacket(iface, packet)
if err != nil {
log.Fatalf("Failed to read from interface: %v", err)
}
if remoteIP == w.localIP {
continue // Don't write to self.
route := routingTable[remoteIP].Load()
if !route.Up {
log.Printf("Route not connected: %d", remoteIP)
continue
}
w.WriteTo(remoteIP, streamData, packet)
_sendDataPacket(route, packet, buf1, buf2)
}
}

View File

@@ -1 +0,0 @@
package node

163
node/packets-util.go Normal file
View File

@@ -0,0 +1,163 @@
package node
import (
"net/netip"
"sync/atomic"
"time"
"unsafe"
)
var traceIDCounter uint64 = uint64(time.Now().Unix()<<30) + 1
func newTraceID() uint64 {
return atomic.AddUint64(&traceIDCounter, 1)
}
// ----------------------------------------------------------------------------
type binWriter struct {
b []byte
i int
}
func newBinWriter(buf []byte) *binWriter {
buf = buf[:cap(buf)]
return &binWriter{buf, 0}
}
func (w *binWriter) Bool(b bool) *binWriter {
if b {
return w.Byte(1)
}
return w.Byte(0)
}
func (w *binWriter) Byte(b byte) *binWriter {
w.b[w.i] = b
w.i++
return w
}
func (w *binWriter) SharedKey(key [32]byte) *binWriter {
copy(w.b[w.i:w.i+32], key[:])
w.i += 32
return w
}
func (w *binWriter) Uint16(x uint16) *binWriter {
*(*uint16)(unsafe.Pointer(&w.b[w.i])) = x
w.i += 2
return w
}
func (w *binWriter) Uint64(x uint64) *binWriter {
*(*uint64)(unsafe.Pointer(&w.b[w.i])) = x
w.i += 8
return w
}
func (w *binWriter) Int64(x int64) *binWriter {
*(*int64)(unsafe.Pointer(&w.b[w.i])) = x
w.i += 8
return w
}
func (w *binWriter) AddrPort(addrPort netip.AddrPort) *binWriter {
addr := addrPort.Addr().As16()
copy(w.b[w.i:w.i+16], addr[:])
w.i += 16
return w.Uint16(addrPort.Port())
}
func (w *binWriter) Build() []byte {
return w.b[:w.i]
}
// ----------------------------------------------------------------------------
type binReader struct {
b []byte
i int
err error
}
func newBinReader(buf []byte) *binReader {
return &binReader{b: buf}
}
func (r *binReader) hasBytes(n int) bool {
if r.err != nil || (len(r.b)-r.i) < n {
r.err = errMalformedPacket
return false
}
return true
}
func (r *binReader) Bool(b *bool) *binReader {
var bb byte
r.Byte(&bb)
*b = bb != 0
return r
}
func (r *binReader) Byte(b *byte) *binReader {
if !r.hasBytes(1) {
return r
}
*b = r.b[r.i]
r.i++
return r
}
func (r *binReader) SharedKey(x *[32]byte) *binReader {
if !r.hasBytes(32) {
return r
}
*x = ([32]byte)(r.b[r.i : r.i+32])
r.i += 32
return r
}
func (r *binReader) Uint16(x *uint16) *binReader {
if !r.hasBytes(2) {
return r
}
*x = *(*uint16)(unsafe.Pointer(&r.b[r.i]))
r.i += 2
return r
}
func (r *binReader) Uint64(x *uint64) *binReader {
if !r.hasBytes(8) {
return r
}
*x = *(*uint64)(unsafe.Pointer(&r.b[r.i]))
r.i += 8
return r
}
func (r *binReader) Int64(x *int64) *binReader {
if !r.hasBytes(8) {
return r
}
*x = *(*int64)(unsafe.Pointer(&r.b[r.i]))
r.i += 8
return r
}
func (r *binReader) AddrPort(x *netip.AddrPort) *binReader {
if !r.hasBytes(18) {
return r
}
addr := netip.AddrFrom16(([16]byte)(r.b[r.i : r.i+16])).Unmap()
r.i += 16
var port uint16
r.Uint16(&port)
*x = netip.AddrPortFrom(addr, port)
return r
}
func (r *binReader) Error() error {
return r.err
}

40
node/packets-util_test.go Normal file
View File

@@ -0,0 +1,40 @@
package node
import (
"net/netip"
"reflect"
"testing"
)
func TestBinWriteRead(t *testing.T) {
buf := make([]byte, 1024)
type Item struct {
Type byte
TraceID uint64
DestAddr netip.AddrPort
}
in := Item{1, 2, netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 22)}
buf = newBinWriter(buf).
Byte(in.Type).
Uint64(in.TraceID).
AddrPort(in.DestAddr).
Build()
out := Item{}
err := newBinReader(buf).
Byte(&out.Type).
Uint64(&out.TraceID).
AddrPort(&out.DestAddr).
Error()
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(in, out) {
t.Fatal(in, out)
}
}

140
node/packets.go Normal file
View File

@@ -0,0 +1,140 @@
package node
import (
"errors"
"net/netip"
)
var (
errMalformedPacket = errors.New("malformed packet")
errUnknownPacketType = errors.New("unknown packet type")
)
const (
packetTypeSyn = iota + 1
packetTypeSynAck
packetTypeAck
packetTypeProbe
packetTypeAddrDiscovery
)
// ----------------------------------------------------------------------------
type controlPacket struct {
SrcIP byte
SrcAddr netip.AddrPort
Payload any
}
func (p *controlPacket) ParsePayload(buf []byte) (err error) {
switch buf[0] {
case packetTypeSyn:
p.Payload, err = parseSynPacket(buf)
case packetTypeSynAck:
p.Payload, err = parseSynAckPacket(buf)
case packetTypeProbe:
p.Payload, err = parseProbePacket(buf)
case packetTypeAddrDiscovery:
p.Payload, err = parseAddrDiscoveryPacket(buf)
default:
return errUnknownPacketType
}
return err
}
// ----------------------------------------------------------------------------
type synPacket struct {
TraceID uint64 // TraceID to match response w/ request.
SharedKey [32]byte // Our shared key.
Direct bool
FromAddr netip.AddrPort // The client's sending address.
}
func (p synPacket) Marshal(buf []byte) []byte {
return newBinWriter(buf).
Byte(packetTypeSyn).
Uint64(p.TraceID).
SharedKey(p.SharedKey).
Bool(p.Direct).
AddrPort(p.FromAddr).
Build()
}
func parseSynPacket(buf []byte) (p synPacket, err error) {
err = newBinReader(buf[1:]).
Uint64(&p.TraceID).
SharedKey(&p.SharedKey).
Bool(&p.Direct).
AddrPort(&p.FromAddr).
Error()
return
}
// ----------------------------------------------------------------------------
type synAckPacket struct {
TraceID uint64
FromAddr netip.AddrPort
}
func (p synAckPacket) Marshal(buf []byte) []byte {
return newBinWriter(buf).
Byte(packetTypeSynAck).
Uint64(p.TraceID).
AddrPort(p.FromAddr).
Build()
}
func parseSynAckPacket(buf []byte) (p synAckPacket, err error) {
err = newBinReader(buf[1:]).
Uint64(&p.TraceID).
AddrPort(&p.FromAddr).
Error()
return
}
// ----------------------------------------------------------------------------
type addrDiscoveryPacket struct {
TraceID uint64
ToAddr netip.AddrPort
}
func (p addrDiscoveryPacket) Marshal(buf []byte) []byte {
return newBinWriter(buf).
Byte(packetTypeAddrDiscovery).
Uint64(p.TraceID).
AddrPort(p.ToAddr).
Build()
}
func parseAddrDiscoveryPacket(buf []byte) (p addrDiscoveryPacket, err error) {
err = newBinReader(buf[1:]).
Uint64(&p.TraceID).
AddrPort(&p.ToAddr).
Error()
return
}
// ----------------------------------------------------------------------------
// A probeReqPacket is sent from a client to a server to determine if direct
// UDP communication can be used.
type probePacket struct {
TraceID uint64
}
func (p probePacket) Marshal(buf []byte) []byte {
return newBinWriter(buf).
Byte(packetTypeProbe).
Uint64(p.TraceID).
Build()
}
func parseProbePacket(buf []byte) (p probePacket, err error) {
err = newBinReader(buf[1:]).
Uint64(&p.TraceID).
Error()
return
}

42
node/packets_test.go Normal file
View File

@@ -0,0 +1,42 @@
package node
import (
"crypto/rand"
"net/netip"
"reflect"
"testing"
)
func TestPacketSyn(t *testing.T) {
in := synPacket{
TraceID: newTraceID(),
RelayIP: 4,
FromAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{4, 5, 6, 7}), 22),
}
rand.Read(in.SharedKey[:])
out, err := parseSynPacket(in.Marshal(make([]byte, bufferSize)))
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(in, out) {
t.Fatal("\n", in, "\n", out)
}
}
func TestPacketSynAck(t *testing.T) {
in := synAckPacket{
TraceID: newTraceID(),
FromAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{4, 5, 6, 7}), 22),
}
out, err := parseSynAckPacket(in.Marshal(make([]byte, bufferSize)))
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(in, out) {
t.Fatal("\n", in, "\n", out)
}
}

331
node/peer-supervisor.go Normal file
View File

@@ -0,0 +1,331 @@
package node
import (
"fmt"
"log"
"net/netip"
"sync/atomic"
"time"
"vppn/m"
)
const (
pingInterval = 8 * time.Second
timeoutInterval = 25 * time.Second
)
// ----------------------------------------------------------------------------
type peerSupervisor struct {
// The purpose of this state machine is to manage this published data.
published *atomic.Pointer[peerRoute]
staged peerRoute // Local copy of shared data. See publish().
// Immutable data.
remoteIP byte // Remote VPN IP.
// Mutable peer data.
peer *m.Peer
remotePub bool
// Incoming events.
peerUpdates chan *m.Peer
controlPackets chan controlPacket
// Buffers for sending control packets.
buf1 []byte
buf2 []byte
}
func newPeerSupervisor(i int) *peerSupervisor {
return &peerSupervisor{
published: routingTable[i],
remoteIP: byte(i),
peerUpdates: peerUpdates[i],
controlPackets: controlPackets[i],
buf1: make([]byte, bufferSize),
buf2: make([]byte, bufferSize),
}
}
type stateFunc func() stateFunc
func (s *peerSupervisor) Run() {
state := s.noPeer
for {
state = state()
}
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) {
_sendControlPacket(pkt, s.staged, s.buf1, s.buf2)
time.Sleep(500 * time.Millisecond) // Rate limit packets.
}
func (s *peerSupervisor) sendControlPacketTo(
pkt interface{ Marshal([]byte) []byte },
addr netip.AddrPort,
) {
if !addr.IsValid() {
s.logf("ERROR: Attepted to send packet to invalid address: %v", addr)
return
}
route := s.staged
route.Direct = true
route.RemoteAddr = addr
_sendControlPacket(pkt, route, s.buf1, s.buf2)
time.Sleep(500 * time.Millisecond) // Rate limit packets.
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) logf(msg string, args ...any) {
log.Printf(fmt.Sprintf("[%03d] ", s.remoteIP)+msg, args...)
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) publish() {
data := s.staged
s.published.Store(&data)
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) noPeer() stateFunc {
return s.peerUpdate(<-s.peerUpdates)
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) peerUpdate(peer *m.Peer) stateFunc {
return func() stateFunc { return s._peerUpdate(peer) }
}
func (s *peerSupervisor) _peerUpdate(peer *m.Peer) stateFunc {
defer s.publish()
s.peer = peer
s.staged = peerRoute{}
if s.peer == nil {
return s.noPeer
}
s.staged.IP = s.remoteIP
s.staged.ControlCipher = newControlCipher(privateKey, peer.PubKey)
s.staged.DataCipher = newDataCipher()
if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid {
s.remotePub = true
s.staged.Relay = peer.Relay
s.staged.Direct = true
s.staged.RemoteAddr = netip.AddrPortFrom(ip, peer.Port)
} else if localPub {
s.staged.Direct = true
}
if s.remotePub == localPub {
if localIP < s.remoteIP {
return s.server
}
return s.client
}
if s.remotePub {
return s.client
}
return s.server
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) server() stateFunc {
logf := func(format string, args ...any) { s.logf("SERVER "+format, args...) }
logf("DOWN")
var (
syn synPacket
timeoutTimer = time.NewTimer(timeoutInterval)
)
for {
select {
case peer := <-s.peerUpdates:
return s.peerUpdate(peer)
case pkt := <-s.controlPackets:
switch p := pkt.Payload.(type) {
case synPacket:
// Before we can respond to this packet, we need to make sure the
// route is setup properly.
//
// The client will update the syn's TraceID whenever there's a change.
// The server will follow the client's request.
if p.TraceID != syn.TraceID || !s.staged.Up {
if p.Direct {
logf("UP - Direct")
} else {
logf("UP - Relayed")
}
syn = p
s.staged.Up = true
s.staged.Direct = syn.Direct
s.staged.DataCipher = newDataCipherFromKey(syn.SharedKey)
s.staged.RemoteAddr = pkt.SrcAddr
s.publish()
}
// We should always respond.
ack := synAckPacket{
TraceID: syn.TraceID,
FromAddr: getLocalAddr(),
}
s.sendControlPacket(ack)
if s.staged.Direct {
continue
}
if !syn.FromAddr.IsValid() {
continue
}
probe := probePacket{TraceID: newTraceID()}
s.sendControlPacketTo(probe, syn.FromAddr)
case probePacket:
if pkt.SrcAddr.IsValid() {
s.sendControlPacketTo(probePacket{TraceID: p.TraceID}, pkt.SrcAddr)
} else {
logf("Invalid probe address")
}
}
case <-timeoutTimer.C:
logf("Connection timeout")
s.staged.Up = false
s.publish()
}
}
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) client() stateFunc {
logf := func(format string, args ...any) { s.logf("CLIENT "+format, args...) }
logf("DOWN")
var (
syn = synPacket{
TraceID: newTraceID(),
SharedKey: s.staged.DataCipher.Key(),
Direct: s.staged.Direct,
FromAddr: getLocalAddr(),
}
ack synAckPacket
probe probePacket
probeAddr netip.AddrPort
lAddr netip.AddrPort
timeoutTimer = time.NewTimer(timeoutInterval)
pingTimer = time.NewTimer(pingInterval)
)
defer timeoutTimer.Stop()
defer pingTimer.Stop()
s.sendControlPacket(syn)
for {
select {
case peer := <-s.peerUpdates:
return s.peerUpdate(peer)
case pkt := <-s.controlPackets:
switch p := pkt.Payload.(type) {
case synAckPacket:
if p.TraceID != syn.TraceID {
continue // Hmm...
}
ack = p
timeoutTimer.Reset(timeoutInterval)
if !s.staged.Up {
if s.staged.Direct {
logf("UP - Direct")
} else {
logf("UP - Relayed")
}
s.staged.Up = true
s.publish()
}
case probePacket:
if s.staged.Direct {
continue
}
if p.TraceID != probe.TraceID {
continue
}
// Upgrade connection.
logf("UP - Direct")
s.staged.Direct = true
s.staged.RemoteAddr = probeAddr
s.publish()
syn.TraceID = newTraceID()
syn.Direct = true
syn.FromAddr = getLocalAddr()
s.sendControlPacket(syn)
}
case <-pingTimer.C:
// Send syn.
syn.FromAddr = getLocalAddr()
if syn.FromAddr != lAddr {
syn.TraceID = newTraceID()
lAddr = syn.FromAddr
}
s.sendControlPacket(syn)
pingTimer.Reset(pingInterval)
if s.staged.Direct {
continue
}
if !ack.FromAddr.IsValid() {
continue
}
probe = probePacket{TraceID: newTraceID()}
probeAddr = ack.FromAddr
s.sendControlPacketTo(probe, ack.FromAddr)
case <-timeoutTimer.C:
logf("Connection timeout")
return s.peerUpdate(s.peer)
}
}
}

View File

@@ -1 +0,0 @@
package node

View File

@@ -1 +0,0 @@
package node

View File

@@ -1,327 +0,0 @@
package node
import (
"fmt"
"log"
"net/netip"
"time"
"vppn/m"
)
const (
connectTimeout = 6 * time.Second
pingInterval = 6 * time.Second
timeoutInterval = 20 * time.Second
)
type routingPacketWrapper struct {
routingPacket
Addr netip.AddrPort // Source.
}
type peerSupervisor struct {
// Constants:
localIP byte
localPublic bool
remoteIP byte
privKey []byte
// Shared data:
w *connWriter
table *routingTable
packets chan routingPacketWrapper
peerUpdates chan *m.Peer
// Peer-related items.
version int64 // Ony accessed in HandlePeerUpdate.
peer *m.Peer
remoteAddrPort *netip.AddrPort
mediated bool
sharedKey []byte
// Used by our state functions.
pingTimer *time.Timer
timeoutTimer *time.Timer
buf []byte
}
// ----------------------------------------------------------------------------
func newPeerSupervisor(
conf m.PeerConfig,
remoteIP byte,
w *connWriter,
table *routingTable,
) *peerSupervisor {
s := &peerSupervisor{
localIP: conf.PeerIP,
remoteIP: remoteIP,
privKey: conf.EncPrivKey,
w: w,
table: table,
packets: make(chan routingPacketWrapper, 256),
peerUpdates: make(chan *m.Peer, 1),
pingTimer: time.NewTimer(pingInterval),
timeoutTimer: time.NewTimer(timeoutInterval),
buf: make([]byte, bufferSize),
}
_, s.localPublic = netip.AddrFromSlice(conf.PublicIP)
go s.mainLoop()
return s
}
func (s *peerSupervisor) logf(msg string, args ...any) {
msg = fmt.Sprintf("[%03d] ", s.remoteIP) + msg
log.Printf(msg, args...)
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) mainLoop() {
defer panicHandler()
state := s.stateInit
for {
state = state()
}
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) HandlePeerUpdate(p *m.Peer) {
if p != nil {
if p.Version == s.version {
return
}
s.version = p.Version
} else {
s.version = 0
}
s.peerUpdates <- p
}
func (s *peerSupervisor) HandlePacket(w routingPacketWrapper) {
select {
case s.packets <- w:
default:
// Drop
}
}
// ----------------------------------------------------------------------------
type stateFunc func() stateFunc
func (s *peerSupervisor) stateInit() stateFunc {
if s.peer == nil {
return s.stateDisconnected
}
addr, ok := netip.AddrFromSlice(s.peer.PublicIP)
if ok {
addrPort := netip.AddrPortFrom(addr, s.peer.Port)
s.remoteAddrPort = &addrPort
} else {
s.remoteAddrPort = nil
}
s.sharedKey = computeSharedKey(s.peer.EncPubKey, s.privKey)
return s.stateSelectRole()
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) stateDisconnected() stateFunc {
s.clearRoutingTable()
for {
select {
case <-s.packets:
// Drop
case s.peer = <-s.peerUpdates:
return s.stateInit
}
}
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) stateSelectRole() stateFunc {
s.logf("STATE: SelectRole")
s.updateRoutingTable(false)
if s.remoteAddrPort != nil {
s.mediated = false
// If both remote and local are public, one side acts as client, and one
// side as server.
if s.localPublic && s.localIP < s.peer.PeerIP {
return s.stateAccept
}
return s.stateDial
}
// We're public, remote is not => can only wait for connection
if s.localPublic {
s.mediated = false
return s.stateAccept
}
// Both non-public => need to use mediator.
return s.stateMediated
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) stateAccept() stateFunc {
s.logf("STATE: Accept")
for {
select {
case pkt := <-s.packets:
switch pkt.Type {
case packetTypePing:
s.remoteAddrPort = &pkt.Addr
s.updateRoutingTable(true)
s.sendPong(pkt.TraceID)
return s.stateConnected
default:
// Still waiting for ping...
}
case s.peer = <-s.peerUpdates:
return s.stateInit
}
}
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) stateDial() stateFunc {
s.logf("STATE: Dial")
s.updateRoutingTable(false)
s.sendPing()
for {
select {
case pkt := <-s.packets:
switch pkt.Type {
case packetTypePong:
s.updateRoutingTable(true)
return s.stateConnected
default:
// Ignore
}
case <-s.pingTimer.C:
s.sendPing()
case s.peer = <-s.peerUpdates:
return s.stateInit
}
}
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) stateConnected() stateFunc {
s.logf("STATE: Connected")
s.timeoutTimer.Reset(timeoutInterval)
for {
select {
case <-s.pingTimer.C:
s.sendPing()
case <-s.timeoutTimer.C:
s.logf("Timeout")
return s.stateInit
case pkt := <-s.packets:
switch pkt.Type {
case packetTypePing:
s.sendPong(pkt.TraceID)
// Server should always follow remote port.
if s.localPublic {
if pkt.Addr != *s.remoteAddrPort {
s.remoteAddrPort = &pkt.Addr
s.updateRoutingTable(true)
}
}
case packetTypePong:
s.timeoutTimer.Reset(timeoutInterval)
default:
// Drop packet.
}
case s.peer = <-s.peerUpdates:
s.logf("New peer: %v", s.peer)
return s.stateInit
}
}
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) stateMediated() stateFunc {
s.logf("STATE: Mediated")
s.mediated = true
s.updateRoutingTable(true)
for {
select {
case <-s.packets:
// Drop.
case s.peer = <-s.peerUpdates:
s.logf("New peer: %v", s.peer)
return s.stateInit
}
}
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) clearRoutingTable() {
s.table.Set(s.remoteIP, nil)
}
func (s *peerSupervisor) updateRoutingTable(up bool) {
s.table.Set(s.remoteIP, &peer{
Up: up,
Mediator: s.peer.Mediator,
Mediated: s.mediated,
IP: s.remoteIP,
Addr: s.remoteAddrPort,
SharedKey: s.sharedKey,
})
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) sendPing() uint64 {
traceID := newTraceID()
pkt := newRoutingPacket(packetTypePing, traceID)
s.w.WriteTo(s.peer.PeerIP, streamRouting, pkt.Marshal(s.buf))
s.pingTimer.Reset(pingInterval)
return traceID
}
func (s *peerSupervisor) sendPong(traceID uint64) {
pkt := newRoutingPacket(packetTypePong, traceID)
s.w.WriteTo(s.peer.PeerIP, streamRouting, pkt.Marshal(s.buf))
}

40
node/relaymanager.go Normal file
View File

@@ -0,0 +1,40 @@
package node
import (
"log"
"math/rand"
"time"
)
func relayManager() {
time.Sleep(2 * time.Second)
updateRelayRoute()
for range time.Tick(8 * time.Second) {
relay := getRelayRoute()
if relay == nil || !relay.Up || !relay.Relay {
updateRelayRoute()
}
}
}
func updateRelayRoute() {
possible := make([]*peerRoute, 0, 8)
for i := range routingTable {
route := routingTable[i].Load()
if !route.Up || !route.Relay {
continue
}
possible = append(possible, route)
}
if len(possible) == 0 {
log.Printf("No relay available.")
relayIP.Store(nil)
return
}
ip := possible[rand.Intn(len(possible))].IP
log.Printf("New relay IP: %d", ip)
relayIP.Store(&ip)
}

View File

@@ -1,186 +0,0 @@
package node
import (
"encoding/json"
"io"
"log"
"net/http"
"net/netip"
"net/url"
"sync/atomic"
"time"
"vppn/m"
)
type peer struct {
Up bool // No data will be sent to peers that are down.
Mediator bool
Mediated bool
IP byte
Addr *netip.AddrPort // If we have direct connection, otherwise use mediator.
SharedKey []byte
}
// ----------------------------------------------------------------------------
type routingTable struct {
table [256]*atomic.Pointer[peer]
mediator *atomic.Pointer[peer]
}
func newRoutingTable() *routingTable {
r := routingTable{
mediator: &atomic.Pointer[peer]{},
}
for i := range r.table {
r.table[i] = &atomic.Pointer[peer]{}
}
return &r
}
func (r *routingTable) Get(ip byte) *peer {
return r.table[ip].Load()
}
func (r *routingTable) Set(ip byte, p *peer) {
r.table[ip].Store(p)
}
// ----------------------------------------------------------------------------
type router struct {
*routingTable
netName string
peerSupers [256]*peerSupervisor
}
func newRouter(netName string, conf m.PeerConfig, routingData *routingTable, w *connWriter) *router {
r := &router{
netName: netName,
routingTable: routingData,
}
for i := range r.peerSupers {
r.peerSupers[i] = newPeerSupervisor(
conf,
byte(i),
w,
r.routingTable)
}
go r.selectMediator()
go r.pollHub(conf)
return r
}
// ----------------------------------------------------------------------------
func (r *router) HandlePacket(sourceIP byte, remoteAddr netip.AddrPort, data []byte) {
p := routingPacket{}
if err := p.Parse(data); err != nil {
log.Printf("Dropping malformed routing packet: %v", err)
return
}
w := routingPacketWrapper{
routingPacket: p,
Addr: remoteAddr,
}
r.peerSupers[sourceIP].HandlePacket(w)
}
// ----------------------------------------------------------------------------
func (r *router) pollHub(conf m.PeerConfig) {
defer panicHandler()
u, err := url.Parse(conf.HubAddress)
if err != nil {
log.Fatalf("Failed to parse hub address %s: %v", conf.HubAddress, err)
}
u.Path = "/peer/fetch-state/"
client := &http.Client{Timeout: 8 * time.Second}
req := &http.Request{
Method: http.MethodGet,
URL: u,
Header: http.Header{},
}
req.SetBasicAuth("", conf.APIKey)
state, err := loadNetworkState(r.netName)
if err != nil {
log.Printf("Failed to load network state: %v", err)
log.Printf("Polling hub...")
r._pollHub(conf, client, req)
} else {
r.applyNetworkState(conf, state)
}
for range time.Tick(64 * time.Second) {
r._pollHub(conf, client, req)
}
}
func (r *router) _pollHub(conf m.PeerConfig, client *http.Client, req *http.Request) {
var state m.NetworkState
log.Printf("Fetching peer state from %s...", conf.HubAddress)
resp, err := client.Do(req)
if err != nil {
log.Printf("Failed to fetch peer state: %v", err)
return
}
body, err := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if err != nil {
log.Printf("Failed to read body from hub: %v", err)
return
}
if err := json.Unmarshal(body, &state); err != nil {
log.Printf("Failed to unmarshal response from hub: %v", err)
return
}
r.applyNetworkState(conf, state)
if err := storeNetworkState(r.netName, state); err != nil {
log.Printf("Failed to store network state: %v", err)
}
}
func (r *router) applyNetworkState(conf m.PeerConfig, state m.NetworkState) {
for i := range state.Peers {
if i != int(conf.PeerIP) {
r.peerSupers[i].HandlePeerUpdate(state.Peers[i])
}
}
}
// ----------------------------------------------------------------------------
func (r *router) selectMediator() {
for range time.Tick(8 * time.Second) {
current := r.mediator.Load()
if current != nil && current.Up {
continue
}
for i := range r.table {
peer := r.table[i].Load()
if peer != nil && peer.Up && peer.Mediator {
log.Printf("Got mediator: %v", *peer)
r.mediator.Store(peer)
return
}
}
r.mediator.Store(nil)
}
}

View File

@@ -1,44 +0,0 @@
package node
import (
"errors"
"unsafe"
)
var errMalformedPacket = errors.New("malformed packet")
const (
packetTypeInvalid = iota
// Used to maintain connection.
packetTypePing
packetTypePong
)
type routingPacket struct {
Type byte // One of the packetType* constants.
TraceID uint64 // For matching requests and responses.
}
func newRoutingPacket(reqType byte, traceID uint64) routingPacket {
return routingPacket{
Type: reqType,
TraceID: traceID,
}
}
func (p routingPacket) Marshal(buf []byte) []byte {
buf = buf[:32] // Reserve 32 bytes just in case we need to add anything.
buf[0] = p.Type
*(*uint64)(unsafe.Pointer(&buf[1])) = uint64(p.TraceID)
return buf
}
func (p *routingPacket) Parse(buf []byte) error {
if len(buf) != 32 {
return errMalformedPacket
}
p.Type = buf[0]
p.TraceID = *(*uint64)(unsafe.Pointer(&buf[1]))
return nil
}

View File

@@ -1,185 +0,0 @@
package node
/*
var (
network = []byte{10, 1, 1, 0}
serverIP = byte(1)
clientIP = byte(2)
port = uint16(5151)
netName = "testnet"
pubKey1 = []byte{0x43, 0xde, 0xd4, 0xb2, 0x1d, 0x71, 0x58, 0x9a, 0x96, 0x3a, 0x23, 0xfc, 0x2, 0xe, 0xfa, 0x42, 0x3, 0x94, 0xbc, 0xf8, 0x25, 0xf, 0x54, 0xcc, 0x98, 0x42, 0x8b, 0xe5, 0x27, 0x86, 0x49, 0x33}
privKey1 = []byte{0xae, 0x4d, 0xc5, 0xaa, 0xc9, 0xbc, 0x65, 0x41, 0x55, 0xb, 0x61, 0x52, 0xc4, 0x6c, 0xce, 0x2f, 0x1b, 0xf5, 0xb3, 0xbf, 0xb5, 0x54, 0x61, 0x7c, 0x26, 0x2e, 0xba, 0x5a, 0x19, 0xe2, 0x9c, 0xe0}
pubKey2 = []byte{0x8c, 0xfe, 0x12, 0xd9, 0x2d, 0x37, 0x5, 0x43, 0xab, 0x70, 0x59, 0x20, 0x3d, 0x82, 0x93, 0x9b, 0xb3, 0xaa, 0x35, 0x23, 0xc1, 0xb4, 0x4, 0x1f, 0x92, 0x97, 0x6f, 0xfd, 0x55, 0x17, 0x5a, 0x4b}
privKey2 = []byte{0xd9, 0xe1, 0xc6, 0x64, 0x3e, 0x29, 0x29, 0x78, 0x81, 0x53, 0xc2, 0x31, 0xd9, 0x34, 0x5b, 0x41, 0xf5, 0x80, 0xb0, 0x27, 0x9f, 0x65, 0x85, 0xd4, 0x78, 0xd5, 0x9, 0x2, 0xca, 0x56, 0x42, 0x80}
)
func must(err error) {
if err != nil {
panic(err)
}
}
type TmpNode struct {
network []byte
localIP byte
router *router
port uint16
netName string
iface io.ReadWriteCloser
pubKey []byte
privKey []byte
w *connWriter
r *connReader
}
// ----------------------------------------------------------------------------
func NewTmpNodeServer() *TmpNode {
n := &TmpNode{
localIP: serverIP,
network: network,
router: &router{table: newPeerRepo()},
port: port,
netName: netName,
pubKey: pubKey1,
privKey: privKey1,
}
var err error
n.iface, err = openInterface(n.network, n.localIP, n.netName)
must(err)
myAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", n.port))
must(err)
conn, err := net.ListenUDP("udp", myAddr)
must(err)
n.w = newConnWriter(conn, n.localIP, n.router)
n.r = newConnReader(conn, n.localIP, n.router)
n.router.table.Set(clientIP, &peer{
IP: clientIP,
SharedKey: computeSharedKey(pubKey2, n.privKey),
})
return n
}
// ----------------------------------------------------------------------------
func NewTmpNodeClient(srvAddrStr string) *TmpNode {
n := &TmpNode{
localIP: clientIP,
network: network,
router: &router{table: newPeerRepo()},
port: port,
netName: netName,
pubKey: pubKey2,
privKey: privKey2,
}
var err error
n.iface, err = openInterface(n.network, n.localIP, n.netName)
must(err)
myAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", n.port))
must(err)
conn, err := net.ListenUDP("udp", myAddr)
must(err)
n.w = newConnWriter(conn, n.localIP, n.router)
n.r = newConnReader(conn, n.localIP, n.router)
serverAddr, err := netip.ParseAddrPort(fmt.Sprintf("%s:%d", srvAddrStr, port))
must(err)
n.router.table.Set(serverIP, &peer{
IP: serverIP,
Addr: &serverAddr,
SharedKey: computeSharedKey(pubKey1, n.privKey),
})
return n
}
// ----------------------------------------------------------------------------
func (n *TmpNode) RunServer() {
defer func() {
if r := recover(); r != nil {
fmt.Printf("%v", r)
debug.PrintStack()
}
}()
// Get remoteAddr from a packet.
buf := make([]byte, bufferSize)
remoteAddr, h, _, err := n.r.Read(buf)
must(err)
log.Printf("Got remote addr: %d -> %v", h.SourceIP, remoteAddr)
must(err)
n.router.table.Set(h.SourceIP, &peer{
IP: h.SourceIP,
Addr: &remoteAddr,
SharedKey: computeSharedKey(pubKey2, n.privKey),
})
go n.readFromIFace()
n.readFromConn()
}
// ----------------------------------------------------------------------------
func (n *TmpNode) RunClient() {
defer func() {
if r := recover(); r != nil {
fmt.Printf("%v\n", r)
debug.PrintStack()
}
}()
log.Printf("Sending to server...")
must(n.w.WriteTo(serverIP, 1, []byte{1, 2, 3, 4, 5, 6, 7, 8}))
go n.readFromIFace()
n.readFromConn()
}
func (n *TmpNode) readFromIFace() {
var (
buf = make([]byte, bufferSize)
packet []byte
remoteIP byte
err error
)
for {
packet, remoteIP, err = readNextPacket(n.iface, buf)
must(err)
must(n.w.WriteTo(remoteIP, 1, packet))
}
}
func (node *TmpNode) readFromConn() {
var (
buf = make([]byte, bufferSize)
packet []byte
err error
)
for {
_, _, packet, err = node.r.Read(buf)
must(err)
// We assume that we're only receiving packets from one source.
_, err = node.iface.Write(packet)
if err != nil {
log.Printf("Got error: %v", err)
}
//must(err)
}
}
*/