Skip to content

Commit

Permalink
preset testcase added
Browse files Browse the repository at this point in the history
BytePairTokenizer must not split sequences of \n (keras-team#1910)

* fix for loading of special tokens in Llama tokenizer

* fix for Llama tokenizer which can have multiple end tokens

* bug fix

* adding some missing tokens to Llama3 tokenizer

* fixed tests and Llama3Tokenizer init.

* now loading correct eos_token config from Hugging Face checkpoint. Using hack for Keras checkpoint because it does not have this info

* fix for BytePairTokenizer to make Lllama3-instruct work in chat: \n\n sequences are significant in the chat template and must be preserved by the tokenizer

---------

Co-authored-by: Martin Görner <[email protected]>

fix for generation that never stops in Llama3-Instruct variants (keras-team#1904)

* fix for loading of special tokens in Llama tokenizer

* fix for Llama tokenizer which can have multiple end tokens

* bug fix

* adding some missing tokens to Llama3 tokenizer

* fixed tests and Llama3Tokenizer init.

* now loading correct eos_token config from Hugging Face checkpoint. Using hack for Keras checkpoint because it does not have this info

---------

Co-authored-by: Martin Görner <[email protected]>

fix failing JAX GPU test (keras-team#1911)

* fix tests

* fix test

Refactor `MMDiT`, add `ImageToImage` and `Inpaint` for SD3 (keras-team#1909)

* Refactor `MMDiT` and add `ImageToImage`

* Update model version

* Fix minor bugs.

* Add `Inpaint` for SD3.

* Fix warnings of MMDiT.

* Addcomment to Inpaint

* Simplify `MMDiT` implementation and info of `summary()`.

* Refactor `generate()` API of `TextToImage`, `ImageToImage` and `Inpaint`.

Minor bug fix (keras-team#1915)

Change to image_converter.image_size since it is a tuple and it's not a callable function.

[Mix Transformer] Add Presets for MiTB0...MiTB5 (keras-team#1893)

* add presets for mit

* add standin paths

* register presets in __init__.py

* fix op in overlapping patching and embedding, start adding conversion utils

* style

* add padding to MiT patchingandembedding

* update to support other presets

* update conversin script

* fix link for b5

* add cityscapes weights

* update presets

* update presets

* update conversion script to make directories

* use save_preset

* change name of output dir

* add preprocessor flow

* api gen and add preprocessor to mits

* conform to new image classifier style

* format

* resizing image converter -> ImageConverter

* address comments

refactoring

remove default resizing for vision backbones (keras-team#1916)

* remove defailt resizing

* fix GPU test

Update VGG model to be compatible with HF and add conversion scripts (keras-team#1914)

Deeplab presets (keras-team#1918)

* add preset configurations for deeplabv3

* fix uri

* Add training details

update presets to point to the main Keras Kaggle page (keras-team#1921)

* update presets to point to the main keras page

* update mit path

Added test for the way BytePairTokenizer handles the \n\n sequence, which is important in Lama chat templates (keras-team#1912)

* added test for the way BytePairTokenizer handles the \n\n sequence, which is important in Lama chat templates

* un commented the test lines that were commented by mistake

* fixed linter errors

Task models fix (keras-team#1922)

* added test for the way BytePairTokenizer handles the \n\n sequence, which is important in Lama chat templates

* fix for wrongly configured task models LLama, PaliGemma, Mistral and Phi3 + test

* comments

* un commented the test lines that were commented by mistake

* fixed linter errors

adding option strip_prompt to generate() (keras-team#1913)

* added test for the way BytePairTokenizer handles the \n\n sequence, which is important in Lama chat templates

* un commented the test lines that were commented by mistake

* fixed linter errors

* added options strip_prompt to generate()

* fix for tensorflow: the compiled version of generate(strip_prompt=True) now works + code refactoring to make it more understandable

* added test for generate(strip_prompt=True)

* minor edits

Layout map for Llama (keras-team#1923)

* added test for the way BytePairTokenizer handles the \n\n sequence, which is important in Lama chat templates

* un commented the test lines that were commented by mistake

* fixed linter errors

* added default layout map for Llama

* minor fixes in tests

Update deeplab_v3_presets.py (keras-team#1924)

Add paths to get SAM weights from (keras-team#1925)

Two fixes for image resizing in preprocessing (keras-team#1927)

1. Properly display when are not resizing the input image in
   `model.summary()`
2. Allow setting the `image_size` directly on a preprocessing layer.

2. is just to allow a more consistent way to set the input shape
across tasks. We now have:

```python
text_classifier = keras_hub.models.TextClassifer.from_preset(
    "bert_base_en",
)
text_classifier.preprocessor.sequence_length = 256

image_classifier = keras_hub.models.TextClassifer.from_preset(
    "bert_base_en",
)
image_classifier.preprocessor.image_size = (256, 256)

multi_modal_lm = keras_hub.models.CausalLM.from_preset(
    "some_preset",
)
multi_modal_lm.preprocessor.sequence_length = 256
multi_modal_lm.preprocessor.image_size = (256, 256)
```

add back default image resizing (keras-team#1926)

Update deeplab_v3_presets.py (keras-team#1928)

* Update deeplab_v3_presets.py

* Update deeplab_v3_presets.py

Update PaliGemma to remove `include_rescaling` arg (keras-team#1917)

* update PaliGemma

* update conversion script

* fix GPU tests

fix path (keras-team#1929)

* fix path

* nit

Fix paligemma checkpoint conversion script (keras-team#1931)

* add back default image resizing

* fix bug in image converter

* fix paligemma checkpoint conversion file

* fix preset name

* remove debug code

* revert unintended changes

update preset path to point to latest version of models (keras-team#1932)

Update sdv3 path (keras-team#1934)

update sam docstring to show correct backbone in docstring (keras-team#1936)

Convert input dict to tensors during train_on_batch (keras-team#1919)

Register VGG presets. (keras-team#1935)

* register vgg preset

* nit

* nit

* nit

Add ResNetVD presets (keras-team#1897)

* Add ResNetVD presets

* Updated Kaggle handles

* Add weight conversion script for ResNet_vd

* Add usage

rebase conflict resolved

conflict resolve

Update sam_presets.py (keras-team#1940)

Update vit_det_backbone.py (keras-team#1941)

fix gpu test (keras-team#1939)

* fix gpu test

* cast input

* update dtype

* change to resnet preset

* remove arg

Added Support for Returning Attention Scores in TransformerEncoder call (keras-team#1879)

* Added: Return attention scores argument to transformer encoder

* Added: docstring for return_attention_scores and added a test to chek the working of the argument

* Fixed: Test case by removing print stmts and using self.assertAllEqual

* Fixed: Linting

Mark preset tests as large (keras-team#1942)

* fix tests

* fix test

* Update preset_utils_test.py

version bump to 0.17.0.dev0 (keras-team#1944)

Update stable_diffusion_3_presets.py (keras-team#1946)

[Semantic Segmentation] - Add SegFormer Architecture, Weight Conversion Script and Presets (keras-team#1883)

* initial commit - tf-based, kcv

* porting to keras_hub structure - removing aliases, presets, etc.

* enable instantiation of segformer backbone with custom MiT backbone

* remove num_classes from backbone

* fix input

* add imports to __init__

* update preset

* update docstrings

* add basic tests

* remove redundant imports

* update docstrings

* remove unused import

* running api_gen.py

* undo refactor of mit

* update docstrings

* add presets for mit

* add standin paths

* add presets for segformer backbone

* register presets in __init__.py

* addressing comments

* addressing comments

* addressing comments

* update most tests

* add remaining tests

* remove copyright

* fix test

* override from_config

* fix op in overlapping patching and embedding, start adding conversion utils

* style

* add padding to MiT patchingandembedding

* update to support other presets

* update conversin script

* fix link for b5

* add cityscapes weights

* update presets

* update presets

* update conversion script to make directories

* use save_preset

* change name of output dir

* add preprocessor flow

* api gen and add preprocessor to mits

* conform to new image classifier style

* format

* resizing image converter -> ImageConverter

* merge mit branch into segformer branch

* add preprocessor and converter

* address comments

* clarify backbone usage

* add conversion script

* numerical equivalence changes

* fix numerical inaccuracies

* update conversion script

* update conversion script

* remove transpose

* add preprocessor to segformer class

* fix preset path

* update test shape

* update presets

* update test shape

* expand docstrings

* add rescaling and normalization to preprocessor

* remove backbone presets, remove copyrights, remove backbone cls from segmenter

* remove copyright and unused import

* apply same transformation to masks as input images

* fix import

* fix shape in tests

Update readme (keras-team#1949)

* Update README.md

* Update README.md

Update llama_backbone.py docstring (keras-team#1950)

Update path (keras-team#1953)

Update preset path for keras.io.

There is no LLaMA2 in keras.io https://keras.io/api/keras_hub/models/llama2

This is the actual link:
https://keras.io/api/keras_hub/models/llama2

For Vicuna it does not have it's own model direcotry, since it is also the part of Llama,, updated the path.

Update SD3 init parameters (replacing `height`, `width` with `image_shape`) (keras-team#1951)

* Replace SD3 `height` and `width` with `image_shape`

* Update URI

* Revert comment

* Update SD3 handle

* Replace `height` and `width` with `image_shape`

* Update docstrings

* Fix CI

Update docstring (keras-team#1954)

AudioConverter is registered as "keras_hub.layers.WhisperAudioConverter" and not as part of models.

 updated Mobilenet backbone to match it with torch implementation

timm script added

checkpoint conversion added

Refactoring
  • Loading branch information
ushareng committed Oct 24, 2024
1 parent c40088d commit 318e7d6
Show file tree
Hide file tree
Showing 101 changed files with 5,527 additions and 1,192 deletions.
30 changes: 11 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
[![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/keras-team/keras-hub/issues)

> [!IMPORTANT]
> 📢 KerasNLP is becoming KerasHub! 📢 Read
> 📢 KerasNLP is now KerasHub! 📢 Read
> [the announcement](https://github.com/keras-team/keras-hub/issues/1831).
>
> We have renamed the repo to KerasHub in preparation for the release, but have not yet
Expand All @@ -26,7 +26,7 @@ All models support JAX, TensorFlow, and PyTorch from a single model
definition and can be fine-tuned on GPUs and TPUs out of the box. Models can
be trained on individual accelerators with built-in PEFT techniques, or
fine-tuned at scale with model and data parallel training. See our
[Getting Started guide](https://keras.io/guides/keras_nlp/getting_started)
[Getting Started guide](https://keras.io/guides/keras_hub/getting_started)
to start learning our API. Browse our models on
[Kaggle](https://www.kaggle.com/organizations/keras/models).
We welcome contributions.
Expand All @@ -35,9 +35,9 @@ We welcome contributions.

### For everyone

- [Home Page](https://keras.io/keras_nlp)
- [Developer Guides](https://keras.io/guides/keras_nlp)
- [API Reference](https://keras.io/api/keras_nlp)
- [Home Page](https://keras.io/keras_hub)
- [Developer Guides](https://keras.io/guides/keras_hub)
- [API Reference](https://keras.io/api/keras_hub)
- [Pre-trained Models](https://www.kaggle.com/organizations/keras/models)

### For contributors
Expand All @@ -56,7 +56,7 @@ Fine-tune a BERT classifier on IMDb movie reviews:
import os
os.environ["KERAS_BACKEND"] = "jax" # Or "tensorflow" or "torch"!

import keras_nlp
import keras_hub
import tensorflow_datasets as tfds

imdb_train, imdb_test = tfds.load(
Expand All @@ -67,7 +67,7 @@ imdb_train, imdb_test = tfds.load(
)

# Load a BERT model.
classifier = keras_nlp.models.Classifier.from_preset(
classifier = keras_hub.models.Classifier.from_preset(
"bert_base_en",
num_classes=2,
activation="softmax",
Expand All @@ -79,25 +79,17 @@ classifier.fit(imdb_train, validation_data=imdb_test)
classifier.predict(["What an amazing movie!", "A total waste of my time."])
```

Try it out [in a colab](https://colab.research.google.com/gist/mattdangerw/e457e42d5ea827110c8d5cb4eb9d9a07/kerasnlp-quickstart.ipynb).
Try it out [in a colab](https://colab.research.google.com/drive/1gSWkh3yOLwmKAaNh2dQQ6kQIlnGte7P2?usp=sharing).
For more in depth guides and examples, visit
[keras.io/keras_nlp](https://keras.io/keras_nlp/).
[keras.io/keras_hub](https://keras.io/keras_hub/).

## Installation

KerasHub is currently in pre-release. Note that pre-release versions may
introduce breaking changes to the API in future versions. For a stable and
supported experience, we recommend installing `keras-nlp` version 0.15.1:

```bash
pip install keras-nlp==0.15.1
```

To try out the latest pre-release version of KerasHub, you can use
To try out the latest version of KerasHub, you can use
our nightly package:

```bash
pip install keras-hub-nightly
pip install keras-hub
```

KerasHub currently requires TensorFlow to be installed for use of the
Expand Down
7 changes: 7 additions & 0 deletions keras_hub/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
from keras_hub.src.models.densenet.densenet_image_converter import (
DenseNetImageConverter,
)
from keras_hub.src.models.mix_transformer.mix_transformer_image_converter import (
MiTImageConverter,
)
from keras_hub.src.models.mobilenet.mobilenet_image_converter import (
MobileNetImageConverter,
)
Expand All @@ -52,6 +55,10 @@
from keras_hub.src.models.sam.sam_image_converter import SAMImageConverter
from keras_hub.src.models.sam.sam_mask_decoder import SAMMaskDecoder
from keras_hub.src.models.sam.sam_prompt_encoder import SAMPromptEncoder
from keras_hub.src.models.segformer.segformer_image_converter import (
SegFormerImageConverter,
)
from keras_hub.src.models.vgg.vgg_image_converter import VGGImageConverter
from keras_hub.src.models.whisper.whisper_audio_converter import (
WhisperAudioConverter,
)
27 changes: 22 additions & 5 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@
from keras_hub.src.models.image_segmenter_preprocessor import (
ImageSegmenterPreprocessor,
)
from keras_hub.src.models.image_to_image import ImageToImage
from keras_hub.src.models.inpaint import Inpaint
from keras_hub.src.models.llama3.llama3_backbone import Llama3Backbone
from keras_hub.src.models.llama3.llama3_causal_lm import Llama3CausalLM
from keras_hub.src.models.llama3.llama3_causal_lm_preprocessor import (
Expand All @@ -200,11 +202,10 @@
MistralCausalLMPreprocessor,
)
from keras_hub.src.models.mistral.mistral_tokenizer import MistralTokenizer
from keras_hub.src.models.mix_transformer.mix_transformer_backbone import (
MiTBackbone,
)
from keras_hub.src.models.mix_transformer.mix_transformer_classifier import (
MiTImageClassifier,
from keras_hub.src.models.mit.mit_backbone import MiTBackbone
from keras_hub.src.models.mit.mit_image_classifier import MiTImageClassifier
from keras_hub.src.models.mit.mit_image_classifier_preprocessor import (
MiTImageClassifierPreprocessor,
)
from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone
from keras_hub.src.models.mobilenet.mobilenet_image_classifier import (
Expand Down Expand Up @@ -268,11 +269,24 @@
from keras_hub.src.models.sam.sam_image_segmenter_preprocessor import (
SAMImageSegmenterPreprocessor,
)
from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone
from keras_hub.src.models.segformer.segformer_image_segmenter import (
SegFormerImageSegmenter,
)
from keras_hub.src.models.segformer.segformer_image_segmenter_preprocessor import (
SegFormerImageSegmenterPreprocessor,
)
from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM
from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import (
StableDiffusion3Backbone,
)
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_image_to_image import (
StableDiffusion3ImageToImage,
)
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_inpaint import (
StableDiffusion3Inpaint,
)
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image import (
StableDiffusion3TextToImage,
)
Expand All @@ -291,6 +305,9 @@
from keras_hub.src.models.text_to_image import TextToImage
from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone
from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageClassifier
from keras_hub.src.models.vgg.vgg_image_classifier_preprocessor import (
VGGImageClassifierPreprocessor,
)
from keras_hub.src.models.vit_det.vit_det_backbone import ViTDetBackbone
from keras_hub.src.models.whisper.whisper_backbone import WhisperBackbone
from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer
Expand Down
35 changes: 28 additions & 7 deletions keras_hub/src/layers/modeling/transformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,12 @@ def build(self, inputs_shape):
self.built = True

def call(
self, inputs, padding_mask=None, attention_mask=None, training=None
self,
inputs,
padding_mask=None,
attention_mask=None,
training=None,
return_attention_scores=False,
):
"""Forward pass of the TransformerEncoder.
Expand All @@ -185,6 +190,7 @@ def call(
[batch_size, sequence_length, sequence_length].
training: a boolean indicating whether the layer should behave in
training mode or in inference mode.
return_attention_scores: a boolean indicating whether the output should be `(attention_output, attention_scores)` if `True` or `attention_output` if `False`. Defaults to `False`.
Returns:
A Tensor of the same shape as the `inputs`.
Expand All @@ -200,12 +206,24 @@ def call(
residual = x
if self.normalize_first:
x = self._self_attention_layer_norm(x)
x = self._self_attention_layer(
query=x,
value=x,
attention_mask=self_attention_mask,
training=training,
)

if return_attention_scores:
x, attention_scores = self._self_attention_layer(
query=x,
value=x,
attention_mask=self_attention_mask,
return_attention_scores=return_attention_scores,
training=training,
)
return x, attention_scores
else:
x = self._self_attention_layer(
query=x,
value=x,
attention_mask=self_attention_mask,
training=training,
)

x = self._self_attention_dropout(x, training=training)
x = x + residual
if not self.normalize_first:
Expand All @@ -222,6 +240,9 @@ def call(
if not self.normalize_first:
x = self._feedforward_layer_norm(x)

if return_attention_scores:
return x, attention_scores

return x

def get_config(self):
Expand Down
11 changes: 11 additions & 0 deletions keras_hub/src/layers/modeling/transformer_encoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,14 @@ def test_mask_propagation(self):
inputs._keras_mask = mask
outputs = encoder(inputs)
self.assertAllEqual(outputs._keras_mask, mask)

def test_attention_scores(self):
encoder = TransformerEncoder(intermediate_dim=4, num_heads=2)
inputs = random.uniform(shape=[1, 4, 6])
outputs, attention_scores = encoder(
inputs, return_attention_scores=True
)
self.assertAllEqual(outputs.shape, inputs.shape)

# attention scores shape (batch_size, num_of_attn_heads, seq_length, seq_length)
self.assertAllEqual(attention_scores.shape, [1, 2, 4, 4])
3 changes: 2 additions & 1 deletion keras_hub/src/layers/preprocessing/image_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,9 @@ def image_size(self, value):

@preprocessing_function
def call(self, inputs):
x = inputs
if self.image_size is not None:
x = self.resizing(inputs)
x = self.resizing(x)
if self.scale is not None:
x = x * self._expand_non_channel_dims(self.scale, x)
if self.offset is not None:
Expand Down
27 changes: 10 additions & 17 deletions keras_hub/src/layers/preprocessing/image_converter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
from keras import ops

from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
from keras_hub.src.models.pali_gemma.pali_gemma_backbone import (
PaliGemmaBackbone,
)
from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import (
PaliGemmaImageConverter,
)
from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone
from keras_hub.src.tests.test_case import TestCase


Expand Down Expand Up @@ -86,24 +84,19 @@ def test_from_preset_errors(self):
def test_save_to_preset(self):
save_dir = self.get_temp_dir()
converter = ImageConverter.from_preset(
"pali_gemma_3b_mix_224",
"resnet_50_imagenet",
interpolation="nearest",
)
converter.save_to_preset(save_dir)
# Save a tiny backbone so the preset is valid.
backbone = PaliGemmaBackbone(
vocabulary_size=100,
image_size=224,
num_layers=1,
num_query_heads=1,
num_key_value_heads=1,
hidden_dim=8,
intermediate_dim=16,
head_dim=8,
vit_patch_size=14,
vit_num_heads=1,
vit_hidden_dim=8,
vit_num_layers=1,
backbone = ResNetBackbone(
input_conv_filters=[64],
input_conv_kernel_sizes=[7],
stackwise_num_filters=[64, 64, 64],
stackwise_num_blocks=[2, 2, 2],
stackwise_num_strides=[1, 2, 2],
block_type="basic_block",
use_pre_activation=True,
)
backbone.save_to_preset(save_dir)

Expand Down
42 changes: 41 additions & 1 deletion keras_hub/src/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def generate(
inputs,
max_length=None,
stop_token_ids="auto",
strip_prompt=False,
):
"""Generate text given prompt `inputs`.
Expand Down Expand Up @@ -309,6 +310,9 @@ def generate(
specify a list of token id's the model should stop on. Note that
sequences of tokens will each be interpreted as a stop token,
multi-token stop sequences are not supported.
strip_prompt: Optional. By default, generate() returns the full prompt
followed by its completion generated by the model. If this option
is set to True, only the newly generated text is returned.
"""
# Setup our three main passes.
# 1. Optionally preprocessing strings to dense integer tensors.
Expand All @@ -326,6 +330,10 @@ def generate(
)
elif stop_token_ids == "auto":
stop_token_ids = [self.preprocessor.tokenizer.end_token_id]
# Some models like Llama3 use two end tokens: <|eot_id|> in
# "instruct" versions and <|end_of_text|> in others.
if hasattr(self.preprocessor.tokenizer, "end_token2_id"):
stop_token_ids.append(self.preprocessor.tokenizer.end_token2_id)

def preprocess(x):
return self.preprocessor.generate_preprocess(
Expand All @@ -335,6 +343,33 @@ def preprocess(x):
def generate(x):
return generate_function(x, stop_token_ids=stop_token_ids)

def strip_prompt_function(x, prompt):
# This function removes the prompt from the generated
# response, in a batch-friendly fashion.
y = {}
prompt_mask = prompt["padding_mask"]
seq_len = prompt_mask.shape[1]

# We need to shift every output sequence by the size of the prompt.
shifts = -ops.sum(ops.cast(prompt_mask, "int"), axis=1) % seq_len
ix = ops.arange(seq_len, dtype="int")
ix = ops.expand_dims(ix, axis=0) - ops.expand_dims(shifts, axis=1)

# This produces the desired shift (in fact a rollover).
def roll_sequence(seq):
return ops.take_along_axis(seq, ix, axis=1)

# The shifting rolls the content over so the prompt is at the end of
# the sequence and the generated text is at the beginning. We mask
# it to retain the generated text only.
y["padding_mask"] = ops.logical_xor(
roll_sequence(prompt_mask), roll_sequence(x["padding_mask"])
)
# we assume the mask is enough and there is no need to zero-out the values
y["token_ids"] = roll_sequence(x["token_ids"])

return y

def postprocess(x):
return self.preprocessor.generate_postprocess(x)

Expand All @@ -343,7 +378,12 @@ def postprocess(x):

if self.preprocessor is not None:
inputs = [preprocess(x) for x in inputs]
outputs = [generate(x) for x in inputs]

if strip_prompt:
outputs = [strip_prompt_function(generate(x), x) for x in inputs]
else:
outputs = [generate(x) for x in inputs]

if self.preprocessor is not None:
outputs = [postprocess(x) for x in outputs]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def test_saved_model(self):
cls=DeepLabV3Backbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
atol=0.00001,
)


Expand Down
Loading

0 comments on commit 318e7d6

Please sign in to comment.