diff --git a/cmd/vppn/main.go b/cmd/vppn/main.go index dada4cf..d0e5be0 100644 --- a/cmd/vppn/main.go +++ b/cmd/vppn/main.go @@ -1,11 +1,72 @@ package main import ( + "flag" "log" + "os" + "path/filepath" + "strings" + "vppn/peer" + + "git.crumpington.com/lib/go/flock" ) func main() { log.SetFlags(0) - peer.Main2() + + name := flag.String("name", "", "network name (required)") + hub := flag.String("hub", "", "hub base URL (required)") + flag.Parse() + + if *name == "" || *hub == "" { + flag.Usage() + os.Exit(1) + } + + apiKey, err := loadAPIKey(*name) + if err != nil { + log.Fatalf("api key: %v", err) + } + + // Directory existence is guaranteed by the apikey file read above. + lockFile, err := flock.TryLock(vppnPath(*name, "lock")) + if err != nil { + log.Fatalf("lock: %v", err) + } + if lockFile == nil { + log.Fatalf("already running for network %q", *name) + } + defer flock.Unlock(lockFile) + + state, err := peer.LoadOrInit(vppnPath(*name, "state.json"), *hub, apiKey) + if err != nil { + log.Fatalf("init: %v", err) + } + + ifaceName := strings.TrimSuffix(state.LocalDomain, ".local") + app, err := peer.New(state, *hub, apiKey, ifaceName, state.LocalDomain, vppnPath(*name, "network.json")) + if err != nil { + log.Fatalf("start: %v", err) + } + + if err := app.Run(); err != nil { + log.Fatalf("run: %v", err) + } +} + +func loadAPIKey(name string) (string, error) { + data, err := os.ReadFile(vppnPath(name, "apikey")) + if err != nil { + return "", err + } + return strings.TrimSpace(string(data)), nil +} + +func vppnPath(name, file string) string { + home, err := os.UserHomeDir() + if err != nil { + return filepath.Join(".vppn", name, file) + } + return filepath.Join(home, ".vppn", name, file) } diff --git a/go.mod b/go.mod index ad3add6..e35d73d 100644 --- a/go.mod +++ b/go.mod @@ -6,10 +6,18 @@ require ( git.crumpington.com/lib/go v0.9.1 golang.org/x/crypto v0.42.0 golang.org/x/sys v0.36.0 + golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 ) require ( + github.com/google/go-cmp v0.6.0 // indirect + github.com/josharian/native v1.1.0 // indirect github.com/mattn/go-sqlite3 v1.14.32 // indirect + github.com/mdlayher/genetlink v1.3.2 // indirect + github.com/mdlayher/netlink v1.7.2 // indirect + github.com/mdlayher/socket v0.5.1 // indirect golang.org/x/net v0.44.0 // indirect + golang.org/x/sync v0.17.0 // indirect golang.org/x/text v0.29.0 // indirect + golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 // indirect ) diff --git a/go.sum b/go.sum index 0f444e0..9d7f776 100644 --- a/go.sum +++ b/go.sum @@ -1,12 +1,30 @@ git.crumpington.com/lib/go v0.9.1 h1:xLBzcgiZRB6Ky3Ce9hKE+Ko0YbkA4USF4eJk5i5RJF4= git.crumpington.com/lib/go v0.9.1/go.mod h1:5nnfjdnUnj/FHhakaliKQKsKeSkUb0GEUKF3PqRgUXg= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA= +github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw= +github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o= +github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= +github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= +github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos= +github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= +github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws= +github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc= golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI= golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8= golang.org/x/net v0.44.0 h1:evd8IRDyfNBMBTTY5XRF1vaZlD+EmWx6x8PkhR04H/I= golang.org/x/net v0.44.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY= +golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= +golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk= golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4= +golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4= +golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= +golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU= +golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ= diff --git a/hub/api/api.go b/hub/api/api.go index 7a534ca..19838b8 100644 --- a/hub/api/api.go +++ b/hub/api/api.go @@ -19,8 +19,10 @@ import ( var migrations embed.FS type API struct { - db *sql.DB - lock sync.Mutex + db *sql.DB + lock sync.Mutex + sessionsMu sync.Mutex + sessions map[string]*Session } func New(dbPath string) (*API, error) { @@ -34,10 +36,17 @@ func New(dbPath string) (*API, error) { } a := &API{ - db: sqlDB, + db: sqlDB, + sessions: make(map[string]*Session), } - return a, a.ensurePassword() + if err := a.ensurePassword(); err != nil { + return nil, err + } + + go a.sweepSessions() + + return a, nil } func (a *API) ensurePassword() error { @@ -62,12 +71,8 @@ func (a *API) ensurePassword() error { return db.Config_Insert(a.db, conf) } -func (a *API) Config_Get() *Config { - conf, err := db.Config_Get(a.db, 1) - if err != nil { - panic(err) - } - return conf +func (a *API) Config_Get() (*Config, error) { + return db.Config_Get(a.db, 1) } func (a *API) Config_Update(conf *Config) error { @@ -75,56 +80,78 @@ func (a *API) Config_Update(conf *Config) error { } func (a *API) Session_Delete(sessionID string) error { - return db.Session_Delete(a.db, sessionID) + a.sessionsMu.Lock() + defer a.sessionsMu.Unlock() + delete(a.sessions, sessionID) + return nil } -func (a *API) Session_Get(sessionID string) (*Session, error) { - if sessionID == "" { - return a.session_CreatePub() +const ( + sessionTTLSecs = 86400 * 21 // sessions expire 21 days after last use + sessionSweepEvery = time.Hour // cadence of expired-session eviction +) + +// Session_Get returns a snapshot copy of the signed-in session for sessionID, +// or the zero Session if the cookie is missing/unknown/expired. It never +// creates a session, so anonymous requests cost no memory — a session is minted +// only by Session_SignIn. Returning a value (not the stored pointer) keeps +// callers from racing on the shared struct. +func (a *API) Session_Get(sessionID string) (Session, error) { + a.sessionsMu.Lock() + defer a.sessionsMu.Unlock() + + s, ok := a.sessions[sessionID] + + if sessionID == "" || !ok { + return Session{}, nil } - session, err := db.Session_Get(a.db, sessionID) + if timeSince(s.LastSeenAt) > sessionTTLSecs { + delete(a.sessions, sessionID) + return Session{}, nil + } + + s.LastSeenAt = time.Now().Unix() + return *s, nil +} + +// Session_SignIn verifies pwd and, on success, mints a fresh signed-in session, +// returning it so the caller can set the cookie. A new ID per sign-in rotates +// the session at the privilege boundary (session-fixation resistance). +func (a *API) Session_SignIn(pwd string) (Session, error) { + conf, err := a.Config_Get() if err != nil { - return a.session_CreatePub() + return Session{}, err + } + if err := bcrypt.CompareHashAndPassword(conf.Password, []byte(pwd)); err != nil { + return Session{}, ErrNotAuthorized } - if timeSince(session.LastSeenAt) > 86400*21 { - return a.session_CreatePub() - } - - if timeSince(session.LastSeenAt) > 86400*7 { - session.LastSeenAt = time.Now().Unix() - if err := db.Session_UpdateLastSeenAt(a.db, session.SessionID); err != nil { - log.Printf("Failed to update session: %v", err) - } - } - - return session, nil -} - -func (a *API) session_CreatePub() (*Session, error) { + a.sessionsMu.Lock() + defer a.sessionsMu.Unlock() s := &Session{ SessionID: idgen.NewToken(), - CSRF: idgen.NewToken(), - SignedIn: false, + SignedIn: true, CreatedAt: time.Now().Unix(), LastSeenAt: time.Now().Unix(), } - err := db.Session_Insert(a.db, s) - return s, err + a.sessions[s.SessionID] = s + return *s, nil } -func (a *API) Session_DeleteBefore(timestamp int64) error { - return db.Session_DeleteBefore(a.db, timestamp) -} - -func (a *API) Session_SignIn(s *Session, pwd string) error { - conf := a.Config_Get() - if err := bcrypt.CompareHashAndPassword(conf.Password, []byte(pwd)); err != nil { - return ErrNotAuthorized +// sweepSessions periodically evicts sessions past their TTL. Without it, a +// signed-in session whose ID is never presented again would linger forever +// (Session_Get only evicts on a lookup of that same ID). +func (a *API) sweepSessions() { + for range time.Tick(sessionSweepEvery) { + a.sessionsMu.Lock() + for id, s := range a.sessions { + if timeSince(s.LastSeenAt) > sessionTTLSecs { + delete(a.sessions, id) + } + } + a.sessionsMu.Unlock() } - - return db.Session_SetSignedIn(a.db, s.SessionID) } func (a *API) Network_Create(n *Network) error { @@ -141,14 +168,13 @@ func (a *API) Network_Get(id int64) (*Network, error) { } func (a *API) Network_List() ([]*Network, error) { - const query = db.Network_SelectQuery + ` ORDER BY Name ASC` + const query = db.Network_SelectQuery + ` ORDER BY LocalDomain ASC` return db.Network_List(a.db, query) } func (a *API) Peer_CreateNew(p *Peer) error { - p.Version = idgen.NextID(0) - p.PubKey = []byte{} - p.PubSignKey = []byte{} + p.WGPubKey = []byte{} + p.SignPubKey = []byte{} p.APIKey = idgen.NewToken() return db.Peer_Insert(a.db, p) @@ -158,21 +184,22 @@ func (a *API) Peer_Init(peer *Peer, args m.PeerInitArgs) error { a.lock.Lock() defer a.lock.Unlock() - peer.Version = idgen.NextID(0) - peer.PubKey = args.EncPubKey - peer.PubSignKey = args.PubSignKey + // Re-read from DB inside the lock — the caller's copy was fetched before + // we held the lock, so it may be stale under concurrent requests. + current, err := db.Peer_Get(a.db, peer.NetworkID, peer.PeerIP) + if err != nil { + return err + } + if len(current.WGPubKey) != 0 { + return errors.New("peer already initialized") + } + + peer.WGPubKey = args.WGPubKey + peer.SignPubKey = args.SignPubKey return db.Peer_UpdateFull(a.db, peer) } -func (a *API) Peer_Update(p *Peer) error { - a.lock.Lock() - defer a.lock.Unlock() - - p.Version = idgen.NextID(0) - return db.Peer_Update(a.db, p) -} - func (a *API) Peer_Delete(networkID int64, peerIP byte) error { return db.Peer_Delete(a.db, networkID, peerIP) } diff --git a/hub/api/db/generated.go b/hub/api/db/generated.go index 88aec6c..ab8e095 100644 --- a/hub/api/db/generated.go +++ b/hub/api/db/generated.go @@ -123,7 +123,9 @@ func Config_Get( ) { row = &Config{} r := tx.QueryRow("SELECT ConfigID,Password FROM config WHERE ConfigID=?", ConfigID) - err = r.Scan(&row.ConfigID, &row.Password) + if err = r.Scan(&row.ConfigID, &row.Password); err != nil { + row = nil + } return } @@ -137,7 +139,9 @@ func Config_GetWhere( ) { row = &Config{} r := tx.QueryRow(query, args...) - err = r.Scan(&row.ConfigID, &row.Password) + if err = r.Scan(&row.ConfigID, &row.Password); err != nil { + row = nil + } return } @@ -182,135 +186,17 @@ func Config_List( return l, nil } -// ---------------------------------------------------------------------------- -// Table: sessions -// ---------------------------------------------------------------------------- - -type Session struct { - SessionID string - CSRF string - SignedIn bool - CreatedAt int64 - LastSeenAt int64 -} - -const Session_SelectQuery = "SELECT SessionID,CSRF,SignedIn,CreatedAt,LastSeenAt FROM sessions" - -func Session_Insert( - tx TX, - row *Session, -) (err error) { - Session_Sanitize(row) - if err = Session_Validate(row); err != nil { - return err - } - - _, err = tx.Exec("INSERT INTO sessions(SessionID,CSRF,SignedIn,CreatedAt,LastSeenAt) VALUES(?,?,?,?,?)", row.SessionID, row.CSRF, row.SignedIn, row.CreatedAt, row.LastSeenAt) - return err -} - -func Session_Delete( - tx TX, - SessionID string, -) (err error) { - result, err := tx.Exec("DELETE FROM sessions WHERE SessionID=?", SessionID) - if err != nil { - return err - } - - n, err := result.RowsAffected() - if err != nil { - panic(err) - } - switch n { - case 0: - return sql.ErrNoRows - case 1: - return nil - default: - panic("multiple rows deleted") - } -} - -func Session_Get( - tx TX, - SessionID string, -) ( - row *Session, - err error, -) { - row = &Session{} - r := tx.QueryRow("SELECT SessionID,CSRF,SignedIn,CreatedAt,LastSeenAt FROM sessions WHERE SessionID=?", SessionID) - err = r.Scan(&row.SessionID, &row.CSRF, &row.SignedIn, &row.CreatedAt, &row.LastSeenAt) - return -} - -func Session_GetWhere( - tx TX, - query string, - args ...any, -) ( - row *Session, - err error, -) { - row = &Session{} - r := tx.QueryRow(query, args...) - err = r.Scan(&row.SessionID, &row.CSRF, &row.SignedIn, &row.CreatedAt, &row.LastSeenAt) - return -} - -func Session_Iterate( - tx TX, - query string, - args ...any, -) iter.Seq2[*Session, error] { - rows, err := tx.Query(query, args...) - if err != nil { - return func(yield func(*Session, error) bool) { - yield(nil, err) - } - } - - return func(yield func(*Session, error) bool) { - defer rows.Close() - for rows.Next() { - row := &Session{} - err := rows.Scan(&row.SessionID, &row.CSRF, &row.SignedIn, &row.CreatedAt, &row.LastSeenAt) - if !yield(row, err) { - return - } - } - } -} - -func Session_List( - tx TX, - query string, - args ...any, -) ( - l []*Session, - err error, -) { - for row, err := range Session_Iterate(tx, query, args...) { - if err != nil { - return nil, err - } - l = append(l, row) - } - return l, nil -} - // ---------------------------------------------------------------------------- // Table: networks // ---------------------------------------------------------------------------- type Network struct { - NetworkID int64 - Name string - Network []byte + NetworkID int64 + LocalDomain string + Network []byte } -const Network_SelectQuery = "SELECT NetworkID,Name,Network FROM networks" +const Network_SelectQuery = "SELECT NetworkID,LocalDomain,Network FROM networks" func Network_Insert( tx TX, @@ -321,7 +207,7 @@ func Network_Insert( return err } - _, err = tx.Exec("INSERT INTO networks(NetworkID,Name,Network) VALUES(?,?,?)", row.NetworkID, row.Name, row.Network) + _, err = tx.Exec("INSERT INTO networks(NetworkID,LocalDomain,Network) VALUES(?,?,?)", row.NetworkID, row.LocalDomain, row.Network) return err } @@ -334,7 +220,7 @@ func Network_UpdateFull( return err } - result, err := tx.Exec("UPDATE networks SET Name=?,Network=? WHERE NetworkID=?", row.Name, row.Network, row.NetworkID) + result, err := tx.Exec("UPDATE networks SET LocalDomain=?,Network=? WHERE NetworkID=?", row.LocalDomain, row.Network, row.NetworkID) if err != nil { return err } @@ -384,8 +270,10 @@ func Network_Get( err error, ) { row = &Network{} - r := tx.QueryRow("SELECT NetworkID,Name,Network FROM networks WHERE NetworkID=?", NetworkID) - err = r.Scan(&row.NetworkID, &row.Name, &row.Network) + r := tx.QueryRow("SELECT NetworkID,LocalDomain,Network FROM networks WHERE NetworkID=?", NetworkID) + if err = r.Scan(&row.NetworkID, &row.LocalDomain, &row.Network); err != nil { + row = nil + } return } @@ -399,7 +287,9 @@ func Network_GetWhere( ) { row = &Network{} r := tx.QueryRow(query, args...) - err = r.Scan(&row.NetworkID, &row.Name, &row.Network) + if err = r.Scan(&row.NetworkID, &row.LocalDomain, &row.Network); err != nil { + row = nil + } return } @@ -419,7 +309,7 @@ func Network_Iterate( defer rows.Close() for rows.Next() { row := &Network{} - err := rows.Scan(&row.NetworkID, &row.Name, &row.Network) + err := rows.Scan(&row.NetworkID, &row.LocalDomain, &row.Network) if !yield(row, err) { return } @@ -451,17 +341,17 @@ func Network_List( type Peer struct { NetworkID int64 PeerIP byte - Version int64 APIKey string Name string - PublicIP []byte + Addr4 []byte + Addr6 []byte Port uint16 Relay bool - PubKey []byte - PubSignKey []byte + WGPubKey []byte + SignPubKey []byte } -const Peer_SelectQuery = "SELECT NetworkID,PeerIP,Version,APIKey,Name,PublicIP,Port,Relay,PubKey,PubSignKey FROM peers" +const Peer_SelectQuery = "SELECT NetworkID,PeerIP,APIKey,Name,Addr4,Addr6,Port,Relay,WGPubKey,SignPubKey FROM peers" func Peer_Insert( tx TX, @@ -472,38 +362,10 @@ func Peer_Insert( return err } - _, err = tx.Exec("INSERT INTO peers(NetworkID,PeerIP,Version,APIKey,Name,PublicIP,Port,Relay,PubKey,PubSignKey) VALUES(?,?,?,?,?,?,?,?,?,?)", row.NetworkID, row.PeerIP, row.Version, row.APIKey, row.Name, row.PublicIP, row.Port, row.Relay, row.PubKey, row.PubSignKey) + _, err = tx.Exec("INSERT INTO peers(NetworkID,PeerIP,APIKey,Name,Addr4,Addr6,Port,Relay,WGPubKey,SignPubKey) VALUES(?,?,?,?,?,?,?,?,?,?)", row.NetworkID, row.PeerIP, row.APIKey, row.Name, row.Addr4, row.Addr6, row.Port, row.Relay, row.WGPubKey, row.SignPubKey) return err } -func Peer_Update( - tx TX, - row *Peer, -) (err error) { - Peer_Sanitize(row) - if err = Peer_Validate(row); err != nil { - return err - } - - result, err := tx.Exec("UPDATE peers SET Version=?,Name=?,PublicIP=?,Port=?,Relay=? WHERE NetworkID=? AND PeerIP=?", row.Version, row.Name, row.PublicIP, row.Port, row.Relay, row.NetworkID, row.PeerIP) - if err != nil { - return err - } - - n, err := result.RowsAffected() - if err != nil { - panic(err) - } - switch n { - case 0: - return sql.ErrNoRows - case 1: - return nil - default: - panic("multiple rows updated") - } -} - func Peer_UpdateFull( tx TX, row *Peer, @@ -513,7 +375,7 @@ func Peer_UpdateFull( return err } - result, err := tx.Exec("UPDATE peers SET Version=?,APIKey=?,Name=?,PublicIP=?,Port=?,Relay=?,PubKey=?,PubSignKey=? WHERE NetworkID=? AND PeerIP=?", row.Version, row.APIKey, row.Name, row.PublicIP, row.Port, row.Relay, row.PubKey, row.PubSignKey, row.NetworkID, row.PeerIP) + result, err := tx.Exec("UPDATE peers SET APIKey=?,Name=?,Addr4=?,Addr6=?,Port=?,Relay=?,WGPubKey=?,SignPubKey=? WHERE NetworkID=? AND PeerIP=?", row.APIKey, row.Name, row.Addr4, row.Addr6, row.Port, row.Relay, row.WGPubKey, row.SignPubKey, row.NetworkID, row.PeerIP) if err != nil { return err } @@ -565,8 +427,10 @@ func Peer_Get( err error, ) { row = &Peer{} - r := tx.QueryRow("SELECT NetworkID,PeerIP,Version,APIKey,Name,PublicIP,Port,Relay,PubKey,PubSignKey FROM peers WHERE NetworkID=? AND PeerIP=?", NetworkID, PeerIP) - err = r.Scan(&row.NetworkID, &row.PeerIP, &row.Version, &row.APIKey, &row.Name, &row.PublicIP, &row.Port, &row.Relay, &row.PubKey, &row.PubSignKey) + r := tx.QueryRow("SELECT NetworkID,PeerIP,APIKey,Name,Addr4,Addr6,Port,Relay,WGPubKey,SignPubKey FROM peers WHERE NetworkID=? AND PeerIP=?", NetworkID, PeerIP) + if err = r.Scan(&row.NetworkID, &row.PeerIP, &row.APIKey, &row.Name, &row.Addr4, &row.Addr6, &row.Port, &row.Relay, &row.WGPubKey, &row.SignPubKey); err != nil { + row = nil + } return } @@ -580,7 +444,9 @@ func Peer_GetWhere( ) { row = &Peer{} r := tx.QueryRow(query, args...) - err = r.Scan(&row.NetworkID, &row.PeerIP, &row.Version, &row.APIKey, &row.Name, &row.PublicIP, &row.Port, &row.Relay, &row.PubKey, &row.PubSignKey) + if err = r.Scan(&row.NetworkID, &row.PeerIP, &row.APIKey, &row.Name, &row.Addr4, &row.Addr6, &row.Port, &row.Relay, &row.WGPubKey, &row.SignPubKey); err != nil { + row = nil + } return } @@ -600,7 +466,7 @@ func Peer_Iterate( defer rows.Close() for rows.Next() { row := &Peer{} - err := rows.Scan(&row.NetworkID, &row.PeerIP, &row.Version, &row.APIKey, &row.Name, &row.PublicIP, &row.Port, &row.Relay, &row.PubKey, &row.PubSignKey) + err := rows.Scan(&row.NetworkID, &row.PeerIP, &row.APIKey, &row.Name, &row.Addr4, &row.Addr6, &row.Port, &row.Relay, &row.WGPubKey, &row.SignPubKey) if !yield(row, err) { return } diff --git a/hub/api/db/sanitize-validate.go b/hub/api/db/sanitize-validate.go index 71785e9..ffe1f7d 100644 --- a/hub/api/db/sanitize-validate.go +++ b/hub/api/db/sanitize-validate.go @@ -8,9 +8,11 @@ import ( var ( ErrInvalidIP = errors.New("invalid IP") + ErrInvalidPeerIP = errors.New("invalid peer IP") ErrNonPrivateIP = errors.New("non-private IP") ErrInvalidPort = errors.New("invalid port") ErrInvalidNetName = errors.New("invalid network name") + ErrNetNameNotLocal = errors.New("network name must end with .local") ErrInvalidPeerName = errors.New("invalid peer name") ) @@ -21,15 +23,8 @@ func Config_Validate(c *Config) error { return nil } -func Session_Sanitize(s *Session) { -} - -func Session_Validate(s *Session) error { - return nil -} - func Network_Sanitize(n *Network) { - n.Name = strings.TrimSpace(n.Name) + n.LocalDomain = strings.TrimSpace(n.LocalDomain) if addr, ok := netip.AddrFromSlice(n.Network); ok { n.Network = addr.AsSlice() @@ -37,12 +32,17 @@ func Network_Sanitize(n *Network) { } func Network_Validate(c *Network) error { - // 16 bytes is linux limit for network interface names. - if len(c.Name) == 0 || len(c.Name) > 16 { + // 15 bytes is linux limit for network interface names. With ending .local, + // max length is 21. + if len(c.LocalDomain) == 0 || len(c.LocalDomain) > 21 { return ErrInvalidNetName } - for _, c := range c.Name { + if !strings.HasSuffix(c.LocalDomain, ".local") { + return ErrNetNameNotLocal + } + + for _, c := range strings.TrimSuffix(c.LocalDomain, ".local") { if c >= 'a' && c <= 'z' { continue } @@ -66,21 +66,35 @@ func Network_Validate(c *Network) error { func Peer_Sanitize(p *Peer) { p.Name = strings.TrimSpace(p.Name) - if len(p.PublicIP) != 0 { - addr, ok := netip.AddrFromSlice(p.PublicIP) - if ok && addr.Is4() { - p.PublicIP = addr.AsSlice() + if len(p.Addr4) != 0 { + if addr, ok := netip.AddrFromSlice(p.Addr4); ok { + // Unmap so an IPv4-mapped form is stored canonically as 4 bytes. + p.Addr4 = addr.Unmap().AsSlice() + } + } + if len(p.Addr6) != 0 { + if addr, ok := netip.AddrFromSlice(p.Addr6); ok { + p.Addr6 = addr.AsSlice() } } if p.Port == 0 { - p.Port = 456 + p.Port = 51820 } } func Peer_Validate(p *Peer) error { - if len(p.PublicIP) > 0 { - _, ok := netip.AddrFromSlice(p.PublicIP) - if !ok { + if p.PeerIP < 1 || p.PeerIP > 254 { + return ErrInvalidPeerIP + } + if len(p.Addr4) > 0 { + // Must be a genuine IPv4 address (reject an IPv6 in the v4 field). + if addr, ok := netip.AddrFromSlice(p.Addr4); !ok || !addr.Is4() { + return ErrInvalidIP + } + } + if len(p.Addr6) > 0 { + // Must be a genuine IPv6 address (reject IPv4 / IPv4-mapped in the v6 field). + if addr, ok := netip.AddrFromSlice(p.Addr6); !ok || !addr.Is6() || addr.Is4In6() { return ErrInvalidIP } } @@ -88,6 +102,9 @@ func Peer_Validate(p *Peer) error { return ErrInvalidPort } + if len(p.Name) == 0 { + return ErrInvalidPeerName + } for _, c := range p.Name { if c >= 'a' && c <= 'z' { continue @@ -95,10 +112,9 @@ func Peer_Validate(p *Peer) error { if c >= '0' && c <= '9' { continue } - if c == '.' || c == '-' || c == '_' { + if c == '-' { continue } - return ErrInvalidPeerName } diff --git a/hub/api/db/tables.defs b/hub/api/db/tables.defs index d6dc338..39665fa 100644 --- a/hub/api/db/tables.defs +++ b/hub/api/db/tables.defs @@ -3,29 +3,21 @@ TABLE config OF Config ( Password []byte ); -TABLE sessions OF Session NoUpdate ( - SessionID string PK, - CSRF string, - SignedIn bool, - CreatedAt int64, - LastSeenAt int64 -); - TABLE networks OF Network ( - NetworkID int64 PK, - Name string NoUpdate, - Network []byte NoUpdate + NetworkID int64 PK, + LocalDomain string NoUpdate, + Network []byte NoUpdate ); TABLE peers OF Peer ( NetworkID int64 PK, PeerIP byte PK, - Version int64, APIKey string NoUpdate, - Name string, - PublicIP []byte, - Port uint16, - Relay bool, - PubKey []byte NoUpdate, - PubSignKey []byte NoUpdate + Name string NoUpdate, + Addr4 []byte NoUpdate, + Addr6 []byte NoUpdate, + Port uint16 NoUpdate, + Relay bool NoUpdate, + WGPubKey []byte NoUpdate, + SignPubKey []byte NoUpdate ); diff --git a/hub/api/db/written.go b/hub/api/db/written.go index 6d61bb5..11251f2 100644 --- a/hub/api/db/written.go +++ b/hub/api/db/written.go @@ -1,31 +1,5 @@ package db -import "time" - -func Session_UpdateLastSeenAt( - tx TX, - id string, -) (err error) { - _, err = tx.Exec("UPDATE sessions SET LastSeenAt=? WHERE SessionID=?", time.Now().Unix(), id) - return err -} - -func Session_SetSignedIn( - tx TX, - id string, -) (err error) { - _, err = tx.Exec("UPDATE sessions SET SignedIn=1 WHERE SessionID=?", id) - return err -} - -func Session_DeleteBefore( - tx TX, - timestamp int64, -) (err error) { - _, err = tx.Exec("DELETE FROM sessions WHERE LastSeenAt", timestamp) - return err -} - func Peer_ListAll(tx TX, networkID int64) ([]*Peer, error) { const query = Peer_SelectQuery + ` WHERE NetworkID=? ORDER BY PeerIP ASC` return Peer_List(tx, query, networkID) @@ -37,9 +11,3 @@ func Peer_GetByAPIKey(tx TX, apiKey string) (*Peer, error) { Peer_SelectQuery+` WHERE APIKey=?`, apiKey) } - -func Peer_Exists(tx TX, networkID int64, ip byte) (exists bool, err error) { - const query = `SELECT EXISTS(SELECT 1 FROM peers WHERE NetworkID=? AND PeerIP=?)` - err = tx.QueryRow(query, networkID, ip).Scan(&exists) - return -} diff --git a/hub/api/errors.go b/hub/api/errors.go index a7da6f6..27d241f 100644 --- a/hub/api/errors.go +++ b/hub/api/errors.go @@ -7,7 +7,6 @@ import ( var ( ErrNotAuthorized = errors.New("not authorized") - ErrNoIPAvailable = errors.New("no IP address available") ErrInvalidIP = db.ErrInvalidIP ErrInvalidPort = db.ErrInvalidPort ) diff --git a/hub/api/migrations/2024-11-30-init.sql b/hub/api/migrations/2024-11-30-init.sql index f60aa77..bce964b 100644 --- a/hub/api/migrations/2024-11-30-init.sql +++ b/hub/api/migrations/2024-11-30-init.sql @@ -3,32 +3,23 @@ CREATE TABLE config ( Password BLOB NOT NULL -- bcrypt password for web interface ) WITHOUT ROWID; -CREATE TABLE sessions ( - SessionID TEXT NOT NULL PRIMARY KEY, - CSRF TEXT NOT NULL, - SignedIn INTEGER NOT NULL, - CreatedAt INTEGER NOT NULL, - LastSeenAt INTEGER NOT NULL -) WITHOUT ROWID; - -CREATE INDEX sessions_last_seen_index ON sessions(LastSeenAt); - CREATE TABLE networks ( NetworkID INTEGER NOT NULL PRIMARY KEY, - Name TEXT NOT NULL UNIQUE, -- Network/interface name. + LocalDomain TEXT NOT NULL UNIQUE, -- Network/interface name. Network BLOB NOT NULL UNIQUE -- Network (/24), example 10.51.50.0 ) WITHOUT ROWID; CREATE TABLE peers ( NetworkID INTEGER NOT NULL, - PeerIP INTEGER NOT NULL, -- Final byte of IP. - Version INTEGER NOT NULL, -- Changes when updated. - APIKey TEXT NOT NULL UNIQUE, -- Peer's secret API key. - Name TEXT NOT NULL UNIQUE, -- For humans. - PublicIP BLOB NOT NULL, + PeerIP INTEGER NOT NULL, -- Final byte of IP. + APIKey TEXT NOT NULL UNIQUE, -- Peer's secret API key. + Name TEXT NOT NULL, -- For humans. + Addr4 BLOB NOT NULL, + Addr6 BLOB NOT NULL, Port INTEGER NOT NULL, - Relay INTEGER NOT NULL DEFAULT 0, -- Boolean if peer will forward packets. Must also have public address. - PubKey BLOB NOT NULL, - PubSignKey BLOB NOT NULL, + Relay INTEGER NOT NULL DEFAULT 0, -- Boolean if peer will forward packets. + WGPubKey BLOB NOT NULL, + SignPubKey BLOB NOT NULL, + UNIQUE(NetworkID, Name), PRIMARY KEY(NetworkID, PeerIP) ) WITHOUT ROWID; diff --git a/hub/api/types.go b/hub/api/types.go index bfcfc04..d1a3b35 100644 --- a/hub/api/types.go +++ b/hub/api/types.go @@ -3,6 +3,12 @@ package api import "vppn/hub/api/db" type Config = db.Config -type Session = db.Session type Network = db.Network type Peer = db.Peer + +type Session struct { + SessionID string + SignedIn bool + CreatedAt int64 + LastSeenAt int64 +} diff --git a/hub/app.go b/hub/app.go index 7e07a34..2695d1a 100644 --- a/hub/app.go +++ b/hub/app.go @@ -2,6 +2,7 @@ package hub import ( "embed" + "encoding/base64" "html/template" "net/http" "path/filepath" @@ -47,6 +48,19 @@ func NewApp(conf Config) (*App, error) { return app, nil } -var templateFuncs = template.FuncMap{ - "ipToString": ipBytesTostring, +func (app *App) Handler() http.Handler { + cop := http.NewCrossOriginProtection() + return cop.Handler(app.mux) +} + +var templateFuncs = template.FuncMap{ + "ipToString": ipBytesTostring, + "wgKeyString": wgKeyString, +} + +func wgKeyString(key []byte) string { + if len(key) == 0 { + return "not set" + } + return base64.StdEncoding.EncodeToString(key) } diff --git a/hub/cookie.go b/hub/cookie.go index 2048d6b..c2bacc9 100644 --- a/hub/cookie.go +++ b/hub/cookie.go @@ -2,7 +2,6 @@ package hub import ( "net/http" - "time" ) func (a *App) getCookie(r *http.Request, name string) string { @@ -26,9 +25,12 @@ func (a *App) setCookie(w http.ResponseWriter, name, value string) { func (a *App) deleteCookie(w http.ResponseWriter, name string) { http.SetCookie(w, &http.Cookie{ - Name: name, - Value: "", - Path: "/", - Expires: time.Unix(0, 0), + Name: name, + Value: "", + Path: "/", + Secure: !a.insecure, + SameSite: http.SameSiteStrictMode, + HttpOnly: true, + MaxAge: -1, // delete now }) } diff --git a/hub/global.go b/hub/global.go index 9f9a308..525b63a 100644 --- a/hub/global.go +++ b/hub/global.go @@ -1,5 +1,5 @@ package hub const ( - SESSION_ID_COOKIE_NAME = "SessionID" + sessionIDCookieName = "SessionID" ) diff --git a/hub/handler.go b/hub/handler.go index 25c87ae..13bfe50 100644 --- a/hub/handler.go +++ b/hub/handler.go @@ -12,7 +12,7 @@ type handlerFunc func(s *api.Session, w http.ResponseWriter, r *http.Request) er func (app *App) handlePub(pattern string, fn handlerFunc) { wrapped := func(w http.ResponseWriter, r *http.Request) { - sessionID := app.getCookie(r, SESSION_ID_COOKIE_NAME) + sessionID := app.getCookie(r, sessionIDCookieName) s, err := app.api.Session_Get(sessionID) if err != nil { log.Printf("Failed to get session: %v", err) @@ -20,22 +20,13 @@ func (app *App) handlePub(pattern string, fn handlerFunc) { return } - if s.SessionID != sessionID { - app.setCookie(w, SESSION_ID_COOKIE_NAME, s.SessionID) - } - if r.Method == http.MethodPost { r.ParseMultipartForm(64 * 1024) - if r.FormValue("CSRF") != s.CSRF { - log.Printf("%s != %s", r.FormValue("CSRF"), s.CSRF) - http.Error(w, "CSRF mismatch", http.StatusBadRequest) - return - } } else { r.ParseForm() } - if err := fn(s, w, r); err != nil { + if err := fn(&s, w, r); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } } diff --git a/hub/handlers.go b/hub/handlers.go index b2c6f3f..bcbe82b 100644 --- a/hub/handlers.go +++ b/hub/handlers.go @@ -10,6 +10,7 @@ import ( "git.crumpington.com/lib/go/webutil" "golang.org/x/crypto/bcrypt" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) func (a *App) _root(s *api.Session, w http.ResponseWriter, r *http.Request) error { @@ -33,9 +34,11 @@ func (a *App) _signinSubmit(s *api.Session, w http.ResponseWriter, r *http.Reque return err } - if err := a.api.Session_SignIn(s, pwd); err != nil { + sess, err := a.api.Session_SignIn(pwd) + if err != nil { return err } + a.setCookie(w, sessionIDCookieName, sess.SessionID) return a.redirect(w, r, "/") } @@ -48,7 +51,7 @@ func (a *App) _adminSignOutSubmit(s *api.Session, w http.ResponseWriter, r *http if err := a.api.Session_Delete(s.SessionID); err != nil { log.Printf("Failed to delete session cookie %s: %v", s.SessionID, err) } - a.deleteCookie(w, SESSION_ID_COOKIE_NAME) + a.deleteCookie(w, sessionIDCookieName) return a.redirect(w, r, "/") } @@ -72,7 +75,7 @@ func (a *App) _adminNetworkCreateSubmit(s *api.Session, w http.ResponseWriter, r var netStr string err := webutil.NewFormScanner(r.Form). - Scan("Name", &n.Name). + Scan("LocalDomain", &n.LocalDomain). Scan("Network", &netStr). Error() if err != nil { @@ -142,14 +145,15 @@ func (a *App) _adminPeerCreate(s *api.Session, w http.ResponseWriter, r *http.Re } func (a *App) _adminPeerCreateSubmit(s *api.Session, w http.ResponseWriter, r *http.Request) error { - var ipStr string + var addr4Str, addr6Str string p := &api.Peer{} err := webutil.NewFormScanner(r.Form). Scan("NetworkID", &p.NetworkID). Scan("IP", &p.PeerIP). Scan("Name", &p.Name). - Scan("PublicIP", &ipStr). + Scan("Addr4", &addr4Str). + Scan("Addr6", &addr6Str). Scan("Port", &p.Port). Scan("Relay", &p.Relay). Error() @@ -157,7 +161,10 @@ func (a *App) _adminPeerCreateSubmit(s *api.Session, w http.ResponseWriter, r *h return err } - if p.PublicIP, err = stringToIP(ipStr); err != nil { + if p.Addr4, err = stringToIP(addr4Str); err != nil { + return err + } + if p.Addr6, err = stringToIP(addr6Str); err != nil { return err } @@ -180,48 +187,6 @@ func (a *App) _adminPeerView(s *api.Session, w http.ResponseWriter, r *http.Requ }{s, net, peer}) } -func (a *App) _adminPeerEdit(s *api.Session, w http.ResponseWriter, r *http.Request) error { - net, peer, err := a.formGetPeer(r.Form) - if err != nil { - return err - } - - return a.render("/network/peer-edit.html", w, struct { - Session *api.Session - Network *api.Network - Peer *api.Peer - }{s, net, peer}) -} - -func (a *App) _adminPeerEditSubmit(s *api.Session, w http.ResponseWriter, r *http.Request) error { - _, peer, err := a.formGetPeer(r.Form) - if err != nil { - return err - } - - var ipStr string - - err = webutil.NewFormScanner(r.Form). - Scan("Name", &peer.Name). - Scan("PublicIP", &ipStr). - Scan("Port", &peer.Port). - Scan("Relay", &peer.Relay). - Error() - if err != nil { - return err - } - - if peer.PublicIP, err = stringToIP(ipStr); err != nil { - return err - } - - if err = a.api.Peer_Update(peer); err != nil { - return err - } - - return a.redirect(w, r, "/admin/peer/view/?NetworkID=%d&PeerIP=%d", peer.NetworkID, peer.PeerIP) -} - func (a *App) _adminPeerDelete(s *api.Session, w http.ResponseWriter, r *http.Request) error { n, peer, err := a.formGetPeer(r.Form) if err != nil { @@ -252,13 +217,17 @@ func (a *App) _adminPasswordEdit(s *api.Session, w http.ResponseWriter, r *http. func (a *App) _adminPasswordSubmit(s *api.Session, w http.ResponseWriter, r *http.Request) error { var ( - conf = a.api.Config_Get() curPwd string newPwd string newPwd2 string ) - err := webutil.NewFormScanner(r.Form). + conf, err := a.api.Config_Get() + if err != nil { + return err + } + + err = webutil.NewFormScanner(r.Form). Scan("CurrentPassword", &curPwd). Scan("NewPassword", &newPwd). Scan("NewPassword2", &newPwd2). @@ -295,11 +264,25 @@ func (a *App) _adminPasswordSubmit(s *api.Session, w http.ResponseWriter, r *htt } func (a *App) _peerInit(peer *api.Peer, w http.ResponseWriter, r *http.Request) error { + if len(peer.WGPubKey) != 0 { + http.Error(w, "Already initialized", http.StatusConflict) + return nil + } + args := m.PeerInitArgs{} if err := json.NewDecoder(r.Body).Decode(&args); err != nil { return err } + if len(args.WGPubKey) != 32 { + http.Error(w, "invalid WGPubKey", http.StatusBadRequest) + return nil + } + if len(args.SignPubKey) != 32 { + http.Error(w, "invalid SignPubKey", http.StatusBadRequest) + return nil + } + net, err := a.api.Network_Get(peer.NetworkID) if err != nil { return err @@ -310,11 +293,12 @@ func (a *App) _peerInit(peer *api.Peer, w http.ResponseWriter, r *http.Request) } resp := m.PeerInitResp{ - PeerIP: peer.PeerIP, - Network: net.Network, + PeerIP: peer.PeerIP, + Network: net.Network, + LocalDomain: net.LocalDomain, } - resp.NetworkState.Peers, err = a.peersArray(net.NetworkID) + resp.NetworkState.Peers, err = a.peersList(net.NetworkID) if err != nil { return err } @@ -323,34 +307,42 @@ func (a *App) _peerInit(peer *api.Peer, w http.ResponseWriter, r *http.Request) } func (a *App) _peerFetchState(peer *api.Peer, w http.ResponseWriter, r *http.Request) error { - - peers, err := a.peersArray(peer.NetworkID) + peers, err := a.peersList(peer.NetworkID) if err != nil { return err } return a.sendJSON(w, m.NetworkState{Peers: peers}) } -func (a *App) peersArray(networkID int64) (peers [256]*m.Peer, err error) { +func (a *App) peersList(networkID int64) (peers []m.Peer, err error) { l, err := a.api.Peer_List(networkID) if err != nil { - return peers, err + return nil, err } + peers = make([]m.Peer, 0, len(l)) + for _, p := range l { - if len(p.PubKey) != 0 { - peers[p.PeerIP] = &m.Peer{ - PeerIP: p.PeerIP, - Version: p.Version, - Name: p.Name, - PublicIP: p.PublicIP, - Port: p.Port, - Relay: p.Relay, - PubKey: p.PubKey, - PubSignKey: p.PubSignKey, - } + if len(p.WGPubKey) == 0 { + continue } + wgKey, err := wgtypes.NewKey(p.WGPubKey) + if err != nil { + continue // malformed key; skip rather than serve garbage + } + var signKey [32]byte + copy(signKey[:], p.SignPubKey) + peers = append(peers, m.Peer{ + PeerIP: p.PeerIP, + Name: p.Name, + Addr4: addrFromBytes(p.Addr4), + Addr6: addrFromBytes(p.Addr6), + Port: p.Port, + Relay: p.Relay, + WGPubKey: wgKey, + SignPubKey: signKey, + }) } - return + return peers, nil } diff --git a/hub/main.go b/hub/main.go index 5b1d951..dfa6f2e 100644 --- a/hub/main.go +++ b/hub/main.go @@ -31,7 +31,7 @@ func Main() { srv := &http.Server{ Addr: conf.ListenAddr, - Handler: app.mux, + Handler: app.Handler(), } log.Fatal(webutil.ListenAndServe(srv)) diff --git a/hub/routes.go b/hub/routes.go index bb2d555..3ba5fe3 100644 --- a/hub/routes.go +++ b/hub/routes.go @@ -22,8 +22,6 @@ func (a *App) registerRoutes() { a.handleSignedIn("GET /admin/peer/create/", a._adminPeerCreate) a.handleSignedIn("POST /admin/peer/create/", a._adminPeerCreateSubmit) a.handleSignedIn("GET /admin/peer/view/", a._adminPeerView) - a.handleSignedIn("GET /admin/peer/edit/", a._adminPeerEdit) - a.handleSignedIn("POST /admin/peer/edit/", a._adminPeerEditSubmit) a.handleSignedIn("GET /admin/peer/delete/", a._adminPeerDelete) a.handleSignedIn("POST /admin/peer/delete/", a._adminPeerDeleteSubmit) diff --git a/hub/templates/admin-network-create.html b/hub/templates/admin-network-create.html index 786ae8b..de06e4f 100644 --- a/hub/templates/admin-network-create.html +++ b/hub/templates/admin-network-create.html @@ -2,10 +2,9 @@