Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add tables property to base model class #502

Merged
merged 7 commits into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,4 @@ dask-worker-space/
#ruff linting
.ruff_cache
.envrc
pyrightconfig.json
1 change: 1 addition & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Added
- add ``open_mfcsv`` function in ``io`` module for combining multiple CSV files into one dataset. (PR #486)
- Adapters can now clip data that is passed through a python object the same way as through the data catalog. (PR #481)
- Model objects now have a _MODEL_VERSION attribute that plugins can use for compatibility purposes (PR # 495)
- Model class now has methods for getting, setting, reading and writing arbitrary tabular data. (PR #502)

Changed
-------
Expand Down
97 changes: 86 additions & 11 deletions hydromt/models/model_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"DeferedFileClose",
{"ds": xr.Dataset, "org_fn": str, "tmp_fn": str, "close_attempts": int},
)
XArrayDict = Dict[str, Union[xr.DataArray, xr.Dataset]]


class Model(object, metaclass=ABCMeta):
Expand All @@ -60,11 +61,12 @@ class Model(object, metaclass=ABCMeta):
"crs": CRS,
"config": Dict[str, Any],
"geoms": Dict[str, gpd.GeoDataFrame],
"maps": Dict[str, Union[xr.DataArray, xr.Dataset]],
"forcing": Dict[str, Union[xr.DataArray, xr.Dataset]],
"tables": Dict[str, pd.DataFrame],
"maps": XArrayDict,
"forcing": XArrayDict,
"region": gpd.GeoDataFrame,
"results": Dict[str, Union[xr.DataArray, xr.Dataset]],
"states": Dict[str, Union[xr.DataArray, xr.Dataset]],
"results": XArrayDict,
"states": XArrayDict,
}

def __init__(
Expand Down Expand Up @@ -109,15 +111,15 @@ def __init__(

# placeholders
# metadata maps that can be at different resolutions
# TODO do we want read/write maps?
self._config = None # nested dictionary
self._maps = None # dictionary of xr.DataArray and/or xr.Dataset
self._maps: Optional[XArrayDict] = None
self._tables: Dict[str, pd.DataFrame] = None

# NOTE was staticgeoms in <=v0.5
self._geoms = None # dictionary of gdp.GeoDataFrame
self._forcing = None # dictionary of xr.DataArray and/or xr.Dataset
self._states = None # dictionary of xr.DataArray and/or xr.Dataset
self._results = None # dictionary of xr.DataArray and/or xr.Dataset
self._geoms: Optional[Dict[str, gpd.GeoDataFrame]] = None
self._forcing: Optional[XArrayDict] = None
self._states: Optional[XArrayDict] = None
self._results: Optional[XArrayDict] = None
# To be deprecated in future versions!
self._staticmaps = None
self._staticgeoms = None
Expand Down Expand Up @@ -509,6 +511,7 @@ def read(
"config",
"staticmaps",
"maps",
"tables",
"geoms",
"forcing",
"states",
Expand Down Expand Up @@ -537,6 +540,7 @@ def write(
components: List = [
"staticmaps",
"maps",
"tables",
"geoms",
"forcing",
"states",
Expand Down Expand Up @@ -705,6 +709,77 @@ def _configread(self, fn: str):
def _configwrite(self, fn: str):
return config.configwrite(fn, self.config)

@property
def tables(self) -> Dict[str, pd.DataFrame]:
"""Model tables."""
if self._tables is None:
self._tables = dict()
if self._read:
self.read_tables()
return self._tables

def write_tables(self, fn: str = "tables/{name}.csv", **kwargs) -> None:
"""Write tables at <root>/tables."""
if self.tables:
self._assert_write_mode
self.logger.info("Writing table files.")
local_kwargs = {"index": False, "header": True, "sep": ","}
local_kwargs.update(**kwargs)
for name in self.tables:
fn_out = join(self.root, fn.format(name=name))
os.makedirs(dirname(fn_out), exist_ok=True)
self.tables[name].to_csv(fn_out, **local_kwargs)
else:
self.logger.debug("No tables found, skip writing.")

def read_tables(self, fn: str = "tables/{name}.csv", **kwargs) -> None:
"""Read table files at <root>/tables and parse to dict of dataframes."""
self._assert_read_mode
self.logger.info("Reading model table files.")
fns = glob.glob(join(self.root, fn.format(name="*")))
if len(fns) > 0:
for fn in fns:
name = basename(fn).split(".")[0]
tbl = pd.read_csv(fn, **kwargs)
self.set_tables(tbl, name=name)

def set_tables(
self, tables: Union[pd.DataFrame, pd.Series, Dict], name=None
) -> None:
"""Add (a) table(s) <pandas.DataFrame> to model.

Parameters
----------
tables : pandas.DataFrame, pandas.Series or dict
Table(s) to add to model.
Multiple tables can be added at once by passing a dict of tables.
name : str, optional
Name of table, by default None. Required when tables is not a dict.
"""
if not isinstance(tables, dict) and name is None:
raise ValueError("name required when tables is not a dict")
elif not isinstance(tables, dict):
tables = {name: tables}
for name, df in tables.items():
if not (isinstance(df, pd.DataFrame) or isinstance(df, pd.Series)):
raise ValueError(
"table type not recognized, should be pandas DataFrame or Series."
)
if name in self.tables:
if not self._write:
raise IOError(f"Cannot overwrite table {name} in read-only mode")
elif self._read:
self.logger.warning(f"Overwriting table: {name}")

self.tables[name] = df

def get_tables_merged(self) -> pd.DataFrame:
"""Return all tables of a model merged into one dataframe."""
# This is mostly used for convenience and testing.
return pd.concat(
[df.assign(table_origin=name) for name, df in self.tables.items()], axis=0
)

def read_config(self, config_fn: Optional[str] = None):
"""Parse config from file.

Expand Down Expand Up @@ -1484,7 +1559,7 @@ def _cleanup(self, forceful_overwrite=False, max_close_attempts=2) -> List[str]:

def write_nc(
self,
nc_dict: Dict[str, Union[xr.DataArray, xr.Dataset]],
nc_dict: XArrayDict,
fn: str,
gdal_compliant: bool = False,
rename_dims: bool = False,
Expand Down
17 changes: 13 additions & 4 deletions hydromt/models/model_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,7 @@ def read(
"config",
"grid",
"geoms",
"tables",
"forcing",
"states",
"results",
Expand All @@ -695,22 +696,30 @@ def read(
components : List, optional
List of model components to read, each should have an associated
read_<component> method. By default ['config', 'maps', 'grid',
'geoms', 'forcing', 'states', 'results']
'geoms', 'tables', 'forcing', 'states', 'results']
"""
super().read(components=components)

def write(
self,
components: List = ["config", "maps", "grid", "geoms", "forcing", "states"],
components: List = [
"config",
"maps",
"grid",
"geoms",
"tables",
"forcing",
"states",
],
) -> None:
"""Write the complete model schematization and configuration to model files.

Parameters
----------
components : List, optional
List of model components to write, each should have an
associated write_<component> method.
By default ['config', 'maps', 'grid', 'geoms', 'forcing', 'states']
associated write_<component> method. By default
['config', 'maps', 'grid', 'geoms', 'tables', 'forcing', 'states']
"""
super().write(components=components)

Expand Down
6 changes: 4 additions & 2 deletions hydromt/models/model_lumped.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def read(
"config",
"response_units",
"geoms",
"tables",
"forcing",
"states",
"results",
Expand All @@ -192,7 +193,7 @@ def read(
components : List, optional
List of model components to read, each should have an
associated read_<component> method.
By default ['config', 'maps', 'response_units', 'geoms',
By default ['config', 'maps', 'response_units', 'geoms', 'tables',
'forcing', 'states', 'results']
"""
super().read(components=components)
Expand All @@ -203,6 +204,7 @@ def write(
"config",
"response_units",
"geoms",
"tables",
"forcing",
"states",
],
Expand All @@ -214,7 +216,7 @@ def write(
components : List, optional
List of model components to write, each should have an
associated write_<component> method. By default ['config',
'maps', 'response_units', 'geoms', 'forcing', 'states']
'maps', 'response_units', 'geoms', 'tables', 'forcing', 'states']
"""
super().write(components=components)

Expand Down
5 changes: 3 additions & 2 deletions hydromt/models/model_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,7 @@ def read(
"config",
"mesh",
"geoms",
"tables",
"forcing",
"states",
"results",
Expand All @@ -665,7 +666,7 @@ def read(

def write(
self,
components: List = ["config", "mesh", "geoms", "forcing", "states"],
components: List = ["config", "mesh", "geoms", "tables", "forcing", "states"],
) -> None:
"""Write the complete model schematization and configuration to model files.

Expand All @@ -674,7 +675,7 @@ def write(
components : List, optional
List of model components to write, each should have an
associated write_<component> method. By default ['config', 'maps',
'mesh', 'geoms', 'forcing', 'states']
'mesh', 'geoms', 'tables', 'forcing', 'states']
"""
super().write(components=components)

Expand Down
6 changes: 4 additions & 2 deletions hydromt/models/model_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def read(
"config",
"network",
"geoms",
"tables",
"forcing",
"states",
"results",
Expand All @@ -58,7 +59,7 @@ def read(
components : List, optional
List of model components to read, each should have an associated
read_<component> method. By default ['config', 'maps',
'network', 'geoms', 'forcing', 'states', 'results']
'network', 'geoms', 'tables', 'forcing', 'states', 'results']
"""
super().read(components=components)

Expand All @@ -68,6 +69,7 @@ def write(
"config",
"network",
"geoms",
"tables",
"forcing",
"states",
],
Expand All @@ -79,7 +81,7 @@ def write(
components : List, optional
List of model components to write, each should have an
associated write_<component> method. By default ['config', 'maps',
'network', 'geoms', 'forcing', 'states']
'network', 'geoms', 'tables', 'forcing', 'states']
"""
super().write(components=components)

Expand Down
38 changes: 37 additions & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
"""Tests for the hydromt.models module of HydroMT."""

from copy import deepcopy
from os.path import abspath, dirname, isfile, join

import geopandas as gpd
Expand Down Expand Up @@ -171,7 +172,39 @@ def test_model(model, tmpdir):
assert np.all(model.region.total_bounds == model.staticmaps.raster.bounds)


def test_model_append(demda, tmpdir):
def test_model_tables(model, df, tmpdir):
# make a couple copies of the dfs for testing
dfs = {str(i): df.copy() for i in range(5)}
model.set_root(tmpdir, mode="r+") # append mode
clean_model = deepcopy(model)

with pytest.raises(KeyError):
model.tables[1]

for i, d in dfs.items():
model.set_tables(d, name=i)
assert df.equals(model.tables[i])

# now do the same but interating over the stables instead
for i, d in model.tables.items():
model.set_tables(d, name=i)
assert df.equals(model.tables[i])

assert list(model.tables.keys()) == list(map(str, range(5)))

model.write_tables()
clean_model.read_tables()

model_merged = model.get_tables_merged().sort_values(["table_origin", "city"])
clean_model_merged = clean_model.get_tables_merged().sort_values(
["table_origin", "city"]
)
assert np.all(
np.equal(model_merged, clean_model_merged)
), f"model: {model_merged}\nclean_model: {clean_model_merged}"


def test_model_append(demda, df, tmpdir):
# write a model
demda.name = "dem"
mod = GridModel(mode="w", root=str(tmpdir))
Expand All @@ -181,6 +214,7 @@ def test_model_append(demda, tmpdir):
mod.set_forcing(demda, name="dem")
mod.set_states(demda, name="dem")
mod.set_geoms(demda.raster.box, name="dem")
mod.set_tables(df, name="df")
mod.write()
# append to model and check if previous data is still there
mod1 = GridModel(mode="r+", root=str(tmpdir))
Expand All @@ -196,6 +230,8 @@ def test_model_append(demda, tmpdir):
assert "dem" in mod1.states
mod1.set_geoms(demda.raster.box, name="dem1")
assert "dem" in mod1.geoms
mod1.set_tables(df, name="df1")
assert "df" in mod1.tables


@pytest.mark.filterwarnings("ignore:The setup_basemaps")
Expand Down
Loading