diff --git a/libs/community/langchain_community/llms/huggingface_hub.py b/libs/community/langchain_community/llms/huggingface_hub.py index 32facc244b0cd..1db668fe7d32e 100644 --- a/libs/community/langchain_community/llms/huggingface_hub.py +++ b/libs/community/langchain_community/llms/huggingface_hub.py @@ -6,6 +6,7 @@ from langchain_core.utils import get_from_dict_or_env from langchain_community.llms.utils import enforce_stop_tokens +from langchain_community.utils.huggingface_hub import is_inference_client_supported DEFAULT_REPO_ID = "gpt2" VALID_TASKS = ("text2text-generation", "text-generation", "summarization") @@ -18,8 +19,6 @@ class HuggingFaceHub(LLM): environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass it as a named parameter to the constructor. - Only supports `text-generation`, `text2text-generation` and `summarization` for now. - Example: .. code-block:: python @@ -28,11 +27,19 @@ class HuggingFaceHub(LLM): """ client: Any #: :meta private: - repo_id: str = DEFAULT_REPO_ID + repo_id: Optional[str] = DEFAULT_REPO_ID + """(Deprecated) Model name to use + when huggingface_hub package is version 0.17 or below.""" + model: Optional[str] = None """Model name to use.""" task: Optional[str] = None - """Task to call the model with. - Should be a task that returns `generated_text` or `summary_text`.""" + """Task to perform on the inference. Used only to default to a recommended model + if `model` is not provided. At least `model` or `task` must be provided. + + When huggingface_hub package version is 0.17 or below, + it's a task to call the model with. + only `text-generation`, `text2text-generation` and `summarization` are supported. + The task should return `generated_text` or `summary_text`.""" model_kwargs: Optional[dict] = None """Keyword arguments to pass to the model.""" @@ -50,19 +57,30 @@ def validate_environment(cls, values: Dict) -> Dict: values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN" ) try: - from huggingface_hub.inference_api import InferenceApi - repo_id = values["repo_id"] - client = InferenceApi( - repo_id=repo_id, - token=huggingfacehub_api_token, - task=values.get("task"), - ) - if client.task not in VALID_TASKS: - raise ValueError( - f"Got invalid task {client.task}, " - f"currently only {VALID_TASKS} are supported" + repo_id_specified = DEFAULT_REPO_ID != repo_id + + if repo_id_specified or not is_inference_client_supported(): + from huggingface_hub import InferenceApi + + client = InferenceApi( + repo_id=repo_id, + token=huggingfacehub_api_token, + task=values.get("task"), ) + + if client.task not in VALID_TASKS: + raise ValueError( + f"Got invalid task {client.task}, " + f"currently only {VALID_TASKS} are supported" + ) + else: + from huggingface_hub import InferenceClient + + client = InferenceClient( + model=values.get("model"), token=huggingfacehub_api_token + ) + values["client"] = client except ImportError: raise ValueError( @@ -76,7 +94,7 @@ def _identifying_params(self) -> Mapping[str, Any]: """Get the identifying parameters.""" _model_kwargs = self.model_kwargs or {} return { - **{"repo_id": self.repo_id, "task": self.task}, + **{"repo_id": self.repo_id, "model": self.model, "task": self.task}, **{"model_kwargs": _model_kwargs}, } @@ -91,6 +109,22 @@ def _call( stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, + ) -> str: + if type(self.client).__name__ == "InferenceClient": + return self._call_with_inference_client( + prompt=prompt, stop=stop, run_manager=run_manager, **kwargs + ) + elif type(self.client).__name__ == "InferenceApi": + return self._call_with_inference_api( + prompt=prompt, stop=stop, run_manager=run_manager, **kwargs + ) + + def _call_with_inference_api( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Call out to HuggingFace Hub's inference endpoint. @@ -128,3 +162,20 @@ def _call( # stop tokens when making calls to huggingface_hub. text = enforce_stop_tokens(text, stop) return text + + def _call_with_inference_client( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + _model_kwargs = self.model_kwargs or {} + json = {**_model_kwargs, **kwargs} + + if prompt: + json["inputs"] = prompt + + response = self.client.post(json=json, task=self.task) + + return response.decode("utf-8") diff --git a/libs/community/langchain_community/utils/huggingface_hub.py b/libs/community/langchain_community/utils/huggingface_hub.py new file mode 100644 index 0000000000000..174dd24783511 --- /dev/null +++ b/libs/community/langchain_community/utils/huggingface_hub.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from importlib.metadata import version + +from packaging.version import parse + + +def is_inference_client_supported() -> bool: + """Return whether HuggingFace Hub Client library supports InferenceClient.""" + # InferenceAPI was deprecated 0.17. + # See https://github.com/huggingface/huggingface_hub/commit/0a02b04e6cab31a906ddeaf61fce0d5df4b4f7be. + _version = parse(version("hugingface_hub")) + return not (_version.major == 0 and _version.minor < 17) diff --git a/libs/community/poetry.lock b/libs/community/poetry.lock index e4f801d984cee..7c03d6c5726a0 100644 --- a/libs/community/poetry.lock +++ b/libs/community/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "aenum" @@ -3009,13 +3009,13 @@ files = [ [[package]] name = "huggingface-hub" -version = "0.19.4" +version = "0.20.1" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" optional = false python-versions = ">=3.8.0" files = [ - {file = "huggingface_hub-0.19.4-py3-none-any.whl", hash = "sha256:dba013f779da16f14b606492828f3760600a1e1801432d09fe1c33e50b825bb5"}, - {file = "huggingface_hub-0.19.4.tar.gz", hash = "sha256:176a4fc355a851c17550e7619488f383189727eab209534d7cef2114dae77b22"}, + {file = "huggingface_hub-0.20.1-py3-none-any.whl", hash = "sha256:ecfdea395a8bc68cd160106c5bd857f7e010768d95f9e1862a779010cc304831"}, + {file = "huggingface_hub-0.20.1.tar.gz", hash = "sha256:8c88c4c3c8853e22f2dfb4d84c3d493f4e1af52fb3856a90e1eeddcf191ddbb1"}, ] [package.dependencies] @@ -3028,15 +3028,14 @@ tqdm = ">=4.42.1" typing-extensions = ">=3.7.4.3" [package.extras] -all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] +all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] cli = ["InquirerPy (==0.3.4)"] -dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] -docs = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "hf-doc-builder", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)", "watchdog"] +dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"] inference = ["aiohttp", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)"] quality = ["mypy (==1.5.1)", "ruff (>=0.1.3)"] tensorflow = ["graphviz", "pydot", "tensorflow"] -testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] +testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] torch = ["torch"] typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"] @@ -6164,6 +6163,7 @@ files = [ {file = "pymongo-4.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b8729dbf25eb32ad0dc0b9bd5e6a0d0b7e5c2dc8ec06ad171088e1896b522a74"}, {file = "pymongo-4.6.1-cp312-cp312-win32.whl", hash = "sha256:3177f783ae7e08aaf7b2802e0df4e4b13903520e8380915e6337cdc7a6ff01d8"}, {file = "pymongo-4.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:00c199e1c593e2c8b033136d7a08f0c376452bac8a896c923fcd6f419e07bdd2"}, + {file = "pymongo-4.6.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6dcc95f4bb9ed793714b43f4f23a7b0c57e4ef47414162297d6f650213512c19"}, {file = "pymongo-4.6.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:13552ca505366df74e3e2f0a4f27c363928f3dff0eef9f281eb81af7f29bc3c5"}, {file = "pymongo-4.6.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:77e0df59b1a4994ad30c6d746992ae887f9756a43fc25dec2db515d94cf0222d"}, {file = "pymongo-4.6.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:3a7f02a58a0c2912734105e05dedbee4f7507e6f1bd132ebad520be0b11d46fd"}, @@ -9167,9 +9167,9 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [extras] cli = ["typer"] -extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cohere", "dashvector", "databricks-vectorsearch", "datasets", "dgml-utils", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hologres-vector", "html2text", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "openai", "openapi-pydantic", "oracle-ads", "pandas", "pdfminer-six", "pgvector", "praw", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "upstash-redis", "xata", "xmltodict", "zhipuai"] +extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cohere", "dashvector", "databricks-vectorsearch", "datasets", "dgml-utils", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hologres-vector", "html2text", "huggingface-hub", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "openai", "openapi-pydantic", "oracle-ads", "pandas", "pdfminer-six", "pgvector", "praw", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "upstash-redis", "xata", "xmltodict", "zhipuai"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "454e721d2b68b5769f70bf8deb0d43285d270e2ca8ae99dc72c773dfa827835b" +content-hash = "a869f09e3f4eda23987e5de73390f67b9c6fbc7f359b8c173b0aba18ea6117ca" diff --git a/libs/community/pyproject.toml b/libs/community/pyproject.toml index ec0f6088c3a11..15eb5dee53cb2 100644 --- a/libs/community/pyproject.toml +++ b/libs/community/pyproject.toml @@ -87,6 +87,7 @@ datasets = {version = "^2.15.0", optional = true} azure-ai-documentintelligence = {version = "^1.0.0b1", optional = true} oracle-ads = {version = "^2.9.1", optional = true} zhipuai = {version = "^1.0.7", optional = true} +huggingface-hub = {version = "^0.20.1", optional = true} [tool.poetry.group.test] optional = true @@ -216,6 +217,7 @@ extended_testing = [ "streamlit", "pyspark", "openai", + "huggingface_hub", "sympy", "rapidfuzz", "jsonschema", diff --git a/libs/community/tests/unit_tests/llms/test_huggingface_hub.py b/libs/community/tests/unit_tests/llms/test_huggingface_hub.py new file mode 100644 index 0000000000000..83d559727aa32 --- /dev/null +++ b/libs/community/tests/unit_tests/llms/test_huggingface_hub.py @@ -0,0 +1,75 @@ +import os +from unittest.mock import MagicMock, PropertyMock, patch + +import pytest + +from langchain_community.llms.huggingface_hub import HuggingFaceHub + +os.environ["HUGGINGFACEHUB_API_TOKEN"] = "hb-token" + + +@pytest.mark.requires("huggingface_hub") +@patch( + "langchain_community.llms.huggingface_hub.is_inference_client_supported", + return_value=True, +) +def test_huggingface_hub_with_model(mock_is_inference_client_supported) -> None: + from huggingface_hub import InferenceClient + + llm = HuggingFaceHub(model="model-1") + + assert type(llm.client) is InferenceClient + assert llm.client.model == "model-1" + assert llm.client.headers["authorization"] == "Bearer hb-token" + + mock_response = MagicMock() + type(mock_response).content = PropertyMock(return_value=b"Hello, world!") + + with patch("huggingface_hub.inference._client.get_session") as mock_session: + mock_session.return_value.post.return_value = mock_response + + completion = llm("Hi, how are you?") + + assert completion == "Hello, world!" + + +@pytest.mark.requires("huggingface_hub") +@patch( + "langchain_community.llms.huggingface_hub.is_inference_client_supported", + return_value=True, +) +def test_huggingface_hub_with_task(mock_is_inference_client_supported) -> None: + from huggingface_hub import InferenceClient + + llm = HuggingFaceHub(task="task-1") + + assert type(llm.client) is InferenceClient + assert llm.client.model is None + assert llm.client.headers["authorization"] == "Bearer hb-token" + assert llm.task == "task-1" + + +@pytest.mark.requires("huggingface_hub") +@patch( + "langchain_community.llms.huggingface_hub.is_inference_client_supported", + return_value=False, +) +def test_huggingface_hub_param_with_inference_api( + mock_is_inference_client_supported, +) -> None: + from huggingface_hub import InferenceApi + + with patch("huggingface_hub.inference_api.HfApi") as mock_hfapi: + llm = HuggingFaceHub(repo_id="model-1", task="text-generation") + + mock_hfapi.assert_called_once_with(token="hb-token") + + assert type(llm.client) is InferenceApi + assert "model-1" in llm.client.api_url + assert llm.client.task == "text-generation" + assert llm.client.headers["authorization"] == "Bearer hb-token" + + llm("Hi, how are you?") + + +# TODO: Assert the model invocation selectes the task