Skip to content

Commit

Permalink
🔧 chore(dependencies): update dependencies in pyproject.toml
Browse files Browse the repository at this point in the history
🔨 refactor(event): remove unused AnimeIDF and related code

🚀 feat(controller): add new message handlers for various commands

♻️ refactor(controller): simplify tagger function and markdown replies
  • Loading branch information
sudoskys committed Dec 25, 2024
1 parent 5c3ce4d commit d1080ea
Show file tree
Hide file tree
Showing 4 changed files with 706 additions and 411 deletions.
211 changes: 163 additions & 48 deletions app/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import telegramify_markdown
from PIL import Image
from asgiref.sync import sync_to_async
from loguru import logger
from novelai_python.tool.image_metadata import ImageMetadata, ImageVerifier
from novelai_python.tool.random_prompt import RandomPromptGenerator
Expand All @@ -16,13 +15,16 @@
from telebot.async_telebot import AsyncTeleBot
from telebot.asyncio_helper import ApiTelegramException
from telebot.asyncio_storage import StateMemoryStorage
from telegramify_markdown import ContentTypes

from app.event import pipeline_tag
from app_conf import settings
from setting.telegrambot import BotSetting

StepCache = StateMemoryStorage()

prompt_generator = RandomPromptGenerator(nsfw_enabled=False)


def extract_between_multiple_markers(input_list, start_markers, end_markers):
extracting = False
Expand Down Expand Up @@ -69,31 +71,37 @@ async def read_a111(file: BytesIO):
logger.debug(f"Error {e}")
return []
else:
return [f"**📦 Prompt**\n```{prompt}```", f">`{message}`\n"]
return [
formatting.mbold("📦 Prompt", escape=False),
formatting.mcode(content=prompt, language="txt", escape=False),
formatting.mbold("📦 Negative Prompt", escape=False),
formatting.mcite(content=negative_prompt, expandable=True, escape=False),
]


async def read_comfyui(file: BytesIO):
try:
file.seek(0)
with Image.open(file) as img:
print(img.info)
parameter = img.info.get("prompt", None)
parameter = img.info.get("prompt")
if not parameter:
raise Exception("Empty Parameter")
return [
formatting.mbold("📦 Comfyui", escape=False),
formatting.mcode(content=parameter, language="txt", escape=False),
]
except Exception as e:
logger.debug(f"Error {e}")
return []
else:
return [f"**📦 Comfyui** \n```{parameter}```"]
return []


async def read_novelai(file: BytesIO):
message = []
try:
file.seek(0)
meta_data = ImageMetadata.load_image(file)
read_prompt = meta_data.Description
read_model = meta_data.used_model
with Image.open(file) as img:
meta_data = ImageMetadata.load_image(img)
rq_type = meta_data.Comment.request_type
mode = ""
if rq_type == "PromptGenerateRequest":
Expand All @@ -106,31 +114,50 @@ async def read_novelai(file: BytesIO):
logger.debug(f"Empty metadata {e}")
return []
else:
message.extend(
[
f"**📦 Prompt:** `{read_prompt}`" if read_prompt else "",
f"**📦 Model:** `{read_model.value}`" if read_model else "",
f"**📦 Source:** `{meta_data.Source}`" if meta_data.Source else "",
]
message.append(formatting.mbold(f"📦 NovelAI {mode}", escape=False))
if meta_data.Comment.prompt:
message.append(
formatting.mcode(
content=meta_data.Comment.prompt, language="txt", escape=False
)
)
if meta_data.Comment.negative_prompt:
message.append(
formatting.mcode(
content=meta_data.Comment.negative_prompt,
language="txt",
escape=False,
)
)
if meta_data.used_model:
message.append(
formatting.mbold(f"📦 Model #{meta_data.used_model}", escape=False),
)
if meta_data.Source:
message.append(
formatting.mbold(f"📦 Source #{meta_data.Source}", escape=False),
)
message.append(
formatting.mcode(
content=meta_data.Comment.model_dump_json(indent=2),
language="json",
escape=False,
)
)
try:
file.seek(0)
is_novelai, has_latent = ImageVerifier().verify(file)
with Image.open(file) as img:
is_novelai, has_latent = ImageVerifier().verify(img)
except Exception:
logger.debug("Not NovelAI")
else:
if is_novelai:
message.append("**🧊 Signed by NovelAI**")
message.append(formatting.mbold("🧊 Signed by NovelAI", escape=False))
if has_latent:
message.append("**🧊 Find Latent Space**")
message.append(formatting.mbold("🧊 Find Latent Space", escape=False))
return message


@sync_to_async
def sync_to_async_func():
pass


class BotRunner(object):
def __init__(self):
self.bot = AsyncTeleBot(BotSetting.token, state_storage=StepCache)
Expand All @@ -148,18 +175,18 @@ async def download(self, file):
downloaded_file = await self.bot.download_file(_file_info.file_path)
return downloaded_file

async def tagger(self, file, hidden_long_text=False) -> str:
async def tagger(self, file) -> str:
raw_file_data = await self.download(file=file)
if raw_file_data is None:
return "🥛 Not An image"
if isinstance(raw_file_data, bytes):
file_data = BytesIO(raw_file_data)
else:
file_data = raw_file_data
result = await pipeline_tag(trace_id="test", content=file_data)
# Infer Tags
infer = await pipeline_tag(trace_id="test", content=file_data)
infer_message = [
f"**🥽 AnimeScore: {result.anime_score}**",
"**🔍 Infer Tags**",
formatting.mbold("🥛 Tags", escape=False),
]
novelai_message = await read_novelai(file=file_data)
comfyui_message = await read_comfyui(file=file_data)
Expand All @@ -169,17 +196,25 @@ async def tagger(self, file, hidden_long_text=False) -> str:
filter(lambda msg: msg, [novelai_message, comfyui_message, a111_message]),
None,
)
if read_message and hidden_long_text:
infer_message.append(f"\n>`{result.anime_tags}`\n")
if read_message:
infer_message.append(
formatting.mcite(
content=infer.anime_tags, expandable=True, escape=False
)
)
if infer.characters:
infer_message.append(formatting.mbold("🥛 Characters", escape=False))
infer_message.append(
formatting.mcode(
content=",".join(infer.characters), language="txt", escape=False
)
)
if not read_message:
infer_message.append(formatting.mbold("🥛 No Metadata", escape=False))
else:
infer_message.append(f"```{result.anime_tags}```")
if result.characters:
infer_message.append(f"**🌟 Characters:** `{','.join(result.characters)}`")
read_message = read_message or ["🥛 No Metadata"]
content = infer_message + read_message
prompt = telegramify_markdown.convert("\n".join(content))
infer_message.extend(read_message)
file_data.close()
return prompt
return "\n".join(infer_message)

