Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor apply_parallel #268

Merged
merged 9 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
fail-fast: false
matrix:
os: [ "ubuntu-latest" ]
python-version: [ "3.9", "3.10", "3.11" ]
python-version: [ "3.8", "3.9", "3.10", "3.11" ]
runs-on: "${{ matrix.os }}"
steps:
- name: Check out repository
Expand Down
3 changes: 3 additions & 0 deletions changelog/268.trivial.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Refactored :meth:`RunGroupBy.apply_parallel` to allow better user control and dependency injection

Also updated associated type hints throughout `scmdata.groupby`
6 changes: 6 additions & 0 deletions docs/source/api/scmdata.groupby.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,9 @@ RunGroupBy

.. autoclass:: RunGroupBy
:members:


get\_joblib\_parallel\_processor
================================

.. autofunction:: get_joblib_parallel_processor
176 changes: 116 additions & 60 deletions src/scmdata/groupby.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""
Functionality for grouping and filtering ScmRun objects
"""
from __future__ import annotations

import warnings
from collections.abc import Iterable
from typing import TYPE_CHECKING, Callable, Generic, Iterator, Optional, Sequence, Union
from typing import TYPE_CHECKING, Any, Callable, Generic, Iterator, TypeVar, Union

import numpy as np
import pandas as pd
Expand All @@ -12,13 +14,25 @@
from xarray.core.common import ImplementsArrayReduce

from scmdata._typing import MetadataValue
from scmdata.run import GenericRun
from scmdata.run import BaseScmRun, GenericRun

if TYPE_CHECKING:
from pandas.core.groupby.generic import DataFrameGroupBy
from typing_extensions import Concatenate, ParamSpec

Check warning on line 21 in src/scmdata/groupby.py

View check run for this annotation

Codecov / codecov/patch

src/scmdata/groupby.py#L20-L21

Added lines #L20 - L21 were not covered by tests

P = ParamSpec("P")
Q = ParamSpec("Q")
RunLike = TypeVar("RunLike", bound=BaseScmRun)
ApplyCallableReturnType = Union[RunLike, pd.DataFrame, None]
ApplyCallable = Callable[Concatenate[RunLike, Q], ApplyCallableReturnType[RunLike]]
ParallelProcessor = Callable[

Check warning on line 28 in src/scmdata/groupby.py

View check run for this annotation

Codecov / codecov/patch

src/scmdata/groupby.py#L23-L28

Added lines #L23 - L28 were not covered by tests
Concatenate[
ApplyCallable[RunLike, Q],
Iterable[RunLike],
Q,
],
Iterable[ApplyCallableReturnType[RunLike]],
]


class RunGroupBy(ImplementsArrayReduce, Generic[GenericRun]):
Expand All @@ -27,7 +41,7 @@
"""

def __init__(
self, run: "GenericRun", groups: "Iterable[str]", na_fill_value: float = -10000
self, run: GenericRun, groups: Iterable[str], na_fill_value: float = -10000
):
self.run = run
self.group_keys = groups
Expand All @@ -45,9 +59,9 @@
else:
m = m.fillna(na_fill_value)

self._grouper: "DataFrameGroupBy" = m.groupby(list(groups), group_keys=True)
self._grouper: DataFrameGroupBy = m.groupby(list(groups), group_keys=True)

def _iter_grouped(self) -> "Iterator[GenericRun]":
def _iter_grouped(self) -> Iterator[GenericRun]:
def _try_fill_value(v: MetadataValue) -> MetadataValue:
try:
if float(v) == float(self.na_fill_value):
Expand All @@ -57,7 +71,7 @@
return v

groups: Iterable[
Union[MetadataValue, tuple[MetadataValue, ...]]
MetadataValue | tuple[MetadataValue, ...]
] = self._grouper.groups
for indices in groups:
if not isinstance(indices, Iterable) or isinstance(indices, str):
Expand All @@ -82,20 +96,20 @@

