Skip to content

Commit

Permalink
Implement compute_output_spec() for tokenizers with vocabulary. (kera…
Browse files Browse the repository at this point in the history
…s-team#1523)

* Implement compute_output_spec() for tokenizers with vocabulary. (restarted from new point in master branch)

* Remove type annotation from compute_output_spec() in tokenizers
  • Loading branch information
briango28 authored Mar 29, 2024
1 parent e8f75c8 commit 5341426
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 0 deletions.
6 changes: 6 additions & 0 deletions keras_nlp/tokenizers/byte_pair_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import Iterable
from typing import List

import keras
import regex as re
import tensorflow as tf

Expand Down Expand Up @@ -605,6 +606,11 @@ def detokenize(self, inputs):
outputs = tf.squeeze(outputs, 0)
return outputs

def compute_output_spec(self, input_spec):
return keras.KerasTensor(
input_spec.shape + (self.sequence_length,), dtype=self.compute_dtype
)

def _transform_bytes(self, tokens):
"""Map token bytes to unicode using `byte2unicode`."""
split_bytes = tf.strings.bytes_split(tokens)
Expand Down
6 changes: 6 additions & 0 deletions keras_nlp/tokenizers/sentence_piece_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os
from typing import List

import keras
import tensorflow as tf

from keras_nlp.api_export import keras_nlp_export
Expand Down Expand Up @@ -255,3 +256,8 @@ def detokenize(self, inputs):
if unbatched:
outputs = tf.squeeze(outputs, 0)
return outputs

def compute_output_spec(self, input_spec):
return keras.KerasTensor(
input_spec.shape + (self.sequence_length,), dtype=self.compute_dtype
)
6 changes: 6 additions & 0 deletions keras_nlp/tokenizers/word_piece_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Iterable
from typing import List

import keras
import tensorflow as tf

from keras_nlp.api_export import keras_nlp_export
Expand Down Expand Up @@ -528,3 +529,8 @@ def detokenize(self, inputs):
if unbatched:
outputs = tf.squeeze(outputs, 0)
return outputs

def compute_output_spec(self, input_spec):
return keras.KerasTensor(
input_spec.shape + (self.sequence_length,), dtype=self.compute_dtype
)

0 comments on commit 5341426

Please sign in to comment.