diff --git a/.gitignore b/.gitignore index 95d7ffd0e..42c4a83d1 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,4 @@ dask-worker-space/ #ruff linting .ruff_cache +.envrc diff --git a/docs/changelog.rst b/docs/changelog.rst index a2286c79c..e9212c611 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -17,6 +17,7 @@ Added - Add support for reading model configs in ``TOML`` format. (PR #444) - new ``force-overwrite`` option in ``hydromt update`` CLI to force overwritting updated netcdf files. (PR #460) - add ``open_mfcsv`` function in ``io`` module for combining multiple CSV files into one dataset. (PR #486) +- Adapters can now clip data that is passed through a python object the same way as through the data catalog. (PR #481) Changed ------- @@ -29,7 +30,6 @@ Changed Fixed ----- - when a model component (eg maps, forcing, grid) is updated using the set_ methods, it will first be read to avoid loosing data. (PR #460) -- Deprecated ---------- diff --git a/hydromt/data_adapter/data_adapter.py b/hydromt/data_adapter/data_adapter.py index 2ae7eea49..ec42ee11a 100644 --- a/hydromt/data_adapter/data_adapter.py +++ b/hydromt/data_adapter/data_adapter.py @@ -7,16 +7,14 @@ from string import Formatter from typing import Optional -import geopandas as gpd import numpy as np import pandas as pd import xarray as xr import yaml from fsspec.implementations import local -from pyproj import CRS from upath import UPath -from .. import _compat, gis_utils +from .. import _compat logger = logging.getLogger(__name__) @@ -239,80 +237,11 @@ def __eq__(self, other: object) -> bool: else: return False - def _parse_zoom_level( + def _resolve_paths( self, - zoom_level: int | tuple = None, - geom: gpd.GeoSeries = None, - bbox: list = None, - logger=logger, - ) -> int: - """Return nearest smaller zoom level. - - Based on zoom resolutions defined in data catalog. - """ - # common pyproj crs axis units - known_units = ["degree", "metre", "US survey foot"] - if self.zoom_levels is None or len(self.zoom_levels) == 0: - logger.warning("No zoom levels available, default to zero") - return 0 - zls = list(self.zoom_levels.keys()) - if zoom_level is None: # return first zoomlevel (assume these are ordered) - return next(iter(zls)) - # parse zoom_level argument - if ( - isinstance(zoom_level, tuple) - and isinstance(zoom_level[0], (int, float)) - and isinstance(zoom_level[1], str) - and len(zoom_level) == 2 - ): - res, unit = zoom_level - # covert 'meter' and foot to official pyproj units - unit = {"meter": "metre", "foot": "US survey foot"}.get(unit, unit) - if unit not in known_units: - raise TypeError( - f"zoom_level unit {unit} not understood;" - f" should be one of {known_units}" - ) - elif not isinstance(zoom_level, int): - raise TypeError( - f"zoom_level argument not understood: {zoom_level}; should be a float" - ) - else: - return zoom_level - if self.crs: - # convert res if different unit than crs - crs = CRS.from_user_input(self.crs) - crs_unit = crs.axis_info[0].unit_name - if crs_unit != unit and crs_unit not in known_units: - raise NotImplementedError( - f"no conversion available for {unit} to {crs_unit}" - ) - if unit != crs_unit: - lat = 0 - if bbox is not None: - lat = (bbox[1] + bbox[3]) / 2 - elif geom is not None: - lat = geom.to_crs(4326).centroid.y.item() - conversions = { - "degree": np.hypot(*gis_utils.cellres(lat=lat)), - "US survey foot": 0.3048, - } - res = res * conversions.get(unit, 1) / conversions.get(crs_unit, 1) - # find nearest smaller zoomlevel - eps = 1e-5 # allow for rounding errors - smaller = [x < (res + eps) for x in self.zoom_levels.values()] - zl = zls[-1] if all(smaller) else zls[max(smaller.index(False) - 1, 0)] - logger.info(f"Getting data for zoom_level {zl} based on res {zoom_level}") - return zl - - def resolve_paths( - self, - time_tuple: tuple = None, - variables: list = None, - zoom_level: int | tuple = None, - geom: gpd.GeoSeries = None, - bbox: list = None, - logger=logger, + time_tuple: Optional[tuple] = None, + variables: Optional[list] = None, + zoom_level: int = 0, **kwargs, ): """Resolve {year}, {month} and {variable} keywords in self.path. @@ -326,22 +255,18 @@ def resolve_paths( :py:func:`pandas.to_timedelta`, by default None variables : list of str, optional List of variable names, by default None - zoom_level : int | tuple, optional - zoom level of dataset, can be provided as tuple of - (, ) + zoom_level : int + Parsed zoom level to use, by default 0 + See :py:meth:`RasterDataAdapter._parse_zoom_level` for more info + logger: + The logger to use. If none is provided, the devault logger will be used. **kwargs key-word arguments are passed to fsspec FileSystem objects. Arguments depend on protocal (local, gcs, s3...). - geom: - A geoSeries describing the geometries. - bbox: - A list of bounding boxes. - logger: - The logger to use. If none is provided, the devault logger will be used. Returns ------- - List: + fns: list of str list of filenames matching the path pattern given date range and variables """ known_keys = ["year", "month", "zoom_level", "variable"] @@ -365,6 +290,7 @@ def resolve_paths( else: path = path + key_str keys.append(key) + # resolve dates: month & year keys dates, vrs, postfix = [None], [None], "" if time_tuple is not None: @@ -375,21 +301,16 @@ def resolve_paths( postfix += "; date range: " + " - ".join([t.strftime(strf) for t in trange]) # resolve variables if variables is not None: + variables = np.atleast_1d(variables).tolist() mv_inv = {v: k for k, v in self.rename.items()} vrs = [mv_inv.get(var, var) for var in variables] postfix += f"; variables: {variables}" - # parse zoom level - if "zoom_level" in keys: - # returns the first zoom_level if zoom_level is None - zoom_level = self._parse_zoom_level( - zoom_level=zoom_level, bbox=bbox, geom=geom, logger=logger - ) # get filenames with glob for all date / variable combinations fs = self.get_filesystem(**kwargs) fmt = {} # update based on zoomlevel (size = 1) - if zoom_level is not None: + if "zoom_level" in keys: fmt.update(zoom_level=zoom_level) # update based on dates and variables (size >= 1) for date, var in product(dates, vrs): @@ -398,6 +319,7 @@ def resolve_paths( if var is not None: fmt.update(variable=var) fns.extend(fs.glob(path.format(**fmt))) + if len(fns) == 0: raise FileNotFoundError(f"No such file found: {path}{postfix}") @@ -409,7 +331,8 @@ def resolve_paths( last_parent = UPath(path).parents[-1] # add the rest of the path fns = [last_parent.joinpath(*UPath(fn).parts[1:]) for fn in fns] - return list(set(fns)) # return unique paths + fns = list(set(fns)) # return unique paths + return fns def get_filesystem(self, **kwargs): """Return an initialised filesystem object.""" @@ -445,3 +368,17 @@ def get_data(self, bbox, geom, buffer): If bbox of mask are given, clip data to that extent. """ + + @staticmethod + def _single_var_as_array(ds, single_var_as_array, variable_name=None): + # return data array if single variable dataset + dvars = list(ds.data_vars.keys()) + if single_var_as_array and len(dvars) == 1: + da = ds[dvars[0]] + if isinstance(variable_name, list) and len(variable_name) == 1: + da.name = variable_name[0] + elif isinstance(variable_name, str): + da.name = variable_name + return da + else: + return ds diff --git a/hydromt/data_adapter/dataframe.py b/hydromt/data_adapter/dataframe.py index f10bd1527..0d763bbd0 100644 --- a/hydromt/data_adapter/dataframe.py +++ b/hydromt/data_adapter/dataframe.py @@ -139,9 +139,6 @@ def to_file( time_tuple : tuple of str or datetime, optional Start and end date of the period of interest. By default, the entire time period of the DataFrame is included. - logger : Logger, optional - Logger object to log warnings or messages. By default, the module - logger is used. **kwargs : dict Additional keyword arguments to be passed to the file writing method. @@ -158,13 +155,7 @@ def to_file( """ kwargs.pop("bbox", None) - try: - obj = self.get_data( - time_tuple=time_tuple, variables=variables, logger=logger - ) - except IndexError as err: # out of bounds for time - logger.warning(str(err)) - return None, None, None + obj = self.get_data(time_tuple=time_tuple, variables=variables, logger=logger) read_kwargs = dict() if driver is None or driver == "csv": @@ -196,6 +187,20 @@ def get_data( based on the properties of this DataFrameAdapter. For a detailed description see: :py:func:`~hydromt.data_catalog.DataCatalog.get_dataframe` """ + # load data + fns = self._resolve_paths(variables) + df = self._read_data(fns, logger=logger) + # rename variables and parse nodata + df = self._rename_vars(df) + df = self._set_nodata(df) + # slice data + df = DataFrameAdapter._slice_data(df, variables, time_tuple, logger=logger) + # uniformize data + df = self._apply_unit_conversion(df, logger=logger) + df = self._set_metadata(df) + return df + + def _resolve_paths(self, variables=None): # Extract storage_options from kwargs to instantiate fsspec object correctly so_kwargs = {} if "storage_options" in self.driver_kwargs: @@ -206,68 +211,100 @@ def get_data( os.environ["AWS_NO_SIGN_REQUEST"] = "YES" else: os.environ["AWS_NO_SIGN_REQUEST"] = "NO" - _ = self.resolve_paths(**so_kwargs) # throw nice error if data not found - kwargs = self.driver_kwargs.copy() + # throw nice error if data not found + fns = super()._resolve_paths(variables=variables, **so_kwargs) - # read and clip - logger.info(f"DataFrame: Read {self.driver} data.") + return fns + def _read_data(self, fns, logger=logger): + if len(fns) > 1: + raise ValueError( + f"DataFrame: Reading multiple {self.driver} files is not supported." + ) + kwargs = self.driver_kwargs.copy() + path = fns[0] + logger.info(f"Reading {self.name} {self.driver} data from {self.path}") if self.driver in ["csv"]: - df = pd.read_csv(self.path, **kwargs) + df = pd.read_csv(path, **kwargs) elif self.driver == "parquet": - df = pd.read_parquet(self.path, **kwargs) + _ = kwargs.pop("index_col", None) + df = pd.read_parquet(path, **kwargs) elif self.driver in ["xls", "xlsx", "excel"]: - df = pd.read_excel(self.path, engine="openpyxl", **kwargs) + df = pd.read_excel(path, engine="openpyxl", **kwargs) elif self.driver in ["fwf"]: - df = pd.read_fwf(self.path, **kwargs) + df = pd.read_fwf(path, **kwargs) else: raise IOError(f"DataFrame: driver {self.driver} unknown.") - # rename and select columns + return df + + def _rename_vars(self, df): if self.rename: rename = {k: v for k, v in self.rename.items() if k in df.columns} df = df.rename(columns=rename) + return df + + def _set_nodata(self, df): + # parse nodata values + cols = df.select_dtypes([np.number]).columns + if self.nodata is not None and len(cols) > 0: + if not isinstance(self.nodata, dict): + nodata = {c: self.nodata for c in cols} + else: + nodata = self.nodata + for c in cols: + mv = nodata.get(c, None) + if mv is not None: + is_nodata = np.isin(df[c], np.atleast_1d(mv)) + df[c] = np.where(is_nodata, np.nan, df[c]) + return df + + @staticmethod + def _slice_data(df, variables=None, time_tuple=None, logger=logger): + """Return a sliced DataFrame. + + Parameters + ---------- + df : pd.DataFrame + the dataframe to be sliced. + variables : list of str, optional + Names of DataFrame columns to include in the output. By default all columns + time_tuple : tuple of str, datetime, optional + Start and end date of period of interest. By default the entire time period + of the dataset is returned. + + Returns + ------- + pd.DataFrame + Tabular data + """ if variables is not None: + variables = np.atleast_1d(variables).tolist() if np.any([var not in df.columns for var in variables]): raise ValueError(f"DataFrame: Not all variables found: {variables}") df = df.loc[:, variables] - # nodata and unit conversion for numeric data - if df.index.size == 0: - logger.warning(f"DataFrame: No data within spatial domain {self.path}.") - else: - # parse nodata values - cols = df.select_dtypes([np.number]).columns - if self.nodata is not None and len(cols) > 0: - if not isinstance(self.nodata, dict): - nodata = {c: self.nodata for c in cols} - else: - nodata = self.nodata - for c in cols: - mv = nodata.get(c, None) - if mv is not None: - is_nodata = np.isin(df[c], np.atleast_1d(mv)) - df[c] = np.where(is_nodata, np.nan, df[c]) - - # unit conversion - unit_names = list(self.unit_mult.keys()) + list(self.unit_add.keys()) - unit_names = [k for k in unit_names if k in df.columns] - if len(unit_names) > 0: - logger.debug(f"DataFrame: Convert units for {len(unit_names)} columns.") - for name in list(set(unit_names)): # unique - m = self.unit_mult.get(name, 1) - a = self.unit_add.get(name, 0) - df[name] = df[name] * m + a - - # clip time slice if time_tuple is not None and np.dtype(df.index).type == np.datetime64: - logger.debug(f"DataFrame: Slicing time dime {time_tuple}") + logger.debug(f"Slicing time dime {time_tuple}") df = df[df.index.slice_indexer(*time_tuple)] if df.size == 0: raise IndexError("DataFrame: Time slice out of range.") - # set meta data + return df + + def _apply_unit_conversion(self, df, logger=logger): + unit_names = list(self.unit_mult.keys()) + list(self.unit_add.keys()) + unit_names = [k for k in unit_names if k in df.columns] + if len(unit_names) > 0: + logger.debug(f"Convert units for {len(unit_names)} columns.") + for name in list(set(unit_names)): # unique + m = self.unit_mult.get(name, 1) + a = self.unit_add.get(name, 0) + df[name] = df[name] * m + a + return df + + def _set_metadata(self, df): df.attrs.update(self.meta) # set column attributes diff --git a/hydromt/data_adapter/geodataframe.py b/hydromt/data_adapter/geodataframe.py index d692d12a6..929554059 100644 --- a/hydromt/data_adapter/geodataframe.py +++ b/hydromt/data_adapter/geodataframe.py @@ -5,11 +5,10 @@ from pathlib import Path from typing import NewType, Union -import geopandas as gpd import numpy as np -from shapely.geometry import box +import pyproj -from .. import io +from .. import gis_utils, io from .data_adapter import DataAdapter logger = logging.getLogger(__name__) @@ -150,9 +149,6 @@ def to_file( variables : list of str, optional Names of GeoDataset variables to return. By default all dataset variables are returned. - logger : logger object, optional - The logger object used for logging messages. If not provided, the default - logger will be used. **kwargs Additional keyword arguments that are passed to the geopandas driver. @@ -166,8 +162,6 @@ def to_file( """ kwargs.pop("time_tuple", None) gdf = self.get_data(bbox=bbox, variables=variables, logger=logger) - if gdf.index.size == 0: - return None, None, None read_kwargs = {} if driver is None: @@ -208,50 +202,54 @@ def get_data( self, bbox=None, geom=None, - predicate="intersects", buffer=0, + predicate="intersects", logger=logger, variables=None, - # **kwargs, # this is not used, for testing only ): """Return a clipped and unified GeoDataFrame (vector). For a detailed description see: :py:func:`~hydromt.data_catalog.DataCatalog.get_geodataframe` """ - # If variable is string, convert to list - if variables: - variables = np.atleast_1d(variables).tolist() + # load + fns = self._resolve_paths(variables) + gdf = self._read_data(fns, bbox, geom, buffer, predicate, logger=logger) + # rename variables and parse crs & nodata + gdf = self._rename_vars(gdf) + gdf = self._set_crs(gdf, logger=logger) + gdf = self._set_nodata(gdf) + # slice + gdf = GeoDataFrameAdapter._slice_data( + gdf, variables, geom, bbox, buffer, predicate, logger=logger + ) + # uniformize + gdf = self._apply_unit_conversions(gdf, logger=logger) + gdf = self._set_metadata(gdf) + return gdf + def _resolve_paths(self, variables): + # storage options for fsspec (TODO: not implemented yet) if "storage_options" in self.driver_kwargs: # not sure if storage options can be passed to fiona.open() # for now throw NotImplemented Error raise NotImplementedError( "Remote file storage_options not implemented for GeoDataFrame" ) - _ = self.resolve_paths() # throw nice error if data not found - kwargs = self.driver_kwargs.copy() - # parse geom, bbox and buffer arguments - clip_str = "" - if geom is None and bbox is not None: - # convert bbox to geom with crs EPGS:4326 to apply buffer later - geom = gpd.GeoDataFrame(geometry=[box(*bbox)], crs=4326) - clip_str = " and clip to bbox (epsg:4326)" - elif geom is not None: - clip_str = f" and clip to geom (epsg:{geom.crs.to_epsg():d})" - if geom is not None: - # make sure geom is projected > buffer in meters! - if geom.crs.is_geographic and buffer > 0: - geom = geom.to_crs(3857) - geom = geom.buffer(buffer) # a buffer with zero fixes some topology errors - bbox_str = ", ".join([f"{c:.3f}" for c in geom.total_bounds]) - clip_str = f"{clip_str} [{bbox_str}]" - if kwargs.pop("within", False): # for backward compatibility - predicate = "contains" + # resolve paths + fns = super()._resolve_paths(variables=variables) + + return fns - # read and clip - logger.info(f"GeoDataFrame: Read {self.driver} data{clip_str}.") + def _read_data(self, fns, bbox, geom, buffer, predicate, logger=logger): + if len(fns) > 1: + raise ValueError( + f"GeoDataFrame: Reading multiple {self.driver} files is not supported." + ) + kwargs = self.driver_kwargs.copy() + path = fns[0] + logger.info(f"Reading {self.name} {self.driver} data from {self.path}") if self.driver in [ "csv", "parquet", @@ -268,55 +266,120 @@ def get_data( "using the driver setting is deprecated. Please use" "vector_table instead." ) - kwargs.update(driver=self.driver) + # parse bbox and geom to (buffere) geom + if bbox is not None or geom is not None: + geom = gis_utils.parse_geom_bbox_buffer(geom, bbox, buffer) # Check if file-object is required because of additional options gdf = io.open_vector( - self.path, crs=self.crs, geom=geom, predicate=predicate, **kwargs + path, crs=self.crs, geom=geom, predicate=predicate, **kwargs ) else: raise ValueError(f"GeoDataFrame: driver {self.driver} unknown.") + return gdf + + def _rename_vars(self, gdf): # rename and select columns if self.rename: rename = {k: v for k, v in self.rename.items() if k in gdf.columns} gdf = gdf.rename(columns=rename) + return gdf + + def _set_crs(self, gdf, logger=logger): + if self.crs is not None and gdf.crs is None: + gdf.set_crs(self.crs, inplace=True) + elif gdf.crs is None: + raise ValueError( + f"GeoDataFrame {self.name}: CRS not defined in data catalog or data." + ) + elif self.crs is not None and gdf.crs != pyproj.CRS.from_user_input(self.crs): + logger.warning( + f"GeoDataFrame {self.name}: CRS from data catalog does not match CRS of" + " data. The original CRS will be used. Please check your data catalog." + ) + return gdf + + @staticmethod + def _slice_data( + gdf, + variables=None, + geom=None, + bbox=None, + buffer=0, + predicate="intersects", + logger=logger, + ): + """Return a clipped GeoDataFrame (vector). + + Arguments + --------- + variables : str or list of str, optional. + Names of GeoDataFrame columns to return. + geom : geopandas.GeoDataFrame/Series, optional + A geometry defining the area of interest. + bbox : array-like of floats, optional + (xmin, ymin, xmax, ymax) bounding box of area of interest + (in WGS84 coordinates). + buffer : float, optional + Buffer around the `bbox` or `geom` area of interest in meters. By default 0. + predicate : str, optional + Predicate used to filter the GeoDataFrame, see + :py:func:`hydromt.gis_utils.filter_gdf` for details. + + Returns + ------- + gdf: geopandas.GeoDataFrame + GeoDataFrame + """ if variables is not None: + variables = np.atleast_1d(variables).tolist() if np.any([var not in gdf.columns for var in variables]): raise ValueError(f"GeoDataFrame: Not all variables found: {variables}") if "geometry" not in variables: # always keep geometry column variables = variables + ["geometry"] gdf = gdf.loc[:, variables] - # nodata and unit conversion for numeric data - if gdf.index.size == 0: - logger.warning(f"GeoDataFrame: No data within spatial domain {self.path}.") - else: - # parse nodata values - cols = gdf.select_dtypes([np.number]).columns - if self.nodata is not None and len(cols) > 0: - if not isinstance(self.nodata, dict): - nodata = {c: self.nodata for c in cols} - else: - nodata = self.nodata - for c in cols: - mv = nodata.get(c, None) - if mv is not None: - is_nodata = np.isin(gdf[c], np.atleast_1d(mv)) - gdf[c] = np.where(is_nodata, np.nan, gdf[c]) - - # unit conversion - unit_names = list(self.unit_mult.keys()) + list(self.unit_add.keys()) - unit_names = [k for k in unit_names if k in gdf.columns] - if len(unit_names) > 0: - logger.debug( - f"GeoDataFrame: Convert units for {len(unit_names)} columns." - ) - for name in list(set(unit_names)): # unique - m = self.unit_mult.get(name, 1) - a = self.unit_add.get(name, 0) - gdf[name] = gdf[name] * m + a + if geom is not None or bbox is not None: + # NOTE if we read with vector driver this is already done .. + geom = gis_utils.parse_geom_bbox_buffer(geom, bbox, buffer) + bbox_str = ", ".join([f"{c:.3f}" for c in geom.total_bounds]) + epsg = geom.crs.to_epsg() + logger.debug(f"Clip {predicate} [{bbox_str}] (EPSG:{epsg})") + idxs = gis_utils.filter_gdf(gdf, geom=geom, predicate=predicate) + if idxs.size == 0: + raise IndexError("No data within spatial domain.") + gdf = gdf.iloc[idxs] + return gdf + def _set_nodata(self, gdf): + # parse nodata values + cols = gdf.select_dtypes([np.number]).columns + if self.nodata is not None and len(cols) > 0: + if not isinstance(self.nodata, dict): + nodata = {c: self.nodata for c in cols} + else: + nodata = self.nodata + for c in cols: + mv = nodata.get(c, None) + if mv is not None: + is_nodata = np.isin(gdf[c], np.atleast_1d(mv)) + gdf[c] = np.where(is_nodata, np.nan, gdf[c]) + return gdf + + def _apply_unit_conversions(self, gdf, logger=logger): + # unit conversion + unit_names = list(self.unit_mult.keys()) + list(self.unit_add.keys()) + unit_names = [k for k in unit_names if k in gdf.columns] + if len(unit_names) > 0: + logger.debug(f"Convert units for {len(unit_names)} columns.") + for name in list(set(unit_names)): # unique + m = self.unit_mult.get(name, 1) + a = self.unit_add.get(name, 0) + gdf[name] = gdf[name] * m + a + return gdf + + def _set_metadata(self, gdf): # set meta data gdf.attrs.update(self.meta) @@ -324,4 +387,5 @@ def get_data( for col in self.attrs: if col in gdf.columns: gdf[col].attrs.update(**self.attrs[col]) + return gdf diff --git a/hydromt/data_adapter/geodataset.py b/hydromt/data_adapter/geodataset.py index 56554e909..9c0fb8f1c 100644 --- a/hydromt/data_adapter/geodataset.py +++ b/hydromt/data_adapter/geodataset.py @@ -6,11 +6,10 @@ from pathlib import Path from typing import NewType, Union -import geopandas as gpd import numpy as np import pandas as pd +import pyproj import xarray as xr -from shapely.geometry import box from .. import gis_utils, io from ..raster import GEO_MAP_COORD @@ -159,9 +158,6 @@ def to_file( variables : list of str, optional Names of GeoDataset variables to return. By default all dataset variables are returned. - logger : logger object, optional - The logger object used for logging messages. If not provided, the default - logger will be used. **kwargs Additional keyword arguments that are passed to the `to_zarr` function. @@ -181,8 +177,6 @@ def to_file( logger=logger, single_var_as_array=variables is None, ) - if obj.vector.index.size == 0 or ("time" in obj.coords and obj.time.size == 0): - return None, None, None read_kwargs = {} @@ -227,6 +221,7 @@ def get_data( bbox=None, geom=None, buffer=0, + predicate="intersects", variables=None, time_tuple=None, single_var_as_array=True, @@ -237,10 +232,26 @@ def get_data( For a detailed description see: :py:func:`~hydromt.data_catalog.DataCatalog.get_geodataset` """ - # If variable is string, convert to list - if variables: - variables = np.atleast_1d(variables).tolist() + # load data + fns = self._resolve_paths(variables, time_tuple) + ds = self._read_data(fns, logger=logger) + # rename variables and parse data and attrs + ds = self._rename_vars(ds) + ds = self._validate_spatial_coords(ds) + ds = self._set_crs(ds, logger=logger) + ds = self._set_nodata(ds) + ds = self._shift_time(ds, logger=logger) + # slice + ds = GeoDatasetAdapter._slice_data( + ds, variables, geom, bbox, buffer, predicate, time_tuple, logger=logger + ) + # uniformize + ds = self._apply_unit_conversion(ds, logger=logger) + ds = self._set_metadata(ds) + # return array if single var and single_var_as_array + return self._single_var_as_array(ds, single_var_as_array, variables) + def _resolve_paths(self, variables, time_tuple): # Extract storage_options from kwargs to instantiate fsspec object correctly so_kwargs = dict() if "storage_options" in self.driver_kwargs and self.driver == "zarr": @@ -256,154 +267,207 @@ def get_data( raise NotImplementedError( "Remote (cloud) GeoDataset only supported with driver zarr." ) - fns = self.resolve_paths( + + # resolve paths + fns = super()._resolve_paths( time_tuple=time_tuple, variables=variables, **so_kwargs ) + return fns + + def _read_data(self, fns, logger=logger): kwargs = self.driver_kwargs.copy() - # parse geom, bbox and buffer arguments - clip_str = "" - if geom is None and bbox is not None: - # convert bbox to geom with crs EPGS:4326 to apply buffer later - geom = gpd.GeoDataFrame(geometry=[box(*bbox)], crs=4326) - clip_str = " and clip to bbox (epsg:4326)" - elif geom is not None: - clip_str = f" and clip to geom (epsg:{geom.crs.to_epsg():d})" - if geom is not None: - # make sure geom is projected > buffer in meters! - if buffer > 0 and geom.crs.is_geographic: - geom = geom.to_crs(3857) - geom = geom.buffer(buffer) - bbox_str = ", ".join([f"{c:.3f}" for c in geom.total_bounds]) - clip_str = f"{clip_str} [{bbox_str}]" - if kwargs.pop("within", False): # for backward compatibility - kwargs.update(predicate="contains") - - # read and clip - logger.info(f"GeoDataset: Read {self.driver} data{clip_str}.") + if len(fns) > 1 and self.driver in ["vector", "zarr"]: + raise ValueError( + f"GeoDataset: Reading multiple {self.driver} files is not supported." + ) + logger.info(f"Reading {self.name} {self.driver} data from {self.path}") if self.driver in ["netcdf"]: - ds_out = xr.open_mfdataset(fns, **kwargs) + ds = xr.open_mfdataset(fns, **kwargs) elif self.driver == "zarr": - if len(fns) > 1: - raise ValueError( - "GeoDataset: Opening multiple zarr data files is not supported." - ) - ds_out = xr.open_zarr(fns[0], **kwargs) + ds = xr.open_zarr(fns[0], **kwargs) elif self.driver == "vector": - # read geodataset from point + time series file - ds_out = io.open_geodataset( - fn_locs=fns[0], geom=geom, crs=self.crs, **kwargs - ) - geom = None # already clipped + ds = io.open_geodataset(fn_locs=fns[0], crs=self.crs, **kwargs) else: raise ValueError(f"GeoDataset: Driver {self.driver} unknown") - if GEO_MAP_COORD in ds_out.data_vars: - ds_out = ds_out.set_coords(GEO_MAP_COORD) - # rename and select vars - if variables and len(ds_out.vector.vars) == 1 and len(self.rename) == 0: - rm = {ds_out.vector.vars[0]: variables[0]} - else: - rm = {k: v for k, v in self.rename.items() if k in ds_out} - ds_out = ds_out.rename(rm) - # check spatial dims and make sure all are set as coordinates + return ds + + def _rename_vars(self, ds): + rm = {k: v for k, v in self.rename.items() if k in ds} + ds = ds.rename(rm) + return ds + + def _validate_spatial_coords(self, ds): + if GEO_MAP_COORD in ds.data_vars: + ds = ds.set_coords(GEO_MAP_COORD) try: - ds_out.vector.set_spatial_dims() - idim = ds_out.vector.index_dim - if idim not in ds_out: # set coordinates for index dimension if missing - ds_out[idim] = xr.IndexVariable(idim, np.arange(ds_out.dims[idim])) - coords = [ds_out.vector.x_name, ds_out.vector.y_name, idim] + ds.vector.set_spatial_dims() + idim = ds.vector.index_dim + if idim not in ds: # set coordinates for index dimension if missing + ds[idim] = xr.IndexVariable(idim, np.arange(ds.dims[idim])) + coords = [ds.vector.x_name, ds.vector.y_name, idim] coords = [item for item in coords if item is not None] - ds_out = ds_out.set_coords(coords) + ds = ds.set_coords(coords) except ValueError: - raise ValueError(f"GeoDataset: No spatial coords found in data {self.path}") - if variables is not None: - if np.any([var not in ds_out.data_vars for var in variables]): - raise ValueError(f"GeoDataset: Not all variables found: {variables}") - ds_out = ds_out[variables] + raise ValueError( + f"GeoDataset: No spatial geometry dimension found in data {self.path}" + ) + return ds + def _set_crs(self, ds, logger=logger): # set crs - if ds_out.vector.crs is None and self.crs is not None: - ds_out.vector.set_crs(self.crs) - if ds_out.vector.crs is None: + if ds.vector.crs is None and self.crs is not None: + ds.vector.set_crs(self.crs) + elif ds.vector.crs is None: raise ValueError( - "GeoDataset: The data has no CRS, set in GeoDatasetAdapter." + f"GeoDataset {self.name}: CRS not defined in data catalog or data." ) - - # clip - if geom is not None: - bbox = geom.to_crs(4326).total_bounds - if ds_out.vector.crs.to_epsg() == 4326: - e = ds_out.vector.geometry.total_bounds[2] - if e > 180 or (bbox is not None and (bbox[0] < -180 or bbox[2] > 180)): - ds_out = gis_utils.meridian_offset(ds_out, ds_out.vector.x_name, bbox) - if geom is not None: - predicate = kwargs.pop("predicate", "intersects") - ds_out = ds_out.vector.clip_geom(geom, predicate=predicate) - if ds_out.vector.index.size == 0: + elif self.crs is not None and ds.vector.crs != pyproj.CRS.from_user_input( + self.crs + ): logger.warning( - f"GeoDataset: No data within spatial domain for {self.path}." + f"GeoDataset {self.name}: CRS from data catalog does not match CRS of" + " data. The original CRS will be used. Please check your data catalog." ) + return ds - # clip tslice + @staticmethod + def _slice_data( + ds, + variables=None, + geom=None, + bbox=None, + buffer=0, + predicate="intersects", + time_tuple=None, + logger=logger, + ): + """Slice the dataset in space and time. + + Arguments + --------- + ds : xarray.Dataset or xarray.DataArray + The GeoDataset to slice. + variables : str or list of str, optional. + Names of variables to return. + geom : geopandas.GeoDataFrame/Series, + A geometry defining the area of interest. + bbox : array-like of floats + (xmin, ymin, xmax, ymax) bounding box of area of interest + (in WGS84 coordinates). + buffer : float, optional + Buffer distance [m] applied to the geometry or bbox. By default 0 m. + predicate : str, optional + Predicate used to filter the GeoDataFrame, see + :py:func:`hydromt.gis_utils.filter_gdf` for details. + time_tuple : tuple of str, datetime, optional + Start and end date of period of interest. By default the entire time period + of the dataset is returned. + + Returns + ------- + ds : xarray.Dataset + The sliced GeoDataset. + """ + if isinstance(ds, xr.DataArray): + if ds.name is None: + # dummy name, required to create dataset + # renamed to variable in _single_var_as_array + ds.name = "data" + ds = ds.to_dataset() + elif variables is not None: + variables = np.atleast_1d(variables).tolist() + if len(variables) > 1 or len(ds.data_vars) > 1: + mvars = [var not in ds.data_vars for var in variables] + if any(mvars): + raise ValueError(f"GeoDataset: variables not found {mvars}") + ds = ds[variables] + if time_tuple is not None: + ds = GeoDatasetAdapter._slice_temporal_dimension( + ds, time_tuple, logger=logger + ) + if geom is not None or bbox is not None: + ds = GeoDatasetAdapter._slice_spatial_dimension( + ds, geom, bbox, buffer, predicate, logger=logger + ) + return ds + + @staticmethod + def _slice_spatial_dimension(ds, geom, bbox, buffer, predicate, logger=logger): + geom = gis_utils.parse_geom_bbox_buffer(geom, bbox, buffer) + bbox_str = ", ".join([f"{c:.3f}" for c in geom.total_bounds]) + epsg = geom.crs.to_epsg() + logger.debug(f"Clip {predicate} [{bbox_str}] (EPSG:{epsg})") + ds = ds.vector.clip_geom(geom, predicate=predicate) + if ds.vector.index.size == 0: + raise IndexError("No data within spatial domain.") + return ds + + def _shift_time(self, ds, logger=logger): + dt = self.unit_add.get("time", 0) + if ( + dt != 0 + and "time" in ds.dims + and ds["time"].size > 1 + and np.issubdtype(ds["time"].dtype, np.datetime64) + ): + logger.debug(f"Shifting time labels with {dt} sec.") + ds["time"] = ds["time"] + pd.to_timedelta(dt, unit="s") + elif dt != 0: + logger.warning("Time shift not applied, time dimension not found.") + return ds + + @staticmethod + def _slice_temporal_dimension(ds, time_tuple, logger=logger): if ( - "time" in ds_out.dims - and ds_out["time"].size > 1 - and np.issubdtype(ds_out["time"].dtype, np.datetime64) + "time" in ds.dims + and ds["time"].size > 1 + and np.issubdtype(ds["time"].dtype, np.datetime64) ): - dt = self.unit_add.get("time", 0) - if dt != 0: - logger.debug(f"GeoDataset: Shifting time labels with {dt} sec.") - ds_out["time"] = ds_out["time"] + pd.to_timedelta(dt, unit="s") - if time_tuple is not None: - logger.debug(f"GeoDataset: Slicing time dim {time_tuple}") - ds_out = ds_out.sel(time=slice(*time_tuple)) - if ds_out.time.size == 0: - logger.warning("GeoDataset: Time slice out of range.") - drop_vars = [v for v in ds_out.data_vars if "time" in ds_out[v].dims] - ds_out = ds_out.drop_vars(drop_vars) - - # set nodata value + logger.debug(f"Slicing time dim {time_tuple}") + ds = ds.sel(time=slice(*time_tuple)) + if ds.time.size == 0: + raise IndexError("GeoDataset: Time slice out of range.") + return ds + + def _set_metadata(self, ds): + if self.attrs: + if isinstance(ds, xr.DataArray): + ds.attrs.update(self.attrs[ds.name]) + else: + for k in self.attrs: + ds[k].attrs.update(self.attrs[k]) + + ds.attrs.update(self.meta) + return ds + + def _set_nodata(self, ds): if self.nodata is not None: if not isinstance(self.nodata, dict): - nodata = {k: self.nodata for k in ds_out.data_vars.keys()} + nodata = {k: self.nodata for k in ds.data_vars.keys()} else: nodata = self.nodata - for k in ds_out.data_vars: + for k in ds.data_vars: mv = nodata.get(k, None) - if mv is not None and ds_out[k].vector.nodata is None: - ds_out[k].vector.set_nodata(mv) + if mv is not None and ds[k].vector.nodata is None: + ds[k].vector.set_nodata(mv) + return ds - # unit conversion + def _apply_unit_conversion(self, ds, logger=logger): unit_names = list(self.unit_mult.keys()) + list(self.unit_add.keys()) - unit_names = [k for k in unit_names if k in ds_out.data_vars] + unit_names = [k for k in unit_names if k in ds.data_vars] if len(unit_names) > 0: - logger.debug(f"GeoDataset: Convert units for {len(unit_names)} variables.") + logger.debug(f"Convert units for {len(unit_names)} variables.") for name in list(set(unit_names)): # unique m = self.unit_mult.get(name, 1) a = self.unit_add.get(name, 0) - da = ds_out[name] + da = ds[name] attrs = da.attrs.copy() nodata_isnan = da.vector.nodata is None or np.isnan(da.vector.nodata) # nodata value is explicitly set to NaN in case no nodata value is provided nodata = np.nan if nodata_isnan else da.vector.nodata data_bool = ~np.isnan(da) if nodata_isnan else da != nodata - ds_out[name] = xr.where(data_bool, da * m + a, nodata) - ds_out[name].attrs.update(attrs) # set original attributes - - # return data array if single var - if single_var_as_array and len(ds_out.vector.vars) == 1: - ds_out = ds_out[ds_out.vector.vars[0]] - - # Set variable attribute data - if self.attrs: - if isinstance(ds_out, xr.DataArray): - ds_out.attrs.update(self.attrs[ds_out.name]) - else: - for k in self.attrs: - ds_out[k].attrs.update(self.attrs[k]) - - # set meta data - ds_out.attrs.update(self.meta) - - return ds_out + ds[name] = xr.where(data_bool, da * m + a, nodata) + ds[name].attrs.update(attrs) # set original attributes + return ds diff --git a/hydromt/data_adapter/rasterdataset.py b/hydromt/data_adapter/rasterdataset.py index 9834d5ef0..bb1af79ed 100644 --- a/hydromt/data_adapter/rasterdataset.py +++ b/hydromt/data_adapter/rasterdataset.py @@ -1,11 +1,14 @@ """Implementation for the RasterDatasetAdapter.""" +from __future__ import annotations + import logging import os import warnings from os import PathLike from os.path import join -from typing import NewType, Union +from typing import NewType, Optional, Union +import geopandas as gpd import numpy as np import pandas as pd import pyproj @@ -37,10 +40,10 @@ class RasterDatasetAdapter(DataAdapter): def __init__( self, path: str, - driver: str = None, + driver: Optional[str] = None, filesystem: str = "local", - crs: Union[int, str, dict] = None, - nodata: Union[dict, float, int] = None, + crs: Optional[Union[int, str, dict]] = None, + nodata: Optional[Union[dict, float, int]] = None, rename: dict = {}, unit_mult: dict = {}, unit_add: dict = {}, @@ -165,9 +168,6 @@ def to_file( variables : list of str, optional Names of GeoDataset variables to return. By default all dataset variables are returned. - logger : logger object, optional - The logger object used for logging messages. If not provided, the default - logger will be used. **kwargs Additional keyword arguments that are passed to the `to_netcdf` function. @@ -182,24 +182,20 @@ def to_file( kwargs: dict the additional kwyeord arguments that were passed to `to_netcdf` """ - try: - obj = self.get_data( - bbox=bbox, - time_tuple=time_tuple, - variables=variables, - logger=logger, - single_var_as_array=variables is None, - ) - except IndexError as err: # out of bounds - logger.warning(str(err)) - return None, None, None + obj = self.get_data( + bbox=bbox, + time_tuple=time_tuple, + variables=variables, + logger=logger, + single_var_as_array=variables is None, + ) read_kwargs = {} if driver is None: # by default write 2D raster data to GeoTiff and 3D raster data to netcdf driver = "netcdf" if len(obj.dims) == 3 else "GTiff" # write using various writers - if driver in ["netcdf"]: # TODO complete list + if driver == "netcdf": dvars = [obj.name] if isinstance(obj, xr.DataArray) else obj.raster.vars if variables is None: encoding = {k: {"zlib": True} for k in dvars} @@ -251,10 +247,34 @@ def get_data( For a detailed description see: :py:func:`~hydromt.data_catalog.DataCatalog.get_rasterdataset` """ - # If variable is string, convert to list - if variables: - variables = np.atleast_1d(variables).tolist() + # load data + fns = self._resolve_paths(time_tuple, variables, zoom_level, geom, bbox, logger) + ds = self._read_data(fns, geom, bbox, cache_root, logger) + # rename variables and parse data and attrs + ds = self._rename_vars(ds) + ds = self._validate_spatial_dims(ds) + ds = self._set_crs(ds, logger) + ds = self._set_nodata(ds) + ds = self._shift_time(ds, logger) + # slice data + ds = RasterDatasetAdapter._slice_data( + ds, variables, geom, bbox, buffer, align, time_tuple, logger + ) + # uniformize data + ds = self._apply_unit_conversions(ds, logger) + ds = self._set_metadata(ds) + # return array if single var and single_var_as_array + return self._single_var_as_array(ds, single_var_as_array, variables) + def _resolve_paths( + self, + time_tuple: Optional[tuple] = None, + variables: Optional[list] = None, + zoom_level: int = 0, + geom: gpd.GeoSeries = None, + bbox: Optional[list] = None, + logger=logger, + ): # Extract storage_options from kwargs to instantiate fsspec object correctly so_kwargs = dict() if "storage_options" in self.driver_kwargs: @@ -266,19 +286,22 @@ def get_data( else: os.environ["AWS_NO_SIGN_REQUEST"] = "NO" + # parse zoom level (raster only) + if len(self.zoom_levels) > 0: + zoom_level = self._parse_zoom_level(zoom_level, geom, bbox, logger=logger) + # resolve path based on time, zoom level and/or variables - fns = self.resolve_paths( + fns = super()._resolve_paths( time_tuple=time_tuple, variables=variables, zoom_level=zoom_level, - geom=geom, - bbox=bbox, - logger=logger, **so_kwargs, ) - kwargs = self.driver_kwargs.copy() + return fns + def _read_data(self, fns, geom, bbox, cache_root, logger=logger): + kwargs = self.driver_kwargs.copy() # zarr can use storage options directly, the rest should be converted to # file-like objects if "storage_options" in kwargs and self.driver == "raster": @@ -287,12 +310,13 @@ def get_data( fns = [fs.open(f) for f in fns] # read using various readers - if self.driver in ["netcdf"]: # TODO complete list + logger.info(f"Reading {self.name} {self.driver} data from {self.path}") + if self.driver == "netcdf": if self.filesystem == "local": if "preprocess" in kwargs: preprocess = PREPROCESSORS.get(kwargs["preprocess"], None) kwargs.update(preprocess=preprocess) - ds_out = xr.open_mfdataset(fns, decode_coords="all", **kwargs) + ds = xr.open_mfdataset(fns, decode_coords="all", **kwargs) else: raise NotImplementedError( "Remote (cloud) RasterDataset not supported with driver netcdf." @@ -309,14 +333,12 @@ def get_data( if do_preprocess: ds = preprocess(ds) ds_lst.append(ds) - ds_out = xr.merge(ds_lst) + ds = xr.merge(ds_lst) elif self.driver == "raster_tindex": if self.filesystem == "local": if np.issubdtype(type(self.nodata), np.number): kwargs.update(nodata=self.nodata) - ds_out = io.open_raster_from_tindex( - fns[0], bbox=bbox, geom=geom, **kwargs - ) + ds = io.open_raster_from_tindex(fns[0], bbox=bbox, geom=geom, **kwargs) else: raise NotImplementedError( "Remote (cloud) RasterDataset not supported " @@ -334,61 +356,143 @@ def get_data( fns = fns_cached if np.issubdtype(type(self.nodata), np.number): kwargs.update(nodata=self.nodata) - ds_out = io.open_mfraster(fns, logger=logger, **kwargs) + ds = io.open_mfraster(fns, logger=logger, **kwargs) else: raise ValueError(f"RasterDataset: Driver {self.driver} unknown") - if GEO_MAP_COORD in ds_out.data_vars: - ds_out = ds_out.set_coords(GEO_MAP_COORD) - - # rename and select vars - if variables and len(ds_out.raster.vars) == 1 and len(self.rename) == 0: - rm = {ds_out.raster.vars[0]: variables[0]} - if set(rm.keys()) != set(rm.values()): - warnings.warn( - "Automatic renaming of single var array will be deprecated, rename" - f" {rm} in the data catalog instead.", - DeprecationWarning, - ) - else: - rm = {k: v for k, v in self.rename.items() if k in ds_out} - ds_out = ds_out.rename(rm) - if variables is not None: - if np.any([var not in ds_out.data_vars for var in variables]): - raise ValueError(f"RasterDataset: Not all variables found: {variables}") - ds_out = ds_out[variables] - - # transpose dims to get y and x dim last - x_dim = ds_out.raster.x_dim - y_dim = ds_out.raster.y_dim - ds_out = ds_out.transpose(..., y_dim, x_dim) - - # clip tslice - if ( - "time" in ds_out.dims - and ds_out["time"].size > 1 - and np.issubdtype(ds_out["time"].dtype, np.datetime64) - ): - dt = self.unit_add.get("time", 0) - if dt != 0: - logger.debug(f"RasterDataset: Shifting time labels with {dt} sec.") - ds_out["time"] = ds_out["time"] + pd.to_timedelta(dt, unit="s") - if time_tuple is not None: - logger.debug(f"RasterDataset: Slicing time dim {time_tuple}") - ds_out = ds_out.sel({"time": slice(*time_tuple)}) - if ds_out.time.size == 0: - raise IndexError("RasterDataset: Time slice out of range.") + return ds + + def _rename_vars(self, ds): + rm = {k: v for k, v in self.rename.items() if k in ds} + ds = ds.rename(rm) + return ds + + def _validate_spatial_dims(self, ds): + if GEO_MAP_COORD in ds.data_vars: + ds = ds.set_coords(GEO_MAP_COORD) + try: + ds.raster.set_spatial_dims() + # transpose dims to get y and x dim last + x_dim = ds.raster.x_dim + y_dim = ds.raster.y_dim + ds = ds.transpose(..., y_dim, x_dim) + except ValueError: + raise ValueError( + f"RasterDataset: No valid spatial coords found in data {self.path}" + ) + return ds + + def _set_crs(self, ds, logger=logger): # set crs - if ds_out.raster.crs is None and self.crs is not None: - ds_out.raster.set_crs(self.crs) - elif ds_out.raster.crs is None: + if ds.raster.crs is None and self.crs is not None: + ds.raster.set_crs(self.crs) + elif ds.raster.crs is None: raise ValueError( - "RasterDataset: The data has no CRS, set in RasterDatasetAdapter." + f"RasterDataset {self.name}: CRS not defined in data catalog or data." + ) + elif self.crs is not None and ds.raster.crs != pyproj.CRS.from_user_input( + self.crs + ): + logger.warning( + f"RasterDataset {self.name}: CRS from data catalog does not match CRS " + " of data. The original CRS will be used. Please check your catalog." + ) + return ds + + @staticmethod + def _slice_data( + ds, + variables=None, + geom=None, + bbox=None, + buffer=0, + align=None, + time_tuple=None, + logger=logger, + ): + """Return a RasterDataset sliced in both spatial and temporal dimensions. + + Arguments + --------- + ds : xarray.Dataset or xarray.DataArray + The RasterDataset to slice. + variables : list of str, optional + Names of variables to return. By default all dataset variables + geom : geopandas.GeoDataFrame/Series, optional + A geometry defining the area of interest. + bbox : array-like of floats, optional + (xmin, ymin, xmax, ymax) bounding box of area of interest + (in WGS84 coordinates). + buffer : int, optional + Buffer around the `bbox` or `geom` area of interest in pixels. By default 0. + align : float, optional + Resolution to align the bounding box, by default None + time_tuple : Tuple of datetime, optional + A tuple consisting of the lower and upper bounds of time that the + result should contain + + Returns + ------- + ds : xarray.Dataset + The sliced RasterDataset. + """ + if isinstance(ds, xr.DataArray): + if ds.name is None: + # dummy name, required to create dataset + # renamed to variable in _single_var_as_array + ds.name = "data" + ds = ds.to_dataset() + elif variables is not None: + variables = np.atleast_1d(variables).tolist() + if len(variables) > 1 or len(ds.data_vars) > 1: + mvars = [var not in ds.data_vars for var in variables] + if any(mvars): + raise ValueError(f"RasterDataset: variables not found {mvars}") + ds = ds[variables] + if time_tuple is not None: + ds = RasterDatasetAdapter._slice_temporal_dimension( + ds, + time_tuple, + logger=logger, + ) + if geom is not None or bbox is not None: + ds = RasterDatasetAdapter._slice_spatial_dimensions( + ds, geom, bbox, buffer, align, logger=logger ) + return ds - # clip + def _shift_time(self, ds, logger=logger): + dt = self.unit_add.get("time", 0) + if ( + dt != 0 + and "time" in ds.dims + and ds["time"].size > 1 + and np.issubdtype(ds["time"].dtype, np.datetime64) + ): + logger.debug(f"Shifting time labels with {dt} sec.") + ds["time"] = ds["time"] + pd.to_timedelta(dt, unit="s") + elif dt != 0: + logger.warning("Time shift not applied, time dimension not found.") + return ds + + @staticmethod + def _slice_temporal_dimension(ds, time_tuple, logger=logger): + if ( + "time" in ds.dims + and ds["time"].size > 1 + and np.issubdtype(ds["time"].dtype, np.datetime64) + ): + if time_tuple is not None: + logger.debug(f"Slicing time dim {time_tuple}") + ds = ds.sel({"time": slice(*time_tuple)}) + if ds.time.size == 0: + raise IndexError("Time slice out of range.") + return ds + + @staticmethod + def _slice_spatial_dimensions(ds, geom, bbox, buffer, align, logger=logger): # make sure bbox is in data crs - crs = ds_out.raster.crs + crs = ds.raster.crs epsg = crs.to_epsg() # this could return None if geom is not None: bbox = geom.to_crs(crs).total_bounds @@ -397,59 +501,124 @@ def get_data( bbox = rasterio.warp.transform_bounds(crs4326, crs, *bbox) # work with 4326 data that is defined at 0-360 degrees longtitude if epsg == 4326: - e = ds_out.raster.bounds[2] + e = ds.raster.bounds[2] if e > 180 or (bbox is not None and (bbox[0] < -180 or bbox[2] > 180)): - x_dim = ds_out.raster.x_dim - ds_out = gis_utils.meridian_offset(ds_out, x_dim, bbox).sortby(x_dim) + x_dim = ds.raster.x_dim + ds = gis_utils.meridian_offset(ds, x_dim, bbox).sortby(x_dim) + # clip with bbox if bbox is not None: bbox_str = ", ".join([f"{c:.3f}" for c in bbox]) - logger.debug(f"RasterDataset: Clip with bbox - [{bbox_str}] (epsg:{epsg}))") - ds_out = ds_out.raster.clip_bbox(bbox, buffer=buffer, align=align) - if np.any(np.array(ds_out.raster.shape) < 2): - raise IndexError( - f"RasterDataset: No data within spatial domain for {self.path}." - ) + logger.debug(f"Clip to [{bbox_str}] (epsg:{epsg}))") + ds = ds.raster.clip_bbox(bbox, buffer=buffer, align=align) + if np.any(np.array(ds.raster.shape) < 2): + raise IndexError("RasterDataset: No data within spatial domain.") - # set nodata value - if self.nodata is not None: - if not isinstance(self.nodata, dict): - nodata = {k: self.nodata for k in ds_out.data_vars.keys()} - else: - nodata = self.nodata - for k in ds_out.data_vars: - mv = nodata.get(k, None) - if mv is not None and ds_out[k].raster.nodata is None: - ds_out[k].raster.set_nodata(mv) + return ds - # unit conversion + def _apply_unit_conversions(self, ds, logger=logger): unit_names = list(self.unit_mult.keys()) + list(self.unit_add.keys()) - unit_names = [k for k in unit_names if k in ds_out.data_vars] + unit_names = [k for k in unit_names if k in ds.data_vars] if len(unit_names) > 0: - logger.debug( - f"RasterDataset: Convert units for {len(unit_names)} variables." - ) + logger.debug(f"Convert units for {len(unit_names)} variables.") for name in list(set(unit_names)): # unique m = self.unit_mult.get(name, 1) a = self.unit_add.get(name, 0) - da = ds_out[name] + da = ds[name] attrs = da.attrs.copy() nodata_isnan = da.raster.nodata is None or np.isnan(da.raster.nodata) # nodata value is explicitly set to NaN in case no nodata value is provided nodata = np.nan if nodata_isnan else da.raster.nodata data_bool = ~np.isnan(da) if nodata_isnan else da != nodata - ds_out[name] = xr.where(data_bool, da * m + a, nodata) - ds_out[name].attrs.update(attrs) # set original attributes - ds_out[name].raster.set_nodata(nodata) # reset nodata in case of change + ds[name] = xr.where(data_bool, da * m + a, nodata) + ds[name].attrs.update(attrs) # set original attributes + ds[name].raster.set_nodata(nodata) # reset nodata in case of change + + return ds + def _set_nodata(self, ds): + # set nodata value + if self.nodata is not None: + if not isinstance(self.nodata, dict): + nodata = {k: self.nodata for k in ds.data_vars.keys()} + else: + nodata = self.nodata + for k in ds.data_vars: + mv = nodata.get(k, None) + if mv is not None and ds[k].raster.nodata is None: + ds[k].raster.set_nodata(mv) + return ds + + def _set_metadata(self, ds): # unit attributes for k in self.attrs: - ds_out[k].attrs.update(self.attrs[k]) + ds[k].attrs.update(self.attrs[k]) + # set meta data + ds.attrs.update(self.meta) + return ds - # return data array if single var - if single_var_as_array and len(ds_out.raster.vars) == 1: - ds_out = ds_out[ds_out.raster.vars[0]] + def _parse_zoom_level( + self, + zoom_level: Optional[int | tuple] = None, + geom: gpd.GeoSeries = None, + bbox: Optional[list] = None, + logger=logger, + ) -> int: + """Return nearest smaller zoom level. - # set meta data - ds_out.attrs.update(self.meta) - return ds_out + Based on zoom resolutions defined in data catalog. + """ + # common pyproj crs axis units + known_units = ["degree", "metre", "US survey foot"] + if self.zoom_levels is None or len(self.zoom_levels) == 0: + logger.warning("No zoom levels available, default to zero") + return 0 + zls = list(self.zoom_levels.keys()) + if zoom_level is None: # return first zoomlevel (assume these are ordered) + return next(iter(zls)) + # parse zoom_level argument + if ( + isinstance(zoom_level, tuple) + and isinstance(zoom_level[0], (int, float)) + and isinstance(zoom_level[1], str) + and len(zoom_level) == 2 + ): + res, unit = zoom_level + # covert 'meter' and foot to official pyproj units + unit = {"meter": "metre", "foot": "US survey foot"}.get(unit, unit) + if unit not in known_units: + raise TypeError( + f"zoom_level unit {unit} not understood;" + f" should be one of {known_units}" + ) + elif not isinstance(zoom_level, int): + raise TypeError( + f"zoom_level argument not understood: {zoom_level}; should be a float" + ) + else: + return zoom_level + if self.crs: + # convert res if different unit than crs + crs = pyproj.CRS.from_user_input(self.crs) + crs_unit = crs.axis_info[0].unit_name + if crs_unit != unit and crs_unit not in known_units: + raise NotImplementedError( + f"no conversion available for {unit} to {crs_unit}" + ) + if unit != crs_unit: + lat = 0 + if bbox is not None: + lat = (bbox[1] + bbox[3]) / 2 + elif geom is not None: + lat = geom.to_crs(4326).centroid.y.item() + conversions = { + "degree": np.hypot(*gis_utils.cellres(lat=lat)), + "US survey foot": 0.3048, + } + res = res * conversions.get(unit, 1) / conversions.get(crs_unit, 1) + # find nearest smaller zoomlevel + eps = 1e-5 # allow for rounding errors + smaller = [x < (res + eps) for x in self.zoom_levels.values()] + zl = zls[-1] if all(smaller) else zls[max(smaller.index(False) - 1, 0)] + logger.info(f"Getting data for zoom_level {zl} based on res {zoom_level}") + return zl diff --git a/hydromt/data_catalog.py b/hydromt/data_catalog.py index 52f6c8ef6..a50644c9c 100644 --- a/hydromt/data_catalog.py +++ b/hydromt/data_catalog.py @@ -867,18 +867,17 @@ def export_data( unit_add = source.unit_add source.unit_mult = {} source.unit_add = {} - fn_out, driver, driver_kwargs = source.to_file( - data_root=data_root, - data_name=key, - variables=source_vars.get(key, None), - bbox=bbox, - time_tuple=time_tuple, - logger=self.logger, - ) - if fn_out is None: - self.logger.warning( - f"{key} file contains no data within domain" + try: + fn_out, driver, driver_kwargs = source.to_file( + data_root=data_root, + data_name=key, + variables=source_vars.get(key, None), + bbox=bbox, + time_tuple=time_tuple, + logger=self.logger, ) + except IndexError as e: + self.logger.warning(f"{key} file contains no data: {e}") continue # update path & driver and remove kwargs # and rename in output sources @@ -1006,16 +1005,24 @@ def get_rasterdataset( else: raise FileNotFoundError(f"No such file or catalog source: {data_like}") elif isinstance(data_like, (xr.DataArray, xr.Dataset)): - # TODO apply bbox, geom, buffer, align, variables, time_tuple - return data_like + data_like = RasterDatasetAdapter._slice_data( + data_like, + variables, + geom, + bbox, + buffer, + align, + time_tuple, + logger=self.logger, + ) + return RasterDatasetAdapter._single_var_as_array( + data_like, single_var_as_array, variables + ) else: raise ValueError(f'Unknown raster data type "{type(data_like).__name__}"') + # TODO add also provider and version to used data self._used_data.append(name) - self.logger.info( - f"DataCatalog: Getting {name} RasterDataset {source.driver} data from" - f" {source.path}" - ) obj = source.get_data( bbox=bbox, geom=geom, @@ -1104,16 +1111,13 @@ def get_geodataframe( else: raise FileNotFoundError(f"No such file or catalog source: {data_like}") elif isinstance(data_like, gpd.GeoDataFrame): - # TODO apply bbox, geom, buffer, predicate, variables - return data_like + return GeoDataFrameAdapter._slice_data( + data_like, variables, geom, bbox, buffer, predicate, logger=self.logger + ) else: raise ValueError(f'Unknown vector data type "{type(data_like).__name__}"') self._used_data.append(name) - self.logger.info( - f"DataCatalog: Getting {name} GeoDataFrame {source.driver} data" - f" from {source.path}" - ) gdf = source.get_data( bbox=bbox, geom=geom, @@ -1130,6 +1134,7 @@ def get_geodataset( bbox: Optional[List] = None, geom: Optional[gpd.GeoDataFrame] = None, buffer: Union[float, int] = 0, + predicate: str = "intersects", variables: Optional[List] = None, time_tuple: Optional[Tuple] = None, single_var_as_array: bool = True, @@ -1163,6 +1168,10 @@ def get_geodataset( A geometry defining the area of interest. buffer : float, optional Buffer around the `bbox` or `geom` area of interest in meters. By default 0. + predicate : {'intersects', 'within', 'contains', 'overlaps', + 'crosses', 'touches'}, optional If predicate is provided, + the GeoDataFrame is filtered by testing the predicate function + against each item. Requires bbox or mask. By default 'intersects' variables : str or list of str, optional. Names of GeoDataset variables to return. By default all dataset variables are returned. @@ -1199,20 +1208,28 @@ def get_geodataset( else: raise FileNotFoundError(f"No such file or catalog source: {data_like}") elif isinstance(data_like, (xr.DataArray, xr.Dataset)): - # TODO apply bbox, geom, buffer, variables, time_tuple - return data_like + data_like = GeoDatasetAdapter._slice_data( + data_like, + variables, + geom, + bbox, + buffer, + predicate, + time_tuple, + logger=self.logger, + ) + return GeoDatasetAdapter._single_var_as_array( + data_like, single_var_as_array, variables + ) else: raise ValueError(f'Unknown geo data type "{type(data_like).__name__}"') self._used_data.append(name) - self.logger.info( - f"DataCatalog: Getting {name} GeoDataset {source.driver} data" - f" from {source.path}" - ) obj = source.get_data( bbox=bbox, geom=geom, buffer=buffer, + predicate=predicate, variables=variables, time_tuple=time_tuple, single_var_as_array=single_var_as_array, @@ -1271,15 +1288,13 @@ def get_dataframe( else: raise FileNotFoundError(f"No such file or catalog source: {data_like}") elif isinstance(data_like, pd.DataFrame): - return data_like + return DataFrameAdapter._slice_data( + data_like, variables, time_tuple, logger=self.logger + ) else: raise ValueError(f'Unknown tabular data type "{type(data_like).__name__}"') self._used_data.append(name) - self.logger.info( - f"DataCatalog: Getting {name} DataFrame {source.driver} data" - f" from {source.path}" - ) obj = source.get_data( variables=variables, time_tuple=time_tuple, diff --git a/hydromt/gis_utils.py b/hydromt/gis_utils.py index d90ae387b..88dae516e 100644 --- a/hydromt/gis_utils.py +++ b/hydromt/gis_utils.py @@ -226,10 +226,42 @@ def filter_gdf(gdf, geom=None, bbox=None, crs=None, predicate="intersects"): geom = geom.to_crs(gdf.crs) # convert geopandas to geometry geom = geom.unary_union - idx = gdf.sindex.query(geom, predicate=predicate) + idx = np.sort(gdf.sindex.query(geom, predicate=predicate)) return idx +def parse_geom_bbox_buffer(geom=None, bbox=None, buffer=0): + """Parse geom or bbox to a (buffered) geometry. + + Arguments + --------- + geom : geopandas.GeoDataFrame/Series, optional + A geometry defining the area of interest. + bbox : array-like of floats, optional + (xmin, ymin, xmax, ymax) bounding box of area of interest + (in WGS84 coordinates). + buffer : float, optional + Buffer around the `bbox` or `geom` area of interest in meters. By default 0. + + Returns + ------- + geom: geometry + the actual geometry + """ + if geom is None and bbox is not None: + # convert bbox to geom with crs EPGS:4326 to apply buffer later + geom = gpd.GeoDataFrame(geometry=[box(*bbox)], crs=4326) + elif geom is None: + raise ValueError("No geom or bbox provided.") + + if buffer > 0: + # make sure geom is projected > buffer in meters! + if geom.crs.is_geographic: + geom = geom.to_crs(3857) + geom = geom.buffer(buffer) + return geom + + # REPROJ def utm_crs(bbox): """Return wkt string of nearest UTM projects. diff --git a/hydromt/io.py b/hydromt/io.py index 9ffac2754..06f23634f 100644 --- a/hydromt/io.py +++ b/hydromt/io.py @@ -242,7 +242,7 @@ def open_mfcsv( Dictionary containing a id -> filename mapping. Here the ids, should correspond to the values of the `concat_dim` dimension. concat_dim : str, - name of the dimention that will be created by concatinating + name of the dimension that will be created by concatinating all of the supplied csv files. driver_kwargs : Dict[str, Any], Any additional arguments to be passed to pandas' `read_csv` function. @@ -397,9 +397,9 @@ def open_geodataset( Filter features by given bounding box described by [xmin, ymin, xmax, ymax] Cannot be used with geom. index_dim: - The dimention to index on. + The dimension to index on. chunks: - The dimentions of the chunks to store the underlying data in. + The dimensions of the chunks to store the underlying data in. geom : GeoDataFrame or GeoSeries | shapely Geometry, default None Filter for features that intersect with the geom. CRS mis-matches are resolved if given a GeoSeries or GeoDataFrame. @@ -451,7 +451,7 @@ def open_timeseries_from_table( name: str variable name, derived from basename of fn if None. index_dim: - the dimention to index on. + the dimension to index on. **kwargs: key-word arguments are passed to the reader method logger: diff --git a/hydromt/workflows/forcing.py b/hydromt/workflows/forcing.py index 55e2872e6..b653477a1 100644 --- a/hydromt/workflows/forcing.py +++ b/hydromt/workflows/forcing.py @@ -759,7 +759,7 @@ def delta_freq(da_or_freq, da_or_freq1): def to_timedelta(da_or_freq): - """Convert time dimention or frequency to timedelta.""" + """Convert time dimension or frequency to timedelta.""" if isinstance(da_or_freq, (xr.DataArray, xr.Dataset)): freq = da_to_timedelta(da_or_freq) else: diff --git a/tests/conftest.py b/tests/conftest.py index cce907877..cb70e6c4e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,8 @@ from dask import config as dask_config from shapely.geometry import box +dask_config.set(scheduler="single-threaded") + import hydromt._compat as compat if compat.HAS_XUGRID: diff --git a/tests/test_data_adapter.py b/tests/test_data_adapter.py index 68d841d47..160fbab8e 100644 --- a/tests/test_data_adapter.py +++ b/tests/test_data_adapter.py @@ -40,13 +40,16 @@ def test_resolve_path(tmpdir): } cat = DataCatalog() cat.from_dict(dd) + source = cat.get_source("test") # test - assert len(cat.get_source("test").resolve_paths()) == 48 - assert len(cat.get_source("test").resolve_paths(variables=["precip"])) == 24 - kwargs = dict(variables=["precip"], time_tuple=("2021-03-01", "2021-05-01")) - assert len(cat.get_source("test").resolve_paths(**kwargs)) == 3 + fns = source._resolve_paths() + assert len(fns) == 48 + fns = source._resolve_paths(variables=["precip"]) + assert len(fns) == 24 + fns = source._resolve_paths(("2021-03-01", "2021-05-01"), ["precip"]) + assert len(fns) == 3 with pytest.raises(FileNotFoundError, match="No such file found:"): - cat.get_source("test").resolve_paths(variables=["waves"]) + source._resolve_paths(variables=["waves"]) def test_rasterdataset(rioda, tmpdir): @@ -184,11 +187,6 @@ def test_rasterdataset_unit_attrs(artifact_data: DataCatalog): # @pytest.mark.skip() def test_geodataset(geoda, geodf, ts, tmpdir): - # this test can sometimes hang because of threading issues therefore - # the synchronous scheduler here is necessary - from dask import config as dask_config - - dask_config.set(scheduler="single-threaded") fn_nc = str(tmpdir.join("test.nc")) fn_gdf = str(tmpdir.join("test.geojson")) fn_csv = str(tmpdir.join("test.csv")) @@ -210,6 +208,7 @@ def test_geodataset(geoda, geodf, ts, tmpdir): da2 = data_catalog.get_geodataset( fn_gdf, driver_kwargs=dict(fn_data=fn_csv) ).sortby("index") + assert isinstance(da2, xr.DataArray), type(da2) assert np.allclose(da2, geoda) # test with xy locs da3 = data_catalog.get_geodataset( diff --git a/tests/test_data_catalog.py b/tests/test_data_catalog.py index 5d8321cb0..86f0fab7c 100644 --- a/tests/test_data_catalog.py +++ b/tests/test_data_catalog.py @@ -5,6 +5,7 @@ from pathlib import Path import geopandas as gpd +import numpy as np import pandas as pd import pytest import xarray as xr @@ -391,10 +392,13 @@ def test_get_data(df, tmpdir): assert isinstance(da, xr.DataArray) da = data_catalog.get_rasterdataset(name, provider="artifact_data") assert isinstance(da, xr.DataArray) - da = data_catalog.get_rasterdataset(da) + bbox = [12.0, 46.0, 13.0, 46.5] + da = data_catalog.get_rasterdataset(da, bbox=bbox) assert isinstance(da, xr.DataArray) + assert np.allclose(da.raster.bounds, bbox) data = {"source": name, "provider": "artifact_data"} - da = data_catalog.get_rasterdataset(data) + ds = data_catalog.get_rasterdataset(data, single_var_as_array=False) + assert isinstance(ds, xr.Dataset) with pytest.raises(ValueError, match='Unknown raster data type "list"'): data_catalog.get_rasterdataset([]) with pytest.raises(FileNotFoundError): @@ -409,8 +413,10 @@ def test_get_data(df, tmpdir): assert isinstance(gdf, gpd.GeoDataFrame) gdf = data_catalog.get_geodataframe(name, provider="artifact_data") assert isinstance(gdf, gpd.GeoDataFrame) - gdf = data_catalog.get_geodataframe(gdf) + assert gdf.index.size == 2 + gdf = data_catalog.get_geodataframe(gdf, geom=gdf.iloc[[0],], predicate="within") assert isinstance(gdf, gpd.GeoDataFrame) + assert gdf.index.size == 1 data = {"source": name, "provider": "artifact_data"} gdf = data_catalog.get_geodataframe(data) assert isinstance(gdf, gpd.GeoDataFrame) @@ -427,12 +433,18 @@ def test_get_data(df, tmpdir): assert len(data_catalog) == n + 3 assert isinstance(da, xr.DataArray) da = data_catalog.get_geodataset(name, provider="artifact_data") + assert da.vector.index.size == 19 assert isinstance(da, xr.DataArray) - da = data_catalog.get_geodataset(da) + bbox = [12.22412, 45.25635, 12.25342, 45.271] + da = data_catalog.get_geodataset( + da, bbox=bbox, time_tuple=("2010-02-01", "2010-02-05") + ) + assert da.vector.index.size == 2 + assert da.time.size == 720 assert isinstance(da, xr.DataArray) data = {"source": name, "provider": "artifact_data"} - gdf = data_catalog.get_geodataset(data) - assert isinstance(gdf, xr.DataArray) + ds = data_catalog.get_geodataset(data, single_var_as_array=False) + assert isinstance(ds, xr.Dataset) with pytest.raises(ValueError, match='Unknown geo data type "list"'): data_catalog.get_geodataset([]) with pytest.raises(FileNotFoundError): @@ -449,8 +461,9 @@ def test_get_data(df, tmpdir): assert isinstance(df, pd.DataFrame) df = data_catalog.get_dataframe(name, provider="local") assert isinstance(df, pd.DataFrame) - df = data_catalog.get_dataframe(df) + df = data_catalog.get_dataframe(df, variables=["city"]) assert isinstance(df, pd.DataFrame) + assert df.columns == ["city"] data = {"source": name, "provider": "local"} gdf = data_catalog.get_dataframe(data) assert isinstance(gdf, pd.DataFrame)