Skip to content

Commit

Permalink
Fix context download missing task_metadata file (#90)
Browse files Browse the repository at this point in the history
  • Loading branch information
Innixma authored Jan 28, 2025
1 parent 21c4257 commit 440c03f
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 6 deletions.
6 changes: 5 additions & 1 deletion tabrepo/contexts/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,14 +490,17 @@ def construct_s3_download_map(
path_context: str,
split_key: str,
files_pp: List[str],
files_gt: List[str]
files_gt: List[str],
task_metadata: str | None = None,
) -> Dict[str, str]:
split_value = f"{s3_prefix}model_predictions/"
s3_download_map = {
# FIXME: COMPARISON ROUNDING ERROR
"configs.parquet": "configs.parquet",
"baselines.parquet": "baselines.parquet",
}
if task_metadata is not None:
s3_download_map[task_metadata] = task_metadata
s3_download_map = {f'{path_context}{k}': f'{s3_prefix}{v}' for k, v in s3_download_map.items()}
_s3_download_map_metadata_pp = {f"{split_key}{f}": f"{split_value}{f}" for f in files_pp}
_s3_download_map_metadata_gt = {f"{split_key}{f}": f"{split_value}{f}" for f in files_gt}
Expand Down Expand Up @@ -568,6 +571,7 @@ def construct_context(
split_key=split_key,
files_pp=_files_pp,
files_gt=_files_gt,
task_metadata=task_metadata,
)
else:
_s3_download_map = None
Expand Down
2 changes: 1 addition & 1 deletion tabrepo/contexts/context_2023_08_21.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@
s3_prefix=s3_prefix,
folds=folds,
date=date,
task_metadata="task_metadata_244.csv",
task_metadata="task_metadata.csv",
metadata_join_column=metadata_join_column,
configs_hyperparameters=configs,
)
Expand Down
2 changes: 1 addition & 1 deletion tabrepo/contexts/context_2023_11_14.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
s3_prefix=s3_prefix,
folds=folds,
date=date,
task_metadata="task_metadata_244.csv",
task_metadata="task_metadata.csv",
metadata_join_column=metadata_join_column,
configs_hyperparameters=configs,
)
Expand Down
10 changes: 7 additions & 3 deletions tabrepo/utils/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,23 @@
import pathlib
import os
import urllib.request
from urllib.error import HTTPError
from tqdm import tqdm

from tabrepo.utils.parallel_for import parallel_for


def download_files(remote_to_local_tuple_list: list, dry_run: bool = False, verbose: bool = False):
def download_file(remote_path, local_path, dry_run):
def download_file(remote_path: str, local_path: str, dry_run: bool) -> None:
if dry_run:
print(f'Dry Run: Would download file "{remote_path}" to "{local_path}"')
return
directory = os.path.dirname(local_path)
if directory not in ["", "."]:
pathlib.Path(directory).mkdir(parents=True, exist_ok=True)
urllib.request.urlretrieve(remote_path, local_path)
try:
urllib.request.urlretrieve(remote_path, local_path)
except HTTPError as e:
raise Exception(f"Failed to download file '{remote_path}' ... Maybe this file does not exist or is not public?") from e

parallel_for(download_file, inputs=remote_to_local_tuple_list, context={"dry_run": dry_run})
parallel_for(download_file, inputs=remote_to_local_tuple_list, context={"dry_run": dry_run})

0 comments on commit 440c03f

Please sign in to comment.