You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
255 lines
5.4 KiB
255 lines
5.4 KiB
package mongo |
|
|
|
import ( |
|
"context" |
|
|
|
"github.com/OhYee/rainbow/errors" |
|
"go.mongodb.org/mongo-driver/bson" |
|
"go.mongodb.org/mongo-driver/bson/primitive" |
|
"go.mongodb.org/mongo-driver/mongo" |
|
"go.mongodb.org/mongo-driver/mongo/options" |
|
) |
|
|
|
type countResult struct { |
|
Count int64 `bson:"count"` |
|
} |
|
|
|
type Conn struct { |
|
Client *mongo.Client |
|
Collection *mongo.Collection |
|
} |
|
|
|
func NewConn(databaseName string, collectionName string) (conn *Conn, err error) { |
|
defer func() { |
|
if err != nil { |
|
err = errors.NewErr(err) |
|
} |
|
}() |
|
conn = &Conn{} |
|
|
|
conn.Client, err = mongo.Connect(context.TODO(), getClientOptions()) |
|
if err != nil { |
|
return |
|
} |
|
|
|
err = conn.Client.Ping(context.TODO(), nil) |
|
if err != nil { |
|
return |
|
} |
|
|
|
conn.Collection = conn.Client.Database(databaseName).Collection(collectionName) |
|
return |
|
} |
|
|
|
func (conn *Conn) Close() { |
|
if conn.Client != nil { |
|
conn.Client.Disconnect(context.TODO()) |
|
} |
|
} |
|
|
|
func Find(databaseName string, collectionName string, filter interface{}, |
|
opt *options.FindOptions, res interface{}) (total int64, err error) { |
|
defer func() { |
|
if err != nil { |
|
err = errors.NewErr(err) |
|
} |
|
}() |
|
|
|
conn, err := NewConn(databaseName, collectionName) |
|
defer conn.Close() |
|
if err != nil { |
|
return |
|
} |
|
|
|
cur, err := conn.Collection.Find(context.TODO(), filter, opt) |
|
if err != nil { |
|
return |
|
} |
|
defer cur.Close(context.TODO()) |
|
|
|
if total, err = conn.Collection.CountDocuments(context.TODO(), filter, nil); err != nil { |
|
return |
|
} |
|
|
|
if res != nil { |
|
err = cur.All(context.TODO(), res) |
|
} |
|
return |
|
} |
|
|
|
func Aggregate(databaseName string, collectionName string, pipeline interface{}, |
|
opt *options.AggregateOptions, res interface{}) (total int64, err error) { |
|
defer func() { |
|
if err != nil { |
|
err = errors.NewErr(err) |
|
} |
|
}() |
|
|
|
conn, err := NewConn(databaseName, collectionName) |
|
defer conn.Close() |
|
if err != nil { |
|
return |
|
} |
|
|
|
cur, err := conn.Collection.Aggregate(context.TODO(), pipeline, opt) |
|
if err != nil { |
|
return |
|
} |
|
defer cur.Close(context.TODO()) |
|
|
|
if res != nil { |
|
if err = cur.All(context.TODO(), res); err != nil { |
|
return |
|
} |
|
} |
|
|
|
count := countResult{} |
|
countPipeline, err := pipelineTruncated(pipeline) |
|
countPipeline = append(countPipeline, bson.M{"$count": "count"}) |
|
if err != nil { |
|
return |
|
} |
|
countCur, err := conn.Collection.Aggregate(context.TODO(), countPipeline, opt) |
|
if err != nil { |
|
return |
|
} |
|
defer countCur.Close(context.TODO()) |
|
if countCur.Next(context.TODO()) { |
|
if err = countCur.Decode(&count); err != nil { |
|
return |
|
} |
|
} |
|
total = count.Count |
|
|
|
return |
|
} |
|
|
|
func bsonFormat(b interface{}) (bb []bson.M, err error) { |
|
switch b.(type) { |
|
case bson.D: |
|
bb = []bson.M{b.(bson.D).Map()} |
|
case []bson.E: |
|
bb = []bson.M{bson.D(b.([]bson.E)).Map()} |
|
case bson.E: |
|
bb = []bson.M{bson.D([]bson.E{b.(bson.E)}).Map()} |
|
case bson.M: |
|
bb = []bson.M{b.(bson.M)} |
|
case map[string]interface{}: |
|
bb = []bson.M{bson.M(b.(map[string]interface{}))} |
|
case []bson.M: |
|
bb = b.([]bson.M) |
|
case []map[string]interface{}: |
|
m := b.([]map[string]interface{}) |
|
bb = make([]bson.M, len(m)) |
|
for idx, data := range m { |
|
bb[idx] = bson.M(data) |
|
} |
|
default: |
|
err = errors.New("Can format bson: %+v", b) |
|
bb = []bson.M{} |
|
} |
|
return |
|
} |
|
|
|
func pipelineTruncated(pipeline interface{}) (res []bson.M, err error) { |
|
m, err := bsonFormat(pipeline) |
|
end := -1 |
|
for i := len(m) - 1; i >= 0; i-- { |
|
if _, exist := m[i]["$limit"]; exist { |
|
continue |
|
} |
|
if _, exist := m[i]["$skip"]; exist { |
|
continue |
|
} |
|
end = i |
|
break |
|
} |
|
res = m[0 : end+1] |
|
return |
|
} |
|
|
|
func Add(databaseName string, collectionName string, |
|
opt *options.InsertManyOptions, documents ...interface{}) (ids []interface{}, err error) { |
|
conn, err := NewConn(databaseName, collectionName) |
|
defer conn.Close() |
|
if err != nil { |
|
return |
|
} |
|
|
|
result, err := conn.Collection.InsertMany(context.TODO(), documents, opt) |
|
if err != nil { |
|
return |
|
} |
|
ids = result.InsertedIDs |
|
return |
|
} |
|
|
|
func Update(databaseName string, collectionName string, filter interface{}, update interface{}, |
|
opt *options.UpdateOptions) (result *mongo.UpdateResult, err error) { |
|
conn, err := NewConn(databaseName, collectionName) |
|
defer conn.Close() |
|
if err != nil { |
|
return |
|
} |
|
|
|
result, err = conn.Collection.UpdateMany(context.TODO(), filter, update, opt) |
|
if err != nil { |
|
return |
|
} |
|
return |
|
} |
|
|
|
func Remove(databaseName string, collectionName string, filter interface{}, |
|
opt *options.DeleteOptions) (count int64, err error) { |
|
conn, err := NewConn(databaseName, collectionName) |
|
defer conn.Close() |
|
if err != nil { |
|
return |
|
} |
|
|
|
result, err := conn.Collection.DeleteMany(context.TODO(), filter, opt) |
|
if err != nil { |
|
return |
|
} |
|
count = result.DeletedCount |
|
return |
|
} |
|
|
|
// AggregateOffset using offset in aggregate |
|
func AggregateOffset(offset int64, number int64) []bson.M { |
|
return []bson.M{ |
|
bson.M{"$limit": offset + number}, |
|
bson.M{"$skip": offset}, |
|
} |
|
} |
|
|
|
func StringToObjectIDs(idStrings ...string) (ids []primitive.ObjectID) { |
|
ids = make([]primitive.ObjectID, 0) |
|
for _, s := range idStrings { |
|
id, e := primitive.ObjectIDFromHex(s) |
|
if e == nil { |
|
ids = append(ids, id) |
|
} |
|
} |
|
return ids |
|
} |
|
|
|
func CollectionExists(databaseName string, collectionName string) (exist bool, err error) { |
|
conn, err := NewConn(databaseName, collectionName) |
|
if err != nil { |
|
return |
|
} |
|
defer conn.Close() |
|
|
|
db := conn.Client.Database(databaseName) |
|
|
|
cursor, err := db.ListCollections(context.TODO(), bson.M{"name": collectionName}) |
|
defer cursor.Close(context.TODO()) |
|
|
|
if err != nil { |
|
return |
|
} |
|
|
|
exist = cursor.TryNext(context.TODO()) |
|
return |
|
}
|
|
|