From 9a3cb2d1c2bca9e1a671788c2c01aa096853af8e Mon Sep 17 00:00:00 2001 From: "J. David Lee" Date: Fri, 12 Jun 2026 15:11:01 +0000 Subject: [PATCH] Refactor - now wireguard based. (#7) --- cmd/vppn/main.go | 63 ++- go.mod | 8 + go.sum | 18 + hub/api/api.go | 147 ++++--- hub/api/db/generated.go | 204 ++-------- hub/api/db/sanitize-validate.go | 58 ++- hub/api/db/tables.defs | 28 +- hub/api/db/written.go | 32 -- hub/api/errors.go | 1 - hub/api/migrations/2024-11-30-init.sql | 29 +- hub/api/types.go | 8 +- hub/app.go | 18 +- hub/cookie.go | 12 +- hub/global.go | 2 +- hub/handler.go | 13 +- hub/handlers.go | 130 +++---- hub/main.go | 2 +- hub/routes.go | 2 - hub/templates/admin-network-create.html | 7 +- hub/templates/admin-network-list.html | 4 +- hub/templates/admin-password-edit.html | 3 +- hub/templates/admin-sign-out.html | 3 +- hub/templates/network/base.html | 2 +- hub/templates/network/network-delete.html | 3 +- hub/templates/network/network-view.html | 6 +- hub/templates/network/peer-create.html | 13 +- hub/templates/network/peer-delete.html | 5 +- hub/templates/network/peer-edit.html | 35 -- hub/templates/network/peer-view.html | 9 +- hub/templates/sign-in.html | 3 +- hub/util.go | 13 + m/models.go | 119 +++++- peer/app.go | 105 +++++ peer/app_test.go | 62 +++ peer/bitset.go | 21 - peer/bitset_test.go | 48 --- peer/cipher-control.go | 26 -- peer/cipher-control_test.go | 122 ------ peer/cipher-data.go | 61 --- peer/cipher-data_test.go | 141 ------- peer/connreader.go | 46 --- peer/control/ping.go | 76 ++++ peer/control/ping_test.go | 106 +++++ peer/control/role.go | 22 ++ peer/control_conn.go | 64 ++++ peer/controlmessage.go | 64 ---- peer/crypto.go | 30 -- peer/device.go | 74 ++++ peer/dupcheck.go | 76 ---- peer/dupcheck_test.go | 57 --- peer/errors.go | 8 - peer/fake_control_conn_test.go | 43 +++ peer/fake_wgdevice_test.go | 123 ++++++ peer/files.go | 115 ------ peer/files_test.go | 57 --- peer/globals.go | 109 ------ peer/header.go | 47 --- peer/header_test.go | 21 - peer/hosts.go | 128 +++++++ peer/hosts_test.go | 205 ++++++++++ peer/hub_poller.go | 153 ++++++++ peer/hub_poller_test.go | 80 ++++ peer/hubpoller.go | 111 ------ peer/ifreader.go | 73 ---- peer/ifreader_test.go | 81 ---- peer/init.go | 190 +++++++++ peer/interface.go | 137 ------- peer/interfaces.go | 28 ++ peer/json.go | 36 ++ peer/main.go | 209 ---------- peer/main_test.go | 5 - peer/mcreader.go | 47 --- peer/mcreader_test.go | 132 ------- peer/mcwriter.go | 54 --- peer/mcwriter_test.go | 98 ----- peer/mock-iface_test.go | 31 -- peer/mock-network_test.go | 80 ---- peer/multicast/broadcaster.go | 62 +++ peer/multicast/global.go | 9 + peer/multicast/packet.go | 54 +++ peer/multicast/packet_test.go | 38 ++ peer/multicast/receiver.go | 61 +++ peer/network_state.go | 18 + peer/network_state_test.go | 56 +++ peer/new.go | 109 ++++++ peer/on_hub.go | 114 ++++++ peer/on_hub_test.go | 299 +++++++++++++++ peer/on_multicast.go | 51 +++ peer/on_ping.go | 58 +++ peer/on_tick.go | 52 +++ peer/packets-util.go | 182 --------- peer/packets-util_test.go | 76 ---- peer/packets.go | 120 ------ peer/packets_test.go | 64 ---- peer/peer.go | 201 ---------- peer/ping.go | 21 + peer/pubaddrs.go | 86 ----- peer/pubaddrs_test.go | 29 -- peer/relayhandler.go | 54 --- peer/remote.go | 380 +++--------------- peer/remotefsm.go | 448 ---------------------- peer/statusserver.go | 71 ---- peer/wginterface/interface.go | 225 +++++++++++ peer/wginterface/manage.go | 184 +++++++++ peer/wginterface/manage_test.go | 303 +++++++++++++++ 105 files changed, 3776 insertions(+), 4251 deletions(-) delete mode 100644 hub/templates/network/peer-edit.html create mode 100644 peer/app.go create mode 100644 peer/app_test.go delete mode 100644 peer/bitset.go delete mode 100644 peer/bitset_test.go delete mode 100644 peer/cipher-control.go delete mode 100644 peer/cipher-control_test.go delete mode 100644 peer/cipher-data.go delete mode 100644 peer/cipher-data_test.go delete mode 100644 peer/connreader.go create mode 100644 peer/control/ping.go create mode 100644 peer/control/ping_test.go create mode 100644 peer/control/role.go create mode 100644 peer/control_conn.go delete mode 100644 peer/controlmessage.go delete mode 100644 peer/crypto.go create mode 100644 peer/device.go delete mode 100644 peer/dupcheck.go delete mode 100644 peer/dupcheck_test.go delete mode 100644 peer/errors.go create mode 100644 peer/fake_control_conn_test.go create mode 100644 peer/fake_wgdevice_test.go delete mode 100644 peer/files.go delete mode 100644 peer/files_test.go delete mode 100644 peer/globals.go delete mode 100644 peer/header.go delete mode 100644 peer/header_test.go create mode 100644 peer/hosts.go create mode 100644 peer/hosts_test.go create mode 100644 peer/hub_poller.go create mode 100644 peer/hub_poller_test.go delete mode 100644 peer/hubpoller.go delete mode 100644 peer/ifreader.go delete mode 100644 peer/ifreader_test.go create mode 100644 peer/init.go delete mode 100644 peer/interface.go create mode 100644 peer/interfaces.go create mode 100644 peer/json.go delete mode 100644 peer/main.go delete mode 100644 peer/main_test.go delete mode 100644 peer/mcreader.go delete mode 100644 peer/mcreader_test.go delete mode 100644 peer/mcwriter.go delete mode 100644 peer/mcwriter_test.go delete mode 100644 peer/mock-iface_test.go delete mode 100644 peer/mock-network_test.go create mode 100644 peer/multicast/broadcaster.go create mode 100644 peer/multicast/global.go create mode 100644 peer/multicast/packet.go create mode 100644 peer/multicast/packet_test.go create mode 100644 peer/multicast/receiver.go create mode 100644 peer/network_state.go create mode 100644 peer/network_state_test.go create mode 100644 peer/new.go create mode 100644 peer/on_hub.go create mode 100644 peer/on_hub_test.go create mode 100644 peer/on_multicast.go create mode 100644 peer/on_ping.go create mode 100644 peer/on_tick.go delete mode 100644 peer/packets-util.go delete mode 100644 peer/packets-util_test.go delete mode 100644 peer/packets.go delete mode 100644 peer/packets_test.go delete mode 100644 peer/peer.go create mode 100644 peer/ping.go delete mode 100644 peer/pubaddrs.go delete mode 100644 peer/pubaddrs_test.go delete mode 100644 peer/relayhandler.go delete mode 100644 peer/remotefsm.go delete mode 100644 peer/statusserver.go create mode 100644 peer/wginterface/interface.go create mode 100644 peer/wginterface/manage.go create mode 100644 peer/wginterface/manage_test.go diff --git a/cmd/vppn/main.go b/cmd/vppn/main.go index dada4cf..d0e5be0 100644 --- a/cmd/vppn/main.go +++ b/cmd/vppn/main.go @@ -1,11 +1,72 @@ package main import ( + "flag" "log" + "os" + "path/filepath" + "strings" + "vppn/peer" + + "git.crumpington.com/lib/go/flock" ) func main() { log.SetFlags(0) - peer.Main2() + + name := flag.String("name", "", "network name (required)") + hub := flag.String("hub", "", "hub base URL (required)") + flag.Parse() + + if *name == "" || *hub == "" { + flag.Usage() + os.Exit(1) + } + + apiKey, err := loadAPIKey(*name) + if err != nil { + log.Fatalf("api key: %v", err) + } + + // Directory existence is guaranteed by the apikey file read above. + lockFile, err := flock.TryLock(vppnPath(*name, "lock")) + if err != nil { + log.Fatalf("lock: %v", err) + } + if lockFile == nil { + log.Fatalf("already running for network %q", *name) + } + defer flock.Unlock(lockFile) + + state, err := peer.LoadOrInit(vppnPath(*name, "state.json"), *hub, apiKey) + if err != nil { + log.Fatalf("init: %v", err) + } + + ifaceName := strings.TrimSuffix(state.LocalDomain, ".local") + app, err := peer.New(state, *hub, apiKey, ifaceName, state.LocalDomain, vppnPath(*name, "network.json")) + if err != nil { + log.Fatalf("start: %v", err) + } + + if err := app.Run(); err != nil { + log.Fatalf("run: %v", err) + } +} + +func loadAPIKey(name string) (string, error) { + data, err := os.ReadFile(vppnPath(name, "apikey")) + if err != nil { + return "", err + } + return strings.TrimSpace(string(data)), nil +} + +func vppnPath(name, file string) string { + home, err := os.UserHomeDir() + if err != nil { + return filepath.Join(".vppn", name, file) + } + return filepath.Join(home, ".vppn", name, file) } diff --git a/go.mod b/go.mod index ad3add6..e35d73d 100644 --- a/go.mod +++ b/go.mod @@ -6,10 +6,18 @@ require ( git.crumpington.com/lib/go v0.9.1 golang.org/x/crypto v0.42.0 golang.org/x/sys v0.36.0 + golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 ) require ( + github.com/google/go-cmp v0.6.0 // indirect + github.com/josharian/native v1.1.0 // indirect github.com/mattn/go-sqlite3 v1.14.32 // indirect + github.com/mdlayher/genetlink v1.3.2 // indirect + github.com/mdlayher/netlink v1.7.2 // indirect + github.com/mdlayher/socket v0.5.1 // indirect golang.org/x/net v0.44.0 // indirect + golang.org/x/sync v0.17.0 // indirect golang.org/x/text v0.29.0 // indirect + golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 // indirect ) diff --git a/go.sum b/go.sum index 0f444e0..9d7f776 100644 --- a/go.sum +++ b/go.sum @@ -1,12 +1,30 @@ git.crumpington.com/lib/go v0.9.1 h1:xLBzcgiZRB6Ky3Ce9hKE+Ko0YbkA4USF4eJk5i5RJF4= git.crumpington.com/lib/go v0.9.1/go.mod h1:5nnfjdnUnj/FHhakaliKQKsKeSkUb0GEUKF3PqRgUXg= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA= +github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw= +github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o= +github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= +github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= +github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos= +github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= +github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws= +github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc= golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI= golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8= golang.org/x/net v0.44.0 h1:evd8IRDyfNBMBTTY5XRF1vaZlD+EmWx6x8PkhR04H/I= golang.org/x/net v0.44.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY= +golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= +golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk= golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4= +golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4= +golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= +golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU= +golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ= diff --git a/hub/api/api.go b/hub/api/api.go index 7a534ca..19838b8 100644 --- a/hub/api/api.go +++ b/hub/api/api.go @@ -19,8 +19,10 @@ import ( var migrations embed.FS type API struct { - db *sql.DB - lock sync.Mutex + db *sql.DB + lock sync.Mutex + sessionsMu sync.Mutex + sessions map[string]*Session } func New(dbPath string) (*API, error) { @@ -34,10 +36,17 @@ func New(dbPath string) (*API, error) { } a := &API{ - db: sqlDB, + db: sqlDB, + sessions: make(map[string]*Session), } - return a, a.ensurePassword() + if err := a.ensurePassword(); err != nil { + return nil, err + } + + go a.sweepSessions() + + return a, nil } func (a *API) ensurePassword() error { @@ -62,12 +71,8 @@ func (a *API) ensurePassword() error { return db.Config_Insert(a.db, conf) } -func (a *API) Config_Get() *Config { - conf, err := db.Config_Get(a.db, 1) - if err != nil { - panic(err) - } - return conf +func (a *API) Config_Get() (*Config, error) { + return db.Config_Get(a.db, 1) } func (a *API) Config_Update(conf *Config) error { @@ -75,56 +80,78 @@ func (a *API) Config_Update(conf *Config) error { } func (a *API) Session_Delete(sessionID string) error { - return db.Session_Delete(a.db, sessionID) + a.sessionsMu.Lock() + defer a.sessionsMu.Unlock() + delete(a.sessions, sessionID) + return nil } -func (a *API) Session_Get(sessionID string) (*Session, error) { - if sessionID == "" { - return a.session_CreatePub() +const ( + sessionTTLSecs = 86400 * 21 // sessions expire 21 days after last use + sessionSweepEvery = time.Hour // cadence of expired-session eviction +) + +// Session_Get returns a snapshot copy of the signed-in session for sessionID, +// or the zero Session if the cookie is missing/unknown/expired. It never +// creates a session, so anonymous requests cost no memory — a session is minted +// only by Session_SignIn. Returning a value (not the stored pointer) keeps +// callers from racing on the shared struct. +func (a *API) Session_Get(sessionID string) (Session, error) { + a.sessionsMu.Lock() + defer a.sessionsMu.Unlock() + + s, ok := a.sessions[sessionID] + + if sessionID == "" || !ok { + return Session{}, nil } - session, err := db.Session_Get(a.db, sessionID) + if timeSince(s.LastSeenAt) > sessionTTLSecs { + delete(a.sessions, sessionID) + return Session{}, nil + } + + s.LastSeenAt = time.Now().Unix() + return *s, nil +} + +// Session_SignIn verifies pwd and, on success, mints a fresh signed-in session, +// returning it so the caller can set the cookie. A new ID per sign-in rotates +// the session at the privilege boundary (session-fixation resistance). +func (a *API) Session_SignIn(pwd string) (Session, error) { + conf, err := a.Config_Get() if err != nil { - return a.session_CreatePub() + return Session{}, err + } + if err := bcrypt.CompareHashAndPassword(conf.Password, []byte(pwd)); err != nil { + return Session{}, ErrNotAuthorized } - if timeSince(session.LastSeenAt) > 86400*21 { - return a.session_CreatePub() - } - - if timeSince(session.LastSeenAt) > 86400*7 { - session.LastSeenAt = time.Now().Unix() - if err := db.Session_UpdateLastSeenAt(a.db, session.SessionID); err != nil { - log.Printf("Failed to update session: %v", err) - } - } - - return session, nil -} - -func (a *API) session_CreatePub() (*Session, error) { + a.sessionsMu.Lock() + defer a.sessionsMu.Unlock() s := &Session{ SessionID: idgen.NewToken(), - CSRF: idgen.NewToken(), - SignedIn: false, + SignedIn: true, CreatedAt: time.Now().Unix(), LastSeenAt: time.Now().Unix(), } - err := db.Session_Insert(a.db, s) - return s, err + a.sessions[s.SessionID] = s + return *s, nil } -func (a *API) Session_DeleteBefore(timestamp int64) error { - return db.Session_DeleteBefore(a.db, timestamp) -} - -func (a *API) Session_SignIn(s *Session, pwd string) error { - conf := a.Config_Get() - if err := bcrypt.CompareHashAndPassword(conf.Password, []byte(pwd)); err != nil { - return ErrNotAuthorized +// sweepSessions periodically evicts sessions past their TTL. Without it, a +// signed-in session whose ID is never presented again would linger forever +// (Session_Get only evicts on a lookup of that same ID). +func (a *API) sweepSessions() { + for range time.Tick(sessionSweepEvery) { + a.sessionsMu.Lock() + for id, s := range a.sessions { + if timeSince(s.LastSeenAt) > sessionTTLSecs { + delete(a.sessions, id) + } + } + a.sessionsMu.Unlock() } - - return db.Session_SetSignedIn(a.db, s.SessionID) } func (a *API) Network_Create(n *Network) error { @@ -141,14 +168,13 @@ func (a *API) Network_Get(id int64) (*Network, error) { } func (a *API) Network_List() ([]*Network, error) { - const query = db.Network_SelectQuery + ` ORDER BY Name ASC` + const query = db.Network_SelectQuery + ` ORDER BY LocalDomain ASC` return db.Network_List(a.db, query) } func (a *API) Peer_CreateNew(p *Peer) error { - p.Version = idgen.NextID(0) - p.PubKey = []byte{} - p.PubSignKey = []byte{} + p.WGPubKey = []byte{} + p.SignPubKey = []byte{} p.APIKey = idgen.NewToken() return db.Peer_Insert(a.db, p) @@ -158,21 +184,22 @@ func (a *API) Peer_Init(peer *Peer, args m.PeerInitArgs) error { a.lock.Lock() defer a.lock.Unlock() - peer.Version = idgen.NextID(0) - peer.PubKey = args.EncPubKey - peer.PubSignKey = args.PubSignKey + // Re-read from DB inside the lock — the caller's copy was fetched before + // we held the lock, so it may be stale under concurrent requests. + current, err := db.Peer_Get(a.db, peer.NetworkID, peer.PeerIP) + if err != nil { + return err + } + if len(current.WGPubKey) != 0 { + return errors.New("peer already initialized") + } + + peer.WGPubKey = args.WGPubKey + peer.SignPubKey = args.SignPubKey return db.Peer_UpdateFull(a.db, peer) } -func (a *API) Peer_Update(p *Peer) error { - a.lock.Lock() - defer a.lock.Unlock() - - p.Version = idgen.NextID(0) - return db.Peer_Update(a.db, p) -} - func (a *API) Peer_Delete(networkID int64, peerIP byte) error { return db.Peer_Delete(a.db, networkID, peerIP) } diff --git a/hub/api/db/generated.go b/hub/api/db/generated.go index 88aec6c..ab8e095 100644 --- a/hub/api/db/generated.go +++ b/hub/api/db/generated.go @@ -123,7 +123,9 @@ func Config_Get( ) { row = &Config{} r := tx.QueryRow("SELECT ConfigID,Password FROM config WHERE ConfigID=?", ConfigID) - err = r.Scan(&row.ConfigID, &row.Password) + if err = r.Scan(&row.ConfigID, &row.Password); err != nil { + row = nil + } return } @@ -137,7 +139,9 @@ func Config_GetWhere( ) { row = &Config{} r := tx.QueryRow(query, args...) - err = r.Scan(&row.ConfigID, &row.Password) + if err = r.Scan(&row.ConfigID, &row.Password); err != nil { + row = nil + } return } @@ -182,135 +186,17 @@ func Config_List( return l, nil } -// ---------------------------------------------------------------------------- -// Table: sessions -// ---------------------------------------------------------------------------- - -type Session struct { - SessionID string - CSRF string - SignedIn bool - CreatedAt int64 - LastSeenAt int64 -} - -const Session_SelectQuery = "SELECT SessionID,CSRF,SignedIn,CreatedAt,LastSeenAt FROM sessions" - -func Session_Insert( - tx TX, - row *Session, -) (err error) { - Session_Sanitize(row) - if err = Session_Validate(row); err != nil { - return err - } - - _, err = tx.Exec("INSERT INTO sessions(SessionID,CSRF,SignedIn,CreatedAt,LastSeenAt) VALUES(?,?,?,?,?)", row.SessionID, row.CSRF, row.SignedIn, row.CreatedAt, row.LastSeenAt) - return err -} - -func Session_Delete( - tx TX, - SessionID string, -) (err error) { - result, err := tx.Exec("DELETE FROM sessions WHERE SessionID=?", SessionID) - if err != nil { - return err - } - - n, err := result.RowsAffected() - if err != nil { - panic(err) - } - switch n { - case 0: - return sql.ErrNoRows - case 1: - return nil - default: - panic("multiple rows deleted") - } -} - -func Session_Get( - tx TX, - SessionID string, -) ( - row *Session, - err error, -) { - row = &Session{} - r := tx.QueryRow("SELECT SessionID,CSRF,SignedIn,CreatedAt,LastSeenAt FROM sessions WHERE SessionID=?", SessionID) - err = r.Scan(&row.SessionID, &row.CSRF, &row.SignedIn, &row.CreatedAt, &row.LastSeenAt) - return -} - -func Session_GetWhere( - tx TX, - query string, - args ...any, -) ( - row *Session, - err error, -) { - row = &Session{} - r := tx.QueryRow(query, args...) - err = r.Scan(&row.SessionID, &row.CSRF, &row.SignedIn, &row.CreatedAt, &row.LastSeenAt) - return -} - -func Session_Iterate( - tx TX, - query string, - args ...any, -) iter.Seq2[*Session, error] { - rows, err := tx.Query(query, args...) - if err != nil { - return func(yield func(*Session, error) bool) { - yield(nil, err) - } - } - - return func(yield func(*Session, error) bool) { - defer rows.Close() - for rows.Next() { - row := &Session{} - err := rows.Scan(&row.SessionID, &row.CSRF, &row.SignedIn, &row.CreatedAt, &row.LastSeenAt) - if !yield(row, err) { - return - } - } - } -} - -func Session_List( - tx TX, - query string, - args ...any, -) ( - l []*Session, - err error, -) { - for row, err := range Session_Iterate(tx, query, args...) { - if err != nil { - return nil, err - } - l = append(l, row) - } - return l, nil -} - // ---------------------------------------------------------------------------- // Table: networks // ---------------------------------------------------------------------------- type Network struct { - NetworkID int64 - Name string - Network []byte + NetworkID int64 + LocalDomain string + Network []byte } -const Network_SelectQuery = "SELECT NetworkID,Name,Network FROM networks" +const Network_SelectQuery = "SELECT NetworkID,LocalDomain,Network FROM networks" func Network_Insert( tx TX, @@ -321,7 +207,7 @@ func Network_Insert( return err } - _, err = tx.Exec("INSERT INTO networks(NetworkID,Name,Network) VALUES(?,?,?)", row.NetworkID, row.Name, row.Network) + _, err = tx.Exec("INSERT INTO networks(NetworkID,LocalDomain,Network) VALUES(?,?,?)", row.NetworkID, row.LocalDomain, row.Network) return err } @@ -334,7 +220,7 @@ func Network_UpdateFull( return err } - result, err := tx.Exec("UPDATE networks SET Name=?,Network=? WHERE NetworkID=?", row.Name, row.Network, row.NetworkID) + result, err := tx.Exec("UPDATE networks SET LocalDomain=?,Network=? WHERE NetworkID=?", row.LocalDomain, row.Network, row.NetworkID) if err != nil { return err } @@ -384,8 +270,10 @@ func Network_Get( err error, ) { row = &Network{} - r := tx.QueryRow("SELECT NetworkID,Name,Network FROM networks WHERE NetworkID=?", NetworkID) - err = r.Scan(&row.NetworkID, &row.Name, &row.Network) + r := tx.QueryRow("SELECT NetworkID,LocalDomain,Network FROM networks WHERE NetworkID=?", NetworkID) + if err = r.Scan(&row.NetworkID, &row.LocalDomain, &row.Network); err != nil { + row = nil + } return } @@ -399,7 +287,9 @@ func Network_GetWhere( ) { row = &Network{} r := tx.QueryRow(query, args...) - err = r.Scan(&row.NetworkID, &row.Name, &row.Network) + if err = r.Scan(&row.NetworkID, &row.LocalDomain, &row.Network); err != nil { + row = nil + } return } @@ -419,7 +309,7 @@ func Network_Iterate( defer rows.Close() for rows.Next() { row := &Network{} - err := rows.Scan(&row.NetworkID, &row.Name, &row.Network) + err := rows.Scan(&row.NetworkID, &row.LocalDomain, &row.Network) if !yield(row, err) { return } @@ -451,17 +341,17 @@ func Network_List( type Peer struct { NetworkID int64 PeerIP byte - Version int64 APIKey string Name string - PublicIP []byte + Addr4 []byte + Addr6 []byte Port uint16 Relay bool - PubKey []byte - PubSignKey []byte + WGPubKey []byte + SignPubKey []byte } -const Peer_SelectQuery = "SELECT NetworkID,PeerIP,Version,APIKey,Name,PublicIP,Port,Relay,PubKey,PubSignKey FROM peers" +const Peer_SelectQuery = "SELECT NetworkID,PeerIP,APIKey,Name,Addr4,Addr6,Port,Relay,WGPubKey,SignPubKey FROM peers" func Peer_Insert( tx TX, @@ -472,38 +362,10 @@ func Peer_Insert( return err } - _, err = tx.Exec("INSERT INTO peers(NetworkID,PeerIP,Version,APIKey,Name,PublicIP,Port,Relay,PubKey,PubSignKey) VALUES(?,?,?,?,?,?,?,?,?,?)", row.NetworkID, row.PeerIP, row.Version, row.APIKey, row.Name, row.PublicIP, row.Port, row.Relay, row.PubKey, row.PubSignKey) + _, err = tx.Exec("INSERT INTO peers(NetworkID,PeerIP,APIKey,Name,Addr4,Addr6,Port,Relay,WGPubKey,SignPubKey) VALUES(?,?,?,?,?,?,?,?,?,?)", row.NetworkID, row.PeerIP, row.APIKey, row.Name, row.Addr4, row.Addr6, row.Port, row.Relay, row.WGPubKey, row.SignPubKey) return err } -func Peer_Update( - tx TX, - row *Peer, -) (err error) { - Peer_Sanitize(row) - if err = Peer_Validate(row); err != nil { - return err - } - - result, err := tx.Exec("UPDATE peers SET Version=?,Name=?,PublicIP=?,Port=?,Relay=? WHERE NetworkID=? AND PeerIP=?", row.Version, row.Name, row.PublicIP, row.Port, row.Relay, row.NetworkID, row.PeerIP) - if err != nil { - return err - } - - n, err := result.RowsAffected() - if err != nil { - panic(err) - } - switch n { - case 0: - return sql.ErrNoRows - case 1: - return nil - default: - panic("multiple rows updated") - } -} - func Peer_UpdateFull( tx TX, row *Peer, @@ -513,7 +375,7 @@ func Peer_UpdateFull( return err } - result, err := tx.Exec("UPDATE peers SET Version=?,APIKey=?,Name=?,PublicIP=?,Port=?,Relay=?,PubKey=?,PubSignKey=? WHERE NetworkID=? AND PeerIP=?", row.Version, row.APIKey, row.Name, row.PublicIP, row.Port, row.Relay, row.PubKey, row.PubSignKey, row.NetworkID, row.PeerIP) + result, err := tx.Exec("UPDATE peers SET APIKey=?,Name=?,Addr4=?,Addr6=?,Port=?,Relay=?,WGPubKey=?,SignPubKey=? WHERE NetworkID=? AND PeerIP=?", row.APIKey, row.Name, row.Addr4, row.Addr6, row.Port, row.Relay, row.WGPubKey, row.SignPubKey, row.NetworkID, row.PeerIP) if err != nil { return err } @@ -565,8 +427,10 @@ func Peer_Get( err error, ) { row = &Peer{} - r := tx.QueryRow("SELECT NetworkID,PeerIP,Version,APIKey,Name,PublicIP,Port,Relay,PubKey,PubSignKey FROM peers WHERE NetworkID=? AND PeerIP=?", NetworkID, PeerIP) - err = r.Scan(&row.NetworkID, &row.PeerIP, &row.Version, &row.APIKey, &row.Name, &row.PublicIP, &row.Port, &row.Relay, &row.PubKey, &row.PubSignKey) + r := tx.QueryRow("SELECT NetworkID,PeerIP,APIKey,Name,Addr4,Addr6,Port,Relay,WGPubKey,SignPubKey FROM peers WHERE NetworkID=? AND PeerIP=?", NetworkID, PeerIP) + if err = r.Scan(&row.NetworkID, &row.PeerIP, &row.APIKey, &row.Name, &row.Addr4, &row.Addr6, &row.Port, &row.Relay, &row.WGPubKey, &row.SignPubKey); err != nil { + row = nil + } return } @@ -580,7 +444,9 @@ func Peer_GetWhere( ) { row = &Peer{} r := tx.QueryRow(query, args...) - err = r.Scan(&row.NetworkID, &row.PeerIP, &row.Version, &row.APIKey, &row.Name, &row.PublicIP, &row.Port, &row.Relay, &row.PubKey, &row.PubSignKey) + if err = r.Scan(&row.NetworkID, &row.PeerIP, &row.APIKey, &row.Name, &row.Addr4, &row.Addr6, &row.Port, &row.Relay, &row.WGPubKey, &row.SignPubKey); err != nil { + row = nil + } return } @@ -600,7 +466,7 @@ func Peer_Iterate( defer rows.Close() for rows.Next() { row := &Peer{} - err := rows.Scan(&row.NetworkID, &row.PeerIP, &row.Version, &row.APIKey, &row.Name, &row.PublicIP, &row.Port, &row.Relay, &row.PubKey, &row.PubSignKey) + err := rows.Scan(&row.NetworkID, &row.PeerIP, &row.APIKey, &row.Name, &row.Addr4, &row.Addr6, &row.Port, &row.Relay, &row.WGPubKey, &row.SignPubKey) if !yield(row, err) { return } diff --git a/hub/api/db/sanitize-validate.go b/hub/api/db/sanitize-validate.go index 71785e9..ffe1f7d 100644 --- a/hub/api/db/sanitize-validate.go +++ b/hub/api/db/sanitize-validate.go @@ -8,9 +8,11 @@ import ( var ( ErrInvalidIP = errors.New("invalid IP") + ErrInvalidPeerIP = errors.New("invalid peer IP") ErrNonPrivateIP = errors.New("non-private IP") ErrInvalidPort = errors.New("invalid port") ErrInvalidNetName = errors.New("invalid network name") + ErrNetNameNotLocal = errors.New("network name must end with .local") ErrInvalidPeerName = errors.New("invalid peer name") ) @@ -21,15 +23,8 @@ func Config_Validate(c *Config) error { return nil } -func Session_Sanitize(s *Session) { -} - -func Session_Validate(s *Session) error { - return nil -} - func Network_Sanitize(n *Network) { - n.Name = strings.TrimSpace(n.Name) + n.LocalDomain = strings.TrimSpace(n.LocalDomain) if addr, ok := netip.AddrFromSlice(n.Network); ok { n.Network = addr.AsSlice() @@ -37,12 +32,17 @@ func Network_Sanitize(n *Network) { } func Network_Validate(c *Network) error { - // 16 bytes is linux limit for network interface names. - if len(c.Name) == 0 || len(c.Name) > 16 { + // 15 bytes is linux limit for network interface names. With ending .local, + // max length is 21. + if len(c.LocalDomain) == 0 || len(c.LocalDomain) > 21 { return ErrInvalidNetName } - for _, c := range c.Name { + if !strings.HasSuffix(c.LocalDomain, ".local") { + return ErrNetNameNotLocal + } + + for _, c := range strings.TrimSuffix(c.LocalDomain, ".local") { if c >= 'a' && c <= 'z' { continue } @@ -66,21 +66,35 @@ func Network_Validate(c *Network) error { func Peer_Sanitize(p *Peer) { p.Name = strings.TrimSpace(p.Name) - if len(p.PublicIP) != 0 { - addr, ok := netip.AddrFromSlice(p.PublicIP) - if ok && addr.Is4() { - p.PublicIP = addr.AsSlice() + if len(p.Addr4) != 0 { + if addr, ok := netip.AddrFromSlice(p.Addr4); ok { + // Unmap so an IPv4-mapped form is stored canonically as 4 bytes. + p.Addr4 = addr.Unmap().AsSlice() + } + } + if len(p.Addr6) != 0 { + if addr, ok := netip.AddrFromSlice(p.Addr6); ok { + p.Addr6 = addr.AsSlice() } } if p.Port == 0 { - p.Port = 456 + p.Port = 51820 } } func Peer_Validate(p *Peer) error { - if len(p.PublicIP) > 0 { - _, ok := netip.AddrFromSlice(p.PublicIP) - if !ok { + if p.PeerIP < 1 || p.PeerIP > 254 { + return ErrInvalidPeerIP + } + if len(p.Addr4) > 0 { + // Must be a genuine IPv4 address (reject an IPv6 in the v4 field). + if addr, ok := netip.AddrFromSlice(p.Addr4); !ok || !addr.Is4() { + return ErrInvalidIP + } + } + if len(p.Addr6) > 0 { + // Must be a genuine IPv6 address (reject IPv4 / IPv4-mapped in the v6 field). + if addr, ok := netip.AddrFromSlice(p.Addr6); !ok || !addr.Is6() || addr.Is4In6() { return ErrInvalidIP } } @@ -88,6 +102,9 @@ func Peer_Validate(p *Peer) error { return ErrInvalidPort } + if len(p.Name) == 0 { + return ErrInvalidPeerName + } for _, c := range p.Name { if c >= 'a' && c <= 'z' { continue @@ -95,10 +112,9 @@ func Peer_Validate(p *Peer) error { if c >= '0' && c <= '9' { continue } - if c == '.' || c == '-' || c == '_' { + if c == '-' { continue } - return ErrInvalidPeerName } diff --git a/hub/api/db/tables.defs b/hub/api/db/tables.defs index d6dc338..39665fa 100644 --- a/hub/api/db/tables.defs +++ b/hub/api/db/tables.defs @@ -3,29 +3,21 @@ TABLE config OF Config ( Password []byte ); -TABLE sessions OF Session NoUpdate ( - SessionID string PK, - CSRF string, - SignedIn bool, - CreatedAt int64, - LastSeenAt int64 -); - TABLE networks OF Network ( - NetworkID int64 PK, - Name string NoUpdate, - Network []byte NoUpdate + NetworkID int64 PK, + LocalDomain string NoUpdate, + Network []byte NoUpdate ); TABLE peers OF Peer ( NetworkID int64 PK, PeerIP byte PK, - Version int64, APIKey string NoUpdate, - Name string, - PublicIP []byte, - Port uint16, - Relay bool, - PubKey []byte NoUpdate, - PubSignKey []byte NoUpdate + Name string NoUpdate, + Addr4 []byte NoUpdate, + Addr6 []byte NoUpdate, + Port uint16 NoUpdate, + Relay bool NoUpdate, + WGPubKey []byte NoUpdate, + SignPubKey []byte NoUpdate ); diff --git a/hub/api/db/written.go b/hub/api/db/written.go index 6d61bb5..11251f2 100644 --- a/hub/api/db/written.go +++ b/hub/api/db/written.go @@ -1,31 +1,5 @@ package db -import "time" - -func Session_UpdateLastSeenAt( - tx TX, - id string, -) (err error) { - _, err = tx.Exec("UPDATE sessions SET LastSeenAt=? WHERE SessionID=?", time.Now().Unix(), id) - return err -} - -func Session_SetSignedIn( - tx TX, - id string, -) (err error) { - _, err = tx.Exec("UPDATE sessions SET SignedIn=1 WHERE SessionID=?", id) - return err -} - -func Session_DeleteBefore( - tx TX, - timestamp int64, -) (err error) { - _, err = tx.Exec("DELETE FROM sessions WHERE LastSeenAtCreate Network
- -

-
- +

+
+


diff --git a/hub/templates/admin-network-list.html b/hub/templates/admin-network-list.html index 3626d6f..43d4036 100644 --- a/hub/templates/admin-network-list.html +++ b/hub/templates/admin-network-list.html @@ -9,7 +9,7 @@ - + @@ -18,7 +18,7 @@ diff --git a/hub/templates/admin-password-edit.html b/hub/templates/admin-password-edit.html index 32c8d43..e65d2d4 100644 --- a/hub/templates/admin-password-edit.html +++ b/hub/templates/admin-password-edit.html @@ -2,8 +2,7 @@

Change Password

- -

+


diff --git a/hub/templates/admin-sign-out.html b/hub/templates/admin-sign-out.html index 7141fb8..fd5aa19 100644 --- a/hub/templates/admin-sign-out.html +++ b/hub/templates/admin-sign-out.html @@ -2,8 +2,7 @@

Sign Out

- -

+

Cancel

diff --git a/hub/templates/network/base.html b/hub/templates/network/base.html index b773baf..3e848f6 100644 --- a/hub/templates/network/base.html +++ b/hub/templates/network/base.html @@ -17,7 +17,7 @@

Network: - {{.Network.Name}} + {{.Network.LocalDomain}}

{{block "body" .}}There's nothing here.{{end}} diff --git a/hub/templates/network/network-delete.html b/hub/templates/network/network-delete.html index 8b61116..f80bc79 100644 --- a/hub/templates/network/network-delete.html +++ b/hub/templates/network/network-delete.html @@ -5,8 +5,7 @@

You must first delete all peers.

{{- else -}} - - +

Cancel diff --git a/hub/templates/network/network-view.html b/hub/templates/network/network-view.html index 860814e..dc3f22a 100644 --- a/hub/templates/network/network-view.html +++ b/hub/templates/network/network-view.html @@ -22,7 +22,8 @@

- + + @@ -36,7 +37,8 @@ - + + diff --git a/hub/templates/network/peer-create.html b/hub/templates/network/peer-create.html index cc6b92f..303cb6e 100644 --- a/hub/templates/network/peer-create.html +++ b/hub/templates/network/peer-create.html @@ -2,7 +2,6 @@

New Peer

-


@@ -13,12 +12,16 @@

-
- +
+

-
- +
+ +

+

+
+

Cancel diff --git a/hub/templates/network/peer-edit.html b/hub/templates/network/peer-edit.html deleted file mode 100644 index a27f674..0000000 --- a/hub/templates/network/peer-edit.html +++ /dev/null @@ -1,35 +0,0 @@ -{{define "body" -}} -

Edit Peer

- -{{with .Peer -}} - - -

-
- -

-

-
- -

-

-
- -

-

-
- -

-

- -

-

- - Cancel -

- -{{- end}} -{{- end}} diff --git a/hub/templates/network/peer-view.html b/hub/templates/network/peer-view.html index 546f69b..11bf105 100644 --- a/hub/templates/network/peer-view.html +++ b/hub/templates/network/peer-view.html @@ -1,17 +1,17 @@ {{define "body" -}}

{{.Peer.Name}}

- Edit / Delete

{{with .Peer -}}
NameLocal Domain Network
- {{.Name}} + {{.LocalDomain}} {{ipToString .Network}}
PeerIP NamePublic IPIPv4IPv6 Port Relay
{{.Name}}{{ipToString .PublicIP}}{{ipToString .Addr4}}{{ipToString .Addr6}} {{.Port}} {{if .Relay}}T{{else}}F{{end}}
- - + + + - +
Peer IP{{.PeerIP}}
Public IP{{ipToString .PublicIP}}
Port{{.Port}}
IPv4 Address{{ipToString .Addr4}}
IPv6 Address{{ipToString .Addr6}}
WireGuard Port{{.Port}}
Relay{{if .Relay}}T{{else}}F{{end}}
WG Public Key{{wgKeyString .WGPubKey}}

@@ -19,7 +19,6 @@

{{.APIKey}}

- {{- end}} {{- end}} diff --git a/hub/templates/sign-in.html b/hub/templates/sign-in.html index 07e9d87..6d885ad 100644 --- a/hub/templates/sign-in.html +++ b/hub/templates/sign-in.html @@ -2,8 +2,7 @@

Sign In

- -

+


diff --git a/hub/util.go b/hub/util.go index 5503cc9..bf70738 100644 --- a/hub/util.go +++ b/hub/util.go @@ -38,6 +38,19 @@ func (app *App) sendJSON(w http.ResponseWriter, data any) error { return nil } +// addrFromBytes parses raw IP bytes (4 or 16) into a netip.Addr, unmapping +// IPv4-in-IPv6, returning the zero Addr for empty/invalid input. +func addrFromBytes(b []byte) netip.Addr { + if len(b) == 0 { + return netip.Addr{} + } + addr, ok := netip.AddrFromSlice(b) + if !ok { + return netip.Addr{} + } + return addr.Unmap() +} + func stringToIP(in string) ([]byte, error) { in = strings.TrimSpace(in) if len(in) == 0 { diff --git a/m/models.go b/m/models.go index 0bac684..386401c 100644 --- a/m/models.go +++ b/m/models.go @@ -1,28 +1,133 @@ // The package `m` contains models shared between the hub and peer programs. package m +import ( + "encoding/base64" + "encoding/json" + "fmt" + "net/netip" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + type PeerInitArgs struct { - EncPubKey []byte - PubSignKey []byte + WGPubKey []byte + SignPubKey []byte } type PeerInitResp struct { PeerIP byte Network []byte + LocalDomain string NetworkState NetworkState } +// Peer is the network membership record for a single peer, exchanged between +// the hub and peers. Addr4/Addr6 are the peer's public endpoint addresses (zero +// if it has none); Port is its WireGuard listen port, meaningful even for a +// non-public peer (it is the peer's own bind/beacon port). type Peer struct { PeerIP byte - Version int64 Name string - PublicIP []byte + Addr4 netip.Addr // zero if none + Addr6 netip.Addr // zero if none Port uint16 Relay bool - PubKey []byte - PubSignKey []byte + WGPubKey wgtypes.Key + SignPubKey [32]byte +} + +// IsPublic reports whether the peer advertises at least one reachable endpoint. +func (p Peer) IsPublic() bool { + return p.Addr4.IsValid() || p.Addr6.IsValid() +} + +// Endpoint4 returns the IPv4 endpoint (addr+port), or the zero AddrPort if the +// peer has no IPv4 address. +func (p Peer) Endpoint4() netip.AddrPort { + if !p.Addr4.IsValid() { + return netip.AddrPort{} + } + return netip.AddrPortFrom(p.Addr4, p.Port) +} + +// Endpoint6 returns the IPv6 endpoint (addr+port), or the zero AddrPort if the +// peer has no IPv6 address. +func (p Peer) Endpoint6() netip.AddrPort { + if !p.Addr6.IsValid() { + return netip.AddrPort{} + } + return netip.AddrPortFrom(p.Addr6, p.Port) +} + +// PreferredEndpoint returns the IPv4 endpoint if present, else IPv6. +func (p Peer) PreferredEndpoint() netip.AddrPort { + if ep := p.Endpoint4(); ep.IsValid() { + return ep + } + return p.Endpoint6() +} + +// peerJSON is the wire representation. netip.Addr fields round-trip as text +// strings automatically; only the fixed-size key arrays need base64 (otherwise +// encoding/json would emit them as arrays of numbers). +type peerJSON struct { + PeerIP byte + Name string + Addr4 netip.Addr + Addr6 netip.Addr + Port uint16 + Relay bool + WGPubKey string + SignPubKey string +} + +func (p Peer) MarshalJSON() ([]byte, error) { + return json.Marshal(peerJSON{ + PeerIP: p.PeerIP, + Name: p.Name, + Addr4: p.Addr4, + Addr6: p.Addr6, + Port: p.Port, + Relay: p.Relay, + WGPubKey: base64.StdEncoding.EncodeToString(p.WGPubKey[:]), + SignPubKey: base64.StdEncoding.EncodeToString(p.SignPubKey[:]), + }) +} + +func (p *Peer) UnmarshalJSON(data []byte) error { + var j peerJSON + if err := json.Unmarshal(data, &j); err != nil { + return err + } + wg, err := base64.StdEncoding.DecodeString(j.WGPubKey) + if err != nil { + return fmt.Errorf("decode WGPubKey: %w", err) + } + key, err := wgtypes.NewKey(wg) + if err != nil { + return fmt.Errorf("invalid WGPubKey: %w", err) + } + sign, err := base64.StdEncoding.DecodeString(j.SignPubKey) + if err != nil { + return fmt.Errorf("decode SignPubKey: %w", err) + } + if len(sign) != 32 { + return fmt.Errorf("invalid SignPubKey length: %d", len(sign)) + } + *p = Peer{ + PeerIP: j.PeerIP, + Name: j.Name, + Addr4: j.Addr4, + Addr6: j.Addr6, + Port: j.Port, + Relay: j.Relay, + WGPubKey: key, + SignPubKey: [32]byte(sign), + } + return nil } type NetworkState struct { - Peers [256]*Peer + Peers []Peer } diff --git a/peer/app.go b/peer/app.go new file mode 100644 index 0000000..d113ffb --- /dev/null +++ b/peer/app.go @@ -0,0 +1,105 @@ +package peer + +import ( + "net/netip" + "os" + "os/signal" + "syscall" + "time" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "vppn/m" + "vppn/peer/control" + "vppn/peer/multicast" + "vppn/peer/wginterface" +) + +var _ WGDevice = (*wginterface.Device)(nil) // compile-time check: Device satisfies WGDevice + +const ( + ControlPort = 4561 + PingInterval = 8 * time.Second + TimeoutInterval = 30 * time.Second +) + +// scratchSize is large enough for the biggest buffer either the ping or the +// multicast path serializes through the shared App scratch. +const scratchSize = max(control.Size, multicast.SignedPacketSize) + +type PingEvent struct { + srcVPNIP netip.Addr + ping control.Ping +} + +// App is the peer application. All mutable state lives here and is +// accessed only from the Run goroutine. +type App struct { + // Identity + vpnIP netip.Addr + vpnNet netip.Prefix + privKey wgtypes.Key + pubKey wgtypes.Key + isRelay bool + isPublic bool + localDomain string + + // Infrastructure + dev WGDevice + controlConn ControlConn + + // Peer state + relay *Peer + peersByKey map[wgtypes.Key]*Peer + peersByIP map[netip.Addr]*Peer + + // Our own external endpoints, learned from Dst fields in incoming pings + selfV4 netip.AddrPort + selfV6 netip.AddrPort + + // Reusable serialization scratch for outgoing pings and multicast signature + // verification. Only touched from the Run goroutine. + scratch []byte + + // Event channels fed by background goroutines + hubAddCh <-chan m.Peer + hubRemoveCh <-chan wgtypes.Key + pingCh <-chan PingEvent + multicastCh <-chan multicast.Packet +} + +// Run is the main event loop. It runs until SIGTERM/SIGINT. +func (a *App) Run() error { + // Establish a clean hosts section before the first poll lands, clearing + // any stale entries left by a prior run (e.g. crash, or peers removed + // while we were down). + a.updateHosts() + + ticker := time.NewTicker(PingInterval) + defer ticker.Stop() + + sig := make(chan os.Signal, 1) + signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT) + defer signal.Stop(sig) + + for { + select { + case p := <-a.hubAddCh: + a.onAddPeer(p) + case key := <-a.hubRemoveCh: + a.onRemovePeer(key) + case e := <-a.pingCh: + a.onPing(e) + case e := <-a.multicastCh: + a.onMulticastDiscovery(e) + case <-ticker.C: + a.onTick() + case <-sig: + return a.onShutdown() + } + } +} + +func (a *App) onShutdown() error { + return wginterface.Delete(a.dev.Name()) +} diff --git a/peer/app_test.go b/peer/app_test.go new file mode 100644 index 0000000..6564ef9 --- /dev/null +++ b/peer/app_test.go @@ -0,0 +1,62 @@ +package peer + +import ( + "net/netip" + "testing" + "time" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "vppn/m" + "vppn/peer/multicast" +) + +// addRelayPeer adds a public relay peer and marks it Up so it satisfies +// CanRelay. It does not set a.relay — callers do that explicitly. +func addRelayPeer(t *testing.T, a *App, vpnIP string, ep netip.AddrPort) *Peer { + t.Helper() + key := mustKey(t) + ip := netip.MustParseAddr(vpnIP) + a.onAddPeer(m.Peer{ + WGPubKey: key, + PeerIP: ip.As4()[3], + Addr4: ep.Addr(), + Port: ep.Port(), + Relay: true, + }) + p := a.peersByKey[key] + p.wgPeer.LastHandshakeTime = time.Now() + return p +} + +// newTestApp returns a minimal App wired to a fakeWGDevice and fakeControlConn. +// vpnIP is the local VPN address (e.g. "10.0.0.1"). +// isPublic / isRelay describe the local node's role. +func newTestApp(t *testing.T, vpnIP string, isPublic, isRelay bool) (*App, *fakeWGDevice, *fakeControlConn) { + t.Helper() + privKey, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Fatalf("generate key: %v", err) + } + ip := netip.MustParseAddr(vpnIP) + dev := &fakeWGDevice{} + cc := &fakeControlConn{} + a := &App{ + vpnIP: ip, + vpnNet: netip.MustParsePrefix("10.0.0.0/24"), + privKey: privKey, + pubKey: privKey.PublicKey(), + isPublic: isPublic, + isRelay: isRelay, + dev: dev, + controlConn: cc, + peersByKey: make(map[wgtypes.Key]*Peer), + peersByIP: make(map[netip.Addr]*Peer), + scratch: make([]byte, scratchSize), + hubAddCh: make(chan m.Peer), + hubRemoveCh: make(chan wgtypes.Key), + pingCh: make(chan PingEvent), + multicastCh: make(chan multicast.Packet), + } + return a, dev, cc +} diff --git a/peer/bitset.go b/peer/bitset.go deleted file mode 100644 index 8d03b50..0000000 --- a/peer/bitset.go +++ /dev/null @@ -1,21 +0,0 @@ -package peer - -const bitSetSize = 512 // Multiple of 64. - -type bitSet [bitSetSize / 64]uint64 - -func (bs *bitSet) Set(i int) { - bs[i/64] |= 1 << (i % 64) -} - -func (bs *bitSet) Clear(i int) { - bs[i/64] &= ^(1 << (i % 64)) -} - -func (bs *bitSet) ClearAll() { - clear(bs[:]) -} - -func (bs *bitSet) Get(i int) bool { - return bs[i/64]&(1<<(i%64)) != 0 -} diff --git a/peer/bitset_test.go b/peer/bitset_test.go deleted file mode 100644 index 01ae82b..0000000 --- a/peer/bitset_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package peer - -import ( - "math/rand" - "testing" -) - -func TestBitSet(t *testing.T) { - state := make([]bool, bitSetSize) - for i := range state { - state[i] = rand.Float32() > 0.5 - } - - bs := bitSet{} - - for i := range state { - if state[i] { - bs.Set(i) - } - } - - for i := range state { - if bs.Get(i) != state[i] { - t.Fatal(i, state[i], bs.Get(i)) - } - } - - for i := range state { - if rand.Float32() > 0.5 { - state[i] = false - bs.Clear(i) - } - } - - for i := range state { - if bs.Get(i) != state[i] { - t.Fatal(i, state[i], bs.Get(i)) - } - } - - bs.ClearAll() - - for i := range state { - if bs.Get(i) { - t.Fatal(i, bs.Get(i)) - } - } -} diff --git a/peer/cipher-control.go b/peer/cipher-control.go deleted file mode 100644 index 178ff97..0000000 --- a/peer/cipher-control.go +++ /dev/null @@ -1,26 +0,0 @@ -package peer - -import "golang.org/x/crypto/nacl/box" - -type controlCipher struct { - sharedKey [32]byte -} - -func newControlCipher(privKey, pubKey []byte) *controlCipher { - shared := [32]byte{} - box.Precompute(&shared, (*[32]byte)(pubKey), (*[32]byte)(privKey)) - return &controlCipher{shared} -} - -func (cc *controlCipher) Encrypt(h Header, data, out []byte) []byte { - const s = controlHeaderSize - out = out[:s+controlCipherOverhead+len(data)] - h.Marshal(out[:s]) - box.SealAfterPrecomputation(out[s:s], data, (*[24]byte)(out[:s]), &cc.sharedKey) - return out -} - -func (cc *controlCipher) Decrypt(encrypted, out []byte) (data []byte, ok bool) { - const s = controlHeaderSize - return box.OpenAfterPrecomputation(out[:0], encrypted[s:], (*[24]byte)(encrypted[:s]), &cc.sharedKey) -} diff --git a/peer/cipher-control_test.go b/peer/cipher-control_test.go deleted file mode 100644 index abeb8d5..0000000 --- a/peer/cipher-control_test.go +++ /dev/null @@ -1,122 +0,0 @@ -package peer - -import ( - "bytes" - "crypto/rand" - "reflect" - "testing" - - "golang.org/x/crypto/nacl/box" -) - -func newControlCipherForTesting() (c1, c2 *controlCipher) { - pubKey1, privKey1, err := box.GenerateKey(rand.Reader) - if err != nil { - panic(err) - } - - pubKey2, privKey2, err := box.GenerateKey(rand.Reader) - if err != nil { - panic(err) - } - - return newControlCipher(privKey1[:], pubKey2[:]), - newControlCipher(privKey2[:], pubKey1[:]) -} - -func TestControlCipher(t *testing.T) { - c1, c2 := newControlCipherForTesting() - - maxSizePlaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead) - rand.Read(maxSizePlaintext) - - testCases := [][]byte{ - make([]byte, 0), - {1}, - {255}, - {1, 2, 3, 4, 5}, - []byte("Hello world"), - maxSizePlaintext, - } - - for _, plaintext := range testCases { - h1 := Header{ - StreamID: controlStreamID, - Counter: 235153, - SourceIP: 4, - DestIP: 88, - } - - encrypted := make([]byte, bufferSize) - - encrypted = c1.Encrypt(h1, plaintext, encrypted) - - h2 := Header{} - h2.Parse(encrypted) - if !reflect.DeepEqual(h1, h2) { - t.Fatal(h1, h2) - } - - decrypted, ok := c2.Decrypt(encrypted, make([]byte, bufferSize)) - if !ok { - t.Fatal(ok) - } - - if !bytes.Equal(decrypted, plaintext) { - t.Fatal("not equal") - } - } -} - -func TestControlCipher_ShortCiphertext(t *testing.T) { - c1, _ := newControlCipherForTesting() - shortText := make([]byte, controlHeaderSize+controlCipherOverhead-1) - rand.Read(shortText) - _, ok := c1.Decrypt(shortText, make([]byte, bufferSize)) - if ok { - t.Fatal(ok) - } -} - -func BenchmarkControlCipher_Encrypt(b *testing.B) { - c1, _ := newControlCipherForTesting() - h1 := Header{ - Counter: 235153, - SourceIP: 4, - DestIP: 88, - } - - plaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead) - rand.Read(plaintext) - - encrypted := make([]byte, bufferSize) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - encrypted = c1.Encrypt(h1, plaintext, encrypted) - } -} - -func BenchmarkControlCipher_Decrypt(b *testing.B) { - c1, c2 := newControlCipherForTesting() - - h1 := Header{ - Counter: 235153, - SourceIP: 4, - DestIP: 88, - } - - plaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead) - rand.Read(plaintext) - - encrypted := make([]byte, bufferSize) - - encrypted = c1.Encrypt(h1, plaintext, encrypted) - - decrypted := make([]byte, bufferSize) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - decrypted, _ = c2.Decrypt(encrypted, decrypted) - } -} diff --git a/peer/cipher-data.go b/peer/cipher-data.go deleted file mode 100644 index 5ce8555..0000000 --- a/peer/cipher-data.go +++ /dev/null @@ -1,61 +0,0 @@ -package peer - -import ( - "crypto/aes" - "crypto/cipher" - "crypto/rand" - "log" -) - -type dataCipher struct { - key [32]byte - aead cipher.AEAD -} - -func newDataCipher() *dataCipher { - key := [32]byte{} - if _, err := rand.Read(key[:]); err != nil { - log.Fatalf("Failed to read random data: %v", err) - } - return newDataCipherFromKey(key) -} - -func newDataCipherFromKey(key [32]byte) *dataCipher { - block, err := aes.NewCipher(key[:]) - if err != nil { - log.Fatalf("Failed to create new cipher: %v", err) - } - - aead, err := cipher.NewGCM(block) - if err != nil { - log.Fatalf("Failed to create new GCM: %v", err) - } - - return &dataCipher{key: key, aead: aead} -} - -func (sc *dataCipher) Key() [32]byte { - return sc.key -} - -func (sc *dataCipher) Encrypt(h Header, data, out []byte) []byte { - const s = dataHeaderSize - out = out[:s+dataCipherOverhead+len(data)] - h.Marshal(out[:s]) - sc.aead.Seal(out[s:s], out[:s], data, nil) - return out -} - -func (sc *dataCipher) Decrypt(encrypted, out []byte) (data []byte, ok bool) { - const s = dataHeaderSize - if len(encrypted) < s+dataCipherOverhead { - ok = false - return - } - - var err error - - data, err = sc.aead.Open(out[:0], encrypted[:s], encrypted[s:], nil) - ok = err == nil - return -} diff --git a/peer/cipher-data_test.go b/peer/cipher-data_test.go deleted file mode 100644 index 4a388f8..0000000 --- a/peer/cipher-data_test.go +++ /dev/null @@ -1,141 +0,0 @@ -package peer - -import ( - "bytes" - "crypto/rand" - mrand "math/rand/v2" - "reflect" - "testing" -) - -func TestDataCipher(t *testing.T) { - maxSizePlaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead) - rand.Read(maxSizePlaintext) - - testCases := [][]byte{ - make([]byte, 0), - {1}, - {255}, - {1, 2, 3, 4, 5}, - []byte("Hello world"), - maxSizePlaintext, - } - - for _, plaintext := range testCases { - h1 := Header{ - StreamID: dataStreamID, - Counter: 235153, - SourceIP: 4, - DestIP: 88, - } - - encrypted := make([]byte, bufferSize) - - dc1 := newDataCipher() - encrypted = dc1.Encrypt(h1, plaintext, encrypted) - h2 := Header{} - h2.Parse(encrypted) - - dc2 := newDataCipherFromKey(dc1.Key()) - - decrypted, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize)) - if !ok { - t.Fatal(ok) - } - - if !bytes.Equal(plaintext, decrypted) { - t.Fatal("not equal") - } - - if !reflect.DeepEqual(h1, h2) { - t.Fatalf("%v != %v", h1, h2) - } - } -} - -func TestDataCipher_ModifyCiphertext(t *testing.T) { - maxSizePlaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead) - rand.Read(maxSizePlaintext) - - testCases := [][]byte{ - make([]byte, 0), - {1}, - {255}, - {1, 2, 3, 4, 5}, - []byte("Hello world"), - maxSizePlaintext, - } - - for _, plaintext := range testCases { - h1 := Header{ - Counter: 235153, - SourceIP: 4, - DestIP: 88, - } - - encrypted := make([]byte, bufferSize) - - dc1 := newDataCipher() - encrypted = dc1.Encrypt(h1, plaintext, encrypted) - encrypted[mrand.IntN(len(encrypted))]++ - - dc2 := newDataCipherFromKey(dc1.Key()) - - _, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize)) - if ok { - t.Fatal(ok) - } - } -} - -func TestDataCipher_ShortCiphertext(t *testing.T) { - dc1 := newDataCipher() - shortText := make([]byte, dataHeaderSize+dataCipherOverhead-1) - rand.Read(shortText) - _, ok := dc1.Decrypt(shortText, make([]byte, bufferSize)) - if ok { - t.Fatal(ok) - } -} - -func BenchmarkDataCipher_Encrypt(b *testing.B) { - h1 := Header{ - Counter: 235153, - SourceIP: 4, - DestIP: 88, - } - - plaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead) - rand.Read(plaintext) - - encrypted := make([]byte, bufferSize) - - dc1 := newDataCipher() - b.ResetTimer() - for i := 0; i < b.N; i++ { - encrypted = dc1.Encrypt(h1, plaintext, encrypted) - } -} - -func BenchmarkDataCipher_Decrypt(b *testing.B) { - h1 := Header{ - Counter: 235153, - SourceIP: 4, - DestIP: 88, - } - - plaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead) - rand.Read(plaintext) - - encrypted := make([]byte, bufferSize) - - dc1 := newDataCipher() - encrypted = dc1.Encrypt(h1, plaintext, encrypted) - - decrypted := make([]byte, bufferSize) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - decrypted, _ = dc1.Decrypt(encrypted, decrypted) - } -} diff --git a/peer/connreader.go b/peer/connreader.go deleted file mode 100644 index 5427227..0000000 --- a/peer/connreader.go +++ /dev/null @@ -1,46 +0,0 @@ -package peer - -import ( - "log" - "net" - "net/netip" -) - -type ConnReader struct { - Globals - conn *net.UDPConn - buf []byte -} - -func NewConnReader(g Globals, conn *net.UDPConn) *ConnReader { - return &ConnReader{ - Globals: g, - conn: conn, - buf: make([]byte, bufferSize), - } -} - -func (r *ConnReader) Run() { - for { - r.handleNextPacket() - } -} - -func (r *ConnReader) handleNextPacket() { - buf := r.buf[:bufferSize] - n, remoteAddr, err := r.conn.ReadFromUDPAddrPort(buf) - if err != nil { - log.Fatalf("Failed to read from UDP port: %v", err) - } - - if n < headerSize { - return - } - - remoteAddr = netip.AddrPortFrom(remoteAddr.Addr().Unmap(), remoteAddr.Port()) - - buf = buf[:n] - h := parseHeader(buf) - - r.RemotePeers[h.SourceIP].Load().HandlePacket(h, remoteAddr, buf) -} diff --git a/peer/control/ping.go b/peer/control/ping.go new file mode 100644 index 0000000..8d0826e --- /dev/null +++ b/peer/control/ping.go @@ -0,0 +1,76 @@ +// Package control implements the VPN-internal peer control protocol. +// Peers exchange Ping packets over UDP on the VPN control port to maintain +// liveness and discover external endpoints for direct connection attempts. +package control + +import ( + "encoding/binary" + "fmt" + "net/netip" +) + +const ( + version = 1 + Size = 51 // 1 version + 8 PingTS + 6 SrcV4 + 18 SrcV6 + 18 Dst +) + +// Ping is the single control packet type exchanged between VPN peers. +// +// In each peer pair, the peer with the lower VPN IP is the client: it sets +// PingTS and sends pings on a timer. The server echoes PingTS back in its +// response, allowing the client to compute RTT = now - PingTS. +// +// Both client and server populate SrcV4, SrcV6, and Dst on every packet so +// endpoint information flows in both directions. +// +// Dst is the recipient's external endpoint as observed by the sender from the +// WireGuard handshake source. Zero if the sender has not observed a handshake +// from the recipient. +type Ping struct { + PingTS int64 // Client ping send time in nanoseconds. + SrcV4 netip.AddrPort // Sender's discovered IPv4 address and port. + SrcV6 netip.AddrPort // Sender's discovered IPv6 address and port. + Dst netip.AddrPort +} + +// Marshal encodes p into buf (which must be at least Size bytes) and returns +// buf[:Size]. Taking the buffer lets callers reuse one across sends; every +// field is written unconditionally so a reused buffer needs no pre-zeroing. +func (p Ping) Marshal(buf []byte) []byte { + buf[0] = version + binary.BigEndian.PutUint64(buf[1:9], uint64(p.PingTS)) + if p.SrcV4.IsValid() { + a4 := p.SrcV4.Addr().As4() + copy(buf[9:13], a4[:]) + binary.BigEndian.PutUint16(buf[13:15], p.SrcV4.Port()) + } else { + clear(buf[9:15]) + } + a16 := p.SrcV6.Addr().As16() + copy(buf[15:31], a16[:]) + binary.BigEndian.PutUint16(buf[31:33], p.SrcV6.Port()) + a16 = p.Dst.Addr().As16() + copy(buf[33:49], a16[:]) + binary.BigEndian.PutUint16(buf[49:51], p.Dst.Port()) + return buf[:Size] +} + +// Unmarshal decodes a Ping from a fixed-size 51-byte array. +func Unmarshal(buf [Size]byte) (Ping, error) { + if buf[0] != version { + return Ping{}, fmt.Errorf("unknown ping version %d", buf[0]) + } + p := Ping{ + PingTS: int64(binary.BigEndian.Uint64(buf[1:9])), + } + if addr := netip.AddrFrom4([4]byte(buf[9:13])); !addr.IsUnspecified() { + p.SrcV4 = netip.AddrPortFrom(addr, binary.BigEndian.Uint16(buf[13:15])) + } + if addr := netip.AddrFrom16([16]byte(buf[15:31])); !addr.IsUnspecified() { + p.SrcV6 = netip.AddrPortFrom(addr, binary.BigEndian.Uint16(buf[31:33])) + } + if addr := netip.AddrFrom16([16]byte(buf[33:49])).Unmap(); !addr.IsUnspecified() { + p.Dst = netip.AddrPortFrom(addr, binary.BigEndian.Uint16(buf[49:51])) + } + return p, nil +} diff --git a/peer/control/ping_test.go b/peer/control/ping_test.go new file mode 100644 index 0000000..df1af53 --- /dev/null +++ b/peer/control/ping_test.go @@ -0,0 +1,106 @@ +package control_test + +import ( + "net/netip" + "testing" + + "vppn/peer/control" +) + +func TestRoundTrip(t *testing.T) { + cases := []struct { + name string + ping control.Ping + }{ + { + name: "zero", + ping: control.Ping{}, + }, + { + name: "client ping", + ping: control.Ping{ + PingTS: 1234567890, + SrcV4: netip.MustParseAddrPort("1.2.3.4:51820"), + Dst: netip.MustParseAddrPort("5.6.7.8:51820"), + }, + }, + { + name: "server response", + ping: control.Ping{ + PingTS: 1234567890, + SrcV4: netip.MustParseAddrPort("5.6.7.8:51820"), + Dst: netip.MustParseAddrPort("1.2.3.4:9999"), + }, + }, + { + name: "IPv6 only", + ping: control.Ping{ + PingTS: 999, + SrcV6: netip.MustParseAddrPort("[2001:db8::1]:51820"), + Dst: netip.MustParseAddrPort("[2001:db8::2]:51820"), + }, + }, + { + name: "dual stack", + ping: control.Ping{ + PingTS: 555, + SrcV4: netip.MustParseAddrPort("1.2.3.4:51820"), + SrcV6: netip.MustParseAddrPort("[2001:db8::1]:51820"), + Dst: netip.MustParseAddrPort("5.6.7.8:9999"), + }, + }, + { + name: "no src known", + ping: control.Ping{ + Dst: netip.MustParseAddrPort("5.6.7.8:51820"), + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + var buf [control.Size]byte + tc.ping.Marshal(buf[:]) + got, err := control.Unmarshal(buf) + if err != nil { + t.Fatalf("Unmarshal: %v", err) + } + if got != tc.ping { + t.Fatalf("round-trip mismatch:\n got %+v\n want %+v", got, tc.ping) + } + }) + } +} + +func TestUnmarshalBadVersion(t *testing.T) { + var buf [control.Size]byte + buf[0] = 99 + if _, err := control.Unmarshal(buf); err == nil { + t.Fatal("expected error for unknown version, got nil") + } +} + +func TestZeroEncoding(t *testing.T) { + var buf [control.Size]byte + (control.Ping{}).Marshal(buf[:]) + for i, b := range buf { + if i == 0 { + continue // version byte + } + if b != 0 { + t.Fatalf("expected zero encoding at byte %d, got %d", i, b) + } + } +} + +func TestRoleFor(t *testing.T) { + lo := netip.MustParseAddr("10.0.0.1") + hi := netip.MustParseAddr("10.0.0.2") + + if control.RoleFor(lo, hi) != control.Client { + t.Error("lower IP should be client") + } + if control.RoleFor(hi, lo) != control.Server { + t.Error("higher IP should be server") + } +} diff --git a/peer/control/role.go b/peer/control/role.go new file mode 100644 index 0000000..8dbd2cd --- /dev/null +++ b/peer/control/role.go @@ -0,0 +1,22 @@ +package control + +import "net/netip" + +// Role identifies a peer's role in a ping exchange with a specific remote peer. +type Role string + +const ( + // Client initiates pings and measures RTT. + Client Role = "CLIENT" + // Server responds to pings. + Server Role = "SERVER" +) + +// RoleFor returns the Role of local relative to remote. +// The peer with the lower VPN IP is the client. +func RoleFor(local, remote netip.Addr) Role { + if local.Compare(remote) < 0 { + return Client + } + return Server +} diff --git a/peer/control_conn.go b/peer/control_conn.go new file mode 100644 index 0000000..d7f9746 --- /dev/null +++ b/peer/control_conn.go @@ -0,0 +1,64 @@ +package peer + +import ( + "log" + "net" + "net/netip" + + "vppn/peer/control" +) + +var _ ControlConn = (*udpControlConn)(nil) + +type udpControlConn struct { + conn *net.UDPConn +} + +// newUDPControlConn opens a UDP socket bound to localIP:port. +func newUDPControlConn(localIP netip.Addr, port uint16) (*udpControlConn, error) { + addr := net.UDPAddrFromAddrPort(netip.AddrPortFrom(localIP, port)) + conn, err := net.ListenUDP("udp4", addr) + if err != nil { + return nil, err + } + return &udpControlConn{conn: conn}, nil +} + +func (c *udpControlConn) SendPing(dst netip.AddrPort, ping control.Ping, buf []byte) error { + _, err := c.conn.WriteToUDP(ping.Marshal(buf), net.UDPAddrFromAddrPort(dst)) + return err +} + +// run reads incoming ping packets and forwards them to ch until ctx is done. +// Call this in a goroutine before starting the App event loop. +func (c *udpControlConn) run(ch chan<- PingEvent) { + var buf [control.Size]byte + for { + n, src, err := c.conn.ReadFromUDP(buf[:]) + if err != nil { + log.Printf("control read: %v", err) + continue + } + + if n != control.Size { + continue + } + + ping, err := control.Unmarshal(buf) + if err != nil { + log.Printf("control unmarshal: %v", err) + continue + } + + srcIP, ok := netip.AddrFromSlice(src.IP) + if !ok { + continue + } + + ch <- PingEvent{srcVPNIP: srcIP.Unmap(), ping: ping} + } +} + +func (c *udpControlConn) Close() error { + return c.conn.Close() +} diff --git a/peer/controlmessage.go b/peer/controlmessage.go deleted file mode 100644 index f327291..0000000 --- a/peer/controlmessage.go +++ /dev/null @@ -1,64 +0,0 @@ -package peer - -import ( - "net/netip" - "vppn/m" -) - -// ---------------------------------------------------------------------------- - -type controlMsg[T any] struct { - SrcIP byte - SrcAddr netip.AddrPort - Packet T -} - -func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error) { - switch buf[0] { - - case packetTypeInit: - packet, err := parsePacketInit(buf) - return controlMsg[packetInit]{ - SrcIP: srcIP, - SrcAddr: srcAddr, - Packet: packet, - }, err - - case packetTypeSyn: - packet, err := parsePacketSyn(buf) - return controlMsg[packetSyn]{ - SrcIP: srcIP, - SrcAddr: srcAddr, - Packet: packet, - }, err - - case packetTypeAck: - packet, err := parsePacketAck(buf) - return controlMsg[packetAck]{ - SrcIP: srcIP, - SrcAddr: srcAddr, - Packet: packet, - }, err - - case packetTypeProbe: - packet, err := parsePacketProbe(buf) - return controlMsg[packetProbe]{ - SrcIP: srcIP, - SrcAddr: srcAddr, - Packet: packet, - }, err - - default: - return nil, errUnknownPacketType - } -} - -// ---------------------------------------------------------------------------- - -type peerUpdateMsg struct { - Peer *m.Peer -} - -// ---------------------------------------------------------------------------- - -type pingTimerMsg struct{} diff --git a/peer/crypto.go b/peer/crypto.go deleted file mode 100644 index a533e6d..0000000 --- a/peer/crypto.go +++ /dev/null @@ -1,30 +0,0 @@ -package peer - -import ( - "crypto/rand" - "log" - - "golang.org/x/crypto/nacl/box" - "golang.org/x/crypto/nacl/sign" -) - -type cryptoKeys struct { - PubKey []byte - PrivKey []byte - PubSignKey []byte - PrivSignKey []byte -} - -func generateKeys() cryptoKeys { - pubKey, privKey, err := box.GenerateKey(rand.Reader) - if err != nil { - log.Fatalf("Failed to generate encryption keys: %v", err) - } - - pubSignKey, privSignKey, err := sign.GenerateKey(rand.Reader) - if err != nil { - log.Fatalf("Failed to generate signing keys: %v", err) - } - - return cryptoKeys{pubKey[:], privKey[:], pubSignKey[:], privSignKey[:]} -} diff --git a/peer/device.go b/peer/device.go new file mode 100644 index 0000000..3a3ba1c --- /dev/null +++ b/peer/device.go @@ -0,0 +1,74 @@ +package peer + +import ( + "errors" + "log" + "net/netip" + "syscall" + "time" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +// devRetry calls fn up to 6 times with exponential backoff, retrying on EBUSY +// (transient netlink contention during WireGuard handshake/rekey). Fatal on +// any other error. +func devRetry(vpnIP netip.Addr, op string, fn func() error) { + const attempts = 6 + timeout := 10 * time.Millisecond + for i := range attempts { + err := fn() + if err == nil { + return + } + if errors.Is(err, syscall.EBUSY) && i < attempts-1 { + time.Sleep(timeout) + timeout *= 2 + continue + } + log.Fatalf("%s %v: %v", op, vpnIP, err) + } +} + +func (a *App) devPeers() []wgtypes.Peer { + peers, err := a.dev.Peers() + if err != nil { + log.Fatalf("Failed to get peers %v: %v", a.vpnIP, err) + } + return peers +} + +func (a *App) devAddPeer(p *Peer) { + log.Printf("RELAYED: %s - %s ", p.Name, p.VPNIP.String()) + devRetry(p.VPNIP, "AddPeer", func() error { return a.dev.AddPeer(p.PubKey()) }) + p.State = StateRelayed +} + +func (a *App) devAddDirect(p *Peer, endpoint netip.AddrPort) { + log.Printf("DIRECT: %s - %s @ %s", p.Name, p.VPNIP.String(), endpoint.String()) + devRetry(p.VPNIP, "AddDirect", func() error { return a.dev.AddDirect(p.PubKey(), endpoint, p.VPNIP) }) + p.State = StateDirect +} + +func (a *App) devSetRelay(p *Peer, endpoint netip.AddrPort) { + log.Printf("RELAY: %s - %s @ %s", p.Name, p.VPNIP.String(), endpoint.String()) + devRetry(p.VPNIP, "SetRelay", func() error { return a.dev.SetRelay(p.PubKey(), endpoint, a.vpnNet) }) + p.State = StateDirect // Dirrect connection. The app marks peer as relay. +} + +func (a *App) devPromote(p *Peer) { + log.Printf("PROMOTED: %s - %s @ %s", p.Name, p.VPNIP.String(), p.WGEndpoint().String()) + devRetry(p.VPNIP, "Promote", func() error { return a.dev.Promote(p.PubKey(), p.VPNIP) }) + p.State = StateDirect +} + +func (a *App) devAddProbe(p *Peer, endpoint netip.AddrPort) { + log.Printf("PROBE: %s - %s @ %s", p.Name, p.VPNIP.String(), endpoint.String()) + devRetry(p.VPNIP, "AddProbe", func() error { return a.dev.AddProbe(p.PubKey(), endpoint) }) + p.State = StateProbing +} + +func (a *App) devRemove(p *Peer) { + log.Printf("REMOVED: %s - %s", p.Name, p.VPNIP.String()) + devRetry(p.VPNIP, "RemovePeer", func() error { return a.dev.RemovePeer(p.PubKey()) }) +} diff --git a/peer/dupcheck.go b/peer/dupcheck.go deleted file mode 100644 index 2394b15..0000000 --- a/peer/dupcheck.go +++ /dev/null @@ -1,76 +0,0 @@ -package peer - -type dupCheck struct { - bitSet - head int - tail int - headCounter uint64 - tailCounter uint64 // Also next expected counter value. -} - -func newDupCheck(headCounter uint64) *dupCheck { - return &dupCheck{ - headCounter: headCounter, - tailCounter: headCounter + 1, - tail: 1, - } -} - -func (dc *dupCheck) IsDup(counter uint64) bool { - - // Before head => it's late, say it's a dup. - if counter < dc.headCounter { - return true - } - - // It's within the counter bounds. - if counter < dc.tailCounter { - index := (int(counter-dc.headCounter) + dc.head) % bitSetSize - if dc.Get(index) { - return true - } - - dc.Set(index) - return false - } - - // It's more than 1 beyond the tail. - delta := counter - dc.tailCounter - - // Full clear. - if delta >= bitSetSize-1 { - dc.ClearAll() - dc.Set(0) - - dc.tail = 1 - dc.head = 2 - dc.tailCounter = counter + 1 - dc.headCounter = dc.tailCounter - bitSetSize + 1 - - return false - } - - // Clear if necessary. - for range delta { - dc.put(false) - } - - dc.put(true) - return false -} - -func (dc *dupCheck) put(set bool) { - if set { - dc.Set(dc.tail) - } else { - dc.Clear(dc.tail) - } - - dc.tail = (dc.tail + 1) % bitSetSize - dc.tailCounter++ - - if dc.head == dc.tail { - dc.head = (dc.head + 1) % bitSetSize - dc.headCounter++ - } -} diff --git a/peer/dupcheck_test.go b/peer/dupcheck_test.go deleted file mode 100644 index 2b50d74..0000000 --- a/peer/dupcheck_test.go +++ /dev/null @@ -1,57 +0,0 @@ -package peer - -import ( - "testing" -) - -func TestDupCheck(t *testing.T) { - dc := newDupCheck(0) - - for i := range bitSetSize { - if dc.IsDup(uint64(i)) { - t.Fatal("!") - } - } - - type TestCase struct { - Counter uint64 - Dup bool - } - - testCases := []TestCase{ - {511, true}, - {0, true}, - {1, true}, - {2, true}, - {3, true}, - {63, true}, - {256, true}, - {510, true}, - {511, true}, - {512, false}, - {0, true}, - {512, true}, - {513, false}, - {517, false}, - {512, true}, - {513, true}, - {514, false}, - {515, false}, - {516, false}, - {517, true}, - {2512, false}, - {2512, true}, - {2001, true}, - {2002, false}, - {2002, true}, - {4000, false}, - {4000 - 511, true}, // Too old. - {4000 - 510, false}, // Just in the window. - } - - for i, tc := range testCases { - if ok := dc.IsDup(tc.Counter); ok != tc.Dup { - t.Fatal(i, ok, tc) - } - } -} diff --git a/peer/errors.go b/peer/errors.go deleted file mode 100644 index 5ab1df8..0000000 --- a/peer/errors.go +++ /dev/null @@ -1,8 +0,0 @@ -package peer - -import "errors" - -var ( - errMalformedPacket = errors.New("malformed packet") - errUnknownPacketType = errors.New("unknown packet type") -) diff --git a/peer/fake_control_conn_test.go b/peer/fake_control_conn_test.go new file mode 100644 index 0000000..640bb3f --- /dev/null +++ b/peer/fake_control_conn_test.go @@ -0,0 +1,43 @@ +package peer + +import ( + "net/netip" + "testing" + + "vppn/peer/control" +) + +type sentPing struct { + Dst netip.AddrPort + Ping control.Ping +} + +type fakeControlConn struct { + Sent []sentPing +} + +func (f *fakeControlConn) SendPing(dst netip.AddrPort, ping control.Ping, _ []byte) error { + f.Sent = append(f.Sent, sentPing{Dst: dst, Ping: ping}) + return nil +} + +func (f *fakeControlConn) AssertNone(t *testing.T) { + t.Helper() + if len(f.Sent) != 0 { + t.Fatalf("expected no pings sent, got %d: %v", len(f.Sent), f.Sent) + } +} + +func (f *fakeControlConn) AssertSent(t *testing.T, i int, dst netip.AddrPort, ping control.Ping) { + t.Helper() + if i >= len(f.Sent) { + t.Fatalf("no ping at index %d (have %d)", i, len(f.Sent)) + } + got := f.Sent[i] + if got.Dst != dst { + t.Errorf("ping[%d].Dst = %v, want %v", i, got.Dst, dst) + } + if got.Ping != ping { + t.Errorf("ping[%d].Ping = %+v, want %+v", i, got.Ping, ping) + } +} diff --git a/peer/fake_wgdevice_test.go b/peer/fake_wgdevice_test.go new file mode 100644 index 0000000..a21cd54 --- /dev/null +++ b/peer/fake_wgdevice_test.go @@ -0,0 +1,123 @@ +package peer + +import ( + "net/netip" + "sync" + "testing" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +// fakeWGDevice records every call made to it. It is safe to read Calls after +// the event loop has processed the event under test (single-threaded loop +// means no extra synchronisation needed, but the mutex guards concurrent test +// helpers if needed). +type fakeWGDevice struct { + mu sync.Mutex + Calls []fakeCall + peers []wgtypes.Peer +} + +type fakeCall struct { + Method string + PubKey wgtypes.Key + Endpoint netip.AddrPort + VPNiP netip.Addr + Network netip.Prefix +} + +func (f *fakeWGDevice) record(c fakeCall) { + f.mu.Lock() + f.Calls = append(f.Calls, c) + f.mu.Unlock() +} + +func (f *fakeWGDevice) Name() string { return "wg-test" } + +func (f *fakeWGDevice) Peers() ([]wgtypes.Peer, error) { + f.mu.Lock() + defer f.mu.Unlock() + out := make([]wgtypes.Peer, len(f.peers)) + copy(out, f.peers) + return out, nil +} + +func (f *fakeWGDevice) AddPeer(pubKey wgtypes.Key) error { + f.record(fakeCall{Method: "AddPeer", PubKey: pubKey}) + return nil +} + +func (f *fakeWGDevice) AddDirect(pubKey wgtypes.Key, endpoint netip.AddrPort, vpnIP netip.Addr) error { + f.record(fakeCall{Method: "AddDirect", PubKey: pubKey, Endpoint: endpoint, VPNiP: vpnIP}) + return nil +} + +func (f *fakeWGDevice) SetRelay(pubKey wgtypes.Key, endpoint netip.AddrPort, network netip.Prefix) error { + f.record(fakeCall{Method: "SetRelay", PubKey: pubKey, Endpoint: endpoint, Network: network}) + return nil +} + +func (f *fakeWGDevice) AddProbe(pubKey wgtypes.Key, endpoint netip.AddrPort) error { + f.record(fakeCall{Method: "AddProbe", PubKey: pubKey, Endpoint: endpoint}) + return nil +} + +func (f *fakeWGDevice) Promote(pubKey wgtypes.Key, vpnIP netip.Addr) error { + f.record(fakeCall{Method: "Promote", PubKey: pubKey, VPNiP: vpnIP}) + return nil +} + +func (f *fakeWGDevice) RemovePeer(pubKey wgtypes.Key) error { + f.record(fakeCall{Method: "RemovePeer", PubKey: pubKey}) + return nil +} + +// AssertNoCalls fails the test if any dev calls were recorded. +func (f *fakeWGDevice) AssertNoCalls(t *testing.T) { + t.Helper() + f.mu.Lock() + defer f.mu.Unlock() + if len(f.Calls) != 0 { + t.Fatalf("unexpected dev calls: %v", f.Calls) + } +} + +func (f *fakeWGDevice) AssertAddPeer(t *testing.T, i int, pubKey wgtypes.Key) { + t.Helper() + f.assertCall(t, i, fakeCall{Method: "AddPeer", PubKey: pubKey}) +} + +func (f *fakeWGDevice) AssertAddDirect(t *testing.T, i int, pubKey wgtypes.Key, endpoint netip.AddrPort, vpnIP netip.Addr) { + t.Helper() + f.assertCall(t, i, fakeCall{Method: "AddDirect", PubKey: pubKey, Endpoint: endpoint, VPNiP: vpnIP}) +} + +func (f *fakeWGDevice) AssertSetRelay(t *testing.T, i int, pubKey wgtypes.Key, endpoint netip.AddrPort, network netip.Prefix) { + t.Helper() + f.assertCall(t, i, fakeCall{Method: "SetRelay", PubKey: pubKey, Endpoint: endpoint, Network: network}) +} + +func (f *fakeWGDevice) AssertAddProbe(t *testing.T, i int, pubKey wgtypes.Key, endpoint netip.AddrPort) { + t.Helper() + f.assertCall(t, i, fakeCall{Method: "AddProbe", PubKey: pubKey, Endpoint: endpoint}) +} + +func (f *fakeWGDevice) AssertPromote(t *testing.T, i int, pubKey wgtypes.Key, vpnIP netip.Addr) { + t.Helper() + f.assertCall(t, i, fakeCall{Method: "Promote", PubKey: pubKey, VPNiP: vpnIP}) +} + +func (f *fakeWGDevice) AssertRemovePeer(t *testing.T, i int, pubKey wgtypes.Key) { + t.Helper() + f.assertCall(t, i, fakeCall{Method: "RemovePeer", PubKey: pubKey}) +} + +func (f *fakeWGDevice) assertCall(t *testing.T, i int, c fakeCall) { + t.Helper() + if len(f.Calls) <= i { + t.Fatalf("no call at index %d: %v", i, c) + } + if c != f.Calls[i] { + t.Fatalf("call[%d]: got %v, want %v", i, f.Calls[i], c) + } +} diff --git a/peer/files.go b/peer/files.go deleted file mode 100644 index 6e6afe5..0000000 --- a/peer/files.go +++ /dev/null @@ -1,115 +0,0 @@ -package peer - -import ( - "encoding/json" - "log" - "os" - "path/filepath" - "vppn/m" -) - -type LocalConfig struct { - LocalPeerIP byte - Network []byte - PubKey []byte - PrivKey []byte - PubSignKey []byte - PrivSignKey []byte -} - -type startupCount struct { - Count uint16 -} - -func configDir(netName string) string { - d, err := os.UserHomeDir() - if err != nil { - log.Fatalf("Failed to get user home directory: %v", err) - } - return filepath.Join(d, ".vppn", netName) -} - -func lockFilePath(netName string) string { - return filepath.Join(configDir(netName), "__lock__") -} - -func peerConfigPath(netName string) string { - return filepath.Join(configDir(netName), "config.json") -} - -func peerStatePath(netName string) string { - return filepath.Join(configDir(netName), "state.json") -} - -func startupCountPath(netName string) string { - return filepath.Join(configDir(netName), "startup_count.json") -} - -func statusSocketPath(netName string) string { - return filepath.Join(configDir(netName), "status.sock") -} - -func storeJson(x any, outPath string) error { - outDir := filepath.Dir(outPath) - _ = os.MkdirAll(outDir, 0700) - - tmpPath := outPath + ".tmp" - buf, err := json.Marshal(x) - if err != nil { - return err - } - - f, err := os.Create(tmpPath) - if err != nil { - return err - } - - if _, err := f.Write(buf); err != nil { - f.Close() - return err - } - - if err := f.Sync(); err != nil { - f.Close() - return err - } - - if err := f.Close(); err != nil { - return err - } - - return os.Rename(tmpPath, outPath) -} - -func storePeerConfig(netName string, pc LocalConfig) error { - return storeJson(pc, peerConfigPath(netName)) -} - -func storeNetworkState(netName string, ps m.NetworkState) error { - return storeJson(ps, peerStatePath(netName)) -} - -func loadJson(dataPath string, ptr any) error { - data, err := os.ReadFile(dataPath) - if err != nil { - return err - } - - return json.Unmarshal(data, ptr) -} - -func loadPeerConfig(netName string) (pc LocalConfig, err error) { - return pc, loadJson(peerConfigPath(netName), &pc) -} - -func loadNetworkState(netName string) (ps m.NetworkState, err error) { - return ps, loadJson(peerStatePath(netName), &ps) -} - -func loadStartupCount(netName string) (c startupCount, err error) { - return c, loadJson(startupCountPath(netName), &c) -} - -func storeStartupCount(netName string, c startupCount) error { - return storeJson(c, startupCountPath(netName)) -} diff --git a/peer/files_test.go b/peer/files_test.go deleted file mode 100644 index 5a7f334..0000000 --- a/peer/files_test.go +++ /dev/null @@ -1,57 +0,0 @@ -package peer - -import ( - "path/filepath" - "reflect" - "testing" -) - -func TestFilePaths(t *testing.T) { - confDir := configDir("netName") - if filepath.Base(confDir) != "netName" { - t.Fatal(confDir) - } - if filepath.Base(filepath.Dir(confDir)) != ".vppn" { - t.Fatal(confDir) - } - - path := peerConfigPath("netName") - if path != filepath.Join(confDir, "config.json") { - t.Fatal(path) - } - - path = peerStatePath("netName") - if path != filepath.Join(confDir, "state.json") { - t.Fatal(path) - } -} - -func TestStoreLoadJson(t *testing.T) { - type Object struct { - Name string - Age int - Price float64 - } - - tmpDir := t.TempDir() - outPath := filepath.Join(tmpDir, "object.json") - - obj := Object{ - Name: "Jason", - Age: 22, - Price: 123.534, - } - - if err := storeJson(obj, outPath); err != nil { - t.Fatal(err) - } - - obj2 := Object{} - if err := loadJson(outPath, &obj2); err != nil { - t.Fatal(err) - } - - if !reflect.DeepEqual(obj, obj2) { - t.Fatal(obj, obj2) - } -} diff --git a/peer/globals.go b/peer/globals.go deleted file mode 100644 index 861a319..0000000 --- a/peer/globals.go +++ /dev/null @@ -1,109 +0,0 @@ -package peer - -import ( - "io" - "net" - "net/netip" - "sync" - "sync/atomic" - "time" -) - -const ( - version = 1 - - bufferSize = 8192 // Enough for data packets and encryption buffers. - - if_mtu = 1200 - if_queue_len = 2048 - - controlCipherOverhead = 16 - dataCipherOverhead = 16 - signingOverhead = 64 - - pingInterval = 8 * time.Second - timeoutInterval = 30 * time.Second - broadcastInterval = 16 * time.Second - broadcastErrorTimeoutInterval = 8 * time.Second -) - -var multicastAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom( - netip.AddrFrom4([4]byte{224, 0, 0, 157}), - 4560)) - -// ---------------------------------------------------------------------------- - -type Globals struct { - LocalConfig // Embed, immutable. - - // The number of startups - StartupCount uint16 - - // Local public address (if available). Immutable. - LocalAddr netip.AddrPort - - // True if local public address is valid. Immutable. - LocalAddrValid bool - - // All remote peers by VPN IP. - RemotePeers [256]*atomic.Pointer[Remote] - - // Discovered public addresses. - PubAddrs *pubAddrStore - - // Attempts to ensure that we have a relay available. - RelayHandler *relayHandler - - // Send UDP - Global function to write UDP packets. - SendUDP func(b []byte, addr netip.AddrPort) (n int, err error) - - // Global TUN interface. - IFace io.ReadWriteCloser - - // For trace ID. - NewTraceID func() uint64 -} - -func NewGlobals( - localConfig LocalConfig, - startupCount startupCount, - localAddr netip.AddrPort, - conn *net.UDPConn, - iface io.ReadWriteCloser, -) (g Globals) { - g.LocalConfig = localConfig - g.StartupCount = startupCount.Count - - g.LocalAddr = localAddr - g.LocalAddrValid = localAddr.IsValid() - - g.PubAddrs = newPubAddrStore(localAddr) - - g.RelayHandler = newRelayHandler() - - // Use a lock here avoids starvation, at least on my Linux machine. - sendLock := sync.Mutex{} - g.SendUDP = func(b []byte, addr netip.AddrPort) (int, error) { - sendLock.Lock() - n, err := conn.WriteToUDPAddrPort(b, addr) - sendLock.Unlock() - return n, err - } - - g.IFace = iface - - traceID := (uint64(g.StartupCount) << 48) + 1 - g.NewTraceID = func() uint64 { - return atomic.AddUint64(&traceID, 1) - } - - for i := range g.RemotePeers { - g.RemotePeers[i] = &atomic.Pointer[Remote]{} - } - - for i := range g.RemotePeers { - g.RemotePeers[i].Store(newRemote(g, byte(i))) - } - - return g -} diff --git a/peer/header.go b/peer/header.go deleted file mode 100644 index 887c4dd..0000000 --- a/peer/header.go +++ /dev/null @@ -1,47 +0,0 @@ -package peer - -import "unsafe" - -// ---------------------------------------------------------------------------- - -const ( - headerSize = 12 - controlHeaderSize = 24 - dataHeaderSize = 12 - - dataStreamID = 1 - controlStreamID = 2 -) - -type Header struct { - Version byte - StreamID byte - SourceIP byte - DestIP byte - Counter uint64 // Init with time.Now().Unix << 30 to ensure monotonic. -} - -func parseHeader(b []byte) (h Header) { - h.Version = b[0] - h.StreamID = b[1] - h.SourceIP = b[2] - h.DestIP = b[3] - h.Counter = *(*uint64)(unsafe.Pointer(&b[4])) - return h -} - -func (h *Header) Parse(b []byte) { - h.Version = b[0] - h.StreamID = b[1] - h.SourceIP = b[2] - h.DestIP = b[3] - h.Counter = *(*uint64)(unsafe.Pointer(&b[4])) -} - -func (h *Header) Marshal(buf []byte) { - buf[0] = h.Version - buf[1] = h.StreamID - buf[2] = h.SourceIP - buf[3] = h.DestIP - *(*uint64)(unsafe.Pointer(&buf[4])) = h.Counter -} diff --git a/peer/header_test.go b/peer/header_test.go deleted file mode 100644 index f644a36..0000000 --- a/peer/header_test.go +++ /dev/null @@ -1,21 +0,0 @@ -package peer - -import "testing" - -func TestHeaderMarshalParse(t *testing.T) { - nIn := Header{ - StreamID: 23, - Counter: 3212, - SourceIP: 34, - DestIP: 200, - } - - buf := make([]byte, headerSize) - nIn.Marshal(buf) - - nOut := Header{} - nOut.Parse(buf) - if nIn != nOut { - t.Fatal(nIn, nOut) - } -} diff --git a/peer/hosts.go b/peer/hosts.go new file mode 100644 index 0000000..fc47d20 --- /dev/null +++ b/peer/hosts.go @@ -0,0 +1,128 @@ +package peer + +import ( + "fmt" + "log" + "net/netip" + "os" + "sort" + "strings" + "syscall" + + "git.crumpington.com/lib/go/flock" +) + +const ( + hostsFile = "/etc/hosts" + hostsBegin = "# BEGIN vppn" + hostsEnd = "# END vppn" +) + +// hostMarkers returns the begin/end marker lines that delimit the managed +// section for localDomain. The domain is wrapped in parentheses so one domain's +// marker can never be a prefix of another's (e.g. "net" vs "net2") when +// multiple vppn instances share /etc/hosts. +func hostMarkers(localDomain string) (begin, end string) { + return hostsBegin + "(" + localDomain + ")", hostsEnd + "(" + localDomain + ")" +} + +// updateHosts rewrites the managed vppn section in /etc/hosts using the +// current peersByIP map. Peers without a Name are skipped. +func (a *App) updateHosts() { + if a.localDomain == "" { + return + } + if err := updateHosts(hostsFile, a.localDomain, a.peersByIP); err != nil { + log.Printf("Failed to update hosts file: %v", err) + } +} + +func updateHosts(hostsPath, localDomain string, peers map[netip.Addr]*Peer) error { + lockFile, err := flock.Lock(hostsPath + ".vppn.lock") + if err != nil { + return err + } + defer lockFile.Close() + + begin, end := hostMarkers(localDomain) + + info, err := os.Stat(hostsPath) + if err != nil { + return err + } + + raw, err := os.ReadFile(hostsPath) + if err != nil { + return err + } + + data := string(raw) + + before := strings.TrimSpace(data) + after := "" + + if idxBegin := strings.Index(data, begin); idxBegin != -1 { + idxEnd := strings.Index(data[idxBegin:], end) + if idxEnd != -1 { + after = strings.TrimSpace(data[idxBegin+idxEnd+len(end):]) + } + before = strings.TrimSpace(data[:idxBegin]) + } + + b := strings.Builder{} + b.WriteString(before) + b.WriteRune('\n') + b.WriteString(after) + b.WriteRune('\n') + b.WriteRune('\n') + + b.WriteString(begin) + b.WriteRune('\n') + + // Collect entries so we can sort by IP for stable output. Pad the IP + // column to the width of the widest possible address ("255.255.255.255") + // for readability. + type entry struct { + ip netip.Addr + host string + } + var entries []entry + for ip, p := range peers { + if p.Name == "" { + continue + } + entries = append(entries, entry{ip: ip, host: p.Name + "." + localDomain}) + } + sort.Slice(entries, func(i, j int) bool { + return entries[i].ip.Less(entries[j].ip) + }) + + for _, e := range entries { + b.WriteString(fmt.Sprintf("%-15s %s\n", e.ip.String(), e.host)) + } + + b.WriteString(end) + b.WriteRune('\n') + + // Write to a temp file in the same directory, then rename over the + // original so readers never observe a partial file. Preserve the + // original's mode and ownership, since rename replaces the inode. + tmpPath := hostsPath + ".vppn.tmp" + if err := os.WriteFile(tmpPath, []byte(b.String()), info.Mode().Perm()); err != nil { + return err + } + + if st, ok := info.Sys().(*syscall.Stat_t); ok { + if err := os.Chown(tmpPath, int(st.Uid), int(st.Gid)); err != nil { + os.Remove(tmpPath) + return err + } + } + + if err := os.Rename(tmpPath, hostsPath); err != nil { + os.Remove(tmpPath) + return err + } + + return nil +} diff --git a/peer/hosts_test.go b/peer/hosts_test.go new file mode 100644 index 0000000..6b957c3 --- /dev/null +++ b/peer/hosts_test.go @@ -0,0 +1,205 @@ +package peer + +import ( + "net/netip" + "os" + "path/filepath" + "sort" + "strings" + "testing" +) + +// writeTempHosts creates a temp hosts file with the given content and returns +// its path. +func writeTempHosts(t *testing.T, content string) string { + t.Helper() + path := filepath.Join(t.TempDir(), "hosts") + if err := os.WriteFile(path, []byte(content), 0600); err != nil { + t.Fatal(err) + } + return path +} + +// readManagedSection returns the lines between the begin/end markers for the +// given localDomain, plus everything outside the section ("outside"). +func readManagedSection(t *testing.T, path, localDomain string) (inside, outside []string) { + t.Helper() + raw, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + begin, end := hostMarkers(localDomain) + + inSection := false + for _, line := range strings.Split(string(raw), "\n") { + switch { + case strings.HasPrefix(line, begin): + inSection = true + case strings.HasPrefix(line, end): + inSection = false + case inSection: + if f := strings.Join(strings.Fields(line), " "); f != "" { + inside = append(inside, f) + } + default: + if f := strings.Join(strings.Fields(line), " "); f != "" { + outside = append(outside, f) + } + } + } + return inside, outside +} + +func peer(name string) *Peer { + return &Peer{Name: name} +} + +func TestUpdateHosts_AddsSection(t *testing.T) { + path := writeTempHosts(t, "127.0.0.1 localhost\n") + + peers := map[netip.Addr]*Peer{ + netip.MustParseAddr("10.11.12.1"): peer("hub"), + netip.MustParseAddr("10.11.12.10"): peer("laptop"), + } + + if err := updateHosts(path, "mynet.local", peers); err != nil { + t.Fatal(err) + } + + inside, outside := readManagedSection(t, path, "mynet.local") + + sort.Strings(inside) + want := []string{ + "10.11.12.1 hub.mynet.local", + "10.11.12.10 laptop.mynet.local", + } + if strings.Join(inside, "\n") != strings.Join(want, "\n") { + t.Errorf("managed section = %v, want %v", inside, want) + } + + if !contains(outside, "127.0.0.1 localhost") { + t.Errorf("original content lost; outside = %v", outside) + } +} + +func TestUpdateHosts_ReplacesExistingSection(t *testing.T) { + path := writeTempHosts(t, "127.0.0.1 localhost\n") + + // First write. + first := map[netip.Addr]*Peer{ + netip.MustParseAddr("10.11.12.1"): peer("hub"), + } + if err := updateHosts(path, "mynet.local", first); err != nil { + t.Fatal(err) + } + + // Second write with a different set of peers. + second := map[netip.Addr]*Peer{ + netip.MustParseAddr("10.11.12.20"): peer("phone"), + } + if err := updateHosts(path, "mynet.local", second); err != nil { + t.Fatal(err) + } + + inside, outside := readManagedSection(t, path, "mynet.local") + + if len(inside) != 1 || inside[0] != "10.11.12.20 phone.mynet.local" { + t.Errorf("section not replaced; inside = %v", inside) + } + if contains(inside, "10.11.12.1 hub.mynet.local") { + t.Errorf("stale entry remained; inside = %v", inside) + } + if !contains(outside, "127.0.0.1 localhost") { + t.Errorf("original content lost; outside = %v", outside) + } +} + +func TestUpdateHosts_SkipsEmptyNames(t *testing.T) { + path := writeTempHosts(t, "127.0.0.1 localhost\n") + + peers := map[netip.Addr]*Peer{ + netip.MustParseAddr("10.11.12.1"): peer("hub"), + netip.MustParseAddr("10.11.12.99"): peer(""), // no name + } + if err := updateHosts(path, "mynet.local", peers); err != nil { + t.Fatal(err) + } + + inside, _ := readManagedSection(t, path, "mynet.local") + if len(inside) != 1 || inside[0] != "10.11.12.1 hub.mynet.local" { + t.Errorf("expected only named peer; inside = %v", inside) + } +} + +func TestUpdateHosts_Idempotent(t *testing.T) { + path := writeTempHosts(t, "127.0.0.1 localhost\n") + + peers := map[netip.Addr]*Peer{ + netip.MustParseAddr("10.11.12.1"): peer("hub"), + } + if err := updateHosts(path, "mynet.local", peers); err != nil { + t.Fatal(err) + } + first, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + if err := updateHosts(path, "mynet.local", peers); err != nil { + t.Fatal(err) + } + second, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + if string(first) != string(second) { + t.Errorf("repeated update changed file:\nfirst:\n%s\nsecond:\n%s", first, second) + } +} + +// TestUpdateHosts_PrefixDomainsCoexist guards finding 4.4: two domains where +// one label is a prefix of the other ("net" vs "net2") must each manage their +// own section without clobbering the other's, even sharing one hosts file. +func TestUpdateHosts_PrefixDomainsCoexist(t *testing.T) { + path := writeTempHosts(t, "127.0.0.1 localhost\n") + + if err := updateHosts(path, "net2.local", map[netip.Addr]*Peer{ + netip.MustParseAddr("10.0.2.1"): peer("a"), + }); err != nil { + t.Fatal(err) + } + if err := updateHosts(path, "net.local", map[netip.Addr]*Peer{ + netip.MustParseAddr("10.0.1.1"): peer("b"), + }); err != nil { + t.Fatal(err) + } + + // Both sections coexist after writing the prefix domain. + if in, _ := readManagedSection(t, path, "net2.local"); len(in) != 1 || in[0] != "10.0.2.1 a.net2.local" { + t.Errorf("net2 section clobbered: %v", in) + } + if in, _ := readManagedSection(t, path, "net.local"); len(in) != 1 || in[0] != "10.0.1.1 b.net.local" { + t.Errorf("net section wrong: %v", in) + } + + // Re-updating net2 must not disturb the net section. + if err := updateHosts(path, "net2.local", map[netip.Addr]*Peer{ + netip.MustParseAddr("10.0.2.2"): peer("c"), + }); err != nil { + t.Fatal(err) + } + if in, _ := readManagedSection(t, path, "net.local"); len(in) != 1 || in[0] != "10.0.1.1 b.net.local" { + t.Errorf("net section disturbed by net2 update: %v", in) + } + if in, _ := readManagedSection(t, path, "net2.local"); len(in) != 1 || in[0] != "10.0.2.2 c.net2.local" { + t.Errorf("net2 section not updated: %v", in) + } +} + +func contains(ss []string, s string) bool { + for _, x := range ss { + if x == s { + return true + } + } + return false +} diff --git a/peer/hub_poller.go b/peer/hub_poller.go new file mode 100644 index 0000000..2138e86 --- /dev/null +++ b/peer/hub_poller.go @@ -0,0 +1,153 @@ +package peer + +import ( + "encoding/json" + "io" + "log" + "net/http" + "net/netip" + "net/url" + "time" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "vppn/m" +) + +const hubPollInterval = 64 * time.Second + +type HubPoller struct { + selfVPNIP netip.Addr + vpnNet netip.Prefix + hubURL string + apiKey string + statePath string // where the network state cache is persisted + addCh chan<- m.Peer + removeCh chan<- wgtypes.Key + known map[wgtypes.Key]struct{} // pubKeys currently configured +} + +func NewHubPoller( + selfVPNIP netip.Addr, + vpnNet netip.Prefix, + hubURL, apiKey string, + statePath string, + addCh chan<- m.Peer, + removeCh chan<- wgtypes.Key, +) (*HubPoller, error) { + u, err := url.Parse(hubURL) + if err != nil { + return nil, err + } + u.Path = "/peer/fetch-state/" + + return &HubPoller{ + selfVPNIP: selfVPNIP, + vpnNet: vpnNet, + hubURL: u.String(), + apiKey: apiKey, + statePath: statePath, + addCh: addCh, + removeCh: removeCh, + known: make(map[wgtypes.Key]struct{}), + }, nil +} + +func (hp *HubPoller) Run() { + // Prime from the on-disk cache before reaching the hub, so the peer + // configures WireGuard from its last known state even if the hub is down. + // known starts empty, so this emits every cached peer as an add; the first + // real poll then emits only deltas (adds for new peers, removes for gone). + if state, err := loadNetworkState(hp.statePath); err == nil { + hp.apply(state) + } + + hp.poll() + for range time.Tick(hubPollInterval) { + hp.poll() + } +} + +func (hp *HubPoller) poll() { + req, err := http.NewRequest(http.MethodGet, hp.hubURL, nil) + if err != nil { + log.Printf("[HubPoller] build request: %v", err) + return + } + req.SetBasicAuth("", hp.apiKey) + + client := &http.Client{Timeout: 32 * time.Second} + resp, err := client.Do(req) + if err != nil { + log.Printf("[HubPoller] fetch: %v", err) + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + log.Printf("[HubPoller] unexpected status %d", resp.StatusCode) + return + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + log.Printf("[HubPoller] read body: %v", err) + return + } + + var state m.NetworkState + if err := json.Unmarshal(body, &state); err != nil { + log.Printf("[HubPoller] unmarshal: %v", err) + return + } + + // Persist only when the state actually changed, to avoid needless writes + // on every poll. + if hp.apply(state) { + if err := saveNetworkState(hp.statePath, state); err != nil { + log.Printf("[HubPoller] save state: %v", err) + } + } +} + +// apply diffs state against the set of known peers, emitting an add for each +// newly-seen peer and a remove for each that disappeared. It returns true if +// anything changed. A peer's config is immutable under a stable WG key (the hub +// has no peer-edit path), so a key already in known needs no re-emit. +func (hp *HubPoller) apply(state m.NetworkState) (changed bool) { + seen := make(map[wgtypes.Key]struct{}, len(hp.known)) + + netAddr := hp.vpnNet.Addr().As4() + + for _, p := range state.Peers { + if p.WGPubKey == (wgtypes.Key{}) { + continue + } + + octets := netAddr + octets[3] = p.PeerIP + vpnIP := netip.AddrFrom4(octets) + if vpnIP == hp.selfVPNIP { + continue + } + + seen[p.WGPubKey] = struct{}{} + + if _, ok := hp.known[p.WGPubKey]; ok { + continue + } + hp.known[p.WGPubKey] = struct{}{} + hp.addCh <- p + changed = true + } + + for key := range hp.known { + if _, ok := seen[key]; !ok { + delete(hp.known, key) + hp.removeCh <- key + changed = true + } + } + + return changed +} diff --git a/peer/hub_poller_test.go b/peer/hub_poller_test.go new file mode 100644 index 0000000..87efba3 --- /dev/null +++ b/peer/hub_poller_test.go @@ -0,0 +1,80 @@ +package peer + +import ( + "net/netip" + "testing" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "vppn/m" +) + +func testPoller(t *testing.T) (*HubPoller, chan m.Peer, chan wgtypes.Key) { + t.Helper() + addCh := make(chan m.Peer, 8) + removeCh := make(chan wgtypes.Key, 8) + hp := &HubPoller{ + selfVPNIP: netip.MustParseAddr("10.0.0.1"), + vpnNet: netip.MustParsePrefix("10.0.0.0/24"), + addCh: addCh, + removeCh: removeCh, + known: make(map[wgtypes.Key]struct{}), + } + return hp, addCh, removeCh +} + +func stateWith(key wgtypes.Key, peerIP byte) m.NetworkState { + return m.NetworkState{Peers: []m.Peer{{ + PeerIP: peerIP, + WGPubKey: key, + }}} +} + +func TestApply_EmitsAddsAndReportsChange(t *testing.T) { + hp, addCh, _ := testPoller(t) + key := mustKey(t) + + if changed := hp.apply(stateWith(key, 2)); !changed { + t.Fatal("expected changed=true on first apply") + } + if len(addCh) != 1 { + t.Fatalf("expected 1 add, got %d", len(addCh)) + } + if got := <-addCh; got.WGPubKey != key { + t.Errorf("add pubkey mismatch") + } +} + +func TestApply_NoChangeWhenKnown(t *testing.T) { + hp, addCh, _ := testPoller(t) + key := mustKey(t) + + hp.apply(stateWith(key, 2)) + <-addCh // drain initial add + + if changed := hp.apply(stateWith(key, 2)); changed { + t.Fatal("expected changed=false when peer already known") + } + if len(addCh) != 0 { + t.Fatalf("expected no re-emit, got %d adds", len(addCh)) + } +} + +func TestApply_RemovesVanishedPeer(t *testing.T) { + hp, addCh, removeCh := testPoller(t) + key := mustKey(t) + + hp.apply(stateWith(key, 2)) + <-addCh + + // Empty state: the peer is gone. + if changed := hp.apply(m.NetworkState{}); !changed { + t.Fatal("expected changed=true when peer vanishes") + } + if len(removeCh) != 1 { + t.Fatalf("expected 1 remove, got %d", len(removeCh)) + } + if got := <-removeCh; got != key { + t.Errorf("remove key mismatch") + } +} diff --git a/peer/hubpoller.go b/peer/hubpoller.go deleted file mode 100644 index a0d79d0..0000000 --- a/peer/hubpoller.go +++ /dev/null @@ -1,111 +0,0 @@ -package peer - -import ( - "encoding/json" - "io" - "log" - "net/http" - "net/url" - "time" - "vppn/m" -) - -type HubPoller struct { - Globals - client *http.Client - req *http.Request - versions [256]int64 - netName string -} - -func NewHubPoller( - g Globals, - netName, - hubURL, - apiKey string, -) (*HubPoller, error) { - u, err := url.Parse(hubURL) - if err != nil { - return nil, err - } - u.Path = "/peer/fetch-state/" - - client := &http.Client{Timeout: 8 * time.Second} - - req := &http.Request{ - Method: http.MethodGet, - URL: u, - Header: http.Header{}, - } - req.SetBasicAuth("", apiKey) - - return &HubPoller{ - Globals: g, - client: client, - req: req, - netName: netName, - }, nil -} - -func (hp *HubPoller) logf(s string, args ...any) { - log.Printf("[HubPoller] "+s, args...) -} - -func (hp *HubPoller) Run() { - state, err := loadNetworkState(hp.netName) - if err != nil { - hp.logf("Failed to load network state: %v", err) - hp.logf("Polling hub...") - hp.pollHub() - } else { - hp.applyNetworkState(state) - } - - for range time.Tick(64 * time.Second) { - hp.pollHub() - } -} - -func (hp *HubPoller) pollHub() { - var state m.NetworkState - - resp, err := hp.client.Do(hp.req) - if err != nil { - hp.logf("Failed to fetch peer state: %v", err) - return - } - body, err := io.ReadAll(resp.Body) - _ = resp.Body.Close() - if err != nil { - hp.logf("Failed to read body from hub: %v", err) - return - } - - if err := json.Unmarshal(body, &state); err != nil { - hp.logf("Failed to unmarshal response from hub: %v\n%s", err, body) - return - } - - if err := storeNetworkState(hp.netName, state); err != nil { - hp.logf("Failed to store network state: %v", err) - } - - hp.applyNetworkState(state) -} - -func (hp *HubPoller) applyNetworkState(state m.NetworkState) { - for i, peer := range state.Peers { - if i == int(hp.LocalPeerIP) { - continue - } - - if peer != nil && peer.Version == hp.versions[i] { - continue - } - - hp.RemotePeers[i].Load().HandlePeerUpdate(peerUpdateMsg{Peer: state.Peers[i]}) - if peer != nil { - hp.versions[i] = peer.Version - } - } -} diff --git a/peer/ifreader.go b/peer/ifreader.go deleted file mode 100644 index ebade54..0000000 --- a/peer/ifreader.go +++ /dev/null @@ -1,73 +0,0 @@ -package peer - -import ( - "log" -) - -type IFReader struct { - Globals -} - -func NewIFReader(g Globals) *IFReader { - return &IFReader{Globals: g} -} - -func (r *IFReader) Run() { - packet := make([]byte, bufferSize) - for { - r.handleNextPacket(packet) - } -} - -func (r *IFReader) handleNextPacket(packet []byte) { - packet = r.readNextPacket(packet) - remoteIP, ok := r.parsePacket(packet) - if !ok { - return - } - r.RemotePeers[remoteIP].Load().SendDataTo(packet) -} - -func (r *IFReader) readNextPacket(buf []byte) []byte { - n, err := r.IFace.Read(buf[:cap(buf)]) - if err != nil { - log.Fatalf("Failed to read from interface: %v", err) - } - - return buf[:n] -} - -// parsePacket returns the VPN ip for the packet, and a boolean indicating -// success. -func (r *IFReader) parsePacket(buf []byte) (byte, bool) { - n := len(buf) - if n == 0 { - return 0, false - } - - version := buf[0] >> 4 - - switch version { - case 4: - if n < 20 { - r.logf("Short IPv4 packet: %d", len(buf)) - return 0, false - } - return buf[19], true - - case 6: - if len(buf) < 40 { - r.logf("Short IPv6 packet: %d", len(buf)) - return 0, false - } - return buf[39], true - - default: - r.logf("Invalid IP packet version: %v", version) - return 0, false - } -} - -func (*IFReader) logf(s string, args ...any) { - log.Printf("[IFReader] "+s, args...) -} diff --git a/peer/ifreader_test.go b/peer/ifreader_test.go deleted file mode 100644 index 92ec5ac..0000000 --- a/peer/ifreader_test.go +++ /dev/null @@ -1,81 +0,0 @@ -package peer - -/* -func TestIFReader_IPv4(t *testing.T) { - p1, p2, _ := NewPeersForTesting() - - pkt := make([]byte, 1234) - pkt[0] = 4 << 4 - pkt[19] = 2 // IP. - - p1.IFace.UserWrite(pkt) - p1.IFReader.handleNextPacket(newBuf()) - - packets := p2.Conn.Packets() - if len(packets) != 1 { - t.Fatal(packets) - } -} - -func TestIFReader_IPv6(t *testing.T) { - p1, p2, _ := NewPeersForTesting() - - pkt := make([]byte, 1234) - pkt[0] = 6 << 4 - pkt[39] = 2 // IP. - - p1.IFace.UserWrite(pkt) - p1.IFReader.handleNextPacket(newBuf()) - - packets := p2.Conn.Packets() - if len(packets) != 1 { - t.Fatal(packets) - } -} - -func TestIFReader_parsePacket_emptyPacket(t *testing.T) { - r := NewIFReader(nil, nil) - pkt := make([]byte, 0) - if ip, ok := r.parsePacket(pkt); ok { - t.Fatal(ip, ok) - } -} - -func TestIFReader_parsePacket_invalidIPVersion(t *testing.T) { - r := NewIFReader(nil, nil) - - for i := byte(1); i < 16; i++ { - if i == 4 || i == 6 { - continue - } - pkt := make([]byte, 1234) - pkt[0] = i << 4 - - if ip, ok := r.parsePacket(pkt); ok { - t.Fatal(i, ip, ok) - } - } -} - -func TestIFReader_parsePacket_shortIPv4(t *testing.T) { - r := NewIFReader(nil, nil) - - pkt := make([]byte, 19) - pkt[0] = 4 << 4 - - if ip, ok := r.parsePacket(pkt); ok { - t.Fatal(ip, ok) - } -} - -func TestIFReader_parsePacket_shortIPv6(t *testing.T) { - r := NewIFReader(nil, nil) - - pkt := make([]byte, 39) - pkt[0] = 6 << 4 - - if ip, ok := r.parsePacket(pkt); ok { - t.Fatal(ip, ok) - } -} -*/ diff --git a/peer/init.go b/peer/init.go new file mode 100644 index 0000000..e3a1699 --- /dev/null +++ b/peer/init.go @@ -0,0 +1,190 @@ +package peer + +import ( + "bytes" + "crypto/rand" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "net/netip" + "os" + + "golang.org/x/crypto/nacl/sign" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "vppn/m" +) + +// LocalState is the persisted identity for this peer, written on first run and +// loaded on every subsequent run. +type LocalState struct { + PrivKey wgtypes.Key + SignKey [64]byte // nacl/sign Ed25519 private key + VPNIP netip.Addr + VPNNet netip.Prefix + WGPort uint16 + IsRelay bool + IsPublic bool + LocalDomain string +} + +// localStateJSON is the on-disk representation. +type localStateJSON struct { + PrivKey string + SignKey string + VPNIP netip.Addr + VPNNet netip.Prefix + WGPort uint16 + IsRelay bool + IsPublic bool + LocalDomain string +} + +// LoadOrInit loads LocalState from path, or registers with the hub and creates +// the file if it doesn't exist. +func LoadOrInit(statePath, hubURL, apiKey string) (LocalState, error) { + var state LocalState + switch err := loadJSON(statePath, &state); { + case err == nil: + return state, nil + case !os.IsNotExist(err): + // File exists but is unreadable/corrupt: surface it rather than + // silently regenerating a new identity and re-registering. + return LocalState{}, fmt.Errorf("load state: %w", err) + } + + privKey, err := wgtypes.GeneratePrivateKey() + if err != nil { + return LocalState{}, fmt.Errorf("generate key: %w", err) + } + + state, err = initFromHub(hubURL, apiKey, privKey) + if err != nil { + return LocalState{}, err + } + + if err := storeJSON(statePath, state); err != nil { + return LocalState{}, fmt.Errorf("save state: %w", err) + } + return state, nil +} + +func initFromHub(hubURL, apiKey string, privKey wgtypes.Key) (LocalState, error) { + wgPubKey := privKey.PublicKey() + + signPubKey, signPrivKey, err := sign.GenerateKey(rand.Reader) + if err != nil { + return LocalState{}, fmt.Errorf("generate sign key: %w", err) + } + + body, err := json.Marshal(m.PeerInitArgs{ + WGPubKey: wgPubKey[:], + SignPubKey: signPubKey[:], + }) + if err != nil { + return LocalState{}, fmt.Errorf("json error: %w", err) + } + + req, err := http.NewRequest(http.MethodPost, hubURL+"/peer/init/", bytes.NewReader(body)) + if err != nil { + return LocalState{}, err + } + req.SetBasicAuth("", apiKey) + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return LocalState{}, fmt.Errorf("hub init: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return LocalState{}, fmt.Errorf("hub init: HTTP %d", resp.StatusCode) + } + + var r m.PeerInitResp + if err := json.NewDecoder(resp.Body).Decode(&r); err != nil { + return LocalState{}, fmt.Errorf("hub init decode: %w", err) + } + + if len(r.Network) != 4 { + return LocalState{}, fmt.Errorf("hub init: invalid network %v", r.Network) + } + + netAddr := netip.AddrFrom4([4]byte(r.Network)) + octets := netAddr.As4() + octets[3] = r.PeerIP + vpnIP := netip.AddrFrom4(octets) + vpnNet := netip.PrefixFrom(netAddr, 24) + + var self *m.Peer + for i := range r.NetworkState.Peers { + if r.NetworkState.Peers[i].PeerIP == r.PeerIP { + self = &r.NetworkState.Peers[i] + break + } + } + if self == nil { + return LocalState{}, fmt.Errorf("hub init: no peer for own IP: %d", r.PeerIP) + } + + public := self.IsPublic() + + return LocalState{ + PrivKey: privKey, + SignKey: *signPrivKey, + VPNIP: vpnIP, + VPNNet: vpnNet, + WGPort: self.Port, + IsRelay: self.Relay && public, + IsPublic: public, + LocalDomain: r.LocalDomain, + }, nil +} + +func (s LocalState) MarshalJSON() ([]byte, error) { + return json.Marshal(localStateJSON{ + PrivKey: base64.StdEncoding.EncodeToString(s.PrivKey[:]), + SignKey: base64.StdEncoding.EncodeToString(s.SignKey[:]), + VPNIP: s.VPNIP, + VPNNet: s.VPNNet, + WGPort: s.WGPort, + IsRelay: s.IsRelay, + IsPublic: s.IsPublic, + LocalDomain: s.LocalDomain, + }) +} + +func (s *LocalState) UnmarshalJSON(data []byte) error { + var j localStateJSON + if err := json.Unmarshal(data, &j); err != nil { + return err + } + keyBytes, err := base64.StdEncoding.DecodeString(j.PrivKey) + if err != nil { + return fmt.Errorf("decode key: %w", err) + } + key, err := wgtypes.NewKey(keyBytes) + if err != nil { + return fmt.Errorf("invalid key: %w", err) + } + signKeyBytes, err := base64.StdEncoding.DecodeString(j.SignKey) + if err != nil { + return fmt.Errorf("decode sign key: %w", err) + } + if len(signKeyBytes) != 64 { + return fmt.Errorf("invalid sign key length: %d", len(signKeyBytes)) + } + *s = LocalState{ + PrivKey: key, + SignKey: [64]byte(signKeyBytes), + VPNIP: j.VPNIP, + VPNNet: j.VPNNet, + WGPort: j.WGPort, + IsRelay: j.IsRelay, + IsPublic: j.IsPublic, + LocalDomain: j.LocalDomain, + } + return nil +} diff --git a/peer/interface.go b/peer/interface.go deleted file mode 100644 index 0022392..0000000 --- a/peer/interface.go +++ /dev/null @@ -1,137 +0,0 @@ -package peer - -import ( - "fmt" - "io" - "net" - "os" - "syscall" - - "golang.org/x/sys/unix" -) - -func openInterface(network []byte, localIP byte, name string) (io.ReadWriteCloser, error) { - if len(network) != 4 { - return nil, fmt.Errorf("expected network to be 4 bytes, got %d", len(network)) - } - ip := net.IPv4(network[0], network[1], network[2], localIP) - - ////////////////////////// - // Create TUN Interface // - ////////////////////////// - - tunFD, err := syscall.Open("/dev/net/tun", syscall.O_RDWR|unix.O_CLOEXEC, 0600) - if err != nil { - return nil, fmt.Errorf("failed to open TUN device: %w", err) - } - - // New interface request. - req, err := unix.NewIfreq(name) - if err != nil { - return nil, fmt.Errorf("failed to create new TUN interface request: %w", err) - } - - // Flags: - // - // IFF_NO_PI => don't add packet info data to packets sent to the interface. - // IFF_TUN => create a TUN device handling IP packets. - req.SetUint16(unix.IFF_NO_PI | unix.IFF_TUN) - - err = unix.IoctlIfreq(tunFD, unix.TUNSETIFF, req) - if err != nil { - return nil, fmt.Errorf("failed to set TUN device settings: %w", err) - } - - // Name may not be exactly the same? - name = req.Name() - - ///////////// - // Set MTU // - ///////////// - - // We need a socket file descriptor to set other options for some reason. - sockFD, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP) - if err != nil { - return nil, fmt.Errorf("failed to open socket: %w", err) - } - defer unix.Close(sockFD) - - req, err = unix.NewIfreq(name) - if err != nil { - return nil, fmt.Errorf("failed to create MTU interface request: %w", err) - } - - req.SetUint32(if_mtu) - if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFMTU, req); err != nil { - return nil, fmt.Errorf("failed to set interface MTU: %w", err) - } - - ////////////////////// - // Set Queue Length // - ////////////////////// - - req, err = unix.NewIfreq(name) - if err != nil { - return nil, fmt.Errorf("failed to create IP interface request: %w", err) - } - - req.SetUint16(if_queue_len) - if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFTXQLEN, req); err != nil { - return nil, fmt.Errorf("failed to set interface queue length: %w", err) - } - - ///////////////////// - // Set IP and Mask // - ///////////////////// - - req, err = unix.NewIfreq(name) - if err != nil { - return nil, fmt.Errorf("failed to create IP interface request: %w", err) - } - - if err := req.SetInet4Addr(ip.To4()); err != nil { - return nil, fmt.Errorf("failed to set interface request IP: %w", err) - } - - if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFADDR, req); err != nil { - return nil, fmt.Errorf("failed to set interface IP: %w", err) - } - - // SET MASK - must happen after setting address. - req, err = unix.NewIfreq(name) - if err != nil { - return nil, fmt.Errorf("failed to create mask interface request: %w", err) - } - - if err := req.SetInet4Addr(net.IPv4(255, 255, 255, 0).To4()); err != nil { - return nil, fmt.Errorf("failed to set interface request mask: %w", err) - } - - if err := unix.IoctlIfreq(sockFD, unix.SIOCSIFNETMASK, req); err != nil { - return nil, fmt.Errorf("failed to set interface mask: %w", err) - } - - //////////////////////// - // Bring Interface Up // - //////////////////////// - - req, err = unix.NewIfreq(name) - if err != nil { - return nil, fmt.Errorf("failed to create up interface request: %w", err) - } - - // Get current flags. - if err = unix.IoctlIfreq(sockFD, unix.SIOCGIFFLAGS, req); err != nil { - return nil, fmt.Errorf("failed to get interface flags: %w", err) - } - - flags := req.Uint16() | unix.IFF_UP | unix.IFF_RUNNING - - // Set UP flag / broadcast flags. - req.SetUint16(flags) - if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFFLAGS, req); err != nil { - return nil, fmt.Errorf("failed to set interface up: %w", err) - } - - return os.NewFile(uintptr(tunFD), "tun"), nil -} diff --git a/peer/interfaces.go b/peer/interfaces.go new file mode 100644 index 0000000..e231186 --- /dev/null +++ b/peer/interfaces.go @@ -0,0 +1,28 @@ +package peer + +import ( + "net/netip" + "vppn/peer/control" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +// WGDevice is the subset of wginterface.Device used by App. +type WGDevice interface { + Name() string + Peers() ([]wgtypes.Peer, error) + AddPeer(pubKey wgtypes.Key) error + AddDirect(pubKey wgtypes.Key, endpoint netip.AddrPort, vpnIP netip.Addr) error + SetRelay(pubKey wgtypes.Key, endpoint netip.AddrPort, network netip.Prefix) error + AddProbe(pubKey wgtypes.Key, endpoint netip.AddrPort) error + Promote(pubKey wgtypes.Key, vpnIP netip.Addr) error + RemovePeer(pubKey wgtypes.Key) error +} + +// ControlConn sends pings to peers over the VPN control port. +// Reading is handled separately via run, which feeds the App's pingCh. +// buf is a caller-provided scratch buffer (at least control.Size bytes) used to +// marshal the ping; the caller reuses one across sends. +type ControlConn interface { + SendPing(dst netip.AddrPort, ping control.Ping, buf []byte) error +} diff --git a/peer/json.go b/peer/json.go new file mode 100644 index 0000000..9829820 --- /dev/null +++ b/peer/json.go @@ -0,0 +1,36 @@ +package peer + +import ( + "encoding/json" + "os" + "path/filepath" +) + +func loadJSON(path string, target any) error { + data, err := os.ReadFile(path) + if err != nil { + return err + } + return json.Unmarshal(data, target) +} + +func storeJSON(path string, obj any) error { + if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil { + return err + } + data, err := json.MarshalIndent(obj, "", " ") + if err != nil { + return err + } + + tmpPath := path + ".tmp" + if err := os.WriteFile(tmpPath, data, 0600); err != nil { + return err + } + + if err := os.Rename(tmpPath, path); err != nil { + os.Remove(tmpPath) + return err + } + return nil +} diff --git a/peer/main.go b/peer/main.go deleted file mode 100644 index 53c1bf8..0000000 --- a/peer/main.go +++ /dev/null @@ -1,209 +0,0 @@ -package peer - -import ( - "encoding/json" - "fmt" - "log" - "net" - "net/http" - "net/netip" - "os" - "time" -) - -// Usage: -// -// vppn netName run -// vppn netName status -func Main2() { - printUsage := func() { - fmt.Fprintf(os.Stderr, `%s COMMAND [ARGUMENTS...] - -Available commands: - run - status - hosts -`, os.Args[0]) - os.Exit(1) - } - - if len(os.Args) < 2 { - printUsage() - } - - command := os.Args[1] - - switch command { - case "run": - main_run() - case "status": - main_status() - case "hosts": - main_hosts() - default: - printUsage() - } -} - -// ---------------------------------------------------------------------------- - -type mainArgs struct { - NetName string - HubAddress string - APIKey string -} - -func main_run() { - printUsage := func() { - fmt.Fprintf(os.Stderr, `Usage: %s run NETWORK_NAME HUB_ADDRESS API_KEY - - NETWORK_NAME - Unique name of the network interface created. The network name - shouldn't change between invocations of the application. - - HUB_ADDRESS - The address of the hub server. This should also contain the scheme, for - example https://hub.domain.com/. - - API_KEY - The API key assigned to this peer by the hub. - -`, os.Args[0]) - os.Exit(1) - } - - if len(os.Args) != 5 { - printUsage() - } - - args := mainArgs{ - NetName: os.Args[2], - HubAddress: os.Args[3], - APIKey: os.Args[4], - } - - newPeerMain(args).Run() -} - -// ---------------------------------------------------------------------------- - -func main_status() { - printUsage := func() { - fmt.Fprintf(os.Stderr, `Usage: %s status NETWORK_NAME - - NETWORK_NAME - Unique name of the network interface created. - -`, os.Args[0]) - os.Exit(1) - } - - if len(os.Args) != 3 { - printUsage() - } - - netName := os.Args[2] - report := fetchStatusReport(netName) - - fmt.Printf("\n%s Status\n\n", netName) - - if len(report.Network) != 4 { - fmt.Println("ERROR: Network isn't 4 bytes.") - fmt.Printf("Network: %v\n\n", report.Network) - } else { - nw := report.Network - fmt.Printf("%-8s %d.%d.%d.%d\n", "IP", nw[0], nw[1], nw[2], report.LocalPeerIP) - fmt.Printf("%-8s %d.%d.%d.%d/24\n", "Network", nw[0], nw[1], nw[2], nw[3]) - } - - if report.RelayPeerIP != 0 { - fmt.Printf("%-8s %d\n\n", "Relay", report.RelayPeerIP) - } else { - fmt.Printf("%-8s -\n\n", "Relay") - } - - for _, status := range report.Remotes { - fmt.Printf("%3d %s\n", status.PeerIP, status.Name) - fmt.Printf(" %-11s %v\n", "Up", status.Up) - - pubIP, ok := netip.AddrFromSlice(status.PublicIP) - if ok { - fmt.Printf(" %-11s %v\n", "Public IP", pubIP) - } else { - fmt.Printf(" %-11s\n", "Public IP") - } - fmt.Printf(" %-11s %d\n", "Port", status.Port) - fmt.Printf(" %-11s %v\n", "Relay", status.Relay) - fmt.Printf(" %-11s %v\n", "Server", status.Server) - fmt.Printf(" %-11s %v\n", "Direct", status.Direct) - if status.DirectAddr.IsValid() { - fmt.Printf(" %-11s %v\n", "Address", status.DirectAddr) - } - fmt.Println("") - } -} - -// ---------------------------------------------------------------------------- - -func main_hosts() { - printUsage := func() { - fmt.Fprintf(os.Stderr, `Usage: %s hosts NETWORK_NAME - - NETWORK_NAME - Unique name of the network interface created. - -`, os.Args[0]) - os.Exit(1) - } - - if len(os.Args) != 3 { - printUsage() - } - - netName := os.Args[2] - state, err := loadNetworkState(netName) - if err != nil { - log.Fatalf("Failed to load network state: %v", err) - } - - config, err := loadPeerConfig(netName) - if err != nil { - log.Fatalf("Failed to load config: %v", err) - } - - nw := config.Network - for _, peer := range state.Peers { - if peer == nil { - continue - } - fmt.Printf("%d.%d.%d.%d %s\n", - nw[0], nw[1], nw[2], peer.PeerIP, peer.Name) - } - fmt.Println("") -} - -// ---------------------------------------------------------------------------- - -func fetchStatusReport(netName string) StatusReport { - client := http.Client{ - Transport: &http.Transport{ - Dial: func(_, _ string) (net.Conn, error) { - return net.Dial("unix", statusSocketPath(netName)) - }, - }, - Timeout: 8 * time.Second, - } - - getURL := "http://unix" + statusSocketPath(netName) - resp, err := client.Get(getURL) - if err != nil { - log.Fatalf("Failed to get response: %v", err) - } - - report := StatusReport{} - if err := json.NewDecoder(resp.Body).Decode(&report); err != nil { - log.Fatalf("Failed to decode status report: %v", err) - } - - return report -} diff --git a/peer/main_test.go b/peer/main_test.go deleted file mode 100644 index c759212..0000000 --- a/peer/main_test.go +++ /dev/null @@ -1,5 +0,0 @@ -package peer - -func newBuf() []byte { - return make([]byte, bufferSize) -} diff --git a/peer/mcreader.go b/peer/mcreader.go deleted file mode 100644 index e29bab6..0000000 --- a/peer/mcreader.go +++ /dev/null @@ -1,47 +0,0 @@ -package peer - -import ( - "log" - "net" - "time" -) - -func RunMCReader(g Globals) { - for { - runMCReaderInner(g) - time.Sleep(broadcastErrorTimeoutInterval) - } -} - -func runMCReaderInner(g Globals) { - var ( - buf = make([]byte, bufferSize) - logf = func(s string, args ...any) { - log.Printf("[MCReader] "+s, args...) - } - ) - - conn, err := net.ListenMulticastUDP("udp", nil, multicastAddr) - if err != nil { - logf("Failed to bind to multicast address: %v", err) - return - } - - for { - conn.SetReadDeadline(time.Now().Add(32 * time.Second)) - n, remoteAddr, err := conn.ReadFromUDPAddrPort(buf[:bufferSize]) - if err != nil { - logf("Failed to read from UDP port): %v", err) - return - } - - buf = buf[:n] - h, ok := headerFromLocalDiscoveryPacket(buf) - if !ok { - logf("Failed to open discovery packet?") - continue - } - - g.RemotePeers[h.SourceIP].Load().HandleLocalDiscoveryPacket(h, remoteAddr, buf) - } -} diff --git a/peer/mcreader_test.go b/peer/mcreader_test.go deleted file mode 100644 index 60feb44..0000000 --- a/peer/mcreader_test.go +++ /dev/null @@ -1,132 +0,0 @@ -package peer - -/* -type mcMockConn struct { - packets chan []byte -} - -func newMCMockConn() *mcMockConn { - return &mcMockConn{make(chan []byte, 32)} -} - -func (c *mcMockConn) WriteToUDP(in []byte, addr *net.UDPAddr) (int, error) { - c.packets <- bytes.Clone(in) - return len(in), nil -} - -func (c *mcMockConn) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) { - buf := <-c.packets - b = b[:len(buf)] - copy(b, buf) - return len(b), netip.AddrPort{}, nil -} - -func TestMCReader(t *testing.T) { - keys := generateKeys() - super := &mockControlMsgHandler{} - conn := newMCMockConn() - - peers := [256]*atomic.Pointer[RemotePeer]{} - peer := &RemotePeer{ - IP: 1, - Up: true, - PubSignKey: keys.PubSignKey, - } - peers[1] = &atomic.Pointer[RemotePeer]{} - peers[1].Store(peer) - - w := newMCWriter(conn, 1, keys.PrivSignKey) - r := newMCReader(conn, super, peers) - - w.SendLocalDiscovery() - r.handleNextPacket() - - if len(super.Messages) != 1 { - t.Fatal(super.Messages) - } - msg, ok := super.Messages[0].(controlMsg[PacketLocalDiscovery]) - if !ok || msg.SrcIP != 1 { - t.Fatal(ok, msg) - } -} - -func TestMCReader_noHeader(t *testing.T) { - keys := generateKeys() - super := &mockControlMsgHandler{} - conn := newMCMockConn() - - peers := [256]*atomic.Pointer[RemotePeer]{} - peer := &RemotePeer{ - IP: 1, - Up: true, - PubSignKey: keys.PubSignKey, - } - peers[1] = &atomic.Pointer[RemotePeer]{} - peers[1].Store(peer) - - r := newMCReader(conn, super, peers) - conn.WriteToUDP([]byte("0123546789"), nil) - r.handleNextPacket() - - if len(super.Messages) != 0 { - t.Fatal(super.Messages) - } -} - -func TestMCReader_noPeer(t *testing.T) { - keys := generateKeys() - super := &mockControlMsgHandler{} - conn := newMCMockConn() - - peers := [256]*atomic.Pointer[RemotePeer]{} - peer := &RemotePeer{ - IP: 1, - Up: true, - PubSignKey: keys.PubSignKey, - } - peers[1] = &atomic.Pointer[RemotePeer]{} - peers[2] = &atomic.Pointer[RemotePeer]{} - peers[1].Store(peer) - - w := newMCWriter(conn, 2, keys.PrivSignKey) - r := newMCReader(conn, super, peers) - - w.SendLocalDiscovery() - r.handleNextPacket() - - if len(super.Messages) != 0 { - t.Fatal(super.Messages) - } -} - -func TestMCReader_badSignature(t *testing.T) { - keys := generateKeys() - super := &mockControlMsgHandler{} - conn := newMCMockConn() - - peers := [256]*atomic.Pointer[RemotePeer]{} - peer := &RemotePeer{ - IP: 1, - Up: true, - PubSignKey: keys.PubSignKey, - } - peers[1] = &atomic.Pointer[RemotePeer]{} - peers[1].Store(peer) - - w := newMCWriter(conn, 1, keys.PrivSignKey) - w.SendLocalDiscovery() - - // Break signing. - packet := <-conn.packets - packet[0]++ - conn.packets <- packet - - r := newMCReader(conn, super, peers) - - r.handleNextPacket() - - if len(super.Messages) != 0 { - t.Fatal(super.Messages) - } -} -*/ diff --git a/peer/mcwriter.go b/peer/mcwriter.go deleted file mode 100644 index 5430aac..0000000 --- a/peer/mcwriter.go +++ /dev/null @@ -1,54 +0,0 @@ -package peer - -import ( - "log" - "net" - "time" - - "golang.org/x/crypto/nacl/sign" -) - -func createLocalDiscoveryPacket(localIP byte, signingKey []byte) []byte { - h := Header{ - SourceIP: localIP, - DestIP: 255, - } - buf := make([]byte, headerSize) - h.Marshal(buf) - out := make([]byte, headerSize+signingOverhead) - return sign.Sign(out[:0], buf, (*[64]byte)(signingKey)) -} - -func headerFromLocalDiscoveryPacket(pkt []byte) (h Header, ok bool) { - if len(pkt) != headerSize+signingOverhead { - return - } - - h.Parse(pkt[signingOverhead:]) - ok = true - return -} - -func verifyLocalDiscoveryPacket(pkt, buf []byte, pubSignKey []byte) bool { - _, ok := sign.Open(buf[:0], pkt, (*[32]byte)(pubSignKey)) - return ok -} - -// ---------------------------------------------------------------------------- - -func RunMCWriter(localIP byte, signingKey []byte) { - discoveryPacket := createLocalDiscoveryPacket(localIP, signingKey) - - conn, err := net.ListenMulticastUDP("udp", nil, multicastAddr) - if err != nil { - log.Fatalf("[MCWriter] Failed to bind to multicast address: %v", err) - } - - for range time.Tick(broadcastInterval) { - log.Printf("[MCWriter] Broadcasting on %v...", multicastAddr) - _, err := conn.WriteToUDP(discoveryPacket, multicastAddr) - if err != nil { - log.Printf("[MCWriter] Failed to write multicast: %v", err) - } - } -} diff --git a/peer/mcwriter_test.go b/peer/mcwriter_test.go deleted file mode 100644 index 74411f4..0000000 --- a/peer/mcwriter_test.go +++ /dev/null @@ -1,98 +0,0 @@ -package peer - -/* -// ---------------------------------------------------------------------------- - -// Testing that we can create and verify a local discovery packet. -func TestVerifyLocalDiscoveryPacket_valid(t *testing.T) { - keys := generateKeys() - - created := createLocalDiscoveryPacket(55, keys.PrivSignKey) - - header, ok := headerFromLocalDiscoveryPacket(created) - if !ok { - t.Fatal(ok) - } - if header.SourceIP != 55 || header.DestIP != 255 { - t.Fatal(header) - } - - if !verifyLocalDiscoveryPacket(created, make([]byte, 1024), keys.PubSignKey) { - t.Fatal("Not valid") - } -} - -// Testing that we don't try to parse short packets. -func TestVerifyLocalDiscoveryPacket_tooShort(t *testing.T) { - keys := generateKeys() - - created := createLocalDiscoveryPacket(55, keys.PrivSignKey) - - _, ok := headerFromLocalDiscoveryPacket(created[:len(created)-1]) - if ok { - t.Fatal(ok) - } -} - -// Testing that modifying a packet makes it invalid. -func TestVerifyLocalDiscoveryPacket_invalid(t *testing.T) { - keys := generateKeys() - - created := createLocalDiscoveryPacket(55, keys.PrivSignKey) - buf := make([]byte, 1024) - for i := range created { - modified := bytes.Clone(created) - modified[i]++ - if verifyLocalDiscoveryPacket(modified, buf, keys.PubSignKey) { - t.Fatal("Verification should have failed.") - } - } -} - -// ---------------------------------------------------------------------------- - -type testUDPWriter struct { - written [][]byte -} - -func (w *testUDPWriter) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) { - w.written = append(w.written, bytes.Clone(b)) - return len(b), nil -} - -func (w *testUDPWriter) Written() [][]byte { - out := w.written - w.written = [][]byte{} - return out -} - -// ---------------------------------------------------------------------------- - -// Testing that the mcWriter sends local discovery packets as expected. -func TestMCWriter_SendLocalDiscovery(t *testing.T) { - keys := generateKeys() - writer := &testUDPWriter{} - - mcw := newMCWriter(writer, 42, keys.PrivSignKey) - mcw.SendLocalDiscovery() - - out := writer.Written() - if len(out) != 1 { - t.Fatal(out) - } - - pkt := out[0] - - header, ok := headerFromLocalDiscoveryPacket(pkt) - if !ok { - t.Fatal(ok) - } - if header.SourceIP != 42 || header.DestIP != 255 { - t.Fatal(header) - } - - if !verifyLocalDiscoveryPacket(pkt, make([]byte, 1024), keys.PubSignKey) { - t.Fatal("Verification should succeed.") - } -} -*/ diff --git a/peer/mock-iface_test.go b/peer/mock-iface_test.go deleted file mode 100644 index ffef5d9..0000000 --- a/peer/mock-iface_test.go +++ /dev/null @@ -1,31 +0,0 @@ -package peer - -import "bytes" - -type TestIFace struct { - out *bytes.Buffer // Toward the network. - in *bytes.Buffer // From the network -} - -func NewTestIFace() *TestIFace { - return &TestIFace{ - out: &bytes.Buffer{}, - in: &bytes.Buffer{}, - } -} - -func (iface *TestIFace) Write(b []byte) (int, error) { - return iface.in.Write(b) -} - -func (iface *TestIFace) Read(b []byte) (int, error) { - return iface.out.Read(b) -} - -func (iface *TestIFace) UserWrite(b []byte) (int, error) { - return iface.out.Write(b) -} - -func (iface *TestIFace) UserRead(b []byte) (int, error) { - return iface.in.Read(b) -} diff --git a/peer/mock-network_test.go b/peer/mock-network_test.go deleted file mode 100644 index 4b5240c..0000000 --- a/peer/mock-network_test.go +++ /dev/null @@ -1,80 +0,0 @@ -package peer - -import ( - "bytes" - "net" - "net/netip" - "sync" -) - -type TestPacket struct { - Addr netip.AddrPort - Data []byte -} - -type TestNetwork struct { - lock sync.Mutex - packets map[netip.AddrPort]chan TestPacket -} - -func NewTestNetwork() *TestNetwork { - return &TestNetwork{packets: map[netip.AddrPort]chan TestPacket{}} -} - -func (n *TestNetwork) NewUDPConn(localAddr netip.AddrPort) *TestUDPConn { - n.lock.Lock() - defer n.lock.Unlock() - if _, ok := n.packets[localAddr]; !ok { - n.packets[localAddr] = make(chan TestPacket, 1024) - } - return &TestUDPConn{ - addr: localAddr, - n: n, - packets: n.packets[localAddr], - } -} - -func (n *TestNetwork) write(b []byte, from, to netip.AddrPort) { - n.lock.Lock() - defer n.lock.Unlock() - if _, ok := n.packets[to]; !ok { - n.packets[to] = make(chan TestPacket, 1024) - } - n.packets[to] <- TestPacket{ - Addr: from, - Data: bytes.Clone(b), - } -} - -type TestUDPConn struct { - addr netip.AddrPort - n *TestNetwork - packets chan TestPacket -} - -func (c *TestUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { - c.n.write(b, c.addr, addr) - return len(b), nil -} - -func (c *TestUDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) { - return c.WriteToUDPAddrPort(b, addr.AddrPort()) -} - -func (c *TestUDPConn) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) { - pkt := <-c.packets - b = b[:len(pkt.Data)] - copy(b, pkt.Data) - return len(b), pkt.Addr, nil -} - -func (c *TestUDPConn) Packets() (out []TestPacket) { - for { - select { - case pkt := <-c.packets: - out = append(out, pkt) - default: - return - } - } -} diff --git a/peer/multicast/broadcaster.go b/peer/multicast/broadcaster.go new file mode 100644 index 0000000..c349d6b --- /dev/null +++ b/peer/multicast/broadcaster.go @@ -0,0 +1,62 @@ +package multicast + +import ( + "log" + "net" + "net/netip" + "time" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +var addr = net.UDPAddrFromAddrPort(netip.AddrPortFrom( + netip.AddrFrom4([4]byte{224, 0, 0, 157}), + 4560)) + +func Broadcast( + selfVPNIP netip.Addr, + pubKey wgtypes.Key, + wgPort uint16, + signKey *[64]byte, +) { + for { + broadcastInner(selfVPNIP, pubKey, wgPort, signKey) + time.Sleep(errorTimeout) + } +} + +func broadcastInner(selfVPNIP netip.Addr, pubKey wgtypes.Key, wgPort uint16, signKey *[64]byte) { + conn, err := net.ListenMulticastUDP("udp", nil, addr) + if err != nil { + log.Printf("[MCBroadcast] bind: %v", err) + return + } + defer conn.Close() + + buf := make([]byte, BufferSize) + packet := Packet{ + PeerIP: selfVPNIP.As4()[3], + WGPubKey: pubKey, + WGPort: wgPort, + } + + // Re-sign on each send so the timestamp is fresh; a stale timestamp would be + // dropped by receivers' freshness gate. + send := func() error { + packet.Timestamp = time.Now().Unix() + payload := packet.Marshal(buf, signKey) + _, err := conn.WriteToUDP(payload, addr) + return err + } + + if err := send(); err != nil { + log.Printf("[MCBroadcast] write: %v", err) + } + + for range time.Tick(broadcastInterval) { + if err := send(); err != nil { + log.Printf("[MCBroadcast] write: %v", err) + return + } + } +} diff --git a/peer/multicast/global.go b/peer/multicast/global.go new file mode 100644 index 0000000..00ea4b7 --- /dev/null +++ b/peer/multicast/global.go @@ -0,0 +1,9 @@ +package multicast + +import "time" + +const ( + errorTimeout = 16 * time.Second + broadcastInterval = 16 * time.Second + maxPacketAge = time.Minute +) diff --git a/peer/multicast/packet.go b/peer/multicast/packet.go new file mode 100644 index 0000000..ed3db7c --- /dev/null +++ b/peer/multicast/packet.go @@ -0,0 +1,54 @@ +package multicast + +import ( + "encoding/binary" + "net/netip" + + "golang.org/x/crypto/nacl/sign" +) + +const ( + BufferSize = packetSize + SignedPacketSize + SignedPacketSize = packetSize + signSize + packetSize = 43 + signSize = 64 +) + +// Layout: +// +// [0] final octet of the sender's VPN IP +// [1:33] WG public key +// [33:35] WG listen port (big-endian uint16) +// [35:43] send time, Unix seconds (big-endian int64) — freshness/replay gate +type Packet struct { + PeerIP byte // Final octet of the sender's VPN IP. + WGPubKey [32]byte // WG public key. + WGPort uint16 // WG listen port. + Timestamp int64 // Unix timestamp. + Src netip.Addr // Source of packet. + Signed []byte // Raw signed message for verification (incoming packet). +} + +// Marshal the packet into a buffer with prefixed signature. +func (p Packet) Marshal(buf []byte, signKey *[64]byte) []byte { + buf[0] = p.PeerIP + copy(buf[1:33], p.WGPubKey[:]) + binary.BigEndian.PutUint16(buf[33:35], p.WGPort) + binary.BigEndian.PutUint64(buf[35:43], uint64(p.Timestamp)) + return sign.Sign(buf[packetSize:packetSize], buf[:packetSize], signKey) +} + +func (p Packet) Verify(buf []byte, pubKey *[32]byte) bool { + _, ok := sign.Open(buf, p.Signed, pubKey) + return ok +} + +func Unmarshal(signed []byte) (p Packet) { + buf := signed[signSize:] + p.PeerIP = buf[0] + copy(p.WGPubKey[:], buf[1:33]) + p.WGPort = binary.BigEndian.Uint16(buf[33:35]) + p.Timestamp = int64(binary.BigEndian.Uint64(buf[35:43])) + p.Signed = signed + return +} diff --git a/peer/multicast/packet_test.go b/peer/multicast/packet_test.go new file mode 100644 index 0000000..6aed8d9 --- /dev/null +++ b/peer/multicast/packet_test.go @@ -0,0 +1,38 @@ +package multicast + +import ( + "crypto/rand" + "testing" + + "golang.org/x/crypto/nacl/sign" +) + +func TestPacket(t *testing.T) { + pub, priv, err := sign.GenerateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + + p := Packet{ + PeerIP: 10, + WGPubKey: [32]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2}, + WGPort: 44, + Timestamp: 12948893, + } + + buf := make([]byte, BufferSize) + signed := p.Marshal(buf, priv) + if len(signed) != SignedPacketSize { + t.Fatalf("signed length = %d, want %d", len(signed), SignedPacketSize) + } + + got := Unmarshal(signed) + if got.PeerIP != p.PeerIP || got.WGPubKey != p.WGPubKey || + got.WGPort != p.WGPort || got.Timestamp != p.Timestamp { + t.Fatalf("round-trip mismatch:\n got %+v\nwant %+v", got, p) + } + + if !got.Verify(nil, pub) { + t.Error("signature did not verify") + } +} diff --git a/peer/multicast/receiver.go b/peer/multicast/receiver.go new file mode 100644 index 0000000..2d1bfad --- /dev/null +++ b/peer/multicast/receiver.go @@ -0,0 +1,61 @@ +package multicast + +import ( + "bytes" + "fmt" + "log" + "net" + "net/netip" + "time" +) + +func Receiver(vpnNet netip.Prefix, selfVPNIP netip.Addr, ch chan<- Packet) { + for { + if err := receiver(vpnNet, selfVPNIP, ch); err != nil { + log.Printf("[MCReader] %v", err) + } + time.Sleep(errorTimeout) + } +} + +func receiver(vpnNet netip.Prefix, selfVPNIP netip.Addr, ch chan<- Packet) error { + selfIP := selfVPNIP.As4()[3] + + conn, err := net.ListenMulticastUDP("udp", nil, addr) + if err != nil { + return fmt.Errorf("bind: %w", err) + } + defer conn.Close() + + buf := make([]byte, BufferSize+1) // +1 to detect oversized packets + + for { + conn.SetReadDeadline(time.Now().Add(32 * time.Second)) + n, src, err := conn.ReadFromUDPAddrPort(buf) + if err != nil { + if ne, ok := err.(net.Error); ok && ne.Timeout() { + continue + } + return fmt.Errorf("read: %w", err) + } + + if n != SignedPacketSize { + continue + } + + packet := Unmarshal(buf[:n]) + + if packet.PeerIP == selfIP { + continue + } + + age := time.Since(time.Unix(packet.Timestamp, 0)) + if age > maxPacketAge || age < -maxPacketAge { + continue + } + + packet.Signed = bytes.Clone(packet.Signed) + packet.Src = src.Addr().Unmap() + ch <- packet + } +} diff --git a/peer/network_state.go b/peer/network_state.go new file mode 100644 index 0000000..7f80218 --- /dev/null +++ b/peer/network_state.go @@ -0,0 +1,18 @@ +package peer + +import "vppn/m" + +// loadNetworkState reads a cached network state from disk. Any error (most +// commonly a missing file on first run) is returned to the caller, which +// treats it as "no cache available". +func loadNetworkState(path string) (m.NetworkState, error) { + var state m.NetworkState + err := loadJSON(path, &state) + return state, err +} + +// saveNetworkState writes state to path atomically (see storeJSON), so a crash +// mid-write cannot leave a corrupt cache. +func saveNetworkState(path string, state m.NetworkState) error { + return storeJSON(path, state) +} diff --git a/peer/network_state_test.go b/peer/network_state_test.go new file mode 100644 index 0000000..be1fd23 --- /dev/null +++ b/peer/network_state_test.go @@ -0,0 +1,56 @@ +package peer + +import ( + "net/netip" + "path/filepath" + "reflect" + "testing" + + "vppn/m" +) + +func TestNetworkState_RoundTrip(t *testing.T) { + path := filepath.Join(t.TempDir(), "network.json") + + var sign1 [32]byte + copy(sign1[:], []byte("0123456789abcdef0123456789abcdef")) + + state := m.NetworkState{Peers: []m.Peer{ + { + PeerIP: 1, + Name: "hub", + Addr4: netip.MustParseAddr("10.11.12.1"), + Port: 51820, + Relay: true, + WGPubKey: mustKey(t), + SignPubKey: sign1, + }, + { + PeerIP: 10, + Name: "laptop", + Addr4: netip.MustParseAddr("10.11.12.10"), + Port: 51820, + WGPubKey: mustKey(t), + }, + }} + + if err := saveNetworkState(path, state); err != nil { + t.Fatal(err) + } + + got, err := loadNetworkState(path) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(got, state) { + t.Errorf("round-trip mismatch:\n got: %+v\nwant: %+v", got.Peers[1], state.Peers[1]) + } +} + +func TestNetworkState_LoadMissing(t *testing.T) { + path := filepath.Join(t.TempDir(), "does-not-exist.json") + if _, err := loadNetworkState(path); err == nil { + t.Fatal("expected error loading missing cache, got nil") + } +} diff --git a/peer/new.go b/peer/new.go new file mode 100644 index 0000000..ff75168 --- /dev/null +++ b/peer/new.go @@ -0,0 +1,109 @@ +package peer + +import ( + "fmt" + "net/netip" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "vppn/m" + "vppn/peer/multicast" + "vppn/peer/wginterface" +) + +// New constructs an App, creates the WireGuard interface, and starts the +// background goroutines (hub poller, multicast, control conn reader). +// The caller should invoke Run() to start the event loop. +func New( + state LocalState, + hubURL, apiKey string, + ifaceName string, + localDomain string, + networkStatePath string, +) (*App, error) { + + a4 := state.VPNIP.As4() + if err := wginterface.Create(ifaceName, a4[:], 24); err != nil { + return nil, fmt.Errorf("create WG interface: %w", err) + } + + dev, err := wginterface.Open(ifaceName) + if err != nil { + _ = wginterface.Delete(ifaceName) + return nil, fmt.Errorf("open WG device: %w", err) + } + + cc, err := newUDPControlConn(state.VPNIP, ControlPort) + if err != nil { + _ = dev.Close() + _ = wginterface.Delete(ifaceName) + return nil, fmt.Errorf("control conn: %w", err) + } + + cleanup := func() { + _ = cc.Close() + _ = dev.Close() + _ = wginterface.Delete(ifaceName) + } + + if err := dev.Configure(state.PrivKey, int(state.WGPort)); err != nil { + cleanup() + return nil, fmt.Errorf("configure WG device: %w", err) + } + + if state.IsRelay { + if err := dev.EnableForwarding(); err != nil { + cleanup() + return nil, fmt.Errorf("enable forwarding: %w", err) + } + } + + pingCh := make(chan PingEvent) + hubAddCh := make(chan m.Peer) + hubRemoveCh := make(chan wgtypes.Key) + multicastCh := make(chan multicast.Packet) + + poller, err := NewHubPoller( + state.VPNIP, + state.VPNNet, + hubURL, + apiKey, + networkStatePath, + hubAddCh, + hubRemoveCh) + if err != nil { + cleanup() + return nil, fmt.Errorf("hub poller: %w", err) + } + + go cc.run(pingCh) + go poller.Run() + + if !state.IsPublic { + go multicast.Broadcast(state.VPNIP, state.PrivKey.PublicKey(), state.WGPort, &state.SignKey) + go multicast.Receiver(state.VPNNet, state.VPNIP, multicastCh) + } + + return &App{ + vpnIP: state.VPNIP, + vpnNet: state.VPNNet, + privKey: state.PrivKey, + pubKey: state.PrivKey.PublicKey(), + isRelay: state.IsRelay, + isPublic: state.IsPublic, + localDomain: localDomain, + + dev: dev, + controlConn: cc, + + peersByKey: make(map[wgtypes.Key]*Peer), + peersByIP: make(map[netip.Addr]*Peer), + + scratch: make([]byte, scratchSize), + + hubAddCh: hubAddCh, + hubRemoveCh: hubRemoveCh, + pingCh: pingCh, + multicastCh: multicastCh, + }, nil +} diff --git a/peer/on_hub.go b/peer/on_hub.go new file mode 100644 index 0000000..677263f --- /dev/null +++ b/peer/on_hub.go @@ -0,0 +1,114 @@ +package peer + +import ( + "log" + "math" + "net/netip" + "time" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "vppn/m" + "vppn/peer/control" +) + +func (a *App) onAddPeer(p m.Peer) { + a.onRemovePeer(p.WGPubKey) + + octets := a.vpnNet.Addr().As4() + octets[3] = p.PeerIP + vpnIP := netip.AddrFrom4(octets) + + peer := &Peer{ + wgPeer: wgtypes.Peer{PublicKey: p.WGPubKey}, + VPNIP: vpnIP, + Name: p.Name, + IsRelay: p.Relay, + IsPublic: p.IsPublic(), + EndpointV4: p.Endpoint4(), + EndpointV6: p.Endpoint6(), + RTT: time.Duration(math.MaxInt64) * time.Nanosecond, + Role: roleFor(a.isPublic, a.vpnIP, p.IsPublic(), vpnIP), + SignPubKey: p.SignPubKey, + } + + a.peersByKey[p.WGPubKey] = peer + a.peersByIP[peer.VPNIP] = peer + defer a.updateHosts() + + if !peer.IsPublic { + if a.isPublic { + // Public nodes accept traffic from non-public peers as soon as they + // initiate a handshake. Set /32 AllowedIPs now; WireGuard learns the + // endpoint from the incoming handshake automatically. + a.devPromote(peer) + } else { + a.devAddPeer(peer) + } + return + } + + a.devAddDirect(peer, peer.PreferredEndpoint()) +} + +func (a *App) onRemovePeer(key wgtypes.Key) { + peer, exists := a.peersByKey[key] + if !exists { + return + } + a.devRemove(peer) + delete(a.peersByKey, key) + delete(a.peersByIP, peer.VPNIP) + a.updateHosts() + + if peer == a.relay { + a.relay = nil + a.switchActiveRelay() + } +} + +// switchActiveRelay promotes the lowest-latency relay peer to active. +func (a *App) switchActiveRelay() { + if a.relay != nil { + // If we have a relay, it's public, so should go back to being a direct + // peer - this will convert it's /24 to a /32. + a.devAddDirect(a.relay, a.relay.PreferredEndpoint()) + a.relay = nil + } + + var best *Peer + for _, p := range a.peersByKey { + if !p.CanRelay() { + continue + } + + if best == nil || p.RTT < best.RTT { + best = p + } + } + if best == nil { + log.Printf("no relay available") + return + } + + a.devSetRelay(best, best.PreferredEndpoint()) + a.relay = best +} + +func preferredEndpoint(v4, v6 netip.AddrPort) netip.AddrPort { + // We always prefer v4 since all peers can connect to IPv4 addresses. + if v4.IsValid() { + return v4 + } + return v6 +} + +func roleFor(selfIsPublic bool, selfIP netip.Addr, peerIsPublic bool, peerVPNIP netip.Addr) control.Role { + if !selfIsPublic && peerIsPublic { + return control.Client + } + if selfIsPublic && !peerIsPublic { + return control.Server + } + return control.RoleFor(selfIP, peerVPNIP) +} diff --git a/peer/on_hub_test.go b/peer/on_hub_test.go new file mode 100644 index 0000000..230aa93 --- /dev/null +++ b/peer/on_hub_test.go @@ -0,0 +1,299 @@ +package peer + +import ( + "net/netip" + "testing" + "time" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "vppn/m" +) + +func mustKey(t *testing.T) wgtypes.Key { + t.Helper() + k, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Fatalf("generate key: %v", err) + } + return k.PublicKey() +} + +func TestOnAddPeer(t *testing.T) { + ep1 := netip.MustParseAddrPort("1.2.3.4:51820") + ep2 := netip.MustParseAddrPort("5.6.7.8:51820") + peerVPNIP := netip.MustParseAddr("10.0.0.2") + + testCases := []struct { + name string + setup func(a *App, key wgtypes.Key) + peer func(key wgtypes.Key) m.Peer + check func(t *testing.T, a *App, dev *fakeWGDevice, key wgtypes.Key) + }{ + { + name: "non-public peer registered in WG via AddPeer", + peer: func(k wgtypes.Key) m.Peer { + return m.Peer{WGPubKey: k, PeerIP: 2} + }, + check: func(t *testing.T, a *App, dev *fakeWGDevice, key wgtypes.Key) { + p := a.peersByKey[key] + if p == nil { + t.Fatal("not in peersByKey") + } + if a.peersByIP[peerVPNIP] == nil { + t.Fatal("not in peersByIP") + } + if p.State != StateRelayed { + t.Fatalf("state = %v, want StateRelayed", p.State) + } + dev.AssertAddPeer(t, 0, key) + }, + }, + { + name: "public peer with endpoint registered via AddDirect", + peer: func(k wgtypes.Key) m.Peer { + return m.Peer{WGPubKey: k, PeerIP: 2, Addr4: ep1.Addr(), Port: ep1.Port()} + }, + check: func(t *testing.T, a *App, dev *fakeWGDevice, key wgtypes.Key) { + p := a.peersByKey[key] + if p == nil { + t.Fatal("not in peersByKey") + } + dev.AssertAddDirect(t, 0, p.PubKey(), ep1, p.VPNIP) + }, + }, + { + name: "re-add removes old WG entry before adding new one", + setup: func(a *App, key wgtypes.Key) { + a.onAddPeer(m.Peer{WGPubKey: key, PeerIP: 2, Addr4: ep1.Addr(), Port: ep1.Port()}) + }, + peer: func(k wgtypes.Key) m.Peer { + return m.Peer{WGPubKey: k, PeerIP: 2, Addr4: ep2.Addr(), Port: ep2.Port()} + }, + check: func(t *testing.T, a *App, dev *fakeWGDevice, key wgtypes.Key) { + if len(dev.Calls) != 2 { + t.Fatalf("dev calls = %v, want [RemovePeer, AddDirect]", dev.Calls) + } + dev.AssertRemovePeer(t, 0, key) + dev.AssertAddDirect(t, 1, key, ep2, peerVPNIP) + if len(a.peersByKey) != 1 || len(a.peersByIP) != 1 { + t.Errorf("maps: peersByKey=%d peersByIP=%d, want 1 each", len(a.peersByKey), len(a.peersByIP)) + } + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + a, dev, _ := newTestApp(t, "10.0.0.1", false, false) + key := mustKey(t) + if tc.setup != nil { + tc.setup(a, key) + dev.Calls = nil + } + a.onAddPeer(tc.peer(key)) + tc.check(t, a, dev, key) + }) + } +} + +func TestOnRemovePeer(t *testing.T) { + ep1 := netip.MustParseAddrPort("1.2.3.4:51820") + ep2 := netip.MustParseAddrPort("5.6.7.8:51820") + + testCases := []struct { + name string + setup func(t *testing.T, a *App) wgtypes.Key // returns the key to remove + check func(t *testing.T, a *App, dev *fakeWGDevice) + }{ + { + name: "unknown key is a no-op", + setup: func(t *testing.T, a *App) wgtypes.Key { + return mustKey(t) + }, + check: func(t *testing.T, a *App, dev *fakeWGDevice) { + dev.AssertNoCalls(t) + if len(a.peersByKey) != 0 { + t.Errorf("peersByKey should be empty") + } + }, + }, + { + name: "StateRelayed peer removed from maps with RemovePeer", + setup: func(t *testing.T, a *App) wgtypes.Key { + key := mustKey(t) + a.onAddPeer(m.Peer{WGPubKey: key, PeerIP: 2}) + return key + }, + check: func(t *testing.T, a *App, dev *fakeWGDevice) { + if len(dev.Calls) != 1 { + t.Fatalf("dev calls = %v, want [RemovePeer]", dev.Calls) + } + dev.AssertRemovePeer(t, 0, dev.Calls[0].PubKey) + if len(a.peersByKey) != 0 || len(a.peersByIP) != 0 { + t.Errorf("maps should be empty after remove") + } + }, + }, + { + name: "StateDirect peer removed from maps with RemovePeer", + setup: func(t *testing.T, a *App) wgtypes.Key { + key := mustKey(t) + a.onAddPeer(m.Peer{WGPubKey: key, PeerIP: 2, Addr4: ep1.Addr(), Port: ep1.Port()}) + return key + }, + check: func(t *testing.T, a *App, dev *fakeWGDevice) { + if len(dev.Calls) != 1 { + t.Fatalf("dev calls = %v, want [RemovePeer]", dev.Calls) + } + dev.AssertRemovePeer(t, 0, dev.Calls[0].PubKey) + if len(a.peersByKey) != 0 || len(a.peersByIP) != 0 { + t.Errorf("maps should be empty after remove") + } + }, + }, + { + name: "removing active relay with no backup clears relay field", + setup: func(t *testing.T, a *App) wgtypes.Key { + relay := addRelayPeer(t, a, "10.0.0.10", ep1) + a.relay = relay + return relay.PubKey() + }, + check: func(t *testing.T, a *App, dev *fakeWGDevice) { + if len(dev.Calls) != 1 { + t.Fatalf("dev calls = %v, want [RemovePeer]", dev.Calls) + } + dev.AssertRemovePeer(t, 0, dev.Calls[0].PubKey) + if a.relay != nil { + t.Errorf("relay should be nil after removing only relay") + } + }, + }, + { + name: "removing active relay elects backup via SetRelay", + setup: func(t *testing.T, a *App) wgtypes.Key { + relay1 := addRelayPeer(t, a, "10.0.0.10", ep1) + addRelayPeer(t, a, "10.0.0.11", ep2) + a.relay = relay1 + return relay1.PubKey() + }, + check: func(t *testing.T, a *App, dev *fakeWGDevice) { + if len(dev.Calls) != 2 { + t.Fatalf("dev calls = %v, want [RemovePeer, SetRelay]", dev.Calls) + } + dev.AssertRemovePeer(t, 0, dev.Calls[0].PubKey) + dev.AssertSetRelay(t, 1, dev.Calls[1].PubKey, ep2, a.vpnNet) + if a.relay == nil { + t.Errorf("relay should be set to backup after failover") + } + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + a, dev, _ := newTestApp(t, "10.0.0.1", false, false) + key := tc.setup(t, a) + dev.Calls = nil + a.onRemovePeer(key) + tc.check(t, a, dev) + }) + } +} + +func TestSwitchActiveRelay(t *testing.T) { + ep1 := netip.MustParseAddrPort("1.2.3.4:51820") + ep2 := netip.MustParseAddrPort("5.6.7.8:51820") + + testCases := []struct { + name string + setup func(t *testing.T, a *App) + check func(t *testing.T, a *App, dev *fakeWGDevice) + }{ + { + name: "no candidates leaves relay nil", + setup: func(t *testing.T, a *App) {}, + check: func(t *testing.T, a *App, dev *fakeWGDevice) { + dev.AssertNoCalls(t) + if a.relay != nil { + t.Error("relay should be nil") + } + }, + }, + { + name: "single candidate elected via SetRelay", + setup: func(t *testing.T, a *App) { + addRelayPeer(t, a, "10.0.0.10", ep1) + }, + check: func(t *testing.T, a *App, dev *fakeWGDevice) { + if len(dev.Calls) != 1 { + t.Fatalf("dev calls = %v, want [SetRelay]", dev.Calls) + } + dev.AssertSetRelay(t, 0, dev.Calls[0].PubKey, ep1, a.vpnNet) + if a.relay == nil { + t.Error("relay should be set") + } + }, + }, + { + name: "measured RTT beats zero RTT", + setup: func(t *testing.T, a *App) { + r1 := addRelayPeer(t, a, "10.0.0.10", ep1) + r1.RTT = 10 * time.Millisecond + addRelayPeer(t, a, "10.0.0.11", ep2) // RTT stays MaxInt64 (unmeaured) + }, + check: func(t *testing.T, a *App, dev *fakeWGDevice) { + if len(dev.Calls) != 1 { + t.Fatalf("dev calls = %v, want [SetRelay]", dev.Calls) + } + dev.AssertSetRelay(t, 0, dev.Calls[0].PubKey, ep1, a.vpnNet) + }, + }, + { + name: "lower RTT wins", + setup: func(t *testing.T, a *App) { + r1 := addRelayPeer(t, a, "10.0.0.10", ep1) + r1.RTT = 5 * time.Millisecond + r2 := addRelayPeer(t, a, "10.0.0.11", ep2) + r2.RTT = 20 * time.Millisecond + }, + check: func(t *testing.T, a *App, dev *fakeWGDevice) { + if len(dev.Calls) != 1 { + t.Fatalf("dev calls = %v, want [SetRelay]", dev.Calls) + } + dev.AssertSetRelay(t, 0, dev.Calls[0].PubKey, ep1, a.vpnNet) + }, + }, + { + name: "stale relay demoted to direct before backup elected", + setup: func(t *testing.T, a *App) { + old := addRelayPeer(t, a, "10.0.0.10", ep1) + old.wgPeer.LastHandshakeTime = time.Time{} // stale — triggers switch from onTick + a.relay = old + addRelayPeer(t, a, "10.0.0.11", ep2) + }, + check: func(t *testing.T, a *App, dev *fakeWGDevice) { + if len(dev.Calls) != 2 { + t.Fatalf("dev calls = %v, want [AddDirect, SetRelay]", dev.Calls) + } + if dev.Calls[0].Method != "AddDirect" || dev.Calls[0].Endpoint != ep1 { + t.Errorf("call[0]: got %v, want AddDirect with ep1", dev.Calls[0]) + } + dev.AssertSetRelay(t, 1, dev.Calls[1].PubKey, ep2, a.vpnNet) + if a.relay == nil || a.relay.EndpointV4 != ep2 { + t.Error("relay should be the backup peer") + } + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + a, dev, _ := newTestApp(t, "10.0.0.1", false, false) + tc.setup(t, a) + dev.Calls = nil + a.switchActiveRelay() + tc.check(t, a, dev) + }) + } +} diff --git a/peer/on_multicast.go b/peer/on_multicast.go new file mode 100644 index 0000000..0e05851 --- /dev/null +++ b/peer/on_multicast.go @@ -0,0 +1,51 @@ +package peer + +import ( + "net/netip" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "vppn/peer/multicast" +) + +func (a *App) onMulticastDiscovery(pkt multicast.Packet) { + if a.isPublic { + return + } + + // Locate the sender peer by its VPN IP (final octet carried in the beacon). + octets := a.vpnNet.Addr().As4() + octets[3] = pkt.PeerIP + vpnIP := netip.AddrFrom4(octets) + + peer, ok := a.peersByIP[vpnIP] + if !ok || peer.IsPublic || peer.State == StateDirect { + return + } + + // Authenticate the beacon against the peer's known sign key. scratch[:0] + // gives sign.Open an empty-but-capacity buffer to decode into. + if !pkt.Verify(a.scratch[:0], &peer.SignPubKey) { + return + } + + // The beacon is authentic but must also advertise the WG key the hub gave + // us for this peer; otherwise it's inconsistent — drop it. + if wgtypes.Key(pkt.WGPubKey) != peer.PubKey() { + return + } + + endpoint := netip.AddrPortFrom(pkt.Src, pkt.WGPort) + if !endpoint.IsValid() { + return + } + + var v4, v6 netip.AddrPort + if pkt.Src.Is4() { + v4 = endpoint + } else { + v6 = endpoint + } + + a.addProbe(peer, v4, v6) +} diff --git a/peer/on_ping.go b/peer/on_ping.go new file mode 100644 index 0000000..2be8d0a --- /dev/null +++ b/peer/on_ping.go @@ -0,0 +1,58 @@ +package peer + +import ( + "net/netip" + "time" + + "vppn/peer/control" +) + +func (a *App) onPing(e PingEvent) { + peer, ok := a.peersByIP[e.srcVPNIP] + if !ok { + // TODO: Log here. + return + } + + now := time.Now() + + // If we're the server, respond - this is always necessary as it's used to + // know if peers are up or down. + if peer.Role == control.Server { + a.sendPing(peer, e.ping.PingTS) + } + + // Compute RTT from server echo. + if peer.Role == control.Client { + peer.RTT = now.Sub(time.Unix(0, e.ping.PingTS)) + } + + // If we're public, nothing more to do. + if a.isPublic { + return + } + + // We can only learn our own endpoint from directly-connected peers — Dst + // is the sender's observation of our WG handshake source. + if peer.State == StateDirect { + if dst := e.ping.Dst; dst.IsValid() { + if dst.Addr().Is4() { + a.selfV4 = dst + } else { + a.selfV6 = dst + } + } + return + } + + a.addProbe(peer, e.ping.SrcV4, e.ping.SrcV6) +} + +func (a *App) addProbe(peer *Peer, v4, v6 netip.AddrPort) { + endpoint := preferredEndpoint(v4, v6) + if !endpoint.IsValid() || endpoint == peer.PreferredEndpoint() { + return + } + peer.UpdateEndpoints(v4, v6) + a.devAddProbe(peer, endpoint) +} diff --git a/peer/on_tick.go b/peer/on_tick.go new file mode 100644 index 0000000..afdd5b6 --- /dev/null +++ b/peer/on_tick.go @@ -0,0 +1,52 @@ +package peer + +import ( + "log" + "time" + + "vppn/peer/control" + "vppn/peer/wginterface" +) + +func (a *App) onTick() { + wgPeers := a.devPeers() + + now := time.Now().UnixNano() + + for _, wgPeer := range wgPeers { + p, ok := a.peersByKey[wgPeer.PublicKey] + if !ok { + log.Printf("Wireguard peer not in index, removing: %v", wgPeer) + a.devRemove(&Peer{wgPeer: wgPeer}) + continue + } + p.wgPeer = wgPeer + + // Send pings to peers where we're the client. + if p.Role == control.Client { + a.sendPing(p, now) + } + + switch p.State { + case StateProbing: + // Promote probing peers to direct once alive (direct path confirmed + // working). + if time.Since(p.LastHandshakeTime()) < 2*wginterface.ProbeKeepalive { + a.devPromote(p) + } + + case StateDirect: + if p.IsPublic || a.isPublic || p.Up() { + break + } + // Stale non-public direct peer: demote to probing so WireGuard + // resumes handshake attempts on the direct path. + a.devAddProbe(p, p.WGEndpoint()) + } + } + + // Ensure we have a live relay (if we're not public). + if !a.isPublic && (a.relay == nil || !a.relay.Up()) { + a.switchActiveRelay() + } +} diff --git a/peer/packets-util.go b/peer/packets-util.go deleted file mode 100644 index 3bdfc67..0000000 --- a/peer/packets-util.go +++ /dev/null @@ -1,182 +0,0 @@ -package peer - -import ( - "net/netip" - "unsafe" -) - -// ---------------------------------------------------------------------------- - -type binWriter struct { - b []byte - i int -} - -func newBinWriter(buf []byte) *binWriter { - buf = buf[:cap(buf)] - return &binWriter{buf, 0} -} - -func (w *binWriter) Bool(b bool) *binWriter { - if b { - return w.Byte(1) - } - return w.Byte(0) -} - -func (w *binWriter) Byte(b byte) *binWriter { - w.b[w.i] = b - w.i++ - return w -} - -func (w *binWriter) SharedKey(key [32]byte) *binWriter { - copy(w.b[w.i:w.i+32], key[:]) - w.i += 32 - return w -} - -func (w *binWriter) Uint16(x uint16) *binWriter { - *(*uint16)(unsafe.Pointer(&w.b[w.i])) = x - w.i += 2 - return w -} - -func (w *binWriter) Uint64(x uint64) *binWriter { - *(*uint64)(unsafe.Pointer(&w.b[w.i])) = x - w.i += 8 - return w -} - -func (w *binWriter) Int64(x int64) *binWriter { - *(*int64)(unsafe.Pointer(&w.b[w.i])) = x - w.i += 8 - return w -} - -func (w *binWriter) AddrPort(addrPort netip.AddrPort) *binWriter { - w.Bool(addrPort.IsValid()) - addr := addrPort.Addr().As16() - copy(w.b[w.i:w.i+16], addr[:]) - w.i += 16 - return w.Uint16(addrPort.Port()) -} - -func (w *binWriter) AddrPort8(l [8]netip.AddrPort) *binWriter { - for _, addrPort := range l { - w.AddrPort(addrPort) - } - return w -} - -func (w *binWriter) Build() []byte { - return w.b[:w.i] -} - -// ---------------------------------------------------------------------------- - -type binReader struct { - b []byte - i int - err error -} - -func newBinReader(buf []byte) *binReader { - return &binReader{b: buf} -} - -func (r *binReader) hasBytes(n int) bool { - if r.err != nil || (len(r.b)-r.i) < n { - r.err = errMalformedPacket - return false - } - return true -} - -func (r *binReader) Bool(b *bool) *binReader { - var bb byte - r.Byte(&bb) - *b = bb != 0 - return r -} - -func (r *binReader) Byte(b *byte) *binReader { - if !r.hasBytes(1) { - return r - } - *b = r.b[r.i] - r.i++ - return r -} - -func (r *binReader) SharedKey(x *[32]byte) *binReader { - if !r.hasBytes(32) { - return r - } - *x = ([32]byte)(r.b[r.i : r.i+32]) - r.i += 32 - return r -} - -func (r *binReader) Uint16(x *uint16) *binReader { - if !r.hasBytes(2) { - return r - } - *x = *(*uint16)(unsafe.Pointer(&r.b[r.i])) - r.i += 2 - return r -} - -func (r *binReader) Uint64(x *uint64) *binReader { - if !r.hasBytes(8) { - return r - } - *x = *(*uint64)(unsafe.Pointer(&r.b[r.i])) - r.i += 8 - return r -} - -func (r *binReader) Int64(x *int64) *binReader { - if !r.hasBytes(8) { - return r - } - *x = *(*int64)(unsafe.Pointer(&r.b[r.i])) - r.i += 8 - return r -} - -func (r *binReader) AddrPort(x *netip.AddrPort) *binReader { - if !r.hasBytes(19) { - return r - } - - var ( - valid bool - port uint16 - ) - - r.Bool(&valid) - addr := netip.AddrFrom16(([16]byte)(r.b[r.i : r.i+16])).Unmap() - r.i += 16 - - r.Uint16(&port) - - if valid { - *x = netip.AddrPortFrom(addr, port) - } else { - *x = netip.AddrPort{} - } - - return r -} - -func (r *binReader) AddrPort8(x *[8]netip.AddrPort) *binReader { - for i := range x { - r.AddrPort(&x[i]) - } - return r -} - -func (r *binReader) Error() error { - return r.err -} diff --git a/peer/packets-util_test.go b/peer/packets-util_test.go deleted file mode 100644 index 6e4a98c..0000000 --- a/peer/packets-util_test.go +++ /dev/null @@ -1,76 +0,0 @@ -package peer - -import ( - "net/netip" - "reflect" - "testing" -) - -func TestBinWriteRead_invalidAddrPort(t *testing.T) { - addr := netip.AddrPort{} - buf := make([]byte, 1024) - buf = newBinWriter(buf). - AddrPort(addr). - Build() - - var addr2 netip.AddrPort - err := newBinReader(buf). - AddrPort(&addr2). - Error() - if err != nil { - t.Fatal(err) - } - - if addr2.IsValid() { - t.Fatal(addr, addr2) - } -} - -func TestBinWriteRead(t *testing.T) { - buf := make([]byte, 1024) - - type Item struct { - Type byte - TraceID uint64 - Addrs [8]netip.AddrPort - DestAddr netip.AddrPort - } - - in := Item{ - 1, - 2, - [8]netip.AddrPort{}, - netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 22), - } - - in.Addrs[0] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{0, 1, 2, 3}), 20) - in.Addrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 5}), 22) - in.Addrs[3] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 3}), 23) - in.Addrs[4] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 4}), 24) - in.Addrs[5] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 5}), 25) - in.Addrs[6] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 6}), 26) - in.Addrs[7] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{7, 8, 9, 7}), 27) - - buf = newBinWriter(buf). - Byte(in.Type). - Uint64(in.TraceID). - AddrPort(in.DestAddr). - AddrPort8(in.Addrs). - Build() - - out := Item{} - - err := newBinReader(buf). - Byte(&out.Type). - Uint64(&out.TraceID). - AddrPort(&out.DestAddr). - AddrPort8(&out.Addrs). - Error() - if err != nil { - t.Fatal(err) - } - - if !reflect.DeepEqual(in, out) { - t.Fatal(in, out) - } -} diff --git a/peer/packets.go b/peer/packets.go deleted file mode 100644 index b673a4c..0000000 --- a/peer/packets.go +++ /dev/null @@ -1,120 +0,0 @@ -package peer - -import ( - "net/netip" -) - -const ( - packetTypeSyn = 1 - packetTypeInit = 2 - packetTypeAck = 3 - packetTypeProbe = 4 - packetTypeAddrDiscovery = 5 -) - -// ---------------------------------------------------------------------------- - -type packetInit struct { - TraceID uint64 - Direct bool - Version uint64 -} - -func (p packetInit) Marshal(buf []byte) []byte { - return newBinWriter(buf). - Byte(packetTypeInit). - Uint64(p.TraceID). - Bool(p.Direct). - Uint64(p.Version). - Build() -} - -func parsePacketInit(buf []byte) (p packetInit, err error) { - err = newBinReader(buf[1:]). - Uint64(&p.TraceID). - Bool(&p.Direct). - Uint64(&p.Version). - Error() - return -} - -// ---------------------------------------------------------------------------- - -type packetSyn struct { - TraceID uint64 // TraceID to match response w/ request. - SharedKey [32]byte // Our shared key. - Direct bool - PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. -} - -func (p packetSyn) Marshal(buf []byte) []byte { - return newBinWriter(buf). - Byte(packetTypeSyn). - Uint64(p.TraceID). - SharedKey(p.SharedKey). - Bool(p.Direct). - AddrPort8(p.PossibleAddrs). - Build() -} - -func parsePacketSyn(buf []byte) (p packetSyn, err error) { - err = newBinReader(buf[1:]). - Uint64(&p.TraceID). - SharedKey(&p.SharedKey). - Bool(&p.Direct). - AddrPort8(&p.PossibleAddrs). - Error() - return -} - -// ---------------------------------------------------------------------------- - -type packetAck struct { - TraceID uint64 - ToAddr netip.AddrPort - PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. -} - -func (p packetAck) Marshal(buf []byte) []byte { - return newBinWriter(buf). - Byte(packetTypeAck). - Uint64(p.TraceID). - AddrPort(p.ToAddr). - AddrPort8(p.PossibleAddrs). - Build() -} - -func parsePacketAck(buf []byte) (p packetAck, err error) { - err = newBinReader(buf[1:]). - Uint64(&p.TraceID). - AddrPort(&p.ToAddr). - AddrPort8(&p.PossibleAddrs). - Error() - return -} - -// ---------------------------------------------------------------------------- - -// A probeReqPacket is sent from a client to a server to determine if direct -// UDP communication can be used. -type packetProbe struct { - TraceID uint64 -} - -func (p packetProbe) Marshal(buf []byte) []byte { - return newBinWriter(buf). - Byte(packetTypeProbe). - Uint64(p.TraceID). - Build() -} - -func parsePacketProbe(buf []byte) (p packetProbe, err error) { - err = newBinReader(buf[1:]). - Uint64(&p.TraceID). - Error() - return -} - -// ---------------------------------------------------------------------------- - -type packetLocalDiscovery struct{} diff --git a/peer/packets_test.go b/peer/packets_test.go deleted file mode 100644 index f41817b..0000000 --- a/peer/packets_test.go +++ /dev/null @@ -1,64 +0,0 @@ -package peer - -import ( - "crypto/rand" - "net/netip" - "reflect" - "testing" -) - -func TestSynPacket(t *testing.T) { - p := packetSyn{ - TraceID: 2342342345, - Direct: true, - } - rand.Read(p.SharedKey[:]) - - p.PossibleAddrs[0] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 234) - p.PossibleAddrs[1] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 2, 3, 4}), 12399) - p.PossibleAddrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{3, 2, 3, 4}), 60000) - - buf := p.Marshal(newBuf()) - p2, err := parsePacketSyn(buf) - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(p, p2) { - t.Fatal(p2) - } -} - -func TestAckPacket(t *testing.T) { - p := packetAck{ - TraceID: 123213, - ToAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 234), - } - - p.PossibleAddrs[0] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 2, 3, 4}), 100) - p.PossibleAddrs[1] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 2, 3, 4}), 12399) - p.PossibleAddrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{3, 2, 3, 4}), 60000) - - buf := p.Marshal(newBuf()) - p2, err := parsePacketAck(buf) - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(p, p2) { - t.Fatal(p2) - } -} - -func TestProbePacket(t *testing.T) { - p := packetProbe{ - TraceID: 12345, - } - - buf := p.Marshal(newBuf()) - p2, err := parsePacketProbe(buf) - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(p, p2) { - t.Fatal(p2) - } -} diff --git a/peer/peer.go b/peer/peer.go deleted file mode 100644 index f206141..0000000 --- a/peer/peer.go +++ /dev/null @@ -1,201 +0,0 @@ -package peer - -import ( - "bytes" - "encoding/json" - "fmt" - "io" - "log" - "math" - "net" - "net/http" - "net/netip" - "net/url" - "os" - "vppn/m" - - "git.crumpington.com/lib/go/flock" -) - -type peerMain struct { - Globals - ifReader *IFReader - connReader *ConnReader - hubPoller *HubPoller - lockFile *os.File -} - -func newPeerMain(args mainArgs) *peerMain { - logf := func(s string, args ...any) { - log.Printf("[Main] "+s, args...) - } - - if err := os.MkdirAll(configDir(args.NetName), 0700); err != nil { - log.Fatalf("Failed to create config directory: %v", err) - } - - lockFile, err := flock.TryLock(lockFilePath(args.NetName)) - if err != nil { - log.Fatalf("Failed to open lock file: %v", err) - } - if lockFile == nil { - log.Fatalf("Failed to obtain file lock.") - } - - config, err := loadPeerConfig(args.NetName) - if err != nil { - logf("Failed to load configuration: %v", err) - logf("Initializing...") - initPeerWithHub(args) - - config, err = loadPeerConfig(args.NetName) - if err != nil { - log.Fatalf("Failed to load configuration: %v", err) - } - } - - state, err := loadNetworkState(args.NetName) - if err != nil { - log.Fatalf("Failed to load network state: %v", err) - } - - startupCount, err := loadStartupCount(args.NetName) - if err != nil { - if !os.IsNotExist(err) { - log.Fatalf("Failed to load startup count: %v", err) - } - } - - if startupCount.Count == math.MaxUint16 { - log.Fatalf("Startup counter overflow.") - } - startupCount.Count += 1 - - if err := storeStartupCount(args.NetName, startupCount); err != nil { - log.Fatalf("Failed to write startup count: %v", err) - } - - iface, err := openInterface(config.Network, config.LocalPeerIP, args.NetName) - if err != nil { - log.Fatalf("Failed to open interface: %v", err) - } - - localPeer := state.Peers[config.LocalPeerIP] - - myAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", localPeer.Port)) - if err != nil { - log.Fatalf("Failed to resolve UDP address: %v", err) - } - - logf("Listening on %v...", myAddr) - conn, err := net.ListenUDP("udp", myAddr) - if err != nil { - log.Fatalf("Failed to open UDP port: %v", err) - } - - conn.SetReadBuffer(1024 * 1024 * 8) - conn.SetWriteBuffer(1024 * 1024 * 8) - - var localAddr netip.AddrPort - ip, localAddrValid := netip.AddrFromSlice(localPeer.PublicIP) - if localAddrValid { - localAddr = netip.AddrPortFrom(ip, localPeer.Port) - } - - g := NewGlobals(config, startupCount, localAddr, conn, iface) - - hubPoller, err := NewHubPoller(g, args.NetName, args.HubAddress, args.APIKey) - if err != nil { - log.Fatalf("Failed to create hub poller: %v", err) - } - - // Start status server. - go runStatusServer(g, statusSocketPath(args.NetName)) - - return &peerMain{ - Globals: g, - ifReader: NewIFReader(g), - connReader: NewConnReader(g, conn), - hubPoller: hubPoller, - lockFile: lockFile, - } -} - -func (p *peerMain) Run() { - for i := range p.RemotePeers { - remote := p.RemotePeers[i].Load() - go newRemoteFSM(remote).Run() - } - - go p.ifReader.Run() - go p.connReader.Run() - - if !p.LocalAddrValid { - go RunMCWriter(p.LocalPeerIP, p.PrivSignKey) - go RunMCReader(p.Globals) - } - - go p.hubPoller.Run() - - select {} -} - -func initPeerWithHub(args mainArgs) { - keys := generateKeys() - - initURL, err := url.Parse(args.HubAddress) - if err != nil { - log.Fatalf("Failed to parse hub URL: %v", err) - } - initURL.Path = "/peer/init/" - - initArgs := m.PeerInitArgs{ - EncPubKey: keys.PubKey, - PubSignKey: keys.PubSignKey, - } - - buf := &bytes.Buffer{} - if err := json.NewEncoder(buf).Encode(initArgs); err != nil { - log.Fatalf("Failed to encode init args: %v", err) - } - - req, err := http.NewRequest(http.MethodPost, initURL.String(), buf) - if err != nil { - log.Fatalf("Failed to construct request: %v", err) - } - req.SetBasicAuth("", args.APIKey) - - resp, err := http.DefaultClient.Do(req) - if err != nil { - log.Fatalf("Failed to init with hub: %v", err) - } - defer resp.Body.Close() - - data, err := io.ReadAll(resp.Body) - if err != nil { - log.Fatalf("Failed to read response body: %v", err) - } - - initResp := m.PeerInitResp{} - if err := json.Unmarshal(data, &initResp); err != nil { - log.Fatalf("Failed to parse configuration: %v\n%s", err, data) - } - - config := LocalConfig{} - config.LocalPeerIP = initResp.PeerIP - config.Network = initResp.Network - config.PubKey = keys.PubKey - config.PrivKey = keys.PrivKey - config.PubSignKey = keys.PubSignKey - config.PrivSignKey = keys.PrivSignKey - - if err := storeNetworkState(args.NetName, initResp.NetworkState); err != nil { - log.Fatalf("Failed to store network state: %v", err) - } - - if err := storePeerConfig(args.NetName, config); err != nil { - log.Fatalf("Failed to store configuration: %v", err) - } - - log.Print("Initialization successful.") -} diff --git a/peer/ping.go b/peer/ping.go new file mode 100644 index 0000000..6b51610 --- /dev/null +++ b/peer/ping.go @@ -0,0 +1,21 @@ +package peer + +import ( + "log" + "net/netip" + + "vppn/peer/control" +) + +func (a *App) sendPing(p *Peer, ts int64) { + ping := control.Ping{ + PingTS: ts, + SrcV4: a.selfV4, + SrcV6: a.selfV6, + Dst: p.WGEndpoint(), + } + dst := netip.AddrPortFrom(p.VPNIP, ControlPort) + if err := a.controlConn.SendPing(dst, ping, a.scratch); err != nil { + log.Printf("sendPing %v: %v", p.VPNIP, err) + } +} diff --git a/peer/pubaddrs.go b/peer/pubaddrs.go deleted file mode 100644 index 7945458..0000000 --- a/peer/pubaddrs.go +++ /dev/null @@ -1,86 +0,0 @@ -package peer - -import ( - "net/netip" - "sort" - "sync" - "time" -) - -type pubAddrStore struct { - lock sync.Mutex - localPub bool - localAddr netip.AddrPort - lastSeen map[netip.AddrPort]time.Time - addrList []netip.AddrPort -} - -func newPubAddrStore(localAddr netip.AddrPort) *pubAddrStore { - return &pubAddrStore{ - localPub: localAddr.IsValid(), - localAddr: localAddr, - lastSeen: map[netip.AddrPort]time.Time{}, - addrList: make([]netip.AddrPort, 0, 32), - } -} - -func (store *pubAddrStore) Store(addr netip.AddrPort) { - if store.localPub { - return - } - - if !addr.IsValid() { - return - } - - if addr.Addr().IsPrivate() { - return - } - - store.lock.Lock() - defer store.lock.Unlock() - - if _, exists := store.lastSeen[addr]; !exists { - store.addrList = append(store.addrList, addr) - } - store.lastSeen[addr] = time.Now() - store.sort() -} - -func (store *pubAddrStore) Get() (addrs [8]netip.AddrPort) { - store.lock.Lock() - defer store.lock.Unlock() - - store.clean() - - if store.localPub { - addrs[0] = store.localAddr - return - } - - copy(addrs[:], store.addrList) - return -} - -func (store *pubAddrStore) clean() { - if store.localPub { - return - } - - for ip, lastSeen := range store.lastSeen { - if time.Since(lastSeen) > timeoutInterval { - delete(store.lastSeen, ip) - } - } - store.addrList = store.addrList[:0] - for ip := range store.lastSeen { - store.addrList = append(store.addrList, ip) - } - store.sort() -} - -func (store *pubAddrStore) sort() { - sort.Slice(store.addrList, func(i, j int) bool { - return store.lastSeen[store.addrList[j]].Before(store.lastSeen[store.addrList[i]]) - }) -} diff --git a/peer/pubaddrs_test.go b/peer/pubaddrs_test.go deleted file mode 100644 index fa47c22..0000000 --- a/peer/pubaddrs_test.go +++ /dev/null @@ -1,29 +0,0 @@ -package peer - -import ( - "net/netip" - "testing" - "time" -) - -func TestPubAddrStore(t *testing.T) { - s := newPubAddrStore(netip.AddrPort{}) - - l := []netip.AddrPort{ - netip.AddrPortFrom(netip.AddrFrom4([4]byte{0, 1, 2, 3}), 20), - netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 2, 3}), 21), - netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 1, 2, 3}), 22), - } - - for i := range l { - s.Store(l[i]) - time.Sleep(time.Millisecond) - } - - s.clean() - - l2 := s.Get() - if l2[0] != l[2] || l2[1] != l[1] || l2[2] != l[0] { - t.Fatal(l, l2) - } -} diff --git a/peer/relayhandler.go b/peer/relayhandler.go deleted file mode 100644 index a8e9b3d..0000000 --- a/peer/relayhandler.go +++ /dev/null @@ -1,54 +0,0 @@ -package peer - -import ( - "log" - "sync" - "sync/atomic" -) - -type relayHandler struct { - lock sync.Mutex - relays map[byte]*Remote - relay atomic.Pointer[Remote] -} - -func newRelayHandler() *relayHandler { - return &relayHandler{ - relays: make(map[byte]*Remote, 256), - } -} - -func (h *relayHandler) Add(r *Remote) { - h.lock.Lock() - defer h.lock.Unlock() - - h.relays[r.RemotePeerIP] = r - - if h.relay.Load() == nil { - log.Printf("Setting Relay: %v", r.conf().Peer.Name) - h.relay.Store(r) - } -} - -func (h *relayHandler) Remove(r *Remote) { - h.lock.Lock() - defer h.lock.Unlock() - - log.Printf("Removing relay %d...", r.RemotePeerIP) - delete(h.relays, r.RemotePeerIP) - - if h.relay.Load() == r { - // Remove current relay. - h.relay.Store(nil) - - // Find new relay. - for _, r := range h.relays { - h.relay.Store(r) - break - } - } -} - -func (h *relayHandler) Load() *Remote { - return h.relay.Load() -} diff --git a/peer/remote.go b/peer/remote.go index 9f815fa..d3b889a 100644 --- a/peer/remote.go +++ b/peer/remote.go @@ -1,351 +1,75 @@ package peer import ( - "fmt" - "log" "net/netip" - "strings" - "sync/atomic" - "vppn/m" + "time" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "vppn/peer/control" + "vppn/peer/wginterface" ) -// ---------------------------------------------------------------------------- -// The remoteConfig is the shared, immutable configuration for a remote -// peer. It's read and written atomically. See remote.config. -// ---------------------------------------------------------------------------- +type PeerState string -type remoteConfig struct { - Up bool // True if peer is up and we can send data. - Server bool // True if role is server. - Direct bool // True if this is a direct connection. - DirectAddr netip.AddrPort // Remote address if directly connected. - ControlCipher *controlCipher - DataCipher *dataCipher - Peer *m.Peer +const ( + StateRelayed = PeerState("RELAY") + StateProbing = PeerState("PROBE") + StateDirect = PeerState("DIRECT") +) + +type Peer struct { + wgPeer wgtypes.Peer + VPNIP netip.Addr // VPN IP address. + Name string // Human-readable DNS label. + IsRelay bool // Peer is a relay. + IsPublic bool // Peer has a public IP. + EndpointV4 netip.AddrPort // Reported IPv4 endpoint. + EndpointV6 netip.AddrPort // Reported IPv6 endpoint. + RTT time.Duration // Round-trip time. + State PeerState // Current routing state; updated on each devXxx call. + Role control.Role // Client initiates pings; server responds. + SignPubKey [32]byte // nacl/sign public key for verifying multicast beacons. } -// CanRelay returns true if the remote configuration is able to relay packets. -// to other hosts. -func (rc remoteConfig) CanRelay() bool { - return rc.Up && rc.Direct && rc.Peer.Relay +// PubKey is the wireguard public key. +func (p *Peer) PubKey() wgtypes.Key { + return p.wgPeer.PublicKey } -// A Remote represents a remote peer and contains functions for handling -// incoming control, data, and multicast packets, peer udpates, as well as -// sending, forwarding, and relaying packets. -type Remote struct { - Globals - RemotePeerIP byte // Immutable. - - dupCheck *dupCheck - sendCounter uint64 // init to startupCount << 48. Atomic access only. - - // config should be accessed via conf() and updateConf(...) methods. - config atomic.Pointer[remoteConfig] - messages chan any -} - -func newRemote(g Globals, remotePeerIP byte) *Remote { - r := &Remote{ - Globals: g, - RemotePeerIP: remotePeerIP, - dupCheck: newDupCheck(0), - sendCounter: (uint64(g.StartupCount) << 48) + 1, - messages: make(chan any, 8), +func (p *Peer) WGEndpoint() netip.AddrPort { + ep := p.wgPeer.Endpoint + if ep == nil { + return netip.AddrPort{} } - r.config.Store(&remoteConfig{}) - return r -} - -// ---------------------------------------------------------------------------- - -func (r *Remote) conf() remoteConfig { - return *(r.config.Load()) -} - -func (r *Remote) updateConf(conf remoteConfig) { - old := r.config.Load() - r.config.Store(&conf) - - if !old.CanRelay() && conf.CanRelay() { - r.RelayHandler.Add(r) - } - - if old.CanRelay() && !conf.CanRelay() { - r.RelayHandler.Remove(r) - } -} - -// ---------------------------------------------------------------------------- - -func (r *Remote) sendUDP(b []byte, addr netip.AddrPort) { - if _, err := r.SendUDP(b, addr); err != nil { - r.logf("Failed to send UDP packet: %v", err) - } -} - -// ---------------------------------------------------------------------------- - -func (r *Remote) encryptData(conf remoteConfig, destIP byte, packet []byte) []byte { - h := Header{ - StreamID: dataStreamID, - Counter: atomic.AddUint64(&r.sendCounter, 1), - SourceIP: r.Globals.LocalPeerIP, - DestIP: destIP, - } - return conf.DataCipher.Encrypt(h, packet, packet[len(packet):cap(packet)]) -} - -func (r *Remote) encryptControl(conf remoteConfig, packet []byte) []byte { - h := Header{ - StreamID: controlStreamID, - Counter: atomic.AddUint64(&r.sendCounter, 1), - SourceIP: r.LocalPeerIP, - DestIP: r.RemotePeerIP, - } - return conf.ControlCipher.Encrypt(h, packet, packet[len(packet):cap(packet)]) -} - -func (r *Remote) Status() (RemoteStatus, bool) { - conf := r.conf() - if conf.Peer == nil { - return RemoteStatus{}, false - } - - return RemoteStatus{ - PeerIP: conf.Peer.PeerIP, - Up: conf.Up, - Name: conf.Peer.Name, - PublicIP: conf.Peer.PublicIP, - Port: conf.Peer.Port, - Relay: conf.Peer.Relay, - Server: conf.Server, - Direct: conf.Direct, - DirectAddr: conf.DirectAddr, - }, true -} - -// ---------------------------------------------------------------------------- - -// SendDataTo sends a data packet to the remote, called by the IFReader. -func (r *Remote) SendDataTo(data []byte) { - conf := r.conf() - if !conf.Up { - r.logf("Cannot send: link down") - return - } - - // Direct: - - if conf.Direct { - r.sendUDP(r.encryptData(conf, conf.Peer.PeerIP, data), conf.DirectAddr) - return - } - - // Relayed: - relay := r.RelayHandler.Load() - - if relay == nil { - r.logf("Connot send: no relay") - return - } - - relay.relayData(conf.Peer.PeerIP, r.encryptData(conf, conf.Peer.PeerIP, data)) -} - -func (r *Remote) relayData(toIP byte, enc []byte) { - conf := r.conf() - if !conf.Up || !conf.Direct { - r.logf("Cannot relay: not up or not a direct connection") - return - } - r.sendUDP(r.encryptData(conf, toIP, enc), conf.DirectAddr) -} - -func (r *Remote) sendControl(conf remoteConfig, data []byte) { - // Direct: - - if conf.Direct { - enc := r.encryptControl(conf, data) - r.sendUDP(enc, conf.DirectAddr) - return - } - - // Relayed: - - relay := r.RelayHandler.Load() - - if relay == nil { - r.logf("Connot send: no relay") - return - } - - relay.relayData(conf.Peer.PeerIP, r.encryptControl(conf, data)) -} - -func (r *Remote) sendControlToAddr(buf []byte, addr netip.AddrPort) { - enc := r.encryptControl(r.conf(), buf) - r.sendUDP(enc, addr) -} - -func (r *Remote) forwardPacket(data []byte) { - conf := r.conf() - if !conf.Up || !conf.Direct { - r.logf("Cannot forward to %d: not a direct connection", conf.Peer.PeerIP) - return - } - r.sendUDP(data, conf.DirectAddr) -} - -// ---------------------------------------------------------------------------- - -// HandlePacket is called by the ConnReader to handle an incoming packet. -func (r *Remote) HandlePacket(h Header, srcAddr netip.AddrPort, data []byte) { - switch h.StreamID { - case controlStreamID: - r.handleControlPacket(h, srcAddr, data) - case dataStreamID: - r.handleDataPacket(h, data) - default: - r.logf("Unknown stream ID: %d", h.StreamID) - } -} - -// Handle a control packet. Decrypt, verify, etc. -func (r *Remote) handleControlPacket(h Header, srcAddr netip.AddrPort, data []byte) { - conf := r.conf() - if conf.ControlCipher == nil { - r.logf("No control cipher") - return - } - - dec, ok := conf.ControlCipher.Decrypt(data, data[len(data):cap(data)]) + addr, ok := netip.AddrFromSlice(ep.IP) if !ok { - r.logf("Failed to decrypt control packet") - return - } - - if r.dupCheck.IsDup(h.Counter) { - r.logf("Dropping control packet as duplicate: %d", h.Counter) - return - } - - msg, err := parseControlMsg(h.SourceIP, srcAddr, dec) - if err != nil { - r.logf("Failed to parse control packet: %v", err) - return - } - - select { - case r.messages <- msg: - default: - r.logf("Dropping control message") + return netip.AddrPort{} } + return netip.AddrPortFrom(addr.Unmap(), uint16(ep.Port)) } -func (r *Remote) handleDataPacket(h Header, data []byte) { - conf := r.conf() - if conf.DataCipher == nil { - return - } - - dec, ok := conf.DataCipher.Decrypt(data, data[len(data):cap(data)]) - if !ok { - r.logf("Failed to decrypt data packet") - return - } - - if r.dupCheck.IsDup(h.Counter) { - r.logf("Dropping data packet as duplicate: %d", h.Counter) - return - } - - // For local. - if h.DestIP == r.LocalPeerIP { - if _, err := r.IFace.Write(dec); err != nil { - // This could be a malformed packet from a peer, so we don't crash if it - // happens. - r.logf("Failed to write to interface: %v", err) - } - return - } - - // Forward. - dest := r.RemotePeers[h.DestIP].Load() - dest.forwardPacket(dec) +func (p *Peer) LastHandshakeTime() time.Time { + return p.wgPeer.LastHandshakeTime } -// ---------------------------------------------------------------------------- - -// HandleLocalDiscoveryPacket is called by the MCReader. -func (r *Remote) HandleLocalDiscoveryPacket(h Header, srcAddr netip.AddrPort, data []byte) { - conf := r.conf() - if conf.Peer == nil { - r.logf("No peer for discovery packet.") - return - } - - if conf.Peer.PubSignKey == nil { - r.logf("No signing key for discovery packet.") - return - } - - if !verifyLocalDiscoveryPacket(data, data[len(data):cap(data)], conf.Peer.PubSignKey) { - r.logf("Invalid signature on discovery packet.") - return - } - - msg := controlMsg[packetLocalDiscovery]{ - SrcIP: h.SourceIP, - SrcAddr: srcAddr, - } - - select { - case r.messages <- msg: - default: - r.logf("Dropping discovery message.") - } +func (p *Peer) Up() bool { + return time.Since(p.wgPeer.LastHandshakeTime) < wginterface.SessionTimeout } -// ---------------------------------------------------------------------------- - -// HandlePeerUpdate is called by the HubPoller when it gets a new version of -// the associated peer configuration. -func (r *Remote) HandlePeerUpdate(msg peerUpdateMsg) { - r.messages <- msg +func (p *Peer) CanRelay() bool { + return p.IsRelay && p.Up() } -// ---------------------------------------------------------------------------- - -func (s *Remote) logf(format string, args ...any) { - conf := s.conf() - - b := strings.Builder{} - name := "" - if conf.Peer != nil { - name = conf.Peer.Name - } - b.WriteString(fmt.Sprintf("%03d", s.RemotePeerIP)) - - b.WriteString(fmt.Sprintf("%30s: ", name)) - - if conf.Server { - b.WriteString("SERVER | ") - } else { - b.WriteString("CLIENT | ") - } - - if conf.Direct { - b.WriteString("DIRECT | ") - } else { - b.WriteString("RELAYED | ") - } - - if conf.Up { - b.WriteString("UP | ") - } else { - b.WriteString("DOWN | ") - } - - log.Printf(b.String()+format, args...) +func (p *Peer) PreferredEndpoint() netip.AddrPort { + return preferredEndpoint(p.EndpointV4, p.EndpointV6) +} + +func (p *Peer) UpdateEndpoints(v4, v6 netip.AddrPort) { + if v4.IsValid() { + p.EndpointV4 = v4 + } + if v6.IsValid() { + p.EndpointV6 = v6 + } } diff --git a/peer/remotefsm.go b/peer/remotefsm.go deleted file mode 100644 index 9f6a442..0000000 --- a/peer/remotefsm.go +++ /dev/null @@ -1,448 +0,0 @@ -package peer - -import ( - "net/netip" - "time" - "vppn/m" -) - -type stateFunc func(msg any) stateFunc - -type sentProbe struct { - SentAt time.Time - Addr netip.AddrPort -} - -type remoteFSM struct { - *Remote - - pingTimer *time.Ticker - lastSeen time.Time - traceID uint64 - probes map[uint64]sentProbe - - buf []byte -} - -func newRemoteFSM(r *Remote) *remoteFSM { - fsm := &remoteFSM{ - Remote: r, - pingTimer: time.NewTicker(timeoutInterval), - probes: map[uint64]sentProbe{}, - buf: make([]byte, bufferSize), - } - fsm.pingTimer.Stop() - return fsm -} - -func (r *remoteFSM) Run() { - go func() { - for range r.pingTimer.C { - r.messages <- pingTimerMsg{} - } - }() - state := r.enterDisconnected() - for msg := range r.messages { - state = state(msg) - } -} - -// ---------------------------------------------------------------------------- - -func (r *remoteFSM) enterDisconnected() stateFunc { - r.updateConf(remoteConfig{}) - return r.stateDisconnected -} - -func (r *remoteFSM) stateDisconnected(iMsg any) stateFunc { - switch msg := iMsg.(type) { - case peerUpdateMsg: - return r.enterPeerUpdating(msg.Peer) - case controlMsg[packetInit]: - r.logf("Unexpected INIT") - case controlMsg[packetSyn]: - r.logf("Unexpected SYN") - case controlMsg[packetAck]: - r.logf("Unexpected ACK") - case controlMsg[packetProbe]: - r.logf("Unexpected probe") - case controlMsg[packetLocalDiscovery]: - // Ignore - case pingTimerMsg: - r.logf("Unexpected ping") - default: - r.logf("Ignoring message: %#v", iMsg) - } - - return r.stateDisconnected -} - -// ---------------------------------------------------------------------------- - -func (r *remoteFSM) enterPeerUpdating(peer *m.Peer) stateFunc { - if peer == nil { - return r.enterDisconnected() - } - - conf := remoteConfig{ - Peer: peer, - ControlCipher: newControlCipher(r.PrivKey, peer.PubKey), - } - r.updateConf(conf) - - if _, isValid := netip.AddrFromSlice(peer.PublicIP); isValid { - if r.LocalAddrValid && r.LocalPeerIP < peer.PeerIP { - return r.enterServer() - } - return r.enterClientInit() - } - - if r.LocalAddrValid || r.LocalPeerIP < peer.PeerIP { - return r.enterServer() - } - - return r.enterClientInit() -} - -// ---------------------------------------------------------------------------- - -func (r *remoteFSM) enterServer() stateFunc { - - conf := r.conf() - conf.Server = true - r.updateConf(conf) - r.logf("==> Server") - - r.pingTimer.Reset(pingInterval) - r.lastSeen = time.Now() - return r.stateServer -} - -func (r *remoteFSM) stateServer(iMsg any) stateFunc { - switch msg := iMsg.(type) { - case peerUpdateMsg: - return r.enterPeerUpdating(msg.Peer) - case controlMsg[packetInit]: - r.stateServer_onInit(msg) - case controlMsg[packetSyn]: - r.stateServer_onSyn(msg) - case controlMsg[packetAck]: - r.logf("Unexpected ACK") - case controlMsg[packetProbe]: - r.stateServer_onProbe(msg) - case controlMsg[packetLocalDiscovery]: - // Ignore - case pingTimerMsg: - r.stateServer_onPingTimer() - default: - r.logf("Unexpected message: %#v", iMsg) - } - - return r.stateServer -} - -func (r *remoteFSM) stateServer_onInit(msg controlMsg[packetInit]) { - conf := r.conf() - conf.Up = false - conf.Direct = msg.Packet.Direct - conf.DirectAddr = msg.SrcAddr - r.updateConf(conf) - - init := packetInit{ - TraceID: msg.Packet.TraceID, - Direct: conf.Direct, - Version: version, - } - - // Reset traceID to force state update on SYN. - r.traceID = 0 - r.sendControl(conf, init.Marshal(r.buf)) -} - -func (r *remoteFSM) stateServer_onSyn(msg controlMsg[packetSyn]) { - r.lastSeen = time.Now() - p := msg.Packet - - conf := r.conf() - - // New trace ID => Update the route configuration. - if p.TraceID != r.traceID { - r.traceID = p.TraceID - - conf.Up = true - conf.Direct = p.Direct - conf.DirectAddr = msg.SrcAddr - - conf.DataCipher = newDataCipherFromKey(p.SharedKey) - - r.updateConf(conf) - r.logf("Got SYN.") - } - - r.sendControl(conf, packetAck{ - TraceID: p.TraceID, - ToAddr: conf.DirectAddr, - PossibleAddrs: r.PubAddrs.Get(), - }.Marshal(r.buf)) - - if p.Direct { - return - } - - // Send probes if not a direct connection. The server sends probes without - // trace IDs unless responding to a client probe. - for _, addr := range msg.Packet.PossibleAddrs { - if !addr.IsValid() { - break - } - r.logf("Probing %v...", addr) - r.sendControlToAddr(packetProbe{}.Marshal(r.buf), addr) - } -} - -func (r *remoteFSM) stateServer_onProbe(msg controlMsg[packetProbe]) { - if !msg.SrcAddr.IsValid() { - return - } - - data := packetProbe{TraceID: msg.Packet.TraceID}.Marshal(r.buf) - r.sendControlToAddr(data, msg.SrcAddr) -} - -func (r *remoteFSM) stateServer_onPingTimer() { - conf := r.conf() - if time.Since(r.lastSeen) > timeoutInterval && conf.Up { - // Reset trace ID to ensure connection goes up on next SYN. - r.traceID = 0 - conf.Up = false - r.updateConf(conf) - r.logf("Timeout.") - } -} - -// ---------------------------------------------------------------------------- - -func (r *remoteFSM) enterClientInit() stateFunc { - conf := r.conf() - ip, ipValid := netip.AddrFromSlice(conf.Peer.PublicIP) - - conf.Up = false - conf.Server = false - conf.Direct = ipValid - conf.DirectAddr = netip.AddrPortFrom(ip, conf.Peer.Port) - conf.DataCipher = newDataCipher() - - r.updateConf(conf) - r.logf("==> ClientInit") - - r.lastSeen = time.Now() - r.pingTimer.Reset(pingInterval) - r.stateClientInit_sendInit() - return r.stateClientInit -} - -func (r *remoteFSM) stateClientInit(iMsg any) stateFunc { - switch msg := iMsg.(type) { - case peerUpdateMsg: - return r.enterPeerUpdating(msg.Peer) - case controlMsg[packetInit]: - return r.stateClientInit_onInit(msg) - case controlMsg[packetSyn]: - r.logf("Unexpected SYN") - case controlMsg[packetAck]: - r.logf("Unexpected ACK") - case controlMsg[packetProbe]: - // Ignore - case controlMsg[packetLocalDiscovery]: - // Ignore - case pingTimerMsg: - return r.stateClientInit_onPing() - default: - r.logf("Unexpected message: %#v", iMsg) - } - - return r.stateClientInit -} - -func (r *remoteFSM) stateClientInit_sendInit() { - conf := r.conf() - r.traceID = r.NewTraceID() - init := packetInit{ - TraceID: r.traceID, - Direct: conf.Direct, - Version: version, - } - r.sendControl(conf, init.Marshal(r.buf)) -} - -func (r *remoteFSM) stateClientInit_onInit(msg controlMsg[packetInit]) stateFunc { - if msg.Packet.TraceID != r.traceID { - r.logf("Invalid trace ID on INIT.") - return r.stateClientInit - } - r.logf("Got INIT version %d.", msg.Packet.Version) - return r.enterClient() -} - -func (r *remoteFSM) stateClientInit_onPing() stateFunc { - if time.Since(r.lastSeen) < timeoutInterval { - r.stateClientInit_sendInit() - return r.stateClientInit - } - - // Direct connect failed. Try indirect. - conf := r.conf() - - if conf.Direct { - conf.Direct = false - r.updateConf(conf) - r.lastSeen = time.Now() - r.stateClientInit_sendInit() - r.logf("Direct connection failed. Attempting indirect connection.") - return r.stateClientInit - } - - // Indirect failed. Re-enter init state. - r.logf("Timeout.") - return r.enterClientInit() -} - -// ---------------------------------------------------------------------------- - -func (r *remoteFSM) enterClient() stateFunc { - conf := r.conf() - clear(r.probes) - - r.traceID = r.NewTraceID() - r.stateClient_sendSyn(conf) - - r.pingTimer.Reset(pingInterval) - r.logf("==> Client") - return r.stateClient -} - -func (r *remoteFSM) stateClient(iMsg any) stateFunc { - switch msg := iMsg.(type) { - case peerUpdateMsg: - return r.enterPeerUpdating(msg.Peer) - case controlMsg[packetAck]: - r.stateClient_onAck(msg) - case controlMsg[packetProbe]: - r.stateClient_onProbe(msg) - case controlMsg[packetLocalDiscovery]: - r.stateClient_onLocalDiscovery(msg) - case pingTimerMsg: - return r.stateClient_onPingTimer() - default: - r.logf("Ignoring message: %v", iMsg) - } - return r.stateClient -} - -func (r *remoteFSM) stateClient_onAck(msg controlMsg[packetAck]) { - if msg.Packet.TraceID != r.traceID { - return - } - - r.lastSeen = time.Now() - - conf := r.conf() - if !conf.Up { - conf.Up = true - r.updateConf(conf) - r.logf("Got ACK.") - } - - if conf.Direct { - r.PubAddrs.Store(msg.Packet.ToAddr) - return - } - - // Relayed. - - r.stateClient_cleanProbes() - - for _, addr := range msg.Packet.PossibleAddrs { - if !addr.IsValid() { - break - } - r.stateClient_sendProbeTo(addr) - } -} - -func (r *remoteFSM) stateClient_cleanProbes() { - for key, sent := range r.probes { - if time.Since(sent.SentAt) > pingInterval { - delete(r.probes, key) - } - } -} - -func (r *remoteFSM) stateClient_sendProbeTo(addr netip.AddrPort) { - probe := packetProbe{TraceID: r.NewTraceID()} - r.probes[probe.TraceID] = sentProbe{ - SentAt: time.Now(), - Addr: addr, - } - r.logf("Probing %v...", addr) - r.sendControlToAddr(probe.Marshal(r.buf), addr) -} - -func (r *remoteFSM) stateClient_onProbe(msg controlMsg[packetProbe]) { - conf := r.conf() - if conf.Direct { - return - } - - r.stateClient_cleanProbes() - - sent, ok := r.probes[msg.Packet.TraceID] - if !ok { - return - } - - conf.Direct = true - conf.DirectAddr = sent.Addr - r.updateConf(conf) - - r.traceID = r.NewTraceID() - r.stateClient_sendSyn(conf) - r.logf("Successful probe to %v.", sent.Addr) -} - -func (r *remoteFSM) stateClient_onLocalDiscovery(msg controlMsg[packetLocalDiscovery]) { - conf := r.conf() - if conf.Direct { - return - } - - // The source port will be the multicast port, so we'll have to - // construct the correct address using the peer's listed port. - addr := netip.AddrPortFrom(msg.SrcAddr.Addr(), conf.Peer.Port) - r.stateClient_sendProbeTo(addr) -} - -func (r *remoteFSM) stateClient_onPingTimer() stateFunc { - conf := r.conf() - - if time.Since(r.lastSeen) > timeoutInterval { - if conf.Up { - r.logf("Timeout.") - } - return r.enterClientInit() - } - - r.stateClient_sendSyn(conf) - return r.stateClient -} - -func (r *remoteFSM) stateClient_sendSyn(conf remoteConfig) { - syn := packetSyn{ - TraceID: r.traceID, - SharedKey: conf.DataCipher.Key(), - Direct: conf.Direct, - PossibleAddrs: r.PubAddrs.Get(), - } - - r.sendControl(conf, syn.Marshal(r.buf)) -} diff --git a/peer/statusserver.go b/peer/statusserver.go deleted file mode 100644 index 1eebbfc..0000000 --- a/peer/statusserver.go +++ /dev/null @@ -1,71 +0,0 @@ -package peer - -import ( - "encoding/json" - "log" - "net" - "net/http" - "net/netip" - "os" -) - -type StatusReport struct { - LocalPeerIP byte - Network []byte - RelayPeerIP byte - Remotes []RemoteStatus -} - -type RemoteStatus struct { - PeerIP byte - Up bool - Name string - PublicIP []byte - Port uint16 - Relay bool - Server bool - Direct bool - DirectAddr netip.AddrPort -} - -func runStatusServer(g Globals, socketPath string) { - _ = os.RemoveAll(socketPath) - - handler := func(w http.ResponseWriter, r *http.Request) { - report := StatusReport{ - LocalPeerIP: g.LocalPeerIP, - Network: g.Network, - Remotes: make([]RemoteStatus, 0, 255), - } - - relay := g.RelayHandler.Load() - if relay != nil { - if relayStatus, ok := relay.Status(); ok { - report.RelayPeerIP = relayStatus.PeerIP - } - } - - for i := range g.RemotePeers { - remote := g.RemotePeers[i].Load() - status, ok := remote.Status() - if !ok { - continue - } - report.Remotes = append(report.Remotes, status) - } - - json.NewEncoder(w).Encode(report) - } - - server := http.Server{ - Handler: http.HandlerFunc(handler), - } - - unixListener, err := net.Listen("unix", socketPath) - if err != nil { - log.Fatalf("Failed to bind to unix socket: %v", err) - } - if err := server.Serve(unixListener); err != nil { - log.Fatalf("Failed to serve on unix socket: %v", err) - } -} diff --git a/peer/wginterface/interface.go b/peer/wginterface/interface.go new file mode 100644 index 0000000..3354eed --- /dev/null +++ b/peer/wginterface/interface.go @@ -0,0 +1,225 @@ +// Package wginterface demonstrates creating and destroying a WireGuard network +// interface using only raw system calls — no netlink library. +// +// Creating a typed interface (kind = "wireguard") requires the NETLINK_ROUTE +// protocol; there is no ioctl path for it. Everything else — assigning an IP +// address and bringing the link up — can be done with the older AF_INET ioctl +// interface, exactly as one would for a TUN device. +// +// The package requires CAP_NET_ADMIN and the wireguard kernel module. +package wginterface + +import ( + "encoding/binary" + "fmt" + "net" + "slices" + + "golang.org/x/sys/unix" +) + +// Create creates a WireGuard interface named name, assigns vpnIP/prefixLen to +// it, and brings it up. +func Create(name string, vpnIP net.IP, prefixLen int) error { + _ = Delete(name) // remove any stale interface left by a previous run + if err := nlNewLink(name); err != nil { + return fmt.Errorf("failed to create wireguard link: %w", err) + } + if err := ioctlSetAddr(name, vpnIP, prefixLen); err != nil { + _ = Delete(name) + return fmt.Errorf("assign address: %w", err) + } + if err := ioctlLinkUp(name); err != nil { + _ = Delete(name) + return fmt.Errorf("link up: %w", err) + } + return nil +} + +// Delete removes the named interface. +func Delete(name string) error { + return nlDelLink(name) +} + +// --------------------------------------------------------------------------- +// Netlink link management +// +// Creating a WireGuard interface requires an RTM_NEWLINK message with a nested +// IFLA_LINKINFO attribute whose IFLA_INFO_KIND is "wireguard". The full +// message layout is: +// +// nlmsghdr (16 bytes) +// ifinfomsg (16 bytes, all zeros for a new link) +// rtattr IFLA_IFNAME → name + \0 +// rtattr IFLA_LINKINFO +// rtattr IFLA_INFO_KIND → "wireguard" + \0 +// +// All multi-byte integers are in native byte order (little-endian on +// x86/arm64). Every attribute is padded to a 4-byte boundary; the len field +// in the header records the unpadded length but the attribute occupies the +// padded size. + +const ( + nlmsgHdrLen = 16 // sizeof(struct nlmsghdr) + sizeofIfInfo = 16 // sizeof(struct ifinfomsg) + + // Attribute types not exposed by the unix package at the level we need. + iflaLinkInfo = 18 // IFLA_LINKINFO — container for link-type attributes + iflaInfoKind = 1 // IFLA_INFO_KIND — link type string, nested inside IFLA_LINKINFO +) + +// nlNewLink creates the wireguard interface using Netlink. +func nlNewLink(name string) error { + // Build innermost attribute first, then wrap outward. + infoKind := nlAttr(iflaInfoKind, cstring("wireguard")) + linkInfo := nlAttr(iflaLinkInfo, infoKind) + ifName := nlAttr(unix.IFLA_IFNAME, cstring(name)) + + // ifinfomsg: all-zero = AF_UNSPEC, no index, no flags (kernel assigns index). + ifInfo := make([]byte, sizeofIfInfo) + + payload := slices.Concat(ifInfo, ifName, linkInfo) + flags := uint16(unix.NLM_F_REQUEST | unix.NLM_F_ACK | unix.NLM_F_CREATE | unix.NLM_F_EXCL) + return nlRoundtrip(unix.RTM_NEWLINK, flags, payload) +} + +func nlDelLink(name string) error { + iface, err := net.InterfaceByName(name) + if err != nil { + return err + } + + // For RTM_DELLINK the kernel identifies the link by ifi_index. ifi_index + // sits at byte offset 4 in the ifinfomsg struct. + ifInfo := make([]byte, sizeofIfInfo) + binary.NativeEndian.PutUint32(ifInfo[4:8], uint32(iface.Index)) + + return nlRoundtrip(unix.RTM_DELLINK, uint16(unix.NLM_F_REQUEST|unix.NLM_F_ACK), ifInfo) +} + +// nlRoundtrip opens a NETLINK_ROUTE socket, sends one request, reads the +// NLMSG_ERROR acknowledgement, and closes the socket. +func nlRoundtrip(msgType uint16, flags uint16, payload []byte) error { + fd, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE) + if err != nil { + return fmt.Errorf("socket: %w", err) + } + defer unix.Close(fd) + + if err := unix.Bind(fd, &unix.SockaddrNetlink{Family: unix.AF_NETLINK}); err != nil { + return fmt.Errorf("bind: %w", err) + } + + msg := nlMsg(msgType, flags, payload) + if err := unix.Sendto(fd, msg, 0, &unix.SockaddrNetlink{Family: unix.AF_NETLINK}); err != nil { + return fmt.Errorf("sendto: %w", err) + } + + resp := make([]byte, 4096) + n, _, err := unix.Recvfrom(fd, resp, 0) + if err != nil { + return fmt.Errorf("recvfrom: %w", err) + } + return nlAckErr(resp[:n]) +} + +// nlMsg prepends an nlmsghdr to payload. +func nlMsg(msgType uint16, flags uint16, payload []byte) []byte { + buf := make([]byte, nlmsgHdrLen+len(payload)) + binary.NativeEndian.PutUint32(buf[0:4], uint32(len(buf))) // nlmsg_len + binary.NativeEndian.PutUint16(buf[4:6], msgType) // nlmsg_type + binary.NativeEndian.PutUint16(buf[6:8], flags) // nlmsg_flags + binary.NativeEndian.PutUint32(buf[8:12], 1) // nlmsg_seq + binary.NativeEndian.PutUint32(buf[12:16], 0) // nlmsg_pid (0 = kernel) + copy(buf[nlmsgHdrLen:], payload) + return buf +} + +// nlAckErr parses an NLMSG_ERROR response. The error field is a negated errno +// (0 = success, -EEXIST = interface exists, etc.). +func nlAckErr(resp []byte) error { + if len(resp) < nlmsgHdrLen+4 { + return fmt.Errorf("netlink response too short (%d bytes)", len(resp)) + } + if binary.NativeEndian.Uint16(resp[4:6]) != unix.NLMSG_ERROR { + return fmt.Errorf("unexpected nlmsg_type %d", binary.NativeEndian.Uint16(resp[4:6])) + } + // Error code follows the nlmsghdr; it is a signed int32 holding -errno. + code := int32(binary.NativeEndian.Uint32(resp[nlmsgHdrLen:])) + if code != 0 { + return unix.Errno(-code) + } + return nil +} + +// nlAttr encodes one netlink attribute: [len:u16][type:u16][data][pad to 4 +// bytes]. The len field counts the header + data (before padding); the +// allocation is padded so that the next attribute starts on a 4-byte boundary. +func nlAttr(attrType uint16, data []byte) []byte { + const hdr = 4 + attrLen := hdr + len(data) + padded := (attrLen + 3) &^ 3 + buf := make([]byte, padded) + binary.NativeEndian.PutUint16(buf[0:2], uint16(attrLen)) + binary.NativeEndian.PutUint16(buf[2:4], attrType) + copy(buf[hdr:], data) + return buf +} + +// --------------------------------------------------------------------------- +// ioctl-based address assignment and link-up +// +// These operations could also be done via RTM_NEWADDR / RTM_NEWLINK netlink +// messages, but the AF_INET ioctl interface is simpler. + +func ioctlSetAddr(name string, ip net.IP, prefixLen int) error { + fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP) + if err != nil { + return err + } + defer unix.Close(fd) + + req, err := unix.NewIfreq(name) + if err != nil { + return err + } + if err := req.SetInet4Addr(ip.To4()); err != nil { + return err + } + if err := unix.IoctlIfreq(fd, unix.SIOCSIFADDR, req); err != nil { + return err + } + + req, err = unix.NewIfreq(name) + if err != nil { + return err + } + mask := net.CIDRMask(prefixLen, 32) + if err := req.SetInet4Addr([]byte(mask)); err != nil { + return err + } + return unix.IoctlIfreq(fd, unix.SIOCSIFNETMASK, req) +} + +func ioctlLinkUp(name string) error { + fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP) + if err != nil { + return err + } + defer unix.Close(fd) + + req, err := unix.NewIfreq(name) + if err != nil { + return err + } + if err := unix.IoctlIfreq(fd, unix.SIOCGIFFLAGS, req); err != nil { + return err + } + req.SetUint16(req.Uint16() | unix.IFF_UP | unix.IFF_RUNNING) + return unix.IoctlIfreq(fd, unix.SIOCSIFFLAGS, req) +} + +// cstring returns b as a null-terminated byte slice. +func cstring(s string) []byte { + return append([]byte(s), 0) +} diff --git a/peer/wginterface/manage.go b/peer/wginterface/manage.go new file mode 100644 index 0000000..4f8f9fe --- /dev/null +++ b/peer/wginterface/manage.go @@ -0,0 +1,184 @@ +package wginterface + +import ( + "fmt" + "net" + "net/netip" + "os" + "time" + + "golang.zx2c4.com/wireguard/wgctrl" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +const ( + // RekeyTimeout is the WireGuard session lifetime before a new handshake + // is initiated. Sessions older than this but younger than SessionTimeout + // remain valid. + RekeyTimeout = 120 * time.Second + + // SessionTimeout is the WireGuard session lifetime after which sessions + // are rejected. A peer with LastHandshakeTime older than this is + // effectively disconnected. + SessionTimeout = 180 * time.Second +) + +const ProbeKeepalive = 8 * time.Second + +var zeroKeepalive = time.Duration(0) + +// Device wraps a wgctrl client bound to a named WireGuard interface. +type Device struct { + client *wgctrl.Client + name string +} + +// Open attaches to an existing WireGuard interface. +func Open(name string) (*Device, error) { + client, err := wgctrl.New() + if err != nil { + return nil, fmt.Errorf("wgctrl: %w", err) + } + return &Device{client: client, name: name}, nil +} + +// Close releases the underlying wgctrl client. +func (d *Device) Close() error { + return d.client.Close() +} + +// Name returns the interface name. +func (d *Device) Name() string { + return d.name +} + +// Configure sets the device's private key and UDP listen port. +func (d *Device) Configure(privKey wgtypes.Key, listenPort int) error { + return d.client.ConfigureDevice(d.name, wgtypes.Config{ + PrivateKey: &privKey, + ListenPort: &listenPort, + }) +} + +// Peers returns the current state of all peers on the device. +func (d *Device) Peers() ([]wgtypes.Peer, error) { + dev, err := d.client.Device(d.name) + if err != nil { + return nil, fmt.Errorf("get device %q: %w", d.name, err) + } + return dev.Peers, nil +} + +// Peer returns the current state of a single peer by public key. +func (d *Device) Peer(pubKey wgtypes.Key) (wgtypes.Peer, error) { + peers, err := d.Peers() + if err != nil { + return wgtypes.Peer{}, err + } + for _, p := range peers { + if p.PublicKey == pubKey { + return p, nil + } + } + return wgtypes.Peer{}, fmt.Errorf("peer %v not found in %q", pubKey, d.name) +} + +// AddPeer registers a peer with no AllowedIPs and no endpoint. WireGuard will +// accept handshakes from this peer but route no traffic to it yet. +func (d *Device) AddPeer(pubKey wgtypes.Key) error { + return d.client.ConfigureDevice(d.name, wgtypes.Config{ + Peers: []wgtypes.PeerConfig{{ + PublicKey: pubKey, + ReplaceAllowedIPs: true, + }}, + }) +} + +// SetRelay configures the relay peer with AllowedIPs covering the entire VPN +// network prefix. This is the fallback route for all VPN traffic. +func (d *Device) SetRelay(pubKey wgtypes.Key, endpoint netip.AddrPort, network netip.Prefix) error { + masked := network.Masked() + a4 := masked.Addr().As4() + return d.client.ConfigureDevice(d.name, wgtypes.Config{ + Peers: []wgtypes.PeerConfig{{ + PublicKey: pubKey, + Endpoint: net.UDPAddrFromAddrPort(endpoint), + AllowedIPs: []net.IPNet{{ + IP: net.IP(a4[:]), + Mask: net.CIDRMask(masked.Bits(), 32), + }}, + ReplaceAllowedIPs: true, + }}, + }) +} + +// AddProbe adds a peer with no AllowedIPs and a 5s keepalive. WireGuard will +// attempt handshakes without routing any traffic through this peer yet. +func (d *Device) AddProbe(pubKey wgtypes.Key, endpoint netip.AddrPort) error { + keepalive := ProbeKeepalive + return d.client.ConfigureDevice(d.name, wgtypes.Config{ + Peers: []wgtypes.PeerConfig{{ + PublicKey: pubKey, + Endpoint: net.UDPAddrFromAddrPort(endpoint), + AllowedIPs: []net.IPNet{}, + ReplaceAllowedIPs: true, + PersistentKeepaliveInterval: &keepalive, + }}, + }) +} + +// Promote upgrades a probe entry to a /32 AllowedIPs and removes the probe +// keepalive, causing WireGuard to prefer this peer's direct path over the +// relay's wider route. +func (d *Device) Promote(pubKey wgtypes.Key, vpnIP netip.Addr) error { + a4 := vpnIP.As4() + return d.client.ConfigureDevice(d.name, wgtypes.Config{ + Peers: []wgtypes.PeerConfig{{ + PublicKey: pubKey, + AllowedIPs: []net.IPNet{{ + IP: net.IP(a4[:]), + Mask: net.CIDRMask(32, 32), + }}, + ReplaceAllowedIPs: true, + PersistentKeepaliveInterval: &zeroKeepalive, + }}, + }) +} + +// AddDirect adds a peer with a known endpoint and /32 AllowedIPs in one step, +// for peers with a stable public endpoint reported by the hub. +func (d *Device) AddDirect(pubKey wgtypes.Key, endpoint netip.AddrPort, vpnIP netip.Addr) error { + a4 := vpnIP.As4() + return d.client.ConfigureDevice(d.name, wgtypes.Config{ + Peers: []wgtypes.PeerConfig{{ + PublicKey: pubKey, + Endpoint: net.UDPAddrFromAddrPort(endpoint), + AllowedIPs: []net.IPNet{{ + IP: net.IP(a4[:]), + Mask: net.CIDRMask(32, 32), + }}, + ReplaceAllowedIPs: true, + PersistentKeepaliveInterval: &zeroKeepalive, + }}, + }) +} + +// RemovePeer removes a peer from the device. +func (d *Device) RemovePeer(pubKey wgtypes.Key) error { + return d.client.ConfigureDevice(d.name, wgtypes.Config{ + Peers: []wgtypes.PeerConfig{{ + PublicKey: pubKey, + Remove: true, + }}, + }) +} + +// EnableForwarding enables IPv4 forwarding globally and on the interface, +// required for relay peers that forward traffic between VPN peers. +func (d *Device) EnableForwarding() error { + if err := os.WriteFile("/proc/sys/net/ipv4/ip_forward", []byte("1\n"), 0644); err != nil { + return err + } + path := fmt.Sprintf("/proc/sys/net/ipv4/conf/%s/forwarding", d.name) + return os.WriteFile(path, []byte("1\n"), 0644) +} diff --git a/peer/wginterface/manage_test.go b/peer/wginterface/manage_test.go new file mode 100644 index 0000000..b809e2d --- /dev/null +++ b/peer/wginterface/manage_test.go @@ -0,0 +1,303 @@ +//go:build integration + +package wginterface_test + +import ( + "fmt" + "log" + "net" + "net/netip" + "os" + "strings" + "testing" + "time" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "vppn/peer/wginterface" +) + +const ( + testBasePort = 59100 +) + +func TestMain(m *testing.M) { + if os.Getuid() != 0 { + fmt.Fprintln(os.Stderr, "wginterface integration tests require root; skipping") + os.Exit(0) + } + os.Exit(m.Run()) +} + +type testPeer struct { + Name string + VpnIP netip.Addr + Port int + PrivKey wgtypes.Key + PubKey wgtypes.Key + Dev *wginterface.Device +} + +func (p *testPeer) Endpoint() netip.AddrPort { + return netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), uint16(p.Port)) +} + +func newTestPeer(t *testing.T, name string, vpnIP netip.Addr, port int) *testPeer { + t.Helper() + + privKey, err := wgtypes.GenerateKey() + if err != nil { + t.Fatalf("generate key: %v", err) + } + + a4 := vpnIP.As4() + if err := wginterface.Create(name, net.IP(a4[:]), 24); err != nil { + t.Fatalf("create %s: %v", name, err) + } + t.Cleanup(func() { + if err := wginterface.Delete(name); err != nil { + log.Printf("Failed to delete interface %s: %v", name, err) + } + }) + + dev, err := wginterface.Open(name) + if err != nil { + t.Fatalf("open %s: %v", name, err) + } + t.Cleanup(func() { dev.Close() }) + + if err := dev.Configure(privKey, port); err != nil { + t.Fatalf("configure %s: %v", name, err) + } + + return &testPeer{ + Name: name, + VpnIP: vpnIP, + Port: port, + PrivKey: privKey, + PubKey: privKey.PublicKey(), + Dev: dev, + } +} + +// waitHandshake polls until the named peer has completed a handshake or the timeout elapses. +func waitHandshake(t *testing.T, dev *wginterface.Device, pubKey wgtypes.Key, timeout time.Duration) { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + p, err := dev.Peer(pubKey) + if err != nil { + t.Fatalf("peer lookup: %v", err) + } + if !p.LastHandshakeTime.IsZero() { + return + } + time.Sleep(200 * time.Millisecond) + } + t.Fatalf("no handshake within %v", timeout) +} + +func TestDirectHandshake(t *testing.T) { + p1 := newTestPeer(t, "wgtest0", netip.MustParseAddr("192.168.99.1"), testBasePort) + p2 := newTestPeer(t, "wgtest1", netip.MustParseAddr("192.168.99.2"), testBasePort+1) + + if err := p1.Dev.AddDirect(p2.PubKey, p2.Endpoint(), p2.VpnIP); err != nil { + t.Fatalf("p1 AddDirect: %v", err) + } + if err := p2.Dev.AddDirect(p1.PubKey, p1.Endpoint(), p1.VpnIP); err != nil { + t.Fatalf("p2 AddDirect: %v", err) + } + + waitHandshake(t, p1.Dev, p2.PubKey, 30*time.Second) + waitHandshake(t, p2.Dev, p1.PubKey, 30*time.Second) +} + +func TestProbeAndPromote(t *testing.T) { + p1 := newTestPeer(t, "wgtest0", netip.MustParseAddr("192.168.99.1"), testBasePort) + p2 := newTestPeer(t, "wgtest1", netip.MustParseAddr("192.168.99.2"), testBasePort+1) + + // p2 needs a peer entry for p1 so it can respond to the handshake initiation. + if err := p2.Dev.AddDirect(p1.PubKey, p1.Endpoint(), p1.VpnIP); err != nil { + t.Fatalf("p2 AddDirect: %v", err) + } + + if err := p1.Dev.AddProbe(p2.PubKey, p2.Endpoint()); err != nil { + t.Fatalf("AddProbe: %v", err) + } + waitHandshake(t, p1.Dev, p2.PubKey, 30*time.Second) + + if err := p1.Dev.Promote(p2.PubKey, p2.VpnIP); err != nil { + t.Fatalf("Promote: %v", err) + } + + peer, err := p1.Dev.Peer(p2.PubKey) + if err != nil { + t.Fatalf("Peer: %v", err) + } + checkAllowedIP(t, peer, p2.VpnIP, 32) +} + +func TestRelayHandshakes(t *testing.T) { + vpnNetwork := netip.MustParsePrefix("192.168.99.0/24") + + relay := newTestPeer(t, "wgtest0", netip.MustParseAddr("192.168.99.1"), testBasePort) + peer1 := newTestPeer(t, "wgtest1", netip.MustParseAddr("192.168.99.2"), testBasePort+1) + peer2 := newTestPeer(t, "wgtest2", netip.MustParseAddr("192.168.99.3"), testBasePort+2) + + if err := relay.Dev.AddDirect(peer1.PubKey, peer1.Endpoint(), peer1.VpnIP); err != nil { + t.Fatalf("relay AddDirect peer1: %v", err) + } + if err := relay.Dev.AddDirect(peer2.PubKey, peer2.Endpoint(), peer2.VpnIP); err != nil { + t.Fatalf("relay AddDirect peer2: %v", err) + } + if err := peer1.Dev.SetRelay(relay.PubKey, relay.Endpoint(), vpnNetwork); err != nil { + t.Fatalf("peer1 SetRelay: %v", err) + } + if err := peer2.Dev.SetRelay(relay.PubKey, relay.Endpoint(), vpnNetwork); err != nil { + t.Fatalf("peer2 SetRelay: %v", err) + } + + waitHandshake(t, relay.Dev, peer1.PubKey, 30*time.Second) + waitHandshake(t, relay.Dev, peer2.PubKey, 30*time.Second) + waitHandshake(t, peer1.Dev, relay.PubKey, 30*time.Second) + waitHandshake(t, peer2.Dev, relay.PubKey, 30*time.Second) + + // relay has /32 entries for each peer + p, err := relay.Dev.Peer(peer1.PubKey) + if err != nil { + t.Fatalf("relay peer1: %v", err) + } + checkAllowedIP(t, p, peer1.VpnIP, 32) + + p, err = relay.Dev.Peer(peer2.PubKey) + if err != nil { + t.Fatalf("relay peer2: %v", err) + } + checkAllowedIP(t, p, peer2.VpnIP, 32) + + // peers have /24 fallback route via relay + p, err = peer1.Dev.Peer(relay.PubKey) + if err != nil { + t.Fatalf("peer1 relay: %v", err) + } + checkAllowedIP(t, p, vpnNetwork.Masked().Addr(), 24) + + p, err = peer2.Dev.Peer(relay.PubKey) + if err != nil { + t.Fatalf("peer2 relay: %v", err) + } + checkAllowedIP(t, p, vpnNetwork.Masked().Addr(), 24) +} + +func TestRemovePeer(t *testing.T) { + p1 := newTestPeer(t, "wgtest0", netip.MustParseAddr("192.168.99.1"), testBasePort) + p2 := newTestPeer(t, "wgtest1", netip.MustParseAddr("192.168.99.2"), testBasePort+1) + + if err := p1.Dev.AddDirect(p2.PubKey, p2.Endpoint(), p2.VpnIP); err != nil { + t.Fatalf("AddDirect: %v", err) + } + if err := p2.Dev.AddDirect(p1.PubKey, p1.Endpoint(), p1.VpnIP); err != nil { + t.Fatalf("AddDirect: %v", err) + } + waitHandshake(t, p1.Dev, p2.PubKey, 30*time.Second) + + if err := p1.Dev.RemovePeer(p2.PubKey); err != nil { + t.Fatalf("RemovePeer: %v", err) + } + if _, err := p1.Dev.Peer(p2.PubKey); err == nil { + t.Fatal("expected error after RemovePeer, got nil") + } +} + +func TestEnableForwarding(t *testing.T) { + p := newTestPeer(t, "wgtest0", netip.MustParseAddr("192.168.99.1"), testBasePort) + + if err := p.Dev.EnableForwarding(); err != nil { + t.Fatalf("EnableForwarding: %v", err) + } + + data, err := os.ReadFile(fmt.Sprintf("/proc/sys/net/ipv4/conf/%s/forwarding", p.Name)) + if err != nil { + t.Fatalf("read forwarding: %v", err) + } + if strings.TrimSpace(string(data)) != "1" { + t.Fatalf("expected forwarding=1, got %q", string(data)) + } +} + +func TestPromoteKeepalive(t *testing.T) { + p1 := newTestPeer(t, "wgtest0", netip.MustParseAddr("192.168.99.1"), testBasePort) + p2 := newTestPeer(t, "wgtest1", netip.MustParseAddr("192.168.99.2"), testBasePort+1) + + if err := p2.Dev.AddDirect(p1.PubKey, p1.Endpoint(), p1.VpnIP); err != nil { + t.Fatalf("p2 AddDirect: %v", err) + } + if err := p1.Dev.AddProbe(p2.PubKey, p2.Endpoint()); err != nil { + t.Fatalf("AddProbe: %v", err) + } + waitHandshake(t, p1.Dev, p2.PubKey, 30*time.Second) + + if err := p1.Dev.Promote(p2.PubKey, p2.VpnIP); err != nil { + t.Fatalf("Promote: %v", err) + } + + peer, err := p1.Dev.Peer(p2.PubKey) + if err != nil { + t.Fatalf("Peer: %v", err) + } + if peer.PersistentKeepaliveInterval != 0 { + t.Fatalf("expected keepalive disabled after promote, got %v", peer.PersistentKeepaliveInterval) + } +} + +func TestPeersCount(t *testing.T) { + relay := newTestPeer(t, "wgtest0", netip.MustParseAddr("192.168.99.1"), testBasePort) + peer1 := newTestPeer(t, "wgtest1", netip.MustParseAddr("192.168.99.2"), testBasePort+1) + peer2 := newTestPeer(t, "wgtest2", netip.MustParseAddr("192.168.99.3"), testBasePort+2) + + if err := relay.Dev.AddDirect(peer1.PubKey, peer1.Endpoint(), peer1.VpnIP); err != nil { + t.Fatalf("AddDirect peer1: %v", err) + } + if err := relay.Dev.AddDirect(peer2.PubKey, peer2.Endpoint(), peer2.VpnIP); err != nil { + t.Fatalf("AddDirect peer2: %v", err) + } + + peers, err := relay.Dev.Peers() + if err != nil { + t.Fatalf("Peers: %v", err) + } + if len(peers) != 2 { + t.Fatalf("expected 2 peers, got %d", len(peers)) + } + + if err := relay.Dev.RemovePeer(peer1.PubKey); err != nil { + t.Fatalf("RemovePeer: %v", err) + } + + peers, err = relay.Dev.Peers() + if err != nil { + t.Fatalf("Peers after remove: %v", err) + } + if len(peers) != 1 { + t.Fatalf("expected 1 peer after remove, got %d", len(peers)) + } + if peers[0].PublicKey != peer2.PubKey { + t.Fatal("wrong peer remained after remove") + } +} + +// checkAllowedIP asserts that a peer has exactly one AllowedIP matching addr/bits. +func checkAllowedIP(t *testing.T, p wgtypes.Peer, addr netip.Addr, bits int) { + t.Helper() + if len(p.AllowedIPs) != 1 { + t.Fatalf("expected 1 AllowedIP, got %d", len(p.AllowedIPs)) + } + ones, _ := p.AllowedIPs[0].Mask.Size() + if ones != bits { + t.Fatalf("expected /%d, got /%d", bits, ones) + } + got := netip.AddrFrom4([4]byte(p.AllowedIPs[0].IP.To4())) + if got != addr { + t.Fatalf("expected AllowedIP %v, got %v", addr, got) + } +}