新增群聊记录学习发言,优化代码

This commit is contained in:
CMHopeSunshine 2022-05-28 23:34:15 +08:00
parent de09bb113f
commit b5323fd378
4 changed files with 1097 additions and 0 deletions

View File

@ -0,0 +1,214 @@
import random
import asyncio
import re
import time
import os
import threading
from nonebot import on_message, require, get_bot, logger
from nonebot.exception import ActionFailed
from nonebot.typing import T_State
from nonebot.rule import keyword, to_me, Rule
from nonebot.adapters import Bot
from nonebot.adapters.onebot.v11 import GroupMessageEvent
from nonebot.adapters.onebot.v11 import permission
from .model import Chat
from utils.config import config
message_id_lock = threading.Lock()
message_id_dict = {}
async def check_accounts(event: GroupMessageEvent) -> bool:
# 不响应其他nonebot_plugin_gocqhttp机器人账号的信息
if os.path.exists('accounts'):
accounts = [int(d) for d in os.listdir('accounts')
if d.isnumeric()]
if event.user_id in accounts:
return False
return True
async def get_answer(event: GroupMessageEvent, state: T_State) -> bool:
# 不响应被屏蔽的人的信息
if event.user_id in config.paimon_chat_ban:
return False
chat: Chat = Chat(event)
to_learn = True
# 多账号登陆,且在同一群中时;避免一条消息被处理多次
with message_id_lock:
message_id = event.message_id
group_id = event.group_id
if group_id in message_id_dict:
if message_id in message_id_dict[group_id]:
to_learn = False
else:
message_id_dict[group_id] = []
group_message = message_id_dict[group_id]
group_message.append(message_id)
if len(group_message) > 100:
group_message = group_message[:-10]
answers = chat.answer()
if to_learn:
chat.learn()
if answers:
state['answers'] = answers
return True
return False
any_msg = on_message(
priority=20,
block=False,
rule=Rule(check_accounts, get_answer),
permission=permission.GROUP # | permission.PRIVATE_FRIEND
)
async def is_shutup(self_id: int, group_id: int) -> bool:
info = await get_bot(str(self_id)).call_api('get_group_member_info', **{
'user_id': self_id,
'group_id': group_id
})
flag: bool = info['shut_up_timestamp'] > time.time()
if flag:
logger.info(f'repeater派蒙[{self_id}]在群[{group_id}] 处于禁言状态')
return flag
@any_msg.handle()
async def _(bot: Bot, event: GroupMessageEvent, state: T_State):
delay = random.randint(2, 4)
for item in state['answers']:
logger.info(f'repeater派蒙[{event.self_id}]准备向群[{event.group_id}]回复[{item}]')
await asyncio.sleep(delay)
try:
await any_msg.send(item)
except ActionFailed:
# 自动删除失效消息。若 bot 处于风控期,请勿开启该功能
shutup = await is_shutup(event.self_id, event.group_id)
if not shutup: # 说明这条消息失效了
logger.info('repeater | bot [{}] ready to ban [{}] in group [{}]'.format(
event.self_id, str(item), event.group_id))
Chat.ban(event.group_id, event.self_id, str(item), 'ActionFailed')
break
delay = random.randint(2, 4)
async def is_reply(bot: Bot, event: GroupMessageEvent) -> bool:
return bool(event.reply)
ban_msg = on_message(
rule=to_me() & keyword('不可以', '达咩', '不行', 'no') & Rule(is_reply),
priority=5,
block=True,
permission=permission.GROUP_OWNER | permission.GROUP_ADMIN
)
@ban_msg.handle()
async def _(bot: Bot, event: GroupMessageEvent):
if '[CQ:reply,' not in event.raw_message:
return False
raw_message = ''
for item in event.reply.message:
raw_reply = str(item)
# 去掉图片消息中的 url, subType 等字段
raw_message += re.sub(r'(\[CQ\:.+)(?:,url=*)(\])',
r'\1\2', raw_reply)
logger.info(f'repeater派蒙[{event.self_id}] ready to ban [{raw_message}] in group [{event.group_id}]')
if Chat.ban(event.group_id, event.self_id, raw_message, str(event.user_id)):
msg_send = ['派蒙知道错了...达咩!', '派蒙不会再这么说了...', '果面呐噻,派蒙说错话了...']
await ban_msg.finish(random.choice(msg_send))
scheduler = require('nonebot_plugin_apscheduler').scheduler
async def message_is_ban(bot: Bot, event: GroupMessageEvent) -> bool:
return event.get_plaintext().strip() == '不可以发这个'
ban_msg_latest = on_message(
rule=to_me() & Rule(message_is_ban),
priority=5,
block=True,
permission=permission.GROUP_OWNER | permission.GROUP_ADMIN
)
@ban_msg_latest.handle()
async def _(bot: Bot, event: GroupMessageEvent):
logger.info(
f'repeater派蒙[{event.self_id}]把群[{event.group_id}]最后的回复ban了')
if Chat.ban(event.group_id, event.self_id, '', str(event.user_id)):
msg_send = ['派蒙知道错了...达咩!', '派蒙不会再这么说了...', '果面呐噻,派蒙说错话了...']
await ban_msg_latest.finish(random.choice(msg_send))
@scheduler.scheduled_job('interval', seconds=5, misfire_grace_time=5)
async def speak_up():
ret = Chat.speak()
if not ret:
return
bot_id, group_id, messages = ret
for msg in messages:
logger.info(f'repeater派蒙[{bot_id}]准备向群[{group_id}]发送消息[{messages}]')
await get_bot(str(bot_id)).call_api('send_group_msg', **{
'message': msg,
'group_id': group_id
})
await asyncio.sleep(random.randint(2, 4))
update_scheduler = require('nonebot_plugin_apscheduler').scheduler
async def is_drink_msg(bot: Bot, event: GroupMessageEvent) -> bool:
return event.get_plaintext().strip() in ['派蒙干杯', '应急食品开餐', '派蒙干饭']
drink_msg = on_message(
rule=Rule(is_drink_msg),
priority=5,
block=True,
permission=permission.GROUP_OWNER | permission.GROUP_ADMIN
)
@drink_msg.handle()
async def _(bot: Bot, event: GroupMessageEvent):
drunk_duration = random.randint(60, 600)
logger.info(f'repeater派蒙[{event.self_id}]即将在群[{event.group_id}]喝醉,在[{drunk_duration}秒]后醒来')
Chat.drink(event.group_id)
try:
await drink_msg.send('呀,旅行者。你今天走起路来,怎么看着摇摇晃晃的?')
except ActionFailed:
pass
await asyncio.sleep(drunk_duration)
ret = Chat.sober_up(event.group_id)
if ret:
logger.info(f'repeater派蒙[{event.self_id}]在群[{event.group_id}]醒酒了')
await drink_msg.finish('呃...头好疼...下次不能喝那么多了...')
@update_scheduler.scheduled_job('cron', hour='4')
def update_data():
Chat.clearup_context()
Chat.completely_sober()

