diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 09552c20..8a4b7e18 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -43,7 +43,7 @@ jobs: fail-fast: false matrix: os: [ "ubuntu-latest" ] - python-version: [ "3.9", "3.10", "3.11" ] + python-version: [ "3.8", "3.9", "3.10", "3.11" ] runs-on: "${{ matrix.os }}" steps: - name: Check out repository diff --git a/changelog/268.trivial.md b/changelog/268.trivial.md new file mode 100644 index 00000000..9f899124 --- /dev/null +++ b/changelog/268.trivial.md @@ -0,0 +1,3 @@ +Refactored :meth:`RunGroupBy.apply_parallel` to allow better user control and dependency injection + +Also updated associated type hints throughout `scmdata.groupby` diff --git a/docs/source/api/scmdata.groupby.rst b/docs/source/api/scmdata.groupby.rst index 8afd4a75..f8f79124 100644 --- a/docs/source/api/scmdata.groupby.rst +++ b/docs/source/api/scmdata.groupby.rst @@ -12,3 +12,9 @@ RunGroupBy .. autoclass:: RunGroupBy :members: + + +get\_joblib\_parallel\_processor +================================ + +.. autofunction:: get_joblib_parallel_processor diff --git a/src/scmdata/groupby.py b/src/scmdata/groupby.py index 372bb6b2..95e4ba18 100644 --- a/src/scmdata/groupby.py +++ b/src/scmdata/groupby.py @@ -1,9 +1,11 @@ """ Functionality for grouping and filtering ScmRun objects """ +from __future__ import annotations + import warnings from collections.abc import Iterable -from typing import TYPE_CHECKING, Callable, Generic, Iterator, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Callable, Generic, Iterator, TypeVar, Union import numpy as np import pandas as pd @@ -12,13 +14,25 @@ from xarray.core.common import ImplementsArrayReduce from scmdata._typing import MetadataValue -from scmdata.run import GenericRun +from scmdata.run import BaseScmRun, GenericRun if TYPE_CHECKING: from pandas.core.groupby.generic import DataFrameGroupBy from typing_extensions import Concatenate, ParamSpec P = ParamSpec("P") + Q = ParamSpec("Q") + RunLike = TypeVar("RunLike", bound=BaseScmRun) + ApplyCallableReturnType = Union[RunLike, pd.DataFrame, None] + ApplyCallable = Callable[Concatenate[RunLike, Q], ApplyCallableReturnType[RunLike]] + ParallelProcessor = Callable[ + Concatenate[ + ApplyCallable[RunLike, Q], + Iterable[RunLike], + Q, + ], + Iterable[ApplyCallableReturnType[RunLike]], + ] class RunGroupBy(ImplementsArrayReduce, Generic[GenericRun]): @@ -27,7 +41,7 @@ class RunGroupBy(ImplementsArrayReduce, Generic[GenericRun]): """ def __init__( - self, run: "GenericRun", groups: "Iterable[str]", na_fill_value: float = -10000 + self, run: GenericRun, groups: Iterable[str], na_fill_value: float = -10000 ): self.run = run self.group_keys = groups @@ -45,9 +59,9 @@ def __init__( else: m = m.fillna(na_fill_value) - self._grouper: "DataFrameGroupBy" = m.groupby(list(groups), group_keys=True) + self._grouper: DataFrameGroupBy = m.groupby(list(groups), group_keys=True) - def _iter_grouped(self) -> "Iterator[GenericRun]": + def _iter_grouped(self) -> Iterator[GenericRun]: def _try_fill_value(v: MetadataValue) -> MetadataValue: try: if float(v) == float(self.na_fill_value): @@ -57,7 +71,7 @@ def _try_fill_value(v: MetadataValue) -> MetadataValue: return v groups: Iterable[ - Union[MetadataValue, tuple[MetadataValue, ...]] + MetadataValue | tuple[MetadataValue, ...] ] = self._grouper.groups for indices in groups: if not isinstance(indices, Iterable) or isinstance(indices, str): @@ -82,20 +96,20 @@ def __iter__(self) -> Iterator[GenericRun]: def apply( self, - func: "Callable[Concatenate[GenericRun, P], Union[GenericRun, pd.DataFrame, None]]", - *args: "P.args", - **kwargs: "P.kwargs", - ) -> "GenericRun": + func: Callable[Concatenate[GenericRun, P], GenericRun | (pd.DataFrame | None)], + *args: P.args, + **kwargs: P.kwargs, + ) -> GenericRun: """ Apply a function to each group and append the results - `func` is called like `func(ar, *args, **kwargs)` for each :class:`ScmRun ` ``ar`` - in this group. If the result of this function call is None, than it is + `func` is called like `func(ar, *args, **kwargs)` for each :class:`ScmRun ` + group. If the result of this function call is ``None``, than it is excluded from the results. The results are appended together using :func:`run_append`. The function - can change the size of the input :class:`ScmRun ` as long as :func:`run_append` - can be applied to all results. + can change the size of the input :class:`ScmRun ` + as long as :func:`run_append` can be applied to all results. Examples -------- @@ -109,18 +123,17 @@ def apply( Parameters ---------- func - Callable to apply to each timeseries. + Callable to apply to each group. - ``*args`` + *args Positional arguments passed to `func`. - ``**kwargs`` - Used to call `func(ar, **kwargs)` for each array `ar`. + **kwargs + Keyword arguments passed to `func`. Returns ------- - applied : :class:`ScmRun ` - The result of splitting, applying and combining this array. + The result of applying and combining. """ grouped = self._iter_grouped() applied = [func(arr, *args, **kwargs) for arr in grouped] @@ -128,16 +141,16 @@ def apply( def apply_parallel( self, - func: "Callable[Concatenate[GenericRun, P], Union[GenericRun, pd.DataFrame, None]]", - n_jobs: int = 1, - backend: str = "loky", - *args: "P.args", - **kwargs: "P.kwargs", - ) -> "GenericRun": + func: ApplyCallable[GenericRun, P], + parallel_processor: ParallelProcessor[GenericRun, P] | None = None, + *args: P.args, + **kwargs: P.kwargs, + ) -> GenericRun: """ Apply a function to each group in parallel and append the results - Provides the same functionality as :func:`~apply` except that :mod:`joblib` is used to apply + Provides the same functionality as :func:`~apply` except that parallel processing can be + used via the ``parallel_processor`` argument. By default, :mod:`joblib` is used to apply `func` to each group in parallel. This can be slower than using :func:`~apply` for small numbers of groups or in the case where `func` is fast as there is overhead setting up the processing pool. @@ -149,41 +162,29 @@ def apply_parallel( Parameters ---------- func - Callable to apply to each timeseries. - - n_jobs - Number of jobs to run in parallel (defaults to a single job which is useful for - debugging purposes). If `-1` all CPUs are used. + Callable to apply to each group. - backend - Backend used for parallelisation. Defaults to 'loky' which uses separate processes for - each worker. + parallel_processor + Parallel processor to use to process the groups. If not provided, + a default joblib parallel processor is used (for details, see + :func:`get_joblib_parallel_processor`). - See :class:`joblib.Parallel` for a more complete description of the available - options. - - ``*args`` + *args Positional arguments passed to `func`. - ``**kwargs`` - Used to call `func(ar, **kwargs)` for each array `ar`. + **kwargs + Keyword arguments passed to `func`. Returns ------- - applied : :class:`ScmRun ` - The result of splitting, applying and combining this array. + The result of applying and combining. """ - try: - import joblib # type: ignore - except ImportError as e: # pragma: no cover - raise ImportError( - "joblib is not installed. Run 'pip install joblib'" - ) from e + if parallel_processor is None: + parallel_processor = get_joblib_parallel_processor() grouped = self._iter_grouped() - applied: "list[Union[GenericRun, pd.DataFrame, None]]" = joblib.Parallel( - n_jobs=n_jobs, backend=backend - )(joblib.delayed(func)(arr, *args, **kwargs) for arr in grouped) + applied = parallel_processor(func, grouped, *args, **kwargs) + return self._combine(applied) def map(self, func, *args, **kwargs): @@ -202,8 +203,8 @@ def map(self, func, *args, **kwargs): return self.apply(func, *args, **kwargs) def _combine( - self, applied: "Sequence[Union[GenericRun, pd.DataFrame, None]]" - ) -> "GenericRun": + self, applied: Iterable[GenericRun | (pd.DataFrame | None)] + ) -> GenericRun: """ Recombine the applied objects like the original. """ @@ -219,12 +220,12 @@ def _combine( def reduce( self, - func: "Callable[Concatenate[NDArray[np.float_], P], NDArray[np.float_]]", - dim: "Optional[Union[str, Iterable[str]]]" = None, - axis: "Optional[Union[str, Iterable[int]]]" = None, - *args: "P.args", - **kwargs: "P.kwargs", - ) -> "GenericRun": + func: Callable[Concatenate[NDArray[np.float_], P], NDArray[np.float_]], + dim: str | Iterable[str] | None = None, + axis: str | Iterable[int] | None = None, + *args: P.args, + **kwargs: P.kwargs, + ) -> GenericRun: """ Reduce the items in this group by applying `func` along some dimension(s). @@ -260,4 +261,59 @@ def reduce_array(ar): return self.apply(reduce_array) +def get_joblib_parallel_processor( + n_jobs: int = -1, + backend: str = "loky", + *args: Any, + **kwargs: Any, +) -> ParallelProcessor[RunLike, Q]: + """ + Get parallel processor using :mod:`joblib` as the backend. + + Parameters + ---------- + n_jobs + Number of jobs to run in parallel. If `-1` all CPUs are used. + + backend + Backend used for parallelisation. Defaults to 'loky' which uses separate processes for + each worker. + See :class:`joblib.Parallel` for a more complete description of the available + options. + + *args + Passed to initialiser of :class:`joblib.Parallel` + + **kwargs + Passed to initialiser of :class:`joblib.Parallel` + + Returns + ------- + Function that can be used for parallel processing in + :meth:`RunGroupBy.apply_parallel` + """ + try: + import joblib + except ImportError as e: # pragma: no cover + raise ImportError("joblib is not installed. Run 'pip install joblib'") from e + + processor = joblib.Parallel(*args, n_jobs=n_jobs, backend=backend, **kwargs) + + def joblib_parallel_processor( + func: ApplyCallable[RunLike, Q], + groups: Iterable[RunLike], + /, + *args: Q.args, + **kwargs: Q.kwargs, + ) -> Iterable[ApplyCallableReturnType[RunLike]]: + prepped_groups = ( + joblib.delayed(func)(group, *args, **kwargs) for group in groups + ) + applied = processor(prepped_groups) + + return applied + + return joblib_parallel_processor + + ops.inject_reduce_methods(RunGroupBy) diff --git a/src/scmdata/run.py b/src/scmdata/run.py index 81abf401..b59a0768 100644 --- a/src/scmdata/run.py +++ b/src/scmdata/run.py @@ -1982,10 +1982,10 @@ def apply( func : function Callable to apply to each timeseries. - ``*args`` + *args Positional arguments passed to `func`. - ``**kwargs`` + **kwargs Used to call `func(ar, **kwargs)` for each array `ar`. Returns diff --git a/stubs/.gitkeep b/stubs/.gitkeep deleted file mode 100644 index e69de29b..00000000 diff --git a/stubs/joblib.pyi b/stubs/joblib.pyi new file mode 100644 index 00000000..a716011a --- /dev/null +++ b/stubs/joblib.pyi @@ -0,0 +1,22 @@ +from typing import Any, Callable, Iterable, ParamSpec, TypeVar + +T = TypeVar("T") +P = ParamSpec("P") +Q = ParamSpec("Q") + +def delayed( + func: Callable[P, T] +) -> Callable[..., tuple[Callable[P, T], Q.args, Q.kwargs]]: ... + +class Parallel: + def __init__(self, *args: Any, n_jobs: int, backend: str, **kwargs: Any): ... + def __call__( + self, + iterable: Iterable[ + tuple[ + Callable[P, T], + P.args, + P.kwargs, + ] + ], + ) -> Iterable[T]: ... diff --git a/tests/unit/test_groupby.py b/tests/unit/test_groupby.py index 766fa1b9..99dbb1a9 100644 --- a/tests/unit/test_groupby.py +++ b/tests/unit/test_groupby.py @@ -2,6 +2,7 @@ import pytest from scmdata import ScmRun +from scmdata.groupby import get_joblib_parallel_processor from scmdata.testing import assert_scmdf_almost_equal group_tests = pytest.mark.parametrize( @@ -28,7 +29,7 @@ def func(df): return df if parallel: - res = scm_run.groupby(*g).apply_parallel(func, n_jobs=-1) + res = scm_run.groupby(*g).apply_parallel(func) else: res = scm_run.groupby(*g).apply(func)