async def run(self):
logger.info("Bot Start")
Expand All @@ -190,20 +225,89 @@ async def run(self):
asyncio_helper.proxy = BotSetting.proxy_address
logger.info("Proxy tunnels are being used!")

async def reply_markdown(
chat_id: int,
text: str,
reply_to_message_id: int = None,
):
blocks = await telegramify_markdown.telegramify(text)
for item in blocks:
if item.content_type == ContentTypes.TEXT:
await bot.send_message(
chat_id=chat_id,
reply_to_message_id=reply_to_message_id,
text=item.content,
parse_mode="MarkdownV2",
)
elif item.content_type == ContentTypes.PHOTO:
await bot.send_photo(
chat_id,
(item.file_name, item.file_data),
caption=item.caption,
reply_to_message_id=reply_to_message_id,
parse_mode="MarkdownV2",
)
elif item.content_type == ContentTypes.FILE:
await bot.send_document(
chat_id,
(item.file_name, item.file_data),
caption=item.caption,
reply_to_message_id=reply_to_message_id,
parse_mode="MarkdownV2",
)

@bot.message_handler(
content_types=["photo", "document"], chat_types=["private"]
)
async def start(message: types.Message):
async def listen_pm(message: types.Message):
if settings.mode.only_white:
if message.chat.id not in settings.mode.white_group:
return logger.info(f"White List Out {message.chat.id}")
logger.info(f"Report in {message.chat.id} {message.from_user.id}")
if message.photo:
prompt = await self.tagger(file=message.photo[-1])
await bot.reply_to(message, text=prompt, parse_mode="MarkdownV2")
await reply_markdown(
chat_id=message.chat.id, reply_to_message_id=message.id, text=prompt
)
if message.document:
prompt = await self.tagger(file=message.document)
await bot.reply_to(message, text=prompt, parse_mode="MarkdownV2")
await reply_markdown(
chat_id=message.chat.id, reply_to_message_id=message.id, text=prompt
)

@bot.message_handler(
commands="scene_composition", chat_types=["supergroup", "group", "private"]
)
async def scene_composition(message: types.Message):
if settings.mode.only_white:
if message.chat.id not in settings.mode.white_group:
return logger.info(f"White List Out {message.chat.id}")
contents = prompt_generator.generate_scene_tags()
prompt = [formatting.mbold("🥛 Scene Composition Prompt", escape=False)]
for content in contents:
prompt.append(f"- `{content}`")
return await reply_markdown(
chat_id=message.chat.id,
reply_to_message_id=message.id,
text="\n".join(prompt),
)

