From 89c5b46e0a90a72f49d8ac9ec5330086b8f03924 Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Tue, 7 Jan 2025 20:22:54 -0800 Subject: [PATCH] fix: threadpool configuration --- src/zarr/core/sync.py | 8 +++++--- tests/test_sync.py | 18 ++++++++++++++---- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/src/zarr/core/sync.py b/src/zarr/core/sync.py index f7d4529478..6a2de855e8 100644 --- a/src/zarr/core/sync.py +++ b/src/zarr/core/sync.py @@ -54,9 +54,7 @@ def _get_executor() -> ThreadPoolExecutor: global _executor if not _executor: max_workers = config.get("threading.max_workers", None) - print(max_workers) - # if max_workers is not None and max_workers > 0: - # raise ValueError(max_workers) + logger.debug("Creating Zarr ThreadPoolExecutor with max_workers=%s", max_workers) _executor = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="zarr_pool") _get_loop().set_default_executor(_executor) return _executor @@ -118,6 +116,9 @@ def sync( # NB: if the loop is not running *yet*, it is OK to submit work # and we will wait for it loop = _get_loop() + if _executor is None and config.get("threading.max_workers", None) is not None: + # trigger executor creation and attach to loop + _ = _get_executor() if not isinstance(loop, asyncio.AbstractEventLoop): raise TypeError(f"loop cannot be of type {type(loop)}") if loop.is_closed(): @@ -153,6 +154,7 @@ def _get_loop() -> asyncio.AbstractEventLoop: # repeat the check just in case the loop got filled between the # previous two calls from another thread if loop[0] is None: + logger.debug("Creating Zarr event loop") new_loop = asyncio.new_event_loop() loop[0] = new_loop iothread[0] = threading.Thread(target=new_loop.run_forever, name="zarr_io") diff --git a/tests/test_sync.py b/tests/test_sync.py index b0a6ecffd0..e0002fc5a7 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -12,6 +12,7 @@ _get_lock, _get_loop, cleanup_resources, + loop, sync, ) from zarr.storage import MemoryStore @@ -148,11 +149,20 @@ def test_open_positional_args_deprecate(): @pytest.mark.parametrize("workers", [None, 1, 2]) -def test_get_executor(clean_state, workers) -> None: +def test_threadpool_executor(clean_state, workers: int | None) -> None: with zarr.config.set({"threading.max_workers": workers}): - e = _get_executor() - if workers is not None and workers != 0: - assert e._max_workers == workers + _ = zarr.zeros(shape=(1,)) # trigger executor creation + assert loop != [None] # confirm loop was created + if workers is None: + # confirm no executor was created if no workers were specified + # (this is the default behavior) + assert loop[0]._default_executor is None + else: + # confirm executor was created and attached to loop as the default executor + # note: python doesn't have a direct way to get the default executor so we + # use the private attribute + assert _get_executor() is loop[0]._default_executor + assert _get_executor()._max_workers == workers def test_cleanup_resources_idempotent() -> None: