diff --git a/LICENSE b/LICENSE index 078df32..042a386 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2024 app +Copyright (c) 2024 John David Lee (johndavidlee@crumpington.com) Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: diff --git a/README.md b/README.md index 3aa4d04..87e3072 100644 --- a/README.md +++ b/README.md @@ -2,10 +2,8 @@ ## Roadmap -* Node: use symmetric encryption after handshake -* AEAD-AES uses a 12 byte nonce. We need to shrink the header: - * Remove Forward and replace it with a HeaderFlags bitfield. - * Forward, Asym/Sym, ... +* Use probe and relayed-probe packets vs ping/pong. +* Rename Mediator -> Relay * Use default port 456 * Remove signing key from hub * Peer: UDP hole-punching diff --git a/fasttime/time.go b/fasttime/time.go deleted file mode 100644 index 5c569ac..0000000 --- a/fasttime/time.go +++ /dev/null @@ -1,20 +0,0 @@ -package fasttime - -import ( - "sync/atomic" - "time" -) - -var _timestamp int64 = time.Now().Unix() - -func init() { - go func() { - for range time.Tick(1100 * time.Millisecond) { - atomic.StoreInt64(&_timestamp, time.Now().Unix()) - } - }() -} - -func Now() int64 { - return atomic.LoadInt64(&_timestamp) -} diff --git a/fasttime/time_test.go b/fasttime/time_test.go deleted file mode 100644 index b0a85d0..0000000 --- a/fasttime/time_test.go +++ /dev/null @@ -1,18 +0,0 @@ -package fasttime - -import ( - "testing" - "time" -) - -func BenchmarkNow(b *testing.B) { - for i := 0; i < b.N; i++ { - Now() - } -} - -func BenchmarkTimeUnix(b *testing.B) { - for i := 0; i < b.N; i++ { - time.Now().Unix() - } -} diff --git a/hub/api/api.go b/hub/api/api.go index 053c574..975149d 100644 --- a/hub/api/api.go +++ b/hub/api/api.go @@ -15,7 +15,6 @@ import ( "git.crumpington.com/lib/go/sqliteutil" "golang.org/x/crypto/bcrypt" "golang.org/x/crypto/nacl/box" - "golang.org/x/crypto/nacl/sign" ) //go:embed migrations @@ -146,7 +145,7 @@ type PeerCreateArgs struct { Name string PublicIP []byte Port uint16 - Mediator bool + Relay bool } // Create the intention to add a peer. The returned code is used to complete @@ -184,11 +183,6 @@ func (a *API) Peer_Create(creationCode string) (*m.PeerConfig, error) { return nil, err } - signPubKey, signPrivKey, err := sign.GenerateKey(rand.Reader) - if err != nil { - return nil, err - } - // Get peer IP. peerIP := byte(0) @@ -208,15 +202,14 @@ func (a *API) Peer_Create(creationCode string) (*m.PeerConfig, error) { } peer := &Peer{ - PeerIP: peerIP, - Version: idgen.NextID(0), - APIKey: idgen.NewToken(), - Name: args.Name, - PublicIP: args.PublicIP, - Port: args.Port, - Mediator: args.Mediator, - EncPubKey: encPubKey[:], - SignPubKey: signPubKey[:], + PeerIP: peerIP, + Version: idgen.NextID(0), + APIKey: idgen.NewToken(), + Name: args.Name, + PublicIP: args.PublicIP, + Port: args.Port, + Relay: args.Relay, + PubKey: encPubKey[:], } if err := db.Peer_Insert(a.db, peer); err != nil { @@ -226,17 +219,15 @@ func (a *API) Peer_Create(creationCode string) (*m.PeerConfig, error) { conf := a.Config_Get() return &m.PeerConfig{ - PeerIP: peer.PeerIP, - HubAddress: conf.HubAddress, - APIKey: peer.APIKey, - Network: conf.VPNNetwork, - PublicIP: peer.PublicIP, - Port: peer.Port, - Mediator: peer.Mediator, - EncPubKey: encPubKey[:], - EncPrivKey: encPrivKey[:], - SignPubKey: signPubKey[:], - SignPrivKey: signPrivKey[:], + PeerIP: peer.PeerIP, + HubAddress: conf.HubAddress, + APIKey: peer.APIKey, + Network: conf.VPNNetwork, + PublicIP: peer.PublicIP, + Port: peer.Port, + Relay: peer.Relay, + PubKey: encPubKey[:], + PrivKey: encPrivKey[:], }, nil } diff --git a/hub/api/db/generated.go b/hub/api/db/generated.go index a23498d..1957b6f 100644 --- a/hub/api/db/generated.go +++ b/hub/api/db/generated.go @@ -307,18 +307,17 @@ func Session_List( // ---------------------------------------------------------------------------- type Peer struct { - PeerIP byte - Version int64 - APIKey string - Name string - PublicIP []byte - Port uint16 - Mediator bool - EncPubKey []byte - SignPubKey []byte + PeerIP byte + Version int64 + APIKey string + Name string + PublicIP []byte + Port uint16 + Relay bool + PubKey []byte } -const Peer_SelectQuery = "SELECT PeerIP,Version,APIKey,Name,PublicIP,Port,Mediator,EncPubKey,SignPubKey FROM peers" +const Peer_SelectQuery = "SELECT PeerIP,Version,APIKey,Name,PublicIP,Port,Relay,PubKey FROM peers" func Peer_Insert( tx TX, @@ -329,7 +328,7 @@ func Peer_Insert( return err } - _, err = tx.Exec("INSERT INTO peers(PeerIP,Version,APIKey,Name,PublicIP,Port,Mediator,EncPubKey,SignPubKey) VALUES(?,?,?,?,?,?,?,?,?)", row.PeerIP, row.Version, row.APIKey, row.Name, row.PublicIP, row.Port, row.Mediator, row.EncPubKey, row.SignPubKey) + _, err = tx.Exec("INSERT INTO peers(PeerIP,Version,APIKey,Name,PublicIP,Port,Relay,PubKey) VALUES(?,?,?,?,?,?,?,?)", row.PeerIP, row.Version, row.APIKey, row.Name, row.PublicIP, row.Port, row.Relay, row.PubKey) return err } @@ -342,7 +341,7 @@ func Peer_Update( return err } - result, err := tx.Exec("UPDATE peers SET Version=?,Name=?,PublicIP=?,Port=?,Mediator=? WHERE PeerIP=?", row.Version, row.Name, row.PublicIP, row.Port, row.Mediator, row.PeerIP) + result, err := tx.Exec("UPDATE peers SET Version=?,Name=?,PublicIP=?,Port=?,Relay=? WHERE PeerIP=?", row.Version, row.Name, row.PublicIP, row.Port, row.Relay, row.PeerIP) if err != nil { return err } @@ -370,7 +369,7 @@ func Peer_UpdateFull( return err } - result, err := tx.Exec("UPDATE peers SET Version=?,APIKey=?,Name=?,PublicIP=?,Port=?,Mediator=?,EncPubKey=?,SignPubKey=? WHERE PeerIP=?", row.Version, row.APIKey, row.Name, row.PublicIP, row.Port, row.Mediator, row.EncPubKey, row.SignPubKey, row.PeerIP) + result, err := tx.Exec("UPDATE peers SET Version=?,APIKey=?,Name=?,PublicIP=?,Port=?,Relay=?,PubKey=? WHERE PeerIP=?", row.Version, row.APIKey, row.Name, row.PublicIP, row.Port, row.Relay, row.PubKey, row.PeerIP) if err != nil { return err } @@ -420,8 +419,8 @@ func Peer_Get( err error, ) { row = &Peer{} - r := tx.QueryRow("SELECT PeerIP,Version,APIKey,Name,PublicIP,Port,Mediator,EncPubKey,SignPubKey FROM peers WHERE PeerIP=?", PeerIP) - err = r.Scan(&row.PeerIP, &row.Version, &row.APIKey, &row.Name, &row.PublicIP, &row.Port, &row.Mediator, &row.EncPubKey, &row.SignPubKey) + r := tx.QueryRow("SELECT PeerIP,Version,APIKey,Name,PublicIP,Port,Relay,PubKey FROM peers WHERE PeerIP=?", PeerIP) + err = r.Scan(&row.PeerIP, &row.Version, &row.APIKey, &row.Name, &row.PublicIP, &row.Port, &row.Relay, &row.PubKey) return } @@ -435,7 +434,7 @@ func Peer_GetWhere( ) { row = &Peer{} r := tx.QueryRow(query, args...) - err = r.Scan(&row.PeerIP, &row.Version, &row.APIKey, &row.Name, &row.PublicIP, &row.Port, &row.Mediator, &row.EncPubKey, &row.SignPubKey) + err = r.Scan(&row.PeerIP, &row.Version, &row.APIKey, &row.Name, &row.PublicIP, &row.Port, &row.Relay, &row.PubKey) return } @@ -455,7 +454,7 @@ func Peer_Iterate( defer rows.Close() for rows.Next() { row := &Peer{} - err := rows.Scan(&row.PeerIP, &row.Version, &row.APIKey, &row.Name, &row.PublicIP, &row.Port, &row.Mediator, &row.EncPubKey, &row.SignPubKey) + err := rows.Scan(&row.PeerIP, &row.Version, &row.APIKey, &row.Name, &row.PublicIP, &row.Port, &row.Relay, &row.PubKey) if !yield(row, err) { return } diff --git a/hub/api/db/sanitize-validate.go b/hub/api/db/sanitize-validate.go index b4ed8ff..e06ad94 100644 --- a/hub/api/db/sanitize-validate.go +++ b/hub/api/db/sanitize-validate.go @@ -51,7 +51,7 @@ func Peer_Sanitize(p *Peer) { } } if p.Port == 0 { - p.Port = 515 + p.Port = 456 } } diff --git a/hub/api/db/tables.defs b/hub/api/db/tables.defs index c9e35e2..6df286f 100644 --- a/hub/api/db/tables.defs +++ b/hub/api/db/tables.defs @@ -20,7 +20,6 @@ TABLE peers OF Peer ( Name string, PublicIP []byte, Port uint16, - Mediator bool, - EncPubKey []byte NoUpdate, - SignPubKey []byte NoUpdate + Relay bool, + PubKey []byte NoUpdate ); diff --git a/hub/api/db/written.go b/hub/api/db/written.go index 65769c4..5b8bb15 100644 --- a/hub/api/db/written.go +++ b/hub/api/db/written.go @@ -1,12 +1,12 @@ package db -import "vppn/fasttime" +import "time" func Session_UpdateLastSeenAt( tx TX, id string, ) (err error) { - _, err = tx.Exec("UPDATE sessions SET LastSeenAt=? WHERE SessionID=?", fasttime.Now(), id) + _, err = tx.Exec("UPDATE sessions SET LastSeenAt=? WHERE SessionID=?", time.Now().Unix(), id) return err } diff --git a/hub/api/migrations/2024-11-30-init.sql b/hub/api/migrations/2024-11-30-init.sql index eb5da37..ee37ddc 100644 --- a/hub/api/migrations/2024-11-30-init.sql +++ b/hub/api/migrations/2024-11-30-init.sql @@ -22,7 +22,6 @@ CREATE TABLE peers ( Name TEXT NOT NULL UNIQUE, -- For humans. PublicIP BLOB NOT NULL, Port INTEGER NOT NULL, - Mediator INTEGER NOT NULL DEFAULT 0, -- Boolean if peer will forward packets. Must also have public address. - EncPubKey BLOB NOT NULL, - SignPubKey BLOB NOT NULL + Relay INTEGER NOT NULL DEFAULT 0, -- Boolean if peer will forward packets. Must also have public address. + PubKey BLOB NOT NULL ) WITHOUT ROWID; diff --git a/hub/handlers.go b/hub/handlers.go index f24aaaa..aabf3c7 100644 --- a/hub/handlers.go +++ b/hub/handlers.go @@ -4,6 +4,8 @@ import ( "errors" "log" "net/http" + "net/netip" + "strings" "vppn/hub/api" "vppn/m" @@ -155,6 +157,29 @@ func (a *App) _adminPeerList(s *api.Session, w http.ResponseWriter, r *http.Requ }) } +func (a *App) _adminHosts(s *api.Session, w http.ResponseWriter, r *http.Request) error { + conf := a.api.Config_Get() + + peers, err := a.api.Peer_List() + if err != nil { + return err + } + + b := strings.Builder{} + + for _, peer := range peers { + ip := conf.VPNNetwork + ip[3] = peer.PeerIP + b.WriteString(netip.AddrFrom4([4]byte(ip)).String()) + b.WriteString(" ") + b.WriteString(peer.Name) + b.WriteString("\n") + } + + w.Write([]byte(b.String())) + return nil +} + func (a *App) _adminPeerCreate(s *api.Session, w http.ResponseWriter, r *http.Request) error { return a.render("/admin-peer-create.html", w, struct{ Session *api.Session }{s}) } @@ -167,7 +192,7 @@ func (a *App) _adminPeerCreateSubmit(s *api.Session, w http.ResponseWriter, r *h Scan("Name", &args.Name). Scan("PublicIP", &ipStr). Scan("Port", &args.Port). - Scan("Mediator", &args.Mediator). + Scan("Relay", &args.Relay). Error() if err != nil { return err @@ -249,7 +274,7 @@ func (a *App) _adminPeerEditSubmit(s *api.Session, w http.ResponseWriter, r *htt Scan("Name", &peer.Name). Scan("PublicIP", &ipStr). Scan("Port", &peer.Port). - Scan("Mediator", &peer.Mediator). + Scan("Relay", &peer.Relay). Error() if err != nil { return err @@ -311,19 +336,16 @@ func (a *App) _peerCreate(w http.ResponseWriter, r *http.Request) error { func (a *App) _peerFetchState(w http.ResponseWriter, r *http.Request) error { _, apiKey, ok := r.BasicAuth() if !ok { - log.Printf("1") return api.ErrNotAuthorized } peer, err := a.api.Peer_GetByAPIKey(apiKey) if err != nil { - log.Printf("2") return err } peers, err := a.api.Peer_List() if err != nil { - log.Printf("3") return err } @@ -339,14 +361,13 @@ func (a *App) _peerFetchState(w http.ResponseWriter, r *http.Request) error { for _, p := range peers { state.Peers[p.PeerIP] = &m.Peer{ - PeerIP: p.PeerIP, - Version: p.Version, - Name: p.Name, - PublicIP: p.PublicIP, - Port: p.Port, - Mediator: p.Mediator, - EncPubKey: p.EncPubKey, - SignPubKey: p.SignPubKey, + PeerIP: p.PeerIP, + Version: p.Version, + Name: p.Name, + PublicIP: p.PublicIP, + Port: p.Port, + Relay: p.Relay, + PubKey: p.PubKey, } } diff --git a/hub/routes.go b/hub/routes.go index 0fa47f2..a29736f 100644 --- a/hub/routes.go +++ b/hub/routes.go @@ -17,6 +17,7 @@ func (a *App) registerRoutes() { a.handleSignedIn("GET /admin/password/edit/", a._adminPasswordEdit) a.handleSignedIn("POST /admin/password/edit/", a._adminPasswordSubmit) a.handleSignedIn("GET /admin/peer/list/", a._adminPeerList) + a.handleSignedIn("GET /admin/peer/hosts/", a._adminHosts) a.handleSignedIn("GET /admin/peer/create/", a._adminPeerCreate) a.handleSignedIn("POST /admin/peer/create/", a._adminPeerCreateSubmit) a.handleSignedIn("GET /admin/peer/intent-created/", a._adminPeerIntentCreated) diff --git a/hub/templates/admin-peer-create.html b/hub/templates/admin-peer-create.html index f2f0c39..8225fc8 100644 --- a/hub/templates/admin-peer-create.html +++ b/hub/templates/admin-peer-create.html @@ -13,12 +13,12 @@


