14 Commits

Author SHA1 Message Date
jdl
3ebfe754e7 Removed accidentially committed compiled file. 2025-09-17 09:59:15 +02:00
jdl
069243e5d4 Cleanup / fixes 2025-09-16 21:05:19 +02:00
jdl
bd78ffd669 Bug fix! 2025-09-16 20:38:43 +02:00
jdl
a90ab3f5d6 Added hosts command 2025-09-16 14:18:18 +02:00
jdl
650c74c013 Added unix socket client timeout 2025-09-15 21:18:27 +02:00
jdl
b308150d21 cleanup 2025-09-15 21:15:25 +02:00
jdl
a0b7ecbfe0 Fix syn-ack bug 2025-09-15 21:15:05 +02:00
jdl
69dff24344 Cleanup. 2025-09-15 15:08:28 +02:00
jdl
257fac67ce Cleanup 2025-09-15 15:05:55 +02:00
jdl
2ff8aaf5c4 Cleanup 2025-09-15 15:05:15 +02:00
jdl
fccc4f7d57 WIP 2025-09-15 04:56:23 +02:00
jdl
c6d35856bc FSM logic cleanup 2025-09-15 04:29:30 +02:00
jdl
5844584219 FSM logic cleanup 2025-09-15 04:24:00 +02:00
jdl
e458e43d83 WIP 2025-09-15 04:07:56 +02:00
106 changed files with 4274 additions and 3790 deletions

View File

@@ -1,5 +1,9 @@
# vppn: Virtual Potentially Private Network # vppn: Virtual Potentially Private Network
## TO DO
* Double buffering in IFReader and ConnReader ?
## Hub Server Configuration ## Hub Server Configuration
``` ```
@@ -53,17 +57,15 @@ Sign-in and configure.
Install the binary somewhere, for example `~/bin/vppn`. Install the binary somewhere, for example `~/bin/vppn`.
Add the API key for your network name in `~/.vppn/<netname>/apikey`.
Create systemd file in `/etc/systemd/system/vppn.service`. Create systemd file in `/etc/systemd/system/vppn.service`.
``` ```
[Service] [Service]
AmbientCapabilities=AP_NET_ADMIN CAP_DAC_OVERRIDE CAP_CHOWN AmbientCapabilities=CAP_NET_BIND_SERVICE CAP_NET_ADMIN
Type=simple Type=simple
User=user User=user
WorkingDirectory=/home/user/ WorkingDirectory=/home/user/
ExecStart=/home/user/bin/vppn -name my_net_name -hub https://my.hub ExecStart=/home/user/vppn run my_net_name https://my.hub my_api_key
Restart=always Restart=always
RestartSec=8 RestartSec=8
TimeoutStopSec=24 TimeoutStopSec=24

View File

@@ -1,72 +1,11 @@
package main package main
import ( import (
"flag"
"log" "log"
"os"
"path/filepath"
"strings"
"vppn/peer" "vppn/peer"
"git.crumpington.com/lib/go/flock"
) )
func main() { func main() {
log.SetFlags(0) log.SetFlags(0)
peer.Main2()
name := flag.String("name", "", "network name (required)")
hub := flag.String("hub", "", "hub base URL (required)")
flag.Parse()
if *name == "" || *hub == "" {
flag.Usage()
os.Exit(1)
}
apiKey, err := loadAPIKey(*name)
if err != nil {
log.Fatalf("api key: %v", err)
}
// Directory existence is guaranteed by the apikey file read above.
lockFile, err := flock.TryLock(vppnPath(*name, "lock"))
if err != nil {
log.Fatalf("lock: %v", err)
}
if lockFile == nil {
log.Fatalf("already running for network %q", *name)
}
defer flock.Unlock(lockFile)
state, err := peer.LoadOrInit(vppnPath(*name, "state.json"), *hub, apiKey)
if err != nil {
log.Fatalf("init: %v", err)
}
ifaceName := strings.TrimSuffix(state.LocalDomain, ".local")
app, err := peer.New(state, *hub, apiKey, ifaceName, state.LocalDomain, vppnPath(*name, "network.json"))
if err != nil {
log.Fatalf("start: %v", err)
}
if err := app.Run(); err != nil {
log.Fatalf("run: %v", err)
}
}
func loadAPIKey(name string) (string, error) {
data, err := os.ReadFile(vppnPath(name, "apikey"))
if err != nil {
return "", err
}
return strings.TrimSpace(string(data)), nil
}
func vppnPath(name, file string) string {
home, err := os.UserHomeDir()
if err != nil {
return filepath.Join(".vppn", name, file)
}
return filepath.Join(home, ".vppn", name, file)
} }

8
go.mod
View File

@@ -6,18 +6,10 @@ require (
git.crumpington.com/lib/go v0.9.1 git.crumpington.com/lib/go v0.9.1
golang.org/x/crypto v0.42.0 golang.org/x/crypto v0.42.0
golang.org/x/sys v0.36.0 golang.org/x/sys v0.36.0
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
) )
require ( require (
github.com/google/go-cmp v0.6.0 // indirect
github.com/josharian/native v1.1.0 // indirect
github.com/mattn/go-sqlite3 v1.14.32 // indirect github.com/mattn/go-sqlite3 v1.14.32 // indirect
github.com/mdlayher/genetlink v1.3.2 // indirect
github.com/mdlayher/netlink v1.7.2 // indirect
github.com/mdlayher/socket v0.5.1 // indirect
golang.org/x/net v0.44.0 // indirect golang.org/x/net v0.44.0 // indirect
golang.org/x/sync v0.17.0 // indirect
golang.org/x/text v0.29.0 // indirect golang.org/x/text v0.29.0 // indirect
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 // indirect
) )

18
go.sum
View File

@@ -1,30 +1,12 @@
git.crumpington.com/lib/go v0.9.1 h1:xLBzcgiZRB6Ky3Ce9hKE+Ko0YbkA4USF4eJk5i5RJF4= git.crumpington.com/lib/go v0.9.1 h1:xLBzcgiZRB6Ky3Ce9hKE+Ko0YbkA4USF4eJk5i5RJF4=
git.crumpington.com/lib/go v0.9.1/go.mod h1:5nnfjdnUnj/FHhakaliKQKsKeSkUb0GEUKF3PqRgUXg= git.crumpington.com/lib/go v0.9.1/go.mod h1:5nnfjdnUnj/FHhakaliKQKsKeSkUb0GEUKF3PqRgUXg=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw=
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI= golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI=
golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8= golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
golang.org/x/net v0.44.0 h1:evd8IRDyfNBMBTTY5XRF1vaZlD+EmWx6x8PkhR04H/I= golang.org/x/net v0.44.0 h1:evd8IRDyfNBMBTTY5XRF1vaZlD+EmWx6x8PkhR04H/I=
golang.org/x/net v0.44.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY= golang.org/x/net v0.44.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk= golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4= golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4=
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ=

View File

