Skip to content

Commit

Permalink
Merge pull request #72 from fractalego/llm-webserving
Browse files Browse the repository at this point in the history
Llm webserving
  • Loading branch information
fractalego authored Dec 6, 2023
2 parents ccb83d5 + 813321c commit fc7abb8
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 9 deletions.
7 changes: 6 additions & 1 deletion todo.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
### TODO

* only one rule at the time!!
* if a rule is executed, it is then consumed

* bug: the system kept executing "The bot predicts:"

**** what to do with conversational collapse?
- the system just repeats the last utterance
- how much is 2+2, what is the real name of bon jovi, how tall is mt everest
- the collapse is due to <execute>NUMBER</execute> being returing unknown (execute becomes more likely after one prior <execute>)
- the system is also more likely to return unknown after one unknown. Select the answer that has no unknowns?

* solve math expression execute (import do not work in eval and exec needs a print on stdout)

* add errors when loading config file (add log to stderr)
* add a memory that the execute command was called/not called.

* no more than one rule (two rules it already gets confused)
Expand Down
26 changes: 25 additions & 1 deletion wafl/answerer/dialogue_answerer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,23 @@
from wafl.simple_text_processing.deixis import from_user_to_bot
from wafl.simple_text_processing.questions import is_question


def get_last_bot_utterance(dialogue_items):
for item in reversed(dialogue_items):
if item[1].startswith("bot:"):
return item[1]

return ""


def get_last_user_utterance(dialogue_items):
for item in reversed(dialogue_items):
if item[1].startswith("user:"):
return item[1]

return ""


class DialogueAnswerer(BaseAnswerer):
def __init__(self, config, knowledge, interface, code_path, logger):
self._bridge = LLMChitChatAnswerBridge(config)
Expand All @@ -23,6 +40,7 @@ def __init__(self, config, knowledge, interface, code_path, logger):
self._prior_facts = []
self._prior_rules = []
self._init_python_module(code_path.replace(".py", ""))
self._max_predictions = 3

async def answer(self, query_text):
print(__name__)
Expand All @@ -45,10 +63,12 @@ async def answer(self, query_text):

dialogue_items = dialogue
dialogue_items = sorted(dialogue_items, key=lambda x: x[0])
last_bot_utterance = get_last_bot_utterance(dialogue_items)
last_user_utterance = get_last_user_utterance(dialogue_items)
dialogue_items = [item[1] for item in dialogue_items if item[0] >= start_time]
dialogue_items = "\n".join(dialogue_items)

while True:
for _ in range(self._max_predictions):
original_answer_text = await self._bridge.get_answer(
text=facts,
dialogue=dialogue_items,
Expand All @@ -58,6 +78,10 @@ async def answer(self, query_text):
answer_text, memories = await self._substitute_memory_in_answer_and_get_memories_if_present(
await self._substitute_results_in_answer(original_answer_text)
)
if answer_text == last_bot_utterance:
dialogue_items = last_user_utterance
continue

if not memories:
break

Expand Down
12 changes: 9 additions & 3 deletions wafl/connectors/local/local_llm_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,15 @@ def __init__(self, tokenizer: "AutoTokenizer", last_strings: List[str]):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
generated_text = self._tokenizer.decode(input_ids[0], skip_special_tokens=True)
for last_string in self._last_strings:
if generated_text.endswith(last_string):
num_ending_tokens = 0
for token_ids in input_ids:
generated_text = self._tokenizer.decode(token_ids)
for last_string in self._last_strings:
if generated_text.endswith(last_string):
num_ending_tokens += 1
break

if num_ending_tokens >= 1:
return True

return False
Expand Down
4 changes: 2 additions & 2 deletions wafl/connectors/remote/remote_llm_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async def predict(self, prompt: str, temperature=None, num_tokens=None) -> str:
) as session:
async with session.post(self._server_url, json=payload) as response:
answer = await response.text()
return select_best_answer(answer.split("<||>"))
return select_best_answer(answer.split("<||>"), self._last_strings)

return "UNKNOWN"

Expand Down Expand Up @@ -93,7 +93,7 @@ async def generate(self, prompt: str) -> str:
return candidate_answer

async def check_connection(self):
payload = {"data": "test", "temperature": 0.6, "num_tokens": 1}
payload = {"data": "test", "temperature": 0.6, "num_tokens": 1, "num_replicas": 3}
try:
async with aiohttp.ClientSession(
conn_timeout=3, connector=aiohttp.TCPConnector(ssl=False)
Expand Down
4 changes: 2 additions & 2 deletions wafl/connectors/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
def select_best_answer(answers):
special_words = ["</remember>", "</execute>", "result ="]
def select_best_answer(answers, last_strings):
special_words = last_strings + ["</remember>", "</execute>", "result ="]
return sorted(answers, key=lambda x: sum([x.count(word) for word in special_words]))[-1]

0 comments on commit fc7abb8

Please sign in to comment.