- +

diff --git a/hub/templates/admin-peer-delete.html b/hub/templates/admin-peer-delete.html index a330eb8..9290f68 100644 --- a/hub/templates/admin-peer-delete.html +++ b/hub/templates/admin-peer-delete.html @@ -22,8 +22,8 @@

diff --git a/hub/templates/admin-peer-edit.html b/hub/templates/admin-peer-edit.html index c6081b1..da40de8 100644 --- a/hub/templates/admin-peer-edit.html +++ b/hub/templates/admin-peer-edit.html @@ -22,8 +22,8 @@

diff --git a/hub/templates/admin-peer-list.html b/hub/templates/admin-peer-list.html index 4acadc7..cb7c72c 100644 --- a/hub/templates/admin-peer-list.html +++ b/hub/templates/admin-peer-list.html @@ -2,7 +2,8 @@

Peers

- Add Peer + Add Peer / + Hosts

{{if .Peers -}} @@ -13,7 +14,7 @@ Name Public IP Port - Mediator + Relay @@ -27,7 +28,7 @@ {{.Name}} {{ipToString .PublicIP}} {{.Port}} - {{if .Mediator}}T{{else}}F{{end}} + {{if .Relay}}T{{else}}F{{end}} {{- end}} diff --git a/hub/templates/admin-peer-view.html b/hub/templates/admin-peer-view.html index 89ff754..e8d6f6e 100644 --- a/hub/templates/admin-peer-view.html +++ b/hub/templates/admin-peer-view.html @@ -12,7 +12,7 @@ Name{{.Name}} Public IP{{ipToString .PublicIP}} Port{{.Port}} - Mediator{{if .Mediator}}T{{else}}F{{end}} + Relay{{if .Relay}}T{{else}}F{{end}} API Key{{.APIKey}} {{- end}} diff --git a/m/models.go b/m/models.go index 29c39f9..345bf5d 100644 --- a/m/models.go +++ b/m/models.go @@ -2,28 +2,25 @@ package m type PeerConfig struct { - PeerIP byte - HubAddress string - Network []byte - APIKey string - PublicIP []byte - Port uint16 - Mediator bool - EncPubKey []byte - EncPrivKey []byte - SignPubKey []byte - SignPrivKey []byte + PeerIP byte + HubAddress string + Network []byte + APIKey string + PublicIP []byte + Port uint16 + Relay bool + PubKey []byte + PrivKey []byte } type Peer struct { - PeerIP byte - Version int64 - Name string - PublicIP []byte - Port uint16 - Mediator bool - EncPubKey []byte - SignPubKey []byte + PeerIP byte + Version int64 + Name string + PublicIP []byte + Port uint16 + Relay bool + PubKey []byte } type NetworkState struct { diff --git a/node/addrdiscovery.go b/node/addrdiscovery.go new file mode 100644 index 0000000..b62e13f --- /dev/null +++ b/node/addrdiscovery.go @@ -0,0 +1,71 @@ +package node + +import ( + "log" + "net/netip" + "time" +) + +func addrDiscoveryServer() { + var ( + buf1 = make([]byte, bufferSize) + buf2 = make([]byte, bufferSize) + ) + + for { + pkt := <-discoveryPackets + + p, ok := pkt.Payload.(addrDiscoveryPacket) + if !ok { + continue + } + + route := routingTable[pkt.SrcIP].Load() + if route == nil || !route.RemoteAddr.IsValid() { + continue + } + + _sendControlPacket(addrDiscoveryPacket{ + TraceID: p.TraceID, + ToAddr: pkt.SrcAddr, + }, *route, buf1, buf2) + } +} + +func addrDiscoveryClient() { + var ( + checkInterval = 8 * time.Second + timer = time.NewTimer(4 * time.Second) + + buf1 = make([]byte, bufferSize) + buf2 = make([]byte, bufferSize) + + addrPacket addrDiscoveryPacket + lAddr netip.AddrPort + ) + + for { + select { + case pkt := <-discoveryPackets: + p, ok := pkt.Payload.(addrDiscoveryPacket) + if !ok || p.TraceID != addrPacket.TraceID || !p.ToAddr.IsValid() || p.ToAddr == lAddr { + continue + } + + log.Printf("Discovered local address: %v", p.ToAddr) + lAddr = p.ToAddr + localAddr.Store(&p.ToAddr) + + case <-timer.C: + timer.Reset(checkInterval) + + route := getRelayRoute() + if route == nil { + continue + } + + addrPacket.TraceID = newTraceID() + _sendControlPacket(addrPacket, *route, buf1, buf2) + } + } +} diff --git a/node/cipher-control.go b/node/cipher-control.go new file mode 100644 index 0000000..bd11470 --- /dev/null +++ b/node/cipher-control.go @@ -0,0 +1,26 @@ +package node + +import "golang.org/x/crypto/nacl/box" + +type controlCipher struct { + sharedKey [32]byte +} + +func newControlCipher(privKey, pubKey []byte) *controlCipher { + shared := [32]byte{} + box.Precompute(&shared, (*[32]byte)(pubKey), (*[32]byte)(privKey)) + return &controlCipher{shared} +} + +func (cc *controlCipher) Encrypt(h header, data, out []byte) []byte { + const s = controlHeaderSize + out = out[:s+controlCipherOverhead+len(data)] + h.Marshal(out[:s]) + box.SealAfterPrecomputation(out[s:s], data, (*[24]byte)(out[:s]), &cc.sharedKey) + return out +} + +func (cc *controlCipher) Decrypt(encrypted, out []byte) (data []byte, ok bool) { + const s = controlHeaderSize + return box.OpenAfterPrecomputation(out[:0], encrypted[s:], (*[24]byte)(encrypted[:s]), &cc.sharedKey) +} diff --git a/node/cipher-control_test.go b/node/cipher-control_test.go new file mode 100644 index 0000000..ab28860 --- /dev/null +++ b/node/cipher-control_test.go @@ -0,0 +1,122 @@ +package node + +import ( + "bytes" + "crypto/rand" + "reflect" + "testing" + + "golang.org/x/crypto/nacl/box" +) + +func newControlCipherForTesting() (c1, c2 *controlCipher) { + pubKey1, privKey1, err := box.GenerateKey(rand.Reader) + if err != nil { + panic(err) + } + + pubKey2, privKey2, err := box.GenerateKey(rand.Reader) + if err != nil { + panic(err) + } + + return newControlCipher(privKey1[:], pubKey2[:]), + newControlCipher(privKey2[:], pubKey1[:]) +} + +func TestControlCipher(t *testing.T) { + c1, c2 := newControlCipherForTesting() + + maxSizePlaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead) + rand.Read(maxSizePlaintext) + + testCases := [][]byte{ + make([]byte, 0), + {1}, + {255}, + {1, 2, 3, 4, 5}, + []byte("Hello world"), + maxSizePlaintext, + } + + for _, plaintext := range testCases { + h1 := header{ + StreamID: controlStreamID, + Counter: 235153, + SourceIP: 4, + DestIP: 88, + } + + encrypted := make([]byte, bufferSize) + + encrypted = c1.Encrypt(h1, plaintext, encrypted) + + h2 := header{} + h2.Parse(encrypted) + if !reflect.DeepEqual(h1, h2) { + t.Fatal(h1, h2) + } + + decrypted, ok := c2.Decrypt(encrypted, make([]byte, bufferSize)) + if !ok { + t.Fatal(ok) + } + + if !bytes.Equal(decrypted, plaintext) { + t.Fatal("not equal") + } + } +} + +func TestControlCipher_ShortCiphertext(t *testing.T) { + c1, _ := newControlCipherForTesting() + shortText := make([]byte, controlHeaderSize+controlCipherOverhead-1) + rand.Read(shortText) + _, ok := c1.Decrypt(shortText, make([]byte, bufferSize)) + if ok { + t.Fatal(ok) + } +} + +func BenchmarkControlCipher_Encrypt(b *testing.B) { + c1, _ := newControlCipherForTesting() + h1 := header{ + Counter: 235153, + SourceIP: 4, + DestIP: 88, + } + + plaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead) + rand.Read(plaintext) + + encrypted := make([]byte, bufferSize) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + encrypted = c1.Encrypt(h1, plaintext, encrypted) + } +} + +func BenchmarkControlCipher_Decrypt(b *testing.B) { + c1, c2 := newControlCipherForTesting() + + h1 := header{ + Counter: 235153, + SourceIP: 4, + DestIP: 88, + } + + plaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead) + rand.Read(plaintext) + + encrypted := make([]byte, bufferSize) + + encrypted = c1.Encrypt(h1, plaintext, encrypted) + + decrypted := make([]byte, bufferSize) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + decrypted, _ = c2.Decrypt(encrypted, decrypted) + } +} diff --git a/node/cipher-data.go b/node/cipher-data.go new file mode 100644 index 0000000..7cdc0d5 --- /dev/null +++ b/node/cipher-data.go @@ -0,0 +1,62 @@ +package node + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" +) + +// TODO: Use [32]byte for simplicity everywhere. +type dataCipher struct { + key [32]byte + aead cipher.AEAD +} + +func newDataCipher() *dataCipher { + key := [32]byte{} + if _, err := rand.Read(key[:]); err != nil { + panic(err) + } + return newDataCipherFromKey(key) +} + +// key must be 32 bytes. +func newDataCipherFromKey(key [32]byte) *dataCipher { + block, err := aes.NewCipher(key[:]) + if err != nil { + panic(err) + } + + aead, err := cipher.NewGCM(block) + if err != nil { + panic(err) + } + + return &dataCipher{key: key, aead: aead} +} + +func (sc *dataCipher) Key() [32]byte { + return sc.key +} + +func (sc *dataCipher) Encrypt(h header, data, out []byte) []byte { + const s = dataHeaderSize + out = out[:s+dataCipherOverhead+len(data)] + h.Marshal(out[:s]) + sc.aead.Seal(out[s:s], out[:s], data, nil) + return out +} + +func (sc *dataCipher) Decrypt(encrypted, out []byte) (data []byte, ok bool) { + const s = dataHeaderSize + if len(encrypted) < s+dataCipherOverhead { + ok = false + return + } + + var err error + + data, err = sc.aead.Open(out[:0], encrypted[:s], encrypted[s:], nil) + ok = err == nil + return +} diff --git a/node/cipher-data_test.go b/node/cipher-data_test.go new file mode 100644 index 0000000..493c198 --- /dev/null +++ b/node/cipher-data_test.go @@ -0,0 +1,141 @@ +package node + +import ( + "bytes" + "crypto/rand" + mrand "math/rand/v2" + "reflect" + "testing" +) + +func TestDataCipher(t *testing.T) { + maxSizePlaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead) + rand.Read(maxSizePlaintext) + + testCases := [][]byte{ + make([]byte, 0), + {1}, + {255}, + {1, 2, 3, 4, 5}, + []byte("Hello world"), + maxSizePlaintext, + } + + for _, plaintext := range testCases { + h1 := header{ + StreamID: dataStreamID, + Counter: 235153, + SourceIP: 4, + DestIP: 88, + } + + encrypted := make([]byte, bufferSize) + + dc1 := newDataCipher() + encrypted = dc1.Encrypt(h1, plaintext, encrypted) + h2 := header{} + h2.Parse(encrypted) + + dc2 := newDataCipherFromKey(dc1.Key()) + + decrypted, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize)) + if !ok { + t.Fatal(ok) + } + + if !bytes.Equal(plaintext, decrypted) { + t.Fatal("not equal") + } + + if !reflect.DeepEqual(h1, h2) { + t.Fatalf("%v != %v", h1, h2) + } + } +} + +func TestDataCipher_ModifyCiphertext(t *testing.T) { + maxSizePlaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead) + rand.Read(maxSizePlaintext) + + testCases := [][]byte{ + make([]byte, 0), + {1}, + {255}, + {1, 2, 3, 4, 5}, + []byte("Hello world"), + maxSizePlaintext, + } + + for _, plaintext := range testCases { + h1 := header{ + Counter: 235153, + SourceIP: 4, + DestIP: 88, + } + + encrypted := make([]byte, bufferSize) + + dc1 := newDataCipher() + encrypted = dc1.Encrypt(h1, plaintext, encrypted) + encrypted[mrand.IntN(len(encrypted))]++ + + dc2 := newDataCipherFromKey(dc1.Key()) + + _, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize)) + if ok { + t.Fatal(ok) + } + } +} + +func TestDataCipher_ShortCiphertext(t *testing.T) { + dc1 := newDataCipher() + shortText := make([]byte, dataHeaderSize+dataCipherOverhead-1) + rand.Read(shortText) + _, ok := dc1.Decrypt(shortText, make([]byte, bufferSize)) + if ok { + t.Fatal(ok) + } +} + +func BenchmarkDataCipher_Encrypt(b *testing.B) { + h1 := header{ + Counter: 235153, + SourceIP: 4, + DestIP: 88, + } + + plaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead) + rand.Read(plaintext) + + encrypted := make([]byte, bufferSize) + + dc1 := newDataCipher() + b.ResetTimer() + for i := 0; i < b.N; i++ { + encrypted = dc1.Encrypt(h1, plaintext, encrypted) + } +} + +func BenchmarkDataCipher_Decrypt(b *testing.B) { + h1 := header{ + Counter: 235153, + SourceIP: 4, + DestIP: 88, + } + + plaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead) + rand.Read(plaintext) + + encrypted := make([]byte, bufferSize) + + dc1 := newDataCipher() + encrypted = dc1.Encrypt(h1, plaintext, encrypted) + + decrypted := make([]byte, bufferSize) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + decrypted, _ = dc1.Decrypt(encrypted, decrypted) + } +} diff --git a/node/conn.go b/node/conn.go index 9224d57..7671f36 100644 --- a/node/conn.go +++ b/node/conn.go @@ -1,172 +1,49 @@ package node import ( + "io" "log" "net" "net/netip" + "runtime/debug" "sync" - "sync/atomic" - "vppn/fasttime" ) +// ---------------------------------------------------------------------------- + type connWriter struct { - *net.UDPConn - lock sync.Mutex - localIP byte - buf []byte - buf2 []byte - counters [256]uint64 - routing *routingTable + lock sync.Mutex + conn *net.UDPConn } -func newConnWriter(conn *net.UDPConn, localIP byte, routing *routingTable) *connWriter { - w := &connWriter{ - UDPConn: conn, - localIP: localIP, - buf: make([]byte, bufferSize), - buf2: make([]byte, bufferSize), - routing: routing, - } - - for i := range w.counters { - w.counters[i] = uint64(fasttime.Now() << 30) - } - - return w +func newConnWriter(conn *net.UDPConn) *connWriter { + return &connWriter{conn: conn} } -func (w *connWriter) WriteTo(remoteIP, stream byte, data []byte) { - dstPeer := w.routing.Get(remoteIP) - if dstPeer == nil { - log.Printf("No peer: %d", remoteIP) - return - } - - if stream == streamData && !dstPeer.Up { - log.Printf("Peer down: %d", remoteIP) - return - } - - var viaPeer *peer - if dstPeer.Mediated { - viaPeer = w.routing.mediator.Load() - if viaPeer == nil || viaPeer.Addr == nil { - log.Printf("Mediator not connected") - return - } - } else if dstPeer.Addr == nil { - log.Printf("Peer doesn't have address: %d", remoteIP) - return - } - - w.WriteToPeer(dstPeer, viaPeer, stream, data) -} - -func (w *connWriter) WriteToPeer(dstPeer, viaPeer *peer, stream byte, data []byte) { +func (w *connWriter) WriteTo(packet []byte, addr netip.AddrPort) { w.lock.Lock() - - addr := dstPeer.Addr - - h := header{ - Counter: atomic.AddUint64(&w.counters[dstPeer.IP], 1), - SourceIP: w.localIP, - DestIP: dstPeer.IP, - Stream: stream, - } - - buf := encryptPacket(&h, dstPeer.SharedKey, data, w.buf) - - if viaPeer != nil { - h := header{ - Counter: atomic.AddUint64(&w.counters[viaPeer.IP], 1), - SourceIP: w.localIP, - DestIP: dstPeer.IP, - Forward: 1, - Stream: stream, - } - - buf = encryptPacket(&h, viaPeer.SharedKey, buf, w.buf2) - addr = viaPeer.Addr - } - - if _, err := w.WriteToUDPAddrPort(buf, *addr); err != nil { + if _, err := w.conn.WriteToUDPAddrPort(packet, addr); err != nil { + debug.PrintStack() log.Fatalf("Failed to write to UDP port: %v", err) } w.lock.Unlock() } -func (w *connWriter) Forward(dstIP byte, packet []byte) { - dstPeer := w.routing.Get(dstIP) - if dstPeer == nil || dstPeer.Addr == nil { - log.Printf("No peer: %d", dstIP) - return - } - - if _, err := w.WriteToUDPAddrPort(packet, *dstPeer.Addr); err != nil { - log.Fatalf("Failed to write to UDP port: %v", err) - } -} - // ---------------------------------------------------------------------------- -type connReader struct { - *net.UDPConn - localIP byte - dupChecks [256]*dupCheck - routing *routingTable - buf []byte +type ifWriter struct { + lock sync.Mutex + iface io.ReadWriteCloser } -func newConnReader(conn *net.UDPConn, localIP byte, routing *routingTable) *connReader { - r := &connReader{ - UDPConn: conn, - localIP: localIP, - routing: routing, - buf: make([]byte, bufferSize), - } - for i := range r.dupChecks { - r.dupChecks[i] = newDupCheck(0) - } - return r +func newIFWriter(iface io.ReadWriteCloser) *ifWriter { + return &ifWriter{iface: iface} } -func (r *connReader) Read(buf []byte) (remoteAddr netip.AddrPort, h header, data []byte) { - var ( - n int - err error - ) - - for { - n, remoteAddr, err = r.ReadFromUDPAddrPort(buf[:bufferSize]) - if err != nil { - log.Fatalf("Failed to read from UDP port: %v", err) - } - - data = buf[:n] - - if n < headerSize { - continue // Packet it soo short. - } - - h.Parse(data) - - peer := r.routing.Get(h.SourceIP) - if peer == nil { - continue - } - - out, ok := decryptPacket(peer.SharedKey, data, r.buf) - if !ok { - continue - } - - out, data = data, out - - if r.dupChecks[h.SourceIP].IsDup(h.Counter) { - log.Printf("Duplicate: %d", h.Counter) - continue - } - - return +func (w *ifWriter) Write(packet []byte) { + w.lock.Lock() + if _, err := w.iface.Write(packet); err != nil { + log.Fatalf("Failed to write to interface: %v", err) } + w.lock.Unlock() } diff --git a/node/crypto.go b/node/crypto.go deleted file mode 100644 index cc5904f..0000000 --- a/node/crypto.go +++ /dev/null @@ -1,50 +0,0 @@ -package node - -import ( - "sync" - "vppn/fasttime" - - "golang.org/x/crypto/nacl/box" -) - -// Encrypting the packet will also set the header's DataSize field. -func encryptPacket(h *header, sharedKey, data, out []byte) []byte { - out = out[:headerSize] - h.Marshal(out) - b := box.SealAfterPrecomputation(out[headerSize:headerSize], data, (*[24]byte)(out[:headerSize]), (*[32]byte)(sharedKey)) - return out[:len(b)+headerSize] -} - -func decryptPacket(sharedKey, packetAndHeader, out []byte) (decrypted []byte, ok bool) { - return box.OpenAfterPrecomputation( - out[:0], - packetAndHeader[headerSize:], - (*[24]byte)(packetAndHeader[:headerSize]), - (*[32]byte)(sharedKey)) -} - -func computeSharedKey(peerPubKey, privKey []byte) []byte { - shared := [32]byte{} - box.Precompute(&shared, (*[32]byte)(peerPubKey), (*[32]byte)(privKey)) - return shared[:] -} - -var ( - traceIDLock sync.Mutex - traceIDTime uint64 - traceIDCounter uint64 -) - -func newTraceID() (id uint64) { - traceIDLock.Lock() - defer traceIDLock.Unlock() - - now := uint64(fasttime.Now()) - if traceIDTime < now { - traceIDTime = now - traceIDCounter = 0 - } - traceIDCounter++ - - return traceIDTime<<30 + traceIDCounter -} diff --git a/node/crypto_test.go b/node/crypto_test.go deleted file mode 100644 index 0a651b0..0000000 --- a/node/crypto_test.go +++ /dev/null @@ -1,140 +0,0 @@ -package node - -import ( - "bytes" - "crypto/rand" - "log" - "reflect" - "testing" - - "golang.org/x/crypto/nacl/box" -) - -func TestEncryptDecryptPacket(t *testing.T) { - pubKey1, privKey1, err := box.GenerateKey(rand.Reader) - if err != nil { - t.Fatal(err) - } - - pubKey2, privKey2, err := box.GenerateKey(rand.Reader) - if err != nil { - t.Fatal(err) - } - - log.Printf("\n%#v\n%#v\n%#v\n%#v\n", pubKey1, privKey1, pubKey2, privKey2) - - sharedEncKey := [32]byte{} - box.Precompute(&sharedEncKey, pubKey2, privKey1) - - sharedDecKey := [32]byte{} - box.Precompute(&sharedDecKey, pubKey1, privKey2) - - original := make([]byte, if_mtu-64) - rand.Read(original) - - h := header{ - Counter: 2893749238, - SourceIP: 5, - DestIP: 12, - Forward: 1, - Stream: 1, - } - - encrypted := make([]byte, bufferSize) - encrypted = encryptPacket(&h, sharedEncKey[:], original, encrypted) - - decrypted := make([]byte, bufferSize) - var ok bool - decrypted, ok = decryptPacket(sharedDecKey[:], encrypted, decrypted) - if !ok { - t.Fatal(ok) - } - - var h2 header - h2.Parse(encrypted) - - if !reflect.DeepEqual(h, h2) { - t.Fatal(h, h2) - } - - if !bytes.Equal(original, decrypted) { - t.Fatal("mismatch") - } -} - -func BenchmarkEncryptPacket(b *testing.B) { - _, privKey1, err := box.GenerateKey(rand.Reader) - if err != nil { - b.Fatal(err) - } - - pubKey2, _, err := box.GenerateKey(rand.Reader) - if err != nil { - b.Fatal(err) - } - - sharedEncKey := [32]byte{} - box.Precompute(&sharedEncKey, pubKey2, privKey1) - - original := make([]byte, if_mtu) - rand.Read(original) - - nonce := make([]byte, headerSize) - rand.Read(nonce) - - encrypted := make([]byte, bufferSize) - - h := header{ - Counter: 2893749238, - SourceIP: 5, - DestIP: 12, - Forward: 1, - Stream: 1, - } - - for i := 0; i < b.N; i++ { - encrypted = encryptPacket(&h, sharedEncKey[:], original, encrypted) - } -} - -func BenchmarkDecryptPacket(b *testing.B) { - pubKey1, privKey1, err := box.GenerateKey(rand.Reader) - if err != nil { - b.Fatal(err) - } - - pubKey2, privKey2, err := box.GenerateKey(rand.Reader) - if err != nil { - b.Fatal(err) - } - - sharedEncKey := [32]byte{} - box.Precompute(&sharedEncKey, pubKey2, privKey1) - - sharedDecKey := [32]byte{} - box.Precompute(&sharedDecKey, pubKey1, privKey2) - - original := make([]byte, if_mtu) - rand.Read(original) - - nonce := make([]byte, headerSize) - rand.Read(nonce) - - h := header{ - Counter: 2893749238, - SourceIP: 5, - DestIP: 12, - Forward: 1, - Stream: 1, - } - - encrypted := encryptPacket(&h, sharedEncKey[:], original, make([]byte, bufferSize)) - decrypted := make([]byte, bufferSize) - var ok bool - for i := 0; i < b.N; i++ { - decrypted, ok = decryptPacket(sharedDecKey[:], encrypted, decrypted) - if !ok { - panic(ok) - } - } -} diff --git a/node/dupcheck.go b/node/dupcheck.go index e960bd4..fac7a72 100644 --- a/node/dupcheck.go +++ b/node/dupcheck.go @@ -1,7 +1,5 @@ package node -import "log" - type dupCheck struct { bitSet head int @@ -22,7 +20,6 @@ func (dc *dupCheck) IsDup(counter uint64) bool { // Before head => it's late, say it's a dup. if counter < dc.headCounter { - log.Printf("Late: %d", counter) return true } @@ -30,7 +27,6 @@ func (dc *dupCheck) IsDup(counter uint64) bool { if counter < dc.tailCounter { index := (int(counter-dc.headCounter) + dc.head) % bitSetSize if dc.Get(index) { - log.Printf("Dup: %d, %d", counter, dc.tailCounter) return true } diff --git a/node/dupcheck_test.go b/node/dupcheck_test.go index 9a939b5..2156b4e 100644 --- a/node/dupcheck_test.go +++ b/node/dupcheck_test.go @@ -1,7 +1,6 @@ package node import ( - "log" "testing" ) @@ -49,8 +48,6 @@ func TestDupCheck(t *testing.T) { for i, tc := range testCases { if ok := dc.IsDup(tc.Counter); ok != tc.Dup { - log.Printf("%b", dc.bitSet) - log.Printf("%+v", *dc) t.Fatal(i, ok, tc) } } diff --git a/node/globalfuncs.go b/node/globalfuncs.go new file mode 100644 index 0000000..98975da --- /dev/null +++ b/node/globalfuncs.go @@ -0,0 +1,73 @@ +package node + +import ( + "net/netip" + "sync/atomic" +) + +func getRelayRoute() *peerRoute { + if ip := relayIP.Load(); ip != nil { + return routingTable[*ip].Load() + } + return nil +} + +func getLocalAddr() netip.AddrPort { + if a := localAddr.Load(); a != nil { + return *a + } + return netip.AddrPort{} +} + +func _sendControlPacket(pkt interface{ Marshal([]byte) []byte }, route peerRoute, buf1, buf2 []byte) { + buf := pkt.Marshal(buf2) + h := header{ + StreamID: controlStreamID, + Counter: atomic.AddUint64(&sendCounters[route.IP], 1), + SourceIP: localIP, + DestIP: route.IP, + } + buf = route.ControlCipher.Encrypt(h, buf, buf1) + + if route.Direct { + _conn.WriteTo(buf, route.RemoteAddr) + return + } + + _relayPacket(route.IP, buf, buf2) +} + +func _sendDataPacket(route *peerRoute, pkt, buf1, buf2 []byte) { + h := header{ + StreamID: dataStreamID, + Counter: atomic.AddUint64(&sendCounters[route.IP], 1), + SourceIP: localIP, + DestIP: route.IP, + } + + enc := route.DataCipher.Encrypt(h, pkt, buf1) + + if route.Direct { + _conn.WriteTo(enc, route.RemoteAddr) + return + } + + _relayPacket(route.IP, enc, buf2) +} + +func _relayPacket(destIP byte, data, buf []byte) { + relayRoute := getRelayRoute() + if relayRoute == nil || !relayRoute.Up || !relayRoute.Relay { + return + } + + h := header{ + StreamID: dataStreamID, + Counter: atomic.AddUint64(&sendCounters[relayRoute.IP], 1), + SourceIP: localIP, + DestIP: destIP, + } + + enc := relayRoute.DataCipher.Encrypt(h, data, buf) + _conn.WriteTo(enc, relayRoute.RemoteAddr) +} diff --git a/node/globals.go b/node/globals.go index 172e6ef..3b8edea 100644 --- a/node/globals.go +++ b/node/globals.go @@ -1,3 +1,86 @@ package node -const bufferSize = if_mtu + 128 +import ( + "net/netip" + "sync/atomic" + "time" + "vppn/m" +) + +const ( + bufferSize = 1536 + if_mtu = 1200 + if_queue_len = 2048 + controlCipherOverhead = 16 + dataCipherOverhead = 16 +) + +type peerRoute struct { + IP byte + Up bool // True if data can be sent on the route. + Relay bool // True if the peer is a relay. + Direct bool // True if this is a direct connection. + ControlCipher *controlCipher + DataCipher *dataCipher + RemoteAddr netip.AddrPort // Remote address if directly connected. +} + +var ( + // Configuration for this peer. + netName string + localIP byte + localPub bool + privateKey []byte + + // Shared interface for writing. + _iface *ifWriter + + // Shared connection for writing. + _conn *connWriter + + // Counters for sending to each peer. + sendCounters [256]uint64 = func() (out [256]uint64) { + for i := range out { + out[i] = uint64(time.Now().Unix()<<30 + 1) + } + return + }() + + // Duplicate checkers for incoming packets. + dupChecks [256]*dupCheck = func() (out [256]*dupCheck) { + for i := range out { + out[i] = newDupCheck(0) + } + return + }() + + // Channels for incoming control packets. + controlPackets [256]chan controlPacket = func() (out [256]chan controlPacket) { + for i := range out { + out[i] = make(chan controlPacket, 256) + } + return + }() + + // Channels for incoming peer updates from the hub. + peerUpdates [256]chan *m.Peer = func() (out [256]chan *m.Peer) { + for i := range out { + out[i] = make(chan *m.Peer) + } + return + }() + + // Global routing table. + routingTable [256]*atomic.Pointer[peerRoute] = func() (out [256]*atomic.Pointer[peerRoute]) { + for i := range out { + out[i] = &atomic.Pointer[peerRoute]{} + out[i].Store(&peerRoute{}) + } + return + }() + + // Managed by the relayManager. + discoveryPackets chan controlPacket + localAddr *atomic.Pointer[netip.AddrPort] // May be nil. + relayIP *atomic.Pointer[byte] // May be nil. +) diff --git a/node/header.go b/node/header.go index ed3671a..58ba852 100644 --- a/node/header.go +++ b/node/header.go @@ -2,32 +2,34 @@ package node import "unsafe" +// ---------------------------------------------------------------------------- + const ( - headerSize = 24 - streamData = 1 - streamRouting = 2 + headerSize = 12 + controlStreamID = 2 + controlHeaderSize = 24 + dataStreamID = 1 + dataHeaderSize = 12 ) type header struct { - Counter uint64 // Init with fasttime.Now() << 30 to ensure monotonic. + StreamID byte + Counter uint64 // Init with time.Now().Unix << 30 to ensure monotonic. SourceIP byte DestIP byte - Forward byte - Stream byte // See stream* constants. } -func (hdr *header) Parse(nb []byte) { - hdr.Counter = *(*uint64)(unsafe.Pointer(&nb[0])) - hdr.SourceIP = nb[8] - hdr.DestIP = nb[9] - hdr.Forward = nb[10] - hdr.Stream = nb[11] +func (h *header) Parse(b []byte) { + h.StreamID = b[0] + h.Counter = *(*uint64)(unsafe.Pointer(&b[1])) + h.SourceIP = b[9] + h.DestIP = b[10] } -func (hdr header) Marshal(buf []byte) { - *(*uint64)(unsafe.Pointer(&buf[0])) = hdr.Counter - buf[8] = hdr.SourceIP - buf[9] = hdr.DestIP - buf[10] = hdr.Forward - buf[11] = hdr.Stream +func (h *header) Marshal(buf []byte) { + buf[0] = h.StreamID + *(*uint64)(unsafe.Pointer(&buf[1])) = h.Counter + buf[9] = h.SourceIP + buf[10] = h.DestIP + buf[11] = 0 } diff --git a/node/header_test.go b/node/header_test.go index 7a87354..9dbb061 100644 --- a/node/header_test.go +++ b/node/header_test.go @@ -4,11 +4,10 @@ import "testing" func TestHeaderMarshalParse(t *testing.T) { nIn := header{ + StreamID: 23, Counter: 3212, SourceIP: 34, DestIP: 200, - Forward: 1, - Stream: 44, } buf := make([]byte, headerSize) diff --git a/node/hubpoller.go b/node/hubpoller.go new file mode 100644 index 0000000..ac6b110 --- /dev/null +++ b/node/hubpoller.go @@ -0,0 +1,94 @@ +package node + +import ( + "encoding/json" + "io" + "log" + "net/http" + "net/url" + "time" + "vppn/m" +) + +type hubPoller struct { + client *http.Client + req *http.Request + versions [256]int64 +} + +func newHubPoller(conf m.PeerConfig) *hubPoller { + u, err := url.Parse(conf.HubAddress) + if err != nil { + log.Fatalf("Failed to parse hub address %s: %v", conf.HubAddress, err) + } + u.Path = "/peer/fetch-state/" + + client := &http.Client{Timeout: 8 * time.Second} + + req := &http.Request{ + Method: http.MethodGet, + URL: u, + Header: http.Header{}, + } + req.SetBasicAuth("", conf.APIKey) + + return &hubPoller{ + client: client, + req: req, + } +} + +func (hp *hubPoller) Run() { + defer panicHandler() + + state, err := loadNetworkState(netName) + if err != nil { + log.Printf("Failed to load network state: %v", err) + log.Printf("Polling hub...") + hp.pollHub() + } else { + hp.applyNetworkState(state) + } + + for range time.Tick(64 * time.Second) { + hp.pollHub() + } +} + +func (hp *hubPoller) pollHub() { + var state m.NetworkState + + resp, err := hp.client.Do(hp.req) + if err != nil { + log.Printf("Failed to fetch peer state: %v", err) + return + } + body, err := io.ReadAll(resp.Body) + _ = resp.Body.Close() + if err != nil { + log.Printf("Failed to read body from hub: %v", err) + return + } + + if err := json.Unmarshal(body, &state); err != nil { + log.Printf("Failed to unmarshal response from hub: %v", err) + return + } + + hp.applyNetworkState(state) + + if err := storeNetworkState(netName, state); err != nil { + log.Printf("Failed to store network state: %v", err) + } +} + +func (hp *hubPoller) applyNetworkState(state m.NetworkState) { + for i, peer := range state.Peers { + if i != int(localIP) { + if peer != nil && peer.Version != hp.versions[i] { + peerUpdates[i] <- state.Peers[i] + hp.versions[i] = peer.Version + } + } + } +} diff --git a/node/interface.go b/node/interface.go index c5edf3e..4b492b4 100644 --- a/node/interface.go +++ b/node/interface.go @@ -50,11 +50,6 @@ func readNextPacket(iface io.ReadWriteCloser, buf []byte) ([]byte, byte, error) } } -const ( - if_mtu = 1200 - if_queue_len = 2048 -) - func openInterface(network []byte, localIP byte, name string) (io.ReadWriteCloser, error) { if len(network) != 4 { return nil, fmt.Errorf("expected network to be 4 bytes, got %d", len(network)) diff --git a/node/main.go b/node/main.go index cac6df8..ee2e7a7 100644 --- a/node/main.go +++ b/node/main.go @@ -11,6 +11,7 @@ import ( "net/netip" "os" "runtime/debug" + "sync/atomic" "vppn/m" ) @@ -24,7 +25,6 @@ func Main() { defer panicHandler() var ( - netName string initURL string listenIP string port int @@ -42,14 +42,14 @@ func Main() { } if initURL != "" { - mainInit(netName, initURL) + mainInit(initURL) return } - main(netName, listenIP, uint16(port)) + main(listenIP, uint16(port)) } -func mainInit(netName, initURL string) { +func mainInit(initURL string) { if _, err := loadPeerConfig(netName); err == nil { log.Fatalf("Network is already initialized.") } @@ -79,15 +79,15 @@ func mainInit(netName, initURL string) { // ---------------------------------------------------------------------------- -func main(netName, listenIP string, port uint16) { - conf, err := loadPeerConfig(netName) +func main(listenIP string, port uint16) { + config, err := loadPeerConfig(netName) if err != nil { log.Fatalf("Failed to load configuration: %v", err) } - port = determinePort(conf.Port, port) + port = determinePort(config.Port, port) - iface, err := openInterface(conf.Network, conf.PeerIP, netName) + iface, err := openInterface(config.Network, config.PeerIP, netName) if err != nil { log.Fatalf("Failed to open interface: %v", err) } @@ -102,15 +102,38 @@ func main(netName, listenIP string, port uint16) { log.Fatalf("Failed to open UDP port: %v", err) } - routing := newRoutingTable() + // Intialize globals. + _iface = newIFWriter(iface) + _conn = newConnWriter(conn) - w := newConnWriter(conn, conf.PeerIP, routing) - r := newConnReader(conn, conf.PeerIP, routing) + localIP = config.PeerIP + discoveryPackets = make(chan controlPacket, 256) + localAddr = &atomic.Pointer[netip.AddrPort]{} + relayIP = &atomic.Pointer[byte]{} - router := newRouter(netName, conf, routing, w) + ip, ok := netip.AddrFromSlice(config.PublicIP) + if ok { + localPub = true + addr := netip.AddrPortFrom(ip, config.Port) + localAddr.Store(&addr) + } - go nodeConnReader(r, w, iface, router) - nodeIFaceReader(w, iface, router) + privateKey = config.PrivKey + + // Start supervisors. + for i := range 256 { + go newPeerSupervisor(i).Run() + } + + if localPub { + go addrDiscoveryServer() + } else { + go addrDiscoveryClient() + go relayManager() + } + go newHubPoller(config).Run() + go readFromConn(conn) + readFromIFace(iface) } // ---------------------------------------------------------------------------- @@ -127,62 +150,160 @@ func determinePort(confPort, portFromCommandLine uint16) uint16 { // ---------------------------------------------------------------------------- -func nodeConnReader(r *connReader, w *connWriter, iface io.ReadWriteCloser, router *router) { +func readFromConn(conn *net.UDPConn) { + defer panicHandler() + var ( remoteAddr netip.AddrPort - h header - buf = make([]byte, bufferSize) - data []byte + n int err error + buf = make([]byte, bufferSize) + decBuf = make([]byte, bufferSize) + data []byte + h header ) for { - remoteAddr, h, data = r.Read(buf) - - if h.Forward != 0 { - w.Forward(h.DestIP, data) - continue + n, remoteAddr, err = conn.ReadFromUDPAddrPort(buf[:bufferSize]) + if err != nil { + log.Fatalf("Failed to read from UDP port: %v", err) } - switch h.Stream { + remoteAddr = netip.AddrPortFrom(remoteAddr.Addr().Unmap(), remoteAddr.Port()) - case streamData: - if _, err = iface.Write(data); err != nil { - log.Printf("Malformed data from peer %d: %v", h.SourceIP, err) - } + data = buf[:n] - case streamRouting: - router.HandlePacket(h.SourceIP, remoteAddr, data) + if n < headerSize { + continue // Packet it soo short. + } + + h.Parse(data) + switch h.StreamID { + case controlStreamID: + handleControlPacket(remoteAddr, h, data, decBuf) + + case dataStreamID: + handleDataPacket(h, data, decBuf) default: - log.Printf("Dropping unknown stream: %d", h.Stream) + log.Printf("Unknown stream ID: %d", h.StreamID) } } } +func handleControlPacket(addr netip.AddrPort, h header, data, decBuf []byte) { + route := routingTable[h.SourceIP].Load() + if route.ControlCipher == nil { + //log.Printf("Not connected (control).") + return + } + + if h.DestIP != localIP { + log.Printf("Incorrect destination IP on control packet: %d != %d", h.DestIP, localIP) + return + } + + out, ok := route.ControlCipher.Decrypt(data, decBuf) + if !ok { + //log.Printf("Failed to decrypt control packet.") + return + } + + if len(out) == 0 { + //log.Printf("Empty control packet from: %d", h.SourceIP) + return + } + + if dupChecks[h.SourceIP].IsDup(h.Counter) { + //log.Printf("[%03d] Duplicate control packet: %d", h.SourceIP, h.Counter) + return + } + + pkt := controlPacket{ + SrcIP: h.SourceIP, + SrcAddr: addr, + } + + if err := pkt.ParsePayload(out); err != nil { + log.Printf("Failed to parse control packet: %v", err) + return + } + + switch pkt.Payload.(type) { + + case addrDiscoveryPacket: + select { + case discoveryPackets <- pkt: + default: + log.Printf("Dropping discovery packet.") + } + + default: + select { + case controlPackets[h.SourceIP] <- pkt: + default: + log.Printf("Dropping control packet.") + } + } + +} + +func handleDataPacket(h header, data []byte, decBuf []byte) { + route := routingTable[h.SourceIP].Load() + if !route.Up { + //log.Printf("Not connected (recv).") + return + } + + dec, ok := route.DataCipher.Decrypt(data, decBuf) + if !ok { + log.Printf("Failed to decrypt data packet.") + return + } + + if dupChecks[h.SourceIP].IsDup(h.Counter) { + //log.Printf("[%03d] Duplicate data packet: %d", h.SourceIP, h.Counter) + return + } + + if h.DestIP == localIP { + _iface.Write(dec) + return + } + + destRoute := routingTable[h.DestIP].Load() + if !destRoute.Up { + log.Printf("Not connected (relay): %v", destRoute) + return + } + + _conn.WriteTo(dec, destRoute.RemoteAddr) +} + // ---------------------------------------------------------------------------- -func nodeIFaceReader(w *connWriter, iface io.ReadWriteCloser, router *router) { - +func readFromIFace(iface io.ReadWriteCloser) { var ( - buf = make([]byte, bufferSize) - packet []byte + packet = make([]byte, bufferSize) + buf1 = make([]byte, bufferSize) + buf2 = make([]byte, bufferSize) remoteIP byte err error ) for { - - packet, remoteIP, err = readNextPacket(iface, buf) + packet, remoteIP, err = readNextPacket(iface, packet) if err != nil { log.Fatalf("Failed to read from interface: %v", err) } - if remoteIP == w.localIP { - continue // Don't write to self. + route := routingTable[remoteIP].Load() + if !route.Up { + log.Printf("Route not connected: %d", remoteIP) + continue } - w.WriteTo(remoteIP, streamData, packet) + _sendDataPacket(route, packet, buf1, buf2) } } diff --git a/node/node.go b/node/node.go deleted file mode 100644 index 2b4023a..0000000 --- a/node/node.go +++ /dev/null @@ -1 +0,0 @@ -package node diff --git a/node/packets-util.go b/node/packets-util.go new file mode 100644 index 0000000..af10eb5 --- /dev/null +++ b/node/packets-util.go @@ -0,0 +1,163 @@ +package node + +import ( + "net/netip" + "sync/atomic" + "time" + "unsafe" +) + +var traceIDCounter uint64 = uint64(time.Now().Unix()<<30) + 1 + +func newTraceID() uint64 { + return atomic.AddUint64(&traceIDCounter, 1) +} + +// ---------------------------------------------------------------------------- + +type binWriter struct { + b []byte + i int +} + +func newBinWriter(buf []byte) *binWriter { + buf = buf[:cap(buf)] + return &binWriter{buf, 0} +} + +func (w *binWriter) Bool(b bool) *binWriter { + if b { + return w.Byte(1) + } + return w.Byte(0) +} + +func (w *binWriter) Byte(b byte) *binWriter { + w.b[w.i] = b + w.i++ + return w +} + +func (w *binWriter) SharedKey(key [32]byte) *binWriter { + copy(w.b[w.i:w.i+32], key[:]) + w.i += 32 + return w +} + +func (w *binWriter) Uint16(x uint16) *binWriter { + *(*uint16)(unsafe.Pointer(&w.b[w.i])) = x + w.i += 2 + return w +} + +func (w *binWriter) Uint64(x uint64) *binWriter { + *(*uint64)(unsafe.Pointer(&w.b[w.i])) = x + w.i += 8 + return w +} + +func (w *binWriter) Int64(x int64) *binWriter { + *(*int64)(unsafe.Pointer(&w.b[w.i])) = x + w.i += 8 + return w +} + +func (w *binWriter) AddrPort(addrPort netip.AddrPort) *binWriter { + addr := addrPort.Addr().As16() + copy(w.b[w.i:w.i+16], addr[:]) + w.i += 16 + return w.Uint16(addrPort.Port()) +} + +func (w *binWriter) Build() []byte { + return w.b[:w.i] +} + +// ---------------------------------------------------------------------------- + +type binReader struct { + b []byte + i int + err error +} + +func newBinReader(buf []byte) *binReader { + return &binReader{b: buf} +} + +func (r *binReader) hasBytes(n int) bool { + if r.err != nil || (len(r.b)-r.i) < n { + r.err = errMalformedPacket + return false + } + return true +} + +func (r *binReader) Bool(b *bool) *binReader { + var bb byte + r.Byte(&bb) + *b = bb != 0 + return r +} + +func (r *binReader) Byte(b *byte) *binReader { + if !r.hasBytes(1) { + return r + } + *b = r.b[r.i] + r.i++ + return r +} + +func (r *binReader) SharedKey(x *[32]byte) *binReader { + if !r.hasBytes(32) { + return r + } + *x = ([32]byte)(r.b[r.i : r.i+32]) + r.i += 32 + return r +} + +func (r *binReader) Uint16(x *uint16) *binReader { + if !r.hasBytes(2) { + return r + } + *x = *(*uint16)(unsafe.Pointer(&r.b[r.i])) + r.i += 2 + return r +} + +func (r *binReader) Uint64(x *uint64) *binReader { + if !r.hasBytes(8) { + return r + } + *x = *(*uint64)(unsafe.Pointer(&r.b[r.i])) + r.i += 8 + return r +} + +func (r *binReader) Int64(x *int64) *binReader { + if !r.hasBytes(8) { + return r + } + *x = *(*int64)(unsafe.Pointer(&r.b[r.i])) + r.i += 8 + return r +} + +func (r *binReader) AddrPort(x *netip.AddrPort) *binReader { + if !r.hasBytes(18) { + return r + } + addr := netip.AddrFrom16(([16]byte)(r.b[r.i : r.i+16])).Unmap() + r.i += 16 + + var port uint16 + r.Uint16(&port) + *x = netip.AddrPortFrom(addr, port) + return r +} + +func (r *binReader) Error() error { + return r.err +} diff --git a/node/packets-util_test.go b/node/packets-util_test.go new file mode 100644 index 0000000..06b0370 --- /dev/null +++ b/node/packets-util_test.go @@ -0,0 +1,40 @@ +package node + +import ( + "net/netip" + "reflect" + "testing" +) + +func TestBinWriteRead(t *testing.T) { + buf := make([]byte, 1024) + + type Item struct { + Type byte + TraceID uint64 + DestAddr netip.AddrPort + } + + in := Item{1, 2, netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 22)} + + buf = newBinWriter(buf). + Byte(in.Type). + Uint64(in.TraceID). + AddrPort(in.DestAddr). + Build() + + out := Item{} + + err := newBinReader(buf). + Byte(&out.Type). + Uint64(&out.TraceID). + AddrPort(&out.DestAddr). + Error() + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(in, out) { + t.Fatal(in, out) + } +} diff --git a/node/packets.go b/node/packets.go new file mode 100644 index 0000000..267fed0 --- /dev/null +++ b/node/packets.go @@ -0,0 +1,140 @@ +package node + +import ( + "errors" + "net/netip" +) + +var ( + errMalformedPacket = errors.New("malformed packet") + errUnknownPacketType = errors.New("unknown packet type") +) + +const ( + packetTypeSyn = iota + 1 + packetTypeSynAck + packetTypeAck + packetTypeProbe + packetTypeAddrDiscovery +) + +// ---------------------------------------------------------------------------- + +type controlPacket struct { + SrcIP byte + SrcAddr netip.AddrPort + Payload any +} + +func (p *controlPacket) ParsePayload(buf []byte) (err error) { + switch buf[0] { + case packetTypeSyn: + p.Payload, err = parseSynPacket(buf) + case packetTypeSynAck: + p.Payload, err = parseSynAckPacket(buf) + case packetTypeProbe: + p.Payload, err = parseProbePacket(buf) + case packetTypeAddrDiscovery: + p.Payload, err = parseAddrDiscoveryPacket(buf) + default: + return errUnknownPacketType + } + return err +} + +// ---------------------------------------------------------------------------- + +type synPacket struct { + TraceID uint64 // TraceID to match response w/ request. + SharedKey [32]byte // Our shared key. + Direct bool + FromAddr netip.AddrPort // The client's sending address. +} + +func (p synPacket) Marshal(buf []byte) []byte { + return newBinWriter(buf). + Byte(packetTypeSyn). + Uint64(p.TraceID). + SharedKey(p.SharedKey). + Bool(p.Direct). + AddrPort(p.FromAddr). + Build() +} + +func parseSynPacket(buf []byte) (p synPacket, err error) { + err = newBinReader(buf[1:]). + Uint64(&p.TraceID). + SharedKey(&p.SharedKey). + Bool(&p.Direct). + AddrPort(&p.FromAddr). + Error() + return +} + +// ---------------------------------------------------------------------------- + +type synAckPacket struct { + TraceID uint64 + FromAddr netip.AddrPort +} + +func (p synAckPacket) Marshal(buf []byte) []byte { + return newBinWriter(buf). + Byte(packetTypeSynAck). + Uint64(p.TraceID). + AddrPort(p.FromAddr). + Build() +} + +func parseSynAckPacket(buf []byte) (p synAckPacket, err error) { + err = newBinReader(buf[1:]). + Uint64(&p.TraceID). + AddrPort(&p.FromAddr). + Error() + return +} + +// ---------------------------------------------------------------------------- + +type addrDiscoveryPacket struct { + TraceID uint64 + ToAddr netip.AddrPort +} + +func (p addrDiscoveryPacket) Marshal(buf []byte) []byte { + return newBinWriter(buf). + Byte(packetTypeAddrDiscovery). + Uint64(p.TraceID). + AddrPort(p.ToAddr). + Build() +} + +func parseAddrDiscoveryPacket(buf []byte) (p addrDiscoveryPacket, err error) { + err = newBinReader(buf[1:]). + Uint64(&p.TraceID). + AddrPort(&p.ToAddr). + Error() + return +} + +// ---------------------------------------------------------------------------- + +// A probeReqPacket is sent from a client to a server to determine if direct +// UDP communication can be used. +type probePacket struct { + TraceID uint64 +} + +func (p probePacket) Marshal(buf []byte) []byte { + return newBinWriter(buf). + Byte(packetTypeProbe). + Uint64(p.TraceID). + Build() +} + +func parseProbePacket(buf []byte) (p probePacket, err error) { + err = newBinReader(buf[1:]). + Uint64(&p.TraceID). + Error() + return +} diff --git a/node/packets_test.go b/node/packets_test.go new file mode 100644 index 0000000..60295ec --- /dev/null +++ b/node/packets_test.go @@ -0,0 +1,42 @@ +package node + +import ( + "crypto/rand" + "net/netip" + "reflect" + "testing" +) + +func TestPacketSyn(t *testing.T) { + in := synPacket{ + TraceID: newTraceID(), + RelayIP: 4, + FromAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{4, 5, 6, 7}), 22), + } + rand.Read(in.SharedKey[:]) + + out, err := parseSynPacket(in.Marshal(make([]byte, bufferSize))) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(in, out) { + t.Fatal("\n", in, "\n", out) + } +} + +func TestPacketSynAck(t *testing.T) { + in := synAckPacket{ + TraceID: newTraceID(), + FromAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{4, 5, 6, 7}), 22), + } + + out, err := parseSynAckPacket(in.Marshal(make([]byte, bufferSize))) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(in, out) { + t.Fatal("\n", in, "\n", out) + } +} diff --git a/node/peer-supervisor.go b/node/peer-supervisor.go new file mode 100644 index 0000000..76e329c --- /dev/null +++ b/node/peer-supervisor.go @@ -0,0 +1,331 @@ +package node + +import ( + "fmt" + "log" + "net/netip" + "sync/atomic" + "time" + "vppn/m" +) + +const ( + pingInterval = 8 * time.Second + timeoutInterval = 25 * time.Second +) + +// ---------------------------------------------------------------------------- + +type peerSupervisor struct { + // The purpose of this state machine is to manage this published data. + published *atomic.Pointer[peerRoute] + staged peerRoute // Local copy of shared data. See publish(). + + // Immutable data. + remoteIP byte // Remote VPN IP. + + // Mutable peer data. + peer *m.Peer + remotePub bool + + // Incoming events. + peerUpdates chan *m.Peer + controlPackets chan controlPacket + + // Buffers for sending control packets. + buf1 []byte + buf2 []byte +} + +func newPeerSupervisor(i int) *peerSupervisor { + return &peerSupervisor{ + published: routingTable[i], + remoteIP: byte(i), + peerUpdates: peerUpdates[i], + controlPackets: controlPackets[i], + buf1: make([]byte, bufferSize), + buf2: make([]byte, bufferSize), + } +} + +type stateFunc func() stateFunc + +func (s *peerSupervisor) Run() { + state := s.noPeer + for { + state = state() + } +} + +// ---------------------------------------------------------------------------- + +func (s *peerSupervisor) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) { + _sendControlPacket(pkt, s.staged, s.buf1, s.buf2) + time.Sleep(500 * time.Millisecond) // Rate limit packets. +} + +func (s *peerSupervisor) sendControlPacketTo( + pkt interface{ Marshal([]byte) []byte }, + addr netip.AddrPort, +) { + if !addr.IsValid() { + s.logf("ERROR: Attepted to send packet to invalid address: %v", addr) + return + } + route := s.staged + route.Direct = true + route.RemoteAddr = addr + _sendControlPacket(pkt, route, s.buf1, s.buf2) + time.Sleep(500 * time.Millisecond) // Rate limit packets. +} + +// ---------------------------------------------------------------------------- + +func (s *peerSupervisor) logf(msg string, args ...any) { + log.Printf(fmt.Sprintf("[%03d] ", s.remoteIP)+msg, args...) +} + +// ---------------------------------------------------------------------------- + +func (s *peerSupervisor) publish() { + data := s.staged + s.published.Store(&data) +} + +// ---------------------------------------------------------------------------- + +func (s *peerSupervisor) noPeer() stateFunc { + return s.peerUpdate(<-s.peerUpdates) +} + +// ---------------------------------------------------------------------------- + +func (s *peerSupervisor) peerUpdate(peer *m.Peer) stateFunc { + return func() stateFunc { return s._peerUpdate(peer) } +} + +func (s *peerSupervisor) _peerUpdate(peer *m.Peer) stateFunc { + defer s.publish() + + s.peer = peer + s.staged = peerRoute{} + + if s.peer == nil { + return s.noPeer + } + + s.staged.IP = s.remoteIP + s.staged.ControlCipher = newControlCipher(privateKey, peer.PubKey) + s.staged.DataCipher = newDataCipher() + + if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid { + s.remotePub = true + s.staged.Relay = peer.Relay + s.staged.Direct = true + s.staged.RemoteAddr = netip.AddrPortFrom(ip, peer.Port) + } else if localPub { + s.staged.Direct = true + } + + if s.remotePub == localPub { + if localIP < s.remoteIP { + return s.server + } + return s.client + } + + if s.remotePub { + return s.client + } + return s.server +} + +// ---------------------------------------------------------------------------- + +func (s *peerSupervisor) server() stateFunc { + logf := func(format string, args ...any) { s.logf("SERVER "+format, args...) } + + logf("DOWN") + + var ( + syn synPacket + timeoutTimer = time.NewTimer(timeoutInterval) + ) + + for { + select { + case peer := <-s.peerUpdates: + return s.peerUpdate(peer) + + case pkt := <-s.controlPackets: + switch p := pkt.Payload.(type) { + + case synPacket: + // Before we can respond to this packet, we need to make sure the + // route is setup properly. + // + // The client will update the syn's TraceID whenever there's a change. + // The server will follow the client's request. + if p.TraceID != syn.TraceID || !s.staged.Up { + if p.Direct { + logf("UP - Direct") + } else { + logf("UP - Relayed") + } + + syn = p + s.staged.Up = true + s.staged.Direct = syn.Direct + s.staged.DataCipher = newDataCipherFromKey(syn.SharedKey) + s.staged.RemoteAddr = pkt.SrcAddr + + s.publish() + } + + // We should always respond. + ack := synAckPacket{ + TraceID: syn.TraceID, + FromAddr: getLocalAddr(), + } + s.sendControlPacket(ack) + + if s.staged.Direct { + continue + } + + if !syn.FromAddr.IsValid() { + continue + } + + probe := probePacket{TraceID: newTraceID()} + s.sendControlPacketTo(probe, syn.FromAddr) + + case probePacket: + if pkt.SrcAddr.IsValid() { + s.sendControlPacketTo(probePacket{TraceID: p.TraceID}, pkt.SrcAddr) + } else { + logf("Invalid probe address") + } + } + + case <-timeoutTimer.C: + logf("Connection timeout") + s.staged.Up = false + s.publish() + } + } +} + +// ---------------------------------------------------------------------------- + +func (s *peerSupervisor) client() stateFunc { + logf := func(format string, args ...any) { s.logf("CLIENT "+format, args...) } + + logf("DOWN") + + var ( + syn = synPacket{ + TraceID: newTraceID(), + SharedKey: s.staged.DataCipher.Key(), + Direct: s.staged.Direct, + FromAddr: getLocalAddr(), + } + + ack synAckPacket + + probe probePacket + probeAddr netip.AddrPort + + lAddr netip.AddrPort + + timeoutTimer = time.NewTimer(timeoutInterval) + pingTimer = time.NewTimer(pingInterval) + ) + + defer timeoutTimer.Stop() + defer pingTimer.Stop() + + s.sendControlPacket(syn) + + for { + select { + + case peer := <-s.peerUpdates: + return s.peerUpdate(peer) + + case pkt := <-s.controlPackets: + switch p := pkt.Payload.(type) { + + case synAckPacket: + if p.TraceID != syn.TraceID { + continue // Hmm... + } + + ack = p + timeoutTimer.Reset(timeoutInterval) + + if !s.staged.Up { + if s.staged.Direct { + logf("UP - Direct") + } else { + logf("UP - Relayed") + } + + s.staged.Up = true + s.publish() + } + + case probePacket: + if s.staged.Direct { + continue + } + + if p.TraceID != probe.TraceID { + continue + } + + // Upgrade connection. + + logf("UP - Direct") + s.staged.Direct = true + s.staged.RemoteAddr = probeAddr + s.publish() + + syn.TraceID = newTraceID() + syn.Direct = true + syn.FromAddr = getLocalAddr() + s.sendControlPacket(syn) + } + + case <-pingTimer.C: + // Send syn. + + syn.FromAddr = getLocalAddr() + if syn.FromAddr != lAddr { + syn.TraceID = newTraceID() + lAddr = syn.FromAddr + } + + s.sendControlPacket(syn) + + pingTimer.Reset(pingInterval) + + if s.staged.Direct { + continue + } + + if !ack.FromAddr.IsValid() { + continue + } + + probe = probePacket{TraceID: newTraceID()} + probeAddr = ack.FromAddr + + s.sendControlPacketTo(probe, ack.FromAddr) + + case <-timeoutTimer.C: + logf("Connection timeout") + return s.peerUpdate(s.peer) + } + } +} diff --git a/node/peer.go b/node/peer.go deleted file mode 100644 index 2b4023a..0000000 --- a/node/peer.go +++ /dev/null @@ -1 +0,0 @@ -package node diff --git a/node/peerstate.go b/node/peerstate.go deleted file mode 100644 index 2b4023a..0000000 --- a/node/peerstate.go +++ /dev/null @@ -1 +0,0 @@ -package node diff --git a/node/peersupervisor.go b/node/peersupervisor.go deleted file mode 100644 index 14c9315..0000000 --- a/node/peersupervisor.go +++ /dev/null @@ -1,327 +0,0 @@ -package node - -import ( - "fmt" - "log" - "net/netip" - "time" - "vppn/m" -) - -const ( - connectTimeout = 6 * time.Second - pingInterval = 6 * time.Second - timeoutInterval = 20 * time.Second -) - -type routingPacketWrapper struct { - routingPacket - Addr netip.AddrPort // Source. -} - -type peerSupervisor struct { - // Constants: - localIP byte - localPublic bool - remoteIP byte - privKey []byte - - // Shared data: - w *connWriter - table *routingTable - - packets chan routingPacketWrapper - peerUpdates chan *m.Peer - - // Peer-related items. - version int64 // Ony accessed in HandlePeerUpdate. - peer *m.Peer - remoteAddrPort *netip.AddrPort - mediated bool - sharedKey []byte - - // Used by our state functions. - pingTimer *time.Timer - timeoutTimer *time.Timer - buf []byte -} - -// ---------------------------------------------------------------------------- - -func newPeerSupervisor( - conf m.PeerConfig, - remoteIP byte, - w *connWriter, - table *routingTable, -) *peerSupervisor { - s := &peerSupervisor{ - localIP: conf.PeerIP, - remoteIP: remoteIP, - privKey: conf.EncPrivKey, - w: w, - table: table, - packets: make(chan routingPacketWrapper, 256), - peerUpdates: make(chan *m.Peer, 1), - pingTimer: time.NewTimer(pingInterval), - timeoutTimer: time.NewTimer(timeoutInterval), - buf: make([]byte, bufferSize), - } - - _, s.localPublic = netip.AddrFromSlice(conf.PublicIP) - - go s.mainLoop() - return s -} - -func (s *peerSupervisor) logf(msg string, args ...any) { - msg = fmt.Sprintf("[%03d] ", s.remoteIP) + msg - log.Printf(msg, args...) -} - -// ---------------------------------------------------------------------------- - -func (s *peerSupervisor) mainLoop() { - defer panicHandler() - state := s.stateInit - for { - state = state() - } -} - -// ---------------------------------------------------------------------------- - -func (s *peerSupervisor) HandlePeerUpdate(p *m.Peer) { - if p != nil { - if p.Version == s.version { - return - } - s.version = p.Version - } else { - s.version = 0 - } - - s.peerUpdates <- p -} - -func (s *peerSupervisor) HandlePacket(w routingPacketWrapper) { - select { - case s.packets <- w: - default: - // Drop - } -} - -// ---------------------------------------------------------------------------- - -type stateFunc func() stateFunc - -func (s *peerSupervisor) stateInit() stateFunc { - if s.peer == nil { - return s.stateDisconnected - } - - addr, ok := netip.AddrFromSlice(s.peer.PublicIP) - if ok { - addrPort := netip.AddrPortFrom(addr, s.peer.Port) - s.remoteAddrPort = &addrPort - } else { - s.remoteAddrPort = nil - } - s.sharedKey = computeSharedKey(s.peer.EncPubKey, s.privKey) - - return s.stateSelectRole() -} - -// ---------------------------------------------------------------------------- - -func (s *peerSupervisor) stateDisconnected() stateFunc { - s.clearRoutingTable() - - for { - select { - case <-s.packets: - // Drop - case s.peer = <-s.peerUpdates: - return s.stateInit - } - } -} - -// ---------------------------------------------------------------------------- - -func (s *peerSupervisor) stateSelectRole() stateFunc { - s.logf("STATE: SelectRole") - s.updateRoutingTable(false) - - if s.remoteAddrPort != nil { - s.mediated = false - - // If both remote and local are public, one side acts as client, and one - // side as server. - if s.localPublic && s.localIP < s.peer.PeerIP { - return s.stateAccept - } - return s.stateDial - } - - // We're public, remote is not => can only wait for connection - if s.localPublic { - s.mediated = false - return s.stateAccept - } - - // Both non-public => need to use mediator. - return s.stateMediated -} - -// ---------------------------------------------------------------------------- - -func (s *peerSupervisor) stateAccept() stateFunc { - s.logf("STATE: Accept") - - for { - - select { - case pkt := <-s.packets: - switch pkt.Type { - - case packetTypePing: - s.remoteAddrPort = &pkt.Addr - s.updateRoutingTable(true) - s.sendPong(pkt.TraceID) - return s.stateConnected - - default: - // Still waiting for ping... - } - - case s.peer = <-s.peerUpdates: - return s.stateInit - } - } -} - -// ---------------------------------------------------------------------------- - -func (s *peerSupervisor) stateDial() stateFunc { - s.logf("STATE: Dial") - s.updateRoutingTable(false) - - s.sendPing() - - for { - select { - case pkt := <-s.packets: - - switch pkt.Type { - - case packetTypePong: - s.updateRoutingTable(true) - return s.stateConnected - - default: - // Ignore - } - - case <-s.pingTimer.C: - s.sendPing() - - case s.peer = <-s.peerUpdates: - return s.stateInit - } - } -} - -// ---------------------------------------------------------------------------- - -func (s *peerSupervisor) stateConnected() stateFunc { - s.logf("STATE: Connected") - - s.timeoutTimer.Reset(timeoutInterval) - - for { - select { - - case <-s.pingTimer.C: - s.sendPing() - - case <-s.timeoutTimer.C: - s.logf("Timeout") - return s.stateInit - - case pkt := <-s.packets: - switch pkt.Type { - case packetTypePing: - s.sendPong(pkt.TraceID) - - // Server should always follow remote port. - if s.localPublic { - if pkt.Addr != *s.remoteAddrPort { - s.remoteAddrPort = &pkt.Addr - s.updateRoutingTable(true) - } - } - - case packetTypePong: - s.timeoutTimer.Reset(timeoutInterval) - - default: - // Drop packet. - } - - case s.peer = <-s.peerUpdates: - s.logf("New peer: %v", s.peer) - return s.stateInit - } - } -} - -// ---------------------------------------------------------------------------- - -func (s *peerSupervisor) stateMediated() stateFunc { - s.logf("STATE: Mediated") - s.mediated = true - s.updateRoutingTable(true) - - for { - select { - case <-s.packets: - // Drop. - case s.peer = <-s.peerUpdates: - s.logf("New peer: %v", s.peer) - return s.stateInit - } - } -} - -// ---------------------------------------------------------------------------- - -func (s *peerSupervisor) clearRoutingTable() { - s.table.Set(s.remoteIP, nil) -} - -func (s *peerSupervisor) updateRoutingTable(up bool) { - s.table.Set(s.remoteIP, &peer{ - Up: up, - Mediator: s.peer.Mediator, - Mediated: s.mediated, - IP: s.remoteIP, - Addr: s.remoteAddrPort, - SharedKey: s.sharedKey, - }) -} - -// ---------------------------------------------------------------------------- - -func (s *peerSupervisor) sendPing() uint64 { - traceID := newTraceID() - pkt := newRoutingPacket(packetTypePing, traceID) - s.w.WriteTo(s.peer.PeerIP, streamRouting, pkt.Marshal(s.buf)) - s.pingTimer.Reset(pingInterval) - return traceID -} - -func (s *peerSupervisor) sendPong(traceID uint64) { - pkt := newRoutingPacket(packetTypePong, traceID) - s.w.WriteTo(s.peer.PeerIP, streamRouting, pkt.Marshal(s.buf)) -} diff --git a/node/relaymanager.go b/node/relaymanager.go new file mode 100644 index 0000000..5c44ea8 --- /dev/null +++ b/node/relaymanager.go @@ -0,0 +1,40 @@ +package node + +import ( + "log" + "math/rand" + "time" +) + +func relayManager() { + time.Sleep(2 * time.Second) + updateRelayRoute() + + for range time.Tick(8 * time.Second) { + relay := getRelayRoute() + if relay == nil || !relay.Up || !relay.Relay { + updateRelayRoute() + } + } +} + +func updateRelayRoute() { + possible := make([]*peerRoute, 0, 8) + for i := range routingTable { + route := routingTable[i].Load() + if !route.Up || !route.Relay { + continue + } + possible = append(possible, route) + } + + if len(possible) == 0 { + log.Printf("No relay available.") + relayIP.Store(nil) + return + } + + ip := possible[rand.Intn(len(possible))].IP + log.Printf("New relay IP: %d", ip) + relayIP.Store(&ip) +} diff --git a/node/router.go b/node/router.go deleted file mode 100644 index 67c0756..0000000 --- a/node/router.go +++ /dev/null @@ -1,186 +0,0 @@ -package node - -import ( - "encoding/json" - "io" - "log" - "net/http" - "net/netip" - "net/url" - "sync/atomic" - "time" - "vppn/m" -) - -type peer struct { - Up bool // No data will be sent to peers that are down. - Mediator bool - Mediated bool - IP byte - Addr *netip.AddrPort // If we have direct connection, otherwise use mediator. - SharedKey []byte -} - -// ---------------------------------------------------------------------------- - -type routingTable struct { - table [256]*atomic.Pointer[peer] - mediator *atomic.Pointer[peer] -} - -func newRoutingTable() *routingTable { - r := routingTable{ - mediator: &atomic.Pointer[peer]{}, - } - - for i := range r.table { - r.table[i] = &atomic.Pointer[peer]{} - } - - return &r -} - -func (r *routingTable) Get(ip byte) *peer { - return r.table[ip].Load() -} - -func (r *routingTable) Set(ip byte, p *peer) { - r.table[ip].Store(p) -} - -// ---------------------------------------------------------------------------- - -type router struct { - *routingTable - netName string - peerSupers [256]*peerSupervisor -} - -func newRouter(netName string, conf m.PeerConfig, routingData *routingTable, w *connWriter) *router { - r := &router{ - netName: netName, - routingTable: routingData, - } - - for i := range r.peerSupers { - r.peerSupers[i] = newPeerSupervisor( - conf, - byte(i), - w, - r.routingTable) - } - - go r.selectMediator() - go r.pollHub(conf) - - return r -} - -// ---------------------------------------------------------------------------- - -func (r *router) HandlePacket(sourceIP byte, remoteAddr netip.AddrPort, data []byte) { - p := routingPacket{} - if err := p.Parse(data); err != nil { - log.Printf("Dropping malformed routing packet: %v", err) - return - } - - w := routingPacketWrapper{ - routingPacket: p, - Addr: remoteAddr, - } - - r.peerSupers[sourceIP].HandlePacket(w) -} - -// ---------------------------------------------------------------------------- - -func (r *router) pollHub(conf m.PeerConfig) { - defer panicHandler() - - u, err := url.Parse(conf.HubAddress) - if err != nil { - log.Fatalf("Failed to parse hub address %s: %v", conf.HubAddress, err) - } - u.Path = "/peer/fetch-state/" - - client := &http.Client{Timeout: 8 * time.Second} - - req := &http.Request{ - Method: http.MethodGet, - URL: u, - Header: http.Header{}, - } - req.SetBasicAuth("", conf.APIKey) - - state, err := loadNetworkState(r.netName) - if err != nil { - log.Printf("Failed to load network state: %v", err) - log.Printf("Polling hub...") - r._pollHub(conf, client, req) - } else { - r.applyNetworkState(conf, state) - } - - for range time.Tick(64 * time.Second) { - r._pollHub(conf, client, req) - } -} - -func (r *router) _pollHub(conf m.PeerConfig, client *http.Client, req *http.Request) { - var state m.NetworkState - - log.Printf("Fetching peer state from %s...", conf.HubAddress) - resp, err := client.Do(req) - if err != nil { - log.Printf("Failed to fetch peer state: %v", err) - return - } - body, err := io.ReadAll(resp.Body) - _ = resp.Body.Close() - if err != nil { - log.Printf("Failed to read body from hub: %v", err) - return - } - - if err := json.Unmarshal(body, &state); err != nil { - log.Printf("Failed to unmarshal response from hub: %v", err) - return - } - - r.applyNetworkState(conf, state) - - if err := storeNetworkState(r.netName, state); err != nil { - log.Printf("Failed to store network state: %v", err) - } -} - -func (r *router) applyNetworkState(conf m.PeerConfig, state m.NetworkState) { - for i := range state.Peers { - if i != int(conf.PeerIP) { - r.peerSupers[i].HandlePeerUpdate(state.Peers[i]) - } - } -} - -// ---------------------------------------------------------------------------- - -func (r *router) selectMediator() { - for range time.Tick(8 * time.Second) { - current := r.mediator.Load() - if current != nil && current.Up { - continue - } - - for i := range r.table { - peer := r.table[i].Load() - if peer != nil && peer.Up && peer.Mediator { - log.Printf("Got mediator: %v", *peer) - r.mediator.Store(peer) - return - } - } - - r.mediator.Store(nil) - } -} diff --git a/node/routingpacket.go b/node/routingpacket.go deleted file mode 100644 index 4e35055..0000000 --- a/node/routingpacket.go +++ /dev/null @@ -1,44 +0,0 @@ -package node - -import ( - "errors" - "unsafe" -) - -var errMalformedPacket = errors.New("malformed packet") - -const ( - packetTypeInvalid = iota - - // Used to maintain connection. - packetTypePing - packetTypePong -) - -type routingPacket struct { - Type byte // One of the packetType* constants. - TraceID uint64 // For matching requests and responses. -} - -func newRoutingPacket(reqType byte, traceID uint64) routingPacket { - return routingPacket{ - Type: reqType, - TraceID: traceID, - } -} - -func (p routingPacket) Marshal(buf []byte) []byte { - buf = buf[:32] // Reserve 32 bytes just in case we need to add anything. - buf[0] = p.Type - *(*uint64)(unsafe.Pointer(&buf[1])) = uint64(p.TraceID) - return buf -} - -func (p *routingPacket) Parse(buf []byte) error { - if len(buf) != 32 { - return errMalformedPacket - } - p.Type = buf[0] - p.TraceID = *(*uint64)(unsafe.Pointer(&buf[1])) - return nil -} diff --git a/node/tmp-server.go b/node/tmp-server.go deleted file mode 100644 index 179a8a4..0000000 --- a/node/tmp-server.go +++ /dev/null @@ -1,185 +0,0 @@ -package node - -/* -var ( - network = []byte{10, 1, 1, 0} - serverIP = byte(1) - clientIP = byte(2) - port = uint16(5151) - netName = "testnet" - pubKey1 = []byte{0x43, 0xde, 0xd4, 0xb2, 0x1d, 0x71, 0x58, 0x9a, 0x96, 0x3a, 0x23, 0xfc, 0x2, 0xe, 0xfa, 0x42, 0x3, 0x94, 0xbc, 0xf8, 0x25, 0xf, 0x54, 0xcc, 0x98, 0x42, 0x8b, 0xe5, 0x27, 0x86, 0x49, 0x33} - privKey1 = []byte{0xae, 0x4d, 0xc5, 0xaa, 0xc9, 0xbc, 0x65, 0x41, 0x55, 0xb, 0x61, 0x52, 0xc4, 0x6c, 0xce, 0x2f, 0x1b, 0xf5, 0xb3, 0xbf, 0xb5, 0x54, 0x61, 0x7c, 0x26, 0x2e, 0xba, 0x5a, 0x19, 0xe2, 0x9c, 0xe0} - pubKey2 = []byte{0x8c, 0xfe, 0x12, 0xd9, 0x2d, 0x37, 0x5, 0x43, 0xab, 0x70, 0x59, 0x20, 0x3d, 0x82, 0x93, 0x9b, 0xb3, 0xaa, 0x35, 0x23, 0xc1, 0xb4, 0x4, 0x1f, 0x92, 0x97, 0x6f, 0xfd, 0x55, 0x17, 0x5a, 0x4b} - privKey2 = []byte{0xd9, 0xe1, 0xc6, 0x64, 0x3e, 0x29, 0x29, 0x78, 0x81, 0x53, 0xc2, 0x31, 0xd9, 0x34, 0x5b, 0x41, 0xf5, 0x80, 0xb0, 0x27, 0x9f, 0x65, 0x85, 0xd4, 0x78, 0xd5, 0x9, 0x2, 0xca, 0x56, 0x42, 0x80} -) - -func must(err error) { - if err != nil { - panic(err) - } -} - -type TmpNode struct { - network []byte - localIP byte - router *router - port uint16 - netName string - iface io.ReadWriteCloser - pubKey []byte - privKey []byte - w *connWriter - r *connReader -} - -// ---------------------------------------------------------------------------- - -func NewTmpNodeServer() *TmpNode { - n := &TmpNode{ - localIP: serverIP, - network: network, - router: &router{table: newPeerRepo()}, - port: port, - netName: netName, - pubKey: pubKey1, - privKey: privKey1, - } - - var err error - n.iface, err = openInterface(n.network, n.localIP, n.netName) - must(err) - - myAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", n.port)) - must(err) - - conn, err := net.ListenUDP("udp", myAddr) - must(err) - - n.w = newConnWriter(conn, n.localIP, n.router) - n.r = newConnReader(conn, n.localIP, n.router) - - n.router.table.Set(clientIP, &peer{ - IP: clientIP, - SharedKey: computeSharedKey(pubKey2, n.privKey), - }) - - return n -} - -// ---------------------------------------------------------------------------- - -func NewTmpNodeClient(srvAddrStr string) *TmpNode { - n := &TmpNode{ - localIP: clientIP, - network: network, - router: &router{table: newPeerRepo()}, - port: port, - netName: netName, - pubKey: pubKey2, - privKey: privKey2, - } - - var err error - n.iface, err = openInterface(n.network, n.localIP, n.netName) - must(err) - - myAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", n.port)) - must(err) - - conn, err := net.ListenUDP("udp", myAddr) - must(err) - - n.w = newConnWriter(conn, n.localIP, n.router) - n.r = newConnReader(conn, n.localIP, n.router) - - serverAddr, err := netip.ParseAddrPort(fmt.Sprintf("%s:%d", srvAddrStr, port)) - must(err) - - n.router.table.Set(serverIP, &peer{ - IP: serverIP, - Addr: &serverAddr, - SharedKey: computeSharedKey(pubKey1, n.privKey), - }) - - return n -} - -// ---------------------------------------------------------------------------- - -func (n *TmpNode) RunServer() { - defer func() { - if r := recover(); r != nil { - fmt.Printf("%v", r) - debug.PrintStack() - } - }() - - // Get remoteAddr from a packet. - buf := make([]byte, bufferSize) - remoteAddr, h, _, err := n.r.Read(buf) - must(err) - log.Printf("Got remote addr: %d -> %v", h.SourceIP, remoteAddr) - must(err) - - n.router.table.Set(h.SourceIP, &peer{ - IP: h.SourceIP, - Addr: &remoteAddr, - SharedKey: computeSharedKey(pubKey2, n.privKey), - }) - - go n.readFromIFace() - n.readFromConn() -} - -// ---------------------------------------------------------------------------- - -func (n *TmpNode) RunClient() { - defer func() { - if r := recover(); r != nil { - fmt.Printf("%v\n", r) - debug.PrintStack() - } - }() - - log.Printf("Sending to server...") - must(n.w.WriteTo(serverIP, 1, []byte{1, 2, 3, 4, 5, 6, 7, 8})) - - go n.readFromIFace() - n.readFromConn() -} - -func (n *TmpNode) readFromIFace() { - var ( - buf = make([]byte, bufferSize) - packet []byte - remoteIP byte - err error - ) - - for { - packet, remoteIP, err = readNextPacket(n.iface, buf) - must(err) - must(n.w.WriteTo(remoteIP, 1, packet)) - } -} - -func (node *TmpNode) readFromConn() { - var ( - buf = make([]byte, bufferSize) - packet []byte - err error - ) - - for { - _, _, packet, err = node.r.Read(buf) - must(err) - // We assume that we're only receiving packets from one source. - - _, err = node.iface.Write(packet) - if err != nil { - log.Printf("Got error: %v", err) - } - //must(err) - } -} -*/