Skip to content

Commit

Permalink
Adding off-the-shelf Stanford-NLP loading, removing the add_pooling_l…
Browse files Browse the repository at this point in the history
…ayer parameters and casting the Dense layer to dtype if set in model_kwargs
  • Loading branch information
NohTow committed Sep 11, 2024
1 parent 68d3b86 commit 554ae17
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 5 deletions.
36 changes: 36 additions & 0 deletions pylate/models/Dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import os

import torch
from safetensors import safe_open
from safetensors.torch import load_model as load_safetensors_model
from sentence_transformers.models import Dense as DenseSentenceTransformer
from sentence_transformers.util import import_from_string
from torch import nn
from transformers.utils import cached_file

__all__ = ["Dense"]

Expand Down Expand Up @@ -77,6 +79,40 @@ def from_sentence_transformers(dense: DenseSentenceTransformer) -> "Dense":
model.load_state_dict(dense.state_dict())
return model

@staticmethod
def from_stanford_weights(
model_name_or_path,
cache_folder,
revision,
local_files_only,
token,
use_auth_token,
) -> "Dense":
# Check if the model is locally available
if not (os.path.exists(os.path.join(model_name_or_path))):
# Else download the model/use the cached version
model_name_or_path = cached_file(
model_name_or_path,
filename="model.safetensors",
cache_dir=cache_folder,
revision=revision,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
)
with safe_open(model_name_or_path, framework="pt", device="cpu") as f:
state_dict = {"linear.weight": f.get_tensor("linear.weight")}

# Determine input and output dimensions
in_features = state_dict["linear.weight"].shape[1]
out_features = state_dict["linear.weight"].shape[0]

# Create Dense layer instance
model = Dense(in_features=in_features, out_features=out_features, bias=False)

model.load_state_dict(state_dict)
return model

@staticmethod
def load(input_path) -> "Dense":
"""Load a Dense layer."""
Expand Down
22 changes: 17 additions & 5 deletions pylate/models/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,6 @@ def __init__(
config_kwargs: dict | None = None,
model_card_data: Optional[SentenceTransformerModelCardData] = None,
) -> None:
model_kwargs = {} if model_kwargs is None else model_kwargs
model_kwargs["add_pooling_layer"] = False

self.query_prefix = query_prefix
self.document_prefix = document_prefix
self.query_length = query_length
Expand Down Expand Up @@ -250,9 +247,21 @@ def __init__(
)

hidden_size = self[0].get_word_embedding_dimension()

# If the model is a stanford-nlp ColBERT, load the weights of the dense layer
if self[0].auto_model.config.architectures[0] == "HF_ColBERT":
self.append(
Dense.from_stanford_weights(
model_name_or_path,
cache_folder,
revision,
local_files_only,
token,
use_auth_token,
)
)
logger.warning("Loaded the ColBERT model from Stanford NLP.")
# Add a linear projection layer to the model in order to project the embeddings to the desired size.
if len(self) < 2:
elif len(self) < 2:
# Add a linear projection layer to the model in order to project the embeddings to the desired size
embedding_size = embedding_size or 128

Expand Down Expand Up @@ -282,6 +291,9 @@ def __init__(
else:
logger.warning("Pylate model loaded successfully.")

if model_kwargs is not None and "torch_dtype" in model_kwargs:
self[1].to(model_kwargs["torch_dtype"])

self.to(device)
self.is_hpu_graph_enabled = False

Expand Down

0 comments on commit 554ae17

Please sign in to comment.