From a4a20b9d22a79caa322add73911ad50afa818b8a Mon Sep 17 00:00:00 2001 From: Stephan Tulkens Date: Sat, 28 Sep 2024 19:46:40 +0200 Subject: [PATCH] Add model card loading (#45) * Add model card loading * Add tests --- model2vec/model.py | 6 ++++-- model2vec/utils.py | 29 ++++++++++++++++++++++++++--- tests/test_utils.py | 25 +++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 5 deletions(-) create mode 100644 tests/test_utils.py diff --git a/model2vec/model.py b/model2vec/model.py index 0764d37..f6d9f65 100644 --- a/model2vec/model.py +++ b/model2vec/model.py @@ -154,9 +154,11 @@ def from_pretrained( :param token: The huggingface token to use. :return: A StaticEmbedder """ - embeddings, tokenizer, config = load_pretrained(path, token=token) + embeddings, tokenizer, config, metadata = load_pretrained(path, token=token) - return cls(embeddings, tokenizer, config) + return cls( + embeddings, tokenizer, config, base_model_name=metadata.get("base_model"), language=metadata.get("language") + ) def encode( self, diff --git a/model2vec/utils.py b/model2vec/utils.py index d9c2875..7ec8ee7 100644 --- a/model2vec/utils.py +++ b/model2vec/utils.py @@ -102,7 +102,7 @@ def _create_model_card( def load_pretrained( folder_or_repo_path: str | Path, token: str | None = None -) -> tuple[np.ndarray, Tokenizer, dict[str, Any]]: +) -> tuple[np.ndarray, Tokenizer, dict[str, Any], dict[str, Any]]: """ Loads a pretrained model from a folder. @@ -111,7 +111,7 @@ def load_pretrained( - If the local path is not found, we will attempt to load from the huggingface hub. :param token: The huggingface token to use. :raises: FileNotFoundError if the folder exists, but the file does not exist locally. - :return: The embeddings, tokenizer, and config. + :return: The embeddings, tokenizer, config, and metadata. """ folder_or_repo_path = Path(folder_or_repo_path) @@ -133,6 +133,10 @@ def load_pretrained( if not tokenizer_path.exists(): raise FileNotFoundError(f"Tokenizer file does not exist in {folder_or_repo_path}") + # README is optional, so this is a bit finicky. + readme_path = folder_or_repo_path / "README.md" + metadata = _get_metadata_from_readme(readme_path) + else: logger.info("Folder does not exist locally, attempting to use huggingface hub.") try: @@ -148,6 +152,13 @@ def load_pretrained( # Raise original exception. raise e + try: + readme_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), "README.md", token=token) + metadata = _get_metadata_from_readme(Path(readme_path)) + except huggingface_hub.utils.EntryNotFoundError: + logger.info("No README found in the model folder. No model card loaded.") + metadata = {} + config_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), "config.json", token=token) tokenizer_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), "tokenizer.json", token=token) @@ -162,7 +173,19 @@ def load_pretrained( f"Number of tokens does not match number of embeddings: `{len(tokenizer.get_vocab())}` vs `{len(embeddings)}`" ) - return embeddings, tokenizer, config + return embeddings, tokenizer, config, metadata + + +def _get_metadata_from_readme(readme_path: Path) -> dict[str, Any]: + """Get metadata from a README file.""" + if not readme_path.exists(): + logger.info(f"README file not found in {readme_path}. No model card loaded.") + return {} + model_card = ModelCard.load(readme_path) + data: dict[str, Any] = model_card.data.to_dict() + if not data: + logger.info("File README.md exists, but was empty. No model card loaded.") + return data def push_folder_to_hub(folder_path: Path, repo_id: str, private: bool, token: str | None) -> None: diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..7ac4c70 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,25 @@ +from pathlib import Path +from tempfile import NamedTemporaryFile + +from model2vec.utils import _get_metadata_from_readme + + +def test__get_metadata_from_readme_not_exists() -> None: + """Test getting metadata from a README.""" + assert _get_metadata_from_readme(Path("zzz")) == {} + + +def test__get_metadata_from_readme_mocked_file() -> None: + """Test getting metadata from a README.""" + with NamedTemporaryFile() as f: + f.write(b"---\nkey: value\n---\n") + f.flush() + assert _get_metadata_from_readme(Path(f.name))["key"] == "value" + + +def test__get_metadata_from_readme_mocked_file_keys() -> None: + """Test getting metadata from a README.""" + with NamedTemporaryFile() as f: + f.write(b"") + f.flush() + assert set(_get_metadata_from_readme(Path(f.name))) == set()