From b043f277a2a3a4737fe2f217dcd9fbf4b76dcc59 Mon Sep 17 00:00:00 2001 From: Shivaen Date: Sun, 3 Sep 2023 00:51:59 -0700 Subject: [PATCH] integrate updates to BeamSearchScorer --- src/open_clip/coca_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 039453af7..ad81fb665 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -320,10 +320,10 @@ def _generate_beamsearch( else logit_processor ) - batch_size = len(beam_scorer._beam_hyps) num_beams = beam_scorer.num_beams num_beam_groups = beam_scorer.num_beam_groups num_sub_beams = num_beams // num_beam_groups + batch_size = len(beam_scorer._beam_hyps) // num_beam_groups batch_beam_size, cur_len = input_ids.shape beam_indices = None @@ -400,6 +400,7 @@ def _generate_beamsearch( pad_token_id=pad_token_id, eos_token_id=eos_token_id, beam_indices=process_beam_indices, + group_index=beam_group_idx, ) beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] beam_next_tokens = beam_outputs["next_beam_tokens"]