@@ -19,10 +19,8 @@ import (
var migrations embed.FS var migrations embed.FS
type API struct { type API struct {
db *sql.DB db *sql.DB
lock sync.Mutex lock sync.Mutex
sessionsMu sync.Mutex
sessions map[string]*Session
} }
func New(dbPath string) (*API, error) { func New(dbPath string) (*API, error) {
@@ -36,17 +34,10 @@ func New(dbPath string) (*API, error) {
} }
a := &API{ a := &API{
db: sqlDB, db: sqlDB,
sessions: make(map[string]*Session),
} }
if err := a.ensurePassword(); err != nil { return a, a.ensurePassword()
return nil, err
}
go a.sweepSessions()
return a, nil
} }
func (a *API) ensurePassword() error { func (a *API) ensurePassword() error {
@@ -71,8 +62,12 @@ func (a *API) ensurePassword() error {
return db.Config_Insert(a.db, conf) return db.Config_Insert(a.db, conf)
} }
func (a *API) Config_Get() (*Config, error) { func (a *API) Config_Get() *Config {
return db.Config_Get(a.db, 1) conf, err := db.Config_Get(a.db, 1)
if err != nil {
panic(err)
}
return conf
} }
func (a *API) Config_Update(conf *Config) error { func (a *API) Config_Update(conf *Config) error {
@@ -80,78 +75,56 @@ func (a *API) Config_Update(conf *Config) error {
} }
func (a *API) Session_Delete(sessionID string) error { func (a *API) Session_Delete(sessionID string) error {
a.sessionsMu.Lock() return db.Session_Delete(a.db, sessionID)
defer a.sessionsMu.Unlock()
delete(a.sessions, sessionID)
return nil
} }
const ( func (a *API) Session_Get(sessionID string) (*Session, error) {
sessionTTLSecs = 86400 * 21 // sessions expire 21 days after last use if sessionID == "" {
sessionSweepEvery = time.Hour // cadence of expired-session eviction return a.session_CreatePub()
)
// Session_Get returns a snapshot copy of the signed-in session for sessionID,
// or the zero Session if the cookie is missing/unknown/expired. It never
// creates a session, so anonymous requests cost no memory — a session is minted
// only by Session_SignIn. Returning a value (not the stored pointer) keeps
// callers from racing on the shared struct.
func (a *API) Session_Get(sessionID string) (Session, error) {
a.sessionsMu.Lock()
defer a.sessionsMu.Unlock()
s, ok := a.sessions[sessionID]
if sessionID == "" || !ok {
return Session{}, nil
} }
if timeSince(s.LastSeenAt) > sessionTTLSecs { session, err := db.Session_Get(a.db, sessionID)
delete(a.sessions, sessionID)
return Session{}, nil
}
s.LastSeenAt = time.Now().Unix()
return *s, nil
}
// Session_SignIn verifies pwd and, on success, mints a fresh signed-in session,
// returning it so the caller can set the cookie. A new ID per sign-in rotates
// the session at the privilege boundary (session-fixation resistance).
func (a *API) Session_SignIn(pwd string) (Session, error) {
conf, err := a.Config_Get()
if err != nil { if err != nil {
return Session{}, err return a.session_CreatePub()
}
if err := bcrypt.CompareHashAndPassword(conf.Password, []byte(pwd)); err != nil {
return Session{}, ErrNotAuthorized
} }
a.sessionsMu.Lock() if timeSince(session.LastSeenAt) > 86400*21 {
defer a.sessionsMu.Unlock() return a.session_CreatePub()
}
if timeSince(session.LastSeenAt) > 86400*7 {
session.LastSeenAt = time.Now().Unix()
if err := db.Session_UpdateLastSeenAt(a.db, session.SessionID); err != nil {
log.Printf("Failed to update session: %v", err)
}
}
return session, nil
}
func (a *API) session_CreatePub() (*Session, error) {
s := &Session{ s := &Session{
SessionID: idgen.NewToken(), SessionID: idgen.NewToken(),
SignedIn: true, CSRF: idgen.NewToken(),
SignedIn: false,
CreatedAt: time.Now().Unix(), CreatedAt: time.Now().Unix(),
LastSeenAt: time.Now().Unix(), LastSeenAt: time.Now().Unix(),
} }
a.sessions[s.SessionID] = s err := db.Session_Insert(a.db, s)
return *s, nil return s, err
} }
// sweepSessions periodically evicts sessions past their TTL. Without it, a func (a *API) Session_DeleteBefore(timestamp int64) error {
// signed-in session whose ID is never presented again would linger forever return db.Session_DeleteBefore(a.db, timestamp)
// (Session_Get only evicts on a lookup of that same ID). }
func (a *API) sweepSessions() {
for range time.Tick(sessionSweepEvery) { func (a *API) Session_SignIn(s *Session, pwd string) error {
a.sessionsMu.Lock() conf := a.Config_Get()
for id, s := range a.sessions { if err := bcrypt.CompareHashAndPassword(conf.Password, []byte(pwd)); err != nil {
if timeSince(s.LastSeenAt) > sessionTTLSecs { return ErrNotAuthorized
delete(a.sessions, id)
}
}
a.sessionsMu.Unlock()
} }
return db.Session_SetSignedIn(a.db, s.SessionID)
} }
func (a *API) Network_Create(n *Network) error { func (a *API) Network_Create(n *Network) error {
@@ -168,13 +141,14 @@ func (a *API) Network_Get(id int64) (*Network, error) {
} }
func (a *API) Network_List() ([]*Network, error) { func (a *API) Network_List() ([]*Network, error) {
const query = db.Network_SelectQuery + ` ORDER BY LocalDomain ASC` const query = db.Network_SelectQuery + ` ORDER BY Name ASC`
return db.Network_List(a.db, query) return db.Network_List(a.db, query)
} }
func (a *API) Peer_CreateNew(p *Peer) error { func (a *API) Peer_CreateNew(p *Peer) error {
p.WGPubKey = []byte{} p.Version = idgen.NextID(0)
p.SignPubKey = []byte{} p.PubKey = []byte{}
p.PubSignKey = []byte{}
p.APIKey = idgen.NewToken() p.APIKey = idgen.NewToken()
return db.Peer_Insert(a.db, p) return db.Peer_Insert(a.db, p)
@@ -184,22 +158,21 @@ func (a *API) Peer_Init(peer *Peer, args m.PeerInitArgs) error {
a.lock.Lock() a.lock.Lock()
defer a.lock.Unlock() defer a.lock.Unlock()
// Re-read from DB inside the lock — the caller's copy was fetched before peer.Version = idgen.NextID(0)
// we held the lock, so it may be stale under concurrent requests. peer.PubKey = args.EncPubKey
current, err := db.Peer_Get(a.db, peer.NetworkID, peer.PeerIP) peer.PubSignKey = args.PubSignKey
if err != nil {
return err
}
if len(current.WGPubKey) != 0 {
return errors.New("peer already initialized")
}
peer.WGPubKey = args.WGPubKey
peer.SignPubKey = args.SignPubKey
return db.Peer_UpdateFull(a.db, peer) return db.Peer_UpdateFull(a.db, peer)
} }
func (a *API) Peer_Update(p *Peer) error {
a.lock.Lock()
defer a.lock.Unlock()
p.Version = idgen.NextID(0)
return db.Peer_Update(a.db, p)
}
func (a *API) Peer_Delete(networkID int64, peerIP byte) error { func (a *API) Peer_Delete(networkID int64, peerIP byte) error {
return db.Peer_Delete(a.db, networkID, peerIP) return db.Peer_Delete(a.db, networkID, peerIP)
} }

View File

@@ -123,9 +123,7 @@ func Config_Get(
) { ) {
row = &Config{} row = &Config{}
r := tx.QueryRow("SELECT ConfigID,Password FROM config WHERE ConfigID=?", ConfigID) r := tx.QueryRow("SELECT ConfigID,Password FROM config WHERE ConfigID=?", ConfigID)
if err = r.Scan(&row.ConfigID, &row.Password); err != nil { err = r.Scan(&row.ConfigID, &row.Password)
row = nil
}
return return
} }
@@ -139,9 +137,7 @@ func Config_GetWhere(
) { ) {
row = &Config{} row = &Config{}
r := tx.QueryRow(query, args...) r := tx.QueryRow(query, args...)
if err = r.Scan(&row.ConfigID, &row.Password); err != nil { err = r.Scan(&row.ConfigID, &row.Password)
row = nil
}
return return
} }
@@ -186,17 +182,135 @@ func Config_List(
return l, nil return l, nil
} }
// ----------------------------------------------------------------------------
// Table: sessions
// ----------------------------------------------------------------------------
type Session struct {
SessionID string
CSRF string
SignedIn bool
CreatedAt int64
LastSeenAt int64
}
const Session_SelectQuery = "SELECT SessionID,CSRF,SignedIn,CreatedAt,LastSeenAt FROM sessions"
func Session_Insert(
tx TX,
row *Session,
) (err error) {
Session_Sanitize(row)
if err = Session_Validate(row); err != nil {
return err
}
_, err = tx.Exec("INSERT INTO sessions(SessionID,CSRF,SignedIn,CreatedAt,LastSeenAt) VALUES(?,?,?,?,?)", row.SessionID, row.CSRF, row.SignedIn, row.CreatedAt, row.LastSeenAt)
return err
}
func Session_Delete(
tx TX,
SessionID string,
) (err error) {
result, err := tx.Exec("DELETE FROM sessions WHERE SessionID=?", SessionID)
if err != nil {
return err
}
n, err := result.RowsAffected()
if err != nil {
panic(err)
}
switch n {
case 0:
return sql.ErrNoRows
case 1:
return nil
default:
panic("multiple rows deleted")
}
}
func Session_Get(
tx TX,
SessionID string,
) (
row *Session,
err error,
) {
row = &Session{}
r := tx.QueryRow("SELECT SessionID,CSRF,SignedIn,CreatedAt,LastSeenAt FROM sessions WHERE SessionID=?", SessionID)
err = r.Scan(&row.SessionID, &row.CSRF, &row.SignedIn, &row.CreatedAt, &row.LastSeenAt)
return
}
func Session_GetWhere(
tx TX,
query string,
args ...any,
) (
row *Session,
err error,
) {
row = &Session{}
r := tx.QueryRow(query, args...)
err = r.Scan(&row.SessionID, &row.CSRF, &row.SignedIn, &row.CreatedAt, &row.LastSeenAt)
return
}
func Session_Iterate(
tx TX,
query string,
args ...any,
) iter.Seq2[*Session, error] {
rows, err := tx.Query(query, args...)
if err != nil {
return func(yield func(*Session, error) bool) {
yield(nil, err)
}
}
return func(yield func(*Session, error) bool) {
defer rows.Close()
for rows.Next() {
row := &Session{}
err := rows.Scan(&row.SessionID, &row.CSRF, &row.SignedIn, &row.CreatedAt, &row.LastSeenAt)
if !yield(row, err) {
return
}
}
}
}
func Session_List(
tx TX,
query string,
args ...any,
) (
l []*Session,
err error,
) {
for row, err := range Session_Iterate(tx, query, args...) {
if err != nil {
return nil, err
}
l = append(l, row)
}
return l, nil
}
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Table: networks // Table: networks
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
type Network struct { type Network struct {
NetworkID int64 NetworkID int64
LocalDomain string Name string
Network []byte Network []byte
} }
const Network_SelectQuery = "SELECT NetworkID,LocalDomain,Network FROM networks" const Network_SelectQuery = "SELECT NetworkID,Name,Network FROM networks"
func Network_Insert( func Network_Insert(
tx TX, tx TX,
@@ -207,7 +321,7 @@ func Network_Insert(
return err return err
} }
_, err = tx.Exec("INSERT INTO networks(NetworkID,LocalDomain,Network) VALUES(?,?,?)", row.NetworkID, row.LocalDomain, row.Network) _, err = tx.Exec("INSERT INTO networks(NetworkID,Name,Network) VALUES(?,?,?)", row.NetworkID, row.Name, row.Network)
return err return err
} }
@@ -220,7 +334,7 @@ func Network_UpdateFull(
return err return err
} }
result, err := tx.Exec("UPDATE networks SET LocalDomain=?,Network=? WHERE NetworkID=?", row.LocalDomain, row.Network, row.NetworkID) result, err := tx.Exec("UPDATE networks SET Name=?,Network=? WHERE NetworkID=?", row.Name, row.Network, row.NetworkID)
if err != nil { if err != nil {
return err return err
} }
@@ -270,10 +384,8 @@ func Network_Get(
err error, err error,
) { ) {
row = &Network{} row = &Network{}
r := tx.QueryRow("SELECT NetworkID,LocalDomain,Network FROM networks WHERE NetworkID=?", NetworkID) r := tx.QueryRow("SELECT NetworkID,Name,Network FROM networks WHERE NetworkID=?", NetworkID)
if err = r.Scan(&row.NetworkID, &row.LocalDomain, &row.Network); err != nil { err = r.Scan(&row.NetworkID, &row.Name, &row.Network)
row = nil
}
return return
} }
@@ -287,9 +399,7 @@ func Network_GetWhere(
) { ) {
row = &Network{} row = &Network{}
r := tx.QueryRow(query, args...) r := tx.QueryRow(query, args...)
if err = r.Scan(&row.NetworkID, &row.LocalDomain, &row.Network); err != nil { err = r.Scan(&row.NetworkID, &row.Name, &row.Network)
row = nil
}
return return
} }
@@ -309,7 +419,7 @@ func Network_Iterate(
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
row := &Network{} row := &Network{}
err := rows.Scan(&row.NetworkID, &row.LocalDomain, &row.Network) err := rows.Scan(&row.NetworkID, &row.Name, &row.Network)
if !yield(row, err) { if !yield(row, err) {
return return
} }
@@ -341,17 +451,17 @@ func Network_List(
type Peer struct { type Peer struct {
NetworkID int64 NetworkID int64
PeerIP byte PeerIP byte
Version int64
APIKey string APIKey string
Name string Name string
Addr4 []byte PublicIP []byte
Addr6 []byte
Port uint16 Port uint16
Relay bool Relay bool
WGPubKey []byte PubKey []byte
SignPubKey []byte PubSignKey []byte
} }
const Peer_SelectQuery = "SELECT NetworkID,PeerIP,APIKey,Name,Addr4,Addr6,Port,Relay,WGPubKey,SignPubKey FROM peers" const Peer_SelectQuery = "SELECT NetworkID,PeerIP,Version,APIKey,Name,PublicIP,Port,Relay,PubKey,PubSignKey FROM peers"
func Peer_Insert( func Peer_Insert(
tx TX, tx TX,
@@ -362,10 +472,38 @@ func Peer_Insert(
return err return err
} }
_, err = tx.Exec("INSERT INTO peers(NetworkID,PeerIP,APIKey,Name,Addr4,Addr6,Port,Relay,WGPubKey,SignPubKey) VALUES(?,?,?,?,?,?,?,?,?,?)", row.NetworkID, row.PeerIP, row.APIKey, row.Name, row.Addr4, row.Addr6, row.Port, row.Relay, row.WGPubKey, row.SignPubKey) _, err = tx.Exec("INSERT INTO peers(NetworkID,PeerIP,Version,APIKey,Name,PublicIP,Port,Relay,PubKey,PubSignKey) VALUES(?,?,?,?,?,?,?,?,?,?)", row.NetworkID, row.PeerIP, row.Version, row.APIKey, row.Name, row.PublicIP, row.Port, row.Relay, row.PubKey, row.PubSignKey)
return err return err
} }
func Peer_Update(
tx TX,
row *Peer,
) (err error) {
Peer_Sanitize(row)
if err = Peer_Validate(row); err != nil {
return err
}
result, err := tx.Exec("UPDATE peers SET Version=?,Name=?,PublicIP=?,Port=?,Relay=? WHERE NetworkID=? AND PeerIP=?", row.Version, row.Name, row.PublicIP, row.Port, row.Relay, row.NetworkID, row.PeerIP)
if err != nil {
return err
}
n, err := result.RowsAffected()
if err != nil {
panic(err)
}
switch n {
case 0:
return sql.ErrNoRows
case 1:
return nil
default:
panic("multiple rows updated")
}
}
func Peer_UpdateFull( func Peer_UpdateFull(
tx TX, tx TX,
row *Peer, row *Peer,
@@ -375,7 +513,7 @@ func Peer_UpdateFull(
return err return err
} }
result, err := tx.Exec("UPDATE peers SET APIKey=?,Name=?,Addr4=?,Addr6=?,Port=?,Relay=?,WGPubKey=?,SignPubKey=? WHERE NetworkID=? AND PeerIP=?", row.APIKey, row.Name, row.Addr4, row.Addr6, row.Port, row.Relay, row.WGPubKey, row.SignPubKey, row.NetworkID, row.PeerIP) result, err := tx.Exec("UPDATE peers SET Version=?,APIKey=?,Name=?,PublicIP=?,Port=?,Relay=?,PubKey=?,PubSignKey=? WHERE NetworkID=? AND PeerIP=?", row.Version, row.APIKey, row.Name, row.PublicIP, row.Port, row.Relay, row.PubKey, row.PubSignKey, row.NetworkID, row.PeerIP)
if err != nil { if err != nil {
return err return err
} }
@@ -427,10 +565,8 @@ func Peer_Get(
err error, err error,
) { ) {
row = &Peer{} row = &Peer{}
r := tx.QueryRow("SELECT NetworkID,PeerIP,APIKey,Name,Addr4,Addr6,Port,Relay,WGPubKey,SignPubKey FROM peers WHERE NetworkID=? AND PeerIP=?", NetworkID, PeerIP) r := tx.QueryRow("SELECT NetworkID,PeerIP,Version,APIKey,Name,PublicIP,Port,Relay,PubKey,PubSignKey FROM peers WHERE NetworkID=? AND PeerIP=?", NetworkID, PeerIP)
if err = r.Scan(&row.NetworkID, &row.PeerIP, &row.APIKey, &row.Name, &row.Addr4, &row.Addr6, &row.Port, &row.Relay, &row.WGPubKey, &row.SignPubKey); err != nil { err = r.Scan(&row.NetworkID, &row.PeerIP, &row.Version, &row.APIKey, &row.Name, &row.PublicIP, &row.Port, &row.Relay, &row.PubKey, &row.PubSignKey)
row = nil
}
return return
} }
@@ -444,9 +580,7 @@ func Peer_GetWhere(
) { ) {
row = &Peer{} row = &Peer{}
r := tx.QueryRow(query, args...) r := tx.QueryRow(query, args...)
if err = r.Scan(&row.NetworkID, &row.PeerIP, &row.APIKey, &row.Name, &row.Addr4, &row.Addr6, &row.Port, &row.Relay, &row.WGPubKey, &row.SignPubKey); err != nil { err = r.Scan(&row.NetworkID, &row.PeerIP, &row.Version, &row.APIKey, &row.Name, &row.PublicIP, &row.Port, &row.Relay, &row.PubKey, &row.PubSignKey)
row = nil
}
return return
} }
@@ -466,7 +600,7 @@ func Peer_Iterate(
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
row := &Peer{} row := &Peer{}
err := rows.Scan(&row.NetworkID, &row.PeerIP, &row.APIKey, &row.Name, &row.Addr4, &row.Addr6, &row.Port, &row.Relay, &row.WGPubKey, &row.SignPubKey) err := rows.Scan(&row.NetworkID, &row.PeerIP, &row.Version, &row.APIKey, &row.Name, &row.PublicIP, &row.Port, &row.Relay, &row.PubKey, &row.PubSignKey)
if !yield(row, err) { if !yield(row, err) {
return return
} }

View File

@@ -8,11 +8,9 @@ import (
var ( var (
ErrInvalidIP = errors.New("invalid IP") ErrInvalidIP = errors.New("invalid IP")
ErrInvalidPeerIP = errors.New("invalid peer IP")
ErrNonPrivateIP = errors.New("non-private IP") ErrNonPrivateIP = errors.New("non-private IP")
ErrInvalidPort = errors.New("invalid port") ErrInvalidPort = errors.New("invalid port")
ErrInvalidNetName = errors.New("invalid network name") ErrInvalidNetName = errors.New("invalid network name")
ErrNetNameNotLocal = errors.New("network name must end with .local")
ErrInvalidPeerName = errors.New("invalid peer name") ErrInvalidPeerName = errors.New("invalid peer name")
) )
@@ -23,8 +21,15 @@ func Config_Validate(c *Config) error {
return nil return nil
} }
func Session_Sanitize(s *Session) {
}
func Session_Validate(s *Session) error {
return nil
}
func Network_Sanitize(n *Network) { func Network_Sanitize(n *Network) {
n.LocalDomain = strings.TrimSpace(n.LocalDomain) n.Name = strings.TrimSpace(n.Name)
if addr, ok := netip.AddrFromSlice(n.Network); ok { if addr, ok := netip.AddrFromSlice(n.Network); ok {
n.Network = addr.AsSlice() n.Network = addr.AsSlice()
@@ -32,17 +37,12 @@ func Network_Sanitize(n *Network) {
} }
func Network_Validate(c *Network) error { func Network_Validate(c *Network) error {
// 15 bytes is linux limit for network interface names. With ending .local, // 16 bytes is linux limit for network interface names.
// max length is 21. if len(c.Name) == 0 || len(c.Name) > 16 {
if len(c.LocalDomain) == 0 || len(c.LocalDomain) > 21 {
return ErrInvalidNetName return ErrInvalidNetName
} }
if !strings.HasSuffix(c.LocalDomain, ".local") { for _, c := range c.Name {
return ErrNetNameNotLocal
}
for _, c := range strings.TrimSuffix(c.LocalDomain, ".local") {
if c >= 'a' && c <= 'z' { if c >= 'a' && c <= 'z' {
continue continue
} }
@@ -66,35 +66,21 @@ func Network_Validate(c *Network) error {
func Peer_Sanitize(p *Peer) { func Peer_Sanitize(p *Peer) {
p.Name = strings.TrimSpace(p.Name) p.Name = strings.TrimSpace(p.Name)
if len(p.Addr4) != 0 { if len(p.PublicIP) != 0 {
if addr, ok := netip.AddrFromSlice(p.Addr4); ok { addr, ok := netip.AddrFromSlice(p.PublicIP)
// Unmap so an IPv4-mapped form is stored canonically as 4 bytes. if ok && addr.Is4() {
p.Addr4 = addr.Unmap().AsSlice() p.PublicIP = addr.AsSlice()
}
}
if len(p.Addr6) != 0 {
if addr, ok := netip.AddrFromSlice(p.Addr6); ok {
p.Addr6 = addr.AsSlice()
} }
} }
if p.Port == 0 { if p.Port == 0 {
p.Port = 51820 p.Port = 456
} }
} }
func Peer_Validate(p *Peer) error { func Peer_Validate(p *Peer) error {
if p.PeerIP < 1 || p.PeerIP > 254 { if len(p.PublicIP) > 0 {
return ErrInvalidPeerIP _, ok := netip.AddrFromSlice(p.PublicIP)
} if !ok {
if len(p.Addr4) > 0 {
// Must be a genuine IPv4 address (reject an IPv6 in the v4 field).
if addr, ok := netip.AddrFromSlice(p.Addr4); !ok || !addr.Is4() {
return ErrInvalidIP
}
}
if len(p.Addr6) > 0 {
// Must be a genuine IPv6 address (reject IPv4 / IPv4-mapped in the v6 field).
if addr, ok := netip.AddrFromSlice(p.Addr6); !ok || !addr.Is6() || addr.Is4In6() {
return ErrInvalidIP return ErrInvalidIP
} }
} }
@@ -102,9 +88,6 @@ func Peer_Validate(p *Peer) error {
return ErrInvalidPort return ErrInvalidPort
} }
if len(p.Name) == 0 {
return ErrInvalidPeerName
}
for _, c := range p.Name { for _, c := range p.Name {
if c >= 'a' && c <= 'z' { if c >= 'a' && c <= 'z' {
continue continue
@@ -112,9 +95,10 @@ func Peer_Validate(p *Peer) error {
if c >= '0' && c <= '9' { if c >= '0' && c <= '9' {
continue continue
} }
if c == '-' { if c == '.' || c == '-' || c == '_' {
continue continue
} }
return ErrInvalidPeerName return ErrInvalidPeerName
} }

View File

@@ -3,21 +3,29 @@ TABLE config OF Config (
Password []byte Password []byte
); );
TABLE sessions OF Session NoUpdate (
SessionID string PK,
CSRF string,
SignedIn bool,
CreatedAt int64,
LastSeenAt int64
);
TABLE networks OF Network ( TABLE networks OF Network (
NetworkID int64 PK, NetworkID int64 PK,
LocalDomain string NoUpdate, Name string NoUpdate,
Network []byte NoUpdate Network []byte NoUpdate
); );
TABLE peers OF Peer ( TABLE peers OF Peer (
NetworkID int64 PK, NetworkID int64 PK,
PeerIP byte PK, PeerIP byte PK,
Version int64,
APIKey string NoUpdate, APIKey string NoUpdate,
Name string NoUpdate, Name string,
Addr4 []byte NoUpdate, PublicIP []byte,
Addr6 []byte NoUpdate, Port uint16,
Port uint16 NoUpdate, Relay bool,
Relay bool NoUpdate, PubKey []byte NoUpdate,
WGPubKey []byte NoUpdate, PubSignKey []byte NoUpdate
SignPubKey []byte NoUpdate
); );

View File

@@ -1,5 +1,31 @@
package db package db
import "time"
func Session_UpdateLastSeenAt(
tx TX,
id string,
) (err error) {
_, err = tx.Exec("UPDATE sessions SET LastSeenAt=? WHERE SessionID=?", time.Now().Unix(), id)
return err
}
func Session_SetSignedIn(
tx TX,
id string,
) (err error) {
_, err = tx.Exec("UPDATE sessions SET SignedIn=1 WHERE SessionID=?", id)
return err
}
func Session_DeleteBefore(
tx TX,
timestamp int64,
) (err error) {
_, err = tx.Exec("DELETE FROM sessions WHERE LastSeenAt<?", timestamp)
return err
}
func Peer_ListAll(tx TX, networkID int64) ([]*Peer, error) { func Peer_ListAll(tx TX, networkID int64) ([]*Peer, error) {
const query = Peer_SelectQuery + ` WHERE NetworkID=? ORDER BY PeerIP ASC` const query = Peer_SelectQuery + ` WHERE NetworkID=? ORDER BY PeerIP ASC`
return Peer_List(tx, query, networkID) return Peer_List(tx, query, networkID)
@@ -11,3 +37,9 @@ func Peer_GetByAPIKey(tx TX, apiKey string) (*Peer, error) {
Peer_SelectQuery+` WHERE APIKey=?`, Peer_SelectQuery+` WHERE APIKey=?`,
apiKey) apiKey)
} }
func Peer_Exists(tx TX, networkID int64, ip byte) (exists bool, err error) {
const query = `SELECT EXISTS(SELECT 1 FROM peers WHERE NetworkID=? AND PeerIP=?)`
err = tx.QueryRow(query, networkID, ip).Scan(&exists)
return
}

View File

@@ -7,6 +7,7 @@ import (
var ( var (
ErrNotAuthorized = errors.New("not authorized") ErrNotAuthorized = errors.New("not authorized")
ErrNoIPAvailable = errors.New("no IP address available")
ErrInvalidIP = db.ErrInvalidIP ErrInvalidIP = db.ErrInvalidIP
ErrInvalidPort = db.ErrInvalidPort ErrInvalidPort = db.ErrInvalidPort
) )

View File

@@ -3,23 +3,32 @@ CREATE TABLE config (
Password BLOB NOT NULL -- bcrypt password for web interface Password BLOB NOT NULL -- bcrypt password for web interface
) WITHOUT ROWID; ) WITHOUT ROWID;
CREATE TABLE sessions (
SessionID TEXT NOT NULL PRIMARY KEY,
CSRF TEXT NOT NULL,
SignedIn INTEGER NOT NULL,
CreatedAt INTEGER NOT NULL,
LastSeenAt INTEGER NOT NULL
) WITHOUT ROWID;
CREATE INDEX sessions_last_seen_index ON sessions(LastSeenAt);
CREATE TABLE networks ( CREATE TABLE networks (
NetworkID INTEGER NOT NULL PRIMARY KEY, NetworkID INTEGER NOT NULL PRIMARY KEY,
LocalDomain TEXT NOT NULL UNIQUE, -- Network/interface name. Name TEXT NOT NULL UNIQUE, -- Network/interface name.
Network BLOB NOT NULL UNIQUE -- Network (/24), example 10.51.50.0 Network BLOB NOT NULL UNIQUE -- Network (/24), example 10.51.50.0
) WITHOUT ROWID; ) WITHOUT ROWID;
CREATE TABLE peers ( CREATE TABLE peers (
NetworkID INTEGER NOT NULL, NetworkID INTEGER NOT NULL,
PeerIP INTEGER NOT NULL, -- Final byte of IP. PeerIP INTEGER NOT NULL, -- Final byte of IP.
APIKey TEXT NOT NULL UNIQUE, -- Peer's secret API key. Version INTEGER NOT NULL, -- Changes when updated.
Name TEXT NOT NULL, -- For humans. APIKey TEXT NOT NULL UNIQUE, -- Peer's secret API key.
Addr4 BLOB NOT NULL, Name TEXT NOT NULL UNIQUE, -- For humans.
Addr6 BLOB NOT NULL, PublicIP BLOB NOT NULL,
Port INTEGER NOT NULL, Port INTEGER NOT NULL,
Relay INTEGER NOT NULL DEFAULT 0, -- Boolean if peer will forward packets. Relay INTEGER NOT NULL DEFAULT 0, -- Boolean if peer will forward packets. Must also have public address.
WGPubKey BLOB NOT NULL, PubKey BLOB NOT NULL,
SignPubKey BLOB NOT NULL, PubSignKey BLOB NOT NULL,
UNIQUE(NetworkID, Name),
PRIMARY KEY(NetworkID, PeerIP) PRIMARY KEY(NetworkID, PeerIP)
) WITHOUT ROWID; ) WITHOUT ROWID;

View File

@@ -3,12 +3,6 @@ package api
import "vppn/hub/api/db" import "vppn/hub/api/db"
type Config = db.Config type Config = db.Config
type Session = db.Session
type Network = db.Network type Network = db.Network
type Peer = db.Peer type Peer = db.Peer
type Session struct {
SessionID string
SignedIn bool
CreatedAt int64
LastSeenAt int64
}

View File

@@ -2,7 +2,6 @@ package hub
import ( import (
"embed" "embed"
"encoding/base64"
"html/template" "html/template"
"net/http" "net/http"
"path/filepath" "path/filepath"
@@ -48,19 +47,6 @@ func NewApp(conf Config) (*App, error) {
return app, nil return app, nil
} }
func (app *App) Handler() http.Handler {
cop := http.NewCrossOriginProtection()
return cop.Handler(app.mux)
}
var templateFuncs = template.FuncMap{ var templateFuncs = template.FuncMap{
"ipToString": ipBytesTostring, "ipToString": ipBytesTostring,
"wgKeyString": wgKeyString,
}
func wgKeyString(key []byte) string {
if len(key) == 0 {
return "not set"
}
return base64.StdEncoding.EncodeToString(key)
} }

View File

@@ -2,6 +2,7 @@ package hub
import ( import (
"net/http" "net/http"
"time"
) )
func (a *App) getCookie(r *http.Request, name string) string { func (a *App) getCookie(r *http.Request, name string) string {
@@ -25,12 +26,9 @@ func (a *App) setCookie(w http.ResponseWriter, name, value string) {
func (a *App) deleteCookie(w http.ResponseWriter, name string) { func (a *App) deleteCookie(w http.ResponseWriter, name string) {
http.SetCookie(w, &http.Cookie{ http.SetCookie(w, &http.Cookie{
Name: name, Name: name,
Value: "", Value: "",
Path: "/", Path: "/",
Secure: !a.insecure, Expires: time.Unix(0, 0),
SameSite: http.SameSiteStrictMode,
HttpOnly: true,
MaxAge: -1, // delete now
}) })
} }

View File

@@ -1,5 +1,5 @@
package hub package hub
const ( const (
sessionIDCookieName = "SessionID" SESSION_ID_COOKIE_NAME = "SessionID"
) )

View File

@@ -12,7 +12,7 @@ type handlerFunc func(s *api.Session, w http.ResponseWriter, r *http.Request) er
func (app *App) handlePub(pattern string, fn handlerFunc) { func (app *App) handlePub(pattern string, fn handlerFunc) {
wrapped := func(w http.ResponseWriter, r *http.Request) { wrapped := func(w http.ResponseWriter, r *http.Request) {
sessionID := app.getCookie(r, sessionIDCookieName) sessionID := app.getCookie(r, SESSION_ID_COOKIE_NAME)
s, err := app.api.Session_Get(sessionID) s, err := app.api.Session_Get(sessionID)
if err != nil { if err != nil {
log.Printf("Failed to get session: %v", err) log.Printf("Failed to get session: %v", err)
@@ -20,13 +20,22 @@ func (app *App) handlePub(pattern string, fn handlerFunc) {
return return
} }
if s.SessionID != sessionID {
app.setCookie(w, SESSION_ID_COOKIE_NAME, s.SessionID)
}
if r.Method == http.MethodPost { if r.Method == http.MethodPost {
r.ParseMultipartForm(64 * 1024) r.ParseMultipartForm(64 * 1024)
if r.FormValue("CSRF") != s.CSRF {
log.Printf("%s != %s", r.FormValue("CSRF"), s.CSRF)
http.Error(w, "CSRF mismatch", http.StatusBadRequest)
return
}
} else { } else {
r.ParseForm() r.ParseForm()
} }
if err := fn(&s, w, r); err != nil { if err := fn(s, w, r); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
} }
} }

View File

@@ -5,12 +5,13 @@ import (
"errors" "errors"
"log" "log"
"net/http" "net/http"
"net/netip"
"strings"
"vppn/hub/api" "vppn/hub/api"
"vppn/m" "vppn/m"
"git.crumpington.com/lib/go/webutil" "git.crumpington.com/lib/go/webutil"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
func (a *App) _root(s *api.Session, w http.ResponseWriter, r *http.Request) error { func (a *App) _root(s *api.Session, w http.ResponseWriter, r *http.Request) error {
@@ -34,11 +35,9 @@ func (a *App) _signinSubmit(s *api.Session, w http.ResponseWriter, r *http.Reque
return err return err
} }
sess, err := a.api.Session_SignIn(pwd) if err := a.api.Session_SignIn(s, pwd); err != nil {
if err != nil {
return err return err
} }
a.setCookie(w, sessionIDCookieName, sess.SessionID)
return a.redirect(w, r, "/") return a.redirect(w, r, "/")
} }
@@ -51,7 +50,7 @@ func (a *App) _adminSignOutSubmit(s *api.Session, w http.ResponseWriter, r *http
if err := a.api.Session_Delete(s.SessionID); err != nil { if err := a.api.Session_Delete(s.SessionID); err != nil {
log.Printf("Failed to delete session cookie %s: %v", s.SessionID, err) log.Printf("Failed to delete session cookie %s: %v", s.SessionID, err)
} }
a.deleteCookie(w, sessionIDCookieName) a.deleteCookie(w, SESSION_ID_COOKIE_NAME)
return a.redirect(w, r, "/") return a.redirect(w, r, "/")
} }
@@ -75,7 +74,7 @@ func (a *App) _adminNetworkCreateSubmit(s *api.Session, w http.ResponseWriter, r
var netStr string var netStr string
err := webutil.NewFormScanner(r.Form). err := webutil.NewFormScanner(r.Form).
Scan("LocalDomain", &n.LocalDomain). Scan("Name", &n.Name).
Scan("Network", &netStr). Scan("Network", &netStr).
Error() Error()
if err != nil { if err != nil {
@@ -145,15 +144,14 @@ func (a *App) _adminPeerCreate(s *api.Session, w http.ResponseWriter, r *http.Re
} }
func (a *App) _adminPeerCreateSubmit(s *api.Session, w http.ResponseWriter, r *http.Request) error { func (a *App) _adminPeerCreateSubmit(s *api.Session, w http.ResponseWriter, r *http.Request) error {
var addr4Str, addr6Str string var ipStr string
p := &api.Peer{} p := &api.Peer{}
err := webutil.NewFormScanner(r.Form). err := webutil.NewFormScanner(r.Form).
Scan("NetworkID", &p.NetworkID). Scan("NetworkID", &p.NetworkID).
Scan("IP", &p.PeerIP). Scan("IP", &p.PeerIP).
Scan("Name", &p.Name). Scan("Name", &p.Name).
Scan("Addr4", &addr4Str). Scan("PublicIP", &ipStr).
Scan("Addr6", &addr6Str).
Scan("Port", &p.Port). Scan("Port", &p.Port).
Scan("Relay", &p.Relay). Scan("Relay", &p.Relay).
Error() Error()
@@ -161,10 +159,7 @@ func (a *App) _adminPeerCreateSubmit(s *api.Session, w http.ResponseWriter, r *h
return err return err
} }
if p.Addr4, err = stringToIP(addr4Str); err != nil { if p.PublicIP, err = stringToIP(ipStr); err != nil {
return err
}
if p.Addr6, err = stringToIP(addr6Str); err != nil {
return err return err
} }
@@ -187,6 +182,48 @@ func (a *App) _adminPeerView(s *api.Session, w http.ResponseWriter, r *http.Requ
}{s, net, peer}) }{s, net, peer})
} }
func (a *App) _adminPeerEdit(s *api.Session, w http.ResponseWriter, r *http.Request) error {
net, peer, err := a.formGetPeer(r.Form)
if err != nil {
return err
}
return a.render("/network/peer-edit.html", w, struct {
Session *api.Session
Network *api.Network
Peer *api.Peer
}{s, net, peer})
}
func (a *App) _adminPeerEditSubmit(s *api.Session, w http.ResponseWriter, r *http.Request) error {
_, peer, err := a.formGetPeer(r.Form)
if err != nil {
return err
}
var ipStr string
err = webutil.NewFormScanner(r.Form).
Scan("Name", &peer.Name).
Scan("PublicIP", &ipStr).
Scan("Port", &peer.Port).
Scan("Relay", &peer.Relay).
Error()
if err != nil {
return err
}
if peer.PublicIP, err = stringToIP(ipStr); err != nil {
return err
}
if err = a.api.Peer_Update(peer); err != nil {
return err
}
return a.redirect(w, r, "/admin/peer/view/?NetworkID=%d&PeerIP=%d", peer.NetworkID, peer.PeerIP)
}
func (a *App) _adminPeerDelete(s *api.Session, w http.ResponseWriter, r *http.Request) error { func (a *App) _adminPeerDelete(s *api.Session, w http.ResponseWriter, r *http.Request) error {
n, peer, err := a.formGetPeer(r.Form) n, peer, err := a.formGetPeer(r.Form)
if err != nil { if err != nil {
@@ -211,23 +248,40 @@ func (a *App) _adminPeerDeleteSubmit(s *api.Session, w http.ResponseWriter, r *h
return a.redirect(w, r, "/admin/network/view/?NetworkID=%d", n.NetworkID) return a.redirect(w, r, "/admin/network/view/?NetworkID=%d", n.NetworkID)
} }
func (a *App) _adminNetworkHosts(s *api.Session, w http.ResponseWriter, r *http.Request) error {
n, peers, err := a.formGetNetworkPeers(r.Form)
if err != nil {
return err
}
b := strings.Builder{}
for _, peer := range peers {
ip := n.Network
ip[3] = peer.PeerIP
b.WriteString(netip.AddrFrom4([4]byte(ip)).String())
b.WriteString(" ")
b.WriteString(peer.Name)
b.WriteString("\n")
}
w.Write([]byte(b.String()))
return nil
}
func (a *App) _adminPasswordEdit(s *api.Session, w http.ResponseWriter, r *http.Request) error { func (a *App) _adminPasswordEdit(s *api.Session, w http.ResponseWriter, r *http.Request) error {
return a.render("/admin-password-edit.html", w, struct{ Session *api.Session }{s}) return a.render("/admin-password-edit.html", w, struct{ Session *api.Session }{s})
} }
func (a *App) _adminPasswordSubmit(s *api.Session, w http.ResponseWriter, r *http.Request) error { func (a *App) _adminPasswordSubmit(s *api.Session, w http.ResponseWriter, r *http.Request) error {
var ( var (
conf = a.api.Config_Get()
curPwd string curPwd string
newPwd string newPwd string
newPwd2 string newPwd2 string
) )
conf, err := a.api.Config_Get() err := webutil.NewFormScanner(r.Form).
if err != nil {
return err
}
err = webutil.NewFormScanner(r.Form).
Scan("CurrentPassword", &curPwd). Scan("CurrentPassword", &curPwd).
Scan("NewPassword", &newPwd). Scan("NewPassword", &newPwd).
Scan("NewPassword2", &newPwd2). Scan("NewPassword2", &newPwd2).
@@ -264,25 +318,11 @@ func (a *App) _adminPasswordSubmit(s *api.Session, w http.ResponseWriter, r *htt
} }
func (a *App) _peerInit(peer *api.Peer, w http.ResponseWriter, r *http.Request) error { func (a *App) _peerInit(peer *api.Peer, w http.ResponseWriter, r *http.Request) error {
if len(peer.WGPubKey) != 0 {
http.Error(w, "Already initialized", http.StatusConflict)
return nil
}
args := m.PeerInitArgs{} args := m.PeerInitArgs{}
if err := json.NewDecoder(r.Body).Decode(&args); err != nil { if err := json.NewDecoder(r.Body).Decode(&args); err != nil {
return err return err
} }
if len(args.WGPubKey) != 32 {
http.Error(w, "invalid WGPubKey", http.StatusBadRequest)
return nil
}
if len(args.SignPubKey) != 32 {
http.Error(w, "invalid SignPubKey", http.StatusBadRequest)
return nil
}
net, err := a.api.Network_Get(peer.NetworkID) net, err := a.api.Network_Get(peer.NetworkID)
if err != nil { if err != nil {
return err return err
@@ -293,12 +333,11 @@ func (a *App) _peerInit(peer *api.Peer, w http.ResponseWriter, r *http.Request)
} }
resp := m.PeerInitResp{ resp := m.PeerInitResp{
PeerIP: peer.PeerIP, PeerIP: peer.PeerIP,
Network: net.Network, Network: net.Network,
LocalDomain: net.LocalDomain,
} }
resp.NetworkState.Peers, err = a.peersList(net.NetworkID) resp.NetworkState.Peers, err = a.peersArray(net.NetworkID)
if err != nil { if err != nil {
return err return err
} }
@@ -307,42 +346,34 @@ func (a *App) _peerInit(peer *api.Peer, w http.ResponseWriter, r *http.Request)
} }
func (a *App) _peerFetchState(peer *api.Peer, w http.ResponseWriter, r *http.Request) error { func (a *App) _peerFetchState(peer *api.Peer, w http.ResponseWriter, r *http.Request) error {
peers, err := a.peersList(peer.NetworkID)
peers, err := a.peersArray(peer.NetworkID)
if err != nil { if err != nil {
return err return err
} }
return a.sendJSON(w, m.NetworkState{Peers: peers}) return a.sendJSON(w, m.NetworkState{Peers: peers})
} }
func (a *App) peersList(networkID int64) (peers []m.Peer, err error) { func (a *App) peersArray(networkID int64) (peers [256]*m.Peer, err error) {
l, err := a.api.Peer_List(networkID) l, err := a.api.Peer_List(networkID)
if err != nil { if err != nil {
return nil, err return peers, err
} }
peers = make([]m.Peer, 0, len(l))
for _, p := range l { for _, p := range l {
if len(p.WGPubKey) == 0 { if len(p.PubKey) != 0 {
continue peers[p.PeerIP] = &m.Peer{
PeerIP: p.PeerIP,
Version: p.Version,
Name: p.Name,
PublicIP: p.PublicIP,
Port: p.Port,
Relay: p.Relay,
PubKey: p.PubKey,
PubSignKey: p.PubSignKey,
}
} }
wgKey, err := wgtypes.NewKey(p.WGPubKey)
if err != nil {
continue // malformed key; skip rather than serve garbage
}
var signKey [32]byte
copy(signKey[:], p.SignPubKey)
peers = append(peers, m.Peer{
PeerIP: p.PeerIP,
Name: p.Name,
Addr4: addrFromBytes(p.Addr4),
Addr6: addrFromBytes(p.Addr6),
Port: p.Port,
Relay: p.Relay,
WGPubKey: wgKey,
SignPubKey: signKey,
})
} }
return peers, nil return
} }

View File

@@ -31,7 +31,7 @@ func Main() {
srv := &http.Server{ srv := &http.Server{
Addr: conf.ListenAddr, Addr: conf.ListenAddr,
Handler: app.Handler(), Handler: app.mux,
} }
log.Fatal(webutil.ListenAndServe(srv)) log.Fatal(webutil.ListenAndServe(srv))

View File

@@ -19,9 +19,12 @@ func (a *App) registerRoutes() {
a.handleSignedIn("POST /admin/network/delete/", a._adminNetworkDeleteSubmit) a.handleSignedIn("POST /admin/network/delete/", a._adminNetworkDeleteSubmit)
a.handleSignedIn("GET /admin/network/view/", a._adminNetworkView) a.handleSignedIn("GET /admin/network/view/", a._adminNetworkView)
a.handleSignedIn("GET /admin/network/hosts/", a._adminNetworkHosts)
a.handleSignedIn("GET /admin/peer/create/", a._adminPeerCreate) a.handleSignedIn("GET /admin/peer/create/", a._adminPeerCreate)
a.handleSignedIn("POST /admin/peer/create/", a._adminPeerCreateSubmit) a.handleSignedIn("POST /admin/peer/create/", a._adminPeerCreateSubmit)
a.handleSignedIn("GET /admin/peer/view/", a._adminPeerView) a.handleSignedIn("GET /admin/peer/view/", a._adminPeerView)
a.handleSignedIn("GET /admin/peer/edit/", a._adminPeerEdit)
a.handleSignedIn("POST /admin/peer/edit/", a._adminPeerEditSubmit)
a.handleSignedIn("GET /admin/peer/delete/", a._adminPeerDelete) a.handleSignedIn("GET /admin/peer/delete/", a._adminPeerDelete)
a.handleSignedIn("POST /admin/peer/delete/", a._adminPeerDeleteSubmit) a.handleSignedIn("POST /admin/peer/delete/", a._adminPeerDeleteSubmit)

View File

@@ -2,9 +2,10 @@
<h2>Create Network</h2> <h2>Create Network</h2>
<form method="POST"> <form method="POST">
<p> <input type="hidden" name="CSRF" value="{{.Session.CSRF}}">
<label>Local Domain (ending with .local)</label><br> <p>
<input type="text" name="LocalDomain"> <label>Name</label><br>
<input type="text" name="Name">
</p> </p>
<p> <p>
<label>Network /24</label><br> <label>Network /24</label><br>

View File

@@ -9,7 +9,7 @@
<table> <table>
<thead> <thead>
<tr> <tr>
<th>Local Domain</th> <th>Name</th>
<th>Network</th> <th>Network</th>
</tr> </tr>
</thead> </thead>
@@ -18,7 +18,7 @@
<tr> <tr>
<td> <td>
<a href="/admin/network/view/?NetworkID={{.NetworkID}}"> <a href="/admin/network/view/?NetworkID={{.NetworkID}}">
{{.LocalDomain}} {{.Name}}
</a> </a>
</td> </td>
<td>{{ipToString .Network}}</td> <td>{{ipToString .Network}}</td>

View File

@@ -2,7 +2,8 @@
<h2>Change Password</h2> <h2>Change Password</h2>
<form method="POST"> <form method="POST">
<p> <input type="hidden" name="CSRF" value="{{.Session.CSRF}}">
<p>
<label>Current Password</label><br> <label>Current Password</label><br>
<input type="password" name="CurrentPassword"> <input type="password" name="CurrentPassword">
</p> </p>

View File

@@ -2,7 +2,8 @@
<h2>Sign Out</h2> <h2>Sign Out</h2>
<form method="POST"> <form method="POST">
<p> <input type="hidden" name="CSRF" value="{{.Session.CSRF}}">
<p>
<button type="submit">Sign Out</button> <button type="submit">Sign Out</button>
<a href="/">Cancel</a> <a href="/">Cancel</a>
</p> </p>

View File

@@ -17,7 +17,7 @@
</header> </header>
<h2> <h2>
Network: Network:
<a href="/admin/network/view/?NetworkID={{.Network.NetworkID}}">{{.Network.LocalDomain}}</a> <a href="/admin/network/view/?NetworkID={{.Network.NetworkID}}">{{.Network.Name}}</a>
</h2> </h2>
{{block "body" .}}There's nothing here.{{end}} {{block "body" .}}There's nothing here.{{end}}

View File

@@ -5,7 +5,8 @@
<p>You must first delete all peers.</p> <p>You must first delete all peers.</p>
{{- else -}} {{- else -}}
<form method="POST"> <form method="POST">
<input type="hidden" name="NetworkID" value="{{.Network.NetworkID}}"> <input type="hidden" name="CSRF" value="{{.Session.CSRF}}">
<input type="hidden" name="NetworkID" value="{{.Network.NetworkID}}">
<p> <p>
<button type="submit">Delete</button> <button type="submit">Delete</button>
<a href="/admin/network/view/?NetworkID={{.Network.NetworkID}}">Cancel</a> <a href="/admin/network/view/?NetworkID={{.Network.NetworkID}}">Cancel</a>

View File

@@ -1,6 +1,7 @@
{{define "body" -}} {{define "body" -}}
<p> <p>
<a href="/admin/network/delete/?NetworkID={{.Network.NetworkID}}">Delete</a> <a href="/admin/network/delete/?NetworkID={{.Network.NetworkID}}">Delete</a> /
<a href="/admin/network/hosts/?NetworkID={{.Network.NetworkID}}">Hosts</a>
</p> </p>
<table class="def-list"> <table class="def-list">
@@ -22,8 +23,7 @@
<tr> <tr>
<th>PeerIP</th> <th>PeerIP</th>
<th>Name</th> <th>Name</th>
<th>IPv4</th> <th>Public IP</th>
<th>IPv6</th>
<th>Port</th> <th>Port</th>
<th>Relay</th> <th>Relay</th>
</tr> </tr>
@@ -37,8 +37,7 @@
</a> </a>
</td> </td>
<td>{{.Name}}</td> <td>{{.Name}}</td>
<td>{{ipToString .Addr4}}</td> <td>{{ipToString .PublicIP}}</td>
<td>{{ipToString .Addr6}}</td>
<td>{{.Port}}</td> <td>{{.Port}}</td>
<td>{{if .Relay}}T{{else}}F{{end}}</td> <td>{{if .Relay}}T{{else}}F{{end}}</td>
</tr> </tr>

View File

@@ -2,6 +2,7 @@
<h3>New Peer</h3> <h3>New Peer</h3>
<form method="POST"> <form method="POST">
<input type="hidden" name="CSRF" value="{{.Session.CSRF}}">
<input type="hidden" name="NetworkID" value="{{.Network.NetworkID}}"> <input type="hidden" name="NetworkID" value="{{.Network.NetworkID}}">
<p> <p>
<label>IP</label><br> <label>IP</label><br>
@@ -12,16 +13,12 @@
<input type="text" name="Name"> <input type="text" name="Name">
</p> </p>
<p> <p>
<label>IPv4 Address (optional)</label><br> <label>Public IP</label><br>
<input type="text" name="Addr4"> <input type="text" name="PublicIP">
</p> </p>
<p> <p>
<label>IPv6 Address (optional)</label><br> <label>Port</label><br>
<input type="text" name="Addr6"> <input type="number" name="Port" value="456">
</p>
<p>
<label>WireGuard Port</label><br>
<input type="number" name="Port" value="51820">
</p> </p>
<p> <p>
<label> <label>

View File

@@ -3,8 +3,9 @@
{{with .Peer -}} {{with .Peer -}}
<form method="POST"> <form method="POST">
<input type="hidden" name="NetworkID" value="{{.NetworkID}}"> <input type="hidden" name="CSRF" value="{{$.Session.CSRF}}">
<input type="hidden" name="PeerIP" value="{{.PeerIP}}"> <input type="hidden" name="NetworkID" value="{{.NetworkID}}">
<input type="hidden" name="NetworkID" value="{{.PeerIP}}">
<p> <p>
<button type="submit">Delete</button> <button type="submit">Delete</button>
<a href="/admin/peer/view/?PeerIP={{.PeerIP}}&NetworkID={{.NetworkID}}">Cancel</a> <a href="/admin/peer/view/?PeerIP={{.PeerIP}}&NetworkID={{.NetworkID}}">Cancel</a>

View File

@@ -0,0 +1,35 @@
{{define "body" -}}
<h2>Edit Peer</h2>
{{with .Peer -}}
<form method="POST">
<input type="hidden" name="CSRF" value="{{$.Session.CSRF}}">
<p>
<label>Peer IP</label><br>
<input type="text" value="{{.PeerIP}}" disabled>
</p>
<p>
<label>Name</label><br>
<input type="text" name="Name" value="{{.Name}}">
</p>
<p>
<label>Public IP</label><br>
<input type="text" name="PublicIP" value="{{ipToString .PublicIP}}">
</p>
<p>
<label>Port</label><br>
<input type="number" name="Port" value="{{.Port}}">
</p>
<p>
<label>
<input type="checkbox" name="Relay" {{if .Relay}}checked{{end}}>
Relay
</label>
</p>
<p>
<button type="submit">Save</button>
<a href="/admin/peer/view/?NetworkID={{$.Network.NetworkID}}&PeerIP={{.PeerIP}}">Cancel</a>
</p>
</form>
{{- end}}
{{- end}}

View File

@@ -1,17 +1,17 @@
{{define "body" -}} {{define "body" -}}
<h3>{{.Peer.Name}}</h3> <h3>{{.Peer.Name}}</h3>
<p> <p>
<a href="/admin/peer/edit/?NetworkID={{.Network.NetworkID}}&PeerIP={{.Peer.PeerIP}}">Edit</a> /
<a href="/admin/peer/delete/?NetworkID={{.Network.NetworkID}}&PeerIP={{.Peer.PeerIP}}">Delete</a> <a href="/admin/peer/delete/?NetworkID={{.Network.NetworkID}}&PeerIP={{.Peer.PeerIP}}">Delete</a>
</p> </p>
{{with .Peer -}} {{with .Peer -}}
<table class="def-list"> <table class="def-list">
<tr><td>Peer IP</td><td>{{.PeerIP}}</td></tr> <tr><td>Peer IP</td><td>{{.PeerIP}}</td></tr>
<tr><td>IPv4 Address</td><td>{{ipToString .Addr4}}</td></tr> <tr><td>Public IP</td><td>{{ipToString .PublicIP}}</td></tr>
<tr><td>IPv6 Address</td><td>{{ipToString .Addr6}}</td></tr> <tr><td>Port</td><td>{{.Port}}</td></tr>
<tr><td>WireGuard Port</td><td>{{.Port}}</td></tr>
<tr><td>Relay</td><td>{{if .Relay}}T{{else}}F{{end}}</td></tr> <tr><td>Relay</td><td>{{if .Relay}}T{{else}}F{{end}}</td></tr>
<tr><td>WG Public Key</td><td>{{wgKeyString .WGPubKey}}</td></tr> </td></tr>
</table> </table>
<details> <details>
@@ -19,6 +19,7 @@
<p>{{.APIKey}}</p> <p>{{.APIKey}}</p>
</details> </details>
{{- end}} {{- end}}
{{- end}} {{- end}}

View File

@@ -2,7 +2,8 @@
<h2>Sign In</h2> <h2>Sign In</h2>
<form method="POST"> <form method="POST">
<p> <input type="hidden" name="CSRF" value="{{.Session.CSRF}}">
<p>
<label>Password</label><br> <label>Password</label><br>
<input type="password" name="Password"> <input type="password" name="Password">
</p> </p>

View File

@@ -38,19 +38,6 @@ func (app *App) sendJSON(w http.ResponseWriter, data any) error {
return nil return nil
} }
// addrFromBytes parses raw IP bytes (4 or 16) into a netip.Addr, unmapping
// IPv4-in-IPv6, returning the zero Addr for empty/invalid input.
func addrFromBytes(b []byte) netip.Addr {
if len(b) == 0 {
return netip.Addr{}
}
addr, ok := netip.AddrFromSlice(b)
if !ok {
return netip.Addr{}
}
return addr.Unmap()
}
func stringToIP(in string) ([]byte, error) { func stringToIP(in string) ([]byte, error) {
in = strings.TrimSpace(in) in = strings.TrimSpace(in)
if len(in) == 0 { if len(in) == 0 {

View File

@@ -1,133 +1,28 @@
// The package `m` contains models shared between the hub and peer programs. // The package `m` contains models shared between the hub and peer programs.
package m package m
import (
"encoding/base64"
"encoding/json"
"fmt"
"net/netip"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type PeerInitArgs struct { type PeerInitArgs struct {
WGPubKey []byte EncPubKey []byte
SignPubKey []byte PubSignKey []byte
} }
type PeerInitResp struct { type PeerInitResp struct {
PeerIP byte PeerIP byte
Network []byte Network []byte
LocalDomain string
NetworkState NetworkState NetworkState NetworkState
} }
// Peer is the network membership record for a single peer, exchanged between
// the hub and peers. Addr4/Addr6 are the peer's public endpoint addresses (zero
// if it has none); Port is its WireGuard listen port, meaningful even for a
// non-public peer (it is the peer's own bind/beacon port).
type Peer struct { type Peer struct {
PeerIP byte PeerIP byte
Version int64
Name string Name string
Addr4 netip.Addr // zero if none PublicIP []byte
Addr6 netip.Addr // zero if none
Port uint16 Port uint16
Relay bool Relay bool
WGPubKey wgtypes.Key PubKey []byte
SignPubKey [32]byte PubSignKey []byte
}
// IsPublic reports whether the peer advertises at least one reachable endpoint.
func (p Peer) IsPublic() bool {
return p.Addr4.IsValid() || p.Addr6.IsValid()
}
// Endpoint4 returns the IPv4 endpoint (addr+port), or the zero AddrPort if the
// peer has no IPv4 address.
func (p Peer) Endpoint4() netip.AddrPort {
if !p.Addr4.IsValid() {
return netip.AddrPort{}
}
return netip.AddrPortFrom(p.Addr4, p.Port)
}
// Endpoint6 returns the IPv6 endpoint (addr+port), or the zero AddrPort if the
// peer has no IPv6 address.
func (p Peer) Endpoint6() netip.AddrPort {
if !p.Addr6.IsValid() {
return netip.AddrPort{}
}
return netip.AddrPortFrom(p.Addr6, p.Port)
}
// PreferredEndpoint returns the IPv4 endpoint if present, else IPv6.
func (p Peer) PreferredEndpoint() netip.AddrPort {
if ep := p.Endpoint4(); ep.IsValid() {
return ep
}
return p.Endpoint6()
}
// peerJSON is the wire representation. netip.Addr fields round-trip as text
// strings automatically; only the fixed-size key arrays need base64 (otherwise
// encoding/json would emit them as arrays of numbers).
type peerJSON struct {
PeerIP byte
Name string
Addr4 netip.Addr
Addr6 netip.Addr
Port uint16
Relay bool
WGPubKey string
SignPubKey string
}
func (p Peer) MarshalJSON() ([]byte, error) {
return json.Marshal(peerJSON{
PeerIP: p.PeerIP,
Name: p.Name,
Addr4: p.Addr4,
Addr6: p.Addr6,
Port: p.Port,
Relay: p.Relay,
WGPubKey: base64.StdEncoding.EncodeToString(p.WGPubKey[:]),
SignPubKey: base64.StdEncoding.EncodeToString(p.SignPubKey[:]),
})
}
func (p *Peer) UnmarshalJSON(data []byte) error {
var j peerJSON
if err := json.Unmarshal(data, &j); err != nil {
return err
}
wg, err := base64.StdEncoding.DecodeString(j.WGPubKey)
if err != nil {
return fmt.Errorf("decode WGPubKey: %w", err)
}
key, err := wgtypes.NewKey(wg)
if err != nil {
return fmt.Errorf("invalid WGPubKey: %w", err)
}
sign, err := base64.StdEncoding.DecodeString(j.SignPubKey)
if err != nil {
return fmt.Errorf("decode SignPubKey: %w", err)
}
if len(sign) != 32 {
return fmt.Errorf("invalid SignPubKey length: %d", len(sign))
}
*p = Peer{
PeerIP: j.PeerIP,
Name: j.Name,
Addr4: j.Addr4,
Addr6: j.Addr6,
Port: j.Port,
Relay: j.Relay,
WGPubKey: key,
SignPubKey: [32]byte(sign),
}
return nil
} }
type NetworkState struct { type NetworkState struct {
Peers []Peer Peers [256]*Peer
} }

View File

@@ -1,105 +0,0 @@
package peer
import (
"net/netip"
"os"
"os/signal"
"syscall"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"vppn/m"
"vppn/peer/control"
"vppn/peer/multicast"
"vppn/peer/wginterface"
)
var _ WGDevice = (*wginterface.Device)(nil) // compile-time check: Device satisfies WGDevice
const (
ControlPort = 4561
PingInterval = 8 * time.Second
TimeoutInterval = 30 * time.Second
)
// scratchSize is large enough for the biggest buffer either the ping or the
// multicast path serializes through the shared App scratch.
const scratchSize = max(control.Size, multicast.SignedPacketSize)
type PingEvent struct {
srcVPNIP netip.Addr
ping control.Ping
}
// App is the peer application. All mutable state lives here and is
// accessed only from the Run goroutine.
type App struct {
// Identity
vpnIP netip.Addr
vpnNet netip.Prefix
privKey wgtypes.Key
pubKey wgtypes.Key
isRelay bool
isPublic bool
localDomain string
// Infrastructure
dev WGDevice
controlConn ControlConn
// Peer state
relay *Peer
peersByKey map[wgtypes.Key]*Peer
peersByIP map[netip.Addr]*Peer
// Our own external endpoints, learned from Dst fields in incoming pings
selfV4 netip.AddrPort
selfV6 netip.AddrPort
// Reusable serialization scratch for outgoing pings and multicast signature
// verification. Only touched from the Run goroutine.
scratch []byte
// Event channels fed by background goroutines
hubAddCh <-chan m.Peer
hubRemoveCh <-chan wgtypes.Key
pingCh <-chan PingEvent
multicastCh <-chan multicast.Packet
}
// Run is the main event loop. It runs until SIGTERM/SIGINT.
func (a *App) Run() error {
// Establish a clean hosts section before the first poll lands, clearing
// any stale entries left by a prior run (e.g. crash, or peers removed
// while we were down).
a.updateHosts()
ticker := time.NewTicker(PingInterval)
defer ticker.Stop()
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
defer signal.Stop(sig)
for {
select {
case p := <-a.hubAddCh:
a.onAddPeer(p)
case key := <-a.hubRemoveCh:
a.onRemovePeer(key)
case e := <-a.pingCh:
a.onPing(e)
case e := <-a.multicastCh:
a.onMulticastDiscovery(e)
case <-ticker.C:
a.onTick()
case <-sig:
return a.onShutdown()
}
}
}
func (a *App) onShutdown() error {
return wginterface.Delete(a.dev.Name())
}

View File

@@ -1,62 +0,0 @@
package peer
import (
"net/netip"
"testing"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"vppn/m"
"vppn/peer/multicast"
)
// addRelayPeer adds a public relay peer and marks it Up so it satisfies
// CanRelay. It does not set a.relay — callers do that explicitly.
func addRelayPeer(t *testing.T, a *App, vpnIP string, ep netip.AddrPort) *Peer {
t.Helper()
key := mustKey(t)
ip := netip.MustParseAddr(vpnIP)
a.onAddPeer(m.Peer{
WGPubKey: key,
PeerIP: ip.As4()[3],
Addr4: ep.Addr(),
Port: ep.Port(),
Relay: true,
})
p := a.peersByKey[key]
p.wgPeer.LastHandshakeTime = time.Now()
return p
}
// newTestApp returns a minimal App wired to a fakeWGDevice and fakeControlConn.
// vpnIP is the local VPN address (e.g. "10.0.0.1").
// isPublic / isRelay describe the local node's role.
func newTestApp(t *testing.T, vpnIP string, isPublic, isRelay bool) (*App, *fakeWGDevice, *fakeControlConn) {
t.Helper()
privKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatalf("generate key: %v", err)
}
ip := netip.MustParseAddr(vpnIP)
dev := &fakeWGDevice{}
cc := &fakeControlConn{}
a := &App{
vpnIP: ip,
vpnNet: netip.MustParsePrefix("10.0.0.0/24"),
privKey: privKey,
pubKey: privKey.PublicKey(),
isPublic: isPublic,
isRelay: isRelay,
dev: dev,
controlConn: cc,
peersByKey: make(map[wgtypes.Key]*Peer),
peersByIP: make(map[netip.Addr]*Peer),
scratch: make([]byte, scratchSize),
hubAddCh: make(chan m.Peer),
hubRemoveCh: make(chan wgtypes.Key),
pingCh: make(chan PingEvent),
multicastCh: make(chan multicast.Packet),
}
return a, dev, cc
}

21
peer/bitset.go Normal file
View File

@@ -0,0 +1,21 @@
package peer
const bitSetSize = 512 // Multiple of 64.
type bitSet [bitSetSize / 64]uint64
func (bs *bitSet) Set(i int) {
bs[i/64] |= 1 << (i % 64)
}
func (bs *bitSet) Clear(i int) {
bs[i/64] &= ^(1 << (i % 64))
}
func (bs *bitSet) ClearAll() {
clear(bs[:])
}
func (bs *bitSet) Get(i int) bool {
return bs[i/64]&(1<<(i%64)) != 0
}

48
peer/bitset_test.go Normal file
View File

@@ -0,0 +1,48 @@
package peer
import (
"math/rand"
"testing"
)
func TestBitSet(t *testing.T) {
state := make([]bool, bitSetSize)
for i := range state {
state[i] = rand.Float32() > 0.5
}
bs := bitSet{}
for i := range state {
if state[i] {
bs.Set(i)
}
}
for i := range state {
if bs.Get(i) != state[i] {
t.Fatal(i, state[i], bs.Get(i))
}
}
for i := range state {
if rand.Float32() > 0.5 {
state[i] = false
bs.Clear(i)
}
}
for i := range state {
if bs.Get(i) != state[i] {
t.Fatal(i, state[i], bs.Get(i))
}
}
bs.ClearAll()
for i := range state {
if bs.Get(i) {
t.Fatal(i, bs.Get(i))
}
}
}

26
peer/cipher-control.go Normal file
View File

@@ -0,0 +1,26 @@
package peer
import "golang.org/x/crypto/nacl/box"
type controlCipher struct {
sharedKey [32]byte
}
func newControlCipher(privKey, pubKey []byte) *controlCipher {
shared := [32]byte{}
box.Precompute(&shared, (*[32]byte)(pubKey), (*[32]byte)(privKey))
return &controlCipher{shared}
}
func (cc *controlCipher) Encrypt(h Header, data, out []byte) []byte {
const s = controlHeaderSize
out = out[:s+controlCipherOverhead+len(data)]
h.Marshal(out[:s])
box.SealAfterPrecomputation(out[s:s], data, (*[24]byte)(out[:s]), &cc.sharedKey)
return out
}
func (cc *controlCipher) Decrypt(encrypted, out []byte) (data []byte, ok bool) {
const s = controlHeaderSize
return box.OpenAfterPrecomputation(out[:0], encrypted[s:], (*[24]byte)(encrypted[:s]), &cc.sharedKey)
}

122
peer/cipher-control_test.go Normal file
View File

@@ -0,0 +1,122 @@
package peer
import (
"bytes"
"crypto/rand"
"reflect"
"testing"
"golang.org/x/crypto/nacl/box"
)
func newControlCipherForTesting() (c1, c2 *controlCipher) {
pubKey1, privKey1, err := box.GenerateKey(rand.Reader)
if err != nil {
panic(err)
}
pubKey2, privKey2, err := box.GenerateKey(rand.Reader)
if err != nil {
panic(err)
}
return newControlCipher(privKey1[:], pubKey2[:]),
newControlCipher(privKey2[:], pubKey1[:])
}
func TestControlCipher(t *testing.T) {
c1, c2 := newControlCipherForTesting()
maxSizePlaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead)
rand.Read(maxSizePlaintext)
testCases := [][]byte{
make([]byte, 0),
{1},
{255},
{1, 2, 3, 4, 5},
[]byte("Hello world"),
maxSizePlaintext,
}
for _, plaintext := range testCases {
h1 := Header{
StreamID: controlStreamID,
Counter: 235153,
SourceIP: 4,
DestIP: 88,
}
encrypted := make([]byte, bufferSize)
encrypted = c1.Encrypt(h1, plaintext, encrypted)
h2 := Header{}
h2.Parse(encrypted)
if !reflect.DeepEqual(h1, h2) {
t.Fatal(h1, h2)
}
decrypted, ok := c2.Decrypt(encrypted, make([]byte, bufferSize))
if !ok {
t.Fatal(ok)
}
if !bytes.Equal(decrypted, plaintext) {
t.Fatal("not equal")
}
}
}
func TestControlCipher_ShortCiphertext(t *testing.T) {
c1, _ := newControlCipherForTesting()
shortText := make([]byte, controlHeaderSize+controlCipherOverhead-1)
rand.Read(shortText)
_, ok := c1.Decrypt(shortText, make([]byte, bufferSize))
if ok {
t.Fatal(ok)
}
}
func BenchmarkControlCipher_Encrypt(b *testing.B) {
c1, _ := newControlCipherForTesting()
h1 := Header{
Counter: 235153,
SourceIP: 4,
DestIP: 88,
}
plaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead)
rand.Read(plaintext)
encrypted := make([]byte, bufferSize)
b.ResetTimer()
for i := 0; i < b.N; i++ {
encrypted = c1.Encrypt(h1, plaintext, encrypted)
}
}
func BenchmarkControlCipher_Decrypt(b *testing.B) {
c1, c2 := newControlCipherForTesting()
h1 := Header{
Counter: 235153,
SourceIP: 4,
DestIP: 88,
}
plaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead)
rand.Read(plaintext)
encrypted := make([]byte, bufferSize)
encrypted = c1.Encrypt(h1, plaintext, encrypted)
decrypted := make([]byte, bufferSize)
b.ResetTimer()
for i := 0; i < b.N; i++ {
decrypted, _ = c2.Decrypt(encrypted, decrypted)
}
}

61
peer/cipher-data.go Normal file
View File

@@ -0,0 +1,61 @@
package peer
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"log"
)
type dataCipher struct {
key [32]byte
aead cipher.AEAD
}
func newDataCipher() *dataCipher {
key := [32]byte{}
if _, err := rand.Read(key[:]); err != nil {
log.Fatalf("Failed to read random data: %v", err)
}
return newDataCipherFromKey(key)
}
func newDataCipherFromKey(key [32]byte) *dataCipher {
block, err := aes.NewCipher(key[:])
if err != nil {
log.Fatalf("Failed to create new cipher: %v", err)
}
aead, err := cipher.NewGCM(block)
if err != nil {
log.Fatalf("Failed to create new GCM: %v", err)
}
return &dataCipher{key: key, aead: aead}
}
func (sc *dataCipher) Key() [32]byte {
return sc.key
}
func (sc *dataCipher) Encrypt(h Header, data, out []byte) []byte {
const s = dataHeaderSize
out = out[:s+dataCipherOverhead+len(data)]
h.Marshal(out[:s])
sc.aead.Seal(out[s:s], out[:s], data, nil)
return out
}
func (sc *dataCipher) Decrypt(encrypted, out []byte) (data []byte, ok bool) {
const s = dataHeaderSize
if len(encrypted) < s+dataCipherOverhead {
ok = false
return
}
var err error
data, err = sc.aead.Open(out[:0], encrypted[:s], encrypted[s:], nil)
ok = err == nil
return
}

141
peer/cipher-data_test.go Normal file
View File

@@ -0,0 +1,141 @@
package peer
import (
"bytes"
"crypto/rand"
mrand "math/rand/v2"
"reflect"
"testing"
)
func TestDataCipher(t *testing.T) {
maxSizePlaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead)
rand.Read(maxSizePlaintext)
testCases := [][]byte{
make([]byte, 0),
{1},
{255},
{1, 2, 3, 4, 5},
[]byte("Hello world"),
maxSizePlaintext,
}
for _, plaintext := range testCases {
h1 := Header{
StreamID: dataStreamID,
Counter: 235153,
SourceIP: 4,
DestIP: 88,
}
encrypted := make([]byte, bufferSize)
dc1 := newDataCipher()
encrypted = dc1.Encrypt(h1, plaintext, encrypted)
h2 := Header{}
h2.Parse(encrypted)
dc2 := newDataCipherFromKey(dc1.Key())
decrypted, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize))
if !ok {
t.Fatal(ok)
}
if !bytes.Equal(plaintext, decrypted) {
t.Fatal("not equal")
}
if !reflect.DeepEqual(h1, h2) {
t.Fatalf("%v != %v", h1, h2)
}
}
}
func TestDataCipher_ModifyCiphertext(t *testing.T) {
maxSizePlaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead)
rand.Read(maxSizePlaintext)
testCases := [][]byte{
make([]byte, 0),
{1},
{255},
{1, 2, 3, 4, 5},
[]byte("Hello world"),
maxSizePlaintext,
}
for _, plaintext := range testCases {
h1 := Header{
Counter: 235153,
SourceIP: 4,
DestIP: 88,
}
encrypted := make([]byte, bufferSize)
dc1 := newDataCipher()
encrypted = dc1.Encrypt(h1, plaintext, encrypted)
encrypted[mrand.IntN(len(encrypted))]++
dc2 := newDataCipherFromKey(dc1.Key())
_, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize))
if ok {
t.Fatal(ok)
}
}
}
func TestDataCipher_ShortCiphertext(t *testing.T) {
dc1 := newDataCipher()
shortText := make([]byte, dataHeaderSize+dataCipherOverhead-1)
rand.Read(shortText)
_, ok := dc1.Decrypt(shortText, make([]byte, bufferSize))
if ok {
t.Fatal(ok)
}
}
func BenchmarkDataCipher_Encrypt(b *testing.B) {
h1 := Header{
Counter: 235153,
SourceIP: 4,
DestIP: 88,
}
plaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead)
rand.Read(plaintext)
encrypted := make([]byte, bufferSize)
dc1 := newDataCipher()
b.ResetTimer()
for i := 0; i < b.N; i++ {
encrypted = dc1.Encrypt(h1, plaintext, encrypted)
}
}
func BenchmarkDataCipher_Decrypt(b *testing.B) {
h1 := Header{
Counter: 235153,
SourceIP: 4,
DestIP: 88,
}
plaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead)
rand.Read(plaintext)
encrypted := make([]byte, bufferSize)
dc1 := newDataCipher()
encrypted = dc1.Encrypt(h1, plaintext, encrypted)
decrypted := make([]byte, bufferSize)
b.ResetTimer()
for i := 0; i < b.N; i++ {
decrypted, _ = dc1.Decrypt(encrypted, decrypted)
}
}

46
peer/connreader.go Normal file
View File

@@ -0,0 +1,46 @@
package peer
import (
"log"
"net"
"net/netip"
)
type ConnReader struct {
Globals
conn *net.UDPConn
buf []byte
}
func NewConnReader(g Globals, conn *net.UDPConn) *ConnReader {
return &ConnReader{
Globals: g,
conn: conn,
buf: make([]byte, bufferSize),
}
}
func (r *ConnReader) Run() {
for {
r.handleNextPacket()
}
}
func (r *ConnReader) handleNextPacket() {
buf := r.buf[:bufferSize]
n, remoteAddr, err := r.conn.ReadFromUDPAddrPort(buf)
if err != nil {
log.Fatalf("Failed to read from UDP port: %v", err)
}
if n < headerSize {
return
}
remoteAddr = netip.AddrPortFrom(remoteAddr.Addr().Unmap(), remoteAddr.Port())
buf = buf[:n]
h := parseHeader(buf)
r.RemotePeers[h.SourceIP].Load().HandlePacket(h, remoteAddr, buf)
}

View File

@@ -1,76 +0,0 @@
// Package control implements the VPN-internal peer control protocol.
// Peers exchange Ping packets over UDP on the VPN control port to maintain
// liveness and discover external endpoints for direct connection attempts.
package control
import (
"encoding/binary"
"fmt"
"net/netip"
)
const (
version = 1
Size = 51 // 1 version + 8 PingTS + 6 SrcV4 + 18 SrcV6 + 18 Dst
)
// Ping is the single control packet type exchanged between VPN peers.
//
// In each peer pair, the peer with the lower VPN IP is the client: it sets
// PingTS and sends pings on a timer. The server echoes PingTS back in its
// response, allowing the client to compute RTT = now - PingTS.
//
// Both client and server populate SrcV4, SrcV6, and Dst on every packet so
// endpoint information flows in both directions.
//
// Dst is the recipient's external endpoint as observed by the sender from the
// WireGuard handshake source. Zero if the sender has not observed a handshake
// from the recipient.
type Ping struct {
PingTS int64 // Client ping send time in nanoseconds.
SrcV4 netip.AddrPort // Sender's discovered IPv4 address and port.
SrcV6 netip.AddrPort // Sender's discovered IPv6 address and port.
Dst netip.AddrPort
}
// Marshal encodes p into buf (which must be at least Size bytes) and returns
// buf[:Size]. Taking the buffer lets callers reuse one across sends; every
// field is written unconditionally so a reused buffer needs no pre-zeroing.
func (p Ping) Marshal(buf []byte) []byte {
buf[0] = version
binary.BigEndian.PutUint64(buf[1:9], uint64(p.PingTS))
if p.SrcV4.IsValid() {
a4 := p.SrcV4.Addr().As4()
copy(buf[9:13], a4[:])
binary.BigEndian.PutUint16(buf[13:15], p.SrcV4.Port())
} else {
clear(buf[9:15])
}
a16 := p.SrcV6.Addr().As16()
copy(buf[15:31], a16[:])
binary.BigEndian.PutUint16(buf[31:33], p.SrcV6.Port())
a16 = p.Dst.Addr().As16()
copy(buf[33:49], a16[:])
binary.BigEndian.PutUint16(buf[49:51], p.Dst.Port())
return buf[:Size]
}
// Unmarshal decodes a Ping from a fixed-size 51-byte array.
func Unmarshal(buf [Size]byte) (Ping, error) {
if buf[0] != version {
return Ping{}, fmt.Errorf("unknown ping version %d", buf[0])
}
p := Ping{
PingTS: int64(binary.BigEndian.Uint64(buf[1:9])),
}
if addr := netip.AddrFrom4([4]byte(buf[9:13])); !addr.IsUnspecified() {
p.SrcV4 = netip.AddrPortFrom(addr, binary.BigEndian.Uint16(buf[13:15]))
}
if addr := netip.AddrFrom16([16]byte(buf[15:31])); !addr.IsUnspecified() {
p.SrcV6 = netip.AddrPortFrom(addr, binary.BigEndian.Uint16(buf[31:33]))
}
if addr := netip.AddrFrom16([16]byte(buf[33:49])).Unmap(); !addr.IsUnspecified() {
p.Dst = netip.AddrPortFrom(addr, binary.BigEndian.Uint16(buf[49:51]))
}
return p, nil
}

View File

@@ -1,106 +0,0 @@
package control_test
import (
"net/netip"
"testing"
"vppn/peer/control"
)
func TestRoundTrip(t *testing.T) {
cases := []struct {
name string
ping control.Ping
}{
{
name: "zero",
ping: control.Ping{},
},
{
name: "client ping",
ping: control.Ping{
PingTS: 1234567890,
SrcV4: netip.MustParseAddrPort("1.2.3.4:51820"),
Dst: netip.MustParseAddrPort("5.6.7.8:51820"),
},
},
{
name: "server response",
ping: control.Ping{
PingTS: 1234567890,
SrcV4: netip.MustParseAddrPort("5.6.7.8:51820"),
Dst: netip.MustParseAddrPort("1.2.3.4:9999"),
},
},
{
name: "IPv6 only",
ping: control.Ping{
PingTS: 999,
SrcV6: netip.MustParseAddrPort("[2001:db8::1]:51820"),
Dst: netip.MustParseAddrPort("[2001:db8::2]:51820"),
},
},
{
name: "dual stack",
ping: control.Ping{
PingTS: 555,
SrcV4: netip.MustParseAddrPort("1.2.3.4:51820"),
SrcV6: netip.MustParseAddrPort("[2001:db8::1]:51820"),
Dst: netip.MustParseAddrPort("5.6.7.8:9999"),
},
},
{
name: "no src known",
ping: control.Ping{
Dst: netip.MustParseAddrPort("5.6.7.8:51820"),
},
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
var buf [control.Size]byte
tc.ping.Marshal(buf[:])
got, err := control.Unmarshal(buf)
if err != nil {
t.Fatalf("Unmarshal: %v", err)
}
if got != tc.ping {
t.Fatalf("round-trip mismatch:\n got %+v\n want %+v", got, tc.ping)
}
})
}
}
func TestUnmarshalBadVersion(t *testing.T) {
var buf [control.Size]byte
buf[0] = 99
if _, err := control.Unmarshal(buf); err == nil {
t.Fatal("expected error for unknown version, got nil")
}
}
func TestZeroEncoding(t *testing.T) {
var buf [control.Size]byte
(control.Ping{}).Marshal(buf[:])
for i, b := range buf {
if i == 0 {
continue // version byte
}
if b != 0 {
t.Fatalf("expected zero encoding at byte %d, got %d", i, b)
}
}
}
func TestRoleFor(t *testing.T) {
lo := netip.MustParseAddr("10.0.0.1")
hi := netip.MustParseAddr("10.0.0.2")
if control.RoleFor(lo, hi) != control.Client {
t.Error("lower IP should be client")
}
if control.RoleFor(hi, lo) != control.Server {
t.Error("higher IP should be server")
}
}

View File

@@ -1,22 +0,0 @@
package control
import "net/netip"
// Role identifies a peer's role in a ping exchange with a specific remote peer.
type Role string
const (
// Client initiates pings and measures RTT.
Client Role = "CLIENT"
// Server responds to pings.
Server Role = "SERVER"
)
// RoleFor returns the Role of local relative to remote.
// The peer with the lower VPN IP is the client.
func RoleFor(local, remote netip.Addr) Role {
if local.Compare(remote) < 0 {
return Client
}
return Server
}

View File

@@ -1,64 +0,0 @@
package peer
import (
"log"
"net"
"net/netip"
"vppn/peer/control"
)
var _ ControlConn = (*udpControlConn)(nil)
type udpControlConn struct {
conn *net.UDPConn
}
// newUDPControlConn opens a UDP socket bound to localIP:port.
func newUDPControlConn(localIP netip.Addr, port uint16) (*udpControlConn, error) {
addr := net.UDPAddrFromAddrPort(netip.AddrPortFrom(localIP, port))
conn, err := net.ListenUDP("udp4", addr)
if err != nil {
return nil, err
}
return &udpControlConn{conn: conn}, nil
}
func (c *udpControlConn) SendPing(dst netip.AddrPort, ping control.Ping, buf []byte) error {
_, err := c.conn.WriteToUDP(ping.Marshal(buf), net.UDPAddrFromAddrPort(dst))
return err
}
// run reads incoming ping packets and forwards them to ch until ctx is done.
// Call this in a goroutine before starting the App event loop.
func (c *udpControlConn) run(ch chan<- PingEvent) {
var buf [control.Size]byte
for {
n, src, err := c.conn.ReadFromUDP(buf[:])
if err != nil {
log.Printf("control read: %v", err)
continue
}
if n != control.Size {
continue
}
ping, err := control.Unmarshal(buf)
if err != nil {
log.Printf("control unmarshal: %v", err)
continue
}
srcIP, ok := netip.AddrFromSlice(src.IP)
if !ok {
continue
}
ch <- PingEvent{srcVPNIP: srcIP.Unmap(), ping: ping}
}
}
func (c *udpControlConn) Close() error {
return c.conn.Close()
}

64
peer/controlmessage.go Normal file
View File

@@ -0,0 +1,64 @@
package peer
import (
"net/netip"
"vppn/m"
)
// ----------------------------------------------------------------------------
type controlMsg[T any] struct {
SrcIP byte
SrcAddr netip.AddrPort
Packet T
}
func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error) {
switch buf[0] {
case packetTypeInit:
packet, err := parsePacketInit(buf)
return controlMsg[packetInit]{
SrcIP: srcIP,
SrcAddr: srcAddr,
Packet: packet,
}, err
case packetTypeSyn:
packet, err := parsePacketSyn(buf)
return controlMsg[packetSyn]{
SrcIP: srcIP,
SrcAddr: srcAddr,
Packet: packet,
}, err
case packetTypeAck:
packet, err := parsePacketAck(buf)
return controlMsg[packetAck]{
SrcIP: srcIP,
SrcAddr: srcAddr,
Packet: packet,
}, err
case packetTypeProbe:
packet, err := parsePacketProbe(buf)
return controlMsg[packetProbe]{
SrcIP: srcIP,
SrcAddr: srcAddr,
Packet: packet,
}, err
default:
return nil, errUnknownPacketType
}
}
// ----------------------------------------------------------------------------
type peerUpdateMsg struct {
Peer *m.Peer
}
// ----------------------------------------------------------------------------
type pingTimerMsg struct{}

30
peer/crypto.go Normal file
View File

@@ -0,0 +1,30 @@
package peer
import (
"crypto/rand"
"log"
"golang.org/x/crypto/nacl/box"
"golang.org/x/crypto/nacl/sign"
)
type cryptoKeys struct {
PubKey []byte
PrivKey []byte
PubSignKey []byte
PrivSignKey []byte
}
func generateKeys() cryptoKeys {
pubKey, privKey, err := box.GenerateKey(rand.Reader)
if err != nil {
log.Fatalf("Failed to generate encryption keys: %v", err)
}
pubSignKey, privSignKey, err := sign.GenerateKey(rand.Reader)
if err != nil {
log.Fatalf("Failed to generate signing keys: %v", err)
}
return cryptoKeys{pubKey[:], privKey[:], pubSignKey[:], privSignKey[:]}
}

View File

@@ -1,79 +0,0 @@
package peer
import (
"errors"
"log"
"net/netip"
"syscall"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// devRetry calls fn up to 6 times with exponential backoff, retrying on EBUSY
// (transient netlink contention during WireGuard handshake/rekey). Fatal on
// any other error.
func devRetry(vpnIP netip.Addr, op string, fn func() error) {
const attempts = 6
timeout := 10 * time.Millisecond
for i := range attempts {
err := fn()
if err == nil {
return
}
if errors.Is(err, syscall.EBUSY) && i < attempts-1 {
time.Sleep(timeout)
timeout *= 2
continue
}
log.Fatalf("%s %v: %v", op, vpnIP, err)
}
}
func (a *App) devPeers() []wgtypes.Peer {
peers, err := a.dev.Peers()
if err != nil {
log.Fatalf("Failed to get peers %v: %v", a.vpnIP, err)
}
return peers
}
func (a *App) devAddPeer(p *Peer) {
log.Printf("RELAYED: %s - %s ", p.Name, p.VPNIP.String())
devRetry(p.VPNIP, "AddPeer", func() error { return a.dev.AddPeer(p.PubKey()) })
p.State = StateRelayed
}
func (a *App) devAddDirect(p *Peer, endpoint netip.AddrPort) {
log.Printf("DIRECT: %s - %s @ %s", p.Name, p.VPNIP.String(), endpoint.String())
devRetry(p.VPNIP, "AddDirect", func() error { return a.dev.AddDirect(p.PubKey(), endpoint, p.VPNIP) })
p.State = StateDirect
}
func (a *App) devSetRelay(p *Peer, endpoint netip.AddrPort) {
log.Printf("RELAY: %s - %s @ %s", p.Name, p.VPNIP.String(), endpoint.String())
devRetry(p.VPNIP, "SetRelay", func() error { return a.dev.SetRelay(p.PubKey(), endpoint, a.vpnNet) })
p.State = StateDirect // Dirrect connection. The app marks peer as relay.
}
func (a *App) devPromote(p *Peer) {
ep := p.WGEndpoint()
if ep.IsValid() {
log.Printf("PROMOTED: %s - %s @ %s", p.Name, p.VPNIP.String(), p.WGEndpoint().String())
} else {
log.Printf("PROMOTED: %s - %s (no IP)", p.Name, p.VPNIP.String())
}
devRetry(p.VPNIP, "Promote", func() error { return a.dev.Promote(p.PubKey(), p.VPNIP) })
p.State = StateDirect
}
func (a *App) devAddProbe(p *Peer, endpoint netip.AddrPort) {
log.Printf("PROBE: %s - %s @ %s", p.Name, p.VPNIP.String(), endpoint.String())
devRetry(p.VPNIP, "AddProbe", func() error { return a.dev.AddProbe(p.PubKey(), endpoint) })
p.State = StateProbing
}
func (a *App) devRemove(p *Peer) {
log.Printf("REMOVED: %s - %s", p.Name, p.VPNIP.String())
devRetry(p.VPNIP, "RemovePeer", func() error { return a.dev.RemovePeer(p.PubKey()) })
}

76
peer/dupcheck.go Normal file
View File

@@ -0,0 +1,76 @@
package peer
type dupCheck struct {
bitSet
head int
tail int
headCounter uint64
tailCounter uint64 // Also next expected counter value.
}
func newDupCheck(headCounter uint64) *dupCheck {
return &dupCheck{
headCounter: headCounter,
tailCounter: headCounter + 1,
tail: 1,
}
}
func (dc *dupCheck) IsDup(counter uint64) bool {
// Before head => it's late, say it's a dup.
if counter < dc.headCounter {
return true
}
// It's within the counter bounds.
if counter < dc.tailCounter {
index := (int(counter-dc.headCounter) + dc.head) % bitSetSize
if dc.Get(index) {
return true
}
dc.Set(index)
return false
}
// It's more than 1 beyond the tail.
delta := counter - dc.tailCounter
// Full clear.
if delta >= bitSetSize-1 {
dc.ClearAll()
dc.Set(0)
dc.tail = 1
dc.head = 2
dc.tailCounter = counter + 1
dc.headCounter = dc.tailCounter - bitSetSize + 1
return false
}
// Clear if necessary.
for range delta {
dc.put(false)
}
dc.put(true)
return false
}
func (dc *dupCheck) put(set bool) {
if set {
dc.Set(dc.tail)
} else {
dc.Clear(dc.tail)
}
dc.tail = (dc.tail + 1) % bitSetSize
dc.tailCounter++
if dc.head == dc.tail {
dc.head = (dc.head + 1) % bitSetSize
dc.headCounter++
}
}

57
peer/dupcheck_test.go Normal file
View File

@@ -0,0 +1,57 @@
package peer
import (
"testing"
)
func TestDupCheck(t *testing.T) {
dc := newDupCheck(0)
for i := range bitSetSize {
if dc.IsDup(uint64(i)) {
t.Fatal("!")
}
}
type TestCase struct {
Counter uint64
Dup bool
}
testCases := []TestCase{
{511, true},
{0, true},
{1, true},
{2, true},
{3, true},
{63, true},
{256, true},
{510, true},
{511, true},
{512, false},
{0, true},
{512, true},
{513, false},
{517, false},
{512, true},
{513, true},
{514, false},
{515, false},
{516, false},
{517, true},
{2512, false},
{2512, true},
{2001, true},
{2002, false},
{2002, true},
{4000, false},
{4000 - 511, true}, // Too old.
{4000 - 510, false}, // Just in the window.
}
for i, tc := range testCases {
if ok := dc.IsDup(tc.Counter); ok != tc.Dup {
t.Fatal(i, ok, tc)
}
}
}

8
peer/errors.go Normal file
View File

@@ -0,0 +1,8 @@
package peer
import "errors"
var (
errMalformedPacket = errors.New("malformed packet")
errUnknownPacketType = errors.New("unknown packet type")
)

View File

@@ -1,43 +0,0 @@
package peer
import (
"net/netip"
"testing"
"vppn/peer/control"
)
type sentPing struct {
Dst netip.AddrPort
Ping control.Ping
}
type fakeControlConn struct {
Sent []sentPing
}
func (f *fakeControlConn) SendPing(dst netip.AddrPort, ping control.Ping, _ []byte) error {
f.Sent = append(f.Sent, sentPing{Dst: dst, Ping: ping})
return nil
}
func (f *fakeControlConn) AssertNone(t *testing.T) {
t.Helper()
if len(f.Sent) != 0 {
t.Fatalf("expected no pings sent, got %d: %v", len(f.Sent), f.Sent)
}
}
func (f *fakeControlConn) AssertSent(t *testing.T, i int, dst netip.AddrPort, ping control.Ping) {
t.Helper()
if i >= len(f.Sent) {
t.Fatalf("no ping at index %d (have %d)", i, len(f.Sent))
}
got := f.Sent[i]
if got.Dst != dst {
t.Errorf("ping[%d].Dst = %v, want %v", i, got.Dst, dst)
}
if got.Ping != ping {
t.Errorf("ping[%d].Ping = %+v, want %+v", i, got.Ping, ping)
}
}

View File

@@ -1,123 +0,0 @@
package peer
import (
"net/netip"
"sync"
"testing"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// fakeWGDevice records every call made to it. It is safe to read Calls after
// the event loop has processed the event under test (single-threaded loop
// means no extra synchronisation needed, but the mutex guards concurrent test
// helpers if needed).
type fakeWGDevice struct {
mu sync.Mutex
Calls []fakeCall
peers []wgtypes.Peer
}
type fakeCall struct {
Method string
PubKey wgtypes.Key
Endpoint netip.AddrPort
VPNiP netip.Addr
Network netip.Prefix
}
func (f *fakeWGDevice) record(c fakeCall) {
f.mu.Lock()
f.Calls = append(f.Calls, c)
f.mu.Unlock()
}
func (f *fakeWGDevice) Name() string { return "wg-test" }
func (f *fakeWGDevice) Peers() ([]wgtypes.Peer, error) {
f.mu.Lock()
defer f.mu.Unlock()
out := make([]wgtypes.Peer, len(f.peers))
copy(out, f.peers)
return out, nil
}
func (f *fakeWGDevice) AddPeer(pubKey wgtypes.Key) error {
f.record(fakeCall{Method: "AddPeer", PubKey: pubKey})
return nil
}
func (f *fakeWGDevice) AddDirect(pubKey wgtypes.Key, endpoint netip.AddrPort, vpnIP netip.Addr) error {
f.record(fakeCall{Method: "AddDirect", PubKey: pubKey, Endpoint: endpoint, VPNiP: vpnIP})
return nil
}
func (f *fakeWGDevice) SetRelay(pubKey wgtypes.Key, endpoint netip.AddrPort, network netip.Prefix) error {
f.record(fakeCall{Method: "SetRelay", PubKey: pubKey, Endpoint: endpoint, Network: network})
return nil
}
func (f *fakeWGDevice) AddProbe(pubKey wgtypes.Key, endpoint netip.AddrPort) error {
f.record(fakeCall{Method: "AddProbe", PubKey: pubKey, Endpoint: endpoint})
return nil
}
func (f *fakeWGDevice) Promote(pubKey wgtypes.Key, vpnIP netip.Addr) error {
f.record(fakeCall{Method: "Promote", PubKey: pubKey, VPNiP: vpnIP})
return nil
}
func (f *fakeWGDevice) RemovePeer(pubKey wgtypes.Key) error {
f.record(fakeCall{Method: "RemovePeer", PubKey: pubKey})
return nil
}
// AssertNoCalls fails the test if any dev calls were recorded.
func (f *fakeWGDevice) AssertNoCalls(t *testing.T) {
t.Helper()
f.mu.Lock()
defer f.mu.Unlock()
if len(f.Calls) != 0 {
t.Fatalf("unexpected dev calls: %v", f.Calls)
}
}
func (f *fakeWGDevice) AssertAddPeer(t *testing.T, i int, pubKey wgtypes.Key) {
t.Helper()
f.assertCall(t, i, fakeCall{Method: "AddPeer", PubKey: pubKey})
}
func (f *fakeWGDevice) AssertAddDirect(t *testing.T, i int, pubKey wgtypes.Key, endpoint netip.AddrPort, vpnIP netip.Addr) {
t.Helper()
f.assertCall(t, i, fakeCall{Method: "AddDirect", PubKey: pubKey, Endpoint: endpoint, VPNiP: vpnIP})
}
func (f *fakeWGDevice) AssertSetRelay(t *testing.T, i int, pubKey wgtypes.Key, endpoint netip.AddrPort, network netip.Prefix) {
t.Helper()
f.assertCall(t, i, fakeCall{Method: "SetRelay", PubKey: pubKey, Endpoint: endpoint, Network: network})
}
func (f *fakeWGDevice) AssertAddProbe(t *testing.T, i int, pubKey wgtypes.Key, endpoint netip.AddrPort) {
t.Helper()
f.assertCall(t, i, fakeCall{Method: "AddProbe", PubKey: pubKey, Endpoint: endpoint})
}
func (f *fakeWGDevice) AssertPromote(t *testing.T, i int, pubKey wgtypes.Key, vpnIP netip.Addr) {
t.Helper()
f.assertCall(t, i, fakeCall{Method: "Promote", PubKey: pubKey, VPNiP: vpnIP})
}
func (f *fakeWGDevice) AssertRemovePeer(t *testing.T, i int, pubKey wgtypes.Key) {
t.Helper()
f.assertCall(t, i, fakeCall{Method: "RemovePeer", PubKey: pubKey})
}
func (f *fakeWGDevice) assertCall(t *testing.T, i int, c fakeCall) {
t.Helper()
if len(f.Calls) <= i {
t.Fatalf("no call at index %d: %v", i, c)
}
if c != f.Calls[i] {
t.Fatalf("call[%d]: got %v, want %v", i, f.Calls[i], c)
}
}

115
peer/files.go Normal file
View File

@@ -0,0 +1,115 @@
package peer
import (
"encoding/json"
"log"
"os"
"path/filepath"
"vppn/m"
)
type LocalConfig struct {
LocalPeerIP byte
Network []byte
PubKey []byte
PrivKey []byte
PubSignKey []byte
PrivSignKey []byte
}
type startupCount struct {
Count uint16
}
func configDir(netName string) string {
d, err := os.UserHomeDir()
if err != nil {
log.Fatalf("Failed to get user home directory: %v", err)
}
return filepath.Join(d, ".vppn", netName)
}
func lockFilePath(netName string) string {
return filepath.Join(configDir(netName), "__lock__")
}
func peerConfigPath(netName string) string {
return filepath.Join(configDir(netName), "config.json")
}
func peerStatePath(netName string) string {
return filepath.Join(configDir(netName), "state.json")
}
func startupCountPath(netName string) string {
return filepath.Join(configDir(netName), "startup_count.json")
}
func statusSocketPath(netName string) string {
return filepath.Join(configDir(netName), "status.sock")
}
func storeJson(x any, outPath string) error {
outDir := filepath.Dir(outPath)
_ = os.MkdirAll(outDir, 0700)
tmpPath := outPath + ".tmp"
buf, err := json.Marshal(x)
if err != nil {
return err
}
f, err := os.Create(tmpPath)
if err != nil {
return err
}
if _, err := f.Write(buf); err != nil {
f.Close()
return err
}
if err := f.Sync(); err != nil {
f.Close()
return err
}
if err := f.Close(); err != nil {
return err
}
return os.Rename(tmpPath, outPath)
}
func storePeerConfig(netName string, pc LocalConfig) error {
return storeJson(pc, peerConfigPath(netName))
}
func storeNetworkState(netName string, ps m.NetworkState) error {
return storeJson(ps, peerStatePath(netName))
}
func loadJson(dataPath string, ptr any) error {
data, err := os.ReadFile(dataPath)
if err != nil {
return err
}
return json.Unmarshal(data, ptr)
}
func loadPeerConfig(netName string) (pc LocalConfig, err error) {
return pc, loadJson(peerConfigPath(netName), &pc)
}
func loadNetworkState(netName string) (ps m.NetworkState, err error) {
return ps, loadJson(peerStatePath(netName), &ps)
}
func loadStartupCount(netName string) (c startupCount, err error) {
return c, loadJson(startupCountPath(netName), &c)
}
func storeStartupCount(netName string, c startupCount) error {
return storeJson(c, startupCountPath(netName))
}

57
peer/files_test.go Normal file
View File

@@ -0,0 +1,57 @@
package peer
import (
"path/filepath"
"reflect"
"testing"
)
func TestFilePaths(t *testing.T) {
confDir := configDir("netName")
if filepath.Base(confDir) != "netName" {
t.Fatal(confDir)
}
if filepath.Base(filepath.Dir(confDir)) != ".vppn" {
t.Fatal(confDir)
}
path := peerConfigPath("netName")
if path != filepath.Join(confDir, "config.json") {
t.Fatal(path)
}
path = peerStatePath("netName")
if path != filepath.Join(confDir, "state.json") {
t.Fatal(path)
}
}
func TestStoreLoadJson(t *testing.T) {
type Object struct {
Name string
Age int
Price float64
}
tmpDir := t.TempDir()
outPath := filepath.Join(tmpDir, "object.json")
obj := Object{
Name: "Jason",
Age: 22,
Price: 123.534,
}
if err := storeJson(obj, outPath); err != nil {
t.Fatal(err)
}
obj2 := Object{}
if err := loadJson(outPath, &obj2); err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(obj, obj2) {
t.Fatal(obj, obj2)
}
}

109
peer/globals.go Normal file
View File

@@ -0,0 +1,109 @@
package peer
import (
"io"
"net"
"net/netip"
"sync"
"sync/atomic"
"time"
)
const (
version = 1
bufferSize = 8192 // Enough for data packets and encryption buffers.
if_mtu = 1200
if_queue_len = 2048
controlCipherOverhead = 16
dataCipherOverhead = 16
signingOverhead = 64
pingInterval = 8 * time.Second
timeoutInterval = 30 * time.Second
broadcastInterval = 16 * time.Second
broadcastErrorTimeoutInterval = 8 * time.Second
)
var multicastAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom(
netip.AddrFrom4([4]byte{224, 0, 0, 157}),
4560))
// ----------------------------------------------------------------------------
type Globals struct {
LocalConfig // Embed, immutable.
// The number of startups
StartupCount uint16
// Local public address (if available). Immutable.
LocalAddr netip.AddrPort
// True if local public address is valid. Immutable.
LocalAddrValid bool
// All remote peers by VPN IP.
RemotePeers [256]*atomic.Pointer[Remote]
// Discovered public addresses.
PubAddrs *pubAddrStore
// Attempts to ensure that we have a relay available.
RelayHandler *relayHandler
// Send UDP - Global function to write UDP packets.
SendUDP func(b []byte, addr netip.AddrPort) (n int, err error)
// Global TUN interface.
IFace io.ReadWriteCloser
// For trace ID.
NewTraceID func() uint64
}
func NewGlobals(
localConfig LocalConfig,
startupCount startupCount,
localAddr netip.AddrPort,
conn *net.UDPConn,
iface io.ReadWriteCloser,
) (g Globals) {
g.LocalConfig = localConfig
g.StartupCount = startupCount.Count
g.LocalAddr = localAddr
g.LocalAddrValid = localAddr.IsValid()
g.PubAddrs = newPubAddrStore(localAddr)
g.RelayHandler = newRelayHandler()
// Use a lock here avoids starvation, at least on my Linux machine.
sendLock := sync.Mutex{}
g.SendUDP = func(b []byte, addr netip.AddrPort) (int, error) {
sendLock.Lock()
n, err := conn.WriteToUDPAddrPort(b, addr)
sendLock.Unlock()
return n, err
}
g.IFace = iface
traceID := (uint64(g.StartupCount) << 48) + 1
g.NewTraceID = func() uint64 {
return atomic.AddUint64(&traceID, 1)
}
for i := range g.RemotePeers {
g.RemotePeers[i] = &atomic.Pointer[Remote]{}
}
for i := range g.RemotePeers {
g.RemotePeers[i].Store(newRemote(g, byte(i)))
}
return g
}

47
peer/header.go Normal file
View File

@@ -0,0 +1,47 @@
package peer
import "unsafe"
// ----------------------------------------------------------------------------
const (
headerSize = 12
controlHeaderSize = 24
dataHeaderSize = 12
dataStreamID = 1
controlStreamID = 2
)
type Header struct {
Version byte
StreamID byte
SourceIP byte
DestIP byte
Counter uint64 // Init with time.Now().Unix << 30 to ensure monotonic.
}
func parseHeader(b []byte) (h Header) {
h.Version = b[0]
h.StreamID = b[1]
h.SourceIP = b[2]
h.DestIP = b[3]
h.Counter = *(*uint64)(unsafe.Pointer(&b[4]))
return h
}
func (h *Header) Parse(b []byte) {
h.Version = b[0]
h.StreamID = b[1]
h.SourceIP = b[2]
h.DestIP = b[3]
h.Counter = *(*uint64)(unsafe.Pointer(&b[4]))
}
func (h *Header) Marshal(buf []byte) {
buf[0] = h.Version
buf[1] = h.StreamID
buf[2] = h.SourceIP
buf[3] = h.DestIP
*(*uint64)(unsafe.Pointer(&buf[4])) = h.Counter
}

21
peer/header_test.go Normal file
View File

@@ -0,0 +1,21 @@
package peer
import "testing"
func TestHeaderMarshalParse(t *testing.T) {
nIn := Header{
StreamID: 23,
Counter: 3212,
SourceIP: 34,
DestIP: 200,
}
buf := make([]byte, headerSize)
nIn.Marshal(buf)
nOut := Header{}
nOut.Parse(buf)
if nIn != nOut {
t.Fatal(nIn, nOut)
}
}

View File

@@ -1,128 +0,0 @@
package peer
import (
"fmt"
"log"
"net/netip"
"os"
"sort"
"strings"
"syscall"
"git.crumpington.com/lib/go/flock"
)
const (
hostsFile = "/etc/hosts"
hostsBegin = "# BEGIN vppn"
hostsEnd = "# END vppn"
)
// hostMarkers returns the begin/end marker lines that delimit the managed
// section for localDomain. The domain is wrapped in parentheses so one domain's
// marker can never be a prefix of another's (e.g. "net" vs "net2") when
// multiple vppn instances share /etc/hosts.
func hostMarkers(localDomain string) (begin, end string) {
return hostsBegin + "(" + localDomain + ")", hostsEnd + "(" + localDomain + ")"
}
// updateHosts rewrites the managed vppn section in /etc/hosts using the
// current peersByIP map. Peers without a Name are skipped.
func (a *App) updateHosts() {
if a.localDomain == "" {
return
}
if err := updateHosts(hostsFile, a.localDomain, a.peersByIP); err != nil {
log.Printf("Failed to update hosts file: %v", err)
}
}
func updateHosts(hostsPath, localDomain string, peers map[netip.Addr]*Peer) error {
lockFile, err := flock.Lock(hostsPath + ".vppn.lock")
if err != nil {
return err
}
defer lockFile.Close()
begin, end := hostMarkers(localDomain)
info, err := os.Stat(hostsPath)
if err != nil {
return err
}
raw, err := os.ReadFile(hostsPath)
if err != nil {
return err
}
data := string(raw)
before := strings.TrimSpace(data)
after := ""
if idxBegin := strings.Index(data, begin); idxBegin != -1 {
idxEnd := strings.Index(data[idxBegin:], end)
if idxEnd != -1 {
after = strings.TrimSpace(data[idxBegin+idxEnd+len(end):])
}
before = strings.TrimSpace(data[:idxBegin])
}
b := strings.Builder{}
b.WriteString(before)
b.WriteRune('\n')
b.WriteString(after)
b.WriteRune('\n')
b.WriteRune('\n')
b.WriteString(begin)
b.WriteRune('\n')
// Collect entries so we can sort by IP for stable output. Pad the IP
// column to the width of the widest possible address ("255.255.255.255")
// for readability.
type entry struct {
ip netip.Addr
host string
}
var entries []entry
for ip, p := range peers {
if p.Name == "" {
continue
}
entries = append(entries, entry{ip: ip, host: p.Name + "." + localDomain})
}
sort.Slice(entries, func(i, j int) bool {
return entries[i].ip.Less(entries[j].ip)
})
for _, e := range entries {
b.WriteString(fmt.Sprintf("%-15s %s\n", e.ip.String(), e.host))
}
b.WriteString(end)
b.WriteRune('\n')
// Write to a temp file in the same directory, then rename over the
// original so readers never observe a partial file. Preserve the
// original's mode and ownership, since rename replaces the inode.
tmpPath := hostsPath + ".vppn.tmp"
if err := os.WriteFile(tmpPath, []byte(b.String()), info.Mode().Perm()); err != nil {
return err
}
if st, ok := info.Sys().(*syscall.Stat_t); ok {
if err := os.Chown(tmpPath, int(st.Uid), int(st.Gid)); err != nil {
os.Remove(tmpPath)
return err
}
}
if err := os.Rename(tmpPath, hostsPath); err != nil {
os.Remove(tmpPath)
return err
}
return nil
}

View File

@@ -1,205 +0,0 @@
package peer
import (
"net/netip"
"os"
"path/filepath"
"sort"
"strings"
"testing"
)
// writeTempHosts creates a temp hosts file with the given content and returns
// its path.
func writeTempHosts(t *testing.T, content string) string {
t.Helper()
path := filepath.Join(t.TempDir(), "hosts")
if err := os.WriteFile(path, []byte(content), 0600); err != nil {
t.Fatal(err)
}
return path
}
// readManagedSection returns the lines between the begin/end markers for the
// given localDomain, plus everything outside the section ("outside").
func readManagedSection(t *testing.T, path, localDomain string) (inside, outside []string) {
t.Helper()
raw, err := os.ReadFile(path)
if err != nil {
t.Fatal(err)
}
begin, end := hostMarkers(localDomain)
inSection := false
for _, line := range strings.Split(string(raw), "\n") {
switch {
case strings.HasPrefix(line, begin):
inSection = true
case strings.HasPrefix(line, end):
inSection = false
case inSection:
if f := strings.Join(strings.Fields(line), " "); f != "" {
inside = append(inside, f)
}
default:
if f := strings.Join(strings.Fields(line), " "); f != "" {
outside = append(outside, f)
}
}
}
return inside, outside
}
func peer(name string) *Peer {
return &Peer{Name: name}
}
func TestUpdateHosts_AddsSection(t *testing.T) {
path := writeTempHosts(t, "127.0.0.1 localhost\n")
peers := map[netip.Addr]*Peer{
netip.MustParseAddr("10.11.12.1"): peer("hub"),
netip.MustParseAddr("10.11.12.10"): peer("laptop"),
}
if err := updateHosts(path, "mynet.local", peers); err != nil {
t.Fatal(err)
}
inside, outside := readManagedSection(t, path, "mynet.local")
sort.Strings(inside)
want := []string{
"10.11.12.1 hub.mynet.local",
"10.11.12.10 laptop.mynet.local",
}
if strings.Join(inside, "\n") != strings.Join(want, "\n") {
t.Errorf("managed section = %v, want %v", inside, want)
}
if !contains(outside, "127.0.0.1 localhost") {
t.Errorf("original content lost; outside = %v", outside)
}
}
func TestUpdateHosts_ReplacesExistingSection(t *testing.T) {
path := writeTempHosts(t, "127.0.0.1 localhost\n")
// First write.
first := map[netip.Addr]*Peer{
netip.MustParseAddr("10.11.12.1"): peer("hub"),
}
if err := updateHosts(path, "mynet.local", first); err != nil {
t.Fatal(err)
}
// Second write with a different set of peers.
second := map[netip.Addr]*Peer{
netip.MustParseAddr("10.11.12.20"): peer("phone"),
}
if err := updateHosts(path, "mynet.local", second); err != nil {
t.Fatal(err)
}
inside, outside := readManagedSection(t, path, "mynet.local")
if len(inside) != 1 || inside[0] != "10.11.12.20 phone.mynet.local" {
t.Errorf("section not replaced; inside = %v", inside)
}
if contains(inside, "10.11.12.1 hub.mynet.local") {
t.Errorf("stale entry remained; inside = %v", inside)
}
if !contains(outside, "127.0.0.1 localhost") {
t.Errorf("original content lost; outside = %v", outside)
}
}
func TestUpdateHosts_SkipsEmptyNames(t *testing.T) {
path := writeTempHosts(t, "127.0.0.1 localhost\n")
peers := map[netip.Addr]*Peer{
netip.MustParseAddr("10.11.12.1"): peer("hub"),
netip.MustParseAddr("10.11.12.99"): peer(""), // no name
}
if err := updateHosts(path, "mynet.local", peers); err != nil {
t.Fatal(err)
}
inside, _ := readManagedSection(t, path, "mynet.local")
if len(inside) != 1 || inside[0] != "10.11.12.1 hub.mynet.local" {
t.Errorf("expected only named peer; inside = %v", inside)
}
}
func TestUpdateHosts_Idempotent(t *testing.T) {
path := writeTempHosts(t, "127.0.0.1 localhost\n")
peers := map[netip.Addr]*Peer{
netip.MustParseAddr("10.11.12.1"): peer("hub"),
}
if err := updateHosts(path, "mynet.local", peers); err != nil {
t.Fatal(err)
}
first, err := os.ReadFile(path)
if err != nil {
t.Fatal(err)
}
if err := updateHosts(path, "mynet.local", peers); err != nil {
t.Fatal(err)
}
second, err := os.ReadFile(path)
if err != nil {
t.Fatal(err)
}
if string(first) != string(second) {
t.Errorf("repeated update changed file:\nfirst:\n%s\nsecond:\n%s", first, second)
}
}
// TestUpdateHosts_PrefixDomainsCoexist guards finding 4.4: two domains where
// one label is a prefix of the other ("net" vs "net2") must each manage their
// own section without clobbering the other's, even sharing one hosts file.
func TestUpdateHosts_PrefixDomainsCoexist(t *testing.T) {
path := writeTempHosts(t, "127.0.0.1 localhost\n")
if err := updateHosts(path, "net2.local", map[netip.Addr]*Peer{
netip.MustParseAddr("10.0.2.1"): peer("a"),
}); err != nil {
t.Fatal(err)
}
if err := updateHosts(path, "net.local", map[netip.Addr]*Peer{
netip.MustParseAddr("10.0.1.1"): peer("b"),
}); err != nil {
t.Fatal(err)
}
// Both sections coexist after writing the prefix domain.
if in, _ := readManagedSection(t, path, "net2.local"); len(in) != 1 || in[0] != "10.0.2.1 a.net2.local" {
t.Errorf("net2 section clobbered: %v", in)
}
if in, _ := readManagedSection(t, path, "net.local"); len(in) != 1 || in[0] != "10.0.1.1 b.net.local" {
t.Errorf("net section wrong: %v", in)
}
// Re-updating net2 must not disturb the net section.
if err := updateHosts(path, "net2.local", map[netip.Addr]*Peer{
netip.MustParseAddr("10.0.2.2"): peer("c"),
}); err != nil {
t.Fatal(err)
}
if in, _ := readManagedSection(t, path, "net.local"); len(in) != 1 || in[0] != "10.0.1.1 b.net.local" {
t.Errorf("net section disturbed by net2 update: %v", in)
}
if in, _ := readManagedSection(t, path, "net2.local"); len(in) != 1 || in[0] != "10.0.2.2 c.net2.local" {
t.Errorf("net2 section not updated: %v", in)
}
}
func contains(ss []string, s string) bool {
for _, x := range ss {
if x == s {
return true
}
}
return false
}

View File

@@ -1,153 +0,0 @@
package peer
import (
"encoding/json"
"io"
"log"
"net/http"
"net/netip"
"net/url"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"vppn/m"
)
const hubPollInterval = 64 * time.Second
type HubPoller struct {
selfVPNIP netip.Addr
vpnNet netip.Prefix
hubURL string
apiKey string
statePath string // where the network state cache is persisted
addCh chan<- m.Peer
removeCh chan<- wgtypes.Key
known map[wgtypes.Key]struct{} // pubKeys currently configured
}
func NewHubPoller(
selfVPNIP netip.Addr,
vpnNet netip.Prefix,
hubURL, apiKey string,
statePath string,
addCh chan<- m.Peer,
removeCh chan<- wgtypes.Key,
) (*HubPoller, error) {
u, err := url.Parse(hubURL)
if err != nil {
return nil, err
}
u.Path = "/peer/fetch-state/"
return &HubPoller{
selfVPNIP: selfVPNIP,
vpnNet: vpnNet,
hubURL: u.String(),
apiKey: apiKey,
statePath: statePath,
addCh: addCh,
removeCh: removeCh,
known: make(map[wgtypes.Key]struct{}),
}, nil
}
func (hp *HubPoller) Run() {
// Prime from the on-disk cache before reaching the hub, so the peer
// configures WireGuard from its last known state even if the hub is down.
// known starts empty, so this emits every cached peer as an add; the first
// real poll then emits only deltas (adds for new peers, removes for gone).
if state, err := loadNetworkState(hp.statePath); err == nil {
hp.apply(state)
}
hp.poll()
for range time.Tick(hubPollInterval) {
hp.poll()
}
}
func (hp *HubPoller) poll() {
req, err := http.NewRequest(http.MethodGet, hp.hubURL, nil)
if err != nil {
log.Printf("[HubPoller] build request: %v", err)
return
}
req.SetBasicAuth("", hp.apiKey)
client := &http.Client{Timeout: 32 * time.Second}
resp, err := client.Do(req)
if err != nil {
log.Printf("[HubPoller] fetch: %v", err)
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
log.Printf("[HubPoller] unexpected status %d", resp.StatusCode)
return
}
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Printf("[HubPoller] read body: %v", err)
return
}
var state m.NetworkState
if err := json.Unmarshal(body, &state); err != nil {
log.Printf("[HubPoller] unmarshal: %v", err)
return
}
// Persist only when the state actually changed, to avoid needless writes
// on every poll.
if hp.apply(state) {
if err := saveNetworkState(hp.statePath, state); err != nil {
log.Printf("[HubPoller] save state: %v", err)
}
}
}
// apply diffs state against the set of known peers, emitting an add for each
// newly-seen peer and a remove for each that disappeared. It returns true if
// anything changed. A peer's config is immutable under a stable WG key (the hub
// has no peer-edit path), so a key already in known needs no re-emit.
func (hp *HubPoller) apply(state m.NetworkState) (changed bool) {
seen := make(map[wgtypes.Key]struct{}, len(hp.known))
netAddr := hp.vpnNet.Addr().As4()
for _, p := range state.Peers {
if p.WGPubKey == (wgtypes.Key{}) {
continue
}
octets := netAddr
octets[3] = p.PeerIP
vpnIP := netip.AddrFrom4(octets)
if vpnIP == hp.selfVPNIP {
continue
}
seen[p.WGPubKey] = struct{}{}
if _, ok := hp.known[p.WGPubKey]; ok {
continue
}
hp.known[p.WGPubKey] = struct{}{}
hp.addCh <- p
changed = true
}
for key := range hp.known {
if _, ok := seen[key]; !ok {
delete(hp.known, key)
hp.removeCh <- key
changed = true
}
}
return changed
}

View File

@@ -1,80 +0,0 @@
package peer
import (
"net/netip"
"testing"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"vppn/m"
)
func testPoller(t *testing.T) (*HubPoller, chan m.Peer, chan wgtypes.Key) {
t.Helper()
addCh := make(chan m.Peer, 8)
removeCh := make(chan wgtypes.Key, 8)
hp := &HubPoller{
selfVPNIP: netip.MustParseAddr("10.0.0.1"),
vpnNet: netip.MustParsePrefix("10.0.0.0/24"),
addCh: addCh,
removeCh: removeCh,
known: make(map[wgtypes.Key]struct{}),
}
return hp, addCh, removeCh
}
func stateWith(key wgtypes.Key, peerIP byte) m.NetworkState {
return m.NetworkState{Peers: []m.Peer{{
PeerIP: peerIP,
WGPubKey: key,
}}}
}
func TestApply_EmitsAddsAndReportsChange(t *testing.T) {
hp, addCh, _ := testPoller(t)
key := mustKey(t)
if changed := hp.apply(stateWith(key, 2)); !changed {
t.Fatal("expected changed=true on first apply")
}
if len(addCh) != 1 {
t.Fatalf("expected 1 add, got %d", len(addCh))
}
if got := <-addCh; got.WGPubKey != key {
t.Errorf("add pubkey mismatch")
}
}
func TestApply_NoChangeWhenKnown(t *testing.T) {
hp, addCh, _ := testPoller(t)
key := mustKey(t)
hp.apply(stateWith(key, 2))
<-addCh // drain initial add
if changed := hp.apply(stateWith(key, 2)); changed {
t.Fatal("expected changed=false when peer already known")
}
if len(addCh) != 0 {
t.Fatalf("expected no re-emit, got %d adds", len(addCh))
}
}
func TestApply_RemovesVanishedPeer(t *testing.T) {
hp, addCh, removeCh := testPoller(t)
key := mustKey(t)
hp.apply(stateWith(key, 2))
<-addCh
// Empty state: the peer is gone.
if changed := hp.apply(m.NetworkState{}); !changed {
t.Fatal("expected changed=true when peer vanishes")
}
if len(removeCh) != 1 {
t.Fatalf("expected 1 remove, got %d", len(removeCh))
}
if got := <-removeCh; got != key {
t.Errorf("remove key mismatch")
}
}

111
peer/hubpoller.go Normal file
View File

@@ -0,0 +1,111 @@
package peer
import (
"encoding/json"
"io"
"log"
"net/http"
"net/url"
"time"
"vppn/m"
)
type HubPoller struct {
Globals
client *http.Client
req *http.Request
versions [256]int64
netName string
}
func NewHubPoller(
g Globals,
netName,
hubURL,
apiKey string,
) (*HubPoller, error) {
u, err := url.Parse(hubURL)
if err != nil {
return nil, err
}
u.Path = "/peer/fetch-state/"
client := &http.Client{Timeout: 8 * time.Second}
req := &http.Request{
Method: http.MethodGet,
URL: u,
Header: http.Header{},
}
req.SetBasicAuth("", apiKey)
return &HubPoller{
Globals: g,
client: client,
req: req,
netName: netName,
}, nil
}
func (hp *HubPoller) logf(s string, args ...any) {
log.Printf("[HubPoller] "+s, args...)
}
func (hp *HubPoller) Run() {
state, err := loadNetworkState(hp.netName)
if err != nil {
hp.logf("Failed to load network state: %v", err)
hp.logf("Polling hub...")
hp.pollHub()
} else {
hp.applyNetworkState(state)
}
for range time.Tick(64 * time.Second) {
hp.pollHub()
}
}
func (hp *HubPoller) pollHub() {
var state m.NetworkState
resp, err := hp.client.Do(hp.req)
if err != nil {
hp.logf("Failed to fetch peer state: %v", err)
return
}
body, err := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if err != nil {
hp.logf("Failed to read body from hub: %v", err)
return
}
if err := json.Unmarshal(body, &state); err != nil {
hp.logf("Failed to unmarshal response from hub: %v\n%s", err, body)
return
}
if err := storeNetworkState(hp.netName, state); err != nil {
hp.logf("Failed to store network state: %v", err)
}
hp.applyNetworkState(state)
}
func (hp *HubPoller) applyNetworkState(state m.NetworkState) {
for i, peer := range state.Peers {
if i == int(hp.LocalPeerIP) {
continue
}
if peer != nil && peer.Version == hp.versions[i] {
continue
}
hp.RemotePeers[i].Load().HandlePeerUpdate(peerUpdateMsg{Peer: state.Peers[i]})
if peer != nil {
hp.versions[i] = peer.Version
}
}
}

73
peer/ifreader.go Normal file
View File

@@ -0,0 +1,73 @@
package peer
import (
"log"
)
type IFReader struct {
Globals
}
func NewIFReader(g Globals) *IFReader {
return &IFReader{Globals: g}
}
func (r *IFReader) Run() {
packet := make([]byte, bufferSize)
for {
r.handleNextPacket(packet)
}
}
func (r *IFReader) handleNextPacket(packet []byte) {
packet = r.readNextPacket(packet)
remoteIP, ok := r.parsePacket(packet)
if !ok {
return
}
r.RemotePeers[remoteIP].Load().SendDataTo(packet)
}
func (r *IFReader) readNextPacket(buf []byte) []byte {
n, err := r.IFace.Read(buf[:cap(buf)])
if err != nil {
log.Fatalf("Failed to read from interface: %v", err)
}
return buf[:n]
}
// parsePacket returns the VPN ip for the packet, and a boolean indicating
// success.
func (r *IFReader) parsePacket(buf []byte) (byte, bool) {
n := len(buf)
if n == 0 {
return 0, false
}
version := buf[0] >> 4
switch version {
case 4:
if n < 20 {
r.logf("Short IPv4 packet: %d", len(buf))
return 0, false
}
return buf[19], true
case 6:
if len(buf) < 40 {
r.logf("Short IPv6 packet: %d", len(buf))
return 0, false
}
return buf[39], true
default:
r.logf("Invalid IP packet version: %v", version)
return 0, false
}
}
func (*IFReader) logf(s string, args ...any) {
log.Printf("[IFReader] "+s, args...)
}

81
peer/ifreader_test.go Normal file
View File

@@ -0,0 +1,81 @@
package peer
/*
func TestIFReader_IPv4(t *testing.T) {
p1, p2, _ := NewPeersForTesting()
pkt := make([]byte, 1234)
pkt[0] = 4 << 4
pkt[19] = 2 // IP.
p1.IFace.UserWrite(pkt)
p1.IFReader.handleNextPacket(newBuf())
packets := p2.Conn.Packets()
if len(packets) != 1 {
t.Fatal(packets)
}
}
func TestIFReader_IPv6(t *testing.T) {
p1, p2, _ := NewPeersForTesting()
pkt := make([]byte, 1234)
pkt[0] = 6 << 4
pkt[39] = 2 // IP.
p1.IFace.UserWrite(pkt)
p1.IFReader.handleNextPacket(newBuf())
packets := p2.Conn.Packets()
if len(packets) != 1 {
t.Fatal(packets)
}
}
func TestIFReader_parsePacket_emptyPacket(t *testing.T) {
r := NewIFReader(nil, nil)
pkt := make([]byte, 0)
if ip, ok := r.parsePacket(pkt); ok {
t.Fatal(ip, ok)
}
}
func TestIFReader_parsePacket_invalidIPVersion(t *testing.T) {
r := NewIFReader(nil, nil)
for i := byte(1); i < 16; i++ {
if i == 4 || i == 6 {
continue
}
pkt := make([]byte, 1234)
pkt[0] = i << 4
if ip, ok := r.parsePacket(pkt); ok {
t.Fatal(i, ip, ok)
}
}
}
func TestIFReader_parsePacket_shortIPv4(t *testing.T) {
r := NewIFReader(nil, nil)
pkt := make([]byte, 19)
pkt[0] = 4 << 4
if ip, ok := r.parsePacket(pkt); ok {
t.Fatal(ip, ok)
}
}
func TestIFReader_parsePacket_shortIPv6(t *testing.T) {
r := NewIFReader(nil, nil)
pkt := make([]byte, 39)
pkt[0] = 6 << 4
if ip, ok := r.parsePacket(pkt); ok {
t.Fatal(ip, ok)
}
}
*/

View File

@@ -1,190 +0,0 @@
package peer
import (
"bytes"
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/netip"
"os"
"golang.org/x/crypto/nacl/sign"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"vppn/m"
)
// LocalState is the persisted identity for this peer, written on first run and
// loaded on every subsequent run.
type LocalState struct {
PrivKey wgtypes.Key
SignKey [64]byte // nacl/sign Ed25519 private key
VPNIP netip.Addr
VPNNet netip.Prefix
WGPort uint16
IsRelay bool
IsPublic bool
LocalDomain string
}
// localStateJSON is the on-disk representation.
type localStateJSON struct {
PrivKey string
SignKey string
VPNIP netip.Addr
VPNNet netip.Prefix
WGPort uint16
IsRelay bool
IsPublic bool
LocalDomain string
}
// LoadOrInit loads LocalState from path, or registers with the hub and creates
// the file if it doesn't exist.
func LoadOrInit(statePath, hubURL, apiKey string) (LocalState, error) {
var state LocalState
switch err := loadJSON(statePath, &state); {
case err == nil:
return state, nil
case !os.IsNotExist(err):
// File exists but is unreadable/corrupt: surface it rather than
// silently regenerating a new identity and re-registering.
return LocalState{}, fmt.Errorf("load state: %w", err)
}
privKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
return LocalState{}, fmt.Errorf("generate key: %w", err)
}
state, err = initFromHub(hubURL, apiKey, privKey)
if err != nil {
return LocalState{}, err
}
if err := storeJSON(statePath, state); err != nil {
return LocalState{}, fmt.Errorf("save state: %w", err)
}
return state, nil
}
func initFromHub(hubURL, apiKey string, privKey wgtypes.Key) (LocalState, error) {
wgPubKey := privKey.PublicKey()
signPubKey, signPrivKey, err := sign.GenerateKey(rand.Reader)
if err != nil {
return LocalState{}, fmt.Errorf("generate sign key: %w", err)
}
body, err := json.Marshal(m.PeerInitArgs{
WGPubKey: wgPubKey[:],
SignPubKey: signPubKey[:],
})
if err != nil {
return LocalState{}, fmt.Errorf("json error: %w", err)
}
req, err := http.NewRequest(http.MethodPost, hubURL+"/peer/init/", bytes.NewReader(body))
if err != nil {
return LocalState{}, err
}
req.SetBasicAuth("", apiKey)
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return LocalState{}, fmt.Errorf("hub init: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return LocalState{}, fmt.Errorf("hub init: HTTP %d", resp.StatusCode)
}
var r m.PeerInitResp
if err := json.NewDecoder(resp.Body).Decode(&r); err != nil {
return LocalState{}, fmt.Errorf("hub init decode: %w", err)
}
if len(r.Network) != 4 {
return LocalState{}, fmt.Errorf("hub init: invalid network %v", r.Network)
}
netAddr := netip.AddrFrom4([4]byte(r.Network))
octets := netAddr.As4()
octets[3] = r.PeerIP
vpnIP := netip.AddrFrom4(octets)
vpnNet := netip.PrefixFrom(netAddr, 24)
var self *m.Peer
for i := range r.NetworkState.Peers {
if r.NetworkState.Peers[i].PeerIP == r.PeerIP {
self = &r.NetworkState.Peers[i]
break
}
}
if self == nil {
return LocalState{}, fmt.Errorf("hub init: no peer for own IP: %d", r.PeerIP)
}
public := self.IsPublic()
return LocalState{
PrivKey: privKey,
SignKey: *signPrivKey,
VPNIP: vpnIP,
VPNNet: vpnNet,
WGPort: self.Port,
IsRelay: self.Relay && public,
IsPublic: public,
LocalDomain: r.LocalDomain,
}, nil
}
func (s LocalState) MarshalJSON() ([]byte, error) {
return json.Marshal(localStateJSON{
PrivKey: base64.StdEncoding.EncodeToString(s.PrivKey[:]),
SignKey: base64.StdEncoding.EncodeToString(s.SignKey[:]),
VPNIP: s.VPNIP,
VPNNet: s.VPNNet,
WGPort: s.WGPort,
IsRelay: s.IsRelay,
IsPublic: s.IsPublic,
LocalDomain: s.LocalDomain,
})
}
func (s *LocalState) UnmarshalJSON(data []byte) error {
var j localStateJSON
if err := json.Unmarshal(data, &j); err != nil {
return err
}
keyBytes, err := base64.StdEncoding.DecodeString(j.PrivKey)
if err != nil {
return fmt.Errorf("decode key: %w", err)
}
key, err := wgtypes.NewKey(keyBytes)
if err != nil {
return fmt.Errorf("invalid key: %w", err)
}
signKeyBytes, err := base64.StdEncoding.DecodeString(j.SignKey)
if err != nil {
return fmt.Errorf("decode sign key: %w", err)
}
if len(signKeyBytes) != 64 {
return fmt.Errorf("invalid sign key length: %d", len(signKeyBytes))
}
*s = LocalState{
PrivKey: key,
SignKey: [64]byte(signKeyBytes),
VPNIP: j.VPNIP,
VPNNet: j.VPNNet,
WGPort: j.WGPort,
IsRelay: j.IsRelay,
IsPublic: j.IsPublic,
LocalDomain: j.LocalDomain,
}
return nil
}

137
peer/interface.go Normal file
View File

@@ -0,0 +1,137 @@
package peer
import (
"fmt"
"io"
"net"
"os"
"syscall"
"golang.org/x/sys/unix"
)
func openInterface(network []byte, localIP byte, name string) (io.ReadWriteCloser, error) {
if len(network) != 4 {
return nil, fmt.Errorf("expected network to be 4 bytes, got %d", len(network))
}
ip := net.IPv4(network[0], network[1], network[2], localIP)
//////////////////////////
// Create TUN Interface //
//////////////////////////
tunFD, err := syscall.Open("/dev/net/tun", syscall.O_RDWR|unix.O_CLOEXEC, 0600)
if err != nil {
return nil, fmt.Errorf("failed to open TUN device: %w", err)
}
// New interface request.
req, err := unix.NewIfreq(name)
if err != nil {
return nil, fmt.Errorf("failed to create new TUN interface request: %w", err)
}
// Flags:
//
// IFF_NO_PI => don't add packet info data to packets sent to the interface.
// IFF_TUN => create a TUN device handling IP packets.
req.SetUint16(unix.IFF_NO_PI | unix.IFF_TUN)
err = unix.IoctlIfreq(tunFD, unix.TUNSETIFF, req)
if err != nil {
return nil, fmt.Errorf("failed to set TUN device settings: %w", err)
}
// Name may not be exactly the same?
name = req.Name()
/////////////
// Set MTU //
/////////////
// We need a socket file descriptor to set other options for some reason.
sockFD, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return nil, fmt.Errorf("failed to open socket: %w", err)
}
defer unix.Close(sockFD)
req, err = unix.NewIfreq(name)
if err != nil {
return nil, fmt.Errorf("failed to create MTU interface request: %w", err)
}
req.SetUint32(if_mtu)
if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFMTU, req); err != nil {
return nil, fmt.Errorf("failed to set interface MTU: %w", err)
}
//////////////////////
// Set Queue Length //
//////////////////////
req, err = unix.NewIfreq(name)
if err != nil {
return nil, fmt.Errorf("failed to create IP interface request: %w", err)
}
req.SetUint16(if_queue_len)
if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFTXQLEN, req); err != nil {
return nil, fmt.Errorf("failed to set interface queue length: %w", err)
}
/////////////////////
// Set IP and Mask //
/////////////////////
req, err = unix.NewIfreq(name)
if err != nil {
return nil, fmt.Errorf("failed to create IP interface request: %w", err)
}
if err := req.SetInet4Addr(ip.To4()); err != nil {
return nil, fmt.Errorf("failed to set interface request IP: %w", err)
}
if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFADDR, req); err != nil {
return nil, fmt.Errorf("failed to set interface IP: %w", err)
}
// SET MASK - must happen after setting address.
req, err = unix.NewIfreq(name)
if err != nil {
return nil, fmt.Errorf("failed to create mask interface request: %w", err)
}
if err := req.SetInet4Addr(net.IPv4(255, 255, 255, 0).To4()); err != nil {
return nil, fmt.Errorf("failed to set interface request mask: %w", err)
}
if err := unix.IoctlIfreq(sockFD, unix.SIOCSIFNETMASK, req); err != nil {
return nil, fmt.Errorf("failed to set interface mask: %w", err)
}
////////////////////////
// Bring Interface Up //
////////////////////////
req, err = unix.NewIfreq(name)
if err != nil {
return nil, fmt.Errorf("failed to create up interface request: %w", err)
}
// Get current flags.
if err = unix.IoctlIfreq(sockFD, unix.SIOCGIFFLAGS, req); err != nil {
return nil, fmt.Errorf("failed to get interface flags: %w", err)
}
flags := req.Uint16() | unix.IFF_UP | unix.IFF_RUNNING
// Set UP flag / broadcast flags.
req.SetUint16(flags)
if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFFLAGS, req); err != nil {
return nil, fmt.Errorf("failed to set interface up: %w", err)
}
return os.NewFile(uintptr(tunFD), "tun"), nil
}

View File

@@ -1,28 +0,0 @@
package peer
import (
"net/netip"
"vppn/peer/control"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// WGDevice is the subset of wginterface.Device used by App.
type WGDevice interface {
Name() string
Peers() ([]wgtypes.Peer, error)
AddPeer(pubKey wgtypes.Key) error
AddDirect(pubKey wgtypes.Key, endpoint netip.AddrPort, vpnIP netip.Addr) error
SetRelay(pubKey wgtypes.Key, endpoint netip.AddrPort, network netip.Prefix) error
AddProbe(pubKey wgtypes.Key, endpoint netip.AddrPort) error
Promote(pubKey wgtypes.Key, vpnIP netip.Addr) error
RemovePeer(pubKey wgtypes.Key) error
}
// ControlConn sends pings to peers over the VPN control port.
// Reading is handled separately via run, which feeds the App's pingCh.
// buf is a caller-provided scratch buffer (at least control.Size bytes) used to
// marshal the ping; the caller reuses one across sends.
type ControlConn interface {
SendPing(dst netip.AddrPort, ping control.Ping, buf []byte) error
}

View File

@@ -1,36 +0,0 @@
package peer
import (
"encoding/json"
"os"
"path/filepath"
)
func loadJSON(path string, target any) error {
data, err := os.ReadFile(path)
if err != nil {
return err
}
return json.Unmarshal(data, target)
}
func storeJSON(path string, obj any) error {
if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil {
return err
}
data, err := json.MarshalIndent(obj, "", " ")
if err != nil {
return err
}
tmpPath := path + ".tmp"
if err := os.WriteFile(tmpPath, data, 0600); err != nil {
return err
}
if err := os.Rename(tmpPath, path); err != nil {
os.Remove(tmpPath)
return err
}
return nil
}

207
peer/main.go Normal file
View File

@@ -0,0 +1,207 @@
package peer
import (
"encoding/json"
"fmt"
"log"
"net"
"net/http"
"net/netip"
"os"
"time"
)
// Usage:
//
// vppn netName run
// vppn netName status
func Main2() {
printUsage := func() {
fmt.Fprintf(os.Stderr, `%s COMMAND [ARGUMENTS...]
Available commands:
run
status
hosts
`, os.Args[0])
os.Exit(1)
}
if len(os.Args) < 2 {
printUsage()
}
command := os.Args[1]
switch command {
case "run":
main_run()
case "status":
main_status()
case "hosts":
main_hosts()
default:
printUsage()
}
}
// ----------------------------------------------------------------------------
type mainArgs struct {
NetName string
HubAddress string
APIKey string
}
func main_run() {
printUsage := func() {
fmt.Fprintf(os.Stderr, `Usage: %s run NETWORK_NAME HUB_ADDRESS API_KEY
NETWORK_NAME
Unique name of the network interface created. The network name
shouldn't change between invocations of the application.
HUB_ADDRESS
The address of the hub server. This should also contain the scheme, for
example https://hub.domain.com/.
API_KEY
The API key assigned to this peer by the hub.
`, os.Args[0])
os.Exit(1)
}
if len(os.Args) != 5 {
printUsage()
}
args := mainArgs{
NetName: os.Args[2],
HubAddress: os.Args[3],
APIKey: os.Args[4],
}
newPeerMain(args).Run()
}
// ----------------------------------------------------------------------------
func main_status() {
printUsage := func() {
fmt.Fprintf(os.Stderr, `Usage: %s status NETWORK_NAME
NETWORK_NAME
Unique name of the network interface created.
`, os.Args[0])
os.Exit(1)
}
if len(os.Args) != 3 {
printUsage()
}
netName := os.Args[2]
report := fetchStatusReport(netName)
fmt.Printf("\n%s Status\n\n", netName)
if len(report.Network) != 4 {
fmt.Printf("Network: %v\n\n", report.Network)
} else {
nw := report.Network
fmt.Printf("%-8s %d.%d.%d.%d/24\n", "Network", nw[0], nw[1], nw[2], nw[3])
}
if report.RelayPeerIP != 0 {
fmt.Printf("%-8s %d\n\n", "Relay", report.RelayPeerIP)
} else {
fmt.Printf("%-8s -\n\n", "Relay")
}
for _, status := range report.Remotes {
fmt.Printf("%3d %s\n", status.PeerIP, status.Name)
fmt.Printf(" %-11s %v\n", "Up", status.Up)
pubIP, ok := netip.AddrFromSlice(status.PublicIP)
if ok {
fmt.Printf(" %-11s %v\n", "Public IP", pubIP)
} else {
fmt.Printf(" %-11s\n", "Public IP")
}
fmt.Printf(" %-11s %d\n", "Port", status.Port)
fmt.Printf(" %-11s %v\n", "Relay", status.Relay)
fmt.Printf(" %-11s %v\n", "Server", status.Server)
fmt.Printf(" %-11s %v\n", "Direct", status.Direct)
if status.DirectAddr.IsValid() {
fmt.Printf(" %-11s %v\n", "Address", status.DirectAddr)
}
fmt.Println("")
}
}
// ----------------------------------------------------------------------------
func main_hosts() {
printUsage := func() {
fmt.Fprintf(os.Stderr, `Usage: %s hosts NETWORK_NAME
NETWORK_NAME
Unique name of the network interface created.
`, os.Args[0])
os.Exit(1)
}
if len(os.Args) != 3 {
printUsage()
}
netName := os.Args[2]
state, err := loadNetworkState(netName)
if err != nil {
log.Fatalf("Failed to load network state: %v", err)
}
config, err := loadPeerConfig(netName)
if err != nil {
log.Fatalf("Failed to load config: %v", err)
}
nw := config.Network
for _, peer := range state.Peers {
if peer == nil {
continue
}
fmt.Printf("%d.%d.%d.%d %s\n",
nw[0], nw[1], nw[2], peer.PeerIP, peer.Name)
}
fmt.Println("")
}
// ----------------------------------------------------------------------------
func fetchStatusReport(netName string) StatusReport {
client := http.Client{
Transport: &http.Transport{
Dial: func(_, _ string) (net.Conn, error) {
return net.Dial("unix", statusSocketPath(netName))
},
},
Timeout: 8 * time.Second,
}
getURL := "http://unix" + statusSocketPath(netName)
resp, err := client.Get(getURL)
if err != nil {
log.Fatalf("Failed to get response: %v", err)
}
report := StatusReport{}
if err := json.NewDecoder(resp.Body).Decode(&report); err != nil {
log.Fatalf("Failed to decode status report: %v", err)
}
return report
}

5
peer/main_test.go Normal file
View File

@@ -0,0 +1,5 @@
package peer
func newBuf() []byte {
return make([]byte, bufferSize)
}

47
peer/mcreader.go Normal file
View File

@@ -0,0 +1,47 @@
package peer
import (
"log"
"net"
"time"
)
func RunMCReader(g Globals) {
for {
runMCReaderInner(g)
time.Sleep(broadcastErrorTimeoutInterval)
}
}
func runMCReaderInner(g Globals) {
var (
buf = make([]byte, bufferSize)
logf = func(s string, args ...any) {
log.Printf("[MCReader] "+s, args...)
}
)
conn, err := net.ListenMulticastUDP("udp", nil, multicastAddr)
if err != nil {
logf("Failed to bind to multicast address: %v", err)
return
}
for {
conn.SetReadDeadline(time.Now().Add(32 * time.Second))
n, remoteAddr, err := conn.ReadFromUDPAddrPort(buf[:bufferSize])
if err != nil {
logf("Failed to read from UDP port): %v", err)
return
}
buf = buf[:n]
h, ok := headerFromLocalDiscoveryPacket(buf)
if !ok {
logf("Failed to open discovery packet?")
continue
}
g.RemotePeers[h.SourceIP].Load().HandleLocalDiscoveryPacket(h, remoteAddr, buf)
}
}

132
peer/mcreader_test.go Normal file
View File

@@ -0,0 +1,132 @@
package peer
/*
type mcMockConn struct {
packets chan []byte
}
func newMCMockConn() *mcMockConn {
return &mcMockConn{make(chan []byte, 32)}
}
func (c *mcMockConn) WriteToUDP(in []byte, addr *net.UDPAddr) (int, error) {
c.packets <- bytes.Clone(in)
return len(in), nil
}
func (c *mcMockConn) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) {
buf := <-c.packets
b = b[:len(buf)]
copy(b, buf)
return len(b), netip.AddrPort{}, nil
}
func TestMCReader(t *testing.T) {
keys := generateKeys()
super := &mockControlMsgHandler{}
conn := newMCMockConn()
peers := [256]*atomic.Pointer[RemotePeer]{}
peer := &RemotePeer{
IP: 1,
Up: true,
PubSignKey: keys.PubSignKey,
}
peers[1] = &atomic.Pointer[RemotePeer]{}
peers[1].Store(peer)
w := newMCWriter(conn, 1, keys.PrivSignKey)
r := newMCReader(conn, super, peers)
w.SendLocalDiscovery()
r.handleNextPacket()
if len(super.Messages) != 1 {
t.Fatal(super.Messages)
}
msg, ok := super.Messages[0].(controlMsg[PacketLocalDiscovery])
if !ok || msg.SrcIP != 1 {
t.Fatal(ok, msg)
}
}
func TestMCReader_noHeader(t *testing.T) {
keys := generateKeys()
super := &mockControlMsgHandler{}
conn := newMCMockConn()
peers := [256]*atomic.Pointer[RemotePeer]{}
peer := &RemotePeer{
IP: 1,
Up: true,
PubSignKey: keys.PubSignKey,
}
peers[1] = &atomic.Pointer[RemotePeer]{}
peers[1].Store(peer)
r := newMCReader(conn, super, peers)
conn.WriteToUDP([]byte("0123546789"), nil)
r.handleNextPacket()
if len(super.Messages) != 0 {
t.Fatal(super.Messages)
}
}
func TestMCReader_noPeer(t *testing.T) {
keys := generateKeys()
super := &mockControlMsgHandler{}
conn := newMCMockConn()
peers := [256]*atomic.Pointer[RemotePeer]{}
peer := &RemotePeer{
IP: 1,
Up: true,
PubSignKey: keys.PubSignKey,
}
peers[1] = &atomic.Pointer[RemotePeer]{}
peers[2] = &atomic.Pointer[RemotePeer]{}
peers[1].Store(peer)
w := newMCWriter(conn, 2, keys.PrivSignKey)
r := newMCReader(conn, super, peers)
w.SendLocalDiscovery()
r.handleNextPacket()
if len(super.Messages) != 0 {
t.Fatal(super.Messages)
}
}
func TestMCReader_badSignature(t *testing.T) {
keys := generateKeys()
super := &mockControlMsgHandler{}
conn := newMCMockConn()
peers := [256]*atomic.Pointer[RemotePeer]{}
peer := &RemotePeer{
IP: 1,
Up: true,
PubSignKey: keys.PubSignKey,
}
peers[1] = &atomic.Pointer[RemotePeer]{}
peers[1].Store(peer)
w := newMCWriter(conn, 1, keys.PrivSignKey)
w.SendLocalDiscovery()
// Break signing.
packet := <-conn.packets
packet[0]++
conn.packets <- packet
r := newMCReader(conn, super, peers)
r.handleNextPacket()
if len(super.Messages) != 0 {
t.Fatal(super.Messages)
}
}
*/

54
peer/mcwriter.go Normal file
View File

@@ -0,0 +1,54 @@
package peer
import (
"log"
"net"
"time"
"golang.org/x/crypto/nacl/sign"
)
func createLocalDiscoveryPacket(localIP byte, signingKey []byte) []byte {
h := Header{
SourceIP: localIP,
DestIP: 255,
}
buf := make([]byte, headerSize)
h.Marshal(buf)
out := make([]byte, headerSize+signingOverhead)
return sign.Sign(out[:0], buf, (*[64]byte)(signingKey))
}
func headerFromLocalDiscoveryPacket(pkt []byte) (h Header, ok bool) {
if len(pkt) != headerSize+signingOverhead {
return
}
h.Parse(pkt[signingOverhead:])
ok = true
return
}
func verifyLocalDiscoveryPacket(pkt, buf []byte, pubSignKey []byte) bool {
_, ok := sign.Open(buf[:0], pkt, (*[32]byte)(pubSignKey))
return ok
}
// ----------------------------------------------------------------------------
func RunMCWriter(localIP byte, signingKey []byte) {
discoveryPacket := createLocalDiscoveryPacket(localIP, signingKey)
conn, err := net.ListenMulticastUDP("udp", nil, multicastAddr)
if err != nil {
log.Fatalf("[MCWriter] Failed to bind to multicast address: %v", err)
}
for range time.Tick(broadcastInterval) {
log.Printf("[MCWriter] Broadcasting on %v...", multicastAddr)
_, err := conn.WriteToUDP(discoveryPacket, multicastAddr)
if err != nil {
log.Printf("[MCWriter] Failed to write multicast: %v", err)
}
}
}

98
peer/mcwriter_test.go Normal file
View File

@@ -0,0 +1,98 @@
package peer
/*
// ----------------------------------------------------------------------------
// Testing that we can create and verify a local discovery packet.
func TestVerifyLocalDiscoveryPacket_valid(t *testing.T) {
keys := generateKeys()
created := createLocalDiscoveryPacket(55, keys.PrivSignKey)
header, ok := headerFromLocalDiscoveryPacket(created)
if !ok {
t.Fatal(ok)
}
if header.SourceIP != 55 || header.DestIP != 255 {
t.Fatal(header)
}
if !verifyLocalDiscoveryPacket(created, make([]byte, 1024), keys.PubSignKey) {
t.Fatal("Not valid")
}
}
// Testing that we don't try to parse short packets.
func TestVerifyLocalDiscoveryPacket_tooShort(t *testing.T) {
keys := generateKeys()
created := createLocalDiscoveryPacket(55, keys.PrivSignKey)
_, ok := headerFromLocalDiscoveryPacket(created[:len(created)-1])
if ok {
t.Fatal(ok)
}
}
// Testing that modifying a packet makes it invalid.
func TestVerifyLocalDiscoveryPacket_invalid(t *testing.T) {
keys := generateKeys()
created := createLocalDiscoveryPacket(55, keys.PrivSignKey)
buf := make([]byte, 1024)
for i := range created {
modified := bytes.Clone(created)
modified[i]++
if verifyLocalDiscoveryPacket(modified, buf, keys.PubSignKey) {
t.Fatal("Verification should have failed.")
}
}
}
// ----------------------------------------------------------------------------
type testUDPWriter struct {
written [][]byte
}
func (w *testUDPWriter) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) {
w.written = append(w.written, bytes.Clone(b))
return len(b), nil
}
func (w *testUDPWriter) Written() [][]byte {
out := w.written
w.written = [][]byte{}
return out
}
// ----------------------------------------------------------------------------
// Testing that the mcWriter sends local discovery packets as expected.
func TestMCWriter_SendLocalDiscovery(t *testing.T) {
keys := generateKeys()
writer := &testUDPWriter{}
mcw := newMCWriter(writer, 42, keys.PrivSignKey)
mcw.SendLocalDiscovery()
out := writer.Written()
if len(out) != 1 {
t.Fatal(out)
}
pkt := out[0]
header, ok := headerFromLocalDiscoveryPacket(pkt)
if !ok {
t.Fatal(ok)
}
if header.SourceIP != 42 || header.DestIP != 255 {
t.Fatal(header)
}
if !verifyLocalDiscoveryPacket(pkt, make([]byte, 1024), keys.PubSignKey) {
t.Fatal("Verification should succeed.")
}
}
*/

31
peer/mock-iface_test.go Normal file
View File

@@ -0,0 +1,31 @@
package peer
import "bytes"
type TestIFace struct {
out *bytes.Buffer // Toward the network.
in *bytes.Buffer // From the network
}
func NewTestIFace() *TestIFace {
return &TestIFace{
out: &bytes.Buffer{},
in: &bytes.Buffer{},
}
}
func (iface *TestIFace) Write(b []byte) (int, error) {
return iface.in.Write(b)
}
func (iface *TestIFace) Read(b []byte) (int, error) {
return iface.out.Read(b)
}
func (iface *TestIFace) UserWrite(b []byte) (int, error) {
return iface.out.Write(b)
}
func (iface *TestIFace) UserRead(b []byte) (int, error) {
return iface.in.Read(b)
}

80
peer/mock-network_test.go Normal file
View File

@@ -0,0 +1,80 @@
package peer
import (
"bytes"
"net"
"net/netip"
"sync"
)
type TestPacket struct {
Addr netip.AddrPort
Data []byte
}
type TestNetwork struct {
lock sync.Mutex
packets map[netip.AddrPort]chan TestPacket
}
func NewTestNetwork() *TestNetwork {
return &TestNetwork{packets: map[netip.AddrPort]chan TestPacket{}}
}
func (n *TestNetwork) NewUDPConn(localAddr netip.AddrPort) *TestUDPConn {
n.lock.Lock()
defer n.lock.Unlock()
if _, ok := n.packets[localAddr]; !ok {
n.packets[localAddr] = make(chan TestPacket, 1024)
}
return &TestUDPConn{
addr: localAddr,
n: n,
packets: n.packets[localAddr],
}
}
func (n *TestNetwork) write(b []byte, from, to netip.AddrPort) {
n.lock.Lock()
defer n.lock.Unlock()
if _, ok := n.packets[to]; !ok {
n.packets[to] = make(chan TestPacket, 1024)
}
n.packets[to] <- TestPacket{
Addr: from,
Data: bytes.Clone(b),
}
}
type TestUDPConn struct {
addr netip.AddrPort
n *TestNetwork
packets chan TestPacket
}
func (c *TestUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
c.n.write(b, c.addr, addr)
return len(b), nil
}
func (c *TestUDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) {
return c.WriteToUDPAddrPort(b, addr.AddrPort())
}
func (c *TestUDPConn) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) {
pkt := <-c.packets
b = b[:len(pkt.Data)]
copy(b, pkt.Data)
return len(b), pkt.Addr, nil
}
func (c *TestUDPConn) Packets() (out []TestPacket) {
for {
select {
case pkt := <-c.packets:
out = append(out, pkt)
default:
return
}
}
}

View File

@@ -1,62 +0,0 @@
package multicast
import (
"log"
"net"
"net/netip"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
var addr = net.UDPAddrFromAddrPort(netip.AddrPortFrom(
netip.AddrFrom4([4]byte{224, 0, 0, 157}),
4560))
func Broadcast(
selfVPNIP netip.Addr,
pubKey wgtypes.Key,
wgPort uint16,
signKey *[64]byte,
) {
for {
broadcastInner(selfVPNIP, pubKey, wgPort, signKey)
time.Sleep(errorTimeout)
}
}
func broadcastInner(selfVPNIP netip.Addr, pubKey wgtypes.Key, wgPort uint16, signKey *[64]byte) {
conn, err := net.ListenMulticastUDP("udp", nil, addr)
if err != nil {
log.Printf("[MCBroadcast] bind: %v", err)
return
}
defer conn.Close()
buf := make([]byte, BufferSize)
packet := Packet{
PeerIP: selfVPNIP.As4()[3],
WGPubKey: pubKey,
WGPort: wgPort,
}
// Re-sign on each send so the timestamp is fresh; a stale timestamp would be
// dropped by receivers' freshness gate.
send := func() error {
packet.Timestamp = time.Now().Unix()
payload := packet.Marshal(buf, signKey)
_, err := conn.WriteToUDP(payload, addr)
return err
}
if err := send(); err != nil {
log.Printf("[MCBroadcast] write: %v", err)
}
for range time.Tick(broadcastInterval) {
if err := send(); err != nil {
log.Printf("[MCBroadcast] write: %v", err)
return
}
}
}

View File

@@ -1,9 +0,0 @@
package multicast
import "time"
const (
errorTimeout = 16 * time.Second
broadcastInterval = 16 * time.Second
maxPacketAge = time.Minute
)

View File

@@ -1,54 +0,0 @@
package multicast
import (
"encoding/binary"
"net/netip"
"golang.org/x/crypto/nacl/sign"
)
const (
BufferSize = packetSize + SignedPacketSize
SignedPacketSize = packetSize + signSize
packetSize = 43
signSize = 64
)
// Layout:
//
// [0] final octet of the sender's VPN IP
// [1:33] WG public key
// [33:35] WG listen port (big-endian uint16)
// [35:43] send time, Unix seconds (big-endian int64) — freshness/replay gate
type Packet struct {
PeerIP byte // Final octet of the sender's VPN IP.
WGPubKey [32]byte // WG public key.
WGPort uint16 // WG listen port.
Timestamp int64 // Unix timestamp.
Src netip.Addr // Source of packet.
Signed []byte // Raw signed message for verification (incoming packet).
}
// Marshal the packet into a buffer with prefixed signature.
func (p Packet) Marshal(buf []byte, signKey *[64]byte) []byte {
buf[0] = p.PeerIP
copy(buf[1:33], p.WGPubKey[:])
binary.BigEndian.PutUint16(buf[33:35], p.WGPort)
binary.BigEndian.PutUint64(buf[35:43], uint64(p.Timestamp))
return sign.Sign(buf[packetSize:packetSize], buf[:packetSize], signKey)
}
func (p Packet) Verify(buf []byte, pubKey *[32]byte) bool {
_, ok := sign.Open(buf, p.Signed, pubKey)
return ok
}
func Unmarshal(signed []byte) (p Packet) {
buf := signed[signSize:]
p.PeerIP = buf[0]
copy(p.WGPubKey[:], buf[1:33])
p.WGPort = binary.BigEndian.Uint16(buf[33:35])
p.Timestamp = int64(binary.BigEndian.Uint64(buf[35:43]))
p.Signed = signed
return
}

View File

@@ -1,38 +0,0 @@
package multicast
import (
"crypto/rand"
"testing"
"golang.org/x/crypto/nacl/sign"
)
func TestPacket(t *testing.T) {
pub, priv, err := sign.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
p := Packet{
PeerIP: 10,
WGPubKey: [32]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2},
WGPort: 44,
Timestamp: 12948893,
}
buf := make([]byte, BufferSize)
signed := p.Marshal(buf, priv)
if len(signed) != SignedPacketSize {
t.Fatalf("signed length = %d, want %d", len(signed), SignedPacketSize)
}
got := Unmarshal(signed)
if got.PeerIP != p.PeerIP || got.WGPubKey != p.WGPubKey ||
got.WGPort != p.WGPort || got.Timestamp != p.Timestamp {
t.Fatalf("round-trip mismatch:\n got %+v\nwant %+v", got, p)
}
if !got.Verify(nil, pub) {
t.Error("signature did not verify")
}
}

View File

@@ -1,61 +0,0 @@
package multicast
import (
"bytes"
"fmt"
"log"
"net"
"net/netip"
"time"
)
func Receiver(vpnNet netip.Prefix, selfVPNIP netip.Addr, ch chan<- Packet) {
for {
if err := receiver(vpnNet, selfVPNIP, ch); err != nil {
log.Printf("[MCReader] %v", err)
}
time.Sleep(errorTimeout)
}
}
func receiver(vpnNet netip.Prefix, selfVPNIP netip.Addr, ch chan<- Packet) error {
selfIP := selfVPNIP.As4()[3]
conn, err := net.ListenMulticastUDP("udp", nil, addr)
if err != nil {
return fmt.Errorf("bind: %w", err)
}
defer conn.Close()
buf := make([]byte, BufferSize+1) // +1 to detect oversized packets
for {
conn.SetReadDeadline(time.Now().Add(32 * time.Second))
n, src, err := conn.ReadFromUDPAddrPort(buf)
if err != nil {
if ne, ok := err.(net.Error); ok && ne.Timeout() {
continue
}
return fmt.Errorf("read: %w", err)
}
if n != SignedPacketSize {
continue
}
packet := Unmarshal(buf[:n])
if packet.PeerIP == selfIP {
continue
}
age := time.Since(time.Unix(packet.Timestamp, 0))
if age > maxPacketAge || age < -maxPacketAge {
continue
}
packet.Signed = bytes.Clone(packet.Signed)
packet.Src = src.Addr().Unmap()
ch <- packet
}
}

View File

@@ -1,18 +0,0 @@
package peer
import "vppn/m"
// loadNetworkState reads a cached network state from disk. Any error (most
// commonly a missing file on first run) is returned to the caller, which
// treats it as "no cache available".
func loadNetworkState(path string) (m.NetworkState, error) {
var state m.NetworkState
err := loadJSON(path, &state)
return state, err
}
// saveNetworkState writes state to path atomically (see storeJSON), so a crash
// mid-write cannot leave a corrupt cache.
func saveNetworkState(path string, state m.NetworkState) error {
return storeJSON(path, state)
}

View File

@@ -1,56 +0,0 @@
package peer
import (
"net/netip"
"path/filepath"
"reflect"
"testing"
"vppn/m"
)
func TestNetworkState_RoundTrip(t *testing.T) {
path := filepath.Join(t.TempDir(), "network.json")
var sign1 [32]byte
copy(sign1[:], []byte("0123456789abcdef0123456789abcdef"))
state := m.NetworkState{Peers: []m.Peer{
{
PeerIP: 1,
Name: "hub",
Addr4: netip.MustParseAddr("10.11.12.1"),
Port: 51820,
Relay: true,
WGPubKey: mustKey(t),
SignPubKey: sign1,
},
{
PeerIP: 10,
Name: "laptop",
Addr4: netip.MustParseAddr("10.11.12.10"),
Port: 51820,
WGPubKey: mustKey(t),
},
}}
if err := saveNetworkState(path, state); err != nil {
t.Fatal(err)
}
got, err := loadNetworkState(path)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(got, state) {
t.Errorf("round-trip mismatch:\n got: %+v\nwant: %+v", got.Peers[1], state.Peers[1])
}
}
func TestNetworkState_LoadMissing(t *testing.T) {
path := filepath.Join(t.TempDir(), "does-not-exist.json")
if _, err := loadNetworkState(path); err == nil {
t.Fatal("expected error loading missing cache, got nil")
}
}

View File

@@ -1,109 +0,0 @@
package peer
import (
"fmt"
"net/netip"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"vppn/m"
"vppn/peer/multicast"
"vppn/peer/wginterface"
)
// New constructs an App, creates the WireGuard interface, and starts the
// background goroutines (hub poller, multicast, control conn reader).
// The caller should invoke Run() to start the event loop.
func New(
state LocalState,
hubURL, apiKey string,
ifaceName string,
localDomain string,
networkStatePath string,
) (*App, error) {
a4 := state.VPNIP.As4()
if err := wginterface.Create(ifaceName, a4[:], 24); err != nil {
return nil, fmt.Errorf("create WG interface: %w", err)
}
dev, err := wginterface.Open(ifaceName)
if err != nil {
_ = wginterface.Delete(ifaceName)
return nil, fmt.Errorf("open WG device: %w", err)
}
cc, err := newUDPControlConn(state.VPNIP, ControlPort)
if err != nil {
_ = dev.Close()
_ = wginterface.Delete(ifaceName)
return nil, fmt.Errorf("control conn: %w", err)
}
cleanup := func() {
_ = cc.Close()
_ = dev.Close()
_ = wginterface.Delete(ifaceName)
}
if err := dev.Configure(state.PrivKey, int(state.WGPort)); err != nil {
cleanup()
return nil, fmt.Errorf("configure WG device: %w", err)
}
if state.IsRelay {
if err := dev.EnableForwarding(); err != nil {
cleanup()
return nil, fmt.Errorf("enable forwarding: %w", err)
}
}
pingCh := make(chan PingEvent)
hubAddCh := make(chan m.Peer)
hubRemoveCh := make(chan wgtypes.Key)
multicastCh := make(chan multicast.Packet)
poller, err := NewHubPoller(
state.VPNIP,
state.VPNNet,
hubURL,
apiKey,
networkStatePath,
hubAddCh,
hubRemoveCh)
if err != nil {
cleanup()
return nil, fmt.Errorf("hub poller: %w", err)
}
go cc.run(pingCh)
go poller.Run()
if !state.IsPublic {
go multicast.Broadcast(state.VPNIP, state.PrivKey.PublicKey(), state.WGPort, &state.SignKey)
go multicast.Receiver(state.VPNNet, state.VPNIP, multicastCh)
}
return &App{
vpnIP: state.VPNIP,
vpnNet: state.VPNNet,
privKey: state.PrivKey,
pubKey: state.PrivKey.PublicKey(),
isRelay: state.IsRelay,
isPublic: state.IsPublic,
localDomain: localDomain,
dev: dev,
controlConn: cc,
peersByKey: make(map[wgtypes.Key]*Peer),
peersByIP: make(map[netip.Addr]*Peer),
scratch: make([]byte, scratchSize),
hubAddCh: hubAddCh,
hubRemoveCh: hubRemoveCh,
pingCh: pingCh,
multicastCh: multicastCh,
}, nil
}

View File

@@ -1,114 +0,0 @@
package peer
import (
"log"
"math"
"net/netip"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"vppn/m"
"vppn/peer/control"
)
func (a *App) onAddPeer(p m.Peer) {
a.onRemovePeer(p.WGPubKey)
octets := a.vpnNet.Addr().As4()
octets[3] = p.PeerIP
vpnIP := netip.AddrFrom4(octets)
peer := &Peer{
wgPeer: wgtypes.Peer{PublicKey: p.WGPubKey},
VPNIP: vpnIP,
Name: p.Name,
IsRelay: p.Relay,
IsPublic: p.IsPublic(),
EndpointV4: p.Endpoint4(),
EndpointV6: p.Endpoint6(),
RTT: time.Duration(math.MaxInt64) * time.Nanosecond,
Role: roleFor(a.isPublic, a.vpnIP, p.IsPublic(), vpnIP),
SignPubKey: p.SignPubKey,
}
a.peersByKey[p.WGPubKey] = peer
a.peersByIP[peer.VPNIP] = peer
defer a.updateHosts()
if !peer.IsPublic {
if a.isPublic {
// Public nodes accept traffic from non-public peers as soon as they
// initiate a handshake. Set /32 AllowedIPs now; WireGuard learns the
// endpoint from the incoming handshake automatically.
a.devPromote(peer)
} else {
a.devAddPeer(peer)
}
return
}
a.devAddDirect(peer, peer.PreferredEndpoint())
}
func (a *App) onRemovePeer(key wgtypes.Key) {
peer, exists := a.peersByKey[key]
if !exists {
return
}
a.devRemove(peer)
delete(a.peersByKey, key)
delete(a.peersByIP, peer.VPNIP)
a.updateHosts()
if peer == a.relay {
a.relay = nil
a.switchActiveRelay()
}
}
// switchActiveRelay promotes the lowest-latency relay peer to active.
func (a *App) switchActiveRelay() {
if a.relay != nil {
// If we have a relay, it's public, so should go back to being a direct
// peer - this will convert it's /24 to a /32.
a.devAddDirect(a.relay, a.relay.PreferredEndpoint())
a.relay = nil
}
var best *Peer
for _, p := range a.peersByKey {
if !p.CanRelay() {
continue
}
if best == nil || p.RTT < best.RTT {
best = p
}
}
if best == nil {
log.Printf("no relay available")
return
}
a.devSetRelay(best, best.PreferredEndpoint())
a.relay = best
}
func preferredEndpoint(v4, v6 netip.AddrPort) netip.AddrPort {
// We always prefer v4 since all peers can connect to IPv4 addresses.
if v4.IsValid() {
return v4
}
return v6
}
func roleFor(selfIsPublic bool, selfIP netip.Addr, peerIsPublic bool, peerVPNIP netip.Addr) control.Role {
if !selfIsPublic && peerIsPublic {
return control.Client
}
if selfIsPublic && !peerIsPublic {
return control.Server
}
return control.RoleFor(selfIP, peerVPNIP)
}

View File

@@ -1,299 +0,0 @@
package peer
import (
"net/netip"
"testing"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"vppn/m"
)
func mustKey(t *testing.T) wgtypes.Key {
t.Helper()
k, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatalf("generate key: %v", err)
}
return k.PublicKey()
}
func TestOnAddPeer(t *testing.T) {
ep1 := netip.MustParseAddrPort("1.2.3.4:51820")
ep2 := netip.MustParseAddrPort("5.6.7.8:51820")
peerVPNIP := netip.MustParseAddr("10.0.0.2")
testCases := []struct {
name string
setup func(a *App, key wgtypes.Key)
peer func(key wgtypes.Key) m.Peer
check func(t *testing.T, a *App, dev *fakeWGDevice, key wgtypes.Key)
}{
{
name: "non-public peer registered in WG via AddPeer",
peer: func(k wgtypes.Key) m.Peer {
return m.Peer{WGPubKey: k, PeerIP: 2}
},
check: func(t *testing.T, a *App, dev *fakeWGDevice, key wgtypes.Key) {
p := a.peersByKey[key]
if p == nil {
t.Fatal("not in peersByKey")
}
if a.peersByIP[peerVPNIP] == nil {
t.Fatal("not in peersByIP")
}
if p.State != StateRelayed {
t.Fatalf("state = %v, want StateRelayed", p.State)
}
dev.AssertAddPeer(t, 0, key)
},
},
{
name: "public peer with endpoint registered via AddDirect",
peer: func(k wgtypes.Key) m.Peer {
return m.Peer{WGPubKey: k, PeerIP: 2, Addr4: ep1.Addr(), Port: ep1.Port()}
},
check: func(t *testing.T, a *App, dev *fakeWGDevice, key wgtypes.Key) {
p := a.peersByKey[key]
if p == nil {
t.Fatal("not in peersByKey")
}
dev.AssertAddDirect(t, 0, p.PubKey(), ep1, p.VPNIP)
},
},
{
name: "re-add removes old WG entry before adding new one",
setup: func(a *App, key wgtypes.Key) {
a.onAddPeer(m.Peer{WGPubKey: key, PeerIP: 2, Addr4: ep1.Addr(), Port: ep1.Port()})
},
peer: func(k wgtypes.Key) m.Peer {
return m.Peer{WGPubKey: k, PeerIP: 2, Addr4: ep2.Addr(), Port: ep2.Port()}
},
check: func(t *testing.T, a *App, dev *fakeWGDevice, key wgtypes.Key) {
if len(dev.Calls) != 2 {
t.Fatalf("dev calls = %v, want [RemovePeer, AddDirect]", dev.Calls)
}
dev.AssertRemovePeer(t, 0, key)
dev.AssertAddDirect(t, 1, key, ep2, peerVPNIP)
if len(a.peersByKey) != 1 || len(a.peersByIP) != 1 {
t.Errorf("maps: peersByKey=%d peersByIP=%d, want 1 each", len(a.peersByKey), len(a.peersByIP))
}
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
a, dev, _ := newTestApp(t, "10.0.0.1", false, false)
key := mustKey(t)
if tc.setup != nil {
tc.setup(a, key)
dev.Calls = nil
}
a.onAddPeer(tc.peer(key))
tc.check(t, a, dev, key)
})
}
}
func TestOnRemovePeer(t *testing.T) {
ep1 := netip.MustParseAddrPort("1.2.3.4:51820")
ep2 := netip.MustParseAddrPort("5.6.7.8:51820")
testCases := []struct {
name string
setup func(t *testing.T, a *App) wgtypes.Key // returns the key to remove
check func(t *testing.T, a *App, dev *fakeWGDevice)
}{
{
name: "unknown key is a no-op",
setup: func(t *testing.T, a *App) wgtypes.Key {
return mustKey(t)
},
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
dev.AssertNoCalls(t)
if len(a.peersByKey) != 0 {
t.Errorf("peersByKey should be empty")
}
},
},
{
name: "StateRelayed peer removed from maps with RemovePeer",
setup: func(t *testing.T, a *App) wgtypes.Key {
key := mustKey(t)
a.onAddPeer(m.Peer{WGPubKey: key, PeerIP: 2})
return key
},
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
if len(dev.Calls) != 1 {
t.Fatalf("dev calls = %v, want [RemovePeer]", dev.Calls)
}
dev.AssertRemovePeer(t, 0, dev.Calls[0].PubKey)
if len(a.peersByKey) != 0 || len(a.peersByIP) != 0 {
t.Errorf("maps should be empty after remove")
}
},
},
{
name: "StateDirect peer removed from maps with RemovePeer",
setup: func(t *testing.T, a *App) wgtypes.Key {
key := mustKey(t)
a.onAddPeer(m.Peer{WGPubKey: key, PeerIP: 2, Addr4: ep1.Addr(), Port: ep1.Port()})
return key
},
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
if len(dev.Calls) != 1 {
t.Fatalf("dev calls = %v, want [RemovePeer]", dev.Calls)
}
dev.AssertRemovePeer(t, 0, dev.Calls[0].PubKey)
if len(a.peersByKey) != 0 || len(a.peersByIP) != 0 {
t.Errorf("maps should be empty after remove")
}
},
},
{
name: "removing active relay with no backup clears relay field",
setup: func(t *testing.T, a *App) wgtypes.Key {
relay := addRelayPeer(t, a, "10.0.0.10", ep1)
a.relay = relay
return relay.PubKey()
},
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
if len(dev.Calls) != 1 {
t.Fatalf("dev calls = %v, want [RemovePeer]", dev.Calls)
}
dev.AssertRemovePeer(t, 0, dev.Calls[0].PubKey)
if a.relay != nil {
t.Errorf("relay should be nil after removing only relay")
}
},
},
{
name: "removing active relay elects backup via SetRelay",
setup: func(t *testing.T, a *App) wgtypes.Key {
relay1 := addRelayPeer(t, a, "10.0.0.10", ep1)
addRelayPeer(t, a, "10.0.0.11", ep2)
a.relay = relay1
return relay1.PubKey()
},
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
if len(dev.Calls) != 2 {
t.Fatalf("dev calls = %v, want [RemovePeer, SetRelay]", dev.Calls)
}
dev.AssertRemovePeer(t, 0, dev.Calls[0].PubKey)
dev.AssertSetRelay(t, 1, dev.Calls[1].PubKey, ep2, a.vpnNet)
if a.relay == nil {
t.Errorf("relay should be set to backup after failover")
}
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
a, dev, _ := newTestApp(t, "10.0.0.1", false, false)
key := tc.setup(t, a)
dev.Calls = nil
a.onRemovePeer(key)
tc.check(t, a, dev)
})
}
}
func TestSwitchActiveRelay(t *testing.T) {
ep1 := netip.MustParseAddrPort("1.2.3.4:51820")
ep2 := netip.MustParseAddrPort("5.6.7.8:51820")
testCases := []struct {
name string
setup func(t *testing.T, a *App)
check func(t *testing.T, a *App, dev *fakeWGDevice)
}{
{
name: "no candidates leaves relay nil",
setup: func(t *testing.T, a *App) {},
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
dev.AssertNoCalls(t)
if a.relay != nil {
t.Error("relay should be nil")
}
},
},
{
name: "single candidate elected via SetRelay",
setup: func(t *testing.T, a *App) {
addRelayPeer(t, a, "10.0.0.10", ep1)
},
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
if len(dev.Calls) != 1 {
t.Fatalf("dev calls = %v, want [SetRelay]", dev.Calls)
}
dev.AssertSetRelay(t, 0, dev.Calls[0].PubKey, ep1, a.vpnNet)
if a.relay == nil {
t.Error("relay should be set")
}
},
},
{
name: "measured RTT beats zero RTT",
setup: func(t *testing.T, a *App) {
r1 := addRelayPeer(t, a, "10.0.0.10", ep1)
r1.RTT = 10 * time.Millisecond
addRelayPeer(t, a, "10.0.0.11", ep2) // RTT stays MaxInt64 (unmeaured)
},
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
if len(dev.Calls) != 1 {
t.Fatalf("dev calls = %v, want [SetRelay]", dev.Calls)
}
dev.AssertSetRelay(t, 0, dev.Calls[0].PubKey, ep1, a.vpnNet)
},
},
{
name: "lower RTT wins",
setup: func(t *testing.T, a *App) {
r1 := addRelayPeer(t, a, "10.0.0.10", ep1)
r1.RTT = 5 * time.Millisecond
r2 := addRelayPeer(t, a, "10.0.0.11", ep2)
r2.RTT = 20 * time.Millisecond
},
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
if len(dev.Calls) != 1 {
t.Fatalf("dev calls = %v, want [SetRelay]", dev.Calls)
}
dev.AssertSetRelay(t, 0, dev.Calls[0].PubKey, ep1, a.vpnNet)
},
},
{
name: "stale relay demoted to direct before backup elected",
setup: func(t *testing.T, a *App) {
old := addRelayPeer(t, a, "10.0.0.10", ep1)
old.wgPeer.LastHandshakeTime = time.Time{} // stale — triggers switch from onTick
a.relay = old
addRelayPeer(t, a, "10.0.0.11", ep2)
},
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
if len(dev.Calls) != 2 {
t.Fatalf("dev calls = %v, want [AddDirect, SetRelay]", dev.Calls)
}
if dev.Calls[0].Method != "AddDirect" || dev.Calls[0].Endpoint != ep1 {
t.Errorf("call[0]: got %v, want AddDirect with ep1", dev.Calls[0])
}
dev.AssertSetRelay(t, 1, dev.Calls[1].PubKey, ep2, a.vpnNet)
if a.relay == nil || a.relay.EndpointV4 != ep2 {
t.Error("relay should be the backup peer")
}
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
a, dev, _ := newTestApp(t, "10.0.0.1", false, false)
tc.setup(t, a)
dev.Calls = nil
a.switchActiveRelay()
tc.check(t, a, dev)
})
}
}

