From 7dff5e50af3f2d3b4e7dc31f0ca9809599fd379d Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Tue, 9 Apr 2024 09:35:48 -0700 Subject: [PATCH] feature(store): make list_* methods async generators --- src/zarr/v3/abc/store.py | 13 +++++----- src/zarr/v3/group.py | 32 ++++++++++++++----------- src/zarr/v3/store/local.py | 47 +++++++++++++++++-------------------- src/zarr/v3/store/memory.py | 33 +++++++++++++------------- tests/v3/test_group.py | 5 ++++ 5 files changed, 68 insertions(+), 62 deletions(-) diff --git a/src/zarr/v3/abc/store.py b/src/zarr/v3/abc/store.py index ce5de279c4..c9845b7ae7 100644 --- a/src/zarr/v3/abc/store.py +++ b/src/zarr/v3/abc/store.py @@ -1,4 +1,5 @@ from abc import abstractmethod, ABC +from collections.abc import AsyncGenerator from typing import List, Tuple, Optional @@ -106,17 +107,17 @@ def supports_listing(self) -> bool: ... @abstractmethod - async def list(self) -> List[str]: + async def list(self) -> AsyncGenerator[str, None]: """Retrieve all keys in the store. Returns ------- - list[str] + AsyncGenerator[str, None] """ ... @abstractmethod - async def list_prefix(self, prefix: str) -> List[str]: + async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: """Retrieve all keys in the store with a given prefix. Parameters @@ -125,12 +126,12 @@ async def list_prefix(self, prefix: str) -> List[str]: Returns ------- - list[str] + AsyncGenerator[str, None] """ ... @abstractmethod - async def list_dir(self, prefix: str) -> List[str]: + async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: """ Retrieve all keys and prefixes with a given prefix and which do not contain the character “/” after the given prefix. @@ -141,6 +142,6 @@ async def list_dir(self, prefix: str) -> List[str]: Returns ------- - list[str] + AsyncGenerator[str, None] """ ... diff --git a/src/zarr/v3/group.py b/src/zarr/v3/group.py index bfb6440cf3..031d9a0ad9 100644 --- a/src/zarr/v3/group.py +++ b/src/zarr/v3/group.py @@ -1,4 +1,5 @@ from __future__ import annotations +from collections.abc import AsyncGenerator from typing import TYPE_CHECKING from dataclasses import asdict, dataclass, field, replace @@ -9,7 +10,6 @@ if TYPE_CHECKING: from typing import ( Any, - AsyncGenerator, Literal, AsyncIterator, Iterator, @@ -161,6 +161,7 @@ async def getitem( ) -> AsyncArray | AsyncGroup: store_path = self.store_path / key + logger.warning("key=%s, store_path=%s", key, store_path) # Note: # in zarr-python v2, we first check if `key` references an Array, else if `key` references @@ -289,7 +290,7 @@ def __repr__(self): async def nchildren(self) -> int: raise NotImplementedError - async def children(self) -> AsyncGenerator[AsyncArray, AsyncGroup]: + async def children(self) -> AsyncGenerator[AsyncArray | AsyncGroup, None]: """ Returns an AsyncGenerator over the arrays and groups contained in this group. This method requires that `store_path.store` supports directory listing. @@ -303,18 +304,21 @@ async def children(self) -> AsyncGenerator[AsyncArray, AsyncGroup]: ) raise ValueError(msg) - subkeys = await self.store_path.store.list_dir(self.store_path.path) - # would be nice to make these special keys accessible programmatically, - # and scoped to specific zarr versions - subkeys_filtered = filter(lambda v: v not in ("zarr.json", ".zgroup", ".zattrs"), subkeys) - # is there a better way to schedule this? - for subkey in subkeys_filtered: - try: - yield await self.getitem(subkey) - except KeyError: - # keyerror is raised when `subkey``names an object in the store - # in which case `subkey` cannot be the name of a sub-array or sub-group. - pass + + async for key in self.store_path.store.list_dir(self.store_path.path): + # these keys are not valid child names so we make sure to skip them + # TODO: it would be nice to make these special keys accessible programmatically, + # and scoped to specific zarr versions + if key not in ("zarr.json", ".zgroup", ".zattrs"): + try: + # TODO: performance optimization -- batch + print(key) + child = await self.getitem(key) + # keyerror is raised when `subkey``names an object in the store + # in which case `subkey` cannot be the name of a sub-array or sub-group. + yield child + except KeyError: + pass async def contains(self, child: str) -> bool: raise NotImplementedError diff --git a/src/zarr/v3/store/local.py b/src/zarr/v3/store/local.py index 1677e08ddf..1a87c450a0 100644 --- a/src/zarr/v3/store/local.py +++ b/src/zarr/v3/store/local.py @@ -2,6 +2,7 @@ import io import shutil +from collections.abc import AsyncGenerator from pathlib import Path from typing import Union, Optional, List, Tuple @@ -142,21 +143,19 @@ async def exists(self, key: str) -> bool: path = self.root / key return await to_thread(path.is_file) - async def list(self) -> List[str]: + async def list(self) -> AsyncGenerator[str, None]: """Retrieve all keys in the store. Returns ------- - list[str] + AsyncGenerator[str, None] """ - # Q: do we want to return strings or Paths? - def _list(root: Path) -> List[str]: - files = [str(p) for p in root.rglob("") if p.is_file()] - return files + for p in self.root.rglob(""): + if p.is_file(): + yield str(p) - return await to_thread(_list, self.root) - async def list_prefix(self, prefix: str) -> List[str]: + async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: """Retrieve all keys in the store with a given prefix. Parameters @@ -165,16 +164,14 @@ async def list_prefix(self, prefix: str) -> List[str]: Returns ------- - list[str] + AsyncGenerator[str, None] """ + for p in (self.root / prefix).rglob("*"): + if p.is_file(): + yield str(p) - def _list_prefix(root: Path, prefix: str) -> List[str]: - files = [p for p in (root / prefix).rglob("*") if p.is_file()] - return files - return await to_thread(_list_prefix, self.root, prefix) - - async def list_dir(self, prefix: str) -> List[str]: + async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: """ Retrieve all keys and prefixes with a given prefix and which do not contain the character “/” after the given prefix. @@ -185,16 +182,16 @@ async def list_dir(self, prefix: str) -> List[str]: Returns ------- - list[str] + AsyncGenerator[str, None] """ + base = self.root / prefix + to_strip = str(base) + "/" + + try: + key_iter = base.iterdir() + except (FileNotFoundError, NotADirectoryError): + key_iter = [] - def _list_dir(root: Path, prefix: str) -> List[str]: - - base = root / prefix - to_strip = str(base) + "/" - try: - return [str(key).replace(to_strip, "") for key in base.iterdir()] - except (FileNotFoundError, NotADirectoryError): - return [] + for key in key_iter: + yield str(key).replace(to_strip, "") - return await to_thread(_list_dir, self.root, prefix) diff --git a/src/zarr/v3/store/memory.py b/src/zarr/v3/store/memory.py index afacfa4321..dbfa537edf 100644 --- a/src/zarr/v3/store/memory.py +++ b/src/zarr/v3/store/memory.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import AsyncGenerator from typing import Optional, MutableMapping, List, Tuple from zarr.v3.common import BytesLike @@ -67,20 +68,18 @@ async def delete(self, key: str) -> None: async def set_partial_values(self, key_start_values: List[Tuple[str, int, bytes]]) -> None: raise NotImplementedError - async def list(self) -> List[str]: - return list(self._store_dict.keys()) - - async def list_prefix(self, prefix: str) -> List[str]: - return [key for key in self._store_dict if key.startswith(prefix)] - - async def list_dir(self, prefix: str) -> List[str]: - if prefix == "": - return list({key.split("/", maxsplit=1)[0] for key in self._store_dict}) - else: - return list( - { - key.strip(prefix + "/").split("/")[0] - for key in self._store_dict - if (key.startswith(prefix + "/") and key != prefix) - } - ) + async def list(self) -> AsyncGenerator[str, None]: + for key in self._store_dict: + yield key + + async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: + for key in self._store_dict: + if key.startswith(prefix): + yield key + + async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: + print('prefix', prefix) + print('keys in list_dir', list(self._store_dict)) + for key in self._store_dict: + if key.startswith(prefix + "/") and key != prefix: + yield key.strip(prefix + "/").rsplit("/", maxsplit=1)[0] diff --git a/tests/v3/test_group.py b/tests/v3/test_group.py index f3530c6042..136a7917ad 100644 --- a/tests/v3/test_group.py +++ b/tests/v3/test_group.py @@ -56,13 +56,18 @@ def test_group_children(store: MemoryStore | LocalStore): # if group.children guarantees a particular order for the children. # If order is not guaranteed, then the better version of this test is # to compare two sets, but presently neither the group nor array classes are hashable. + print('getting children') observed = group.children + print(observed) + print(list([subgroup, subarray, implicit_subgroup])) assert len(observed) == 3 assert subarray in observed assert implicit_subgroup in observed assert subgroup in observed + + @pytest.mark.parametrize("store", (("local", "memory")), indirect=["store"]) def test_group(store: MemoryStore | LocalStore) -> None: store_path = StorePath(store)