Skip to content

Commit

Permalink
GODRIVER-3054 Handshake connection should not use legacy for LB (mong…
Browse files Browse the repository at this point in the history
  • Loading branch information
prestonvasquez authored Dec 4, 2023
1 parent 919d4ae commit a8fa12a
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 28 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ evg-test-load-balancers:
go test $(BUILD_TAGS) ./mongo/integration -run TestChangeStreamSpec -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestInitialDNSSeedlistDiscoverySpec/load_balanced -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestLoadBalancerSupport -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestLoadBalancedConnectionHandshake -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration/unified -run TestUnifiedSpec -v -timeout $(TEST_TIMEOUT)s >> test.suite

.PHONY: evg-test-search-index
Expand Down
21 changes: 0 additions & 21 deletions mongo/integration/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -768,27 +768,6 @@ func TestClient(t *testing.T) {
"expected 'OP_MSG' OpCode in wire message, got %q", pair.Sent.OpCode.String())
}
})

// Test that OP_MSG is used for handshakes when loadBalanced is true.
opMsgLBOpts := mtest.NewOptions().ClientType(mtest.Proxy).MinServerVersion("5.0").Topologies(mtest.LoadBalanced)
mt.RunOpts("OP_MSG used for handshakes when loadBalanced is true", opMsgLBOpts, func(mt *mtest.T) {
err := mt.Client.Ping(context.Background(), mtest.PrimaryRp)
assert.Nil(mt, err, "Ping error: %v", err)

msgPairs := mt.GetProxiedMessages()
assert.True(mt, len(msgPairs) >= 3, "expected at least 3 events, got %v", len(msgPairs))

// First three messages should be connection handshakes: one for the heartbeat connection, another for the
// application connection, and a final one for the RTT monitor connection.
for idx, pair := range msgPairs[:3] {
assert.Equal(mt, "hello", pair.CommandName, "expected command name 'hello' at index %d, got %s", idx,
pair.CommandName)

// Assert that appended OpCode is OP_MSG when loadBalanced is true.
assert.Equal(mt, wiremessage.OpMsg, pair.Sent.OpCode,
"expected 'OP_MSG' OpCode in wire message, got %q", pair.Sent.OpCode.String())
}
})
}

func TestClient_BSONOptions(t *testing.T) {
Expand Down
51 changes: 51 additions & 0 deletions mongo/integration/handshake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"go.mongodb.org/mongo-driver/mongo/integration/mtest"
"go.mongodb.org/mongo-driver/version"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
)

func TestHandshakeProse(t *testing.T) {
Expand Down Expand Up @@ -199,3 +200,53 @@ func TestHandshakeProse(t *testing.T) {
})
}
}

func TestLoadBalancedConnectionHandshake(t *testing.T) {
mt := mtest.New(t)

lbopts := mtest.NewOptions().ClientType(mtest.Proxy).Topologies(
mtest.LoadBalanced)

mt.RunOpts("LB connection handshake uses OP_MSG", lbopts, func(mt *mtest.T) {
// Ping the server to ensure the handshake has completed.
err := mt.Client.Ping(context.Background(), nil)
require.NoError(mt, err, "Ping error: %v", err)

messages := mt.GetProxiedMessages()
handshakeMessage := messages[:1][0]

// Per the specifications, if loadBalanced=true, drivers MUST use the hello
// command for the initial handshake and use the OP_MSG protocol.
assert.Equal(mt, "hello", handshakeMessage.CommandName)
assert.Equal(mt, wiremessage.OpMsg, handshakeMessage.Sent.OpCode)
})

opts := mtest.NewOptions().ClientType(mtest.Proxy).Topologies(
mtest.ReplicaSet,
mtest.Sharded,
mtest.Single,
mtest.ShardedReplicaSet)

mt.RunOpts("non-LB connection handshake uses OP_QUERY", opts, func(mt *mtest.T) {
// Ping the server to ensure the handshake has completed.
err := mt.Client.Ping(context.Background(), nil)
require.NoError(mt, err, "Ping error: %v", err)

messages := mt.GetProxiedMessages()
handshakeMessage := messages[:1][0]

want := wiremessage.OpQuery

hello := handshake.LegacyHello
if os.Getenv("REQUIRE_API_VERSION") == "true" {
hello = "hello"

// If the server API version is requested, then we should use OP_MSG
// regardless of the topology
want = wiremessage.OpMsg
}

assert.Equal(mt, hello, handshakeMessage.CommandName)
assert.Equal(mt, want, handshakeMessage.Sent.OpCode)
})
}
3 changes: 2 additions & 1 deletion testdata/load-balancers/sdam-error-handling.json
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,8 @@
},
"data": {
"failCommands": [
"isMaster"
"isMaster",
"hello"
],
"closeConnection": true,
"appName": "lbSDAMErrorTestClient"
Expand Down
2 changes: 1 addition & 1 deletion testdata/load-balancers/sdam-error-handling.yml
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ tests:
configureFailPoint: failCommand
mode: { times: 1 }
data:
failCommands: [isMaster]
failCommands: [isMaster, hello]
closeConnection: true
appName: *singleClientAppName
- name: insertOne
Expand Down
10 changes: 5 additions & 5 deletions x/mongo/driver/operation/hello.go
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ func (h *Hello) handshakeCommand(dst []byte, desc description.SelectedServer) ([
func (h *Hello) command(dst []byte, desc description.SelectedServer) ([]byte, error) {
// Use "hello" if topology is LoadBalanced, API version is declared or server
// has responded with "helloOk". Otherwise, use legacy hello.
if desc.Kind == description.LoadBalanced || h.serverAPI != nil || desc.Server.HelloOK {
if h.loadBalanced || h.serverAPI != nil || desc.Server.HelloOK {
dst = bsoncore.AppendInt32Element(dst, "hello", 1)
} else {
dst = bsoncore.AppendInt32Element(dst, handshake.LegacyHello, 1)
Expand Down Expand Up @@ -575,8 +575,8 @@ func (h *Hello) StreamResponse(ctx context.Context, conn driver.StreamerConnecti
// loadBalanced is False. If this is the case, then the drivers MUST use legacy
// hello for the first message of the initial handshake with the OP_QUERY
// protocol
func isLegacyHandshake(srvAPI *driver.ServerAPIOptions, deployment driver.Deployment) bool {
return srvAPI == nil && deployment.Kind() != description.LoadBalanced
func isLegacyHandshake(srvAPI *driver.ServerAPIOptions, loadbalanced bool) bool {
return srvAPI == nil && !loadbalanced
}

func (h *Hello) createOperation() driver.Operation {
Expand All @@ -592,7 +592,7 @@ func (h *Hello) createOperation() driver.Operation {
ServerAPI: h.serverAPI,
}

if isLegacyHandshake(h.serverAPI, h.d) {
if isLegacyHandshake(h.serverAPI, h.loadBalanced) {
op.Legacy = driver.LegacyHandshake
}

Expand All @@ -616,7 +616,7 @@ func (h *Hello) GetHandshakeInformation(ctx context.Context, _ address.Address,
ServerAPI: h.serverAPI,
}

if isLegacyHandshake(h.serverAPI, deployment) {
if isLegacyHandshake(h.serverAPI, h.loadBalanced) {
op.Legacy = driver.LegacyHandshake
}

Expand Down

0 comments on commit a8fa12a

Please sign in to comment.