def apply(
self,
func: "Callable[Concatenate[GenericRun, P], Union[GenericRun, pd.DataFrame, None]]",
*args: "P.args",
**kwargs: "P.kwargs",
) -> "GenericRun":
func: Callable[Concatenate[GenericRun, P], GenericRun | (pd.DataFrame | None)],
*args: P.args,
**kwargs: P.kwargs,
) -> GenericRun:
"""
Apply a function to each group and append the results

`func` is called like `func(ar, *args, **kwargs)` for each :class:`ScmRun <scmdata.run.ScmRun>` ``ar``
in this group. If the result of this function call is None, than it is
`func` is called like `func(ar, *args, **kwargs)` for each :class:`ScmRun <scmdata.run.ScmRun>`
group. If the result of this function call is ``None``, than it is
excluded from the results.

The results are appended together using :func:`run_append`. The function
can change the size of the input :class:`ScmRun <scmdata.run.ScmRun>` as long as :func:`run_append`
can be applied to all results.
can change the size of the input :class:`ScmRun <scmdata.run.ScmRun>`
as long as :func:`run_append` can be applied to all results.

Examples
--------
Expand All @@ -109,35 +123,34 @@
Parameters
----------
func
Callable to apply to each timeseries.
Callable to apply to each group.

``*args``
*args
Positional arguments passed to `func`.

``**kwargs``
Used to call `func(ar, **kwargs)` for each array `ar`.
**kwargs
Keyword arguments passed to `func`.

Returns
-------
applied : :class:`ScmRun <scmdata.run.ScmRun>`
The result of splitting, applying and combining this array.
The result of applying and combining.
"""
grouped = self._iter_grouped()
applied = [func(arr, *args, **kwargs) for arr in grouped]
return self._combine(applied)

def apply_parallel(
self,
func: "Callable[Concatenate[GenericRun, P], Union[GenericRun, pd.DataFrame, None]]",
n_jobs: int = 1,
backend: str = "loky",
*args: "P.args",
**kwargs: "P.kwargs",
) -> "GenericRun":
func: ApplyCallable[GenericRun, P],
parallel_processor: ParallelProcessor[GenericRun, P] | None = None,
*args: P.args,
**kwargs: P.kwargs,
) -> GenericRun:
"""
Apply a function to each group in parallel and append the results

Provides the same functionality as :func:`~apply` except that :mod:`joblib` is used to apply
Provides the same functionality as :func:`~apply` except that parallel processing can be
used via the ``parallel_processor`` argument. By default, :mod:`joblib` is used to apply
`func` to each group in parallel. This can be slower than using :func:`~apply` for small
numbers of groups or in the case where `func` is fast as there is overhead setting up the
processing pool.
Expand All @@ -149,41 +162,29 @@
Parameters
----------
func
Callable to apply to each timeseries.

n_jobs
Number of jobs to run in parallel (defaults to a single job which is useful for
debugging purposes). If `-1` all CPUs are used.
Callable to apply to each group.

backend
Backend used for parallelisation. Defaults to 'loky' which uses separate processes for
each worker.
parallel_processor
Parallel processor to use to process the groups. If not provided,
a default joblib parallel processor is used (for details, see
:func:`get_joblib_parallel_processor`).

See :class:`joblib.Parallel` for a more complete description of the available
options.

``*args``
*args
Positional arguments passed to `func`.

``**kwargs``
Used to call `func(ar, **kwargs)` for each array `ar`.
**kwargs
Keyword arguments passed to `func`.

Returns
-------
applied : :class:`ScmRun <scmdata.run.ScmRun>`
The result of splitting, applying and combining this array.
The result of applying and combining.
"""
try:
import joblib # type: ignore
except ImportError as e: # pragma: no cover
raise ImportError(
"joblib is not installed. Run 'pip install joblib'"
) from e
if parallel_processor is None:
parallel_processor = get_joblib_parallel_processor()

grouped = self._iter_grouped()
applied: "list[Union[GenericRun, pd.DataFrame, None]]" = joblib.Parallel(
n_jobs=n_jobs, backend=backend
)(joblib.delayed(func)(arr, *args, **kwargs) for arr in grouped)
applied = parallel_processor(func, grouped, *args, **kwargs)

return self._combine(applied)

def map(self, func, *args, **kwargs):
Expand All @@ -202,8 +203,8 @@
return self.apply(func, *args, **kwargs)

def _combine(
self, applied: "Sequence[Union[GenericRun, pd.DataFrame, None]]"
) -> "GenericRun":
self, applied: Iterable[GenericRun | (pd.DataFrame | None)]
) -> GenericRun:
"""
Recombine the applied objects like the original.
"""
Expand All @@ -219,12 +220,12 @@

def reduce(
self,
func: "Callable[Concatenate[NDArray[np.float_], P], NDArray[np.float_]]",
dim: "Optional[Union[str, Iterable[str]]]" = None,
axis: "Optional[Union[str, Iterable[int]]]" = None,
*args: "P.args",
**kwargs: "P.kwargs",
) -> "GenericRun":
func: Callable[Concatenate[NDArray[np.float_], P], NDArray[np.float_]],
dim: str | Iterable[str] | None = None,
axis: str | Iterable[int] | None = None,
*args: P.args,
**kwargs: P.kwargs,
) -> GenericRun:
"""
Reduce the items in this group by applying `func` along some dimension(s).

Expand Down Expand Up @@ -260,4 +261,59 @@
return self.apply(reduce_array)


def get_joblib_parallel_processor(
n_jobs: int = -1,
backend: str = "loky",
*args: Any,
**kwargs: Any,
) -> ParallelProcessor[RunLike, Q]:
"""
Get parallel processor using :mod:`joblib` as the backend.

