Skip to content

Commit

Permalink
stopping generation at the first ending token
Browse files Browse the repository at this point in the history
  • Loading branch information
fractalego committed Dec 6, 2023
1 parent 883a1e3 commit 813321c
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 6 deletions.
3 changes: 1 addition & 2 deletions wafl/connectors/local/local_llm_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,14 @@ def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
num_ending_tokens = 0
max_endings = input_ids.shape[0]
for token_ids in input_ids:
generated_text = self._tokenizer.decode(token_ids)
for last_string in self._last_strings:
if generated_text.endswith(last_string):
num_ending_tokens += 1
break

if num_ending_tokens >= max_endings:
if num_ending_tokens >= 1:
return True

return False
Expand Down
4 changes: 2 additions & 2 deletions wafl/connectors/remote/remote_llm_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async def predict(self, prompt: str, temperature=None, num_tokens=None) -> str:
) as session:
async with session.post(self._server_url, json=payload) as response:
answer = await response.text()
return select_best_answer(answer.split("<||>"))
return select_best_answer(answer.split("<||>"), self._last_strings)

return "UNKNOWN"

Expand Down Expand Up @@ -93,7 +93,7 @@ async def generate(self, prompt: str) -> str:
return candidate_answer

async def check_connection(self):
payload = {"data": "test", "temperature": 0.6, "num_tokens": 1}
payload = {"data": "test", "temperature": 0.6, "num_tokens": 1, "num_replicas": 3}
try:
async with aiohttp.ClientSession(
conn_timeout=3, connector=aiohttp.TCPConnector(ssl=False)
Expand Down
4 changes: 2 additions & 2 deletions wafl/connectors/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
def select_best_answer(answers):
special_words = ["</remember>", "</execute>", "result ="]
def select_best_answer(answers, last_strings):
special_words = last_strings + ["</remember>", "</execute>", "result ="]
return sorted(answers, key=lambda x: sum([x.count(word) for word in special_words]))[-1]

0 comments on commit 813321c

Please sign in to comment.