View File

@@ -1,51 +0,0 @@
package peer
import (
"net/netip"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"vppn/peer/multicast"
)
func (a *App) onMulticastDiscovery(pkt multicast.Packet) {
if a.isPublic {
return
}
// Locate the sender peer by its VPN IP (final octet carried in the beacon).
octets := a.vpnNet.Addr().As4()
octets[3] = pkt.PeerIP
vpnIP := netip.AddrFrom4(octets)
peer, ok := a.peersByIP[vpnIP]
if !ok || peer.IsPublic || peer.State == StateDirect {
return
}
// Authenticate the beacon against the peer's known sign key. scratch[:0]
// gives sign.Open an empty-but-capacity buffer to decode into.
if !pkt.Verify(a.scratch[:0], &peer.SignPubKey) {
return
}
// The beacon is authentic but must also advertise the WG key the hub gave
// us for this peer; otherwise it's inconsistent — drop it.
if wgtypes.Key(pkt.WGPubKey) != peer.PubKey() {
return
}
endpoint := netip.AddrPortFrom(pkt.Src, pkt.WGPort)
if !endpoint.IsValid() {
return
}
var v4, v6 netip.AddrPort
if pkt.Src.Is4() {
v4 = endpoint
} else {
v6 = endpoint
}
a.addProbe(peer, v4, v6)
}

