This commit is contained in:
jdl
2024-11-11 06:36:55 +01:00
parent d0587cc585
commit c5419d662e
102 changed files with 4181 additions and 0 deletions

22
sqlgen/README.md Normal file
View File

@@ -0,0 +1,22 @@
# sqlgen
## Installing
```
go install git.crumpington.com/lib/sqlgen/cmd/sqlgen@latest
```
## Usage
```
sqlgen [driver] [defs-path] [output-path]
```
## File Format
```
TABLE [sql-name] OF [go-type] <NoInsert> <NoUpdate> <NoDelete> (
[sql-column] [go-type] <AS go-name> <PK> <NoInsert> <NoUpdate>,
...
);
```

View File

@@ -0,0 +1,7 @@
package main
import "git.crumpington.com/lib/sqlgen"
func main() {
sqlgen.Main()
}

3
sqlgen/go.mod Normal file
View File

@@ -0,0 +1,3 @@
module git.crumpington.com/lib/sqlgen
go 1.23.2

43
sqlgen/main.go Normal file
View File

@@ -0,0 +1,43 @@
package sqlgen
import (
"fmt"
"os"
)
func Main() {
usage := func() {
fmt.Fprintf(os.Stderr, `
%s DRIVER DEFS_PATH OUTPUT_PATH
Drivers are one of: sqlite, postgres
`,
os.Args[0])
os.Exit(1)
}
if len(os.Args) != 4 {
usage()
}
var (
template string
driver = os.Args[1]
defsPath = os.Args[2]
outputPath = os.Args[3]
)
switch driver {
case "sqlite":
template = sqliteTemplate
default:
fmt.Fprintf(os.Stderr, "Unknown driver: %s", driver)
usage()
}
err := render(template, defsPath, outputPath)
if err != nil {
fmt.Fprintf(os.Stderr, "Error: %v", err)
os.Exit(1)
}
}

143
sqlgen/parse.go Normal file
View File

@@ -0,0 +1,143 @@
package sqlgen
import (
"errors"
"os"
"strings"
)
func parsePath(filePath string) (*schema, error) {
fileBytes, err := os.ReadFile(filePath)
if err != nil {
return nil, err
}
return parseBytes(fileBytes)
}
func parseBytes(fileBytes []byte) (*schema, error) {
s := string(fileBytes)
for _, c := range []string{",", "(", ")", ";"} {
s = strings.ReplaceAll(s, c, " "+c+" ")
}
var (
tokens = strings.Fields(s)
schema = &schema{}
err error
)
for len(tokens) > 0 {
switch tokens[0] {
case "TABLE":
tokens, err = parseTable(schema, tokens)
if err != nil {
return nil, err
}
default:
return nil, errors.New("invalid token: " + tokens[0])
}
}
return schema, nil
}
func parseTable(schema *schema, tokens []string) ([]string, error) {
tokens = tokens[1:]
if len(tokens) < 3 {
return tokens, errors.New("incomplete table definition")
}
if tokens[1] != "OF" {
return tokens, errors.New("expected OF in table definition")
}
table := &table{
Name: tokens[0],
Type: tokens[2],
}
schema.Tables = append(schema.Tables, table)
tokens = tokens[3:]
if len(tokens) == 0 {
return tokens, errors.New("missing table definition body")
}
for len(tokens) > 0 {
switch tokens[0] {
case "NoInsert":
table.NoInsert = true
tokens = tokens[1:]
case "NoUpdate":
table.NoUpdate = true
tokens = tokens[1:]
case "NoDelete":
table.NoDelete = true
tokens = tokens[1:]
case "(":
return parseTableBody(table, tokens[1:])
default:
return tokens, errors.New("unexpected token in table definition: " + tokens[0])
}
}
return tokens, errors.New("incomplete table definition")
}
func parseTableBody(table *table, tokens []string) ([]string, error) {
var err error
for len(tokens) > 0 && tokens[0] != ";" {
tokens, err = parseTableColumn(table, tokens)
if err != nil {
return tokens, err
}
}
if len(tokens) < 1 || tokens[0] != ";" {
return tokens, errors.New("incomplete table column definitions")
}
return tokens[1:], nil
}
func parseTableColumn(table *table, tokens []string) ([]string, error) {
if len(tokens) < 2 {
return tokens, errors.New("incomplete column definition")
}
column := &column{
Name: tokens[0],
Type: tokens[1],
SqlName: tokens[0],
}
table.Columns = append(table.Columns, column)
tokens = tokens[2:]
for len(tokens) > 0 && tokens[0] != "," && tokens[0] != ")" {
switch tokens[0] {
case "AS":
if len(tokens) < 2 {
return tokens, errors.New("incomplete AS clause in column definition")
}
column.Name = tokens[1]
tokens = tokens[2:]
case "PK":
column.PK = true
tokens = tokens[1:]
case "NoInsert":
column.NoInsert = true
tokens = tokens[1:]
case "NoUpdate":
column.NoUpdate = true
tokens = tokens[1:]
default:
return tokens, errors.New("unexpected token in column definition: " + tokens[0])
}
}
if len(tokens) == 0 {
return tokens, errors.New("incomplete column definition")
}
return tokens[1:], nil
}

