From 22e21affd0caca9c2b9ef1604abf74175c6f0bbb Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Mon, 21 Aug 2023 16:37:48 -0700 Subject: [PATCH] Stop asserting key order in bart preprocessor (#1221) --- keras_nlp/models/bart/bart_preprocessor.py | 2 +- keras_nlp/models/bart/bart_preprocessor_test.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/keras_nlp/models/bart/bart_preprocessor.py b/keras_nlp/models/bart/bart_preprocessor.py index 04626ddb9c..0493c89c6f 100644 --- a/keras_nlp/models/bart/bart_preprocessor.py +++ b/keras_nlp/models/bart/bart_preprocessor.py @@ -179,7 +179,7 @@ def get_config(self): def call(self, x, y=None, sample_weight=None): if not ( isinstance(x, dict) - and ["encoder_text", "decoder_text"] == list(x.keys()) + and all(k in x for k in ("encoder_text", "decoder_text")) ): raise ValueError( '`x` must be a dictionary, containing the keys `"encoder_text"`' diff --git a/keras_nlp/models/bart/bart_preprocessor_test.py b/keras_nlp/models/bart/bart_preprocessor_test.py index df569c5de1..0b1929895b 100644 --- a/keras_nlp/models/bart/bart_preprocessor_test.py +++ b/keras_nlp/models/bart/bart_preprocessor_test.py @@ -74,6 +74,22 @@ def test_tokenize_strings(self): output["decoder_padding_mask"], [1, 1, 1, 1, 1, 1, 1, 1, 0] ) + def test_key_order(self): + self.assertAllClose( + self.preprocessor( + { + "encoder_text": " airplane at airport", + "decoder_text": " kohli is the best", + } + ), + self.preprocessor( + { + "decoder_text": " kohli is the best", + "encoder_text": " airplane at airport", + } + ), + ) + def test_tokenize_list_of_strings(self): input_data = { "encoder_text": [" airplane at airport"] * 4,