Skip to content

Commit

Permalink
Add model card loading (#45)
Browse files Browse the repository at this point in the history
* Add model card loading

* Add tests
  • Loading branch information
stephantul authored Sep 28, 2024
1 parent 0ca9d00 commit a4a20b9
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 5 deletions.
6 changes: 4 additions & 2 deletions model2vec/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,11 @@ def from_pretrained(
:param token: The huggingface token to use.
:return: A StaticEmbedder
"""
embeddings, tokenizer, config = load_pretrained(path, token=token)
embeddings, tokenizer, config, metadata = load_pretrained(path, token=token)

return cls(embeddings, tokenizer, config)
return cls(
embeddings, tokenizer, config, base_model_name=metadata.get("base_model"), language=metadata.get("language")
)

def encode(
self,
Expand Down
29 changes: 26 additions & 3 deletions model2vec/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _create_model_card(

def load_pretrained(
folder_or_repo_path: str | Path, token: str | None = None
) -> tuple[np.ndarray, Tokenizer, dict[str, Any]]:
) -> tuple[np.ndarray, Tokenizer, dict[str, Any], dict[str, Any]]:
"""
Loads a pretrained model from a folder.
Expand All @@ -111,7 +111,7 @@ def load_pretrained(
- If the local path is not found, we will attempt to load from the huggingface hub.
:param token: The huggingface token to use.
:raises: FileNotFoundError if the folder exists, but the file does not exist locally.
:return: The embeddings, tokenizer, and config.
:return: The embeddings, tokenizer, config, and metadata.
"""
folder_or_repo_path = Path(folder_or_repo_path)
Expand All @@ -133,6 +133,10 @@ def load_pretrained(
if not tokenizer_path.exists():
raise FileNotFoundError(f"Tokenizer file does not exist in {folder_or_repo_path}")

# README is optional, so this is a bit finicky.
readme_path = folder_or_repo_path / "README.md"
metadata = _get_metadata_from_readme(readme_path)

else:
logger.info("Folder does not exist locally, attempting to use huggingface hub.")
try:
Expand All @@ -148,6 +152,13 @@ def load_pretrained(
# Raise original exception.
raise e

try:
readme_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), "README.md", token=token)
metadata = _get_metadata_from_readme(Path(readme_path))
except huggingface_hub.utils.EntryNotFoundError:
logger.info("No README found in the model folder. No model card loaded.")
metadata = {}

config_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), "config.json", token=token)
tokenizer_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), "tokenizer.json", token=token)

Expand All @@ -162,7 +173,19 @@ def load_pretrained(
f"Number of tokens does not match number of embeddings: `{len(tokenizer.get_vocab())}` vs `{len(embeddings)}`"
)

return embeddings, tokenizer, config
return embeddings, tokenizer, config, metadata


def _get_metadata_from_readme(readme_path: Path) -> dict[str, Any]:
"""Get metadata from a README file."""
if not readme_path.exists():
logger.info(f"README file not found in {readme_path}. No model card loaded.")
return {}
model_card = ModelCard.load(readme_path)
data: dict[str, Any] = model_card.data.to_dict()
if not data:
logger.info("File README.md exists, but was empty. No model card loaded.")
return data


def push_folder_to_hub(folder_path: Path, repo_id: str, private: bool, token: str | None) -> None:
Expand Down
25 changes: 25 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from pathlib import Path
from tempfile import NamedTemporaryFile

from model2vec.utils import _get_metadata_from_readme


def test__get_metadata_from_readme_not_exists() -> None:
"""Test getting metadata from a README."""
assert _get_metadata_from_readme(Path("zzz")) == {}


def test__get_metadata_from_readme_mocked_file() -> None:
"""Test getting metadata from a README."""
with NamedTemporaryFile() as f:
f.write(b"---\nkey: value\n---\n")
f.flush()
assert _get_metadata_from_readme(Path(f.name))["key"] == "value"


def test__get_metadata_from_readme_mocked_file_keys() -> None:
"""Test getting metadata from a README."""
with NamedTemporaryFile() as f:
f.write(b"")
f.flush()
assert set(_get_metadata_from_readme(Path(f.name))) == set()

0 comments on commit a4a20b9

Please sign in to comment.