From d7c9ff758e78d7059bb35a99794c57e9d1913b9c Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Thu, 28 Sep 2023 16:47:18 -0400 Subject: [PATCH] WIP --- benchmark/operation_test.go | 21 +++------------------ benchmark/single.go | 5 +---- mongo/integration/client_options_test.go | 4 +--- mongo/integration/mtest/mongotest.go | 9 +++------ x/mongo/driver/session/client_session.go | 16 ++++++++++------ 5 files changed, 18 insertions(+), 37 deletions(-) diff --git a/benchmark/operation_test.go b/benchmark/operation_test.go index 80f20ddc75..04e618e7bc 100644 --- a/benchmark/operation_test.go +++ b/benchmark/operation_test.go @@ -32,15 +32,10 @@ func BenchmarkClientWrite(b *testing.B) { } for _, bm := range benchmarks { b.Run(bm.name, func(b *testing.B) { - client, err := mongo.NewClient(bm.opt) + client, err := mongo.Connect(context.Background(), bm.opt) if err != nil { b.Fatalf("error creating client: %v", err) } - ctx := context.Background() - err = client.Connect(ctx) - if err != nil { - b.Fatalf("error connecting: %v", err) - } defer client.Disconnect(context.Background()) coll := client.Database("test").Collection("test") _, err = coll.DeleteMany(context.Background(), bson.D{}) @@ -76,15 +71,10 @@ func BenchmarkClientBulkWrite(b *testing.B) { } for _, bm := range benchmarks { b.Run(bm.name, func(b *testing.B) { - client, err := mongo.NewClient(bm.opt) + client, err := mongo.Connect(context.Background(), bm.opt) if err != nil { b.Fatalf("error creating client: %v", err) } - ctx := context.Background() - err = client.Connect(ctx) - if err != nil { - b.Fatalf("error connecting: %v", err) - } defer client.Disconnect(context.Background()) coll := client.Database("test").Collection("test") _, err = coll.DeleteMany(context.Background(), bson.D{}) @@ -125,15 +115,10 @@ func BenchmarkClientRead(b *testing.B) { } for _, bm := range benchmarks { b.Run(bm.name, func(b *testing.B) { - client, err := mongo.NewClient(bm.opt) + client, err := mongo.Connect(context.Background(), bm.opt) if err != nil { b.Fatalf("error creating client: %v", err) } - ctx := context.Background() - err = client.Connect(ctx) - if err != nil { - b.Fatalf("error connecting: %v", err) - } defer client.Disconnect(context.Background()) coll := client.Database("test").Collection("test") _, err = coll.DeleteMany(context.Background(), bson.D{}) diff --git a/benchmark/single.go b/benchmark/single.go index b85b46f34f..333a8f66be 100644 --- a/benchmark/single.go +++ b/benchmark/single.go @@ -29,13 +29,10 @@ func getClientDB(ctx context.Context) (*mongo.Database, error) { if err != nil { return nil, err } - client, err := mongo.NewClient(options.Client().ApplyURI(cs.String())) + client, err := mongo.Connect(ctx, options.Client().ApplyURI(cs.String())) if err != nil { return nil, err } - if err = client.Connect(ctx); err != nil { - return nil, err - } db := client.Database(integtest.GetDBName(cs)) return db, nil diff --git a/mongo/integration/client_options_test.go b/mongo/integration/client_options_test.go index 0fb068bc5e..43703d5e33 100644 --- a/mongo/integration/client_options_test.go +++ b/mongo/integration/client_options_test.go @@ -24,9 +24,7 @@ func TestClientOptions_CustomDialer(t *testing.T) { cs := integtest.ConnString(t) opts := options.Client().ApplyURI(cs.String()).SetDialer(td) integtest.AddTestServerAPIVersion(opts) - client, err := mongo.NewClient(opts) - require.NoError(t, err) - err = client.Connect(context.Background()) + client, err := mongo.Connect(context.Background(), opts) require.NoError(t, err) _, err = client.ListDatabases(context.Background(), bson.D{}) require.NoError(t, err) diff --git a/mongo/integration/mtest/mongotest.go b/mongo/integration/mtest/mongotest.go index 22aaa99a0a..ed7bddade7 100644 --- a/mongo/integration/mtest/mongotest.go +++ b/mongo/integration/mtest/mongotest.go @@ -692,13 +692,13 @@ func (t *T) createTestClient() { // pin to first mongos pinnedHostList := []string{testContext.connString.Hosts[0]} uriOpts := options.Client().ApplyURI(testContext.connString.Original).SetHosts(pinnedHostList) - t.Client, err = mongo.NewClient(uriOpts, clientOpts) + t.Client, err = mongo.Connect(context.Background(), uriOpts, clientOpts) case Mock: // clear pool monitor to avoid configuration error clientOpts.PoolMonitor = nil t.mockDeployment = newMockDeployment() clientOpts.Deployment = t.mockDeployment - t.Client, err = mongo.NewClient(clientOpts) + t.Client, err = mongo.Connect(context.Background(), clientOpts) case Proxy: t.proxyDialer = newProxyDialer() clientOpts.SetDialer(t.proxyDialer) @@ -716,14 +716,11 @@ func (t *T) createTestClient() { } // Pass in uriOpts first so clientOpts wins if there are any conflicting settings. - t.Client, err = mongo.NewClient(uriOpts, clientOpts) + t.Client, err = mongo.Connect(context.Background(), uriOpts, clientOpts) } if err != nil { t.Fatalf("error creating client: %v", err) } - if err := t.Client.Connect(context.Background()); err != nil { - t.Fatalf("error connecting client: %v", err) - } } func (t *T) createTestCollection() { diff --git a/x/mongo/driver/session/client_session.go b/x/mongo/driver/session/client_session.go index 32e556ff1f..4181b6d3de 100644 --- a/x/mongo/driver/session/client_session.go +++ b/x/mongo/driver/session/client_session.go @@ -465,13 +465,17 @@ func (c *Client) CommitTransaction() error { // w timeout of 10 seconds. This should be called after a commit transaction operation fails with a // retryable error or after a successful commit transaction operation. func (c *Client) UpdateCommitTransactionWriteConcern() { - if c.CurrentWc == nil { - c.CurrentWc = &writeconcern.WriteConcern{} - } - c.CurrentWc.W = "majority" - if c.CurrentWc.WTimeout == 0 { - c.CurrentWc.WTimeout = 10 * time.Second + wc := &writeconcern.WriteConcern{} + timeout := 10 * time.Second + if c.CurrentWc != nil { + *wc = *c.CurrentWc + if c.CurrentWc.WTimeout != 0 { + timeout = c.CurrentWc.WTimeout + } } + wc.W = "majority" + wc.WTimeout = timeout + c.CurrentWc = wc } // CheckAbortTransaction checks to see if allowed to abort transaction and returns