diff --git a/sqlgen/sqlite.go.tmpl b/sqlgen/sqlite.go.tmpl index d7db01c..bffc078 100644 --- a/sqlgen/sqlite.go.tmpl +++ b/sqlgen/sqlite.go.tmpl @@ -5,6 +5,24 @@ import ( "iter" ) +var ( + ErrConstraint = errors.New("constraint violation") + ErrNotFound = errors.New("not found") +) + +func translateError(err error) error { + if err == nil { + return nil + } + if e, ok := err.(sqlite3.Error); ok && e.Code == 19 { + return ErrConstraint + } + if errors.Is(err, sql.ErrNoRows) { + return ErrNotFound + } + return err +} + type TX interface { Exec(query string, args ...any) (sql.Result, error) Query(query string, args ...any) (*sql.Rows, error) @@ -36,7 +54,7 @@ func {{.Type}}_Insert( } _, err = tx.Exec("{{.InsertQuery}}", {{.InsertArgs}}) - return err + return translateError(err) } {{- end}} {{/* if not .NoInsert */}} @@ -48,27 +66,29 @@ func {{.Type}}_Insert( func {{.Type}}_Update( tx TX, row *{{.Type}}, -) (found bool, err error) { +) (err error) { {{.Type}}_Sanitize(row) if err = {{.Type}}_Validate(row); err != nil { - return false, err + return err } result, err := tx.Exec("{{.UpdateQuery}}", {{.UpdateArgs}}) if err != nil { - return false, err + return translateError(err) } n, err := result.RowsAffected() if err != nil { panic(err) } - - if n > 1 { + switch n { + case 0: + return ErrNotFound + case 1: + return nil + default: panic("multiple rows updated") } - - return n != 0, nil } {{- end}} @@ -77,27 +97,29 @@ func {{.Type}}_Update( func {{.Type}}_UpdateFull( tx TX, row *{{.Type}}, -) (found bool, err error) { +) (err error) { {{.Type}}_Sanitize(row) if err = {{.Type}}_Validate(row); err != nil { - return false, err + return err } result, err := tx.Exec("{{.UpdateFullQuery}}", {{.UpdateFullArgs}}) if err != nil { - return false, err + return translateError(err) } n, err := result.RowsAffected() if err != nil { panic(err) } - - if n > 1 { + switch n { + case 0: + return ErrNotFound + case 1: + return nil + default: panic("multiple rows updated") } - - return n != 0, nil } {{- end}} @@ -109,22 +131,24 @@ func {{.Type}}_UpdateFull( func {{.Type}}_Delete( tx TX, {{.PKFunctionArgs -}} -) (found bool, err error) { +) (err error) { result, err := tx.Exec("{{.DeleteQuery}}", {{.DeleteArgs}}) - if err != nil { - return false, err + if err != nil { + return translateError(err) } n, err := result.RowsAffected() if err != nil { panic(err) } - - if n > 1 { + switch n { + case 0: + return ErrNotFound + case 1: + return nil + default: panic("multiple rows deleted") } - - return n != 0, nil } {{- end}} @@ -138,7 +162,7 @@ func {{.Type}}_Get( ) { row = &{{.Type}}{} r := tx.QueryRow("{{.GetQuery}}", {{.DeleteArgs}}) - err = r.Scan({{.ScanArgs}}) + err = translateError(r.Scan({{.ScanArgs}})) return } @@ -153,7 +177,7 @@ func {{.Type}}_GetWhere( ) { row = &{{.Type}}{} r := tx.QueryRow(query, args...) - err = r.Scan({{.ScanArgs}}) + err = translateError(r.Scan({{.ScanArgs}})) return } @@ -177,7 +201,7 @@ func {{.Type}}_Iterate( defer rows.Close() for rows.Next() { row := &{{.Type}}{} - err := rows.Scan({{.ScanArgs}}) + err := translateError(rows.Scan({{.ScanArgs}})) if !yield(row, err) { return } diff --git a/sqliteutil/errors.go b/sqliteutil/errors.go index b0db407..9bdf060 100644 --- a/sqliteutil/errors.go +++ b/sqliteutil/errors.go @@ -2,7 +2,7 @@ package sqliteutil import "github.com/mattn/go-sqlite3" -func ErrIsDuplicate(err error) bool { +func ErrIsConstraint(err error) bool { e, ok := err.(sqlite3.Error) return ok && e.Code == 19 }