Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
jhamman committed Jan 2, 2024
1 parent 84b497f commit f4822c7
Show file tree
Hide file tree
Showing 11 changed files with 546 additions and 387 deletions.
2 changes: 1 addition & 1 deletion zarr/v3/abc/codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import numpy as np

from zarr.v3.common import BytesLike, SliceSelection
from zarr.v3.store import StorePath
from zarr.v3.stores import StorePath


if TYPE_CHECKING:
Expand Down
33 changes: 30 additions & 3 deletions zarr/v3/abc/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ async def get(self, key: str) -> bytes:
...

@abstractmethod
async def get_partial_values(self, key_ranges: List[Tuple[str, int]]) -> bytes:
async def get_partial_values(self, key_ranges: List[Tuple[str, Tuple[int, int]]]) -> List[bytes]:
"""Retrieve possibly partial values from given key_ranges.
Parameters
----------
key_ranges : list[tuple[str, int]]
key_ranges : list[tuple[str, tuple[int, int]]]
Ordered set of key, range pairs, a key may occur multiple times with different ranges
Returns
Expand All @@ -38,6 +38,19 @@ async def get_partial_values(self, key_ranges: List[Tuple[str, int]]) -> bytes:
"""
...

async def exists(self, key: str) -> bool:
"""Check if a key exists in the store.
Parameters
----------
key : str
Returns
-------
bool
"""
...


class WriteStore(ReadStore):
@abstractmethod
Expand All @@ -51,6 +64,20 @@ async def set(self, key: str, value: bytes) -> None:
"""
...

async def delete(self, key: str) -> None
"""Remove a key from the store
Parameters
----------
key : str
"""
...


class PartialWriteStore(WriteStore):
# TODO, instead of using this, should we just check if the store is a PartialWriteStore?
supports_partial_writes = True

@abstractmethod
async def set_partial_values(self, key_start_values: List[Tuple[str, int, bytes]]) -> None:
"""Store values at a given key, starting at byte range_start.
Expand Down Expand Up @@ -78,7 +105,7 @@ async def list(self) -> List[str]:

@abstractmethod
async def list_prefix(self, prefix: str) -> List[str]:
"""Retrieve all keys in the store.
"""Retrieve all keys in the store with a given prefix.
Parameters
----------
Expand Down
26 changes: 13 additions & 13 deletions zarr/v3/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ async def create(
) -> AsyncArray:
store_path = make_store_path(store)
if not exists_ok:
assert not await (store_path / ZARR_JSON).exists_async()
assert not await (store_path / ZARR_JSON).exists()

