Skip to content

Commit

Permalink
Generate new ObjectID only when required (#1479)
Browse files Browse the repository at this point in the history
  • Loading branch information
adhocore authored Dec 12, 2023
1 parent 4dbe540 commit 820366e
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 4 deletions.
7 changes: 7 additions & 0 deletions bson/primitive/objectid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ func BenchmarkObjectIDFromHex(b *testing.B) {
}
}

func BenchmarkNewObjectIDFromTimestamp(b *testing.B) {
for i := 0; i < b.N; i++ {
timestamp := time.Now().Add(time.Duration(i) * time.Millisecond)
_ = NewObjectIDFromTimestamp(timestamp)
}
}

func TestFromHex_RoundTrip(t *testing.T) {
before := NewObjectID()
after, err := ObjectIDFromHex(before.Hex())
Expand Down
2 changes: 1 addition & 1 deletion mongo/bulk_write.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ func (bw *bulkWrite) runInsert(ctx context.Context, batch bulkWriteBatch) (opera
if err != nil {
return operation.InsertResult{}, err
}
doc, _, err = ensureID(doc, primitive.NewObjectID(), bw.collection.bsonOpts, bw.collection.registry)
doc, _, err = ensureID(doc, primitive.NilObjectID, bw.collection.bsonOpts, bw.collection.registry)
if err != nil {
return operation.InsertResult{}, err
}
Expand Down
2 changes: 1 addition & 1 deletion mongo/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ func (coll *Collection) insert(ctx context.Context, documents []interface{},
if err != nil {
return nil, err
}
bsoncoreDoc, id, err := ensureID(bsoncoreDoc, primitive.NewObjectID(), coll.bsonOpts, coll.registry)
bsoncoreDoc, id, err := ensureID(bsoncoreDoc, primitive.NilObjectID, coll.bsonOpts, coll.registry)
if err != nil {
return nil, err
}
Expand Down
10 changes: 8 additions & 2 deletions mongo/mongo.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,11 @@ func marshal(
}

// ensureID inserts the given ObjectID as an element named "_id" at the
// beginning of the given BSON document if there is not an "_id" already. If
// there is already an element named "_id", the document is not modified. It
// beginning of the given BSON document if there is not an "_id" already.
// If the given ObjectID is primitive.NilObjectID, a new object ID will be
// generated with time.Now().
//
// If there is already an element named "_id", the document is not modified. It
// returns the resulting document and the decoded Go value of the "_id" element.
func ensureID(
doc bsoncore.Document,
Expand Down Expand Up @@ -219,6 +222,9 @@ func ensureID(
const extraSpace = 17
doc = make(bsoncore.Document, 0, len(olddoc)+extraSpace)
_, doc = bsoncore.ReserveLength(doc)
if oid.IsZero() {
oid = primitive.NewObjectID()
}
doc = bsoncore.AppendObjectIDElement(doc, "_id", oid)

// Remove and re-write the BSON document length header.
Expand Down
23 changes: 23 additions & 0 deletions mongo/mongo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,29 @@ func TestEnsureID(t *testing.T) {
}
}

func TestEnsureID_NilObjectID(t *testing.T) {
t.Parallel()

doc := bsoncore.NewDocumentBuilder().
AppendString("foo", "bar").
Build()

got, gotIDI, err := ensureID(doc, primitive.NilObjectID, nil, nil)
assert.NoError(t, err)

gotID, ok := gotIDI.(primitive.ObjectID)

assert.True(t, ok)
assert.NotEqual(t, primitive.NilObjectID, gotID)

want := bsoncore.NewDocumentBuilder().
AppendObjectID("_id", gotID).
AppendString("foo", "bar").
Build()

assert.Equal(t, want, got)
}

func TestMarshalAggregatePipeline(t *testing.T) {
// []byte of [{{"$limit", 12345}}]
index, arr := bsoncore.AppendArrayStart(nil)
Expand Down

0 comments on commit 820366e

Please sign in to comment.