Files
vppn/hub/api/api.go
2026-06-14 05:53:12 +02:00

237 lines
5.6 KiB
Go

package api
import (
"database/sql"
"embed"
"errors"
"log"
"sync"
"time"
"vppn/hub/api/db"
"vppn/hub/errs"
"vppn/m"
"git.crumpington.com/lib/go/idgen"
"git.crumpington.com/lib/go/sqliteutil"
"golang.org/x/crypto/bcrypt"
)
//go:embed migrations
var migrations embed.FS
type API struct {
db *sql.DB
lock sync.Mutex
sessionsMu sync.Mutex
sessions map[string]*Session
}
func New(dbPath string) (*API, error) {
sqlDB, err := sql.Open("sqlite3", dbPath+"?_journal=WAL")
if err != nil {
return nil, err
}
if err := sqliteutil.Migrate(sqlDB, migrations); err != nil {
return nil, err
}
a := &API{
db: sqlDB,
sessions: make(map[string]*Session),
}
if err := a.ensurePassword(); err != nil {
return nil, err
}
go a.sweepSessions()
return a, nil
}
func (a *API) ensurePassword() error {
_, err := db.Config_Get(a.db, 1)
if err == nil {
return nil
}
if !errors.Is(err, sql.ErrNoRows) {
return err
}
pwd := idgen.NewToken()
log.Printf("Setting password: %s", pwd)
hashed, err := bcrypt.GenerateFromPassword([]byte(pwd), bcrypt.DefaultCost)
if err != nil {
log.Printf("Failed to generate password: %v", err)
return errs.ErrUnexpected
}
conf := &Config{ConfigID: 1, Password: hashed}
return errs.DB(db.Config_Insert(a.db, conf))
}
func (a *API) Config_Get() (*Config, error) {
conf, err := db.Config_Get(a.db, 1)
return conf, errs.DB(err)
}
func (a *API) Config_Update(conf *Config) error {
return errs.DB(db.Config_Update(a.db, conf))
}
func (a *API) Session_Delete(sessionID string) {
a.sessionsMu.Lock()
defer a.sessionsMu.Unlock()
delete(a.sessions, sessionID)
}
const (
sessionTTL = 24 * 21 * time.Hour // sessions expire 21 days after last use
sessionSweepEvery = time.Hour // cadence of expired-session eviction
)
// Session_Get returns a snapshot copy of the signed-in session for sessionID,
// or the zero Session if the cookie is missing/unknown/expired. It never
// creates a session, so anonymous requests cost no memory — a session is minted
// only by Session_SignIn. Returning a value (not the stored pointer) keeps
// callers from racing on the shared struct.
func (a *API) Session_Get(sessionID string) Session {
a.sessionsMu.Lock()
defer a.sessionsMu.Unlock()
s, ok := a.sessions[sessionID]
if sessionID == "" || !ok {
return Session{}
}
if time.Since(s.LastSeenAt) > sessionTTL {
delete(a.sessions, sessionID)
return Session{}
}
s.LastSeenAt = time.Now()
return *s
}
// Session_SignIn verifies pwd and, on success, mints a fresh signed-in session,
// returning it so the caller can set the cookie. A new ID per sign-in rotates
// the session at the privilege boundary (session-fixation resistance).
func (a *API) Session_SignIn(pwd string) (Session, error) {
conf, err := a.Config_Get()
if err != nil {
log.Printf("Failed to get config: %v", err)
return Session{}, errs.ErrUnexpected
}
if err := bcrypt.CompareHashAndPassword(conf.Password, []byte(pwd)); err != nil {
return Session{}, errs.ErrNotAuthorized
}
a.sessionsMu.Lock()
defer a.sessionsMu.Unlock()
s := &Session{
SessionID: idgen.NewToken(),
LastSeenAt: time.Now(),
}
a.sessions[s.SessionID] = s
return *s, nil
}
func (a *API) Session_InvalidateAll() Session {
a.sessionsMu.Lock()
defer a.sessionsMu.Unlock()
clear(a.sessions)
s := &Session{
SessionID: idgen.NewToken(),
LastSeenAt: time.Now(),
}
a.sessions[s.SessionID] = s
return *s
}
// sweepSessions periodically evicts sessions past their TTL. Without it, a
// signed-in session whose ID is never presented again would linger forever
// (Session_Get only evicts on a lookup of that same ID).
func (a *API) sweepSessions() {
for range time.Tick(sessionSweepEvery) {
a.sessionsMu.Lock()
for id, s := range a.sessions {
if time.Since(s.LastSeenAt) > sessionTTL {
delete(a.sessions, id)
}
}
a.sessionsMu.Unlock()
}
}
func (a *API) Network_Create(n *Network) error {
n.NetworkID = idgen.NextID(0)
return errs.DB(db.Network_Insert(a.db, n))
}
func (a *API) Network_Delete(n *Network) error {
return errs.DB(db.Network_Delete(a.db, n.NetworkID))
}
func (a *API) Network_Get(id int64) (*Network, error) {
n, err := db.Network_Get(a.db, id)
return n, errs.DB(err)
}
func (a *API) Network_List() ([]*Network, error) {
const query = db.Network_SelectQuery + ` ORDER BY LocalDomain ASC`
n, err := db.Network_List(a.db, query)
return n, errs.DB(err)
}
func (a *API) Peer_CreateNew(p *Peer) error {
p.WGPubKey = []byte{}
p.SignPubKey = []byte{}
p.APIKey = idgen.NewToken()
return errs.DB(db.Peer_Insert(a.db, p))
}
func (a *API) Peer_Init(peer *Peer, args m.PeerInitArgs) error {
a.lock.Lock()
defer a.lock.Unlock()
// Re-read from DB inside the lock — the caller's copy was fetched before
// we held the lock, so it may be stale under concurrent requests.
current, err := db.Peer_Get(a.db, peer.NetworkID, peer.PeerIP)
if err != nil {
return errs.DB(err)
}
if len(current.WGPubKey) != 0 {
return errs.ErrAlreadyExists
}
peer.WGPubKey = args.WGPubKey
peer.SignPubKey = args.SignPubKey
return errs.DB(db.Peer_UpdateFull(a.db, peer))
}
func (a *API) Peer_Delete(networkID int64, peerIP byte) error {
return errs.DB(db.Peer_Delete(a.db, networkID, peerIP))
}
func (a *API) Peer_List(networkID int64) ([]*Peer, error) {
p, err := db.Peer_ListAll(a.db, networkID)
return p, errs.DB(err)
}
func (a *API) Peer_Get(networkID int64, ip byte) (*Peer, error) {
p, err := db.Peer_Get(a.db, networkID, ip)
return p, errs.DB(err)
}
func (a *API) Peer_GetByAPIKey(key string) (*Peer, error) {
p, err := db.Peer_GetByAPIKey(a.db, key)
return p, errs.DB(err)
}