diff --git a/x/mongo/driver/topology/connection.go b/x/mongo/driver/topology/connection.go index 49a613aef8..093531f538 100644 --- a/x/mongo/driver/topology/connection.go +++ b/x/mongo/driver/topology/connection.go @@ -476,6 +476,10 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, // read the length as an int32 size := (int32(sizeBuf[0])) | (int32(sizeBuf[1]) << 8) | (int32(sizeBuf[2]) << 16) | (int32(sizeBuf[3]) << 24) + if size < 4 { + err = fmt.Errorf("malformatted message length: %d", size) + return nil, err.Error(), err + } // In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded // defaultMaxMessageSize instead. maxMessageSize := c.desc.MaxMessageSize diff --git a/x/mongo/driver/topology/connection_test.go b/x/mongo/driver/topology/connection_test.go index 946f74d8f2..e7247969ab 100644 --- a/x/mongo/driver/topology/connection_test.go +++ b/x/mongo/driver/topology/connection_test.go @@ -546,6 +546,23 @@ func TestConnection(t *testing.T) { } listener.assertCalledOnce(t) }) + t.Run("size too small errors", func(t *testing.T) { + err := errors.New("malformatted message length: 3") + tnc := &testNetConn{readerr: err, buf: []byte{0x03, 0x00, 0x00, 0x00}} + conn := &connection{id: "foobar", nc: tnc, state: connConnected} + listener := newTestCancellationListener(false) + conn.cancellationListener = listener + + want := ConnectionError{ConnectionID: "foobar", Wrapped: err, message: err.Error()} + _, got := conn.readWireMessage(context.Background()) + if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) { + t.Errorf("errors do not match. got %v; want %v", got, want) + } + if !tnc.closed { + t.Errorf("failed to closeConnection net.Conn after error writing bytes.") + } + listener.assertCalledOnce(t) + }) t.Run("full message read errors", func(t *testing.T) { err := errors.New("Read error") tnc := &testNetConn{readerr: err, buf: []byte{0x11, 0x00, 0x00, 0x00}}