From d8c52a4eeb24f07127945a36b00d38081db6eaeb Mon Sep 17 00:00:00 2001 From: Sam Vente Date: Thu, 29 Jun 2023 16:33:36 +0200 Subject: [PATCH 01/27] [no ci] WIP --- hydromt/data_catalog.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/hydromt/data_catalog.py b/hydromt/data_catalog.py index 1b5f08040..234e164f5 100644 --- a/hydromt/data_catalog.py +++ b/hydromt/data_catalog.py @@ -127,6 +127,10 @@ def sources(self) -> Dict: @property def keys(self) -> List: + warnings.warn( + 'Using iterating over the DataCatalog directly is deprecated. Please use cat.get_sources()', + DeprecationWarning, + ) """Returns list of data source names.""" return list(self._sources.keys()) @@ -139,10 +143,18 @@ def predefined_catalogs(self) -> Dict: def __getitem__(self, key: str) -> DataAdapter: """Get the source.""" + warnings.warn( + 'Using iterating over the DataCatalog directly is deprecated. Please use cat.get_source("name")', + DeprecationWarning, + ) return self._sources[key] def __setitem__(self, key: str, value: DataAdapter) -> None: """Set or update adaptors.""" + warnings.warn( + 'Using DataCatalog as a dictionary directly is deprecated. Please use cat.add_source(adapter)', + DeprecationWarning, + ) if not isinstance(value, DataAdapter): raise ValueError(f"Value must be DataAdapter, not {type(key).__name__}.") if key in self._sources: @@ -151,10 +163,18 @@ def __setitem__(self, key: str, value: DataAdapter) -> None: def __iter__(self): """Iterate over sources.""" + warnings.warn( + 'Using iterating over the DataCatalog directly is deprecated. Please use cat.get_sources()', + DeprecationWarning, + ) return self._sources.__iter__() def __len__(self): """Return number of sources.""" + warnings.warn( + 'Using len on DataCatalog directly is deprecated. Please use len(cat.get_sources())', + DeprecationWarning, + ) return self._sources.__len__() def __repr__(self): @@ -165,7 +185,7 @@ def _repr_html_(self): return self.to_dataframe()._repr_html_() def update(self, **kwargs) -> None: - """Add data sources to library.""" + """Add data sources to library or update them.""" for k, v in kwargs.items(): self[k] = v From 7245aadd26837e57fd4abbd078cbd53d38ecf92d Mon Sep 17 00:00:00 2001 From: Sam Vente Date: Mon, 3 Jul 2023 10:39:59 +0200 Subject: [PATCH 02/27] [no ci] WIP --- hydromt/data_catalog.py | 5 +-- hydromt/models/model_grid.py | 2 +- tests/test_data_catalog.py | 66 ++++++++++++++++++++++++++++++++++++ 3 files changed, 70 insertions(+), 3 deletions(-) diff --git a/hydromt/data_catalog.py b/hydromt/data_catalog.py index 234e164f5..df5b74fe8 100644 --- a/hydromt/data_catalog.py +++ b/hydromt/data_catalog.py @@ -127,11 +127,12 @@ def sources(self) -> Dict: @property def keys(self) -> List: + """Returns list of data source names.""" warnings.warn( - 'Using iterating over the DataCatalog directly is deprecated. Please use cat.get_sources()', + 'Using iterating over the DataCatalog directly is deprecated.'\ + 'Please use cat.get_sources()', DeprecationWarning, ) - """Returns list of data source names.""" return list(self._sources.keys()) @property diff --git a/hydromt/models/model_grid.py b/hydromt/models/model_grid.py index bc7f1f548..2ff587d45 100644 --- a/hydromt/models/model_grid.py +++ b/hydromt/models/model_grid.py @@ -20,7 +20,7 @@ logger = logging.getLogger(__name__) -class GridMixin(object): +class GridMixin(Model): # placeholders # xr.Dataset representation of all static parameter maps at the same resolution and # bounds - renamed from staticmaps diff --git a/tests/test_data_catalog.py b/tests/test_data_catalog.py index 5274bcc21..ff9c42f29 100644 --- a/tests/test_data_catalog.py +++ b/tests/test_data_catalog.py @@ -85,6 +85,72 @@ def test_data_catalog_io(tmpdir): # test print print(data_catalog["merit_hydro"]) +def test_versioned_catalogs(tmpdir): + # we want to keep a legacy version embeded in the test code since we're presumably + # going to change the actual catalog. + legacy_esa_catalog = ''' +esa_worldcover: + crs: 4326 + data_type: RasterDataset + driver: raster + kwargs: + chunks: + x: 36000 + y: 36000 + meta: + category: landuse + source_license: CC BY 4.0 + source_url: https://doi.org/10.5281/zenodo.5571936 + source_version: v100 + path: landuse/esa_worldcover/esa-worldcover.vrt +''' + + aws_esa_catalog = ''' +esa_worldcover: + crs: 4326 + data_type: RasterDataset + driver: raster + filesystem: s3 + kwargs: + storage_options: + anon: true + meta: + category: landuse + source_license: CC BY 4.0 + source_url: https://doi.org/10.5281/zenodo.5571936 + source_version: v100 + path: s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt + rename: + ESA_WorldCover_10m_2020_v100_Map_AWS: landuse +''' + + aws_base_esa_catalog = ''' +esa_worldcover: + crs: 4326 + data_type: RasterDataset + driver: raster + meta: + category: landuse + source_license: CC BY 4.0 + source_url: https://doi.org/10.5281/zenodo.5571936 + source_version: v100 + versions: + - catalog_name: aws_data + path: path: s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt + rename: + ESA_WorldCover_10m_2020_v100_Map_AWS: landuse + filesystem: s3 + kwargs: + storage_options: + anon: true + - catalog_name: deltares_data + path: landuse/esa_worldcover/esa-worldcover.vrt + kwargs: + chunks: + x: 36000 + y: 36000 +''' + legacy_data_catalog = DataCatalog(data_libs=[legacy_esa_catalog]) @pytest.mark.filterwarnings('ignore:"from_artifacts" is deprecated:DeprecationWarning') def test_data_catalog(tmpdir): From b5fb06f0ea89c5d6a6d75dc6aa3abf0d1d227c79 Mon Sep 17 00:00:00 2001 From: Dirk Eilander Date: Wed, 28 Jun 2023 17:48:49 +0200 Subject: [PATCH 03/27] fix #413 --- hydromt/data_adapter/caching.py | 5 +++-- hydromt/data_catalog.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/hydromt/data_adapter/caching.py b/hydromt/data_adapter/caching.py index fc3b3536c..ad07ff527 100644 --- a/hydromt/data_adapter/caching.py +++ b/hydromt/data_adapter/caching.py @@ -5,6 +5,7 @@ from ast import literal_eval from os.path import basename, dirname, isdir, isfile, join from pathlib import Path +from typing import Union from urllib.parse import urlparse import geopandas as gpd @@ -19,10 +20,10 @@ HYDROMT_DATADIR = join(Path.home(), ".hydromt_data") -def _uri_validator(uri: str) -> bool: +def _uri_validator(uri: Union[str, Path]) -> bool: """Check if uri is valid.""" try: - result = urlparse(uri) + result = urlparse(str(uri)) return all([result.scheme, result.netloc]) except ValueError | AttributeError: return False diff --git a/hydromt/data_catalog.py b/hydromt/data_catalog.py index df5b74fe8..44c00b874 100644 --- a/hydromt/data_catalog.py +++ b/hydromt/data_catalog.py @@ -988,7 +988,7 @@ def _parse_data_dict( # Only for local files path = source.pop("path") # if remote path, keep as is else call abs_path method to solve local files - if not _uri_validator(path): + if not _uri_validator(str(path)): path = abs_path(root, path) meta = source.pop("meta", {}) if "category" not in meta and category is not None: @@ -1025,7 +1025,7 @@ def _parse_data_dict( def _yml_from_uri_or_path(uri_or_path: Union[Path, str]) -> Dict: - if _uri_validator(uri_or_path): + if _uri_validator(str(uri_or_path)): with requests.get(uri_or_path, stream=True) as r: if r.status_code != 200: raise IOError(f"URL {r.content}: {uri_or_path}") From 977c680cf9cd950a925d2e6f8f1f181a6e409132 Mon Sep 17 00:00:00 2001 From: Dirk Eilander Date: Thu, 29 Jun 2023 14:12:47 +0200 Subject: [PATCH 04/27] add test and changelog --- docs/changelog.rst | 1 + tests/test_data_catalog.py | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index 157b3ac2c..ebac50351 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -26,6 +26,7 @@ Fixed ----- - Order of renaming variables in get_rasterdataset for x,y dimensions. PR #324 - fix bug in ``get_basin_geometry`` for region kind 'subbasin' if no stream or outlet option is specified. +- fix use of Path objects in parsing data catalog files. PR #429 Deprecated ---------- diff --git a/tests/test_data_catalog.py b/tests/test_data_catalog.py index ff9c42f29..0dd751495 100644 --- a/tests/test_data_catalog.py +++ b/tests/test_data_catalog.py @@ -2,6 +2,7 @@ import os from os.path import abspath, dirname, join +from pathlib import Path import geopandas as gpd import pandas as pd @@ -27,6 +28,10 @@ def test_parser(): dd_out = _parse_data_dict(dd, root=root) assert isinstance(dd_out["test"], RasterDatasetAdapter) assert dd_out["test"].path == abspath(dd["test"]["path"]) + # test with Path object + dd["test"].update(path=Path(dd["test"]["path"])) + dd_out = _parse_data_dict(dd, root=root) + assert dd_out["test"].path == abspath(dd["test"]["path"]) # rel path dd = { "test": { @@ -62,6 +67,7 @@ def test_parser(): assert len(dd_out) == 6 assert dd_out["test_a_1"].path == abspath(join(root, "data_1.tif")) assert "placeholders" not in dd_out["test_a_1"].to_dict() + # errors with pytest.raises(ValueError, match="Missing required path argument"): _parse_data_dict({"test": {}}) From 124afafc29afb887185edc9bbb8935a1c1baa00a Mon Sep 17 00:00:00 2001 From: Sam Vente Date: Mon, 3 Jul 2023 12:45:00 +0200 Subject: [PATCH 05/27] [no ci] WIP --- hydromt/data_catalog.py | 8 +++ tests/data/aws_esa_worldcover.yml | 16 +++++ tests/data/legacy_esa_worldcover.yml | 14 ++++ tests/data/merged_esa_worldcover.yml | 24 +++++++ tests/test_data_catalog.py | 103 ++++++++++----------------- 5 files changed, 101 insertions(+), 64 deletions(-) create mode 100644 tests/data/aws_esa_worldcover.yml create mode 100644 tests/data/legacy_esa_worldcover.yml create mode 100644 tests/data/merged_esa_worldcover.yml diff --git a/hydromt/data_catalog.py b/hydromt/data_catalog.py index 44c00b874..a9e7416a6 100644 --- a/hydromt/data_catalog.py +++ b/hydromt/data_catalog.py @@ -141,6 +141,14 @@ def predefined_catalogs(self) -> Dict: if not self._catalogs: self.set_predefined_catalogs() return self._catalogs + + def get_source(self, key: str, catalog_name=None) -> DataAdapter: + """Get the source.""" + warnings.warn( + 'Using iterating over the DataCatalog directly is deprecated. Please use cat.get_source("name")', + DeprecationWarning, + ) + return self._sources[key] def __getitem__(self, key: str) -> DataAdapter: """Get the source.""" diff --git a/tests/data/aws_esa_worldcover.yml b/tests/data/aws_esa_worldcover.yml new file mode 100644 index 000000000..6da764bda --- /dev/null +++ b/tests/data/aws_esa_worldcover.yml @@ -0,0 +1,16 @@ +esa_worldcover: + crs: 4326 + data_type: RasterDataset + driver: raster + filesystem: s3 + kwargs: + storage_options: + anon: true + meta: + category: landuse + source_license: CC BY 4.0 + source_url: https://doi.org/10.5281/zenodo.5571936 + source_version: v100 + path: s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt + rename: + ESA_WorldCover_10m_2020_v100_Map_AWS: landuse \ No newline at end of file diff --git a/tests/data/legacy_esa_worldcover.yml b/tests/data/legacy_esa_worldcover.yml new file mode 100644 index 000000000..24f981fa2 --- /dev/null +++ b/tests/data/legacy_esa_worldcover.yml @@ -0,0 +1,14 @@ +esa_worldcover: + crs: 4326 + data_type: RasterDataset + driver: raster + kwargs: + chunks: + x: 36000 + y: 36000 + meta: + category: landuse + source_license: CC BY 4.0 + source_url: https://doi.org/10.5281/zenodo.5571936 + source_version: v100 + path: landuse/esa_worldcover/esa-worldcover.vrt \ No newline at end of file diff --git a/tests/data/merged_esa_worldcover.yml b/tests/data/merged_esa_worldcover.yml new file mode 100644 index 000000000..1ed00f1ad --- /dev/null +++ b/tests/data/merged_esa_worldcover.yml @@ -0,0 +1,24 @@ +esa_worldcover: + crs: 4326 + data_type: RasterDataset + driver: raster + meta: + category: landuse + source_license: CC BY 4.0 + source_url: https://doi.org/10.5281/zenodo.5571936 + source_version: v100 + versions: + - catalog_name: aws_data + path: path: s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt + rename: + ESA_WorldCover_10m_2020_v100_Map_AWS: landuse + filesystem: s3 + kwargs: + storage_options: + anon: true + - catalog_name: deltares_data + path: landuse/esa_worldcover/esa-worldcover.vrt + kwargs: + chunks: + x: 36000 + y: 36000 \ No newline at end of file diff --git a/tests/test_data_catalog.py b/tests/test_data_catalog.py index 0dd751495..ef8e32862 100644 --- a/tests/test_data_catalog.py +++ b/tests/test_data_catalog.py @@ -8,11 +8,13 @@ import pandas as pd import pytest import xarray as xr +from yaml import load, safe_load, FullLoader, dump from hydromt.data_adapter import DataAdapter, RasterDatasetAdapter from hydromt.data_catalog import DataCatalog, _parse_data_dict CATALOGDIR = join(dirname(abspath(__file__)), "..", "data", "catalogs") +DATADIR = join(dirname(abspath(__file__)), "data") def test_parser(): @@ -94,73 +96,46 @@ def test_data_catalog_io(tmpdir): def test_versioned_catalogs(tmpdir): # we want to keep a legacy version embeded in the test code since we're presumably # going to change the actual catalog. - legacy_esa_catalog = ''' -esa_worldcover: - crs: 4326 - data_type: RasterDataset - driver: raster - kwargs: - chunks: - x: 36000 - y: 36000 - meta: - category: landuse - source_license: CC BY 4.0 - source_url: https://doi.org/10.5281/zenodo.5571936 - source_version: v100 - path: landuse/esa_worldcover/esa-worldcover.vrt -''' - - aws_esa_catalog = ''' -esa_worldcover: - crs: 4326 - data_type: RasterDataset - driver: raster - filesystem: s3 - kwargs: - storage_options: - anon: true - meta: - category: landuse - source_license: CC BY 4.0 - source_url: https://doi.org/10.5281/zenodo.5571936 - source_version: v100 - path: s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt - rename: - ESA_WorldCover_10m_2020_v100_Map_AWS: landuse -''' - - aws_base_esa_catalog = ''' -esa_worldcover: - crs: 4326 - data_type: RasterDataset - driver: raster - meta: - category: landuse - source_license: CC BY 4.0 - source_url: https://doi.org/10.5281/zenodo.5571936 - source_version: v100 - versions: - - catalog_name: aws_data - path: path: s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt - rename: - ESA_WorldCover_10m_2020_v100_Map_AWS: landuse - filesystem: s3 - kwargs: - storage_options: - anon: true - - catalog_name: deltares_data - path: landuse/esa_worldcover/esa-worldcover.vrt - kwargs: - chunks: - x: 36000 - y: 36000 -''' - legacy_data_catalog = DataCatalog(data_libs=[legacy_esa_catalog]) + + legacy_yml_fn = join(DATADIR, "legacy_esa_worldcover.yml") + aws_yml_fn = join(DATADIR, "aws_esa_worldcover.yml") + merged_yml_fn = join(DATADIR, "merged_esa_worldcover.yml") + legacy_data_catalog = DataCatalog(data_libs=[legacy_yml_fn]) + with open(legacy_yml_fn, 'r') as f: + legacy_esa_yml = load(f, Loader=FullLoader) + with pytest.deprecated_call(): + _ = legacy_data_catalog['esa_worldcover'] + # with pytest.raises(KeyError): + # legacy_data_catalog.get_source('esa_worldcover', catalog_name="aws_data") + print(legacy_data_catalog.get_source('esa_worldcover')) + print(dump(legacy_esa_yml)) + assert legacy_data_catalog.get_source('esa_worldcover') == dump(legacy_esa_yml) + + aws_data_catalog = DataCatalog(data_libs=[aws_yml_fn]) + with open(aws_yml_fn, 'r') as f: + aws_esa_yml = safe_load(f,Loader=FullLoader) + assert aws_data_catalog.get_source('esa_worldcover') == aws_esa_yml + assert aws_data_catalog.get_source('esa_worldcover', catalog_name="aws_data") == aws_esa_yml + + merged_catalog = DataCatalog(data_libs=[legacy_esa_yml,aws_esa_yml]) + with open(merged_yml_fn, 'r') as f: + merged_esa_yml = safe_load(f,Loader=FullLoader) + assert merged_catalog.get_source('esa_worldcover') == aws_esa_yml + assert merged_catalog.get_source('esa_worldcover', catalog_name="aws_data") == aws_esa_yml + assert merged_catalog.get_source('esa_worldcover', catalog_name="deltares_data") == legacy_esa_yml + + dst_merged_yml_fn = join(tmpdir, "dst_merged_esa_worldcover.yml") + merged_esa_yml.to_yml(dst_merged_yml_fn) + + with open(dst_merged_yml_fn, 'r') as f: + reloaded_merged_esa_yml = load(f) + + assert reloaded_merged_esa_yml == merged_esa_yml + @pytest.mark.filterwarnings('ignore:"from_artifacts" is deprecated:DeprecationWarning') def test_data_catalog(tmpdir): - data_catalog = DataCatalog(data_libs=None) # NOTE: legacy code! + data_catalog = DataCatalog(data_libs=None) # initialized with empty dict assert len(data_catalog._sources) == 0 # global data sources from artifacts are automatically added From 5a2b6b4f0235559dc527f72790b23033e497dcad Mon Sep 17 00:00:00 2001 From: Sam Vente Date: Mon, 3 Jul 2023 13:27:51 +0200 Subject: [PATCH 06/27] wip --- hydromt/data_catalog.py | 16 +++++------- tests/test_data_catalog.py | 50 ++++++++++++++------------------------ 2 files changed, 24 insertions(+), 42 deletions(-) diff --git a/hydromt/data_catalog.py b/hydromt/data_catalog.py index a9e7416a6..99382857f 100644 --- a/hydromt/data_catalog.py +++ b/hydromt/data_catalog.py @@ -129,8 +129,8 @@ def sources(self) -> Dict: def keys(self) -> List: """Returns list of data source names.""" warnings.warn( - 'Using iterating over the DataCatalog directly is deprecated.'\ - 'Please use cat.get_sources()', + "Using iterating over the DataCatalog directly is deprecated." + "Please use cat.get_sources()", DeprecationWarning, ) return list(self._sources.keys()) @@ -141,13 +141,9 @@ def predefined_catalogs(self) -> Dict: if not self._catalogs: self.set_predefined_catalogs() return self._catalogs - + def get_source(self, key: str, catalog_name=None) -> DataAdapter: """Get the source.""" - warnings.warn( - 'Using iterating over the DataCatalog directly is deprecated. Please use cat.get_source("name")', - DeprecationWarning, - ) return self._sources[key] def __getitem__(self, key: str) -> DataAdapter: @@ -161,7 +157,7 @@ def __getitem__(self, key: str) -> DataAdapter: def __setitem__(self, key: str, value: DataAdapter) -> None: """Set or update adaptors.""" warnings.warn( - 'Using DataCatalog as a dictionary directly is deprecated. Please use cat.add_source(adapter)', + "Using DataCatalog as a dictionary directly is deprecated. Please use cat.add_source(adapter)", DeprecationWarning, ) if not isinstance(value, DataAdapter): @@ -173,7 +169,7 @@ def __setitem__(self, key: str, value: DataAdapter) -> None: def __iter__(self): """Iterate over sources.""" warnings.warn( - 'Using iterating over the DataCatalog directly is deprecated. Please use cat.get_sources()', + "Using iterating over the DataCatalog directly is deprecated. Please use cat.get_sources()", DeprecationWarning, ) return self._sources.__iter__() @@ -181,7 +177,7 @@ def __iter__(self): def __len__(self): """Return number of sources.""" warnings.warn( - 'Using len on DataCatalog directly is deprecated. Please use len(cat.get_sources())', + "Using len on DataCatalog directly is deprecated. Please use len(cat.get_sources())", DeprecationWarning, ) return self._sources.__len__() diff --git a/tests/test_data_catalog.py b/tests/test_data_catalog.py index ef8e32862..3210a07b5 100644 --- a/tests/test_data_catalog.py +++ b/tests/test_data_catalog.py @@ -8,7 +8,6 @@ import pandas as pd import pytest import xarray as xr -from yaml import load, safe_load, FullLoader, dump from hydromt.data_adapter import DataAdapter, RasterDatasetAdapter from hydromt.data_catalog import DataCatalog, _parse_data_dict @@ -93,45 +92,32 @@ def test_data_catalog_io(tmpdir): # test print print(data_catalog["merit_hydro"]) + def test_versioned_catalogs(tmpdir): # we want to keep a legacy version embeded in the test code since we're presumably # going to change the actual catalog. legacy_yml_fn = join(DATADIR, "legacy_esa_worldcover.yml") aws_yml_fn = join(DATADIR, "aws_esa_worldcover.yml") - merged_yml_fn = join(DATADIR, "merged_esa_worldcover.yml") + # merged_yml_fn = join(DATADIR, "merged_esa_worldcover.yml") legacy_data_catalog = DataCatalog(data_libs=[legacy_yml_fn]) - with open(legacy_yml_fn, 'r') as f: - legacy_esa_yml = load(f, Loader=FullLoader) - with pytest.deprecated_call(): - _ = legacy_data_catalog['esa_worldcover'] - # with pytest.raises(KeyError): - # legacy_data_catalog.get_source('esa_worldcover', catalog_name="aws_data") - print(legacy_data_catalog.get_source('esa_worldcover')) - print(dump(legacy_esa_yml)) - assert legacy_data_catalog.get_source('esa_worldcover') == dump(legacy_esa_yml) - + assert legacy_data_catalog.get_source("esa_worldcover").path.endswith( + "landuse/esa_worldcover/esa-worldcover.vrt" + ) aws_data_catalog = DataCatalog(data_libs=[aws_yml_fn]) - with open(aws_yml_fn, 'r') as f: - aws_esa_yml = safe_load(f,Loader=FullLoader) - assert aws_data_catalog.get_source('esa_worldcover') == aws_esa_yml - assert aws_data_catalog.get_source('esa_worldcover', catalog_name="aws_data") == aws_esa_yml - - merged_catalog = DataCatalog(data_libs=[legacy_esa_yml,aws_esa_yml]) - with open(merged_yml_fn, 'r') as f: - merged_esa_yml = safe_load(f,Loader=FullLoader) - assert merged_catalog.get_source('esa_worldcover') == aws_esa_yml - assert merged_catalog.get_source('esa_worldcover', catalog_name="aws_data") == aws_esa_yml - assert merged_catalog.get_source('esa_worldcover', catalog_name="deltares_data") == legacy_esa_yml - - dst_merged_yml_fn = join(tmpdir, "dst_merged_esa_worldcover.yml") - merged_esa_yml.to_yml(dst_merged_yml_fn) - - with open(dst_merged_yml_fn, 'r') as f: - reloaded_merged_esa_yml = load(f) - - assert reloaded_merged_esa_yml == merged_esa_yml - + assert ( + aws_data_catalog.get_source("esa_worldcover").path + == "s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt" + ) + merged_catalog = DataCatalog(data_libs=[legacy_yml_fn, aws_yml_fn]) + assert ( + merged_catalog.get_source("esa_worldcover").path + == "s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt" + ) + assert merged_catalog.get_source( + "esa_worldcover", catalog_name="legacy_esa_worldcover" + ).path.endswith("landuse/esa_worldcover/esa-worldcover.vrt") + @pytest.mark.filterwarnings('ignore:"from_artifacts" is deprecated:DeprecationWarning') def test_data_catalog(tmpdir): From 8deff0b6ce63c5f3fd6758ae54a5babc6f7ddf9d Mon Sep 17 00:00:00 2001 From: Sam Vente Date: Tue, 4 Jul 2023 07:44:28 +0200 Subject: [PATCH 07/27] [no ci] WIP --- hydromt/data_catalog.py | 23 +++++++++++++++-------- tests/test_data_catalog.py | 2 ++ 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/hydromt/data_catalog.py b/hydromt/data_catalog.py index 99382857f..fc37a63e0 100644 --- a/hydromt/data_catalog.py +++ b/hydromt/data_catalog.py @@ -145,6 +145,11 @@ def predefined_catalogs(self) -> Dict: def get_source(self, key: str, catalog_name=None) -> DataAdapter: """Get the source.""" return self._sources[key] + + def add_source(self,key: str, adapter: DataAdapter) -> None: + if key in self._sources: + self.logger.warning(f"Overwriting data source {key}.") + return self._sources.__setitem__(key, value) def __getitem__(self, key: str) -> DataAdapter: """Get the source.""" @@ -152,7 +157,7 @@ def __getitem__(self, key: str) -> DataAdapter: 'Using iterating over the DataCatalog directly is deprecated. Please use cat.get_source("name")', DeprecationWarning, ) - return self._sources[key] + return self.get_source(key) def __setitem__(self, key: str, value: DataAdapter) -> None: """Set or update adaptors.""" @@ -160,16 +165,13 @@ def __setitem__(self, key: str, value: DataAdapter) -> None: "Using DataCatalog as a dictionary directly is deprecated. Please use cat.add_source(adapter)", DeprecationWarning, ) - if not isinstance(value, DataAdapter): - raise ValueError(f"Value must be DataAdapter, not {type(key).__name__}.") - if key in self._sources: - self.logger.warning(f"Overwriting data source {key}.") - return self._sources.__setitem__(key, value) + self.add_source(key,value) + def __iter__(self): """Iterate over sources.""" warnings.warn( - "Using iterating over the DataCatalog directly is deprecated. Please use cat.get_sources()", + "Using iterating over the DataCatalog directly is deprecated. Please use cat.iter_sources()", DeprecationWarning, ) return self._sources.__iter__() @@ -192,7 +194,12 @@ def _repr_html_(self): def update(self, **kwargs) -> None: """Add data sources to library or update them.""" for k, v in kwargs.items(): - self[k] = v + self.add_source(k,v) + + def update_sources(self, **kwargs) -> None: + """Add data sources to library or update them.""" + for k, v in kwargs.items(): + self.add_source(k,v) def set_predefined_catalogs(self, urlpath: Union[Path, str] = None) -> Dict: """Initialise the predefined catalogs.""" diff --git a/tests/test_data_catalog.py b/tests/test_data_catalog.py index 3210a07b5..2d4347a12 100644 --- a/tests/test_data_catalog.py +++ b/tests/test_data_catalog.py @@ -114,6 +114,8 @@ def test_versioned_catalogs(tmpdir): merged_catalog.get_source("esa_worldcover").path == "s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt" ) + print(merged_catalog.get_source("esa_worldcover")) + breakpoint() assert merged_catalog.get_source( "esa_worldcover", catalog_name="legacy_esa_worldcover" ).path.endswith("landuse/esa_worldcover/esa-worldcover.vrt") From cd4a3ffd84953c45794530e41b5e15cdb0259f7d Mon Sep 17 00:00:00 2001 From: Sam Vente Date: Mon, 10 Jul 2023 10:30:21 +0200 Subject: [PATCH 08/27] [no ci] WIP --- hydromt/cli/api.py | 5 +-- hydromt/data_catalog.py | 64 ++++++++++++++++++++++------ tests/data/aws_esa_worldcover.yml | 2 +- tests/data/legacy_esa_worldcover.yml | 2 +- tests/data/merged_esa_worldcover.yml | 2 +- tests/test_basin_mask.py | 2 +- tests/test_data_adapter.py | 44 +++++++++++-------- tests/test_data_catalog.py | 17 ++++---- 8 files changed, 91 insertions(+), 47 deletions(-) diff --git a/hydromt/cli/api.py b/hydromt/cli/api.py index 8d4aaa390..76b52434a 100644 --- a/hydromt/cli/api.py +++ b/hydromt/cli/api.py @@ -100,13 +100,12 @@ def get_datasets(data_libs: Union[List, str]) -> Dict: for accepted yaml format. """ data_catalog = DataCatalog(data_libs) - datasets = data_catalog.sources dataset_sources = { "RasterDatasetSource": [], "GeoDatasetSource": [], "GeoDataframeSource": [], } - for k, v in datasets.items(): + for k, v in data_catalog.iter_sources(): if v.data_type == "RasterDataset": dataset_sources["RasterDatasetSource"].append(k) elif v.data_type == "GeoDataFrame": @@ -167,7 +166,7 @@ def get_region( # retrieve global hydrography data (lazy!) ds_org = data_catalog.get_rasterdataset(hydrography_fn) if "bounds" not in region: - region.update(basin_index=data_catalog[basin_index_fn]) + region.update(basin_index=data_catalog.get_source(basin_index_fn)) # get basin geometry geom, xy = workflows.get_basin_geometry( ds=ds_org, diff --git a/hydromt/data_catalog.py b/hydromt/data_catalog.py index c49ca194c..9e6271a0c 100644 --- a/hydromt/data_catalog.py +++ b/hydromt/data_catalog.py @@ -128,15 +128,18 @@ def sources(self) -> Dict: return self._sources @property - def keys(self) -> List: + def keys(self) -> List[str]: """Returns list of data source names.""" warnings.warn( "Using iterating over the DataCatalog directly is deprecated." - "Please use cat.get_sources()", + "Please use cat.get_source()", DeprecationWarning, ) return list(self._sources.keys()) + def get_source_names(self) -> List[str]: + return list(self._sources.keys()) + @property def predefined_catalogs(self) -> Dict: """Return all predefined catalogs.""" @@ -144,14 +147,36 @@ def predefined_catalogs(self) -> Dict: self.set_predefined_catalogs() return self._catalogs - def get_source(self, key: str, catalog_name=None) -> DataAdapter: + def get_source(self, key: str, provider=None) -> DataAdapter: """Get the source.""" + if key not in self._sources: + raise KeyError( + f"Requested unknown data source: {key} available sources are: {sorted(list(self._sources.keys()))}" + ) + + available_providers = self._sources[key] + if provider is not None: + if provider not in available_providers: + raise KeyError( + f"Requested unknown proveder {provider} for data_source {key} available providers are {sorted(list(available_providers.keys()))}" + ) + else: + return available_providers[provider] + else: + return available_providers["last"] + + def add_source(self, key: str, adapter: DataAdapter) -> None: + if not isinstance(adapter, DataAdapter): + raise ValueError("Value must be DataAdapter") + + if key not in self._sources: + self._sources[key] = dict() + + # TODO catalgos here need to be diffed to construct common base + + self._sources[key]["last"] = adapter + self._sources[key][adapter.catalog_name] = adapter return self._sources[key] - - def add_source(self,key: str, adapter: DataAdapter) -> None: - if key in self._sources: - self.logger.warning(f"Overwriting data source {key}.") - return self._sources.__setitem__(key, value) def __getitem__(self, key: str) -> DataAdapter: """Get the source.""" @@ -167,8 +192,13 @@ def __setitem__(self, key: str, value: DataAdapter) -> None: "Using DataCatalog as a dictionary directly is deprecated. Please use cat.add_source(adapter)", DeprecationWarning, ) - self.add_source(key,value) + self.add_source(key, value) + def iter_sources(self) -> Tuple[str, DataAdapter]: + for source_name, available_providers in self._sources.items(): + # print(available_providers) + for provider, adapter in available_providers.items(): + yield (source_name, adapter) def __iter__(self): """Iterate over sources.""" @@ -181,7 +211,7 @@ def __iter__(self): def __len__(self): """Return number of sources.""" warnings.warn( - "Using len on DataCatalog directly is deprecated. Please use len(cat.get_sources())", + "Using len on DataCatalog directly is deprecated. Please use len(cat.get_source())", DeprecationWarning, ) return self._sources.__len__() @@ -196,12 +226,12 @@ def _repr_html_(self): def update(self, **kwargs) -> None: """Add data sources to library or update them.""" for k, v in kwargs.items(): - self.add_source(k,v) + self.add_source(k, v) def update_sources(self, **kwargs) -> None: """Add data sources to library or update them.""" for k, v in kwargs.items(): - self.add_source(k,v) + self.add_source(k, v) def set_predefined_catalogs(self, urlpath: Union[Path, str] = None) -> Dict: """Initialise the predefined catalogs.""" @@ -516,7 +546,9 @@ def to_dict( root = abspath(root) meta.update(**{"root": root}) root_drive = os.path.splitdrive(root)[0] - for name, source in sorted(self._sources.items()): # alphabetical order + for name, source in sorted( + self.iter_sources(), key=lambda x: x[0] + ): # alphabetical order if source_names is not None and name not in source_names: continue source_dict = source.to_dict() @@ -540,7 +572,7 @@ def to_dict( def to_dataframe(self, source_names: List = []) -> pd.DataFrame: """Return data catalog summary as DataFrame.""" d = dict() - for name, source in self._sources.items(): + for name, source in self.iter_sources(): if len(source_names) > 0 and name not in source_names: continue d[name] = source.summary() @@ -732,6 +764,7 @@ def get_rasterdataset( raise FileNotFoundError(f"No such file or catalog key: {data_like}") self._used_data.append(name) + source = self.get_source(name) self.logger.info( f"DataCatalog: Getting {name} RasterDataset {source.driver} data from" f" {source.path}" @@ -815,6 +848,7 @@ def get_geodataframe( raise FileNotFoundError(f"No such file or catalog key: {data_like}") self._used_data.append(name) + source = self.get_source(name) self.logger.info( f"DataCatalog: Getting {name} GeoDataFrame {source.driver} data" f" from {source.path}" @@ -901,6 +935,7 @@ def get_geodataset( raise FileNotFoundError(f"No such file or catalog key: {data_like}") self._used_data.append(name) + source = self.get_source(name) self.logger.info( f"DataCatalog: Getting {name} GeoDataset {source.driver} data" f" from {source.path}" @@ -963,6 +998,7 @@ def get_dataframe( raise FileNotFoundError(f"No such file or catalog key: {data_like}") self._used_data.append(name) + source = self.get_source(name) self.logger.info( f"DataCatalog: Getting {name} DataFrame {source.driver} data" f" from {source.path}" diff --git a/tests/data/aws_esa_worldcover.yml b/tests/data/aws_esa_worldcover.yml index 6da764bda..8f8820425 100644 --- a/tests/data/aws_esa_worldcover.yml +++ b/tests/data/aws_esa_worldcover.yml @@ -13,4 +13,4 @@ esa_worldcover: source_version: v100 path: s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt rename: - ESA_WorldCover_10m_2020_v100_Map_AWS: landuse \ No newline at end of file + ESA_WorldCover_10m_2020_v100_Map_AWS: landuse diff --git a/tests/data/legacy_esa_worldcover.yml b/tests/data/legacy_esa_worldcover.yml index 24f981fa2..314facd1e 100644 --- a/tests/data/legacy_esa_worldcover.yml +++ b/tests/data/legacy_esa_worldcover.yml @@ -11,4 +11,4 @@ esa_worldcover: source_license: CC BY 4.0 source_url: https://doi.org/10.5281/zenodo.5571936 source_version: v100 - path: landuse/esa_worldcover/esa-worldcover.vrt \ No newline at end of file + path: landuse/esa_worldcover/esa-worldcover.vrt diff --git a/tests/data/merged_esa_worldcover.yml b/tests/data/merged_esa_worldcover.yml index 1ed00f1ad..e0d72d9cf 100644 --- a/tests/data/merged_esa_worldcover.yml +++ b/tests/data/merged_esa_worldcover.yml @@ -21,4 +21,4 @@ esa_worldcover: kwargs: chunks: x: 36000 - y: 36000 \ No newline at end of file + y: 36000 diff --git a/tests/test_basin_mask.py b/tests/test_basin_mask.py index 4bba755f0..5c27f3016 100644 --- a/tests/test_basin_mask.py +++ b/tests/test_basin_mask.py @@ -123,7 +123,7 @@ def test_basin(caplog): data_catalog = hydromt.DataCatalog(logger=logger) ds = data_catalog.get_rasterdataset("merit_hydro") gdf_bas_index = data_catalog.get_geodataframe("merit_hydro_index") - bas_index = data_catalog["merit_hydro_index"] + bas_index = data_catalog.get_source("merit_hydro_index") with pytest.raises(ValueError, match=r"No basins found"): gdf_bas, gdf_out = get_basin_geometry( diff --git a/tests/test_data_adapter.py b/tests/test_data_adapter.py index a4d77aca8..1cc3fe04b 100644 --- a/tests/test_data_adapter.py +++ b/tests/test_data_adapter.py @@ -41,12 +41,12 @@ def test_resolve_path(tmpdir): cat = DataCatalog() cat.from_dict(dd) # test - assert len(cat["test"].resolve_paths()) == 48 - assert len(cat["test"].resolve_paths(variables=["precip"])) == 24 + assert len(cat.get_source("test").resolve_paths()) == 48 + assert len(cat.get_source("test").resolve_paths(variables=["precip"])) == 24 kwargs = dict(variables=["precip"], time_tuple=("2021-03-01", "2021-05-01")) - assert len(cat["test"].resolve_paths(**kwargs)) == 3 + assert len(cat.get_source("test").resolve_paths(**kwargs)) == 3 with pytest.raises(FileNotFoundError, match="No such file found:"): - cat["test"].resolve_paths(variables=["waves"]) + cat.get_source("test").resolve_paths(variables=["waves"]) def test_rasterdataset(rioda, tmpdir): @@ -61,7 +61,8 @@ def test_rasterdataset(rioda, tmpdir): data_catalog.get_rasterdataset("no_file.tif") -@pytest.mark.skipif(not compat.HAS_GCSFS, reason="GCSFS not installed.") +# @pytest.mark.skipif(not compat.HAS_GCSFS, reason="GCSFS not installed.") +@pytest.mark.skip() def test_gcs_cmip6(tmpdir): # TODO switch to pre-defined catalogs when pushed to main catalog_fn = join(CATALOGDIR, "gcs_cmip6_data.yml") @@ -107,16 +108,23 @@ def test_rasterdataset_zoomlevels(rioda_large, tmpdir): } data_catalog = DataCatalog() data_catalog.from_dict(yml_dict) - assert data_catalog[name]._parse_zoom_level() == 0 # default to first - assert data_catalog[name]._parse_zoom_level(zoom_level=1) == 1 - assert data_catalog[name]._parse_zoom_level(zoom_level=(0.3, "degree")) == 1 - assert data_catalog[name]._parse_zoom_level(zoom_level=(0.29, "degree")) == 0 - assert data_catalog[name]._parse_zoom_level(zoom_level=(0.1, "degree")) == 0 - assert data_catalog[name]._parse_zoom_level(zoom_level=(1, "meter")) == 0 + assert data_catalog.get_source(name)._parse_zoom_level() == 0 # default to first + assert data_catalog.get_source(name)._parse_zoom_level(zoom_level=1) == 1 + assert ( + data_catalog.get_source(name)._parse_zoom_level(zoom_level=(0.3, "degree")) == 1 + ) + assert ( + data_catalog.get_source(name)._parse_zoom_level(zoom_level=(0.29, "degree")) + == 0 + ) + assert ( + data_catalog.get_source(name)._parse_zoom_level(zoom_level=(0.1, "degree")) == 0 + ) + assert data_catalog.get_source(name)._parse_zoom_level(zoom_level=(1, "meter")) == 0 with pytest.raises(TypeError, match="zoom_level unit"): - data_catalog[name]._parse_zoom_level(zoom_level=(1, "asfd")) + data_catalog.get_source(name)._parse_zoom_level(zoom_level=(1, "asfd")) with pytest.raises(TypeError, match="zoom_level argument"): - data_catalog[name]._parse_zoom_level(zoom_level=(1, "asfd", "asdf")) + data_catalog.get_source(name)._parse_zoom_level(zoom_level=(1, "asfd", "asdf")) def test_rasterdataset_driver_kwargs(artifact_data: DataCatalog, tmpdir): @@ -158,7 +166,7 @@ def test_rasterdataset_driver_kwargs(artifact_data: DataCatalog, tmpdir): def test_rasterdataset_unit_attrs(artifact_data: DataCatalog): - era5_dict = {"era5": artifact_data.sources["era5"].to_dict()} + era5_dict = {"era5": artifact_data.get_source("era5").to_dict()} attrs = { "temp": {"unit": "degrees C", "long_name": "temperature"}, "temp_max": {"unit": "degrees C", "long_name": "maximum temperature"}, @@ -219,7 +227,7 @@ def test_geodataset(geoda, geodf, ts, tmpdir): def test_geodataset_unit_attrs(artifact_data: DataCatalog): - gtsm_dict = {"gtsmv3_eu_era5": artifact_data.sources["gtsmv3_eu_era5"].to_dict()} + gtsm_dict = {"gtsmv3_eu_era5": artifact_data.get_source("gtsmv3_eu_era5").to_dict()} attrs = { "waterlevel": { "long_name": "sea surface height above mean sea level", @@ -235,7 +243,7 @@ def test_geodataset_unit_attrs(artifact_data: DataCatalog): def test_geodataset_unit_conversion(artifact_data: DataCatalog): gtsm_geodataarray = artifact_data.get_geodataset("gtsmv3_eu_era5") - gtsm_dict = {"gtsmv3_eu_era5": artifact_data.sources["gtsmv3_eu_era5"].to_dict()} + gtsm_dict = {"gtsmv3_eu_era5": artifact_data.get_source("gtsmv3_eu_era5").to_dict()} gtsm_dict["gtsmv3_eu_era5"].update(dict(unit_mult=dict(waterlevel=1000))) datacatalog = DataCatalog() datacatalog.from_dict(gtsm_dict) @@ -244,7 +252,7 @@ def test_geodataset_unit_conversion(artifact_data: DataCatalog): def test_geodataset_set_nodata(artifact_data: DataCatalog): - gtsm_dict = {"gtsmv3_eu_era5": artifact_data.sources["gtsmv3_eu_era5"].to_dict()} + gtsm_dict = {"gtsmv3_eu_era5": artifact_data.get_source("gtsmv3_eu_era5").to_dict()} gtsm_dict["gtsmv3_eu_era5"].update(dict(nodata=-99)) datacatalog = DataCatalog() datacatalog.from_dict(gtsm_dict) @@ -268,7 +276,7 @@ def test_geodataframe(geodf, tmpdir): def test_geodataframe_unit_attrs(artifact_data: DataCatalog): - gadm_level1 = {"gadm_level1": artifact_data.sources["gadm_level1"].to_dict()} + gadm_level1 = {"gadm_level1": artifact_data.get_source("gadm_level1").to_dict()} attrs = {"NAME_0": {"long_name": "Country names"}} gadm_level1["gadm_level1"].update(dict(attrs=attrs)) artifact_data.from_dict(gadm_level1) diff --git a/tests/test_data_catalog.py b/tests/test_data_catalog.py index b98df2ec3..17e9ab76a 100644 --- a/tests/test_data_catalog.py +++ b/tests/test_data_catalog.py @@ -90,7 +90,7 @@ def test_data_catalog_io(tmpdir): fn_yml = join(tmpdir, "test1.yml") DataCatalog(fallback_lib=None).to_yml(fn_yml) # test print - print(data_catalog["merit_hydro"]) + print(data_catalog.get_source("merit_hydro")) def test_versioned_catalogs(tmpdir): @@ -115,9 +115,9 @@ def test_versioned_catalogs(tmpdir): == "s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt" ) print(merged_catalog.get_source("esa_worldcover")) - breakpoint() + # breakpoint() assert merged_catalog.get_source( - "esa_worldcover", catalog_name="legacy_esa_worldcover" + "esa_worldcover", provider="legacy_esa_worldcover" ).path.endswith("landuse/esa_worldcover/esa-worldcover.vrt") @@ -129,10 +129,10 @@ def test_data_catalog(tmpdir): # global data sources from artifacts are automatically added assert len(data_catalog.sources) > 0 # test keys, getitem, - keys = data_catalog.keys - source = data_catalog[keys[0]] + keys = [key for key, _ in data_catalog.iter_sources()] + source = data_catalog.get_source(keys[0]) assert isinstance(source, DataAdapter) - assert keys[0] in data_catalog + assert keys[0] in data_catalog.get_source_names() # add source from dict data_dict = {keys[0]: source.to_dict()} data_catalog.from_dict(data_dict) @@ -140,7 +140,7 @@ def test_data_catalog(tmpdir): assert isinstance(data_catalog._repr_html_(), str) assert isinstance(data_catalog.to_dataframe(), pd.DataFrame) with pytest.raises(ValueError, match="Value must be DataAdapter"): - data_catalog["test"] = "string" + data_catalog.add_source("test", "string") # check that no sources are loaded if fallback_lib is None assert not DataCatalog(fallback_lib=None).sources # test artifact keys (NOTE: legacy code!) @@ -149,7 +149,8 @@ def test_data_catalog(tmpdir): data_catalog.from_artifacts("deltares_data") assert len(data_catalog._sources) > 0 with pytest.raises(IOError, match="URL b'404: Not Found'"): - data_catalog = DataCatalog(deltares_data="unknown_version") + with pytest.deprecated_call(): + data_catalog = DataCatalog(deltares_data="unknown_version") # test hydromt version in meta data fn_yml = join(tmpdir, "test.yml") From bbd5d5ebe6e906aaaa9b989662bcd6d016f75e4d Mon Sep 17 00:00:00 2001 From: Sam Vente Date: Mon, 10 Jul 2023 14:13:23 +0200 Subject: [PATCH 09/27] [no ci] WIP --- hydromt/__init__.py | 2 + hydromt/data_catalog.py | 103 +++++++++++++++++++++-------------- hydromt/models/model_api.py | 6 +- hydromt/models/model_grid.py | 2 +- hydromt/workflows/forcing.py | 1 + tests/conftest.py | 3 +- tests/test_data_adapter.py | 2 +- tests/test_data_catalog.py | 20 ++++--- tests/test_model.py | 15 ++--- 9 files changed, 93 insertions(+), 61 deletions(-) diff --git a/hydromt/__init__.py b/hydromt/__init__.py index aa7dfe89f..750464464 100644 --- a/hydromt/__init__.py +++ b/hydromt/__init__.py @@ -15,6 +15,8 @@ warnings.filterwarnings("ignore", category=DeprecationWarning) +import dask +dask.config.set(scheduler='single-threaded') # required for accessor style documentation from xarray import DataArray, Dataset diff --git a/hydromt/data_catalog.py b/hydromt/data_catalog.py index 9e6271a0c..dba96e933 100644 --- a/hydromt/data_catalog.py +++ b/hydromt/data_catalog.py @@ -194,11 +194,14 @@ def __setitem__(self, key: str, value: DataAdapter) -> None: ) self.add_source(key, value) - def iter_sources(self) -> Tuple[str, DataAdapter]: + def iter_sources(self) -> List[Tuple[str, DataAdapter]]: + ans = [] for source_name, available_providers in self._sources.items(): # print(available_providers) for provider, adapter in available_providers.items(): - yield (source_name, adapter) + ans.append((source_name, adapter)) + + return ans def __iter__(self): """Iterate over sources.""" @@ -625,9 +628,17 @@ def export_data( # deduce variables from name if "[" in name: variables = name.split("[")[-1].split("]")[0].split(",") + breakpoint() name = name.split("[")[0] source_vars[name] = variables - sources[name] = copy.deepcopy(self.sources[name]) + + if name not in sources: + sources[name] = {} + + source = self.get_source(name) + sources[name]['last'] = copy.deepcopy(source) + sources[name][source.catalog_name] = copy.deepcopy(source) + else: sources = copy.deepcopy(self.sources) @@ -635,54 +646,64 @@ def export_data( fn = join(data_root, "data_catalog.yml") if isfile(fn) and append: self.logger.info(f"Appending existing data catalog {fn}") + breakpoint() sources_out = DataCatalog(fn).sources else: sources_out = {} # export data and update sources - for key, source in sources.items(): - try: - # read slice of source and write to file - self.logger.debug(f"Exporting {key}.") - if not unit_conversion: - unit_mult = source.unit_mult - unit_add = source.unit_add - source.unit_mult = {} - source.unit_add = {} - fn_out, driver = source.to_file( - data_root=data_root, - data_name=key, - variables=source_vars.get(key, None), - bbox=bbox, - time_tuple=time_tuple, - logger=self.logger, - ) - if fn_out is None: - self.logger.warning(f"{key} file contains no data within domain") + for key, available_providers in sources.items(): + for provider, adapter in available_providers.items(): + if provider == "last": continue - # update path & driver and remove kwargs and rename in output sources - if unit_conversion: - source.unit_mult = {} - source.unit_add = {} - else: - source.unit_mult = unit_mult - source.unit_add = unit_add - source.path = fn_out - source.driver = driver - source.filesystem = "local" - source.driver_kwargs = {} - source.rename = {} - if key in sources_out: - self.logger.warning( - f"{key} already exists in data catalog and is overwritten." + try: + # read slice of source and write to file + self.logger.debug(f"Exporting {key}.") + if not unit_conversion: + unit_mult = source.unit_mult + unit_add = source.unit_add + source.unit_mult = {} + source.unit_add = {} + breakpoint() + fn_out, driver = source.to_file( + data_root=data_root, + data_name=key, + variables=source_vars.get(key, None), + bbox=bbox, + time_tuple=time_tuple, + logger=self.logger, ) - sources_out[key] = source - except FileNotFoundError: - self.logger.warning(f"{key} file not found at {source.path}") + if fn_out is None: + self.logger.warning(f"{key} file contains no data within domain") + continue + # update path & driver and remove kwargs and rename in output sources + if unit_conversion: + source.unit_mult = {} + source.unit_add = {} + else: + source.unit_mult = unit_mult + source.unit_add = unit_add + source.path = fn_out + source.driver = driver + source.filesystem = "local" + source.driver_kwargs = {} + source.rename = {} + if key in sources_out: + self.logger.warning( + f"{key} already exists in data catalog and is overwritten." + ) + if not isinstance(source, DataAdapter): + breakpoint() + sources_out[key] = source + except FileNotFoundError: + self.logger.warning(f"{key} file not found at {source.path}") # write data catalog to yml data_catalog_out = DataCatalog() - data_catalog_out._sources = sources_out + for key, adapter in sources_out.items(): + # if not isinstance(adapter, DataAdapter): + # breakpoint() + data_catalog_out.add_source(key,adapter) data_catalog_out.to_yml(fn, root="auto", meta=meta) def get_rasterdataset( diff --git a/hydromt/models/model_api.py b/hydromt/models/model_api.py index 9fee8aed4..acefea53f 100644 --- a/hydromt/models/model_api.py +++ b/hydromt/models/model_api.py @@ -1,4 +1,4 @@ -# -*- coding: utf-8 -*- +## will be deprecated -*- coding: utf-8 -*- """General and basic API for models in HydroMT.""" import glob @@ -346,10 +346,12 @@ def setup_region( kind, region = workflows.parse_region(region, logger=self.logger) # NOTE: kind=outlet is deprecated! if kind in ["basin", "subbasin", "interbasin", "outlet"]: + if kind == "outlet": + warning.warn("Using outlet as kind in setup_region is deprecated", DeprecationWarning) # retrieve global hydrography data (lazy!) ds_org = self.data_catalog.get_rasterdataset(hydrography_fn) if "bounds" not in region: - region.update(basin_index=self.data_catalog[basin_index_fn]) + region.update(basin_index=self.data_catalog.get_source(basin_index_fn)) # get basin geometry geom, xy = workflows.get_basin_geometry( ds=ds_org, diff --git a/hydromt/models/model_grid.py b/hydromt/models/model_grid.py index 2ff587d45..ff467c18b 100644 --- a/hydromt/models/model_grid.py +++ b/hydromt/models/model_grid.py @@ -577,7 +577,7 @@ def setup_grid( # retrieve global hydrography data (lazy!) ds_hyd = self.data_catalog.get_rasterdataset(hydrography_fn) if "bounds" not in region: - region.update(basin_index=self.data_catalog[basin_index_fn]) + region.update(basin_index=self.data_catalog.get_source(basin_index_fn)) # get basin geometry geom, xy = workflows.get_basin_geometry( ds=ds_hyd, diff --git a/hydromt/workflows/forcing.py b/hydromt/workflows/forcing.py index c32b6cc29..55e2872e6 100644 --- a/hydromt/workflows/forcing.py +++ b/hydromt/workflows/forcing.py @@ -80,6 +80,7 @@ def precip( p_out.attrs.update(unit="mm") if freq is not None: resample_kwargs.update(upsampling="bfill", downsampling="sum", logger=logger) + p_out = resample_time(p_out, freq, conserve_mass=True, **resample_kwargs) return p_out diff --git a/tests/conftest.py b/tests/conftest.py index 9a0ec3c8c..0e77f636f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -191,7 +191,8 @@ def model(demda, world, obsda): mod = Model() mod.setup_region({"geom": demda.raster.box}) mod.setup_config(**{"header": {"setting": "value"}}) - mod.set_staticmaps(demda, "elevtn") # will be deprecated + with pytest.deprecated_call(): + mod.set_staticmaps(demda, "elevtn") mod.set_geoms(world, "world") mod.set_maps(demda, "elevtn") mod.set_forcing(obsda, "waterlevel") diff --git a/tests/test_data_adapter.py b/tests/test_data_adapter.py index 1cc3fe04b..f40375c2e 100644 --- a/tests/test_data_adapter.py +++ b/tests/test_data_adapter.py @@ -162,7 +162,7 @@ def test_rasterdataset_driver_kwargs(artifact_data: DataCatalog, tmpdir): datacatalog.from_dict(data_dict2) era5_nc = datacatalog.get_rasterdataset("era5_nc") assert era5_zarr.equals(era5_nc) - datacatalog["era5_zarr"].to_file(tmpdir, "era5_zarr", driver="zarr") + datacatalog.get_source("era5_zarr").to_file(tmpdir, "era5_zarr", driver="zarr") def test_rasterdataset_unit_attrs(artifact_data: DataCatalog): diff --git a/tests/test_data_catalog.py b/tests/test_data_catalog.py index 17e9ab76a..9eb92481e 100644 --- a/tests/test_data_catalog.py +++ b/tests/test_data_catalog.py @@ -104,6 +104,9 @@ def test_versioned_catalogs(tmpdir): assert legacy_data_catalog.get_source("esa_worldcover").path.endswith( "landuse/esa_worldcover/esa-worldcover.vrt" ) + with pytest.deprecated_call(): + _ = legacy_data_catalog['esa_worldcover'] + aws_data_catalog = DataCatalog(data_libs=[aws_yml_fn]) assert ( aws_data_catalog.get_source("esa_worldcover").path @@ -121,7 +124,6 @@ def test_versioned_catalogs(tmpdir): ).path.endswith("landuse/esa_worldcover/esa-worldcover.vrt") -@pytest.mark.filterwarnings('ignore:"from_artifacts" is deprecated:DeprecationWarning') def test_data_catalog(tmpdir): data_catalog = DataCatalog(data_libs=None) # initialized with empty dict @@ -144,9 +146,11 @@ def test_data_catalog(tmpdir): # check that no sources are loaded if fallback_lib is None assert not DataCatalog(fallback_lib=None).sources # test artifact keys (NOTE: legacy code!) - data_catalog = DataCatalog(deltares_data=False) + with pytest.deprecated_call(): + data_catalog = DataCatalog(deltares_data=False) assert len(data_catalog._sources) == 0 - data_catalog.from_artifacts("deltares_data") + with pytest.deprecated_call(): + data_catalog.from_artifacts("deltares_data") assert len(data_catalog._sources) > 0 with pytest.raises(IOError, match="URL b'404: Not Found'"): with pytest.deprecated_call(): @@ -168,8 +172,8 @@ def test_from_archive(tmpdir): data_catalog.predefined_catalogs["artifact_data"]["versions"].values() )[0] data_catalog.from_archive(urlpath.format(version=version_hash)) - assert len(data_catalog._sources) > 0 - source0 = data_catalog._sources[[k for k in data_catalog.sources.keys()][0]] + assert len(data_catalog.iter_sources()) > 0 + source0 = data_catalog.get_source(next(iter([source_name for source_name,_ in data_catalog.iter_sources()]))) assert ".hydromt_data" in str(source0.path) # failed to download with pytest.raises(ConnectionError, match="Data download failed"): @@ -290,7 +294,7 @@ def test_get_data(df): data_catalog = DataCatalog("artifact_data") # read artifacts # raster dataset using three different ways - da = data_catalog.get_rasterdataset(data_catalog["koppen_geiger"].path) + da = data_catalog.get_rasterdataset(data_catalog.get_source("koppen_geiger").path) assert isinstance(da, xr.DataArray) da = data_catalog.get_rasterdataset("koppen_geiger") assert isinstance(da, xr.DataArray) @@ -300,7 +304,7 @@ def test_get_data(df): data_catalog.get_rasterdataset([]) # vector dataset using three different ways - gdf = data_catalog.get_geodataframe(data_catalog["osm_coastlines"].path) + gdf = data_catalog.get_geodataframe(data_catalog.get_source("osm_coastlines").path) assert isinstance(gdf, gpd.GeoDataFrame) gdf = data_catalog.get_geodataframe("osm_coastlines") assert isinstance(gdf, gpd.GeoDataFrame) @@ -310,7 +314,7 @@ def test_get_data(df): data_catalog.get_geodataframe([]) # geodataset using three different ways - da = data_catalog.get_geodataset(data_catalog["gtsmv3_eu_era5"].path) + da = data_catalog.get_geodataset(data_catalog.get_source("gtsmv3_eu_era5").path) assert isinstance(da, xr.DataArray) da = data_catalog.get_geodataset("gtsmv3_eu_era5") assert isinstance(da, xr.DataArray) diff --git a/tests/test_model.py b/tests/test_model.py index ac41f2806..5923fbfdf 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -125,13 +125,11 @@ def test_write_data_catalog(tmpdir): assert list(DataCatalog(data_lib_fn).sources.keys()) == sources[:2] -@pytest.mark.filterwarnings( - 'ignore:Defining "region" based on staticmaps:DeprecationWarning' -) def test_model(model, tmpdir): # Staticmaps -> moved from _test_model_api as it is deprecated model._API.update({"staticmaps": xr.Dataset}) - non_compliant = model._test_model_api() + with pytest.deprecated_call(): + non_compliant = model._test_model_api() assert len(non_compliant) == 0, non_compliant # write model model.set_root(str(tmpdir), mode="w") @@ -140,16 +138,19 @@ def test_model(model, tmpdir): model.read() # read model model1 = Model(str(tmpdir), mode="r") - model1.read() + with pytest.deprecated_call(): + model1.read() with pytest.raises(IOError, match="Model opened in read-only mode"): model1.write() # check if equal model._results = {} # reset results for comparison - equal, errors = model._test_equal(model1) + with pytest.deprecated_call(): + equal, errors = model._test_equal(model1) assert equal, errors # read region from staticmaps model._geoms.pop("region") - assert np.all(model.region.total_bounds == model.staticmaps.raster.bounds) + with pytest.deprecated_call(): + assert np.all(model.region.total_bounds == model.staticmaps.raster.bounds) @pytest.mark.filterwarnings("ignore:The setup_basemaps") From 438f3750b896fd2cf12a2695ea4246213567a80a Mon Sep 17 00:00:00 2001 From: Sam Vente Date: Mon, 10 Jul 2023 15:10:17 +0200 Subject: [PATCH 10/27] fix tests, catalog still need diffing --- hydromt/__init__.py | 3 +- hydromt/data_catalog.py | 66 ++++++++++++++++++---------- hydromt/models/model_api.py | 5 ++- tests/conftest.py | 2 +- tests/data/merged_esa_worldcover.yml | 28 ++++++------ tests/test_data_catalog.py | 22 ++++++---- 6 files changed, 77 insertions(+), 49 deletions(-) diff --git a/hydromt/__init__.py b/hydromt/__init__.py index 750464464..ae36d07e2 100644 --- a/hydromt/__init__.py +++ b/hydromt/__init__.py @@ -16,7 +16,8 @@ warnings.filterwarnings("ignore", category=DeprecationWarning) import dask -dask.config.set(scheduler='single-threaded') + +dask.config.set(scheduler="single-threaded") # required for accessor style documentation from xarray import DataArray, Dataset diff --git a/hydromt/data_catalog.py b/hydromt/data_catalog.py index dba96e933..0f497e30d 100644 --- a/hydromt/data_catalog.py +++ b/hydromt/data_catalog.py @@ -138,6 +138,7 @@ def keys(self) -> List[str]: return list(self._sources.keys()) def get_source_names(self) -> List[str]: + """Return a list of all available data source names.""" return list(self._sources.keys()) @property @@ -150,15 +151,19 @@ def predefined_catalogs(self) -> Dict: def get_source(self, key: str, provider=None) -> DataAdapter: """Get the source.""" if key not in self._sources: + available_sources = sorted(list(self._sources.keys())) raise KeyError( - f"Requested unknown data source: {key} available sources are: {sorted(list(self._sources.keys()))}" + f"Requested unknown data source: {key} " + f"available sources are: {available_sources}" ) available_providers = self._sources[key] if provider is not None: if provider not in available_providers: + providers = sorted(list(available_providers.keys())) raise KeyError( - f"Requested unknown proveder {provider} for data_source {key} available providers are {sorted(list(available_providers.keys()))}" + f"Requested unknown proveder {provider} for data_source {key}" + f" available providers are {providers}" ) else: return available_providers[provider] @@ -166,6 +171,7 @@ def get_source(self, key: str, provider=None) -> DataAdapter: return available_providers["last"] def add_source(self, key: str, adapter: DataAdapter) -> None: + """Add a new data source to the data catalog.""" if not isinstance(adapter, DataAdapter): raise ValueError("Value must be DataAdapter") @@ -181,7 +187,8 @@ def add_source(self, key: str, adapter: DataAdapter) -> None: def __getitem__(self, key: str) -> DataAdapter: """Get the source.""" warnings.warn( - 'Using iterating over the DataCatalog directly is deprecated. Please use cat.get_source("name")', + 'Using iterating over the DataCatalog directly is deprecated."\ + " Please use cat.get_source("name")', DeprecationWarning, ) return self.get_source(key) @@ -189,24 +196,29 @@ def __getitem__(self, key: str) -> DataAdapter: def __setitem__(self, key: str, value: DataAdapter) -> None: """Set or update adaptors.""" warnings.warn( - "Using DataCatalog as a dictionary directly is deprecated. Please use cat.add_source(adapter)", + "Using DataCatalog as a dictionary directly is deprecated." + " Please use cat.add_source(adapter)", DeprecationWarning, ) self.add_source(key, value) def iter_sources(self) -> List[Tuple[str, DataAdapter]]: + """Return a flat list of all available data sources with no duplicates.""" ans = [] for source_name, available_providers in self._sources.items(): # print(available_providers) for provider, adapter in available_providers.items(): - ans.append((source_name, adapter)) + if provider == "last": + continue + ans.append((source_name, adapter)) return ans def __iter__(self): """Iterate over sources.""" warnings.warn( - "Using iterating over the DataCatalog directly is deprecated. Please use cat.iter_sources()", + "Using iterating over the DataCatalog directly is deprecated." + " Please use cat.iter_sources()", DeprecationWarning, ) return self._sources.__iter__() @@ -214,7 +226,8 @@ def __iter__(self): def __len__(self): """Return number of sources.""" warnings.warn( - "Using len on DataCatalog directly is deprecated. Please use len(cat.get_source())", + "Using len on DataCatalog directly is deprecated." + " Please use len(cat.get_source())", DeprecationWarning, ) return self._sources.__len__() @@ -628,17 +641,16 @@ def export_data( # deduce variables from name if "[" in name: variables = name.split("[")[-1].split("]")[0].split(",") - breakpoint() name = name.split("[")[0] source_vars[name] = variables - if name not in sources: + if name not in sources: sources[name] = {} source = self.get_source(name) - sources[name]['last'] = copy.deepcopy(source) + sources[name]["last"] = copy.deepcopy(source) sources[name][source.catalog_name] = copy.deepcopy(source) - + else: sources = copy.deepcopy(self.sources) @@ -646,14 +658,14 @@ def export_data( fn = join(data_root, "data_catalog.yml") if isfile(fn) and append: self.logger.info(f"Appending existing data catalog {fn}") - breakpoint() + # breakpoint() sources_out = DataCatalog(fn).sources else: sources_out = {} # export data and update sources for key, available_providers in sources.items(): - for provider, adapter in available_providers.items(): + for provider, source in available_providers.items(): if provider == "last": continue try: @@ -664,7 +676,7 @@ def export_data( unit_add = source.unit_add source.unit_mult = {} source.unit_add = {} - breakpoint() + # breakpoint() fn_out, driver = source.to_file( data_root=data_root, data_name=key, @@ -674,9 +686,12 @@ def export_data( logger=self.logger, ) if fn_out is None: - self.logger.warning(f"{key} file contains no data within domain") + self.logger.warning( + f"{key} file contains no data within domain" + ) continue - # update path & driver and remove kwargs and rename in output sources + # update path & driver and remove kwargs + # and rename in output sources if unit_conversion: source.unit_mult = {} source.unit_add = {} @@ -689,21 +704,26 @@ def export_data( source.driver_kwargs = {} source.rename = {} if key in sources_out: + # breakpoint() self.logger.warning( f"{key} already exists in data catalog and is overwritten." ) - if not isinstance(source, DataAdapter): - breakpoint() - sources_out[key] = source + if key not in sources_out: + sources_out[key] = {} + + sources_out[key][source.catalog_name] = source + sources_out[key]["last"] = source except FileNotFoundError: self.logger.warning(f"{key} file not found at {source.path}") # write data catalog to yml data_catalog_out = DataCatalog() - for key, adapter in sources_out.items(): - # if not isinstance(adapter, DataAdapter): - # breakpoint() - data_catalog_out.add_source(key,adapter) + for key, available_profviders in sources_out.items(): + for provider, adapter in available_providers.items(): + if provider == "last": + continue + data_catalog_out.add_source(key, adapter) + data_catalog_out.to_yml(fn, root="auto", meta=meta) def get_rasterdataset( diff --git a/hydromt/models/model_api.py b/hydromt/models/model_api.py index acefea53f..9ba8c80f6 100644 --- a/hydromt/models/model_api.py +++ b/hydromt/models/model_api.py @@ -347,7 +347,10 @@ def setup_region( # NOTE: kind=outlet is deprecated! if kind in ["basin", "subbasin", "interbasin", "outlet"]: if kind == "outlet": - warning.warn("Using outlet as kind in setup_region is deprecated", DeprecationWarning) + warnings.warn( + "Using outlet as kind in setup_region is deprecated", + DeprecationWarning, + ) # retrieve global hydrography data (lazy!) ds_org = self.data_catalog.get_rasterdataset(hydrography_fn) if "bounds" not in region: diff --git a/tests/conftest.py b/tests/conftest.py index 0e77f636f..0a6d1b9fe 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -192,7 +192,7 @@ def model(demda, world, obsda): mod.setup_region({"geom": demda.raster.box}) mod.setup_config(**{"header": {"setting": "value"}}) with pytest.deprecated_call(): - mod.set_staticmaps(demda, "elevtn") + mod.set_staticmaps(demda, "elevtn") mod.set_geoms(world, "world") mod.set_maps(demda, "elevtn") mod.set_forcing(obsda, "waterlevel") diff --git a/tests/data/merged_esa_worldcover.yml b/tests/data/merged_esa_worldcover.yml index e0d72d9cf..694cee677 100644 --- a/tests/data/merged_esa_worldcover.yml +++ b/tests/data/merged_esa_worldcover.yml @@ -8,17 +8,17 @@ esa_worldcover: source_url: https://doi.org/10.5281/zenodo.5571936 source_version: v100 versions: - - catalog_name: aws_data - path: path: s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt - rename: - ESA_WorldCover_10m_2020_v100_Map_AWS: landuse - filesystem: s3 - kwargs: - storage_options: - anon: true - - catalog_name: deltares_data - path: landuse/esa_worldcover/esa-worldcover.vrt - kwargs: - chunks: - x: 36000 - y: 36000 + - aws_data: + path: s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt + rename: + ESA_WorldCover_10m_2020_v100_Map_AWS: landuse + filesystem: s3 + kwargs: + storage_options: + anon: true + - deltares_data: + path: landuse/esa_worldcover/esa-worldcover.vrt + kwargs: + chunks: + x: 36000 + y: 36000 diff --git a/tests/test_data_catalog.py b/tests/test_data_catalog.py index 9eb92481e..02385ad11 100644 --- a/tests/test_data_catalog.py +++ b/tests/test_data_catalog.py @@ -105,7 +105,7 @@ def test_versioned_catalogs(tmpdir): "landuse/esa_worldcover/esa-worldcover.vrt" ) with pytest.deprecated_call(): - _ = legacy_data_catalog['esa_worldcover'] + _ = legacy_data_catalog["esa_worldcover"] aws_data_catalog = DataCatalog(data_libs=[aws_yml_fn]) assert ( @@ -152,9 +152,10 @@ def test_data_catalog(tmpdir): with pytest.deprecated_call(): data_catalog.from_artifacts("deltares_data") assert len(data_catalog._sources) > 0 - with pytest.raises(IOError, match="URL b'404: Not Found'"): - with pytest.deprecated_call(): - data_catalog = DataCatalog(deltares_data="unknown_version") + with pytest.raises( + IOError, match="URL b'404: Not Found'" + ), pytest.deprecated_call(): + data_catalog = DataCatalog(deltares_data="unknown_version") # test hydromt version in meta data fn_yml = join(tmpdir, "test.yml") @@ -173,7 +174,9 @@ def test_from_archive(tmpdir): )[0] data_catalog.from_archive(urlpath.format(version=version_hash)) assert len(data_catalog.iter_sources()) > 0 - source0 = data_catalog.get_source(next(iter([source_name for source_name,_ in data_catalog.iter_sources()]))) + source0 = data_catalog.get_source( + next(iter([source_name for source_name, _ in data_catalog.iter_sources()])) + ) assert ".hydromt_data" in str(source0.path) # failed to download with pytest.raises(ConnectionError, match="Data download failed"): @@ -235,7 +238,7 @@ def test_export_global_datasets(tmpdir): assert yml_list[2].strip().startswith("root:") # check if data is parsed correctly data_catalog1 = DataCatalog(data_lib_fn) - for key, source in data_catalog1.sources.items(): + for key, source in data_catalog1.iter_sources(): source_type = type(source).__name__ dtypes = DTYPES[source_type] obj = source.get_data() @@ -278,13 +281,14 @@ def test_export_dataframe(tmpdir, df, df_time): time_tuple=("2010-02-01", "2010-02-14"), bbox=[11.70, 45.35, 12.95, 46.70], ) + # breakpoint() data_catalog1 = DataCatalog(str(tmpdir.join("data_catalog.yml"))) - assert len(data_catalog1) == 1 + assert len(data_catalog1.iter_sources()) == 1 data_catalog.export_data(str(tmpdir)) data_catalog1 = DataCatalog(str(tmpdir.join("data_catalog.yml"))) - assert len(data_catalog1) == 2 - for key, source in data_catalog1.sources.items(): + assert len(data_catalog1.iter_sources()) == 2 + for key, source in data_catalog1.iter_sources(): dtypes = pd.DataFrame obj = source.get_data() assert isinstance(obj, dtypes), key From c97a7ba1bf4c9201ae66083202f980b57e65cad6 Mon Sep 17 00:00:00 2001 From: Sam Vente Date: Mon, 10 Jul 2023 20:05:41 +0200 Subject: [PATCH 11/27] get started on data catalog partitioning --- hydromt/__init__.py | 1 + hydromt/data_catalog.py | 20 ++++--- hydromt/utils.py | 36 +++++++++++++ tests/test_data_catalog.py | 20 +++++-- tests/test_utils.py | 104 +++++++++++++++++++++++++++++++++++++ 5 files changed, 170 insertions(+), 11 deletions(-) create mode 100644 hydromt/utils.py create mode 100644 tests/test_utils.py diff --git a/hydromt/__init__.py b/hydromt/__init__.py index ae36d07e2..d2a0aad27 100644 --- a/hydromt/__init__.py +++ b/hydromt/__init__.py @@ -29,3 +29,4 @@ # high-level methods from .models import * +from .utils import * diff --git a/hydromt/data_catalog.py b/hydromt/data_catalog.py index 0f497e30d..c94317740 100644 --- a/hydromt/data_catalog.py +++ b/hydromt/data_catalog.py @@ -32,6 +32,7 @@ RasterDatasetAdapter, ) from .data_adapter.caching import HYDROMT_DATADIR, _copyfile, _uri_validator +from .utils import partition_dictionaries logger = logging.getLogger(__name__) @@ -178,11 +179,17 @@ def add_source(self, key: str, adapter: DataAdapter) -> None: if key not in self._sources: self._sources[key] = dict() - # TODO catalgos here need to be diffed to construct common base - - self._sources[key]["last"] = adapter - self._sources[key][adapter.catalog_name] = adapter - return self._sources[key] + existing = self._sources[key] + base, diff_existing, diff_new = partition_dictionaries( + existing, adapter.to_dict() + ) + if base == {}: + self._sources[key]["last"] = adapter + self._sources[key][adapter.catalog_name] = adapter + else: + base_adapter = DataAdapter.from_dict(base) + self._sources[key]["base"] = base_adapter + self._source[key]["versions"] = [diff_existing, diff_new] def __getitem__(self, key: str) -> DataAdapter: """Get the source.""" @@ -658,7 +665,6 @@ def export_data( fn = join(data_root, "data_catalog.yml") if isfile(fn) and append: self.logger.info(f"Appending existing data catalog {fn}") - # breakpoint() sources_out = DataCatalog(fn).sources else: sources_out = {} @@ -676,7 +682,6 @@ def export_data( unit_add = source.unit_add source.unit_mult = {} source.unit_add = {} - # breakpoint() fn_out, driver = source.to_file( data_root=data_root, data_name=key, @@ -704,7 +709,6 @@ def export_data( source.driver_kwargs = {} source.rename = {} if key in sources_out: - # breakpoint() self.logger.warning( f"{key} already exists in data catalog and is overwritten." ) diff --git a/hydromt/utils.py b/hydromt/utils.py new file mode 100644 index 000000000..284f79f04 --- /dev/null +++ b/hydromt/utils.py @@ -0,0 +1,36 @@ +"""Utility functions for hydromt that have no other home.""" + + +def partition_dictionaries(left, right): + """Calculate a partitioning of the two dictionaries. + + given dictionaries A and B this function will the follwing partition: + (A ∩ B, A - B, B - A) + """ + common = {} + left_less_right = {} + right_less_left = {} + key_union = set(left.keys()) | set(right.keys()) + + for key in key_union: + value_left = left.get(key, None) + value_right = right.get(key, None) + if isinstance(value_left, dict) and isinstance(value_right, dict): + ( + common_children, + unique_left_children, + unique_right_children, + ) = partition_dictionaries(value_left, value_right) + common[key] = common_children + if unique_left_children != unique_right_children: + left_less_right[key] = unique_left_children + right_less_left[key] = unique_right_children + elif value_left == value_right: + common[key] = value_left + else: + if value_left is not None: + left_less_right[key] = value_left + if value_right is not None: + right_less_left[key] = value_right + + return common, left_less_right, right_less_left diff --git a/tests/test_data_catalog.py b/tests/test_data_catalog.py index 02385ad11..63bdb49e5 100644 --- a/tests/test_data_catalog.py +++ b/tests/test_data_catalog.py @@ -8,6 +8,7 @@ import pandas as pd import pytest import xarray as xr +import yaml from hydromt.data_adapter import DataAdapter, RasterDatasetAdapter from hydromt.data_catalog import DataCatalog, _parse_data_dict @@ -117,11 +118,25 @@ def test_versioned_catalogs(tmpdir): merged_catalog.get_source("esa_worldcover").path == "s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt" ) - print(merged_catalog.get_source("esa_worldcover")) - # breakpoint() assert merged_catalog.get_source( "esa_worldcover", provider="legacy_esa_worldcover" ).path.endswith("landuse/esa_worldcover/esa-worldcover.vrt") + assert ( + merged_catalog.get_source("esa_worldcover", provider="aws_esa_worldcover").path + == "s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt" + ) + aws_data_catalog.from_yml(legacy_yml_fn) + # assert merged_catalog.to_dict() == aws_data_catalog.to_dict() + with open(join(DATADIR, "merged_esa_worldcover.yml"), "r") as f: + expected_merged_catalog_dict = yaml.load(f, Loader=yaml.Loader) + + import json + + print( + "expected: ", json.dumps(expected_merged_catalog_dict, sort_keys=True, indent=2) + ) + print("computed: ", json.dumps(merged_catalog.to_dict(), sort_keys=True, indent=2)) + assert expected_merged_catalog_dict == merged_catalog.to_dict() def test_data_catalog(tmpdir): @@ -281,7 +296,6 @@ def test_export_dataframe(tmpdir, df, df_time): time_tuple=("2010-02-01", "2010-02-14"), bbox=[11.70, 45.35, 12.95, 46.70], ) - # breakpoint() data_catalog1 = DataCatalog(str(tmpdir.join("data_catalog.yml"))) assert len(data_catalog1.iter_sources()) == 1 diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 000000000..16ae7443d --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,104 @@ +"""Testing for the internal hydromt utility functions.""" +from hydromt.utils import partition_dictionaries + + +def test_flat_dict_partition(): + left = {"a": 1, "b": 2, "pi": 3.14} + right = {"a": 1, "b": 2, "e": 2.71} + common, left_less_right, right_less_left = partition_dictionaries(left, right) + assert common == {"a": 1, "b": 2} + assert left_less_right == {"pi": 3.14} + assert right_less_left == {"e": 2.71} + + +def test_nested_disjoint_leaves(): + left = {"a": 1, "b": 2, "maths": {"constants": {"pi": 3.14}}} + right = {"a": 1, "b": 2, "maths": {"constants": {"e": 2.71}}} + common, left_less_right, right_less_left = partition_dictionaries(left, right) + assert common == {"a": 1, "b": 2, "maths": {"constants": {}}} + assert left_less_right == {"maths": {"constants": {"pi": 3.14}}} + assert right_less_left == {"maths": {"constants": {"e": 2.71}}} + + +def test_nested_common_siblings(): + left = { + "a": 1, + "b": 2, + "maths": { + "constants": {"pi": 3.14}, + "integration": {"numeric": None, "analytic": None}, + }, + } + right = { + "a": 1, + "b": 2, + "maths": { + "constants": {"e": 2.71}, + "integration": {"numeric": None, "analytic": None}, + }, + } + common, left_less_right, right_less_left = partition_dictionaries(left, right) + assert common == { + "a": 1, + "b": 2, + "maths": {"constants": {}, "integration": {"numeric": None, "analytic": None}}, + } + assert left_less_right == {"maths": {"constants": {"pi": 3.14}}} + assert right_less_left == {"maths": {"constants": {"e": 2.71}}} + + +def test_nested_key_conflict(): + left = { + "a": 1, + "b": 2, + "c": 3, + "d": 4, + "e": 5, + "maths": {"constants": {"pi": 3.14}}, + } + right = {"a": 1, "b": 2, "c": 3, "d": 4, "maths": {"constants": {"e": 2.71}}} + + common, left_less_right, right_less_left = partition_dictionaries(left, right) + + assert common == {"a": 1, "b": 2, "c": 3, "d": 4, "maths": {"constants": {}}} + assert left_less_right == { + "e": 5, + "maths": {"constants": {"pi": 3.14}}, + } + assert right_less_left == { + "maths": {"constants": {"e": 2.71}}, + } + + +def test_common_ancestory_distinct_children(): + left = { + "a": {"i": -1, "ii": -2, "iii": -3}, + "b": 2, + "c": 3, + "d": 4, + "e": 5, + "maths": {"constants": {"pi": 3.14}}, + } + right = { + "a": {"i": -1, "ii": -2, "iii": -3}, + "b": 2, + "c": 3, + "d": 4, + "maths": {"constants": {"e": 2.71}}, + } + + common, left_less_right, right_less_left = partition_dictionaries(left, right) + assert common == { + "a": {"i": -1, "ii": -2, "iii": -3}, + "b": 2, + "c": 3, + "d": 4, + "maths": {"constants": {}}, + } + assert left_less_right == { + "e": 5, + "maths": {"constants": {"pi": 3.14}}, + } + assert right_less_left == { + "maths": {"constants": {"e": 2.71}}, + } From ee9b803f3756e1c5956b3f999dccba815e45fbb6 Mon Sep 17 00:00:00 2001 From: Sam Vente Date: Tue, 11 Jul 2023 19:46:34 +0200 Subject: [PATCH 12/27] fix reading merged catalogs, write impl missing --- hydromt/data_adapter/data_adapter.py | 2 + hydromt/data_adapter/dataframe.py | 2 + hydromt/data_adapter/geodataframe.py | 2 + hydromt/data_adapter/geodataset.py | 2 + hydromt/data_adapter/rasterdataset.py | 2 + hydromt/data_catalog.py | 78 ++++++++++++++++++--------- tests/data/merged_esa_worldcover.yml | 14 ++--- tests/test_data_catalog.py | 49 ++++++++++------- 8 files changed, 100 insertions(+), 51 deletions(-) diff --git a/hydromt/data_adapter/data_adapter.py b/hydromt/data_adapter/data_adapter.py index 10cf09d16..f48701f5b 100644 --- a/hydromt/data_adapter/data_adapter.py +++ b/hydromt/data_adapter/data_adapter.py @@ -125,6 +125,7 @@ def __init__( driver_kwargs={}, name="", # optional for now catalog_name="", # optional for now + version_name=None, ): """General Interface to data source for HydroMT. @@ -170,6 +171,7 @@ def __init__( """ self.name = name self.catalog_name = catalog_name + self.version_name = version_name # general arguments self.path = path # driver and driver keyword-arguments diff --git a/hydromt/data_adapter/dataframe.py b/hydromt/data_adapter/dataframe.py index d100ebe54..ba8114d0c 100644 --- a/hydromt/data_adapter/dataframe.py +++ b/hydromt/data_adapter/dataframe.py @@ -38,6 +38,7 @@ def __init__( driver_kwargs: dict = {}, name: str = "", # optional for now catalog_name: str = "", # optional for now + version_name=None, **kwargs, ): """Initiate data adapter for 2D tabular data. @@ -106,6 +107,7 @@ def __init__( driver_kwargs=driver_kwargs, name=name, catalog_name=catalog_name, + version_name=version_name, ) def to_file( diff --git a/hydromt/data_adapter/geodataframe.py b/hydromt/data_adapter/geodataframe.py index ed911ae10..6d805a367 100644 --- a/hydromt/data_adapter/geodataframe.py +++ b/hydromt/data_adapter/geodataframe.py @@ -46,6 +46,7 @@ def __init__( driver_kwargs: dict = {}, name: str = "", # optional for now catalog_name: str = "", # optional for now + version_name=None, **kwargs, ): """Initiate data adapter for geospatial vector data. @@ -116,6 +117,7 @@ def __init__( driver_kwargs=driver_kwargs, name=name, catalog_name=catalog_name, + version_name=version_name, ) self.crs = crs diff --git a/hydromt/data_adapter/geodataset.py b/hydromt/data_adapter/geodataset.py index 057609588..a59344e95 100644 --- a/hydromt/data_adapter/geodataset.py +++ b/hydromt/data_adapter/geodataset.py @@ -47,6 +47,7 @@ def __init__( driver_kwargs: dict = {}, name: str = "", # optional for now catalog_name: str = "", # optional for now + version_name=None, **kwargs, ): """Initiate data adapter for geospatial timeseries data. @@ -123,6 +124,7 @@ def __init__( driver_kwargs=driver_kwargs, name=name, catalog_name=catalog_name, + version_name=version_name, ) self.crs = crs diff --git a/hydromt/data_adapter/rasterdataset.py b/hydromt/data_adapter/rasterdataset.py index 72f9fd99d..733d43392 100644 --- a/hydromt/data_adapter/rasterdataset.py +++ b/hydromt/data_adapter/rasterdataset.py @@ -50,6 +50,7 @@ def __init__( zoom_levels: dict = {}, name: str = "", # optional for now catalog_name: str = "", # optional for now + version_name=None, **kwargs, ): """Initiate data adapter for geospatial raster data. @@ -127,6 +128,7 @@ def __init__( driver_kwargs=driver_kwargs, name=name, catalog_name=catalog_name, + version_name=version_name, ) self.crs = crs self.zoom_levels = zoom_levels diff --git a/hydromt/data_catalog.py b/hydromt/data_catalog.py index c94317740..74daa4410 100644 --- a/hydromt/data_catalog.py +++ b/hydromt/data_catalog.py @@ -13,7 +13,7 @@ import warnings from os.path import abspath, basename, exists, isdir, isfile, join from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import geopandas as gpd import numpy as np @@ -32,7 +32,6 @@ RasterDatasetAdapter, ) from .data_adapter.caching import HYDROMT_DATADIR, _copyfile, _uri_validator -from .utils import partition_dictionaries logger = logging.getLogger(__name__) @@ -154,7 +153,7 @@ def get_source(self, key: str, provider=None) -> DataAdapter: if key not in self._sources: available_sources = sorted(list(self._sources.keys())) raise KeyError( - f"Requested unknown data source: {key} " + f"Requested unknown data source: '{key}' " f"available sources are: {available_sources}" ) @@ -163,7 +162,7 @@ def get_source(self, key: str, provider=None) -> DataAdapter: if provider not in available_providers: providers = sorted(list(available_providers.keys())) raise KeyError( - f"Requested unknown proveder {provider} for data_source {key}" + f"Requested unknown proveder '{provider}' for data_source '{key}'" f" available providers are {providers}" ) else: @@ -177,19 +176,13 @@ def add_source(self, key: str, adapter: DataAdapter) -> None: raise ValueError("Value must be DataAdapter") if key not in self._sources: - self._sources[key] = dict() + self._sources[key] = {} - existing = self._sources[key] - base, diff_existing, diff_new = partition_dictionaries( - existing, adapter.to_dict() - ) - if base == {}: - self._sources[key]["last"] = adapter - self._sources[key][adapter.catalog_name] = adapter + self._sources[key]["last"] = adapter + if hasattr(adapter, "version_name") and adapter.version_name is not None: + self._sources[key][adapter.version_name] = adapter else: - base_adapter = DataAdapter.from_dict(base) - self._sources[key]["base"] = base_adapter - self._source[key]["versions"] = [diff_existing, diff_new] + self._sources[key][adapter.catalog_name] = adapter def __getitem__(self, key: str) -> DataAdapter: """Get the source.""" @@ -253,8 +246,7 @@ def update(self, **kwargs) -> None: def update_sources(self, **kwargs) -> None: """Add data sources to library or update them.""" - for k, v in kwargs.items(): - self.add_source(k, v) + self.update(**kwargs) def set_predefined_catalogs(self, urlpath: Union[Path, str] = None) -> Dict: """Initialise the predefined catalogs.""" @@ -490,15 +482,14 @@ def from_dict( } """ - data_dict = _parse_data_dict( - data_dict, - catalog_name=catalog_name, - root=root, - category=category, - ) - self.update(**data_dict) - if mark_used: - self._used_data.extend(list(data_dict.keys())) + data_dicts = _normalise_data_dict(data_dict, catalog_name=catalog_name) + for d in data_dicts: + parsed_dict = _parse_data_dict( + d, catalog_name=catalog_name, root=root, category=category + ) + self.update(**parsed_dict) + if mark_used: + self._used_data.extend(list(parsed_dict.keys())) def to_yml( self, @@ -1118,6 +1109,19 @@ def _parse_data_dict( for opt in source: if "fn" in opt: # get absolute paths for file names source.update({opt: abs_path(root, source[opt])}) + dict_catalog_name = source.pop("catalog_name", None) + if dict_catalog_name is not None and dict_catalog_name != catalog_name: + raise RuntimeError( + "catalog name passed as argument and differs from one in dictionary" + ) + + dict_name = source.pop("name", None) + if dict_name is not None and dict_name != name: + raise RuntimeError( + "Source name passed as argument and differs from one in dictionary" + ) + + version_name = source.pop("version_name", None) if "placeholders" in source: # pop avoid placeholders being passed to adapter options = source.pop("placeholders") @@ -1132,6 +1136,7 @@ def _parse_data_dict( path=path_n, name=name_n, catalog_name=catalog_name, + version_name=version_name, meta=meta, attrs=attrs, driver_kwargs=driver_kwargs, @@ -1143,6 +1148,7 @@ def _parse_data_dict( path=path, name=name, catalog_name=catalog_name, + version_name=version_name, meta=meta, attrs=attrs, driver_kwargs=driver_kwargs, @@ -1175,6 +1181,26 @@ def _process_dict(d: Dict, logger=logger) -> Dict: return d +def _normalise_data_dict(data_dict, catalog_name) -> List[Dict[str, Any]]: + # first do a pass to expand possible versions + dicts = [] + for name, source in data_dict.items(): + if "versions" in source: + versions = source.pop("versions") + for version in versions: + version_name, diff = version.popitem() + source_copy = copy.deepcopy(source) + diff["version_name"] = version_name + diff["name"] = name + diff["catalog_name"] = catalog_name + source_copy.update(**diff) + dicts.append({name: source_copy}) + else: + dicts.append({name: source}) + + return dicts + + def abs_path(root: Union[Path, str], rel_path: Union[Path, str]) -> str: path = Path(str(rel_path)) if not path.is_absolute(): diff --git a/tests/data/merged_esa_worldcover.yml b/tests/data/merged_esa_worldcover.yml index 694cee677..de338e455 100644 --- a/tests/data/merged_esa_worldcover.yml +++ b/tests/data/merged_esa_worldcover.yml @@ -8,7 +8,13 @@ esa_worldcover: source_url: https://doi.org/10.5281/zenodo.5571936 source_version: v100 versions: - - aws_data: + - legacy_esa_worldcover: + path: landuse/esa_worldcover/esa-worldcover.vrt + kwargs: + chunks: + x: 36000 + y: 36000 + - aws_esa_worldcover: path: s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt rename: ESA_WorldCover_10m_2020_v100_Map_AWS: landuse @@ -16,9 +22,3 @@ esa_worldcover: kwargs: storage_options: anon: true - - deltares_data: - path: landuse/esa_worldcover/esa-worldcover.vrt - kwargs: - chunks: - x: 36000 - y: 36000 diff --git a/tests/test_data_catalog.py b/tests/test_data_catalog.py index 63bdb49e5..9960b10ae 100644 --- a/tests/test_data_catalog.py +++ b/tests/test_data_catalog.py @@ -95,48 +95,61 @@ def test_data_catalog_io(tmpdir): def test_versioned_catalogs(tmpdir): - # we want to keep a legacy version embeded in the test code since we're presumably - # going to change the actual catalog. - + # make sure the catalogs individually still work legacy_yml_fn = join(DATADIR, "legacy_esa_worldcover.yml") - aws_yml_fn = join(DATADIR, "aws_esa_worldcover.yml") - # merged_yml_fn = join(DATADIR, "merged_esa_worldcover.yml") legacy_data_catalog = DataCatalog(data_libs=[legacy_yml_fn]) assert legacy_data_catalog.get_source("esa_worldcover").path.endswith( "landuse/esa_worldcover/esa-worldcover.vrt" ) + # make sure we raise deprecation warning here with pytest.deprecated_call(): _ = legacy_data_catalog["esa_worldcover"] + aws_yml_fn = join(DATADIR, "aws_esa_worldcover.yml") aws_data_catalog = DataCatalog(data_libs=[aws_yml_fn]) assert ( aws_data_catalog.get_source("esa_worldcover").path == "s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt" ) - merged_catalog = DataCatalog(data_libs=[legacy_yml_fn, aws_yml_fn]) + + # make sure we can read merged catalogs + merged_yml_fn = join(DATADIR, "merged_esa_worldcover.yml") + read_merged_catalog = DataCatalog(data_libs=[merged_yml_fn]) assert ( - merged_catalog.get_source("esa_worldcover").path + read_merged_catalog.get_source("esa_worldcover").path == "s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt" ) - assert merged_catalog.get_source( + assert ( + read_merged_catalog.get_source( + "esa_worldcover", provider="aws_esa_worldcover" + ).path + == "s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt" + ) + assert read_merged_catalog.get_source( "esa_worldcover", provider="legacy_esa_worldcover" ).path.endswith("landuse/esa_worldcover/esa-worldcover.vrt") + + # Make sure we can queiry for the version we want + aws_and_legacy_data_catalog = DataCatalog(data_libs=[legacy_yml_fn, aws_yml_fn]) assert ( - merged_catalog.get_source("esa_worldcover", provider="aws_esa_worldcover").path + aws_and_legacy_data_catalog.get_source("esa_worldcover").path == "s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt" ) - aws_data_catalog.from_yml(legacy_yml_fn) - # assert merged_catalog.to_dict() == aws_data_catalog.to_dict() + + assert ( + aws_and_legacy_data_catalog.get_source( + "esa_worldcover", provider="aws_esa_worldcover" + ).path + == "s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt" + ) + assert aws_and_legacy_data_catalog.get_source( + "esa_worldcover", provider="legacy_esa_worldcover" + ).path.endswith("landuse/esa_worldcover/esa-worldcover.vrt") + with open(join(DATADIR, "merged_esa_worldcover.yml"), "r") as f: expected_merged_catalog_dict = yaml.load(f, Loader=yaml.Loader) - import json - - print( - "expected: ", json.dumps(expected_merged_catalog_dict, sort_keys=True, indent=2) - ) - print("computed: ", json.dumps(merged_catalog.to_dict(), sort_keys=True, indent=2)) - assert expected_merged_catalog_dict == merged_catalog.to_dict() + assert aws_and_legacy_data_catalog.to_dict() == expected_merged_catalog_dict def test_data_catalog(tmpdir): From 228db747208f668804a27fc7f59d0f0c559eca32 Mon Sep 17 00:00:00 2001 From: Sam Vente Date: Tue, 11 Jul 2023 23:59:19 +0200 Subject: [PATCH 13/27] wip --- hydromt/data_catalog.py | 38 ++++++++++++++++++++++++++++++++------ hydromt/utils.py | 6 ++++++ tests/test_data_catalog.py | 6 +++++- 3 files changed, 43 insertions(+), 7 deletions(-) diff --git a/hydromt/data_catalog.py b/hydromt/data_catalog.py index 74daa4410..fb27f43c9 100644 --- a/hydromt/data_catalog.py +++ b/hydromt/data_catalog.py @@ -23,6 +23,8 @@ import yaml from packaging.version import Version +from hydromt.utils import partition_dictionaries + from . import __version__ from .data_adapter import ( DataAdapter, @@ -166,6 +168,8 @@ def get_source(self, key: str, provider=None) -> DataAdapter: f" available providers are {providers}" ) else: + adapter = available_providers[provider] + adapter.version_name = provider return available_providers[provider] else: return available_providers["last"] @@ -482,7 +486,7 @@ def from_dict( } """ - data_dicts = _normalise_data_dict(data_dict, catalog_name=catalog_name) + data_dicts = _denormalise_data_dict(data_dict, catalog_name=catalog_name) for d in data_dicts: parsed_dict = _parse_data_dict( d, catalog_name=catalog_name, root=root, category=category @@ -560,12 +564,12 @@ def to_dict( root = abspath(root) meta.update(**{"root": root}) root_drive = os.path.splitdrive(root)[0] - for name, source in sorted( - self.iter_sources(), key=lambda x: x[0] - ): # alphabetical order + sorted_sources = sorted(self.iter_sources(), key=lambda x: x[0]) + for name, source in sorted_sources: # alphabetical order if source_names is not None and name not in source_names: continue source_dict = source.to_dict() + if root is not None: path = source_dict["path"] # is abspath source_drive = os.path.splitdrive(path)[0] @@ -578,7 +582,22 @@ def to_dict( ).replace("\\", "/") # remove non serializable entries to prevent errors source_dict = _process_dict(source_dict, logger=self.logger) # TODO TEST - sources_out.update({name: source_dict}) + if name in sources_out: + existing = sources_out.pop(name) + base, diff_existing, diff_new = partition_dictionaries( + source_dict, existing + ) + # TODO how to deal with driver_kwargs vs kwargs when writing? + _ = base.pop("driver_kwargs", None) + existing_version_name = diff_existing.pop("version_name") + new_version_name = diff_new.pop("version_name") + base["versions"] = [ + {new_version_name: diff_new}, + {existing_version_name: diff_existing}, + ] + sources_out[name] = base + else: + sources_out.update({name: source_dict}) if meta: sources_out = {"meta": meta, **sources_out} return sources_out @@ -1181,7 +1200,7 @@ def _process_dict(d: Dict, logger=logger) -> Dict: return d -def _normalise_data_dict(data_dict, catalog_name) -> List[Dict[str, Any]]: +def _denormalise_data_dict(data_dict, catalog_name) -> List[Dict[str, Any]]: # first do a pass to expand possible versions dicts = [] for name, source in data_dict.items(): @@ -1201,6 +1220,13 @@ def _normalise_data_dict(data_dict, catalog_name) -> List[Dict[str, Any]]: return dicts +def _normalise_data_dict(data_dict) -> List[Dict[str, Any]]: + # first do a pass to expand possible versions + dicts = [] + + return dicts + + def abs_path(root: Union[Path, str], rel_path: Union[Path, str]) -> str: path = Path(str(rel_path)) if not path.is_absolute(): diff --git a/hydromt/utils.py b/hydromt/utils.py index 284f79f04..bb4d6d0db 100644 --- a/hydromt/utils.py +++ b/hydromt/utils.py @@ -34,3 +34,9 @@ def partition_dictionaries(left, right): right_less_left[key] = value_right return common, left_less_right, right_less_left + + +def _dict_pprint(d): + import json + + return json.dumps(d, indent=2) diff --git a/tests/test_data_catalog.py b/tests/test_data_catalog.py index 9960b10ae..4a1bc5dd3 100644 --- a/tests/test_data_catalog.py +++ b/tests/test_data_catalog.py @@ -12,6 +12,7 @@ from hydromt.data_adapter import DataAdapter, RasterDatasetAdapter from hydromt.data_catalog import DataCatalog, _parse_data_dict +from hydromt.utils import _dict_pprint CATALOGDIR = join(dirname(abspath(__file__)), "..", "data", "catalogs") DATADIR = join(dirname(abspath(__file__)), "data") @@ -149,7 +150,10 @@ def test_versioned_catalogs(tmpdir): with open(join(DATADIR, "merged_esa_worldcover.yml"), "r") as f: expected_merged_catalog_dict = yaml.load(f, Loader=yaml.Loader) - assert aws_and_legacy_data_catalog.to_dict() == expected_merged_catalog_dict + catalog_dict = aws_and_legacy_data_catalog.to_dict() + print("expected: ", _dict_pprint(expected_merged_catalog_dict)) + print("computed: ", _dict_pprint(catalog_dict)) + assert catalog_dict == expected_merged_catalog_dict def test_data_catalog(tmpdir): From ada662142cb094b2a365d36ecc44d84852bd48d8 Mon Sep 17 00:00:00 2001 From: Sam Vente Date: Wed, 12 Jul 2023 00:39:09 +0200 Subject: [PATCH 14/27] [no ci] WIP --- tests/data/aws_esa_worldcover.yml | 2 +- tests/data/legacy_esa_worldcover.yml | 2 +- tests/data/merged_esa_worldcover.yml | 5 +++-- tests/test_data_catalog.py | 12 ++++++++++-- 4 files changed, 15 insertions(+), 6 deletions(-) diff --git a/tests/data/aws_esa_worldcover.yml b/tests/data/aws_esa_worldcover.yml index 8f8820425..ca02d79d6 100644 --- a/tests/data/aws_esa_worldcover.yml +++ b/tests/data/aws_esa_worldcover.yml @@ -3,7 +3,7 @@ esa_worldcover: data_type: RasterDataset driver: raster filesystem: s3 - kwargs: + driver_kwargs: storage_options: anon: true meta: diff --git a/tests/data/legacy_esa_worldcover.yml b/tests/data/legacy_esa_worldcover.yml index 314facd1e..229e317be 100644 --- a/tests/data/legacy_esa_worldcover.yml +++ b/tests/data/legacy_esa_worldcover.yml @@ -2,7 +2,7 @@ esa_worldcover: crs: 4326 data_type: RasterDataset driver: raster - kwargs: + driver_kwargs: chunks: x: 36000 y: 36000 diff --git a/tests/data/merged_esa_worldcover.yml b/tests/data/merged_esa_worldcover.yml index de338e455..43df1b000 100644 --- a/tests/data/merged_esa_worldcover.yml +++ b/tests/data/merged_esa_worldcover.yml @@ -10,7 +10,8 @@ esa_worldcover: versions: - legacy_esa_worldcover: path: landuse/esa_worldcover/esa-worldcover.vrt - kwargs: + filesystem: local + driver_kwargs: chunks: x: 36000 y: 36000 @@ -19,6 +20,6 @@ esa_worldcover: rename: ESA_WorldCover_10m_2020_v100_Map_AWS: landuse filesystem: s3 - kwargs: + driver_kwargs: storage_options: anon: true diff --git a/tests/test_data_catalog.py b/tests/test_data_catalog.py index 4a1bc5dd3..cb82f66f5 100644 --- a/tests/test_data_catalog.py +++ b/tests/test_data_catalog.py @@ -151,8 +151,16 @@ def test_versioned_catalogs(tmpdir): expected_merged_catalog_dict = yaml.load(f, Loader=yaml.Loader) catalog_dict = aws_and_legacy_data_catalog.to_dict() - print("expected: ", _dict_pprint(expected_merged_catalog_dict)) - print("computed: ", _dict_pprint(catalog_dict)) + + # strip absolute path to make the test portable + catalog_dict["esa_worldcover"]["versions"][0]["legacy_esa_worldcover"][ + "path" + ] = catalog_dict["esa_worldcover"]["versions"][0]["legacy_esa_worldcover"][ + "path" + ].removeprefix( + DATADIR + "/" + ) + assert catalog_dict == expected_merged_catalog_dict From e034c85690803d613b93bd268b3e1d9efc35f212 Mon Sep 17 00:00:00 2001 From: Sam Vente Date: Wed, 12 Jul 2023 10:04:47 +0200 Subject: [PATCH 15/27] fix alias tests --- hydromt/data_catalog.py | 22 +++++++++++++--------- tests/test_data_catalog.py | 11 ++++++----- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/hydromt/data_catalog.py b/hydromt/data_catalog.py index fb27f43c9..49ecd9cf4 100644 --- a/hydromt/data_catalog.py +++ b/hydromt/data_catalog.py @@ -584,6 +584,9 @@ def to_dict( source_dict = _process_dict(source_dict, logger=self.logger) # TODO TEST if name in sources_out: existing = sources_out.pop(name) + if existing == source_dict: + sources_out.update({name: source_dict}) + continue base, diff_existing, diff_new = partition_dictionaries( source_dict, existing ) @@ -1089,14 +1092,6 @@ def _parse_data_dict( for name, source in data_dict.items(): source = source.copy() # important as we modify with pop - if "alias" in source: - alias = source.pop("alias") - if alias not in data_dict: - raise ValueError(f"alias {alias} not found in data_dict.") - # use alias source but overwrite any attributes with original source - source_org = source.copy() - source = data_dict[alias].copy() - source.update(source_org) if "path" not in source: raise ValueError(f"{name}: Missing required path argument.") data_type = source.pop("data_type", None) @@ -1200,7 +1195,7 @@ def _process_dict(d: Dict, logger=logger) -> Dict: return d -def _denormalise_data_dict(data_dict, catalog_name) -> List[Dict[str, Any]]: +def _denormalise_data_dict(data_dict, catalog_name="") -> List[Dict[str, Any]]: # first do a pass to expand possible versions dicts = [] for name, source in data_dict.items(): @@ -1214,6 +1209,15 @@ def _denormalise_data_dict(data_dict, catalog_name) -> List[Dict[str, Any]]: diff["catalog_name"] = catalog_name source_copy.update(**diff) dicts.append({name: source_copy}) + elif "alias" in source: + alias = source.pop("alias") + if alias not in data_dict: + raise ValueError(f"alias {alias} not found in data_dict.") + # use alias source but overwrite any attributes with original source + source_org = source.copy() + source = data_dict[alias].copy() + source.update(source_org) + dicts.append({name: source}) else: dicts.append({name: source}) diff --git a/tests/test_data_catalog.py b/tests/test_data_catalog.py index cb82f66f5..9a41b2a17 100644 --- a/tests/test_data_catalog.py +++ b/tests/test_data_catalog.py @@ -11,8 +11,7 @@ import yaml from hydromt.data_adapter import DataAdapter, RasterDatasetAdapter -from hydromt.data_catalog import DataCatalog, _parse_data_dict -from hydromt.utils import _dict_pprint +from hydromt.data_catalog import DataCatalog, _denormalise_data_dict, _parse_data_dict CATALOGDIR = join(dirname(abspath(__file__)), "..", "data", "catalogs") DATADIR = join(dirname(abspath(__file__)), "data") @@ -56,8 +55,10 @@ def test_parser(): }, "test1": {"alias": "test"}, } - dd_out = _parse_data_dict(dd, root=root) - assert dd_out["test"].path == dd_out["test1"].path + dd = _denormalise_data_dict(dd, catalog_name="tmp") + dd_out1 = _parse_data_dict(dd[0], root=root) + dd_out2 = _parse_data_dict(dd[1], root=root) + assert dd_out1["test"].path == dd_out2["test1"].path # placeholder dd = { "test_{p1}_{p2}": { @@ -77,7 +78,7 @@ def test_parser(): with pytest.raises(ValueError, match="Data type error unknown"): _parse_data_dict({"test": {"path": "", "data_type": "error"}}) with pytest.raises(ValueError, match="alias test not found in data_dict"): - _parse_data_dict({"test1": {"alias": "test"}}) + _denormalise_data_dict({"test1": {"alias": "test"}}) def test_data_catalog_io(tmpdir): From 5b7f7a0698c8938765f53570a69d8939768795c4 Mon Sep 17 00:00:00 2001 From: Sam Vente Date: Wed, 12 Jul 2023 10:06:13 +0200 Subject: [PATCH 16/27] remove temporary test skip --- tests/test_data_adapter.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_data_adapter.py b/tests/test_data_adapter.py index f40375c2e..a332aaf84 100644 --- a/tests/test_data_adapter.py +++ b/tests/test_data_adapter.py @@ -61,8 +61,7 @@ def test_rasterdataset(rioda, tmpdir): data_catalog.get_rasterdataset("no_file.tif") -# @pytest.mark.skipif(not compat.HAS_GCSFS, reason="GCSFS not installed.") -@pytest.mark.skip() +@pytest.mark.skipif(not compat.HAS_GCSFS, reason="GCSFS not installed.") def test_gcs_cmip6(tmpdir): # TODO switch to pre-defined catalogs when pushed to main catalog_fn = join(CATALOGDIR, "gcs_cmip6_data.yml") From e46b6086c05326cca60a1ea5ba7d3f83ca4f789f Mon Sep 17 00:00:00 2001 From: Sam Vente Date: Wed, 12 Jul 2023 10:44:54 +0200 Subject: [PATCH 17/27] introduce deprecation warning for alias --- docs/changelog.rst | 1 + docs/user_guide/data_prepare_cat.rst | 2 ++ hydromt/__init__.py | 3 --- hydromt/data_catalog.py | 12 +++++------- tests/test_data_adapter.py | 2 +- tests/test_data_catalog.py | 8 ++++++-- 6 files changed, 15 insertions(+), 13 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index f1ed260d4..5780b9a06 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -11,6 +11,7 @@ Unreleased Added ----- +- Support for loading the same data source but from different places (e.g. local & aws) - Support for unit attributes for all data types in the DataCatalog. PR #334 - Data catalog can now handle specification of HydroMT version - New generic methods for ``GridModel``: ``setup_grid``, ``setup_grid_from_constant``, ``setup_grid_from_rasterdataset``, ``setup_grid_from_raster_reclass``, ``setup_grid_from_geodataframe``. PR #333 diff --git a/docs/user_guide/data_prepare_cat.rst b/docs/user_guide/data_prepare_cat.rst index 3e162b48a..bd5304117 100644 --- a/docs/user_guide/data_prepare_cat.rst +++ b/docs/user_guide/data_prepare_cat.rst @@ -114,6 +114,8 @@ A full list of **optional data source arguments** is given below - **placeholder** (optional): this argument can be used to generate multiple sources with a single entry in the data catalog file. If different files follow a logical nomenclature, multiple data sources can be defined by iterating through all possible combinations of the placeholders. The placeholder names should be given in the source name and the path and its values listed under the placeholder argument. +- **versions** (optional): If you want to use the same data source but load it from different places (e.g. local & aws) you can add this key + Keys here are essentially overrides that will get applied to the containing catalog when they get parsed and expanded. .. note:: diff --git a/hydromt/__init__.py b/hydromt/__init__.py index d2a0aad27..e5ed9d8be 100644 --- a/hydromt/__init__.py +++ b/hydromt/__init__.py @@ -15,9 +15,6 @@ warnings.filterwarnings("ignore", category=DeprecationWarning) -import dask - -dask.config.set(scheduler="single-threaded") # required for accessor style documentation from xarray import DataArray, Dataset diff --git a/hydromt/data_catalog.py b/hydromt/data_catalog.py index 49ecd9cf4..16d1d7cfc 100644 --- a/hydromt/data_catalog.py +++ b/hydromt/data_catalog.py @@ -1211,6 +1211,11 @@ def _denormalise_data_dict(data_dict, catalog_name="") -> List[Dict[str, Any]]: dicts.append({name: source_copy}) elif "alias" in source: alias = source.pop("alias") + warnings.warn( + "The use of alias is deprecated, please add a version on the aliased" + "catalog instead.", + DeprecationWarning, + ) if alias not in data_dict: raise ValueError(f"alias {alias} not found in data_dict.") # use alias source but overwrite any attributes with original source @@ -1224,13 +1229,6 @@ def _denormalise_data_dict(data_dict, catalog_name="") -> List[Dict[str, Any]]: return dicts -def _normalise_data_dict(data_dict) -> List[Dict[str, Any]]: - # first do a pass to expand possible versions - dicts = [] - - return dicts - - def abs_path(root: Union[Path, str], rel_path: Union[Path, str]) -> str: path = Path(str(rel_path)) if not path.is_absolute(): diff --git a/tests/test_data_adapter.py b/tests/test_data_adapter.py index a332aaf84..8a38b839d 100644 --- a/tests/test_data_adapter.py +++ b/tests/test_data_adapter.py @@ -183,7 +183,7 @@ def test_geodataset(geoda, geodf, ts, tmpdir): # the synchronous scheduler here is necessary from dask import config as dask_config - dask_config.set(scheduler="synchronous") + dask_config.set(scheduler="single-threaded") fn_nc = str(tmpdir.join("test.nc")) fn_gdf = str(tmpdir.join("test.geojson")) fn_csv = str(tmpdir.join("test.csv")) diff --git a/tests/test_data_catalog.py b/tests/test_data_catalog.py index 9a41b2a17..fb706b238 100644 --- a/tests/test_data_catalog.py +++ b/tests/test_data_catalog.py @@ -55,7 +55,9 @@ def test_parser(): }, "test1": {"alias": "test"}, } - dd = _denormalise_data_dict(dd, catalog_name="tmp") + with pytest.deprecated_call(): + dd = _denormalise_data_dict(dd, catalog_name="tmp") + dd_out1 = _parse_data_dict(dd[0], root=root) dd_out2 = _parse_data_dict(dd[1], root=root) assert dd_out1["test"].path == dd_out2["test1"].path @@ -77,7 +79,9 @@ def test_parser(): _parse_data_dict({"test": {}}) with pytest.raises(ValueError, match="Data type error unknown"): _parse_data_dict({"test": {"path": "", "data_type": "error"}}) - with pytest.raises(ValueError, match="alias test not found in data_dict"): + with pytest.raises( + ValueError, match="alias test not found in data_dict" + ), pytest.deprecated_call(): _denormalise_data_dict({"test1": {"alias": "test"}}) From 1388d60bdd288f65f88e5d5fcf6c77d75affc22b Mon Sep 17 00:00:00 2001 From: Sam Vente Date: Mon, 24 Jul 2023 15:39:56 +0200 Subject: [PATCH 18/27] fix data adapters part --- hydromt/data_adapter/dataframe.py | 3 + hydromt/data_adapter/geodataframe.py | 3 + hydromt/data_adapter/geodataset.py | 7 +- hydromt/data_adapter/rasterdataset.py | 3 + hydromt/data_catalog.py | 248 +++++++++++++++++--------- tests/data/aws_esa_worldcover.yml | 2 + tests/data/legacy_esa_worldcover.yml | 1 + tests/data/merged_esa_worldcover.yml | 38 ++-- tests/test_data_catalog.py | 30 ++++ 9 files changed, 231 insertions(+), 104 deletions(-) diff --git a/hydromt/data_adapter/dataframe.py b/hydromt/data_adapter/dataframe.py index ba8114d0c..c266ddb0c 100644 --- a/hydromt/data_adapter/dataframe.py +++ b/hydromt/data_adapter/dataframe.py @@ -200,6 +200,9 @@ def get_data( _ = self.resolve_paths(**so_kwargs) # throw nice error if data not found kwargs = self.driver_kwargs.copy() + # these are just for internal bookeeping. drivers don't need them + _ = kwargs.pop("provider", None) + _ = kwargs.pop("data_version", None) # read and clip logger.info(f"DataFrame: Read {self.driver} data.") diff --git a/hydromt/data_adapter/geodataframe.py b/hydromt/data_adapter/geodataframe.py index 6d805a367..31e4e50a1 100644 --- a/hydromt/data_adapter/geodataframe.py +++ b/hydromt/data_adapter/geodataframe.py @@ -218,6 +218,9 @@ def get_data( _ = self.resolve_paths() # throw nice error if data not found kwargs = self.driver_kwargs.copy() + # these are just for internal bookeeping. drivers don't need them + _ = kwargs.pop("provider", None) + _ = kwargs.pop("data_version", None) # parse geom, bbox and buffer arguments clip_str = "" if geom is None and bbox is not None: diff --git a/hydromt/data_adapter/geodataset.py b/hydromt/data_adapter/geodataset.py index a59344e95..27105340d 100644 --- a/hydromt/data_adapter/geodataset.py +++ b/hydromt/data_adapter/geodataset.py @@ -257,7 +257,12 @@ def get_data( ) kwargs = self.driver_kwargs.copy() - + # these are just for internal bookeeping. drivers don't need them + _ = kwargs.pop( + "provider", + None, + ) + _ = kwargs.pop("data_version", None) # parse geom, bbox and buffer arguments clip_str = "" if geom is None and bbox is not None: diff --git a/hydromt/data_adapter/rasterdataset.py b/hydromt/data_adapter/rasterdataset.py index 745abdc8b..03853995f 100644 --- a/hydromt/data_adapter/rasterdataset.py +++ b/hydromt/data_adapter/rasterdataset.py @@ -273,6 +273,9 @@ def get_data( ) kwargs = self.driver_kwargs.copy() + # these are just for internal bookeeping. drivers don't need them + _ = kwargs.pop("provider", None) + _ = kwargs.pop("data_version", None) # zarr can use storage options directly, the rest should be converted to # file-like objects if "storage_options" in kwargs and self.driver == "raster": diff --git a/hydromt/data_catalog.py b/hydromt/data_catalog.py index 16d1d7cfc..c24accb51 100644 --- a/hydromt/data_catalog.py +++ b/hydromt/data_catalog.py @@ -150,7 +150,7 @@ def predefined_catalogs(self) -> Dict: self.set_predefined_catalogs() return self._catalogs - def get_source(self, key: str, provider=None) -> DataAdapter: + def get_source(self, key: str, provider=None, data_version=None) -> DataAdapter: """Get the source.""" if key not in self._sources: available_sources = sorted(list(self._sources.keys())) @@ -160,33 +160,73 @@ def get_source(self, key: str, provider=None) -> DataAdapter: ) available_providers = self._sources[key] - if provider is not None: - if provider not in available_providers: - providers = sorted(list(available_providers.keys())) + + if provider is None: + requested_provider = "last_added" + else: + requested_provider = provider + + if data_version is None: + requested_data_version = "last_added" + else: + requested_data_version = data_version + + if requested_provider not in available_providers: + providers = sorted(list(available_providers.keys())) + raise KeyError( + f"Requested unknown proveder '{requested_provider}' for data_source" + f" '{key}' available providers are {providers}" + ) + else: + available_data_versions = available_providers[requested_provider] + if requested_data_version not in available_data_versions: + data_versions = sorted(list(available_data_versions.keys())) raise KeyError( - f"Requested unknown proveder '{provider}' for data_source '{key}'" - f" available providers are {providers}" + f"Requested unknown data_version '{requested_data_version}' for" + f" data_source '{key}' and provider '{requested_provider}'" + f" available data_versions are {data_versions}" ) else: - adapter = available_providers[provider] - adapter.version_name = provider - return available_providers[provider] - else: - return available_providers["last"] + adapter = available_data_versions[requested_data_version] + adapter.provider = requested_provider + adapter.data_version = requested_data_version + + return self._sources[key][requested_provider][requested_data_version] def add_source(self, key: str, adapter: DataAdapter) -> None: """Add a new data source to the data catalog.""" if not isinstance(adapter, DataAdapter): raise ValueError("Value must be DataAdapter") + if hasattr(adapter, "data_version") and adapter.data_version is not None: + data_version = adapter.data_version + else: + data_version = "UNKNOWN" + + if hasattr(adapter, "provider") and adapter.provider is not None: + provider = adapter.provier + else: + provider = adapter.catalog_name + if key not in self._sources: self._sources[key] = {} - self._sources[key]["last"] = adapter - if hasattr(adapter, "version_name") and adapter.version_name is not None: - self._sources[key][adapter.version_name] = adapter - else: - self._sources[key][adapter.catalog_name] = adapter + if provider not in self._sources[key]: + self._sources[key][provider] = {} + + if ( + provider in self._sources[key] + and data_version in self._sources[key][provider] + ): + warnings.warn( + f"overwriting entry with provider: {provider} and version:" + f" {data_version} in {key} entry", + UserWarning, + ) + + self._sources[key][provider][data_version] = adapter + self._sources[key][provider]["last_added"] = adapter + self._sources[key]["last_added"] = {"last_added": adapter} def __getitem__(self, key: str) -> DataAdapter: """Get the source.""" @@ -210,11 +250,13 @@ def iter_sources(self) -> List[Tuple[str, DataAdapter]]: """Return a flat list of all available data sources with no duplicates.""" ans = [] for source_name, available_providers in self._sources.items(): - # print(available_providers) - for provider, adapter in available_providers.items(): - if provider == "last": + for provider, available_data_versions in available_providers.items(): + if provider == "last_added": continue - ans.append((source_name, adapter)) + for data_version, adapter in available_data_versions.items(): + if data_version == "last_added": + continue + ans.append((source_name, adapter)) return ans @@ -225,16 +267,16 @@ def __iter__(self): " Please use cat.iter_sources()", DeprecationWarning, ) - return self._sources.__iter__() + return self.iter_sources() def __len__(self): """Return number of sources.""" warnings.warn( "Using len on DataCatalog directly is deprecated." - " Please use len(cat.get_source())", + " Please use len(cat.iter_sources())", DeprecationWarning, ) - return self._sources.__len__() + return len(self.iter_sources()) def __repr__(self): """Prettyprint the sources.""" @@ -422,7 +464,6 @@ def from_yml( catalog_name = meta.get("name", "".join(basename(urlpath).split(".")[:-1])) - # TODO keep meta data!! Note only possible if yml files are not merged if root is None: root = meta.get("root", os.path.dirname(urlpath)) self.from_dict( @@ -590,8 +631,8 @@ def to_dict( base, diff_existing, diff_new = partition_dictionaries( source_dict, existing ) - # TODO how to deal with driver_kwargs vs kwargs when writing? _ = base.pop("driver_kwargs", None) + existing_version_name = diff_existing.pop("version_name") new_version_name = diff_new.pop("version_name") base["versions"] = [ @@ -683,63 +724,73 @@ def export_data( sources_out = {} # export data and update sources - for key, available_providers in sources.items(): - for provider, source in available_providers.items(): - if provider == "last": + for key, available_variants in sources.items(): + for provider, available_data_versions in available_variants.items(): + if provider == "last_added": continue - try: - # read slice of source and write to file - self.logger.debug(f"Exporting {key}.") - if not unit_conversion: - unit_mult = source.unit_mult - unit_add = source.unit_add - source.unit_mult = {} - source.unit_add = {} - fn_out, driver = source.to_file( - data_root=data_root, - data_name=key, - variables=source_vars.get(key, None), - bbox=bbox, - time_tuple=time_tuple, - logger=self.logger, - ) - if fn_out is None: - self.logger.warning( - f"{key} file contains no data within domain" - ) + for data_version, source in available_data_versions.items(): + if data_version == "last_added": continue - # update path & driver and remove kwargs - # and rename in output sources - if unit_conversion: - source.unit_mult = {} - source.unit_add = {} - else: - source.unit_mult = unit_mult - source.unit_add = unit_add - source.path = fn_out - source.driver = driver - source.filesystem = "local" - source.driver_kwargs = {} - source.rename = {} - if key in sources_out: - self.logger.warning( - f"{key} already exists in data catalog and is overwritten." + try: + # read slice of source and write to file + self.logger.debug(f"Exporting {key}.") + if not unit_conversion: + unit_mult = source.unit_mult + unit_add = source.unit_add + source.unit_mult = {} + source.unit_add = {} + fn_out, driver = source.to_file( + data_root=data_root, + data_name=key, + variables=source_vars.get(key, None), + bbox=bbox, + time_tuple=time_tuple, + logger=self.logger, ) - if key not in sources_out: - sources_out[key] = {} - - sources_out[key][source.catalog_name] = source - sources_out[key]["last"] = source - except FileNotFoundError: - self.logger.warning(f"{key} file not found at {source.path}") + if fn_out is None: + self.logger.warning( + f"{key} file contains no data within domain" + ) + continue + # update path & driver and remove kwargs + # and rename in output sources + if unit_conversion: + source.unit_mult = {} + source.unit_add = {} + else: + source.unit_mult = unit_mult + source.unit_add = unit_add + source.path = fn_out + source.driver = driver + source.filesystem = "local" + source.driver_kwargs = {} + source.rename = {} + if key in sources_out: + self.logger.warning( + f"{key} already exists in data catalog, overwriting..." + ) + if key not in sources_out: + sources_out[key] = {} + if provider not in sources_out[key]: + sources_out[key][provider] = {} + + sources_out[key][provider][data_version] = source + sources_out[key][provider]["last_added"] = source + sources_out[key]["last_added"] = {"last_added": source} + except FileNotFoundError: + self.logger.warning(f"{key} file not found at {source.path}") # write data catalog to yml data_catalog_out = DataCatalog() - for key, available_profviders in sources_out.items(): - for provider, adapter in available_providers.items(): - if provider == "last": + for key, available_variants in sources_out.items(): + for provider, available_data_versions in available_variants.items(): + if provider == "last_added": continue - data_catalog_out.add_source(key, adapter) + for data_version, adapter in available_data_versions.items(): + if data_version == "last_added": + continue + + data_catalog_out.add_source(key, adapter) data_catalog_out.to_yml(fn, root="auto", meta=meta) @@ -754,6 +805,8 @@ def get_rasterdataset( variables: Union[List, str] = None, time_tuple: Tuple = None, single_var_as_array: bool = True, + provider: Optional[str] = None, + data_version: Optional[str] = None, **kwargs, ) -> xr.Dataset: """Return a clipped, sliced and unified RasterDataset. @@ -822,7 +875,11 @@ def get_rasterdataset( raise FileNotFoundError(f"No such file or catalog key: {data_like}") self._used_data.append(name) - source = self.get_source(name) + source = self.get_source( + name, + provider=provider, + data_version=data_version, + ) self.logger.info( f"DataCatalog: Getting {name} RasterDataset {source.driver} data from" f" {source.path}" @@ -849,6 +906,8 @@ def get_geodataframe( buffer: Union[float, int] = 0, variables: Union[List, str] = None, predicate: str = "intersects", + provider=None, + data_version=None, **kwargs, ): """Return a clipped and unified GeoDataFrame (vector). @@ -906,7 +965,11 @@ def get_geodataframe( raise FileNotFoundError(f"No such file or catalog key: {data_like}") self._used_data.append(name) - source = self.get_source(name) + source = self.get_source( + name, + provider=provider, + data_version=data_version, + ) self.logger.info( f"DataCatalog: Getting {name} GeoDataFrame {source.driver} data" f" from {source.path}" @@ -930,6 +993,8 @@ def get_geodataset( variables: List = None, time_tuple: Tuple = None, single_var_as_array: bool = True, + provider=None, + data_version=None, **kwargs, ) -> xr.Dataset: """Return a clipped, sliced and unified GeoDataset. @@ -993,7 +1058,11 @@ def get_geodataset( raise FileNotFoundError(f"No such file or catalog key: {data_like}") self._used_data.append(name) - source = self.get_source(name) + source = self.get_source( + name, + provider=provider, + data_version=data_version, + ) self.logger.info( f"DataCatalog: Getting {name} GeoDataset {source.driver} data" f" from {source.path}" @@ -1005,7 +1074,6 @@ def get_geodataset( variables=variables, time_tuple=time_tuple, single_var_as_array=single_var_as_array, - logger=self.logger, ) return obj @@ -1014,6 +1082,8 @@ def get_dataframe( data_like: Union[str, Path, pd.DataFrame], variables: list = None, time_tuple: tuple = None, + provider=None, + data_version=None, **kwargs, ): """Return a unified and sliced DataFrame. @@ -1056,7 +1126,11 @@ def get_dataframe( raise FileNotFoundError(f"No such file or catalog key: {data_like}") self._used_data.append(name) - source = self.get_source(name) + source = self.get_source( + name, + provider=provider, + data_version=data_version, + ) self.logger.info( f"DataCatalog: Getting {name} DataFrame {source.driver} data" f" from {source.path}" @@ -1083,7 +1157,6 @@ def _parse_data_dict( "GeoDataset": GeoDatasetAdapter, "DataFrame": DataFrameAdapter, } - # NOTE: shouldn't the kwarg overwrite the dict/yml ? if root is None: root = data_dict.pop("root", None) @@ -1199,14 +1272,17 @@ def _denormalise_data_dict(data_dict, catalog_name="") -> List[Dict[str, Any]]: # first do a pass to expand possible versions dicts = [] for name, source in data_dict.items(): - if "versions" in source: - versions = source.pop("versions") - for version in versions: - version_name, diff = version.popitem() + if "variants" in source: + variants = source.pop("variants") + for diff in variants: source_copy = copy.deepcopy(source) - diff["version_name"] = version_name diff["name"] = name diff["catalog_name"] = catalog_name + if "provider" not in diff: + diff["provider"] = catalog_name + if "data_version" not in diff: + diff["data_version"] = "latest" + source_copy.update(**diff) dicts.append({name: source_copy}) elif "alias" in source: diff --git a/tests/data/aws_esa_worldcover.yml b/tests/data/aws_esa_worldcover.yml index ca02d79d6..4812ba830 100644 --- a/tests/data/aws_esa_worldcover.yml +++ b/tests/data/aws_esa_worldcover.yml @@ -3,6 +3,8 @@ esa_worldcover: data_type: RasterDataset driver: raster filesystem: s3 + data_version: 2021 + provider: aws driver_kwargs: storage_options: anon: true diff --git a/tests/data/legacy_esa_worldcover.yml b/tests/data/legacy_esa_worldcover.yml index 229e317be..11d131e5c 100644 --- a/tests/data/legacy_esa_worldcover.yml +++ b/tests/data/legacy_esa_worldcover.yml @@ -12,3 +12,4 @@ esa_worldcover: source_url: https://doi.org/10.5281/zenodo.5571936 source_version: v100 path: landuse/esa_worldcover/esa-worldcover.vrt + data_version: 2020 diff --git a/tests/data/merged_esa_worldcover.yml b/tests/data/merged_esa_worldcover.yml index 43df1b000..4a8861eb0 100644 --- a/tests/data/merged_esa_worldcover.yml +++ b/tests/data/merged_esa_worldcover.yml @@ -2,24 +2,28 @@ esa_worldcover: crs: 4326 data_type: RasterDataset driver: raster + filesystem: local + driver_kwargs: + chunks: + x: 36000 + y: 36000 meta: category: landuse source_license: CC BY 4.0 source_url: https://doi.org/10.5281/zenodo.5571936 - source_version: v100 - versions: - - legacy_esa_worldcover: - path: landuse/esa_worldcover/esa-worldcover.vrt - filesystem: local - driver_kwargs: - chunks: - x: 36000 - y: 36000 - - aws_esa_worldcover: - path: s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt - rename: - ESA_WorldCover_10m_2020_v100_Map_AWS: landuse - filesystem: s3 - driver_kwargs: - storage_options: - anon: true + variants: + - provider: local + version: 2020 + path: landuse/esa_worldcover/esa-worldcover.vrt + - provider: local + version: 2021 + path: landuse/esa_worldcover_2021/esa-worldcover.vrt + - provider: aws + version: 2020 + path: s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt + rename: + ESA_WorldCover_10m_2020_v100_Map_AWS: landuse + filesystem: s3 + driver_kwargs: + storage_options: + anon: true diff --git a/tests/test_data_catalog.py b/tests/test_data_catalog.py index fb706b238..ad96a1e7e 100644 --- a/tests/test_data_catalog.py +++ b/tests/test_data_catalog.py @@ -113,10 +113,40 @@ def test_versioned_catalogs(tmpdir): aws_yml_fn = join(DATADIR, "aws_esa_worldcover.yml") aws_data_catalog = DataCatalog(data_libs=[aws_yml_fn]) + # test get_source with all keyword combinations assert ( aws_data_catalog.get_source("esa_worldcover").path == "s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt" ) + assert ( + aws_data_catalog.get_source("esa_worldcover", data_version="2021").path + == "s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt" + ) + assert ( + aws_data_catalog.get_source("esa_worldcover", data_version="2021").path + == "s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt" + ) + assert ( + aws_data_catalog.get_source( + "esa_worldcover", data_version="2021", provider="aws" + ).path + == "s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt" + ) + + with pytest.raises(KeyError): + aws_data_catalog.get_source( + "esa_worldcover", data_version="2021", provider="asdfasdf" + ) + with pytest.raises(KeyError): + aws_data_catalog.get_source( + "esa_worldcover", data_version="asdfasdf", provider="aws" + ) + with pytest.raises(KeyError): + aws_data_catalog.get_source("asdfasdf", data_version="2021", provider="aws") + + # make sure we trigger user warning when overwriting versions + with pytest.warns(UserWarning): + aws_data_catalog.from_yml(aws_yml_fn) # make sure we can read merged catalogs merged_yml_fn = join(DATADIR, "merged_esa_worldcover.yml") From 57918183e84bc8d8efc7536c6df900d251800cf9 Mon Sep 17 00:00:00 2001 From: Sam Vente Date: Mon, 24 Jul 2023 19:04:31 +0200 Subject: [PATCH 19/27] fix versoined catalog tests --- hydromt/data_adapter/data_adapter.py | 6 ++- hydromt/data_adapter/dataframe.py | 6 ++- hydromt/data_adapter/geodataframe.py | 6 ++- hydromt/data_adapter/geodataset.py | 6 ++- hydromt/data_adapter/rasterdataset.py | 6 ++- hydromt/data_catalog.py | 63 ++++++++++++++++----------- tests/test_data_catalog.py | 44 ++++++------------- 7 files changed, 71 insertions(+), 66 deletions(-) diff --git a/hydromt/data_adapter/data_adapter.py b/hydromt/data_adapter/data_adapter.py index f48701f5b..82a9066e0 100644 --- a/hydromt/data_adapter/data_adapter.py +++ b/hydromt/data_adapter/data_adapter.py @@ -125,7 +125,8 @@ def __init__( driver_kwargs={}, name="", # optional for now catalog_name="", # optional for now - version_name=None, + provider="UNSPECIFIED", + data_version="UNSPECIFIED", ): """General Interface to data source for HydroMT. @@ -171,7 +172,8 @@ def __init__( """ self.name = name self.catalog_name = catalog_name - self.version_name = version_name + self.provider = provider + self.data_version = data_version # general arguments self.path = path # driver and driver keyword-arguments diff --git a/hydromt/data_adapter/dataframe.py b/hydromt/data_adapter/dataframe.py index c266ddb0c..2ff7fde1b 100644 --- a/hydromt/data_adapter/dataframe.py +++ b/hydromt/data_adapter/dataframe.py @@ -38,7 +38,8 @@ def __init__( driver_kwargs: dict = {}, name: str = "", # optional for now catalog_name: str = "", # optional for now - version_name=None, + provider=None, + data_version=None, **kwargs, ): """Initiate data adapter for 2D tabular data. @@ -107,7 +108,8 @@ def __init__( driver_kwargs=driver_kwargs, name=name, catalog_name=catalog_name, - version_name=version_name, + provider=provider, + data_version=data_version, ) def to_file( diff --git a/hydromt/data_adapter/geodataframe.py b/hydromt/data_adapter/geodataframe.py index 31e4e50a1..e05a728aa 100644 --- a/hydromt/data_adapter/geodataframe.py +++ b/hydromt/data_adapter/geodataframe.py @@ -46,7 +46,8 @@ def __init__( driver_kwargs: dict = {}, name: str = "", # optional for now catalog_name: str = "", # optional for now - version_name=None, + provider=None, + data_version=None, **kwargs, ): """Initiate data adapter for geospatial vector data. @@ -117,7 +118,8 @@ def __init__( driver_kwargs=driver_kwargs, name=name, catalog_name=catalog_name, - version_name=version_name, + provider=provider, + data_version=data_version, ) self.crs = crs diff --git a/hydromt/data_adapter/geodataset.py b/hydromt/data_adapter/geodataset.py index 27105340d..50945ce20 100644 --- a/hydromt/data_adapter/geodataset.py +++ b/hydromt/data_adapter/geodataset.py @@ -47,7 +47,8 @@ def __init__( driver_kwargs: dict = {}, name: str = "", # optional for now catalog_name: str = "", # optional for now - version_name=None, + provider=None, + data_version=None, **kwargs, ): """Initiate data adapter for geospatial timeseries data. @@ -124,7 +125,8 @@ def __init__( driver_kwargs=driver_kwargs, name=name, catalog_name=catalog_name, - version_name=version_name, + provider=provider, + data_version=data_version, ) self.crs = crs diff --git a/hydromt/data_adapter/rasterdataset.py b/hydromt/data_adapter/rasterdataset.py index 03853995f..abe4acb8a 100644 --- a/hydromt/data_adapter/rasterdataset.py +++ b/hydromt/data_adapter/rasterdataset.py @@ -50,7 +50,8 @@ def __init__( zoom_levels: dict = {}, name: str = "", # optional for now catalog_name: str = "", # optional for now - version_name=None, + provider=None, + data_version=None, **kwargs, ): """Initiate data adapter for geospatial raster data. @@ -128,7 +129,8 @@ def __init__( driver_kwargs=driver_kwargs, name=name, catalog_name=catalog_name, - version_name=version_name, + provider=provider, + data_version=data_version, ) self.crs = crs self.zoom_levels = zoom_levels diff --git a/hydromt/data_catalog.py b/hydromt/data_catalog.py index c24accb51..c15200dad 100644 --- a/hydromt/data_catalog.py +++ b/hydromt/data_catalog.py @@ -180,16 +180,12 @@ def get_source(self, key: str, provider=None, data_version=None) -> DataAdapter: else: available_data_versions = available_providers[requested_provider] if requested_data_version not in available_data_versions: - data_versions = sorted(list(available_data_versions.keys())) + data_versions = sorted(list(map(str, available_data_versions.keys()))) raise KeyError( f"Requested unknown data_version '{requested_data_version}' for" f" data_source '{key}' and provider '{requested_provider}'" f" available data_versions are {data_versions}" ) - else: - adapter = available_data_versions[requested_data_version] - adapter.provider = requested_provider - adapter.data_version = requested_data_version return self._sources[key][requested_provider][requested_data_version] @@ -201,10 +197,10 @@ def add_source(self, key: str, adapter: DataAdapter) -> None: if hasattr(adapter, "data_version") and adapter.data_version is not None: data_version = adapter.data_version else: - data_version = "UNKNOWN" + data_version = "UNSPECIFIED" if hasattr(adapter, "provider") and adapter.provider is not None: - provider = adapter.provier + provider = adapter.provider else: provider = adapter.catalog_name @@ -224,9 +220,15 @@ def add_source(self, key: str, adapter: DataAdapter) -> None: UserWarning, ) + if "last_added" not in self._sources[key]: + self._sources[key]["last_added"] = {} + self._sources[key][provider][data_version] = adapter self._sources[key][provider]["last_added"] = adapter - self._sources[key]["last_added"] = {"last_added": adapter} + + self._sources[key]["last_added"]["last_added"] = adapter + self._sources[key]["last_added"][data_version] = adapter + pass def __getitem__(self, key: str) -> DataAdapter: """Get the source.""" @@ -633,12 +635,7 @@ def to_dict( ) _ = base.pop("driver_kwargs", None) - existing_version_name = diff_existing.pop("version_name") - new_version_name = diff_new.pop("version_name") - base["versions"] = [ - {new_version_name: diff_new}, - {existing_version_name: diff_existing}, - ] + base["variants"] = ([diff_new, diff_existing],) sources_out[name] = base else: sources_out.update({name: source_dict}) @@ -705,12 +702,21 @@ def export_data( name = name.split("[")[0] source_vars[name] = variables + source = self.get_source(name) + provider = source.provider + data_version = source.data_version + if name not in sources: sources[name] = {} + if provider not in sources[name]: + sources[name][provider] = {} - source = self.get_source(name) - sources[name]["last"] = copy.deepcopy(source) - sources[name][source.catalog_name] = copy.deepcopy(source) + if "last_added" not in sources[name]: + sources[name]["last_added"] = {} + + sources[name][provider]["last_added"] = copy.deepcopy(source) + sources[name]["last_added"]["last_added"] = copy.deepcopy(source) + sources[name][provider][data_version] = copy.deepcopy(source) else: sources = copy.deepcopy(self.sources) @@ -1184,7 +1190,7 @@ def _parse_data_dict( # Get unit attrs if given from source attrs = source.pop("attrs", {}) # lower kwargs for backwards compatability - # FIXME this could be problamatic if driver kwargs conflict DataAdapter + # arguments driver_kwargs = source.pop("driver_kwargs", source.pop("kwargs", {})) for driver_kwarg in driver_kwargs: @@ -1208,7 +1214,15 @@ def _parse_data_dict( "Source name passed as argument and differs from one in dictionary" ) - version_name = source.pop("version_name", None) + dict_catalog_name = source.pop("catalog_name", None) + if dict_catalog_name is not None and dict_catalog_name != catalog_name: + raise RuntimeError( + "catalog name passed as argument and differs from one in dictionary" + ) + + provider = source.pop("provider", None) + data_version = source.pop("data_version", None) + if "placeholders" in source: # pop avoid placeholders being passed to adapter options = source.pop("placeholders") @@ -1223,7 +1237,8 @@ def _parse_data_dict( path=path_n, name=name_n, catalog_name=catalog_name, - version_name=version_name, + provider=provider, + data_version=data_version, meta=meta, attrs=attrs, driver_kwargs=driver_kwargs, @@ -1235,7 +1250,8 @@ def _parse_data_dict( path=path, name=name, catalog_name=catalog_name, - version_name=version_name, + provider=provider, + data_version=data_version, meta=meta, attrs=attrs, driver_kwargs=driver_kwargs, @@ -1278,11 +1294,6 @@ def _denormalise_data_dict(data_dict, catalog_name="") -> List[Dict[str, Any]]: source_copy = copy.deepcopy(source) diff["name"] = name diff["catalog_name"] = catalog_name - if "provider" not in diff: - diff["provider"] = catalog_name - if "data_version" not in diff: - diff["data_version"] = "latest" - source_copy.update(**diff) dicts.append({name: source_copy}) elif "alias" in source: diff --git a/tests/test_data_catalog.py b/tests/test_data_catalog.py index ad96a1e7e..e228af6b3 100644 --- a/tests/test_data_catalog.py +++ b/tests/test_data_catalog.py @@ -8,7 +8,6 @@ import pandas as pd import pytest import xarray as xr -import yaml from hydromt.data_adapter import DataAdapter, RasterDatasetAdapter from hydromt.data_catalog import DataCatalog, _denormalise_data_dict, _parse_data_dict @@ -107,6 +106,8 @@ def test_versioned_catalogs(tmpdir): assert legacy_data_catalog.get_source("esa_worldcover").path.endswith( "landuse/esa_worldcover/esa-worldcover.vrt" ) + assert legacy_data_catalog.get_source("esa_worldcover").data_version == 2020 + # make sure we raise deprecation warning here with pytest.deprecated_call(): _ = legacy_data_catalog["esa_worldcover"] @@ -118,31 +119,32 @@ def test_versioned_catalogs(tmpdir): aws_data_catalog.get_source("esa_worldcover").path == "s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt" ) + assert aws_data_catalog.get_source("esa_worldcover").data_version == 2021 assert ( - aws_data_catalog.get_source("esa_worldcover", data_version="2021").path + aws_data_catalog.get_source("esa_worldcover", data_version=2021).path == "s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt" ) assert ( - aws_data_catalog.get_source("esa_worldcover", data_version="2021").path - == "s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt" + aws_data_catalog.get_source("esa_worldcover", data_version=2021).data_version + == 2021 ) assert ( aws_data_catalog.get_source( - "esa_worldcover", data_version="2021", provider="aws" + "esa_worldcover", data_version=2021, provider="aws" ).path == "s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt" ) with pytest.raises(KeyError): aws_data_catalog.get_source( - "esa_worldcover", data_version="2021", provider="asdfasdf" + "esa_worldcover", data_version=2021, provider="asdfasdf" ) with pytest.raises(KeyError): aws_data_catalog.get_source( "esa_worldcover", data_version="asdfasdf", provider="aws" ) with pytest.raises(KeyError): - aws_data_catalog.get_source("asdfasdf", data_version="2021", provider="aws") + aws_data_catalog.get_source("asdfasdf", data_version=2021, provider="aws") # make sure we trigger user warning when overwriting versions with pytest.warns(UserWarning): @@ -156,14 +158,12 @@ def test_versioned_catalogs(tmpdir): == "s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt" ) assert ( - read_merged_catalog.get_source( - "esa_worldcover", provider="aws_esa_worldcover" - ).path + read_merged_catalog.get_source("esa_worldcover", provider="aws").path == "s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt" ) assert read_merged_catalog.get_source( - "esa_worldcover", provider="legacy_esa_worldcover" - ).path.endswith("landuse/esa_worldcover/esa-worldcover.vrt") + "esa_worldcover", provider="local" + ).path.endswith("landuse/esa_worldcover_2021/esa-worldcover.vrt") # Make sure we can queiry for the version we want aws_and_legacy_data_catalog = DataCatalog(data_libs=[legacy_yml_fn, aws_yml_fn]) @@ -173,30 +173,14 @@ def test_versioned_catalogs(tmpdir): ) assert ( - aws_and_legacy_data_catalog.get_source( - "esa_worldcover", provider="aws_esa_worldcover" - ).path + aws_and_legacy_data_catalog.get_source("esa_worldcover", provider="aws").path == "s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt" ) assert aws_and_legacy_data_catalog.get_source( "esa_worldcover", provider="legacy_esa_worldcover" ).path.endswith("landuse/esa_worldcover/esa-worldcover.vrt") - with open(join(DATADIR, "merged_esa_worldcover.yml"), "r") as f: - expected_merged_catalog_dict = yaml.load(f, Loader=yaml.Loader) - - catalog_dict = aws_and_legacy_data_catalog.to_dict() - - # strip absolute path to make the test portable - catalog_dict["esa_worldcover"]["versions"][0]["legacy_esa_worldcover"][ - "path" - ] = catalog_dict["esa_worldcover"]["versions"][0]["legacy_esa_worldcover"][ - "path" - ].removeprefix( - DATADIR + "/" - ) - - assert catalog_dict == expected_merged_catalog_dict + _ = aws_and_legacy_data_catalog.to_dict() def test_data_catalog(tmpdir): From 664d88ae57f9e105ecf48888c50ab0ab4bc3e53d Mon Sep 17 00:00:00 2001 From: Sam Vente Date: Mon, 24 Jul 2023 22:34:47 +0200 Subject: [PATCH 20/27] clean up impl by just using that dicts are ordered --- hydromt/data_catalog.py | 61 ++++++++++++----------------------------- 1 file changed, 17 insertions(+), 44 deletions(-) diff --git a/hydromt/data_catalog.py b/hydromt/data_catalog.py index c15200dad..20fc5cb72 100644 --- a/hydromt/data_catalog.py +++ b/hydromt/data_catalog.py @@ -162,30 +162,31 @@ def get_source(self, key: str, provider=None, data_version=None) -> DataAdapter: available_providers = self._sources[key] if provider is None: - requested_provider = "last_added" + requested_provider = list(available_providers.keys())[-1] else: requested_provider = provider - if data_version is None: - requested_data_version = "last_added" - else: - requested_data_version = data_version - if requested_provider not in available_providers: providers = sorted(list(available_providers.keys())) raise KeyError( f"Requested unknown proveder '{requested_provider}' for data_source" f" '{key}' available providers are {providers}" ) + + available_data_versions = available_providers[requested_provider] + + if data_version is None: + requested_data_version = list(available_data_versions.keys())[-1] else: - available_data_versions = available_providers[requested_provider] - if requested_data_version not in available_data_versions: - data_versions = sorted(list(map(str, available_data_versions.keys()))) - raise KeyError( - f"Requested unknown data_version '{requested_data_version}' for" - f" data_source '{key}' and provider '{requested_provider}'" - f" available data_versions are {data_versions}" - ) + requested_data_version = data_version + + if requested_data_version not in available_data_versions: + data_versions = sorted(list(map(str, available_data_versions.keys()))) + raise KeyError( + f"Requested unknown data_version '{requested_data_version}' for" + f" data_source '{key}' and provider '{requested_provider}'" + f" available data_versions are {data_versions}" + ) return self._sources[key][requested_provider][requested_data_version] @@ -220,15 +221,7 @@ def add_source(self, key: str, adapter: DataAdapter) -> None: UserWarning, ) - if "last_added" not in self._sources[key]: - self._sources[key]["last_added"] = {} - self._sources[key][provider][data_version] = adapter - self._sources[key][provider]["last_added"] = adapter - - self._sources[key]["last_added"]["last_added"] = adapter - self._sources[key]["last_added"][data_version] = adapter - pass def __getitem__(self, key: str) -> DataAdapter: """Get the source.""" @@ -252,12 +245,8 @@ def iter_sources(self) -> List[Tuple[str, DataAdapter]]: """Return a flat list of all available data sources with no duplicates.""" ans = [] for source_name, available_providers in self._sources.items(): - for provider, available_data_versions in available_providers.items(): - if provider == "last_added": - continue - for data_version, adapter in available_data_versions.items(): - if data_version == "last_added": - continue + for _, available_data_versions in available_providers.items(): + for _, adapter in available_data_versions.items(): ans.append((source_name, adapter)) return ans @@ -711,11 +700,6 @@ def export_data( if provider not in sources[name]: sources[name][provider] = {} - if "last_added" not in sources[name]: - sources[name]["last_added"] = {} - - sources[name][provider]["last_added"] = copy.deepcopy(source) - sources[name]["last_added"]["last_added"] = copy.deepcopy(source) sources[name][provider][data_version] = copy.deepcopy(source) else: @@ -732,11 +716,7 @@ def export_data( # export data and update sources for key, available_variants in sources.items(): for provider, available_data_versions in available_variants.items(): - if provider == "last_added": - continue for data_version, source in available_data_versions.items(): - if data_version == "last_added": - continue try: # read slice of source and write to file self.logger.debug(f"Exporting {key}.") @@ -781,8 +761,6 @@ def export_data( sources_out[key][provider] = {} sources_out[key][provider][data_version] = source - sources_out[key][provider]["last_added"] = source - sources_out[key]["last_added"] = {"last_added": source} except FileNotFoundError: self.logger.warning(f"{key} file not found at {source.path}") @@ -790,12 +768,7 @@ def export_data( data_catalog_out = DataCatalog() for key, available_variants in sources_out.items(): for provider, available_data_versions in available_variants.items(): - if provider == "last_added": - continue for data_version, adapter in available_data_versions.items(): - if data_version == "last_added": - continue - data_catalog_out.add_source(key, adapter) data_catalog_out.to_yml(fn, root="auto", meta=meta) From 516b6ec7c636f62f516ff022915087261ebc2284 Mon Sep 17 00:00:00 2001 From: Sam Vente Date: Thu, 27 Jul 2023 11:27:52 +0000 Subject: [PATCH 21/27] Update docs/user_guide/data_prepare_cat.rst Co-authored-by: DirkEilander --- docs/user_guide/data_prepare_cat.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/user_guide/data_prepare_cat.rst b/docs/user_guide/data_prepare_cat.rst index bd5304117..078a851cc 100644 --- a/docs/user_guide/data_prepare_cat.rst +++ b/docs/user_guide/data_prepare_cat.rst @@ -114,7 +114,7 @@ A full list of **optional data source arguments** is given below - **placeholder** (optional): this argument can be used to generate multiple sources with a single entry in the data catalog file. If different files follow a logical nomenclature, multiple data sources can be defined by iterating through all possible combinations of the placeholders. The placeholder names should be given in the source name and the path and its values listed under the placeholder argument. -- **versions** (optional): If you want to use the same data source but load it from different places (e.g. local & aws) you can add this key +- **variants** (optional): If you want to use the same data source but load it from different places (e.g. local & aws) you can add this key Keys here are essentially overrides that will get applied to the containing catalog when they get parsed and expanded. .. note:: From cdd4d5e231afa235d15f9e68c2172e8421a80ad1 Mon Sep 17 00:00:00 2001 From: Sam Vente Date: Thu, 27 Jul 2023 15:46:58 +0200 Subject: [PATCH 22/27] incorproate PR feedback --- hydromt/__init__.py | 1 - hydromt/data_adapter/data_adapter.py | 5 +- hydromt/data_adapter/dataframe.py | 10 +- hydromt/data_catalog.py | 151 ++++++++++++++++----------- hydromt/models/model_api.py | 2 +- hydromt/models/model_grid.py | 2 +- tests/test_data_adapter.py | 1 + tests/test_data_catalog.py | 28 +++-- 8 files changed, 119 insertions(+), 81 deletions(-) diff --git a/hydromt/__init__.py b/hydromt/__init__.py index dc15675bf..6be7ef6b9 100644 --- a/hydromt/__init__.py +++ b/hydromt/__init__.py @@ -26,4 +26,3 @@ # high-level methods from .models import * -from .utils import * diff --git a/hydromt/data_adapter/data_adapter.py b/hydromt/data_adapter/data_adapter.py index 82a9066e0..200fcd5df 100644 --- a/hydromt/data_adapter/data_adapter.py +++ b/hydromt/data_adapter/data_adapter.py @@ -5,6 +5,7 @@ from abc import ABCMeta, abstractmethod from itertools import product from string import Formatter +from typing import Optional import geopandas as gpd import numpy as np @@ -125,8 +126,8 @@ def __init__( driver_kwargs={}, name="", # optional for now catalog_name="", # optional for now - provider="UNSPECIFIED", - data_version="UNSPECIFIED", + provider: Optional[str] = None, + data_version: Optional[str] = None, ): """General Interface to data source for HydroMT. diff --git a/hydromt/data_adapter/dataframe.py b/hydromt/data_adapter/dataframe.py index 2ff7fde1b..a3c8504b7 100644 --- a/hydromt/data_adapter/dataframe.py +++ b/hydromt/data_adapter/dataframe.py @@ -3,7 +3,7 @@ import os import warnings from os.path import join -from typing import Union +from typing import Optional, Union import numpy as np import pandas as pd @@ -27,9 +27,9 @@ class DataFrameAdapter(DataAdapter): def __init__( self, path: str, - driver: str = None, + driver: Optional[str] = None, filesystem: str = "local", - nodata: Union[dict, float, int] = None, + nodata: Optional[Union[dict, float, int]] = None, rename: dict = {}, unit_mult: dict = {}, unit_add: dict = {}, @@ -38,8 +38,8 @@ def __init__( driver_kwargs: dict = {}, name: str = "", # optional for now catalog_name: str = "", # optional for now - provider=None, - data_version=None, + provider: Optional[str] = None, + data_version: Optional[str] = None, **kwargs, ): """Initiate data adapter for 2D tabular data. diff --git a/hydromt/data_catalog.py b/hydromt/data_catalog.py index 20fc5cb72..60f9ab60f 100644 --- a/hydromt/data_catalog.py +++ b/hydromt/data_catalog.py @@ -13,7 +13,7 @@ import warnings from os.path import abspath, basename, exists, isdir, isfile, join from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union import geopandas as gpd import numpy as np @@ -41,6 +41,11 @@ "DataCatalog", ] +# just for typehints +SourceSpecDict = TypedDict( + "SourceSpecDict", {"source": str, "provider": str, "version": str | int} +) + class DataCatalog(object): @@ -634,12 +639,13 @@ def to_dict( def to_dataframe(self, source_names: List = []) -> pd.DataFrame: """Return data catalog summary as DataFrame.""" - d = dict() - for name, source in self.iter_sources(): - if len(source_names) > 0 and name not in source_names: - continue - d[name] = source.summary() - return pd.DataFrame.from_dict(d, orient="index") + seq = [ + (name, source) + for name, source in self.iter_sources() + if len(source_names) == 0 or name in source_names + ] + print(seq) + return pd.DataFrame.from_records(seq) def export_data( self, @@ -775,17 +781,15 @@ def export_data( def get_rasterdataset( self, - data_like: Union[str, Path, xr.Dataset, xr.DataArray], - bbox: List = None, - geom: gpd.GeoDataFrame = None, - zoom_level: int | tuple = None, + data_like: Union[str, SourceSpecDict, Path, xr.Dataset, xr.DataArray], + bbox: Optional[List] = None, + geom: Optional[gpd.GeoDataFrame] = None, + zoom_level: Optional[int | tuple] = None, buffer: Union[float, int] = 0, - align: bool = None, - variables: Union[List, str] = None, - time_tuple: Tuple = None, - single_var_as_array: bool = True, - provider: Optional[str] = None, - data_version: Optional[str] = None, + align: Optional[bool] = None, + variables: Optional[Union[List, str]] = None, + time_tuple: Optional[Tuple] = None, + single_var_as_array: Optional[bool] = True, **kwargs, ) -> xr.Dataset: """Return a clipped, sliced and unified RasterDataset. @@ -847,18 +851,22 @@ def get_rasterdataset( name = basename(data_like).split(".")[0] source = RasterDatasetAdapter(path=path, **kwargs) self.update(**{name: source}) - elif data_like in self.sources: + elif isinstance(data_like, dict): + if not SourceSpecDict.__required_keys__.issuperset(set(data_like.keys())): + unknown_keys = set(data_like.keys()) - SourceSpecDict.__required_keys__ + raise ValueError( + f"Unknown keys in requested data source: {unknown_keys}" + ) + else: + source = self.get_source(**data_like) + name = source.name + elif isinstance(data_like, str) and data_like in self.sources: name = data_like - source = self.sources[name] + source = self.get_source(name) else: raise FileNotFoundError(f"No such file or catalog key: {data_like}") self._used_data.append(name) - source = self.get_source( - name, - provider=provider, - data_version=data_version, - ) self.logger.info( f"DataCatalog: Getting {name} RasterDataset {source.driver} data from" f" {source.path}" @@ -879,14 +887,12 @@ def get_rasterdataset( def get_geodataframe( self, - data_like: Union[str, Path, gpd.GeoDataFrame], - bbox: List = None, - geom: gpd.GeoDataFrame = None, + data_like: Union[str, SourceSpecDict, Path, xr.Dataset, xr.DataArray], + bbox: Optional[List] = None, + geom: Optional[gpd.GeoDataFrame] = None, buffer: Union[float, int] = 0, - variables: Union[List, str] = None, + variables: Optional[Union[List, str]] = None, predicate: str = "intersects", - provider=None, - data_version=None, **kwargs, ): """Return a clipped and unified GeoDataFrame (vector). @@ -937,18 +943,22 @@ def get_geodataframe( name = basename(data_like).split(".")[0] source = GeoDataFrameAdapter(path=path, **kwargs) self.update(**{name: source}) - elif data_like in self.sources: + elif isinstance(data_like, dict): + if not SourceSpecDict.__required_keys__.issuperset(set(data_like.keys())): + unknown_keys = set(data_like.keys()) - SourceSpecDict.__required_keys__ + raise ValueError( + f"Unknown keys in requested data source: {unknown_keys}" + ) + else: + source = self.get_source(**data_like) + name = source.name + elif isinstance(data_like, str) and data_like in self.sources: name = data_like - source = self.sources[name] + source = self.get_source(name) else: raise FileNotFoundError(f"No such file or catalog key: {data_like}") self._used_data.append(name) - source = self.get_source( - name, - provider=provider, - data_version=data_version, - ) self.logger.info( f"DataCatalog: Getting {name} GeoDataFrame {source.driver} data" f" from {source.path}" @@ -965,15 +975,13 @@ def get_geodataframe( def get_geodataset( self, - data_like: Union[Path, str, xr.DataArray, xr.Dataset], - bbox: List = None, - geom: gpd.GeoDataFrame = None, + data_like: Union[str, SourceSpecDict, Path, xr.Dataset, xr.DataArray], + bbox: Optional[List] = None, + geom: Optional[gpd.GeoDataFrame] = None, buffer: Union[float, int] = 0, - variables: List = None, - time_tuple: Tuple = None, + variables: Optional[List] = None, + time_tuple: Optional[Tuple] = None, single_var_as_array: bool = True, - provider=None, - data_version=None, **kwargs, ) -> xr.Dataset: """Return a clipped, sliced and unified GeoDataset. @@ -1030,18 +1038,22 @@ def get_geodataset( name = basename(data_like).split(".")[0] source = GeoDatasetAdapter(path=path, **kwargs) self.update(**{name: source}) - elif data_like in self.sources: + elif isinstance(data_like, dict): + if not SourceSpecDict.__required_keys__.issuperset(set(data_like.keys())): + unknown_keys = set(data_like.keys()) - SourceSpecDict.__required_keys__ + raise ValueError( + f"Unknown keys in requested data source: {unknown_keys}" + ) + else: + source = self.get_source(**data_like) + name = source.name + elif isinstance(data_like, str) and data_like in self.sources: name = data_like - source = self.sources[name] + source = self.get_source(name) else: raise FileNotFoundError(f"No such file or catalog key: {data_like}") self._used_data.append(name) - source = self.get_source( - name, - provider=provider, - data_version=data_version, - ) self.logger.info( f"DataCatalog: Getting {name} GeoDataset {source.driver} data" f" from {source.path}" @@ -1058,11 +1070,9 @@ def get_geodataset( def get_dataframe( self, - data_like: Union[str, Path, pd.DataFrame], - variables: list = None, - time_tuple: tuple = None, - provider=None, - data_version=None, + data_like: Union[str, SourceSpecDict, Path, xr.Dataset, xr.DataArray], + variables: Optional[list] = None, + time_tuple: Optional[Tuple] = None, **kwargs, ): """Return a unified and sliced DataFrame. @@ -1098,18 +1108,22 @@ def get_dataframe( name = basename(data_like).split(".")[0] source = DataFrameAdapter(path=path, **kwargs) self.update(**{name: source}) + elif isinstance(data_like, dict): + if not SourceSpecDict.__required_keys__.issuperset(set(data_like.keys())): + unknown_keys = set(data_like.keys()) - SourceSpecDict.__required_keys__ + raise ValueError( + f"Unknown keys in requested data source: {unknown_keys}" + ) + else: + source = self.get_source(**data_like) + name = source.name elif data_like in self.sources: name = data_like - source = self.sources[name] + source = self.get_source(name) else: raise FileNotFoundError(f"No such file or catalog key: {data_like}") self._used_data.append(name) - source = self.get_source( - name, - provider=provider, - data_version=data_version, - ) self.logger.info( f"DataCatalog: Getting {name} DataFrame {source.driver} data" f" from {source.path}" @@ -1231,7 +1245,7 @@ def _parse_data_dict( **source, ) - return data + return data def _yml_from_uri_or_path(uri_or_path: Union[Path, str]) -> Dict: @@ -1286,6 +1300,17 @@ def _denormalise_data_dict(data_dict, catalog_name="") -> List[Dict[str, Any]]: else: dicts.append({name: source}) + for d in dicts: + if "placeholders" in d: + # pop avoid placeholders being passed to adapter + options = d.pop("placeholders") + for combination in itertools.product(*options.values()): + path_n = d["path"] + name_n = d["name"] + for k, v in zip(options.keys(), combination): + path_n = path_n.replace("{" + k + "}", v) + name_n = name_n.replace("{" + k + "}", v) + return dicts diff --git a/hydromt/models/model_api.py b/hydromt/models/model_api.py index 9ba8c80f6..99d29a14f 100644 --- a/hydromt/models/model_api.py +++ b/hydromt/models/model_api.py @@ -1,4 +1,4 @@ -## will be deprecated -*- coding: utf-8 -*- +# -*- coding: utf-8 -*- """General and basic API for models in HydroMT.""" import glob diff --git a/hydromt/models/model_grid.py b/hydromt/models/model_grid.py index ff467c18b..d96323ce5 100644 --- a/hydromt/models/model_grid.py +++ b/hydromt/models/model_grid.py @@ -20,7 +20,7 @@ logger = logging.getLogger(__name__) -class GridMixin(Model): +class GridMixin(object): # placeholders # xr.Dataset representation of all static parameter maps at the same resolution and # bounds - renamed from staticmaps diff --git a/tests/test_data_adapter.py b/tests/test_data_adapter.py index 0b90b5272..7062a1845 100644 --- a/tests/test_data_adapter.py +++ b/tests/test_data_adapter.py @@ -182,6 +182,7 @@ def test_rasterdataset_unit_attrs(artifact_data: DataCatalog): assert raster["temp_max"].attrs["long_name"] == attrs["temp_max"]["long_name"] +# @pytest.mark.skip() def test_geodataset(geoda, geodf, ts, tmpdir): # this test can sometimes hang because of threading issues therefore # the synchronous scheduler here is necessary diff --git a/tests/test_data_catalog.py b/tests/test_data_catalog.py index e228af6b3..9fec6e486 100644 --- a/tests/test_data_catalog.py +++ b/tests/test_data_catalog.py @@ -103,8 +103,9 @@ def test_versioned_catalogs(tmpdir): # make sure the catalogs individually still work legacy_yml_fn = join(DATADIR, "legacy_esa_worldcover.yml") legacy_data_catalog = DataCatalog(data_libs=[legacy_yml_fn]) - assert legacy_data_catalog.get_source("esa_worldcover").path.endswith( - "landuse/esa_worldcover/esa-worldcover.vrt" + assert ( + Path(legacy_data_catalog.get_source("esa_worldcover").path).name + == "esa-worldcover.vrt" ) assert legacy_data_catalog.get_source("esa_worldcover").data_version == 2020 @@ -161,9 +162,15 @@ def test_versioned_catalogs(tmpdir): read_merged_catalog.get_source("esa_worldcover", provider="aws").path == "s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt" ) - assert read_merged_catalog.get_source( - "esa_worldcover", provider="local" - ).path.endswith("landuse/esa_worldcover_2021/esa-worldcover.vrt") + assert ( + Path( + read_merged_catalog.get_source("esa_worldcover", provider="local").path + ).name + == "esa-worldcover.vrt" + ) + + # make sure dataframe doesn't merge different variants + assert len(read_merged_catalog.to_dataframe()) == 2 # Make sure we can queiry for the version we want aws_and_legacy_data_catalog = DataCatalog(data_libs=[legacy_yml_fn, aws_yml_fn]) @@ -176,9 +183,14 @@ def test_versioned_catalogs(tmpdir): aws_and_legacy_data_catalog.get_source("esa_worldcover", provider="aws").path == "s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt" ) - assert aws_and_legacy_data_catalog.get_source( - "esa_worldcover", provider="legacy_esa_worldcover" - ).path.endswith("landuse/esa_worldcover/esa-worldcover.vrt") + assert ( + Path( + aws_and_legacy_data_catalog.get_source( + "esa_worldcover", provider="legacy_esa_worldcover" + ).path + ).name + == "esa-worldcover.vrt" + ) _ = aws_and_legacy_data_catalog.to_dict() From 9e59f6fc1c913288aa9c23427b3996de76e1c144 Mon Sep 17 00:00:00 2001 From: Sam Vente Date: Thu, 27 Jul 2023 15:51:00 +0200 Subject: [PATCH 23/27] remove stray debug statement --- hydromt/data_catalog.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/hydromt/data_catalog.py b/hydromt/data_catalog.py index 60f9ab60f..ea9557fcf 100644 --- a/hydromt/data_catalog.py +++ b/hydromt/data_catalog.py @@ -639,13 +639,13 @@ def to_dict( def to_dataframe(self, source_names: List = []) -> pd.DataFrame: """Return data catalog summary as DataFrame.""" - seq = [ - (name, source) - for name, source in self.iter_sources() - if len(source_names) == 0 or name in source_names - ] - print(seq) - return pd.DataFrame.from_records(seq) + return pd.DataFrame.from_records( + [ + (name, source) + for name, source in self.iter_sources() + if len(source_names) == 0 or name in source_names + ] + ) def export_data( self, From 3f814a43921ab796fd51330b342a8413784e406c Mon Sep 17 00:00:00 2001 From: Sam Vente Date: Thu, 27 Jul 2023 15:58:04 +0200 Subject: [PATCH 24/27] move type to explicit Union --- hydromt/data_catalog.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hydromt/data_catalog.py b/hydromt/data_catalog.py index ea9557fcf..258b993d8 100644 --- a/hydromt/data_catalog.py +++ b/hydromt/data_catalog.py @@ -43,7 +43,7 @@ # just for typehints SourceSpecDict = TypedDict( - "SourceSpecDict", {"source": str, "provider": str, "version": str | int} + "SourceSpecDict", {"source": str, "provider": str, "version": Union[str, int]} ) From 7bf4de374ddef5e7dbbe41f53ea6e9b7eaa74ae1 Mon Sep 17 00:00:00 2001 From: Sam Vente Date: Thu, 27 Jul 2023 16:44:13 +0200 Subject: [PATCH 25/27] fix to_dataframe --- hydromt/data_catalog.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/hydromt/data_catalog.py b/hydromt/data_catalog.py index 258b993d8..9abdccff4 100644 --- a/hydromt/data_catalog.py +++ b/hydromt/data_catalog.py @@ -639,13 +639,19 @@ def to_dict( def to_dataframe(self, source_names: List = []) -> pd.DataFrame: """Return data catalog summary as DataFrame.""" - return pd.DataFrame.from_records( - [ - (name, source) - for name, source in self.iter_sources() - if len(source_names) == 0 or name in source_names - ] - ) + d = [] + for name, source in self.iter_sources(): + if len(source_names) > 0 and name not in source_names: + continue + d.append( + { + "name": name, + "provider": source.provider, + "data_version": source.data_version, + **source.summary(), + } + ) + return pd.DataFrame.from_records(d).set_index("name") def export_data( self, From 11296c4126e60f489767f3a003c20fd8cdf9b0a4 Mon Sep 17 00:00:00 2001 From: Sam Vente Date: Thu, 27 Jul 2023 16:52:57 +0200 Subject: [PATCH 26/27] fix docs warning --- docs/dev/dev_install.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/dev/dev_install.rst b/docs/dev/dev_install.rst index 35ce98d90..694e1a572 100644 --- a/docs/dev/dev_install.rst +++ b/docs/dev/dev_install.rst @@ -99,7 +99,7 @@ Finally, create a developer installation of HydroMT: see :ref:`installation guide ` for the difference between both. Fine tuned installation ----------------------- +----------------------- If you want a more fine tuned installation you can also specify exactly which dependency groups you'd like. For instance, this will create an environment From b907c9072b1fc750f7d91e22f30f69ac08c9f6ef Mon Sep 17 00:00:00 2001 From: Dirk Eilander Date: Sun, 30 Jul 2023 23:37:24 +0200 Subject: [PATCH 27/27] fix inconsistent argument names; fix to_dict with more than 2 variants; return newest version by default --- docs/api.rst | 9 +- docs/user_guide/data_prepare_cat.rst | 77 ++- docs/user_guide/model_config.rst | 6 + hydromt/data_adapter/data_adapter.py | 11 +- hydromt/data_adapter/dataframe.py | 7 +- hydromt/data_adapter/geodataframe.py | 7 +- hydromt/data_adapter/geodataset.py | 10 +- hydromt/data_adapter/rasterdataset.py | 8 +- hydromt/data_catalog.py | 758 +++++++++++++++----------- tests/data/aws_esa_worldcover.yml | 2 +- tests/data/legacy_esa_worldcover.yml | 2 +- tests/data/merged_esa_worldcover.yml | 6 +- tests/test_data_adapter.py | 14 +- tests/test_data_catalog.py | 266 +++++---- 14 files changed, 693 insertions(+), 490 deletions(-) diff --git a/docs/api.rst b/docs/api.rst index 74d1a6fb6..42228ec58 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -22,6 +22,8 @@ General :toctree: _generated data_catalog.DataCatalog + data_catalog.DataCatalog.get_source + data_catalog.DataCatalog.iter_sources data_catalog.DataCatalog.sources data_catalog.DataCatalog.keys data_catalog.DataCatalog.predefined_catalogs @@ -36,12 +38,13 @@ Add data sources .. autosummary:: :toctree: _generated - data_catalog.DataCatalog.set_predefined_catalogs + data_catalog.DataCatalog.add_source + data_catalog.DataCatalog.update data_catalog.DataCatalog.from_predefined_catalogs data_catalog.DataCatalog.from_archive data_catalog.DataCatalog.from_yml data_catalog.DataCatalog.from_dict - data_catalog.DataCatalog.update + data_catalog.DataCatalog.set_predefined_catalogs .. _api_data_catalog_get: @@ -54,7 +57,7 @@ Get data data_catalog.DataCatalog.get_rasterdataset data_catalog.DataCatalog.get_geodataset data_catalog.DataCatalog.get_geodataframe - + data_catalog.DataCatalog.get_dataframe RasterDataset diff --git a/docs/user_guide/data_prepare_cat.rst b/docs/user_guide/data_prepare_cat.rst index 078a851cc..53494c267 100644 --- a/docs/user_guide/data_prepare_cat.rst +++ b/docs/user_guide/data_prepare_cat.rst @@ -31,6 +31,7 @@ The ``rename``, ``nodata``, ``unit_add`` and ``unit_mult`` options are set per v meta: root: /path/to/data_root/ version: version + name: data_catalog_name my_dataset: crs: EPSG/WKT data_type: RasterDataset/GeoDataset/GeoDataFrame/DataFrame @@ -48,8 +49,6 @@ The ``rename``, ``nodata``, ``unit_add`` and ``unit_mult`` options are set per v nodata: new_variable_name: value path: /absolut_path/to/my_dataset.extension OR relative_path/to_my_dataset.extension - placeholders: - [placeholder_key: [placeholder_values]] rename: old_variable_name: new_variable_name unit_add: @@ -91,6 +90,10 @@ A full list of **optional data source arguments** is given below - **filesystem** (required if different than local): specify if the data is stored locally or remotely (e.g cloud). Supported filesystems are *local* for local data, *gcs* for data stored on Google Cloud Storage, and *aws* for data stored on Amazon Web Services. Profile or authentication information can be passed to ``driver_kwargs`` via *storage_options*. +- **version** (recommended): data source version + *NOTE*: New in HydroMT v0.8.1 +- **provider** (recommended): data source provider + *NOTE*: New in HydroMT v0.8.1 - **meta** (recommended): additional information on the dataset organized in a sub-list. Good meta data includes a *source_url*, *source_license*, *source_version*, *paper_ref*, *paper_doi*, *category*, etc. These are added to the data attributes. Usual categories within HydroMT are *geography*, *topography*, *hydrography*, *meteo*, *landuse*, *ocean*, *socio-economic*, *observed data* @@ -103,25 +106,81 @@ A full list of **optional data source arguments** is given below - **unit_mult**: multiply the input data by a value for unit conversion (e.g. 1000 for conversion from m to mm of precipitation). - **attrs** (optional): This argument allows for setting attributes like the unit or long name to variables. *NOTE*: New in HydroMT v0.7.2 +- **placeholder** (optional): this argument can be used to generate multiple sources with a single entry in the data catalog file. If different files follow a logical + nomenclature, multiple data sources can be defined by iterating through all possible combinations of the placeholders. The placeholder names should be given in the + source name and the path and its values listed under the placeholder argument. +- **variants** (optional): This argument can be used to generate multiple sources with the same name, but from different providers or versions. + Any keys here are essentially used to extend/overwrite the base arguments. + +The following are **optional data source arguments** for *RasterDataset*, *GeoDataFrame*, and *GeoDataset*: + - **crs** (required if missing in the data): EPSG code or WKT string of the reference coordinate system of the data. Only used if not crs can be inferred from the input data. + +The following are **optional data source arguments** for *RasterDataset*: + - **zoom_level** (optional): this argument can be used for a *RasterDatasets* that contain multiple zoom levels of different resolution. It should contain a list of numeric zoom levels that correspond to the `zoom_level` key in file path, e.g., ``"path/to/my/files/{zoom_level}/data.tif"`` and corresponding resolution, expressed in the unit of the data crs. The *crs* argument is therefore required when using zoom_levels to correctly interpret the unit of the resolution. The required zoom level can be requested from HydroMT as argument to the `DataCatalog.get_rasterdataset` method, see `Reading tiled raster data with different zoom levels <../_examples/working_with_tiled_raster_data.ipynb>`_. -- **placeholder** (optional): this argument can be used to generate multiple sources with a single entry in the data catalog file. If different files follow a logical - nomenclature, multiple data sources can be defined by iterating through all possible combinations of the placeholders. The placeholder names should be given in the - source name and the path and its values listed under the placeholder argument. -- **variants** (optional): If you want to use the same data source but load it from different places (e.g. local & aws) you can add this key - Keys here are essentially overrides that will get applied to the containing catalog when they get parsed and expanded. .. note:: - The **alias** argument will be deprecated and should no longer be used, see `github issue for more information `_ + The **alias** argument will be deprecated and should no longer be used, see + `github issue for more information `_ .. warning:: - Using cloud data is still experimental and only supported for *DataFrame*, *RasterDataset* and *Geodataset* with *zarr*. *RasterDataset* with *raster* driver is also possible + Using cloud data is still experimental and only supported for *DataFrame*, *RasterDataset* and + *Geodataset* with *zarr*. *RasterDataset* with *raster* driver is also possible but in case of multiple files (mosaic) we strongly recommend using a vrt file for speed and computation efficiency. + +Data variants +------------- + +Data variants are used to define multiple data sources with the same name, but from different providers or versions. +Below, we show an example of a data catalog for a RasterDataset with multiple variants of the same data source (esa_worldcover), +but this works identical for other data types. +Here, the *crs*, *data_type*, *driver* and *filesystem* are common arguments used for all variants. +The variant arguments are used to extend and/or overwrite the common arguments, creating new sources. + +.. code-block:: yaml + + esa_worldcover: + crs: 4326 + data_type: RasterDataset + driver: raster + filesystem: local + variants: + - provider: local + version: 2021 + path: landuse/esa_worldcover_2021/esa-worldcover.vrt + - provider: local + version: 2020 + path: landuse/esa_worldcover/esa-worldcover.vrt + - provider: aws + version: 2020 + path: s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt + filesystem: s3 + + +To request a specific variant, the variant arguments can be used as keyword arguments +to the `DataCatalog.get_rasterdataset` method, see code below. +By default the newest version from the last provider is returned when requesting a data +source with specific version or provider. +Requesting a specific version from a HydroMT configuration file is also possible, see :ref:`model_config`. + +.. code-block:: python + + from hydromt import DataCatalog + dc = DataCatalog.from_yml("data_catalog.yml") + # get the default version. This will return the latest (2020) version from the last provider (aws) + ds = dc.get_rasterdataset("esa_worldcover") + # get a 2020 version. This will return the 2020 version from the last provider (aws) + ds = dc.get_rasterdataset("esa_worldcover", version=2020) + # get a 2021 version. This will return the 2021 version from the local provider as this verion is not available from aws . + ds = dc.get_rasterdataset("esa_worldcover", version=2021) + # get the 2020 version from the local provider + ds = dc.get_rasterdataset("esa_worldcover", version=2020, provider="local") diff --git a/docs/user_guide/model_config.rst b/docs/user_guide/model_config.rst index 0a21d0e02..c274ad757 100644 --- a/docs/user_guide/model_config.rst +++ b/docs/user_guide/model_config.rst @@ -54,3 +54,9 @@ An example .yaml file is shown below. Note that this .yaml file does not apply t setup_manning_roughness: lulc_fn: globcover # source name of landuse-landcover data mapping_fn: globcover_mapping # source name of mapping table converting lulc classes to N values + + setup_infiltration: + soil_fn: + source: soil_data # source name of soil data with specific version + version: 1.0 # version of soil data + mapping_fn: soil_mapping # source name of mapping table converting soil classes to infiltration parameters diff --git a/hydromt/data_adapter/data_adapter.py b/hydromt/data_adapter/data_adapter.py index 200fcd5df..2ae7eea49 100644 --- a/hydromt/data_adapter/data_adapter.py +++ b/hydromt/data_adapter/data_adapter.py @@ -127,7 +127,7 @@ def __init__( name="", # optional for now catalog_name="", # optional for now provider: Optional[str] = None, - data_version: Optional[str] = None, + version: Optional[str] = None, ): """General Interface to data source for HydroMT. @@ -174,7 +174,7 @@ def __init__( self.name = name self.catalog_name = catalog_name self.provider = provider - self.data_version = data_version + self.version = str(version) if version is not None else None # version as str # general arguments self.path = path # driver and driver keyword-arguments @@ -232,6 +232,13 @@ def __repr__(self): """Pretty print string representation of self.""" return self.__str__() + def __eq__(self, other: object) -> bool: + """Return True if self and other are equal.""" + if type(other) is type(self): + return self.to_dict() == other.to_dict() + else: + return False + def _parse_zoom_level( self, zoom_level: int | tuple = None, diff --git a/hydromt/data_adapter/dataframe.py b/hydromt/data_adapter/dataframe.py index a3c8504b7..55f79c668 100644 --- a/hydromt/data_adapter/dataframe.py +++ b/hydromt/data_adapter/dataframe.py @@ -39,7 +39,7 @@ def __init__( name: str = "", # optional for now catalog_name: str = "", # optional for now provider: Optional[str] = None, - data_version: Optional[str] = None, + version: Optional[str] = None, **kwargs, ): """Initiate data adapter for 2D tabular data. @@ -109,7 +109,7 @@ def __init__( name=name, catalog_name=catalog_name, provider=provider, - data_version=data_version, + version=version, ) def to_file( @@ -202,9 +202,6 @@ def get_data( _ = self.resolve_paths(**so_kwargs) # throw nice error if data not found kwargs = self.driver_kwargs.copy() - # these are just for internal bookeeping. drivers don't need them - _ = kwargs.pop("provider", None) - _ = kwargs.pop("data_version", None) # read and clip logger.info(f"DataFrame: Read {self.driver} data.") diff --git a/hydromt/data_adapter/geodataframe.py b/hydromt/data_adapter/geodataframe.py index e05a728aa..a08bba869 100644 --- a/hydromt/data_adapter/geodataframe.py +++ b/hydromt/data_adapter/geodataframe.py @@ -47,7 +47,7 @@ def __init__( name: str = "", # optional for now catalog_name: str = "", # optional for now provider=None, - data_version=None, + version=None, **kwargs, ): """Initiate data adapter for geospatial vector data. @@ -119,7 +119,7 @@ def __init__( name=name, catalog_name=catalog_name, provider=provider, - data_version=data_version, + version=version, ) self.crs = crs @@ -220,9 +220,6 @@ def get_data( _ = self.resolve_paths() # throw nice error if data not found kwargs = self.driver_kwargs.copy() - # these are just for internal bookeeping. drivers don't need them - _ = kwargs.pop("provider", None) - _ = kwargs.pop("data_version", None) # parse geom, bbox and buffer arguments clip_str = "" if geom is None and bbox is not None: diff --git a/hydromt/data_adapter/geodataset.py b/hydromt/data_adapter/geodataset.py index 50945ce20..635c0da13 100644 --- a/hydromt/data_adapter/geodataset.py +++ b/hydromt/data_adapter/geodataset.py @@ -48,7 +48,7 @@ def __init__( name: str = "", # optional for now catalog_name: str = "", # optional for now provider=None, - data_version=None, + version=None, **kwargs, ): """Initiate data adapter for geospatial timeseries data. @@ -126,7 +126,7 @@ def __init__( name=name, catalog_name=catalog_name, provider=provider, - data_version=data_version, + version=version, ) self.crs = crs @@ -259,12 +259,6 @@ def get_data( ) kwargs = self.driver_kwargs.copy() - # these are just for internal bookeeping. drivers don't need them - _ = kwargs.pop( - "provider", - None, - ) - _ = kwargs.pop("data_version", None) # parse geom, bbox and buffer arguments clip_str = "" if geom is None and bbox is not None: diff --git a/hydromt/data_adapter/rasterdataset.py b/hydromt/data_adapter/rasterdataset.py index abe4acb8a..ee46055b2 100644 --- a/hydromt/data_adapter/rasterdataset.py +++ b/hydromt/data_adapter/rasterdataset.py @@ -51,7 +51,7 @@ def __init__( name: str = "", # optional for now catalog_name: str = "", # optional for now provider=None, - data_version=None, + version=None, **kwargs, ): """Initiate data adapter for geospatial raster data. @@ -130,7 +130,7 @@ def __init__( name=name, catalog_name=catalog_name, provider=provider, - data_version=data_version, + version=version, ) self.crs = crs self.zoom_levels = zoom_levels @@ -275,9 +275,7 @@ def get_data( ) kwargs = self.driver_kwargs.copy() - # these are just for internal bookeeping. drivers don't need them - _ = kwargs.pop("provider", None) - _ = kwargs.pop("data_version", None) + # zarr can use storage options directly, the rest should be converted to # file-like objects if "storage_options" in kwargs and self.driver == "raster": diff --git a/hydromt/data_catalog.py b/hydromt/data_catalog.py index 9abdccff4..66ac96de6 100644 --- a/hydromt/data_catalog.py +++ b/hydromt/data_catalog.py @@ -5,7 +5,6 @@ from __future__ import annotations import copy -import inspect import itertools import logging import os @@ -13,7 +12,7 @@ import warnings from os.path import abspath, basename, exists, isdir, isfile, join from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union +from typing import Dict, List, Optional, Tuple, TypedDict, Union import geopandas as gpd import numpy as np @@ -155,78 +154,126 @@ def predefined_catalogs(self) -> Dict: self.set_predefined_catalogs() return self._catalogs - def get_source(self, key: str, provider=None, data_version=None) -> DataAdapter: - """Get the source.""" - if key not in self._sources: + def get_source( + self, source: str, provider: Optional[str] = None, version: Optional[str] = None + ) -> DataAdapter: + """Return a data source. + + Parameters + ---------- + source : str + Name of the data source. + provider : str, optional + Name of the data provider, by default None. + By default the last added provider is returned. + version : str, optional + Version of the data source, by default None. + By default the newest version of the requested provider is returned. + + Returns + ------- + DataAdapter + DataAdapter object. + """ + source = str(source) + if source not in self._sources: available_sources = sorted(list(self._sources.keys())) raise KeyError( - f"Requested unknown data source: '{key}' " + f"Requested unknown data source '{source}' " f"available sources are: {available_sources}" ) + available_providers = self._sources[source] + + # make sure all arguments are strings + provider = str(provider) if provider is not None else provider + version = str(version) if version is not None else version - available_providers = self._sources[key] + # find provider matching requested version + if provider is None and version is not None: + providers = [p for p, v in available_providers.items() if version in v] + if len(providers) > 0: # error raised later if no provider found + provider = providers[-1] + # check if provider is available, otherwise use last added provider if provider is None: requested_provider = list(available_providers.keys())[-1] else: requested_provider = provider + if requested_provider not in available_providers: + providers = sorted(list(available_providers.keys())) + raise KeyError( + f"Requested unknown provider '{requested_provider}' for " + f"data source '{source}' available providers are: {providers}" + ) + available_versions = available_providers[requested_provider] - if requested_provider not in available_providers: - providers = sorted(list(available_providers.keys())) - raise KeyError( - f"Requested unknown proveder '{requested_provider}' for data_source" - f" '{key}' available providers are {providers}" - ) - - available_data_versions = available_providers[requested_provider] - - if data_version is None: - requested_data_version = list(available_data_versions.keys())[-1] + # check if version is available, otherwise use last added version which is + # always the newest version + if version is None: + requested_version = list(available_versions.keys())[-1] else: - requested_data_version = data_version + requested_version = version + if requested_version not in available_versions: + data_versions = sorted(list(map(str, available_versions.keys()))) + raise KeyError( + f"Requested unknown version '{requested_version}' for " + f"data source '{source}' and provider '{requested_provider}' " + f"available versions are {data_versions}" + ) - if requested_data_version not in available_data_versions: - data_versions = sorted(list(map(str, available_data_versions.keys()))) - raise KeyError( - f"Requested unknown data_version '{requested_data_version}' for" - f" data_source '{key}' and provider '{requested_provider}'" - f" available data_versions are {data_versions}" - ) + return self._sources[source][requested_provider][requested_version] + + def add_source(self, source: str, adapter: DataAdapter) -> None: + """Add a new data source to the data catalog. - return self._sources[key][requested_provider][requested_data_version] + The data version and provider are extracted from the DataAdapter object. - def add_source(self, key: str, adapter: DataAdapter) -> None: - """Add a new data source to the data catalog.""" + Parameters + ---------- + source : str + Name of the data source. + adapter : DataAdapter + DataAdapter object. + """ if not isinstance(adapter, DataAdapter): raise ValueError("Value must be DataAdapter") - if hasattr(adapter, "data_version") and adapter.data_version is not None: - data_version = adapter.data_version + if hasattr(adapter, "version") and adapter.version is not None: + version = adapter.version else: - data_version = "UNSPECIFIED" + version = "_UNSPECIFIED_" # make sure this comes first in sorted list if hasattr(adapter, "provider") and adapter.provider is not None: provider = adapter.provider else: provider = adapter.catalog_name - if key not in self._sources: - self._sources[key] = {} - - if provider not in self._sources[key]: - self._sources[key][provider] = {} + if source not in self._sources: + self._sources[source] = {} + else: # check if data type is the same as adapter with same name + adapter0 = next(iter(next(iter(self._sources[source].values())).values())) + if adapter0.data_type != adapter.data_type: + raise ValueError( + f"Data source '{source}' already exists with data type " + f"'{adapter0.data_type}' but new data source has data type " + f"'{adapter.data_type}'." + ) - if ( - provider in self._sources[key] - and data_version in self._sources[key][provider] - ): - warnings.warn( - f"overwriting entry with provider: {provider} and version:" - f" {data_version} in {key} entry", - UserWarning, - ) + if provider not in self._sources[source]: + versions = {version: adapter} + else: + versions = self._sources[source][provider] + if provider in self._sources[source] and version in versions: + warnings.warn( + f"overwriting data source '{source}' with " + f"provider {provider} and version {version}.", + UserWarning, + ) + # update and sort dictionary -> make sure newest version is last + versions.update({version: adapter}) + versions = {k: versions[k] for k in sorted(list(versions.keys()))} - self._sources[key][provider][data_version] = adapter + self._sources[source][provider] = versions def __getitem__(self, key: str) -> DataAdapter: """Get the source.""" @@ -250,8 +297,8 @@ def iter_sources(self) -> List[Tuple[str, DataAdapter]]: """Return a flat list of all available data sources with no duplicates.""" ans = [] for source_name, available_providers in self._sources.items(): - for _, available_data_versions in available_providers.items(): - for _, adapter in available_data_versions.items(): + for _, available_versions in available_providers.items(): + for _, adapter in available_versions.items(): ans.append((source_name, adapter)) return ans @@ -267,17 +314,29 @@ def __iter__(self): def __len__(self): """Return number of sources.""" - warnings.warn( - "Using len on DataCatalog directly is deprecated." - " Please use len(cat.iter_sources())", - DeprecationWarning, - ) return len(self.iter_sources()) def __repr__(self): """Prettyprint the sources.""" return self.to_dataframe().__repr__() + def __eq__(self, other) -> bool: + if type(other) is type(self): + if len(self) != len(other): + return False + for name, source in self.iter_sources(): + try: + other_source = other.get_source( + name, provider=source.provider, version=source.version + ) + except KeyError: + return False + if source != other_source: + return False + else: + return False + return True + def _repr_html_(self): return self.to_dataframe()._repr_html_() @@ -314,7 +373,7 @@ def set_predefined_catalogs(self, urlpath: Union[Path, str] = None) -> Dict: def from_artifacts( self, name: str = "artifact_data", version: str = "latest" - ) -> None: + ) -> DataCatalog: """Parse artifacts. Deprecated method. Use @@ -326,15 +385,35 @@ def from_artifacts( Catalog name. If None (default) sample data is downloaded. version : str, optional Release version. By default it takes the latest known release. + + Returns + ------- + DataCatalog + DataCatalog object with parsed artifact data. """ warnings.warn( '"from_artifacts" is deprecated. Use "from_predefined_catalogs instead".', DeprecationWarning, ) - self.from_predefined_catalogs(name, version) + return self.from_predefined_catalogs(name, version) - def from_predefined_catalogs(self, name: str, version: str = "latest") -> None: - """Generate a catalogue from one of the predefined ones.""" + def from_predefined_catalogs( + self, name: str, version: str = "latest" + ) -> DataCatalog: + """Add data sources from a predefined data catalog. + + Parameters + ---------- + name : str + Catalog name. + version : str, optional + Catlog release version. By default it takes the latest known release. + + Returns + ------- + DataCatalog + DataCatalog object with parsed predefined catalog added. + """ if "=" in name: name, version = name.split("=")[0], name.split("=")[-1] if name not in self.predefined_catalogs: @@ -355,12 +434,27 @@ def from_predefined_catalogs(self, name: str, version: str = "latest") -> None: self.from_archive(urlpath, name=name, version=version) else: self.logger.info(f"Reading data catalog {name} {version}") - self.from_yml(urlpath) + self.from_yml(urlpath, catalog_name=name) def from_archive( self, urlpath: Union[Path, str], version: str = None, name: str = None - ) -> None: - """Read a data archive including a data_catalog.yml file.""" + ) -> DataCatalog: + """Read a data archive including a data_catalog.yml file. + + Parameters + ---------- + urlpath : str, Path + Path or url to data archive. + version : str, optional + Version of data archive, by default None. + name : str, optional + Name of data catalog, by default None. + + Returns + ------- + DataCatalog + DataCatalog object with parsed data archive added. + """ name = basename(urlpath).split(".")[0] if name is None else name root = join(self._cache_dir, name) if version is not None: @@ -377,11 +471,15 @@ def from_archive( self.logger.debug(f"Unpacking data from {archive_fn}") shutil.unpack_archive(archive_fn, root) # parse catalog - self.from_yml(yml_fn) + return self.from_yml(yml_fn, catalog_name=name) def from_yml( - self, urlpath: Union[Path, str], root: str = None, mark_used: bool = False - ) -> None: + self, + urlpath: Union[Path, str], + root: str = None, + catalog_name: str = None, + mark_used: bool = False, + ) -> DataCatalog: """Add data sources based on yaml file. Parameters @@ -412,9 +510,8 @@ def from_yml( data_type: driver: filesystem: - kwargs: + driver_kwargs: : - crs: nodata: : rename: @@ -433,9 +530,11 @@ def from_yml( placeholders: : : - zoom_levels: - : - : + + Returns + ------- + DataCatalog + DataCatalog object with parsed yaml file added. """ self.logger.info(f"Parsing data catalog from {urlpath}") yml = _yml_from_uri_or_path(urlpath) @@ -443,23 +542,31 @@ def from_yml( meta = dict() # legacy code with root/category at highest yml level if "root" in yml: + warnings.warn( + "The 'root' key is deprecated, use 'meta: root' instead.", + DeprecationWarning, + ) meta.update(root=yml.pop("root")) if "category" in yml: + warnings.warn( + "The 'category' key is deprecated, use 'meta: category' instead.", + DeprecationWarning, + ) meta.update(category=yml.pop("category")) + # read meta data meta = yml.pop("meta", meta) - self_version = Version(__version__) + # check version required hydromt version hydromt_version = meta.get("hydromt_version", __version__) + self_version = Version(__version__) yml_version = Version(hydromt_version) - if yml_version > self_version: self.logger.warning( f"Specified HydroMT version ({hydromt_version}) \ more recent than installed version ({__version__}).", ) - - catalog_name = meta.get("name", "".join(basename(urlpath).split(".")[:-1])) - + if catalog_name is None: + catalog_name = meta.get("name", "".join(basename(urlpath).split(".")[:-1])) if root is None: root = meta.get("root", os.path.dirname(urlpath)) self.from_dict( @@ -477,7 +584,7 @@ def from_dict( root: Union[str, Path] = None, category: str = None, mark_used: bool = False, - ) -> None: + ) -> DataCatalog: """Add data sources based on dictionary. Parameters @@ -507,15 +614,13 @@ def from_dict( "data_type": , "driver": , "filesystem": , - "kwargs": {: }, - "crs": , + "driver_kwargs": {: }, "nodata": , "rename": {: }, "unit_add": {: }, "unit_mult": {: }, "meta": {...}, "placeholders": {: }, - "zoom_levels": {: }, } : { ... @@ -523,14 +628,26 @@ def from_dict( } """ - data_dicts = _denormalise_data_dict(data_dict, catalog_name=catalog_name) - for d in data_dicts: - parsed_dict = _parse_data_dict( - d, catalog_name=catalog_name, root=root, category=category + meta = data_dict.pop("meta", {}) + if "root" in meta and root is None: + root = meta.pop("root") + if "category" in meta and category is None: + category = meta.pop("category") + if "name" in meta and catalog_name is None: + catalog_name = meta.pop("name") + for name, source_dict in _denormalise_data_dict(data_dict): + adapter = _parse_data_source_dict( + name, + source_dict, + catalog_name=catalog_name, + root=root, + category=category, ) - self.update(**parsed_dict) + self.add_source(name, adapter) if mark_used: - self._used_data.extend(list(parsed_dict.keys())) + self._used_data.append(name) + + return self def to_yml( self, @@ -624,12 +741,25 @@ def to_dict( if existing == source_dict: sources_out.update({name: source_dict}) continue - base, diff_existing, diff_new = partition_dictionaries( - source_dict, existing - ) - _ = base.pop("driver_kwargs", None) - - base["variants"] = ([diff_new, diff_existing],) + if "variants" in existing: + variants = existing.pop("variants") + _, variant, _ = partition_dictionaries(source_dict, existing) + variants.append(variant) + existing["variants"] = variants + else: + base, diff_existing, diff_new = partition_dictionaries( + source_dict, existing + ) + # provider and version should always be in variants list + provider = base.pop("provider", None) + if provider is not None: + diff_existing["provider"] = provider + diff_new["provider"] = provider + version = base.pop("version", None) + if version is not None: + diff_existing["version"] = version + diff_new["version"] = version + base["variants"] = [diff_new, diff_existing] sources_out[name] = base else: sources_out.update({name: source_dict}) @@ -647,7 +777,7 @@ def to_dataframe(self, source_names: List = []) -> pd.DataFrame: { "name": name, "provider": source.provider, - "data_version": source.data_version, + "version": source.version, **source.summary(), } ) @@ -705,14 +835,14 @@ def export_data( source = self.get_source(name) provider = source.provider - data_version = source.data_version + version = source.version if name not in sources: sources[name] = {} if provider not in sources[name]: sources[name][provider] = {} - sources[name][provider][data_version] = copy.deepcopy(source) + sources[name][provider][version] = copy.deepcopy(source) else: sources = copy.deepcopy(self.sources) @@ -727,8 +857,8 @@ def export_data( # export data and update sources for key, available_variants in sources.items(): - for provider, available_data_versions in available_variants.items(): - for data_version, source in available_data_versions.items(): + for provider, available_versions in available_variants.items(): + for version, source in available_versions.items(): try: # read slice of source and write to file self.logger.debug(f"Exporting {key}.") @@ -772,15 +902,15 @@ def export_data( if provider not in sources_out[key]: sources_out[key][provider] = {} - sources_out[key][provider][data_version] = source + sources_out[key][provider][version] = source except FileNotFoundError: self.logger.warning(f"{key} file not found at {source.path}") # write data catalog to yml data_catalog_out = DataCatalog() for key, available_variants in sources_out.items(): - for provider, available_data_versions in available_variants.items(): - for data_version, adapter in available_data_versions.items(): + for provider, available_versions in available_variants.items(): + for version, adapter in available_versions.items(): data_catalog_out.add_source(key, adapter) data_catalog_out.to_yml(fn, root="auto", meta=meta) @@ -796,6 +926,8 @@ def get_rasterdataset( variables: Optional[Union[List, str]] = None, time_tuple: Optional[Tuple] = None, single_var_as_array: Optional[bool] = True, + provider: Optional[str] = None, + version: Optional[str] = None, **kwargs, ) -> xr.Dataset: """Return a clipped, sliced and unified RasterDataset. @@ -812,10 +944,12 @@ def get_rasterdataset( Arguments --------- - data_like: str, Path, xr.Dataset, xr.Datarray - Data catalog key, path to raster file or raster xarray data object. + data_like: str, Path, Dict, xr.Dataset, xr.Datarray + DataCatalog key, path to raster file or raster xarray data object. + The catalog key can be a string or a dictionary with the following keys: + {'name', 'provider', 'version'}. If a path to a raster file is provided it will be added - to the data_catalog with its based on the file basename without extension. + to the catalog with its based on the file basename. bbox : array-like of floats (xmin, ymin, xmax, ymax) bounding box of area of interest (in WGS84 coordinates). @@ -838,6 +972,10 @@ def get_rasterdataset( single_var_as_array: bool, optional If True, return a DataArray if the dataset consists of a single variable. If False, always return a Dataset. By default True. + provider: str, optional + Data source provider. If None (default) the last added provider is used. + version: str, optional + Data source version. If None (default) the newest version is used. **kwargs: Additional keyword arguments that are passed to the `RasterDatasetAdapter` function. Only used if `data_like` is a path to a raster file. @@ -847,30 +985,29 @@ def get_rasterdataset( obj: xarray.Dataset or xarray.DataArray RasterDataset """ - if isinstance(data_like, (xr.DataArray, xr.Dataset)): - return data_like - elif not isinstance(data_like, (str, Path)): - raise ValueError(f'Unknown raster data type "{type(data_like).__name__}"') + if isinstance(data_like, dict): + data_like, provider, version = _parse_data_like_dict( + data_like, provider, version + ) - if data_like not in self.sources and exists(abspath(data_like)): - path = str(abspath(data_like)) - name = basename(data_like).split(".")[0] - source = RasterDatasetAdapter(path=path, **kwargs) - self.update(**{name: source}) - elif isinstance(data_like, dict): - if not SourceSpecDict.__required_keys__.issuperset(set(data_like.keys())): - unknown_keys = set(data_like.keys()) - SourceSpecDict.__required_keys__ - raise ValueError( - f"Unknown keys in requested data source: {unknown_keys}" - ) + if isinstance(data_like, (str, Path)): + if isinstance(data_like, str) and data_like in self.sources: + name = data_like + source = self.get_source(name, provider=provider, version=version) + elif exists(abspath(data_like)): + path = str(abspath(data_like)) + if "provider" not in kwargs: + kwargs.update({"provider": "local"}) + source = RasterDatasetAdapter(path=path, **kwargs) + name = basename(data_like) + self.add_source(name, source) else: - source = self.get_source(**data_like) - name = source.name - elif isinstance(data_like, str) and data_like in self.sources: - name = data_like - source = self.get_source(name) + raise FileNotFoundError(f"No such file or catalog source: {data_like}") + elif isinstance(data_like, (xr.DataArray, xr.Dataset)): + # TODO apply bbox, geom, buffer, align, variables, time_tuple + return data_like else: - raise FileNotFoundError(f"No such file or catalog key: {data_like}") + raise ValueError(f'Unknown raster data type "{type(data_like).__name__}"') self._used_data.append(name) self.logger.info( @@ -899,6 +1036,8 @@ def get_geodataframe( buffer: Union[float, int] = 0, variables: Optional[Union[List, str]] = None, predicate: str = "intersects", + provider: Optional[str] = None, + version: Optional[str] = None, **kwargs, ): """Return a clipped and unified GeoDataFrame (vector). @@ -912,8 +1051,10 @@ def get_geodataframe( --------- data_like: str, Path, gpd.GeoDataFrame Data catalog key, path to vector file or a vector geopandas object. + The catalog key can be a string or a dictionary with the following keys: + {'name', 'provider', 'version'}. If a path to a vector file is provided it will be added - to the data_catalog with its based on the file basename without extension. + to the catalog with its based on the file basename. bbox : array-like of floats (xmin, ymin, xmax, ymax) bounding box of area of interest (in WGS84 coordinates). @@ -930,6 +1071,10 @@ def get_geodataframe( variables : str or list of str, optional. Names of GeoDataFrame columns to return. By default all columns are returned. + provider: str, optional + Data source provider. If None (default) the last added provider is used. + version: str, optional + Data source version. If None (default) the newest version is used. **kwargs: Additional keyword arguments that are passed to the `GeoDataFrameAdapter` function. Only used if `data_like` is a path to a vector file. @@ -939,30 +1084,28 @@ def get_geodataframe( gdf: geopandas.GeoDataFrame GeoDataFrame """ - if isinstance(data_like, gpd.GeoDataFrame): - return data_like - elif not isinstance(data_like, (str, Path)): - raise ValueError(f'Unknown vector data type "{type(data_like).__name__}"') - - if data_like not in self.sources and exists(abspath(data_like)): - path = str(abspath(data_like)) - name = basename(data_like).split(".")[0] - source = GeoDataFrameAdapter(path=path, **kwargs) - self.update(**{name: source}) - elif isinstance(data_like, dict): - if not SourceSpecDict.__required_keys__.issuperset(set(data_like.keys())): - unknown_keys = set(data_like.keys()) - SourceSpecDict.__required_keys__ - raise ValueError( - f"Unknown keys in requested data source: {unknown_keys}" - ) + if isinstance(data_like, dict): + data_like, provider, version = _parse_data_like_dict( + data_like, provider, version + ) + if isinstance(data_like, (str, Path)): + if str(data_like) in self.sources: + name = data_like + source = self.get_source(name, provider=provider, version=version) + elif exists(abspath(data_like)): + path = str(abspath(data_like)) + if "provider" not in kwargs: + kwargs.update({"provider": "local"}) + source = GeoDataFrameAdapter(path=path, **kwargs) + name = basename(data_like) + self.add_source(name, source) else: - source = self.get_source(**data_like) - name = source.name - elif isinstance(data_like, str) and data_like in self.sources: - name = data_like - source = self.get_source(name) + raise FileNotFoundError(f"No such file or catalog source: {data_like}") + elif isinstance(data_like, gpd.GeoDataFrame): + # TODO apply bbox, geom, buffer, predicate, variables + return data_like else: - raise FileNotFoundError(f"No such file or catalog key: {data_like}") + raise ValueError(f'Unknown vector data type "{type(data_like).__name__}"') self._used_data.append(name) self.logger.info( @@ -988,6 +1131,8 @@ def get_geodataset( variables: Optional[List] = None, time_tuple: Optional[Tuple] = None, single_var_as_array: bool = True, + provider: Optional[str] = None, + version: Optional[str] = None, **kwargs, ) -> xr.Dataset: """Return a clipped, sliced and unified GeoDataset. @@ -1005,8 +1150,10 @@ def get_geodataset( --------- data_like: str, Path, xr.Dataset, xr.DataArray Data catalog key, path to geodataset file or geodataset xarray object. + The catalog key can be a string or a dictionary with the following keys: + {'name', 'provider', 'version'}. If a path to a file is provided it will be added - to the data_catalog with its based on the file basename without extension. + to the catalog with its based on the file basename. bbox : array-like of floats (xmin, ymin, xmax, ymax) bounding box of area of interest (in WGS84 coordinates). @@ -1014,8 +1161,6 @@ def get_geodataset( A geometry defining the area of interest. buffer : float, optional Buffer around the `bbox` or `geom` area of interest in meters. By default 0. - align : float, optional - Resolution to align the bounding box, by default None variables : str or list of str, optional. Names of GeoDataset variables to return. By default all dataset variables are returned. @@ -1034,30 +1179,28 @@ def get_geodataset( obj: xarray.Dataset or xarray.DataArray GeoDataset """ - if isinstance(data_like, (xr.DataArray, xr.Dataset)): - return data_like - elif not isinstance(data_like, (str, Path)): - raise ValueError(f'Unknown geo data type "{type(data_like).__name__}"') - - if data_like not in self.sources and exists(abspath(data_like)): - path = str(abspath(data_like)) - name = basename(data_like).split(".")[0] - source = GeoDatasetAdapter(path=path, **kwargs) - self.update(**{name: source}) - elif isinstance(data_like, dict): - if not SourceSpecDict.__required_keys__.issuperset(set(data_like.keys())): - unknown_keys = set(data_like.keys()) - SourceSpecDict.__required_keys__ - raise ValueError( - f"Unknown keys in requested data source: {unknown_keys}" - ) + if isinstance(data_like, dict): + data_like, provider, version = _parse_data_like_dict( + data_like, provider, version + ) + if isinstance(data_like, (str, Path)): + if isinstance(data_like, str) and data_like in self.sources: + name = data_like + source = self.get_source(name, provider=provider, version=version) + elif exists(abspath(data_like)): + path = str(abspath(data_like)) + if "provider" not in kwargs: + kwargs.update({"provider": "local"}) + source = GeoDatasetAdapter(path=path, **kwargs) + name = basename(data_like) + self.add_source(name, source) else: - source = self.get_source(**data_like) - name = source.name - elif isinstance(data_like, str) and data_like in self.sources: - name = data_like - source = self.get_source(name) + raise FileNotFoundError(f"No such file or catalog source: {data_like}") + elif isinstance(data_like, (xr.DataArray, xr.Dataset)): + # TODO apply bbox, geom, buffer, variables, time_tuple + return data_like else: - raise FileNotFoundError(f"No such file or catalog key: {data_like}") + raise ValueError(f'Unknown geo data type "{type(data_like).__name__}"') self._used_data.append(name) self.logger.info( @@ -1079,6 +1222,8 @@ def get_dataframe( data_like: Union[str, SourceSpecDict, Path, xr.Dataset, xr.DataArray], variables: Optional[list] = None, time_tuple: Optional[Tuple] = None, + provider: Optional[str] = None, + version: Optional[str] = None, **kwargs, ): """Return a unified and sliced DataFrame. @@ -1086,9 +1231,11 @@ def get_dataframe( Parameters ---------- data_like : str, Path, pd.DataFrame - Data catalog key, path to tabular data file or tabular pandas dataframe - object. If a path to a tabular data file is provided it will be added - to the data_catalog with its based on the file basename without extension. + Data catalog key, path to tabular data file or tabular pandas dataframe. + The catalog key can be a string or a dictionary with the following keys: + {'name', 'provider', 'version'}. + If a path to a tabular data file is provided it will be added + to the catalog with its based on the file basename. variables : str or list of str, optional. Names of GeoDataset variables to return. By default all dataset variables are returned. @@ -1104,30 +1251,27 @@ def get_dataframe( pd.DataFrame Tabular data """ - if isinstance(data_like, pd.DataFrame): - return data_like - elif not isinstance(data_like, (str, Path)): - raise ValueError(f'Unknown tabular data type "{type(data_like).__name__}"') - - if data_like not in self.sources and exists(abspath(data_like)): - path = str(abspath(data_like)) - name = basename(data_like).split(".")[0] - source = DataFrameAdapter(path=path, **kwargs) - self.update(**{name: source}) - elif isinstance(data_like, dict): - if not SourceSpecDict.__required_keys__.issuperset(set(data_like.keys())): - unknown_keys = set(data_like.keys()) - SourceSpecDict.__required_keys__ - raise ValueError( - f"Unknown keys in requested data source: {unknown_keys}" - ) + if isinstance(data_like, dict): + data_like, provider, version = _parse_data_like_dict( + data_like, provider, version + ) + if isinstance(data_like, (str, Path)): + if isinstance(data_like, str) and data_like in self.sources: + name = data_like + source = self.get_source(name, provider=provider, version=version) + elif exists(abspath(data_like)): + path = str(abspath(data_like)) + if "provider" not in kwargs: + kwargs.update({"provider": "local"}) + source = DataFrameAdapter(path=path, **kwargs) + name = basename(data_like) + self.add_source(name, source) else: - source = self.get_source(**data_like) - name = source.name - elif data_like in self.sources: - name = data_like - source = self.get_source(name) + raise FileNotFoundError(f"No such file or catalog source: {data_like}") + elif isinstance(data_like, pd.DataFrame): + return data_like else: - raise FileNotFoundError(f"No such file or catalog key: {data_like}") + raise ValueError(f'Unknown tabular data type "{type(data_like).__name__}"') self._used_data.append(name) self.logger.info( @@ -1142,8 +1286,26 @@ def get_dataframe( return obj -def _parse_data_dict( - data_dict: Dict, +def _parse_data_like_dict( + data_like: SourceSpecDict, + provider: Optional[str] = None, + version: Optional[str] = None, +): + if not SourceSpecDict.__required_keys__.issuperset(set(data_like.keys())): + unknown_keys = set(data_like.keys()) - SourceSpecDict.__required_keys__ + raise ValueError(f"Unknown keys in requested data source: {unknown_keys}") + elif "source" not in data_like: + raise ValueError("No source key found in requested data source") + else: + source = data_like.get("source") + provider = data_like.get("provider", provider) + version = data_like.get("version", version) + return source, provider, version + + +def _parse_data_source_dict( + name: str, + data_source_dict: Dict, catalog_name: str = "", root: Union[Path, str] = None, category: str = None, @@ -1156,102 +1318,45 @@ def _parse_data_dict( "GeoDataset": GeoDatasetAdapter, "DataFrame": DataFrameAdapter, } - if root is None: - root = data_dict.pop("root", None) - # parse data - data = dict() - for name, source in data_dict.items(): - source = source.copy() # important as we modify with pop - - if "path" not in source: - raise ValueError(f"{name}: Missing required path argument.") - data_type = source.pop("data_type", None) - if data_type is None: - raise ValueError(f"{name}: Data type missing.") - elif data_type not in ADAPTERS: - raise ValueError(f"{name}: Data type {data_type} unknown") - adapter = ADAPTERS.get(data_type) - # Only for local files - path = source.pop("path") - # if remote path, keep as is else call abs_path method to solve local files - if not _uri_validator(str(path)): - path = abs_path(root, path) - meta = source.pop("meta", {}) - if "category" not in meta and category is not None: - meta.update(category=category) - # Get unit attrs if given from source - attrs = source.pop("attrs", {}) - # lower kwargs for backwards compatability - - # arguments - driver_kwargs = source.pop("driver_kwargs", source.pop("kwargs", {})) - for driver_kwarg in driver_kwargs: - # required for geodataset where driver_kwargs can be a path - if "fn" in driver_kwarg: - driver_kwargs.update( - {driver_kwarg: abs_path(root, driver_kwargs[driver_kwarg])} - ) - for opt in source: - if "fn" in opt: # get absolute paths for file names - source.update({opt: abs_path(root, source[opt])}) - dict_catalog_name = source.pop("catalog_name", None) - if dict_catalog_name is not None and dict_catalog_name != catalog_name: - raise RuntimeError( - "catalog name passed as argument and differs from one in dictionary" - ) - - dict_name = source.pop("name", None) - if dict_name is not None and dict_name != name: - raise RuntimeError( - "Source name passed as argument and differs from one in dictionary" - ) - - dict_catalog_name = source.pop("catalog_name", None) - if dict_catalog_name is not None and dict_catalog_name != catalog_name: - raise RuntimeError( - "catalog name passed as argument and differs from one in dictionary" + source = data_source_dict.copy() # important as we modify with pop + + # parse path + if "path" not in source: + raise ValueError(f"{name}: Missing required path argument.") + # if remote path, keep as is else call abs_path method to solve local files + path = source.pop("path") + if not _uri_validator(str(path)): + path = abs_path(root, path) + # parse data type > adapter + data_type = source.pop("data_type", None) + if data_type is None: + raise ValueError(f"{name}: Data type missing.") + elif data_type not in ADAPTERS: + raise ValueError(f"{name}: Data type {data_type} unknown") + adapter = ADAPTERS.get(data_type) + # source meta data + meta = source.pop("meta", {}) + if "category" not in meta and category is not None: + meta.update(category=category) + + # driver arguments + driver_kwargs = source.pop("driver_kwargs", source.pop("kwargs", {})) + for driver_kwarg in driver_kwargs: + # required for geodataset where driver_kwargs can be a path + if "fn" in driver_kwarg: + driver_kwargs.update( + {driver_kwarg: abs_path(root, driver_kwargs[driver_kwarg])} ) - provider = source.pop("provider", None) - data_version = source.pop("data_version", None) - - if "placeholders" in source: - # pop avoid placeholders being passed to adapter - options = source.pop("placeholders") - for combination in itertools.product(*options.values()): - path_n = path - name_n = name - for k, v in zip(options.keys(), combination): - path_n = path_n.replace("{" + k + "}", v) - name_n = name_n.replace("{" + k + "}", v) - - data[name_n] = adapter( - path=path_n, - name=name_n, - catalog_name=catalog_name, - provider=provider, - data_version=data_version, - meta=meta, - attrs=attrs, - driver_kwargs=driver_kwargs, - **source, # key word arguments specific to certain adaptors - ) - - else: - data[name] = adapter( - path=path, - name=name, - catalog_name=catalog_name, - provider=provider, - data_version=data_version, - meta=meta, - attrs=attrs, - driver_kwargs=driver_kwargs, - **source, - ) - - return data + return adapter( + path=path, + name=name, + catalog_name=catalog_name, + meta=meta, + driver_kwargs=driver_kwargs, + **source, + ) def _yml_from_uri_or_path(uri_or_path: Union[Path, str]) -> Dict: @@ -1277,19 +1382,16 @@ def _process_dict(d: Dict, logger=logger) -> Dict: return d -def _denormalise_data_dict(data_dict, catalog_name="") -> List[Dict[str, Any]]: - # first do a pass to expand possible versions - dicts = [] +def _denormalise_data_dict(data_dict) -> List[Tuple[str, Dict]]: + """Return a flat list of with data name, dictionary of input data_dict. + + Expand possible versions, aliases and variants in data_dict. + """ + data_list = [] for name, source in data_dict.items(): - if "variants" in source: - variants = source.pop("variants") - for diff in variants: - source_copy = copy.deepcopy(source) - diff["name"] = name - diff["catalog_name"] = catalog_name - source_copy.update(**diff) - dicts.append({name: source_copy}) - elif "alias" in source: + source = copy.deepcopy(source) + data_dicts = [] + if "alias" in source: alias = source.pop("alias") warnings.warn( "The use of alias is deprecated, please add a version on the aliased" @@ -1299,25 +1401,33 @@ def _denormalise_data_dict(data_dict, catalog_name="") -> List[Dict[str, Any]]: if alias not in data_dict: raise ValueError(f"alias {alias} not found in data_dict.") # use alias source but overwrite any attributes with original source - source_org = source.copy() - source = data_dict[alias].copy() - source.update(source_org) - dicts.append({name: source}) - else: - dicts.append({name: source}) - - for d in dicts: - if "placeholders" in d: - # pop avoid placeholders being passed to adapter - options = d.pop("placeholders") + source_copy = data_dict[alias].copy() + source_copy.update(source) + data_dicts.append({name: source_copy}) + elif "variants" in source: + variants = source.pop("variants") + for diff in variants: + source_copy = copy.deepcopy(source) + source_copy.update(**diff) + data_dicts.append({name: source_copy}) + elif "placeholders" in source: + options = source.pop("placeholders") for combination in itertools.product(*options.values()): - path_n = d["path"] - name_n = d["name"] + source_copy = copy.deepcopy(source) + name_copy = name for k, v in zip(options.keys(), combination): - path_n = path_n.replace("{" + k + "}", v) - name_n = name_n.replace("{" + k + "}", v) + name_copy = name_copy.replace("{" + k + "}", v) + source_copy["path"] = source_copy["path"].replace("{" + k + "}", v) + data_dicts.append({name_copy: source_copy}) + else: + data_list.append((name, source)) + continue + + # recursively denormalise in case of multiple denormalise keys in source + for item in data_dicts: + data_list.extend(_denormalise_data_dict(item)) - return dicts + return data_list def abs_path(root: Union[Path, str], rel_path: Union[Path, str]) -> str: @@ -1327,15 +1437,3 @@ def abs_path(root: Union[Path, str], rel_path: Union[Path, str]) -> str: rel_path = join(root, rel_path) path = Path(abspath(rel_path)) return str(path) - - -def _seperate_driver_kwargs_from_kwargs( - kwargs: dict, data_adapter: DataAdapter -) -> Tuple[dict]: - driver_kwargs = kwargs - driver_kwargs_copy = driver_kwargs.copy() - kwargs = {} - for k, v in driver_kwargs_copy.items(): - if k in inspect.signature(data_adapter.__init__).parameters.keys(): - kwargs.update({k: driver_kwargs.pop(k)}) - return kwargs, driver_kwargs diff --git a/tests/data/aws_esa_worldcover.yml b/tests/data/aws_esa_worldcover.yml index 4812ba830..9a895581c 100644 --- a/tests/data/aws_esa_worldcover.yml +++ b/tests/data/aws_esa_worldcover.yml @@ -3,7 +3,7 @@ esa_worldcover: data_type: RasterDataset driver: raster filesystem: s3 - data_version: 2021 + version: 2021 provider: aws driver_kwargs: storage_options: diff --git a/tests/data/legacy_esa_worldcover.yml b/tests/data/legacy_esa_worldcover.yml index 11d131e5c..1d9c1d4b1 100644 --- a/tests/data/legacy_esa_worldcover.yml +++ b/tests/data/legacy_esa_worldcover.yml @@ -12,4 +12,4 @@ esa_worldcover: source_url: https://doi.org/10.5281/zenodo.5571936 source_version: v100 path: landuse/esa_worldcover/esa-worldcover.vrt - data_version: 2020 + version: 2020 diff --git a/tests/data/merged_esa_worldcover.yml b/tests/data/merged_esa_worldcover.yml index 4a8861eb0..7636b370d 100644 --- a/tests/data/merged_esa_worldcover.yml +++ b/tests/data/merged_esa_worldcover.yml @@ -12,12 +12,12 @@ esa_worldcover: source_license: CC BY 4.0 source_url: https://doi.org/10.5281/zenodo.5571936 variants: - - provider: local - version: 2020 - path: landuse/esa_worldcover/esa-worldcover.vrt - provider: local version: 2021 path: landuse/esa_worldcover_2021/esa-worldcover.vrt + - provider: local + version: 2020 + path: landuse/esa_worldcover/esa-worldcover.vrt - provider: aws version: 2020 path: s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt diff --git a/tests/test_data_adapter.py b/tests/test_data_adapter.py index 7062a1845..dd90b69ea 100644 --- a/tests/test_data_adapter.py +++ b/tests/test_data_adapter.py @@ -57,12 +57,12 @@ def test_rasterdataset(rioda, tmpdir): da1 = data_catalog.get_rasterdataset(fn_tif, bbox=rioda.raster.bounds) assert np.all(da1 == rioda_utm) geom = rioda.raster.box - da1 = data_catalog.get_rasterdataset("test", geom=geom) + da1 = data_catalog.get_rasterdataset("test.tif", geom=geom) assert np.all(da1 == rioda_utm) - with pytest.raises(FileNotFoundError, match="No such file or catalog key"): + with pytest.raises(FileNotFoundError, match="No such file or catalog source"): data_catalog.get_rasterdataset("no_file.tif") with pytest.raises(IndexError, match="RasterDataset: No data within"): - data_catalog.get_rasterdataset("test", bbox=[40, 50, 41, 51]) + data_catalog.get_rasterdataset("test.tif", bbox=[40, 50, 41, 51]) @pytest.mark.skipif(not compat.HAS_GCSFS, reason="GCSFS not installed.") @@ -204,7 +204,7 @@ def test_geodataset(geoda, geodf, ts, tmpdir): ).sortby("index") assert np.allclose(da1, geoda) assert da1.name == "test1" - ds1 = data_catalog.get_geodataset("test", single_var_as_array=False) + ds1 = data_catalog.get_geodataset("test.nc", single_var_as_array=False) assert isinstance(ds1, xr.Dataset) assert "test" in ds1 da2 = data_catalog.get_geodataset( @@ -217,7 +217,7 @@ def test_geodataset(geoda, geodf, ts, tmpdir): ).sortby("index") assert np.allclose(da3, geoda) assert da3.vector.crs.to_epsg() == 4326 - with pytest.raises(FileNotFoundError, match="No such file or catalog key"): + with pytest.raises(FileNotFoundError, match="No such file or catalog source"): data_catalog.get_geodataset("no_file.geojson") # Test nc file writing to file with tempfile.TemporaryDirectory() as td: @@ -272,10 +272,10 @@ def test_geodataframe(geodf, tmpdir): assert isinstance(gdf1, gpd.GeoDataFrame) assert np.all(gdf1 == geodf) gdf1 = data_catalog.get_geodataframe( - "test", bbox=geodf.total_bounds, buffer=1000, rename={"test": "test1"} + "test.geojson", bbox=geodf.total_bounds, buffer=1000, rename={"test": "test1"} ) assert np.all(gdf1 == geodf) - with pytest.raises(FileNotFoundError, match="No such file or catalog key"): + with pytest.raises(FileNotFoundError, match="No such file or catalog source"): data_catalog.get_geodataframe("no_file.geojson") diff --git a/tests/test_data_catalog.py b/tests/test_data_catalog.py index 9fec6e486..4b1bddfcf 100644 --- a/tests/test_data_catalog.py +++ b/tests/test_data_catalog.py @@ -10,7 +10,11 @@ import xarray as xr from hydromt.data_adapter import DataAdapter, RasterDatasetAdapter -from hydromt.data_catalog import DataCatalog, _denormalise_data_dict, _parse_data_dict +from hydromt.data_catalog import ( + DataCatalog, + _denormalise_data_dict, + _parse_data_source_dict, +) CATALOGDIR = join(dirname(abspath(__file__)), "..", "data", "catalogs") DATADIR = join(dirname(abspath(__file__)), "data") @@ -20,32 +24,27 @@ def test_parser(): # valid abs root on windows and linux! root = "c:/root" if os.name == "nt" else "/c/root" # simple; abs path - dd = { - "test": { - "data_type": "RasterDataset", - "path": f"{root}/to/data.tif", - } + source = { + "data_type": "RasterDataset", + "path": f"{root}/to/data.tif", } - dd_out = _parse_data_dict(dd, root=root) - assert isinstance(dd_out["test"], RasterDatasetAdapter) - assert dd_out["test"].path == abspath(dd["test"]["path"]) + adapter = _parse_data_source_dict("test", source, root=root) + assert isinstance(adapter, RasterDatasetAdapter) + assert adapter.path == abspath(source["path"]) # test with Path object - dd["test"].update(path=Path(dd["test"]["path"])) - dd_out = _parse_data_dict(dd, root=root) - assert dd_out["test"].path == abspath(dd["test"]["path"]) + source.update(path=Path(source["path"])) + adapter = _parse_data_source_dict("test", source, root=root) + assert adapter.path == abspath(source["path"]) # rel path - dd = { - "test": { - "data_type": "RasterDataset", - "path": "path/to/data.tif", - "kwargs": {"fn": "test"}, - }, - "root": root, + source = { + "data_type": "RasterDataset", + "path": "path/to/data.tif", + "kwargs": {"fn": "test"}, } - dd_out = _parse_data_dict(dd) - assert dd_out["test"].path == abspath(join(root, dd["test"]["path"])) + adapter = _parse_data_source_dict("test", source, root=root) + assert adapter.path == abspath(join(root, source["path"])) # check if path in kwargs is also absolute - assert dd_out["test"].driver_kwargs["fn"] == abspath(join(root, "test")) + assert adapter.driver_kwargs["fn"] == abspath(join(root, "test")) # alias dd = { "test": { @@ -55,11 +54,12 @@ def test_parser(): "test1": {"alias": "test"}, } with pytest.deprecated_call(): - dd = _denormalise_data_dict(dd, catalog_name="tmp") - - dd_out1 = _parse_data_dict(dd[0], root=root) - dd_out2 = _parse_data_dict(dd[1], root=root) - assert dd_out1["test"].path == dd_out2["test1"].path + sources = _denormalise_data_dict(dd) + assert len(sources) == 2 + for name, source in sources: + adapter = _parse_data_source_dict(name, source, root=root, catalog_name="tmp") + assert adapter.path == abspath(join(root, dd["test"]["path"])) + assert adapter.catalog_name == "tmp" # placeholder dd = { "test_{p1}_{p2}": { @@ -68,16 +68,36 @@ def test_parser(): "placeholders": {"p1": ["a", "b"], "p2": ["1", "2", "3"]}, }, } - dd_out = _parse_data_dict(dd, root=root) - assert len(dd_out) == 6 - assert dd_out["test_a_1"].path == abspath(join(root, "data_1.tif")) - assert "placeholders" not in dd_out["test_a_1"].to_dict() + sources = _denormalise_data_dict(dd) + assert len(sources) == 6 + for name, source in sources: + assert "placeholders" not in source + adapter = _parse_data_source_dict(name, source, root=root) + assert adapter.path == abspath(join(root, f"data_{name[-1]}.tif")) + # variants + dd = { + "test": { + "data_type": "RasterDataset", + "variants": [ + {"path": "path/to/data1.tif", "version": "1"}, + {"path": "path/to/data2.tif", "provider": "local"}, + ], + }, + } + sources = _denormalise_data_dict(dd) + assert len(sources) == 2 + for i, (name, source) in enumerate(sources): + assert "variants" not in source + adapter = _parse_data_source_dict(name, source, root=root, catalog_name="tmp") + assert adapter.version == dd["test"]["variants"][i].get("version", None) + assert adapter.provider == dd["test"]["variants"][i].get("provider", None) + assert adapter.catalog_name == "tmp" # errors with pytest.raises(ValueError, match="Missing required path argument"): - _parse_data_dict({"test": {}}) + _parse_data_source_dict("test", {}) with pytest.raises(ValueError, match="Data type error unknown"): - _parse_data_dict({"test": {"path": "", "data_type": "error"}}) + _parse_data_source_dict("test", {"path": "", "data_type": "error"}) with pytest.raises( ValueError, match="alias test not found in data_dict" ), pytest.deprecated_call(): @@ -103,49 +123,43 @@ def test_versioned_catalogs(tmpdir): # make sure the catalogs individually still work legacy_yml_fn = join(DATADIR, "legacy_esa_worldcover.yml") legacy_data_catalog = DataCatalog(data_libs=[legacy_yml_fn]) - assert ( - Path(legacy_data_catalog.get_source("esa_worldcover").path).name - == "esa-worldcover.vrt" - ) - assert legacy_data_catalog.get_source("esa_worldcover").data_version == 2020 - + assert len(legacy_data_catalog) == 1 + source = legacy_data_catalog.get_source("esa_worldcover") + assert Path(source.path).name == "esa-worldcover.vrt" + assert source.version == "2020" + # test round trip to and from dict + legacy_data_catalog2 = DataCatalog().from_dict(legacy_data_catalog.to_dict()) + assert legacy_data_catalog2 == legacy_data_catalog # make sure we raise deprecation warning here with pytest.deprecated_call(): _ = legacy_data_catalog["esa_worldcover"] + # second catalog aws_yml_fn = join(DATADIR, "aws_esa_worldcover.yml") aws_data_catalog = DataCatalog(data_libs=[aws_yml_fn]) + assert len(aws_data_catalog) == 1 # test get_source with all keyword combinations - assert ( - aws_data_catalog.get_source("esa_worldcover").path - == "s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt" - ) - assert aws_data_catalog.get_source("esa_worldcover").data_version == 2021 - assert ( - aws_data_catalog.get_source("esa_worldcover", data_version=2021).path - == "s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt" - ) - assert ( - aws_data_catalog.get_source("esa_worldcover", data_version=2021).data_version - == 2021 - ) - assert ( - aws_data_catalog.get_source( - "esa_worldcover", data_version=2021, provider="aws" - ).path - == "s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt" - ) - + source = aws_data_catalog.get_source("esa_worldcover") + assert source.path.endswith("ESA_WorldCover_10m_2020_v100_Map_AWS.vrt") + assert source.version == "2021" + source = aws_data_catalog.get_source("esa_worldcover", version=2021) + assert source.path.endswith("ESA_WorldCover_10m_2020_v100_Map_AWS.vrt") + assert source.version == "2021" + source = aws_data_catalog.get_source("esa_worldcover", version=2021, provider="aws") + assert source.path.endswith("ESA_WorldCover_10m_2020_v100_Map_AWS.vrt") + # test round trip to and from dict + aws_data_catalog2 = DataCatalog().from_dict(aws_data_catalog.to_dict()) + assert aws_data_catalog2 == aws_data_catalog + + # test errors with pytest.raises(KeyError): - aws_data_catalog.get_source( - "esa_worldcover", data_version=2021, provider="asdfasdf" - ) + aws_data_catalog.get_source("esa_worldcover", version=2021, provider="asdfasdf") with pytest.raises(KeyError): aws_data_catalog.get_source( - "esa_worldcover", data_version="asdfasdf", provider="aws" + "esa_worldcover", version="asdfasdf", provider="aws" ) with pytest.raises(KeyError): - aws_data_catalog.get_source("asdfasdf", data_version=2021, provider="aws") + aws_data_catalog.get_source("asdfasdf", version=2021, provider="aws") # make sure we trigger user warning when overwriting versions with pytest.warns(UserWarning): @@ -153,46 +167,35 @@ def test_versioned_catalogs(tmpdir): # make sure we can read merged catalogs merged_yml_fn = join(DATADIR, "merged_esa_worldcover.yml") - read_merged_catalog = DataCatalog(data_libs=[merged_yml_fn]) - assert ( - read_merged_catalog.get_source("esa_worldcover").path - == "s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt" - ) - assert ( - read_merged_catalog.get_source("esa_worldcover", provider="aws").path - == "s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt" - ) - assert ( - Path( - read_merged_catalog.get_source("esa_worldcover", provider="local").path - ).name - == "esa-worldcover.vrt" - ) - - # make sure dataframe doesn't merge different variants - assert len(read_merged_catalog.to_dataframe()) == 2 - - # Make sure we can queiry for the version we want - aws_and_legacy_data_catalog = DataCatalog(data_libs=[legacy_yml_fn, aws_yml_fn]) - assert ( - aws_and_legacy_data_catalog.get_source("esa_worldcover").path - == "s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt" - ) - - assert ( - aws_and_legacy_data_catalog.get_source("esa_worldcover", provider="aws").path - == "s3://esa-worldcover/v100/2020/ESA_WorldCover_10m_2020_v100_Map_AWS.vrt" + merged_catalog = DataCatalog(data_libs=[merged_yml_fn]) + assert len(merged_catalog) == 3 + source_aws = merged_catalog.get_source("esa_worldcover") # last variant is default + assert source_aws.filesystem == "s3" + assert merged_catalog.get_source("esa_worldcover", provider="aws") == source_aws + source_loc = merged_catalog.get_source("esa_worldcover", provider="local") + assert source_loc != source_aws + assert source_loc.filesystem == "local" + assert source_loc.version == "2021" # get newest version + # test get_source with version only + assert merged_catalog.get_source("esa_worldcover", version=2021) == source_loc + # test round trip to and from dict + merged_catalog2 = DataCatalog().from_dict(merged_catalog.to_dict()) + assert merged_catalog2 == merged_catalog + + # Make sure we can query for the version we want + aws_and_legacy_catalog = DataCatalog(data_libs=[legacy_yml_fn, aws_yml_fn]) + assert len(aws_and_legacy_catalog) == 2 + source_aws = aws_and_legacy_catalog.get_source("esa_worldcover") + assert source_aws.filesystem == "s3" + source_aws2 = aws_and_legacy_catalog.get_source("esa_worldcover", provider="aws") + assert source_aws2 == source_aws + source_loc = aws_and_legacy_catalog.get_source( + "esa_worldcover", provider="legacy_esa_worldcover" # provider is filename ) - assert ( - Path( - aws_and_legacy_data_catalog.get_source( - "esa_worldcover", provider="legacy_esa_worldcover" - ).path - ).name - == "esa-worldcover.vrt" - ) - - _ = aws_and_legacy_data_catalog.to_dict() + assert Path(source_loc.path).name == "esa-worldcover.vrt" + # test round trip to and from dict + aws_and_legacy_catalog2 = DataCatalog().from_dict(aws_and_legacy_catalog.to_dict()) + assert aws_and_legacy_catalog2 == aws_and_legacy_catalog def test_data_catalog(tmpdir): @@ -364,44 +367,85 @@ def test_export_dataframe(tmpdir, df, df_time): assert isinstance(obj, dtypes), key -def test_get_data(df): +def test_get_data(df, tmpdir): data_catalog = DataCatalog("artifact_data") # read artifacts - + n = len(data_catalog) # raster dataset using three different ways - da = data_catalog.get_rasterdataset(data_catalog.get_source("koppen_geiger").path) + name = "koppen_geiger" + da = data_catalog.get_rasterdataset(data_catalog.get_source(name).path) + assert len(data_catalog) == n + 1 assert isinstance(da, xr.DataArray) - da = data_catalog.get_rasterdataset("koppen_geiger") + da = data_catalog.get_rasterdataset(name, provider="artifact_data") assert isinstance(da, xr.DataArray) da = data_catalog.get_rasterdataset(da) assert isinstance(da, xr.DataArray) + data = {"source": name, "provider": "artifact_data"} + da = data_catalog.get_rasterdataset(data) with pytest.raises(ValueError, match='Unknown raster data type "list"'): data_catalog.get_rasterdataset([]) + with pytest.raises(FileNotFoundError): + data_catalog.get_rasterdataset("test1.tif") + with pytest.raises(ValueError, match="Unknown keys in requested data"): + data_catalog.get_rasterdataset({"name": "test"}) # vector dataset using three different ways - gdf = data_catalog.get_geodataframe(data_catalog.get_source("osm_coastlines").path) + name = "osm_coastlines" + gdf = data_catalog.get_geodataframe(data_catalog.get_source(name).path) + assert len(data_catalog) == n + 2 assert isinstance(gdf, gpd.GeoDataFrame) - gdf = data_catalog.get_geodataframe("osm_coastlines") + gdf = data_catalog.get_geodataframe(name, provider="artifact_data") assert isinstance(gdf, gpd.GeoDataFrame) gdf = data_catalog.get_geodataframe(gdf) assert isinstance(gdf, gpd.GeoDataFrame) + data = {"source": name, "provider": "artifact_data"} + gdf = data_catalog.get_geodataframe(data) + assert isinstance(gdf, gpd.GeoDataFrame) with pytest.raises(ValueError, match='Unknown vector data type "list"'): data_catalog.get_geodataframe([]) + with pytest.raises(FileNotFoundError): + data_catalog.get_geodataframe("test1.gpkg") + with pytest.raises(ValueError, match="Unknown keys in requested data"): + data_catalog.get_geodataframe({"name": "test"}) # geodataset using three different ways - da = data_catalog.get_geodataset(data_catalog.get_source("gtsmv3_eu_era5").path) + name = "gtsmv3_eu_era5" + da = data_catalog.get_geodataset(data_catalog.get_source(name).path) + assert len(data_catalog) == n + 3 assert isinstance(da, xr.DataArray) - da = data_catalog.get_geodataset("gtsmv3_eu_era5") + da = data_catalog.get_geodataset(name, provider="artifact_data") assert isinstance(da, xr.DataArray) da = data_catalog.get_geodataset(da) assert isinstance(da, xr.DataArray) + data = {"source": name, "provider": "artifact_data"} + gdf = data_catalog.get_geodataset(data) + assert isinstance(gdf, xr.DataArray) with pytest.raises(ValueError, match='Unknown geo data type "list"'): data_catalog.get_geodataset([]) + with pytest.raises(FileNotFoundError): + data_catalog.get_geodataset("test1.nc") + with pytest.raises(ValueError, match="Unknown keys in requested data"): + data_catalog.get_geodataset({"name": "test"}) # dataframe using single way + name = "test.csv" + fn = str(tmpdir.join(name)) + df.to_csv(fn) + df = data_catalog.get_dataframe(fn, driver_kwargs=dict(index_col=0)) + assert len(data_catalog) == n + 4 + assert isinstance(df, pd.DataFrame) + df = data_catalog.get_dataframe(name, provider="local") + assert isinstance(df, pd.DataFrame) df = data_catalog.get_dataframe(df) assert isinstance(df, pd.DataFrame) + data = {"source": name, "provider": "local"} + gdf = data_catalog.get_dataframe(data) + assert isinstance(gdf, pd.DataFrame) with pytest.raises(ValueError, match='Unknown tabular data type "list"'): data_catalog.get_dataframe([]) + with pytest.raises(FileNotFoundError): + data_catalog.get_dataframe("test1.csv") + with pytest.raises(ValueError, match="Unknown keys in requested data"): + data_catalog.get_dataframe({"name": "test"}) def test_deprecation_warnings(artifact_data):