Compare commits
66 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a2fe8dc79d | ||
|
|
2b8cc86077 | ||
|
|
cb7c07ac96 | ||
|
|
d8c2990ffd | ||
|
|
32b8b0dc89 | ||
|
|
302e5d00d0 | ||
|
|
85d6d577e3 | ||
|
|
691eb49009 | ||
|
|
b479b37479 | ||
|
|
fe9f15bec9 | ||
|
|
b86b43f1de | ||
|
|
c47d00e694 | ||
|
|
458e1ac603 | ||
|
|
9d57f45aea | ||
|
|
36e9f6149d | ||
|
|
eaa101f976 | ||
|
|
802ca9aba4 | ||
|
|
fa933ae029 | ||
|
|
b4320c9330 | ||
|
|
d02f47cce6 | ||
|
|
797ab8bdef | ||
|
|
f765303daf | ||
|
|
0e3d4ec3a5 | ||
|
|
b875313f7d | ||
|
|
71328eb67e | ||
|
|
3e630ee0ad | ||
|
|
164d1f9d95 | ||
|
|
fa182eca76 | ||
|
|
52ea1a8d42 | ||
|
|
cc21bee798 | ||
|
|
353ef07f92 | ||
|
|
c12ef3341f | ||
|
|
992eabc0e9 | ||
|
|
c45ac83eb0 | ||
|
|
3fe9f63901 | ||
|
|
cfb2a29082 | ||
|
|
9bdb836eaa | ||
|
|
393f79e1d3 | ||
|
|
c911e3e865 | ||
|
|
c356347cf6 | ||
|
|
111f3f4d20 | ||
|
|
b6052ee7b8 | ||
|
|
2e19f0945f | ||
|
|
c325180a1b | ||
|
|
c4a81cf553 | ||
|
|
0f117e5e66 | ||
|
|
68f01f9823 | ||
|
|
8983c0d651 | ||
|
|
0cd5982a3f | ||
|
|
243e75dd09 | ||
|
|
1f7d3151b5 | ||
|
|
0709c4dac0 | ||
|
|
c0126c2036 | ||
|
|
528e67ea61 | ||
|
|
fe5f26ed70 | ||
|
|
75782c4efd | ||
|
|
867b3b5949 | ||
|
|
76fce15e32 | ||
|
|
3dcd1c1080 | ||
|
|
232b68310c | ||
|
|
a730211167 | ||
|
|
f3d8a9ff75 | ||
|
|
98f07457b9 | ||
|
|
cbc901496c | ||
|
|
cd5442f3bf | ||
|
|
11f6f2fc75 |
@@ -9,7 +9,7 @@ import (
|
|||||||
|
|
||||||
"vppn/peer"
|
"vppn/peer"
|
||||||
|
|
||||||
"git.crumpington.com/lib/go/flock"
|
"git.crumpington.com/lib/flock"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@@ -34,9 +34,6 @@ func main() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("lock: %v", err)
|
log.Fatalf("lock: %v", err)
|
||||||
}
|
}
|
||||||
if lockFile == nil {
|
|
||||||
log.Fatalf("already running for network %q", *name)
|
|
||||||
}
|
|
||||||
defer flock.Unlock(lockFile)
|
defer flock.Unlock(lockFile)
|
||||||
|
|
||||||
state, err := peer.LoadOrInit(vppnPath(*name, "state.json"), *hub, apiKey)
|
state, err := peer.LoadOrInit(vppnPath(*name, "state.json"), *hub, apiKey)
|
||||||
|
|||||||
30
go.mod
30
go.mod
@@ -3,21 +3,25 @@ module vppn
|
|||||||
go 1.25.1
|
go 1.25.1
|
||||||
|
|
||||||
require (
|
require (
|
||||||
git.crumpington.com/lib/go v0.9.1
|
git.crumpington.com/lib/flock v1.1.0
|
||||||
golang.org/x/crypto v0.42.0
|
git.crumpington.com/lib/idgen v1.0.0
|
||||||
golang.org/x/sys v0.36.0
|
git.crumpington.com/lib/keyedmutex v1.1.0
|
||||||
|
git.crumpington.com/lib/ratelimiter v1.1.1
|
||||||
|
git.crumpington.com/lib/sqliteutil v1.1.1
|
||||||
|
git.crumpington.com/lib/webutil v1.1.0
|
||||||
|
github.com/mattn/go-sqlite3 v1.14.45
|
||||||
|
golang.org/x/crypto v0.53.0
|
||||||
|
golang.org/x/sys v0.46.0
|
||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/google/go-cmp v0.6.0 // indirect
|
github.com/google/go-cmp v0.7.0 // indirect
|
||||||
github.com/josharian/native v1.1.0 // indirect
|
github.com/mdlayher/genetlink v1.4.0 // indirect
|
||||||
github.com/mattn/go-sqlite3 v1.14.32 // indirect
|
github.com/mdlayher/netlink v1.11.2 // indirect
|
||||||
github.com/mdlayher/genetlink v1.3.2 // indirect
|
github.com/mdlayher/socket v0.6.1 // indirect
|
||||||
github.com/mdlayher/netlink v1.7.2 // indirect
|
golang.org/x/net v0.56.0 // indirect
|
||||||
github.com/mdlayher/socket v0.5.1 // indirect
|
golang.org/x/sync v0.21.0 // indirect
|
||||||
golang.org/x/net v0.44.0 // indirect
|
golang.org/x/text v0.38.0 // indirect
|
||||||
golang.org/x/sync v0.17.0 // indirect
|
golang.zx2c4.com/wireguard v0.0.0-20260522210424-ecfc5a8d5446 // indirect
|
||||||
golang.org/x/text v0.29.0 // indirect
|
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 // indirect
|
|
||||||
)
|
)
|
||||||
|
|||||||
60
go.sum
60
go.sum
@@ -1,30 +1,38 @@
|
|||||||
git.crumpington.com/lib/go v0.9.1 h1:xLBzcgiZRB6Ky3Ce9hKE+Ko0YbkA4USF4eJk5i5RJF4=
|
git.crumpington.com/lib/flock v1.1.0 h1:NzPUAXnywikN+ZPabzQw9eXAwvZolGUE3pjnSxnDwFk=
|
||||||
git.crumpington.com/lib/go v0.9.1/go.mod h1:5nnfjdnUnj/FHhakaliKQKsKeSkUb0GEUKF3PqRgUXg=
|
git.crumpington.com/lib/flock v1.1.0/go.mod h1:prUmtkjpGDUakQh6TiEAylrgDTPG0HuBOUe8Lq4HKsc=
|
||||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
git.crumpington.com/lib/idgen v1.0.0 h1:0Jre8R3B+RaMOKmCgagBT659wGM93QNpamuGF2e9SII=
|
||||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
git.crumpington.com/lib/idgen v1.0.0/go.mod h1:Q8kV11Zta4P5WKDpBwsekEsnOe9IysVLsW+gPhbzFTc=
|
||||||
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
|
git.crumpington.com/lib/keyedmutex v1.1.0 h1:XOlk9f0rnwmr5yNoIvPteM2W2uakZqT4tnZKficrXho=
|
||||||
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
|
git.crumpington.com/lib/keyedmutex v1.1.0/go.mod h1:ova6v/794UCZJ5FKKrLpaol0wfNZZTB3plLObSWaGk4=
|
||||||
github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
|
git.crumpington.com/lib/ratelimiter v1.1.1 h1:8jVDVK/I0zzE3EHCu+sUeZN8a9Aqzm+PG4WrlnEvLes=
|
||||||
github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
git.crumpington.com/lib/ratelimiter v1.1.1/go.mod h1:TycyPTi/aBfnWW8F51yfo/5fSP/qKywDREqsph7TEns=
|
||||||
github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw=
|
git.crumpington.com/lib/sqliteutil v1.1.1 h1:xwfp/l2BL4nfw8Ye0Cex2HdGJQKQ1YBCFtDiMeUhnzk=
|
||||||
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
|
git.crumpington.com/lib/sqliteutil v1.1.1/go.mod h1:K8OelqOwhSYAZK42v8hKK6UmafItGf2WcMfNlq9Gfeo=
|
||||||
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
|
git.crumpington.com/lib/webutil v1.1.0 h1:S9CaRBbVgYOUsgZ5AU1gAJxkxzr8Zjn2v84MoMOy1+I=
|
||||||
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
|
git.crumpington.com/lib/webutil v1.1.0/go.mod h1:+LNLGApoe9InAJ7DCeLfiDmYov87XU3crYRHr/RYv2E=
|
||||||
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
|
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||||
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
|
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||||
|
github.com/mattn/go-sqlite3 v1.14.45 h1:6KA/spDguL3KV8rnybG7ezSaE4SeMR3KC9VbUoAQaIk=
|
||||||
|
github.com/mattn/go-sqlite3 v1.14.45/go.mod h1:pjEuOr8IwzLJP2MfGeTb0A35jauH+C2kbHKBr7yXKVQ=
|
||||||
|
github.com/mdlayher/genetlink v1.4.0 h1:f/Xs7Y2T+GyX9b3dbiUhnLE9InGs5F9RxJ2JwBMl71o=
|
||||||
|
github.com/mdlayher/genetlink v1.4.0/go.mod h1:d1hrKr8fwZU2JkcAtQUAzeTrI7nbgQSl+5k1cC0biSA=
|
||||||
|
github.com/mdlayher/netlink v1.11.2 h1:HKh2jqe+omdSWcQ88nrT7INE61B0NXfiSPFdgL4YbNI=
|
||||||
|
github.com/mdlayher/netlink v1.11.2/go.mod h1:uT2Yc/QLaZubzDpZIBi9d4GoeLwtp3x1AMeqSRrK2sA=
|
||||||
|
github.com/mdlayher/socket v0.6.1 h1:M7uj2NtuujUY4mYr1C57NmfNiRHbkKpnBxO856lsc3A=
|
||||||
|
github.com/mdlayher/socket v0.6.1/go.mod h1:+/SGtqc9V+5dAuRgQsU0fGBI+oRDiW7O2Obx10OIWfg=
|
||||||
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
|
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
|
||||||
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
|
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
|
||||||
golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI=
|
golang.org/x/crypto v0.53.0 h1:QZ4Muo8THX6CizN2vPPd5fBGHyogrdK9fG4wLPFUsto=
|
||||||
golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
|
golang.org/x/crypto v0.53.0/go.mod h1:DNLU434OwVakk9PzuwV8w62mAJpRJL3vsgcfp4Qnsio=
|
||||||
golang.org/x/net v0.44.0 h1:evd8IRDyfNBMBTTY5XRF1vaZlD+EmWx6x8PkhR04H/I=
|
golang.org/x/net v0.56.0 h1:Rw8j/hFzGvJUZwNBXnAtf5sVDVt+65SK2C7IxCxZt5o=
|
||||||
golang.org/x/net v0.44.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
|
golang.org/x/net v0.56.0/go.mod h1:D3Ku6r+V6JROoZK144D2XfMHFcMq/0zSfLelVTCFKec=
|
||||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
golang.org/x/sync v0.21.0 h1:HLII4xRRTtCRkxYp4HNFF0Js/Og6q2i++KXbg0gHCwM=
|
||||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
golang.org/x/sync v0.21.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||||
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
|
golang.org/x/sys v0.46.0 h1:noSf2Fq6F8DBgS+LysIkx7rIExoNHJsxOAtPp4rthXw=
|
||||||
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
golang.org/x/sys v0.46.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||||
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
|
golang.org/x/text v0.38.0 h1:sXmwo9DwP3OK9EZ7PqAdaooSGozfl/3a6/xJcbzPRhE=
|
||||||
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
|
golang.org/x/text v0.38.0/go.mod h1:YXZt3QhHUKYT53r2lLKFIVi6Ao1jdzrTR/KQ09qyxF4=
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4=
|
golang.zx2c4.com/wireguard v0.0.0-20260522210424-ecfc5a8d5446 h1:cqHQ3AycTHvM2R7ikgyX57D+XvtcSnGylsLkOVhta/w=
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
golang.zx2c4.com/wireguard v0.0.0-20260522210424-ecfc5a8d5446/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
|
||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU=
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU=
|
||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ=
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ=
|
||||||
|
|||||||
110
hub/api/api.go
110
hub/api/api.go
@@ -8,10 +8,11 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
"vppn/hub/api/db"
|
"vppn/hub/api/db"
|
||||||
|
"vppn/hub/errs"
|
||||||
"vppn/m"
|
"vppn/m"
|
||||||
|
|
||||||
"git.crumpington.com/lib/go/idgen"
|
"git.crumpington.com/lib/idgen"
|
||||||
"git.crumpington.com/lib/go/sqliteutil"
|
"git.crumpington.com/lib/sqliteutil"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -26,7 +27,8 @@ type API struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func New(dbPath string) (*API, error) {
|
func New(dbPath string) (*API, error) {
|
||||||
sqlDB, err := sql.Open("sqlite3", dbPath+"?_journal=WAL")
|
dbPath += "?_journal=WAL&_foreign_keys=on&_busy_timeout=5000&_txlock=immediate"
|
||||||
|
sqlDB, err := sql.Open("sqlite3", dbPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -64,30 +66,33 @@ func (a *API) ensurePassword() error {
|
|||||||
|
|
||||||
hashed, err := bcrypt.GenerateFromPassword([]byte(pwd), bcrypt.DefaultCost)
|
hashed, err := bcrypt.GenerateFromPassword([]byte(pwd), bcrypt.DefaultCost)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
log.Printf("Failed to generate password: %v", err)
|
||||||
|
return errs.ErrUnexpected
|
||||||
}
|
}
|
||||||
|
|
||||||
conf := &Config{ConfigID: 1, Password: hashed}
|
conf := &Config{ConfigID: 1, Password: hashed}
|
||||||
return db.Config_Insert(a.db, conf)
|
return errs.DB(db.Config_Insert(a.db, conf))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *API) Config_Get() (*Config, error) {
|
func (a *API) Config_Get() (*Config, error) {
|
||||||
return db.Config_Get(a.db, 1)
|
conf, err := db.Config_Get(a.db, 1)
|
||||||
|
return conf, errs.DB(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *API) Config_Update(conf *Config) error {
|
func (a *API) Config_Update(conf *Config) error {
|
||||||
return db.Config_Update(a.db, conf)
|
a.lock.Lock()
|
||||||
|
defer a.lock.Unlock()
|
||||||
|
return errs.DB(db.Config_Update(a.db, conf))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *API) Session_Delete(sessionID string) error {
|
func (a *API) Session_Delete(sessionID string) {
|
||||||
a.sessionsMu.Lock()
|
a.sessionsMu.Lock()
|
||||||
defer a.sessionsMu.Unlock()
|
defer a.sessionsMu.Unlock()
|
||||||
delete(a.sessions, sessionID)
|
delete(a.sessions, sessionID)
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
sessionTTLSecs = 86400 * 21 // sessions expire 21 days after last use
|
sessionTTL = 24 * 21 * time.Hour // sessions expire 21 days after last use
|
||||||
sessionSweepEvery = time.Hour // cadence of expired-session eviction
|
sessionSweepEvery = time.Hour // cadence of expired-session eviction
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -96,23 +101,23 @@ const (
|
|||||||
// creates a session, so anonymous requests cost no memory — a session is minted
|
// creates a session, so anonymous requests cost no memory — a session is minted
|
||||||
// only by Session_SignIn. Returning a value (not the stored pointer) keeps
|
// only by Session_SignIn. Returning a value (not the stored pointer) keeps
|
||||||
// callers from racing on the shared struct.
|
// callers from racing on the shared struct.
|
||||||
func (a *API) Session_Get(sessionID string) (Session, error) {
|
func (a *API) Session_Get(sessionID string) Session {
|
||||||
a.sessionsMu.Lock()
|
a.sessionsMu.Lock()
|
||||||
defer a.sessionsMu.Unlock()
|
defer a.sessionsMu.Unlock()
|
||||||
|
|
||||||
s, ok := a.sessions[sessionID]
|
s, ok := a.sessions[sessionID]
|
||||||
|
|
||||||
if sessionID == "" || !ok {
|
if sessionID == "" || !ok {
|
||||||
return Session{}, nil
|
return Session{}
|
||||||
}
|
}
|
||||||
|
|
||||||
if timeSince(s.LastSeenAt) > sessionTTLSecs {
|
if time.Since(s.LastSeenAt) > sessionTTL {
|
||||||
delete(a.sessions, sessionID)
|
delete(a.sessions, sessionID)
|
||||||
return Session{}, nil
|
return Session{}
|
||||||
}
|
}
|
||||||
|
|
||||||
s.LastSeenAt = time.Now().Unix()
|
s.LastSeenAt = time.Now()
|
||||||
return *s, nil
|
return *s
|
||||||
}
|
}
|
||||||
|
|
||||||
// Session_SignIn verifies pwd and, on success, mints a fresh signed-in session,
|
// Session_SignIn verifies pwd and, on success, mints a fresh signed-in session,
|
||||||
@@ -121,24 +126,36 @@ func (a *API) Session_Get(sessionID string) (Session, error) {
|
|||||||
func (a *API) Session_SignIn(pwd string) (Session, error) {
|
func (a *API) Session_SignIn(pwd string) (Session, error) {
|
||||||
conf, err := a.Config_Get()
|
conf, err := a.Config_Get()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Session{}, err
|
log.Printf("Failed to get config: %v", err)
|
||||||
|
return Session{}, errs.ErrUnexpected
|
||||||
}
|
}
|
||||||
if err := bcrypt.CompareHashAndPassword(conf.Password, []byte(pwd)); err != nil {
|
if err := bcrypt.CompareHashAndPassword(conf.Password, []byte(pwd)); err != nil {
|
||||||
return Session{}, ErrNotAuthorized
|
return Session{}, errs.ErrNotAuthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
a.sessionsMu.Lock()
|
a.sessionsMu.Lock()
|
||||||
defer a.sessionsMu.Unlock()
|
defer a.sessionsMu.Unlock()
|
||||||
s := &Session{
|
s := &Session{
|
||||||
SessionID: idgen.NewToken(),
|
SessionID: idgen.NewToken(),
|
||||||
SignedIn: true,
|
LastSeenAt: time.Now(),
|
||||||
CreatedAt: time.Now().Unix(),
|
|
||||||
LastSeenAt: time.Now().Unix(),
|
|
||||||
}
|
}
|
||||||
a.sessions[s.SessionID] = s
|
a.sessions[s.SessionID] = s
|
||||||
return *s, nil
|
return *s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *API) Session_InvalidateAll() Session {
|
||||||
|
a.sessionsMu.Lock()
|
||||||
|
defer a.sessionsMu.Unlock()
|
||||||
|
|
||||||
|
clear(a.sessions)
|
||||||
|
s := &Session{
|
||||||
|
SessionID: idgen.NewToken(),
|
||||||
|
LastSeenAt: time.Now(),
|
||||||
|
}
|
||||||
|
a.sessions[s.SessionID] = s
|
||||||
|
return *s
|
||||||
|
}
|
||||||
|
|
||||||
// sweepSessions periodically evicts sessions past their TTL. Without it, a
|
// sweepSessions periodically evicts sessions past their TTL. Without it, a
|
||||||
// signed-in session whose ID is never presented again would linger forever
|
// signed-in session whose ID is never presented again would linger forever
|
||||||
// (Session_Get only evicts on a lookup of that same ID).
|
// (Session_Get only evicts on a lookup of that same ID).
|
||||||
@@ -146,7 +163,7 @@ func (a *API) sweepSessions() {
|
|||||||
for range time.Tick(sessionSweepEvery) {
|
for range time.Tick(sessionSweepEvery) {
|
||||||
a.sessionsMu.Lock()
|
a.sessionsMu.Lock()
|
||||||
for id, s := range a.sessions {
|
for id, s := range a.sessions {
|
||||||
if timeSince(s.LastSeenAt) > sessionTTLSecs {
|
if time.Since(s.LastSeenAt) > sessionTTL {
|
||||||
delete(a.sessions, id)
|
delete(a.sessions, id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -155,29 +172,48 @@ func (a *API) sweepSessions() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *API) Network_Create(n *Network) error {
|
func (a *API) Network_Create(n *Network) error {
|
||||||
|
a.lock.Lock()
|
||||||
|
defer a.lock.Unlock()
|
||||||
|
|
||||||
n.NetworkID = idgen.NextID(0)
|
n.NetworkID = idgen.NextID(0)
|
||||||
return db.Network_Insert(a.db, n)
|
return errs.DB(db.Network_Insert(a.db, n))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *API) Network_Delete(n *Network) error {
|
func (a *API) Network_Delete(n *Network) error {
|
||||||
return db.Network_Delete(a.db, n.NetworkID)
|
a.lock.Lock()
|
||||||
|
defer a.lock.Unlock()
|
||||||
|
|
||||||
|
exists, err := db.Network_HasPeers(a.db, n.NetworkID)
|
||||||
|
if err != nil {
|
||||||
|
return errs.DB(err)
|
||||||
|
}
|
||||||
|
if exists {
|
||||||
|
return errs.Conflict.WithMsg("Delete all peers before deleting network.")
|
||||||
|
}
|
||||||
|
|
||||||
|
return errs.DB(db.Network_Delete(a.db, n.NetworkID))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *API) Network_Get(id int64) (*Network, error) {
|
func (a *API) Network_Get(id int64) (*Network, error) {
|
||||||
return db.Network_Get(a.db, id)
|
n, err := db.Network_Get(a.db, id)
|
||||||
|
return n, errs.DB(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *API) Network_List() ([]*Network, error) {
|
func (a *API) Network_List() ([]*Network, error) {
|
||||||
const query = db.Network_SelectQuery + ` ORDER BY LocalDomain ASC`
|
const query = db.Network_SelectQuery + ` ORDER BY LocalDomain ASC`
|
||||||
return db.Network_List(a.db, query)
|
n, err := db.Network_List(a.db, query)
|
||||||
|
return n, errs.DB(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *API) Peer_CreateNew(p *Peer) error {
|
func (a *API) Peer_CreateNew(p *Peer) error {
|
||||||
|
a.lock.Lock()
|
||||||
|
defer a.lock.Unlock()
|
||||||
|
|
||||||
p.WGPubKey = []byte{}
|
p.WGPubKey = []byte{}
|
||||||
p.SignPubKey = []byte{}
|
p.SignPubKey = []byte{}
|
||||||
p.APIKey = idgen.NewToken()
|
p.APIKey = idgen.NewToken()
|
||||||
|
|
||||||
return db.Peer_Insert(a.db, p)
|
return errs.DB(db.Peer_Insert(a.db, p))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *API) Peer_Init(peer *Peer, args m.PeerInitArgs) error {
|
func (a *API) Peer_Init(peer *Peer, args m.PeerInitArgs) error {
|
||||||
@@ -188,30 +224,36 @@ func (a *API) Peer_Init(peer *Peer, args m.PeerInitArgs) error {
|
|||||||
// we held the lock, so it may be stale under concurrent requests.
|
// we held the lock, so it may be stale under concurrent requests.
|
||||||
current, err := db.Peer_Get(a.db, peer.NetworkID, peer.PeerIP)
|
current, err := db.Peer_Get(a.db, peer.NetworkID, peer.PeerIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return errs.DB(err)
|
||||||
}
|
}
|
||||||
if len(current.WGPubKey) != 0 {
|
if len(current.WGPubKey) != 0 {
|
||||||
return errors.New("peer already initialized")
|
return errs.ErrAlreadyExists
|
||||||
}
|
}
|
||||||
|
|
||||||
peer.WGPubKey = args.WGPubKey
|
peer.WGPubKey = args.WGPubKey
|
||||||
peer.SignPubKey = args.SignPubKey
|
peer.SignPubKey = args.SignPubKey
|
||||||
|
|
||||||
return db.Peer_UpdateFull(a.db, peer)
|
return errs.DB(db.Peer_UpdateFull(a.db, peer))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *API) Peer_Delete(networkID int64, peerIP byte) error {
|
func (a *API) Peer_Delete(networkID int64, peerIP byte) error {
|
||||||
return db.Peer_Delete(a.db, networkID, peerIP)
|
a.lock.Lock()
|
||||||
|
defer a.lock.Unlock()
|
||||||
|
|
||||||
|
return errs.DB(db.Peer_Delete(a.db, networkID, peerIP))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *API) Peer_List(networkID int64) ([]*Peer, error) {
|
func (a *API) Peer_List(networkID int64) ([]*Peer, error) {
|
||||||
return db.Peer_ListAll(a.db, networkID)
|
p, err := db.Peer_ListAll(a.db, networkID)
|
||||||
|
return p, errs.DB(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *API) Peer_Get(networkID int64, ip byte) (*Peer, error) {
|
func (a *API) Peer_Get(networkID int64, ip byte) (*Peer, error) {
|
||||||
return db.Peer_Get(a.db, networkID, ip)
|
p, err := db.Peer_Get(a.db, networkID, ip)
|
||||||
|
return p, errs.DB(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *API) Peer_GetByAPIKey(key string) (*Peer, error) {
|
func (a *API) Peer_GetByAPIKey(key string) (*Peer, error) {
|
||||||
return db.Peer_GetByAPIKey(a.db, key)
|
p, err := db.Peer_GetByAPIKey(a.db, key)
|
||||||
|
return p, errs.DB(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ func Config_Update(
|
|||||||
|
|
||||||
n, err := result.RowsAffected()
|
n, err := result.RowsAffected()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
return err
|
||||||
}
|
}
|
||||||
switch n {
|
switch n {
|
||||||
case 0:
|
case 0:
|
||||||
|
|||||||
@@ -1,19 +1,9 @@
|
|||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
"vppn/hub/errs"
|
||||||
|
|
||||||
var (
|
|
||||||
ErrInvalidIP = errors.New("invalid IP")
|
|
||||||
ErrInvalidPeerIP = errors.New("invalid peer IP")
|
|
||||||
ErrNonPrivateIP = errors.New("non-private IP")
|
|
||||||
ErrInvalidPort = errors.New("invalid port")
|
|
||||||
ErrInvalidNetName = errors.New("invalid network name")
|
|
||||||
ErrNetNameNotLocal = errors.New("network name must end with .local")
|
|
||||||
ErrInvalidPeerName = errors.New("invalid peer name")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func Config_Sanitize(c *Config) {
|
func Config_Sanitize(c *Config) {
|
||||||
@@ -35,11 +25,11 @@ func Network_Validate(c *Network) error {
|
|||||||
// 15 bytes is linux limit for network interface names. With ending .local,
|
// 15 bytes is linux limit for network interface names. With ending .local,
|
||||||
// max length is 21.
|
// max length is 21.
|
||||||
if len(c.LocalDomain) == 0 || len(c.LocalDomain) > 21 {
|
if len(c.LocalDomain) == 0 || len(c.LocalDomain) > 21 {
|
||||||
return ErrInvalidNetName
|
return errs.ErrInvalidNetName
|
||||||
}
|
}
|
||||||
|
|
||||||
if !strings.HasSuffix(c.LocalDomain, ".local") {
|
if !strings.HasSuffix(c.LocalDomain, ".local") {
|
||||||
return ErrNetNameNotLocal
|
return errs.ErrNetNameNotLocal
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, c := range strings.TrimSuffix(c.LocalDomain, ".local") {
|
for _, c := range strings.TrimSuffix(c.LocalDomain, ".local") {
|
||||||
@@ -49,23 +39,23 @@ func Network_Validate(c *Network) error {
|
|||||||
if c >= '0' && c <= '9' {
|
if c >= '0' && c <= '9' {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
return ErrInvalidNetName
|
return errs.ErrInvalidNetName
|
||||||
}
|
}
|
||||||
|
|
||||||
addr, ok := netip.AddrFromSlice(c.Network)
|
addr, ok := netip.AddrFromSlice(c.Network)
|
||||||
if !ok || !addr.Is4() || addr.As4()[3] != 0 || addr.As4()[0] == 0 {
|
if !ok || !addr.Is4() || addr.As4()[3] != 0 || addr.As4()[0] == 0 {
|
||||||
return ErrInvalidIP
|
return errs.ErrInvalidIP
|
||||||
}
|
}
|
||||||
|
|
||||||
if !addr.IsPrivate() {
|
if !addr.IsPrivate() {
|
||||||
return ErrNonPrivateIP
|
return errs.ErrNonPrivateIP
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func Peer_Sanitize(p *Peer) {
|
func Peer_Sanitize(p *Peer) {
|
||||||
p.Name = strings.TrimSpace(p.Name)
|
p.Name = strings.TrimSpace(strings.ToLower(p.Name))
|
||||||
if len(p.Addr4) != 0 {
|
if len(p.Addr4) != 0 {
|
||||||
if addr, ok := netip.AddrFromSlice(p.Addr4); ok {
|
if addr, ok := netip.AddrFromSlice(p.Addr4); ok {
|
||||||
// Unmap so an IPv4-mapped form is stored canonically as 4 bytes.
|
// Unmap so an IPv4-mapped form is stored canonically as 4 bytes.
|
||||||
@@ -84,26 +74,26 @@ func Peer_Sanitize(p *Peer) {
|
|||||||
|
|
||||||
func Peer_Validate(p *Peer) error {
|
func Peer_Validate(p *Peer) error {
|
||||||
if p.PeerIP < 1 || p.PeerIP > 254 {
|
if p.PeerIP < 1 || p.PeerIP > 254 {
|
||||||
return ErrInvalidPeerIP
|
return errs.ErrInvalidPeerIP
|
||||||
}
|
}
|
||||||
if len(p.Addr4) > 0 {
|
if len(p.Addr4) > 0 {
|
||||||
// Must be a genuine IPv4 address (reject an IPv6 in the v4 field).
|
// Must be a genuine IPv4 address (reject an IPv6 in the v4 field).
|
||||||
if addr, ok := netip.AddrFromSlice(p.Addr4); !ok || !addr.Is4() {
|
if addr, ok := netip.AddrFromSlice(p.Addr4); !ok || !addr.Is4() {
|
||||||
return ErrInvalidIP
|
return errs.ErrInvalidIP
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(p.Addr6) > 0 {
|
if len(p.Addr6) > 0 {
|
||||||
// Must be a genuine IPv6 address (reject IPv4 / IPv4-mapped in the v6 field).
|
// Must be a genuine IPv6 address (reject IPv4 / IPv4-mapped in the v6 field).
|
||||||
if addr, ok := netip.AddrFromSlice(p.Addr6); !ok || !addr.Is6() || addr.Is4In6() {
|
if addr, ok := netip.AddrFromSlice(p.Addr6); !ok || !addr.Is6() || addr.Is4In6() {
|
||||||
return ErrInvalidIP
|
return errs.ErrInvalidIP
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if p.Port == 0 {
|
if p.Port == 0 {
|
||||||
return ErrInvalidPort
|
return errs.ErrInvalidPort
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(p.Name) == 0 {
|
if len(p.Name) == 0 || len(p.Name) > 63 {
|
||||||
return ErrInvalidPeerName
|
return errs.ErrInvalidPeerName
|
||||||
}
|
}
|
||||||
for _, c := range p.Name {
|
for _, c := range p.Name {
|
||||||
if c >= 'a' && c <= 'z' {
|
if c >= 'a' && c <= 'z' {
|
||||||
@@ -115,7 +105,7 @@ func Peer_Validate(p *Peer) error {
|
|||||||
if c == '-' {
|
if c == '-' {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
return ErrInvalidPeerName
|
return errs.ErrInvalidPeerName
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -11,3 +11,9 @@ func Peer_GetByAPIKey(tx TX, apiKey string) (*Peer, error) {
|
|||||||
Peer_SelectQuery+` WHERE APIKey=?`,
|
Peer_SelectQuery+` WHERE APIKey=?`,
|
||||||
apiKey)
|
apiKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Network_HasPeers(tx TX, networkID int64) (exists bool, err error) {
|
||||||
|
const query = "SELECT EXISTS(SELECT 1 FROM peers WHERE NetworkID=?)"
|
||||||
|
err = tx.QueryRow(query, networkID).Scan(&exists)
|
||||||
|
return exists, err
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,12 +0,0 @@
|
|||||||
package api
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"vppn/hub/api/db"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
ErrNotAuthorized = errors.New("not authorized")
|
|
||||||
ErrInvalidIP = db.ErrInvalidIP
|
|
||||||
ErrInvalidPort = db.ErrInvalidPort
|
|
||||||
)
|
|
||||||
@@ -21,5 +21,6 @@ CREATE TABLE peers (
|
|||||||
WGPubKey BLOB NOT NULL,
|
WGPubKey BLOB NOT NULL,
|
||||||
SignPubKey BLOB NOT NULL,
|
SignPubKey BLOB NOT NULL,
|
||||||
UNIQUE(NetworkID, Name),
|
UNIQUE(NetworkID, Name),
|
||||||
PRIMARY KEY(NetworkID, PeerIP)
|
PRIMARY KEY(NetworkID, PeerIP),
|
||||||
|
FOREIGN KEY(NetworkID) REFERENCES networks(NetworkID)
|
||||||
) WITHOUT ROWID;
|
) WITHOUT ROWID;
|
||||||
|
|||||||
@@ -1,7 +0,0 @@
|
|||||||
package api
|
|
||||||
|
|
||||||
import "time"
|
|
||||||
|
|
||||||
func timeSince(ts int64) int64 {
|
|
||||||
return time.Now().Unix() - ts
|
|
||||||
}
|
|
||||||
@@ -1,6 +1,9 @@
|
|||||||
package api
|
package api
|
||||||
|
|
||||||
import "vppn/hub/api/db"
|
import (
|
||||||
|
"time"
|
||||||
|
"vppn/hub/api/db"
|
||||||
|
)
|
||||||
|
|
||||||
type Config = db.Config
|
type Config = db.Config
|
||||||
type Network = db.Network
|
type Network = db.Network
|
||||||
@@ -8,7 +11,5 @@ type Peer = db.Peer
|
|||||||
|
|
||||||
type Session struct {
|
type Session struct {
|
||||||
SessionID string
|
SessionID string
|
||||||
SignedIn bool
|
LastSeenAt time.Time
|
||||||
CreatedAt int64
|
|
||||||
LastSeenAt int64
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,7 +8,8 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"vppn/hub/api"
|
"vppn/hub/api"
|
||||||
|
|
||||||
"git.crumpington.com/lib/go/webutil"
|
"git.crumpington.com/lib/keyedmutex"
|
||||||
|
"git.crumpington.com/lib/webutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
//go:embed static
|
//go:embed static
|
||||||
@@ -28,6 +29,9 @@ type App struct {
|
|||||||
mux *http.ServeMux
|
mux *http.ServeMux
|
||||||
tmpl map[string]*template.Template
|
tmpl map[string]*template.Template
|
||||||
insecure bool
|
insecure bool
|
||||||
|
|
||||||
|
// Per-remote address sign-in serialization lock.
|
||||||
|
signInLock *keyedmutex.KeyedMutex[string]
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewApp(conf Config) (*App, error) {
|
func NewApp(conf Config) (*App, error) {
|
||||||
@@ -41,6 +45,7 @@ func NewApp(conf Config) (*App, error) {
|
|||||||
mux: http.NewServeMux(),
|
mux: http.NewServeMux(),
|
||||||
tmpl: webutil.ParseTemplateSet(templateFuncs, templateFS),
|
tmpl: webutil.ParseTemplateSet(templateFuncs, templateFS),
|
||||||
insecure: conf.Insecure,
|
insecure: conf.Insecure,
|
||||||
|
signInLock: keyedmutex.New[string](),
|
||||||
}
|
}
|
||||||
|
|
||||||
app.registerRoutes()
|
app.registerRoutes()
|
||||||
|
|||||||
37
hub/errs/db.go
Normal file
37
hub/errs/db.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
package errs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
sqlite3 "github.com/mattn/go-sqlite3"
|
||||||
|
)
|
||||||
|
|
||||||
|
func DB(err error) error {
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var e *Error
|
||||||
|
if errors.As(err, &e) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
var se sqlite3.Error
|
||||||
|
if errors.As(err, &se) {
|
||||||
|
switch se.ExtendedCode {
|
||||||
|
case sqlite3.ErrConstraintUnique, sqlite3.ErrConstraintPrimaryKey:
|
||||||
|
return ErrAlreadyExists
|
||||||
|
case sqlite3.ErrConstraintForeignKey, sqlite3.ErrConstraintCheck:
|
||||||
|
return ErrConstraint
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("Unexpected error: %v", err)
|
||||||
|
return ErrUnexpected
|
||||||
|
}
|
||||||
61
hub/errs/types.go
Normal file
61
hub/errs/types.go
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
package errs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Error struct {
|
||||||
|
Code int
|
||||||
|
Msg string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Error) Error() string {
|
||||||
|
return fmt.Sprintf("[%d] %s", e.Code, e.Msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrNotAuthorized = NotAuthorized.WithMsg("Not authorized.")
|
||||||
|
ErrInvalidPassword = BadRequest.WithMsg("Invalid password.")
|
||||||
|
ErrPasswordMismatch = BadRequest.WithMsg("Passwords don't match.")
|
||||||
|
ErrUnexpected = Internal.WithMsg("Unexpected internal error.")
|
||||||
|
ErrNotFound = NotFound.WithMsg("Not found.")
|
||||||
|
ErrAlreadyExists = Conflict.WithMsg("Already exists.")
|
||||||
|
|
||||||
|
// Validation errors.
|
||||||
|
ErrInvalidIP = BadRequest.WithMsg("Invalid IP.")
|
||||||
|
ErrInvalidPeerIP = BadRequest.WithMsg("Invalid peer IP.")
|
||||||
|
ErrNonPrivateIP = BadRequest.WithMsg("Non-private IP.")
|
||||||
|
ErrInvalidPort = BadRequest.WithMsg("Invalid port.")
|
||||||
|
ErrInvalidNetName = BadRequest.WithMsg("Invalid network name.")
|
||||||
|
ErrNetNameNotLocal = BadRequest.WithMsg("Network name must end with .local.")
|
||||||
|
ErrInvalidPeerName = BadRequest.WithMsg("Invalid peer name.")
|
||||||
|
ErrConstraint = BadRequest.WithMsg("Constraint error.")
|
||||||
|
)
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type Type struct {
|
||||||
|
Code int
|
||||||
|
Msg string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Type) WithErr(err error) *Error {
|
||||||
|
return &Error{Code: t.Code, Msg: err.Error()}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Type) WithMsg(msg string) *Error {
|
||||||
|
return &Error{Code: t.Code, Msg: msg}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Type) WithMsgf(msg string, args ...any) *Error {
|
||||||
|
return &Error{Code: t.Code, Msg: fmt.Sprintf(msg, args...)}
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
Internal = Type{Code: http.StatusInternalServerError}
|
||||||
|
NotAuthorized = Type{Code: http.StatusUnauthorized}
|
||||||
|
NotFound = Type{Code: http.StatusNotFound}
|
||||||
|
BadRequest = Type{Code: http.StatusBadRequest}
|
||||||
|
Conflict = Type{Code: http.StatusConflict}
|
||||||
|
)
|
||||||
@@ -4,7 +4,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"vppn/hub/api"
|
"vppn/hub/api"
|
||||||
|
|
||||||
"git.crumpington.com/lib/go/webutil"
|
"git.crumpington.com/lib/webutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (app *App) formGetNetwork(form url.Values) (*api.Network, error) {
|
func (app *App) formGetNetwork(form url.Values) (*api.Network, error) {
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
package hub
|
package hub
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"vppn/hub/api"
|
"vppn/hub/api"
|
||||||
|
"vppn/hub/errs"
|
||||||
"git.crumpington.com/lib/go/webutil"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type handlerFunc func(s *api.Session, w http.ResponseWriter, r *http.Request) error
|
type handlerFunc func(s *api.Session, w http.ResponseWriter, r *http.Request) error
|
||||||
@@ -13,32 +13,26 @@ type handlerFunc func(s *api.Session, w http.ResponseWriter, r *http.Request) er
|
|||||||
func (app *App) handlePub(pattern string, fn handlerFunc) {
|
func (app *App) handlePub(pattern string, fn handlerFunc) {
|
||||||
wrapped := func(w http.ResponseWriter, r *http.Request) {
|
wrapped := func(w http.ResponseWriter, r *http.Request) {
|
||||||
sessionID := app.getCookie(r, sessionIDCookieName)
|
sessionID := app.getCookie(r, sessionIDCookieName)
|
||||||
s, err := app.api.Session_Get(sessionID)
|
s := app.api.Session_Get(sessionID)
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to get session: %v", err)
|
|
||||||
http.Error(w, "Internal error", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.Method == http.MethodPost {
|
if r.Method == http.MethodPost {
|
||||||
|
r.Body = http.MaxBytesReader(w, r.Body, 128*1024)
|
||||||
r.ParseMultipartForm(64 * 1024)
|
r.ParseMultipartForm(64 * 1024)
|
||||||
} else {
|
} else {
|
||||||
r.ParseForm()
|
r.ParseForm()
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := fn(&s, w, r); err != nil {
|
if err := fn(&s, w, r); err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
handleError(w, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
app.mux.HandleFunc(pattern,
|
app.mux.HandleFunc(pattern, withLogging(wrapped))
|
||||||
webutil.WithLogging(
|
|
||||||
wrapped))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (app *App) handleNotSignedIn(pattern string, fn handlerFunc) {
|
func (app *App) handleNotSignedIn(pattern string, fn handlerFunc) {
|
||||||
app.handlePub(pattern, func(s *api.Session, w http.ResponseWriter, r *http.Request) error {
|
app.handlePub(pattern, func(s *api.Session, w http.ResponseWriter, r *http.Request) error {
|
||||||
if s.SignedIn {
|
if s.SessionID != "" {
|
||||||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -48,7 +42,7 @@ func (app *App) handleNotSignedIn(pattern string, fn handlerFunc) {
|
|||||||
|
|
||||||
func (app *App) handleSignedIn(pattern string, fn handlerFunc) {
|
func (app *App) handleSignedIn(pattern string, fn handlerFunc) {
|
||||||
app.handlePub(pattern, func(s *api.Session, w http.ResponseWriter, r *http.Request) error {
|
app.handlePub(pattern, func(s *api.Session, w http.ResponseWriter, r *http.Request) error {
|
||||||
if !s.SignedIn {
|
if s.SessionID == "" {
|
||||||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -66,6 +60,7 @@ func (app *App) handlePeer(pattern string, fn peerHandlerFunc) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Not doing constant time compare because index lookup time dominates.
|
||||||
peer, err := app.api.Peer_GetByAPIKey(apiKey)
|
peer, err := app.api.Peer_GetByAPIKey(apiKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "Not authorized", http.StatusUnauthorized)
|
http.Error(w, "Not authorized", http.StatusUnauthorized)
|
||||||
@@ -74,12 +69,19 @@ func (app *App) handlePeer(pattern string, fn peerHandlerFunc) {
|
|||||||
|
|
||||||
r.ParseForm()
|
r.ParseForm()
|
||||||
if err := fn(peer, w, r); err != nil {
|
if err := fn(peer, w, r); err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
handleError(w, err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
app.mux.HandleFunc(pattern,
|
app.mux.HandleFunc(pattern, withLogging(wrapped))
|
||||||
webutil.WithLogging(
|
}
|
||||||
wrapped))
|
|
||||||
|
func handleError(w http.ResponseWriter, err error) {
|
||||||
|
var e *errs.Error
|
||||||
|
if errors.As(err, &e) {
|
||||||
|
http.Error(w, e.Msg, e.Code)
|
||||||
|
} else {
|
||||||
|
log.Printf("Unexpected error: %v", err)
|
||||||
|
http.Error(w, "Internal server error.", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,19 +2,22 @@ package hub
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"log"
|
"log"
|
||||||
|
"math/rand/v2"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"time"
|
||||||
"vppn/hub/api"
|
"vppn/hub/api"
|
||||||
|
"vppn/hub/errs"
|
||||||
"vppn/m"
|
"vppn/m"
|
||||||
|
|
||||||
"git.crumpington.com/lib/go/webutil"
|
"git.crumpington.com/lib/webutil"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (a *App) _root(s *api.Session, w http.ResponseWriter, r *http.Request) error {
|
func (a *App) _root(s *api.Session, w http.ResponseWriter, r *http.Request) error {
|
||||||
if s.SignedIn {
|
if s.SessionID != "" {
|
||||||
return a.redirect(w, r, "/admin/network/list/")
|
return a.redirect(w, r, "/admin/network/list/")
|
||||||
} else {
|
} else {
|
||||||
return a.redirect(w, r, "/sign-in/")
|
return a.redirect(w, r, "/sign-in/")
|
||||||
@@ -26,6 +29,15 @@ func (a *App) _signin(s *api.Session, w http.ResponseWriter, r *http.Request) er
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *App) _signinSubmit(s *api.Session, w http.ResponseWriter, r *http.Request) error {
|
func (a *App) _signinSubmit(s *api.Session, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
// Ignoring error here - if host is the empty string, it will contend for the
|
||||||
|
// lock anyway.
|
||||||
|
host, _, _ := net.SplitHostPort(r.RemoteAddr)
|
||||||
|
if !a.signInLock.TryLock(host) {
|
||||||
|
time.Sleep(time.Second + time.Duration(rand.Int64N(int64(3*time.Second))))
|
||||||
|
return errs.ErrNotAuthorized
|
||||||
|
}
|
||||||
|
defer a.signInLock.Unlock(host)
|
||||||
|
|
||||||
var pwd string
|
var pwd string
|
||||||
err := webutil.NewFormScanner(r.Form).
|
err := webutil.NewFormScanner(r.Form).
|
||||||
Scan("Password", &pwd).
|
Scan("Password", &pwd).
|
||||||
@@ -36,8 +48,10 @@ func (a *App) _signinSubmit(s *api.Session, w http.ResponseWriter, r *http.Reque
|
|||||||
|
|
||||||
sess, err := a.api.Session_SignIn(pwd)
|
sess, err := a.api.Session_SignIn(pwd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
time.Sleep(time.Second + time.Duration(rand.Int64N(int64(3*time.Second))))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
a.setCookie(w, sessionIDCookieName, sess.SessionID)
|
a.setCookie(w, sessionIDCookieName, sess.SessionID)
|
||||||
|
|
||||||
return a.redirect(w, r, "/")
|
return a.redirect(w, r, "/")
|
||||||
@@ -48,9 +62,7 @@ func (a *App) _adminSignOut(s *api.Session, w http.ResponseWriter, r *http.Reque
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *App) _adminSignOutSubmit(s *api.Session, w http.ResponseWriter, r *http.Request) error {
|
func (a *App) _adminSignOutSubmit(s *api.Session, w http.ResponseWriter, r *http.Request) error {
|
||||||
if err := a.api.Session_Delete(s.SessionID); err != nil {
|
a.api.Session_Delete(s.SessionID)
|
||||||
log.Printf("Failed to delete session cookie %s: %v", s.SessionID, err)
|
|
||||||
}
|
|
||||||
a.deleteCookie(w, sessionIDCookieName)
|
a.deleteCookie(w, sessionIDCookieName)
|
||||||
return a.redirect(w, r, "/")
|
return a.redirect(w, r, "/")
|
||||||
}
|
}
|
||||||
@@ -236,22 +248,23 @@ func (a *App) _adminPasswordSubmit(s *api.Session, w http.ResponseWriter, r *htt
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(newPwd) < 8 {
|
if len(newPwd) < 8 || len(newPwd) > 72 {
|
||||||
return errors.New("password is too short")
|
return errs.ErrInvalidPassword
|
||||||
}
|
}
|
||||||
|
|
||||||
if newPwd != newPwd2 {
|
if newPwd != newPwd2 {
|
||||||
return errors.New("passwords don't match")
|
return errs.ErrPasswordMismatch
|
||||||
}
|
}
|
||||||
|
|
||||||
err = bcrypt.CompareHashAndPassword(conf.Password, []byte(curPwd))
|
err = bcrypt.CompareHashAndPassword(conf.Password, []byte(curPwd))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return errs.ErrNotAuthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
hash, err := bcrypt.GenerateFromPassword([]byte(newPwd), bcrypt.DefaultCost)
|
hash, err := bcrypt.GenerateFromPassword([]byte(newPwd), bcrypt.DefaultCost)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
log.Printf("Failed to hash password with bcrypt: %v", err)
|
||||||
|
return errs.ErrUnexpected
|
||||||
}
|
}
|
||||||
|
|
||||||
conf.Password = hash
|
conf.Password = hash
|
||||||
@@ -260,7 +273,10 @@ func (a *App) _adminPasswordSubmit(s *api.Session, w http.ResponseWriter, r *htt
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return a.redirect(w, r, "/admin/config/")
|
*s = a.api.Session_InvalidateAll()
|
||||||
|
a.setCookie(w, sessionIDCookieName, s.SessionID)
|
||||||
|
|
||||||
|
return a.redirect(w, r, "/admin/network/list/")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *App) _peerInit(peer *api.Peer, w http.ResponseWriter, r *http.Request) error {
|
func (a *App) _peerInit(peer *api.Peer, w http.ResponseWriter, r *http.Request) error {
|
||||||
@@ -269,9 +285,11 @@ func (a *App) _peerInit(peer *api.Peer, w http.ResponseWriter, r *http.Request)
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.Body = http.MaxBytesReader(w, r.Body, 2048)
|
||||||
|
|
||||||
args := m.PeerInitArgs{}
|
args := m.PeerInitArgs{}
|
||||||
if err := json.NewDecoder(r.Body).Decode(&args); err != nil {
|
if err := json.NewDecoder(r.Body).Decode(&args); err != nil {
|
||||||
return err
|
return errs.BadRequest.WithMsg("Invalid request body.")
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(args.WGPubKey) != 32 {
|
if len(args.WGPubKey) != 32 {
|
||||||
@@ -328,8 +346,10 @@ func (a *App) peersList(networkID int64) (peers []m.Peer, err error) {
|
|||||||
}
|
}
|
||||||
wgKey, err := wgtypes.NewKey(p.WGPubKey)
|
wgKey, err := wgtypes.NewKey(p.WGPubKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Printf("Bad WG key in DB for peer %d/%d", p.NetworkID, p.PeerIP)
|
||||||
continue // malformed key; skip rather than serve garbage
|
continue // malformed key; skip rather than serve garbage
|
||||||
}
|
}
|
||||||
|
|
||||||
var signKey [32]byte
|
var signKey [32]byte
|
||||||
copy(signKey[:], p.SignPubKey)
|
copy(signKey[:], p.SignPubKey)
|
||||||
peers = append(peers, m.Peer{
|
peers = append(peers, m.Peer{
|
||||||
|
|||||||
@@ -5,8 +5,9 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
"git.crumpington.com/lib/go/webutil"
|
"git.crumpington.com/lib/webutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Main() {
|
func Main() {
|
||||||
@@ -32,6 +33,10 @@ func Main() {
|
|||||||
srv := &http.Server{
|
srv := &http.Server{
|
||||||
Addr: conf.ListenAddr,
|
Addr: conf.ListenAddr,
|
||||||
Handler: app.Handler(),
|
Handler: app.Handler(),
|
||||||
|
ReadHeaderTimeout: 30 * time.Second,
|
||||||
|
ReadTimeout: 60 * time.Second,
|
||||||
|
WriteTimeout: 120 * time.Second,
|
||||||
|
IdleTimeout: 180 * time.Second,
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Fatal(webutil.ListenAndServe(srv))
|
log.Fatal(webutil.ListenAndServe(srv))
|
||||||
|
|||||||
47
hub/middleware.go
Normal file
47
hub/middleware.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package hub
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _log = log.New(os.Stderr, "", 0)
|
||||||
|
|
||||||
|
type responseWriterWrapper struct {
|
||||||
|
http.ResponseWriter
|
||||||
|
httpStatus int
|
||||||
|
responseSize int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *responseWriterWrapper) WriteHeader(status int) {
|
||||||
|
w.httpStatus = status
|
||||||
|
w.ResponseWriter.WriteHeader(status)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *responseWriterWrapper) Write(b []byte) (int, error) {
|
||||||
|
if w.httpStatus == 0 {
|
||||||
|
w.httpStatus = 200
|
||||||
|
}
|
||||||
|
w.responseSize += len(b)
|
||||||
|
return w.ResponseWriter.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func withLogging(inner http.HandlerFunc) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
t := time.Now()
|
||||||
|
wrapper := responseWriterWrapper{w, 0, 0}
|
||||||
|
|
||||||
|
inner(&wrapper, r)
|
||||||
|
_log.Printf("%s \"%s %s %s\" %d %d %v",
|
||||||
|
r.RemoteAddr,
|
||||||
|
r.Method,
|
||||||
|
r.URL.Path,
|
||||||
|
r.Proto,
|
||||||
|
wrapper.httpStatus,
|
||||||
|
wrapper.responseSize,
|
||||||
|
time.Since(t),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,7 +3,7 @@ package hub
|
|||||||
import "net/http"
|
import "net/http"
|
||||||
|
|
||||||
func (a *App) registerRoutes() {
|
func (a *App) registerRoutes() {
|
||||||
a.mux.Handle("GET /static/", http.FileServerFS(staticFS))
|
a.mux.Handle("GET /static/", withLogging(http.FileServerFS(staticFS).ServeHTTP))
|
||||||
a.handlePub("GET /", a._root)
|
a.handlePub("GET /", a._root)
|
||||||
|
|
||||||
a.handleNotSignedIn("GET /sign-in/", a._signin)
|
a.handleNotSignedIn("GET /sign-in/", a._signin)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@
|
|||||||
<header>
|
<header>
|
||||||
<h1>VPPN</h1>
|
<h1>VPPN</h1>
|
||||||
<nav>
|
<nav>
|
||||||
{{if .Session.SignedIn -}}
|
{{if .Session.SessionID -}}
|
||||||
<a href="/admin/networks/list/">Home</a> /
|
<a href="/admin/networks/list/">Home</a> /
|
||||||
<a href="/admin/sign-out/">Sign out</a>
|
<a href="/admin/sign-out/">Sign out</a>
|
||||||
{{- end}}
|
{{- end}}
|
||||||
|
|||||||
@@ -9,7 +9,7 @@
|
|||||||
<header>
|
<header>
|
||||||
<h1>VPPN</h1>
|
<h1>VPPN</h1>
|
||||||
<nav>
|
<nav>
|
||||||
{{if .Session.SignedIn -}}
|
{{if .Session.SessionID -}}
|
||||||
<a href="/admin/networks/list/">Home</a> /
|
<a href="/admin/networks/list/">Home</a> /
|
||||||
<a href="/admin/sign-out/">Sign out</a>
|
<a href="/admin/sign-out/">Sign out</a>
|
||||||
{{- end}}
|
{{- end}}
|
||||||
|
|||||||
67
peer/app.go
67
peer/app.go
@@ -1,9 +1,13 @@
|
|||||||
package peer
|
package peer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -20,6 +24,7 @@ var _ WGDevice = (*wginterface.Device)(nil) // compile-time check: Device satisf
|
|||||||
const (
|
const (
|
||||||
ControlPort = 4561
|
ControlPort = 4561
|
||||||
PingInterval = 8 * time.Second
|
PingInterval = 8 * time.Second
|
||||||
|
TickInterval = 2 * time.Second
|
||||||
TimeoutInterval = 30 * time.Second
|
TimeoutInterval = 30 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -40,7 +45,6 @@ type App struct {
|
|||||||
vpnNet netip.Prefix
|
vpnNet netip.Prefix
|
||||||
privKey wgtypes.Key
|
privKey wgtypes.Key
|
||||||
pubKey wgtypes.Key
|
pubKey wgtypes.Key
|
||||||
isRelay bool
|
|
||||||
isPublic bool
|
isPublic bool
|
||||||
localDomain string
|
localDomain string
|
||||||
|
|
||||||
@@ -75,13 +79,15 @@ func (a *App) Run() error {
|
|||||||
// while we were down).
|
// while we were down).
|
||||||
a.updateHosts()
|
a.updateHosts()
|
||||||
|
|
||||||
ticker := time.NewTicker(PingInterval)
|
stateTicker := time.NewTicker(TickInterval)
|
||||||
defer ticker.Stop()
|
pingTicker := time.NewTicker(PingInterval)
|
||||||
|
|
||||||
sig := make(chan os.Signal, 1)
|
sig := make(chan os.Signal, 1)
|
||||||
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
|
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
|
||||||
defer signal.Stop(sig)
|
defer signal.Stop(sig)
|
||||||
|
|
||||||
|
tickCount := 0
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case p := <-a.hubAddCh:
|
case p := <-a.hubAddCh:
|
||||||
@@ -92,8 +98,15 @@ func (a *App) Run() error {
|
|||||||
a.onPing(e)
|
a.onPing(e)
|
||||||
case e := <-a.multicastCh:
|
case e := <-a.multicastCh:
|
||||||
a.onMulticastDiscovery(e)
|
a.onMulticastDiscovery(e)
|
||||||
case <-ticker.C:
|
case <-stateTicker.C:
|
||||||
a.onTick()
|
a.onStateTick()
|
||||||
|
case <-pingTicker.C:
|
||||||
|
a.onPingTick()
|
||||||
|
tickCount++
|
||||||
|
if tickCount%8 == 0 {
|
||||||
|
a.logNetworkState()
|
||||||
|
}
|
||||||
|
|
||||||
case <-sig:
|
case <-sig:
|
||||||
return a.onShutdown()
|
return a.onShutdown()
|
||||||
}
|
}
|
||||||
@@ -103,3 +116,47 @@ func (a *App) Run() error {
|
|||||||
func (a *App) onShutdown() error {
|
func (a *App) onShutdown() error {
|
||||||
return wginterface.Delete(a.dev.Name())
|
return wginterface.Delete(a.dev.Name())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *App) logNetworkState() {
|
||||||
|
var b strings.Builder
|
||||||
|
fmt.Fprintf(&b, "Network state (self: %s public=%v):\n", a.vpnIP, a.isPublic)
|
||||||
|
fmt.Fprintf(&b, " Network: %v\n", a.vpnNet)
|
||||||
|
fmt.Fprintf(&b, " IPv4: %v\n", a.selfV4)
|
||||||
|
fmt.Fprintf(&b, " IPv6: %v\n", a.selfV6)
|
||||||
|
|
||||||
|
if a.relay != nil {
|
||||||
|
fmt.Fprintf(&b, " Relay: %s\n", a.relay.Name)
|
||||||
|
} else {
|
||||||
|
fmt.Fprint(&b, " Relay: -\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteString("Peers:\n")
|
||||||
|
//
|
||||||
|
peers := make([]*Peer, 0, len(a.peersByIP))
|
||||||
|
for _, p := range a.peersByIP {
|
||||||
|
peers = append(peers, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Slice(peers, func(i, j int) bool {
|
||||||
|
return peers[i].VPNIP.As4()[3] < peers[j].VPNIP.As4()[3]
|
||||||
|
})
|
||||||
|
|
||||||
|
for _, p := range peers {
|
||||||
|
ip := p.VPNIP.As4()[3]
|
||||||
|
up := "DOWN"
|
||||||
|
if p.Up() {
|
||||||
|
up = "UP "
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint := p.WGEndpoint()
|
||||||
|
if endpoint.IsValid() {
|
||||||
|
fmt.Fprintf(&b, " %24s %03d %s %s seen=%s @ %s\n",
|
||||||
|
p.Name, ip, p.State, up, time.Since(p.LastPing).Round(time.Millisecond), endpoint)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(&b, " %24s %03d %s %s seen=%s\n",
|
||||||
|
p.Name, ip, p.State, up, time.Since(p.LastPing).Round(time.Millisecond))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Print(b.String())
|
||||||
|
}
|
||||||
|
|||||||
@@ -26,13 +26,14 @@ func addRelayPeer(t *testing.T, a *App, vpnIP string, ep netip.AddrPort) *Peer {
|
|||||||
})
|
})
|
||||||
p := a.peersByKey[key]
|
p := a.peersByKey[key]
|
||||||
p.wgPeer.LastHandshakeTime = time.Now()
|
p.wgPeer.LastHandshakeTime = time.Now()
|
||||||
|
p.LastPing = time.Now()
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
// newTestApp returns a minimal App wired to a fakeWGDevice and fakeControlConn.
|
// newTestApp returns a minimal App wired to a fakeWGDevice and fakeControlConn.
|
||||||
// vpnIP is the local VPN address (e.g. "10.0.0.1").
|
// vpnIP is the local VPN address (e.g. "10.0.0.1").
|
||||||
// isPublic / isRelay describe the local node's role.
|
// isPublic / isRelay describe the local node's role.
|
||||||
func newTestApp(t *testing.T, vpnIP string, isPublic, isRelay bool) (*App, *fakeWGDevice, *fakeControlConn) {
|
func newTestApp(t *testing.T, vpnIP string, isPublic bool) (*App, *fakeWGDevice, *fakeControlConn) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
privKey, err := wgtypes.GeneratePrivateKey()
|
privKey, err := wgtypes.GeneratePrivateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -47,7 +48,6 @@ func newTestApp(t *testing.T, vpnIP string, isPublic, isRelay bool) (*App, *fake
|
|||||||
privKey: privKey,
|
privKey: privKey,
|
||||||
pubKey: privKey.PublicKey(),
|
pubKey: privKey.PublicKey(),
|
||||||
isPublic: isPublic,
|
isPublic: isPublic,
|
||||||
isRelay: isRelay,
|
|
||||||
dev: dev,
|
dev: dev,
|
||||||
controlConn: cc,
|
controlConn: cc,
|
||||||
peersByKey: make(map[wgtypes.Key]*Peer),
|
peersByKey: make(map[wgtypes.Key]*Peer),
|
||||||
|
|||||||
@@ -37,8 +37,12 @@ type Ping struct {
|
|||||||
// buf[:Size]. Taking the buffer lets callers reuse one across sends; every
|
// buf[:Size]. Taking the buffer lets callers reuse one across sends; every
|
||||||
// field is written unconditionally so a reused buffer needs no pre-zeroing.
|
// field is written unconditionally so a reused buffer needs no pre-zeroing.
|
||||||
func (p Ping) Marshal(buf []byte) []byte {
|
func (p Ping) Marshal(buf []byte) []byte {
|
||||||
|
_ = buf[Size-1] // Panic if buffer is too small.
|
||||||
|
|
||||||
buf[0] = version
|
buf[0] = version
|
||||||
binary.BigEndian.PutUint64(buf[1:9], uint64(p.PingTS))
|
binary.BigEndian.PutUint64(buf[1:9], uint64(p.PingTS))
|
||||||
|
|
||||||
|
// SrcV4.
|
||||||
if p.SrcV4.IsValid() {
|
if p.SrcV4.IsValid() {
|
||||||
a4 := p.SrcV4.Addr().As4()
|
a4 := p.SrcV4.Addr().As4()
|
||||||
copy(buf[9:13], a4[:])
|
copy(buf[9:13], a4[:])
|
||||||
@@ -46,9 +50,13 @@ func (p Ping) Marshal(buf []byte) []byte {
|
|||||||
} else {
|
} else {
|
||||||
clear(buf[9:15])
|
clear(buf[9:15])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SrcV6.
|
||||||
a16 := p.SrcV6.Addr().As16()
|
a16 := p.SrcV6.Addr().As16()
|
||||||
copy(buf[15:31], a16[:])
|
copy(buf[15:31], a16[:])
|
||||||
binary.BigEndian.PutUint16(buf[31:33], p.SrcV6.Port())
|
binary.BigEndian.PutUint16(buf[31:33], p.SrcV6.Port())
|
||||||
|
|
||||||
|
// Dst.
|
||||||
a16 = p.Dst.Addr().As16()
|
a16 = p.Dst.Addr().As16()
|
||||||
copy(buf[33:49], a16[:])
|
copy(buf[33:49], a16[:])
|
||||||
binary.BigEndian.PutUint16(buf[49:51], p.Dst.Port())
|
binary.BigEndian.PutUint16(buf[49:51], p.Dst.Port())
|
||||||
@@ -63,14 +71,21 @@ func Unmarshal(buf [Size]byte) (Ping, error) {
|
|||||||
p := Ping{
|
p := Ping{
|
||||||
PingTS: int64(binary.BigEndian.Uint64(buf[1:9])),
|
PingTS: int64(binary.BigEndian.Uint64(buf[1:9])),
|
||||||
}
|
}
|
||||||
if addr := netip.AddrFrom4([4]byte(buf[9:13])); !addr.IsUnspecified() {
|
|
||||||
|
addr := netip.AddrFrom4([4]byte(buf[9:13]))
|
||||||
|
if !addr.IsUnspecified() {
|
||||||
p.SrcV4 = netip.AddrPortFrom(addr, binary.BigEndian.Uint16(buf[13:15]))
|
p.SrcV4 = netip.AddrPortFrom(addr, binary.BigEndian.Uint16(buf[13:15]))
|
||||||
}
|
}
|
||||||
if addr := netip.AddrFrom16([16]byte(buf[15:31])); !addr.IsUnspecified() {
|
|
||||||
|
addr = netip.AddrFrom16([16]byte(buf[15:31])).Unmap()
|
||||||
|
if !addr.IsUnspecified() {
|
||||||
p.SrcV6 = netip.AddrPortFrom(addr, binary.BigEndian.Uint16(buf[31:33]))
|
p.SrcV6 = netip.AddrPortFrom(addr, binary.BigEndian.Uint16(buf[31:33]))
|
||||||
}
|
}
|
||||||
if addr := netip.AddrFrom16([16]byte(buf[33:49])).Unmap(); !addr.IsUnspecified() {
|
|
||||||
|
addr = netip.AddrFrom16([16]byte(buf[33:49])).Unmap()
|
||||||
|
if !addr.IsUnspecified() {
|
||||||
p.Dst = netip.AddrPortFrom(addr, binary.BigEndian.Uint16(buf[49:51]))
|
p.Dst = netip.AddrPortFrom(addr, binary.BigEndian.Uint16(buf[49:51]))
|
||||||
}
|
}
|
||||||
|
|
||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
"vppn/peer/control"
|
"vppn/peer/control"
|
||||||
)
|
)
|
||||||
@@ -32,11 +33,14 @@ func (c *udpControlConn) SendPing(dst netip.AddrPort, ping control.Ping, buf []b
|
|||||||
// run reads incoming ping packets and forwards them to ch until ctx is done.
|
// run reads incoming ping packets and forwards them to ch until ctx is done.
|
||||||
// Call this in a goroutine before starting the App event loop.
|
// Call this in a goroutine before starting the App event loop.
|
||||||
func (c *udpControlConn) run(ch chan<- PingEvent) {
|
func (c *udpControlConn) run(ch chan<- PingEvent) {
|
||||||
|
const errorTimeout = 8 * time.Second
|
||||||
|
|
||||||
var buf [control.Size]byte
|
var buf [control.Size]byte
|
||||||
for {
|
for {
|
||||||
n, src, err := c.conn.ReadFromUDP(buf[:])
|
n, src, err := c.conn.ReadFromUDP(buf[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("control read: %v", err)
|
log.Printf("control read: %v", err)
|
||||||
|
time.Sleep(errorTimeout)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -38,22 +38,28 @@ func (a *App) devPeers() []wgtypes.Peer {
|
|||||||
return peers
|
return peers
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *App) devAddPeer(p *Peer) {
|
func (a *App) devAddRelayed(p *Peer) {
|
||||||
log.Printf("RELAYED: %s - %s ", p.Name, p.VPNIP.String())
|
log.Printf("RELAYED: %s - %s ", p.Name, p.VPNIP.String())
|
||||||
devRetry(p.VPNIP, "AddPeer", func() error { return a.dev.AddPeer(p.PubKey()) })
|
devRetry(p.VPNIP, "AddPeer", func() error { return a.dev.AddPeer(p.PubKey()) })
|
||||||
|
|
||||||
p.State = StateRelayed
|
p.State = StateRelayed
|
||||||
|
p.EndpointV4 = netip.AddrPort{}
|
||||||
|
p.EndpointV6 = netip.AddrPort{}
|
||||||
|
p.EndpointLAN = netip.AddrPort{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *App) devAddDirect(p *Peer, endpoint netip.AddrPort) {
|
func (a *App) devAddDirect(p *Peer, endpoint netip.AddrPort) {
|
||||||
log.Printf("DIRECT: %s - %s @ %s", p.Name, p.VPNIP.String(), endpoint.String())
|
log.Printf("DIRECT: %s - %s @ %s", p.Name, p.VPNIP.String(), endpoint.String())
|
||||||
devRetry(p.VPNIP, "AddDirect", func() error { return a.dev.AddDirect(p.PubKey(), endpoint, p.VPNIP) })
|
devRetry(p.VPNIP, "AddDirect", func() error { return a.dev.AddDirect(p.PubKey(), endpoint, p.VPNIP) })
|
||||||
|
|
||||||
p.State = StateDirect
|
p.State = StateDirect
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *App) devSetRelay(p *Peer, endpoint netip.AddrPort) {
|
func (a *App) devSetRelay(p *Peer, endpoint netip.AddrPort) {
|
||||||
log.Printf("RELAY: %s - %s @ %s", p.Name, p.VPNIP.String(), endpoint.String())
|
log.Printf("RELAY: %s - %s @ %s", p.Name, p.VPNIP.String(), endpoint.String())
|
||||||
devRetry(p.VPNIP, "SetRelay", func() error { return a.dev.SetRelay(p.PubKey(), endpoint, a.vpnNet) })
|
devRetry(p.VPNIP, "SetRelay", func() error { return a.dev.SetRelay(p.PubKey(), endpoint, a.vpnNet) })
|
||||||
p.State = StateDirect // Dirrect connection. The app marks peer as relay.
|
|
||||||
|
p.State = StateDirect // Direct connection. The app marks peer as relay.
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *App) devPromote(p *Peer) {
|
func (a *App) devPromote(p *Peer) {
|
||||||
@@ -61,19 +67,24 @@ func (a *App) devPromote(p *Peer) {
|
|||||||
if ep.IsValid() {
|
if ep.IsValid() {
|
||||||
log.Printf("PROMOTED: %s - %s @ %s", p.Name, p.VPNIP.String(), p.WGEndpoint().String())
|
log.Printf("PROMOTED: %s - %s @ %s", p.Name, p.VPNIP.String(), p.WGEndpoint().String())
|
||||||
} else {
|
} else {
|
||||||
log.Printf("PROMOTED: %s - %s (no IP)", p.Name, p.VPNIP.String())
|
log.Printf("DIRECT: %s - %s (waiting for handshake)", p.Name, p.VPNIP.String())
|
||||||
}
|
}
|
||||||
devRetry(p.VPNIP, "Promote", func() error { return a.dev.Promote(p.PubKey(), p.VPNIP) })
|
devRetry(p.VPNIP, "Promote", func() error { return a.dev.Promote(p.PubKey(), p.VPNIP) })
|
||||||
|
|
||||||
p.State = StateDirect
|
p.State = StateDirect
|
||||||
|
p.LastPing = time.Now() // Assume the peer is up after being promoted.
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *App) devAddProbe(p *Peer, endpoint netip.AddrPort) {
|
func (a *App) devAddProbe(p *Peer, endpoint netip.AddrPort) {
|
||||||
log.Printf("PROBE: %s - %s @ %s", p.Name, p.VPNIP.String(), endpoint.String())
|
log.Printf("PROBE: %s - %s @ %s", p.Name, p.VPNIP.String(), endpoint.String())
|
||||||
devRetry(p.VPNIP, "AddProbe", func() error { return a.dev.AddProbe(p.PubKey(), endpoint) })
|
devRetry(p.VPNIP, "AddProbe", func() error { return a.dev.AddProbe(p.PubKey(), endpoint) })
|
||||||
|
|
||||||
p.State = StateProbing
|
p.State = StateProbing
|
||||||
|
p.ProbeStart = time.Now()
|
||||||
|
p.ProbeEndpoint = endpoint
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *App) devRemove(p *Peer) {
|
func (a *App) devRemove(p *Peer) {
|
||||||
log.Printf("REMOVED: %s - %s", p.Name, p.VPNIP.String())
|
log.Printf("REMOVED: %s", p.PubKey())
|
||||||
devRetry(p.VPNIP, "RemovePeer", func() error { return a.dev.RemovePeer(p.PubKey()) })
|
devRetry(p.VPNIP, "RemovePeer", func() error { return a.dev.RemovePeer(p.PubKey()) })
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"git.crumpington.com/lib/go/flock"
|
"git.crumpington.com/lib/flock"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
@@ -62,13 +62,15 @@ func (hp *HubPoller) Run() {
|
|||||||
hp.apply(state)
|
hp.apply(state)
|
||||||
}
|
}
|
||||||
|
|
||||||
hp.poll()
|
client := &http.Client{Timeout: 32 * time.Second}
|
||||||
|
|
||||||
|
hp.poll(client)
|
||||||
for range time.Tick(hubPollInterval) {
|
for range time.Tick(hubPollInterval) {
|
||||||
hp.poll()
|
hp.poll(client)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hp *HubPoller) poll() {
|
func (hp *HubPoller) poll(client *http.Client) {
|
||||||
req, err := http.NewRequest(http.MethodGet, hp.hubURL, nil)
|
req, err := http.NewRequest(http.MethodGet, hp.hubURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[HubPoller] build request: %v", err)
|
log.Printf("[HubPoller] build request: %v", err)
|
||||||
@@ -76,7 +78,6 @@ func (hp *HubPoller) poll() {
|
|||||||
}
|
}
|
||||||
req.SetBasicAuth("", hp.apiKey)
|
req.SetBasicAuth("", hp.apiKey)
|
||||||
|
|
||||||
client := &http.Client{Timeout: 32 * time.Second}
|
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[HubPoller] fetch: %v", err)
|
log.Printf("[HubPoller] fetch: %v", err)
|
||||||
@@ -89,7 +90,7 @@ func (hp *HubPoller) poll() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(io.LimitReader(resp.Body, 128*1024))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[HubPoller] read body: %v", err)
|
log.Printf("[HubPoller] read body: %v", err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
"golang.org/x/crypto/nacl/sign"
|
"golang.org/x/crypto/nacl/sign"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
@@ -93,7 +94,7 @@ func initFromHub(hubURL, apiKey string, privKey wgtypes.Key) (LocalState, error)
|
|||||||
req.SetBasicAuth("", apiKey)
|
req.SetBasicAuth("", apiKey)
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
resp, err := (&http.Client{Timeout: time.Minute}).Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return LocalState{}, fmt.Errorf("hub init: %w", err)
|
return LocalState{}, fmt.Errorf("hub init: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,10 +9,6 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
)
|
)
|
||||||
|
|
||||||
var addr = net.UDPAddrFromAddrPort(netip.AddrPortFrom(
|
|
||||||
netip.AddrFrom4([4]byte{224, 0, 0, 157}),
|
|
||||||
4560))
|
|
||||||
|
|
||||||
func Broadcast(
|
func Broadcast(
|
||||||
selfVPNIP netip.Addr,
|
selfVPNIP netip.Addr,
|
||||||
pubKey wgtypes.Key,
|
pubKey wgtypes.Key,
|
||||||
@@ -20,15 +16,19 @@ func Broadcast(
|
|||||||
signKey *[64]byte,
|
signKey *[64]byte,
|
||||||
) {
|
) {
|
||||||
for {
|
for {
|
||||||
broadcastInner(selfVPNIP, pubKey, wgPort, signKey)
|
broadcast(selfVPNIP, pubKey, wgPort, signKey)
|
||||||
time.Sleep(errorTimeout)
|
time.Sleep(errorTimeout)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func broadcastInner(selfVPNIP netip.Addr, pubKey wgtypes.Key, wgPort uint16, signKey *[64]byte) {
|
func broadcast(selfVPNIP netip.Addr, pubKey wgtypes.Key, wgPort uint16, signKey *[64]byte) {
|
||||||
|
addr := multicastAddr(selfVPNIP)
|
||||||
|
|
||||||
|
log.Printf("[MC Broadcast] Sending on %v.", addr)
|
||||||
|
|
||||||
conn, err := net.ListenMulticastUDP("udp", nil, addr)
|
conn, err := net.ListenMulticastUDP("udp", nil, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[MCBroadcast] bind: %v", err)
|
log.Printf("[MC Broadcast] bind: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
@@ -44,18 +44,18 @@ func broadcastInner(selfVPNIP netip.Addr, pubKey wgtypes.Key, wgPort uint16, sig
|
|||||||
// dropped by receivers' freshness gate.
|
// dropped by receivers' freshness gate.
|
||||||
send := func() error {
|
send := func() error {
|
||||||
packet.Timestamp = time.Now().Unix()
|
packet.Timestamp = time.Now().Unix()
|
||||||
payload := packet.Marshal(buf, signKey)
|
payload := packet.marshal(buf, signKey)
|
||||||
_, err := conn.WriteToUDP(payload, addr)
|
_, err := conn.WriteToUDP(payload, addr)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := send(); err != nil {
|
if err := send(); err != nil {
|
||||||
log.Printf("[MCBroadcast] write: %v", err)
|
log.Printf("[MC Broadcast] write: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for range time.Tick(broadcastInterval) {
|
for range time.Tick(broadcastInterval) {
|
||||||
if err := send(); err != nil {
|
if err := send(); err != nil {
|
||||||
log.Printf("[MCBroadcast] write: %v", err)
|
log.Printf("[MC Broadcast] write: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,8 +29,8 @@ type Packet struct {
|
|||||||
Signed []byte // Raw signed message for verification (incoming packet).
|
Signed []byte // Raw signed message for verification (incoming packet).
|
||||||
}
|
}
|
||||||
|
|
||||||
// Marshal the packet into a buffer with prefixed signature.
|
// marshal the packet into a buffer with prefixed signature.
|
||||||
func (p Packet) Marshal(buf []byte, signKey *[64]byte) []byte {
|
func (p Packet) marshal(buf []byte, signKey *[64]byte) []byte {
|
||||||
buf[0] = p.PeerIP
|
buf[0] = p.PeerIP
|
||||||
copy(buf[1:33], p.WGPubKey[:])
|
copy(buf[1:33], p.WGPubKey[:])
|
||||||
binary.BigEndian.PutUint16(buf[33:35], p.WGPort)
|
binary.BigEndian.PutUint16(buf[33:35], p.WGPort)
|
||||||
@@ -43,7 +43,7 @@ func (p Packet) Verify(buf []byte, pubKey *[32]byte) bool {
|
|||||||
return ok
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
func Unmarshal(signed []byte) (p Packet) {
|
func unmarshal(signed []byte) (p Packet) {
|
||||||
buf := signed[signSize:]
|
buf := signed[signSize:]
|
||||||
p.PeerIP = buf[0]
|
p.PeerIP = buf[0]
|
||||||
copy(p.WGPubKey[:], buf[1:33])
|
copy(p.WGPubKey[:], buf[1:33])
|
||||||
|
|||||||
@@ -21,12 +21,12 @@ func TestPacket(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
buf := make([]byte, BufferSize)
|
buf := make([]byte, BufferSize)
|
||||||
signed := p.Marshal(buf, priv)
|
signed := p.marshal(buf, priv)
|
||||||
if len(signed) != SignedPacketSize {
|
if len(signed) != SignedPacketSize {
|
||||||
t.Fatalf("signed length = %d, want %d", len(signed), SignedPacketSize)
|
t.Fatalf("signed length = %d, want %d", len(signed), SignedPacketSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
got := Unmarshal(signed)
|
got := unmarshal(signed)
|
||||||
if got.PeerIP != p.PeerIP || got.WGPubKey != p.WGPubKey ||
|
if got.PeerIP != p.PeerIP || got.WGPubKey != p.WGPubKey ||
|
||||||
got.WGPort != p.WGPort || got.Timestamp != p.Timestamp {
|
got.WGPort != p.WGPort || got.Timestamp != p.Timestamp {
|
||||||
t.Fatalf("round-trip mismatch:\n got %+v\nwant %+v", got, p)
|
t.Fatalf("round-trip mismatch:\n got %+v\nwant %+v", got, p)
|
||||||
|
|||||||
@@ -7,27 +7,34 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"git.crumpington.com/lib/ratelimiter"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Receiver(vpnNet netip.Prefix, selfVPNIP netip.Addr, ch chan<- Packet) {
|
func Receiver(selfVPNIP netip.Addr, ch chan<- Packet) {
|
||||||
for {
|
for {
|
||||||
if err := receiver(vpnNet, selfVPNIP, ch); err != nil {
|
if err := receiver(selfVPNIP, ch); err != nil {
|
||||||
log.Printf("[MCReader] %v", err)
|
log.Printf("[MC Receiver] %v", err)
|
||||||
}
|
}
|
||||||
time.Sleep(errorTimeout)
|
time.Sleep(errorTimeout)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func receiver(vpnNet netip.Prefix, selfVPNIP netip.Addr, ch chan<- Packet) error {
|
func receiver(selfVPNIP netip.Addr, ch chan<- Packet) error {
|
||||||
|
limiters := map[netip.Addr]*ratelimiter.Limiter{}
|
||||||
|
|
||||||
selfIP := selfVPNIP.As4()[3]
|
selfIP := selfVPNIP.As4()[3]
|
||||||
|
|
||||||
|
addr := multicastAddr(selfVPNIP)
|
||||||
|
|
||||||
|
log.Printf("[MC Receiver] Listening on %v.", addr)
|
||||||
conn, err := net.ListenMulticastUDP("udp", nil, addr)
|
conn, err := net.ListenMulticastUDP("udp", nil, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("bind: %w", err)
|
return fmt.Errorf("bind: %w", err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
buf := make([]byte, BufferSize+1) // +1 to detect oversized packets
|
buf := make([]byte, SignedPacketSize+1) // +1 to detect oversized packets
|
||||||
|
|
||||||
for {
|
for {
|
||||||
conn.SetReadDeadline(time.Now().Add(32 * time.Second))
|
conn.SetReadDeadline(time.Now().Add(32 * time.Second))
|
||||||
@@ -43,19 +50,43 @@ func receiver(vpnNet netip.Prefix, selfVPNIP netip.Addr, ch chan<- Packet) error
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
packet := Unmarshal(buf[:n])
|
packet := unmarshal(buf[:n])
|
||||||
|
|
||||||
if packet.PeerIP == selfIP {
|
if packet.PeerIP == selfIP {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Slightly cheaper than limiting.
|
||||||
age := time.Since(time.Unix(packet.Timestamp, 0))
|
age := time.Since(time.Unix(packet.Timestamp, 0))
|
||||||
if age > maxPacketAge || age < -maxPacketAge {
|
if age > maxPacketAge || age < -maxPacketAge {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
srcAddr := src.Addr().Unmap()
|
||||||
|
lim, ok := limiters[srcAddr]
|
||||||
|
if !ok {
|
||||||
|
lim = ratelimiter.New(ratelimiter.Config{
|
||||||
|
BurstLimit: 1,
|
||||||
|
FillPeriod: broadcastInterval / 2,
|
||||||
|
MaxWaitCount: 0,
|
||||||
|
})
|
||||||
|
limiters[srcAddr] = lim
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := lim.Limit(); err != nil {
|
||||||
|
log.Printf("[MC Receiver] Rate limited packet from peer IP %d.", packet.PeerIP)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
packet.Signed = bytes.Clone(packet.Signed)
|
packet.Signed = bytes.Clone(packet.Signed)
|
||||||
packet.Src = src.Addr().Unmap()
|
packet.Src = src.Addr().Unmap()
|
||||||
ch <- packet
|
ch <- packet
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func multicastAddr(vpnIP netip.Addr) *net.UDPAddr {
|
||||||
|
b := vpnIP.As4()
|
||||||
|
return net.UDPAddrFromAddrPort(
|
||||||
|
netip.AddrPortFrom(
|
||||||
|
netip.AddrFrom4([4]byte{239, b[0], b[1], b[2]}), 4560))
|
||||||
|
}
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ func New(
|
|||||||
|
|
||||||
if !state.IsPublic {
|
if !state.IsPublic {
|
||||||
go multicast.Broadcast(state.VPNIP, state.PrivKey.PublicKey(), state.WGPort, &state.SignKey)
|
go multicast.Broadcast(state.VPNIP, state.PrivKey.PublicKey(), state.WGPort, &state.SignKey)
|
||||||
go multicast.Receiver(state.VPNNet, state.VPNIP, multicastCh)
|
go multicast.Receiver(state.VPNIP, multicastCh)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &App{
|
return &App{
|
||||||
@@ -89,7 +89,6 @@ func New(
|
|||||||
vpnNet: state.VPNNet,
|
vpnNet: state.VPNNet,
|
||||||
privKey: state.PrivKey,
|
privKey: state.PrivKey,
|
||||||
pubKey: state.PrivKey.PublicKey(),
|
pubKey: state.PrivKey.PublicKey(),
|
||||||
isRelay: state.IsRelay,
|
|
||||||
isPublic: state.IsPublic,
|
isPublic: state.IsPublic,
|
||||||
localDomain: localDomain,
|
localDomain: localDomain,
|
||||||
|
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ func (a *App) onAddPeer(p m.Peer) {
|
|||||||
// endpoint from the incoming handshake automatically.
|
// endpoint from the incoming handshake automatically.
|
||||||
a.devPromote(peer)
|
a.devPromote(peer)
|
||||||
} else {
|
} else {
|
||||||
a.devAddPeer(peer)
|
a.devAddRelayed(peer)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -95,14 +95,6 @@ func (a *App) switchActiveRelay() {
|
|||||||
a.relay = best
|
a.relay = best
|
||||||
}
|
}
|
||||||
|
|
||||||
func preferredEndpoint(v4, v6 netip.AddrPort) netip.AddrPort {
|
|
||||||
// We always prefer v4 since all peers can connect to IPv4 addresses.
|
|
||||||
if v4.IsValid() {
|
|
||||||
return v4
|
|
||||||
}
|
|
||||||
return v6
|
|
||||||
}
|
|
||||||
|
|
||||||
func roleFor(selfIsPublic bool, selfIP netip.Addr, peerIsPublic bool, peerVPNIP netip.Addr) control.Role {
|
func roleFor(selfIsPublic bool, selfIP netip.Addr, peerIsPublic bool, peerVPNIP netip.Addr) control.Role {
|
||||||
if !selfIsPublic && peerIsPublic {
|
if !selfIsPublic && peerIsPublic {
|
||||||
return control.Client
|
return control.Client
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ func TestOnAddPeer(t *testing.T) {
|
|||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
a, dev, _ := newTestApp(t, "10.0.0.1", false, false)
|
a, dev, _ := newTestApp(t, "10.0.0.1", false)
|
||||||
key := mustKey(t)
|
key := mustKey(t)
|
||||||
if tc.setup != nil {
|
if tc.setup != nil {
|
||||||
tc.setup(a, key)
|
tc.setup(a, key)
|
||||||
@@ -192,7 +192,7 @@ func TestOnRemovePeer(t *testing.T) {
|
|||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
a, dev, _ := newTestApp(t, "10.0.0.1", false, false)
|
a, dev, _ := newTestApp(t, "10.0.0.1", false)
|
||||||
key := tc.setup(t, a)
|
key := tc.setup(t, a)
|
||||||
dev.Calls = nil
|
dev.Calls = nil
|
||||||
a.onRemovePeer(key)
|
a.onRemovePeer(key)
|
||||||
@@ -268,7 +268,7 @@ func TestSwitchActiveRelay(t *testing.T) {
|
|||||||
name: "stale relay demoted to direct before backup elected",
|
name: "stale relay demoted to direct before backup elected",
|
||||||
setup: func(t *testing.T, a *App) {
|
setup: func(t *testing.T, a *App) {
|
||||||
old := addRelayPeer(t, a, "10.0.0.10", ep1)
|
old := addRelayPeer(t, a, "10.0.0.10", ep1)
|
||||||
old.wgPeer.LastHandshakeTime = time.Time{} // stale — triggers switch from onTick
|
old.LastPing = time.Time{} // stale — Up() checks LastPing; triggers switch — triggers switch from onTick
|
||||||
a.relay = old
|
a.relay = old
|
||||||
addRelayPeer(t, a, "10.0.0.11", ep2)
|
addRelayPeer(t, a, "10.0.0.11", ep2)
|
||||||
},
|
},
|
||||||
@@ -289,7 +289,7 @@ func TestSwitchActiveRelay(t *testing.T) {
|
|||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
a, dev, _ := newTestApp(t, "10.0.0.1", false, false)
|
a, dev, _ := newTestApp(t, "10.0.0.1", false)
|
||||||
tc.setup(t, a)
|
tc.setup(t, a)
|
||||||
dev.Calls = nil
|
dev.Calls = nil
|
||||||
a.switchActiveRelay()
|
a.switchActiveRelay()
|
||||||
|
|||||||
@@ -9,17 +9,13 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (a *App) onMulticastDiscovery(pkt multicast.Packet) {
|
func (a *App) onMulticastDiscovery(pkt multicast.Packet) {
|
||||||
if a.isPublic {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Locate the sender peer by its VPN IP (final octet carried in the beacon).
|
// Locate the sender peer by its VPN IP (final octet carried in the beacon).
|
||||||
octets := a.vpnNet.Addr().As4()
|
octets := a.vpnNet.Addr().As4()
|
||||||
octets[3] = pkt.PeerIP
|
octets[3] = pkt.PeerIP
|
||||||
vpnIP := netip.AddrFrom4(octets)
|
vpnIP := netip.AddrFrom4(octets)
|
||||||
|
|
||||||
peer, ok := a.peersByIP[vpnIP]
|
peer, ok := a.peersByIP[vpnIP]
|
||||||
if !ok || peer.IsPublic || peer.State == StateDirect {
|
if !ok || peer.IsPublic {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -36,16 +32,9 @@ func (a *App) onMulticastDiscovery(pkt multicast.Packet) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
endpoint := netip.AddrPortFrom(pkt.Src, pkt.WGPort)
|
endpoint := netip.AddrPortFrom(pkt.Src, pkt.WGPort)
|
||||||
if !endpoint.IsValid() {
|
if !endpoint.IsValid() || endpoint.Port() == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var v4, v6 netip.AddrPort
|
peer.EndpointLAN = endpoint
|
||||||
if pkt.Src.Is4() {
|
|
||||||
v4 = endpoint
|
|
||||||
} else {
|
|
||||||
v6 = endpoint
|
|
||||||
}
|
|
||||||
|
|
||||||
a.addProbe(peer, v4, v6)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package peer
|
package peer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"log"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -16,6 +17,8 @@ func (a *App) onPing(e PingEvent) {
|
|||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
|
peer.LastPing = now
|
||||||
|
|
||||||
// If we're the server, respond - this is always necessary as it's used to
|
// If we're the server, respond - this is always necessary as it's used to
|
||||||
// know if peers are up or down.
|
// know if peers are up or down.
|
||||||
if peer.Role == control.Server {
|
if peer.Role == control.Server {
|
||||||
@@ -32,27 +35,37 @@ func (a *App) onPing(e PingEvent) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// We can only learn our own endpoint from directly-connected peers — Dst
|
// We can only learn our own endpoint from directly-connected peers — Dst is
|
||||||
// is the sender's observation of our WG handshake source.
|
// the sender's observation of our WG handshake source.
|
||||||
|
//
|
||||||
|
// We make sure we don't set a private address as our public address since we
|
||||||
|
// may be connected via LAN to some peers.
|
||||||
if peer.State == StateDirect {
|
if peer.State == StateDirect {
|
||||||
if dst := e.ping.Dst; dst.IsValid() {
|
if dst := e.ping.Dst; addrIsRoutable(dst) {
|
||||||
if dst.Addr().Is4() {
|
if dst.Addr().Is4() {
|
||||||
|
if dst != a.selfV4 {
|
||||||
|
log.Printf("Local IPv4 updated: %s -> %s", a.selfV4, dst)
|
||||||
a.selfV4 = dst
|
a.selfV4 = dst
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
|
if dst != a.selfV6 {
|
||||||
|
log.Printf("Local IPv6 updated: %s -> %s", a.selfV6, dst)
|
||||||
a.selfV6 = dst
|
a.selfV6 = dst
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
a.addProbe(peer, e.ping.SrcV4, e.ping.SrcV6)
|
peer.UpdateEndpoints(e.ping.SrcV4, e.ping.SrcV6)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *App) addProbe(peer *Peer, v4, v6 netip.AddrPort) {
|
var cgnatPrefix = netip.MustParsePrefix("100.64.0.0/10")
|
||||||
endpoint := preferredEndpoint(v4, v6)
|
|
||||||
if !endpoint.IsValid() || endpoint == peer.PreferredEndpoint() {
|
func addrIsRoutable(addrPort netip.AddrPort) bool {
|
||||||
return
|
if addrPort.Port() == 0 {
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
peer.UpdateEndpoints(v4, v6)
|
addr := addrPort.Addr()
|
||||||
a.devAddProbe(peer, endpoint)
|
return addr.IsGlobalUnicast() && !addr.IsPrivate() && !cgnatPrefix.Contains(addr)
|
||||||
}
|
}
|
||||||
|
|||||||
15
peer/on_pingtick.go
Normal file
15
peer/on_pingtick.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
"vppn/peer/control"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (a *App) onPingTick() {
|
||||||
|
now := time.Now().UnixNano()
|
||||||
|
for _, p := range a.peersByIP {
|
||||||
|
if p.Role == control.Client {
|
||||||
|
a.sendPing(p, now)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
68
peer/on_statetick.go
Normal file
68
peer/on_statetick.go
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"vppn/peer/wginterface"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (a *App) onStateTick() {
|
||||||
|
wgPeers := a.devPeers()
|
||||||
|
|
||||||
|
for _, wgPeer := range wgPeers {
|
||||||
|
p, ok := a.peersByKey[wgPeer.PublicKey]
|
||||||
|
if !ok {
|
||||||
|
log.Printf("Wireguard peer not known. Removing: %v", wgPeer.PublicKey)
|
||||||
|
a.devRemove(&Peer{wgPeer: wgPeer})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
p.wgPeer = wgPeer
|
||||||
|
|
||||||
|
// Log endpoint changes.
|
||||||
|
if ep := p.WGEndpoint(); ep != p.EndpointWG {
|
||||||
|
log.Printf("Client %s %s endpoint: %s -> %s", p.Name, p.VPNIP, p.EndpointWG, ep)
|
||||||
|
p.EndpointWG = ep
|
||||||
|
}
|
||||||
|
|
||||||
|
switch p.State {
|
||||||
|
case StateRelayed:
|
||||||
|
if p.DirectAlive() {
|
||||||
|
// We may already have a valid direct endpoint due to wireguard
|
||||||
|
// roaming.
|
||||||
|
a.devPromote(p)
|
||||||
|
} else if ep := p.PreferredEndpoint(); ep.IsValid() {
|
||||||
|
// If we have an ep to probe, add it.
|
||||||
|
a.devAddProbe(p, ep)
|
||||||
|
}
|
||||||
|
|
||||||
|
case StateProbing:
|
||||||
|
if time.Since(p.LastHandshakeTime()) < 2*wginterface.ProbeKeepalive {
|
||||||
|
// Promote probing peers to direct once alive (direct path confirmed
|
||||||
|
// working).
|
||||||
|
a.devPromote(p)
|
||||||
|
} else if ep := p.PreferredEndpoint(); ep.IsValid() && ep != p.ProbeEndpoint {
|
||||||
|
// Re-start probing if we see a new endpoint.
|
||||||
|
a.devAddProbe(p, ep)
|
||||||
|
} else if time.Since(p.ProbeStart) > 8*wginterface.ProbeKeepalive {
|
||||||
|
// Give up probing if we haven't been able to handshake.
|
||||||
|
a.devAddRelayed(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
case StateDirect:
|
||||||
|
if p.IsPublic || a.isPublic || p.Up() {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stale non-public direct peer: demote to relayed and wait for new IP
|
||||||
|
// information.
|
||||||
|
a.devAddRelayed(p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure we have a live relay (if we're not public).
|
||||||
|
if !a.isPublic && (a.relay == nil || !a.relay.Up()) {
|
||||||
|
a.switchActiveRelay()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,52 +0,0 @@
|
|||||||
package peer
|
|
||||||
|
|
||||||
import (
|
|
||||||
"log"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"vppn/peer/control"
|
|
||||||
"vppn/peer/wginterface"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (a *App) onTick() {
|
|
||||||
wgPeers := a.devPeers()
|
|
||||||
|
|
||||||
now := time.Now().UnixNano()
|
|
||||||
|
|
||||||
for _, wgPeer := range wgPeers {
|
|
||||||
p, ok := a.peersByKey[wgPeer.PublicKey]
|
|
||||||
if !ok {
|
|
||||||
log.Printf("Wireguard peer not in index, removing: %v", wgPeer)
|
|
||||||
a.devRemove(&Peer{wgPeer: wgPeer})
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
p.wgPeer = wgPeer
|
|
||||||
|
|
||||||
// Send pings to peers where we're the client.
|
|
||||||
if p.Role == control.Client {
|
|
||||||
a.sendPing(p, now)
|
|
||||||
}
|
|
||||||
|
|
||||||
switch p.State {
|
|
||||||
case StateProbing:
|
|
||||||
// Promote probing peers to direct once alive (direct path confirmed
|
|
||||||
// working).
|
|
||||||
if time.Since(p.LastHandshakeTime()) < 2*wginterface.ProbeKeepalive {
|
|
||||||
a.devPromote(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
case StateDirect:
|
|
||||||
if p.IsPublic || a.isPublic || p.Up() {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
// Stale non-public direct peer: demote to probing so WireGuard
|
|
||||||
// resumes handshake attempts on the direct path.
|
|
||||||
a.devAddProbe(p, p.WGEndpoint())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure we have a live relay (if we're not public).
|
|
||||||
if !a.isPublic && (a.relay == nil || !a.relay.Up()) {
|
|
||||||
a.switchActiveRelay()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -13,8 +13,8 @@ import (
|
|||||||
type PeerState string
|
type PeerState string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
StateRelayed = PeerState("RELAY")
|
StateRelayed = PeerState("RELAY ")
|
||||||
StateProbing = PeerState("PROBE")
|
StateProbing = PeerState("PROBE ")
|
||||||
StateDirect = PeerState("DIRECT")
|
StateDirect = PeerState("DIRECT")
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -26,9 +26,14 @@ type Peer struct {
|
|||||||
IsPublic bool // Peer has a public IP.
|
IsPublic bool // Peer has a public IP.
|
||||||
EndpointV4 netip.AddrPort // Reported IPv4 endpoint.
|
EndpointV4 netip.AddrPort // Reported IPv4 endpoint.
|
||||||
EndpointV6 netip.AddrPort // Reported IPv6 endpoint.
|
EndpointV6 netip.AddrPort // Reported IPv6 endpoint.
|
||||||
|
EndpointLAN netip.AddrPort // Discovered via multicast.
|
||||||
|
EndpointWG netip.AddrPort // Current wireguard endpoint.
|
||||||
RTT time.Duration // Round-trip time.
|
RTT time.Duration // Round-trip time.
|
||||||
|
LastPing time.Time // Last time we had a ping.
|
||||||
|
ProbeStart time.Time // When we started probing.
|
||||||
|
ProbeEndpoint netip.AddrPort
|
||||||
State PeerState // Current routing state; updated on each devXxx call.
|
State PeerState // Current routing state; updated on each devXxx call.
|
||||||
Role control.Role // Client initiates pings; server responds.
|
Role control.Role // Role in relation to the local application.
|
||||||
SignPubKey [32]byte // nacl/sign public key for verifying multicast beacons.
|
SignPubKey [32]byte // nacl/sign public key for verifying multicast beacons.
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -54,7 +59,12 @@ func (p *Peer) LastHandshakeTime() time.Time {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *Peer) Up() bool {
|
func (p *Peer) Up() bool {
|
||||||
return time.Since(p.wgPeer.LastHandshakeTime) < wginterface.SessionTimeout
|
return time.Since(p.LastPing) < 3*PingInterval
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Peer) DirectAlive() bool {
|
||||||
|
return p.WGEndpoint().IsValid() &&
|
||||||
|
time.Since(p.LastHandshakeTime()) < 2*wginterface.ProbeKeepalive
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Peer) CanRelay() bool {
|
func (p *Peer) CanRelay() bool {
|
||||||
@@ -62,7 +72,13 @@ func (p *Peer) CanRelay() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *Peer) PreferredEndpoint() netip.AddrPort {
|
func (p *Peer) PreferredEndpoint() netip.AddrPort {
|
||||||
return preferredEndpoint(p.EndpointV4, p.EndpointV6)
|
if p.EndpointLAN.IsValid() {
|
||||||
|
return p.EndpointLAN
|
||||||
|
} else if p.EndpointV4.IsValid() {
|
||||||
|
return p.EndpointV4
|
||||||
|
} else {
|
||||||
|
return p.EndpointV6
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Peer) UpdateEndpoints(v4, v6 netip.AddrPort) {
|
func (p *Peer) UpdateEndpoints(v4, v6 netip.AddrPort) {
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ package wginterface
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"slices"
|
"slices"
|
||||||
@@ -173,6 +174,10 @@ func nlAttr(attrType uint16, data []byte) []byte {
|
|||||||
// messages, but the AF_INET ioctl interface is simpler.
|
// messages, but the AF_INET ioctl interface is simpler.
|
||||||
|
|
||||||
func ioctlSetAddr(name string, ip net.IP, prefixLen int) error {
|
func ioctlSetAddr(name string, ip net.IP, prefixLen int) error {
|
||||||
|
if ip.To4() == nil {
|
||||||
|
return errors.New("attempted to set non-IPv4 address on interface")
|
||||||
|
}
|
||||||
|
|
||||||
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -112,7 +112,7 @@ func (d *Device) SetRelay(pubKey wgtypes.Key, endpoint netip.AddrPort, network n
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddProbe adds a peer with no AllowedIPs and a 5s keepalive. WireGuard will
|
// AddProbe adds a peer with no AllowedIPs and an 8s keepalive. WireGuard will
|
||||||
// attempt handshakes without routing any traffic through this peer yet.
|
// attempt handshakes without routing any traffic through this peer yet.
|
||||||
func (d *Device) AddProbe(pubKey wgtypes.Key, endpoint netip.AddrPort) error {
|
func (d *Device) AddProbe(pubKey wgtypes.Key, endpoint netip.AddrPort) error {
|
||||||
keepalive := ProbeKeepalive
|
keepalive := ProbeKeepalive
|
||||||
|
|||||||
Reference in New Issue
Block a user