This commit is contained in:
jdl
2026-06-13 14:44:25 +02:00
parent a730211167
commit 232b68310c
5 changed files with 56 additions and 74 deletions

View File

@@ -8,6 +8,7 @@ import (
"sync"
"time"
"vppn/hub/api/db"
"vppn/hub/errs"
"vppn/m"
"git.crumpington.com/lib/go/idgen"
@@ -64,31 +65,32 @@ func (a *API) ensurePassword() error {
hashed, err := bcrypt.GenerateFromPassword([]byte(pwd), bcrypt.DefaultCost)
if err != nil {
return err
log.Printf("Failed to generate password: %v", err)
return errs.ErrUnexpected
}
conf := &Config{ConfigID: 1, Password: hashed}
return db.Config_Insert(a.db, conf)
return errs.DB(db.Config_Insert(a.db, conf))
}
func (a *API) Config_Get() (*Config, error) {
return db.Config_Get(a.db, 1)
conf, err := db.Config_Get(a.db, 1)
return conf, errs.DB(err)
}
func (a *API) Config_Update(conf *Config) error {
return db.Config_Update(a.db, conf)
return errs.DB(db.Config_Update(a.db, conf))
}
func (a *API) Session_Delete(sessionID string) error {
func (a *API) Session_Delete(sessionID string) {
a.sessionsMu.Lock()
defer a.sessionsMu.Unlock()
delete(a.sessions, sessionID)
return nil
}
const (
sessionTTLSecs = 86400 * 21 // sessions expire 21 days after last use
sessionSweepEvery = time.Hour // cadence of expired-session eviction
sessionTTLSecs = 24 * 21 * time.Hour // sessions expire 21 days after last use
sessionSweepEvery = time.Hour // cadence of expired-session eviction
)
// Session_Get returns a snapshot copy of the signed-in session for sessionID,
@@ -96,23 +98,23 @@ const (
// creates a session, so anonymous requests cost no memory — a session is minted
// only by Session_SignIn. Returning a value (not the stored pointer) keeps
// callers from racing on the shared struct.
func (a *API) Session_Get(sessionID string) (Session, error) {
func (a *API) Session_Get(sessionID string) Session {
a.sessionsMu.Lock()
defer a.sessionsMu.Unlock()
s, ok := a.sessions[sessionID]
if sessionID == "" || !ok {
return Session{}, nil
return Session{}
}
if timeSince(s.LastSeenAt) > sessionTTLSecs {
if time.Since(s.LastSeenAt) > sessionTTLSecs {
delete(a.sessions, sessionID)
return Session{}, nil
return Session{}
}
s.LastSeenAt = time.Now().Unix()
return *s, nil
s.LastSeenAt = time.Now()
return *s
}
// Session_SignIn verifies pwd and, on success, mints a fresh signed-in session,
@@ -121,10 +123,11 @@ func (a *API) Session_Get(sessionID string) (Session, error) {
func (a *API) Session_SignIn(pwd string) (Session, error) {
conf, err := a.Config_Get()
if err != nil {
return Session{}, err
log.Printf("Failed to get config: %v", err)
return Session{}, errs.ErrUnexpected
}
if err := bcrypt.CompareHashAndPassword(conf.Password, []byte(pwd)); err != nil {
return Session{}, ErrNotAuthorized
return Session{}, errs.NotAuthorized.WithMsg("Not authorized.")
}
a.sessionsMu.Lock()
@@ -132,8 +135,8 @@ func (a *API) Session_SignIn(pwd string) (Session, error) {
s := &Session{
SessionID: idgen.NewToken(),
SignedIn: true,
CreatedAt: time.Now().Unix(),
LastSeenAt: time.Now().Unix(),
CreatedAt: time.Now(),
LastSeenAt: time.Now(),
}
a.sessions[s.SessionID] = s
return *s, nil
@@ -146,7 +149,7 @@ func (a *API) sweepSessions() {
for range time.Tick(sessionSweepEvery) {
a.sessionsMu.Lock()
for id, s := range a.sessions {
if timeSince(s.LastSeenAt) > sessionTTLSecs {
if time.Since(s.LastSeenAt) > sessionTTLSecs {
delete(a.sessions, id)
}
}
@@ -156,20 +159,22 @@ func (a *API) sweepSessions() {
func (a *API) Network_Create(n *Network) error {
n.NetworkID = idgen.NextID(0)
return db.Network_Insert(a.db, n)
return errs.DB(db.Network_Insert(a.db, n))
}
func (a *API) Network_Delete(n *Network) error {
return db.Network_Delete(a.db, n.NetworkID)
return errs.DB(db.Network_Delete(a.db, n.NetworkID))
}
func (a *API) Network_Get(id int64) (*Network, error) {
return db.Network_Get(a.db, id)
n, err := db.Network_Get(a.db, id)
return n, errs.DB(err)
}
func (a *API) Network_List() ([]*Network, error) {
const query = db.Network_SelectQuery + ` ORDER BY LocalDomain ASC`
return db.Network_List(a.db, query)
n, err := db.Network_List(a.db, query)
return n, errs.DB(err)
}
func (a *API) Peer_CreateNew(p *Peer) error {
@@ -177,7 +182,7 @@ func (a *API) Peer_CreateNew(p *Peer) error {
p.SignPubKey = []byte{}
p.APIKey = idgen.NewToken()
return db.Peer_Insert(a.db, p)
return errs.DB(db.Peer_Insert(a.db, p))
}
func (a *API) Peer_Init(peer *Peer, args m.PeerInitArgs) error {
@@ -191,27 +196,30 @@ func (a *API) Peer_Init(peer *Peer, args m.PeerInitArgs) error {
return err
}
if len(current.WGPubKey) != 0 {
return errors.New("peer already initialized")
return errs.ErrAlreadyExists
}
peer.WGPubKey = args.WGPubKey
peer.SignPubKey = args.SignPubKey
return db.Peer_UpdateFull(a.db, peer)
return errs.DB(db.Peer_UpdateFull(a.db, peer))
}
func (a *API) Peer_Delete(networkID int64, peerIP byte) error {
return db.Peer_Delete(a.db, networkID, peerIP)
return errs.DB(db.Peer_Delete(a.db, networkID, peerIP))
}
func (a *API) Peer_List(networkID int64) ([]*Peer, error) {
return db.Peer_ListAll(a.db, networkID)
p, err := db.Peer_ListAll(a.db, networkID)
return p, errs.DB(err)
}
func (a *API) Peer_Get(networkID int64, ip byte) (*Peer, error) {
return db.Peer_Get(a.db, networkID, ip)
p, err := db.Peer_Get(a.db, networkID, ip)
return p, errs.DB(err)
}
func (a *API) Peer_GetByAPIKey(key string) (*Peer, error) {
return db.Peer_GetByAPIKey(a.db, key)
p, err := db.Peer_GetByAPIKey(a.db, key)
return p, errs.DB(err)
}

View File

@@ -1,19 +1,9 @@
package db
import (
"errors"
"net/netip"
"strings"
)
var (
ErrInvalidIP = errors.New("invalid IP")
ErrInvalidPeerIP = errors.New("invalid peer IP")
ErrNonPrivateIP = errors.New("non-private IP")
ErrInvalidPort = errors.New("invalid port")
ErrInvalidNetName = errors.New("invalid network name")
ErrNetNameNotLocal = errors.New("network name must end with .local")
ErrInvalidPeerName = errors.New("invalid peer name")
"vppn/hub/errs"
)
func Config_Sanitize(c *Config) {
@@ -35,11 +25,11 @@ func Network_Validate(c *Network) error {
// 15 bytes is linux limit for network interface names. With ending .local,
// max length is 21.
if len(c.LocalDomain) == 0 || len(c.LocalDomain) > 21 {
return ErrInvalidNetName
return errs.ErrInvalidNetName
}
if !strings.HasSuffix(c.LocalDomain, ".local") {
return ErrNetNameNotLocal
return errs.ErrNetNameNotLocal
}
for _, c := range strings.TrimSuffix(c.LocalDomain, ".local") {
@@ -49,16 +39,16 @@ func Network_Validate(c *Network) error {
if c >= '0' && c <= '9' {
continue
}
return ErrInvalidNetName
return errs.ErrInvalidNetName
}
addr, ok := netip.AddrFromSlice(c.Network)
if !ok || !addr.Is4() || addr.As4()[3] != 0 || addr.As4()[0] == 0 {
return ErrInvalidIP
return errs.ErrInvalidIP
}
if !addr.IsPrivate() {
return ErrNonPrivateIP
return errs.ErrNonPrivateIP
}
return nil
@@ -84,26 +74,26 @@ func Peer_Sanitize(p *Peer) {
func Peer_Validate(p *Peer) error {
if p.PeerIP < 1 || p.PeerIP > 254 {
return ErrInvalidPeerIP
return errs.ErrInvalidPeerIP
}
if len(p.Addr4) > 0 {
// Must be a genuine IPv4 address (reject an IPv6 in the v4 field).
if addr, ok := netip.AddrFromSlice(p.Addr4); !ok || !addr.Is4() {
return ErrInvalidIP
return errs.ErrInvalidIP
}
}
if len(p.Addr6) > 0 {
// Must be a genuine IPv6 address (reject IPv4 / IPv4-mapped in the v6 field).
if addr, ok := netip.AddrFromSlice(p.Addr6); !ok || !addr.Is6() || addr.Is4In6() {
return ErrInvalidIP
return errs.ErrInvalidIP
}
}
if p.Port == 0 {
return ErrInvalidPort
return errs.ErrInvalidPort
}
if len(p.Name) == 0 {
return ErrInvalidPeerName
return errs.ErrInvalidPeerName
}
for _, c := range p.Name {
if c >= 'a' && c <= 'z' {
@@ -115,7 +105,7 @@ func Peer_Validate(p *Peer) error {
if c == '-' {
continue
}
return ErrInvalidPeerName
return errs.ErrInvalidPeerName
}
return nil

View File

@@ -1,12 +0,0 @@
package api
import (
"errors"
"vppn/hub/api/db"
)
var (
ErrNotAuthorized = errors.New("not authorized")
ErrInvalidIP = db.ErrInvalidIP
ErrInvalidPort = db.ErrInvalidPort
)

View File

@@ -1,7 +0,0 @@
package api
import "time"
func timeSince(ts int64) int64 {
return time.Now().Unix() - ts
}

View File

@@ -1,6 +1,9 @@
package api
import "vppn/hub/api/db"
import (
"time"
"vppn/hub/api/db"
)
type Config = db.Config
type Network = db.Network
@@ -9,6 +12,6 @@ type Peer = db.Peer
type Session struct {
SessionID string
SignedIn bool
CreatedAt int64
LastSeenAt int64
CreatedAt time.Time
LastSeenAt time.Time
}