Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add apply_parallel function to RunGroupBy #262

Merged
merged 7 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ Redistribution and use in source and binary forms, with or without modification,

3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2 changes: 1 addition & 1 deletion docs/source/_static/tables.css
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Force wide pandas tables to scroll */
.output.text_html {
overflow: scroll;
}
}
7 changes: 7 additions & 0 deletions docs/source/api/scmdata.pyam_compat.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,10 @@ scmdata.pyam\_compat

.. currentmodule:: scmdata.pyam_compat



LongDatetimeIamDataFrame
========================

.. autoclass:: LongDatetimeIamDataFrame
:members:
4 changes: 4 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,10 @@ def setup(app):
"https://pint.readthedocs.io/en/latest",
None,
),
"joblib": (
"https://joblib.readthedocs.io/en/latest",
None,
),
}


Expand Down
4 changes: 2 additions & 2 deletions docs/source/notebooks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ This directory contains notebooks of examples for how to use `scmdata`.

We use [Jupytext](https://github.com/mwouts/jupytext) to encode the notebooks as standard .py files. This makes
it easier to version control notebooks as you get a clear, meaningful diffs. Jupytext also enables these notebooks
to be edited via Jupyter Lab or Jupyter Notebook.
to be edited via Jupyter Lab or Jupyter Notebook.

As part of the CI, these notebooks are run and checked for any errors.
As part of the CI, these notebooks are run and checked for any errors.
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,13 @@ netCDF4 = { version = "*", optional = true }
openpyxl = { version = "*", optional = true }
xlrd = { version = "*", optional = true }
scipy = { version = "*", optional = true }
joblib = { version = "*", optional = true }
notebook = { version = ">=7", optional = true }
pyam-iamc = { version = "<2", optional = true }

[tool.poetry.extras]
plots = ["matplotlib", "seaborn" , "nc-time-axis"]
optional = ["netCDF4", "openpyxl", "xlrd", "scipy", "pyam-iamc"]
optional = ["netCDF4", "openpyxl", "xlrd", "scipy", "pyam-iamc", "joblib" ]
notebooks = ["notebook"]

[tool.poetry.group.tests.dependencies]
Expand All @@ -77,6 +78,7 @@ ruff = "0.0.288"
pre-commit = "^3.3.1"
towncrier = "^23.6.0"
liccheck = "^0.9.1"
pandas-stubs = "<2"

[tool.poetry.group.notebooks.dependencies]
myst-nb = "^0.17.0"
Expand All @@ -101,7 +103,7 @@ show_missing = true
# Regexes for lines to exclude from consideration in addition to the defaults
exclude_also = [
# Don't complain about missing type checking code:
"if TYPE_CHECKING",
"if TYPE_CHECKING:",
]

[tool.mypy]
Expand Down
4 changes: 3 additions & 1 deletion src/scmdata/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@
climate models and then load just the ``Surface Temperature`` data for all models.
"""

from ._database import ScmDatabase # noqa: F401
from ._database import ScmDatabase

__all__ = ["ScmDatabase"]
znicholls marked this conversation as resolved.
Show resolved Hide resolved
170 changes: 123 additions & 47 deletions src/scmdata/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,79 +3,89 @@
"""
import warnings
from collections.abc import Iterable
from typing import TYPE_CHECKING, Callable, Generic, Iterator, Optional, Sequence, Union

import numpy as np
import pandas as pd
from numpy.typing import NDArray
from xarray.core import ops
from xarray.core.common import ImplementsArrayReduce

from scmdata._typing import MetadataValue
from scmdata.run import GenericRun

if TYPE_CHECKING:
from pandas.core.groupby.generic import DataFrameGroupBy
from typing_extensions import Concatenate, ParamSpec

Check warning on line 19 in src/scmdata/groupby.py

View check run for this annotation

Codecov / codecov/patch

src/scmdata/groupby.py#L18-L19

Added lines #L18 - L19 were not covered by tests

P = ParamSpec("P")

Check warning on line 21 in src/scmdata/groupby.py

View check run for this annotation

Codecov / codecov/patch

src/scmdata/groupby.py#L21

Added line #L21 was not covered by tests

def _maybe_wrap_array(original, new_array):
"""
Wrap a transformed array with ``__array_wrap__`` if it can be done safely.

This lets us treat arbitrary functions that take and return ndarray objects
like ufuncs, as long as they return an array with the same shape.
class RunGroupBy(ImplementsArrayReduce, Generic[GenericRun]):
lewisjared marked this conversation as resolved.
Show resolved Hide resolved
"""
GroupBy object specialized to grouping ScmRun objects
"""
# in case func lost array's metadata
if isinstance(new_array, np.ndarray) and new_array.shape == original.shape:
return original.__array_wrap__(new_array)
else:
return new_array

def __init__(
self, run: "GenericRun", groups: "Iterable[str]", na_fill_value: float = -10000
lewisjared marked this conversation as resolved.
Show resolved Hide resolved
):
self.run = run
self.group_keys = groups

class _GroupBy(ImplementsArrayReduce):
lewisjared marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, meta, groups, na_fill_value=-10000):
m = meta.reset_index(drop=True)
m = run.meta.reset_index(drop=True)
self.na_fill_value = float(na_fill_value)

# Work around the bad handling of NaN values in groupbys
if any([np.issubdtype(m[c].dtype, np.number) for c in m]):
if (meta == na_fill_value).any(axis=None):
if (m == na_fill_value).any(axis=None):
raise ValueError(
"na_fill_value conflicts with data value. Choose a na_fill_value "
"not in meta"
)
else:
m = m.fillna(na_fill_value)

self._grouper = m.groupby(list(groups), group_keys=True)
self._grouper: "DataFrameGroupBy" = m.groupby(list(groups), group_keys=True)

def _iter_grouped(self):
def _try_fill_value(v):
def _iter_grouped(self) -> "Iterator[GenericRun]":
def _try_fill_value(v: MetadataValue) -> MetadataValue:
try:
if float(v) == float(self.na_fill_value):
return np.nan
except ValueError:
pass
return v

for indices in self._grouper.groups:
groups: Iterable[
Union[MetadataValue, tuple[MetadataValue, ...]]
] = self._grouper.groups
for indices in groups:
if not isinstance(indices, Iterable) or isinstance(indices, str):
indices = [indices] # noqa: PLW2901
indices_clean: tuple[MetadataValue, ...] = (indices,)
else:
indices_clean = indices

indices = [_try_fill_value(v) for v in indices] # noqa: PLW2901
res = self.run.filter(**{k: v for k, v in zip(self.group_keys, indices)})
indices_clean = tuple(_try_fill_value(v) for v in indices_clean)
filter_kwargs = {k: v for k, v in zip(self.group_keys, indices_clean)}
res = self.run.filter(**filter_kwargs) # type: ignore
if not len(res):
raise ValueError(
f"Empty group for {list(zip(self.group_keys, indices))}"
f"Empty group for {list(zip(self.group_keys, indices_clean))}"
)
yield res

def __iter__(self):
def __iter__(self) -> Iterator[GenericRun]:
"""
Iterate over the groups
"""
return self._iter_grouped()


class RunGroupBy(_GroupBy):
"""
GroupBy object specialized to grouping ScmRun objects
"""

def __init__(self, run, groups):
self.run = run
self.group_keys = groups
super().__init__(run.meta, groups)

def apply(self, func, *args, **kwargs):
def apply(
self,
func: "Callable[Concatenate[GenericRun, P], Union[GenericRun, pd.DataFrame, None]]",
*args: "P.args",
lewisjared marked this conversation as resolved.
Show resolved Hide resolved
**kwargs: "P.kwargs",
) -> "GenericRun":
"""
Apply a function to each group and append the results

Expand All @@ -91,10 +101,9 @@
--------
.. code:: python

>>> def write_csv(arr):
>>> def write_csv(arr: scmdata.ScmRun) -> None:
... variable = arr.get_unique_meta("variable")
... arr.to_csv("out-{}.csv".format(variable))
...
>>> df.groupby("variable").apply(write_csv)

Parameters
Expand All @@ -114,9 +123,67 @@
The result of splitting, applying and combining this array.
"""
grouped = self._iter_grouped()
applied = [
_maybe_wrap_array(arr, func(arr, *args, **kwargs)) for arr in grouped
]
applied = [func(arr, *args, **kwargs) for arr in grouped]
return self._combine(applied)

def apply_parallel(
self,
func: "Callable[Concatenate[GenericRun, P], Union[GenericRun, pd.DataFrame, None]]",
n_jobs: int = 1,
backend: str = "loky",
*args: "P.args",
**kwargs: "P.kwargs",
) -> "GenericRun":
"""
Apply a function to each group in parallel and append the results

Provides the same functionality as :func:`~apply` except that :mod:`joblib` is used to apply
`func` to each group in parallel. This can be slower than using :func:`~apply` for small
numbers of groups or in the case where `func` is fast as there is overhead setting up the
processing pool.

See Also
--------
:func:`~apply`

Parameters
----------
func
Callable to apply to each timeseries.

n_jobs
Number of jobs to run in parallel (defaults to a single job which is useful for
debugging purposes). If `-1` all CPUs are used.
lewisjared marked this conversation as resolved.
Show resolved Hide resolved

backend
Backend used for parallelisation. Defaults to 'loky' which uses separate processes for
each worker.

See :class:`joblib.Parallel` for a more complete description of the available
options.

``*args``
Positional arguments passed to `func`.

``**kwargs``
Used to call `func(ar, **kwargs)` for each array `ar`.

Returns
-------
applied : :class:`ScmRun <scmdata.run.ScmRun>`
The result of splitting, applying and combining this array.
"""
try:
import joblib # type: ignore
lewisjared marked this conversation as resolved.
Show resolved Hide resolved
except ImportError as e: # pragma: no cover
lewisjared marked this conversation as resolved.
Show resolved Hide resolved
raise ImportError(
"joblib is not installed. Run 'pip install joblib'"
) from e

grouped = self._iter_grouped()
applied: "list[Union[GenericRun, pd.DataFrame, None]]" = joblib.Parallel(
n_jobs=n_jobs, backend=backend
)(joblib.delayed(func)(arr, *args, **kwargs) for arr in grouped)
return self._combine(applied)

def map(self, func, *args, **kwargs):
Expand All @@ -134,21 +201,30 @@
warnings.warn("Use RunGroupby.apply instead", DeprecationWarning)
return self.apply(func, *args, **kwargs)

def _combine(self, applied):
def _combine(
self, applied: "Sequence[Union[GenericRun, pd.DataFrame, None]]"
) -> "GenericRun":
"""
Recombine the applied objects like the original.
"""
from scmdata.run import run_append

# Remove all None values
applied = [df for df in applied if df is not None]
applied_clean = [df for df in applied if df is not None]

if len(applied) == 0:
return None
if len(applied_clean) == 0:
return self.run.__class__()
else:
return run_append(applied)

def reduce(self, func, dim=None, axis=None, **kwargs):
return run_append(applied_clean)

def reduce(
self,
func: "Callable[Concatenate[NDArray[np.float_], P], NDArray[np.float_]]",
dim: "Optional[Union[str, Iterable[str]]]" = None,
axis: "Optional[Union[str, Iterable[int]]]" = None,
znicholls marked this conversation as resolved.
Show resolved Hide resolved
*args: "P.args",
**kwargs: "P.kwargs",
) -> "GenericRun":
"""
Reduce the items in this group by applying `func` along some dimension(s).

Expand Down
2 changes: 1 addition & 1 deletion src/scmdata/pyam_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dateutil import parser

try:
from pyam import IamDataFrame
from pyam import IamDataFrame # type: ignore

# mypy can't work out try-except block forces IamDataFrame to be here
class LongDatetimeIamDataFrame(IamDataFrame): # type: ignore
Expand Down
Loading