From fd48a1a714793e7d7b1179b54af6d4d99c490dee Mon Sep 17 00:00:00 2001 From: zion <51308183+nousr@users.noreply.github.com> Date: Fri, 4 Nov 2022 13:42:13 -0700 Subject: [PATCH] update key toggles in inf.main (#201) --- clip_retrieval/clip_inference/main.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/clip_retrieval/clip_inference/main.py b/clip_retrieval/clip_inference/main.py index 67a1040..cc762a8 100644 --- a/clip_retrieval/clip_inference/main.py +++ b/clip_retrieval/clip_inference/main.py @@ -21,6 +21,9 @@ def calculate_partition_count( ): """ Calculate the partition count needed to store the resulting embeddings. + + Return: + - the output partition count and the updated toggles for image, text and metadata. """ sample_count = 0 @@ -59,7 +62,7 @@ def calculate_partition_count( output_partition_count = math.ceil(sample_count / write_batch_size) - return output_partition_count + return output_partition_count, enable_text, enable_image, enable_metadata # pylint: disable=unused-argument @@ -103,7 +106,7 @@ def main( # compute this now for the distributors to use if output_partition_count is None: - output_partition_count = calculate_partition_count( + output_partition_count, enable_text, enable_image, enable_metadata = calculate_partition_count( input_format=input_format, input_dataset=expanded_dataset, enable_image=enable_image, @@ -113,7 +116,11 @@ def main( wds_number_file_per_input_file=wds_number_file_per_input_file, ) + # update the local args to match the computed values local_args["output_partition_count"] = output_partition_count + local_args["enable_text"] = enable_text + local_args["enable_image"] = enable_image + local_args["enable_metadata"] = enable_metadata local_args.pop("wds_number_file_per_input_file") local_args.pop("write_batch_size")