31 Commits

Author SHA1 Message Date
jdl
926e111c3f Cleanup 2026-06-12 18:22:25 +02:00
9a3cb2d1c2 Refactor - now wireguard based. (#7) 2026-06-12 15:11:01 +00:00
jdl
5ae075647d Bug fixes 2025-09-25 10:03:37 +02:00
jdl
29bbb442c8 Cleanup 2025-09-25 09:08:10 +02:00
3d93c0206c client-interface-cleanup (#6)
Refactoring and code cleanup. Improved client command interface.
2025-09-17 08:00:12 +00:00
jdl
75c7c2d3d9 README cleanup. 2025-09-04 11:29:26 +02:00
jdl
c61319ed16 Cleanup 2025-09-03 20:41:35 +02:00
jdl
0a7328ed5f Refactor 2025-09-01 18:15:42 +02:00
jdl
5f0b00ff46 Refactor 2025-09-01 18:09:24 +02:00
jdl
e91cbfe957 Merge branch 'refactor-2025' 2025-09-01 18:05:42 +02:00
jdl
7c476fc332 WIP: Cleanup 2025-09-01 18:03:09 +02:00
jdl
69823d1d19 WIP 2025-09-01 18:00:41 +02:00
jdl
b7cb4e20f0 WIP 2025-08-26 19:50:59 +02:00
jdl
6382c13d1a WIP 2025-08-26 19:12:07 +02:00
jdl
1ca55158c2 wip 2025-08-26 17:01:38 +02:00
jdl
302d27692b WIP: Cleanup 2025-08-26 16:57:46 +02:00
jdl
31c48fbafd Cleanup 2025-08-26 16:20:47 +02:00
jdl
3c4534f620 WIP: Fixed rate limiting 2025-08-26 16:17:46 +02:00
jdl
169231d848 WIP: Apparently working? 2025-08-26 16:11:21 +02:00
jdl
f4589a1031 Don't crash 2025-08-26 15:45:06 +02:00
jdl
ab246b2a90 WIP 2025-08-26 15:33:27 +02:00
b9e773ec83 Update - modify hub to support multiple networks. (#4)
Co-authored-by: jdl <jdl@desktop>
Reviewed-on: #4
2025-04-12 11:43:18 +00:00
jdl
d558ebbd14 WIP 2025-04-06 07:51:47 +02:00
jdl
03b1bbcbcf Cleanup dependencies 2025-03-10 16:11:40 +01:00
jdl
8160eb5ad7 Attempt relayed connection if direct fails. 2025-03-08 10:41:50 +01:00
1d3cc1f959 refactor-for-testability (#3)
Co-authored-by: jdl <jdl@desktop>
Co-authored-by: jdl <jdl@crumpington.com>
Reviewed-on: #3
2025-03-01 20:02:27 +00:00
jdl
a0b5058544 Cleanup 2025-01-15 08:52:07 +01:00
jdl
232681fac6 Cleanup 2025-01-15 08:51:53 +01:00
jdl
6e7a2456b2 Cleanup 2025-01-15 08:45:01 +01:00
jdl
970490b17b Breaking change: new packet formats. 2025-01-13 16:43:27 +01:00
jdl
2bdd76e689 Better address discovery. 2025-01-12 20:31:36 +01:00
109 changed files with 4463 additions and 3423 deletions

View File

@@ -1,9 +1,5 @@
# vppn: Virtual Potentially Private Network # vppn: Virtual Potentially Private Network
## TODO
* Add `-force-init` argument to `node` main?
## Hub Server Configuration ## Hub Server Configuration
``` ```
@@ -13,7 +9,6 @@ adduser user
# Enable ssh. # Enable ssh.
cp -r ~/.ssh /home/user/ cp -r ~/.ssh /home/user/
chown -R user:user /home/user/.ssh chown -R user:user /home/user/.ssh
``` ```
Upload `hub` executable: Upload `hub` executable:
@@ -33,7 +28,6 @@ WorkingDirectory=/home/user/
ExecStart=/home/user/hub -listen <addr>:https -root-dir=/home/user ExecStart=/home/user/hub -listen <addr>:https -root-dir=/home/user
Restart=always Restart=always
RestartSec=8 RestartSec=8
TimeoutStopSec=24
[Install] [Install]
WantedBy=default.target WantedBy=default.target
@@ -43,6 +37,7 @@ Add and start the hub server:
``` ```
systemctl daemon-reload systemctl daemon-reload
systemctl enable hub
systemctl start hub systemctl start hub
``` ```
@@ -58,20 +53,29 @@ Sign-in and configure.
Install the binary somewhere, for example `~/bin/vppn`. Install the binary somewhere, for example `~/bin/vppn`.
Create systemd file in `/etc/systemd/system/vppn.service`. Add the API key for your network name in `~/.vppn/<netname>/apikey`.
Create systemd file in `/etc/systemd/system/vppn.service`.
``` ```
[Service] [Service]
AmbientCapabilities=CAP_NET_BIND_SERVICE CAP_NET_ADMIN AmbientCapabilities=AP_NET_ADMIN CAP_DAC_OVERRIDE CAP_CHOWN
Type=simple Type=simple
User=user User=user
WorkingDirectory=/home/user/ WorkingDirectory=/home/user/
ExecStart=/home/user/vppn -name vppn -hub-address https://my.hub -api-key 1234567890 ExecStart=/home/user/bin/vppn -name my_net_name -hub https://my.hub
Restart=always Restart=always
RestartSec=8 RestartSec=8
TimeoutStopSec=24 TimeoutStopSec=24
[Install] [Install]
WantedBy=default.target WantedBy=multi-user.target
```
Add and start the service:
```
systemctl daemon-reload
systemctl enable vppn
systemctl start vppn
``` ```

View File

@@ -1,11 +1,72 @@
package main package main
import ( import (
"flag"
"log" "log"
"vppn/node" "os"
"path/filepath"
"strings"
"vppn/peer"
"git.crumpington.com/lib/go/flock"
) )
func main() { func main() {
log.SetFlags(0) log.SetFlags(0)
node.Main()
name := flag.String("name", "", "network name (required)")
hub := flag.String("hub", "", "hub base URL (required)")
flag.Parse()
if *name == "" || *hub == "" {
flag.Usage()
os.Exit(1)
}
apiKey, err := loadAPIKey(*name)
if err != nil {
log.Fatalf("api key: %v", err)
}
// Directory existence is guaranteed by the apikey file read above.
lockFile, err := flock.TryLock(vppnPath(*name, "lock"))
if err != nil {
log.Fatalf("lock: %v", err)
}
if lockFile == nil {
log.Fatalf("already running for network %q", *name)
}
defer flock.Unlock(lockFile)
state, err := peer.LoadOrInit(vppnPath(*name, "state.json"), *hub, apiKey)
if err != nil {
log.Fatalf("init: %v", err)
}
ifaceName := strings.TrimSuffix(state.LocalDomain, ".local")
app, err := peer.New(state, *hub, apiKey, ifaceName, state.LocalDomain, vppnPath(*name, "network.json"))
if err != nil {
log.Fatalf("start: %v", err)
}
if err := app.Run(); err != nil {
log.Fatalf("run: %v", err)
}
}
func loadAPIKey(name string) (string, error) {
data, err := os.ReadFile(vppnPath(name, "apikey"))
if err != nil {
return "", err
}
return strings.TrimSpace(string(data)), nil
}
func vppnPath(name, file string) string {
home, err := os.UserHomeDir()
if err != nil {
return filepath.Join(".vppn", name, file)
}
return filepath.Join(home, ".vppn", name, file)
} }

23
go.mod
View File

@@ -1,16 +1,23 @@
module vppn module vppn
go 1.23.2 go 1.25.1
require ( require (
git.crumpington.com/lib/go v0.8.1 git.crumpington.com/lib/go v0.9.1
git.crumpington.com/lib/webutil v0.0.7 golang.org/x/crypto v0.42.0
golang.org/x/crypto v0.29.0 golang.org/x/sys v0.36.0
golang.org/x/sys v0.27.0 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
) )
require ( require (
github.com/mattn/go-sqlite3 v1.14.24 // indirect github.com/google/go-cmp v0.6.0 // indirect
golang.org/x/net v0.31.0 // indirect github.com/josharian/native v1.1.0 // indirect
golang.org/x/text v0.20.0 // indirect github.com/mattn/go-sqlite3 v1.14.32 // indirect
github.com/mdlayher/genetlink v1.3.2 // indirect
github.com/mdlayher/netlink v1.7.2 // indirect
github.com/mdlayher/socket v0.5.1 // indirect
golang.org/x/net v0.44.0 // indirect
golang.org/x/sync v0.17.0 // indirect
golang.org/x/text v0.29.0 // indirect
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 // indirect
) )

44
go.sum
View File

@@ -1,14 +1,30 @@
git.crumpington.com/lib/go v0.8.1 h1:rWjddllSxQ4yReraqDaGZAod4NpRD9LtGx1yV71ytcU= git.crumpington.com/lib/go v0.9.1 h1:xLBzcgiZRB6Ky3Ce9hKE+Ko0YbkA4USF4eJk5i5RJF4=
git.crumpington.com/lib/go v0.8.1/go.mod h1:XjQaf2NFlje9BJ1EevZL8NNioPrAe7WwHpKUhcDw2Lk= git.crumpington.com/lib/go v0.9.1/go.mod h1:5nnfjdnUnj/FHhakaliKQKsKeSkUb0GEUKF3PqRgUXg=
git.crumpington.com/lib/webutil v0.0.7 h1:1RG9CpuXYalT0NPj8fvxjOLV566LqL37APvAdASFzgA= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
git.crumpington.com/lib/webutil v0.0.7/go.mod h1:efIEiuK1uqFIhI/dlsWUHMsC5bXcEbJEjmdluRoFPPQ= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM= github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
golang.org/x/crypto v0.29.0 h1:L5SG1JTTXupVV3n6sUqMTeWbjAyfPwoda2DLX8J8FrQ= github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
golang.org/x/crypto v0.29.0/go.mod h1:+F4F4N5hv6v38hfeYwTdx20oUvLLc+QfrE9Ax9HtgRg= github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
golang.org/x/net v0.31.0 h1:68CPQngjLL0r2AlUKiSxtQFKvzRVbnzLwMUn5SzcLHo= github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw=
golang.org/x/net v0.31.0/go.mod h1:P4fl1q7dY2hnZFxEk4pPSkDHF+QqjitcnDjUQyMM+pM= github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s= github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug= github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI=
golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
golang.org/x/net v0.44.0 h1:evd8IRDyfNBMBTTY5XRF1vaZlD+EmWx6x8PkhR04H/I=
golang.org/x/net v0.44.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4=
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ=

View File

@@ -19,8 +19,10 @@ import (
var migrations embed.FS var migrations embed.FS
type API struct { type API struct {
db *sql.DB db *sql.DB
lock sync.Mutex lock sync.Mutex
sessionsMu sync.Mutex
sessions map[string]*Session
} }
func New(dbPath string) (*API, error) { func New(dbPath string) (*API, error) {
@@ -34,10 +36,17 @@ func New(dbPath string) (*API, error) {
} }
a := &API{ a := &API{
db: sqlDB, db: sqlDB,
sessions: make(map[string]*Session),
} }
return a, a.ensurePassword() if err := a.ensurePassword(); err != nil {
return nil, err
}
go a.sweepSessions()
return a, nil
} }
func (a *API) ensurePassword() error { func (a *API) ensurePassword() error {
@@ -54,139 +63,153 @@ func (a *API) ensurePassword() error {
log.Printf("Setting password: %s", pwd) log.Printf("Setting password: %s", pwd)
hashed, err := bcrypt.GenerateFromPassword([]byte(pwd), bcrypt.DefaultCost) hashed, err := bcrypt.GenerateFromPassword([]byte(pwd), bcrypt.DefaultCost)
if err != nil { if err != nil {
return err return err
} }
conf := &Config{ conf := &Config{ConfigID: 1, Password: hashed}
ConfigID: 1,
VPNNetwork: []byte{10, 1, 1, 0},
Password: hashed,
}
return db.Config_Insert(a.db, conf) return db.Config_Insert(a.db, conf)
} }
func (a *API) Config_Get() *Config { func (a *API) Config_Get() (*Config, error) {
conf, err := db.Config_Get(a.db, 1) return db.Config_Get(a.db, 1)
if err != nil {
panic(err)
}
return conf
} }
func (a *API) Config_Update(conf *Config) error { func (a *API) Config_Update(conf *Config) error {
return db.Config_Update(a.db, conf) return db.Config_Update(a.db, conf)
} }
func (a *API) Config_UpdatePassword(pwdHash []byte) error {
return db.Config_UpdatePassword(a.db, pwdHash)
}
func (a *API) Session_Delete(sessionID string) error { func (a *API) Session_Delete(sessionID string) error {
return db.Session_Delete(a.db, sessionID) a.sessionsMu.Lock()
defer a.sessionsMu.Unlock()
delete(a.sessions, sessionID)
return nil
} }
func (a *API) Session_Get(sessionID string) (*Session, error) { const (
if sessionID == "" { sessionTTLSecs = 86400 * 21 // sessions expire 21 days after last use
return a.session_CreatePub() sessionSweepEvery = time.Hour // cadence of expired-session eviction
)
// Session_Get returns a snapshot copy of the signed-in session for sessionID,
// or the zero Session if the cookie is missing/unknown/expired. It never
// creates a session, so anonymous requests cost no memory — a session is minted
// only by Session_SignIn. Returning a value (not the stored pointer) keeps
// callers from racing on the shared struct.
func (a *API) Session_Get(sessionID string) (Session, error) {
a.sessionsMu.Lock()
defer a.sessionsMu.Unlock()
s, ok := a.sessions[sessionID]
if sessionID == "" || !ok {
return Session{}, nil
} }
session, err := db.Session_Get(a.db, sessionID) if timeSince(s.LastSeenAt) > sessionTTLSecs {
delete(a.sessions, sessionID)
return Session{}, nil
}
s.LastSeenAt = time.Now().Unix()
return *s, nil
}
// Session_SignIn verifies pwd and, on success, mints a fresh signed-in session,
// returning it so the caller can set the cookie. A new ID per sign-in rotates
// the session at the privilege boundary (session-fixation resistance).
func (a *API) Session_SignIn(pwd string) (Session, error) {
conf, err := a.Config_Get()
if err != nil { if err != nil {
return a.session_CreatePub() return Session{}, err
}
if err := bcrypt.CompareHashAndPassword(conf.Password, []byte(pwd)); err != nil {
return Session{}, ErrNotAuthorized
} }
if timeSince(session.LastSeenAt) > 86400*21 { a.sessionsMu.Lock()
return a.session_CreatePub() defer a.sessionsMu.Unlock()
}
if timeSince(session.LastSeenAt) > 86400*7 {
session.LastSeenAt = time.Now().Unix()
if err := db.Session_UpdateLastSeenAt(a.db, session.SessionID); err != nil {
log.Printf("Failed to update session: %v", err)
}
}
return session, nil
}
func (a *API) session_CreatePub() (*Session, error) {
s := &Session{ s := &Session{
SessionID: idgen.NewToken(), SessionID: idgen.NewToken(),
CSRF: idgen.NewToken(), SignedIn: true,
SignedIn: false,
CreatedAt: time.Now().Unix(), CreatedAt: time.Now().Unix(),
LastSeenAt: time.Now().Unix(), LastSeenAt: time.Now().Unix(),
} }
err := db.Session_Insert(a.db, s) a.sessions[s.SessionID] = s
return s, err return *s, nil
} }
func (a *API) Session_DeleteBefore(timestamp int64) error { // sweepSessions periodically evicts sessions past their TTL. Without it, a
return db.Session_DeleteBefore(a.db, timestamp) // signed-in session whose ID is never presented again would linger forever
} // (Session_Get only evicts on a lookup of that same ID).
func (a *API) sweepSessions() {
func (a *API) Session_SignIn(s *Session, pwd string) error { for range time.Tick(sessionSweepEvery) {
conf := a.Config_Get() a.sessionsMu.Lock()
if err := bcrypt.CompareHashAndPassword(conf.Password, []byte(pwd)); err != nil { for id, s := range a.sessions {
return ErrNotAuthorized if timeSince(s.LastSeenAt) > sessionTTLSecs {
delete(a.sessions, id)
}
}
a.sessionsMu.Unlock()
} }
}
return db.Session_SetSignedIn(a.db, s.SessionID) func (a *API) Network_Create(n *Network) error {
n.NetworkID = idgen.NextID(0)
return db.Network_Insert(a.db, n)
}
func (a *API) Network_Delete(n *Network) error {
return db.Network_Delete(a.db, n.NetworkID)
}
func (a *API) Network_Get(id int64) (*Network, error) {
return db.Network_Get(a.db, id)
}
func (a *API) Network_List() ([]*Network, error) {
const query = db.Network_SelectQuery + ` ORDER BY LocalDomain ASC`
return db.Network_List(a.db, query)
} }
func (a *API) Peer_CreateNew(p *Peer) error { func (a *API) Peer_CreateNew(p *Peer) error {
p.Version = idgen.NextID(0) p.WGPubKey = []byte{}
p.PubKey = []byte{} p.SignPubKey = []byte{}
p.PubSignKey = []byte{}
p.APIKey = idgen.NewToken() p.APIKey = idgen.NewToken()
return db.Peer_Insert(a.db, p) return db.Peer_Insert(a.db, p)
} }
func (a *API) Peer_Init(peer *Peer, args m.PeerInitArgs) (*m.PeerConfig, error) { func (a *API) Peer_Init(peer *Peer, args m.PeerInitArgs) error {
a.lock.Lock() a.lock.Lock()
defer a.lock.Unlock() defer a.lock.Unlock()
peer.Version = idgen.NextID(0) // Re-read from DB inside the lock — the caller's copy was fetched before
peer.PubKey = args.EncPubKey // we held the lock, so it may be stale under concurrent requests.
peer.PubSignKey = args.PubSignKey current, err := db.Peer_Get(a.db, peer.NetworkID, peer.PeerIP)
if err != nil {
if err := db.Peer_UpdateFull(a.db, peer); err != nil { return err
return nil, err }
if len(current.WGPubKey) != 0 {
return errors.New("peer already initialized")
} }
conf := a.Config_Get() peer.WGPubKey = args.WGPubKey
peer.SignPubKey = args.SignPubKey
return &m.PeerConfig{ return db.Peer_UpdateFull(a.db, peer)
PeerIP: peer.PeerIP,
Network: conf.VPNNetwork,
PublicIP: peer.PublicIP,
Port: peer.Port,
Relay: peer.Relay,
}, nil
} }
func (a *API) Peer_Update(p *Peer) error { func (a *API) Peer_Delete(networkID int64, peerIP byte) error {
a.lock.Lock() return db.Peer_Delete(a.db, networkID, peerIP)
defer a.lock.Unlock()
p.Version = idgen.NextID(0)
return db.Peer_Update(a.db, p)
} }
func (a *API) Peer_Delete(ip byte) error { func (a *API) Peer_List(networkID int64) ([]*Peer, error) {
return db.Peer_Delete(a.db, ip) return db.Peer_ListAll(a.db, networkID)
} }
func (a *API) Peer_List() ([]*Peer, error) { func (a *API) Peer_Get(networkID int64, ip byte) (*Peer, error) {
return db.Peer_ListAll(a.db) return db.Peer_Get(a.db, networkID, ip)
}
func (a *API) Peer_Get(ip byte) (*Peer, error) {
return db.Peer_Get(a.db, ip)
} }
func (a *API) Peer_GetByAPIKey(key string) (*Peer, error) { func (a *API) Peer_GetByAPIKey(key string) (*Peer, error) {

View File

View File

@@ -16,13 +16,11 @@ type TX interface {
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
type Config struct { type Config struct {
ConfigID int64 ConfigID int64
HubAddress string Password []byte
VPNNetwork []byte
Password []byte
} }
const Config_SelectQuery = "SELECT ConfigID,HubAddress,VPNNetwork,Password FROM config" const Config_SelectQuery = "SELECT ConfigID,Password FROM config"
func Config_Insert( func Config_Insert(
tx TX, tx TX,
@@ -33,7 +31,7 @@ func Config_Insert(
return err return err
} }
_, err = tx.Exec("INSERT INTO config(ConfigID,HubAddress,VPNNetwork,Password) VALUES(?,?,?,?)", row.ConfigID, row.HubAddress, row.VPNNetwork, row.Password) _, err = tx.Exec("INSERT INTO config(ConfigID,Password) VALUES(?,?)", row.ConfigID, row.Password)
return err return err
} }
@@ -46,7 +44,7 @@ func Config_Update(
return err return err
} }
result, err := tx.Exec("UPDATE config SET HubAddress=?,VPNNetwork=? WHERE ConfigID=?", row.HubAddress, row.VPNNetwork, row.ConfigID) result, err := tx.Exec("UPDATE config SET Password=? WHERE ConfigID=?", row.Password, row.ConfigID)
if err != nil { if err != nil {
return err return err
} }
@@ -74,7 +72,7 @@ func Config_UpdateFull(
return err return err
} }
result, err := tx.Exec("UPDATE config SET HubAddress=?,VPNNetwork=?,Password=? WHERE ConfigID=?", row.HubAddress, row.VPNNetwork, row.Password, row.ConfigID) result, err := tx.Exec("UPDATE config SET Password=? WHERE ConfigID=?", row.Password, row.ConfigID)
if err != nil { if err != nil {
return err return err
} }
@@ -124,8 +122,10 @@ func Config_Get(
err error, err error,
) { ) {
row = &Config{} row = &Config{}
r := tx.QueryRow("SELECT ConfigID,HubAddress,VPNNetwork,Password FROM config WHERE ConfigID=?", ConfigID) r := tx.QueryRow("SELECT ConfigID,Password FROM config WHERE ConfigID=?", ConfigID)
err = r.Scan(&row.ConfigID, &row.HubAddress, &row.VPNNetwork, &row.Password) if err = r.Scan(&row.ConfigID, &row.Password); err != nil {
row = nil
}
return return
} }
@@ -139,7 +139,9 @@ func Config_GetWhere(
) { ) {
row = &Config{} row = &Config{}
r := tx.QueryRow(query, args...) r := tx.QueryRow(query, args...)
err = r.Scan(&row.ConfigID, &row.HubAddress, &row.VPNNetwork, &row.Password) if err = r.Scan(&row.ConfigID, &row.Password); err != nil {
row = nil
}
return return
} }
@@ -159,7 +161,7 @@ func Config_Iterate(
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
row := &Config{} row := &Config{}
err := rows.Scan(&row.ConfigID, &row.HubAddress, &row.VPNNetwork, &row.Password) err := rows.Scan(&row.ConfigID, &row.Password)
if !yield(row, err) { if !yield(row, err) {
return return
} }
@@ -185,164 +187,40 @@ func Config_List(
} }
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Table: sessions // Table: networks
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
type Session struct { type Network struct {
SessionID string NetworkID int64
CSRF string LocalDomain string
SignedIn bool Network []byte
CreatedAt int64
LastSeenAt int64
} }
const Session_SelectQuery = "SELECT SessionID,CSRF,SignedIn,CreatedAt,LastSeenAt FROM sessions" const Network_SelectQuery = "SELECT NetworkID,LocalDomain,Network FROM networks"
func Session_Insert( func Network_Insert(
tx TX, tx TX,
row *Session, row *Network,
) (err error) { ) (err error) {
Session_Sanitize(row) Network_Sanitize(row)
if err = Session_Validate(row); err != nil { if err = Network_Validate(row); err != nil {
return err return err
} }
_, err = tx.Exec("INSERT INTO sessions(SessionID,CSRF,SignedIn,CreatedAt,LastSeenAt) VALUES(?,?,?,?,?)", row.SessionID, row.CSRF, row.SignedIn, row.CreatedAt, row.LastSeenAt) _, err = tx.Exec("INSERT INTO networks(NetworkID,LocalDomain,Network) VALUES(?,?,?)", row.NetworkID, row.LocalDomain, row.Network)
return err return err
} }
func Session_Delete( func Network_UpdateFull(
tx TX, tx TX,
SessionID string, row *Network,
) (err error) { ) (err error) {
result, err := tx.Exec("DELETE FROM sessions WHERE SessionID=?", SessionID) Network_Sanitize(row)
if err != nil { if err = Network_Validate(row); err != nil {
return err return err
} }
n, err := result.RowsAffected() result, err := tx.Exec("UPDATE networks SET LocalDomain=?,Network=? WHERE NetworkID=?", row.LocalDomain, row.Network, row.NetworkID)
if err != nil {
panic(err)
}
switch n {
case 0:
return sql.ErrNoRows
case 1:
return nil
default:
panic("multiple rows deleted")
}
}
func Session_Get(
tx TX,
SessionID string,
) (
row *Session,
err error,
) {
row = &Session{}
r := tx.QueryRow("SELECT SessionID,CSRF,SignedIn,CreatedAt,LastSeenAt FROM sessions WHERE SessionID=?", SessionID)
err = r.Scan(&row.SessionID, &row.CSRF, &row.SignedIn, &row.CreatedAt, &row.LastSeenAt)
return
}
func Session_GetWhere(
tx TX,
query string,
args ...any,
) (
row *Session,
err error,
) {
row = &Session{}
r := tx.QueryRow(query, args...)
err = r.Scan(&row.SessionID, &row.CSRF, &row.SignedIn, &row.CreatedAt, &row.LastSeenAt)
return
}
func Session_Iterate(
tx TX,
query string,
args ...any,
) iter.Seq2[*Session, error] {
rows, err := tx.Query(query, args...)
if err != nil {
return func(yield func(*Session, error) bool) {
yield(nil, err)
}
}
return func(yield func(*Session, error) bool) {
defer rows.Close()
for rows.Next() {
row := &Session{}
err := rows.Scan(&row.SessionID, &row.CSRF, &row.SignedIn, &row.CreatedAt, &row.LastSeenAt)
if !yield(row, err) {
return
}
}
}
}
func Session_List(
tx TX,
query string,
args ...any,
) (
l []*Session,
err error,
) {
for row, err := range Session_Iterate(tx, query, args...) {
if err != nil {
return nil, err
}
l = append(l, row)
}
return l, nil
}
// ----------------------------------------------------------------------------
// Table: peers
// ----------------------------------------------------------------------------
type Peer struct {
PeerIP byte
Version int64
APIKey string
Name string
PublicIP []byte
Port uint16
Relay bool
PubKey []byte
PubSignKey []byte
}
const Peer_SelectQuery = "SELECT PeerIP,Version,APIKey,Name,PublicIP,Port,Relay,PubKey,PubSignKey FROM peers"
func Peer_Insert(
tx TX,
row *Peer,
) (err error) {
Peer_Sanitize(row)
if err = Peer_Validate(row); err != nil {
return err
}
_, err = tx.Exec("INSERT INTO peers(PeerIP,Version,APIKey,Name,PublicIP,Port,Relay,PubKey,PubSignKey) VALUES(?,?,?,?,?,?,?,?,?)", row.PeerIP, row.Version, row.APIKey, row.Name, row.PublicIP, row.Port, row.Relay, row.PubKey, row.PubSignKey)
return err
}
func Peer_Update(
tx TX,
row *Peer,
) (err error) {
Peer_Sanitize(row)
if err = Peer_Validate(row); err != nil {
return err
}
result, err := tx.Exec("UPDATE peers SET Version=?,Name=?,PublicIP=?,Port=?,Relay=? WHERE PeerIP=?", row.Version, row.Name, row.PublicIP, row.Port, row.Relay, row.PeerIP)
if err != nil { if err != nil {
return err return err
} }
@@ -361,6 +239,133 @@ func Peer_Update(
} }
} }
func Network_Delete(
tx TX,
NetworkID int64,
) (err error) {
result, err := tx.Exec("DELETE FROM networks WHERE NetworkID=?", NetworkID)
if err != nil {
return err
}
n, err := result.RowsAffected()
if err != nil {
panic(err)
}
switch n {
case 0:
return sql.ErrNoRows
case 1:
return nil
default:
panic("multiple rows deleted")
}
}
func Network_Get(
tx TX,
NetworkID int64,
) (
row *Network,
err error,
) {
row = &Network{}
r := tx.QueryRow("SELECT NetworkID,LocalDomain,Network FROM networks WHERE NetworkID=?", NetworkID)
if err = r.Scan(&row.NetworkID, &row.LocalDomain, &row.Network); err != nil {
row = nil
}
return
}
func Network_GetWhere(
tx TX,
query string,
args ...any,
) (
row *Network,
err error,
) {
row = &Network{}
r := tx.QueryRow(query, args...)
if err = r.Scan(&row.NetworkID, &row.LocalDomain, &row.Network); err != nil {
row = nil
}
return
}
func Network_Iterate(
tx TX,
query string,
args ...any,
) iter.Seq2[*Network, error] {
rows, err := tx.Query(query, args...)
if err != nil {
return func(yield func(*Network, error) bool) {
yield(nil, err)
}
}
return func(yield func(*Network, error) bool) {
defer rows.Close()
for rows.Next() {
row := &Network{}
err := rows.Scan(&row.NetworkID, &row.LocalDomain, &row.Network)
if !yield(row, err) {
return
}
}
}
}
func Network_List(
tx TX,
query string,
args ...any,
) (
l []*Network,
err error,
) {
for row, err := range Network_Iterate(tx, query, args...) {
if err != nil {
return nil, err
}
l = append(l, row)
}
return l, nil
}
// ----------------------------------------------------------------------------
// Table: peers
// ----------------------------------------------------------------------------
type Peer struct {
NetworkID int64
PeerIP byte
APIKey string
Name string
Addr4 []byte
Addr6 []byte
Port uint16
Relay bool
WGPubKey []byte
SignPubKey []byte
}
const Peer_SelectQuery = "SELECT NetworkID,PeerIP,APIKey,Name,Addr4,Addr6,Port,Relay,WGPubKey,SignPubKey FROM peers"
func Peer_Insert(
tx TX,
row *Peer,
) (err error) {
Peer_Sanitize(row)
if err = Peer_Validate(row); err != nil {
return err
}
_, err = tx.Exec("INSERT INTO peers(NetworkID,PeerIP,APIKey,Name,Addr4,Addr6,Port,Relay,WGPubKey,SignPubKey) VALUES(?,?,?,?,?,?,?,?,?,?)", row.NetworkID, row.PeerIP, row.APIKey, row.Name, row.Addr4, row.Addr6, row.Port, row.Relay, row.WGPubKey, row.SignPubKey)
return err
}
func Peer_UpdateFull( func Peer_UpdateFull(
tx TX, tx TX,
row *Peer, row *Peer,
@@ -370,7 +375,7 @@ func Peer_UpdateFull(
return err return err
} }
result, err := tx.Exec("UPDATE peers SET Version=?,APIKey=?,Name=?,PublicIP=?,Port=?,Relay=?,PubKey=?,PubSignKey=? WHERE PeerIP=?", row.Version, row.APIKey, row.Name, row.PublicIP, row.Port, row.Relay, row.PubKey, row.PubSignKey, row.PeerIP) result, err := tx.Exec("UPDATE peers SET APIKey=?,Name=?,Addr4=?,Addr6=?,Port=?,Relay=?,WGPubKey=?,SignPubKey=? WHERE NetworkID=? AND PeerIP=?", row.APIKey, row.Name, row.Addr4, row.Addr6, row.Port, row.Relay, row.WGPubKey, row.SignPubKey, row.NetworkID, row.PeerIP)
if err != nil { if err != nil {
return err return err
} }
@@ -391,9 +396,10 @@ func Peer_UpdateFull(
func Peer_Delete( func Peer_Delete(
tx TX, tx TX,
NetworkID int64,
PeerIP byte, PeerIP byte,
) (err error) { ) (err error) {
result, err := tx.Exec("DELETE FROM peers WHERE PeerIP=?", PeerIP) result, err := tx.Exec("DELETE FROM peers WHERE NetworkID=? AND PeerIP=?", NetworkID, PeerIP)
if err != nil { if err != nil {
return err return err
} }
@@ -414,14 +420,17 @@ func Peer_Delete(
func Peer_Get( func Peer_Get(
tx TX, tx TX,
NetworkID int64,
PeerIP byte, PeerIP byte,
) ( ) (
row *Peer, row *Peer,
err error, err error,
) { ) {
row = &Peer{} row = &Peer{}
r := tx.QueryRow("SELECT PeerIP,Version,APIKey,Name,PublicIP,Port,Relay,PubKey,PubSignKey FROM peers WHERE PeerIP=?", PeerIP) r := tx.QueryRow("SELECT NetworkID,PeerIP,APIKey,Name,Addr4,Addr6,Port,Relay,WGPubKey,SignPubKey FROM peers WHERE NetworkID=? AND PeerIP=?", NetworkID, PeerIP)
err = r.Scan(&row.PeerIP, &row.Version, &row.APIKey, &row.Name, &row.PublicIP, &row.Port, &row.Relay, &row.PubKey, &row.PubSignKey) if err = r.Scan(&row.NetworkID, &row.PeerIP, &row.APIKey, &row.Name, &row.Addr4, &row.Addr6, &row.Port, &row.Relay, &row.WGPubKey, &row.SignPubKey); err != nil {
row = nil
}
return return
} }
@@ -435,7 +444,9 @@ func Peer_GetWhere(
) { ) {
row = &Peer{} row = &Peer{}
r := tx.QueryRow(query, args...) r := tx.QueryRow(query, args...)
err = r.Scan(&row.PeerIP, &row.Version, &row.APIKey, &row.Name, &row.PublicIP, &row.Port, &row.Relay, &row.PubKey, &row.PubSignKey) if err = r.Scan(&row.NetworkID, &row.PeerIP, &row.APIKey, &row.Name, &row.Addr4, &row.Addr6, &row.Port, &row.Relay, &row.WGPubKey, &row.SignPubKey); err != nil {
row = nil
}
return return
} }
@@ -455,7 +466,7 @@ func Peer_Iterate(
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
row := &Peer{} row := &Peer{}
err := rows.Scan(&row.PeerIP, &row.Version, &row.APIKey, &row.Name, &row.PublicIP, &row.Port, &row.Relay, &row.PubKey, &row.PubSignKey) err := rows.Scan(&row.NetworkID, &row.PeerIP, &row.APIKey, &row.Name, &row.Addr4, &row.Addr6, &row.Port, &row.Relay, &row.WGPubKey, &row.SignPubKey)
if !yield(row, err) { if !yield(row, err) {
return return
} }

View File

@@ -3,67 +3,120 @@ package db
import ( import (
"errors" "errors"
"net/netip" "net/netip"
"net/url"
"strings" "strings"
) )
var ( var (
ErrInvalidIP = errors.New("invalid IP") ErrInvalidIP = errors.New("invalid IP")
ErrInvalidPort = errors.New("invalid port") ErrInvalidPeerIP = errors.New("invalid peer IP")
ErrNonPrivateIP = errors.New("non-private IP")
ErrInvalidPort = errors.New("invalid port")
ErrInvalidNetName = errors.New("invalid network name")
ErrNetNameNotLocal = errors.New("network name must end with .local")
ErrInvalidPeerName = errors.New("invalid peer name")
) )
func Config_Sanitize(c *Config) { func Config_Sanitize(c *Config) {
if u, err := url.Parse(c.HubAddress); err == nil {
c.HubAddress = u.String()
}
if addr, ok := netip.AddrFromSlice(c.VPNNetwork); ok {
c.VPNNetwork = addr.AsSlice()
}
} }
func Config_Validate(c *Config) error { func Config_Validate(c *Config) error {
if _, err := url.Parse(c.HubAddress); err != nil { return nil
return err }
func Network_Sanitize(n *Network) {
n.LocalDomain = strings.TrimSpace(n.LocalDomain)
if addr, ok := netip.AddrFromSlice(n.Network); ok {
n.Network = addr.AsSlice()
}
}
func Network_Validate(c *Network) error {
// 15 bytes is linux limit for network interface names. With ending .local,
// max length is 21.
if len(c.LocalDomain) == 0 || len(c.LocalDomain) > 21 {
return ErrInvalidNetName
} }
addr, ok := netip.AddrFromSlice(c.VPNNetwork) if !strings.HasSuffix(c.LocalDomain, ".local") {
return ErrNetNameNotLocal
}
for _, c := range strings.TrimSuffix(c.LocalDomain, ".local") {
if c >= 'a' && c <= 'z' {
continue
}
if c >= '0' && c <= '9' {
continue
}
return ErrInvalidNetName
}
addr, ok := netip.AddrFromSlice(c.Network)
if !ok || !addr.Is4() || addr.As4()[3] != 0 || addr.As4()[0] == 0 { if !ok || !addr.Is4() || addr.As4()[3] != 0 || addr.As4()[0] == 0 {
return ErrInvalidIP return ErrInvalidIP
} }
return nil if !addr.IsPrivate() {
} return ErrNonPrivateIP
}
func Session_Sanitize(s *Session) {
}
func Session_Validate(s *Session) error {
return nil return nil
} }
func Peer_Sanitize(p *Peer) { func Peer_Sanitize(p *Peer) {
p.Name = strings.TrimSpace(p.Name) p.Name = strings.TrimSpace(p.Name)
if len(p.PublicIP) != 0 { if len(p.Addr4) != 0 {
addr, ok := netip.AddrFromSlice(p.PublicIP) if addr, ok := netip.AddrFromSlice(p.Addr4); ok {
if ok && addr.Is4() { // Unmap so an IPv4-mapped form is stored canonically as 4 bytes.
p.PublicIP = addr.AsSlice() p.Addr4 = addr.Unmap().AsSlice()
}
}
if len(p.Addr6) != 0 {
if addr, ok := netip.AddrFromSlice(p.Addr6); ok {
p.Addr6 = addr.AsSlice()
} }
} }
if p.Port == 0 { if p.Port == 0 {
p.Port = 456 p.Port = 51820
} }
} }
func Peer_Validate(p *Peer) error { func Peer_Validate(p *Peer) error {
if len(p.PublicIP) > 0 { if p.PeerIP < 1 || p.PeerIP > 254 {
_, ok := netip.AddrFromSlice(p.PublicIP) return ErrInvalidPeerIP
if !ok { }
if len(p.Addr4) > 0 {
// Must be a genuine IPv4 address (reject an IPv6 in the v4 field).
if addr, ok := netip.AddrFromSlice(p.Addr4); !ok || !addr.Is4() {
return ErrInvalidIP
}
}
if len(p.Addr6) > 0 {
// Must be a genuine IPv6 address (reject IPv4 / IPv4-mapped in the v6 field).
if addr, ok := netip.AddrFromSlice(p.Addr6); !ok || !addr.Is6() || addr.Is4In6() {
return ErrInvalidIP return ErrInvalidIP
} }
} }
if p.Port == 0 { if p.Port == 0 {
return ErrInvalidPort return ErrInvalidPort
} }
if len(p.Name) == 0 {
return ErrInvalidPeerName
}
for _, c := range p.Name {
if c >= 'a' && c <= 'z' {
continue
}
if c >= '0' && c <= '9' {
continue
}
if c == '-' {
continue
}
return ErrInvalidPeerName
}
return nil return nil
} }

View File

@@ -1,26 +1,23 @@
TABLE config OF Config ( TABLE config OF Config (
ConfigID int64 PK, ConfigID int64 PK,
HubAddress string, Password []byte
VPNNetwork []byte,
Password []byte NoUpdate
); );
TABLE sessions OF Session NoUpdate ( TABLE networks OF Network (
SessionID string PK, NetworkID int64 PK,
CSRF string, LocalDomain string NoUpdate,
SignedIn bool, Network []byte NoUpdate
CreatedAt int64,
LastSeenAt int64
); );
TABLE peers OF Peer ( TABLE peers OF Peer (
NetworkID int64 PK,
PeerIP byte PK, PeerIP byte PK,
Version int64,
APIKey string NoUpdate, APIKey string NoUpdate,
Name string, Name string NoUpdate,
PublicIP []byte, Addr4 []byte NoUpdate,
Port uint16, Addr6 []byte NoUpdate,
Relay bool, Port uint16 NoUpdate,
PubKey []byte NoUpdate, Relay bool NoUpdate,
PubSignKey []byte NoUpdate WGPubKey []byte NoUpdate,
SignPubKey []byte NoUpdate
); );

View File

@@ -1,41 +1,8 @@
package db package db
import "time" func Peer_ListAll(tx TX, networkID int64) ([]*Peer, error) {
const query = Peer_SelectQuery + ` WHERE NetworkID=? ORDER BY PeerIP ASC`
func Session_UpdateLastSeenAt( return Peer_List(tx, query, networkID)
tx TX,
id string,
) (err error) {
_, err = tx.Exec("UPDATE sessions SET LastSeenAt=? WHERE SessionID=?", time.Now().Unix(), id)
return err
}
func Session_SetSignedIn(
tx TX,
id string,
) (err error) {
_, err = tx.Exec("UPDATE sessions SET SignedIn=1 WHERE SessionID=?", id)
return err
}
func Session_DeleteBefore(
tx TX,
timestamp int64,
) (err error) {
_, err = tx.Exec("DELETE FROM sessions WHERE LastSeenAt<?", timestamp)
return err
}
func Config_UpdatePassword(
tx TX,
pwdHash []byte,
) (err error) {
_, err = tx.Exec("UPDATE config SET Password=? WHERE ConfigID=1", pwdHash)
return err
}
func Peer_ListAll(tx TX) ([]*Peer, error) {
return Peer_List(tx, Peer_SelectQuery)
} }
func Peer_GetByAPIKey(tx TX, apiKey string) (*Peer, error) { func Peer_GetByAPIKey(tx TX, apiKey string) (*Peer, error) {
@@ -44,8 +11,3 @@ func Peer_GetByAPIKey(tx TX, apiKey string) (*Peer, error) {
Peer_SelectQuery+` WHERE APIKey=?`, Peer_SelectQuery+` WHERE APIKey=?`,
apiKey) apiKey)
} }
func Peer_Exists(tx TX, ip byte) (exists bool, err error) {
err = tx.QueryRow(`SELECT EXISTS(SELECT 1 FROM peers WHERE PeerIP=?)`, ip).Scan(&exists)
return
}

View File

@@ -7,7 +7,6 @@ import (
var ( var (
ErrNotAuthorized = errors.New("not authorized") ErrNotAuthorized = errors.New("not authorized")
ErrNoIPAvailable = errors.New("no IP address available")
ErrInvalidIP = db.ErrInvalidIP ErrInvalidIP = db.ErrInvalidIP
ErrInvalidPort = db.ErrInvalidPort ErrInvalidPort = db.ErrInvalidPort
) )

View File

@@ -1,27 +1,25 @@
CREATE TABLE config ( CREATE TABLE config (
ConfigID INTEGER NOT NULL PRIMARY KEY, -- Always 1. ConfigID INTEGER NOT NULL PRIMARY KEY, -- Always 1.
HubAddress TEXT NOT NULL, -- https://for.example.com
VPNNetwork BLOB NOT NULL, -- Network (/24), example 10.51.50.0
Password BLOB NOT NULL -- bcrypt password for web interface Password BLOB NOT NULL -- bcrypt password for web interface
) WITHOUT ROWID; ) WITHOUT ROWID;
CREATE TABLE sessions ( CREATE TABLE networks (
SessionID TEXT NOT NULL PRIMARY KEY, NetworkID INTEGER NOT NULL PRIMARY KEY,
CSRF TEXT NOT NULL, LocalDomain TEXT NOT NULL UNIQUE, -- Network/interface name.
SignedIn INTEGER NOT NULL, Network BLOB NOT NULL UNIQUE -- Network (/24), example 10.51.50.0
CreatedAt INTEGER NOT NULL,
LastSeenAt INTEGER NOT NULL
) WITHOUT ROWID; ) WITHOUT ROWID;
CREATE INDEX sessions_last_seen_index ON sessions(LastSeenAt);
CREATE TABLE peers ( CREATE TABLE peers (
PeerIP INTEGER NOT NULL PRIMARY KEY, -- Final byte. NetworkID INTEGER NOT NULL,
Version INTEGER NOT NULL, PeerIP INTEGER NOT NULL, -- Final byte of IP.
APIKey TEXT NOT NULL UNIQUE, APIKey TEXT NOT NULL UNIQUE, -- Peer's secret API key.
Name TEXT NOT NULL UNIQUE, -- For humans. Name TEXT NOT NULL, -- For humans.
PublicIP BLOB NOT NULL, Addr4 BLOB NOT NULL,
Addr6 BLOB NOT NULL,
Port INTEGER NOT NULL, Port INTEGER NOT NULL,
Relay INTEGER NOT NULL DEFAULT 0, -- Boolean if peer will forward packets. Must also have public address. Relay INTEGER NOT NULL DEFAULT 0, -- Boolean if peer will forward packets.
PubKey BLOB NOT NULL WGPubKey BLOB NOT NULL,
SignPubKey BLOB NOT NULL,
UNIQUE(NetworkID, Name),
PRIMARY KEY(NetworkID, PeerIP)
) WITHOUT ROWID; ) WITHOUT ROWID;

View File

@@ -1 +0,0 @@
ALTER TABLE peers ADD COLUMN PubSignKey BLOB NOT NULL DEFAULT '';

View File

@@ -3,5 +3,12 @@ package api
import "vppn/hub/api/db" import "vppn/hub/api/db"
type Config = db.Config type Config = db.Config
type Session = db.Session type Network = db.Network
type Peer = db.Peer type Peer = db.Peer
type Session struct {
SessionID string
SignedIn bool
CreatedAt int64
LastSeenAt int64
}

View File

@@ -2,12 +2,13 @@ package hub
import ( import (
"embed" "embed"
"encoding/base64"
"html/template" "html/template"
"net/http" "net/http"
"path/filepath" "path/filepath"
"vppn/hub/api" "vppn/hub/api"
"git.crumpington.com/lib/webutil" "git.crumpington.com/lib/go/webutil"
) )
//go:embed static //go:embed static
@@ -47,6 +48,19 @@ func NewApp(conf Config) (*App, error) {
return app, nil return app, nil
} }
var templateFuncs = template.FuncMap{ func (app *App) Handler() http.Handler {
"ipToString": ipBytesTostring, cop := http.NewCrossOriginProtection()
return cop.Handler(app.mux)
}
var templateFuncs = template.FuncMap{
"ipToString": ipBytesTostring,
"wgKeyString": wgKeyString,
}
func wgKeyString(key []byte) string {
if len(key) == 0 {
return "not set"
}
return base64.StdEncoding.EncodeToString(key)
} }

View File

@@ -2,7 +2,6 @@ package hub
import ( import (
"net/http" "net/http"
"time"
) )
func (a *App) getCookie(r *http.Request, name string) string { func (a *App) getCookie(r *http.Request, name string) string {
@@ -26,9 +25,12 @@ func (a *App) setCookie(w http.ResponseWriter, name, value string) {
func (a *App) deleteCookie(w http.ResponseWriter, name string) { func (a *App) deleteCookie(w http.ResponseWriter, name string) {
http.SetCookie(w, &http.Cookie{ http.SetCookie(w, &http.Cookie{
Name: name, Name: name,
Value: "", Value: "",
Path: "/", Path: "/",
Expires: time.Unix(0, 0), Secure: !a.insecure,
SameSite: http.SameSiteStrictMode,
HttpOnly: true,
MaxAge: -1, // delete now
}) })
} }

42
hub/form.go Normal file
View File

@@ -0,0 +1,42 @@
package hub
import (
"net/url"
"vppn/hub/api"
"git.crumpington.com/lib/go/webutil"
)
func (app *App) formGetNetwork(form url.Values) (*api.Network, error) {
var id int64
if err := webutil.NewFormScanner(form).Scan("NetworkID", &id).Error(); err != nil {
return nil, err
}
return app.api.Network_Get(id)
}
func (app *App) formGetNetworkPeers(form url.Values) (*api.Network, []*api.Peer, error) {
n, err := app.formGetNetwork(form)
if err != nil {
return nil, nil, err
}
peers, err := app.api.Peer_List(n.NetworkID)
return n, peers, err
}
func (app *App) formGetPeer(form url.Values) (*api.Network, *api.Peer, error) {
net, err := app.formGetNetwork(form)
if err != nil {
return nil, nil, err
}
var ip byte
if err := webutil.NewFormScanner(form).Scan("PeerIP", &ip).Error(); err != nil {
return nil, nil, err
}
peer, err := app.api.Peer_Get(net.NetworkID, ip)
return net, peer, err
}

View File

@@ -1,5 +1,5 @@
package hub package hub
const ( const (
SESSION_ID_COOKIE_NAME = "SessionID" sessionIDCookieName = "SessionID"
) )

View File

@@ -5,14 +5,14 @@ import (
"net/http" "net/http"
"vppn/hub/api" "vppn/hub/api"
"git.crumpington.com/lib/webutil" "git.crumpington.com/lib/go/webutil"
) )
type handlerFunc func(s *api.Session, w http.ResponseWriter, r *http.Request) error type handlerFunc func(s *api.Session, w http.ResponseWriter, r *http.Request) error
func (app *App) handlePub(pattern string, fn handlerFunc) { func (app *App) handlePub(pattern string, fn handlerFunc) {
wrapped := func(w http.ResponseWriter, r *http.Request) { wrapped := func(w http.ResponseWriter, r *http.Request) {
sessionID := app.getCookie(r, SESSION_ID_COOKIE_NAME) sessionID := app.getCookie(r, sessionIDCookieName)
s, err := app.api.Session_Get(sessionID) s, err := app.api.Session_Get(sessionID)
if err != nil { if err != nil {
log.Printf("Failed to get session: %v", err) log.Printf("Failed to get session: %v", err)
@@ -20,22 +20,13 @@ func (app *App) handlePub(pattern string, fn handlerFunc) {
return return
} }
if s.SessionID != sessionID {
app.setCookie(w, SESSION_ID_COOKIE_NAME, s.SessionID)
}
if r.Method == http.MethodPost { if r.Method == http.MethodPost {
r.ParseMultipartForm(64 * 1024) r.ParseMultipartForm(64 * 1024)
if r.FormValue("CSRF") != s.CSRF {
log.Printf("%s != %s", r.FormValue("CSRF"), s.CSRF)
http.Error(w, "CSRF mismatch", http.StatusBadRequest)
return
}
} else { } else {
r.ParseForm() r.ParseForm()
} }
if err := fn(s, w, r); err != nil { if err := fn(&s, w, r); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
} }
} }

View File

@@ -5,18 +5,17 @@ import (
"errors" "errors"
"log" "log"
"net/http" "net/http"
"net/netip"
"strings"
"vppn/hub/api" "vppn/hub/api"
"vppn/m" "vppn/m"
"git.crumpington.com/lib/go/webutil" "git.crumpington.com/lib/go/webutil"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
func (a *App) _root(s *api.Session, w http.ResponseWriter, r *http.Request) error { func (a *App) _root(s *api.Session, w http.ResponseWriter, r *http.Request) error {
if s.SignedIn { if s.SignedIn {
return a.redirect(w, r, "/admin/config/") return a.redirect(w, r, "/admin/network/list/")
} else { } else {
return a.redirect(w, r, "/sign-in/") return a.redirect(w, r, "/sign-in/")
} }
@@ -35,9 +34,11 @@ func (a *App) _signinSubmit(s *api.Session, w http.ResponseWriter, r *http.Reque
return err return err
} }
if err := a.api.Session_SignIn(s, pwd); err != nil { sess, err := a.api.Session_SignIn(pwd)
if err != nil {
return err return err
} }
a.setCookie(w, sessionIDCookieName, sess.SessionID)
return a.redirect(w, r, "/") return a.redirect(w, r, "/")
} }
@@ -50,58 +51,164 @@ func (a *App) _adminSignOutSubmit(s *api.Session, w http.ResponseWriter, r *http
if err := a.api.Session_Delete(s.SessionID); err != nil { if err := a.api.Session_Delete(s.SessionID); err != nil {
log.Printf("Failed to delete session cookie %s: %v", s.SessionID, err) log.Printf("Failed to delete session cookie %s: %v", s.SessionID, err)
} }
a.deleteCookie(w, SESSION_ID_COOKIE_NAME) a.deleteCookie(w, sessionIDCookieName)
return a.redirect(w, r, "/") return a.redirect(w, r, "/")
} }
func (a *App) _adminConfig(s *api.Session, w http.ResponseWriter, r *http.Request) error { func (a *App) _adminNetworkList(s *api.Session, w http.ResponseWriter, r *http.Request) error {
peers, err := a.api.Peer_List() l, err := a.api.Network_List()
if err != nil { if err != nil {
return err return err
} }
return a.render("/admin-network-list.html", w, struct {
return a.render("/admin-config.html", w, struct { Session *api.Session
Session *api.Session Networks []*api.Network
Peers []*api.Peer }{s, l})
Config *api.Config
}{
s,
peers,
a.api.Config_Get(),
})
} }
func (a *App) _adminConfigEdit(s *api.Session, w http.ResponseWriter, r *http.Request) error { func (a *App) _adminNetworkCreate(s *api.Session, w http.ResponseWriter, r *http.Request) error {
return a.render("/admin-config-edit.html", w, struct { return a.render("/admin-network-create.html", w, struct{ Session *api.Session }{s})
Session *api.Session
Config *api.Config
}{
s,
a.api.Config_Get(),
})
} }
func (a *App) _adminConfigEditSubmit(s *api.Session, w http.ResponseWriter, r *http.Request) error { func (a *App) _adminNetworkCreateSubmit(s *api.Session, w http.ResponseWriter, r *http.Request) error {
var ( n := &api.Network{}
conf = a.api.Config_Get() var netStr string
ipStr string
)
err := webutil.NewFormScanner(r.Form). err := webutil.NewFormScanner(r.Form).
Scan("HubAddress", &conf.HubAddress). Scan("LocalDomain", &n.LocalDomain).
Scan("VPNNetwork", &ipStr). Scan("Network", &netStr).
Error() Error()
if err != nil { if err != nil {
return err return err
} }
if conf.VPNNetwork, err = stringToIP(ipStr); err != nil { n.Network, err = stringToIP(netStr)
if err != nil {
return err return err
} }
if err := a.api.Config_Update(conf); err != nil {
if err := a.api.Network_Create(n); err != nil {
return err return err
} }
return a.redirect(w, r, "/admin/config/")
return a.redirect(w, r, "/admin/network/view/?NetworkID=%d", n.NetworkID)
}
func (a *App) _adminNetworkView(s *api.Session, w http.ResponseWriter, r *http.Request) error {
n, peers, err := a.formGetNetworkPeers(r.Form)
if err != nil {
return err
}
return a.render("/network/network-view.html", w, struct {
Session *api.Session
Network *api.Network
Peers []*api.Peer
}{s, n, peers})
}
func (a *App) _adminNetworkDelete(s *api.Session, w http.ResponseWriter, r *http.Request) error {
n, peers, err := a.formGetNetworkPeers(r.Form)
if err != nil {
return err
}
return a.render("/network/network-delete.html", w, struct {
Session *api.Session
Network *api.Network
Peers []*api.Peer
}{s, n, peers})
}
func (a *App) _adminNetworkDeleteSubmit(s *api.Session, w http.ResponseWriter, r *http.Request) error {
n, err := a.formGetNetwork(r.Form)
if err != nil {
return err
}
if err = a.api.Network_Delete(n); err != nil {
return err
}
return a.redirect(w, r, "/admin/network/list/")
}
func (a *App) _adminPeerCreate(s *api.Session, w http.ResponseWriter, r *http.Request) error {
n, err := a.formGetNetwork(r.Form)
if err != nil {
return err
}
return a.render("/network/peer-create.html", w, struct {
Session *api.Session
Network *api.Network
}{s, n})
}
func (a *App) _adminPeerCreateSubmit(s *api.Session, w http.ResponseWriter, r *http.Request) error {
var addr4Str, addr6Str string
p := &api.Peer{}
err := webutil.NewFormScanner(r.Form).
Scan("NetworkID", &p.NetworkID).
Scan("IP", &p.PeerIP).
Scan("Name", &p.Name).
Scan("Addr4", &addr4Str).
Scan("Addr6", &addr6Str).
Scan("Port", &p.Port).
Scan("Relay", &p.Relay).
Error()
if err != nil {
return err
}
if p.Addr4, err = stringToIP(addr4Str); err != nil {
return err
}
if p.Addr6, err = stringToIP(addr6Str); err != nil {
return err
}
if err := a.api.Peer_CreateNew(p); err != nil {
return err
}
return a.redirect(w, r, "/admin/peer/view/?NetworkID=%d&PeerIP=%d", p.NetworkID, p.PeerIP)
}
func (a *App) _adminPeerView(s *api.Session, w http.ResponseWriter, r *http.Request) error {
net, peer, err := a.formGetPeer(r.Form)
if err != nil {
return err
}
return a.render("/network/peer-view.html", w, struct {
Session *api.Session
Network *api.Network
Peer *api.Peer
}{s, net, peer})
}
func (a *App) _adminPeerDelete(s *api.Session, w http.ResponseWriter, r *http.Request) error {
n, peer, err := a.formGetPeer(r.Form)
if err != nil {
return err
}
return a.render("/network/peer-delete.html", w, struct {
Session *api.Session
Network *api.Network
Peer *api.Peer
}{s, n, peer})
}
func (a *App) _adminPeerDeleteSubmit(s *api.Session, w http.ResponseWriter, r *http.Request) error {
n, peer, err := a.formGetPeer(r.Form)
if err != nil {
return err
}
if err := a.api.Peer_Delete(n.NetworkID, peer.PeerIP); err != nil {
return err
}
return a.redirect(w, r, "/admin/network/view/?NetworkID=%d", n.NetworkID)
} }
func (a *App) _adminPasswordEdit(s *api.Session, w http.ResponseWriter, r *http.Request) error { func (a *App) _adminPasswordEdit(s *api.Session, w http.ResponseWriter, r *http.Request) error {
@@ -110,13 +217,17 @@ func (a *App) _adminPasswordEdit(s *api.Session, w http.ResponseWriter, r *http.
func (a *App) _adminPasswordSubmit(s *api.Session, w http.ResponseWriter, r *http.Request) error { func (a *App) _adminPasswordSubmit(s *api.Session, w http.ResponseWriter, r *http.Request) error {
var ( var (
conf = a.api.Config_Get()
curPwd string curPwd string
newPwd string newPwd string
newPwd2 string newPwd2 string
) )
err := webutil.NewFormScanner(r.Form). conf, err := a.api.Config_Get()
if err != nil {
return err
}
err = webutil.NewFormScanner(r.Form).
Scan("CurrentPassword", &curPwd). Scan("CurrentPassword", &curPwd).
Scan("NewPassword", &newPwd). Scan("NewPassword", &newPwd).
Scan("NewPassword2", &newPwd2). Scan("NewPassword2", &newPwd2).
@@ -143,205 +254,95 @@ func (a *App) _adminPasswordSubmit(s *api.Session, w http.ResponseWriter, r *htt
return err return err
} }
if err := a.api.Config_UpdatePassword(hash); err != nil { conf.Password = hash
if err := a.api.Config_Update(conf); err != nil {
return err return err
} }
return a.redirect(w, r, "/admin/config/") return a.redirect(w, r, "/admin/config/")
} }
func (a *App) _adminHosts(s *api.Session, w http.ResponseWriter, r *http.Request) error {
conf := a.api.Config_Get()
peers, err := a.api.Peer_List()
if err != nil {
return err
}
b := strings.Builder{}
for _, peer := range peers {
ip := conf.VPNNetwork
ip[3] = peer.PeerIP
b.WriteString(netip.AddrFrom4([4]byte(ip)).String())
b.WriteString(" ")
b.WriteString(peer.Name)
b.WriteString("\n")
}
w.Write([]byte(b.String()))
return nil
}
func (a *App) _adminPeerCreate(s *api.Session, w http.ResponseWriter, r *http.Request) error {
return a.render("/admin-peer-create.html", w, struct{ Session *api.Session }{s})
}
func (a *App) _adminPeerCreateSubmit(s *api.Session, w http.ResponseWriter, r *http.Request) error {
var ipStr string
p := &api.Peer{}
err := webutil.NewFormScanner(r.Form).
Scan("IP", &p.PeerIP).
Scan("Name", &p.Name).
Scan("PublicIP", &ipStr).
Scan("Port", &p.Port).
Scan("Relay", &p.Relay).
Error()
if err != nil {
return err
}
if p.PublicIP, err = stringToIP(ipStr); err != nil {
return err
}
if err := a.api.Peer_CreateNew(p); err != nil {
return err
}
return a.redirect(w, r, "/admin/peer/view/?PeerIP=%d", p.PeerIP)
}
func (a *App) _adminPeerView(s *api.Session, w http.ResponseWriter, r *http.Request) error {
var peerIP byte
err := webutil.NewFormScanner(r.Form).Scan("PeerIP", &peerIP).Error()
if err != nil {
return err
}
peer, err := a.api.Peer_Get(peerIP)
if err != nil {
return err
}
return a.render("/admin-peer-view.html", w, struct {
Session *api.Session
Peer *api.Peer
}{s, peer})
}
func (a *App) _adminPeerEdit(s *api.Session, w http.ResponseWriter, r *http.Request) error {
var peerIP byte
err := webutil.NewFormScanner(r.Form).Scan("PeerIP", &peerIP).Error()
if err != nil {
return err
}
peer, err := a.api.Peer_Get(peerIP)
if err != nil {
return err
}
return a.render("/admin-peer-edit.html", w, struct {
Session *api.Session
Peer *api.Peer
}{s, peer})
}
func (a *App) _adminPeerEditSubmit(s *api.Session, w http.ResponseWriter, r *http.Request) error {
var (
peerIP byte
ipStr string
)
err := webutil.NewFormScanner(r.Form).Scan("PeerIP", &peerIP).Error()
if err != nil {
return err
}
peer, err := a.api.Peer_Get(peerIP)
if err != nil {
return err
}
err = webutil.NewFormScanner(r.Form).
Scan("Name", &peer.Name).
Scan("PublicIP", &ipStr).
Scan("Port", &peer.Port).
Scan("Relay", &peer.Relay).
Error()
if err != nil {
return err
}
if peer.PublicIP, err = stringToIP(ipStr); err != nil {
return err
}
if err = a.api.Peer_Update(peer); err != nil {
return err
}
return a.redirect(w, r, "/admin/peer/view/?PeerIP=%d", peer.PeerIP)
}
func (a *App) _adminPeerDelete(s *api.Session, w http.ResponseWriter, r *http.Request) error {
var peerIP byte
err := webutil.NewFormScanner(r.Form).Scan("PeerIP", &peerIP).Error()
if err != nil {
return err
}
peer, err := a.api.Peer_Get(peerIP)
if err != nil {
return err
}
return a.render("/admin-peer-delete.html", w, struct {
Session *api.Session
Peer *api.Peer
}{s, peer})
}
func (a *App) _adminPeerDeleteSubmit(s *api.Session, w http.ResponseWriter, r *http.Request) error {
var peerIP byte
err := webutil.NewFormScanner(r.Form).Scan("PeerIP", &peerIP).Error()
if err != nil {
return err
}
if err := a.api.Peer_Delete(peerIP); err != nil {
return err
}
return a.redirect(w, r, "/admin/peer/list/")
}
func (a *App) _peerInit(peer *api.Peer, w http.ResponseWriter, r *http.Request) error { func (a *App) _peerInit(peer *api.Peer, w http.ResponseWriter, r *http.Request) error {
if len(peer.WGPubKey) != 0 {
http.Error(w, "Already initialized", http.StatusConflict)
return nil
}
args := m.PeerInitArgs{} args := m.PeerInitArgs{}
if err := json.NewDecoder(r.Body).Decode(&args); err != nil { if err := json.NewDecoder(r.Body).Decode(&args); err != nil {
return err return err
} }
conf, err := a.api.Peer_Init(peer, args) if len(args.WGPubKey) != 32 {
http.Error(w, "invalid WGPubKey", http.StatusBadRequest)
return nil
}
if len(args.SignPubKey) != 32 {
http.Error(w, "invalid SignPubKey", http.StatusBadRequest)
return nil
}
net, err := a.api.Network_Get(peer.NetworkID)
if err != nil { if err != nil {
return err return err
} }
return a.sendJSON(w, conf) if err := a.api.Peer_Init(peer, args); err != nil {
return err
}
resp := m.PeerInitResp{
PeerIP: peer.PeerIP,
Network: net.Network,
LocalDomain: net.LocalDomain,
}
resp.NetworkState.Peers, err = a.peersList(net.NetworkID)
if err != nil {
return err
}
return a.sendJSON(w, resp)
} }
func (a *App) _peerFetchState(peer *api.Peer, w http.ResponseWriter, r *http.Request) error { func (a *App) _peerFetchState(peer *api.Peer, w http.ResponseWriter, r *http.Request) error {
peers, err := a.api.Peer_List() peers, err := a.peersList(peer.NetworkID)
if err != nil { if err != nil {
return err return err
} }
return a.sendJSON(w, m.NetworkState{Peers: peers})
}
state := m.NetworkState{} func (a *App) peersList(networkID int64) (peers []m.Peer, err error) {
l, err := a.api.Peer_List(networkID)
for _, p := range peers { if err != nil {
if len(p.PubKey) != 0 { return nil, err
state.Peers[p.PeerIP] = &m.Peer{
PeerIP: p.PeerIP,
Version: p.Version,
Name: p.Name,
PublicIP: p.PublicIP,
Port: p.Port,
Relay: p.Relay,
PubKey: p.PubKey,
PubSignKey: p.PubSignKey,
}
}
} }
return a.sendJSON(w, state) peers = make([]m.Peer, 0, len(l))
for _, p := range l {
if len(p.WGPubKey) == 0 {
continue
}
wgKey, err := wgtypes.NewKey(p.WGPubKey)
if err != nil {
continue // malformed key; skip rather than serve garbage
}
var signKey [32]byte
copy(signKey[:], p.SignPubKey)
peers = append(peers, m.Peer{
PeerIP: p.PeerIP,
Name: p.Name,
Addr4: addrFromBytes(p.Addr4),
Addr6: addrFromBytes(p.Addr6),
Port: p.Port,
Relay: p.Relay,
WGPubKey: wgKey,
SignPubKey: signKey,
})
}
return peers, nil
} }

View File

@@ -6,7 +6,7 @@ import (
"net/http" "net/http"
"os" "os"
"git.crumpington.com/lib/webutil" "git.crumpington.com/lib/go/webutil"
) )
func Main() { func Main() {
@@ -31,7 +31,7 @@ func Main() {
srv := &http.Server{ srv := &http.Server{
Addr: conf.ListenAddr, Addr: conf.ListenAddr,
Handler: app.mux, Handler: app.Handler(),
} }
log.Fatal(webutil.ListenAndServe(srv)) log.Fatal(webutil.ListenAndServe(srv))

View File

@@ -9,22 +9,25 @@ func (a *App) registerRoutes() {
a.handleNotSignedIn("GET /sign-in/", a._signin) a.handleNotSignedIn("GET /sign-in/", a._signin)
a.handleNotSignedIn("POST /sign-in/", a._signinSubmit) a.handleNotSignedIn("POST /sign-in/", a._signinSubmit)
a.handleSignedIn("GET /admin/config/", a._adminConfig)
a.handleSignedIn("GET /admin/config/edit/", a._adminConfigEdit)
a.handleSignedIn("POST /admin/config/edit/", a._adminConfigEditSubmit)
a.handleSignedIn("GET /admin/sign-out/", a._adminSignOut) a.handleSignedIn("GET /admin/sign-out/", a._adminSignOut)
a.handleSignedIn("POST /admin/sign-out/", a._adminSignOutSubmit) a.handleSignedIn("POST /admin/sign-out/", a._adminSignOutSubmit)
a.handleSignedIn("GET /admin/password/edit/", a._adminPasswordEdit)
a.handleSignedIn("POST /admin/password/edit/", a._adminPasswordSubmit) a.handleSignedIn("GET /admin/network/list/", a._adminNetworkList)
a.handleSignedIn("GET /admin/peer/hosts/", a._adminHosts) a.handleSignedIn("GET /admin/network/create/", a._adminNetworkCreate)
a.handleSignedIn("POST /admin/network/create/", a._adminNetworkCreateSubmit)
a.handleSignedIn("GET /admin/network/delete/", a._adminNetworkDelete)
a.handleSignedIn("POST /admin/network/delete/", a._adminNetworkDeleteSubmit)
a.handleSignedIn("GET /admin/network/view/", a._adminNetworkView)
a.handleSignedIn("GET /admin/peer/create/", a._adminPeerCreate) a.handleSignedIn("GET /admin/peer/create/", a._adminPeerCreate)
a.handleSignedIn("POST /admin/peer/create/", a._adminPeerCreateSubmit) a.handleSignedIn("POST /admin/peer/create/", a._adminPeerCreateSubmit)
a.handleSignedIn("GET /admin/peer/view/", a._adminPeerView) a.handleSignedIn("GET /admin/peer/view/", a._adminPeerView)
a.handleSignedIn("GET /admin/peer/edit/", a._adminPeerEdit)
a.handleSignedIn("POST /admin/peer/edit/", a._adminPeerEditSubmit)
a.handleSignedIn("GET /admin/peer/delete/", a._adminPeerDelete) a.handleSignedIn("GET /admin/peer/delete/", a._adminPeerDelete)
a.handleSignedIn("POST /admin/peer/delete/", a._adminPeerDeleteSubmit) a.handleSignedIn("POST /admin/peer/delete/", a._adminPeerDeleteSubmit)
a.handleSignedIn("GET /admin/password/edit/", a._adminPasswordEdit)
a.handleSignedIn("POST /admin/password/edit/", a._adminPasswordSubmit)
a.handlePeer("POST /peer/init/", a._peerInit) a.handlePeer("POST /peer/init/", a._peerInit)
a.handlePeer("GET /peer/fetch-state/", a._peerFetchState) a.handlePeer("GET /peer/fetch-state/", a._peerFetchState)
} }

View File

@@ -1,20 +0,0 @@
{{define "body" -}}
<h2>Config</h2>
<form method="POST">
<input type="hidden" name="CSRF" value="{{.Session.CSRF}}">
<p>
<label>Hub Address</label><br>
<input type="url" name="HubAddress" value="{{.Config.HubAddress}}">
</p>
<p>
<label>VPN Network</label><br>
<input type="text" name="VPNNetwork" value="{{ipToString .Config.VPNNetwork}}">
</p>
<p>
<button type="submit">Save</button>
<a href="/admin/config/">Cancel</a>
</p>
</form>
{{- end}}

View File

@@ -0,0 +1,18 @@
{{define "body" -}}
<h2>Create Network</h2>
<form method="POST">
<p>
<label>Local Domain (ending with .local)</label><br>
<input type="text" name="LocalDomain">
</p>
<p>
<label>Network /24</label><br>
<input type="text" name="Network">
</p>
<p>
<button type="submit">Save</button>
<a href="/admin/network/list/">Cancel</a>
</p>
</form>
{{- end}}

View File

@@ -0,0 +1,38 @@
{{define "body" -}}
<h2>Networks</h2>
<p>
<a href="/admin/network/create/">Create</a>
</p>
{{if .Networks -}}
<table>
<thead>
<tr>
<th>Local Domain</th>
<th>Network</th>
</tr>
</thead>
<tbody>
{{range .Networks -}}
<tr>
<td>
<a href="/admin/network/view/?NetworkID={{.NetworkID}}">
{{.LocalDomain}}
</a>
</td>
<td>{{ipToString .Network}}</td>
</tr>
</tbody>
{{- end}}
</table>
{{- else}}
<p>No networks.</p>
{{- end}}
<h3>Settings</h3>
<ul>
<li><a href="/admin/password/edit/">Password</a></li>
</ul>
{{- end}}

View File

@@ -2,8 +2,7 @@
<h2>Change Password</h2> <h2>Change Password</h2>
<form method="POST"> <form method="POST">
<input type="hidden" name="CSRF" value="{{.Session.CSRF}}"> <p>
<p>
<label>Current Password</label><br> <label>Current Password</label><br>
<input type="password" name="CurrentPassword"> <input type="password" name="CurrentPassword">
</p> </p>

View File

@@ -1,34 +0,0 @@
{{define "body" -}}
<h2>New Peer</h2>
<form method="POST">
<input type="hidden" name="CSRF" value="{{.Session.CSRF}}">
<p>
<label>IP</label><br>
<input type="number" name="IP" min="1" max="255" value="0">
</p>
<p>
<label>Name</label><br>
<input type="text" name="Name">
</p>
<p>
<label>Public IP</label><br>
<input type="text" name="PublicIP">
</p>
<p>
<label>Port</label><br>
<input type="number" name="Port" value="456">
</p>
<p>
<label>
<input type="checkbox" name="Relay">
Relay
</label>
</p>
<p>
<button type="submit">Save</button>
<a href="/admin/config/">Cancel</a>
</p>
</form>
{{- end}}

View File

@@ -1,36 +0,0 @@
{{define "body" -}}
<h2>Delete Peer</h2>
{{with .Peer -}}
<form method="POST">
<input type="hidden" name="CSRF" value="{{$.Session.CSRF}}">
<p>
<label>Peer IP</label><br>
<input type="number" name="PeerIP" value="{{.PeerIP}}" disabled>
</p>
<p>
<label>Name</label><br>
<input type="text" value="{{.Name}}" disabled>
</p>
<p>
<label>Public IP</label><br>
<input type="text" value="{{ipToString .PublicIP}}" disabled>
</p>
<p>
<label>Port</label><br>
<input type="number" value="{{.Port}}" disabled>
</p>
<p>
<label>
<input type="checkbox" {{if .Relay}}checked{{end}} disabled>
Relay
</label>
</p>
<p>
<button type="submit">Delete</button>
<a href="/admin/peer/view/?PeerIP={{.PeerIP}}">Cancel</a>
</p>
</form>
{{- end}}
{{- end}}

View File

@@ -1,36 +0,0 @@
{{define "body" -}}
<h2>Edit Peer</h2>
{{with .Peer -}}
<form method="POST">
<input type="hidden" name="CSRF" value="{{$.Session.CSRF}}">
<p>
<label>Peer IP</label><br>
<input type="number" name="PeerIP" value="{{.PeerIP}}" disabled>
</p>
<p>
<label>Name</label><br>
<input type="text" name="Name" value="{{.Name}}">
</p>
<p>
<label>Public IP</label><br>
<input type="text" name="PublicIP" value="{{ipToString .PublicIP}}">
</p>
<p>
<label>Port</label><br>
<input type="number" name="Port" value="{{.Port}}">
</p>
<p>
<label>
<input type="checkbox" name="Relay" {{if .Relay}}checked{{end}}>
Relay
</label>
</p>
<p>
<button type="submit">Save</button>
<a href="/admin/peer/view/?PeerIP={{.PeerIP}}">Cancel</a>
</p>
</form>
{{- end}}
{{- end}}

View File

@@ -1,13 +0,0 @@
{{define "body" -}}
<h2>Initialize Peer</h2>
<p>
Configure the peer with the following URL:
</p>
<pre>
{{.HubAddress}}/peer/init/?Code={{.Code}}
</pre>
<p>
<a href="/admin/config/">Done</a>
</p>
{{- end}}

View File

@@ -1,13 +0,0 @@
{{define "body" -}}
<h2>Create Peer</h2>
<p>
Configure the peer with the following URL:
</p>
<pre>
{{.HubAddress}}/peer/create/?Code={{.Code}}
</pre>
<p>
<a href="/admin/config/">Done</a>
</p>
{{- end}}

View File

@@ -1,20 +0,0 @@
{{define "body" -}}
<h2>Peer</h2>
<p>
<a href="/admin/peer/edit/?PeerIP={{.Peer.PeerIP}}">Edit</a> /
<a href="/admin/peer/delete/?PeerIP={{.Peer.PeerIP}}">Delete</a>
</p>
{{with .Peer -}}
<table class="def-list">
<tr><td>Peer IP</td><td>{{.PeerIP}}</td></tr>
<tr><td>Name</td><td>{{.Name}}</td></tr>
<tr><td>Public IP</td><td>{{ipToString .PublicIP}}</td></tr>
<tr><td>Port</td><td>{{.Port}}</td></tr>
<tr><td>Relay</td><td>{{if .Relay}}T{{else}}F{{end}}</td></tr>
<tr><td>API Key</td><td>{{.APIKey}}</td></tr>
</table>
{{- end}}
{{- end}}

View File

@@ -2,10 +2,9 @@
<h2>Sign Out</h2> <h2>Sign Out</h2>
<form method="POST"> <form method="POST">
<input type="hidden" name="CSRF" value="{{.Session.CSRF}}"> <p>
<p>
<button type="submit">Sign Out</button> <button type="submit">Sign Out</button>
<a href="/admin/config/">Cancel</a> <a href="/">Cancel</a>
</p> </p>
</form> </form>
{{- end}} {{- end}}

View File

@@ -10,7 +10,7 @@
<h1>VPPN</h1> <h1>VPPN</h1>
<nav> <nav>
{{if .Session.SignedIn -}} {{if .Session.SignedIn -}}
<a href="/admin/config/">Home</a> / <a href="/admin/networks/list/">Home</a> /
<a href="/admin/sign-out/">Sign out</a> <a href="/admin/sign-out/">Sign out</a>
{{- end}} {{- end}}
</nav> </nav>

View File

@@ -0,0 +1,25 @@
<!DOCTYPE html>
<html lang="en">
<head>
<title>VPPN Hub</title>
<link rel="stylesheet" href="/static/new.min.css">
<link rel="stylesheet" href="/static/custom.css">
</head>
<body>
<header>
<h1>VPPN</h1>
<nav>
{{if .Session.SignedIn -}}
<a href="/admin/networks/list/">Home</a> /
<a href="/admin/sign-out/">Sign out</a>
{{- end}}
</nav>
</header>
<h2>
Network:
<a href="/admin/network/view/?NetworkID={{.Network.NetworkID}}">{{.Network.LocalDomain}}</a>
</h2>
{{block "body" .}}There's nothing here.{{end}}
</body>
</html>

View File

@@ -0,0 +1,15 @@
{{define "body" -}}
<h3>Delete</h3>
{{if .Peers -}}
<p>You must first delete all peers.</p>
{{- else -}}
<form method="POST">
<input type="hidden" name="NetworkID" value="{{.Network.NetworkID}}">
<p>
<button type="submit">Delete</button>
<a href="/admin/network/view/?NetworkID={{.Network.NetworkID}}">Cancel</a>
</p>
</form>
{{- end}}
{{- end}}

View File

@@ -1,27 +1,19 @@
{{define "body" -}} {{define "body" -}}
<h2>Config</h2>
<p> <p>
<a href="/admin/config/edit/">Edit</a> / <a href="/admin/network/delete/?NetworkID={{.Network.NetworkID}}">Delete</a>
<a href="/admin/password/edit/">Change Password</a>
</p> </p>
<table class="def-list"> <table class="def-list">
<tr> <tr>
<td>Hub Address</td> <td>Network</td>
<td>{{.Config.HubAddress}}</td> <td>{{ipToString .Network.Network}}/24</td>
</tr>
<tr>
<td>VPN Network</td>
<td>{{ipToString .Config.VPNNetwork}}</td>
</tr> </tr>
</table> </table>
<h2>Peers</h2> <h3>Peers</h3>
<p> <p>
<a href="/admin/peer/create/">Add Peer</a> / <a href="/admin/peer/create/?NetworkID={{.Network.NetworkID}}">Create</a>
<a href="/admin/peer/hosts/">Hosts</a>
</p> </p>
{{if .Peers -}} {{if .Peers -}}
@@ -30,7 +22,8 @@
<tr> <tr>
<th>PeerIP</th> <th>PeerIP</th>
<th>Name</th> <th>Name</th>
<th>Public IP</th> <th>IPv4</th>
<th>IPv6</th>
<th>Port</th> <th>Port</th>
<th>Relay</th> <th>Relay</th>
</tr> </tr>
@@ -39,12 +32,13 @@
{{range .Peers -}} {{range .Peers -}}
<tr> <tr>
<td> <td>
<a href="/admin/peer/view/?PeerIP={{.PeerIP}}"> <a href="/admin/peer/view/?NetworkID={{$.Network.NetworkID}}&PeerIP={{.PeerIP}}">
{{.PeerIP}} {{.PeerIP}}
</a> </a>
</td> </td>
<td>{{.Name}}</td> <td>{{.Name}}</td>
<td>{{ipToString .PublicIP}}</td> <td>{{ipToString .Addr4}}</td>
<td>{{ipToString .Addr6}}</td>
<td>{{.Port}}</td> <td>{{.Port}}</td>
<td>{{if .Relay}}T{{else}}F{{end}}</td> <td>{{if .Relay}}T{{else}}F{{end}}</td>
</tr> </tr>

View File

@@ -0,0 +1,38 @@
{{define "body" -}}
<h3>New Peer</h3>
<form method="POST">
<input type="hidden" name="NetworkID" value="{{.Network.NetworkID}}">
<p>
<label>IP</label><br>
<input type="number" name="IP" min="1" max="255" value="0">
</p>
<p>
<label>Name</label><br>
<input type="text" name="Name">
</p>
<p>
<label>IPv4 Address (optional)</label><br>
<input type="text" name="Addr4">
</p>
<p>
<label>IPv6 Address (optional)</label><br>
<input type="text" name="Addr6">
</p>
<p>
<label>WireGuard Port</label><br>
<input type="number" name="Port" value="51820">
</p>
<p>
<label>
<input type="checkbox" name="Relay">
Relay
</label>
</p>
<p>
<button type="submit">Save</button>
<a href="/admin/network/view/?NetworkID={{.Network.NetworkID}}">Cancel</a>
</p>
</form>
{{- end}}

View File

@@ -0,0 +1,14 @@
{{define "body" -}}
<h3>Delete {{.Peer.Name}}</h3>
{{with .Peer -}}
<form method="POST">
<input type="hidden" name="NetworkID" value="{{.NetworkID}}">
<input type="hidden" name="PeerIP" value="{{.PeerIP}}">
<p>
<button type="submit">Delete</button>
<a href="/admin/peer/view/?PeerIP={{.PeerIP}}&NetworkID={{.NetworkID}}">Cancel</a>
</p>
</form>
{{- end}}
{{- end}}

View File

@@ -0,0 +1,24 @@
{{define "body" -}}
<h3>{{.Peer.Name}}</h3>
<p>
<a href="/admin/peer/delete/?NetworkID={{.Network.NetworkID}}&PeerIP={{.Peer.PeerIP}}">Delete</a>
</p>
{{with .Peer -}}
<table class="def-list">
<tr><td>Peer IP</td><td>{{.PeerIP}}</td></tr>
<tr><td>IPv4 Address</td><td>{{ipToString .Addr4}}</td></tr>
<tr><td>IPv6 Address</td><td>{{ipToString .Addr6}}</td></tr>
<tr><td>WireGuard Port</td><td>{{.Port}}</td></tr>
<tr><td>Relay</td><td>{{if .Relay}}T{{else}}F{{end}}</td></tr>
<tr><td>WG Public Key</td><td>{{wgKeyString .WGPubKey}}</td></tr>
</table>
<details>
<summary>API Key</summary>
<p>{{.APIKey}}</p>
</details>
{{- end}}
{{- end}}

View File

@@ -2,8 +2,7 @@
<h2>Sign In</h2> <h2>Sign In</h2>
<form method="POST"> <form method="POST">
<input type="hidden" name="CSRF" value="{{.Session.CSRF}}"> <p>
<p>
<label>Password</label><br> <label>Password</label><br>
<input type="password" name="Password"> <input type="password" name="Password">
</p> </p>

View File

@@ -38,6 +38,19 @@ func (app *App) sendJSON(w http.ResponseWriter, data any) error {
return nil return nil
} }
// addrFromBytes parses raw IP bytes (4 or 16) into a netip.Addr, unmapping
// IPv4-in-IPv6, returning the zero Addr for empty/invalid input.
func addrFromBytes(b []byte) netip.Addr {
if len(b) == 0 {
return netip.Addr{}
}
addr, ok := netip.AddrFromSlice(b)
if !ok {
return netip.Addr{}
}
return addr.Unmap()
}
func stringToIP(in string) ([]byte, error) { func stringToIP(in string) ([]byte, error) {
in = strings.TrimSpace(in) in = strings.TrimSpace(in)
if len(in) == 0 { if len(in) == 0 {

View File

@@ -1,30 +1,133 @@
// The package `m` contains models shared between the hub and peer programs. // The package `m` contains models shared between the hub and peer programs.
package m package m
import (
"encoding/base64"
"encoding/json"
"fmt"
"net/netip"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type PeerInitArgs struct { type PeerInitArgs struct {
EncPubKey []byte WGPubKey []byte
PubSignKey []byte SignPubKey []byte
} }
type PeerConfig struct { type PeerInitResp struct {
PeerIP byte PeerIP byte
Network []byte Network []byte
PublicIP []byte LocalDomain string
Port uint16 NetworkState NetworkState
Relay bool
} }
// Peer is the network membership record for a single peer, exchanged between
// the hub and peers. Addr4/Addr6 are the peer's public endpoint addresses (zero
// if it has none); Port is its WireGuard listen port, meaningful even for a
// non-public peer (it is the peer's own bind/beacon port).
type Peer struct { type Peer struct {
PeerIP byte PeerIP byte
Version int64
Name string Name string
PublicIP []byte Addr4 netip.Addr // zero if none
Addr6 netip.Addr // zero if none
Port uint16 Port uint16
Relay bool Relay bool
PubKey []byte WGPubKey wgtypes.Key
PubSignKey []byte SignPubKey [32]byte
}
// IsPublic reports whether the peer advertises at least one reachable endpoint.
func (p Peer) IsPublic() bool {
return p.Addr4.IsValid() || p.Addr6.IsValid()
}
// Endpoint4 returns the IPv4 endpoint (addr+port), or the zero AddrPort if the
// peer has no IPv4 address.
func (p Peer) Endpoint4() netip.AddrPort {
if !p.Addr4.IsValid() {
return netip.AddrPort{}
}
return netip.AddrPortFrom(p.Addr4, p.Port)
}
// Endpoint6 returns the IPv6 endpoint (addr+port), or the zero AddrPort if the
// peer has no IPv6 address.
func (p Peer) Endpoint6() netip.AddrPort {
if !p.Addr6.IsValid() {
return netip.AddrPort{}
}
return netip.AddrPortFrom(p.Addr6, p.Port)
}
// PreferredEndpoint returns the IPv4 endpoint if present, else IPv6.
func (p Peer) PreferredEndpoint() netip.AddrPort {
if ep := p.Endpoint4(); ep.IsValid() {
return ep
}
return p.Endpoint6()
}
// peerJSON is the wire representation. netip.Addr fields round-trip as text
// strings automatically; only the fixed-size key arrays need base64 (otherwise
// encoding/json would emit them as arrays of numbers).
type peerJSON struct {
PeerIP byte
Name string
Addr4 netip.Addr
Addr6 netip.Addr
Port uint16
Relay bool
WGPubKey string
SignPubKey string
}
func (p Peer) MarshalJSON() ([]byte, error) {
return json.Marshal(peerJSON{
PeerIP: p.PeerIP,
Name: p.Name,
Addr4: p.Addr4,
Addr6: p.Addr6,
Port: p.Port,
Relay: p.Relay,
WGPubKey: base64.StdEncoding.EncodeToString(p.WGPubKey[:]),
SignPubKey: base64.StdEncoding.EncodeToString(p.SignPubKey[:]),
})
}
func (p *Peer) UnmarshalJSON(data []byte) error {
var j peerJSON
if err := json.Unmarshal(data, &j); err != nil {
return err
}
wg, err := base64.StdEncoding.DecodeString(j.WGPubKey)
if err != nil {
return fmt.Errorf("decode WGPubKey: %w", err)
}
key, err := wgtypes.NewKey(wg)
if err != nil {
return fmt.Errorf("invalid WGPubKey: %w", err)
}
sign, err := base64.StdEncoding.DecodeString(j.SignPubKey)
if err != nil {
return fmt.Errorf("decode SignPubKey: %w", err)
}
if len(sign) != 32 {
return fmt.Errorf("invalid SignPubKey length: %d", len(sign))
}
*p = Peer{
PeerIP: j.PeerIP,
Name: j.Name,
Addr4: j.Addr4,
Addr6: j.Addr6,
Port: j.Port,
Relay: j.Relay,
WGPubKey: key,
SignPubKey: [32]byte(sign),
}
return nil
} }
type NetworkState struct { type NetworkState struct {
Peers [256]*Peer Peers []Peer
} }

View File

@@ -1,67 +0,0 @@
package node
import (
"log"
"net/netip"
"time"
)
func addrDiscoveryServer() {
var (
buf1 = make([]byte, bufferSize)
buf2 = make([]byte, bufferSize)
)
for {
msg := <-discoveryMessages
p := msg.Packet
route := routingTable[msg.SrcIP].Load()
if route == nil || !route.RemoteAddr.IsValid() {
continue
}
_sendControlPacket(addrDiscoveryPacket{
TraceID: p.TraceID,
ToAddr: msg.SrcAddr,
}, *route, buf1, buf2)
}
}
func addrDiscoveryClient() {
var (
checkInterval = 8 * time.Second
timer = time.NewTimer(4 * time.Second)
buf1 = make([]byte, bufferSize)
buf2 = make([]byte, bufferSize)
addrPacket addrDiscoveryPacket
lAddr netip.AddrPort
)
for {
select {
case msg := <-discoveryMessages:
p := msg.Packet
if p.TraceID != addrPacket.TraceID || !p.ToAddr.IsValid() || p.ToAddr == lAddr {
continue
}
log.Printf("Discovered local address: %v", p.ToAddr)
lAddr = p.ToAddr
localAddr.Store(&p.ToAddr)
case <-timer.C:
timer.Reset(checkInterval)
route := getRelayRoute()
if route == nil {
continue
}
addrPacket.TraceID = newTraceID()
_sendControlPacket(addrPacket, *route, buf1, buf2)
}
}
}

View File

@@ -1,21 +0,0 @@
package node
const bitSetSize = 512 // Multiple of 64.
type bitSet [bitSetSize / 64]uint64
func (bs *bitSet) Set(i int) {
bs[i/64] |= 1 << (i % 64)
}
func (bs *bitSet) Clear(i int) {
bs[i/64] &= ^(1 << (i % 64))
}
func (bs *bitSet) ClearAll() {
clear(bs[:])
}
func (bs *bitSet) Get(i int) bool {
return bs[i/64]&(1<<(i%64)) != 0
}

View File

@@ -1,48 +0,0 @@
package node
import (
"math/rand"
"testing"
)
func TestBitSet(t *testing.T) {
state := make([]bool, bitSetSize)
for i := range state {
state[i] = rand.Float32() > 0.5
}
bs := bitSet{}
for i := range state {
if state[i] {
bs.Set(i)
}
}
for i := range state {
if bs.Get(i) != state[i] {
t.Fatal(i, state[i], bs.Get(i))
}
}
for i := range state {
if rand.Float32() > 0.5 {
state[i] = false
bs.Clear(i)
}
}
for i := range state {
if bs.Get(i) != state[i] {
t.Fatal(i, state[i], bs.Get(i))
}
}
bs.ClearAll()
for i := range state {
if bs.Get(i) {
t.Fatal(i, bs.Get(i))
}
}
}

View File

@@ -1,26 +0,0 @@
package node
import "golang.org/x/crypto/nacl/box"
type controlCipher struct {
sharedKey [32]byte
}
func newControlCipher(privKey, pubKey []byte) *controlCipher {
shared := [32]byte{}
box.Precompute(&shared, (*[32]byte)(pubKey), (*[32]byte)(privKey))
return &controlCipher{shared}
}
func (cc *controlCipher) Encrypt(h header, data, out []byte) []byte {
const s = controlHeaderSize
out = out[:s+controlCipherOverhead+len(data)]
h.Marshal(out[:s])
box.SealAfterPrecomputation(out[s:s], data, (*[24]byte)(out[:s]), &cc.sharedKey)
return out
}
func (cc *controlCipher) Decrypt(encrypted, out []byte) (data []byte, ok bool) {
const s = controlHeaderSize
return box.OpenAfterPrecomputation(out[:0], encrypted[s:], (*[24]byte)(encrypted[:s]), &cc.sharedKey)
}

View File

@@ -1,122 +0,0 @@
package node
import (
"bytes"
"crypto/rand"
"reflect"
"testing"
"golang.org/x/crypto/nacl/box"
)
func newControlCipherForTesting() (c1, c2 *controlCipher) {
pubKey1, privKey1, err := box.GenerateKey(rand.Reader)
if err != nil {
panic(err)
}
pubKey2, privKey2, err := box.GenerateKey(rand.Reader)
if err != nil {
panic(err)
}
return newControlCipher(privKey1[:], pubKey2[:]),
newControlCipher(privKey2[:], pubKey1[:])
}
func TestControlCipher(t *testing.T) {
c1, c2 := newControlCipherForTesting()
maxSizePlaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead)
rand.Read(maxSizePlaintext)
testCases := [][]byte{
make([]byte, 0),
{1},
{255},
{1, 2, 3, 4, 5},
[]byte("Hello world"),
maxSizePlaintext,
}
for _, plaintext := range testCases {
h1 := header{
StreamID: controlStreamID,
Counter: 235153,
SourceIP: 4,
DestIP: 88,
}
encrypted := make([]byte, bufferSize)
encrypted = c1.Encrypt(h1, plaintext, encrypted)
h2 := header{}
h2.Parse(encrypted)
if !reflect.DeepEqual(h1, h2) {
t.Fatal(h1, h2)
}
decrypted, ok := c2.Decrypt(encrypted, make([]byte, bufferSize))
if !ok {
t.Fatal(ok)
}
if !bytes.Equal(decrypted, plaintext) {
t.Fatal("not equal")
}
}
}
func TestControlCipher_ShortCiphertext(t *testing.T) {
c1, _ := newControlCipherForTesting()
shortText := make([]byte, controlHeaderSize+controlCipherOverhead-1)
rand.Read(shortText)
_, ok := c1.Decrypt(shortText, make([]byte, bufferSize))
if ok {
t.Fatal(ok)
}
}
func BenchmarkControlCipher_Encrypt(b *testing.B) {
c1, _ := newControlCipherForTesting()
h1 := header{
Counter: 235153,
SourceIP: 4,
DestIP: 88,
}
plaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead)
rand.Read(plaintext)
encrypted := make([]byte, bufferSize)
b.ResetTimer()
for i := 0; i < b.N; i++ {
encrypted = c1.Encrypt(h1, plaintext, encrypted)
}
}
func BenchmarkControlCipher_Decrypt(b *testing.B) {
c1, c2 := newControlCipherForTesting()
h1 := header{
Counter: 235153,
SourceIP: 4,
DestIP: 88,
}
plaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead)
rand.Read(plaintext)
encrypted := make([]byte, bufferSize)
encrypted = c1.Encrypt(h1, plaintext, encrypted)
decrypted := make([]byte, bufferSize)
b.ResetTimer()
for i := 0; i < b.N; i++ {
decrypted, _ = c2.Decrypt(encrypted, decrypted)
}
}

View File

@@ -1,60 +0,0 @@
package node
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
)
type dataCipher struct {
key [32]byte
aead cipher.AEAD
}
func newDataCipher() *dataCipher {
key := [32]byte{}
if _, err := rand.Read(key[:]); err != nil {
panic(err)
}
return newDataCipherFromKey(key)
}
func newDataCipherFromKey(key [32]byte) *dataCipher {
block, err := aes.NewCipher(key[:])
if err != nil {
panic(err)
}
aead, err := cipher.NewGCM(block)
if err != nil {
panic(err)
}
return &dataCipher{key: key, aead: aead}
}
func (sc *dataCipher) Key() [32]byte {
return sc.key
}
func (sc *dataCipher) Encrypt(h header, data, out []byte) []byte {
const s = dataHeaderSize
out = out[:s+dataCipherOverhead+len(data)]
h.Marshal(out[:s])
sc.aead.Seal(out[s:s], out[:s], data, nil)
return out
}
func (sc *dataCipher) Decrypt(encrypted, out []byte) (data []byte, ok bool) {
const s = dataHeaderSize
if len(encrypted) < s+dataCipherOverhead {
ok = false
return
}
var err error
data, err = sc.aead.Open(out[:0], encrypted[:s], encrypted[s:], nil)
ok = err == nil
return
}

View File

@@ -1,141 +0,0 @@
package node
import (
"bytes"
"crypto/rand"
mrand "math/rand/v2"
"reflect"
"testing"
)
func TestDataCipher(t *testing.T) {
maxSizePlaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead)
rand.Read(maxSizePlaintext)
testCases := [][]byte{
make([]byte, 0),
{1},
{255},
{1, 2, 3, 4, 5},
[]byte("Hello world"),
maxSizePlaintext,
}
for _, plaintext := range testCases {
h1 := header{
StreamID: dataStreamID,
Counter: 235153,
SourceIP: 4,
DestIP: 88,
}
encrypted := make([]byte, bufferSize)
dc1 := newDataCipher()
encrypted = dc1.Encrypt(h1, plaintext, encrypted)
h2 := header{}
h2.Parse(encrypted)
dc2 := newDataCipherFromKey(dc1.Key())
decrypted, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize))
if !ok {
t.Fatal(ok)
}
if !bytes.Equal(plaintext, decrypted) {
t.Fatal("not equal")
}
if !reflect.DeepEqual(h1, h2) {
t.Fatalf("%v != %v", h1, h2)
}
}
}
func TestDataCipher_ModifyCiphertext(t *testing.T) {
maxSizePlaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead)
rand.Read(maxSizePlaintext)
testCases := [][]byte{
make([]byte, 0),
{1},
{255},
{1, 2, 3, 4, 5},
[]byte("Hello world"),
maxSizePlaintext,
}
for _, plaintext := range testCases {
h1 := header{
Counter: 235153,
SourceIP: 4,
DestIP: 88,
}
encrypted := make([]byte, bufferSize)
dc1 := newDataCipher()
encrypted = dc1.Encrypt(h1, plaintext, encrypted)
encrypted[mrand.IntN(len(encrypted))]++
dc2 := newDataCipherFromKey(dc1.Key())
_, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize))
if ok {
t.Fatal(ok)
}
}
}
func TestDataCipher_ShortCiphertext(t *testing.T) {
dc1 := newDataCipher()
shortText := make([]byte, dataHeaderSize+dataCipherOverhead-1)
rand.Read(shortText)
_, ok := dc1.Decrypt(shortText, make([]byte, bufferSize))
if ok {
t.Fatal(ok)
}
}
func BenchmarkDataCipher_Encrypt(b *testing.B) {
h1 := header{
Counter: 235153,
SourceIP: 4,
DestIP: 88,
}
plaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead)
rand.Read(plaintext)
encrypted := make([]byte, bufferSize)
dc1 := newDataCipher()
b.ResetTimer()
for i := 0; i < b.N; i++ {
encrypted = dc1.Encrypt(h1, plaintext, encrypted)
}
}
func BenchmarkDataCipher_Decrypt(b *testing.B) {
h1 := header{
Counter: 235153,
SourceIP: 4,
DestIP: 88,
}
plaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead)
rand.Read(plaintext)
encrypted := make([]byte, bufferSize)
dc1 := newDataCipher()
encrypted = dc1.Encrypt(h1, plaintext, encrypted)
decrypted := make([]byte, bufferSize)
b.ResetTimer()
for i := 0; i < b.N; i++ {
decrypted, _ = dc1.Decrypt(encrypted, decrypted)
}
}

View File

@@ -1,13 +0,0 @@
package node
/*
func signData(privKey *[64]byte, h header, data, out []byte) []byte {
out = out[:headerSize]
h.Marshal(out)
return sign.Sign(out, data, privKey)
}
func openData(pubKey *[32]byte, signed, out []byte) (data []byte, ok bool) {
return sign.Open(out[:0], signed[headerSize:], pubKey)
}
*/

View File

@@ -1,11 +0,0 @@
package node
import "vppn/m"
type localConfig struct {
m.PeerConfig
PubKey []byte
PrivKey []byte
PubSignKey []byte
PrivSignKey []byte
}

View File

@@ -1,50 +0,0 @@
package node
import (
"io"
"log"
"net"
"net/netip"
"sync"
)
// ----------------------------------------------------------------------------
type connWriter struct {
lock sync.Mutex
conn *net.UDPConn
}
func newConnWriter(conn *net.UDPConn) *connWriter {
return &connWriter{conn: conn}
}
func (w *connWriter) WriteTo(packet []byte, addr netip.AddrPort) {
// Even though a conn is safe for concurrent use, it turns out that a mutex
// in Go is more fair when there's contention. Without this lock, control
// packets may fail to be sent in a timely manner causing timeouts.
w.lock.Lock()
if _, err := w.conn.WriteToUDPAddrPort(packet, addr); err != nil {
log.Printf("Failed to write to UDP port: %v", err)
}
w.lock.Unlock()
}
// ----------------------------------------------------------------------------
type ifWriter struct {
lock sync.Mutex
iface io.ReadWriteCloser
}
func newIFWriter(iface io.ReadWriteCloser) *ifWriter {
return &ifWriter{iface: iface}
}
func (w *ifWriter) Write(packet []byte) {
w.lock.Lock()
if _, err := w.iface.Write(packet); err != nil {
log.Fatalf("Failed to write to interface: %v", err)
}
w.lock.Unlock()
}

View File

@@ -1,76 +0,0 @@
package node
type dupCheck struct {
bitSet
head int
tail int
headCounter uint64
tailCounter uint64 // Also next expected counter value.
}
func newDupCheck(headCounter uint64) *dupCheck {
return &dupCheck{
headCounter: headCounter,
tailCounter: headCounter + 1,
tail: 1,
}
}
func (dc *dupCheck) IsDup(counter uint64) bool {
// Before head => it's late, say it's a dup.
if counter < dc.headCounter {
return true
}
// It's within the counter bounds.
if counter < dc.tailCounter {
index := (int(counter-dc.headCounter) + dc.head) % bitSetSize
if dc.Get(index) {
return true
}
dc.Set(index)
return false
}
// It's more than 1 beyond the tail.
delta := counter - dc.tailCounter
// Full clear.
if delta >= bitSetSize {
dc.ClearAll()
dc.Set(0)
dc.tail = 1
dc.head = 2
dc.tailCounter = counter + 1
dc.headCounter = dc.tailCounter - bitSetSize
return false
}
// Clear if necessary.
for i := 0; i < int(delta); i++ {
dc.put(false)
}
dc.put(true)
return false
}
func (dc *dupCheck) put(set bool) {
if set {
dc.Set(dc.tail)
} else {
dc.Clear(dc.tail)
}
dc.tail = (dc.tail + 1) % bitSetSize
dc.tailCounter++
if dc.head == dc.tail {
dc.head = (dc.head + 1) % bitSetSize
dc.headCounter++
}
}

View File

@@ -1,54 +0,0 @@
package node
import (
"testing"
)
func TestDupCheck(t *testing.T) {
dc := newDupCheck(0)
for i := range bitSetSize {
if dc.IsDup(uint64(i)) {
t.Fatal("!")
}
}
type TestCase struct {
Counter uint64
Dup bool
}
testCases := []TestCase{
{0, true},
{1, true},
{2, true},
{3, true},
{63, true},
{256, true},
{510, true},
{511, true},
{512, false},
{0, true},
{512, true},
{513, false},
{517, false},
{512, true},
{513, true},
{514, false},
{515, false},
{516, false},
{517, true},
{2512, false},
{2000, true},
{2001, false},
{4000, false},
{4000 - 512, true}, // Too old.
{4000 - 511, false}, // Just in the window.
}
for i, tc := range testCases {
if ok := dc.IsDup(tc.Counter); ok != tc.Dup {
t.Fatal(i, ok, tc)
}
}
}

View File

@@ -1,82 +0,0 @@
package node
import (
"encoding/json"
"log"
"os"
"path/filepath"
"vppn/m"
)
func configDir(netName string) string {
d, err := os.UserHomeDir()
if err != nil {
log.Fatalf("Failed to get user home directory: %v", err)
}
return filepath.Join(d, ".vppn", netName)
}
func peerConfigPath(netName string) string {
return filepath.Join(configDir(netName), "peer-config.json")
}
func peerStatePath(netName string) string {
return filepath.Join(configDir(netName), "peer-state.json")
}
func storeJson(x any, outPath string) error {
outDir := filepath.Dir(outPath)
_ = os.MkdirAll(outDir, 0700)
tmpPath := outPath + ".tmp"
buf, err := json.Marshal(x)
if err != nil {
return err
}
f, err := os.Create(tmpPath)
if err != nil {
return err
}
if _, err := f.Write(buf); err != nil {
f.Close()
return err
}
if err := f.Sync(); err != nil {
f.Close()
return err
}
if err := f.Close(); err != nil {
return err
}
return os.Rename(tmpPath, outPath)
}
func storePeerConfig(netName string, pc localConfig) error {
return storeJson(pc, peerConfigPath(netName))
}
func storeNetworkState(netName string, ps m.NetworkState) error {
return storeJson(ps, peerStatePath(netName))
}
func loadJson(dataPath string, ptr any) error {
data, err := os.ReadFile(dataPath)
if err != nil {
return err
}
return json.Unmarshal(data, ptr)
}
func loadPeerConfig(netName string) (pc localConfig, err error) {
return pc, loadJson(peerConfigPath(netName), &pc)
}
func loadNetworkState(netName string) (ps m.NetworkState, err error) {
return ps, loadJson(peerStatePath(netName), &ps)
}

View File

@@ -1,73 +0,0 @@
package node
import (
"net/netip"
"sync/atomic"
)
func getRelayRoute() *peerRoute {
if ip := relayIP.Load(); ip != nil {
return routingTable[*ip].Load()
}
return nil
}
func getLocalAddr() netip.AddrPort {
if a := localAddr.Load(); a != nil {
return *a
}
return netip.AddrPort{}
}
func _sendControlPacket(pkt interface{ Marshal([]byte) []byte }, route peerRoute, buf1, buf2 []byte) {
buf := pkt.Marshal(buf2)
h := header{
StreamID: controlStreamID,
Counter: atomic.AddUint64(&sendCounters[route.IP], 1),
SourceIP: localIP,
DestIP: route.IP,
}
buf = route.ControlCipher.Encrypt(h, buf, buf1)
if route.Direct {
_conn.WriteTo(buf, route.RemoteAddr)
return
}
_relayPacket(route.IP, buf, buf2)
}
func _sendDataPacket(route *peerRoute, pkt, buf1, buf2 []byte) {
h := header{
StreamID: dataStreamID,
Counter: atomic.AddUint64(&sendCounters[route.IP], 1),
SourceIP: localIP,
DestIP: route.IP,
}
enc := route.DataCipher.Encrypt(h, pkt, buf1)
if route.Direct {
_conn.WriteTo(enc, route.RemoteAddr)
return
}
_relayPacket(route.IP, enc, buf2)
}
func _relayPacket(destIP byte, data, buf []byte) {
relayRoute := getRelayRoute()
if relayRoute == nil || !relayRoute.Up || !relayRoute.Relay {
return
}
h := header{
StreamID: dataStreamID,
Counter: atomic.AddUint64(&sendCounters[relayRoute.IP], 1),
SourceIP: localIP,
DestIP: destIP,
}
enc := relayRoute.DataCipher.Encrypt(h, data, buf)
_conn.WriteTo(enc, relayRoute.RemoteAddr)
}

View File

@@ -1,87 +0,0 @@
package node
import (
"net"
"net/netip"
"net/url"
"sync/atomic"
"time"
)
const (
bufferSize = 1536
if_mtu = 1200
if_queue_len = 2048
controlCipherOverhead = 16
dataCipherOverhead = 16
signOverhead = 64
)
var (
multicastIP = netip.AddrFrom4([4]byte{224, 0, 0, 157})
multicastAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom(multicastIP, 4560))
)
type peerRoute struct {
IP byte
Up bool // True if data can be sent on the route.
Relay bool // True if the peer is a relay.
Direct bool // True if this is a direct connection.
PubSignKey []byte
ControlCipher *controlCipher
DataCipher *dataCipher
RemoteAddr netip.AddrPort // Remote address if directly connected.
}
var (
hubURL *url.URL
apiKey string
// Configuration for this peer.
netName string
localIP byte
localPub bool
privKey []byte
privSignKey []byte
// Shared interface for writing.
_iface *ifWriter
// Shared connection for writing.
_conn *connWriter
// Counters for sending to each peer.
sendCounters [256]uint64 = func() (out [256]uint64) {
for i := range out {
out[i] = uint64(time.Now().Unix()<<30 + 1)
}
return
}()
// Duplicate checkers for incoming packets.
dupChecks [256]*dupCheck = func() (out [256]*dupCheck) {
for i := range out {
out[i] = newDupCheck(0)
}
return
}()
// Messages for the supervisor.
messages = make(chan any, 512)
// Global routing table.
routingTable [256]*atomic.Pointer[peerRoute] = func() (out [256]*atomic.Pointer[peerRoute]) {
for i := range out {
out[i] = &atomic.Pointer[peerRoute]{}
out[i].Store(&peerRoute{})
}
return
}()
// Managed by the addrDiscovery* functions.
discoveryMessages = make(chan controlMsg[addrDiscoveryPacket], 256)
// Managed by the relayManager.
localAddr = &atomic.Pointer[netip.AddrPort]{}
relayIP = &atomic.Pointer[byte]{}
)

View File

@@ -1,37 +0,0 @@
package node
import "unsafe"
// ----------------------------------------------------------------------------
const (
headerSize = 12
controlStreamID = 2
controlHeaderSize = 24
dataStreamID = 1
dataHeaderSize = 12
)
type header struct {
Version byte
StreamID byte
SourceIP byte
DestIP byte
Counter uint64 // Init with time.Now().Unix << 30 to ensure monotonic.
}
func (h *header) Parse(b []byte) {
h.Version = b[0]
h.StreamID = b[1]
h.SourceIP = b[2]
h.DestIP = b[3]
h.Counter = *(*uint64)(unsafe.Pointer(&b[4]))
}
func (h *header) Marshal(buf []byte) {
buf[0] = h.Version
buf[1] = h.StreamID
buf[2] = h.SourceIP
buf[3] = h.DestIP
*(*uint64)(unsafe.Pointer(&buf[4])) = h.Counter
}

View File

@@ -1,21 +0,0 @@
package node
import "testing"
func TestHeaderMarshalParse(t *testing.T) {
nIn := header{
StreamID: 23,
Counter: 3212,
SourceIP: 34,
DestIP: 200,
}
buf := make([]byte, headerSize)
nIn.Marshal(buf)
nOut := header{}
nOut.Parse(buf)
if nIn != nOut {
t.Fatal(nIn, nOut)
}
}

View File

@@ -1,92 +0,0 @@
package node
import (
"encoding/json"
"io"
"log"
"net/http"
"time"
"vppn/m"
)
type hubPoller struct {
client *http.Client
req *http.Request
versions [256]int64
}
func newHubPoller() *hubPoller {
u := *hubURL
u.Path = "/peer/fetch-state/"
client := &http.Client{Timeout: 8 * time.Second}
req := &http.Request{
Method: http.MethodGet,
URL: &u,
Header: http.Header{},
}
req.SetBasicAuth("", apiKey)
return &hubPoller{
client: client,
req: req,
}
}
func (hp *hubPoller) Run() {
defer panicHandler()
state, err := loadNetworkState(netName)
if err != nil {
log.Printf("Failed to load network state: %v", err)
log.Printf("Polling hub...")
hp.pollHub()
} else {
hp.applyNetworkState(state)
}
for range time.Tick(64 * time.Second) {
hp.pollHub()
}
}
func (hp *hubPoller) pollHub() {
var state m.NetworkState
resp, err := hp.client.Do(hp.req)
if err != nil {
log.Printf("Failed to fetch peer state: %v", err)
return
}
body, err := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if err != nil {
log.Printf("Failed to read body from hub: %v", err)
return
}
if err := json.Unmarshal(body, &state); err != nil {
log.Printf("Failed to unmarshal response from hub: %v\n%s", err, body)
return
}
hp.applyNetworkState(state)
if err := storeNetworkState(netName, state); err != nil {
log.Printf("Failed to store network state: %v", err)
}
}
func (hp *hubPoller) applyNetworkState(state m.NetworkState) {
for i, peer := range state.Peers {
if i != int(localIP) {
if peer == nil || peer.Version != hp.versions[i] {
messages <- peerUpdateMsg{PeerIP: byte(i), Peer: state.Peers[i]}
if peer != nil {
hp.versions[i] = peer.Version
}
}
}
}
}

View File

@@ -1,177 +0,0 @@
package node
import (
"fmt"
"io"
"log"
"net"
"os"
"syscall"
"golang.org/x/sys/unix"
)
// Get next packet, returning packet, ip, and possible error.
func readNextPacket(iface io.ReadWriteCloser, buf []byte) ([]byte, byte, error) {
var (
version byte
ip byte
)
for {
n, err := iface.Read(buf[:cap(buf)])
if err != nil {
return nil, ip, err
}
buf = buf[:n]
version = buf[0] >> 4
switch version {
case 4:
if n < 20 {
log.Printf("Short IPv4 packet: %d", len(buf))
continue
}
ip = buf[19]
case 6:
if len(buf) < 40 {
log.Printf("Short IPv6 packet: %d", len(buf))
continue
}
ip = buf[39]
default:
log.Printf("Invalid IP packet version: %v", version)
continue
}
return buf, ip, nil
}
}
func openInterface(network []byte, localIP byte, name string) (io.ReadWriteCloser, error) {
if len(network) != 4 {
return nil, fmt.Errorf("expected network to be 4 bytes, got %d", len(network))
}
ip := net.IPv4(network[0], network[1], network[2], localIP)
//////////////////////////
// Create TUN Interface //
//////////////////////////
tunFD, err := syscall.Open("/dev/net/tun", syscall.O_RDWR|unix.O_CLOEXEC, 0600)
if err != nil {
return nil, fmt.Errorf("failed to open TUN device: %w", err)
}
// New interface request.
req, err := unix.NewIfreq(name)
if err != nil {
return nil, fmt.Errorf("failed to create new TUN interface request: %w", err)
}
// Flags:
//
// IFF_NO_PI => don't add packet info data to packets sent to the interface.
// IFF_TUN => create a TUN device handling IP packets.
req.SetUint16(unix.IFF_NO_PI | unix.IFF_TUN)
err = unix.IoctlIfreq(tunFD, unix.TUNSETIFF, req)
if err != nil {
return nil, fmt.Errorf("failed to set TUN device settings: %w", err)
}
// Name may not be exactly the same?
name = req.Name()
/////////////
// Set MTU //
/////////////
// We need a socket file descriptor to set other options for some reason.
sockFD, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return nil, fmt.Errorf("failed to open socket: %w", err)
}
defer unix.Close(sockFD)
req, err = unix.NewIfreq(name)
if err != nil {
return nil, fmt.Errorf("failed to create MTU interface request: %w", err)
}
req.SetUint32(if_mtu)
if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFMTU, req); err != nil {
return nil, fmt.Errorf("failed to set interface MTU: %w", err)
}
//////////////////////
// Set Queue Length //
//////////////////////
req, err = unix.NewIfreq(name)
if err != nil {
return nil, fmt.Errorf("failed to create IP interface request: %w", err)
}
req.SetUint16(if_queue_len)
if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFTXQLEN, req); err != nil {
return nil, fmt.Errorf("failed to set interface queue length: %w", err)
}
/////////////////////
// Set IP and Mask //
/////////////////////
req, err = unix.NewIfreq(name)
if err != nil {
return nil, fmt.Errorf("failed to create IP interface request: %w", err)
}
if err := req.SetInet4Addr(ip.To4()); err != nil {
return nil, fmt.Errorf("failed to set interface request IP: %w", err)
}
if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFADDR, req); err != nil {
return nil, fmt.Errorf("failed to set interface IP: %w", err)
}
// SET MASK - must happen after setting address.
req, err = unix.NewIfreq(name)
if err != nil {
return nil, fmt.Errorf("failed to create mask interface request: %w", err)
}
if err := req.SetInet4Addr(net.IPv4(255, 255, 255, 0).To4()); err != nil {
return nil, fmt.Errorf("failed to set interface request mask: %w", err)
}
if err := unix.IoctlIfreq(sockFD, unix.SIOCSIFNETMASK, req); err != nil {
return nil, fmt.Errorf("failed to set interface mask: %w", err)
}
////////////////////////
// Bring Interface Up //
////////////////////////
req, err = unix.NewIfreq(name)
if err != nil {
return nil, fmt.Errorf("failed to create up interface request: %w", err)
}
// Get current flags.
if err = unix.IoctlIfreq(sockFD, unix.SIOCGIFFLAGS, req); err != nil {
return nil, fmt.Errorf("failed to get interface flags: %w", err)
}
flags := req.Uint16() | unix.IFF_UP | unix.IFF_RUNNING
// Set UP flag / broadcast flags.
req.SetUint16(flags)
if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFFLAGS, req); err != nil {
return nil, fmt.Errorf("failed to set interface up: %w", err)
}
return os.NewFile(uintptr(tunFD), "tun"), nil
}