45
sqlgen/parse_test.go Normal file
View File

@@ -0,0 +1,45 @@
package sqlgen
import (
"encoding/json"
"os"
"path/filepath"
"reflect"
"strings"
"testing"
)
func TestParse(t *testing.T) {
toString := func(v any) string {
txt, _ := json.MarshalIndent(v, "", " ")
return string(txt)
}
paths, err := filepath.Glob("test-files/TestParse/*.def")
if err != nil {
t.Fatal(err)
}
for _, defPath := range paths {
t.Run(filepath.Base(defPath), func(t *testing.T) {
parsed, err := parsePath(defPath)
if err != nil {
t.Fatal(err)
}
b, err := os.ReadFile(strings.TrimSuffix(defPath, "def") + "json")
if err != nil {
t.Fatal(err)
}
expected := &schema{}
if err := json.Unmarshal(b, expected); err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(parsed, expected) {
t.Fatalf("%s != %s", toString(parsed), toString(expected))
}
})
}
}

263
sqlgen/schema.go Normal file
View File

@@ -0,0 +1,263 @@
package sqlgen
import (
"fmt"
"strings"
)
type schema struct {
Tables []*table
}
type table struct {
Name string // Name in SQL
Type string // Go type
NoInsert bool
NoUpdate bool
NoDelete bool
Columns []*column
}
type column struct {
Name string
Type string
SqlName string // Defaults to Name
PK bool // PK won't be updated
NoInsert bool
NoUpdate bool // Don't update column in update function
}
// ----------------------------------------------------------------------------
func (t *table) colSQLNames() []string {
names := make([]string, len(t.Columns))
for i := range names {
names[i] = t.Columns[i].SqlName
}
return names
}
func (t *table) SelectQuery() string {
return fmt.Sprintf(`SELECT %s FROM %s`,
strings.Join(t.colSQLNames(), ","),
t.Name)
}
func (t *table) insertCols() (cols []*column) {
for _, c := range t.Columns {
if !c.NoInsert {
cols = append(cols, c)
}
}
return cols
}
func (t *table) InsertQuery() string {
cols := t.insertCols()
b := &strings.Builder{}
b.WriteString(`INSERT INTO `)
b.WriteString(t.Name)
b.WriteString(`(`)
for i, c := range cols {
if i != 0 {
b.WriteString(`,`)
}
b.WriteString(c.SqlName)
}
b.WriteString(`) VALUES(`)
for i := range cols {
if i != 0 {
b.WriteString(`,`)
}
b.WriteString(`?`)
}
b.WriteString(`)`)
return b.String()
}
func (t *table) InsertArgs() string {
args := []string{}
for i, col := range t.Columns {
if !col.NoInsert {
args = append(args, "row."+t.Columns[i].Name)
}
}
return strings.Join(args, ", ")
}
func (t *table) UpdateCols() (cols []*column) {
for _, c := range t.Columns {
if !(c.PK || c.NoUpdate) {
cols = append(cols, c)
}
}
return cols
}
func (t *table) UpdateQuery() string {
cols := t.UpdateCols()
b := &strings.Builder{}
b.WriteString(`UPDATE `)
b.WriteString(t.Name + ` SET `)
for i, col := range cols {
if i != 0 {
b.WriteByte(',')
}
b.WriteString(col.SqlName + `=?`)
}
b.WriteString(` WHERE`)
for i, c := range t.pkCols() {
if i != 0 {
b.WriteString(` AND`)
}
b.WriteString(` ` + c.SqlName + `=?`)
}
return b.String()
}
func (t *table) UpdateArgs() string {
cols := t.UpdateCols()
b := &strings.Builder{}
for i, col := range cols {
if i != 0 {
b.WriteString(`, `)
}
b.WriteString("row." + col.Name)
}
for _, col := range t.pkCols() {
b.WriteString(", row." + col.Name)
}
return b.String()
}
func (t *table) UpdateFullCols() (cols []*column) {
for _, c := range t.Columns {
if !c.PK {
cols = append(cols, c)
}
}
return cols
}
func (t *table) UpdateFullQuery() string {
cols := t.UpdateFullCols()
b := &strings.Builder{}
b.WriteString(`UPDATE `)
b.WriteString(t.Name + ` SET `)
for i, col := range cols {
if i != 0 {
b.WriteByte(',')
}
b.WriteString(col.SqlName + `=?`)
}
b.WriteString(` WHERE`)
for i, c := range t.pkCols() {
if i != 0 {
b.WriteString(` AND`)
}
b.WriteString(` ` + c.SqlName + `=?`)
}
return b.String()
}
func (t *table) UpdateFullArgs() string {
cols := t.UpdateFullCols()
b := &strings.Builder{}
for i, col := range cols {
if i != 0 {
b.WriteString(`, `)
}
b.WriteString("row." + col.Name)
}
for _, col := range t.pkCols() {
b.WriteString(", row." + col.Name)
}
return b.String()
}
func (t *table) pkCols() (cols []*column) {
for _, c := range t.Columns {
if c.PK {
cols = append(cols, c)
}
}
return cols
}
func (t *table) PKFunctionArgs() string {
b := &strings.Builder{}
for _, col := range t.pkCols() {
b.WriteString(col.Name)
b.WriteString(` `)
b.WriteString(col.Type)
b.WriteString(",\n")
}
return b.String()
}
func (t *table) DeleteQuery() string {
cols := t.pkCols()
b := &strings.Builder{}
b.WriteString(`DELETE FROM `)
b.WriteString(t.Name)
b.WriteString(` WHERE `)
for i, col := range cols {
if i != 0 {
b.WriteString(` AND `)
}
b.WriteString(col.SqlName)
b.WriteString(`=?`)
}
return b.String()
}
func (t *table) DeleteArgs() string {
cols := t.pkCols()
b := &strings.Builder{}
for i, col := range cols {
if i != 0 {
b.WriteString(`,`)
}
b.WriteString(col.Name)
}
return b.String()
}
func (t *table) GetQuery() string {
b := &strings.Builder{}
b.WriteString(t.SelectQuery())
b.WriteString(` WHERE `)
for i, col := range t.pkCols() {
if i != 0 {
b.WriteString(` AND `)
}
b.WriteString(col.SqlName + `=?`)
}
return b.String()
}
func (t *table) ScanArgs() string {
b := &strings.Builder{}
for i, col := range t.Columns {
if i != 0 {
b.WriteString(`, `)
}
b.WriteString(`&row.` + col.Name)
}
return b.String()
}

