go/sqlgen/schema.go
2024-11-19 16:30:42 +01:00

278 lines
4.9 KiB
Go

package sqlgen
import (
"fmt"
"strings"
)
type schema struct {
Tables []*table
}
type table struct {
driver string //
Name string // Name in SQL
Type string // Go type
NoInsert bool
NoUpdate bool
NoDelete bool
Columns []*column
}
type column struct {
Name string
Type string
SqlName string // Defaults to Name
PK bool // PK won't be updated
NoInsert bool
NoUpdate bool // Don't update column in update function
}
// ----------------------------------------------------------------------------
func (t *table) translateQuery(in string) string {
if t.driver == DRIVER_SQLITE {
return in
}
i := 1
for strings.Contains(in, "?") {
in = strings.Replace(in, "?", fmt.Sprintf("$%d", i), 1)
i++
}
return in
}
func (t *table) colSQLNames() []string {
names := make([]string, len(t.Columns))
for i := range names {
names[i] = t.Columns[i].SqlName
}
return names
}
func (t *table) SelectQuery() string {
return fmt.Sprintf(`SELECT %s FROM %s`,
strings.Join(t.colSQLNames(), ","),
t.Name)
}
func (t *table) insertCols() (cols []*column) {
for _, c := range t.Columns {
if !c.NoInsert {
cols = append(cols, c)
}
}
return cols
}
func (t *table) InsertQuery() string {
cols := t.insertCols()
b := &strings.Builder{}
b.WriteString(`INSERT INTO `)
b.WriteString(t.Name)
b.WriteString(`(`)
for i, c := range cols {
if i != 0 {
b.WriteString(`,`)
}
b.WriteString(c.SqlName)
}
b.WriteString(`) VALUES(`)
for i := range cols {
if i != 0 {
b.WriteString(`,`)
}
b.WriteString(`?`)
}
b.WriteString(`)`)
return t.translateQuery(b.String())
}
func (t *table) InsertArgs() string {
args := []string{}
for i, col := range t.Columns {
if !col.NoInsert {
args = append(args, "row."+t.Columns[i].Name)
}
}
return strings.Join(args, ", ")
}
func (t *table) UpdateCols() (cols []*column) {
for _, c := range t.Columns {
if !(c.PK || c.NoUpdate) {
cols = append(cols, c)
}
}
return cols
}
func (t *table) UpdateQuery() string {
cols := t.UpdateCols()
b := &strings.Builder{}
b.WriteString(`UPDATE `)
b.WriteString(t.Name + ` SET `)
for i, col := range cols {
if i != 0 {
b.WriteByte(',')
}
b.WriteString(col.SqlName + `=?`)
}
b.WriteString(` WHERE`)
for i, c := range t.pkCols() {
if i != 0 {
b.WriteString(` AND`)
}
b.WriteString(` ` + c.SqlName + `=?`)
}
return t.translateQuery(b.String())
}
func (t *table) UpdateArgs() string {
cols := t.UpdateCols()
b := &strings.Builder{}
for i, col := range cols {
if i != 0 {
b.WriteString(`, `)
}
b.WriteString("row." + col.Name)
}
for _, col := range t.pkCols() {
b.WriteString(", row." + col.Name)
}
return b.String()
}
func (t *table) UpdateFullCols() (cols []*column) {
for _, c := range t.Columns {
if !c.PK {
cols = append(cols, c)
}
}
return cols
}
func (t *table) UpdateFullQuery() string {
cols := t.UpdateFullCols()
b := &strings.Builder{}
b.WriteString(`UPDATE `)
b.WriteString(t.Name + ` SET `)
for i, col := range cols {
if i != 0 {
b.WriteByte(',')
}
b.WriteString(col.SqlName + `=?`)
}
b.WriteString(` WHERE`)
for i, c := range t.pkCols() {
if i != 0 {
b.WriteString(` AND`)
}
b.WriteString(` ` + c.SqlName + `=?`)
}
return t.translateQuery(b.String())
}
func (t *table) UpdateFullArgs() string {
cols := t.UpdateFullCols()
b := &strings.Builder{}
for i, col := range cols {
if i != 0 {
b.WriteString(`, `)
}
b.WriteString("row." + col.Name)
}
for _, col := range t.pkCols() {
b.WriteString(", row." + col.Name)
}
return b.String()
}
func (t *table) pkCols() (cols []*column) {
for _, c := range t.Columns {
if c.PK {
cols = append(cols, c)
}
}
return cols
}
func (t *table) PKFunctionArgs() string {
b := &strings.Builder{}
for _, col := range t.pkCols() {
b.WriteString(col.Name)
b.WriteString(` `)
b.WriteString(col.Type)
b.WriteString(",\n")
}
return b.String()
}
func (t *table) DeleteQuery() string {
cols := t.pkCols()
b := &strings.Builder{}
b.WriteString(`DELETE FROM `)
b.WriteString(t.Name)
b.WriteString(` WHERE `)
for i, col := range cols {
if i != 0 {
b.WriteString(` AND `)
}
b.WriteString(col.SqlName)
b.WriteString(`=?`)
}
return t.translateQuery(b.String())
}
func (t *table) DeleteArgs() string {
cols := t.pkCols()
b := &strings.Builder{}
for i, col := range cols {
if i != 0 {
b.WriteString(`,`)
}
b.WriteString(col.Name)
}
return b.String()
}
func (t *table) GetQuery() string {
b := &strings.Builder{}
b.WriteString(t.SelectQuery())
b.WriteString(` WHERE `)
for i, col := range t.pkCols() {
if i != 0 {
b.WriteString(` AND `)
}
b.WriteString(col.SqlName + `=?`)
}
return t.translateQuery(b.String())
}
func (t *table) ScanArgs() string {
b := &strings.Builder{}
for i, col := range t.Columns {
if i != 0 {
b.WriteString(`, `)
}
b.WriteString(`&row.` + col.Name)
}
return b.String()
}