Skip to content

Commit

Permalink
integrate updates to BeamSearchScorer (mlfoundations#617)
Browse files Browse the repository at this point in the history
  • Loading branch information
sramshetty authored and Interpause committed May 23, 2024
1 parent a39f3a2 commit 643f085
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/open_clip/coca_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,10 +320,10 @@ def _generate_beamsearch(
else logit_processor
)

batch_size = len(beam_scorer._beam_hyps)
num_beams = beam_scorer.num_beams
num_beam_groups = beam_scorer.num_beam_groups
num_sub_beams = num_beams // num_beam_groups
batch_size = len(beam_scorer._beam_hyps) // num_beam_groups
batch_beam_size, cur_len = input_ids.shape
beam_indices = None

Expand Down Expand Up @@ -400,6 +400,7 @@ def _generate_beamsearch(
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
beam_indices=process_beam_indices,
group_index=beam_group_idx,
)
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
Expand Down

0 comments on commit 643f085

Please sign in to comment.