View File

@@ -1,97 +0,0 @@
package node
import (
"log"
"net"
"time"
"golang.org/x/crypto/nacl/sign"
)
func localDiscovery() {
conn, err := net.ListenMulticastUDP("udp", nil, multicastAddr)
if err != nil {
log.Printf("Failed to bind to multicast address: %v", err)
return
}
go sendLocalDiscovery(conn)
go recvLocalDiscovery(conn)
}
func sendLocalDiscovery(conn *net.UDPConn) {
var (
buf1 = make([]byte, bufferSize)
buf2 = make([]byte, bufferSize)
)
for range time.Tick(32 * time.Second) {
signed := buildLocalDiscoveryPacket(buf1, buf2)
if _, err := conn.WriteToUDP(signed, multicastAddr); err != nil {
log.Printf("Failed to write multicast UDP packet: %v", err)
}
}
}
func recvLocalDiscovery(conn *net.UDPConn) {
var (
raw = make([]byte, bufferSize)
buf = make([]byte, bufferSize)
)
for {
n, remoteAddr, err := conn.ReadFromUDPAddrPort(raw[:bufferSize])
if err != nil {
log.Fatalf("Failed to read from UDP port (multicast): %v", err)
}
raw = raw[:n]
h, ok := openLocalDiscoveryPacket(raw, buf)
if !ok {
log.Printf("Failed to open discovery packet?")
continue
}
msg := controlMsg[localDiscoveryPacket]{
SrcIP: h.SourceIP,
SrcAddr: remoteAddr,
Packet: localDiscoveryPacket{},
}
select {
case messages <- msg:
default:
log.Printf("Dropping local discovery message.")
}
}
}
func buildLocalDiscoveryPacket(buf1, buf2 []byte) []byte {
h := header{
StreamID: controlStreamID,
Counter: 0,
SourceIP: localIP,
DestIP: 255,
}
out := buf1[:headerSize]
h.Marshal(out)
return sign.Sign(buf2[:0], out, (*[64]byte)(privSignKey))
}
func openLocalDiscoveryPacket(raw, buf []byte) (h header, ok bool) {
if len(raw) != headerSize+signOverhead {
ok = false
return
}
h.Parse(raw[signOverhead:])
route := routingTable[h.SourceIP].Load()
if route == nil || route.PubSignKey == nil {
log.Printf("Missing signing key: %d", h.SourceIP)
ok = false
return
}
_, ok = sign.Open(buf[:0], raw, (*[32]byte)(route.PubSignKey))
return
}

