diff --git a/amiyabot/__init__.py b/amiyabot/__init__.py index 16129a0..c62dca7 100644 --- a/amiyabot/__init__.py +++ b/amiyabot/__init__.py @@ -15,7 +15,6 @@ from amiyabot.adapters.onebot.v12 import OneBot12Instance from amiyabot.adapters.tencent.qqGuild import QQGuildBotInstance, QQGuildSandboxBotInstance from amiyabot.adapters.comwechat import ComWeChatBotInstance -from amiyabot.adapters.common import CQCode # network from amiyabot.network.httpServer import HttpServer @@ -33,7 +32,7 @@ from amiyabot.builtin.lib.browserService import BrowserLaunchConfig, basic_browser_service # message -from amiyabot.builtin.messageChain import Chain, ChainBuilder, InlineKeyboard +from amiyabot.builtin.messageChain import Chain, ChainBuilder, InlineKeyboard, CQCode from amiyabot.builtin.message import ( Event, EventList, diff --git a/amiyabot/adapters/common.py b/amiyabot/adapters/common.py deleted file mode 100644 index 8afa10e..0000000 --- a/amiyabot/adapters/common.py +++ /dev/null @@ -1,39 +0,0 @@ -import jieba - -from dataclasses import dataclass -from amiyabot.util import remove_punctuation, chinese_to_digits -from amiyabot.builtin.message import Message - - -@dataclass -class CQCode: - code: str - - -def text_convert(message: Message, text: str, original: str): - """ - 消息文本的最终处理 - - :param message: Message 对象 - :param text: 预处理消息文本 - :param original: 未经预处理的原始消息文本 - :return: Message 对象 - """ - - message.text = text - message.text_digits = chinese_to_digits(message.text) - message.text_unsigned = remove_punctuation(original) - message.text_original = original - - chars = cut_by_jieba(message.text) + cut_by_jieba(message.text_digits) - - words = list(set(chars)) - words = sorted(words, key=chars.index) - - message.text_words = words - - return message - - -def cut_by_jieba(text: str): - return jieba.lcut(text.lower().replace(' ', '')) diff --git a/amiyabot/adapters/kook/package.py b/amiyabot/adapters/kook/package.py index ba57029..98ab417 100644 --- a/amiyabot/adapters/kook/package.py +++ b/amiyabot/adapters/kook/package.py @@ -4,8 +4,6 @@ from amiyabot.builtin.message import Event, Message, File from amiyabot.adapters import BotAdapterProtocol -from ..common import text_convert - class RolePermissionCache: guild_role: Dict[str, Dict[str, int]] = {} @@ -78,4 +76,5 @@ async def package_kook_message(instance: BotAdapterProtocol, message: dict): if extra['quote']['type'] == 2: data.image.append(extra['quote']['content']) - return text_convert(data, text.strip(), text) + data.set_text(text) + return data diff --git a/amiyabot/adapters/mirai/package.py b/amiyabot/adapters/mirai/package.py index 114d2da..ca36ae1 100644 --- a/amiyabot/adapters/mirai/package.py +++ b/amiyabot/adapters/mirai/package.py @@ -1,8 +1,6 @@ from amiyabot.builtin.message import Event, Message from amiyabot.adapters import BotAdapterProtocol -from ..common import text_convert - def package_mirai_message(instance: BotAdapterProtocol, account: str, data: dict): if 'type' not in data: @@ -50,4 +48,5 @@ def package_mirai_message(instance: BotAdapterProtocol, account: str, data: dict if chain['type'] == 'Image': msg.image.append(chain['url'].strip()) - return text_convert(msg, text, text) + msg.set_text(text) + return msg diff --git a/amiyabot/adapters/onebot/v11/package.py b/amiyabot/adapters/onebot/v11/package.py index 890da62..b1d8f97 100644 --- a/amiyabot/adapters/onebot/v11/package.py +++ b/amiyabot/adapters/onebot/v11/package.py @@ -1,8 +1,6 @@ from amiyabot.builtin.message import Event, EventList, Message from amiyabot.adapters import BotAdapterProtocol -from amiyabot.adapters.common import text_convert - async def package_onebot11_message(instance: BotAdapterProtocol, account: str, data: dict): if 'post_type' not in data: @@ -77,4 +75,5 @@ async def package_onebot11_message(instance: BotAdapterProtocol, account: str, d if chain['type'] == 'image': msg.image.append(chain_data['url'].strip()) - return text_convert(msg, text, text) + msg.set_text(text) + return msg diff --git a/amiyabot/adapters/onebot/v12/package.py b/amiyabot/adapters/onebot/v12/package.py index da9df6a..c235551 100644 --- a/amiyabot/adapters/onebot/v12/package.py +++ b/amiyabot/adapters/onebot/v12/package.py @@ -1,7 +1,6 @@ from amiyabot.builtin.message import Event, EventList, Message from amiyabot.adapters import BotAdapterProtocol -from amiyabot.adapters.common import text_convert from .api import OneBot12API @@ -62,7 +61,8 @@ async def package_onebot12_message(instance: BotAdapterProtocol, data: dict): if chain['type'] == 'video': msg.video = await get_file(instance, chain_data) - return text_convert(msg, text, text) + msg.set_text(text) + return msg event_list = EventList([Event(instance, message_type, data)]) diff --git a/amiyabot/adapters/tencent/qqGroup/package.py b/amiyabot/adapters/tencent/qqGroup/package.py index 8cf6a91..fa2204c 100644 --- a/amiyabot/adapters/tencent/qqGroup/package.py +++ b/amiyabot/adapters/tencent/qqGroup/package.py @@ -1,6 +1,5 @@ from amiyabot.builtin.message import Event, Message from amiyabot.adapters import BotAdapterProtocol -from amiyabot.adapters.common import text_convert async def package_qq_group_message(instance: BotAdapterProtocol, event: str, message: dict, is_reference: bool = False): @@ -29,7 +28,7 @@ async def package_qq_group_message(instance: BotAdapterProtocol, event: str, mes data.image.append(item['url']) if 'content' in message: - data = text_convert(data, message['content'].strip(), message['content']) + data.set_text(message['content']) return data diff --git a/amiyabot/adapters/tencent/qqGuild/package.py b/amiyabot/adapters/tencent/qqGuild/package.py index 14539ba..8420c10 100644 --- a/amiyabot/adapters/tencent/qqGuild/package.py +++ b/amiyabot/adapters/tencent/qqGuild/package.py @@ -2,7 +2,6 @@ from amiyabot.builtin.message import Event, Message from amiyabot.adapters import BotAdapterProtocol -from amiyabot.adapters.common import text_convert from .api import QQGuildAPI @@ -59,7 +58,7 @@ async def package_qq_guild_message(instance: BotAdapterProtocol, event: str, mes for fid in face_list: data.face.append(fid) - data = text_convert(data, text.strip(), message['content']) + data.set_text(text) if 'message_reference' in message: reference = await api.get_message(message['channel_id'], message['message_reference']['message_id']) diff --git a/amiyabot/adapters/test/server.py b/amiyabot/adapters/test/server.py index 46b1fa2..8ccf5e8 100644 --- a/amiyabot/adapters/test/server.py +++ b/amiyabot/adapters/test/server.py @@ -12,8 +12,6 @@ from amiyabot.util import random_code, create_dir from amiyabot import log -from ..common import text_convert - @dataclass class ReceivedMessage: @@ -95,7 +93,8 @@ async def package_message(self, event: str, event_id: str, message: dict): text = message.get('message', '') - return text_convert(msg, text, text) + msg.set_text(text) + return msg def base64_to_temp_url(self, base64_string: str): data = base64_string.split('base64,')[-1] diff --git a/amiyabot/builtin/message/structure.py b/amiyabot/builtin/message/structure.py index fb0e14b..863d9d3 100644 --- a/amiyabot/builtin/message/structure.py +++ b/amiyabot/builtin/message/structure.py @@ -4,6 +4,7 @@ from typing import Any, List, Union, Optional, Callable from dataclasses import dataclass from amiyabot.typeIndexes import * +from amiyabot.util import remove_punctuation, chinese_to_digits, cut_by_jieba class EventStructure: @@ -36,6 +37,7 @@ def __init__(self, instance: T_BotAdapterProtocol, message: Optional[dict] = Non self.video = '' self.text = '' + self.text_prefix = '' self.text_digits = '' self.text_unsigned = '' self.text_original = '' @@ -80,6 +82,24 @@ def __str__(self): } ) + def set_text(self, text: str, set_original: bool = True): + if set_original: + self.text_original = text + + self.text = text.strip() + self.text_convert() + + def text_convert(self): + self.text_digits = chinese_to_digits(self.text) + self.text_unsigned = remove_punctuation(self.text) + + chars = cut_by_jieba(self.text) + cut_by_jieba(self.text_digits) + + words = list(set(chars)) + words = sorted(words, key=chars.index) + + self.text_words = words + @abc.abstractmethod async def send(self, reply: T_Chain): raise NotImplementedError diff --git a/amiyabot/builtin/messageChain/element.py b/amiyabot/builtin/messageChain/element.py index 97695e5..5faa602 100644 --- a/amiyabot/builtin/messageChain/element.py +++ b/amiyabot/builtin/messageChain/element.py @@ -4,7 +4,6 @@ from dataclasses import dataclass from typing import List, Any from amiyabot.builtin.lib.browserService import * -from amiyabot.adapters.common import CQCode from amiyabot import log from .keyboard import InlineKeyboard @@ -221,6 +220,11 @@ def get(self): return self.data +@dataclass +class CQCode: + code: str + + CHAIN_ITEM = Union[ At, AtAll, diff --git a/amiyabot/factory/implemented.py b/amiyabot/factory/implemented.py index 9e5d327..3ac3821 100644 --- a/amiyabot/factory/implemented.py +++ b/amiyabot/factory/implemented.py @@ -1,6 +1,7 @@ import re from dataclasses import dataclass +from amiyabot.util import remove_prefix_once from amiyabot.builtin.message import Message, MessageMatch, Verify, Equal from amiyabot.factory.factoryTyping import MessageHandlerItem, KeywordsType @@ -30,47 +31,44 @@ def __check(self, data: Message, obj: KeywordsType) -> Verify: return Verify(False) async def verify(self, data: Message): + # 检查是否支持私信 direct_only = self.direct_only or (self.group_config and self.group_config.direct_only) - if self.check_prefix is None: - if self.group_config: - need_check_prefix = self.group_config and self.group_config.check_prefix - else: - need_check_prefix = True - else: - need_check_prefix = self.check_prefix - if data.is_direct: if not direct_only: - # 检查是否支持私信 if self.allow_direct is None: if not self.group_config or not self.group_config.allow_direct: return Verify(False) if self.allow_direct is False: return Verify(False) - else: - # 是否仅支持私信 if direct_only: return Verify(False) - # 检查是否包含“前缀触发词”或被 @ + # 检查是否包含前缀触发词或被 @ flag = False + + if self.check_prefix is None: + need_check_prefix = self.group_config.check_prefix if self.group_config else True + else: + need_check_prefix = self.check_prefix + if need_check_prefix: if data.is_at: flag = True else: prefix_keywords = need_check_prefix if isinstance(need_check_prefix, list) else self.prefix_keywords() - # 未设置前缀触发词允许直接通过 if not prefix_keywords: flag = True - for word in prefix_keywords: - if data.text.startswith(word): - flag = True - break + # 如果前缀校验通过,再次修正 Message 对象的属性值 + text, prefix = remove_prefix_once(data.text, prefix_keywords) + if prefix: + flag = True + data.text_prefix = prefix + data.set_text(text, set_original=False) # 若不通过以上检查,且关键字不为全等句式(Equal) # 则允许当关键字为列表时,筛选列表内的全等句式继续执行校验,否则校验不通过 diff --git a/amiyabot/util/toolsUtils.py b/amiyabot/util/toolsUtils.py index 2530b2c..2315016 100644 --- a/amiyabot/util/toolsUtils.py +++ b/amiyabot/util/toolsUtils.py @@ -1,8 +1,10 @@ import re import dhash +import jieba import string import random +from typing import List from string import punctuation from zhon.hanzi import punctuation as punctuation_cn from io import BytesIO @@ -25,6 +27,17 @@ def remove_punctuation(text: str): return text +def remove_prefix_once(sentence: str, prefix_keywords: List[str]): + for prefix in prefix_keywords: + if sentence.startswith(prefix): + return sentence[len(prefix) :], prefix + return sentence, '' + + +def cut_by_jieba(text: str): + return jieba.lcut(text.lower().replace(' ', '')) + + def chinese_to_digits(text: str): character_relation = { '零': 0,