wip
This commit is contained in:
parent
0d8cc762c0
commit
9070d8cfc0
@ -5,14 +5,19 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
DRIVER_SQLITE = "sqlite"
|
||||||
|
DRIVER_POSTGRES = "postgres"
|
||||||
|
)
|
||||||
|
|
||||||
func Main() {
|
func Main() {
|
||||||
usage := func() {
|
usage := func() {
|
||||||
fmt.Fprintf(os.Stderr, `
|
fmt.Fprintf(os.Stderr, `
|
||||||
%s DRIVER DEFS_PATH OUTPUT_PATH
|
%s DRIVER DEFS_PATH OUTPUT_PATH
|
||||||
|
|
||||||
Drivers are one of: sqlite, postgres
|
Drivers are one of: %s %s
|
||||||
`,
|
`,
|
||||||
os.Args[0])
|
os.Args[0], DRIVER_SQLITE, DRIVER_POSTGRES)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -21,21 +26,20 @@ Drivers are one of: sqlite, postgres
|
|||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
template string
|
|
||||||
driver = os.Args[1]
|
driver = os.Args[1]
|
||||||
defsPath = os.Args[2]
|
defsPath = os.Args[2]
|
||||||
outputPath = os.Args[3]
|
outputPath = os.Args[3]
|
||||||
)
|
)
|
||||||
|
|
||||||
switch driver {
|
switch driver {
|
||||||
case "sqlite":
|
case DRIVER_SQLITE, DRIVER_POSTGRES:
|
||||||
template = sqliteTemplate
|
// OK
|
||||||
default:
|
default:
|
||||||
fmt.Fprintf(os.Stderr, "Unknown driver: %s", driver)
|
fmt.Fprintf(os.Stderr, "Unknown driver: %s", driver)
|
||||||
usage()
|
usage()
|
||||||
}
|
}
|
||||||
|
|
||||||
err := render(template, defsPath, outputPath)
|
err := render(driver, defsPath, outputPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Fprintf(os.Stderr, "Error: %v", err)
|
fmt.Fprintf(os.Stderr, "Error: %v", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
|
@ -6,15 +6,15 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func parsePath(filePath string) (*schema, error) {
|
func parsePath(driver, filePath string) (*schema, error) {
|
||||||
fileBytes, err := os.ReadFile(filePath)
|
fileBytes, err := os.ReadFile(filePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return parseBytes(fileBytes)
|
return parseBytes(driver, fileBytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseBytes(fileBytes []byte) (*schema, error) {
|
func parseBytes(driver string, fileBytes []byte) (*schema, error) {
|
||||||
s := string(fileBytes)
|
s := string(fileBytes)
|
||||||
for _, c := range []string{",", "(", ")", ";"} {
|
for _, c := range []string{",", "(", ")", ";"} {
|
||||||
s = strings.ReplaceAll(s, c, " "+c+" ")
|
s = strings.ReplaceAll(s, c, " "+c+" ")
|
||||||
@ -29,7 +29,7 @@ func parseBytes(fileBytes []byte) (*schema, error) {
|
|||||||
for len(tokens) > 0 {
|
for len(tokens) > 0 {
|
||||||
switch tokens[0] {
|
switch tokens[0] {
|
||||||
case "TABLE":
|
case "TABLE":
|
||||||
tokens, err = parseTable(schema, tokens)
|
tokens, err = parseTable(driver, schema, tokens)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -42,7 +42,7 @@ func parseBytes(fileBytes []byte) (*schema, error) {
|
|||||||
return schema, nil
|
return schema, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseTable(schema *schema, tokens []string) ([]string, error) {
|
func parseTable(driver string, schema *schema, tokens []string) ([]string, error) {
|
||||||
tokens = tokens[1:]
|
tokens = tokens[1:]
|
||||||
if len(tokens) < 3 {
|
if len(tokens) < 3 {
|
||||||
return tokens, errors.New("incomplete table definition")
|
return tokens, errors.New("incomplete table definition")
|
||||||
@ -52,6 +52,7 @@ func parseTable(schema *schema, tokens []string) ([]string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
table := &table{
|
table := &table{
|
||||||
|
driver: driver,
|
||||||
Name: tokens[0],
|
Name: tokens[0],
|
||||||
Type: tokens[2],
|
Type: tokens[2],
|
||||||
}
|
}
|
||||||
|
@ -22,7 +22,7 @@ func TestParse(t *testing.T) {
|
|||||||
|
|
||||||
for _, defPath := range paths {
|
for _, defPath := range paths {
|
||||||
t.Run(filepath.Base(defPath), func(t *testing.T) {
|
t.Run(filepath.Base(defPath), func(t *testing.T) {
|
||||||
parsed, err := parsePath(defPath)
|
parsed, err := parsePath("", defPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -10,6 +10,7 @@ type schema struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type table struct {
|
type table struct {
|
||||||
|
driver string //
|
||||||
Name string // Name in SQL
|
Name string // Name in SQL
|
||||||
Type string // Go type
|
Type string // Go type
|
||||||
NoInsert bool
|
NoInsert bool
|
||||||
@ -29,6 +30,19 @@ type column struct {
|
|||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func (t *table) translateQuery(in string) string {
|
||||||
|
if t.driver == DRIVER_SQLITE {
|
||||||
|
return in
|
||||||
|
}
|
||||||
|
|
||||||
|
i := 1
|
||||||
|
for strings.Contains(in, "?") {
|
||||||
|
in = strings.Replace(in, "?", fmt.Sprintf("$%d", i), 1)
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
return in
|
||||||
|
}
|
||||||
|
|
||||||
func (t *table) colSQLNames() []string {
|
func (t *table) colSQLNames() []string {
|
||||||
names := make([]string, len(t.Columns))
|
names := make([]string, len(t.Columns))
|
||||||
for i := range names {
|
for i := range names {
|
||||||
@ -72,7 +86,7 @@ func (t *table) InsertQuery() string {
|
|||||||
b.WriteString(`?`)
|
b.WriteString(`?`)
|
||||||
}
|
}
|
||||||
b.WriteString(`)`)
|
b.WriteString(`)`)
|
||||||
return b.String()
|
return t.translateQuery(b.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *table) InsertArgs() string {
|
func (t *table) InsertArgs() string {
|
||||||
@ -116,7 +130,7 @@ func (t *table) UpdateQuery() string {
|
|||||||
b.WriteString(` ` + c.SqlName + `=?`)
|
b.WriteString(` ` + c.SqlName + `=?`)
|
||||||
}
|
}
|
||||||
|
|
||||||
return b.String()
|
return t.translateQuery(b.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *table) UpdateArgs() string {
|
func (t *table) UpdateArgs() string {
|
||||||
@ -166,7 +180,7 @@ func (t *table) UpdateFullQuery() string {
|
|||||||
b.WriteString(` ` + c.SqlName + `=?`)
|
b.WriteString(` ` + c.SqlName + `=?`)
|
||||||
}
|
}
|
||||||
|
|
||||||
return b.String()
|
return t.translateQuery(b.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *table) UpdateFullArgs() string {
|
func (t *table) UpdateFullArgs() string {
|
||||||
@ -222,7 +236,7 @@ func (t *table) DeleteQuery() string {
|
|||||||
b.WriteString(col.SqlName)
|
b.WriteString(col.SqlName)
|
||||||
b.WriteString(`=?`)
|
b.WriteString(`=?`)
|
||||||
}
|
}
|
||||||
return b.String()
|
return t.translateQuery(b.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *table) DeleteArgs() string {
|
func (t *table) DeleteArgs() string {
|
||||||
@ -248,7 +262,7 @@ func (t *table) GetQuery() string {
|
|||||||
}
|
}
|
||||||
b.WriteString(col.SqlName + `=?`)
|
b.WriteString(col.SqlName + `=?`)
|
||||||
}
|
}
|
||||||
return b.String()
|
return t.translateQuery(b.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *table) ScanArgs() string {
|
func (t *table) ScanArgs() string {
|
||||||
|
@ -8,16 +8,16 @@ import (
|
|||||||
"text/template"
|
"text/template"
|
||||||
)
|
)
|
||||||
|
|
||||||
//go:embed sqlite.go.tmpl
|
//go:embed gen.go.tmpl
|
||||||
var sqliteTemplate string
|
var fileTemplate string
|
||||||
|
|
||||||
func render(templateStr, schemaPath, outputPath string) error {
|
func render(driver, schemaPath, outputPath string) error {
|
||||||
sch, err := parsePath(schemaPath)
|
sch, err := parsePath(driver, schemaPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpl := template.Must(template.New("").Parse(templateStr))
|
tmpl := template.Must(template.New("").Parse(fileTemplate))
|
||||||
fOut, err := os.Create(outputPath)
|
fOut, err := os.Create(outputPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
Loading…
Reference in New Issue
Block a user