Skip to content

Commit

Permalink
fix stopping_criteria check
Browse files Browse the repository at this point in the history
  • Loading branch information
MengqingCao committed May 11, 2024
1 parent b3b5fea commit 001fcac
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions src/open_clip/coca_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
RepetitionPenaltyLogitsProcessor,
MinLengthLogitsProcessor,
MaxLengthCriteria,
StopStringCriteria,
EosTokenCriteria,
StoppingCriteriaList
)

Expand Down Expand Up @@ -298,7 +300,12 @@ def generate(

cur_len += 1

if stopping_criteria(out, None).any():
is_done = False
if EosTokenCriteria in stopping_criteria or StopStringCriteria in stopping_criteria:
is_done = stopping_criteria(out, None).all()
else:
is_done = stopping_criteria(out, None).any()
if is_done:
break

if num_dims == 1:
Expand Down Expand Up @@ -439,7 +446,14 @@ def _generate_beamsearch(

# increase cur_len
cur_len = cur_len + 1
if beam_scorer.is_done or stopping_criteria(input_ids, None).any():
is_done = False
if EosTokenCriteria in stopping_criteria or StopStringCriteria in stopping_criteria:
is_done = stopping_criteria(input_ids, None).all()
else:
is_done = stopping_criteria(input_ids, None).any()
if is_done:
break
if beam_scorer.is_done or is_done:
break

final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
Expand Down

0 comments on commit 001fcac

Please sign in to comment.