Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyang-hu committed Oct 13, 2024
1 parent 63a973b commit 97d5b6b
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 62 deletions.
12 changes: 9 additions & 3 deletions mongo/client_bulk_write.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,16 @@ func (mb *modelBatches) IsOrdered() *bool {

func (mb *modelBatches) AdvanceBatches(n int) {
mb.offset += n
if mb.offset > len(mb.models) {
mb.offset = len(mb.models)
}
}

func (mb *modelBatches) End() bool {
return len(mb.models) <= mb.offset
func (mb *modelBatches) Size() int {
if mb.offset > len(mb.models) {
return 0
}
return len(mb.models) - mb.offset
}

func (mb *modelBatches) AppendBatchSequence(dst []byte, maxCount, maxDocSize, totalSize int) (int, []byte, error) {
Expand Down Expand Up @@ -208,7 +214,7 @@ type functionSet struct {
}

func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxDocSize, totalSize int) (int, []byte, error) {
if mb.End() {
if mb.Size() == 0 {
return 0, dst, io.EOF
}

Expand Down
2 changes: 1 addition & 1 deletion mongo/integration/client_side_encryption_prose_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ func TestClientSideEncryptionProse(t *testing.T) {
cpt.cseStarted = cpt.cseStarted[:0]
_, err = cpt.cseColl.InsertMany(context.Background(), []interface{}{firstBulkDoc, secondBulkDoc})
assert.Nil(mt, err, "InsertMany error for large documents: %v", err)
assert.Equal(mt, 2, len(cpt.cseStarted), "expected 2 insert events, got %d", len(cpt.cseStarted))
assert.Equal(mt, 2, len(cpt.cseStarted), "expected 2 insert events, got %d with size %d %d", len(cpt.cseStarted), len(str), len(limitsDoc))

// insert a document slightly smaller than 16MiB and expect the operation to succeed
doc = bson.D{{"_id", "under_16mib"}, {"unencrypted", complete16mbStr[:maxBsonObjSize-2000]}}
Expand Down
22 changes: 14 additions & 8 deletions x/mongo/driver/batches.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type Batches struct {
}

func (b *Batches) AppendBatchSequence(dst []byte, maxCount, maxDocSize, totalSize int) (int, []byte, error) {
if b.End() {
if b.Size() == 0 {
return 0, dst, io.EOF
}
l := len(dst)
Expand All @@ -34,7 +34,7 @@ func (b *Batches) AppendBatchSequence(dst []byte, maxCount, maxDocSize, totalSiz
idx, dst = bsoncore.ReserveLength(dst)
dst = append(dst, b.Identifier...)
dst = append(dst, 0x00)
size := len(dst) - l
var size int
var n int
for i := b.offset; i < len(b.Documents); i++ {
if n == maxCount {
Expand All @@ -45,7 +45,7 @@ func (b *Batches) AppendBatchSequence(dst []byte, maxCount, maxDocSize, totalSiz
break
}
size += len(doc)
if size >= totalSize {
if size > maxDocSize {
break
}
dst = append(dst, doc...)
Expand All @@ -59,12 +59,12 @@ func (b *Batches) AppendBatchSequence(dst []byte, maxCount, maxDocSize, totalSiz
}

func (b *Batches) AppendBatchArray(dst []byte, maxCount, maxDocSize, totalSize int) (int, []byte, error) {
if b.End() {
if b.Size() == 0 {
return 0, dst, io.EOF
}
l := len(dst)
aidx, dst := bsoncore.AppendArrayElementStart(dst, b.Identifier)
size := len(dst) - l
var size int
var n int
for i := b.offset; i < len(b.Documents); i++ {
if n == maxCount {
Expand All @@ -75,7 +75,7 @@ func (b *Batches) AppendBatchArray(dst []byte, maxCount, maxDocSize, totalSize i
break
}
size += len(doc)
if size >= totalSize {
if size > totalSize {
break
}
dst = bsoncore.AppendDocumentElement(dst, strconv.Itoa(n), doc)
Expand All @@ -98,8 +98,14 @@ func (b *Batches) IsOrdered() *bool {

func (b *Batches) AdvanceBatches(n int) {
b.offset += n
if b.offset > len(b.Documents) {
b.offset = len(b.Documents)
}
}

func (b *Batches) End() bool {
return len(b.Documents) <= b.offset
func (b *Batches) Size() int {
if b.offset > len(b.Documents) {
return 0
}
return len(b.Documents) - b.offset
}
93 changes: 43 additions & 50 deletions x/mongo/driver/operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ type Operation struct {
AppendBatchArray(dst []byte, maxCount int, maxDocSize int, totalSize int) (int, []byte, error)
IsOrdered() *bool
AdvanceBatches(n int)
End() bool
Size() int
}

// Legacy sets the legacy type for this operation. There are only 3 types that require legacy
Expand Down Expand Up @@ -719,8 +719,9 @@ func (op Operation) Execute(ctx context.Context) error {

desc := description.SelectedServer{Server: conn.Description(), Kind: op.Deployment.Kind()}

var moreToCome bool
var startedInfo startedInformation
*wm, startedInfo, err = op.createWireMessage(ctx, maxTimeMS, (*wm)[:0], desc, conn, requestID)
*wm, moreToCome, startedInfo, err = op.createWireMessage(ctx, maxTimeMS, (*wm)[:0], desc, conn, requestID)

if err != nil {
return err
Expand All @@ -746,9 +747,6 @@ func (op Operation) Execute(ctx context.Context) error {

op.publishStartedEvent(ctx, startedInfo)

// get the moreToCome flag information before we compress
moreToCome := wiremessage.IsMsgMoreToCome(*wm)

// compress wiremessage if allowed
if compressor, ok := conn.(Compressor); ok && op.canCompress(startedInfo.cmdName) {
b := memoryPool.Get().(*[]byte)
Expand Down Expand Up @@ -872,15 +870,14 @@ func (op Operation) Execute(ctx context.Context) error {
// }
}

if op.Batches != nil && len(tt.WriteErrors) > 0 && currIndex > 0 {
for i := range tt.WriteErrors {
tt.WriteErrors[i].Index += int64(currIndex)
}
}

// If batching is enabled and either ordered is the default (which is true) or
// explicitly set to true and we have write errors, return the errors.
if op.Batches != nil && len(tt.WriteErrors) > 0 {
if currIndex > 0 {
for i := range tt.WriteErrors {
tt.WriteErrors[i].Index += int64(currIndex)
}
}
if isOrdered := op.Batches.IsOrdered(); isOrdered == nil || *isOrdered {
return tt
}
Expand Down Expand Up @@ -1015,7 +1012,6 @@ func (op Operation) Execute(ctx context.Context) error {
}
perr := op.ProcessResponseFn(ctx, res, info)
if perr != nil {
fmt.Println("op", perr)
return perr
}
}
Expand All @@ -1036,7 +1032,7 @@ func (op Operation) Execute(ctx context.Context) error {
// If we're batching and there are batches remaining, advance to the next batch. This isn't
// a retry, so increment the transaction number, reset the retries number, and don't set
// server or connection to nil to continue using the same connection.
if op.Batches != nil {
if op.Batches != nil && op.Batches.Size() > startedInfo.processedBatches {
// If retries are supported for the current operation on the current server description,
// the session isn't nil, and client retries are enabled, increment the txn number.
// Calling IncrementTxnNumber() for server descriptions or topologies that do not
Expand All @@ -1053,7 +1049,7 @@ func (op Operation) Execute(ctx context.Context) error {
}
currIndex += startedInfo.processedBatches
op.Batches.AdvanceBatches(startedInfo.processedBatches)
if !op.Batches.End() {
if op.Batches.Size() > 0 {
continue
}
}
Expand Down Expand Up @@ -1289,21 +1285,11 @@ func (op Operation) createMsgWireMessage(
cmdFn func([]byte, description.SelectedServer) ([]byte, error),
) ([]byte, []byte, error) {
var flags wiremessage.MsgFlag
// We set the MoreToCome bit if we have a write concern, it's unacknowledged, and we either
// aren't batching or we are encoding the last batch.
var batching bool
if op.Batches != nil && !op.Batches.End() {
batching = true
}
if op.WriteConcern != nil && !writeconcern.AckWrite(op.WriteConcern) && !batching {
flags = wiremessage.MoreToCome
}
// Set the ExhaustAllowed flag if the connection supports streaming. This will tell the server that it can
// respond with the MoreToCome flag and then stream responses over this connection.
if streamer, ok := conn.(StreamerConnection); ok && streamer.SupportsStreaming() {
flags |= wiremessage.ExhaustAllowed
flags = wiremessage.ExhaustAllowed
}

dst = wiremessage.AppendMsgFlags(dst, flags)
// Body
dst = wiremessage.AppendMsgSectionType(dst, wiremessage.SingleDocument)
Expand Down Expand Up @@ -1365,11 +1351,12 @@ func (op Operation) createWireMessage(
desc description.SelectedServer,
conn Connection,
requestID int32,
) ([]byte, startedInformation, error) {
) ([]byte, bool, startedInformation, error) {
var info startedInformation
var wmindex int32
var err error

fIdx := len(dst)
isLegacy := isLegacyHandshake(op, desc)
shouldEncrypt := op.shouldEncrypt()
if !isLegacy && !shouldEncrypt {
Expand All @@ -1395,23 +1382,11 @@ func (op Operation) createWireMessage(
}
} else if shouldEncrypt {
if desc.WireVersion.Max < cryptMinWireVersion {
return dst, info, errors.New("auto-encryption requires a MongoDB version of 4.2")
return dst, false, info, errors.New("auto-encryption requires a MongoDB version of 4.2")
}
cmdFn := func(dst []byte, desc description.SelectedServer) ([]byte, error) {
// create temporary command document
var cmdDst []byte
info.processedBatches, cmdDst, err = op.addEncryptCommandFields(nil, desc)
if err != nil {
return nil, err
}
// encrypt the command
encrypted, err := op.Crypt.Encrypt(ctx, op.Database, cmdDst)
if err != nil {
return nil, err
}
// append encrypted command to original destination, removing the first 4 bytes (length) and final byte (terminator)
dst = append(dst, encrypted[4:len(encrypted)-1]...)
return dst, nil
info.processedBatches, dst, err = op.addEncryptCommandFields(ctx, dst, desc)
return dst, err
}
wmindex, dst = wiremessage.AppendHeaderStart(dst, requestID, 0, wiremessage.OpMsg)
dst, info.cmd, err = op.createMsgWireMessage(maxTimeMS, dst, desc, conn, cmdFn)
Expand All @@ -1425,32 +1400,43 @@ func (op Operation) createWireMessage(
dst, info.cmd, err = op.createLegacyHandshakeWireMessage(maxTimeMS, dst, desc, cmdFn)
}
if err != nil {
return nil, info, err
return nil, false, info, err
}

var moreToCome bool
// We set the MoreToCome bit if we have a write concern, it's unacknowledged, and we either
// aren't batching or we are encoding the last batch.
unacknowledged := op.WriteConcern != nil && !writeconcern.AckWrite(op.WriteConcern)
batching := op.Batches != nil && op.Batches.Size() > info.processedBatches
if !isLegacy && unacknowledged && !batching {
dst[fIdx] |= byte(wiremessage.MoreToCome)
moreToCome = true
}
info.requestID = requestID
return bsoncore.UpdateLength(dst, wmindex, int32(len(dst[wmindex:]))), info, nil
return bsoncore.UpdateLength(dst, wmindex, int32(len(dst[wmindex:]))), moreToCome, info, nil
}

func (op Operation) addEncryptCommandFields(dst []byte, desc description.SelectedServer) (int, []byte, error) {
var idx int32
idx, dst = bsoncore.AppendDocumentStart(dst)
func (op Operation) addEncryptCommandFields(ctx context.Context, dst []byte, desc description.SelectedServer) (int, []byte, error) {
idx, cmdDst := bsoncore.AppendDocumentStart(nil)
var err error
dst, err = op.CommandFn(dst, desc)
// create temporary command document
cmdDst, err = op.CommandFn(cmdDst, desc)
if err != nil {
return 0, nil, err
}
var n int
if op.Batches != nil {
maxBatchCount := int(desc.MaxBatchCount)
maxDocumentSize := int(desc.MaxDocumentSize)
fmt.Println("addEncryptCommandFields", cryptMaxBsonObjectSize, maxDocumentSize)
if maxBatchCount > 1 {
n, dst, err = op.Batches.AppendBatchArray(dst, maxBatchCount, cryptMaxBsonObjectSize, maxDocumentSize)
n, cmdDst, err = op.Batches.AppendBatchArray(cmdDst, maxBatchCount, cryptMaxBsonObjectSize, maxDocumentSize)
if err != nil {
return 0, nil, err
}
}
if n == 0 {
n, dst, err = op.Batches.AppendBatchArray(dst, 1, maxDocumentSize, maxDocumentSize)
n, cmdDst, err = op.Batches.AppendBatchArray(cmdDst, 1, maxDocumentSize, maxDocumentSize)
if err != nil {
return 0, nil, err
}
Expand All @@ -1459,10 +1445,17 @@ func (op Operation) addEncryptCommandFields(dst []byte, desc description.Selecte
}
}
}
dst, err = bsoncore.AppendDocumentEnd(dst, idx)
cmdDst, err = bsoncore.AppendDocumentEnd(cmdDst, idx)
if err != nil {
return 0, nil, err
}
// encrypt the command
encrypted, err := op.Crypt.Encrypt(ctx, op.Database, cmdDst)
if err != nil {
return 0, nil, err
}
// append encrypted command to original destination, removing the first 4 bytes (length) and final byte (terminator)
dst = append(dst, encrypted[4:len(encrypted)-1]...)
return n, dst, nil
}

Expand Down

0 comments on commit 97d5b6b

Please sign in to comment.