package sqlgen import ( "fmt" "strings" ) type schema struct { Tables []*table } type table struct { Name string // Name in SQL Type string // Go type NoInsert bool NoUpdate bool NoDelete bool Columns []*column } type column struct { Name string Type string SqlName string // Defaults to Name PK bool // PK won't be updated NoInsert bool NoUpdate bool // Don't update column in update function } // ---------------------------------------------------------------------------- func (t *table) colSQLNames() []string { names := make([]string, len(t.Columns)) for i := range names { names[i] = t.Columns[i].SqlName } return names } func (t *table) SelectQuery() string { return fmt.Sprintf(`SELECT %s FROM %s`, strings.Join(t.colSQLNames(), ","), t.Name) } func (t *table) insertCols() (cols []*column) { for _, c := range t.Columns { if !c.NoInsert { cols = append(cols, c) } } return cols } func (t *table) InsertQuery() string { cols := t.insertCols() b := &strings.Builder{} b.WriteString(`INSERT INTO `) b.WriteString(t.Name) b.WriteString(`(`) for i, c := range cols { if i != 0 { b.WriteString(`,`) } b.WriteString(c.SqlName) } b.WriteString(`) VALUES(`) for i := range cols { if i != 0 { b.WriteString(`,`) } b.WriteString(`?`) } b.WriteString(`)`) return b.String() } func (t *table) InsertArgs() string { args := []string{} for i, col := range t.Columns { if !col.NoInsert { args = append(args, "row."+t.Columns[i].Name) } } return strings.Join(args, ", ") } func (t *table) UpdateCols() (cols []*column) { for _, c := range t.Columns { if !(c.PK || c.NoUpdate) { cols = append(cols, c) } } return cols } func (t *table) UpdateQuery() string { cols := t.UpdateCols() b := &strings.Builder{} b.WriteString(`UPDATE `) b.WriteString(t.Name + ` SET `) for i, col := range cols { if i != 0 { b.WriteByte(',') } b.WriteString(col.SqlName + `=?`) } b.WriteString(` WHERE`) for i, c := range t.pkCols() { if i != 0 { b.WriteString(` AND`) } b.WriteString(` ` + c.SqlName + `=?`) } return b.String() } func (t *table) UpdateArgs() string { cols := t.UpdateCols() b := &strings.Builder{} for i, col := range cols { if i != 0 { b.WriteString(`, `) } b.WriteString("row." + col.Name) } for _, col := range t.pkCols() { b.WriteString(", row." + col.Name) } return b.String() } func (t *table) UpdateFullCols() (cols []*column) { for _, c := range t.Columns { if !c.PK { cols = append(cols, c) } } return cols } func (t *table) UpdateFullQuery() string { cols := t.UpdateFullCols() b := &strings.Builder{} b.WriteString(`UPDATE `) b.WriteString(t.Name + ` SET `) for i, col := range cols { if i != 0 { b.WriteByte(',') } b.WriteString(col.SqlName + `=?`) } b.WriteString(` WHERE`) for i, c := range t.pkCols() { if i != 0 { b.WriteString(` AND`) } b.WriteString(` ` + c.SqlName + `=?`) } return b.String() } func (t *table) UpdateFullArgs() string { cols := t.UpdateFullCols() b := &strings.Builder{} for i, col := range cols { if i != 0 { b.WriteString(`, `) } b.WriteString("row." + col.Name) } for _, col := range t.pkCols() { b.WriteString(", row." + col.Name) } return b.String() } func (t *table) pkCols() (cols []*column) { for _, c := range t.Columns { if c.PK { cols = append(cols, c) } } return cols } func (t *table) PKFunctionArgs() string { b := &strings.Builder{} for _, col := range t.pkCols() { b.WriteString(col.Name) b.WriteString(` `) b.WriteString(col.Type) b.WriteString(",\n") } return b.String() } func (t *table) DeleteQuery() string { cols := t.pkCols() b := &strings.Builder{} b.WriteString(`DELETE FROM `) b.WriteString(t.Name) b.WriteString(` WHERE `) for i, col := range cols { if i != 0 { b.WriteString(` AND `) } b.WriteString(col.SqlName) b.WriteString(`=?`) } return b.String() } func (t *table) DeleteArgs() string { cols := t.pkCols() b := &strings.Builder{} for i, col := range cols { if i != 0 { b.WriteString(`,`) } b.WriteString(col.Name) } return b.String() } func (t *table) GetQuery() string { b := &strings.Builder{} b.WriteString(t.SelectQuery()) b.WriteString(` WHERE `) for i, col := range t.pkCols() { if i != 0 { b.WriteString(` AND `) } b.WriteString(col.SqlName + `=?`) } return b.String() } func (t *table) ScanArgs() string { b := &strings.Builder{} for i, col := range t.Columns { if i != 0 { b.WriteString(`, `) } b.WriteString(`&row.` + col.Name) } return b.String() }