Skip to content

Commit

Permalink
V0.15.1.dev1 (#1844)
Browse files Browse the repository at this point in the history
* Preprocessing decorator fixes (#1843)

* Fix handling bytesting input to tokenizers, preprocessing

* Fix no convert scope in multithreaded contexts

* Version bump dev release
  • Loading branch information
mattdangerw authored Sep 19, 2024
1 parent 8390c65 commit fbffd33
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 7 deletions.
11 changes: 5 additions & 6 deletions keras_nlp/src/utils/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,19 @@


NO_CONVERT_COUNTER = threading.local()
NO_CONVERT_COUNTER.count = 0


@contextlib.contextmanager
def no_convert_scope():
try:
NO_CONVERT_COUNTER.count += 1
NO_CONVERT_COUNTER.count = getattr(NO_CONVERT_COUNTER, "count", 0) + 1
yield
finally:
NO_CONVERT_COUNTER.count -= 1
NO_CONVERT_COUNTER.count = getattr(NO_CONVERT_COUNTER, "count", 0) - 1


def in_no_convert_scope():
return NO_CONVERT_COUNTER.count > 0
return getattr(NO_CONVERT_COUNTER, "count", 0) > 0


def preprocessing_function(fn):
Expand Down Expand Up @@ -119,7 +118,7 @@ def convert_preprocessing_inputs(x):
return {k: convert_preprocessing_inputs(x[k]) for k, v in x.items()}
if isinstance(x, tuple):
return tuple(convert_preprocessing_inputs(v) for v in x)
if isinstance(x, str):
if isinstance(x, (str, bytes)):
return tf.constant(x)
if isinstance(x, list):
try:
Expand All @@ -132,7 +131,7 @@ def convert_preprocessing_inputs(x):
# If ragged conversion failed return to the numpy error.
raise e
# If we have a string input, use tf.tensor.
if numpy_x.dtype.type is np.str_:
if numpy_x.dtype.type is np.str_ or numpy_x.dtype.type is np.bytes_:
return tf.convert_to_tensor(x)
# Numpy will default to int64, int32 works with more ops.
if numpy_x.dtype == np.int64:
Expand Down
11 changes: 11 additions & 0 deletions keras_nlp/src/utils/tensor_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,17 @@ def test_strings(self):
self.assertIsInstance(outputs, list)
self.assertEqual(outputs, inputs)

def test_bytestrings(self):
inputs = ["one".encode("utf-8"), "two".encode("utf-8")]
# Convert to tf.
outputs = convert_preprocessing_inputs(inputs)
self.assertIsInstance(outputs, tf.Tensor)
self.assertAllEqual(outputs, tf.constant(inputs))
# Convert from tf.
outputs = convert_preprocessing_outputs(outputs)
self.assertIsInstance(outputs, list)
self.assertEqual(outputs, [x.decode("utf-8") for x in inputs])

def test_ragged(self):
inputs = [np.ones((1, 3)), np.ones((1, 2))]
# Convert to tf.
Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/src/version_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from keras_nlp.src.api_export import keras_nlp_export

# Unique source of truth for the version number.
__version__ = "0.15.1.dev0"
__version__ = "0.15.1.dev1"


@keras_nlp_export("keras_nlp.version")
Expand Down

0 comments on commit fbffd33

Please sign in to comment.