package am import ( "database/sql" "errors" "log" "strconv" "time" _ "github.com/mattn/go-sqlite3" "golang.org/x/crypto/bcrypt" ) type dbal struct { sql.DB } func (*dbal) logf(s string, args ...interface{}) { log.Printf("db: "+s, args...) } func newDB(dbPath string) (*dbal, error) { _db, err := sql.Open("sqlite3", dbPath) if err != nil { log.Printf("Failed to open database %s: %v", dbPath, err) return nil, err } db := &dbal{*_db} if _, err := db.Exec(migration); err != nil { log.Printf("Failed to apply database migration: %v", err) return nil, err } _ = db.UserInsert(User{ Username: `root`, Admin: true, }, `root1234`) return db, nil } /********* * Users * *********/ func (db *dbal) UserInsert(u User, pwd string) error { if err := validateName(u.Username); err != nil { return err } if err := validatePwd(pwd); err != nil { return err } bHash, err := bcrypt.GenerateFromPassword( []byte(pwd), bcrypt.DefaultCost) if err != nil { db.logf("Failed to hash password: %v", err) return err } _, err = db.Exec( `INSERT INTO users(Username,Password,Admin)VALUES(?,?,?)`, u.Username, string(bHash), u.Admin) if err != nil { db.logf("Failed to insert user %s: %v", u.Username, err) } return err } func (db *dbal) UserGet(username string) (u User, err error) { u.Username = username err = db.QueryRow( `SELECT Admin FROM users WHERE Username=?`, username).Scan(&u.Admin) if err != nil { db.logf("Failed to get user %s: %v", username, err) } return } func (db *dbal) UserGetWithPwd(username, pwd string) (u User, err error) { u.Username = username pwdHash := "" err = db.QueryRow( `SELECT Password,Admin FROM users WHERE Username=?`, username).Scan(&pwdHash, &u.Admin) if err != nil { db.logf("Failed to get user %s: %v", username, err) return u, err } err = bcrypt.CompareHashAndPassword([]byte(pwdHash), []byte(pwd)) if err != nil { db.logf("Password comparison failed for user %s: %v", username, err) } return u, err } func (db *dbal) UserList() (l []User, err error) { rows, err := db.Query(`SELECT Username, Admin FROM users ORDER BY Username`) if err != nil { db.logf("Failed to list users: %v", err) return nil, err } for rows.Next() { u := User{} if err := rows.Scan(&u.Username, &u.Admin); err != nil { db.logf("Failed to scan user: %v", err) return nil, err } l = append(l, u) } return l, nil } func (db *dbal) UserUpdatePwd(username, pwd string) error { if err := validatePwd(pwd); err != nil { return err } bHash, err := bcrypt.GenerateFromPassword( []byte(pwd), bcrypt.DefaultCost) if err != nil { db.logf("Failed to hash password: %v", err) return err } _, err = db.Exec( `UPDATE users SET Password=? WHERE Username=?`, string(bHash), username) if err != nil { db.logf("Failed to update password for user %s: %v", username, err) } return err } func (db *dbal) UserUpdateAdmin(username string, b bool) error { if username == `root` { return errors.New("Root user is always an admin.") } _, err := db.Exec( `UPDATE users SET Admin=? WHERE Username=?`, b, username) if err != nil { db.logf("Failed to update admin for user %s: %v", username, err) } return err } func (db *dbal) UserDelete(username string) error { if username == `root` { return errors.New("Root user cannot be deleted.") } _, err := db.Exec(`DELETE FROM users WHERE username=?`, username) if err != nil { db.logf("Failed to delete user %s: %v", username, err) } return err } /*********** * Sources * ***********/ func (db *dbal) SourceInsert(s *Source) error { if err := validateName(s.Name); err != nil { return err } s.SourceID = newUUID() s.APIKey = newUUID() s.LastSeenAt = time.Now().UTC() s.AlertedAt = s.LastSeenAt _, err := db.Exec(`INSERT INTO sources(`+ ` SourceID,Name,APIKey,Description,`+ ` LastSeenAt,AlertTimeout,AlertedAt`+ `)VALUES(?,?,?,?,?,?,?)`, s.SourceID, s.Name, s.APIKey, s.Description, s.LastSeenAt, s.AlertTimeout, s.AlertedAt) if err != nil { db.logf("Failed to insert source: %v", err) } return err } const sourceCols = `SourceID,Name,APIKey,Description,` + `LastSeenAt,AlertTimeout,AlertedAt` func (db *dbal) scanSource( row interface{ Scan(...interface{}) error }, ) (s Source, err error) { err = row.Scan( &s.SourceID, &s.Name, &s.APIKey, &s.Description, &s.LastSeenAt, &s.AlertTimeout, &s.AlertedAt) if err != nil { db.logf("Failed to scan source: %v", err) } return } func (db *dbal) SourceGet(id string) (s Source, err error) { return db.scanSource( db.QueryRow( `SELECT `+sourceCols+` FROM sources WHERE SourceID=?`, id)) } func (db *dbal) SourceGetByKey(key string) (s Source, err error) { return db.scanSource( db.QueryRow( `SELECT `+sourceCols+` FROM sources WHERE APIKey=?`, key)) } func (db *dbal) SourceList() (l []Source, err error) { rows, err := db.Query(`SELECT ` + sourceCols + ` FROM sources ORDER BY Name`) if err != nil { db.logf("Failed to list sources: %v", err) return nil, err } for rows.Next() { u, err := db.scanSource(rows) if err != nil { return nil, err } l = append(l, u) } return l, nil } // Updates Description and AlertTimeout. func (db *dbal) SourceUpdate(s Source) error { _, err := db.Exec(`UPDATE sources `+ `SET Description=?,AlertTimeout=? `+ `WHERE SourceID=?`, s.Description, s.AlertTimeout, s.SourceID) if err != nil { db.logf("Failed to update source %s: %v", s.Name, err) } return err } func (db *dbal) SourceUpdateLastSeenAt(id string) error { t := time.Now().UTC() _, err := db.Exec(`UPDATE sources SET LastSeenAt=? WHERE SourceID=?`, t, id) if err != nil { db.logf("Failed to update source last seen at %s: %v", id, err) } return err } func (db *dbal) SourceUpdateAlertedAt(id string) error { t := time.Now().UTC() _, err := db.Exec(`UPDATE sources SET AlertedAt=? WHERE SourceID=?`, t, id) if err != nil { db.logf("Failed to update source alerted at %s: %v", id, err) } return err } func (db *dbal) SourceDelete(id string) error { _, err := db.Exec(`DELETE FROM sources WHERE SourceID=?`, id) if err != nil { db.logf("Failed to delete source %s: %v", id, err) } return err } /******* * Log * *******/ func (db *dbal) LogInsert(e Entry) error { e.TS = time.Now().UTC() _, err := db.Exec(`INSERT INTO log`+ `(SourceID,TS,Alert,Text)VALUES(?,?,?,?)`, e.SourceID, e.TS, e.Alert, e.Text) if err != nil { db.logf("Failed to insert log entry: %v", err) } return err } const logCols = `LogID,SourceID,TS,Alert,Text` func (db *dbal) scanLog( row interface{ Scan(...interface{}) error }, ) (e Entry, err error) { err = row.Scan(&e.LogID, &e.SourceID, &e.TS, &e.Alert, &e.Text) if err != nil { db.logf("Failed to scan log entry: %v", err) } return } type LogListArgs struct { BeforeID int64 Alert *bool SourceID string Limit int64 } func (db *dbal) LogList(args LogListArgs) (l []EntryListRow, err error) { if args.Limit <= 0 { args.Limit = 100 } qArgs := []interface{}{} query := `SELECT ` + `l.LogID,s.SourceID,s.Name,l.TS,l.Alert,l.Text ` + `FROM log l JOIN sources s ON l.SourceID=s.SourceID WHERE 1` if args.BeforeID != 0 { query += " AND LogID