View File

@@ -1,58 +0,0 @@
package peer
import (
"net/netip"
"time"
"vppn/peer/control"
)
func (a *App) onPing(e PingEvent) {
peer, ok := a.peersByIP[e.srcVPNIP]
if !ok {
// TODO: Log here.
return
}
now := time.Now()
// If we're the server, respond - this is always necessary as it's used to
// know if peers are up or down.
if peer.Role == control.Server {
a.sendPing(peer, e.ping.PingTS)
}
// Compute RTT from server echo.
if peer.Role == control.Client {
peer.RTT = now.Sub(time.Unix(0, e.ping.PingTS))
}
// If we're public, nothing more to do.
if a.isPublic {
return
}
// We can only learn our own endpoint from directly-connected peers — Dst
// is the sender's observation of our WG handshake source.
if peer.State == StateDirect {
if dst := e.ping.Dst; dst.IsValid() {
if dst.Addr().Is4() {
a.selfV4 = dst
} else {
a.selfV6 = dst
}
}
return
}
a.addProbe(peer, e.ping.SrcV4, e.ping.SrcV6)
}
func (a *App) addProbe(peer *Peer, v4, v6 netip.AddrPort) {
endpoint := preferredEndpoint(v4, v6)
if !endpoint.IsValid() || endpoint == peer.PreferredEndpoint() {
return
}
peer.UpdateEndpoints(v4, v6)
a.devAddProbe(peer, endpoint)
}

