Compare commits
14 Commits
v0.11.0
...
client-int
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3ebfe754e7 | ||
|
|
069243e5d4 | ||
|
|
bd78ffd669 | ||
|
|
a90ab3f5d6 | ||
|
|
650c74c013 | ||
|
|
b308150d21 | ||
|
|
a0b7ecbfe0 | ||
|
|
69dff24344 | ||
|
|
257fac67ce | ||
|
|
2ff8aaf5c4 | ||
|
|
fccc4f7d57 | ||
|
|
c6d35856bc | ||
|
|
5844584219 | ||
|
|
e458e43d83 |
10
README.md
10
README.md
@@ -1,5 +1,9 @@
|
|||||||
# vppn: Virtual Potentially Private Network
|
# vppn: Virtual Potentially Private Network
|
||||||
|
|
||||||
|
## TO DO
|
||||||
|
|
||||||
|
* Double buffering in IFReader and ConnReader ?
|
||||||
|
|
||||||
## Hub Server Configuration
|
## Hub Server Configuration
|
||||||
|
|
||||||
```
|
```
|
||||||
@@ -53,17 +57,15 @@ Sign-in and configure.
|
|||||||
|
|
||||||
Install the binary somewhere, for example `~/bin/vppn`.
|
Install the binary somewhere, for example `~/bin/vppn`.
|
||||||
|
|
||||||
Add the API key for your network name in `~/.vppn/<netname>/apikey`.
|
|
||||||
|
|
||||||
Create systemd file in `/etc/systemd/system/vppn.service`.
|
Create systemd file in `/etc/systemd/system/vppn.service`.
|
||||||
|
|
||||||
```
|
```
|
||||||
[Service]
|
[Service]
|
||||||
AmbientCapabilities=AP_NET_ADMIN CAP_DAC_OVERRIDE CAP_CHOWN
|
AmbientCapabilities=CAP_NET_BIND_SERVICE CAP_NET_ADMIN
|
||||||
Type=simple
|
Type=simple
|
||||||
User=user
|
User=user
|
||||||
WorkingDirectory=/home/user/
|
WorkingDirectory=/home/user/
|
||||||
ExecStart=/home/user/bin/vppn -name my_net_name -hub https://my.hub
|
ExecStart=/home/user/vppn run my_net_name https://my.hub my_api_key
|
||||||
Restart=always
|
Restart=always
|
||||||
RestartSec=8
|
RestartSec=8
|
||||||
TimeoutStopSec=24
|
TimeoutStopSec=24
|
||||||
|
|||||||
@@ -1,72 +1,11 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"flag"
|
|
||||||
"log"
|
"log"
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"vppn/peer"
|
"vppn/peer"
|
||||||
|
|
||||||
"git.crumpington.com/lib/go/flock"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
log.SetFlags(0)
|
log.SetFlags(0)
|
||||||
|
peer.Main2()
|
||||||
name := flag.String("name", "", "network name (required)")
|
|
||||||
hub := flag.String("hub", "", "hub base URL (required)")
|
|
||||||
flag.Parse()
|
|
||||||
|
|
||||||
if *name == "" || *hub == "" {
|
|
||||||
flag.Usage()
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
apiKey, err := loadAPIKey(*name)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("api key: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Directory existence is guaranteed by the apikey file read above.
|
|
||||||
lockFile, err := flock.TryLock(vppnPath(*name, "lock"))
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("lock: %v", err)
|
|
||||||
}
|
|
||||||
if lockFile == nil {
|
|
||||||
log.Fatalf("already running for network %q", *name)
|
|
||||||
}
|
|
||||||
defer flock.Unlock(lockFile)
|
|
||||||
|
|
||||||
state, err := peer.LoadOrInit(vppnPath(*name, "state.json"), *hub, apiKey)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("init: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ifaceName := strings.TrimSuffix(state.LocalDomain, ".local")
|
|
||||||
app, err := peer.New(state, *hub, apiKey, ifaceName, state.LocalDomain, vppnPath(*name, "network.json"))
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("start: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := app.Run(); err != nil {
|
|
||||||
log.Fatalf("run: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func loadAPIKey(name string) (string, error) {
|
|
||||||
data, err := os.ReadFile(vppnPath(name, "apikey"))
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return strings.TrimSpace(string(data)), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func vppnPath(name, file string) string {
|
|
||||||
home, err := os.UserHomeDir()
|
|
||||||
if err != nil {
|
|
||||||
return filepath.Join(".vppn", name, file)
|
|
||||||
}
|
|
||||||
return filepath.Join(home, ".vppn", name, file)
|
|
||||||
}
|
}
|
||||||
|
|||||||
20
go.mod
20
go.mod
@@ -3,21 +3,13 @@ module vppn
|
|||||||
go 1.25.1
|
go 1.25.1
|
||||||
|
|
||||||
require (
|
require (
|
||||||
git.crumpington.com/lib/go v0.10.0
|
git.crumpington.com/lib/go v0.9.1
|
||||||
golang.org/x/crypto v0.53.0
|
golang.org/x/crypto v0.42.0
|
||||||
golang.org/x/sys v0.46.0
|
golang.org/x/sys v0.36.0
|
||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
|
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/google/go-cmp v0.7.0 // indirect
|
github.com/mattn/go-sqlite3 v1.14.32 // indirect
|
||||||
github.com/josharian/native v1.1.0 // indirect
|
golang.org/x/net v0.44.0 // indirect
|
||||||
github.com/mattn/go-sqlite3 v1.14.45 // indirect
|
golang.org/x/text v0.29.0 // indirect
|
||||||
github.com/mdlayher/genetlink v1.4.0 // indirect
|
|
||||||
github.com/mdlayher/netlink v1.11.2 // indirect
|
|
||||||
github.com/mdlayher/socket v0.6.1 // indirect
|
|
||||||
golang.org/x/net v0.56.0 // indirect
|
|
||||||
golang.org/x/sync v0.21.0 // indirect
|
|
||||||
golang.org/x/text v0.38.0 // indirect
|
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20260522210424-ecfc5a8d5446 // indirect
|
|
||||||
)
|
)
|
||||||
|
|||||||
44
go.sum
44
go.sum
@@ -1,56 +1,12 @@
|
|||||||
git.crumpington.com/lib/go v0.9.1 h1:xLBzcgiZRB6Ky3Ce9hKE+Ko0YbkA4USF4eJk5i5RJF4=
|
git.crumpington.com/lib/go v0.9.1 h1:xLBzcgiZRB6Ky3Ce9hKE+Ko0YbkA4USF4eJk5i5RJF4=
|
||||||
git.crumpington.com/lib/go v0.9.1/go.mod h1:5nnfjdnUnj/FHhakaliKQKsKeSkUb0GEUKF3PqRgUXg=
|
git.crumpington.com/lib/go v0.9.1/go.mod h1:5nnfjdnUnj/FHhakaliKQKsKeSkUb0GEUKF3PqRgUXg=
|
||||||
git.crumpington.com/lib/go v0.9.2 h1:DZ7tzFM/S+zL5hexNo8zKbH7Ryi+VtvSMRzCMnlz+c4=
|
|
||||||
git.crumpington.com/lib/go v0.9.2/go.mod h1:5nnfjdnUnj/FHhakaliKQKsKeSkUb0GEUKF3PqRgUXg=
|
|
||||||
git.crumpington.com/lib/go v0.10.0 h1:4O+o9QBVcre8RYESAXhxJ1kT0w1tIakUdt/rV4v4riw=
|
|
||||||
git.crumpington.com/lib/go v0.10.0/go.mod h1:8y838PnV7dM6QT0XwLMuG2ulDNtCv4NmdSJIEqGViKg=
|
|
||||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
|
||||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
|
||||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
|
||||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
|
||||||
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
|
|
||||||
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
|
|
||||||
github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
|
github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
|
||||||
github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||||
github.com/mattn/go-sqlite3 v1.14.45 h1:6KA/spDguL3KV8rnybG7ezSaE4SeMR3KC9VbUoAQaIk=
|
|
||||||
github.com/mattn/go-sqlite3 v1.14.45/go.mod h1:pjEuOr8IwzLJP2MfGeTb0A35jauH+C2kbHKBr7yXKVQ=
|
|
||||||
github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw=
|
|
||||||
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
|
|
||||||
github.com/mdlayher/genetlink v1.4.0 h1:f/Xs7Y2T+GyX9b3dbiUhnLE9InGs5F9RxJ2JwBMl71o=
|
|
||||||
github.com/mdlayher/genetlink v1.4.0/go.mod h1:d1hrKr8fwZU2JkcAtQUAzeTrI7nbgQSl+5k1cC0biSA=
|
|
||||||
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
|
|
||||||
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
|
|
||||||
github.com/mdlayher/netlink v1.11.2 h1:HKh2jqe+omdSWcQ88nrT7INE61B0NXfiSPFdgL4YbNI=
|
|
||||||
github.com/mdlayher/netlink v1.11.2/go.mod h1:uT2Yc/QLaZubzDpZIBi9d4GoeLwtp3x1AMeqSRrK2sA=
|
|
||||||
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
|
|
||||||
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
|
|
||||||
github.com/mdlayher/socket v0.6.1 h1:M7uj2NtuujUY4mYr1C57NmfNiRHbkKpnBxO856lsc3A=
|
|
||||||
github.com/mdlayher/socket v0.6.1/go.mod h1:+/SGtqc9V+5dAuRgQsU0fGBI+oRDiW7O2Obx10OIWfg=
|
|
||||||
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
|
|
||||||
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
|
|
||||||
golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI=
|
golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI=
|
||||||
golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
|
golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
|
||||||
golang.org/x/crypto v0.53.0 h1:QZ4Muo8THX6CizN2vPPd5fBGHyogrdK9fG4wLPFUsto=
|
|
||||||
golang.org/x/crypto v0.53.0/go.mod h1:DNLU434OwVakk9PzuwV8w62mAJpRJL3vsgcfp4Qnsio=
|
|
||||||
golang.org/x/net v0.44.0 h1:evd8IRDyfNBMBTTY5XRF1vaZlD+EmWx6x8PkhR04H/I=
|
golang.org/x/net v0.44.0 h1:evd8IRDyfNBMBTTY5XRF1vaZlD+EmWx6x8PkhR04H/I=
|
||||||
golang.org/x/net v0.44.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
|
golang.org/x/net v0.44.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
|
||||||
golang.org/x/net v0.56.0 h1:Rw8j/hFzGvJUZwNBXnAtf5sVDVt+65SK2C7IxCxZt5o=
|
|
||||||
golang.org/x/net v0.56.0/go.mod h1:D3Ku6r+V6JROoZK144D2XfMHFcMq/0zSfLelVTCFKec=
|
|
||||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
|
||||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
|
||||||
golang.org/x/sync v0.21.0 h1:HLII4xRRTtCRkxYp4HNFF0Js/Og6q2i++KXbg0gHCwM=
|
|
||||||
golang.org/x/sync v0.21.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
|
||||||
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
|
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
|
||||||
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
golang.org/x/sys v0.46.0 h1:noSf2Fq6F8DBgS+LysIkx7rIExoNHJsxOAtPp4rthXw=
|
|
||||||
golang.org/x/sys v0.46.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
|
||||||
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
|
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
|
||||||
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
|
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
|
||||||
golang.org/x/text v0.38.0 h1:sXmwo9DwP3OK9EZ7PqAdaooSGozfl/3a6/xJcbzPRhE=
|
|
||||||
golang.org/x/text v0.38.0/go.mod h1:YXZt3QhHUKYT53r2lLKFIVi6Ao1jdzrTR/KQ09qyxF4=
|
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4=
|
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20260522210424-ecfc5a8d5446 h1:cqHQ3AycTHvM2R7ikgyX57D+XvtcSnGylsLkOVhta/w=
|
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20260522210424-ecfc5a8d5446/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
|
|
||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU=
|
|
||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ=
|
|
||||||
|
|||||||
147
hub/api/api.go
147
hub/api/api.go
@@ -19,10 +19,8 @@ import (
|
|||||||
var migrations embed.FS
|
var migrations embed.FS
|
||||||
|
|
||||||
type API struct {
|
type API struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
sessionsMu sync.Mutex
|
|
||||||
sessions map[string]*Session
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(dbPath string) (*API, error) {
|
func New(dbPath string) (*API, error) {
|
||||||
@@ -36,17 +34,10 @@ func New(dbPath string) (*API, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
a := &API{
|
a := &API{
|
||||||
db: sqlDB,
|
db: sqlDB,
|
||||||
sessions: make(map[string]*Session),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := a.ensurePassword(); err != nil {
|
return a, a.ensurePassword()
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
go a.sweepSessions()
|
|
||||||
|
|
||||||
return a, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *API) ensurePassword() error {
|
func (a *API) ensurePassword() error {
|
||||||
@@ -71,8 +62,12 @@ func (a *API) ensurePassword() error {
|
|||||||
return db.Config_Insert(a.db, conf)
|
return db.Config_Insert(a.db, conf)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *API) Config_Get() (*Config, error) {
|
func (a *API) Config_Get() *Config {
|
||||||
return db.Config_Get(a.db, 1)
|
conf, err := db.Config_Get(a.db, 1)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return conf
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *API) Config_Update(conf *Config) error {
|
func (a *API) Config_Update(conf *Config) error {
|
||||||
@@ -80,78 +75,56 @@ func (a *API) Config_Update(conf *Config) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *API) Session_Delete(sessionID string) error {
|
func (a *API) Session_Delete(sessionID string) error {
|
||||||
a.sessionsMu.Lock()
|
return db.Session_Delete(a.db, sessionID)
|
||||||
defer a.sessionsMu.Unlock()
|
|
||||||
delete(a.sessions, sessionID)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
func (a *API) Session_Get(sessionID string) (*Session, error) {
|
||||||
sessionTTLSecs = 86400 * 21 // sessions expire 21 days after last use
|
if sessionID == "" {
|
||||||
sessionSweepEvery = time.Hour // cadence of expired-session eviction
|
return a.session_CreatePub()
|
||||||
)
|
|
||||||
|
|
||||||
// Session_Get returns a snapshot copy of the signed-in session for sessionID,
|
|
||||||
// or the zero Session if the cookie is missing/unknown/expired. It never
|
|
||||||
// creates a session, so anonymous requests cost no memory — a session is minted
|
|
||||||
// only by Session_SignIn. Returning a value (not the stored pointer) keeps
|
|
||||||
// callers from racing on the shared struct.
|
|
||||||
func (a *API) Session_Get(sessionID string) (Session, error) {
|
|
||||||
a.sessionsMu.Lock()
|
|
||||||
defer a.sessionsMu.Unlock()
|
|
||||||
|
|
||||||
s, ok := a.sessions[sessionID]
|
|
||||||
|
|
||||||
if sessionID == "" || !ok {
|
|
||||||
return Session{}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if timeSince(s.LastSeenAt) > sessionTTLSecs {
|
session, err := db.Session_Get(a.db, sessionID)
|
||||||
delete(a.sessions, sessionID)
|
|
||||||
return Session{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
s.LastSeenAt = time.Now().Unix()
|
|
||||||
return *s, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Session_SignIn verifies pwd and, on success, mints a fresh signed-in session,
|
|
||||||
// returning it so the caller can set the cookie. A new ID per sign-in rotates
|
|
||||||
// the session at the privilege boundary (session-fixation resistance).
|
|
||||||
func (a *API) Session_SignIn(pwd string) (Session, error) {
|
|
||||||
conf, err := a.Config_Get()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Session{}, err
|
return a.session_CreatePub()
|
||||||
}
|
|
||||||
if err := bcrypt.CompareHashAndPassword(conf.Password, []byte(pwd)); err != nil {
|
|
||||||
return Session{}, ErrNotAuthorized
|
|
||||||
}
|
}
|
||||||
|
|
||||||
a.sessionsMu.Lock()
|
if timeSince(session.LastSeenAt) > 86400*21 {
|
||||||
defer a.sessionsMu.Unlock()
|
return a.session_CreatePub()
|
||||||
|
}
|
||||||
|
|
||||||
|
if timeSince(session.LastSeenAt) > 86400*7 {
|
||||||
|
session.LastSeenAt = time.Now().Unix()
|
||||||
|
if err := db.Session_UpdateLastSeenAt(a.db, session.SessionID); err != nil {
|
||||||
|
log.Printf("Failed to update session: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return session, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *API) session_CreatePub() (*Session, error) {
|
||||||
s := &Session{
|
s := &Session{
|
||||||
SessionID: idgen.NewToken(),
|
SessionID: idgen.NewToken(),
|
||||||
SignedIn: true,
|
CSRF: idgen.NewToken(),
|
||||||
|
SignedIn: false,
|
||||||
CreatedAt: time.Now().Unix(),
|
CreatedAt: time.Now().Unix(),
|
||||||
LastSeenAt: time.Now().Unix(),
|
LastSeenAt: time.Now().Unix(),
|
||||||
}
|
}
|
||||||
a.sessions[s.SessionID] = s
|
err := db.Session_Insert(a.db, s)
|
||||||
return *s, nil
|
return s, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// sweepSessions periodically evicts sessions past their TTL. Without it, a
|
func (a *API) Session_DeleteBefore(timestamp int64) error {
|
||||||
// signed-in session whose ID is never presented again would linger forever
|
return db.Session_DeleteBefore(a.db, timestamp)
|
||||||
// (Session_Get only evicts on a lookup of that same ID).
|
}
|
||||||
func (a *API) sweepSessions() {
|
|
||||||
for range time.Tick(sessionSweepEvery) {
|
func (a *API) Session_SignIn(s *Session, pwd string) error {
|
||||||
a.sessionsMu.Lock()
|
conf := a.Config_Get()
|
||||||
for id, s := range a.sessions {
|
if err := bcrypt.CompareHashAndPassword(conf.Password, []byte(pwd)); err != nil {
|
||||||
if timeSince(s.LastSeenAt) > sessionTTLSecs {
|
return ErrNotAuthorized
|
||||||
delete(a.sessions, id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
a.sessionsMu.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return db.Session_SetSignedIn(a.db, s.SessionID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *API) Network_Create(n *Network) error {
|
func (a *API) Network_Create(n *Network) error {
|
||||||
@@ -168,13 +141,14 @@ func (a *API) Network_Get(id int64) (*Network, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *API) Network_List() ([]*Network, error) {
|
func (a *API) Network_List() ([]*Network, error) {
|
||||||
const query = db.Network_SelectQuery + ` ORDER BY LocalDomain ASC`
|
const query = db.Network_SelectQuery + ` ORDER BY Name ASC`
|
||||||
return db.Network_List(a.db, query)
|
return db.Network_List(a.db, query)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *API) Peer_CreateNew(p *Peer) error {
|
func (a *API) Peer_CreateNew(p *Peer) error {
|
||||||
p.WGPubKey = []byte{}
|
p.Version = idgen.NextID(0)
|
||||||
p.SignPubKey = []byte{}
|
p.PubKey = []byte{}
|
||||||
|
p.PubSignKey = []byte{}
|
||||||
p.APIKey = idgen.NewToken()
|
p.APIKey = idgen.NewToken()
|
||||||
|
|
||||||
return db.Peer_Insert(a.db, p)
|
return db.Peer_Insert(a.db, p)
|
||||||
@@ -184,22 +158,21 @@ func (a *API) Peer_Init(peer *Peer, args m.PeerInitArgs) error {
|
|||||||
a.lock.Lock()
|
a.lock.Lock()
|
||||||
defer a.lock.Unlock()
|
defer a.lock.Unlock()
|
||||||
|
|
||||||
// Re-read from DB inside the lock — the caller's copy was fetched before
|
peer.Version = idgen.NextID(0)
|
||||||
// we held the lock, so it may be stale under concurrent requests.
|
peer.PubKey = args.EncPubKey
|
||||||
current, err := db.Peer_Get(a.db, peer.NetworkID, peer.PeerIP)
|
peer.PubSignKey = args.PubSignKey
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if len(current.WGPubKey) != 0 {
|
|
||||||
return errors.New("peer already initialized")
|
|
||||||
}
|
|
||||||
|
|
||||||
peer.WGPubKey = args.WGPubKey
|
|
||||||
peer.SignPubKey = args.SignPubKey
|
|
||||||
|
|
||||||
return db.Peer_UpdateFull(a.db, peer)
|
return db.Peer_UpdateFull(a.db, peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *API) Peer_Update(p *Peer) error {
|
||||||
|
a.lock.Lock()
|
||||||
|
defer a.lock.Unlock()
|
||||||
|
|
||||||
|
p.Version = idgen.NextID(0)
|
||||||
|
return db.Peer_Update(a.db, p)
|
||||||
|
}
|
||||||
|
|
||||||
func (a *API) Peer_Delete(networkID int64, peerIP byte) error {
|
func (a *API) Peer_Delete(networkID int64, peerIP byte) error {
|
||||||
return db.Peer_Delete(a.db, networkID, peerIP)
|
return db.Peer_Delete(a.db, networkID, peerIP)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -123,9 +123,7 @@ func Config_Get(
|
|||||||
) {
|
) {
|
||||||
row = &Config{}
|
row = &Config{}
|
||||||
r := tx.QueryRow("SELECT ConfigID,Password FROM config WHERE ConfigID=?", ConfigID)
|
r := tx.QueryRow("SELECT ConfigID,Password FROM config WHERE ConfigID=?", ConfigID)
|
||||||
if err = r.Scan(&row.ConfigID, &row.Password); err != nil {
|
err = r.Scan(&row.ConfigID, &row.Password)
|
||||||
row = nil
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -139,9 +137,7 @@ func Config_GetWhere(
|
|||||||
) {
|
) {
|
||||||
row = &Config{}
|
row = &Config{}
|
||||||
r := tx.QueryRow(query, args...)
|
r := tx.QueryRow(query, args...)
|
||||||
if err = r.Scan(&row.ConfigID, &row.Password); err != nil {
|
err = r.Scan(&row.ConfigID, &row.Password)
|
||||||
row = nil
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -186,17 +182,135 @@ func Config_List(
|
|||||||
return l, nil
|
return l, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
// Table: sessions
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type Session struct {
|
||||||
|
SessionID string
|
||||||
|
CSRF string
|
||||||
|
SignedIn bool
|
||||||
|
CreatedAt int64
|
||||||
|
LastSeenAt int64
|
||||||
|
}
|
||||||
|
|
||||||
|
const Session_SelectQuery = "SELECT SessionID,CSRF,SignedIn,CreatedAt,LastSeenAt FROM sessions"
|
||||||
|
|
||||||
|
func Session_Insert(
|
||||||
|
tx TX,
|
||||||
|
row *Session,
|
||||||
|
) (err error) {
|
||||||
|
Session_Sanitize(row)
|
||||||
|
if err = Session_Validate(row); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = tx.Exec("INSERT INTO sessions(SessionID,CSRF,SignedIn,CreatedAt,LastSeenAt) VALUES(?,?,?,?,?)", row.SessionID, row.CSRF, row.SignedIn, row.CreatedAt, row.LastSeenAt)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func Session_Delete(
|
||||||
|
tx TX,
|
||||||
|
SessionID string,
|
||||||
|
) (err error) {
|
||||||
|
result, err := tx.Exec("DELETE FROM sessions WHERE SessionID=?", SessionID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
switch n {
|
||||||
|
case 0:
|
||||||
|
return sql.ErrNoRows
|
||||||
|
case 1:
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
panic("multiple rows deleted")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Session_Get(
|
||||||
|
tx TX,
|
||||||
|
SessionID string,
|
||||||
|
) (
|
||||||
|
row *Session,
|
||||||
|
err error,
|
||||||
|
) {
|
||||||
|
row = &Session{}
|
||||||
|
r := tx.QueryRow("SELECT SessionID,CSRF,SignedIn,CreatedAt,LastSeenAt FROM sessions WHERE SessionID=?", SessionID)
|
||||||
|
err = r.Scan(&row.SessionID, &row.CSRF, &row.SignedIn, &row.CreatedAt, &row.LastSeenAt)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func Session_GetWhere(
|
||||||
|
tx TX,
|
||||||
|
query string,
|
||||||
|
args ...any,
|
||||||
|
) (
|
||||||
|
row *Session,
|
||||||
|
err error,
|
||||||
|
) {
|
||||||
|
row = &Session{}
|
||||||
|
r := tx.QueryRow(query, args...)
|
||||||
|
err = r.Scan(&row.SessionID, &row.CSRF, &row.SignedIn, &row.CreatedAt, &row.LastSeenAt)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func Session_Iterate(
|
||||||
|
tx TX,
|
||||||
|
query string,
|
||||||
|
args ...any,
|
||||||
|
) iter.Seq2[*Session, error] {
|
||||||
|
rows, err := tx.Query(query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return func(yield func(*Session, error) bool) {
|
||||||
|
yield(nil, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(yield func(*Session, error) bool) {
|
||||||
|
defer rows.Close()
|
||||||
|
for rows.Next() {
|
||||||
|
row := &Session{}
|
||||||
|
err := rows.Scan(&row.SessionID, &row.CSRF, &row.SignedIn, &row.CreatedAt, &row.LastSeenAt)
|
||||||
|
if !yield(row, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Session_List(
|
||||||
|
tx TX,
|
||||||
|
query string,
|
||||||
|
args ...any,
|
||||||
|
) (
|
||||||
|
l []*Session,
|
||||||
|
err error,
|
||||||
|
) {
|
||||||
|
for row, err := range Session_Iterate(tx, query, args...) {
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
l = append(l, row)
|
||||||
|
}
|
||||||
|
return l, nil
|
||||||
|
}
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
// Table: networks
|
// Table: networks
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
type Network struct {
|
type Network struct {
|
||||||
NetworkID int64
|
NetworkID int64
|
||||||
LocalDomain string
|
Name string
|
||||||
Network []byte
|
Network []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
const Network_SelectQuery = "SELECT NetworkID,LocalDomain,Network FROM networks"
|
const Network_SelectQuery = "SELECT NetworkID,Name,Network FROM networks"
|
||||||
|
|
||||||
func Network_Insert(
|
func Network_Insert(
|
||||||
tx TX,
|
tx TX,
|
||||||
@@ -207,7 +321,7 @@ func Network_Insert(
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = tx.Exec("INSERT INTO networks(NetworkID,LocalDomain,Network) VALUES(?,?,?)", row.NetworkID, row.LocalDomain, row.Network)
|
_, err = tx.Exec("INSERT INTO networks(NetworkID,Name,Network) VALUES(?,?,?)", row.NetworkID, row.Name, row.Network)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -220,7 +334,7 @@ func Network_UpdateFull(
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := tx.Exec("UPDATE networks SET LocalDomain=?,Network=? WHERE NetworkID=?", row.LocalDomain, row.Network, row.NetworkID)
|
result, err := tx.Exec("UPDATE networks SET Name=?,Network=? WHERE NetworkID=?", row.Name, row.Network, row.NetworkID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -270,10 +384,8 @@ func Network_Get(
|
|||||||
err error,
|
err error,
|
||||||
) {
|
) {
|
||||||
row = &Network{}
|
row = &Network{}
|
||||||
r := tx.QueryRow("SELECT NetworkID,LocalDomain,Network FROM networks WHERE NetworkID=?", NetworkID)
|
r := tx.QueryRow("SELECT NetworkID,Name,Network FROM networks WHERE NetworkID=?", NetworkID)
|
||||||
if err = r.Scan(&row.NetworkID, &row.LocalDomain, &row.Network); err != nil {
|
err = r.Scan(&row.NetworkID, &row.Name, &row.Network)
|
||||||
row = nil
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -287,9 +399,7 @@ func Network_GetWhere(
|
|||||||
) {
|
) {
|
||||||
row = &Network{}
|
row = &Network{}
|
||||||
r := tx.QueryRow(query, args...)
|
r := tx.QueryRow(query, args...)
|
||||||
if err = r.Scan(&row.NetworkID, &row.LocalDomain, &row.Network); err != nil {
|
err = r.Scan(&row.NetworkID, &row.Name, &row.Network)
|
||||||
row = nil
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -309,7 +419,7 @@ func Network_Iterate(
|
|||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
row := &Network{}
|
row := &Network{}
|
||||||
err := rows.Scan(&row.NetworkID, &row.LocalDomain, &row.Network)
|
err := rows.Scan(&row.NetworkID, &row.Name, &row.Network)
|
||||||
if !yield(row, err) {
|
if !yield(row, err) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -341,17 +451,17 @@ func Network_List(
|
|||||||
type Peer struct {
|
type Peer struct {
|
||||||
NetworkID int64
|
NetworkID int64
|
||||||
PeerIP byte
|
PeerIP byte
|
||||||
|
Version int64
|
||||||
APIKey string
|
APIKey string
|
||||||
Name string
|
Name string
|
||||||
Addr4 []byte
|
PublicIP []byte
|
||||||
Addr6 []byte
|
|
||||||
Port uint16
|
Port uint16
|
||||||
Relay bool
|
Relay bool
|
||||||
WGPubKey []byte
|
PubKey []byte
|
||||||
SignPubKey []byte
|
PubSignKey []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
const Peer_SelectQuery = "SELECT NetworkID,PeerIP,APIKey,Name,Addr4,Addr6,Port,Relay,WGPubKey,SignPubKey FROM peers"
|
const Peer_SelectQuery = "SELECT NetworkID,PeerIP,Version,APIKey,Name,PublicIP,Port,Relay,PubKey,PubSignKey FROM peers"
|
||||||
|
|
||||||
func Peer_Insert(
|
func Peer_Insert(
|
||||||
tx TX,
|
tx TX,
|
||||||
@@ -362,10 +472,38 @@ func Peer_Insert(
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = tx.Exec("INSERT INTO peers(NetworkID,PeerIP,APIKey,Name,Addr4,Addr6,Port,Relay,WGPubKey,SignPubKey) VALUES(?,?,?,?,?,?,?,?,?,?)", row.NetworkID, row.PeerIP, row.APIKey, row.Name, row.Addr4, row.Addr6, row.Port, row.Relay, row.WGPubKey, row.SignPubKey)
|
_, err = tx.Exec("INSERT INTO peers(NetworkID,PeerIP,Version,APIKey,Name,PublicIP,Port,Relay,PubKey,PubSignKey) VALUES(?,?,?,?,?,?,?,?,?,?)", row.NetworkID, row.PeerIP, row.Version, row.APIKey, row.Name, row.PublicIP, row.Port, row.Relay, row.PubKey, row.PubSignKey)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Peer_Update(
|
||||||
|
tx TX,
|
||||||
|
row *Peer,
|
||||||
|
) (err error) {
|
||||||
|
Peer_Sanitize(row)
|
||||||
|
if err = Peer_Validate(row); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := tx.Exec("UPDATE peers SET Version=?,Name=?,PublicIP=?,Port=?,Relay=? WHERE NetworkID=? AND PeerIP=?", row.Version, row.Name, row.PublicIP, row.Port, row.Relay, row.NetworkID, row.PeerIP)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
switch n {
|
||||||
|
case 0:
|
||||||
|
return sql.ErrNoRows
|
||||||
|
case 1:
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
panic("multiple rows updated")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func Peer_UpdateFull(
|
func Peer_UpdateFull(
|
||||||
tx TX,
|
tx TX,
|
||||||
row *Peer,
|
row *Peer,
|
||||||
@@ -375,7 +513,7 @@ func Peer_UpdateFull(
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := tx.Exec("UPDATE peers SET APIKey=?,Name=?,Addr4=?,Addr6=?,Port=?,Relay=?,WGPubKey=?,SignPubKey=? WHERE NetworkID=? AND PeerIP=?", row.APIKey, row.Name, row.Addr4, row.Addr6, row.Port, row.Relay, row.WGPubKey, row.SignPubKey, row.NetworkID, row.PeerIP)
|
result, err := tx.Exec("UPDATE peers SET Version=?,APIKey=?,Name=?,PublicIP=?,Port=?,Relay=?,PubKey=?,PubSignKey=? WHERE NetworkID=? AND PeerIP=?", row.Version, row.APIKey, row.Name, row.PublicIP, row.Port, row.Relay, row.PubKey, row.PubSignKey, row.NetworkID, row.PeerIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -427,10 +565,8 @@ func Peer_Get(
|
|||||||
err error,
|
err error,
|
||||||
) {
|
) {
|
||||||
row = &Peer{}
|
row = &Peer{}
|
||||||
r := tx.QueryRow("SELECT NetworkID,PeerIP,APIKey,Name,Addr4,Addr6,Port,Relay,WGPubKey,SignPubKey FROM peers WHERE NetworkID=? AND PeerIP=?", NetworkID, PeerIP)
|
r := tx.QueryRow("SELECT NetworkID,PeerIP,Version,APIKey,Name,PublicIP,Port,Relay,PubKey,PubSignKey FROM peers WHERE NetworkID=? AND PeerIP=?", NetworkID, PeerIP)
|
||||||
if err = r.Scan(&row.NetworkID, &row.PeerIP, &row.APIKey, &row.Name, &row.Addr4, &row.Addr6, &row.Port, &row.Relay, &row.WGPubKey, &row.SignPubKey); err != nil {
|
err = r.Scan(&row.NetworkID, &row.PeerIP, &row.Version, &row.APIKey, &row.Name, &row.PublicIP, &row.Port, &row.Relay, &row.PubKey, &row.PubSignKey)
|
||||||
row = nil
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -444,9 +580,7 @@ func Peer_GetWhere(
|
|||||||
) {
|
) {
|
||||||
row = &Peer{}
|
row = &Peer{}
|
||||||
r := tx.QueryRow(query, args...)
|
r := tx.QueryRow(query, args...)
|
||||||
if err = r.Scan(&row.NetworkID, &row.PeerIP, &row.APIKey, &row.Name, &row.Addr4, &row.Addr6, &row.Port, &row.Relay, &row.WGPubKey, &row.SignPubKey); err != nil {
|
err = r.Scan(&row.NetworkID, &row.PeerIP, &row.Version, &row.APIKey, &row.Name, &row.PublicIP, &row.Port, &row.Relay, &row.PubKey, &row.PubSignKey)
|
||||||
row = nil
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -466,7 +600,7 @@ func Peer_Iterate(
|
|||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
row := &Peer{}
|
row := &Peer{}
|
||||||
err := rows.Scan(&row.NetworkID, &row.PeerIP, &row.APIKey, &row.Name, &row.Addr4, &row.Addr6, &row.Port, &row.Relay, &row.WGPubKey, &row.SignPubKey)
|
err := rows.Scan(&row.NetworkID, &row.PeerIP, &row.Version, &row.APIKey, &row.Name, &row.PublicIP, &row.Port, &row.Relay, &row.PubKey, &row.PubSignKey)
|
||||||
if !yield(row, err) {
|
if !yield(row, err) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,11 +8,9 @@ import (
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
ErrInvalidIP = errors.New("invalid IP")
|
ErrInvalidIP = errors.New("invalid IP")
|
||||||
ErrInvalidPeerIP = errors.New("invalid peer IP")
|
|
||||||
ErrNonPrivateIP = errors.New("non-private IP")
|
ErrNonPrivateIP = errors.New("non-private IP")
|
||||||
ErrInvalidPort = errors.New("invalid port")
|
ErrInvalidPort = errors.New("invalid port")
|
||||||
ErrInvalidNetName = errors.New("invalid network name")
|
ErrInvalidNetName = errors.New("invalid network name")
|
||||||
ErrNetNameNotLocal = errors.New("network name must end with .local")
|
|
||||||
ErrInvalidPeerName = errors.New("invalid peer name")
|
ErrInvalidPeerName = errors.New("invalid peer name")
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -23,8 +21,15 @@ func Config_Validate(c *Config) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Session_Sanitize(s *Session) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func Session_Validate(s *Session) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func Network_Sanitize(n *Network) {
|
func Network_Sanitize(n *Network) {
|
||||||
n.LocalDomain = strings.TrimSpace(n.LocalDomain)
|
n.Name = strings.TrimSpace(n.Name)
|
||||||
|
|
||||||
if addr, ok := netip.AddrFromSlice(n.Network); ok {
|
if addr, ok := netip.AddrFromSlice(n.Network); ok {
|
||||||
n.Network = addr.AsSlice()
|
n.Network = addr.AsSlice()
|
||||||
@@ -32,17 +37,12 @@ func Network_Sanitize(n *Network) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Network_Validate(c *Network) error {
|
func Network_Validate(c *Network) error {
|
||||||
// 15 bytes is linux limit for network interface names. With ending .local,
|
// 16 bytes is linux limit for network interface names.
|
||||||
// max length is 21.
|
if len(c.Name) == 0 || len(c.Name) > 16 {
|
||||||
if len(c.LocalDomain) == 0 || len(c.LocalDomain) > 21 {
|
|
||||||
return ErrInvalidNetName
|
return ErrInvalidNetName
|
||||||
}
|
}
|
||||||
|
|
||||||
if !strings.HasSuffix(c.LocalDomain, ".local") {
|
for _, c := range c.Name {
|
||||||
return ErrNetNameNotLocal
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, c := range strings.TrimSuffix(c.LocalDomain, ".local") {
|
|
||||||
if c >= 'a' && c <= 'z' {
|
if c >= 'a' && c <= 'z' {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -66,35 +66,21 @@ func Network_Validate(c *Network) error {
|
|||||||
|
|
||||||
func Peer_Sanitize(p *Peer) {
|
func Peer_Sanitize(p *Peer) {
|
||||||
p.Name = strings.TrimSpace(p.Name)
|
p.Name = strings.TrimSpace(p.Name)
|
||||||
if len(p.Addr4) != 0 {
|
if len(p.PublicIP) != 0 {
|
||||||
if addr, ok := netip.AddrFromSlice(p.Addr4); ok {
|
addr, ok := netip.AddrFromSlice(p.PublicIP)
|
||||||
// Unmap so an IPv4-mapped form is stored canonically as 4 bytes.
|
if ok && addr.Is4() {
|
||||||
p.Addr4 = addr.Unmap().AsSlice()
|
p.PublicIP = addr.AsSlice()
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(p.Addr6) != 0 {
|
|
||||||
if addr, ok := netip.AddrFromSlice(p.Addr6); ok {
|
|
||||||
p.Addr6 = addr.AsSlice()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if p.Port == 0 {
|
if p.Port == 0 {
|
||||||
p.Port = 51820
|
p.Port = 456
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Peer_Validate(p *Peer) error {
|
func Peer_Validate(p *Peer) error {
|
||||||
if p.PeerIP < 1 || p.PeerIP > 254 {
|
if len(p.PublicIP) > 0 {
|
||||||
return ErrInvalidPeerIP
|
_, ok := netip.AddrFromSlice(p.PublicIP)
|
||||||
}
|
if !ok {
|
||||||
if len(p.Addr4) > 0 {
|
|
||||||
// Must be a genuine IPv4 address (reject an IPv6 in the v4 field).
|
|
||||||
if addr, ok := netip.AddrFromSlice(p.Addr4); !ok || !addr.Is4() {
|
|
||||||
return ErrInvalidIP
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(p.Addr6) > 0 {
|
|
||||||
// Must be a genuine IPv6 address (reject IPv4 / IPv4-mapped in the v6 field).
|
|
||||||
if addr, ok := netip.AddrFromSlice(p.Addr6); !ok || !addr.Is6() || addr.Is4In6() {
|
|
||||||
return ErrInvalidIP
|
return ErrInvalidIP
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -102,9 +88,6 @@ func Peer_Validate(p *Peer) error {
|
|||||||
return ErrInvalidPort
|
return ErrInvalidPort
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(p.Name) == 0 {
|
|
||||||
return ErrInvalidPeerName
|
|
||||||
}
|
|
||||||
for _, c := range p.Name {
|
for _, c := range p.Name {
|
||||||
if c >= 'a' && c <= 'z' {
|
if c >= 'a' && c <= 'z' {
|
||||||
continue
|
continue
|
||||||
@@ -112,9 +95,10 @@ func Peer_Validate(p *Peer) error {
|
|||||||
if c >= '0' && c <= '9' {
|
if c >= '0' && c <= '9' {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if c == '-' {
|
if c == '.' || c == '-' || c == '_' {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
return ErrInvalidPeerName
|
return ErrInvalidPeerName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,21 +3,29 @@ TABLE config OF Config (
|
|||||||
Password []byte
|
Password []byte
|
||||||
);
|
);
|
||||||
|
|
||||||
|
TABLE sessions OF Session NoUpdate (
|
||||||
|
SessionID string PK,
|
||||||
|
CSRF string,
|
||||||
|
SignedIn bool,
|
||||||
|
CreatedAt int64,
|
||||||
|
LastSeenAt int64
|
||||||
|
);
|
||||||
|
|
||||||
TABLE networks OF Network (
|
TABLE networks OF Network (
|
||||||
NetworkID int64 PK,
|
NetworkID int64 PK,
|
||||||
LocalDomain string NoUpdate,
|
Name string NoUpdate,
|
||||||
Network []byte NoUpdate
|
Network []byte NoUpdate
|
||||||
);
|
);
|
||||||
|
|
||||||
TABLE peers OF Peer (
|
TABLE peers OF Peer (
|
||||||
NetworkID int64 PK,
|
NetworkID int64 PK,
|
||||||
PeerIP byte PK,
|
PeerIP byte PK,
|
||||||
|
Version int64,
|
||||||
APIKey string NoUpdate,
|
APIKey string NoUpdate,
|
||||||
Name string NoUpdate,
|
Name string,
|
||||||
Addr4 []byte NoUpdate,
|
PublicIP []byte,
|
||||||
Addr6 []byte NoUpdate,
|
Port uint16,
|
||||||
Port uint16 NoUpdate,
|
Relay bool,
|
||||||
Relay bool NoUpdate,
|
PubKey []byte NoUpdate,
|
||||||
WGPubKey []byte NoUpdate,
|
PubSignKey []byte NoUpdate
|
||||||
SignPubKey []byte NoUpdate
|
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -1,5 +1,31 @@
|
|||||||
package db
|
package db
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
func Session_UpdateLastSeenAt(
|
||||||
|
tx TX,
|
||||||
|
id string,
|
||||||
|
) (err error) {
|
||||||
|
_, err = tx.Exec("UPDATE sessions SET LastSeenAt=? WHERE SessionID=?", time.Now().Unix(), id)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func Session_SetSignedIn(
|
||||||
|
tx TX,
|
||||||
|
id string,
|
||||||
|
) (err error) {
|
||||||
|
_, err = tx.Exec("UPDATE sessions SET SignedIn=1 WHERE SessionID=?", id)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func Session_DeleteBefore(
|
||||||
|
tx TX,
|
||||||
|
timestamp int64,
|
||||||
|
) (err error) {
|
||||||
|
_, err = tx.Exec("DELETE FROM sessions WHERE LastSeenAt<?", timestamp)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func Peer_ListAll(tx TX, networkID int64) ([]*Peer, error) {
|
func Peer_ListAll(tx TX, networkID int64) ([]*Peer, error) {
|
||||||
const query = Peer_SelectQuery + ` WHERE NetworkID=? ORDER BY PeerIP ASC`
|
const query = Peer_SelectQuery + ` WHERE NetworkID=? ORDER BY PeerIP ASC`
|
||||||
return Peer_List(tx, query, networkID)
|
return Peer_List(tx, query, networkID)
|
||||||
@@ -11,3 +37,9 @@ func Peer_GetByAPIKey(tx TX, apiKey string) (*Peer, error) {
|
|||||||
Peer_SelectQuery+` WHERE APIKey=?`,
|
Peer_SelectQuery+` WHERE APIKey=?`,
|
||||||
apiKey)
|
apiKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Peer_Exists(tx TX, networkID int64, ip byte) (exists bool, err error) {
|
||||||
|
const query = `SELECT EXISTS(SELECT 1 FROM peers WHERE NetworkID=? AND PeerIP=?)`
|
||||||
|
err = tx.QueryRow(query, networkID, ip).Scan(&exists)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
ErrNotAuthorized = errors.New("not authorized")
|
ErrNotAuthorized = errors.New("not authorized")
|
||||||
|
ErrNoIPAvailable = errors.New("no IP address available")
|
||||||
ErrInvalidIP = db.ErrInvalidIP
|
ErrInvalidIP = db.ErrInvalidIP
|
||||||
ErrInvalidPort = db.ErrInvalidPort
|
ErrInvalidPort = db.ErrInvalidPort
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -3,23 +3,32 @@ CREATE TABLE config (
|
|||||||
Password BLOB NOT NULL -- bcrypt password for web interface
|
Password BLOB NOT NULL -- bcrypt password for web interface
|
||||||
) WITHOUT ROWID;
|
) WITHOUT ROWID;
|
||||||
|
|
||||||
|
CREATE TABLE sessions (
|
||||||
|
SessionID TEXT NOT NULL PRIMARY KEY,
|
||||||
|
CSRF TEXT NOT NULL,
|
||||||
|
SignedIn INTEGER NOT NULL,
|
||||||
|
CreatedAt INTEGER NOT NULL,
|
||||||
|
LastSeenAt INTEGER NOT NULL
|
||||||
|
) WITHOUT ROWID;
|
||||||
|
|
||||||
|
CREATE INDEX sessions_last_seen_index ON sessions(LastSeenAt);
|
||||||
|
|
||||||
CREATE TABLE networks (
|
CREATE TABLE networks (
|
||||||
NetworkID INTEGER NOT NULL PRIMARY KEY,
|
NetworkID INTEGER NOT NULL PRIMARY KEY,
|
||||||
LocalDomain TEXT NOT NULL UNIQUE, -- Network/interface name.
|
Name TEXT NOT NULL UNIQUE, -- Network/interface name.
|
||||||
Network BLOB NOT NULL UNIQUE -- Network (/24), example 10.51.50.0
|
Network BLOB NOT NULL UNIQUE -- Network (/24), example 10.51.50.0
|
||||||
) WITHOUT ROWID;
|
) WITHOUT ROWID;
|
||||||
|
|
||||||
CREATE TABLE peers (
|
CREATE TABLE peers (
|
||||||
NetworkID INTEGER NOT NULL,
|
NetworkID INTEGER NOT NULL,
|
||||||
PeerIP INTEGER NOT NULL, -- Final byte of IP.
|
PeerIP INTEGER NOT NULL, -- Final byte of IP.
|
||||||
APIKey TEXT NOT NULL UNIQUE, -- Peer's secret API key.
|
Version INTEGER NOT NULL, -- Changes when updated.
|
||||||
Name TEXT NOT NULL, -- For humans.
|
APIKey TEXT NOT NULL UNIQUE, -- Peer's secret API key.
|
||||||
Addr4 BLOB NOT NULL,
|
Name TEXT NOT NULL UNIQUE, -- For humans.
|
||||||
Addr6 BLOB NOT NULL,
|
PublicIP BLOB NOT NULL,
|
||||||
Port INTEGER NOT NULL,
|
Port INTEGER NOT NULL,
|
||||||
Relay INTEGER NOT NULL DEFAULT 0, -- Boolean if peer will forward packets.
|
Relay INTEGER NOT NULL DEFAULT 0, -- Boolean if peer will forward packets. Must also have public address.
|
||||||
WGPubKey BLOB NOT NULL,
|
PubKey BLOB NOT NULL,
|
||||||
SignPubKey BLOB NOT NULL,
|
PubSignKey BLOB NOT NULL,
|
||||||
UNIQUE(NetworkID, Name),
|
|
||||||
PRIMARY KEY(NetworkID, PeerIP)
|
PRIMARY KEY(NetworkID, PeerIP)
|
||||||
) WITHOUT ROWID;
|
) WITHOUT ROWID;
|
||||||
|
|||||||
@@ -3,12 +3,6 @@ package api
|
|||||||
import "vppn/hub/api/db"
|
import "vppn/hub/api/db"
|
||||||
|
|
||||||
type Config = db.Config
|
type Config = db.Config
|
||||||
|
type Session = db.Session
|
||||||
type Network = db.Network
|
type Network = db.Network
|
||||||
type Peer = db.Peer
|
type Peer = db.Peer
|
||||||
|
|
||||||
type Session struct {
|
|
||||||
SessionID string
|
|
||||||
SignedIn bool
|
|
||||||
CreatedAt int64
|
|
||||||
LastSeenAt int64
|
|
||||||
}
|
|
||||||
|
|||||||
16
hub/app.go
16
hub/app.go
@@ -2,7 +2,6 @@ package hub
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"embed"
|
"embed"
|
||||||
"encoding/base64"
|
|
||||||
"html/template"
|
"html/template"
|
||||||
"net/http"
|
"net/http"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@@ -48,19 +47,6 @@ func NewApp(conf Config) (*App, error) {
|
|||||||
return app, nil
|
return app, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (app *App) Handler() http.Handler {
|
|
||||||
cop := http.NewCrossOriginProtection()
|
|
||||||
return cop.Handler(app.mux)
|
|
||||||
}
|
|
||||||
|
|
||||||
var templateFuncs = template.FuncMap{
|
var templateFuncs = template.FuncMap{
|
||||||
"ipToString": ipBytesTostring,
|
"ipToString": ipBytesTostring,
|
||||||
"wgKeyString": wgKeyString,
|
|
||||||
}
|
|
||||||
|
|
||||||
func wgKeyString(key []byte) string {
|
|
||||||
if len(key) == 0 {
|
|
||||||
return "not set"
|
|
||||||
}
|
|
||||||
return base64.StdEncoding.EncodeToString(key)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package hub
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (a *App) getCookie(r *http.Request, name string) string {
|
func (a *App) getCookie(r *http.Request, name string) string {
|
||||||
@@ -25,12 +26,9 @@ func (a *App) setCookie(w http.ResponseWriter, name, value string) {
|
|||||||
|
|
||||||
func (a *App) deleteCookie(w http.ResponseWriter, name string) {
|
func (a *App) deleteCookie(w http.ResponseWriter, name string) {
|
||||||
http.SetCookie(w, &http.Cookie{
|
http.SetCookie(w, &http.Cookie{
|
||||||
Name: name,
|
Name: name,
|
||||||
Value: "",
|
Value: "",
|
||||||
Path: "/",
|
Path: "/",
|
||||||
Secure: !a.insecure,
|
Expires: time.Unix(0, 0),
|
||||||
SameSite: http.SameSiteStrictMode,
|
|
||||||
HttpOnly: true,
|
|
||||||
MaxAge: -1, // delete now
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
package hub
|
package hub
|
||||||
|
|
||||||
const (
|
const (
|
||||||
sessionIDCookieName = "SessionID"
|
SESSION_ID_COOKIE_NAME = "SessionID"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ type handlerFunc func(s *api.Session, w http.ResponseWriter, r *http.Request) er
|
|||||||
|
|
||||||
func (app *App) handlePub(pattern string, fn handlerFunc) {
|
func (app *App) handlePub(pattern string, fn handlerFunc) {
|
||||||
wrapped := func(w http.ResponseWriter, r *http.Request) {
|
wrapped := func(w http.ResponseWriter, r *http.Request) {
|
||||||
sessionID := app.getCookie(r, sessionIDCookieName)
|
sessionID := app.getCookie(r, SESSION_ID_COOKIE_NAME)
|
||||||
s, err := app.api.Session_Get(sessionID)
|
s, err := app.api.Session_Get(sessionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to get session: %v", err)
|
log.Printf("Failed to get session: %v", err)
|
||||||
@@ -20,13 +20,22 @@ func (app *App) handlePub(pattern string, fn handlerFunc) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.SessionID != sessionID {
|
||||||
|
app.setCookie(w, SESSION_ID_COOKIE_NAME, s.SessionID)
|
||||||
|
}
|
||||||
|
|
||||||
if r.Method == http.MethodPost {
|
if r.Method == http.MethodPost {
|
||||||
r.ParseMultipartForm(64 * 1024)
|
r.ParseMultipartForm(64 * 1024)
|
||||||
|
if r.FormValue("CSRF") != s.CSRF {
|
||||||
|
log.Printf("%s != %s", r.FormValue("CSRF"), s.CSRF)
|
||||||
|
http.Error(w, "CSRF mismatch", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
r.ParseForm()
|
r.ParseForm()
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := fn(&s, w, r); err != nil {
|
if err := fn(s, w, r); err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
153
hub/handlers.go
153
hub/handlers.go
@@ -5,12 +5,13 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
|
"strings"
|
||||||
"vppn/hub/api"
|
"vppn/hub/api"
|
||||||
"vppn/m"
|
"vppn/m"
|
||||||
|
|
||||||
"git.crumpington.com/lib/go/webutil"
|
"git.crumpington.com/lib/go/webutil"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (a *App) _root(s *api.Session, w http.ResponseWriter, r *http.Request) error {
|
func (a *App) _root(s *api.Session, w http.ResponseWriter, r *http.Request) error {
|
||||||
@@ -34,11 +35,9 @@ func (a *App) _signinSubmit(s *api.Session, w http.ResponseWriter, r *http.Reque
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
sess, err := a.api.Session_SignIn(pwd)
|
if err := a.api.Session_SignIn(s, pwd); err != nil {
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
a.setCookie(w, sessionIDCookieName, sess.SessionID)
|
|
||||||
|
|
||||||
return a.redirect(w, r, "/")
|
return a.redirect(w, r, "/")
|
||||||
}
|
}
|
||||||
@@ -51,7 +50,7 @@ func (a *App) _adminSignOutSubmit(s *api.Session, w http.ResponseWriter, r *http
|
|||||||
if err := a.api.Session_Delete(s.SessionID); err != nil {
|
if err := a.api.Session_Delete(s.SessionID); err != nil {
|
||||||
log.Printf("Failed to delete session cookie %s: %v", s.SessionID, err)
|
log.Printf("Failed to delete session cookie %s: %v", s.SessionID, err)
|
||||||
}
|
}
|
||||||
a.deleteCookie(w, sessionIDCookieName)
|
a.deleteCookie(w, SESSION_ID_COOKIE_NAME)
|
||||||
return a.redirect(w, r, "/")
|
return a.redirect(w, r, "/")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -75,7 +74,7 @@ func (a *App) _adminNetworkCreateSubmit(s *api.Session, w http.ResponseWriter, r
|
|||||||
var netStr string
|
var netStr string
|
||||||
|
|
||||||
err := webutil.NewFormScanner(r.Form).
|
err := webutil.NewFormScanner(r.Form).
|
||||||
Scan("LocalDomain", &n.LocalDomain).
|
Scan("Name", &n.Name).
|
||||||
Scan("Network", &netStr).
|
Scan("Network", &netStr).
|
||||||
Error()
|
Error()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -145,15 +144,14 @@ func (a *App) _adminPeerCreate(s *api.Session, w http.ResponseWriter, r *http.Re
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *App) _adminPeerCreateSubmit(s *api.Session, w http.ResponseWriter, r *http.Request) error {
|
func (a *App) _adminPeerCreateSubmit(s *api.Session, w http.ResponseWriter, r *http.Request) error {
|
||||||
var addr4Str, addr6Str string
|
var ipStr string
|
||||||
|
|
||||||
p := &api.Peer{}
|
p := &api.Peer{}
|
||||||
err := webutil.NewFormScanner(r.Form).
|
err := webutil.NewFormScanner(r.Form).
|
||||||
Scan("NetworkID", &p.NetworkID).
|
Scan("NetworkID", &p.NetworkID).
|
||||||
Scan("IP", &p.PeerIP).
|
Scan("IP", &p.PeerIP).
|
||||||
Scan("Name", &p.Name).
|
Scan("Name", &p.Name).
|
||||||
Scan("Addr4", &addr4Str).
|
Scan("PublicIP", &ipStr).
|
||||||
Scan("Addr6", &addr6Str).
|
|
||||||
Scan("Port", &p.Port).
|
Scan("Port", &p.Port).
|
||||||
Scan("Relay", &p.Relay).
|
Scan("Relay", &p.Relay).
|
||||||
Error()
|
Error()
|
||||||
@@ -161,10 +159,7 @@ func (a *App) _adminPeerCreateSubmit(s *api.Session, w http.ResponseWriter, r *h
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.Addr4, err = stringToIP(addr4Str); err != nil {
|
if p.PublicIP, err = stringToIP(ipStr); err != nil {
|
||||||
return err
|
|
||||||
}
|
|
||||||
if p.Addr6, err = stringToIP(addr6Str); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -187,6 +182,48 @@ func (a *App) _adminPeerView(s *api.Session, w http.ResponseWriter, r *http.Requ
|
|||||||
}{s, net, peer})
|
}{s, net, peer})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *App) _adminPeerEdit(s *api.Session, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
net, peer, err := a.formGetPeer(r.Form)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return a.render("/network/peer-edit.html", w, struct {
|
||||||
|
Session *api.Session
|
||||||
|
Network *api.Network
|
||||||
|
Peer *api.Peer
|
||||||
|
}{s, net, peer})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *App) _adminPeerEditSubmit(s *api.Session, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
_, peer, err := a.formGetPeer(r.Form)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var ipStr string
|
||||||
|
|
||||||
|
err = webutil.NewFormScanner(r.Form).
|
||||||
|
Scan("Name", &peer.Name).
|
||||||
|
Scan("PublicIP", &ipStr).
|
||||||
|
Scan("Port", &peer.Port).
|
||||||
|
Scan("Relay", &peer.Relay).
|
||||||
|
Error()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if peer.PublicIP, err = stringToIP(ipStr); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = a.api.Peer_Update(peer); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return a.redirect(w, r, "/admin/peer/view/?NetworkID=%d&PeerIP=%d", peer.NetworkID, peer.PeerIP)
|
||||||
|
}
|
||||||
|
|
||||||
func (a *App) _adminPeerDelete(s *api.Session, w http.ResponseWriter, r *http.Request) error {
|
func (a *App) _adminPeerDelete(s *api.Session, w http.ResponseWriter, r *http.Request) error {
|
||||||
n, peer, err := a.formGetPeer(r.Form)
|
n, peer, err := a.formGetPeer(r.Form)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -211,23 +248,40 @@ func (a *App) _adminPeerDeleteSubmit(s *api.Session, w http.ResponseWriter, r *h
|
|||||||
return a.redirect(w, r, "/admin/network/view/?NetworkID=%d", n.NetworkID)
|
return a.redirect(w, r, "/admin/network/view/?NetworkID=%d", n.NetworkID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *App) _adminNetworkHosts(s *api.Session, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
n, peers, err := a.formGetNetworkPeers(r.Form)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
b := strings.Builder{}
|
||||||
|
|
||||||
|
for _, peer := range peers {
|
||||||
|
ip := n.Network
|
||||||
|
ip[3] = peer.PeerIP
|
||||||
|
b.WriteString(netip.AddrFrom4([4]byte(ip)).String())
|
||||||
|
b.WriteString(" ")
|
||||||
|
b.WriteString(peer.Name)
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Write([]byte(b.String()))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *App) _adminPasswordEdit(s *api.Session, w http.ResponseWriter, r *http.Request) error {
|
func (a *App) _adminPasswordEdit(s *api.Session, w http.ResponseWriter, r *http.Request) error {
|
||||||
return a.render("/admin-password-edit.html", w, struct{ Session *api.Session }{s})
|
return a.render("/admin-password-edit.html", w, struct{ Session *api.Session }{s})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *App) _adminPasswordSubmit(s *api.Session, w http.ResponseWriter, r *http.Request) error {
|
func (a *App) _adminPasswordSubmit(s *api.Session, w http.ResponseWriter, r *http.Request) error {
|
||||||
var (
|
var (
|
||||||
|
conf = a.api.Config_Get()
|
||||||
curPwd string
|
curPwd string
|
||||||
newPwd string
|
newPwd string
|
||||||
newPwd2 string
|
newPwd2 string
|
||||||
)
|
)
|
||||||
|
|
||||||
conf, err := a.api.Config_Get()
|
err := webutil.NewFormScanner(r.Form).
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = webutil.NewFormScanner(r.Form).
|
|
||||||
Scan("CurrentPassword", &curPwd).
|
Scan("CurrentPassword", &curPwd).
|
||||||
Scan("NewPassword", &newPwd).
|
Scan("NewPassword", &newPwd).
|
||||||
Scan("NewPassword2", &newPwd2).
|
Scan("NewPassword2", &newPwd2).
|
||||||
@@ -264,25 +318,11 @@ func (a *App) _adminPasswordSubmit(s *api.Session, w http.ResponseWriter, r *htt
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *App) _peerInit(peer *api.Peer, w http.ResponseWriter, r *http.Request) error {
|
func (a *App) _peerInit(peer *api.Peer, w http.ResponseWriter, r *http.Request) error {
|
||||||
if len(peer.WGPubKey) != 0 {
|
|
||||||
http.Error(w, "Already initialized", http.StatusConflict)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
args := m.PeerInitArgs{}
|
args := m.PeerInitArgs{}
|
||||||
if err := json.NewDecoder(r.Body).Decode(&args); err != nil {
|
if err := json.NewDecoder(r.Body).Decode(&args); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(args.WGPubKey) != 32 {
|
|
||||||
http.Error(w, "invalid WGPubKey", http.StatusBadRequest)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if len(args.SignPubKey) != 32 {
|
|
||||||
http.Error(w, "invalid SignPubKey", http.StatusBadRequest)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
net, err := a.api.Network_Get(peer.NetworkID)
|
net, err := a.api.Network_Get(peer.NetworkID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -293,12 +333,11 @@ func (a *App) _peerInit(peer *api.Peer, w http.ResponseWriter, r *http.Request)
|
|||||||
}
|
}
|
||||||
|
|
||||||
resp := m.PeerInitResp{
|
resp := m.PeerInitResp{
|
||||||
PeerIP: peer.PeerIP,
|
PeerIP: peer.PeerIP,
|
||||||
Network: net.Network,
|
Network: net.Network,
|
||||||
LocalDomain: net.LocalDomain,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.NetworkState.Peers, err = a.peersList(net.NetworkID)
|
resp.NetworkState.Peers, err = a.peersArray(net.NetworkID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -307,42 +346,34 @@ func (a *App) _peerInit(peer *api.Peer, w http.ResponseWriter, r *http.Request)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *App) _peerFetchState(peer *api.Peer, w http.ResponseWriter, r *http.Request) error {
|
func (a *App) _peerFetchState(peer *api.Peer, w http.ResponseWriter, r *http.Request) error {
|
||||||
peers, err := a.peersList(peer.NetworkID)
|
|
||||||
|
peers, err := a.peersArray(peer.NetworkID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return a.sendJSON(w, m.NetworkState{Peers: peers})
|
return a.sendJSON(w, m.NetworkState{Peers: peers})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *App) peersList(networkID int64) (peers []m.Peer, err error) {
|
func (a *App) peersArray(networkID int64) (peers [256]*m.Peer, err error) {
|
||||||
l, err := a.api.Peer_List(networkID)
|
l, err := a.api.Peer_List(networkID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return peers, err
|
||||||
}
|
}
|
||||||
|
|
||||||
peers = make([]m.Peer, 0, len(l))
|
|
||||||
|
|
||||||
for _, p := range l {
|
for _, p := range l {
|
||||||
if len(p.WGPubKey) == 0 {
|
if len(p.PubKey) != 0 {
|
||||||
continue
|
peers[p.PeerIP] = &m.Peer{
|
||||||
|
PeerIP: p.PeerIP,
|
||||||
|
Version: p.Version,
|
||||||
|
Name: p.Name,
|
||||||
|
PublicIP: p.PublicIP,
|
||||||
|
Port: p.Port,
|
||||||
|
Relay: p.Relay,
|
||||||
|
PubKey: p.PubKey,
|
||||||
|
PubSignKey: p.PubSignKey,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
wgKey, err := wgtypes.NewKey(p.WGPubKey)
|
|
||||||
if err != nil {
|
|
||||||
continue // malformed key; skip rather than serve garbage
|
|
||||||
}
|
|
||||||
var signKey [32]byte
|
|
||||||
copy(signKey[:], p.SignPubKey)
|
|
||||||
peers = append(peers, m.Peer{
|
|
||||||
PeerIP: p.PeerIP,
|
|
||||||
Name: p.Name,
|
|
||||||
Addr4: addrFromBytes(p.Addr4),
|
|
||||||
Addr6: addrFromBytes(p.Addr6),
|
|
||||||
Port: p.Port,
|
|
||||||
Relay: p.Relay,
|
|
||||||
WGPubKey: wgKey,
|
|
||||||
SignPubKey: signKey,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return peers, nil
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ func Main() {
|
|||||||
|
|
||||||
srv := &http.Server{
|
srv := &http.Server{
|
||||||
Addr: conf.ListenAddr,
|
Addr: conf.ListenAddr,
|
||||||
Handler: app.Handler(),
|
Handler: app.mux,
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Fatal(webutil.ListenAndServe(srv))
|
log.Fatal(webutil.ListenAndServe(srv))
|
||||||
|
|||||||
@@ -19,9 +19,12 @@ func (a *App) registerRoutes() {
|
|||||||
a.handleSignedIn("POST /admin/network/delete/", a._adminNetworkDeleteSubmit)
|
a.handleSignedIn("POST /admin/network/delete/", a._adminNetworkDeleteSubmit)
|
||||||
|
|
||||||
a.handleSignedIn("GET /admin/network/view/", a._adminNetworkView)
|
a.handleSignedIn("GET /admin/network/view/", a._adminNetworkView)
|
||||||
|
a.handleSignedIn("GET /admin/network/hosts/", a._adminNetworkHosts)
|
||||||
a.handleSignedIn("GET /admin/peer/create/", a._adminPeerCreate)
|
a.handleSignedIn("GET /admin/peer/create/", a._adminPeerCreate)
|
||||||
a.handleSignedIn("POST /admin/peer/create/", a._adminPeerCreateSubmit)
|
a.handleSignedIn("POST /admin/peer/create/", a._adminPeerCreateSubmit)
|
||||||
a.handleSignedIn("GET /admin/peer/view/", a._adminPeerView)
|
a.handleSignedIn("GET /admin/peer/view/", a._adminPeerView)
|
||||||
|
a.handleSignedIn("GET /admin/peer/edit/", a._adminPeerEdit)
|
||||||
|
a.handleSignedIn("POST /admin/peer/edit/", a._adminPeerEditSubmit)
|
||||||
a.handleSignedIn("GET /admin/peer/delete/", a._adminPeerDelete)
|
a.handleSignedIn("GET /admin/peer/delete/", a._adminPeerDelete)
|
||||||
a.handleSignedIn("POST /admin/peer/delete/", a._adminPeerDeleteSubmit)
|
a.handleSignedIn("POST /admin/peer/delete/", a._adminPeerDeleteSubmit)
|
||||||
|
|
||||||
|
|||||||
@@ -2,9 +2,10 @@
|
|||||||
<h2>Create Network</h2>
|
<h2>Create Network</h2>
|
||||||
|
|
||||||
<form method="POST">
|
<form method="POST">
|
||||||
<p>
|
<input type="hidden" name="CSRF" value="{{.Session.CSRF}}">
|
||||||
<label>Local Domain (ending with .local)</label><br>
|
<p>
|
||||||
<input type="text" name="LocalDomain">
|
<label>Name</label><br>
|
||||||
|
<input type="text" name="Name">
|
||||||
</p>
|
</p>
|
||||||
<p>
|
<p>
|
||||||
<label>Network /24</label><br>
|
<label>Network /24</label><br>
|
||||||
|
|||||||
@@ -9,7 +9,7 @@
|
|||||||
<table>
|
<table>
|
||||||
<thead>
|
<thead>
|
||||||
<tr>
|
<tr>
|
||||||
<th>Local Domain</th>
|
<th>Name</th>
|
||||||
<th>Network</th>
|
<th>Network</th>
|
||||||
</tr>
|
</tr>
|
||||||
</thead>
|
</thead>
|
||||||
@@ -18,7 +18,7 @@
|
|||||||
<tr>
|
<tr>
|
||||||
<td>
|
<td>
|
||||||
<a href="/admin/network/view/?NetworkID={{.NetworkID}}">
|
<a href="/admin/network/view/?NetworkID={{.NetworkID}}">
|
||||||
{{.LocalDomain}}
|
{{.Name}}
|
||||||
</a>
|
</a>
|
||||||
</td>
|
</td>
|
||||||
<td>{{ipToString .Network}}</td>
|
<td>{{ipToString .Network}}</td>
|
||||||
|
|||||||
@@ -2,7 +2,8 @@
|
|||||||
<h2>Change Password</h2>
|
<h2>Change Password</h2>
|
||||||
|
|
||||||
<form method="POST">
|
<form method="POST">
|
||||||
<p>
|
<input type="hidden" name="CSRF" value="{{.Session.CSRF}}">
|
||||||
|
<p>
|
||||||
<label>Current Password</label><br>
|
<label>Current Password</label><br>
|
||||||
<input type="password" name="CurrentPassword">
|
<input type="password" name="CurrentPassword">
|
||||||
</p>
|
</p>
|
||||||
|
|||||||
@@ -2,7 +2,8 @@
|
|||||||
<h2>Sign Out</h2>
|
<h2>Sign Out</h2>
|
||||||
|
|
||||||
<form method="POST">
|
<form method="POST">
|
||||||
<p>
|
<input type="hidden" name="CSRF" value="{{.Session.CSRF}}">
|
||||||
|
<p>
|
||||||
<button type="submit">Sign Out</button>
|
<button type="submit">Sign Out</button>
|
||||||
<a href="/">Cancel</a>
|
<a href="/">Cancel</a>
|
||||||
</p>
|
</p>
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
</header>
|
</header>
|
||||||
<h2>
|
<h2>
|
||||||
Network:
|
Network:
|
||||||
<a href="/admin/network/view/?NetworkID={{.Network.NetworkID}}">{{.Network.LocalDomain}}</a>
|
<a href="/admin/network/view/?NetworkID={{.Network.NetworkID}}">{{.Network.Name}}</a>
|
||||||
</h2>
|
</h2>
|
||||||
|
|
||||||
{{block "body" .}}There's nothing here.{{end}}
|
{{block "body" .}}There's nothing here.{{end}}
|
||||||
|
|||||||
@@ -5,7 +5,8 @@
|
|||||||
<p>You must first delete all peers.</p>
|
<p>You must first delete all peers.</p>
|
||||||
{{- else -}}
|
{{- else -}}
|
||||||
<form method="POST">
|
<form method="POST">
|
||||||
<input type="hidden" name="NetworkID" value="{{.Network.NetworkID}}">
|
<input type="hidden" name="CSRF" value="{{.Session.CSRF}}">
|
||||||
|
<input type="hidden" name="NetworkID" value="{{.Network.NetworkID}}">
|
||||||
<p>
|
<p>
|
||||||
<button type="submit">Delete</button>
|
<button type="submit">Delete</button>
|
||||||
<a href="/admin/network/view/?NetworkID={{.Network.NetworkID}}">Cancel</a>
|
<a href="/admin/network/view/?NetworkID={{.Network.NetworkID}}">Cancel</a>
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
{{define "body" -}}
|
{{define "body" -}}
|
||||||
<p>
|
<p>
|
||||||
<a href="/admin/network/delete/?NetworkID={{.Network.NetworkID}}">Delete</a>
|
<a href="/admin/network/delete/?NetworkID={{.Network.NetworkID}}">Delete</a> /
|
||||||
|
<a href="/admin/network/hosts/?NetworkID={{.Network.NetworkID}}">Hosts</a>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
<table class="def-list">
|
<table class="def-list">
|
||||||
@@ -22,8 +23,7 @@
|
|||||||
<tr>
|
<tr>
|
||||||
<th>PeerIP</th>
|
<th>PeerIP</th>
|
||||||
<th>Name</th>
|
<th>Name</th>
|
||||||
<th>IPv4</th>
|
<th>Public IP</th>
|
||||||
<th>IPv6</th>
|
|
||||||
<th>Port</th>
|
<th>Port</th>
|
||||||
<th>Relay</th>
|
<th>Relay</th>
|
||||||
</tr>
|
</tr>
|
||||||
@@ -37,8 +37,7 @@
|
|||||||
</a>
|
</a>
|
||||||
</td>
|
</td>
|
||||||
<td>{{.Name}}</td>
|
<td>{{.Name}}</td>
|
||||||
<td>{{ipToString .Addr4}}</td>
|
<td>{{ipToString .PublicIP}}</td>
|
||||||
<td>{{ipToString .Addr6}}</td>
|
|
||||||
<td>{{.Port}}</td>
|
<td>{{.Port}}</td>
|
||||||
<td>{{if .Relay}}T{{else}}F{{end}}</td>
|
<td>{{if .Relay}}T{{else}}F{{end}}</td>
|
||||||
</tr>
|
</tr>
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
<h3>New Peer</h3>
|
<h3>New Peer</h3>
|
||||||
|
|
||||||
<form method="POST">
|
<form method="POST">
|
||||||
|
<input type="hidden" name="CSRF" value="{{.Session.CSRF}}">
|
||||||
<input type="hidden" name="NetworkID" value="{{.Network.NetworkID}}">
|
<input type="hidden" name="NetworkID" value="{{.Network.NetworkID}}">
|
||||||
<p>
|
<p>
|
||||||
<label>IP</label><br>
|
<label>IP</label><br>
|
||||||
@@ -12,16 +13,12 @@
|
|||||||
<input type="text" name="Name">
|
<input type="text" name="Name">
|
||||||
</p>
|
</p>
|
||||||
<p>
|
<p>
|
||||||
<label>IPv4 Address (optional)</label><br>
|
<label>Public IP</label><br>
|
||||||
<input type="text" name="Addr4">
|
<input type="text" name="PublicIP">
|
||||||
</p>
|
</p>
|
||||||
<p>
|
<p>
|
||||||
<label>IPv6 Address (optional)</label><br>
|
<label>Port</label><br>
|
||||||
<input type="text" name="Addr6">
|
<input type="number" name="Port" value="456">
|
||||||
</p>
|
|
||||||
<p>
|
|
||||||
<label>WireGuard Port</label><br>
|
|
||||||
<input type="number" name="Port" value="51820">
|
|
||||||
</p>
|
</p>
|
||||||
<p>
|
<p>
|
||||||
<label>
|
<label>
|
||||||
|
|||||||
@@ -3,8 +3,9 @@
|
|||||||
|
|
||||||
{{with .Peer -}}
|
{{with .Peer -}}
|
||||||
<form method="POST">
|
<form method="POST">
|
||||||
<input type="hidden" name="NetworkID" value="{{.NetworkID}}">
|
<input type="hidden" name="CSRF" value="{{$.Session.CSRF}}">
|
||||||
<input type="hidden" name="PeerIP" value="{{.PeerIP}}">
|
<input type="hidden" name="NetworkID" value="{{.NetworkID}}">
|
||||||
|
<input type="hidden" name="NetworkID" value="{{.PeerIP}}">
|
||||||
<p>
|
<p>
|
||||||
<button type="submit">Delete</button>
|
<button type="submit">Delete</button>
|
||||||
<a href="/admin/peer/view/?PeerIP={{.PeerIP}}&NetworkID={{.NetworkID}}">Cancel</a>
|
<a href="/admin/peer/view/?PeerIP={{.PeerIP}}&NetworkID={{.NetworkID}}">Cancel</a>
|
||||||
|
|||||||
35
hub/templates/network/peer-edit.html
Normal file
35
hub/templates/network/peer-edit.html
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
{{define "body" -}}
|
||||||
|
<h2>Edit Peer</h2>
|
||||||
|
|
||||||
|
{{with .Peer -}}
|
||||||
|
<form method="POST">
|
||||||
|
<input type="hidden" name="CSRF" value="{{$.Session.CSRF}}">
|
||||||
|
<p>
|
||||||
|
<label>Peer IP</label><br>
|
||||||
|
<input type="text" value="{{.PeerIP}}" disabled>
|
||||||
|
</p>
|
||||||
|
<p>
|
||||||
|
<label>Name</label><br>
|
||||||
|
<input type="text" name="Name" value="{{.Name}}">
|
||||||
|
</p>
|
||||||
|
<p>
|
||||||
|
<label>Public IP</label><br>
|
||||||
|
<input type="text" name="PublicIP" value="{{ipToString .PublicIP}}">
|
||||||
|
</p>
|
||||||
|
<p>
|
||||||
|
<label>Port</label><br>
|
||||||
|
<input type="number" name="Port" value="{{.Port}}">
|
||||||
|
</p>
|
||||||
|
<p>
|
||||||
|
<label>
|
||||||
|
<input type="checkbox" name="Relay" {{if .Relay}}checked{{end}}>
|
||||||
|
Relay
|
||||||
|
</label>
|
||||||
|
</p>
|
||||||
|
<p>
|
||||||
|
<button type="submit">Save</button>
|
||||||
|
<a href="/admin/peer/view/?NetworkID={{$.Network.NetworkID}}&PeerIP={{.PeerIP}}">Cancel</a>
|
||||||
|
</p>
|
||||||
|
</form>
|
||||||
|
{{- end}}
|
||||||
|
{{- end}}
|
||||||
@@ -1,17 +1,17 @@
|
|||||||
{{define "body" -}}
|
{{define "body" -}}
|
||||||
<h3>{{.Peer.Name}}</h3>
|
<h3>{{.Peer.Name}}</h3>
|
||||||
<p>
|
<p>
|
||||||
|
<a href="/admin/peer/edit/?NetworkID={{.Network.NetworkID}}&PeerIP={{.Peer.PeerIP}}">Edit</a> /
|
||||||
<a href="/admin/peer/delete/?NetworkID={{.Network.NetworkID}}&PeerIP={{.Peer.PeerIP}}">Delete</a>
|
<a href="/admin/peer/delete/?NetworkID={{.Network.NetworkID}}&PeerIP={{.Peer.PeerIP}}">Delete</a>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
{{with .Peer -}}
|
{{with .Peer -}}
|
||||||
<table class="def-list">
|
<table class="def-list">
|
||||||
<tr><td>Peer IP</td><td>{{.PeerIP}}</td></tr>
|
<tr><td>Peer IP</td><td>{{.PeerIP}}</td></tr>
|
||||||
<tr><td>IPv4 Address</td><td>{{ipToString .Addr4}}</td></tr>
|
<tr><td>Public IP</td><td>{{ipToString .PublicIP}}</td></tr>
|
||||||
<tr><td>IPv6 Address</td><td>{{ipToString .Addr6}}</td></tr>
|
<tr><td>Port</td><td>{{.Port}}</td></tr>
|
||||||
<tr><td>WireGuard Port</td><td>{{.Port}}</td></tr>
|
|
||||||
<tr><td>Relay</td><td>{{if .Relay}}T{{else}}F{{end}}</td></tr>
|
<tr><td>Relay</td><td>{{if .Relay}}T{{else}}F{{end}}</td></tr>
|
||||||
<tr><td>WG Public Key</td><td>{{wgKeyString .WGPubKey}}</td></tr>
|
</td></tr>
|
||||||
</table>
|
</table>
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
@@ -19,6 +19,7 @@
|
|||||||
<p>{{.APIKey}}</p>
|
<p>{{.APIKey}}</p>
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
|
||||||
{{- end}}
|
{{- end}}
|
||||||
|
|
||||||
{{- end}}
|
{{- end}}
|
||||||
|
|||||||
@@ -2,7 +2,8 @@
|
|||||||
<h2>Sign In</h2>
|
<h2>Sign In</h2>
|
||||||
|
|
||||||
<form method="POST">
|
<form method="POST">
|
||||||
<p>
|
<input type="hidden" name="CSRF" value="{{.Session.CSRF}}">
|
||||||
|
<p>
|
||||||
<label>Password</label><br>
|
<label>Password</label><br>
|
||||||
<input type="password" name="Password">
|
<input type="password" name="Password">
|
||||||
</p>
|
</p>
|
||||||
|
|||||||
13
hub/util.go
13
hub/util.go
@@ -38,19 +38,6 @@ func (app *App) sendJSON(w http.ResponseWriter, data any) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// addrFromBytes parses raw IP bytes (4 or 16) into a netip.Addr, unmapping
|
|
||||||
// IPv4-in-IPv6, returning the zero Addr for empty/invalid input.
|
|
||||||
func addrFromBytes(b []byte) netip.Addr {
|
|
||||||
if len(b) == 0 {
|
|
||||||
return netip.Addr{}
|
|
||||||
}
|
|
||||||
addr, ok := netip.AddrFromSlice(b)
|
|
||||||
if !ok {
|
|
||||||
return netip.Addr{}
|
|
||||||
}
|
|
||||||
return addr.Unmap()
|
|
||||||
}
|
|
||||||
|
|
||||||
func stringToIP(in string) ([]byte, error) {
|
func stringToIP(in string) ([]byte, error) {
|
||||||
in = strings.TrimSpace(in)
|
in = strings.TrimSpace(in)
|
||||||
if len(in) == 0 {
|
if len(in) == 0 {
|
||||||
|
|||||||
119
m/models.go
119
m/models.go
@@ -1,133 +1,28 @@
|
|||||||
// The package `m` contains models shared between the hub and peer programs.
|
// The package `m` contains models shared between the hub and peer programs.
|
||||||
package m
|
package m
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
)
|
|
||||||
|
|
||||||
type PeerInitArgs struct {
|
type PeerInitArgs struct {
|
||||||
WGPubKey []byte
|
EncPubKey []byte
|
||||||
SignPubKey []byte
|
PubSignKey []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
type PeerInitResp struct {
|
type PeerInitResp struct {
|
||||||
PeerIP byte
|
PeerIP byte
|
||||||
Network []byte
|
Network []byte
|
||||||
LocalDomain string
|
|
||||||
NetworkState NetworkState
|
NetworkState NetworkState
|
||||||
}
|
}
|
||||||
|
|
||||||
// Peer is the network membership record for a single peer, exchanged between
|
|
||||||
// the hub and peers. Addr4/Addr6 are the peer's public endpoint addresses (zero
|
|
||||||
// if it has none); Port is its WireGuard listen port, meaningful even for a
|
|
||||||
// non-public peer (it is the peer's own bind/beacon port).
|
|
||||||
type Peer struct {
|
type Peer struct {
|
||||||
PeerIP byte
|
PeerIP byte
|
||||||
|
Version int64
|
||||||
Name string
|
Name string
|
||||||
Addr4 netip.Addr // zero if none
|
PublicIP []byte
|
||||||
Addr6 netip.Addr // zero if none
|
|
||||||
Port uint16
|
Port uint16
|
||||||
Relay bool
|
Relay bool
|
||||||
WGPubKey wgtypes.Key
|
PubKey []byte
|
||||||
SignPubKey [32]byte
|
PubSignKey []byte
|
||||||
}
|
|
||||||
|
|
||||||
// IsPublic reports whether the peer advertises at least one reachable endpoint.
|
|
||||||
func (p Peer) IsPublic() bool {
|
|
||||||
return p.Addr4.IsValid() || p.Addr6.IsValid()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Endpoint4 returns the IPv4 endpoint (addr+port), or the zero AddrPort if the
|
|
||||||
// peer has no IPv4 address.
|
|
||||||
func (p Peer) Endpoint4() netip.AddrPort {
|
|
||||||
if !p.Addr4.IsValid() {
|
|
||||||
return netip.AddrPort{}
|
|
||||||
}
|
|
||||||
return netip.AddrPortFrom(p.Addr4, p.Port)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Endpoint6 returns the IPv6 endpoint (addr+port), or the zero AddrPort if the
|
|
||||||
// peer has no IPv6 address.
|
|
||||||
func (p Peer) Endpoint6() netip.AddrPort {
|
|
||||||
if !p.Addr6.IsValid() {
|
|
||||||
return netip.AddrPort{}
|
|
||||||
}
|
|
||||||
return netip.AddrPortFrom(p.Addr6, p.Port)
|
|
||||||
}
|
|
||||||
|
|
||||||
// PreferredEndpoint returns the IPv4 endpoint if present, else IPv6.
|
|
||||||
func (p Peer) PreferredEndpoint() netip.AddrPort {
|
|
||||||
if ep := p.Endpoint4(); ep.IsValid() {
|
|
||||||
return ep
|
|
||||||
}
|
|
||||||
return p.Endpoint6()
|
|
||||||
}
|
|
||||||
|
|
||||||
// peerJSON is the wire representation. netip.Addr fields round-trip as text
|
|
||||||
// strings automatically; only the fixed-size key arrays need base64 (otherwise
|
|
||||||
// encoding/json would emit them as arrays of numbers).
|
|
||||||
type peerJSON struct {
|
|
||||||
PeerIP byte
|
|
||||||
Name string
|
|
||||||
Addr4 netip.Addr
|
|
||||||
Addr6 netip.Addr
|
|
||||||
Port uint16
|
|
||||||
Relay bool
|
|
||||||
WGPubKey string
|
|
||||||
SignPubKey string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p Peer) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(peerJSON{
|
|
||||||
PeerIP: p.PeerIP,
|
|
||||||
Name: p.Name,
|
|
||||||
Addr4: p.Addr4,
|
|
||||||
Addr6: p.Addr6,
|
|
||||||
Port: p.Port,
|
|
||||||
Relay: p.Relay,
|
|
||||||
WGPubKey: base64.StdEncoding.EncodeToString(p.WGPubKey[:]),
|
|
||||||
SignPubKey: base64.StdEncoding.EncodeToString(p.SignPubKey[:]),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Peer) UnmarshalJSON(data []byte) error {
|
|
||||||
var j peerJSON
|
|
||||||
if err := json.Unmarshal(data, &j); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
wg, err := base64.StdEncoding.DecodeString(j.WGPubKey)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("decode WGPubKey: %w", err)
|
|
||||||
}
|
|
||||||
key, err := wgtypes.NewKey(wg)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("invalid WGPubKey: %w", err)
|
|
||||||
}
|
|
||||||
sign, err := base64.StdEncoding.DecodeString(j.SignPubKey)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("decode SignPubKey: %w", err)
|
|
||||||
}
|
|
||||||
if len(sign) != 32 {
|
|
||||||
return fmt.Errorf("invalid SignPubKey length: %d", len(sign))
|
|
||||||
}
|
|
||||||
*p = Peer{
|
|
||||||
PeerIP: j.PeerIP,
|
|
||||||
Name: j.Name,
|
|
||||||
Addr4: j.Addr4,
|
|
||||||
Addr6: j.Addr6,
|
|
||||||
Port: j.Port,
|
|
||||||
Relay: j.Relay,
|
|
||||||
WGPubKey: key,
|
|
||||||
SignPubKey: [32]byte(sign),
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type NetworkState struct {
|
type NetworkState struct {
|
||||||
Peers []Peer
|
Peers [256]*Peer
|
||||||
}
|
}
|
||||||
|
|||||||
105
peer/app.go
105
peer/app.go
@@ -1,105 +0,0 @@
|
|||||||
package peer
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"os"
|
|
||||||
"os/signal"
|
|
||||||
"syscall"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
|
|
||||||
"vppn/m"
|
|
||||||
"vppn/peer/control"
|
|
||||||
"vppn/peer/multicast"
|
|
||||||
"vppn/peer/wginterface"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ WGDevice = (*wginterface.Device)(nil) // compile-time check: Device satisfies WGDevice
|
|
||||||
|
|
||||||
const (
|
|
||||||
ControlPort = 4561
|
|
||||||
PingInterval = 8 * time.Second
|
|
||||||
TimeoutInterval = 30 * time.Second
|
|
||||||
)
|
|
||||||
|
|
||||||
// scratchSize is large enough for the biggest buffer either the ping or the
|
|
||||||
// multicast path serializes through the shared App scratch.
|
|
||||||
const scratchSize = max(control.Size, multicast.SignedPacketSize)
|
|
||||||
|
|
||||||
type PingEvent struct {
|
|
||||||
srcVPNIP netip.Addr
|
|
||||||
ping control.Ping
|
|
||||||
}
|
|
||||||
|
|
||||||
// App is the peer application. All mutable state lives here and is
|
|
||||||
// accessed only from the Run goroutine.
|
|
||||||
type App struct {
|
|
||||||
// Identity
|
|
||||||
vpnIP netip.Addr
|
|
||||||
vpnNet netip.Prefix
|
|
||||||
privKey wgtypes.Key
|
|
||||||
pubKey wgtypes.Key
|
|
||||||
isRelay bool
|
|
||||||
isPublic bool
|
|
||||||
localDomain string
|
|
||||||
|
|
||||||
// Infrastructure
|
|
||||||
dev WGDevice
|
|
||||||
controlConn ControlConn
|
|
||||||
|
|
||||||
// Peer state
|
|
||||||
relay *Peer
|
|
||||||
peersByKey map[wgtypes.Key]*Peer
|
|
||||||
peersByIP map[netip.Addr]*Peer
|
|
||||||
|
|
||||||
// Our own external endpoints, learned from Dst fields in incoming pings
|
|
||||||
selfV4 netip.AddrPort
|
|
||||||
selfV6 netip.AddrPort
|
|
||||||
|
|
||||||
// Reusable serialization scratch for outgoing pings and multicast signature
|
|
||||||
// verification. Only touched from the Run goroutine.
|
|
||||||
scratch []byte
|
|
||||||
|
|
||||||
// Event channels fed by background goroutines
|
|
||||||
hubAddCh <-chan m.Peer
|
|
||||||
hubRemoveCh <-chan wgtypes.Key
|
|
||||||
pingCh <-chan PingEvent
|
|
||||||
multicastCh <-chan multicast.Packet
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run is the main event loop. It runs until SIGTERM/SIGINT.
|
|
||||||
func (a *App) Run() error {
|
|
||||||
// Establish a clean hosts section before the first poll lands, clearing
|
|
||||||
// any stale entries left by a prior run (e.g. crash, or peers removed
|
|
||||||
// while we were down).
|
|
||||||
a.updateHosts()
|
|
||||||
|
|
||||||
ticker := time.NewTicker(PingInterval)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
sig := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
|
|
||||||
defer signal.Stop(sig)
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case p := <-a.hubAddCh:
|
|
||||||
a.onAddPeer(p)
|
|
||||||
case key := <-a.hubRemoveCh:
|
|
||||||
a.onRemovePeer(key)
|
|
||||||
case e := <-a.pingCh:
|
|
||||||
a.onPing(e)
|
|
||||||
case e := <-a.multicastCh:
|
|
||||||
a.onMulticastDiscovery(e)
|
|
||||||
case <-ticker.C:
|
|
||||||
a.onTick()
|
|
||||||
case <-sig:
|
|
||||||
return a.onShutdown()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *App) onShutdown() error {
|
|
||||||
return wginterface.Delete(a.dev.Name())
|
|
||||||
}
|
|
||||||
@@ -1,62 +0,0 @@
|
|||||||
package peer
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
|
|
||||||
"vppn/m"
|
|
||||||
"vppn/peer/multicast"
|
|
||||||
)
|
|
||||||
|
|
||||||
// addRelayPeer adds a public relay peer and marks it Up so it satisfies
|
|
||||||
// CanRelay. It does not set a.relay — callers do that explicitly.
|
|
||||||
func addRelayPeer(t *testing.T, a *App, vpnIP string, ep netip.AddrPort) *Peer {
|
|
||||||
t.Helper()
|
|
||||||
key := mustKey(t)
|
|
||||||
ip := netip.MustParseAddr(vpnIP)
|
|
||||||
a.onAddPeer(m.Peer{
|
|
||||||
WGPubKey: key,
|
|
||||||
PeerIP: ip.As4()[3],
|
|
||||||
Addr4: ep.Addr(),
|
|
||||||
Port: ep.Port(),
|
|
||||||
Relay: true,
|
|
||||||
})
|
|
||||||
p := a.peersByKey[key]
|
|
||||||
p.wgPeer.LastHandshakeTime = time.Now()
|
|
||||||
return p
|
|
||||||
}
|
|
||||||
|
|
||||||
// newTestApp returns a minimal App wired to a fakeWGDevice and fakeControlConn.
|
|
||||||
// vpnIP is the local VPN address (e.g. "10.0.0.1").
|
|
||||||
// isPublic / isRelay describe the local node's role.
|
|
||||||
func newTestApp(t *testing.T, vpnIP string, isPublic, isRelay bool) (*App, *fakeWGDevice, *fakeControlConn) {
|
|
||||||
t.Helper()
|
|
||||||
privKey, err := wgtypes.GeneratePrivateKey()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("generate key: %v", err)
|
|
||||||
}
|
|
||||||
ip := netip.MustParseAddr(vpnIP)
|
|
||||||
dev := &fakeWGDevice{}
|
|
||||||
cc := &fakeControlConn{}
|
|
||||||
a := &App{
|
|
||||||
vpnIP: ip,
|
|
||||||
vpnNet: netip.MustParsePrefix("10.0.0.0/24"),
|
|
||||||
privKey: privKey,
|
|
||||||
pubKey: privKey.PublicKey(),
|
|
||||||
isPublic: isPublic,
|
|
||||||
isRelay: isRelay,
|
|
||||||
dev: dev,
|
|
||||||
controlConn: cc,
|
|
||||||
peersByKey: make(map[wgtypes.Key]*Peer),
|
|
||||||
peersByIP: make(map[netip.Addr]*Peer),
|
|
||||||
scratch: make([]byte, scratchSize),
|
|
||||||
hubAddCh: make(chan m.Peer),
|
|
||||||
hubRemoveCh: make(chan wgtypes.Key),
|
|
||||||
pingCh: make(chan PingEvent),
|
|
||||||
multicastCh: make(chan multicast.Packet),
|
|
||||||
}
|
|
||||||
return a, dev, cc
|
|
||||||
}
|
|
||||||
21
peer/bitset.go
Normal file
21
peer/bitset.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
const bitSetSize = 512 // Multiple of 64.
|
||||||
|
|
||||||
|
type bitSet [bitSetSize / 64]uint64
|
||||||
|
|
||||||
|
func (bs *bitSet) Set(i int) {
|
||||||
|
bs[i/64] |= 1 << (i % 64)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bs *bitSet) Clear(i int) {
|
||||||
|
bs[i/64] &= ^(1 << (i % 64))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bs *bitSet) ClearAll() {
|
||||||
|
clear(bs[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bs *bitSet) Get(i int) bool {
|
||||||
|
return bs[i/64]&(1<<(i%64)) != 0
|
||||||
|
}
|
||||||
48
peer/bitset_test.go
Normal file
48
peer/bitset_test.go
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/rand"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBitSet(t *testing.T) {
|
||||||
|
state := make([]bool, bitSetSize)
|
||||||
|
for i := range state {
|
||||||
|
state[i] = rand.Float32() > 0.5
|
||||||
|
}
|
||||||
|
|
||||||
|
bs := bitSet{}
|
||||||
|
|
||||||
|
for i := range state {
|
||||||
|
if state[i] {
|
||||||
|
bs.Set(i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range state {
|
||||||
|
if bs.Get(i) != state[i] {
|
||||||
|
t.Fatal(i, state[i], bs.Get(i))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range state {
|
||||||
|
if rand.Float32() > 0.5 {
|
||||||
|
state[i] = false
|
||||||
|
bs.Clear(i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range state {
|
||||||
|
if bs.Get(i) != state[i] {
|
||||||
|
t.Fatal(i, state[i], bs.Get(i))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bs.ClearAll()
|
||||||
|
|
||||||
|
for i := range state {
|
||||||
|
if bs.Get(i) {
|
||||||
|
t.Fatal(i, bs.Get(i))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
26
peer/cipher-control.go
Normal file
26
peer/cipher-control.go
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import "golang.org/x/crypto/nacl/box"
|
||||||
|
|
||||||
|
type controlCipher struct {
|
||||||
|
sharedKey [32]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func newControlCipher(privKey, pubKey []byte) *controlCipher {
|
||||||
|
shared := [32]byte{}
|
||||||
|
box.Precompute(&shared, (*[32]byte)(pubKey), (*[32]byte)(privKey))
|
||||||
|
return &controlCipher{shared}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cc *controlCipher) Encrypt(h Header, data, out []byte) []byte {
|
||||||
|
const s = controlHeaderSize
|
||||||
|
out = out[:s+controlCipherOverhead+len(data)]
|
||||||
|
h.Marshal(out[:s])
|
||||||
|
box.SealAfterPrecomputation(out[s:s], data, (*[24]byte)(out[:s]), &cc.sharedKey)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cc *controlCipher) Decrypt(encrypted, out []byte) (data []byte, ok bool) {
|
||||||
|
const s = controlHeaderSize
|
||||||
|
return box.OpenAfterPrecomputation(out[:0], encrypted[s:], (*[24]byte)(encrypted[:s]), &cc.sharedKey)
|
||||||
|
}
|
||||||
122
peer/cipher-control_test.go
Normal file
122
peer/cipher-control_test.go
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/rand"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/nacl/box"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newControlCipherForTesting() (c1, c2 *controlCipher) {
|
||||||
|
pubKey1, privKey1, err := box.GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pubKey2, privKey2, err := box.GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return newControlCipher(privKey1[:], pubKey2[:]),
|
||||||
|
newControlCipher(privKey2[:], pubKey1[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestControlCipher(t *testing.T) {
|
||||||
|
c1, c2 := newControlCipherForTesting()
|
||||||
|
|
||||||
|
maxSizePlaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead)
|
||||||
|
rand.Read(maxSizePlaintext)
|
||||||
|
|
||||||
|
testCases := [][]byte{
|
||||||
|
make([]byte, 0),
|
||||||
|
{1},
|
||||||
|
{255},
|
||||||
|
{1, 2, 3, 4, 5},
|
||||||
|
[]byte("Hello world"),
|
||||||
|
maxSizePlaintext,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, plaintext := range testCases {
|
||||||
|
h1 := Header{
|
||||||
|
StreamID: controlStreamID,
|
||||||
|
Counter: 235153,
|
||||||
|
SourceIP: 4,
|
||||||
|
DestIP: 88,
|
||||||
|
}
|
||||||
|
|
||||||
|
encrypted := make([]byte, bufferSize)
|
||||||
|
|
||||||
|
encrypted = c1.Encrypt(h1, plaintext, encrypted)
|
||||||
|
|
||||||
|
h2 := Header{}
|
||||||
|
h2.Parse(encrypted)
|
||||||
|
if !reflect.DeepEqual(h1, h2) {
|
||||||
|
t.Fatal(h1, h2)
|
||||||
|
}
|
||||||
|
|
||||||
|
decrypted, ok := c2.Decrypt(encrypted, make([]byte, bufferSize))
|
||||||
|
if !ok {
|
||||||
|
t.Fatal(ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(decrypted, plaintext) {
|
||||||
|
t.Fatal("not equal")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestControlCipher_ShortCiphertext(t *testing.T) {
|
||||||
|
c1, _ := newControlCipherForTesting()
|
||||||
|
shortText := make([]byte, controlHeaderSize+controlCipherOverhead-1)
|
||||||
|
rand.Read(shortText)
|
||||||
|
_, ok := c1.Decrypt(shortText, make([]byte, bufferSize))
|
||||||
|
if ok {
|
||||||
|
t.Fatal(ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkControlCipher_Encrypt(b *testing.B) {
|
||||||
|
c1, _ := newControlCipherForTesting()
|
||||||
|
h1 := Header{
|
||||||
|
Counter: 235153,
|
||||||
|
SourceIP: 4,
|
||||||
|
DestIP: 88,
|
||||||
|
}
|
||||||
|
|
||||||
|
plaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead)
|
||||||
|
rand.Read(plaintext)
|
||||||
|
|
||||||
|
encrypted := make([]byte, bufferSize)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
encrypted = c1.Encrypt(h1, plaintext, encrypted)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkControlCipher_Decrypt(b *testing.B) {
|
||||||
|
c1, c2 := newControlCipherForTesting()
|
||||||
|
|
||||||
|
h1 := Header{
|
||||||
|
Counter: 235153,
|
||||||
|
SourceIP: 4,
|
||||||
|
DestIP: 88,
|
||||||
|
}
|
||||||
|
|
||||||
|
plaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead)
|
||||||
|
rand.Read(plaintext)
|
||||||
|
|
||||||
|
encrypted := make([]byte, bufferSize)
|
||||||
|
|
||||||
|
encrypted = c1.Encrypt(h1, plaintext, encrypted)
|
||||||
|
|
||||||
|
decrypted := make([]byte, bufferSize)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
decrypted, _ = c2.Decrypt(encrypted, decrypted)
|
||||||
|
}
|
||||||
|
}
|
||||||
61
peer/cipher-data.go
Normal file
61
peer/cipher-data.go
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/aes"
|
||||||
|
"crypto/cipher"
|
||||||
|
"crypto/rand"
|
||||||
|
"log"
|
||||||
|
)
|
||||||
|
|
||||||
|
type dataCipher struct {
|
||||||
|
key [32]byte
|
||||||
|
aead cipher.AEAD
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDataCipher() *dataCipher {
|
||||||
|
key := [32]byte{}
|
||||||
|
if _, err := rand.Read(key[:]); err != nil {
|
||||||
|
log.Fatalf("Failed to read random data: %v", err)
|
||||||
|
}
|
||||||
|
return newDataCipherFromKey(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDataCipherFromKey(key [32]byte) *dataCipher {
|
||||||
|
block, err := aes.NewCipher(key[:])
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to create new cipher: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
aead, err := cipher.NewGCM(block)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to create new GCM: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &dataCipher{key: key, aead: aead}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *dataCipher) Key() [32]byte {
|
||||||
|
return sc.key
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *dataCipher) Encrypt(h Header, data, out []byte) []byte {
|
||||||
|
const s = dataHeaderSize
|
||||||
|
out = out[:s+dataCipherOverhead+len(data)]
|
||||||
|
h.Marshal(out[:s])
|
||||||
|
sc.aead.Seal(out[s:s], out[:s], data, nil)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *dataCipher) Decrypt(encrypted, out []byte) (data []byte, ok bool) {
|
||||||
|
const s = dataHeaderSize
|
||||||
|
if len(encrypted) < s+dataCipherOverhead {
|
||||||
|
ok = false
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
|
||||||
|
data, err = sc.aead.Open(out[:0], encrypted[:s], encrypted[s:], nil)
|
||||||
|
ok = err == nil
|
||||||
|
return
|
||||||
|
}
|
||||||
141
peer/cipher-data_test.go
Normal file
141
peer/cipher-data_test.go
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/rand"
|
||||||
|
mrand "math/rand/v2"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDataCipher(t *testing.T) {
|
||||||
|
maxSizePlaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead)
|
||||||
|
rand.Read(maxSizePlaintext)
|
||||||
|
|
||||||
|
testCases := [][]byte{
|
||||||
|
make([]byte, 0),
|
||||||
|
{1},
|
||||||
|
{255},
|
||||||
|
{1, 2, 3, 4, 5},
|
||||||
|
[]byte("Hello world"),
|
||||||
|
maxSizePlaintext,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, plaintext := range testCases {
|
||||||
|
h1 := Header{
|
||||||
|
StreamID: dataStreamID,
|
||||||
|
Counter: 235153,
|
||||||
|
SourceIP: 4,
|
||||||
|
DestIP: 88,
|
||||||
|
}
|
||||||
|
|
||||||
|
encrypted := make([]byte, bufferSize)
|
||||||
|
|
||||||
|
dc1 := newDataCipher()
|
||||||
|
encrypted = dc1.Encrypt(h1, plaintext, encrypted)
|
||||||
|
h2 := Header{}
|
||||||
|
h2.Parse(encrypted)
|
||||||
|
|
||||||
|
dc2 := newDataCipherFromKey(dc1.Key())
|
||||||
|
|
||||||
|
decrypted, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize))
|
||||||
|
if !ok {
|
||||||
|
t.Fatal(ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(plaintext, decrypted) {
|
||||||
|
t.Fatal("not equal")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(h1, h2) {
|
||||||
|
t.Fatalf("%v != %v", h1, h2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDataCipher_ModifyCiphertext(t *testing.T) {
|
||||||
|
maxSizePlaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead)
|
||||||
|
rand.Read(maxSizePlaintext)
|
||||||
|
|
||||||
|
testCases := [][]byte{
|
||||||
|
make([]byte, 0),
|
||||||
|
{1},
|
||||||
|
{255},
|
||||||
|
{1, 2, 3, 4, 5},
|
||||||
|
[]byte("Hello world"),
|
||||||
|
maxSizePlaintext,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, plaintext := range testCases {
|
||||||
|
h1 := Header{
|
||||||
|
Counter: 235153,
|
||||||
|
SourceIP: 4,
|
||||||
|
DestIP: 88,
|
||||||
|
}
|
||||||
|
|
||||||
|
encrypted := make([]byte, bufferSize)
|
||||||
|
|
||||||
|
dc1 := newDataCipher()
|
||||||
|
encrypted = dc1.Encrypt(h1, plaintext, encrypted)
|
||||||
|
encrypted[mrand.IntN(len(encrypted))]++
|
||||||
|
|
||||||
|
dc2 := newDataCipherFromKey(dc1.Key())
|
||||||
|
|
||||||
|
_, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize))
|
||||||
|
if ok {
|
||||||
|
t.Fatal(ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDataCipher_ShortCiphertext(t *testing.T) {
|
||||||
|
dc1 := newDataCipher()
|
||||||
|
shortText := make([]byte, dataHeaderSize+dataCipherOverhead-1)
|
||||||
|
rand.Read(shortText)
|
||||||
|
_, ok := dc1.Decrypt(shortText, make([]byte, bufferSize))
|
||||||
|
if ok {
|
||||||
|
t.Fatal(ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDataCipher_Encrypt(b *testing.B) {
|
||||||
|
h1 := Header{
|
||||||
|
Counter: 235153,
|
||||||
|
SourceIP: 4,
|
||||||
|
DestIP: 88,
|
||||||
|
}
|
||||||
|
|
||||||
|
plaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead)
|
||||||
|
rand.Read(plaintext)
|
||||||
|
|
||||||
|
encrypted := make([]byte, bufferSize)
|
||||||
|
|
||||||
|
dc1 := newDataCipher()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
encrypted = dc1.Encrypt(h1, plaintext, encrypted)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDataCipher_Decrypt(b *testing.B) {
|
||||||
|
h1 := Header{
|
||||||
|
Counter: 235153,
|
||||||
|
SourceIP: 4,
|
||||||
|
DestIP: 88,
|
||||||
|
}
|
||||||
|
|
||||||
|
plaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead)
|
||||||
|
rand.Read(plaintext)
|
||||||
|
|
||||||
|
encrypted := make([]byte, bufferSize)
|
||||||
|
|
||||||
|
dc1 := newDataCipher()
|
||||||
|
encrypted = dc1.Encrypt(h1, plaintext, encrypted)
|
||||||
|
|
||||||
|
decrypted := make([]byte, bufferSize)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
decrypted, _ = dc1.Decrypt(encrypted, decrypted)
|
||||||
|
}
|
||||||
|
}
|
||||||
46
peer/connreader.go
Normal file
46
peer/connreader.go
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ConnReader struct {
|
||||||
|
Globals
|
||||||
|
conn *net.UDPConn
|
||||||
|
buf []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConnReader(g Globals, conn *net.UDPConn) *ConnReader {
|
||||||
|
return &ConnReader{
|
||||||
|
Globals: g,
|
||||||
|
conn: conn,
|
||||||
|
buf: make([]byte, bufferSize),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ConnReader) Run() {
|
||||||
|
for {
|
||||||
|
r.handleNextPacket()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ConnReader) handleNextPacket() {
|
||||||
|
buf := r.buf[:bufferSize]
|
||||||
|
n, remoteAddr, err := r.conn.ReadFromUDPAddrPort(buf)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to read from UDP port: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if n < headerSize {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
remoteAddr = netip.AddrPortFrom(remoteAddr.Addr().Unmap(), remoteAddr.Port())
|
||||||
|
|
||||||
|
buf = buf[:n]
|
||||||
|
h := parseHeader(buf)
|
||||||
|
|
||||||
|
r.RemotePeers[h.SourceIP].Load().HandlePacket(h, remoteAddr, buf)
|
||||||
|
}
|
||||||
@@ -1,76 +0,0 @@
|
|||||||
// Package control implements the VPN-internal peer control protocol.
|
|
||||||
// Peers exchange Ping packets over UDP on the VPN control port to maintain
|
|
||||||
// liveness and discover external endpoints for direct connection attempts.
|
|
||||||
package control
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
version = 1
|
|
||||||
Size = 51 // 1 version + 8 PingTS + 6 SrcV4 + 18 SrcV6 + 18 Dst
|
|
||||||
)
|
|
||||||
|
|
||||||
// Ping is the single control packet type exchanged between VPN peers.
|
|
||||||
//
|
|
||||||
// In each peer pair, the peer with the lower VPN IP is the client: it sets
|
|
||||||
// PingTS and sends pings on a timer. The server echoes PingTS back in its
|
|
||||||
// response, allowing the client to compute RTT = now - PingTS.
|
|
||||||
//
|
|
||||||
// Both client and server populate SrcV4, SrcV6, and Dst on every packet so
|
|
||||||
// endpoint information flows in both directions.
|
|
||||||
//
|
|
||||||
// Dst is the recipient's external endpoint as observed by the sender from the
|
|
||||||
// WireGuard handshake source. Zero if the sender has not observed a handshake
|
|
||||||
// from the recipient.
|
|
||||||
type Ping struct {
|
|
||||||
PingTS int64 // Client ping send time in nanoseconds.
|
|
||||||
SrcV4 netip.AddrPort // Sender's discovered IPv4 address and port.
|
|
||||||
SrcV6 netip.AddrPort // Sender's discovered IPv6 address and port.
|
|
||||||
Dst netip.AddrPort
|
|
||||||
}
|
|
||||||
|
|
||||||
// Marshal encodes p into buf (which must be at least Size bytes) and returns
|
|
||||||
// buf[:Size]. Taking the buffer lets callers reuse one across sends; every
|
|
||||||
// field is written unconditionally so a reused buffer needs no pre-zeroing.
|
|
||||||
func (p Ping) Marshal(buf []byte) []byte {
|
|
||||||
buf[0] = version
|
|
||||||
binary.BigEndian.PutUint64(buf[1:9], uint64(p.PingTS))
|
|
||||||
if p.SrcV4.IsValid() {
|
|
||||||
a4 := p.SrcV4.Addr().As4()
|
|
||||||
copy(buf[9:13], a4[:])
|
|
||||||
binary.BigEndian.PutUint16(buf[13:15], p.SrcV4.Port())
|
|
||||||
} else {
|
|
||||||
clear(buf[9:15])
|
|
||||||
}
|
|
||||||
a16 := p.SrcV6.Addr().As16()
|
|
||||||
copy(buf[15:31], a16[:])
|
|
||||||
binary.BigEndian.PutUint16(buf[31:33], p.SrcV6.Port())
|
|
||||||
a16 = p.Dst.Addr().As16()
|
|
||||||
copy(buf[33:49], a16[:])
|
|
||||||
binary.BigEndian.PutUint16(buf[49:51], p.Dst.Port())
|
|
||||||
return buf[:Size]
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unmarshal decodes a Ping from a fixed-size 51-byte array.
|
|
||||||
func Unmarshal(buf [Size]byte) (Ping, error) {
|
|
||||||
if buf[0] != version {
|
|
||||||
return Ping{}, fmt.Errorf("unknown ping version %d", buf[0])
|
|
||||||
}
|
|
||||||
p := Ping{
|
|
||||||
PingTS: int64(binary.BigEndian.Uint64(buf[1:9])),
|
|
||||||
}
|
|
||||||
if addr := netip.AddrFrom4([4]byte(buf[9:13])); !addr.IsUnspecified() {
|
|
||||||
p.SrcV4 = netip.AddrPortFrom(addr, binary.BigEndian.Uint16(buf[13:15]))
|
|
||||||
}
|
|
||||||
if addr := netip.AddrFrom16([16]byte(buf[15:31])); !addr.IsUnspecified() {
|
|
||||||
p.SrcV6 = netip.AddrPortFrom(addr, binary.BigEndian.Uint16(buf[31:33]))
|
|
||||||
}
|
|
||||||
if addr := netip.AddrFrom16([16]byte(buf[33:49])).Unmap(); !addr.IsUnspecified() {
|
|
||||||
p.Dst = netip.AddrPortFrom(addr, binary.BigEndian.Uint16(buf[49:51]))
|
|
||||||
}
|
|
||||||
return p, nil
|
|
||||||
}
|
|
||||||
@@ -1,106 +0,0 @@
|
|||||||
package control_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"vppn/peer/control"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestRoundTrip(t *testing.T) {
|
|
||||||
cases := []struct {
|
|
||||||
name string
|
|
||||||
ping control.Ping
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "zero",
|
|
||||||
ping: control.Ping{},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "client ping",
|
|
||||||
ping: control.Ping{
|
|
||||||
PingTS: 1234567890,
|
|
||||||
SrcV4: netip.MustParseAddrPort("1.2.3.4:51820"),
|
|
||||||
Dst: netip.MustParseAddrPort("5.6.7.8:51820"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "server response",
|
|
||||||
ping: control.Ping{
|
|
||||||
PingTS: 1234567890,
|
|
||||||
SrcV4: netip.MustParseAddrPort("5.6.7.8:51820"),
|
|
||||||
Dst: netip.MustParseAddrPort("1.2.3.4:9999"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "IPv6 only",
|
|
||||||
ping: control.Ping{
|
|
||||||
PingTS: 999,
|
|
||||||
SrcV6: netip.MustParseAddrPort("[2001:db8::1]:51820"),
|
|
||||||
Dst: netip.MustParseAddrPort("[2001:db8::2]:51820"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "dual stack",
|
|
||||||
ping: control.Ping{
|
|
||||||
PingTS: 555,
|
|
||||||
SrcV4: netip.MustParseAddrPort("1.2.3.4:51820"),
|
|
||||||
SrcV6: netip.MustParseAddrPort("[2001:db8::1]:51820"),
|
|
||||||
Dst: netip.MustParseAddrPort("5.6.7.8:9999"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "no src known",
|
|
||||||
ping: control.Ping{
|
|
||||||
Dst: netip.MustParseAddrPort("5.6.7.8:51820"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range cases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
var buf [control.Size]byte
|
|
||||||
tc.ping.Marshal(buf[:])
|
|
||||||
got, err := control.Unmarshal(buf)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Unmarshal: %v", err)
|
|
||||||
}
|
|
||||||
if got != tc.ping {
|
|
||||||
t.Fatalf("round-trip mismatch:\n got %+v\n want %+v", got, tc.ping)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUnmarshalBadVersion(t *testing.T) {
|
|
||||||
var buf [control.Size]byte
|
|
||||||
buf[0] = 99
|
|
||||||
if _, err := control.Unmarshal(buf); err == nil {
|
|
||||||
t.Fatal("expected error for unknown version, got nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestZeroEncoding(t *testing.T) {
|
|
||||||
var buf [control.Size]byte
|
|
||||||
(control.Ping{}).Marshal(buf[:])
|
|
||||||
for i, b := range buf {
|
|
||||||
if i == 0 {
|
|
||||||
continue // version byte
|
|
||||||
}
|
|
||||||
if b != 0 {
|
|
||||||
t.Fatalf("expected zero encoding at byte %d, got %d", i, b)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRoleFor(t *testing.T) {
|
|
||||||
lo := netip.MustParseAddr("10.0.0.1")
|
|
||||||
hi := netip.MustParseAddr("10.0.0.2")
|
|
||||||
|
|
||||||
if control.RoleFor(lo, hi) != control.Client {
|
|
||||||
t.Error("lower IP should be client")
|
|
||||||
}
|
|
||||||
if control.RoleFor(hi, lo) != control.Server {
|
|
||||||
t.Error("higher IP should be server")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
package control
|
|
||||||
|
|
||||||
import "net/netip"
|
|
||||||
|
|
||||||
// Role identifies a peer's role in a ping exchange with a specific remote peer.
|
|
||||||
type Role string
|
|
||||||
|
|
||||||
const (
|
|
||||||
// Client initiates pings and measures RTT.
|
|
||||||
Client Role = "CLIENT"
|
|
||||||
// Server responds to pings.
|
|
||||||
Server Role = "SERVER"
|
|
||||||
)
|
|
||||||
|
|
||||||
// RoleFor returns the Role of local relative to remote.
|
|
||||||
// The peer with the lower VPN IP is the client.
|
|
||||||
func RoleFor(local, remote netip.Addr) Role {
|
|
||||||
if local.Compare(remote) < 0 {
|
|
||||||
return Client
|
|
||||||
}
|
|
||||||
return Server
|
|
||||||
}
|
|
||||||
@@ -1,64 +0,0 @@
|
|||||||
package peer
|
|
||||||
|
|
||||||
import (
|
|
||||||
"log"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
|
|
||||||
"vppn/peer/control"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ ControlConn = (*udpControlConn)(nil)
|
|
||||||
|
|
||||||
type udpControlConn struct {
|
|
||||||
conn *net.UDPConn
|
|
||||||
}
|
|
||||||
|
|
||||||
// newUDPControlConn opens a UDP socket bound to localIP:port.
|
|
||||||
func newUDPControlConn(localIP netip.Addr, port uint16) (*udpControlConn, error) {
|
|
||||||
addr := net.UDPAddrFromAddrPort(netip.AddrPortFrom(localIP, port))
|
|
||||||
conn, err := net.ListenUDP("udp4", addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &udpControlConn{conn: conn}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *udpControlConn) SendPing(dst netip.AddrPort, ping control.Ping, buf []byte) error {
|
|
||||||
_, err := c.conn.WriteToUDP(ping.Marshal(buf), net.UDPAddrFromAddrPort(dst))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// run reads incoming ping packets and forwards them to ch until ctx is done.
|
|
||||||
// Call this in a goroutine before starting the App event loop.
|
|
||||||
func (c *udpControlConn) run(ch chan<- PingEvent) {
|
|
||||||
var buf [control.Size]byte
|
|
||||||
for {
|
|
||||||
n, src, err := c.conn.ReadFromUDP(buf[:])
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("control read: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if n != control.Size {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
ping, err := control.Unmarshal(buf)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("control unmarshal: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
srcIP, ok := netip.AddrFromSlice(src.IP)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
ch <- PingEvent{srcVPNIP: srcIP.Unmap(), ping: ping}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *udpControlConn) Close() error {
|
|
||||||
return c.conn.Close()
|
|
||||||
}
|
|
||||||
64
peer/controlmessage.go
Normal file
64
peer/controlmessage.go
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"vppn/m"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type controlMsg[T any] struct {
|
||||||
|
SrcIP byte
|
||||||
|
SrcAddr netip.AddrPort
|
||||||
|
Packet T
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error) {
|
||||||
|
switch buf[0] {
|
||||||
|
|
||||||
|
case packetTypeInit:
|
||||||
|
packet, err := parsePacketInit(buf)
|
||||||
|
return controlMsg[packetInit]{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
SrcAddr: srcAddr,
|
||||||
|
Packet: packet,
|
||||||
|
}, err
|
||||||
|
|
||||||
|
case packetTypeSyn:
|
||||||
|
packet, err := parsePacketSyn(buf)
|
||||||
|
return controlMsg[packetSyn]{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
SrcAddr: srcAddr,
|
||||||
|
Packet: packet,
|
||||||
|
}, err
|
||||||
|
|
||||||
|
case packetTypeAck:
|
||||||
|
packet, err := parsePacketAck(buf)
|
||||||
|
return controlMsg[packetAck]{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
SrcAddr: srcAddr,
|
||||||
|
Packet: packet,
|
||||||
|
}, err
|
||||||
|
|
||||||
|
case packetTypeProbe:
|
||||||
|
packet, err := parsePacketProbe(buf)
|
||||||
|
return controlMsg[packetProbe]{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
SrcAddr: srcAddr,
|
||||||
|
Packet: packet,
|
||||||
|
}, err
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, errUnknownPacketType
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type peerUpdateMsg struct {
|
||||||
|
Peer *m.Peer
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type pingTimerMsg struct{}
|
||||||
30
peer/crypto.go
Normal file
30
peer/crypto.go
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/nacl/box"
|
||||||
|
"golang.org/x/crypto/nacl/sign"
|
||||||
|
)
|
||||||
|
|
||||||
|
type cryptoKeys struct {
|
||||||
|
PubKey []byte
|
||||||
|
PrivKey []byte
|
||||||
|
PubSignKey []byte
|
||||||
|
PrivSignKey []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateKeys() cryptoKeys {
|
||||||
|
pubKey, privKey, err := box.GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to generate encryption keys: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pubSignKey, privSignKey, err := sign.GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to generate signing keys: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return cryptoKeys{pubKey[:], privKey[:], pubSignKey[:], privSignKey[:]}
|
||||||
|
}
|
||||||
@@ -1,79 +0,0 @@
|
|||||||
package peer
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"log"
|
|
||||||
"net/netip"
|
|
||||||
"syscall"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
)
|
|
||||||
|
|
||||||
// devRetry calls fn up to 6 times with exponential backoff, retrying on EBUSY
|
|
||||||
// (transient netlink contention during WireGuard handshake/rekey). Fatal on
|
|
||||||
// any other error.
|
|
||||||
func devRetry(vpnIP netip.Addr, op string, fn func() error) {
|
|
||||||
const attempts = 6
|
|
||||||
timeout := 10 * time.Millisecond
|
|
||||||
for i := range attempts {
|
|
||||||
err := fn()
|
|
||||||
if err == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if errors.Is(err, syscall.EBUSY) && i < attempts-1 {
|
|
||||||
time.Sleep(timeout)
|
|
||||||
timeout *= 2
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
log.Fatalf("%s %v: %v", op, vpnIP, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *App) devPeers() []wgtypes.Peer {
|
|
||||||
peers, err := a.dev.Peers()
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("Failed to get peers %v: %v", a.vpnIP, err)
|
|
||||||
}
|
|
||||||
return peers
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *App) devAddPeer(p *Peer) {
|
|
||||||
log.Printf("RELAYED: %s - %s ", p.Name, p.VPNIP.String())
|
|
||||||
devRetry(p.VPNIP, "AddPeer", func() error { return a.dev.AddPeer(p.PubKey()) })
|
|
||||||
p.State = StateRelayed
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *App) devAddDirect(p *Peer, endpoint netip.AddrPort) {
|
|
||||||
log.Printf("DIRECT: %s - %s @ %s", p.Name, p.VPNIP.String(), endpoint.String())
|
|
||||||
devRetry(p.VPNIP, "AddDirect", func() error { return a.dev.AddDirect(p.PubKey(), endpoint, p.VPNIP) })
|
|
||||||
p.State = StateDirect
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *App) devSetRelay(p *Peer, endpoint netip.AddrPort) {
|
|
||||||
log.Printf("RELAY: %s - %s @ %s", p.Name, p.VPNIP.String(), endpoint.String())
|
|
||||||
devRetry(p.VPNIP, "SetRelay", func() error { return a.dev.SetRelay(p.PubKey(), endpoint, a.vpnNet) })
|
|
||||||
p.State = StateDirect // Dirrect connection. The app marks peer as relay.
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *App) devPromote(p *Peer) {
|
|
||||||
ep := p.WGEndpoint()
|
|
||||||
if ep.IsValid() {
|
|
||||||
log.Printf("PROMOTED: %s - %s @ %s", p.Name, p.VPNIP.String(), p.WGEndpoint().String())
|
|
||||||
} else {
|
|
||||||
log.Printf("PROMOTED: %s - %s (no IP)", p.Name, p.VPNIP.String())
|
|
||||||
}
|
|
||||||
devRetry(p.VPNIP, "Promote", func() error { return a.dev.Promote(p.PubKey(), p.VPNIP) })
|
|
||||||
p.State = StateDirect
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *App) devAddProbe(p *Peer, endpoint netip.AddrPort) {
|
|
||||||
log.Printf("PROBE: %s - %s @ %s", p.Name, p.VPNIP.String(), endpoint.String())
|
|
||||||
devRetry(p.VPNIP, "AddProbe", func() error { return a.dev.AddProbe(p.PubKey(), endpoint) })
|
|
||||||
p.State = StateProbing
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *App) devRemove(p *Peer) {
|
|
||||||
log.Printf("REMOVED: %s - %s", p.Name, p.VPNIP.String())
|
|
||||||
devRetry(p.VPNIP, "RemovePeer", func() error { return a.dev.RemovePeer(p.PubKey()) })
|
|
||||||
}
|
|
||||||
76
peer/dupcheck.go
Normal file
76
peer/dupcheck.go
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
type dupCheck struct {
|
||||||
|
bitSet
|
||||||
|
head int
|
||||||
|
tail int
|
||||||
|
headCounter uint64
|
||||||
|
tailCounter uint64 // Also next expected counter value.
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDupCheck(headCounter uint64) *dupCheck {
|
||||||
|
return &dupCheck{
|
||||||
|
headCounter: headCounter,
|
||||||
|
tailCounter: headCounter + 1,
|
||||||
|
tail: 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dc *dupCheck) IsDup(counter uint64) bool {
|
||||||
|
|
||||||
|
// Before head => it's late, say it's a dup.
|
||||||
|
if counter < dc.headCounter {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// It's within the counter bounds.
|
||||||
|
if counter < dc.tailCounter {
|
||||||
|
index := (int(counter-dc.headCounter) + dc.head) % bitSetSize
|
||||||
|
if dc.Get(index) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
dc.Set(index)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// It's more than 1 beyond the tail.
|
||||||
|
delta := counter - dc.tailCounter
|
||||||
|
|
||||||
|
// Full clear.
|
||||||
|
if delta >= bitSetSize-1 {
|
||||||
|
dc.ClearAll()
|
||||||
|
dc.Set(0)
|
||||||
|
|
||||||
|
dc.tail = 1
|
||||||
|
dc.head = 2
|
||||||
|
dc.tailCounter = counter + 1
|
||||||
|
dc.headCounter = dc.tailCounter - bitSetSize + 1
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear if necessary.
|
||||||
|
for range delta {
|
||||||
|
dc.put(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
dc.put(true)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dc *dupCheck) put(set bool) {
|
||||||
|
if set {
|
||||||
|
dc.Set(dc.tail)
|
||||||
|
} else {
|
||||||
|
dc.Clear(dc.tail)
|
||||||
|
}
|
||||||
|
|
||||||
|
dc.tail = (dc.tail + 1) % bitSetSize
|
||||||
|
dc.tailCounter++
|
||||||
|
|
||||||
|
if dc.head == dc.tail {
|
||||||
|
dc.head = (dc.head + 1) % bitSetSize
|
||||||
|
dc.headCounter++
|
||||||
|
}
|
||||||
|
}
|
||||||
57
peer/dupcheck_test.go
Normal file
57
peer/dupcheck_test.go
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDupCheck(t *testing.T) {
|
||||||
|
dc := newDupCheck(0)
|
||||||
|
|
||||||
|
for i := range bitSetSize {
|
||||||
|
if dc.IsDup(uint64(i)) {
|
||||||
|
t.Fatal("!")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type TestCase struct {
|
||||||
|
Counter uint64
|
||||||
|
Dup bool
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []TestCase{
|
||||||
|
{511, true},
|
||||||
|
{0, true},
|
||||||
|
{1, true},
|
||||||
|
{2, true},
|
||||||
|
{3, true},
|
||||||
|
{63, true},
|
||||||
|
{256, true},
|
||||||
|
{510, true},
|
||||||
|
{511, true},
|
||||||
|
{512, false},
|
||||||
|
{0, true},
|
||||||
|
{512, true},
|
||||||
|
{513, false},
|
||||||
|
{517, false},
|
||||||
|
{512, true},
|
||||||
|
{513, true},
|
||||||
|
{514, false},
|
||||||
|
{515, false},
|
||||||
|
{516, false},
|
||||||
|
{517, true},
|
||||||
|
{2512, false},
|
||||||
|
{2512, true},
|
||||||
|
{2001, true},
|
||||||
|
{2002, false},
|
||||||
|
{2002, true},
|
||||||
|
{4000, false},
|
||||||
|
{4000 - 511, true}, // Too old.
|
||||||
|
{4000 - 510, false}, // Just in the window.
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tc := range testCases {
|
||||||
|
if ok := dc.IsDup(tc.Counter); ok != tc.Dup {
|
||||||
|
t.Fatal(i, ok, tc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
8
peer/errors.go
Normal file
8
peer/errors.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var (
|
||||||
|
errMalformedPacket = errors.New("malformed packet")
|
||||||
|
errUnknownPacketType = errors.New("unknown packet type")
|
||||||
|
)
|
||||||
@@ -1,43 +0,0 @@
|
|||||||
package peer
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"vppn/peer/control"
|
|
||||||
)
|
|
||||||
|
|
||||||
type sentPing struct {
|
|
||||||
Dst netip.AddrPort
|
|
||||||
Ping control.Ping
|
|
||||||
}
|
|
||||||
|
|
||||||
type fakeControlConn struct {
|
|
||||||
Sent []sentPing
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeControlConn) SendPing(dst netip.AddrPort, ping control.Ping, _ []byte) error {
|
|
||||||
f.Sent = append(f.Sent, sentPing{Dst: dst, Ping: ping})
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeControlConn) AssertNone(t *testing.T) {
|
|
||||||
t.Helper()
|
|
||||||
if len(f.Sent) != 0 {
|
|
||||||
t.Fatalf("expected no pings sent, got %d: %v", len(f.Sent), f.Sent)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeControlConn) AssertSent(t *testing.T, i int, dst netip.AddrPort, ping control.Ping) {
|
|
||||||
t.Helper()
|
|
||||||
if i >= len(f.Sent) {
|
|
||||||
t.Fatalf("no ping at index %d (have %d)", i, len(f.Sent))
|
|
||||||
}
|
|
||||||
got := f.Sent[i]
|
|
||||||
if got.Dst != dst {
|
|
||||||
t.Errorf("ping[%d].Dst = %v, want %v", i, got.Dst, dst)
|
|
||||||
}
|
|
||||||
if got.Ping != ping {
|
|
||||||
t.Errorf("ping[%d].Ping = %+v, want %+v", i, got.Ping, ping)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,123 +0,0 @@
|
|||||||
package peer
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
)
|
|
||||||
|
|
||||||
// fakeWGDevice records every call made to it. It is safe to read Calls after
|
|
||||||
// the event loop has processed the event under test (single-threaded loop
|
|
||||||
// means no extra synchronisation needed, but the mutex guards concurrent test
|
|
||||||
// helpers if needed).
|
|
||||||
type fakeWGDevice struct {
|
|
||||||
mu sync.Mutex
|
|
||||||
Calls []fakeCall
|
|
||||||
peers []wgtypes.Peer
|
|
||||||
}
|
|
||||||
|
|
||||||
type fakeCall struct {
|
|
||||||
Method string
|
|
||||||
PubKey wgtypes.Key
|
|
||||||
Endpoint netip.AddrPort
|
|
||||||
VPNiP netip.Addr
|
|
||||||
Network netip.Prefix
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeWGDevice) record(c fakeCall) {
|
|
||||||
f.mu.Lock()
|
|
||||||
f.Calls = append(f.Calls, c)
|
|
||||||
f.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeWGDevice) Name() string { return "wg-test" }
|
|
||||||
|
|
||||||
func (f *fakeWGDevice) Peers() ([]wgtypes.Peer, error) {
|
|
||||||
f.mu.Lock()
|
|
||||||
defer f.mu.Unlock()
|
|
||||||
out := make([]wgtypes.Peer, len(f.peers))
|
|
||||||
copy(out, f.peers)
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeWGDevice) AddPeer(pubKey wgtypes.Key) error {
|
|
||||||
f.record(fakeCall{Method: "AddPeer", PubKey: pubKey})
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeWGDevice) AddDirect(pubKey wgtypes.Key, endpoint netip.AddrPort, vpnIP netip.Addr) error {
|
|
||||||
f.record(fakeCall{Method: "AddDirect", PubKey: pubKey, Endpoint: endpoint, VPNiP: vpnIP})
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeWGDevice) SetRelay(pubKey wgtypes.Key, endpoint netip.AddrPort, network netip.Prefix) error {
|
|
||||||
f.record(fakeCall{Method: "SetRelay", PubKey: pubKey, Endpoint: endpoint, Network: network})
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeWGDevice) AddProbe(pubKey wgtypes.Key, endpoint netip.AddrPort) error {
|
|
||||||
f.record(fakeCall{Method: "AddProbe", PubKey: pubKey, Endpoint: endpoint})
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeWGDevice) Promote(pubKey wgtypes.Key, vpnIP netip.Addr) error {
|
|
||||||
f.record(fakeCall{Method: "Promote", PubKey: pubKey, VPNiP: vpnIP})
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeWGDevice) RemovePeer(pubKey wgtypes.Key) error {
|
|
||||||
f.record(fakeCall{Method: "RemovePeer", PubKey: pubKey})
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AssertNoCalls fails the test if any dev calls were recorded.
|
|
||||||
func (f *fakeWGDevice) AssertNoCalls(t *testing.T) {
|
|
||||||
t.Helper()
|
|
||||||
f.mu.Lock()
|
|
||||||
defer f.mu.Unlock()
|
|
||||||
if len(f.Calls) != 0 {
|
|
||||||
t.Fatalf("unexpected dev calls: %v", f.Calls)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeWGDevice) AssertAddPeer(t *testing.T, i int, pubKey wgtypes.Key) {
|
|
||||||
t.Helper()
|
|
||||||
f.assertCall(t, i, fakeCall{Method: "AddPeer", PubKey: pubKey})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeWGDevice) AssertAddDirect(t *testing.T, i int, pubKey wgtypes.Key, endpoint netip.AddrPort, vpnIP netip.Addr) {
|
|
||||||
t.Helper()
|
|
||||||
f.assertCall(t, i, fakeCall{Method: "AddDirect", PubKey: pubKey, Endpoint: endpoint, VPNiP: vpnIP})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeWGDevice) AssertSetRelay(t *testing.T, i int, pubKey wgtypes.Key, endpoint netip.AddrPort, network netip.Prefix) {
|
|
||||||
t.Helper()
|
|
||||||
f.assertCall(t, i, fakeCall{Method: "SetRelay", PubKey: pubKey, Endpoint: endpoint, Network: network})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeWGDevice) AssertAddProbe(t *testing.T, i int, pubKey wgtypes.Key, endpoint netip.AddrPort) {
|
|
||||||
t.Helper()
|
|
||||||
f.assertCall(t, i, fakeCall{Method: "AddProbe", PubKey: pubKey, Endpoint: endpoint})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeWGDevice) AssertPromote(t *testing.T, i int, pubKey wgtypes.Key, vpnIP netip.Addr) {
|
|
||||||
t.Helper()
|
|
||||||
f.assertCall(t, i, fakeCall{Method: "Promote", PubKey: pubKey, VPNiP: vpnIP})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeWGDevice) AssertRemovePeer(t *testing.T, i int, pubKey wgtypes.Key) {
|
|
||||||
t.Helper()
|
|
||||||
f.assertCall(t, i, fakeCall{Method: "RemovePeer", PubKey: pubKey})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeWGDevice) assertCall(t *testing.T, i int, c fakeCall) {
|
|
||||||
t.Helper()
|
|
||||||
if len(f.Calls) <= i {
|
|
||||||
t.Fatalf("no call at index %d: %v", i, c)
|
|
||||||
}
|
|
||||||
if c != f.Calls[i] {
|
|
||||||
t.Fatalf("call[%d]: got %v, want %v", i, f.Calls[i], c)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
115
peer/files.go
Normal file
115
peer/files.go
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"vppn/m"
|
||||||
|
)
|
||||||
|
|
||||||
|
type LocalConfig struct {
|
||||||
|
LocalPeerIP byte
|
||||||
|
Network []byte
|
||||||
|
PubKey []byte
|
||||||
|
PrivKey []byte
|
||||||
|
PubSignKey []byte
|
||||||
|
PrivSignKey []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type startupCount struct {
|
||||||
|
Count uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
func configDir(netName string) string {
|
||||||
|
d, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to get user home directory: %v", err)
|
||||||
|
}
|
||||||
|
return filepath.Join(d, ".vppn", netName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func lockFilePath(netName string) string {
|
||||||
|
return filepath.Join(configDir(netName), "__lock__")
|
||||||
|
}
|
||||||
|
|
||||||
|
func peerConfigPath(netName string) string {
|
||||||
|
return filepath.Join(configDir(netName), "config.json")
|
||||||
|
}
|
||||||
|
|
||||||
|
func peerStatePath(netName string) string {
|
||||||
|
return filepath.Join(configDir(netName), "state.json")
|
||||||
|
}
|
||||||
|
|
||||||
|
func startupCountPath(netName string) string {
|
||||||
|
return filepath.Join(configDir(netName), "startup_count.json")
|
||||||
|
}
|
||||||
|
|
||||||
|
func statusSocketPath(netName string) string {
|
||||||
|
return filepath.Join(configDir(netName), "status.sock")
|
||||||
|
}
|
||||||
|
|
||||||
|
func storeJson(x any, outPath string) error {
|
||||||
|
outDir := filepath.Dir(outPath)
|
||||||
|
_ = os.MkdirAll(outDir, 0700)
|
||||||
|
|
||||||
|
tmpPath := outPath + ".tmp"
|
||||||
|
buf, err := json.Marshal(x)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := os.Create(tmpPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := f.Write(buf); err != nil {
|
||||||
|
f.Close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := f.Sync(); err != nil {
|
||||||
|
f.Close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := f.Close(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return os.Rename(tmpPath, outPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func storePeerConfig(netName string, pc LocalConfig) error {
|
||||||
|
return storeJson(pc, peerConfigPath(netName))
|
||||||
|
}
|
||||||
|
|
||||||
|
func storeNetworkState(netName string, ps m.NetworkState) error {
|
||||||
|
return storeJson(ps, peerStatePath(netName))
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadJson(dataPath string, ptr any) error {
|
||||||
|
data, err := os.ReadFile(dataPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.Unmarshal(data, ptr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadPeerConfig(netName string) (pc LocalConfig, err error) {
|
||||||
|
return pc, loadJson(peerConfigPath(netName), &pc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadNetworkState(netName string) (ps m.NetworkState, err error) {
|
||||||
|
return ps, loadJson(peerStatePath(netName), &ps)
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadStartupCount(netName string) (c startupCount, err error) {
|
||||||
|
return c, loadJson(startupCountPath(netName), &c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func storeStartupCount(netName string, c startupCount) error {
|
||||||
|
return storeJson(c, startupCountPath(netName))
|
||||||
|
}
|
||||||
57
peer/files_test.go
Normal file
57
peer/files_test.go
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path/filepath"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFilePaths(t *testing.T) {
|
||||||
|
confDir := configDir("netName")
|
||||||
|
if filepath.Base(confDir) != "netName" {
|
||||||
|
t.Fatal(confDir)
|
||||||
|
}
|
||||||
|
if filepath.Base(filepath.Dir(confDir)) != ".vppn" {
|
||||||
|
t.Fatal(confDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
path := peerConfigPath("netName")
|
||||||
|
if path != filepath.Join(confDir, "config.json") {
|
||||||
|
t.Fatal(path)
|
||||||
|
}
|
||||||
|
|
||||||
|
path = peerStatePath("netName")
|
||||||
|
if path != filepath.Join(confDir, "state.json") {
|
||||||
|
t.Fatal(path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStoreLoadJson(t *testing.T) {
|
||||||
|
type Object struct {
|
||||||
|
Name string
|
||||||
|
Age int
|
||||||
|
Price float64
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
outPath := filepath.Join(tmpDir, "object.json")
|
||||||
|
|
||||||
|
obj := Object{
|
||||||
|
Name: "Jason",
|
||||||
|
Age: 22,
|
||||||
|
Price: 123.534,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := storeJson(obj, outPath); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
obj2 := Object{}
|
||||||
|
if err := loadJson(outPath, &obj2); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(obj, obj2) {
|
||||||
|
t.Fatal(obj, obj2)
|
||||||
|
}
|
||||||
|
}
|
||||||
109
peer/globals.go
Normal file
109
peer/globals.go
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
version = 1
|
||||||
|
|
||||||
|
bufferSize = 8192 // Enough for data packets and encryption buffers.
|
||||||
|
|
||||||
|
if_mtu = 1200
|
||||||
|
if_queue_len = 2048
|
||||||
|
|
||||||
|
controlCipherOverhead = 16
|
||||||
|
dataCipherOverhead = 16
|
||||||
|
signingOverhead = 64
|
||||||
|
|
||||||
|
pingInterval = 8 * time.Second
|
||||||
|
timeoutInterval = 30 * time.Second
|
||||||
|
broadcastInterval = 16 * time.Second
|
||||||
|
broadcastErrorTimeoutInterval = 8 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
var multicastAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom(
|
||||||
|
netip.AddrFrom4([4]byte{224, 0, 0, 157}),
|
||||||
|
4560))
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type Globals struct {
|
||||||
|
LocalConfig // Embed, immutable.
|
||||||
|
|
||||||
|
// The number of startups
|
||||||
|
StartupCount uint16
|
||||||
|
|
||||||
|
// Local public address (if available). Immutable.
|
||||||
|
LocalAddr netip.AddrPort
|
||||||
|
|
||||||
|
// True if local public address is valid. Immutable.
|
||||||
|
LocalAddrValid bool
|
||||||
|
|
||||||
|
// All remote peers by VPN IP.
|
||||||
|
RemotePeers [256]*atomic.Pointer[Remote]
|
||||||
|
|
||||||
|
// Discovered public addresses.
|
||||||
|
PubAddrs *pubAddrStore
|
||||||
|
|
||||||
|
// Attempts to ensure that we have a relay available.
|
||||||
|
RelayHandler *relayHandler
|
||||||
|
|
||||||
|
// Send UDP - Global function to write UDP packets.
|
||||||
|
SendUDP func(b []byte, addr netip.AddrPort) (n int, err error)
|
||||||
|
|
||||||
|
// Global TUN interface.
|
||||||
|
IFace io.ReadWriteCloser
|
||||||
|
|
||||||
|
// For trace ID.
|
||||||
|
NewTraceID func() uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewGlobals(
|
||||||
|
localConfig LocalConfig,
|
||||||
|
startupCount startupCount,
|
||||||
|
localAddr netip.AddrPort,
|
||||||
|
conn *net.UDPConn,
|
||||||
|
iface io.ReadWriteCloser,
|
||||||
|
) (g Globals) {
|
||||||
|
g.LocalConfig = localConfig
|
||||||
|
g.StartupCount = startupCount.Count
|
||||||
|
|
||||||
|
g.LocalAddr = localAddr
|
||||||
|
g.LocalAddrValid = localAddr.IsValid()
|
||||||
|
|
||||||
|
g.PubAddrs = newPubAddrStore(localAddr)
|
||||||
|
|
||||||
|
g.RelayHandler = newRelayHandler()
|
||||||
|
|
||||||
|
// Use a lock here avoids starvation, at least on my Linux machine.
|
||||||
|
sendLock := sync.Mutex{}
|
||||||
|
g.SendUDP = func(b []byte, addr netip.AddrPort) (int, error) {
|
||||||
|
sendLock.Lock()
|
||||||
|
n, err := conn.WriteToUDPAddrPort(b, addr)
|
||||||
|
sendLock.Unlock()
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
g.IFace = iface
|
||||||
|
|
||||||
|
traceID := (uint64(g.StartupCount) << 48) + 1
|
||||||
|
g.NewTraceID = func() uint64 {
|
||||||
|
return atomic.AddUint64(&traceID, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range g.RemotePeers {
|
||||||
|
g.RemotePeers[i] = &atomic.Pointer[Remote]{}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range g.RemotePeers {
|
||||||
|
g.RemotePeers[i].Store(newRemote(g, byte(i)))
|
||||||
|
}
|
||||||
|
|
||||||
|
return g
|
||||||
|
}
|
||||||
47
peer/header.go
Normal file
47
peer/header.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import "unsafe"
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
const (
|
||||||
|
headerSize = 12
|
||||||
|
controlHeaderSize = 24
|
||||||
|
dataHeaderSize = 12
|
||||||
|
|
||||||
|
dataStreamID = 1
|
||||||
|
controlStreamID = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
type Header struct {
|
||||||
|
Version byte
|
||||||
|
StreamID byte
|
||||||
|
SourceIP byte
|
||||||
|
DestIP byte
|
||||||
|
Counter uint64 // Init with time.Now().Unix << 30 to ensure monotonic.
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseHeader(b []byte) (h Header) {
|
||||||
|
h.Version = b[0]
|
||||||
|
h.StreamID = b[1]
|
||||||
|
h.SourceIP = b[2]
|
||||||
|
h.DestIP = b[3]
|
||||||
|
h.Counter = *(*uint64)(unsafe.Pointer(&b[4]))
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Header) Parse(b []byte) {
|
||||||
|
h.Version = b[0]
|
||||||
|
h.StreamID = b[1]
|
||||||
|
h.SourceIP = b[2]
|
||||||
|
h.DestIP = b[3]
|
||||||
|
h.Counter = *(*uint64)(unsafe.Pointer(&b[4]))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Header) Marshal(buf []byte) {
|
||||||
|
buf[0] = h.Version
|
||||||
|
buf[1] = h.StreamID
|
||||||
|
buf[2] = h.SourceIP
|
||||||
|
buf[3] = h.DestIP
|
||||||
|
*(*uint64)(unsafe.Pointer(&buf[4])) = h.Counter
|
||||||
|
}
|
||||||
21
peer/header_test.go
Normal file
21
peer/header_test.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestHeaderMarshalParse(t *testing.T) {
|
||||||
|
nIn := Header{
|
||||||
|
StreamID: 23,
|
||||||
|
Counter: 3212,
|
||||||
|
SourceIP: 34,
|
||||||
|
DestIP: 200,
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, headerSize)
|
||||||
|
nIn.Marshal(buf)
|
||||||
|
|
||||||
|
nOut := Header{}
|
||||||
|
nOut.Parse(buf)
|
||||||
|
if nIn != nOut {
|
||||||
|
t.Fatal(nIn, nOut)
|
||||||
|
}
|
||||||
|
}
|
||||||
128
peer/hosts.go
128
peer/hosts.go
@@ -1,128 +0,0 @@
|
|||||||
package peer
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"net/netip"
|
|
||||||
"os"
|
|
||||||
"sort"
|
|
||||||
"strings"
|
|
||||||
"syscall"
|
|
||||||
|
|
||||||
"git.crumpington.com/lib/go/flock"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
hostsFile = "/etc/hosts"
|
|
||||||
hostsBegin = "# BEGIN vppn"
|
|
||||||
hostsEnd = "# END vppn"
|
|
||||||
)
|
|
||||||
|
|
||||||
// hostMarkers returns the begin/end marker lines that delimit the managed
|
|
||||||
// section for localDomain. The domain is wrapped in parentheses so one domain's
|
|
||||||
// marker can never be a prefix of another's (e.g. "net" vs "net2") when
|
|
||||||
// multiple vppn instances share /etc/hosts.
|
|
||||||
func hostMarkers(localDomain string) (begin, end string) {
|
|
||||||
return hostsBegin + "(" + localDomain + ")", hostsEnd + "(" + localDomain + ")"
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateHosts rewrites the managed vppn section in /etc/hosts using the
|
|
||||||
// current peersByIP map. Peers without a Name are skipped.
|
|
||||||
func (a *App) updateHosts() {
|
|
||||||
if a.localDomain == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := updateHosts(hostsFile, a.localDomain, a.peersByIP); err != nil {
|
|
||||||
log.Printf("Failed to update hosts file: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateHosts(hostsPath, localDomain string, peers map[netip.Addr]*Peer) error {
|
|
||||||
lockFile, err := flock.Lock(hostsPath + ".vppn.lock")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer lockFile.Close()
|
|
||||||
|
|
||||||
begin, end := hostMarkers(localDomain)
|
|
||||||
|
|
||||||
info, err := os.Stat(hostsPath)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
raw, err := os.ReadFile(hostsPath)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
data := string(raw)
|
|
||||||
|
|
||||||
before := strings.TrimSpace(data)
|
|
||||||
after := ""
|
|
||||||
|
|
||||||
if idxBegin := strings.Index(data, begin); idxBegin != -1 {
|
|
||||||
idxEnd := strings.Index(data[idxBegin:], end)
|
|
||||||
if idxEnd != -1 {
|
|
||||||
after = strings.TrimSpace(data[idxBegin+idxEnd+len(end):])
|
|
||||||
}
|
|
||||||
before = strings.TrimSpace(data[:idxBegin])
|
|
||||||
}
|
|
||||||
|
|
||||||
b := strings.Builder{}
|
|
||||||
b.WriteString(before)
|
|
||||||
b.WriteRune('\n')
|
|
||||||
b.WriteString(after)
|
|
||||||
b.WriteRune('\n')
|
|
||||||
b.WriteRune('\n')
|
|
||||||
|
|
||||||
b.WriteString(begin)
|
|
||||||
b.WriteRune('\n')
|
|
||||||
|
|
||||||
// Collect entries so we can sort by IP for stable output. Pad the IP
|
|
||||||
// column to the width of the widest possible address ("255.255.255.255")
|
|
||||||
// for readability.
|
|
||||||
type entry struct {
|
|
||||||
ip netip.Addr
|
|
||||||
host string
|
|
||||||
}
|
|
||||||
var entries []entry
|
|
||||||
for ip, p := range peers {
|
|
||||||
if p.Name == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
entries = append(entries, entry{ip: ip, host: p.Name + "." + localDomain})
|
|
||||||
}
|
|
||||||
sort.Slice(entries, func(i, j int) bool {
|
|
||||||
return entries[i].ip.Less(entries[j].ip)
|
|
||||||
})
|
|
||||||
|
|
||||||
for _, e := range entries {
|
|
||||||
b.WriteString(fmt.Sprintf("%-15s %s\n", e.ip.String(), e.host))
|
|
||||||
}
|
|
||||||
|
|
||||||
b.WriteString(end)
|
|
||||||
b.WriteRune('\n')
|
|
||||||
|
|
||||||
// Write to a temp file in the same directory, then rename over the
|
|
||||||
// original so readers never observe a partial file. Preserve the
|
|
||||||
// original's mode and ownership, since rename replaces the inode.
|
|
||||||
tmpPath := hostsPath + ".vppn.tmp"
|
|
||||||
if err := os.WriteFile(tmpPath, []byte(b.String()), info.Mode().Perm()); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if st, ok := info.Sys().(*syscall.Stat_t); ok {
|
|
||||||
if err := os.Chown(tmpPath, int(st.Uid), int(st.Gid)); err != nil {
|
|
||||||
os.Remove(tmpPath)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := os.Rename(tmpPath, hostsPath); err != nil {
|
|
||||||
os.Remove(tmpPath)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,205 +0,0 @@
|
|||||||
package peer
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"sort"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
// writeTempHosts creates a temp hosts file with the given content and returns
|
|
||||||
// its path.
|
|
||||||
func writeTempHosts(t *testing.T, content string) string {
|
|
||||||
t.Helper()
|
|
||||||
path := filepath.Join(t.TempDir(), "hosts")
|
|
||||||
if err := os.WriteFile(path, []byte(content), 0600); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
return path
|
|
||||||
}
|
|
||||||
|
|
||||||
// readManagedSection returns the lines between the begin/end markers for the
|
|
||||||
// given localDomain, plus everything outside the section ("outside").
|
|
||||||
func readManagedSection(t *testing.T, path, localDomain string) (inside, outside []string) {
|
|
||||||
t.Helper()
|
|
||||||
raw, err := os.ReadFile(path)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
begin, end := hostMarkers(localDomain)
|
|
||||||
|
|
||||||
inSection := false
|
|
||||||
for _, line := range strings.Split(string(raw), "\n") {
|
|
||||||
switch {
|
|
||||||
case strings.HasPrefix(line, begin):
|
|
||||||
inSection = true
|
|
||||||
case strings.HasPrefix(line, end):
|
|
||||||
inSection = false
|
|
||||||
case inSection:
|
|
||||||
if f := strings.Join(strings.Fields(line), " "); f != "" {
|
|
||||||
inside = append(inside, f)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
if f := strings.Join(strings.Fields(line), " "); f != "" {
|
|
||||||
outside = append(outside, f)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return inside, outside
|
|
||||||
}
|
|
||||||
|
|
||||||
func peer(name string) *Peer {
|
|
||||||
return &Peer{Name: name}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateHosts_AddsSection(t *testing.T) {
|
|
||||||
path := writeTempHosts(t, "127.0.0.1 localhost\n")
|
|
||||||
|
|
||||||
peers := map[netip.Addr]*Peer{
|
|
||||||
netip.MustParseAddr("10.11.12.1"): peer("hub"),
|
|
||||||
netip.MustParseAddr("10.11.12.10"): peer("laptop"),
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := updateHosts(path, "mynet.local", peers); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
inside, outside := readManagedSection(t, path, "mynet.local")
|
|
||||||
|
|
||||||
sort.Strings(inside)
|
|
||||||
want := []string{
|
|
||||||
"10.11.12.1 hub.mynet.local",
|
|
||||||
"10.11.12.10 laptop.mynet.local",
|
|
||||||
}
|
|
||||||
if strings.Join(inside, "\n") != strings.Join(want, "\n") {
|
|
||||||
t.Errorf("managed section = %v, want %v", inside, want)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !contains(outside, "127.0.0.1 localhost") {
|
|
||||||
t.Errorf("original content lost; outside = %v", outside)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateHosts_ReplacesExistingSection(t *testing.T) {
|
|
||||||
path := writeTempHosts(t, "127.0.0.1 localhost\n")
|
|
||||||
|
|
||||||
// First write.
|
|
||||||
first := map[netip.Addr]*Peer{
|
|
||||||
netip.MustParseAddr("10.11.12.1"): peer("hub"),
|
|
||||||
}
|
|
||||||
if err := updateHosts(path, "mynet.local", first); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Second write with a different set of peers.
|
|
||||||
second := map[netip.Addr]*Peer{
|
|
||||||
netip.MustParseAddr("10.11.12.20"): peer("phone"),
|
|
||||||
}
|
|
||||||
if err := updateHosts(path, "mynet.local", second); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
inside, outside := readManagedSection(t, path, "mynet.local")
|
|
||||||
|
|
||||||
if len(inside) != 1 || inside[0] != "10.11.12.20 phone.mynet.local" {
|
|
||||||
t.Errorf("section not replaced; inside = %v", inside)
|
|
||||||
}
|
|
||||||
if contains(inside, "10.11.12.1 hub.mynet.local") {
|
|
||||||
t.Errorf("stale entry remained; inside = %v", inside)
|
|
||||||
}
|
|
||||||
if !contains(outside, "127.0.0.1 localhost") {
|
|
||||||
t.Errorf("original content lost; outside = %v", outside)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateHosts_SkipsEmptyNames(t *testing.T) {
|
|
||||||
path := writeTempHosts(t, "127.0.0.1 localhost\n")
|
|
||||||
|
|
||||||
peers := map[netip.Addr]*Peer{
|
|
||||||
netip.MustParseAddr("10.11.12.1"): peer("hub"),
|
|
||||||
netip.MustParseAddr("10.11.12.99"): peer(""), // no name
|
|
||||||
}
|
|
||||||
if err := updateHosts(path, "mynet.local", peers); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
inside, _ := readManagedSection(t, path, "mynet.local")
|
|
||||||
if len(inside) != 1 || inside[0] != "10.11.12.1 hub.mynet.local" {
|
|
||||||
t.Errorf("expected only named peer; inside = %v", inside)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateHosts_Idempotent(t *testing.T) {
|
|
||||||
path := writeTempHosts(t, "127.0.0.1 localhost\n")
|
|
||||||
|
|
||||||
peers := map[netip.Addr]*Peer{
|
|
||||||
netip.MustParseAddr("10.11.12.1"): peer("hub"),
|
|
||||||
}
|
|
||||||
if err := updateHosts(path, "mynet.local", peers); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
first, err := os.ReadFile(path)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if err := updateHosts(path, "mynet.local", peers); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
second, err := os.ReadFile(path)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if string(first) != string(second) {
|
|
||||||
t.Errorf("repeated update changed file:\nfirst:\n%s\nsecond:\n%s", first, second)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestUpdateHosts_PrefixDomainsCoexist guards finding 4.4: two domains where
|
|
||||||
// one label is a prefix of the other ("net" vs "net2") must each manage their
|
|
||||||
// own section without clobbering the other's, even sharing one hosts file.
|
|
||||||
func TestUpdateHosts_PrefixDomainsCoexist(t *testing.T) {
|
|
||||||
path := writeTempHosts(t, "127.0.0.1 localhost\n")
|
|
||||||
|
|
||||||
if err := updateHosts(path, "net2.local", map[netip.Addr]*Peer{
|
|
||||||
netip.MustParseAddr("10.0.2.1"): peer("a"),
|
|
||||||
}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if err := updateHosts(path, "net.local", map[netip.Addr]*Peer{
|
|
||||||
netip.MustParseAddr("10.0.1.1"): peer("b"),
|
|
||||||
}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Both sections coexist after writing the prefix domain.
|
|
||||||
if in, _ := readManagedSection(t, path, "net2.local"); len(in) != 1 || in[0] != "10.0.2.1 a.net2.local" {
|
|
||||||
t.Errorf("net2 section clobbered: %v", in)
|
|
||||||
}
|
|
||||||
if in, _ := readManagedSection(t, path, "net.local"); len(in) != 1 || in[0] != "10.0.1.1 b.net.local" {
|
|
||||||
t.Errorf("net section wrong: %v", in)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Re-updating net2 must not disturb the net section.
|
|
||||||
if err := updateHosts(path, "net2.local", map[netip.Addr]*Peer{
|
|
||||||
netip.MustParseAddr("10.0.2.2"): peer("c"),
|
|
||||||
}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if in, _ := readManagedSection(t, path, "net.local"); len(in) != 1 || in[0] != "10.0.1.1 b.net.local" {
|
|
||||||
t.Errorf("net section disturbed by net2 update: %v", in)
|
|
||||||
}
|
|
||||||
if in, _ := readManagedSection(t, path, "net2.local"); len(in) != 1 || in[0] != "10.0.2.2 c.net2.local" {
|
|
||||||
t.Errorf("net2 section not updated: %v", in)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func contains(ss []string, s string) bool {
|
|
||||||
for _, x := range ss {
|
|
||||||
if x == s {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
@@ -1,153 +0,0 @@
|
|||||||
package peer
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
"net/netip"
|
|
||||||
"net/url"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
|
|
||||||
"vppn/m"
|
|
||||||
)
|
|
||||||
|
|
||||||
const hubPollInterval = 64 * time.Second
|
|
||||||
|
|
||||||
type HubPoller struct {
|
|
||||||
selfVPNIP netip.Addr
|
|
||||||
vpnNet netip.Prefix
|
|
||||||
hubURL string
|
|
||||||
apiKey string
|
|
||||||
statePath string // where the network state cache is persisted
|
|
||||||
addCh chan<- m.Peer
|
|
||||||
removeCh chan<- wgtypes.Key
|
|
||||||
known map[wgtypes.Key]struct{} // pubKeys currently configured
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewHubPoller(
|
|
||||||
selfVPNIP netip.Addr,
|
|
||||||
vpnNet netip.Prefix,
|
|
||||||
hubURL, apiKey string,
|
|
||||||
statePath string,
|
|
||||||
addCh chan<- m.Peer,
|
|
||||||
removeCh chan<- wgtypes.Key,
|
|
||||||
) (*HubPoller, error) {
|
|
||||||
u, err := url.Parse(hubURL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
u.Path = "/peer/fetch-state/"
|
|
||||||
|
|
||||||
return &HubPoller{
|
|
||||||
selfVPNIP: selfVPNIP,
|
|
||||||
vpnNet: vpnNet,
|
|
||||||
hubURL: u.String(),
|
|
||||||
apiKey: apiKey,
|
|
||||||
statePath: statePath,
|
|
||||||
addCh: addCh,
|
|
||||||
removeCh: removeCh,
|
|
||||||
known: make(map[wgtypes.Key]struct{}),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hp *HubPoller) Run() {
|
|
||||||
// Prime from the on-disk cache before reaching the hub, so the peer
|
|
||||||
// configures WireGuard from its last known state even if the hub is down.
|
|
||||||
// known starts empty, so this emits every cached peer as an add; the first
|
|
||||||
// real poll then emits only deltas (adds for new peers, removes for gone).
|
|
||||||
if state, err := loadNetworkState(hp.statePath); err == nil {
|
|
||||||
hp.apply(state)
|
|
||||||
}
|
|
||||||
|
|
||||||
hp.poll()
|
|
||||||
for range time.Tick(hubPollInterval) {
|
|
||||||
hp.poll()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hp *HubPoller) poll() {
|
|
||||||
req, err := http.NewRequest(http.MethodGet, hp.hubURL, nil)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("[HubPoller] build request: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
req.SetBasicAuth("", hp.apiKey)
|
|
||||||
|
|
||||||
client := &http.Client{Timeout: 32 * time.Second}
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("[HubPoller] fetch: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
log.Printf("[HubPoller] unexpected status %d", resp.StatusCode)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("[HubPoller] read body: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var state m.NetworkState
|
|
||||||
if err := json.Unmarshal(body, &state); err != nil {
|
|
||||||
log.Printf("[HubPoller] unmarshal: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Persist only when the state actually changed, to avoid needless writes
|
|
||||||
// on every poll.
|
|
||||||
if hp.apply(state) {
|
|
||||||
if err := saveNetworkState(hp.statePath, state); err != nil {
|
|
||||||
log.Printf("[HubPoller] save state: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// apply diffs state against the set of known peers, emitting an add for each
|
|
||||||
// newly-seen peer and a remove for each that disappeared. It returns true if
|
|
||||||
// anything changed. A peer's config is immutable under a stable WG key (the hub
|
|
||||||
// has no peer-edit path), so a key already in known needs no re-emit.
|
|
||||||
func (hp *HubPoller) apply(state m.NetworkState) (changed bool) {
|
|
||||||
seen := make(map[wgtypes.Key]struct{}, len(hp.known))
|
|
||||||
|
|
||||||
netAddr := hp.vpnNet.Addr().As4()
|
|
||||||
|
|
||||||
for _, p := range state.Peers {
|
|
||||||
if p.WGPubKey == (wgtypes.Key{}) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
octets := netAddr
|
|
||||||
octets[3] = p.PeerIP
|
|
||||||
vpnIP := netip.AddrFrom4(octets)
|
|
||||||
if vpnIP == hp.selfVPNIP {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
seen[p.WGPubKey] = struct{}{}
|
|
||||||
|
|
||||||
if _, ok := hp.known[p.WGPubKey]; ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
hp.known[p.WGPubKey] = struct{}{}
|
|
||||||
hp.addCh <- p
|
|
||||||
changed = true
|
|
||||||
}
|
|
||||||
|
|
||||||
for key := range hp.known {
|
|
||||||
if _, ok := seen[key]; !ok {
|
|
||||||
delete(hp.known, key)
|
|
||||||
hp.removeCh <- key
|
|
||||||
changed = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return changed
|
|
||||||
}
|
|
||||||
@@ -1,80 +0,0 @@
|
|||||||
package peer
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
|
|
||||||
"vppn/m"
|
|
||||||
)
|
|
||||||
|
|
||||||
func testPoller(t *testing.T) (*HubPoller, chan m.Peer, chan wgtypes.Key) {
|
|
||||||
t.Helper()
|
|
||||||
addCh := make(chan m.Peer, 8)
|
|
||||||
removeCh := make(chan wgtypes.Key, 8)
|
|
||||||
hp := &HubPoller{
|
|
||||||
selfVPNIP: netip.MustParseAddr("10.0.0.1"),
|
|
||||||
vpnNet: netip.MustParsePrefix("10.0.0.0/24"),
|
|
||||||
addCh: addCh,
|
|
||||||
removeCh: removeCh,
|
|
||||||
known: make(map[wgtypes.Key]struct{}),
|
|
||||||
}
|
|
||||||
return hp, addCh, removeCh
|
|
||||||
}
|
|
||||||
|
|
||||||
func stateWith(key wgtypes.Key, peerIP byte) m.NetworkState {
|
|
||||||
return m.NetworkState{Peers: []m.Peer{{
|
|
||||||
PeerIP: peerIP,
|
|
||||||
WGPubKey: key,
|
|
||||||
}}}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestApply_EmitsAddsAndReportsChange(t *testing.T) {
|
|
||||||
hp, addCh, _ := testPoller(t)
|
|
||||||
key := mustKey(t)
|
|
||||||
|
|
||||||
if changed := hp.apply(stateWith(key, 2)); !changed {
|
|
||||||
t.Fatal("expected changed=true on first apply")
|
|
||||||
}
|
|
||||||
if len(addCh) != 1 {
|
|
||||||
t.Fatalf("expected 1 add, got %d", len(addCh))
|
|
||||||
}
|
|
||||||
if got := <-addCh; got.WGPubKey != key {
|
|
||||||
t.Errorf("add pubkey mismatch")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestApply_NoChangeWhenKnown(t *testing.T) {
|
|
||||||
hp, addCh, _ := testPoller(t)
|
|
||||||
key := mustKey(t)
|
|
||||||
|
|
||||||
hp.apply(stateWith(key, 2))
|
|
||||||
<-addCh // drain initial add
|
|
||||||
|
|
||||||
if changed := hp.apply(stateWith(key, 2)); changed {
|
|
||||||
t.Fatal("expected changed=false when peer already known")
|
|
||||||
}
|
|
||||||
if len(addCh) != 0 {
|
|
||||||
t.Fatalf("expected no re-emit, got %d adds", len(addCh))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestApply_RemovesVanishedPeer(t *testing.T) {
|
|
||||||
hp, addCh, removeCh := testPoller(t)
|
|
||||||
key := mustKey(t)
|
|
||||||
|
|
||||||
hp.apply(stateWith(key, 2))
|
|
||||||
<-addCh
|
|
||||||
|
|
||||||
// Empty state: the peer is gone.
|
|
||||||
if changed := hp.apply(m.NetworkState{}); !changed {
|
|
||||||
t.Fatal("expected changed=true when peer vanishes")
|
|
||||||
}
|
|
||||||
if len(removeCh) != 1 {
|
|
||||||
t.Fatalf("expected 1 remove, got %d", len(removeCh))
|
|
||||||
}
|
|
||||||
if got := <-removeCh; got != key {
|
|
||||||
t.Errorf("remove key mismatch")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
111
peer/hubpoller.go
Normal file
111
peer/hubpoller.go
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
"vppn/m"
|
||||||
|
)
|
||||||
|
|
||||||
|
type HubPoller struct {
|
||||||
|
Globals
|
||||||
|
client *http.Client
|
||||||
|
req *http.Request
|
||||||
|
versions [256]int64
|
||||||
|
netName string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHubPoller(
|
||||||
|
g Globals,
|
||||||
|
netName,
|
||||||
|
hubURL,
|
||||||
|
apiKey string,
|
||||||
|
) (*HubPoller, error) {
|
||||||
|
u, err := url.Parse(hubURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
u.Path = "/peer/fetch-state/"
|
||||||
|
|
||||||
|
client := &http.Client{Timeout: 8 * time.Second}
|
||||||
|
|
||||||
|
req := &http.Request{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
URL: u,
|
||||||
|
Header: http.Header{},
|
||||||
|
}
|
||||||
|
req.SetBasicAuth("", apiKey)
|
||||||
|
|
||||||
|
return &HubPoller{
|
||||||
|
Globals: g,
|
||||||
|
client: client,
|
||||||
|
req: req,
|
||||||
|
netName: netName,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hp *HubPoller) logf(s string, args ...any) {
|
||||||
|
log.Printf("[HubPoller] "+s, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hp *HubPoller) Run() {
|
||||||
|
state, err := loadNetworkState(hp.netName)
|
||||||
|
if err != nil {
|
||||||
|
hp.logf("Failed to load network state: %v", err)
|
||||||
|
hp.logf("Polling hub...")
|
||||||
|
hp.pollHub()
|
||||||
|
} else {
|
||||||
|
hp.applyNetworkState(state)
|
||||||
|
}
|
||||||
|
|
||||||
|
for range time.Tick(64 * time.Second) {
|
||||||
|
hp.pollHub()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hp *HubPoller) pollHub() {
|
||||||
|
var state m.NetworkState
|
||||||
|
|
||||||
|
resp, err := hp.client.Do(hp.req)
|
||||||
|
if err != nil {
|
||||||
|
hp.logf("Failed to fetch peer state: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
hp.logf("Failed to read body from hub: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(body, &state); err != nil {
|
||||||
|
hp.logf("Failed to unmarshal response from hub: %v\n%s", err, body)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := storeNetworkState(hp.netName, state); err != nil {
|
||||||
|
hp.logf("Failed to store network state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
hp.applyNetworkState(state)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hp *HubPoller) applyNetworkState(state m.NetworkState) {
|
||||||
|
for i, peer := range state.Peers {
|
||||||
|
if i == int(hp.LocalPeerIP) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if peer != nil && peer.Version == hp.versions[i] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
hp.RemotePeers[i].Load().HandlePeerUpdate(peerUpdateMsg{Peer: state.Peers[i]})
|
||||||
|
if peer != nil {
|
||||||
|
hp.versions[i] = peer.Version
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
73
peer/ifreader.go
Normal file
73
peer/ifreader.go
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
)
|
||||||
|
|
||||||
|
type IFReader struct {
|
||||||
|
Globals
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewIFReader(g Globals) *IFReader {
|
||||||
|
return &IFReader{Globals: g}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *IFReader) Run() {
|
||||||
|
packet := make([]byte, bufferSize)
|
||||||
|
for {
|
||||||
|
r.handleNextPacket(packet)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *IFReader) handleNextPacket(packet []byte) {
|
||||||
|
packet = r.readNextPacket(packet)
|
||||||
|
remoteIP, ok := r.parsePacket(packet)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.RemotePeers[remoteIP].Load().SendDataTo(packet)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *IFReader) readNextPacket(buf []byte) []byte {
|
||||||
|
n, err := r.IFace.Read(buf[:cap(buf)])
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to read from interface: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf[:n]
|
||||||
|
}
|
||||||
|
|
||||||
|
// parsePacket returns the VPN ip for the packet, and a boolean indicating
|
||||||
|
// success.
|
||||||
|
func (r *IFReader) parsePacket(buf []byte) (byte, bool) {
|
||||||
|
n := len(buf)
|
||||||
|
if n == 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
version := buf[0] >> 4
|
||||||
|
|
||||||
|
switch version {
|
||||||
|
case 4:
|
||||||
|
if n < 20 {
|
||||||
|
r.logf("Short IPv4 packet: %d", len(buf))
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return buf[19], true
|
||||||
|
|
||||||
|
case 6:
|
||||||
|
if len(buf) < 40 {
|
||||||
|
r.logf("Short IPv6 packet: %d", len(buf))
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return buf[39], true
|
||||||
|
|
||||||
|
default:
|
||||||
|
r.logf("Invalid IP packet version: %v", version)
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*IFReader) logf(s string, args ...any) {
|
||||||
|
log.Printf("[IFReader] "+s, args...)
|
||||||
|
}
|
||||||
81
peer/ifreader_test.go
Normal file
81
peer/ifreader_test.go
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
/*
|
||||||
|
func TestIFReader_IPv4(t *testing.T) {
|
||||||
|
p1, p2, _ := NewPeersForTesting()
|
||||||
|
|
||||||
|
pkt := make([]byte, 1234)
|
||||||
|
pkt[0] = 4 << 4
|
||||||
|
pkt[19] = 2 // IP.
|
||||||
|
|
||||||
|
p1.IFace.UserWrite(pkt)
|
||||||
|
p1.IFReader.handleNextPacket(newBuf())
|
||||||
|
|
||||||
|
packets := p2.Conn.Packets()
|
||||||
|
if len(packets) != 1 {
|
||||||
|
t.Fatal(packets)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIFReader_IPv6(t *testing.T) {
|
||||||
|
p1, p2, _ := NewPeersForTesting()
|
||||||
|
|
||||||
|
pkt := make([]byte, 1234)
|
||||||
|
pkt[0] = 6 << 4
|
||||||
|
pkt[39] = 2 // IP.
|
||||||
|
|
||||||
|
p1.IFace.UserWrite(pkt)
|
||||||
|
p1.IFReader.handleNextPacket(newBuf())
|
||||||
|
|
||||||
|
packets := p2.Conn.Packets()
|
||||||
|
if len(packets) != 1 {
|
||||||
|
t.Fatal(packets)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIFReader_parsePacket_emptyPacket(t *testing.T) {
|
||||||
|
r := NewIFReader(nil, nil)
|
||||||
|
pkt := make([]byte, 0)
|
||||||
|
if ip, ok := r.parsePacket(pkt); ok {
|
||||||
|
t.Fatal(ip, ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIFReader_parsePacket_invalidIPVersion(t *testing.T) {
|
||||||
|
r := NewIFReader(nil, nil)
|
||||||
|
|
||||||
|
for i := byte(1); i < 16; i++ {
|
||||||
|
if i == 4 || i == 6 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
pkt := make([]byte, 1234)
|
||||||
|
pkt[0] = i << 4
|
||||||
|
|
||||||
|
if ip, ok := r.parsePacket(pkt); ok {
|
||||||
|
t.Fatal(i, ip, ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIFReader_parsePacket_shortIPv4(t *testing.T) {
|
||||||
|
r := NewIFReader(nil, nil)
|
||||||
|
|
||||||
|
pkt := make([]byte, 19)
|
||||||
|
pkt[0] = 4 << 4
|
||||||
|
|
||||||
|
if ip, ok := r.parsePacket(pkt); ok {
|
||||||
|
t.Fatal(ip, ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIFReader_parsePacket_shortIPv6(t *testing.T) {
|
||||||
|
r := NewIFReader(nil, nil)
|
||||||
|
|
||||||
|
pkt := make([]byte, 39)
|
||||||
|
pkt[0] = 6 << 4
|
||||||
|
|
||||||
|
if ip, ok := r.parsePacket(pkt); ok {
|
||||||
|
t.Fatal(ip, ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*/
|
||||||
190
peer/init.go
190
peer/init.go
@@ -1,190 +0,0 @@
|
|||||||
package peer
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"crypto/rand"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"net/netip"
|
|
||||||
"os"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/nacl/sign"
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
|
|
||||||
"vppn/m"
|
|
||||||
)
|
|
||||||
|
|
||||||
// LocalState is the persisted identity for this peer, written on first run and
|
|
||||||
// loaded on every subsequent run.
|
|
||||||
type LocalState struct {
|
|
||||||
PrivKey wgtypes.Key
|
|
||||||
SignKey [64]byte // nacl/sign Ed25519 private key
|
|
||||||
VPNIP netip.Addr
|
|
||||||
VPNNet netip.Prefix
|
|
||||||
WGPort uint16
|
|
||||||
IsRelay bool
|
|
||||||
IsPublic bool
|
|
||||||
LocalDomain string
|
|
||||||
}
|
|
||||||
|
|
||||||
// localStateJSON is the on-disk representation.
|
|
||||||
type localStateJSON struct {
|
|
||||||
PrivKey string
|
|
||||||
SignKey string
|
|
||||||
VPNIP netip.Addr
|
|
||||||
VPNNet netip.Prefix
|
|
||||||
WGPort uint16
|
|
||||||
IsRelay bool
|
|
||||||
IsPublic bool
|
|
||||||
LocalDomain string
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoadOrInit loads LocalState from path, or registers with the hub and creates
|
|
||||||
// the file if it doesn't exist.
|
|
||||||
func LoadOrInit(statePath, hubURL, apiKey string) (LocalState, error) {
|
|
||||||
var state LocalState
|
|
||||||
switch err := loadJSON(statePath, &state); {
|
|
||||||
case err == nil:
|
|
||||||
return state, nil
|
|
||||||
case !os.IsNotExist(err):
|
|
||||||
// File exists but is unreadable/corrupt: surface it rather than
|
|
||||||
// silently regenerating a new identity and re-registering.
|
|
||||||
return LocalState{}, fmt.Errorf("load state: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
privKey, err := wgtypes.GeneratePrivateKey()
|
|
||||||
if err != nil {
|
|
||||||
return LocalState{}, fmt.Errorf("generate key: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
state, err = initFromHub(hubURL, apiKey, privKey)
|
|
||||||
if err != nil {
|
|
||||||
return LocalState{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := storeJSON(statePath, state); err != nil {
|
|
||||||
return LocalState{}, fmt.Errorf("save state: %w", err)
|
|
||||||
}
|
|
||||||
return state, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func initFromHub(hubURL, apiKey string, privKey wgtypes.Key) (LocalState, error) {
|
|
||||||
wgPubKey := privKey.PublicKey()
|
|
||||||
|
|
||||||
signPubKey, signPrivKey, err := sign.GenerateKey(rand.Reader)
|
|
||||||
if err != nil {
|
|
||||||
return LocalState{}, fmt.Errorf("generate sign key: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
body, err := json.Marshal(m.PeerInitArgs{
|
|
||||||
WGPubKey: wgPubKey[:],
|
|
||||||
SignPubKey: signPubKey[:],
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return LocalState{}, fmt.Errorf("json error: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequest(http.MethodPost, hubURL+"/peer/init/", bytes.NewReader(body))
|
|
||||||
if err != nil {
|
|
||||||
return LocalState{}, err
|
|
||||||
}
|
|
||||||
req.SetBasicAuth("", apiKey)
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return LocalState{}, fmt.Errorf("hub init: %w", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return LocalState{}, fmt.Errorf("hub init: HTTP %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
var r m.PeerInitResp
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&r); err != nil {
|
|
||||||
return LocalState{}, fmt.Errorf("hub init decode: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(r.Network) != 4 {
|
|
||||||
return LocalState{}, fmt.Errorf("hub init: invalid network %v", r.Network)
|
|
||||||
}
|
|
||||||
|
|
||||||
netAddr := netip.AddrFrom4([4]byte(r.Network))
|
|
||||||
octets := netAddr.As4()
|
|
||||||
octets[3] = r.PeerIP
|
|
||||||
vpnIP := netip.AddrFrom4(octets)
|
|
||||||
vpnNet := netip.PrefixFrom(netAddr, 24)
|
|
||||||
|
|
||||||
var self *m.Peer
|
|
||||||
for i := range r.NetworkState.Peers {
|
|
||||||
if r.NetworkState.Peers[i].PeerIP == r.PeerIP {
|
|
||||||
self = &r.NetworkState.Peers[i]
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if self == nil {
|
|
||||||
return LocalState{}, fmt.Errorf("hub init: no peer for own IP: %d", r.PeerIP)
|
|
||||||
}
|
|
||||||
|
|
||||||
public := self.IsPublic()
|
|
||||||
|
|
||||||
return LocalState{
|
|
||||||
PrivKey: privKey,
|
|
||||||
SignKey: *signPrivKey,
|
|
||||||
VPNIP: vpnIP,
|
|
||||||
VPNNet: vpnNet,
|
|
||||||
WGPort: self.Port,
|
|
||||||
IsRelay: self.Relay && public,
|
|
||||||
IsPublic: public,
|
|
||||||
LocalDomain: r.LocalDomain,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s LocalState) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(localStateJSON{
|
|
||||||
PrivKey: base64.StdEncoding.EncodeToString(s.PrivKey[:]),
|
|
||||||
SignKey: base64.StdEncoding.EncodeToString(s.SignKey[:]),
|
|
||||||
VPNIP: s.VPNIP,
|
|
||||||
VPNNet: s.VPNNet,
|
|
||||||
WGPort: s.WGPort,
|
|
||||||
IsRelay: s.IsRelay,
|
|
||||||
IsPublic: s.IsPublic,
|
|
||||||
LocalDomain: s.LocalDomain,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *LocalState) UnmarshalJSON(data []byte) error {
|
|
||||||
var j localStateJSON
|
|
||||||
if err := json.Unmarshal(data, &j); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
keyBytes, err := base64.StdEncoding.DecodeString(j.PrivKey)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("decode key: %w", err)
|
|
||||||
}
|
|
||||||
key, err := wgtypes.NewKey(keyBytes)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("invalid key: %w", err)
|
|
||||||
}
|
|
||||||
signKeyBytes, err := base64.StdEncoding.DecodeString(j.SignKey)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("decode sign key: %w", err)
|
|
||||||
}
|
|
||||||
if len(signKeyBytes) != 64 {
|
|
||||||
return fmt.Errorf("invalid sign key length: %d", len(signKeyBytes))
|
|
||||||
}
|
|
||||||
*s = LocalState{
|
|
||||||
PrivKey: key,
|
|
||||||
SignKey: [64]byte(signKeyBytes),
|
|
||||||
VPNIP: j.VPNIP,
|
|
||||||
VPNNet: j.VPNNet,
|
|
||||||
WGPort: j.WGPort,
|
|
||||||
IsRelay: j.IsRelay,
|
|
||||||
IsPublic: j.IsPublic,
|
|
||||||
LocalDomain: j.LocalDomain,
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
137
peer/interface.go
Normal file
137
peer/interface.go
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func openInterface(network []byte, localIP byte, name string) (io.ReadWriteCloser, error) {
|
||||||
|
if len(network) != 4 {
|
||||||
|
return nil, fmt.Errorf("expected network to be 4 bytes, got %d", len(network))
|
||||||
|
}
|
||||||
|
ip := net.IPv4(network[0], network[1], network[2], localIP)
|
||||||
|
|
||||||
|
//////////////////////////
|
||||||
|
// Create TUN Interface //
|
||||||
|
//////////////////////////
|
||||||
|
|
||||||
|
tunFD, err := syscall.Open("/dev/net/tun", syscall.O_RDWR|unix.O_CLOEXEC, 0600)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to open TUN device: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// New interface request.
|
||||||
|
req, err := unix.NewIfreq(name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create new TUN interface request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flags:
|
||||||
|
//
|
||||||
|
// IFF_NO_PI => don't add packet info data to packets sent to the interface.
|
||||||
|
// IFF_TUN => create a TUN device handling IP packets.
|
||||||
|
req.SetUint16(unix.IFF_NO_PI | unix.IFF_TUN)
|
||||||
|
|
||||||
|
err = unix.IoctlIfreq(tunFD, unix.TUNSETIFF, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to set TUN device settings: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name may not be exactly the same?
|
||||||
|
name = req.Name()
|
||||||
|
|
||||||
|
/////////////
|
||||||
|
// Set MTU //
|
||||||
|
/////////////
|
||||||
|
|
||||||
|
// We need a socket file descriptor to set other options for some reason.
|
||||||
|
sockFD, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to open socket: %w", err)
|
||||||
|
}
|
||||||
|
defer unix.Close(sockFD)
|
||||||
|
|
||||||
|
req, err = unix.NewIfreq(name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create MTU interface request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.SetUint32(if_mtu)
|
||||||
|
if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFMTU, req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to set interface MTU: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////
|
||||||
|
// Set Queue Length //
|
||||||
|
//////////////////////
|
||||||
|
|
||||||
|
req, err = unix.NewIfreq(name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create IP interface request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.SetUint16(if_queue_len)
|
||||||
|
if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFTXQLEN, req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to set interface queue length: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
/////////////////////
|
||||||
|
// Set IP and Mask //
|
||||||
|
/////////////////////
|
||||||
|
|
||||||
|
req, err = unix.NewIfreq(name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create IP interface request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := req.SetInet4Addr(ip.To4()); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to set interface request IP: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFADDR, req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to set interface IP: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SET MASK - must happen after setting address.
|
||||||
|
req, err = unix.NewIfreq(name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create mask interface request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := req.SetInet4Addr(net.IPv4(255, 255, 255, 0).To4()); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to set interface request mask: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := unix.IoctlIfreq(sockFD, unix.SIOCSIFNETMASK, req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to set interface mask: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////
|
||||||
|
// Bring Interface Up //
|
||||||
|
////////////////////////
|
||||||
|
|
||||||
|
req, err = unix.NewIfreq(name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create up interface request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get current flags.
|
||||||
|
if err = unix.IoctlIfreq(sockFD, unix.SIOCGIFFLAGS, req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get interface flags: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
flags := req.Uint16() | unix.IFF_UP | unix.IFF_RUNNING
|
||||||
|
|
||||||
|
// Set UP flag / broadcast flags.
|
||||||
|
req.SetUint16(flags)
|
||||||
|
if err = unix.IoctlIfreq(sockFD, unix.SIOCSIFFLAGS, req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to set interface up: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return os.NewFile(uintptr(tunFD), "tun"), nil
|
||||||
|
}
|
||||||
@@ -1,28 +0,0 @@
|
|||||||
package peer
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"vppn/peer/control"
|
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
)
|
|
||||||
|
|
||||||
// WGDevice is the subset of wginterface.Device used by App.
|
|
||||||
type WGDevice interface {
|
|
||||||
Name() string
|
|
||||||
Peers() ([]wgtypes.Peer, error)
|
|
||||||
AddPeer(pubKey wgtypes.Key) error
|
|
||||||
AddDirect(pubKey wgtypes.Key, endpoint netip.AddrPort, vpnIP netip.Addr) error
|
|
||||||
SetRelay(pubKey wgtypes.Key, endpoint netip.AddrPort, network netip.Prefix) error
|
|
||||||
AddProbe(pubKey wgtypes.Key, endpoint netip.AddrPort) error
|
|
||||||
Promote(pubKey wgtypes.Key, vpnIP netip.Addr) error
|
|
||||||
RemovePeer(pubKey wgtypes.Key) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// ControlConn sends pings to peers over the VPN control port.
|
|
||||||
// Reading is handled separately via run, which feeds the App's pingCh.
|
|
||||||
// buf is a caller-provided scratch buffer (at least control.Size bytes) used to
|
|
||||||
// marshal the ping; the caller reuses one across sends.
|
|
||||||
type ControlConn interface {
|
|
||||||
SendPing(dst netip.AddrPort, ping control.Ping, buf []byte) error
|
|
||||||
}
|
|
||||||
36
peer/json.go
36
peer/json.go
@@ -1,36 +0,0 @@
|
|||||||
package peer
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
)
|
|
||||||
|
|
||||||
func loadJSON(path string, target any) error {
|
|
||||||
data, err := os.ReadFile(path)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return json.Unmarshal(data, target)
|
|
||||||
}
|
|
||||||
|
|
||||||
func storeJSON(path string, obj any) error {
|
|
||||||
if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
data, err := json.MarshalIndent(obj, "", " ")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
tmpPath := path + ".tmp"
|
|
||||||
if err := os.WriteFile(tmpPath, data, 0600); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := os.Rename(tmpPath, path); err != nil {
|
|
||||||
os.Remove(tmpPath)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
207
peer/main.go
Normal file
207
peer/main.go
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Usage:
|
||||||
|
//
|
||||||
|
// vppn netName run
|
||||||
|
// vppn netName status
|
||||||
|
func Main2() {
|
||||||
|
printUsage := func() {
|
||||||
|
fmt.Fprintf(os.Stderr, `%s COMMAND [ARGUMENTS...]
|
||||||
|
|
||||||
|
Available commands:
|
||||||
|
run
|
||||||
|
status
|
||||||
|
hosts
|
||||||
|
`, os.Args[0])
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(os.Args) < 2 {
|
||||||
|
printUsage()
|
||||||
|
}
|
||||||
|
|
||||||
|
command := os.Args[1]
|
||||||
|
|
||||||
|
switch command {
|
||||||
|
case "run":
|
||||||
|
main_run()
|
||||||
|
case "status":
|
||||||
|
main_status()
|
||||||
|
case "hosts":
|
||||||
|
main_hosts()
|
||||||
|
default:
|
||||||
|
printUsage()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type mainArgs struct {
|
||||||
|
NetName string
|
||||||
|
HubAddress string
|
||||||
|
APIKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
func main_run() {
|
||||||
|
printUsage := func() {
|
||||||
|
fmt.Fprintf(os.Stderr, `Usage: %s run NETWORK_NAME HUB_ADDRESS API_KEY
|
||||||
|
|
||||||
|
NETWORK_NAME
|
||||||
|
Unique name of the network interface created. The network name
|
||||||
|
shouldn't change between invocations of the application.
|
||||||
|
|
||||||
|
HUB_ADDRESS
|
||||||
|
The address of the hub server. This should also contain the scheme, for
|
||||||
|
example https://hub.domain.com/.
|
||||||
|
|
||||||
|
API_KEY
|
||||||
|
The API key assigned to this peer by the hub.
|
||||||
|
|
||||||
|
`, os.Args[0])
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(os.Args) != 5 {
|
||||||
|
printUsage()
|
||||||
|
}
|
||||||
|
|
||||||
|
args := mainArgs{
|
||||||
|
NetName: os.Args[2],
|
||||||
|
HubAddress: os.Args[3],
|
||||||
|
APIKey: os.Args[4],
|
||||||
|
}
|
||||||
|
|
||||||
|
newPeerMain(args).Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func main_status() {
|
||||||
|
printUsage := func() {
|
||||||
|
fmt.Fprintf(os.Stderr, `Usage: %s status NETWORK_NAME
|
||||||
|
|
||||||
|
NETWORK_NAME
|
||||||
|
Unique name of the network interface created.
|
||||||
|
|
||||||
|
`, os.Args[0])
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(os.Args) != 3 {
|
||||||
|
printUsage()
|
||||||
|
}
|
||||||
|
|
||||||
|
netName := os.Args[2]
|
||||||
|
report := fetchStatusReport(netName)
|
||||||
|
|
||||||
|
fmt.Printf("\n%s Status\n\n", netName)
|
||||||
|
|
||||||
|
if len(report.Network) != 4 {
|
||||||
|
fmt.Printf("Network: %v\n\n", report.Network)
|
||||||
|
} else {
|
||||||
|
nw := report.Network
|
||||||
|
fmt.Printf("%-8s %d.%d.%d.%d/24\n", "Network", nw[0], nw[1], nw[2], nw[3])
|
||||||
|
}
|
||||||
|
|
||||||
|
if report.RelayPeerIP != 0 {
|
||||||
|
fmt.Printf("%-8s %d\n\n", "Relay", report.RelayPeerIP)
|
||||||
|
} else {
|
||||||
|
fmt.Printf("%-8s -\n\n", "Relay")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, status := range report.Remotes {
|
||||||
|
fmt.Printf("%3d %s\n", status.PeerIP, status.Name)
|
||||||
|
fmt.Printf(" %-11s %v\n", "Up", status.Up)
|
||||||
|
|
||||||
|
pubIP, ok := netip.AddrFromSlice(status.PublicIP)
|
||||||
|
if ok {
|
||||||
|
fmt.Printf(" %-11s %v\n", "Public IP", pubIP)
|
||||||
|
} else {
|
||||||
|
fmt.Printf(" %-11s\n", "Public IP")
|
||||||
|
}
|
||||||
|
fmt.Printf(" %-11s %d\n", "Port", status.Port)
|
||||||
|
fmt.Printf(" %-11s %v\n", "Relay", status.Relay)
|
||||||
|
fmt.Printf(" %-11s %v\n", "Server", status.Server)
|
||||||
|
fmt.Printf(" %-11s %v\n", "Direct", status.Direct)
|
||||||
|
if status.DirectAddr.IsValid() {
|
||||||
|
fmt.Printf(" %-11s %v\n", "Address", status.DirectAddr)
|
||||||
|
}
|
||||||
|
fmt.Println("")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func main_hosts() {
|
||||||
|
printUsage := func() {
|
||||||
|
fmt.Fprintf(os.Stderr, `Usage: %s hosts NETWORK_NAME
|
||||||
|
|
||||||
|
NETWORK_NAME
|
||||||
|
Unique name of the network interface created.
|
||||||
|
|
||||||
|
`, os.Args[0])
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(os.Args) != 3 {
|
||||||
|
printUsage()
|
||||||
|
}
|
||||||
|
|
||||||
|
netName := os.Args[2]
|
||||||
|
state, err := loadNetworkState(netName)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to load network state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
config, err := loadPeerConfig(netName)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to load config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
nw := config.Network
|
||||||
|
for _, peer := range state.Peers {
|
||||||
|
if peer == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fmt.Printf("%d.%d.%d.%d %s\n",
|
||||||
|
nw[0], nw[1], nw[2], peer.PeerIP, peer.Name)
|
||||||
|
}
|
||||||
|
fmt.Println("")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func fetchStatusReport(netName string) StatusReport {
|
||||||
|
client := http.Client{
|
||||||
|
Transport: &http.Transport{
|
||||||
|
Dial: func(_, _ string) (net.Conn, error) {
|
||||||
|
return net.Dial("unix", statusSocketPath(netName))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Timeout: 8 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
getURL := "http://unix" + statusSocketPath(netName)
|
||||||
|
resp, err := client.Get(getURL)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to get response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
report := StatusReport{}
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&report); err != nil {
|
||||||
|
log.Fatalf("Failed to decode status report: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return report
|
||||||
|
}
|
||||||
5
peer/main_test.go
Normal file
5
peer/main_test.go
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
func newBuf() []byte {
|
||||||
|
return make([]byte, bufferSize)
|
||||||
|
}
|
||||||
47
peer/mcreader.go
Normal file
47
peer/mcreader.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func RunMCReader(g Globals) {
|
||||||
|
for {
|
||||||
|
runMCReaderInner(g)
|
||||||
|
time.Sleep(broadcastErrorTimeoutInterval)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func runMCReaderInner(g Globals) {
|
||||||
|
var (
|
||||||
|
buf = make([]byte, bufferSize)
|
||||||
|
logf = func(s string, args ...any) {
|
||||||
|
log.Printf("[MCReader] "+s, args...)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
conn, err := net.ListenMulticastUDP("udp", nil, multicastAddr)
|
||||||
|
if err != nil {
|
||||||
|
logf("Failed to bind to multicast address: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
conn.SetReadDeadline(time.Now().Add(32 * time.Second))
|
||||||
|
n, remoteAddr, err := conn.ReadFromUDPAddrPort(buf[:bufferSize])
|
||||||
|
if err != nil {
|
||||||
|
logf("Failed to read from UDP port): %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
buf = buf[:n]
|
||||||
|
h, ok := headerFromLocalDiscoveryPacket(buf)
|
||||||
|
if !ok {
|
||||||
|
logf("Failed to open discovery packet?")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
g.RemotePeers[h.SourceIP].Load().HandleLocalDiscoveryPacket(h, remoteAddr, buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
132
peer/mcreader_test.go
Normal file
132
peer/mcreader_test.go
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
/*
|
||||||
|
type mcMockConn struct {
|
||||||
|
packets chan []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMCMockConn() *mcMockConn {
|
||||||
|
return &mcMockConn{make(chan []byte, 32)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *mcMockConn) WriteToUDP(in []byte, addr *net.UDPAddr) (int, error) {
|
||||||
|
c.packets <- bytes.Clone(in)
|
||||||
|
return len(in), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *mcMockConn) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) {
|
||||||
|
buf := <-c.packets
|
||||||
|
b = b[:len(buf)]
|
||||||
|
copy(b, buf)
|
||||||
|
return len(b), netip.AddrPort{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMCReader(t *testing.T) {
|
||||||
|
keys := generateKeys()
|
||||||
|
super := &mockControlMsgHandler{}
|
||||||
|
conn := newMCMockConn()
|
||||||
|
|
||||||
|
peers := [256]*atomic.Pointer[RemotePeer]{}
|
||||||
|
peer := &RemotePeer{
|
||||||
|
IP: 1,
|
||||||
|
Up: true,
|
||||||
|
PubSignKey: keys.PubSignKey,
|
||||||
|
}
|
||||||
|
peers[1] = &atomic.Pointer[RemotePeer]{}
|
||||||
|
peers[1].Store(peer)
|
||||||
|
|
||||||
|
w := newMCWriter(conn, 1, keys.PrivSignKey)
|
||||||
|
r := newMCReader(conn, super, peers)
|
||||||
|
|
||||||
|
w.SendLocalDiscovery()
|
||||||
|
r.handleNextPacket()
|
||||||
|
|
||||||
|
if len(super.Messages) != 1 {
|
||||||
|
t.Fatal(super.Messages)
|
||||||
|
}
|
||||||
|
msg, ok := super.Messages[0].(controlMsg[PacketLocalDiscovery])
|
||||||
|
if !ok || msg.SrcIP != 1 {
|
||||||
|
t.Fatal(ok, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMCReader_noHeader(t *testing.T) {
|
||||||
|
keys := generateKeys()
|
||||||
|
super := &mockControlMsgHandler{}
|
||||||
|
conn := newMCMockConn()
|
||||||
|
|
||||||
|
peers := [256]*atomic.Pointer[RemotePeer]{}
|
||||||
|
peer := &RemotePeer{
|
||||||
|
IP: 1,
|
||||||
|
Up: true,
|
||||||
|
PubSignKey: keys.PubSignKey,
|
||||||
|
}
|
||||||
|
peers[1] = &atomic.Pointer[RemotePeer]{}
|
||||||
|
peers[1].Store(peer)
|
||||||
|
|
||||||
|
r := newMCReader(conn, super, peers)
|
||||||
|
conn.WriteToUDP([]byte("0123546789"), nil)
|
||||||
|
r.handleNextPacket()
|
||||||
|
|
||||||
|
if len(super.Messages) != 0 {
|
||||||
|
t.Fatal(super.Messages)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMCReader_noPeer(t *testing.T) {
|
||||||
|
keys := generateKeys()
|
||||||
|
super := &mockControlMsgHandler{}
|
||||||
|
conn := newMCMockConn()
|
||||||
|
|
||||||
|
peers := [256]*atomic.Pointer[RemotePeer]{}
|
||||||
|
peer := &RemotePeer{
|
||||||
|
IP: 1,
|
||||||
|
Up: true,
|
||||||
|
PubSignKey: keys.PubSignKey,
|
||||||
|
}
|
||||||
|
peers[1] = &atomic.Pointer[RemotePeer]{}
|
||||||
|
peers[2] = &atomic.Pointer[RemotePeer]{}
|
||||||
|
peers[1].Store(peer)
|
||||||
|
|
||||||
|
w := newMCWriter(conn, 2, keys.PrivSignKey)
|
||||||
|
r := newMCReader(conn, super, peers)
|
||||||
|
|
||||||
|
w.SendLocalDiscovery()
|
||||||
|
r.handleNextPacket()
|
||||||
|
|
||||||
|
if len(super.Messages) != 0 {
|
||||||
|
t.Fatal(super.Messages)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMCReader_badSignature(t *testing.T) {
|
||||||
|
keys := generateKeys()
|
||||||
|
super := &mockControlMsgHandler{}
|
||||||
|
conn := newMCMockConn()
|
||||||
|
|
||||||
|
peers := [256]*atomic.Pointer[RemotePeer]{}
|
||||||
|
peer := &RemotePeer{
|
||||||
|
IP: 1,
|
||||||
|
Up: true,
|
||||||
|
PubSignKey: keys.PubSignKey,
|
||||||
|
}
|
||||||
|
peers[1] = &atomic.Pointer[RemotePeer]{}
|
||||||
|
peers[1].Store(peer)
|
||||||
|
|
||||||
|
w := newMCWriter(conn, 1, keys.PrivSignKey)
|
||||||
|
w.SendLocalDiscovery()
|
||||||
|
|
||||||
|
// Break signing.
|
||||||
|
packet := <-conn.packets
|
||||||
|
packet[0]++
|
||||||
|
conn.packets <- packet
|
||||||
|
|
||||||
|
r := newMCReader(conn, super, peers)
|
||||||
|
|
||||||
|
r.handleNextPacket()
|
||||||
|
|
||||||
|
if len(super.Messages) != 0 {
|
||||||
|
t.Fatal(super.Messages)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*/
|
||||||
54
peer/mcwriter.go
Normal file
54
peer/mcwriter.go
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/nacl/sign"
|
||||||
|
)
|
||||||
|
|
||||||
|
func createLocalDiscoveryPacket(localIP byte, signingKey []byte) []byte {
|
||||||
|
h := Header{
|
||||||
|
SourceIP: localIP,
|
||||||
|
DestIP: 255,
|
||||||
|
}
|
||||||
|
buf := make([]byte, headerSize)
|
||||||
|
h.Marshal(buf)
|
||||||
|
out := make([]byte, headerSize+signingOverhead)
|
||||||
|
return sign.Sign(out[:0], buf, (*[64]byte)(signingKey))
|
||||||
|
}
|
||||||
|
|
||||||
|
func headerFromLocalDiscoveryPacket(pkt []byte) (h Header, ok bool) {
|
||||||
|
if len(pkt) != headerSize+signingOverhead {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.Parse(pkt[signingOverhead:])
|
||||||
|
ok = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyLocalDiscoveryPacket(pkt, buf []byte, pubSignKey []byte) bool {
|
||||||
|
_, ok := sign.Open(buf[:0], pkt, (*[32]byte)(pubSignKey))
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func RunMCWriter(localIP byte, signingKey []byte) {
|
||||||
|
discoveryPacket := createLocalDiscoveryPacket(localIP, signingKey)
|
||||||
|
|
||||||
|
conn, err := net.ListenMulticastUDP("udp", nil, multicastAddr)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("[MCWriter] Failed to bind to multicast address: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for range time.Tick(broadcastInterval) {
|
||||||
|
log.Printf("[MCWriter] Broadcasting on %v...", multicastAddr)
|
||||||
|
_, err := conn.WriteToUDP(discoveryPacket, multicastAddr)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[MCWriter] Failed to write multicast: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
98
peer/mcwriter_test.go
Normal file
98
peer/mcwriter_test.go
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
/*
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// Testing that we can create and verify a local discovery packet.
|
||||||
|
func TestVerifyLocalDiscoveryPacket_valid(t *testing.T) {
|
||||||
|
keys := generateKeys()
|
||||||
|
|
||||||
|
created := createLocalDiscoveryPacket(55, keys.PrivSignKey)
|
||||||
|
|
||||||
|
header, ok := headerFromLocalDiscoveryPacket(created)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal(ok)
|
||||||
|
}
|
||||||
|
if header.SourceIP != 55 || header.DestIP != 255 {
|
||||||
|
t.Fatal(header)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !verifyLocalDiscoveryPacket(created, make([]byte, 1024), keys.PubSignKey) {
|
||||||
|
t.Fatal("Not valid")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Testing that we don't try to parse short packets.
|
||||||
|
func TestVerifyLocalDiscoveryPacket_tooShort(t *testing.T) {
|
||||||
|
keys := generateKeys()
|
||||||
|
|
||||||
|
created := createLocalDiscoveryPacket(55, keys.PrivSignKey)
|
||||||
|
|
||||||
|
_, ok := headerFromLocalDiscoveryPacket(created[:len(created)-1])
|
||||||
|
if ok {
|
||||||
|
t.Fatal(ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Testing that modifying a packet makes it invalid.
|
||||||
|
func TestVerifyLocalDiscoveryPacket_invalid(t *testing.T) {
|
||||||
|
keys := generateKeys()
|
||||||
|
|
||||||
|
created := createLocalDiscoveryPacket(55, keys.PrivSignKey)
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
for i := range created {
|
||||||
|
modified := bytes.Clone(created)
|
||||||
|
modified[i]++
|
||||||
|
if verifyLocalDiscoveryPacket(modified, buf, keys.PubSignKey) {
|
||||||
|
t.Fatal("Verification should have failed.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type testUDPWriter struct {
|
||||||
|
written [][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *testUDPWriter) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) {
|
||||||
|
w.written = append(w.written, bytes.Clone(b))
|
||||||
|
return len(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *testUDPWriter) Written() [][]byte {
|
||||||
|
out := w.written
|
||||||
|
w.written = [][]byte{}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// Testing that the mcWriter sends local discovery packets as expected.
|
||||||
|
func TestMCWriter_SendLocalDiscovery(t *testing.T) {
|
||||||
|
keys := generateKeys()
|
||||||
|
writer := &testUDPWriter{}
|
||||||
|
|
||||||
|
mcw := newMCWriter(writer, 42, keys.PrivSignKey)
|
||||||
|
mcw.SendLocalDiscovery()
|
||||||
|
|
||||||
|
out := writer.Written()
|
||||||
|
if len(out) != 1 {
|
||||||
|
t.Fatal(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
pkt := out[0]
|
||||||
|
|
||||||
|
header, ok := headerFromLocalDiscoveryPacket(pkt)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal(ok)
|
||||||
|
}
|
||||||
|
if header.SourceIP != 42 || header.DestIP != 255 {
|
||||||
|
t.Fatal(header)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !verifyLocalDiscoveryPacket(pkt, make([]byte, 1024), keys.PubSignKey) {
|
||||||
|
t.Fatal("Verification should succeed.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*/
|
||||||
31
peer/mock-iface_test.go
Normal file
31
peer/mock-iface_test.go
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import "bytes"
|
||||||
|
|
||||||
|
type TestIFace struct {
|
||||||
|
out *bytes.Buffer // Toward the network.
|
||||||
|
in *bytes.Buffer // From the network
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTestIFace() *TestIFace {
|
||||||
|
return &TestIFace{
|
||||||
|
out: &bytes.Buffer{},
|
||||||
|
in: &bytes.Buffer{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (iface *TestIFace) Write(b []byte) (int, error) {
|
||||||
|
return iface.in.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (iface *TestIFace) Read(b []byte) (int, error) {
|
||||||
|
return iface.out.Read(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (iface *TestIFace) UserWrite(b []byte) (int, error) {
|
||||||
|
return iface.out.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (iface *TestIFace) UserRead(b []byte) (int, error) {
|
||||||
|
return iface.in.Read(b)
|
||||||
|
}
|
||||||
80
peer/mock-network_test.go
Normal file
80
peer/mock-network_test.go
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TestPacket struct {
|
||||||
|
Addr netip.AddrPort
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type TestNetwork struct {
|
||||||
|
lock sync.Mutex
|
||||||
|
packets map[netip.AddrPort]chan TestPacket
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTestNetwork() *TestNetwork {
|
||||||
|
return &TestNetwork{packets: map[netip.AddrPort]chan TestPacket{}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *TestNetwork) NewUDPConn(localAddr netip.AddrPort) *TestUDPConn {
|
||||||
|
n.lock.Lock()
|
||||||
|
defer n.lock.Unlock()
|
||||||
|
if _, ok := n.packets[localAddr]; !ok {
|
||||||
|
n.packets[localAddr] = make(chan TestPacket, 1024)
|
||||||
|
}
|
||||||
|
return &TestUDPConn{
|
||||||
|
addr: localAddr,
|
||||||
|
n: n,
|
||||||
|
packets: n.packets[localAddr],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *TestNetwork) write(b []byte, from, to netip.AddrPort) {
|
||||||
|
n.lock.Lock()
|
||||||
|
defer n.lock.Unlock()
|
||||||
|
if _, ok := n.packets[to]; !ok {
|
||||||
|
n.packets[to] = make(chan TestPacket, 1024)
|
||||||
|
}
|
||||||
|
n.packets[to] <- TestPacket{
|
||||||
|
Addr: from,
|
||||||
|
Data: bytes.Clone(b),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type TestUDPConn struct {
|
||||||
|
addr netip.AddrPort
|
||||||
|
n *TestNetwork
|
||||||
|
packets chan TestPacket
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *TestUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
|
||||||
|
c.n.write(b, c.addr, addr)
|
||||||
|
return len(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *TestUDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) {
|
||||||
|
return c.WriteToUDPAddrPort(b, addr.AddrPort())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *TestUDPConn) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) {
|
||||||
|
pkt := <-c.packets
|
||||||
|
b = b[:len(pkt.Data)]
|
||||||
|
copy(b, pkt.Data)
|
||||||
|
return len(b), pkt.Addr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *TestUDPConn) Packets() (out []TestPacket) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case pkt := <-c.packets:
|
||||||
|
out = append(out, pkt)
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,62 +0,0 @@
|
|||||||
package multicast
|
|
||||||
|
|
||||||
import (
|
|
||||||
"log"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
)
|
|
||||||
|
|
||||||
var addr = net.UDPAddrFromAddrPort(netip.AddrPortFrom(
|
|
||||||
netip.AddrFrom4([4]byte{224, 0, 0, 157}),
|
|
||||||
4560))
|
|
||||||
|
|
||||||
func Broadcast(
|
|
||||||
selfVPNIP netip.Addr,
|
|
||||||
pubKey wgtypes.Key,
|
|
||||||
wgPort uint16,
|
|
||||||
signKey *[64]byte,
|
|
||||||
) {
|
|
||||||
for {
|
|
||||||
broadcastInner(selfVPNIP, pubKey, wgPort, signKey)
|
|
||||||
time.Sleep(errorTimeout)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func broadcastInner(selfVPNIP netip.Addr, pubKey wgtypes.Key, wgPort uint16, signKey *[64]byte) {
|
|
||||||
conn, err := net.ListenMulticastUDP("udp", nil, addr)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("[MCBroadcast] bind: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
buf := make([]byte, BufferSize)
|
|
||||||
packet := Packet{
|
|
||||||
PeerIP: selfVPNIP.As4()[3],
|
|
||||||
WGPubKey: pubKey,
|
|
||||||
WGPort: wgPort,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Re-sign on each send so the timestamp is fresh; a stale timestamp would be
|
|
||||||
// dropped by receivers' freshness gate.
|
|
||||||
send := func() error {
|
|
||||||
packet.Timestamp = time.Now().Unix()
|
|
||||||
payload := packet.Marshal(buf, signKey)
|
|
||||||
_, err := conn.WriteToUDP(payload, addr)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := send(); err != nil {
|
|
||||||
log.Printf("[MCBroadcast] write: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for range time.Tick(broadcastInterval) {
|
|
||||||
if err := send(); err != nil {
|
|
||||||
log.Printf("[MCBroadcast] write: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
package multicast
|
|
||||||
|
|
||||||
import "time"
|
|
||||||
|
|
||||||
const (
|
|
||||||
errorTimeout = 16 * time.Second
|
|
||||||
broadcastInterval = 16 * time.Second
|
|
||||||
maxPacketAge = time.Minute
|
|
||||||
)
|
|
||||||
@@ -1,54 +0,0 @@
|
|||||||
package multicast
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"net/netip"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/nacl/sign"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
BufferSize = packetSize + SignedPacketSize
|
|
||||||
SignedPacketSize = packetSize + signSize
|
|
||||||
packetSize = 43
|
|
||||||
signSize = 64
|
|
||||||
)
|
|
||||||
|
|
||||||
// Layout:
|
|
||||||
//
|
|
||||||
// [0] final octet of the sender's VPN IP
|
|
||||||
// [1:33] WG public key
|
|
||||||
// [33:35] WG listen port (big-endian uint16)
|
|
||||||
// [35:43] send time, Unix seconds (big-endian int64) — freshness/replay gate
|
|
||||||
type Packet struct {
|
|
||||||
PeerIP byte // Final octet of the sender's VPN IP.
|
|
||||||
WGPubKey [32]byte // WG public key.
|
|
||||||
WGPort uint16 // WG listen port.
|
|
||||||
Timestamp int64 // Unix timestamp.
|
|
||||||
Src netip.Addr // Source of packet.
|
|
||||||
Signed []byte // Raw signed message for verification (incoming packet).
|
|
||||||
}
|
|
||||||
|
|
||||||
// Marshal the packet into a buffer with prefixed signature.
|
|
||||||
func (p Packet) Marshal(buf []byte, signKey *[64]byte) []byte {
|
|
||||||
buf[0] = p.PeerIP
|
|
||||||
copy(buf[1:33], p.WGPubKey[:])
|
|
||||||
binary.BigEndian.PutUint16(buf[33:35], p.WGPort)
|
|
||||||
binary.BigEndian.PutUint64(buf[35:43], uint64(p.Timestamp))
|
|
||||||
return sign.Sign(buf[packetSize:packetSize], buf[:packetSize], signKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p Packet) Verify(buf []byte, pubKey *[32]byte) bool {
|
|
||||||
_, ok := sign.Open(buf, p.Signed, pubKey)
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func Unmarshal(signed []byte) (p Packet) {
|
|
||||||
buf := signed[signSize:]
|
|
||||||
p.PeerIP = buf[0]
|
|
||||||
copy(p.WGPubKey[:], buf[1:33])
|
|
||||||
p.WGPort = binary.BigEndian.Uint16(buf[33:35])
|
|
||||||
p.Timestamp = int64(binary.BigEndian.Uint64(buf[35:43]))
|
|
||||||
p.Signed = signed
|
|
||||||
return
|
|
||||||
}
|
|
||||||
@@ -1,38 +0,0 @@
|
|||||||
package multicast
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/nacl/sign"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestPacket(t *testing.T) {
|
|
||||||
pub, priv, err := sign.GenerateKey(rand.Reader)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
p := Packet{
|
|
||||||
PeerIP: 10,
|
|
||||||
WGPubKey: [32]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2},
|
|
||||||
WGPort: 44,
|
|
||||||
Timestamp: 12948893,
|
|
||||||
}
|
|
||||||
|
|
||||||
buf := make([]byte, BufferSize)
|
|
||||||
signed := p.Marshal(buf, priv)
|
|
||||||
if len(signed) != SignedPacketSize {
|
|
||||||
t.Fatalf("signed length = %d, want %d", len(signed), SignedPacketSize)
|
|
||||||
}
|
|
||||||
|
|
||||||
got := Unmarshal(signed)
|
|
||||||
if got.PeerIP != p.PeerIP || got.WGPubKey != p.WGPubKey ||
|
|
||||||
got.WGPort != p.WGPort || got.Timestamp != p.Timestamp {
|
|
||||||
t.Fatalf("round-trip mismatch:\n got %+v\nwant %+v", got, p)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !got.Verify(nil, pub) {
|
|
||||||
t.Error("signature did not verify")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,61 +0,0 @@
|
|||||||
package multicast
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Receiver(vpnNet netip.Prefix, selfVPNIP netip.Addr, ch chan<- Packet) {
|
|
||||||
for {
|
|
||||||
if err := receiver(vpnNet, selfVPNIP, ch); err != nil {
|
|
||||||
log.Printf("[MCReader] %v", err)
|
|
||||||
}
|
|
||||||
time.Sleep(errorTimeout)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func receiver(vpnNet netip.Prefix, selfVPNIP netip.Addr, ch chan<- Packet) error {
|
|
||||||
selfIP := selfVPNIP.As4()[3]
|
|
||||||
|
|
||||||
conn, err := net.ListenMulticastUDP("udp", nil, addr)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("bind: %w", err)
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
buf := make([]byte, BufferSize+1) // +1 to detect oversized packets
|
|
||||||
|
|
||||||
for {
|
|
||||||
conn.SetReadDeadline(time.Now().Add(32 * time.Second))
|
|
||||||
n, src, err := conn.ReadFromUDPAddrPort(buf)
|
|
||||||
if err != nil {
|
|
||||||
if ne, ok := err.(net.Error); ok && ne.Timeout() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return fmt.Errorf("read: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if n != SignedPacketSize {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
packet := Unmarshal(buf[:n])
|
|
||||||
|
|
||||||
if packet.PeerIP == selfIP {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
age := time.Since(time.Unix(packet.Timestamp, 0))
|
|
||||||
if age > maxPacketAge || age < -maxPacketAge {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
packet.Signed = bytes.Clone(packet.Signed)
|
|
||||||
packet.Src = src.Addr().Unmap()
|
|
||||||
ch <- packet
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
package peer
|
|
||||||
|
|
||||||
import "vppn/m"
|
|
||||||
|
|
||||||
// loadNetworkState reads a cached network state from disk. Any error (most
|
|
||||||
// commonly a missing file on first run) is returned to the caller, which
|
|
||||||
// treats it as "no cache available".
|
|
||||||
func loadNetworkState(path string) (m.NetworkState, error) {
|
|
||||||
var state m.NetworkState
|
|
||||||
err := loadJSON(path, &state)
|
|
||||||
return state, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// saveNetworkState writes state to path atomically (see storeJSON), so a crash
|
|
||||||
// mid-write cannot leave a corrupt cache.
|
|
||||||
func saveNetworkState(path string, state m.NetworkState) error {
|
|
||||||
return storeJSON(path, state)
|
|
||||||
}
|
|
||||||
@@ -1,56 +0,0 @@
|
|||||||
package peer
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"path/filepath"
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"vppn/m"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNetworkState_RoundTrip(t *testing.T) {
|
|
||||||
path := filepath.Join(t.TempDir(), "network.json")
|
|
||||||
|
|
||||||
var sign1 [32]byte
|
|
||||||
copy(sign1[:], []byte("0123456789abcdef0123456789abcdef"))
|
|
||||||
|
|
||||||
state := m.NetworkState{Peers: []m.Peer{
|
|
||||||
{
|
|
||||||
PeerIP: 1,
|
|
||||||
Name: "hub",
|
|
||||||
Addr4: netip.MustParseAddr("10.11.12.1"),
|
|
||||||
Port: 51820,
|
|
||||||
Relay: true,
|
|
||||||
WGPubKey: mustKey(t),
|
|
||||||
SignPubKey: sign1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: 10,
|
|
||||||
Name: "laptop",
|
|
||||||
Addr4: netip.MustParseAddr("10.11.12.10"),
|
|
||||||
Port: 51820,
|
|
||||||
WGPubKey: mustKey(t),
|
|
||||||
},
|
|
||||||
}}
|
|
||||||
|
|
||||||
if err := saveNetworkState(path, state); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
got, err := loadNetworkState(path)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(got, state) {
|
|
||||||
t.Errorf("round-trip mismatch:\n got: %+v\nwant: %+v", got.Peers[1], state.Peers[1])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNetworkState_LoadMissing(t *testing.T) {
|
|
||||||
path := filepath.Join(t.TempDir(), "does-not-exist.json")
|
|
||||||
if _, err := loadNetworkState(path); err == nil {
|
|
||||||
t.Fatal("expected error loading missing cache, got nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
109
peer/new.go
109
peer/new.go
@@ -1,109 +0,0 @@
|
|||||||
package peer
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
|
|
||||||
"vppn/m"
|
|
||||||
"vppn/peer/multicast"
|
|
||||||
"vppn/peer/wginterface"
|
|
||||||
)
|
|
||||||
|
|
||||||
// New constructs an App, creates the WireGuard interface, and starts the
|
|
||||||
// background goroutines (hub poller, multicast, control conn reader).
|
|
||||||
// The caller should invoke Run() to start the event loop.
|
|
||||||
func New(
|
|
||||||
state LocalState,
|
|
||||||
hubURL, apiKey string,
|
|
||||||
ifaceName string,
|
|
||||||
localDomain string,
|
|
||||||
networkStatePath string,
|
|
||||||
) (*App, error) {
|
|
||||||
|
|
||||||
a4 := state.VPNIP.As4()
|
|
||||||
if err := wginterface.Create(ifaceName, a4[:], 24); err != nil {
|
|
||||||
return nil, fmt.Errorf("create WG interface: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
dev, err := wginterface.Open(ifaceName)
|
|
||||||
if err != nil {
|
|
||||||
_ = wginterface.Delete(ifaceName)
|
|
||||||
return nil, fmt.Errorf("open WG device: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cc, err := newUDPControlConn(state.VPNIP, ControlPort)
|
|
||||||
if err != nil {
|
|
||||||
_ = dev.Close()
|
|
||||||
_ = wginterface.Delete(ifaceName)
|
|
||||||
return nil, fmt.Errorf("control conn: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cleanup := func() {
|
|
||||||
_ = cc.Close()
|
|
||||||
_ = dev.Close()
|
|
||||||
_ = wginterface.Delete(ifaceName)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := dev.Configure(state.PrivKey, int(state.WGPort)); err != nil {
|
|
||||||
cleanup()
|
|
||||||
return nil, fmt.Errorf("configure WG device: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if state.IsRelay {
|
|
||||||
if err := dev.EnableForwarding(); err != nil {
|
|
||||||
cleanup()
|
|
||||||
return nil, fmt.Errorf("enable forwarding: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pingCh := make(chan PingEvent)
|
|
||||||
hubAddCh := make(chan m.Peer)
|
|
||||||
hubRemoveCh := make(chan wgtypes.Key)
|
|
||||||
multicastCh := make(chan multicast.Packet)
|
|
||||||
|
|
||||||
poller, err := NewHubPoller(
|
|
||||||
state.VPNIP,
|
|
||||||
state.VPNNet,
|
|
||||||
hubURL,
|
|
||||||
apiKey,
|
|
||||||
networkStatePath,
|
|
||||||
hubAddCh,
|
|
||||||
hubRemoveCh)
|
|
||||||
if err != nil {
|
|
||||||
cleanup()
|
|
||||||
return nil, fmt.Errorf("hub poller: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
go cc.run(pingCh)
|
|
||||||
go poller.Run()
|
|
||||||
|
|
||||||
if !state.IsPublic {
|
|
||||||
go multicast.Broadcast(state.VPNIP, state.PrivKey.PublicKey(), state.WGPort, &state.SignKey)
|
|
||||||
go multicast.Receiver(state.VPNNet, state.VPNIP, multicastCh)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &App{
|
|
||||||
vpnIP: state.VPNIP,
|
|
||||||
vpnNet: state.VPNNet,
|
|
||||||
privKey: state.PrivKey,
|
|
||||||
pubKey: state.PrivKey.PublicKey(),
|
|
||||||
isRelay: state.IsRelay,
|
|
||||||
isPublic: state.IsPublic,
|
|
||||||
localDomain: localDomain,
|
|
||||||
|
|
||||||
dev: dev,
|
|
||||||
controlConn: cc,
|
|
||||||
|
|
||||||
peersByKey: make(map[wgtypes.Key]*Peer),
|
|
||||||
peersByIP: make(map[netip.Addr]*Peer),
|
|
||||||
|
|
||||||
scratch: make([]byte, scratchSize),
|
|
||||||
|
|
||||||
hubAddCh: hubAddCh,
|
|
||||||
hubRemoveCh: hubRemoveCh,
|
|
||||||
pingCh: pingCh,
|
|
||||||
multicastCh: multicastCh,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
114
peer/on_hub.go
114
peer/on_hub.go
@@ -1,114 +0,0 @@
|
|||||||
package peer
|
|
||||||
|
|
||||||
import (
|
|
||||||
"log"
|
|
||||||
"math"
|
|
||||||
"net/netip"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
|
|
||||||
"vppn/m"
|
|
||||||
"vppn/peer/control"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (a *App) onAddPeer(p m.Peer) {
|
|
||||||
a.onRemovePeer(p.WGPubKey)
|
|
||||||
|
|
||||||
octets := a.vpnNet.Addr().As4()
|
|
||||||
octets[3] = p.PeerIP
|
|
||||||
vpnIP := netip.AddrFrom4(octets)
|
|
||||||
|
|
||||||
peer := &Peer{
|
|
||||||
wgPeer: wgtypes.Peer{PublicKey: p.WGPubKey},
|
|
||||||
VPNIP: vpnIP,
|
|
||||||
Name: p.Name,
|
|
||||||
IsRelay: p.Relay,
|
|
||||||
IsPublic: p.IsPublic(),
|
|
||||||
EndpointV4: p.Endpoint4(),
|
|
||||||
EndpointV6: p.Endpoint6(),
|
|
||||||
RTT: time.Duration(math.MaxInt64) * time.Nanosecond,
|
|
||||||
Role: roleFor(a.isPublic, a.vpnIP, p.IsPublic(), vpnIP),
|
|
||||||
SignPubKey: p.SignPubKey,
|
|
||||||
}
|
|
||||||
|
|
||||||
a.peersByKey[p.WGPubKey] = peer
|
|
||||||
a.peersByIP[peer.VPNIP] = peer
|
|
||||||
defer a.updateHosts()
|
|
||||||
|
|
||||||
if !peer.IsPublic {
|
|
||||||
if a.isPublic {
|
|
||||||
// Public nodes accept traffic from non-public peers as soon as they
|
|
||||||
// initiate a handshake. Set /32 AllowedIPs now; WireGuard learns the
|
|
||||||
// endpoint from the incoming handshake automatically.
|
|
||||||
a.devPromote(peer)
|
|
||||||
} else {
|
|
||||||
a.devAddPeer(peer)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
a.devAddDirect(peer, peer.PreferredEndpoint())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *App) onRemovePeer(key wgtypes.Key) {
|
|
||||||
peer, exists := a.peersByKey[key]
|
|
||||||
if !exists {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
a.devRemove(peer)
|
|
||||||
delete(a.peersByKey, key)
|
|
||||||
delete(a.peersByIP, peer.VPNIP)
|
|
||||||
a.updateHosts()
|
|
||||||
|
|
||||||
if peer == a.relay {
|
|
||||||
a.relay = nil
|
|
||||||
a.switchActiveRelay()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// switchActiveRelay promotes the lowest-latency relay peer to active.
|
|
||||||
func (a *App) switchActiveRelay() {
|
|
||||||
if a.relay != nil {
|
|
||||||
// If we have a relay, it's public, so should go back to being a direct
|
|
||||||
// peer - this will convert it's /24 to a /32.
|
|
||||||
a.devAddDirect(a.relay, a.relay.PreferredEndpoint())
|
|
||||||
a.relay = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var best *Peer
|
|
||||||
for _, p := range a.peersByKey {
|
|
||||||
if !p.CanRelay() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if best == nil || p.RTT < best.RTT {
|
|
||||||
best = p
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if best == nil {
|
|
||||||
log.Printf("no relay available")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
a.devSetRelay(best, best.PreferredEndpoint())
|
|
||||||
a.relay = best
|
|
||||||
}
|
|
||||||
|
|
||||||
func preferredEndpoint(v4, v6 netip.AddrPort) netip.AddrPort {
|
|
||||||
// We always prefer v4 since all peers can connect to IPv4 addresses.
|
|
||||||
if v4.IsValid() {
|
|
||||||
return v4
|
|
||||||
}
|
|
||||||
return v6
|
|
||||||
}
|
|
||||||
|
|
||||||
func roleFor(selfIsPublic bool, selfIP netip.Addr, peerIsPublic bool, peerVPNIP netip.Addr) control.Role {
|
|
||||||
if !selfIsPublic && peerIsPublic {
|
|
||||||
return control.Client
|
|
||||||
}
|
|
||||||
if selfIsPublic && !peerIsPublic {
|
|
||||||
return control.Server
|
|
||||||
}
|
|
||||||
return control.RoleFor(selfIP, peerVPNIP)
|
|
||||||
}
|
|
||||||
@@ -1,299 +0,0 @@
|
|||||||
package peer
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
|
|
||||||
"vppn/m"
|
|
||||||
)
|
|
||||||
|
|
||||||
func mustKey(t *testing.T) wgtypes.Key {
|
|
||||||
t.Helper()
|
|
||||||
k, err := wgtypes.GeneratePrivateKey()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("generate key: %v", err)
|
|
||||||
}
|
|
||||||
return k.PublicKey()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOnAddPeer(t *testing.T) {
|
|
||||||
ep1 := netip.MustParseAddrPort("1.2.3.4:51820")
|
|
||||||
ep2 := netip.MustParseAddrPort("5.6.7.8:51820")
|
|
||||||
peerVPNIP := netip.MustParseAddr("10.0.0.2")
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
setup func(a *App, key wgtypes.Key)
|
|
||||||
peer func(key wgtypes.Key) m.Peer
|
|
||||||
check func(t *testing.T, a *App, dev *fakeWGDevice, key wgtypes.Key)
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "non-public peer registered in WG via AddPeer",
|
|
||||||
peer: func(k wgtypes.Key) m.Peer {
|
|
||||||
return m.Peer{WGPubKey: k, PeerIP: 2}
|
|
||||||
},
|
|
||||||
check: func(t *testing.T, a *App, dev *fakeWGDevice, key wgtypes.Key) {
|
|
||||||
p := a.peersByKey[key]
|
|
||||||
if p == nil {
|
|
||||||
t.Fatal("not in peersByKey")
|
|
||||||
}
|
|
||||||
if a.peersByIP[peerVPNIP] == nil {
|
|
||||||
t.Fatal("not in peersByIP")
|
|
||||||
}
|
|
||||||
if p.State != StateRelayed {
|
|
||||||
t.Fatalf("state = %v, want StateRelayed", p.State)
|
|
||||||
}
|
|
||||||
dev.AssertAddPeer(t, 0, key)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "public peer with endpoint registered via AddDirect",
|
|
||||||
peer: func(k wgtypes.Key) m.Peer {
|
|
||||||
return m.Peer{WGPubKey: k, PeerIP: 2, Addr4: ep1.Addr(), Port: ep1.Port()}
|
|
||||||
},
|
|
||||||
check: func(t *testing.T, a *App, dev *fakeWGDevice, key wgtypes.Key) {
|
|
||||||
p := a.peersByKey[key]
|
|
||||||
if p == nil {
|
|
||||||
t.Fatal("not in peersByKey")
|
|
||||||
}
|
|
||||||
dev.AssertAddDirect(t, 0, p.PubKey(), ep1, p.VPNIP)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "re-add removes old WG entry before adding new one",
|
|
||||||
setup: func(a *App, key wgtypes.Key) {
|
|
||||||
a.onAddPeer(m.Peer{WGPubKey: key, PeerIP: 2, Addr4: ep1.Addr(), Port: ep1.Port()})
|
|
||||||
},
|
|
||||||
peer: func(k wgtypes.Key) m.Peer {
|
|
||||||
return m.Peer{WGPubKey: k, PeerIP: 2, Addr4: ep2.Addr(), Port: ep2.Port()}
|
|
||||||
},
|
|
||||||
check: func(t *testing.T, a *App, dev *fakeWGDevice, key wgtypes.Key) {
|
|
||||||
if len(dev.Calls) != 2 {
|
|
||||||
t.Fatalf("dev calls = %v, want [RemovePeer, AddDirect]", dev.Calls)
|
|
||||||
}
|
|
||||||
dev.AssertRemovePeer(t, 0, key)
|
|
||||||
dev.AssertAddDirect(t, 1, key, ep2, peerVPNIP)
|
|
||||||
if len(a.peersByKey) != 1 || len(a.peersByIP) != 1 {
|
|
||||||
t.Errorf("maps: peersByKey=%d peersByIP=%d, want 1 each", len(a.peersByKey), len(a.peersByIP))
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
a, dev, _ := newTestApp(t, "10.0.0.1", false, false)
|
|
||||||
key := mustKey(t)
|
|
||||||
if tc.setup != nil {
|
|
||||||
tc.setup(a, key)
|
|
||||||
dev.Calls = nil
|
|
||||||
}
|
|
||||||
a.onAddPeer(tc.peer(key))
|
|
||||||
tc.check(t, a, dev, key)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOnRemovePeer(t *testing.T) {
|
|
||||||
ep1 := netip.MustParseAddrPort("1.2.3.4:51820")
|
|
||||||
ep2 := netip.MustParseAddrPort("5.6.7.8:51820")
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
setup func(t *testing.T, a *App) wgtypes.Key // returns the key to remove
|
|
||||||
check func(t *testing.T, a *App, dev *fakeWGDevice)
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "unknown key is a no-op",
|
|
||||||
setup: func(t *testing.T, a *App) wgtypes.Key {
|
|
||||||
return mustKey(t)
|
|
||||||
},
|
|
||||||
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
|
|
||||||
dev.AssertNoCalls(t)
|
|
||||||
if len(a.peersByKey) != 0 {
|
|
||||||
t.Errorf("peersByKey should be empty")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "StateRelayed peer removed from maps with RemovePeer",
|
|
||||||
setup: func(t *testing.T, a *App) wgtypes.Key {
|
|
||||||
key := mustKey(t)
|
|
||||||
a.onAddPeer(m.Peer{WGPubKey: key, PeerIP: 2})
|
|
||||||
return key
|
|
||||||
},
|
|
||||||
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
|
|
||||||
if len(dev.Calls) != 1 {
|
|
||||||
t.Fatalf("dev calls = %v, want [RemovePeer]", dev.Calls)
|
|
||||||
}
|
|
||||||
dev.AssertRemovePeer(t, 0, dev.Calls[0].PubKey)
|
|
||||||
if len(a.peersByKey) != 0 || len(a.peersByIP) != 0 {
|
|
||||||
t.Errorf("maps should be empty after remove")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "StateDirect peer removed from maps with RemovePeer",
|
|
||||||
setup: func(t *testing.T, a *App) wgtypes.Key {
|
|
||||||
key := mustKey(t)
|
|
||||||
a.onAddPeer(m.Peer{WGPubKey: key, PeerIP: 2, Addr4: ep1.Addr(), Port: ep1.Port()})
|
|
||||||
return key
|
|
||||||
},
|
|
||||||
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
|
|
||||||
if len(dev.Calls) != 1 {
|
|
||||||
t.Fatalf("dev calls = %v, want [RemovePeer]", dev.Calls)
|
|
||||||
}
|
|
||||||
dev.AssertRemovePeer(t, 0, dev.Calls[0].PubKey)
|
|
||||||
if len(a.peersByKey) != 0 || len(a.peersByIP) != 0 {
|
|
||||||
t.Errorf("maps should be empty after remove")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "removing active relay with no backup clears relay field",
|
|
||||||
setup: func(t *testing.T, a *App) wgtypes.Key {
|
|
||||||
relay := addRelayPeer(t, a, "10.0.0.10", ep1)
|
|
||||||
a.relay = relay
|
|
||||||
return relay.PubKey()
|
|
||||||
},
|
|
||||||
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
|
|
||||||
if len(dev.Calls) != 1 {
|
|
||||||
t.Fatalf("dev calls = %v, want [RemovePeer]", dev.Calls)
|
|
||||||
}
|
|
||||||
dev.AssertRemovePeer(t, 0, dev.Calls[0].PubKey)
|
|
||||||
if a.relay != nil {
|
|
||||||
t.Errorf("relay should be nil after removing only relay")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "removing active relay elects backup via SetRelay",
|
|
||||||
setup: func(t *testing.T, a *App) wgtypes.Key {
|
|
||||||
relay1 := addRelayPeer(t, a, "10.0.0.10", ep1)
|
|
||||||
addRelayPeer(t, a, "10.0.0.11", ep2)
|
|
||||||
a.relay = relay1
|
|
||||||
return relay1.PubKey()
|
|
||||||
},
|
|
||||||
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
|
|
||||||
if len(dev.Calls) != 2 {
|
|
||||||
t.Fatalf("dev calls = %v, want [RemovePeer, SetRelay]", dev.Calls)
|
|
||||||
}
|
|
||||||
dev.AssertRemovePeer(t, 0, dev.Calls[0].PubKey)
|
|
||||||
dev.AssertSetRelay(t, 1, dev.Calls[1].PubKey, ep2, a.vpnNet)
|
|
||||||
if a.relay == nil {
|
|
||||||
t.Errorf("relay should be set to backup after failover")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
a, dev, _ := newTestApp(t, "10.0.0.1", false, false)
|
|
||||||
key := tc.setup(t, a)
|
|
||||||
dev.Calls = nil
|
|
||||||
a.onRemovePeer(key)
|
|
||||||
tc.check(t, a, dev)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSwitchActiveRelay(t *testing.T) {
|
|
||||||
ep1 := netip.MustParseAddrPort("1.2.3.4:51820")
|
|
||||||
ep2 := netip.MustParseAddrPort("5.6.7.8:51820")
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
setup func(t *testing.T, a *App)
|
|
||||||
check func(t *testing.T, a *App, dev *fakeWGDevice)
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "no candidates leaves relay nil",
|
|
||||||
setup: func(t *testing.T, a *App) {},
|
|
||||||
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
|
|
||||||
dev.AssertNoCalls(t)
|
|
||||||
if a.relay != nil {
|
|
||||||
t.Error("relay should be nil")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "single candidate elected via SetRelay",
|
|
||||||
setup: func(t *testing.T, a *App) {
|
|
||||||
addRelayPeer(t, a, "10.0.0.10", ep1)
|
|
||||||
},
|
|
||||||
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
|
|
||||||
if len(dev.Calls) != 1 {
|
|
||||||
t.Fatalf("dev calls = %v, want [SetRelay]", dev.Calls)
|
|
||||||
}
|
|
||||||
dev.AssertSetRelay(t, 0, dev.Calls[0].PubKey, ep1, a.vpnNet)
|
|
||||||
if a.relay == nil {
|
|
||||||
t.Error("relay should be set")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "measured RTT beats zero RTT",
|
|
||||||
setup: func(t *testing.T, a *App) {
|
|
||||||
r1 := addRelayPeer(t, a, "10.0.0.10", ep1)
|
|
||||||
r1.RTT = 10 * time.Millisecond
|
|
||||||
addRelayPeer(t, a, "10.0.0.11", ep2) // RTT stays MaxInt64 (unmeaured)
|
|
||||||
},
|
|
||||||
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
|
|
||||||
if len(dev.Calls) != 1 {
|
|
||||||
t.Fatalf("dev calls = %v, want [SetRelay]", dev.Calls)
|
|
||||||
}
|
|
||||||
dev.AssertSetRelay(t, 0, dev.Calls[0].PubKey, ep1, a.vpnNet)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "lower RTT wins",
|
|
||||||
setup: func(t *testing.T, a *App) {
|
|
||||||
r1 := addRelayPeer(t, a, "10.0.0.10", ep1)
|
|
||||||
r1.RTT = 5 * time.Millisecond
|
|
||||||
r2 := addRelayPeer(t, a, "10.0.0.11", ep2)
|
|
||||||
r2.RTT = 20 * time.Millisecond
|
|
||||||
},
|
|
||||||
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
|
|
||||||
if len(dev.Calls) != 1 {
|
|
||||||
t.Fatalf("dev calls = %v, want [SetRelay]", dev.Calls)
|
|
||||||
}
|
|
||||||
dev.AssertSetRelay(t, 0, dev.Calls[0].PubKey, ep1, a.vpnNet)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "stale relay demoted to direct before backup elected",
|
|
||||||
setup: func(t *testing.T, a *App) {
|
|
||||||
old := addRelayPeer(t, a, "10.0.0.10", ep1)
|
|
||||||
old.wgPeer.LastHandshakeTime = time.Time{} // stale — triggers switch from onTick
|
|
||||||
a.relay = old
|
|
||||||
addRelayPeer(t, a, "10.0.0.11", ep2)
|
|
||||||
},
|
|
||||||
check: func(t *testing.T, a *App, dev *fakeWGDevice) {
|
|
||||||
if len(dev.Calls) != 2 {
|
|
||||||
t.Fatalf("dev calls = %v, want [AddDirect, SetRelay]", dev.Calls)
|
|
||||||
}
|
|
||||||
if dev.Calls[0].Method != "AddDirect" || dev.Calls[0].Endpoint != ep1 {
|
|
||||||
t.Errorf("call[0]: got %v, want AddDirect with ep1", dev.Calls[0])
|
|
||||||
}
|
|
||||||
dev.AssertSetRelay(t, 1, dev.Calls[1].PubKey, ep2, a.vpnNet)
|
|
||||||
if a.relay == nil || a.relay.EndpointV4 != ep2 {
|
|
||||||
t.Error("relay should be the backup peer")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
a, dev, _ := newTestApp(t, "10.0.0.1", false, false)
|
|
||||||
tc.setup(t, a)
|
|
||||||
dev.Calls = nil
|
|
||||||
a.switchActiveRelay()
|
|
||||||
tc.check(t, a, dev)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,51 +0,0 @@
|
|||||||
package peer
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
|
|
||||||
"vppn/peer/multicast"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (a *App) onMulticastDiscovery(pkt multicast.Packet) {
|
|
||||||
if a.isPublic {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Locate the sender peer by its VPN IP (final octet carried in the beacon).
|
|
||||||
octets := a.vpnNet.Addr().As4()
|
|
||||||
octets[3] = pkt.PeerIP
|
|
||||||
vpnIP := netip.AddrFrom4(octets)
|
|
||||||
|
|
||||||
peer, ok := a.peersByIP[vpnIP]
|
|
||||||
if !ok || peer.IsPublic || peer.State == StateDirect {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Authenticate the beacon against the peer's known sign key. scratch[:0]
|
|
||||||
// gives sign.Open an empty-but-capacity buffer to decode into.
|
|
||||||
if !pkt.Verify(a.scratch[:0], &peer.SignPubKey) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// The beacon is authentic but must also advertise the WG key the hub gave
|
|
||||||
// us for this peer; otherwise it's inconsistent — drop it.
|
|
||||||
if wgtypes.Key(pkt.WGPubKey) != peer.PubKey() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
endpoint := netip.AddrPortFrom(pkt.Src, pkt.WGPort)
|
|
||||||
if !endpoint.IsValid() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var v4, v6 netip.AddrPort
|
|
||||||
if pkt.Src.Is4() {
|
|
||||||
v4 = endpoint
|
|
||||||
} else {
|
|
||||||
v6 = endpoint
|
|
||||||
}
|
|
||||||
|
|
||||||
a.addProbe(peer, v4, v6)
|
|
||||||
}
|
|
||||||
@@ -1,58 +0,0 @@
|
|||||||
package peer
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"vppn/peer/control"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (a *App) onPing(e PingEvent) {
|
|
||||||
peer, ok := a.peersByIP[e.srcVPNIP]
|
|
||||||
if !ok {
|
|
||||||
// TODO: Log here.
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
|
|
||||||
// If we're the server, respond - this is always necessary as it's used to
|
|
||||||
// know if peers are up or down.
|
|
||||||
if peer.Role == control.Server {
|
|
||||||
a.sendPing(peer, e.ping.PingTS)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute RTT from server echo.
|
|
||||||
if peer.Role == control.Client {
|
|
||||||
peer.RTT = now.Sub(time.Unix(0, e.ping.PingTS))
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we're public, nothing more to do.
|
|
||||||
if a.isPublic {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// We can only learn our own endpoint from directly-connected peers — Dst
|
|
||||||
// is the sender's observation of our WG handshake source.
|
|
||||||
if peer.State == StateDirect {
|
|
||||||
if dst := e.ping.Dst; dst.IsValid() {
|
|
||||||
if dst.Addr().Is4() {
|
|
||||||
a.selfV4 = dst
|
|
||||||
} else {
|
|
||||||
a.selfV6 = dst
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
a.addProbe(peer, e.ping.SrcV4, e.ping.SrcV6)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *App) addProbe(peer *Peer, v4, v6 netip.AddrPort) {
|
|
||||||
endpoint := preferredEndpoint(v4, v6)
|
|
||||||
if !endpoint.IsValid() || endpoint == peer.PreferredEndpoint() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
peer.UpdateEndpoints(v4, v6)
|
|
||||||
a.devAddProbe(peer, endpoint)
|
|
||||||
}
|
|
||||||
@@ -1,52 +0,0 @@
|
|||||||
package peer
|
|
||||||
|
|
||||||
import (
|
|
||||||
"log"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"vppn/peer/control"
|
|
||||||
"vppn/peer/wginterface"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (a *App) onTick() {
|
|
||||||
wgPeers := a.devPeers()
|
|
||||||
|
|
||||||
now := time.Now().UnixNano()
|
|
||||||
|
|
||||||
for _, wgPeer := range wgPeers {
|
|
||||||
p, ok := a.peersByKey[wgPeer.PublicKey]
|
|
||||||
if !ok {
|
|
||||||
log.Printf("Wireguard peer not in index, removing: %v", wgPeer)
|
|
||||||
a.devRemove(&Peer{wgPeer: wgPeer})
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
p.wgPeer = wgPeer
|
|
||||||
|
|
||||||
// Send pings to peers where we're the client.
|
|
||||||
if p.Role == control.Client {
|
|
||||||
a.sendPing(p, now)
|
|
||||||
}
|
|
||||||
|
|
||||||
switch p.State {
|
|
||||||
case StateProbing:
|
|
||||||
// Promote probing peers to direct once alive (direct path confirmed
|
|
||||||
// working).
|
|
||||||
if time.Since(p.LastHandshakeTime()) < 2*wginterface.ProbeKeepalive {
|
|
||||||
a.devPromote(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
case StateDirect:
|
|
||||||
if p.IsPublic || a.isPublic || p.Up() {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
// Stale non-public direct peer: demote to probing so WireGuard
|
|
||||||
// resumes handshake attempts on the direct path.
|
|
||||||
a.devAddProbe(p, p.WGEndpoint())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure we have a live relay (if we're not public).
|
|
||||||
if !a.isPublic && (a.relay == nil || !a.relay.Up()) {
|
|
||||||
a.switchActiveRelay()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
182
peer/packets-util.go
Normal file
182
peer/packets-util.go
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type binWriter struct {
|
||||||
|
b []byte
|
||||||
|
i int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBinWriter(buf []byte) *binWriter {
|
||||||
|
buf = buf[:cap(buf)]
|
||||||
|
return &binWriter{buf, 0}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *binWriter) Bool(b bool) *binWriter {
|
||||||
|
if b {
|
||||||
|
return w.Byte(1)
|
||||||
|
}
|
||||||
|
return w.Byte(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *binWriter) Byte(b byte) *binWriter {
|
||||||
|
w.b[w.i] = b
|
||||||
|
w.i++
|
||||||
|
return w
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *binWriter) SharedKey(key [32]byte) *binWriter {
|
||||||
|
copy(w.b[w.i:w.i+32], key[:])
|
||||||
|
w.i += 32
|
||||||
|
return w
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *binWriter) Uint16(x uint16) *binWriter {
|
||||||
|
*(*uint16)(unsafe.Pointer(&w.b[w.i])) = x
|
||||||
|
w.i += 2
|
||||||
|
return w
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *binWriter) Uint64(x uint64) *binWriter {
|
||||||
|
*(*uint64)(unsafe.Pointer(&w.b[w.i])) = x
|
||||||
|
w.i += 8
|
||||||
|
return w
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *binWriter) Int64(x int64) *binWriter {
|
||||||
|
*(*int64)(unsafe.Pointer(&w.b[w.i])) = x
|
||||||
|
w.i += 8
|
||||||
|
return w
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *binWriter) AddrPort(addrPort netip.AddrPort) *binWriter {
|
||||||
|
w.Bool(addrPort.IsValid())
|
||||||
|
addr := addrPort.Addr().As16()
|
||||||
|
copy(w.b[w.i:w.i+16], addr[:])
|
||||||
|
w.i += 16
|
||||||
|
return w.Uint16(addrPort.Port())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *binWriter) AddrPort8(l [8]netip.AddrPort) *binWriter {
|
||||||
|
for _, addrPort := range l {
|
||||||
|
w.AddrPort(addrPort)
|
||||||
|
}
|
||||||
|
return w
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *binWriter) Build() []byte {
|
||||||
|
return w.b[:w.i]
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type binReader struct {
|
||||||
|
b []byte
|
||||||
|
i int
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBinReader(buf []byte) *binReader {
|
||||||
|
return &binReader{b: buf}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *binReader) hasBytes(n int) bool {
|
||||||
|
if r.err != nil || (len(r.b)-r.i) < n {
|
||||||
|
r.err = errMalformedPacket
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *binReader) Bool(b *bool) *binReader {
|
||||||
|
var bb byte
|
||||||
|
r.Byte(&bb)
|
||||||
|
*b = bb != 0
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *binReader) Byte(b *byte) *binReader {
|
||||||
|
if !r.hasBytes(1) {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
*b = r.b[r.i]
|
||||||
|
r.i++
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *binReader) SharedKey(x *[32]byte) *binReader {
|
||||||
|
if !r.hasBytes(32) {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
*x = ([32]byte)(r.b[r.i : r.i+32])
|
||||||
|
r.i += 32
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *binReader) Uint16(x *uint16) *binReader {
|
||||||
|
if !r.hasBytes(2) {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
*x = *(*uint16)(unsafe.Pointer(&r.b[r.i]))
|
||||||
|
r.i += 2
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *binReader) Uint64(x *uint64) *binReader {
|
||||||
|
if !r.hasBytes(8) {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
*x = *(*uint64)(unsafe.Pointer(&r.b[r.i]))
|
||||||
|
r.i += 8
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *binReader) Int64(x *int64) *binReader {
|
||||||
|
if !r.hasBytes(8) {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
*x = *(*int64)(unsafe.Pointer(&r.b[r.i]))
|
||||||
|
r.i += 8
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *binReader) AddrPort(x *netip.AddrPort) *binReader {
|
||||||
|
if !r.hasBytes(19) {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
valid bool
|
||||||
|
port uint16
|
||||||
|
)
|
||||||
|
|
||||||
|
r.Bool(&valid)
|
||||||
|
addr := netip.AddrFrom16(([16]byte)(r.b[r.i : r.i+16])).Unmap()
|
||||||
|
r.i += 16
|
||||||
|
|
||||||
|
r.Uint16(&port)
|
||||||
|
|
||||||
|
if valid {
|
||||||
|
*x = netip.AddrPortFrom(addr, port)
|
||||||
|
} else {
|
||||||
|
*x = netip.AddrPort{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *binReader) AddrPort8(x *[8]netip.AddrPort) *binReader {
|
||||||
|
for i := range x {
|
||||||
|
r.AddrPort(&x[i])
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *binReader) Error() error {
|
||||||
|
return r.err
|
||||||
|
}
|
||||||
76
peer/packets-util_test.go
Normal file
76
peer/packets-util_test.go
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBinWriteRead_invalidAddrPort(t *testing.T) {
|
||||||
|
addr := netip.AddrPort{}
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
buf = newBinWriter(buf).
|
||||||
|
AddrPort(addr).
|
||||||
|
Build()
|
||||||
|
|
||||||
|
var addr2 netip.AddrPort
|
||||||
|
err := newBinReader(buf).
|
||||||
|
AddrPort(&addr2).
|
||||||
|
Error()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if addr2.IsValid() {
|
||||||
|
t.Fatal(addr, addr2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBinWriteRead(t *testing.T) {
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
|
||||||
|
type Item struct {
|
||||||
|
Type byte
|
||||||
|
TraceID uint64
|
||||||
|
Addrs [8]netip.AddrPort
|
||||||
|
DestAddr netip.AddrPort
|
||||||
|
}
|
||||||
|
|
||||||
|
in := Item{
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
[8]netip.AddrPort{},
|
||||||
|
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 22),
|
||||||
|
}
|
||||||
|
|
||||||
|
in.Addrs[0] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{0, 1, 2, 3}), 20)
|
||||||
|
in.Addrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 5}), 22)
|
||||||
|
in.Addrs[3] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 3}), 23)
|
||||||
|
in.Addrs[4] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 4}), 24)
|
||||||
|
in.Addrs[5] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 5}), 25)
|
||||||
|
in.Addrs[6] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 6}), 26)
|
||||||
|
in.Addrs[7] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{7, 8, 9, 7}), 27)
|
||||||
|
|
||||||
|
buf = newBinWriter(buf).
|
||||||
|
Byte(in.Type).
|
||||||
|
Uint64(in.TraceID).
|
||||||
|
AddrPort(in.DestAddr).
|
||||||
|
AddrPort8(in.Addrs).
|
||||||
|
Build()
|
||||||
|
|
||||||
|
out := Item{}
|
||||||
|
|
||||||
|
err := newBinReader(buf).
|
||||||
|
Byte(&out.Type).
|
||||||
|
Uint64(&out.TraceID).
|
||||||
|
AddrPort(&out.DestAddr).
|
||||||
|
AddrPort8(&out.Addrs).
|
||||||
|
Error()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(in, out) {
|
||||||
|
t.Fatal(in, out)
|
||||||
|
}
|
||||||
|
}
|
||||||
120
peer/packets.go
Normal file
120
peer/packets.go
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
packetTypeSyn = 1
|
||||||
|
packetTypeInit = 2
|
||||||
|
packetTypeAck = 3
|
||||||
|
packetTypeProbe = 4
|
||||||
|
packetTypeAddrDiscovery = 5
|
||||||
|
)
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type packetInit struct {
|
||||||
|
TraceID uint64
|
||||||
|
Direct bool
|
||||||
|
Version uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p packetInit) Marshal(buf []byte) []byte {
|
||||||
|
return newBinWriter(buf).
|
||||||
|
Byte(packetTypeInit).
|
||||||
|
Uint64(p.TraceID).
|
||||||
|
Bool(p.Direct).
|
||||||
|
Uint64(p.Version).
|
||||||
|
Build()
|
||||||
|
}
|
||||||
|
|
||||||
|
func parsePacketInit(buf []byte) (p packetInit, err error) {
|
||||||
|
err = newBinReader(buf[1:]).
|
||||||
|
Uint64(&p.TraceID).
|
||||||
|
Bool(&p.Direct).
|
||||||
|
Uint64(&p.Version).
|
||||||
|
Error()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type packetSyn struct {
|
||||||
|
TraceID uint64 // TraceID to match response w/ request.
|
||||||
|
SharedKey [32]byte // Our shared key.
|
||||||
|
Direct bool
|
||||||
|
PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender.
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p packetSyn) Marshal(buf []byte) []byte {
|
||||||
|
return newBinWriter(buf).
|
||||||
|
Byte(packetTypeSyn).
|
||||||
|
Uint64(p.TraceID).
|
||||||
|
SharedKey(p.SharedKey).
|
||||||
|
Bool(p.Direct).
|
||||||
|
AddrPort8(p.PossibleAddrs).
|
||||||
|
Build()
|
||||||
|
}
|
||||||
|
|
||||||
|
func parsePacketSyn(buf []byte) (p packetSyn, err error) {
|
||||||
|
err = newBinReader(buf[1:]).
|
||||||
|
Uint64(&p.TraceID).
|
||||||
|
SharedKey(&p.SharedKey).
|
||||||
|
Bool(&p.Direct).
|
||||||
|
AddrPort8(&p.PossibleAddrs).
|
||||||
|
Error()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type packetAck struct {
|
||||||
|
TraceID uint64
|
||||||
|
ToAddr netip.AddrPort
|
||||||
|
PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender.
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p packetAck) Marshal(buf []byte) []byte {
|
||||||
|
return newBinWriter(buf).
|
||||||
|
Byte(packetTypeAck).
|
||||||
|
Uint64(p.TraceID).
|
||||||
|
AddrPort(p.ToAddr).
|
||||||
|
AddrPort8(p.PossibleAddrs).
|
||||||
|
Build()
|
||||||
|
}
|
||||||
|
|
||||||
|
func parsePacketAck(buf []byte) (p packetAck, err error) {
|
||||||
|
err = newBinReader(buf[1:]).
|
||||||
|
Uint64(&p.TraceID).
|
||||||
|
AddrPort(&p.ToAddr).
|
||||||
|
AddrPort8(&p.PossibleAddrs).
|
||||||
|
Error()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// A probeReqPacket is sent from a client to a server to determine if direct
|
||||||
|
// UDP communication can be used.
|
||||||
|
type packetProbe struct {
|
||||||
|
TraceID uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p packetProbe) Marshal(buf []byte) []byte {
|
||||||
|
return newBinWriter(buf).
|
||||||
|
Byte(packetTypeProbe).
|
||||||
|
Uint64(p.TraceID).
|
||||||
|
Build()
|
||||||
|
}
|
||||||
|
|
||||||
|
func parsePacketProbe(buf []byte) (p packetProbe, err error) {
|
||||||
|
err = newBinReader(buf[1:]).
|
||||||
|
Uint64(&p.TraceID).
|
||||||
|
Error()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type packetLocalDiscovery struct{}
|
||||||
64
peer/packets_test.go
Normal file
64
peer/packets_test.go
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"net/netip"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSynPacket(t *testing.T) {
|
||||||
|
p := packetSyn{
|
||||||
|
TraceID: 2342342345,
|
||||||
|
Direct: true,
|
||||||
|
}
|
||||||
|
rand.Read(p.SharedKey[:])
|
||||||
|
|
||||||
|
p.PossibleAddrs[0] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 234)
|
||||||
|
p.PossibleAddrs[1] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 2, 3, 4}), 12399)
|
||||||
|
p.PossibleAddrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{3, 2, 3, 4}), 60000)
|
||||||
|
|
||||||
|
buf := p.Marshal(newBuf())
|
||||||
|
p2, err := parsePacketSyn(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(p, p2) {
|
||||||
|
t.Fatal(p2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAckPacket(t *testing.T) {
|
||||||
|
p := packetAck{
|
||||||
|
TraceID: 123213,
|
||||||
|
ToAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 234),
|
||||||
|
}
|
||||||
|
|
||||||
|
p.PossibleAddrs[0] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 2, 3, 4}), 100)
|
||||||
|
p.PossibleAddrs[1] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 2, 3, 4}), 12399)
|
||||||
|
p.PossibleAddrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{3, 2, 3, 4}), 60000)
|
||||||
|
|
||||||
|
buf := p.Marshal(newBuf())
|
||||||
|
p2, err := parsePacketAck(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(p, p2) {
|
||||||
|
t.Fatal(p2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProbePacket(t *testing.T) {
|
||||||
|
p := packetProbe{
|
||||||
|
TraceID: 12345,
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := p.Marshal(newBuf())
|
||||||
|
p2, err := parsePacketProbe(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(p, p2) {
|
||||||
|
t.Fatal(p2)
|
||||||
|
}
|
||||||
|
}
|
||||||
197
peer/peer.go
Normal file
197
peer/peer.go
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"math"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"vppn/m"
|
||||||
|
|
||||||
|
"git.crumpington.com/lib/go/flock"
|
||||||
|
)
|
||||||
|
|
||||||
|
type peerMain struct {
|
||||||
|
Globals
|
||||||
|
ifReader *IFReader
|
||||||
|
connReader *ConnReader
|
||||||
|
hubPoller *HubPoller
|
||||||
|
lockFile *os.File
|
||||||
|
}
|
||||||
|
|
||||||
|
func newPeerMain(args mainArgs) *peerMain {
|
||||||
|
logf := func(s string, args ...any) {
|
||||||
|
log.Printf("[Main] "+s, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
lockFile, err := flock.TryLock(lockFilePath(args.NetName))
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to open lock file: %v", err)
|
||||||
|
}
|
||||||
|
if lockFile == nil {
|
||||||
|
log.Fatalf("Failed to obtain file lock.")
|
||||||
|
}
|
||||||
|
|
||||||
|
config, err := loadPeerConfig(args.NetName)
|
||||||
|
if err != nil {
|
||||||
|
logf("Failed to load configuration: %v", err)
|
||||||
|
logf("Initializing...")
|
||||||
|
initPeerWithHub(args)
|
||||||
|
|
||||||
|
config, err = loadPeerConfig(args.NetName)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to load configuration: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
state, err := loadNetworkState(args.NetName)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to load network state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
startupCount, err := loadStartupCount(args.NetName)
|
||||||
|
if err != nil {
|
||||||
|
if !os.IsNotExist(err) {
|
||||||
|
log.Fatalf("Failed to load startup count: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if startupCount.Count == math.MaxUint16 {
|
||||||
|
log.Fatalf("Startup counter overflow.")
|
||||||
|
}
|
||||||
|
startupCount.Count += 1
|
||||||
|
|
||||||
|
if err := storeStartupCount(args.NetName, startupCount); err != nil {
|
||||||
|
log.Fatalf("Failed to write startup count: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
iface, err := openInterface(config.Network, config.LocalPeerIP, args.NetName)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to open interface: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
localPeer := state.Peers[config.LocalPeerIP]
|
||||||
|
|
||||||
|
myAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", localPeer.Port))
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to resolve UDP address: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logf("Listening on %v...", myAddr)
|
||||||
|
conn, err := net.ListenUDP("udp", myAddr)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to open UDP port: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.SetReadBuffer(1024 * 1024 * 8)
|
||||||
|
conn.SetWriteBuffer(1024 * 1024 * 8)
|
||||||
|
|
||||||
|
var localAddr netip.AddrPort
|
||||||
|
ip, localAddrValid := netip.AddrFromSlice(localPeer.PublicIP)
|
||||||
|
if localAddrValid {
|
||||||
|
localAddr = netip.AddrPortFrom(ip, localPeer.Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
g := NewGlobals(config, startupCount, localAddr, conn, iface)
|
||||||
|
|
||||||
|
hubPoller, err := NewHubPoller(g, args.NetName, args.HubAddress, args.APIKey)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to create hub poller: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start status server.
|
||||||
|
go runStatusServer(g, statusSocketPath(args.NetName))
|
||||||
|
|
||||||
|
return &peerMain{
|
||||||
|
Globals: g,
|
||||||
|
ifReader: NewIFReader(g),
|
||||||
|
connReader: NewConnReader(g, conn),
|
||||||
|
hubPoller: hubPoller,
|
||||||
|
lockFile: lockFile,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *peerMain) Run() {
|
||||||
|
for i := range p.RemotePeers {
|
||||||
|
remote := p.RemotePeers[i].Load()
|
||||||
|
go newRemoteFSM(remote).Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
go p.ifReader.Run()
|
||||||
|
go p.connReader.Run()
|
||||||
|
|
||||||
|
if !p.LocalAddrValid {
|
||||||
|
go RunMCWriter(p.LocalPeerIP, p.PrivSignKey)
|
||||||
|
go RunMCReader(p.Globals)
|
||||||
|
}
|
||||||
|
|
||||||
|
go p.hubPoller.Run()
|
||||||
|
|
||||||
|
select {}
|
||||||
|
}
|
||||||
|
|
||||||
|
func initPeerWithHub(args mainArgs) {
|
||||||
|
keys := generateKeys()
|
||||||
|
|
||||||
|
initURL, err := url.Parse(args.HubAddress)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to parse hub URL: %v", err)
|
||||||
|
}
|
||||||
|
initURL.Path = "/peer/init/"
|
||||||
|
|
||||||
|
initArgs := m.PeerInitArgs{
|
||||||
|
EncPubKey: keys.PubKey,
|
||||||
|
PubSignKey: keys.PubSignKey,
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
if err := json.NewEncoder(buf).Encode(initArgs); err != nil {
|
||||||
|
log.Fatalf("Failed to encode init args: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequest(http.MethodPost, initURL.String(), buf)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to construct request: %v", err)
|
||||||
|
}
|
||||||
|
req.SetBasicAuth("", args.APIKey)
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to init with hub: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
data, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to read response body: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
initResp := m.PeerInitResp{}
|
||||||
|
if err := json.Unmarshal(data, &initResp); err != nil {
|
||||||
|
log.Fatalf("Failed to parse configuration: %v\n%s", err, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
config := LocalConfig{}
|
||||||
|
config.LocalPeerIP = initResp.PeerIP
|
||||||
|
config.Network = initResp.Network
|
||||||
|
config.PubKey = keys.PubKey
|
||||||
|
config.PrivKey = keys.PrivKey
|
||||||
|
config.PubSignKey = keys.PubSignKey
|
||||||
|
config.PrivSignKey = keys.PrivSignKey
|
||||||
|
|
||||||
|
if err := storeNetworkState(args.NetName, initResp.NetworkState); err != nil {
|
||||||
|
log.Fatalf("Failed to store network state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := storePeerConfig(args.NetName, config); err != nil {
|
||||||
|
log.Fatalf("Failed to store configuration: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Print("Initialization successful.")
|
||||||
|
}
|
||||||
21
peer/ping.go
21
peer/ping.go
@@ -1,21 +0,0 @@
|
|||||||
package peer
|
|
||||||
|
|
||||||
import (
|
|
||||||
"log"
|
|
||||||
"net/netip"
|
|
||||||
|
|
||||||
"vppn/peer/control"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (a *App) sendPing(p *Peer, ts int64) {
|
|
||||||
ping := control.Ping{
|
|
||||||
PingTS: ts,
|
|
||||||
SrcV4: a.selfV4,
|
|
||||||
SrcV6: a.selfV6,
|
|
||||||
Dst: p.WGEndpoint(),
|
|
||||||
}
|
|
||||||
dst := netip.AddrPortFrom(p.VPNIP, ControlPort)
|
|
||||||
if err := a.controlConn.SendPing(dst, ping, a.scratch); err != nil {
|
|
||||||
log.Printf("sendPing %v: %v", p.VPNIP, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
86
peer/pubaddrs.go
Normal file
86
peer/pubaddrs.go
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"sort"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type pubAddrStore struct {
|
||||||
|
lock sync.Mutex
|
||||||
|
localPub bool
|
||||||
|
localAddr netip.AddrPort
|
||||||
|
lastSeen map[netip.AddrPort]time.Time
|
||||||
|
addrList []netip.AddrPort
|
||||||
|
}
|
||||||
|
|
||||||
|
func newPubAddrStore(localAddr netip.AddrPort) *pubAddrStore {
|
||||||
|
return &pubAddrStore{
|
||||||
|
localPub: localAddr.IsValid(),
|
||||||
|
localAddr: localAddr,
|
||||||
|
lastSeen: map[netip.AddrPort]time.Time{},
|
||||||
|
addrList: make([]netip.AddrPort, 0, 32),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (store *pubAddrStore) Store(addr netip.AddrPort) {
|
||||||
|
if store.localPub {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !addr.IsValid() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if addr.Addr().IsPrivate() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
store.lock.Lock()
|
||||||
|
defer store.lock.Unlock()
|
||||||
|
|
||||||
|
if _, exists := store.lastSeen[addr]; !exists {
|
||||||
|
store.addrList = append(store.addrList, addr)
|
||||||
|
}
|
||||||
|
store.lastSeen[addr] = time.Now()
|
||||||
|
store.sort()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (store *pubAddrStore) Get() (addrs [8]netip.AddrPort) {
|
||||||
|
store.lock.Lock()
|
||||||
|
defer store.lock.Unlock()
|
||||||
|
|
||||||
|
store.clean()
|
||||||
|
|
||||||
|
if store.localPub {
|
||||||
|
addrs[0] = store.localAddr
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(addrs[:], store.addrList)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (store *pubAddrStore) clean() {
|
||||||
|
if store.localPub {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for ip, lastSeen := range store.lastSeen {
|
||||||
|
if time.Since(lastSeen) > timeoutInterval {
|
||||||
|
delete(store.lastSeen, ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
store.addrList = store.addrList[:0]
|
||||||
|
for ip := range store.lastSeen {
|
||||||
|
store.addrList = append(store.addrList, ip)
|
||||||
|
}
|
||||||
|
store.sort()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (store *pubAddrStore) sort() {
|
||||||
|
sort.Slice(store.addrList, func(i, j int) bool {
|
||||||
|
return store.lastSeen[store.addrList[j]].Before(store.lastSeen[store.addrList[i]])
|
||||||
|
})
|
||||||
|
}
|
||||||
29
peer/pubaddrs_test.go
Normal file
29
peer/pubaddrs_test.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPubAddrStore(t *testing.T) {
|
||||||
|
s := newPubAddrStore(netip.AddrPort{})
|
||||||
|
|
||||||
|
l := []netip.AddrPort{
|
||||||
|
netip.AddrPortFrom(netip.AddrFrom4([4]byte{0, 1, 2, 3}), 20),
|
||||||
|
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 2, 3}), 21),
|
||||||
|
netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 1, 2, 3}), 22),
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range l {
|
||||||
|
s.Store(l[i])
|
||||||
|
time.Sleep(time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.clean()
|
||||||
|
|
||||||
|
l2 := s.Get()
|
||||||
|
if l2[0] != l[2] || l2[1] != l[1] || l2[2] != l[0] {
|
||||||
|
t.Fatal(l, l2)
|
||||||
|
}
|
||||||
|
}
|
||||||
54
peer/relayhandler.go
Normal file
54
peer/relayhandler.go
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
type relayHandler struct {
|
||||||
|
lock sync.Mutex
|
||||||
|
relays map[byte]*Remote
|
||||||
|
relay atomic.Pointer[Remote]
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRelayHandler() *relayHandler {
|
||||||
|
return &relayHandler{
|
||||||
|
relays: make(map[byte]*Remote, 256),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *relayHandler) Add(r *Remote) {
|
||||||
|
h.lock.Lock()
|
||||||
|
defer h.lock.Unlock()
|
||||||
|
|
||||||
|
h.relays[r.RemotePeerIP] = r
|
||||||
|
|
||||||
|
if h.relay.Load() == nil {
|
||||||
|
log.Printf("Setting Relay: %v", r.conf().Peer.Name)
|
||||||
|
h.relay.Store(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *relayHandler) Remove(r *Remote) {
|
||||||
|
h.lock.Lock()
|
||||||
|
defer h.lock.Unlock()
|
||||||
|
|
||||||
|
log.Printf("Removing relay %d...", r.RemotePeerIP)
|
||||||
|
delete(h.relays, r.RemotePeerIP)
|
||||||
|
|
||||||
|
if h.relay.Load() == r {
|
||||||
|
// Remove current relay.
|
||||||
|
h.relay.Store(nil)
|
||||||
|
|
||||||
|
// Find new relay.
|
||||||
|
for _, r := range h.relays {
|
||||||
|
h.relay.Store(r)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *relayHandler) Load() *Remote {
|
||||||
|
return h.relay.Load()
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user