Skip to content

Commit

Permalink
FEAT-#7401: Implement DataFrame/Series.attrs
Browse files Browse the repository at this point in the history
Signed-off-by: Jonathan Shi <[email protected]>
  • Loading branch information
noloerino committed Sep 24, 2024
1 parent 1c4d173 commit cee01de
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 24 deletions.
39 changes: 39 additions & 0 deletions modin/pandas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from __future__ import annotations

import abc
import copy
import functools

Check notice

Code scanning / CodeQL

Module is imported with 'import' and 'import from' Note

Module 'functools' is imported with both 'import' and 'import from'.
import pickle as pkl
import re
import warnings
Expand Down Expand Up @@ -105,6 +107,7 @@
"_ipython_canary_method_should_not_exist_",
"_ipython_display_",
"_repr_mimebundle_",
"_attrs",
}

_DEFAULT_BEHAVIOUR = {
Expand Down Expand Up @@ -193,6 +196,26 @@ def _get_repr_axis_label_indexer(labels, num_for_repr):
)


def propagate_self_attrs(method):
"""
Wrap a BasePandasDataset/DataFrame/Series method with a function that automatically deep-copies self.attrs if present.
This annotation should not be used on special methods like concat, str, and groupby, which may need to
examine multiple sources to reconcile `attrs`.
"""

@functools.wraps(method)
def wrapper(self, *args, **kwargs):
result = method(self, *args, **kwargs)
if isinstance(result, BasePandasDataset) and len(self._attrs):
# If the result of the method call is a modin.pandas object and `self.attrs` is
# not empty, perform a deep copy of `self.attrs`.
result._attrs = copy.deepcopy(self._attrs)

Check warning on line 213 in modin/pandas/base.py

View check run for this annotation

Codecov / codecov/patch

modin/pandas/base.py#L213

Added line #L213 was not covered by tests
return result

return wrapper


@_inherit_docstrings(pandas.DataFrame, apilink=["pandas.DataFrame", "pandas.Series"])
class BasePandasDataset(ClassLogger):
"""
Expand All @@ -208,6 +231,7 @@ class BasePandasDataset(ClassLogger):
_pandas_class = pandas.core.generic.NDFrame
_query_compiler: BaseQueryCompiler
_siblings: list[BasePandasDataset]
_attrs: dict

@cached_property
def _is_dataframe(self) -> bool:
Expand Down Expand Up @@ -1125,6 +1149,20 @@ def at(self, axis=None) -> _LocIndexer: # noqa: PR01, RT01, D200

return _LocIndexer(self)

def _set_attrs(self, key: Any, value: Any) -> dict: # noqa: PR01, RT01, D200
"""
Set the dictionary of global attributes of this dataset.
"""
self._attrs[key] = value

Check warning on line 1156 in modin/pandas/base.py

View check run for this annotation

Codecov / codecov/patch

modin/pandas/base.py#L1156

Added line #L1156 was not covered by tests

def _get_attrs(self) -> dict: # noqa: PR01, RT01, D200
"""
Get the dictionary of global attributes of this dataset.
"""
return self._attrs

attrs: dict = property(_get_attrs, _set_attrs)

def at_time(self, time, asof=False, axis=None) -> Self: # noqa: PR01, RT01, D200
"""
Select values at particular time of day (e.g., 9:30AM).
Expand Down Expand Up @@ -3221,6 +3259,7 @@ def tail(self, n=5) -> Self: # noqa: PR01, RT01, D200
return self.iloc[-n:]
return self.iloc[len(self) :]

@propagate_self_attrs
def take(self, indices, axis=0, **kwargs) -> Self: # noqa: PR01, RT01, D200
"""
Return the elements in the given *positional* indices along an axis.
Expand Down
18 changes: 5 additions & 13 deletions modin/pandas/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,11 @@ def __init__(
# Siblings are other dataframes that share the same query compiler. We
# use this list to update inplace when there is a shallow copy.
self._siblings = []
self._attrs = {}
if isinstance(data, (DataFrame, Series)):
self._query_compiler = data._query_compiler.copy()
if len(data._attrs):
self._attrs = copy.deepcopy(data._attrs)

Check warning on line 158 in modin/pandas/dataframe.py

View check run for this annotation

Codecov / codecov/patch

modin/pandas/dataframe.py#L158

Added line #L158 was not covered by tests
if index is not None and any(i not in data.index for i in index):
raise NotImplementedError(
"Passing non-existant columns or index values to constructor not"
Expand Down Expand Up @@ -2636,12 +2639,12 @@ def __setattr__(self, key, value) -> None:
# - anything in self.__dict__. This includes any attributes that the
# user has added to the dataframe with, e.g., `df.c = 3`, and
# any attribute that Modin has added to the frame, e.g.
# `_query_compiler` and `_siblings`
# `_query_compiler`, `_siblings`, and "_attrs"
# - `_query_compiler`, which Modin initializes before it appears in
# __dict__
# - `_siblings`, which Modin initializes before it appears in __dict__
# before it appears in __dict__.
if key in ("_query_compiler", "_siblings") or key in self.__dict__:
if key in ("_attrs", "_query_compiler", "_siblings") or key in self.__dict__:
pass
# we have to check for the key in `dir(self)` first in order not to trigger columns computation
elif key not in dir(self) and key in self:
Expand Down Expand Up @@ -2938,17 +2941,6 @@ def __dataframe_consortium_standard__(
)
return convert_to_standard_compliant_dataframe(self, api_version=api_version)

@property
def attrs(self) -> dict: # noqa: RT01, D200
"""
Return dictionary of global attributes of this dataset.
"""

def attrs(df):
return df.attrs

return self._default_to_pandas(attrs)

@property
def style(self): # noqa: RT01, D200
"""
Expand Down
14 changes: 3 additions & 11 deletions modin/pandas/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,11 @@ def __init__(
# Siblings are other dataframes that share the same query compiler. We
# use this list to update inplace when there is a shallow copy.
self._siblings = []
self._attrs = {}
if isinstance(data, type(self)):
query_compiler = data._query_compiler.copy()
if len(data._attrs):
self._attrs = copy.deepcopy(data._attrs)

Check warning on line 121 in modin/pandas/series.py

View check run for this annotation

Codecov / codecov/patch

modin/pandas/series.py#L121

Added line #L121 was not covered by tests
if index is not None:
if any(i not in data.index for i in index):
raise NotImplementedError(
Expand Down Expand Up @@ -2264,17 +2267,6 @@ def where(
level=level,
)

@property
def attrs(self) -> dict: # noqa: RT01, D200
"""
Return dictionary of global attributes of this dataset.
"""

def attrs(df):
return df.attrs

return self._default_to_pandas(attrs)

@property
def array(self) -> ExtensionArray: # noqa: RT01, D200
"""
Expand Down

0 comments on commit cee01de

Please sign in to comment.