Skip to content

Commit

Permalink
Preprocessing decorator fixes (#1843)
Browse files Browse the repository at this point in the history
* Fix handling bytesting input to tokenizers, preprocessing

* Fix no convert scope in multithreaded contexts
  • Loading branch information
mattdangerw authored Sep 19, 2024
1 parent 5d2e5f4 commit fb3dc3b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
11 changes: 5 additions & 6 deletions keras_nlp/src/utils/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,19 @@


NO_CONVERT_COUNTER = threading.local()
NO_CONVERT_COUNTER.count = 0


@contextlib.contextmanager
def no_convert_scope():
try:
NO_CONVERT_COUNTER.count += 1
NO_CONVERT_COUNTER.count = getattr(NO_CONVERT_COUNTER, "count", 0) + 1
yield
finally:
NO_CONVERT_COUNTER.count -= 1
NO_CONVERT_COUNTER.count = getattr(NO_CONVERT_COUNTER, "count", 0) - 1


def in_no_convert_scope():
return NO_CONVERT_COUNTER.count > 0
return getattr(NO_CONVERT_COUNTER, "count", 0) > 0


def preprocessing_function(fn):
Expand Down Expand Up @@ -119,7 +118,7 @@ def convert_preprocessing_inputs(x):
return {k: convert_preprocessing_inputs(x[k]) for k, v in x.items()}
if isinstance(x, tuple):
return tuple(convert_preprocessing_inputs(v) for v in x)
if isinstance(x, str):
if isinstance(x, (str, bytes)):
return tf.constant(x)
if isinstance(x, list):
try:
Expand All @@ -132,7 +131,7 @@ def convert_preprocessing_inputs(x):
# If ragged conversion failed return to the numpy error.
raise e
# If we have a string input, use tf.tensor.
if numpy_x.dtype.type is np.str_:
if numpy_x.dtype.type is np.str_ or numpy_x.dtype.type is np.bytes_:
return tf.convert_to_tensor(x)
# Numpy will default to int64, int32 works with more ops.
if numpy_x.dtype == np.int64:
Expand Down
11 changes: 11 additions & 0 deletions keras_nlp/src/utils/tensor_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,17 @@ def test_strings(self):
self.assertIsInstance(outputs, list)
self.assertEqual(outputs, inputs)

def test_bytestrings(self):
inputs = ["one".encode("utf-8"), "two".encode("utf-8")]
# Convert to tf.
outputs = convert_preprocessing_inputs(inputs)
self.assertIsInstance(outputs, tf.Tensor)
self.assertAllEqual(outputs, tf.constant(inputs))
# Convert from tf.
outputs = convert_preprocessing_outputs(outputs)
self.assertIsInstance(outputs, list)
self.assertEqual(outputs, [x.decode("utf-8") for x in inputs])

def test_ragged(self):
inputs = [np.ones((1, 3)), np.ones((1, 2))]
# Convert to tf.
Expand Down

0 comments on commit fb3dc3b

Please sign in to comment.