From 25b76854724e947a013f160dcb3db88e9ccd0da2 Mon Sep 17 00:00:00 2001 From: Preston Vasquez <24281431+prestonvasquez@users.noreply.github.com> Date: Thu, 7 Sep 2023 09:00:54 -0600 Subject: [PATCH] GODRIVER-2935 Add legacy tests back to auth and client --- mongo/integration/client_test.go | 4 +-- x/mongo/driver/auth/speculative_scram_test.go | 4 +-- x/mongo/driver/auth/speculative_x509_test.go | 4 +-- x/mongo/driver/drivertest/channel_conn.go | 32 +++++++++++++++++++ 4 files changed, 38 insertions(+), 6 deletions(-) diff --git a/mongo/integration/client_test.go b/mongo/integration/client_test.go index 914ca863b7..007427824b 100644 --- a/mongo/integration/client_test.go +++ b/mongo/integration/client_test.go @@ -733,8 +733,8 @@ func TestClient(t *testing.T) { pair := msgPairs[0] assert.Equal(mt, handshake.LegacyHello, pair.CommandName, "expected command name %s at index 0, got %s", handshake.LegacyHello, pair.CommandName) - assert.Equal(mt, wiremessage.OpMsg, pair.Sent.OpCode, - "expected 'OP_MSG' OpCode in wire message, got %q", pair.Sent.OpCode.String()) + assert.Equal(mt, wiremessage.OpQuery, pair.Sent.OpCode, + "expected 'OP_QUERY' OpCode in wire message, got %q", pair.Sent.OpCode.String()) // Look for a saslContinue in the remaining proxied messages and assert that it uses the OP_MSG OpCode, as wire // version is now known to be >= 6. diff --git a/x/mongo/driver/auth/speculative_scram_test.go b/x/mongo/driver/auth/speculative_scram_test.go index f2234e227c..a159891adc 100644 --- a/x/mongo/driver/auth/speculative_scram_test.go +++ b/x/mongo/driver/auth/speculative_scram_test.go @@ -93,7 +93,7 @@ func TestSpeculativeSCRAM(t *testing.T) { // Assert that the driver sent hello with the speculative authentication message. assert.Equal(t, len(tc.payloads), len(conn.Written), "expected %d wire messages to be sent, got %d", len(tc.payloads), (conn.Written)) - helloCmd, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written) + helloCmd, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written) assert.Nil(t, err, "error parsing hello command: %v", err) assertCommandName(t, helloCmd, handshake.LegacyHello) @@ -177,7 +177,7 @@ func TestSpeculativeSCRAM(t *testing.T) { assert.Equal(t, numResponses, len(conn.Written), "expected %d wire messages to be sent, got %d", numResponses, len(conn.Written)) - hello, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written) + hello, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written) assert.Nil(t, err, "error parsing hello command: %v", err) assertCommandName(t, hello, handshake.LegacyHello) _, err = hello.LookupErr("speculativeAuthenticate") diff --git a/x/mongo/driver/auth/speculative_x509_test.go b/x/mongo/driver/auth/speculative_x509_test.go index 13fdf2b185..cf46de6ffd 100644 --- a/x/mongo/driver/auth/speculative_x509_test.go +++ b/x/mongo/driver/auth/speculative_x509_test.go @@ -58,7 +58,7 @@ func TestSpeculativeX509(t *testing.T) { assert.Equal(t, numResponses, len(conn.Written), "expected %d wire messages to be sent, got %d", numResponses, len(conn.Written)) - hello, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written) + hello, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written) assert.Nil(t, err, "error parsing hello command: %v", err) assertCommandName(t, hello, handshake.LegacyHello) @@ -103,7 +103,7 @@ func TestSpeculativeX509(t *testing.T) { assert.Equal(t, numResponses, len(conn.Written), "expected %d wire messages to be sent, got %d", numResponses, len(conn.Written)) - hello, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written) + hello, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written) assert.Nil(t, err, "error parsing hello command: %v", err) assertCommandName(t, hello, handshake.LegacyHello) _, err = hello.LookupErr("speculativeAuthenticate") diff --git a/x/mongo/driver/drivertest/channel_conn.go b/x/mongo/driver/drivertest/channel_conn.go index d2ae8df248..27be4c264d 100644 --- a/x/mongo/driver/drivertest/channel_conn.go +++ b/x/mongo/driver/drivertest/channel_conn.go @@ -99,6 +99,38 @@ func MakeReply(doc bsoncore.Document) []byte { return bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:]))) } +// GetCommandFromQueryWireMessage returns the command sent in an OP_QUERY wire message. +func GetCommandFromQueryWireMessage(wm []byte) (bsoncore.Document, error) { + var ok bool + _, _, _, _, wm, ok = wiremessage.ReadHeader(wm) + if !ok { + return nil, errors.New("could not read header") + } + _, wm, ok = wiremessage.ReadQueryFlags(wm) + if !ok { + return nil, errors.New("could not read flags") + } + _, wm, ok = wiremessage.ReadQueryFullCollectionName(wm) + if !ok { + return nil, errors.New("could not read fullCollectionName") + } + _, wm, ok = wiremessage.ReadQueryNumberToSkip(wm) + if !ok { + return nil, errors.New("could not read numberToSkip") + } + _, wm, ok = wiremessage.ReadQueryNumberToReturn(wm) + if !ok { + return nil, errors.New("could not read numberToReturn") + } + + var query bsoncore.Document + query, wm, ok = wiremessage.ReadQueryQuery(wm) + if !ok { + return nil, errors.New("could not read query") + } + return query, nil +} + // GetCommandFromMsgWireMessage returns the command document sent in an OP_MSG wire message. func GetCommandFromMsgWireMessage(wm []byte) (bsoncore.Document, error) { var ok bool