package mongo import ( "context" "github.com/fiskerinc/cloud-services/pkg/db/queries" e "github.com/fiskerinc/cloud-services/pkg/mongo/error" "github.com/pkg/errors" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" ) func NewCollection(collection *mongo.Collection) CollectionInterface { return &Collection{ collection: collection, } } type CollectionInterface interface { InsertOne(document interface{}) (interface{}, error) FindOne(filter interface{}, object interface{}, projection interface{}) error ReplaceOne(filter interface{}, update interface{}) error DeleteOne(document interface{}) error UpdateOne(filter interface{}, update interface{}) (*mongo.UpdateResult, error) UpdateMany(filter interface{}, update interface{}) error Find(filter interface{}, objects interface{}, pq *queries.PageQueryOptions) error Count(filter interface{}) (int64, error) Aggregate(pipeline mongo.Pipeline, object interface{}) error } type Collection struct { collection *mongo.Collection } // InsertOne is an abstraction over Mongo's library func (c *Collection) InsertOne(document interface{}) (interface{}, error) { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() r, err := c.collection.InsertOne(ctx, document) if err != nil { return r, errors.WithStack(err) } return r.InsertedID, nil } // FindOne is an abstraction over Mongo's library func (c *Collection) FindOne(filter interface{}, object interface{}, projection interface{}) error { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() opts := options.FindOne().SetProjection(projection) err := c.collection.FindOne(ctx, filter, opts).Decode(object) if err != nil { return errors.WithStack(err) } return nil } // ReplaceOne is an abstraction over Mongo's library func (c *Collection) ReplaceOne(filter interface{}, update interface{}) error { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() r, err := c.collection.ReplaceOne(ctx, filter, update) if err != nil { return errors.WithStack(err) } else if r.MatchedCount == 0 { return e.ErrInvalidNumberOfDocs } return nil } // DeleteOne is an abstraction over Mongo's library func (c *Collection) DeleteOne(document interface{}) error { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() r, err := c.collection.DeleteOne(ctx, document) if err != nil { return errors.WithStack(err) } else if r.DeletedCount != 1 { return e.ErrUnableToDelete } return nil } func (c *Collection) UpdateOne(filter interface{}, update interface{}) (*mongo.UpdateResult, error) { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() r, err := c.collection.UpdateOne(ctx, filter, update) if err != nil { return nil, errors.WithStack(err) } else if r.MatchedCount == 0 { return nil, e.ErrInvalidNumberOfDocs } return r, nil } func (c *Collection) UpdateMany(filter interface{}, update interface{}) error { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() r, err := c.collection.UpdateMany(ctx, filter, update) if err != nil { return errors.WithStack(err) } else if r.MatchedCount == 0 { return e.ErrInvalidNumberOfDocs } return nil } func (c *Collection) Find(filter interface{}, objects interface{}, pq *queries.PageQueryOptions) error { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() f, err := c.collection.Find( ctx, filter, options.Find().SetLimit(int64(pq.Limit)), options.Find().SetSkip(int64(pq.Offset)), options.Find().SetSort(getOrder(pq.Order)), options.Find().SetProjection(getFieldsToIgnore(pq.Ignore)), ) if err != nil { return errors.WithStack(err) } err = f.All(ctx, objects) if err != nil { return errors.WithStack(err) } return nil } func (c *Collection) Count(filter interface{}) (int64, error) { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() t, err := c.collection.CountDocuments(ctx, filter) if err != nil { return 0, errors.WithStack(err) } return t, nil } // Aggregate when all else fails func (c *Collection) Aggregate(pipeline mongo.Pipeline, object interface{}) error { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() a, err := c.collection.Aggregate(ctx, pipeline) if err != nil { return errors.WithStack(err) } err = a.All(ctx, object) if err != nil { return errors.WithStack(err) } return nil }