View File

@@ -1,35 +0,0 @@
package node
import (
"bytes"
"crypto/rand"
"testing"
"golang.org/x/crypto/nacl/sign"
)
func TestLocalDiscoveryPacketSigning(t *testing.T) {
localIP = 32
var (
buf1 = make([]byte, bufferSize)
buf2 = make([]byte, bufferSize)
pubSignKey, privSigKey, _ = sign.GenerateKey(rand.Reader)
)
privSignKey = privSigKey[:]
route := routingTable[localIP].Load()
route.IP = byte(localIP)
route.PubSignKey = pubSignKey[0:32]
routingTable[localIP].Store(route)
out := buildLocalDiscoveryPacket(buf1, buf2)
h, ok := openLocalDiscoveryPacket(bytes.Clone(out), buf1)
if !ok {
t.Fatal(h, ok)
}
if h.StreamID != controlStreamID || h.SourceIP != localIP || h.DestIP != 255 {
t.Fatal(h)
}
}

View File

@@ -1,331 +0,0 @@
package node
import (
"bytes"
"crypto/rand"
"encoding/json"
"flag"
"fmt"
"io"
"log"
"net"
"net/http"
"net/netip"
"net/url"
"os"
"runtime/debug"
"time"
"vppn/m"
"golang.org/x/crypto/nacl/box"
"golang.org/x/crypto/nacl/sign"
)
func panicHandler() {
if r := recover(); r != nil {
log.Fatalf("\n %v\n\nstacktrace from panic: %s\n", r, string(debug.Stack()))
}
}
func Main() {
defer panicHandler()
var hubAddress string
flag.StringVar(&netName, "name", "", "[REQUIRED] The network name.")
flag.StringVar(&hubAddress, "hub-address", "", "[REQUIRED] The hub address.")
flag.StringVar(&apiKey, "api-key", "", "[REQUIRED] The node's API key.")
flag.Parse()
if netName == "" || hubAddress == "" || apiKey == "" {
flag.Usage()
os.Exit(1)
}
var err error
hubURL, err = url.Parse(hubAddress)
if err != nil {
log.Fatalf("Failed to parse hub address: %v", err)
}
main()
}
func initPeerWithHub() {
encPubKey, encPrivKey, err := box.GenerateKey(rand.Reader)
if err != nil {
log.Fatalf("Failed to generate encryption keys: %v", err)
}
signPubKey, signPrivKey, err := sign.GenerateKey(rand.Reader)
if err != nil {
log.Fatalf("Failed to generate signing keys: %v", err)
}
initURL := *hubURL
initURL.Path = "/peer/init/"
args := m.PeerInitArgs{
EncPubKey: encPubKey[:],
PubSignKey: signPubKey[:],
}
buf := &bytes.Buffer{}
if err := json.NewEncoder(buf).Encode(args); err != nil {
log.Fatalf("Failed to encode init args: %v", err)
}
req, err := http.NewRequest(http.MethodPost, initURL.String(), buf)
if err != nil {
log.Fatalf("Failed to construct request: %v", err)
}
req.SetBasicAuth("", apiKey)
resp, err := http.DefaultClient.Do(req)
if err != nil {
log.Fatalf("Failed to init with hub: %v", err)
}
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
if err != nil {
log.Fatalf("Failed to read response body: %v", err)
}
peerConfig := localConfig{}
if err := json.Unmarshal(data, &peerConfig.PeerConfig); err != nil {
log.Fatalf("Failed to parse configuration: %v\n%s", err, data)
}
peerConfig.PubKey = encPubKey[:]
peerConfig.PrivKey = encPrivKey[:]
peerConfig.PubSignKey = signPubKey[:]
peerConfig.PrivSignKey = signPrivKey[:]
if err := storePeerConfig(netName, peerConfig); err != nil {
log.Fatalf("Failed to store configuration: %v", err)
}
log.Print("Initialization successful.")
}
// ----------------------------------------------------------------------------
func main() {
config, err := loadPeerConfig(netName)
if err != nil {
log.Printf("Failed to load configuration: %v", err)
log.Printf("Initializing...")
initPeerWithHub()
config, err = loadPeerConfig(netName)
if err != nil {
log.Fatalf("Failed to load configuration: %v", err)
}
}
iface, err := openInterface(config.Network, config.PeerIP, netName)
if err != nil {
log.Fatalf("Failed to open interface: %v", err)
}
myAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", config.Port))
if err != nil {
log.Fatalf("Failed to resolve UDP address: %v", err)
}
conn, err := net.ListenUDP("udp", myAddr)
if err != nil {
log.Fatalf("Failed to open UDP port: %v", err)
}
conn.SetReadBuffer(1024 * 1024 * 8)
conn.SetWriteBuffer(1024 * 1024 * 8)
// Intialize globals.
_iface = newIFWriter(iface)
_conn = newConnWriter(conn)
localIP = config.PeerIP
ip, ok := netip.AddrFromSlice(config.PublicIP)
if ok {
localPub = true
addr := netip.AddrPortFrom(ip, config.Port)
localAddr.Store(&addr)
}
privKey = config.PrivKey
privSignKey = config.PrivSignKey
if localPub {
go addrDiscoveryServer()
} else {
go addrDiscoveryClient()
go relayManager()
go localDiscovery()
}
go func() {
for range time.Tick(pingInterval) {
messages <- pingTimerMsg{}
}
}()
go startPeerSuper()
go newHubPoller().Run()
go readFromConn(conn)
readFromIFace(iface)
}
// ----------------------------------------------------------------------------
func readFromConn(conn *net.UDPConn) {
defer panicHandler()
var (
remoteAddr netip.AddrPort
n int
err error
buf = make([]byte, bufferSize)
decBuf = make([]byte, bufferSize)
data []byte
h header
)
for {
n, remoteAddr, err = conn.ReadFromUDPAddrPort(buf[:bufferSize])
if err != nil {
log.Fatalf("Failed to read from UDP port: %v", err)
}
remoteAddr = netip.AddrPortFrom(remoteAddr.Addr().Unmap(), remoteAddr.Port())
data = buf[:n]
if n < headerSize {
continue // Packet it soo short.
}
h.Parse(data)
switch h.StreamID {
case controlStreamID:
handleControlPacket(remoteAddr, h, data, decBuf)
case dataStreamID:
handleDataPacket(h, data, decBuf)
default:
log.Printf("Unknown stream ID: %d", h.StreamID)
}
}
}
func handleControlPacket(addr netip.AddrPort, h header, data, decBuf []byte) {
route := routingTable[h.SourceIP].Load()
if route.ControlCipher == nil {
//log.Printf("Not connected (control).")
return
}
if h.DestIP != localIP {
log.Printf("Incorrect destination IP on control packet: %d != %d", h.DestIP, localIP)
return
}
out, ok := route.ControlCipher.Decrypt(data, decBuf)
if !ok {
log.Printf("Failed to decrypt control packet.")
return
}
if len(out) == 0 {
log.Printf("Empty control packet from: %d", h.SourceIP)
return
}
if dupChecks[h.SourceIP].IsDup(h.Counter) {
log.Printf("[%03d] Duplicate control packet: %d", h.SourceIP, h.Counter)
return
}
msg, err := parseControlMsg(h.SourceIP, addr, out)
if err != nil {
log.Printf("Failed to parse control packet: %v", err)
return
}
if dm, ok := msg.(controlMsg[addrDiscoveryPacket]); ok {
discoveryMessages <- dm
return
}
select {
case messages <- msg:
default:
log.Printf("Dropping control packet.")
}
}
func handleDataPacket(h header, data []byte, decBuf []byte) {
route := routingTable[h.SourceIP].Load()
if !route.Up {
log.Printf("Not connected (recv).")
return
}
dec, ok := route.DataCipher.Decrypt(data, decBuf)
if !ok {
log.Printf("Failed to decrypt data packet.")
return
}
if dupChecks[h.SourceIP].IsDup(h.Counter) {
log.Printf("[%03d] Duplicate data packet: %d", h.SourceIP, h.Counter)
return
}
if h.DestIP == localIP {
_iface.Write(dec)
return
}
destRoute := routingTable[h.DestIP].Load()
if !destRoute.Up {
log.Printf("Not connected (relay): %d", destRoute.IP)
return
}
_conn.WriteTo(dec, destRoute.RemoteAddr)
}
// ----------------------------------------------------------------------------
func readFromIFace(iface io.ReadWriteCloser) {
var (
packet = make([]byte, bufferSize)
buf1 = make([]byte, bufferSize)
buf2 = make([]byte, bufferSize)
remoteIP byte
err error
)
for {
packet, remoteIP, err = readNextPacket(iface, packet)
if err != nil {
log.Fatalf("Failed to read from interface: %v", err)
}
route := routingTable[remoteIP].Load()
if !route.Up {
log.Printf("Route not connected: %d", remoteIP)
continue
}
_sendDataPacket(route, packet, buf1, buf2)
}
}