206
sqlgen/sqlite.go.tmpl Normal file
View File

@@ -0,0 +1,206 @@
package {{.PackageName}}
import (
"database/sql"
"iter"
)
type TX interface {
Exec(query string, args ...any) (sql.Result, error)
Query(query string, args ...any) (*sql.Rows, error)
QueryRow(query string, args ...any) *sql.Row
}
{{range .Schema.Tables}}
// ----------------------------------------------------------------------------
// Table: {{.Name}}
// ----------------------------------------------------------------------------
type {{.Type}} struct {
{{- range .Columns}}
{{.Name}} {{.Type}}{{end}}
}
const {{.Type}}_SelectQuery = "{{.SelectQuery}}"
{{if not .NoInsert -}}
func {{.Type}}_Insert(
tx TX,
row *{{.Type}},
) (err error) {
{{.Type}}_Sanitize(row)
if err = {{.Type}}_Validate(row); err != nil {
return err
}
_, err = tx.Exec("{{.InsertQuery}}", {{.InsertArgs}})
return err
}
{{- end}} {{/* if not .NoInsert */}}
{{if not .NoUpdate -}}
{{if .UpdateCols -}}
func {{.Type}}_Update(
tx TX,
row *{{.Type}},
) (found bool, err error) {
{{.Type}}_Sanitize(row)
if err = {{.Type}}_Validate(row); err != nil {
return false, err
}
result, err := tx.Exec("{{.UpdateQuery}}", {{.UpdateArgs}})
if err != nil {
return false, err
}
n, err := result.RowsAffected()
if err != nil {
panic(err)
}
if n > 1 {
panic("multiple rows updated")
}
return n != 0, nil
}
{{- end}}
{{if .UpdateFullCols -}}
func {{.Type}}_UpdateFull(
tx TX,
row *{{.Type}},
) (found bool, err error) {
{{.Type}}_Sanitize(row)
if err = {{.Type}}_Validate(row); err != nil {
return false, err
}
result, err := tx.Exec("{{.UpdateFullQuery}}", {{.UpdateFullArgs}})
if err != nil {
return false, err
}
n, err := result.RowsAffected()
if err != nil {
panic(err)
}
if n > 1 {
panic("multiple rows updated")
}
return n != 0, nil
}
{{- end}}
{{- end}} {{/* if not .NoUpdate */}}
{{if not .NoDelete -}}
func {{.Type}}_Delete(
tx TX,
{{.PKFunctionArgs -}}
) (found bool, err error) {
result, err := tx.Exec("{{.DeleteQuery}}", {{.DeleteArgs}})
if err != nil {
return false, err
}
n, err := result.RowsAffected()
if err != nil {
panic(err)
}
if n > 1 {
panic("multiple rows deleted")
}
return n != 0, nil
}
{{- end}}
func {{.Type}}_Get(
tx TX,
{{.PKFunctionArgs -}}
) (
row *{{.Type}},
err error,
) {
row = &{{.Type}}{}
r := tx.QueryRow("{{.GetQuery}}", {{.DeleteArgs}})
err = r.Scan({{.ScanArgs}})
return
}
func {{.Type}}_GetWhere(
tx TX,
query string,
args ...any,
) (
row *{{.Type}},
err error,
) {
row = &{{.Type}}{}
r := tx.QueryRow(query, args...)
err = r.Scan({{.ScanArgs}})
return
}
func {{.Type}}_Iterate(
tx TX,
query string,
args ...any,
) (
iter.Seq2[*{{.Type}}, error],
) {
rows, err := tx.Query(query, args...)
if err != nil {
return func(yield func(*{{.Type}}, error) bool) {
yield(nil, err)
}
}
return func(yield func(*{{.Type}}, error) bool) {
defer rows.Close()
for rows.Next() {
row := &{{.Type}}{}
err := rows.Scan({{.ScanArgs}})
if !yield(row, err) {
return
}
}
}
}
func {{.Type}}_List(
tx TX,
query string,
args ...any,
) (
l []*{{.Type}},
err error,
) {
for row, err := range {{.Type}}_Iterate(tx, query, args...) {
if err != nil {
return nil, err
}
l = append(l, row)
}
return l, nil
}
{{end}} {{/* range .Schema.Tables */}}

