Refactor - now wireguard based. (#7)
This commit is contained in:
@@ -123,7 +123,9 @@ func Config_Get(
|
||||
) {
|
||||
row = &Config{}
|
||||
r := tx.QueryRow("SELECT ConfigID,Password FROM config WHERE ConfigID=?", ConfigID)
|
||||
err = r.Scan(&row.ConfigID, &row.Password)
|
||||
if err = r.Scan(&row.ConfigID, &row.Password); err != nil {
|
||||
row = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -137,7 +139,9 @@ func Config_GetWhere(
|
||||
) {
|
||||
row = &Config{}
|
||||
r := tx.QueryRow(query, args...)
|
||||
err = r.Scan(&row.ConfigID, &row.Password)
|
||||
if err = r.Scan(&row.ConfigID, &row.Password); err != nil {
|
||||
row = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -182,135 +186,17 @@ func Config_List(
|
||||
return l, nil
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Table: sessions
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type Session struct {
|
||||
SessionID string
|
||||
CSRF string
|
||||
SignedIn bool
|
||||
CreatedAt int64
|
||||
LastSeenAt int64
|
||||
}
|
||||
|
||||
const Session_SelectQuery = "SELECT SessionID,CSRF,SignedIn,CreatedAt,LastSeenAt FROM sessions"
|
||||
|
||||
func Session_Insert(
|
||||
tx TX,
|
||||
row *Session,
|
||||
) (err error) {
|
||||
Session_Sanitize(row)
|
||||
if err = Session_Validate(row); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tx.Exec("INSERT INTO sessions(SessionID,CSRF,SignedIn,CreatedAt,LastSeenAt) VALUES(?,?,?,?,?)", row.SessionID, row.CSRF, row.SignedIn, row.CreatedAt, row.LastSeenAt)
|
||||
return err
|
||||
}
|
||||
|
||||
func Session_Delete(
|
||||
tx TX,
|
||||
SessionID string,
|
||||
) (err error) {
|
||||
result, err := tx.Exec("DELETE FROM sessions WHERE SessionID=?", SessionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
n, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
switch n {
|
||||
case 0:
|
||||
return sql.ErrNoRows
|
||||
case 1:
|
||||
return nil
|
||||
default:
|
||||
panic("multiple rows deleted")
|
||||
}
|
||||
}
|
||||
|
||||
func Session_Get(
|
||||
tx TX,
|
||||
SessionID string,
|
||||
) (
|
||||
row *Session,
|
||||
err error,
|
||||
) {
|
||||
row = &Session{}
|
||||
r := tx.QueryRow("SELECT SessionID,CSRF,SignedIn,CreatedAt,LastSeenAt FROM sessions WHERE SessionID=?", SessionID)
|
||||
err = r.Scan(&row.SessionID, &row.CSRF, &row.SignedIn, &row.CreatedAt, &row.LastSeenAt)
|
||||
return
|
||||
}
|
||||
|
||||
func Session_GetWhere(
|
||||
tx TX,
|
||||
query string,
|
||||
args ...any,
|
||||
) (
|
||||
row *Session,
|
||||
err error,
|
||||
) {
|
||||
row = &Session{}
|
||||
r := tx.QueryRow(query, args...)
|
||||
err = r.Scan(&row.SessionID, &row.CSRF, &row.SignedIn, &row.CreatedAt, &row.LastSeenAt)
|
||||
return
|
||||
}
|
||||
|
||||
func Session_Iterate(
|
||||
tx TX,
|
||||
query string,
|
||||
args ...any,
|
||||
) iter.Seq2[*Session, error] {
|
||||
rows, err := tx.Query(query, args...)
|
||||
if err != nil {
|
||||
return func(yield func(*Session, error) bool) {
|
||||
yield(nil, err)
|
||||
}
|
||||
}
|
||||
|
||||
return func(yield func(*Session, error) bool) {
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
row := &Session{}
|
||||
err := rows.Scan(&row.SessionID, &row.CSRF, &row.SignedIn, &row.CreatedAt, &row.LastSeenAt)
|
||||
if !yield(row, err) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Session_List(
|
||||
tx TX,
|
||||
query string,
|
||||
args ...any,
|
||||
) (
|
||||
l []*Session,
|
||||
err error,
|
||||
) {
|
||||
for row, err := range Session_Iterate(tx, query, args...) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l = append(l, row)
|
||||
}
|
||||
return l, nil
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Table: networks
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type Network struct {
|
||||
NetworkID int64
|
||||
Name string
|
||||
Network []byte
|
||||
NetworkID int64
|
||||
LocalDomain string
|
||||
Network []byte
|
||||
}
|
||||
|
||||
const Network_SelectQuery = "SELECT NetworkID,Name,Network FROM networks"
|
||||
const Network_SelectQuery = "SELECT NetworkID,LocalDomain,Network FROM networks"
|
||||
|
||||
func Network_Insert(
|
||||
tx TX,
|
||||
@@ -321,7 +207,7 @@ func Network_Insert(
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tx.Exec("INSERT INTO networks(NetworkID,Name,Network) VALUES(?,?,?)", row.NetworkID, row.Name, row.Network)
|
||||
_, err = tx.Exec("INSERT INTO networks(NetworkID,LocalDomain,Network) VALUES(?,?,?)", row.NetworkID, row.LocalDomain, row.Network)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -334,7 +220,7 @@ func Network_UpdateFull(
|
||||
return err
|
||||
}
|
||||
|
||||
result, err := tx.Exec("UPDATE networks SET Name=?,Network=? WHERE NetworkID=?", row.Name, row.Network, row.NetworkID)
|
||||
result, err := tx.Exec("UPDATE networks SET LocalDomain=?,Network=? WHERE NetworkID=?", row.LocalDomain, row.Network, row.NetworkID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -384,8 +270,10 @@ func Network_Get(
|
||||
err error,
|
||||
) {
|
||||
row = &Network{}
|
||||
r := tx.QueryRow("SELECT NetworkID,Name,Network FROM networks WHERE NetworkID=?", NetworkID)
|
||||
err = r.Scan(&row.NetworkID, &row.Name, &row.Network)
|
||||
r := tx.QueryRow("SELECT NetworkID,LocalDomain,Network FROM networks WHERE NetworkID=?", NetworkID)
|
||||
if err = r.Scan(&row.NetworkID, &row.LocalDomain, &row.Network); err != nil {
|
||||
row = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -399,7 +287,9 @@ func Network_GetWhere(
|
||||
) {
|
||||
row = &Network{}
|
||||
r := tx.QueryRow(query, args...)
|
||||
err = r.Scan(&row.NetworkID, &row.Name, &row.Network)
|
||||
if err = r.Scan(&row.NetworkID, &row.LocalDomain, &row.Network); err != nil {
|
||||
row = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -419,7 +309,7 @@ func Network_Iterate(
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
row := &Network{}
|
||||
err := rows.Scan(&row.NetworkID, &row.Name, &row.Network)
|
||||
err := rows.Scan(&row.NetworkID, &row.LocalDomain, &row.Network)
|
||||
if !yield(row, err) {
|
||||
return
|
||||
}
|
||||
@@ -451,17 +341,17 @@ func Network_List(
|
||||
type Peer struct {
|
||||
NetworkID int64
|
||||
PeerIP byte
|
||||
Version int64
|
||||
APIKey string
|
||||
Name string
|
||||
PublicIP []byte
|
||||
Addr4 []byte
|
||||
Addr6 []byte
|
||||
Port uint16
|
||||
Relay bool
|
||||
PubKey []byte
|
||||
PubSignKey []byte
|
||||
WGPubKey []byte
|
||||
SignPubKey []byte
|
||||
}
|
||||
|
||||
const Peer_SelectQuery = "SELECT NetworkID,PeerIP,Version,APIKey,Name,PublicIP,Port,Relay,PubKey,PubSignKey FROM peers"
|
||||
const Peer_SelectQuery = "SELECT NetworkID,PeerIP,APIKey,Name,Addr4,Addr6,Port,Relay,WGPubKey,SignPubKey FROM peers"
|
||||
|
||||
func Peer_Insert(
|
||||
tx TX,
|
||||
@@ -472,38 +362,10 @@ func Peer_Insert(
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tx.Exec("INSERT INTO peers(NetworkID,PeerIP,Version,APIKey,Name,PublicIP,Port,Relay,PubKey,PubSignKey) VALUES(?,?,?,?,?,?,?,?,?,?)", row.NetworkID, row.PeerIP, row.Version, row.APIKey, row.Name, row.PublicIP, row.Port, row.Relay, row.PubKey, row.PubSignKey)
|
||||
_, err = tx.Exec("INSERT INTO peers(NetworkID,PeerIP,APIKey,Name,Addr4,Addr6,Port,Relay,WGPubKey,SignPubKey) VALUES(?,?,?,?,?,?,?,?,?,?)", row.NetworkID, row.PeerIP, row.APIKey, row.Name, row.Addr4, row.Addr6, row.Port, row.Relay, row.WGPubKey, row.SignPubKey)
|
||||
return err
|
||||
}
|
||||
|
||||
func Peer_Update(
|
||||
tx TX,
|
||||
row *Peer,
|
||||
) (err error) {
|
||||
Peer_Sanitize(row)
|
||||
if err = Peer_Validate(row); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result, err := tx.Exec("UPDATE peers SET Version=?,Name=?,PublicIP=?,Port=?,Relay=? WHERE NetworkID=? AND PeerIP=?", row.Version, row.Name, row.PublicIP, row.Port, row.Relay, row.NetworkID, row.PeerIP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
n, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
switch n {
|
||||
case 0:
|
||||
return sql.ErrNoRows
|
||||
case 1:
|
||||
return nil
|
||||
default:
|
||||
panic("multiple rows updated")
|
||||
}
|
||||
}
|
||||
|
||||
func Peer_UpdateFull(
|
||||
tx TX,
|
||||
row *Peer,
|
||||
@@ -513,7 +375,7 @@ func Peer_UpdateFull(
|
||||
return err
|
||||
}
|
||||
|
||||
result, err := tx.Exec("UPDATE peers SET Version=?,APIKey=?,Name=?,PublicIP=?,Port=?,Relay=?,PubKey=?,PubSignKey=? WHERE NetworkID=? AND PeerIP=?", row.Version, row.APIKey, row.Name, row.PublicIP, row.Port, row.Relay, row.PubKey, row.PubSignKey, row.NetworkID, row.PeerIP)
|
||||
result, err := tx.Exec("UPDATE peers SET APIKey=?,Name=?,Addr4=?,Addr6=?,Port=?,Relay=?,WGPubKey=?,SignPubKey=? WHERE NetworkID=? AND PeerIP=?", row.APIKey, row.Name, row.Addr4, row.Addr6, row.Port, row.Relay, row.WGPubKey, row.SignPubKey, row.NetworkID, row.PeerIP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -565,8 +427,10 @@ func Peer_Get(
|
||||
err error,
|
||||
) {
|
||||
row = &Peer{}
|
||||
r := tx.QueryRow("SELECT NetworkID,PeerIP,Version,APIKey,Name,PublicIP,Port,Relay,PubKey,PubSignKey FROM peers WHERE NetworkID=? AND PeerIP=?", NetworkID, PeerIP)
|
||||
err = r.Scan(&row.NetworkID, &row.PeerIP, &row.Version, &row.APIKey, &row.Name, &row.PublicIP, &row.Port, &row.Relay, &row.PubKey, &row.PubSignKey)
|
||||
r := tx.QueryRow("SELECT NetworkID,PeerIP,APIKey,Name,Addr4,Addr6,Port,Relay,WGPubKey,SignPubKey FROM peers WHERE NetworkID=? AND PeerIP=?", NetworkID, PeerIP)
|
||||
if err = r.Scan(&row.NetworkID, &row.PeerIP, &row.APIKey, &row.Name, &row.Addr4, &row.Addr6, &row.Port, &row.Relay, &row.WGPubKey, &row.SignPubKey); err != nil {
|
||||
row = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -580,7 +444,9 @@ func Peer_GetWhere(
|
||||
) {
|
||||
row = &Peer{}
|
||||
r := tx.QueryRow(query, args...)
|
||||
err = r.Scan(&row.NetworkID, &row.PeerIP, &row.Version, &row.APIKey, &row.Name, &row.PublicIP, &row.Port, &row.Relay, &row.PubKey, &row.PubSignKey)
|
||||
if err = r.Scan(&row.NetworkID, &row.PeerIP, &row.APIKey, &row.Name, &row.Addr4, &row.Addr6, &row.Port, &row.Relay, &row.WGPubKey, &row.SignPubKey); err != nil {
|
||||
row = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -600,7 +466,7 @@ func Peer_Iterate(
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
row := &Peer{}
|
||||
err := rows.Scan(&row.NetworkID, &row.PeerIP, &row.Version, &row.APIKey, &row.Name, &row.PublicIP, &row.Port, &row.Relay, &row.PubKey, &row.PubSignKey)
|
||||
err := rows.Scan(&row.NetworkID, &row.PeerIP, &row.APIKey, &row.Name, &row.Addr4, &row.Addr6, &row.Port, &row.Relay, &row.WGPubKey, &row.SignPubKey)
|
||||
if !yield(row, err) {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -8,9 +8,11 @@ import (
|
||||
|
||||
var (
|
||||
ErrInvalidIP = errors.New("invalid IP")
|
||||
ErrInvalidPeerIP = errors.New("invalid peer IP")
|
||||
ErrNonPrivateIP = errors.New("non-private IP")
|
||||
ErrInvalidPort = errors.New("invalid port")
|
||||
ErrInvalidNetName = errors.New("invalid network name")
|
||||
ErrNetNameNotLocal = errors.New("network name must end with .local")
|
||||
ErrInvalidPeerName = errors.New("invalid peer name")
|
||||
)
|
||||
|
||||
@@ -21,15 +23,8 @@ func Config_Validate(c *Config) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func Session_Sanitize(s *Session) {
|
||||
}
|
||||
|
||||
func Session_Validate(s *Session) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func Network_Sanitize(n *Network) {
|
||||
n.Name = strings.TrimSpace(n.Name)
|
||||
n.LocalDomain = strings.TrimSpace(n.LocalDomain)
|
||||
|
||||
if addr, ok := netip.AddrFromSlice(n.Network); ok {
|
||||
n.Network = addr.AsSlice()
|
||||
@@ -37,12 +32,17 @@ func Network_Sanitize(n *Network) {
|
||||
}
|
||||
|
||||
func Network_Validate(c *Network) error {
|
||||
// 16 bytes is linux limit for network interface names.
|
||||
if len(c.Name) == 0 || len(c.Name) > 16 {
|
||||
// 15 bytes is linux limit for network interface names. With ending .local,
|
||||
// max length is 21.
|
||||
if len(c.LocalDomain) == 0 || len(c.LocalDomain) > 21 {
|
||||
return ErrInvalidNetName
|
||||
}
|
||||
|
||||
for _, c := range c.Name {
|
||||
if !strings.HasSuffix(c.LocalDomain, ".local") {
|
||||
return ErrNetNameNotLocal
|
||||
}
|
||||
|
||||
for _, c := range strings.TrimSuffix(c.LocalDomain, ".local") {
|
||||
if c >= 'a' && c <= 'z' {
|
||||
continue
|
||||
}
|
||||
@@ -66,21 +66,35 @@ func Network_Validate(c *Network) error {
|
||||
|
||||
func Peer_Sanitize(p *Peer) {
|
||||
p.Name = strings.TrimSpace(p.Name)
|
||||
if len(p.PublicIP) != 0 {
|
||||
addr, ok := netip.AddrFromSlice(p.PublicIP)
|
||||
if ok && addr.Is4() {
|
||||
p.PublicIP = addr.AsSlice()
|
||||
if len(p.Addr4) != 0 {
|
||||
if addr, ok := netip.AddrFromSlice(p.Addr4); ok {
|
||||
// Unmap so an IPv4-mapped form is stored canonically as 4 bytes.
|
||||
p.Addr4 = addr.Unmap().AsSlice()
|
||||
}
|
||||
}
|
||||
if len(p.Addr6) != 0 {
|
||||
if addr, ok := netip.AddrFromSlice(p.Addr6); ok {
|
||||
p.Addr6 = addr.AsSlice()
|
||||
}
|
||||
}
|
||||
if p.Port == 0 {
|
||||
p.Port = 456
|
||||
p.Port = 51820
|
||||
}
|
||||
}
|
||||
|
||||
func Peer_Validate(p *Peer) error {
|
||||
if len(p.PublicIP) > 0 {
|
||||
_, ok := netip.AddrFromSlice(p.PublicIP)
|
||||
if !ok {
|
||||
if p.PeerIP < 1 || p.PeerIP > 254 {
|
||||
return ErrInvalidPeerIP
|
||||
}
|
||||
if len(p.Addr4) > 0 {
|
||||
// Must be a genuine IPv4 address (reject an IPv6 in the v4 field).
|
||||
if addr, ok := netip.AddrFromSlice(p.Addr4); !ok || !addr.Is4() {
|
||||
return ErrInvalidIP
|
||||
}
|
||||
}
|
||||
if len(p.Addr6) > 0 {
|
||||
// Must be a genuine IPv6 address (reject IPv4 / IPv4-mapped in the v6 field).
|
||||
if addr, ok := netip.AddrFromSlice(p.Addr6); !ok || !addr.Is6() || addr.Is4In6() {
|
||||
return ErrInvalidIP
|
||||
}
|
||||
}
|
||||
@@ -88,6 +102,9 @@ func Peer_Validate(p *Peer) error {
|
||||
return ErrInvalidPort
|
||||
}
|
||||
|
||||
if len(p.Name) == 0 {
|
||||
return ErrInvalidPeerName
|
||||
}
|
||||
for _, c := range p.Name {
|
||||
if c >= 'a' && c <= 'z' {
|
||||
continue
|
||||
@@ -95,10 +112,9 @@ func Peer_Validate(p *Peer) error {
|
||||
if c >= '0' && c <= '9' {
|
||||
continue
|
||||
}
|
||||
if c == '.' || c == '-' || c == '_' {
|
||||
if c == '-' {
|
||||
continue
|
||||
}
|
||||
|
||||
return ErrInvalidPeerName
|
||||
}
|
||||
|
||||
|
||||
@@ -3,29 +3,21 @@ TABLE config OF Config (
|
||||
Password []byte
|
||||
);
|
||||
|
||||
TABLE sessions OF Session NoUpdate (
|
||||
SessionID string PK,
|
||||
CSRF string,
|
||||
SignedIn bool,
|
||||
CreatedAt int64,
|
||||
LastSeenAt int64
|
||||
);
|
||||
|
||||
TABLE networks OF Network (
|
||||
NetworkID int64 PK,
|
||||
Name string NoUpdate,
|
||||
Network []byte NoUpdate
|
||||
NetworkID int64 PK,
|
||||
LocalDomain string NoUpdate,
|
||||
Network []byte NoUpdate
|
||||
);
|
||||
|
||||
TABLE peers OF Peer (
|
||||
NetworkID int64 PK,
|
||||
PeerIP byte PK,
|
||||
Version int64,
|
||||
APIKey string NoUpdate,
|
||||
Name string,
|
||||
PublicIP []byte,
|
||||
Port uint16,
|
||||
Relay bool,
|
||||
PubKey []byte NoUpdate,
|
||||
PubSignKey []byte NoUpdate
|
||||
Name string NoUpdate,
|
||||
Addr4 []byte NoUpdate,
|
||||
Addr6 []byte NoUpdate,
|
||||
Port uint16 NoUpdate,
|
||||
Relay bool NoUpdate,
|
||||
WGPubKey []byte NoUpdate,
|
||||
SignPubKey []byte NoUpdate
|
||||
);
|
||||
|
||||
@@ -1,31 +1,5 @@
|
||||
package db
|
||||
|
||||
import "time"
|
||||
|
||||
func Session_UpdateLastSeenAt(
|
||||
tx TX,
|
||||
id string,
|
||||
) (err error) {
|
||||
_, err = tx.Exec("UPDATE sessions SET LastSeenAt=? WHERE SessionID=?", time.Now().Unix(), id)
|
||||
return err
|
||||
}
|
||||
|
||||
func Session_SetSignedIn(
|
||||
tx TX,
|
||||
id string,
|
||||
) (err error) {
|
||||
_, err = tx.Exec("UPDATE sessions SET SignedIn=1 WHERE SessionID=?", id)
|
||||
return err
|
||||
}
|
||||
|
||||
func Session_DeleteBefore(
|
||||
tx TX,
|
||||
timestamp int64,
|
||||
) (err error) {
|
||||
_, err = tx.Exec("DELETE FROM sessions WHERE LastSeenAt<?", timestamp)
|
||||
return err
|
||||
}
|
||||
|
||||
func Peer_ListAll(tx TX, networkID int64) ([]*Peer, error) {
|
||||
const query = Peer_SelectQuery + ` WHERE NetworkID=? ORDER BY PeerIP ASC`
|
||||
return Peer_List(tx, query, networkID)
|
||||
@@ -37,9 +11,3 @@ func Peer_GetByAPIKey(tx TX, apiKey string) (*Peer, error) {
|
||||
Peer_SelectQuery+` WHERE APIKey=?`,
|
||||
apiKey)
|
||||
}
|
||||
|
||||
func Peer_Exists(tx TX, networkID int64, ip byte) (exists bool, err error) {
|
||||
const query = `SELECT EXISTS(SELECT 1 FROM peers WHERE NetworkID=? AND PeerIP=?)`
|
||||
err = tx.QueryRow(query, networkID, ip).Scan(&exists)
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user