Initial cloud-services repo - gateway service + pkg modules

This commit is contained in:
Chris Rai
2026-01-30 23:14:52 -05:00
commit fbb820d7b3
1037 changed files with 171318 additions and 0 deletions

461
pkg/db/migrator.go Normal file
View File

@@ -0,0 +1,461 @@
package db
import (
"fmt"
"reflect"
"strings"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/iancoleman/strcase"
"github.com/jinzhu/inflection"
"github.com/pkg/errors"
"github.com/vmihailenco/tagparser"
"fiskerinc.com/modules/logger"
)
type Migrator struct {
DB *pg.DB
DropColumns bool
DryRun bool
Schema string
allowedTypes []string
allowedKinds []reflect.Kind
}
type Column struct {
ColumnName string
IsNullable bool
DataType string
OrdinalPosition int
ColumnDefault string
Unique bool
Index bool
CompositeKey string
}
type CountResult struct {
Count int32
}
func (m *Migrator) Close() {
m.DB = nil
m.allowedTypes = nil
m.allowedKinds = nil
}
func (m *Migrator) Check(model interface{}) error {
err := m.TableExists(model)
if err != nil {
return err
}
_, err = orm.NewModel(model)
if err != nil {
logger.Warn().Err(err).Send()
}
err = m.UnmatchedFields(model)
if err != nil {
return err
}
return nil
}
func (m *Migrator) DropTable(model interface{}) error {
sql := m.sqlDropTable()
_, err := m.DB.Model(model).Exec(sql)
return err
}
func (m *Migrator) updateColumns(model interface{}) error {
dbColumns, err := m.getExistingColumns(model)
if err != nil {
return err
}
modelFields := m.GetFields(reflect.TypeOf(model))
table := m.GetTableName(model)
mapOldColumns := m.makeColumnMap(dbColumns)
mapModelColumns := m.makeColumnMap(modelFields)
tx, err := m.DB.Begin()
if err != nil {
return err
}
defer tx.Close()
// Check for dropped columns
for _, c := range dbColumns {
if c.DataType == "tsvector" {
continue
}
if _, ok := mapModelColumns[c.ColumnName]; !ok {
logger.Warn().Msgf("%s:%s is not in model. Consider dropping from db.", table, c.ColumnName)
}
}
// Check for new columns
for _, c := range modelFields {
if c.ColumnName == "id" || c.ColumnName == "table_name" {
continue
}
if _, ok := mapOldColumns[c.ColumnName]; !ok {
sql := m.sqlAddColumn(c)
r, err := m.DB.Model(model).Exec(sql)
if err != nil {
logger.Error().Err(err).Send()
tx.Rollback()
return err
} else if r.RowsAffected() > 0 {
logger.Info().Msgf("added column %s:%s", table, c.ColumnName)
}
if c.CompositeKey != "" {
logger.Warn().Msgf("%s:%s:%s composite key not created", table, c.ColumnName, c.CompositeKey)
}
}
}
tx.Commit()
return nil
}
func (m *Migrator) makeColumnMap(columns []*Column) map[string]*Column {
result := make(map[string]*Column, len(columns))
for _, c := range columns {
if c.DataType != "struct{}" {
result[c.ColumnName] = c
}
}
return result
}
func (m *Migrator) sqlDropTable() string {
return "DROP TABLE ?TableName CASCADE"
}
func (m *Migrator) sqlAddColumn(c *Column) string {
sql := fmt.Sprintf("ALTER TABLE ?TableName ADD COLUMN \"%s\" %s", c.ColumnName, c.DataType)
if c.Unique {
sql += " UNIQUE"
}
if !c.IsNullable {
sql += " NOT NULL"
sql += " DEFAULT " + c.ColumnDefault
}
return sql
}
func (m *Migrator) getExistingColumns(model interface{}) ([]*Column, error) {
var columns []*Column
table := m.GetTableName(model)
_, err := m.DB.Query(&columns, "SELECT column_name, ordinal_position, column_default, is_nullable, data_type FROM information_schema.COLUMNS WHERE TABLE_NAME = ?", table)
return columns, err
}
func (m *Migrator) createTable(model interface{}) error {
return m.DB.Model(model).CreateTable(&orm.CreateTableOptions{
IfNotExists: true,
FKConstraints: true,
Temp: true,
})
}
func (m *Migrator) GetTableName(model interface{}) string {
t := reflect.TypeOf(model)
if t.Kind() == reflect.Ptr {
t = t.Elem()
// check if table name is being overriden by struct tag
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
if f.Name == "tableName" {
tag := tagparser.Parse(f.Tag.Get("pg"))
if tag.Name != "" {
return tag.Name
} else {
break
}
}
}
return inflection.Plural(strcase.ToSnake(t.Name()))
} else {
return inflection.Plural(strcase.ToSnake(t.Name()))
}
}
func (m *Migrator) getAllowedKinds() []reflect.Kind {
if m.allowedKinds == nil {
m.allowedKinds = []reflect.Kind{
reflect.Bool,
reflect.Int,
reflect.Int8,
reflect.Int16,
reflect.Int32,
reflect.Int64,
reflect.Uint,
reflect.Uint8,
reflect.Uint16,
reflect.Uint32,
reflect.Uint64,
reflect.Float32,
reflect.Float64,
reflect.String,
}
}
return m.allowedKinds
}
func (m *Migrator) getAllowedTypes() []string {
if m.allowedTypes == nil {
m.allowedTypes = []string{"string", "int", "int32", "int64", "time.Time", "float64", "bool", "uint", "uint8", "uint16", "uint32", "uint64", "uuid.UUID"}
for _, t := range m.allowedTypes {
m.allowedTypes = append(m.allowedTypes, "*"+t)
m.allowedTypes = append(m.allowedTypes, "[]"+t)
m.allowedTypes = append(m.allowedTypes, "[]*"+t)
}
m.allowedTypes = append(m.allowedTypes, "map[string]interface {}")
m.allowedTypes = append(m.allowedTypes, "interface {}")
m.allowedTypes = append(m.allowedTypes, "struct {}")
}
return m.allowedTypes
}
// GetFields returns struct fields and database related metadata
func (m *Migrator) GetFields(t reflect.Type) []*Column {
var (
res []*Column
)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
return res
}
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
if field.Anonymous {
// embedded struct
result := m.GetFields(field.Type)
res = append(res, result...)
continue
}
column := &Column{
ColumnName: strcase.ToSnake(field.Name),
}
fieldType := m.GetType(field.Type)
tag := field.Tag.Get("pg")
ignoreField := false
if len(tag) > 0 {
tags := strings.Split(tag, ",")
for i, tagS := range tags {
s := strings.ToLower(strings.TrimSpace(tagS))
if s == "-" {
ignoreField = true
} else if s == "unique" {
column.Unique = true
} else if s == "index" {
column.Index = true
} else if strings.Contains(s, "unique:") {
column.CompositeKey = strings.Replace(s, "unique:", "", 1)
} else if strings.Contains(s, "type:") {
ss := strings.Split(s, "type:")
if len(ss) > 1 {
column.DataType = ss[1]
}
} else if strings.Contains(s, "alias:") {
column.ColumnName = strings.Replace(s, "alias:", "", 1)
} else if i == 0 && len(s) > 0 && !strings.Contains(s, ":") {
column.ColumnName = s
} else if s == "notnull" {
column.IsNullable = false
}
}
}
accept := m.isAllowedType(fieldType) || len(column.DataType) > 0
if ignoreField || !accept {
continue
}
m.configColumn(fieldType, column)
res = append(res, column)
}
return res
}
func (m *Migrator) GetType(field reflect.Type) string {
fieldType := field.String()
if m.isAllowedType(fieldType) {
return fieldType
}
// Try to get the underlying data type
if m.isAllowedKind(field.Kind()) {
return field.Kind().String()
}
return fieldType
}
func (m *Migrator) isAllowedType(fieldType string) bool {
for _, at := range m.getAllowedTypes() {
if at == fieldType {
return true
}
}
return false
}
func (m *Migrator) isAllowedKind(kind reflect.Kind) bool {
for _, k := range m.getAllowedKinds() {
if k == kind {
return true
}
}
return false
}
func (m *Migrator) configColumn(fieldType string, column *Column) {
switch fieldType {
case "string":
column.DataType = "text"
column.IsNullable = false
column.ColumnDefault = "''"
case "*string":
column.DataType = "text"
column.IsNullable = true
case "[]string", "[]*string":
column.DataType = "text[]"
case "int":
column.DataType = "integer"
column.IsNullable = false
case "int64":
column.DataType = "bigint"
column.IsNullable = false
case "*int64":
column.DataType = "bigint"
column.IsNullable = true
case "[]int64", "[]*int64":
column.DataType = "integer[]"
column.IsNullable = true
case "time.Time":
column.IsNullable = false
column.DataType = "timestamp with time zone"
column.ColumnDefault = "NOW()"
case "*time.Time":
column.DataType = "timestamp with time zone"
column.IsNullable = true
case "float64":
column.DataType = "numeric"
column.IsNullable = false
column.ColumnDefault = "0.00"
case "*float64":
column.DataType = "numeric"
column.IsNullable = true
case "[]float64", "[]*float64":
column.DataType = "numeric[]"
column.IsNullable = true
case "bool":
column.DataType = "boolean"
column.IsNullable = false
column.ColumnDefault = "false"
case "*bool":
column.DataType = "boolean"
column.IsNullable = true
case "[]bool", "[]*bool":
column.DataType = "boolean[]"
column.IsNullable = true
case "map[string]interface", "interface":
column.DataType = "jsonb"
column.IsNullable = true
case "uint", "uint32", "uint64":
column.DataType = "bigint"
column.IsNullable = false
case "[]uint8":
column.DataType = "bytea"
case "uuid.UUID":
column.DataType = "uuid"
case "[]byte", "common.BinaryHex", "*common.BinaryHex":
column.DataType = "bytea"
column.IsNullable = true
default:
logger.Warn().Msgf("%s unknown type %s", column.ColumnName, fieldType)
}
}
func (m *Migrator) getSchema() string {
if m.Schema == "" {
m.Schema = "public"
}
return m.Schema
}
func (m *Migrator) TableExists(model interface{}) error {
var count []*CountResult
tablename := m.GetTableName(model)
result, err := m.DB.QueryOne(&count, "SELECT COUNT(*) FROM information_schema.TABLES WHERE table_name = ? AND table_schema = ? AND table_type = 'BASE TABLE'", tablename, m.getSchema())
if err != nil {
return errors.WithStack(err)
}
if result.RowsReturned() == 0 || count[0].Count == 0 {
return errors.WithMessagef(ErrTableDoesntExist, "for model %t", model)
}
return nil
}
// UnmatchedFields returns list of model fields that were not found in the DB.
func (m *Migrator) UnmatchedFields(model interface{}) error {
tablename := m.GetTableName(model)
dbColumns, err := m.getExistingColumns(model)
if err != nil {
return err
}
modelFields := m.GetFields(reflect.TypeOf(model))
mapDBColumns := m.makeColumnMap(dbColumns)
mapModelColumns := m.makeColumnMap(modelFields)
// Check if db table column exists in model
for _, c := range dbColumns {
if c.DataType == "tsvector" {
continue
}
if modelProp, ok := mapModelColumns[c.ColumnName]; !ok {
logger.Warn().Msgf("%t:%s is not in model. Consider dropping from db.", model, c.ColumnName)
} else if modelProp.DataType != c.DataType {
logger.Warn().Msgf("Type mismatch db %s:%s should be type %s instead of %s", tablename, c.ColumnName, modelProp.DataType, c.DataType)
}
}
// Check if model property exists in database table
for colName := range mapModelColumns {
if _, ok := mapDBColumns[colName]; !ok {
logger.Warn().Msgf("%s:%s is not in db. Consider dropping from model or add to table.", tablename, colName)
}
}
return nil
}