Skip to content

Commit

Permalink
Updated test
Browse files Browse the repository at this point in the history
  • Loading branch information
Pringled committed Oct 11, 2024
1 parent 826f01f commit 27d0a39
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion model2vec/distill/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,10 @@ def create_output_embeddings_from_model_name(
if out.dtype == torch.bfloat16:
out = out.float()

# Detach the tensor from the computation graph before converting to NumPy
# Add the output to the intermediate weights
intermediate_weights.append(out[:, 1].detach().cpu().numpy())

# Concatenate the intermediate weights
out_weights = np.concatenate(intermediate_weights)

return tokenizer.convert_ids_to_tokens(ids), out_weights

0 comments on commit 27d0a39

Please sign in to comment.