This commit is contained in:
jdl
2026-06-13 14:44:25 +02:00
parent a730211167
commit 232b68310c
5 changed files with 56 additions and 74 deletions

View File

@@ -8,6 +8,7 @@ import (
"sync" "sync"
"time" "time"
"vppn/hub/api/db" "vppn/hub/api/db"
"vppn/hub/errs"
"vppn/m" "vppn/m"
"git.crumpington.com/lib/go/idgen" "git.crumpington.com/lib/go/idgen"
@@ -64,30 +65,31 @@ func (a *API) ensurePassword() error {
hashed, err := bcrypt.GenerateFromPassword([]byte(pwd), bcrypt.DefaultCost) hashed, err := bcrypt.GenerateFromPassword([]byte(pwd), bcrypt.DefaultCost)
if err != nil { if err != nil {
return err log.Printf("Failed to generate password: %v", err)
return errs.ErrUnexpected
} }
conf := &Config{ConfigID: 1, Password: hashed} conf := &Config{ConfigID: 1, Password: hashed}
return db.Config_Insert(a.db, conf) return errs.DB(db.Config_Insert(a.db, conf))
} }
func (a *API) Config_Get() (*Config, error) { func (a *API) Config_Get() (*Config, error) {
return db.Config_Get(a.db, 1) conf, err := db.Config_Get(a.db, 1)
return conf, errs.DB(err)
} }
func (a *API) Config_Update(conf *Config) error { func (a *API) Config_Update(conf *Config) error {
return db.Config_Update(a.db, conf) return errs.DB(db.Config_Update(a.db, conf))
} }
func (a *API) Session_Delete(sessionID string) error { func (a *API) Session_Delete(sessionID string) {
a.sessionsMu.Lock() a.sessionsMu.Lock()
defer a.sessionsMu.Unlock() defer a.sessionsMu.Unlock()
delete(a.sessions, sessionID) delete(a.sessions, sessionID)
return nil
} }
const ( const (
sessionTTLSecs = 86400 * 21 // sessions expire 21 days after last use sessionTTLSecs = 24 * 21 * time.Hour // sessions expire 21 days after last use
sessionSweepEvery = time.Hour // cadence of expired-session eviction sessionSweepEvery = time.Hour // cadence of expired-session eviction
) )
@@ -96,23 +98,23 @@ const (
// creates a session, so anonymous requests cost no memory — a session is minted // creates a session, so anonymous requests cost no memory — a session is minted
// only by Session_SignIn. Returning a value (not the stored pointer) keeps // only by Session_SignIn. Returning a value (not the stored pointer) keeps
// callers from racing on the shared struct. // callers from racing on the shared struct.
func (a *API) Session_Get(sessionID string) (Session, error) { func (a *API) Session_Get(sessionID string) Session {
a.sessionsMu.Lock() a.sessionsMu.Lock()
defer a.sessionsMu.Unlock() defer a.sessionsMu.Unlock()
s, ok := a.sessions[sessionID] s, ok := a.sessions[sessionID]
if sessionID == "" || !ok { if sessionID == "" || !ok {
return Session{}, nil return Session{}
} }
if timeSince(s.LastSeenAt) > sessionTTLSecs { if time.Since(s.LastSeenAt) > sessionTTLSecs {
delete(a.sessions, sessionID) delete(a.sessions, sessionID)
return Session{}, nil return Session{}
} }
s.LastSeenAt = time.Now().Unix() s.LastSeenAt = time.Now()
return *s, nil return *s
} }
// Session_SignIn verifies pwd and, on success, mints a fresh signed-in session, // Session_SignIn verifies pwd and, on success, mints a fresh signed-in session,
@@ -121,10 +123,11 @@ func (a *API) Session_Get(sessionID string) (Session, error) {
func (a *API) Session_SignIn(pwd string) (Session, error) { func (a *API) Session_SignIn(pwd string) (Session, error) {
conf, err := a.Config_Get() conf, err := a.Config_Get()
if err != nil { if err != nil {
return Session{}, err log.Printf("Failed to get config: %v", err)
return Session{}, errs.ErrUnexpected
} }
if err := bcrypt.CompareHashAndPassword(conf.Password, []byte(pwd)); err != nil { if err := bcrypt.CompareHashAndPassword(conf.Password, []byte(pwd)); err != nil {
return Session{}, ErrNotAuthorized return Session{}, errs.NotAuthorized.WithMsg("Not authorized.")
} }
a.sessionsMu.Lock() a.sessionsMu.Lock()
@@ -132,8 +135,8 @@ func (a *API) Session_SignIn(pwd string) (Session, error) {
s := &Session{ s := &Session{
SessionID: idgen.NewToken(), SessionID: idgen.NewToken(),
SignedIn: true, SignedIn: true,
CreatedAt: time.Now().Unix(), CreatedAt: time.Now(),
LastSeenAt: time.Now().Unix(), LastSeenAt: time.Now(),
} }
a.sessions[s.SessionID] = s a.sessions[s.SessionID] = s
return *s, nil return *s, nil
@@ -146,7 +149,7 @@ func (a *API) sweepSessions() {
for range time.Tick(sessionSweepEvery) { for range time.Tick(sessionSweepEvery) {
a.sessionsMu.Lock() a.sessionsMu.Lock()
for id, s := range a.sessions { for id, s := range a.sessions {
if timeSince(s.LastSeenAt) > sessionTTLSecs { if time.Since(s.LastSeenAt) > sessionTTLSecs {
delete(a.sessions, id) delete(a.sessions, id)
} }
} }
@@ -156,20 +159,22 @@ func (a *API) sweepSessions() {
func (a *API) Network_Create(n *Network) error { func (a *API) Network_Create(n *Network) error {
n.NetworkID = idgen.NextID(0) n.NetworkID = idgen.NextID(0)
return db.Network_Insert(a.db, n) return errs.DB(db.Network_Insert(a.db, n))
} }
func (a *API) Network_Delete(n *Network) error { func (a *API) Network_Delete(n *Network) error {
return db.Network_Delete(a.db, n.NetworkID) return errs.DB(db.Network_Delete(a.db, n.NetworkID))
} }
func (a *API) Network_Get(id int64) (*Network, error) { func (a *API) Network_Get(id int64) (*Network, error) {
return db.Network_Get(a.db, id) n, err := db.Network_Get(a.db, id)
return n, errs.DB(err)
} }
func (a *API) Network_List() ([]*Network, error) { func (a *API) Network_List() ([]*Network, error) {
const query = db.Network_SelectQuery + ` ORDER BY LocalDomain ASC` const query = db.Network_SelectQuery + ` ORDER BY LocalDomain ASC`
return db.Network_List(a.db, query) n, err := db.Network_List(a.db, query)
return n, errs.DB(err)
} }
func (a *API) Peer_CreateNew(p *Peer) error { func (a *API) Peer_CreateNew(p *Peer) error {
@@ -177,7 +182,7 @@ func (a *API) Peer_CreateNew(p *Peer) error {
p.SignPubKey = []byte{} p.SignPubKey = []byte{}
p.APIKey = idgen.NewToken() p.APIKey = idgen.NewToken()
return db.Peer_Insert(a.db, p) return errs.DB(db.Peer_Insert(a.db, p))
} }
func (a *API) Peer_Init(peer *Peer, args m.PeerInitArgs) error { func (a *API) Peer_Init(peer *Peer, args m.PeerInitArgs) error {
@@ -191,27 +196,30 @@ func (a *API) Peer_Init(peer *Peer, args m.PeerInitArgs) error {
return err return err
} }
if len(current.WGPubKey) != 0 { if len(current.WGPubKey) != 0 {
return errors.New("peer already initialized") return errs.ErrAlreadyExists
} }
peer.WGPubKey = args.WGPubKey peer.WGPubKey = args.WGPubKey
peer.SignPubKey = args.SignPubKey peer.SignPubKey = args.SignPubKey
return db.Peer_UpdateFull(a.db, peer) return errs.DB(db.Peer_UpdateFull(a.db, peer))
} }
func (a *API) Peer_Delete(networkID int64, peerIP byte) error { func (a *API) Peer_Delete(networkID int64, peerIP byte) error {
return db.Peer_Delete(a.db, networkID, peerIP) return errs.DB(db.Peer_Delete(a.db, networkID, peerIP))
} }
func (a *API) Peer_List(networkID int64) ([]*Peer, error) { func (a *API) Peer_List(networkID int64) ([]*Peer, error) {
return db.Peer_ListAll(a.db, networkID) p, err := db.Peer_ListAll(a.db, networkID)
return p, errs.DB(err)
} }
func (a *API) Peer_Get(networkID int64, ip byte) (*Peer, error) { func (a *API) Peer_Get(networkID int64, ip byte) (*Peer, error) {
return db.Peer_Get(a.db, networkID, ip) p, err := db.Peer_Get(a.db, networkID, ip)
return p, errs.DB(err)
} }
func (a *API) Peer_GetByAPIKey(key string) (*Peer, error) { func (a *API) Peer_GetByAPIKey(key string) (*Peer, error) {
return db.Peer_GetByAPIKey(a.db, key) p, err := db.Peer_GetByAPIKey(a.db, key)
return p, errs.DB(err)
} }

View File

@@ -1,19 +1,9 @@
package db package db
import ( import (
"errors"
"net/netip" "net/netip"
"strings" "strings"
) "vppn/hub/errs"
var (
ErrInvalidIP = errors.New("invalid IP")
ErrInvalidPeerIP = errors.New("invalid peer IP")
ErrNonPrivateIP = errors.New("non-private IP")
ErrInvalidPort = errors.New("invalid port")
ErrInvalidNetName = errors.New("invalid network name")
ErrNetNameNotLocal = errors.New("network name must end with .local")
ErrInvalidPeerName = errors.New("invalid peer name")
) )
func Config_Sanitize(c *Config) { func Config_Sanitize(c *Config) {
@@ -35,11 +25,11 @@ func Network_Validate(c *Network) error {
// 15 bytes is linux limit for network interface names. With ending .local, // 15 bytes is linux limit for network interface names. With ending .local,
// max length is 21. // max length is 21.
if len(c.LocalDomain) == 0 || len(c.LocalDomain) > 21 { if len(c.LocalDomain) == 0 || len(c.LocalDomain) > 21 {
return ErrInvalidNetName return errs.ErrInvalidNetName
} }
if !strings.HasSuffix(c.LocalDomain, ".local") { if !strings.HasSuffix(c.LocalDomain, ".local") {
return ErrNetNameNotLocal return errs.ErrNetNameNotLocal
} }
for _, c := range strings.TrimSuffix(c.LocalDomain, ".local") { for _, c := range strings.TrimSuffix(c.LocalDomain, ".local") {
@@ -49,16 +39,16 @@ func Network_Validate(c *Network) error {
if c >= '0' && c <= '9' { if c >= '0' && c <= '9' {
continue continue
} }
return ErrInvalidNetName return errs.ErrInvalidNetName
} }
addr, ok := netip.AddrFromSlice(c.Network) addr, ok := netip.AddrFromSlice(c.Network)
if !ok || !addr.Is4() || addr.As4()[3] != 0 || addr.As4()[0] == 0 { if !ok || !addr.Is4() || addr.As4()[3] != 0 || addr.As4()[0] == 0 {
return ErrInvalidIP return errs.ErrInvalidIP
} }
if !addr.IsPrivate() { if !addr.IsPrivate() {
return ErrNonPrivateIP return errs.ErrNonPrivateIP
} }
return nil return nil
@@ -84,26 +74,26 @@ func Peer_Sanitize(p *Peer) {
func Peer_Validate(p *Peer) error { func Peer_Validate(p *Peer) error {
if p.PeerIP < 1 || p.PeerIP > 254 { if p.PeerIP < 1 || p.PeerIP > 254 {
return ErrInvalidPeerIP return errs.ErrInvalidPeerIP
} }
if len(p.Addr4) > 0 { if len(p.Addr4) > 0 {
// Must be a genuine IPv4 address (reject an IPv6 in the v4 field). // Must be a genuine IPv4 address (reject an IPv6 in the v4 field).
if addr, ok := netip.AddrFromSlice(p.Addr4); !ok || !addr.Is4() { if addr, ok := netip.AddrFromSlice(p.Addr4); !ok || !addr.Is4() {
return ErrInvalidIP return errs.ErrInvalidIP
} }
} }
if len(p.Addr6) > 0 { if len(p.Addr6) > 0 {
// Must be a genuine IPv6 address (reject IPv4 / IPv4-mapped in the v6 field). // Must be a genuine IPv6 address (reject IPv4 / IPv4-mapped in the v6 field).
if addr, ok := netip.AddrFromSlice(p.Addr6); !ok || !addr.Is6() || addr.Is4In6() { if addr, ok := netip.AddrFromSlice(p.Addr6); !ok || !addr.Is6() || addr.Is4In6() {
return ErrInvalidIP return errs.ErrInvalidIP
} }
} }
if p.Port == 0 { if p.Port == 0 {
return ErrInvalidPort return errs.ErrInvalidPort
} }
if len(p.Name) == 0 { if len(p.Name) == 0 {
return ErrInvalidPeerName return errs.ErrInvalidPeerName
} }
for _, c := range p.Name { for _, c := range p.Name {
if c >= 'a' && c <= 'z' { if c >= 'a' && c <= 'z' {
@@ -115,7 +105,7 @@ func Peer_Validate(p *Peer) error {
if c == '-' { if c == '-' {
continue continue
} }
return ErrInvalidPeerName return errs.ErrInvalidPeerName
} }
return nil return nil

View File

@@ -1,12 +0,0 @@
package api
import (
"errors"
"vppn/hub/api/db"
)
var (
ErrNotAuthorized = errors.New("not authorized")
ErrInvalidIP = db.ErrInvalidIP
ErrInvalidPort = db.ErrInvalidPort
)

View File

@@ -1,7 +0,0 @@
package api
import "time"
func timeSince(ts int64) int64 {
return time.Now().Unix() - ts
}

View File

@@ -1,6 +1,9 @@
package api package api
import "vppn/hub/api/db" import (
"time"
"vppn/hub/api/db"
)
type Config = db.Config type Config = db.Config
type Network = db.Network type Network = db.Network
@@ -9,6 +12,6 @@ type Peer = db.Peer
type Session struct { type Session struct {
SessionID string SessionID string
SignedIn bool SignedIn bool
CreatedAt int64 CreatedAt time.Time
LastSeenAt int64 LastSeenAt time.Time
} }