diff --git a/mongo/client.go b/mongo/client.go index 5d8e7ece67..cebd06559c 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -909,8 +909,10 @@ func (c *Client) BulkWrite(ctx context.Context, models *ClientWriteModels, if bwo.VerboseResults == nil || !(*bwo.VerboseResults) { op.errorsOnly = true } - err = op.execute(ctx) - return &op.result, replaceErrors(err) + if err = op.execute(ctx); err != nil { + return nil, replaceErrors(err) + } + return &op.result, nil } // newLogger will use the LoggerOptions to create an internal logger and publish diff --git a/mongo/client_bulk_write.go b/mongo/client_bulk_write.go index 5762cc8033..22b06ec23c 100644 --- a/mongo/client_bulk_write.go +++ b/mongo/client_bulk_write.go @@ -96,7 +96,16 @@ func (bw *clientBulkWrite) execute(ctx context.Context) error { exception.WriteErrors = batches.writeErrors } if exception != nil { - exception.PartialResult = batches.result + var hasSuccess bool + if bw.ordered == nil || *bw.ordered { + _, ok := batches.writeErrors[0] + hasSuccess = !ok + } else { + hasSuccess = len(batches.writeErrors) < len(bw.models) + } + if hasSuccess { + exception.PartialResult = batches.result + } return *exception } return err @@ -471,7 +480,7 @@ func (mb *modelBatches) processResponse(ctx context.Context, resp bsoncore.Docum WriteErrors: mb.writeErrors, PartialResult: mb.result, } - if !res.Ok || res.NErrors > 0 { + if !res.Ok { exception.TopLevelError = &WriteError{ Code: int(res.Code), Message: res.Errmsg, diff --git a/mongo/integration/crud_prose_test.go b/mongo/integration/crud_prose_test.go index c2ce360537..89aa7f1296 100644 --- a/mongo/integration/crud_prose_test.go +++ b/mongo/integration/crud_prose_test.go @@ -10,6 +10,7 @@ import ( "bytes" "context" "errors" + "os" "strings" "testing" @@ -415,7 +416,7 @@ func TestErrorsCodeNamePropagated(t *testing.T) { } func TestClientBulkWrite(t *testing.T) { - mtOpts := mtest.NewOptions().MinServerVersion("8.0").AtlasDataLake(false).CreateClient(false) + mtOpts := mtest.NewOptions().MinServerVersion("8.0").AtlasDataLake(false).CreateClient(false).ClientType(mtest.Pinned) mt := mtest.New(t, mtOpts) mt.Run("input with greater than maxWriteBatchSize", func(mt *mtest.T) { @@ -423,10 +424,12 @@ func TestClientBulkWrite(t *testing.T) { monitor := &event.CommandMonitor{ Started: func(_ context.Context, e *event.CommandStartedEvent) { if e.CommandName == "bulkWrite" { - v := e.Command.Lookup("ops") - elems, err := v.Array().Elements() - require.NoError(mt, err, "monitor error") - opsCnt = append(opsCnt, len(elems)) + var c struct { + Ops []bson.D + } + err := bson.Unmarshal(e.Command, &c) + require.NoError(mt, err) + opsCnt = append(opsCnt, len(c.Ops)) } }, } @@ -455,10 +458,12 @@ func TestClientBulkWrite(t *testing.T) { monitor := &event.CommandMonitor{ Started: func(_ context.Context, e *event.CommandStartedEvent) { if e.CommandName == "bulkWrite" { - v := e.Command.Lookup("ops") - elems, err := v.Array().Elements() - require.NoError(mt, err, "monitor error") - opsCnt = append(opsCnt, len(elems)) + var c struct { + Ops []bson.D + } + err := bson.Unmarshal(e.Command, &c) + require.NoError(mt, err) + opsCnt = append(opsCnt, len(c.Ops)) } }, } @@ -744,11 +749,20 @@ func TestClientBulkWrite(t *testing.T) { }) mt.Run("bulkWrite batch splits when the addition of a new namespace exceeds the maximum message size", func(mt *mtest.T) { - var bwCmd []bsoncore.Document + type cmd struct { + Ops []bson.D + NsInfo []struct { + Ns string + } + } + var bwCmd []cmd monitor := &event.CommandMonitor{ Started: func(_ context.Context, e *event.CommandStartedEvent) { if e.CommandName == "bulkWrite" { - bwCmd = append(bwCmd, bsoncore.Document(e.Command)) + var c cmd + err := bson.Unmarshal(e.Command, &c) + require.NoError(mt, err) + bwCmd = append(bwCmd, c) } }, } @@ -794,17 +808,9 @@ func TestClientBulkWrite(t *testing.T) { assert.Equal(mt, numModels+1, int(result.InsertedCount), "expected insertedCound: %d, got: %d", numModels+1, result.InsertedCount) require.Len(mt, bwCmd, 1, "expected %d bulkWrite call, got %d", 1, len(bwCmd)) - var cmd struct { - Ops []bson.D - NsInfo []struct { - Ns string - } - } - err = bson.Unmarshal(bwCmd[0], &cmd) - require.NoError(mt, err) - assert.Len(mt, cmd.Ops, numModels+1, "expected ops: %d, got: %d", numModels+1, len(cmd.Ops)) - require.Len(mt, cmd.NsInfo, 1, "expected %d nsInfo, got: %d", 1, len(cmd.NsInfo)) - assert.Equal(mt, "db.coll", cmd.NsInfo[0].Ns, "expected namespace: %s, got: %s", "db.coll", cmd.NsInfo[0].Ns) + assert.Len(mt, bwCmd[0].Ops, numModels+1, "expected ops: %d, got: %d", numModels+1, len(bwCmd[0].Ops)) + require.Len(mt, bwCmd[0].NsInfo, 1, "expected %d nsInfo, got: %d", 1, len(bwCmd[0].NsInfo)) + assert.Equal(mt, "db.coll", bwCmd[0].NsInfo[0].Ns, "expected namespace: %s, got: %s", "db.coll", bwCmd[0].NsInfo[0].Ns) }) mt.Run("batch-splitting required", func(mt *mtest.T) { bwCmd = bwCmd[:0] @@ -818,27 +824,15 @@ func TestClientBulkWrite(t *testing.T) { result, err := mt.Client.BulkWrite(context.Background(), models, options.ClientBulkWrite()) require.NoError(mt, err) assert.Equal(mt, numModels+1, int(result.InsertedCount), "expected insertedCound: %d, got: %d", numModels+1, result.InsertedCount) - require.Len(mt, bwCmd, 2, "expected %d bulkWrite call, got %d", 2, len(bwCmd)) + require.Len(mt, bwCmd, 2, "expected %d bulkWrite calls, got %d", 2, len(bwCmd)) - type cmd struct { - Ops []bson.D - NsInfo []struct { - Ns string - } - } - var c1 cmd - err = bson.Unmarshal(bwCmd[0], &c1) - require.NoError(mt, err) - assert.Len(mt, c1.Ops, numModels, "expected ops: %d, got: %d", numModels, len(c1.Ops)) - require.Len(mt, c1.NsInfo, 1, "expected %d nsInfo, got: %d", 1, len(c1.NsInfo)) - assert.Equal(mt, "db.coll", c1.NsInfo[0].Ns, "expected namespace: %s, got: %s", "db.coll", c1.NsInfo[0].Ns) + assert.Len(mt, bwCmd[0].Ops, numModels, "expected ops: %d, got: %d", numModels, len(bwCmd[0].Ops)) + require.Len(mt, bwCmd[0].NsInfo, 1, "expected %d nsInfo, got: %d", 1, len(bwCmd[0].NsInfo)) + assert.Equal(mt, "db.coll", bwCmd[0].NsInfo[0].Ns, "expected namespace: %s, got: %s", "db.coll", bwCmd[0].NsInfo[0].Ns) - var c2 cmd - err = bson.Unmarshal(bwCmd[1], &c2) - require.NoError(mt, err) - assert.Len(mt, c2.Ops, 1, "expected ops: %d, got: %d", 1, len(c2.Ops)) - require.Len(mt, c2.NsInfo, 1, "expected %d nsInfo, got: %d", 1, len(c2.NsInfo)) - assert.Equal(mt, "db."+coll, c2.NsInfo[0].Ns, "expected namespace: %s, got: %s", "db."+coll, c2.NsInfo[0].Ns) + assert.Len(mt, bwCmd[1].Ops, 1, "expected ops: %d, got: %d", 1, len(bwCmd[1].Ops)) + require.Len(mt, bwCmd[1].NsInfo, 1, "expected %d nsInfo, got: %d", 1, len(bwCmd[1].NsInfo)) + assert.Equal(mt, "db."+coll, bwCmd[1].NsInfo[0].Ns, "expected namespace: %s, got: %s", "db."+coll, bwCmd[1].NsInfo[0].Ns) }) }) @@ -867,6 +861,10 @@ func TestClientBulkWrite(t *testing.T) { }) mt.Run("bulkWrite returns an error if auto-encryption is configured", func(mt *mtest.T) { + if os.Getenv("DOCKER_RUNNING") != "" { + mt.Skip("skipping test in docker environment") + } + autoEncryptionOpts := options.AutoEncryption(). SetKeyVaultNamespace("db.coll"). SetKmsProviders(map[string]map[string]interface{}{ @@ -883,4 +881,59 @@ func TestClientBulkWrite(t *testing.T) { _, err := mt.Client.BulkWrite(context.Background(), models, options.ClientBulkWrite().SetOrdered(false).SetWriteConcern(writeconcern.Unacknowledged())) require.ErrorContains(mt, err, "bulkWrite does not currently support automatic encryption") }) + + mt.Run("bulkWrite with unacknowledged write concern uses w:0 for all batches", func(mt *mtest.T) { + type cmd struct { + Ops []bson.D + WriteConcern struct { + W interface{} + } + } + var bwCmd []cmd + monitor := &event.CommandMonitor{ + Started: func(_ context.Context, e *event.CommandStartedEvent) { + if e.CommandName == "bulkWrite" { + var c cmd + err := bson.Unmarshal(e.Command, &c) + require.NoError(mt, err) + + bwCmd = append(bwCmd, c) + } + }, + } + var hello struct { + MaxBsonObjectSize int + MaxMessageSizeBytes int + } + err := mt.DB.RunCommand(context.Background(), bson.D{{"hello", 1}}).Decode(&hello) + require.NoError(mt, err, "Hello error") + + mt.ResetClient(options.Client().SetMonitor(monitor)) + + coll := mt.CreateCollection(mtest.Collection{DB: "db", Name: "coll"}, true) + err = coll.Drop(context.Background()) + require.NoError(mt, err, "Drop error") + + numModels := hello.MaxMessageSizeBytes / hello.MaxBsonObjectSize + models := &mongo.ClientWriteModels{} + for i := 0; i < numModels+1; i++ { + models. + AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{ + Document: bson.D{{"a", strings.Repeat("b", hello.MaxBsonObjectSize-500)}}, + }) + } + _, err = mt.Client.BulkWrite(context.Background(), models, options.ClientBulkWrite().SetOrdered(false).SetWriteConcern(writeconcern.Unacknowledged())) + require.NoError(mt, err) + require.Len(mt, bwCmd, 2, "expected %d bulkWrite calls, got %d", 2, len(bwCmd)) + + assert.Len(mt, bwCmd[0].Ops, numModels, "expected ops: %d, got: %d", numModels, len(bwCmd[0].Ops)) + assert.Equal(mt, int32(0), bwCmd[0].WriteConcern.W, "expected writeConcern: %d, got: %v", 0, bwCmd[0].WriteConcern.W) + + assert.Len(mt, bwCmd[1].Ops, 1, "expected ops: %d, got: %d", 1, len(bwCmd[1].Ops)) + assert.Equal(mt, int32(0), bwCmd[1].WriteConcern.W, "expected writeConcern: %d, got: %v", 0, bwCmd[1].WriteConcern.W) + + n, err := coll.CountDocuments(context.Background(), bson.D{}) + require.NoError(mt, err) + assert.Equal(mt, numModels+1, int(n), "expected %d documents, got %d", numModels+1, n) + }) } diff --git a/mongo/integration/unified/client_operation_execution.go b/mongo/integration/unified/client_operation_execution.go index 24f1d145a5..90ec577ebd 100644 --- a/mongo/integration/unified/client_operation_execution.go +++ b/mongo/integration/unified/client_operation_execution.go @@ -8,6 +8,7 @@ package unified import ( "context" + "errors" "fmt" "strconv" "strings" @@ -228,60 +229,64 @@ func executeClientBulkWrite(ctx context.Context, operation *operation) (*operati } res, err := client.BulkWrite(ctx, wirteModels, opts) - raw := emptyCoreDocument - if res != nil { - rawBuilder := bsoncore.NewDocumentBuilder(). - AppendInt64("deletedCount", res.DeletedCount). - AppendInt64("insertedCount", res.InsertedCount). - AppendInt64("matchedCount", res.MatchedCount). - AppendInt64("modifiedCount", res.ModifiedCount). - AppendInt64("upsertedCount", res.UpsertedCount) - - var resBuilder *bsoncore.DocumentBuilder - - resBuilder = bsoncore.NewDocumentBuilder() - for k, v := range res.DeleteResults { - resBuilder.AppendDocument(strconv.Itoa(k), - bsoncore.NewDocumentBuilder(). - AppendInt64("deletedCount", v.DeletedCount). - Build(), - ) + if res == nil { + var bwe mongo.ClientBulkWriteException + if !errors.As(err, &bwe) || bwe.PartialResult == nil { + return newDocumentResult(emptyCoreDocument, err), nil + } else { + res = bwe.PartialResult } - rawBuilder.AppendDocument("deleteResults", resBuilder.Build()) + } + rawBuilder := bsoncore.NewDocumentBuilder(). + AppendInt64("deletedCount", res.DeletedCount). + AppendInt64("insertedCount", res.InsertedCount). + AppendInt64("matchedCount", res.MatchedCount). + AppendInt64("modifiedCount", res.ModifiedCount). + AppendInt64("upsertedCount", res.UpsertedCount) + + var resBuilder *bsoncore.DocumentBuilder + + resBuilder = bsoncore.NewDocumentBuilder() + for k, v := range res.DeleteResults { + resBuilder.AppendDocument(strconv.Itoa(k), + bsoncore.NewDocumentBuilder(). + AppendInt64("deletedCount", v.DeletedCount). + Build(), + ) + } + rawBuilder.AppendDocument("deleteResults", resBuilder.Build()) - resBuilder = bsoncore.NewDocumentBuilder() - for k, v := range res.InsertResults { - t, d, err := bson.MarshalValue(v.InsertedID) + resBuilder = bsoncore.NewDocumentBuilder() + for k, v := range res.InsertResults { + t, d, err := bson.MarshalValue(v.InsertedID) + if err != nil { + return nil, err + } + resBuilder.AppendDocument(strconv.Itoa(k), + bsoncore.NewDocumentBuilder(). + AppendValue("insertedId", bsoncore.Value{Type: t, Data: d}). + Build(), + ) + } + rawBuilder.AppendDocument("insertResults", resBuilder.Build()) + + resBuilder = bsoncore.NewDocumentBuilder() + for k, v := range res.UpdateResults { + b := bsoncore.NewDocumentBuilder(). + AppendInt64("matchedCount", v.MatchedCount). + AppendInt64("modifiedCount", v.ModifiedCount) + if v.UpsertedID != nil { + t, d, err := bson.MarshalValue(v.UpsertedID) if err != nil { return nil, err } - resBuilder.AppendDocument(strconv.Itoa(k), - bsoncore.NewDocumentBuilder(). - AppendValue("insertedId", bsoncore.Value{Type: t, Data: d}). - Build(), - ) + b.AppendValue("upsertedId", bsoncore.Value{Type: t, Data: d}) } - rawBuilder.AppendDocument("insertResults", resBuilder.Build()) - - resBuilder = bsoncore.NewDocumentBuilder() - for k, v := range res.UpdateResults { - b := bsoncore.NewDocumentBuilder(). - AppendInt64("matchedCount", v.MatchedCount). - AppendInt64("modifiedCount", v.ModifiedCount) - if v.UpsertedID != nil { - t, d, err := bson.MarshalValue(v.UpsertedID) - if err != nil { - return nil, err - } - b.AppendValue("upsertedId", bsoncore.Value{Type: t, Data: d}) - } - resBuilder.AppendDocument(strconv.Itoa(k), b.Build()) - } - rawBuilder.AppendDocument("updateResults", resBuilder.Build()) - - raw = rawBuilder.Build() + resBuilder.AppendDocument(strconv.Itoa(k), b.Build()) } - return newDocumentResult(raw, err), nil + rawBuilder.AppendDocument("updateResults", resBuilder.Build()) + + return newDocumentResult(rawBuilder.Build(), err), nil } func appendClientBulkWriteModel(key string, value bson.Raw, model *mongo.ClientWriteModels) error {