From cb5e8aa64963786c6e0c0ab5665329ed54f7f6b7 Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Fri, 18 Oct 2024 00:26:00 -0400 Subject: [PATCH] WIP --- mongo/client_bulk_write.go | 68 +++++-- mongo/integration/crud_prose_test.go | 284 +++++++++++++++++++++++---- x/mongo/driver/operation.go | 14 +- 3 files changed, 304 insertions(+), 62 deletions(-) diff --git a/mongo/client_bulk_write.go b/mongo/client_bulk_write.go index 19d43684b8..5762cc8033 100644 --- a/mongo/client_bulk_write.go +++ b/mongo/client_bulk_write.go @@ -73,8 +73,31 @@ func (bw *clientBulkWrite) execute(ctx context.Context) error { Authenticator: bw.client.authenticator, Name: "bulkWrite", }.Execute(ctx) - if err != nil && errors.Is(err, driver.ErrUnacknowledgedWrite) { - return nil + var exception *ClientBulkWriteException + switch tt := err.(type) { + case CommandError: + exception = &ClientBulkWriteException{ + TopLevelError: &WriteError{ + Code: int(tt.Code), + Message: tt.Message, + Raw: tt.Raw, + }, + } + default: + if errors.Is(err, driver.ErrUnacknowledgedWrite) { + err = nil + } + } + if len(batches.writeConcernErrors) > 0 || len(batches.writeErrors) > 0 { + if exception == nil { + exception = new(ClientBulkWriteException) + } + exception.WriteConcernErrors = batches.writeConcernErrors + exception.WriteErrors = batches.writeErrors + } + if exception != nil { + exception.PartialResult = batches.result + return *exception } return err } @@ -219,7 +242,7 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD return 0, dst, io.EOF } - mb.cursorHandlers = make([]func(*cursorInfo, bson.Raw) bool, len(mb.models)) + mb.cursorHandlers = mb.cursorHandlers[:0] mb.newIDMap = make(map[int]interface{}) nsMap := make(map[string]int) @@ -240,6 +263,7 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD opsIdx, dst := fn.appendStart(dst, "ops") nsIdx, nsDst := fn.appendStart(nil, "nsInfo") + totalSize -= 1000 size := (len(dst) - l) * 2 var n int for i := mb.offset; i < len(mb.models); i++ { @@ -254,7 +278,7 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD var err error switch model := mb.models[i].model.(type) { case *ClientInsertOneModel: - mb.cursorHandlers[i] = mb.appendInsertResult + mb.cursorHandlers = append(mb.cursorHandlers, mb.appendInsertResult) var id interface{} id, doc, err = (&clientInsertDoc{ namespace: nsIdx, @@ -265,7 +289,7 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD } mb.newIDMap[i] = id case *ClientUpdateOneModel: - mb.cursorHandlers[i] = mb.appendUpdateResult + mb.cursorHandlers = append(mb.cursorHandlers, mb.appendUpdateResult) doc, err = (&clientUpdateDoc{ namespace: nsIdx, filter: model.Filter, @@ -279,7 +303,7 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD }).marshal(mb.client.bsonOpts, mb.client.registry) case *ClientUpdateManyModel: canRetry = false - mb.cursorHandlers[i] = mb.appendUpdateResult + mb.cursorHandlers = append(mb.cursorHandlers, mb.appendUpdateResult) doc, err = (&clientUpdateDoc{ namespace: nsIdx, filter: model.Filter, @@ -292,7 +316,7 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD checkDollarKey: true, }).marshal(mb.client.bsonOpts, mb.client.registry) case *ClientReplaceOneModel: - mb.cursorHandlers[i] = mb.appendUpdateResult + mb.cursorHandlers = append(mb.cursorHandlers, mb.appendUpdateResult) doc, err = (&clientUpdateDoc{ namespace: nsIdx, filter: model.Filter, @@ -305,7 +329,7 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD checkDollarKey: false, }).marshal(mb.client.bsonOpts, mb.client.registry) case *ClientDeleteOneModel: - mb.cursorHandlers[i] = mb.appendDeleteResult + mb.cursorHandlers = append(mb.cursorHandlers, mb.appendDeleteResult) doc, err = (&clientDeleteDoc{ namespace: nsIdx, filter: model.Filter, @@ -315,7 +339,7 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD }).marshal(mb.client.bsonOpts, mb.client.registry) case *ClientDeleteManyModel: canRetry = false - mb.cursorHandlers[i] = mb.appendDeleteResult + mb.cursorHandlers = append(mb.cursorHandlers, mb.appendDeleteResult) doc, err = (&clientDeleteDoc{ namespace: nsIdx, filter: model.Filter, @@ -323,14 +347,19 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD hint: model.Hint, multi: true, }).marshal(mb.client.bsonOpts, mb.client.registry) + default: + mb.cursorHandlers = append(mb.cursorHandlers, nil) } if err != nil { return 0, nil, err } - length := len(doc) + len(ns) + length := len(doc) if length > maxDocSize { break } + if !exists { + length += len(ns) + } size += length if size >= totalSize { break @@ -369,7 +398,6 @@ func (mb *modelBatches) processResponse(ctx context.Context, resp bsoncore.Docum mb.writeConcernErrors = append(mb.writeConcernErrors, *wce) } } - // closeImplicitSession(sess) if len(resp) == 0 { return nil } @@ -435,8 +463,9 @@ func (mb *modelBatches) processResponse(ctx context.Context, resp bsoncore.Docum if err != nil { return err } + isOrdered := mb.ordered == nil || *mb.ordered fmt.Println("ProcessResponse toplevelerror", res.Ok, res.NErrors, res.Code, res.Errmsg) - if writeCmdErr.WriteConcernError != nil || !ok || !res.Ok || res.NErrors > 0 { + if isOrdered && (writeCmdErr.WriteConcernError != nil || !ok || !res.Ok || res.NErrors > 0) { exception := ClientBulkWriteException{ WriteConcernErrors: mb.writeConcernErrors, WriteErrors: mb.writeErrors, @@ -455,48 +484,51 @@ func (mb *modelBatches) processResponse(ctx context.Context, resp bsoncore.Docum } func (mb *modelBatches) appendDeleteResult(cur *cursorInfo, raw bson.Raw) bool { + idx := int(cur.Idx) + mb.offset if err := cur.extractError(); err != nil { err.Raw = raw if mb.writeErrors == nil { mb.writeErrors = make(map[int]WriteError) } - mb.writeErrors[int(cur.Idx)] = *err + mb.writeErrors[idx] = *err return false } if mb.result.DeleteResults == nil { mb.result.DeleteResults = make(map[int]ClientDeleteResult) } - mb.result.DeleteResults[int(cur.Idx)] = ClientDeleteResult{int64(cur.N)} + mb.result.DeleteResults[idx] = ClientDeleteResult{int64(cur.N)} return true } func (mb *modelBatches) appendInsertResult(cur *cursorInfo, raw bson.Raw) bool { + idx := int(cur.Idx) + mb.offset if err := cur.extractError(); err != nil { err.Raw = raw if mb.writeErrors == nil { mb.writeErrors = make(map[int]WriteError) } - mb.writeErrors[int(cur.Idx)] = *err + mb.writeErrors[idx] = *err return false } if mb.result.InsertResults == nil { mb.result.InsertResults = make(map[int]ClientInsertResult) } - mb.result.InsertResults[int(cur.Idx)] = ClientInsertResult{mb.newIDMap[int(cur.Idx)]} + mb.result.InsertResults[idx] = ClientInsertResult{mb.newIDMap[idx]} return true } func (mb *modelBatches) appendUpdateResult(cur *cursorInfo, raw bson.Raw) bool { + idx := int(cur.Idx) + mb.offset if err := cur.extractError(); err != nil { err.Raw = raw if mb.writeErrors == nil { mb.writeErrors = make(map[int]WriteError) } - mb.writeErrors[int(cur.Idx)] = *err + mb.writeErrors[idx] = *err return false } @@ -512,7 +544,7 @@ func (mb *modelBatches) appendUpdateResult(cur *cursorInfo, raw bson.Raw) bool { if cur.Upserted != nil { result.UpsertedID = cur.Upserted.ID } - mb.result.UpdateResults[int(cur.Idx)] = result + mb.result.UpdateResults[idx] = result return true } diff --git a/mongo/integration/crud_prose_test.go b/mongo/integration/crud_prose_test.go index ba8e9d6606..c2ce360537 100644 --- a/mongo/integration/crud_prose_test.go +++ b/mongo/integration/crud_prose_test.go @@ -22,6 +22,7 @@ import ( "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/mongo/driver" ) func TestWriteErrorsWithLabels(t *testing.T) { @@ -442,11 +443,11 @@ func TestClientBulkWrite(t *testing.T) { }) } result, err := mt.Client.BulkWrite(context.Background(), models) - require.NoError(mt, err, "BulkWrite error") - assert.Equal(mt, hello.MaxWriteBatchSize+1, int(result.InsertedCount), "InsertedCount expected to be %d", hello.MaxWriteBatchSize+1) + require.NoError(mt, err, "BulkWrite error", err) + assert.Equal(mt, hello.MaxWriteBatchSize+1, int(result.InsertedCount), "expected InsertedCount: %d, got %d", hello.MaxWriteBatchSize+1, int(result.InsertedCount)) require.Len(mt, opsCnt, 2, "expected 2 bulkWrite commands") - assert.Equal(mt, hello.MaxWriteBatchSize, opsCnt[0], "the length of firstEvent.command.ops is %d", hello.MaxWriteBatchSize) - assert.Equal(mt, 1, opsCnt[1], "the length of secondEvent.command.ops is 1") + assert.Equal(mt, hello.MaxWriteBatchSize, opsCnt[0], "expected %d firstEvent.command.ops, got: %d", hello.MaxWriteBatchSize, opsCnt[0]) + assert.Equal(mt, 1, opsCnt[1], "expected %d secondEvent.command.ops, got %d", 1, opsCnt[1]) }) mt.Run("input with greater than maxMessageSizeBytes", func(mt *mtest.T) { @@ -476,11 +477,11 @@ func TestClientBulkWrite(t *testing.T) { }) } result, err := mt.Client.BulkWrite(context.Background(), models) - require.NoError(mt, err, "BulkWrite error") - assert.Equal(mt, numModels, int(result.InsertedCount), "InsertedCount expected to be %d", numModels) + require.NoError(mt, err, "BulkWrite error", err) + assert.Equal(mt, numModels, int(result.InsertedCount), "expected InsertedCount: %d, got: %d", numModels, int(result.InsertedCount)) require.Len(mt, opsCnt, 2, "expected 2 bulkWrite commands") - assert.Equal(mt, numModels-1, opsCnt[0], "the length of firstEvent.command.ops is %d", numModels-1) - assert.Equal(mt, 1, opsCnt[1], "the length of secondEvent.command.ops is 1") + assert.Equal(mt, numModels-1, opsCnt[0], "expected %d firstEvent.command.ops, got %d", numModels-1, opsCnt[0]) + assert.Equal(mt, 1, opsCnt[1], "expected %d secondEvent.command.ops, got: %d", 1, opsCnt[1]) }) mt.Run("bulkWrite collects WriteConcernErrors across batches", func(mt *mtest.T) { @@ -488,7 +489,8 @@ func TestClientBulkWrite(t *testing.T) { var hello struct { MaxWriteBatchSize int } - require.NoError(mt, mt.DB.RunCommand(context.Background(), bson.D{{"hello", 1}}).Decode(&hello), "Hello error") + err := mt.DB.RunCommand(context.Background(), bson.D{{"hello", 1}}).Decode(&hello) + require.NoError(mt, err, "Hello error") mt.SetFailPoint(mtest.FailPoint{ ConfigureFailPoint: "failCommand", @@ -511,14 +513,14 @@ func TestClientBulkWrite(t *testing.T) { Document: bson.D{{"a", "b"}}, }) } - _, err := mt.Client.BulkWrite(context.Background(), models) + _, err = mt.Client.BulkWrite(context.Background(), models) require.Error(mt, err) bwe, ok := err.(mongo.ClientBulkWriteException) - require.True(mt, ok, "expected a BulkWriteException, got %T", err) - assert.Len(mt, bwe.WriteConcernErrors, 2, "expected 2 writeConcernErrors") + require.True(mt, ok, "expected a BulkWriteException, got %T: %v", err, err) + assert.Len(mt, bwe.WriteConcernErrors, 2, "expected writeConcernErrors: %d, got: %d", 2, len(bwe.WriteConcernErrors)) require.NotNil(mt, bwe.PartialResult) assert.Equal(mt, hello.MaxWriteBatchSize+1, int(bwe.PartialResult.InsertedCount), - "InsertedCount expected to be %d", hello.MaxWriteBatchSize+1) + "expected InsertedCount: %d, got: %d", hello.MaxWriteBatchSize+1, int(bwe.PartialResult.InsertedCount)) }) mt.Run("bulkWrite handles individual WriteErrors across batches", func(mt *mtest.T) { @@ -550,25 +552,25 @@ func TestClientBulkWrite(t *testing.T) { }) } - mt.Run("Unordered", func(mt *mtest.T) { + mt.Run("unordered", func(mt *mtest.T) { eventCnt = 0 mt.ResetClient(options.Client().SetMonitor(monitor)) _, err := mt.Client.BulkWrite(context.Background(), models, options.ClientBulkWrite().SetOrdered(false)) require.Error(mt, err) bwe, ok := err.(mongo.ClientBulkWriteException) - require.True(mt, ok, "expected a BulkWriteException, got %T", err) + require.True(mt, ok, "expected a BulkWriteException, got %T: %v", err, err) assert.Len(mt, bwe.WriteErrors, hello.MaxWriteBatchSize+1, "expected %d writeErrors, got %d", hello.MaxWriteBatchSize+1, len(bwe.WriteErrors)) - require.Equal(mt, 2, eventCnt, "expected 2 bulkWrite commands, got %d", eventCnt) + require.Equal(mt, 2, eventCnt, "expected %d bulkWrite commands, got %d", 2, eventCnt) }) - mt.Run("Ordered", func(mt *mtest.T) { + mt.Run("ordered", func(mt *mtest.T) { eventCnt = 0 mt.ResetClient(options.Client().SetMonitor(monitor)) _, err := mt.Client.BulkWrite(context.Background(), models, options.ClientBulkWrite().SetOrdered(true)) require.Error(mt, err) bwe, ok := err.(mongo.ClientBulkWriteException) - require.True(mt, ok, "expected a BulkWriteException, got %T", err) - assert.Len(mt, bwe.WriteErrors, 1, "expected %d writeErrors, got %d", 1, len(bwe.WriteErrors)) - require.Equal(mt, 1, eventCnt, "expected 1 bulkWrite commands, got %d", eventCnt) + require.True(mt, ok, "expected a BulkWriteException, got %T: %v", err, err) + assert.Len(mt, bwe.WriteErrors, 1, "expected writeErrors: %d, got: %d", 1, len(bwe.WriteErrors)) + require.Equal(mt, 1, eventCnt, "expected %d bulkWrite commands, got %d", 1, eventCnt) }) }) @@ -577,11 +579,11 @@ func TestClientBulkWrite(t *testing.T) { err := coll.Drop(context.Background()) require.NoError(mt, err, "Drop error") - var getMoreCalled bool + var getMoreCalled int monitor := &event.CommandMonitor{ Started: func(_ context.Context, e *event.CommandStartedEvent) { if e.CommandName == "getMore" { - getMoreCalled = true + getMoreCalled++ } }, } @@ -605,9 +607,9 @@ func TestClientBulkWrite(t *testing.T) { }) result, err := mt.Client.BulkWrite(context.Background(), models, options.ClientBulkWrite().SetVerboseResults(true)) require.NoError(mt, err, "BulkWrite error") - assert.Equal(mt, 2, int(result.UpsertedCount), "InsertedCount expected to be %d, got %d", 2, result.UpsertedCount) - assert.Len(mt, result.UpdateResults, 2, "expected %d UpdateResults, got %d", 2, len(result.UpdateResults)) - assert.True(mt, getMoreCalled, "the getMore was not called") + assert.Equal(mt, 2, int(result.UpsertedCount), "expected InsertedCount: %d, got: %d", 2, result.UpsertedCount) + assert.Len(mt, result.UpdateResults, 2, "expected UpdateResults: %d, got: %d", 2, len(result.UpdateResults)) + assert.Equal(mt, 1, getMoreCalled, "expected %d getMore call, got %d", 1, getMoreCalled) }) mt.Run("bulkWrite handles a cursor requiring a getMore within a transaction", func(mt *mtest.T) { @@ -615,11 +617,11 @@ func TestClientBulkWrite(t *testing.T) { err := coll.Drop(context.Background()) require.NoError(mt, err, "Drop error") - var getMoreCalled bool + var getMoreCalled int monitor := &event.CommandMonitor{ Started: func(_ context.Context, e *event.CommandStartedEvent) { if e.CommandName == "getMore" { - getMoreCalled = true + getMoreCalled++ } }, } @@ -650,12 +652,69 @@ func TestClientBulkWrite(t *testing.T) { require.NoError(mt, err, "BulkWrite error") cbwResult, ok := result.(*mongo.ClientBulkWriteResult) require.True(mt, ok, "expected a ClientBulkWriteResult") - assert.Equal(mt, 2, int(cbwResult.UpsertedCount), "InsertedCount expected to be %d, got %d", 2, cbwResult.UpsertedCount) - assert.Len(mt, cbwResult.UpdateResults, 2, "expected %d UpdateResults, got %d", 2, len(cbwResult.UpdateResults)) - assert.True(mt, getMoreCalled, "the getMore was not called") + assert.Equal(mt, 2, int(cbwResult.UpsertedCount), "expected InsertedCount: %d, got: %d", 2, cbwResult.UpsertedCount) + assert.Len(mt, cbwResult.UpdateResults, 2, "expected UpdateResults: %d, got: %d", 2, len(cbwResult.UpdateResults)) + assert.Equal(mt, 1, getMoreCalled, "expected %d getMore call, got %d", 1, getMoreCalled) }) - mt.Run("bulkWrite handles a getMore error", func(_ *mtest.T) { + mt.Run("bulkWrite handles a getMore error", func(mt *mtest.T) { + var getMoreCalled int + var killCursorsCalled int + monitor := &event.CommandMonitor{ + Started: func(_ context.Context, e *event.CommandStartedEvent) { + switch e.CommandName { + case "getMore": + getMoreCalled++ + case "killCursors": + killCursorsCalled++ + } + }, + } + mt.ResetClient(options.Client().SetMonitor(monitor)) + var hello struct { + MaxBsonObjectSize int + } + err := mt.DB.RunCommand(context.Background(), bson.D{{"hello", 1}}).Decode(&hello) + require.NoError(mt, err, "Hello error") + + mt.SetFailPoint(mtest.FailPoint{ + ConfigureFailPoint: "failCommand", + Mode: mtest.FailPointMode{ + Times: 1, + }, + Data: mtest.FailPointData{ + FailCommands: []string{"getMore"}, + ErrorCode: 8, + }, + }) + + coll := mt.CreateCollection(mtest.Collection{DB: "db", Name: "coll"}, true) + err = coll.Drop(context.Background()) + require.NoError(mt, err, "Drop error") + + upsert := true + models := (&mongo.ClientWriteModels{}). + AppendUpdateOne("db", "coll", &mongo.ClientUpdateOneModel{ + Filter: bson.D{{"_id", strings.Repeat("a", hello.MaxBsonObjectSize/2)}}, + Update: bson.D{{"$set", bson.D{{"x", 1}}}}, + Upsert: &upsert, + }). + AppendUpdateOne("db", "coll", &mongo.ClientUpdateOneModel{ + Filter: bson.D{{"_id", strings.Repeat("b", hello.MaxBsonObjectSize/2)}}, + Update: bson.D{{"$set", bson.D{{"x", 1}}}}, + Upsert: &upsert, + }) + _, err = mt.Client.BulkWrite(context.Background(), models, options.ClientBulkWrite().SetVerboseResults(true)) + assert.Error(mt, err) + bwe, ok := err.(mongo.ClientBulkWriteException) + require.True(mt, ok, "expected a BulkWriteException, got %T: %v", err, err) + require.NotNil(mt, bwe.PartialResult) + require.NotNil(mt, bwe.TopLevelError) + assert.Equal(mt, 8, bwe.TopLevelError.Code, "expected top level error code: %d, got; %d", 8, bwe.TopLevelError.Code) + assert.Equal(mt, int64(2), bwe.PartialResult.UpsertedCount, "expected UpsertedCount: %d, got: %d", 2, bwe.PartialResult.UpsertedCount) + assert.Len(mt, bwe.PartialResult.UpdateResults, 1, "expected UpdateResults: %d, got: %d", 1, len(bwe.PartialResult.UpdateResults)) + assert.Equal(mt, 1, getMoreCalled, "expected %d getMore call, got %d", 1, getMoreCalled) + assert.Equal(mt, 1, killCursorsCalled, "expected %d killCursors call, got %d", 1, killCursorsCalled) }) mt.Run("bulkWrite returns error for unacknowledged too-large insert", func(mt *mtest.T) { @@ -663,14 +722,165 @@ func TestClientBulkWrite(t *testing.T) { var hello struct { MaxBsonObjectSize int } - require.NoError(mt, mt.DB.RunCommand(context.Background(), bson.D{{"hello", 1}}).Decode(&hello), "Hello error") + err := mt.DB.RunCommand(context.Background(), bson.D{{"hello", 1}}).Decode(&hello) + require.NoError(mt, err, "Hello error") + mt.Run("insert", func(mt *mtest.T) { + models := (&mongo.ClientWriteModels{}). + AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{ + Document: bson.D{{"a", strings.Repeat("b", hello.MaxBsonObjectSize)}}, + }) + _, err := mt.Client.BulkWrite(context.Background(), models, options.ClientBulkWrite().SetOrdered(false).SetWriteConcern(writeconcern.Unacknowledged())) + require.EqualError(mt, err, driver.ErrDocumentTooLarge.Error()) + }) + mt.Run("replace", func(mt *mtest.T) { + models := (&mongo.ClientWriteModels{}). + AppendReplaceOne("db", "coll", &mongo.ClientReplaceOneModel{ + Filter: bson.D{}, + Replacement: bson.D{{"a", strings.Repeat("b", hello.MaxBsonObjectSize)}}, + }) + _, err := mt.Client.BulkWrite(context.Background(), models, options.ClientBulkWrite().SetOrdered(false).SetWriteConcern(writeconcern.Unacknowledged())) + require.EqualError(mt, err, driver.ErrDocumentTooLarge.Error()) + }) + }) + + mt.Run("bulkWrite batch splits when the addition of a new namespace exceeds the maximum message size", func(mt *mtest.T) { + var bwCmd []bsoncore.Document + monitor := &event.CommandMonitor{ + Started: func(_ context.Context, e *event.CommandStartedEvent) { + if e.CommandName == "bulkWrite" { + bwCmd = append(bwCmd, bsoncore.Document(e.Command)) + } + }, + } + var hello struct { + MaxBsonObjectSize int + MaxMessageSizeBytes int + } + err := mt.DB.RunCommand(context.Background(), bson.D{{"hello", 1}}).Decode(&hello) + require.NoError(mt, err, "Hello error") + + newModels := func() (int, *mongo.ClientWriteModels) { + maxBsonObjectSize := hello.MaxBsonObjectSize + opsBytes := hello.MaxMessageSizeBytes - 1122 + numModels := opsBytes / maxBsonObjectSize + + models := &mongo.ClientWriteModels{} + n := numModels + for i := 0; i < n; i++ { + models. + AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{ + Document: bson.D{{"a", strings.Repeat("b", maxBsonObjectSize-57)}}, + }) + } + if remainderBytes := opsBytes % maxBsonObjectSize; remainderBytes > 217 { + n++ + models. + AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{ + Document: bson.D{{"a", strings.Repeat("b", remainderBytes-57)}}, + }) + } + return n, models + } + mt.Run("no batch-splitting required", func(mt *mtest.T) { + bwCmd = bwCmd[:0] + mt.ResetClient(options.Client().SetMonitor(monitor)) + + numModels, models := newModels() + models.AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{ + Document: bson.D{{"a", "b"}}, + }) + result, err := mt.Client.BulkWrite(context.Background(), models, options.ClientBulkWrite()) + require.NoError(mt, err) + assert.Equal(mt, numModels+1, int(result.InsertedCount), "expected insertedCound: %d, got: %d", numModels+1, result.InsertedCount) + require.Len(mt, bwCmd, 1, "expected %d bulkWrite call, got %d", 1, len(bwCmd)) + + var cmd struct { + Ops []bson.D + NsInfo []struct { + Ns string + } + } + err = bson.Unmarshal(bwCmd[0], &cmd) + require.NoError(mt, err) + assert.Len(mt, cmd.Ops, numModels+1, "expected ops: %d, got: %d", numModels+1, len(cmd.Ops)) + require.Len(mt, cmd.NsInfo, 1, "expected %d nsInfo, got: %d", 1, len(cmd.NsInfo)) + assert.Equal(mt, "db.coll", cmd.NsInfo[0].Ns, "expected namespace: %s, got: %s", "db.coll", cmd.NsInfo[0].Ns) + }) + mt.Run("batch-splitting required", func(mt *mtest.T) { + bwCmd = bwCmd[:0] + mt.ResetClient(options.Client().SetMonitor(monitor)) + + coll := strings.Repeat("c", 200) + numModels, models := newModels() + models.AppendInsertOne("db", coll, &mongo.ClientInsertOneModel{ + Document: bson.D{{"a", "b"}}, + }) + result, err := mt.Client.BulkWrite(context.Background(), models, options.ClientBulkWrite()) + require.NoError(mt, err) + assert.Equal(mt, numModels+1, int(result.InsertedCount), "expected insertedCound: %d, got: %d", numModels+1, result.InsertedCount) + require.Len(mt, bwCmd, 2, "expected %d bulkWrite call, got %d", 2, len(bwCmd)) + + type cmd struct { + Ops []bson.D + NsInfo []struct { + Ns string + } + } + var c1 cmd + err = bson.Unmarshal(bwCmd[0], &c1) + require.NoError(mt, err) + assert.Len(mt, c1.Ops, numModels, "expected ops: %d, got: %d", numModels, len(c1.Ops)) + require.Len(mt, c1.NsInfo, 1, "expected %d nsInfo, got: %d", 1, len(c1.NsInfo)) + assert.Equal(mt, "db.coll", c1.NsInfo[0].Ns, "expected namespace: %s, got: %s", "db.coll", c1.NsInfo[0].Ns) + + var c2 cmd + err = bson.Unmarshal(bwCmd[1], &c2) + require.NoError(mt, err) + assert.Len(mt, c2.Ops, 1, "expected ops: %d, got: %d", 1, len(c2.Ops)) + require.Len(mt, c2.NsInfo, 1, "expected %d nsInfo, got: %d", 1, len(c2.NsInfo)) + assert.Equal(mt, "db."+coll, c2.NsInfo[0].Ns, "expected namespace: %s, got: %s", "db."+coll, c2.NsInfo[0].Ns) + }) + }) + + mt.Run("bulkWrite returns an error if no operations can be added to ops", func(mt *mtest.T) { + var hello struct { + MaxMessageSizeBytes int + } + err := mt.DB.RunCommand(context.Background(), bson.D{{"hello", 1}}).Decode(&hello) + require.NoError(mt, err, "Hello error") + mt.Run("document too large", func(mt *mtest.T) { + models := (&mongo.ClientWriteModels{}). + AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{ + Document: bson.D{{"a", strings.Repeat("b", hello.MaxMessageSizeBytes)}}, + }) + _, err := mt.Client.BulkWrite(context.Background(), models, options.ClientBulkWrite().SetOrdered(false).SetWriteConcern(writeconcern.Unacknowledged())) + require.EqualError(mt, err, driver.ErrDocumentTooLarge.Error()) + }) + mt.Run("document too large", func(mt *mtest.T) { + models := (&mongo.ClientWriteModels{}). + AppendInsertOne("db", strings.Repeat("c", hello.MaxMessageSizeBytes), &mongo.ClientInsertOneModel{ + Document: bson.D{{"a", "b"}}, + }) + _, err := mt.Client.BulkWrite(context.Background(), models, options.ClientBulkWrite().SetOrdered(false).SetWriteConcern(writeconcern.Unacknowledged())) + require.EqualError(mt, err, driver.ErrDocumentTooLarge.Error()) + }) + }) + + mt.Run("bulkWrite returns an error if auto-encryption is configured", func(mt *mtest.T) { + autoEncryptionOpts := options.AutoEncryption(). + SetKeyVaultNamespace("db.coll"). + SetKmsProviders(map[string]map[string]interface{}{ + "aws": { + "accessKeyId": "foo", + "secretAccessKey": "bar", + }, + }) + mt.ResetClient(options.Client().SetAutoEncryptionOptions(autoEncryptionOpts)) models := (&mongo.ClientWriteModels{}). AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{ - Document: bson.D{{"a", strings.Repeat("b", hello.MaxBsonObjectSize)}}, + Document: bson.D{{"a", "b"}}, }) - result, err := mt.Client.BulkWrite(context.Background(), models, options.ClientBulkWrite().SetWriteConcern(writeconcern.Unacknowledged())) - require.NoError(mt, err, "BulkWrite error") - assert.Equal(mt, 2, int(result.UpsertedCount), "InsertedCount expected to be %d, got %d", 2, result.UpsertedCount) - assert.Len(mt, result.UpdateResults, 2, "expected %d UpdateResults, got %d", 2, len(result.UpdateResults)) + _, err := mt.Client.BulkWrite(context.Background(), models, options.ClientBulkWrite().SetOrdered(false).SetWriteConcern(writeconcern.Unacknowledged())) + require.ErrorContains(mt, err, "bulkWrite does not currently support automatic encryption") }) } diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 2eb9a1be70..c2f4601947 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -862,9 +862,6 @@ func (op Operation) Execute(ctx context.Context) error { Error: tt, } _ = op.ProcessResponseFn(ctx, res, info) - // if perr != nil { - // return perr - // } } // If batching is enabled and either ordered is the default (which is true) or @@ -985,9 +982,6 @@ func (op Operation) Execute(ctx context.Context) error { Error: tt, } _ = op.ProcessResponseFn(ctx, res, info) - // if perr != nil { - // return perr - // } } if op.Client != nil && op.Client.Committing && (retryableErr || tt.Code == 50) { @@ -1385,7 +1379,10 @@ func (op Operation) createWireMessage( dsOffset := len(dst) processedBatches, dst, err = op.Batches.AppendBatchSequence(dst, int(desc.MaxBatchCount), int(desc.MaxDocumentSize), int(desc.MaxMessageSize)) if err != nil { - return dst, err + return nil, err + } + if processedBatches == 0 { + return nil, ErrDocumentTooLarge } info.processedBatches = processedBatches info.documentSequence = make([]byte, 0) @@ -1492,6 +1489,9 @@ func (op Operation) addLegacyCommandFields(dst []byte, desc description.Selected if err != nil { return 0, nil, err } + if n == 0 { + return 0, nil, ErrDocumentTooLarge + } return n, dst, nil }