package tester import ( "encoding/json" "fmt" "strconv" "strings" "sync" "fiskerinc.com/modules/common" "fiskerinc.com/modules/redis" "github.com/pkg/errors" ) const ( errInsufficientArgs = "insufficient number of args" errBadSetValue = "bad set value" errBadExpireValue = "bad expire value" errExpectedKey = "expected key to be string" errExpectedMessage = "expected message to be []byte" ) func NewRedisMock() *MockRedis { redis.MockRedisConnection() conn := redis.GetMockPool().Get() mockRedis := &MockRedis{} mockRedis.SetConn(conn) mockRedis.Reset() return mockRedis } type MockRedis struct { redis.Connection mu sync.Mutex // HGET results for key and field HGETResults map[string]map[string]interface{} // HGETALL results array for key HGETALLResults map[string][]interface{} // SMSMEMBER results for key and member SISMEMBEResults map[string]map[string]interface{} // Results for get commands (GET, EXISTS, SMEMBERS) and key GetCommandResult map[string]map[string]interface{} ExecuteResults interface{} GetResults interface{} GetCacheResults string GetSetResults string RetrieveResult string Error error PublishedMessages map[string]interface{} GetObjectResults map[string]string GetObjectRawResults map[string][]byte GetMultiResults []interface{} SetValues map[string]ExpiringCache ExecutedCommands []interface{} Closed bool } func (m *MockRedis) Delete(id ...interface{}) error { return m.processDelCommand(append([]interface{}{"DEL"}, id...)) } func (m *MockRedis) Close() error { m.mu.Lock() defer m.mu.Unlock() m.Closed = true return m.Error } func (m *MockRedis) Execute(command ...interface{}) (interface{}, error) { _, _ = m.executeBatch(&redis.RedisBatchCommands{Commands: [][]interface{}{command}}) return m.ExecuteResults, m.Error } func (m *MockRedis) ExecuteBatch(batch *redis.RedisBatchCommands) (interface{}, error) { if m.Error != nil { return nil, m.Error } if batch.IsEmpty() { return nil, nil } return m.executeBatch(batch) } func (c *MockRedis) SafeQueueMessage(id string, message interface{}) error { c.mu.Lock() defer c.mu.Unlock() return c.QueueMessage(id, message) } // SafeQueueMessage is the thread-safe implementation of QueueMessage func (c *MockRedis) SafePublishMessage(id string, message interface{}) error { c.mu.Lock() defer c.mu.Unlock() return c.PublishMessage(id, message) } func (m *MockRedis) SetObjectField(key, field string, value interface{}) error { return m.processHSetCommand([]interface{}{"HSET", key, field, value}) } func (m *MockRedis) GetObjectField(string, string) (string, error) { return "", nil } func (m *MockRedis) SafeExecuteBatch(batch *redis.RedisBatchCommands) (interface{}, error) { if m.Error != nil { return nil, m.Error } if batch.IsEmpty() { return nil, nil } m.mu.Lock() defer m.mu.Unlock() return m.executeBatch(batch) } func (m *MockRedis) executeBatch(batch *redis.RedisBatchCommands) (interface{}, error) { if m.Error != nil { return nil, m.Error } var err error var val interface{} var vals []interface{} results := []interface{}{} defer batch.Clear() for _, command := range batch.Commands { m.ExecutedCommands = append(m.ExecutedCommands, command) val = int64(0) switch command[0] { case "DEL": err = m.processDelCommand(command) case "GET", "EXISTS", "SMEMBERS": val, err = m.processGetCommand(command) case "EXPIRE", "EXPIREAT": err = m.processExpireCommand(command) case "HGETALL": vals, err = m.processHGetAllCommand(command) if err == nil { results = append(results, vals) continue } case "HGET": val, err = m.processHGetCommand(command) case "HSET": err = m.processHSetCommand(command) case "PUBLISH": err = m.processPublishCommand(command) case "SET": err = m.processSetCommand(command) case "RPUSH": err = m.processQueueCommand(command) case "SISMEMBER": val, err = m.processSISMemberCommand(command) case "SADD": err = m.processSADDCommand(command) } if err == nil { results = append(results, val) } } return results, err } func (m *MockRedis) processGetCommand(command []interface{}) (interface{}, error) { if len(command) != 2 { return nil, errors.New(errInsufficientArgs) } return m.getMapMapResult(m.GetCommandResult, command), nil } func (m *MockRedis) processSetCommand(command []interface{}) error { cache := ExpiringCache{} if len(command) < 3 { return errors.New(errInsufficientArgs) } else { data, ok := command[2].([]byte) if !ok { return errors.New(errBadSetValue) } cache.Value = string(data) } if len(command) == 5 && command[3] == "EX" { expires, ok := command[4].(int) if !ok { return errors.New(errBadExpireValue) } cache.Expires = expires } key := fmt.Sprintf("%v", command[1]) m.SetValues[key] = cache return nil } func (m *MockRedis) getMapMapResult(mmap map[string]map[string]interface{}, command []interface{}) interface{} { if mmap == nil { return nil } value, ok := mmap[command[0].(string)] if !ok { return nil } result, ok := value[command[1].(string)] if ok { return result } return nil } func (m *MockRedis) processHGetCommand(command []interface{}) (interface{}, error) { if len(command) != 3 { return nil, errors.New("HGET incorrect number of parameters") } return m.getMapMapResult(m.HGETResults, command[1:]), nil } func (m *MockRedis) processHSetCommand(command []interface{}) error { cache := ExpiringCache{} numArgs := len(command) if numArgs < 4 || numArgs%2 != 0 { return errors.New(errInsufficientArgs) } else { obj, expire := m.getValueCache(command[1].(string)) for i, value := range command[2:] { if i%2 == 0 { key, ok := value.(string) if !ok { return errors.New(errExpectedKey) } obj[key] = command[i+3] } } data, err := json.Marshal(obj) if err != nil { return err } cache.Value = string(data) cache.Expires = expire } key := fmt.Sprintf("%v", command[1]) m.SetValues[key] = cache return nil } func (m *MockRedis) getValueCache(key string) (map[string]interface{}, int) { obj := map[string]interface{}{} if cache, ok := m.SetValues[key]; ok { data := cache.Value.(string) err := json.Unmarshal([]byte(data), &obj) if err != nil { panic(fmt.Sprintf("getValueCache %s %v", key, err)) } return obj, cache.Expires } return obj, 0 } func (m *MockRedis) processExpireCommand(command []interface{}) error { if len(command) != 3 { return errors.New(errInsufficientArgs) } key := fmt.Sprintf("%v", command[1]) cache, ok := m.SetValues[key] if ok { expires, err := strconv.Atoi(fmt.Sprint(command[2])) if err == nil { cache.Expires = expires m.SetValues[key] = cache } else { return errors.New(errBadExpireValue) } } return nil } func (m *MockRedis) processSISMemberCommand(command []interface{}) (interface{}, error) { if len(command) != 3 { return nil, errors.New(errInsufficientArgs) } return m.getMapMapResult(m.SISMEMBEResults, command[1:]), nil } func (m *MockRedis) processHGetAllCommand(command []interface{}) ([]interface{}, error) { if len(command) != 2 { return nil, errors.New(errInsufficientArgs) } if m.HGETALLResults == nil { return nil, nil } values, ok := m.HGETALLResults[command[1].(string)] if !ok { return nil, nil } return values, nil } func (m *MockRedis) processQueueCommand(command []interface{}) error { if len(command) != 3 { return errors.New(errInsufficientArgs) } return m.QueueMessage(command[1].(string), command[2]) } func (m *MockRedis) processPublishCommand(command []interface{}) error { if len(command) != 3 { return errors.New(errInsufficientArgs) } // Publish message is passed in as []byte, but mock PublishMessage expects object data, ok := command[2].([]byte) if !ok { return errors.New(errExpectedMessage) } msg := map[string]interface{}{} err := json.Unmarshal(data, &msg) if err != nil { return errors.WithStack(err) } return m.PublishMessage(command[1].(string), msg) } func (m *MockRedis) processDelCommand(command []interface{}) error { if len(command) != 2 { return errors.New(errInsufficientArgs) } key := command[1].(string) cache := ExpiringCache{Value: "DELETED"} m.SetValues[key] = cache return nil } func (m *MockRedis) processSADDCommand(command []interface{}) error { key := command[1].(string) cache := ExpiringCache{Value: command[2:]} m.SetValues[key] = cache return nil } func (m *MockRedis) Get(string) (interface{}, error) { return m.GetResults, m.Error } func (m *MockRedis) GetCache(id string, data interface{}, expire int) error { err := json.Unmarshal([]byte(m.GetCacheResults), data) if err != nil { return errors.WithStack(err) } return nil } func (m *MockRedis) GetSet(id string, data interface{}) error { if m.Error != nil { return m.Error } err := json.Unmarshal([]byte(m.GetSetResults), data) if err != nil { return err } return nil } func (m *MockRedis) PublishMessage(id string, msg interface{}) error { if m.Error != nil { return m.Error } if m.PublishedMessages == nil { m.PublishedMessages = map[string]interface{}{} } // In the real thing, you message is converted to JSON and sent out, so further changes // to the struct do not change redis, but because we are keeping the struct, any internal pointer // can still be affected msg, _ = BeginDeepCopy(msg) // trim prefix for publishing and queueing of message key := strings.Replace(strings.Replace(id, "channel:", "", 1), "queue:", "", 1) m.PublishedMessages[key] = msg return nil } func (m *MockRedis) QueueMessage(id string, msg interface{}) error { return m.PublishMessage(id, msg) } func (m *MockRedis) BatchPublishMessages(ids []string, messages []interface{}) error { if len(ids) != len(messages) { return errors.Errorf( "mismatch number of ids and messages. have %d ids and %d messages", len(ids), len(messages), ) } for i := 0; i < len(ids); i++ { err := m.SafePublishMessage(ids[i], messages[i]) if err != nil { return err } } return nil } func (m *MockRedis) BatchQueueMessages(ids []string, messages []interface{}) error { if len(ids) != len(messages) { return errors.Errorf( "mismatch number of ids and messages. have %d ids and %d messages", len(ids), len(messages), ) } for i := 0; i < len(ids); i++ { err := m.QueueMessage(ids[i], messages[i]) if err != nil { return err } } return nil } func (m *MockRedis) Retrieve(id string, data interface{}) error { if m.Error != nil { return m.Error } err := json.Unmarshal([]byte(m.RetrieveResult), data) if err != nil { return errors.WithStack(err) } return nil } func (m *MockRedis) Set(key string, value interface{}) error { if m.Error != nil { return m.Error } m.SetValues[key] = ExpiringCache{ Value: value, } return nil } func (m *MockRedis) SetCache(key string, value interface{}, expires int) error { if m.Error != nil { return m.Error } m.SetValues[key] = ExpiringCache{Value: value, Expires: expires} return nil } func (m *MockRedis) GetObject(key string, obj interface{}) error { if m.Error != nil { return m.Error } data, ok := m.GetObjectResults[key] if !ok { return redis.ErrNilObject } err := json.Unmarshal([]byte(data), obj) if err != nil { return errors.WithStack(err) } return nil } func (m *MockRedis) SetObject(key string, value interface{}, expires int) error { return m.SetCache(key, value, expires) } func (m *MockRedis) GetMulti(ids []string) ([]interface{}, error) { return m.GetMultiResults, m.Error } // Test helper methods func (m *MockRedis) HasMessage(id string, msg string) (string, bool) { var compare string if value, ok := m.PublishedMessages[id]; ok { if compare, ok = m.isByteSlice(value); !ok { if compare, ok = m.getJSON(value); !ok { return "", false } } if compare == msg { return compare, true } } return compare, false } func (m *MockRedis) getJSON(value interface{}) (string, bool) { result, err := json.Marshal(value) if err != nil { return "", false } return string(result), true } func (m *MockRedis) isByteSlice(value interface{}) (string, bool) { // convert the []byte message into string if data, ok := value.([]byte); ok { return string(data), true } return "", false } func (m *MockRedis) FetchCache(id string) (value ExpiringCache, ok bool) { value, ok = m.SetValues[id] return } func (m *MockRedis) NewSet(id string, value interface{}, expires int) error { if m.Error != nil { return m.Error } m.SetValues[id] = ExpiringCache{Value: value, Expires: expires} return nil } func (m *MockRedis) Ping() error { return m.Error } func (m *MockRedis) GetObjectRaw(string) (map[string][]byte, error) { if m.Error != nil { return nil, m.Error } return m.GetObjectRawResults, nil } func (m *MockRedis) Reset() { m.ExecuteResults = nil m.GetResults = nil m.Error = nil m.PublishedMessages = map[string]interface{}{} m.SetValues = map[string]ExpiringCache{} m.ExecutedCommands = []interface{}{} m.Closed = false } // We attempt to type convert our interfaces so we can copy them. // If we try to do the json.Marshal copy without doing this, then we get map[string]interface{} // which requires code changes to multiple different areas to get working func BeginDeepCopy(original interface{}) (copy interface{}, err error) { switch original.(type) { case common.Message: copy, err = deepCopyMessage(original.(common.Message)) default: err = errors.New("no match") } if err != nil { return original, err } return } // If we want to deep copy other data structs, just need to add their type in here func deepCopyMessage(original common.Message) (copy common.Message, err error) { copy = original switch original.Data.(type) { case ExampleDeepItem: data := original.Data.(ExampleDeepItem) copy.Data, err = copyObject(data) case common.UpdateManifest: data := original.Data.(common.UpdateManifest) copy.Data, err = copyObject(data) case common.CarUpdate: data := original.Data.(common.CarUpdate) copy.Data, err = copyObject(data) default: err = errors.New("no result") } if err != nil { return original, err } return } // If this is used on an interface, it makes it a map[string]interface // so your object needs to be type casted first func copyObject[V any](original V) (copy V, err error) { b, err := json.Marshal(original) if err != nil { return original, err } err = json.Unmarshal(b, ©) if err != nil { return original, err } return } type ExampleDeepItem struct { Title string NestedObject []*NestedDeepItem } type NestedDeepItem struct { ID int Description *string }