diff --git a/deltachat_chatbot/gpt4all.py b/deltachat_chatbot/gpt4all.py index 6ca2608..0a0027e 100644 --- a/deltachat_chatbot/gpt4all.py +++ b/deltachat_chatbot/gpt4all.py @@ -9,7 +9,6 @@ from gpt4all._pyllmodel import ( LLModel, PromptCallback, - RecalculateCallback, ResponseCallback, ResponseCallbackType, empty_response_callback, @@ -76,7 +75,7 @@ def prompt_model( ctypes.c_char_p(prompt_template.encode()), PromptCallback(self._prompt_callback), ResponseCallback(self._callback_decoder(callback)), - RecalculateCallback(self._recalculate_callback), + True, self.context, special, ctypes.c_char_p(fake_reply.encode()) if fake_reply else ctypes.c_char_p(), @@ -104,14 +103,15 @@ def generate( # noqa repeat_last_n: int = 64, n_batch: int = 8, n_predict: int | None = None, + streaming: bool = False, + fake_reply: str = "", callback: ResponseCallbackType = empty_response_callback, - fake_reply="", ) -> Any: """ Generate outputs from any GPT4All model. Args: - prompt: The prompt for the model the complete. + prompt: The prompt for the model to complete. max_tokens: The maximum number of tokens to generate. temp: The model temperature. Larger values increase creativity but decrease factuality. top_k: Randomly sample from the top_k most likely tokens at each generation step. Set this to 1 for greedy decoding. @@ -121,6 +121,8 @@ def generate( # noqa repeat_last_n: How far in the models generation history to apply the repeat penalty. n_batch: Number of prompt tokens processed in parallel. Larger values decrease latency but increase resource requirements. n_predict: Equivalent to max_tokens, exists for backwards compatibility. + streaming: If True, this method will instead return a generator that yields tokens as the model generates them. + fake_reply: A spoofed reply for the given prompt, used as a way to load chat history. callback: A function with arguments token_id:int and response:str, which receives the tokens from the model as they are generated and stops the generation by returning False. Returns: @@ -199,6 +201,15 @@ def _callback(token_id: int, response: str) -> bool: return _callback + # Send the request to the model + if streaming: + return self.model.prompt_model_streaming( + prompt, + prompt_template, + _callback_wrapper(callback, output_collector), + **generate_kwargs, + ) + self.model.prompt_model( # noqa prompt, prompt_template,