Initial cloud-services repo - gateway service + pkg modules
This commit is contained in:
110
pkg/db/db.go
Normal file
110
pkg/db/db.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"fiskerinc.com/modules/logger"
|
||||
"fiskerinc.com/modules/utils/envtool"
|
||||
"github.com/go-pg/pg/v10"
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
pgtrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/go-pg/pg.v10"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrTableDoesntExist = errors.New("table doesn't exist")
|
||||
ErrTableAndModelDontMatch = errors.New("table and model columns don't match")
|
||||
)
|
||||
|
||||
func GetDefaultConn() *pg.DB {
|
||||
host := envtool.GetEnv("DB_HOST", "localhost")
|
||||
port := envtool.GetEnv("DB_PORT", "5432")
|
||||
user := envtool.GetEnv("DB_USER", "postgres")
|
||||
password := envtool.GetEnv("DB_PASSWORD", "REPLACE_ME")
|
||||
dbname := envtool.GetEnv("DB_NAME", "postgres")
|
||||
sslmode := envtool.GetEnv("DB_SSLMODE", "disable")
|
||||
poolSize := envtool.GetEnvInt("DB_POOLSIZE", 10)
|
||||
addr := fmt.Sprintf("%v:%v", host, port)
|
||||
|
||||
logger.Info().Msgf("Initializing database connection %s", addr)
|
||||
conn_str := fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=%s", user, password, host, port, dbname, sslmode)
|
||||
opts, err := pg.ParseURL(conn_str)
|
||||
opts.PoolSize = poolSize
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Send()
|
||||
}
|
||||
conn := pg.Connect(opts)
|
||||
|
||||
// Wrap the connection with the APM hook.
|
||||
pgtrace.Wrap(conn)
|
||||
|
||||
return conn
|
||||
}
|
||||
|
||||
type DBClientInterface interface {
|
||||
GetConn() *pg.DB
|
||||
SetConn(*pg.DB) error
|
||||
InitSchema([]interface{}) error
|
||||
RegisterManyToManyRel(tables []interface{})
|
||||
Close() error
|
||||
}
|
||||
|
||||
type DBClient struct {
|
||||
driver *pg.DB
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func (db *DBClient) Close() error {
|
||||
logger.Info().Msg("Closing database connection")
|
||||
|
||||
if db.driver == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return db.driver.Close()
|
||||
}
|
||||
|
||||
func (db *DBClient) GetConn() *pg.DB {
|
||||
db.once.Do(func() {
|
||||
if db.driver != nil {
|
||||
return
|
||||
}
|
||||
db.driver = GetDefaultConn()
|
||||
})
|
||||
|
||||
return db.driver
|
||||
}
|
||||
|
||||
func (db *DBClient) SetConn(d *pg.DB) error {
|
||||
var err error
|
||||
if db.driver != nil {
|
||||
err = db.driver.Close()
|
||||
}
|
||||
|
||||
db.driver = d
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *DBClient) InitSchema(models []interface{}) error {
|
||||
migrator := Migrator{
|
||||
DB: db.GetConn(),
|
||||
}
|
||||
defer migrator.Close()
|
||||
|
||||
for _, model := range models {
|
||||
err := migrator.Check(model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DBClient) RegisterManyToManyRel(tables []interface{}) {
|
||||
for _, table := range tables {
|
||||
orm.RegisterTable(table)
|
||||
}
|
||||
}
|
||||
24
pkg/db/db_test.go
Normal file
24
pkg/db/db_test.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package db_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/go-pg/pg/v10"
|
||||
|
||||
m "fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/db"
|
||||
"fiskerinc.com/modules/testhelper"
|
||||
)
|
||||
|
||||
func TestDBClient(t *testing.T) {
|
||||
client := db.DBClient{}
|
||||
client.SetConn(pg.Connect(&pg.Options{}))
|
||||
defer client.Close()
|
||||
|
||||
err := client.InitSchema([]interface{}{
|
||||
(*m.Car)(nil),
|
||||
})
|
||||
if err == nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "TestDBCreateSchema", "error", err)
|
||||
}
|
||||
}
|
||||
461
pkg/db/migrator.go
Normal file
461
pkg/db/migrator.go
Normal file
@@ -0,0 +1,461 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/go-pg/pg/v10"
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
"github.com/iancoleman/strcase"
|
||||
"github.com/jinzhu/inflection"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/vmihailenco/tagparser"
|
||||
|
||||
"fiskerinc.com/modules/logger"
|
||||
)
|
||||
|
||||
type Migrator struct {
|
||||
DB *pg.DB
|
||||
DropColumns bool
|
||||
DryRun bool
|
||||
Schema string
|
||||
allowedTypes []string
|
||||
allowedKinds []reflect.Kind
|
||||
}
|
||||
|
||||
type Column struct {
|
||||
ColumnName string
|
||||
IsNullable bool
|
||||
DataType string
|
||||
OrdinalPosition int
|
||||
ColumnDefault string
|
||||
Unique bool
|
||||
Index bool
|
||||
CompositeKey string
|
||||
}
|
||||
|
||||
type CountResult struct {
|
||||
Count int32
|
||||
}
|
||||
|
||||
func (m *Migrator) Close() {
|
||||
m.DB = nil
|
||||
m.allowedTypes = nil
|
||||
m.allowedKinds = nil
|
||||
}
|
||||
|
||||
func (m *Migrator) Check(model interface{}) error {
|
||||
err := m.TableExists(model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = orm.NewModel(model)
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Send()
|
||||
}
|
||||
|
||||
err = m.UnmatchedFields(model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Migrator) DropTable(model interface{}) error {
|
||||
sql := m.sqlDropTable()
|
||||
_, err := m.DB.Model(model).Exec(sql)
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *Migrator) updateColumns(model interface{}) error {
|
||||
dbColumns, err := m.getExistingColumns(model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
modelFields := m.GetFields(reflect.TypeOf(model))
|
||||
table := m.GetTableName(model)
|
||||
mapOldColumns := m.makeColumnMap(dbColumns)
|
||||
mapModelColumns := m.makeColumnMap(modelFields)
|
||||
|
||||
tx, err := m.DB.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Close()
|
||||
|
||||
// Check for dropped columns
|
||||
for _, c := range dbColumns {
|
||||
if c.DataType == "tsvector" {
|
||||
continue
|
||||
}
|
||||
if _, ok := mapModelColumns[c.ColumnName]; !ok {
|
||||
logger.Warn().Msgf("%s:%s is not in model. Consider dropping from db.", table, c.ColumnName)
|
||||
}
|
||||
}
|
||||
|
||||
// Check for new columns
|
||||
for _, c := range modelFields {
|
||||
if c.ColumnName == "id" || c.ColumnName == "table_name" {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := mapOldColumns[c.ColumnName]; !ok {
|
||||
sql := m.sqlAddColumn(c)
|
||||
r, err := m.DB.Model(model).Exec(sql)
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Send()
|
||||
tx.Rollback()
|
||||
return err
|
||||
} else if r.RowsAffected() > 0 {
|
||||
logger.Info().Msgf("added column %s:%s", table, c.ColumnName)
|
||||
}
|
||||
if c.CompositeKey != "" {
|
||||
logger.Warn().Msgf("%s:%s:%s composite key not created", table, c.ColumnName, c.CompositeKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Migrator) makeColumnMap(columns []*Column) map[string]*Column {
|
||||
result := make(map[string]*Column, len(columns))
|
||||
|
||||
for _, c := range columns {
|
||||
if c.DataType != "struct{}" {
|
||||
result[c.ColumnName] = c
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (m *Migrator) sqlDropTable() string {
|
||||
return "DROP TABLE ?TableName CASCADE"
|
||||
}
|
||||
|
||||
func (m *Migrator) sqlAddColumn(c *Column) string {
|
||||
sql := fmt.Sprintf("ALTER TABLE ?TableName ADD COLUMN \"%s\" %s", c.ColumnName, c.DataType)
|
||||
if c.Unique {
|
||||
sql += " UNIQUE"
|
||||
}
|
||||
if !c.IsNullable {
|
||||
sql += " NOT NULL"
|
||||
sql += " DEFAULT " + c.ColumnDefault
|
||||
}
|
||||
|
||||
return sql
|
||||
}
|
||||
|
||||
func (m *Migrator) getExistingColumns(model interface{}) ([]*Column, error) {
|
||||
var columns []*Column
|
||||
|
||||
table := m.GetTableName(model)
|
||||
_, err := m.DB.Query(&columns, "SELECT column_name, ordinal_position, column_default, is_nullable, data_type FROM information_schema.COLUMNS WHERE TABLE_NAME = ?", table)
|
||||
|
||||
return columns, err
|
||||
}
|
||||
|
||||
func (m *Migrator) createTable(model interface{}) error {
|
||||
return m.DB.Model(model).CreateTable(&orm.CreateTableOptions{
|
||||
IfNotExists: true,
|
||||
FKConstraints: true,
|
||||
Temp: true,
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Migrator) GetTableName(model interface{}) string {
|
||||
t := reflect.TypeOf(model)
|
||||
if t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
// check if table name is being overriden by struct tag
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
f := t.Field(i)
|
||||
if f.Name == "tableName" {
|
||||
tag := tagparser.Parse(f.Tag.Get("pg"))
|
||||
if tag.Name != "" {
|
||||
return tag.Name
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return inflection.Plural(strcase.ToSnake(t.Name()))
|
||||
} else {
|
||||
return inflection.Plural(strcase.ToSnake(t.Name()))
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Migrator) getAllowedKinds() []reflect.Kind {
|
||||
if m.allowedKinds == nil {
|
||||
m.allowedKinds = []reflect.Kind{
|
||||
reflect.Bool,
|
||||
reflect.Int,
|
||||
reflect.Int8,
|
||||
reflect.Int16,
|
||||
reflect.Int32,
|
||||
reflect.Int64,
|
||||
reflect.Uint,
|
||||
reflect.Uint8,
|
||||
reflect.Uint16,
|
||||
reflect.Uint32,
|
||||
reflect.Uint64,
|
||||
reflect.Float32,
|
||||
reflect.Float64,
|
||||
reflect.String,
|
||||
}
|
||||
}
|
||||
|
||||
return m.allowedKinds
|
||||
}
|
||||
|
||||
func (m *Migrator) getAllowedTypes() []string {
|
||||
if m.allowedTypes == nil {
|
||||
m.allowedTypes = []string{"string", "int", "int32", "int64", "time.Time", "float64", "bool", "uint", "uint8", "uint16", "uint32", "uint64", "uuid.UUID"}
|
||||
for _, t := range m.allowedTypes {
|
||||
m.allowedTypes = append(m.allowedTypes, "*"+t)
|
||||
m.allowedTypes = append(m.allowedTypes, "[]"+t)
|
||||
m.allowedTypes = append(m.allowedTypes, "[]*"+t)
|
||||
}
|
||||
m.allowedTypes = append(m.allowedTypes, "map[string]interface {}")
|
||||
m.allowedTypes = append(m.allowedTypes, "interface {}")
|
||||
m.allowedTypes = append(m.allowedTypes, "struct {}")
|
||||
}
|
||||
|
||||
return m.allowedTypes
|
||||
}
|
||||
|
||||
// GetFields returns struct fields and database related metadata
|
||||
func (m *Migrator) GetFields(t reflect.Type) []*Column {
|
||||
var (
|
||||
res []*Column
|
||||
)
|
||||
if t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
}
|
||||
if t.Kind() != reflect.Struct {
|
||||
return res
|
||||
}
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
if field.Anonymous {
|
||||
// embedded struct
|
||||
result := m.GetFields(field.Type)
|
||||
res = append(res, result...)
|
||||
continue
|
||||
}
|
||||
column := &Column{
|
||||
ColumnName: strcase.ToSnake(field.Name),
|
||||
}
|
||||
|
||||
fieldType := m.GetType(field.Type)
|
||||
|
||||
tag := field.Tag.Get("pg")
|
||||
ignoreField := false
|
||||
|
||||
if len(tag) > 0 {
|
||||
tags := strings.Split(tag, ",")
|
||||
for i, tagS := range tags {
|
||||
s := strings.ToLower(strings.TrimSpace(tagS))
|
||||
if s == "-" {
|
||||
ignoreField = true
|
||||
} else if s == "unique" {
|
||||
column.Unique = true
|
||||
} else if s == "index" {
|
||||
column.Index = true
|
||||
} else if strings.Contains(s, "unique:") {
|
||||
column.CompositeKey = strings.Replace(s, "unique:", "", 1)
|
||||
} else if strings.Contains(s, "type:") {
|
||||
ss := strings.Split(s, "type:")
|
||||
if len(ss) > 1 {
|
||||
column.DataType = ss[1]
|
||||
}
|
||||
} else if strings.Contains(s, "alias:") {
|
||||
column.ColumnName = strings.Replace(s, "alias:", "", 1)
|
||||
} else if i == 0 && len(s) > 0 && !strings.Contains(s, ":") {
|
||||
column.ColumnName = s
|
||||
} else if s == "notnull" {
|
||||
column.IsNullable = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
accept := m.isAllowedType(fieldType) || len(column.DataType) > 0
|
||||
|
||||
if ignoreField || !accept {
|
||||
continue
|
||||
}
|
||||
|
||||
m.configColumn(fieldType, column)
|
||||
|
||||
res = append(res, column)
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func (m *Migrator) GetType(field reflect.Type) string {
|
||||
fieldType := field.String()
|
||||
|
||||
if m.isAllowedType(fieldType) {
|
||||
return fieldType
|
||||
}
|
||||
|
||||
// Try to get the underlying data type
|
||||
if m.isAllowedKind(field.Kind()) {
|
||||
return field.Kind().String()
|
||||
}
|
||||
|
||||
return fieldType
|
||||
}
|
||||
|
||||
func (m *Migrator) isAllowedType(fieldType string) bool {
|
||||
for _, at := range m.getAllowedTypes() {
|
||||
if at == fieldType {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Migrator) isAllowedKind(kind reflect.Kind) bool {
|
||||
for _, k := range m.getAllowedKinds() {
|
||||
if k == kind {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Migrator) configColumn(fieldType string, column *Column) {
|
||||
switch fieldType {
|
||||
case "string":
|
||||
column.DataType = "text"
|
||||
column.IsNullable = false
|
||||
column.ColumnDefault = "''"
|
||||
case "*string":
|
||||
column.DataType = "text"
|
||||
column.IsNullable = true
|
||||
case "[]string", "[]*string":
|
||||
column.DataType = "text[]"
|
||||
case "int":
|
||||
column.DataType = "integer"
|
||||
column.IsNullable = false
|
||||
case "int64":
|
||||
column.DataType = "bigint"
|
||||
column.IsNullable = false
|
||||
case "*int64":
|
||||
column.DataType = "bigint"
|
||||
column.IsNullable = true
|
||||
case "[]int64", "[]*int64":
|
||||
column.DataType = "integer[]"
|
||||
column.IsNullable = true
|
||||
case "time.Time":
|
||||
column.IsNullable = false
|
||||
column.DataType = "timestamp with time zone"
|
||||
column.ColumnDefault = "NOW()"
|
||||
case "*time.Time":
|
||||
column.DataType = "timestamp with time zone"
|
||||
column.IsNullable = true
|
||||
case "float64":
|
||||
column.DataType = "numeric"
|
||||
column.IsNullable = false
|
||||
column.ColumnDefault = "0.00"
|
||||
case "*float64":
|
||||
column.DataType = "numeric"
|
||||
column.IsNullable = true
|
||||
case "[]float64", "[]*float64":
|
||||
column.DataType = "numeric[]"
|
||||
column.IsNullable = true
|
||||
case "bool":
|
||||
column.DataType = "boolean"
|
||||
column.IsNullable = false
|
||||
column.ColumnDefault = "false"
|
||||
case "*bool":
|
||||
column.DataType = "boolean"
|
||||
column.IsNullable = true
|
||||
case "[]bool", "[]*bool":
|
||||
column.DataType = "boolean[]"
|
||||
column.IsNullable = true
|
||||
case "map[string]interface", "interface":
|
||||
column.DataType = "jsonb"
|
||||
column.IsNullable = true
|
||||
case "uint", "uint32", "uint64":
|
||||
column.DataType = "bigint"
|
||||
column.IsNullable = false
|
||||
case "[]uint8":
|
||||
column.DataType = "bytea"
|
||||
case "uuid.UUID":
|
||||
column.DataType = "uuid"
|
||||
case "[]byte", "common.BinaryHex", "*common.BinaryHex":
|
||||
column.DataType = "bytea"
|
||||
column.IsNullable = true
|
||||
default:
|
||||
logger.Warn().Msgf("%s unknown type %s", column.ColumnName, fieldType)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Migrator) getSchema() string {
|
||||
if m.Schema == "" {
|
||||
m.Schema = "public"
|
||||
}
|
||||
return m.Schema
|
||||
}
|
||||
|
||||
func (m *Migrator) TableExists(model interface{}) error {
|
||||
var count []*CountResult
|
||||
tablename := m.GetTableName(model)
|
||||
|
||||
result, err := m.DB.QueryOne(&count, "SELECT COUNT(*) FROM information_schema.TABLES WHERE table_name = ? AND table_schema = ? AND table_type = 'BASE TABLE'", tablename, m.getSchema())
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
if result.RowsReturned() == 0 || count[0].Count == 0 {
|
||||
return errors.WithMessagef(ErrTableDoesntExist, "for model %t", model)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmatchedFields returns list of model fields that were not found in the DB.
|
||||
func (m *Migrator) UnmatchedFields(model interface{}) error {
|
||||
tablename := m.GetTableName(model)
|
||||
dbColumns, err := m.getExistingColumns(model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
modelFields := m.GetFields(reflect.TypeOf(model))
|
||||
|
||||
mapDBColumns := m.makeColumnMap(dbColumns)
|
||||
mapModelColumns := m.makeColumnMap(modelFields)
|
||||
|
||||
// Check if db table column exists in model
|
||||
for _, c := range dbColumns {
|
||||
if c.DataType == "tsvector" {
|
||||
continue
|
||||
}
|
||||
if modelProp, ok := mapModelColumns[c.ColumnName]; !ok {
|
||||
logger.Warn().Msgf("%t:%s is not in model. Consider dropping from db.", model, c.ColumnName)
|
||||
} else if modelProp.DataType != c.DataType {
|
||||
logger.Warn().Msgf("Type mismatch db %s:%s should be type %s instead of %s", tablename, c.ColumnName, modelProp.DataType, c.DataType)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if model property exists in database table
|
||||
for colName := range mapModelColumns {
|
||||
if _, ok := mapDBColumns[colName]; !ok {
|
||||
logger.Warn().Msgf("%s:%s is not in db. Consider dropping from model or add to table.", tablename, colName)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
152
pkg/db/migrator_test.go
Normal file
152
pkg/db/migrator_test.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package db_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/common/dbbasemodel"
|
||||
"fiskerinc.com/modules/db"
|
||||
)
|
||||
|
||||
var instance db.Migrator
|
||||
|
||||
type TestTable struct {
|
||||
ID int64
|
||||
Name string
|
||||
dbbasemodel.DBModelBase
|
||||
}
|
||||
|
||||
type TestTable2 struct {
|
||||
//lint:ignore U1000 metadata for go-pg to use alternative sql table name
|
||||
tableName struct{} `pg:"test_tables,alias:test_alias"`
|
||||
ID int64
|
||||
Name string `pg:",unique"`
|
||||
Address string `pg:",unique:group_name"`
|
||||
City string `pg:",unique:group_name"`
|
||||
dbbasemodel.DBModelBase
|
||||
}
|
||||
|
||||
type NonExistingTable struct {
|
||||
ID int64
|
||||
}
|
||||
|
||||
func TestMigratorIntegration(t *testing.T) {
|
||||
t.Skip()
|
||||
setupMigrator()
|
||||
migratorTableName(t)
|
||||
migratorTableExists(t)
|
||||
migrateTableFields(t)
|
||||
migrateTable(t)
|
||||
}
|
||||
|
||||
func setupMigrator() {
|
||||
instance = db.Migrator{
|
||||
DB: db.GetDefaultConn(),
|
||||
DryRun: false,
|
||||
}
|
||||
instance.DB.AddQueryHook(db.SQLLogger{})
|
||||
}
|
||||
|
||||
func migratorTableName(t *testing.T) {
|
||||
|
||||
tablename := instance.GetTableName((*TestTable)(nil))
|
||||
if tablename != "test_tables" {
|
||||
t.Error("Table name does not match")
|
||||
}
|
||||
|
||||
tablename = instance.GetTableName((*TestTable2)(nil))
|
||||
if tablename != "test_tables" {
|
||||
t.Error("Table name does not match")
|
||||
}
|
||||
}
|
||||
|
||||
func migrateTableFields(t *testing.T) {
|
||||
type TestCase struct {
|
||||
fieldName string
|
||||
fieldType string
|
||||
}
|
||||
tests := []TestCase{
|
||||
{
|
||||
fieldName: "id",
|
||||
fieldType: "bigint",
|
||||
},
|
||||
{
|
||||
fieldName: "name",
|
||||
fieldType: "text",
|
||||
},
|
||||
{
|
||||
fieldName: "created_at",
|
||||
fieldType: "timestamptz",
|
||||
},
|
||||
{
|
||||
fieldName: "updated_at",
|
||||
fieldType: "timestamptz",
|
||||
},
|
||||
}
|
||||
|
||||
columns := instance.GetFields(reflect.TypeOf((*TestTable)(nil)))
|
||||
if len(columns) != 4 {
|
||||
t.Error("Incorrect number of columns")
|
||||
}
|
||||
|
||||
test:
|
||||
for _, test := range tests {
|
||||
for _, col := range columns {
|
||||
if col.ColumnName == test.fieldName {
|
||||
if col.DataType == test.fieldType {
|
||||
continue test
|
||||
} else {
|
||||
t.Errorf("%s type %s is not %s", test.fieldName, col.DataType, test.fieldType)
|
||||
continue test
|
||||
}
|
||||
}
|
||||
}
|
||||
t.Errorf("%s not found", test.fieldName)
|
||||
}
|
||||
}
|
||||
|
||||
func migrateTable(t *testing.T) {
|
||||
err := instance.Check((*TestTable)(nil))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
err = instance.Check((*TestTable2)(nil))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
err = instance.Check((*TestTable)(nil))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
err = instance.DropTable((*TestTable)(nil))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func migratorTableExists(t *testing.T) {
|
||||
err := instance.TableExists((*NonExistingTable)(nil))
|
||||
if err == nil {
|
||||
t.Errorf("Table should not exist %v", err)
|
||||
}
|
||||
|
||||
err = instance.TableExists((*common.Car)(nil))
|
||||
if err != nil {
|
||||
t.Errorf("Table should exist %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMigrator(b *testing.B) {
|
||||
migrator := db.Migrator{
|
||||
DB: db.GetDefaultConn(),
|
||||
DryRun: false,
|
||||
}
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
migrator.Check((*TestTable2)(nil))
|
||||
}
|
||||
}
|
||||
23
pkg/db/operation.go
Normal file
23
pkg/db/operation.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
)
|
||||
|
||||
type Operation interface {
|
||||
Close() error
|
||||
Context() context.Context
|
||||
Exec(query interface{}, params ...interface{}) (orm.Result, error)
|
||||
ExecContext(c context.Context, query interface{}, params ...interface{}) (orm.Result, error)
|
||||
ExecOne(query interface{}, params ...interface{}) (orm.Result, error)
|
||||
ExecOneContext(c context.Context, query interface{}, params ...interface{}) (orm.Result, error)
|
||||
Formatter() orm.QueryFormatter
|
||||
Model(model ...interface{}) *orm.Query
|
||||
ModelContext(c context.Context, model ...interface{}) *orm.Query
|
||||
Query(model interface{}, query interface{}, params ...interface{}) (orm.Result, error)
|
||||
QueryContext(c context.Context, model interface{}, query interface{}, params ...interface{}) (orm.Result, error)
|
||||
QueryOne(model interface{}, query interface{}, params ...interface{}) (orm.Result, error)
|
||||
QueryOneContext(c context.Context, model interface{}, query interface{}, params ...interface{}) (orm.Result, error)
|
||||
}
|
||||
48
pkg/db/queries/action_logs.go
Normal file
48
pkg/db/queries/action_logs.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package queries
|
||||
|
||||
import (
|
||||
"fiskerinc.com/modules/common/actionlogger"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type ActionLogInterface interface {
|
||||
Insert(log actionlogger.ActionLog) (err error)
|
||||
Select(filter actionlogger.ActionLogFilter) (logs []actionlogger.ActionLog, err error)
|
||||
}
|
||||
|
||||
type ActionLogDB struct {
|
||||
QueryBase
|
||||
}
|
||||
|
||||
// Insert implements ActionLog.
|
||||
func (al *ActionLogDB) Insert(log actionlogger.ActionLog) (err error) {
|
||||
_, err = al.insert(&log)
|
||||
if err != nil {
|
||||
errors.WithStack(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Select implements ActionLog.
|
||||
func (al *ActionLogDB) Select(filter actionlogger.ActionLogFilter) (logs []actionlogger.ActionLog, err error) {
|
||||
query := al.GetDBConn().Model(&logs)
|
||||
if len(filter.VINs) > 0 {
|
||||
query.WhereIn("vin IN (?)", filter.VINs)
|
||||
}
|
||||
|
||||
if len(filter.Actions) > 0 {
|
||||
query.WhereIn("action IN (?)", filter.Actions)
|
||||
}
|
||||
|
||||
if filter.TrackingID != nil {
|
||||
query.Where("tracking_id = ?", &filter.TrackingID)
|
||||
}
|
||||
|
||||
err = query.Select()
|
||||
if err != nil {
|
||||
errors.WithStack(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var _ ActionLogInterface = &ActionLogDB{}
|
||||
60
pkg/db/queries/apicalls.go
Normal file
60
pkg/db/queries/apicalls.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package queries
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/validator"
|
||||
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type APICallsInterface interface {
|
||||
Insert(apiCall common.APICall) (orm.Result, error)
|
||||
Search(filter common.APICallsSearch, paging *PageQueryOptions) ([]common.APICall, int, error)
|
||||
}
|
||||
|
||||
type APICalls struct {
|
||||
QueryBase
|
||||
}
|
||||
|
||||
func (kv *APICalls) Insert(apiCall common.APICall) (orm.Result, error) {
|
||||
err := validator.ValidateStruct(apiCall)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return kv.insert(&apiCall)
|
||||
}
|
||||
|
||||
func (kv *APICalls) Search(filter common.APICallsSearch, paging *PageQueryOptions) ([]common.APICall, int, error) {
|
||||
calls := []common.APICall{}
|
||||
query := kv.GetDBConn().Model(&calls)
|
||||
|
||||
kv.pageQuery(query, paging)
|
||||
kv.applyFilters(query, filter)
|
||||
count, err := query.SelectAndCount()
|
||||
if err != nil {
|
||||
return nil, 0, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return calls, count, nil
|
||||
}
|
||||
|
||||
func (kv *APICalls) applyFilters(query *orm.Query, filter common.APICallsSearch) {
|
||||
if filter.Search != "" {
|
||||
search := strings.ToLower("%" + filter.Search + "%")
|
||||
query.Where("LOWER(client_id) LIKE ? "+
|
||||
"OR LOWER(endpoint) LIKE ? "+
|
||||
"OR LOWER(method) LIKE ?", search, search, search)
|
||||
}
|
||||
|
||||
if filter.From != nil {
|
||||
query.Where("created_at >= ?", filter.From)
|
||||
}
|
||||
|
||||
if filter.To != nil {
|
||||
query.Where("created_at <= ?", filter.To)
|
||||
}
|
||||
}
|
||||
109
pkg/db/queries/apitokens.go
Normal file
109
pkg/db/queries/apitokens.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package queries
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/validator"
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type APITokensInterface interface {
|
||||
Delete(token string) (orm.Result, error)
|
||||
Insert(apitoken common.APIToken) (orm.Result, error)
|
||||
Update(apitoken *common.APIToken) (orm.Result, error)
|
||||
Get(token string) (*common.APIToken, error)
|
||||
Select(apitoken *common.APIToken, paging *PageQueryOptions) ([]common.APIToken, error)
|
||||
Count(apitoken *common.APIToken) (int, error)
|
||||
}
|
||||
|
||||
type APITokens struct {
|
||||
QueryBase
|
||||
}
|
||||
|
||||
func (kv *APITokens) Delete(token string) (orm.Result, error) {
|
||||
if token == "" {
|
||||
return nil, errors.WithStack(&validator.FieldError{
|
||||
ErrorMsg: "token required",
|
||||
})
|
||||
}
|
||||
|
||||
conn := kv.GetDBConn()
|
||||
result, err := conn.Model(&common.APIToken{
|
||||
Token: token,
|
||||
}).WherePK().Delete()
|
||||
|
||||
return result, errors.WithStack(err)
|
||||
}
|
||||
|
||||
func (kv *APITokens) Insert(apiToken common.APIToken) (orm.Result, error) {
|
||||
err := validator.ValidateStruct(apiToken)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return kv.insert(&apiToken)
|
||||
}
|
||||
|
||||
func (kv *APITokens) Get(token string) (*common.APIToken, error) {
|
||||
if token == "" {
|
||||
return nil, errors.WithStack(&validator.FieldError{
|
||||
ErrorMsg: "token required",
|
||||
})
|
||||
}
|
||||
|
||||
keyvalues := []common.APIToken{}
|
||||
err := kv.GetDBConn().
|
||||
Model(&keyvalues).
|
||||
Where("token = ?", token).
|
||||
Where("expires_at > ? or expires_at is null", time.Now()).
|
||||
Select()
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
if len(keyvalues) == 0 {
|
||||
return nil, errors.New("token not found")
|
||||
}
|
||||
|
||||
return &keyvalues[0], nil
|
||||
}
|
||||
|
||||
func (kv *APITokens) selectFilter(query *orm.Query, filter *common.APIToken) {
|
||||
if filter.Token != "" {
|
||||
query.Where("token = ?", filter.Token)
|
||||
}
|
||||
|
||||
if filter.Roles != "" {
|
||||
query.Where("roles LIKE ?", fmt.Sprintf("%%%s%%", filter.Roles))
|
||||
}
|
||||
|
||||
if filter.Description != "" {
|
||||
query.Where("description = ?", filter.Description)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (kv *APITokens) Select(filter *common.APIToken, paging *PageQueryOptions) ([]common.APIToken, error) {
|
||||
items := []common.APIToken{}
|
||||
query := kv.GetDBConn().Model(&items)
|
||||
|
||||
kv.selectFilter(query, filter)
|
||||
if paging != nil {
|
||||
kv.pageQuery(query, paging)
|
||||
}
|
||||
|
||||
err := query.Select()
|
||||
|
||||
return items, errors.WithStack(err)
|
||||
}
|
||||
|
||||
func (kv *APITokens) Update(model *common.APIToken) (orm.Result, error) {
|
||||
return kv.update(model)
|
||||
}
|
||||
|
||||
func (kv *APITokens) Count(apitoken *common.APIToken) (int, error) {
|
||||
return kv.count(apitoken)
|
||||
}
|
||||
22
pkg/db/queries/car_config_data.go
Normal file
22
pkg/db/queries/car_config_data.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package queries
|
||||
|
||||
import (
|
||||
"fiskerinc.com/modules/common"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type CarConfigDataInterface interface {
|
||||
SelectByVIN(vin string) (common.CarConfigData, error)
|
||||
}
|
||||
|
||||
type CarConfigData struct {
|
||||
QueryBase
|
||||
}
|
||||
|
||||
func (c *CarConfigData) SelectByVIN(vin string) (common.CarConfigData, error) {
|
||||
config := common.CarConfigData{}
|
||||
|
||||
err := c.GetDBConn().Model(&config).Where("vin = ?", vin).Select()
|
||||
|
||||
return config, errors.WithStack(err)
|
||||
}
|
||||
51
pkg/db/queries/car_versions_log.go
Normal file
51
pkg/db/queries/car_versions_log.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package queries
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"fiskerinc.com/modules/common"
|
||||
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const versionsAtSQL = "SELECT version_source, version FROM public.car_version_logs WHERE id IN (SELECT MAX(id) as id FROM public.car_version_logs WHERE vin = ? AND created_at <= ? GROUP BY vin, version_source)"
|
||||
|
||||
type CarVersionsLogInterface interface {
|
||||
LogVersionChange(log *common.CarVersionLogs) (orm.Result, error)
|
||||
SelectByVIN(vin string, options *PageQueryOptions) ([]common.CarVersionLogs, int, error)
|
||||
GetCarVersions(vin string, timestamp time.Time) (map[string]string, error)
|
||||
}
|
||||
|
||||
// CarVersionsLog query methods
|
||||
type CarVersionsLog struct {
|
||||
QueryBase
|
||||
}
|
||||
|
||||
func (c *CarVersionsLog) LogVersionChange(log *common.CarVersionLogs) (orm.Result, error) {
|
||||
return c.insert(log)
|
||||
}
|
||||
|
||||
func (c *CarVersionsLog) SelectByVIN(vin string, options *PageQueryOptions) ([]common.CarVersionLogs, int, error) {
|
||||
var logs []common.CarVersionLogs
|
||||
query := c.GetDBConn().Model(&logs).Where("vin = ?", vin)
|
||||
query = c.pageQuery(query, options)
|
||||
total, err := query.SelectAndCount()
|
||||
|
||||
return logs, total, errors.WithStack(err)
|
||||
}
|
||||
|
||||
func (c *CarVersionsLog) GetCarVersions(vin string, timestamp time.Time) (map[string]string, error) {
|
||||
logs := []common.CarVersionLogs{}
|
||||
result := map[string]string{}
|
||||
|
||||
_, err := c.GetDBConn().Query(&logs, versionsAtSQL, vin, timestamp)
|
||||
if err == nil {
|
||||
result = map[string]string{}
|
||||
for _, log := range logs {
|
||||
result[string(log.VersionSource)] = log.Version
|
||||
}
|
||||
}
|
||||
|
||||
return result, errors.WithStack(err)
|
||||
}
|
||||
1158
pkg/db/queries/cars.go
Normal file
1158
pkg/db/queries/cars.go
Normal file
File diff suppressed because it is too large
Load Diff
600
pkg/db/queries/cars_test.go
Normal file
600
pkg/db/queries/cars_test.go
Normal file
@@ -0,0 +1,600 @@
|
||||
package queries_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
|
||||
m "fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/db"
|
||||
"fiskerinc.com/modules/db/queries"
|
||||
"fiskerinc.com/modules/testhelper"
|
||||
|
||||
"github.com/go-pg/pg/v10"
|
||||
"github.com/google/uuid"
|
||||
"fiskerinc.com/modules/utils/elptr"
|
||||
)
|
||||
|
||||
const testCarVIN = "1GNGC26R1XJ407649"
|
||||
const testCarVIN2 = "1GNGC26R1XJ407648"
|
||||
|
||||
var qc queries.CarsInterface
|
||||
var conn *pg.DB
|
||||
var testCarID string
|
||||
var testDriverID string
|
||||
|
||||
func TestCarsIntegration(t *testing.T) {
|
||||
t.Skip()
|
||||
|
||||
defer testCarDelete(t, testCarVIN)
|
||||
defer testCarDelete(t, testCarVIN2)
|
||||
setupCarsTests()
|
||||
clearTestCar()
|
||||
testCarInsert(t, testCarVIN)
|
||||
testCarInsert(t, testCarVIN2)
|
||||
testCarECU(t)
|
||||
testCarAddDriver(t)
|
||||
testCarSelectByVIN(t)
|
||||
testCarSelect(t)
|
||||
testCarUpdate(t)
|
||||
testSetSetting(t)
|
||||
testCarSearch(t)
|
||||
testGetModels(t)
|
||||
testGetYears(t)
|
||||
testSelectCarToDriver(t)
|
||||
testGetCarsForDriver(t)
|
||||
testGetCount(t)
|
||||
testCarVINsSearch(t)
|
||||
}
|
||||
|
||||
func setupCarsTests() {
|
||||
instance := &queries.Cars{}
|
||||
conn = instance.GetDBConn()
|
||||
conn.AddQueryHook(db.SQLLogger{})
|
||||
qc = instance
|
||||
client := instance.GetClient()
|
||||
client.InitSchema([]interface{}{
|
||||
(*m.CarECU)(nil),
|
||||
(*m.Car)(nil),
|
||||
(*m.Driver)(nil),
|
||||
(*m.CarToDriver)(nil),
|
||||
(*m.CarSetting)(nil),
|
||||
})
|
||||
}
|
||||
|
||||
func testCarECU(t *testing.T) {
|
||||
ecu := m.CarECU{
|
||||
VIN: testCarVIN,
|
||||
ECU: "ECU1",
|
||||
Version: "1002",
|
||||
}
|
||||
|
||||
err := qc.UpdateCarECU(&ecu)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "UpdateCarECU insert", "No error", err)
|
||||
}
|
||||
|
||||
ecu.Version = "1000"
|
||||
err = qc.UpdateCarECU(&ecu)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "UpdateCarECU2 update", "No error", err)
|
||||
}
|
||||
|
||||
car := m.Car{
|
||||
VIN: testCarVIN,
|
||||
}
|
||||
err = qc.Load(&car)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "UpdateCarECU2 update", "No error", err)
|
||||
}
|
||||
if !strings.Contains(car.ECUList, "ECU1 1000") {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "UpdateCarECU2 update", "ECU1 1001", car.ECUList)
|
||||
}
|
||||
}
|
||||
|
||||
func testCarSearch(t *testing.T) {
|
||||
search := m.CarSearch{
|
||||
Search: testCarVIN,
|
||||
}
|
||||
options := queries.PageQueryOptions{
|
||||
Offset: 0,
|
||||
Limit: 0,
|
||||
Order: "vin DESC",
|
||||
}
|
||||
|
||||
result, err := qc.Search(&search, &options)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Car Search", "No error", err)
|
||||
}
|
||||
if len(result) != 1 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Car Search result", 1, len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func testCarVINsSearch(t *testing.T) {
|
||||
search := m.CarSearch{
|
||||
VINs: testCarVIN + "," + testCarVIN2,
|
||||
}
|
||||
|
||||
options := queries.PageQueryOptions{
|
||||
Offset: 0,
|
||||
Limit: 0,
|
||||
Order: "vin DESC",
|
||||
}
|
||||
|
||||
result, err := qc.Search(&search, &options)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Car Search", "No error", err)
|
||||
}
|
||||
if len(result) != 2 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Car Search result", 2, len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func testSetSetting(t *testing.T) {
|
||||
setting := m.CarSetting{
|
||||
VIN: testCarVIN,
|
||||
DriverID: testDriverID,
|
||||
Name: "TestSetting",
|
||||
Value: "TestValue",
|
||||
}
|
||||
result, err := qc.SetSetting(&setting)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "testSetSetting", result, err)
|
||||
}
|
||||
}
|
||||
|
||||
func testCarInsert(t *testing.T, vin string) {
|
||||
car := m.Car{
|
||||
VIN: vin,
|
||||
Model: "Ocean",
|
||||
Year: 2022,
|
||||
Trim: "Base",
|
||||
}
|
||||
|
||||
_, err := qc.Insert(&car)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "SelectOrInsert", "nil", err)
|
||||
}
|
||||
|
||||
testCarID = car.VIN
|
||||
}
|
||||
|
||||
func testCarAddDriver(t *testing.T) {
|
||||
car := m.Car{
|
||||
VIN: testCarVIN,
|
||||
}
|
||||
|
||||
drivers := []m.Driver{
|
||||
{
|
||||
ID: "TEST-001",
|
||||
},
|
||||
{
|
||||
ID: "TEST-002",
|
||||
},
|
||||
}
|
||||
_, err := qc.AddDriver(&car, &m.Driver{}, "driver")
|
||||
if err == nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "AddDriver errored", "Error", err)
|
||||
}
|
||||
|
||||
qc.Load(&car)
|
||||
for _, driver := range drivers {
|
||||
_, err = conn.Model(&driver).Where("id = ?id").SelectOrInsert()
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Insert driver", "nil", err)
|
||||
}
|
||||
|
||||
_, err := qc.AddDriver(&car, &driver, "driver")
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "AddDriver", "nil", err)
|
||||
}
|
||||
|
||||
testDriverID = driver.ID
|
||||
|
||||
_, err = qc.AddDriver(&car, &driver, "driverX")
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "AddDriver", "nil", err)
|
||||
}
|
||||
}
|
||||
|
||||
qc.Load(&car)
|
||||
if len(car.Drivers) != 2 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Added drivers", 2, len(car.Drivers))
|
||||
}
|
||||
}
|
||||
|
||||
func testCarSelectByVIN(t *testing.T) {
|
||||
car, err := qc.SelectByVIN(testCarVIN)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "SelectByID error", "nil", err)
|
||||
}
|
||||
if car.VIN != testCarVIN {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "VIN", testCarVIN, car.VIN)
|
||||
}
|
||||
|
||||
err = qc.Load(car)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Load error", "nil", err)
|
||||
}
|
||||
|
||||
if len(car.Drivers) != 2 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Drivers", 2, len(car.Drivers))
|
||||
}
|
||||
}
|
||||
|
||||
func testCarSelect(t *testing.T) {
|
||||
car := m.Car{
|
||||
Year: 2022,
|
||||
}
|
||||
options := queries.PageQueryOptions{
|
||||
Offset: 0,
|
||||
Limit: 0,
|
||||
Order: "vin DESC",
|
||||
}
|
||||
|
||||
cars, err := qc.Select(&car, &options)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Select error", "nil", err)
|
||||
}
|
||||
count := len(cars)
|
||||
if count == 0 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Count", "more than 0", len(cars))
|
||||
}
|
||||
|
||||
options.Offset = count
|
||||
options.Limit = 10
|
||||
cars, err = qc.Select(&car, &options)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Select error", "nil", err)
|
||||
}
|
||||
if len(cars) != 0 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Count", 0, len(cars))
|
||||
}
|
||||
}
|
||||
|
||||
func testCarUpdate(t *testing.T) {
|
||||
car := m.Car{
|
||||
VIN: testCarVIN,
|
||||
}
|
||||
|
||||
err := qc.Load(&car)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Load", "nil", err)
|
||||
}
|
||||
|
||||
car.Year = 2020
|
||||
car.Model = "Ocean S"
|
||||
car.Trim = "Sport"
|
||||
result, err := qc.Update(&car)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Update", "nil", err)
|
||||
}
|
||||
if result.RowsAffected() != 1 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Update RowsAffected", 1, result.RowsAffected())
|
||||
}
|
||||
|
||||
car = m.Car{
|
||||
VIN: testCarVIN,
|
||||
}
|
||||
err = qc.Load(&car)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Reload", "nil", err)
|
||||
}
|
||||
if car.Year != 2020 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Update Year", 2020, car.Year)
|
||||
}
|
||||
if car.Model != "Ocean S" {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Update Model", "Ocean S", car.Model)
|
||||
}
|
||||
if !car.CreatedAt.Before(*car.UpdatedAt) {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "UpdatedAt", car.CreatedAt, car.UpdatedAt)
|
||||
}
|
||||
}
|
||||
|
||||
func testCarDelete(t *testing.T, vin string) {
|
||||
car := m.Car{
|
||||
Year: 2022,
|
||||
}
|
||||
|
||||
_, err := qc.Delete(&car)
|
||||
if err == nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "No ID", "No ID error", err)
|
||||
}
|
||||
|
||||
car = m.Car{
|
||||
VIN: vin,
|
||||
}
|
||||
carVIN := car.VIN
|
||||
|
||||
_, err = qc.Delete(&car)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Delete", "nil", err)
|
||||
}
|
||||
|
||||
cardrivers := []m.CarToDriver{}
|
||||
|
||||
count, err := conn.Model(&cardrivers).Where("vin = ?", carVIN).SelectAndCount()
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Count", "nil", err)
|
||||
}
|
||||
if count > 0 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Count", 0, count)
|
||||
}
|
||||
}
|
||||
|
||||
func testGetModels(t *testing.T) {
|
||||
models, err := qc.GetModels()
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "GetModels", "nil", err)
|
||||
}
|
||||
if len(models) == 0 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Models Count", "More than 0", len(models))
|
||||
}
|
||||
}
|
||||
|
||||
func testGetYears(t *testing.T) {
|
||||
years, err := qc.GetYears()
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "GetYears", "nil", err)
|
||||
}
|
||||
if len(years) == 0 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Years Count", "More than 0", len(years))
|
||||
}
|
||||
}
|
||||
|
||||
func testSelectCarToDriver(t *testing.T) {
|
||||
filter := m.CarToDriver{
|
||||
DriverID: "TEST-001",
|
||||
}
|
||||
|
||||
drivers, err := qc.SelectCarToDriver(&filter)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "SelectCarToDriver", "nil", err)
|
||||
} else if len(drivers) == 0 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "SelectCarToDriver", "nil", err)
|
||||
}
|
||||
}
|
||||
|
||||
func testGetCarsForDriver(t *testing.T) {
|
||||
carToDrivers, err := qc.GetCarsForDriver("TEST-001")
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "GetCarsForDriver", nil, err)
|
||||
}
|
||||
if len(carToDrivers) == 0 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Car to Drivers Count", "More than 0", len(carToDrivers))
|
||||
}
|
||||
}
|
||||
|
||||
func testGetCount(t *testing.T) {
|
||||
count, err := qc.Count(&m.Car{VIN: testCarVIN})
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Count", nil, err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Count", 1, count)
|
||||
}
|
||||
}
|
||||
|
||||
func clearTestCar() {
|
||||
car := m.Car{
|
||||
VIN: testCarVIN,
|
||||
}
|
||||
result, err := conn.Model(&car).Where("vin = ?vin").Delete()
|
||||
fmt.Println(result, err)
|
||||
}
|
||||
|
||||
func TestSelectCarToDriver(t *testing.T) {
|
||||
t.Skip()
|
||||
cars := queries.Cars{}
|
||||
rel, err := cars.AddDriver(&m.Car{VIN: "1G1FP87S3GN100062"}, &m.Driver{ID: "ddf34966-9677-46de-b2eb-ddf501968ea5"}, "role")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
defer cars.RemoveDriver(rel.VIN, rel.DriverID)
|
||||
|
||||
subtypes := queries.SubscriptionTypes{}
|
||||
subtype := m.SubscriptionType{
|
||||
Name: fmt.Sprintf("test %v", uuid.New()),
|
||||
Description: "this is a test",
|
||||
Destination: "ICC",
|
||||
Currency: "USD",
|
||||
Price: 0,
|
||||
DurationValue: 1,
|
||||
DurationUnit: "Hours",
|
||||
}
|
||||
_, err = subtypes.Insert(&subtype)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
defer subtypes.Delete(&subtype)
|
||||
|
||||
subs := queries.Subscriptions{}
|
||||
sub, err := subs.Create(&subtype, rel)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
defer subs.Delete(&queries.SubscriptionDeleteRequest{ID: sub.ID})
|
||||
|
||||
setting := m.CarSetting{VIN: testCarVIN, DriverID: testDriverID, Name: "test123", Value: "XXXX"}
|
||||
_, err = cars.SetSetting(&setting)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
defer cars.DeleteSetting(&setting)
|
||||
|
||||
carToDrivers, err := cars.SelectCarToDriver(&m.CarToDriver{VIN: rel.VIN, DriverID: rel.DriverID})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
if len(carToDrivers) != 1 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "SelectCarToDriver count", 1, len(carToDrivers))
|
||||
return
|
||||
}
|
||||
if len(carToDrivers[0].Settings) != 1 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "SelectCarToDriver settings", 1, len(carToDrivers))
|
||||
}
|
||||
if len(carToDrivers[0].Subscriptions) != 1 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "SelectCarToDriver settings", 1, len(carToDrivers))
|
||||
}
|
||||
}
|
||||
|
||||
func queryToString(q *orm.Query) string {
|
||||
value, _ := q.AppendQuery(orm.NewFormatter(), nil)
|
||||
|
||||
return string(value)
|
||||
}
|
||||
|
||||
func Test_addOnlineFilter(t *testing.T) {
|
||||
type args struct {
|
||||
query *orm.Query
|
||||
filter *m.CarSearch
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
name: "empty filter",
|
||||
args: args{
|
||||
query: new(orm.Query),
|
||||
filter: &m.CarSearch{},
|
||||
},
|
||||
expect: "SELECT *",
|
||||
},
|
||||
{
|
||||
name: "filter with online true but no online cars",
|
||||
args: args{
|
||||
query: new(orm.Query),
|
||||
filter: &m.CarSearch{Online: &m.CarOnlineFilter{
|
||||
Online: elptr.ElPtr(true),
|
||||
VINsOnline: []string{},
|
||||
}},
|
||||
},
|
||||
expect: "SELECT * WHERE (vin IN (''))",
|
||||
},
|
||||
{
|
||||
name: "filter with hmi online true but no online cars",
|
||||
args: args{
|
||||
query: new(orm.Query),
|
||||
filter: &m.CarSearch{Online: &m.CarOnlineFilter{
|
||||
HMI: elptr.ElPtr(true),
|
||||
VINsOnline: []string{},
|
||||
}},
|
||||
},
|
||||
expect: "SELECT * WHERE (vin IN (''))",
|
||||
},
|
||||
|
||||
{
|
||||
name: "filter with online false",
|
||||
args: args{
|
||||
query: new(orm.Query),
|
||||
filter: &m.CarSearch{Online: &m.CarOnlineFilter{
|
||||
Online: elptr.ElPtr(false),
|
||||
}},
|
||||
},
|
||||
expect: "SELECT * WHERE (vin NOT IN (''))",
|
||||
},
|
||||
{
|
||||
name: "filter with hmi online false",
|
||||
args: args{
|
||||
query: new(orm.Query),
|
||||
filter: &m.CarSearch{Online: &m.CarOnlineFilter{
|
||||
Online: elptr.ElPtr(false),
|
||||
}},
|
||||
},
|
||||
expect: "SELECT * WHERE (vin NOT IN (''))",
|
||||
},
|
||||
{
|
||||
name: "filter with one of them being true 1",
|
||||
args: args{
|
||||
query: new(orm.Query),
|
||||
filter: &m.CarSearch{
|
||||
Online: &m.CarOnlineFilter{
|
||||
Online: elptr.ElPtr(true),
|
||||
HMI: elptr.ElPtr(false),
|
||||
},
|
||||
},
|
||||
},
|
||||
expect: "SELECT * WHERE (vin IN (''))",
|
||||
},
|
||||
{
|
||||
name: "filter with one of them being true 2",
|
||||
args: args{
|
||||
query: new(orm.Query),
|
||||
filter: &m.CarSearch{
|
||||
Online: &m.CarOnlineFilter{
|
||||
Online: elptr.ElPtr(false),
|
||||
HMI: elptr.ElPtr(true),
|
||||
},
|
||||
},
|
||||
},
|
||||
expect: "SELECT * WHERE (vin IN (''))",
|
||||
},
|
||||
{
|
||||
name: "filter with online true",
|
||||
args: args{
|
||||
query: new(orm.Query),
|
||||
filter: &m.CarSearch{Online: &m.CarOnlineFilter{
|
||||
Online: elptr.ElPtr(true),
|
||||
VINsOnline: []string{"1G1FP87S3GN100062", "1G1FP87S3GN100063"},
|
||||
}},
|
||||
},
|
||||
expect: "SELECT * WHERE (vin IN ('1G1FP87S3GN100062','1G1FP87S3GN100063'))",
|
||||
},
|
||||
{
|
||||
name: "filter with hmi online true",
|
||||
args: args{
|
||||
query: new(orm.Query),
|
||||
filter: &m.CarSearch{Online: &m.CarOnlineFilter{
|
||||
Online: elptr.ElPtr(true),
|
||||
VINsOnline: []string{"1G1FP87S3GN100064", "1G1FP87S3GN100065"},
|
||||
}},
|
||||
},
|
||||
expect: "SELECT * WHERE (vin IN ('1G1FP87S3GN100064','1G1FP87S3GN100065'))",
|
||||
},
|
||||
{
|
||||
name: "filter with both being online true",
|
||||
args: args{
|
||||
query: new(orm.Query),
|
||||
filter: &m.CarSearch{Online: &m.CarOnlineFilter{
|
||||
Online: elptr.ElPtr(true),
|
||||
HMI: elptr.ElPtr(true),
|
||||
VINsOnline: []string{
|
||||
"1G1FP87S3GN100064",
|
||||
"1G1FP87S3GN100065",
|
||||
"1G1FP87S3GN100062",
|
||||
"1G1FP87S3GN100063"},
|
||||
}}},
|
||||
expect: "SELECT * WHERE (vin IN ('1G1FP87S3GN100064','1G1FP87S3GN100065','1G1FP87S3GN100062','1G1FP87S3GN100063'))",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
queries.AddOnlineFilter(tt.args.query, tt.args.filter)
|
||||
|
||||
if got := queryToString(tt.args.query); got != tt.expect {
|
||||
t.Errorf("addOnlineFilter() = %v, want %v", got, tt.expect)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMeLocal(t *testing.T) {
|
||||
instance := &queries.Cars{}
|
||||
conn = instance.GetDBConn()
|
||||
|
||||
certificat := &queries.Certificates{}
|
||||
//certificat.SetClient(conn)
|
||||
cert, err := certificat.SelectMostRecent("p", "j")
|
||||
t.Log(err)
|
||||
t.Log(cert)
|
||||
}
|
||||
426
pkg/db/queries/carupdates.go
Normal file
426
pkg/db/queries/carupdates.go
Normal file
@@ -0,0 +1,426 @@
|
||||
package queries
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/common/carupdatestatus"
|
||||
"fiskerinc.com/modules/logger"
|
||||
"fiskerinc.com/modules/validator"
|
||||
|
||||
"github.com/go-pg/pg/v10"
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var FINAL_UPDATE_STATUS = carupdatestatus.FINAL_UPDATE_STATUS
|
||||
|
||||
type CarUpdatesInterface interface {
|
||||
Count(filter *common.CarUpdate) (int, error)
|
||||
Delete(update *common.CarUpdate) (orm.Result, error)
|
||||
CountUpdateStatuses(carupdateid int64) (int, error)
|
||||
GetUpdateStatuses(carupdateid int64, paging *PageQueryOptions) ([]common.CarUpdateStatus, error)
|
||||
TruncateRequirementsAwaitForUpdate(carupdateid int64) (orm.Result, error)
|
||||
Insert(update *common.CarUpdate) (orm.Result, error)
|
||||
InsertAndCreateStatus(update *common.CarUpdate) (orm.Result, error) // Insert a car update and create the appropriate update_status entry
|
||||
Load(update *common.CarUpdate) error
|
||||
LogStatus(update *common.CarUpdate) (orm.Result, error)
|
||||
SelectByID(id int64) (*common.CarUpdate, error)
|
||||
SelectByManifestID(int64) ([]common.CarUpdate, error)
|
||||
SelectByVIN(vin string) ([]common.CarUpdate, error)
|
||||
SelectMostRecentByVINs(vins []string) ([]common.CarUpdate, error)
|
||||
SelectOrInsert(update *common.CarUpdate) (bool, error)
|
||||
Select(filter *common.CarUpdate, paging *PageQueryOptions) ([]common.CarUpdate, error)
|
||||
UpdateStatus(update *common.CarUpdate) (orm.Result, error)
|
||||
UpdateStatusIfNotRepeat(update *common.CarUpdate) (orm.Result, error)
|
||||
GetManifest(carupdateid int64) (*common.UpdateManifest, error)
|
||||
HasPendingUpdates(manifestID int64, vin string) (bool, error)
|
||||
HasPendingUpdatesFromAftersalesUser(manifestID int64, vin string) (updateID int64, pendingUpdateAftersales bool, err error) // Cancel any pending aftersales update
|
||||
InsertMissingFlashpack(vin string, flashpackVersion string) error // Flashpack Version = OS Version Number
|
||||
}
|
||||
|
||||
// CarUpdate query methods
|
||||
type CarUpdates struct {
|
||||
QueryBase
|
||||
}
|
||||
|
||||
func (c *CarUpdates) selectFilter(query *orm.Query, filter *common.CarUpdate) {
|
||||
if filter.ID > 0 {
|
||||
query.Where("car_update.id = ?", filter.ID)
|
||||
}
|
||||
|
||||
if filter.VIN != "" {
|
||||
query.Where("vin = ?", filter.VIN)
|
||||
}
|
||||
|
||||
if filter.UpdateManifestID > 0 {
|
||||
query.Where("update_manifest_id = ?", filter.UpdateManifestID)
|
||||
}
|
||||
|
||||
if filter.Status != "" {
|
||||
query.Where("status = ?", filter.Status)
|
||||
}
|
||||
|
||||
if filter.UpdateSource != "" {
|
||||
query.Where("update_source = ?", filter.UpdateSource)
|
||||
}
|
||||
}
|
||||
|
||||
// StatusHistoryFilter get CarUpdates with certain status' filtered out.
|
||||
// When a status is filtered out, CarUpdates with that status will still
|
||||
// be returned, but will contain their last status not ignored.
|
||||
func StatusHistoryFilter(query *orm.Query, ignoreStatuses []string) {
|
||||
query.ColumnExpr("car_update.id")
|
||||
query.ColumnExpr("car_update.vin")
|
||||
query.ColumnExpr("car_update.created_at")
|
||||
query.ColumnExpr("car_update.updated_at")
|
||||
query.ColumnExpr("car_update.update_manifest_id")
|
||||
query.ColumnExpr("car_update.error_code")
|
||||
query.ColumnExpr("car_update.info")
|
||||
query.ColumnExpr("car_update.username")
|
||||
query.ColumnExpr("car_update_display_status.display_status AS status")
|
||||
query.Join(`LEFT JOIN car_update_display_status ON car_update.id = car_update_display_status.car_update_id`)
|
||||
}
|
||||
|
||||
// CarUpdate select by car update id
|
||||
func (c *CarUpdates) SelectByID(id int64) (*common.CarUpdate, error) {
|
||||
update := common.CarUpdate{ID: id}
|
||||
|
||||
err := c.GetDBConn().Model(&update).WherePK().
|
||||
Relation("UpdateManifest").
|
||||
Select()
|
||||
|
||||
return &update, errors.WithStack(err)
|
||||
}
|
||||
|
||||
// SelectByVIN returns list of car updates by VIN
|
||||
func (c *CarUpdates) SelectByVIN(vin string) ([]common.CarUpdate, error) {
|
||||
updates := []common.CarUpdate{}
|
||||
|
||||
err := c.GetDBConn().Model(&updates).Where("vin = ?", vin).
|
||||
Relation("UpdateManifest").
|
||||
Select()
|
||||
|
||||
return updates, errors.WithStack(err)
|
||||
}
|
||||
|
||||
// SelectMostRecentByVINs returns a list of most recent car update for VINs
|
||||
func (c *CarUpdates) SelectMostRecentByVINs(vins []string) ([]common.CarUpdate, error) {
|
||||
updates := []common.CarUpdate{}
|
||||
|
||||
err := c.GetDBConn().Model(&updates).
|
||||
Where("vin IN (?)", pg.In(vins)).
|
||||
WhereGroup(func(q *pg.Query) (*pg.Query, error) {
|
||||
return q.Where("(car_update.created_at = (SELECT MAX(t2.created_at) FROM public.car_updates t2 WHERE t2.vin = car_update.vin))"), nil
|
||||
}).
|
||||
Relation("UpdateManifest").
|
||||
Select()
|
||||
|
||||
return updates, errors.WithStack(err)
|
||||
}
|
||||
|
||||
// SelectByManifestID returns list of car updates by package update id
|
||||
func (c *CarUpdates) SelectByManifestID(manifest_id int64) ([]common.CarUpdate, error) {
|
||||
updates := []common.CarUpdate{}
|
||||
|
||||
err := c.GetDBConn().Model(&updates).Where("update_manifest.id = ?", manifest_id).
|
||||
Relation("UpdateManifest").
|
||||
Select()
|
||||
|
||||
return updates, errors.WithStack(err)
|
||||
}
|
||||
|
||||
func (c *CarUpdates) SelectOrInsert(update *common.CarUpdate) (bool, error) {
|
||||
return c.insertSelectWithStack(c.GetDBConn().Model(update).Where("vin = ?vin AND update_manifest_id = ?update_manifest_id").SelectOrInsert())
|
||||
}
|
||||
|
||||
// Select returns list of cars
|
||||
func (c *CarUpdates) Select(filter *common.CarUpdate, paging *PageQueryOptions) ([]common.CarUpdate, error) {
|
||||
ups := []common.CarUpdate{}
|
||||
query := c.GetDBConn().Model(&ups)
|
||||
c.selectFilter(query, filter)
|
||||
StatusHistoryFilter(query, []string{"cleanup_succeeded"})
|
||||
c.pageQuery(query, paging)
|
||||
|
||||
if filter.UpdateManifest != nil && filter.UpdateManifest.ManifestType > 0 {
|
||||
query.
|
||||
Relation("UpdateManifest", func(q *orm.Query) (*orm.Query, error) {
|
||||
return q.Where("manifest_type = ?", filter.UpdateManifest.ManifestType), nil
|
||||
})
|
||||
} else {
|
||||
query.Relation("UpdateManifest")
|
||||
}
|
||||
|
||||
err := query.Select()
|
||||
|
||||
return ups, errors.WithStack(err)
|
||||
}
|
||||
|
||||
func (c *CarUpdates) Delete(update *common.CarUpdate) (orm.Result, error) {
|
||||
err := validator.ValidateIDField(update.ID)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
conn := c.GetDBConn()
|
||||
tx, err := conn.Begin()
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
tx.Close()
|
||||
}()
|
||||
|
||||
_, err = tx.Model(&common.CarUpdateStatus{CarUpdateID: update.ID}).Where("car_update_id = ?car_update_id").Delete()
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
result, err := tx.Model(update).WherePK().Delete()
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
|
||||
return result, errors.WithStack(err)
|
||||
}
|
||||
|
||||
func (c *CarUpdates) UpdateStatus(update *common.CarUpdate) (orm.Result, error) {
|
||||
result, err := c.UpdateCarUpdate(update)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
_, err = c.LogStatus(update)
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (c *CarUpdates) UpdateStatusIfNotRepeat(update *common.CarUpdate) (orm.Result, error) {
|
||||
orm, err := c.LogStatusIfNotARepeat(update)
|
||||
if err != nil {
|
||||
return orm, err
|
||||
}
|
||||
result, err := c.UpdateCarUpdate(update)
|
||||
if err != nil {
|
||||
err = errors.WithStack(err)
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (c *CarUpdates) UpdateCarUpdate(update *common.CarUpdate) (orm.Result, error) {
|
||||
err := validator.ValidateIDField(update.ID)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return c.resultWithStack(c.GetDBConn().Model(update).Column("status", "error_code").WherePK().Update())
|
||||
}
|
||||
|
||||
func (c *CarUpdates) LogStatus(update *common.CarUpdate) (orm.Result, error) {
|
||||
// Should this also be updated to have the catch of double statuses
|
||||
status := common.CarUpdateStatus{
|
||||
CarUpdateID: update.ID,
|
||||
Status: update.Status,
|
||||
ErrorCode: update.ErrorCode,
|
||||
Info: update.Info,
|
||||
}
|
||||
|
||||
result, err := c.resultWithStack(c.GetDBConn().Model(&status).Insert())
|
||||
if err != nil {
|
||||
err = errors.WithStack(fmt.Errorf("LogStatus::%s. with status %d %s %d %s", err.Error(), status.CarUpdateID, status.Status, status.ErrorCode, status.Info))
|
||||
logger.Err(err)
|
||||
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
// If the last status that was inserted is the same as the one we are trying to insert, we do not do it
|
||||
var RepeatedStatus = errors.New("RepeatedStatus")
|
||||
|
||||
func (c *CarUpdates) LogStatusIfNotARepeat(update *common.CarUpdate) (orm orm.Result, err error) {
|
||||
logger.Debug().Msgf("attempt to add new status %s for update %d", update.Status, update.ID) // CEC-5650 debugging
|
||||
status := common.CarUpdateStatus{
|
||||
CarUpdateID: update.ID,
|
||||
Status: update.Status,
|
||||
ErrorCode: update.ErrorCode,
|
||||
Info: update.Info,
|
||||
}
|
||||
|
||||
// This query insert's our status if the most recent status on the car_update_id does not match the new status
|
||||
// If it does match the new status, then we do not insert
|
||||
// 1: Update ID, 2: Status, 3: error_code, 4: The info, 5: UpdateID, 6: The status
|
||||
query := `INSERT INTO public.car_update_statuses(
|
||||
car_update_id, status, error_code, info)
|
||||
SELECT ?, ?, ?, ? WHERE NOT EXISTS(
|
||||
SELECT 1 FROM (SELECT id, car_update_id, status, error_code, created_at, updated_at, info
|
||||
FROM public.car_update_statuses WHERE car_update_id = ? ORDER BY created_at DESC LIMIT 1) AS temp WHERE status = ?)`
|
||||
orm, err = c.GetDBConn().Exec(query, status.CarUpdateID, status.Status, status.ErrorCode, status.Info, status.CarUpdateID, status.Status)
|
||||
if err != nil {
|
||||
err = errors.WithStack(fmt.Errorf("%s. with status %d %s %d %s", err.Error(), status.CarUpdateID, status.Status, status.ErrorCode, status.Info))
|
||||
logger.Err(err)
|
||||
return orm, err
|
||||
}
|
||||
if orm.RowsAffected() == 0 {
|
||||
return orm, RepeatedStatus
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *CarUpdates) Insert(update *common.CarUpdate) (orm.Result, error) {
|
||||
return c.insert(update)
|
||||
}
|
||||
|
||||
func (c *CarUpdates) InsertAndCreateStatus(update *common.CarUpdate) (res orm.Result, err error) {
|
||||
res, err = c.insert(update)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return c.LogStatus(update)
|
||||
}
|
||||
|
||||
func (c *CarUpdates) Load(update *common.CarUpdate) error {
|
||||
query := c.GetDBConn().Model(update)
|
||||
|
||||
if update.ID > 0 {
|
||||
query.WherePK()
|
||||
} else if update.VIN == "" && update.UpdateManifestID > 0 {
|
||||
query.Where("vin = ?vin AND update_manifest_id = ?update_manifest_id")
|
||||
} else {
|
||||
return errors.New("no id, vin, update_manifest_id")
|
||||
}
|
||||
|
||||
if update.UpdateManifest != nil && update.UpdateManifest.ManifestType > 0 {
|
||||
query.
|
||||
Relation("UpdateManifest", func(q *orm.Query) (*orm.Query, error) {
|
||||
return q.Where("manifest_type = ?", update.UpdateManifest.ManifestType), nil
|
||||
})
|
||||
} else {
|
||||
query.Relation("UpdateManifest")
|
||||
}
|
||||
|
||||
return errors.WithStack(query.
|
||||
Relation("UpdateManifest.ECUs").
|
||||
Relation("UpdateManifest.ECUs.Files").
|
||||
Relation("UpdateManifest.ECUs.Files.WriteRegion").
|
||||
Relation("UpdateManifest.ECUs.Files.EraseRegion").
|
||||
Select())
|
||||
}
|
||||
|
||||
func (c *CarUpdates) Count(filter *common.CarUpdate) (int, error) {
|
||||
query := c.GetDBConn().Model((*common.CarUpdate)(nil))
|
||||
|
||||
c.selectFilter(query, filter)
|
||||
|
||||
return c.countWithStack(query.Count())
|
||||
}
|
||||
|
||||
func (c *CarUpdates) CountUpdateStatuses(carupdateid int64) (int, error) {
|
||||
query := c.GetDBConn().Model((*common.CarUpdateStatus)(nil))
|
||||
|
||||
return c.countWithStack(query.Where("car_update_id = ?", carupdateid).Count())
|
||||
}
|
||||
|
||||
func (c *CarUpdates) GetUpdateStatuses(carupdateid int64, paging *PageQueryOptions) ([]common.CarUpdateStatus, error) {
|
||||
result := []common.CarUpdateStatus{}
|
||||
query := c.GetDBConn().Model(&result)
|
||||
|
||||
c.pageQuery(query, paging)
|
||||
|
||||
err := query.Where("car_update_id = ?", carupdateid).Select()
|
||||
if err == nil && len(result) == 0 {
|
||||
return result, errors.WithStack(pg.ErrNoRows)
|
||||
}
|
||||
|
||||
return result, errors.WithStack(err)
|
||||
}
|
||||
|
||||
func (c *CarUpdates) TruncateRequirementsAwaitForUpdate(carupdateid int64) (orm.Result, error) {
|
||||
queryString := `with row_nums as (select *, row_number() over (partition by car_update_id order by updated_at) as row
|
||||
from car_update_statuses where status = 'requirements_await'
|
||||
and car_update_id = ?)
|
||||
delete from car_update_statuses where id in (
|
||||
select id from row_nums where (row != (select max(row) from row_nums)) and (row != 1)
|
||||
)`
|
||||
|
||||
err := validator.ValidateIDField(carupdateid)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
conn := c.GetDBConn()
|
||||
tx, err := conn.Begin()
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
tx.Close()
|
||||
}()
|
||||
|
||||
result, err := conn.Query(&common.CarUpdateStatus{}, queryString, carupdateid)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
|
||||
return result, errors.WithStack(err)
|
||||
}
|
||||
|
||||
func (c *CarUpdates) GetManifest(carupdateid int64) (*common.UpdateManifest, error) {
|
||||
manifest := common.UpdateManifest{}
|
||||
|
||||
err := c.GetDBConn().Model(&manifest).Join("JOIN car_updates").JoinOn("car_updates.update_manifest_id = update_manifest.id").JoinOn("car_updates.id = ?", carupdateid).Relation("ECUs").Relation("ECUs.Files").Select()
|
||||
|
||||
return &manifest, errors.WithStack(err)
|
||||
}
|
||||
|
||||
func (c *CarUpdates) HasPendingUpdates(manifestID int64, vin string) (bool, error) {
|
||||
count, err := c.GetDBConn().Model(&common.CarUpdate{}).Where("update_manifest_id = ? AND vin = ? AND status NOT IN (?)", manifestID, vin, pg.In(FINAL_UPDATE_STATUS)).Count()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// An updateID of 0 means no pending update
|
||||
func (c *CarUpdates) HasPendingUpdatesFromAftersalesUser(manifestID int64, vin string) (updateID int64, pendingUpdateAftersales bool, err error) {
|
||||
var carUpdates []common.CarUpdate
|
||||
err = c.GetDBConn().Model(&carUpdates).Where("update_manifest_id = ? AND vin = ? AND status NOT IN (?)", manifestID, vin, pg.In(FINAL_UPDATE_STATUS)).Select()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if len(carUpdates) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if len(carUpdates) > 1 {
|
||||
logger.Error().Msgf("Have more than one pending update for manifestID %d on vin %s. This should not happen and our logic needs to change", manifestID, vin)
|
||||
}
|
||||
// This should only ever be one
|
||||
for _, u := range carUpdates {
|
||||
updateID = u.ID
|
||||
if u.UpdateSource == common.UPDATE_SOURCE_AFTERSALES {
|
||||
pendingUpdateAftersales = true
|
||||
break
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *CarUpdates) InsertMissingFlashpack(vin string, flashpackVersion string) (err error) {
|
||||
mf := common.MissingFlashpack{
|
||||
VIN: vin,
|
||||
FlashPackVersion: flashpackVersion,
|
||||
}
|
||||
_, err = c.insert(&mf)
|
||||
return
|
||||
}
|
||||
304
pkg/db/queries/carupdates_test.go
Normal file
304
pkg/db/queries/carupdates_test.go
Normal file
@@ -0,0 +1,304 @@
|
||||
package queries_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
|
||||
m "fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/db"
|
||||
"fiskerinc.com/modules/db/queries"
|
||||
"fiskerinc.com/modules/testhelper"
|
||||
)
|
||||
|
||||
const testCarUpdateVIN = "FISKER1234"
|
||||
const testCarUpdateVIN2 = "FISKER12345"
|
||||
const testCarUpdatesPackageName = "TEST CARUPDATES"
|
||||
|
||||
func TestCarUpdateIntegration(t *testing.T) {
|
||||
t.Skip()
|
||||
query := queries.CarUpdates{}
|
||||
client := query.GetClient()
|
||||
vin, manifestID := setupCarUpdates(&query, client, testCarUpdateVIN)
|
||||
vin2, manifestID2 := setupCarUpdates(&query, client, testCarUpdateVIN2)
|
||||
defer func() {
|
||||
if manifestID > 0 {
|
||||
manifest := queries.NewUpdateManifest(nil)
|
||||
manifest.Delete(&m.UpdateManifest{ID: manifestID})
|
||||
}
|
||||
}()
|
||||
carupdateID := testCarUpdateSelectOrInsert(t, &query, vin, manifestID)
|
||||
testCarUpdateSelectByID(t, &query, carupdateID)
|
||||
testCarUpdateSelectByVIN(t, &query, carupdateID)
|
||||
testCarUpdateSelectByManifestID(t, &query, manifestID, carupdateID)
|
||||
testCarUpdateGetManifest(t, &query, carupdateID)
|
||||
testCarUpdateSelect(t, &query, vin, carupdateID)
|
||||
testStatusHistoryFilter(t, &query)
|
||||
carupdateID2 := testCarUpdateSelectOrInsert(t, &query, vin2, manifestID2)
|
||||
testCarUpdateSelectMostRecentByVINs(t, &query, []int64{carupdateID, carupdateID2})
|
||||
testCarUpdateDelete(t, &query, carupdateID)
|
||||
testCarUpdateDelete(t, &query, carupdateID2)
|
||||
}
|
||||
|
||||
func testCarUpdateSelect(t *testing.T, query queries.CarUpdatesInterface, vin string, carupdateID int64) {
|
||||
filter := m.CarUpdate{
|
||||
VIN: vin,
|
||||
}
|
||||
options := &queries.PageQueryOptions{
|
||||
Limit: 100,
|
||||
Offset: 0,
|
||||
Order: "id desc",
|
||||
}
|
||||
|
||||
query.UpdateStatusIfNotRepeat(&m.CarUpdate{
|
||||
ID: carupdateID,
|
||||
VIN: vin,
|
||||
Status: "manifest_succeeded",
|
||||
})
|
||||
|
||||
carupdates, err := query.Select(&filter, options)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Select", nil, err)
|
||||
}
|
||||
|
||||
for _, cu := range carupdates {
|
||||
if cu.Status == "cleanup_succeeded" {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Select", "manifest_succeeded", cu.Status)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testStatusHistoryFilter(t *testing.T, query queries.CarUpdatesInterface) {
|
||||
mock := new(orm.Query)
|
||||
queries.StatusHistoryFilter(mock, []string{"cleanup_succeeded"})
|
||||
value, _ := mock.AppendQuery(orm.NewFormatter(), nil)
|
||||
actual := string(value)
|
||||
|
||||
expected := `SELECT car_update.id, car_update.vin, car_update.created_at, car_update.updated_at, car_update.update_manifest_id, car_update.error_code, car_update.info, car_update.username, cus.status AS status LEFT JOIN (
|
||||
SELECT car_update_id, status,
|
||||
ROW_NUMBER() OVER (PARTITION BY car_update_id ORDER BY updated_at DESC) AS rn
|
||||
FROM public.car_update_statuses
|
||||
WHERE status NOT IN ('cleanup_succeeded')
|
||||
) cus ON car_update.id = cus.car_update_id AND cus.rn = 1`
|
||||
|
||||
if actual != expected {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "StatusHistoryFilter", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func testCarUpdateSelectOrInsert(t *testing.T, query queries.CarUpdatesInterface, vin string, manifestID int64) int64 {
|
||||
carupdate := m.CarUpdate{
|
||||
VIN: vin,
|
||||
UpdateManifestID: manifestID,
|
||||
}
|
||||
|
||||
_, err := query.SelectOrInsert(&carupdate)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "SelectOrInsert", nil, err)
|
||||
}
|
||||
if carupdate.ID == 0 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "ID", "Has ID", carupdate.ID)
|
||||
}
|
||||
|
||||
return carupdate.ID
|
||||
}
|
||||
|
||||
func testCarUpdateSelectByID(t *testing.T, query queries.CarUpdatesInterface, carupdateID int64) {
|
||||
carupdate, err := query.SelectByID(carupdateID)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "SelectByID", nil, err)
|
||||
}
|
||||
checkCarUpdate(t, carupdate, carupdateID)
|
||||
}
|
||||
|
||||
func testCarUpdateSelectByVIN(t *testing.T, query queries.CarUpdatesInterface, carupdateID int64) {
|
||||
carupdates, err := query.SelectByVIN(testCarUpdateVIN)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "SelectByVIN", nil, err)
|
||||
}
|
||||
if len(carupdates) == 0 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Count", "More than 0", len(carupdates))
|
||||
}
|
||||
for _, update := range carupdates {
|
||||
if update.ID == carupdateID {
|
||||
checkCarUpdate(t, &update, carupdateID)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testCarUpdateSelectMostRecentByVINs(t *testing.T, query queries.CarUpdatesInterface, carupdateIDs []int64) {
|
||||
carupdates, err := query.SelectMostRecentByVINs([]string{testCarUpdateVIN, testCarUpdateVIN2})
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "SelectMostRecentByVINs", nil, err)
|
||||
}
|
||||
if len(carupdates) == 0 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Count", "More than 0", len(carupdates))
|
||||
}
|
||||
matchCount := 0
|
||||
for _, update := range carupdates {
|
||||
for _, carupdateID := range carupdateIDs {
|
||||
if update.ID == carupdateID {
|
||||
matchCount++
|
||||
checkCarUpdate(t, &update, carupdateID)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if matchCount != 2 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Matches", "2", matchCount)
|
||||
}
|
||||
}
|
||||
|
||||
func testCarUpdateSelectByManifestID(t *testing.T, query queries.CarUpdatesInterface, manifestID int64, carupdateID int64) {
|
||||
carupdates, err := query.SelectByManifestID(manifestID)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "SelectByVIN", nil, err)
|
||||
}
|
||||
if len(carupdates) == 0 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Count", "More than 0", len(carupdates))
|
||||
}
|
||||
for _, carupdate := range carupdates {
|
||||
if carupdate.VIN == "FISKER1234" {
|
||||
checkCarUpdate(t, &carupdate, carupdateID)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
t.Errorf(testhelper.TestErrorTemplate, "checkCarUpdate", "FISKER1234", "VIN not found")
|
||||
}
|
||||
|
||||
func testCarUpdateGetManifest(t *testing.T, query queries.CarUpdatesInterface, carupdateID int64) {
|
||||
manifest, err := query.GetManifest(carupdateID)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "GetManifest", nil, err)
|
||||
}
|
||||
if manifest.Name != testCarUpdatesPackageName {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "GetManifest Name", testCarUpdatesPackageName, manifest.Name)
|
||||
}
|
||||
if len(manifest.ECUs) != 1 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "GetManifest ECUs", 1, len(manifest.ECUs))
|
||||
}
|
||||
}
|
||||
|
||||
func testCarUpdateDelete(t *testing.T, query queries.CarUpdatesInterface, carupdateID int64) {
|
||||
result, err := query.Delete(&m.CarUpdate{ID: carupdateID})
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Delete", nil, err)
|
||||
}
|
||||
if result.RowsAffected() != 1 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Delete RowsAffected", 1, result.RowsAffected())
|
||||
}
|
||||
}
|
||||
|
||||
func checkCarUpdate(t *testing.T, carupdate *m.CarUpdate, carupdateID int64) {
|
||||
if carupdate.ID != carupdateID {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "ID", carupdateID, carupdate.ID)
|
||||
}
|
||||
if carupdate.Status != "pending" {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "ID", "pending", carupdate.Status)
|
||||
}
|
||||
if carupdate.VIN != testCarUpdateVIN && carupdate.VIN != testCarUpdateVIN2 {
|
||||
expected := fmt.Sprintf("%s or %s", testCarUpdateVIN, testCarUpdateVIN2)
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Car", expected, carupdate.VIN)
|
||||
}
|
||||
}
|
||||
|
||||
func setupCarUpdates(query queries.CarUpdatesInterface, client *db.DBClient, vin string) (string, int64) {
|
||||
conn := client.GetConn()
|
||||
client.InitSchema([]interface{}{
|
||||
(*m.UpdateManifest)(nil),
|
||||
(*m.UpdateManifestECU)(nil),
|
||||
(*m.UpdateManifestFile)(nil),
|
||||
(*m.CarUpdate)(nil),
|
||||
})
|
||||
|
||||
conn.AddQueryHook(db.SQLLogger{})
|
||||
|
||||
car := m.Car{
|
||||
VIN: vin,
|
||||
}
|
||||
updatemanifest := m.UpdateManifest{
|
||||
Name: testCarUpdatesPackageName,
|
||||
Version: "2.0",
|
||||
}
|
||||
|
||||
_, err := conn.Model(&car).Where("vin = ?vin").SelectOrInsert()
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
|
||||
_, err = conn.Model(&updatemanifest).Where("name = ?name AND version = ?version").SelectOrInsert()
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
|
||||
ecu := m.UpdateManifestECU{
|
||||
UpdateManifestID: updatemanifest.ID,
|
||||
ECU: "TEST",
|
||||
Version: "version",
|
||||
}
|
||||
_, err = conn.Model(&ecu).Insert()
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
|
||||
_, err = conn.Model(&m.UpdateManifestFile{
|
||||
UpdateManifestECUID: ecu.ID,
|
||||
FileID: "f000000000000000",
|
||||
Filename: "TEST.bin",
|
||||
URL: "http://fiskerinc.com/download",
|
||||
}).Insert()
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
|
||||
return vin, updatemanifest.ID
|
||||
}
|
||||
|
||||
/*
|
||||
This test was for an integration test
|
||||
func TestIntegrationInsert(t *testing.T) {
|
||||
opts := pg.Options{
|
||||
Addr: "127.0.0.1:5432",
|
||||
User: "postgres",
|
||||
Password: "REPLACE_ME",
|
||||
}
|
||||
con := pg.Connect(&opts)
|
||||
cl := db.DBClient{}
|
||||
err := cl.SetConn(con)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
p := queries.CarUpdates{}
|
||||
p.QueryBase.SetClient(&cl)
|
||||
|
||||
update := m.CarUpdate{
|
||||
ID: 325,
|
||||
VIN: testCarUpdateVIN,
|
||||
UpdateManifestID: 325,
|
||||
Status: "requirements_failed",
|
||||
ErrorCode: 0,
|
||||
}
|
||||
_, err = p.LogStatusIfNotARepeat(&update)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
update.Status = "manifest_succeeded"
|
||||
_, err = p.LogStatusIfNotARepeat(&update)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
_, err = p.LogStatusIfNotARepeat(&update)
|
||||
if err == nil {
|
||||
t.Error("Repeated status did not produce an error")
|
||||
}
|
||||
if !errors.Is(err, queries.RepeatedStatus) {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
*/
|
||||
142
pkg/db/queries/certificates.go
Normal file
142
pkg/db/queries/certificates.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package queries
|
||||
|
||||
import (
|
||||
"fiskerinc.com/modules/common"
|
||||
s "fiskerinc.com/modules/security"
|
||||
"fiskerinc.com/modules/validator"
|
||||
"github.com/go-pg/pg/v10"
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const sqlSelectMostRecents = "SELECT c1.* FROM certificates c1 INNER JOIN (SELECT common_name, type, MAX(created_at) created_at FROM certificates WHERE common_name = ? AND type IN (?) GROUP BY common_name, type) c2 ON c1.common_name = c2.common_name AND c1.type = c2.type AND c1.created_at = c2.created_at"
|
||||
|
||||
type CertificatesInterface interface {
|
||||
Insert(c *common.Certificate) (orm.Result, error)
|
||||
Update(c *common.Certificate) (orm.Result, error)
|
||||
Remove(c *common.Certificate) (orm.Result, error)
|
||||
SelectByCommonName(cn string) ([]common.Certificate, error)
|
||||
SelectBySerial(serial string) (*common.Certificate, error)
|
||||
SelectMostRecent(cn string, certType string) (*common.Certificate, error)
|
||||
SelectMostRecents(cn string, certTypes []string) ([]common.Certificate, error)
|
||||
}
|
||||
|
||||
type Certificates struct {
|
||||
QueryBase
|
||||
}
|
||||
|
||||
func (c *Certificates) Insert(certificate *common.Certificate) (orm.Result, error) {
|
||||
enc := s.Encrypt{}
|
||||
encryptor, err := enc.GetEncryptor()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
certificate.EncryptedKey = encryptor.EncryptChunk([]byte(certificate.EncryptedKey))
|
||||
|
||||
return c.resultWithStack(c.GetDBConn().Model(certificate).Insert())
|
||||
}
|
||||
|
||||
func (c *Certificates) Update(certificate *common.Certificate) (orm.Result, error) {
|
||||
if certificate.SerialNumber == "" {
|
||||
return nil, errors.WithStack(&validator.FieldError{
|
||||
ErrorMsg: "Serial required",
|
||||
})
|
||||
}
|
||||
|
||||
return c.resultWithStack(c.GetDBConn().Model(certificate).Column("valid").WherePK().Update())
|
||||
}
|
||||
|
||||
func (c *Certificates) Remove(certificate *common.Certificate) (orm.Result, error) {
|
||||
if certificate.SerialNumber == "" {
|
||||
return nil, &validator.FieldError{
|
||||
ErrorMsg: "Serial required",
|
||||
}
|
||||
}
|
||||
|
||||
return c.resultWithStack(c.GetDBConn().Model(certificate).WherePK().Delete())
|
||||
}
|
||||
|
||||
func (c *Certificates) SelectByCommonName(cn string) ([]common.Certificate, error) {
|
||||
certificates := []common.Certificate{}
|
||||
|
||||
err := c.GetDBConn().Model(&certificates).Where("common_name = ?", cn).Select()
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
for i := range certificates {
|
||||
err = c.decrypt(&certificates[i])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
certificates[i].PrivateKey = string(certificates[i].EncryptedKey)
|
||||
}
|
||||
return certificates, err
|
||||
}
|
||||
|
||||
func (c *Certificates) SelectBySerial(serial string) (*common.Certificate, error) {
|
||||
certificate := common.Certificate{}
|
||||
|
||||
err := c.GetDBConn().Model(&certificate).Where("serial_number = ?", serial).Select()
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
err = c.decrypt(&certificate)
|
||||
certificate.PrivateKey = string(certificate.EncryptedKey)
|
||||
|
||||
return &certificate, err
|
||||
}
|
||||
|
||||
func (c *Certificates) SelectMostRecent(cn string, certType string) (*common.Certificate, error) {
|
||||
cert := common.Certificate{}
|
||||
err := c.GetDBConn().Model(&cert).Where("common_name = ? AND type = ?", cn, certType).Order("created_at desc").Limit(1).Select()
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
err = c.decrypt(&cert)
|
||||
cert.PrivateKey = string(cert.EncryptedKey)
|
||||
|
||||
return &cert, err
|
||||
}
|
||||
|
||||
func (c *Certificates) SelectMostRecents(cn string, certTypes []string) ([]common.Certificate, error) {
|
||||
certificates := []common.Certificate{}
|
||||
|
||||
_, err := c.GetDBConn().Model().Query(&certificates, sqlSelectMostRecents, cn, pg.In(certTypes))
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
for i := range certificates {
|
||||
err = c.decrypt(&certificates[i])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
certificates[i].PrivateKey = string(certificates[i].EncryptedKey)
|
||||
}
|
||||
|
||||
return certificates, err
|
||||
}
|
||||
|
||||
func (c *Certificates) decrypt(cert *common.Certificate) error {
|
||||
if cert.EncryptedKey == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
enc := s.Encrypt{}
|
||||
encryptor, err := enc.GetEncryptor()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pkey, err := encryptor.DecryptChunk([]byte(cert.EncryptedKey))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cert.EncryptedKey = pkey
|
||||
|
||||
return nil
|
||||
}
|
||||
22
pkg/db/queries/certificates_test.go
Normal file
22
pkg/db/queries/certificates_test.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package queries_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"fiskerinc.com/modules/db"
|
||||
"fiskerinc.com/modules/db/queries"
|
||||
)
|
||||
|
||||
func TestSelectMostRecents(t *testing.T) {
|
||||
t.Skip()
|
||||
q := queries.Certificates{}
|
||||
q.GetDBConn().AddQueryHook(db.SQLLogger{})
|
||||
|
||||
result, err := q.SelectMostRecents("11111111111111111", []string{"CHARGING", "ICC", "TBOX"})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
t.Logf("%v", result)
|
||||
}
|
||||
23
pkg/db/queries/driver_emails.go
Normal file
23
pkg/db/queries/driver_emails.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package queries
|
||||
|
||||
import (
|
||||
"fiskerinc.com/modules/common"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type DriverEmailsInterface interface {
|
||||
SelectByVINs(vins []string) ([]common.DriverEmail, error)
|
||||
}
|
||||
|
||||
type DriverEmails struct {
|
||||
QueryBase
|
||||
}
|
||||
|
||||
func (d *DriverEmails) SelectByVINs(vins []string) ([]common.DriverEmail, error) {
|
||||
driverEmails := []common.DriverEmail{}
|
||||
query := d.GetDBConn().Model(&driverEmails)
|
||||
|
||||
err := query.WhereIn("vin IN (?)", vins).Select()
|
||||
|
||||
return driverEmails, errors.WithStack(err)
|
||||
}
|
||||
109
pkg/db/queries/drivers.go
Normal file
109
pkg/db/queries/drivers.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package queries
|
||||
|
||||
import (
|
||||
"fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/validator"
|
||||
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type DriversInterface interface {
|
||||
Delete(driver *common.Driver) (orm.Result, error)
|
||||
Insert(driver *common.Driver) (orm.Result, error)
|
||||
Select(filter *common.Driver) ([]common.Driver, error)
|
||||
SelectOrInsert(driver *common.Driver) (bool, error)
|
||||
Load(driver *common.Driver) error
|
||||
}
|
||||
|
||||
type Drivers struct {
|
||||
QueryBase
|
||||
}
|
||||
|
||||
// Select returns list of drivers
|
||||
func (d *Drivers) Select(filter *common.Driver) ([]common.Driver, error) {
|
||||
drivers := []common.Driver{}
|
||||
query := d.GetDBConn().Model(&drivers)
|
||||
|
||||
d.selectFilter(query, filter)
|
||||
|
||||
err := query.Select()
|
||||
|
||||
return drivers, errors.WithStack(err)
|
||||
}
|
||||
|
||||
func (d *Drivers) SelectOrInsert(driver *common.Driver) (bool, error) {
|
||||
q := d.GetDBConn().Model(driver)
|
||||
|
||||
if driver.ID != "" {
|
||||
q.Where("id = ?id")
|
||||
} else {
|
||||
return false, errors.New("no ID")
|
||||
}
|
||||
|
||||
inserted, err := q.SelectOrInsert()
|
||||
|
||||
return inserted, errors.WithStack(err)
|
||||
}
|
||||
|
||||
func (d *Drivers) selectFilter(query *orm.Query, filter *common.Driver) {
|
||||
if filter.ID != "" {
|
||||
query.Where("ID = ?", filter.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Drivers) Insert(driver *common.Driver) (orm.Result, error) {
|
||||
|
||||
err := validator.ValidateStruct(driver)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
result, err := d.GetDBConn().Model(driver).Insert()
|
||||
|
||||
return result, errors.WithStack(err)
|
||||
}
|
||||
|
||||
func (d *Drivers) Delete(driver *common.Driver) (orm.Result, error) {
|
||||
if driver.ID == "" {
|
||||
return nil, errors.WithStack(&validator.FieldError{
|
||||
ErrorMsg: "ID required",
|
||||
})
|
||||
}
|
||||
|
||||
tx, err := d.GetDBConn().Begin()
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
defer tx.Close()
|
||||
|
||||
cardrivers := common.CarToDriver{
|
||||
DriverID: driver.ID,
|
||||
}
|
||||
_, err = tx.Model(&cardrivers).Where("driverId = ?driverId").Delete()
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
result, err := tx.Model(driver).WherePK().Delete()
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
|
||||
return result, errors.WithStack(err)
|
||||
}
|
||||
|
||||
func (d *Drivers) Load(driver *common.Driver) error {
|
||||
query := d.GetDBConn().Model(driver)
|
||||
|
||||
err := query.Select()
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
223
pkg/db/queries/ecckeys.go
Normal file
223
pkg/db/queries/ecckeys.go
Normal file
@@ -0,0 +1,223 @@
|
||||
package queries
|
||||
|
||||
import (
|
||||
"fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/logger"
|
||||
s "fiskerinc.com/modules/security"
|
||||
"github.com/go-pg/pg/v10"
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const sqlLastManifestEnv = "WITH sub AS (SELECT COALESCE((SELECT env FROM update_manifests WHERE id IN (SELECT update_manifest_id FROM car_updates WHERE id IN (SELECT MAX(id) FROM car_updates WHERE vin = ?))), 'current') as env)"
|
||||
const sqlCarUpdateIDEnv = "WITH sub AS (SELECT COALESCE((SELECT env FROM update_manifests WHERE update_manifests.id IN (SELECT update_manifest_id FROM car_updates WHERE id = ?)), 'current') as env)"
|
||||
const sqlECCPrivKeysEnv = " SELECT priv_key_level_1, priv_key_level_2, priv_key_level_3, ecu, ecc_keys.env FROM ecc_keys INNER JOIN sub ON ecc_keys.env = sub.env ORDER BY ecu;"
|
||||
|
||||
type EccKeysInterface interface {
|
||||
Insert(keys common.ECCKeys) (orm.Result, error)
|
||||
SelectAllPrivateKeys() ([]common.ECCKeys, error)
|
||||
SelectAllPrivateKeysByEnv(env string) ([]common.ECCKeys, error)
|
||||
SelectPublicKeysByECUByEnv(ecu string, env string) (common.ECCKeys, error)
|
||||
SelectAllPublicKeysByEnv(env string) ([]common.ECCKeys, error)
|
||||
SelectPrivateKeysByECUsEnv(ecus []string, env string) ([]common.ECCKeys, error)
|
||||
SelectAllPrivateKeysByVIN(vin string) ([]common.ECCKeys, error)
|
||||
SelectAllPrivateKeysByCarUpdateID(id int64) ([]common.ECCKeys, error)
|
||||
}
|
||||
|
||||
type EccKeys struct {
|
||||
QueryBase
|
||||
}
|
||||
|
||||
func (ek *EccKeys) Insert(keys common.ECCKeys) (orm.Result, error) {
|
||||
enc := s.Encrypt{}
|
||||
encryptor, err := enc.GetEncryptor()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keys.PubKey1.SetBytes(encryptor.EncryptChunk(keys.PubKey1.Bytes()))
|
||||
keys.PrivKey1.SetBytes(encryptor.EncryptChunk(keys.PrivKey1.Bytes()))
|
||||
keys.PubKey2.SetBytes(encryptor.EncryptChunk(keys.PubKey2.Bytes()))
|
||||
keys.PrivKey2.SetBytes(encryptor.EncryptChunk(keys.PrivKey2.Bytes()))
|
||||
keys.PubKey3.SetBytes(encryptor.EncryptChunk(keys.PubKey3.Bytes()))
|
||||
keys.PrivKey3.SetBytes(encryptor.EncryptChunk(keys.PrivKey3.Bytes()))
|
||||
|
||||
return ek.resultWithStack(ek.GetDBConn().Model(&keys).Insert())
|
||||
}
|
||||
|
||||
// Selects all private keys and ECU's
|
||||
func (ek *EccKeys) SelectAllPrivateKeys() ([]common.ECCKeys, error) {
|
||||
return ek.selectKeys("current", "priv_key_level_1", "priv_key_level_2", "priv_key_level_3", "ecu")
|
||||
}
|
||||
|
||||
// Selects all private keys and ECU's by env
|
||||
func (ek *EccKeys) SelectAllPrivateKeysByEnv(env string) ([]common.ECCKeys, error) {
|
||||
return ek.selectKeys(env, "priv_key_level_1", "priv_key_level_2", "priv_key_level_3", "ecu")
|
||||
}
|
||||
|
||||
// Selects public keys associated with ECU by env
|
||||
func (ek *EccKeys) SelectAllPublicKeysByEnv(env string) ([]common.ECCKeys, error) {
|
||||
return ek.selectKeys(env, "pub_key_level_1", "pub_key_level_2", "pub_key_level_3", "ecu")
|
||||
}
|
||||
|
||||
// Selects keys associated with ECU
|
||||
func (ek *EccKeys) SelectAllKeys() ([]common.ECCKeys, error) {
|
||||
return ek.selectKeys("current", "pub_key_level_1", "pub_key_level_2", "pub_key_level_3", "priv_key_level_1", "priv_key_level_2", "priv_key_level_3", "ecu")
|
||||
}
|
||||
|
||||
// Selects public keys associated with ECU by env
|
||||
func (ek *EccKeys) SelectPublicKeysByECUByEnv(ecu string, env string) (common.ECCKeys, error) {
|
||||
ecckey := common.ECCKeys{}
|
||||
|
||||
err := ek.GetDBConn().Model(&ecckey).Column("pub_key_level_1", "pub_key_level_2", "pub_key_level_3", "ecu").Where("ecu = ? AND env = ?", ecu, env).Select()
|
||||
if err != nil {
|
||||
return ecckey, errors.WithStack(err)
|
||||
}
|
||||
|
||||
err = ek.decrypt(&ecckey)
|
||||
|
||||
return ecckey, err
|
||||
}
|
||||
|
||||
func (ek *EccKeys) SelectPrivateKeysByECUsEnv(ecus []string, env string) ([]common.ECCKeys, error) {
|
||||
var keys []common.ECCKeys
|
||||
|
||||
if env == "" {
|
||||
env = common.EnvCurrent
|
||||
}
|
||||
|
||||
err := ek.GetDBConn().Model(&keys).
|
||||
Column("priv_key_level_1", "priv_key_level_2", "priv_key_level_3", "ecu").
|
||||
Where("ecu IN (?) AND env = ?", pg.In(ecus), env).
|
||||
Order("ecu").
|
||||
Select()
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
for i := range keys {
|
||||
err = ek.decrypt(&keys[i])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// SelectAllPrivateKeysByVIN returns the ECC keys for an environment based on the last update sent to the VIN
|
||||
// If no update was sent, it will send the keys for the current environment
|
||||
func (ek *EccKeys) SelectAllPrivateKeysByVIN(vin string) ([]common.ECCKeys, error) {
|
||||
var keys []common.ECCKeys
|
||||
|
||||
_, err := ek.GetDBConn().Query(&keys, sqlLastManifestEnv+sqlECCPrivKeysEnv, vin)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
if len(keys) == 0 {
|
||||
return nil, errors.Errorf("No ECC keys for %s", vin)
|
||||
}
|
||||
|
||||
logger.Info().Msgf("SelectAllPrivateKeysByVIN %s %s", vin, keys[0].Env)
|
||||
|
||||
for i := range keys {
|
||||
keys[i].Env = ""
|
||||
err = ek.decrypt(&keys[i])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func (ek *EccKeys) SelectAllPrivateKeysByCarUpdateID(id int64) ([]common.ECCKeys, error) {
|
||||
var keys []common.ECCKeys
|
||||
|
||||
_, err := ek.GetDBConn().Query(&keys, sqlCarUpdateIDEnv+sqlECCPrivKeysEnv, id)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
if len(keys) == 0 {
|
||||
return nil, errors.Errorf("No ECC keys for car update id %d", id)
|
||||
}
|
||||
|
||||
logger.Info().Msgf("SelectAllPrivateKeysByCarUpdateID %d %s", id, keys[0].Env)
|
||||
|
||||
for i := range keys {
|
||||
keys[i].Env = ""
|
||||
err = ek.decrypt(&keys[i])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func (ek *EccKeys) selectKeys(env string, columns ...string) ([]common.ECCKeys, error) {
|
||||
ecckeys := []common.ECCKeys{}
|
||||
|
||||
err := ek.GetDBConn().Model(&ecckeys).Column(columns...).Where("env = ?", env).Order("ecu").Select()
|
||||
if err != nil {
|
||||
return ecckeys, errors.WithStack(err)
|
||||
}
|
||||
|
||||
for i := range ecckeys {
|
||||
err = ek.decrypt(&ecckeys[i])
|
||||
}
|
||||
|
||||
return ecckeys, err
|
||||
}
|
||||
|
||||
func (ek *EccKeys) decrypt(eccKeys *common.ECCKeys) error {
|
||||
enc := s.Encrypt{}
|
||||
encryptor, err := enc.GetEncryptor()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = ek.decryptKey(encryptor, eccKeys.PrivKey1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = ek.decryptKey(encryptor, eccKeys.PrivKey2)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = ek.decryptKey(encryptor, eccKeys.PrivKey3)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = ek.decryptKey(encryptor, eccKeys.PubKey1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = ek.decryptKey(encryptor, eccKeys.PubKey2)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = ek.decryptKey(encryptor, eccKeys.PubKey3)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ek *EccKeys) decryptKey(encryptor s.IEncryptor, key *common.BinaryHex) error {
|
||||
if key != nil {
|
||||
keys, err := encryptor.DecryptChunk(key.Bytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
key.SetBytes(keys)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
101
pkg/db/queries/ecckeys_test.go
Normal file
101
pkg/db/queries/ecckeys_test.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package queries_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/db"
|
||||
"fiskerinc.com/modules/db/queries"
|
||||
"fiskerinc.com/modules/testhelper"
|
||||
)
|
||||
|
||||
func TestSelectAllPrivateKeysByVIN(t *testing.T) {
|
||||
t.Skip()
|
||||
client := db.DBClient{}
|
||||
client.GetConn().AddQueryHook(db.SQLLogger{})
|
||||
q := queries.EccKeys{}
|
||||
q.SetClient(&client)
|
||||
result, err := q.SelectAllPrivateKeysByVIN("3FAFP13P71R199432")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
if len(result) == 0 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Existing VIN", "more than 0 keys", 0)
|
||||
}
|
||||
|
||||
result, err = q.SelectAllPrivateKeysByVIN("3FAFP13P71R19943X")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
if len(result) == 0 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Non-existing VIN", "more than 0 keys", 0)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectAllPrivateKeysByCarUpdateID(t *testing.T) {
|
||||
t.Skip()
|
||||
client := db.DBClient{}
|
||||
client.GetConn().AddQueryHook(db.SQLLogger{})
|
||||
q := queries.EccKeys{}
|
||||
q.SetClient(&client)
|
||||
result, err := q.SelectAllPrivateKeysByCarUpdateID(3497)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
if len(result) == 0 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Existing car update id", "more than 0 keys", 0)
|
||||
}
|
||||
|
||||
result, err = q.SelectAllPrivateKeysByCarUpdateID(0)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
if len(result) == 0 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "non-existing car update id", "more than 0 keys", 0)
|
||||
}
|
||||
}
|
||||
|
||||
// Use for getting keys from db
|
||||
func TestECCKeysAll(t *testing.T) {
|
||||
t.Skip()
|
||||
ek := queries.EccKeys{}
|
||||
|
||||
keys, err := ek.SelectAllKeys()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
data, err := json.Marshal(keys)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
t.Error(string(data))
|
||||
}
|
||||
|
||||
// Use for inserting keys back into db
|
||||
func TestInsertECCKeys(t *testing.T) {
|
||||
t.Skip()
|
||||
ek := queries.EccKeys{}
|
||||
ecckeys := []common.ECCKeys{}
|
||||
err := json.Unmarshal([]byte(dataprod), &ecckeys)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
for _, keys := range ecckeys {
|
||||
_, err = ek.Insert(keys)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const dataprod = `[{"ecu":"BCM","pub_key_level_1":"bff0eab780e30ff9ece2fe487d1bab819ddf626ef75fc3886cab3f785180b0360b3dc2a09a64c6a64a2b66415c6438d9811aa8538fcb8d9dd47df3d84a35dfd4","level_1":"be298a33a95f80a782da14b071e49f18e3489f21e3d2e8798a5bc3796e3e78f2","pub_key_level_2":"456b9ed1d87b48c84a8085b59c9d464c842b6c9ab43c38ff86763145ea51613685cfc6fe450b57033a9ac54bd710f6aadb8678b30f49e9679e6abd15d112677b","level_2":"f140a2170d28a3be1f0f4d89627449e2340de90a255137ea621de0c45efc5146","pub_key_level_3":"9d9dbff29ef8bb930010f231d5231a6a9abe88b1db6221381748ad84ee52f3c71b35d45f1f5e051ccde71414b0961a533c9f6ffe0df8c303f43805979d619d8e","level_3":"22d92dcb2dad5436df8274309c1f2e39385733551ffb7cdac4932f14405dc9c9"}]`
|
||||
78
pkg/db/queries/ecu_dtc.go
Normal file
78
pkg/db/queries/ecu_dtc.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package queries
|
||||
|
||||
import (
|
||||
"time"
|
||||
"fiskerinc.com/modules/common"
|
||||
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type ECUInterface interface {
|
||||
Insert(dtc *[]common.DTC_ECU) (orm.Result, error)
|
||||
UpdateTimestamp(dtc *common.DTC_ECU) error
|
||||
Select(dtcecu common.DTC_ECUQuery, paging *PageQueryOptions) ([]common.DTC_ECU, error)
|
||||
Count(filter common.DTC_ECUQuery) (int, error)
|
||||
}
|
||||
|
||||
type ECU struct {
|
||||
QueryBase
|
||||
}
|
||||
|
||||
func (c *ECU) Insert(ecuDtc *[]common.DTC_ECU) (orm.Result, error) {
|
||||
if len(*ecuDtc) > 0 {
|
||||
return c.insert(ecuDtc)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *ECU) UpdateTimestamp(dtc *common.DTC_ECU) error {
|
||||
_, err := c.GetDBConn().Model(dtc).
|
||||
Where("vin = ?vin AND ecu = ?ecu AND trouble_code = ?trouble_code").
|
||||
Set("updated_at = ?", time.Now()).
|
||||
Update()
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
func (c *ECU) Select(filter common.DTC_ECUQuery, paging *PageQueryOptions) ([]common.DTC_ECU, error) {
|
||||
|
||||
dtcEcu := []common.DTC_ECU{}
|
||||
query := c.GetDBConn().Model(&dtcEcu)
|
||||
|
||||
c.applyFilters(query, filter)
|
||||
c.pageQuery(query, paging)
|
||||
err := query.Select()
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return dtcEcu, nil
|
||||
}
|
||||
|
||||
func (c *ECU) applyFilters(query *orm.Query, filter common.DTC_ECUQuery) {
|
||||
|
||||
query.Where("vin = ?", filter.VIN)
|
||||
|
||||
if filter.ECU != "" {
|
||||
query.Where("ecu = ?", filter.ECU)
|
||||
}
|
||||
|
||||
if filter.TroubleCode != 0 {
|
||||
query.Where("trouble_code = ?", filter.TroubleCode)
|
||||
}
|
||||
|
||||
if filter.StartTime != nil {
|
||||
query.Where("epoch_usec >= ?", filter.StartTime.Unix()*1000000)
|
||||
}
|
||||
|
||||
if filter.EndTime != nil {
|
||||
query.Where("epoch_usec <= ?", filter.EndTime.Unix()*1000000)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ECU) Count(filter common.DTC_ECUQuery) (int, error) {
|
||||
ecu_dtc := common.DTC_ECU{}
|
||||
query := c.GetDBConn().Model(&ecu_dtc)
|
||||
c.applyFilters(query, filter)
|
||||
return c.countWithStack(query.Count())
|
||||
}
|
||||
46
pkg/db/queries/ecu_dtc_test.go
Normal file
46
pkg/db/queries/ecu_dtc_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package queries_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
m "fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/db"
|
||||
"fiskerinc.com/modules/db/queries"
|
||||
"fiskerinc.com/modules/testhelper"
|
||||
)
|
||||
|
||||
func TestDTCIntegration(t *testing.T) {
|
||||
t.Skip()
|
||||
query := setupECUDTC(t)
|
||||
|
||||
testDTCInsert(t, query)
|
||||
}
|
||||
|
||||
func setupECUDTC(t *testing.T) queries.ECU {
|
||||
instance := queries.ECU{}
|
||||
conn = instance.GetDBConn()
|
||||
conn.AddQueryHook(db.SQLLogger{})
|
||||
|
||||
return instance
|
||||
}
|
||||
|
||||
func testDTCInsert(t *testing.T, query queries.ECU) {
|
||||
|
||||
var EcuDtc = []m.DTC_ECU{
|
||||
{
|
||||
VIN: "1B7HF16Y8TS510206",
|
||||
ECU: "AMP",
|
||||
TroubleCode: 123,
|
||||
},
|
||||
{
|
||||
VIN: "1B7HF16Y8TS510206",
|
||||
ECU: "Brake",
|
||||
TroubleCode: 456,
|
||||
},
|
||||
}
|
||||
_, err := query.Insert(&EcuDtc)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Failed to insert DTCs", "No error", err)
|
||||
}
|
||||
|
||||
}
|
||||
153
pkg/db/queries/filekeys.go
Normal file
153
pkg/db/queries/filekeys.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package queries
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/security"
|
||||
"fiskerinc.com/modules/validator"
|
||||
|
||||
"github.com/go-pg/pg/v10"
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
"github.com/pkg/errors"
|
||||
berrors "errors"
|
||||
)
|
||||
|
||||
type FileKeysInterface interface {
|
||||
Delete(fileID string) (orm.Result, error)
|
||||
Insert(filekey common.FileKey) (orm.Result, error)
|
||||
Get(fileID string) (*common.FileKey, error)
|
||||
GetMulti(fileIDs []string) ([]common.FileKey, error)
|
||||
}
|
||||
|
||||
type FileKeys struct {
|
||||
QueryBase
|
||||
encryptor security.IEncryptor
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func (fk *FileKeys) getEncryptor() (security.IEncryptor, error) {
|
||||
var err error
|
||||
fk.once.Do(func() {
|
||||
encrypt := security.Encrypt{}
|
||||
fk.encryptor, err = encrypt.GetEncryptor()
|
||||
})
|
||||
|
||||
return fk.encryptor, err
|
||||
}
|
||||
|
||||
// Delete deletes fileID of FileKey from database
|
||||
func (fk *FileKeys) Delete(fileID string) (orm.Result, error) {
|
||||
if fileID == "" {
|
||||
return nil, errors.WithStack(&validator.FieldError{
|
||||
ErrorMsg: "FileID required",
|
||||
})
|
||||
}
|
||||
|
||||
conn := fk.GetDBConn()
|
||||
return fk.resultWithStack(conn.Model(&common.FileKey{
|
||||
FileID: fileID,
|
||||
}).WherePK().Delete())
|
||||
}
|
||||
|
||||
// Insert makes a copy of FileKey, encrypts encryption parameters, and inserts into database
|
||||
func (fk *FileKeys) Insert(filekey common.FileKey) (orm.Result, error) {
|
||||
err := validator.ValidateStruct(filekey)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
err = fk.encrypt(&filekey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return fk.resultWithStack(fk.GetDBConn().Model(&filekey).Insert())
|
||||
}
|
||||
|
||||
func (fk *FileKeys) Get(fileID string) (*common.FileKey, error) {
|
||||
if fileID == "" {
|
||||
return nil, errors.WithStack(&validator.FieldError{
|
||||
ErrorMsg: "FileID required",
|
||||
})
|
||||
}
|
||||
|
||||
filekey := []common.FileKey{}
|
||||
err := fk.GetDBConn().Model(&filekey).Where("file_id = ?", fileID).Select()
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
if len(filekey) == 0 {
|
||||
return &common.FileKey{}, nil
|
||||
}
|
||||
|
||||
err = fk.decrypt(&filekey[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &filekey[0], nil
|
||||
}
|
||||
|
||||
func (fk *FileKeys) GetMulti(fileIDs []string) ([]common.FileKey, error) {
|
||||
filekeys := []common.FileKey{}
|
||||
if len(fileIDs) == 0 {
|
||||
return filekeys, nil
|
||||
}
|
||||
err := fk.GetDBConn().Model(&filekeys).Where("file_id in (?)", pg.In(fileIDs)).Select()
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
var masterError error
|
||||
for i := range filekeys {
|
||||
err = fk.decrypt(&filekeys[i])
|
||||
if err != nil {
|
||||
masterError = berrors.Join(masterError, err)
|
||||
filekeys[i].Error = err.Error()
|
||||
}
|
||||
}
|
||||
|
||||
return filekeys, masterError
|
||||
}
|
||||
|
||||
func (fk *FileKeys) encrypt(filekey *common.FileKey) error {
|
||||
encryptor, err := fk.getEncryptor()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
filekey.Key = encryptor.EncryptChunk([]byte(filekey.Key))
|
||||
filekey.Auth = encryptor.EncryptChunk([]byte(filekey.Auth))
|
||||
filekey.Nonce = encryptor.EncryptChunk([]byte(filekey.Nonce))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fk *FileKeys) decrypt(filekey *common.FileKey) error {
|
||||
|
||||
encryptor, err := fk.getEncryptor()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
value, err := encryptor.DecryptChunk([]byte(filekey.Key))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
filekey.Key = value
|
||||
value, err = encryptor.DecryptChunk([]byte(filekey.Auth))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
filekey.Auth = value
|
||||
value, err = encryptor.DecryptChunk([]byte(filekey.Nonce))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
filekey.Nonce = value
|
||||
return nil
|
||||
}
|
||||
155
pkg/db/queries/filekeys_test.go
Normal file
155
pkg/db/queries/filekeys_test.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package queries_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/db/queries"
|
||||
"fiskerinc.com/modules/testhelper"
|
||||
th "fiskerinc.com/modules/testhelper"
|
||||
"fiskerinc.com/modules/validator"
|
||||
)
|
||||
|
||||
const testFileID string = "07a4ed515543d4d8"
|
||||
const testFileIDNonExistent string = "3edffbecef53734d"
|
||||
const testFileKey string = "TEST_KEYTEST_KEYTEST_KEYTEST_KEY"
|
||||
const testNouce string = "NOUNCENOUNCE"
|
||||
|
||||
var fk queries.FileKeysInterface
|
||||
|
||||
func TestStructValidation(t *testing.T) {
|
||||
type TestCase struct {
|
||||
Name string
|
||||
Struct common.FileKey
|
||||
ExpectedError string
|
||||
}
|
||||
|
||||
tests := []TestCase{
|
||||
{
|
||||
Name: "No values",
|
||||
Struct: common.FileKey{},
|
||||
ExpectedError: `Key: 'FileKey.FileID' Error:Field validation for 'FileID' failed on the 'required' tag
|
||||
Key: 'FileKey.Key' Error:Field validation for 'Key' failed on the 'required' tag
|
||||
Key: 'FileKey.Auth' Error:Field validation for 'Auth' failed on the 'required' tag
|
||||
Key: 'FileKey.Nonce' Error:Field validation for 'Nonce' failed on the 'required' tag`,
|
||||
},
|
||||
{
|
||||
Name: "Bad file id",
|
||||
Struct: common.FileKey{
|
||||
FileID: "XXXXXXXXXXXXXXXX",
|
||||
Key: []byte("12345678901234561234567890123456"),
|
||||
Auth: []byte("12345"),
|
||||
Nonce: []byte("123456789012"),
|
||||
},
|
||||
ExpectedError: `Key: 'FileKey.FileID' Error:Field validation for 'FileID' failed on the 'hexadecimal' tag`,
|
||||
},
|
||||
{
|
||||
Name: "Bad file id 2",
|
||||
Struct: common.FileKey{
|
||||
FileID: "74fe7f75a9f59f3",
|
||||
Key: []byte("12345678901234561234567890123456"),
|
||||
Auth: []byte("12345"),
|
||||
Nonce: []byte("123456789012"),
|
||||
},
|
||||
ExpectedError: `Key: 'FileKey.FileID' Error:Field validation for 'FileID' failed on the 'len' tag`,
|
||||
},
|
||||
{
|
||||
Name: "Good",
|
||||
Struct: common.FileKey{
|
||||
FileID: "074fe7f75a9f59f3",
|
||||
Key: []byte("12345678901234561234567890123456"),
|
||||
Auth: []byte("12345"),
|
||||
Nonce: []byte("123456789012"),
|
||||
},
|
||||
ExpectedError: `Key: 'FileKey.FileID' Error:Field validation for 'FileID' failed on the 'len' tag`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
err := validator.ValidateStruct(test.Struct)
|
||||
if err != nil && err.Error() != test.ExpectedError {
|
||||
t.Errorf(testhelper.TestErrorTemplate, test.Name, test.ExpectedError, err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileKeysIntegration(t *testing.T) {
|
||||
t.Skip()
|
||||
fk = &queries.FileKeys{}
|
||||
|
||||
fileKeysInsert(t)
|
||||
fileKeysGet(t)
|
||||
fileKeysDelete(t)
|
||||
|
||||
fk = nil
|
||||
}
|
||||
|
||||
func fileKeysInsert(t *testing.T) {
|
||||
_, err := fk.Insert(common.FileKey{})
|
||||
expectedError := `Key: 'FileKey.FileID' Error:Field validation for 'FileID' failed on the 'required' tag
|
||||
Key: 'FileKey.Key' Error:Field validation for 'Key' failed on the 'required' tag
|
||||
Key: 'FileKey.Auth' Error:Field validation for 'Auth' failed on the 'required' tag
|
||||
Key: 'FileKey.Nonce' Error:Field validation for 'Nonce' failed on the 'required' tag`
|
||||
if err != nil && err.Error() != expectedError {
|
||||
t.Errorf(th.TestErrorTemplate, "Bad insert", expectedError, err)
|
||||
} else if err == nil {
|
||||
t.Errorf(th.TestErrorTemplate, "Bad insert", "Validation errors", err)
|
||||
}
|
||||
|
||||
result, err := fk.Insert(common.FileKey{
|
||||
FileID: testFileID,
|
||||
Key: []byte(testFileKey),
|
||||
Auth: []byte(testFileKey + "AUTH"),
|
||||
Nonce: []byte(testNouce),
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf(th.TestErrorTemplate, "Good insert", "no error", err)
|
||||
}
|
||||
if result.RowsAffected() != 1 {
|
||||
t.Errorf(th.TestErrorTemplate, "Good insert", "1 row", result.RowsAffected())
|
||||
}
|
||||
}
|
||||
|
||||
func fileKeysGet(t *testing.T) {
|
||||
filekey, err := fk.Get(testFileIDNonExistent)
|
||||
expectedError := "non-existent key"
|
||||
if err != nil && err.Error() != expectedError {
|
||||
t.Errorf(th.TestErrorTemplate, "Get Non-existent", expectedError, err)
|
||||
} else if err == nil {
|
||||
t.Errorf(th.TestErrorTemplate, "Get Non-existent", "Results errors", err)
|
||||
}
|
||||
if filekey != nil {
|
||||
t.Errorf(th.TestErrorTemplate, "Get Non-existent", "nil", filekey)
|
||||
}
|
||||
|
||||
filekey, err = fk.Get(testFileID)
|
||||
if err != nil {
|
||||
t.Errorf(th.TestErrorTemplate, "Get", "get errors", err)
|
||||
}
|
||||
if string(filekey.Key) != testFileKey {
|
||||
t.Errorf(th.TestErrorTemplate, "Get", testFileKey, filekey.Key)
|
||||
}
|
||||
if string(filekey.Auth) != testFileKey+"AUTH" {
|
||||
t.Errorf(th.TestErrorTemplate, "Get", testFileKey+"AUTH", filekey.Auth)
|
||||
}
|
||||
if string(filekey.Nonce) != testNouce {
|
||||
t.Errorf(th.TestErrorTemplate, "Get", testNouce, filekey.Nonce)
|
||||
}
|
||||
}
|
||||
|
||||
func fileKeysDelete(t *testing.T) {
|
||||
result, err := fk.Delete(testFileIDNonExistent)
|
||||
if err != nil {
|
||||
t.Errorf(th.TestErrorTemplate, "Delete non-existing file id", "no error", err)
|
||||
} else if result.RowsAffected() > 0 {
|
||||
t.Errorf(th.TestErrorTemplate, "Delete non-existing file id", "no rows", result.RowsAffected())
|
||||
}
|
||||
|
||||
result, err = fk.Delete(testFileID)
|
||||
if err != nil {
|
||||
t.Errorf(th.TestErrorTemplate, "Delete file id", "no error", err)
|
||||
}
|
||||
if result.RowsAffected() != 1 {
|
||||
t.Errorf(th.TestErrorTemplate, "Delete file id", 1, result.RowsAffected())
|
||||
}
|
||||
}
|
||||
39
pkg/db/queries/helper.go
Normal file
39
pkg/db/queries/helper.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package queries
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"fiskerinc.com/modules/validator"
|
||||
|
||||
"github.com/gorilla/schema"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type PageQueryOptions struct {
|
||||
Order string `json:"order" validate:"max=512,sqlorder"` // Order only allows one field to be ordered, allows ASC and DESC as well. Leave empty to not apply order
|
||||
Limit int `json:"limit" validate:"gte=0,lte=100"`
|
||||
Offset int `json:"offset" validate:"gte=0"`
|
||||
Ignore []string `json:"ignore" validate:"dive"`
|
||||
}
|
||||
|
||||
var PageQueryOptionsLimitMaximum = 100
|
||||
|
||||
func (p *PageQueryOptions) String() string {
|
||||
return fmt.Sprintf("PageQueryOptions<%s %d %d>", p.Order, p.Limit, p.Offset)
|
||||
}
|
||||
|
||||
// ParsePageQuery parses PageQueryOptions from http request
|
||||
func ParsePageQuery(r *http.Request) (*PageQueryOptions, error) {
|
||||
decoder := schema.NewDecoder()
|
||||
options := PageQueryOptions{}
|
||||
|
||||
decoder.SetAliasTag("json")
|
||||
decoder.Decode(&options, r.URL.Query())
|
||||
err := validator.ValidateStruct(options)
|
||||
if err == nil && options.Limit == 0 {
|
||||
options.Limit = PageQueryOptionsLimitMaximum
|
||||
}
|
||||
|
||||
return &options, errors.WithStack(err)
|
||||
}
|
||||
35
pkg/db/queries/issue_images.go
Normal file
35
pkg/db/queries/issue_images.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package queries
|
||||
|
||||
import (
|
||||
"fiskerinc.com/modules/common"
|
||||
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type IssueImagesInterface interface {
|
||||
Insert(issueImage *[]common.IssueImage) (orm.Result, error)
|
||||
SearchByIssueID(issueID string) ([]common.IssueImage, error)
|
||||
}
|
||||
|
||||
type IssueImages struct {
|
||||
QueryBase
|
||||
}
|
||||
|
||||
func (c *IssueImages) Insert(issueImages *[]common.IssueImage) (orm.Result, error) {
|
||||
|
||||
res, err := c.GetDBConn().Model(issueImages).Insert()
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (c *IssueImages) SearchByIssueID(issueID string) ([]common.IssueImage, error) {
|
||||
issueImages := []common.IssueImage{}
|
||||
query := c.GetDBConn().Model(&issueImages)
|
||||
|
||||
err := query.Where("issue_id = ?", issueID).Select()
|
||||
|
||||
return issueImages, errors.WithStack(err)
|
||||
}
|
||||
62
pkg/db/queries/issue_images_test.go
Normal file
62
pkg/db/queries/issue_images_test.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package queries_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"fiskerinc.com/modules/common"
|
||||
m "fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/db"
|
||||
"fiskerinc.com/modules/db/queries"
|
||||
"fiskerinc.com/modules/testhelper"
|
||||
)
|
||||
|
||||
func TestIssueImageIntegration(t *testing.T) {
|
||||
t.Skip()
|
||||
query := setupIssueImages(t)
|
||||
testIssueImagesInsert(t, query)
|
||||
testIssueImagesSearch(t, query)
|
||||
}
|
||||
func setupIssueImages(t *testing.T) queries.IssueImages {
|
||||
instance := queries.IssueImages{}
|
||||
conn = instance.GetDBConn()
|
||||
conn.AddQueryHook(db.SQLLogger{})
|
||||
|
||||
client := instance.GetClient()
|
||||
client.InitSchema([]interface{}{
|
||||
(*common.Issue)(nil),
|
||||
})
|
||||
|
||||
return instance
|
||||
}
|
||||
|
||||
func testIssueImagesInsert(t *testing.T, query queries.IssueImages) {
|
||||
|
||||
issueImage := []m.IssueImage{
|
||||
{
|
||||
Image: []byte{72, 101, 108, 108, 111, 49},
|
||||
IssueID: 1,
|
||||
},
|
||||
{
|
||||
Image: []byte{72, 101, 108, 108, 111, 111},
|
||||
IssueID: 1,
|
||||
},
|
||||
}
|
||||
|
||||
res, err := query.Insert(&issueImage)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "IssueImages update", "No error", err)
|
||||
}
|
||||
|
||||
if res.RowsAffected() != 2 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "IssueImages insert RowsAffected", 1, res.RowsAffected())
|
||||
}
|
||||
}
|
||||
|
||||
func testIssueImagesSearch(t *testing.T, query queries.IssueImages) {
|
||||
|
||||
_, err := query.SearchByIssueID("22")
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Issue Image Search", "No error", err)
|
||||
}
|
||||
|
||||
}
|
||||
101
pkg/db/queries/issues.go
Normal file
101
pkg/db/queries/issues.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package queries
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/validator"
|
||||
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type IssuesInterface interface {
|
||||
Insert(issue *common.Issue) (orm.Result, error)
|
||||
Delete(id int) (orm.Result, error)
|
||||
SelectByID(id int) (*common.Issue, error)
|
||||
Search(filter *common.IssueSearch, paging *PageQueryOptions) ([]common.Issue, error)
|
||||
Count() (int, error)
|
||||
}
|
||||
|
||||
type Issues struct {
|
||||
QueryBase
|
||||
}
|
||||
|
||||
func (c *Issues) load(query *orm.Query) error {
|
||||
err := query.Relation("IssueImages").Select()
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
func (c *Issues) Insert(issue *common.Issue) (orm.Result, error) {
|
||||
return c.insert(issue)
|
||||
}
|
||||
|
||||
func (c *Issues) Search(filter *common.IssueSearch, paging *PageQueryOptions) ([]common.Issue, error) {
|
||||
issues := []common.Issue{}
|
||||
query := c.GetDBConn().Model(&issues)
|
||||
|
||||
c.searchFilter(query, filter)
|
||||
c.pageQuery(query, paging)
|
||||
|
||||
err := query.Select()
|
||||
|
||||
return issues, errors.WithStack(err)
|
||||
}
|
||||
|
||||
func (c *Issues) SelectByID(id int) (*common.Issue, error) {
|
||||
|
||||
if id <= 0 {
|
||||
return nil, &validator.FieldError{
|
||||
ErrorMsg: "id cannot be less than 0",
|
||||
}
|
||||
}
|
||||
issue := common.Issue{}
|
||||
query := c.GetDBConn().Model(&issue)
|
||||
|
||||
query.Where("issue.id = ?", id)
|
||||
err := c.load(query)
|
||||
|
||||
return &issue, err
|
||||
}
|
||||
|
||||
func (c *Issues) Delete(id int) (orm.Result, error) {
|
||||
|
||||
if id <= 0 {
|
||||
return nil, &validator.FieldError{
|
||||
ErrorMsg: "id has to be a positive integer",
|
||||
}
|
||||
}
|
||||
|
||||
total := ORMResults{}
|
||||
issueImage := common.IssueImage{}
|
||||
issueImagesQuery := c.GetDBConn().Model(&issueImage)
|
||||
res, err := issueImagesQuery.Where("issue_id = ?", id).Delete()
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
total.AddResult(res)
|
||||
|
||||
issue := common.Issue{}
|
||||
query := c.GetDBConn().Model(&issue)
|
||||
res, err = query.Where("issue.id = ?", id).Delete()
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
total.AddResult(res)
|
||||
|
||||
return &total, nil
|
||||
|
||||
}
|
||||
|
||||
func (c *Issues) Count() (int, error) {
|
||||
issue := common.Issue{}
|
||||
query := c.GetDBConn().Model(&issue)
|
||||
return c.countWithStack(query.Count())
|
||||
}
|
||||
|
||||
func (c *Issues) searchFilter(query *orm.Query, filter *common.IssueSearch) {
|
||||
if filter.Search != "" {
|
||||
query.Where("vin ILIKE ? OR title ILIKE ? OR driver_id ILIKE ?", fmt.Sprintf("%%%s%%", filter.Search), fmt.Sprintf("%%%s%%", filter.Search), fmt.Sprintf("%%%s%%", filter.Search))
|
||||
}
|
||||
}
|
||||
84
pkg/db/queries/issues_test.go
Normal file
84
pkg/db/queries/issues_test.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package queries_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"fiskerinc.com/modules/common"
|
||||
m "fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/db"
|
||||
"fiskerinc.com/modules/db/queries"
|
||||
"fiskerinc.com/modules/testhelper"
|
||||
)
|
||||
|
||||
func TestIssueIntegration(t *testing.T) {
|
||||
t.Skip()
|
||||
query := setupIssues(t)
|
||||
|
||||
testIssueInsert(t, query)
|
||||
testIssueSearch(t, query)
|
||||
testIssueSelect(t, query)
|
||||
testIssueDelete(t, query)
|
||||
}
|
||||
|
||||
func setupIssues(t *testing.T) queries.Issues {
|
||||
instance := queries.Issues{}
|
||||
conn = instance.GetDBConn()
|
||||
conn.AddQueryHook(db.SQLLogger{})
|
||||
|
||||
client := instance.GetClient()
|
||||
client.InitSchema([]interface{}{
|
||||
(*common.Issue)(nil),
|
||||
})
|
||||
|
||||
return instance
|
||||
}
|
||||
|
||||
func testIssueInsert(t *testing.T, query queries.Issues) {
|
||||
|
||||
issue := m.Issue{
|
||||
VIN: "1GNGC26RXXJ407648",
|
||||
Title: "Example HMI Problem",
|
||||
Description: "HMI blue screen",
|
||||
DriverID: "0b6b1930-b20a-4fce-967a-efac6a01fd10",
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
res, err := query.Insert(&issue)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Issues Insert", "No error", err)
|
||||
}
|
||||
|
||||
if res.RowsAffected() != 1 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Issues insert RowsAffected", 1, res.RowsAffected())
|
||||
}
|
||||
}
|
||||
|
||||
func testIssueSearch(t *testing.T, query queries.Issues) {
|
||||
|
||||
options := queries.PageQueryOptions{
|
||||
Offset: 0,
|
||||
Limit: 0,
|
||||
Order: "id DESC",
|
||||
}
|
||||
|
||||
_, err := query.Search(nil, &options)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Issues Insert", "No error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func testIssueSelect(t *testing.T, query queries.Issues) {
|
||||
|
||||
_, err := query.SelectByID(22)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Issues Select", "No error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func testIssueDelete(t *testing.T, query queries.Issues) {
|
||||
_, err := query.Delete(14)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Issues Delete", "No error", err)
|
||||
}
|
||||
}
|
||||
20
pkg/db/queries/mocks/apicalls.go
Normal file
20
pkg/db/queries/mocks/apicalls.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/db/queries"
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
)
|
||||
|
||||
type MockAPICalls struct {
|
||||
SearchMock func(filter common.APICallsSearch, paging *queries.PageQueryOptions) ([]common.APICall, int, error)
|
||||
DBMockHelper
|
||||
}
|
||||
|
||||
func (m *MockAPICalls) Insert(keyValue common.APICall) (orm.Result, error) {
|
||||
return m.ORMResponse, m.Error
|
||||
}
|
||||
|
||||
func (m *MockAPICalls) Search(filter common.APICallsSearch, paging *queries.PageQueryOptions) ([]common.APICall, int, error) {
|
||||
return m.SearchMock(filter, paging)
|
||||
}
|
||||
62
pkg/db/queries/mocks/apitokens.go
Normal file
62
pkg/db/queries/mocks/apitokens.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/db/queries"
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
)
|
||||
|
||||
type MockAPITokens struct {
|
||||
ListResult []common.APIToken
|
||||
GetResult *common.APIToken
|
||||
DBMockHelper
|
||||
}
|
||||
|
||||
func (m *MockAPITokens) SetListResp(list interface{}) {
|
||||
if list != nil {
|
||||
m.ListResult = list.([]common.APIToken)
|
||||
} else {
|
||||
m.ListResult = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockAPITokens) SetLoadResp(item interface{}) {
|
||||
if item != nil {
|
||||
m.GetResult = item.(*common.APIToken)
|
||||
} else {
|
||||
m.GetResult = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockAPITokens) Delete(token string) (orm.Result, error) {
|
||||
m.LastFilter = &common.APIToken{Token: token}
|
||||
return m.ORMResponse, m.Error
|
||||
}
|
||||
|
||||
func (m *MockAPITokens) Insert(keyValue common.APIToken) (orm.Result, error) {
|
||||
return m.ORMResponse, m.Error
|
||||
}
|
||||
|
||||
func (m *MockAPITokens) Get(token string) (*common.APIToken, error) {
|
||||
m.LastFilter = &common.APIToken{Token: token}
|
||||
return m.GetResult, m.Error
|
||||
}
|
||||
|
||||
func (m *MockAPITokens) Update(keyValue *common.APIToken) (orm.Result, error) {
|
||||
return m.ORMResponse, m.Error
|
||||
}
|
||||
|
||||
func (m *MockAPITokens) Select(filter *common.APIToken, paging *queries.PageQueryOptions) ([]common.APIToken, error) {
|
||||
m.LastFilter = filter
|
||||
m.LastPaging = paging
|
||||
|
||||
return m.ListResult, m.Error
|
||||
}
|
||||
|
||||
func (m *MockAPITokens) Count(apitoken *common.APIToken) (int, error) {
|
||||
if m.Error != nil {
|
||||
return 0, m.Error
|
||||
}
|
||||
|
||||
return len(m.ListResult), nil
|
||||
}
|
||||
16
pkg/db/queries/mocks/car_config_data.go
Normal file
16
pkg/db/queries/mocks/car_config_data.go
Normal file
File diff suppressed because one or more lines are too long
32
pkg/db/queries/mocks/car_versions_log.go
Normal file
32
pkg/db/queries/mocks/car_versions_log.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/db/queries"
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
)
|
||||
|
||||
type MockCarVersionsLog struct {
|
||||
MockLogVersionChange func(log *common.CarVersionLogs) (orm.Result, error)
|
||||
MockSelectByVIN func(vin string, options *queries.PageQueryOptions) ([]common.CarVersionLogs, int, error)
|
||||
GetCarVersionsResult map[string]string
|
||||
DBMockHelper
|
||||
}
|
||||
|
||||
func (m MockCarVersionsLog) LogVersionChange(log *common.CarVersionLogs) (orm.Result, error) {
|
||||
return m.MockLogVersionChange(log)
|
||||
}
|
||||
|
||||
func (m MockCarVersionsLog) SelectByVIN(vin string, options *queries.PageQueryOptions) ([]common.CarVersionLogs, int, error) {
|
||||
return m.MockSelectByVIN(vin, options)
|
||||
}
|
||||
|
||||
func (m MockCarVersionsLog) GetCarVersions(vin string, timestamp time.Time) (map[string]string, error) {
|
||||
if m.Error != nil {
|
||||
return nil, m.Error
|
||||
}
|
||||
|
||||
return m.GetCarVersionsResult, nil
|
||||
}
|
||||
516
pkg/db/queries/mocks/cars.go
Normal file
516
pkg/db/queries/mocks/cars.go
Normal file
@@ -0,0 +1,516 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
|
||||
"fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/db/queries"
|
||||
"fiskerinc.com/modules/validator"
|
||||
"github.com/jinzhu/copier"
|
||||
)
|
||||
|
||||
// CarUpdate query methods
|
||||
type MockCars struct {
|
||||
SelectResponse *common.Car
|
||||
SelectC2DResponse *common.CarToDriver
|
||||
SelectCarsResponse []common.Car
|
||||
SelectCarECUs []common.CarECU
|
||||
SelectCarsForDrivers []common.CarToDriver
|
||||
SelectCarsForDriver common.CarToDriver
|
||||
SelectCarSettings []common.CarSetting
|
||||
SelectCarFlashpackVersions []common.CarFlashpackVersion
|
||||
DBMockHelper
|
||||
}
|
||||
|
||||
// GetSoftwareVersion implements queries.CarsInterface.
|
||||
func (c *MockCars) GetSoftwareVersion(vin string) (result common.CarPKCOSVersion, err error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// GetICCIDs implements queries.CarsInterface.
|
||||
func (c *MockCars) GetICCIDs(vins []string) (iccids []string, err error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// GetWhiteListCars implements queries.CarsInterface.
|
||||
func (c *MockCars) GetWhiteListCars() (vins []string, err error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// BlacklistCars implements queries.CarsInterface.
|
||||
func (c *MockCars) BlacklistCars(vin []string) (err error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// WhitelistCars implements queries.CarsInterface.
|
||||
func (c *MockCars) WhitelistCars(vin []string, source string) (err error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
var _ queries.CarsInterface = &MockCars{}
|
||||
|
||||
func (c *MockCars) UpdateICCID(car *common.Car) (orm.Result, error) {
|
||||
if car.VIN == "" {
|
||||
return nil, &validator.FieldError{
|
||||
ErrorMsg: "VIN required",
|
||||
}
|
||||
}
|
||||
|
||||
c.ORMResponse = &MockORMResults{AffectedRows: 1}
|
||||
|
||||
return c.ORMResponse, c.Error
|
||||
}
|
||||
|
||||
func (c *MockCars) UpdateSoldStatus(car *common.Car) (orm.Result, error) {
|
||||
if car.VIN == "" {
|
||||
return nil, &validator.FieldError{
|
||||
ErrorMsg: "VIN required",
|
||||
}
|
||||
}
|
||||
|
||||
c.ORMResponse = &MockORMResults{AffectedRows: 1}
|
||||
|
||||
return c.ORMResponse, c.Error
|
||||
}
|
||||
|
||||
func (c *MockCars) SelectByID(id int64) (*common.Car, error) {
|
||||
return c.SelectResponse, c.Error
|
||||
}
|
||||
|
||||
func (c *MockCars) SelectByVIN(vin string) (*common.Car, error) {
|
||||
return c.SelectResponse, c.Error
|
||||
}
|
||||
|
||||
func (c *MockCars) Search(filter *common.CarSearch, paging *queries.PageQueryOptions) ([]common.Car, error) {
|
||||
c.LastFilter = filter
|
||||
return c.Select(&filter.Car, paging)
|
||||
}
|
||||
|
||||
func (c *MockCars) Select(filter *common.Car, paging *queries.PageQueryOptions) ([]common.Car, error) {
|
||||
c.LastPaging = paging
|
||||
|
||||
return c.SelectCarsResponse, c.Error
|
||||
}
|
||||
|
||||
func (c *MockCars) SelectOrInsert(car *common.Car) (bool, error) {
|
||||
return c.SelectOrInsertResult, c.Error
|
||||
}
|
||||
|
||||
func (c *MockCars) Delete(car *common.Car) (orm.Result, error) {
|
||||
if car.VIN == "" {
|
||||
return nil, &validator.FieldError{
|
||||
ErrorMsg: "id required",
|
||||
}
|
||||
}
|
||||
|
||||
return c.ORMResponse, c.Error
|
||||
}
|
||||
|
||||
func (c *MockCars) Update(car *common.Car) (orm.Result, error) {
|
||||
if car.VIN == "" {
|
||||
return nil, &validator.FieldError{
|
||||
ErrorMsg: "VIN required",
|
||||
}
|
||||
}
|
||||
|
||||
return c.ORMResponse, c.Error
|
||||
}
|
||||
|
||||
func (c *MockCars) CarsByManifest(manifest common.UpdateManifest, paging *queries.PageQueryOptions) ([]common.Car, error) {
|
||||
c.LastPaging = paging
|
||||
return c.SelectCarsResponse, c.Error
|
||||
}
|
||||
|
||||
func (c *MockCars) CountCarsByManifest(manifest common.UpdateManifest) (int, error) {
|
||||
return len(c.SelectCarsResponse), c.Error
|
||||
}
|
||||
|
||||
func (c *MockCars) Insert(car *common.Car) (orm.Result, error) {
|
||||
return c.ORMResponse, c.Error
|
||||
}
|
||||
|
||||
func (c *MockCars) Load(car *common.Car) error {
|
||||
if c.Error != nil {
|
||||
return c.Error
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MockCars) Count(filter *common.Car) (int, error) {
|
||||
return len(c.SelectCarsResponse), c.Error
|
||||
}
|
||||
|
||||
func (c *MockCars) SearchCount(filter *common.CarSearch) (int, error) {
|
||||
return len(c.SelectCarsResponse), c.Error
|
||||
}
|
||||
|
||||
func (c *MockCars) AddDriver(car *common.Car, driver *common.Driver, role string) (*common.CarToDriver, error) {
|
||||
if c.Error != nil {
|
||||
return nil, c.Error
|
||||
}
|
||||
if c.DriverError != nil {
|
||||
return nil, c.DriverError
|
||||
}
|
||||
|
||||
return c.SelectC2DResponse, nil
|
||||
}
|
||||
|
||||
func (c *MockCars) SelectCarToDriver(filter *common.CarToDriver) ([]common.CarToDriver, error) {
|
||||
if c.Error != nil {
|
||||
return nil, c.Error
|
||||
}
|
||||
|
||||
return c.SelectCarsForDrivers, nil
|
||||
}
|
||||
|
||||
func (c *MockCars) GetDriver(id string) (common.CarToDriver, error) {
|
||||
if c.Error != nil {
|
||||
return c.SelectCarsForDriver, c.Error
|
||||
}
|
||||
|
||||
return c.SelectCarsForDriver, nil
|
||||
}
|
||||
|
||||
func (c *MockCars) GetDrivers(vin string) ([]common.CarToDriver, error) {
|
||||
if c.Error != nil {
|
||||
return nil, c.Error
|
||||
}
|
||||
|
||||
return c.SelectCarsForDrivers, nil
|
||||
}
|
||||
|
||||
func (c *MockCars) RemoveDriver(vin string, driverID string) (orm.Result, error) {
|
||||
if c.Error != nil {
|
||||
return c.ORMResponse, c.Error
|
||||
}
|
||||
|
||||
return c.ORMResponse, nil
|
||||
}
|
||||
|
||||
func (c *MockCars) GetTRexSetting(vin string) (common.TRexSetting, error) {
|
||||
tRexSetting := common.TRexSetting{}
|
||||
if c.Error != nil {
|
||||
return tRexSetting, c.Error
|
||||
}
|
||||
|
||||
return tRexSetting, nil
|
||||
}
|
||||
|
||||
func (c *MockCars) GetModels() ([]string, error) {
|
||||
if c.Error != nil {
|
||||
return nil, c.Error
|
||||
}
|
||||
|
||||
return []string{"1G1FP87S3GN100062"}, nil
|
||||
}
|
||||
|
||||
func (c *MockCars) GetYears() ([]int, error) {
|
||||
if c.Error != nil {
|
||||
return nil, c.Error
|
||||
}
|
||||
|
||||
return []int{3000}, nil
|
||||
}
|
||||
|
||||
func (c *MockCars) SetSetting(setting *common.CarSetting) (orm.Result, error) {
|
||||
if c.Error != nil {
|
||||
return c.ORMResponse, c.Error
|
||||
}
|
||||
|
||||
return c.ORMResponse, nil
|
||||
}
|
||||
|
||||
func (c *MockCars) GetVehicleSpecificSettings(driver *common.CarToDriver) ([]common.CarSetting, error) {
|
||||
return c.SelectCarSettings, c.Error
|
||||
}
|
||||
|
||||
func (c *MockCars) GetSettings(driver *common.CarToDriver) ([]common.CarSetting, error) {
|
||||
return []common.CarSetting{}, c.Error
|
||||
}
|
||||
|
||||
func (c *MockCars) DeleteSetting(setting *common.CarSetting) (orm.Result, error) {
|
||||
return c.ORMResponse, c.Error
|
||||
}
|
||||
|
||||
func (c *MockCars) UpdateCarECU(ecu *common.CarECU) error {
|
||||
return c.Error
|
||||
}
|
||||
|
||||
func (c *MockCars) UpdateCarECUs(ecus []common.CarECU) error {
|
||||
return c.Error
|
||||
}
|
||||
|
||||
func (c *MockCars) InsertCarECUs(ecus []common.CarECU) error {
|
||||
return c.Error
|
||||
}
|
||||
|
||||
func (c *MockCars) uniqueFilter() []common.CarECU {
|
||||
seen := make(map[string]bool)
|
||||
filteredECUs := make([]common.CarECU, 0)
|
||||
for _, ecu := range c.SelectCarECUs {
|
||||
if !seen[ecu.VIN+ecu.ECU] {
|
||||
filteredECUs = append(filteredECUs, ecu)
|
||||
seen[ecu.VIN+ecu.ECU] = true
|
||||
}
|
||||
}
|
||||
return filteredECUs
|
||||
}
|
||||
|
||||
func (c *MockCars) GetCarECUs(filter common.CarECUFilter, paging *queries.PageQueryOptions) ([]common.CarECU, error) {
|
||||
copiedList := []common.CarECU{}
|
||||
copier.CopyWithOption(&copiedList, &c.SelectCarECUs, copier.Option{DeepCopy: true})
|
||||
if filter.Search != "" {
|
||||
return []common.CarECU{copiedList[0]}, c.Error
|
||||
}
|
||||
|
||||
if filter.Unique {
|
||||
return c.uniqueFilter(), c.Error
|
||||
}
|
||||
|
||||
return copiedList, c.Error
|
||||
}
|
||||
|
||||
func (c *MockCars) GetCarECUsCount(filter common.CarECUFilter) (int, error) {
|
||||
if filter.Search != "" {
|
||||
return 1, c.Error
|
||||
}
|
||||
|
||||
if filter.Unique {
|
||||
return len(c.uniqueFilter()), c.Error
|
||||
}
|
||||
|
||||
return len(c.SelectCarECUs), c.Error
|
||||
}
|
||||
|
||||
func (c *MockCars) UpdateCarFlashpackVersion(vin string, flashpack string) (orm.Result, error) {
|
||||
return c.ORMResponse, nil
|
||||
}
|
||||
|
||||
func (c *MockCars) GetFlashpackVersions(carModel string, carTrim string, carYear int, options *queries.PageQueryOptions) ([]common.CarFlashpackVersionResponse, error) {
|
||||
return []common.CarFlashpackVersionResponse{
|
||||
{
|
||||
CarModel: "Ocean",
|
||||
CarTrim: "Base",
|
||||
CarYear: 2023,
|
||||
Flashpack: "43.19",
|
||||
},
|
||||
{
|
||||
CarModel: "Ocean",
|
||||
CarTrim: "Base",
|
||||
CarYear: 2023,
|
||||
Flashpack: "41.14",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *MockCars) GetFlashpackVersionsCount(carModel string, carTrim string, carYear int) (int, error) {
|
||||
return 2, nil
|
||||
}
|
||||
|
||||
func (c *MockCars) GetNextFlashpackVersion(carModel string, carTrim string, flashpack string) (*common.CarFlashpackVersionResponse, error) {
|
||||
return &common.CarFlashpackVersionResponse{
|
||||
CarModel: "Ocean",
|
||||
CarTrim: "Base",
|
||||
CarYear: 2023,
|
||||
Flashpack: "41.14",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *MockCars) GetCarFlashpackVersionMappingsByModelTrim(carModel string, carTrim string, options *queries.PageQueryOptions) ([]common.CarFlashpackVersion, error) {
|
||||
if c.SelectCarFlashpackVersions != nil {
|
||||
return c.SelectCarFlashpackVersions, nil
|
||||
}
|
||||
|
||||
return []common.CarFlashpackVersion{
|
||||
{
|
||||
CarModel: "Ocean",
|
||||
CarTrim: "Base",
|
||||
CarYear: 2023,
|
||||
Flashpack: "44.14",
|
||||
CarECUName: "ADAS",
|
||||
CarECUVersion: "ADASVersion1",
|
||||
},
|
||||
{
|
||||
CarModel: "Ocean",
|
||||
CarTrim: "Base",
|
||||
CarYear: 2023,
|
||||
Flashpack: "41.14",
|
||||
CarECUName: "ADAS",
|
||||
CarECUVersion: "ADASVersion",
|
||||
},
|
||||
{
|
||||
CarModel: "Ocean",
|
||||
CarTrim: "Base",
|
||||
CarYear: 2023,
|
||||
Flashpack: "41.14",
|
||||
CarECUName: "ACUN",
|
||||
CarECUVersion: "ACUNVersion",
|
||||
},
|
||||
{
|
||||
CarModel: "Ocean",
|
||||
CarTrim: "Base",
|
||||
CarYear: 2023,
|
||||
Flashpack: "39.14",
|
||||
CarECUName: "BCM",
|
||||
CarECUVersion: "BCMVersion",
|
||||
},
|
||||
{
|
||||
CarModel: "Ocean",
|
||||
CarTrim: "Base",
|
||||
CarYear: 2024,
|
||||
Flashpack: "11.0",
|
||||
CarECUName: "ADAS",
|
||||
CarECUVersion: "ADASVersion4",
|
||||
},
|
||||
{
|
||||
CarModel: "Ocean",
|
||||
CarTrim: "Base",
|
||||
CarYear: 2023,
|
||||
Flashpack: "39.14",
|
||||
CarECUName: "ADAS",
|
||||
CarECUVersion: "ADASVersion0",
|
||||
},
|
||||
{
|
||||
CarModel: "Ocean",
|
||||
CarTrim: "Base",
|
||||
CarYear: 2023,
|
||||
Flashpack: "39.14",
|
||||
CarECUName: "ACUN",
|
||||
CarECUVersion: "ACUNVersion0",
|
||||
},
|
||||
{
|
||||
CarModel: "Ocean",
|
||||
CarTrim: "Base",
|
||||
CarYear: 2023,
|
||||
Flashpack: "39.14",
|
||||
CarECUName: "PDI",
|
||||
CarECUVersion: "PDIVersion",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *MockCars) GetCarFlashpackVersionMappingsByModelTrimYearFlashpack(carModel string, carTrim string, carYear int, flashpack string, options *queries.PageQueryOptions) ([]common.CarFlashpackVersion, error) {
|
||||
if c.SelectCarFlashpackVersions != nil {
|
||||
return c.SelectCarFlashpackVersions, nil
|
||||
}
|
||||
|
||||
return []common.CarFlashpackVersion{
|
||||
{
|
||||
CarModel: "Ocean",
|
||||
CarTrim: "Base",
|
||||
CarYear: 2023,
|
||||
Flashpack: "44.14",
|
||||
CarECUName: "ADAS",
|
||||
CarECUVersion: "ADASVersion1",
|
||||
},
|
||||
{
|
||||
CarModel: "Ocean",
|
||||
CarTrim: "Base",
|
||||
CarYear: 2023,
|
||||
Flashpack: "41.14",
|
||||
CarECUName: "ADAS",
|
||||
CarECUVersion: "ADASVersion",
|
||||
},
|
||||
{
|
||||
CarModel: "Ocean",
|
||||
CarTrim: "Base",
|
||||
CarYear: 2024,
|
||||
Flashpack: "11.0",
|
||||
CarECUName: "ADAS",
|
||||
CarECUVersion: "ADASVersion4",
|
||||
},
|
||||
{
|
||||
CarModel: "Ocean",
|
||||
CarTrim: "Base",
|
||||
CarYear: 2023,
|
||||
Flashpack: "41.14",
|
||||
CarECUName: "ACUN",
|
||||
CarECUVersion: "ACUNVersion",
|
||||
},
|
||||
{
|
||||
CarModel: "Ocean",
|
||||
CarTrim: "Base",
|
||||
CarYear: 2023,
|
||||
Flashpack: "39.14",
|
||||
CarECUName: "BCM",
|
||||
CarECUVersion: "BCMVersion",
|
||||
},
|
||||
{
|
||||
CarModel: "Ocean",
|
||||
CarTrim: "Base",
|
||||
CarYear: 2023,
|
||||
Flashpack: "39.14",
|
||||
CarECUName: "ADAS",
|
||||
CarECUVersion: "ADASVersion0",
|
||||
},
|
||||
{
|
||||
CarModel: "Ocean",
|
||||
CarTrim: "Base",
|
||||
CarYear: 2023,
|
||||
Flashpack: "39.14",
|
||||
CarECUName: "ACUN",
|
||||
CarECUVersion: "ACUNVersion0",
|
||||
},
|
||||
{
|
||||
CarModel: "Ocean",
|
||||
CarTrim: "Base",
|
||||
CarYear: 2023,
|
||||
Flashpack: "39.14",
|
||||
CarECUName: "PDI",
|
||||
CarECUVersion: "PDIVersion",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *MockCars) GetCarFlashpackVersionMappingsByModelTrimYearFlashpackCount(carModel string, carTrim string, carYear int, flashpack string) (int, error) {
|
||||
if c.SelectCarFlashpackVersions != nil {
|
||||
return len(c.SelectCarFlashpackVersions), nil
|
||||
}
|
||||
|
||||
return 8, nil
|
||||
}
|
||||
|
||||
func (c *MockCars) GetCarECUNamesByModelTrim(carModel string, carTrim string) ([]string, error) {
|
||||
return []string{"ADAS", "ACU", "BMS"}, nil
|
||||
}
|
||||
|
||||
func (c *MockCars) AddCarFlashpackVersionMappings(carFlashpackVersions []common.CarFlashpackVersion) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MockCars) DeleteFlashpackVersion(carModel string, carTrim string, carYear int, flashpack string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MockCars) GetCarsForDriver(driverID string) ([]common.CarToDriver, error) {
|
||||
return c.SelectCarsForDrivers, c.Error
|
||||
}
|
||||
|
||||
func (c *MockCars) UpdateBLEKey(vin string, driverid string, blekey string) (common.DriverExternal, error) {
|
||||
return common.DriverExternal{}, c.Error
|
||||
}
|
||||
|
||||
func (c *MockCars) ECUUpdatedAt(ecu common.CarECU) (orm.Result, error) {
|
||||
return c.ORMResponse, c.Error
|
||||
}
|
||||
|
||||
func (c *MockCars) SetListResp(list interface{}) {
|
||||
if list != nil {
|
||||
c.SelectCarsResponse = list.([]common.Car)
|
||||
} else {
|
||||
c.SelectCarsResponse = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *MockCars) SetLoadResp(item interface{}) {
|
||||
if item != nil {
|
||||
c.SelectResponse = item.(*common.Car)
|
||||
} else {
|
||||
c.SelectResponse = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *MockCars) GetSoftwareAndPKCVersions(vins []string) (results []common.CarPKCOSVersion, err error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
179
pkg/db/queries/mocks/carupdates.go
Normal file
179
pkg/db/queries/mocks/carupdates.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
|
||||
"fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/db/queries"
|
||||
"fiskerinc.com/modules/validator"
|
||||
)
|
||||
|
||||
type MockCarUpdates struct {
|
||||
SelectCarUpdateResponse *common.CarUpdate
|
||||
SelectCarUpdatesResponse []common.CarUpdate
|
||||
SelectCarUpdateStatusesResponse []common.CarUpdateStatus
|
||||
HasPendingUpdatesResponse bool
|
||||
LoadManifest *common.UpdateManifest
|
||||
PendingUpdateSameAfterSalesUsersResponse PendingUpdatesFromSameAftersaleUserResponse
|
||||
DBMockHelper
|
||||
}
|
||||
|
||||
func (c *MockCarUpdates) HasPendingUpdatesFromAftersalesUser(manifestID int64, vin string) (updateID int64, pendingUpdateSameUser bool, err error) {
|
||||
return c.PendingUpdateSameAfterSalesUsersResponse.UpdateID, c.PendingUpdateSameAfterSalesUsersResponse.PendingUpdateSameUser, c.PendingUpdateSameAfterSalesUsersResponse.Err
|
||||
}
|
||||
|
||||
type PendingUpdatesFromSameAftersaleUserResponse struct {
|
||||
PendingUpdateSameUser bool
|
||||
UpdateID int64
|
||||
Err error
|
||||
}
|
||||
|
||||
func (c *MockCarUpdates) SelectByID(id int64) (*common.CarUpdate, error) {
|
||||
return c.SelectCarUpdateResponse, c.Error
|
||||
}
|
||||
|
||||
func (c *MockCarUpdates) SelectByVIN(vin string) ([]common.CarUpdate, error) {
|
||||
return c.SelectCarUpdatesResponse, c.Error
|
||||
}
|
||||
|
||||
func (c *MockCarUpdates) SelectMostRecentByVINs(vins []string) ([]common.CarUpdate, error) {
|
||||
return c.SelectCarUpdatesResponse, c.Error
|
||||
}
|
||||
|
||||
func (c *MockCarUpdates) SelectByManifestID(manifest_id int64) ([]common.CarUpdate, error) {
|
||||
return c.SelectCarUpdatesResponse, c.Error
|
||||
}
|
||||
|
||||
func (c *MockCarUpdates) SelectOrInsert(update *common.CarUpdate) (bool, error) {
|
||||
update.ID++
|
||||
return c.SelectOrInsertResult, c.Error
|
||||
}
|
||||
|
||||
func (c *MockCarUpdates) Select(filter *common.CarUpdate, paging *queries.PageQueryOptions) ([]common.CarUpdate, error) {
|
||||
c.LastFilter = filter
|
||||
c.LastPaging = paging
|
||||
|
||||
return c.SelectCarUpdatesResponse, c.Error
|
||||
}
|
||||
|
||||
func (c *MockCarUpdates) Delete(update *common.CarUpdate) (orm.Result, error) {
|
||||
err := validator.ValidateIDField(update.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c.ORMResponse, c.Error
|
||||
}
|
||||
|
||||
func (c *MockCarUpdates) UpdateStatus(update *common.CarUpdate) (orm.Result, error) {
|
||||
err := validator.ValidateIDField(update.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
update.UpdatedAt = &time.Time{}
|
||||
|
||||
return c.ORMResponse, c.Error
|
||||
}
|
||||
|
||||
func (c *MockCarUpdates) Insert(update *common.CarUpdate) (orm.Result, error) {
|
||||
update.ID++
|
||||
return c.ORMResponse, c.Error
|
||||
}
|
||||
|
||||
func (c *MockCarUpdates) Load(update *common.CarUpdate) error {
|
||||
if c.Error != nil {
|
||||
return c.Error
|
||||
}
|
||||
|
||||
if c.SelectCarUpdateResponse != nil {
|
||||
update.VIN = c.SelectCarUpdateResponse.VIN
|
||||
update.UpdateManifestID = c.SelectCarUpdateResponse.UpdateManifestID
|
||||
}
|
||||
|
||||
if c.LoadManifest == nil {
|
||||
update.UpdateManifest = &common.UpdateManifest{
|
||||
ID: update.UpdateManifestID,
|
||||
Name: "Test",
|
||||
Version: "1.2",
|
||||
ReleaseNotes: "http://releasenotes.com",
|
||||
}
|
||||
} else {
|
||||
update.UpdateManifest = c.LoadManifest
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MockCarUpdates) Count(filter *common.CarUpdate) (int, error) {
|
||||
return len(c.SelectCarUpdatesResponse), c.Error
|
||||
}
|
||||
|
||||
func (c *MockCarUpdates) GetUpdateStatuses(carupdateid int64, paging *queries.PageQueryOptions) ([]common.CarUpdateStatus, error) {
|
||||
c.LastPaging = paging
|
||||
|
||||
return c.SelectCarUpdateStatusesResponse, c.Error
|
||||
}
|
||||
|
||||
func (c *MockCarUpdates) TruncateRequirementsAwaitForUpdate(carupdateid int64) (orm.Result, error) {
|
||||
return c.ORMResponse, c.Error
|
||||
}
|
||||
|
||||
func (c *MockCarUpdates) CountUpdateStatuses(carupdateid int64) (int, error) {
|
||||
return len(c.SelectCarUpdateStatusesResponse), c.Error
|
||||
}
|
||||
|
||||
func (c *MockCarUpdates) GetManifest(carupdateid int64) (*common.UpdateManifest, error) {
|
||||
if c.LoadManifest == nil {
|
||||
return nil, c.Error
|
||||
}
|
||||
return c.LoadManifest, c.Error
|
||||
}
|
||||
|
||||
func (c *MockCarUpdates) LogStatus(update *common.CarUpdate) (orm.Result, error) {
|
||||
return c.ORMResponse, c.Error
|
||||
}
|
||||
|
||||
func (c *MockCarUpdates) SetLoadResp(item interface{}) {
|
||||
if item == nil {
|
||||
c.SelectCarUpdateResponse = nil
|
||||
} else {
|
||||
c.SelectCarUpdateResponse = item.(*common.CarUpdate)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *MockCarUpdates) SetListResp(list interface{}) {
|
||||
if list != nil {
|
||||
c.SelectCarUpdatesResponse = list.([]common.CarUpdate)
|
||||
} else {
|
||||
c.SelectCarUpdatesResponse = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *MockCarUpdates) HasPendingUpdates(manifestID int64, vin string) (bool, error) {
|
||||
return c.HasPendingUpdatesResponse, c.Error
|
||||
}
|
||||
|
||||
var lastStatus string
|
||||
|
||||
func (c *MockCarUpdates) UpdateStatusIfNotRepeat(update *common.CarUpdate) (orm.Result, error) {
|
||||
if update.Status == lastStatus {
|
||||
return nil, queries.RepeatedStatus
|
||||
}
|
||||
|
||||
lastStatus = update.Status
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *MockCarUpdates) InsertAndCreateStatus(update *common.CarUpdate) (orm.Result, error) {
|
||||
update.ID++
|
||||
return c.ORMResponse, c.Error
|
||||
}
|
||||
|
||||
func (c *MockCarUpdates) InsertMissingFlashpack(vin string, flashpackVersion string) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ queries.CarUpdatesInterface = &MockCarUpdates{}
|
||||
53
pkg/db/queries/mocks/certificates.go
Normal file
53
pkg/db/queries/mocks/certificates.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
|
||||
"fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/validator"
|
||||
)
|
||||
|
||||
// CarUpdate query methods
|
||||
type MockCertificates struct {
|
||||
DBMockHelper
|
||||
MockListResponse []common.Certificate
|
||||
MockCertificate *common.Certificate
|
||||
}
|
||||
|
||||
func (c *MockCertificates) Insert(cert *common.Certificate) (orm.Result, error) {
|
||||
return c.ORMResponse, c.Error
|
||||
}
|
||||
|
||||
func (c *MockCertificates) Update(cert *common.Certificate) (orm.Result, error) {
|
||||
if cert.SerialNumber == "" {
|
||||
return nil, &validator.FieldError{
|
||||
ErrorMsg: "Serial number required",
|
||||
}
|
||||
}
|
||||
return c.ORMResponse, c.Error
|
||||
}
|
||||
|
||||
func (c *MockCertificates) Remove(cert *common.Certificate) (orm.Result, error) {
|
||||
if cert.SerialNumber == "" {
|
||||
return nil, &validator.FieldError{
|
||||
ErrorMsg: "Serial number required",
|
||||
}
|
||||
}
|
||||
return c.ORMResponse, c.Error
|
||||
}
|
||||
|
||||
func (c *MockCertificates) SelectByCommonName(cn string) ([]common.Certificate, error) {
|
||||
return c.MockListResponse, c.Error
|
||||
}
|
||||
|
||||
func (c *MockCertificates) SelectBySerial(serial string) (*common.Certificate, error) {
|
||||
return c.MockCertificate, c.Error
|
||||
}
|
||||
|
||||
func (c *MockCertificates) SelectMostRecent(cn string, certType string) (*common.Certificate, error) {
|
||||
return c.MockCertificate, c.Error
|
||||
}
|
||||
|
||||
func (c *MockCertificates) SelectMostRecents(cn string, certTypes []string) ([]common.Certificate, error) {
|
||||
return c.MockListResponse, c.Error
|
||||
}
|
||||
78
pkg/db/queries/mocks/dbhttptest.go
Normal file
78
pkg/db/queries/mocks/dbhttptest.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
th "fiskerinc.com/modules/testhelper"
|
||||
"fiskerinc.com/modules/validator"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
)
|
||||
|
||||
const IGNORE_EXPECTED_RESP = "IGNORE_EXPECTED_RESP"
|
||||
|
||||
// Deprecated. Use modules_go/testrunner/test_case.go and modules_go/db/queries/mocks/dbtestcase.go
|
||||
type DBHttpTest struct {
|
||||
Name string
|
||||
Request *http.Request
|
||||
ExpectedStatus int
|
||||
ExpectedResponse string
|
||||
ExpectedResponseRegex *regexp.Regexp
|
||||
ValidateResponse bool `default:"false"`
|
||||
|
||||
DBTestCase
|
||||
}
|
||||
|
||||
func (test *DBHttpTest) ValidateHttp(t *testing.T, w *httptest.ResponseRecorder) {
|
||||
if test.ExpectedStatus != w.Result().StatusCode {
|
||||
th.Equal(t, fmt.Sprintf("%s status code", test.Name), test.ExpectedStatus, w.Result().StatusCode)
|
||||
}
|
||||
|
||||
if test.ExpectedResponseRegex != nil {
|
||||
if !test.ExpectedResponseRegex.Match(w.Body.Bytes()) {
|
||||
th.Equal(t, fmt.Sprintf("%s body", test.Name), test.ExpectedResponseRegex, w.Body.String())
|
||||
}
|
||||
} else if test.ExpectedResponse != IGNORE_EXPECTED_RESP && test.ExpectedResponse != w.Body.String() {
|
||||
th.Equal(t, fmt.Sprintf("%s body", test.Name), test.ExpectedResponse, w.Body.String())
|
||||
}
|
||||
|
||||
if test.ValidateResponse {
|
||||
err := validator.ValidateStruct(w.Body)
|
||||
th.NoError(t, fmt.Sprintf("%s validate body", test.Name), err)
|
||||
}
|
||||
}
|
||||
|
||||
func RunDBTests(t *testing.T, tests []DBHttpTest, handler http.HandlerFunc, mock DBMockHelperInterface) {
|
||||
for _, test := range tests {
|
||||
test.SetupDB(mock)
|
||||
|
||||
w := th.ExecHTTPHandler(handler, test.Request)
|
||||
|
||||
test.ValidateHttp(t, w)
|
||||
test.Validate(t, test.Name, mock)
|
||||
}
|
||||
}
|
||||
|
||||
func ExecHTTPRouterHandler(handler http.HandlerFunc, routePath string, request *http.Request) *httptest.ResponseRecorder {
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
router := httprouter.New()
|
||||
router.HandlerFunc(request.Method, routePath, handler)
|
||||
router.ServeHTTP(recorder, request)
|
||||
|
||||
return recorder
|
||||
}
|
||||
|
||||
func RunParamHttpTests(t *testing.T, tests []DBHttpTest, handler http.HandlerFunc, routePath string, mock DBMockHelperInterface) {
|
||||
for _, test := range tests {
|
||||
test.SetupDB(mock)
|
||||
|
||||
w := ExecHTTPRouterHandler(handler, routePath, test.Request)
|
||||
|
||||
test.ValidateHttp(t, w)
|
||||
test.Validate(t, test.Name, mock)
|
||||
}
|
||||
}
|
||||
50
pkg/db/queries/mocks/dbmockhelper.go
Normal file
50
pkg/db/queries/mocks/dbmockhelper.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"fiskerinc.com/modules/db/queries"
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
)
|
||||
|
||||
type DBMockHelperInterface interface {
|
||||
GetFilter() fmt.Stringer
|
||||
GetPaging() *queries.PageQueryOptions
|
||||
SetListResp(list interface{})
|
||||
SetLoadResp(item interface{})
|
||||
SetErr(error)
|
||||
SetDriverError(error)
|
||||
}
|
||||
|
||||
type DBMockHelper struct {
|
||||
SelectOrInsertResult bool
|
||||
ORMResponse orm.Result
|
||||
Error error
|
||||
DriverError error
|
||||
LastFilter fmt.Stringer
|
||||
LastPaging *queries.PageQueryOptions
|
||||
}
|
||||
|
||||
func (m *DBMockHelper) GetFilter() fmt.Stringer {
|
||||
return m.LastFilter
|
||||
}
|
||||
|
||||
func (m *DBMockHelper) GetPaging() *queries.PageQueryOptions {
|
||||
return m.LastPaging
|
||||
}
|
||||
|
||||
func (m *DBMockHelper) SetListResp(list interface{}) {
|
||||
// fmt.Printf("override SetListResp() in %s\n", reflect.TypeOf(list))
|
||||
}
|
||||
|
||||
func (m *DBMockHelper) SetLoadResp(item interface{}) {
|
||||
// fmt.Printf("override SetLoadResp() in %s\n", reflect.TypeOf(item))
|
||||
}
|
||||
|
||||
func (m *DBMockHelper) SetErr(err error) {
|
||||
m.Error = err
|
||||
}
|
||||
|
||||
func (up *DBMockHelper) SetDriverError(err error) {
|
||||
up.DriverError = err
|
||||
}
|
||||
48
pkg/db/queries/mocks/dbtestcase.go
Normal file
48
pkg/db/queries/mocks/dbtestcase.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"fiskerinc.com/modules/db/queries"
|
||||
th "fiskerinc.com/modules/testhelper"
|
||||
)
|
||||
|
||||
type DBTestCase struct {
|
||||
ExpectedFilter fmt.Stringer
|
||||
ExpectedPage *queries.PageQueryOptions
|
||||
MockListResponse interface{}
|
||||
MockLoadResponse interface{}
|
||||
SetupMockResponse func()
|
||||
MockError error
|
||||
MockDriverError error
|
||||
}
|
||||
|
||||
func (tc *DBTestCase) SetupDB(mock DBMockHelperInterface) {
|
||||
if mock != nil {
|
||||
mock.SetErr(tc.MockError)
|
||||
mock.SetListResp(tc.MockListResponse)
|
||||
mock.SetLoadResp(tc.MockLoadResponse)
|
||||
mock.SetDriverError(tc.MockDriverError)
|
||||
}
|
||||
|
||||
if tc.SetupMockResponse != nil {
|
||||
tc.SetupMockResponse()
|
||||
}
|
||||
}
|
||||
|
||||
func (tc *DBTestCase) Validate(t *testing.T, name string, mock DBMockHelperInterface) {
|
||||
if mock != nil {
|
||||
if mock.GetFilter() != nil && tc.ExpectedFilter != nil && mock.GetFilter().String() != tc.ExpectedFilter.String() {
|
||||
t.Errorf(th.TestErrorTemplate, name, tc.ExpectedFilter.String(), mock.GetFilter().String())
|
||||
} else if mock.GetFilter() == nil && tc.ExpectedFilter != nil {
|
||||
t.Errorf(th.TestErrorTemplate, name, tc.ExpectedFilter.String(), nil)
|
||||
}
|
||||
|
||||
if mock.GetPaging() != nil && tc.ExpectedPage != nil && mock.GetPaging().String() != tc.ExpectedPage.String() {
|
||||
t.Errorf(th.TestErrorTemplate, name, tc.ExpectedPage.String(), mock.GetPaging().String())
|
||||
} else if mock.GetPaging() == nil && tc.ExpectedPage != nil {
|
||||
t.Errorf(th.TestErrorTemplate, name, tc.ExpectedPage.String(), nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
54
pkg/db/queries/mocks/drivers.go
Normal file
54
pkg/db/queries/mocks/drivers.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
|
||||
"fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/validator"
|
||||
)
|
||||
|
||||
// MockDrivers
|
||||
type MockDrivers struct {
|
||||
SelectResponse []common.Driver
|
||||
SelectListResponse []common.Driver
|
||||
DBMockHelper
|
||||
}
|
||||
|
||||
func (d *MockDrivers) Select(filter *common.Driver) ([]common.Driver, error) {
|
||||
d.LastFilter = filter
|
||||
|
||||
return d.SelectResponse, d.Error
|
||||
}
|
||||
|
||||
func (d *MockDrivers) SelectOrInsert(driver *common.Driver) (bool, error) {
|
||||
return d.SelectOrInsertResult, d.Error
|
||||
}
|
||||
|
||||
func (d *MockDrivers) Delete(driver *common.Driver) (orm.Result, error) {
|
||||
return d.ORMResponse, d.Error
|
||||
}
|
||||
|
||||
func (d *MockDrivers) Insert(driver *common.Driver) (orm.Result, error) {
|
||||
err := validator.ValidateStruct(driver)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return d.ORMResponse, d.Error
|
||||
}
|
||||
|
||||
func (d *MockDrivers) Load(driver *common.Driver) error {
|
||||
if d.Error != nil {
|
||||
return d.Error
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *MockDrivers) SetListResp(list interface{}) {
|
||||
if list != nil {
|
||||
d.SelectListResponse = list.([]common.Driver)
|
||||
} else {
|
||||
d.SelectListResponse = nil
|
||||
}
|
||||
}
|
||||
53
pkg/db/queries/mocks/ecckeys.go
Normal file
53
pkg/db/queries/mocks/ecckeys.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"fiskerinc.com/modules/common"
|
||||
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
"github.com/jinzhu/copier"
|
||||
)
|
||||
|
||||
// EccKey query methods
|
||||
type MockEccKeys struct {
|
||||
DBMockHelper
|
||||
MockListResponse []common.ECCKeys
|
||||
MockEccKeys common.ECCKeys
|
||||
}
|
||||
|
||||
func (ek MockEccKeys) Insert(keys common.ECCKeys) (orm.Result, error) {
|
||||
return ek.ORMResponse, ek.Error
|
||||
}
|
||||
|
||||
func (ek MockEccKeys) SelectAllPrivateKeys() ([]common.ECCKeys, error) {
|
||||
return ek.MockListResponse, ek.Error
|
||||
}
|
||||
|
||||
func (ek MockEccKeys) SelectAllPrivateKeysByEnv(env string) ([]common.ECCKeys, error) {
|
||||
return ek.MockListResponse, ek.Error
|
||||
}
|
||||
|
||||
func (ek MockEccKeys) SelectPublicKeysByECUByEnv(ecu string, env string) (common.ECCKeys, error) {
|
||||
return ek.MockEccKeys, ek.Error
|
||||
}
|
||||
|
||||
func (ek MockEccKeys) SelectAllPublicKeysByEnv(env string) ([]common.ECCKeys, error) {
|
||||
return ek.MockListResponse, ek.Error
|
||||
}
|
||||
|
||||
func (ek MockEccKeys) SelectPrivateKeysByECUsEnv(ecus []string, env string) ([]common.ECCKeys, error) {
|
||||
result := []common.ECCKeys{}
|
||||
copier.Copy(&result, &ek.MockListResponse)
|
||||
return result, ek.Error
|
||||
}
|
||||
|
||||
func (ek MockEccKeys) SelectAllPrivateKeysByVIN(env string) ([]common.ECCKeys, error) {
|
||||
result := []common.ECCKeys{}
|
||||
copier.Copy(&result, &ek.MockListResponse)
|
||||
return result, ek.Error
|
||||
}
|
||||
|
||||
func (ek MockEccKeys) SelectAllPrivateKeysByCarUpdateID(id int64) ([]common.ECCKeys, error) {
|
||||
result := []common.ECCKeys{}
|
||||
copier.Copy(&result, &ek.MockListResponse)
|
||||
return result, ek.Error
|
||||
}
|
||||
30
pkg/db/queries/mocks/ecu_dtc.go
Normal file
30
pkg/db/queries/mocks/ecu_dtc.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/db/queries"
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
)
|
||||
|
||||
type MockEcuDtc struct {
|
||||
DBMockHelper
|
||||
SelectDTCECUResponse []common.DTC_ECU
|
||||
LastInsertCount int
|
||||
}
|
||||
|
||||
func (c *MockEcuDtc) Insert(ecudtc *[]common.DTC_ECU) (orm.Result, error) {
|
||||
c.LastInsertCount = len(*ecudtc)
|
||||
return c.ORMResponse, c.Error
|
||||
}
|
||||
|
||||
func (c *MockEcuDtc) UpdateTimestamp(dtc *common.DTC_ECU) error {
|
||||
return c.Error
|
||||
}
|
||||
|
||||
func (c *MockEcuDtc) Select(ecudtc common.DTC_ECUQuery, paging *queries.PageQueryOptions) ([]common.DTC_ECU, error) {
|
||||
return c.SelectDTCECUResponse, c.Error
|
||||
}
|
||||
|
||||
func (c *MockEcuDtc) Count(filter common.DTC_ECUQuery) (int, error) {
|
||||
return 0, c.Error
|
||||
}
|
||||
48
pkg/db/queries/mocks/filekeys.go
Normal file
48
pkg/db/queries/mocks/filekeys.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/validator"
|
||||
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
)
|
||||
|
||||
type MockFileKeys struct {
|
||||
GetResponse *common.FileKey
|
||||
GetMultiResponse []common.FileKey
|
||||
DBMockHelper
|
||||
}
|
||||
|
||||
func (fk *MockFileKeys) Delete(fileID string) (orm.Result, error) {
|
||||
if fileID == "" {
|
||||
return nil, &validator.FieldError{
|
||||
ErrorMsg: "FileID required",
|
||||
}
|
||||
}
|
||||
|
||||
return fk.ORMResponse, fk.Error
|
||||
}
|
||||
|
||||
func (fk *MockFileKeys) Insert(filekey common.FileKey) (orm.Result, error) {
|
||||
err := validator.ValidateStruct(filekey)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return fk.ORMResponse, fk.Error
|
||||
}
|
||||
|
||||
func (fk *MockFileKeys) Get(fileID string) (*common.FileKey, error) {
|
||||
if fileID == "" {
|
||||
return nil, &validator.FieldError{
|
||||
ErrorMsg: "FileID required",
|
||||
}
|
||||
}
|
||||
|
||||
return fk.GetResponse, fk.Error
|
||||
}
|
||||
|
||||
func (fk *MockFileKeys) GetMulti(fileIDs []string) ([]common.FileKey, error) {
|
||||
return fk.GetMultiResponse, fk.Error
|
||||
}
|
||||
109
pkg/db/queries/mocks/issues.go
Normal file
109
pkg/db/queries/mocks/issues.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
|
||||
"fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/db/queries"
|
||||
"fiskerinc.com/modules/validator"
|
||||
)
|
||||
|
||||
type MockIssue struct {
|
||||
SelectIssuesResponse []common.Issue
|
||||
DBMockHelper
|
||||
}
|
||||
|
||||
func (c *MockIssue) Insert(issue *common.Issue) (orm.Result, error) {
|
||||
return c.ORMResponse, c.Error
|
||||
}
|
||||
|
||||
func (c *MockIssue) Delete(id int) (orm.Result, error) {
|
||||
if id <= 0 {
|
||||
return nil, &validator.FieldError{
|
||||
ErrorMsg: "id cannot be less than 0",
|
||||
}
|
||||
}
|
||||
return c.ORMResponse, c.Error
|
||||
}
|
||||
|
||||
func (c *MockIssue) SelectByID(id int) (*common.Issue, error) {
|
||||
if id <= 0 {
|
||||
return nil, &validator.FieldError{
|
||||
ErrorMsg: "id cannot be less than 0",
|
||||
}
|
||||
}
|
||||
|
||||
issueImage := []common.IssueImage{
|
||||
{
|
||||
ID: 1,
|
||||
Image: []byte{},
|
||||
IssueID: 1,
|
||||
},
|
||||
}
|
||||
return &common.Issue{
|
||||
ID: 1,
|
||||
VIN: "",
|
||||
Title: "",
|
||||
Description: "",
|
||||
DriverID: "",
|
||||
Timestamp: time.Time{},
|
||||
IssueImages: issueImage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *MockIssue) Search(filter *common.IssueSearch, paging *queries.PageQueryOptions) ([]common.Issue, error) {
|
||||
if c.SelectIssuesResponse != nil {
|
||||
if filter.Search != "" && strings.Contains(c.SelectIssuesResponse[0].Title, filter.Search) {
|
||||
return []common.Issue{c.SelectIssuesResponse[0]}, nil
|
||||
}
|
||||
|
||||
return c.SelectIssuesResponse, nil
|
||||
}
|
||||
|
||||
return []common.Issue{
|
||||
{
|
||||
ID: 1,
|
||||
VIN: "",
|
||||
Title: "",
|
||||
Description: "",
|
||||
DriverID: "",
|
||||
Timestamp: time.Time{},
|
||||
IssueImages: []common.IssueImage{},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *MockIssue) Count() (int, error) {
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
func (c *MockIssue) SetListResp(list interface{}) {
|
||||
if list != nil {
|
||||
c.SelectIssuesResponse = list.([]common.Issue)
|
||||
} else {
|
||||
c.SelectIssuesResponse = nil
|
||||
}
|
||||
}
|
||||
|
||||
type MockIssueImages struct {
|
||||
queries.QueryBase
|
||||
SearchByIssueIDResponse []common.IssueImage
|
||||
DBMockHelper
|
||||
}
|
||||
|
||||
func (c *MockIssueImages) Insert(issueImage *[]common.IssueImage) (orm.Result, error) {
|
||||
return c.ORMResponse, c.Error
|
||||
}
|
||||
|
||||
func (c *MockIssueImages) SearchByIssueID(issueID string) ([]common.IssueImage, error) {
|
||||
return []common.IssueImage{
|
||||
{
|
||||
ID: 1,
|
||||
Image: []byte{0, 1, 0},
|
||||
IssueID: 1,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
25
pkg/db/queries/mocks/ormresults.go
Normal file
25
pkg/db/queries/mocks/ormresults.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package mocks
|
||||
|
||||
import "github.com/go-pg/pg/v10/orm"
|
||||
|
||||
type MockORMResults struct {
|
||||
ORMModel orm.Model
|
||||
AffectedRows int
|
||||
ReturnedRows int
|
||||
}
|
||||
|
||||
func (r *MockORMResults) Model() orm.Model {
|
||||
return r.ORMModel
|
||||
}
|
||||
|
||||
func (r *MockORMResults) RowsAffected() int {
|
||||
return r.AffectedRows
|
||||
}
|
||||
|
||||
func (r *MockORMResults) RowsReturned() int {
|
||||
return r.ReturnedRows
|
||||
}
|
||||
|
||||
func (r *MockORMResults) SetModel(model orm.Model) {
|
||||
r.ORMModel = model
|
||||
}
|
||||
20
pkg/db/queries/mocks/rate_plan.go
Normal file
20
pkg/db/queries/mocks/rate_plan.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/db/queries"
|
||||
)
|
||||
|
||||
type MockRatePlan struct {
|
||||
queries.QueryBase
|
||||
SelectResponse []common.RatePlanTMobile
|
||||
DBMockHelper
|
||||
}
|
||||
|
||||
func (m *MockRatePlan) Select(version string) (*common.RatePlanTMobile, error) {
|
||||
return &common.RatePlanTMobile{
|
||||
Country: "US",
|
||||
ProductID: "12345",
|
||||
PlanName: "Fisker US 5G",
|
||||
}, nil
|
||||
}
|
||||
44
pkg/db/queries/mocks/signed_images.go
Normal file
44
pkg/db/queries/mocks/signed_images.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
|
||||
"fiskerinc.com/modules/common"
|
||||
)
|
||||
|
||||
// EccKey query methods
|
||||
type MockSignedImages struct {
|
||||
DBMockHelper
|
||||
MockListResponse []common.SignedImage
|
||||
MockSignedImage common.SignedImage
|
||||
GetSigningCertResponse common.SupplierSigningCert
|
||||
GetSigningCertErr error
|
||||
}
|
||||
|
||||
func (si *MockSignedImages) Insert(keys common.SignedImage) (orm.Result, error) {
|
||||
return si.ORMResponse, si.Error
|
||||
}
|
||||
|
||||
func (si *MockSignedImages) SelectAll() ([]common.SignedImage, error) {
|
||||
return si.MockListResponse, si.Error
|
||||
}
|
||||
|
||||
func (si *MockSignedImages) SelectBySupplier(email string) (common.SignedImage, error) {
|
||||
return si.MockSignedImage, si.Error
|
||||
}
|
||||
|
||||
func (si *MockSignedImages) DeleteSigningCert(supplier_cert common.SupplierSigningCert) (orm.Result, error) {
|
||||
return si.ORMResponse, si.Error
|
||||
}
|
||||
|
||||
func (si *MockSignedImages) GetSigningCert(supplier string, keyCert string) (common.SupplierSigningCert, error) {
|
||||
return si.GetSigningCertResponse, si.GetSigningCertErr
|
||||
}
|
||||
|
||||
func (si *MockSignedImages) InsertSigningCert(supplier_cert common.SupplierSigningCert) (orm.Result, error) {
|
||||
return si.ORMResponse, si.Error
|
||||
}
|
||||
|
||||
func (si *MockSignedImages) SetListResp(list interface{}) {
|
||||
|
||||
}
|
||||
52
pkg/db/queries/mocks/subscription_configurations.go
Normal file
52
pkg/db/queries/mocks/subscription_configurations.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/db/queries"
|
||||
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
)
|
||||
|
||||
type MockSubscriptionConfigurations struct {
|
||||
ListResult []common.SubscriptionConfiguration
|
||||
DBMockHelper
|
||||
}
|
||||
|
||||
func (m *MockSubscriptionConfigurations) SetListResp(list interface{}) {
|
||||
if list != nil {
|
||||
m.ListResult = list.([]common.SubscriptionConfiguration)
|
||||
} else {
|
||||
m.ListResult = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockSubscriptionConfigurations) SetLoadResp(item interface{}) {
|
||||
// no get result to set
|
||||
}
|
||||
|
||||
func (m *MockSubscriptionConfigurations) Delete(model *common.SubscriptionConfiguration) (orm.Result, error) {
|
||||
m.LastFilter = model
|
||||
return m.ORMResponse, m.Error
|
||||
}
|
||||
|
||||
func (m *MockSubscriptionConfigurations) Insert(model *common.SubscriptionConfiguration) (orm.Result, error) {
|
||||
m.LastFilter = model
|
||||
return m.ORMResponse, m.Error
|
||||
}
|
||||
|
||||
func (m *MockSubscriptionConfigurations) Update(model *common.SubscriptionConfiguration) (orm.Result, error) {
|
||||
m.LastFilter = model
|
||||
return m.ORMResponse, m.Error
|
||||
}
|
||||
|
||||
func (m *MockSubscriptionConfigurations) Count(filter *common.SubscriptionConfiguration) (int, error) {
|
||||
m.LastFilter = filter
|
||||
return len(m.ListResult), m.Error
|
||||
}
|
||||
|
||||
func (m *MockSubscriptionConfigurations) Select(filter *common.SubscriptionConfiguration, paging *queries.PageQueryOptions) ([]common.SubscriptionConfiguration, error) {
|
||||
m.LastFilter = filter
|
||||
m.LastPaging = paging
|
||||
|
||||
return m.ListResult, m.Error
|
||||
}
|
||||
78
pkg/db/queries/mocks/subscription_features.go
Normal file
78
pkg/db/queries/mocks/subscription_features.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/db/queries"
|
||||
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type MockSubscriptionFeatures struct {
|
||||
ListResult []common.SubscriptionFeature
|
||||
LoadResult *common.SubscriptionFeature
|
||||
DBMockHelper
|
||||
}
|
||||
|
||||
func (m *MockSubscriptionFeatures) SetListResp(list interface{}) {
|
||||
if list != nil {
|
||||
m.ListResult = list.([]common.SubscriptionFeature)
|
||||
} else {
|
||||
m.ListResult = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockSubscriptionFeatures) SetLoadResp(item interface{}) {
|
||||
if item != nil {
|
||||
m.LoadResult = item.(*common.SubscriptionFeature)
|
||||
} else {
|
||||
m.LoadResult = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockSubscriptionFeatures) Delete(model *common.SubscriptionFeature) (orm.Result, error) {
|
||||
return m.ORMResponse, m.Error
|
||||
}
|
||||
|
||||
func (m *MockSubscriptionFeatures) Insert(model *common.SubscriptionFeature) (orm.Result, error) {
|
||||
if m.Error != nil {
|
||||
return nil, m.Error
|
||||
}
|
||||
|
||||
model.ID = uuid.MustParse("ecfb89e0-ca03-4aa9-a43a-a9d703256edb")
|
||||
|
||||
return m.ORMResponse, nil
|
||||
}
|
||||
|
||||
func (m *MockSubscriptionFeatures) Update(model *common.SubscriptionFeature) (orm.Result, error) {
|
||||
return m.ORMResponse, m.Error
|
||||
}
|
||||
|
||||
func (m *MockSubscriptionFeatures) Count(filter *common.SubscriptionFeature) (int, error) {
|
||||
return len(m.ListResult), m.Error
|
||||
}
|
||||
|
||||
func (m *MockSubscriptionFeatures) Select(filter *common.SubscriptionFeature, paging *queries.PageQueryOptions) ([]common.SubscriptionFeature, error) {
|
||||
m.LastFilter = filter
|
||||
m.LastPaging = paging
|
||||
|
||||
return m.ListResult, m.Error
|
||||
}
|
||||
|
||||
func (m *MockSubscriptionFeatures) Load(model *common.SubscriptionFeature) error {
|
||||
filter := *model
|
||||
m.LastFilter = &filter
|
||||
|
||||
if m.Error != nil {
|
||||
return m.Error
|
||||
}
|
||||
|
||||
if m.LoadResult != nil {
|
||||
model.ID = m.LoadResult.ID
|
||||
model.Name = m.LoadResult.Name
|
||||
model.Description = m.LoadResult.Description
|
||||
model.Configurations = m.LoadResult.Configurations
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
98
pkg/db/queries/mocks/subscription_packages.go
Normal file
98
pkg/db/queries/mocks/subscription_packages.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/db/queries"
|
||||
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type MockSubscriptionPackages struct {
|
||||
ListResult []common.SubscriptionPackage
|
||||
LoadResult *common.SubscriptionPackage
|
||||
InsertResult bool
|
||||
DBMockHelper
|
||||
}
|
||||
|
||||
func (m *MockSubscriptionPackages) SetListResp(list interface{}) {
|
||||
if list != nil {
|
||||
result, ok := list.([]common.SubscriptionPackage)
|
||||
if ok {
|
||||
m.ListResult = result
|
||||
return
|
||||
}
|
||||
}
|
||||
m.ListResult = nil
|
||||
}
|
||||
|
||||
func (m *MockSubscriptionPackages) SetLoadResp(item interface{}) {
|
||||
if item != nil {
|
||||
result, ok := item.(common.SubscriptionPackage)
|
||||
if ok {
|
||||
m.LoadResult = &result
|
||||
return
|
||||
}
|
||||
}
|
||||
m.LoadResult = nil
|
||||
}
|
||||
|
||||
func (m *MockSubscriptionPackages) Delete(model *common.SubscriptionPackage) (orm.Result, error) {
|
||||
return m.ORMResponse, m.Error
|
||||
}
|
||||
|
||||
func (m *MockSubscriptionPackages) Insert(model *common.SubscriptionPackage) (orm.Result, error) {
|
||||
if m.Error != nil {
|
||||
return nil, m.Error
|
||||
}
|
||||
|
||||
model.ID = uuid.MustParse("0557bd1d-76d3-41e5-a44e-13c479e55ab0")
|
||||
|
||||
return m.ORMResponse, nil
|
||||
}
|
||||
|
||||
func (m *MockSubscriptionPackages) Update(model *common.SubscriptionPackage) (orm.Result, error) {
|
||||
return m.ORMResponse, m.Error
|
||||
}
|
||||
|
||||
func (m *MockSubscriptionPackages) Count(filter *common.SubscriptionPackage) (int, error) {
|
||||
return len(m.ListResult), m.Error
|
||||
}
|
||||
|
||||
func (m *MockSubscriptionPackages) Select(filter *common.SubscriptionPackage, paging *queries.PageQueryOptions) ([]common.SubscriptionPackage, error) {
|
||||
m.LastFilter = filter
|
||||
m.LastPaging = paging
|
||||
|
||||
return m.ListResult, m.Error
|
||||
}
|
||||
|
||||
func (m *MockSubscriptionPackages) Load(model *common.SubscriptionPackage) error {
|
||||
filter := *model
|
||||
m.LastFilter = &filter
|
||||
|
||||
if m.LoadResult != nil {
|
||||
model.ID = m.LoadResult.ID
|
||||
model.Name = m.LoadResult.Name
|
||||
model.Features = m.LoadResult.Features
|
||||
}
|
||||
|
||||
return m.Error
|
||||
}
|
||||
|
||||
func (m *MockSubscriptionPackages) AddFeature(pack *common.SubscriptionPackage, feature *common.SubscriptionFeature) (bool, error) {
|
||||
if m.Error != nil {
|
||||
return false, m.Error
|
||||
}
|
||||
|
||||
pack.AddFeature(feature)
|
||||
|
||||
return m.InsertResult, nil
|
||||
}
|
||||
|
||||
func (m *MockSubscriptionPackages) AssociateFeature(packageid uuid.UUID, featureid uuid.UUID) (bool, error) {
|
||||
if m.Error != nil {
|
||||
return false, m.Error
|
||||
}
|
||||
|
||||
return m.InsertResult, nil
|
||||
}
|
||||
56
pkg/db/queries/mocks/subscriptions.go
Normal file
56
pkg/db/queries/mocks/subscriptions.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/db/queries"
|
||||
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
)
|
||||
|
||||
type MockSubscriptions struct {
|
||||
ListResult []common.Subscription
|
||||
ItemResult *common.Subscription
|
||||
DBMockHelper
|
||||
}
|
||||
|
||||
// Select returns list of drivers
|
||||
func (s *MockSubscriptions) Select(filter *common.Subscription) ([]common.Subscription, error) {
|
||||
return s.ListResult, s.Error
|
||||
}
|
||||
|
||||
func (s *MockSubscriptions) Insert(subtype *common.Subscription) (orm.Result, error) {
|
||||
return s.ORMResponse, s.Error
|
||||
}
|
||||
|
||||
func (s *MockSubscriptions) Update(subtype *common.Subscription) (orm.Result, error) {
|
||||
return s.ORMResponse, s.Error
|
||||
}
|
||||
|
||||
func (s *MockSubscriptions) Delete(req *queries.SubscriptionDeleteRequest) (orm.Result, error) {
|
||||
return s.ORMResponse, s.Error
|
||||
}
|
||||
|
||||
func (s *MockSubscriptions) Load(sub *common.Subscription) error {
|
||||
return s.Error
|
||||
|
||||
}
|
||||
|
||||
func (s *MockSubscriptions) Count(filter *common.Subscription) (int, error) {
|
||||
return len(s.ListResult), s.Error
|
||||
}
|
||||
|
||||
func (s *MockSubscriptions) Create(subtype *common.SubscriptionType, carToDriver *common.CarToDriver) (*common.Subscription, error) {
|
||||
if s.ItemResult == nil {
|
||||
return nil, s.Error
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
s.ItemResult.Name = subtype.Name
|
||||
s.ItemResult.SubscriptionTypeID = subtype.ID
|
||||
s.ItemResult.CreatedAt = &now
|
||||
s.ItemResult.UpdatedAt = &now
|
||||
|
||||
return s.ItemResult, s.Error
|
||||
}
|
||||
37
pkg/db/queries/mocks/subscriptiontypes.go
Normal file
37
pkg/db/queries/mocks/subscriptiontypes.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"fiskerinc.com/modules/common"
|
||||
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
)
|
||||
|
||||
type MockSubscriptionTypes struct {
|
||||
ORMResult orm.Result
|
||||
Error error
|
||||
ListResult []common.SubscriptionType
|
||||
}
|
||||
|
||||
func (st *MockSubscriptionTypes) Select(filter *common.SubscriptionType) ([]common.SubscriptionType, error) {
|
||||
return st.ListResult, st.Error
|
||||
}
|
||||
|
||||
func (st *MockSubscriptionTypes) Insert(subtype *common.SubscriptionType) (orm.Result, error) {
|
||||
return st.ORMResult, st.Error
|
||||
}
|
||||
|
||||
func (st *MockSubscriptionTypes) Update(subtype *common.SubscriptionType) (orm.Result, error) {
|
||||
return st.ORMResult, st.Error
|
||||
}
|
||||
|
||||
func (st *MockSubscriptionTypes) Delete(subtype *common.SubscriptionType) (orm.Result, error) {
|
||||
return st.ORMResult, st.Error
|
||||
}
|
||||
|
||||
func (st *MockSubscriptionTypes) Load(subtype *common.SubscriptionType) error {
|
||||
return st.Error
|
||||
}
|
||||
|
||||
func (st *MockSubscriptionTypes) Count(filter *common.SubscriptionType) (int, error) {
|
||||
return len(st.ListResult), st.Error
|
||||
}
|
||||
33
pkg/db/queries/mocks/sums_versions.go
Normal file
33
pkg/db/queries/mocks/sums_versions.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/db/queries"
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
)
|
||||
|
||||
type MockUpdateManifestVersions struct {
|
||||
queries.QueryBase
|
||||
SelectResponse []common.SUMSVersion
|
||||
DBMockHelper
|
||||
}
|
||||
|
||||
func (m *MockUpdateManifestVersions) SelectAll(options *queries.PageQueryOptions) ([]common.SUMSVersion, error) {
|
||||
return m.SelectResponse, nil
|
||||
}
|
||||
|
||||
func (m *MockUpdateManifestVersions) SelectAllCount() (int, error) {
|
||||
return len(m.SelectResponse), nil
|
||||
}
|
||||
|
||||
func (m *MockUpdateManifestVersions) Insert(u *common.SUMSVersion) (orm.Result, error) {
|
||||
return m.ORMResponse, nil
|
||||
}
|
||||
|
||||
func (m *MockUpdateManifestVersions) Delete(u *common.SUMSVersion) (orm.Result, error) {
|
||||
return m.ORMResponse, nil
|
||||
}
|
||||
|
||||
func (m *MockUpdateManifestVersions) Select(version string) (*common.SUMSVersion, error) {
|
||||
return nil, nil
|
||||
}
|
||||
70
pkg/db/queries/mocks/supplier_accounts.go
Normal file
70
pkg/db/queries/mocks/supplier_accounts.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/db/queries"
|
||||
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
)
|
||||
|
||||
// CarUpdate query methods
|
||||
type MockSupplierAccounts struct {
|
||||
DBMockHelper
|
||||
MockListResponse []common.SupplierAccount
|
||||
MockSupplierAccount *common.SupplierAccount
|
||||
}
|
||||
|
||||
func (c *MockSupplierAccounts) Count(account *common.SupplierAccount) (int, error) {
|
||||
return len(c.MockListResponse), c.Error
|
||||
}
|
||||
|
||||
func (c *MockSupplierAccounts) Delete(account *common.SupplierAccount) (orm.Result, error) {
|
||||
c.LastFilter = account
|
||||
return c.ORMResponse, c.Error
|
||||
}
|
||||
|
||||
func (c *MockSupplierAccounts) Insert(account *common.SupplierAccount) (orm.Result, error) {
|
||||
c.LastFilter = account
|
||||
return c.ORMResponse, c.Error
|
||||
}
|
||||
|
||||
func (c *MockSupplierAccounts) Load(account *common.SupplierAccount) error {
|
||||
c.LastFilter = account
|
||||
|
||||
if c.MockSupplierAccount != nil {
|
||||
account.ECUs = c.MockSupplierAccount.ECUs
|
||||
account.ActivatedAt = c.MockSupplierAccount.ActivatedAt
|
||||
}
|
||||
|
||||
return c.Error
|
||||
}
|
||||
|
||||
func (c *MockSupplierAccounts) Select(account *common.SupplierAccount, paging *queries.PageQueryOptions) ([]common.SupplierAccount, error) {
|
||||
c.LastFilter = account
|
||||
c.LastPaging = paging
|
||||
|
||||
return c.MockListResponse, c.Error
|
||||
}
|
||||
|
||||
func (c *MockSupplierAccounts) Update(account *common.SupplierAccount) (orm.Result, error) {
|
||||
c.LastFilter = account
|
||||
return c.ORMResponse, c.Error
|
||||
}
|
||||
|
||||
func (c *MockSupplierAccounts) Approve(email string) (orm.Result, error) {
|
||||
return c.ORMResponse, c.Error
|
||||
}
|
||||
|
||||
func (c *MockSupplierAccounts) UpdateTimestamp(email string, activity queries.SupplierTimestamp) (orm.Result, error) {
|
||||
return c.ORMResponse, c.Error
|
||||
}
|
||||
|
||||
func (c *MockSupplierAccounts) SetListResp(list interface{}) {
|
||||
if list != nil {
|
||||
c.MockListResponse = list.([]common.SupplierAccount)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *MockSupplierAccounts) SetLoadResp(item interface{}) {
|
||||
// fmt.Printf("override SetLoadResp() in %s\n", reflect.TypeOf(item))
|
||||
}
|
||||
43
pkg/db/queries/mocks/supplier_organizations.go
Normal file
43
pkg/db/queries/mocks/supplier_organizations.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/db/queries"
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
)
|
||||
|
||||
// EccKey query methods
|
||||
type MockSupplierOrganization struct {
|
||||
DBMockHelper
|
||||
MockListResponse []common.SupplierOrganization
|
||||
MockSupplierOrganization *common.SupplierOrganization
|
||||
}
|
||||
|
||||
func (so *MockSupplierOrganization) Count(supplierOrganization *common.SupplierOrganization) (int, error) {
|
||||
return len(so.MockListResponse), so.Error
|
||||
}
|
||||
|
||||
func (so *MockSupplierOrganization) Insert(supplierOrg *common.SupplierOrganization) (orm.Result, error) {
|
||||
return so.ORMResponse, so.Error
|
||||
}
|
||||
|
||||
func (so *MockSupplierOrganization) Update(supplierOrg *common.SupplierOrganization) (orm.Result, error) {
|
||||
return so.ORMResponse, so.Error
|
||||
}
|
||||
|
||||
func (so *MockSupplierOrganization) Delete(supplierOrg *common.SupplierOrganization) (orm.Result, error) {
|
||||
return so.ORMResponse, so.Error
|
||||
}
|
||||
|
||||
func (so *MockSupplierOrganization) Select(supplierOrg *common.SupplierOrganization, paging *queries.PageQueryOptions) ([]common.SupplierOrganization, error) {
|
||||
so.LastFilter = supplierOrg
|
||||
so.LastPaging = paging
|
||||
|
||||
return so.MockListResponse, so.Error
|
||||
}
|
||||
|
||||
func (c *MockSupplierOrganization) SetListResp(list interface{}) {
|
||||
if list != nil {
|
||||
c.MockListResponse = list.([]common.SupplierOrganization)
|
||||
}
|
||||
}
|
||||
29
pkg/db/queries/mocks/swversion_rxswin.go
Normal file
29
pkg/db/queries/mocks/swversion_rxswin.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/db/queries"
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
)
|
||||
|
||||
type MockSwVersionRxSwin struct {
|
||||
queries.QueryBase
|
||||
SelectResponse []common.SwVersionRxSwin
|
||||
DBMockHelper
|
||||
}
|
||||
|
||||
func (m *MockSwVersionRxSwin) SelectByVersion(version string, options *queries.PageQueryOptions) ([]common.SwVersionRxSwin, error) {
|
||||
return m.SelectResponse, nil
|
||||
}
|
||||
|
||||
func (m *MockSwVersionRxSwin) SelectCountByVersion(version string) (int, error) {
|
||||
return len(m.SelectResponse), nil
|
||||
}
|
||||
|
||||
func (m *MockSwVersionRxSwin) Insert(swVersionRxSwin *common.SwVersionRxSwin) (orm.Result, error) {
|
||||
return m.ORMResponse, nil
|
||||
}
|
||||
|
||||
func (m *MockSwVersionRxSwin) Delete(model *common.SwVersionRxSwin) (orm.Result, error) {
|
||||
return m.ORMResponse, nil
|
||||
}
|
||||
26
pkg/db/queries/mocks/symkeys.go
Normal file
26
pkg/db/queries/mocks/symkeys.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
|
||||
"fiskerinc.com/modules/common"
|
||||
)
|
||||
|
||||
// SymKey query methods
|
||||
type MockSymKeys struct {
|
||||
DBMockHelper
|
||||
MockListResponse []common.SymKeys
|
||||
MockSymKeys common.SymKeys
|
||||
}
|
||||
|
||||
func (sk *MockSymKeys) Insert(keys common.SymKeys) (orm.Result, error) {
|
||||
return sk.ORMResponse, sk.Error
|
||||
}
|
||||
|
||||
func (sk *MockSymKeys) SelectAll() ([]common.SymKeys, error) {
|
||||
return sk.MockListResponse, sk.Error
|
||||
}
|
||||
|
||||
func (sk *MockSymKeys) SelectByVIN(vin string) (common.SymKeys, error) {
|
||||
return sk.MockSymKeys, sk.Error
|
||||
}
|
||||
23
pkg/db/queries/mocks/tags.go
Normal file
23
pkg/db/queries/mocks/tags.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/db/queries"
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
)
|
||||
|
||||
type MockTags struct {
|
||||
queries.QueryBase
|
||||
ReceivedTags []string
|
||||
DBMockHelper
|
||||
}
|
||||
|
||||
func (t *MockTags) Update(car *common.Car) (orm.Result, error) {
|
||||
t.ReceivedTags = car.Tags
|
||||
return t.ORMResponse, nil
|
||||
}
|
||||
|
||||
func (t *MockTags) UpdateDistinctTags(car *common.Car) (orm.Result, error) {
|
||||
t.ReceivedTags = car.Tags
|
||||
return t.ORMResponse, nil
|
||||
}
|
||||
137
pkg/db/queries/mocks/updatemanifests.go
Normal file
137
pkg/db/queries/mocks/updatemanifests.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/db/queries"
|
||||
"fiskerinc.com/modules/validator"
|
||||
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type MockFingerprintParams struct {
|
||||
Serial string
|
||||
Time time.Time
|
||||
}
|
||||
|
||||
func (m *MockFingerprintParams) ManifestSerial() string {
|
||||
return m.Serial
|
||||
}
|
||||
|
||||
func (m *MockFingerprintParams) CurTime() time.Time {
|
||||
return m.Time
|
||||
}
|
||||
|
||||
type MockUpdateManifests struct {
|
||||
queries.QueryBase
|
||||
SelectResponse []common.UpdateManifest
|
||||
SelectByVINResponse []common.StatusManifest
|
||||
LoadResponse *common.UpdateManifest
|
||||
ECUUpdatesMock func(man *common.UpdateManifestECU, vin string) ([]*common.UpdateManifestECU, error)
|
||||
FlashPackManifest common.UpdateManifest
|
||||
DBMockHelper
|
||||
}
|
||||
|
||||
func (m *MockUpdateManifests) Count(filter common.UpdateManifest) (int, error) {
|
||||
return len(m.SelectResponse), m.Error
|
||||
}
|
||||
|
||||
func (m *MockUpdateManifests) Delete(manifest *common.UpdateManifest) (orm.Result, error) {
|
||||
err := validator.ValidateIDField(manifest.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m.ORMResponse, m.Error
|
||||
}
|
||||
|
||||
func (m *MockUpdateManifests) Insert(manifest *common.UpdateManifest) (orm.Result, error) {
|
||||
manifest.ID += 1
|
||||
return m.ORMResponse, m.Error
|
||||
}
|
||||
|
||||
func (m *MockUpdateManifests) ECUInsert(ecu *common.UpdateManifestECU) (orm.Result, error) {
|
||||
ecu.ID = 1
|
||||
return m.ORMResponse, m.Error
|
||||
}
|
||||
|
||||
func (m *MockUpdateManifests) FileInsert(file *common.UpdateManifestFile) (orm.Result, error) {
|
||||
return m.ORMResponse, m.Error
|
||||
}
|
||||
|
||||
func (m *MockUpdateManifests) Select(filter *common.UpdateManifest, paging *queries.PageQueryOptions) ([]common.UpdateManifest, error) {
|
||||
m.LastFilter = filter
|
||||
m.LastPaging = paging
|
||||
|
||||
return m.SelectResponse, m.Error
|
||||
}
|
||||
|
||||
func (m *MockUpdateManifests) Archive(ids []int64, active bool) (orm.Result, error) {
|
||||
return m.ORMResponse, m.Error
|
||||
}
|
||||
|
||||
func (m *MockUpdateManifests) SelectByVIN(vin string, paging *queries.PageQueryOptions) ([]common.StatusManifest, error) {
|
||||
m.LastPaging = paging
|
||||
|
||||
return m.SelectByVINResponse, m.Error
|
||||
}
|
||||
|
||||
func (m *MockUpdateManifests) Update(manifest *common.UpdateManifest) (orm.Result, error) {
|
||||
m.LastFilter = manifest
|
||||
|
||||
return m.ORMResponse, m.Error
|
||||
}
|
||||
|
||||
func (m *MockUpdateManifests) Load(manifest *common.UpdateManifest) error {
|
||||
m.LastFilter = manifest
|
||||
|
||||
if m.LoadResponse != nil {
|
||||
m.LoadResponse.ID = manifest.ID
|
||||
data, err := json.Marshal(m.LoadResponse)
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
err = json.Unmarshal(data, manifest)
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
}
|
||||
|
||||
return m.Error
|
||||
}
|
||||
|
||||
func (m *MockUpdateManifests) Search(filter common.UpdateManifestSearch, paging *queries.PageQueryOptions) ([]common.UpdateManifest, error) {
|
||||
m.LastFilter = &filter
|
||||
m.LastPaging = paging
|
||||
|
||||
return m.SelectResponse, m.Error
|
||||
}
|
||||
|
||||
func (m *MockUpdateManifests) SearchCount(filter common.UpdateManifestSearch) (int, error) {
|
||||
m.LastFilter = &filter
|
||||
|
||||
return len(m.SelectResponse), m.Error
|
||||
}
|
||||
|
||||
func (m *MockUpdateManifests) SetListResp(list interface{}) {
|
||||
if list != nil {
|
||||
m.SelectResponse = list.([]common.UpdateManifest)
|
||||
} else {
|
||||
m.SelectResponse = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockUpdateManifests) ECURollback(man *common.UpdateManifestECU, vin string) ([]*common.UpdateManifestECU, error) {
|
||||
return m.ECUUpdatesMock(man, vin)
|
||||
}
|
||||
|
||||
func (m *MockUpdateManifests) AddSUMSVersion(manifest *common.UpdateManifest) (orm.Result, error) {
|
||||
return m.ORMResponse, m.Error
|
||||
}
|
||||
|
||||
func (m *MockUpdateManifests) SelectFlashPackByVersion(versionNumber string) (manifest common.UpdateManifest, err error) {
|
||||
return m.FlashPackManifest, nil
|
||||
}
|
||||
30
pkg/db/queries/orm_results.go
Normal file
30
pkg/db/queries/orm_results.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package queries
|
||||
|
||||
import "github.com/go-pg/pg/v10/orm"
|
||||
|
||||
type ORMResults struct {
|
||||
model orm.Model
|
||||
rowsAffected int
|
||||
rowsReturned int
|
||||
}
|
||||
|
||||
func (r *ORMResults) Model() orm.Model {
|
||||
return r.model
|
||||
}
|
||||
|
||||
func (r *ORMResults) RowsAffected() int {
|
||||
return r.rowsAffected
|
||||
}
|
||||
|
||||
func (r *ORMResults) RowsReturned() int {
|
||||
return r.rowsReturned
|
||||
}
|
||||
|
||||
func (r *ORMResults) SetModel(model orm.Model) {
|
||||
r.model = model
|
||||
}
|
||||
|
||||
func (r *ORMResults) AddResult(result orm.Result) {
|
||||
r.rowsAffected += result.RowsAffected()
|
||||
r.rowsReturned += result.RowsReturned()
|
||||
}
|
||||
101
pkg/db/queries/querybase.go
Normal file
101
pkg/db/queries/querybase.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package queries
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"fiskerinc.com/modules/db"
|
||||
"github.com/go-pg/pg/v10"
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type QueryBaseInterface interface {
|
||||
GetDBConn() *pg.DB
|
||||
GetClient() *db.DBClient
|
||||
SetClient(client *db.DBClient)
|
||||
}
|
||||
|
||||
type QueryBase struct {
|
||||
client *db.DBClient
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func (q *QueryBase) GetDBConn() *pg.DB {
|
||||
return q.GetClient().GetConn()
|
||||
}
|
||||
|
||||
func (q *QueryBase) GetClient() *db.DBClient {
|
||||
q.once.Do(func() {
|
||||
if q.client == nil {
|
||||
q.client = &db.DBClient{}
|
||||
}
|
||||
})
|
||||
|
||||
return q.client
|
||||
}
|
||||
|
||||
func (q *QueryBase) SetClient(client *db.DBClient) {
|
||||
if q.client != nil {
|
||||
q.client.Close()
|
||||
}
|
||||
q.client = client
|
||||
}
|
||||
|
||||
// pageQuery update orm query with PageQueryOptions
|
||||
func (q *QueryBase) pageQuery(query *orm.Query, options *PageQueryOptions) *orm.Query {
|
||||
if options == nil {
|
||||
return query
|
||||
}
|
||||
|
||||
if options.Order != "" {
|
||||
query.Order(options.Order)
|
||||
}
|
||||
|
||||
// A limit of 0 is not respected
|
||||
if options.Limit > 0 {
|
||||
query.Limit(options.Limit)
|
||||
}
|
||||
|
||||
if options.Offset > 0 {
|
||||
query.Offset(options.Offset)
|
||||
}
|
||||
|
||||
return query
|
||||
}
|
||||
|
||||
func (q *QueryBase) countWithStack(count int, err error) (int, error) {
|
||||
return count, errors.WithStack(err)
|
||||
}
|
||||
|
||||
func (q *QueryBase) resultWithStack(result orm.Result, err error) (orm.Result, error) {
|
||||
return result, errors.WithStack(err)
|
||||
}
|
||||
|
||||
func (q *QueryBase) insertSelectWithStack(inserted bool, err error) (bool, error) {
|
||||
return inserted, errors.WithStack(err)
|
||||
}
|
||||
|
||||
func (q *QueryBase) count(filter interface{}) (int, error) {
|
||||
return q.countWithStack(q.GetDBConn().Model(filter).Count())
|
||||
}
|
||||
|
||||
func (q *QueryBase) delete(model interface{}) (orm.Result, error) {
|
||||
return q.resultWithStack(q.GetDBConn().Model(model).WherePK().Delete())
|
||||
}
|
||||
|
||||
func (q *QueryBase) insert(model interface{}) (orm.Result, error) {
|
||||
return q.resultWithStack(q.GetDBConn().Model(model).Insert())
|
||||
}
|
||||
|
||||
func (q *QueryBase) update(model interface{}) (orm.Result, error) {
|
||||
return q.resultWithStack(q.GetDBConn().Model(model).WherePK().Update())
|
||||
}
|
||||
|
||||
func (q *QueryBase) hasErrorResult(total *ORMResults, result orm.Result, err error) bool {
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
|
||||
total.AddResult(result)
|
||||
return false
|
||||
}
|
||||
41
pkg/db/queries/rate_plan_tomobile.go
Normal file
41
pkg/db/queries/rate_plan_tomobile.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package queries
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"fiskerinc.com/modules/common"
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type RatePlanInterface interface {
|
||||
Select(string) (*common.RatePlanTMobile, error)
|
||||
}
|
||||
|
||||
type RatePlanTmobile struct {
|
||||
QueryBase
|
||||
}
|
||||
|
||||
func (rpt *RatePlanTmobile) Select(country string) (*common.RatePlanTMobile, error) {
|
||||
ratePlan := []common.RatePlanTMobile{}
|
||||
query := rpt.CreateSelectQuery(country, &ratePlan)
|
||||
|
||||
err := query.Select()
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
if len(ratePlan) == 1 {
|
||||
return &ratePlan[0], nil
|
||||
} else if len(ratePlan) == 0 {
|
||||
return nil, fmt.Errorf("no rate plan exists for country %s", country)
|
||||
} else {
|
||||
return nil, fmt.Errorf("multiple rate plans exist for country %s", country)
|
||||
}
|
||||
}
|
||||
|
||||
func (rpt *RatePlanTmobile) CreateSelectQuery(country string, ratePlan *[]common.RatePlanTMobile) *orm.Query {
|
||||
return rpt.GetDBConn().
|
||||
Model(ratePlan).
|
||||
Where("country = ?", country)
|
||||
}
|
||||
20
pkg/db/queries/rate_plan_tomobile_test.go
Normal file
20
pkg/db/queries/rate_plan_tomobile_test.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package queries_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
m "fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/db/queries"
|
||||
)
|
||||
|
||||
func TestCreateSelectQuery(t *testing.T) {
|
||||
ratePlanDB := queries.RatePlanTmobile{}
|
||||
ratePlanModel := []m.RatePlanTMobile{}
|
||||
|
||||
q := ratePlanDB.CreateSelectQuery("US", &ratePlanModel)
|
||||
q_str := queryToString(q)
|
||||
expected := `SELECT "r"."country", "r"."product_id", "r"."plan_name", "r"."created_at", "r"."updated_at" FROM "rate_plan_tmobile" AS "r" WHERE (country = 'US')`
|
||||
if q_str != expected {
|
||||
t.Errorf("Expected the query string %s, but got %s.", expected, q_str)
|
||||
}
|
||||
}
|
||||
98
pkg/db/queries/signed_images.go
Normal file
98
pkg/db/queries/signed_images.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package queries
|
||||
|
||||
import (
|
||||
"fiskerinc.com/modules/common"
|
||||
s "fiskerinc.com/modules/security"
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type SignedImagesInterface interface {
|
||||
Insert(SignedImage common.SignedImage) (orm.Result, error)
|
||||
SelectAll() ([]common.SignedImage, error)
|
||||
SelectBySupplier(email string) (common.SignedImage, error)
|
||||
DeleteSigningCert(supplier_cert common.SupplierSigningCert) (orm.Result, error)
|
||||
GetSigningCert(supplier string, keyCert string) (common.SupplierSigningCert, error)
|
||||
InsertSigningCert(supplier_cert common.SupplierSigningCert) (orm.Result, error)
|
||||
}
|
||||
|
||||
type SignedImages struct {
|
||||
QueryBase
|
||||
}
|
||||
|
||||
func (si *SignedImages) Insert(signedImage common.SignedImage) (orm.Result, error) {
|
||||
return si.resultWithStack(si.GetDBConn().Model(&signedImage).Insert())
|
||||
}
|
||||
|
||||
// Selects all public keys and signatures
|
||||
func (si *SignedImages) SelectAll() ([]common.SignedImage, error) {
|
||||
signatures := []common.SignedImage{}
|
||||
|
||||
err := si.GetDBConn().Model(&signatures).Column("signature").Select()
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return signatures, nil
|
||||
}
|
||||
|
||||
// Selects all public keys and signatures
|
||||
func (si *SignedImages) SelectBySupplier(email string) (common.SignedImage, error) {
|
||||
signature := common.SignedImage{}
|
||||
|
||||
err := si.GetDBConn().Model(&signature).Where("email = ?", email).Order("created_at desc").Limit(1).Select()
|
||||
if err != nil {
|
||||
return signature, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return signature, err
|
||||
}
|
||||
|
||||
func (si *SignedImages) decryptSigningCert(cert *common.SupplierSigningCert) error {
|
||||
enc := s.Encrypt{}
|
||||
encryptor, err := enc.GetEncryptor()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if cert.PrivateCertEncrypted != nil {
|
||||
key, err := encryptor.DecryptChunk(cert.PrivateCertEncrypted.Bytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cert.PrivateCert.SetBytes(key)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (si *SignedImages) GetSigningCert(supplier string, keyCert string) (common.SupplierSigningCert, error) {
|
||||
cert := common.SupplierSigningCert{
|
||||
Supplier: supplier,
|
||||
KeyCert: keyCert,
|
||||
}
|
||||
err := si.GetDBConn().Model(&cert).WherePK().Limit(1).Select()
|
||||
if err != nil {
|
||||
return cert, errors.WithStack(err)
|
||||
}
|
||||
|
||||
err = si.decryptSigningCert(&cert)
|
||||
|
||||
return cert, err
|
||||
}
|
||||
|
||||
func (si *SignedImages) InsertSigningCert(supplier_cert common.SupplierSigningCert) (orm.Result, error) {
|
||||
enc := s.Encrypt{}
|
||||
encryptor, err := enc.GetEncryptor()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
supplier_cert.PrivateCertEncrypted = encryptor.EncryptChunk([]byte(supplier_cert.PrivateCert))
|
||||
|
||||
return si.insert(&supplier_cert)
|
||||
}
|
||||
|
||||
func (si *SignedImages) DeleteSigningCert(supplier_cert common.SupplierSigningCert) (orm.Result, error) {
|
||||
return si.delete(&supplier_cert)
|
||||
}
|
||||
85
pkg/db/queries/signed_images_test.go
Normal file
85
pkg/db/queries/signed_images_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package queries_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"fiskerinc.com/modules/common"
|
||||
m "fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/db"
|
||||
"fiskerinc.com/modules/db/queries"
|
||||
"fiskerinc.com/modules/testhelper"
|
||||
"github.com/go-pg/pg/v10"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var publicTestCert = common.NewBinaryHex([]byte("9a1a6949d7f8a511df6e2e2771e444dbd6de97"))
|
||||
var privTestCert = common.NewBinaryHex([]byte("cda02810bad1b6f1b8c6202234a424b7d5b9a1"))
|
||||
|
||||
func TestSignedImagesIntegration(t *testing.T) {
|
||||
t.Skip()
|
||||
query := setupSignedImages(t)
|
||||
testInsertSigningCert(t, query)
|
||||
testGetSigningCert(t, query)
|
||||
testDeleteSigningCert(t, query)
|
||||
}
|
||||
|
||||
func setupSignedImages(t *testing.T) queries.SignedImagesInterface {
|
||||
instance := queries.SignedImages{}
|
||||
conn = instance.GetDBConn()
|
||||
conn.AddQueryHook(db.SQLLogger{})
|
||||
|
||||
client := instance.GetClient()
|
||||
client.InitSchema([]interface{}{
|
||||
(*common.Issue)(nil),
|
||||
})
|
||||
|
||||
return &instance
|
||||
}
|
||||
|
||||
func testInsertSigningCert(t *testing.T, query queries.SignedImagesInterface) {
|
||||
cert := m.SupplierSigningCert{
|
||||
Supplier: "TESTSUPPLER",
|
||||
KeyCert: "sbc_key_4096",
|
||||
PublicCert: publicTestCert,
|
||||
PrivateCert: privTestCert,
|
||||
}
|
||||
|
||||
res, err := query.InsertSigningCert(cert)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "SupplierSigningCert insert", "No error", err)
|
||||
}
|
||||
|
||||
if res.RowsAffected() != 1 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "SupplierSigningCert insert RowsAffected", 1, res.RowsAffected())
|
||||
}
|
||||
}
|
||||
|
||||
func testGetSigningCert(t *testing.T, query queries.SignedImagesInterface) {
|
||||
_, err := query.GetSigningCert("TESTSUPPLER", "verified_rsa4096_key")
|
||||
if !errors.Is(err, pg.ErrNoRows) {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "GetSigningCert", pg.ErrNoRows, err)
|
||||
}
|
||||
|
||||
cert, err := query.GetSigningCert("TESTSUPPLER", "sbc_key_4096")
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "GetSigningCert", pg.ErrNoRows, err)
|
||||
}
|
||||
|
||||
if cert.PublicCert.String() != publicTestCert.String() {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "GetSigningCert.PublicCert", publicTestCert.String(), cert.PublicCert.String())
|
||||
}
|
||||
|
||||
if cert.PrivateCert.String() != privTestCert.String() {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "GetSigningCert.PrivateCert", privTestCert, cert.PrivateCert)
|
||||
}
|
||||
}
|
||||
|
||||
func testDeleteSigningCert(t *testing.T, query queries.SignedImagesInterface) {
|
||||
_, err := query.DeleteSigningCert(m.SupplierSigningCert{
|
||||
Supplier: "TESTSUPPLER",
|
||||
KeyCert: "sbc_key_4096",
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "DeleteSigningCert", nil, err)
|
||||
}
|
||||
}
|
||||
93
pkg/db/queries/subscription_configurations.go
Normal file
93
pkg/db/queries/subscription_configurations.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package queries
|
||||
|
||||
import (
|
||||
"fiskerinc.com/modules/common"
|
||||
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var subConfigurationCols = []string{}
|
||||
|
||||
type SubscriptionConfigurationsInterface interface {
|
||||
Delete(model *common.SubscriptionConfiguration) (orm.Result, error)
|
||||
Insert(model *common.SubscriptionConfiguration) (orm.Result, error)
|
||||
Update(model *common.SubscriptionConfiguration) (orm.Result, error)
|
||||
Count(filter *common.SubscriptionConfiguration) (int, error)
|
||||
Select(fitler *common.SubscriptionConfiguration, paging *PageQueryOptions) ([]common.SubscriptionConfiguration, error)
|
||||
}
|
||||
|
||||
type SubscriptionConfigurations struct {
|
||||
QueryBase
|
||||
}
|
||||
|
||||
func (sc *SubscriptionConfigurations) Delete(model *common.SubscriptionConfiguration) (orm.Result, error) {
|
||||
return sc.delete(model)
|
||||
}
|
||||
|
||||
func (sc *SubscriptionConfigurations) Insert(model *common.SubscriptionConfiguration) (orm.Result, error) {
|
||||
return sc.insert(model)
|
||||
}
|
||||
|
||||
func (sc *SubscriptionConfigurations) Update(model *common.SubscriptionConfiguration) (orm.Result, error) {
|
||||
return sc.resultWithStack(sc.GetDBConn().Model(model).Column(subConfigurationCols...).WherePK().Update())
|
||||
}
|
||||
|
||||
func (sc *SubscriptionConfigurations) Count(filter *common.SubscriptionConfiguration) (int, error) {
|
||||
query := sc.GetDBConn().Model(filter)
|
||||
|
||||
sc.selectFilter(query, filter)
|
||||
|
||||
count, err := query.Count()
|
||||
|
||||
return count, errors.WithStack(err)
|
||||
}
|
||||
|
||||
func (sc *SubscriptionConfigurations) selectFilter(query *orm.Query, filter *common.SubscriptionConfiguration) {
|
||||
if filter.SubscriptionFeatureID != uuid.Nil {
|
||||
query.Where("subscription_feature_id = ?", filter.SubscriptionFeatureID)
|
||||
}
|
||||
|
||||
if filter.ECU != "" {
|
||||
query.Where("ecu = ?", filter.ECU)
|
||||
}
|
||||
|
||||
if filter.HardwareVersion != "" {
|
||||
query.Where("hardware_version = ?", filter.HardwareVersion)
|
||||
}
|
||||
|
||||
if filter.SoftwareVersion != "" {
|
||||
query.Where("software_version = ?", filter.SoftwareVersion)
|
||||
}
|
||||
|
||||
if filter.DID != nil {
|
||||
query.Where("did = decode(?, 'hex')", filter.DID.String())
|
||||
}
|
||||
|
||||
if filter.PID != nil {
|
||||
query.Where("pid = decode(?, 'hex')", filter.PID.String())
|
||||
}
|
||||
|
||||
if filter.Configuration != nil {
|
||||
query.Where("configuration = decode(?, 'hex')", filter.Configuration.String())
|
||||
}
|
||||
|
||||
if filter.Mask != nil {
|
||||
query.Where("mask = decode(?, 'hex')", filter.Mask.String())
|
||||
}
|
||||
}
|
||||
|
||||
func (sc *SubscriptionConfigurations) Select(filter *common.SubscriptionConfiguration, paging *PageQueryOptions) ([]common.SubscriptionConfiguration, error) {
|
||||
items := []common.SubscriptionConfiguration{}
|
||||
query := sc.GetDBConn().Model(&items)
|
||||
|
||||
sc.selectFilter(query, filter)
|
||||
if paging != nil {
|
||||
sc.pageQuery(query, paging)
|
||||
}
|
||||
|
||||
err := query.Select()
|
||||
|
||||
return items, errors.WithStack(err)
|
||||
}
|
||||
141
pkg/db/queries/subscription_features.go
Normal file
141
pkg/db/queries/subscription_features.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package queries
|
||||
|
||||
import (
|
||||
"fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/validator"
|
||||
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var subFeaturesCols = []string{}
|
||||
|
||||
type SubscriptionFeaturesInterface interface {
|
||||
Delete(model *common.SubscriptionFeature) (orm.Result, error)
|
||||
Insert(model *common.SubscriptionFeature) (orm.Result, error)
|
||||
Update(model *common.SubscriptionFeature) (orm.Result, error)
|
||||
Count(filter *common.SubscriptionFeature) (int, error)
|
||||
Select(fitler *common.SubscriptionFeature, paging *PageQueryOptions) ([]common.SubscriptionFeature, error)
|
||||
Load(model *common.SubscriptionFeature) error
|
||||
}
|
||||
|
||||
type SubscriptionFeatures struct {
|
||||
QueryBase
|
||||
}
|
||||
|
||||
func (sf *SubscriptionFeatures) Delete(model *common.SubscriptionFeature) (orm.Result, error) {
|
||||
var err error
|
||||
total := &ORMResults{}
|
||||
tx, err := sf.GetDBConn().Begin()
|
||||
if err != nil {
|
||||
return total, errors.WithStack(err)
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
tx.Close()
|
||||
}()
|
||||
|
||||
result, err := tx.Model(&common.SubscriptionConfiguration{SubscriptionFeatureID: model.ID}).Where("subscription_feature_id = ?subscription_feature_id").Delete()
|
||||
if err != nil {
|
||||
return total, errors.WithStack(err)
|
||||
}
|
||||
total.AddResult(result)
|
||||
|
||||
result, err = tx.Model(model).WherePK().Delete()
|
||||
if err != nil {
|
||||
return total, errors.WithStack(err)
|
||||
}
|
||||
total.AddResult(result)
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return total, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
func (sf *SubscriptionFeatures) Insert(model *common.SubscriptionFeature) (orm.Result, error) {
|
||||
return sf.insert(model)
|
||||
}
|
||||
|
||||
func (sf *SubscriptionFeatures) Update(model *common.SubscriptionFeature) (orm.Result, error) {
|
||||
err := validator.ValidateStruct(model)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return sf.resultWithStack(sf.GetDBConn().Model(model).Column(subFeaturesCols...).WherePK().Update())
|
||||
}
|
||||
|
||||
func (sf *SubscriptionFeatures) Count(filter *common.SubscriptionFeature) (int, error) {
|
||||
query := sf.GetDBConn().Model(filter)
|
||||
|
||||
sf.selectFilter(query, filter)
|
||||
|
||||
count, err := query.Count()
|
||||
|
||||
return count, errors.WithStack(err)
|
||||
}
|
||||
|
||||
func (sf *SubscriptionFeatures) Load(model *common.SubscriptionFeature) error {
|
||||
var err error
|
||||
|
||||
tx, err := sf.GetDBConn().Begin()
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
tx.Close()
|
||||
}()
|
||||
|
||||
err = tx.Model(model).WherePK().First()
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
configs := []common.SubscriptionConfiguration{}
|
||||
err = tx.Model(&configs).Where("subscription_feature_id = ?", model.ID).Select()
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
model.Configurations = configs
|
||||
|
||||
err = tx.Commit()
|
||||
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
func (sf *SubscriptionFeatures) selectFilter(query *orm.Query, filter *common.SubscriptionFeature) {
|
||||
if filter.ID != uuid.Nil {
|
||||
query.Where("ID = ?", filter.ID)
|
||||
}
|
||||
|
||||
if filter.Name != "" {
|
||||
query.Where("name = ?", filter.Name)
|
||||
}
|
||||
|
||||
if filter.Description != "" {
|
||||
query.Where("description = ?", filter.Description)
|
||||
}
|
||||
}
|
||||
|
||||
func (sf *SubscriptionFeatures) Select(filter *common.SubscriptionFeature, paging *PageQueryOptions) ([]common.SubscriptionFeature, error) {
|
||||
items := []common.SubscriptionFeature{}
|
||||
query := sf.GetDBConn().Model(&items)
|
||||
|
||||
sf.selectFilter(query, filter)
|
||||
if paging != nil {
|
||||
sf.pageQuery(query, paging)
|
||||
}
|
||||
|
||||
err := query.Select()
|
||||
|
||||
return items, errors.WithStack(err)
|
||||
}
|
||||
156
pkg/db/queries/subscription_packages.go
Normal file
156
pkg/db/queries/subscription_packages.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package queries
|
||||
|
||||
import (
|
||||
"fiskerinc.com/modules/common"
|
||||
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type SubscriptionPackagesInterface interface {
|
||||
Delete(model *common.SubscriptionPackage) (orm.Result, error)
|
||||
Insert(model *common.SubscriptionPackage) (orm.Result, error)
|
||||
Update(model *common.SubscriptionPackage) (orm.Result, error)
|
||||
Count(filter *common.SubscriptionPackage) (int, error)
|
||||
Select(fitler *common.SubscriptionPackage, paging *PageQueryOptions) ([]common.SubscriptionPackage, error)
|
||||
Load(model *common.SubscriptionPackage) error
|
||||
AddFeature(pack *common.SubscriptionPackage, feature *common.SubscriptionFeature) (bool, error)
|
||||
AssociateFeature(packageid uuid.UUID, featureid uuid.UUID) (bool, error)
|
||||
}
|
||||
|
||||
type SubscriptionPackages struct {
|
||||
QueryBase
|
||||
}
|
||||
|
||||
func (sp *SubscriptionPackages) Delete(model *common.SubscriptionPackage) (orm.Result, error) {
|
||||
var err error
|
||||
total := &ORMResults{}
|
||||
tx, err := sp.GetDBConn().Begin()
|
||||
if err != nil {
|
||||
return total, errors.WithStack(err)
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
tx.Close()
|
||||
}()
|
||||
|
||||
result, err := tx.Model(&common.SubscriptionPackageToFeature{SubscriptionPackageID: model.ID}).Where("subscription_package_id = ?subscription_package_id").Delete()
|
||||
if err != nil {
|
||||
return total, errors.WithStack(err)
|
||||
}
|
||||
total.AddResult(result)
|
||||
|
||||
result, err = tx.Model(model).WherePK().Delete()
|
||||
if err != nil {
|
||||
return total, errors.WithStack(err)
|
||||
}
|
||||
total.AddResult(result)
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return total, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
func (sp *SubscriptionPackages) Insert(model *common.SubscriptionPackage) (orm.Result, error) {
|
||||
return sp.insert(model)
|
||||
}
|
||||
|
||||
func (sp *SubscriptionPackages) Update(model *common.SubscriptionPackage) (orm.Result, error) {
|
||||
return sp.update(model)
|
||||
}
|
||||
|
||||
func (sp *SubscriptionPackages) Count(filter *common.SubscriptionPackage) (int, error) {
|
||||
query := sp.GetDBConn().Model(filter)
|
||||
|
||||
sp.selectFilter(query, filter)
|
||||
|
||||
count, err := query.Count()
|
||||
|
||||
return count, errors.WithStack(err)
|
||||
}
|
||||
|
||||
func (sp *SubscriptionPackages) selectFilter(query *orm.Query, filter *common.SubscriptionPackage) {
|
||||
if filter.ID != uuid.Nil {
|
||||
query.Where("id = ?", filter.ID)
|
||||
}
|
||||
|
||||
if filter.Name != "" {
|
||||
query.Where("name = ?", filter.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func (sp *SubscriptionPackages) Select(filter *common.SubscriptionPackage, paging *PageQueryOptions) ([]common.SubscriptionPackage, error) {
|
||||
items := []common.SubscriptionPackage{}
|
||||
query := sp.GetDBConn().Model(&items)
|
||||
|
||||
sp.selectFilter(query, filter)
|
||||
if paging != nil {
|
||||
sp.pageQuery(query, paging)
|
||||
}
|
||||
|
||||
err := query.Select()
|
||||
|
||||
return items, errors.WithStack(err)
|
||||
}
|
||||
|
||||
func (sp *SubscriptionPackages) Load(model *common.SubscriptionPackage) error {
|
||||
var err error
|
||||
tx, err := sp.GetDBConn().Begin()
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
tx.Close()
|
||||
}()
|
||||
|
||||
err = tx.Model(model).WherePK().First()
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
features := []common.SubscriptionFeature{}
|
||||
err = tx.Model(&features).Join("JOIN subscription_package_to_features").JoinOn("subscription_package_to_features.subscription_feature_id = subscription_feature.id").Where("subscription_package_to_features.subscription_package_id = ?", model.ID).Select()
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
model.Features = features
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sp *SubscriptionPackages) AddFeature(model *common.SubscriptionPackage, feature *common.SubscriptionFeature) (bool, error) {
|
||||
result, err := sp.AssociateFeature(model.ID, feature.ID)
|
||||
if err == nil {
|
||||
model.AddFeature(feature)
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (sp *SubscriptionPackages) AssociateFeature(packageid uuid.UUID, featureid uuid.UUID) (bool, error) {
|
||||
inserted, err := sp.GetDBConn().Model(&common.SubscriptionPackageToFeature{
|
||||
SubscriptionPackageID: packageid,
|
||||
SubscriptionFeatureID: featureid,
|
||||
}).SelectOrInsert()
|
||||
|
||||
return inserted, errors.WithStack(err)
|
||||
}
|
||||
|
||||
func (sp *SubscriptionPackages) RemoveFeature(model *common.SubscriptionPackage, feature *common.SubscriptionFeature) (orm.Result, error) {
|
||||
result, err := sp.resultWithStack(sp.GetDBConn().Model((*common.SubscriptionPackageToFeature)(nil)).Where("package_id = ? AND feature_id = ?", model.ID, feature.ID).Delete())
|
||||
if err == nil {
|
||||
model.RemoveFeature(feature)
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
||||
252
pkg/db/queries/subscription_packages_test.go
Normal file
252
pkg/db/queries/subscription_packages_test.go
Normal file
@@ -0,0 +1,252 @@
|
||||
package queries_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
m "fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/db"
|
||||
"fiskerinc.com/modules/db/queries"
|
||||
th "fiskerinc.com/modules/testhelper"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const testSubPackageName = "TESTSUBPACKAGE"
|
||||
|
||||
func TestSubscriptionPackagesIntegration(t *testing.T) {
|
||||
t.Skip()
|
||||
qsp := setupSubscriptionPackages()
|
||||
qsf := &queries.SubscriptionFeatures{}
|
||||
qsc := &queries.SubscriptionConfigurations{}
|
||||
|
||||
pack := testSubscriptionPackageInsert(t, qsp)
|
||||
if pack.ID == uuid.Nil {
|
||||
t.Error("unable to create package")
|
||||
return
|
||||
}
|
||||
|
||||
feature := testSubscriptionPackageFeatureInsert(t, qsf)
|
||||
if feature.ID == uuid.Nil {
|
||||
t.Error("unable to create feature")
|
||||
return
|
||||
}
|
||||
|
||||
config, err := testSubscriptionAddConfiguration(t, qsc, &feature)
|
||||
if err != nil {
|
||||
t.Error("unable to create configuration")
|
||||
return
|
||||
}
|
||||
|
||||
testSubscriptionAddFeature(t, qsp, &pack, &feature)
|
||||
testSubscriptionPackageSelect(t, qsp, &pack)
|
||||
testSubscriptionFeatureSelect(t, qsf, &feature)
|
||||
testSubscriptionConfigurationSelect(t, qsc, &config)
|
||||
testSubscriptionPackageCount(t, qsp, &pack)
|
||||
testSubscriptionFeatureCount(t, qsf, &feature)
|
||||
testSubscriptionConfigurationCount(t, qsc, &config)
|
||||
testSubscriptionPackageUpdate(t, qsp, &pack)
|
||||
testSubscriptionFeatureUpdate(t, qsf, &feature)
|
||||
testSubscriptionConfigurationUpdate(t, qsc, &config)
|
||||
testSubscriptionFeatureLoad(t, qsf, &feature)
|
||||
testSubscriptionPackageLoad(t, qsp, &pack)
|
||||
testSubscriptionPackageDelete(t, qsp, &pack)
|
||||
testSubscriptionFeatureDelete(t, qsf, &feature)
|
||||
testSubscriptionConfigDelete(t, qsc, &config)
|
||||
}
|
||||
|
||||
func setupSubscriptionPackages() queries.SubscriptionPackagesInterface {
|
||||
instance := &queries.SubscriptionPackages{}
|
||||
conn = instance.GetDBConn()
|
||||
conn.AddQueryHook(db.SQLLogger{})
|
||||
client := instance.GetClient()
|
||||
client.InitSchema([]interface{}{
|
||||
(*m.SubscriptionPackageToFeature)(nil),
|
||||
(*m.SubscriptionPackage)(nil),
|
||||
(*m.SubscriptionFeature)(nil),
|
||||
(*m.SubscriptionConfiguration)(nil),
|
||||
})
|
||||
|
||||
return instance
|
||||
}
|
||||
|
||||
func testSubscriptionPackageInsert(t *testing.T, q queries.SubscriptionPackagesInterface) m.SubscriptionPackage {
|
||||
model := m.SubscriptionPackage{Name: testSubPackageName}
|
||||
|
||||
result, err := q.Insert(&model)
|
||||
|
||||
if th.NoError(t, "Insert Package error", err) {
|
||||
return model
|
||||
}
|
||||
th.Equal(t, "Insert Package affected", 1, result.RowsAffected())
|
||||
th.Equal(t, "Insert Package returned", 1, result.RowsReturned())
|
||||
|
||||
return model
|
||||
}
|
||||
|
||||
func testSubscriptionPackageFeatureInsert(t *testing.T, q queries.SubscriptionFeaturesInterface) m.SubscriptionFeature {
|
||||
feature := m.SubscriptionFeature{
|
||||
Name: testSubPackageName,
|
||||
Description: "Test Description",
|
||||
}
|
||||
|
||||
result, err := q.Insert(&feature)
|
||||
|
||||
if th.NoError(t, "Insert Feature error", err) {
|
||||
return feature
|
||||
}
|
||||
th.Equal(t, "Insert Feature affected", 1, result.RowsAffected())
|
||||
th.Equal(t, "Insert Feature returned", 1, result.RowsReturned())
|
||||
|
||||
return feature
|
||||
}
|
||||
|
||||
func testSubscriptionAddFeature(t *testing.T, q queries.SubscriptionPackagesInterface, p *m.SubscriptionPackage, f *m.SubscriptionFeature) {
|
||||
result, err := q.AddFeature(p, f)
|
||||
|
||||
th.NoError(t, "AddFeature error", err)
|
||||
th.True(t, "AddFeature result", result)
|
||||
}
|
||||
|
||||
func testSubscriptionAddConfiguration(t *testing.T, q queries.SubscriptionConfigurationsInterface, f *m.SubscriptionFeature) (m.SubscriptionConfiguration, error) {
|
||||
model := m.SubscriptionConfiguration{
|
||||
SubscriptionFeatureID: f.ID,
|
||||
ECU: "TEST",
|
||||
SoftwareVersion: "SOFTWAREVERSION",
|
||||
HardwareVersion: "HARDWAREVERSION",
|
||||
Configuration: &m.BinaryHex{0x99},
|
||||
DID: &m.BinaryHex{0x01},
|
||||
PID: &m.BinaryHex{0x02},
|
||||
Mask: &m.BinaryHex{0x03},
|
||||
}
|
||||
|
||||
result, err := q.Insert(&model)
|
||||
|
||||
if th.NoError(t, "Insert Config error", err) {
|
||||
return model, err
|
||||
}
|
||||
th.Equal(t, "Insert Config affected", 1, result.RowsAffected())
|
||||
th.Equal(t, "Insert Config returned", 1, result.RowsReturned())
|
||||
|
||||
return model, nil
|
||||
}
|
||||
|
||||
func testSubscriptionPackageUpdate(t *testing.T, q queries.SubscriptionPackagesInterface, model *m.SubscriptionPackage) {
|
||||
model.Name = model.Name + "X"
|
||||
result, err := q.Update(model)
|
||||
|
||||
if th.NoError(t, "Update Package error", err) {
|
||||
return
|
||||
}
|
||||
th.Equal(t, "Update Package affected", 1, result.RowsAffected())
|
||||
th.Equal(t, "Update Package returned", 0, result.RowsReturned())
|
||||
}
|
||||
|
||||
func testSubscriptionFeatureUpdate(t *testing.T, q queries.SubscriptionFeaturesInterface, model *m.SubscriptionFeature) {
|
||||
result, err := q.Update(model)
|
||||
|
||||
if th.NoError(t, "Update Feature error", err) {
|
||||
return
|
||||
}
|
||||
th.Equal(t, "Update Feature affected", 1, result.RowsAffected())
|
||||
th.Equal(t, "Update Feature returned", 0, result.RowsReturned())
|
||||
}
|
||||
|
||||
func testSubscriptionConfigurationUpdate(t *testing.T, q queries.SubscriptionConfigurationsInterface, model *m.SubscriptionConfiguration) {
|
||||
result, err := q.Update(model)
|
||||
|
||||
if th.NoError(t, "Update Config error", err) {
|
||||
return
|
||||
}
|
||||
th.Equal(t, "Update Config affected", 1, result.RowsAffected())
|
||||
th.Equal(t, "Update Config returned", 0, result.RowsReturned())
|
||||
}
|
||||
|
||||
func testSubscriptionPackageCount(t *testing.T, q queries.SubscriptionPackagesInterface, model *m.SubscriptionPackage) {
|
||||
result, err := q.Count(model)
|
||||
|
||||
if th.NoError(t, "Count Package error", err) {
|
||||
return
|
||||
}
|
||||
th.Equal(t, "Count Package result", 1, result)
|
||||
}
|
||||
|
||||
func testSubscriptionFeatureCount(t *testing.T, q queries.SubscriptionFeaturesInterface, model *m.SubscriptionFeature) {
|
||||
result, err := q.Count(model)
|
||||
|
||||
th.NoError(t, "Count Feature error", err)
|
||||
th.Equal(t, "Count Feature result", 1, result)
|
||||
}
|
||||
|
||||
func testSubscriptionConfigurationCount(t *testing.T, q queries.SubscriptionConfigurationsInterface, model *m.SubscriptionConfiguration) {
|
||||
result, err := q.Count(model)
|
||||
|
||||
th.NoError(t, "Count Config error", err)
|
||||
th.Equal(t, "Count Config result", 1, result)
|
||||
}
|
||||
|
||||
func testSubscriptionPackageSelect(t *testing.T, q queries.SubscriptionPackagesInterface, model *m.SubscriptionPackage) {
|
||||
result, err := q.Select(model, nil)
|
||||
|
||||
th.NoError(t, "Select Package error", err)
|
||||
th.Equal(t, "Select Package result", 1, len(result))
|
||||
}
|
||||
|
||||
func testSubscriptionFeatureSelect(t *testing.T, q queries.SubscriptionFeaturesInterface, model *m.SubscriptionFeature) {
|
||||
result, err := q.Select(model, nil)
|
||||
|
||||
th.NoError(t, "Select Feature error", err)
|
||||
th.Equal(t, "Select Feature result", 1, len(result))
|
||||
}
|
||||
|
||||
func testSubscriptionConfigurationSelect(t *testing.T, q queries.SubscriptionConfigurationsInterface, model *m.SubscriptionConfiguration) {
|
||||
result, err := q.Select(model, nil)
|
||||
|
||||
th.NoError(t, "Select Config error", err)
|
||||
th.Equal(t, "Select Config result", 1, len(result))
|
||||
}
|
||||
|
||||
func testSubscriptionFeatureLoad(t *testing.T, q queries.SubscriptionFeaturesInterface, model *m.SubscriptionFeature) {
|
||||
item := m.SubscriptionFeature{ID: model.ID}
|
||||
err := q.Load(&item)
|
||||
|
||||
th.NoError(t, "Load Feature error", err)
|
||||
th.Equal(t, "Load Feature configurations", 1, len(item.Configurations))
|
||||
}
|
||||
|
||||
func testSubscriptionPackageLoad(t *testing.T, q queries.SubscriptionPackagesInterface, model *m.SubscriptionPackage) {
|
||||
item := m.SubscriptionPackage{ID: model.ID}
|
||||
err := q.Load(&item)
|
||||
|
||||
th.NoError(t, "Load Package error", err)
|
||||
th.Equal(t, "Load Package features", 1, len(item.Features))
|
||||
}
|
||||
|
||||
func testSubscriptionPackageDelete(t *testing.T, q queries.SubscriptionPackagesInterface, model *m.SubscriptionPackage) {
|
||||
result, err := q.Delete(model)
|
||||
|
||||
if th.NoError(t, "Delete Package error", err) {
|
||||
return
|
||||
}
|
||||
th.Equal(t, "Delete Package affected", 2, result.RowsAffected())
|
||||
th.Equal(t, "Delete Package returned", 0, result.RowsReturned())
|
||||
}
|
||||
|
||||
func testSubscriptionFeatureDelete(t *testing.T, q queries.SubscriptionFeaturesInterface, model *m.SubscriptionFeature) {
|
||||
result, err := q.Delete(model)
|
||||
|
||||
if th.NoError(t, "Delete Feature error", err) {
|
||||
return
|
||||
}
|
||||
th.Equal(t, "Delete Feature affected", 2, result.RowsAffected())
|
||||
th.Equal(t, "Delete Feature returned", 0, result.RowsReturned())
|
||||
}
|
||||
|
||||
func testSubscriptionConfigDelete(t *testing.T, q queries.SubscriptionConfigurationsInterface, model *m.SubscriptionConfiguration) {
|
||||
result, err := q.Delete(model)
|
||||
|
||||
if th.NoError(t, "Delete Config error", err) {
|
||||
return
|
||||
}
|
||||
th.Equal(t, "Delete Config affected", 0, result.RowsAffected())
|
||||
th.Equal(t, "Delete Config returned", 0, result.RowsReturned())
|
||||
}
|
||||
182
pkg/db/queries/subscriptions._test.go
Normal file
182
pkg/db/queries/subscriptions._test.go
Normal file
@@ -0,0 +1,182 @@
|
||||
package queries_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/db"
|
||||
"fiskerinc.com/modules/db/queries"
|
||||
"fiskerinc.com/modules/testhelper"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func TestSubscriptions(t *testing.T) {
|
||||
t.Skip()
|
||||
query, carToDriver, subtype := setupSubscriptions(t)
|
||||
defer cleanUpSubscriptionTest(carToDriver, subtype)
|
||||
if query == nil || carToDriver == nil || subtype == nil {
|
||||
t.Error("setupSubscriptions error")
|
||||
return
|
||||
}
|
||||
|
||||
subID := createSubscription(t, query, carToDriver, subtype)
|
||||
if subID == 0 {
|
||||
return
|
||||
}
|
||||
selectSubscription(t, query, subID)
|
||||
updateSubscription(t, query, subID)
|
||||
deleteSubscription(t, query, subID)
|
||||
}
|
||||
|
||||
func setupSubscriptions(t *testing.T) (queries.SubscriptionsInterface, *common.CarToDriver, *common.SubscriptionType) {
|
||||
subtype := createTestSubscriptionType(t)
|
||||
carToDriver, err := createTestCarDriver(t)
|
||||
if err != nil {
|
||||
return nil, nil, nil
|
||||
}
|
||||
instance := &queries.Subscriptions{}
|
||||
conn = instance.GetDBConn()
|
||||
conn.AddQueryHook(db.SQLLogger{})
|
||||
|
||||
return instance, carToDriver, subtype
|
||||
}
|
||||
|
||||
func createTestSubscriptionType(t *testing.T) *common.SubscriptionType {
|
||||
subtypes := setupSubscriptionTypes(t)
|
||||
subtype := insertSubscriptionType(t, subtypes)
|
||||
|
||||
return subtype
|
||||
}
|
||||
|
||||
func createTestCarDriver(t *testing.T) (*common.CarToDriver, error) {
|
||||
driver := common.Driver{
|
||||
ID: uuid.New().String(),
|
||||
}
|
||||
drivers := queries.Drivers{}
|
||||
_, err := drivers.Insert(&driver)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cars := queries.Cars{}
|
||||
car := common.Car{
|
||||
VIN: "4S3BJ6332P6953766",
|
||||
}
|
||||
_, err = cars.Insert(&car)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
carToDriver, err := cars.AddDriver(&car, &driver, "driver")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return carToDriver, nil
|
||||
}
|
||||
|
||||
func createSubscription(t *testing.T, query queries.SubscriptionsInterface, carToDriver *common.CarToDriver, subtype *common.SubscriptionType) int64 {
|
||||
sub, err := query.Create(subtype, carToDriver)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Subscription insert", nil, err)
|
||||
return 0
|
||||
}
|
||||
|
||||
return sub.ID
|
||||
}
|
||||
|
||||
func selectSubscription(t *testing.T, query queries.SubscriptionsInterface, subID int64) {
|
||||
sub := common.Subscription{ID: subID}
|
||||
err := query.Load(&sub)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Subscriptions Load", nil, err)
|
||||
}
|
||||
if sub.Name == "" {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Subscriptions Load Name", "not empty", sub.Name)
|
||||
}
|
||||
|
||||
subTypeValues := map[string]interface{}{
|
||||
"Name": "NAME",
|
||||
"Destination": "ICC",
|
||||
}
|
||||
testhelper.PropsTester(t, &sub, subTypeValues)
|
||||
|
||||
subtypes, err := query.Select(&common.Subscription{ID: subID})
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Subscriptions Select", nil, err)
|
||||
}
|
||||
if len(subtypes) != 1 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Subscriptions Select count", 1, len(subtypes))
|
||||
} else {
|
||||
testhelper.PropsTester(t, &subtypes[0], subTypeValues)
|
||||
}
|
||||
|
||||
count, err := query.Count(&common.Subscription{ID: subID})
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Subscriptions Count", nil, err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Subscriptions Count", 1, count)
|
||||
}
|
||||
}
|
||||
|
||||
func updateSubscription(t *testing.T, query queries.SubscriptionsInterface, subID int64) {
|
||||
sub := common.Subscription{ID: subID}
|
||||
err := query.Load(&sub)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Subscriptions Update Load", nil, err)
|
||||
}
|
||||
|
||||
sub.Name = "NAME2"
|
||||
result, err := query.Update(&sub)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Subscriptions Update", nil, err)
|
||||
} else {
|
||||
if result.RowsAffected() != 1 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Subscriptions Update RowsAffected", 1, result.RowsAffected())
|
||||
}
|
||||
if result.RowsReturned() != 0 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Subscriptions Update RowsReturned", 0, result.RowsReturned())
|
||||
}
|
||||
}
|
||||
|
||||
sub = common.Subscription{ID: subID}
|
||||
err = query.Load(&sub)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Subscriptions Update Reload", nil, err)
|
||||
}
|
||||
if sub.Name != "NAME2" {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Subscriptions Update", "NAME2", sub.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func deleteSubscription(t *testing.T, query queries.SubscriptionsInterface, subID int64) {
|
||||
result, err := query.Delete(&queries.SubscriptionDeleteRequest{VIN: "4S3BJ6332P6953766", Name: "NAME2"})
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Subscriptions delete", nil, err)
|
||||
return
|
||||
}
|
||||
|
||||
if result.RowsAffected() != 1 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Subscriptions delete RowsAffected", 1, result.RowsAffected())
|
||||
}
|
||||
|
||||
if result.RowsReturned() != 0 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Subscriptions delete RowsReturned", 0, result.RowsReturned())
|
||||
}
|
||||
}
|
||||
|
||||
func cleanUpSubscriptionTest(carToDriver *common.CarToDriver, subtype *common.SubscriptionType) {
|
||||
drivers := queries.Drivers{}
|
||||
cars := queries.Cars{}
|
||||
|
||||
cars.RemoveDriver(carToDriver.VIN, carToDriver.DriverID)
|
||||
cars.Delete(&common.Car{VIN: carToDriver.VIN})
|
||||
|
||||
drivers.Delete(&common.Driver{ID: carToDriver.DriverID})
|
||||
|
||||
if subtype.ID != uuid.Nil {
|
||||
subtypes := queries.SubscriptionTypes{}
|
||||
subtypes.Delete(subtype)
|
||||
}
|
||||
}
|
||||
156
pkg/db/queries/subscriptions.go
Normal file
156
pkg/db/queries/subscriptions.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package queries
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/validator"
|
||||
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type SubscriptionsInterface interface {
|
||||
Delete(req *SubscriptionDeleteRequest) (orm.Result, error)
|
||||
Insert(subtype *common.Subscription) (orm.Result, error)
|
||||
Update(subtype *common.Subscription) (orm.Result, error)
|
||||
Select(filter *common.Subscription) ([]common.Subscription, error)
|
||||
Load(subtype *common.Subscription) error
|
||||
Count(filter *common.Subscription) (int, error)
|
||||
Create(subtype *common.SubscriptionType, carToDriver *common.CarToDriver) (*common.Subscription, error)
|
||||
}
|
||||
|
||||
type Subscriptions struct {
|
||||
QueryBase
|
||||
}
|
||||
|
||||
// Select returns list of drivers
|
||||
func (s *Subscriptions) Select(filter *common.Subscription) ([]common.Subscription, error) {
|
||||
subs := []common.Subscription{}
|
||||
query := s.GetDBConn().Model(&subs)
|
||||
|
||||
s.selectFilter(query, filter)
|
||||
|
||||
err := query.Select()
|
||||
|
||||
return subs, errors.WithStack(err)
|
||||
}
|
||||
|
||||
func (s *Subscriptions) selectFilter(query *orm.Query, filter *common.Subscription) {
|
||||
if filter.ID != 0 {
|
||||
query.Where("id = ?", filter.ID)
|
||||
}
|
||||
|
||||
if filter.SubscriptionTypeID != uuid.Nil {
|
||||
query.Where("subscription_type_id = ?", filter.SubscriptionTypeID)
|
||||
}
|
||||
|
||||
if filter.CarToDriverID != 0 {
|
||||
query.Where("car_to_driver_id = ?", filter.CarToDriverID)
|
||||
}
|
||||
|
||||
if filter.Name != "" {
|
||||
query.Where("name = ?", filter.Name)
|
||||
}
|
||||
|
||||
if filter.Name != "" {
|
||||
query.Where("destination = ?", filter.Destination)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Subscriptions) Insert(subtype *common.Subscription) (orm.Result, error) {
|
||||
return s.insert(subtype)
|
||||
}
|
||||
|
||||
func (s *Subscriptions) Update(subtype *common.Subscription) (orm.Result, error) {
|
||||
err := s.validateSubID(subtype)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.GetDBConn().Model(subtype).Column("subscription_type_id", "car_to_driver_id", "name", "destination", "expires").WherePK().Update()
|
||||
}
|
||||
|
||||
func (s *Subscriptions) Delete(req *SubscriptionDeleteRequest) (orm.Result, error) {
|
||||
if req.ID > 0 {
|
||||
return s.deleteByID(&common.Subscription{ID: req.ID})
|
||||
}
|
||||
|
||||
return s.deleteRequest(req)
|
||||
}
|
||||
|
||||
func (s *Subscriptions) deleteByID(sub *common.Subscription) (orm.Result, error) {
|
||||
err := s.validateSubID(sub)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.resultWithStack(s.GetDBConn().Model(sub).WherePK().Delete())
|
||||
}
|
||||
|
||||
func (s *Subscriptions) deleteRequest(req *SubscriptionDeleteRequest) (orm.Result, error) {
|
||||
err := validator.ValidateStruct(req)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return s.resultWithStack(s.GetDBConn().Model((*common.Subscription)(nil)).Exec(`DELETE FROM "subscriptions" AS "subs" WHERE (subs.id IN (SELECT subs2.id FROM "subscriptions" AS "subs2" JOIN car_to_drivers AS c ON (c.id = subs2.car_to_driver_id) WHERE (c.vin = ? AND subs2.name = ?)))`, req.VIN, req.Name))
|
||||
}
|
||||
|
||||
func (s *Subscriptions) Load(sub *common.Subscription) error {
|
||||
query := s.GetDBConn().Model(sub)
|
||||
|
||||
if sub.ID != 0 {
|
||||
query.WherePK()
|
||||
} else if sub.SubscriptionTypeID != uuid.Nil && sub.CarToDriverID != 0 {
|
||||
query.Where("subscription_type_id = ?subscription_type_id AND driver_id = ?car_to_driver_id")
|
||||
} else {
|
||||
return errors.New("no id or subscription_type_id, driver_id")
|
||||
}
|
||||
|
||||
return errors.WithStack(query.Select())
|
||||
}
|
||||
|
||||
func (s *Subscriptions) Count(filter *common.Subscription) (int, error) {
|
||||
query := s.GetDBConn().Model((*common.Subscription)(nil))
|
||||
|
||||
s.selectFilter(query, filter)
|
||||
|
||||
count, err := query.Count()
|
||||
|
||||
return count, errors.WithStack(err)
|
||||
}
|
||||
|
||||
func (s *Subscriptions) validateSubID(subtype *common.Subscription) error {
|
||||
if subtype.ID == 0 {
|
||||
return errors.WithStack(&validator.FieldError{
|
||||
ErrorMsg: "ID required",
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Subscriptions) Create(subtype *common.SubscriptionType, carToDriver *common.CarToDriver) (*common.Subscription, error) {
|
||||
sub := common.Subscription{
|
||||
Name: subtype.Name,
|
||||
Destination: subtype.Destination,
|
||||
CarToDriverID: carToDriver.ID,
|
||||
SubscriptionTypeID: subtype.ID,
|
||||
Expires: time.Now().Add(subtype.GetDuration()),
|
||||
}
|
||||
|
||||
_, err := s.Insert(&sub)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &sub, nil
|
||||
}
|
||||
|
||||
type SubscriptionDeleteRequest struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name" validate:"required,max=256"`
|
||||
VIN string `json:"vin" validate:"required,vin"`
|
||||
}
|
||||
165
pkg/db/queries/subscriptiontypes._test.go
Normal file
165
pkg/db/queries/subscriptiontypes._test.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package queries_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"fiskerinc.com/modules/common"
|
||||
m "fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/db"
|
||||
"fiskerinc.com/modules/db/queries"
|
||||
"fiskerinc.com/modules/testhelper"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func TestSubscriptionTypes(t *testing.T) {
|
||||
t.Skip()
|
||||
query := setupSubscriptionTypes(t)
|
||||
subType := insertSubscriptionType(t, query)
|
||||
if subType == nil {
|
||||
return
|
||||
}
|
||||
selectSubscriptionType(t, query, subType.ID)
|
||||
updateSubscriptionType(t, query, subType.ID)
|
||||
deleteSubscriptionType(t, query, subType.ID)
|
||||
}
|
||||
|
||||
func setupSubscriptionTypes(t *testing.T) queries.SubscriptionTypesInterface {
|
||||
instance := &queries.SubscriptionTypes{}
|
||||
conn = instance.GetDBConn()
|
||||
conn.AddQueryHook(db.SQLLogger{})
|
||||
|
||||
client := instance.GetClient()
|
||||
client.InitSchema([]interface{}{
|
||||
(*common.SubscriptionType)(nil),
|
||||
(*common.Subscription)(nil),
|
||||
})
|
||||
|
||||
return instance
|
||||
}
|
||||
|
||||
func insertSubscriptionType(t *testing.T, query queries.SubscriptionTypesInterface) *m.SubscriptionType {
|
||||
subtype := common.SubscriptionType{
|
||||
Name: "NAME",
|
||||
Destination: "ICC",
|
||||
Description: "DESC",
|
||||
Price: 10099,
|
||||
Currency: "USD",
|
||||
DurationValue: 10,
|
||||
DurationUnit: "Hours",
|
||||
}
|
||||
|
||||
result, err := query.Insert(&subtype)
|
||||
if err != nil {
|
||||
result, err := getTestSubscriptionType(query, &subtype)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return result
|
||||
} else {
|
||||
if result.RowsAffected() != 1 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "SubscriptionType insert RowsAffected", 1, result.RowsAffected())
|
||||
}
|
||||
if result.RowsReturned() != 1 {
|
||||
// file insert does not return row
|
||||
t.Errorf(testhelper.TestErrorTemplate, "SubscriptionType insert RowsReturned", 1, result.RowsReturned())
|
||||
}
|
||||
}
|
||||
|
||||
return &subtype
|
||||
}
|
||||
|
||||
func getTestSubscriptionType(query queries.SubscriptionTypesInterface, filter *common.SubscriptionType) (*common.SubscriptionType, error) {
|
||||
subtypes, err := query.Select(filter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(subtypes) == 0 {
|
||||
return nil, errors.New("no found")
|
||||
}
|
||||
return &subtypes[0], nil
|
||||
}
|
||||
|
||||
func selectSubscriptionType(t *testing.T, query queries.SubscriptionTypesInterface, subtypeID uuid.UUID) {
|
||||
subtype := common.SubscriptionType{ID: subtypeID}
|
||||
err := query.Load(&subtype)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "SubscriptionTypes Load", nil, err)
|
||||
}
|
||||
if subtype.Name == "" {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "SubscriptionTypes Load Name", "not empty", subtype.Name)
|
||||
}
|
||||
|
||||
subTypeValues := map[string]interface{}{
|
||||
"Name": "NAME",
|
||||
"Destination": "ICC",
|
||||
"Description": "DESC",
|
||||
"Price": 100.99,
|
||||
"Currency": "USD",
|
||||
"DurationValue": 10,
|
||||
"DurationUnit": "Hours",
|
||||
}
|
||||
testhelper.PropsTester(t, &subtype, subTypeValues)
|
||||
|
||||
subtypes, err := query.Select(&common.SubscriptionType{ID: subtypeID})
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "SubscriptionTypes Select", nil, err)
|
||||
}
|
||||
if len(subtypes) != 1 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "SubscriptionTypes Select count", 1, len(subtypes))
|
||||
} else {
|
||||
testhelper.PropsTester(t, &subtypes[0], subTypeValues)
|
||||
}
|
||||
|
||||
count, err := query.Count(&common.SubscriptionType{ID: subtypeID})
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "UpdateManifest Count", nil, err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "UpdateManifest Count", 1, count)
|
||||
}
|
||||
}
|
||||
|
||||
func updateSubscriptionType(t *testing.T, query queries.SubscriptionTypesInterface, subtypeID uuid.UUID) {
|
||||
subtype := common.SubscriptionType{ID: subtypeID}
|
||||
err := query.Load(&subtype)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "SubscriptionTypes Update Load", nil, err)
|
||||
}
|
||||
|
||||
subtype.Name = "NAME2"
|
||||
result, err := query.Update(&subtype)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "SubscriptionTypes Update", nil, err)
|
||||
}
|
||||
if result.RowsAffected() != 1 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "SubscriptionTypes Update RowsAffected", 1, result.RowsAffected())
|
||||
}
|
||||
if result.RowsReturned() != 0 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "SubscriptionTypes Update RowsReturned", 0, result.RowsReturned())
|
||||
}
|
||||
|
||||
subtype = common.SubscriptionType{ID: subtypeID}
|
||||
err = query.Load(&subtype)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "SubscriptionTypes Update Reload", nil, err)
|
||||
}
|
||||
if subtype.Name != "NAME2" {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "SubscriptionTypes Update", "NAME2", subtype.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func deleteSubscriptionType(t *testing.T, query queries.SubscriptionTypesInterface, subtypeID uuid.UUID) {
|
||||
result, err := query.Delete(&common.SubscriptionType{ID: subtypeID})
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "SubscriptionTypes delete", nil, err)
|
||||
}
|
||||
|
||||
if result.RowsAffected() != 1 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "SubscriptionTypes delete RowsAffected", 1, result.RowsAffected())
|
||||
}
|
||||
|
||||
if result.RowsReturned() != 0 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "SubscriptionTypes delete RowsReturned", 0, result.RowsReturned())
|
||||
}
|
||||
}
|
||||
118
pkg/db/queries/subscriptiontypes.go
Normal file
118
pkg/db/queries/subscriptiontypes.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package queries
|
||||
|
||||
import (
|
||||
"fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/validator"
|
||||
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type SubscriptionTypesInterface interface {
|
||||
Delete(subtype *common.SubscriptionType) (orm.Result, error)
|
||||
Insert(subtype *common.SubscriptionType) (orm.Result, error)
|
||||
Update(subtype *common.SubscriptionType) (orm.Result, error)
|
||||
Select(filter *common.SubscriptionType) ([]common.SubscriptionType, error)
|
||||
Load(subtype *common.SubscriptionType) error
|
||||
Count(filter *common.SubscriptionType) (int, error)
|
||||
}
|
||||
|
||||
type SubscriptionTypes struct {
|
||||
QueryBase
|
||||
}
|
||||
|
||||
// Select returns list of drivers
|
||||
func (st *SubscriptionTypes) Select(filter *common.SubscriptionType) ([]common.SubscriptionType, error) {
|
||||
subtypes := []common.SubscriptionType{}
|
||||
query := st.GetDBConn().Model(&subtypes)
|
||||
|
||||
st.selectFilter(query, filter)
|
||||
|
||||
err := query.Select()
|
||||
|
||||
return subtypes, errors.WithStack(err)
|
||||
}
|
||||
|
||||
func (st *SubscriptionTypes) selectFilter(query *orm.Query, filter *common.SubscriptionType) {
|
||||
if filter.ID != uuid.Nil {
|
||||
query.Where("ID = ?", filter.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func (st *SubscriptionTypes) Insert(subtype *common.SubscriptionType) (orm.Result, error) {
|
||||
return st.insert(subtype)
|
||||
}
|
||||
|
||||
func (st *SubscriptionTypes) Update(subtype *common.SubscriptionType) (orm.Result, error) {
|
||||
err := st.validateSubTypeID(subtype)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return st.GetDBConn().Model(subtype).Column("name", "destination", "price", "currency", "description", "duration_value", "duration_unit").WherePK().Update()
|
||||
}
|
||||
|
||||
func (st *SubscriptionTypes) Delete(subtype *common.SubscriptionType) (orm.Result, error) {
|
||||
err := st.validateSubTypeID(subtype)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tx, err := st.GetDBConn().Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer tx.Close()
|
||||
|
||||
subscriptions := common.Subscription{
|
||||
SubscriptionTypeID: subtype.ID,
|
||||
}
|
||||
_, err = tx.Model(&subscriptions).Where("subscription_type_id = ?subscription_type_id").Delete()
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result, err := tx.Model(subtype).WherePK().Delete()
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (st *SubscriptionTypes) Load(subtype *common.SubscriptionType) error {
|
||||
query := st.GetDBConn().Model(subtype)
|
||||
|
||||
if subtype.ID != uuid.Nil {
|
||||
query.WherePK()
|
||||
} else if subtype.Name != "" {
|
||||
query.Where("name = ?name")
|
||||
} else {
|
||||
return errors.New("no id or name")
|
||||
}
|
||||
|
||||
return query.Select()
|
||||
}
|
||||
|
||||
func (st *SubscriptionTypes) Count(filter *common.SubscriptionType) (int, error) {
|
||||
query := st.GetDBConn().Model((*common.SubscriptionType)(nil))
|
||||
|
||||
st.selectFilter(query, filter)
|
||||
|
||||
return query.Count()
|
||||
}
|
||||
|
||||
func (st *SubscriptionTypes) validateSubTypeID(subtype *common.SubscriptionType) error {
|
||||
if subtype.ID == uuid.Nil {
|
||||
return errors.WithStack(&validator.FieldError{
|
||||
ErrorMsg: "ID required",
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
76
pkg/db/queries/sums_versions.go
Normal file
76
pkg/db/queries/sums_versions.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package queries
|
||||
|
||||
import (
|
||||
"fiskerinc.com/modules/common"
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type SUMSVersionsInterface interface {
|
||||
SelectAll(options *PageQueryOptions) ([]common.SUMSVersion, error)
|
||||
SelectAllCount() (int, error)
|
||||
Insert(u *common.SUMSVersion) (orm.Result, error)
|
||||
Delete(u *common.SUMSVersion) (orm.Result, error)
|
||||
Select(string) (*common.SUMSVersion, error)
|
||||
}
|
||||
|
||||
type SUMSVersions struct {
|
||||
QueryBase
|
||||
}
|
||||
|
||||
func (umv *SUMSVersions) SelectAll(options *PageQueryOptions) ([]common.SUMSVersion, error) {
|
||||
allUpdateManifestVersions := []common.SUMSVersion{}
|
||||
|
||||
q := umv.GetDBConn().Model(&allUpdateManifestVersions)
|
||||
|
||||
// Adding a limit to prevent unreasonably large queries
|
||||
// Expecting a paged query from the front end, so this should not be used
|
||||
if options != nil {
|
||||
umv.pageQuery(q, options)
|
||||
} else {
|
||||
umv.pageQuery(q, &PageQueryOptions{
|
||||
Limit: 500,
|
||||
})
|
||||
}
|
||||
|
||||
err := q.Select()
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return allUpdateManifestVersions, err
|
||||
}
|
||||
|
||||
func (umv *SUMSVersions) SelectAllCount() (int, error) {
|
||||
allUpdateManifestVersions := []common.SUMSVersion{}
|
||||
|
||||
return umv.GetDBConn().Model(&allUpdateManifestVersions).Count()
|
||||
}
|
||||
|
||||
func (umv *SUMSVersions) Insert(u *common.SUMSVersion) (orm.Result, error) {
|
||||
return umv.insert(u)
|
||||
}
|
||||
|
||||
func (umv *SUMSVersions) Delete(u *common.SUMSVersion) (orm.Result, error) {
|
||||
return umv.GetDBConn().Model(u).WherePK().Delete()
|
||||
}
|
||||
|
||||
func (umv *SUMSVersions) Select(version string) (*common.SUMSVersion, error) {
|
||||
allUpdateManifestVersions := []common.SUMSVersion{}
|
||||
|
||||
query := umv.GetDBConn().
|
||||
Model(&allUpdateManifestVersions).
|
||||
Where("version = ?", version).
|
||||
Order("os_version DESC")
|
||||
err := query.Select()
|
||||
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
if len(allUpdateManifestVersions) > 0 {
|
||||
return &allUpdateManifestVersions[0], nil
|
||||
} else {
|
||||
return nil, errors.New("empty result from database")
|
||||
}
|
||||
}
|
||||
104
pkg/db/queries/supplier_accounts.go
Normal file
104
pkg/db/queries/supplier_accounts.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package queries
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"fiskerinc.com/modules/common"
|
||||
|
||||
"github.com/go-pg/pg/v10"
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type SupplierTimestamp string
|
||||
|
||||
const (
|
||||
SupplierTimestampActivated SupplierTimestamp = "activated_at"
|
||||
SupplierTimestampSignIn SupplierTimestamp = "signin_at"
|
||||
SupplierTimestampKeys SupplierTimestamp = "keys_at"
|
||||
)
|
||||
|
||||
type SupplierAccountsInterface interface {
|
||||
Count(account *common.SupplierAccount) (int, error)
|
||||
Delete(account *common.SupplierAccount) (orm.Result, error)
|
||||
Insert(account *common.SupplierAccount) (orm.Result, error)
|
||||
Load(account *common.SupplierAccount) error
|
||||
Select(account *common.SupplierAccount, paging *PageQueryOptions) ([]common.SupplierAccount, error)
|
||||
Update(account *common.SupplierAccount) (orm.Result, error)
|
||||
Approve(email string) (orm.Result, error)
|
||||
UpdateTimestamp(email string, activity SupplierTimestamp) (orm.Result, error)
|
||||
}
|
||||
|
||||
type SupplierAccounts struct {
|
||||
QueryBase
|
||||
}
|
||||
|
||||
func (sf *SupplierAccounts) Count(account *common.SupplierAccount) (int, error) {
|
||||
return sf.count(account)
|
||||
}
|
||||
|
||||
func (sf *SupplierAccounts) Delete(account *common.SupplierAccount) (orm.Result, error) {
|
||||
return sf.delete(account)
|
||||
}
|
||||
|
||||
func (sf *SupplierAccounts) Insert(account *common.SupplierAccount) (orm.Result, error) {
|
||||
return sf.insert(account)
|
||||
}
|
||||
|
||||
func (sf *SupplierAccounts) Load(account *common.SupplierAccount) error {
|
||||
query := sf.GetDBConn().Model(account)
|
||||
|
||||
if account.Email != "" {
|
||||
query.Where("email = ?", account.Email)
|
||||
} else if account.Telephone != "" {
|
||||
query.Where("telephone = ?", account.Telephone)
|
||||
} else {
|
||||
return errors.New("requires email or telephone")
|
||||
}
|
||||
|
||||
err := query.Select()
|
||||
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
func (sf *SupplierAccounts) Select(filter *common.SupplierAccount, paging *PageQueryOptions) ([]common.SupplierAccount, error) {
|
||||
accounts := []common.SupplierAccount{}
|
||||
query := sf.GetDBConn().Model(&accounts)
|
||||
|
||||
sf.selectFilter(query, filter)
|
||||
sf.pageQuery(query, paging)
|
||||
|
||||
err := query.Select()
|
||||
|
||||
return accounts, err
|
||||
}
|
||||
|
||||
func (sf *SupplierAccounts) selectFilter(query *orm.Query, filter *common.SupplierAccount) {
|
||||
if filter.Email != "" {
|
||||
query.Where("email = ?", filter.Email)
|
||||
}
|
||||
|
||||
if filter.Telephone != "" {
|
||||
query.Where("telephone = ?", filter.Telephone)
|
||||
}
|
||||
}
|
||||
|
||||
func (sf *SupplierAccounts) Update(account *common.SupplierAccount) (orm.Result, error) {
|
||||
return sf.GetDBConn().Model(account).Column("supplier_organization_id", "contact", "company", "address", "telephone", "program", "ecus").Where("email = ?", account.Email).Update()
|
||||
}
|
||||
|
||||
func (sf *SupplierAccounts) Approve(email string) (orm.Result, error) {
|
||||
return sf.UpdateTimestamp(email, SupplierTimestampActivated)
|
||||
}
|
||||
|
||||
func (sf *SupplierAccounts) UpdateTimestamp(email string, activity SupplierTimestamp) (orm.Result, error) {
|
||||
result, err := sf.GetDBConn().Model(&common.SupplierAccount{}).Set(fmt.Sprintf("%s = CURRENT_TIMESTAMP", string(activity))).Where("email = ?", email).Update()
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
if result.RowsAffected() == 0 {
|
||||
return nil, errors.WithStack(pg.ErrNoRows)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
92
pkg/db/queries/supplier_organizations.go
Normal file
92
pkg/db/queries/supplier_organizations.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package queries
|
||||
|
||||
import (
|
||||
"fiskerinc.com/modules/common"
|
||||
s "fiskerinc.com/modules/security"
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
)
|
||||
|
||||
type SupplierOrganizationsInterface interface {
|
||||
Count(account *common.SupplierOrganization) (int, error)
|
||||
Insert(supplierOrganization *common.SupplierOrganization) (orm.Result, error)
|
||||
Select(account *common.SupplierOrganization, paging *PageQueryOptions) ([]common.SupplierOrganization, error)
|
||||
}
|
||||
|
||||
type SupplierOrganization struct {
|
||||
QueryBase
|
||||
}
|
||||
|
||||
func (so *SupplierOrganization) Count(supplierOrganization *common.SupplierOrganization) (int, error) {
|
||||
return so.count(supplierOrganization)
|
||||
}
|
||||
|
||||
func (so *SupplierOrganization) Insert(supplierOrganization *common.SupplierOrganization) (orm.Result, error) {
|
||||
enc := s.Encrypt{}
|
||||
encryptor, err := enc.GetEncryptor()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
supplierOrganization.PubKey.SetBytes(encryptor.EncryptChunk(supplierOrganization.PubKey.Bytes()))
|
||||
supplierOrganization.PrivKey.SetBytes(encryptor.EncryptChunk(supplierOrganization.PrivKey.Bytes()))
|
||||
|
||||
return so.GetDBConn().Model(supplierOrganization).Insert()
|
||||
}
|
||||
|
||||
func (so *SupplierOrganization) Select(filter *common.SupplierOrganization, paging *PageQueryOptions) ([]common.SupplierOrganization, error) {
|
||||
supplierOrganizations := []common.SupplierOrganization{}
|
||||
query := so.GetDBConn().Model(&supplierOrganizations)
|
||||
|
||||
so.selectFilter(query, filter)
|
||||
|
||||
if paging != nil {
|
||||
so.pageQuery(query, paging)
|
||||
}
|
||||
|
||||
err := query.Select()
|
||||
|
||||
for _, supplierOrganization := range supplierOrganizations {
|
||||
err = so.decryptSupplierOrganization(&supplierOrganization)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return supplierOrganizations, err
|
||||
}
|
||||
|
||||
func (sf *SupplierOrganization) selectFilter(query *orm.Query, filter *common.SupplierOrganization) {
|
||||
if filter.DomainName != "" {
|
||||
query.Where("domain_name = ?", filter.DomainName)
|
||||
}
|
||||
|
||||
if filter.SupplierOrganizationID != 0 {
|
||||
query.Where("supplier_organization_id = ?", filter.SupplierOrganizationID)
|
||||
}
|
||||
}
|
||||
|
||||
func (si *SupplierOrganization) decryptSupplierOrganization(supplierOrganization *common.SupplierOrganization) error {
|
||||
enc := s.Encrypt{}
|
||||
encryptor, err := enc.GetEncryptor()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if supplierOrganization.PubKey != nil {
|
||||
key, err := encryptor.DecryptChunk(supplierOrganization.PubKey.Bytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
supplierOrganization.PubKey.SetBytes(key)
|
||||
}
|
||||
|
||||
if supplierOrganization.PrivKey != nil {
|
||||
key, err := encryptor.DecryptChunk(supplierOrganization.PrivKey.Bytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
supplierOrganization.PrivKey.SetBytes(key)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
67
pkg/db/queries/swversion_rxswin.go
Normal file
67
pkg/db/queries/swversion_rxswin.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package queries
|
||||
|
||||
import (
|
||||
"fiskerinc.com/modules/common"
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type SwVersionRxSwinInterface interface {
|
||||
SelectByVersion(version string, options *PageQueryOptions) ([]common.SwVersionRxSwin, error)
|
||||
SelectCountByVersion(version string) (int, error)
|
||||
Insert(swVersionRxSwin *common.SwVersionRxSwin) (orm.Result, error)
|
||||
Delete(model *common.SwVersionRxSwin) (orm.Result, error)
|
||||
}
|
||||
|
||||
type SwVersionRxSwin struct {
|
||||
QueryBase
|
||||
}
|
||||
|
||||
func (svrs *SwVersionRxSwin) SelectByVersion(version string, options *PageQueryOptions) ([]common.SwVersionRxSwin, error) {
|
||||
swVersionRxSwins := []common.SwVersionRxSwin{}
|
||||
query := svrs.GetDBConn().Model(&swVersionRxSwins).Where("version = ?", version)
|
||||
|
||||
svrs.pageQuery(query, options)
|
||||
|
||||
err := query.Select()
|
||||
|
||||
return swVersionRxSwins, errors.WithStack(err)
|
||||
}
|
||||
|
||||
func (svrs *SwVersionRxSwin) SelectCountByVersion(version string) (int, error) {
|
||||
query := svrs.GetDBConn().Model(&[]common.SwVersionRxSwin{}).Where("version = ?", version)
|
||||
|
||||
return query.Count()
|
||||
}
|
||||
|
||||
func (svrs *SwVersionRxSwin) Insert(swVersionRxSwin *common.SwVersionRxSwin) (orm.Result, error) {
|
||||
return svrs.insert(swVersionRxSwin)
|
||||
}
|
||||
|
||||
func (svrs *SwVersionRxSwin) Delete(model *common.SwVersionRxSwin) (orm.Result, error) {
|
||||
var err error
|
||||
total := &ORMResults{}
|
||||
tx, err := svrs.GetDBConn().Begin()
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
tx.Close()
|
||||
}()
|
||||
|
||||
result, err := tx.Model(model).WherePK().Delete()
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
total.AddResult(result)
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
|
||||
return total, nil
|
||||
}
|
||||
88
pkg/db/queries/symkeys.go
Normal file
88
pkg/db/queries/symkeys.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package queries
|
||||
|
||||
import (
|
||||
"fiskerinc.com/modules/common"
|
||||
s "fiskerinc.com/modules/security"
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type SymKeysInterface interface {
|
||||
Insert(keys common.SymKeys) (orm.Result, error)
|
||||
SelectAll() ([]common.SymKeys, error)
|
||||
SelectByVIN(vin string) (common.SymKeys, error)
|
||||
}
|
||||
|
||||
type SymKeys struct {
|
||||
QueryBase
|
||||
}
|
||||
|
||||
func (sk *SymKeys) Insert(keys common.SymKeys) (orm.Result, error) {
|
||||
enc := s.Encrypt{}
|
||||
encryptor, err := enc.GetEncryptor()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keys.SecOC.SetBytes(encryptor.EncryptChunk(keys.SecOC.Bytes()))
|
||||
keys.SecureLogging.SetBytes(encryptor.EncryptChunk(keys.SecureLogging.Bytes()))
|
||||
keys.SecureStorage.SetBytes(encryptor.EncryptChunk(keys.SecureStorage.Bytes()))
|
||||
|
||||
return sk.GetDBConn().Model(&keys).Insert()
|
||||
}
|
||||
|
||||
func (sk *SymKeys) SelectAll() ([]common.SymKeys, error) {
|
||||
symkeys := []common.SymKeys{}
|
||||
|
||||
err := sk.GetDBConn().Model(&symkeys).Select()
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
for i := range symkeys {
|
||||
err = sk.decrypt(&symkeys[i])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return symkeys, err
|
||||
}
|
||||
|
||||
func (sk *SymKeys) SelectByVIN(vin string) (common.SymKeys, error) {
|
||||
symkey := common.SymKeys{}
|
||||
|
||||
err := sk.GetDBConn().Model(&symkey).Where("vin = ?", vin).Select()
|
||||
if err != nil {
|
||||
return symkey, errors.WithStack(err)
|
||||
}
|
||||
|
||||
err = sk.decrypt(&symkey)
|
||||
return symkey, err
|
||||
}
|
||||
|
||||
func (sk *SymKeys) decrypt(symkeys *common.SymKeys) error {
|
||||
enc := s.Encrypt{}
|
||||
encryptor, err := enc.GetEncryptor()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
secoc, err := encryptor.DecryptChunk(symkeys.SecOC.Bytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
symkeys.SecOC.SetBytes(secoc)
|
||||
|
||||
logging, err := encryptor.DecryptChunk(symkeys.SecureLogging.Bytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
symkeys.SecureLogging.SetBytes(logging)
|
||||
|
||||
storage, err := encryptor.DecryptChunk(symkeys.SecureStorage.Bytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
symkeys.SecureStorage.SetBytes(storage)
|
||||
|
||||
return nil
|
||||
}
|
||||
34
pkg/db/queries/tags.go
Normal file
34
pkg/db/queries/tags.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package queries
|
||||
|
||||
import (
|
||||
"fiskerinc.com/modules/common"
|
||||
"github.com/go-pg/pg/v10"
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
)
|
||||
|
||||
type TagsInterface interface {
|
||||
QueryBaseInterface
|
||||
Update(c *common.Car) (orm.Result, error)
|
||||
UpdateDistinctTags(car *common.Car) (orm.Result, error)
|
||||
}
|
||||
|
||||
func NewTags() TagsInterface {
|
||||
return &Tags{}
|
||||
}
|
||||
|
||||
type Tags struct {
|
||||
QueryBase
|
||||
}
|
||||
|
||||
func (t *Tags) Update(car *common.Car) (orm.Result, error) {
|
||||
return t.resultWithStack(t.GetDBConn().Model(car).
|
||||
Column("tags").
|
||||
Where("vin = ?", car.VIN).Update())
|
||||
}
|
||||
|
||||
func (t *Tags) UpdateDistinctTags(car *common.Car) (orm.Result, error) {
|
||||
return t.resultWithStack(t.GetDBConn().Model(car).
|
||||
Where("vin = ?", car.VIN).
|
||||
Set("tags = array(SELECT DISTINCT unnest(tags || ?))", pg.Array(car.Tags)).
|
||||
Update())
|
||||
}
|
||||
22
pkg/db/queries/update_manifests_mode.go
Normal file
22
pkg/db/queries/update_manifests_mode.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package queries
|
||||
|
||||
import "github.com/go-pg/pg/v10/orm"
|
||||
|
||||
type UpdateManifestMode interface {
|
||||
LoadRelations(query *orm.Query) error
|
||||
SelectByVINCondition(query *orm.Query) *orm.Query
|
||||
}
|
||||
|
||||
type DefaultMode struct{}
|
||||
|
||||
func (DefaultMode) LoadRelations(query *orm.Query) error {
|
||||
return query.Relation("ECUs").
|
||||
Relation("ECUs.Files").
|
||||
Relation("ECUs.Files.WriteRegion").
|
||||
Relation("ECUs.Files.EraseRegion").
|
||||
Select()
|
||||
}
|
||||
|
||||
func (DefaultMode) SelectByVINCondition(query *orm.Query) *orm.Query {
|
||||
return query
|
||||
}
|
||||
566
pkg/db/queries/updatemanifests.go
Normal file
566
pkg/db/queries/updatemanifests.go
Normal file
@@ -0,0 +1,566 @@
|
||||
package queries
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/db"
|
||||
"fiskerinc.com/modules/logger"
|
||||
"fiskerinc.com/modules/validator"
|
||||
|
||||
"github.com/go-pg/pg/v10"
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type UpdateManifestsInterface interface {
|
||||
QueryBaseInterface
|
||||
Archive(ids []int64, active bool) (orm.Result, error)
|
||||
Count(filter common.UpdateManifest) (int, error)
|
||||
Delete(manifest *common.UpdateManifest) (orm.Result, error)
|
||||
Insert(manifest *common.UpdateManifest) (orm.Result, error)
|
||||
Select(filter *common.UpdateManifest, paging *PageQueryOptions) ([]common.UpdateManifest, error)
|
||||
SelectByVIN(vin string, paging *PageQueryOptions) ([]common.StatusManifest, error)
|
||||
Update(manifest *common.UpdateManifest) (orm.Result, error)
|
||||
Load(manifest *common.UpdateManifest) error
|
||||
Search(filter common.UpdateManifestSearch, paging *PageQueryOptions) ([]common.UpdateManifest, error)
|
||||
SearchCount(filter common.UpdateManifestSearch) (int, error)
|
||||
ECURollback(man *common.UpdateManifestECU, vin string) ([]*common.UpdateManifestECU, error)
|
||||
ECUInsert(manifest *common.UpdateManifestECU) (orm.Result, error)
|
||||
FileInsert(manifest *common.UpdateManifestFile) (orm.Result, error)
|
||||
AddSUMSVersion(manifest *common.UpdateManifest) (orm.Result, error)
|
||||
SelectFlashPackByVersion(versionNumber string) (manifest common.UpdateManifest, err error)
|
||||
}
|
||||
|
||||
// NewUpdateManifest returns UpdateManifestInterface.
|
||||
// mode allows to use more flexible requests.
|
||||
// If mode is nil, default mode is used.
|
||||
func NewUpdateManifest(mode UpdateManifestMode) UpdateManifestsInterface {
|
||||
if mode == nil {
|
||||
return &UpdateManifests{
|
||||
mode: DefaultMode{},
|
||||
}
|
||||
}
|
||||
|
||||
return &UpdateManifests{
|
||||
mode: mode,
|
||||
}
|
||||
}
|
||||
|
||||
type UpdateManifests struct {
|
||||
mode UpdateManifestMode
|
||||
QueryBase
|
||||
}
|
||||
|
||||
func (um *UpdateManifests) Count(filter common.UpdateManifest) (int, error) {
|
||||
query := um.GetDBConn().Model((*common.UpdateManifest)(nil))
|
||||
|
||||
um.selectFilter(query, &filter)
|
||||
|
||||
return query.Count()
|
||||
}
|
||||
|
||||
func (um *UpdateManifests) Delete(manifest *common.UpdateManifest) (orm.Result, error) {
|
||||
total := ORMResults{}
|
||||
|
||||
err := validator.ValidateIDField(manifest.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = um.Load(manifest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tx, err := um.GetDBConn().Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = um.manifestDelete(tx, &total, manifest)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &total, err
|
||||
}
|
||||
|
||||
func (um *UpdateManifests) Insert(manifest *common.UpdateManifest) (orm.Result, error) {
|
||||
total := ORMResults{}
|
||||
|
||||
tx, err := um.GetDBConn().Begin()
|
||||
if err != nil {
|
||||
err = errors.WithStack(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = um.manifestInsert(tx, &total, manifest)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
err = errors.WithStack(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
err = errors.WithStack(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &total, nil
|
||||
}
|
||||
|
||||
func (um *UpdateManifests) ECUInsert(ecu *common.UpdateManifestECU) (orm.Result, error) {
|
||||
total := ORMResults{}
|
||||
|
||||
tx, err := um.GetDBConn().Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = um.ecuInsert(tx, &total, ecu)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &total, nil
|
||||
}
|
||||
|
||||
func (um *UpdateManifests) FileInsert(file *common.UpdateManifestFile) (orm.Result, error) {
|
||||
total := ORMResults{}
|
||||
|
||||
tx, err := um.GetDBConn().Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = um.fileInsert(tx, &total, file)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &total, nil
|
||||
}
|
||||
|
||||
func (um *UpdateManifests) Select(filter *common.UpdateManifest, paging *PageQueryOptions) ([]common.UpdateManifest, error) {
|
||||
manifests := []common.UpdateManifest{}
|
||||
query := um.GetDBConn().Model(&manifests)
|
||||
|
||||
um.selectFilter(query, filter)
|
||||
um.pageQuery(query, paging)
|
||||
|
||||
err := query.Select()
|
||||
|
||||
return manifests, errors.WithStack(err)
|
||||
}
|
||||
|
||||
func (um *UpdateManifests) Archive(ids []int64, active bool) (orm.Result, error) {
|
||||
manifests := []common.UpdateManifest{}
|
||||
return um.resultWithStack(
|
||||
um.GetDBConn().Model(&manifests).
|
||||
Set("active = ?", active).
|
||||
Where("id IN (?)", pg.In(ids)).
|
||||
Update(),
|
||||
)
|
||||
}
|
||||
|
||||
func (um *UpdateManifests) SelectByVIN(vin string, paging *PageQueryOptions) ([]common.StatusManifest, error) {
|
||||
manifests := []common.StatusManifest{}
|
||||
query := um.GetDBConn().Model(&manifests)
|
||||
query.
|
||||
ColumnExpr("status_manifest.*").
|
||||
ColumnExpr("car_updates.status as status").
|
||||
ColumnExpr("car_updates.updated_at as status_updated").
|
||||
ColumnExpr("status_manifest.created_at as manifest_created").
|
||||
Join("LEFT JOIN car_updates ON car_updates.update_manifest_id=status_manifest.id").
|
||||
Where("vin = ?", vin)
|
||||
|
||||
query = um.mode.SelectByVINCondition(query)
|
||||
query = um.pageQuery(query, paging)
|
||||
err := query.Select()
|
||||
|
||||
return manifests, errors.WithStack(err)
|
||||
}
|
||||
|
||||
func (um *UpdateManifests) Update(manifest *common.UpdateManifest) (orm.Result, error) {
|
||||
err := validator.ValidateIDField(manifest.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
query := um.GetDBConn().Model(manifest).Column(
|
||||
"name",
|
||||
"release_notes",
|
||||
"type",
|
||||
"fingerprint",
|
||||
"rollback_enabled",
|
||||
"update_duration",
|
||||
"max_attempts",
|
||||
)
|
||||
if manifest.Active != nil {
|
||||
query = query.Column("active")
|
||||
}
|
||||
if manifest.Env != "" {
|
||||
query = query.Column("env")
|
||||
}
|
||||
if manifest.SUMS != "" {
|
||||
query = query.Column("sums")
|
||||
}
|
||||
|
||||
return query.WherePK().Update()
|
||||
}
|
||||
|
||||
func (um *UpdateManifests) Load(manifest *common.UpdateManifest) error {
|
||||
query := um.GetDBConn().Model(manifest)
|
||||
|
||||
if manifest.ID > 0 {
|
||||
query.WherePK()
|
||||
} else if manifest.Name != "" && manifest.Version != "" {
|
||||
query.Where("name = ?name AND version = ?version")
|
||||
} else {
|
||||
return errors.New("no id, name or version")
|
||||
}
|
||||
|
||||
// For Magna, we will add the manifest type, then we will only get it when it is equal to two
|
||||
if manifest.ManifestType > 0 {
|
||||
query.Where("manifest_type = ?", manifest.ManifestType)
|
||||
}
|
||||
|
||||
if manifest.Env == "" {
|
||||
manifest.Env = common.EnvCurrent
|
||||
}
|
||||
|
||||
return um.mode.LoadRelations(query)
|
||||
}
|
||||
|
||||
func (um *UpdateManifests) Search(filter common.UpdateManifestSearch, paging *PageQueryOptions) ([]common.UpdateManifest, error) {
|
||||
manifests := []common.UpdateManifest{}
|
||||
query := um.GetDBConn().Model(&manifests)
|
||||
|
||||
um.searchFilter(query, &filter)
|
||||
um.pageQuery(query, paging)
|
||||
|
||||
err := query.Select()
|
||||
|
||||
return manifests, errors.WithStack(err)
|
||||
}
|
||||
|
||||
func (um *UpdateManifests) SearchCount(filter common.UpdateManifestSearch) (int, error) {
|
||||
query := um.GetDBConn().Model((*common.UpdateManifest)(nil))
|
||||
|
||||
um.searchFilter(query, &filter)
|
||||
|
||||
return query.Count()
|
||||
}
|
||||
|
||||
// Used to get rollbacks for ecu's
|
||||
func (um *UpdateManifests) ECURollback(man *common.UpdateManifestECU, vin string) (rollBackManifest []*common.UpdateManifestECU, err error) {
|
||||
// Use our new rollback system, if it fails return the old get for rollbacks
|
||||
// This may fail as the car_ecus system is a bit corrupted in data
|
||||
targetedRollback, err := um.getECURollbackSpecificToCar(man, vin)
|
||||
if err == nil && targetedRollback != nil {
|
||||
rollBackManifest = append(rollBackManifest, targetedRollback)
|
||||
return
|
||||
}
|
||||
|
||||
return um.getRollbacksListOfPossible(man)
|
||||
}
|
||||
|
||||
func (um *UpdateManifests) getECURollbackSpecificToCar(man *common.UpdateManifestECU, vin string) (rollbackManifest *common.UpdateManifestECU, err error) {
|
||||
conn := um.GetDBConn()
|
||||
targetCarECU := common.CarECU{}
|
||||
targetCarECU.VIN = vin
|
||||
targetCarECU.ECU = man.ECU
|
||||
// Using the ecu and vin, get the latest carECU for that specific car
|
||||
err = conn.Model(&targetCarECU).
|
||||
Where("vin = ?", targetCarECU.VIN).
|
||||
Where("ecu = ?", targetCarECU.ECU).
|
||||
Where("version != 'unavailable'"). // We have so many of these though, like its not something that you can even roll back to
|
||||
Order("epoch_usec DESC").
|
||||
Limit(1).
|
||||
Select()
|
||||
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msgf("failed to get car ecu for rollback for car %s ecu %s", vin, man.ECU)
|
||||
return
|
||||
}
|
||||
|
||||
if targetCarECU.Version == "" {
|
||||
logger.Warn().Interface("target car ecu", targetCarECU).Msgf("failed to fetch a previous ecu version for car %s ecu %s", vin, man.ECU)
|
||||
err = fmt.Errorf("getRollBackNew error")
|
||||
return
|
||||
}
|
||||
|
||||
targetManifest := common.UpdateManifestECU{}
|
||||
err = conn.Model(&targetManifest).
|
||||
Where("version = ?", targetCarECU.Version).
|
||||
Where("ecu = ?", targetCarECU.ECU).
|
||||
Select()
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msgf("failed to get ecu rollback update manifest ecu for car %s", vin)
|
||||
return
|
||||
}
|
||||
if targetManifest.ID == 0 {
|
||||
logger.Warn().Interface("target ecu", targetCarECU).Msgf("failed to get ecu rollback update manifest ecu for car %s", vin)
|
||||
err = fmt.Errorf("getRollBackNew error")
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (um *UpdateManifests) getRollbacksListOfPossible(man *common.UpdateManifestECU) (rollBackManifest []*common.UpdateManifestECU, err error) {
|
||||
var manifests []*common.UpdateManifestECU
|
||||
conn := um.GetDBConn()
|
||||
sub := conn.Model(&manifests).ColumnExpr("version, max(created_at)").
|
||||
Where("ecu = ?", man.ECU).
|
||||
Where("hw_versions = ?", pg.Array(man.HWVersions)).
|
||||
Where("version != ?", man.Version).
|
||||
Group("version")
|
||||
err = conn.
|
||||
Model(&manifests).
|
||||
With("sub", sub).
|
||||
Column("update_manifest_ecu.version", "id").
|
||||
Join("INNER JOIN sub ON update_manifest_ecu.version = sub.version AND "+
|
||||
"update_manifest_ecu.created_at = sub.max").
|
||||
Where("ecu = ?", man.ECU).
|
||||
Where("hw_versions = ?", pg.Array(man.HWVersions)).
|
||||
Where("update_manifest_ecu.version != ?", man.Version).
|
||||
Relation("Files").
|
||||
Relation("Files.WriteRegion").
|
||||
Relation("Files.EraseRegion").
|
||||
Order("version").
|
||||
Select()
|
||||
if err != nil {
|
||||
errors.WithMessagef(err, "failed to do ECURollback() for manifest_id %d on ECU %s id: %d", man.UpdateManifestID, man.ECU, man.ID)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (um *UpdateManifests) manifestDelete(operation db.Operation, resultsTotal *ORMResults, manifest *common.UpdateManifest) (orm.Result, error) {
|
||||
// Removed manifest.ecuDelete as the deletes are now propagated by the database
|
||||
result, err := um.resultWithStack(operation.Model(manifest).WherePK().Delete())
|
||||
if um.hasErrorResult(resultsTotal, result, err) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resultsTotal, nil
|
||||
}
|
||||
|
||||
func (um *UpdateManifests) manifestInsert(operation db.Operation, resultsTotal *ORMResults, manifest *common.UpdateManifest) (orm.Result, error) {
|
||||
result, err := um.resultWithStack(operation.Model(manifest).Insert())
|
||||
if um.hasErrorResult(resultsTotal, result, err) {
|
||||
return nil, err
|
||||
}
|
||||
resultsTotal.SetModel(result.Model())
|
||||
|
||||
if manifest.ECUs != nil {
|
||||
for i := range manifest.ECUs {
|
||||
ecu := manifest.ECUs[i]
|
||||
ecu.UpdateManifestID = manifest.ID
|
||||
result, err = um.ecuInsert(operation, resultsTotal, ecu)
|
||||
if um.hasErrorResult(resultsTotal, result, err) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return resultsTotal, nil
|
||||
}
|
||||
|
||||
func (um *UpdateManifests) ecuDelete(operation db.Operation, resultsTotal *ORMResults, ecu *common.UpdateManifestECU) (orm.Result, error) {
|
||||
if ecu.Files != nil {
|
||||
for i := range ecu.Files {
|
||||
file := ecu.Files[i]
|
||||
result, err := um.fileDelete(operation, resultsTotal, file)
|
||||
if um.hasErrorResult(resultsTotal, result, err) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result, err := operation.Model(ecu).WherePK().Delete()
|
||||
if um.hasErrorResult(resultsTotal, result, err) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resultsTotal, nil
|
||||
}
|
||||
|
||||
func (um *UpdateManifests) ecuInsert(operation db.Operation, resultsTotal *ORMResults, ecu *common.UpdateManifestECU) (orm.Result, error) {
|
||||
result, err := um.resultWithStack(operation.Model(ecu).Insert())
|
||||
if um.hasErrorResult(resultsTotal, result, err) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ecu.Files != nil {
|
||||
for i := range ecu.Files {
|
||||
file := ecu.Files[i]
|
||||
file.UpdateManifestECUID = ecu.ID
|
||||
result, err := um.fileInsert(operation, resultsTotal, file)
|
||||
if um.hasErrorResult(resultsTotal, result, err) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return resultsTotal, nil
|
||||
}
|
||||
|
||||
func (um *UpdateManifests) fileDelete(operation db.Operation, resultsTotal *ORMResults, file *common.UpdateManifestFile) (orm.Result, error) {
|
||||
result, err := um.resultWithStack(operation.Model(file).WherePK().Delete())
|
||||
if um.hasErrorResult(resultsTotal, result, err) {
|
||||
return nil, err
|
||||
}
|
||||
resultsTotal.SetModel(result.Model())
|
||||
|
||||
result, err = um.resultWithStack(operation.Model(&common.MemoryRegion{ID: file.WriteRegionID}).WherePK().Delete())
|
||||
if um.hasErrorResult(resultsTotal, result, err) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if file.EraseRegionID != 0 {
|
||||
result, err = um.resultWithStack(operation.Model(&common.MemoryRegion{ID: file.EraseRegionID}).WherePK().Delete())
|
||||
if um.hasErrorResult(resultsTotal, result, err) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
result, err = um.fileKeyDelete(operation, resultsTotal, file)
|
||||
if um.hasErrorResult(resultsTotal, result, err) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resultsTotal, nil
|
||||
}
|
||||
|
||||
func (um *UpdateManifests) fileKeyDelete(operation db.Operation, resultsTotal *ORMResults, file *common.UpdateManifestFile) (orm.Result, error) {
|
||||
fk := &common.FileKey{FileID: file.FileID}
|
||||
result, err := um.resultWithStack(operation.Model(fk).WherePK().Delete())
|
||||
if um.hasErrorResult(resultsTotal, result, err) {
|
||||
return nil, err
|
||||
}
|
||||
return resultsTotal, err
|
||||
}
|
||||
|
||||
func (um *UpdateManifests) fileInsert(operation db.Operation, resultsTotal *ORMResults, file *common.UpdateManifestFile) (orm.Result, error) {
|
||||
result, err := um.resultWithStack(operation.Model(&file.WriteRegion).Insert())
|
||||
if um.hasErrorResult(resultsTotal, result, err) {
|
||||
return nil, err
|
||||
}
|
||||
file.WriteRegionID = file.WriteRegion.ID
|
||||
|
||||
if file.EraseRegion != nil {
|
||||
result, err = um.resultWithStack(operation.Model(file.EraseRegion).Insert())
|
||||
if um.hasErrorResult(resultsTotal, result, err) {
|
||||
return nil, err
|
||||
}
|
||||
file.EraseRegionID = file.EraseRegion.ID
|
||||
}
|
||||
|
||||
result, err = um.resultWithStack(operation.Model(file).Insert())
|
||||
if um.hasErrorResult(resultsTotal, result, err) {
|
||||
return nil, err
|
||||
}
|
||||
resultsTotal.SetModel(result.Model())
|
||||
|
||||
return resultsTotal, nil
|
||||
}
|
||||
|
||||
func (um *UpdateManifests) selectFilter(query *orm.Query, filter *common.UpdateManifest) {
|
||||
if filter.ID > 0 {
|
||||
query.Where("id = ?", filter.ID)
|
||||
}
|
||||
|
||||
if filter.Name != "" {
|
||||
query.Where("name = ?", filter.Name)
|
||||
}
|
||||
|
||||
if filter.Version != "" {
|
||||
query.Where("version = ?", filter.Version)
|
||||
}
|
||||
|
||||
if filter.Description != "" {
|
||||
query.Where("description LIKE ?", filter.Description)
|
||||
}
|
||||
|
||||
if filter.ManifestType > 0 {
|
||||
query.Where("manifest_type = ?", filter.ManifestType)
|
||||
}
|
||||
|
||||
if filter.Active != nil {
|
||||
query.Where("active = ?", filter.Active)
|
||||
}
|
||||
|
||||
if filter.Country != "" {
|
||||
query.Where("country = ? or country is null", filter.Country)
|
||||
}
|
||||
|
||||
if filter.PowerTrain != "" {
|
||||
query.Where("powertrain = ? or powertrain is null", filter.PowerTrain)
|
||||
}
|
||||
|
||||
if filter.Restraint != "" {
|
||||
query.Where("restraint = ? or restraint is null", filter.Restraint)
|
||||
}
|
||||
|
||||
if filter.Model != "" {
|
||||
query.Where("model = ? or model is null", filter.Model)
|
||||
}
|
||||
|
||||
if filter.Trim != "" {
|
||||
query.Where("trim = ? or trim is null", filter.Trim)
|
||||
}
|
||||
|
||||
if filter.Year != 0 {
|
||||
query.Where("year = ? or year is null", filter.Year)
|
||||
}
|
||||
|
||||
if filter.BodyType != "" {
|
||||
query.Where("body_type = ? or body_type is null", filter.BodyType)
|
||||
}
|
||||
}
|
||||
|
||||
func (um *UpdateManifests) searchFilter(query *orm.Query, filter *common.UpdateManifestSearch) {
|
||||
if filter.Search != "" {
|
||||
query.Where("document @@ plainto_tsquery(?)", filter.Search)
|
||||
}
|
||||
|
||||
if filter.ManifestType == common.SoftwareUpdateType {
|
||||
filter.ManifestType = 0
|
||||
query.WhereInMulti("manifest_type in (?)", common.SoftwareUpdateType, common.MagnaManifestUpdateType, common.AftersalesUpdateType)
|
||||
}
|
||||
|
||||
um.selectFilter(query, &filter.UpdateManifest)
|
||||
}
|
||||
|
||||
func (um *UpdateManifests) AddSUMSVersion(manifest *common.UpdateManifest) (orm.Result, error) {
|
||||
return um.resultWithStack(um.GetDBConn().Model(manifest).Column("sums").Where("id = ?id AND sums IS NULL").Update())
|
||||
}
|
||||
|
||||
// Given a flashpack version number: 14.4, we will return the corresponding flashpack
|
||||
func (um *UpdateManifests) SelectFlashPackByVersion(versionNumber string) (manifest common.UpdateManifest, err error) {
|
||||
// Prepare for Like Statement
|
||||
versionNumber = "%" + versionNumber
|
||||
err = um.GetDBConn().Model(&manifest).Where("manifest_type = ? AND name LIKE ?", common.AsBuiltUpdateType, versionNumber).Limit(1).Select()
|
||||
return
|
||||
}
|
||||
313
pkg/db/queries/updatemanifests_test.go
Normal file
313
pkg/db/queries/updatemanifests_test.go
Normal file
@@ -0,0 +1,313 @@
|
||||
package queries_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/db"
|
||||
"fiskerinc.com/modules/db/queries"
|
||||
"fiskerinc.com/modules/testhelper"
|
||||
)
|
||||
|
||||
func TestUpdatePackagesManifest(t *testing.T) {
|
||||
t.Skip()
|
||||
query := setupUpdateManifest(t)
|
||||
manifestID := insertUpdateManifest(t, query)
|
||||
if manifestID == 0 {
|
||||
t.Error("setup failed")
|
||||
return
|
||||
}
|
||||
selectUpdateManifest(t, query, manifestID)
|
||||
searchUpdateManifest(t, query, manifestID)
|
||||
updateUpdateManifest(t, query, manifestID)
|
||||
archiveUpdateManifest(t, query, []int64{manifestID}, true)
|
||||
deleteUpdateManifest(t, query, manifestID)
|
||||
}
|
||||
|
||||
func setupUpdateManifest(t *testing.T) queries.UpdateManifestsInterface {
|
||||
instance := queries.NewUpdateManifest(nil)
|
||||
conn := instance.GetDBConn()
|
||||
conn.AddQueryHook(db.SQLLogger{})
|
||||
|
||||
client := instance.GetClient()
|
||||
client.InitSchema([]interface{}{
|
||||
(*common.UpdateManifest)(nil),
|
||||
(*common.UpdateManifestECU)(nil),
|
||||
(*common.UpdateManifestFile)(nil),
|
||||
(*common.MemoryRegion)(nil),
|
||||
})
|
||||
|
||||
return instance
|
||||
}
|
||||
|
||||
func insertUpdateManifest(t *testing.T, query queries.UpdateManifestsInterface) int64 {
|
||||
expectedRows := 20
|
||||
manifest := common.UpdateManifest{
|
||||
Name: fmt.Sprintf("NAME %s", time.Now().String()),
|
||||
Version: "VERSION",
|
||||
Description: "DESCRIPTION",
|
||||
ReleaseNotes: "RELEASENOTES",
|
||||
RollbackEnabled: true,
|
||||
Type: "standard",
|
||||
Fingerprint: "10203040",
|
||||
Country: "US",
|
||||
PowerTrain: "MD23",
|
||||
Restraint: "None",
|
||||
Model: "Ocean",
|
||||
Trim: "Sport",
|
||||
Year: 2022,
|
||||
BodyType: "truck",
|
||||
ECUs: []*common.UpdateManifestECU{
|
||||
{
|
||||
ECU: "ECU",
|
||||
Version: "VERSION",
|
||||
HWVersion: "BIGBADVERSION",
|
||||
HWVersions: []string{"HWVERSION"},
|
||||
ConfigurationMask: "CONFIGURATIONMASK",
|
||||
Mode: "D",
|
||||
SelfDownload: true,
|
||||
Files: []*common.UpdateManifestFile{
|
||||
{
|
||||
FileID: "FILEID",
|
||||
URL: "URL",
|
||||
Filename: "FILENAME",
|
||||
FileSize: 5,
|
||||
Checksum: "CHECKSUM",
|
||||
FileType: "bootloader",
|
||||
FileOrder: 9,
|
||||
WriteRegionID: 100,
|
||||
WriteRegion: common.MemoryRegion{
|
||||
Offset: 101,
|
||||
Length: 102,
|
||||
},
|
||||
EraseRegionID: 200,
|
||||
EraseRegion: &common.MemoryRegion{
|
||||
Offset: 201,
|
||||
Length: 202,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
result, err := query.Insert(&manifest)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "UpdateManifest insert", nil, err)
|
||||
return 0
|
||||
}
|
||||
|
||||
if result.RowsAffected() != expectedRows {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "UpdateManifest insert RowsAffected", expectedRows, result.RowsAffected())
|
||||
}
|
||||
if result.RowsReturned() != expectedRows {
|
||||
// file insert does not return row
|
||||
t.Errorf(testhelper.TestErrorTemplate, "UpdateManifest insert RowsReturned", expectedRows, result.RowsReturned())
|
||||
}
|
||||
|
||||
return manifest.ID
|
||||
}
|
||||
|
||||
func selectUpdateManifest(t *testing.T, query queries.UpdateManifestsInterface, manifestID int64) {
|
||||
manifest := common.UpdateManifest{ID: manifestID}
|
||||
err := query.Load(&manifest)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "UpdateManifest Load", nil, err)
|
||||
}
|
||||
if manifest.Name == "" {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "UpdateManifest Load Name", "not empty", manifest.Name)
|
||||
}
|
||||
|
||||
manifestValues := map[string]interface{}{
|
||||
"Version": "VERSION",
|
||||
"Description": "DESCRIPTION",
|
||||
"ReleaseNotes": "RELEASENOTES",
|
||||
"RollbackEnabled": true,
|
||||
"Type": "standard",
|
||||
"Fingerprint": "10203040",
|
||||
}
|
||||
testhelper.PropsTester(t, &manifest, manifestValues)
|
||||
if len(manifest.ECUs) != 1 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "UpdateManifest Load ECUs", 1, len(manifest.ECUs))
|
||||
} else {
|
||||
ecu := manifest.ECUs[0]
|
||||
ecuValues := map[string]interface{}{
|
||||
"ECU": "ECU",
|
||||
"Version": "VERSION",
|
||||
"HWVersion": "",
|
||||
"HWVersions": []string{"HWVERSION"},
|
||||
"ConfigurationMask": "CONFIGURATIONMASK",
|
||||
"Mode": "D",
|
||||
"SelfDownload": true,
|
||||
}
|
||||
testhelper.PropsTester(t, ecu, ecuValues)
|
||||
|
||||
if len(ecu.Files) != 1 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "UpdateManifest Load Files", 1, len(ecu.Files))
|
||||
} else {
|
||||
file := ecu.Files[0]
|
||||
fileValues := map[string]interface{}{
|
||||
"FileID": "FILEID",
|
||||
"URL": "URL",
|
||||
"Filename": "FILENAME",
|
||||
"FileSize": uint64(5),
|
||||
"Checksum": "CHECKSUM",
|
||||
"FileType": "bootloader",
|
||||
"FileOrder": 9,
|
||||
}
|
||||
testhelper.PropsTester(t, file, fileValues)
|
||||
}
|
||||
}
|
||||
|
||||
manifests, err := query.Select(&common.UpdateManifest{ID: manifestID}, nil)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "UpdateManifest Select", nil, err)
|
||||
}
|
||||
if len(manifests) != 1 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "UpdateManifest Select count", 1, len(manifests))
|
||||
} else {
|
||||
testhelper.PropsTester(t, &manifests[0], manifestValues)
|
||||
}
|
||||
|
||||
count, err := query.Count(common.UpdateManifest{ID: manifestID})
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "UpdateManifest Count", nil, err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "UpdateManifest Count", 1, count)
|
||||
}
|
||||
}
|
||||
|
||||
func archiveUpdateManifest(t *testing.T, query queries.UpdateManifestsInterface, ids []int64, active bool) {
|
||||
result, err := query.Archive(ids, active)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "UpdateManifest archive", nil, err)
|
||||
return
|
||||
}
|
||||
|
||||
if result.RowsAffected() != 1 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "UpdateManifest archive RowsAffected", 1, result.RowsAffected())
|
||||
}
|
||||
|
||||
if result.RowsReturned() != 0 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "UpdateManifest archive RowsReturned", 0, result.RowsReturned())
|
||||
}
|
||||
}
|
||||
|
||||
func searchUpdateManifest(t *testing.T, query queries.UpdateManifestsInterface, manifestID int64) {
|
||||
manifest := common.UpdateManifest{ID: manifestID}
|
||||
err := query.Load(&manifest)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "UpdateManifest Search Load", nil, err)
|
||||
}
|
||||
|
||||
search := common.UpdateManifestSearch{
|
||||
Search: manifest.Name,
|
||||
}
|
||||
manifests, err := query.Search(search, nil)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "UpdateManifest Search", nil, err)
|
||||
}
|
||||
if len(manifests) == 0 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "UpdateManifest Search len", "gt 0", len(manifests))
|
||||
}
|
||||
|
||||
count, err := query.SearchCount(search)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "UpdateManifest SearchCount", nil, err)
|
||||
}
|
||||
if count == 0 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "UpdateManifest SearchCount count", "gt 0", count)
|
||||
}
|
||||
}
|
||||
|
||||
func updateUpdateManifest(t *testing.T, query queries.UpdateManifestsInterface, manifestID int64) {
|
||||
manifest := common.UpdateManifest{ID: manifestID}
|
||||
err := query.Load(&manifest)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "UpdateManifest Update Load", nil, err)
|
||||
}
|
||||
|
||||
manifest.Type = "forced"
|
||||
manifest.Name = "some very nice update"
|
||||
result, err := query.Update(&manifest)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "UpdateManifest Update", nil, err)
|
||||
return
|
||||
}
|
||||
if result.RowsAffected() != 1 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "UpdateManifest Update RowsAffected", 1, result.RowsAffected())
|
||||
}
|
||||
if result.RowsReturned() != 0 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "UpdateManifest Update RowsReturned", 0, result.RowsReturned())
|
||||
}
|
||||
|
||||
manifest = common.UpdateManifest{ID: manifestID}
|
||||
err = query.Load(&manifest)
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "UpdateManifest Update Reload", nil, err)
|
||||
}
|
||||
if manifest.Type != "forced" && manifest.Name != "some very nice update" {
|
||||
t.Errorf(
|
||||
testhelper.TestErrorTemplate,
|
||||
"UpdateManifest Update",
|
||||
"forced, some very nice update",
|
||||
fmt.Sprintf("%s, %s", manifest.Type, manifest.Name))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func deleteUpdateManifest(t *testing.T, query queries.UpdateManifestsInterface, manifestID int64) {
|
||||
result, err := query.Delete(&common.UpdateManifest{ID: manifestID})
|
||||
if err != nil {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "UpdateManifest delete", nil, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Used to be 15 as the delete wasn't cascaded and the count of rows deletes was the total of all the deletes being ran
|
||||
if result.RowsAffected() != 1 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "UpdateManifest delete RowsAffected", 1, result.RowsAffected())
|
||||
}
|
||||
|
||||
if result.RowsReturned() != 0 {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "UpdateManifest delete RowsReturned", 0, result.RowsReturned())
|
||||
}
|
||||
}
|
||||
|
||||
// Testing to see if ECURollbakcs query works successfully on the database
|
||||
func TestECURollbacks(t *testing.T) {
|
||||
t.Skip()
|
||||
vin := "VCF1ZBU28PG003392"
|
||||
instance := queries.NewUpdateManifest(nil)
|
||||
|
||||
man := common.UpdateManifestECU{
|
||||
ECU: "ICC",
|
||||
}
|
||||
res, err := instance.ECURollback(&man, vin)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if len(res) != 1{
|
||||
t.Logf("Rollback length wrong: expected 1 got %d", len(res))
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
|
||||
func TestManifestSearch(t *testing.T){
|
||||
t.Skip()
|
||||
filter := common.UpdateManifestSearch{}
|
||||
filter.ManifestType = common.AftersalesUpdateType
|
||||
|
||||
paging := queries.PageQueryOptions{
|
||||
Order: "id DESC",
|
||||
Limit: 5,
|
||||
Offset: 0,
|
||||
}
|
||||
|
||||
query := setupUpdateManifest(t)
|
||||
_, err := query.Search(filter, &paging)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
34
pkg/db/sqllogger.go
Normal file
34
pkg/db/sqllogger.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"fiskerinc.com/modules/logger"
|
||||
pg "github.com/go-pg/pg/v10"
|
||||
)
|
||||
|
||||
// SQLLogger query hook to display generated SQL
|
||||
// To use:
|
||||
// Driver.AddQueryHook(db.SQLLogger{})
|
||||
// For development use only. Do not enable in production
|
||||
type SQLLogger struct {
|
||||
Out chan []byte
|
||||
}
|
||||
|
||||
// AfterQuery query hook to display generated SQL
|
||||
func (d SQLLogger) AfterQuery(c context.Context, q *pg.QueryEvent) error {
|
||||
data, _ := q.FormattedQuery()
|
||||
|
||||
if d.Out != nil {
|
||||
d.Out <- data
|
||||
} else {
|
||||
logger.Debug().Msgf("AfterQuery: %s", string(data))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BeforeQuery required for interface
|
||||
func (d SQLLogger) BeforeQuery(c context.Context, q *pg.QueryEvent) (context.Context, error) {
|
||||
return c, nil
|
||||
}
|
||||
Reference in New Issue
Block a user