This commit is contained in:
jdl 2024-11-19 16:20:18 +01:00
parent 5577f90f95
commit 0d8cc762c0

View File

@ -2,30 +2,9 @@ package {{.PackageName}}
import ( import (
"database/sql" "database/sql"
"errors"
"iter" "iter"
"github.com/mattn/go-sqlite3"
) )
var (
ErrConstraint = errors.New("constraint violation")
ErrNotFound = errors.New("not found")
)
func translateError(err error) error {
if err == nil {
return nil
}
if e, ok := err.(sqlite3.Error); ok && e.Code == 19 {
return errors.Join(ErrConstraint, err)
}
if errors.Is(err, sql.ErrNoRows) {
return errors.Join(ErrNotFound, err)
}
return err
}
type TX interface { type TX interface {
Exec(query string, args ...any) (sql.Result, error) Exec(query string, args ...any) (sql.Result, error)
Query(query string, args ...any) (*sql.Rows, error) Query(query string, args ...any) (*sql.Rows, error)
@ -57,7 +36,7 @@ func {{.Type}}_Insert(
} }
_, err = tx.Exec("{{.InsertQuery}}", {{.InsertArgs}}) _, err = tx.Exec("{{.InsertQuery}}", {{.InsertArgs}})
return translateError(err) return err
} }
{{- end}} {{/* if not .NoInsert */}} {{- end}} {{/* if not .NoInsert */}}
@ -77,7 +56,7 @@ func {{.Type}}_Update(
result, err := tx.Exec("{{.UpdateQuery}}", {{.UpdateArgs}}) result, err := tx.Exec("{{.UpdateQuery}}", {{.UpdateArgs}})
if err != nil { if err != nil {
return translateError(err) return err
} }
n, err := result.RowsAffected() n, err := result.RowsAffected()
@ -86,7 +65,7 @@ func {{.Type}}_Update(
} }
switch n { switch n {
case 0: case 0:
return ErrNotFound return sql.ErrNoRows
case 1: case 1:
return nil return nil
default: default:
@ -108,7 +87,7 @@ func {{.Type}}_UpdateFull(
result, err := tx.Exec("{{.UpdateFullQuery}}", {{.UpdateFullArgs}}) result, err := tx.Exec("{{.UpdateFullQuery}}", {{.UpdateFullArgs}})
if err != nil { if err != nil {
return translateError(err) return err
} }
n, err := result.RowsAffected() n, err := result.RowsAffected()
@ -117,7 +96,7 @@ func {{.Type}}_UpdateFull(
} }
switch n { switch n {
case 0: case 0:
return ErrNotFound return sql.ErrNoRows
case 1: case 1:
return nil return nil
default: default:
@ -137,7 +116,7 @@ func {{.Type}}_Delete(
) (err error) { ) (err error) {
result, err := tx.Exec("{{.DeleteQuery}}", {{.DeleteArgs}}) result, err := tx.Exec("{{.DeleteQuery}}", {{.DeleteArgs}})
if err != nil { if err != nil {
return translateError(err) return err
} }
n, err := result.RowsAffected() n, err := result.RowsAffected()
@ -146,7 +125,7 @@ func {{.Type}}_Delete(
} }
switch n { switch n {
case 0: case 0:
return ErrNotFound return sql.ErrNoRows
case 1: case 1:
return nil return nil
default: default:
@ -165,7 +144,7 @@ func {{.Type}}_Get(
) { ) {
row = &{{.Type}}{} row = &{{.Type}}{}
r := tx.QueryRow("{{.GetQuery}}", {{.DeleteArgs}}) r := tx.QueryRow("{{.GetQuery}}", {{.DeleteArgs}})
err = translateError(r.Scan({{.ScanArgs}})) err = r.Scan({{.ScanArgs}})
return return
} }
@ -180,7 +159,7 @@ func {{.Type}}_GetWhere(
) { ) {
row = &{{.Type}}{} row = &{{.Type}}{}
r := tx.QueryRow(query, args...) r := tx.QueryRow(query, args...)
err = translateError(r.Scan({{.ScanArgs}})) err = r.Scan({{.ScanArgs}})
return return
} }
@ -204,7 +183,7 @@ func {{.Type}}_Iterate(
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
row := &{{.Type}}{} row := &{{.Type}}{}
err := translateError(rows.Scan({{.ScanArgs}})) err := rows.Scan({{.ScanArgs}})
if !yield(row, err) { if !yield(row, err) {
return return
} }