Skip to content

Commit

Permalink
Updated test
Browse files Browse the repository at this point in the history
  • Loading branch information
Pringled committed Oct 11, 2024
1 parent ec6765e commit 826f01f
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 71 deletions.
88 changes: 19 additions & 69 deletions model2vec/distill/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,77 +126,27 @@ def create_output_embeddings_from_model_name(
intermediate_weights: list[np.ndarray] = []
for batch_idx in tqdm(range(0, len(stacked), _DEFAULT_BATCH_SIZE)):
batch = stacked[batch_idx : batch_idx + _DEFAULT_BATCH_SIZE].to(model.device)

attention_mask = torch.ones_like(batch)
# Prepare model inputs
model_inputs = {"input_ids": batch.to(device), "attention_mask": attention_mask}

# Add token_type_ids only if the model supports it
if "token_type_ids" in model.forward.__code__.co_varnames:
model_inputs["token_type_ids"] = torch.zeros_like(batch)

# Perform the forward pass
encoded_output: BaseModelOutputWithPoolingAndCrossAttentions = model(**model_inputs)

out: torch.Tensor = encoded_output.last_hidden_state

# Convert bfloat16 to float32 if necessary
if out.dtype == torch.bfloat16:
out = out.float()

# Detach the tensor from the computation graph before converting to NumPy
with torch.no_grad():
attention_mask = torch.ones_like(batch)
# Prepare model inputs
model_inputs = {"input_ids": batch.to(device), "attention_mask": attention_mask}

# Add token_type_ids only if the model supports it
if "token_type_ids" in model.forward.__code__.co_varnames:
model_inputs["token_type_ids"] = torch.zeros_like(batch)

# Perform the forward pass
encoded_output: BaseModelOutputWithPoolingAndCrossAttentions = model(**model_inputs)
out: torch.Tensor = encoded_output.last_hidden_state
# NOTE: If the dtype is bfloat 16, we convert to float32,
# because numpy does not suport bfloat16
# See here: https://github.com/numpy/numpy/issues/19808
if out.dtype == torch.bfloat16:
out = out.float()

# Detach the tensor from the computation graph before converting to NumPy
intermediate_weights.append(out[:, 1].detach().cpu().numpy())

out_weights = np.concatenate(intermediate_weights)

return tokenizer.convert_ids_to_tokens(ids), out_weights


# def create_output_embeddings_from_model_name(
# model: PreTrainedModel,
# tokenizer: PreTrainedTokenizer,
# device: str,
# ) -> tuple[list[str], np.ndarray]:
# """
# Create output embeddings for a bunch of tokens from a model name.

# It does a forward pass for all ids in the tokenizer.

# :param model: The model name to use.
# :param tokenizer: The tokenizer to use.
# :param device: The torch device to use.
# :return: The tokens and output embeddings.
# """
# model = model.to(device)
# ids = torch.arange(tokenizer.vocab_size)

# # Work-around to get the eos and bos token ids without having to go into tokenizer internals.
# dummy_encoding = tokenizer.encode("A")
# eos_token_id, bos_token_id = dummy_encoding[0], dummy_encoding[-1]

# eos = torch.full([len(ids)], fill_value=eos_token_id)
# bos = torch.full([len(ids)], fill_value=bos_token_id)

# stacked = torch.stack([bos, ids, eos], dim=1)

# intermediate_weights: list[np.ndarray] = []
# for batch_idx in tqdm(range(0, len(stacked), _DEFAULT_BATCH_SIZE)):
# batch = stacked[batch_idx : batch_idx + _DEFAULT_BATCH_SIZE].to(model.device)
# with torch.no_grad():
# # NOTE: we create these masks because nomic embed requires them.
# # Normally, we could set them to None
# token_type_ids = torch.zeros_like(batch)
# attention_mask = torch.ones_like(batch)
# encoded: BaseModelOutputWithPoolingAndCrossAttentions = model(
# input_ids=batch.to(device), attention_mask=attention_mask, token_type_ids=token_type_ids
# )
# out: torch.Tensor = encoded.last_hidden_state
# # NOTE: If the dtype is bfloat 16, we convert to float32,
# # because numpy does not suport bfloat16
# # See here: https://github.com/numpy/numpy/issues/19808
# if out.dtype == torch.bfloat16:
# out = out.float()
# intermediate_weights.append(out[:, 1].cpu().numpy())
# out_weights = np.concatenate(intermediate_weights)

# return tokenizer.convert_ids_to_tokens(ids), out_weights
6 changes: 5 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def to(self, device: str) -> MockPreTrainedModel:
self.device = device
return self

def __call__(self, *args: Any, **kwargs: Any) -> Any:
def forward(self, *args: Any, **kwargs: Any) -> Any:
# Simulate a last_hidden_state output for a transformer model
batch_size, seq_length = kwargs["input_ids"].shape
# Return a tensor of shape (batch_size, seq_length, 768)
Expand All @@ -56,6 +56,10 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
},
)

def __call__(self, *args: Any, **kwargs: Any) -> Any:
# Simply call the forward method to simulate the same behavior as transformers models
return self.forward(*args, **kwargs)

return MockPreTrainedModel()


Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 826f01f

Please sign in to comment.