diff --git a/libs/community/langchain_community/llms/gpt4all.py b/libs/community/langchain_community/llms/gpt4all.py index 8b347ceb5c38b..d7fb811bf71f2 100644 --- a/libs/community/langchain_community/llms/gpt4all.py +++ b/libs/community/langchain_community/llms/gpt4all.py @@ -1,7 +1,10 @@ from functools import partial from typing import Any, Dict, List, Mapping, Optional, Set -from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain_core.language_models.llms import LLM from langchain_core.pydantic_v1 import Extra, Field, root_validator @@ -211,3 +214,40 @@ def _call( if stop is not None: text = enforce_stop_tokens(text, stop) return text + + async def _acall( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Asynchronous Call out to GPT4All's generate method. + + Args: + prompt: The prompt to pass into the model. + stop: A list of strings to stop generation when encountered. + run_manager: an async callback manager + Returns: + The string generated by the model. + + Example: + .. code-block:: python + + prompt = "Once upon a time, " + response = model(prompt, n_predict=55) + """ + text_callback = None + if run_manager: + text_callback = partial(run_manager.on_llm_new_token, verbose=self.verbose) + + params = {**self._default_params(), **kwargs} + + text = "" + for token in self.client.generate(prompt, **params): + if text_callback: + await text_callback(token) + text += token + if stop is not None: + text = enforce_stop_tokens(text, stop) + return text