diff --git a/sqlgen/sqlite.go.tmpl b/sqlgen/sqlite.go.tmpl index dd8b9c2..cef1f29 100644 --- a/sqlgen/sqlite.go.tmpl +++ b/sqlgen/sqlite.go.tmpl @@ -2,30 +2,9 @@ package {{.PackageName}} import ( "database/sql" - "errors" "iter" - - "github.com/mattn/go-sqlite3" ) -var ( - ErrConstraint = errors.New("constraint violation") - ErrNotFound = errors.New("not found") -) - -func translateError(err error) error { - if err == nil { - return nil - } - if e, ok := err.(sqlite3.Error); ok && e.Code == 19 { - return errors.Join(ErrConstraint, err) - } - if errors.Is(err, sql.ErrNoRows) { - return errors.Join(ErrNotFound, err) - } - return err -} - type TX interface { Exec(query string, args ...any) (sql.Result, error) Query(query string, args ...any) (*sql.Rows, error) @@ -57,7 +36,7 @@ func {{.Type}}_Insert( } _, err = tx.Exec("{{.InsertQuery}}", {{.InsertArgs}}) - return translateError(err) + return err } {{- end}} {{/* if not .NoInsert */}} @@ -77,7 +56,7 @@ func {{.Type}}_Update( result, err := tx.Exec("{{.UpdateQuery}}", {{.UpdateArgs}}) if err != nil { - return translateError(err) + return err } n, err := result.RowsAffected() @@ -86,7 +65,7 @@ func {{.Type}}_Update( } switch n { case 0: - return ErrNotFound + return sql.ErrNoRows case 1: return nil default: @@ -108,7 +87,7 @@ func {{.Type}}_UpdateFull( result, err := tx.Exec("{{.UpdateFullQuery}}", {{.UpdateFullArgs}}) if err != nil { - return translateError(err) + return err } n, err := result.RowsAffected() @@ -117,7 +96,7 @@ func {{.Type}}_UpdateFull( } switch n { case 0: - return ErrNotFound + return sql.ErrNoRows case 1: return nil default: @@ -137,7 +116,7 @@ func {{.Type}}_Delete( ) (err error) { result, err := tx.Exec("{{.DeleteQuery}}", {{.DeleteArgs}}) if err != nil { - return translateError(err) + return err } n, err := result.RowsAffected() @@ -146,7 +125,7 @@ func {{.Type}}_Delete( } switch n { case 0: - return ErrNotFound + return sql.ErrNoRows case 1: return nil default: @@ -165,7 +144,7 @@ func {{.Type}}_Get( ) { row = &{{.Type}}{} r := tx.QueryRow("{{.GetQuery}}", {{.DeleteArgs}}) - err = translateError(r.Scan({{.ScanArgs}})) + err = r.Scan({{.ScanArgs}}) return } @@ -180,7 +159,7 @@ func {{.Type}}_GetWhere( ) { row = &{{.Type}}{} r := tx.QueryRow(query, args...) - err = translateError(r.Scan({{.ScanArgs}})) + err = r.Scan({{.ScanArgs}}) return } @@ -204,7 +183,7 @@ func {{.Type}}_Iterate( defer rows.Close() for rows.Next() { row := &{{.Type}}{} - err := translateError(rows.Scan({{.ScanArgs}})) + err := rows.Scan({{.ScanArgs}}) if !yield(row, err) { return }