Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyang-hu committed Oct 19, 2024
1 parent cb5e8aa commit 3bb6bbf
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 92 deletions.
6 changes: 4 additions & 2 deletions mongo/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -909,8 +909,10 @@ func (c *Client) BulkWrite(ctx context.Context, models *ClientWriteModels,
if bwo.VerboseResults == nil || !(*bwo.VerboseResults) {
op.errorsOnly = true
}
err = op.execute(ctx)
return &op.result, replaceErrors(err)
if err = op.execute(ctx); err != nil {
return nil, replaceErrors(err)
}
return &op.result, nil
}

// newLogger will use the LoggerOptions to create an internal logger and publish
Expand Down
13 changes: 11 additions & 2 deletions mongo/client_bulk_write.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,16 @@ func (bw *clientBulkWrite) execute(ctx context.Context) error {
exception.WriteErrors = batches.writeErrors
}
if exception != nil {
exception.PartialResult = batches.result
var hasSuccess bool
if bw.ordered == nil || *bw.ordered {
_, ok := batches.writeErrors[0]
hasSuccess = !ok
} else {
hasSuccess = len(batches.writeErrors) < len(bw.models)
}
if hasSuccess {
exception.PartialResult = batches.result
}
return *exception
}
return err
Expand Down Expand Up @@ -471,7 +480,7 @@ func (mb *modelBatches) processResponse(ctx context.Context, resp bsoncore.Docum
WriteErrors: mb.writeErrors,
PartialResult: mb.result,
}
if !res.Ok || res.NErrors > 0 {
if !res.Ok {
exception.TopLevelError = &WriteError{
Code: int(res.Code),
Message: res.Errmsg,
Expand Down
135 changes: 94 additions & 41 deletions mongo/integration/crud_prose_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"bytes"
"context"
"errors"
"os"
"strings"
"testing"

Expand Down Expand Up @@ -415,18 +416,20 @@ func TestErrorsCodeNamePropagated(t *testing.T) {
}

func TestClientBulkWrite(t *testing.T) {
mtOpts := mtest.NewOptions().MinServerVersion("8.0").AtlasDataLake(false).CreateClient(false)
mtOpts := mtest.NewOptions().MinServerVersion("8.0").AtlasDataLake(false).CreateClient(false).ClientType(mtest.Pinned)
mt := mtest.New(t, mtOpts)

mt.Run("input with greater than maxWriteBatchSize", func(mt *mtest.T) {
var opsCnt []int
monitor := &event.CommandMonitor{
Started: func(_ context.Context, e *event.CommandStartedEvent) {
if e.CommandName == "bulkWrite" {
v := e.Command.Lookup("ops")
elems, err := v.Array().Elements()
require.NoError(mt, err, "monitor error")
opsCnt = append(opsCnt, len(elems))
var c struct {
Ops []bson.D
}
err := bson.Unmarshal(e.Command, &c)
require.NoError(mt, err)
opsCnt = append(opsCnt, len(c.Ops))
}
},
}
Expand Down Expand Up @@ -455,10 +458,12 @@ func TestClientBulkWrite(t *testing.T) {
monitor := &event.CommandMonitor{
Started: func(_ context.Context, e *event.CommandStartedEvent) {
if e.CommandName == "bulkWrite" {
v := e.Command.Lookup("ops")
elems, err := v.Array().Elements()
require.NoError(mt, err, "monitor error")
opsCnt = append(opsCnt, len(elems))
var c struct {
Ops []bson.D
}
err := bson.Unmarshal(e.Command, &c)
require.NoError(mt, err)
opsCnt = append(opsCnt, len(c.Ops))
}
},
}
Expand Down Expand Up @@ -744,11 +749,20 @@ func TestClientBulkWrite(t *testing.T) {
})

mt.Run("bulkWrite batch splits when the addition of a new namespace exceeds the maximum message size", func(mt *mtest.T) {
var bwCmd []bsoncore.Document
type cmd struct {
Ops []bson.D
NsInfo []struct {
Ns string
}
}
var bwCmd []cmd
monitor := &event.CommandMonitor{
Started: func(_ context.Context, e *event.CommandStartedEvent) {
if e.CommandName == "bulkWrite" {
bwCmd = append(bwCmd, bsoncore.Document(e.Command))
var c cmd
err := bson.Unmarshal(e.Command, &c)
require.NoError(mt, err)
bwCmd = append(bwCmd, c)
}
},
}
Expand Down Expand Up @@ -794,17 +808,9 @@ func TestClientBulkWrite(t *testing.T) {
assert.Equal(mt, numModels+1, int(result.InsertedCount), "expected insertedCound: %d, got: %d", numModels+1, result.InsertedCount)
require.Len(mt, bwCmd, 1, "expected %d bulkWrite call, got %d", 1, len(bwCmd))

var cmd struct {
Ops []bson.D
NsInfo []struct {
Ns string
}
}
err = bson.Unmarshal(bwCmd[0], &cmd)
require.NoError(mt, err)
assert.Len(mt, cmd.Ops, numModels+1, "expected ops: %d, got: %d", numModels+1, len(cmd.Ops))
require.Len(mt, cmd.NsInfo, 1, "expected %d nsInfo, got: %d", 1, len(cmd.NsInfo))
assert.Equal(mt, "db.coll", cmd.NsInfo[0].Ns, "expected namespace: %s, got: %s", "db.coll", cmd.NsInfo[0].Ns)
assert.Len(mt, bwCmd[0].Ops, numModels+1, "expected ops: %d, got: %d", numModels+1, len(bwCmd[0].Ops))
require.Len(mt, bwCmd[0].NsInfo, 1, "expected %d nsInfo, got: %d", 1, len(bwCmd[0].NsInfo))
assert.Equal(mt, "db.coll", bwCmd[0].NsInfo[0].Ns, "expected namespace: %s, got: %s", "db.coll", bwCmd[0].NsInfo[0].Ns)
})
mt.Run("batch-splitting required", func(mt *mtest.T) {
bwCmd = bwCmd[:0]
Expand All @@ -818,27 +824,15 @@ func TestClientBulkWrite(t *testing.T) {
result, err := mt.Client.BulkWrite(context.Background(), models, options.ClientBulkWrite())
require.NoError(mt, err)
assert.Equal(mt, numModels+1, int(result.InsertedCount), "expected insertedCound: %d, got: %d", numModels+1, result.InsertedCount)
require.Len(mt, bwCmd, 2, "expected %d bulkWrite call, got %d", 2, len(bwCmd))
require.Len(mt, bwCmd, 2, "expected %d bulkWrite calls, got %d", 2, len(bwCmd))

type cmd struct {
Ops []bson.D
NsInfo []struct {
Ns string
}
}
var c1 cmd
err = bson.Unmarshal(bwCmd[0], &c1)
require.NoError(mt, err)
assert.Len(mt, c1.Ops, numModels, "expected ops: %d, got: %d", numModels, len(c1.Ops))
require.Len(mt, c1.NsInfo, 1, "expected %d nsInfo, got: %d", 1, len(c1.NsInfo))
assert.Equal(mt, "db.coll", c1.NsInfo[0].Ns, "expected namespace: %s, got: %s", "db.coll", c1.NsInfo[0].Ns)
assert.Len(mt, bwCmd[0].Ops, numModels, "expected ops: %d, got: %d", numModels, len(bwCmd[0].Ops))
require.Len(mt, bwCmd[0].NsInfo, 1, "expected %d nsInfo, got: %d", 1, len(bwCmd[0].NsInfo))
assert.Equal(mt, "db.coll", bwCmd[0].NsInfo[0].Ns, "expected namespace: %s, got: %s", "db.coll", bwCmd[0].NsInfo[0].Ns)

var c2 cmd
err = bson.Unmarshal(bwCmd[1], &c2)
require.NoError(mt, err)
assert.Len(mt, c2.Ops, 1, "expected ops: %d, got: %d", 1, len(c2.Ops))
require.Len(mt, c2.NsInfo, 1, "expected %d nsInfo, got: %d", 1, len(c2.NsInfo))
assert.Equal(mt, "db."+coll, c2.NsInfo[0].Ns, "expected namespace: %s, got: %s", "db."+coll, c2.NsInfo[0].Ns)
assert.Len(mt, bwCmd[1].Ops, 1, "expected ops: %d, got: %d", 1, len(bwCmd[1].Ops))
require.Len(mt, bwCmd[1].NsInfo, 1, "expected %d nsInfo, got: %d", 1, len(bwCmd[1].NsInfo))
assert.Equal(mt, "db."+coll, bwCmd[1].NsInfo[0].Ns, "expected namespace: %s, got: %s", "db."+coll, bwCmd[1].NsInfo[0].Ns)
})
})

Expand Down Expand Up @@ -867,6 +861,10 @@ func TestClientBulkWrite(t *testing.T) {
})

mt.Run("bulkWrite returns an error if auto-encryption is configured", func(mt *mtest.T) {
if os.Getenv("DOCKER_RUNNING") != "" {
mt.Skip("skipping test in docker environment")
}

autoEncryptionOpts := options.AutoEncryption().
SetKeyVaultNamespace("db.coll").
SetKmsProviders(map[string]map[string]interface{}{
Expand All @@ -883,4 +881,59 @@ func TestClientBulkWrite(t *testing.T) {
_, err := mt.Client.BulkWrite(context.Background(), models, options.ClientBulkWrite().SetOrdered(false).SetWriteConcern(writeconcern.Unacknowledged()))
require.ErrorContains(mt, err, "bulkWrite does not currently support automatic encryption")
})

mt.Run("bulkWrite with unacknowledged write concern uses w:0 for all batches", func(mt *mtest.T) {
type cmd struct {
Ops []bson.D
WriteConcern struct {
W interface{}
}
}
var bwCmd []cmd
monitor := &event.CommandMonitor{
Started: func(_ context.Context, e *event.CommandStartedEvent) {
if e.CommandName == "bulkWrite" {
var c cmd
err := bson.Unmarshal(e.Command, &c)
require.NoError(mt, err)

bwCmd = append(bwCmd, c)
}
},
}
var hello struct {
MaxBsonObjectSize int
MaxMessageSizeBytes int
}
err := mt.DB.RunCommand(context.Background(), bson.D{{"hello", 1}}).Decode(&hello)
require.NoError(mt, err, "Hello error")

mt.ResetClient(options.Client().SetMonitor(monitor))

coll := mt.CreateCollection(mtest.Collection{DB: "db", Name: "coll"}, true)
err = coll.Drop(context.Background())
require.NoError(mt, err, "Drop error")

numModels := hello.MaxMessageSizeBytes / hello.MaxBsonObjectSize
models := &mongo.ClientWriteModels{}
for i := 0; i < numModels+1; i++ {
models.
AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{
Document: bson.D{{"a", strings.Repeat("b", hello.MaxBsonObjectSize-500)}},
})
}
_, err = mt.Client.BulkWrite(context.Background(), models, options.ClientBulkWrite().SetOrdered(false).SetWriteConcern(writeconcern.Unacknowledged()))
require.NoError(mt, err)
require.Len(mt, bwCmd, 2, "expected %d bulkWrite calls, got %d", 2, len(bwCmd))

assert.Len(mt, bwCmd[0].Ops, numModels, "expected ops: %d, got: %d", numModels, len(bwCmd[0].Ops))
assert.Equal(mt, int32(0), bwCmd[0].WriteConcern.W, "expected writeConcern: %d, got: %v", 0, bwCmd[0].WriteConcern.W)

assert.Len(mt, bwCmd[1].Ops, 1, "expected ops: %d, got: %d", 1, len(bwCmd[1].Ops))
assert.Equal(mt, int32(0), bwCmd[1].WriteConcern.W, "expected writeConcern: %d, got: %v", 0, bwCmd[1].WriteConcern.W)

n, err := coll.CountDocuments(context.Background(), bson.D{})
require.NoError(mt, err)
assert.Equal(mt, numModels+1, int(n), "expected %d documents, got %d", numModels+1, n)
})
}
99 changes: 52 additions & 47 deletions mongo/integration/unified/client_operation_execution.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package unified

import (
"context"
"errors"
"fmt"
"strconv"
"strings"
Expand Down Expand Up @@ -228,60 +229,64 @@ func executeClientBulkWrite(ctx context.Context, operation *operation) (*operati
}

res, err := client.BulkWrite(ctx, wirteModels, opts)
raw := emptyCoreDocument
if res != nil {
rawBuilder := bsoncore.NewDocumentBuilder().
AppendInt64("deletedCount", res.DeletedCount).
AppendInt64("insertedCount", res.InsertedCount).
AppendInt64("matchedCount", res.MatchedCount).
AppendInt64("modifiedCount", res.ModifiedCount).
AppendInt64("upsertedCount", res.UpsertedCount)

var resBuilder *bsoncore.DocumentBuilder

resBuilder = bsoncore.NewDocumentBuilder()
for k, v := range res.DeleteResults {
resBuilder.AppendDocument(strconv.Itoa(k),
bsoncore.NewDocumentBuilder().
AppendInt64("deletedCount", v.DeletedCount).
Build(),
)
if res == nil {
var bwe mongo.ClientBulkWriteException
if !errors.As(err, &bwe) || bwe.PartialResult == nil {
return newDocumentResult(emptyCoreDocument, err), nil
} else {
res = bwe.PartialResult
}
rawBuilder.AppendDocument("deleteResults", resBuilder.Build())
}
rawBuilder := bsoncore.NewDocumentBuilder().
AppendInt64("deletedCount", res.DeletedCount).
AppendInt64("insertedCount", res.InsertedCount).
AppendInt64("matchedCount", res.MatchedCount).
AppendInt64("modifiedCount", res.ModifiedCount).
AppendInt64("upsertedCount", res.UpsertedCount)

var resBuilder *bsoncore.DocumentBuilder

resBuilder = bsoncore.NewDocumentBuilder()
for k, v := range res.DeleteResults {
resBuilder.AppendDocument(strconv.Itoa(k),
bsoncore.NewDocumentBuilder().
AppendInt64("deletedCount", v.DeletedCount).
Build(),
)
}
rawBuilder.AppendDocument("deleteResults", resBuilder.Build())

resBuilder = bsoncore.NewDocumentBuilder()
for k, v := range res.InsertResults {
t, d, err := bson.MarshalValue(v.InsertedID)
resBuilder = bsoncore.NewDocumentBuilder()
for k, v := range res.InsertResults {
t, d, err := bson.MarshalValue(v.InsertedID)
if err != nil {
return nil, err
}
resBuilder.AppendDocument(strconv.Itoa(k),
bsoncore.NewDocumentBuilder().
AppendValue("insertedId", bsoncore.Value{Type: t, Data: d}).
Build(),
)
}
rawBuilder.AppendDocument("insertResults", resBuilder.Build())

resBuilder = bsoncore.NewDocumentBuilder()
for k, v := range res.UpdateResults {
b := bsoncore.NewDocumentBuilder().
AppendInt64("matchedCount", v.MatchedCount).
AppendInt64("modifiedCount", v.ModifiedCount)
if v.UpsertedID != nil {
t, d, err := bson.MarshalValue(v.UpsertedID)
if err != nil {
return nil, err
}
resBuilder.AppendDocument(strconv.Itoa(k),
bsoncore.NewDocumentBuilder().
AppendValue("insertedId", bsoncore.Value{Type: t, Data: d}).
Build(),
)
b.AppendValue("upsertedId", bsoncore.Value{Type: t, Data: d})
}
rawBuilder.AppendDocument("insertResults", resBuilder.Build())

resBuilder = bsoncore.NewDocumentBuilder()
for k, v := range res.UpdateResults {
b := bsoncore.NewDocumentBuilder().
AppendInt64("matchedCount", v.MatchedCount).
AppendInt64("modifiedCount", v.ModifiedCount)
if v.UpsertedID != nil {
t, d, err := bson.MarshalValue(v.UpsertedID)
if err != nil {
return nil, err
}
b.AppendValue("upsertedId", bsoncore.Value{Type: t, Data: d})
}
resBuilder.AppendDocument(strconv.Itoa(k), b.Build())
}
rawBuilder.AppendDocument("updateResults", resBuilder.Build())

raw = rawBuilder.Build()
resBuilder.AppendDocument(strconv.Itoa(k), b.Build())
}
return newDocumentResult(raw, err), nil
rawBuilder.AppendDocument("updateResults", resBuilder.Build())

return newDocumentResult(rawBuilder.Build(), err), nil
}

func appendClientBulkWriteModel(key string, value bson.Raw, model *mongo.ClientWriteModels) error {
Expand Down

0 comments on commit 3bb6bbf

Please sign in to comment.