From a676d0abf60dbb61effddf3dbb73436e8d6b3ded Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Fri, 5 Apr 2024 19:33:49 -0700 Subject: [PATCH] Fix tests with the intended "auto" default (#1559) --- keras_nlp/models/bart/bart_seq_2_seq_lm_test.py | 2 +- keras_nlp/models/bloom/bloom_causal_lm_test.py | 6 +++--- keras_nlp/models/gemma/gemma_causal_lm_test.py | 4 ++-- keras_nlp/models/gpt2/gpt2_causal_lm_test.py | 2 +- keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm_test.py | 2 +- keras_nlp/models/llama/llama_causal_lm.py | 5 ++--- keras_nlp/models/llama/llama_causal_lm_test.py | 2 +- keras_nlp/models/mistral/mistral_causal_lm_test.py | 2 +- keras_nlp/models/opt/opt_causal_lm_test.py | 2 +- 9 files changed, 13 insertions(+), 14 deletions(-) diff --git a/keras_nlp/models/bart/bart_seq_2_seq_lm_test.py b/keras_nlp/models/bart/bart_seq_2_seq_lm_test.py index 280ec33dc6..2edd6b3a94 100644 --- a/keras_nlp/models/bart/bart_seq_2_seq_lm_test.py +++ b/keras_nlp/models/bart/bart_seq_2_seq_lm_test.py @@ -87,7 +87,7 @@ def test_generate(self): # Int tensor input. seq_2_seq_lm.preprocessor = None preprocessed_batch = self.preprocessor.generate_preprocess(inputs) - outputs = seq_2_seq_lm.generate(preprocessed_batch) + outputs = seq_2_seq_lm.generate(preprocessed_batch, stop_token_ids=None) # Assert prompt is in output in token id space. self.assertAllEqual( outputs["decoder_token_ids"][:, :5], diff --git a/keras_nlp/models/bloom/bloom_causal_lm_test.py b/keras_nlp/models/bloom/bloom_causal_lm_test.py index 70af6a2302..431a734da5 100644 --- a/keras_nlp/models/bloom/bloom_causal_lm_test.py +++ b/keras_nlp/models/bloom/bloom_causal_lm_test.py @@ -78,7 +78,7 @@ def test_generate(self): # Int tensor input. prompt_ids = self.preprocessor.generate_preprocess([prompt]) causal_lm.preprocessor = None - outputs = causal_lm.generate(prompt_ids) + outputs = causal_lm.generate(prompt_ids, stop_token_ids=None) # Assert prompt is in output in token id space. self.assertAllEqual( outputs["token_ids"][:, :4], @@ -103,7 +103,7 @@ def test_generate_with_bfloat16(self): # Int tensor input. prompt_ids = self.preprocessor.generate_preprocess([prompt]) causal_lm.preprocessor = None - outputs = causal_lm.generate(prompt_ids) + outputs = causal_lm.generate(prompt_ids, stop_token_ids=None) # Assert prompt is in output in token id space. self.assertAllEqual( outputs["token_ids"][:, :4], @@ -128,7 +128,7 @@ def test_generate_with_mixed_float16(self): # Int tensor input. prompt_ids = self.preprocessor.generate_preprocess([prompt]) causal_lm.preprocessor = None - outputs = causal_lm.generate(prompt_ids) + outputs = causal_lm.generate(prompt_ids, stop_token_ids=None) # Assert prompt is in output in token id space. self.assertAllEqual( outputs["token_ids"][:, :4], diff --git a/keras_nlp/models/gemma/gemma_causal_lm_test.py b/keras_nlp/models/gemma/gemma_causal_lm_test.py index 4a47d162ef..8f681df52f 100644 --- a/keras_nlp/models/gemma/gemma_causal_lm_test.py +++ b/keras_nlp/models/gemma/gemma_causal_lm_test.py @@ -73,7 +73,7 @@ def test_generate(self): # Int tensor input. prompt_ids = self.preprocessor.generate_preprocess([prompt]) causal_lm.preprocessor = None - outputs = causal_lm.generate(prompt_ids) + outputs = causal_lm.generate(prompt_ids, stop_token_ids=None) # Assert prompt is in output in token id space. self.assertAllEqual( outputs["token_ids"][:, :4], @@ -96,7 +96,7 @@ def test_generate_with_bfloat16(self): # Int tensor input. prompt_ids = self.preprocessor.generate_preprocess([prompt]) causal_lm.preprocessor = None - outputs = causal_lm.generate(prompt_ids) + outputs = causal_lm.generate(prompt_ids, stop_token_ids=None) # Assert prompt is in output in token id space. self.assertAllEqual( outputs["token_ids"][:, :4], diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm_test.py b/keras_nlp/models/gpt2/gpt2_causal_lm_test.py index 8999ebd9af..2d2f0e8a0a 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm_test.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm_test.py @@ -70,7 +70,7 @@ def test_generate(self): # Int tensor input. prompt_ids = self.preprocessor.generate_preprocess([prompt]) causal_lm.preprocessor = None - outputs = causal_lm.generate(prompt_ids) + outputs = causal_lm.generate(prompt_ids, stop_token_ids=None) # Assert prompt is in output in token id space. self.assertAllEqual( outputs["token_ids"][:, :5], diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm_test.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm_test.py index c8839c8be9..3bf0197e4a 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm_test.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm_test.py @@ -70,7 +70,7 @@ def test_generate(self): # Int tensor input. prompt_ids = self.preprocessor.generate_preprocess([prompt]) causal_lm.preprocessor = None - outputs = causal_lm.generate(prompt_ids) + outputs = causal_lm.generate(prompt_ids, stop_token_ids=None) # Assert prompt is in output in token id space. self.assertAllEqual( outputs["token_ids"][:, :5], diff --git a/keras_nlp/models/llama/llama_causal_lm.py b/keras_nlp/models/llama/llama_causal_lm.py index b1e85d2925..8bc629281c 100644 --- a/keras_nlp/models/llama/llama_causal_lm.py +++ b/keras_nlp/models/llama/llama_causal_lm.py @@ -187,9 +187,8 @@ def next(prompt, cache, index): if stop_token_ids is not None: # Build a mask of stop token locations not in the original # prompt (not in locations where `padding_mask` is True). - end_locations = ops.logical_and( - any_equal(token_ids, stop_token_ids), - ops.logical_not(padding_mask), + end_locations = any_equal( + token_ids, stop_token_ids, ops.logical_not(padding_mask) ) end_locations = ops.cast(end_locations, "int32") # Use cumsum to get ones in all locations after end_locations. diff --git a/keras_nlp/models/llama/llama_causal_lm_test.py b/keras_nlp/models/llama/llama_causal_lm_test.py index c006f72783..7a449ad3ed 100644 --- a/keras_nlp/models/llama/llama_causal_lm_test.py +++ b/keras_nlp/models/llama/llama_causal_lm_test.py @@ -70,7 +70,7 @@ def test_generate(self): # Int tensor input. prompt_ids = self.preprocessor.generate_preprocess([prompt]) causal_lm.preprocessor = None - outputs = causal_lm.generate(prompt_ids) + outputs = causal_lm.generate(prompt_ids, stop_token_ids=None) # Assert prompt is in output in token id space. self.assertAllEqual( outputs["token_ids"][:, :5], diff --git a/keras_nlp/models/mistral/mistral_causal_lm_test.py b/keras_nlp/models/mistral/mistral_causal_lm_test.py index 13f0dad907..a5f1ac9d89 100644 --- a/keras_nlp/models/mistral/mistral_causal_lm_test.py +++ b/keras_nlp/models/mistral/mistral_causal_lm_test.py @@ -70,7 +70,7 @@ def test_generate(self): # Int tensor input. prompt_ids = self.preprocessor.generate_preprocess([prompt]) causal_lm.preprocessor = None - outputs = causal_lm.generate(prompt_ids) + outputs = causal_lm.generate(prompt_ids, stop_token_ids=None) # Assert prompt is in output in token id space. self.assertAllEqual( outputs["token_ids"][:, :5], diff --git a/keras_nlp/models/opt/opt_causal_lm_test.py b/keras_nlp/models/opt/opt_causal_lm_test.py index 3ba27178d1..82c3614a73 100644 --- a/keras_nlp/models/opt/opt_causal_lm_test.py +++ b/keras_nlp/models/opt/opt_causal_lm_test.py @@ -69,7 +69,7 @@ def test_generate(self): # Int tensor input. prompt_ids = self.preprocessor.generate_preprocess([prompt]) causal_lm.preprocessor = None - outputs = causal_lm.generate(prompt_ids) + outputs = causal_lm.generate(prompt_ids, stop_token_ids=None) # Assert prompt is in output in token id space. self.assertAllEqual( outputs["token_ids"][:, :5],