Skip to content

Commit

Permalink
revert dep change, update deltachat_chatbot/gpt4all.py
Browse files Browse the repository at this point in the history
  • Loading branch information
adbenitez committed Sep 4, 2024
1 parent ae46ea0 commit b2f21b2
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions deltachat_chatbot/gpt4all.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from gpt4all._pyllmodel import (
LLModel,
PromptCallback,
RecalculateCallback,
ResponseCallback,
ResponseCallbackType,
empty_response_callback,
Expand Down Expand Up @@ -76,7 +75,7 @@ def prompt_model(
ctypes.c_char_p(prompt_template.encode()),
PromptCallback(self._prompt_callback),
ResponseCallback(self._callback_decoder(callback)),
RecalculateCallback(self._recalculate_callback),
True,
self.context,
special,
ctypes.c_char_p(fake_reply.encode()) if fake_reply else ctypes.c_char_p(),
Expand Down Expand Up @@ -104,14 +103,15 @@ def generate( # noqa
repeat_last_n: int = 64,
n_batch: int = 8,
n_predict: int | None = None,
streaming: bool = False,
fake_reply: str = "",
callback: ResponseCallbackType = empty_response_callback,
fake_reply="",
) -> Any:
"""
Generate outputs from any GPT4All model.
Args:
prompt: The prompt for the model the complete.
prompt: The prompt for the model to complete.
max_tokens: The maximum number of tokens to generate.
temp: The model temperature. Larger values increase creativity but decrease factuality.
top_k: Randomly sample from the top_k most likely tokens at each generation step. Set this to 1 for greedy decoding.
Expand All @@ -121,6 +121,8 @@ def generate( # noqa
repeat_last_n: How far in the models generation history to apply the repeat penalty.
n_batch: Number of prompt tokens processed in parallel. Larger values decrease latency but increase resource requirements.
n_predict: Equivalent to max_tokens, exists for backwards compatibility.
streaming: If True, this method will instead return a generator that yields tokens as the model generates them.
fake_reply: A spoofed reply for the given prompt, used as a way to load chat history.
callback: A function with arguments token_id:int and response:str, which receives the tokens from the model as they are generated and stops the generation by returning False.
Returns:
Expand Down Expand Up @@ -199,6 +201,15 @@ def _callback(token_id: int, response: str) -> bool:

return _callback

# Send the request to the model
if streaming:
return self.model.prompt_model_streaming(
prompt,
prompt_template,
_callback_wrapper(callback, output_collector),
**generate_kwargs,
)

self.model.prompt_model( # noqa
prompt,
prompt_template,
Expand Down

0 comments on commit b2f21b2

Please sign in to comment.