package redis import ( "context" "sync" "github.com/fiskerinc/cloud-services/pkg/logger" "github.com/gomodule/redigo/redis" "github.com/pkg/errors" ) // NewPubSub generates a new working PubSub object func NewPubSub(args ...redis.Conn) *PubSub { var conn redis.Conn if len(args) > 0 { conn = args[0] } else { conn = NewClient().GetConn() } return &PubSub{ connection: redis.PubSubConn{Conn: conn}, subscriptions: make(Set), } } // PubSub is a struct used for subscribing to Redis channels // // follows the Listener interface type PubSub struct { connection PubSubClient subscriptions Set mu sync.Mutex } // PubSubClient provides necessary functions needed for connection // // within PubSub struct type PubSubClient interface { Receive() interface{} Subscribe(...interface{}) error Unsubscribe(...interface{}) error Close() error } // Add an ID to subscriber func (ps *PubSub) Add(id string) error { ps.mu.Lock() defer ps.mu.Unlock() ok := ps.subscriptions.Add(id) if !ok { return errors.Errorf("%v already in subscriptions", id) } if err := ps.connection.Subscribe(ChannelKey(id)); err != nil { return errors.WithStack(err) } return nil } // Remove an ID from subscriber func (ps *PubSub) Remove(id string) error { ps.mu.Lock() defer ps.mu.Unlock() ok := ps.subscriptions.Remove(id) if !ok { return errors.Errorf("%v does not exist in subscriptions", id) } if err := ps.connection.Unsubscribe(ChannelKey(id)); err != nil { return errors.WithStack(err) } return nil } // Listen loops on receiving messages from subscriptions until cancelled func (ps *PubSub) Listen(ctx context.Context, handler func(string, []byte) error) error { isListening := true done := make(chan error, 1) go func() { select { case <-ctx.Done(): isListening = false break case <-done: return } ps.mu.Lock() defer ps.mu.Unlock() if err := ps.connection.Unsubscribe(); err != nil { logger.Error().Err(err).Send() } }() for isListening { switch m := ps.connection.Receive().(type) { case error: done <- m return errors.WithStack(m) case redis.Message: id := ParseChannelKey(m.Channel) go func(channel string, data []byte) { if err := handler(channel, data); err != nil { logger.At(logger.Error(), channel, "redis"). Err(err).Send() } else { logger.At(logger.Debug(), channel, "redis"). Str("msg", string(data)). Msgf("sent published msg to %s", channel) } }(id, m.Data) case redis.Subscription: switch m.Count { case 0: if !isListening { return nil } } } } return nil } // ListenChannel dumps redis messages to channel rather // // than using a traditional handler func (ps *PubSub) ListenChannel() error { // stub return nil } // Length returns number of subscriptions func (ps *PubSub) Length() int { return len(ps.subscriptions) } // Restart re-initializes pubsub connection func (ps *PubSub) Restart() error { ps.mu.Lock() defer ps.mu.Unlock() if ps.connection != nil { ps.connection.Close() } ps.connection = redis.PubSubConn{Conn: NewClient().GetConn()} for id := range ps.subscriptions { if err := ps.connection.Subscribe(ChannelKey(id)); err != nil { return errors.WithStack(err) } } return nil }