diff --git a/lib/idseq_utils/idseq_utils/batch_run_helpers.py b/lib/idseq_utils/idseq_utils/batch_run_helpers.py index cf817bd9..13908468 100644 --- a/lib/idseq_utils/idseq_utils/batch_run_helpers.py +++ b/lib/idseq_utils/idseq_utils/batch_run_helpers.py @@ -13,6 +13,8 @@ from typing import Dict, List, Optional from urllib.parse import urlparse +from itertools import tee + from idseq_utils.diamond_scatter import blastx_join from idseq_utils.minimap2_scatter import minimap2_merge @@ -295,6 +297,10 @@ def _db_chunks(bucket: str, prefix): for obj in page["Contents"]: yield obj["Key"] +def count_generator(gen): + gen, gen_copy = tee(gen) + generator_length = sum(1 for _ in gen_copy) + return generator_length, gen def run_alignment( input_dir: str, @@ -321,7 +327,8 @@ def run_alignment( ] for chunk_id, db_chunk in enumerate(_db_chunks(db_bucket, db_prefix)) ) - with Pool(len(list(chunks))) as p: + chunk_length, chunks = count_generator(chunks) + with Pool(chunk_length) as p: p.starmap(_run_chunk, chunks) run(["s3parcp", "--recursive", chunk_dir, "chunks"], check=True) if os.path.exists(os.path.join("chunks", "cache")):