Skip to content

Commit

Permalink
update: 支持提取前缀词,从原始句子中分离
Browse files Browse the repository at this point in the history
  • Loading branch information
vivien8261 committed Jul 1, 2024
1 parent bad1606 commit 99ada6b
Show file tree
Hide file tree
Showing 13 changed files with 66 additions and 77 deletions.
3 changes: 1 addition & 2 deletions amiyabot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from amiyabot.adapters.onebot.v12 import OneBot12Instance
from amiyabot.adapters.tencent.qqGuild import QQGuildBotInstance, QQGuildSandboxBotInstance
from amiyabot.adapters.comwechat import ComWeChatBotInstance
from amiyabot.adapters.common import CQCode

# network
from amiyabot.network.httpServer import HttpServer
Expand All @@ -33,7 +32,7 @@
from amiyabot.builtin.lib.browserService import BrowserLaunchConfig, basic_browser_service

# message
from amiyabot.builtin.messageChain import Chain, ChainBuilder, InlineKeyboard
from amiyabot.builtin.messageChain import Chain, ChainBuilder, InlineKeyboard, CQCode
from amiyabot.builtin.message import (
Event,
EventList,
Expand Down
39 changes: 0 additions & 39 deletions amiyabot/adapters/common.py

This file was deleted.

5 changes: 2 additions & 3 deletions amiyabot/adapters/kook/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from amiyabot.builtin.message import Event, Message, File
from amiyabot.adapters import BotAdapterProtocol

from ..common import text_convert


class RolePermissionCache:
guild_role: Dict[str, Dict[str, int]] = {}
Expand Down Expand Up @@ -78,4 +76,5 @@ async def package_kook_message(instance: BotAdapterProtocol, message: dict):
if extra['quote']['type'] == 2:
data.image.append(extra['quote']['content'])

return text_convert(data, text.strip(), text)
data.set_text(text)
return data
5 changes: 2 additions & 3 deletions amiyabot/adapters/mirai/package.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from amiyabot.builtin.message import Event, Message
from amiyabot.adapters import BotAdapterProtocol

from ..common import text_convert


def package_mirai_message(instance: BotAdapterProtocol, account: str, data: dict):
if 'type' not in data:
Expand Down Expand Up @@ -50,4 +48,5 @@ def package_mirai_message(instance: BotAdapterProtocol, account: str, data: dict
if chain['type'] == 'Image':
msg.image.append(chain['url'].strip())

return text_convert(msg, text, text)
msg.set_text(text)
return msg
5 changes: 2 additions & 3 deletions amiyabot/adapters/onebot/v11/package.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from amiyabot.builtin.message import Event, EventList, Message
from amiyabot.adapters import BotAdapterProtocol

from amiyabot.adapters.common import text_convert


async def package_onebot11_message(instance: BotAdapterProtocol, account: str, data: dict):
if 'post_type' not in data:
Expand Down Expand Up @@ -77,4 +75,5 @@ async def package_onebot11_message(instance: BotAdapterProtocol, account: str, d
if chain['type'] == 'image':
msg.image.append(chain_data['url'].strip())

return text_convert(msg, text, text)
msg.set_text(text)
return msg
4 changes: 2 additions & 2 deletions amiyabot/adapters/onebot/v12/package.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from amiyabot.builtin.message import Event, EventList, Message
from amiyabot.adapters import BotAdapterProtocol

from amiyabot.adapters.common import text_convert
from .api import OneBot12API


Expand Down Expand Up @@ -62,7 +61,8 @@ async def package_onebot12_message(instance: BotAdapterProtocol, data: dict):
if chain['type'] == 'video':
msg.video = await get_file(instance, chain_data)

return text_convert(msg, text, text)
msg.set_text(text)
return msg

event_list = EventList([Event(instance, message_type, data)])

Expand Down
3 changes: 1 addition & 2 deletions amiyabot/adapters/tencent/qqGroup/package.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from amiyabot.builtin.message import Event, Message
from amiyabot.adapters import BotAdapterProtocol
from amiyabot.adapters.common import text_convert


async def package_qq_group_message(instance: BotAdapterProtocol, event: str, message: dict, is_reference: bool = False):
Expand Down Expand Up @@ -29,7 +28,7 @@ async def package_qq_group_message(instance: BotAdapterProtocol, event: str, mes
data.image.append(item['url'])

if 'content' in message:
data = text_convert(data, message['content'].strip(), message['content'])
data.set_text(message['content'])

return data

Expand Down
3 changes: 1 addition & 2 deletions amiyabot/adapters/tencent/qqGuild/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from amiyabot.builtin.message import Event, Message
from amiyabot.adapters import BotAdapterProtocol
from amiyabot.adapters.common import text_convert

from .api import QQGuildAPI

Expand Down Expand Up @@ -59,7 +58,7 @@ async def package_qq_guild_message(instance: BotAdapterProtocol, event: str, mes
for fid in face_list:
data.face.append(fid)

data = text_convert(data, text.strip(), message['content'])
data.set_text(text)

if 'message_reference' in message:
reference = await api.get_message(message['channel_id'], message['message_reference']['message_id'])
Expand Down
5 changes: 2 additions & 3 deletions amiyabot/adapters/test/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
from amiyabot.util import random_code, create_dir
from amiyabot import log

from ..common import text_convert


@dataclass
class ReceivedMessage:
Expand Down Expand Up @@ -95,7 +93,8 @@ async def package_message(self, event: str, event_id: str, message: dict):

text = message.get('message', '')

return text_convert(msg, text, text)
msg.set_text(text)
return msg

def base64_to_temp_url(self, base64_string: str):
data = base64_string.split('base64,')[-1]
Expand Down
20 changes: 20 additions & 0 deletions amiyabot/builtin/message/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, List, Union, Optional, Callable
from dataclasses import dataclass
from amiyabot.typeIndexes import *
from amiyabot.util import remove_punctuation, chinese_to_digits, cut_by_jieba


class EventStructure:
Expand Down Expand Up @@ -36,6 +37,7 @@ def __init__(self, instance: T_BotAdapterProtocol, message: Optional[dict] = Non
self.video = ''

self.text = ''
self.text_prefix = ''
self.text_digits = ''
self.text_unsigned = ''
self.text_original = ''
Expand Down Expand Up @@ -80,6 +82,24 @@ def __str__(self):
}
)

def set_text(self, text: str, set_original: bool = True):
if set_original:
self.text_original = text

self.text = text.strip()
self.text_convert()

def text_convert(self):
self.text_digits = chinese_to_digits(self.text)
self.text_unsigned = remove_punctuation(self.text)

chars = cut_by_jieba(self.text) + cut_by_jieba(self.text_digits)

words = list(set(chars))
words = sorted(words, key=chars.index)

self.text_words = words

@abc.abstractmethod
async def send(self, reply: T_Chain):
raise NotImplementedError
Expand Down
6 changes: 5 additions & 1 deletion amiyabot/builtin/messageChain/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from dataclasses import dataclass
from typing import List, Any
from amiyabot.builtin.lib.browserService import *
from amiyabot.adapters.common import CQCode
from amiyabot import log

from .keyboard import InlineKeyboard
Expand Down Expand Up @@ -221,6 +220,11 @@ def get(self):
return self.data


@dataclass
class CQCode:
code: str


CHAIN_ITEM = Union[
At,
AtAll,
Expand Down
32 changes: 15 additions & 17 deletions amiyabot/factory/implemented.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import re

from dataclasses import dataclass
from amiyabot.util import remove_prefix_once
from amiyabot.builtin.message import Message, MessageMatch, Verify, Equal
from amiyabot.factory.factoryTyping import MessageHandlerItem, KeywordsType

Expand Down Expand Up @@ -30,47 +31,44 @@ def __check(self, data: Message, obj: KeywordsType) -> Verify:
return Verify(False)

async def verify(self, data: Message):
# 检查是否支持私信
direct_only = self.direct_only or (self.group_config and self.group_config.direct_only)

if self.check_prefix is None:
if self.group_config:
need_check_prefix = self.group_config and self.group_config.check_prefix
else:
need_check_prefix = True
else:
need_check_prefix = self.check_prefix

if data.is_direct:
if not direct_only:
# 检查是否支持私信
if self.allow_direct is None:
if not self.group_config or not self.group_config.allow_direct:
return Verify(False)

if self.allow_direct is False:
return Verify(False)

else:
# 是否仅支持私信
if direct_only:
return Verify(False)

# 检查是否包含“前缀触发词”或被 @
# 检查是否包含前缀触发词或被 @
flag = False

if self.check_prefix is None:
need_check_prefix = self.group_config.check_prefix if self.group_config else True
else:
need_check_prefix = self.check_prefix

if need_check_prefix:
if data.is_at:
flag = True
else:
prefix_keywords = need_check_prefix if isinstance(need_check_prefix, list) else self.prefix_keywords()

# 未设置前缀触发词允许直接通过
if not prefix_keywords:
flag = True

for word in prefix_keywords:
if data.text.startswith(word):
flag = True
break
# 如果前缀校验通过,再次修正 Message 对象的属性值
text, prefix = remove_prefix_once(data.text, prefix_keywords)
if prefix:
flag = True
data.text_prefix = prefix
data.set_text(text, set_original=False)

# 若不通过以上检查,且关键字不为全等句式(Equal)
# 则允许当关键字为列表时,筛选列表内的全等句式继续执行校验,否则校验不通过
Expand Down
13 changes: 13 additions & 0 deletions amiyabot/util/toolsUtils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import re
import dhash
import jieba
import string
import random

from typing import List
from string import punctuation
from zhon.hanzi import punctuation as punctuation_cn
from io import BytesIO
Expand All @@ -25,6 +27,17 @@ def remove_punctuation(text: str):
return text


def remove_prefix_once(sentence: str, prefix_keywords: List[str]):
for prefix in prefix_keywords:
if sentence.startswith(prefix):
return sentence[len(prefix) :], prefix
return sentence, ''


def cut_by_jieba(text: str):
return jieba.lcut(text.lower().replace(' ', ''))


def chinese_to_digits(text: str):
character_relation = {
'零': 0,
Expand Down

0 comments on commit 99ada6b

Please sign in to comment.