View File

@@ -1,67 +0,0 @@
package node
import (
"net/netip"
"vppn/m"
)
// ----------------------------------------------------------------------------
type controlMsg[T any] struct {
SrcIP byte
SrcAddr netip.AddrPort
Packet T
}
func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error) {
switch buf[0] {
case packetTypeSyn:
packet, err := parseSynPacket(buf)
return controlMsg[synPacket]{
SrcIP: srcIP,
SrcAddr: srcAddr,
Packet: packet,
}, err
case packetTypeSynAck:
packet, err := parseAckPacket(buf)
return controlMsg[ackPacket]{
SrcIP: srcIP,
SrcAddr: srcAddr,
Packet: packet,
}, err
case packetTypeProbe:
packet, err := parseProbePacket(buf)
return controlMsg[probePacket]{
SrcIP: srcIP,
SrcAddr: srcAddr,
Packet: packet,
}, err
case packetTypeAddrDiscovery:
packet, err := parseAddrDiscoveryPacket(buf)
return controlMsg[addrDiscoveryPacket]{
SrcIP: srcIP,
SrcAddr: srcAddr,
Packet: packet,
}, err
default:
return nil, errUnknownPacketType
}
}
// ----------------------------------------------------------------------------
type peerUpdateMsg struct {
PeerIP byte
Peer *m.Peer
}
// ----------------------------------------------------------------------------
type pingTimerMsg struct{}
// ----------------------------------------------------------------------------

