Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(llms): support more tasks in HuggingFaceHub LLM and remove deprecated dep #14406

Merged
merged 7 commits into from
Jan 24, 2024
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 44 additions & 25 deletions libs/community/langchain_community/llms/huggingface_hub.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from typing import Any, Dict, List, Mapping, Optional

from langchain_core.callbacks import CallbackManagerForLLMRun
Expand All @@ -7,8 +8,15 @@

from langchain_community.llms.utils import enforce_stop_tokens

DEFAULT_REPO_ID = "gpt2"
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
# key: task
# value: key in the output dictionary
VALID_TASKS_DICT = {
"translation": "translation_text",
"summarization": "summary_text",
"conversational": "generated_text",
"text-generation": "generated_text",
"text2text-generation": "generated_text",
}


class HuggingFaceHub(LLM):
Expand All @@ -18,7 +26,8 @@ class HuggingFaceHub(LLM):
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
it as a named parameter to the constructor.

Only supports `text-generation`, `text2text-generation` and `summarization` for now.
Supports `text-generation`, `text2text-generation`, `conversational`, `translation`,
and `summarization`.

Example:
.. code-block:: python
Expand All @@ -28,11 +37,13 @@ class HuggingFaceHub(LLM):
"""

client: Any #: :meta private:
repo_id: str = DEFAULT_REPO_ID
"""Model name to use."""
repo_id: Optional[str] = None
"""Model name to use.
If not provided, the default model for the chosen task will be used."""
task: Optional[str] = None
"""Task to call the model with.
Should be a task that returns `generated_text` or `summary_text`."""
Should be a task that returns `generated_text`, `summary_text`,
or `translation_text`."""
model_kwargs: Optional[dict] = None
"""Keyword arguments to pass to the model."""

Expand All @@ -50,18 +61,27 @@ def validate_environment(cls, values: Dict) -> Dict:
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
)
try:
from huggingface_hub.inference_api import InferenceApi
from huggingface_hub import HfApi, InferenceClient

repo_id = values["repo_id"]
client = InferenceApi(
repo_id=repo_id,
client = InferenceClient(
model=repo_id,
token=huggingfacehub_api_token,
task=values.get("task"),
)
if client.task not in VALID_TASKS:
if not values["task"]:
if not repo_id:
raise ValueError(
"Must specify either `repo_id` or `task`, or both."
)
# Use the recommended task for the chosen model
model_info = HfApi(token=huggingfacehub_api_token).model_info(
repo_id=repo_id
)
values["task"] = model_info.pipeline_tag
if values["task"] not in VALID_TASKS_DICT:
raise ValueError(
f"Got invalid task {client.task}, "
f"currently only {VALID_TASKS} are supported"
f"Got invalid task {values['task']}, "
f"currently only {VALID_TASKS_DICT.keys()} are supported"
)
values["client"] = client
except ImportError:
Expand Down Expand Up @@ -108,21 +128,20 @@ def _call(
"""
_model_kwargs = self.model_kwargs or {}
params = {**_model_kwargs, **kwargs}
response = self.client(inputs=prompt, params=params)

response = self.client.post(
json={"inputs": prompt, "params": params}, task=self.task
)
response = json.loads(response.decode())
Comment on lines +132 to +135
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I would use the task-specific methods instead of client.post directly. For the record, client.post(...) is what's used internally by the client to make the calls. It is perfectly valid to use it but you also have higher-level method -like client.text_generation(...)- that are more appropriate to handle inputs and outputs correctly. See supported tasks and their documentation on this page.

So instead of those two lines, I would suggest something like

if self.task == "text-generation":
    response = client.text_generation(...)
elif self.task == "conversational"
    response = client.conversational(...)
...

Inputs and outputs may vary depending on the selected task. It would also avoid the logic if isinstance(response, list): ... else: ... below that might be error prone.

Copy link
Contributor Author

@mspronesti mspronesti Jan 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Wauplin, thanks for your precious feedback! I had a look at the InferenceClient implementation in hf-hub, but I couldn't find any method handling 'text2text-generation'. This task is still supported by post though.

Does this porting to the newer higher-level API look good to you?

        if self.task == "text-generation":
            response = self.client.text_generation(prompt, **params)
        elif self.task == "conversational":
            response = self.client.conversational(prompt, **params)["generated_text"]
        elif self.task == "translation":
            response = self.client.translation(prompt, **params)
        elif self.task == "summarization":
            response = self.client.summarization(prompt, parameters=params)
        else:
            raise ValueError(f"Invalid task {self.task}")

Copy link
Contributor

@Wauplin Wauplin Jan 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me yes! It should be more robust as well. Thanks for making the change :)

About text2text-generation, we are trying to deprecate it in favor of text-generation (for harmonization) cc @osanseviero. I think that using self.client.text_generation should be fine but would be best to test it first. According to the docs, the API should be the same (see here).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The API is the same, but the underlying pipelines in transformers are different (one is encoder-decoder while the other is decoder-only), so not sure they would work out of the box.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hwchase17 @baskaryan Any thoughts/suggestions on this?

if "error" in response:
raise ValueError(f"Error raised by inference API: {response['error']}")
if self.client.task == "text-generation":
# Text generation return includes the starter text.
text = response[0]["generated_text"][len(prompt) :]
elif self.client.task == "text2text-generation":
text = response[0]["generated_text"]
elif self.client.task == "summarization":
text = response[0]["summary_text"]

response_key = VALID_TASKS_DICT[self.task] # type: ignore
if isinstance(response, list):
text = response[0][response_key]
else:
raise ValueError(
f"Got invalid task {self.client.task}, "
f"currently only {VALID_TASKS} are supported"
)
text = response[response_key]

if stop is not None:
# This is a bit hacky, but I can't figure out a better way to enforce
# stop tokens when making calls to huggingface_hub.
Expand Down
Loading