data_type = (
DataType[dtype] if isinstance(dtype, str) else DataType[dtype_to_data_type[dtype.str]]
Expand Down Expand Up @@ -152,7 +152,7 @@ async def open(
runtime_configuration: RuntimeConfiguration = RuntimeConfiguration(),
) -> AsyncArray:
store_path = make_store_path(store)
zarr_json_bytes = await (store_path / ZARR_JSON).get_async()
zarr_json_bytes = await (store_path / ZARR_JSON).get()
assert zarr_json_bytes is not None
return cls.from_json(
store_path,
Expand All @@ -167,7 +167,7 @@ async def open_auto(
runtime_configuration: RuntimeConfiguration = RuntimeConfiguration(),
) -> AsyncArray: # TODO: Union[AsyncArray, ArrayV2]
store_path = make_store_path(store)
v3_metadata_bytes = await (store_path / ZARR_JSON).get_async()
v3_metadata_bytes = await (store_path / ZARR_JSON).get()
if v3_metadata_bytes is not None:
return cls.from_json(
store_path,
Expand All @@ -176,7 +176,7 @@ async def open_auto(
)
else:
raise ValueError("no v2 support yet")
# return await ArrayV2.open_async(store_path)
# return await ArrayV2.open(store_path)

@property
def ndim(self) -> int:
Expand Down Expand Up @@ -230,7 +230,7 @@ async def getitem(self, selection: Selection):
async def _save_metadata(self) -> None:
self._validate_metadata()

await (self.store_path / ZARR_JSON).set_async(self.metadata.to_bytes())
await (self.store_path / ZARR_JSON).set(self.metadata.to_bytes())

def _validate_metadata(self) -> None:
assert len(self.metadata.shape) == len(
Expand Down Expand Up @@ -263,7 +263,7 @@ async def _read_chunk(
else:
out[out_selection] = self.metadata.fill_value
else:
chunk_bytes = await store_path.get_async()
chunk_bytes = await store_path.get()
if chunk_bytes is not None:
chunk_array = await self.codec_pipeline.decode(chunk_bytes)
tmp = chunk_array[chunk_selection]
Expand Down Expand Up @@ -345,7 +345,7 @@ async def _write_chunk(
else:
# writing partial chunks
# read chunk first
chunk_bytes = await store_path.get_async()
chunk_bytes = await store_path.get()

# merge new value
if chunk_bytes is None:
Expand All @@ -365,13 +365,13 @@ async def _write_chunk(
async def _write_chunk_to_store(self, store_path: StorePath, chunk_array: np.ndarray):
if np.all(chunk_array == self.metadata.fill_value):
# chunks that only contain fill_value will be removed
await store_path.delete_async()
await store_path.delete()
else:
chunk_bytes = await self.codec_pipeline.encode(chunk_array)
if chunk_bytes is None:
await store_path.delete_async()
await store_path.delete()
else:
await store_path.set_async(chunk_bytes)
await store_path.set(chunk_bytes)

async def resize(self, new_shape: ChunkCoords) -> AsyncArray:
assert len(new_shape) == len(self.metadata.shape)
Expand All @@ -384,7 +384,7 @@ async def resize(self, new_shape: ChunkCoords) -> AsyncArray:
new_chunk_coords = set(all_chunk_coords(new_shape, chunk_shape))

async def _delete_key(key: str) -> None:
await (self.store_path / key).delete_async()
await (self.store_path / key).delete()

await concurrent_map(
[
Expand All @@ -396,14 +396,14 @@ async def _delete_key(key: str) -> None:
)

# Write new metadata
await (self.store_path / ZARR_JSON).set_async(new_metadata.to_bytes())
await (self.store_path / ZARR_JSON).set(new_metadata.to_bytes())
return evolve(self, metadata=new_metadata)

async def update_attributes(self, new_attributes: Dict[str, Any]) -> Array:
new_metadata = evolve(self.metadata, attributes=new_attributes)

# Write new metadata
await (self.store_path / ZARR_JSON).set_async(new_metadata.to_bytes())
await (self.store_path / ZARR_JSON).set(new_metadata.to_bytes())
return evolve(self, metadata=new_metadata)

def __repr__(self):
Expand Down
2 changes: 1 addition & 1 deletion zarr/v3/codecs/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
CodecMetadata,
ShardingCodecIndexLocation,
)
from zarr.v3.store import StorePath
from zarr.v3.stores import StorePath

MAX_UINT_64 = 2**64 - 1

Expand Down
32 changes: 22 additions & 10 deletions zarr/v3/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ async def create(
store_path = make_store_path(store)
if not exists_ok:
if zarr_format == 3:
assert not await (store_path / ZARR_JSON).exists_async()
assert not await (store_path / ZARR_JSON).exists()
elif zarr_format == 2:
assert not await (store_path / ZGROUP_JSON).exists_async()
assert not await (store_path / ZGROUP_JSON).exists()
group = cls(
metadata=GroupMetadata(attributes=attributes or {}, zarr_format=zarr_format),
store_path=store_path,
Expand All @@ -82,7 +82,7 @@ async def open(
if zarr_format == 3:
# V3 groups are comprised of a zarr.json object
# (it is optional in the case of implicit groups)
zarr_json_bytes = await (store_path / ZARR_JSON).get_async()
zarr_json_bytes = await (store_path / ZARR_JSON).get()
zarr_json = (
json.loads(zarr_json_bytes) if zarr_json_bytes is not None else {"zarr_format": 3}
)
Expand All @@ -91,7 +91,7 @@ async def open(
# V2 groups are comprised of a .zgroup and .zattrs objects
# (both are optional in the case of implicit groups)
zgroup_bytes, zattrs_bytes = await asyncio.gather(
(store_path / ZGROUP_JSON).get_async(), (store_path / ZATTRS_JSON).get_async()
(store_path / ZGROUP_JSON).get(), (store_path / ZATTRS_JSON).get()
)
zgroup = (
json.loads(json.loads(zgroup_bytes))
Expand Down Expand Up @@ -126,7 +126,7 @@ async def getitem(
store_path = self.store_path / key

if self.metadata.zarr_format == 3:
zarr_json_bytes = await (store_path / ZARR_JSON).get_async()
zarr_json_bytes = await (store_path / ZARR_JSON).get()
if zarr_json_bytes is None:
# implicit group?
logger.warning("group at {} is an implicit group", store_path)
Expand All @@ -147,9 +147,9 @@ async def getitem(
# Q: how do we like optimistically fetching .zgroup, .zarray, and .zattrs?
# This guarantees that we will always make at least one extra request to the store
zgroup_bytes, zarray_bytes, zattrs_bytes = await asyncio.gather(
(store_path / ZGROUP_JSON).get_async(),
(store_path / ZARRAY_JSON).get_async(),
(store_path / ZATTRS_JSON).get_async(),
(store_path / ZGROUP_JSON).get(),
(store_path / ZARRAY_JSON).get(),
(store_path / ZATTRS_JSON).get(),
)

# unpack the zarray, if this is None then we must be opening a group
Expand Down Expand Up @@ -177,6 +177,18 @@ async def getitem(
else:
raise ValueError(f"unexpected zarr_format: {self.metadata.zarr_format}")

async def delitem(self, key: str) -> None:
store_path = self.store_path / key
if self.metadata.zarr_format == 3:
await (store_path / ZARR_JSON).delete()
elif self.metadata.zarr_format == 2:
await asyncio.gather(
(store_path / ZGROUP_JSON).delete(), # TODO: missing_ok=False
(store_path / ZATTRS_JSON).delete(), # TODO: missing_ok=True
)
else:
raise ValueError(f"unexpected zarr_format: {self.metadata.zarr_format}")

async def _save_metadata(self) -> None:
to_save = self.metadata.to_bytes()
awaitables = [(self.store_path / key).set_async(value) for key, value in to_save.items()]
Expand Down Expand Up @@ -320,8 +332,8 @@ def __getitem__(self, path: str) -> Union[Array, Group]:
else:
return Group(obj)

def __delitem__(self, key):
raise NotImplementedError
def __delitem__(self, key) -> None:
self._sync(self._async_group.delitem(path))

def __iter__(self):
raise NotImplementedError
Expand Down
Loading

0 comments on commit f4822c7

Please sign in to comment.