From ebc78f23b279341570339fc66d53e5f9f4c61cc6 Mon Sep 17 00:00:00 2001 From: Rue Tokino <79139779+ruecat@users.noreply.github.com> Date: Sun, 12 May 2024 22:57:35 +0300 Subject: [PATCH] Markdown & Stream regression --- bot/func/{functions.py => interactions.py} | 50 ++--- bot/run.py | 230 +++++++++------------ 2 files changed, 113 insertions(+), 167 deletions(-) rename bot/func/{functions.py => interactions.py} (79%) diff --git a/bot/func/functions.py b/bot/func/interactions.py similarity index 79% rename from bot/func/functions.py rename to bot/func/interactions.py index 4518f07..92cc3e6 100644 --- a/bot/func/functions.py +++ b/bot/func/interactions.py @@ -1,3 +1,4 @@ +# >> interactions import logging import os import aiohttp @@ -6,31 +7,19 @@ from asyncio import Lock from functools import wraps from dotenv import load_dotenv -# --- Environment load_dotenv() -# --- Environment Checker token = os.getenv("TOKEN") allowed_ids = list(map(int, os.getenv("USER_IDS", "").split(","))) admin_ids = list(map(int, os.getenv("ADMIN_IDS", "").split(","))) ollama_base_url = os.getenv("OLLAMA_BASE_URL") ollama_port = os.getenv("OLLAMA_PORT", "11434") log_level_str = os.getenv("LOG_LEVEL", "INFO") - -# --- Other log_levels = list(logging._levelToName.values()) -# ['CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG', 'NOTSET'] - -# Set default level to be INFO if log_level_str not in log_levels: log_level = logging.DEBUG else: log_level = logging.getLevelName(log_level_str) - logging.basicConfig(level=log_level) - - -# Ollama API -# Model List async def model_list(): async with aiohttp.ClientSession() as session: url = f"http://{ollama_base_url}:{ollama_port}/api/tags" @@ -41,20 +30,27 @@ async def model_list(): else: return [] async def generate(payload: dict, modelname: str, prompt: str): - # try: async with aiohttp.ClientSession() as session: url = f"http://{ollama_base_url}:{ollama_port}/api/chat" - # Stream from API - async with session.post(url, json=payload) as response: - async for chunk in response.content: - if chunk: - decoded_chunk = chunk.decode() - if decoded_chunk.strip(): - yield json.loads(decoded_chunk) + try: + async with session.post(url, json=payload) as response: + if response.status != 200: + raise aiohttp.ClientResponseError( + status=response.status, message=response.reason + ) + buffer = b"" + + async for chunk in response.content.iter_any(): + buffer += chunk + while b"\n" in buffer: + line, buffer = buffer.split(b"\n", 1) + line = line.strip() + if line: + yield json.loads(line) + except aiohttp.ClientError as e: + print(f"Error during request: {e}") - -# Aiogram functions & wraps def perms_allowed(func): @wraps(func) async def wrapper(message: types.Message = None, query: types.CallbackQuery = None): @@ -103,16 +99,6 @@ async def wrapper(message: types.Message = None, query: types.CallbackQuery = No ) return wrapper - - -def md_autofixer(text: str) -> str: - # In MarkdownV2, these characters must be escaped: _ * [ ] ( ) ~ ` > # + - = | { } . ! - escape_chars = r"_[]()~>#+-=|{}.!" - # Use a backslash to escape special characters - return "".join("\\" + char if char in escape_chars else char for char in text) - - -# Context-Related class contextLock: lock = Lock() diff --git a/bot/run.py b/bot/run.py index 53ee099..5073d16 100644 --- a/bot/run.py +++ b/bot/run.py @@ -3,9 +3,7 @@ from aiogram.filters.command import Command, CommandStart from aiogram.types import Message from aiogram.utils.keyboard import InlineKeyboardBuilder -from func.functions import * - -# Other +from func.interactions import * import asyncio import traceback import io @@ -13,10 +11,15 @@ bot = Bot(token=token) dp = Dispatcher() -builder = InlineKeyboardBuilder() -builder.row( - types.InlineKeyboardButton(text="ℹ️ About", callback_data="info"), - types.InlineKeyboardButton(text="⚙️ Settings", callback_data="modelmanager"), +start_kb = InlineKeyboardBuilder() +settings_kb = InlineKeyboardBuilder() +start_kb.row( + types.InlineKeyboardButton(text="ℹ️ About", callback_data="about"), + types.InlineKeyboardButton(text="⚙️ Settings", callback_data="settings"), +) +settings_kb.row( + types.InlineKeyboardButton(text="🔄 Switch LLM", callback_data="switchllm"), + types.InlineKeyboardButton(text="✏️ Edit system prompt", callback_data="editsystemprompt"), ) commands = [ @@ -24,14 +27,10 @@ types.BotCommand(command="reset", description="Reset Chat"), types.BotCommand(command="history", description="Look through messages"), ] - -# Context variables for OllamaAPI ACTIVE_CHATS = {} ACTIVE_CHATS_LOCK = contextLock() modelname = os.getenv("INITMODEL") mention = None - -# Telegram group types CHAT_TYPE_GROUP = "group" CHAT_TYPE_SUPERGROUP = "supergroup" @@ -41,29 +40,21 @@ def is_mentioned_in_group_or_supergroup(message): (message.text is not None and message.text.startswith(mention)) or (message.caption is not None and message.caption.startswith(mention)) ) - - async def get_bot_info(): global mention if mention is None: get = await bot.get_me() mention = f"@{get.username}" return mention - - -# /start command @dp.message(CommandStart()) async def command_start_handler(message: Message) -> None: start_message = f"Welcome, {message.from_user.full_name}!" await message.answer( start_message, parse_mode=ParseMode.HTML, - reply_markup=builder.as_markup(), + reply_markup=start_kb.as_markup(), disable_web_page_preview=True, ) - - -# /reset command, wipes context (history) @dp.message(Command("reset")) async def command_reset_handler(message: Message) -> None: if message.from_user.id in allowed_ids: @@ -75,9 +66,6 @@ async def command_reset_handler(message: Message) -> None: chat_id=message.chat.id, text="Chat has been reset", ) - - -# /history command | Displays dialogs between LLM and USER @dp.message(Command("history")) async def command_get_context_handler(message: Message) -> None: if message.from_user.id in allowed_ids: @@ -96,12 +84,20 @@ async def command_get_context_handler(message: Message) -> None: chat_id=message.chat.id, text="No chat history available for this user", ) +@dp.callback_query(lambda query: query.data == "settings") +async def settings_callback_handler(query: types.CallbackQuery): + await bot.send_message( + chat_id=query.message.chat.id, + text=f"Choose the right option.", + parse_mode=ParseMode.HTML, + disable_web_page_preview=True, + reply_markup=settings_kb.as_markup() + ) - -@dp.callback_query(lambda query: query.data == "modelmanager") -async def modelmanager_callback_handler(query: types.CallbackQuery): +@dp.callback_query(lambda query: query.data == "switchllm") +async def switchllm_callback_handler(query: types.CallbackQuery): models = await model_list() - modelmanager_builder = InlineKeyboardBuilder() + switchllm_builder = InlineKeyboardBuilder() for model in models: modelname = model["name"] modelfamilies = "" @@ -112,17 +108,14 @@ async def modelmanager_callback_handler(query: types.CallbackQuery): [modelicon[family] for family in model["details"]["families"]] ) except KeyError as e: - # Use a default value when the key is not found modelfamilies = f"✨" - # Add a button for each model - modelmanager_builder.row( + switchllm_builder.row( types.InlineKeyboardButton( text=f"{modelname} {modelfamilies}", callback_data=f"model_{modelname}" ) ) await query.message.edit_text( - f"{len(models)} models available.\n🦙 = Regular\n🦙📷 = Multimodal", - reply_markup=modelmanager_builder.as_markup(), + f"{len(models)} models available.\n🦙 = Regular\n🦙📷 = Multimodal", reply_markup=switchllm_builder.as_markup(), ) @@ -134,20 +127,17 @@ async def model_callback_handler(query: types.CallbackQuery): await query.answer(f"Chosen model: {modelname}") -@dp.callback_query(lambda query: query.data == "info") +@dp.callback_query(lambda query: query.data == "about") @perms_admins -async def info_callback_handler(query: types.CallbackQuery): +async def about_callback_handler(query: types.CallbackQuery): dotenv_model = os.getenv("INITMODEL") global modelname await bot.send_message( chat_id=query.message.chat.id, - text=f"About Models\nCurrent model: {modelname}\nDefault model: {dotenv_model}\nThis project is under MIT License.\nSource Code", + text=f"Your LLMs\nCurrently using: {modelname}\nDefault in .env: {dotenv_model}\nThis project is under MIT License.\nSource Code", parse_mode=ParseMode.HTML, disable_web_page_preview=True, ) - - -# React on message | LLM will respond on user's message or mention in groups @dp.message() @perms_allowed async def handle_message(message: types.Message): @@ -155,61 +145,85 @@ async def handle_message(message: types.Message): if message.chat.type == "private": await ollama_request(message) if is_mentioned_in_group_or_supergroup(message): - # Remove the mention from the message if message.text is not None: text_without_mention = message.text.replace(mention, "").strip() prompt = text_without_mention else: text_without_mention = message.caption.replace(mention, "").strip() prompt = text_without_mention - - # Pass the modified text and bot instance to ollama_request await ollama_request(message, prompt) -... +async def process_image(message): + image_base64 = "" + if message.content_type == "photo": + image_buffer = io.BytesIO() + await bot.download(message.photo[-1], destination=image_buffer) + image_base64 = base64.b64encode(image_buffer.getvalue()).decode("utf-8") + return image_base64 + +async def add_prompt_to_active_chats(message, prompt, image_base64, modelname): + async with ACTIVE_CHATS_LOCK: + if ACTIVE_CHATS.get(message.from_user.id) is None: + ACTIVE_CHATS[message.from_user.id] = { + "model": modelname, + "messages": [ + { + "role": "user", + "content": prompt, + "images": ([image_base64] if image_base64 else []), + } + ], + "stream": True, + } + else: + ACTIVE_CHATS[message.from_user.id]["messages"].append( + { + "role": "user", + "content": prompt, + "images": ([image_base64] if image_base64 else []), + } + ) +async def handle_response(message, response_data, full_response): + full_response_stripped = full_response.strip() + if full_response_stripped == "": + return + if response_data.get("done"): + text = f"{full_response_stripped}\n\n⚙️ {modelname}\nGenerated in {response_data.get('total_duration') / 1e9:.2f}s." + await send_response(message, text) + async with ACTIVE_CHATS_LOCK: + if ACTIVE_CHATS.get(message.from_user.id) is not None: + ACTIVE_CHATS[message.from_user.id]["messages"].append( + {"role": "assistant", "content": full_response_stripped} + ) + logging.info( + f"[Response]: '{full_response_stripped}' for {message.from_user.first_name} {message.from_user.last_name}" + ) + return True + return False + +async def send_response(message, text): + if message.chat.id == message.from_user.id: + await bot.send_message(chat_id=message.chat.id, text=text) + else: + await bot.edit_message_text( + chat_id=message.chat.id, + message_id=message.message_id, + text=text + ) async def ollama_request(message: types.Message, prompt: str = None): try: + full_response = "" await bot.send_chat_action(message.chat.id, "typing") - image_base64 = "" - if message.content_type == "photo": - image_buffer = io.BytesIO() - await bot.download(message.photo[-1], destination=image_buffer) - image_base64 = base64.b64encode(image_buffer.getvalue()).decode("utf-8") - + image_base64 = await process_image(message) if prompt is None: prompt = message.text or message.caption - full_response = "" - sent_message = None - last_sent_text = None - - async with ACTIVE_CHATS_LOCK: - # Add prompt to active chats object - if ACTIVE_CHATS.get(message.from_user.id) is None: - ACTIVE_CHATS[message.from_user.id] = { - "model": modelname, - "messages": [ - { - "role": "user", - "content": prompt, - "images": ([image_base64] if image_base64 else []), - } - ], - "stream": True, - } - else: - ACTIVE_CHATS[message.from_user.id]["messages"].append( - { - "role": "user", - "content": prompt, - "images": ([image_base64] if image_base64 else []), - } - ) + await add_prompt_to_active_chats(message, prompt, image_base64, modelname) logging.info( - f"[Request]: Processing '{prompt}' for {message.from_user.first_name} {message.from_user.last_name}" + f"[OllamaAPI]: Processing '{prompt}' for {message.from_user.first_name} {message.from_user.last_name}" ) payload = ACTIVE_CHATS.get(message.from_user.id) async for response_data in generate(payload, modelname, prompt): @@ -218,71 +232,17 @@ async def ollama_request(message: types.Message, prompt: str = None): continue chunk = msg.get("content", "") full_response += chunk - full_response_stripped = full_response.strip() - - # avoid Bad Request: message text is empty - if full_response_stripped == "": - continue - - if "." in chunk or "\n" in chunk or "!" in chunk or "?" in chunk: - if sent_message: - if last_sent_text != full_response_stripped: - await bot.edit_message_text( - chat_id=message.chat.id, - message_id=sent_message.message_id, - text=full_response_stripped, - ) - last_sent_text = full_response_stripped - else: - sent_message = await bot.send_message( - chat_id=message.chat.id, - text=full_response_stripped, - reply_to_message_id=message.message_id, - ) - last_sent_text = full_response_stripped - - if response_data.get("done"): - if full_response_stripped and last_sent_text != full_response_stripped: - if sent_message: - await bot.edit_message_text( - chat_id=message.chat.id, - message_id=sent_message.message_id, - text=full_response_stripped, - ) - else: - sent_message = await bot.send_message( - chat_id=message.chat.id, text=full_response_stripped - ) - await bot.edit_message_text( - chat_id=message.chat.id, - message_id=sent_message.message_id, - text=md_autofixer( - full_response_stripped - + f"\n\nCurrent Model: `{modelname}`**\n**Generated in {response_data.get('total_duration') / 1e9:.2f}s" - ), - parse_mode=ParseMode.MARKDOWN_V2, - ) - async with ACTIVE_CHATS_LOCK: - if ACTIVE_CHATS.get(message.from_user.id) is not None: - # Add response to active chats object - ACTIVE_CHATS[message.from_user.id]["messages"].append( - {"role": "assistant", "content": full_response_stripped} - ) - logging.info( - f"[Response]: '{full_response_stripped}' for {message.from_user.first_name} {message.from_user.last_name}" - ) - else: - await bot.send_message( - chat_id=message.chat.id, text="Chat was reset" - ) + if any([c in chunk for c in ".\n!?"]) or response_data.get("done"): + if await handle_response(message, response_data, full_response): + break - break except Exception as e: + print(f"-----\n[OllamaAPI-ERR] CAUGHT FAULT!\n{traceback.format_exc()}\n-----") await bot.send_message( chat_id=message.chat.id, - text=f"""Error occurred\n```\n{traceback.format_exc()}\n```""", - parse_mode=ParseMode.MARKDOWN_V2, + text=f"Something went wrong.", + parse_mode=ParseMode.HTML, )