Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
stephantul committed Oct 3, 2024
1 parent 4d0e085 commit 98308eb
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 10 deletions.
7 changes: 3 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.pre_tokenizers import Whitespace
from transformers import AutoModel, BertTokenizerFast
from transformers import AutoModel, AutoTokenizer


@pytest.fixture
Expand All @@ -26,10 +26,9 @@ def mock_tokenizer() -> Tokenizer:


@pytest.fixture
def mock_berttokenizer() -> BertTokenizerFast:
def mock_berttokenizer() -> AutoTokenizer:
"""Load the real BertTokenizerFast from the provided tokenizer.json file."""
tokenizer_path = Path("tests/data/test_tokenizer/tokenizer.json")
return BertTokenizerFast(tokenizer_file=str(tokenizer_path))
return AutoTokenizer.from_pretrained("tests/data/test_tokenizer")


@pytest.fixture
Expand Down
83 changes: 77 additions & 6 deletions tests/test_distillation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from importlib import import_module
from unittest.mock import MagicMock, patch

Expand All @@ -6,12 +7,86 @@
from pytest import LogCaptureFixture
from transformers import AutoModel, BertTokenizerFast

from model2vec.distill.distillation import _clean_vocabulary, _post_process_embeddings, distill
from model2vec.distill.distillation import _clean_vocabulary, _post_process_embeddings, distill, distill_from_model
from model2vec.model import StaticModel

rng = np.random.default_rng()


@pytest.mark.parametrize(
"vocabulary, use_subword, pca_dims, apply_zipf",
[
(None, True, 256, True), # Output vocab with subwords, PCA applied
(
["wordA", "wordB"],
False,
4,
False,
), # Custom vocab without subword , PCA applied
(["wordA", "wordB"], True, 4, False), # Custom vocab with subword, PCA applied
(None, True, None, True), # No PCA applied
(["wordA", "wordB"], False, 4, True), # Custom vocab without subwords PCA and Zipf applied
(None, False, 256, True), # use_subword = False without passing a vocabulary should raise an error
],
)
@patch.object(import_module("model2vec.distill.distillation"), "model_info")
@patch("transformers.AutoModel.from_pretrained")
def test_distill_from_moddel(
mock_auto_model: MagicMock,
mock_model_info: MagicMock,
mock_berttokenizer: BertTokenizerFast,
mock_transformer: AutoModel,
vocabulary: list[str] | None,
use_subword: bool,
pca_dims: int | None,
apply_zipf: bool,
) -> None:
"""Test distill function with different parameters."""
# Mock the return value of model_info to avoid calling the Hugging Face API
mock_model_info.return_value = type("ModelInfo", (object,), {"cardData": {"language": "en"}})

# Patch the tokenizers and models to return the real BertTokenizerFast and mock model instances
# mock_auto_tokenizer.return_value = mock_berttokenizer
mock_auto_model.return_value = mock_transformer

if vocabulary is None and not use_subword:
with pytest.raises(ValueError):
static_model = distill_from_model(
model=mock_transformer,
tokenizer=mock_berttokenizer,
vocabulary=vocabulary,
device="cpu",
pca_dims=pca_dims,
apply_zipf=apply_zipf,
use_subword=use_subword,
)
else:
# Call the distill function with the parametrized inputs
static_model = distill_from_model(
model=mock_transformer,
tokenizer=mock_berttokenizer,
vocabulary=vocabulary,
device="cpu",
pca_dims=pca_dims,
apply_zipf=apply_zipf,
use_subword=use_subword,
)

static_model2 = distill(
model_name="tests/data/test_tokenizer",
vocabulary=vocabulary,
device="cpu",
pca_dims=pca_dims,
apply_zipf=apply_zipf,
use_subword=use_subword,
)

assert static_model.embedding.weight.shape == static_model2.embedding.weight.shape
assert static_model.config == static_model2.config
assert json.loads(static_model.tokenizer.to_str()) == json.loads(static_model2.tokenizer.to_str())
assert static_model.base_model_name == static_model2.base_model_name


@pytest.mark.parametrize(
"vocabulary, use_subword, pca_dims, apply_zipf, expected_shape",
[
Expand All @@ -30,13 +105,10 @@
],
)
@patch.object(import_module("model2vec.distill.distillation"), "model_info")
@patch("transformers.AutoTokenizer.from_pretrained")
@patch("transformers.AutoModel.from_pretrained")
def test_distill(
mock_auto_model: MagicMock,
mock_auto_tokenizer: MagicMock,
mock_model_info: MagicMock,
mock_berttokenizer: BertTokenizerFast,
mock_transformer: AutoModel,
vocabulary: list[str] | None,
use_subword: bool,
Expand All @@ -49,10 +121,9 @@ def test_distill(
mock_model_info.return_value = type("ModelInfo", (object,), {"cardData": {"language": "en"}})

# Patch the tokenizers and models to return the real BertTokenizerFast and mock model instances
mock_auto_tokenizer.return_value = mock_berttokenizer
mock_auto_model.return_value = mock_transformer

model_name = "mock-model"
model_name = "tests/data/test_tokenizer"

if vocabulary is None and not use_subword:
with pytest.raises(ValueError):
Expand Down

0 comments on commit 98308eb

Please sign in to comment.