package db import ( "fmt" "reflect" "strings" "github.com/go-pg/pg/v10" "github.com/go-pg/pg/v10/orm" "github.com/iancoleman/strcase" "github.com/jinzhu/inflection" "github.com/pkg/errors" "github.com/vmihailenco/tagparser" "fiskerinc.com/modules/logger" ) type Migrator struct { DB *pg.DB DropColumns bool DryRun bool Schema string allowedTypes []string allowedKinds []reflect.Kind } type Column struct { ColumnName string IsNullable bool DataType string OrdinalPosition int ColumnDefault string Unique bool Index bool CompositeKey string } type CountResult struct { Count int32 } func (m *Migrator) Close() { m.DB = nil m.allowedTypes = nil m.allowedKinds = nil } func (m *Migrator) Check(model interface{}) error { err := m.TableExists(model) if err != nil { return err } _, err = orm.NewModel(model) if err != nil { logger.Warn().Err(err).Send() } err = m.UnmatchedFields(model) if err != nil { return err } return nil } func (m *Migrator) DropTable(model interface{}) error { sql := m.sqlDropTable() _, err := m.DB.Model(model).Exec(sql) return err } func (m *Migrator) updateColumns(model interface{}) error { dbColumns, err := m.getExistingColumns(model) if err != nil { return err } modelFields := m.GetFields(reflect.TypeOf(model)) table := m.GetTableName(model) mapOldColumns := m.makeColumnMap(dbColumns) mapModelColumns := m.makeColumnMap(modelFields) tx, err := m.DB.Begin() if err != nil { return err } defer tx.Close() // Check for dropped columns for _, c := range dbColumns { if c.DataType == "tsvector" { continue } if _, ok := mapModelColumns[c.ColumnName]; !ok { logger.Warn().Msgf("%s:%s is not in model. Consider dropping from db.", table, c.ColumnName) } } // Check for new columns for _, c := range modelFields { if c.ColumnName == "id" || c.ColumnName == "table_name" { continue } if _, ok := mapOldColumns[c.ColumnName]; !ok { sql := m.sqlAddColumn(c) r, err := m.DB.Model(model).Exec(sql) if err != nil { logger.Error().Err(err).Send() tx.Rollback() return err } else if r.RowsAffected() > 0 { logger.Info().Msgf("added column %s:%s", table, c.ColumnName) } if c.CompositeKey != "" { logger.Warn().Msgf("%s:%s:%s composite key not created", table, c.ColumnName, c.CompositeKey) } } } tx.Commit() return nil } func (m *Migrator) makeColumnMap(columns []*Column) map[string]*Column { result := make(map[string]*Column, len(columns)) for _, c := range columns { if c.DataType != "struct{}" { result[c.ColumnName] = c } } return result } func (m *Migrator) sqlDropTable() string { return "DROP TABLE ?TableName CASCADE" } func (m *Migrator) sqlAddColumn(c *Column) string { sql := fmt.Sprintf("ALTER TABLE ?TableName ADD COLUMN \"%s\" %s", c.ColumnName, c.DataType) if c.Unique { sql += " UNIQUE" } if !c.IsNullable { sql += " NOT NULL" sql += " DEFAULT " + c.ColumnDefault } return sql } func (m *Migrator) getExistingColumns(model interface{}) ([]*Column, error) { var columns []*Column table := m.GetTableName(model) _, err := m.DB.Query(&columns, "SELECT column_name, ordinal_position, column_default, is_nullable, data_type FROM information_schema.COLUMNS WHERE TABLE_NAME = ?", table) return columns, err } func (m *Migrator) createTable(model interface{}) error { return m.DB.Model(model).CreateTable(&orm.CreateTableOptions{ IfNotExists: true, FKConstraints: true, Temp: true, }) } func (m *Migrator) GetTableName(model interface{}) string { t := reflect.TypeOf(model) if t.Kind() == reflect.Ptr { t = t.Elem() // check if table name is being overriden by struct tag for i := 0; i < t.NumField(); i++ { f := t.Field(i) if f.Name == "tableName" { tag := tagparser.Parse(f.Tag.Get("pg")) if tag.Name != "" { return tag.Name } else { break } } } return inflection.Plural(strcase.ToSnake(t.Name())) } else { return inflection.Plural(strcase.ToSnake(t.Name())) } } func (m *Migrator) getAllowedKinds() []reflect.Kind { if m.allowedKinds == nil { m.allowedKinds = []reflect.Kind{ reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String, } } return m.allowedKinds } func (m *Migrator) getAllowedTypes() []string { if m.allowedTypes == nil { m.allowedTypes = []string{"string", "int", "int32", "int64", "time.Time", "float64", "bool", "uint", "uint8", "uint16", "uint32", "uint64", "uuid.UUID"} for _, t := range m.allowedTypes { m.allowedTypes = append(m.allowedTypes, "*"+t) m.allowedTypes = append(m.allowedTypes, "[]"+t) m.allowedTypes = append(m.allowedTypes, "[]*"+t) } m.allowedTypes = append(m.allowedTypes, "map[string]interface {}") m.allowedTypes = append(m.allowedTypes, "interface {}") m.allowedTypes = append(m.allowedTypes, "struct {}") } return m.allowedTypes } // GetFields returns struct fields and database related metadata func (m *Migrator) GetFields(t reflect.Type) []*Column { var ( res []*Column ) if t.Kind() == reflect.Ptr { t = t.Elem() } if t.Kind() != reflect.Struct { return res } for i := 0; i < t.NumField(); i++ { field := t.Field(i) if field.Anonymous { // embedded struct result := m.GetFields(field.Type) res = append(res, result...) continue } column := &Column{ ColumnName: strcase.ToSnake(field.Name), } fieldType := m.GetType(field.Type) tag := field.Tag.Get("pg") ignoreField := false if len(tag) > 0 { tags := strings.Split(tag, ",") for i, tagS := range tags { s := strings.ToLower(strings.TrimSpace(tagS)) if s == "-" { ignoreField = true } else if s == "unique" { column.Unique = true } else if s == "index" { column.Index = true } else if strings.Contains(s, "unique:") { column.CompositeKey = strings.Replace(s, "unique:", "", 1) } else if strings.Contains(s, "type:") { ss := strings.Split(s, "type:") if len(ss) > 1 { column.DataType = ss[1] } } else if strings.Contains(s, "alias:") { column.ColumnName = strings.Replace(s, "alias:", "", 1) } else if i == 0 && len(s) > 0 && !strings.Contains(s, ":") { column.ColumnName = s } else if s == "notnull" { column.IsNullable = false } } } accept := m.isAllowedType(fieldType) || len(column.DataType) > 0 if ignoreField || !accept { continue } m.configColumn(fieldType, column) res = append(res, column) } return res } func (m *Migrator) GetType(field reflect.Type) string { fieldType := field.String() if m.isAllowedType(fieldType) { return fieldType } // Try to get the underlying data type if m.isAllowedKind(field.Kind()) { return field.Kind().String() } return fieldType } func (m *Migrator) isAllowedType(fieldType string) bool { for _, at := range m.getAllowedTypes() { if at == fieldType { return true } } return false } func (m *Migrator) isAllowedKind(kind reflect.Kind) bool { for _, k := range m.getAllowedKinds() { if k == kind { return true } } return false } func (m *Migrator) configColumn(fieldType string, column *Column) { switch fieldType { case "string": column.DataType = "text" column.IsNullable = false column.ColumnDefault = "''" case "*string": column.DataType = "text" column.IsNullable = true case "[]string", "[]*string": column.DataType = "text[]" case "int": column.DataType = "integer" column.IsNullable = false case "int64": column.DataType = "bigint" column.IsNullable = false case "*int64": column.DataType = "bigint" column.IsNullable = true case "[]int64", "[]*int64": column.DataType = "integer[]" column.IsNullable = true case "time.Time": column.IsNullable = false column.DataType = "timestamp with time zone" column.ColumnDefault = "NOW()" case "*time.Time": column.DataType = "timestamp with time zone" column.IsNullable = true case "float64": column.DataType = "numeric" column.IsNullable = false column.ColumnDefault = "0.00" case "*float64": column.DataType = "numeric" column.IsNullable = true case "[]float64", "[]*float64": column.DataType = "numeric[]" column.IsNullable = true case "bool": column.DataType = "boolean" column.IsNullable = false column.ColumnDefault = "false" case "*bool": column.DataType = "boolean" column.IsNullable = true case "[]bool", "[]*bool": column.DataType = "boolean[]" column.IsNullable = true case "map[string]interface", "interface": column.DataType = "jsonb" column.IsNullable = true case "uint", "uint32", "uint64": column.DataType = "bigint" column.IsNullable = false case "[]uint8": column.DataType = "bytea" case "uuid.UUID": column.DataType = "uuid" case "[]byte", "common.BinaryHex", "*common.BinaryHex": column.DataType = "bytea" column.IsNullable = true default: logger.Warn().Msgf("%s unknown type %s", column.ColumnName, fieldType) } } func (m *Migrator) getSchema() string { if m.Schema == "" { m.Schema = "public" } return m.Schema } func (m *Migrator) TableExists(model interface{}) error { var count []*CountResult tablename := m.GetTableName(model) result, err := m.DB.QueryOne(&count, "SELECT COUNT(*) FROM information_schema.TABLES WHERE table_name = ? AND table_schema = ? AND table_type = 'BASE TABLE'", tablename, m.getSchema()) if err != nil { return errors.WithStack(err) } if result.RowsReturned() == 0 || count[0].Count == 0 { return errors.WithMessagef(ErrTableDoesntExist, "for model %t", model) } return nil } // UnmatchedFields returns list of model fields that were not found in the DB. func (m *Migrator) UnmatchedFields(model interface{}) error { tablename := m.GetTableName(model) dbColumns, err := m.getExistingColumns(model) if err != nil { return err } modelFields := m.GetFields(reflect.TypeOf(model)) mapDBColumns := m.makeColumnMap(dbColumns) mapModelColumns := m.makeColumnMap(modelFields) // Check if db table column exists in model for _, c := range dbColumns { if c.DataType == "tsvector" { continue } if modelProp, ok := mapModelColumns[c.ColumnName]; !ok { logger.Warn().Msgf("%t:%s is not in model. Consider dropping from db.", model, c.ColumnName) } else if modelProp.DataType != c.DataType { logger.Warn().Msgf("Type mismatch db %s:%s should be type %s instead of %s", tablename, c.ColumnName, modelProp.DataType, c.DataType) } } // Check if model property exists in database table for colName := range mapModelColumns { if _, ok := mapDBColumns[colName]; !ok { logger.Warn().Msgf("%s:%s is not in db. Consider dropping from model or add to table.", tablename, colName) } } return nil }