Skip to content

Commit

Permalink
GODRIVER-3054 Add prose tests
Browse files Browse the repository at this point in the history
  • Loading branch information
prestonvasquez committed Nov 29, 2023
1 parent e458f9d commit ef46a0d
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 4 deletions.
48 changes: 48 additions & 0 deletions mongo/integration/handshake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"go.mongodb.org/mongo-driver/mongo/integration/mtest"
"go.mongodb.org/mongo-driver/version"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
)

func TestHandshakeProse(t *testing.T) {
Expand Down Expand Up @@ -199,3 +200,50 @@ func TestHandshakeProse(t *testing.T) {
})
}
}

func TestLoadBalancedConnectionHandshake(t *testing.T) {
mt := mtest.New(t)

lbopts := mtest.NewOptions().ClientType(mtest.Proxy).Topologies(
mtest.LoadBalanced)

mt.RunOpts("LB connection handshake uses OP_MSG", lbopts, func(mt *mtest.T) {
// Ping the server to ensure the handshake has completed.
err := mt.Client.Ping(context.Background(), nil)
require.NoError(mt, err, "Ping error: %v", err)

messages := mt.GetProxiedMessages()
handshakeMessage := messages[:1][0]

hello := handshake.LegacyHello
if os.Getenv("REQUIRE_API_VERSION") == "true" {
hello = "hello"
}

assert.Equal(mt, hello, handshakeMessage.CommandName)
assert.Equal(mt, wiremessage.OpMsg, handshakeMessage.Sent.OpCode)
})

opts := mtest.NewOptions().ClientType(mtest.Proxy).Topologies(
mtest.ReplicaSet,
mtest.Sharded,
mtest.Single,
mtest.ShardedReplicaSet)

mt.RunOpts("non-LB connection handshake uses OP_QUERY", opts, func(mt *mtest.T) {
// Ping the server to ensure the handshake has completed.
err := mt.Client.Ping(context.Background(), nil)
require.NoError(mt, err, "Ping error: %v", err)

messages := mt.GetProxiedMessages()
handshakeMessage := messages[:1][0]

hello := handshake.LegacyHello
if os.Getenv("REQUIRE_API_VERSION") == "true" {
hello = "hello"
}

assert.Equal(mt, hello, handshakeMessage.CommandName)
assert.Equal(mt, wiremessage.OpQuery, handshakeMessage.Sent.OpCode)
})
}
8 changes: 4 additions & 4 deletions x/mongo/driver/operation/hello.go
Original file line number Diff line number Diff line change
Expand Up @@ -575,8 +575,8 @@ func (h *Hello) StreamResponse(ctx context.Context, conn driver.StreamerConnecti
// loadBalanced is False. If this is the case, then the drivers MUST use legacy
// hello for the first message of the initial handshake with the OP_QUERY
// protocol
func isLegacyHandshake(srvAPI *driver.ServerAPIOptions, deployment driver.Deployment) bool {
return srvAPI == nil && deployment.Kind() != description.LoadBalanced
func isLegacyHandshake(srvAPI *driver.ServerAPIOptions, loadbalanced bool) bool {
return srvAPI == nil && !loadbalanced
}

func (h *Hello) createOperation() driver.Operation {
Expand All @@ -592,7 +592,7 @@ func (h *Hello) createOperation() driver.Operation {
ServerAPI: h.serverAPI,
}

if isLegacyHandshake(h.serverAPI, h.d) {
if isLegacyHandshake(h.serverAPI, h.loadBalanced) {
op.Legacy = driver.LegacyHandshake
}

Expand All @@ -616,7 +616,7 @@ func (h *Hello) GetHandshakeInformation(ctx context.Context, _ address.Address,
ServerAPI: h.serverAPI,
}

if isLegacyHandshake(h.serverAPI, deployment) {
if isLegacyHandshake(h.serverAPI, h.loadBalanced) {
op.Legacy = driver.LegacyHandshake
}

Expand Down

0 comments on commit ef46a0d

Please sign in to comment.