View File

@@ -1,52 +0,0 @@
package peer
import (
"log"
"time"
"vppn/peer/control"
"vppn/peer/wginterface"
)
func (a *App) onTick() {
wgPeers := a.devPeers()
now := time.Now().UnixNano()
for _, wgPeer := range wgPeers {
p, ok := a.peersByKey[wgPeer.PublicKey]
if !ok {
log.Printf("Wireguard peer not in index, removing: %v", wgPeer)
a.devRemove(&Peer{wgPeer: wgPeer})
continue
}
p.wgPeer = wgPeer
// Send pings to peers where we're the client.
if p.Role == control.Client {
a.sendPing(p, now)
}
switch p.State {
case StateProbing:
// Promote probing peers to direct once alive (direct path confirmed
// working).
if time.Since(p.LastHandshakeTime()) < 2*wginterface.ProbeKeepalive {
a.devPromote(p)
}
case StateDirect:
if p.IsPublic || a.isPublic || p.Up() {
break
}
// Stale non-public direct peer: demote to probing so WireGuard
// resumes handshake attempts on the direct path.
a.devAddProbe(p, p.WGEndpoint())
}
}
// Ensure we have a live relay (if we're not public).
if !a.isPublic && (a.relay == nil || !a.relay.Up()) {
a.switchActiveRelay()
}
}

