Skip to content

Commit

Permalink
fix: Fix token type ids not supported (#77)
Browse files Browse the repository at this point in the history
* Fixed token type ids not supported for every model error

* Updated test

* Updated test

* Changed to inspect
  • Loading branch information
Pringled authored Oct 11, 2024
1 parent 2ce3c97 commit 84570b6
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 10 deletions.
24 changes: 16 additions & 8 deletions model2vec/distill/inference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
import inspect
import logging
from pathlib import Path
from typing import Protocol
Expand Down Expand Up @@ -127,20 +128,27 @@ def create_output_embeddings_from_model_name(
for batch_idx in tqdm(range(0, len(stacked), _DEFAULT_BATCH_SIZE)):
batch = stacked[batch_idx : batch_idx + _DEFAULT_BATCH_SIZE].to(model.device)
with torch.no_grad():
# NOTE: we create these masks because nomic embed requires them.
# Normally, we could set them to None
token_type_ids = torch.zeros_like(batch)
attention_mask = torch.ones_like(batch)
encoded: BaseModelOutputWithPoolingAndCrossAttentions = model(
input_ids=batch.to(device), attention_mask=attention_mask, token_type_ids=token_type_ids
)
out: torch.Tensor = encoded.last_hidden_state
# Prepare model inputs
model_inputs = {"input_ids": batch.to(device), "attention_mask": attention_mask}

# Add token_type_ids only if the model supports it
if "token_type_ids" in inspect.getfullargspec(model.forward).args:
model_inputs["token_type_ids"] = torch.zeros_like(batch)

# Perform the forward pass
encoded_output: BaseModelOutputWithPoolingAndCrossAttentions = model(**model_inputs)
out: torch.Tensor = encoded_output.last_hidden_state
# NOTE: If the dtype is bfloat 16, we convert to float32,
# because numpy does not suport bfloat16
# See here: https://github.com/numpy/numpy/issues/19808
if out.dtype == torch.bfloat16:
out = out.float()
intermediate_weights.append(out[:, 1].cpu().numpy())

# Add the output to the intermediate weights
intermediate_weights.append(out[:, 1].detach().cpu().numpy())

# Concatenate the intermediate weights
out_weights = np.concatenate(intermediate_weights)

return tokenizer.convert_ids_to_tokens(ids), out_weights
6 changes: 5 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def to(self, device: str) -> MockPreTrainedModel:
self.device = device
return self

def __call__(self, *args: Any, **kwargs: Any) -> Any:
def forward(self, *args: Any, **kwargs: Any) -> Any:
# Simulate a last_hidden_state output for a transformer model
batch_size, seq_length = kwargs["input_ids"].shape
# Return a tensor of shape (batch_size, seq_length, 768)
Expand All @@ -56,6 +56,10 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
},
)

def __call__(self, *args: Any, **kwargs: Any) -> Any:
# Simply call the forward method to simulate the same behavior as transformers models
return self.forward(*args, **kwargs)

return MockPreTrainedModel()


Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 84570b6

Please sign in to comment.