View File

@ -0,0 +1,879 @@
from typing import Generator, List, Optional, Union, Tuple, Dict, Any
from functools import cached_property, cmp_to_key
from dataclasses import dataclass
from collections import defaultdict
import jieba_fast.analyse
import threading
import pypinyin
import pymongo
import time
import random
import re
import atexit
from nonebot.adapters.onebot.v11 import GroupMessageEvent, PrivateMessageEvent
from nonebot.adapters.onebot.v11 import Message
from utils.config import config
mongo_client = pymongo.MongoClient(config.paimon_mongodb_url)
mongo_db = mongo_client['PaimonChat']
message_mongo = mongo_db['message']
message_mongo.create_index(name='time_index',
keys=[('time', pymongo.DESCENDING)])
context_mongo = mongo_db['context']
context_mongo.create_index(name='keywords_index',
keys=[('keywords', pymongo.HASHED)])
context_mongo.create_index(name='count_index',
keys=[('count', pymongo.DESCENDING)])
context_mongo.create_index(name='time_index',
keys=[('time', pymongo.DESCENDING)])
context_mongo.create_index(name='answers_index',
keys=[('answers.group_id', pymongo.TEXT),
('answers.keywords', pymongo.TEXT)],
default_language='none')
blacklist_mongo = mongo_db['blacklist']
blacklist_mongo.create_index(name='group_index',
keys=[('group_id', pymongo.HASHED)])
@dataclass
class ChatData:
group_id: int
user_id: int
raw_message: str
plain_text: str
time: int
bot_id: int
_keywords_size: int = 3
@cached_property
def is_plain_text(self) -> bool:
return '[CQ:' not in self.raw_message and len(self.plain_text) != 0
@cached_property
def is_image(self) -> bool:
return '[CQ:image,' in self.raw_message or '[CQ:face,' in self.raw_message
@cached_property
def keywords(self) -> str:
if not self.is_plain_text and len(self.plain_text) == 0:
return self.raw_message
keywords_list = jieba_fast.analyse.extract_tags(
self.plain_text, topK=ChatData._keywords_size)
if len(keywords_list) < 2:
return self.plain_text
else:
# keywords_list.sort()
return ' '.join(keywords_list)
@cached_property
def keywords_pinyin(self) -> str:
return ''.join([item[0] for item in pypinyin.pinyin(
self.keywords, style=pypinyin.NORMAL, errors='default')]).lower()
@cached_property
def to_me(self) -> bool:
return self.plain_text.startswith('派蒙')
class Chat:
answer_threshold = config.paimon_answer_threshold # answer 相关的阈值,值越小废话越多,越大话越少
answer_limit_threshold = config.paimon_answer_limit_threshold # 上限阈值,一般正常的上下文不可能发 50 遍,一般是其他 bot 的回复,禁了!
cross_group_threshold = config.paimon_cross_group_threshold # N 个群有相同的回复,就跨群作为全局回复
repeat_threshold = config.paimon_repeat_threshold # 复读的阈值,群里连续多少次有相同的发言,就复读
speak_threshold = config.paimon_speak_threshold # 主动发言的阈值,越小废话越多
drunk_probability = config.paimon_drunk_probability # 喝醉的概率(回复没达到阈值的话)
split_probability = 0.5 # 按逗号分割回复语的概率
voice_probability = config.paimon_voice_probability # 回复语音的概率(仅纯文字)
speak_continuously_probability = config.paimon_speak_continuously_probability # 连续主动说话的概率
speak_poke_probability = config.paimon_speak_poke_probability # 主动说话加上随机戳一戳群友的概率
speak_continuously_max_len = config.paimon_speak_continuously_max_len # 连续主动说话最多几句话
save_time_threshold = 3600 # 每隔多久进行一次持久化 ( 秒 )
save_count_threshold = 1000 # 单个群超过多少条聊天记录就进行一次持久化。与时间是或的关系
blacklist_answer = defaultdict(set)
blacklist_answer_reserve = defaultdict(set)
def __init__(self, data: Union[ChatData, GroupMessageEvent, PrivateMessageEvent]):
if isinstance(data, ChatData):
self.chat_data = data
elif isinstance(data, GroupMessageEvent):
self.chat_data = ChatData(
group_id=data.group_id,
user_id=data.user_id,
# 删除图片子类型字段,同一张图子类型经常不一样,影响判断
raw_message=re.sub(
r',subType=\d+\]',
r']',
data.raw_message),
plain_text=data.get_plaintext(),
time=data.time,
bot_id=data.self_id,
)
elif isinstance(data, PrivateMessageEvent):
event_dict = data.dict()
self.chat_data = ChatData(
group_id=data.user_id, # 故意加个符号,和群号区分开来
user_id=data.user_id,
# 删除图片子类型字段,同一张图子类型经常不一样,影响判断
raw_message=re.sub(
r',subType=\d+\]',
r']',
data.raw_message),
plain_text=data.get_plaintext(),
time=data.time,
bot_id=data.self_id,
)
def learn(self) -> bool:
"""
学习这句话
"""
if len(self.chat_data.raw_message.strip()) == 0:
return False
group_id = self.chat_data.group_id
if group_id in Chat._message_dict:
group_msgs = Chat._message_dict[group_id]
if group_msgs:
group_pre_msg = group_msgs[-1]
else:
group_pre_msg = None
# 群里的上一条发言
self._context_insert(group_pre_msg)
user_id = self.chat_data.user_id
if group_pre_msg and group_pre_msg['user_id'] != user_id:
# 该用户在群里的上一条发言(倒序三句之内)
for msg in group_msgs[:-3:-1]:
if msg['user_id'] == user_id:
self._context_insert(msg)
break
self._message_insert()
return True
def answer(self, with_limit: bool = True) -> Optional[Generator[Message, None, None]]:
"""
回复这句话可能会分多次回复也可能不回复
"""
group_id = self.chat_data.group_id
bot_id = self.chat_data.bot_id
group_bot_replies = Chat._reply_dict[group_id][bot_id]
if with_limit:
# # 不回复太短的对话,大部分是“?”、“草”
# if self.chat_data.is_plain_text and len(self.chat_data.plain_text) < 2:
# return None
if len(group_bot_replies):
latest_reply = group_bot_replies[-1]
# 限制发音频率,最多 6 秒一次
if int(time.time()) - latest_reply['time'] < 6:
return None
# # 不要一直回复同一个内容
# if self.chat_data.raw_message == latest_reply['pre_raw_message']:
# return None
# 有人复读了牛牛的回复,不继续回复
# if self.chat_data.raw_message == latest_reply['reply']:
# return None
results = self._context_find()
if results:
raw_message = self.chat_data.raw_message
keywords = self.chat_data.keywords
with Chat._reply_lock:
group_bot_replies.append({
'time': int(time.time()),
'pre_raw_message': raw_message,
'pre_keywords': keywords,
'reply': '[PaimonChat: Reply]', # flag
'reply_keywords': '[PaimonChat: Reply]', # flag
})
def yield_results(results: Tuple[List[str], str]) -> Generator[Message, None, None]:
answer_list, answer_keywords = results
group_bot_replies = Chat._reply_dict[group_id][bot_id]
for item in answer_list:
with Chat._reply_lock:
group_bot_replies.append({
'time': int(time.time()),
'pre_raw_message': raw_message,
'pre_keywords': keywords,
'reply': item,
'reply_keywords': answer_keywords,
})
if '[CQ:' not in item and len(item) > 1 \
and random.random() < Chat.voice_probability:
yield Chat._text_to_speech(item)
else:
yield Message(item)
with Chat._reply_lock:
group_bot_replies = group_bot_replies[-Chat._save_reserve_size:]
return yield_results(results)
return None
@staticmethod
def speak() -> Optional[Tuple[int, int, List[Message]]]:
"""
主动发言返回当前最希望发言的 bot 账号群号发言消息 List也有可能不发言
"""
basic_msgs_len = 10
basic_delay = 600
def group_popularity_cmp(lhs: Tuple[int, List[Dict[str, Any]]],
rhs: Tuple[int, List[Dict[str, Any]]]) -> int:
def cmp(a: Any, b: Any):
return (a > b) - (a < b)
lhs_group_id, lhs_msgs = lhs
rhs_group_id, rhs_msgs = rhs
lhs_len = len(lhs_msgs)
rhs_len = len(rhs_msgs)
# 默认是 0, 加个 1 避免乘没了
lhs_drunkenness = Chat._drunkenness_dict[lhs_group_id] + 1
rhs_drunkenness = Chat._drunkenness_dict[rhs_group_id] + 1
if lhs_len < basic_msgs_len or rhs_len < basic_msgs_len:
return cmp(lhs_len * lhs_drunkenness,
rhs_len * rhs_drunkenness)
lhs_duration = lhs_msgs[-1]['time'] - lhs_msgs[0]['time']
rhs_duration = rhs_msgs[-1]['time'] - rhs_msgs[0]['time']
if not lhs_duration or not rhs_duration:
return cmp(lhs_len, rhs_len)
return cmp(lhs_len * lhs_drunkenness / lhs_duration,
rhs_len * rhs_drunkenness / rhs_duration)
# 按群聊热度排序
popularity = sorted(Chat._message_dict.items(),
key=cmp_to_key(group_popularity_cmp))
cur_time = time.time()
for group_id, group_msgs in popularity:
group_replies = Chat._reply_dict[group_id]
if not len(group_replies) or len(group_msgs) < basic_msgs_len:
continue
# 一般来说所有牛牛都是一起回复的,最后发言时间应该是一样的,随意随便选一个[0]就好了
group_replies_front = list(group_replies.values())[0]
if not len(group_replies_front) or \
group_replies_front[-1]['time'] > group_msgs[-1]['time']:
continue
msgs_len = len(group_msgs)
latest_time = group_msgs[-1]['time']
duration = latest_time - group_msgs[0]['time']
avg_interval = duration / msgs_len
# 已经超过平均发言间隔 N 倍的时间没有人说话了,才主动发言
# print(cur_time - latest_time, '/', avg_interval *
# Chat.speak_threshold + basic_delay)
if cur_time - latest_time < avg_interval * Chat.speak_threshold + basic_delay:
continue
# append 一个 flag, 防止这个群热度特别高,但压根就没有可用的 context 时,每次 speak 都查这个群,浪费时间
with Chat._reply_lock:
group_replies_front.append({
'time': int(cur_time),
'pre_raw_message': '[PaimonChat: Speak]',
'pre_keywords': '[PaimonChat: Speak]',
'reply': '[PaimonChat: Speak]',
'reply_keywords': '[PaimonChat: Speak]',
})
available_time = cur_time - 24 * 3600
speak_context = context_mongo.aggregate([
{
'$match': {
'count': {
'$gt': Chat.answer_threshold
},
'time': {
'$gt': available_time
},
# 上面两行为了加快查找速度,对查找到的结果不产生影响
'answers.group_id': group_id,
'answers.time': {
'$gt': available_time
},
'answers.count': {
'$gt': Chat.answer_threshold
}
}
}, {
'$sample': {'size': 1} # 随机一条
}
])
speak_context = list(speak_context)
if not speak_context:
continue
ban_keywords = Chat._find_ban_keywords(
context=speak_context[0], group_id=group_id)
messages = [answer['messages']
for answer in speak_context[0]['answers']
if answer['count'] >= Chat.answer_threshold
and answer['keywords'] not in ban_keywords
and answer['group_id'] == group_id]
if not messages:
continue
speak = random.choice(random.choice(messages))
bot_id = random.choice(
[bid for bid in group_replies.keys() if bid])
with Chat._reply_lock:
group_replies[bot_id].append({
'time': int(cur_time),
'pre_raw_message': '[PaimonChat: Speak]',
'pre_keywords': '[PaimonChat: Speak]',
'reply': speak,
'reply_keywords': '[PaimonChat: Speak]',
})
speak_list = [Message(speak), ]
while random.random() < Chat.speak_continuously_probability \
and len(speak_list) < Chat.speak_continuously_max_len:
pre_msg = str(speak_list[-1])
answer = Chat(ChatData(group_id, 0, pre_msg,
pre_msg, cur_time, 0)).answer(False)
if not answer:
break
speak_list.extend(answer)
if random.random() < Chat.speak_poke_probability:
target_id = random.choice(
Chat._message_dict[group_id])['user_id']
speak_list.append(Message('[CQ:poke,qq={}]'.format(target_id)))
return bot_id, group_id, speak_list
return None
@staticmethod
def ban(group_id: int, bot_id: int, ban_raw_message: str, reason: str) -> bool:
"""
禁止以后回复这句话仅对该群有效果
"""
if group_id not in Chat._reply_dict:
return False
ban_reply = None
reply_data = Chat._reply_dict[group_id][bot_id][::-1]
for reply in reply_data:
cur_reply = reply['reply']
# 为空时就直接 ban 最后一条回复
if not ban_raw_message or ban_raw_message in cur_reply:
ban_reply = reply
break
# 这种情况一般是有些 CQ 码,牛牛发送的时候,和被回复的时候,里面的内容不一样
if not ban_reply:
search = re.search(r'(\[CQ:[a-zA-z0-9-_.]+)',
ban_raw_message)
if search:
type_keyword = search.group(1)
for reply in reply_data:
cur_reply = reply['reply']
if type_keyword in cur_reply:
ban_reply = reply
break
if not ban_reply:
return False
pre_keywords = reply['pre_keywords']
keywords = reply['reply_keywords']
# 考虑这句回复是从别的群捞过来的情况,所以这里要分两次 update
# context_mongo.update_one({
# 'keywords': pre_keywords,
# 'answers.keywords': keywords,
# 'answers.group_id': group_id
# }, {
# '$set': {
# 'answers.$.count': -99999
# }
# })
context_mongo.update_one({
'keywords': pre_keywords
}, {
'$push': {
'ban': {
'keywords': keywords,
'group_id': group_id,
'reason': reason,
'time': int(time.time())
}
}
})
if keywords in Chat.blacklist_answer_reserve[group_id]:
Chat.blacklist_answer[group_id].add(keywords)
if keywords in Chat.blacklist_answer_reserve[Chat._blacklist_flag]:
Chat.blacklist_answer[Chat._blacklist_flag].add(
keywords)
else:
Chat.blacklist_answer_reserve[group_id].add(keywords)
return True
@staticmethod
def drink(group_id: int) -> None:
"""
牛牛喝酒仅对该群有效果提高醉酒程度降低回复阈值的概率
"""
Chat._drunkenness_dict[group_id] += 1
@staticmethod
def sober_up(group_id: int) -> bool:
"""
牛牛醒酒仅对该群有效果返回醒酒是否成功
"""
Chat._drunkenness_dict[group_id] -= 1
return Chat._drunkenness_dict[group_id] <= 0
# private:
_reply_dict = defaultdict(lambda: defaultdict(list)) # 牛牛回复的消息缓存,暂未做持久化
_message_dict = {} # 群消息缓存
_drunkenness_dict = defaultdict(int) # 醉酒程度,不同群应用不同的数值
_save_reserve_size = 100 # 保存时,给内存中保留的大小
_late_save_time = 0 # 上次保存(消息数据持久化)的时刻 ( time.time(), 秒 )
_reply_lock = threading.Lock()
_message_lock = threading.Lock()
_blacklist_flag = 114514
def _message_insert(self):
group_id = self.chat_data.group_id
with Chat._message_lock:
if group_id not in Chat._message_dict:
Chat._message_dict[group_id] = []
Chat._message_dict[group_id].append({
'group_id': group_id,
'user_id': self.chat_data.user_id,
'raw_message': self.chat_data.raw_message,
'is_plain_text': self.chat_data.is_plain_text,
'plain_text': self.chat_data.plain_text,
'keywords': self.chat_data.keywords,
'time': self.chat_data.time,
})
cur_time = self.chat_data.time
if Chat._late_save_time == 0:
Chat._late_save_time = cur_time - 1
return
if len(Chat._message_dict[group_id]) > Chat.save_count_threshold:
Chat._sync(cur_time)
elif cur_time - Chat._late_save_time > Chat.save_time_threshold:
Chat._sync(cur_time)
@staticmethod
def _sync(cur_time: int = time.time()):
"""
持久化
"""
with Chat._message_lock:
save_list = [msg
for group_msgs in Chat._message_dict.values()
for msg in group_msgs
if msg['time'] > Chat._late_save_time]
if not save_list:
return
Chat._message_dict = {group_id: group_msgs[-Chat._save_reserve_size:]
for group_id, group_msgs in Chat._message_dict.items()}
Chat._late_save_time = cur_time
message_mongo.insert_many(save_list)
def _context_insert(self, pre_msg):
if not pre_msg:
return
raw_message = self.chat_data.raw_message
# 在复读,不学
if pre_msg['raw_message'] == raw_message:
return
# 回复别人的,不学
if '[CQ:reply,' in raw_message:
return
keywords = self.chat_data.keywords
group_id = self.chat_data.group_id
pre_keywords = pre_msg['keywords']
cur_time = self.chat_data.time
# update_key = {
# 'keywords': pre_keywords,
# 'answers.keywords': keywords,
# 'answers.group_id': group_id
# }
# update_value = {
# '$set': {'time': cur_time},
# '$inc': {'answers.$.count': 1},
# '$push': {'answers.$.messages': raw_message}
# }
# # update_value.update(update_key)
# context_mongo.update_one(
# update_key, update_value, upsert=True)
# 这个 upsert 太难写了搞不定_(:з」∠)_
# 先用 find + insert or update 凑合了
find_key = {'keywords': pre_keywords}
context = context_mongo.find_one(find_key)
if context:
update_value = {
'$set': {
'time': cur_time
},
'$inc': {'count': 1}
}
answer_index = next((idx for idx, answer in enumerate(context['answers'])
if answer['group_id'] == group_id
and answer['keywords'] == keywords), -1)
if answer_index != -1:
update_value['$inc'].update({
f'answers.{answer_index}.count': 1
})
update_value['$set'].update({
f'answers.{answer_index}.time': cur_time
})
# 不是纯文本的时候raw_message 是完全一样的,没必要 push
if self.chat_data.is_plain_text:
update_value['$push'] = {
f'answers.{answer_index}.messages': raw_message
}
else:
update_value['$push'] = {
'answers': {
'keywords': keywords,
'group_id': group_id,
'count': 1,
'time': cur_time,
'messages': [
raw_message
]
}
}
context_mongo.update_one(find_key, update_value)
else:
context = {
'keywords': pre_keywords,
'time': cur_time,
'count': 1,
'answers': [
{
'keywords': keywords,
'group_id': group_id,
'count': 1,
'time': cur_time,
'messages': [
raw_message
]
}
]
}
context_mongo.insert_one(context)
def _context_find(self) -> Optional[Tuple[List[str], str]]:
group_id = self.chat_data.group_id
raw_message = self.chat_data.raw_message
keywords = self.chat_data.keywords
bot_id = self.chat_data.bot_id
# 复读!
if group_id in Chat._message_dict:
group_msgs = Chat._message_dict[group_id]
if len(group_msgs) >= Chat.repeat_threshold and \
all(item['raw_message'] == raw_message
for item in group_msgs[:-Chat.repeat_threshold:-1]):
# 到这里说明当前群里是在复读
group_bot_replies = Chat._reply_dict[group_id][bot_id]
if len(group_bot_replies) and group_bot_replies[-1]['reply'] != raw_message:
return [raw_message, ], keywords
else:
# 复读过一次就不再回复这句话了
return None
context = context_mongo.find_one({'keywords': keywords})
if not context:
return None
if Chat._drunkenness_dict[group_id] > 0:
answer_count_threshold = 1
else:
answer_count_threshold = Chat.answer_threshold
if self.chat_data.to_me:
cross_group_threshold = 1
else:
cross_group_threshold = Chat.cross_group_threshold
ban_keywords = Chat._find_ban_keywords(
context=context, group_id=group_id)
candidate_answers = {}
other_group_cache = {}
answers_count = defaultdict(int)
def candidate_append(dst, answer):
answer_key = answer['keywords']
if answer_key not in dst:
dst[answer_key] = answer
else:
pre_answer = dst[answer_key]
pre_answer['count'] += answer['count']
pre_answer['messages'] += answer['messages']
for answer in context['answers']:
answer_key = answer['keywords']
if answer_key in ban_keywords or answer['count'] < answer_count_threshold:
continue
sample_msg = answer['messages'][0]
if self.chat_data.is_image and '[CQ:' not in sample_msg:
# 图片消息不回复纯文本。图片经常是表情包,后面的纯文本啥都有,很乱
continue
if answer['group_id'] == group_id:
candidate_append(candidate_answers, answer)
# 别的群的 at, 忽略
elif '[CQ:at,qq=' in sample_msg:
continue
else: # 有这么 N 个群都有相同的回复,就作为全局回复
answers_count[answer_key] += 1
cur_count = answers_count[answer_key]
if cur_count < cross_group_threshold: # 没达到阈值前,先缓存
candidate_append(other_group_cache, answer)
elif cur_count == cross_group_threshold: # 刚达到阈值时,将缓存加入
if cur_count > 1:
candidate_append(candidate_answers,
other_group_cache[answer_key])
candidate_append(candidate_answers, answer)
else: # 超过阈值后,加入
candidate_append(candidate_answers, answer)
if not candidate_answers:
return None
final_answer = random.choices(list(candidate_answers.values()), weights=[
# 防止某个回复权重太大,别的都 Roll 不到了
min(answer['count'], 10) for answer in candidate_answers.values()])[0]
answer_str = random.choice(final_answer['messages'])
answer_keywords = final_answer['keywords']
if 0 < answer_str.count('') <= 3 and random.random() < Chat.split_probability:
return answer_str.split(''), answer_keywords
return [answer_str, ], answer_keywords
@staticmethod
def _text_to_speech(text: str) -> Optional[Message]:
# if plugin_config.enable_voice:
# result = tts_client.synthesis(text, options={'per': 111}) # 度小萌
# if not isinstance(result, dict): # error message
# return MessageSegment.record(result)
return Message(f'[CQ:tts,text={text}]')
@staticmethod
def update_global_blacklist() -> None:
Chat._select_blacklist()
keywords_dict = defaultdict(int)
global_blacklist = set()
for _, keywords_list in Chat.blacklist_answer.items():
for keywords in keywords_list:
keywords_dict[keywords] += 1
if keywords_dict[keywords] == Chat.cross_group_threshold:
global_blacklist.add(keywords)
Chat.blacklist_answer[Chat._blacklist_flag] |= global_blacklist
@staticmethod
def _select_blacklist() -> None:
all_blacklist = blacklist_mongo.find()
for item in all_blacklist:
group_id = item['group_id']
if 'answers' in item:
Chat.blacklist_answer[group_id] |= set(item['answers'])
if 'answers_reserve' in item:
Chat.blacklist_answer_reserve[group_id] |= set(
item['answers_reserve'])
@staticmethod
def _sync_blacklist() -> None:
Chat._select_blacklist()
for group_id, answers in Chat.blacklist_answer.items():
if not len(answers):
continue
blacklist_mongo.update_one(
{'group_id': group_id},
{'$set': {'answers': list(answers)}},
upsert=True)
for group_id, answers in Chat.blacklist_answer_reserve.items():
if not len(answers):
continue
if group_id in Chat.blacklist_answer:
answers = answers - Chat.blacklist_answer[group_id]
blacklist_mongo.update_one(
{'group_id': group_id},
{'$set': {'answers_reserve': list(answers)}},
upsert=True)
@staticmethod
def clearup_context() -> None:
"""
清理所有超过 30 天没人说且没有学会的话
"""
cur_time = int(time.time())
expiration = cur_time - 30 * 24 * 3600 # 三十天前
context_mongo.delete_many({
'time': {'$lt': expiration},
'count': {'$lt': Chat.answer_threshold} # lt 是小于,不包括等于
})
all_context = context_mongo.find({
'count': {'$gt': 100},
'$or': [
# 历史遗留问题,老版本的数据没有 clear_time 字段
{"clear_time": {"$exists": False}},
{"clear_time": {"$lt": expiration}}
]
})
for context in all_context:
answers = [ans
for ans in context['answers']
# 历史遗留问题,老版本的数据没有 answers.$.time 字段
if ans['count'] > 1 or ('time' in ans and ans['time'] > expiration)]
context_mongo.update_one({
'keywords': context['keywords']
}, {
'$set': {
'answers': answers,
'clear_time': cur_time
}
})
@staticmethod
def completely_sober():
for key in Chat._drunkenness_dict.keys():
Chat._drunkenness_dict[key] = 0
@staticmethod
def _find_ban_keywords(context, group_id) -> set:
"""
找到在 group_id 群中对应 context 不能回复的关键词
"""
# 全局的黑名单
ban_keywords = Chat.blacklist_answer[Chat._blacklist_flag] | Chat.blacklist_answer[group_id]
# 针对单条回复的黑名单
if 'ban' in context:
ban_count = defaultdict(int)
for ban in context['ban']:
ban_key = ban['keywords']
if ban['group_id'] == group_id or ban['group_id'] == Chat._blacklist_flag:
ban_keywords.add(ban_key)
else:
# 超过 N 个群都把这句话 ban 了,那就全局 ban 掉
ban_count[ban_key] += 1
if ban_count[ban_key] == Chat.cross_group_threshold:
ban_keywords.add(ban_key)
return ban_keywords
# Auto sync on program start
Chat.update_global_blacklist()
def _chat_sync():
Chat._sync()
Chat._sync_blacklist()
# Auto sync on program exit
atexit.register(_chat_sync)
if __name__ == '__main__':
# Chat.clearup_context()
# # while True:
test_data: ChatData = ChatData(
group_id=1234567,
user_id=1111111,
raw_message='完了又有新bug',
plain_text='完了又有新bug',
time=time.time(),
bot_id=0,
)
test_chat: Chat = Chat(test_data)
print(test_chat.answer())
test_chat.learn()
test_answer_data: ChatData = ChatData(
group_id=1234567,
user_id=1111111,
raw_message='完了又有新bug',
plain_text='完了又有新bug',
time=time.time(),
bot_id=0,
)
test_answer: Chat = Chat(test_answer_data)
print(test_chat.answer())
test_answer.learn()
# time.sleep(5)
# print(Chat.speak())

View File

@ -28,6 +28,8 @@ async def draw_ring(per):
plt.savefig('temp.png', transparent=True)
img = Image.open('temp.png').resize((266, 266)).convert('RGBA')
os.remove('temp.png')
plt.cla()
plt.close("all")
return img

View File

@ -40,6 +40,8 @@ async def draw_ring(per, colors):
plt.savefig('temp.png', transparent=True)
img = Image.open('temp.png').resize((378, 378)).convert('RGBA')
os.remove('temp.png')
plt.cla()
plt.close("all")
return img