From 691b9767c5978551310156895721fcf4ca1cef5e Mon Sep 17 00:00:00 2001 From: Preston Vasquez <24281431+prestonvasquez@users.noreply.github.com> Date: Mon, 14 Aug 2023 13:59:06 -0600 Subject: [PATCH 1/8] GODRIVER-2101 Expand test to use pigeonhole principle --- mongo/collection.go | 89 ++++++++++--------- mongo/cursor.go | 4 +- .../integration/retryable_reads_prose_test.go | 79 ++++++++++++++++ mongo/util.go | 2 + x/mongo/driver/batch_cursor.go | 8 +- x/mongo/driver/operation.go | 12 +++ 6 files changed, 147 insertions(+), 47 deletions(-) diff --git a/mongo/collection.go b/mongo/collection.go index 997d877681..8e9bab53aa 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -185,8 +185,8 @@ func (coll *Collection) Database() *Database { // // The opts parameter can be used to specify options for the operation (see the options.BulkWriteOptions documentation.) func (coll *Collection) BulkWrite(ctx context.Context, models []WriteModel, - opts ...*options.BulkWriteOptions) (*BulkWriteResult, error) { - + opts ...*options.BulkWriteOptions, +) (*BulkWriteResult, error) { if len(models) == 0 { return nil, ErrEmptySlice } @@ -242,8 +242,8 @@ func (coll *Collection) BulkWrite(ctx context.Context, models []WriteModel, } func (coll *Collection) insert(ctx context.Context, documents []interface{}, - opts ...*options.InsertManyOptions) ([]interface{}, error) { - + opts ...*options.InsertManyOptions, +) ([]interface{}, error) { if ctx == nil { ctx = context.Background() } @@ -343,8 +343,8 @@ func (coll *Collection) insert(ctx context.Context, documents []interface{}, // // For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/insert/. func (coll *Collection) InsertOne(ctx context.Context, document interface{}, - opts ...*options.InsertOneOptions) (*InsertOneResult, error) { - + opts ...*options.InsertOneOptions, +) (*InsertOneResult, error) { ioOpts := options.MergeInsertOneOptions(opts...) imOpts := options.InsertMany() @@ -375,8 +375,8 @@ func (coll *Collection) InsertOne(ctx context.Context, document interface{}, // // For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/insert/. func (coll *Collection) InsertMany(ctx context.Context, documents []interface{}, - opts ...*options.InsertManyOptions) (*InsertManyResult, error) { - + opts ...*options.InsertManyOptions, +) (*InsertManyResult, error) { if len(documents) == 0 { return nil, ErrEmptySlice } @@ -410,8 +410,8 @@ func (coll *Collection) InsertMany(ctx context.Context, documents []interface{}, } func (coll *Collection) delete(ctx context.Context, filter interface{}, deleteOne bool, expectedRr returnResult, - opts ...*options.DeleteOptions) (*DeleteResult, error) { - + opts ...*options.DeleteOptions, +) (*DeleteResult, error) { if ctx == nil { ctx = context.Background() } @@ -514,8 +514,8 @@ func (coll *Collection) delete(ctx context.Context, filter interface{}, deleteOn // // For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/delete/. func (coll *Collection) DeleteOne(ctx context.Context, filter interface{}, - opts ...*options.DeleteOptions) (*DeleteResult, error) { - + opts ...*options.DeleteOptions, +) (*DeleteResult, error) { return coll.delete(ctx, filter, true, rrOne, opts...) } @@ -530,14 +530,14 @@ func (coll *Collection) DeleteOne(ctx context.Context, filter interface{}, // // For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/delete/. func (coll *Collection) DeleteMany(ctx context.Context, filter interface{}, - opts ...*options.DeleteOptions) (*DeleteResult, error) { - + opts ...*options.DeleteOptions, +) (*DeleteResult, error) { return coll.delete(ctx, filter, false, rrMany, opts...) } func (coll *Collection) updateOrReplace(ctx context.Context, filter bsoncore.Document, update interface{}, multi bool, - expectedRr returnResult, checkDollarKey bool, opts ...*options.UpdateOptions) (*UpdateResult, error) { - + expectedRr returnResult, checkDollarKey bool, opts ...*options.UpdateOptions, +) (*UpdateResult, error) { if ctx == nil { ctx = context.Background() } @@ -648,7 +648,8 @@ func (coll *Collection) updateOrReplace(ctx context.Context, filter bsoncore.Doc // // For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/update/. func (coll *Collection) UpdateByID(ctx context.Context, id interface{}, update interface{}, - opts ...*options.UpdateOptions) (*UpdateResult, error) { + opts ...*options.UpdateOptions, +) (*UpdateResult, error) { if id == nil { return nil, ErrNilValue } @@ -670,8 +671,8 @@ func (coll *Collection) UpdateByID(ctx context.Context, id interface{}, update i // // For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/update/. func (coll *Collection) UpdateOne(ctx context.Context, filter interface{}, update interface{}, - opts ...*options.UpdateOptions) (*UpdateResult, error) { - + opts ...*options.UpdateOptions, +) (*UpdateResult, error) { if ctx == nil { ctx = context.Background() } @@ -698,8 +699,8 @@ func (coll *Collection) UpdateOne(ctx context.Context, filter interface{}, updat // // For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/update/. func (coll *Collection) UpdateMany(ctx context.Context, filter interface{}, update interface{}, - opts ...*options.UpdateOptions) (*UpdateResult, error) { - + opts ...*options.UpdateOptions, +) (*UpdateResult, error) { if ctx == nil { ctx = context.Background() } @@ -726,8 +727,8 @@ func (coll *Collection) UpdateMany(ctx context.Context, filter interface{}, upda // // For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/update/. func (coll *Collection) ReplaceOne(ctx context.Context, filter interface{}, - replacement interface{}, opts ...*options.ReplaceOptions) (*UpdateResult, error) { - + replacement interface{}, opts ...*options.ReplaceOptions, +) (*UpdateResult, error) { if ctx == nil { ctx = context.Background() } @@ -776,7 +777,8 @@ func (coll *Collection) ReplaceOne(ctx context.Context, filter interface{}, // // For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/aggregate/. func (coll *Collection) Aggregate(ctx context.Context, pipeline interface{}, - opts ...*options.AggregateOptions) (*Cursor, error) { + opts ...*options.AggregateOptions, +) (*Cursor, error) { a := aggregateParams{ ctx: ctx, pipeline: pipeline, @@ -952,8 +954,8 @@ func aggregate(a aggregateParams) (cur *Cursor, err error) { // // The opts parameter can be used to specify options for the operation (see the options.CountOptions documentation). func (coll *Collection) CountDocuments(ctx context.Context, filter interface{}, - opts ...*options.CountOptions) (int64, error) { - + opts ...*options.CountOptions, +) (int64, error) { if ctx == nil { ctx = context.Background() } @@ -1037,8 +1039,8 @@ func (coll *Collection) CountDocuments(ctx context.Context, filter interface{}, // // For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/count/. func (coll *Collection) EstimatedDocumentCount(ctx context.Context, - opts ...*options.EstimatedDocumentCountOptions) (int64, error) { - + opts ...*options.EstimatedDocumentCountOptions, +) (int64, error) { if ctx == nil { ctx = context.Background() } @@ -1099,8 +1101,8 @@ func (coll *Collection) EstimatedDocumentCount(ctx context.Context, // // For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/distinct/. func (coll *Collection) Distinct(ctx context.Context, fieldName string, filter interface{}, - opts ...*options.DistinctOptions) ([]interface{}, error) { - + opts ...*options.DistinctOptions, +) ([]interface{}, error) { if ctx == nil { ctx = context.Background() } @@ -1190,8 +1192,8 @@ func (coll *Collection) Distinct(ctx context.Context, fieldName string, filter i // // For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/find/. func (coll *Collection) Find(ctx context.Context, filter interface{}, - opts ...*options.FindOptions) (cur *Cursor, err error) { - + opts ...*options.FindOptions, +) (cur *Cursor, err error) { if ctx == nil { ctx = context.Background() } @@ -1224,6 +1226,8 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, fo := options.MergeFindOptions(opts...) + fmt.Println("timeout on client: ", coll.client.timeout) + selector := makeReadPrefSelector(sess, coll.readSelector, coll.client.localThreshold) op := operation.NewFind(f). Session(sess).ReadConcern(rc).ReadPreference(coll.readPreference). @@ -1349,6 +1353,7 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, if coll.client.retryReads { retry = driver.RetryOncePerCommand } + op = op.Retry(retry) if err = op.Execute(ctx); err != nil { @@ -1372,8 +1377,8 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, // // For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/find/. func (coll *Collection) FindOne(ctx context.Context, filter interface{}, - opts ...*options.FindOneOptions) *SingleResult { - + opts ...*options.FindOneOptions, +) *SingleResult { if ctx == nil { ctx = context.Background() } @@ -1486,8 +1491,8 @@ func (coll *Collection) findAndModify(ctx context.Context, op *operation.FindAnd // // For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/findAndModify/. func (coll *Collection) FindOneAndDelete(ctx context.Context, filter interface{}, - opts ...*options.FindOneAndDeleteOptions) *SingleResult { - + opts ...*options.FindOneAndDeleteOptions, +) *SingleResult { f, err := marshal(filter, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} @@ -1558,8 +1563,8 @@ func (coll *Collection) FindOneAndDelete(ctx context.Context, filter interface{} // // For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/findAndModify/. func (coll *Collection) FindOneAndReplace(ctx context.Context, filter interface{}, - replacement interface{}, opts ...*options.FindOneAndReplaceOptions) *SingleResult { - + replacement interface{}, opts ...*options.FindOneAndReplaceOptions, +) *SingleResult { f, err := marshal(filter, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} @@ -1648,8 +1653,8 @@ func (coll *Collection) FindOneAndReplace(ctx context.Context, filter interface{ // // For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/findAndModify/. func (coll *Collection) FindOneAndUpdate(ctx context.Context, filter interface{}, - update interface{}, opts ...*options.FindOneAndUpdateOptions) *SingleResult { - + update interface{}, opts ...*options.FindOneAndUpdateOptions, +) *SingleResult { if ctx == nil { ctx = context.Background() } @@ -1752,8 +1757,8 @@ func (coll *Collection) FindOneAndUpdate(ctx context.Context, filter interface{} // The opts parameter can be used to specify options for change stream creation (see the options.ChangeStreamOptions // documentation). func (coll *Collection) Watch(ctx context.Context, pipeline interface{}, - opts ...*options.ChangeStreamOptions) (*ChangeStream, error) { - + opts ...*options.ChangeStreamOptions, +) (*ChangeStream, error) { csConfig := changeStreamConfig{ readConcern: coll.readConcern, readPreference: coll.readPreference, diff --git a/mongo/cursor.go b/mongo/cursor.go index d2228ed9c4..f213771807 100644 --- a/mongo/cursor.go +++ b/mongo/cursor.go @@ -345,8 +345,8 @@ func (c *Cursor) RemainingBatchLength() int { // addFromBatch adds all documents from batch to sliceVal starting at the given index. It returns the new slice value, // the next empty index in the slice, and an error if one occurs. func (c *Cursor) addFromBatch(sliceVal reflect.Value, elemType reflect.Type, batch *bsoncore.DocumentSequence, - index int) (reflect.Value, int, error) { - + index int, +) (reflect.Value, int, error) { docs, err := batch.Documents() if err != nil { return sliceVal, index, err diff --git a/mongo/integration/retryable_reads_prose_test.go b/mongo/integration/retryable_reads_prose_test.go index b83414e518..ce48e7bdb5 100644 --- a/mongo/integration/retryable_reads_prose_test.go +++ b/mongo/integration/retryable_reads_prose_test.go @@ -8,6 +8,7 @@ package integration import ( "context" + "fmt" "sync" "testing" "time" @@ -16,6 +17,7 @@ import ( "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/assert" "go.mongodb.org/mongo-driver/internal/eventtest" + "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/mongo/integration/mtest" "go.mongodb.org/mongo-driver/mongo/options" ) @@ -103,4 +105,81 @@ func TestRetryableReadsProse(t *testing.T) { "expected a find event, got a(n) %v event", cmdEvt.CommandName) } }) + + mtOpts = mtest.NewOptions().ClientOptions(clientOpts).MinServerVersion("4.2"). + Topologies(mtest.Sharded) + + mt.RunOpts("retry on different mongos", mtOpts, func(mt *mtest.T) { + const hostCount = 3 + + hosts := options.Client().ApplyURI(mtest.ClusterURI()).Hosts + require.GreaterOrEqualf(mt, len(hosts), hostCount, "test cluster must have at least 2 mongos hosts") + + // Configure a failpoint for the first mongos host. + failPoint := mtest.FailPoint{ + ConfigureFailPoint: "failCommand", + Mode: mtest.FailPointMode{ + Times: 1, + }, + Data: mtest.FailPointData{ + FailCommands: []string{"find"}, + ErrorCode: 11600, + CloseConnection: false, + }, + } + + // In order to ensure that each mongos in the hostCount-many mongos hosts + // are tried at least once (i.e. failures are deprioritized), we set a + // failpoint on all mongos hosts. The idea is that if we get hostCount-many + // failures, then by the pigeonhole principal all mongos hosts must have + // been tried. + for i := 0; i < hostCount; i++ { + mt.ResetClient(options.Client().SetHosts([]string{hosts[i]})) + mt.SetFailPoint(failPoint) + + // The automatic failpoint clearing may not clear failpoints set on + // specific hosts, so manually clear the failpoint we set on the specific + // mongos when the test is done. + defer mt.ResetClient(options.Client().SetHosts([]string{hosts[i]})) + defer mt.ClearFailPoints() + } + + findCommandFailedCount := 0 + + commandMonitor := &event.CommandMonitor{ + Failed: func(_ context.Context, _ *event.CommandFailedEvent) { + findCommandFailedCount++ + }, + } + + // Reset the client with exactly hostCount-many mongos hosts. + mt.ResetClient(options.Client(). + SetHosts(hosts[:hostCount]). + SetTimeout(10000 * time.Millisecond). + SetRetryReads(true). + SetMonitor(commandMonitor)) + + // ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + // defer cancel() + + err := mt.Coll.FindOne(context.Background(), bson.D{}).Err() + fmt.Println("err: ", err) + + assert.Equal(mt, hostCount, findCommandFailedCount) + + // Create a connection to a database for each mongos host + // mongosOpts := options.Client().ApplyURI(hosts[0]) + + // firstMongos, err := mongo.Connect(context.Background(), mongosOpts) + // require.NoError(mt, err) + + // result := firstMongos.Database("admin").RunCommand(context.Background(), doc) + // require.NoError(mt, result.Err()) + + // secondMongos, err := mongo.Connect(context.Background(), mongosOpts) + // require.NoError(mt, err) + + // result = secondMongos.Database("admin").RunCommand(context.Background(), doc) + // require.NoError(mt, result.Err()) + }) } diff --git a/mongo/util.go b/mongo/util.go index 270fa24a25..a15b21cdfb 100644 --- a/mongo/util.go +++ b/mongo/util.go @@ -5,3 +5,5 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 package mongo + +// NOTE: meep diff --git a/x/mongo/driver/batch_cursor.go b/x/mongo/driver/batch_cursor.go index fefcfdb475..3d5b78d54a 100644 --- a/x/mongo/driver/batch_cursor.go +++ b/x/mongo/driver/batch_cursor.go @@ -508,9 +508,11 @@ type loadBalancedCursorDeployment struct { conn PinnedConnection } -var _ Deployment = (*loadBalancedCursorDeployment)(nil) -var _ Server = (*loadBalancedCursorDeployment)(nil) -var _ ErrorProcessor = (*loadBalancedCursorDeployment)(nil) +var ( + _ Deployment = (*loadBalancedCursorDeployment)(nil) + _ Server = (*loadBalancedCursorDeployment)(nil) + _ ErrorProcessor = (*loadBalancedCursorDeployment)(nil) +) func (lbcd *loadBalancedCursorDeployment) SelectServer(_ context.Context, _ description.ServerSelector) (Server, error) { return lbcd, nil diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index f5b0d7df21..a23aa48c7c 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -419,6 +419,11 @@ func (op Operation) Execute(ctx context.Context) error { if err != nil { return err } + //fmt.Println("") + //_, deadlineSet := ctx.Deadline() + //fmt.Println("!deadlineSet: ", !deadlineSet) + //fmt.Println("op.Timeout != nil : ", op.Timeout != nil) + //fmt.Println("!csot.IsTimeoutContext(ctx): ", !csot.IsTimeoutContext(ctx)) // If no deadline is set on the passed-in context, op.Timeout is set, and context is not already // a Timeout context, honor op.Timeout in new Timeout context for operation execution. @@ -479,6 +484,7 @@ func (op Operation) Execute(ctx context.Context) error { // resetForRetry records the error that caused the retry, decrements retries, and resets the // retry loop variables to request a new server and a new connection for the next attempt. resetForRetry := func(err error) { + fmt.Println("resetting for retry: ", retries) retries-- prevErr = err @@ -506,6 +512,10 @@ func (op Operation) Execute(ctx context.Context) error { if conn != nil { conn.Close() } + + // NOTE: We can no longer just nullify the server here. We need to + // NOTE: "remember" the server and pass it down to the server selector so + // NOTE: that it can be ignored when trying to select a new server. // Set the server and connection to nil to request a new server and connection. srvr = nil conn = nil @@ -528,6 +538,8 @@ func (op Operation) Execute(ctx context.Context) error { for { // If the server or connection are nil, try to select a new server and get a new connection. if srvr == nil || conn == nil { + // NOTE: Each time a "retry" occurs, the server will be reset and this + // NOTE: branch will be entered. srvr, conn, err = op.getServerAndConnection(ctx) if err != nil { // If the returned error is retryable and there are retries remaining (negative From 354597efda08c58870bbb92b1f9161f82bcb8dd7 Mon Sep 17 00:00:00 2001 From: Preston Vasquez <24281431+prestonvasquez@users.noreply.github.com> Date: Thu, 17 Aug 2023 19:36:50 -0600 Subject: [PATCH 2/8] GODRIVER-2101 Direct read/write retries to another mongos if possible --- mongo/collection.go | 2 - .../integration/retryable_reads_prose_test.go | 155 ++++++++++-------- .../retryable_writes_prose_test.go | 92 +++++++++++ mongo/util.go | 2 - x/mongo/driver/operation.go | 100 +++++++++-- x/mongo/driver/operation_test.go | 138 +++++++++++++++- 6 files changed, 393 insertions(+), 96 deletions(-) diff --git a/mongo/collection.go b/mongo/collection.go index 8e9bab53aa..9826107072 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -1226,8 +1226,6 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, fo := options.MergeFindOptions(opts...) - fmt.Println("timeout on client: ", coll.client.timeout) - selector := makeReadPrefSelector(sess, coll.readSelector, coll.client.localThreshold) op := operation.NewFind(f). Session(sess).ReadConcern(rc).ReadPreference(coll.readPreference). diff --git a/mongo/integration/retryable_reads_prose_test.go b/mongo/integration/retryable_reads_prose_test.go index ce48e7bdb5..fa72b53dfd 100644 --- a/mongo/integration/retryable_reads_prose_test.go +++ b/mongo/integration/retryable_reads_prose_test.go @@ -8,7 +8,6 @@ package integration import ( "context" - "fmt" "sync" "testing" "time" @@ -106,80 +105,94 @@ func TestRetryableReadsProse(t *testing.T) { } }) - mtOpts = mtest.NewOptions().ClientOptions(clientOpts).MinServerVersion("4.2"). - Topologies(mtest.Sharded) - - mt.RunOpts("retry on different mongos", mtOpts, func(mt *mtest.T) { - const hostCount = 3 - - hosts := options.Client().ApplyURI(mtest.ClusterURI()).Hosts - require.GreaterOrEqualf(mt, len(hosts), hostCount, "test cluster must have at least 2 mongos hosts") - - // Configure a failpoint for the first mongos host. - failPoint := mtest.FailPoint{ - ConfigureFailPoint: "failCommand", - Mode: mtest.FailPointMode{ - Times: 1, + mtOpts = mtest.NewOptions().Topologies(mtest.Sharded).MinServerVersion("4.2") + mt.RunOpts("retrying in sharded cluster", mtOpts, func(mt *mtest.T) { + tests := []struct { + name string + + // Note that setting this value greater than 2 will result in false + // negatives. The current specification does not account for CSOT, which + // might allow for an "inifinite" number of retries over a period of time. + // Because of this, we only track the "previous server". + hostCount int + failpointErrorCode int32 + expectedFailCount int + expectedSuccessCount int + }{ + { + name: "retry on different mongos", + hostCount: 2, + failpointErrorCode: 6, // HostUnreachable + expectedFailCount: 2, + expectedSuccessCount: 0, }, - Data: mtest.FailPointData{ - FailCommands: []string{"find"}, - ErrorCode: 11600, - CloseConnection: false, + { + name: "retry on same mongos", + hostCount: 1, + failpointErrorCode: 6, // HostUnreachable + expectedFailCount: 1, + expectedSuccessCount: 1, }, } - // In order to ensure that each mongos in the hostCount-many mongos hosts - // are tried at least once (i.e. failures are deprioritized), we set a - // failpoint on all mongos hosts. The idea is that if we get hostCount-many - // failures, then by the pigeonhole principal all mongos hosts must have - // been tried. - for i := 0; i < hostCount; i++ { - mt.ResetClient(options.Client().SetHosts([]string{hosts[i]})) - mt.SetFailPoint(failPoint) - - // The automatic failpoint clearing may not clear failpoints set on - // specific hosts, so manually clear the failpoint we set on the specific - // mongos when the test is done. - defer mt.ResetClient(options.Client().SetHosts([]string{hosts[i]})) - defer mt.ClearFailPoints() + for _, tc := range tests { + mt.Run(tc.name, func(mt *mtest.T) { + hosts := options.Client().ApplyURI(mtest.ClusterURI()).Hosts + require.GreaterOrEqualf(mt, len(hosts), tc.hostCount, + "test cluster must have at least %v mongos hosts", tc.hostCount) + + // Configure the failpoint options for each mongos. + failPoint := mtest.FailPoint{ + ConfigureFailPoint: "failCommand", + Mode: mtest.FailPointMode{ + Times: 1, + }, + Data: mtest.FailPointData{ + FailCommands: []string{"find"}, + ErrorCode: tc.failpointErrorCode, + CloseConnection: false, + }, + } + + // In order to ensure that each mongos in the hostCount-many mongos + // hosts are tried at least once (i.e. failures are deprioritized), we + // set a failpoint on all mongos hosts. The idea is that if we get + // hostCount-many failures, then by the pigeonhole principal all mongos + // hosts must have been tried. + for i := 0; i < tc.hostCount; i++ { + mt.ResetClient(options.Client().SetHosts([]string{hosts[i]})) + mt.SetFailPoint(failPoint) + + // The automatic failpoint clearing may not clear failpoints set on + // specific hosts, so manually clear the failpoint we set on the + // specific mongos when the test is done. + defer mt.ResetClient(options.Client().SetHosts([]string{hosts[i]})) + defer mt.ClearFailPoints() + } + + failCount := 0 + successCount := 0 + + commandMonitor := &event.CommandMonitor{ + Failed: func(context.Context, *event.CommandFailedEvent) { + failCount++ + }, + Succeeded: func(context.Context, *event.CommandSucceededEvent) { + successCount++ + }, + } + + // Reset the client with exactly hostCount-many mongos hosts. + mt.ResetClient(options.Client(). + SetHosts(hosts[:tc.hostCount]). + SetRetryReads(true). + SetMonitor(commandMonitor)) + + mt.Coll.FindOne(context.Background(), bson.D{}) + + assert.Equal(mt, tc.expectedFailCount, failCount) + assert.Equal(mt, tc.expectedSuccessCount, successCount) + }) } - - findCommandFailedCount := 0 - - commandMonitor := &event.CommandMonitor{ - Failed: func(_ context.Context, _ *event.CommandFailedEvent) { - findCommandFailedCount++ - }, - } - - // Reset the client with exactly hostCount-many mongos hosts. - mt.ResetClient(options.Client(). - SetHosts(hosts[:hostCount]). - SetTimeout(10000 * time.Millisecond). - SetRetryReads(true). - SetMonitor(commandMonitor)) - - // ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) - // defer cancel() - - err := mt.Coll.FindOne(context.Background(), bson.D{}).Err() - fmt.Println("err: ", err) - - assert.Equal(mt, hostCount, findCommandFailedCount) - - // Create a connection to a database for each mongos host - // mongosOpts := options.Client().ApplyURI(hosts[0]) - - // firstMongos, err := mongo.Connect(context.Background(), mongosOpts) - // require.NoError(mt, err) - - // result := firstMongos.Database("admin").RunCommand(context.Background(), doc) - // require.NoError(mt, result.Err()) - - // secondMongos, err := mongo.Connect(context.Background(), mongosOpts) - // require.NoError(mt, err) - - // result = secondMongos.Database("admin").RunCommand(context.Background(), doc) - // require.NoError(mt, result.Err()) }) } diff --git a/mongo/integration/retryable_writes_prose_test.go b/mongo/integration/retryable_writes_prose_test.go index e731a09ade..23d5d8896d 100644 --- a/mongo/integration/retryable_writes_prose_test.go +++ b/mongo/integration/retryable_writes_prose_test.go @@ -285,4 +285,96 @@ func TestRetryableWritesProse(t *testing.T) { // Assert that the "ShutdownInProgress" error is returned. require.True(mt, err.(mongo.WriteException).HasErrorCode(int(shutdownInProgressErrorCode))) }) + + mtOpts = mtest.NewOptions().Topologies(mtest.Sharded).MinServerVersion("4.2") + mt.RunOpts("retrying in sharded cluster", mtOpts, func(mt *mtest.T) { + tests := []struct { + name string + + // Note that setting this value greater than 2 will result in false + // negatives. The current specification does not account for CSOT, which + // might allow for an "inifinite" number of retries over a period of time. + // Because of this, we only track the "previous server". + hostCount int + failpointErrorCode int32 + expectedFailCount int + expectedSuccessCount int + }{ + { + name: "retry on different mongos", + hostCount: 2, + failpointErrorCode: 6, // HostUnreachable + expectedFailCount: 2, + expectedSuccessCount: 0, + }, + { + name: "retry on same mongos", + hostCount: 1, + failpointErrorCode: 6, // HostUnreachable + expectedFailCount: 1, + expectedSuccessCount: 1, + }, + } + + for _, tc := range tests { + mt.Run(tc.name, func(mt *mtest.T) { + hosts := options.Client().ApplyURI(mtest.ClusterURI()).Hosts + require.GreaterOrEqualf(mt, len(hosts), tc.hostCount, + "test cluster must have at least %v mongos hosts", tc.hostCount) + + // Configure the failpoint options for each mongos. + failPoint := mtest.FailPoint{ + ConfigureFailPoint: "failCommand", + Mode: mtest.FailPointMode{ + Times: 1, + }, + Data: mtest.FailPointData{ + FailCommands: []string{"insert"}, + ErrorLabels: &[]string{"RetryableWriteError"}, + ErrorCode: tc.failpointErrorCode, + CloseConnection: false, + }, + } + + // In order to ensure that each mongos in the hostCount-many mongos + // hosts are tried at least once (i.e. failures are deprioritized), we + // set a failpoint on all mongos hosts. The idea is that if we get + // hostCount-many failures, then by the pigeonhole principal all mongos + // hosts must have been tried. + for i := 0; i < tc.hostCount; i++ { + mt.ResetClient(options.Client().SetHosts([]string{hosts[i]})) + mt.SetFailPoint(failPoint) + + // The automatic failpoint clearing may not clear failpoints set on + // specific hosts, so manually clear the failpoint we set on the + // specific mongos when the test is done. + defer mt.ResetClient(options.Client().SetHosts([]string{hosts[i]})) + defer mt.ClearFailPoints() + } + + failCount := 0 + successCount := 0 + + commandMonitor := &event.CommandMonitor{ + Failed: func(context.Context, *event.CommandFailedEvent) { + failCount++ + }, + Succeeded: func(context.Context, *event.CommandSucceededEvent) { + successCount++ + }, + } + + // Reset the client with exactly hostCount-many mongos hosts. + mt.ResetClient(options.Client(). + SetHosts(hosts[:tc.hostCount]). + SetRetryWrites(true). + SetMonitor(commandMonitor)) + + mt.Coll.InsertOne(context.Background(), bson.D{}) + + assert.Equal(mt, tc.expectedFailCount, failCount) + assert.Equal(mt, tc.expectedSuccessCount, successCount) + }) + } + }) } diff --git a/mongo/util.go b/mongo/util.go index a15b21cdfb..270fa24a25 100644 --- a/mongo/util.go +++ b/mongo/util.go @@ -5,5 +5,3 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 package mongo - -// NOTE: meep diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index a23aa48c7c..17890863bf 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -320,8 +320,71 @@ func (op Operation) shouldEncrypt() bool { return op.Crypt != nil && !op.Crypt.BypassAutoEncryption() } +// opServerSelector is a wrapper for the server selector that is assigned to the +// operation. The purpose of this wrapper is to filter candidates with +// operation-specific logic, such as deprioritizing failing servers. +type opServerSelector struct { + selector description.ServerSelector + deprioritizedServers map[address.Address]bool +} + +func (oss *opServerSelector) setDeprioritizedServers(dpa []address.Address) { + oss.deprioritizedServers = make(map[address.Address]bool) + + for _, addr := range dpa { + oss.deprioritizedServers[addr] = true + } +} + +// filterDeprioritizedServers will filter out the server candidates that have +// been deprioritized by the operation due to failure. +// +// The server selector should try to select a server that is not in the +// deprioritization list. However, if this is not possible (e.g. there are no +// other healthy servers in the cluster), the selector may return a +// deprioritized server. +func filterDeprioritizedServers(oss opServerSelector, candidates []description.Server) []description.Server { + allowedIndexes := make([]int, 0, len(candidates)) + + // Iterate over the candidates and append them to the allowdIndexes slice if + // they are not in the deprioritizedServers list. + for i, candidate := range candidates { + if !oss.deprioritizedServers[candidate.Addr] { + allowedIndexes = append(allowedIndexes, i) + } + } + + allowed := make([]description.Server, len(allowedIndexes)) + for i, idx := range allowedIndexes { + allowed[i] = candidates[idx] + } + + // If nothing is allowed, then all available servers must have been + // deprioritized. In this case, return the candidates list as-is so that the + // selector can find a suitable server + if len(allowed) == 0 { + return candidates + } + + return allowed +} + +// SelectServer will filter candidates with operation-specific logic before +// passing them onto the user-defined or defualt selector. +func (oss opServerSelector) SelectServer( + topo description.Topology, + candidates []description.Server, +) ([]description.Server, error) { + if oss.selector == nil { + return []description.Server{}, nil + } + + candidates = filterDeprioritizedServers(oss, candidates) + return oss.selector.SelectServer(topo, candidates) +} + // selectServer handles performing server selection for an operation. -func (op Operation) selectServer(ctx context.Context) (Server, error) { +func (op Operation) selectServer(ctx context.Context, dpa []address.Address) (Server, error) { if err := op.Validate(); err != nil { return nil, err } @@ -338,12 +401,18 @@ func (op Operation) selectServer(ctx context.Context) (Server, error) { }) } - return op.Deployment.SelectServer(ctx, selector) + oss := opServerSelector{selector: selector} + + if len(dpa) > 0 { + oss.setDeprioritizedServers(dpa) + } + + return op.Deployment.SelectServer(ctx, oss) } // getServerAndConnection should be used to retrieve a Server and Connection to execute an operation. -func (op Operation) getServerAndConnection(ctx context.Context) (Server, Connection, error) { - server, err := op.selectServer(ctx) +func (op Operation) getServerAndConnection(ctx context.Context, dpa []address.Address) (Server, Connection, error) { + server, err := op.selectServer(ctx, dpa) if err != nil { if op.Client != nil && !(op.Client.Committing || op.Client.Aborting) && op.Client.TransactionRunning() { @@ -419,11 +488,6 @@ func (op Operation) Execute(ctx context.Context) error { if err != nil { return err } - //fmt.Println("") - //_, deadlineSet := ctx.Deadline() - //fmt.Println("!deadlineSet: ", !deadlineSet) - //fmt.Println("op.Timeout != nil : ", op.Timeout != nil) - //fmt.Println("!csot.IsTimeoutContext(ctx): ", !csot.IsTimeoutContext(ctx)) // If no deadline is set on the passed-in context, op.Timeout is set, and context is not already // a Timeout context, honor op.Timeout in new Timeout context for operation execution. @@ -481,10 +545,15 @@ func (op Operation) Execute(ctx context.Context) error { first := true currIndex := 0 + // deprioritizedServers are a running list of servers that should be + // deprioritized during server selection. Per the specifications, we should + // only ever deprioritize the "previous server". Therefore, this list is + // initialized with a length of 1. + deprioritizedServers := []address.Address{} + // resetForRetry records the error that caused the retry, decrements retries, and resets the // retry loop variables to request a new server and a new connection for the next attempt. resetForRetry := func(err error) { - fmt.Println("resetting for retry: ", retries) retries-- prevErr = err @@ -510,13 +579,14 @@ func (op Operation) Execute(ctx context.Context) error { // If we got a connection, close it immediately to release pool resources for // subsequent retries. if conn != nil { + // Empty the deprioritizedServer list and + if desc := conn.Description; desc != nil { + deprioritizedServers = []address.Address{conn.Description().Addr} + } + conn.Close() } - // NOTE: We can no longer just nullify the server here. We need to - // NOTE: "remember" the server and pass it down to the server selector so - // NOTE: that it can be ignored when trying to select a new server. - // Set the server and connection to nil to request a new server and connection. srvr = nil conn = nil } @@ -540,7 +610,7 @@ func (op Operation) Execute(ctx context.Context) error { if srvr == nil || conn == nil { // NOTE: Each time a "retry" occurs, the server will be reset and this // NOTE: branch will be entered. - srvr, conn, err = op.getServerAndConnection(ctx) + srvr, conn, err = op.getServerAndConnection(ctx, deprioritizedServers) if err != nil { // If the returned error is retryable and there are retries remaining (negative // retries means retry indefinitely), then retry the operation. Set the server diff --git a/x/mongo/driver/operation_test.go b/x/mongo/driver/operation_test.go index f87cc7bd1b..8299cc13db 100644 --- a/x/mongo/driver/operation_test.go +++ b/x/mongo/driver/operation_test.go @@ -1,5 +1,5 @@ // Copyright (C) MongoDB, Inc. 2022-present. -// +//x/mongo/driver/operation_test // Licensed under the Apache License, Version 2.0 (the "License"); you may // not use this file except in compliance with the License. You may obtain // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 @@ -62,7 +62,7 @@ func TestOperation(t *testing.T) { t.Run("selectServer", func(t *testing.T) { t.Run("returns validation error", func(t *testing.T) { op := &Operation{} - _, err := op.selectServer(context.Background()) + _, err := op.selectServer(context.Background(), nil) if err == nil { t.Error("Expected a validation error from selectServer, but got ") } @@ -76,10 +76,15 @@ func TestOperation(t *testing.T) { Database: "testing", Selector: want, } - _, err := op.selectServer(context.Background()) + _, err := op.selectServer(context.Background(), nil) noerr(t, err) - got := d.params.selector - if !cmp.Equal(got, want) { + //got := d.params.selector + + // Assert the the selector is an operation selector wrapper. + got, ok := d.params.selector.(*opServerSelector) + assert.True(t, ok) + + if !cmp.Equal(got.selector, want) { t.Errorf("Did not get expected server selector. got %v; want %v", got, want) } }) @@ -90,7 +95,7 @@ func TestOperation(t *testing.T) { Deployment: d, Database: "testing", } - _, err := op.selectServer(context.Background()) + _, err := op.selectServer(context.Background(), nil) noerr(t, err) if d.params.selector == nil { t.Error("The selectServer method should use a default selector when not specified on Operation, but it passed .") @@ -876,3 +881,124 @@ func TestDecodeOpReply(t *testing.T) { assert.Equal(t, []bsoncore.Document(nil), reply.documents) }) } + +func TestFilterDeprioritizedServers(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + oss opServerSelector + candidates []description.Server + want []description.Server + }{ + { + name: "empty", + oss: opServerSelector{}, + candidates: []description.Server{}, + want: []description.Server{}, + }, + { + name: "nil candidates", + oss: opServerSelector{}, + candidates: nil, + want: []description.Server{}, + }, + { + name: "nil deprioritized server list", + oss: opServerSelector{deprioritizedServers: nil}, + candidates: []description.Server{ + { + Addr: address.Address("mongodb://localhost:27017"), + }, + }, + want: []description.Server{ + { + Addr: address.Address("mongodb://localhost:27017"), + }, + }, + }, + { + name: "deprioritize single server candidate list", + oss: opServerSelector{ + deprioritizedServers: map[address.Address]bool{ + "mongodb://localhost:27017": true, + }, + }, + candidates: []description.Server{ + { + Addr: address.Address("mongodb://localhost:27017"), + }, + }, + want: []description.Server{ + // Since all available servers were deprioritized, then the selector + // should return all candidates. + { + Addr: address.Address("mongodb://localhost:27017"), + }, + }, + }, + { + name: "depriotirize one server in multi server candidate list", + oss: opServerSelector{ + deprioritizedServers: map[address.Address]bool{ + "mongodb://localhost:27017": true, + }, + }, + candidates: []description.Server{ + { + Addr: address.Address("mongodb://localhost:27017"), + }, + { + Addr: address.Address("mongodb://localhost:27018"), + }, + { + Addr: address.Address("mongodb://localhost:27019"), + }, + }, + want: []description.Server{ + { + Addr: address.Address("mongodb://localhost:27018"), + }, + { + Addr: address.Address("mongodb://localhost:27019"), + }, + }, + }, + { + name: "depriotirize multiple servers in multi server candidate list", + oss: opServerSelector{ + deprioritizedServers: map[address.Address]bool{ + "mongodb://localhost:27017": true, + "mongodb://localhost:27018": true, + }, + }, + candidates: []description.Server{ + { + Addr: address.Address("mongodb://localhost:27017"), + }, + { + Addr: address.Address("mongodb://localhost:27018"), + }, + { + Addr: address.Address("mongodb://localhost:27019"), + }, + }, + want: []description.Server{ + { + Addr: address.Address("mongodb://localhost:27019"), + }, + }, + }, + } + + for _, tc := range tests { + tc := tc // Capture the range variable. + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + got := filterDeprioritizedServers(tc.oss, tc.candidates) + assert.ElementsMatch(t, got, tc.want) + }) + } +} From 44dd3f4075e0bcb555e5effe8c9295eccd1b21d6 Mon Sep 17 00:00:00 2001 From: Preston Vasquez <24281431+prestonvasquez@users.noreply.github.com> Date: Thu, 17 Aug 2023 19:44:42 -0600 Subject: [PATCH 3/8] GODRIVER-2101 Revert unecessary changes --- x/mongo/driver/operation.go | 1 + x/mongo/driver/operation_test.go | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 17890863bf..ac6b71f259 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -587,6 +587,7 @@ func (op Operation) Execute(ctx context.Context) error { conn.Close() } + // Set the server and connection to nil to request a new server and connection. srvr = nil conn = nil } diff --git a/x/mongo/driver/operation_test.go b/x/mongo/driver/operation_test.go index 8299cc13db..bb548fe89a 100644 --- a/x/mongo/driver/operation_test.go +++ b/x/mongo/driver/operation_test.go @@ -1,5 +1,5 @@ // Copyright (C) MongoDB, Inc. 2022-present. -//x/mongo/driver/operation_test +// // Licensed under the Apache License, Version 2.0 (the "License"); you may // not use this file except in compliance with the License. You may obtain // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 From ecea7518432a5820ba0d1c94a8b101fa376a2835 Mon Sep 17 00:00:00 2001 From: Preston Vasquez <24281431+prestonvasquez@users.noreply.github.com> Date: Thu, 17 Aug 2023 19:46:36 -0600 Subject: [PATCH 4/8] GODRIVER-2101 revert changes to collection and cursor --- mongo/collection.go | 143 +++++++++++++++++++++++++------------------- mongo/cursor.go | 4 +- 2 files changed, 85 insertions(+), 62 deletions(-) diff --git a/mongo/collection.go b/mongo/collection.go index 9826107072..6abbea9792 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -185,8 +185,8 @@ func (coll *Collection) Database() *Database { // // The opts parameter can be used to specify options for the operation (see the options.BulkWriteOptions documentation.) func (coll *Collection) BulkWrite(ctx context.Context, models []WriteModel, - opts ...*options.BulkWriteOptions, -) (*BulkWriteResult, error) { + opts ...*options.BulkWriteOptions) (*BulkWriteResult, error) { + if len(models) == 0 { return nil, ErrEmptySlice } @@ -242,8 +242,8 @@ func (coll *Collection) BulkWrite(ctx context.Context, models []WriteModel, } func (coll *Collection) insert(ctx context.Context, documents []interface{}, - opts ...*options.InsertManyOptions, -) ([]interface{}, error) { + opts ...*options.InsertManyOptions) ([]interface{}, error) { + if ctx == nil { ctx = context.Background() } @@ -343,8 +343,8 @@ func (coll *Collection) insert(ctx context.Context, documents []interface{}, // // For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/insert/. func (coll *Collection) InsertOne(ctx context.Context, document interface{}, - opts ...*options.InsertOneOptions, -) (*InsertOneResult, error) { + opts ...*options.InsertOneOptions) (*InsertOneResult, error) { + ioOpts := options.MergeInsertOneOptions(opts...) imOpts := options.InsertMany() @@ -375,8 +375,8 @@ func (coll *Collection) InsertOne(ctx context.Context, document interface{}, // // For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/insert/. func (coll *Collection) InsertMany(ctx context.Context, documents []interface{}, - opts ...*options.InsertManyOptions, -) (*InsertManyResult, error) { + opts ...*options.InsertManyOptions) (*InsertManyResult, error) { + if len(documents) == 0 { return nil, ErrEmptySlice } @@ -410,8 +410,8 @@ func (coll *Collection) InsertMany(ctx context.Context, documents []interface{}, } func (coll *Collection) delete(ctx context.Context, filter interface{}, deleteOne bool, expectedRr returnResult, - opts ...*options.DeleteOptions, -) (*DeleteResult, error) { + opts ...*options.DeleteOptions) (*DeleteResult, error) { + if ctx == nil { ctx = context.Background() } @@ -514,8 +514,8 @@ func (coll *Collection) delete(ctx context.Context, filter interface{}, deleteOn // // For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/delete/. func (coll *Collection) DeleteOne(ctx context.Context, filter interface{}, - opts ...*options.DeleteOptions, -) (*DeleteResult, error) { + opts ...*options.DeleteOptions) (*DeleteResult, error) { + return coll.delete(ctx, filter, true, rrOne, opts...) } @@ -530,14 +530,14 @@ func (coll *Collection) DeleteOne(ctx context.Context, filter interface{}, // // For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/delete/. func (coll *Collection) DeleteMany(ctx context.Context, filter interface{}, - opts ...*options.DeleteOptions, -) (*DeleteResult, error) { + opts ...*options.DeleteOptions) (*DeleteResult, error) { + return coll.delete(ctx, filter, false, rrMany, opts...) } func (coll *Collection) updateOrReplace(ctx context.Context, filter bsoncore.Document, update interface{}, multi bool, - expectedRr returnResult, checkDollarKey bool, opts ...*options.UpdateOptions, -) (*UpdateResult, error) { + expectedRr returnResult, checkDollarKey bool, opts ...*options.UpdateOptions) (*UpdateResult, error) { + if ctx == nil { ctx = context.Background() } @@ -648,8 +648,7 @@ func (coll *Collection) updateOrReplace(ctx context.Context, filter bsoncore.Doc // // For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/update/. func (coll *Collection) UpdateByID(ctx context.Context, id interface{}, update interface{}, - opts ...*options.UpdateOptions, -) (*UpdateResult, error) { + opts ...*options.UpdateOptions) (*UpdateResult, error) { if id == nil { return nil, ErrNilValue } @@ -671,8 +670,8 @@ func (coll *Collection) UpdateByID(ctx context.Context, id interface{}, update i // // For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/update/. func (coll *Collection) UpdateOne(ctx context.Context, filter interface{}, update interface{}, - opts ...*options.UpdateOptions, -) (*UpdateResult, error) { + opts ...*options.UpdateOptions) (*UpdateResult, error) { + if ctx == nil { ctx = context.Background() } @@ -699,8 +698,8 @@ func (coll *Collection) UpdateOne(ctx context.Context, filter interface{}, updat // // For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/update/. func (coll *Collection) UpdateMany(ctx context.Context, filter interface{}, update interface{}, - opts ...*options.UpdateOptions, -) (*UpdateResult, error) { + opts ...*options.UpdateOptions) (*UpdateResult, error) { + if ctx == nil { ctx = context.Background() } @@ -727,8 +726,8 @@ func (coll *Collection) UpdateMany(ctx context.Context, filter interface{}, upda // // For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/update/. func (coll *Collection) ReplaceOne(ctx context.Context, filter interface{}, - replacement interface{}, opts ...*options.ReplaceOptions, -) (*UpdateResult, error) { + replacement interface{}, opts ...*options.ReplaceOptions) (*UpdateResult, error) { + if ctx == nil { ctx = context.Background() } @@ -777,8 +776,7 @@ func (coll *Collection) ReplaceOne(ctx context.Context, filter interface{}, // // For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/aggregate/. func (coll *Collection) Aggregate(ctx context.Context, pipeline interface{}, - opts ...*options.AggregateOptions, -) (*Cursor, error) { + opts ...*options.AggregateOptions) (*Cursor, error) { a := aggregateParams{ ctx: ctx, pipeline: pipeline, @@ -954,8 +952,8 @@ func aggregate(a aggregateParams) (cur *Cursor, err error) { // // The opts parameter can be used to specify options for the operation (see the options.CountOptions documentation). func (coll *Collection) CountDocuments(ctx context.Context, filter interface{}, - opts ...*options.CountOptions, -) (int64, error) { + opts ...*options.CountOptions) (int64, error) { + if ctx == nil { ctx = context.Background() } @@ -1039,8 +1037,8 @@ func (coll *Collection) CountDocuments(ctx context.Context, filter interface{}, // // For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/count/. func (coll *Collection) EstimatedDocumentCount(ctx context.Context, - opts ...*options.EstimatedDocumentCountOptions, -) (int64, error) { + opts ...*options.EstimatedDocumentCountOptions) (int64, error) { + if ctx == nil { ctx = context.Background() } @@ -1101,8 +1099,8 @@ func (coll *Collection) EstimatedDocumentCount(ctx context.Context, // // For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/distinct/. func (coll *Collection) Distinct(ctx context.Context, fieldName string, filter interface{}, - opts ...*options.DistinctOptions, -) ([]interface{}, error) { + opts ...*options.DistinctOptions) ([]interface{}, error) { + if ctx == nil { ctx = context.Background() } @@ -1192,8 +1190,8 @@ func (coll *Collection) Distinct(ctx context.Context, fieldName string, filter i // // For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/find/. func (coll *Collection) Find(ctx context.Context, filter interface{}, - opts ...*options.FindOptions, -) (cur *Cursor, err error) { + opts ...*options.FindOptions) (cur *Cursor, err error) { + if ctx == nil { ctx = context.Background() } @@ -1351,7 +1349,6 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, if coll.client.retryReads { retry = driver.RetryOncePerCommand } - op = op.Retry(retry) if err = op.Execute(ctx); err != nil { @@ -1375,8 +1372,8 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, // // For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/find/. func (coll *Collection) FindOne(ctx context.Context, filter interface{}, - opts ...*options.FindOneOptions, -) *SingleResult { + opts ...*options.FindOneOptions) *SingleResult { + if ctx == nil { ctx = context.Background() } @@ -1489,8 +1486,8 @@ func (coll *Collection) findAndModify(ctx context.Context, op *operation.FindAnd // // For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/findAndModify/. func (coll *Collection) FindOneAndDelete(ctx context.Context, filter interface{}, - opts ...*options.FindOneAndDeleteOptions, -) *SingleResult { + opts ...*options.FindOneAndDeleteOptions) *SingleResult { + f, err := marshal(filter, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} @@ -1561,8 +1558,8 @@ func (coll *Collection) FindOneAndDelete(ctx context.Context, filter interface{} // // For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/findAndModify/. func (coll *Collection) FindOneAndReplace(ctx context.Context, filter interface{}, - replacement interface{}, opts ...*options.FindOneAndReplaceOptions, -) *SingleResult { + replacement interface{}, opts ...*options.FindOneAndReplaceOptions) *SingleResult { + f, err := marshal(filter, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} @@ -1651,8 +1648,8 @@ func (coll *Collection) FindOneAndReplace(ctx context.Context, filter interface{ // // For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/findAndModify/. func (coll *Collection) FindOneAndUpdate(ctx context.Context, filter interface{}, - update interface{}, opts ...*options.FindOneAndUpdateOptions, -) *SingleResult { + update interface{}, opts ...*options.FindOneAndUpdateOptions) *SingleResult { + if ctx == nil { ctx = context.Background() } @@ -1755,8 +1752,8 @@ func (coll *Collection) FindOneAndUpdate(ctx context.Context, filter interface{} // The opts parameter can be used to specify options for change stream creation (see the options.ChangeStreamOptions // documentation). func (coll *Collection) Watch(ctx context.Context, pipeline interface{}, - opts ...*options.ChangeStreamOptions, -) (*ChangeStream, error) { + opts ...*options.ChangeStreamOptions) (*ChangeStream, error) { + csConfig := changeStreamConfig{ readConcern: coll.readConcern, readPreference: coll.readPreference, @@ -1871,26 +1868,52 @@ func (coll *Collection) drop(ctx context.Context) error { return nil } -// makePinnedSelector makes a selector for a pinned session with a pinned server. Will attempt to do server selection on -// the pinned server but if that fails it will go through a list of default selectors -func makePinnedSelector(sess *session.Client, defaultSelector description.ServerSelector) description.ServerSelectorFunc { - return func(t description.Topology, svrs []description.Server) ([]description.Server, error) { - if sess != nil && sess.PinnedServer != nil { - // If there is a pinned server, try to find it in the list of candidates. - for _, candidate := range svrs { - if candidate.Addr == sess.PinnedServer.Addr { - return []description.Server{candidate}, nil - } - } +type pinnedServerSelector struct { + stringer fmt.Stringer + fallback description.ServerSelector + session *session.Client +} + +func (pss pinnedServerSelector) String() string { + if pss.stringer == nil { + return "" + } + + return pss.stringer.String() +} - return nil, nil +func (pss pinnedServerSelector) SelectServer( + t description.Topology, + svrs []description.Server, +) ([]description.Server, error) { + if pss.session != nil && pss.session.PinnedServer != nil { + // If there is a pinned server, try to find it in the list of candidates. + for _, candidate := range svrs { + if candidate.Addr == pss.session.PinnedServer.Addr { + return []description.Server{candidate}, nil + } } - return defaultSelector.SelectServer(t, svrs) + return nil, nil + } + + return pss.fallback.SelectServer(t, svrs) +} + +func makePinnedSelector(sess *session.Client, fallback description.ServerSelector) description.ServerSelector { + pss := pinnedServerSelector{ + session: sess, + fallback: fallback, + } + + if srvSelectorStringer, ok := fallback.(fmt.Stringer); ok { + pss.stringer = srvSelectorStringer } + + return pss } -func makeReadPrefSelector(sess *session.Client, selector description.ServerSelector, localThreshold time.Duration) description.ServerSelectorFunc { +func makeReadPrefSelector(sess *session.Client, selector description.ServerSelector, localThreshold time.Duration) description.ServerSelector { if sess != nil && sess.TransactionRunning() { selector = description.CompositeSelector([]description.ServerSelector{ description.ReadPrefSelector(sess.CurrentRp), @@ -1901,7 +1924,7 @@ func makeReadPrefSelector(sess *session.Client, selector description.ServerSelec return makePinnedSelector(sess, selector) } -func makeOutputAggregateSelector(sess *session.Client, rp *readpref.ReadPref, localThreshold time.Duration) description.ServerSelectorFunc { +func makeOutputAggregateSelector(sess *session.Client, rp *readpref.ReadPref, localThreshold time.Duration) description.ServerSelector { if sess != nil && sess.TransactionRunning() { // Use current transaction's read preference if available rp = sess.CurrentRp diff --git a/mongo/cursor.go b/mongo/cursor.go index f213771807..d2228ed9c4 100644 --- a/mongo/cursor.go +++ b/mongo/cursor.go @@ -345,8 +345,8 @@ func (c *Cursor) RemainingBatchLength() int { // addFromBatch adds all documents from batch to sliceVal starting at the given index. It returns the new slice value, // the next empty index in the slice, and an error if one occurs. func (c *Cursor) addFromBatch(sliceVal reflect.Value, elemType reflect.Type, batch *bsoncore.DocumentSequence, - index int, -) (reflect.Value, int, error) { + index int) (reflect.Value, int, error) { + docs, err := batch.Documents() if err != nil { return sliceVal, index, err From 3e242f801b604e8ac5751eb399635b16f14700cf Mon Sep 17 00:00:00 2001 From: Preston Vasquez <24281431+prestonvasquez@users.noreply.github.com> Date: Mon, 21 Aug 2023 17:31:31 -0600 Subject: [PATCH 5/8] GODRIVER-2101 Apply opServerSelector --- x/mongo/driver/batch_cursor.go | 8 ++-- x/mongo/driver/operation.go | 82 ++++++++++++++++---------------- x/mongo/driver/operation_test.go | 51 ++++++++++---------- 3 files changed, 70 insertions(+), 71 deletions(-) diff --git a/x/mongo/driver/batch_cursor.go b/x/mongo/driver/batch_cursor.go index 3d5b78d54a..fefcfdb475 100644 --- a/x/mongo/driver/batch_cursor.go +++ b/x/mongo/driver/batch_cursor.go @@ -508,11 +508,9 @@ type loadBalancedCursorDeployment struct { conn PinnedConnection } -var ( - _ Deployment = (*loadBalancedCursorDeployment)(nil) - _ Server = (*loadBalancedCursorDeployment)(nil) - _ ErrorProcessor = (*loadBalancedCursorDeployment)(nil) -) +var _ Deployment = (*loadBalancedCursorDeployment)(nil) +var _ Server = (*loadBalancedCursorDeployment)(nil) +var _ ErrorProcessor = (*loadBalancedCursorDeployment)(nil) func (lbcd *loadBalancedCursorDeployment) SelectServer(_ context.Context, _ description.ServerSelector) (Server, error) { return lbcd, nil diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 82bdcab48c..950732b714 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -321,22 +321,6 @@ func (op Operation) shouldEncrypt() bool { return op.Crypt != nil && !op.Crypt.BypassAutoEncryption() } -// opServerSelector is a wrapper for the server selector that is assigned to the -// operation. The purpose of this wrapper is to filter candidates with -// operation-specific logic, such as deprioritizing failing servers. -type opServerSelector struct { - selector description.ServerSelector - deprioritizedServers map[address.Address]bool -} - -func (oss *opServerSelector) setDeprioritizedServers(dpa []address.Address) { - oss.deprioritizedServers = make(map[address.Address]bool) - - for _, addr := range dpa { - oss.deprioritizedServers[addr] = true - } -} - // filterDeprioritizedServers will filter out the server candidates that have // been deprioritized by the operation due to failure. // @@ -344,13 +328,22 @@ func (oss *opServerSelector) setDeprioritizedServers(dpa []address.Address) { // deprioritization list. However, if this is not possible (e.g. there are no // other healthy servers in the cluster), the selector may return a // deprioritized server. -func filterDeprioritizedServers(oss opServerSelector, candidates []description.Server) []description.Server { +func filterDeprioritizedServers(candidates, deprioritized []description.Server) []description.Server { + if len(deprioritized) == 0 { + return candidates + } + + dpaSet := make(map[address.Address]*description.Server) + for _, srv := range deprioritized { + dpaSet[srv.Addr] = &srv + } + allowedIndexes := make([]int, 0, len(candidates)) // Iterate over the candidates and append them to the allowdIndexes slice if // they are not in the deprioritizedServers list. for i, candidate := range candidates { - if !oss.deprioritizedServers[candidate.Addr] { + if srv := dpaSet[candidate.Addr]; srv == nil || !srv.Equal(candidate) { allowedIndexes = append(allowedIndexes, i) } } @@ -370,22 +363,32 @@ func filterDeprioritizedServers(oss opServerSelector, candidates []description.S return allowed } +// opServerSelector is a wrapper for the server selector that is assigned to the +// operation. The purpose of this wrapper is to filter candidates with +// operation-specific logic, such as deprioritizing failing servers. +type opServerSelector struct { + selector description.ServerSelector + deprioritizedServers []description.Server +} + // SelectServer will filter candidates with operation-specific logic before // passing them onto the user-defined or defualt selector. -func (oss opServerSelector) SelectServer( +func (oss *opServerSelector) SelectServer( topo description.Topology, candidates []description.Server, ) ([]description.Server, error) { - if oss.selector == nil { - return []description.Server{}, nil + selectedServers, err := oss.selector.SelectServer(topo, candidates) + if err != nil { + return nil, err } - candidates = filterDeprioritizedServers(oss, candidates) - return oss.selector.SelectServer(topo, candidates) + filteredServers := filterDeprioritizedServers(selectedServers, oss.deprioritizedServers) + + return filteredServers, nil } // selectServer handles performing server selection for an operation. -func (op Operation) selectServer(ctx context.Context, dpa []address.Address) (Server, error) { +func (op Operation) selectServer(ctx context.Context, deprioritized []description.Server) (Server, error) { if err := op.Validate(); err != nil { return nil, err } @@ -402,10 +405,9 @@ func (op Operation) selectServer(ctx context.Context, dpa []address.Address) (Se }) } - oss := opServerSelector{selector: selector} - - if len(dpa) > 0 { - oss.setDeprioritizedServers(dpa) + oss := &opServerSelector{ + selector: selector, + deprioritizedServers: deprioritized, } ctx = logger.WithOperationName(ctx, op.Name) @@ -415,8 +417,11 @@ func (op Operation) selectServer(ctx context.Context, dpa []address.Address) (Se } // getServerAndConnection should be used to retrieve a Server and Connection to execute an operation. -func (op Operation) getServerAndConnection(ctx context.Context, dpa []address.Address) (Server, Connection, error) { - server, err := op.selectServer(ctx, dpa) +func (op Operation) getServerAndConnection( + ctx context.Context, + deprioritized []description.Server, +) (Server, Connection, error) { + server, err := op.selectServer(ctx, deprioritized) if err != nil { if op.Client != nil && !(op.Client.Committing || op.Client.Aborting) && op.Client.TransactionRunning() { @@ -551,9 +556,8 @@ func (op Operation) Execute(ctx context.Context) error { // deprioritizedServers are a running list of servers that should be // deprioritized during server selection. Per the specifications, we should - // only ever deprioritize the "previous server". Therefore, this list is - // initialized with a length of 1. - deprioritizedServers := []address.Address{} + // only ever deprioritize the "previous server". + var deprioritizedServers []description.Server // resetForRetry records the error that caused the retry, decrements retries, and resets the // retry loop variables to request a new server and a new connection for the next attempt. @@ -580,12 +584,12 @@ func (op Operation) Execute(ctx context.Context) error { } } - // If we got a connection, close it immediately to release pool resources for - // subsequent retries. + // If we got a connection, close it immediately to release pool resources + // for subsequent retries. + if conn != nil { - // Empty the deprioritizedServer list and - if desc := conn.Description; desc != nil { - deprioritizedServers = []address.Address{conn.Description().Addr} + if desc := conn.Description; desc != nil && op.Deployment.Kind() == description.Sharded { + deprioritizedServers = []description.Server{conn.Description()} } conn.Close() @@ -615,8 +619,6 @@ func (op Operation) Execute(ctx context.Context) error { // If the server or connection are nil, try to select a new server and get a new connection. if srvr == nil || conn == nil { - // NOTE: Each time a "retry" occurs, the server will be reset and this - // NOTE: branch will be entered. srvr, conn, err = op.getServerAndConnection(ctx, deprioritizedServers) if err != nil { // If the returned error is retryable and there are retries remaining (negative diff --git a/x/mongo/driver/operation_test.go b/x/mongo/driver/operation_test.go index 9f0f167166..c308f97241 100644 --- a/x/mongo/driver/operation_test.go +++ b/x/mongo/driver/operation_test.go @@ -20,6 +20,7 @@ import ( "go.mongodb.org/mongo-driver/internal/assert" "go.mongodb.org/mongo-driver/internal/csot" "go.mongodb.org/mongo-driver/internal/handshake" + "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/internal/uuid" "go.mongodb.org/mongo-driver/mongo/address" "go.mongodb.org/mongo-driver/mongo/description" @@ -78,14 +79,13 @@ func TestOperation(t *testing.T) { } _, err := op.selectServer(context.Background(), nil) noerr(t, err) - //got := d.params.selector // Assert the the selector is an operation selector wrapper. - got, ok := d.params.selector.(*opServerSelector) - assert.True(t, ok) + oss, ok := d.params.selector.(*opServerSelector) + require.True(t, ok) - if !cmp.Equal(got.selector, want) { - t.Errorf("Did not get expected server selector. got %v; want %v", got, want) + if !cmp.Equal(oss.selector, want) { + t.Errorf("Did not get expected server selector. got %v; want %v", oss.selector, want) } }) t.Run("uses a default server selector", func(t *testing.T) { @@ -890,26 +890,23 @@ func TestFilterDeprioritizedServers(t *testing.T) { t.Parallel() tests := []struct { - name string - oss opServerSelector - candidates []description.Server - want []description.Server + name string + deprioritized []description.Server + candidates []description.Server + want []description.Server }{ { name: "empty", - oss: opServerSelector{}, candidates: []description.Server{}, want: []description.Server{}, }, { name: "nil candidates", - oss: opServerSelector{}, candidates: nil, want: []description.Server{}, }, { name: "nil deprioritized server list", - oss: opServerSelector{deprioritizedServers: nil}, candidates: []description.Server{ { Addr: address.Address("mongodb://localhost:27017"), @@ -923,12 +920,12 @@ func TestFilterDeprioritizedServers(t *testing.T) { }, { name: "deprioritize single server candidate list", - oss: opServerSelector{ - deprioritizedServers: map[address.Address]bool{ - "mongodb://localhost:27017": true, + candidates: []description.Server{ + { + Addr: address.Address("mongodb://localhost:27017"), }, }, - candidates: []description.Server{ + deprioritized: []description.Server{ { Addr: address.Address("mongodb://localhost:27017"), }, @@ -943,11 +940,6 @@ func TestFilterDeprioritizedServers(t *testing.T) { }, { name: "depriotirize one server in multi server candidate list", - oss: opServerSelector{ - deprioritizedServers: map[address.Address]bool{ - "mongodb://localhost:27017": true, - }, - }, candidates: []description.Server{ { Addr: address.Address("mongodb://localhost:27017"), @@ -959,6 +951,11 @@ func TestFilterDeprioritizedServers(t *testing.T) { Addr: address.Address("mongodb://localhost:27019"), }, }, + deprioritized: []description.Server{ + { + Addr: address.Address("mongodb://localhost:27017"), + }, + }, want: []description.Server{ { Addr: address.Address("mongodb://localhost:27018"), @@ -970,10 +967,12 @@ func TestFilterDeprioritizedServers(t *testing.T) { }, { name: "depriotirize multiple servers in multi server candidate list", - oss: opServerSelector{ - deprioritizedServers: map[address.Address]bool{ - "mongodb://localhost:27017": true, - "mongodb://localhost:27018": true, + deprioritized: []description.Server{ + { + Addr: address.Address("mongodb://localhost:27017"), + }, + { + Addr: address.Address("mongodb://localhost:27018"), }, }, candidates: []description.Server{ @@ -1001,7 +1000,7 @@ func TestFilterDeprioritizedServers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - got := filterDeprioritizedServers(tc.oss, tc.candidates) + got := filterDeprioritizedServers(tc.candidates, tc.deprioritized) assert.ElementsMatch(t, got, tc.want) }) } From 4787e2041804fbf86f5f7c0271044a9577324572 Mon Sep 17 00:00:00 2001 From: Preston Vasquez <24281431+prestonvasquez@users.noreply.github.com> Date: Tue, 22 Aug 2023 10:42:35 -0600 Subject: [PATCH 6/8] GODRIVER-2101 Fix static analysis errors --- mongo/integration/retryable_writes_prose_test.go | 2 +- x/mongo/driver/operation.go | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mongo/integration/retryable_writes_prose_test.go b/mongo/integration/retryable_writes_prose_test.go index 550918a19c..1c8d353f14 100644 --- a/mongo/integration/retryable_writes_prose_test.go +++ b/mongo/integration/retryable_writes_prose_test.go @@ -369,7 +369,7 @@ func TestRetryableWritesProse(t *testing.T) { SetRetryWrites(true). SetMonitor(commandMonitor)) - mt.Coll.InsertOne(context.Background(), bson.D{}) + _, _ = mt.Coll.InsertOne(context.Background(), bson.D{}) assert.Equal(mt, tc.expectedFailCount, failCount) assert.Equal(mt, tc.expectedSuccessCount, successCount) diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 950732b714..84c94bacf1 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -334,8 +334,8 @@ func filterDeprioritizedServers(candidates, deprioritized []description.Server) } dpaSet := make(map[address.Address]*description.Server) - for _, srv := range deprioritized { - dpaSet[srv.Addr] = &srv + for i, srv := range deprioritized { + dpaSet[srv.Addr] = &deprioritized[i] } allowedIndexes := make([]int, 0, len(candidates)) @@ -372,7 +372,7 @@ type opServerSelector struct { } // SelectServer will filter candidates with operation-specific logic before -// passing them onto the user-defined or defualt selector. +// passing them onto the user-defined or default selector. func (oss *opServerSelector) SelectServer( topo description.Topology, candidates []description.Server, From d189587ae5370b1c08a12e75e0d6097e153f1189 Mon Sep 17 00:00:00 2001 From: Preston Vasquez <24281431+prestonvasquez@users.noreply.github.com> Date: Tue, 22 Aug 2023 14:49:28 -0600 Subject: [PATCH 7/8] GODRIVER-2101 Remove empty line --- x/mongo/driver/operation.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 84c94bacf1..8dff264fab 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -586,8 +586,9 @@ func (op Operation) Execute(ctx context.Context) error { // If we got a connection, close it immediately to release pool resources // for subsequent retries. - if conn != nil { + // If we are dealing with a sharded cluster, then mark the failed server + // as "deprioritized". if desc := conn.Description; desc != nil && op.Deployment.Kind() == description.Sharded { deprioritizedServers = []description.Server{conn.Description()} } From 96c42ef9d39678fb9127adacdff083c4f43e999a Mon Sep 17 00:00:00 2001 From: Preston Vasquez <24281431+prestonvasquez@users.noreply.github.com> Date: Wed, 30 Aug 2023 19:19:48 -0600 Subject: [PATCH 8/8] GODRIVER-2101 Use map 'ok' value --- x/mongo/driver/operation.go | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 8dff264fab..4c759c1505 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -338,21 +338,16 @@ func filterDeprioritizedServers(candidates, deprioritized []description.Server) dpaSet[srv.Addr] = &deprioritized[i] } - allowedIndexes := make([]int, 0, len(candidates)) + allowed := []description.Server{} // Iterate over the candidates and append them to the allowdIndexes slice if // they are not in the deprioritizedServers list. - for i, candidate := range candidates { - if srv := dpaSet[candidate.Addr]; srv == nil || !srv.Equal(candidate) { - allowedIndexes = append(allowedIndexes, i) + for _, candidate := range candidates { + if srv, ok := dpaSet[candidate.Addr]; !ok || !srv.Equal(candidate) { + allowed = append(allowed, candidate) } } - allowed := make([]description.Server, len(allowedIndexes)) - for i, idx := range allowedIndexes { - allowed[i] = candidates[idx] - } - // If nothing is allowed, then all available servers must have been // deprioritized. In this case, return the candidates list as-is so that the // selector can find a suitable server