Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attempt to clear the BufferedReader class from IndexError #94

Merged
merged 6 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 41 additions & 52 deletions asynch/proto/streams/buffered.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,24 @@ def __init__(self, reader: StreamReader, buffer_max_size: int = constants.BUFFER
self.current_buffer_size = 0
self.position = 0

async def _refill_buffer(self):
if self.position == self.current_buffer_size:
self._reset_buffer()
await self._read_into_buffer()

def _is_buffer_empty(self):
return not (self.buffer or self.position)

async def _is_buffer_readable(self) -> bool:
await self._refill_buffer()
if self._is_buffer_empty():
return False
return True

def _reset_buffer(self):
self.position = 0
self.buffer = bytearray()

async def _read_into_buffer(self):
packet = await self.reader.read(self.buffer_max_size)
self.buffer.extend(packet)
Expand All @@ -124,95 +142,66 @@ def _read_one(self):
async def read_varint(self):
packets = bytearray()
while True:
if self.position == self.current_buffer_size:
self._reset_buffer()
await self._read_into_buffer()
if not (await self._is_buffer_readable()):
break
packet = self._read_one()
packets.append(packet)
if packet < 0x80:
break
return leb128.u.decode(packets)

def _reset_buffer(self):
self.position = 0
self.buffer = bytearray()

async def read_str(self, as_bytes: bool = False):
length = await self.read_varint()
packet = await self.read_bytes(length)
if as_bytes:
return packet
return packet.decode()

async def read_fixed_str(self, length: int, as_bytes: bool = False):
packet = await self.read_bytes(length)
if as_bytes:
return packet
return packet.decode()

async def read_bytes(self, length: int):
packets = bytearray()
while length > 0:
if self.position == self.current_buffer_size:
self._reset_buffer()
await self._read_into_buffer()

if not (await self._is_buffer_readable()):
break
read_position = self.position + length
packet = self.buffer[self.position : read_position] # noqa: E203
length -= len(packet)
self.position += len(packet)
packets.extend(packet)

return packets

async def read_fixed_str(self, length: int, as_bytes: bool = False):
packet = await self.read_bytes(length)
if as_bytes:
return packet
return packet.decode()

async def read_str(self, as_bytes: bool = False):
length = await self.read_varint()
return await self.read_fixed_str(length=length, as_bytes=as_bytes)
stankudrow marked this conversation as resolved.
Show resolved Hide resolved

async def read_int(self, fmt: str):
s = struct.Struct("<" + fmt)
packet = await self.read_bytes(s.size)
return s.unpack(packet)[0]

async def read_int8(
self,
):
async def read_int8(self):
return await self.read_int("b")

async def read_int16(
self,
):
async def read_int16(self):
return await self.read_int("h")

async def read_int32(
self,
):
async def read_int32(self):
return await self.read_int("i")

async def read_int64(
self,
):
async def read_int64(self):
return await self.read_int("q")

async def read_uint8(
self,
):
async def read_uint8(self):
return await self.read_int("B")

async def read_uint16(
self,
):
async def read_uint16(self):
return await self.read_int("H")

async def read_uint32(
self,
):
async def read_uint32(self):
return await self.read_int("I")

async def read_uint64(
self,
):
async def read_uint64(self):
return await self.read_int("Q")

async def read_uint128(
self,
):
async def read_uint128(self):
hi = await self.read_int("Q")
lo = await self.read_int("Q")
return (hi << 64) + lo
Expand Down
44 changes: 44 additions & 0 deletions tests/test_proto/streams/buffered/test_buffered_readers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from asyncio import StreamReader

import pytest

from asynch.proto.streams.buffered import BufferedReader


@pytest.mark.parametrize(
("stream_data", "answer"),
[
(b"9", 57),
(b"32", 51),
],
)
async def test_read_varint(stream_data: bytes, answer: bytes):
"""When `(b"", 0)`, the reading gets stuck."""

stream_reader = StreamReader()
stream_reader.feed_data(stream_data)
reader = BufferedReader(stream_reader)

result = await reader.read_varint()

assert answer == result


@pytest.mark.parametrize(
("stream_data", "bytes_to_read", "answer"),
[
(b"", 0, b""),
(b"02", 1, b"0"),
(b"3456", 4, b"3456"),
],
)
async def test_read_bytes(stream_data: bytes, bytes_to_read: int, answer: bytes):
"""If `bytes_to_read > len(stream_data)`, the reading gets stuck."""

stream_reader = StreamReader()
stream_reader.feed_data(stream_data)
reader = BufferedReader(stream_reader, 1)

result = await reader.read_bytes(bytes_to_read)

assert answer == result
Original file line number Diff line number Diff line change
@@ -1,22 +1,8 @@
from asyncio import StreamReader
from unittest.mock import AsyncMock

import pytest

from asynch.proto.streams.buffered import BufferedReader, BufferedWriter


@pytest.mark.asyncio
async def test_BufferedReader_overflow():
stream_data = b"1234"

stream_reader = StreamReader()
stream_reader.feed_data(stream_data)
reader = BufferedReader(stream_reader, 1)

result = await reader.read_bytes(4)

assert result == stream_data
from asynch.proto.streams.buffered import BufferedWriter


@pytest.mark.asyncio
Expand Down
Loading