diff --git a/Makefile b/Makefile index 73572ce598..72ea34f6d9 100644 --- a/Makefile +++ b/Makefile @@ -148,6 +148,7 @@ evg-test-load-balancers: go test $(BUILD_TAGS) ./internal/test/integration -run TestChangeStreamSpec -v -timeout $(TEST_TIMEOUT)s >> test.suite go test $(BUILD_TAGS) ./internal/test/integration -run TestInitialDNSSeedlistDiscoverySpec/load_balanced -v -timeout $(TEST_TIMEOUT)s >> test.suite go test $(BUILD_TAGS) ./internal/test/integration -run TestLoadBalancerSupport -v -timeout $(TEST_TIMEOUT)s >> test.suite + go test $(BUILD_TAGS) ./internal/test/integration -run TestLoadBalancedConnectionHandshake -v -timeout $(TEST_TIMEOUT)s >> test.suite go test $(BUILD_TAGS) ./internal/test/integration/unified -run TestUnifiedSpec -v -timeout $(TEST_TIMEOUT)s >> test.suite .PHONY: evg-test-search-index diff --git a/internal/test/integration/client_test.go b/internal/test/integration/client_test.go index 0167c18f35..bcaaec9df6 100644 --- a/internal/test/integration/client_test.go +++ b/internal/test/integration/client_test.go @@ -760,27 +760,6 @@ func TestClient(t *testing.T) { "expected 'OP_MSG' OpCode in wire message, got %q", pair.Sent.OpCode.String()) } }) - - // Test that OP_MSG is used for handshakes when loadBalanced is true. - opMsgLBOpts := mtest.NewOptions().ClientType(mtest.Proxy).MinServerVersion("5.0").Topologies(mtest.LoadBalanced) - mt.RunOpts("OP_MSG used for handshakes when loadBalanced is true", opMsgLBOpts, func(mt *mtest.T) { - err := mt.Client.Ping(context.Background(), mtest.PrimaryRp) - assert.Nil(mt, err, "Ping error: %v", err) - - msgPairs := mt.GetProxiedMessages() - assert.True(mt, len(msgPairs) >= 3, "expected at least 3 events, got %v", len(msgPairs)) - - // First three messages should be connection handshakes: one for the heartbeat connection, another for the - // application connection, and a final one for the RTT monitor connection. - for idx, pair := range msgPairs[:3] { - assert.Equal(mt, "hello", pair.CommandName, "expected command name 'hello' at index %d, got %s", idx, - pair.CommandName) - - // Assert that appended OpCode is OP_MSG when loadBalanced is true. - assert.Equal(mt, wiremessage.OpMsg, pair.Sent.OpCode, - "expected 'OP_MSG' OpCode in wire message, got %q", pair.Sent.OpCode.String()) - } - }) } func TestClient_BSONOptions(t *testing.T) { diff --git a/internal/test/integration/handshake_test.go b/internal/test/integration/handshake_test.go index 7535db4c6b..b897198d73 100644 --- a/internal/test/integration/handshake_test.go +++ b/internal/test/integration/handshake_test.go @@ -20,6 +20,7 @@ import ( "go.mongodb.org/mongo-driver/internal/test/integration/mtest" "go.mongodb.org/mongo-driver/version" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" ) func TestHandshakeProse(t *testing.T) { @@ -199,3 +200,53 @@ func TestHandshakeProse(t *testing.T) { }) } } + +func TestLoadBalancedConnectionHandshake(t *testing.T) { + mt := mtest.New(t) + + lbopts := mtest.NewOptions().ClientType(mtest.Proxy).Topologies( + mtest.LoadBalanced) + + mt.RunOpts("LB connection handshake uses OP_MSG", lbopts, func(mt *mtest.T) { + // Ping the server to ensure the handshake has completed. + err := mt.Client.Ping(context.Background(), nil) + require.NoError(mt, err, "Ping error: %v", err) + + messages := mt.GetProxiedMessages() + handshakeMessage := messages[:1][0] + + // Per the specifications, if loadBalanced=true, drivers MUST use the hello + // command for the initial handshake and use the OP_MSG protocol. + assert.Equal(mt, "hello", handshakeMessage.CommandName) + assert.Equal(mt, wiremessage.OpMsg, handshakeMessage.Sent.OpCode) + }) + + opts := mtest.NewOptions().ClientType(mtest.Proxy).Topologies( + mtest.ReplicaSet, + mtest.Sharded, + mtest.Single, + mtest.ShardedReplicaSet) + + mt.RunOpts("non-LB connection handshake uses OP_QUERY", opts, func(mt *mtest.T) { + // Ping the server to ensure the handshake has completed. + err := mt.Client.Ping(context.Background(), nil) + require.NoError(mt, err, "Ping error: %v", err) + + messages := mt.GetProxiedMessages() + handshakeMessage := messages[:1][0] + + want := wiremessage.OpQuery + + hello := handshake.LegacyHello + if os.Getenv("REQUIRE_API_VERSION") == "true" { + hello = "hello" + + // If the server API version is requested, then we should use OP_MSG + // regardless of the topology + want = wiremessage.OpMsg + } + + assert.Equal(mt, hello, handshakeMessage.CommandName) + assert.Equal(mt, want, handshakeMessage.Sent.OpCode) + }) +} diff --git a/testdata/load-balancers/sdam-error-handling.json b/testdata/load-balancers/sdam-error-handling.json index c0f114cdfb..b9a11f2527 100644 --- a/testdata/load-balancers/sdam-error-handling.json +++ b/testdata/load-balancers/sdam-error-handling.json @@ -279,7 +279,8 @@ }, "data": { "failCommands": [ - "isMaster" + "isMaster", + "hello" ], "closeConnection": true, "appName": "lbSDAMErrorTestClient" diff --git a/testdata/load-balancers/sdam-error-handling.yml b/testdata/load-balancers/sdam-error-handling.yml index 0e6c8993af..0f93b8a249 100644 --- a/testdata/load-balancers/sdam-error-handling.yml +++ b/testdata/load-balancers/sdam-error-handling.yml @@ -153,7 +153,7 @@ tests: configureFailPoint: failCommand mode: { times: 1 } data: - failCommands: [isMaster] + failCommands: [isMaster, hello] closeConnection: true appName: *singleClientAppName - name: insertOne diff --git a/x/mongo/driver/operation/hello.go b/x/mongo/driver/operation/hello.go index 6e750fd034..16f2ebf6c0 100644 --- a/x/mongo/driver/operation/hello.go +++ b/x/mongo/driver/operation/hello.go @@ -530,7 +530,7 @@ func (h *Hello) handshakeCommand(dst []byte, desc description.SelectedServer) ([ func (h *Hello) command(dst []byte, desc description.SelectedServer) ([]byte, error) { // Use "hello" if topology is LoadBalanced, API version is declared or server // has responded with "helloOk". Otherwise, use legacy hello. - if desc.Kind == description.LoadBalanced || h.serverAPI != nil || desc.Server.HelloOK { + if h.loadBalanced || h.serverAPI != nil || desc.Server.HelloOK { dst = bsoncore.AppendInt32Element(dst, "hello", 1) } else { dst = bsoncore.AppendInt32Element(dst, handshake.LegacyHello, 1) @@ -575,8 +575,8 @@ func (h *Hello) StreamResponse(ctx context.Context, conn driver.StreamerConnecti // loadBalanced is False. If this is the case, then the drivers MUST use legacy // hello for the first message of the initial handshake with the OP_QUERY // protocol -func isLegacyHandshake(srvAPI *driver.ServerAPIOptions, deployment driver.Deployment) bool { - return srvAPI == nil && deployment.Kind() != description.LoadBalanced +func isLegacyHandshake(srvAPI *driver.ServerAPIOptions, loadbalanced bool) bool { + return srvAPI == nil && !loadbalanced } func (h *Hello) createOperation() driver.Operation { @@ -592,7 +592,7 @@ func (h *Hello) createOperation() driver.Operation { ServerAPI: h.serverAPI, } - if isLegacyHandshake(h.serverAPI, h.d) { + if isLegacyHandshake(h.serverAPI, h.loadBalanced) { op.Legacy = driver.LegacyHandshake } @@ -616,7 +616,7 @@ func (h *Hello) GetHandshakeInformation(ctx context.Context, _ address.Address, ServerAPI: h.serverAPI, } - if isLegacyHandshake(h.serverAPI, deployment) { + if isLegacyHandshake(h.serverAPI, h.loadBalanced) { op.Legacy = driver.LegacyHandshake }