Parameters
----------
n_jobs
Number of jobs to run in parallel. If `-1` all CPUs are used.

backend
Backend used for parallelisation. Defaults to 'loky' which uses separate processes for
each worker.
See :class:`joblib.Parallel` for a more complete description of the available
options.

*args
Passed to initialiser of :class:`joblib.Parallel`

**kwargs
Passed to initialiser of :class:`joblib.Parallel`

Returns
-------
Function that can be used for parallel processing in
:meth:`RunGroupBy.apply_parallel`
"""
try:
import joblib
except ImportError as e: # pragma: no cover
raise ImportError("joblib is not installed. Run 'pip install joblib'") from e

processor = joblib.Parallel(*args, n_jobs=n_jobs, backend=backend, **kwargs)

def joblib_parallel_processor(
func: ApplyCallable[RunLike, Q],
groups: Iterable[RunLike],
/,
*args: Q.args,
**kwargs: Q.kwargs,
) -> Iterable[ApplyCallableReturnType[RunLike]]:
prepped_groups = (
joblib.delayed(func)(group, *args, **kwargs) for group in groups
)
applied = processor(prepped_groups)

return applied

return joblib_parallel_processor


ops.inject_reduce_methods(RunGroupBy)
4 changes: 2 additions & 2 deletions src/scmdata/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@

if TYPE_CHECKING:
from numpy.typing import NDArray
from typing_extensions import Concatenate, ParamSpec

Check warning on line 74 in src/scmdata/run.py

View check run for this annotation

Codecov / codecov/patch

src/scmdata/run.py#L74

Added line #L74 was not covered by tests

from scmdata.groupby import RunGroupBy

Check warning on line 76 in src/scmdata/run.py

View check run for this annotation

Codecov / codecov/patch

src/scmdata/run.py#L76

Added line #L76 was not covered by tests

P = ParamSpec("P")

Check warning on line 78 in src/scmdata/run.py

View check run for this annotation

Codecov / codecov/patch

src/scmdata/run.py#L78

Added line #L78 was not covered by tests


def _read_file( # pylint: disable=missing-return-doc
Expand Down Expand Up @@ -1982,10 +1982,10 @@
func : function
Callable to apply to each timeseries.

``*args``
*args
Positional arguments passed to `func`.

``**kwargs``
**kwargs
Used to call `func(ar, **kwargs)` for each array `ar`.

Returns
Expand Down Expand Up @@ -2400,7 +2400,7 @@
if len(m) != 1: # pragma: no cover
raise AssertionError(m)

meta: dict[str, MetadataValue | Iterable[MetadataValue]] = m.to_dict( # type: ignore

Check warning on line 2403 in src/scmdata/run.py

View check run for this annotation

Codecov / codecov/patch

src/scmdata/run.py#L2403

Added line #L2403 was not covered by tests
"list"
)

Expand Down
Empty file removed stubs/.gitkeep
Empty file.
22 changes: 22 additions & 0 deletions stubs/joblib.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Any, Callable, Iterable, ParamSpec, TypeVar

T = TypeVar("T")
P = ParamSpec("P")
Q = ParamSpec("Q")

def delayed(
func: Callable[P, T]
) -> Callable[..., tuple[Callable[P, T], Q.args, Q.kwargs]]: ...

class Parallel:
def __init__(self, *args: Any, n_jobs: int, backend: str, **kwargs: Any): ...
def __call__(
self,
iterable: Iterable[
tuple[
Callable[P, T],
P.args,
P.kwargs,
]
],
) -> Iterable[T]: ...
3 changes: 2 additions & 1 deletion tests/unit/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pytest

from scmdata import ScmRun
from scmdata.groupby import get_joblib_parallel_processor
from scmdata.testing import assert_scmdf_almost_equal

group_tests = pytest.mark.parametrize(
Expand All @@ -28,7 +29,7 @@ def func(df):
return df

if parallel:
res = scm_run.groupby(*g).apply_parallel(func, n_jobs=-1)
res = scm_run.groupby(*g).apply_parallel(func)
else:
res = scm_run.groupby(*g).apply(func)

Expand Down