@bot.message_handler(
commands="scene", chat_types=["supergroup", "group", "private"]
)
async def scene(message: types.Message):
if settings.mode.only_white:
if message.chat.id not in settings.mode.white_group:
return logger.info(f"White List Out {message.chat.id}")
contents = prompt_generator.generate_scene_tags()
prompt = [formatting.mbold("🥛 Scene Prompt", escape=False)]
for content in contents:
prompt.append(f"- `{content}`")
return await reply_markdown(
chat_id=message.chat.id,
reply_to_message_id=message.id,
text="\n".join(prompt),
)

@bot.message_handler(
commands="nsfw", chat_types=["supergroup", "group", "private"]
Expand All @@ -212,7 +316,7 @@ async def nsfw(message: types.Message):
if settings.mode.only_white:
if message.chat.id not in settings.mode.white_group:
return logger.info(f"White List Out {message.chat.id}")
contents = RandomPromptGenerator(nsfw_enabled=True).generate()
contents = prompt_generator.generate_common_tags(nsfw=True)
prompt = formatting.format_text(
formatting.mbold("🥛 NSFW Prompt"), formatting.mcode(content=contents)
)
Expand All @@ -225,7 +329,7 @@ async def sfw(message: types.Message):
if settings.mode.only_white:
if message.chat.id not in settings.mode.white_group:
return logger.info(f"White List Out {message.chat.id}")
contents = RandomPromptGenerator(nsfw_enabled=False).generate()
contents = prompt_generator.generate_common_tags(nsfw=False)
prompt = formatting.format_text(
formatting.mbold("🥛 SFW Prompt"), formatting.mcode(content=contents)
)
Expand All @@ -240,24 +344,35 @@ async def tag(message: types.Message):
if not message.reply_to_message:
return await bot.reply_to(
message,
text=f"🍡 please reply to message with this command ({message.chat.id})",
text=f"🍡 Please reply a photo with this command, chat id:({message.chat.id})",
)
logger.info(f"Report in {message.chat.id} {message.from_user.id}")
reply_message = message.reply_to_message
reply_message_ph = reply_message.photo
reply_message_doc = reply_message.document
if reply_message_ph:
prompt = await self.tagger(
file=reply_message_ph[-1], hidden_long_text=True
prompt = await self.tagger(file=reply_message_ph[-1])
return await reply_markdown(
chat_id=message.chat.id, reply_to_message_id=message.id, text=prompt
)
return await bot.reply_to(message, text=prompt, parse_mode="MarkdownV2")
if reply_message_doc:
prompt = await self.tagger(
file=reply_message_doc, hidden_long_text=True
prompt = await self.tagger(file=reply_message_doc)
return await reply_markdown(
chat_id=message.chat.id, reply_to_message_id=message.id, text=prompt
)
return await bot.reply_to(message, text=prompt, parse_mode="MarkdownV2")
return await bot.reply_to(message, text="🥛 Not image")

await bot.set_my_commands(
commands=[
types.BotCommand("tag", "Tag Image"),
types.BotCommand("scene", "Generate Scene Prompt"),
types.BotCommand(
"scene_composition", "Generate Scene Composition Prompt"
),
types.BotCommand("nsfw", "Generate NSFW Prompt"),
types.BotCommand("sfw", "Generate SFW Prompt"),
],
)
try:
await bot.polling(
non_stop=True, allowed_updates=util.update_types, skip_pending=True
Expand Down
12 changes: 2 additions & 10 deletions app/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,19 @@
from io import TextIOBase
from typing import Union, IO, Optional

from anime_identify import AnimeIDF
from loguru import logger
from pydantic import BaseModel

from app.utils import WdTaggerSDK
from setting.wdtagger import TaggerSetting

ANIME = AnimeIDF()


class TaggerResult(BaseModel):
anime_score: float
anime_tags: Optional[str] = ""
characters: Optional[list] = []


async def pipeline_tag(trace_id, content: Union[IO, TextIOBase]) -> TaggerResult:
content.seek(0)
anime_score = ANIME.predict_image(content=content)
content.seek(0)
raw_output_wd = await WdTaggerSDK(base_url=TaggerSetting.wd_api_endpoint).upload(
file=content.read(),
Expand All @@ -35,7 +29,5 @@ async def pipeline_tag(trace_id, content: Union[IO, TextIOBase]) -> TaggerResult
tag_result: str = raw_output_wd["sorted_general_strings"]
character_res: dict = raw_output_wd["character_res"]
characters = list(character_res.keys())
logger.info(f"Censored {trace_id},score {anime_score},result {tag_result}")
return TaggerResult(
anime_score=anime_score, anime_tags=tag_result, characters=characters
)
logger.info(f"Processed {trace_id},result {tag_result}")
return TaggerResult(anime_tags=tag_result, characters=characters)
Loading

0 comments on commit d1080ea

Please sign in to comment.