diff --git a/mongo/client_bulk_write.go b/mongo/client_bulk_write.go index e55a52df9d..5b3d09546b 100644 --- a/mongo/client_bulk_write.go +++ b/mongo/client_bulk_write.go @@ -10,6 +10,8 @@ import ( "context" "errors" "fmt" + "io" + "strconv" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/bsoncodec" @@ -20,6 +22,7 @@ import ( "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/session" + "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" ) // bulkWrite performs a bulkwrite operation @@ -35,29 +38,26 @@ type clientBulkWrite struct { selector description.ServerSelector writeConcern *writeconcern.WriteConcern - cursorHandlers []func(*cursorInfo, bson.Raw) error - insIDMap map[int]interface{} - - result ClientBulkWriteResult - writeConcernErrors []WriteConcernError - writeErrors map[int]WriteError + result ClientBulkWriteResult } func (bw *clientBulkWrite) execute(ctx context.Context) error { if len(bw.models) == 0 { return errors.New("empty write models") } - bw.writeErrors = make(map[int]WriteError) - batches, retry, err := bw.processModels() - if err != nil { - return err + batches := &modelBatches{ + session: bw.session, + client: bw.client, + ordered: bw.ordered, + models: bw.models, + result: &bw.result, } - err = driver.Operation{ + err := driver.Operation{ CommandFn: bw.newCommand(), - ProcessResponseFn: bw.ProcessResponse, + ProcessResponseFn: batches.processResponse, Client: bw.session, Clock: bw.client.clock, - RetryMode: retry, + RetryMode: &batches.retryMode, Type: driver.Write, Batches: batches, CommandMonitor: bw.client.monitor, @@ -75,10 +75,38 @@ func (bw *clientBulkWrite) execute(ctx context.Context) error { if err != nil && errors.Is(err, driver.ErrUnacknowledgedWrite) { return nil } - fmt.Println("exec", len(bw.writeErrors), err) return err } +func (bw *clientBulkWrite) newCommand() func([]byte, description.SelectedServer) ([]byte, error) { + return func(dst []byte, desc description.SelectedServer) ([]byte, error) { + dst = bsoncore.AppendInt32Element(dst, "bulkWrite", 1) + + dst = bsoncore.AppendBooleanElement(dst, "errorsOnly", bw.errorsOnly) + if bw.bypassDocumentValidation != nil && (desc.WireVersion != nil && desc.WireVersion.Includes(4)) { + dst = bsoncore.AppendBooleanElement(dst, "bypassDocumentValidation", *bw.bypassDocumentValidation) + } + if bw.comment != nil { + comment, err := marshalValue(bw.comment, bw.client.bsonOpts, bw.client.registry) + if err != nil { + return nil, err + } + dst = bsoncore.AppendValueElement(dst, "comment", comment) + } + if bw.ordered != nil { + dst = bsoncore.AppendBooleanElement(dst, "ordered", *bw.ordered) + } + if bw.let != nil { + let, err := marshal(bw.let, bw.client.bsonOpts, bw.client.registry) + if err != nil { + return nil, err + } + dst = bsoncore.AppendDocumentElement(dst, "let", let) + } + return dst, nil + } +} + type cursorInfo struct { Ok bool Idx int32 @@ -109,21 +137,237 @@ func (cur *cursorInfo) extractError() *WriteError { return err } -func (bw *clientBulkWrite) ProcessResponse(ctx context.Context, info driver.ResponseInfo) error { +type modelBatches struct { + session *session.Client + client *Client + + ordered *bool + models []interface{} + + offset int + + retryMode driver.RetryMode // RetryNone by default + cursorHandlers []func(*cursorInfo, bson.Raw) bool + newIDMap map[int]interface{} + + result *ClientBulkWriteResult + writeConcernErrors []WriteConcernError + writeErrors map[int]WriteError +} + +func (mb *modelBatches) IsOrdered() *bool { + return mb.ordered +} + +func (mb *modelBatches) AdvanceBatches(n int) { + mb.offset += n +} + +func (mb *modelBatches) End() bool { + return len(mb.models) <= mb.offset +} + +func (mb *modelBatches) AppendBatchSequence(dst []byte, maxCount, maxDocSize, totalSize int) (int, []byte, error) { + fn := functionSet{ + appendStart: func(dst []byte, identifier string) (int32, []byte) { + var idx int32 + dst = wiremessage.AppendMsgSectionType(dst, wiremessage.DocumentSequence) + idx, dst = bsoncore.ReserveLength(dst) + dst = append(dst, identifier...) + dst = append(dst, 0x00) + return idx, dst + }, + appendDocument: func(dst []byte, _ string, doc []byte) []byte { + dst = append(dst, doc...) + return dst + }, + appendEnd: func(dst []byte, idx, length int32) []byte { + dst = bsoncore.UpdateLength(dst, idx, length) + return dst + }, + } + return mb.appendBatches(fn, dst, maxCount, maxDocSize, totalSize) +} + +func (mb *modelBatches) AppendBatchArray(dst []byte, maxCount, maxDocSize, totalSize int) (int, []byte, error) { + fn := functionSet{ + appendStart: bsoncore.AppendArrayElementStart, + appendDocument: bsoncore.AppendDocumentElement, + appendEnd: func(dst []byte, idx, _ int32) []byte { + dst, _ = bsoncore.AppendArrayEnd(dst, idx) + return dst + }, + } + return mb.appendBatches(fn, dst, maxCount, maxDocSize, totalSize) +} + +type functionSet struct { + appendStart func([]byte, string) (int32, []byte) + appendDocument func([]byte, string, []byte) []byte + appendEnd func([]byte, int32, int32) []byte +} + +func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxDocSize, totalSize int) (int, []byte, error) { + if mb.End() { + return 0, dst, io.EOF + } + + mb.cursorHandlers = make([]func(*cursorInfo, bson.Raw) bool, len(mb.models)) + mb.newIDMap = make(map[int]interface{}) + + nsMap := make(map[string]int) + getNsIndex := func(namespace string) (int, bsoncore.Document) { + idx, doc := bsoncore.AppendDocumentStart(nil) + doc = bsoncore.AppendStringElement(doc, "ns", namespace) + doc, _ = bsoncore.AppendDocumentEnd(doc, idx) + + if v, ok := nsMap[namespace]; ok { + return v, doc + } + nsIdx := len(nsMap) + nsMap[namespace] = nsIdx + return nsIdx, doc + } + + canRetry := true + + l := len(dst) + + opsIdx, dst := fn.appendStart(dst, "ops") + nsIdx, nsDst := fn.appendStart(nil, "nsInfo") + + size := (len(dst) - l) * 2 + var n int + for i := mb.offset; i < len(mb.models); i++ { + if n == maxCount { + break + } + + var nsIdx int + var ns, doc bsoncore.Document + var err error + switch model := mb.models[i].(type) { + case *ClientInsertOneModel: + nsIdx, ns = getNsIndex(model.Namespace) + mb.cursorHandlers[i] = mb.appendInsertResult + var id interface{} + id, doc, err = (&clientInsertDoc{ + namespace: nsIdx, + document: model.Document, + }).marshal(mb.client.bsonOpts, mb.client.registry) + if err != nil { + break + } + mb.newIDMap[i] = id + case *ClientUpdateOneModel: + nsIdx, ns = getNsIndex(model.Namespace) + mb.cursorHandlers[i] = mb.appendUpdateResult + doc, err = (&clientUpdateDoc{ + namespace: nsIdx, + filter: model.Filter, + update: model.Update, + hint: model.Hint, + arrayFilters: model.ArrayFilters, + collation: model.Collation, + upsert: model.Upsert, + multi: false, + checkDollarKey: true, + }).marshal(mb.client.bsonOpts, mb.client.registry) + case *ClientUpdateManyModel: + canRetry = false + nsIdx, ns = getNsIndex(model.Namespace) + mb.cursorHandlers[i] = mb.appendUpdateResult + doc, err = (&clientUpdateDoc{ + namespace: nsIdx, + filter: model.Filter, + update: model.Update, + hint: model.Hint, + arrayFilters: model.ArrayFilters, + collation: model.Collation, + upsert: model.Upsert, + multi: true, + checkDollarKey: true, + }).marshal(mb.client.bsonOpts, mb.client.registry) + case *ClientReplaceOneModel: + nsIdx, ns = getNsIndex(model.Namespace) + mb.cursorHandlers[i] = mb.appendUpdateResult + doc, err = (&clientUpdateDoc{ + namespace: nsIdx, + filter: model.Filter, + update: model.Replacement, + hint: model.Hint, + arrayFilters: nil, + collation: model.Collation, + upsert: model.Upsert, + multi: false, + checkDollarKey: false, + }).marshal(mb.client.bsonOpts, mb.client.registry) + case *ClientDeleteOneModel: + nsIdx, ns = getNsIndex(model.Namespace) + mb.cursorHandlers[i] = mb.appendDeleteResult + doc, err = (&clientDeleteDoc{ + namespace: nsIdx, + filter: model.Filter, + collation: model.Collation, + hint: model.Hint, + multi: false, + }).marshal(mb.client.bsonOpts, mb.client.registry) + case *ClientDeleteManyModel: + canRetry = false + nsIdx, ns = getNsIndex(model.Namespace) + mb.cursorHandlers[i] = mb.appendDeleteResult + doc, err = (&clientDeleteDoc{ + namespace: nsIdx, + filter: model.Filter, + collation: model.Collation, + hint: model.Hint, + multi: true, + }).marshal(mb.client.bsonOpts, mb.client.registry) + } + if err != nil { + return 0, nil, err + } + length := len(doc) + len(ns) + if length > maxDocSize { + return 0, nil, driver.ErrDocumentTooLarge(i) + } + size += length + if size >= totalSize { + break + } + + dst = fn.appendDocument(dst, strconv.Itoa(n), doc) + nsDst = fn.appendDocument(nsDst, strconv.Itoa(n), ns) + n++ + } + + dst = fn.appendEnd(dst, opsIdx, int32(len(dst[opsIdx:]))) + nsDst = fn.appendEnd(nsDst, nsIdx, int32(len(nsDst[nsIdx:]))) + dst = append(dst, nsDst...) + + mb.retryMode = driver.RetryNone + if mb.client.retryWrites && canRetry { + mb.retryMode = driver.RetryOncePerCommand + } + return n, dst, nil +} + +func (mb *modelBatches) processResponse(ctx context.Context, resp bsoncore.Document, info driver.ResponseInfo) error { fmt.Println("ProcessResponse", info.Error) var writeCmdErr driver.WriteCommandError if errors.As(info.Error, &writeCmdErr) && writeCmdErr.WriteConcernError != nil { wce := convertDriverWriteConcernError(writeCmdErr.WriteConcernError) if wce != nil { - bw.writeConcernErrors = append(bw.writeConcernErrors, *wce) + mb.writeConcernErrors = append(mb.writeConcernErrors, *wce) } } // closeImplicitSession(sess) - if len(info.ServerResponse) == 0 { + if len(resp) == 0 { return nil } var res struct { Ok bool + Cursor bsoncore.Document NDeleted int32 NInserted int32 NMatched int32 @@ -133,235 +377,130 @@ func (bw *clientBulkWrite) ProcessResponse(ctx context.Context, info driver.Resp Code int32 Errmsg string } - err := bson.Unmarshal(info.ServerResponse, &res) + err := bson.UnmarshalWithRegistry(mb.client.registry, resp, &res) if err != nil { return err } - bw.result.DeletedCount += int64(res.NDeleted) - bw.result.InsertedCount += int64(res.NInserted) - bw.result.MatchedCount += int64(res.NMatched) - bw.result.ModifiedCount += int64(res.NModified) - bw.result.UpsertedCount += int64(res.NUpserted) + mb.result.DeletedCount += int64(res.NDeleted) + mb.result.InsertedCount += int64(res.NInserted) + mb.result.MatchedCount += int64(res.NMatched) + mb.result.ModifiedCount += int64(res.NModified) + mb.result.UpsertedCount += int64(res.NUpserted) var cursorRes driver.CursorResponse - cursorRes, err = driver.NewCursorResponse(info) + cursorRes, err = driver.NewCursorResponse(res.Cursor, info) if err != nil { return err } var bCursor *driver.BatchCursor - bCursor, err = driver.NewBatchCursor(cursorRes, bw.session, bw.client.clock, + bCursor, err = driver.NewBatchCursor(cursorRes, mb.session, mb.client.clock, driver.CursorOptions{ - CommandMonitor: bw.client.monitor, - Crypt: bw.client.cryptFLE, - ServerAPI: bw.client.serverAPI, - MarshalValueEncoderFn: newEncoderFn(bw.client.bsonOpts, bw.client.registry), + CommandMonitor: mb.client.monitor, + Crypt: mb.client.cryptFLE, + ServerAPI: mb.client.serverAPI, + MarshalValueEncoderFn: newEncoderFn(mb.client.bsonOpts, mb.client.registry), }, ) if err != nil { return err } var cursor *Cursor - cursor, err = newCursor(bCursor, bw.client.bsonOpts, bw.client.registry) + cursor, err = newCursorWithSession(bCursor, mb.client.bsonOpts, mb.client.registry, mb.session) if err != nil { return err } defer cursor.Close(ctx) + ok := true for cursor.Next(ctx) { var cur cursorInfo cursor.Decode(&cur) - if int(cur.Idx) >= len(bw.cursorHandlers) { + if int(cur.Idx) >= len(mb.cursorHandlers) { continue } - if err := bw.cursorHandlers[int(cur.Idx)](&cur, cursor.Current); err != nil { - fmt.Println("ProcessResponse cursorHandlers", err) - return err - } + ok = ok && mb.cursorHandlers[int(cur.Idx)](&cur, cursor.Current) } err = cursor.Err() if err != nil { return err } fmt.Println("ProcessResponse toplevelerror", res.Ok, res.NErrors, res.Code, res.Errmsg) - // if !res.Ok || res.NErrors > 0 { - // exception := bw.formException() - // exception.TopLevelError = &WriteError{ - // Code: int(res.Code), - // Message: res.Errmsg, - // Raw: bson.Raw(info.ServerResponse), - // } - // return exception - // } + if writeCmdErr.WriteConcernError != nil || !ok || !res.Ok || res.NErrors > 0 { + exception := ClientBulkWriteException{ + WriteConcernErrors: mb.writeConcernErrors, + WriteErrors: mb.writeErrors, + PartialResult: mb.result, + } + if !res.Ok || res.NErrors > 0 { + exception.TopLevelError = &WriteError{ + Code: int(res.Code), + Message: res.Errmsg, + Raw: bson.Raw(resp), + } + } + return exception + } return nil } -func (bw *clientBulkWrite) processModels() ([]driver.Batches, *driver.RetryMode, error) { - nsMap := make(map[string]int) - var nsList []bsoncore.Document - getNsIndex := func(namespace string) int { - if v, ok := nsMap[namespace]; ok { - return v - } - nsIdx := len(nsList) - nsMap[namespace] = nsIdx - idx, doc := bsoncore.AppendDocumentStart(nil) - doc = bsoncore.AppendStringElement(doc, "ns", namespace) - doc, _ = bsoncore.AppendDocumentEnd(doc, idx) - nsList = append(nsList, doc) - return nsIdx +func (mb *modelBatches) appendDeleteResult(cur *cursorInfo, raw bson.Raw) bool { + if mb.result.DeleteResults == nil { + mb.result.DeleteResults = make(map[int]ClientDeleteResult) } - - bw.cursorHandlers = make([]func(*cursorInfo, bson.Raw) error, len(bw.models)) - bw.insIDMap = make(map[int]interface{}) - canRetry := true - docs := make([]bsoncore.Document, len(bw.models)) - for i, v := range bw.models { - var doc bsoncore.Document - var err error - switch model := v.(type) { - case *ClientInsertOneModel: - nsIdx := getNsIndex(model.Namespace) - bw.cursorHandlers[i] = bw.appendInsertResult - var id interface{} - id, doc, err = createClientInsertDoc(int32(nsIdx), model.Document, bw.client.bsonOpts, bw.client.registry) - if err != nil { - break - } - bw.insIDMap[i] = id - case *ClientUpdateOneModel: - nsIdx := getNsIndex(model.Namespace) - bw.cursorHandlers[i] = bw.appendUpdateResult - doc, err = createClientUpdateDoc( - int32(nsIdx), - model.Filter, - model.Update, - model.Hint, - model.ArrayFilters, - model.Collation, - model.Upsert, - false, - true, - bw.client.bsonOpts, - bw.client.registry) - case *ClientUpdateManyModel: - canRetry = false - nsIdx := getNsIndex(model.Namespace) - bw.cursorHandlers[i] = bw.appendUpdateResult - doc, err = createClientUpdateDoc( - int32(nsIdx), - model.Filter, - model.Update, - model.Hint, - model.ArrayFilters, - model.Collation, - model.Upsert, - true, - true, - bw.client.bsonOpts, - bw.client.registry) - case *ClientReplaceOneModel: - nsIdx := getNsIndex(model.Namespace) - bw.cursorHandlers[i] = bw.appendUpdateResult - doc, err = createClientUpdateDoc( - int32(nsIdx), - model.Filter, - model.Replacement, - model.Hint, - nil, - model.Collation, - model.Upsert, - false, - false, - bw.client.bsonOpts, - bw.client.registry) - case *ClientDeleteOneModel: - nsIdx := getNsIndex(model.Namespace) - bw.cursorHandlers[i] = bw.appendDeleteResult - doc, err = createClientDeleteDoc( - int32(nsIdx), - model.Filter, - model.Collation, - model.Hint, - false, - bw.client.bsonOpts, - bw.client.registry) - case *ClientDeleteManyModel: - canRetry = false - nsIdx := getNsIndex(model.Namespace) - bw.cursorHandlers[i] = bw.appendDeleteResult - doc, err = createClientDeleteDoc( - int32(nsIdx), - model.Filter, - model.Collation, - model.Hint, - true, - bw.client.bsonOpts, - bw.client.registry) - } - if err != nil { - return nil, nil, err - } - docs[i] = doc - } - retry := driver.RetryNone - if bw.client.retryWrites && canRetry { - retry = driver.RetryOncePerCommand - } - ordered := false - return []driver.Batches{ - { - Identifier: "ops", - Documents: docs, - Ordered: bw.ordered, - }, - { - Identifier: "nsInfo", - Documents: nsList, - Ordered: &ordered, - }, - }, - &retry, nil + mb.result.DeleteResults[int(cur.Idx)] = ClientDeleteResult{int64(cur.N)} + if err := cur.extractError(); err != nil { + err.Raw = raw + mb.writeErrors[int(cur.Idx)] = *err + return false + } + return true } -func (bw *clientBulkWrite) newCommand() func([]byte, description.SelectedServer) ([]byte, error) { - return func(dst []byte, desc description.SelectedServer) ([]byte, error) { - dst = bsoncore.AppendInt32Element(dst, "bulkWrite", 1) +func (mb *modelBatches) appendInsertResult(cur *cursorInfo, raw bson.Raw) bool { + if mb.result.InsertResults == nil { + mb.result.InsertResults = make(map[int]ClientInsertResult) + } + mb.result.InsertResults[int(cur.Idx)] = ClientInsertResult{mb.newIDMap[int(cur.Idx)]} + if err := cur.extractError(); err != nil { + err.Raw = raw + mb.writeErrors[int(cur.Idx)] = *err + return false + } + return true +} - dst = bsoncore.AppendBooleanElement(dst, "errorsOnly", bw.errorsOnly) - if bw.bypassDocumentValidation != nil && (desc.WireVersion != nil && desc.WireVersion.Includes(4)) { - dst = bsoncore.AppendBooleanElement(dst, "bypassDocumentValidation", *bw.bypassDocumentValidation) - } - if bw.comment != nil { - comment, err := marshalValue(bw.comment, bw.client.bsonOpts, bw.client.registry) - if err != nil { - return nil, err - } - dst = bsoncore.AppendValueElement(dst, "comment", comment) - } - if bw.ordered != nil { - dst = bsoncore.AppendBooleanElement(dst, "ordered", *bw.ordered) - } - if bw.let != nil { - let, err := marshal(bw.let, bw.client.bsonOpts, bw.client.registry) - if err != nil { - return nil, err - } - dst = bsoncore.AppendDocumentElement(dst, "let", let) - } - return dst, nil +func (mb *modelBatches) appendUpdateResult(cur *cursorInfo, raw bson.Raw) bool { + if mb.result.UpdateResults == nil { + mb.result.UpdateResults = make(map[int]ClientUpdateResult) + } + result := ClientUpdateResult{ + MatchedCount: int64(cur.N), + } + if cur.NModified != nil { + result.ModifiedCount = int64(*cur.NModified) } + if cur.Upserted != nil { + result.UpsertedID = (*cur.Upserted).ID + } + mb.result.UpdateResults[int(cur.Idx)] = result + if err := cur.extractError(); err != nil { + err.Raw = raw + mb.writeErrors[int(cur.Idx)] = *err + return false + } + return true +} + +type clientInsertDoc struct { + namespace int + document interface{} } -func createClientInsertDoc( - namespace int32, - document interface{}, - bsonOpts *options.BSONOptions, - registry *bsoncodec.Registry, -) (interface{}, bsoncore.Document, error) { +func (d *clientInsertDoc) marshal(bsonOpts *options.BSONOptions, registry *bsoncodec.Registry) (interface{}, bsoncore.Document, error) { uidx, doc := bsoncore.AppendDocumentStart(nil) - doc = bsoncore.AppendInt32Element(doc, "insert", namespace) - f, err := marshal(document, bsonOpts, registry) + doc = bsoncore.AppendInt32Element(doc, "insert", int32(d.namespace)) + f, err := marshal(d.document, bsonOpts, registry) if err != nil { return nil, nil, err } @@ -375,61 +514,61 @@ func createClientInsertDoc( return id, doc, err } -func createClientUpdateDoc( - namespace int32, - filter interface{}, - update interface{}, - hint interface{}, - arrayFilters *options.ArrayFilters, - collation *options.Collation, - upsert *bool, - multi bool, - checkDollarKey bool, - bsonOpts *options.BSONOptions, - registry *bsoncodec.Registry, -) (bsoncore.Document, error) { +type clientUpdateDoc struct { + namespace int + filter interface{} + update interface{} + hint interface{} + arrayFilters *options.ArrayFilters + collation *options.Collation + upsert *bool + multi bool + checkDollarKey bool +} + +func (d *clientUpdateDoc) marshal(bsonOpts *options.BSONOptions, registry *bsoncodec.Registry) (bsoncore.Document, error) { uidx, doc := bsoncore.AppendDocumentStart(nil) - doc = bsoncore.AppendInt32Element(doc, "update", namespace) + doc = bsoncore.AppendInt32Element(doc, "update", int32(d.namespace)) - f, err := marshal(filter, bsonOpts, registry) + f, err := marshal(d.filter, bsonOpts, registry) if err != nil { return nil, err } doc = bsoncore.AppendDocumentElement(doc, "filter", f) - u, err := marshalUpdateValue(update, bsonOpts, registry, checkDollarKey) + u, err := marshalUpdateValue(d.update, bsonOpts, registry, d.checkDollarKey) if err != nil { return nil, err } doc = bsoncore.AppendValueElement(doc, "updateMods", u) - doc = bsoncore.AppendBooleanElement(doc, "multi", multi) + doc = bsoncore.AppendBooleanElement(doc, "multi", d.multi) - if arrayFilters != nil { + if d.arrayFilters != nil { reg := registry - if arrayFilters.Registry != nil { - reg = arrayFilters.Registry + if d.arrayFilters.Registry != nil { + reg = d.arrayFilters.Registry } - arr, err := marshalValue(arrayFilters.Filters, bsonOpts, reg) + arr, err := marshalValue(d.arrayFilters.Filters, bsonOpts, reg) if err != nil { return nil, err } doc = bsoncore.AppendArrayElement(doc, "arrayFilters", arr.Data) } - if collation != nil { - doc = bsoncore.AppendDocumentElement(doc, "collation", bsoncore.Document(collation.ToDocument())) + if d.collation != nil { + doc = bsoncore.AppendDocumentElement(doc, "collation", bsoncore.Document(d.collation.ToDocument())) } - if upsert != nil { - doc = bsoncore.AppendBooleanElement(doc, "upsert", *upsert) + if d.upsert != nil { + doc = bsoncore.AppendBooleanElement(doc, "upsert", *d.upsert) } - if hint != nil { - if isUnorderedMap(hint) { + if d.hint != nil { + if isUnorderedMap(d.hint) { return nil, ErrMapForOrderedArgument{"hint"} } - hintVal, err := marshalValue(hint, bsonOpts, registry) + hintVal, err := marshalValue(d.hint, bsonOpts, registry) if err != nil { return nil, err } @@ -439,34 +578,34 @@ func createClientUpdateDoc( return bsoncore.AppendDocumentEnd(doc, uidx) } -func createClientDeleteDoc( - namespace int32, - filter interface{}, - collation *options.Collation, - hint interface{}, - multi bool, - bsonOpts *options.BSONOptions, - registry *bsoncodec.Registry, -) (bsoncore.Document, error) { +type clientDeleteDoc struct { + namespace int + filter interface{} + collation *options.Collation + hint interface{} + multi bool +} + +func (d *clientDeleteDoc) marshal(bsonOpts *options.BSONOptions, registry *bsoncodec.Registry) (bsoncore.Document, error) { didx, doc := bsoncore.AppendDocumentStart(nil) - doc = bsoncore.AppendInt32Element(doc, "delete", namespace) + doc = bsoncore.AppendInt32Element(doc, "delete", int32(d.namespace)) - f, err := marshal(filter, bsonOpts, registry) + f, err := marshal(d.filter, bsonOpts, registry) if err != nil { return nil, err } doc = bsoncore.AppendDocumentElement(doc, "filter", f) - doc = bsoncore.AppendBooleanElement(doc, "multi", multi) + doc = bsoncore.AppendBooleanElement(doc, "multi", d.multi) - if collation != nil { - doc = bsoncore.AppendDocumentElement(doc, "collation", collation.ToDocument()) + if d.collation != nil { + doc = bsoncore.AppendDocumentElement(doc, "collation", d.collation.ToDocument()) } - if hint != nil { - if isUnorderedMap(hint) { + if d.hint != nil { + if isUnorderedMap(d.hint) { return nil, ErrMapForOrderedArgument{"hint"} } - hintVal, err := marshalValue(hint, bsonOpts, registry) + hintVal, err := marshalValue(d.hint, bsonOpts, registry) if err != nil { return nil, err } @@ -474,65 +613,3 @@ func createClientDeleteDoc( } return bsoncore.AppendDocumentEnd(doc, didx) } - -func (bw *clientBulkWrite) appendDeleteResult(cur *cursorInfo, raw bson.Raw) error { - if bw.result.DeleteResults == nil { - bw.result.DeleteResults = make(map[int]ClientDeleteResult) - } - bw.result.DeleteResults[int(cur.Idx)] = ClientDeleteResult{int64(cur.N)} - if err := cur.extractError(); err != nil { - err.Raw = raw - bw.writeErrors[int(cur.Idx)] = *err - if bw.ordered != nil && *bw.ordered { - return bw.formException() - } - } - return nil -} - -func (bw *clientBulkWrite) appendInsertResult(cur *cursorInfo, raw bson.Raw) error { - if bw.result.InsertResults == nil { - bw.result.InsertResults = make(map[int]ClientInsertResult) - } - bw.result.InsertResults[int(cur.Idx)] = ClientInsertResult{bw.insIDMap[int(cur.Idx)]} - if err := cur.extractError(); err != nil { - err.Raw = raw - bw.writeErrors[int(cur.Idx)] = *err - if bw.ordered != nil && *bw.ordered { - return bw.formException() - } - } - return nil -} - -func (bw *clientBulkWrite) appendUpdateResult(cur *cursorInfo, raw bson.Raw) error { - if bw.result.UpdateResults == nil { - bw.result.UpdateResults = make(map[int]ClientUpdateResult) - } - result := ClientUpdateResult{ - MatchedCount: int64(cur.N), - } - if cur.NModified != nil { - result.ModifiedCount = int64(*cur.NModified) - } - if cur.Upserted != nil { - result.UpsertedID = (*cur.Upserted).ID - } - bw.result.UpdateResults[int(cur.Idx)] = result - if err := cur.extractError(); err != nil { - err.Raw = raw - bw.writeErrors[int(cur.Idx)] = *err - if bw.ordered != nil && *bw.ordered { - return bw.formException() - } - } - return nil -} - -func (bw *clientBulkWrite) formException() ClientBulkWriteException { - return ClientBulkWriteException{ - WriteConcernErrors: bw.writeConcernErrors, - WriteErrors: bw.writeErrors, - PartialResult: &bw.result, - } -} diff --git a/x/mongo/driver/batch_cursor.go b/x/mongo/driver/batch_cursor.go index db9e24fe48..fcc3895589 100644 --- a/x/mongo/driver/batch_cursor.go +++ b/x/mongo/driver/batch_cursor.go @@ -71,25 +71,29 @@ type CursorResponse struct { postBatchResumeToken bsoncore.Document } -// NewCursorResponse constructs a cursor response from the given response and -// server. If the provided database response does not contain a cursor, it -// returns ErrNoCursor. -// -// NewCursorResponse can be used within the ProcessResponse method for an operation. -func NewCursorResponse(info ResponseInfo) (CursorResponse, error) { - response := info.ServerResponse +// ExtractCursorDocument retrieves cursor document from a database response. If the +// provided response does not contain a cursor, it returns ErrNoCursor. +func ExtractCursorDocument(response bsoncore.Document) (bsoncore.Document, error) { cur, err := response.LookupErr("cursor") if errors.Is(err, bsoncore.ErrElementNotFound) { - return CursorResponse{}, ErrNoCursor + return nil, ErrNoCursor } if err != nil { - return CursorResponse{}, fmt.Errorf("error getting cursor from database response: %w", err) + return nil, fmt.Errorf("error getting cursor from database response: %w", err) } curDoc, ok := cur.DocumentOK() if !ok { - return CursorResponse{}, fmt.Errorf("cursor should be an embedded document but is BSON type %s", cur.Type) + return nil, fmt.Errorf("cursor should be an embedded document but is BSON type %s", cur.Type) } - elems, err := curDoc.Elements() + return curDoc, nil +} + +// NewCursorResponse constructs a cursor response from the given cursor document +// extracted from a database response. +// +// NewCursorResponse can be used within the ProcessResponse method for an operation. +func NewCursorResponse(response bsoncore.Document, info ResponseInfo) (CursorResponse, error) { + elems, err := response.Elements() if err != nil { return CursorResponse{}, fmt.Errorf("error getting elements from cursor: %w", err) } @@ -115,15 +119,17 @@ func NewCursorResponse(info ResponseInfo) (CursorResponse, error) { curresp.Database = database curresp.Collection = collection case "id": - curresp.ID, ok = elem.Value().Int64OK() + id, ok := elem.Value().Int64OK() if !ok { return CursorResponse{}, fmt.Errorf("id should be an int64 but it is a BSON %s", elem.Value().Type) } + curresp.ID = id case "postBatchResumeToken": - curresp.postBatchResumeToken, ok = elem.Value().DocumentOK() + token, ok := elem.Value().DocumentOK() if !ok { return CursorResponse{}, fmt.Errorf("post batch resume token should be a document but it is a BSON %s", elem.Value().Type) } + curresp.postBatchResumeToken = token } } @@ -393,8 +399,8 @@ func (bc *BatchCursor) getMore(ctx context.Context) { }, Database: bc.database, Deployment: bc.getOperationDeployment(), - ProcessResponseFn: func(_ context.Context, info ResponseInfo) error { - response := info.ServerResponse + ProcessResponseFn: func(_ context.Context, response bsoncore.Document, info ResponseInfo) error { + // response := info.ServerResponse id, ok := response.Lookup("cursor", "id").Int64OK() if !ok { return fmt.Errorf("cursor.id should be an int64 but is a BSON %s", response.Lookup("cursor", "id").Type) diff --git a/x/mongo/driver/batches.go b/x/mongo/driver/batches.go index c6b17ceb08..d8c2d23cf2 100644 --- a/x/mongo/driver/batches.go +++ b/x/mongo/driver/batches.go @@ -7,66 +7,98 @@ package driver import ( - "errors" + "fmt" + "io" + "strconv" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" ) // ErrDocumentTooLarge occurs when a document that is larger than the maximum size accepted by a // server is passed to an insert command. -var ErrDocumentTooLarge = errors.New("an inserted document is too large") +type ErrDocumentTooLarge int + +func (e ErrDocumentTooLarge) Error() string { + return fmt.Sprintf("document %d is too large", int(e)) +} // Batches contains the necessary information to batch split an operation. This is only used for write // operations. type Batches struct { Identifier string Documents []bsoncore.Document - Current []bsoncore.Document Ordered *bool -} -// ClearBatch clears the Current batch. This must be called before AdvanceBatch will advance to the -// next batch. -func (b *Batches) ClearBatch() { b.Current = b.Current[:0] } - -// AdvanceBatch splits the next batch using maxCount and targetBatchSize. This method will do nothing if -// the current batch has not been cleared. We do this so that when this is called during execute we -// can call it without first needing to check if we already have a batch, which makes the code -// simpler and makes retrying easier. -// The maxDocSize parameter is used to check that any one document is not too large. If the first document is bigger -// than targetBatchSize but smaller than maxDocSize, a batch of size 1 containing that document will be created. -func (b *Batches) AdvanceBatch(maxCount, targetBatchSize, maxDocSize int) error { - if len(b.Current) > 0 { - return nil - } + offset int +} - if maxCount <= 0 { - maxCount = 1 +func (b *Batches) AppendBatchSequence(dst []byte, maxCount, maxDocSize, totalSize int) (int, []byte, error) { + if b.End() { + return 0, dst, io.EOF } - - splitAfter := 0 - size := 0 - for i, doc := range b.Documents { - if i == maxCount { + l := len(dst) + var idx int32 + dst = wiremessage.AppendMsgSectionType(dst, wiremessage.DocumentSequence) + idx, dst = bsoncore.ReserveLength(dst) + dst = append(dst, b.Identifier...) + dst = append(dst, 0x00) + size := len(dst) - l + var n int + for i := b.offset; i < len(b.Documents); i++ { + if n == maxCount { break } + doc := b.Documents[i] if len(doc) > maxDocSize { - return ErrDocumentTooLarge + return 0, dst[:l], ErrDocumentTooLarge(i) } - if size+len(doc) > targetBatchSize { + size += len(doc) + if size >= totalSize { break } + dst = append(dst, doc...) + n++ + } + dst = bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:]))) + return n, dst, nil +} +func (b *Batches) AppendBatchArray(dst []byte, maxCount, maxDocSize, totalSize int) (int, []byte, error) { + if b.End() { + return 0, dst, io.EOF + } + l := len(dst) + aidx, dst := bsoncore.AppendArrayElementStart(dst, b.Identifier) + size := len(dst) - l + var n int + for i := b.offset; i < len(b.Documents); i++ { + if n == maxCount { + break + } + doc := b.Documents[i] + if len(doc) > maxDocSize { + return 0, dst[:l], ErrDocumentTooLarge(i) + } size += len(doc) - splitAfter++ + if size >= totalSize { + break + } + dst = bsoncore.AppendDocumentElement(dst, strconv.Itoa(n), doc) + n++ } + dst, _ = bsoncore.AppendArrayEnd(dst, aidx) + return n, dst, nil +} - // if there are no documents, take the first one. - // this can happen if there is a document that is smaller than maxDocSize but greater than targetBatchSize. - if splitAfter == 0 { - splitAfter = 1 - } +func (b *Batches) IsOrdered() *bool { + return b.Ordered +} + +func (b *Batches) AdvanceBatches(n int) { + b.offset += n +} - b.Current, b.Documents = b.Documents[:splitAfter], b.Documents[splitAfter:] - return nil +func (b *Batches) End() bool { + return len(b.Documents) <= b.offset } diff --git a/x/mongo/driver/batches_test.go b/x/mongo/driver/batches_test.go index a72813833a..65b4e97c60 100644 --- a/x/mongo/driver/batches_test.go +++ b/x/mongo/driver/batches_test.go @@ -7,109 +7,118 @@ package driver import ( + "bytes" "testing" - "github.com/google/go-cmp/cmp" "go.mongodb.org/mongo-driver/internal/assert" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" ) func TestBatches(t *testing.T) { - t.Run("ClearBatch", func(t *testing.T) { - batches := &Batches{Identifier: "documents", Current: make([]bsoncore.Document, 2, 10)} - if len(batches.Current) != 2 { - t.Fatalf("Length of current batch should be 2, but is %d", len(batches.Current)) - } - batches.ClearBatch() - if len(batches.Current) != 0 { - t.Fatalf("Length of current batch should be 0, but is %d", len(batches.Current)) - } + t.Run("AppendBatchArray too large", func(t *testing.T) { + batches := &Batches{Identifier: "documents", Documents: []bsoncore.Document{bytes.Repeat([]byte("a"), 100)}} + n, _, err := batches.AppendBatchArray(nil, 2, 50, 500) + assert.Equal(t, 0, n) + assert.ErrorIs(t, err, ErrDocumentTooLarge(0)) }) - t.Run("AdvanceBatch", func(t *testing.T) { - documents := make([]bsoncore.Document, 0) - for i := 0; i < 5; i++ { - doc := make(bsoncore.Document, 100) - documents = append(documents, doc) - } - - testCases := []struct { - name string - batches *Batches - maxCount int - targetBatchSize int - maxDocSize int - err error - want *Batches - }{ - { - "current batch non-zero", - &Batches{Current: make([]bsoncore.Document, 2, 10)}, - 0, 0, 0, nil, - &Batches{Current: make([]bsoncore.Document, 2, 10)}, - }, - { - // all of the documents in the batch fit in targetBatchSize so the batch is created successfully - "documents fit in targetBatchSize", - &Batches{Documents: documents}, - 10, 600, 1000, nil, - &Batches{Documents: documents[:0], Current: documents[0:]}, - }, - { - // the first doc is bigger than targetBatchSize but smaller than maxDocSize so it is taken alone - "first document larger than targetBatchSize, smaller than maxDocSize", - &Batches{Documents: documents}, - 10, 5, 100, nil, - &Batches{Documents: documents[1:], Current: documents[:1]}, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - err := tc.batches.AdvanceBatch(tc.maxCount, tc.targetBatchSize, tc.maxDocSize) - if !cmp.Equal(err, tc.err, cmp.Comparer(compareErrors)) { - t.Errorf("Errors do not match. got %v; want %v", err, tc.err) - } - if !cmp.Equal(tc.batches, tc.want) { - t.Errorf("Batches is not in correct state after AdvanceBatch. got %v; want %v", tc.batches, tc.want) - } - }) - } - - t.Run("middle document larger than targetBatchSize, smaller than maxDocSize", func(t *testing.T) { - // a batch is made but one document is too big, so everything before it is taken. - // on the second call to AdvanceBatch, only the large document is taken + /* + t.Run("ClearBatch", func(t *testing.T) { + batches := &Batches{Identifier: "documents", Current: make([]bsoncore.Document, 2, 10)} + if len(batches.Current) != 2 { + t.Fatalf("Length of current batch should be 2, but is %d", len(batches.Current)) + } + batches.ClearBatch() + if len(batches.Current) != 0 { + t.Fatalf("Length of current batch should be 0, but is %d", len(batches.Current)) + } + }) - middleLargeDoc := make([]bsoncore.Document, 0) + t.Run("AdvanceBatch", func(t *testing.T) { + documents := make([]bsoncore.Document, 0) for i := 0; i < 5; i++ { doc := make(bsoncore.Document, 100) - middleLargeDoc = append(middleLargeDoc, doc) + documents = append(documents, doc) } - largeDoc := make(bsoncore.Document, 900) - middleLargeDoc[2] = largeDoc - batches := &Batches{Documents: middleLargeDoc} - maxCount := 10 - targetSize := 600 - maxDocSize := 1000 - // first batch should take first 2 docs (size 100 each) - err := batches.AdvanceBatch(maxCount, targetSize, maxDocSize) - assert.Nil(t, err, "AdvanceBatch error: %v", err) - want := &Batches{Current: middleLargeDoc[:2], Documents: middleLargeDoc[2:]} - assert.Equal(t, want, batches, "expected batches %v, got %v", want, batches) + testCases := []struct { + name string + batches *Batches + maxCount int + targetBatchSize int + maxDocSize int + err error + want *Batches + }{ + { + "current batch non-zero", + &Batches{Current: make([]bsoncore.Document, 2, 10)}, + 0, 0, 0, nil, + &Batches{Current: make([]bsoncore.Document, 2, 10)}, + }, + { + // all of the documents in the batch fit in targetBatchSize so the batch is created successfully + "documents fit in targetBatchSize", + &Batches{Documents: documents}, + 10, 600, 1000, nil, + &Batches{Documents: documents[:0], Current: documents[0:]}, + }, + { + // the first doc is bigger than targetBatchSize but smaller than maxDocSize so it is taken alone + "first document larger than targetBatchSize, smaller than maxDocSize", + &Batches{Documents: documents}, + 10, 5, 100, nil, + &Batches{Documents: documents[1:], Current: documents[:1]}, + }, + } - // second batch should take single large doc (size 900) - batches.ClearBatch() - err = batches.AdvanceBatch(maxCount, targetSize, maxDocSize) - assert.Nil(t, err, "AdvanceBatch error: %v", err) - want = &Batches{Current: middleLargeDoc[2:3], Documents: middleLargeDoc[3:]} - assert.Equal(t, want, batches, "expected batches %v, got %v", want, batches) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := tc.batches.AdvanceBatch(tc.maxCount, tc.targetBatchSize, tc.maxDocSize) + if !cmp.Equal(err, tc.err, cmp.Comparer(compareErrors)) { + t.Errorf("Errors do not match. got %v; want %v", err, tc.err) + } + if !cmp.Equal(tc.batches, tc.want) { + t.Errorf("Batches is not in correct state after AdvanceBatch. got %v; want %v", tc.batches, tc.want) + } + }) + } - // last batch should take last 2 docs (size 100 each) - batches.ClearBatch() - err = batches.AdvanceBatch(maxCount, targetSize, maxDocSize) - assert.Nil(t, err, "AdvanceBatch error: %v", err) - want = &Batches{Current: middleLargeDoc[3:], Documents: middleLargeDoc[:0]} - assert.Equal(t, want, batches, "expected batches %v, got %v", want, batches) + t.Run("middle document larger than targetBatchSize, smaller than maxDocSize", func(t *testing.T) { + // a batch is made but one document is too big, so everything before it is taken. + // on the second call to AdvanceBatch, only the large document is taken + + middleLargeDoc := make([]bsoncore.Document, 0) + for i := 0; i < 5; i++ { + doc := make(bsoncore.Document, 100) + middleLargeDoc = append(middleLargeDoc, doc) + } + largeDoc := make(bsoncore.Document, 900) + middleLargeDoc[2] = largeDoc + batches := &Batches{Documents: middleLargeDoc} + maxCount := 10 + targetSize := 600 + maxDocSize := 1000 + + // first batch should take first 2 docs (size 100 each) + err := batches.AdvanceBatch(maxCount, targetSize, maxDocSize) + assert.Nil(t, err, "AdvanceBatch error: %v", err) + want := &Batches{Current: middleLargeDoc[:2], Documents: middleLargeDoc[2:]} + assert.Equal(t, want, batches, "expected batches %v, got %v", want, batches) + + // second batch should take single large doc (size 900) + batches.ClearBatch() + err = batches.AdvanceBatch(maxCount, targetSize, maxDocSize) + assert.Nil(t, err, "AdvanceBatch error: %v", err) + want = &Batches{Current: middleLargeDoc[2:3], Documents: middleLargeDoc[3:]} + assert.Equal(t, want, batches, "expected batches %v, got %v", want, batches) + + // last batch should take last 2 docs (size 100 each) + batches.ClearBatch() + err = batches.AdvanceBatch(maxCount, targetSize, maxDocSize) + assert.Nil(t, err, "AdvanceBatch error: %v", err) + want = &Batches{Current: middleLargeDoc[3:], Documents: middleLargeDoc[:0]} + assert.Equal(t, want, batches, "expected batches %v, got %v", want, batches) + }) }) - }) + */ } diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 4f45a78ce4..b5b50dcc32 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -53,7 +53,7 @@ var ( const ( // maximum BSON object size when client side encryption is enabled - cryptMaxBsonObjectSize uint32 = 2097152 + cryptMaxBsonObjectSize int = 2097152 // minimum wire version necessary to use automatic encryption cryptMinWireVersion int32 = 8 // minimum wire version necessary to use read snapshots @@ -92,16 +92,17 @@ type opReply struct { // startedInformation keeps track of all of the information necessary for monitoring started events. type startedInformation struct { - cmd bsoncore.Document - requestID int32 - cmdName string - documentSequenceIncluded bool - connID string - driverConnectionID uint64 // TODO(GODRIVER-2824): change type to int64. - serverConnID *int64 - redacted bool - serviceID *primitive.ObjectID - serverAddress address.Address + cmd bsoncore.Document + requestID int32 + cmdName string + documentSequence []byte + processedBatches int + connID string + driverConnectionID uint64 // TODO(GODRIVER-2824): change type to int64. + serverConnID *int64 + redacted bool + serviceID *primitive.ObjectID + serverAddress address.Address } // finishedInformation keeps track of all of the information necessary for monitoring success and failure events. @@ -151,7 +152,6 @@ func (info finishedInformation) success() bool { // ResponseInfo contains the context required to parse a server response. type ResponseInfo struct { - ServerResponse bsoncore.Document Server Server Connection Connection ConnectionDescription description.Server @@ -159,7 +159,7 @@ type ResponseInfo struct { Error error } -func redactStartedInformationCmd(info startedInformation, batches []Batches) bson.Raw { +func redactStartedInformationCmd(info startedInformation) bson.Raw { var cmdCopy bson.Raw // Make a copy of the command. Redact if the command is security @@ -169,13 +169,10 @@ func redactStartedInformationCmd(info startedInformation, batches []Batches) bso cmdCopy = make([]byte, len(info.cmd)) copy(cmdCopy, info.cmd) - if info.documentSequenceIncluded { + if len(info.documentSequence) > 0 { // remove 0 byte at end cmdCopy = cmdCopy[:len(info.cmd)-1] - for i := 0; i < len(batches); i++ { - cmdCopy = addBatchArray(cmdCopy, batches[i].Identifier, batches[i].Current) - } - + cmdCopy = append(cmdCopy, info.documentSequence...) // add back 0 byte and update length cmdCopy, _ = bsoncore.AppendDocumentEnd(cmdCopy, 0) } @@ -221,7 +218,7 @@ type Operation struct { // ProcessResponseFn is called after a response to the command is returned. The server is // provided for types like Cursor that are required to run subsequent commands using the same // server. - ProcessResponseFn func(context.Context, ResponseInfo) error + ProcessResponseFn func(context.Context, bsoncore.Document, ResponseInfo) error // Selector is the server selector that's used during both initial server selection and // subsequent selection for retries. Depending on the Deployment implementation, the @@ -279,7 +276,13 @@ type Operation struct { // has more documents than can fit in a single command. This should only be specified for // commands that are batch compatible. For more information, please refer to the definition of // Batches. - Batches []Batches + Batches interface { + AppendBatchSequence(dst []byte, maxCount int, maxDocSize int, totalSize int) (int, []byte, error) + AppendBatchArray(dst []byte, maxCount int, maxDocSize int, totalSize int) (int, []byte, error) + IsOrdered() *bool + AdvanceBatches(n int) + End() bool + } // Legacy sets the legacy type for this operation. There are only 3 types that require legacy // support: find, getMore, and killCursors. For more information about LegacyOperationKind, @@ -562,7 +565,6 @@ func (op Operation) Execute(ctx context.Context) error { var operationErr WriteCommandError var prevErr error var prevIndefiniteErr error - batching := len(op.Batches) > 0 retrySupported := false first := true currIndex := 0 @@ -714,26 +716,6 @@ func (op Operation) Execute(ctx context.Context) error { desc := description.SelectedServer{Server: conn.Description(), Kind: op.Deployment.Kind()} - if batching { - targetBatchSize := desc.MaxDocumentSize - maxDocSize := desc.MaxDocumentSize - if op.shouldEncrypt() { - // For client-side encryption, we want the batch to be split at 2 MiB instead of 16MiB. - // If there's only one document in the batch, it can be up to 16MiB, so we set target batch size to - // 2MiB but max document size to 16MiB. This will allow the AdvanceBatch call to create a batch - // with a single large document. - targetBatchSize = cryptMaxBsonObjectSize - } - - for i := 0; i < len(op.Batches); i++ { - err = op.Batches[i].AdvanceBatch(int(desc.MaxBatchCount), int(targetBatchSize), int(maxDocSize)) - if err != nil { - // TODO(GODRIVER-982): Should we also be returning operationErr? - return err - } - } - } - var startedInfo startedInformation *wm, startedInfo, err = op.createWireMessage(ctx, maxTimeMS, (*wm)[:0], desc, conn, requestID) @@ -875,20 +857,19 @@ func (op Operation) Execute(ctx context.Context) error { // If the operation isn't being retried, process the response if op.ProcessResponseFn != nil { info := ResponseInfo{ - ServerResponse: res, Server: srvr, Connection: conn, ConnectionDescription: desc.Server, CurrentIndex: currIndex, Error: tt, } - _ = op.ProcessResponseFn(ctx, info) + _ = op.ProcessResponseFn(ctx, res, info) // if perr != nil { // return perr // } } - if batching && len(tt.WriteErrors) > 0 && currIndex > 0 { + if op.Batches != nil && len(tt.WriteErrors) > 0 && currIndex > 0 { for i := range tt.WriteErrors { tt.WriteErrors[i].Index += int64(currIndex) } @@ -896,16 +877,11 @@ func (op Operation) Execute(ctx context.Context) error { // If batching is enabled and either ordered is the default (which is true) or // explicitly set to true and we have write errors, return the errors. - var ordered bool - for i := 0; i < len(op.Batches); i++ { - if op.Batches[i].Ordered == nil || *op.Batches[i].Ordered { - ordered = true - break + if op.Batches != nil && len(tt.WriteErrors) > 0 { + if isOrdered := op.Batches.IsOrdered(); isOrdered == nil || *isOrdered { + return tt } } - if batching && ordered && len(tt.WriteErrors) > 0 { - return tt - } if op.Client != nil && op.Client.Committing && tt.WriteConcernError != nil { // When running commitTransaction we return WriteConcernErrors as an Error. err := Error{ @@ -1005,14 +981,13 @@ func (op Operation) Execute(ctx context.Context) error { // If the operation isn't being retried, process the response if op.ProcessResponseFn != nil { info := ResponseInfo{ - ServerResponse: res, Server: srvr, Connection: conn, ConnectionDescription: desc.Server, CurrentIndex: currIndex, Error: tt, } - _ = op.ProcessResponseFn(ctx, info) + _ = op.ProcessResponseFn(ctx, res, info) // if perr != nil { // return perr // } @@ -1029,14 +1004,13 @@ func (op Operation) Execute(ctx context.Context) error { } if op.ProcessResponseFn != nil { info := ResponseInfo{ - ServerResponse: res, Server: srvr, Connection: conn, ConnectionDescription: desc.Server, CurrentIndex: currIndex, Error: tt, } - perr := op.ProcessResponseFn(ctx, info) + perr := op.ProcessResponseFn(ctx, res, info) if perr != nil { fmt.Println("op", perr) return perr @@ -1045,14 +1019,13 @@ func (op Operation) Execute(ctx context.Context) error { default: if op.ProcessResponseFn != nil { info := ResponseInfo{ - ServerResponse: res, Server: srvr, Connection: conn, ConnectionDescription: desc.Server, CurrentIndex: currIndex, Error: tt, } - _ = op.ProcessResponseFn(ctx, info) + _ = op.ProcessResponseFn(ctx, res, info) } return err } @@ -1060,13 +1033,7 @@ func (op Operation) Execute(ctx context.Context) error { // If we're batching and there are batches remaining, advance to the next batch. This isn't // a retry, so increment the transaction number, reset the retries number, and don't set // server or connection to nil to continue using the same connection. - var advancing []int - for i := 0; i < len(op.Batches); i++ { - if len(op.Batches[i].Documents) > 0 { - advancing = append(advancing, i) - } - } - if batching && len(advancing) > 0 { + if op.Batches != nil { // If retries are supported for the current operation on the current server description, // the session isn't nil, and client retries are enabled, increment the txn number. // Calling IncrementTxnNumber() for server descriptions or topologies that do not @@ -1081,11 +1048,11 @@ func (op Operation) Execute(ctx context.Context) error { retries = 1 } } - for _, i := range advancing { - currIndex += len(op.Batches[i].Current) - op.Batches[i].ClearBatch() + currIndex += startedInfo.processedBatches + op.Batches.AdvanceBatches(startedInfo.processedBatches) + if !op.Batches.End() { + continue } - continue } break } @@ -1241,28 +1208,13 @@ func (Operation) decompressWireMessage(wm []byte) (wiremessage.OpCode, []byte, e return opcode, uncompressed, nil } -func addBatchArray(dst []byte, identifier string, docs []bsoncore.Document) []byte { - if len(docs) == 0 { - return dst - } - aidx, dst := bsoncore.AppendArrayElementStart(dst, identifier) - for i, doc := range docs { - dst = bsoncore.AppendDocumentElement(dst, strconv.Itoa(i), doc) - } - dst, _ = bsoncore.AppendArrayEnd(dst, aidx) - return dst -} - func (op Operation) createLegacyHandshakeWireMessage( maxTimeMS uint64, dst []byte, desc description.SelectedServer, -) ([]byte, startedInformation, error) { - var info startedInformation + cmdFn func([]byte, description.SelectedServer) ([]byte, error), +) ([]byte, []byte, error) { flags := op.secondaryOK(desc) - var wmindex int32 - info.requestID = wiremessage.NextRequestID() - wmindex, dst = wiremessage.AppendHeaderStart(dst, info.requestID, 0, wiremessage.OpQuery) dst = wiremessage.AppendQueryFlags(dst, flags) dollarCmd := [...]byte{'.', '$', 'c', 'm', 'd'} @@ -1277,35 +1229,31 @@ func (op Operation) createLegacyHandshakeWireMessage( wrapper := int32(-1) rp, err := op.createReadPref(desc, true) if err != nil { - return dst, info, err + return dst, nil, err } if len(rp) > 0 { wrapper, dst = bsoncore.AppendDocumentStart(dst) dst = bsoncore.AppendHeader(dst, bsontype.EmbeddedDocument, "$query") } idx, dst := bsoncore.AppendDocumentStart(dst) - dst, err = op.CommandFn(dst, desc) + dst, err = cmdFn(dst, desc) if err != nil { - return dst, info, err - } - - for i := 0; i < len(op.Batches); i++ { - dst = addBatchArray(dst, op.Batches[i].Identifier, op.Batches[i].Current) + return dst, nil, err } dst, err = op.addReadConcern(dst, desc) if err != nil { - return dst, info, err + return dst, nil, err } dst, err = op.addWriteConcern(dst, desc) if err != nil { - return dst, info, err + return dst, nil, err } dst, err = op.addSession(dst, desc) if err != nil { - return dst, info, err + return dst, nil, err } dst = op.addClusterTime(dst, desc) @@ -1317,40 +1265,32 @@ func (op Operation) createLegacyHandshakeWireMessage( } dst, _ = bsoncore.AppendDocumentEnd(dst, idx) - // Command monitoring only reports the document inside $query - info.cmd = dst[idx:] if len(rp) > 0 { var err error dst = bsoncore.AppendDocumentElement(dst, "$readPreference", rp) dst, err = bsoncore.AppendDocumentEnd(dst, wrapper) if err != nil { - return dst, info, err + return dst, nil, err } } - return bsoncore.UpdateLength(dst, wmindex, int32(len(dst[wmindex:]))), info, nil + return dst, dst[idx:], nil } func (op Operation) createMsgWireMessage( - ctx context.Context, maxTimeMS uint64, dst []byte, desc description.SelectedServer, conn Connection, - requestID int32, -) ([]byte, startedInformation, error) { - var info startedInformation + cmdFn func([]byte, description.SelectedServer) ([]byte, error), +) ([]byte, []byte, error) { var flags wiremessage.MsgFlag - var wmindex int32 // We set the MoreToCome bit if we have a write concern, it's unacknowledged, and we either // aren't batching or we are encoding the last batch. var batching bool - for i := 0; i < len(op.Batches); i++ { - if len(op.Batches[i].Documents) > 0 { - batching = true - break - } + if op.Batches != nil && !op.Batches.End() { + batching = true } if op.WriteConcern != nil && !writeconcern.AckWrite(op.WriteConcern) && !batching { flags = wiremessage.MoreToCome @@ -1361,29 +1301,28 @@ func (op Operation) createMsgWireMessage( flags |= wiremessage.ExhaustAllowed } - info.requestID = requestID - wmindex, dst = wiremessage.AppendHeaderStart(dst, info.requestID, 0, wiremessage.OpMsg) dst = wiremessage.AppendMsgFlags(dst, flags) // Body dst = wiremessage.AppendMsgSectionType(dst, wiremessage.SingleDocument) idx, dst := bsoncore.AppendDocumentStart(dst) - dst, err := op.addCommandFields(ctx, dst, desc) + var err error + dst, err = cmdFn(dst, desc) if err != nil { - return dst, info, err + return dst, nil, err } dst, err = op.addReadConcern(dst, desc) if err != nil { - return dst, info, err + return dst, nil, err } dst, err = op.addWriteConcern(dst, desc) if err != nil { - return dst, info, err + return dst, nil, err } dst, err = op.addSession(dst, desc) if err != nil { - return dst, info, err + return dst, nil, err } dst = op.addClusterTime(dst, desc) @@ -1397,39 +1336,15 @@ func (op Operation) createMsgWireMessage( dst = bsoncore.AppendStringElement(dst, "$db", op.Database) rp, err := op.createReadPref(desc, false) if err != nil { - return dst, info, err + return dst, nil, err } if len(rp) > 0 { dst = bsoncore.AppendDocumentElement(dst, "$readPreference", rp) } dst, _ = bsoncore.AppendDocumentEnd(dst, idx) - // The command document for monitoring shouldn't include the type 1 payload as a document sequence - info.cmd = dst[idx:] - - // add batch as a document sequence if auto encryption is not enabled - // if auto encryption is enabled, the batch will already be an array in the command document - if !op.shouldEncrypt() { - for i := 0; i < len(op.Batches); i++ { - if len(op.Batches[i].Current) == 0 { - continue - } - info.documentSequenceIncluded = true - dst = wiremessage.AppendMsgSectionType(dst, wiremessage.DocumentSequence) - idx, dst = bsoncore.ReserveLength(dst) - - dst = append(dst, op.Batches[i].Identifier...) - dst = append(dst, 0x00) - - for _, doc := range op.Batches[i].Current { - dst = append(dst, doc...) - } - dst = bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:]))) - } - } - - return bsoncore.UpdateLength(dst, wmindex, int32(len(dst[wmindex:]))), info, nil + return dst, dst[idx:], nil } // isLegacyHandshake returns True if the operation is the first message of @@ -1448,45 +1363,124 @@ func (op Operation) createWireMessage( conn Connection, requestID int32, ) ([]byte, startedInformation, error) { - if isLegacyHandshake(op, desc) { - return op.createLegacyHandshakeWireMessage(maxTimeMS, dst, desc) - } + var info startedInformation + var wmindex int32 + var err error - return op.createMsgWireMessage(ctx, maxTimeMS, dst, desc, conn, requestID) + isLegacy := isLegacyHandshake(op, desc) + shouldEncrypt := op.shouldEncrypt() + if !isLegacy && !shouldEncrypt { + wmindex, dst = wiremessage.AppendHeaderStart(dst, requestID, 0, wiremessage.OpMsg) + dst, info.cmd, err = op.createMsgWireMessage(maxTimeMS, dst, desc, conn, op.CommandFn) + if err == nil && op.Batches != nil { + var processedBatches int + dsOffset := len(dst) + processedBatches, dst, err = op.Batches.AppendBatchSequence(dst, int(desc.MaxBatchCount), int(desc.MaxDocumentSize), int(desc.MaxMessageSize)) + if err == nil { + info.processedBatches = processedBatches + info.documentSequence = make([]byte, 0) + for b := dst[dsOffset:]; len(b) > 0; /* nothing */ { + var seq []byte + var ok bool + seq, b, ok = wiremessage.DocumentSequenceToArray(b) + if !ok { + break + } + info.documentSequence = append(info.documentSequence, seq...) + } + } + } + } else if shouldEncrypt { + if desc.WireVersion.Max < cryptMinWireVersion { + return dst, info, errors.New("auto-encryption requires a MongoDB version of 4.2") + } + cmdFn := func(dst []byte, desc description.SelectedServer) ([]byte, error) { + // create temporary command document + var cmdDst []byte + info.processedBatches, cmdDst, err = op.addEncryptCommandFields(nil, desc) + if err != nil { + return nil, err + } + // encrypt the command + encrypted, err := op.Crypt.Encrypt(ctx, op.Database, cmdDst) + if err != nil { + return nil, err + } + // append encrypted command to original destination, removing the first 4 bytes (length) and final byte (terminator) + dst = append(dst, encrypted[4:len(encrypted)-1]...) + return dst, nil + } + wmindex, dst = wiremessage.AppendHeaderStart(dst, requestID, 0, wiremessage.OpMsg) + dst, info.cmd, err = op.createMsgWireMessage(maxTimeMS, dst, desc, conn, cmdFn) + } else { // isLegacy + cmdFn := func(dst []byte, desc description.SelectedServer) ([]byte, error) { + info.processedBatches, dst, err = op.addLegacyCommandFields(dst, desc) + return dst, err + } + requestID := wiremessage.NextRequestID() + wmindex, dst = wiremessage.AppendHeaderStart(dst, requestID, 0, wiremessage.OpQuery) + dst, info.cmd, err = op.createLegacyHandshakeWireMessage(maxTimeMS, dst, desc, cmdFn) + } + if err != nil { + return nil, info, err + } + info.requestID = requestID + return bsoncore.UpdateLength(dst, wmindex, int32(len(dst[wmindex:]))), info, nil } -// addCommandFields adds the fields for a command to the wire message in dst. This assumes that the start of the document -// has already been added and does not add the final 0 byte. -func (op Operation) addCommandFields(ctx context.Context, dst []byte, desc description.SelectedServer) ([]byte, error) { - if !op.shouldEncrypt() { - return op.CommandFn(dst, desc) +func (op Operation) addEncryptCommandFields(dst []byte, desc description.SelectedServer) (int, []byte, error) { + var idx int32 + idx, dst = bsoncore.AppendDocumentStart(dst) + var err error + dst, err = op.CommandFn(dst, desc) + if err != nil { + return 0, nil, err } - - if desc.WireVersion.Max < cryptMinWireVersion { - return dst, errors.New("auto-encryption requires a MongoDB version of 4.2") + if op.Batches == nil { + return 0, dst, nil + } + maxBatchCount := int(desc.MaxBatchCount) + maxDocumentSize := int(desc.MaxDocumentSize) + var n int + if maxBatchCount > 1 { + n, dst, err = op.Batches.AppendBatchArray(dst, maxBatchCount, cryptMaxBsonObjectSize, maxDocumentSize) + if err != nil { + var documentTooLarge ErrDocumentTooLarge + if errors.As(err, &documentTooLarge) { + maxBatchCount = 1 + } else { + return 0, nil, err + } + } + } + if maxBatchCount == 1 { + n, dst, err = op.Batches.AppendBatchArray(dst, maxBatchCount, maxDocumentSize, maxDocumentSize) + if err != nil { + return 0, nil, err + } + } + dst, err = bsoncore.AppendDocumentEnd(dst, idx) + if err != nil { + return 0, nil, err } + return n, dst, nil +} - // create temporary command document - cidx, cmdDst := bsoncore.AppendDocumentStart(nil) +func (op Operation) addLegacyCommandFields(dst []byte, desc description.SelectedServer) (int, []byte, error) { var err error - cmdDst, err = op.CommandFn(cmdDst, desc) + dst, err = op.CommandFn(dst, desc) if err != nil { - return dst, err + return 0, nil, err } - // use a BSON array instead of a type 1 payload because mongocryptd will convert to arrays regardless - for i := 0; i < len(op.Batches); i++ { - cmdDst = addBatchArray(cmdDst, op.Batches[i].Identifier, op.Batches[i].Current) + if op.Batches == nil { + return 0, dst, nil } - cmdDst, _ = bsoncore.AppendDocumentEnd(cmdDst, cidx) - - // encrypt the command - encrypted, err := op.Crypt.Encrypt(ctx, op.Database, cmdDst) + var n int + n, dst, err = op.Batches.AppendBatchArray(dst, int(desc.MaxBatchCount), int(desc.MaxDocumentSize), int(desc.MaxDocumentSize)) if err != nil { - return dst, err + return 0, nil, err } - // append encrypted command to original destination, removing the first 4 bytes (length) and final byte (terminator) - dst = append(dst, encrypted[4:len(encrypted)-1]...) - return dst, nil + return n, dst, nil } // addServerAPI adds the relevant fields for server API specification to the wire message in dst. @@ -2022,7 +2016,7 @@ func (op Operation) publishStartedEvent(ctx context.Context, info startedInforma if op.canLogCommandMessage() { host, port, _ := net.SplitHostPort(info.serverAddress.String()) - redactedCmd := redactStartedInformationCmd(info, op.Batches).String() + redactedCmd := redactStartedInformationCmd(info).String() formattedCmd := logger.FormatMessage(redactedCmd, op.Logger.MaxDocumentLength) op.Logger.Print(logger.LevelDebug, @@ -2045,7 +2039,7 @@ func (op Operation) publishStartedEvent(ctx context.Context, info startedInforma if op.canPublishStartedEvent() { started := &event.CommandStartedEvent{ - Command: redactStartedInformationCmd(info, op.Batches), + Command: redactStartedInformationCmd(info), DatabaseName: op.Database, CommandName: info.cmdName, RequestID: int64(info.requestID), diff --git a/x/mongo/driver/operation/abort_transaction.go b/x/mongo/driver/operation/abort_transaction.go index 3155233e42..c3104cdb90 100644 --- a/x/mongo/driver/operation/abort_transaction.go +++ b/x/mongo/driver/operation/abort_transaction.go @@ -41,9 +41,8 @@ func NewAbortTransaction() *AbortTransaction { return &AbortTransaction{} } -func (at *AbortTransaction) processResponse(context.Context, driver.ResponseInfo) error { - var err error - return err +func (at *AbortTransaction) processResponse(context.Context, bsoncore.Document, driver.ResponseInfo) error { + return nil } // Execute runs this operations and returns an error if the operation did not execute successfully. diff --git a/x/mongo/driver/operation/aggregate.go b/x/mongo/driver/operation/aggregate.go index 17b0c6a70e..b95a8205af 100644 --- a/x/mongo/driver/operation/aggregate.go +++ b/x/mongo/driver/operation/aggregate.go @@ -79,10 +79,12 @@ func (a *Aggregate) ResultCursorResponse() driver.CursorResponse { return a.result } -func (a *Aggregate) processResponse(_ context.Context, info driver.ResponseInfo) error { - var err error - - a.result, err = driver.NewCursorResponse(info) +func (a *Aggregate) processResponse(_ context.Context, resp bsoncore.Document, info driver.ResponseInfo) error { + curDoc, err := driver.ExtractCursorDocument(resp) + if err != nil { + return err + } + a.result, err = driver.NewCursorResponse(curDoc, info) return err } diff --git a/x/mongo/driver/operation/command.go b/x/mongo/driver/operation/command.go index d099067392..42de03a49e 100644 --- a/x/mongo/driver/operation/command.go +++ b/x/mongo/driver/operation/command.go @@ -82,11 +82,15 @@ func (c *Command) Execute(ctx context.Context) error { CommandFn: func(dst []byte, _ description.SelectedServer) ([]byte, error) { return append(dst, c.command[4:len(c.command)-1]...), nil }, - ProcessResponseFn: func(_ context.Context, info driver.ResponseInfo) error { - c.resultResponse = info.ServerResponse + ProcessResponseFn: func(_ context.Context, resp bsoncore.Document, info driver.ResponseInfo) error { + c.resultResponse = resp if c.createCursor { - cursorRes, err := driver.NewCursorResponse(info) + curDoc, err := driver.ExtractCursorDocument(resp) + if err != nil { + return err + } + cursorRes, err := driver.NewCursorResponse(curDoc, info) if err != nil { return err } diff --git a/x/mongo/driver/operation/commit_transaction.go b/x/mongo/driver/operation/commit_transaction.go index cf20d9abec..fda31da542 100644 --- a/x/mongo/driver/operation/commit_transaction.go +++ b/x/mongo/driver/operation/commit_transaction.go @@ -42,9 +42,8 @@ func NewCommitTransaction() *CommitTransaction { return &CommitTransaction{} } -func (ct *CommitTransaction) processResponse(context.Context, driver.ResponseInfo) error { - var err error - return err +func (ct *CommitTransaction) processResponse(context.Context, bsoncore.Document, driver.ResponseInfo) error { + return nil } // Execute runs this operations and returns an error if the operation did not execute successfully. diff --git a/x/mongo/driver/operation/count.go b/x/mongo/driver/operation/count.go index 8c0e30c031..7ef3f549e3 100644 --- a/x/mongo/driver/operation/count.go +++ b/x/mongo/driver/operation/count.go @@ -99,9 +99,9 @@ func NewCount() *Count { // Result returns the result of executing this operation. func (c *Count) Result() CountResult { return c.result } -func (c *Count) processResponse(_ context.Context, info driver.ResponseInfo) error { +func (c *Count) processResponse(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error { var err error - c.result, err = buildCountResult(info.ServerResponse) + c.result, err = buildCountResult(resp) return err } diff --git a/x/mongo/driver/operation/create.go b/x/mongo/driver/operation/create.go index 6a6301d6b1..73a02feb4d 100644 --- a/x/mongo/driver/operation/create.go +++ b/x/mongo/driver/operation/create.go @@ -56,7 +56,7 @@ func NewCreate(collectionName string) *Create { } } -func (c *Create) processResponse(context.Context, driver.ResponseInfo) error { +func (c *Create) processResponse(context.Context, bsoncore.Document, driver.ResponseInfo) error { return nil } diff --git a/x/mongo/driver/operation/create_indexes.go b/x/mongo/driver/operation/create_indexes.go index 64bd90940e..ae73d49e5e 100644 --- a/x/mongo/driver/operation/create_indexes.go +++ b/x/mongo/driver/operation/create_indexes.go @@ -93,9 +93,9 @@ func NewCreateIndexes(indexes bsoncore.Document) *CreateIndexes { // Result returns the result of executing this operation. func (ci *CreateIndexes) Result() CreateIndexesResult { return ci.result } -func (ci *CreateIndexes) processResponse(_ context.Context, info driver.ResponseInfo) error { +func (ci *CreateIndexes) processResponse(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error { var err error - ci.result, err = buildCreateIndexesResult(info.ServerResponse) + ci.result, err = buildCreateIndexesResult(resp) return err } diff --git a/x/mongo/driver/operation/create_search_indexes.go b/x/mongo/driver/operation/create_search_indexes.go index 90495231c7..180858b3ec 100644 --- a/x/mongo/driver/operation/create_search_indexes.go +++ b/x/mongo/driver/operation/create_search_indexes.go @@ -93,9 +93,9 @@ func NewCreateSearchIndexes(indexes bsoncore.Document) *CreateSearchIndexes { // Result returns the result of executing this operation. func (csi *CreateSearchIndexes) Result() CreateSearchIndexesResult { return csi.result } -func (csi *CreateSearchIndexes) processResponse(_ context.Context, info driver.ResponseInfo) error { +func (csi *CreateSearchIndexes) processResponse(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error { var err error - csi.result, err = buildCreateSearchIndexesResult(info.ServerResponse) + csi.result, err = buildCreateSearchIndexesResult(resp) return err } diff --git a/x/mongo/driver/operation/delete.go b/x/mongo/driver/operation/delete.go index e03047a9cf..04308cc239 100644 --- a/x/mongo/driver/operation/delete.go +++ b/x/mongo/driver/operation/delete.go @@ -81,8 +81,8 @@ func NewDelete(deletes ...bsoncore.Document) *Delete { // Result returns the result of executing this operation. func (d *Delete) Result() DeleteResult { return d.result } -func (d *Delete) processResponse(_ context.Context, info driver.ResponseInfo) error { - dr, err := buildDeleteResult(info.ServerResponse) +func (d *Delete) processResponse(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error { + dr, err := buildDeleteResult(resp) d.result.N += dr.N return err } @@ -96,12 +96,10 @@ func (d *Delete) Execute(ctx context.Context) error { return driver.Operation{ CommandFn: d.command, ProcessResponseFn: d.processResponse, - Batches: []driver.Batches{ - { - Identifier: "deletes", - Documents: d.deletes, - Ordered: d.ordered, - }, + Batches: &driver.Batches{ + Identifier: "deletes", + Documents: d.deletes, + Ordered: d.ordered, }, RetryMode: d.retry, Type: driver.Write, diff --git a/x/mongo/driver/operation/distinct.go b/x/mongo/driver/operation/distinct.go index 946aa8aa4c..882b4c558e 100644 --- a/x/mongo/driver/operation/distinct.go +++ b/x/mongo/driver/operation/distinct.go @@ -77,9 +77,9 @@ func NewDistinct(key string, query bsoncore.Document) *Distinct { // Result returns the result of executing this operation. func (d *Distinct) Result() DistinctResult { return d.result } -func (d *Distinct) processResponse(_ context.Context, info driver.ResponseInfo) error { +func (d *Distinct) processResponse(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error { var err error - d.result, err = buildDistinctResult(info.ServerResponse) + d.result, err = buildDistinctResult(resp) return err } diff --git a/x/mongo/driver/operation/drop_collection.go b/x/mongo/driver/operation/drop_collection.go index 11f4b74f82..b1c18d4083 100644 --- a/x/mongo/driver/operation/drop_collection.go +++ b/x/mongo/driver/operation/drop_collection.go @@ -79,9 +79,9 @@ func NewDropCollection() *DropCollection { // Result returns the result of executing this operation. func (dc *DropCollection) Result() DropCollectionResult { return dc.result } -func (dc *DropCollection) processResponse(_ context.Context, info driver.ResponseInfo) error { +func (dc *DropCollection) processResponse(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error { var err error - dc.result, err = buildDropCollectionResult(info.ServerResponse) + dc.result, err = buildDropCollectionResult(resp) return err } diff --git a/x/mongo/driver/operation/drop_indexes.go b/x/mongo/driver/operation/drop_indexes.go index 70932ed0fb..1171d81937 100644 --- a/x/mongo/driver/operation/drop_indexes.go +++ b/x/mongo/driver/operation/drop_indexes.go @@ -74,9 +74,9 @@ func NewDropIndexes(index any) *DropIndexes { // Result returns the result of executing this operation. func (di *DropIndexes) Result() DropIndexesResult { return di.result } -func (di *DropIndexes) processResponse(_ context.Context, info driver.ResponseInfo) error { +func (di *DropIndexes) processResponse(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error { var err error - di.result, err = buildDropIndexesResult(info.ServerResponse) + di.result, err = buildDropIndexesResult(resp) return err } diff --git a/x/mongo/driver/operation/drop_search_index.go b/x/mongo/driver/operation/drop_search_index.go index 688a6e2280..55d675a84f 100644 --- a/x/mongo/driver/operation/drop_search_index.go +++ b/x/mongo/driver/operation/drop_search_index.go @@ -69,9 +69,9 @@ func NewDropSearchIndex(index string) *DropSearchIndex { // Result returns the result of executing this operation. func (dsi *DropSearchIndex) Result() DropSearchIndexResult { return dsi.result } -func (dsi *DropSearchIndex) processResponse(_ context.Context, info driver.ResponseInfo) error { +func (dsi *DropSearchIndex) processResponse(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error { var err error - dsi.result, err = buildDropSearchIndexResult(info.ServerResponse) + dsi.result, err = buildDropSearchIndexResult(resp) return err } diff --git a/x/mongo/driver/operation/end_sessions.go b/x/mongo/driver/operation/end_sessions.go index 4ddf781683..eaf03d2ced 100644 --- a/x/mongo/driver/operation/end_sessions.go +++ b/x/mongo/driver/operation/end_sessions.go @@ -39,9 +39,8 @@ func NewEndSessions(sessionIDs bsoncore.Document) *EndSessions { } } -func (es *EndSessions) processResponse(context.Context, driver.ResponseInfo) error { - var err error - return err +func (es *EndSessions) processResponse(context.Context, bsoncore.Document, driver.ResponseInfo) error { + return nil } // Execute runs this operations and returns an error if the operation did not execute successfully. diff --git a/x/mongo/driver/operation/find.go b/x/mongo/driver/operation/find.go index fd611ebde8..e8f7640e94 100644 --- a/x/mongo/driver/operation/find.go +++ b/x/mongo/driver/operation/find.go @@ -80,9 +80,12 @@ func (f *Find) Result(opts driver.CursorOptions) (*driver.BatchCursor, error) { return driver.NewBatchCursor(f.result, f.session, f.clock, opts) } -func (f *Find) processResponse(_ context.Context, info driver.ResponseInfo) error { - var err error - f.result, err = driver.NewCursorResponse(info) +func (f *Find) processResponse(_ context.Context, resp bsoncore.Document, info driver.ResponseInfo) error { + curDoc, err := driver.ExtractCursorDocument(resp) + if err != nil { + return err + } + f.result, err = driver.NewCursorResponse(curDoc, info) return err } diff --git a/x/mongo/driver/operation/find_and_modify.go b/x/mongo/driver/operation/find_and_modify.go index 1ec799be7c..76f34b9255 100644 --- a/x/mongo/driver/operation/find_and_modify.go +++ b/x/mongo/driver/operation/find_and_modify.go @@ -114,10 +114,10 @@ func NewFindAndModify(query bsoncore.Document) *FindAndModify { // Result returns the result of executing this operation. func (fam *FindAndModify) Result() FindAndModifyResult { return fam.result } -func (fam *FindAndModify) processResponse(_ context.Context, info driver.ResponseInfo) error { +func (fam *FindAndModify) processResponse(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error { var err error - fam.result, err = buildFindAndModifyResult(info.ServerResponse) + fam.result, err = buildFindAndModifyResult(resp) return err } diff --git a/x/mongo/driver/operation/hello.go b/x/mongo/driver/operation/hello.go index 8e38dd0473..77086f8dc6 100644 --- a/x/mongo/driver/operation/hello.go +++ b/x/mongo/driver/operation/hello.go @@ -586,8 +586,8 @@ func (h *Hello) createOperation() driver.Operation { CommandFn: h.command, Database: "admin", Deployment: h.d, - ProcessResponseFn: func(_ context.Context, info driver.ResponseInfo) error { - h.res = info.ServerResponse + ProcessResponseFn: func(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error { + h.res = resp return nil }, ServerAPI: h.serverAPI, @@ -610,8 +610,8 @@ func (h *Hello) GetHandshakeInformation(ctx context.Context, _ address.Address, CommandFn: h.handshakeCommand, Deployment: deployment, Database: "admin", - ProcessResponseFn: func(_ context.Context, info driver.ResponseInfo) error { - h.res = info.ServerResponse + ProcessResponseFn: func(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error { + h.res = resp return nil }, ServerAPI: h.serverAPI, diff --git a/x/mongo/driver/operation/insert.go b/x/mongo/driver/operation/insert.go index 71817a9c42..7f7c6a5453 100644 --- a/x/mongo/driver/operation/insert.go +++ b/x/mongo/driver/operation/insert.go @@ -80,8 +80,8 @@ func NewInsert(documents ...bsoncore.Document) *Insert { // Result returns the result of executing this operation. func (i *Insert) Result() InsertResult { return i.result } -func (i *Insert) processResponse(_ context.Context, info driver.ResponseInfo) error { - ir, err := buildInsertResult(info.ServerResponse) +func (i *Insert) processResponse(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error { + ir, err := buildInsertResult(resp) i.result.N += ir.N return err } @@ -95,12 +95,10 @@ func (i *Insert) Execute(ctx context.Context) error { return driver.Operation{ CommandFn: i.command, ProcessResponseFn: i.processResponse, - Batches: []driver.Batches{ - { - Identifier: "documents", - Documents: i.documents, - Ordered: i.ordered, - }, + Batches: &driver.Batches{ + Identifier: "documents", + Documents: i.documents, + Ordered: i.ordered, }, RetryMode: i.retry, Type: driver.Write, diff --git a/x/mongo/driver/operation/list_collections.go b/x/mongo/driver/operation/list_collections.go index 927f82be48..701f7ea01e 100644 --- a/x/mongo/driver/operation/list_collections.go +++ b/x/mongo/driver/operation/list_collections.go @@ -55,9 +55,12 @@ func (lc *ListCollections) Result(opts driver.CursorOptions) (*driver.BatchCurso return driver.NewBatchCursor(lc.result, lc.session, lc.clock, opts) } -func (lc *ListCollections) processResponse(_ context.Context, info driver.ResponseInfo) error { - var err error - lc.result, err = driver.NewCursorResponse(info) +func (lc *ListCollections) processResponse(_ context.Context, resp bsoncore.Document, info driver.ResponseInfo) error { + curDoc, err := driver.ExtractCursorDocument(resp) + if err != nil { + return err + } + lc.result, err = driver.NewCursorResponse(curDoc, info) return err } diff --git a/x/mongo/driver/operation/listDatabases.go b/x/mongo/driver/operation/list_databases.go similarity index 98% rename from x/mongo/driver/operation/listDatabases.go rename to x/mongo/driver/operation/list_databases.go index 1f8719e9fc..37371d49cb 100644 --- a/x/mongo/driver/operation/listDatabases.go +++ b/x/mongo/driver/operation/list_databases.go @@ -135,10 +135,10 @@ func NewListDatabases(filter bsoncore.Document) *ListDatabases { // Result returns the result of executing this operation. func (ld *ListDatabases) Result() ListDatabasesResult { return ld.result } -func (ld *ListDatabases) processResponse(_ context.Context, info driver.ResponseInfo) error { +func (ld *ListDatabases) processResponse(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error { var err error - ld.result, err = buildListDatabasesResult(info.ServerResponse) + ld.result, err = buildListDatabasesResult(resp) return err } diff --git a/x/mongo/driver/operation/list_indexes.go b/x/mongo/driver/operation/list_indexes.go index 93cd288842..a9cf200779 100644 --- a/x/mongo/driver/operation/list_indexes.go +++ b/x/mongo/driver/operation/list_indexes.go @@ -54,10 +54,12 @@ func (li *ListIndexes) Result(opts driver.CursorOptions) (*driver.BatchCursor, e return driver.NewBatchCursor(li.result, clientSession, clock, opts) } -func (li *ListIndexes) processResponse(_ context.Context, info driver.ResponseInfo) error { - var err error - - li.result, err = driver.NewCursorResponse(info) +func (li *ListIndexes) processResponse(_ context.Context, resp bsoncore.Document, info driver.ResponseInfo) error { + curDoc, err := driver.ExtractCursorDocument(resp) + if err != nil { + return err + } + li.result, err = driver.NewCursorResponse(curDoc, info) return err } diff --git a/x/mongo/driver/operation/update.go b/x/mongo/driver/operation/update.go index 2b63e750a2..2c9c9cba28 100644 --- a/x/mongo/driver/operation/update.go +++ b/x/mongo/driver/operation/update.go @@ -124,8 +124,8 @@ func NewUpdate(updates ...bsoncore.Document) *Update { // Result returns the result of executing this operation. func (u *Update) Result() UpdateResult { return u.result } -func (u *Update) processResponse(_ context.Context, info driver.ResponseInfo) error { - ur, err := buildUpdateResult(info.ServerResponse) +func (u *Update) processResponse(_ context.Context, resp bsoncore.Document, info driver.ResponseInfo) error { + ur, err := buildUpdateResult(resp) u.result.N += ur.N u.result.NModified += ur.NModified @@ -148,12 +148,10 @@ func (u *Update) Execute(ctx context.Context) error { return driver.Operation{ CommandFn: u.command, ProcessResponseFn: u.processResponse, - Batches: []driver.Batches{ - { - Identifier: "updates", - Documents: u.updates, - Ordered: u.ordered, - }, + Batches: &driver.Batches{ + Identifier: "updates", + Documents: u.updates, + Ordered: u.ordered, }, RetryMode: u.retry, Type: driver.Write, diff --git a/x/mongo/driver/operation/update_search_index.go b/x/mongo/driver/operation/update_search_index.go index 6df4874790..60fc0d4c04 100644 --- a/x/mongo/driver/operation/update_search_index.go +++ b/x/mongo/driver/operation/update_search_index.go @@ -71,9 +71,9 @@ func NewUpdateSearchIndex(index string, definition bsoncore.Document) *UpdateSea // Result returns the result of executing this operation. func (usi *UpdateSearchIndex) Result() UpdateSearchIndexResult { return usi.result } -func (usi *UpdateSearchIndex) processResponse(_ context.Context, info driver.ResponseInfo) error { +func (usi *UpdateSearchIndex) processResponse(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error { var err error - usi.result, err = buildUpdateSearchIndexResult(info.ServerResponse) + usi.result, err = buildUpdateSearchIndexResult(resp) return err } diff --git a/x/mongo/driver/operation_exhaust.go b/x/mongo/driver/operation_exhaust.go index 1836ec4657..db1bd881e9 100644 --- a/x/mongo/driver/operation_exhaust.go +++ b/x/mongo/driver/operation_exhaust.go @@ -25,10 +25,9 @@ func (op Operation) ExecuteExhaust(ctx context.Context, conn StreamerConnection) if op.ProcessResponseFn != nil { // Server, ConnectionDescription, and CurrentIndex are unused in this mode. info := ResponseInfo{ - ServerResponse: res, - Connection: conn, + Connection: conn, } - if err = op.ProcessResponseFn(ctx, info); err != nil { + if err = op.ProcessResponseFn(ctx, res, info); err != nil { return err } } diff --git a/x/mongo/driver/wiremessage/wiremessage.go b/x/mongo/driver/wiremessage/wiremessage.go index 987ae16c08..0394272891 100644 --- a/x/mongo/driver/wiremessage/wiremessage.go +++ b/x/mongo/driver/wiremessage/wiremessage.go @@ -16,6 +16,7 @@ package wiremessage import ( "bytes" "encoding/binary" + "strconv" "strings" "sync/atomic" @@ -422,6 +423,38 @@ func ReadMsgSectionRawDocumentSequence(src []byte) (identifier string, data []by return identifier, rem, rest, true } +func DocumentSequenceToArray(src []byte) (data bsoncore.Array, rem []byte, ok bool) { + stype, rem, ok := ReadMsgSectionType(src) + if !ok || stype != DocumentSequence { + return nil, src, false + } + var identifier string + var ret []byte + identifier, rem, ret, ok = ReadMsgSectionRawDocumentSequence(rem) + if !ok { + return nil, src, false + } + + aidx, dst := bsoncore.AppendArrayElementStart(nil, identifier) + i := 0 + for { + var doc bsoncore.Document + doc, rem, ok = bsoncore.ReadDocument(rem) + if !ok { + break + } + dst = bsoncore.AppendDocumentElement(dst, strconv.Itoa(i), doc) + i++ + } + if len(rem) > 0 { + return nil, src, false + } + + dst, _ = bsoncore.AppendArrayEnd(dst, aidx) + + return dst, ret, true +} + // ReadMsgChecksum reads a checksum from src. func ReadMsgChecksum(src []byte) (checksum uint32, rem []byte, ok bool) { i32, rem, ok := readi32(src)