diff --git a/mongo/client.go b/mongo/client.go index edc6cd487f..5d8e7ece67 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -426,8 +426,6 @@ func (c *Client) StartSession(opts ...*options.SessionOptions) (Session, error) return nil, replaceErrors(err) } - // Writes are not retryable on standalones, so let operation determine whether to retry - sess.RetryWrite = false sess.RetryRead = c.retryReads return &sessionImpl{ diff --git a/mongo/client_bulk_write.go b/mongo/client_bulk_write.go index b07da1701d..19d43684b8 100644 --- a/mongo/client_bulk_write.go +++ b/mongo/client_bulk_write.go @@ -27,7 +27,7 @@ import ( // bulkWrite performs a bulkwrite operation type clientBulkWrite struct { - models []interface{} + models []clientWriteModel errorsOnly bool ordered *bool bypassDocumentValidation *bool @@ -46,11 +46,12 @@ func (bw *clientBulkWrite) execute(ctx context.Context) error { return errors.New("empty write models") } batches := &modelBatches{ - session: bw.session, - client: bw.client, - ordered: bw.ordered, - models: bw.models, - result: &bw.result, + session: bw.session, + client: bw.client, + ordered: bw.ordered, + models: bw.models, + result: &bw.result, + retryMode: driver.RetryOnce, } err := driver.Operation{ CommandFn: bw.newCommand(), @@ -142,7 +143,7 @@ type modelBatches struct { client *Client ordered *bool - models []interface{} + models []clientWriteModel offset int @@ -222,17 +223,14 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD mb.newIDMap = make(map[int]interface{}) nsMap := make(map[string]int) - getNsIndex := func(namespace string) (int, bsoncore.Document) { - idx, doc := bsoncore.AppendDocumentStart(nil) - doc = bsoncore.AppendStringElement(doc, "ns", namespace) - doc, _ = bsoncore.AppendDocumentEnd(doc, idx) - - if v, ok := nsMap[namespace]; ok { - return v, doc + getNsIndex := func(namespace string) (int, bool) { + v, ok := nsMap[namespace] + if ok { + return v, ok } nsIdx := len(nsMap) nsMap[namespace] = nsIdx - return nsIdx, doc + return nsIdx, ok } canRetry := true @@ -249,12 +247,13 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD break } - var nsIdx int - var ns, doc bsoncore.Document + ns := mb.models[i].namespace + nsIdx, exists := getNsIndex(ns) + + var doc bsoncore.Document var err error - switch model := mb.models[i].(type) { + switch model := mb.models[i].model.(type) { case *ClientInsertOneModel: - nsIdx, ns = getNsIndex(model.Namespace) mb.cursorHandlers[i] = mb.appendInsertResult var id interface{} id, doc, err = (&clientInsertDoc{ @@ -266,7 +265,6 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD } mb.newIDMap[i] = id case *ClientUpdateOneModel: - nsIdx, ns = getNsIndex(model.Namespace) mb.cursorHandlers[i] = mb.appendUpdateResult doc, err = (&clientUpdateDoc{ namespace: nsIdx, @@ -281,7 +279,6 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD }).marshal(mb.client.bsonOpts, mb.client.registry) case *ClientUpdateManyModel: canRetry = false - nsIdx, ns = getNsIndex(model.Namespace) mb.cursorHandlers[i] = mb.appendUpdateResult doc, err = (&clientUpdateDoc{ namespace: nsIdx, @@ -295,7 +292,6 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD checkDollarKey: true, }).marshal(mb.client.bsonOpts, mb.client.registry) case *ClientReplaceOneModel: - nsIdx, ns = getNsIndex(model.Namespace) mb.cursorHandlers[i] = mb.appendUpdateResult doc, err = (&clientUpdateDoc{ namespace: nsIdx, @@ -309,7 +305,6 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD checkDollarKey: false, }).marshal(mb.client.bsonOpts, mb.client.registry) case *ClientDeleteOneModel: - nsIdx, ns = getNsIndex(model.Namespace) mb.cursorHandlers[i] = mb.appendDeleteResult doc, err = (&clientDeleteDoc{ namespace: nsIdx, @@ -320,7 +315,6 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD }).marshal(mb.client.bsonOpts, mb.client.registry) case *ClientDeleteManyModel: canRetry = false - nsIdx, ns = getNsIndex(model.Namespace) mb.cursorHandlers[i] = mb.appendDeleteResult doc, err = (&clientDeleteDoc{ namespace: nsIdx, @@ -343,7 +337,12 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD } dst = fn.appendDocument(dst, strconv.Itoa(n), doc) - nsDst = fn.appendDocument(nsDst, strconv.Itoa(n), ns) + if !exists { + idx, doc := bsoncore.AppendDocumentStart(nil) + doc = bsoncore.AppendStringElement(doc, "ns", ns) + doc, _ = bsoncore.AppendDocumentEnd(doc, idx) + nsDst = fn.appendDocument(nsDst, strconv.Itoa(n), doc) + } n++ } if n == 0 { @@ -356,7 +355,7 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD mb.retryMode = driver.RetryNone if mb.client.retryWrites && canRetry { - mb.retryMode = driver.RetryOncePerCommand + mb.retryMode = driver.RetryOnce } return n, dst, nil } @@ -414,7 +413,7 @@ func (mb *modelBatches) processResponse(ctx context.Context, resp bsoncore.Docum return err } var cursor *Cursor - cursor, err = newCursorWithSession(bCursor, mb.client.bsonOpts, mb.client.registry, mb.session) + cursor, err = newCursor(bCursor, mb.client.bsonOpts, mb.client.registry) if err != nil { return err } @@ -430,7 +429,7 @@ func (mb *modelBatches) processResponse(ctx context.Context, resp bsoncore.Docum if int(cur.Idx) >= len(mb.cursorHandlers) { continue } - ok = ok && mb.cursorHandlers[int(cur.Idx)](&cur, cursor.Current) + ok = mb.cursorHandlers[int(cur.Idx)](&cur, cursor.Current) && ok } err = cursor.Err() if err != nil { @@ -456,32 +455,51 @@ func (mb *modelBatches) processResponse(ctx context.Context, resp bsoncore.Docum } func (mb *modelBatches) appendDeleteResult(cur *cursorInfo, raw bson.Raw) bool { + if err := cur.extractError(); err != nil { + err.Raw = raw + if mb.writeErrors == nil { + mb.writeErrors = make(map[int]WriteError) + } + mb.writeErrors[int(cur.Idx)] = *err + return false + } + if mb.result.DeleteResults == nil { mb.result.DeleteResults = make(map[int]ClientDeleteResult) } mb.result.DeleteResults[int(cur.Idx)] = ClientDeleteResult{int64(cur.N)} + + return true +} + +func (mb *modelBatches) appendInsertResult(cur *cursorInfo, raw bson.Raw) bool { if err := cur.extractError(); err != nil { err.Raw = raw + if mb.writeErrors == nil { + mb.writeErrors = make(map[int]WriteError) + } mb.writeErrors[int(cur.Idx)] = *err return false } - return true -} -func (mb *modelBatches) appendInsertResult(cur *cursorInfo, raw bson.Raw) bool { if mb.result.InsertResults == nil { mb.result.InsertResults = make(map[int]ClientInsertResult) } mb.result.InsertResults[int(cur.Idx)] = ClientInsertResult{mb.newIDMap[int(cur.Idx)]} + + return true +} + +func (mb *modelBatches) appendUpdateResult(cur *cursorInfo, raw bson.Raw) bool { if err := cur.extractError(); err != nil { err.Raw = raw + if mb.writeErrors == nil { + mb.writeErrors = make(map[int]WriteError) + } mb.writeErrors[int(cur.Idx)] = *err return false } - return true -} -func (mb *modelBatches) appendUpdateResult(cur *cursorInfo, raw bson.Raw) bool { if mb.result.UpdateResults == nil { mb.result.UpdateResults = make(map[int]ClientUpdateResult) } @@ -495,11 +513,7 @@ func (mb *modelBatches) appendUpdateResult(cur *cursorInfo, raw bson.Raw) bool { result.UpsertedID = cur.Upserted.ID } mb.result.UpdateResults[int(cur.Idx)] = result - if err := cur.extractError(); err != nil { - err.Raw = raw - mb.writeErrors[int(cur.Idx)] = *err - return false - } + return true } diff --git a/mongo/client_bulk_write_models.go b/mongo/client_bulk_write_models.go index dfd090715e..4a2259a5c9 100644 --- a/mongo/client_bulk_write_models.go +++ b/mongo/client_bulk_write_models.go @@ -7,89 +7,108 @@ package mongo import ( + "fmt" + "go.mongodb.org/mongo-driver/mongo/options" ) // ClientWriteModels is a struct that can be used in a client-level BulkWrite operation. type ClientWriteModels struct { - models []interface{} + models []clientWriteModel +} + +type clientWriteModel struct { + namespace string + model interface{} } // AppendInsertOne appends ClientInsertOneModels. -func (m *ClientWriteModels) AppendInsertOne(models ...*ClientInsertOneModel) *ClientWriteModels { +func (m *ClientWriteModels) AppendInsertOne(database, collection string, models ...*ClientInsertOneModel) *ClientWriteModels { if m == nil { m = &ClientWriteModels{} } for _, model := range models { - m.models = append(m.models, model) + m.models = append(m.models, clientWriteModel{ + namespace: fmt.Sprintf("%s.%s", database, collection), + model: model, + }) } return m } // AppendUpdateOne appends ClientUpdateOneModels. -func (m *ClientWriteModels) AppendUpdateOne(models ...*ClientUpdateOneModel) *ClientWriteModels { +func (m *ClientWriteModels) AppendUpdateOne(database, collection string, models ...*ClientUpdateOneModel) *ClientWriteModels { if m == nil { m = &ClientWriteModels{} } for _, model := range models { - m.models = append(m.models, model) + m.models = append(m.models, clientWriteModel{ + namespace: fmt.Sprintf("%s.%s", database, collection), + model: model, + }) } return m } // AppendUpdateMany appends ClientUpdateManyModels. -func (m *ClientWriteModels) AppendUpdateMany(models ...*ClientUpdateManyModel) *ClientWriteModels { +func (m *ClientWriteModels) AppendUpdateMany(database, collection string, models ...*ClientUpdateManyModel) *ClientWriteModels { if m == nil { m = &ClientWriteModels{} } for _, model := range models { - m.models = append(m.models, model) + m.models = append(m.models, clientWriteModel{ + namespace: fmt.Sprintf("%s.%s", database, collection), + model: model, + }) } return m } // AppendReplaceOne appends ClientReplaceOneModels. -func (m *ClientWriteModels) AppendReplaceOne(models ...*ClientReplaceOneModel) *ClientWriteModels { +func (m *ClientWriteModels) AppendReplaceOne(database, collection string, models ...*ClientReplaceOneModel) *ClientWriteModels { if m == nil { m = &ClientWriteModels{} } for _, model := range models { - m.models = append(m.models, model) + m.models = append(m.models, clientWriteModel{ + namespace: fmt.Sprintf("%s.%s", database, collection), + model: model, + }) } return m } // AppendDeleteOne appends ClientDeleteOneModels. -func (m *ClientWriteModels) AppendDeleteOne(models ...*ClientDeleteOneModel) *ClientWriteModels { +func (m *ClientWriteModels) AppendDeleteOne(database, collection string, models ...*ClientDeleteOneModel) *ClientWriteModels { if m == nil { m = &ClientWriteModels{} } for _, model := range models { - m.models = append(m.models, model) + m.models = append(m.models, clientWriteModel{ + namespace: fmt.Sprintf("%s.%s", database, collection), + model: model, + }) } return m } // AppendDeleteMany appends ClientDeleteManyModels. -func (m *ClientWriteModels) AppendDeleteMany(models ...*ClientDeleteManyModel) *ClientWriteModels { +func (m *ClientWriteModels) AppendDeleteMany(database, collection string, models ...*ClientDeleteManyModel) *ClientWriteModels { if m == nil { m = &ClientWriteModels{} } for _, model := range models { - m.models = append(m.models, model) + m.models = append(m.models, clientWriteModel{ + namespace: fmt.Sprintf("%s.%s", database, collection), + model: model, + }) } return m } // ClientInsertOneModel is used to insert a single document in a BulkWrite operation. type ClientInsertOneModel struct { - Namespace string - Document interface{} -} - -// NewClientInsertOneModel creates a new ClientInsertOneModel. -func NewClientInsertOneModel(namespace string) *ClientInsertOneModel { - return &ClientInsertOneModel{Namespace: namespace} + Document interface{} } // SetDocument specifies the document to be inserted. The document cannot be nil. If it does not have an _id field when @@ -102,7 +121,6 @@ func (iom *ClientInsertOneModel) SetDocument(doc interface{}) *ClientInsertOneMo // ClientUpdateOneModel is used to update at most one document in a client-level BulkWrite operation. type ClientUpdateOneModel struct { - Namespace string Collation *options.Collation Upsert *bool Filter interface{} @@ -111,11 +129,6 @@ type ClientUpdateOneModel struct { Hint interface{} } -// ClientNewUpdateOneModel creates a new ClientUpdateOneModel. -func ClientNewUpdateOneModel(namespace string) *ClientUpdateOneModel { - return &ClientUpdateOneModel{Namespace: namespace} -} - // SetHint specifies the index to use for the operation. This should either be the index name as a string or the index // specification as a document. The default value is nil, which means that no hint will be sent. func (uom *ClientUpdateOneModel) SetHint(hint interface{}) *ClientUpdateOneModel { @@ -162,7 +175,6 @@ func (uom *ClientUpdateOneModel) SetUpsert(upsert bool) *ClientUpdateOneModel { // ClientUpdateManyModel is used to update multiple documents in a client-level BulkWrite operation. type ClientUpdateManyModel struct { - Namespace string Collation *options.Collation Upsert *bool Filter interface{} @@ -171,11 +183,6 @@ type ClientUpdateManyModel struct { Hint interface{} } -// NewClientUpdateManyModel creates a new ClientUpdateManyModel. -func NewClientUpdateManyModel(namespace string) *ClientUpdateManyModel { - return &ClientUpdateManyModel{Namespace: namespace} -} - // SetHint specifies the index to use for the operation. This should either be the index name as a string or the index // specification as a document. The default value is nil, which means that no hint will be sent. func (umm *ClientUpdateManyModel) SetHint(hint interface{}) *ClientUpdateManyModel { @@ -221,7 +228,6 @@ func (umm *ClientUpdateManyModel) SetUpsert(upsert bool) *ClientUpdateManyModel // ClientReplaceOneModel is used to replace at most one document in a client-level BulkWrite operation. type ClientReplaceOneModel struct { - Namespace string Collation *options.Collation Upsert *bool Filter interface{} @@ -229,11 +235,6 @@ type ClientReplaceOneModel struct { Hint interface{} } -// NewClientReplaceOneModel creates a new ClientReplaceOneModel. -func NewClientReplaceOneModel(namespace string) *ClientReplaceOneModel { - return &ClientReplaceOneModel{Namespace: namespace} -} - // SetHint specifies the index to use for the operation. This should either be the index name as a string or the index // specification as a document. The default value is nil, which means that no hint will be sent. func (rom *ClientReplaceOneModel) SetHint(hint interface{}) *ClientReplaceOneModel { @@ -273,17 +274,11 @@ func (rom *ClientReplaceOneModel) SetUpsert(upsert bool) *ClientReplaceOneModel // ClientDeleteOneModel is used to delete at most one document in a client-level BulkWriteOperation. type ClientDeleteOneModel struct { - Namespace string Filter interface{} Collation *options.Collation Hint interface{} } -// NewClientDeleteOneModel creates a new ClientDeleteOneModel. -func NewClientDeleteOneModel(namespace string) *ClientDeleteOneModel { - return &ClientDeleteOneModel{Namespace: namespace} -} - // SetFilter specifies a filter to use to select the document to delete. The filter must be a document containing query // operators. It cannot be nil. If the filter matches multiple documents, one will be selected from the matching // documents. @@ -308,17 +303,11 @@ func (dom *ClientDeleteOneModel) SetHint(hint interface{}) *ClientDeleteOneModel // ClientDeleteManyModel is used to delete multiple documents in a client-level BulkWrite operation. type ClientDeleteManyModel struct { - Namespace string Filter interface{} Collation *options.Collation Hint interface{} } -// NewClientDeleteManyModel creates a new ClientDeleteManyModel. -func NewClientDeleteManyModel(namespace string) *ClientDeleteManyModel { - return &ClientDeleteManyModel{Namespace: namespace} -} - // SetFilter specifies a filter to use to select documents to delete. The filter must be a document containing query // operators. It cannot be nil. func (dmm *ClientDeleteManyModel) SetFilter(filter interface{}) *ClientDeleteManyModel { diff --git a/mongo/integration/crud_prose_test.go b/mongo/integration/crud_prose_test.go index 2c61c73e4c..ba8e9d6606 100644 --- a/mongo/integration/crud_prose_test.go +++ b/mongo/integration/crud_prose_test.go @@ -437,9 +437,8 @@ func TestClientBulkWrite(t *testing.T) { models := &mongo.ClientWriteModels{} for i := 0; i < hello.MaxWriteBatchSize+1; i++ { models. - AppendInsertOne(&mongo.ClientInsertOneModel{ - Namespace: "db.coll", - Document: bson.D{{"a", "b"}}, + AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{ + Document: bson.D{{"a", "b"}}, }) } result, err := mt.Client.BulkWrite(context.Background(), models) @@ -472,9 +471,8 @@ func TestClientBulkWrite(t *testing.T) { numModels := hello.MaxMessageSizeBytes/hello.MaxBsonObjectSize + 1 for i := 0; i < numModels; i++ { models. - AppendInsertOne(&mongo.ClientInsertOneModel{ - Namespace: "db.coll", - Document: bson.D{{"a", strings.Repeat("b", hello.MaxBsonObjectSize-500)}}, + AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{ + Document: bson.D{{"a", strings.Repeat("b", hello.MaxBsonObjectSize-500)}}, }) } result, err := mt.Client.BulkWrite(context.Background(), models) @@ -509,9 +507,8 @@ func TestClientBulkWrite(t *testing.T) { models := &mongo.ClientWriteModels{} for i := 0; i < hello.MaxWriteBatchSize+1; i++ { models. - AppendInsertOne(&mongo.ClientInsertOneModel{ - Namespace: "db.coll", - Document: bson.D{{"a", "b"}}, + AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{ + Document: bson.D{{"a", "b"}}, }) } _, err := mt.Client.BulkWrite(context.Background(), models) @@ -548,9 +545,8 @@ func TestClientBulkWrite(t *testing.T) { models := &mongo.ClientWriteModels{} for i := 0; i < hello.MaxWriteBatchSize+1; i++ { models. - AppendInsertOne(&mongo.ClientInsertOneModel{ - Namespace: "db.coll", - Document: bson.D{{"_id", 1}}, + AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{ + Document: bson.D{{"_id", 1}}, }) } @@ -597,17 +593,15 @@ func TestClientBulkWrite(t *testing.T) { require.NoError(mt, err, "Hello error") upsert := true models := (&mongo.ClientWriteModels{}). - AppendUpdateOne(&mongo.ClientUpdateOneModel{ - Namespace: "db.coll", - Filter: bson.D{{"_id", strings.Repeat("a", hello.MaxBsonObjectSize/2)}}, - Update: bson.D{{"$set", bson.D{{"x", 1}}}}, - Upsert: &upsert, + AppendUpdateOne("db", "coll", &mongo.ClientUpdateOneModel{ + Filter: bson.D{{"_id", strings.Repeat("a", hello.MaxBsonObjectSize/2)}}, + Update: bson.D{{"$set", bson.D{{"x", 1}}}}, + Upsert: &upsert, }). - AppendUpdateOne(&mongo.ClientUpdateOneModel{ - Namespace: "db.coll", - Filter: bson.D{{"_id", strings.Repeat("b", hello.MaxBsonObjectSize/2)}}, - Update: bson.D{{"$set", bson.D{{"x", 1}}}}, - Upsert: &upsert, + AppendUpdateOne("db", "coll", &mongo.ClientUpdateOneModel{ + Filter: bson.D{{"_id", strings.Repeat("b", hello.MaxBsonObjectSize/2)}}, + Update: bson.D{{"$set", bson.D{{"x", 1}}}}, + Upsert: &upsert, }) result, err := mt.Client.BulkWrite(context.Background(), models, options.ClientBulkWrite().SetVerboseResults(true)) require.NoError(mt, err, "BulkWrite error") @@ -640,17 +634,15 @@ func TestClientBulkWrite(t *testing.T) { defer session.EndSession(context.Background()) upsert := true models := (&mongo.ClientWriteModels{}). - AppendUpdateOne(&mongo.ClientUpdateOneModel{ - Namespace: "db.coll", - Filter: bson.D{{"_id", strings.Repeat("a", hello.MaxBsonObjectSize/2)}}, - Update: bson.D{{"$set", bson.D{{"x", 1}}}}, - Upsert: &upsert, + AppendUpdateOne("db", "coll", &mongo.ClientUpdateOneModel{ + Filter: bson.D{{"_id", strings.Repeat("a", hello.MaxBsonObjectSize/2)}}, + Update: bson.D{{"$set", bson.D{{"x", 1}}}}, + Upsert: &upsert, }). - AppendUpdateOne(&mongo.ClientUpdateOneModel{ - Namespace: "db.coll", - Filter: bson.D{{"_id", strings.Repeat("b", hello.MaxBsonObjectSize/2)}}, - Update: bson.D{{"$set", bson.D{{"x", 1}}}}, - Upsert: &upsert, + AppendUpdateOne("db", "coll", &mongo.ClientUpdateOneModel{ + Filter: bson.D{{"_id", strings.Repeat("b", hello.MaxBsonObjectSize/2)}}, + Update: bson.D{{"$set", bson.D{{"x", 1}}}}, + Upsert: &upsert, }) result, err := session.WithTransaction(context.Background(), func(mongo.SessionContext) (interface{}, error) { return mt.Client.BulkWrite(context.Background(), models, options.ClientBulkWrite().SetVerboseResults(true)) @@ -673,9 +665,8 @@ func TestClientBulkWrite(t *testing.T) { } require.NoError(mt, mt.DB.RunCommand(context.Background(), bson.D{{"hello", 1}}).Decode(&hello), "Hello error") models := (&mongo.ClientWriteModels{}). - AppendInsertOne(&mongo.ClientInsertOneModel{ - Namespace: "db.coll", - Document: bson.D{{"a", strings.Repeat("b", hello.MaxBsonObjectSize)}}, + AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{ + Document: bson.D{{"a", strings.Repeat("b", hello.MaxBsonObjectSize)}}, }) result, err := mt.Client.BulkWrite(context.Background(), models, options.ClientBulkWrite().SetWriteConcern(writeconcern.Unacknowledged())) require.NoError(mt, err, "BulkWrite error") diff --git a/mongo/integration/unified/client_operation_execution.go b/mongo/integration/unified/client_operation_execution.go index 17cb284315..24f1d145a5 100644 --- a/mongo/integration/unified/client_operation_execution.go +++ b/mongo/integration/unified/client_operation_execution.go @@ -10,6 +10,7 @@ import ( "context" "fmt" "strconv" + "strings" "time" "go.mongodb.org/mongo-driver/bson" @@ -286,61 +287,66 @@ func executeClientBulkWrite(ctx context.Context, operation *operation) (*operati func appendClientBulkWriteModel(key string, value bson.Raw, model *mongo.ClientWriteModels) error { switch key { case "insertOne": - m, err := createClientInsertOneModel(value) + namespace, m, err := createClientInsertOneModel(value) if err != nil { return err } - model.AppendInsertOne(m) + ns := strings.SplitN(namespace, ".", 2) + model.AppendInsertOne(ns[0], ns[1], m) case "updateOne": - m, err := createClientUpdateOneModel(value) + namespace, m, err := createClientUpdateOneModel(value) if err != nil { return err } - model.AppendUpdateOne(m) + ns := strings.SplitN(namespace, ".", 2) + model.AppendUpdateOne(ns[0], ns[1], m) case "updateMany": - m, err := createClientUpdateManyModel(value) + namespace, m, err := createClientUpdateManyModel(value) if err != nil { return err } - model.AppendUpdateMany(m) + ns := strings.SplitN(namespace, ".", 2) + model.AppendUpdateMany(ns[0], ns[1], m) case "replaceOne": - m, err := createClientReplaceOneModel(value) + namespace, m, err := createClientReplaceOneModel(value) if err != nil { return err } - model.AppendReplaceOne(m) + ns := strings.SplitN(namespace, ".", 2) + model.AppendReplaceOne(ns[0], ns[1], m) case "deleteOne": - m, err := createClientDeleteOneModel(value) + namespace, m, err := createClientDeleteOneModel(value) if err != nil { return err } - model.AppendDeleteOne(m) + ns := strings.SplitN(namespace, ".", 2) + model.AppendDeleteOne(ns[0], ns[1], m) case "deleteMany": - m, err := createClientDeleteManyModel(value) + namespace, m, err := createClientDeleteManyModel(value) if err != nil { return err } - model.AppendDeleteMany(m) + ns := strings.SplitN(namespace, ".", 2) + model.AppendDeleteMany(ns[0], ns[1], m) } return nil } -func createClientInsertOneModel(value bson.Raw) (*mongo.ClientInsertOneModel, error) { +func createClientInsertOneModel(value bson.Raw) (string, *mongo.ClientInsertOneModel, error) { var v struct { Namespace string Document bson.Raw } err := bson.Unmarshal(value, &v) if err != nil { - return nil, err + return "", nil, err } - return &mongo.ClientInsertOneModel{ - Namespace: v.Namespace, - Document: v.Document, + return v.Namespace, &mongo.ClientInsertOneModel{ + Document: v.Document, }, nil } -func createClientUpdateOneModel(value bson.Raw) (*mongo.ClientUpdateOneModel, error) { +func createClientUpdateOneModel(value bson.Raw) (string, *mongo.ClientUpdateOneModel, error) { var v struct { Namespace string Filter bson.Raw @@ -352,17 +358,16 @@ func createClientUpdateOneModel(value bson.Raw) (*mongo.ClientUpdateOneModel, er } err := bson.Unmarshal(value, &v) if err != nil { - return nil, err + return "", nil, err } var hint interface{} if v.Hint != nil { hint, err = createHint(*v.Hint) if err != nil { - return nil, err + return "", nil, err } } model := &mongo.ClientUpdateOneModel{ - Namespace: v.Namespace, Filter: v.Filter, Update: v.Update, Collation: v.Collation, @@ -372,11 +377,11 @@ func createClientUpdateOneModel(value bson.Raw) (*mongo.ClientUpdateOneModel, er if len(v.ArrayFilters) > 0 { model.ArrayFilters = &options.ArrayFilters{Filters: v.ArrayFilters} } - return model, nil + return v.Namespace, model, nil } -func createClientUpdateManyModel(value bson.Raw) (*mongo.ClientUpdateManyModel, error) { +func createClientUpdateManyModel(value bson.Raw) (string, *mongo.ClientUpdateManyModel, error) { var v struct { Namespace string Filter bson.Raw @@ -388,17 +393,16 @@ func createClientUpdateManyModel(value bson.Raw) (*mongo.ClientUpdateManyModel, } err := bson.Unmarshal(value, &v) if err != nil { - return nil, err + return "", nil, err } var hint interface{} if v.Hint != nil { hint, err = createHint(*v.Hint) if err != nil { - return nil, err + return "", nil, err } } model := &mongo.ClientUpdateManyModel{ - Namespace: v.Namespace, Filter: v.Filter, Update: v.Update, Collation: v.Collation, @@ -408,10 +412,10 @@ func createClientUpdateManyModel(value bson.Raw) (*mongo.ClientUpdateManyModel, if len(v.ArrayFilters) > 0 { model.ArrayFilters = &options.ArrayFilters{Filters: v.ArrayFilters} } - return model, nil + return v.Namespace, model, nil } -func createClientReplaceOneModel(value bson.Raw) (*mongo.ClientReplaceOneModel, error) { +func createClientReplaceOneModel(value bson.Raw) (string, *mongo.ClientReplaceOneModel, error) { var v struct { Namespace string Filter bson.Raw @@ -422,17 +426,16 @@ func createClientReplaceOneModel(value bson.Raw) (*mongo.ClientReplaceOneModel, } err := bson.Unmarshal(value, &v) if err != nil { - return nil, err + return "", nil, err } var hint interface{} if v.Hint != nil { hint, err = createHint(*v.Hint) if err != nil { - return nil, err + return "", nil, err } } - return &mongo.ClientReplaceOneModel{ - Namespace: v.Namespace, + return v.Namespace, &mongo.ClientReplaceOneModel{ Filter: v.Filter, Replacement: v.Replacement, Collation: v.Collation, @@ -441,7 +444,7 @@ func createClientReplaceOneModel(value bson.Raw) (*mongo.ClientReplaceOneModel, }, nil } -func createClientDeleteOneModel(value bson.Raw) (*mongo.ClientDeleteOneModel, error) { +func createClientDeleteOneModel(value bson.Raw) (string, *mongo.ClientDeleteOneModel, error) { var v struct { Namespace string Filter bson.Raw @@ -450,24 +453,23 @@ func createClientDeleteOneModel(value bson.Raw) (*mongo.ClientDeleteOneModel, er } err := bson.Unmarshal(value, &v) if err != nil { - return nil, err + return "", nil, err } var hint interface{} if v.Hint != nil { hint, err = createHint(*v.Hint) if err != nil { - return nil, err + return "", nil, err } } - return &mongo.ClientDeleteOneModel{ - Namespace: v.Namespace, + return v.Namespace, &mongo.ClientDeleteOneModel{ Filter: v.Filter, Collation: v.Collation, Hint: hint, }, nil } -func createClientDeleteManyModel(value bson.Raw) (*mongo.ClientDeleteManyModel, error) { +func createClientDeleteManyModel(value bson.Raw) (string, *mongo.ClientDeleteManyModel, error) { var v struct { Namespace string Filter bson.Raw @@ -476,17 +478,16 @@ func createClientDeleteManyModel(value bson.Raw) (*mongo.ClientDeleteManyModel, } err := bson.Unmarshal(value, &v) if err != nil { - return nil, err + return "", nil, err } var hint interface{} if v.Hint != nil { hint, err = createHint(*v.Hint) if err != nil { - return nil, err + return "", nil, err } } - return &mongo.ClientDeleteManyModel{ - Namespace: v.Namespace, + return v.Namespace, &mongo.ClientDeleteManyModel{ Filter: v.Filter, Collation: v.Collation, Hint: hint, diff --git a/mongo/integration/unified/error.go b/mongo/integration/unified/error.go index f69a3da341..2bf4cf380a 100644 --- a/mongo/integration/unified/error.go +++ b/mongo/integration/unified/error.go @@ -183,8 +183,10 @@ func extractErrorDetails(err error) (errorDetails, bool) { } details.labels = converted.Labels case mongo.ClientBulkWriteException: - details.raw = converted.TopLevelError.Raw - details.codes = append(details.codes, int32(converted.TopLevelError.Code)) + if converted.TopLevelError != nil { + details.raw = converted.TopLevelError.Raw + details.codes = append(details.codes, int32(converted.TopLevelError.Code)) + } default: return errorDetails{}, false } diff --git a/mongo/integration/unified/operation.go b/mongo/integration/unified/operation.go index 179cf16793..dc0bbcbb62 100644 --- a/mongo/integration/unified/operation.go +++ b/mongo/integration/unified/operation.go @@ -126,6 +126,8 @@ func (op *operation) run(ctx context.Context, loopDone <-chan struct{}) (*operat return executeListDatabases(ctx, op, false) case "listDatabaseNames": return executeListDatabases(ctx, op, true) + case "clientBulkWrite": + return executeClientBulkWrite(ctx, op) // Database operations case "createCollection": @@ -148,8 +150,6 @@ func (op *operation) run(ctx context.Context, loopDone <-chan struct{}) (*operat return executeAggregate(ctx, op) case "bulkWrite": return executeBulkWrite(ctx, op) - case "clientBulkWrite": - return executeClientBulkWrite(ctx, op) case "countDocuments": return executeCountDocuments(ctx, op) case "createFindCursor": diff --git a/testdata/command-monitoring/unacknowledgedBulkWrite.json b/testdata/command-monitoring/unacknowledgedBulkWrite.json index b30e1540f4..61bb00726c 100644 --- a/testdata/command-monitoring/unacknowledgedBulkWrite.json +++ b/testdata/command-monitoring/unacknowledgedBulkWrite.json @@ -91,7 +91,8 @@ } } } - ] + ], + "ordered": false }, "expectResult": { "insertedCount": { @@ -158,7 +159,7 @@ "command": { "bulkWrite": 1, "errorsOnly": true, - "ordered": true, + "ordered": false, "ops": [ { "insert": 0, diff --git a/testdata/command-monitoring/unacknowledgedBulkWrite.yml b/testdata/command-monitoring/unacknowledgedBulkWrite.yml index 35b8d556fb..2d54525953 100644 --- a/testdata/command-monitoring/unacknowledgedBulkWrite.yml +++ b/testdata/command-monitoring/unacknowledgedBulkWrite.yml @@ -50,6 +50,7 @@ tests: namespace: *namespace filter: { _id: 3 } update: { $set: { x: 333 } } + ordered: false expectResult: insertedCount: $$unsetOrMatches: 0 @@ -89,7 +90,7 @@ tests: command: bulkWrite: 1 errorsOnly: true - ordered: true + ordered: false ops: - insert: 0 document: { _id: 4, x: 44 } diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index f16afe1da6..2eb9a1be70 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -554,12 +554,12 @@ func (op Operation) Execute(ctx context.Context) error { retries = -1 } } - } - // If context is a Timeout context, automatically set retries to -1 (infinite) if retrying is - // enabled. - retryEnabled := op.RetryMode != nil && op.RetryMode.Enabled() - if csot.IsTimeoutContext(ctx) && retryEnabled { - retries = -1 + + // If context is a Timeout context, automatically set retries to -1 (infinite) if retrying is + // enabled. + if csot.IsTimeoutContext(ctx) && op.RetryMode.Enabled() { + retries = -1 + } } var srvr Server @@ -693,14 +693,10 @@ func (op Operation) Execute(ctx context.Context) error { // Calling IncrementTxnNumber() for server descriptions or topologies that do not // support retries (e.g. standalone topologies) will cause server errors. Only do this // check for the first attempt to keep retried writes in the same transaction. - if retrySupported && op.RetryMode != nil && op.Type == Write && op.Client != nil { - op.Client.RetryWrite = false - if op.RetryMode.Enabled() { - op.Client.RetryWrite = true - if !op.Client.Committing && !op.Client.Aborting { - op.Client.IncrementTxnNumber() - } - } + retryEnabled := op.RetryMode != nil && op.RetryMode.Enabled() + needToIncrease := op.Client != nil && !op.Client.Committing && !op.Client.Aborting + if retrySupported && op.Type == Write && retryEnabled && needToIncrease { + op.Client.IncrementTxnNumber() } first = false @@ -726,6 +722,7 @@ func (op Operation) Execute(ctx context.Context) error { if err != nil { return err } + retryEnabled := op.RetryMode != nil && op.RetryMode.Enabled() // set extra data and send event if possible startedInfo.connID = conn.ID() @@ -835,7 +832,7 @@ func (op Operation) Execute(ctx context.Context) error { // If retries are supported for the current operation on the first server description, // the error is considered retryable, and there are retries remaining (negative retries // means retry indefinitely), then retry the operation. - if retrySupported && retryableErr && retries != 0 { + if retrySupported && retryEnabled && retryableErr && retries != 0 { if op.Client != nil && op.Client.Committing { // Apply majority write concern for retries op.Client.UpdateCommitTransactionWriteConcern() @@ -958,7 +955,7 @@ func (op Operation) Execute(ctx context.Context) error { // If retries are supported for the current operation on the first server description, // the error is considered retryable, and there are retries remaining (negative retries // means retry indefinitely), then retry the operation. - if retrySupported && retryableErr && retries != 0 { + if retrySupported && retryEnabled && retryableErr && retries != 0 { if op.Client != nil && op.Client.Committing { // Apply majority write concern for retries op.Client.UpdateCommitTransactionWriteConcern() @@ -1037,10 +1034,9 @@ func (op Operation) Execute(ctx context.Context) error { // the session isn't nil, and client retries are enabled, increment the txn number. // Calling IncrementTxnNumber() for server descriptions or topologies that do not // support retries (e.g. standalone topologies) will cause server errors. - if retrySupported && op.Client != nil && op.RetryMode != nil { - if op.RetryMode.Enabled() { - op.Client.IncrementTxnNumber() - } + if retrySupported && op.Client != nil && retryEnabled { + op.Client.IncrementTxnNumber() + // Reset the retries number for RetryOncePerCommand unless context is a Timeout context, in // which case retries should remain as -1 (as many times as possible). if *op.RetryMode == RetryOncePerCommand && !csot.IsTimeoutContext(ctx) { @@ -1049,9 +1045,7 @@ func (op Operation) Execute(ctx context.Context) error { } currIndex += startedInfo.processedBatches op.Batches.AdvanceBatches(startedInfo.processedBatches) - if op.Batches.Size() > 0 { - continue - } + continue } break } @@ -1250,7 +1244,7 @@ func (op Operation) createLegacyHandshakeWireMessage( return dst, nil, err } - dst, err = op.addSession(dst, desc) + dst, err = op.addSession(dst, desc, false) if err != nil { return dst, nil, err } @@ -1266,9 +1260,10 @@ func (op Operation) createLegacyHandshakeWireMessage( dst, _ = bsoncore.AppendDocumentEnd(dst, idx) if len(rp) > 0 { + idx = wrapper var err error dst = bsoncore.AppendDocumentElement(dst, "$readPreference", rp) - dst, err = bsoncore.AppendDocumentEnd(dst, wrapper) + dst, err = bsoncore.AppendDocumentEnd(dst, idx) if err != nil { return dst, nil, err } @@ -1309,7 +1304,11 @@ func (op Operation) createMsgWireMessage( if err != nil { return dst, nil, err } - dst, err = op.addSession(dst, desc) + retryWrite := false + if op.retryable(conn.Description()) && op.RetryMode != nil && op.RetryMode.Enabled() { + retryWrite = true + } + dst, err = op.addSession(dst, desc, retryWrite) if err != nil { return dst, nil, err } @@ -1356,9 +1355,10 @@ func (op Operation) createWireMessage( var wmindex int32 var err error - fIdx := len(dst) + fIdx := -1 isLegacy := isLegacyHandshake(op, desc) - if isLegacy { + switch { + case isLegacy: cmdFn := func(dst []byte, desc description.SelectedServer) ([]byte, error) { info.processedBatches, dst, err = op.addLegacyCommandFields(dst, desc) return dst, err @@ -1366,7 +1366,7 @@ func (op Operation) createWireMessage( requestID := wiremessage.NextRequestID() wmindex, dst = wiremessage.AppendHeaderStart(dst, requestID, 0, wiremessage.OpQuery) dst, info.cmd, err = op.createLegacyHandshakeWireMessage(maxTimeMS, dst, desc, cmdFn) - } else if op.shouldEncrypt() { + case op.shouldEncrypt(): if desc.WireVersion.Max < cryptMinWireVersion { return dst, false, info, errors.New("auto-encryption requires a MongoDB version of 4.2") } @@ -1375,26 +1375,47 @@ func (op Operation) createWireMessage( return dst, err } wmindex, dst = wiremessage.AppendHeaderStart(dst, requestID, 0, wiremessage.OpMsg) + fIdx = len(dst) dst, info.cmd, err = op.createMsgWireMessage(maxTimeMS, dst, desc, conn, cmdFn) - } else { + default: wmindex, dst = wiremessage.AppendHeaderStart(dst, requestID, 0, wiremessage.OpMsg) - dst, info.cmd, err = op.createMsgWireMessage(maxTimeMS, dst, desc, conn, op.CommandFn) - if err == nil && op.Batches != nil { + fIdx = len(dst) + appendBatches := func(dst []byte) ([]byte, error) { var processedBatches int dsOffset := len(dst) processedBatches, dst, err = op.Batches.AppendBatchSequence(dst, int(desc.MaxBatchCount), int(desc.MaxDocumentSize), int(desc.MaxMessageSize)) - if err == nil { - info.processedBatches = processedBatches - info.documentSequence = make([]byte, 0) - for b := dst[dsOffset:]; len(b) > 0; /* nothing */ { - var seq []byte - var ok bool - seq, b, ok = wiremessage.DocumentSequenceToArray(b) - if !ok { - break - } - info.documentSequence = append(info.documentSequence, seq...) + if err != nil { + return dst, err + } + info.processedBatches = processedBatches + info.documentSequence = make([]byte, 0) + for b := dst[dsOffset:]; len(b) > 0; /* nothing */ { + var seq []byte + var ok bool + seq, b, ok = wiremessage.DocumentSequenceToArray(b) + if !ok { + break } + info.documentSequence = append(info.documentSequence, seq...) + } + return dst, nil + } + switch op.Batches.(type) { + case *Batches: + dst, info.cmd, err = op.createMsgWireMessage(maxTimeMS, dst, desc, conn, op.CommandFn) + if err == nil && op.Batches != nil { + dst, err = appendBatches(dst) + } + default: + var batches []byte + if op.Batches != nil { + batches, err = appendBatches(batches) + } + if err == nil { + dst, info.cmd, err = op.createMsgWireMessage(maxTimeMS, dst, desc, conn, op.CommandFn) + } + if err == nil && len(batches) > 0 { + dst = append(dst, batches...) } } } @@ -1407,7 +1428,7 @@ func (op Operation) createWireMessage( // aren't batching or we are encoding the last batch. unacknowledged := op.WriteConcern != nil && !writeconcern.AckWrite(op.WriteConcern) batching := op.Batches != nil && op.Batches.Size() > info.processedBatches - if !isLegacy && unacknowledged && !batching { + if fIdx > 0 && unacknowledged && !batching { dst[fIdx] |= byte(wiremessage.MoreToCome) moreToCome = true } @@ -1562,7 +1583,7 @@ func (op Operation) addWriteConcern(dst []byte, desc description.SelectedServer) return append(bsoncore.AppendHeader(dst, t, "writeConcern"), data...), nil } -func (op Operation) addSession(dst []byte, desc description.SelectedServer) ([]byte, error) { +func (op Operation) addSession(dst []byte, desc description.SelectedServer, retryWrite bool) ([]byte, error) { client := op.Client // If the operation is defined for an explicit session but the server @@ -1580,7 +1601,7 @@ func (op Operation) addSession(dst []byte, desc description.SelectedServer) ([]b dst = bsoncore.AppendDocumentElement(dst, "lsid", client.SessionID) var addedTxnNumber bool - if op.Type == Write && client.RetryWrite { + if op.Type == Write && retryWrite { addedTxnNumber = true dst = bsoncore.AppendInt64Element(dst, "txnNumber", op.Client.TxnNumber) } diff --git a/x/mongo/driver/session/client_session.go b/x/mongo/driver/session/client_session.go index eff27bfe33..5403f49c20 100644 --- a/x/mongo/driver/session/client_session.go +++ b/x/mongo/driver/session/client_session.go @@ -112,7 +112,6 @@ type Client struct { RetryingCommit bool Committing bool Aborting bool - RetryWrite bool RetryRead bool Snapshot bool