Skip to content

Commit

Permalink
Use HB Inference Client in llms package
Browse files Browse the repository at this point in the history
  • Loading branch information
dosuken123 committed Jan 8, 2024
1 parent 7025fa2 commit 8033249
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 27 deletions.
85 changes: 68 additions & 17 deletions libs/community/langchain_community/llms/huggingface_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from langchain_core.utils import get_from_dict_or_env

from langchain_community.llms.utils import enforce_stop_tokens
from langchain_community.utils.huggingface_hub import is_inference_client_supported

DEFAULT_REPO_ID = "gpt2"
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
Expand All @@ -18,8 +19,6 @@ class HuggingFaceHub(LLM):
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
it as a named parameter to the constructor.
Only supports `text-generation`, `text2text-generation` and `summarization` for now.
Example:
.. code-block:: python
Expand All @@ -28,11 +27,19 @@ class HuggingFaceHub(LLM):
"""

client: Any #: :meta private:
repo_id: str = DEFAULT_REPO_ID
repo_id: Optional[str] = DEFAULT_REPO_ID
"""(Deprecated) Model name to use
when huggingface_hub package is version 0.17 or below."""
model: Optional[str] = None
"""Model name to use."""
task: Optional[str] = None
"""Task to call the model with.
Should be a task that returns `generated_text` or `summary_text`."""
"""Task to perform on the inference. Used only to default to a recommended model
if `model` is not provided. At least `model` or `task` must be provided.
When huggingface_hub package version is 0.17 or below,
it's a task to call the model with.
only `text-generation`, `text2text-generation` and `summarization` are supported.
The task should return `generated_text` or `summary_text`."""
model_kwargs: Optional[dict] = None
"""Keyword arguments to pass to the model."""

Expand All @@ -50,19 +57,30 @@ def validate_environment(cls, values: Dict) -> Dict:
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
)
try:
from huggingface_hub.inference_api import InferenceApi

repo_id = values["repo_id"]
client = InferenceApi(
repo_id=repo_id,
token=huggingfacehub_api_token,
task=values.get("task"),
)
if client.task not in VALID_TASKS:
raise ValueError(
f"Got invalid task {client.task}, "
f"currently only {VALID_TASKS} are supported"
repo_id_specified = DEFAULT_REPO_ID != repo_id

if repo_id_specified or not is_inference_client_supported():
from huggingface_hub import InferenceApi

client = InferenceApi(
repo_id=repo_id,
token=huggingfacehub_api_token,
task=values.get("task"),
)

if client.task not in VALID_TASKS:
raise ValueError(
f"Got invalid task {client.task}, "
f"currently only {VALID_TASKS} are supported"
)
else:
from huggingface_hub import InferenceClient

client = InferenceClient(
model=values.get("model"), token=huggingfacehub_api_token
)

values["client"] = client
except ImportError:
raise ValueError(
Expand All @@ -76,7 +94,7 @@ def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
_model_kwargs = self.model_kwargs or {}
return {
**{"repo_id": self.repo_id, "task": self.task},
**{"repo_id": self.repo_id, "model": self.model, "task": self.task},
**{"model_kwargs": _model_kwargs},
}

Expand All @@ -91,6 +109,22 @@ def _call(
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
if type(self.client).__name__ == "InferenceClient":
return self._call_with_inference_client(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
)
elif type(self.client).__name__ == "InferenceApi":
return self._call_with_inference_api(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
)

def _call_with_inference_api(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to HuggingFace Hub's inference endpoint.
Expand Down Expand Up @@ -128,3 +162,20 @@ def _call(
# stop tokens when making calls to huggingface_hub.
text = enforce_stop_tokens(text, stop)
return text

def _call_with_inference_client(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
_model_kwargs = self.model_kwargs or {}
json = {**_model_kwargs, **kwargs}

if prompt:
json["inputs"] = prompt

response = self.client.post(json=json, task=self.task)

return response.decode("utf-8")
13 changes: 13 additions & 0 deletions libs/community/langchain_community/utils/huggingface_hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from __future__ import annotations

from importlib.metadata import version

from packaging.version import parse


def is_inference_client_supported() -> bool:
"""Return whether HuggingFace Hub Client library supports InferenceClient."""
# InferenceAPI was deprecated 0.17.
# See https://github.com/huggingface/huggingface_hub/commit/0a02b04e6cab31a906ddeaf61fce0d5df4b4f7be.
_version = parse(version("hugingface_hub"))
return not (_version.major == 0 and _version.minor < 17)
20 changes: 10 additions & 10 deletions libs/community/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions libs/community/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ datasets = {version = "^2.15.0", optional = true}
azure-ai-documentintelligence = {version = "^1.0.0b1", optional = true}
oracle-ads = {version = "^2.9.1", optional = true}
zhipuai = {version = "^1.0.7", optional = true}
huggingface-hub = {version = "^0.20.1", optional = true}

[tool.poetry.group.test]
optional = true
Expand Down Expand Up @@ -216,6 +217,7 @@ extended_testing = [
"streamlit",
"pyspark",
"openai",
"huggingface_hub",
"sympy",
"rapidfuzz",
"jsonschema",
Expand Down
75 changes: 75 additions & 0 deletions libs/community/tests/unit_tests/llms/test_huggingface_hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import os
from unittest.mock import MagicMock, PropertyMock, patch

import pytest

from langchain_community.llms.huggingface_hub import HuggingFaceHub

os.environ["HUGGINGFACEHUB_API_TOKEN"] = "hb-token"


@pytest.mark.requires("huggingface_hub")
@patch(
"langchain_community.llms.huggingface_hub.is_inference_client_supported",
return_value=True,
)
def test_huggingface_hub_with_model(mock_is_inference_client_supported) -> None:
from huggingface_hub import InferenceClient

llm = HuggingFaceHub(model="model-1")

assert type(llm.client) is InferenceClient
assert llm.client.model == "model-1"
assert llm.client.headers["authorization"] == "Bearer hb-token"

mock_response = MagicMock()
type(mock_response).content = PropertyMock(return_value=b"Hello, world!")

with patch("huggingface_hub.inference._client.get_session") as mock_session:
mock_session.return_value.post.return_value = mock_response

completion = llm("Hi, how are you?")

assert completion == "Hello, world!"


@pytest.mark.requires("huggingface_hub")
@patch(
"langchain_community.llms.huggingface_hub.is_inference_client_supported",
return_value=True,
)
def test_huggingface_hub_with_task(mock_is_inference_client_supported) -> None:
from huggingface_hub import InferenceClient

llm = HuggingFaceHub(task="task-1")

assert type(llm.client) is InferenceClient
assert llm.client.model is None
assert llm.client.headers["authorization"] == "Bearer hb-token"
assert llm.task == "task-1"


@pytest.mark.requires("huggingface_hub")
@patch(
"langchain_community.llms.huggingface_hub.is_inference_client_supported",
return_value=False,
)
def test_huggingface_hub_param_with_inference_api(
mock_is_inference_client_supported,
) -> None:
from huggingface_hub import InferenceApi

with patch("huggingface_hub.inference_api.HfApi") as mock_hfapi:
llm = HuggingFaceHub(repo_id="model-1", task="text-generation")

mock_hfapi.assert_called_once_with(token="hb-token")

assert type(llm.client) is InferenceApi
assert "model-1" in llm.client.api_url
assert llm.client.task == "text-generation"
assert llm.client.headers["authorization"] == "Bearer hb-token"

llm("Hi, how are you?")


# TODO: Assert the model invocation selectes the task

Check failure on line 75 in libs/community/tests/unit_tests/llms/test_huggingface_hub.py

View workflow job for this annotation

GitHub Actions / Check for spelling errors

selectes ==> selects

0 comments on commit 8033249

Please sign in to comment.