From ebc78f23b279341570339fc66d53e5f9f4c61cc6 Mon Sep 17 00:00:00 2001
From: Rue Tokino <79139779+ruecat@users.noreply.github.com>
Date: Sun, 12 May 2024 22:57:35 +0300
Subject: [PATCH] Markdown & Stream regression
---
bot/func/{functions.py => interactions.py} | 50 ++---
bot/run.py | 230 +++++++++------------
2 files changed, 113 insertions(+), 167 deletions(-)
rename bot/func/{functions.py => interactions.py} (79%)
diff --git a/bot/func/functions.py b/bot/func/interactions.py
similarity index 79%
rename from bot/func/functions.py
rename to bot/func/interactions.py
index 4518f07..92cc3e6 100644
--- a/bot/func/functions.py
+++ b/bot/func/interactions.py
@@ -1,3 +1,4 @@
+# >> interactions
import logging
import os
import aiohttp
@@ -6,31 +7,19 @@
from asyncio import Lock
from functools import wraps
from dotenv import load_dotenv
-# --- Environment
load_dotenv()
-# --- Environment Checker
token = os.getenv("TOKEN")
allowed_ids = list(map(int, os.getenv("USER_IDS", "").split(",")))
admin_ids = list(map(int, os.getenv("ADMIN_IDS", "").split(",")))
ollama_base_url = os.getenv("OLLAMA_BASE_URL")
ollama_port = os.getenv("OLLAMA_PORT", "11434")
log_level_str = os.getenv("LOG_LEVEL", "INFO")
-
-# --- Other
log_levels = list(logging._levelToName.values())
-# ['CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG', 'NOTSET']
-
-# Set default level to be INFO
if log_level_str not in log_levels:
log_level = logging.DEBUG
else:
log_level = logging.getLevelName(log_level_str)
-
logging.basicConfig(level=log_level)
-
-
-# Ollama API
-# Model List
async def model_list():
async with aiohttp.ClientSession() as session:
url = f"http://{ollama_base_url}:{ollama_port}/api/tags"
@@ -41,20 +30,27 @@ async def model_list():
else:
return []
async def generate(payload: dict, modelname: str, prompt: str):
- # try:
async with aiohttp.ClientSession() as session:
url = f"http://{ollama_base_url}:{ollama_port}/api/chat"
- # Stream from API
- async with session.post(url, json=payload) as response:
- async for chunk in response.content:
- if chunk:
- decoded_chunk = chunk.decode()
- if decoded_chunk.strip():
- yield json.loads(decoded_chunk)
+ try:
+ async with session.post(url, json=payload) as response:
+ if response.status != 200:
+ raise aiohttp.ClientResponseError(
+ status=response.status, message=response.reason
+ )
+ buffer = b""
+
+ async for chunk in response.content.iter_any():
+ buffer += chunk
+ while b"\n" in buffer:
+ line, buffer = buffer.split(b"\n", 1)
+ line = line.strip()
+ if line:
+ yield json.loads(line)
+ except aiohttp.ClientError as e:
+ print(f"Error during request: {e}")
-
-# Aiogram functions & wraps
def perms_allowed(func):
@wraps(func)
async def wrapper(message: types.Message = None, query: types.CallbackQuery = None):
@@ -103,16 +99,6 @@ async def wrapper(message: types.Message = None, query: types.CallbackQuery = No
)
return wrapper
-
-
-def md_autofixer(text: str) -> str:
- # In MarkdownV2, these characters must be escaped: _ * [ ] ( ) ~ ` > # + - = | { } . !
- escape_chars = r"_[]()~>#+-=|{}.!"
- # Use a backslash to escape special characters
- return "".join("\\" + char if char in escape_chars else char for char in text)
-
-
-# Context-Related
class contextLock:
lock = Lock()
diff --git a/bot/run.py b/bot/run.py
index 53ee099..5073d16 100644
--- a/bot/run.py
+++ b/bot/run.py
@@ -3,9 +3,7 @@
from aiogram.filters.command import Command, CommandStart
from aiogram.types import Message
from aiogram.utils.keyboard import InlineKeyboardBuilder
-from func.functions import *
-
-# Other
+from func.interactions import *
import asyncio
import traceback
import io
@@ -13,10 +11,15 @@
bot = Bot(token=token)
dp = Dispatcher()
-builder = InlineKeyboardBuilder()
-builder.row(
- types.InlineKeyboardButton(text="ℹ️ About", callback_data="info"),
- types.InlineKeyboardButton(text="⚙️ Settings", callback_data="modelmanager"),
+start_kb = InlineKeyboardBuilder()
+settings_kb = InlineKeyboardBuilder()
+start_kb.row(
+ types.InlineKeyboardButton(text="ℹ️ About", callback_data="about"),
+ types.InlineKeyboardButton(text="⚙️ Settings", callback_data="settings"),
+)
+settings_kb.row(
+ types.InlineKeyboardButton(text="🔄 Switch LLM", callback_data="switchllm"),
+ types.InlineKeyboardButton(text="✏️ Edit system prompt", callback_data="editsystemprompt"),
)
commands = [
@@ -24,14 +27,10 @@
types.BotCommand(command="reset", description="Reset Chat"),
types.BotCommand(command="history", description="Look through messages"),
]
-
-# Context variables for OllamaAPI
ACTIVE_CHATS = {}
ACTIVE_CHATS_LOCK = contextLock()
modelname = os.getenv("INITMODEL")
mention = None
-
-# Telegram group types
CHAT_TYPE_GROUP = "group"
CHAT_TYPE_SUPERGROUP = "supergroup"
@@ -41,29 +40,21 @@ def is_mentioned_in_group_or_supergroup(message):
(message.text is not None and message.text.startswith(mention))
or (message.caption is not None and message.caption.startswith(mention))
)
-
-
async def get_bot_info():
global mention
if mention is None:
get = await bot.get_me()
mention = f"@{get.username}"
return mention
-
-
-# /start command
@dp.message(CommandStart())
async def command_start_handler(message: Message) -> None:
start_message = f"Welcome, {message.from_user.full_name}!"
await message.answer(
start_message,
parse_mode=ParseMode.HTML,
- reply_markup=builder.as_markup(),
+ reply_markup=start_kb.as_markup(),
disable_web_page_preview=True,
)
-
-
-# /reset command, wipes context (history)
@dp.message(Command("reset"))
async def command_reset_handler(message: Message) -> None:
if message.from_user.id in allowed_ids:
@@ -75,9 +66,6 @@ async def command_reset_handler(message: Message) -> None:
chat_id=message.chat.id,
text="Chat has been reset",
)
-
-
-# /history command | Displays dialogs between LLM and USER
@dp.message(Command("history"))
async def command_get_context_handler(message: Message) -> None:
if message.from_user.id in allowed_ids:
@@ -96,12 +84,20 @@ async def command_get_context_handler(message: Message) -> None:
chat_id=message.chat.id,
text="No chat history available for this user",
)
+@dp.callback_query(lambda query: query.data == "settings")
+async def settings_callback_handler(query: types.CallbackQuery):
+ await bot.send_message(
+ chat_id=query.message.chat.id,
+ text=f"Choose the right option.",
+ parse_mode=ParseMode.HTML,
+ disable_web_page_preview=True,
+ reply_markup=settings_kb.as_markup()
+ )
-
-@dp.callback_query(lambda query: query.data == "modelmanager")
-async def modelmanager_callback_handler(query: types.CallbackQuery):
+@dp.callback_query(lambda query: query.data == "switchllm")
+async def switchllm_callback_handler(query: types.CallbackQuery):
models = await model_list()
- modelmanager_builder = InlineKeyboardBuilder()
+ switchllm_builder = InlineKeyboardBuilder()
for model in models:
modelname = model["name"]
modelfamilies = ""
@@ -112,17 +108,14 @@ async def modelmanager_callback_handler(query: types.CallbackQuery):
[modelicon[family] for family in model["details"]["families"]]
)
except KeyError as e:
- # Use a default value when the key is not found
modelfamilies = f"✨"
- # Add a button for each model
- modelmanager_builder.row(
+ switchllm_builder.row(
types.InlineKeyboardButton(
text=f"{modelname} {modelfamilies}", callback_data=f"model_{modelname}"
)
)
await query.message.edit_text(
- f"{len(models)} models available.\n🦙 = Regular\n🦙📷 = Multimodal",
- reply_markup=modelmanager_builder.as_markup(),
+ f"{len(models)} models available.\n🦙 = Regular\n🦙📷 = Multimodal", reply_markup=switchllm_builder.as_markup(),
)
@@ -134,20 +127,17 @@ async def model_callback_handler(query: types.CallbackQuery):
await query.answer(f"Chosen model: {modelname}")
-@dp.callback_query(lambda query: query.data == "info")
+@dp.callback_query(lambda query: query.data == "about")
@perms_admins
-async def info_callback_handler(query: types.CallbackQuery):
+async def about_callback_handler(query: types.CallbackQuery):
dotenv_model = os.getenv("INITMODEL")
global modelname
await bot.send_message(
chat_id=query.message.chat.id,
- text=f"About Models\nCurrent model: {modelname}
\nDefault model: {dotenv_model}
\nThis project is under MIT License.\nSource Code",
+ text=f"Your LLMs\nCurrently using: {modelname}
\nDefault in .env: {dotenv_model}
\nThis project is under MIT License.\nSource Code",
parse_mode=ParseMode.HTML,
disable_web_page_preview=True,
)
-
-
-# React on message | LLM will respond on user's message or mention in groups
@dp.message()
@perms_allowed
async def handle_message(message: types.Message):
@@ -155,61 +145,85 @@ async def handle_message(message: types.Message):
if message.chat.type == "private":
await ollama_request(message)
if is_mentioned_in_group_or_supergroup(message):
- # Remove the mention from the message
if message.text is not None:
text_without_mention = message.text.replace(mention, "").strip()
prompt = text_without_mention
else:
text_without_mention = message.caption.replace(mention, "").strip()
prompt = text_without_mention
-
- # Pass the modified text and bot instance to ollama_request
await ollama_request(message, prompt)
-...
+async def process_image(message):
+ image_base64 = ""
+ if message.content_type == "photo":
+ image_buffer = io.BytesIO()
+ await bot.download(message.photo[-1], destination=image_buffer)
+ image_base64 = base64.b64encode(image_buffer.getvalue()).decode("utf-8")
+ return image_base64
+
+async def add_prompt_to_active_chats(message, prompt, image_base64, modelname):
+ async with ACTIVE_CHATS_LOCK:
+ if ACTIVE_CHATS.get(message.from_user.id) is None:
+ ACTIVE_CHATS[message.from_user.id] = {
+ "model": modelname,
+ "messages": [
+ {
+ "role": "user",
+ "content": prompt,
+ "images": ([image_base64] if image_base64 else []),
+ }
+ ],
+ "stream": True,
+ }
+ else:
+ ACTIVE_CHATS[message.from_user.id]["messages"].append(
+ {
+ "role": "user",
+ "content": prompt,
+ "images": ([image_base64] if image_base64 else []),
+ }
+ )
+async def handle_response(message, response_data, full_response):
+ full_response_stripped = full_response.strip()
+ if full_response_stripped == "":
+ return
+ if response_data.get("done"):
+ text = f"{full_response_stripped}\n\n⚙️ {modelname}\nGenerated in {response_data.get('total_duration') / 1e9:.2f}s."
+ await send_response(message, text)
+ async with ACTIVE_CHATS_LOCK:
+ if ACTIVE_CHATS.get(message.from_user.id) is not None:
+ ACTIVE_CHATS[message.from_user.id]["messages"].append(
+ {"role": "assistant", "content": full_response_stripped}
+ )
+ logging.info(
+ f"[Response]: '{full_response_stripped}' for {message.from_user.first_name} {message.from_user.last_name}"
+ )
+ return True
+ return False
+
+async def send_response(message, text):
+ if message.chat.id == message.from_user.id:
+ await bot.send_message(chat_id=message.chat.id, text=text)
+ else:
+ await bot.edit_message_text(
+ chat_id=message.chat.id,
+ message_id=message.message_id,
+ text=text
+ )
async def ollama_request(message: types.Message, prompt: str = None):
try:
+ full_response = ""
await bot.send_chat_action(message.chat.id, "typing")
- image_base64 = ""
- if message.content_type == "photo":
- image_buffer = io.BytesIO()
- await bot.download(message.photo[-1], destination=image_buffer)
- image_base64 = base64.b64encode(image_buffer.getvalue()).decode("utf-8")
-
+ image_base64 = await process_image(message)
if prompt is None:
prompt = message.text or message.caption
- full_response = ""
- sent_message = None
- last_sent_text = None
-
- async with ACTIVE_CHATS_LOCK:
- # Add prompt to active chats object
- if ACTIVE_CHATS.get(message.from_user.id) is None:
- ACTIVE_CHATS[message.from_user.id] = {
- "model": modelname,
- "messages": [
- {
- "role": "user",
- "content": prompt,
- "images": ([image_base64] if image_base64 else []),
- }
- ],
- "stream": True,
- }
- else:
- ACTIVE_CHATS[message.from_user.id]["messages"].append(
- {
- "role": "user",
- "content": prompt,
- "images": ([image_base64] if image_base64 else []),
- }
- )
+ await add_prompt_to_active_chats(message, prompt, image_base64, modelname)
logging.info(
- f"[Request]: Processing '{prompt}' for {message.from_user.first_name} {message.from_user.last_name}"
+ f"[OllamaAPI]: Processing '{prompt}' for {message.from_user.first_name} {message.from_user.last_name}"
)
payload = ACTIVE_CHATS.get(message.from_user.id)
async for response_data in generate(payload, modelname, prompt):
@@ -218,71 +232,17 @@ async def ollama_request(message: types.Message, prompt: str = None):
continue
chunk = msg.get("content", "")
full_response += chunk
- full_response_stripped = full_response.strip()
-
- # avoid Bad Request: message text is empty
- if full_response_stripped == "":
- continue
-
- if "." in chunk or "\n" in chunk or "!" in chunk or "?" in chunk:
- if sent_message:
- if last_sent_text != full_response_stripped:
- await bot.edit_message_text(
- chat_id=message.chat.id,
- message_id=sent_message.message_id,
- text=full_response_stripped,
- )
- last_sent_text = full_response_stripped
- else:
- sent_message = await bot.send_message(
- chat_id=message.chat.id,
- text=full_response_stripped,
- reply_to_message_id=message.message_id,
- )
- last_sent_text = full_response_stripped
-
- if response_data.get("done"):
- if full_response_stripped and last_sent_text != full_response_stripped:
- if sent_message:
- await bot.edit_message_text(
- chat_id=message.chat.id,
- message_id=sent_message.message_id,
- text=full_response_stripped,
- )
- else:
- sent_message = await bot.send_message(
- chat_id=message.chat.id, text=full_response_stripped
- )
- await bot.edit_message_text(
- chat_id=message.chat.id,
- message_id=sent_message.message_id,
- text=md_autofixer(
- full_response_stripped
- + f"\n\nCurrent Model: `{modelname}`**\n**Generated in {response_data.get('total_duration') / 1e9:.2f}s"
- ),
- parse_mode=ParseMode.MARKDOWN_V2,
- )
- async with ACTIVE_CHATS_LOCK:
- if ACTIVE_CHATS.get(message.from_user.id) is not None:
- # Add response to active chats object
- ACTIVE_CHATS[message.from_user.id]["messages"].append(
- {"role": "assistant", "content": full_response_stripped}
- )
- logging.info(
- f"[Response]: '{full_response_stripped}' for {message.from_user.first_name} {message.from_user.last_name}"
- )
- else:
- await bot.send_message(
- chat_id=message.chat.id, text="Chat was reset"
- )
+ if any([c in chunk for c in ".\n!?"]) or response_data.get("done"):
+ if await handle_response(message, response_data, full_response):
+ break
- break
except Exception as e:
+ print(f"-----\n[OllamaAPI-ERR] CAUGHT FAULT!\n{traceback.format_exc()}\n-----")
await bot.send_message(
chat_id=message.chat.id,
- text=f"""Error occurred\n```\n{traceback.format_exc()}\n```""",
- parse_mode=ParseMode.MARKDOWN_V2,
+ text=f"Something went wrong.",
+ parse_mode=ParseMode.HTML,
)