182
peer/packets-util.go Normal file
View File

@@ -0,0 +1,182 @@
package peer
import (
"net/netip"
"unsafe"
)
// ----------------------------------------------------------------------------
type binWriter struct {
b []byte
i int
}
func newBinWriter(buf []byte) *binWriter {
buf = buf[:cap(buf)]
return &binWriter{buf, 0}
}
func (w *binWriter) Bool(b bool) *binWriter {
if b {
return w.Byte(1)
}
return w.Byte(0)
}
func (w *binWriter) Byte(b byte) *binWriter {
w.b[w.i] = b
w.i++
return w
}
func (w *binWriter) SharedKey(key [32]byte) *binWriter {
copy(w.b[w.i:w.i+32], key[:])
w.i += 32
return w
}
func (w *binWriter) Uint16(x uint16) *binWriter {
*(*uint16)(unsafe.Pointer(&w.b[w.i])) = x
w.i += 2
return w
}
func (w *binWriter) Uint64(x uint64) *binWriter {
*(*uint64)(unsafe.Pointer(&w.b[w.i])) = x
w.i += 8
return w
}
func (w *binWriter) Int64(x int64) *binWriter {
*(*int64)(unsafe.Pointer(&w.b[w.i])) = x
w.i += 8
return w
}
func (w *binWriter) AddrPort(addrPort netip.AddrPort) *binWriter {
w.Bool(addrPort.IsValid())
addr := addrPort.Addr().As16()
copy(w.b[w.i:w.i+16], addr[:])
w.i += 16
return w.Uint16(addrPort.Port())
}
func (w *binWriter) AddrPort8(l [8]netip.AddrPort) *binWriter {
for _, addrPort := range l {
w.AddrPort(addrPort)
}
return w
}
func (w *binWriter) Build() []byte {
return w.b[:w.i]
}
// ----------------------------------------------------------------------------
type binReader struct {
b []byte
i int
err error
}
func newBinReader(buf []byte) *binReader {
return &binReader{b: buf}
}
func (r *binReader) hasBytes(n int) bool {
if r.err != nil || (len(r.b)-r.i) < n {
r.err = errMalformedPacket
return false
}
return true
}
func (r *binReader) Bool(b *bool) *binReader {
var bb byte
r.Byte(&bb)
*b = bb != 0
return r
}
func (r *binReader) Byte(b *byte) *binReader {
if !r.hasBytes(1) {
return r
}
*b = r.b[r.i]
r.i++
return r
}
func (r *binReader) SharedKey(x *[32]byte) *binReader {
if !r.hasBytes(32) {
return r
}
*x = ([32]byte)(r.b[r.i : r.i+32])
r.i += 32
return r
}
func (r *binReader) Uint16(x *uint16) *binReader {
if !r.hasBytes(2) {
return r
}
*x = *(*uint16)(unsafe.Pointer(&r.b[r.i]))
r.i += 2
return r
}
func (r *binReader) Uint64(x *uint64) *binReader {
if !r.hasBytes(8) {
return r
}
*x = *(*uint64)(unsafe.Pointer(&r.b[r.i]))
r.i += 8
return r
}
func (r *binReader) Int64(x *int64) *binReader {
if !r.hasBytes(8) {
return r
}
*x = *(*int64)(unsafe.Pointer(&r.b[r.i]))
r.i += 8
return r
}
func (r *binReader) AddrPort(x *netip.AddrPort) *binReader {
if !r.hasBytes(19) {
return r
}
var (
valid bool
port uint16
)
r.Bool(&valid)
addr := netip.AddrFrom16(([16]byte)(r.b[r.i : r.i+16])).Unmap()
r.i += 16
r.Uint16(&port)
if valid {
*x = netip.AddrPortFrom(addr, port)
} else {
*x = netip.AddrPort{}
}
return r
}
func (r *binReader) AddrPort8(x *[8]netip.AddrPort) *binReader {
for i := range x {
r.AddrPort(&x[i])
}
return r
}
func (r *binReader) Error() error {
return r.err
}

