diff --git a/tests/conftest.py b/tests/conftest.py index 382c3ad..ae83008 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,7 +9,7 @@ from tokenizers import Tokenizer from tokenizers.models import WordLevel from tokenizers.pre_tokenizers import Whitespace -from transformers import AutoModel, BertTokenizerFast +from transformers import AutoModel, AutoTokenizer @pytest.fixture @@ -26,10 +26,9 @@ def mock_tokenizer() -> Tokenizer: @pytest.fixture -def mock_berttokenizer() -> BertTokenizerFast: +def mock_berttokenizer() -> AutoTokenizer: """Load the real BertTokenizerFast from the provided tokenizer.json file.""" - tokenizer_path = Path("tests/data/test_tokenizer/tokenizer.json") - return BertTokenizerFast(tokenizer_file=str(tokenizer_path)) + return AutoTokenizer.from_pretrained("tests/data/test_tokenizer") @pytest.fixture diff --git a/tests/test_distillation.py b/tests/test_distillation.py index a43ccdf..32830c9 100644 --- a/tests/test_distillation.py +++ b/tests/test_distillation.py @@ -1,3 +1,4 @@ +import json from importlib import import_module from unittest.mock import MagicMock, patch @@ -6,12 +7,86 @@ from pytest import LogCaptureFixture from transformers import AutoModel, BertTokenizerFast -from model2vec.distill.distillation import _clean_vocabulary, _post_process_embeddings, distill +from model2vec.distill.distillation import _clean_vocabulary, _post_process_embeddings, distill, distill_from_model from model2vec.model import StaticModel rng = np.random.default_rng() +@pytest.mark.parametrize( + "vocabulary, use_subword, pca_dims, apply_zipf", + [ + (None, True, 256, True), # Output vocab with subwords, PCA applied + ( + ["wordA", "wordB"], + False, + 4, + False, + ), # Custom vocab without subword , PCA applied + (["wordA", "wordB"], True, 4, False), # Custom vocab with subword, PCA applied + (None, True, None, True), # No PCA applied + (["wordA", "wordB"], False, 4, True), # Custom vocab without subwords PCA and Zipf applied + (None, False, 256, True), # use_subword = False without passing a vocabulary should raise an error + ], +) +@patch.object(import_module("model2vec.distill.distillation"), "model_info") +@patch("transformers.AutoModel.from_pretrained") +def test_distill_from_moddel( + mock_auto_model: MagicMock, + mock_model_info: MagicMock, + mock_berttokenizer: BertTokenizerFast, + mock_transformer: AutoModel, + vocabulary: list[str] | None, + use_subword: bool, + pca_dims: int | None, + apply_zipf: bool, +) -> None: + """Test distill function with different parameters.""" + # Mock the return value of model_info to avoid calling the Hugging Face API + mock_model_info.return_value = type("ModelInfo", (object,), {"cardData": {"language": "en"}}) + + # Patch the tokenizers and models to return the real BertTokenizerFast and mock model instances + # mock_auto_tokenizer.return_value = mock_berttokenizer + mock_auto_model.return_value = mock_transformer + + if vocabulary is None and not use_subword: + with pytest.raises(ValueError): + static_model = distill_from_model( + model=mock_transformer, + tokenizer=mock_berttokenizer, + vocabulary=vocabulary, + device="cpu", + pca_dims=pca_dims, + apply_zipf=apply_zipf, + use_subword=use_subword, + ) + else: + # Call the distill function with the parametrized inputs + static_model = distill_from_model( + model=mock_transformer, + tokenizer=mock_berttokenizer, + vocabulary=vocabulary, + device="cpu", + pca_dims=pca_dims, + apply_zipf=apply_zipf, + use_subword=use_subword, + ) + + static_model2 = distill( + model_name="tests/data/test_tokenizer", + vocabulary=vocabulary, + device="cpu", + pca_dims=pca_dims, + apply_zipf=apply_zipf, + use_subword=use_subword, + ) + + assert static_model.embedding.weight.shape == static_model2.embedding.weight.shape + assert static_model.config == static_model2.config + assert json.loads(static_model.tokenizer.to_str()) == json.loads(static_model2.tokenizer.to_str()) + assert static_model.base_model_name == static_model2.base_model_name + + @pytest.mark.parametrize( "vocabulary, use_subword, pca_dims, apply_zipf, expected_shape", [ @@ -30,13 +105,10 @@ ], ) @patch.object(import_module("model2vec.distill.distillation"), "model_info") -@patch("transformers.AutoTokenizer.from_pretrained") @patch("transformers.AutoModel.from_pretrained") def test_distill( mock_auto_model: MagicMock, - mock_auto_tokenizer: MagicMock, mock_model_info: MagicMock, - mock_berttokenizer: BertTokenizerFast, mock_transformer: AutoModel, vocabulary: list[str] | None, use_subword: bool, @@ -49,10 +121,9 @@ def test_distill( mock_model_info.return_value = type("ModelInfo", (object,), {"cardData": {"language": "en"}}) # Patch the tokenizers and models to return the real BertTokenizerFast and mock model instances - mock_auto_tokenizer.return_value = mock_berttokenizer mock_auto_model.return_value = mock_transformer - model_name = "mock-model" + model_name = "tests/data/test_tokenizer" if vocabulary is None and not use_subword: with pytest.raises(ValueError):