diff --git a/vizro-core/changelog.d/20240502_175131_antony.milne_data_source_mapping.md b/vizro-core/changelog.d/20240502_175131_antony.milne_data_source_mapping.md new file mode 100644 index 000000000..f1f65e73c --- /dev/null +++ b/vizro-core/changelog.d/20240502_175131_antony.milne_data_source_mapping.md @@ -0,0 +1,48 @@ + + + + + + + + + diff --git a/vizro-core/src/vizro/actions/_actions_utils.py b/vizro-core/src/vizro/actions/_actions_utils.py index 9eb153997..7227ac694 100644 --- a/vizro-core/src/vizro/actions/_actions_utils.py +++ b/vizro-core/src/vizro/actions/_actions_utils.py @@ -173,8 +173,8 @@ def _get_filtered_data( ) -> Dict[ModelID, pd.DataFrame]: filtered_data = {} for target in targets: - data_frame = data_manager._get_component_data(target) - + data_source_name = model_manager[target]["data_frame"] + data_frame = data_manager[data_source_name].load() data_frame = _apply_filters(data_frame=data_frame, ctds_filters=ctds_filters, target=target) data_frame = _apply_filter_interaction( data_frame=data_frame, ctds_filter_interaction=ctds_filter_interaction, target=target diff --git a/vizro-core/src/vizro/managers/_data_manager.py b/vizro-core/src/vizro/managers/_data_manager.py index 137b381c2..288852219 100644 --- a/vizro-core/src/vizro/managers/_data_manager.py +++ b/vizro-core/src/vizro/managers/_data_manager.py @@ -3,7 +3,6 @@ from __future__ import annotations import logging -import os import warnings from typing import Any, Callable, Dict, Optional, Protocol, Union @@ -15,10 +14,8 @@ logger = logging.getLogger(__name__) -# Really ComponentID and DataSourceName should be NewType and not just aliases but then for a user's code to type check +# Really DataSourceName should be NewType and not just aliases but then for a user's code to type check # correctly they would need to cast all strings to these types. -# TODO: remove these type aliases once have moved component to data mapping to models -ComponentID = str DataSourceName = str pd_DataFrameCallable = Callable[[], pd.DataFrame] @@ -186,7 +183,6 @@ class DataManager: def __init__(self): self.__data: Dict[DataSourceName, Union[_DynamicData, _StaticData]] = {} - self.__component_to_data: Dict[ComponentID, DataSourceName] = {} self._frozen_state = False self.cache = Cache(config={"CACHE_TYPE": "NullCache"}) # In future, possibly we will accept just a config dict. Would need to work out whether to handle merging with @@ -227,33 +223,6 @@ def __getitem__(self, name: DataSourceName) -> Union[_DynamicData, _StaticData]: except KeyError as exc: raise KeyError(f"Data source {name} does not exist.") from exc - @_state_modifier - def _add_component(self, component_id: ComponentID, name: DataSourceName): - """Adds a mapping from `component_id` to `name`.""" - # TODO: once have removed self.__component_to_data, we shouldn't need this function any more. - # Maybe always updated capturedcallable data_frame to data source name string then. - if name not in self.__data: - raise KeyError(f"Data source {name} does not exist.") - if component_id in self.__component_to_data: - raise ValueError( - f"Component with id={component_id} already exists and is mapped to data " - f"{self.__component_to_data[component_id]}. Components must uniquely map to a data source across the " - f"whole dashboard. If you are working from a Jupyter Notebook, please either restart the kernel, or " - f"use 'from vizro import Vizro; Vizro._reset()`." - ) - self.__component_to_data[component_id] = name - - def _get_component_data(self, component_id: ComponentID) -> pd.DataFrame: - # TODO: once have removed self.__component_to_data, we shouldn't need this function any more. Calling - # functions would just do data_manager[name].load(). - """Returns the original data for `component_id`.""" - if component_id not in self.__component_to_data: - raise KeyError(f"Component {component_id} does not exist. You need to call add_component first.") - name = self.__component_to_data[component_id] - - logger.debug("Loading data %s on process %s", name, os.getpid()) - return self[name].load() - def _clear(self): # We do not actually call self.cache.clear() because (a) it would only work when self._cache_has_app is True, # which is not the case when e.g. Vizro._reset is called, and (b) because we do not want to accidentally diff --git a/vizro-core/src/vizro/models/_components/_components_utils.py b/vizro-core/src/vizro/models/_components/_components_utils.py index 005b8e2e8..51b01dbd2 100644 --- a/vizro-core/src/vizro/models/_components/_components_utils.py +++ b/vizro-core/src/vizro/models/_components/_components_utils.py @@ -24,28 +24,30 @@ def _callable_mode_validator_factory(mode: str): return validator("figure", allow_reuse=True)(check_callable_mode) -def _process_callable_data_frame(captured_callable, values): +def _process_callable_data_frame(captured_callable): + # Possibly all this validator's functionality should move into CapturedCallable (or a subclass of it) in the + # future. This would mean that data is added to the data manager outside the context of a dashboard though, + # which might not be desirable. data_frame = captured_callable["data_frame"] if isinstance(data_frame, str): # Named data source, which could be dynamic or static. This means px.scatter("iris") from the Python API and - # specification of "data_frame": "iris" through JSON. In these cases, data already exists in the data manager - # and just needs to be linked to the component. - data_source_name = data_frame - else: - # Unnamed data source, which must be a pd.DataFrame and hence static data. This means px.scatter(pd.DataFrame()) - # and is only possible from the Python API. Extract dataframe from the captured function and put it into the - # data manager. - # Unlike with model_manager, it doesn't matter if the random seed is different across workers here. So long as - # we always fetch static data from the data manager by going through the appropriate Figure component, the right - # data source name will be fetched. It also doesn't matter if multiple Figures with the same underlying data - # each have their own entry in the data manager, since the underlying pd.DataFrame will still be the same and - # not copied into each one, so no memory is wasted. - logger.debug("Adding data to data manager for Figure with id %s", values["id"]) - data_source_name = str(uuid.uuid4()) - data_manager[data_source_name] = data_frame - - data_manager._add_component(values["id"], data_source_name) - # No need to keep the data in the captured function any more so remove it to save memory. - del captured_callable["data_frame"] + # specification of "data_frame": "iris" through JSON. In these cases, data already exists in the data manager. + return captured_callable + + # Unnamed data source, which must be a pd.DataFrame and hence static data. This means px.scatter(pd.DataFrame()) + # and is only possible from the Python API. Extract dataframe from the captured function and put it into the + # data manager. + # Unlike with model_manager, it doesn't matter if the random seed is different across workers here. So long as + # we always fetch static data from the data manager by going through the appropriate Figure component, the right + # data source name will be fetched. It also doesn't matter if multiple Figures with the same underlying data + # each have their own entry in the data manager, since the underlying pd.DataFrame will still be the same and + # not copied into each one, so no memory is wasted. + # Replace the "data_frame" argument in the captured callable with the data_source_name for consistency with + # dynamic data and to save memory. This way we always access data via the same interface regardless of whether it's + # static or dynamic. + data_source_name = str(uuid.uuid4()) + data_manager[data_source_name] = data_frame + captured_callable["data_frame"] = data_source_name + return captured_callable diff --git a/vizro-core/src/vizro/models/_components/ag_grid.py b/vizro-core/src/vizro/models/_components/ag_grid.py index b3b319433..604850c86 100644 --- a/vizro-core/src/vizro/models/_components/ag_grid.py +++ b/vizro-core/src/vizro/models/_components/ag_grid.py @@ -51,7 +51,7 @@ class AgGrid(VizroBaseModel): # Convenience wrapper/syntactic sugar. def __call__(self, **kwargs): - kwargs.setdefault("data_frame", data_manager._get_component_data(self.id)) + kwargs.setdefault("data_frame", data_manager[self["data_frame"]].load()) figure = self.figure(**kwargs) figure.id = self._input_component_id return figure diff --git a/vizro-core/src/vizro/models/_components/graph.py b/vizro-core/src/vizro/models/_components/graph.py index 64f83d93a..126944e9a 100644 --- a/vizro-core/src/vizro/models/_components/graph.py +++ b/vizro-core/src/vizro/models/_components/graph.py @@ -50,7 +50,10 @@ class Graph(VizroBaseModel): # Convenience wrapper/syntactic sugar. def __call__(self, **kwargs): - kwargs.setdefault("data_frame", data_manager._get_component_data(str(self.id))) + # This default value is not actually used anywhere at the moment since __call__ is always used with data_frame + # specified. It's here to match Table and AgGrid and because we might want to use __call__ more in future. + # If the functionality of process_callable_data_frame moves to CapturedCallable then this would move there too. + kwargs.setdefault("data_frame", data_manager[self["data_frame"]].load()) fig = self.figure(**kwargs) # Remove top margin if title is provided diff --git a/vizro-core/src/vizro/models/_components/table.py b/vizro-core/src/vizro/models/_components/table.py index 4c681e2a0..4dec78de2 100644 --- a/vizro-core/src/vizro/models/_components/table.py +++ b/vizro-core/src/vizro/models/_components/table.py @@ -50,7 +50,7 @@ class Table(VizroBaseModel): # Convenience wrapper/syntactic sugar. def __call__(self, **kwargs): - kwargs.setdefault("data_frame", data_manager._get_component_data(self.id)) + kwargs.setdefault("data_frame", data_manager[self["data_frame"]].load()) figure = self.figure(**kwargs) figure.id = self._input_component_id return figure diff --git a/vizro-core/src/vizro/models/_controls/filter.py b/vizro-core/src/vizro/models/_controls/filter.py index 26d808f5b..407d5fcb6 100644 --- a/vizro-core/src/vizro/models/_controls/filter.py +++ b/vizro-core/src/vizro/models/_controls/filter.py @@ -112,14 +112,20 @@ def _set_targets(self): for component_id in model_manager._get_page_model_ids_with_figure( page_id=model_manager._get_model_page_id(model_id=ModelID(str(self.id))) ): - data_frame = data_manager._get_component_data(component_id) + # TODO: consider making a helper method in data_manager or elsewhere to reduce this operation being + # duplicated across Filter so much, and/or consider storing the result to avoid repeating it. + # Need to think about this in connection with how to update filters on the fly and duplicated calls + # issue outlined in https://github.com/mckinsey/vizro/pull/398#discussion_r1559120849. + data_source_name = model_manager[component_id]["data_frame"] + data_frame = data_manager[data_source_name].load() if self.column in data_frame.columns: self.targets.append(component_id) if not self.targets: raise ValueError(f"Selected column {self.column} not found in any dataframe on this page.") def _set_column_type(self): - data_frame = data_manager._get_component_data(self.targets[0]) + data_source_name = model_manager[self.targets[0]]["data_frame"] + data_frame = data_manager[data_source_name].load() if is_numeric_dtype(data_frame[self.column]): self._column_type = "numerical" @@ -146,7 +152,8 @@ def _set_numerical_and_temporal_selectors_values(self): min_values = [] max_values = [] for target_id in self.targets: - data_frame = data_manager._get_component_data(target_id) + data_source_name = model_manager[target_id]["data_frame"] + data_frame = data_manager[data_source_name].load() min_values.append(data_frame[self.column].min()) max_values.append(data_frame[self.column].max()) @@ -173,7 +180,8 @@ def _set_categorical_selectors_options(self): if isinstance(self.selector, SELECTORS["categorical"]) and not self.selector.options: options = set() for target_id in self.targets: - data_frame = data_manager._get_component_data(target_id) + data_source_name = model_manager[target_id]["data_frame"] + data_frame = data_manager[data_source_name].load() options |= set(data_frame[self.column]) self.selector.options = sorted(options) diff --git a/vizro-core/src/vizro/models/types.py b/vizro-core/src/vizro/models/types.py index 0e5f2dd36..aea1a100f 100644 --- a/vizro-core/src/vizro/models/types.py +++ b/vizro-core/src/vizro/models/types.py @@ -138,9 +138,9 @@ def __getitem__(self, arg_name: str): """Gets the value of a bound argument.""" return self.__bound_arguments[arg_name] - def __delitem__(self, arg_name: str): - """Deletes a bound argument.""" - del self.__bound_arguments[arg_name] + def __setitem__(self, arg_name: str, value): + """Sets the value of a bound argument.""" + self.__bound_arguments[arg_name] = value @property def _arguments(self): diff --git a/vizro-core/tests/unit/vizro/models/_components/test_ag_grid.py b/vizro-core/tests/unit/vizro/models/_components/test_ag_grid.py index d9e1a0cf2..4a08b35c2 100644 --- a/vizro-core/tests/unit/vizro/models/_components/test_ag_grid.py +++ b/vizro-core/tests/unit/vizro/models/_components/test_ag_grid.py @@ -91,15 +91,11 @@ class TestProcessAgGridDataFrame: def test_process_figure_data_frame_str_df(self, dash_ag_grid_with_str_dataframe, gapminder): data_manager["gapminder"] = gapminder ag_grid = vm.AgGrid(id="ag_grid", figure=dash_ag_grid_with_str_dataframe) - assert data_manager._get_component_data("ag_grid").equals(gapminder) - with pytest.raises(KeyError, match="'data_frame'"): - ag_grid["data_frame"] + assert data_manager[ag_grid["data_frame"]].load().equals(gapminder) def test_process_figure_data_frame_df(self, standard_ag_grid, gapminder): ag_grid = vm.AgGrid(id="ag_grid", figure=standard_ag_grid) - assert data_manager._get_component_data("ag_grid").equals(gapminder) - with pytest.raises(KeyError, match="'data_frame'"): - ag_grid["data_frame"] + assert data_manager[ag_grid["data_frame"]].load().equals(gapminder) class TestPreBuildAgGrid: diff --git a/vizro-core/tests/unit/vizro/models/_components/test_graph.py b/vizro-core/tests/unit/vizro/models/_components/test_graph.py index 9779536b9..d38c24c14 100644 --- a/vizro-core/tests/unit/vizro/models/_components/test_graph.py +++ b/vizro-core/tests/unit/vizro/models/_components/test_graph.py @@ -123,15 +123,11 @@ class TestProcessGraphDataFrame: def test_process_figure_data_frame_str_df(self, standard_px_chart_with_str_dataframe, gapminder): data_manager["gapminder"] = gapminder graph = vm.Graph(id="graph", figure=standard_px_chart_with_str_dataframe) - assert data_manager._get_component_data("graph").equals(gapminder) - with pytest.raises(KeyError, match="'data_frame'"): - graph["data_frame"] + assert data_manager[graph["data_frame"]].load().equals(gapminder) def test_process_figure_data_frame_df(self, standard_px_chart, gapminder): graph = vm.Graph(id="graph", figure=standard_px_chart) - assert data_manager._get_component_data("graph").equals(gapminder) - with pytest.raises(KeyError, match="'data_frame'"): - graph["data_frame"] + assert data_manager[graph["data_frame"]].load().equals(gapminder) class TestBuild: diff --git a/vizro-core/tests/unit/vizro/models/_components/test_table.py b/vizro-core/tests/unit/vizro/models/_components/test_table.py index 363eb673c..6c140af75 100644 --- a/vizro-core/tests/unit/vizro/models/_components/test_table.py +++ b/vizro-core/tests/unit/vizro/models/_components/test_table.py @@ -91,15 +91,11 @@ class TestProcessTableDataFrame: def test_process_figure_data_frame_str_df(self, dash_table_with_str_dataframe, gapminder): data_manager["gapminder"] = gapminder table = vm.Table(id="table", figure=dash_table_with_str_dataframe) - assert data_manager._get_component_data("table").equals(gapminder) - with pytest.raises(KeyError, match="'data_frame'"): - table["data_frame"] + assert data_manager[table["data_frame"]].load().equals(gapminder) def test_process_figure_data_frame_df(self, standard_dash_table, gapminder): table = vm.Table(id="table", figure=standard_dash_table) - assert data_manager._get_component_data("table").equals(gapminder) - with pytest.raises(KeyError, match="'data_frame'"): - table["data_frame"] + assert data_manager[table["data_frame"]].load().equals(gapminder) class TestPreBuildTable: diff --git a/vizro-core/tests/unit/vizro/models/test_types.py b/vizro-core/tests/unit/vizro/models/test_types.py index 619379b64..54d3e7bbc 100644 --- a/vizro-core/tests/unit/vizro/models/test_types.py +++ b/vizro-core/tests/unit/vizro/models/test_types.py @@ -76,11 +76,9 @@ def test_getitem_unknown_args(self, captured_callable): with pytest.raises(KeyError): captured_callable["c"] - def test_delitem(self, captured_callable): - del captured_callable["a"] - - with pytest.raises(KeyError): - captured_callable["a"] + def test_setitem(self, captured_callable): + captured_callable["a"] = 2 + assert captured_callable["a"] == 2 @pytest.mark.parametrize(