diff --git a/hub/api/api.go b/hub/api/api.go index 844e59e..7c6ac0d 100644 --- a/hub/api/api.go +++ b/hub/api/api.go @@ -169,11 +169,25 @@ func (a *API) sweepSessions() { } func (a *API) Network_Create(n *Network) error { + a.lock.Lock() + defer a.lock.Unlock() + n.NetworkID = idgen.NextID(0) return errs.DB(db.Network_Insert(a.db, n)) } func (a *API) Network_Delete(n *Network) error { + a.lock.Lock() + defer a.lock.Unlock() + + exists, err := db.Network_HasPeers(a.db, n.NetworkID) + if err != nil { + return errs.DB(err) + } + if exists { + return errs.Conflict.WithMsg("Delete all peers before deleting network.") + } + return errs.DB(db.Network_Delete(a.db, n.NetworkID)) } @@ -189,6 +203,9 @@ func (a *API) Network_List() ([]*Network, error) { } func (a *API) Peer_CreateNew(p *Peer) error { + a.lock.Lock() + defer a.lock.Unlock() + p.WGPubKey = []byte{} p.SignPubKey = []byte{} p.APIKey = idgen.NewToken() @@ -217,6 +234,9 @@ func (a *API) Peer_Init(peer *Peer, args m.PeerInitArgs) error { } func (a *API) Peer_Delete(networkID int64, peerIP byte) error { + a.lock.Lock() + defer a.lock.Unlock() + return errs.DB(db.Peer_Delete(a.db, networkID, peerIP)) } diff --git a/hub/api/db/written.go b/hub/api/db/written.go index 11251f2..da3d4c9 100644 --- a/hub/api/db/written.go +++ b/hub/api/db/written.go @@ -1,5 +1,7 @@ package db +import "database/sql" + func Peer_ListAll(tx TX, networkID int64) ([]*Peer, error) { const query = Peer_SelectQuery + ` WHERE NetworkID=? ORDER BY PeerIP ASC` return Peer_List(tx, query, networkID) @@ -11,3 +13,9 @@ func Peer_GetByAPIKey(tx TX, apiKey string) (*Peer, error) { Peer_SelectQuery+` WHERE APIKey=?`, apiKey) } + +func Network_HasPeers(db *sql.DB, networkID int64) (exists bool, err error) { + const query = "SELECT EXISTS(SELECT 1 FROM peers WHERE NetworkID=?)" + err = db.QueryRow(query, networkID).Scan(&exists) + return exists, err +} diff --git a/hub/handlers.go b/hub/handlers.go index 2db9157..1f8c8d9 100644 --- a/hub/handlers.go +++ b/hub/handlers.go @@ -4,6 +4,7 @@ import ( "encoding/json" "log" "math/rand/v2" + "net" "net/http" "time" "vppn/hub/api" @@ -28,11 +29,12 @@ func (a *App) _signin(s *api.Session, w http.ResponseWriter, r *http.Request) er } func (a *App) _signinSubmit(s *api.Session, w http.ResponseWriter, r *http.Request) error { - if !a.signInLock.TryLock(r.RemoteAddr) { + host, _, _ := net.SplitHostPort(r.RemoteAddr) + if !a.signInLock.TryLock(host) { time.Sleep(time.Duration(rand.Int64N(int64(4 * time.Second)))) return errs.ErrNotAuthorized } - defer a.signInLock.Unlock(r.RemoteAddr) + defer a.signInLock.Unlock(host) var pwd string err := webutil.NewFormScanner(r.Form). @@ -244,7 +246,7 @@ func (a *App) _adminPasswordSubmit(s *api.Session, w http.ResponseWriter, r *htt return err } - if len(newPwd) < 8 { + if len(newPwd) < 8 || len(newPwd) > 72 { return errs.ErrInvalidPassword } @@ -342,9 +344,10 @@ func (a *App) peersList(networkID int64) (peers []m.Peer, err error) { } wgKey, err := wgtypes.NewKey(p.WGPubKey) if err != nil { - log.Printf("Bad WG key in DB for peer %v", p) + log.Printf("Bad WG key in DB for peer %d/%d", p.NetworkID, p.PeerIP) continue // malformed key; skip rather than serve garbage } + var signKey [32]byte copy(signKey[:], p.SignPubKey) peers = append(peers, m.Peer{ diff --git a/peer/init.go b/peer/init.go index e3a1699..029a513 100644 --- a/peer/init.go +++ b/peer/init.go @@ -9,6 +9,7 @@ import ( "net/http" "net/netip" "os" + "time" "golang.org/x/crypto/nacl/sign" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -93,7 +94,7 @@ func initFromHub(hubURL, apiKey string, privKey wgtypes.Key) (LocalState, error) req.SetBasicAuth("", apiKey) req.Header.Set("Content-Type", "application/json") - resp, err := http.DefaultClient.Do(req) + resp, err := (&http.Client{Timeout: time.Minute}).Do(req) if err != nil { return LocalState{}, fmt.Errorf("hub init: %w", err) } diff --git a/peer/multicast/receiver.go b/peer/multicast/receiver.go index c8d3a10..0ceae23 100644 --- a/peer/multicast/receiver.go +++ b/peer/multicast/receiver.go @@ -63,13 +63,14 @@ func receiver(selfVPNIP netip.Addr, limiters []*ratelimiter.Limiter, ch chan<- P continue } - if err := limiters[packet.PeerIP].Limit(); err != nil { - log.Printf("Rate limited packet from peer IP %d.", packet.PeerIP) + // Slightly cheaper than limiting. + age := time.Since(time.Unix(packet.Timestamp, 0)) + if age > maxPacketAge || age < -maxPacketAge { continue } - age := time.Since(time.Unix(packet.Timestamp, 0)) - if age > maxPacketAge || age < -maxPacketAge { + if err := limiters[packet.PeerIP].Limit(); err != nil { + log.Printf("Rate limited packet from peer IP %d.", packet.PeerIP) continue }