Skip to content

Commit

Permalink
arrays proxy
Browse files Browse the repository at this point in the history
  • Loading branch information
jhamman committed Oct 17, 2024
1 parent 77115ff commit ed67cc3
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 8 deletions.
56 changes: 50 additions & 6 deletions src/zarr/core/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,36 @@ def flatten(
return metadata


class ArraysProxy:
"""
Proxy for arrays in a group.
Used to implement the `Group.arrays` property
"""

def __init__(self, group: Group) -> None:
self._group = group

def __getitem__(self, key: str) -> Array:
obj = self._group[key]
if isinstance(obj, Array):
return obj
raise KeyError(key)

def __setitem__(self, key: str, value: npt.ArrayLike) -> None:
"""
Set an array in the group.
"""
self._group._sync(self._group._async_group.set_array(key, value))

def __iter__(self) -> Generator[tuple[str, Array], None]:
for name, async_array in self._group._sync_iter(self._group._async_group.arrays()):
yield name, Array(async_array)

def __call__(self) -> Generator[tuple[str, Array], None]:
return iter(self)


@dataclass(frozen=True)
class GroupMetadata(Metadata):
attributes: dict[str, Any] = field(default_factory=dict)
Expand Down Expand Up @@ -596,7 +626,16 @@ def from_dict(
store_path=store_path,
)

async def setitem(self, key: str, value: Any) -> None:
async def set_array(self, key: str, value: Any) -> None:
"""fastpath for creating a new array
Parameters
----------
key : str
Array name
value : array-like
Array data
"""
path = self.store_path / key
await async_api.save_array(
store=path, arr=value, zarr_format=self.metadata.zarr_format, exists_ok=True
Expand Down Expand Up @@ -1374,9 +1413,14 @@ def __iter__(self) -> Iterator[str]:
def __len__(self) -> int:
return self.nmembers()

@deprecated("Use Group.arrays setter instead.")
def __setitem__(self, key: str, value: Any) -> None:
"""Create a new array"""
self._sync(self._async_group.setitem(key, value))
"""Create a new array
.. deprecated:: 3.0.0
Use Group.arrays.setter instead.
"""
self._sync(self._async_group.set_array(key, value))

def __repr__(self) -> str:
return f"<Group {self.store_path}>"
Expand Down Expand Up @@ -1473,9 +1517,9 @@ def group_values(self) -> Generator[Group, None]:
for _, group in self.groups():
yield group

def arrays(self) -> Generator[tuple[str, Array], None]:
for name, async_array in self._sync_iter(self._async_group.arrays()):
yield name, Array(async_array)
@property
def arrays(self) -> ArraysProxy:
return ArraysProxy(self)

def array_keys(self) -> Generator[str, None]:
for name, _ in self.arrays():
Expand Down
31 changes: 29 additions & 2 deletions tests/v3/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,8 @@ def test_group_setitem(store: Store, zarr_format: ZarrFormat) -> None:
"""
group = Group.from_store(store, zarr_format=zarr_format)
arr = np.ones((2, 4))
group["key"] = arr
with pytest.warns(DeprecationWarning):
group["key"] = arr
assert group["key"].shape == (2, 4)
np.testing.assert_array_equal(group["key"][:], arr)

Expand All @@ -405,7 +406,8 @@ def test_group_setitem(store: Store, zarr_format: ZarrFormat) -> None:

# overwrite with another array
arr = np.zeros((3, 5))
group[key] = arr
with pytest.warns(DeprecationWarning):
group[key] = arr
assert group[key].shape == (3, 5)
np.testing.assert_array_equal(group[key], arr)

Expand All @@ -416,6 +418,31 @@ def test_group_setitem(store: Store, zarr_format: ZarrFormat) -> None:
# assert group["key"][:] == 1


def test_group_arrays_setter(store: Store, zarr_format: ZarrFormat) -> None:
"""
Test the `Group.__setitem__` method.
"""
group = Group.from_store(store, zarr_format=zarr_format)
arr = np.ones((2, 4))
group.arrays["key"] = arr
assert group["key"].shape == (2, 4)
np.testing.assert_array_equal(group["key"][:], arr)

if store.supports_deletes:
key = "key"
else:
# overwriting with another array requires deletes
# for stores that don't support this, we just use a new key
key = "key2"

# overwrite with another array
arr = np.zeros((3, 5))
with pytest.warns(DeprecationWarning):
group[key] = arr
assert group[key].shape == (3, 5)
np.testing.assert_array_equal(group[key], arr)


def test_group_contains(store: Store, zarr_format: ZarrFormat) -> None:
"""
Test the `Group.__contains__` method
Expand Down

0 comments on commit ed67cc3

Please sign in to comment.