diff --git a/modin/pandas/base.py b/modin/pandas/base.py index 298a376f972..bbb3d4068e4 100644 --- a/modin/pandas/base.py +++ b/modin/pandas/base.py @@ -16,6 +16,8 @@ from __future__ import annotations import abc +import copy +import functools import pickle as pkl import re import warnings @@ -105,6 +107,7 @@ "_ipython_canary_method_should_not_exist_", "_ipython_display_", "_repr_mimebundle_", + "_attrs", } _DEFAULT_BEHAVIOUR = { @@ -193,6 +196,26 @@ def _get_repr_axis_label_indexer(labels, num_for_repr): ) +def propagate_self_attrs(method): + """ + Wrap a BasePandasDataset/DataFrame/Series method with a function that automatically deep-copies self.attrs if present. + + This annotation should not be used on special methods like concat, str, and groupby, which may need to + examine multiple sources to reconcile `attrs`. + """ + + @functools.wraps(method) + def wrapper(self, *args, **kwargs): + result = method(self, *args, **kwargs) + if isinstance(result, BasePandasDataset) and len(self._attrs): + # If the result of the method call is a modin.pandas object and `self.attrs` is + # not empty, perform a deep copy of `self.attrs`. + result._attrs = copy.deepcopy(self._attrs) + return result + + return wrapper + + @_inherit_docstrings(pandas.DataFrame, apilink=["pandas.DataFrame", "pandas.Series"]) class BasePandasDataset(ClassLogger): """ @@ -208,6 +231,7 @@ class BasePandasDataset(ClassLogger): _pandas_class = pandas.core.generic.NDFrame _query_compiler: BaseQueryCompiler _siblings: list[BasePandasDataset] + _attrs: dict @cached_property def _is_dataframe(self) -> bool: @@ -1125,6 +1149,20 @@ def at(self, axis=None) -> _LocIndexer: # noqa: PR01, RT01, D200 return _LocIndexer(self) + def _set_attrs(self, key: Any, value: Any) -> dict: # noqa: PR01, RT01, D200 + """ + Set the dictionary of global attributes of this dataset. + """ + self._attrs[key] = value + + def _get_attrs(self) -> dict: # noqa: PR01, RT01, D200 + """ + Get the dictionary of global attributes of this dataset. + """ + return self._attrs + + attrs: dict = property(_get_attrs, _set_attrs) + def at_time(self, time, asof=False, axis=None) -> Self: # noqa: PR01, RT01, D200 """ Select values at particular time of day (e.g., 9:30AM). @@ -3221,6 +3259,7 @@ def tail(self, n=5) -> Self: # noqa: PR01, RT01, D200 return self.iloc[-n:] return self.iloc[len(self) :] + @propagate_self_attrs def take(self, indices, axis=0, **kwargs) -> Self: # noqa: PR01, RT01, D200 """ Return the elements in the given *positional* indices along an axis. diff --git a/modin/pandas/dataframe.py b/modin/pandas/dataframe.py index 4f47c9374e9..774e6531229 100644 --- a/modin/pandas/dataframe.py +++ b/modin/pandas/dataframe.py @@ -151,8 +151,11 @@ def __init__( # Siblings are other dataframes that share the same query compiler. We # use this list to update inplace when there is a shallow copy. self._siblings = [] + self._attrs = {} if isinstance(data, (DataFrame, Series)): self._query_compiler = data._query_compiler.copy() + if len(data._attrs): + self._attrs = copy.deepcopy(data._attrs) if index is not None and any(i not in data.index for i in index): raise NotImplementedError( "Passing non-existant columns or index values to constructor not" @@ -2636,12 +2639,12 @@ def __setattr__(self, key, value) -> None: # - anything in self.__dict__. This includes any attributes that the # user has added to the dataframe with, e.g., `df.c = 3`, and # any attribute that Modin has added to the frame, e.g. - # `_query_compiler` and `_siblings` + # `_query_compiler`, `_siblings`, and "_attrs" # - `_query_compiler`, which Modin initializes before it appears in # __dict__ # - `_siblings`, which Modin initializes before it appears in __dict__ # before it appears in __dict__. - if key in ("_query_compiler", "_siblings") or key in self.__dict__: + if key in ("_attrs", "_query_compiler", "_siblings") or key in self.__dict__: pass # we have to check for the key in `dir(self)` first in order not to trigger columns computation elif key not in dir(self) and key in self: @@ -2938,17 +2941,6 @@ def __dataframe_consortium_standard__( ) return convert_to_standard_compliant_dataframe(self, api_version=api_version) - @property - def attrs(self) -> dict: # noqa: RT01, D200 - """ - Return dictionary of global attributes of this dataset. - """ - - def attrs(df): - return df.attrs - - return self._default_to_pandas(attrs) - @property def style(self): # noqa: RT01, D200 """ diff --git a/modin/pandas/series.py b/modin/pandas/series.py index 11188e85879..bc390b40743 100644 --- a/modin/pandas/series.py +++ b/modin/pandas/series.py @@ -114,8 +114,11 @@ def __init__( # Siblings are other dataframes that share the same query compiler. We # use this list to update inplace when there is a shallow copy. self._siblings = [] + self._attrs = {} if isinstance(data, type(self)): query_compiler = data._query_compiler.copy() + if len(data._attrs): + self._attrs = copy.deepcopy(data._attrs) if index is not None: if any(i not in data.index for i in index): raise NotImplementedError( @@ -2264,17 +2267,6 @@ def where( level=level, ) - @property - def attrs(self) -> dict: # noqa: RT01, D200 - """ - Return dictionary of global attributes of this dataset. - """ - - def attrs(df): - return df.attrs - - return self._default_to_pandas(attrs) - @property def array(self) -> ExtensionArray: # noqa: RT01, D200 """