Skip to content

Commit

Permalink
Make loading a lot faster
Browse files Browse the repository at this point in the history
  • Loading branch information
stephantul committed Oct 11, 2024
1 parent 2ce3c97 commit ddb982b
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 33 deletions.
9 changes: 8 additions & 1 deletion model2vec/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from logging import getLogger
from pathlib import Path
from tempfile import TemporaryDirectory
from time import time
from typing import Any, Iterator

import numpy as np
Expand Down Expand Up @@ -176,6 +177,8 @@ def from_pretrained(
cls: type[StaticModel],
path: PathLike,
token: str | None = None,
local_files_only: bool = False,
load_readme: bool = True,
) -> StaticModel:
"""
Load a StaticModel from a local path or huggingface hub path.
Expand All @@ -184,9 +187,13 @@ def from_pretrained(
:param path: The path to load your static model from.
:param token: The huggingface token to use.
:param local_files_only: Whether to only load local files.
:param load_readme: Whether to load the readme.
:return: A StaticEmbedder
"""
embeddings, tokenizer, config, metadata = load_pretrained(path, token=token)
embeddings, tokenizer, config, metadata = load_pretrained(
path, token=token, local_files_only=local_files_only, load_readme=load_readme
)

return cls(
embeddings, tokenizer, config, base_model_name=metadata.get("base_model"), language=metadata.get("language")
Expand Down
100 changes: 68 additions & 32 deletions model2vec/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import huggingface_hub.errors
import numpy as np
import safetensors
from huggingface_hub import ModelCard, ModelCardData
from huggingface_hub import ModelCard, ModelCardData, get_hf_file_metadata
from rich.logging import RichHandler
from safetensors.numpy import save_file
from tokenizers import Tokenizer
Expand Down Expand Up @@ -101,7 +101,7 @@ def _create_model_card(


def load_pretrained(
folder_or_repo_path: str | Path, token: str | None = None
folder_or_repo_path: str | Path, token: str | None = None, local_files_only: bool = False, load_readme: bool = True
) -> tuple[np.ndarray, Tokenizer, dict[str, Any], dict[str, Any]]:
"""
Loads a pretrained model from a folder.
Expand All @@ -110,57 +110,67 @@ def load_pretrained(
- If this is a local path, we will load from the local path.
- If the local path is not found, we will attempt to load from the huggingface hub.
:param token: The huggingface token to use.
:param local_files_only: Whether to only use local files.
:param load_readme: Whether to load the README file.
:raises: FileNotFoundError if the folder exists, but the file does not exist locally.
:return: The embeddings, tokenizer, config, and metadata.
"""
folder_or_repo_path = Path(folder_or_repo_path)
if folder_or_repo_path.exists():
embeddings_path = folder_or_repo_path / "model.safetensors"
if not embeddings_path.exists():
old_embeddings_path = folder_or_repo_path / "embeddings.safetensors"
if old_embeddings_path.exists():
logger.warning("Old embeddings file found. Please rename to `model.safetensors` and re-save.")
embeddings_path = old_embeddings_path
else:
raise FileNotFoundError(f"Embeddings file does not exist in {folder_or_repo_path}")

config_path = folder_or_repo_path / "config.json"
if not config_path.exists():
raise FileNotFoundError(f"Config file does not exist in {folder_or_repo_path}")

tokenizer_path = folder_or_repo_path / "tokenizer.json"
if not tokenizer_path.exists():
raise FileNotFoundError(f"Tokenizer file does not exist in {folder_or_repo_path}")

# README is optional, so this is a bit finicky.
readme_path = folder_or_repo_path / "README.md"
metadata = _get_metadata_from_readme(readme_path)

embeddings_path, config_path, tokenizer_path, metadata = _get_local_paths(folder_or_repo_path)
else:
logger.info("Folder does not exist locally, attempting to use huggingface hub.")

try:
embeddings_path = huggingface_hub.hf_hub_download(
folder_or_repo_path.as_posix(), "model.safetensors", token=token
folder_or_repo_path.as_posix(), "model.safetensors", token=token, local_files_only=local_files_only
)
path = Path(embeddings_path).parent
revision = path.parts[-1]
except huggingface_hub.utils.EntryNotFoundError as e:
try:
embeddings_path = huggingface_hub.hf_hub_download(
folder_or_repo_path.as_posix(), "embeddings.safetensors", token=token
folder_or_repo_path.as_posix(),
"embeddings.safetensors",
token=token,
local_files_only=local_files_only,
revision=revision,
)
except huggingface_hub.utils.EntryNotFoundError:
# Raise original exception.
raise e

try:
readme_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), "README.md", token=token)
metadata = _get_metadata_from_readme(Path(readme_path))
except huggingface_hub.utils.EntryNotFoundError:
logger.info("No README found in the model folder. No model card loaded.")
if load_readme:
try:
readme_path = huggingface_hub.hf_hub_download(
folder_or_repo_path.as_posix(),
"README.md",
token=token,
local_files_only=local_files_only,
revision=revision,
)
metadata = _get_metadata_from_readme(Path(readme_path))
except huggingface_hub.utils.EntryNotFoundError:
logger.info("No README found in the model folder. No model card loaded.")
metadata = {}
else:
metadata = {}

config_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), "config.json", token=token)
tokenizer_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), "tokenizer.json", token=token)
config_path = huggingface_hub.hf_hub_download(
folder_or_repo_path.as_posix(),
"config.json",
token=token,
local_files_only=local_files_only,
revision=revision,
)
tokenizer_path = huggingface_hub.hf_hub_download(
folder_or_repo_path.as_posix(),
"tokenizer.json",
token=token,
local_files_only=local_files_only,
revision=revision,
)

opened_tensor_file = cast(SafeOpenProtocol, safetensors.safe_open(embeddings_path, framework="numpy"))
embeddings = opened_tensor_file.get_tensor("embeddings")
Expand All @@ -176,6 +186,32 @@ def load_pretrained(
return embeddings, tokenizer, config, metadata


def _get_local_paths(folder_or_repo_path: Path) -> tuple[Path, Path, Path, dict[str, Any]]:
"""Get the local paths for a folder."""
embeddings_path = folder_or_repo_path / "model.safetensors"
if not embeddings_path.exists():
old_embeddings_path = folder_or_repo_path / "embeddings.safetensors"
if old_embeddings_path.exists():
logger.warning("Old embeddings file found. Please rename to `model.safetensors` and re-save.")
embeddings_path = old_embeddings_path
else:
raise FileNotFoundError(f"Embeddings file does not exist in {folder_or_repo_path}")

config_path = folder_or_repo_path / "config.json"
if not config_path.exists():
raise FileNotFoundError(f"Config file does not exist in {folder_or_repo_path}")

tokenizer_path = folder_or_repo_path / "tokenizer.json"
if not tokenizer_path.exists():
raise FileNotFoundError(f"Tokenizer file does not exist in {folder_or_repo_path}")

# README is optional, so this is a bit finicky.
readme_path = folder_or_repo_path / "README.md"
metadata = _get_metadata_from_readme(readme_path)

return embeddings_path, config_path, tokenizer_path, metadata


def _get_metadata_from_readme(readme_path: Path) -> dict[str, Any]:
"""Get metadata from a README file."""
if not readme_path.exists():
Expand Down

0 comments on commit ddb982b

Please sign in to comment.