Compare commits

..

3 Commits
v0.7.2 ... main

Author SHA1 Message Date
jdl
9061198e7f fts5 2024-11-19 16:41:41 +01:00
jdl
9070d8cfc0 wip 2024-11-19 16:30:42 +01:00
jdl
0d8cc762c0 wip 2024-11-19 16:20:18 +01:00
10 changed files with 308 additions and 55 deletions

151
fts5/fts5.go Normal file
View File

@ -0,0 +1,151 @@
package fts5
import (
"database/sql"
"fmt"
"iter"
"strings"
_ "github.com/mattn/go-sqlite3"
)
type FTS5 struct {
colNames []string
db *sql.DB
insertStmt *sql.Stmt
matchStmt *sql.Stmt
}
func OpenMem(columnNames ...string) (*FTS5, error) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
return nil, err
}
return newInternal(db, columnNames)
}
func Open(path string, columnNames ...string) (*FTS5, error) {
db, err := sql.Open("sqlite3", path)
if err != nil {
return nil, err
}
return newInternal(db, columnNames)
}
func newInternal(db *sql.DB, columnNames []string) (*FTS5, error) {
_, err := db.Exec(fmt.Sprintf(
`CREATE VIRTUAL TABLE IF NOT EXISTS search USING fts5(%s, tokenize='porter')`,
strings.Join(columnNames, ",")))
if err != nil {
db.Close()
return nil, err
}
insertStmt, err := prepareInsertStmt(db, columnNames)
if err != nil {
db.Close()
return nil, err
}
matchStmt, err := db.Prepare(`SELECT rowid,rank FROM search WHERE search=? ` +
`ORDER BY rank ` +
`LIMIT ? OFFSET ?`)
if err != nil {
db.Close()
return nil, err
}
return &FTS5{
colNames: columnNames,
db: db,
insertStmt: insertStmt,
matchStmt: matchStmt,
}, nil
}
func (fts *FTS5) Close() error {
return fts.db.Close()
}
func (fts *FTS5) Upsert(id int64, data map[string]string) error {
return execInsertStmt(fts.insertStmt, fts.colNames, id, data)
}
type Item struct {
ID int64
Data map[string]string
}
func (fts *FTS5) UpsertBulk(src iter.Seq2[Item, error]) error {
tx, err := fts.db.Begin()
if err != nil {
return err
}
insertStmt, err := prepareInsertStmt(tx, fts.colNames)
if err != nil {
tx.Rollback()
return err
}
for item, err := range src {
if err != nil {
tx.Rollback()
return err
}
if err := execInsertStmt(insertStmt, fts.colNames, item.ID, item.Data); err != nil {
tx.Rollback()
return err
}
}
return tx.Commit()
}
func (fts *FTS5) Delete(id int64) (bool, error) {
result, err := fts.db.Exec(`DELETE FROM search WHERE rowid=?`, id)
if err != nil {
return false, err
}
count, err := result.RowsAffected()
if err != nil {
panic(err)
}
if count > 1 {
panic("multiple rows deleted")
}
return count > 0, nil
}
func (fts *FTS5) DeleteAll() error {
_, err := fts.db.Exec(`DELETE FROM search`)
return err
}
type Result struct {
ID int64
Rank float64
}
func (fts *FTS5) Match(query string, limit, offset int) ([]Result, error) {
results := make([]Result, 0, limit)
rows, err := fts.matchStmt.Query(query, limit, offset)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
result := Result{}
if err := rows.Scan(&result.ID, &result.Rank); err != nil {
return nil, err
}
results = append(results, result)
}
return results, nil
}

59
fts5/fts5_test.go Normal file
View File

@ -0,0 +1,59 @@
package fts5
import (
"testing"
)
func TestMatch(t *testing.T) {
fts, err := OpenMem("Title", "Text")
if err != nil {
t.Fatal(err)
}
fts.Upsert(100, map[string]string{
"Title": "Peter Rabbit",
"Text": "Peter Rabbit is a fictional animal character in various children's stories by English author Beatrix Potter.",
})
fts.Upsert(200, map[string]string{
"Title": "Baloo",
"Text": "Fictional bear from the Jungle Book.",
})
fts.Upsert(300, map[string]string{
"Title": "Barney Bear",
"Text": "A grumpy brown bear.",
})
fts.Upsert(400, map[string]string{
"Title": "Peter the Panda",
"Text": "A secret agent panda from Seattle.",
})
// Should have 2 results for "peter".
results, err := fts.Match("peter", 3, 0)
if err != nil {
t.Fatal(err)
}
if len(results) != 2 {
t.Fatal(results)
}
// Try limit.
results, err = fts.Match("peter", 1, 0)
if err != nil {
t.Fatal(err)
}
if len(results) != 1 {
t.Fatal(results)
}
// Try offset.
results, err = fts.Match("peter", 3, 1)
if err != nil {
t.Fatal(err)
}
if len(results) != 1 {
t.Fatal(results)
}
}

