Skip to content

Commit

Permalink
v3 refactor - fix for sync constructors
Browse files Browse the repository at this point in the history
  • Loading branch information
jhamman committed Feb 7, 2024
1 parent 661c0d6 commit 61f58e6
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 6 deletions.
13 changes: 8 additions & 5 deletions src/zarr/v3/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from zarr.v3.common import ZARR_JSON, ZARRAY_JSON, ZATTRS_JSON, ZGROUP_JSON, make_cattr
from zarr.v3.config import RuntimeConfiguration, SyncConfiguration
from zarr.v3.store import StoreLike, StorePath, make_store_path
from zarr.v3.sync import SyncMixin
from zarr.v3.sync import SyncMixin, sync

logger = logging.getLogger("zarr.group")

Expand All @@ -21,7 +21,7 @@
class GroupMetadata:
attributes: Dict[str, Any] = field(factory=dict)
zarr_format: Literal[2, 3] = 3 # field(default=3, validator=validators.in_([2, 3]))
node_type: Literal["group"] = field(default="group", init=False)
node_type: Literal["group"] = field(default="group", init=True)

def to_bytes(self) -> Dict[str, bytes]:
if self.zarr_format == 3:
Expand Down Expand Up @@ -305,13 +305,14 @@ def create(
exists_ok: bool = False,
runtime_configuration: RuntimeConfiguration = RuntimeConfiguration(),
) -> Group:
obj = cls._sync(
obj = sync(
AsyncGroup.create(
store,
attributes=attributes,
exists_ok=exists_ok,
runtime_configuration=runtime_configuration,
)
),
loop=runtime_configuration.asyncio_loop,
)

return cls(obj)
Expand All @@ -322,7 +323,9 @@ def open(
store: StoreLike,
runtime_configuration: RuntimeConfiguration = RuntimeConfiguration(),
) -> Group:
obj = cls._sync(AsyncGroup.open(store, runtime_configuration))
obj = sync(
AsyncGroup.open(store, runtime_configuration), loop=runtime_configuration.asyncio_loop
)
return cls(obj)

def __getitem__(self, path: str) -> Union[Array, Group]:
Expand Down
38 changes: 37 additions & 1 deletion tests/test_group_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def store_path(tmpdir):
return p


def test_group(store_path) -> None:
def test_group_async_constructor(store_path) -> None:

agroup = AsyncGroup(
metadata=GroupMetadata(),
Expand Down Expand Up @@ -54,3 +54,39 @@ def test_group(store_path) -> None:
# and the attrs were modified in the store
bar3 = foo["bar"]
assert dict(bar3.attrs) == {"baz": "qux", "name": "bar"}


def test_group_sync_constructor(store_path) -> None:

group = Group.create(store=store_path, runtime_configuration=RuntimeConfiguration())

# create two groups
foo = group.create_group("foo")
bar = foo.create_group("bar", attributes={"baz": "qux"})

# create an array from the "bar" group
data = np.arange(0, 4 * 4, dtype="uint16").reshape((4, 4))
arr = bar.create_array(
"baz", shape=data.shape, dtype=data.dtype, chunk_shape=(2, 2), exists_ok=True
)
arr[:] = data

# check the array
assert arr == bar["baz"]
assert arr.shape == data.shape
assert arr.dtype == data.dtype

# TODO: update this once the array api settles down
# assert arr.chunk_shape == (2, 2)

bar2 = foo["bar"]
assert dict(bar2.attrs) == {"baz": "qux"}

# update a group's attributes
bar2.attrs.update({"name": "bar"})
# bar.attrs was modified in-place
assert dict(bar2.attrs) == {"baz": "qux", "name": "bar"}

# and the attrs were modified in the store
bar3 = foo["bar"]
assert dict(bar3.attrs) == {"baz": "qux", "name": "bar"}

0 comments on commit 61f58e6

Please sign in to comment.