From 965025c1891b6f08ec118ece473d3508d562c58c Mon Sep 17 00:00:00 2001 From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Date: Thu, 10 Oct 2024 13:30:26 +0200 Subject: [PATCH 1/3] fix: Add support for huggingface_hub==0.25.0 (#73) --- model2vec/distill/distillation.py | 9 ++++++++- tests/test_distillation.py | 8 +++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/model2vec/distill/distillation.py b/model2vec/distill/distillation.py index 2d8001f..35a98e3 100644 --- a/model2vec/distill/distillation.py +++ b/model2vec/distill/distillation.py @@ -3,7 +3,6 @@ import numpy as np from huggingface_hub import model_info -from huggingface_hub.utils._errors import RepositoryNotFoundError from sklearn.decomposition import PCA from tokenizers.models import BPE, Unigram from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerFast @@ -16,6 +15,14 @@ from model2vec.distill.utils import select_optimal_device from model2vec.model import StaticModel +try: + # For huggingface_hub>=0.25.0 + from huggingface_hub.errors import RepositoryNotFoundError +except ImportError: + # For huggingface_hub<0.25.0 + from huggingface_hub.utils._errors import RepositoryNotFoundError + + logger = logging.getLogger(__name__) diff --git a/tests/test_distillation.py b/tests/test_distillation.py index b4abac2..36da266 100644 --- a/tests/test_distillation.py +++ b/tests/test_distillation.py @@ -4,13 +4,19 @@ import numpy as np import pytest -from huggingface_hub.utils._errors import RepositoryNotFoundError from pytest import LogCaptureFixture from transformers import AutoModel, BertTokenizerFast from model2vec.distill.distillation import _clean_vocabulary, _post_process_embeddings, distill, distill_from_model from model2vec.model import StaticModel +try: + # For huggingface_hub>=0.25.0 + from huggingface_hub.errors import RepositoryNotFoundError +except ImportError: + # For huggingface_hub<0.25.0 + from huggingface_hub.utils._errors import RepositoryNotFoundError + rng = np.random.default_rng() From 2ce3c97badd539985b371a428b3ad582b092eb1a Mon Sep 17 00:00:00 2001 From: Thomas van Dongen Date: Thu, 10 Oct 2024 19:18:39 +0200 Subject: [PATCH 2/3] Bump version (#74) * Bumped version * Bumped version * Bumped version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 61ef3a5..ae7af19 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ name = "model2vec" description = "Distill a Small Fast Model from any Sentence Transformer" readme = { file = "README.md", content-type = "text/markdown" } license = { file = "LICENSE" } -version = "0.2.3" +version = "0.2.4" requires-python = ">=3.10" authors = [{ name = "Stéphan Tulkens", email = "stephantul@gmail.com"}, {name = "Thomas van Dongen", email = "thomas123@live.nl"}] From 84570b680b7ee49aeb0bb8267a6612b8f42b6af7 Mon Sep 17 00:00:00 2001 From: Thomas van Dongen Date: Fri, 11 Oct 2024 15:30:29 +0200 Subject: [PATCH 3/3] fix: Fix token type ids not supported (#77) * Fixed token type ids not supported for every model error * Updated test * Updated test * Changed to inspect --- model2vec/distill/inference.py | 24 ++++++++++++++++-------- tests/conftest.py | 6 +++++- uv.lock | 2 +- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/model2vec/distill/inference.py b/model2vec/distill/inference.py index 6459336..e480cbf 100644 --- a/model2vec/distill/inference.py +++ b/model2vec/distill/inference.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +import inspect import logging from pathlib import Path from typing import Protocol @@ -127,20 +128,27 @@ def create_output_embeddings_from_model_name( for batch_idx in tqdm(range(0, len(stacked), _DEFAULT_BATCH_SIZE)): batch = stacked[batch_idx : batch_idx + _DEFAULT_BATCH_SIZE].to(model.device) with torch.no_grad(): - # NOTE: we create these masks because nomic embed requires them. - # Normally, we could set them to None - token_type_ids = torch.zeros_like(batch) attention_mask = torch.ones_like(batch) - encoded: BaseModelOutputWithPoolingAndCrossAttentions = model( - input_ids=batch.to(device), attention_mask=attention_mask, token_type_ids=token_type_ids - ) - out: torch.Tensor = encoded.last_hidden_state + # Prepare model inputs + model_inputs = {"input_ids": batch.to(device), "attention_mask": attention_mask} + + # Add token_type_ids only if the model supports it + if "token_type_ids" in inspect.getfullargspec(model.forward).args: + model_inputs["token_type_ids"] = torch.zeros_like(batch) + + # Perform the forward pass + encoded_output: BaseModelOutputWithPoolingAndCrossAttentions = model(**model_inputs) + out: torch.Tensor = encoded_output.last_hidden_state # NOTE: If the dtype is bfloat 16, we convert to float32, # because numpy does not suport bfloat16 # See here: https://github.com/numpy/numpy/issues/19808 if out.dtype == torch.bfloat16: out = out.float() - intermediate_weights.append(out[:, 1].cpu().numpy()) + + # Add the output to the intermediate weights + intermediate_weights.append(out[:, 1].detach().cpu().numpy()) + + # Concatenate the intermediate weights out_weights = np.concatenate(intermediate_weights) return tokenizer.convert_ids_to_tokens(ids), out_weights diff --git a/tests/conftest.py b/tests/conftest.py index ae83008..f7d2222 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -44,7 +44,7 @@ def to(self, device: str) -> MockPreTrainedModel: self.device = device return self - def __call__(self, *args: Any, **kwargs: Any) -> Any: + def forward(self, *args: Any, **kwargs: Any) -> Any: # Simulate a last_hidden_state output for a transformer model batch_size, seq_length = kwargs["input_ids"].shape # Return a tensor of shape (batch_size, seq_length, 768) @@ -56,6 +56,10 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: }, ) + def __call__(self, *args: Any, **kwargs: Any) -> Any: + # Simply call the forward method to simulate the same behavior as transformers models + return self.forward(*args, **kwargs) + return MockPreTrainedModel() diff --git a/uv.lock b/uv.lock index 73c7d3b..30ee1f4 100644 --- a/uv.lock +++ b/uv.lock @@ -433,7 +433,7 @@ wheels = [ [[package]] name = "model2vec" -version = "0.2.2" +version = "0.2.4" source = { editable = "." } dependencies = [ { name = "click" },