Skip to content

Commit

Permalink
GODRIVER-2935 Add legacy tests back to auth and client
Browse files Browse the repository at this point in the history
  • Loading branch information
prestonvasquez committed Sep 7, 2023
1 parent c343f12 commit 25b7685
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 6 deletions.
4 changes: 2 additions & 2 deletions mongo/integration/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -733,8 +733,8 @@ func TestClient(t *testing.T) {
pair := msgPairs[0]
assert.Equal(mt, handshake.LegacyHello, pair.CommandName, "expected command name %s at index 0, got %s",
handshake.LegacyHello, pair.CommandName)
assert.Equal(mt, wiremessage.OpMsg, pair.Sent.OpCode,
"expected 'OP_MSG' OpCode in wire message, got %q", pair.Sent.OpCode.String())
assert.Equal(mt, wiremessage.OpQuery, pair.Sent.OpCode,
"expected 'OP_QUERY' OpCode in wire message, got %q", pair.Sent.OpCode.String())

// Look for a saslContinue in the remaining proxied messages and assert that it uses the OP_MSG OpCode, as wire
// version is now known to be >= 6.
Expand Down
4 changes: 2 additions & 2 deletions x/mongo/driver/auth/speculative_scram_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func TestSpeculativeSCRAM(t *testing.T) {
// Assert that the driver sent hello with the speculative authentication message.
assert.Equal(t, len(tc.payloads), len(conn.Written), "expected %d wire messages to be sent, got %d",
len(tc.payloads), (conn.Written))
helloCmd, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written)
helloCmd, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written)
assert.Nil(t, err, "error parsing hello command: %v", err)
assertCommandName(t, helloCmd, handshake.LegacyHello)

Expand Down Expand Up @@ -177,7 +177,7 @@ func TestSpeculativeSCRAM(t *testing.T) {

assert.Equal(t, numResponses, len(conn.Written), "expected %d wire messages to be sent, got %d",
numResponses, len(conn.Written))
hello, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written)
hello, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written)
assert.Nil(t, err, "error parsing hello command: %v", err)
assertCommandName(t, hello, handshake.LegacyHello)
_, err = hello.LookupErr("speculativeAuthenticate")
Expand Down
4 changes: 2 additions & 2 deletions x/mongo/driver/auth/speculative_x509_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func TestSpeculativeX509(t *testing.T) {

assert.Equal(t, numResponses, len(conn.Written), "expected %d wire messages to be sent, got %d",
numResponses, len(conn.Written))
hello, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written)
hello, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written)
assert.Nil(t, err, "error parsing hello command: %v", err)
assertCommandName(t, hello, handshake.LegacyHello)

Expand Down Expand Up @@ -103,7 +103,7 @@ func TestSpeculativeX509(t *testing.T) {

assert.Equal(t, numResponses, len(conn.Written), "expected %d wire messages to be sent, got %d",
numResponses, len(conn.Written))
hello, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written)
hello, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written)
assert.Nil(t, err, "error parsing hello command: %v", err)
assertCommandName(t, hello, handshake.LegacyHello)
_, err = hello.LookupErr("speculativeAuthenticate")
Expand Down
32 changes: 32 additions & 0 deletions x/mongo/driver/drivertest/channel_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,38 @@ func MakeReply(doc bsoncore.Document) []byte {
return bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:])))
}

// GetCommandFromQueryWireMessage returns the command sent in an OP_QUERY wire message.
func GetCommandFromQueryWireMessage(wm []byte) (bsoncore.Document, error) {
var ok bool
_, _, _, _, wm, ok = wiremessage.ReadHeader(wm)
if !ok {
return nil, errors.New("could not read header")
}
_, wm, ok = wiremessage.ReadQueryFlags(wm)
if !ok {
return nil, errors.New("could not read flags")
}
_, wm, ok = wiremessage.ReadQueryFullCollectionName(wm)
if !ok {
return nil, errors.New("could not read fullCollectionName")
}
_, wm, ok = wiremessage.ReadQueryNumberToSkip(wm)
if !ok {
return nil, errors.New("could not read numberToSkip")
}
_, wm, ok = wiremessage.ReadQueryNumberToReturn(wm)
if !ok {
return nil, errors.New("could not read numberToReturn")
}

var query bsoncore.Document
query, wm, ok = wiremessage.ReadQueryQuery(wm)
if !ok {
return nil, errors.New("could not read query")
}
return query, nil
}

// GetCommandFromMsgWireMessage returns the command document sent in an OP_MSG wire message.
func GetCommandFromMsgWireMessage(wm []byte) (bsoncore.Document, error) {
var ok bool
Expand Down

0 comments on commit 25b7685

Please sign in to comment.