Skip to content

Commit

Permalink
Use dataclasses for ByteRangeRequests (zarr-developers#2585)
Browse files Browse the repository at this point in the history
* Use TypedDicts for more literate ByteRangeRequests

* Update utility function

* fixes sharding

* Ignore mypy errors

* Fix offset in _normalize_byte_range_index

* Update get_partial_values for FsspecStore

* Re-add fs._cat_ranges argument

* Simplify typing

* Update _normalize to return start, stop

* Use explicit range

* Use dataclasses

* Update typing

* Update docstring

* Rename ExplicitRange to ExplicitByteRequest

* Rename OffsetRange to OffsetByteRequest

* Rename SuffixRange to SuffixByteRequest

* Use match; case instead of if; elif

* Revert "Use match; case instead of if; elif"

This reverts commit a7d35f8.

* Update ByteRangeRequest to ByteRequest

* Remove ByteRange definition from common

* Rename ExplicitByteRequest to RangeByteRequest

* Provide more informative error message

---------

Co-authored-by: Norman Rzepka <[email protected]>
  • Loading branch information
maxrjones and normanrz authored Jan 9, 2025
1 parent 22ebded commit 0328656
Show file tree
Hide file tree
Showing 15 changed files with 221 additions and 145 deletions.
50 changes: 42 additions & 8 deletions src/zarr/abc/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from abc import ABC, abstractmethod
from asyncio import gather
from dataclasses import dataclass
from itertools import starmap
from typing import TYPE_CHECKING, Protocol, runtime_checkable

Expand All @@ -19,7 +20,34 @@

__all__ = ["ByteGetter", "ByteSetter", "Store", "set_or_delete"]

ByteRangeRequest: TypeAlias = tuple[int | None, int | None]

@dataclass
class RangeByteRequest:
"""Request a specific byte range"""

start: int
"""The start of the byte range request (inclusive)."""
end: int
"""The end of the byte range request (exclusive)."""


@dataclass
class OffsetByteRequest:
"""Request all bytes starting from a given byte offset"""

offset: int
"""The byte offset for the offset range request."""


@dataclass
class SuffixByteRequest:
"""Request up to the last `n` bytes"""

suffix: int
"""The number of bytes from the suffix to request."""


ByteRequest: TypeAlias = RangeByteRequest | OffsetByteRequest | SuffixByteRequest


class Store(ABC):
Expand Down Expand Up @@ -141,14 +169,20 @@ async def get(
self,
key: str,
prototype: BufferPrototype,
byte_range: ByteRangeRequest | None = None,
byte_range: ByteRequest | None = None,
) -> Buffer | None:
"""Retrieve the value associated with a given key.
Parameters
----------
key : str
byte_range : tuple[int | None, int | None], optional
byte_range : ByteRequest, optional
ByteRequest may be one of the following. If not provided, all data associated with the key is retrieved.
- RangeByteRequest(int, int): Request a specific range of bytes in the form (start, end). The end is exclusive. If the given range is zero-length or starts after the end of the object, an error will be returned. Additionally, if the range ends after the end of the object, the entire remainder of the object will be returned. Otherwise, the exact requested range will be returned.
- OffsetByteRequest(int): Request all bytes starting from a given byte offset. This is equivalent to bytes={int}- as an HTTP header.
- SuffixByteRequest(int): Request the last int bytes. Note that here, int is the size of the request, not the byte offset. This is equivalent to bytes=-{int} as an HTTP header.
Returns
-------
Expand All @@ -160,7 +194,7 @@ async def get(
async def get_partial_values(
self,
prototype: BufferPrototype,
key_ranges: Iterable[tuple[str, ByteRangeRequest]],
key_ranges: Iterable[tuple[str, ByteRequest | None]],
) -> list[Buffer | None]:
"""Retrieve possibly partial values from given key_ranges.
Expand Down Expand Up @@ -338,7 +372,7 @@ def close(self) -> None:
self._is_open = False

async def _get_many(
self, requests: Iterable[tuple[str, BufferPrototype, ByteRangeRequest | None]]
self, requests: Iterable[tuple[str, BufferPrototype, ByteRequest | None]]
) -> AsyncGenerator[tuple[str, Buffer | None], None]:
"""
Retrieve a collection of objects from storage. In general this method does not guarantee
Expand Down Expand Up @@ -416,17 +450,17 @@ async def getsize_prefix(self, prefix: str) -> int:
@runtime_checkable
class ByteGetter(Protocol):
async def get(
self, prototype: BufferPrototype, byte_range: ByteRangeRequest | None = None
self, prototype: BufferPrototype, byte_range: ByteRequest | None = None
) -> Buffer | None: ...


@runtime_checkable
class ByteSetter(Protocol):
async def get(
self, prototype: BufferPrototype, byte_range: ByteRangeRequest | None = None
self, prototype: BufferPrototype, byte_range: ByteRequest | None = None
) -> Buffer | None: ...

async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) -> None: ...
async def set(self, value: Buffer, byte_range: ByteRequest | None = None) -> None: ...

async def delete(self) -> None: ...

Expand Down
24 changes: 16 additions & 8 deletions src/zarr/codecs/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@
Codec,
CodecPipeline,
)
from zarr.abc.store import ByteGetter, ByteRangeRequest, ByteSetter
from zarr.abc.store import (
ByteGetter,
ByteRequest,
ByteSetter,
RangeByteRequest,
SuffixByteRequest,
)
from zarr.codecs.bytes import BytesCodec
from zarr.codecs.crc32c_ import Crc32cCodec
from zarr.core.array_spec import ArrayConfig, ArraySpec
Expand Down Expand Up @@ -77,7 +83,7 @@ class _ShardingByteGetter(ByteGetter):
chunk_coords: ChunkCoords

async def get(
self, prototype: BufferPrototype, byte_range: ByteRangeRequest | None = None
self, prototype: BufferPrototype, byte_range: ByteRequest | None = None
) -> Buffer | None:
assert byte_range is None, "byte_range is not supported within shards"
assert (
Expand All @@ -90,7 +96,7 @@ async def get(
class _ShardingByteSetter(_ShardingByteGetter, ByteSetter):
shard_dict: ShardMutableMapping

async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) -> None:
async def set(self, value: Buffer, byte_range: ByteRequest | None = None) -> None:
assert byte_range is None, "byte_range is not supported within shards"
self.shard_dict[self.chunk_coords] = value

Expand Down Expand Up @@ -129,7 +135,7 @@ def get_chunk_slice(self, chunk_coords: ChunkCoords) -> tuple[int, int] | None:
if (chunk_start, chunk_len) == (MAX_UINT_64, MAX_UINT_64):
return None
else:
return (int(chunk_start), int(chunk_len))
return (int(chunk_start), int(chunk_start + chunk_len))

def set_chunk_slice(self, chunk_coords: ChunkCoords, chunk_slice: slice | None) -> None:
localized_chunk = self._localize_chunk(chunk_coords)
Expand Down Expand Up @@ -203,7 +209,7 @@ def create_empty(
def __getitem__(self, chunk_coords: ChunkCoords) -> Buffer:
chunk_byte_slice = self.index.get_chunk_slice(chunk_coords)
if chunk_byte_slice:
return self.buf[chunk_byte_slice[0] : (chunk_byte_slice[0] + chunk_byte_slice[1])]
return self.buf[chunk_byte_slice[0] : chunk_byte_slice[1]]
raise KeyError

def __len__(self) -> int:
Expand Down Expand Up @@ -504,7 +510,8 @@ async def _decode_partial_single(
chunk_byte_slice = shard_index.get_chunk_slice(chunk_coords)
if chunk_byte_slice:
chunk_bytes = await byte_getter.get(
prototype=chunk_spec.prototype, byte_range=chunk_byte_slice
prototype=chunk_spec.prototype,
byte_range=RangeByteRequest(chunk_byte_slice[0], chunk_byte_slice[1]),
)
if chunk_bytes:
shard_dict[chunk_coords] = chunk_bytes
Expand Down Expand Up @@ -696,11 +703,12 @@ async def _load_shard_index_maybe(
shard_index_size = self._shard_index_size(chunks_per_shard)
if self.index_location == ShardingCodecIndexLocation.start:
index_bytes = await byte_getter.get(
prototype=numpy_buffer_prototype(), byte_range=(0, shard_index_size)
prototype=numpy_buffer_prototype(),
byte_range=RangeByteRequest(0, shard_index_size),
)
else:
index_bytes = await byte_getter.get(
prototype=numpy_buffer_prototype(), byte_range=(-shard_index_size, None)
prototype=numpy_buffer_prototype(), byte_range=SuffixByteRequest(shard_index_size)
)
if index_bytes is not None:
return await self._decode_shard_index(index_bytes, chunks_per_shard)
Expand Down
1 change: 0 additions & 1 deletion src/zarr/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
ZATTRS_JSON = ".zattrs"
ZMETADATA_V2_JSON = ".zmetadata"

ByteRangeRequest = tuple[int | None, int | None]
BytesLike = bytes | bytearray | memoryview
ShapeLike = tuple[int, ...] | int
ChunkCoords = tuple[int, ...]
Expand Down
10 changes: 5 additions & 5 deletions src/zarr/storage/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal

from zarr.abc.store import ByteRangeRequest, Store
from zarr.abc.store import ByteRequest, Store
from zarr.core.buffer import Buffer, default_buffer_prototype
from zarr.core.common import ZARR_JSON, ZARRAY_JSON, ZGROUP_JSON, AccessModeLiteral, ZarrFormat
from zarr.errors import ContainsArrayAndGroupError, ContainsArrayError, ContainsGroupError
Expand Down Expand Up @@ -102,7 +102,7 @@ async def open(
async def get(
self,
prototype: BufferPrototype | None = None,
byte_range: ByteRangeRequest | None = None,
byte_range: ByteRequest | None = None,
) -> Buffer | None:
"""
Read bytes from the store.
Expand All @@ -111,7 +111,7 @@ async def get(
----------
prototype : BufferPrototype, optional
The buffer prototype to use when reading the bytes.
byte_range : ByteRangeRequest, optional
byte_range : ByteRequest, optional
The range of bytes to read.
Returns
Expand All @@ -123,15 +123,15 @@ async def get(
prototype = default_buffer_prototype()
return await self.store.get(self.path, prototype=prototype, byte_range=byte_range)

async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) -> None:
async def set(self, value: Buffer, byte_range: ByteRequest | None = None) -> None:
"""
Write bytes to the store.
Parameters
----------
value : Buffer
The buffer to write.
byte_range : ByteRangeRequest, optional
byte_range : ByteRequest, optional
The range of bytes to write. If None, the entire buffer is written.
Raises
Expand Down
81 changes: 50 additions & 31 deletions src/zarr/storage/_fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
import warnings
from typing import TYPE_CHECKING, Any

from zarr.abc.store import ByteRangeRequest, Store
from zarr.abc.store import (
ByteRequest,
OffsetByteRequest,
RangeByteRequest,
Store,
SuffixByteRequest,
)
from zarr.storage._common import _dereference_path

if TYPE_CHECKING:
Expand Down Expand Up @@ -199,31 +205,34 @@ async def get(
self,
key: str,
prototype: BufferPrototype,
byte_range: ByteRangeRequest | None = None,
byte_range: ByteRequest | None = None,
) -> Buffer | None:
# docstring inherited
if not self._is_open:
await self._open()
path = _dereference_path(self.path, key)

try:
if byte_range:
# fsspec uses start/end, not start/length
start, length = byte_range
if start is not None and length is not None:
end = start + length
elif length is not None:
end = length
else:
end = None
value = prototype.buffer.from_bytes(
await (
self.fs._cat_file(path, start=byte_range[0], end=end)
if byte_range
else self.fs._cat_file(path)
if byte_range is None:
value = prototype.buffer.from_bytes(await self.fs._cat_file(path))
elif isinstance(byte_range, RangeByteRequest):
value = prototype.buffer.from_bytes(
await self.fs._cat_file(
path,
start=byte_range.start,
end=byte_range.end,
)
)
)

elif isinstance(byte_range, OffsetByteRequest):
value = prototype.buffer.from_bytes(
await self.fs._cat_file(path, start=byte_range.offset, end=None)
)
elif isinstance(byte_range, SuffixByteRequest):
value = prototype.buffer.from_bytes(
await self.fs._cat_file(path, start=-byte_range.suffix, end=None)
)
else:
raise ValueError(f"Unexpected byte_range, got {byte_range}.")
except self.allowed_exceptions:
return None
except OSError as e:
Expand Down Expand Up @@ -270,25 +279,35 @@ async def exists(self, key: str) -> bool:
async def get_partial_values(
self,
prototype: BufferPrototype,
key_ranges: Iterable[tuple[str, ByteRangeRequest]],
key_ranges: Iterable[tuple[str, ByteRequest | None]],
) -> list[Buffer | None]:
# docstring inherited
if key_ranges:
paths, starts, stops = zip(
*(
(
_dereference_path(self.path, k[0]),
k[1][0],
((k[1][0] or 0) + k[1][1]) if k[1][1] is not None else None,
)
for k in key_ranges
),
strict=False,
)
# _cat_ranges expects a list of paths, start, and end ranges, so we need to reformat each ByteRequest.
key_ranges = list(key_ranges)
paths: list[str] = []
starts: list[int | None] = []
stops: list[int | None] = []
for key, byte_range in key_ranges:
paths.append(_dereference_path(self.path, key))
if byte_range is None:
starts.append(None)
stops.append(None)
elif isinstance(byte_range, RangeByteRequest):
starts.append(byte_range.start)
stops.append(byte_range.end)
elif isinstance(byte_range, OffsetByteRequest):
starts.append(byte_range.offset)
stops.append(None)
elif isinstance(byte_range, SuffixByteRequest):
starts.append(-byte_range.suffix)
stops.append(None)
else:
raise ValueError(f"Unexpected byte_range, got {byte_range}.")
else:
return []
# TODO: expectations for exceptions or missing keys?
res = await self.fs._cat_ranges(list(paths), starts, stops, on_error="return")
res = await self.fs._cat_ranges(paths, starts, stops, on_error="return")
# the following is an s3-specific condition we probably don't want to leak
res = [b"" if (isinstance(r, OSError) and "not satisfiable" in str(r)) else r for r in res]
for r in res:
Expand Down
Loading

0 comments on commit 0328656

Please sign in to comment.