diff --git a/fts5/fts5.go b/fts5/fts5.go new file mode 100644 index 0000000..cb9e7df --- /dev/null +++ b/fts5/fts5.go @@ -0,0 +1,151 @@ +package fts5 + +import ( + "database/sql" + "fmt" + "iter" + "strings" + + _ "github.com/mattn/go-sqlite3" +) + +type FTS5 struct { + colNames []string + db *sql.DB + + insertStmt *sql.Stmt + matchStmt *sql.Stmt +} + +func OpenMem(columnNames ...string) (*FTS5, error) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + return nil, err + } + return newInternal(db, columnNames) +} + +func Open(path string, columnNames ...string) (*FTS5, error) { + db, err := sql.Open("sqlite3", path) + if err != nil { + return nil, err + } + return newInternal(db, columnNames) +} + +func newInternal(db *sql.DB, columnNames []string) (*FTS5, error) { + _, err := db.Exec(fmt.Sprintf( + `CREATE VIRTUAL TABLE IF NOT EXISTS search USING fts5(%s, tokenize='porter')`, + strings.Join(columnNames, ","))) + if err != nil { + db.Close() + return nil, err + } + + insertStmt, err := prepareInsertStmt(db, columnNames) + if err != nil { + db.Close() + return nil, err + } + + matchStmt, err := db.Prepare(`SELECT rowid,rank FROM search WHERE search=? ` + + `ORDER BY rank ` + + `LIMIT ? OFFSET ?`) + if err != nil { + db.Close() + return nil, err + } + + return &FTS5{ + colNames: columnNames, + db: db, + insertStmt: insertStmt, + matchStmt: matchStmt, + }, nil +} + +func (fts *FTS5) Close() error { + return fts.db.Close() +} + +func (fts *FTS5) Upsert(id int64, data map[string]string) error { + return execInsertStmt(fts.insertStmt, fts.colNames, id, data) +} + +type Item struct { + ID int64 + Data map[string]string +} + +func (fts *FTS5) UpsertBulk(src iter.Seq2[Item, error]) error { + tx, err := fts.db.Begin() + if err != nil { + return err + } + + insertStmt, err := prepareInsertStmt(tx, fts.colNames) + if err != nil { + tx.Rollback() + return err + } + + for item, err := range src { + if err != nil { + tx.Rollback() + return err + } + + if err := execInsertStmt(insertStmt, fts.colNames, item.ID, item.Data); err != nil { + tx.Rollback() + return err + } + } + + return tx.Commit() +} + +func (fts *FTS5) Delete(id int64) (bool, error) { + result, err := fts.db.Exec(`DELETE FROM search WHERE rowid=?`, id) + if err != nil { + return false, err + } + + count, err := result.RowsAffected() + if err != nil { + panic(err) + } + if count > 1 { + panic("multiple rows deleted") + } + return count > 0, nil +} + +func (fts *FTS5) DeleteAll() error { + _, err := fts.db.Exec(`DELETE FROM search`) + return err +} + +type Result struct { + ID int64 + Rank float64 +} + +func (fts *FTS5) Match(query string, limit, offset int) ([]Result, error) { + results := make([]Result, 0, limit) + + rows, err := fts.matchStmt.Query(query, limit, offset) + if err != nil { + return nil, err + } + defer rows.Close() + + for rows.Next() { + result := Result{} + if err := rows.Scan(&result.ID, &result.Rank); err != nil { + return nil, err + } + results = append(results, result) + } + + return results, nil +} diff --git a/fts5/fts5_test.go b/fts5/fts5_test.go new file mode 100644 index 0000000..e73fefb --- /dev/null +++ b/fts5/fts5_test.go @@ -0,0 +1,59 @@ +package fts5 + +import ( + "testing" +) + +func TestMatch(t *testing.T) { + fts, err := OpenMem("Title", "Text") + if err != nil { + t.Fatal(err) + } + + fts.Upsert(100, map[string]string{ + "Title": "Peter Rabbit", + "Text": "Peter Rabbit is a fictional animal character in various children's stories by English author Beatrix Potter.", + }) + + fts.Upsert(200, map[string]string{ + "Title": "Baloo", + "Text": "Fictional bear from the Jungle Book.", + }) + + fts.Upsert(300, map[string]string{ + "Title": "Barney Bear", + "Text": "A grumpy brown bear.", + }) + + fts.Upsert(400, map[string]string{ + "Title": "Peter the Panda", + "Text": "A secret agent panda from Seattle.", + }) + + // Should have 2 results for "peter". + results, err := fts.Match("peter", 3, 0) + if err != nil { + t.Fatal(err) + } + if len(results) != 2 { + t.Fatal(results) + } + + // Try limit. + results, err = fts.Match("peter", 1, 0) + if err != nil { + t.Fatal(err) + } + if len(results) != 1 { + t.Fatal(results) + } + + // Try offset. + results, err = fts.Match("peter", 3, 1) + if err != nil { + t.Fatal(err) + } + if len(results) != 1 { + t.Fatal(results) + } +} diff --git a/fts5/query.go b/fts5/query.go new file mode 100644 index 0000000..3c73cbd --- /dev/null +++ b/fts5/query.go @@ -0,0 +1,36 @@ +package fts5 + +import ( + "database/sql" + "fmt" + "strings" +) + +type preparer interface { + Prepare(string) (*sql.Stmt, error) +} + +func prepareInsertStmt( + db preparer, + columnNames []string, +) (*sql.Stmt, error) { + insertQuery := fmt.Sprintf(`INSERT INTO search(rowid,%s) VALUES (?%s)`, + strings.Join(columnNames, ","), + strings.Repeat(",?", len(columnNames))) + return db.Prepare(insertQuery) +} + +func execInsertStmt( + stmt *sql.Stmt, + colNames []string, + id int64, + data map[string]string, +) error { + values := make([]any, len(colNames)+1) + values[0] = id + for i, col := range colNames { + values[i+1] = data[col] + } + _, err := stmt.Exec(values...) + return err +} diff --git a/fts5/requiretag.go b/fts5/requiretag.go new file mode 100644 index 0000000..40de161 --- /dev/null +++ b/fts5/requiretag.go @@ -0,0 +1,9 @@ +//go:build !fts5 + +package fts5 + +import "log" + +func init() { + log.Fatal("You must provide build tag `-tags fts5`.") +}