462 lines
11 KiB
Go
462 lines
11 KiB
Go
package db
|
|
|
|
import (
|
|
"fmt"
|
|
"reflect"
|
|
"strings"
|
|
|
|
"github.com/go-pg/pg/v10"
|
|
"github.com/go-pg/pg/v10/orm"
|
|
"github.com/iancoleman/strcase"
|
|
"github.com/jinzhu/inflection"
|
|
"github.com/pkg/errors"
|
|
"github.com/vmihailenco/tagparser"
|
|
|
|
"fiskerinc.com/modules/logger"
|
|
)
|
|
|
|
type Migrator struct {
|
|
DB *pg.DB
|
|
DropColumns bool
|
|
DryRun bool
|
|
Schema string
|
|
allowedTypes []string
|
|
allowedKinds []reflect.Kind
|
|
}
|
|
|
|
type Column struct {
|
|
ColumnName string
|
|
IsNullable bool
|
|
DataType string
|
|
OrdinalPosition int
|
|
ColumnDefault string
|
|
Unique bool
|
|
Index bool
|
|
CompositeKey string
|
|
}
|
|
|
|
type CountResult struct {
|
|
Count int32
|
|
}
|
|
|
|
func (m *Migrator) Close() {
|
|
m.DB = nil
|
|
m.allowedTypes = nil
|
|
m.allowedKinds = nil
|
|
}
|
|
|
|
func (m *Migrator) Check(model interface{}) error {
|
|
err := m.TableExists(model)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = orm.NewModel(model)
|
|
if err != nil {
|
|
logger.Warn().Err(err).Send()
|
|
}
|
|
|
|
err = m.UnmatchedFields(model)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *Migrator) DropTable(model interface{}) error {
|
|
sql := m.sqlDropTable()
|
|
_, err := m.DB.Model(model).Exec(sql)
|
|
return err
|
|
}
|
|
|
|
func (m *Migrator) updateColumns(model interface{}) error {
|
|
dbColumns, err := m.getExistingColumns(model)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
modelFields := m.GetFields(reflect.TypeOf(model))
|
|
table := m.GetTableName(model)
|
|
mapOldColumns := m.makeColumnMap(dbColumns)
|
|
mapModelColumns := m.makeColumnMap(modelFields)
|
|
|
|
tx, err := m.DB.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Close()
|
|
|
|
// Check for dropped columns
|
|
for _, c := range dbColumns {
|
|
if c.DataType == "tsvector" {
|
|
continue
|
|
}
|
|
if _, ok := mapModelColumns[c.ColumnName]; !ok {
|
|
logger.Warn().Msgf("%s:%s is not in model. Consider dropping from db.", table, c.ColumnName)
|
|
}
|
|
}
|
|
|
|
// Check for new columns
|
|
for _, c := range modelFields {
|
|
if c.ColumnName == "id" || c.ColumnName == "table_name" {
|
|
continue
|
|
}
|
|
|
|
if _, ok := mapOldColumns[c.ColumnName]; !ok {
|
|
sql := m.sqlAddColumn(c)
|
|
r, err := m.DB.Model(model).Exec(sql)
|
|
if err != nil {
|
|
logger.Error().Err(err).Send()
|
|
tx.Rollback()
|
|
return err
|
|
} else if r.RowsAffected() > 0 {
|
|
logger.Info().Msgf("added column %s:%s", table, c.ColumnName)
|
|
}
|
|
if c.CompositeKey != "" {
|
|
logger.Warn().Msgf("%s:%s:%s composite key not created", table, c.ColumnName, c.CompositeKey)
|
|
}
|
|
}
|
|
}
|
|
|
|
tx.Commit()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *Migrator) makeColumnMap(columns []*Column) map[string]*Column {
|
|
result := make(map[string]*Column, len(columns))
|
|
|
|
for _, c := range columns {
|
|
if c.DataType != "struct{}" {
|
|
result[c.ColumnName] = c
|
|
}
|
|
}
|
|
|
|
return result
|
|
}
|
|
|
|
func (m *Migrator) sqlDropTable() string {
|
|
return "DROP TABLE ?TableName CASCADE"
|
|
}
|
|
|
|
func (m *Migrator) sqlAddColumn(c *Column) string {
|
|
sql := fmt.Sprintf("ALTER TABLE ?TableName ADD COLUMN \"%s\" %s", c.ColumnName, c.DataType)
|
|
if c.Unique {
|
|
sql += " UNIQUE"
|
|
}
|
|
if !c.IsNullable {
|
|
sql += " NOT NULL"
|
|
sql += " DEFAULT " + c.ColumnDefault
|
|
}
|
|
|
|
return sql
|
|
}
|
|
|
|
func (m *Migrator) getExistingColumns(model interface{}) ([]*Column, error) {
|
|
var columns []*Column
|
|
|
|
table := m.GetTableName(model)
|
|
_, err := m.DB.Query(&columns, "SELECT column_name, ordinal_position, column_default, is_nullable, data_type FROM information_schema.COLUMNS WHERE TABLE_NAME = ?", table)
|
|
|
|
return columns, err
|
|
}
|
|
|
|
func (m *Migrator) createTable(model interface{}) error {
|
|
return m.DB.Model(model).CreateTable(&orm.CreateTableOptions{
|
|
IfNotExists: true,
|
|
FKConstraints: true,
|
|
Temp: true,
|
|
})
|
|
}
|
|
|
|
func (m *Migrator) GetTableName(model interface{}) string {
|
|
t := reflect.TypeOf(model)
|
|
if t.Kind() == reflect.Ptr {
|
|
t = t.Elem()
|
|
// check if table name is being overriden by struct tag
|
|
for i := 0; i < t.NumField(); i++ {
|
|
f := t.Field(i)
|
|
if f.Name == "tableName" {
|
|
tag := tagparser.Parse(f.Tag.Get("pg"))
|
|
if tag.Name != "" {
|
|
return tag.Name
|
|
} else {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
return inflection.Plural(strcase.ToSnake(t.Name()))
|
|
} else {
|
|
return inflection.Plural(strcase.ToSnake(t.Name()))
|
|
}
|
|
}
|
|
|
|
func (m *Migrator) getAllowedKinds() []reflect.Kind {
|
|
if m.allowedKinds == nil {
|
|
m.allowedKinds = []reflect.Kind{
|
|
reflect.Bool,
|
|
reflect.Int,
|
|
reflect.Int8,
|
|
reflect.Int16,
|
|
reflect.Int32,
|
|
reflect.Int64,
|
|
reflect.Uint,
|
|
reflect.Uint8,
|
|
reflect.Uint16,
|
|
reflect.Uint32,
|
|
reflect.Uint64,
|
|
reflect.Float32,
|
|
reflect.Float64,
|
|
reflect.String,
|
|
}
|
|
}
|
|
|
|
return m.allowedKinds
|
|
}
|
|
|
|
func (m *Migrator) getAllowedTypes() []string {
|
|
if m.allowedTypes == nil {
|
|
m.allowedTypes = []string{"string", "int", "int32", "int64", "time.Time", "float64", "bool", "uint", "uint8", "uint16", "uint32", "uint64", "uuid.UUID"}
|
|
for _, t := range m.allowedTypes {
|
|
m.allowedTypes = append(m.allowedTypes, "*"+t)
|
|
m.allowedTypes = append(m.allowedTypes, "[]"+t)
|
|
m.allowedTypes = append(m.allowedTypes, "[]*"+t)
|
|
}
|
|
m.allowedTypes = append(m.allowedTypes, "map[string]interface {}")
|
|
m.allowedTypes = append(m.allowedTypes, "interface {}")
|
|
m.allowedTypes = append(m.allowedTypes, "struct {}")
|
|
}
|
|
|
|
return m.allowedTypes
|
|
}
|
|
|
|
// GetFields returns struct fields and database related metadata
|
|
func (m *Migrator) GetFields(t reflect.Type) []*Column {
|
|
var (
|
|
res []*Column
|
|
)
|
|
if t.Kind() == reflect.Ptr {
|
|
t = t.Elem()
|
|
}
|
|
if t.Kind() != reflect.Struct {
|
|
return res
|
|
}
|
|
for i := 0; i < t.NumField(); i++ {
|
|
field := t.Field(i)
|
|
if field.Anonymous {
|
|
// embedded struct
|
|
result := m.GetFields(field.Type)
|
|
res = append(res, result...)
|
|
continue
|
|
}
|
|
column := &Column{
|
|
ColumnName: strcase.ToSnake(field.Name),
|
|
}
|
|
|
|
fieldType := m.GetType(field.Type)
|
|
|
|
tag := field.Tag.Get("pg")
|
|
ignoreField := false
|
|
|
|
if len(tag) > 0 {
|
|
tags := strings.Split(tag, ",")
|
|
for i, tagS := range tags {
|
|
s := strings.ToLower(strings.TrimSpace(tagS))
|
|
if s == "-" {
|
|
ignoreField = true
|
|
} else if s == "unique" {
|
|
column.Unique = true
|
|
} else if s == "index" {
|
|
column.Index = true
|
|
} else if strings.Contains(s, "unique:") {
|
|
column.CompositeKey = strings.Replace(s, "unique:", "", 1)
|
|
} else if strings.Contains(s, "type:") {
|
|
ss := strings.Split(s, "type:")
|
|
if len(ss) > 1 {
|
|
column.DataType = ss[1]
|
|
}
|
|
} else if strings.Contains(s, "alias:") {
|
|
column.ColumnName = strings.Replace(s, "alias:", "", 1)
|
|
} else if i == 0 && len(s) > 0 && !strings.Contains(s, ":") {
|
|
column.ColumnName = s
|
|
} else if s == "notnull" {
|
|
column.IsNullable = false
|
|
}
|
|
}
|
|
}
|
|
|
|
accept := m.isAllowedType(fieldType) || len(column.DataType) > 0
|
|
|
|
if ignoreField || !accept {
|
|
continue
|
|
}
|
|
|
|
m.configColumn(fieldType, column)
|
|
|
|
res = append(res, column)
|
|
}
|
|
|
|
return res
|
|
}
|
|
|
|
func (m *Migrator) GetType(field reflect.Type) string {
|
|
fieldType := field.String()
|
|
|
|
if m.isAllowedType(fieldType) {
|
|
return fieldType
|
|
}
|
|
|
|
// Try to get the underlying data type
|
|
if m.isAllowedKind(field.Kind()) {
|
|
return field.Kind().String()
|
|
}
|
|
|
|
return fieldType
|
|
}
|
|
|
|
func (m *Migrator) isAllowedType(fieldType string) bool {
|
|
for _, at := range m.getAllowedTypes() {
|
|
if at == fieldType {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func (m *Migrator) isAllowedKind(kind reflect.Kind) bool {
|
|
for _, k := range m.getAllowedKinds() {
|
|
if k == kind {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func (m *Migrator) configColumn(fieldType string, column *Column) {
|
|
switch fieldType {
|
|
case "string":
|
|
column.DataType = "text"
|
|
column.IsNullable = false
|
|
column.ColumnDefault = "''"
|
|
case "*string":
|
|
column.DataType = "text"
|
|
column.IsNullable = true
|
|
case "[]string", "[]*string":
|
|
column.DataType = "text[]"
|
|
case "int":
|
|
column.DataType = "integer"
|
|
column.IsNullable = false
|
|
case "int64":
|
|
column.DataType = "bigint"
|
|
column.IsNullable = false
|
|
case "*int64":
|
|
column.DataType = "bigint"
|
|
column.IsNullable = true
|
|
case "[]int64", "[]*int64":
|
|
column.DataType = "integer[]"
|
|
column.IsNullable = true
|
|
case "time.Time":
|
|
column.IsNullable = false
|
|
column.DataType = "timestamp with time zone"
|
|
column.ColumnDefault = "NOW()"
|
|
case "*time.Time":
|
|
column.DataType = "timestamp with time zone"
|
|
column.IsNullable = true
|
|
case "float64":
|
|
column.DataType = "numeric"
|
|
column.IsNullable = false
|
|
column.ColumnDefault = "0.00"
|
|
case "*float64":
|
|
column.DataType = "numeric"
|
|
column.IsNullable = true
|
|
case "[]float64", "[]*float64":
|
|
column.DataType = "numeric[]"
|
|
column.IsNullable = true
|
|
case "bool":
|
|
column.DataType = "boolean"
|
|
column.IsNullable = false
|
|
column.ColumnDefault = "false"
|
|
case "*bool":
|
|
column.DataType = "boolean"
|
|
column.IsNullable = true
|
|
case "[]bool", "[]*bool":
|
|
column.DataType = "boolean[]"
|
|
column.IsNullable = true
|
|
case "map[string]interface", "interface":
|
|
column.DataType = "jsonb"
|
|
column.IsNullable = true
|
|
case "uint", "uint32", "uint64":
|
|
column.DataType = "bigint"
|
|
column.IsNullable = false
|
|
case "[]uint8":
|
|
column.DataType = "bytea"
|
|
case "uuid.UUID":
|
|
column.DataType = "uuid"
|
|
case "[]byte", "common.BinaryHex", "*common.BinaryHex":
|
|
column.DataType = "bytea"
|
|
column.IsNullable = true
|
|
default:
|
|
logger.Warn().Msgf("%s unknown type %s", column.ColumnName, fieldType)
|
|
}
|
|
}
|
|
|
|
func (m *Migrator) getSchema() string {
|
|
if m.Schema == "" {
|
|
m.Schema = "public"
|
|
}
|
|
return m.Schema
|
|
}
|
|
|
|
func (m *Migrator) TableExists(model interface{}) error {
|
|
var count []*CountResult
|
|
tablename := m.GetTableName(model)
|
|
|
|
result, err := m.DB.QueryOne(&count, "SELECT COUNT(*) FROM information_schema.TABLES WHERE table_name = ? AND table_schema = ? AND table_type = 'BASE TABLE'", tablename, m.getSchema())
|
|
if err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
if result.RowsReturned() == 0 || count[0].Count == 0 {
|
|
return errors.WithMessagef(ErrTableDoesntExist, "for model %t", model)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// UnmatchedFields returns list of model fields that were not found in the DB.
|
|
func (m *Migrator) UnmatchedFields(model interface{}) error {
|
|
tablename := m.GetTableName(model)
|
|
dbColumns, err := m.getExistingColumns(model)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
modelFields := m.GetFields(reflect.TypeOf(model))
|
|
|
|
mapDBColumns := m.makeColumnMap(dbColumns)
|
|
mapModelColumns := m.makeColumnMap(modelFields)
|
|
|
|
// Check if db table column exists in model
|
|
for _, c := range dbColumns {
|
|
if c.DataType == "tsvector" {
|
|
continue
|
|
}
|
|
if modelProp, ok := mapModelColumns[c.ColumnName]; !ok {
|
|
logger.Warn().Msgf("%t:%s is not in model. Consider dropping from db.", model, c.ColumnName)
|
|
} else if modelProp.DataType != c.DataType {
|
|
logger.Warn().Msgf("Type mismatch db %s:%s should be type %s instead of %s", tablename, c.ColumnName, modelProp.DataType, c.DataType)
|
|
}
|
|
}
|
|
|
|
// Check if model property exists in database table
|
|
for colName := range mapModelColumns {
|
|
if _, ok := mapDBColumns[colName]; !ok {
|
|
logger.Warn().Msgf("%s:%s is not in db. Consider dropping from model or add to table.", tablename, colName)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|