Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

return kwargs from adapter.to_file funtions #456

Merged
merged 10 commits into from
Aug 3, 2023
8 changes: 6 additions & 2 deletions hydromt/data_adapter/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ def to_file(
driver : str
Name of the driver used to read the data.
See :py:func:`~hydromt.data_catalog.DataCatalog.get_geodataset`.
kwargs: dict
The additional keyword arguments that were passed in.


"""
Expand All @@ -162,13 +164,15 @@ def to_file(
)
except IndexError as err: # out of bounds for time
logger.warning(str(err))
return None, None
return None, None, None

read_kwargs = dict()
if driver is None or driver == "csv":
# always write as CSV
driver = "csv"
fn_out = join(data_root, f"{data_name}.csv")
obj.to_csv(fn_out, **kwargs)
read_kwargs["index_col"] = 0
elif driver == "parquet":
fn_out = join(data_root, f"{data_name}.parquet")
obj.to_parquet(fn_out, **kwargs)
Expand All @@ -178,7 +182,7 @@ def to_file(
else:
raise ValueError(f"DataFrame: Driver {driver} is unknown.")

return fn_out, driver
return fn_out, driver, read_kwargs

def get_data(
self,
Expand Down
6 changes: 4 additions & 2 deletions hydromt/data_adapter/geodataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,9 @@ def to_file(
kwargs.pop("time_tuple", None)
gdf = self.get_data(bbox=bbox, variables=variables, logger=logger)
if gdf.index.size == 0:
return None, None
return None, None, None

read_kwargs = {}
if driver is None:
_lst = ["csv", "parquet", "xls", "xlsx", "xy", "vector_table"]
driver = "csv" if self.driver in _lst else "GPKG"
Expand All @@ -182,6 +183,7 @@ def to_file(
)
gdf["x"], gdf["y"] = gdf.geometry.x, gdf.geometry.y
gdf.drop(columns="geometry").to_csv(fn_out, **kwargs)
read_kwargs["index_col"] = 0
elif driver == "parquet":
fn_out = join(data_root, f"{data_name}.parquet")
if not np.all(gdf.geometry.type == "Point"):
Expand All @@ -200,7 +202,7 @@ def to_file(
gdf.to_file(fn_out, driver=driver, **kwargs)
driver = "vector"

return fn_out, driver
return fn_out, driver, read_kwargs

def get_data(
self,
Expand Down
6 changes: 4 additions & 2 deletions hydromt/data_adapter/geodataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,9 @@ def to_file(
single_var_as_array=variables is None,
)
if obj.vector.index.size == 0 or ("time" in obj.coords and obj.time.size == 0):
return None, None
return None, None, None

read_kwargs = {}

# much better for mem/storage/processing if dtypes are set correctly
for name, coord in obj.coords.items():
Expand Down Expand Up @@ -218,7 +220,7 @@ def to_file(
else:
raise ValueError(f"GeoDataset: Driver {driver} unknown.")

return fn_out, driver
return fn_out, driver, read_kwargs

def get_data(
self,
Expand Down
7 changes: 5 additions & 2 deletions hydromt/data_adapter/rasterdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ def to_file(
driver: str
Name of driver to read data with, see
:py:func:`~hydromt.data_catalog.DataCatalog.get_rasterdataset`
kwargs: dict
the additional kwyeord arguments that were passed to `to_netcdf`
"""
try:
obj = self.get_data(
Expand All @@ -190,8 +192,9 @@ def to_file(
)
except IndexError as err: # out of bounds
logger.warning(str(err))
return None, None
return None, None, None

read_kwargs = {}
if driver is None:
# by default write 2D raster data to GeoTiff and 3D raster data to netcdf
driver = "netcdf" if len(obj.dims) == 3 else "GTiff"
Expand Down Expand Up @@ -228,7 +231,7 @@ def to_file(
)
driver = "raster"

return fn_out, driver
return fn_out, driver, read_kwargs

def get_data(
self,
Expand Down
4 changes: 3 additions & 1 deletion hydromt/data_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,7 @@ def export_data(
unit_add = source.unit_add
source.unit_mult = {}
source.unit_add = {}
fn_out, driver = source.to_file(
fn_out, driver, driver_kwargs = source.to_file(
data_root=data_root,
data_name=key,
variables=source_vars.get(key, None),
Expand All @@ -892,6 +892,8 @@ def export_data(
source.driver = driver
source.filesystem = "local"
source.driver_kwargs = {}
if driver_kwargs is not None:
source.driver_kwargs.update(driver_kwargs)
source.rename = {}
if key in sources_out:
self.logger.warning(
Expand Down
Loading