Compare commits
65 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a2fe8dc79d | ||
|
|
2b8cc86077 | ||
|
|
cb7c07ac96 | ||
|
|
d8c2990ffd | ||
|
|
32b8b0dc89 | ||
|
|
302e5d00d0 | ||
|
|
85d6d577e3 | ||
|
|
691eb49009 | ||
|
|
b479b37479 | ||
|
|
fe9f15bec9 | ||
|
|
b86b43f1de | ||
|
|
c47d00e694 | ||
|
|
458e1ac603 | ||
|
|
9d57f45aea | ||
|
|
36e9f6149d | ||
|
|
eaa101f976 | ||
|
|
802ca9aba4 | ||
|
|
fa933ae029 | ||
|
|
b4320c9330 | ||
|
|
d02f47cce6 | ||
|
|
797ab8bdef | ||
|
|
f765303daf | ||
|
|
0e3d4ec3a5 | ||
|
|
b875313f7d | ||
|
|
71328eb67e | ||
|
|
3e630ee0ad | ||
|
|
164d1f9d95 | ||
|
|
fa182eca76 | ||
|
|
52ea1a8d42 | ||
|
|
cc21bee798 | ||
|
|
353ef07f92 | ||
|
|
c12ef3341f | ||
|
|
992eabc0e9 | ||
|
|
c45ac83eb0 | ||
|
|
3fe9f63901 | ||
|
|
cfb2a29082 | ||
|
|
9bdb836eaa | ||
|
|
393f79e1d3 | ||
|
|
c911e3e865 | ||
|
|
c356347cf6 | ||
|
|
111f3f4d20 | ||
|
|
b6052ee7b8 | ||
|
|
2e19f0945f | ||
|
|
c325180a1b | ||
|
|
c4a81cf553 | ||
|
|
0f117e5e66 | ||
|
|
68f01f9823 | ||
|
|
8983c0d651 | ||
|
|
0cd5982a3f | ||
|
|
243e75dd09 | ||
|
|
1f7d3151b5 | ||
|
|
0709c4dac0 | ||
|
|
c0126c2036 | ||
|
|
528e67ea61 | ||
|
|
fe5f26ed70 | ||
|
|
75782c4efd | ||
|
|
867b3b5949 | ||
|
|
76fce15e32 | ||
|
|
3dcd1c1080 | ||
|
|
232b68310c | ||
|
|
a730211167 | ||
|
|
f3d8a9ff75 | ||
|
|
98f07457b9 | ||
|
|
cbc901496c | ||
|
|
cd5442f3bf |
@@ -9,7 +9,7 @@ import (
|
||||
|
||||
"vppn/peer"
|
||||
|
||||
"git.crumpington.com/lib/go/flock"
|
||||
"git.crumpington.com/lib/flock"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -34,9 +34,6 @@ func main() {
|
||||
if err != nil {
|
||||
log.Fatalf("lock: %v", err)
|
||||
}
|
||||
if lockFile == nil {
|
||||
log.Fatalf("already running for network %q", *name)
|
||||
}
|
||||
defer flock.Unlock(lockFile)
|
||||
|
||||
state, err := peer.LoadOrInit(vppnPath(*name, "state.json"), *hub, apiKey)
|
||||
|
||||
10
go.mod
10
go.mod
@@ -3,7 +3,13 @@ module vppn
|
||||
go 1.25.1
|
||||
|
||||
require (
|
||||
git.crumpington.com/lib/go v0.10.0
|
||||
git.crumpington.com/lib/flock v1.1.0
|
||||
git.crumpington.com/lib/idgen v1.0.0
|
||||
git.crumpington.com/lib/keyedmutex v1.1.0
|
||||
git.crumpington.com/lib/ratelimiter v1.1.1
|
||||
git.crumpington.com/lib/sqliteutil v1.1.1
|
||||
git.crumpington.com/lib/webutil v1.1.0
|
||||
github.com/mattn/go-sqlite3 v1.14.45
|
||||
golang.org/x/crypto v0.53.0
|
||||
golang.org/x/sys v0.46.0
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
|
||||
@@ -11,8 +17,6 @@ require (
|
||||
|
||||
require (
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/josharian/native v1.1.0 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.45 // indirect
|
||||
github.com/mdlayher/genetlink v1.4.0 // indirect
|
||||
github.com/mdlayher/netlink v1.11.2 // indirect
|
||||
github.com/mdlayher/socket v0.6.1 // indirect
|
||||
|
||||
42
go.sum
42
go.sum
@@ -1,55 +1,37 @@
|
||||
git.crumpington.com/lib/go v0.9.1 h1:xLBzcgiZRB6Ky3Ce9hKE+Ko0YbkA4USF4eJk5i5RJF4=
|
||||
git.crumpington.com/lib/go v0.9.1/go.mod h1:5nnfjdnUnj/FHhakaliKQKsKeSkUb0GEUKF3PqRgUXg=
|
||||
git.crumpington.com/lib/go v0.9.2 h1:DZ7tzFM/S+zL5hexNo8zKbH7Ryi+VtvSMRzCMnlz+c4=
|
||||
git.crumpington.com/lib/go v0.9.2/go.mod h1:5nnfjdnUnj/FHhakaliKQKsKeSkUb0GEUKF3PqRgUXg=
|
||||
git.crumpington.com/lib/go v0.10.0 h1:4O+o9QBVcre8RYESAXhxJ1kT0w1tIakUdt/rV4v4riw=
|
||||
git.crumpington.com/lib/go v0.10.0/go.mod h1:8y838PnV7dM6QT0XwLMuG2ulDNtCv4NmdSJIEqGViKg=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
git.crumpington.com/lib/flock v1.1.0 h1:NzPUAXnywikN+ZPabzQw9eXAwvZolGUE3pjnSxnDwFk=
|
||||
git.crumpington.com/lib/flock v1.1.0/go.mod h1:prUmtkjpGDUakQh6TiEAylrgDTPG0HuBOUe8Lq4HKsc=
|
||||
git.crumpington.com/lib/idgen v1.0.0 h1:0Jre8R3B+RaMOKmCgagBT659wGM93QNpamuGF2e9SII=
|
||||
git.crumpington.com/lib/idgen v1.0.0/go.mod h1:Q8kV11Zta4P5WKDpBwsekEsnOe9IysVLsW+gPhbzFTc=
|
||||
git.crumpington.com/lib/keyedmutex v1.1.0 h1:XOlk9f0rnwmr5yNoIvPteM2W2uakZqT4tnZKficrXho=
|
||||
git.crumpington.com/lib/keyedmutex v1.1.0/go.mod h1:ova6v/794UCZJ5FKKrLpaol0wfNZZTB3plLObSWaGk4=
|
||||
git.crumpington.com/lib/ratelimiter v1.1.1 h1:8jVDVK/I0zzE3EHCu+sUeZN8a9Aqzm+PG4WrlnEvLes=
|
||||
git.crumpington.com/lib/ratelimiter v1.1.1/go.mod h1:TycyPTi/aBfnWW8F51yfo/5fSP/qKywDREqsph7TEns=
|
||||
git.crumpington.com/lib/sqliteutil v1.1.1 h1:xwfp/l2BL4nfw8Ye0Cex2HdGJQKQ1YBCFtDiMeUhnzk=
|
||||
git.crumpington.com/lib/sqliteutil v1.1.1/go.mod h1:K8OelqOwhSYAZK42v8hKK6UmafItGf2WcMfNlq9Gfeo=
|
||||
git.crumpington.com/lib/webutil v1.1.0 h1:S9CaRBbVgYOUsgZ5AU1gAJxkxzr8Zjn2v84MoMOy1+I=
|
||||
git.crumpington.com/lib/webutil v1.1.0/go.mod h1:+LNLGApoe9InAJ7DCeLfiDmYov87XU3crYRHr/RYv2E=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
|
||||
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
|
||||
github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
|
||||
github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.45 h1:6KA/spDguL3KV8rnybG7ezSaE4SeMR3KC9VbUoAQaIk=
|
||||
github.com/mattn/go-sqlite3 v1.14.45/go.mod h1:pjEuOr8IwzLJP2MfGeTb0A35jauH+C2kbHKBr7yXKVQ=
|
||||
github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw=
|
||||
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
|
||||
github.com/mdlayher/genetlink v1.4.0 h1:f/Xs7Y2T+GyX9b3dbiUhnLE9InGs5F9RxJ2JwBMl71o=
|
||||
github.com/mdlayher/genetlink v1.4.0/go.mod h1:d1hrKr8fwZU2JkcAtQUAzeTrI7nbgQSl+5k1cC0biSA=
|
||||
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
|
||||
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
|
||||
github.com/mdlayher/netlink v1.11.2 h1:HKh2jqe+omdSWcQ88nrT7INE61B0NXfiSPFdgL4YbNI=
|
||||
github.com/mdlayher/netlink v1.11.2/go.mod h1:uT2Yc/QLaZubzDpZIBi9d4GoeLwtp3x1AMeqSRrK2sA=
|
||||
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
|
||||
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
|
||||
github.com/mdlayher/socket v0.6.1 h1:M7uj2NtuujUY4mYr1C57NmfNiRHbkKpnBxO856lsc3A=
|
||||
github.com/mdlayher/socket v0.6.1/go.mod h1:+/SGtqc9V+5dAuRgQsU0fGBI+oRDiW7O2Obx10OIWfg=
|
||||
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
|
||||
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
|
||||
golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI=
|
||||
golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
|
||||
golang.org/x/crypto v0.53.0 h1:QZ4Muo8THX6CizN2vPPd5fBGHyogrdK9fG4wLPFUsto=
|
||||
golang.org/x/crypto v0.53.0/go.mod h1:DNLU434OwVakk9PzuwV8w62mAJpRJL3vsgcfp4Qnsio=
|
||||
golang.org/x/net v0.44.0 h1:evd8IRDyfNBMBTTY5XRF1vaZlD+EmWx6x8PkhR04H/I=
|
||||
golang.org/x/net v0.44.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
|
||||
golang.org/x/net v0.56.0 h1:Rw8j/hFzGvJUZwNBXnAtf5sVDVt+65SK2C7IxCxZt5o=
|
||||
golang.org/x/net v0.56.0/go.mod h1:D3Ku6r+V6JROoZK144D2XfMHFcMq/0zSfLelVTCFKec=
|
||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sync v0.21.0 h1:HLII4xRRTtCRkxYp4HNFF0Js/Og6q2i++KXbg0gHCwM=
|
||||
golang.org/x/sync v0.21.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
|
||||
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.46.0 h1:noSf2Fq6F8DBgS+LysIkx7rIExoNHJsxOAtPp4rthXw=
|
||||
golang.org/x/sys v0.46.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
|
||||
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
|
||||
golang.org/x/text v0.38.0 h1:sXmwo9DwP3OK9EZ7PqAdaooSGozfl/3a6/xJcbzPRhE=
|
||||
golang.org/x/text v0.38.0/go.mod h1:YXZt3QhHUKYT53r2lLKFIVi6Ao1jdzrTR/KQ09qyxF4=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20260522210424-ecfc5a8d5446 h1:cqHQ3AycTHvM2R7ikgyX57D+XvtcSnGylsLkOVhta/w=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20260522210424-ecfc5a8d5446/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU=
|
||||
|
||||
110
hub/api/api.go
110
hub/api/api.go
@@ -8,10 +8,11 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
"vppn/hub/api/db"
|
||||
"vppn/hub/errs"
|
||||
"vppn/m"
|
||||
|
||||
"git.crumpington.com/lib/go/idgen"
|
||||
"git.crumpington.com/lib/go/sqliteutil"
|
||||
"git.crumpington.com/lib/idgen"
|
||||
"git.crumpington.com/lib/sqliteutil"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
@@ -26,7 +27,8 @@ type API struct {
|
||||
}
|
||||
|
||||
func New(dbPath string) (*API, error) {
|
||||
sqlDB, err := sql.Open("sqlite3", dbPath+"?_journal=WAL")
|
||||
dbPath += "?_journal=WAL&_foreign_keys=on&_busy_timeout=5000&_txlock=immediate"
|
||||
sqlDB, err := sql.Open("sqlite3", dbPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -64,30 +66,33 @@ func (a *API) ensurePassword() error {
|
||||
|
||||
hashed, err := bcrypt.GenerateFromPassword([]byte(pwd), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return err
|
||||
log.Printf("Failed to generate password: %v", err)
|
||||
return errs.ErrUnexpected
|
||||
}
|
||||
|
||||
conf := &Config{ConfigID: 1, Password: hashed}
|
||||
return db.Config_Insert(a.db, conf)
|
||||
return errs.DB(db.Config_Insert(a.db, conf))
|
||||
}
|
||||
|
||||
func (a *API) Config_Get() (*Config, error) {
|
||||
return db.Config_Get(a.db, 1)
|
||||
conf, err := db.Config_Get(a.db, 1)
|
||||
return conf, errs.DB(err)
|
||||
}
|
||||
|
||||
func (a *API) Config_Update(conf *Config) error {
|
||||
return db.Config_Update(a.db, conf)
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
return errs.DB(db.Config_Update(a.db, conf))
|
||||
}
|
||||
|
||||
func (a *API) Session_Delete(sessionID string) error {
|
||||
func (a *API) Session_Delete(sessionID string) {
|
||||
a.sessionsMu.Lock()
|
||||
defer a.sessionsMu.Unlock()
|
||||
delete(a.sessions, sessionID)
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
sessionTTLSecs = 86400 * 21 // sessions expire 21 days after last use
|
||||
sessionTTL = 24 * 21 * time.Hour // sessions expire 21 days after last use
|
||||
sessionSweepEvery = time.Hour // cadence of expired-session eviction
|
||||
)
|
||||
|
||||
@@ -96,23 +101,23 @@ const (
|
||||
// creates a session, so anonymous requests cost no memory — a session is minted
|
||||
// only by Session_SignIn. Returning a value (not the stored pointer) keeps
|
||||
// callers from racing on the shared struct.
|
||||
func (a *API) Session_Get(sessionID string) (Session, error) {
|
||||
func (a *API) Session_Get(sessionID string) Session {
|
||||
a.sessionsMu.Lock()
|
||||
defer a.sessionsMu.Unlock()
|
||||
|
||||
s, ok := a.sessions[sessionID]
|
||||
|
||||
if sessionID == "" || !ok {
|
||||
return Session{}, nil
|
||||
return Session{}
|
||||
}
|
||||
|
||||
if timeSince(s.LastSeenAt) > sessionTTLSecs {
|
||||
if time.Since(s.LastSeenAt) > sessionTTL {
|
||||
delete(a.sessions, sessionID)
|
||||
return Session{}, nil
|
||||
return Session{}
|
||||
}
|
||||
|
||||
s.LastSeenAt = time.Now().Unix()
|
||||
return *s, nil
|
||||
s.LastSeenAt = time.Now()
|
||||
return *s
|
||||
}
|
||||
|
||||
// Session_SignIn verifies pwd and, on success, mints a fresh signed-in session,
|
||||
@@ -121,24 +126,36 @@ func (a *API) Session_Get(sessionID string) (Session, error) {
|
||||
func (a *API) Session_SignIn(pwd string) (Session, error) {
|
||||
conf, err := a.Config_Get()
|
||||
if err != nil {
|
||||
return Session{}, err
|
||||
log.Printf("Failed to get config: %v", err)
|
||||
return Session{}, errs.ErrUnexpected
|
||||
}
|
||||
if err := bcrypt.CompareHashAndPassword(conf.Password, []byte(pwd)); err != nil {
|
||||
return Session{}, ErrNotAuthorized
|
||||
return Session{}, errs.ErrNotAuthorized
|
||||
}
|
||||
|
||||
a.sessionsMu.Lock()
|
||||
defer a.sessionsMu.Unlock()
|
||||
s := &Session{
|
||||
SessionID: idgen.NewToken(),
|
||||
SignedIn: true,
|
||||
CreatedAt: time.Now().Unix(),
|
||||
LastSeenAt: time.Now().Unix(),
|
||||
LastSeenAt: time.Now(),
|
||||
}
|
||||
a.sessions[s.SessionID] = s
|
||||
return *s, nil
|
||||
}
|
||||
|
||||
func (a *API) Session_InvalidateAll() Session {
|
||||
a.sessionsMu.Lock()
|
||||
defer a.sessionsMu.Unlock()
|
||||
|
||||
clear(a.sessions)
|
||||
s := &Session{
|
||||
SessionID: idgen.NewToken(),
|
||||
LastSeenAt: time.Now(),
|
||||
}
|
||||
a.sessions[s.SessionID] = s
|
||||
return *s
|
||||
}
|
||||
|
||||
// sweepSessions periodically evicts sessions past their TTL. Without it, a
|
||||
// signed-in session whose ID is never presented again would linger forever
|
||||
// (Session_Get only evicts on a lookup of that same ID).
|
||||
@@ -146,7 +163,7 @@ func (a *API) sweepSessions() {
|
||||
for range time.Tick(sessionSweepEvery) {
|
||||
a.sessionsMu.Lock()
|
||||
for id, s := range a.sessions {
|
||||
if timeSince(s.LastSeenAt) > sessionTTLSecs {
|
||||
if time.Since(s.LastSeenAt) > sessionTTL {
|
||||
delete(a.sessions, id)
|
||||
}
|
||||
}
|
||||
@@ -155,29 +172,48 @@ func (a *API) sweepSessions() {
|
||||
}
|
||||
|
||||
func (a *API) Network_Create(n *Network) error {
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
|
||||
n.NetworkID = idgen.NextID(0)
|
||||
return db.Network_Insert(a.db, n)
|
||||
return errs.DB(db.Network_Insert(a.db, n))
|
||||
}
|
||||
|
||||
func (a *API) Network_Delete(n *Network) error {
|
||||
return db.Network_Delete(a.db, n.NetworkID)
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
|
||||
exists, err := db.Network_HasPeers(a.db, n.NetworkID)
|
||||
if err != nil {
|
||||
return errs.DB(err)
|
||||
}
|
||||
if exists {
|
||||
return errs.Conflict.WithMsg("Delete all peers before deleting network.")
|
||||
}
|
||||
|
||||
return errs.DB(db.Network_Delete(a.db, n.NetworkID))
|
||||
}
|
||||
|
||||
func (a *API) Network_Get(id int64) (*Network, error) {
|
||||
return db.Network_Get(a.db, id)
|
||||
n, err := db.Network_Get(a.db, id)
|
||||
return n, errs.DB(err)
|
||||
}
|
||||
|
||||
func (a *API) Network_List() ([]*Network, error) {
|
||||
const query = db.Network_SelectQuery + ` ORDER BY LocalDomain ASC`
|
||||
return db.Network_List(a.db, query)
|
||||
n, err := db.Network_List(a.db, query)
|
||||
return n, errs.DB(err)
|
||||
}
|
||||
|
||||
func (a *API) Peer_CreateNew(p *Peer) error {
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
|
||||
p.WGPubKey = []byte{}
|
||||
p.SignPubKey = []byte{}
|
||||
p.APIKey = idgen.NewToken()
|
||||
|
||||
return db.Peer_Insert(a.db, p)
|
||||
return errs.DB(db.Peer_Insert(a.db, p))
|
||||
}
|
||||
|
||||
func (a *API) Peer_Init(peer *Peer, args m.PeerInitArgs) error {
|
||||
@@ -188,30 +224,36 @@ func (a *API) Peer_Init(peer *Peer, args m.PeerInitArgs) error {
|
||||
// we held the lock, so it may be stale under concurrent requests.
|
||||
current, err := db.Peer_Get(a.db, peer.NetworkID, peer.PeerIP)
|
||||
if err != nil {
|
||||
return err
|
||||
return errs.DB(err)
|
||||
}
|
||||
if len(current.WGPubKey) != 0 {
|
||||
return errors.New("peer already initialized")
|
||||
return errs.ErrAlreadyExists
|
||||
}
|
||||
|
||||
peer.WGPubKey = args.WGPubKey
|
||||
peer.SignPubKey = args.SignPubKey
|
||||
|
||||
return db.Peer_UpdateFull(a.db, peer)
|
||||
return errs.DB(db.Peer_UpdateFull(a.db, peer))
|
||||
}
|
||||
|
||||
func (a *API) Peer_Delete(networkID int64, peerIP byte) error {
|
||||
return db.Peer_Delete(a.db, networkID, peerIP)
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
|
||||
return errs.DB(db.Peer_Delete(a.db, networkID, peerIP))
|
||||
}
|
||||
|
||||
func (a *API) Peer_List(networkID int64) ([]*Peer, error) {
|
||||
return db.Peer_ListAll(a.db, networkID)
|
||||
p, err := db.Peer_ListAll(a.db, networkID)
|
||||
return p, errs.DB(err)
|
||||
}
|
||||
|
||||
func (a *API) Peer_Get(networkID int64, ip byte) (*Peer, error) {
|
||||
return db.Peer_Get(a.db, networkID, ip)
|
||||
p, err := db.Peer_Get(a.db, networkID, ip)
|
||||
return p, errs.DB(err)
|
||||
}
|
||||
|
||||
func (a *API) Peer_GetByAPIKey(key string) (*Peer, error) {
|
||||
return db.Peer_GetByAPIKey(a.db, key)
|
||||
p, err := db.Peer_GetByAPIKey(a.db, key)
|
||||
return p, errs.DB(err)
|
||||
}
|
||||
|
||||
@@ -51,7 +51,7 @@ func Config_Update(
|
||||
|
||||
n, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
return err
|
||||
}
|
||||
switch n {
|
||||
case 0:
|
||||
|
||||
@@ -1,19 +1,9 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/netip"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidIP = errors.New("invalid IP")
|
||||
ErrInvalidPeerIP = errors.New("invalid peer IP")
|
||||
ErrNonPrivateIP = errors.New("non-private IP")
|
||||
ErrInvalidPort = errors.New("invalid port")
|
||||
ErrInvalidNetName = errors.New("invalid network name")
|
||||
ErrNetNameNotLocal = errors.New("network name must end with .local")
|
||||
ErrInvalidPeerName = errors.New("invalid peer name")
|
||||
"vppn/hub/errs"
|
||||
)
|
||||
|
||||
func Config_Sanitize(c *Config) {
|
||||
@@ -35,11 +25,11 @@ func Network_Validate(c *Network) error {
|
||||
// 15 bytes is linux limit for network interface names. With ending .local,
|
||||
// max length is 21.
|
||||
if len(c.LocalDomain) == 0 || len(c.LocalDomain) > 21 {
|
||||
return ErrInvalidNetName
|
||||
return errs.ErrInvalidNetName
|
||||
}
|
||||
|
||||
if !strings.HasSuffix(c.LocalDomain, ".local") {
|
||||
return ErrNetNameNotLocal
|
||||
return errs.ErrNetNameNotLocal
|
||||
}
|
||||
|
||||
for _, c := range strings.TrimSuffix(c.LocalDomain, ".local") {
|
||||
@@ -49,23 +39,23 @@ func Network_Validate(c *Network) error {
|
||||
if c >= '0' && c <= '9' {
|
||||
continue
|
||||
}
|
||||
return ErrInvalidNetName
|
||||
return errs.ErrInvalidNetName
|
||||
}
|
||||
|
||||
addr, ok := netip.AddrFromSlice(c.Network)
|
||||
if !ok || !addr.Is4() || addr.As4()[3] != 0 || addr.As4()[0] == 0 {
|
||||
return ErrInvalidIP
|
||||
return errs.ErrInvalidIP
|
||||
}
|
||||
|
||||
if !addr.IsPrivate() {
|
||||
return ErrNonPrivateIP
|
||||
return errs.ErrNonPrivateIP
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func Peer_Sanitize(p *Peer) {
|
||||
p.Name = strings.TrimSpace(p.Name)
|
||||
p.Name = strings.TrimSpace(strings.ToLower(p.Name))
|
||||
if len(p.Addr4) != 0 {
|
||||
if addr, ok := netip.AddrFromSlice(p.Addr4); ok {
|
||||
// Unmap so an IPv4-mapped form is stored canonically as 4 bytes.
|
||||
@@ -84,26 +74,26 @@ func Peer_Sanitize(p *Peer) {
|
||||
|
||||
func Peer_Validate(p *Peer) error {
|
||||
if p.PeerIP < 1 || p.PeerIP > 254 {
|
||||
return ErrInvalidPeerIP
|
||||
return errs.ErrInvalidPeerIP
|
||||
}
|
||||
if len(p.Addr4) > 0 {
|
||||
// Must be a genuine IPv4 address (reject an IPv6 in the v4 field).
|
||||
if addr, ok := netip.AddrFromSlice(p.Addr4); !ok || !addr.Is4() {
|
||||
return ErrInvalidIP
|
||||
return errs.ErrInvalidIP
|
||||
}
|
||||
}
|
||||
if len(p.Addr6) > 0 {
|
||||
// Must be a genuine IPv6 address (reject IPv4 / IPv4-mapped in the v6 field).
|
||||
if addr, ok := netip.AddrFromSlice(p.Addr6); !ok || !addr.Is6() || addr.Is4In6() {
|
||||
return ErrInvalidIP
|
||||
return errs.ErrInvalidIP
|
||||
}
|
||||
}
|
||||
if p.Port == 0 {
|
||||
return ErrInvalidPort
|
||||
return errs.ErrInvalidPort
|
||||
}
|
||||
|
||||
if len(p.Name) == 0 {
|
||||
return ErrInvalidPeerName
|
||||
if len(p.Name) == 0 || len(p.Name) > 63 {
|
||||
return errs.ErrInvalidPeerName
|
||||
}
|
||||
for _, c := range p.Name {
|
||||
if c >= 'a' && c <= 'z' {
|
||||
@@ -115,7 +105,7 @@ func Peer_Validate(p *Peer) error {
|
||||
if c == '-' {
|
||||
continue
|
||||
}
|
||||
return ErrInvalidPeerName
|
||||
return errs.ErrInvalidPeerName
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -11,3 +11,9 @@ func Peer_GetByAPIKey(tx TX, apiKey string) (*Peer, error) {
|
||||
Peer_SelectQuery+` WHERE APIKey=?`,
|
||||
apiKey)
|
||||
}
|
||||
|
||||
func Network_HasPeers(tx TX, networkID int64) (exists bool, err error) {
|
||||
const query = "SELECT EXISTS(SELECT 1 FROM peers WHERE NetworkID=?)"
|
||||
err = tx.QueryRow(query, networkID).Scan(&exists)
|
||||
return exists, err
|
||||
}
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"vppn/hub/api/db"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNotAuthorized = errors.New("not authorized")
|
||||
ErrInvalidIP = db.ErrInvalidIP
|
||||
ErrInvalidPort = db.ErrInvalidPort
|
||||
)
|
||||
@@ -21,5 +21,6 @@ CREATE TABLE peers (
|
||||
WGPubKey BLOB NOT NULL,
|
||||
SignPubKey BLOB NOT NULL,
|
||||
UNIQUE(NetworkID, Name),
|
||||
PRIMARY KEY(NetworkID, PeerIP)
|
||||
PRIMARY KEY(NetworkID, PeerIP),
|
||||
FOREIGN KEY(NetworkID) REFERENCES networks(NetworkID)
|
||||
) WITHOUT ROWID;
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
package api
|
||||
|
||||
import "time"
|
||||
|
||||
func timeSince(ts int64) int64 {
|
||||
return time.Now().Unix() - ts
|
||||
}
|
||||
@@ -1,6 +1,9 @@
|
||||
package api
|
||||
|
||||
import "vppn/hub/api/db"
|
||||
import (
|
||||
"time"
|
||||
"vppn/hub/api/db"
|
||||
)
|
||||
|
||||
type Config = db.Config
|
||||
type Network = db.Network
|
||||
@@ -8,7 +11,5 @@ type Peer = db.Peer
|
||||
|
||||
type Session struct {
|
||||
SessionID string
|
||||
SignedIn bool
|
||||
CreatedAt int64
|
||||
LastSeenAt int64
|
||||
LastSeenAt time.Time
|
||||
}
|
||||
|
||||
@@ -8,7 +8,8 @@ import (
|
||||
"path/filepath"
|
||||
"vppn/hub/api"
|
||||
|
||||
"git.crumpington.com/lib/go/webutil"
|
||||
"git.crumpington.com/lib/keyedmutex"
|
||||
"git.crumpington.com/lib/webutil"
|
||||
)
|
||||
|
||||
//go:embed static
|
||||
@@ -28,6 +29,9 @@ type App struct {
|
||||
mux *http.ServeMux
|
||||
tmpl map[string]*template.Template
|
||||
insecure bool
|
||||
|
||||
// Per-remote address sign-in serialization lock.
|
||||
signInLock *keyedmutex.KeyedMutex[string]
|
||||
}
|
||||
|
||||
func NewApp(conf Config) (*App, error) {
|
||||
@@ -41,6 +45,7 @@ func NewApp(conf Config) (*App, error) {
|
||||
mux: http.NewServeMux(),
|
||||
tmpl: webutil.ParseTemplateSet(templateFuncs, templateFS),
|
||||
insecure: conf.Insecure,
|
||||
signInLock: keyedmutex.New[string](),
|
||||
}
|
||||
|
||||
app.registerRoutes()
|
||||
|
||||
37
hub/errs/db.go
Normal file
37
hub/errs/db.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package errs
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"log"
|
||||
|
||||
sqlite3 "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func DB(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var e *Error
|
||||
if errors.As(err, &e) {
|
||||
return err
|
||||
}
|
||||
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return ErrNotFound
|
||||
}
|
||||
|
||||
var se sqlite3.Error
|
||||
if errors.As(err, &se) {
|
||||
switch se.ExtendedCode {
|
||||
case sqlite3.ErrConstraintUnique, sqlite3.ErrConstraintPrimaryKey:
|
||||
return ErrAlreadyExists
|
||||
case sqlite3.ErrConstraintForeignKey, sqlite3.ErrConstraintCheck:
|
||||
return ErrConstraint
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("Unexpected error: %v", err)
|
||||
return ErrUnexpected
|
||||
}
|
||||
61
hub/errs/types.go
Normal file
61
hub/errs/types.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package errs
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Error struct {
|
||||
Code int
|
||||
Msg string
|
||||
}
|
||||
|
||||
func (e *Error) Error() string {
|
||||
return fmt.Sprintf("[%d] %s", e.Code, e.Msg)
|
||||
}
|
||||
|
||||
var (
|
||||
ErrNotAuthorized = NotAuthorized.WithMsg("Not authorized.")
|
||||
ErrInvalidPassword = BadRequest.WithMsg("Invalid password.")
|
||||
ErrPasswordMismatch = BadRequest.WithMsg("Passwords don't match.")
|
||||
ErrUnexpected = Internal.WithMsg("Unexpected internal error.")
|
||||
ErrNotFound = NotFound.WithMsg("Not found.")
|
||||
ErrAlreadyExists = Conflict.WithMsg("Already exists.")
|
||||
|
||||
// Validation errors.
|
||||
ErrInvalidIP = BadRequest.WithMsg("Invalid IP.")
|
||||
ErrInvalidPeerIP = BadRequest.WithMsg("Invalid peer IP.")
|
||||
ErrNonPrivateIP = BadRequest.WithMsg("Non-private IP.")
|
||||
ErrInvalidPort = BadRequest.WithMsg("Invalid port.")
|
||||
ErrInvalidNetName = BadRequest.WithMsg("Invalid network name.")
|
||||
ErrNetNameNotLocal = BadRequest.WithMsg("Network name must end with .local.")
|
||||
ErrInvalidPeerName = BadRequest.WithMsg("Invalid peer name.")
|
||||
ErrConstraint = BadRequest.WithMsg("Constraint error.")
|
||||
)
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type Type struct {
|
||||
Code int
|
||||
Msg string
|
||||
}
|
||||
|
||||
func (t Type) WithErr(err error) *Error {
|
||||
return &Error{Code: t.Code, Msg: err.Error()}
|
||||
}
|
||||
|
||||
func (t Type) WithMsg(msg string) *Error {
|
||||
return &Error{Code: t.Code, Msg: msg}
|
||||
}
|
||||
|
||||
func (t Type) WithMsgf(msg string, args ...any) *Error {
|
||||
return &Error{Code: t.Code, Msg: fmt.Sprintf(msg, args...)}
|
||||
}
|
||||
|
||||
var (
|
||||
Internal = Type{Code: http.StatusInternalServerError}
|
||||
NotAuthorized = Type{Code: http.StatusUnauthorized}
|
||||
NotFound = Type{Code: http.StatusNotFound}
|
||||
BadRequest = Type{Code: http.StatusBadRequest}
|
||||
Conflict = Type{Code: http.StatusConflict}
|
||||
)
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"net/url"
|
||||
"vppn/hub/api"
|
||||
|
||||
"git.crumpington.com/lib/go/webutil"
|
||||
"git.crumpington.com/lib/webutil"
|
||||
)
|
||||
|
||||
func (app *App) formGetNetwork(form url.Values) (*api.Network, error) {
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
package hub
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"net/http"
|
||||
"vppn/hub/api"
|
||||
|
||||
"git.crumpington.com/lib/go/webutil"
|
||||
"vppn/hub/errs"
|
||||
)
|
||||
|
||||
type handlerFunc func(s *api.Session, w http.ResponseWriter, r *http.Request) error
|
||||
@@ -13,32 +13,26 @@ type handlerFunc func(s *api.Session, w http.ResponseWriter, r *http.Request) er
|
||||
func (app *App) handlePub(pattern string, fn handlerFunc) {
|
||||
wrapped := func(w http.ResponseWriter, r *http.Request) {
|
||||
sessionID := app.getCookie(r, sessionIDCookieName)
|
||||
s, err := app.api.Session_Get(sessionID)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get session: %v", err)
|
||||
http.Error(w, "Internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
s := app.api.Session_Get(sessionID)
|
||||
|
||||
if r.Method == http.MethodPost {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, 128*1024)
|
||||
r.ParseMultipartForm(64 * 1024)
|
||||
} else {
|
||||
r.ParseForm()
|
||||
}
|
||||
|
||||
if err := fn(&s, w, r); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
handleError(w, err)
|
||||
}
|
||||
}
|
||||
|
||||
app.mux.HandleFunc(pattern,
|
||||
webutil.WithLogging(
|
||||
wrapped))
|
||||
app.mux.HandleFunc(pattern, withLogging(wrapped))
|
||||
}
|
||||
|
||||
func (app *App) handleNotSignedIn(pattern string, fn handlerFunc) {
|
||||
app.handlePub(pattern, func(s *api.Session, w http.ResponseWriter, r *http.Request) error {
|
||||
if s.SignedIn {
|
||||
if s.SessionID != "" {
|
||||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||||
return nil
|
||||
}
|
||||
@@ -48,7 +42,7 @@ func (app *App) handleNotSignedIn(pattern string, fn handlerFunc) {
|
||||
|
||||
func (app *App) handleSignedIn(pattern string, fn handlerFunc) {
|
||||
app.handlePub(pattern, func(s *api.Session, w http.ResponseWriter, r *http.Request) error {
|
||||
if !s.SignedIn {
|
||||
if s.SessionID == "" {
|
||||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||||
return nil
|
||||
}
|
||||
@@ -66,6 +60,7 @@ func (app *App) handlePeer(pattern string, fn peerHandlerFunc) {
|
||||
return
|
||||
}
|
||||
|
||||
// Not doing constant time compare because index lookup time dominates.
|
||||
peer, err := app.api.Peer_GetByAPIKey(apiKey)
|
||||
if err != nil {
|
||||
http.Error(w, "Not authorized", http.StatusUnauthorized)
|
||||
@@ -74,12 +69,19 @@ func (app *App) handlePeer(pattern string, fn peerHandlerFunc) {
|
||||
|
||||
r.ParseForm()
|
||||
if err := fn(peer, w, r); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
handleError(w, err)
|
||||
}
|
||||
}
|
||||
|
||||
app.mux.HandleFunc(pattern,
|
||||
webutil.WithLogging(
|
||||
wrapped))
|
||||
app.mux.HandleFunc(pattern, withLogging(wrapped))
|
||||
}
|
||||
|
||||
func handleError(w http.ResponseWriter, err error) {
|
||||
var e *errs.Error
|
||||
if errors.As(err, &e) {
|
||||
http.Error(w, e.Msg, e.Code)
|
||||
} else {
|
||||
log.Printf("Unexpected error: %v", err)
|
||||
http.Error(w, "Internal server error.", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,19 +2,22 @@ package hub
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log"
|
||||
"math/rand/v2"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
"vppn/hub/api"
|
||||
"vppn/hub/errs"
|
||||
"vppn/m"
|
||||
|
||||
"git.crumpington.com/lib/go/webutil"
|
||||
"git.crumpington.com/lib/webutil"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
func (a *App) _root(s *api.Session, w http.ResponseWriter, r *http.Request) error {
|
||||
if s.SignedIn {
|
||||
if s.SessionID != "" {
|
||||
return a.redirect(w, r, "/admin/network/list/")
|
||||
} else {
|
||||
return a.redirect(w, r, "/sign-in/")
|
||||
@@ -26,6 +29,15 @@ func (a *App) _signin(s *api.Session, w http.ResponseWriter, r *http.Request) er
|
||||
}
|
||||
|
||||
func (a *App) _signinSubmit(s *api.Session, w http.ResponseWriter, r *http.Request) error {
|
||||
// Ignoring error here - if host is the empty string, it will contend for the
|
||||
// lock anyway.
|
||||
host, _, _ := net.SplitHostPort(r.RemoteAddr)
|
||||
if !a.signInLock.TryLock(host) {
|
||||
time.Sleep(time.Second + time.Duration(rand.Int64N(int64(3*time.Second))))
|
||||
return errs.ErrNotAuthorized
|
||||
}
|
||||
defer a.signInLock.Unlock(host)
|
||||
|
||||
var pwd string
|
||||
err := webutil.NewFormScanner(r.Form).
|
||||
Scan("Password", &pwd).
|
||||
@@ -36,8 +48,10 @@ func (a *App) _signinSubmit(s *api.Session, w http.ResponseWriter, r *http.Reque
|
||||
|
||||
sess, err := a.api.Session_SignIn(pwd)
|
||||
if err != nil {
|
||||
time.Sleep(time.Second + time.Duration(rand.Int64N(int64(3*time.Second))))
|
||||
return err
|
||||
}
|
||||
|
||||
a.setCookie(w, sessionIDCookieName, sess.SessionID)
|
||||
|
||||
return a.redirect(w, r, "/")
|
||||
@@ -48,9 +62,7 @@ func (a *App) _adminSignOut(s *api.Session, w http.ResponseWriter, r *http.Reque
|
||||
}
|
||||
|
||||
func (a *App) _adminSignOutSubmit(s *api.Session, w http.ResponseWriter, r *http.Request) error {
|
||||
if err := a.api.Session_Delete(s.SessionID); err != nil {
|
||||
log.Printf("Failed to delete session cookie %s: %v", s.SessionID, err)
|
||||
}
|
||||
a.api.Session_Delete(s.SessionID)
|
||||
a.deleteCookie(w, sessionIDCookieName)
|
||||
return a.redirect(w, r, "/")
|
||||
}
|
||||
@@ -236,22 +248,23 @@ func (a *App) _adminPasswordSubmit(s *api.Session, w http.ResponseWriter, r *htt
|
||||
return err
|
||||
}
|
||||
|
||||
if len(newPwd) < 8 {
|
||||
return errors.New("password is too short")
|
||||
if len(newPwd) < 8 || len(newPwd) > 72 {
|
||||
return errs.ErrInvalidPassword
|
||||
}
|
||||
|
||||
if newPwd != newPwd2 {
|
||||
return errors.New("passwords don't match")
|
||||
return errs.ErrPasswordMismatch
|
||||
}
|
||||
|
||||
err = bcrypt.CompareHashAndPassword(conf.Password, []byte(curPwd))
|
||||
if err != nil {
|
||||
return err
|
||||
return errs.ErrNotAuthorized
|
||||
}
|
||||
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(newPwd), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return err
|
||||
log.Printf("Failed to hash password with bcrypt: %v", err)
|
||||
return errs.ErrUnexpected
|
||||
}
|
||||
|
||||
conf.Password = hash
|
||||
@@ -260,7 +273,10 @@ func (a *App) _adminPasswordSubmit(s *api.Session, w http.ResponseWriter, r *htt
|
||||
return err
|
||||
}
|
||||
|
||||
return a.redirect(w, r, "/admin/config/")
|
||||
*s = a.api.Session_InvalidateAll()
|
||||
a.setCookie(w, sessionIDCookieName, s.SessionID)
|
||||
|
||||
return a.redirect(w, r, "/admin/network/list/")
|
||||
}
|
||||
|
||||
func (a *App) _peerInit(peer *api.Peer, w http.ResponseWriter, r *http.Request) error {
|
||||
@@ -269,9 +285,11 @@ func (a *App) _peerInit(peer *api.Peer, w http.ResponseWriter, r *http.Request)
|
||||
return nil
|
||||
}
|
||||
|
||||
r.Body = http.MaxBytesReader(w, r.Body, 2048)
|
||||
|
||||
args := m.PeerInitArgs{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&args); err != nil {
|
||||
return err
|
||||
return errs.BadRequest.WithMsg("Invalid request body.")
|
||||
}
|
||||
|
||||
if len(args.WGPubKey) != 32 {
|
||||
@@ -328,8 +346,10 @@ func (a *App) peersList(networkID int64) (peers []m.Peer, err error) {
|
||||
}
|
||||
wgKey, err := wgtypes.NewKey(p.WGPubKey)
|
||||
if err != nil {
|
||||
log.Printf("Bad WG key in DB for peer %d/%d", p.NetworkID, p.PeerIP)
|
||||
continue // malformed key; skip rather than serve garbage
|
||||
}
|
||||
|
||||
var signKey [32]byte
|
||||
copy(signKey[:], p.SignPubKey)
|
||||
peers = append(peers, m.Peer{
|
||||
|
||||
@@ -5,8 +5,9 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"git.crumpington.com/lib/go/webutil"
|
||||
"git.crumpington.com/lib/webutil"
|
||||
)
|
||||
|
||||
func Main() {
|
||||
@@ -32,6 +33,10 @@ func Main() {
|
||||
srv := &http.Server{
|
||||
Addr: conf.ListenAddr,
|
||||
Handler: app.Handler(),
|
||||
ReadHeaderTimeout: 30 * time.Second,
|
||||
ReadTimeout: 60 * time.Second,
|
||||
WriteTimeout: 120 * time.Second,
|
||||
IdleTimeout: 180 * time.Second,
|
||||
}
|
||||
|
||||
log.Fatal(webutil.ListenAndServe(srv))
|
||||
|
||||
47
hub/middleware.go
Normal file
47
hub/middleware.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package hub
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
var _log = log.New(os.Stderr, "", 0)
|
||||
|
||||
type responseWriterWrapper struct {
|
||||
http.ResponseWriter
|
||||
httpStatus int
|
||||
responseSize int
|
||||
}
|
||||
|
||||
func (w *responseWriterWrapper) WriteHeader(status int) {
|
||||
w.httpStatus = status
|
||||
w.ResponseWriter.WriteHeader(status)
|
||||
}
|
||||
|
||||
func (w *responseWriterWrapper) Write(b []byte) (int, error) {
|
||||
if w.httpStatus == 0 {
|
||||
w.httpStatus = 200
|
||||
}
|
||||
w.responseSize += len(b)
|
||||
return w.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
func withLogging(inner http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
t := time.Now()
|
||||
wrapper := responseWriterWrapper{w, 0, 0}
|
||||
|
||||
inner(&wrapper, r)
|
||||
_log.Printf("%s \"%s %s %s\" %d %d %v",
|
||||
r.RemoteAddr,
|
||||
r.Method,
|
||||
r.URL.Path,
|
||||
r.Proto,
|
||||
wrapper.httpStatus,
|
||||
wrapper.responseSize,
|
||||
time.Since(t),
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,7 @@ package hub
|
||||
import "net/http"
|
||||
|
||||
func (a *App) registerRoutes() {
|
||||
a.mux.Handle("GET /static/", http.FileServerFS(staticFS))
|
||||
a.mux.Handle("GET /static/", withLogging(http.FileServerFS(staticFS).ServeHTTP))
|
||||
a.handlePub("GET /", a._root)
|
||||
|
||||
a.handleNotSignedIn("GET /sign-in/", a._signin)
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
<header>
|
||||
<h1>VPPN</h1>
|
||||
<nav>
|
||||
{{if .Session.SignedIn -}}
|
||||
{{if .Session.SessionID -}}
|
||||
<a href="/admin/networks/list/">Home</a> /
|
||||
<a href="/admin/sign-out/">Sign out</a>
|
||||
{{- end}}
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
<header>
|
||||
<h1>VPPN</h1>
|
||||
<nav>
|
||||
{{if .Session.SignedIn -}}
|
||||
{{if .Session.SessionID -}}
|
||||
<a href="/admin/networks/list/">Home</a> /
|
||||
<a href="/admin/sign-out/">Sign out</a>
|
||||
{{- end}}
|
||||
|
||||
67
peer/app.go
67
peer/app.go
@@ -1,9 +1,13 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sort"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@@ -20,6 +24,7 @@ var _ WGDevice = (*wginterface.Device)(nil) // compile-time check: Device satisf
|
||||
const (
|
||||
ControlPort = 4561
|
||||
PingInterval = 8 * time.Second
|
||||
TickInterval = 2 * time.Second
|
||||
TimeoutInterval = 30 * time.Second
|
||||
)
|
||||
|
||||
@@ -40,7 +45,6 @@ type App struct {
|
||||
vpnNet netip.Prefix
|
||||
privKey wgtypes.Key
|
||||
pubKey wgtypes.Key
|
||||
isRelay bool
|
||||
isPublic bool
|
||||
localDomain string
|
||||
|
||||
@@ -75,13 +79,15 @@ func (a *App) Run() error {
|
||||
// while we were down).
|
||||
a.updateHosts()
|
||||
|
||||
ticker := time.NewTicker(PingInterval)
|
||||
defer ticker.Stop()
|
||||
stateTicker := time.NewTicker(TickInterval)
|
||||
pingTicker := time.NewTicker(PingInterval)
|
||||
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
|
||||
defer signal.Stop(sig)
|
||||
|
||||
tickCount := 0
|
||||
|
||||
for {
|
||||
select {
|
||||
case p := <-a.hubAddCh:
|
||||
@@ -92,8 +98,15 @@ func (a *App) Run() error {
|
||||
a.onPing(e)
|
||||
case e := <-a.multicastCh:
|
||||
a.onMulticastDiscovery(e)
|
||||
case <-ticker.C:
|
||||
a.onTick()
|
||||
case <-stateTicker.C:
|
||||
a.onStateTick()
|
||||
case <-pingTicker.C:
|
||||
a.onPingTick()
|
||||
tickCount++
|
||||
if tickCount%8 == 0 {
|
||||
a.logNetworkState()
|
||||
}
|
||||
|
||||
case <-sig:
|
||||
return a.onShutdown()
|
||||
}
|
||||
@@ -103,3 +116,47 @@ func (a *App) Run() error {
|
||||
func (a *App) onShutdown() error {
|
||||
return wginterface.Delete(a.dev.Name())
|
||||
}
|
||||
|
||||
func (a *App) logNetworkState() {
|
||||
var b strings.Builder
|
||||
fmt.Fprintf(&b, "Network state (self: %s public=%v):\n", a.vpnIP, a.isPublic)
|
||||
fmt.Fprintf(&b, " Network: %v\n", a.vpnNet)
|
||||
fmt.Fprintf(&b, " IPv4: %v\n", a.selfV4)
|
||||
fmt.Fprintf(&b, " IPv6: %v\n", a.selfV6)
|
||||
|
||||
if a.relay != nil {
|
||||
fmt.Fprintf(&b, " Relay: %s\n", a.relay.Name)
|
||||
} else {
|
||||
fmt.Fprint(&b, " Relay: -\n")
|
||||
}
|
||||
|
||||
b.WriteString("Peers:\n")
|
||||
//
|
||||
peers := make([]*Peer, 0, len(a.peersByIP))
|
||||
for _, p := range a.peersByIP {
|
||||
peers = append(peers, p)
|
||||
}
|
||||
|
||||
sort.Slice(peers, func(i, j int) bool {
|
||||
return peers[i].VPNIP.As4()[3] < peers[j].VPNIP.As4()[3]
|
||||
})
|
||||
|
||||
for _, p := range peers {
|
||||
ip := p.VPNIP.As4()[3]
|
||||
up := "DOWN"
|
||||
if p.Up() {
|
||||
up = "UP "
|
||||
}
|
||||
|
||||
endpoint := p.WGEndpoint()
|
||||
if endpoint.IsValid() {
|
||||
fmt.Fprintf(&b, " %24s %03d %s %s seen=%s @ %s\n",
|
||||
p.Name, ip, p.State, up, time.Since(p.LastPing).Round(time.Millisecond), endpoint)
|
||||
} else {
|
||||
fmt.Fprintf(&b, " %24s %03d %s %s seen=%s\n",
|
||||
p.Name, ip, p.State, up, time.Since(p.LastPing).Round(time.Millisecond))
|
||||
}
|
||||
}
|
||||
|
||||
log.Print(b.String())
|
||||
}
|
||||
|
||||
@@ -26,13 +26,14 @@ func addRelayPeer(t *testing.T, a *App, vpnIP string, ep netip.AddrPort) *Peer {
|
||||
})
|
||||
p := a.peersByKey[key]
|
||||
p.wgPeer.LastHandshakeTime = time.Now()
|
||||
p.LastPing = time.Now()
|
||||
return p
|
||||
}
|
||||
|
||||
// newTestApp returns a minimal App wired to a fakeWGDevice and fakeControlConn.
|
||||
// vpnIP is the local VPN address (e.g. "10.0.0.1").
|
||||
// isPublic / isRelay describe the local node's role.
|
||||
func newTestApp(t *testing.T, vpnIP string, isPublic, isRelay bool) (*App, *fakeWGDevice, *fakeControlConn) {
|
||||
func newTestApp(t *testing.T, vpnIP string, isPublic bool) (*App, *fakeWGDevice, *fakeControlConn) {
|
||||
t.Helper()
|
||||
privKey, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
@@ -47,7 +48,6 @@ func newTestApp(t *testing.T, vpnIP string, isPublic, isRelay bool) (*App, *fake
|
||||
privKey: privKey,
|
||||
pubKey: privKey.PublicKey(),
|
||||
isPublic: isPublic,
|
||||
isRelay: isRelay,
|
||||
dev: dev,
|
||||
controlConn: cc,
|
||||
peersByKey: make(map[wgtypes.Key]*Peer),
|
||||
|
||||
@@ -37,8 +37,12 @@ type Ping struct {
|
||||
// buf[:Size]. Taking the buffer lets callers reuse one across sends; every
|
||||
// field is written unconditionally so a reused buffer needs no pre-zeroing.
|
||||
func (p Ping) Marshal(buf []byte) []byte {
|
||||
_ = buf[Size-1] // Panic if buffer is too small.
|
||||
|
||||
buf[0] = version
|
||||
binary.BigEndian.PutUint64(buf[1:9], uint64(p.PingTS))
|
||||
|
||||
// SrcV4.
|
||||
if p.SrcV4.IsValid() {
|
||||
a4 := p.SrcV4.Addr().As4()
|
||||
copy(buf[9:13], a4[:])
|
||||
@@ -46,9 +50,13 @@ func (p Ping) Marshal(buf []byte) []byte {
|
||||
} else {
|
||||
clear(buf[9:15])
|
||||
}
|
||||
|
||||
// SrcV6.
|
||||
a16 := p.SrcV6.Addr().As16()
|
||||
copy(buf[15:31], a16[:])
|
||||
binary.BigEndian.PutUint16(buf[31:33], p.SrcV6.Port())
|
||||
|
||||
// Dst.
|
||||
a16 = p.Dst.Addr().As16()
|
||||
copy(buf[33:49], a16[:])
|
||||
binary.BigEndian.PutUint16(buf[49:51], p.Dst.Port())
|
||||
@@ -63,14 +71,21 @@ func Unmarshal(buf [Size]byte) (Ping, error) {
|
||||
p := Ping{
|
||||
PingTS: int64(binary.BigEndian.Uint64(buf[1:9])),
|
||||
}
|
||||
if addr := netip.AddrFrom4([4]byte(buf[9:13])); !addr.IsUnspecified() {
|
||||
|
||||
addr := netip.AddrFrom4([4]byte(buf[9:13]))
|
||||
if !addr.IsUnspecified() {
|
||||
p.SrcV4 = netip.AddrPortFrom(addr, binary.BigEndian.Uint16(buf[13:15]))
|
||||
}
|
||||
if addr := netip.AddrFrom16([16]byte(buf[15:31])); !addr.IsUnspecified() {
|
||||
|
||||
addr = netip.AddrFrom16([16]byte(buf[15:31])).Unmap()
|
||||
if !addr.IsUnspecified() {
|
||||
p.SrcV6 = netip.AddrPortFrom(addr, binary.BigEndian.Uint16(buf[31:33]))
|
||||
}
|
||||
if addr := netip.AddrFrom16([16]byte(buf[33:49])).Unmap(); !addr.IsUnspecified() {
|
||||
|
||||
addr = netip.AddrFrom16([16]byte(buf[33:49])).Unmap()
|
||||
if !addr.IsUnspecified() {
|
||||
p.Dst = netip.AddrPortFrom(addr, binary.BigEndian.Uint16(buf[49:51]))
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"log"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"vppn/peer/control"
|
||||
)
|
||||
@@ -32,11 +33,14 @@ func (c *udpControlConn) SendPing(dst netip.AddrPort, ping control.Ping, buf []b
|
||||
// run reads incoming ping packets and forwards them to ch until ctx is done.
|
||||
// Call this in a goroutine before starting the App event loop.
|
||||
func (c *udpControlConn) run(ch chan<- PingEvent) {
|
||||
const errorTimeout = 8 * time.Second
|
||||
|
||||
var buf [control.Size]byte
|
||||
for {
|
||||
n, src, err := c.conn.ReadFromUDP(buf[:])
|
||||
if err != nil {
|
||||
log.Printf("control read: %v", err)
|
||||
time.Sleep(errorTimeout)
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@@ -38,22 +38,28 @@ func (a *App) devPeers() []wgtypes.Peer {
|
||||
return peers
|
||||
}
|
||||
|
||||
func (a *App) devAddPeer(p *Peer) {
|
||||
func (a *App) devAddRelayed(p *Peer) {
|
||||
log.Printf("RELAYED: %s - %s ", p.Name, p.VPNIP.String())
|
||||
devRetry(p.VPNIP, "AddPeer", func() error { return a.dev.AddPeer(p.PubKey()) })
|
||||
|
||||
p.State = StateRelayed
|
||||
p.EndpointV4 = netip.AddrPort{}
|
||||
p.EndpointV6 = netip.AddrPort{}
|
||||
p.EndpointLAN = netip.AddrPort{}
|
||||
}
|
||||
|
||||
func (a *App) devAddDirect(p *Peer, endpoint netip.AddrPort) {
|
||||
log.Printf("DIRECT: %s - %s @ %s", p.Name, p.VPNIP.String(), endpoint.String())
|
||||
devRetry(p.VPNIP, "AddDirect", func() error { return a.dev.AddDirect(p.PubKey(), endpoint, p.VPNIP) })
|
||||
|
||||
p.State = StateDirect
|
||||
}
|
||||
|
||||
func (a *App) devSetRelay(p *Peer, endpoint netip.AddrPort) {
|
||||
log.Printf("RELAY: %s - %s @ %s", p.Name, p.VPNIP.String(), endpoint.String())
|
||||
devRetry(p.VPNIP, "SetRelay", func() error { return a.dev.SetRelay(p.PubKey(), endpoint, a.vpnNet) })
|
||||
p.State = StateDirect // Dirrect connection. The app marks peer as relay.
|
||||
|
||||
p.State = StateDirect // Direct connection. The app marks peer as relay.
|
||||
}
|
||||
|
||||
func (a *App) devPromote(p *Peer) {
|
||||
@@ -61,19 +67,24 @@ func (a *App) devPromote(p *Peer) {
|
||||
if ep.IsValid() {
|
||||
log.Printf("PROMOTED: %s - %s @ %s", p.Name, p.VPNIP.String(), p.WGEndpoint().String())
|
||||
} else {
|
||||
log.Printf("PROMOTED: %s - %s (no IP)", p.Name, p.VPNIP.String())
|
||||
log.Printf("DIRECT: %s - %s (waiting for handshake)", p.Name, p.VPNIP.String())
|
||||
}
|
||||
devRetry(p.VPNIP, "Promote", func() error { return a.dev.Promote(p.PubKey(), p.VPNIP) })
|
||||
|
||||
p.State = StateDirect
|
||||
p.LastPing = time.Now() // Assume the peer is up after being promoted.
|
||||
}
|
||||
|
||||
func (a *App) devAddProbe(p *Peer, endpoint netip.AddrPort) {
|
||||
log.Printf("PROBE: %s - %s @ %s", p.Name, p.VPNIP.String(), endpoint.String())
|
||||
devRetry(p.VPNIP, "AddProbe", func() error { return a.dev.AddProbe(p.PubKey(), endpoint) })
|
||||
|
||||
p.State = StateProbing
|
||||
p.ProbeStart = time.Now()
|
||||
p.ProbeEndpoint = endpoint
|
||||
}
|
||||
|
||||
func (a *App) devRemove(p *Peer) {
|
||||
log.Printf("REMOVED: %s - %s", p.Name, p.VPNIP.String())
|
||||
log.Printf("REMOVED: %s", p.PubKey())
|
||||
devRetry(p.VPNIP, "RemovePeer", func() error { return a.dev.RemovePeer(p.PubKey()) })
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"git.crumpington.com/lib/go/flock"
|
||||
"git.crumpington.com/lib/flock"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -62,13 +62,15 @@ func (hp *HubPoller) Run() {
|
||||
hp.apply(state)
|
||||
}
|
||||
|
||||
hp.poll()
|
||||
client := &http.Client{Timeout: 32 * time.Second}
|
||||
|
||||
hp.poll(client)
|
||||
for range time.Tick(hubPollInterval) {
|
||||
hp.poll()
|
||||
hp.poll(client)
|
||||
}
|
||||
}
|
||||
|
||||
func (hp *HubPoller) poll() {
|
||||
func (hp *HubPoller) poll(client *http.Client) {
|
||||
req, err := http.NewRequest(http.MethodGet, hp.hubURL, nil)
|
||||
if err != nil {
|
||||
log.Printf("[HubPoller] build request: %v", err)
|
||||
@@ -76,7 +78,6 @@ func (hp *HubPoller) poll() {
|
||||
}
|
||||
req.SetBasicAuth("", hp.apiKey)
|
||||
|
||||
client := &http.Client{Timeout: 32 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
log.Printf("[HubPoller] fetch: %v", err)
|
||||
@@ -89,7 +90,7 @@ func (hp *HubPoller) poll() {
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 128*1024))
|
||||
if err != nil {
|
||||
log.Printf("[HubPoller] read body: %v", err)
|
||||
return
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/nacl/sign"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
@@ -93,7 +94,7 @@ func initFromHub(hubURL, apiKey string, privKey wgtypes.Key) (LocalState, error)
|
||||
req.SetBasicAuth("", apiKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
resp, err := (&http.Client{Timeout: time.Minute}).Do(req)
|
||||
if err != nil {
|
||||
return LocalState{}, fmt.Errorf("hub init: %w", err)
|
||||
}
|
||||
|
||||
@@ -9,10 +9,6 @@ import (
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
var addr = net.UDPAddrFromAddrPort(netip.AddrPortFrom(
|
||||
netip.AddrFrom4([4]byte{224, 0, 0, 157}),
|
||||
4560))
|
||||
|
||||
func Broadcast(
|
||||
selfVPNIP netip.Addr,
|
||||
pubKey wgtypes.Key,
|
||||
@@ -20,15 +16,19 @@ func Broadcast(
|
||||
signKey *[64]byte,
|
||||
) {
|
||||
for {
|
||||
broadcastInner(selfVPNIP, pubKey, wgPort, signKey)
|
||||
broadcast(selfVPNIP, pubKey, wgPort, signKey)
|
||||
time.Sleep(errorTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func broadcastInner(selfVPNIP netip.Addr, pubKey wgtypes.Key, wgPort uint16, signKey *[64]byte) {
|
||||
func broadcast(selfVPNIP netip.Addr, pubKey wgtypes.Key, wgPort uint16, signKey *[64]byte) {
|
||||
addr := multicastAddr(selfVPNIP)
|
||||
|
||||
log.Printf("[MC Broadcast] Sending on %v.", addr)
|
||||
|
||||
conn, err := net.ListenMulticastUDP("udp", nil, addr)
|
||||
if err != nil {
|
||||
log.Printf("[MCBroadcast] bind: %v", err)
|
||||
log.Printf("[MC Broadcast] bind: %v", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
@@ -44,18 +44,18 @@ func broadcastInner(selfVPNIP netip.Addr, pubKey wgtypes.Key, wgPort uint16, sig
|
||||
// dropped by receivers' freshness gate.
|
||||
send := func() error {
|
||||
packet.Timestamp = time.Now().Unix()
|
||||
payload := packet.Marshal(buf, signKey)
|
||||
payload := packet.marshal(buf, signKey)
|
||||
_, err := conn.WriteToUDP(payload, addr)
|
||||
return err
|
||||
}
|
||||
|
||||
if err := send(); err != nil {
|
||||
log.Printf("[MCBroadcast] write: %v", err)
|
||||
log.Printf("[MC Broadcast] write: %v", err)
|
||||
}
|
||||
|
||||
for range time.Tick(broadcastInterval) {
|
||||
if err := send(); err != nil {
|
||||
log.Printf("[MCBroadcast] write: %v", err)
|
||||
log.Printf("[MC Broadcast] write: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,8 +29,8 @@ type Packet struct {
|
||||
Signed []byte // Raw signed message for verification (incoming packet).
|
||||
}
|
||||
|
||||
// Marshal the packet into a buffer with prefixed signature.
|
||||
func (p Packet) Marshal(buf []byte, signKey *[64]byte) []byte {
|
||||
// marshal the packet into a buffer with prefixed signature.
|
||||
func (p Packet) marshal(buf []byte, signKey *[64]byte) []byte {
|
||||
buf[0] = p.PeerIP
|
||||
copy(buf[1:33], p.WGPubKey[:])
|
||||
binary.BigEndian.PutUint16(buf[33:35], p.WGPort)
|
||||
@@ -43,7 +43,7 @@ func (p Packet) Verify(buf []byte, pubKey *[32]byte) bool {
|
||||
return ok
|
||||
}
|
||||
|
||||
func Unmarshal(signed []byte) (p Packet) {
|
||||
func unmarshal(signed []byte) (p Packet) {
|
||||
buf := signed[signSize:]
|
||||
p.PeerIP = buf[0]
|
||||
copy(p.WGPubKey[:], buf[1:33])
|
||||
|
||||
@@ -21,12 +21,12 @@ func TestPacket(t *testing.T) {
|
||||
}
|
||||
|
||||
buf := make([]byte, BufferSize)
|
||||
signed := p.Marshal(buf, priv)
|
||||
signed := p.marshal(buf, priv)
|
||||
if len(signed) != SignedPacketSize {
|
||||
t.Fatalf("signed length = %d, want %d", len(signed), SignedPacketSize)
|
||||
}
|
||||
|
||||
got := Unmarshal(signed)
|
||||
got := unmarshal(signed)
|
||||
if got.PeerIP != p.PeerIP || got.WGPubKey != p.WGPubKey ||
|
||||
got.WGPort != p.WGPort || got.Timestamp != p.Timestamp {
|
||||
t.Fatalf("round-trip mismatch:\n got %+v\nwant %+v", got, p)
|
||||
|
||||
@@ -7,27 +7,34 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"git.crumpington.com/lib/ratelimiter"
|
||||
)
|
||||
|
||||
func Receiver(vpnNet netip.Prefix, selfVPNIP netip.Addr, ch chan<- Packet) {
|
||||
func Receiver(selfVPNIP netip.Addr, ch chan<- Packet) {
|
||||
for {
|
||||
if err := receiver(vpnNet, selfVPNIP, ch); err != nil {
|
||||
log.Printf("[MCReader] %v", err)
|
||||
if err := receiver(selfVPNIP, ch); err != nil {
|
||||
log.Printf("[MC Receiver] %v", err)
|
||||
}
|
||||
time.Sleep(errorTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func receiver(vpnNet netip.Prefix, selfVPNIP netip.Addr, ch chan<- Packet) error {
|
||||
func receiver(selfVPNIP netip.Addr, ch chan<- Packet) error {
|
||||
limiters := map[netip.Addr]*ratelimiter.Limiter{}
|
||||
|
||||
selfIP := selfVPNIP.As4()[3]
|
||||
|
||||
addr := multicastAddr(selfVPNIP)
|
||||
|
||||
log.Printf("[MC Receiver] Listening on %v.", addr)
|
||||
conn, err := net.ListenMulticastUDP("udp", nil, addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("bind: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
buf := make([]byte, BufferSize+1) // +1 to detect oversized packets
|
||||
buf := make([]byte, SignedPacketSize+1) // +1 to detect oversized packets
|
||||
|
||||
for {
|
||||
conn.SetReadDeadline(time.Now().Add(32 * time.Second))
|
||||
@@ -43,19 +50,43 @@ func receiver(vpnNet netip.Prefix, selfVPNIP netip.Addr, ch chan<- Packet) error
|
||||
continue
|
||||
}
|
||||
|
||||
packet := Unmarshal(buf[:n])
|
||||
packet := unmarshal(buf[:n])
|
||||
|
||||
if packet.PeerIP == selfIP {
|
||||
continue
|
||||
}
|
||||
|
||||
// Slightly cheaper than limiting.
|
||||
age := time.Since(time.Unix(packet.Timestamp, 0))
|
||||
if age > maxPacketAge || age < -maxPacketAge {
|
||||
continue
|
||||
}
|
||||
|
||||
srcAddr := src.Addr().Unmap()
|
||||
lim, ok := limiters[srcAddr]
|
||||
if !ok {
|
||||
lim = ratelimiter.New(ratelimiter.Config{
|
||||
BurstLimit: 1,
|
||||
FillPeriod: broadcastInterval / 2,
|
||||
MaxWaitCount: 0,
|
||||
})
|
||||
limiters[srcAddr] = lim
|
||||
}
|
||||
|
||||
if err := lim.Limit(); err != nil {
|
||||
log.Printf("[MC Receiver] Rate limited packet from peer IP %d.", packet.PeerIP)
|
||||
continue
|
||||
}
|
||||
|
||||
packet.Signed = bytes.Clone(packet.Signed)
|
||||
packet.Src = src.Addr().Unmap()
|
||||
ch <- packet
|
||||
}
|
||||
}
|
||||
|
||||
func multicastAddr(vpnIP netip.Addr) *net.UDPAddr {
|
||||
b := vpnIP.As4()
|
||||
return net.UDPAddrFromAddrPort(
|
||||
netip.AddrPortFrom(
|
||||
netip.AddrFrom4([4]byte{239, b[0], b[1], b[2]}), 4560))
|
||||
}
|
||||
|
||||
@@ -81,7 +81,7 @@ func New(
|
||||
|
||||
if !state.IsPublic {
|
||||
go multicast.Broadcast(state.VPNIP, state.PrivKey.PublicKey(), state.WGPort, &state.SignKey)
|
||||
go multicast.Receiver(state.VPNNet, state.VPNIP, multicastCh)
|
||||
go multicast.Receiver(state.VPNIP, multicastCh)
|
||||
}
|
||||
|
||||
return &App{
|
||||
@@ -89,7 +89,6 @@ func New(
|
||||
vpnNet: state.VPNNet,
|
||||
privKey: state.PrivKey,
|
||||
pubKey: state.PrivKey.PublicKey(),
|
||||
isRelay: state.IsRelay,
|
||||
isPublic: state.IsPublic,
|
||||
localDomain: localDomain,
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ func (a *App) onAddPeer(p m.Peer) {
|
||||
// endpoint from the incoming handshake automatically.
|
||||
a.devPromote(peer)
|
||||
} else {
|
||||
a.devAddPeer(peer)
|
||||
a.devAddRelayed(peer)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -95,14 +95,6 @@ func (a *App) switchActiveRelay() {
|
||||
a.relay = best
|
||||
}
|
||||
|
||||
func preferredEndpoint(v4, v6 netip.AddrPort) netip.AddrPort {
|
||||
// We always prefer v4 since all peers can connect to IPv4 addresses.
|
||||
if v4.IsValid() {
|
||||
return v4
|
||||
}
|
||||
return v6
|
||||
}
|
||||
|
||||
func roleFor(selfIsPublic bool, selfIP netip.Addr, peerIsPublic bool, peerVPNIP netip.Addr) control.Role {
|
||||
if !selfIsPublic && peerIsPublic {
|
||||
return control.Client
|
||||
|
||||
@@ -85,7 +85,7 @@ func TestOnAddPeer(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
a, dev, _ := newTestApp(t, "10.0.0.1", false, false)
|
||||
a, dev, _ := newTestApp(t, "10.0.0.1", false)
|
||||
key := mustKey(t)
|
||||
if tc.setup != nil {
|
||||
tc.setup(a, key)
|
||||
@@ -192,7 +192,7 @@ func TestOnRemovePeer(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
a, dev, _ := newTestApp(t, "10.0.0.1", false, false)
|
||||
a, dev, _ := newTestApp(t, "10.0.0.1", false)
|
||||
key := tc.setup(t, a)
|
||||
dev.Calls = nil
|
||||
a.onRemovePeer(key)
|
||||
@@ -268,7 +268,7 @@ func TestSwitchActiveRelay(t *testing.T) {
|
||||
name: "stale relay demoted to direct before backup elected",
|
||||
setup: func(t *testing.T, a *App) {
|
||||
old := addRelayPeer(t, a, "10.0.0.10", ep1)
|
||||
old.wgPeer.LastHandshakeTime = time.Time{} // stale — triggers switch from onTick
|
||||
old.LastPing = time.Time{} // stale — Up() checks LastPing; triggers switch — triggers switch from onTick
|
||||
a.relay = old
|
||||
addRelayPeer(t, a, "10.0.0.11", ep2)
|
||||
},
|
||||
@@ -289,7 +289,7 @@ func TestSwitchActiveRelay(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
a, dev, _ := newTestApp(t, "10.0.0.1", false, false)
|
||||
a, dev, _ := newTestApp(t, "10.0.0.1", false)
|
||||
tc.setup(t, a)
|
||||
dev.Calls = nil
|
||||
a.switchActiveRelay()
|
||||
|
||||
@@ -9,17 +9,13 @@ import (
|
||||
)
|
||||
|
||||
func (a *App) onMulticastDiscovery(pkt multicast.Packet) {
|
||||
if a.isPublic {
|
||||
return
|
||||
}
|
||||
|
||||
// Locate the sender peer by its VPN IP (final octet carried in the beacon).
|
||||
octets := a.vpnNet.Addr().As4()
|
||||
octets[3] = pkt.PeerIP
|
||||
vpnIP := netip.AddrFrom4(octets)
|
||||
|
||||
peer, ok := a.peersByIP[vpnIP]
|
||||
if !ok || peer.IsPublic || peer.State == StateDirect {
|
||||
if !ok || peer.IsPublic {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -36,16 +32,9 @@ func (a *App) onMulticastDiscovery(pkt multicast.Packet) {
|
||||
}
|
||||
|
||||
endpoint := netip.AddrPortFrom(pkt.Src, pkt.WGPort)
|
||||
if !endpoint.IsValid() {
|
||||
if !endpoint.IsValid() || endpoint.Port() == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var v4, v6 netip.AddrPort
|
||||
if pkt.Src.Is4() {
|
||||
v4 = endpoint
|
||||
} else {
|
||||
v6 = endpoint
|
||||
}
|
||||
|
||||
a.addProbe(peer, v4, v6)
|
||||
peer.EndpointLAN = endpoint
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
@@ -16,6 +17,8 @@ func (a *App) onPing(e PingEvent) {
|
||||
|
||||
now := time.Now()
|
||||
|
||||
peer.LastPing = now
|
||||
|
||||
// If we're the server, respond - this is always necessary as it's used to
|
||||
// know if peers are up or down.
|
||||
if peer.Role == control.Server {
|
||||
@@ -32,27 +35,37 @@ func (a *App) onPing(e PingEvent) {
|
||||
return
|
||||
}
|
||||
|
||||
// We can only learn our own endpoint from directly-connected peers — Dst
|
||||
// is the sender's observation of our WG handshake source.
|
||||
// We can only learn our own endpoint from directly-connected peers — Dst is
|
||||
// the sender's observation of our WG handshake source.
|
||||
//
|
||||
// We make sure we don't set a private address as our public address since we
|
||||
// may be connected via LAN to some peers.
|
||||
if peer.State == StateDirect {
|
||||
if dst := e.ping.Dst; dst.IsValid() {
|
||||
if dst := e.ping.Dst; addrIsRoutable(dst) {
|
||||
if dst.Addr().Is4() {
|
||||
if dst != a.selfV4 {
|
||||
log.Printf("Local IPv4 updated: %s -> %s", a.selfV4, dst)
|
||||
a.selfV4 = dst
|
||||
}
|
||||
} else {
|
||||
if dst != a.selfV6 {
|
||||
log.Printf("Local IPv6 updated: %s -> %s", a.selfV6, dst)
|
||||
a.selfV6 = dst
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
a.addProbe(peer, e.ping.SrcV4, e.ping.SrcV6)
|
||||
peer.UpdateEndpoints(e.ping.SrcV4, e.ping.SrcV6)
|
||||
}
|
||||
|
||||
func (a *App) addProbe(peer *Peer, v4, v6 netip.AddrPort) {
|
||||
endpoint := preferredEndpoint(v4, v6)
|
||||
if !endpoint.IsValid() || endpoint == peer.PreferredEndpoint() {
|
||||
return
|
||||
var cgnatPrefix = netip.MustParsePrefix("100.64.0.0/10")
|
||||
|
||||
func addrIsRoutable(addrPort netip.AddrPort) bool {
|
||||
if addrPort.Port() == 0 {
|
||||
return false
|
||||
}
|
||||
peer.UpdateEndpoints(v4, v6)
|
||||
a.devAddProbe(peer, endpoint)
|
||||
addr := addrPort.Addr()
|
||||
return addr.IsGlobalUnicast() && !addr.IsPrivate() && !cgnatPrefix.Contains(addr)
|
||||
}
|
||||
|
||||
15
peer/on_pingtick.go
Normal file
15
peer/on_pingtick.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"time"
|
||||
"vppn/peer/control"
|
||||
)
|
||||
|
||||
func (a *App) onPingTick() {
|
||||
now := time.Now().UnixNano()
|
||||
for _, p := range a.peersByIP {
|
||||
if p.Role == control.Client {
|
||||
a.sendPing(p, now)
|
||||
}
|
||||
}
|
||||
}
|
||||
68
peer/on_statetick.go
Normal file
68
peer/on_statetick.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"vppn/peer/wginterface"
|
||||
)
|
||||
|
||||
func (a *App) onStateTick() {
|
||||
wgPeers := a.devPeers()
|
||||
|
||||
for _, wgPeer := range wgPeers {
|
||||
p, ok := a.peersByKey[wgPeer.PublicKey]
|
||||
if !ok {
|
||||
log.Printf("Wireguard peer not known. Removing: %v", wgPeer.PublicKey)
|
||||
a.devRemove(&Peer{wgPeer: wgPeer})
|
||||
continue
|
||||
}
|
||||
|
||||
p.wgPeer = wgPeer
|
||||
|
||||
// Log endpoint changes.
|
||||
if ep := p.WGEndpoint(); ep != p.EndpointWG {
|
||||
log.Printf("Client %s %s endpoint: %s -> %s", p.Name, p.VPNIP, p.EndpointWG, ep)
|
||||
p.EndpointWG = ep
|
||||
}
|
||||
|
||||
switch p.State {
|
||||
case StateRelayed:
|
||||
if p.DirectAlive() {
|
||||
// We may already have a valid direct endpoint due to wireguard
|
||||
// roaming.
|
||||
a.devPromote(p)
|
||||
} else if ep := p.PreferredEndpoint(); ep.IsValid() {
|
||||
// If we have an ep to probe, add it.
|
||||
a.devAddProbe(p, ep)
|
||||
}
|
||||
|
||||
case StateProbing:
|
||||
if time.Since(p.LastHandshakeTime()) < 2*wginterface.ProbeKeepalive {
|
||||
// Promote probing peers to direct once alive (direct path confirmed
|
||||
// working).
|
||||
a.devPromote(p)
|
||||
} else if ep := p.PreferredEndpoint(); ep.IsValid() && ep != p.ProbeEndpoint {
|
||||
// Re-start probing if we see a new endpoint.
|
||||
a.devAddProbe(p, ep)
|
||||
} else if time.Since(p.ProbeStart) > 8*wginterface.ProbeKeepalive {
|
||||
// Give up probing if we haven't been able to handshake.
|
||||
a.devAddRelayed(p)
|
||||
}
|
||||
|
||||
case StateDirect:
|
||||
if p.IsPublic || a.isPublic || p.Up() {
|
||||
break
|
||||
}
|
||||
|
||||
// Stale non-public direct peer: demote to relayed and wait for new IP
|
||||
// information.
|
||||
a.devAddRelayed(p)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure we have a live relay (if we're not public).
|
||||
if !a.isPublic && (a.relay == nil || !a.relay.Up()) {
|
||||
a.switchActiveRelay()
|
||||
}
|
||||
}
|
||||
@@ -1,52 +0,0 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"vppn/peer/control"
|
||||
"vppn/peer/wginterface"
|
||||
)
|
||||
|
||||
func (a *App) onTick() {
|
||||
wgPeers := a.devPeers()
|
||||
|
||||
now := time.Now().UnixNano()
|
||||
|
||||
for _, wgPeer := range wgPeers {
|
||||
p, ok := a.peersByKey[wgPeer.PublicKey]
|
||||
if !ok {
|
||||
log.Printf("Wireguard peer not in index, removing: %v", wgPeer)
|
||||
a.devRemove(&Peer{wgPeer: wgPeer})
|
||||
continue
|
||||
}
|
||||
p.wgPeer = wgPeer
|
||||
|
||||
// Send pings to peers where we're the client.
|
||||
if p.Role == control.Client {
|
||||
a.sendPing(p, now)
|
||||
}
|
||||
|
||||
switch p.State {
|
||||
case StateProbing:
|
||||
// Promote probing peers to direct once alive (direct path confirmed
|
||||
// working).
|
||||
if time.Since(p.LastHandshakeTime()) < 2*wginterface.ProbeKeepalive {
|
||||
a.devPromote(p)
|
||||
}
|
||||
|
||||
case StateDirect:
|
||||
if p.IsPublic || a.isPublic || p.Up() {
|
||||
break
|
||||
}
|
||||
// Stale non-public direct peer: demote to probing so WireGuard
|
||||
// resumes handshake attempts on the direct path.
|
||||
a.devAddProbe(p, p.WGEndpoint())
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure we have a live relay (if we're not public).
|
||||
if !a.isPublic && (a.relay == nil || !a.relay.Up()) {
|
||||
a.switchActiveRelay()
|
||||
}
|
||||
}
|
||||
@@ -13,8 +13,8 @@ import (
|
||||
type PeerState string
|
||||
|
||||
const (
|
||||
StateRelayed = PeerState("RELAY")
|
||||
StateProbing = PeerState("PROBE")
|
||||
StateRelayed = PeerState("RELAY ")
|
||||
StateProbing = PeerState("PROBE ")
|
||||
StateDirect = PeerState("DIRECT")
|
||||
)
|
||||
|
||||
@@ -26,9 +26,14 @@ type Peer struct {
|
||||
IsPublic bool // Peer has a public IP.
|
||||
EndpointV4 netip.AddrPort // Reported IPv4 endpoint.
|
||||
EndpointV6 netip.AddrPort // Reported IPv6 endpoint.
|
||||
EndpointLAN netip.AddrPort // Discovered via multicast.
|
||||
EndpointWG netip.AddrPort // Current wireguard endpoint.
|
||||
RTT time.Duration // Round-trip time.
|
||||
LastPing time.Time // Last time we had a ping.
|
||||
ProbeStart time.Time // When we started probing.
|
||||
ProbeEndpoint netip.AddrPort
|
||||
State PeerState // Current routing state; updated on each devXxx call.
|
||||
Role control.Role // Client initiates pings; server responds.
|
||||
Role control.Role // Role in relation to the local application.
|
||||
SignPubKey [32]byte // nacl/sign public key for verifying multicast beacons.
|
||||
}
|
||||
|
||||
@@ -54,7 +59,12 @@ func (p *Peer) LastHandshakeTime() time.Time {
|
||||
}
|
||||
|
||||
func (p *Peer) Up() bool {
|
||||
return time.Since(p.wgPeer.LastHandshakeTime) < wginterface.SessionTimeout
|
||||
return time.Since(p.LastPing) < 3*PingInterval
|
||||
}
|
||||
|
||||
func (p *Peer) DirectAlive() bool {
|
||||
return p.WGEndpoint().IsValid() &&
|
||||
time.Since(p.LastHandshakeTime()) < 2*wginterface.ProbeKeepalive
|
||||
}
|
||||
|
||||
func (p *Peer) CanRelay() bool {
|
||||
@@ -62,7 +72,13 @@ func (p *Peer) CanRelay() bool {
|
||||
}
|
||||
|
||||
func (p *Peer) PreferredEndpoint() netip.AddrPort {
|
||||
return preferredEndpoint(p.EndpointV4, p.EndpointV6)
|
||||
if p.EndpointLAN.IsValid() {
|
||||
return p.EndpointLAN
|
||||
} else if p.EndpointV4.IsValid() {
|
||||
return p.EndpointV4
|
||||
} else {
|
||||
return p.EndpointV6
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Peer) UpdateEndpoints(v4, v6 netip.AddrPort) {
|
||||
|
||||
@@ -11,6 +11,7 @@ package wginterface
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
@@ -173,6 +174,10 @@ func nlAttr(attrType uint16, data []byte) []byte {
|
||||
// messages, but the AF_INET ioctl interface is simpler.
|
||||
|
||||
func ioctlSetAddr(name string, ip net.IP, prefixLen int) error {
|
||||
if ip.To4() == nil {
|
||||
return errors.New("attempted to set non-IPv4 address on interface")
|
||||
}
|
||||
|
||||
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -112,7 +112,7 @@ func (d *Device) SetRelay(pubKey wgtypes.Key, endpoint netip.AddrPort, network n
|
||||
})
|
||||
}
|
||||
|
||||
// AddProbe adds a peer with no AllowedIPs and a 5s keepalive. WireGuard will
|
||||
// AddProbe adds a peer with no AllowedIPs and an 8s keepalive. WireGuard will
|
||||
// attempt handshakes without routing any traffic through this peer yet.
|
||||
func (d *Device) AddProbe(pubKey wgtypes.Key, endpoint netip.AddrPort) error {
|
||||
keepalive := ProbeKeepalive
|
||||
|
||||
Reference in New Issue
Block a user