View File

@@ -1,163 +0,0 @@
package node
import (
"net/netip"
"sync/atomic"
"time"
"unsafe"
)
var traceIDCounter uint64 = uint64(time.Now().Unix()<<30) + 1
func newTraceID() uint64 {
return atomic.AddUint64(&traceIDCounter, 1)
}
// ----------------------------------------------------------------------------
type binWriter struct {
b []byte
i int
}
func newBinWriter(buf []byte) *binWriter {
buf = buf[:cap(buf)]
return &binWriter{buf, 0}
}
func (w *binWriter) Bool(b bool) *binWriter {
if b {
return w.Byte(1)
}
return w.Byte(0)
}
func (w *binWriter) Byte(b byte) *binWriter {
w.b[w.i] = b
w.i++
return w
}
func (w *binWriter) SharedKey(key [32]byte) *binWriter {
copy(w.b[w.i:w.i+32], key[:])
w.i += 32
return w
}
func (w *binWriter) Uint16(x uint16) *binWriter {
*(*uint16)(unsafe.Pointer(&w.b[w.i])) = x
w.i += 2
return w
}
func (w *binWriter) Uint64(x uint64) *binWriter {
*(*uint64)(unsafe.Pointer(&w.b[w.i])) = x
w.i += 8
return w
}
func (w *binWriter) Int64(x int64) *binWriter {
*(*int64)(unsafe.Pointer(&w.b[w.i])) = x
w.i += 8
return w
}
func (w *binWriter) AddrPort(addrPort netip.AddrPort) *binWriter {
addr := addrPort.Addr().As16()
copy(w.b[w.i:w.i+16], addr[:])
w.i += 16
return w.Uint16(addrPort.Port())
}
func (w *binWriter) Build() []byte {
return w.b[:w.i]
}
// ----------------------------------------------------------------------------
type binReader struct {
b []byte
i int
err error
}
func newBinReader(buf []byte) *binReader {
return &binReader{b: buf}
}
func (r *binReader) hasBytes(n int) bool {
if r.err != nil || (len(r.b)-r.i) < n {
r.err = errMalformedPacket
return false
}
return true
}
func (r *binReader) Bool(b *bool) *binReader {
var bb byte
r.Byte(&bb)
*b = bb != 0
return r
}
func (r *binReader) Byte(b *byte) *binReader {
if !r.hasBytes(1) {
return r
}
*b = r.b[r.i]
r.i++
return r
}
func (r *binReader) SharedKey(x *[32]byte) *binReader {
if !r.hasBytes(32) {
return r
}
*x = ([32]byte)(r.b[r.i : r.i+32])
r.i += 32
return r
}
func (r *binReader) Uint16(x *uint16) *binReader {
if !r.hasBytes(2) {
return r
}
*x = *(*uint16)(unsafe.Pointer(&r.b[r.i]))
r.i += 2
return r
}
func (r *binReader) Uint64(x *uint64) *binReader {
if !r.hasBytes(8) {
return r
}
*x = *(*uint64)(unsafe.Pointer(&r.b[r.i]))
r.i += 8
return r
}
func (r *binReader) Int64(x *int64) *binReader {
if !r.hasBytes(8) {
return r
}
*x = *(*int64)(unsafe.Pointer(&r.b[r.i]))
r.i += 8
return r
}
func (r *binReader) AddrPort(x *netip.AddrPort) *binReader {
if !r.hasBytes(18) {
return r
}
addr := netip.AddrFrom16(([16]byte)(r.b[r.i : r.i+16])).Unmap()
r.i += 16
var port uint16
r.Uint16(&port)
*x = netip.AddrPortFrom(addr, port)
return r
}
func (r *binReader) Error() error {
return r.err
}

View File

@@ -1,40 +0,0 @@
package node
import (
"net/netip"
"reflect"
"testing"
)
func TestBinWriteRead(t *testing.T) {
buf := make([]byte, 1024)
type Item struct {
Type byte
TraceID uint64
DestAddr netip.AddrPort
}
in := Item{1, 2, netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 22)}
buf = newBinWriter(buf).
Byte(in.Type).
Uint64(in.TraceID).
AddrPort(in.DestAddr).
Build()
out := Item{}
err := newBinReader(buf).
Byte(&out.Type).
Uint64(&out.TraceID).
AddrPort(&out.DestAddr).
Error()
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(in, out) {
t.Fatal(in, out)
}
}

View File

@@ -1,120 +0,0 @@
package node
import (
"errors"
"net/netip"
)
var (
errMalformedPacket = errors.New("malformed packet")
errUnknownPacketType = errors.New("unknown packet type")
)
const (
packetTypeSyn = iota + 1
packetTypeSynAck
packetTypeAck
packetTypeProbe
packetTypeAddrDiscovery
)
// ----------------------------------------------------------------------------
type synPacket struct {
TraceID uint64 // TraceID to match response w/ request.
SharedKey [32]byte // Our shared key.
Direct bool
FromAddr netip.AddrPort // The client's sending address.
}
func (p synPacket) Marshal(buf []byte) []byte {
return newBinWriter(buf).
Byte(packetTypeSyn).
Uint64(p.TraceID).
SharedKey(p.SharedKey).
Bool(p.Direct).
AddrPort(p.FromAddr).
Build()
}
func parseSynPacket(buf []byte) (p synPacket, err error) {
err = newBinReader(buf[1:]).
Uint64(&p.TraceID).
SharedKey(&p.SharedKey).
Bool(&p.Direct).
AddrPort(&p.FromAddr).
Error()
return
}
// ----------------------------------------------------------------------------
type ackPacket struct {
TraceID uint64
FromAddr netip.AddrPort
}
func (p ackPacket) Marshal(buf []byte) []byte {
return newBinWriter(buf).
Byte(packetTypeSynAck).
Uint64(p.TraceID).
AddrPort(p.FromAddr).
Build()
}
func parseAckPacket(buf []byte) (p ackPacket, err error) {
err = newBinReader(buf[1:]).
Uint64(&p.TraceID).
AddrPort(&p.FromAddr).
Error()
return
}
// ----------------------------------------------------------------------------
type addrDiscoveryPacket struct {
TraceID uint64
ToAddr netip.AddrPort
}
func (p addrDiscoveryPacket) Marshal(buf []byte) []byte {
return newBinWriter(buf).
Byte(packetTypeAddrDiscovery).
Uint64(p.TraceID).
AddrPort(p.ToAddr).
Build()
}
func parseAddrDiscoveryPacket(buf []byte) (p addrDiscoveryPacket, err error) {
err = newBinReader(buf[1:]).
Uint64(&p.TraceID).
AddrPort(&p.ToAddr).
Error()
return
}
// ----------------------------------------------------------------------------
// A probeReqPacket is sent from a client to a server to determine if direct
// UDP communication can be used.
type probePacket struct {
TraceID uint64
}
func (p probePacket) Marshal(buf []byte) []byte {
return newBinWriter(buf).
Byte(packetTypeProbe).
Uint64(p.TraceID).
Build()
}
func parseProbePacket(buf []byte) (p probePacket, err error) {
err = newBinReader(buf[1:]).
Uint64(&p.TraceID).
Error()
return
}
// ----------------------------------------------------------------------------
type localDiscoveryPacket struct{}

