From c1f0f6f13d79b3efd07f31be6c07c7641ff6fb1a Mon Sep 17 00:00:00 2001 From: Sam Vente Date: Thu, 3 Aug 2023 18:24:53 +0200 Subject: [PATCH] return kwargs from adapter.to_file funtions (#456) --- hydromt/data_adapter/dataframe.py | 8 ++++++-- hydromt/data_adapter/geodataframe.py | 6 ++++-- hydromt/data_adapter/geodataset.py | 6 ++++-- hydromt/data_adapter/rasterdataset.py | 7 +++++-- hydromt/data_catalog.py | 4 +++- 5 files changed, 22 insertions(+), 9 deletions(-) diff --git a/hydromt/data_adapter/dataframe.py b/hydromt/data_adapter/dataframe.py index be4093baf..f10bd1527 100644 --- a/hydromt/data_adapter/dataframe.py +++ b/hydromt/data_adapter/dataframe.py @@ -152,6 +152,8 @@ def to_file( driver : str Name of the driver used to read the data. See :py:func:`~hydromt.data_catalog.DataCatalog.get_geodataset`. + kwargs: dict + The additional keyword arguments that were passed in. """ @@ -162,13 +164,15 @@ def to_file( ) except IndexError as err: # out of bounds for time logger.warning(str(err)) - return None, None + return None, None, None + read_kwargs = dict() if driver is None or driver == "csv": # always write as CSV driver = "csv" fn_out = join(data_root, f"{data_name}.csv") obj.to_csv(fn_out, **kwargs) + read_kwargs["index_col"] = 0 elif driver == "parquet": fn_out = join(data_root, f"{data_name}.parquet") obj.to_parquet(fn_out, **kwargs) @@ -178,7 +182,7 @@ def to_file( else: raise ValueError(f"DataFrame: Driver {driver} is unknown.") - return fn_out, driver + return fn_out, driver, read_kwargs def get_data( self, diff --git a/hydromt/data_adapter/geodataframe.py b/hydromt/data_adapter/geodataframe.py index 3b47ba518..d692d12a6 100644 --- a/hydromt/data_adapter/geodataframe.py +++ b/hydromt/data_adapter/geodataframe.py @@ -167,8 +167,9 @@ def to_file( kwargs.pop("time_tuple", None) gdf = self.get_data(bbox=bbox, variables=variables, logger=logger) if gdf.index.size == 0: - return None, None + return None, None, None + read_kwargs = {} if driver is None: _lst = ["csv", "parquet", "xls", "xlsx", "xy", "vector_table"] driver = "csv" if self.driver in _lst else "GPKG" @@ -182,6 +183,7 @@ def to_file( ) gdf["x"], gdf["y"] = gdf.geometry.x, gdf.geometry.y gdf.drop(columns="geometry").to_csv(fn_out, **kwargs) + read_kwargs["index_col"] = 0 elif driver == "parquet": fn_out = join(data_root, f"{data_name}.parquet") if not np.all(gdf.geometry.type == "Point"): @@ -200,7 +202,7 @@ def to_file( gdf.to_file(fn_out, driver=driver, **kwargs) driver = "vector" - return fn_out, driver + return fn_out, driver, read_kwargs def get_data( self, diff --git a/hydromt/data_adapter/geodataset.py b/hydromt/data_adapter/geodataset.py index 635c0da13..56554e909 100644 --- a/hydromt/data_adapter/geodataset.py +++ b/hydromt/data_adapter/geodataset.py @@ -182,7 +182,9 @@ def to_file( single_var_as_array=variables is None, ) if obj.vector.index.size == 0 or ("time" in obj.coords and obj.time.size == 0): - return None, None + return None, None, None + + read_kwargs = {} # much better for mem/storage/processing if dtypes are set correctly for name, coord in obj.coords.items(): @@ -218,7 +220,7 @@ def to_file( else: raise ValueError(f"GeoDataset: Driver {driver} unknown.") - return fn_out, driver + return fn_out, driver, read_kwargs def get_data( self, diff --git a/hydromt/data_adapter/rasterdataset.py b/hydromt/data_adapter/rasterdataset.py index ee46055b2..9834d5ef0 100644 --- a/hydromt/data_adapter/rasterdataset.py +++ b/hydromt/data_adapter/rasterdataset.py @@ -179,6 +179,8 @@ def to_file( driver: str Name of driver to read data with, see :py:func:`~hydromt.data_catalog.DataCatalog.get_rasterdataset` + kwargs: dict + the additional kwyeord arguments that were passed to `to_netcdf` """ try: obj = self.get_data( @@ -190,8 +192,9 @@ def to_file( ) except IndexError as err: # out of bounds logger.warning(str(err)) - return None, None + return None, None, None + read_kwargs = {} if driver is None: # by default write 2D raster data to GeoTiff and 3D raster data to netcdf driver = "netcdf" if len(obj.dims) == 3 else "GTiff" @@ -228,7 +231,7 @@ def to_file( ) driver = "raster" - return fn_out, driver + return fn_out, driver, read_kwargs def get_data( self, diff --git a/hydromt/data_catalog.py b/hydromt/data_catalog.py index 66ac96de6..f619289af 100644 --- a/hydromt/data_catalog.py +++ b/hydromt/data_catalog.py @@ -867,7 +867,7 @@ def export_data( unit_add = source.unit_add source.unit_mult = {} source.unit_add = {} - fn_out, driver = source.to_file( + fn_out, driver, driver_kwargs = source.to_file( data_root=data_root, data_name=key, variables=source_vars.get(key, None), @@ -892,6 +892,8 @@ def export_data( source.driver = driver source.filesystem = "local" source.driver_kwargs = {} + if driver_kwargs is not None: + source.driver_kwargs.update(driver_kwargs) source.rename = {} if key in sources_out: self.logger.warning(