237 lines
5.6 KiB
Go
237 lines
5.6 KiB
Go
package api
|
|
|
|
import (
|
|
"database/sql"
|
|
"embed"
|
|
"errors"
|
|
"log"
|
|
"sync"
|
|
"time"
|
|
"vppn/hub/api/db"
|
|
"vppn/hub/errs"
|
|
"vppn/m"
|
|
|
|
"git.crumpington.com/lib/go/idgen"
|
|
"git.crumpington.com/lib/go/sqliteutil"
|
|
"golang.org/x/crypto/bcrypt"
|
|
)
|
|
|
|
//go:embed migrations
|
|
var migrations embed.FS
|
|
|
|
type API struct {
|
|
db *sql.DB
|
|
lock sync.Mutex
|
|
sessionsMu sync.Mutex
|
|
sessions map[string]*Session
|
|
}
|
|
|
|
func New(dbPath string) (*API, error) {
|
|
sqlDB, err := sql.Open("sqlite3", dbPath+"?_journal=WAL")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := sqliteutil.Migrate(sqlDB, migrations); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
a := &API{
|
|
db: sqlDB,
|
|
sessions: make(map[string]*Session),
|
|
}
|
|
|
|
if err := a.ensurePassword(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
go a.sweepSessions()
|
|
|
|
return a, nil
|
|
}
|
|
|
|
func (a *API) ensurePassword() error {
|
|
_, err := db.Config_Get(a.db, 1)
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
if !errors.Is(err, sql.ErrNoRows) {
|
|
return err
|
|
}
|
|
|
|
pwd := idgen.NewToken()
|
|
|
|
log.Printf("Setting password: %s", pwd)
|
|
|
|
hashed, err := bcrypt.GenerateFromPassword([]byte(pwd), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
log.Printf("Failed to generate password: %v", err)
|
|
return errs.ErrUnexpected
|
|
}
|
|
|
|
conf := &Config{ConfigID: 1, Password: hashed}
|
|
return errs.DB(db.Config_Insert(a.db, conf))
|
|
}
|
|
|
|
func (a *API) Config_Get() (*Config, error) {
|
|
conf, err := db.Config_Get(a.db, 1)
|
|
return conf, errs.DB(err)
|
|
}
|
|
|
|
func (a *API) Config_Update(conf *Config) error {
|
|
return errs.DB(db.Config_Update(a.db, conf))
|
|
}
|
|
|
|
func (a *API) Session_Delete(sessionID string) {
|
|
a.sessionsMu.Lock()
|
|
defer a.sessionsMu.Unlock()
|
|
delete(a.sessions, sessionID)
|
|
}
|
|
|
|
const (
|
|
sessionTTLSecs = 24 * 21 * time.Hour // sessions expire 21 days after last use
|
|
sessionSweepEvery = time.Hour // cadence of expired-session eviction
|
|
)
|
|
|
|
// Session_Get returns a snapshot copy of the signed-in session for sessionID,
|
|
// or the zero Session if the cookie is missing/unknown/expired. It never
|
|
// creates a session, so anonymous requests cost no memory — a session is minted
|
|
// only by Session_SignIn. Returning a value (not the stored pointer) keeps
|
|
// callers from racing on the shared struct.
|
|
func (a *API) Session_Get(sessionID string) Session {
|
|
a.sessionsMu.Lock()
|
|
defer a.sessionsMu.Unlock()
|
|
|
|
s, ok := a.sessions[sessionID]
|
|
|
|
if sessionID == "" || !ok {
|
|
return Session{}
|
|
}
|
|
|
|
if time.Since(s.LastSeenAt) > sessionTTLSecs {
|
|
delete(a.sessions, sessionID)
|
|
return Session{}
|
|
}
|
|
|
|
s.LastSeenAt = time.Now()
|
|
return *s
|
|
}
|
|
|
|
// Session_SignIn verifies pwd and, on success, mints a fresh signed-in session,
|
|
// returning it so the caller can set the cookie. A new ID per sign-in rotates
|
|
// the session at the privilege boundary (session-fixation resistance).
|
|
func (a *API) Session_SignIn(pwd string) (Session, error) {
|
|
conf, err := a.Config_Get()
|
|
if err != nil {
|
|
log.Printf("Failed to get config: %v", err)
|
|
return Session{}, errs.ErrUnexpected
|
|
}
|
|
if err := bcrypt.CompareHashAndPassword(conf.Password, []byte(pwd)); err != nil {
|
|
return Session{}, errs.ErrNotAuthorized
|
|
}
|
|
|
|
a.sessionsMu.Lock()
|
|
defer a.sessionsMu.Unlock()
|
|
s := &Session{
|
|
SessionID: idgen.NewToken(),
|
|
LastSeenAt: time.Now(),
|
|
}
|
|
a.sessions[s.SessionID] = s
|
|
return *s, nil
|
|
}
|
|
|
|
func (a *API) Session_InvalidateAll() Session {
|
|
a.sessionsMu.Lock()
|
|
defer a.sessionsMu.Unlock()
|
|
|
|
clear(a.sessions)
|
|
s := &Session{
|
|
SessionID: idgen.NewToken(),
|
|
LastSeenAt: time.Now(),
|
|
}
|
|
a.sessions[s.SessionID] = s
|
|
return *s
|
|
}
|
|
|
|
// sweepSessions periodically evicts sessions past their TTL. Without it, a
|
|
// signed-in session whose ID is never presented again would linger forever
|
|
// (Session_Get only evicts on a lookup of that same ID).
|
|
func (a *API) sweepSessions() {
|
|
for range time.Tick(sessionSweepEvery) {
|
|
a.sessionsMu.Lock()
|
|
for id, s := range a.sessions {
|
|
if time.Since(s.LastSeenAt) > sessionTTLSecs {
|
|
delete(a.sessions, id)
|
|
}
|
|
}
|
|
a.sessionsMu.Unlock()
|
|
}
|
|
}
|
|
|
|
func (a *API) Network_Create(n *Network) error {
|
|
n.NetworkID = idgen.NextID(0)
|
|
return errs.DB(db.Network_Insert(a.db, n))
|
|
}
|
|
|
|
func (a *API) Network_Delete(n *Network) error {
|
|
return errs.DB(db.Network_Delete(a.db, n.NetworkID))
|
|
}
|
|
|
|
func (a *API) Network_Get(id int64) (*Network, error) {
|
|
n, err := db.Network_Get(a.db, id)
|
|
return n, errs.DB(err)
|
|
}
|
|
|
|
func (a *API) Network_List() ([]*Network, error) {
|
|
const query = db.Network_SelectQuery + ` ORDER BY LocalDomain ASC`
|
|
n, err := db.Network_List(a.db, query)
|
|
return n, errs.DB(err)
|
|
}
|
|
|
|
func (a *API) Peer_CreateNew(p *Peer) error {
|
|
p.WGPubKey = []byte{}
|
|
p.SignPubKey = []byte{}
|
|
p.APIKey = idgen.NewToken()
|
|
|
|
return errs.DB(db.Peer_Insert(a.db, p))
|
|
}
|
|
|
|
func (a *API) Peer_Init(peer *Peer, args m.PeerInitArgs) error {
|
|
a.lock.Lock()
|
|
defer a.lock.Unlock()
|
|
|
|
// Re-read from DB inside the lock — the caller's copy was fetched before
|
|
// we held the lock, so it may be stale under concurrent requests.
|
|
current, err := db.Peer_Get(a.db, peer.NetworkID, peer.PeerIP)
|
|
if err != nil {
|
|
return errs.DB(err)
|
|
}
|
|
if len(current.WGPubKey) != 0 {
|
|
return errs.ErrAlreadyExists
|
|
}
|
|
|
|
peer.WGPubKey = args.WGPubKey
|
|
peer.SignPubKey = args.SignPubKey
|
|
|
|
return errs.DB(db.Peer_UpdateFull(a.db, peer))
|
|
}
|
|
|
|
func (a *API) Peer_Delete(networkID int64, peerIP byte) error {
|
|
return errs.DB(db.Peer_Delete(a.db, networkID, peerIP))
|
|
}
|
|
|
|
func (a *API) Peer_List(networkID int64) ([]*Peer, error) {
|
|
p, err := db.Peer_ListAll(a.db, networkID)
|
|
return p, errs.DB(err)
|
|
}
|
|
|
|
func (a *API) Peer_Get(networkID int64, ip byte) (*Peer, error) {
|
|
p, err := db.Peer_Get(a.db, networkID, ip)
|
|
return p, errs.DB(err)
|
|
}
|
|
|
|
func (a *API) Peer_GetByAPIKey(key string) (*Peer, error) {
|
|
p, err := db.Peer_GetByAPIKey(a.db, key)
|
|
return p, errs.DB(err)
|
|
}
|