This commit is contained in:
jdl 2024-11-19 16:20:18 +01:00
parent 5577f90f95
commit 0d8cc762c0

View File

@ -2,30 +2,9 @@ package {{.PackageName}}
import (
"database/sql"
"errors"
"iter"
"github.com/mattn/go-sqlite3"
)
var (
ErrConstraint = errors.New("constraint violation")
ErrNotFound = errors.New("not found")
)
func translateError(err error) error {
if err == nil {
return nil
}
if e, ok := err.(sqlite3.Error); ok && e.Code == 19 {
return errors.Join(ErrConstraint, err)
}
if errors.Is(err, sql.ErrNoRows) {
return errors.Join(ErrNotFound, err)
}
return err
}
type TX interface {
Exec(query string, args ...any) (sql.Result, error)
Query(query string, args ...any) (*sql.Rows, error)
@ -57,7 +36,7 @@ func {{.Type}}_Insert(
}
_, err = tx.Exec("{{.InsertQuery}}", {{.InsertArgs}})
return translateError(err)
return err
}
{{- end}} {{/* if not .NoInsert */}}
@ -77,7 +56,7 @@ func {{.Type}}_Update(
result, err := tx.Exec("{{.UpdateQuery}}", {{.UpdateArgs}})
if err != nil {
return translateError(err)
return err
}
n, err := result.RowsAffected()
@ -86,7 +65,7 @@ func {{.Type}}_Update(
}
switch n {
case 0:
return ErrNotFound
return sql.ErrNoRows
case 1:
return nil
default:
@ -108,7 +87,7 @@ func {{.Type}}_UpdateFull(
result, err := tx.Exec("{{.UpdateFullQuery}}", {{.UpdateFullArgs}})
if err != nil {
return translateError(err)
return err
}
n, err := result.RowsAffected()
@ -117,7 +96,7 @@ func {{.Type}}_UpdateFull(
}
switch n {
case 0:
return ErrNotFound
return sql.ErrNoRows
case 1:
return nil
default:
@ -137,7 +116,7 @@ func {{.Type}}_Delete(
) (err error) {
result, err := tx.Exec("{{.DeleteQuery}}", {{.DeleteArgs}})
if err != nil {
return translateError(err)
return err
}
n, err := result.RowsAffected()
@ -146,7 +125,7 @@ func {{.Type}}_Delete(
}
switch n {
case 0:
return ErrNotFound
return sql.ErrNoRows
case 1:
return nil
default:
@ -165,7 +144,7 @@ func {{.Type}}_Get(
) {
row = &{{.Type}}{}
r := tx.QueryRow("{{.GetQuery}}", {{.DeleteArgs}})
err = translateError(r.Scan({{.ScanArgs}}))
err = r.Scan({{.ScanArgs}})
return
}
@ -180,7 +159,7 @@ func {{.Type}}_GetWhere(
) {
row = &{{.Type}}{}
r := tx.QueryRow(query, args...)
err = translateError(r.Scan({{.ScanArgs}}))
err = r.Scan({{.ScanArgs}})
return
}
@ -204,7 +183,7 @@ func {{.Type}}_Iterate(
defer rows.Close()
for rows.Next() {
row := &{{.Type}}{}
err := translateError(rows.Scan({{.ScanArgs}}))
err := rows.Scan({{.ScanArgs}})
if !yield(row, err) {
return
}