From 97d5b6bf503ea6c4f642589a5ab81a4bc72a2fbc Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Sun, 13 Oct 2024 15:10:20 -0400 Subject: [PATCH] WIP --- mongo/client_bulk_write.go | 12 ++- .../client_side_encryption_prose_test.go | 2 +- x/mongo/driver/batches.go | 22 +++-- x/mongo/driver/operation.go | 93 +++++++++---------- 4 files changed, 67 insertions(+), 62 deletions(-) diff --git a/mongo/client_bulk_write.go b/mongo/client_bulk_write.go index 962b00efd3..2e00b854c8 100644 --- a/mongo/client_bulk_write.go +++ b/mongo/client_bulk_write.go @@ -161,10 +161,16 @@ func (mb *modelBatches) IsOrdered() *bool { func (mb *modelBatches) AdvanceBatches(n int) { mb.offset += n + if mb.offset > len(mb.models) { + mb.offset = len(mb.models) + } } -func (mb *modelBatches) End() bool { - return len(mb.models) <= mb.offset +func (mb *modelBatches) Size() int { + if mb.offset > len(mb.models) { + return 0 + } + return len(mb.models) - mb.offset } func (mb *modelBatches) AppendBatchSequence(dst []byte, maxCount, maxDocSize, totalSize int) (int, []byte, error) { @@ -208,7 +214,7 @@ type functionSet struct { } func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxDocSize, totalSize int) (int, []byte, error) { - if mb.End() { + if mb.Size() == 0 { return 0, dst, io.EOF } diff --git a/mongo/integration/client_side_encryption_prose_test.go b/mongo/integration/client_side_encryption_prose_test.go index 3d2b2304a2..9b52c44aeb 100644 --- a/mongo/integration/client_side_encryption_prose_test.go +++ b/mongo/integration/client_side_encryption_prose_test.go @@ -470,7 +470,7 @@ func TestClientSideEncryptionProse(t *testing.T) { cpt.cseStarted = cpt.cseStarted[:0] _, err = cpt.cseColl.InsertMany(context.Background(), []interface{}{firstBulkDoc, secondBulkDoc}) assert.Nil(mt, err, "InsertMany error for large documents: %v", err) - assert.Equal(mt, 2, len(cpt.cseStarted), "expected 2 insert events, got %d", len(cpt.cseStarted)) + assert.Equal(mt, 2, len(cpt.cseStarted), "expected 2 insert events, got %d with size %d %d", len(cpt.cseStarted), len(str), len(limitsDoc)) // insert a document slightly smaller than 16MiB and expect the operation to succeed doc = bson.D{{"_id", "under_16mib"}, {"unencrypted", complete16mbStr[:maxBsonObjSize-2000]}} diff --git a/x/mongo/driver/batches.go b/x/mongo/driver/batches.go index fbcd169a9f..fdafafd0bd 100644 --- a/x/mongo/driver/batches.go +++ b/x/mongo/driver/batches.go @@ -25,7 +25,7 @@ type Batches struct { } func (b *Batches) AppendBatchSequence(dst []byte, maxCount, maxDocSize, totalSize int) (int, []byte, error) { - if b.End() { + if b.Size() == 0 { return 0, dst, io.EOF } l := len(dst) @@ -34,7 +34,7 @@ func (b *Batches) AppendBatchSequence(dst []byte, maxCount, maxDocSize, totalSiz idx, dst = bsoncore.ReserveLength(dst) dst = append(dst, b.Identifier...) dst = append(dst, 0x00) - size := len(dst) - l + var size int var n int for i := b.offset; i < len(b.Documents); i++ { if n == maxCount { @@ -45,7 +45,7 @@ func (b *Batches) AppendBatchSequence(dst []byte, maxCount, maxDocSize, totalSiz break } size += len(doc) - if size >= totalSize { + if size > maxDocSize { break } dst = append(dst, doc...) @@ -59,12 +59,12 @@ func (b *Batches) AppendBatchSequence(dst []byte, maxCount, maxDocSize, totalSiz } func (b *Batches) AppendBatchArray(dst []byte, maxCount, maxDocSize, totalSize int) (int, []byte, error) { - if b.End() { + if b.Size() == 0 { return 0, dst, io.EOF } l := len(dst) aidx, dst := bsoncore.AppendArrayElementStart(dst, b.Identifier) - size := len(dst) - l + var size int var n int for i := b.offset; i < len(b.Documents); i++ { if n == maxCount { @@ -75,7 +75,7 @@ func (b *Batches) AppendBatchArray(dst []byte, maxCount, maxDocSize, totalSize i break } size += len(doc) - if size >= totalSize { + if size > totalSize { break } dst = bsoncore.AppendDocumentElement(dst, strconv.Itoa(n), doc) @@ -98,8 +98,14 @@ func (b *Batches) IsOrdered() *bool { func (b *Batches) AdvanceBatches(n int) { b.offset += n + if b.offset > len(b.Documents) { + b.offset = len(b.Documents) + } } -func (b *Batches) End() bool { - return len(b.Documents) <= b.offset +func (b *Batches) Size() int { + if b.offset > len(b.Documents) { + return 0 + } + return len(b.Documents) - b.offset } diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 0c98eabb60..afe23828ae 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -284,7 +284,7 @@ type Operation struct { AppendBatchArray(dst []byte, maxCount int, maxDocSize int, totalSize int) (int, []byte, error) IsOrdered() *bool AdvanceBatches(n int) - End() bool + Size() int } // Legacy sets the legacy type for this operation. There are only 3 types that require legacy @@ -719,8 +719,9 @@ func (op Operation) Execute(ctx context.Context) error { desc := description.SelectedServer{Server: conn.Description(), Kind: op.Deployment.Kind()} + var moreToCome bool var startedInfo startedInformation - *wm, startedInfo, err = op.createWireMessage(ctx, maxTimeMS, (*wm)[:0], desc, conn, requestID) + *wm, moreToCome, startedInfo, err = op.createWireMessage(ctx, maxTimeMS, (*wm)[:0], desc, conn, requestID) if err != nil { return err @@ -746,9 +747,6 @@ func (op Operation) Execute(ctx context.Context) error { op.publishStartedEvent(ctx, startedInfo) - // get the moreToCome flag information before we compress - moreToCome := wiremessage.IsMsgMoreToCome(*wm) - // compress wiremessage if allowed if compressor, ok := conn.(Compressor); ok && op.canCompress(startedInfo.cmdName) { b := memoryPool.Get().(*[]byte) @@ -872,15 +870,14 @@ func (op Operation) Execute(ctx context.Context) error { // } } - if op.Batches != nil && len(tt.WriteErrors) > 0 && currIndex > 0 { - for i := range tt.WriteErrors { - tt.WriteErrors[i].Index += int64(currIndex) - } - } - // If batching is enabled and either ordered is the default (which is true) or // explicitly set to true and we have write errors, return the errors. if op.Batches != nil && len(tt.WriteErrors) > 0 { + if currIndex > 0 { + for i := range tt.WriteErrors { + tt.WriteErrors[i].Index += int64(currIndex) + } + } if isOrdered := op.Batches.IsOrdered(); isOrdered == nil || *isOrdered { return tt } @@ -1015,7 +1012,6 @@ func (op Operation) Execute(ctx context.Context) error { } perr := op.ProcessResponseFn(ctx, res, info) if perr != nil { - fmt.Println("op", perr) return perr } } @@ -1036,7 +1032,7 @@ func (op Operation) Execute(ctx context.Context) error { // If we're batching and there are batches remaining, advance to the next batch. This isn't // a retry, so increment the transaction number, reset the retries number, and don't set // server or connection to nil to continue using the same connection. - if op.Batches != nil { + if op.Batches != nil && op.Batches.Size() > startedInfo.processedBatches { // If retries are supported for the current operation on the current server description, // the session isn't nil, and client retries are enabled, increment the txn number. // Calling IncrementTxnNumber() for server descriptions or topologies that do not @@ -1053,7 +1049,7 @@ func (op Operation) Execute(ctx context.Context) error { } currIndex += startedInfo.processedBatches op.Batches.AdvanceBatches(startedInfo.processedBatches) - if !op.Batches.End() { + if op.Batches.Size() > 0 { continue } } @@ -1289,21 +1285,11 @@ func (op Operation) createMsgWireMessage( cmdFn func([]byte, description.SelectedServer) ([]byte, error), ) ([]byte, []byte, error) { var flags wiremessage.MsgFlag - // We set the MoreToCome bit if we have a write concern, it's unacknowledged, and we either - // aren't batching or we are encoding the last batch. - var batching bool - if op.Batches != nil && !op.Batches.End() { - batching = true - } - if op.WriteConcern != nil && !writeconcern.AckWrite(op.WriteConcern) && !batching { - flags = wiremessage.MoreToCome - } // Set the ExhaustAllowed flag if the connection supports streaming. This will tell the server that it can // respond with the MoreToCome flag and then stream responses over this connection. if streamer, ok := conn.(StreamerConnection); ok && streamer.SupportsStreaming() { - flags |= wiremessage.ExhaustAllowed + flags = wiremessage.ExhaustAllowed } - dst = wiremessage.AppendMsgFlags(dst, flags) // Body dst = wiremessage.AppendMsgSectionType(dst, wiremessage.SingleDocument) @@ -1365,11 +1351,12 @@ func (op Operation) createWireMessage( desc description.SelectedServer, conn Connection, requestID int32, -) ([]byte, startedInformation, error) { +) ([]byte, bool, startedInformation, error) { var info startedInformation var wmindex int32 var err error + fIdx := len(dst) isLegacy := isLegacyHandshake(op, desc) shouldEncrypt := op.shouldEncrypt() if !isLegacy && !shouldEncrypt { @@ -1395,23 +1382,11 @@ func (op Operation) createWireMessage( } } else if shouldEncrypt { if desc.WireVersion.Max < cryptMinWireVersion { - return dst, info, errors.New("auto-encryption requires a MongoDB version of 4.2") + return dst, false, info, errors.New("auto-encryption requires a MongoDB version of 4.2") } cmdFn := func(dst []byte, desc description.SelectedServer) ([]byte, error) { - // create temporary command document - var cmdDst []byte - info.processedBatches, cmdDst, err = op.addEncryptCommandFields(nil, desc) - if err != nil { - return nil, err - } - // encrypt the command - encrypted, err := op.Crypt.Encrypt(ctx, op.Database, cmdDst) - if err != nil { - return nil, err - } - // append encrypted command to original destination, removing the first 4 bytes (length) and final byte (terminator) - dst = append(dst, encrypted[4:len(encrypted)-1]...) - return dst, nil + info.processedBatches, dst, err = op.addEncryptCommandFields(ctx, dst, desc) + return dst, err } wmindex, dst = wiremessage.AppendHeaderStart(dst, requestID, 0, wiremessage.OpMsg) dst, info.cmd, err = op.createMsgWireMessage(maxTimeMS, dst, desc, conn, cmdFn) @@ -1425,17 +1400,27 @@ func (op Operation) createWireMessage( dst, info.cmd, err = op.createLegacyHandshakeWireMessage(maxTimeMS, dst, desc, cmdFn) } if err != nil { - return nil, info, err + return nil, false, info, err + } + + var moreToCome bool + // We set the MoreToCome bit if we have a write concern, it's unacknowledged, and we either + // aren't batching or we are encoding the last batch. + unacknowledged := op.WriteConcern != nil && !writeconcern.AckWrite(op.WriteConcern) + batching := op.Batches != nil && op.Batches.Size() > info.processedBatches + if !isLegacy && unacknowledged && !batching { + dst[fIdx] |= byte(wiremessage.MoreToCome) + moreToCome = true } info.requestID = requestID - return bsoncore.UpdateLength(dst, wmindex, int32(len(dst[wmindex:]))), info, nil + return bsoncore.UpdateLength(dst, wmindex, int32(len(dst[wmindex:]))), moreToCome, info, nil } -func (op Operation) addEncryptCommandFields(dst []byte, desc description.SelectedServer) (int, []byte, error) { - var idx int32 - idx, dst = bsoncore.AppendDocumentStart(dst) +func (op Operation) addEncryptCommandFields(ctx context.Context, dst []byte, desc description.SelectedServer) (int, []byte, error) { + idx, cmdDst := bsoncore.AppendDocumentStart(nil) var err error - dst, err = op.CommandFn(dst, desc) + // create temporary command document + cmdDst, err = op.CommandFn(cmdDst, desc) if err != nil { return 0, nil, err } @@ -1443,14 +1428,15 @@ func (op Operation) addEncryptCommandFields(dst []byte, desc description.Selecte if op.Batches != nil { maxBatchCount := int(desc.MaxBatchCount) maxDocumentSize := int(desc.MaxDocumentSize) + fmt.Println("addEncryptCommandFields", cryptMaxBsonObjectSize, maxDocumentSize) if maxBatchCount > 1 { - n, dst, err = op.Batches.AppendBatchArray(dst, maxBatchCount, cryptMaxBsonObjectSize, maxDocumentSize) + n, cmdDst, err = op.Batches.AppendBatchArray(cmdDst, maxBatchCount, cryptMaxBsonObjectSize, maxDocumentSize) if err != nil { return 0, nil, err } } if n == 0 { - n, dst, err = op.Batches.AppendBatchArray(dst, 1, maxDocumentSize, maxDocumentSize) + n, cmdDst, err = op.Batches.AppendBatchArray(cmdDst, 1, maxDocumentSize, maxDocumentSize) if err != nil { return 0, nil, err } @@ -1459,10 +1445,17 @@ func (op Operation) addEncryptCommandFields(dst []byte, desc description.Selecte } } } - dst, err = bsoncore.AppendDocumentEnd(dst, idx) + cmdDst, err = bsoncore.AppendDocumentEnd(cmdDst, idx) + if err != nil { + return 0, nil, err + } + // encrypt the command + encrypted, err := op.Crypt.Encrypt(ctx, op.Database, cmdDst) if err != nil { return 0, nil, err } + // append encrypted command to original destination, removing the first 4 bytes (length) and final byte (terminator) + dst = append(dst, encrypted[4:len(encrypted)-1]...) return n, dst, nil }