36
fts5/query.go Normal file
View File

@ -0,0 +1,36 @@
package fts5
import (
"database/sql"
"fmt"
"strings"
)
type preparer interface {
Prepare(string) (*sql.Stmt, error)
}
func prepareInsertStmt(
db preparer,
columnNames []string,
) (*sql.Stmt, error) {
insertQuery := fmt.Sprintf(`INSERT INTO search(rowid,%s) VALUES (?%s)`,
strings.Join(columnNames, ","),
strings.Repeat(",?", len(columnNames)))
return db.Prepare(insertQuery)
}
func execInsertStmt(
stmt *sql.Stmt,
colNames []string,
id int64,
data map[string]string,
) error {
values := make([]any, len(colNames)+1)
values[0] = id
for i, col := range colNames {
values[i+1] = data[col]
}
_, err := stmt.Exec(values...)
return err
}

9
fts5/requiretag.go Normal file
View File

@ -0,0 +1,9 @@
//go:build !fts5
package fts5
import "log"
func init() {
log.Fatal("You must provide build tag `-tags fts5`.")
}

View File

@ -2,30 +2,9 @@ package {{.PackageName}}
import ( import (
"database/sql" "database/sql"
"errors"
"iter" "iter"
"github.com/mattn/go-sqlite3"
) )
var (
ErrConstraint = errors.New("constraint violation")
ErrNotFound = errors.New("not found")
)
func translateError(err error) error {
if err == nil {
return nil
}
if e, ok := err.(sqlite3.Error); ok && e.Code == 19 {
return errors.Join(ErrConstraint, err)
}
if errors.Is(err, sql.ErrNoRows) {
return errors.Join(ErrNotFound, err)
}
return err
}
type TX interface { type TX interface {
Exec(query string, args ...any) (sql.Result, error) Exec(query string, args ...any) (sql.Result, error)
Query(query string, args ...any) (*sql.Rows, error) Query(query string, args ...any) (*sql.Rows, error)
@ -57,7 +36,7 @@ func {{.Type}}_Insert(
} }
_, err = tx.Exec("{{.InsertQuery}}", {{.InsertArgs}}) _, err = tx.Exec("{{.InsertQuery}}", {{.InsertArgs}})
return translateError(err) return err
} }
{{- end}} {{/* if not .NoInsert */}} {{- end}} {{/* if not .NoInsert */}}
@ -77,7 +56,7 @@ func {{.Type}}_Update(
result, err := tx.Exec("{{.UpdateQuery}}", {{.UpdateArgs}}) result, err := tx.Exec("{{.UpdateQuery}}", {{.UpdateArgs}})
if err != nil { if err != nil {
return translateError(err) return err
} }
n, err := result.RowsAffected() n, err := result.RowsAffected()
@ -86,7 +65,7 @@ func {{.Type}}_Update(
} }
switch n { switch n {
case 0: case 0:
return ErrNotFound return sql.ErrNoRows
case 1: case 1:
return nil return nil
default: default:
@ -108,7 +87,7 @@ func {{.Type}}_UpdateFull(
result, err := tx.Exec("{{.UpdateFullQuery}}", {{.UpdateFullArgs}}) result, err := tx.Exec("{{.UpdateFullQuery}}", {{.UpdateFullArgs}})
if err != nil { if err != nil {
return translateError(err) return err
} }
n, err := result.RowsAffected() n, err := result.RowsAffected()
@ -117,7 +96,7 @@ func {{.Type}}_UpdateFull(
} }
switch n { switch n {
case 0: case 0:
return ErrNotFound return sql.ErrNoRows
case 1: case 1:
return nil return nil
default: default:
@ -137,7 +116,7 @@ func {{.Type}}_Delete(
) (err error) { ) (err error) {
result, err := tx.Exec("{{.DeleteQuery}}", {{.DeleteArgs}}) result, err := tx.Exec("{{.DeleteQuery}}", {{.DeleteArgs}})
if err != nil { if err != nil {
return translateError(err) return err
} }
n, err := result.RowsAffected() n, err := result.RowsAffected()
@ -146,7 +125,7 @@ func {{.Type}}_Delete(
} }
switch n { switch n {
case 0: case 0:
return ErrNotFound return sql.ErrNoRows
case 1: case 1:
return nil return nil
default: default:
@ -165,7 +144,7 @@ func {{.Type}}_Get(
) { ) {
row = &{{.Type}}{} row = &{{.Type}}{}
r := tx.QueryRow("{{.GetQuery}}", {{.DeleteArgs}}) r := tx.QueryRow("{{.GetQuery}}", {{.DeleteArgs}})
err = translateError(r.Scan({{.ScanArgs}})) err = r.Scan({{.ScanArgs}})
return return
} }
@ -180,7 +159,7 @@ func {{.Type}}_GetWhere(
) { ) {
row = &{{.Type}}{} row = &{{.Type}}{}
r := tx.QueryRow(query, args...) r := tx.QueryRow(query, args...)
err = translateError(r.Scan({{.ScanArgs}})) err = r.Scan({{.ScanArgs}})
return return
} }
@ -204,7 +183,7 @@ func {{.Type}}_Iterate(
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
row := &{{.Type}}{} row := &{{.Type}}{}
err := translateError(rows.Scan({{.ScanArgs}})) err := rows.Scan({{.ScanArgs}})
if !yield(row, err) { if !yield(row, err) {
return return
} }

View File

@ -5,14 +5,19 @@ import (
"os" "os"
) )
const (
DRIVER_SQLITE = "sqlite"
DRIVER_POSTGRES = "postgres"
)
func Main() { func Main() {
usage := func() { usage := func() {
fmt.Fprintf(os.Stderr, ` fmt.Fprintf(os.Stderr, `
%s DRIVER DEFS_PATH OUTPUT_PATH %s DRIVER DEFS_PATH OUTPUT_PATH
Drivers are one of: sqlite, postgres Drivers are one of: %s %s
`, `,
os.Args[0]) os.Args[0], DRIVER_SQLITE, DRIVER_POSTGRES)
os.Exit(1) os.Exit(1)
} }
@ -21,21 +26,20 @@ Drivers are one of: sqlite, postgres
} }
var ( var (
template string
driver = os.Args[1] driver = os.Args[1]
defsPath = os.Args[2] defsPath = os.Args[2]
outputPath = os.Args[3] outputPath = os.Args[3]
) )
switch driver { switch driver {
case "sqlite": case DRIVER_SQLITE, DRIVER_POSTGRES:
template = sqliteTemplate // OK
default: default:
fmt.Fprintf(os.Stderr, "Unknown driver: %s", driver) fmt.Fprintf(os.Stderr, "Unknown driver: %s", driver)
usage() usage()
} }
err := render(template, defsPath, outputPath) err := render(driver, defsPath, outputPath)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "Error: %v", err) fmt.Fprintf(os.Stderr, "Error: %v", err)
os.Exit(1) os.Exit(1)

View File

@ -6,15 +6,15 @@ import (
"strings" "strings"
) )
func parsePath(filePath string) (*schema, error) { func parsePath(driver, filePath string) (*schema, error) {
fileBytes, err := os.ReadFile(filePath) fileBytes, err := os.ReadFile(filePath)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return parseBytes(fileBytes) return parseBytes(driver, fileBytes)
} }
func parseBytes(fileBytes []byte) (*schema, error) { func parseBytes(driver string, fileBytes []byte) (*schema, error) {
s := string(fileBytes) s := string(fileBytes)
for _, c := range []string{",", "(", ")", ";"} { for _, c := range []string{",", "(", ")", ";"} {
s = strings.ReplaceAll(s, c, " "+c+" ") s = strings.ReplaceAll(s, c, " "+c+" ")
@ -29,7 +29,7 @@ func parseBytes(fileBytes []byte) (*schema, error) {
for len(tokens) > 0 { for len(tokens) > 0 {
switch tokens[0] { switch tokens[0] {
case "TABLE": case "TABLE":
tokens, err = parseTable(schema, tokens) tokens, err = parseTable(driver, schema, tokens)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -42,7 +42,7 @@ func parseBytes(fileBytes []byte) (*schema, error) {
return schema, nil return schema, nil
} }
func parseTable(schema *schema, tokens []string) ([]string, error) { func parseTable(driver string, schema *schema, tokens []string) ([]string, error) {
tokens = tokens[1:] tokens = tokens[1:]
if len(tokens) < 3 { if len(tokens) < 3 {
return tokens, errors.New("incomplete table definition") return tokens, errors.New("incomplete table definition")
@ -52,8 +52,9 @@ func parseTable(schema *schema, tokens []string) ([]string, error) {
} }
table := &table{ table := &table{
Name: tokens[0], driver: driver,
Type: tokens[2], Name: tokens[0],
Type: tokens[2],
} }
schema.Tables = append(schema.Tables, table) schema.Tables = append(schema.Tables, table)

View File

@ -22,7 +22,7 @@ func TestParse(t *testing.T) {
for _, defPath := range paths { for _, defPath := range paths {
t.Run(filepath.Base(defPath), func(t *testing.T) { t.Run(filepath.Base(defPath), func(t *testing.T) {
parsed, err := parsePath(defPath) parsed, err := parsePath("", defPath)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -10,6 +10,7 @@ type schema struct {
} }
type table struct { type table struct {
driver string //
Name string // Name in SQL Name string // Name in SQL
Type string // Go type Type string // Go type
NoInsert bool NoInsert bool
@ -29,6 +30,19 @@ type column struct {
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
func (t *table) translateQuery(in string) string {
if t.driver == DRIVER_SQLITE {
return in
}
i := 1
for strings.Contains(in, "?") {
in = strings.Replace(in, "?", fmt.Sprintf("$%d", i), 1)
i++
}
return in
}
func (t *table) colSQLNames() []string { func (t *table) colSQLNames() []string {
names := make([]string, len(t.Columns)) names := make([]string, len(t.Columns))
for i := range names { for i := range names {
@ -72,7 +86,7 @@ func (t *table) InsertQuery() string {
b.WriteString(`?`) b.WriteString(`?`)
} }
b.WriteString(`)`) b.WriteString(`)`)
return b.String() return t.translateQuery(b.String())
} }
func (t *table) InsertArgs() string { func (t *table) InsertArgs() string {
@ -116,7 +130,7 @@ func (t *table) UpdateQuery() string {
b.WriteString(` ` + c.SqlName + `=?`) b.WriteString(` ` + c.SqlName + `=?`)
} }
return b.String() return t.translateQuery(b.String())
} }
func (t *table) UpdateArgs() string { func (t *table) UpdateArgs() string {
@ -166,7 +180,7 @@ func (t *table) UpdateFullQuery() string {
b.WriteString(` ` + c.SqlName + `=?`) b.WriteString(` ` + c.SqlName + `=?`)
} }
return b.String() return t.translateQuery(b.String())
} }
func (t *table) UpdateFullArgs() string { func (t *table) UpdateFullArgs() string {
@ -222,7 +236,7 @@ func (t *table) DeleteQuery() string {
b.WriteString(col.SqlName) b.WriteString(col.SqlName)
b.WriteString(`=?`) b.WriteString(`=?`)
} }
return b.String() return t.translateQuery(b.String())
} }
func (t *table) DeleteArgs() string { func (t *table) DeleteArgs() string {
@ -248,7 +262,7 @@ func (t *table) GetQuery() string {
} }
b.WriteString(col.SqlName + `=?`) b.WriteString(col.SqlName + `=?`)
} }
return b.String() return t.translateQuery(b.String())
} }
func (t *table) ScanArgs() string { func (t *table) ScanArgs() string {

View File

@ -8,16 +8,16 @@ import (
"text/template" "text/template"
) )
//go:embed sqlite.go.tmpl //go:embed gen.go.tmpl
var sqliteTemplate string var fileTemplate string
func render(templateStr, schemaPath, outputPath string) error { func render(driver, schemaPath, outputPath string) error {
sch, err := parsePath(schemaPath) sch, err := parsePath(driver, schemaPath)
if err != nil { if err != nil {
return err return err
} }
tmpl := template.Must(template.New("").Parse(templateStr)) tmpl := template.Must(template.New("").Parse(fileTemplate))
fOut, err := os.Create(outputPath) fOut, err := os.Create(outputPath)
if err != nil { if err != nil {
return err return err