diff --git a/mongo/integration/handshake_test.go b/mongo/integration/handshake_test.go index fc1d25eba9..b4a40da61f 100644 --- a/mongo/integration/handshake_test.go +++ b/mongo/integration/handshake_test.go @@ -20,6 +20,7 @@ import ( "go.mongodb.org/mongo-driver/mongo/integration/mtest" "go.mongodb.org/mongo-driver/version" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" ) func TestHandshakeProse(t *testing.T) { @@ -199,3 +200,50 @@ func TestHandshakeProse(t *testing.T) { }) } } + +func TestLoadBalancedConnectionHandshake(t *testing.T) { + mt := mtest.New(t) + + lbopts := mtest.NewOptions().ClientType(mtest.Proxy).Topologies( + mtest.LoadBalanced) + + mt.RunOpts("LB connection handshake uses OP_MSG", lbopts, func(mt *mtest.T) { + // Ping the server to ensure the handshake has completed. + err := mt.Client.Ping(context.Background(), nil) + require.NoError(mt, err, "Ping error: %v", err) + + messages := mt.GetProxiedMessages() + handshakeMessage := messages[:1][0] + + hello := handshake.LegacyHello + if os.Getenv("REQUIRE_API_VERSION") == "true" { + hello = "hello" + } + + assert.Equal(mt, hello, handshakeMessage.CommandName) + assert.Equal(mt, wiremessage.OpMsg, handshakeMessage.Sent.OpCode) + }) + + opts := mtest.NewOptions().ClientType(mtest.Proxy).Topologies( + mtest.ReplicaSet, + mtest.Sharded, + mtest.Single, + mtest.ShardedReplicaSet) + + mt.RunOpts("non-LB connection handshake uses OP_QUERY", opts, func(mt *mtest.T) { + // Ping the server to ensure the handshake has completed. + err := mt.Client.Ping(context.Background(), nil) + require.NoError(mt, err, "Ping error: %v", err) + + messages := mt.GetProxiedMessages() + handshakeMessage := messages[:1][0] + + hello := handshake.LegacyHello + if os.Getenv("REQUIRE_API_VERSION") == "true" { + hello = "hello" + } + + assert.Equal(mt, hello, handshakeMessage.CommandName) + assert.Equal(mt, wiremessage.OpQuery, handshakeMessage.Sent.OpCode) + }) +} diff --git a/x/mongo/driver/operation/hello.go b/x/mongo/driver/operation/hello.go index 6e750fd034..52f656de94 100644 --- a/x/mongo/driver/operation/hello.go +++ b/x/mongo/driver/operation/hello.go @@ -575,8 +575,8 @@ func (h *Hello) StreamResponse(ctx context.Context, conn driver.StreamerConnecti // loadBalanced is False. If this is the case, then the drivers MUST use legacy // hello for the first message of the initial handshake with the OP_QUERY // protocol -func isLegacyHandshake(srvAPI *driver.ServerAPIOptions, deployment driver.Deployment) bool { - return srvAPI == nil && deployment.Kind() != description.LoadBalanced +func isLegacyHandshake(srvAPI *driver.ServerAPIOptions, loadbalanced bool) bool { + return srvAPI == nil && !loadbalanced } func (h *Hello) createOperation() driver.Operation { @@ -592,7 +592,7 @@ func (h *Hello) createOperation() driver.Operation { ServerAPI: h.serverAPI, } - if isLegacyHandshake(h.serverAPI, h.d) { + if isLegacyHandshake(h.serverAPI, h.loadBalanced) { op.Legacy = driver.LegacyHandshake } @@ -616,7 +616,7 @@ func (h *Hello) GetHandshakeInformation(ctx context.Context, _ address.Address, ServerAPI: h.serverAPI, } - if isLegacyHandshake(h.serverAPI, deployment) { + if isLegacyHandshake(h.serverAPI, h.loadBalanced) { op.Legacy = driver.LegacyHandshake }