Skip to content

Commit

Permalink
return kwargs from adapter.to_file funtions (#456)
Browse files Browse the repository at this point in the history
  • Loading branch information
savente93 authored Aug 3, 2023
1 parent 50e50b5 commit c1f0f6f
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 9 deletions.
8 changes: 6 additions & 2 deletions hydromt/data_adapter/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ def to_file(
driver : str
Name of the driver used to read the data.
See :py:func:`~hydromt.data_catalog.DataCatalog.get_geodataset`.
kwargs: dict
The additional keyword arguments that were passed in.
"""
Expand All @@ -162,13 +164,15 @@ def to_file(
)
except IndexError as err: # out of bounds for time
logger.warning(str(err))
return None, None
return None, None, None

read_kwargs = dict()
if driver is None or driver == "csv":
# always write as CSV
driver = "csv"
fn_out = join(data_root, f"{data_name}.csv")
obj.to_csv(fn_out, **kwargs)
read_kwargs["index_col"] = 0
elif driver == "parquet":
fn_out = join(data_root, f"{data_name}.parquet")
obj.to_parquet(fn_out, **kwargs)
Expand All @@ -178,7 +182,7 @@ def to_file(
else:
raise ValueError(f"DataFrame: Driver {driver} is unknown.")

return fn_out, driver
return fn_out, driver, read_kwargs

def get_data(
self,
Expand Down
6 changes: 4 additions & 2 deletions hydromt/data_adapter/geodataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,9 @@ def to_file(
kwargs.pop("time_tuple", None)
gdf = self.get_data(bbox=bbox, variables=variables, logger=logger)
if gdf.index.size == 0:
return None, None
return None, None, None

read_kwargs = {}
if driver is None:
_lst = ["csv", "parquet", "xls", "xlsx", "xy", "vector_table"]
driver = "csv" if self.driver in _lst else "GPKG"
Expand All @@ -182,6 +183,7 @@ def to_file(
)
gdf["x"], gdf["y"] = gdf.geometry.x, gdf.geometry.y
gdf.drop(columns="geometry").to_csv(fn_out, **kwargs)
read_kwargs["index_col"] = 0
elif driver == "parquet":
fn_out = join(data_root, f"{data_name}.parquet")
if not np.all(gdf.geometry.type == "Point"):
Expand All @@ -200,7 +202,7 @@ def to_file(
gdf.to_file(fn_out, driver=driver, **kwargs)
driver = "vector"

return fn_out, driver
return fn_out, driver, read_kwargs

def get_data(
self,
Expand Down
6 changes: 4 additions & 2 deletions hydromt/data_adapter/geodataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,9 @@ def to_file(
single_var_as_array=variables is None,
)
if obj.vector.index.size == 0 or ("time" in obj.coords and obj.time.size == 0):
return None, None
return None, None, None

read_kwargs = {}

# much better for mem/storage/processing if dtypes are set correctly
for name, coord in obj.coords.items():
Expand Down Expand Up @@ -218,7 +220,7 @@ def to_file(
else:
raise ValueError(f"GeoDataset: Driver {driver} unknown.")

return fn_out, driver
return fn_out, driver, read_kwargs

def get_data(
self,
Expand Down
7 changes: 5 additions & 2 deletions hydromt/data_adapter/rasterdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ def to_file(
driver: str
Name of driver to read data with, see
:py:func:`~hydromt.data_catalog.DataCatalog.get_rasterdataset`
kwargs: dict
the additional kwyeord arguments that were passed to `to_netcdf`
"""
try:
obj = self.get_data(
Expand All @@ -190,8 +192,9 @@ def to_file(
)
except IndexError as err: # out of bounds
logger.warning(str(err))
return None, None
return None, None, None

read_kwargs = {}
if driver is None:
# by default write 2D raster data to GeoTiff and 3D raster data to netcdf
driver = "netcdf" if len(obj.dims) == 3 else "GTiff"
Expand Down Expand Up @@ -228,7 +231,7 @@ def to_file(
)
driver = "raster"

return fn_out, driver
return fn_out, driver, read_kwargs

def get_data(
self,
Expand Down
4 changes: 3 additions & 1 deletion hydromt/data_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,7 @@ def export_data(
unit_add = source.unit_add
source.unit_mult = {}
source.unit_add = {}
fn_out, driver = source.to_file(
fn_out, driver, driver_kwargs = source.to_file(
data_root=data_root,
data_name=key,
variables=source_vars.get(key, None),
Expand All @@ -892,6 +892,8 @@ def export_data(
source.driver = driver
source.filesystem = "local"
source.driver_kwargs = {}
if driver_kwargs is not None:
source.driver_kwargs.update(driver_kwargs)
source.rename = {}
if key in sources_out:
self.logger.warning(
Expand Down

0 comments on commit c1f0f6f

Please sign in to comment.