66 Commits

Author SHA1 Message Date
jdl
a2fe8dc79d Audit fixes 2026-06-16 08:10:37 +02:00
jdl
2b8cc86077 Audit fixes 2026-06-16 08:04:33 +02:00
jdl
cb7c07ac96 Audit fixes 2026-06-16 07:58:50 +02:00
jdl
d8c2990ffd Audit fixes 2026-06-16 07:55:46 +02:00
jdl
32b8b0dc89 Audit fixes. 2026-06-16 07:49:18 +02:00
jdl
302e5d00d0 Audit fixes 2026-06-16 07:44:49 +02:00
jdl
85d6d577e3 WIP: refactor incoming 2026-06-16 07:35:53 +02:00
jdl
691eb49009 Bug fix from audit. 2026-06-15 23:16:41 +02:00
jdl
b479b37479 Bug fix from audit. 2026-06-15 20:55:57 +02:00
jdl
fe9f15bec9 Cleanup. 2026-06-15 20:15:06 +02:00
jdl
b86b43f1de Cleanup. 2026-06-15 19:45:07 +02:00
jdl
c47d00e694 Cleanup. 2026-06-15 19:06:12 +02:00
jdl
458e1ac603 Cleanup. 2026-06-15 18:58:08 +02:00
jdl
9d57f45aea Fixed some audit issues. 2026-06-15 18:55:29 +02:00
jdl
36e9f6149d WIP 2026-06-15 18:45:11 +02:00
jdl
eaa101f976 WIP 2026-06-15 18:02:44 +02:00
jdl
802ca9aba4 WIP: improve responsiveness. 2026-06-15 17:59:25 +02:00
jdl
fa933ae029 Logging 2026-06-15 06:25:47 +02:00
jdl
b4320c9330 Bug fix - peer Up calculation #2 2026-06-15 06:19:20 +02:00
jdl
d02f47cce6 Bug fix - peer Up calculation 2026-06-15 06:17:52 +02:00
jdl
797ab8bdef Bug fixes, cleanup 2026-06-14 20:32:46 +02:00
jdl
f765303daf Dependency updates, logging middleware. 2026-06-14 20:31:49 +02:00
jdl
0e3d4ec3a5 Cleaning up dependencies. 2026-06-14 20:26:43 +02:00
jdl
b875313f7d More audit changes. 2026-06-14 09:13:01 +02:00
jdl
71328eb67e Audit changes 2026-06-14 09:07:41 +02:00
jdl
3e630ee0ad Audit changes 2026-06-14 09:04:41 +02:00
jdl
164d1f9d95 Audit changes 2026-06-14 08:57:23 +02:00
jdl
fa182eca76 Audit changes 2026-06-14 08:15:00 +02:00
jdl
52ea1a8d42 Audit changes 2026-06-14 06:17:35 +02:00
jdl
cc21bee798 Audit changes 2026-06-14 05:53:12 +02:00
jdl
353ef07f92 Audit changes 2026-06-14 05:47:17 +02:00
jdl
c12ef3341f Audit changes 2026-06-14 05:42:44 +02:00
jdl
992eabc0e9 Minor bug fixes 2026-06-14 05:30:59 +02:00
jdl
c45ac83eb0 Fixed bug - update probe address if it changes between ticks 2026-06-13 21:21:55 +02:00
jdl
3fe9f63901 Cleanup - audit 2026-06-13 20:25:07 +02:00
jdl
cfb2a29082 Cleanup - audit 2026-06-13 20:23:09 +02:00
jdl
9bdb836eaa Cleanup - audit 2026-06-13 20:18:13 +02:00
jdl
393f79e1d3 Cleanup - audit 2026-06-13 20:15:08 +02:00
jdl
c911e3e865 Cleanup - audit 2026-06-13 20:07:17 +02:00
jdl
c356347cf6 Cleanup - audit 2026-06-13 20:01:08 +02:00
jdl
111f3f4d20 Cleanup - audit 2026-06-13 19:57:52 +02:00
jdl
b6052ee7b8 Cleanup - audit 2026-06-13 19:55:16 +02:00
jdl
2e19f0945f Audit changes. 2026-06-13 19:34:18 +02:00
jdl
c325180a1b Audit changes. 2026-06-13 19:30:03 +02:00
jdl
c4a81cf553 Audit changes. 2026-06-13 19:11:40 +02:00
jdl
0f117e5e66 Added some logging. 2026-06-13 19:02:17 +02:00
jdl
68f01f9823 Audit changes. 2026-06-13 18:07:21 +02:00
jdl
8983c0d651 Audit changes. 2026-06-13 18:03:44 +02:00
jdl
0cd5982a3f Audit changes. 2026-06-13 15:47:28 +02:00
jdl
243e75dd09 Audit changes. 2026-06-13 15:42:40 +02:00
jdl
1f7d3151b5 Audit changes. 2026-06-13 15:39:08 +02:00
jdl
0709c4dac0 Audit changes. 2026-06-13 15:37:35 +02:00
jdl
c0126c2036 Audit changes. 2026-06-13 15:18:17 +02:00
jdl
528e67ea61 Audit changes. 2026-06-13 15:06:14 +02:00
jdl
fe5f26ed70 Audit changes. 2026-06-13 15:01:50 +02:00
jdl
75782c4efd Cleanup 2026-06-13 14:55:46 +02:00
jdl
867b3b5949 Error cleanup 2026-06-13 14:51:51 +02:00
jdl
76fce15e32 WIP 2026-06-13 14:46:55 +02:00
jdl
3dcd1c1080 WIP 2026-06-13 14:45:23 +02:00
jdl
232b68310c WIP 2026-06-13 14:44:25 +02:00
jdl
a730211167 Added timeout to read failure in rdpControlConn to avoid spinning on error. 2026-06-13 09:07:03 +02:00
jdl
f3d8a9ff75 Added guard in ioctlSetAddr for nil IP 2026-06-13 08:48:42 +02:00
jdl
98f07457b9 Added panic on small buffer for Ping.Marshal. 2026-06-13 08:43:35 +02:00
jdl
cbc901496c Cleanjp 2026-06-13 08:35:05 +02:00
jdl
cd5442f3bf AUDIT changes 2026-06-13 00:10:21 +02:00
jdl
11f6f2fc75 Updated deps 2026-06-12 18:59:18 +02:00
45 changed files with 694 additions and 322 deletions

View File

