wip
This commit is contained in:
		| @@ -5,14 +5,19 @@ import ( | ||||
| 	"os" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	DRIVER_SQLITE   = "sqlite" | ||||
| 	DRIVER_POSTGRES = "postgres" | ||||
| ) | ||||
|  | ||||
| func Main() { | ||||
| 	usage := func() { | ||||
| 		fmt.Fprintf(os.Stderr, ` | ||||
| %s DRIVER DEFS_PATH OUTPUT_PATH | ||||
|  | ||||
| Drivers are one of: sqlite, postgres | ||||
| Drivers are one of: %s %s | ||||
| `, | ||||
| 			os.Args[0]) | ||||
| 			os.Args[0], DRIVER_SQLITE, DRIVER_POSTGRES) | ||||
| 		os.Exit(1) | ||||
| 	} | ||||
|  | ||||
| @@ -21,21 +26,20 @@ Drivers are one of: sqlite, postgres | ||||
| 	} | ||||
|  | ||||
| 	var ( | ||||
| 		template   string | ||||
| 		driver     = os.Args[1] | ||||
| 		defsPath   = os.Args[2] | ||||
| 		outputPath = os.Args[3] | ||||
| 	) | ||||
|  | ||||
| 	switch driver { | ||||
| 	case "sqlite": | ||||
| 		template = sqliteTemplate | ||||
| 	case DRIVER_SQLITE, DRIVER_POSTGRES: | ||||
| 		// OK | ||||
| 	default: | ||||
| 		fmt.Fprintf(os.Stderr, "Unknown driver: %s", driver) | ||||
| 		usage() | ||||
| 	} | ||||
|  | ||||
| 	err := render(template, defsPath, outputPath) | ||||
| 	err := render(driver, defsPath, outputPath) | ||||
| 	if err != nil { | ||||
| 		fmt.Fprintf(os.Stderr, "Error: %v", err) | ||||
| 		os.Exit(1) | ||||
|   | ||||
| @@ -6,15 +6,15 @@ import ( | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| func parsePath(filePath string) (*schema, error) { | ||||
| func parsePath(driver, filePath string) (*schema, error) { | ||||
| 	fileBytes, err := os.ReadFile(filePath) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return parseBytes(fileBytes) | ||||
| 	return parseBytes(driver, fileBytes) | ||||
| } | ||||
|  | ||||
| func parseBytes(fileBytes []byte) (*schema, error) { | ||||
| func parseBytes(driver string, fileBytes []byte) (*schema, error) { | ||||
| 	s := string(fileBytes) | ||||
| 	for _, c := range []string{",", "(", ")", ";"} { | ||||
| 		s = strings.ReplaceAll(s, c, " "+c+" ") | ||||
| @@ -29,7 +29,7 @@ func parseBytes(fileBytes []byte) (*schema, error) { | ||||
| 	for len(tokens) > 0 { | ||||
| 		switch tokens[0] { | ||||
| 		case "TABLE": | ||||
| 			tokens, err = parseTable(schema, tokens) | ||||
| 			tokens, err = parseTable(driver, schema, tokens) | ||||
| 			if err != nil { | ||||
| 				return nil, err | ||||
| 			} | ||||
| @@ -42,7 +42,7 @@ func parseBytes(fileBytes []byte) (*schema, error) { | ||||
| 	return schema, nil | ||||
| } | ||||
|  | ||||
| func parseTable(schema *schema, tokens []string) ([]string, error) { | ||||
| func parseTable(driver string, schema *schema, tokens []string) ([]string, error) { | ||||
| 	tokens = tokens[1:] | ||||
| 	if len(tokens) < 3 { | ||||
| 		return tokens, errors.New("incomplete table definition") | ||||
| @@ -52,8 +52,9 @@ func parseTable(schema *schema, tokens []string) ([]string, error) { | ||||
| 	} | ||||
|  | ||||
| 	table := &table{ | ||||
| 		Name: tokens[0], | ||||
| 		Type: tokens[2], | ||||
| 		driver: driver, | ||||
| 		Name:   tokens[0], | ||||
| 		Type:   tokens[2], | ||||
| 	} | ||||
| 	schema.Tables = append(schema.Tables, table) | ||||
|  | ||||
|   | ||||
| @@ -22,7 +22,7 @@ func TestParse(t *testing.T) { | ||||
|  | ||||
| 	for _, defPath := range paths { | ||||
| 		t.Run(filepath.Base(defPath), func(t *testing.T) { | ||||
| 			parsed, err := parsePath(defPath) | ||||
| 			parsed, err := parsePath("", defPath) | ||||
| 			if err != nil { | ||||
| 				t.Fatal(err) | ||||
| 			} | ||||
|   | ||||
| @@ -10,6 +10,7 @@ type schema struct { | ||||
| } | ||||
|  | ||||
| type table struct { | ||||
| 	driver   string // | ||||
| 	Name     string // Name in SQL | ||||
| 	Type     string // Go type | ||||
| 	NoInsert bool | ||||
| @@ -29,6 +30,19 @@ type column struct { | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (t *table) translateQuery(in string) string { | ||||
| 	if t.driver == DRIVER_SQLITE { | ||||
| 		return in | ||||
| 	} | ||||
|  | ||||
| 	i := 1 | ||||
| 	for strings.Contains(in, "?") { | ||||
| 		in = strings.Replace(in, "?", fmt.Sprintf("$%d", i), 1) | ||||
| 		i++ | ||||
| 	} | ||||
| 	return in | ||||
| } | ||||
|  | ||||
| func (t *table) colSQLNames() []string { | ||||
| 	names := make([]string, len(t.Columns)) | ||||
| 	for i := range names { | ||||
| @@ -72,7 +86,7 @@ func (t *table) InsertQuery() string { | ||||
| 		b.WriteString(`?`) | ||||
| 	} | ||||
| 	b.WriteString(`)`) | ||||
| 	return b.String() | ||||
| 	return t.translateQuery(b.String()) | ||||
| } | ||||
|  | ||||
| func (t *table) InsertArgs() string { | ||||
| @@ -116,7 +130,7 @@ func (t *table) UpdateQuery() string { | ||||
| 		b.WriteString(` ` + c.SqlName + `=?`) | ||||
| 	} | ||||
|  | ||||
| 	return b.String() | ||||
| 	return t.translateQuery(b.String()) | ||||
| } | ||||
|  | ||||
| func (t *table) UpdateArgs() string { | ||||
| @@ -166,7 +180,7 @@ func (t *table) UpdateFullQuery() string { | ||||
| 		b.WriteString(` ` + c.SqlName + `=?`) | ||||
| 	} | ||||
|  | ||||
| 	return b.String() | ||||
| 	return t.translateQuery(b.String()) | ||||
| } | ||||
|  | ||||
| func (t *table) UpdateFullArgs() string { | ||||
| @@ -222,7 +236,7 @@ func (t *table) DeleteQuery() string { | ||||
| 		b.WriteString(col.SqlName) | ||||
| 		b.WriteString(`=?`) | ||||
| 	} | ||||
| 	return b.String() | ||||
| 	return t.translateQuery(b.String()) | ||||
| } | ||||
|  | ||||
| func (t *table) DeleteArgs() string { | ||||
| @@ -248,7 +262,7 @@ func (t *table) GetQuery() string { | ||||
| 		} | ||||
| 		b.WriteString(col.SqlName + `=?`) | ||||
| 	} | ||||
| 	return b.String() | ||||
| 	return t.translateQuery(b.String()) | ||||
| } | ||||
|  | ||||
| func (t *table) ScanArgs() string { | ||||
|   | ||||
| @@ -8,16 +8,16 @@ import ( | ||||
| 	"text/template" | ||||
| ) | ||||
|  | ||||
| //go:embed sqlite.go.tmpl | ||||
| var sqliteTemplate string | ||||
| //go:embed gen.go.tmpl | ||||
| var fileTemplate string | ||||
|  | ||||
| func render(templateStr, schemaPath, outputPath string) error { | ||||
| 	sch, err := parsePath(schemaPath) | ||||
| func render(driver, schemaPath, outputPath string) error { | ||||
| 	sch, err := parsePath(driver, schemaPath) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	tmpl := template.Must(template.New("").Parse(templateStr)) | ||||
| 	tmpl := template.Must(template.New("").Parse(fileTemplate)) | ||||
| 	fOut, err := os.Create(outputPath) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
|   | ||||
		Reference in New Issue
	
	Block a user