Skip to content

Commit

Permalink
Fix tests with the intended "auto" default (keras-team#1559)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdangerw authored Apr 6, 2024
1 parent b50f19c commit a676d0a
Show file tree
Hide file tree
Showing 9 changed files with 13 additions and 14 deletions.
2 changes: 1 addition & 1 deletion keras_nlp/models/bart/bart_seq_2_seq_lm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_generate(self):
# Int tensor input.
seq_2_seq_lm.preprocessor = None
preprocessed_batch = self.preprocessor.generate_preprocess(inputs)
outputs = seq_2_seq_lm.generate(preprocessed_batch)
outputs = seq_2_seq_lm.generate(preprocessed_batch, stop_token_ids=None)
# Assert prompt is in output in token id space.
self.assertAllEqual(
outputs["decoder_token_ids"][:, :5],
Expand Down
6 changes: 3 additions & 3 deletions keras_nlp/models/bloom/bloom_causal_lm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_generate(self):
# Int tensor input.
prompt_ids = self.preprocessor.generate_preprocess([prompt])
causal_lm.preprocessor = None
outputs = causal_lm.generate(prompt_ids)
outputs = causal_lm.generate(prompt_ids, stop_token_ids=None)
# Assert prompt is in output in token id space.
self.assertAllEqual(
outputs["token_ids"][:, :4],
Expand All @@ -103,7 +103,7 @@ def test_generate_with_bfloat16(self):
# Int tensor input.
prompt_ids = self.preprocessor.generate_preprocess([prompt])
causal_lm.preprocessor = None
outputs = causal_lm.generate(prompt_ids)
outputs = causal_lm.generate(prompt_ids, stop_token_ids=None)
# Assert prompt is in output in token id space.
self.assertAllEqual(
outputs["token_ids"][:, :4],
Expand All @@ -128,7 +128,7 @@ def test_generate_with_mixed_float16(self):
# Int tensor input.
prompt_ids = self.preprocessor.generate_preprocess([prompt])
causal_lm.preprocessor = None
outputs = causal_lm.generate(prompt_ids)
outputs = causal_lm.generate(prompt_ids, stop_token_ids=None)
# Assert prompt is in output in token id space.
self.assertAllEqual(
outputs["token_ids"][:, :4],
Expand Down
4 changes: 2 additions & 2 deletions keras_nlp/models/gemma/gemma_causal_lm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_generate(self):
# Int tensor input.
prompt_ids = self.preprocessor.generate_preprocess([prompt])
causal_lm.preprocessor = None
outputs = causal_lm.generate(prompt_ids)
outputs = causal_lm.generate(prompt_ids, stop_token_ids=None)
# Assert prompt is in output in token id space.
self.assertAllEqual(
outputs["token_ids"][:, :4],
Expand All @@ -96,7 +96,7 @@ def test_generate_with_bfloat16(self):
# Int tensor input.
prompt_ids = self.preprocessor.generate_preprocess([prompt])
causal_lm.preprocessor = None
outputs = causal_lm.generate(prompt_ids)
outputs = causal_lm.generate(prompt_ids, stop_token_ids=None)
# Assert prompt is in output in token id space.
self.assertAllEqual(
outputs["token_ids"][:, :4],
Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/models/gpt2/gpt2_causal_lm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_generate(self):
# Int tensor input.
prompt_ids = self.preprocessor.generate_preprocess([prompt])
causal_lm.preprocessor = None
outputs = causal_lm.generate(prompt_ids)
outputs = causal_lm.generate(prompt_ids, stop_token_ids=None)
# Assert prompt is in output in token id space.
self.assertAllEqual(
outputs["token_ids"][:, :5],
Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_generate(self):
# Int tensor input.
prompt_ids = self.preprocessor.generate_preprocess([prompt])
causal_lm.preprocessor = None
outputs = causal_lm.generate(prompt_ids)
outputs = causal_lm.generate(prompt_ids, stop_token_ids=None)
# Assert prompt is in output in token id space.
self.assertAllEqual(
outputs["token_ids"][:, :5],
Expand Down
5 changes: 2 additions & 3 deletions keras_nlp/models/llama/llama_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,8 @@ def next(prompt, cache, index):
if stop_token_ids is not None:
# Build a mask of stop token locations not in the original
# prompt (not in locations where `padding_mask` is True).
end_locations = ops.logical_and(
any_equal(token_ids, stop_token_ids),
ops.logical_not(padding_mask),
end_locations = any_equal(
token_ids, stop_token_ids, ops.logical_not(padding_mask)
)
end_locations = ops.cast(end_locations, "int32")
# Use cumsum to get ones in all locations after end_locations.
Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/models/llama/llama_causal_lm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_generate(self):
# Int tensor input.
prompt_ids = self.preprocessor.generate_preprocess([prompt])
causal_lm.preprocessor = None
outputs = causal_lm.generate(prompt_ids)
outputs = causal_lm.generate(prompt_ids, stop_token_ids=None)
# Assert prompt is in output in token id space.
self.assertAllEqual(
outputs["token_ids"][:, :5],
Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/models/mistral/mistral_causal_lm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_generate(self):
# Int tensor input.
prompt_ids = self.preprocessor.generate_preprocess([prompt])
causal_lm.preprocessor = None
outputs = causal_lm.generate(prompt_ids)
outputs = causal_lm.generate(prompt_ids, stop_token_ids=None)
# Assert prompt is in output in token id space.
self.assertAllEqual(
outputs["token_ids"][:, :5],
Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/models/opt/opt_causal_lm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_generate(self):
# Int tensor input.
prompt_ids = self.preprocessor.generate_preprocess([prompt])
causal_lm.preprocessor = None
outputs = causal_lm.generate(prompt_ids)
outputs = causal_lm.generate(prompt_ids, stop_token_ids=None)
# Assert prompt is in output in token id space.
self.assertAllEqual(
outputs["token_ids"][:, :5],
Expand Down

0 comments on commit a676d0a

Please sign in to comment.