diff --git a/hub/api/api.go b/hub/api/api.go index 19838b8..1790932 100644 --- a/hub/api/api.go +++ b/hub/api/api.go @@ -8,6 +8,7 @@ import ( "sync" "time" "vppn/hub/api/db" + "vppn/hub/errs" "vppn/m" "git.crumpington.com/lib/go/idgen" @@ -64,31 +65,32 @@ func (a *API) ensurePassword() error { hashed, err := bcrypt.GenerateFromPassword([]byte(pwd), bcrypt.DefaultCost) if err != nil { - return err + log.Printf("Failed to generate password: %v", err) + return errs.ErrUnexpected } conf := &Config{ConfigID: 1, Password: hashed} - return db.Config_Insert(a.db, conf) + return errs.DB(db.Config_Insert(a.db, conf)) } func (a *API) Config_Get() (*Config, error) { - return db.Config_Get(a.db, 1) + conf, err := db.Config_Get(a.db, 1) + return conf, errs.DB(err) } func (a *API) Config_Update(conf *Config) error { - return db.Config_Update(a.db, conf) + return errs.DB(db.Config_Update(a.db, conf)) } -func (a *API) Session_Delete(sessionID string) error { +func (a *API) Session_Delete(sessionID string) { a.sessionsMu.Lock() defer a.sessionsMu.Unlock() delete(a.sessions, sessionID) - return nil } const ( - sessionTTLSecs = 86400 * 21 // sessions expire 21 days after last use - sessionSweepEvery = time.Hour // cadence of expired-session eviction + sessionTTLSecs = 24 * 21 * time.Hour // sessions expire 21 days after last use + sessionSweepEvery = time.Hour // cadence of expired-session eviction ) // Session_Get returns a snapshot copy of the signed-in session for sessionID, @@ -96,23 +98,23 @@ const ( // creates a session, so anonymous requests cost no memory — a session is minted // only by Session_SignIn. Returning a value (not the stored pointer) keeps // callers from racing on the shared struct. -func (a *API) Session_Get(sessionID string) (Session, error) { +func (a *API) Session_Get(sessionID string) Session { a.sessionsMu.Lock() defer a.sessionsMu.Unlock() s, ok := a.sessions[sessionID] if sessionID == "" || !ok { - return Session{}, nil + return Session{} } - if timeSince(s.LastSeenAt) > sessionTTLSecs { + if time.Since(s.LastSeenAt) > sessionTTLSecs { delete(a.sessions, sessionID) - return Session{}, nil + return Session{} } - s.LastSeenAt = time.Now().Unix() - return *s, nil + s.LastSeenAt = time.Now() + return *s } // Session_SignIn verifies pwd and, on success, mints a fresh signed-in session, @@ -121,10 +123,11 @@ func (a *API) Session_Get(sessionID string) (Session, error) { func (a *API) Session_SignIn(pwd string) (Session, error) { conf, err := a.Config_Get() if err != nil { - return Session{}, err + log.Printf("Failed to get config: %v", err) + return Session{}, errs.ErrUnexpected } if err := bcrypt.CompareHashAndPassword(conf.Password, []byte(pwd)); err != nil { - return Session{}, ErrNotAuthorized + return Session{}, errs.NotAuthorized.WithMsg("Not authorized.") } a.sessionsMu.Lock() @@ -132,8 +135,8 @@ func (a *API) Session_SignIn(pwd string) (Session, error) { s := &Session{ SessionID: idgen.NewToken(), SignedIn: true, - CreatedAt: time.Now().Unix(), - LastSeenAt: time.Now().Unix(), + CreatedAt: time.Now(), + LastSeenAt: time.Now(), } a.sessions[s.SessionID] = s return *s, nil @@ -146,7 +149,7 @@ func (a *API) sweepSessions() { for range time.Tick(sessionSweepEvery) { a.sessionsMu.Lock() for id, s := range a.sessions { - if timeSince(s.LastSeenAt) > sessionTTLSecs { + if time.Since(s.LastSeenAt) > sessionTTLSecs { delete(a.sessions, id) } } @@ -156,20 +159,22 @@ func (a *API) sweepSessions() { func (a *API) Network_Create(n *Network) error { n.NetworkID = idgen.NextID(0) - return db.Network_Insert(a.db, n) + return errs.DB(db.Network_Insert(a.db, n)) } func (a *API) Network_Delete(n *Network) error { - return db.Network_Delete(a.db, n.NetworkID) + return errs.DB(db.Network_Delete(a.db, n.NetworkID)) } func (a *API) Network_Get(id int64) (*Network, error) { - return db.Network_Get(a.db, id) + n, err := db.Network_Get(a.db, id) + return n, errs.DB(err) } func (a *API) Network_List() ([]*Network, error) { const query = db.Network_SelectQuery + ` ORDER BY LocalDomain ASC` - return db.Network_List(a.db, query) + n, err := db.Network_List(a.db, query) + return n, errs.DB(err) } func (a *API) Peer_CreateNew(p *Peer) error { @@ -177,7 +182,7 @@ func (a *API) Peer_CreateNew(p *Peer) error { p.SignPubKey = []byte{} p.APIKey = idgen.NewToken() - return db.Peer_Insert(a.db, p) + return errs.DB(db.Peer_Insert(a.db, p)) } func (a *API) Peer_Init(peer *Peer, args m.PeerInitArgs) error { @@ -191,27 +196,30 @@ func (a *API) Peer_Init(peer *Peer, args m.PeerInitArgs) error { return err } if len(current.WGPubKey) != 0 { - return errors.New("peer already initialized") + return errs.ErrAlreadyExists } peer.WGPubKey = args.WGPubKey peer.SignPubKey = args.SignPubKey - return db.Peer_UpdateFull(a.db, peer) + return errs.DB(db.Peer_UpdateFull(a.db, peer)) } func (a *API) Peer_Delete(networkID int64, peerIP byte) error { - return db.Peer_Delete(a.db, networkID, peerIP) + return errs.DB(db.Peer_Delete(a.db, networkID, peerIP)) } func (a *API) Peer_List(networkID int64) ([]*Peer, error) { - return db.Peer_ListAll(a.db, networkID) + p, err := db.Peer_ListAll(a.db, networkID) + return p, errs.DB(err) } func (a *API) Peer_Get(networkID int64, ip byte) (*Peer, error) { - return db.Peer_Get(a.db, networkID, ip) + p, err := db.Peer_Get(a.db, networkID, ip) + return p, errs.DB(err) } func (a *API) Peer_GetByAPIKey(key string) (*Peer, error) { - return db.Peer_GetByAPIKey(a.db, key) + p, err := db.Peer_GetByAPIKey(a.db, key) + return p, errs.DB(err) } diff --git a/hub/api/db/sanitize-validate.go b/hub/api/db/sanitize-validate.go index ffe1f7d..6ca1d65 100644 --- a/hub/api/db/sanitize-validate.go +++ b/hub/api/db/sanitize-validate.go @@ -1,19 +1,9 @@ package db import ( - "errors" "net/netip" "strings" -) - -var ( - ErrInvalidIP = errors.New("invalid IP") - ErrInvalidPeerIP = errors.New("invalid peer IP") - ErrNonPrivateIP = errors.New("non-private IP") - ErrInvalidPort = errors.New("invalid port") - ErrInvalidNetName = errors.New("invalid network name") - ErrNetNameNotLocal = errors.New("network name must end with .local") - ErrInvalidPeerName = errors.New("invalid peer name") + "vppn/hub/errs" ) func Config_Sanitize(c *Config) { @@ -35,11 +25,11 @@ func Network_Validate(c *Network) error { // 15 bytes is linux limit for network interface names. With ending .local, // max length is 21. if len(c.LocalDomain) == 0 || len(c.LocalDomain) > 21 { - return ErrInvalidNetName + return errs.ErrInvalidNetName } if !strings.HasSuffix(c.LocalDomain, ".local") { - return ErrNetNameNotLocal + return errs.ErrNetNameNotLocal } for _, c := range strings.TrimSuffix(c.LocalDomain, ".local") { @@ -49,16 +39,16 @@ func Network_Validate(c *Network) error { if c >= '0' && c <= '9' { continue } - return ErrInvalidNetName + return errs.ErrInvalidNetName } addr, ok := netip.AddrFromSlice(c.Network) if !ok || !addr.Is4() || addr.As4()[3] != 0 || addr.As4()[0] == 0 { - return ErrInvalidIP + return errs.ErrInvalidIP } if !addr.IsPrivate() { - return ErrNonPrivateIP + return errs.ErrNonPrivateIP } return nil @@ -84,26 +74,26 @@ func Peer_Sanitize(p *Peer) { func Peer_Validate(p *Peer) error { if p.PeerIP < 1 || p.PeerIP > 254 { - return ErrInvalidPeerIP + return errs.ErrInvalidPeerIP } if len(p.Addr4) > 0 { // Must be a genuine IPv4 address (reject an IPv6 in the v4 field). if addr, ok := netip.AddrFromSlice(p.Addr4); !ok || !addr.Is4() { - return ErrInvalidIP + return errs.ErrInvalidIP } } if len(p.Addr6) > 0 { // Must be a genuine IPv6 address (reject IPv4 / IPv4-mapped in the v6 field). if addr, ok := netip.AddrFromSlice(p.Addr6); !ok || !addr.Is6() || addr.Is4In6() { - return ErrInvalidIP + return errs.ErrInvalidIP } } if p.Port == 0 { - return ErrInvalidPort + return errs.ErrInvalidPort } if len(p.Name) == 0 { - return ErrInvalidPeerName + return errs.ErrInvalidPeerName } for _, c := range p.Name { if c >= 'a' && c <= 'z' { @@ -115,7 +105,7 @@ func Peer_Validate(p *Peer) error { if c == '-' { continue } - return ErrInvalidPeerName + return errs.ErrInvalidPeerName } return nil diff --git a/hub/api/errors.go b/hub/api/errors.go deleted file mode 100644 index 27d241f..0000000 --- a/hub/api/errors.go +++ /dev/null @@ -1,12 +0,0 @@ -package api - -import ( - "errors" - "vppn/hub/api/db" -) - -var ( - ErrNotAuthorized = errors.New("not authorized") - ErrInvalidIP = db.ErrInvalidIP - ErrInvalidPort = db.ErrInvalidPort -) diff --git a/hub/api/time.go b/hub/api/time.go deleted file mode 100644 index eb1342b..0000000 --- a/hub/api/time.go +++ /dev/null @@ -1,7 +0,0 @@ -package api - -import "time" - -func timeSince(ts int64) int64 { - return time.Now().Unix() - ts -} diff --git a/hub/api/types.go b/hub/api/types.go index d1a3b35..f74a2c1 100644 --- a/hub/api/types.go +++ b/hub/api/types.go @@ -1,6 +1,9 @@ package api -import "vppn/hub/api/db" +import ( + "time" + "vppn/hub/api/db" +) type Config = db.Config type Network = db.Network @@ -9,6 +12,6 @@ type Peer = db.Peer type Session struct { SessionID string SignedIn bool - CreatedAt int64 - LastSeenAt int64 + CreatedAt time.Time + LastSeenAt time.Time }