diff --git a/sqlgen/sqlite.go.tmpl b/sqlgen/gen.go.tmpl similarity index 100% rename from sqlgen/sqlite.go.tmpl rename to sqlgen/gen.go.tmpl diff --git a/sqlgen/main.go b/sqlgen/main.go index 170629b..76f9885 100644 --- a/sqlgen/main.go +++ b/sqlgen/main.go @@ -5,14 +5,19 @@ import ( "os" ) +const ( + DRIVER_SQLITE = "sqlite" + DRIVER_POSTGRES = "postgres" +) + func Main() { usage := func() { fmt.Fprintf(os.Stderr, ` %s DRIVER DEFS_PATH OUTPUT_PATH -Drivers are one of: sqlite, postgres +Drivers are one of: %s %s `, - os.Args[0]) + os.Args[0], DRIVER_SQLITE, DRIVER_POSTGRES) os.Exit(1) } @@ -21,21 +26,20 @@ Drivers are one of: sqlite, postgres } var ( - template string driver = os.Args[1] defsPath = os.Args[2] outputPath = os.Args[3] ) switch driver { - case "sqlite": - template = sqliteTemplate + case DRIVER_SQLITE, DRIVER_POSTGRES: + // OK default: fmt.Fprintf(os.Stderr, "Unknown driver: %s", driver) usage() } - err := render(template, defsPath, outputPath) + err := render(driver, defsPath, outputPath) if err != nil { fmt.Fprintf(os.Stderr, "Error: %v", err) os.Exit(1) diff --git a/sqlgen/parse.go b/sqlgen/parse.go index b1e9c64..1d974f6 100644 --- a/sqlgen/parse.go +++ b/sqlgen/parse.go @@ -6,15 +6,15 @@ import ( "strings" ) -func parsePath(filePath string) (*schema, error) { +func parsePath(driver, filePath string) (*schema, error) { fileBytes, err := os.ReadFile(filePath) if err != nil { return nil, err } - return parseBytes(fileBytes) + return parseBytes(driver, fileBytes) } -func parseBytes(fileBytes []byte) (*schema, error) { +func parseBytes(driver string, fileBytes []byte) (*schema, error) { s := string(fileBytes) for _, c := range []string{",", "(", ")", ";"} { s = strings.ReplaceAll(s, c, " "+c+" ") @@ -29,7 +29,7 @@ func parseBytes(fileBytes []byte) (*schema, error) { for len(tokens) > 0 { switch tokens[0] { case "TABLE": - tokens, err = parseTable(schema, tokens) + tokens, err = parseTable(driver, schema, tokens) if err != nil { return nil, err } @@ -42,7 +42,7 @@ func parseBytes(fileBytes []byte) (*schema, error) { return schema, nil } -func parseTable(schema *schema, tokens []string) ([]string, error) { +func parseTable(driver string, schema *schema, tokens []string) ([]string, error) { tokens = tokens[1:] if len(tokens) < 3 { return tokens, errors.New("incomplete table definition") @@ -52,8 +52,9 @@ func parseTable(schema *schema, tokens []string) ([]string, error) { } table := &table{ - Name: tokens[0], - Type: tokens[2], + driver: driver, + Name: tokens[0], + Type: tokens[2], } schema.Tables = append(schema.Tables, table) diff --git a/sqlgen/parse_test.go b/sqlgen/parse_test.go index 9838b69..0e08c22 100644 --- a/sqlgen/parse_test.go +++ b/sqlgen/parse_test.go @@ -22,7 +22,7 @@ func TestParse(t *testing.T) { for _, defPath := range paths { t.Run(filepath.Base(defPath), func(t *testing.T) { - parsed, err := parsePath(defPath) + parsed, err := parsePath("", defPath) if err != nil { t.Fatal(err) } diff --git a/sqlgen/schema.go b/sqlgen/schema.go index a75dc7b..ba8516f 100644 --- a/sqlgen/schema.go +++ b/sqlgen/schema.go @@ -10,6 +10,7 @@ type schema struct { } type table struct { + driver string // Name string // Name in SQL Type string // Go type NoInsert bool @@ -29,6 +30,19 @@ type column struct { // ---------------------------------------------------------------------------- +func (t *table) translateQuery(in string) string { + if t.driver == DRIVER_SQLITE { + return in + } + + i := 1 + for strings.Contains(in, "?") { + in = strings.Replace(in, "?", fmt.Sprintf("$%d", i), 1) + i++ + } + return in +} + func (t *table) colSQLNames() []string { names := make([]string, len(t.Columns)) for i := range names { @@ -72,7 +86,7 @@ func (t *table) InsertQuery() string { b.WriteString(`?`) } b.WriteString(`)`) - return b.String() + return t.translateQuery(b.String()) } func (t *table) InsertArgs() string { @@ -116,7 +130,7 @@ func (t *table) UpdateQuery() string { b.WriteString(` ` + c.SqlName + `=?`) } - return b.String() + return t.translateQuery(b.String()) } func (t *table) UpdateArgs() string { @@ -166,7 +180,7 @@ func (t *table) UpdateFullQuery() string { b.WriteString(` ` + c.SqlName + `=?`) } - return b.String() + return t.translateQuery(b.String()) } func (t *table) UpdateFullArgs() string { @@ -222,7 +236,7 @@ func (t *table) DeleteQuery() string { b.WriteString(col.SqlName) b.WriteString(`=?`) } - return b.String() + return t.translateQuery(b.String()) } func (t *table) DeleteArgs() string { @@ -248,7 +262,7 @@ func (t *table) GetQuery() string { } b.WriteString(col.SqlName + `=?`) } - return b.String() + return t.translateQuery(b.String()) } func (t *table) ScanArgs() string { diff --git a/sqlgen/template.go b/sqlgen/template.go index 5372c84..70069b2 100644 --- a/sqlgen/template.go +++ b/sqlgen/template.go @@ -8,16 +8,16 @@ import ( "text/template" ) -//go:embed sqlite.go.tmpl -var sqliteTemplate string +//go:embed gen.go.tmpl +var fileTemplate string -func render(templateStr, schemaPath, outputPath string) error { - sch, err := parsePath(schemaPath) +func render(driver, schemaPath, outputPath string) error { + sch, err := parsePath(driver, schemaPath) if err != nil { return err } - tmpl := template.Must(template.New("").Parse(templateStr)) + tmpl := template.Must(template.New("").Parse(fileTemplate)) fOut, err := os.Create(outputPath) if err != nil { return err