View File

@@ -1,41 +0,0 @@
package node
import (
"crypto/rand"
"net/netip"
"reflect"
"testing"
)
func TestPacketSyn(t *testing.T) {
in := synPacket{
TraceID: newTraceID(),
FromAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{4, 5, 6, 7}), 22),
}
rand.Read(in.SharedKey[:])
out, err := parseSynPacket(in.Marshal(make([]byte, bufferSize)))
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(in, out) {
t.Fatal("\n", in, "\n", out)
}
}
func TestPacketSynAck(t *testing.T) {
in := ackPacket{
TraceID: newTraceID(),
FromAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{4, 5, 6, 7}), 22),
}
out, err := parseAckPacket(in.Marshal(make([]byte, bufferSize)))
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(in, out) {
t.Fatal("\n", in, "\n", out)
}
}

View File

@@ -1,40 +0,0 @@
package node
import (
"log"
"math/rand"
"time"
)
func relayManager() {
time.Sleep(2 * time.Second)
updateRelayRoute()
for range time.Tick(8 * time.Second) {
relay := getRelayRoute()
if relay == nil || !relay.Up || !relay.Relay {
updateRelayRoute()
}
}
}
func updateRelayRoute() {
possible := make([]*peerRoute, 0, 8)
for i := range routingTable {
route := routingTable[i].Load()
if !route.Up || !route.Relay {
continue
}
possible = append(possible, route)
}
if len(possible) == 0 {
log.Printf("No relay available.")
relayIP.Store(nil)
return
}
ip := possible[rand.Intn(len(possible))].IP
log.Printf("New relay IP: %d", ip)
relayIP.Store(&ip)
}

View File

@@ -1,395 +0,0 @@
package node
import (
"fmt"
"log"
"net/netip"
"strings"
"sync/atomic"
"time"
"vppn/m"
"git.crumpington.com/lib/go/ratelimiter"
)
const (
pingInterval = 8 * time.Second
timeoutInterval = 25 * time.Second
)
// ----------------------------------------------------------------------------
func startPeerSuper() {
peers := [256]peerState{}
for i := range peers {
data := &peerStateData{
published: routingTable[i],
remoteIP: byte(i),
buf1: make([]byte, bufferSize),
buf2: make([]byte, bufferSize),
limiter: ratelimiter.New(ratelimiter.Config{
FillPeriod: 50 * time.Millisecond,
MaxWaitCount: 1,
}),
}
peers[i] = data.OnPeerUpdate(nil)
}
go runPeerSuper(peers)
}
func runPeerSuper(peers [256]peerState) {
for raw := range messages {
switch msg := raw.(type) {
case peerUpdateMsg:
peers[msg.PeerIP] = peers[msg.PeerIP].OnPeerUpdate(msg.Peer)
case controlMsg[synPacket]:
peers[msg.SrcIP].OnSyn(msg)
case controlMsg[ackPacket]:
peers[msg.SrcIP].OnAck(msg)
case controlMsg[probePacket]:
peers[msg.SrcIP].OnProbe(msg)
case controlMsg[localDiscoveryPacket]:
peers[msg.SrcIP].OnLocalDiscovery(msg)
case pingTimerMsg:
for i := range peers {
if newState := peers[i].OnPingTimer(); newState != nil {
peers[i] = newState
}
}
default:
log.Printf("WARNING: unknown message type: %+v", msg)
}
}
}
// ----------------------------------------------------------------------------
type peerState interface {
OnPeerUpdate(*m.Peer) peerState
OnSyn(controlMsg[synPacket])
OnAck(controlMsg[ackPacket])
OnProbe(controlMsg[probePacket])
OnLocalDiscovery(controlMsg[localDiscoveryPacket])
OnPingTimer() peerState
}
// ----------------------------------------------------------------------------
type peerStateData struct {
// The purpose of this state machine is to manage this published data.
published *atomic.Pointer[peerRoute]
staged peerRoute // Local copy of shared data. See publish().
// Immutable data.
remoteIP byte // Remote VPN IP.
// Mutable peer data.
peer *m.Peer
remotePub bool
// Buffers for sending control packets.
buf1 []byte
buf2 []byte
// For logging. Set per-state.
client bool
limiter *ratelimiter.Limiter
}
// ----------------------------------------------------------------------------
func (s *peerStateData) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) {
s._sendControlPacket(pkt, s.staged)
}
func (s *peerStateData) sendControlPacketTo(pkt interface{ Marshal([]byte) []byte }, addr netip.AddrPort) {
if !addr.IsValid() {
s.logf("ERROR: Attepted to send packet to invalid address: %v", addr)
return
}
route := s.staged
route.Direct = true
route.RemoteAddr = addr
s._sendControlPacket(pkt, route)
}
func (s *peerStateData) _sendControlPacket(pkt interface{ Marshal([]byte) []byte }, route peerRoute) {
if err := s.limiter.Limit(); err != nil {
s.logf("Not sending control packet: rate limited.") // Shouldn't happen.
return
}
_sendControlPacket(pkt, route, s.buf1, s.buf2)
}
// ----------------------------------------------------------------------------
func (s *peerStateData) publish() {
data := s.staged
s.published.Store(&data)
}
func (s *peerStateData) logf(format string, args ...any) {
b := strings.Builder{}
b.WriteString(fmt.Sprintf("%30s: ", s.peer.Name))
if s.client {
b.WriteString("CLIENT|")
} else {
b.WriteString("SERVER|")
}
if s.staged.Direct {
b.WriteString("DIRECT |")
} else {
b.WriteString("RELAYED|")
}
if s.staged.Up {
b.WriteString("UP |")
} else {
b.WriteString("DOWN|")
}
log.Printf(b.String()+format, args...)
}
// ----------------------------------------------------------------------------
func (s *peerStateData) OnPeerUpdate(peer *m.Peer) peerState {
defer s.publish()
if peer == nil {
return enterStateDisconnected(s)
}
s.peer = peer
s.staged.IP = s.remoteIP
s.staged.PubSignKey = peer.PubSignKey
s.staged.ControlCipher = newControlCipher(privKey, peer.PubKey)
s.staged.DataCipher = newDataCipher()
if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid {
s.remotePub = true
s.staged.Relay = peer.Relay
s.staged.Direct = true
s.staged.RemoteAddr = netip.AddrPortFrom(ip, peer.Port)
} else if localPub {
s.staged.Direct = true
}
if s.remotePub == localPub {
if localIP < s.remoteIP {
return enterStateServer(s)
}
return enterStateClient(s)
}
if s.remotePub {
return enterStateClient(s)
}
return enterStateServer(s)
}
// ----------------------------------------------------------------------------
type stateDisconnected struct {
*peerStateData
}
func enterStateDisconnected(s *peerStateData) peerState {
s.peer = nil
s.staged = peerRoute{}
s.publish()
return &stateDisconnected{s}
}
func (s *stateDisconnected) OnSyn(controlMsg[synPacket]) {}
func (s *stateDisconnected) OnAck(controlMsg[ackPacket]) {}
func (s *stateDisconnected) OnProbe(controlMsg[probePacket]) {}
func (s *stateDisconnected) OnLocalDiscovery(controlMsg[localDiscoveryPacket]) {}
func (s *stateDisconnected) OnPingTimer() peerState {
return nil
}
// ----------------------------------------------------------------------------
type stateServer struct {
*stateDisconnected
lastSeen time.Time
synTraceID uint64
}
func enterStateServer(s *peerStateData) peerState {
s.client = false
return &stateServer{stateDisconnected: &stateDisconnected{s}}
}
func (s *stateServer) OnSyn(msg controlMsg[synPacket]) {
s.lastSeen = time.Now()
p := msg.Packet
// Before we can respond to this packet, we need to make sure the
// route is setup properly.
//
// The client will update the syn's TraceID whenever there's a change.
// The server will follow the client's request.
if p.TraceID != s.synTraceID || !s.staged.Up {
s.synTraceID = p.TraceID
s.staged.Up = true
s.staged.Direct = p.Direct
s.staged.DataCipher = newDataCipherFromKey(p.SharedKey)
s.staged.RemoteAddr = msg.SrcAddr
s.publish()
s.logf("Got syn.")
}
// Always respond.
ack := ackPacket{
TraceID: p.TraceID,
FromAddr: getLocalAddr(),
}
s.sendControlPacket(ack)
if !s.staged.Direct && p.FromAddr.IsValid() {
s.sendControlPacketTo(probePacket{TraceID: newTraceID()}, p.FromAddr)
}
}
func (s *stateServer) OnProbe(msg controlMsg[probePacket]) {
if !msg.SrcAddr.IsValid() {
s.logf("Invalid probe address.")
return
}
s.sendControlPacketTo(probePacket{TraceID: msg.Packet.TraceID}, msg.SrcAddr)
}
func (s *stateServer) OnPingTimer() peerState {
if time.Since(s.lastSeen) > timeoutInterval && s.staged.Up {
s.staged.Up = false
s.publish()
s.logf("Connection timeout.")
}
return nil
}
// ----------------------------------------------------------------------------
type stateClient struct {
*stateDisconnected
lastSeen time.Time
syn synPacket
ack ackPacket
probeTraceID uint64
probeAddr netip.AddrPort
localProbeTraceID uint64
localProbeAddr netip.AddrPort
}
func enterStateClient(s *peerStateData) peerState {
s.client = true
ss := &stateClient{stateDisconnected: &stateDisconnected{s}}
ss.syn = synPacket{
TraceID: newTraceID(),
SharedKey: s.staged.DataCipher.Key(),
Direct: s.staged.Direct,
FromAddr: getLocalAddr(),
}
ss.sendSyn()
return ss
}
func (s *stateClient) OnAck(msg controlMsg[ackPacket]) {
if msg.Packet.TraceID != s.syn.TraceID {
s.logf("Ack has incorrect trace ID")
return
}
s.ack = msg.Packet
s.lastSeen = time.Now()
if !s.staged.Up {
s.staged.Up = true
s.logf("Got ack.")
s.publish()
} else {
}
}
func (s *stateClient) OnProbe(msg controlMsg[probePacket]) {
if s.staged.Direct {
return
}
switch msg.Packet.TraceID {
case s.probeTraceID:
s.staged.RemoteAddr = s.probeAddr
case s.localProbeTraceID:
s.staged.RemoteAddr = s.localProbeAddr
default:
return
}
s.staged.Direct = true
s.publish()
s.syn.TraceID = newTraceID()
s.syn.Direct = true
s.syn.FromAddr = getLocalAddr()
s.sendControlPacket(s.syn)
s.logf("Established direct connection to %s.", s.staged.RemoteAddr.String())
}
func (s *stateClient) OnLocalDiscovery(msg controlMsg[localDiscoveryPacket]) {
if s.staged.Direct {
return
}
// Send probe.
//
// The source port will be the multicast port, so we'll have to
// construct the correct address using the peer's listed port.
s.localProbeTraceID = newTraceID()
s.localProbeAddr = netip.AddrPortFrom(msg.SrcAddr.Addr(), s.peer.Port)
s.sendControlPacketTo(probePacket{TraceID: s.localProbeTraceID}, s.localProbeAddr)
}
func (s *stateClient) OnPingTimer() peerState {
if time.Since(s.lastSeen) > timeoutInterval {
if s.staged.Up {
s.logf("Connection timeout.")
}
return s.OnPeerUpdate(s.peer)
}
s.sendSyn()
if !s.staged.Direct && s.ack.FromAddr.IsValid() {
s.probeTraceID = newTraceID()
s.probeAddr = s.ack.FromAddr
s.sendControlPacketTo(probePacket{TraceID: s.probeTraceID}, s.probeAddr)
}
return nil
}
func (s *stateClient) sendSyn() {
localAddr := getLocalAddr()
if localAddr != s.syn.FromAddr {
s.syn.TraceID = newTraceID()
s.syn.FromAddr = localAddr
}
s.sendControlPacket(s.syn)
}

105
peer/app.go Normal file
View File

@@ -0,0 +1,105 @@
package peer
import (
"net/netip"
"os"
"os/signal"
"syscall"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"vppn/m"
"vppn/peer/control"
"vppn/peer/multicast"
"vppn/peer/wginterface"
)
var _ WGDevice = (*wginterface.Device)(nil) // compile-time check: Device satisfies WGDevice
const (
ControlPort = 4561
PingInterval = 8 * time.Second
TimeoutInterval = 30 * time.Second
)
// scratchSize is large enough for the biggest buffer either the ping or the
// multicast path serializes through the shared App scratch.
const scratchSize = max(control.Size, multicast.SignedPacketSize)
type PingEvent struct {
srcVPNIP netip.Addr
ping control.Ping
}
// App is the peer application. All mutable state lives here and is
// accessed only from the Run goroutine.
type App struct {
// Identity
vpnIP netip.Addr
vpnNet netip.Prefix
privKey wgtypes.Key
pubKey wgtypes.Key
isRelay bool
isPublic bool
localDomain string
// Infrastructure
dev WGDevice
controlConn ControlConn
// Peer state
relay *Peer
peersByKey map[wgtypes.Key]*Peer
peersByIP map[netip.Addr]*Peer
// Our own external endpoints, learned from Dst fields in incoming pings
selfV4 netip.AddrPort
selfV6 netip.AddrPort
// Reusable serialization scratch for outgoing pings and multicast signature
// verification. Only touched from the Run goroutine.
scratch []byte
// Event channels fed by background goroutines
hubAddCh <-chan m.Peer
hubRemoveCh <-chan wgtypes.Key
pingCh <-chan PingEvent
multicastCh <-chan multicast.Packet
}
// Run is the main event loop. It runs until SIGTERM/SIGINT.
func (a *App) Run() error {
// Establish a clean hosts section before the first poll lands, clearing
// any stale entries left by a prior run (e.g. crash, or peers removed
// while we were down).
a.updateHosts()
ticker := time.NewTicker(PingInterval)
defer ticker.Stop()
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
defer signal.Stop(sig)
for {
select {
case p := <-a.hubAddCh:
a.onAddPeer(p)
case key := <-a.hubRemoveCh:
a.onRemovePeer(key)
case e := <-a.pingCh:
a.onPing(e)
case e := <-a.multicastCh:
a.onMulticastDiscovery(e)
case <-ticker.C:
a.onTick()
case <-sig:
return a.onShutdown()
}
}
}
func (a *App) onShutdown() error {
return wginterface.Delete(a.dev.Name())
}

62
peer/app_test.go Normal file
View File

@@ -0,0 +1,62 @@
package peer
import (
"net/netip"
"testing"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"vppn/m"
"vppn/peer/multicast"
)
// addRelayPeer adds a public relay peer and marks it Up so it satisfies
// CanRelay. It does not set a.relay — callers do that explicitly.
func addRelayPeer(t *testing.T, a *App, vpnIP string, ep netip.AddrPort) *Peer {
t.Helper()
key := mustKey(t)
ip := netip.MustParseAddr(vpnIP)
a.onAddPeer(m.Peer{
WGPubKey: key,
PeerIP: ip.As4()[3],
Addr4: ep.Addr(),
Port: ep.Port(),
Relay: true,
})
p := a.peersByKey[key]
p.wgPeer.LastHandshakeTime = time.Now()
return p
}
// newTestApp returns a minimal App wired to a fakeWGDevice and fakeControlConn.
// vpnIP is the local VPN address (e.g. "10.0.0.1").
// isPublic / isRelay describe the local node's role.
func newTestApp(t *testing.T, vpnIP string, isPublic, isRelay bool) (*App, *fakeWGDevice, *fakeControlConn) {
t.Helper()
privKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatalf("generate key: %v", err)
}
ip := netip.MustParseAddr(vpnIP)
dev := &fakeWGDevice{}
cc := &fakeControlConn{}
a := &App{
vpnIP: ip,
vpnNet: netip.MustParsePrefix("10.0.0.0/24"),
privKey: privKey,
pubKey: privKey.PublicKey(),
isPublic: isPublic,
isRelay: isRelay,
dev: dev,
controlConn: cc,
peersByKey: make(map[wgtypes.Key]*Peer),
peersByIP: make(map[netip.Addr]*Peer),
scratch: make([]byte, scratchSize),
hubAddCh: make(chan m.Peer),
hubRemoveCh: make(chan wgtypes.Key),
pingCh: make(chan PingEvent),
multicastCh: make(chan multicast.Packet),
}
return a, dev, cc
}

76
peer/control/ping.go Normal file
View File

@@ -0,0 +1,76 @@
// Package control implements the VPN-internal peer control protocol.
// Peers exchange Ping packets over UDP on the VPN control port to maintain
// liveness and discover external endpoints for direct connection attempts.
package control
import (
"encoding/binary"
"fmt"
"net/netip"
)
const (
version = 1
Size = 51 // 1 version + 8 PingTS + 6 SrcV4 + 18 SrcV6 + 18 Dst
)
// Ping is the single control packet type exchanged between VPN peers.
//
// In each peer pair, the peer with the lower VPN IP is the client: it sets
// PingTS and sends pings on a timer. The server echoes PingTS back in its
// response, allowing the client to compute RTT = now - PingTS.
//
// Both client and server populate SrcV4, SrcV6, and Dst on every packet so
// endpoint information flows in both directions.
//
// Dst is the recipient's external endpoint as observed by the sender from the
// WireGuard handshake source. Zero if the sender has not observed a handshake
// from the recipient.
type Ping struct {
PingTS int64 // Client ping send time in nanoseconds.
SrcV4 netip.AddrPort // Sender's discovered IPv4 address and port.
SrcV6 netip.AddrPort // Sender's discovered IPv6 address and port.
Dst netip.AddrPort
}
// Marshal encodes p into buf (which must be at least Size bytes) and returns
// buf[:Size]. Taking the buffer lets callers reuse one across sends; every
// field is written unconditionally so a reused buffer needs no pre-zeroing.
func (p Ping) Marshal(buf []byte) []byte {
buf[0] = version
binary.BigEndian.PutUint64(buf[1:9], uint64(p.PingTS))
if p.SrcV4.IsValid() {
a4 := p.SrcV4.Addr().As4()
copy(buf[9:13], a4[:])
binary.BigEndian.PutUint16(buf[13:15], p.SrcV4.Port())
} else {
clear(buf[9:15])
}
a16 := p.SrcV6.Addr().As16()
copy(buf[15:31], a16[:])
binary.BigEndian.PutUint16(buf[31:33], p.SrcV6.Port())
a16 = p.Dst.Addr().As16()
copy(buf[33:49], a16[:])
binary.BigEndian.PutUint16(buf[49:51], p.Dst.Port())
return buf[:Size]
}
// Unmarshal decodes a Ping from a fixed-size 51-byte array.
func Unmarshal(buf [Size]byte) (Ping, error) {
if buf[0] != version {
return Ping{}, fmt.Errorf("unknown ping version %d", buf[0])
}
p := Ping{
PingTS: int64(binary.BigEndian.Uint64(buf[1:9])),
}
if addr := netip.AddrFrom4([4]byte(buf[9:13])); !addr.IsUnspecified() {
p.SrcV4 = netip.AddrPortFrom(addr, binary.BigEndian.Uint16(buf[13:15]))
}
if addr := netip.AddrFrom16([16]byte(buf[15:31])); !addr.IsUnspecified() {
p.SrcV6 = netip.AddrPortFrom(addr, binary.BigEndian.Uint16(buf[31:33]))
}
if addr := netip.AddrFrom16([16]byte(buf[33:49])).Unmap(); !addr.IsUnspecified() {
p.Dst = netip.AddrPortFrom(addr, binary.BigEndian.Uint16(buf[49:51]))
}
return p, nil
}

106
peer/control/ping_test.go Normal file
View File

@@ -0,0 +1,106 @@
package control_test
import (
"net/netip"
"testing"
"vppn/peer/control"
)
func TestRoundTrip(t *testing.T) {
cases := []struct {
name string
ping control.Ping
}{
{
name: "zero",
ping: control.Ping{},
},
{
name: "client ping",
ping: control.Ping{
PingTS: 1234567890,
SrcV4: netip.MustParseAddrPort("1.2.3.4:51820"),
Dst: netip.MustParseAddrPort("5.6.7.8:51820"),
},
},
{
name: "server response",
ping: control.Ping{
PingTS: 1234567890,
SrcV4: netip.MustParseAddrPort("5.6.7.8:51820"),
Dst: netip.MustParseAddrPort("1.2.3.4:9999"),
},
},
{
name: "IPv6 only",
ping: control.Ping{
PingTS: 999,
SrcV6: netip.MustParseAddrPort("[2001:db8::1]:51820"),
Dst: netip.MustParseAddrPort("[2001:db8::2]:51820"),
},
},
{
name: "dual stack",
ping: control.Ping{
PingTS: 555,
SrcV4: netip.MustParseAddrPort("1.2.3.4:51820"),
SrcV6: netip.MustParseAddrPort("[2001:db8::1]:51820"),
Dst: netip.MustParseAddrPort("5.6.7.8:9999"),
},
},
{
name: "no src known",
ping: control.Ping{
Dst: netip.MustParseAddrPort("5.6.7.8:51820"),
},
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
var buf [control.Size]byte
tc.ping.Marshal(buf[:])
got, err := control.Unmarshal(buf)
if err != nil {
t.Fatalf("Unmarshal: %v", err)
}
if got != tc.ping {
t.Fatalf("round-trip mismatch:\n got %+v\n want %+v", got, tc.ping)
}
})
}
}
func TestUnmarshalBadVersion(t *testing.T) {
var buf [control.Size]byte
buf[0] = 99
if _, err := control.Unmarshal(buf); err == nil {
t.Fatal("expected error for unknown version, got nil")
}
}
func TestZeroEncoding(t *testing.T) {
var buf [control.Size]byte
(control.Ping{}).Marshal(buf[:])
for i, b := range buf {
if i == 0 {
continue // version byte
}
if b != 0 {
t.Fatalf("expected zero encoding at byte %d, got %d", i, b)
}
}
}
func TestRoleFor(t *testing.T) {
lo := netip.MustParseAddr("10.0.0.1")
hi := netip.MustParseAddr("10.0.0.2")
if control.RoleFor(lo, hi) != control.Client {
t.Error("lower IP should be client")
}
if control.RoleFor(hi, lo) != control.Server {
t.Error("higher IP should be server")
}
}

22
peer/control/role.go Normal file
View File

@@ -0,0 +1,22 @@
package control
import "net/netip"
// Role identifies a peer's role in a ping exchange with a specific remote peer.
type Role string
const (
// Client initiates pings and measures RTT.
Client Role = "CLIENT"
// Server responds to pings.
Server Role = "SERVER"
)
// RoleFor returns the Role of local relative to remote.
// The peer with the lower VPN IP is the client.
func RoleFor(local, remote netip.Addr) Role {
if local.Compare(remote) < 0 {
return Client
}
return Server
}

64
peer/control_conn.go Normal file
View File

@@ -0,0 +1,64 @@
package peer
import (
"log"
"net"
"net/netip"
"vppn/peer/control"
)
var _ ControlConn = (*udpControlConn)(nil)
type udpControlConn struct {
conn *net.UDPConn
}
// newUDPControlConn opens a UDP socket bound to localIP:port.
func newUDPControlConn(localIP netip.Addr, port uint16) (*udpControlConn, error) {
addr := net.UDPAddrFromAddrPort(netip.AddrPortFrom(localIP, port))
conn, err := net.ListenUDP("udp4", addr)
if err != nil {
return nil, err
}
return &udpControlConn{conn: conn}, nil
}
func (c *udpControlConn) SendPing(dst netip.AddrPort, ping control.Ping, buf []byte) error {
_, err := c.conn.WriteToUDP(ping.Marshal(buf), net.UDPAddrFromAddrPort(dst))
return err
}
// run reads incoming ping packets and forwards them to ch until ctx is done.
// Call this in a goroutine before starting the App event loop.
func (c *udpControlConn) run(ch chan<- PingEvent) {
var buf [control.Size]byte
for {
n, src, err := c.conn.ReadFromUDP(buf[:])
if err != nil {
log.Printf("control read: %v", err)
continue
}
if n != control.Size {
continue
}
ping, err := control.Unmarshal(buf)
if err != nil {
log.Printf("control unmarshal: %v", err)
continue
}
srcIP, ok := netip.AddrFromSlice(src.IP)
if !ok {
continue
}
ch <- PingEvent{srcVPNIP: srcIP.Unmap(), ping: ping}
}
}
func (c *udpControlConn) Close() error {
return c.conn.Close()
}

13
peer/data-flow.dot Normal file
View File

@@ -0,0 +1,13 @@
digraph d {
ifReader -> remote;
connReader -> remote;
mcReader -> remote;
remote -> connWriter;
remote -> ifWriter;
hubPoller -> remote;
connWriter [shape="box"];
mcWriter [shape="box"];
ifWriter [shape="box"];
}

79
peer/device.go Normal file
View File

