Initial cloud-services repo - gateway service + pkg modules
This commit is contained in:
461
pkg/db/migrator.go
Normal file
461
pkg/db/migrator.go
Normal file
@@ -0,0 +1,461 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/go-pg/pg/v10"
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
"github.com/iancoleman/strcase"
|
||||
"github.com/jinzhu/inflection"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/vmihailenco/tagparser"
|
||||
|
||||
"fiskerinc.com/modules/logger"
|
||||
)
|
||||
|
||||
type Migrator struct {
|
||||
DB *pg.DB
|
||||
DropColumns bool
|
||||
DryRun bool
|
||||
Schema string
|
||||
allowedTypes []string
|
||||
allowedKinds []reflect.Kind
|
||||
}
|
||||
|
||||
type Column struct {
|
||||
ColumnName string
|
||||
IsNullable bool
|
||||
DataType string
|
||||
OrdinalPosition int
|
||||
ColumnDefault string
|
||||
Unique bool
|
||||
Index bool
|
||||
CompositeKey string
|
||||
}
|
||||
|
||||
type CountResult struct {
|
||||
Count int32
|
||||
}
|
||||
|
||||
func (m *Migrator) Close() {
|
||||
m.DB = nil
|
||||
m.allowedTypes = nil
|
||||
m.allowedKinds = nil
|
||||
}
|
||||
|
||||
func (m *Migrator) Check(model interface{}) error {
|
||||
err := m.TableExists(model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = orm.NewModel(model)
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Send()
|
||||
}
|
||||
|
||||
err = m.UnmatchedFields(model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Migrator) DropTable(model interface{}) error {
|
||||
sql := m.sqlDropTable()
|
||||
_, err := m.DB.Model(model).Exec(sql)
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *Migrator) updateColumns(model interface{}) error {
|
||||
dbColumns, err := m.getExistingColumns(model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
modelFields := m.GetFields(reflect.TypeOf(model))
|
||||
table := m.GetTableName(model)
|
||||
mapOldColumns := m.makeColumnMap(dbColumns)
|
||||
mapModelColumns := m.makeColumnMap(modelFields)
|
||||
|
||||
tx, err := m.DB.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Close()
|
||||
|
||||
// Check for dropped columns
|
||||
for _, c := range dbColumns {
|
||||
if c.DataType == "tsvector" {
|
||||
continue
|
||||
}
|
||||
if _, ok := mapModelColumns[c.ColumnName]; !ok {
|
||||
logger.Warn().Msgf("%s:%s is not in model. Consider dropping from db.", table, c.ColumnName)
|
||||
}
|
||||
}
|
||||
|
||||
// Check for new columns
|
||||
for _, c := range modelFields {
|
||||
if c.ColumnName == "id" || c.ColumnName == "table_name" {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := mapOldColumns[c.ColumnName]; !ok {
|
||||
sql := m.sqlAddColumn(c)
|
||||
r, err := m.DB.Model(model).Exec(sql)
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Send()
|
||||
tx.Rollback()
|
||||
return err
|
||||
} else if r.RowsAffected() > 0 {
|
||||
logger.Info().Msgf("added column %s:%s", table, c.ColumnName)
|
||||
}
|
||||
if c.CompositeKey != "" {
|
||||
logger.Warn().Msgf("%s:%s:%s composite key not created", table, c.ColumnName, c.CompositeKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Migrator) makeColumnMap(columns []*Column) map[string]*Column {
|
||||
result := make(map[string]*Column, len(columns))
|
||||
|
||||
for _, c := range columns {
|
||||
if c.DataType != "struct{}" {
|
||||
result[c.ColumnName] = c
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (m *Migrator) sqlDropTable() string {
|
||||
return "DROP TABLE ?TableName CASCADE"
|
||||
}
|
||||
|
||||
func (m *Migrator) sqlAddColumn(c *Column) string {
|
||||
sql := fmt.Sprintf("ALTER TABLE ?TableName ADD COLUMN \"%s\" %s", c.ColumnName, c.DataType)
|
||||
if c.Unique {
|
||||
sql += " UNIQUE"
|
||||
}
|
||||
if !c.IsNullable {
|
||||
sql += " NOT NULL"
|
||||
sql += " DEFAULT " + c.ColumnDefault
|
||||
}
|
||||
|
||||
return sql
|
||||
}
|
||||
|
||||
func (m *Migrator) getExistingColumns(model interface{}) ([]*Column, error) {
|
||||
var columns []*Column
|
||||
|
||||
table := m.GetTableName(model)
|
||||
_, err := m.DB.Query(&columns, "SELECT column_name, ordinal_position, column_default, is_nullable, data_type FROM information_schema.COLUMNS WHERE TABLE_NAME = ?", table)
|
||||
|
||||
return columns, err
|
||||
}
|
||||
|
||||
func (m *Migrator) createTable(model interface{}) error {
|
||||
return m.DB.Model(model).CreateTable(&orm.CreateTableOptions{
|
||||
IfNotExists: true,
|
||||
FKConstraints: true,
|
||||
Temp: true,
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Migrator) GetTableName(model interface{}) string {
|
||||
t := reflect.TypeOf(model)
|
||||
if t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
// check if table name is being overriden by struct tag
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
f := t.Field(i)
|
||||
if f.Name == "tableName" {
|
||||
tag := tagparser.Parse(f.Tag.Get("pg"))
|
||||
if tag.Name != "" {
|
||||
return tag.Name
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return inflection.Plural(strcase.ToSnake(t.Name()))
|
||||
} else {
|
||||
return inflection.Plural(strcase.ToSnake(t.Name()))
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Migrator) getAllowedKinds() []reflect.Kind {
|
||||
if m.allowedKinds == nil {
|
||||
m.allowedKinds = []reflect.Kind{
|
||||
reflect.Bool,
|
||||
reflect.Int,
|
||||
reflect.Int8,
|
||||
reflect.Int16,
|
||||
reflect.Int32,
|
||||
reflect.Int64,
|
||||
reflect.Uint,
|
||||
reflect.Uint8,
|
||||
reflect.Uint16,
|
||||
reflect.Uint32,
|
||||
reflect.Uint64,
|
||||
reflect.Float32,
|
||||
reflect.Float64,
|
||||
reflect.String,
|
||||
}
|
||||
}
|
||||
|
||||
return m.allowedKinds
|
||||
}
|
||||
|
||||
func (m *Migrator) getAllowedTypes() []string {
|
||||
if m.allowedTypes == nil {
|
||||
m.allowedTypes = []string{"string", "int", "int32", "int64", "time.Time", "float64", "bool", "uint", "uint8", "uint16", "uint32", "uint64", "uuid.UUID"}
|
||||
for _, t := range m.allowedTypes {
|
||||
m.allowedTypes = append(m.allowedTypes, "*"+t)
|
||||
m.allowedTypes = append(m.allowedTypes, "[]"+t)
|
||||
m.allowedTypes = append(m.allowedTypes, "[]*"+t)
|
||||
}
|
||||
m.allowedTypes = append(m.allowedTypes, "map[string]interface {}")
|
||||
m.allowedTypes = append(m.allowedTypes, "interface {}")
|
||||
m.allowedTypes = append(m.allowedTypes, "struct {}")
|
||||
}
|
||||
|
||||
return m.allowedTypes
|
||||
}
|
||||
|
||||
// GetFields returns struct fields and database related metadata
|
||||
func (m *Migrator) GetFields(t reflect.Type) []*Column {
|
||||
var (
|
||||
res []*Column
|
||||
)
|
||||
if t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
}
|
||||
if t.Kind() != reflect.Struct {
|
||||
return res
|
||||
}
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
if field.Anonymous {
|
||||
// embedded struct
|
||||
result := m.GetFields(field.Type)
|
||||
res = append(res, result...)
|
||||
continue
|
||||
}
|
||||
column := &Column{
|
||||
ColumnName: strcase.ToSnake(field.Name),
|
||||
}
|
||||
|
||||
fieldType := m.GetType(field.Type)
|
||||
|
||||
tag := field.Tag.Get("pg")
|
||||
ignoreField := false
|
||||
|
||||
if len(tag) > 0 {
|
||||
tags := strings.Split(tag, ",")
|
||||
for i, tagS := range tags {
|
||||
s := strings.ToLower(strings.TrimSpace(tagS))
|
||||
if s == "-" {
|
||||
ignoreField = true
|
||||
} else if s == "unique" {
|
||||
column.Unique = true
|
||||
} else if s == "index" {
|
||||
column.Index = true
|
||||
} else if strings.Contains(s, "unique:") {
|
||||
column.CompositeKey = strings.Replace(s, "unique:", "", 1)
|
||||
} else if strings.Contains(s, "type:") {
|
||||
ss := strings.Split(s, "type:")
|
||||
if len(ss) > 1 {
|
||||
column.DataType = ss[1]
|
||||
}
|
||||
} else if strings.Contains(s, "alias:") {
|
||||
column.ColumnName = strings.Replace(s, "alias:", "", 1)
|
||||
} else if i == 0 && len(s) > 0 && !strings.Contains(s, ":") {
|
||||
column.ColumnName = s
|
||||
} else if s == "notnull" {
|
||||
column.IsNullable = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
accept := m.isAllowedType(fieldType) || len(column.DataType) > 0
|
||||
|
||||
if ignoreField || !accept {
|
||||
continue
|
||||
}
|
||||
|
||||
m.configColumn(fieldType, column)
|
||||
|
||||
res = append(res, column)
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func (m *Migrator) GetType(field reflect.Type) string {
|
||||
fieldType := field.String()
|
||||
|
||||
if m.isAllowedType(fieldType) {
|
||||
return fieldType
|
||||
}
|
||||
|
||||
// Try to get the underlying data type
|
||||
if m.isAllowedKind(field.Kind()) {
|
||||
return field.Kind().String()
|
||||
}
|
||||
|
||||
return fieldType
|
||||
}
|
||||
|
||||
func (m *Migrator) isAllowedType(fieldType string) bool {
|
||||
for _, at := range m.getAllowedTypes() {
|
||||
if at == fieldType {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Migrator) isAllowedKind(kind reflect.Kind) bool {
|
||||
for _, k := range m.getAllowedKinds() {
|
||||
if k == kind {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Migrator) configColumn(fieldType string, column *Column) {
|
||||
switch fieldType {
|
||||
case "string":
|
||||
column.DataType = "text"
|
||||
column.IsNullable = false
|
||||
column.ColumnDefault = "''"
|
||||
case "*string":
|
||||
column.DataType = "text"
|
||||
column.IsNullable = true
|
||||
case "[]string", "[]*string":
|
||||
column.DataType = "text[]"
|
||||
case "int":
|
||||
column.DataType = "integer"
|
||||
column.IsNullable = false
|
||||
case "int64":
|
||||
column.DataType = "bigint"
|
||||
column.IsNullable = false
|
||||
case "*int64":
|
||||
column.DataType = "bigint"
|
||||
column.IsNullable = true
|
||||
case "[]int64", "[]*int64":
|
||||
column.DataType = "integer[]"
|
||||
column.IsNullable = true
|
||||
case "time.Time":
|
||||
column.IsNullable = false
|
||||
column.DataType = "timestamp with time zone"
|
||||
column.ColumnDefault = "NOW()"
|
||||
case "*time.Time":
|
||||
column.DataType = "timestamp with time zone"
|
||||
column.IsNullable = true
|
||||
case "float64":
|
||||
column.DataType = "numeric"
|
||||
column.IsNullable = false
|
||||
column.ColumnDefault = "0.00"
|
||||
case "*float64":
|
||||
column.DataType = "numeric"
|
||||
column.IsNullable = true
|
||||
case "[]float64", "[]*float64":
|
||||
column.DataType = "numeric[]"
|
||||
column.IsNullable = true
|
||||
case "bool":
|
||||
column.DataType = "boolean"
|
||||
column.IsNullable = false
|
||||
column.ColumnDefault = "false"
|
||||
case "*bool":
|
||||
column.DataType = "boolean"
|
||||
column.IsNullable = true
|
||||
case "[]bool", "[]*bool":
|
||||
column.DataType = "boolean[]"
|
||||
column.IsNullable = true
|
||||
case "map[string]interface", "interface":
|
||||
column.DataType = "jsonb"
|
||||
column.IsNullable = true
|
||||
case "uint", "uint32", "uint64":
|
||||
column.DataType = "bigint"
|
||||
column.IsNullable = false
|
||||
case "[]uint8":
|
||||
column.DataType = "bytea"
|
||||
case "uuid.UUID":
|
||||
column.DataType = "uuid"
|
||||
case "[]byte", "common.BinaryHex", "*common.BinaryHex":
|
||||
column.DataType = "bytea"
|
||||
column.IsNullable = true
|
||||
default:
|
||||
logger.Warn().Msgf("%s unknown type %s", column.ColumnName, fieldType)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Migrator) getSchema() string {
|
||||
if m.Schema == "" {
|
||||
m.Schema = "public"
|
||||
}
|
||||
return m.Schema
|
||||
}
|
||||
|
||||
func (m *Migrator) TableExists(model interface{}) error {
|
||||
var count []*CountResult
|
||||
tablename := m.GetTableName(model)
|
||||
|
||||
result, err := m.DB.QueryOne(&count, "SELECT COUNT(*) FROM information_schema.TABLES WHERE table_name = ? AND table_schema = ? AND table_type = 'BASE TABLE'", tablename, m.getSchema())
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
if result.RowsReturned() == 0 || count[0].Count == 0 {
|
||||
return errors.WithMessagef(ErrTableDoesntExist, "for model %t", model)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmatchedFields returns list of model fields that were not found in the DB.
|
||||
func (m *Migrator) UnmatchedFields(model interface{}) error {
|
||||
tablename := m.GetTableName(model)
|
||||
dbColumns, err := m.getExistingColumns(model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
modelFields := m.GetFields(reflect.TypeOf(model))
|
||||
|
||||
mapDBColumns := m.makeColumnMap(dbColumns)
|
||||
mapModelColumns := m.makeColumnMap(modelFields)
|
||||
|
||||
// Check if db table column exists in model
|
||||
for _, c := range dbColumns {
|
||||
if c.DataType == "tsvector" {
|
||||
continue
|
||||
}
|
||||
if modelProp, ok := mapModelColumns[c.ColumnName]; !ok {
|
||||
logger.Warn().Msgf("%t:%s is not in model. Consider dropping from db.", model, c.ColumnName)
|
||||
} else if modelProp.DataType != c.DataType {
|
||||
logger.Warn().Msgf("Type mismatch db %s:%s should be type %s instead of %s", tablename, c.ColumnName, modelProp.DataType, c.DataType)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if model property exists in database table
|
||||
for colName := range mapModelColumns {
|
||||
if _, ok := mapDBColumns[colName]; !ok {
|
||||
logger.Warn().Msgf("%s:%s is not in db. Consider dropping from model or add to table.", tablename, colName)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user