Skip to content

Commit

Permalink
test only trained models for now
Browse files Browse the repository at this point in the history
  • Loading branch information
gpucce committed Oct 22, 2023
1 parent f9dcbde commit 20a2cb3
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions tests/test_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(False)

models_to_test = open_clip.list_generative_models()
models_to_test = open_clip.list_generative_models().difference(
{"coca_roberta-ViT-B-32", "coca_base"}
)


@pytest.mark.generative_regression_test
Expand Down Expand Up @@ -51,7 +53,7 @@ def test_generate_with_data(
gt_text = torch.load(gt_text_path)
with torch.no_grad(), torch.cuda.amp.autocast():
y_text = util_test.model_generate(model, preprocess_val, input_image)
assert (y_text == gt_text), f"text output differs @ {gt_text_path}"
assert y_text == gt_text, f"text output differs @ {gt_text_path}"
# logits
y_logits = util_test.forward_model(model, model_name, preprocess_val, input_image, gt_text)[
"logits"
Expand Down

0 comments on commit 20a2cb3

Please sign in to comment.