This commit is contained in:
jdl 2024-11-19 16:30:42 +01:00
parent 0d8cc762c0
commit 9070d8cfc0
6 changed files with 43 additions and 24 deletions

View File

@ -5,14 +5,19 @@ import (
"os"
)
const (
DRIVER_SQLITE = "sqlite"
DRIVER_POSTGRES = "postgres"
)
func Main() {
usage := func() {
fmt.Fprintf(os.Stderr, `
%s DRIVER DEFS_PATH OUTPUT_PATH
Drivers are one of: sqlite, postgres
Drivers are one of: %s %s
`,
os.Args[0])
os.Args[0], DRIVER_SQLITE, DRIVER_POSTGRES)
os.Exit(1)
}
@ -21,21 +26,20 @@ Drivers are one of: sqlite, postgres
}
var (
template string
driver = os.Args[1]
defsPath = os.Args[2]
outputPath = os.Args[3]
)
switch driver {
case "sqlite":
template = sqliteTemplate
case DRIVER_SQLITE, DRIVER_POSTGRES:
// OK
default:
fmt.Fprintf(os.Stderr, "Unknown driver: %s", driver)
usage()
}
err := render(template, defsPath, outputPath)
err := render(driver, defsPath, outputPath)
if err != nil {
fmt.Fprintf(os.Stderr, "Error: %v", err)
os.Exit(1)

View File

@ -6,15 +6,15 @@ import (
"strings"
)
func parsePath(filePath string) (*schema, error) {
func parsePath(driver, filePath string) (*schema, error) {
fileBytes, err := os.ReadFile(filePath)
if err != nil {
return nil, err
}
return parseBytes(fileBytes)
return parseBytes(driver, fileBytes)
}
func parseBytes(fileBytes []byte) (*schema, error) {
func parseBytes(driver string, fileBytes []byte) (*schema, error) {
s := string(fileBytes)
for _, c := range []string{",", "(", ")", ";"} {
s = strings.ReplaceAll(s, c, " "+c+" ")
@ -29,7 +29,7 @@ func parseBytes(fileBytes []byte) (*schema, error) {
for len(tokens) > 0 {
switch tokens[0] {
case "TABLE":
tokens, err = parseTable(schema, tokens)
tokens, err = parseTable(driver, schema, tokens)
if err != nil {
return nil, err
}
@ -42,7 +42,7 @@ func parseBytes(fileBytes []byte) (*schema, error) {
return schema, nil
}
func parseTable(schema *schema, tokens []string) ([]string, error) {
func parseTable(driver string, schema *schema, tokens []string) ([]string, error) {
tokens = tokens[1:]
if len(tokens) < 3 {
return tokens, errors.New("incomplete table definition")
@ -52,8 +52,9 @@ func parseTable(schema *schema, tokens []string) ([]string, error) {
}
table := &table{
Name: tokens[0],
Type: tokens[2],
driver: driver,
Name: tokens[0],
Type: tokens[2],
}
schema.Tables = append(schema.Tables, table)

View File

@ -22,7 +22,7 @@ func TestParse(t *testing.T) {
for _, defPath := range paths {
t.Run(filepath.Base(defPath), func(t *testing.T) {
parsed, err := parsePath(defPath)
parsed, err := parsePath("", defPath)
if err != nil {
t.Fatal(err)
}

View File

@ -10,6 +10,7 @@ type schema struct {
}
type table struct {
driver string //
Name string // Name in SQL
Type string // Go type
NoInsert bool
@ -29,6 +30,19 @@ type column struct {
// ----------------------------------------------------------------------------
func (t *table) translateQuery(in string) string {
if t.driver == DRIVER_SQLITE {
return in
}
i := 1
for strings.Contains(in, "?") {
in = strings.Replace(in, "?", fmt.Sprintf("$%d", i), 1)
i++
}
return in
}
func (t *table) colSQLNames() []string {
names := make([]string, len(t.Columns))
for i := range names {
@ -72,7 +86,7 @@ func (t *table) InsertQuery() string {
b.WriteString(`?`)
}
b.WriteString(`)`)
return b.String()
return t.translateQuery(b.String())
}
func (t *table) InsertArgs() string {
@ -116,7 +130,7 @@ func (t *table) UpdateQuery() string {
b.WriteString(` ` + c.SqlName + `=?`)
}
return b.String()
return t.translateQuery(b.String())
}
func (t *table) UpdateArgs() string {
@ -166,7 +180,7 @@ func (t *table) UpdateFullQuery() string {
b.WriteString(` ` + c.SqlName + `=?`)
}
return b.String()
return t.translateQuery(b.String())
}
func (t *table) UpdateFullArgs() string {
@ -222,7 +236,7 @@ func (t *table) DeleteQuery() string {
b.WriteString(col.SqlName)
b.WriteString(`=?`)
}
return b.String()
return t.translateQuery(b.String())
}
func (t *table) DeleteArgs() string {
@ -248,7 +262,7 @@ func (t *table) GetQuery() string {
}
b.WriteString(col.SqlName + `=?`)
}
return b.String()
return t.translateQuery(b.String())
}
func (t *table) ScanArgs() string {

View File

@ -8,16 +8,16 @@ import (
"text/template"
)
//go:embed sqlite.go.tmpl
var sqliteTemplate string
//go:embed gen.go.tmpl
var fileTemplate string
func render(templateStr, schemaPath, outputPath string) error {
sch, err := parsePath(schemaPath)
func render(driver, schemaPath, outputPath string) error {
sch, err := parsePath(driver, schemaPath)
if err != nil {
return err
}
tmpl := template.Must(template.New("").Parse(templateStr))
tmpl := template.Must(template.New("").Parse(fileTemplate))
fOut, err := os.Create(outputPath)
if err != nil {
return err