From 2d3ef0b3d7d4436454afcac41399af3a05f4abac Mon Sep 17 00:00:00 2001 From: Tomoya Kose Date: Mon, 4 Nov 2024 15:44:19 +0900 Subject: [PATCH] Load examples from JSONLines Dataset without duplication. --- src/torch_wae/cli/convert_dataset_to_web.py | 2 +- src/torch_wae/dataset.py | 12 ++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/torch_wae/cli/convert_dataset_to_web.py b/src/torch_wae/cli/convert_dataset_to_web.py index bd25a70..24365b7 100644 --- a/src/torch_wae/cli/convert_dataset_to_web.py +++ b/src/torch_wae/cli/convert_dataset_to_web.py @@ -61,7 +61,7 @@ def main( ) output.mkdir(parents=True, exist_ok=True) - pattern = str(output / "%06d.tar") + pattern = str(output / "%04d.tar") max_size = size_shard * 1024**2 with tqdm(total=n) as progress: diff --git a/src/torch_wae/dataset.py b/src/torch_wae/dataset.py index 8d21764..c77aefd 100644 --- a/src/torch_wae/dataset.py +++ b/src/torch_wae/dataset.py @@ -15,9 +15,17 @@ def __init__(self, path: Path) -> None: self.__path = path def __iter__(self) -> Iterator[Any]: + worker_info = data.get_worker_info() + with self.__path.open() as f: - for line in f: - yield json.loads(line) + for i, line in enumerate(f): + if worker_info is None: + yield json.loads(line) + else: + worker_id = worker_info.id + num_workers = worker_info.num_workers + if i % num_workers == worker_id: + yield json.loads(line) @dataclass(frozen=True)