76
peer/packets-util_test.go Normal file
View File

@@ -0,0 +1,76 @@
package peer
import (
"net/netip"
"reflect"
"testing"
)
func TestBinWriteRead_invalidAddrPort(t *testing.T) {
addr := netip.AddrPort{}
buf := make([]byte, 1024)
buf = newBinWriter(buf).
AddrPort(addr).
Build()
var addr2 netip.AddrPort
err := newBinReader(buf).
AddrPort(&addr2).
Error()
if err != nil {
t.Fatal(err)
}
if addr2.IsValid() {
t.Fatal(addr, addr2)
}
}
func TestBinWriteRead(t *testing.T) {
buf := make([]byte, 1024)
type Item struct {
Type byte
TraceID uint64
Addrs [8]netip.AddrPort
DestAddr netip.AddrPort
}
in := Item{
1,
2,
[8]netip.AddrPort{},
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 22),
}
in.Addrs[0] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{0, 1, 2, 3}), 20)
in.Addrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 5}), 22)
in.Addrs[3] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 3}), 23)
in.Addrs[4] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 4}), 24)
in.Addrs[5] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 5}), 25)
in.Addrs[6] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 6}), 26)
in.Addrs[7] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{7, 8, 9, 7}), 27)
buf = newBinWriter(buf).
Byte(in.Type).
Uint64(in.TraceID).
AddrPort(in.DestAddr).
AddrPort8(in.Addrs).
Build()
out := Item{}
err := newBinReader(buf).
Byte(&out.Type).
Uint64(&out.TraceID).
AddrPort(&out.DestAddr).
AddrPort8(&out.Addrs).
Error()
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(in, out) {
t.Fatal(in, out)
}
}

