diff --git a/bson/primitive/objectid_test.go b/bson/primitive/objectid_test.go index 36f64a2b9a..5ed296120b 100644 --- a/bson/primitive/objectid_test.go +++ b/bson/primitive/objectid_test.go @@ -37,6 +37,13 @@ func BenchmarkObjectIDFromHex(b *testing.B) { } } +func BenchmarkNewObjectIDFromTimestamp(b *testing.B) { + for i := 0; i < b.N; i++ { + timestamp := time.Now().Add(time.Duration(i) * time.Millisecond) + _ = NewObjectIDFromTimestamp(timestamp) + } +} + func TestFromHex_RoundTrip(t *testing.T) { before := NewObjectID() after, err := ObjectIDFromHex(before.Hex()) diff --git a/mongo/bulk_write.go b/mongo/bulk_write.go index 87f896aec5..3fdb67b9a2 100644 --- a/mongo/bulk_write.go +++ b/mongo/bulk_write.go @@ -171,7 +171,7 @@ func (bw *bulkWrite) runInsert(ctx context.Context, batch bulkWriteBatch) (opera if err != nil { return operation.InsertResult{}, err } - doc, _, err = ensureID(doc, primitive.NewObjectID(), bw.collection.bsonOpts, bw.collection.registry) + doc, _, err = ensureID(doc, primitive.NilObjectID, bw.collection.bsonOpts, bw.collection.registry) if err != nil { return operation.InsertResult{}, err } diff --git a/mongo/collection.go b/mongo/collection.go index ac173307ff..c7b2a8a113 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -256,7 +256,7 @@ func (coll *Collection) insert(ctx context.Context, documents []interface{}, if err != nil { return nil, err } - bsoncoreDoc, id, err := ensureID(bsoncoreDoc, primitive.NewObjectID(), coll.bsonOpts, coll.registry) + bsoncoreDoc, id, err := ensureID(bsoncoreDoc, primitive.NilObjectID, coll.bsonOpts, coll.registry) if err != nil { return nil, err } diff --git a/mongo/mongo.go b/mongo/mongo.go index 393c5b7713..ec8e817c73 100644 --- a/mongo/mongo.go +++ b/mongo/mongo.go @@ -177,8 +177,11 @@ func marshal( } // ensureID inserts the given ObjectID as an element named "_id" at the -// beginning of the given BSON document if there is not an "_id" already. If -// there is already an element named "_id", the document is not modified. It +// beginning of the given BSON document if there is not an "_id" already. +// If the given ObjectID is primitive.NilObjectID, a new object ID will be +// generated with time.Now(). +// +// If there is already an element named "_id", the document is not modified. It // returns the resulting document and the decoded Go value of the "_id" element. func ensureID( doc bsoncore.Document, @@ -219,6 +222,9 @@ func ensureID( const extraSpace = 17 doc = make(bsoncore.Document, 0, len(olddoc)+extraSpace) _, doc = bsoncore.ReserveLength(doc) + if oid.IsZero() { + oid = primitive.NewObjectID() + } doc = bsoncore.AppendObjectIDElement(doc, "_id", oid) // Remove and re-write the BSON document length header. diff --git a/mongo/mongo_test.go b/mongo/mongo_test.go index 8055236a86..b17422ce1e 100644 --- a/mongo/mongo_test.go +++ b/mongo/mongo_test.go @@ -134,6 +134,29 @@ func TestEnsureID(t *testing.T) { } } +func TestEnsureID_NilObjectID(t *testing.T) { + t.Parallel() + + doc := bsoncore.NewDocumentBuilder(). + AppendString("foo", "bar"). + Build() + + got, gotIDI, err := ensureID(doc, primitive.NilObjectID, nil, nil) + assert.NoError(t, err) + + gotID, ok := gotIDI.(primitive.ObjectID) + + assert.True(t, ok) + assert.NotEqual(t, primitive.NilObjectID, gotID) + + want := bsoncore.NewDocumentBuilder(). + AppendObjectID("_id", gotID). + AppendString("foo", "bar"). + Build() + + assert.Equal(t, want, got) +} + func TestMarshalAggregatePipeline(t *testing.T) { // []byte of [{{"$limit", 12345}}] index, arr := bsoncore.AppendArrayStart(nil)