@@ -0,0 +1,79 @@
package peer
import (
"errors"
"log"
"net/netip"
"syscall"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// devRetry calls fn up to 6 times with exponential backoff, retrying on EBUSY
// (transient netlink contention during WireGuard handshake/rekey). Fatal on
// any other error.
func devRetry(vpnIP netip.Addr, op string, fn func() error) {
const attempts = 6
timeout := 10 * time.Millisecond
for i := range attempts {
err := fn()
if err == nil {
return
}
if errors.Is(err, syscall.EBUSY) && i < attempts-1 {
time.Sleep(timeout)
timeout *= 2
continue
}
log.Fatalf("%s %v: %v", op, vpnIP, err)
}
}
func (a *App) devPeers() []wgtypes.Peer {
peers, err := a.dev.Peers()
if err != nil {
log.Fatalf("Failed to get peers %v: %v", a.vpnIP, err)
}
return peers
}
func (a *App) devAddPeer(p *Peer) {
log.Printf("RELAYED: %s - %s ", p.Name, p.VPNIP.String())
devRetry(p.VPNIP, "AddPeer", func() error { return a.dev.AddPeer(p.PubKey()) })
p.State = StateRelayed
}
func (a *App) devAddDirect(p *Peer, endpoint netip.AddrPort) {
log.Printf("DIRECT: %s - %s @ %s", p.Name, p.VPNIP.String(), endpoint.String())
devRetry(p.VPNIP, "AddDirect", func() error { return a.dev.AddDirect(p.PubKey(), endpoint, p.VPNIP) })
p.State = StateDirect
}
func (a *App) devSetRelay(p *Peer, endpoint netip.AddrPort) {
log.Printf("RELAY: %s - %s @ %s", p.Name, p.VPNIP.String(), endpoint.String())
devRetry(p.VPNIP, "SetRelay", func() error { return a.dev.SetRelay(p.PubKey(), endpoint, a.vpnNet) })
p.State = StateDirect // Dirrect connection. The app marks peer as relay.
}
func (a *App) devPromote(p *Peer) {
ep := p.WGEndpoint()
if ep.IsValid() {
log.Printf("PROMOTED: %s - %s @ %s", p.Name, p.VPNIP.String(), p.WGEndpoint().String())
} else {
log.Printf("PROMOTED: %s - %s (no IP)", p.Name, p.VPNIP.String())
}
devRetry(p.VPNIP, "Promote", func() error { return a.dev.Promote(p.PubKey(), p.VPNIP) })
p.State = StateDirect
}
func (a *App) devAddProbe(p *Peer, endpoint netip.AddrPort) {
log.Printf("PROBE: %s - %s @ %s", p.Name, p.VPNIP.String(), endpoint.String())
devRetry(p.VPNIP, "AddProbe", func() error { return a.dev.AddProbe(p.PubKey(), endpoint) })
p.State = StateProbing
}
func (a *App) devRemove(p *Peer) {
log.Printf("REMOVED: %s - %s", p.Name, p.VPNIP.String())
devRetry(p.VPNIP, "RemovePeer", func() error { return a.dev.RemovePeer(p.PubKey()) })
}

View File

@@ -0,0 +1,43 @@
package peer
import (
"net/netip"
"testing"
"vppn/peer/control"
)
type sentPing struct {
Dst netip.AddrPort
Ping control.Ping
}
type fakeControlConn struct {
Sent []sentPing
}
func (f *fakeControlConn) SendPing(dst netip.AddrPort, ping control.Ping, _ []byte) error {
f.Sent = append(f.Sent, sentPing{Dst: dst, Ping: ping})
return nil
}
func (f *fakeControlConn) AssertNone(t *testing.T) {
t.Helper()
if len(f.Sent) != 0 {
t.Fatalf("expected no pings sent, got %d: %v", len(f.Sent), f.Sent)
}
}
func (f *fakeControlConn) AssertSent(t *testing.T, i int, dst netip.AddrPort, ping control.Ping) {
t.Helper()
if i >= len(f.Sent) {
t.Fatalf("no ping at index %d (have %d)", i, len(f.Sent))
}
got := f.Sent[i]
if got.Dst != dst {
t.Errorf("ping[%d].Dst = %v, want %v", i, got.Dst, dst)
}
if got.Ping != ping {
t.Errorf("ping[%d].Ping = %+v, want %+v", i, got.Ping, ping)
}
}

123
peer/fake_wgdevice_test.go Normal file
View File

@@ -0,0 +1,123 @@
package peer
import (
"net/netip"
"sync"
"testing"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// fakeWGDevice records every call made to it. It is safe to read Calls after
// the event loop has processed the event under test (single-threaded loop
// means no extra synchronisation needed, but the mutex guards concurrent test
// helpers if needed).
type fakeWGDevice struct {
mu sync.Mutex
Calls []fakeCall
peers []wgtypes.Peer
}
type fakeCall struct {
Method string
PubKey wgtypes.Key
Endpoint netip.AddrPort
VPNiP netip.Addr
Network netip.Prefix
}
func (f *fakeWGDevice) record(c fakeCall) {
f.mu.Lock()
f.Calls = append(f.Calls, c)
f.mu.Unlock()
}
func (f *fakeWGDevice) Name() string { return "wg-test" }
func (f *fakeWGDevice) Peers() ([]wgtypes.Peer, error) {
f.mu.Lock()
defer f.mu.Unlock()
out := make([]wgtypes.Peer, len(f.peers))
copy(out, f.peers)
return out, nil
}
func (f *fakeWGDevice) AddPeer(pubKey wgtypes.Key) error {
f.record(fakeCall{Method: "AddPeer", PubKey: pubKey})
return nil
}
func (f *fakeWGDevice) AddDirect(pubKey wgtypes.Key, endpoint netip.AddrPort, vpnIP netip.Addr) error {
f.record(fakeCall{Method: "AddDirect", PubKey: pubKey, Endpoint: endpoint, VPNiP: vpnIP})
return nil
}
func (f *fakeWGDevice) SetRelay(pubKey wgtypes.Key, endpoint netip.AddrPort, network netip.Prefix) error {
f.record(fakeCall{Method: "SetRelay", PubKey: pubKey, Endpoint: endpoint, Network: network})
return nil
}
func (f *fakeWGDevice) AddProbe(pubKey wgtypes.Key, endpoint netip.AddrPort) error {
f.record(fakeCall{Method: "AddProbe", PubKey: pubKey, Endpoint: endpoint})
return nil
}
func (f *fakeWGDevice) Promote(pubKey wgtypes.Key, vpnIP netip.Addr) error {
f.record(fakeCall{Method: "Promote", PubKey: pubKey, VPNiP: vpnIP})
return nil
}
func (f *fakeWGDevice) RemovePeer(pubKey wgtypes.Key) error {
f.record(fakeCall{Method: "RemovePeer", PubKey: pubKey})
return nil
}
// AssertNoCalls fails the test if any dev calls were recorded.
func (f *fakeWGDevice) AssertNoCalls(t *testing.T) {
t.Helper()
f.mu.Lock()
defer f.mu.Unlock()
if len(f.Calls) != 0 {
t.Fatalf("unexpected dev calls: %v", f.Calls)
}
}
func (f *fakeWGDevice) AssertAddPeer(t *testing.T, i int, pubKey wgtypes.Key) {
t.Helper()
f.assertCall(t, i, fakeCall{Method: "AddPeer", PubKey: pubKey})
}
func (f *fakeWGDevice) AssertAddDirect(t *testing.T, i int, pubKey wgtypes.Key, endpoint netip.AddrPort, vpnIP netip.Addr) {
t.Helper()
f.assertCall(t, i, fakeCall{Method: "AddDirect", PubKey: pubKey, Endpoint: endpoint, VPNiP: vpnIP})
}
func (f *fakeWGDevice) AssertSetRelay(t *testing.T, i int, pubKey wgtypes.Key, endpoint netip.AddrPort, network netip.Prefix) {
t.Helper()
f.assertCall(t, i, fakeCall{Method: "SetRelay", PubKey: pubKey, Endpoint: endpoint, Network: network})
}
func (f *fakeWGDevice) AssertAddProbe(t *testing.T, i int, pubKey wgtypes.Key, endpoint netip.AddrPort) {
t.Helper()
f.assertCall(t, i, fakeCall{Method: "AddProbe", PubKey: pubKey, Endpoint: endpoint})
}
func (f *fakeWGDevice) AssertPromote(t *testing.T, i int, pubKey wgtypes.Key, vpnIP netip.Addr) {
t.Helper()
f.assertCall(t, i, fakeCall{Method: "Promote", PubKey: pubKey, VPNiP: vpnIP})
}
func (f *fakeWGDevice) AssertRemovePeer(t *testing.T, i int, pubKey wgtypes.Key) {
t.Helper()
f.assertCall(t, i, fakeCall{Method: "RemovePeer", PubKey: pubKey})
}
func (f *fakeWGDevice) assertCall(t *testing.T, i int, c fakeCall) {
t.Helper()
if len(f.Calls) <= i {
t.Fatalf("no call at index %d: %v", i, c)
}
if c != f.Calls[i] {
t.Fatalf("call[%d]: got %v, want %v", i, f.Calls[i], c)
}
}

128
peer/hosts.go Normal file
View File

@@ -0,0 +1,128 @@
package peer
import (
"fmt"
"log"
"net/netip"
"os"
"sort"
"strings"
"syscall"
"git.crumpington.com/lib/go/flock"
)
const (
hostsFile = "/etc/hosts"
hostsBegin = "# BEGIN vppn"
hostsEnd = "# END vppn"
)
// hostMarkers returns the begin/end marker lines that delimit the managed
// section for localDomain. The domain is wrapped in parentheses so one domain's
// marker can never be a prefix of another's (e.g. "net" vs "net2") when
// multiple vppn instances share /etc/hosts.
func hostMarkers(localDomain string) (begin, end string) {
return hostsBegin + "(" + localDomain + ")", hostsEnd + "(" + localDomain + ")"
}
// updateHosts rewrites the managed vppn section in /etc/hosts using the
// current peersByIP map. Peers without a Name are skipped.
func (a *App) updateHosts() {
if a.localDomain == "" {
return
}
if err := updateHosts(hostsFile, a.localDomain, a.peersByIP); err != nil {
log.Printf("Failed to update hosts file: %v", err)
}
}
func updateHosts(hostsPath, localDomain string, peers map[netip.Addr]*Peer) error {
lockFile, err := flock.Lock(hostsPath + ".vppn.lock")
if err != nil {
return err
}
defer lockFile.Close()
begin, end := hostMarkers(localDomain)
info, err := os.Stat(hostsPath)
if err != nil {
return err
}
raw, err := os.ReadFile(hostsPath)
if err != nil {
return err
}
data := string(raw)
before := strings.TrimSpace(data)
after := ""
if idxBegin := strings.Index(data, begin); idxBegin != -1 {
idxEnd := strings.Index(data[idxBegin:], end)
if idxEnd != -1 {
after = strings.TrimSpace(data[idxBegin+idxEnd+len(end):])
}
before = strings.TrimSpace(data[:idxBegin])
}
b := strings.Builder{}
b.WriteString(before)
b.WriteRune('\n')
b.WriteString(after)
b.WriteRune('\n')
b.WriteRune('\n')
b.WriteString(begin)
b.WriteRune('\n')
// Collect entries so we can sort by IP for stable output. Pad the IP
// column to the width of the widest possible address ("255.255.255.255")
// for readability.
type entry struct {
ip netip.Addr
host string
}
var entries []entry
for ip, p := range peers {
if p.Name == "" {
continue
}
entries = append(entries, entry{ip: ip, host: p.Name + "." + localDomain})
}
sort.Slice(entries, func(i, j int) bool {
return entries[i].ip.Less(entries[j].ip)
})
for _, e := range entries {
b.WriteString(fmt.Sprintf("%-15s %s\n", e.ip.String(), e.host))
}
b.WriteString(end)
b.WriteRune('\n')
// Write to a temp file in the same directory, then rename over the
// original so readers never observe a partial file. Preserve the
// original's mode and ownership, since rename replaces the inode.
tmpPath := hostsPath + ".vppn.tmp"
if err := os.WriteFile(tmpPath, []byte(b.String()), info.Mode().Perm()); err != nil {
return err
}
if st, ok := info.Sys().(*syscall.Stat_t); ok {
if err := os.Chown(tmpPath, int(st.Uid), int(st.Gid)); err != nil {
os.Remove(tmpPath)
return err
}
}
if err := os.Rename(tmpPath, hostsPath); err != nil {
os.Remove(tmpPath)
return err
}
return nil
}

205
peer/hosts_test.go Normal file
View File

@@ -0,0 +1,205 @@
package peer
import (
"net/netip"
"os"
"path/filepath"
"sort"
"strings"
"testing"
)
// writeTempHosts creates a temp hosts file with the given content and returns
// its path.
func writeTempHosts(t *testing.T, content string) string {
t.Helper()
path := filepath.Join(t.TempDir(), "hosts")
if err := os.WriteFile(path, []byte(content), 0600); err != nil {
t.Fatal(err)
}
return path
}
// readManagedSection returns the lines between the begin/end markers for the
// given localDomain, plus everything outside the section ("outside").
func readManagedSection(t *testing.T, path, localDomain string) (inside, outside []string) {
t.Helper()
raw, err := os.ReadFile(path)
if err != nil {
t.Fatal(err)
}
begin, end := hostMarkers(localDomain)
inSection := false
for _, line := range strings.Split(string(raw), "\n") {
switch {
case strings.HasPrefix(line, begin):
inSection = true
case strings.HasPrefix(line, end):
inSection = false
case inSection:
if f := strings.Join(strings.Fields(line), " "); f != "" {
inside = append(inside, f)
}
default:
if f := strings.Join(strings.Fields(line), " "); f != "" {
outside = append(outside, f)
}
}
}
return inside, outside
}
func peer(name string) *Peer {
return &Peer{Name: name}
}
func TestUpdateHosts_AddsSection(t *testing.T) {
path := writeTempHosts(t, "127.0.0.1 localhost\n")
peers := map[netip.Addr]*Peer{
netip.MustParseAddr("10.11.12.1"): peer("hub"),
netip.MustParseAddr("10.11.12.10"): peer("laptop"),
}
if err := updateHosts(path, "mynet.local", peers); err != nil {
t.Fatal(err)
}
inside, outside := readManagedSection(t, path, "mynet.local")
sort.Strings(inside)
want := []string{
"10.11.12.1 hub.mynet.local",
"10.11.12.10 laptop.mynet.local",
}
if strings.Join(inside, "\n") != strings.Join(want, "\n") {
t.Errorf("managed section = %v, want %v", inside, want)
}
if !contains(outside, "127.0.0.1 localhost") {
t.Errorf("original content lost; outside = %v", outside)
}
}
func TestUpdateHosts_ReplacesExistingSection(t *testing.T) {
path := writeTempHosts(t, "127.0.0.1 localhost\n")
// First write.
first := map[netip.Addr]*Peer{
netip.MustParseAddr("10.11.12.1"): peer("hub"),
}
if err := updateHosts(path, "mynet.local", first); err != nil {
t.Fatal(err)
}
// Second write with a different set of peers.
second := map[netip.Addr]*Peer{
netip.MustParseAddr("10.11.12.20"): peer("phone"),
}
if err := updateHosts(path, "mynet.local", second); err != nil {
t.Fatal(err)
}
inside, outside := readManagedSection(t, path, "mynet.local")
if len(inside) != 1 || inside[0] != "10.11.12.20 phone.mynet.local" {
t.Errorf("section not replaced; inside = %v", inside)
}
if contains(inside, "10.11.12.1 hub.mynet.local") {
t.Errorf("stale entry remained; inside = %v", inside)
}
if !contains(outside, "127.0.0.1 localhost") {
t.Errorf("original content lost; outside = %v", outside)
}
}
func TestUpdateHosts_SkipsEmptyNames(t *testing.T) {
path := writeTempHosts(t, "127.0.0.1 localhost\n")
peers := map[netip.Addr]*Peer{
netip.MustParseAddr("10.11.12.1"): peer("hub"),
netip.MustParseAddr("10.11.12.99"): peer(""), // no name
}
if err := updateHosts(path, "mynet.local", peers); err != nil {
t.Fatal(err)
}
inside, _ := readManagedSection(t, path, "mynet.local")
if len(inside) != 1 || inside[0] != "10.11.12.1 hub.mynet.local" {
t.Errorf("expected only named peer; inside = %v", inside)
}
}
func TestUpdateHosts_Idempotent(t *testing.T) {
path := writeTempHosts(t, "127.0.0.1 localhost\n")
peers := map[netip.Addr]*Peer{
netip.MustParseAddr("10.11.12.1"): peer("hub"),
}
if err := updateHosts(path, "mynet.local", peers); err != nil {
t.Fatal(err)
}
first, err := os.ReadFile(path)
if err != nil {
t.Fatal(err)
}
if err := updateHosts(path, "mynet.local", peers); err != nil {
t.Fatal(err)
}
second, err := os.ReadFile(path)
if err != nil {
t.Fatal(err)
}
if string(first) != string(second) {
t.Errorf("repeated update changed file:\nfirst:\n%s\nsecond:\n%s", first, second)
}
}
// TestUpdateHosts_PrefixDomainsCoexist guards finding 4.4: two domains where
// one label is a prefix of the other ("net" vs "net2") must each manage their
// own section without clobbering the other's, even sharing one hosts file.
func TestUpdateHosts_PrefixDomainsCoexist(t *testing.T) {
path := writeTempHosts(t, "127.0.0.1 localhost\n")
if err := updateHosts(path, "net2.local", map[netip.Addr]*Peer{
netip.MustParseAddr("10.0.2.1"): peer("a"),
}); err != nil {
t.Fatal(err)
}
if err := updateHosts(path, "net.local", map[netip.Addr]*Peer{
netip.MustParseAddr("10.0.1.1"): peer("b"),
}); err != nil {
t.Fatal(err)
}
// Both sections coexist after writing the prefix domain.
if in, _ := readManagedSection(t, path, "net2.local"); len(in) != 1 || in[0] != "10.0.2.1 a.net2.local" {
t.Errorf("net2 section clobbered: %v", in)
}
if in, _ := readManagedSection(t, path, "net.local"); len(in) != 1 || in[0] != "10.0.1.1 b.net.local" {
t.Errorf("net section wrong: %v", in)
}
// Re-updating net2 must not disturb the net section.
if err := updateHosts(path, "net2.local", map[netip.Addr]*Peer{
netip.MustParseAddr("10.0.2.2"): peer("c"),
}); err != nil {
t.Fatal(err)
}
if in, _ := readManagedSection(t, path, "net.local"); len(in) != 1 || in[0] != "10.0.1.1 b.net.local" {
t.Errorf("net section disturbed by net2 update: %v", in)
}
if in, _ := readManagedSection(t, path, "net2.local"); len(in) != 1 || in[0] != "10.0.2.2 c.net2.local" {
t.Errorf("net2 section not updated: %v", in)
}
}
func contains(ss []string, s string) bool {
for _, x := range ss {
if x == s {
return true
}
}
return false
}

153
peer/hub_poller.go Normal file
View File

@@ -0,0 +1,153 @@
package peer
import (
"encoding/json"
"io"
"log"
"net/http"
"net/netip"
"net/url"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"vppn/m"
)
const hubPollInterval = 64 * time.Second
type HubPoller struct {
selfVPNIP netip.Addr
vpnNet netip.Prefix
hubURL string
apiKey string
statePath string // where the network state cache is persisted
addCh chan<- m.Peer
removeCh chan<- wgtypes.Key
known map[wgtypes.Key]struct{} // pubKeys currently configured
}
func NewHubPoller(
selfVPNIP netip.Addr,
vpnNet netip.Prefix,
hubURL, apiKey string,
statePath string,
addCh chan<- m.Peer,
removeCh chan<- wgtypes.Key,
) (*HubPoller, error) {
u, err := url.Parse(hubURL)
if err != nil {
return nil, err
}
u.Path = "/peer/fetch-state/"
return &HubPoller{
selfVPNIP: selfVPNIP,
vpnNet: vpnNet,
hubURL: u.String(),
apiKey: apiKey,
statePath: statePath,
addCh: addCh,
removeCh: removeCh,
known: make(map[wgtypes.Key]struct{}),
}, nil
}
func (hp *HubPoller) Run() {
// Prime from the on-disk cache before reaching the hub, so the peer
// configures WireGuard from its last known state even if the hub is down.
// known starts empty, so this emits every cached peer as an add; the first
// real poll then emits only deltas (adds for new peers, removes for gone).
if state, err := loadNetworkState(hp.statePath); err == nil {
hp.apply(state)
}
hp.poll()
for range time.Tick(hubPollInterval) {
hp.poll()
}
}
func (hp *HubPoller) poll() {
req, err := http.NewRequest(http.MethodGet, hp.hubURL, nil)
if err != nil {
log.Printf("[HubPoller] build request: %v", err)
return
}
req.SetBasicAuth("", hp.apiKey)
client := &http.Client{Timeout: 32 * time.Second}
resp, err := client.Do(req)
if err != nil {
log.Printf("[HubPoller] fetch: %v", err)
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
log.Printf("[HubPoller] unexpected status %d", resp.StatusCode)
return
}
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Printf("[HubPoller] read body: %v", err)
return
}
var state m.NetworkState
if err := json.Unmarshal(body, &state); err != nil {
log.Printf("[HubPoller] unmarshal: %v", err)
return
}
// Persist only when the state actually changed, to avoid needless writes
// on every poll.
if hp.apply(state) {
if err := saveNetworkState(hp.statePath, state); err != nil {
log.Printf("[HubPoller] save state: %v", err)
}
}
}
// apply diffs state against the set of known peers, emitting an add for each
// newly-seen peer and a remove for each that disappeared. It returns true if
// anything changed. A peer's config is immutable under a stable WG key (the hub
// has no peer-edit path), so a key already in known needs no re-emit.
func (hp *HubPoller) apply(state m.NetworkState) (changed bool) {
seen := make(map[wgtypes.Key]struct{}, len(hp.known))
netAddr := hp.vpnNet.Addr().As4()
for _, p := range state.Peers {
if p.WGPubKey == (wgtypes.Key{}) {
continue
}
octets := netAddr
octets[3] = p.PeerIP
vpnIP := netip.AddrFrom4(octets)
if vpnIP == hp.selfVPNIP {
continue
}
seen[p.WGPubKey] = struct{}{}
if _, ok := hp.known[p.WGPubKey]; ok {
continue
}
hp.known[p.WGPubKey] = struct{}{}
hp.addCh <- p
changed = true
}
for key := range hp.known {
if _, ok := seen[key]; !ok {
delete(hp.known, key)
hp.removeCh <- key
changed = true
}
}
return changed
}

80
peer/hub_poller_test.go Normal file
View File

@@ -0,0 +1,80 @@
package peer
import (
"net/netip"
"testing"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"vppn/m"
)
func testPoller(t *testing.T) (*HubPoller, chan m.Peer, chan wgtypes.Key) {
t.Helper()
addCh := make(chan m.Peer, 8)
removeCh := make(chan wgtypes.Key, 8)
hp := &HubPoller{
selfVPNIP: netip.MustParseAddr("10.0.0.1"),
vpnNet: netip.MustParsePrefix("10.0.0.0/24"),
addCh: addCh,
removeCh: removeCh,
known: make(map[wgtypes.Key]struct{}),
}
return hp, addCh, removeCh
}
func stateWith(key wgtypes.Key, peerIP byte) m.NetworkState {
return m.NetworkState{Peers: []m.Peer{{
PeerIP: peerIP,
WGPubKey: key,
}}}
}
func TestApply_EmitsAddsAndReportsChange(t *testing.T) {
hp, addCh, _ := testPoller(t)
key := mustKey(t)
if changed := hp.apply(stateWith(key, 2)); !changed {
t.Fatal("expected changed=true on first apply")
}
if len(addCh) != 1 {
t.Fatalf("expected 1 add, got %d", len(addCh))
}
if got := <-addCh; got.WGPubKey != key {
t.Errorf("add pubkey mismatch")
}
}
func TestApply_NoChangeWhenKnown(t *testing.T) {
hp, addCh, _ := testPoller(t)
key := mustKey(t)
hp.apply(stateWith(key, 2))
<-addCh // drain initial add
if changed := hp.apply(stateWith(key, 2)); changed {
t.Fatal("expected changed=false when peer already known")
}
if len(addCh) != 0 {
t.Fatalf("expected no re-emit, got %d adds", len(addCh))
}
}
func TestApply_RemovesVanishedPeer(t *testing.T) {
hp, addCh, removeCh := testPoller(t)
key := mustKey(t)
hp.apply(stateWith(key, 2))
<-addCh
// Empty state: the peer is gone.
if changed := hp.apply(m.NetworkState{}); !changed {
t.Fatal("expected changed=true when peer vanishes")
}
if len(removeCh) != 1 {
t.Fatalf("expected 1 remove, got %d", len(removeCh))
}
if got := <-removeCh; got != key {
t.Errorf("remove key mismatch")
}
}

190
peer/init.go Normal file
View File

@@ -0,0 +1,190 @@
package peer
import (
"bytes"
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/netip"
"os"
"golang.org/x/crypto/nacl/sign"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"vppn/m"
)
// LocalState is the persisted identity for this peer, written on first run and
// loaded on every subsequent run.
type LocalState struct {
PrivKey wgtypes.Key
SignKey [64]byte // nacl/sign Ed25519 private key
VPNIP netip.Addr
VPNNet netip.Prefix
WGPort uint16
IsRelay bool
IsPublic bool
LocalDomain string
}
// localStateJSON is the on-disk representation.
type localStateJSON struct {
PrivKey string
SignKey string
VPNIP netip.Addr
VPNNet netip.Prefix
WGPort uint16
IsRelay bool
IsPublic bool
LocalDomain string
}
// LoadOrInit loads LocalState from path, or registers with the hub and creates
// the file if it doesn't exist.
func LoadOrInit(statePath, hubURL, apiKey string) (LocalState, error) {
var state LocalState
switch err := loadJSON(statePath, &state); {
case err == nil:
return state, nil
case !os.IsNotExist(err):
// File exists but is unreadable/corrupt: surface it rather than
// silently regenerating a new identity and re-registering.
return LocalState{}, fmt.Errorf("load state: %w", err)
}
privKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
return LocalState{}, fmt.Errorf("generate key: %w", err)
}
state, err = initFromHub(hubURL, apiKey, privKey)
if err != nil {
return LocalState{}, err
}
if err := storeJSON(statePath, state); err != nil {
return LocalState{}, fmt.Errorf("save state: %w", err)
}
return state, nil
}
func initFromHub(hubURL, apiKey string, privKey wgtypes.Key) (LocalState, error) {
wgPubKey := privKey.PublicKey()
signPubKey, signPrivKey, err := sign.GenerateKey(rand.Reader)
if err != nil {
return LocalState{}, fmt.Errorf("generate sign key: %w", err)
}
body, err := json.Marshal(m.PeerInitArgs{
WGPubKey: wgPubKey[:],
SignPubKey: signPubKey[:],
})
if err != nil {
return LocalState{}, fmt.Errorf("json error: %w", err)
}
req, err := http.NewRequest(http.MethodPost, hubURL+"/peer/init/", bytes.NewReader(body))
if err != nil {
return LocalState{}, err
}
req.SetBasicAuth("", apiKey)
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return LocalState{}, fmt.Errorf("hub init: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return LocalState{}, fmt.Errorf("hub init: HTTP %d", resp.StatusCode)
}
var r m.PeerInitResp
if err := json.NewDecoder(resp.Body).Decode(&r); err != nil {
return LocalState{}, fmt.Errorf("hub init decode: %w", err)
}
if len(r.Network) != 4 {
return LocalState{}, fmt.Errorf("hub init: invalid network %v", r.Network)
}
netAddr := netip.AddrFrom4([4]byte(r.Network))
octets := netAddr.As4()
octets[3] = r.PeerIP
vpnIP := netip.AddrFrom4(octets)
vpnNet := netip.PrefixFrom(netAddr, 24)
var self *m.Peer
for i := range r.NetworkState.Peers {
if r.NetworkState.Peers[i].PeerIP == r.PeerIP {
self = &r.NetworkState.Peers[i]
break
}
}
if self == nil {
return LocalState{}, fmt.Errorf("hub init: no peer for own IP: %d", r.PeerIP)
}
public := self.IsPublic()
return LocalState{
PrivKey: privKey,
SignKey: *signPrivKey,
VPNIP: vpnIP,
VPNNet: vpnNet,
WGPort: self.Port,
IsRelay: self.Relay && public,
IsPublic: public,
LocalDomain: r.LocalDomain,
}, nil
}
func (s LocalState) MarshalJSON() ([]byte, error) {
return json.Marshal(localStateJSON{
PrivKey: base64.StdEncoding.EncodeToString(s.PrivKey[:]),
SignKey: base64.StdEncoding.EncodeToString(s.SignKey[:]),
VPNIP: s.VPNIP,
VPNNet: s.VPNNet,
WGPort: s.WGPort,
IsRelay: s.IsRelay,
IsPublic: s.IsPublic,
LocalDomain: s.LocalDomain,
})
}
func (s *LocalState) UnmarshalJSON(data []byte) error {
var j localStateJSON
if err := json.Unmarshal(data, &j); err != nil {
return err
}
keyBytes, err := base64.StdEncoding.DecodeString(j.PrivKey)
if err != nil {
return fmt.Errorf("decode key: %w", err)
}
key, err := wgtypes.NewKey(keyBytes)
if err != nil {
return fmt.Errorf("invalid key: %w", err)
}
signKeyBytes, err := base64.StdEncoding.DecodeString(j.SignKey)
if err != nil {
return fmt.Errorf("decode sign key: %w", err)
}
if len(signKeyBytes) != 64 {
return fmt.Errorf("invalid sign key length: %d", len(signKeyBytes))
}
*s = LocalState{
PrivKey: key,
SignKey: [64]byte(signKeyBytes),
VPNIP: j.VPNIP,
VPNNet: j.VPNNet,
WGPort: j.WGPort,
IsRelay: j.IsRelay,
IsPublic: j.IsPublic,
LocalDomain: j.LocalDomain,
}
return nil
}