@@ -9,7 +9,7 @@ import (
"vppn/peer"
"git.crumpington.com/lib/go/flock"
"git.crumpington.com/lib/flock"
)
func main() {
@@ -34,9 +34,6 @@ func main() {
if err != nil {
log.Fatalf("lock: %v", err)
}
if lockFile == nil {
log.Fatalf("already running for network %q", *name)
}
defer flock.Unlock(lockFile)
state, err := peer.LoadOrInit(vppnPath(*name, "state.json"), *hub, apiKey)

30
go.mod
View File

@@ -3,21 +3,25 @@ module vppn
go 1.25.1
require (
git.crumpington.com/lib/go v0.9.1
golang.org/x/crypto v0.42.0
golang.org/x/sys v0.36.0
git.crumpington.com/lib/flock v1.1.0
git.crumpington.com/lib/idgen v1.0.0
git.crumpington.com/lib/keyedmutex v1.1.0
git.crumpington.com/lib/ratelimiter v1.1.1
git.crumpington.com/lib/sqliteutil v1.1.1
git.crumpington.com/lib/webutil v1.1.0
github.com/mattn/go-sqlite3 v1.14.45
golang.org/x/crypto v0.53.0
golang.org/x/sys v0.46.0
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
)
require (
github.com/google/go-cmp v0.6.0 // indirect
github.com/josharian/native v1.1.0 // indirect
github.com/mattn/go-sqlite3 v1.14.32 // indirect
github.com/mdlayher/genetlink v1.3.2 // indirect
github.com/mdlayher/netlink v1.7.2 // indirect
github.com/mdlayher/socket v0.5.1 // indirect
golang.org/x/net v0.44.0 // indirect
golang.org/x/sync v0.17.0 // indirect
golang.org/x/text v0.29.0 // indirect
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/mdlayher/genetlink v1.4.0 // indirect
github.com/mdlayher/netlink v1.11.2 // indirect
github.com/mdlayher/socket v0.6.1 // indirect
golang.org/x/net v0.56.0 // indirect
golang.org/x/sync v0.21.0 // indirect
golang.org/x/text v0.38.0 // indirect
golang.zx2c4.com/wireguard v0.0.0-20260522210424-ecfc5a8d5446 // indirect
)

60
go.sum
View File

@@ -1,30 +1,38 @@
git.crumpington.com/lib/go v0.9.1 h1:xLBzcgiZRB6Ky3Ce9hKE+Ko0YbkA4USF4eJk5i5RJF4=
git.crumpington.com/lib/go v0.9.1/go.mod h1:5nnfjdnUnj/FHhakaliKQKsKeSkUb0GEUKF3PqRgUXg=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw=
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
git.crumpington.com/lib/flock v1.1.0 h1:NzPUAXnywikN+ZPabzQw9eXAwvZolGUE3pjnSxnDwFk=
git.crumpington.com/lib/flock v1.1.0/go.mod h1:prUmtkjpGDUakQh6TiEAylrgDTPG0HuBOUe8Lq4HKsc=
git.crumpington.com/lib/idgen v1.0.0 h1:0Jre8R3B+RaMOKmCgagBT659wGM93QNpamuGF2e9SII=
git.crumpington.com/lib/idgen v1.0.0/go.mod h1:Q8kV11Zta4P5WKDpBwsekEsnOe9IysVLsW+gPhbzFTc=
git.crumpington.com/lib/keyedmutex v1.1.0 h1:XOlk9f0rnwmr5yNoIvPteM2W2uakZqT4tnZKficrXho=
git.crumpington.com/lib/keyedmutex v1.1.0/go.mod h1:ova6v/794UCZJ5FKKrLpaol0wfNZZTB3plLObSWaGk4=
git.crumpington.com/lib/ratelimiter v1.1.1 h1:8jVDVK/I0zzE3EHCu+sUeZN8a9Aqzm+PG4WrlnEvLes=
git.crumpington.com/lib/ratelimiter v1.1.1/go.mod h1:TycyPTi/aBfnWW8F51yfo/5fSP/qKywDREqsph7TEns=
git.crumpington.com/lib/sqliteutil v1.1.1 h1:xwfp/l2BL4nfw8Ye0Cex2HdGJQKQ1YBCFtDiMeUhnzk=
git.crumpington.com/lib/sqliteutil v1.1.1/go.mod h1:K8OelqOwhSYAZK42v8hKK6UmafItGf2WcMfNlq9Gfeo=
git.crumpington.com/lib/webutil v1.1.0 h1:S9CaRBbVgYOUsgZ5AU1gAJxkxzr8Zjn2v84MoMOy1+I=
git.crumpington.com/lib/webutil v1.1.0/go.mod h1:+LNLGApoe9InAJ7DCeLfiDmYov87XU3crYRHr/RYv2E=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/mattn/go-sqlite3 v1.14.45 h1:6KA/spDguL3KV8rnybG7ezSaE4SeMR3KC9VbUoAQaIk=
github.com/mattn/go-sqlite3 v1.14.45/go.mod h1:pjEuOr8IwzLJP2MfGeTb0A35jauH+C2kbHKBr7yXKVQ=
github.com/mdlayher/genetlink v1.4.0 h1:f/Xs7Y2T+GyX9b3dbiUhnLE9InGs5F9RxJ2JwBMl71o=
github.com/mdlayher/genetlink v1.4.0/go.mod h1:d1hrKr8fwZU2JkcAtQUAzeTrI7nbgQSl+5k1cC0biSA=
github.com/mdlayher/netlink v1.11.2 h1:HKh2jqe+omdSWcQ88nrT7INE61B0NXfiSPFdgL4YbNI=
github.com/mdlayher/netlink v1.11.2/go.mod h1:uT2Yc/QLaZubzDpZIBi9d4GoeLwtp3x1AMeqSRrK2sA=
github.com/mdlayher/socket v0.6.1 h1:M7uj2NtuujUY4mYr1C57NmfNiRHbkKpnBxO856lsc3A=
github.com/mdlayher/socket v0.6.1/go.mod h1:+/SGtqc9V+5dAuRgQsU0fGBI+oRDiW7O2Obx10OIWfg=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI=
golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
golang.org/x/net v0.44.0 h1:evd8IRDyfNBMBTTY5XRF1vaZlD+EmWx6x8PkhR04H/I=
golang.org/x/net v0.44.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4=
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
golang.org/x/crypto v0.53.0 h1:QZ4Muo8THX6CizN2vPPd5fBGHyogrdK9fG4wLPFUsto=
golang.org/x/crypto v0.53.0/go.mod h1:DNLU434OwVakk9PzuwV8w62mAJpRJL3vsgcfp4Qnsio=
golang.org/x/net v0.56.0 h1:Rw8j/hFzGvJUZwNBXnAtf5sVDVt+65SK2C7IxCxZt5o=
golang.org/x/net v0.56.0/go.mod h1:D3Ku6r+V6JROoZK144D2XfMHFcMq/0zSfLelVTCFKec=
golang.org/x/sync v0.21.0 h1:HLII4xRRTtCRkxYp4HNFF0Js/Og6q2i++KXbg0gHCwM=
golang.org/x/sync v0.21.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.46.0 h1:noSf2Fq6F8DBgS+LysIkx7rIExoNHJsxOAtPp4rthXw=
golang.org/x/sys v0.46.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/text v0.38.0 h1:sXmwo9DwP3OK9EZ7PqAdaooSGozfl/3a6/xJcbzPRhE=
golang.org/x/text v0.38.0/go.mod h1:YXZt3QhHUKYT53r2lLKFIVi6Ao1jdzrTR/KQ09qyxF4=
golang.zx2c4.com/wireguard v0.0.0-20260522210424-ecfc5a8d5446 h1:cqHQ3AycTHvM2R7ikgyX57D+XvtcSnGylsLkOVhta/w=
golang.zx2c4.com/wireguard v0.0.0-20260522210424-ecfc5a8d5446/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ=

View File

@@ -8,10 +8,11 @@ import (
"sync"
"time"
"vppn/hub/api/db"
"vppn/hub/errs"
"vppn/m"
"git.crumpington.com/lib/go/idgen"
"git.crumpington.com/lib/go/sqliteutil"
"git.crumpington.com/lib/idgen"
"git.crumpington.com/lib/sqliteutil"
"golang.org/x/crypto/bcrypt"
)
@@ -26,7 +27,8 @@ type API struct {
}
func New(dbPath string) (*API, error) {
sqlDB, err := sql.Open("sqlite3", dbPath+"?_journal=WAL")
dbPath += "?_journal=WAL&_foreign_keys=on&_busy_timeout=5000&_txlock=immediate"
sqlDB, err := sql.Open("sqlite3", dbPath)
if err != nil {
return nil, err
}
@@ -64,30 +66,33 @@ func (a *API) ensurePassword() error {
hashed, err := bcrypt.GenerateFromPassword([]byte(pwd), bcrypt.DefaultCost)
if err != nil {
return err
log.Printf("Failed to generate password: %v", err)
return errs.ErrUnexpected
}
conf := &Config{ConfigID: 1, Password: hashed}
return db.Config_Insert(a.db, conf)
return errs.DB(db.Config_Insert(a.db, conf))
}
func (a *API) Config_Get() (*Config, error) {
return db.Config_Get(a.db, 1)
conf, err := db.Config_Get(a.db, 1)
return conf, errs.DB(err)
}
func (a *API) Config_Update(conf *Config) error {
return db.Config_Update(a.db, conf)
a.lock.Lock()
defer a.lock.Unlock()
return errs.DB(db.Config_Update(a.db, conf))
}
func (a *API) Session_Delete(sessionID string) error {
func (a *API) Session_Delete(sessionID string) {
a.sessionsMu.Lock()
defer a.sessionsMu.Unlock()
delete(a.sessions, sessionID)
return nil
}
const (
sessionTTLSecs = 86400 * 21 // sessions expire 21 days after last use
sessionTTL = 24 * 21 * time.Hour // sessions expire 21 days after last use
sessionSweepEvery = time.Hour // cadence of expired-session eviction
)
@@ -96,23 +101,23 @@ const (
// creates a session, so anonymous requests cost no memory — a session is minted
// only by Session_SignIn. Returning a value (not the stored pointer) keeps
// callers from racing on the shared struct.
func (a *API) Session_Get(sessionID string) (Session, error) {
func (a *API) Session_Get(sessionID string) Session {
a.sessionsMu.Lock()
defer a.sessionsMu.Unlock()
s, ok := a.sessions[sessionID]
if sessionID == "" || !ok {
return Session{}, nil
return Session{}
}
if timeSince(s.LastSeenAt) > sessionTTLSecs {
if time.Since(s.LastSeenAt) > sessionTTL {
delete(a.sessions, sessionID)
return Session{}, nil
return Session{}
}
s.LastSeenAt = time.Now().Unix()
return *s, nil
s.LastSeenAt = time.Now()
return *s
}
// Session_SignIn verifies pwd and, on success, mints a fresh signed-in session,
@@ -121,24 +126,36 @@ func (a *API) Session_Get(sessionID string) (Session, error) {
func (a *API) Session_SignIn(pwd string) (Session, error) {
conf, err := a.Config_Get()
if err != nil {
return Session{}, err
log.Printf("Failed to get config: %v", err)
return Session{}, errs.ErrUnexpected
}
if err := bcrypt.CompareHashAndPassword(conf.Password, []byte(pwd)); err != nil {
return Session{}, ErrNotAuthorized
return Session{}, errs.ErrNotAuthorized
}
a.sessionsMu.Lock()
defer a.sessionsMu.Unlock()
s := &Session{
SessionID: idgen.NewToken(),
SignedIn: true,
CreatedAt: time.Now().Unix(),
LastSeenAt: time.Now().Unix(),
LastSeenAt: time.Now(),
}
a.sessions[s.SessionID] = s
return *s, nil
}
func (a *API) Session_InvalidateAll() Session {
a.sessionsMu.Lock()
defer a.sessionsMu.Unlock()
clear(a.sessions)
s := &Session{
SessionID: idgen.NewToken(),
LastSeenAt: time.Now(),
}
a.sessions[s.SessionID] = s
return *s
}
// sweepSessions periodically evicts sessions past their TTL. Without it, a
// signed-in session whose ID is never presented again would linger forever
// (Session_Get only evicts on a lookup of that same ID).
@@ -146,7 +163,7 @@ func (a *API) sweepSessions() {
for range time.Tick(sessionSweepEvery) {
a.sessionsMu.Lock()
for id, s := range a.sessions {
if timeSince(s.LastSeenAt) > sessionTTLSecs {
if time.Since(s.LastSeenAt) > sessionTTL {
delete(a.sessions, id)
}
}
@@ -155,29 +172,48 @@ func (a *API) sweepSessions() {
}
func (a *API) Network_Create(n *Network) error {
a.lock.Lock()
defer a.lock.Unlock()
n.NetworkID = idgen.NextID(0)
return db.Network_Insert(a.db, n)
return errs.DB(db.Network_Insert(a.db, n))
}
func (a *API) Network_Delete(n *Network) error {
return db.Network_Delete(a.db, n.NetworkID)
a.lock.Lock()
defer a.lock.Unlock()
exists, err := db.Network_HasPeers(a.db, n.NetworkID)
if err != nil {
return errs.DB(err)
}
if exists {
return errs.Conflict.WithMsg("Delete all peers before deleting network.")
}
return errs.DB(db.Network_Delete(a.db, n.NetworkID))
}
func (a *API) Network_Get(id int64) (*Network, error) {
return db.Network_Get(a.db, id)
n, err := db.Network_Get(a.db, id)
return n, errs.DB(err)
}
func (a *API) Network_List() ([]*Network, error) {
const query = db.Network_SelectQuery + ` ORDER BY LocalDomain ASC`
return db.Network_List(a.db, query)
n, err := db.Network_List(a.db, query)
return n, errs.DB(err)
}
func (a *API) Peer_CreateNew(p *Peer) error {
a.lock.Lock()
defer a.lock.Unlock()
p.WGPubKey = []byte{}
p.SignPubKey = []byte{}
p.APIKey = idgen.NewToken()
return db.Peer_Insert(a.db, p)
return errs.DB(db.Peer_Insert(a.db, p))
}
func (a *API) Peer_Init(peer *Peer, args m.PeerInitArgs) error {
@@ -188,30 +224,36 @@ func (a *API) Peer_Init(peer *Peer, args m.PeerInitArgs) error {
// we held the lock, so it may be stale under concurrent requests.
current, err := db.Peer_Get(a.db, peer.NetworkID, peer.PeerIP)
if err != nil {
return err
return errs.DB(err)
}
if len(current.WGPubKey) != 0 {
return errors.New("peer already initialized")
return errs.ErrAlreadyExists
}
peer.WGPubKey = args.WGPubKey
peer.SignPubKey = args.SignPubKey
return db.Peer_UpdateFull(a.db, peer)
return errs.DB(db.Peer_UpdateFull(a.db, peer))
}
func (a *API) Peer_Delete(networkID int64, peerIP byte) error {
return db.Peer_Delete(a.db, networkID, peerIP)
a.lock.Lock()
defer a.lock.Unlock()
return errs.DB(db.Peer_Delete(a.db, networkID, peerIP))
}
func (a *API) Peer_List(networkID int64) ([]*Peer, error) {
return db.Peer_ListAll(a.db, networkID)
p, err := db.Peer_ListAll(a.db, networkID)
return p, errs.DB(err)
}
func (a *API) Peer_Get(networkID int64, ip byte) (*Peer, error) {
return db.Peer_Get(a.db, networkID, ip)
p, err := db.Peer_Get(a.db, networkID, ip)
return p, errs.DB(err)
}
func (a *API) Peer_GetByAPIKey(key string) (*Peer, error) {
return db.Peer_GetByAPIKey(a.db, key)
p, err := db.Peer_GetByAPIKey(a.db, key)
return p, errs.DB(err)
}

View File

@@ -51,7 +51,7 @@ func Config_Update(
n, err := result.RowsAffected()
if err != nil {
panic(err)
return err
}
switch n {
case 0:

View File

@@ -1,19 +1,9 @@
package db
import (
"errors"
"net/netip"
"strings"
)
var (
ErrInvalidIP = errors.New("invalid IP")
ErrInvalidPeerIP = errors.New("invalid peer IP")
ErrNonPrivateIP = errors.New("non-private IP")
ErrInvalidPort = errors.New("invalid port")
ErrInvalidNetName = errors.New("invalid network name")
ErrNetNameNotLocal = errors.New("network name must end with .local")
ErrInvalidPeerName = errors.New("invalid peer name")
"vppn/hub/errs"
)
func Config_Sanitize(c *Config) {
@@ -35,11 +25,11 @@ func Network_Validate(c *Network) error {
// 15 bytes is linux limit for network interface names. With ending .local,
// max length is 21.
if len(c.LocalDomain) == 0 || len(c.LocalDomain) > 21 {
return ErrInvalidNetName
return errs.ErrInvalidNetName
}
if !strings.HasSuffix(c.LocalDomain, ".local") {
return ErrNetNameNotLocal
return errs.ErrNetNameNotLocal
}
for _, c := range strings.TrimSuffix(c.LocalDomain, ".local") {
@@ -49,23 +39,23 @@ func Network_Validate(c *Network) error {
if c >= '0' && c <= '9' {
continue
}
return ErrInvalidNetName
return errs.ErrInvalidNetName
}
addr, ok := netip.AddrFromSlice(c.Network)
if !ok || !addr.Is4() || addr.As4()[3] != 0 || addr.As4()[0] == 0 {
return ErrInvalidIP
return errs.ErrInvalidIP
}
if !addr.IsPrivate() {
return ErrNonPrivateIP
return errs.ErrNonPrivateIP
}
return nil
}
func Peer_Sanitize(p *Peer) {
p.Name = strings.TrimSpace(p.Name)
p.Name = strings.TrimSpace(strings.ToLower(p.Name))
if len(p.Addr4) != 0 {
if addr, ok := netip.AddrFromSlice(p.Addr4); ok {
// Unmap so an IPv4-mapped form is stored canonically as 4 bytes.
@@ -84,26 +74,26 @@ func Peer_Sanitize(p *Peer) {
func Peer_Validate(p *Peer) error {
if p.PeerIP < 1 || p.PeerIP > 254 {
return ErrInvalidPeerIP
return errs.ErrInvalidPeerIP
}
if len(p.Addr4) > 0 {
// Must be a genuine IPv4 address (reject an IPv6 in the v4 field).
if addr, ok := netip.AddrFromSlice(p.Addr4); !ok || !addr.Is4() {
return ErrInvalidIP
return errs.ErrInvalidIP
}
}
if len(p.Addr6) > 0 {
// Must be a genuine IPv6 address (reject IPv4 / IPv4-mapped in the v6 field).
if addr, ok := netip.AddrFromSlice(p.Addr6); !ok || !addr.Is6() || addr.Is4In6() {
return ErrInvalidIP
return errs.ErrInvalidIP
}
}
if p.Port == 0 {
return ErrInvalidPort
return errs.ErrInvalidPort
}
if len(p.Name) == 0 {
return ErrInvalidPeerName
if len(p.Name) == 0 || len(p.Name) > 63 {
return errs.ErrInvalidPeerName
}
for _, c := range p.Name {
if c >= 'a' && c <= 'z' {
@@ -115,7 +105,7 @@ func Peer_Validate(p *Peer) error {
if c == '-' {
continue
}
return ErrInvalidPeerName
return errs.ErrInvalidPeerName
}
return nil

View File

@@ -11,3 +11,9 @@ func Peer_GetByAPIKey(tx TX, apiKey string) (*Peer, error) {
Peer_SelectQuery+` WHERE APIKey=?`,
apiKey)
}
func Network_HasPeers(tx TX, networkID int64) (exists bool, err error) {
const query = "SELECT EXISTS(SELECT 1 FROM peers WHERE NetworkID=?)"
err = tx.QueryRow(query, networkID).Scan(&exists)
return exists, err
}

View File

@@ -1,12 +0,0 @@
package api
import (
"errors"
"vppn/hub/api/db"
)
var (
ErrNotAuthorized = errors.New("not authorized")
ErrInvalidIP = db.ErrInvalidIP
ErrInvalidPort = db.ErrInvalidPort
)

View File

@@ -21,5 +21,6 @@ CREATE TABLE peers (
WGPubKey BLOB NOT NULL,
SignPubKey BLOB NOT NULL,
UNIQUE(NetworkID, Name),
PRIMARY KEY(NetworkID, PeerIP)
PRIMARY KEY(NetworkID, PeerIP),
FOREIGN KEY(NetworkID) REFERENCES networks(NetworkID)
) WITHOUT ROWID;

View File

@@ -1,7 +0,0 @@
package api
import "time"
func timeSince(ts int64) int64 {
return time.Now().Unix() - ts
}

View File

@@ -1,6 +1,9 @@
package api
import "vppn/hub/api/db"
import (
"time"
"vppn/hub/api/db"
)
type Config = db.Config
type Network = db.Network
@@ -8,7 +11,5 @@ type Peer = db.Peer
type Session struct {
SessionID string
SignedIn bool
CreatedAt int64
LastSeenAt int64
LastSeenAt time.Time
}

View File

@@ -8,7 +8,8 @@ import (
"path/filepath"
"vppn/hub/api"
"git.crumpington.com/lib/go/webutil"
"git.crumpington.com/lib/keyedmutex"
"git.crumpington.com/lib/webutil"
)
//go:embed static
@@ -28,6 +29,9 @@ type App struct {
mux *http.ServeMux
tmpl map[string]*template.Template
insecure bool
// Per-remote address sign-in serialization lock.
signInLock *keyedmutex.KeyedMutex[string]
}
func NewApp(conf Config) (*App, error) {
@@ -41,6 +45,7 @@ func NewApp(conf Config) (*App, error) {
mux: http.NewServeMux(),
tmpl: webutil.ParseTemplateSet(templateFuncs, templateFS),
insecure: conf.Insecure,
signInLock: keyedmutex.New[string](),
}
app.registerRoutes()

37
hub/errs/db.go Normal file
View File

@@ -0,0 +1,37 @@
package errs
import (
"database/sql"
"errors"
"log"
sqlite3 "github.com/mattn/go-sqlite3"
)
func DB(err error) error {
if err == nil {
return nil
}
var e *Error
if errors.As(err, &e) {
return err
}
if errors.Is(err, sql.ErrNoRows) {
return ErrNotFound
}
var se sqlite3.Error
if errors.As(err, &se) {
switch se.ExtendedCode {
case sqlite3.ErrConstraintUnique, sqlite3.ErrConstraintPrimaryKey:
return ErrAlreadyExists
case sqlite3.ErrConstraintForeignKey, sqlite3.ErrConstraintCheck:
return ErrConstraint
}
}
log.Printf("Unexpected error: %v", err)
return ErrUnexpected
}

61
hub/errs/types.go Normal file
View File

@@ -0,0 +1,61 @@
package errs
import (
"fmt"
"net/http"
)
type Error struct {
Code int
Msg string
}
func (e *Error) Error() string {
return fmt.Sprintf("[%d] %s", e.Code, e.Msg)
}
var (
ErrNotAuthorized = NotAuthorized.WithMsg("Not authorized.")
ErrInvalidPassword = BadRequest.WithMsg("Invalid password.")
ErrPasswordMismatch = BadRequest.WithMsg("Passwords don't match.")
ErrUnexpected = Internal.WithMsg("Unexpected internal error.")
ErrNotFound = NotFound.WithMsg("Not found.")
ErrAlreadyExists = Conflict.WithMsg("Already exists.")
// Validation errors.
ErrInvalidIP = BadRequest.WithMsg("Invalid IP.")
ErrInvalidPeerIP = BadRequest.WithMsg("Invalid peer IP.")
ErrNonPrivateIP = BadRequest.WithMsg("Non-private IP.")
ErrInvalidPort = BadRequest.WithMsg("Invalid port.")
ErrInvalidNetName = BadRequest.WithMsg("Invalid network name.")
ErrNetNameNotLocal = BadRequest.WithMsg("Network name must end with .local.")
ErrInvalidPeerName = BadRequest.WithMsg("Invalid peer name.")
ErrConstraint = BadRequest.WithMsg("Constraint error.")
)
// ----------------------------------------------------------------------------
type Type struct {
Code int
Msg string
}
func (t Type) WithErr(err error) *Error {
return &Error{Code: t.Code, Msg: err.Error()}
}
func (t Type) WithMsg(msg string) *Error {
return &Error{Code: t.Code, Msg: msg}
}
func (t Type) WithMsgf(msg string, args ...any) *Error {
return &Error{Code: t.Code, Msg: fmt.Sprintf(msg, args...)}
}
var (
Internal = Type{Code: http.StatusInternalServerError}
NotAuthorized = Type{Code: http.StatusUnauthorized}
NotFound = Type{Code: http.StatusNotFound}
BadRequest = Type{Code: http.StatusBadRequest}
Conflict = Type{Code: http.StatusConflict}
)

View File

@@ -4,7 +4,7 @@ import (
"net/url"
"vppn/hub/api"
"git.crumpington.com/lib/go/webutil"
"git.crumpington.com/lib/webutil"
)
func (app *App) formGetNetwork(form url.Values) (*api.Network, error) {

View File

@@ -1,11 +1,11 @@
package hub
import (
"errors"
"log"
"net/http"
"vppn/hub/api"
"git.crumpington.com/lib/go/webutil"
"vppn/hub/errs"
)
type handlerFunc func(s *api.Session, w http.ResponseWriter, r *http.Request) error
@@ -13,32 +13,26 @@ type handlerFunc func(s *api.Session, w http.ResponseWriter, r *http.Request) er
func (app *App) handlePub(pattern string, fn handlerFunc) {
wrapped := func(w http.ResponseWriter, r *http.Request) {
sessionID := app.getCookie(r, sessionIDCookieName)
s, err := app.api.Session_Get(sessionID)
if err != nil {
log.Printf("Failed to get session: %v", err)
http.Error(w, "Internal error", http.StatusInternalServerError)
return
}
s := app.api.Session_Get(sessionID)
if r.Method == http.MethodPost {
r.Body = http.MaxBytesReader(w, r.Body, 128*1024)
r.ParseMultipartForm(64 * 1024)
} else {
r.ParseForm()
}
if err := fn(&s, w, r); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
handleError(w, err)
}
}
app.mux.HandleFunc(pattern,
webutil.WithLogging(
wrapped))
app.mux.HandleFunc(pattern, withLogging(wrapped))
}
func (app *App) handleNotSignedIn(pattern string, fn handlerFunc) {
app.handlePub(pattern, func(s *api.Session, w http.ResponseWriter, r *http.Request) error {
if s.SignedIn {
if s.SessionID != "" {
http.Redirect(w, r, "/", http.StatusSeeOther)
return nil
}
@@ -48,7 +42,7 @@ func (app *App) handleNotSignedIn(pattern string, fn handlerFunc) {
func (app *App) handleSignedIn(pattern string, fn handlerFunc) {
app.handlePub(pattern, func(s *api.Session, w http.ResponseWriter, r *http.Request) error {
if !s.SignedIn {
if s.SessionID == "" {
http.Redirect(w, r, "/", http.StatusSeeOther)
return nil
}
@@ -66,6 +60,7 @@ func (app *App) handlePeer(pattern string, fn peerHandlerFunc) {
return
}
// Not doing constant time compare because index lookup time dominates.
peer, err := app.api.Peer_GetByAPIKey(apiKey)
if err != nil {
http.Error(w, "Not authorized", http.StatusUnauthorized)
@@ -74,12 +69,19 @@ func (app *App) handlePeer(pattern string, fn peerHandlerFunc) {
r.ParseForm()
if err := fn(peer, w, r); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
handleError(w, err)
}
}
app.mux.HandleFunc(pattern,
webutil.WithLogging(
wrapped))
app.mux.HandleFunc(pattern, withLogging(wrapped))
}
func handleError(w http.ResponseWriter, err error) {
var e *errs.Error
if errors.As(err, &e) {
http.Error(w, e.Msg, e.Code)
} else {
log.Printf("Unexpected error: %v", err)
http.Error(w, "Internal server error.", http.StatusInternalServerError)
}
}

View File

@@ -2,19 +2,22 @@ package hub
import (
"encoding/json"
"errors"
"log"
"math/rand/v2"
"net"
"net/http"
"time"
"vppn/hub/api"
"vppn/hub/errs"
"vppn/m"
"git.crumpington.com/lib/go/webutil"
"git.crumpington.com/lib/webutil"
"golang.org/x/crypto/bcrypt"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
func (a *App) _root(s *api.Session, w http.ResponseWriter, r *http.Request) error {
if s.SignedIn {
if s.SessionID != "" {
return a.redirect(w, r, "/admin/network/list/")
} else {
return a.redirect(w, r, "/sign-in/")
@@ -26,6 +29,15 @@ func (a *App) _signin(s *api.Session, w http.ResponseWriter, r *http.Request) er
}
func (a *App) _signinSubmit(s *api.Session, w http.ResponseWriter, r *http.Request) error {
// Ignoring error here - if host is the empty string, it will contend for the
// lock anyway.
host, _, _ := net.SplitHostPort(r.RemoteAddr)
if !a.signInLock.TryLock(host) {
time.Sleep(time.Second + time.Duration(rand.Int64N(int64(3*time.Second))))
return errs.ErrNotAuthorized
}
defer a.signInLock.Unlock(host)
var pwd string
err := webutil.NewFormScanner(r.Form).
Scan("Password", &pwd).
@@ -36,8 +48,10 @@ func (a *App) _signinSubmit(s *api.Session, w http.ResponseWriter, r *http.Reque
sess, err := a.api.Session_SignIn(pwd)
if err != nil {
time.Sleep(time.Second + time.Duration(rand.Int64N(int64(3*time.Second))))
return err
}
a.setCookie(w, sessionIDCookieName, sess.SessionID)
return a.redirect(w, r, "/")
@@ -48,9 +62,7 @@ func (a *App) _adminSignOut(s *api.Session, w http.ResponseWriter, r *http.Reque
}
func (a *App) _adminSignOutSubmit(s *api.Session, w http.ResponseWriter, r *http.Request) error {
if err := a.api.Session_Delete(s.SessionID); err != nil {
log.Printf("Failed to delete session cookie %s: %v", s.SessionID, err)
}
a.api.Session_Delete(s.SessionID)
a.deleteCookie(w, sessionIDCookieName)
return a.redirect(w, r, "/")
}
@@ -236,22 +248,23 @@ func (a *App) _adminPasswordSubmit(s *api.Session, w http.ResponseWriter, r *htt
return err
}
if len(newPwd) < 8 {
return errors.New("password is too short")
if len(newPwd) < 8 || len(newPwd) > 72 {
return errs.ErrInvalidPassword
}
if newPwd != newPwd2 {
return errors.New("passwords don't match")
return errs.ErrPasswordMismatch
}
err = bcrypt.CompareHashAndPassword(conf.Password, []byte(curPwd))
if err != nil {
return err
return errs.ErrNotAuthorized
}
hash, err := bcrypt.GenerateFromPassword([]byte(newPwd), bcrypt.DefaultCost)
if err != nil {
return err
log.Printf("Failed to hash password with bcrypt: %v", err)
return errs.ErrUnexpected
}
conf.Password = hash
@@ -260,7 +273,10 @@ func (a *App) _adminPasswordSubmit(s *api.Session, w http.ResponseWriter, r *htt
return err
}
return a.redirect(w, r, "/admin/config/")
*s = a.api.Session_InvalidateAll()
a.setCookie(w, sessionIDCookieName, s.SessionID)
return a.redirect(w, r, "/admin/network/list/")
}
func (a *App) _peerInit(peer *api.Peer, w http.ResponseWriter, r *http.Request) error {
@@ -269,9 +285,11 @@ func (a *App) _peerInit(peer *api.Peer, w http.ResponseWriter, r *http.Request)
return nil
}
r.Body = http.MaxBytesReader(w, r.Body, 2048)
args := m.PeerInitArgs{}
if err := json.NewDecoder(r.Body).Decode(&args); err != nil {
return err
return errs.BadRequest.WithMsg("Invalid request body.")
}
if len(args.WGPubKey) != 32 {
@@ -328,8 +346,10 @@ func (a *App) peersList(networkID int64) (peers []m.Peer, err error) {
}
wgKey, err := wgtypes.NewKey(p.WGPubKey)
if err != nil {
log.Printf("Bad WG key in DB for peer %d/%d", p.NetworkID, p.PeerIP)
continue // malformed key; skip rather than serve garbage
}
var signKey [32]byte
copy(signKey[:], p.SignPubKey)
peers = append(peers, m.Peer{

View File

@@ -5,8 +5,9 @@ import (
"log"
"net/http"
"os"
"time"
"git.crumpington.com/lib/go/webutil"
"git.crumpington.com/lib/webutil"
)
func Main() {
@@ -32,6 +33,10 @@ func Main() {
srv := &http.Server{
Addr: conf.ListenAddr,
Handler: app.Handler(),
ReadHeaderTimeout: 30 * time.Second,
ReadTimeout: 60 * time.Second,
WriteTimeout: 120 * time.Second,
IdleTimeout: 180 * time.Second,
}
log.Fatal(webutil.ListenAndServe(srv))

47
hub/middleware.go Normal file
View File

@@ -0,0 +1,47 @@
package hub
import (
"log"
"net/http"
"os"
"time"
)
var _log = log.New(os.Stderr, "", 0)
type responseWriterWrapper struct {
http.ResponseWriter
httpStatus int
responseSize int
}
func (w *responseWriterWrapper) WriteHeader(status int) {
w.httpStatus = status
w.ResponseWriter.WriteHeader(status)
}
func (w *responseWriterWrapper) Write(b []byte) (int, error) {
if w.httpStatus == 0 {
w.httpStatus = 200
}
w.responseSize += len(b)
return w.ResponseWriter.Write(b)
}
func withLogging(inner http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
t := time.Now()
wrapper := responseWriterWrapper{w, 0, 0}
inner(&wrapper, r)
_log.Printf("%s \"%s %s %s\" %d %d %v",
r.RemoteAddr,
r.Method,
r.URL.Path,
r.Proto,
wrapper.httpStatus,
wrapper.responseSize,
time.Since(t),
)
}
}

View File

@@ -3,7 +3,7 @@ package hub
import "net/http"
func (a *App) registerRoutes() {
a.mux.Handle("GET /static/", http.FileServerFS(staticFS))
a.mux.Handle("GET /static/", withLogging(http.FileServerFS(staticFS).ServeHTTP))
a.handlePub("GET /", a._root)
a.handleNotSignedIn("GET /sign-in/", a._signin)

View File

@@ -9,7 +9,7 @@
<header>
<h1>VPPN</h1>
<nav>
{{if .Session.SignedIn -}}
{{if .Session.SessionID -}}
<a href="/admin/networks/list/">Home</a> /
<a href="/admin/sign-out/">Sign out</a>
{{- end}}

View File

@@ -9,7 +9,7 @@
<header>
<h1>VPPN</h1>
<nav>
{{if .Session.SignedIn -}}
{{if .Session.SessionID -}}
<a href="/admin/networks/list/">Home</a> /
<a href="/admin/sign-out/">Sign out</a>
{{- end}}

View File

@@ -1,9 +1,13 @@
package peer
import (
"fmt"
"log"
"net/netip"
"os"
"os/signal"
"sort"
"strings"
"syscall"
"time"
@@ -20,6 +24,7 @@ var _ WGDevice = (*wginterface.Device)(nil) // compile-time check: Device satisf
const (
ControlPort = 4561
PingInterval = 8 * time.Second
TickInterval = 2 * time.Second
TimeoutInterval = 30 * time.Second
)
@@ -40,7 +45,6 @@ type App struct {
vpnNet netip.Prefix
privKey wgtypes.Key
pubKey wgtypes.Key
isRelay bool
isPublic bool
localDomain string
@@ -75,13 +79,15 @@ func (a *App) Run() error {
// while we were down).
a.updateHosts()
ticker := time.NewTicker(PingInterval)
defer ticker.Stop()
stateTicker := time.NewTicker(TickInterval)
pingTicker := time.NewTicker(PingInterval)
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
defer signal.Stop(sig)
tickCount := 0
for {
select {
case p := <-a.hubAddCh:
@@ -92,8 +98,15 @@ func (a *App) Run() error {
a.onPing(e)
case e := <-a.multicastCh:
a.onMulticastDiscovery(e)
case <-ticker.C:
a.onTick()
case <-stateTicker.C:
a.onStateTick()
case <-pingTicker.C:
a.onPingTick()
tickCount++
if tickCount%8 == 0 {
a.logNetworkState()
}
case <-sig:
return a.onShutdown()
}
@@ -103,3 +116,47 @@ func (a *App) Run() error {
func (a *App) onShutdown() error {
return wginterface.Delete(a.dev.Name())
}
func (a *App) logNetworkState() {
var b strings.Builder
fmt.Fprintf(&b, "Network state (self: %s public=%v):\n", a.vpnIP, a.isPublic)
fmt.Fprintf(&b, " Network: %v\n", a.vpnNet)
fmt.Fprintf(&b, " IPv4: %v\n", a.selfV4)
fmt.Fprintf(&b, " IPv6: %v\n", a.selfV6)
if a.relay != nil {
fmt.Fprintf(&b, " Relay: %s\n", a.relay.Name)
} else {
fmt.Fprint(&b, " Relay: -\n")
}
b.WriteString("Peers:\n")
//
peers := make([]*Peer, 0, len(a.peersByIP))
for _, p := range a.peersByIP {
peers = append(peers, p)
}
sort.Slice(peers, func(i, j int) bool {
return peers[i].VPNIP.As4()[3] < peers[j].VPNIP.As4()[3]
})
for _, p := range peers {
ip := p.VPNIP.As4()[3]
up := "DOWN"
if p.Up() {
up = "UP "
}
endpoint := p.WGEndpoint()
if endpoint.IsValid() {
fmt.Fprintf(&b, " %24s %03d %s %s seen=%s @ %s\n",
p.Name, ip, p.State, up, time.Since(p.LastPing).Round(time.Millisecond), endpoint)
} else {
fmt.Fprintf(&b, " %24s %03d %s %s seen=%s\n",
p.Name, ip, p.State, up, time.Since(p.LastPing).Round(time.Millisecond))
}
}
log.Print(b.String())
}

View File

@@ -26,13 +26,14 @@ func addRelayPeer(t *testing.T, a *App, vpnIP string, ep netip.AddrPort) *Peer {
})
p := a.peersByKey[key]
p.wgPeer.LastHandshakeTime = time.Now()
p.LastPing = time.Now()
return p
}
// newTestApp returns a minimal App wired to a fakeWGDevice and fakeControlConn.
// vpnIP is the local VPN address (e.g. "10.0.0.1").
// isPublic / isRelay describe the local node's role.
func newTestApp(t *testing.T, vpnIP string, isPublic, isRelay bool) (*App, *fakeWGDevice, *fakeControlConn) {
func newTestApp(t *testing.T, vpnIP string, isPublic bool) (*App, *fakeWGDevice, *fakeControlConn) {
t.Helper()
privKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
@@ -47,7 +48,6 @@ func newTestApp(t *testing.T, vpnIP string, isPublic, isRelay bool) (*App, *fake
privKey: privKey,
pubKey: privKey.PublicKey(),
isPublic: isPublic,
isRelay: isRelay,
dev: dev,
controlConn: cc,
peersByKey: make(map[wgtypes.Key]*Peer),

View File

@@ -37,8 +37,12 @@ type Ping struct {
// buf[:Size]. Taking the buffer lets callers reuse one across sends; every
// field is written unconditionally so a reused buffer needs no pre-zeroing.
func (p Ping) Marshal(buf []byte) []byte {
_ = buf[Size-1] // Panic if buffer is too small.
buf[0] = version
binary.BigEndian.PutUint64(buf[1:9], uint64(p.PingTS))
// SrcV4.
if p.SrcV4.IsValid() {
a4 := p.SrcV4.Addr().As4()
copy(buf[9:13], a4[:])
@@ -46,9 +50,13 @@ func (p Ping) Marshal(buf []byte) []byte {
} else {
clear(buf[9:15])
}
// SrcV6.
a16 := p.SrcV6.Addr().As16()
copy(buf[15:31], a16[:])
binary.BigEndian.PutUint16(buf[31:33], p.SrcV6.Port())
// Dst.
a16 = p.Dst.Addr().As16()
copy(buf[33:49], a16[:])
binary.BigEndian.PutUint16(buf[49:51], p.Dst.Port())
@@ -63,14 +71,21 @@ func Unmarshal(buf [Size]byte) (Ping, error) {
p := Ping{
PingTS: int64(binary.BigEndian.Uint64(buf[1:9])),
}
if addr := netip.AddrFrom4([4]byte(buf[9:13])); !addr.IsUnspecified() {
addr := netip.AddrFrom4([4]byte(buf[9:13]))
if !addr.IsUnspecified() {
p.SrcV4 = netip.AddrPortFrom(addr, binary.BigEndian.Uint16(buf[13:15]))
}
if addr := netip.AddrFrom16([16]byte(buf[15:31])); !addr.IsUnspecified() {
addr = netip.AddrFrom16([16]byte(buf[15:31])).Unmap()
if !addr.IsUnspecified() {
p.SrcV6 = netip.AddrPortFrom(addr, binary.BigEndian.Uint16(buf[31:33]))
}
if addr := netip.AddrFrom16([16]byte(buf[33:49])).Unmap(); !addr.IsUnspecified() {
addr = netip.AddrFrom16([16]byte(buf[33:49])).Unmap()
if !addr.IsUnspecified() {
p.Dst = netip.AddrPortFrom(addr, binary.BigEndian.Uint16(buf[49:51]))
}
return p, nil
}

View File

@@ -4,6 +4,7 @@ import (
"log"
"net"
"net/netip"
"time"
"vppn/peer/control"
)
@@ -32,11 +33,14 @@ func (c *udpControlConn) SendPing(dst netip.AddrPort, ping control.Ping, buf []b
// run reads incoming ping packets and forwards them to ch until ctx is done.
// Call this in a goroutine before starting the App event loop.
func (c *udpControlConn) run(ch chan<- PingEvent) {
const errorTimeout = 8 * time.Second
var buf [control.Size]byte
for {
n, src, err := c.conn.ReadFromUDP(buf[:])
if err != nil {
log.Printf("control read: %v", err)
time.Sleep(errorTimeout)
continue
}

View File

@@ -38,22 +38,28 @@ func (a *App) devPeers() []wgtypes.Peer {
return peers
}
func (a *App) devAddPeer(p *Peer) {
func (a *App) devAddRelayed(p *Peer) {
log.Printf("RELAYED: %s - %s ", p.Name, p.VPNIP.String())
devRetry(p.VPNIP, "AddPeer", func() error { return a.dev.AddPeer(p.PubKey()) })
p.State = StateRelayed
p.EndpointV4 = netip.AddrPort{}
p.EndpointV6 = netip.AddrPort{}
p.EndpointLAN = netip.AddrPort{}
}
func (a *App) devAddDirect(p *Peer, endpoint netip.AddrPort) {
log.Printf("DIRECT: %s - %s @ %s", p.Name, p.VPNIP.String(), endpoint.String())
devRetry(p.VPNIP, "AddDirect", func() error { return a.dev.AddDirect(p.PubKey(), endpoint, p.VPNIP) })
p.State = StateDirect
}
func (a *App) devSetRelay(p *Peer, endpoint netip.AddrPort) {
log.Printf("RELAY: %s - %s @ %s", p.Name, p.VPNIP.String(), endpoint.String())
devRetry(p.VPNIP, "SetRelay", func() error { return a.dev.SetRelay(p.PubKey(), endpoint, a.vpnNet) })
p.State = StateDirect // Dirrect connection. The app marks peer as relay.
p.State = StateDirect // Direct connection. The app marks peer as relay.
}
func (a *App) devPromote(p *Peer) {
@@ -61,19 +67,24 @@ func (a *App) devPromote(p *Peer) {
if ep.IsValid() {
log.Printf("PROMOTED: %s - %s @ %s", p.Name, p.VPNIP.String(), p.WGEndpoint().String())
} else {
log.Printf("PROMOTED: %s - %s (no IP)", p.Name, p.VPNIP.String())
log.Printf("DIRECT: %s - %s (waiting for handshake)", p.Name, p.VPNIP.String())
}
devRetry(p.VPNIP, "Promote", func() error { return a.dev.Promote(p.PubKey(), p.VPNIP) })
p.State = StateDirect
p.LastPing = time.Now() // Assume the peer is up after being promoted.
}
func (a *App) devAddProbe(p *Peer, endpoint netip.AddrPort) {
log.Printf("PROBE: %s - %s @ %s", p.Name, p.VPNIP.String(), endpoint.String())
devRetry(p.VPNIP, "AddProbe", func() error { return a.dev.AddProbe(p.PubKey(), endpoint) })
p.State = StateProbing
p.ProbeStart = time.Now()
p.ProbeEndpoint = endpoint
}
func (a *App) devRemove(p *Peer) {
log.Printf("REMOVED: %s - %s", p.Name, p.VPNIP.String())
log.Printf("REMOVED: %s", p.PubKey())
devRetry(p.VPNIP, "RemovePeer", func() error { return a.dev.RemovePeer(p.PubKey()) })
}

View File

@@ -9,7 +9,7 @@ import (
"strings"
"syscall"
"git.crumpington.com/lib/go/flock"
"git.crumpington.com/lib/flock"
)
const (

View File

@@ -62,13 +62,15 @@ func (hp *HubPoller) Run() {
hp.apply(state)
}
hp.poll()
client := &http.Client{Timeout: 32 * time.Second}
hp.poll(client)
for range time.Tick(hubPollInterval) {
hp.poll()
hp.poll(client)
}
}
func (hp *HubPoller) poll() {
func (hp *HubPoller) poll(client *http.Client) {
req, err := http.NewRequest(http.MethodGet, hp.hubURL, nil)
if err != nil {
log.Printf("[HubPoller] build request: %v", err)
@@ -76,7 +78,6 @@ func (hp *HubPoller) poll() {
}
req.SetBasicAuth("", hp.apiKey)
client := &http.Client{Timeout: 32 * time.Second}
resp, err := client.Do(req)
if err != nil {
log.Printf("[HubPoller] fetch: %v", err)
@@ -89,7 +90,7 @@ func (hp *HubPoller) poll() {
return
}
body, err := io.ReadAll(resp.Body)
body, err := io.ReadAll(io.LimitReader(resp.Body, 128*1024))
if err != nil {
log.Printf("[HubPoller] read body: %v", err)
return

View File

@@ -9,6 +9,7 @@ import (
"net/http"
"net/netip"
"os"
"time"
"golang.org/x/crypto/nacl/sign"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -93,7 +94,7 @@ func initFromHub(hubURL, apiKey string, privKey wgtypes.Key) (LocalState, error)
req.SetBasicAuth("", apiKey)
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
resp, err := (&http.Client{Timeout: time.Minute}).Do(req)
if err != nil {
return LocalState{}, fmt.Errorf("hub init: %w", err)
}

View File

@@ -9,10 +9,6 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
var addr = net.UDPAddrFromAddrPort(netip.AddrPortFrom(
netip.AddrFrom4([4]byte{224, 0, 0, 157}),
4560))
func Broadcast(
selfVPNIP netip.Addr,
pubKey wgtypes.Key,
@@ -20,12 +16,16 @@ func Broadcast(
signKey *[64]byte,
) {
for {
broadcastInner(selfVPNIP, pubKey, wgPort, signKey)
broadcast(selfVPNIP, pubKey, wgPort, signKey)
time.Sleep(errorTimeout)
}
}
func broadcastInner(selfVPNIP netip.Addr, pubKey wgtypes.Key, wgPort uint16, signKey *[64]byte) {
func broadcast(selfVPNIP netip.Addr, pubKey wgtypes.Key, wgPort uint16, signKey *[64]byte) {
addr := multicastAddr(selfVPNIP)
log.Printf("[MC Broadcast] Sending on %v.", addr)
conn, err := net.ListenMulticastUDP("udp", nil, addr)
if err != nil {
log.Printf("[MC Broadcast] bind: %v", err)
@@ -44,7 +44,7 @@ func broadcastInner(selfVPNIP netip.Addr, pubKey wgtypes.Key, wgPort uint16, sig
// dropped by receivers' freshness gate.
send := func() error {
packet.Timestamp = time.Now().Unix()
payload := packet.Marshal(buf, signKey)
payload := packet.marshal(buf, signKey)
_, err := conn.WriteToUDP(payload, addr)
return err
}

View File

@@ -29,8 +29,8 @@ type Packet struct {
Signed []byte // Raw signed message for verification (incoming packet).
}
// Marshal the packet into a buffer with prefixed signature.
func (p Packet) Marshal(buf []byte, signKey *[64]byte) []byte {
// marshal the packet into a buffer with prefixed signature.
func (p Packet) marshal(buf []byte, signKey *[64]byte) []byte {
buf[0] = p.PeerIP
copy(buf[1:33], p.WGPubKey[:])
binary.BigEndian.PutUint16(buf[33:35], p.WGPort)
@@ -43,7 +43,7 @@ func (p Packet) Verify(buf []byte, pubKey *[32]byte) bool {
return ok
}
func Unmarshal(signed []byte) (p Packet) {
func unmarshal(signed []byte) (p Packet) {
buf := signed[signSize:]
p.PeerIP = buf[0]
copy(p.WGPubKey[:], buf[1:33])

View File

@@ -21,12 +21,12 @@ func TestPacket(t *testing.T) {
}
buf := make([]byte, BufferSize)
signed := p.Marshal(buf, priv)
signed := p.marshal(buf, priv)
if len(signed) != SignedPacketSize {
t.Fatalf("signed length = %d, want %d", len(signed), SignedPacketSize)
}
got := Unmarshal(signed)
got := unmarshal(signed)
if got.PeerIP != p.PeerIP || got.WGPubKey != p.WGPubKey ||
got.WGPort != p.WGPort || got.Timestamp != p.Timestamp {
t.Fatalf("round-trip mismatch:\n got %+v\nwant %+v", got, p)

View File

@@ -7,27 +7,34 @@ import (
"net"
"net/netip"
"time"
"git.crumpington.com/lib/ratelimiter"
)
func Receiver(vpnNet netip.Prefix, selfVPNIP netip.Addr, ch chan<- Packet) {
func Receiver(selfVPNIP netip.Addr, ch chan<- Packet) {
for {
if err := receiver(vpnNet, selfVPNIP, ch); err != nil {
log.Printf("[MCReader] %v", err)
if err := receiver(selfVPNIP, ch); err != nil {
log.Printf("[MC Receiver] %v", err)
}
time.Sleep(errorTimeout)
}
}
func receiver(vpnNet netip.Prefix, selfVPNIP netip.Addr, ch chan<- Packet) error {
func receiver(selfVPNIP netip.Addr, ch chan<- Packet) error {
limiters := map[netip.Addr]*ratelimiter.Limiter{}
selfIP := selfVPNIP.As4()[3]
addr := multicastAddr(selfVPNIP)
log.Printf("[MC Receiver] Listening on %v.", addr)
conn, err := net.ListenMulticastUDP("udp", nil, addr)
if err != nil {
return fmt.Errorf("bind: %w", err)
}
defer conn.Close()
buf := make([]byte, BufferSize+1) // +1 to detect oversized packets
buf := make([]byte, SignedPacketSize+1) // +1 to detect oversized packets
for {
conn.SetReadDeadline(time.Now().Add(32 * time.Second))
@@ -43,19 +50,43 @@ func receiver(vpnNet netip.Prefix, selfVPNIP netip.Addr, ch chan<- Packet) error
continue
}
packet := Unmarshal(buf[:n])
packet := unmarshal(buf[:n])
if packet.PeerIP == selfIP {
continue
}
// Slightly cheaper than limiting.
age := time.Since(time.Unix(packet.Timestamp, 0))
if age > maxPacketAge || age < -maxPacketAge {
continue
}
srcAddr := src.Addr().Unmap()
lim, ok := limiters[srcAddr]
if !ok {
lim = ratelimiter.New(ratelimiter.Config{
BurstLimit: 1,
FillPeriod: broadcastInterval / 2,
MaxWaitCount: 0,
})
limiters[srcAddr] = lim
}
if err := lim.Limit(); err != nil {
log.Printf("[MC Receiver] Rate limited packet from peer IP %d.", packet.PeerIP)
continue
}
packet.Signed = bytes.Clone(packet.Signed)
packet.Src = src.Addr().Unmap()
ch <- packet
}
}
func multicastAddr(vpnIP netip.Addr) *net.UDPAddr {
b := vpnIP.As4()
return net.UDPAddrFromAddrPort(
netip.AddrPortFrom(
netip.AddrFrom4([4]byte{239, b[0], b[1], b[2]}), 4560))
}

View File

@@ -81,7 +81,7 @@ func New(
if !state.IsPublic {
go multicast.Broadcast(state.VPNIP, state.PrivKey.PublicKey(), state.WGPort, &state.SignKey)
go multicast.Receiver(state.VPNNet, state.VPNIP, multicastCh)
go multicast.Receiver(state.VPNIP, multicastCh)
}
return &App{
@@ -89,7 +89,6 @@ func New(
vpnNet: state.VPNNet,
privKey: state.PrivKey,
pubKey: state.PrivKey.PublicKey(),
isRelay: state.IsRelay,
isPublic: state.IsPublic,
localDomain: localDomain,

View File

@@ -43,7 +43,7 @@ func (a *App) onAddPeer(p m.Peer) {
// endpoint from the incoming handshake automatically.
a.devPromote(peer)
} else {
a.devAddPeer(peer)
a.devAddRelayed(peer)
}
return
}
@@ -95,14 +95,6 @@ func (a *App) switchActiveRelay() {
a.relay = best
}
func preferredEndpoint(v4, v6 netip.AddrPort) netip.AddrPort {
// We always prefer v4 since all peers can connect to IPv4 addresses.
if v4.IsValid() {
return v4
}
return v6
}
func roleFor(selfIsPublic bool, selfIP netip.Addr, peerIsPublic bool, peerVPNIP netip.Addr) control.Role {
if !selfIsPublic && peerIsPublic {
return control.Client

View File

@@ -85,7 +85,7 @@ func TestOnAddPeer(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
a, dev, _ := newTestApp(t, "10.0.0.1", false, false)
a, dev, _ := newTestApp(t, "10.0.0.1", false)
key := mustKey(t)
if tc.setup != nil {
tc.setup(a, key)
@@ -192,7 +192,7 @@ func TestOnRemovePeer(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
a, dev, _ := newTestApp(t, "10.0.0.1", false, false)
a, dev, _ := newTestApp(t, "10.0.0.1", false)
key := tc.setup(t, a)
dev.Calls = nil
a.onRemovePeer(key)
@@ -268,7 +268,7 @@ func TestSwitchActiveRelay(t *testing.T) {
name: "stale relay demoted to direct before backup elected",
setup: func(t *testing.T, a *App) {
old := addRelayPeer(t, a, "10.0.0.10", ep1)
old.wgPeer.LastHandshakeTime = time.Time{} // stale — triggers switch from onTick
old.LastPing = time.Time{} // stale — Up() checks LastPing; triggers switch — triggers switch from onTick
a.relay = old
addRelayPeer(t, a, "10.0.0.11", ep2)
},
@@ -289,7 +289,7 @@ func TestSwitchActiveRelay(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
a, dev, _ := newTestApp(t, "10.0.0.1", false, false)
a, dev, _ := newTestApp(t, "10.0.0.1", false)
tc.setup(t, a)
dev.Calls = nil
a.switchActiveRelay()

View File

@@ -9,17 +9,13 @@ import (
)
func (a *App) onMulticastDiscovery(pkt multicast.Packet) {
if a.isPublic {
return
}
// Locate the sender peer by its VPN IP (final octet carried in the beacon).
octets := a.vpnNet.Addr().As4()
octets[3] = pkt.PeerIP
vpnIP := netip.AddrFrom4(octets)
peer, ok := a.peersByIP[vpnIP]
if !ok || peer.IsPublic || peer.State == StateDirect {
if !ok || peer.IsPublic {
return
}
@@ -36,16 +32,9 @@ func (a *App) onMulticastDiscovery(pkt multicast.Packet) {
}
endpoint := netip.AddrPortFrom(pkt.Src, pkt.WGPort)
if !endpoint.IsValid() {
if !endpoint.IsValid() || endpoint.Port() == 0 {
return
}
var v4, v6 netip.AddrPort
if pkt.Src.Is4() {
v4 = endpoint
} else {
v6 = endpoint
}
a.addProbe(peer, v4, v6)
peer.EndpointLAN = endpoint
}

View File

@@ -1,6 +1,7 @@
package peer
import (
"log"
"net/netip"
"time"
@@ -16,6 +17,8 @@ func (a *App) onPing(e PingEvent) {
now := time.Now()
peer.LastPing = now
// If we're the server, respond - this is always necessary as it's used to
// know if peers are up or down.
if peer.Role == control.Server {
@@ -32,27 +35,37 @@ func (a *App) onPing(e PingEvent) {
return
}
// We can only learn our own endpoint from directly-connected peers — Dst
// is the sender's observation of our WG handshake source.
// We can only learn our own endpoint from directly-connected peers — Dst is
// the sender's observation of our WG handshake source.
//
// We make sure we don't set a private address as our public address since we
// may be connected via LAN to some peers.
if peer.State == StateDirect {
if dst := e.ping.Dst; dst.IsValid() {
if dst := e.ping.Dst; addrIsRoutable(dst) {
if dst.Addr().Is4() {
if dst != a.selfV4 {
log.Printf("Local IPv4 updated: %s -> %s", a.selfV4, dst)
a.selfV4 = dst
}
} else {
if dst != a.selfV6 {
log.Printf("Local IPv6 updated: %s -> %s", a.selfV6, dst)
a.selfV6 = dst
}
}
}
return
}
a.addProbe(peer, e.ping.SrcV4, e.ping.SrcV6)
peer.UpdateEndpoints(e.ping.SrcV4, e.ping.SrcV6)
}
func (a *App) addProbe(peer *Peer, v4, v6 netip.AddrPort) {
endpoint := preferredEndpoint(v4, v6)
if !endpoint.IsValid() || endpoint == peer.PreferredEndpoint() {
return
var cgnatPrefix = netip.MustParsePrefix("100.64.0.0/10")
func addrIsRoutable(addrPort netip.AddrPort) bool {
if addrPort.Port() == 0 {
return false
}
peer.UpdateEndpoints(v4, v6)
a.devAddProbe(peer, endpoint)
addr := addrPort.Addr()
return addr.IsGlobalUnicast() && !addr.IsPrivate() && !cgnatPrefix.Contains(addr)
}

15
peer/on_pingtick.go Normal file
View File

@@ -0,0 +1,15 @@
package peer
import (
"time"
"vppn/peer/control"
)
func (a *App) onPingTick() {
now := time.Now().UnixNano()
for _, p := range a.peersByIP {
if p.Role == control.Client {
a.sendPing(p, now)
}
}
}

68
peer/on_statetick.go Normal file
View File

@@ -0,0 +1,68 @@
package peer
import (
"log"
"time"
"vppn/peer/wginterface"
)
func (a *App) onStateTick() {
wgPeers := a.devPeers()
for _, wgPeer := range wgPeers {
p, ok := a.peersByKey[wgPeer.PublicKey]
if !ok {
log.Printf("Wireguard peer not known. Removing: %v", wgPeer.PublicKey)
a.devRemove(&Peer{wgPeer: wgPeer})
continue
}
p.wgPeer = wgPeer
// Log endpoint changes.
if ep := p.WGEndpoint(); ep != p.EndpointWG {
log.Printf("Client %s %s endpoint: %s -> %s", p.Name, p.VPNIP, p.EndpointWG, ep)
p.EndpointWG = ep
}
switch p.State {
case StateRelayed:
if p.DirectAlive() {
// We may already have a valid direct endpoint due to wireguard
// roaming.
a.devPromote(p)
} else if ep := p.PreferredEndpoint(); ep.IsValid() {
// If we have an ep to probe, add it.
a.devAddProbe(p, ep)
}
case StateProbing:
if time.Since(p.LastHandshakeTime()) < 2*wginterface.ProbeKeepalive {
// Promote probing peers to direct once alive (direct path confirmed
// working).
a.devPromote(p)
} else if ep := p.PreferredEndpoint(); ep.IsValid() && ep != p.ProbeEndpoint {
// Re-start probing if we see a new endpoint.
a.devAddProbe(p, ep)
} else if time.Since(p.ProbeStart) > 8*wginterface.ProbeKeepalive {
// Give up probing if we haven't been able to handshake.
a.devAddRelayed(p)
}
case StateDirect:
if p.IsPublic || a.isPublic || p.Up() {
break
}
// Stale non-public direct peer: demote to relayed and wait for new IP
// information.
a.devAddRelayed(p)
}
}
// Ensure we have a live relay (if we're not public).
if !a.isPublic && (a.relay == nil || !a.relay.Up()) {
a.switchActiveRelay()
}
}

View File

@@ -1,52 +0,0 @@
package peer
import (
"log"
"time"
"vppn/peer/control"
"vppn/peer/wginterface"
)
func (a *App) onTick() {
wgPeers := a.devPeers()
now := time.Now().UnixNano()
for _, wgPeer := range wgPeers {
p, ok := a.peersByKey[wgPeer.PublicKey]
if !ok {
log.Printf("Wireguard peer not in index, removing: %v", wgPeer)
a.devRemove(&Peer{wgPeer: wgPeer})
continue
}
p.wgPeer = wgPeer
// Send pings to peers where we're the client.
if p.Role == control.Client {
a.sendPing(p, now)
}
switch p.State {
case StateProbing:
// Promote probing peers to direct once alive (direct path confirmed
// working).
if time.Since(p.LastHandshakeTime()) < 2*wginterface.ProbeKeepalive {
a.devPromote(p)
}
case StateDirect:
if p.IsPublic || a.isPublic || p.Up() {
break
}
// Stale non-public direct peer: demote to probing so WireGuard
// resumes handshake attempts on the direct path.
a.devAddProbe(p, p.WGEndpoint())
}
}
// Ensure we have a live relay (if we're not public).
if !a.isPublic && (a.relay == nil || !a.relay.Up()) {
a.switchActiveRelay()
}
}

View File

@@ -26,9 +26,14 @@ type Peer struct {
IsPublic bool // Peer has a public IP.
EndpointV4 netip.AddrPort // Reported IPv4 endpoint.
EndpointV6 netip.AddrPort // Reported IPv6 endpoint.
EndpointLAN netip.AddrPort // Discovered via multicast.
EndpointWG netip.AddrPort // Current wireguard endpoint.
RTT time.Duration // Round-trip time.
LastPing time.Time // Last time we had a ping.
ProbeStart time.Time // When we started probing.
ProbeEndpoint netip.AddrPort
State PeerState // Current routing state; updated on each devXxx call.
Role control.Role // Client initiates pings; server responds.
Role control.Role // Role in relation to the local application.
SignPubKey [32]byte // nacl/sign public key for verifying multicast beacons.
}
@@ -54,7 +59,12 @@ func (p *Peer) LastHandshakeTime() time.Time {
}
func (p *Peer) Up() bool {
return time.Since(p.wgPeer.LastHandshakeTime) < wginterface.SessionTimeout
return time.Since(p.LastPing) < 3*PingInterval
}
func (p *Peer) DirectAlive() bool {
return p.WGEndpoint().IsValid() &&
time.Since(p.LastHandshakeTime()) < 2*wginterface.ProbeKeepalive
}
func (p *Peer) CanRelay() bool {
@@ -62,7 +72,13 @@ func (p *Peer) CanRelay() bool {
}
func (p *Peer) PreferredEndpoint() netip.AddrPort {
return preferredEndpoint(p.EndpointV4, p.EndpointV6)
if p.EndpointLAN.IsValid() {
return p.EndpointLAN
} else if p.EndpointV4.IsValid() {
return p.EndpointV4
} else {
return p.EndpointV6
}
}
func (p *Peer) UpdateEndpoints(v4, v6 netip.AddrPort) {

View File

@@ -11,6 +11,7 @@ package wginterface
import (
"encoding/binary"
"errors"
"fmt"
"net"
"slices"
@@ -173,6 +174,10 @@ func nlAttr(attrType uint16, data []byte) []byte {
// messages, but the AF_INET ioctl interface is simpler.
func ioctlSetAddr(name string, ip net.IP, prefixLen int) error {
if ip.To4() == nil {
return errors.New("attempted to set non-IPv4 address on interface")
}
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err

View File

@@ -112,7 +112,7 @@ func (d *Device) SetRelay(pubKey wgtypes.Key, endpoint netip.AddrPort, network n
})
}
// AddProbe adds a peer with no AllowedIPs and a 5s keepalive. WireGuard will
// AddProbe adds a peer with no AllowedIPs and an 8s keepalive. WireGuard will
// attempt handshakes without routing any traffic through this peer yet.
func (d *Device) AddProbe(pubKey wgtypes.Key, endpoint netip.AddrPort) error {
keepalive := ProbeKeepalive