diff --git a/libs/community/langchain_community/llms/huggingface_hub.py b/libs/community/langchain_community/llms/huggingface_hub.py index 10eb8eb7a0464..f432727773121 100644 --- a/libs/community/langchain_community/llms/huggingface_hub.py +++ b/libs/community/langchain_community/llms/huggingface_hub.py @@ -1,3 +1,4 @@ +import json from typing import Any, Dict, List, Mapping, Optional from langchain_core.callbacks import CallbackManagerForLLMRun @@ -7,8 +8,15 @@ from langchain_community.llms.utils import enforce_stop_tokens -DEFAULT_REPO_ID = "gpt2" -VALID_TASKS = ("text2text-generation", "text-generation", "summarization") +# key: task +# value: key in the output dictionary +VALID_TASKS_DICT = { + "translation": "translation_text", + "summarization": "summary_text", + "conversational": "generated_text", + "text-generation": "generated_text", + "text2text-generation": "generated_text", +} class HuggingFaceHub(LLM): @@ -18,7 +26,8 @@ class HuggingFaceHub(LLM): environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass it as a named parameter to the constructor. - Only supports `text-generation`, `text2text-generation` and `summarization` for now. + Supports `text-generation`, `text2text-generation`, `conversational`, `translation`, + and `summarization`. Example: .. code-block:: python @@ -28,11 +37,13 @@ class HuggingFaceHub(LLM): """ client: Any #: :meta private: - repo_id: str = DEFAULT_REPO_ID - """Model name to use.""" + repo_id: Optional[str] = None + """Model name to use. + If not provided, the default model for the chosen task will be used.""" task: Optional[str] = None """Task to call the model with. - Should be a task that returns `generated_text` or `summary_text`.""" + Should be a task that returns `generated_text`, `summary_text`, + or `translation_text`.""" model_kwargs: Optional[dict] = None """Keyword arguments to pass to the model.""" @@ -50,18 +61,27 @@ def validate_environment(cls, values: Dict) -> Dict: values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN" ) try: - from huggingface_hub.inference_api import InferenceApi + from huggingface_hub import HfApi, InferenceClient repo_id = values["repo_id"] - client = InferenceApi( - repo_id=repo_id, + client = InferenceClient( + model=repo_id, token=huggingfacehub_api_token, - task=values.get("task"), ) - if client.task not in VALID_TASKS: + if not values["task"]: + if not repo_id: + raise ValueError( + "Must specify either `repo_id` or `task`, or both." + ) + # Use the recommended task for the chosen model + model_info = HfApi(token=huggingfacehub_api_token).model_info( + repo_id=repo_id + ) + values["task"] = model_info.pipeline_tag + if values["task"] not in VALID_TASKS_DICT: raise ValueError( - f"Got invalid task {client.task}, " - f"currently only {VALID_TASKS} are supported" + f"Got invalid task {values['task']}, " + f"currently only {VALID_TASKS_DICT.keys()} are supported" ) values["client"] = client except ImportError: @@ -108,23 +128,20 @@ def _call( """ _model_kwargs = self.model_kwargs or {} params = {**_model_kwargs, **kwargs} - response = self.client(inputs=prompt, params=params) + + response = self.client.post( + json={"inputs": prompt, "params": params}, task=self.task + ) + response = json.loads(response.decode()) if "error" in response: raise ValueError(f"Error raised by inference API: {response['error']}") - if self.client.task == "text-generation": - # Text generation sometimes return includes the starter text. - text = response[0]["generated_text"] - if text.startswith(prompt): - text = response[0]["generated_text"][len(prompt) :] - elif self.client.task == "text2text-generation": - text = response[0]["generated_text"] - elif self.client.task == "summarization": - text = response[0]["summary_text"] + + response_key = VALID_TASKS_DICT[self.task] # type: ignore + if isinstance(response, list): + text = response[0][response_key] else: - raise ValueError( - f"Got invalid task {self.client.task}, " - f"currently only {VALID_TASKS} are supported" - ) + text = response[response_key] + if stop is not None: # This is a bit hacky, but I can't figure out a better way to enforce # stop tokens when making calls to huggingface_hub.