120
peer/packets.go Normal file
View File

@@ -0,0 +1,120 @@
package peer
import (
"net/netip"
)
const (
packetTypeSyn = 1
packetTypeInit = 2
packetTypeAck = 3
packetTypeProbe = 4
packetTypeAddrDiscovery = 5
)
// ----------------------------------------------------------------------------
type packetInit struct {
TraceID uint64
Direct bool
Version uint64
}
func (p packetInit) Marshal(buf []byte) []byte {
return newBinWriter(buf).
Byte(packetTypeInit).
Uint64(p.TraceID).
Bool(p.Direct).
Uint64(p.Version).
Build()
}
func parsePacketInit(buf []byte) (p packetInit, err error) {
err = newBinReader(buf[1:]).
Uint64(&p.TraceID).
Bool(&p.Direct).
Uint64(&p.Version).
Error()
return
}
// ----------------------------------------------------------------------------
type packetSyn struct {
TraceID uint64 // TraceID to match response w/ request.
SharedKey [32]byte // Our shared key.
Direct bool
PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender.
}
func (p packetSyn) Marshal(buf []byte) []byte {
return newBinWriter(buf).
Byte(packetTypeSyn).
Uint64(p.TraceID).
SharedKey(p.SharedKey).
Bool(p.Direct).
AddrPort8(p.PossibleAddrs).
Build()
}
func parsePacketSyn(buf []byte) (p packetSyn, err error) {
err = newBinReader(buf[1:]).
Uint64(&p.TraceID).
SharedKey(&p.SharedKey).
Bool(&p.Direct).
AddrPort8(&p.PossibleAddrs).
Error()
return
}
// ----------------------------------------------------------------------------
type packetAck struct {
TraceID uint64
ToAddr netip.AddrPort
PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender.
}
func (p packetAck) Marshal(buf []byte) []byte {
return newBinWriter(buf).
Byte(packetTypeAck).
Uint64(p.TraceID).
AddrPort(p.ToAddr).
AddrPort8(p.PossibleAddrs).
Build()
}
func parsePacketAck(buf []byte) (p packetAck, err error) {
err = newBinReader(buf[1:]).
Uint64(&p.TraceID).
AddrPort(&p.ToAddr).
AddrPort8(&p.PossibleAddrs).
Error()
return
}
// ----------------------------------------------------------------------------
// A probeReqPacket is sent from a client to a server to determine if direct
// UDP communication can be used.
type packetProbe struct {
TraceID uint64
}
func (p packetProbe) Marshal(buf []byte) []byte {
return newBinWriter(buf).
Byte(packetTypeProbe).
Uint64(p.TraceID).
Build()
}
func parsePacketProbe(buf []byte) (p packetProbe, err error) {
err = newBinReader(buf[1:]).
Uint64(&p.TraceID).
Error()
return
}
// ----------------------------------------------------------------------------
type packetLocalDiscovery struct{}

64
peer/packets_test.go Normal file
View File

@@ -0,0 +1,64 @@
package peer
import (
"crypto/rand"
"net/netip"
"reflect"
"testing"
)
func TestSynPacket(t *testing.T) {
p := packetSyn{
TraceID: 2342342345,
Direct: true,
}
rand.Read(p.SharedKey[:])
p.PossibleAddrs[0] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 234)
p.PossibleAddrs[1] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 2, 3, 4}), 12399)
p.PossibleAddrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{3, 2, 3, 4}), 60000)
buf := p.Marshal(newBuf())
p2, err := parsePacketSyn(buf)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(p, p2) {
t.Fatal(p2)
}
}
func TestAckPacket(t *testing.T) {
p := packetAck{
TraceID: 123213,
ToAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 234),
}
p.PossibleAddrs[0] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 2, 3, 4}), 100)
p.PossibleAddrs[1] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 2, 3, 4}), 12399)
p.PossibleAddrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{3, 2, 3, 4}), 60000)
buf := p.Marshal(newBuf())
p2, err := parsePacketAck(buf)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(p, p2) {
t.Fatal(p2)
}
}
func TestProbePacket(t *testing.T) {
p := packetProbe{
TraceID: 12345,
}
buf := p.Marshal(newBuf())
p2, err := parsePacketProbe(buf)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(p, p2) {
t.Fatal(p2)
}
}

197
peer/peer.go Normal file
View File

@@ -0,0 +1,197 @@
package peer
import (
"bytes"
"encoding/json"
"fmt"
"io"
"log"
"math"
"net"
"net/http"
"net/netip"
"net/url"
"os"
"vppn/m"
"git.crumpington.com/lib/go/flock"
)
type peerMain struct {
Globals
ifReader *IFReader
connReader *ConnReader
hubPoller *HubPoller
lockFile *os.File
}
func newPeerMain(args mainArgs) *peerMain {
logf := func(s string, args ...any) {
log.Printf("[Main] "+s, args...)
}
lockFile, err := flock.TryLock(lockFilePath(args.NetName))
if err != nil {
log.Fatalf("Failed to open lock file: %v", err)
}
if lockFile == nil {
log.Fatalf("Failed to obtain file lock.")
}
config, err := loadPeerConfig(args.NetName)
if err != nil {
logf("Failed to load configuration: %v", err)
logf("Initializing...")
initPeerWithHub(args)
config, err = loadPeerConfig(args.NetName)
if err != nil {
log.Fatalf("Failed to load configuration: %v", err)
}
}
state, err := loadNetworkState(args.NetName)
if err != nil {
log.Fatalf("Failed to load network state: %v", err)
}
startupCount, err := loadStartupCount(args.NetName)
if err != nil {
if !os.IsNotExist(err) {
log.Fatalf("Failed to load startup count: %v", err)
}
}
if startupCount.Count == math.MaxUint16 {
log.Fatalf("Startup counter overflow.")
}
startupCount.Count += 1
if err := storeStartupCount(args.NetName, startupCount); err != nil {
log.Fatalf("Failed to write startup count: %v", err)
}
iface, err := openInterface(config.Network, config.LocalPeerIP, args.NetName)
if err != nil {
log.Fatalf("Failed to open interface: %v", err)
}
localPeer := state.Peers[config.LocalPeerIP]
myAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", localPeer.Port))
if err != nil {
log.Fatalf("Failed to resolve UDP address: %v", err)
}
logf("Listening on %v...", myAddr)
conn, err := net.ListenUDP("udp", myAddr)
if err != nil {
log.Fatalf("Failed to open UDP port: %v", err)
}
conn.SetReadBuffer(1024 * 1024 * 8)
conn.SetWriteBuffer(1024 * 1024 * 8)
var localAddr netip.AddrPort
ip, localAddrValid := netip.AddrFromSlice(localPeer.PublicIP)
if localAddrValid {
localAddr = netip.AddrPortFrom(ip, localPeer.Port)
}
g := NewGlobals(config, startupCount, localAddr, conn, iface)
hubPoller, err := NewHubPoller(g, args.NetName, args.HubAddress, args.APIKey)
if err != nil {
log.Fatalf("Failed to create hub poller: %v", err)
}
// Start status server.
go runStatusServer(g, statusSocketPath(args.NetName))
return &peerMain{
Globals: g,
ifReader: NewIFReader(g),
connReader: NewConnReader(g, conn),
hubPoller: hubPoller,
lockFile: lockFile,
}
}
func (p *peerMain) Run() {
for i := range p.RemotePeers {
remote := p.RemotePeers[i].Load()
go newRemoteFSM(remote).Run()
}
go p.ifReader.Run()
go p.connReader.Run()
if !p.LocalAddrValid {
go RunMCWriter(p.LocalPeerIP, p.PrivSignKey)
go RunMCReader(p.Globals)
}
go p.hubPoller.Run()
select {}
}
func initPeerWithHub(args mainArgs) {
keys := generateKeys()
initURL, err := url.Parse(args.HubAddress)
if err != nil {
log.Fatalf("Failed to parse hub URL: %v", err)
}
initURL.Path = "/peer/init/"
initArgs := m.PeerInitArgs{
EncPubKey: keys.PubKey,
PubSignKey: keys.PubSignKey,
}
buf := &bytes.Buffer{}
if err := json.NewEncoder(buf).Encode(initArgs); err != nil {
log.Fatalf("Failed to encode init args: %v", err)
}
req, err := http.NewRequest(http.MethodPost, initURL.String(), buf)
if err != nil {
log.Fatalf("Failed to construct request: %v", err)
}
req.SetBasicAuth("", args.APIKey)
resp, err := http.DefaultClient.Do(req)
if err != nil {
log.Fatalf("Failed to init with hub: %v", err)
}
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
if err != nil {
log.Fatalf("Failed to read response body: %v", err)
}
initResp := m.PeerInitResp{}
if err := json.Unmarshal(data, &initResp); err != nil {
log.Fatalf("Failed to parse configuration: %v\n%s", err, data)
}
config := LocalConfig{}
config.LocalPeerIP = initResp.PeerIP
config.Network = initResp.Network
config.PubKey = keys.PubKey
config.PrivKey = keys.PrivKey
config.PubSignKey = keys.PubSignKey
config.PrivSignKey = keys.PrivSignKey
if err := storeNetworkState(args.NetName, initResp.NetworkState); err != nil {
log.Fatalf("Failed to store network state: %v", err)
}
if err := storePeerConfig(args.NetName, config); err != nil {
log.Fatalf("Failed to store configuration: %v", err)
}
log.Print("Initialization successful.")
}

View File

@@ -1,21 +0,0 @@
package peer
import (
"log"
"net/netip"
"vppn/peer/control"
)
func (a *App) sendPing(p *Peer, ts int64) {
ping := control.Ping{
PingTS: ts,
SrcV4: a.selfV4,
SrcV6: a.selfV6,
Dst: p.WGEndpoint(),
}
dst := netip.AddrPortFrom(p.VPNIP, ControlPort)
if err := a.controlConn.SendPing(dst, ping, a.scratch); err != nil {
log.Printf("sendPing %v: %v", p.VPNIP, err)
}
}

86
peer/pubaddrs.go Normal file
View File

@@ -0,0 +1,86 @@
package peer
import (
"net/netip"
"sort"
"sync"
"time"
)
type pubAddrStore struct {
lock sync.Mutex
localPub bool
localAddr netip.AddrPort
lastSeen map[netip.AddrPort]time.Time
addrList []netip.AddrPort
}
func newPubAddrStore(localAddr netip.AddrPort) *pubAddrStore {
return &pubAddrStore{
localPub: localAddr.IsValid(),
localAddr: localAddr,
lastSeen: map[netip.AddrPort]time.Time{},
addrList: make([]netip.AddrPort, 0, 32),
}
}
func (store *pubAddrStore) Store(addr netip.AddrPort) {
if store.localPub {
return
}
if !addr.IsValid() {
return
}
if addr.Addr().IsPrivate() {
return
}
store.lock.Lock()
defer store.lock.Unlock()
if _, exists := store.lastSeen[addr]; !exists {
store.addrList = append(store.addrList, addr)
}
store.lastSeen[addr] = time.Now()
store.sort()
}
func (store *pubAddrStore) Get() (addrs [8]netip.AddrPort) {
store.lock.Lock()
defer store.lock.Unlock()
store.clean()
if store.localPub {
addrs[0] = store.localAddr
return
}
copy(addrs[:], store.addrList)
return
}
func (store *pubAddrStore) clean() {
if store.localPub {
return
}
for ip, lastSeen := range store.lastSeen {
if time.Since(lastSeen) > timeoutInterval {
delete(store.lastSeen, ip)
}
}
store.addrList = store.addrList[:0]
for ip := range store.lastSeen {
store.addrList = append(store.addrList, ip)
}
store.sort()
}
func (store *pubAddrStore) sort() {
sort.Slice(store.addrList, func(i, j int) bool {
return store.lastSeen[store.addrList[j]].Before(store.lastSeen[store.addrList[i]])
})
}

29
peer/pubaddrs_test.go Normal file
View File

@@ -0,0 +1,29 @@
package peer
import (
"net/netip"
"testing"
"time"
)
func TestPubAddrStore(t *testing.T) {
s := newPubAddrStore(netip.AddrPort{})
l := []netip.AddrPort{
netip.AddrPortFrom(netip.AddrFrom4([4]byte{0, 1, 2, 3}), 20),
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 2, 3}), 21),
netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 1, 2, 3}), 22),
}
for i := range l {
s.Store(l[i])
time.Sleep(time.Millisecond)
}
s.clean()
l2 := s.Get()
if l2[0] != l[2] || l2[1] != l[1] || l2[2] != l[0] {
t.Fatal(l, l2)
}
}

54
peer/relayhandler.go Normal file
View File

@@ -0,0 +1,54 @@
package peer
import (
"log"
"sync"
"sync/atomic"
)
type relayHandler struct {
lock sync.Mutex
relays map[byte]*Remote
relay atomic.Pointer[Remote]
}
func newRelayHandler() *relayHandler {
return &relayHandler{
relays: make(map[byte]*Remote, 256),
}
}
func (h *relayHandler) Add(r *Remote) {
h.lock.Lock()
defer h.lock.Unlock()
h.relays[r.RemotePeerIP] = r
if h.relay.Load() == nil {
log.Printf("Setting Relay: %v", r.conf().Peer.Name)
h.relay.Store(r)
}
}
func (h *relayHandler) Remove(r *Remote) {
h.lock.Lock()
defer h.lock.Unlock()
log.Printf("Removing relay %d...", r.RemotePeerIP)
delete(h.relays, r.RemotePeerIP)
if h.relay.Load() == r {
// Remove current relay.
h.relay.Store(nil)
// Find new relay.
for _, r := range h.relays {
h.relay.Store(r)
break
}
}
}
func (h *relayHandler) Load() *Remote {
return h.relay.Load()
}

Some files were not shown because too many files have changed in this diff Show More