Skip to content

Commit

Permalink
fix: threadpool configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
jhamman committed Jan 8, 2025
1 parent 29ef41d commit 89c5b46
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 7 deletions.
8 changes: 5 additions & 3 deletions src/zarr/core/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ def _get_executor() -> ThreadPoolExecutor:
global _executor
if not _executor:
max_workers = config.get("threading.max_workers", None)
print(max_workers)
# if max_workers is not None and max_workers > 0:
# raise ValueError(max_workers)
logger.debug("Creating Zarr ThreadPoolExecutor with max_workers=%s", max_workers)
_executor = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="zarr_pool")
_get_loop().set_default_executor(_executor)
return _executor
Expand Down Expand Up @@ -118,6 +116,9 @@ def sync(
# NB: if the loop is not running *yet*, it is OK to submit work
# and we will wait for it
loop = _get_loop()
if _executor is None and config.get("threading.max_workers", None) is not None:
# trigger executor creation and attach to loop
_ = _get_executor()
if not isinstance(loop, asyncio.AbstractEventLoop):
raise TypeError(f"loop cannot be of type {type(loop)}")
if loop.is_closed():
Expand Down Expand Up @@ -153,6 +154,7 @@ def _get_loop() -> asyncio.AbstractEventLoop:
# repeat the check just in case the loop got filled between the
# previous two calls from another thread
if loop[0] is None:
logger.debug("Creating Zarr event loop")
new_loop = asyncio.new_event_loop()
loop[0] = new_loop
iothread[0] = threading.Thread(target=new_loop.run_forever, name="zarr_io")
Expand Down
18 changes: 14 additions & 4 deletions tests/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
_get_lock,
_get_loop,
cleanup_resources,
loop,
sync,
)
from zarr.storage import MemoryStore
Expand Down Expand Up @@ -148,11 +149,20 @@ def test_open_positional_args_deprecate():


@pytest.mark.parametrize("workers", [None, 1, 2])
def test_get_executor(clean_state, workers) -> None:
def test_threadpool_executor(clean_state, workers: int | None) -> None:
with zarr.config.set({"threading.max_workers": workers}):
e = _get_executor()
if workers is not None and workers != 0:
assert e._max_workers == workers
_ = zarr.zeros(shape=(1,)) # trigger executor creation
assert loop != [None] # confirm loop was created
if workers is None:
# confirm no executor was created if no workers were specified
# (this is the default behavior)
assert loop[0]._default_executor is None
else:
# confirm executor was created and attached to loop as the default executor
# note: python doesn't have a direct way to get the default executor so we
# use the private attribute
assert _get_executor() is loop[0]._default_executor
assert _get_executor()._max_workers == workers


def test_cleanup_resources_idempotent() -> None:
Expand Down

0 comments on commit 89c5b46

Please sign in to comment.