diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index bd0a7ad503..e6a5518a4b 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -2,6 +2,7 @@ from abc import ABC, abstractmethod from asyncio import gather +from dataclasses import dataclass from itertools import starmap from typing import TYPE_CHECKING, Protocol, runtime_checkable @@ -19,7 +20,34 @@ __all__ = ["ByteGetter", "ByteSetter", "Store", "set_or_delete"] -ByteRangeRequest: TypeAlias = tuple[int | None, int | None] + +@dataclass +class RangeByteRequest: + """Request a specific byte range""" + + start: int + """The start of the byte range request (inclusive).""" + end: int + """The end of the byte range request (exclusive).""" + + +@dataclass +class OffsetByteRequest: + """Request all bytes starting from a given byte offset""" + + offset: int + """The byte offset for the offset range request.""" + + +@dataclass +class SuffixByteRequest: + """Request up to the last `n` bytes""" + + suffix: int + """The number of bytes from the suffix to request.""" + + +ByteRequest: TypeAlias = RangeByteRequest | OffsetByteRequest | SuffixByteRequest class Store(ABC): @@ -141,14 +169,20 @@ async def get( self, key: str, prototype: BufferPrototype, - byte_range: ByteRangeRequest | None = None, + byte_range: ByteRequest | None = None, ) -> Buffer | None: """Retrieve the value associated with a given key. Parameters ---------- key : str - byte_range : tuple[int | None, int | None], optional + byte_range : ByteRequest, optional + + ByteRequest may be one of the following. If not provided, all data associated with the key is retrieved. + + - RangeByteRequest(int, int): Request a specific range of bytes in the form (start, end). The end is exclusive. If the given range is zero-length or starts after the end of the object, an error will be returned. Additionally, if the range ends after the end of the object, the entire remainder of the object will be returned. Otherwise, the exact requested range will be returned. + - OffsetByteRequest(int): Request all bytes starting from a given byte offset. This is equivalent to bytes={int}- as an HTTP header. + - SuffixByteRequest(int): Request the last int bytes. Note that here, int is the size of the request, not the byte offset. This is equivalent to bytes=-{int} as an HTTP header. Returns ------- @@ -160,7 +194,7 @@ async def get( async def get_partial_values( self, prototype: BufferPrototype, - key_ranges: Iterable[tuple[str, ByteRangeRequest]], + key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: """Retrieve possibly partial values from given key_ranges. @@ -338,7 +372,7 @@ def close(self) -> None: self._is_open = False async def _get_many( - self, requests: Iterable[tuple[str, BufferPrototype, ByteRangeRequest | None]] + self, requests: Iterable[tuple[str, BufferPrototype, ByteRequest | None]] ) -> AsyncGenerator[tuple[str, Buffer | None], None]: """ Retrieve a collection of objects from storage. In general this method does not guarantee @@ -416,17 +450,17 @@ async def getsize_prefix(self, prefix: str) -> int: @runtime_checkable class ByteGetter(Protocol): async def get( - self, prototype: BufferPrototype, byte_range: ByteRangeRequest | None = None + self, prototype: BufferPrototype, byte_range: ByteRequest | None = None ) -> Buffer | None: ... @runtime_checkable class ByteSetter(Protocol): async def get( - self, prototype: BufferPrototype, byte_range: ByteRangeRequest | None = None + self, prototype: BufferPrototype, byte_range: ByteRequest | None = None ) -> Buffer | None: ... - async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) -> None: ... + async def set(self, value: Buffer, byte_range: ByteRequest | None = None) -> None: ... async def delete(self) -> None: ... diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index a01145b3b2..160a74e892 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -17,7 +17,13 @@ Codec, CodecPipeline, ) -from zarr.abc.store import ByteGetter, ByteRangeRequest, ByteSetter +from zarr.abc.store import ( + ByteGetter, + ByteRequest, + ByteSetter, + RangeByteRequest, + SuffixByteRequest, +) from zarr.codecs.bytes import BytesCodec from zarr.codecs.crc32c_ import Crc32cCodec from zarr.core.array_spec import ArrayConfig, ArraySpec @@ -77,7 +83,7 @@ class _ShardingByteGetter(ByteGetter): chunk_coords: ChunkCoords async def get( - self, prototype: BufferPrototype, byte_range: ByteRangeRequest | None = None + self, prototype: BufferPrototype, byte_range: ByteRequest | None = None ) -> Buffer | None: assert byte_range is None, "byte_range is not supported within shards" assert ( @@ -90,7 +96,7 @@ async def get( class _ShardingByteSetter(_ShardingByteGetter, ByteSetter): shard_dict: ShardMutableMapping - async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) -> None: + async def set(self, value: Buffer, byte_range: ByteRequest | None = None) -> None: assert byte_range is None, "byte_range is not supported within shards" self.shard_dict[self.chunk_coords] = value @@ -129,7 +135,7 @@ def get_chunk_slice(self, chunk_coords: ChunkCoords) -> tuple[int, int] | None: if (chunk_start, chunk_len) == (MAX_UINT_64, MAX_UINT_64): return None else: - return (int(chunk_start), int(chunk_len)) + return (int(chunk_start), int(chunk_start + chunk_len)) def set_chunk_slice(self, chunk_coords: ChunkCoords, chunk_slice: slice | None) -> None: localized_chunk = self._localize_chunk(chunk_coords) @@ -203,7 +209,7 @@ def create_empty( def __getitem__(self, chunk_coords: ChunkCoords) -> Buffer: chunk_byte_slice = self.index.get_chunk_slice(chunk_coords) if chunk_byte_slice: - return self.buf[chunk_byte_slice[0] : (chunk_byte_slice[0] + chunk_byte_slice[1])] + return self.buf[chunk_byte_slice[0] : chunk_byte_slice[1]] raise KeyError def __len__(self) -> int: @@ -504,7 +510,8 @@ async def _decode_partial_single( chunk_byte_slice = shard_index.get_chunk_slice(chunk_coords) if chunk_byte_slice: chunk_bytes = await byte_getter.get( - prototype=chunk_spec.prototype, byte_range=chunk_byte_slice + prototype=chunk_spec.prototype, + byte_range=RangeByteRequest(chunk_byte_slice[0], chunk_byte_slice[1]), ) if chunk_bytes: shard_dict[chunk_coords] = chunk_bytes @@ -696,11 +703,12 @@ async def _load_shard_index_maybe( shard_index_size = self._shard_index_size(chunks_per_shard) if self.index_location == ShardingCodecIndexLocation.start: index_bytes = await byte_getter.get( - prototype=numpy_buffer_prototype(), byte_range=(0, shard_index_size) + prototype=numpy_buffer_prototype(), + byte_range=RangeByteRequest(0, shard_index_size), ) else: index_bytes = await byte_getter.get( - prototype=numpy_buffer_prototype(), byte_range=(-shard_index_size, None) + prototype=numpy_buffer_prototype(), byte_range=SuffixByteRequest(shard_index_size) ) if index_bytes is not None: return await self._decode_shard_index(index_bytes, chunks_per_shard) diff --git a/src/zarr/core/common.py b/src/zarr/core/common.py index 7205b8c206..ad3316b619 100644 --- a/src/zarr/core/common.py +++ b/src/zarr/core/common.py @@ -31,7 +31,6 @@ ZATTRS_JSON = ".zattrs" ZMETADATA_V2_JSON = ".zmetadata" -ByteRangeRequest = tuple[int | None, int | None] BytesLike = bytes | bytearray | memoryview ShapeLike = tuple[int, ...] | int ChunkCoords = tuple[int, ...] diff --git a/src/zarr/storage/_common.py b/src/zarr/storage/_common.py index 523e470671..6ab539bb0a 100644 --- a/src/zarr/storage/_common.py +++ b/src/zarr/storage/_common.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Literal -from zarr.abc.store import ByteRangeRequest, Store +from zarr.abc.store import ByteRequest, Store from zarr.core.buffer import Buffer, default_buffer_prototype from zarr.core.common import ZARR_JSON, ZARRAY_JSON, ZGROUP_JSON, AccessModeLiteral, ZarrFormat from zarr.errors import ContainsArrayAndGroupError, ContainsArrayError, ContainsGroupError @@ -102,7 +102,7 @@ async def open( async def get( self, prototype: BufferPrototype | None = None, - byte_range: ByteRangeRequest | None = None, + byte_range: ByteRequest | None = None, ) -> Buffer | None: """ Read bytes from the store. @@ -111,7 +111,7 @@ async def get( ---------- prototype : BufferPrototype, optional The buffer prototype to use when reading the bytes. - byte_range : ByteRangeRequest, optional + byte_range : ByteRequest, optional The range of bytes to read. Returns @@ -123,7 +123,7 @@ async def get( prototype = default_buffer_prototype() return await self.store.get(self.path, prototype=prototype, byte_range=byte_range) - async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) -> None: + async def set(self, value: Buffer, byte_range: ByteRequest | None = None) -> None: """ Write bytes to the store. @@ -131,7 +131,7 @@ async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) - ---------- value : Buffer The buffer to write. - byte_range : ByteRangeRequest, optional + byte_range : ByteRequest, optional The range of bytes to write. If None, the entire buffer is written. Raises diff --git a/src/zarr/storage/_fsspec.py b/src/zarr/storage/_fsspec.py index 89d80320dd..99c8c778e7 100644 --- a/src/zarr/storage/_fsspec.py +++ b/src/zarr/storage/_fsspec.py @@ -3,7 +3,13 @@ import warnings from typing import TYPE_CHECKING, Any -from zarr.abc.store import ByteRangeRequest, Store +from zarr.abc.store import ( + ByteRequest, + OffsetByteRequest, + RangeByteRequest, + Store, + SuffixByteRequest, +) from zarr.storage._common import _dereference_path if TYPE_CHECKING: @@ -199,7 +205,7 @@ async def get( self, key: str, prototype: BufferPrototype, - byte_range: ByteRangeRequest | None = None, + byte_range: ByteRequest | None = None, ) -> Buffer | None: # docstring inherited if not self._is_open: @@ -207,23 +213,26 @@ async def get( path = _dereference_path(self.path, key) try: - if byte_range: - # fsspec uses start/end, not start/length - start, length = byte_range - if start is not None and length is not None: - end = start + length - elif length is not None: - end = length - else: - end = None - value = prototype.buffer.from_bytes( - await ( - self.fs._cat_file(path, start=byte_range[0], end=end) - if byte_range - else self.fs._cat_file(path) + if byte_range is None: + value = prototype.buffer.from_bytes(await self.fs._cat_file(path)) + elif isinstance(byte_range, RangeByteRequest): + value = prototype.buffer.from_bytes( + await self.fs._cat_file( + path, + start=byte_range.start, + end=byte_range.end, + ) ) - ) - + elif isinstance(byte_range, OffsetByteRequest): + value = prototype.buffer.from_bytes( + await self.fs._cat_file(path, start=byte_range.offset, end=None) + ) + elif isinstance(byte_range, SuffixByteRequest): + value = prototype.buffer.from_bytes( + await self.fs._cat_file(path, start=-byte_range.suffix, end=None) + ) + else: + raise ValueError(f"Unexpected byte_range, got {byte_range}.") except self.allowed_exceptions: return None except OSError as e: @@ -270,25 +279,35 @@ async def exists(self, key: str) -> bool: async def get_partial_values( self, prototype: BufferPrototype, - key_ranges: Iterable[tuple[str, ByteRangeRequest]], + key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited if key_ranges: - paths, starts, stops = zip( - *( - ( - _dereference_path(self.path, k[0]), - k[1][0], - ((k[1][0] or 0) + k[1][1]) if k[1][1] is not None else None, - ) - for k in key_ranges - ), - strict=False, - ) + # _cat_ranges expects a list of paths, start, and end ranges, so we need to reformat each ByteRequest. + key_ranges = list(key_ranges) + paths: list[str] = [] + starts: list[int | None] = [] + stops: list[int | None] = [] + for key, byte_range in key_ranges: + paths.append(_dereference_path(self.path, key)) + if byte_range is None: + starts.append(None) + stops.append(None) + elif isinstance(byte_range, RangeByteRequest): + starts.append(byte_range.start) + stops.append(byte_range.end) + elif isinstance(byte_range, OffsetByteRequest): + starts.append(byte_range.offset) + stops.append(None) + elif isinstance(byte_range, SuffixByteRequest): + starts.append(-byte_range.suffix) + stops.append(None) + else: + raise ValueError(f"Unexpected byte_range, got {byte_range}.") else: return [] # TODO: expectations for exceptions or missing keys? - res = await self.fs._cat_ranges(list(paths), starts, stops, on_error="return") + res = await self.fs._cat_ranges(paths, starts, stops, on_error="return") # the following is an s3-specific condition we probably don't want to leak res = [b"" if (isinstance(r, OSError) and "not satisfiable" in str(r)) else r for r in res] for r in res: diff --git a/src/zarr/storage/_local.py b/src/zarr/storage/_local.py index f4226792cb..5eaa85c592 100644 --- a/src/zarr/storage/_local.py +++ b/src/zarr/storage/_local.py @@ -7,7 +7,13 @@ from pathlib import Path from typing import TYPE_CHECKING -from zarr.abc.store import ByteRangeRequest, Store +from zarr.abc.store import ( + ByteRequest, + OffsetByteRequest, + RangeByteRequest, + Store, + SuffixByteRequest, +) from zarr.core.buffer import Buffer from zarr.core.buffer.core import default_buffer_prototype from zarr.core.common import concurrent_map @@ -18,29 +24,20 @@ from zarr.core.buffer import BufferPrototype -def _get( - path: Path, prototype: BufferPrototype, byte_range: tuple[int | None, int | None] | None -) -> Buffer: - if byte_range is not None: - if byte_range[0] is None: - start = 0 - else: - start = byte_range[0] - - end = (start + byte_range[1]) if byte_range[1] is not None else None - else: +def _get(path: Path, prototype: BufferPrototype, byte_range: ByteRequest | None) -> Buffer: + if byte_range is None: return prototype.buffer.from_bytes(path.read_bytes()) with path.open("rb") as f: size = f.seek(0, io.SEEK_END) - if start is not None: - if start >= 0: - f.seek(start) - else: - f.seek(max(0, size + start)) - if end is not None: - if end < 0: - end = size + end - return prototype.buffer.from_bytes(f.read(end - f.tell())) + if isinstance(byte_range, RangeByteRequest): + f.seek(byte_range.start) + return prototype.buffer.from_bytes(f.read(byte_range.end - f.tell())) + elif isinstance(byte_range, OffsetByteRequest): + f.seek(byte_range.offset) + elif isinstance(byte_range, SuffixByteRequest): + f.seek(max(0, size - byte_range.suffix)) + else: + raise TypeError(f"Unexpected byte_range, got {byte_range}.") return prototype.buffer.from_bytes(f.read()) @@ -127,7 +124,7 @@ async def get( self, key: str, prototype: BufferPrototype | None = None, - byte_range: tuple[int | None, int | None] | None = None, + byte_range: ByteRequest | None = None, ) -> Buffer | None: # docstring inherited if prototype is None: @@ -145,7 +142,7 @@ async def get( async def get_partial_values( self, prototype: BufferPrototype, - key_ranges: Iterable[tuple[str, ByteRangeRequest]], + key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited args = [] diff --git a/src/zarr/storage/_logging.py b/src/zarr/storage/_logging.py index 45ddeef40c..5ca716df2c 100644 --- a/src/zarr/storage/_logging.py +++ b/src/zarr/storage/_logging.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: from collections.abc import AsyncGenerator, Generator, Iterable - from zarr.abc.store import ByteRangeRequest + from zarr.abc.store import ByteRequest from zarr.core.buffer import Buffer, BufferPrototype counter: defaultdict[str, int] @@ -161,7 +161,7 @@ async def get( self, key: str, prototype: BufferPrototype, - byte_range: tuple[int | None, int | None] | None = None, + byte_range: ByteRequest | None = None, ) -> Buffer | None: # docstring inherited with self.log(key): @@ -170,7 +170,7 @@ async def get( async def get_partial_values( self, prototype: BufferPrototype, - key_ranges: Iterable[tuple[str, ByteRangeRequest]], + key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited keys = ",".join([k[0] for k in key_ranges]) diff --git a/src/zarr/storage/_memory.py b/src/zarr/storage/_memory.py index 1f8dd75768..d35ecbe33d 100644 --- a/src/zarr/storage/_memory.py +++ b/src/zarr/storage/_memory.py @@ -3,10 +3,10 @@ from logging import getLogger from typing import TYPE_CHECKING, Self -from zarr.abc.store import ByteRangeRequest, Store +from zarr.abc.store import ByteRequest, Store from zarr.core.buffer import Buffer, gpu from zarr.core.common import concurrent_map -from zarr.storage._utils import _normalize_interval_index +from zarr.storage._utils import _normalize_byte_range_index if TYPE_CHECKING: from collections.abc import AsyncIterator, Iterable, MutableMapping @@ -75,7 +75,7 @@ async def get( self, key: str, prototype: BufferPrototype, - byte_range: tuple[int | None, int | None] | None = None, + byte_range: ByteRequest | None = None, ) -> Buffer | None: # docstring inherited if not self._is_open: @@ -83,20 +83,20 @@ async def get( assert isinstance(key, str) try: value = self._store_dict[key] - start, length = _normalize_interval_index(value, byte_range) - return prototype.buffer.from_buffer(value[start : start + length]) + start, stop = _normalize_byte_range_index(value, byte_range) + return prototype.buffer.from_buffer(value[start:stop]) except KeyError: return None async def get_partial_values( self, prototype: BufferPrototype, - key_ranges: Iterable[tuple[str, ByteRangeRequest]], + key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited # All the key-ranges arguments goes with the same prototype - async def _get(key: str, byte_range: ByteRangeRequest) -> Buffer | None: + async def _get(key: str, byte_range: ByteRequest | None) -> Buffer | None: return await self.get(key, prototype=prototype, byte_range=byte_range) return await concurrent_map(key_ranges, _get, limit=None) diff --git a/src/zarr/storage/_utils.py b/src/zarr/storage/_utils.py index 7ba82b00fd..4fc3171eb8 100644 --- a/src/zarr/storage/_utils.py +++ b/src/zarr/storage/_utils.py @@ -4,7 +4,10 @@ from pathlib import Path from typing import TYPE_CHECKING +from zarr.abc.store import OffsetByteRequest, RangeByteRequest, SuffixByteRequest + if TYPE_CHECKING: + from zarr.abc.store import ByteRequest from zarr.core.buffer import Buffer @@ -44,25 +47,22 @@ def normalize_path(path: str | bytes | Path | None) -> str: return result -def _normalize_interval_index( - data: Buffer, interval: tuple[int | None, int | None] | None -) -> tuple[int, int]: +def _normalize_byte_range_index(data: Buffer, byte_range: ByteRequest | None) -> tuple[int, int]: """ - Convert an implicit interval into an explicit start and length + Convert an ByteRequest into an explicit start and stop """ - if interval is None: + if byte_range is None: start = 0 - length = len(data) + stop = len(data) + 1 + elif isinstance(byte_range, RangeByteRequest): + start = byte_range.start + stop = byte_range.end + elif isinstance(byte_range, OffsetByteRequest): + start = byte_range.offset + stop = len(data) + 1 + elif isinstance(byte_range, SuffixByteRequest): + start = len(data) - byte_range.suffix + stop = len(data) + 1 else: - maybe_start, maybe_len = interval - if maybe_start is None: - start = 0 - else: - start = maybe_start - - if maybe_len is None: - length = len(data) - start - else: - length = maybe_len - - return (start, length) + raise ValueError(f"Unexpected byte_range, got {byte_range}.") + return (start, stop) diff --git a/src/zarr/storage/_wrapper.py b/src/zarr/storage/_wrapper.py index c160100084..255e965439 100644 --- a/src/zarr/storage/_wrapper.py +++ b/src/zarr/storage/_wrapper.py @@ -7,7 +7,7 @@ from types import TracebackType from typing import Any, Self - from zarr.abc.store import ByteRangeRequest + from zarr.abc.store import ByteRequest from zarr.core.buffer import Buffer, BufferPrototype from zarr.core.common import BytesLike @@ -70,14 +70,14 @@ def __eq__(self, value: object) -> bool: return type(self) is type(value) and self._store.__eq__(value) async def get( - self, key: str, prototype: BufferPrototype, byte_range: ByteRangeRequest | None = None + self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None ) -> Buffer | None: return await self._store.get(key, prototype, byte_range) async def get_partial_values( self, prototype: BufferPrototype, - key_ranges: Iterable[tuple[str, ByteRangeRequest]], + key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: return await self._store.get_partial_values(prototype, key_ranges) @@ -133,7 +133,7 @@ def close(self) -> None: self._store.close() async def _get_many( - self, requests: Iterable[tuple[str, BufferPrototype, ByteRangeRequest | None]] + self, requests: Iterable[tuple[str, BufferPrototype, ByteRequest | None]] ) -> AsyncGenerator[tuple[str, Buffer | None], None]: async for req in self._store._get_many(requests): yield req diff --git a/src/zarr/storage/_zip.py b/src/zarr/storage/_zip.py index a186b3cf59..e808b80e4e 100644 --- a/src/zarr/storage/_zip.py +++ b/src/zarr/storage/_zip.py @@ -7,7 +7,13 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Literal -from zarr.abc.store import ByteRangeRequest, Store +from zarr.abc.store import ( + ByteRequest, + OffsetByteRequest, + RangeByteRequest, + Store, + SuffixByteRequest, +) from zarr.core.buffer import Buffer, BufferPrototype if TYPE_CHECKING: @@ -138,23 +144,24 @@ def _get( self, key: str, prototype: BufferPrototype, - byte_range: ByteRangeRequest | None = None, + byte_range: ByteRequest | None = None, ) -> Buffer | None: # docstring inherited try: with self._zf.open(key) as f: # will raise KeyError if byte_range is None: return prototype.buffer.from_bytes(f.read()) - start, length = byte_range - if start: - if start < 0: - start = f.seek(start, os.SEEK_END) + start - else: - start = f.seek(start, os.SEEK_SET) - if length: - return prototype.buffer.from_bytes(f.read(length)) + elif isinstance(byte_range, RangeByteRequest): + f.seek(byte_range.start) + return prototype.buffer.from_bytes(f.read(byte_range.end - f.tell())) + size = f.seek(0, os.SEEK_END) + if isinstance(byte_range, OffsetByteRequest): + f.seek(byte_range.offset) + elif isinstance(byte_range, SuffixByteRequest): + f.seek(max(0, size - byte_range.suffix)) else: - return prototype.buffer.from_bytes(f.read()) + raise TypeError(f"Unexpected byte_range, got {byte_range}.") + return prototype.buffer.from_bytes(f.read()) except KeyError: return None @@ -162,7 +169,7 @@ async def get( self, key: str, prototype: BufferPrototype, - byte_range: ByteRangeRequest | None = None, + byte_range: ByteRequest | None = None, ) -> Buffer | None: # docstring inherited assert isinstance(key, str) @@ -173,7 +180,7 @@ async def get( async def get_partial_values( self, prototype: BufferPrototype, - key_ranges: Iterable[tuple[str, ByteRangeRequest]], + key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited out = [] diff --git a/src/zarr/testing/stateful.py b/src/zarr/testing/stateful.py index cc0f220807..1a1ef0e3a3 100644 --- a/src/zarr/testing/stateful.py +++ b/src/zarr/testing/stateful.py @@ -355,9 +355,8 @@ def get_partial_values(self, data: DataObject) -> None: model_vals_ls = [] for key, byte_range in key_range: - start = byte_range[0] or 0 - step = byte_range[1] - stop = start + step if step is not None else None + start = byte_range.start + stop = byte_range.end model_vals_ls.append(self.model[key][start:stop]) assert all( diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index ada028c273..602d001693 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -9,15 +9,21 @@ if TYPE_CHECKING: from typing import Any - from zarr.abc.store import ByteRangeRequest + from zarr.abc.store import ByteRequest from zarr.core.buffer.core import BufferPrototype import pytest -from zarr.abc.store import ByteRangeRequest, Store +from zarr.abc.store import ( + ByteRequest, + OffsetByteRequest, + RangeByteRequest, + Store, + SuffixByteRequest, +) from zarr.core.buffer import Buffer, default_buffer_prototype from zarr.core.sync import _collect_aiterator -from zarr.storage._utils import _normalize_interval_index +from zarr.storage._utils import _normalize_byte_range_index from zarr.testing.utils import assert_bytes_equal __all__ = ["StoreTests"] @@ -115,18 +121,18 @@ def test_store_supports_listing(self, store: S) -> None: @pytest.mark.parametrize("key", ["c/0", "foo/c/0.0", "foo/0/0"]) @pytest.mark.parametrize("data", [b"\x01\x02\x03\x04", b""]) - @pytest.mark.parametrize("byte_range", [None, (0, None), (1, None), (1, 2), (None, 1)]) - async def test_get( - self, store: S, key: str, data: bytes, byte_range: tuple[int | None, int | None] | None - ) -> None: + @pytest.mark.parametrize( + "byte_range", [None, RangeByteRequest(1, 4), OffsetByteRequest(1), SuffixByteRequest(1)] + ) + async def test_get(self, store: S, key: str, data: bytes, byte_range: ByteRequest) -> None: """ Ensure that data can be read from the store using the store.get method. """ data_buf = self.buffer_cls.from_bytes(data) await self.set(store, key, data_buf) observed = await store.get(key, prototype=default_buffer_prototype(), byte_range=byte_range) - start, length = _normalize_interval_index(data_buf, interval=byte_range) - expected = data_buf[start : start + length] + start, stop = _normalize_byte_range_index(data_buf, byte_range=byte_range) + expected = data_buf[start:stop] assert_bytes_equal(observed, expected) async def test_get_many(self, store: S) -> None: @@ -179,13 +185,17 @@ async def test_set_many(self, store: S) -> None: "key_ranges", [ [], - [("zarr.json", (0, 1))], - [("c/0", (0, 1)), ("zarr.json", (0, None))], - [("c/0/0", (0, 1)), ("c/0/1", (None, 2)), ("c/0/2", (0, 3))], + [("zarr.json", RangeByteRequest(0, 2))], + [("c/0", RangeByteRequest(0, 2)), ("zarr.json", None)], + [ + ("c/0/0", RangeByteRequest(0, 2)), + ("c/0/1", SuffixByteRequest(2)), + ("c/0/2", OffsetByteRequest(2)), + ], ], ) async def test_get_partial_values( - self, store: S, key_ranges: list[tuple[str, tuple[int | None, int | None]]] + self, store: S, key_ranges: list[tuple[str, ByteRequest]] ) -> None: # put all of the data for key, _ in key_ranges: @@ -367,7 +377,7 @@ async def set(self, key: str, value: Buffer) -> None: await self._store.set(key, value) async def get( - self, key: str, prototype: BufferPrototype, byte_range: ByteRangeRequest | None = None + self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None ) -> Buffer | None: """ Add latency to the ``get`` method. @@ -380,7 +390,7 @@ async def get( The key to get prototype : BufferPrototype The BufferPrototype to use. - byte_range : ByteRangeRequest, optional + byte_range : ByteRequest, optional An optional byte range. Returns diff --git a/src/zarr/testing/strategies.py b/src/zarr/testing/strategies.py index 1bde01b8f9..b948651ce6 100644 --- a/src/zarr/testing/strategies.py +++ b/src/zarr/testing/strategies.py @@ -7,6 +7,7 @@ from hypothesis.strategies import SearchStrategy import zarr +from zarr.abc.store import RangeByteRequest from zarr.core.array import Array from zarr.core.common import ZarrFormat from zarr.core.sync import sync @@ -194,12 +195,13 @@ def key_ranges( Function to generate key_ranges strategy for get_partial_values() returns list strategy w/ form:: - [(key, (range_start, range_step)), - (key, (range_start, range_step)),...] + [(key, (range_start, range_end)), + (key, (range_start, range_end)),...] """ - byte_ranges = st.tuples( - st.none() | st.integers(min_value=0, max_value=max_size), - st.none() | st.integers(min_value=0, max_value=max_size), + byte_ranges = st.builds( + RangeByteRequest, + start=st.integers(min_value=0, max_value=max_size), + end=st.integers(min_value=0, max_value=max_size), ) key_tuple = st.tuples(keys, byte_ranges) return st.lists(key_tuple, min_size=1, max_size=10) diff --git a/tests/test_store/test_fsspec.py b/tests/test_store/test_fsspec.py index b307f2cdf4..2713a2969d 100644 --- a/tests/test_store/test_fsspec.py +++ b/tests/test_store/test_fsspec.py @@ -8,6 +8,7 @@ from botocore.session import Session import zarr.api.asynchronous +from zarr.abc.store import OffsetByteRequest from zarr.core.buffer import Buffer, cpu, default_buffer_prototype from zarr.core.sync import _collect_aiterator, sync from zarr.storage import FsspecStore @@ -97,7 +98,7 @@ async def test_basic() -> None: assert await store.exists("foo") assert (await store.get("foo", prototype=default_buffer_prototype())).to_bytes() == data out = await store.get_partial_values( - prototype=default_buffer_prototype(), key_ranges=[("foo", (1, None))] + prototype=default_buffer_prototype(), key_ranges=[("foo", OffsetByteRequest(1))] ) assert out[0].to_bytes() == data[1:]