diff --git a/asynch/proto/streams/buffered.py b/asynch/proto/streams/buffered.py index f7d3dd0..f146140 100644 --- a/asynch/proto/streams/buffered.py +++ b/asynch/proto/streams/buffered.py @@ -111,6 +111,24 @@ def __init__(self, reader: StreamReader, buffer_max_size: int = constants.BUFFER self.current_buffer_size = 0 self.position = 0 + async def _refill_buffer(self): + if self.position == self.current_buffer_size: + self._reset_buffer() + await self._read_into_buffer() + + def _is_buffer_empty(self): + return not (self.buffer or self.position) + + async def _is_buffer_readable(self) -> bool: + await self._refill_buffer() + if self._is_buffer_empty(): + return False + return True + + def _reset_buffer(self): + self.position = 0 + self.buffer = bytearray() + async def _read_into_buffer(self): packet = await self.reader.read(self.buffer_max_size) self.buffer.extend(packet) @@ -124,95 +142,66 @@ def _read_one(self): async def read_varint(self): packets = bytearray() while True: - if self.position == self.current_buffer_size: - self._reset_buffer() - await self._read_into_buffer() + if not (await self._is_buffer_readable()): + break packet = self._read_one() packets.append(packet) if packet < 0x80: break return leb128.u.decode(packets) - def _reset_buffer(self): - self.position = 0 - self.buffer = bytearray() - - async def read_str(self, as_bytes: bool = False): - length = await self.read_varint() - packet = await self.read_bytes(length) - if as_bytes: - return packet - return packet.decode() - - async def read_fixed_str(self, length: int, as_bytes: bool = False): - packet = await self.read_bytes(length) - if as_bytes: - return packet - return packet.decode() - async def read_bytes(self, length: int): packets = bytearray() while length > 0: - if self.position == self.current_buffer_size: - self._reset_buffer() - await self._read_into_buffer() - + if not (await self._is_buffer_readable()): + break read_position = self.position + length packet = self.buffer[self.position : read_position] # noqa: E203 length -= len(packet) self.position += len(packet) packets.extend(packet) - return packets + async def read_fixed_str(self, length: int, as_bytes: bool = False): + packet = await self.read_bytes(length) + if as_bytes: + return packet + return packet.decode() + + async def read_str(self, as_bytes: bool = False): + length = await self.read_varint() + return await self.read_fixed_str(length=length, as_bytes=as_bytes) + async def read_int(self, fmt: str): s = struct.Struct("<" + fmt) packet = await self.read_bytes(s.size) return s.unpack(packet)[0] - async def read_int8( - self, - ): + async def read_int8(self): return await self.read_int("b") - async def read_int16( - self, - ): + async def read_int16(self): return await self.read_int("h") - async def read_int32( - self, - ): + async def read_int32(self): return await self.read_int("i") - async def read_int64( - self, - ): + async def read_int64(self): return await self.read_int("q") - async def read_uint8( - self, - ): + async def read_uint8(self): return await self.read_int("B") - async def read_uint16( - self, - ): + async def read_uint16(self): return await self.read_int("H") - async def read_uint32( - self, - ): + async def read_uint32(self): return await self.read_int("I") - async def read_uint64( - self, - ): + async def read_uint64(self): return await self.read_int("Q") - async def read_uint128( - self, - ): + async def read_uint128(self): hi = await self.read_int("Q") lo = await self.read_int("Q") return (hi << 64) + lo diff --git a/tests/test_proto/streams/buffered/test_buffered_readers.py b/tests/test_proto/streams/buffered/test_buffered_readers.py new file mode 100644 index 0000000..dc11a97 --- /dev/null +++ b/tests/test_proto/streams/buffered/test_buffered_readers.py @@ -0,0 +1,44 @@ +from asyncio import StreamReader + +import pytest + +from asynch.proto.streams.buffered import BufferedReader + + +@pytest.mark.parametrize( + ("stream_data", "answer"), + [ + (b"9", 57), + (b"32", 51), + ], +) +async def test_read_varint(stream_data: bytes, answer: bytes): + """When `(b"", 0)`, the reading gets stuck.""" + + stream_reader = StreamReader() + stream_reader.feed_data(stream_data) + reader = BufferedReader(stream_reader) + + result = await reader.read_varint() + + assert answer == result + + +@pytest.mark.parametrize( + ("stream_data", "bytes_to_read", "answer"), + [ + (b"", 0, b""), + (b"02", 1, b"0"), + (b"3456", 4, b"3456"), + ], +) +async def test_read_bytes(stream_data: bytes, bytes_to_read: int, answer: bytes): + """If `bytes_to_read > len(stream_data)`, the reading gets stuck.""" + + stream_reader = StreamReader() + stream_reader.feed_data(stream_data) + reader = BufferedReader(stream_reader, 1) + + result = await reader.read_bytes(bytes_to_read) + + assert answer == result diff --git a/tests/test_proto/test_io.py b/tests/test_proto/streams/buffered/test_buffered_writers.py similarity index 64% rename from tests/test_proto/test_io.py rename to tests/test_proto/streams/buffered/test_buffered_writers.py index e93c105..8c5b2a4 100644 --- a/tests/test_proto/test_io.py +++ b/tests/test_proto/streams/buffered/test_buffered_writers.py @@ -1,22 +1,8 @@ -from asyncio import StreamReader from unittest.mock import AsyncMock import pytest -from asynch.proto.streams.buffered import BufferedReader, BufferedWriter - - -@pytest.mark.asyncio -async def test_BufferedReader_overflow(): - stream_data = b"1234" - - stream_reader = StreamReader() - stream_reader.feed_data(stream_data) - reader = BufferedReader(stream_reader, 1) - - result = await reader.read_bytes(4) - - assert result == stream_data +from asynch.proto.streams.buffered import BufferedWriter @pytest.mark.asyncio