28
peer/interfaces.go Normal file
View File

@@ -0,0 +1,28 @@
package peer
import (
"net/netip"
"vppn/peer/control"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// WGDevice is the subset of wginterface.Device used by App.
type WGDevice interface {
Name() string
Peers() ([]wgtypes.Peer, error)
AddPeer(pubKey wgtypes.Key) error
AddDirect(pubKey wgtypes.Key, endpoint netip.AddrPort, vpnIP netip.Addr) error
SetRelay(pubKey wgtypes.Key, endpoint netip.AddrPort, network netip.Prefix) error
AddProbe(pubKey wgtypes.Key, endpoint netip.AddrPort) error
Promote(pubKey wgtypes.Key, vpnIP netip.Addr) error
RemovePeer(pubKey wgtypes.Key) error
}
// ControlConn sends pings to peers over the VPN control port.
// Reading is handled separately via run, which feeds the App's pingCh.
// buf is a caller-provided scratch buffer (at least control.Size bytes) used to
// marshal the ping; the caller reuses one across sends.
type ControlConn interface {
SendPing(dst netip.AddrPort, ping control.Ping, buf []byte) error
}

36
peer/json.go Normal file
View File

@@ -0,0 +1,36 @@
package peer
import (
"encoding/json"
"os"
"path/filepath"
)
func loadJSON(path string, target any) error {
data, err := os.ReadFile(path)
if err != nil {
return err
}
return json.Unmarshal(data, target)
}
func storeJSON(path string, obj any) error {
if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil {
return err
}
data, err := json.MarshalIndent(obj, "", " ")
if err != nil {
return err
}
tmpPath := path + ".tmp"
if err := os.WriteFile(tmpPath, data, 0600); err != nil {
return err
}
if err := os.Rename(tmpPath, path); err != nil {
os.Remove(tmpPath)
return err
}
return nil
}

View File

@@ -0,0 +1,62 @@
package multicast
import (
"log"
"net"
"net/netip"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
var addr = net.UDPAddrFromAddrPort(netip.AddrPortFrom(
netip.AddrFrom4([4]byte{224, 0, 0, 157}),
4560))
func Broadcast(
selfVPNIP netip.Addr,
pubKey wgtypes.Key,
wgPort uint16,
signKey *[64]byte,
) {
for {
broadcastInner(selfVPNIP, pubKey, wgPort, signKey)
time.Sleep(errorTimeout)
}
}
func broadcastInner(selfVPNIP netip.Addr, pubKey wgtypes.Key, wgPort uint16, signKey *[64]byte) {
conn, err := net.ListenMulticastUDP("udp", nil, addr)
if err != nil {
log.Printf("[MCBroadcast] bind: %v", err)
return
}
defer conn.Close()
buf := make([]byte, BufferSize)
packet := Packet{
PeerIP: selfVPNIP.As4()[3],
WGPubKey: pubKey,
WGPort: wgPort,
}
// Re-sign on each send so the timestamp is fresh; a stale timestamp would be
// dropped by receivers' freshness gate.
send := func() error {
packet.Timestamp = time.Now().Unix()
payload := packet.Marshal(buf, signKey)
_, err := conn.WriteToUDP(payload, addr)
return err
}
if err := send(); err != nil {
log.Printf("[MCBroadcast] write: %v", err)
}
for range time.Tick(broadcastInterval) {
if err := send(); err != nil {
log.Printf("[MCBroadcast] write: %v", err)
return
}
}
}

9
peer/multicast/global.go Normal file
View File

@@ -0,0 +1,9 @@
package multicast
import "time"
const (
errorTimeout = 16 * time.Second
broadcastInterval = 16 * time.Second
maxPacketAge = time.Minute
)

54
peer/multicast/packet.go Normal file
View File

@@ -0,0 +1,54 @@
package multicast
import (
"encoding/binary"
"net/netip"
"golang.org/x/crypto/nacl/sign"
)
const (
BufferSize = packetSize + SignedPacketSize
SignedPacketSize = packetSize + signSize
packetSize = 43
signSize = 64
)
// Layout:
//
// [0] final octet of the sender's VPN IP
// [1:33] WG public key
// [33:35] WG listen port (big-endian uint16)
// [35:43] send time, Unix seconds (big-endian int64) — freshness/replay gate
type Packet struct {
PeerIP byte // Final octet of the sender's VPN IP.
WGPubKey [32]byte // WG public key.
WGPort uint16 // WG listen port.
Timestamp int64 // Unix timestamp.
Src netip.Addr // Source of packet.
Signed []byte // Raw signed message for verification (incoming packet).
}
// Marshal the packet into a buffer with prefixed signature.
func (p Packet) Marshal(buf []byte, signKey *[64]byte) []byte {
buf[0] = p.PeerIP
copy(buf[1:33], p.WGPubKey[:])
binary.BigEndian.PutUint16(buf[33:35], p.WGPort)
binary.BigEndian.PutUint64(buf[35:43], uint64(p.Timestamp))
return sign.Sign(buf[packetSize:packetSize], buf[:packetSize], signKey)
}
func (p Packet) Verify(buf []byte, pubKey *[32]byte) bool {
_, ok := sign.Open(buf, p.Signed, pubKey)
return ok
}
func Unmarshal(signed []byte) (p Packet) {
buf := signed[signSize:]
p.PeerIP = buf[0]
copy(p.WGPubKey[:], buf[1:33])
p.WGPort = binary.BigEndian.Uint16(buf[33:35])
p.Timestamp = int64(binary.BigEndian.Uint64(buf[35:43]))
p.Signed = signed
return
}

View File

@@ -0,0 +1,38 @@
package multicast
import (
"crypto/rand"
"testing"
"golang.org/x/crypto/nacl/sign"
)
func TestPacket(t *testing.T) {
pub, priv, err := sign.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
p := Packet{
PeerIP: 10,
WGPubKey: [32]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2},
WGPort: 44,
Timestamp: 12948893,
}
buf := make([]byte, BufferSize)
signed := p.Marshal(buf, priv)
if len(signed) != SignedPacketSize {
t.Fatalf("signed length = %d, want %d", len(signed), SignedPacketSize)
}
got := Unmarshal(signed)
if got.PeerIP != p.PeerIP || got.WGPubKey != p.WGPubKey ||
got.WGPort != p.WGPort || got.Timestamp != p.Timestamp {
t.Fatalf("round-trip mismatch:\n got %+v\nwant %+v", got, p)
}
if !got.Verify(nil, pub) {
t.Error("signature did not verify")
}
}

View File

@@ -0,0 +1,61 @@
package multicast
import (
"bytes"
"fmt"
"log"
"net"
"net/netip"
"time"
)
func Receiver(vpnNet netip.Prefix, selfVPNIP netip.Addr, ch chan<- Packet) {
for {
if err := receiver(vpnNet, selfVPNIP, ch); err != nil {
log.Printf("[MCReader] %v", err)
}
time.Sleep(errorTimeout)
}
}
func receiver(vpnNet netip.Prefix, selfVPNIP netip.Addr, ch chan<- Packet) error {
selfIP := selfVPNIP.As4()[3]
conn, err := net.ListenMulticastUDP("udp", nil, addr)
if err != nil {
return fmt.Errorf("bind: %w", err)
}
defer conn.Close()
buf := make([]byte, BufferSize+1) // +1 to detect oversized packets
for {
conn.SetReadDeadline(time.Now().Add(32 * time.Second))
n, src, err := conn.ReadFromUDPAddrPort(buf)
if err != nil {
if ne, ok := err.(net.Error); ok && ne.Timeout() {
continue
}
return fmt.Errorf("read: %w", err)
}
if n != SignedPacketSize {
continue
}
packet := Unmarshal(buf[:n])
if packet.PeerIP == selfIP {
continue
}
age := time.Since(time.Unix(packet.Timestamp, 0))
if age > maxPacketAge || age < -maxPacketAge {
continue
}
packet.Signed = bytes.Clone(packet.Signed)
packet.Src = src.Addr().Unmap()
ch <- packet
}
}

18
peer/network_state.go Normal file
View File

@@ -0,0 +1,18 @@
package peer
import "vppn/m"
// loadNetworkState reads a cached network state from disk. Any error (most
// commonly a missing file on first run) is returned to the caller, which
// treats it as "no cache available".
func loadNetworkState(path string) (m.NetworkState, error) {
var state m.NetworkState
err := loadJSON(path, &state)
return state, err
}
// saveNetworkState writes state to path atomically (see storeJSON), so a crash
// mid-write cannot leave a corrupt cache.
func saveNetworkState(path string, state m.NetworkState) error {
return storeJSON(path, state)
}

View File

@@ -0,0 +1,56 @@
package peer
import (
"net/netip"
"path/filepath"
"reflect"
"testing"
"vppn/m"
)
func TestNetworkState_RoundTrip(t *testing.T) {
path := filepath.Join(t.TempDir(), "network.json")
var sign1 [32]byte
copy(sign1[:], []byte("0123456789abcdef0123456789abcdef"))
state := m.NetworkState{Peers: []m.Peer{
{
PeerIP: 1,
Name: "hub",
Addr4: netip.MustParseAddr("10.11.12.1"),
Port: 51820,
Relay: true,
WGPubKey: mustKey(t),
SignPubKey: sign1,
},
{
PeerIP: 10,
Name: "laptop",
Addr4: netip.MustParseAddr("10.11.12.10"),
Port: 51820,
WGPubKey: mustKey(t),
},
}}
if err := saveNetworkState(path, state); err != nil {
t.Fatal(err)
}
got, err := loadNetworkState(path)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(got, state) {
t.Errorf("round-trip mismatch:\n got: %+v\nwant: %+v", got.Peers[1], state.Peers[1])
}
}
func TestNetworkState_LoadMissing(t *testing.T) {
path := filepath.Join(t.TempDir(), "does-not-exist.json")
if _, err := loadNetworkState(path); err == nil {
t.Fatal("expected error loading missing cache, got nil")
}
}

109
peer/new.go Normal file
View File

@@ -0,0 +1,109 @@
package peer
import (
"fmt"
"net/netip"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"vppn/m"
"vppn/peer/multicast"
"vppn/peer/wginterface"
)
// New constructs an App, creates the WireGuard interface, and starts the
// background goroutines (hub poller, multicast, control conn reader).
// The caller should invoke Run() to start the event loop.
func New(
state LocalState,
hubURL, apiKey string,
ifaceName string,
localDomain string,
networkStatePath string,
) (*App, error) {
a4 := state.VPNIP.As4()
if err := wginterface.Create(ifaceName, a4[:], 24); err != nil {
return nil, fmt.Errorf("create WG interface: %w", err)
}
dev, err := wginterface.Open(ifaceName)
if err != nil {
_ = wginterface.Delete(ifaceName)
return nil, fmt.Errorf("open WG device: %w", err)
}
cc, err := newUDPControlConn(state.VPNIP, ControlPort)
if err != nil {
_ = dev.Close()
_ = wginterface.Delete(ifaceName)
return nil, fmt.Errorf("control conn: %w", err)
}
cleanup := func() {
_ = cc.Close()
_ = dev.Close()
_ = wginterface.Delete(ifaceName)
}
if err := dev.Configure(state.PrivKey, int(state.WGPort)); err != nil {
cleanup()
return nil, fmt.Errorf("configure WG device: %w", err)
}
if state.IsRelay {
if err := dev.EnableForwarding(); err != nil {
cleanup()
return nil, fmt.Errorf("enable forwarding: %w", err)
}
}
pingCh := make(chan PingEvent)
hubAddCh := make(chan m.Peer)
hubRemoveCh := make(chan wgtypes.Key)
multicastCh := make(chan multicast.Packet)
poller, err := NewHubPoller(
state.VPNIP,
state.VPNNet,
hubURL,
apiKey,
networkStatePath,
hubAddCh,
hubRemoveCh)
if err != nil {
cleanup()
return nil, fmt.Errorf("hub poller: %w", err)
}
go cc.run(pingCh)
go poller.Run()
if !state.IsPublic {
go multicast.Broadcast(state.VPNIP, state.PrivKey.PublicKey(), state.WGPort, &state.SignKey)
go multicast.Receiver(state.VPNNet, state.VPNIP, multicastCh)
}
return &App{
vpnIP: state.VPNIP,
vpnNet: state.VPNNet,
privKey: state.PrivKey,
pubKey: state.PrivKey.PublicKey(),
isRelay: state.IsRelay,
isPublic: state.IsPublic,
localDomain: localDomain,
dev: dev,
controlConn: cc,
peersByKey: make(map[wgtypes.Key]*Peer),
peersByIP: make(map[netip.Addr]*Peer),
scratch: make([]byte, scratchSize),
hubAddCh: hubAddCh,
hubRemoveCh: hubRemoveCh,
pingCh: pingCh,
multicastCh: multicastCh,
}, nil
}

114
peer/on_hub.go Normal file
View File

@@ -0,0 +1,114 @@
package peer
import (
"log"
"math"
"net/netip"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"vppn/m"
"vppn/peer/control"
)
func (a *App) onAddPeer(p m.Peer) {
a.onRemovePeer(p.WGPubKey)
octets := a.vpnNet.Addr().As4()
octets[3] = p.PeerIP
vpnIP := netip.AddrFrom4(octets)
peer := &Peer{
wgPeer: wgtypes.Peer{PublicKey: p.WGPubKey},
VPNIP: vpnIP,
Name: p.Name,
IsRelay: p.Relay,
IsPublic: p.IsPublic(),
EndpointV4: p.Endpoint4(),
EndpointV6: p.Endpoint6(),
RTT: time.Duration(math.MaxInt64) * time.Nanosecond,
Role: roleFor(a.isPublic, a.vpnIP, p.IsPublic(), vpnIP),
SignPubKey: p.SignPubKey,
}
a.peersByKey[p.WGPubKey] = peer
a.peersByIP[peer.VPNIP] = peer
defer a.updateHosts()
if !peer.IsPublic {
if a.isPublic {
// Public nodes accept traffic from non-public peers as soon as they
// initiate a handshake. Set /32 AllowedIPs now; WireGuard learns the
// endpoint from the incoming handshake automatically.
a.devPromote(peer)
} else {
a.devAddPeer(peer)
}
return
}
a.devAddDirect(peer, peer.PreferredEndpoint())
}
func (a *App) onRemovePeer(key wgtypes.Key) {
peer, exists := a.peersByKey[key]
if !exists {
return
}
a.devRemove(peer)
delete(a.peersByKey, key)
delete(a.peersByIP, peer.VPNIP)
a.updateHosts()
if peer == a.relay {
a.relay = nil
a.switchActiveRelay()
}
}
// switchActiveRelay promotes the lowest-latency relay peer to active.
func (a *App) switchActiveRelay() {
if a.relay != nil {
// If we have a relay, it's public, so should go back to being a direct
// peer - this will convert it's /24 to a /32.
a.devAddDirect(a.relay, a.relay.PreferredEndpoint())
a.relay = nil
}
var best *Peer
for _, p := range a.peersByKey {
if !p.CanRelay() {
continue
}
if best == nil || p.RTT < best.RTT {
best = p
}
}
if best == nil {
log.Printf("no relay available")
return
}
a.devSetRelay(best, best.PreferredEndpoint())
a.relay = best
}
func preferredEndpoint(v4, v6 netip.AddrPort) netip.AddrPort {
// We always prefer v4 since all peers can connect to IPv4 addresses.
if v4.IsValid() {
return v4
}
return v6
}
func roleFor(selfIsPublic bool, selfIP netip.Addr, peerIsPublic bool, peerVPNIP netip.Addr) control.Role {
if !selfIsPublic && peerIsPublic {
return control.Client
}
if selfIsPublic && !peerIsPublic {
return control.Server
}
return control.RoleFor(selfIP, peerVPNIP)
}

299
peer/on_hub_test.go Normal file
View File

@@ -0,0 +1,299 @@
package peer
import (
"net/netip"
"testing"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"vppn/m"
)
func mustKey(t *testing.T) wgtypes.Key {
t.Helper()
k, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatalf("generate key: %v", err)
}
return k.PublicKey()
}
func TestOnAddPeer(t *testing.T) {
ep1 := netip.MustParseAddrPort("1.2.3.4:51820")
ep2 := netip.MustParseAddrPort("5.6.7.8:51820")
peerVPNIP := netip.MustParseAddr("10.0.0.2")
testCases := []struct {
name string
setup func(a *App, key wgtypes.Key)
peer func(key wgtypes.Key) m.Peer
check func(t *testing.T, a *App, dev *fakeWGDevice, key wgtypes.Key)
}{
{
name: "non-public peer registered in WG via AddPeer",
peer: func(k wgtypes.Key) m.Peer {
return m.Peer{WGPubKey: k, PeerIP: 2}
},
check: func(t *testing.T, a *App, dev *fakeWGDevice, key wgtypes.Key) {
p := a.peersByKey[key]
if p == nil {
t.Fatal("not in peersByKey")
}
if a.peersByIP[peerVPNIP] == nil {
t.Fatal("not in peersByIP")
}
if p.State != StateRelayed {
t.Fatalf("state = %v, want StateRelayed", p.State)
}
dev.AssertAddPeer(t, 0, key)
},
},
{
name: "public peer with endpoint registered via AddDirect",
peer: func(k wgtypes.Key) m.Peer {
return m.Peer{WGPubKey: k, PeerIP: 2, Addr4: ep1.Addr(), Port: ep1.Port()}
},
check: func(t *testing.T, a *App, dev *fakeWGDevice, key wgtypes.Key) {
p := a.peersByKey[key]
if p == nil {
t.Fatal("not in peersByKey")
}
dev.AssertAddDirect(t, 0, p.PubKey(), ep1, p.VPNIP)
},
},
{
name: "re-add removes old WG entry before adding new one",
setup: func(a *App, key wgtypes.Key) {
a.onAddPeer(m.Peer{WGPubKey: key, PeerIP: 2, Addr4: ep1.Addr(), Port: ep1.Port()})
},
peer: func(k wgtypes.Key) m.Peer {
return m.Peer{WGPubKey: k, PeerIP: 2, Addr4: ep2.Addr(), Port: ep2.Port()}
},
check: func(t *testing.T, a *App, dev *fakeWGDevice, key wgtypes.Key) {
if len(dev.Calls) != 2 {
t.Fatalf("dev calls = %v, want [RemovePeer, AddDirect]", dev.Calls)
}
dev.AssertRemovePeer(t, 0, key)
dev.AssertAddDirect(t, 1, key, ep2, peerVPNIP)
if len(a.peersByKey) != 1 || len(a.peersByIP) != 1 {
t.Errorf("maps: peersByKey=%d peersByIP=%d, want 1 each", len(a.peersByKey), len(a.peersByIP))
}
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
a, dev, _ := newTestApp(t, "10.0.0.1", false, false)
key := mustKey(t)
if tc.setup != nil {
tc.setup(a, key)
dev.Calls = nil
}
a.onAddPeer(tc.peer(key))
tc.check(t, a, dev, key)
})
}
}
func TestOnRemovePeer(t *testing.T) {
ep1 := netip.MustParseAddrPort("1.2.3.4:51820")
ep2 := netip.MustParseAddrPort("5.6.7.8:51820")
testCases := []struct {
name string
setup func(t *testing.T, a *App) wgtypes.Key // returns the key to remove
check func(t *testing.T, a *App, dev *fakeWGDevice)
}{
{
name: "unknown key is a no-op",
setup: func(t *testing.T, a *App) wgtypes.Key {
return mustKey(t)
},
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
dev.AssertNoCalls(t)
if len(a.peersByKey) != 0 {
t.Errorf("peersByKey should be empty")
}
},
},
{
name: "StateRelayed peer removed from maps with RemovePeer",
setup: func(t *testing.T, a *App) wgtypes.Key {
key := mustKey(t)
a.onAddPeer(m.Peer{WGPubKey: key, PeerIP: 2})
return key
},
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
if len(dev.Calls) != 1 {
t.Fatalf("dev calls = %v, want [RemovePeer]", dev.Calls)
}
dev.AssertRemovePeer(t, 0, dev.Calls[0].PubKey)
if len(a.peersByKey) != 0 || len(a.peersByIP) != 0 {
t.Errorf("maps should be empty after remove")
}
},
},
{
name: "StateDirect peer removed from maps with RemovePeer",
setup: func(t *testing.T, a *App) wgtypes.Key {
key := mustKey(t)
a.onAddPeer(m.Peer{WGPubKey: key, PeerIP: 2, Addr4: ep1.Addr(), Port: ep1.Port()})
return key
},
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
if len(dev.Calls) != 1 {
t.Fatalf("dev calls = %v, want [RemovePeer]", dev.Calls)
}
dev.AssertRemovePeer(t, 0, dev.Calls[0].PubKey)
if len(a.peersByKey) != 0 || len(a.peersByIP) != 0 {
t.Errorf("maps should be empty after remove")
}
},
},
{
name: "removing active relay with no backup clears relay field",
setup: func(t *testing.T, a *App) wgtypes.Key {
relay := addRelayPeer(t, a, "10.0.0.10", ep1)
a.relay = relay
return relay.PubKey()
},
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
if len(dev.Calls) != 1 {
t.Fatalf("dev calls = %v, want [RemovePeer]", dev.Calls)
}
dev.AssertRemovePeer(t, 0, dev.Calls[0].PubKey)
if a.relay != nil {
t.Errorf("relay should be nil after removing only relay")
}
},
},
{
name: "removing active relay elects backup via SetRelay",
setup: func(t *testing.T, a *App) wgtypes.Key {
relay1 := addRelayPeer(t, a, "10.0.0.10", ep1)
addRelayPeer(t, a, "10.0.0.11", ep2)
a.relay = relay1
return relay1.PubKey()
},
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
if len(dev.Calls) != 2 {
t.Fatalf("dev calls = %v, want [RemovePeer, SetRelay]", dev.Calls)
}
dev.AssertRemovePeer(t, 0, dev.Calls[0].PubKey)
dev.AssertSetRelay(t, 1, dev.Calls[1].PubKey, ep2, a.vpnNet)
if a.relay == nil {
t.Errorf("relay should be set to backup after failover")
}
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
a, dev, _ := newTestApp(t, "10.0.0.1", false, false)
key := tc.setup(t, a)
dev.Calls = nil
a.onRemovePeer(key)
tc.check(t, a, dev)
})
}
}
func TestSwitchActiveRelay(t *testing.T) {
ep1 := netip.MustParseAddrPort("1.2.3.4:51820")
ep2 := netip.MustParseAddrPort("5.6.7.8:51820")
testCases := []struct {
name string
setup func(t *testing.T, a *App)
check func(t *testing.T, a *App, dev *fakeWGDevice)
}{
{
name: "no candidates leaves relay nil",
setup: func(t *testing.T, a *App) {},
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
dev.AssertNoCalls(t)
if a.relay != nil {
t.Error("relay should be nil")
}
},
},
{
name: "single candidate elected via SetRelay",
setup: func(t *testing.T, a *App) {
addRelayPeer(t, a, "10.0.0.10", ep1)
},
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
if len(dev.Calls) != 1 {
t.Fatalf("dev calls = %v, want [SetRelay]", dev.Calls)
}
dev.AssertSetRelay(t, 0, dev.Calls[0].PubKey, ep1, a.vpnNet)
if a.relay == nil {
t.Error("relay should be set")
}
},
},
{
name: "measured RTT beats zero RTT",
setup: func(t *testing.T, a *App) {
r1 := addRelayPeer(t, a, "10.0.0.10", ep1)
r1.RTT = 10 * time.Millisecond
addRelayPeer(t, a, "10.0.0.11", ep2) // RTT stays MaxInt64 (unmeaured)
},
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
if len(dev.Calls) != 1 {
t.Fatalf("dev calls = %v, want [SetRelay]", dev.Calls)
}
dev.AssertSetRelay(t, 0, dev.Calls[0].PubKey, ep1, a.vpnNet)
},
},
{
name: "lower RTT wins",
setup: func(t *testing.T, a *App) {
r1 := addRelayPeer(t, a, "10.0.0.10", ep1)
r1.RTT = 5 * time.Millisecond
r2 := addRelayPeer(t, a, "10.0.0.11", ep2)
r2.RTT = 20 * time.Millisecond
},
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
if len(dev.Calls) != 1 {
t.Fatalf("dev calls = %v, want [SetRelay]", dev.Calls)
}
dev.AssertSetRelay(t, 0, dev.Calls[0].PubKey, ep1, a.vpnNet)
},
},
{
name: "stale relay demoted to direct before backup elected",
setup: func(t *testing.T, a *App) {
old := addRelayPeer(t, a, "10.0.0.10", ep1)
old.wgPeer.LastHandshakeTime = time.Time{} // stale — triggers switch from onTick
a.relay = old
addRelayPeer(t, a, "10.0.0.11", ep2)
},
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
if len(dev.Calls) != 2 {
t.Fatalf("dev calls = %v, want [AddDirect, SetRelay]", dev.Calls)
}
if dev.Calls[0].Method != "AddDirect" || dev.Calls[0].Endpoint != ep1 {
t.Errorf("call[0]: got %v, want AddDirect with ep1", dev.Calls[0])
}
dev.AssertSetRelay(t, 1, dev.Calls[1].PubKey, ep2, a.vpnNet)
if a.relay == nil || a.relay.EndpointV4 != ep2 {
t.Error("relay should be the backup peer")
}
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
a, dev, _ := newTestApp(t, "10.0.0.1", false, false)
tc.setup(t, a)
dev.Calls = nil
a.switchActiveRelay()
tc.check(t, a, dev)
})
}
}

Some files were not shown because too many files have changed in this diff Show More