36
sqlgen/template.go Normal file
View File

@@ -0,0 +1,36 @@
package sqlgen
import (
_ "embed"
"os"
"os/exec"
"path/filepath"
"text/template"
)
//go:embed sqlite.go.tmpl
var sqliteTemplate string
func render(templateStr, schemaPath, outputPath string) error {
sch, err := parsePath(schemaPath)
if err != nil {
return err
}
tmpl := template.Must(template.New("").Parse(templateStr))
fOut, err := os.Create(outputPath)
if err != nil {
return err
}
defer fOut.Close()
err = tmpl.Execute(fOut, struct {
PackageName string
Schema *schema
}{filepath.Base(filepath.Dir(outputPath)), sch})
if err != nil {
return err
}
return exec.Command("gofmt", "-w", outputPath).Run()
}

View File

@@ -0,0 +1,3 @@
TABLE users OF User NoDelete (
user_id string AS UserID PK
);

View File

@@ -0,0 +1,17 @@
{
"Tables": [
{
"Name": "users",
"Type": "User",
"NoDelete": true,
"Columns": [
{
"Name": "UserID",
"Type": "string",
"SqlName": "user_id",
"PK": true
}
]
}
]
}

View File

@@ -0,0 +1,4 @@
TABLE users OF User NoDelete (
user_id string AS UserID PK,
email string AS Email NoUpdate
);

