go/fts5/fts5.go

152 lines
2.8 KiB
Go
Raw Normal View History

2024-11-19 15:41:41 +00:00
package fts5
import (
"database/sql"
"fmt"
"iter"
"strings"
_ "github.com/mattn/go-sqlite3"
)
type FTS5 struct {
colNames []string
db *sql.DB
insertStmt *sql.Stmt
matchStmt *sql.Stmt
}
func OpenMem(columnNames ...string) (*FTS5, error) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
return nil, err
}
return newInternal(db, columnNames)
}
func Open(path string, columnNames ...string) (*FTS5, error) {
db, err := sql.Open("sqlite3", path)
if err != nil {
return nil, err
}
return newInternal(db, columnNames)
}
func newInternal(db *sql.DB, columnNames []string) (*FTS5, error) {
_, err := db.Exec(fmt.Sprintf(
`CREATE VIRTUAL TABLE IF NOT EXISTS search USING fts5(%s, tokenize='porter')`,
strings.Join(columnNames, ",")))
if err != nil {
db.Close()
return nil, err
}
insertStmt, err := prepareInsertStmt(db, columnNames)
if err != nil {
db.Close()
return nil, err
}
matchStmt, err := db.Prepare(`SELECT rowid,rank FROM search WHERE search=? ` +
`ORDER BY rank ` +
`LIMIT ? OFFSET ?`)
if err != nil {
db.Close()
return nil, err
}
return &FTS5{
colNames: columnNames,
db: db,
insertStmt: insertStmt,
matchStmt: matchStmt,
}, nil
}
func (fts *FTS5) Close() error {
return fts.db.Close()
}
func (fts *FTS5) Upsert(id int64, data map[string]string) error {
return execInsertStmt(fts.insertStmt, fts.colNames, id, data)
}
type Item struct {
ID int64
Data map[string]string
}
func (fts *FTS5) UpsertBulk(src iter.Seq2[Item, error]) error {
tx, err := fts.db.Begin()
if err != nil {
return err
}
insertStmt, err := prepareInsertStmt(tx, fts.colNames)
if err != nil {
tx.Rollback()
return err
}
for item, err := range src {
if err != nil {
tx.Rollback()
return err
}
if err := execInsertStmt(insertStmt, fts.colNames, item.ID, item.Data); err != nil {
tx.Rollback()
return err
}
}
return tx.Commit()
}
func (fts *FTS5) Delete(id int64) (bool, error) {
result, err := fts.db.Exec(`DELETE FROM search WHERE rowid=?`, id)
if err != nil {
return false, err
}
count, err := result.RowsAffected()
if err != nil {
panic(err)
}
if count > 1 {
panic("multiple rows deleted")
}
return count > 0, nil
}
func (fts *FTS5) DeleteAll() error {
_, err := fts.db.Exec(`DELETE FROM search`)
return err
}
type Result struct {
ID int64
Rank float64
}
func (fts *FTS5) Match(query string, limit, offset int) ([]Result, error) {
results := make([]Result, 0, limit)
rows, err := fts.matchStmt.Query(query, limit, offset)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
result := Result{}
if err := rows.Scan(&result.ID, &result.Rank); err != nil {
return nil, err
}
results = append(results, result)
}
return results, nil
}