package redis import ( "encoding/json" "net" "strings" "sync" "github.com/fiskerinc/cloud-services/pkg/logger" "github.com/gomodule/redigo/redis" "github.com/pkg/errors" ) // NewClient is constructor for connection method func NewClient(args ...redis.Conn) (client Client) { if len(args) > 0 { client = &Connection{ conn: args[0], } } else { client = &Connection{} } return } // Client defines the function signatures associated with sending messages // // and setting/getting objects type Client interface { GetConn() redis.Conn SetConn(redis.Conn) Close() error Ping() error queueMessage(string, interface{}) error publishMessage(string, interface{}) error BatchQueueMessages(ids []string, messages []interface{}) error BatchPublishMessages(ids []string, messages []interface{}) error SafeQueueMessage(string, interface{}) error SafePublishMessage(string, interface{}) error // Simple redis operations Set(string, interface{}) error Get(string) (interface{}, error) Delete(...interface{}) error GetMulti(ids []string) ([]interface{}, error) SetMulti(ids []string, data []interface{}) error // Sets NewSet(string, interface{}, int) error GetSet(string, interface{}) error AddToSet(id string, data interface{}, expire int) error // Use objects when you wish to access individual fields in future SetObject(string, interface{}, int) error SetObjectField(string, string, interface{}) error SetObjects([]string, []interface{}, int) error GetObject(string, interface{}) error GetObjectField(string, string) (string, error) GetObjectMap(string) (map[string]string, error) GetObjectRaw(string) (map[string][]byte, error) GetObjectsMulti([]string, []interface{}) error GetObjectsMultiMap([]string) (map[string]map[string]string, error) GetValuesMulti(ids []string, data interface{}) error // General execution Retrieve(command string, data interface{}) error // Cache functions marshal/unmarshal any data type to redis SetCache(string, interface{}, int) error GetCache(string, interface{}, int) error // Thread-safe variations SafeSet(string, interface{}) error SafeGet(string) (interface{}, error) SafeDelete(...interface{}) error SafeNewSet(string, interface{}, int) error SafeGetSet(string, interface{}) error SafeSetObject(string, interface{}, int) error SafeGetObject(string, interface{}) error Execute(command ...interface{}) (interface{}, error) SafeExecute(command ...interface{}) (interface{}, error) ExecuteBatch(batch *RedisBatchCommands) (interface{}, error) SafeExecuteBatch(batch *RedisBatchCommands) (interface{}, error) } // Connection holds a client to redis // // The methods for connection are NOT thread safe. type Connection struct { conn redis.Conn once sync.Once mu sync.Mutex } // GetConn creates a client if it doesn't exist func (c *Connection) GetConn() redis.Conn { c.once.Do(func() { if c.conn == nil { c.SetConn(NewPool().Get()) } }) return c.conn } // SetConn sets the client func (c *Connection) SetConn(conn redis.Conn) { c.conn = conn } // Close the client if it exists func (c *Connection) Close() error { if c.conn == nil { return nil } err := c.conn.Close() if err != nil { return errors.WithStack(err) } c.conn = nil return nil } func (c *Connection) do(commandName string, args ...interface{}) (reply interface{}, err error) { reply, err = c.GetConn().Do(commandName, args...) if c.checkConnError(err) { return reply, errors.WithStack(err) } return reply, nil } func (c *Connection) send(commandName string, args ...interface{}) error { err := c.GetConn().Send(commandName, args...) if err != nil { return errors.WithStack(err) } return nil } func (c *Connection) Ping() error { _, err := c.GetConn().Do("PING") if err != nil { return errors.WithStack(err) } return nil } func (c *Connection) checkConnError(err error) bool { if err == nil { return false } if c.isNetErr(err) { c.reconnect() } return true } func (c *Connection) isNetErr(err error) bool { _, ok := errors.Cause(err).(*net.OpError) return ok } func (c *Connection) reconnect() { c.Close() c.SetConn(NewPool().Get()) } // QueueMessage writes a message to the corresponding list // follows the format "queue:" // Messages are guaranteed to be delivered upon websocket connection func (c *Connection) queueMessage(id string, message interface{}) error { data, err := json.Marshal(message) if err != nil { return errors.WithStack(err) } _, err = c.do("RPUSH", QueueKey(id), data) if err != nil { logger.At(logger.Error(), ChannelKey(id), "redis"). Str("msg", string(data)).Err(err).Send() } c.do("EXPIRE", QueueKey(id), 3600) // 3600 = 1hr return err } // PublishMessage writes a message to the corresponding channel // follows the format "channel:" // This is a fire and forget mechanism func (c *Connection) publishMessage(id string, message interface{}) error { data, err := json.Marshal(message) if err != nil { return errors.WithStack(err) } _, err = c.do("PUBLISH", ChannelKey(id), data) if err != nil { return err } logger.At(logger.Debug(), ChannelKey(id), "redis"). Str("msg", string(data)). Msgf("sent redis msg to %s", id) return nil } // BatchQueueMessages is the same as QueueMessage except performs a batch call func (c *Connection) BatchQueueMessages(ids []string, messages []interface{}) error { if len(ids) != len(messages) { return errors.Errorf( "mismatch number of ids and messages. have %d ids and %d messages", len(ids), len(messages), ) } batch := NewRedisBatchCommands() for i := 0; i < len(ids); i++ { data, err := json.Marshal(messages[i]) if err != nil { return errors.WithStack(err) } batch.Add("RPUSH", QueueKey(ids[i]), data) } _, err := c.ExecuteBatch(batch) if err != nil { return err } return nil } // BatchPublishMessages is the same as PublishMessage except performs a batch call func (c *Connection) BatchPublishMessages(ids []string, messages []interface{}) error { if len(ids) != len(messages) { return errors.Errorf( "mismatch number of ids and messages. have %d ids and %d messages", len(ids), len(messages), ) } batch := NewRedisBatchCommands() for i := 0; i < len(ids); i++ { data, err := json.Marshal(messages[i]) if err != nil { return errors.WithStack(err) } batch.Add("PUBLISH", ChannelKey(ids[i]), data) } _, err := c.ExecuteBatch(batch) if err != nil { return err } return nil } // SafeQueueMessage is the thread-safe implementation of QueueMessage func (c *Connection) SafeQueueMessage(id string, message interface{}) error { // c.mu.Lock() // defer c.mu.Unlock() return c.queueMessage(id, message) } // SafePublishMessage is the thread-safe implementation of PublishMessage func (c *Connection) SafePublishMessage(id string, message interface{}) error { c.mu.Lock() defer c.mu.Unlock() return c.publishMessage(id, message) } // Set replicates redis "SET" func (c *Connection) Set(id string, data interface{}) error { if _, err := c.do("SET", id, data); err != nil { return err } return nil } // Get replicates redis "GET", must properly unpack interface returned func (c *Connection) Get(id string) (interface{}, error) { data, err := c.do("GET", id) return data, err } // Delete removes all ids inputted func (c *Connection) Delete(id ...interface{}) error { numDeleted, err := redis.Int(c.do("DEL", id...)) if err != nil { return err } if numDeleted != len(id) { return errors.Errorf( "tried to delete %v (total: %v), however only %v were deleted", id, len(id), numDeleted, ) } return nil } // Deprecated: NewSet adds items to a set in redis func (c *Connection) NewSet(id string, data interface{}, expire int) error { var err error if err = c.send("MULTI"); err != nil { return err } if err = c.send("DEL", id); err != nil { c.GetConn().Do("DISCARD") return err } if err = c.send("SADD", redisArgs(id, data)...); err != nil { c.GetConn().Do("DISCARD") return err } if expire > 0 { if err = c.send("EXPIRE", id, expire); err != nil { c.GetConn().Do("DISCARD") return errors.WithStack(err) } } if _, err = c.do("EXEC"); err != nil { return errors.WithStack(err) } return nil } // AddToSet adds item to a set in redis func (c *Connection) AddToSet(id string, data interface{}, expire int) error { var err error if err = c.send("MULTI"); err != nil { return err } if err = c.send("SADD", id, data); err != nil { return err } if expire > 0 { if err = c.send("EXPIRE", id, expire); err != nil { c.GetConn().Do("DISCARD") return errors.WithStack(err) } } if _, err = c.do("EXEC"); err != nil { return errors.WithStack(err) } return nil } // Deprecated: GetSet retrieves items from a set in redis func (c *Connection) GetSet(id string, data interface{}) error { values, err := redis.Values(c.do("SMEMBERS", id)) if err != nil { return err } err = redis.ScanSlice(values, data) if err != nil { return err } return nil } // Deprecated: SetObject assigns the hash key id to the data object // data can be of any type. Expire is in seconds, use -1 for no expire func (c *Connection) SetObject(id string, data interface{}, expire int) error { var err error if expire > 0 { err = c.setExpiringObject(id, data, expire) } else { err = c.setPersistentObject(id, data) } return err } func (c *Connection) setPersistentObject(id string, data interface{}) error { _, err := c.do("HSET", redisArgs(id, data)...) return err } func (c *Connection) setExpiringObject(id string, data interface{}, expire int) error { var err error if err = c.send("MULTI"); err != nil { return err } if err = c.send("HSET", redisArgs(id, data)...); err != nil { c.GetConn().Do("DISCARD") return err } if err = c.send("EXPIRE", redisArgs(id, expire)...); err != nil { c.GetConn().Do("DISCARD") return err } if _, err = c.do("EXEC"); err != nil { return err } return nil } // Deprecated: SetObjectField sets a specific key of an object func (c *Connection) SetObjectField(id string, key string, data interface{}) error { _, err := c.do("HSET", redisArgsMulti(id, key, data)...) return err } // Deprecated: SetObjects provides the same functionality as SetObject for multiple objects in one call to redis func (c *Connection) SetObjects(ids []string, data []interface{}, expire int) error { var err error if len(ids) <= 0 || len(data) <= 0 || len(ids) != len(data) { return errors.Errorf("invalid lengths entered, lengths must match and be > 0. ids length: %v data length: %v", len(ids), len(data), ) } if err = c.send("MULTI"); err != nil { return err } for i := range ids { if err = c.send("HSET", redisArgs(ids[i], data[i])...); err != nil { c.GetConn().Do("DISCARD") return err } } if expire > 0 { for _, id := range ids { if err = c.send("EXPIRE", id, expire); err != nil { c.GetConn().Do("DISCARD") return err } } } if _, err = c.do("EXEC"); err != nil { return err } return nil } // Deprecated: GetObject retrieves an object based off the hash key id // and "unmarshals" it to struct pointer given func (c *Connection) GetObject(id string, dest interface{}) error { values, err := redis.Values(c.do("HGETALL", id)) if err != nil { return err } err = redis.ScanStruct(values, dest) if err != nil { return errors.WithStack(err) } return nil } // Deprecated: GetObjectMap retrieves an object based off the hash key id // and returns it as a map[string]interface{}. GetObject() // is preferred if you know the object being retrieved func (c *Connection) GetObjectMap(id string) (map[string]string, error) { object, err := redis.StringMap(c.do("HGETALL", id)) return object, err } // Deprecated: GetObjectRaw retrieves an object based off the hash key id // and returns it as a map[string][]byte. Use this method when you have // a hash of objects that you want to unmarshal. func (c *Connection) GetObjectRaw(id string) (map[string][]byte, error) { var m map[string][]byte values, err := redis.Values(c.do("HGETALL", id)) if err != nil { return m, err } if len(values)%2 != 0 { return nil, errors.New("GetObjectRaw expects even number of values in result") } m = make(map[string][]byte, len(values)/2) for i := 0; i < len(values); i += 2 { key, okKey := values[i].([]byte) value, okValue := values[i+1].([]byte) if !okKey || !okValue { return nil, errors.New("cannot parse object into map[string][]byte") } m[string(key)] = value } return m, nil } // Deprecated: GetObjectField retrieves the value associated with the id and key of a hash map func (c *Connection) GetObjectField(id string, key string) (string, error) { value, err := redis.String(c.do("HGET", redisArgs(id, key)...)) return value, err } // Deprecated: GetObjectsMulti retrieves an array of objects based off // the hash key ids given and returns the data as a // []interface{}. Use this function if you know // you need to get multiple objects from redis (one call to server). func (c *Connection) GetObjectsMulti(ids []string, data []interface{}) error { var err error if len(ids) != len(data) { return errors.Errorf( "number of ids given %v does not match size of data slice %v", len(ids), len(data), ) } if err = c.send("MULTI"); err != nil { return err } for _, id := range ids { if err = c.send("HGETALL", id); err != nil { c.GetConn().Do("DISCARD") return err } } values, err := redis.Values(c.do("EXEC")) if err != nil { return err } for i, value := range values { s, err := redis.Values(value, nil) if err != nil { return errors.WithStack(err) } if err = redis.ScanStruct(s, data[i]); err != nil { return errors.WithStack(err) } } return nil } // Deprecated: GetObjectsMultiMap retrieves an array of objects based off // the hash key ids given and returns the data as a // map[string]interface{}. Use this function if you know // you need to get multiple objects from redis (one call to server). func (c *Connection) GetObjectsMultiMap(ids []string) (map[string]map[string]string, error) { var err error objects := make(map[string]map[string]string) if err := c.send("MULTI"); err != nil { return objects, err } for _, id := range ids { if err = c.send("HGETALL", id); err != nil { c.GetConn().Do("DISCARD") return objects, err } } values, err := redis.Values(c.do("EXEC")) if err != nil { c.GetConn().Do("DISCARD") return objects, err } for i, value := range values { o, err := redis.StringMap(value, nil) if err != nil { return objects, errors.WithStack(err) } objects[ids[i]] = o } return objects, nil } func (c *Connection) makeKeys(ids []string) []interface{} { keys := make([]interface{}, len(ids)) for i, id := range ids { keys[i] = id } return keys } // Deprecated: GetMulti func (c *Connection) GetMulti(ids []string) ([]interface{}, error) { if len(ids) == 0 { return nil, errors.WithStack(errors.New("cannot call redis MGET with no keys")) } keys := c.makeKeys(ids) result, err := redis.Values(c.do("MGET", keys...)) if err != nil { c.GetConn().Do("DISCARD") return result, errors.WithStack(err) } return result, nil } // GetValuesMulti retrieves an array of strings based off // the hash key ids given and returns the data as a // []string. Use this function if you know // you need to get multiple values from redis (one call to server). func (c *Connection) GetValuesMulti(ids []string, data interface{}) error { if len(ids) == 0 { return nil } values, err := c.GetMulti(ids) if err != nil { return errors.WithStack(err) } err = redis.ScanSlice(values, data) if err != nil { return errors.WithStack(err) } return nil } // Deprecated: SetMulti func (c *Connection) SetMulti(ids []string, data []interface{}) error { if len(ids) != len(data) { return errors.New("id and data lengths do not match") } cmds := make([]interface{}, 2*len(ids)) for i, id := range ids { cmds[2*i] = id cmds[2*i+1] = data[i] } _, err := c.do("MSET", redisArgsMulti(cmds...)...) return err } // Deprecated: Retrieve allows for manual commands and providing a destination interface{} to // deserialize bytes retrieved from redis func (c *Connection) Retrieve(command string, dest interface{}) error { input := strings.Split(command, " ") if len(input) < 2 { return ErrInvalidCommand } keyword := input[0] cmds := make([]interface{}, len(input)-1) for i, in := range input[1:] { cmds[i] = in } reply, err := redis.Bytes(c.do(keyword, redisArgsMulti(cmds...)...)) if err != nil { return err } err = json.Unmarshal(reply, dest) if err != nil { return errors.WithStack(err) } return nil } // Deprecated: SetCache marshals object and inserts into redis based on key id // sets expiration to expire int func (c *Connection) SetCache(id string, data interface{}, expire int) error { var err error serialized, err := json.Marshal(data) if err != nil { return errors.WithStack(err) } if err = c.send("MULTI"); err != nil { return err } if err = c.send("SET", id, serialized); err != nil { return err } if expire > 0 { if err = c.send("EXPIRE", id, expire); err != nil { c.GetConn().Do("DISCARD") return err } } if _, err = c.do("EXEC"); err != nil { return errors.WithStack(err) } return nil } // Deprecated: GetCache retrieves object from redis and unmarshals into dest (must be pointer) // resets expiration to expire int func (c *Connection) GetCache(id string, dest interface{}, expire int) error { var err error if err = c.send("MULTI"); err != nil { return err } if err = c.send("GET", id); err != nil { c.GetConn().Do("DISCARD") return err } if expire > 0 { if err = c.send("EXPIRE", id, expire); err != nil { c.GetConn().Do("DISCARD") return err } } values, err := redis.Values(c.do("EXEC")) if err != nil || len(values) == 0 { return err } if values[0] == nil { return ErrNilObject } data, err := redis.Bytes(values[0], nil) if err != nil { return errors.WithStack(err) } err = json.Unmarshal(data, dest) if err != nil { return errors.WithStack(err) } return nil } // Deprecated: SafeSet is the thread-safe version of Set() func (c *Connection) SafeSet(id string, data interface{}) error { c.mu.Lock() defer c.mu.Unlock() return c.Set(id, data) } // Deprecated: SafeGet is the thread-safe version of Get() func (c *Connection) SafeGet(id string) (interface{}, error) { c.mu.Lock() defer c.mu.Unlock() return c.Get(id) } // Deprecated: SafeDelete is the thread-safe version of Delete() func (c *Connection) SafeDelete(id ...interface{}) error { c.mu.Lock() defer c.mu.Unlock() return c.Delete(id...) } // Deprecated: SafeNewSet is the thread-safe version of NewSet() func (c *Connection) SafeNewSet(id string, data interface{}, expire int) error { c.mu.Lock() defer c.mu.Unlock() return c.NewSet(id, data, expire) } // Deprecated: SafeGetSet is the thread-safe version of GetSet() func (c *Connection) SafeGetSet(id string, data interface{}) error { c.mu.Lock() defer c.mu.Unlock() return c.GetSet(id, data) } // Deprecated: SafeSetObject provides a thread-safe SetObject() func (c *Connection) SafeSetObject(id string, data interface{}, expire int) error { c.mu.Lock() defer c.mu.Unlock() return c.SetObject(id, data, expire) } // Deprecated: SafeGetObject provides a thread-safe GetObject() func (c *Connection) SafeGetObject(id string, data interface{}) error { c.mu.Lock() defer c.mu.Unlock() return c.GetObject(id, data) } // Execute provides an abstraction over redigo's Do method // use this method over specialized methods within this library func (c *Connection) Execute(command ...interface{}) (interface{}, error) { var reply interface{} if len(command) < 2 { return reply, ErrInvalidCommand } reply, err := c.do(command[0].(string), command[1:]...) if err != nil { return reply, err } return reply, nil } // SafeExecute provides a thread-safe Execute method func (c *Connection) SafeExecute(command ...interface{}) (interface{}, error) { c.mu.Lock() defer c.mu.Unlock() return c.Execute(command) } // ExecuteBatch sends all commands stored to Redis // removes all commands regardless of success or failure func (c *Connection) ExecuteBatch(batch *RedisBatchCommands) (interface{}, error) { if batch.IsEmpty() { return nil, nil } return c.executeBatch(batch) } func (c *Connection) SafeExecuteBatch(batch *RedisBatchCommands) (interface{}, error) { if batch.IsEmpty() { return nil, nil } c.mu.Lock() defer c.mu.Unlock() return c.executeBatch(batch) } func (c *Connection) executeBatch(batch *RedisBatchCommands) (interface{}, error) { var reply interface{} var err error defer batch.Clear() if err := c.send("MULTI"); err != nil { return reply, err } for _, command := range batch.Commands { keyword, ok := command[0].(string) if !ok { c.GetConn().Do("DISCARD") return reply, ErrInvalidCommand } if err := c.send(keyword, command[1:]...); err != nil { c.GetConn().Do("DISCARD") return reply, err } } reply, err = c.do("EXEC") return reply, err } // redisArgs is a helper function to generate redis args func redisArgs(id string, data interface{}) redis.Args { return redis.Args{}.Add(id).AddFlat(data) } // redisArgsMulti is a helper function to generate redis args for hash map fields and arrays func redisArgsMulti(data ...interface{}) redis.Args { return redis.Args{}.Add(data...) }