View File

@@ -0,0 +1,22 @@
{
"Tables": [
{
"Name": "users",
"Type": "User",
"NoDelete": true,
"Columns": [
{
"Name": "UserID",
"Type": "string",
"SqlName": "user_id",
"PK": true
}, {
"Name": "Email",
"Type": "string",
"SqlName": "email",
"NoUpdate": true
}
]
}
]
}

View File

@@ -0,0 +1,6 @@
TABLE users OF User NoDelete (
user_id string AS UserID PK,
email string AS Email NoUpdate,
name string AS Name NoInsert,
admin bool AS Admin NoInsert NoUpdate
);

View File

@@ -0,0 +1,33 @@
{
"Tables": [
{
"Name": "users",
"Type": "User",
"NoDelete": true,
"Columns": [
{
"Name": "UserID",
"Type": "string",
"SqlName": "user_id",
"PK": true
}, {
"Name": "Email",
"Type": "string",
"SqlName": "email",
"NoUpdate": true
}, {
"Name": "Name",
"Type": "string",
"SqlName": "name",
"NoInsert": true
}, {
"Name": "Admin",
"Type": "bool",
"SqlName": "admin",
"NoInsert": true,
"NoUpdate": true
}
]
}
]
}

View File

@@ -0,0 +1,12 @@
TABLE users OF User NoDelete (
user_id string AS UserID PK,
email string AS Email NoUpdate,
name string AS Name NoInsert,
admin bool AS Admin NoInsert NoUpdate
);
TABLE users_view OF UserView NoInsert NoUpdate NoDelete (
user_id string AS UserID PK,
email string AS Email,
name string AS Name
);

View File

@@ -0,0 +1,61 @@
{
"Tables": [
{
"Name": "users",
"Type": "User",
"NoDelete": true,
"Columns": [
{
"Name": "UserID",
"Type": "string",
"SqlName": "user_id",
"PK": true
},
{
"Name": "Email",
"Type": "string",
"SqlName": "email",
"NoUpdate": true
},
{
"Name": "Name",
"Type": "string",
"SqlName": "name",
"NoInsert": true
},
{
"Name": "Admin",
"Type": "bool",
"SqlName": "admin",
"NoInsert": true,
"NoUpdate": true
}
]
},
{
"Name": "users_view",
"Type": "UserView",
"NoInsert": true,
"NoUpdate": true,
"NoDelete": true,
"Columns": [
{
"Name": "UserID",
"Type": "string",
"SqlName": "user_id",
"PK": true
},
{
"Name": "Email",
"Type": "string",
"SqlName": "email"
},
{
"Name": "Name",
"Type": "string",
"SqlName": "name"
}
]
}
]
}

View File

@@ -0,0 +1,13 @@
TABLE users OF User NoDelete (
user_id string AS UserID PK,
email string AS Email NoUpdate,
name string AS Name NoInsert,
admin bool AS Admin NoInsert NoUpdate,
SSN string NoUpdate
);
TABLE users_view OF UserView NoInsert NoUpdate NoDelete (
user_id string AS UserID PK,
email string AS Email,
name string AS Name
);

View File

@@ -0,0 +1,66 @@
{
"Tables": [
{
"Name": "users",
"Type": "User",
"NoDelete": true,
"Columns": [
{
"Name": "UserID",
"Type": "string",
"SqlName": "user_id",
"PK": true
},
{
"Name": "Email",
"Type": "string",
"SqlName": "email",
"NoUpdate": true
},
{
"Name": "Name",
"Type": "string",
"SqlName": "name",
"NoInsert": true
},
{
"Name": "Admin",
"Type": "bool",
"SqlName": "admin",
"NoInsert": true,
"NoUpdate": true
}, {
"Name": "SSN",
"Type": "string",
"SqlName": "SSN",
"NoUpdate": true
}
]
},
{
"Name": "users_view",
"Type": "UserView",
"NoInsert": true,
"NoUpdate": true,
"NoDelete": true,
"Columns": [
{
"Name": "UserID",
"Type": "string",
"SqlName": "user_id",
"PK": true
},
{
"Name": "Email",
"Type": "string",
"SqlName": "email"
},
{
"Name": "Name",
"Type": "string",
"SqlName": "name"
}
]
}
]
}