From b5323fd378b76ea93916270271acb1e73ef67830 Mon Sep 17 00:00:00 2001 From: CMHopeSunshine <277073121@qq.com> Date: Sat, 28 May 2022 23:34:15 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E7=BE=A4=E8=81=8A=E8=AE=B0?= =?UTF-8?q?=E5=BD=95=E5=AD=A6=E4=B9=A0=E5=8F=91=E8=A8=80=EF=BC=8C=E4=BC=98?= =?UTF-8?q?=E5=8C=96=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Paimon_Chat/Learning_repeate/main.py | 214 +++++++ Paimon_Chat/Learning_repeate/model.py | 879 ++++++++++++++++++++++++++ Paimon_Info/draw_daily_note.py | 2 + Paimon_Info/draw_month_info.py | 2 + 4 files changed, 1097 insertions(+) create mode 100644 Paimon_Chat/Learning_repeate/main.py create mode 100644 Paimon_Chat/Learning_repeate/model.py diff --git a/Paimon_Chat/Learning_repeate/main.py b/Paimon_Chat/Learning_repeate/main.py new file mode 100644 index 0000000..40c7d36 --- /dev/null +++ b/Paimon_Chat/Learning_repeate/main.py @@ -0,0 +1,214 @@ +import random +import asyncio +import re +import time +import os +import threading + +from nonebot import on_message, require, get_bot, logger +from nonebot.exception import ActionFailed +from nonebot.typing import T_State +from nonebot.rule import keyword, to_me, Rule +from nonebot.adapters import Bot +from nonebot.adapters.onebot.v11 import GroupMessageEvent + +from nonebot.adapters.onebot.v11 import permission + +from .model import Chat +from utils.config import config + +message_id_lock = threading.Lock() +message_id_dict = {} + + +async def check_accounts(event: GroupMessageEvent) -> bool: + # 不响应其他nonebot_plugin_gocqhttp机器人账号的信息 + if os.path.exists('accounts'): + accounts = [int(d) for d in os.listdir('accounts') + if d.isnumeric()] + if event.user_id in accounts: + return False + return True + + +async def get_answer(event: GroupMessageEvent, state: T_State) -> bool: + # 不响应被屏蔽的人的信息 + if event.user_id in config.paimon_chat_ban: + return False + chat: Chat = Chat(event) + to_learn = True + # 多账号登陆,且在同一群中时;避免一条消息被处理多次 + with message_id_lock: + message_id = event.message_id + group_id = event.group_id + if group_id in message_id_dict: + if message_id in message_id_dict[group_id]: + to_learn = False + else: + message_id_dict[group_id] = [] + + group_message = message_id_dict[group_id] + group_message.append(message_id) + if len(group_message) > 100: + group_message = group_message[:-10] + answers = chat.answer() + if to_learn: + chat.learn() + + if answers: + state['answers'] = answers + return True + return False + + +any_msg = on_message( + priority=20, + block=False, + rule=Rule(check_accounts, get_answer), + permission=permission.GROUP # | permission.PRIVATE_FRIEND +) + + +async def is_shutup(self_id: int, group_id: int) -> bool: + info = await get_bot(str(self_id)).call_api('get_group_member_info', **{ + 'user_id': self_id, + 'group_id': group_id + }) + flag: bool = info['shut_up_timestamp'] > time.time() + + if flag: + logger.info(f'repeater:派蒙[{self_id}]在群[{group_id}] 处于禁言状态') + + return flag + + +@any_msg.handle() +async def _(bot: Bot, event: GroupMessageEvent, state: T_State): + + delay = random.randint(2, 4) + for item in state['answers']: + logger.info(f'repeater:派蒙[{event.self_id}]准备向群[{event.group_id}]回复[{item}]') + + await asyncio.sleep(delay) + try: + await any_msg.send(item) + except ActionFailed: + # 自动删除失效消息。若 bot 处于风控期,请勿开启该功能 + shutup = await is_shutup(event.self_id, event.group_id) + if not shutup: # 说明这条消息失效了 + logger.info('repeater | bot [{}] ready to ban [{}] in group [{}]'.format( + event.self_id, str(item), event.group_id)) + Chat.ban(event.group_id, event.self_id, str(item), 'ActionFailed') + break + delay = random.randint(2, 4) + + +async def is_reply(bot: Bot, event: GroupMessageEvent) -> bool: + return bool(event.reply) + + +ban_msg = on_message( + rule=to_me() & keyword('不可以', '达咩', '不行', 'no') & Rule(is_reply), + priority=5, + block=True, + permission=permission.GROUP_OWNER | permission.GROUP_ADMIN +) + + +@ban_msg.handle() +async def _(bot: Bot, event: GroupMessageEvent): + if '[CQ:reply,' not in event.raw_message: + return False + + raw_message = '' + for item in event.reply.message: + raw_reply = str(item) + # 去掉图片消息中的 url, subType 等字段 + raw_message += re.sub(r'(\[CQ\:.+)(?:,url=*)(\])', + r'\1\2', raw_reply) + + logger.info(f'repeater:派蒙[{event.self_id}] ready to ban [{raw_message}] in group [{event.group_id}]') + + if Chat.ban(event.group_id, event.self_id, raw_message, str(event.user_id)): + msg_send = ['派蒙知道错了...达咩!', '派蒙不会再这么说了...', '果面呐噻,派蒙说错话了...'] + await ban_msg.finish(random.choice(msg_send)) + + +scheduler = require('nonebot_plugin_apscheduler').scheduler + + +async def message_is_ban(bot: Bot, event: GroupMessageEvent) -> bool: + return event.get_plaintext().strip() == '不可以发这个' + + +ban_msg_latest = on_message( + rule=to_me() & Rule(message_is_ban), + priority=5, + block=True, + permission=permission.GROUP_OWNER | permission.GROUP_ADMIN +) + + +@ban_msg_latest.handle() +async def _(bot: Bot, event: GroupMessageEvent): + logger.info( + f'repeater:派蒙[{event.self_id}]把群[{event.group_id}]最后的回复ban了') + + if Chat.ban(event.group_id, event.self_id, '', str(event.user_id)): + msg_send = ['派蒙知道错了...达咩!', '派蒙不会再这么说了...', '果面呐噻,派蒙说错话了...'] + await ban_msg_latest.finish(random.choice(msg_send)) + + +@scheduler.scheduled_job('interval', seconds=5, misfire_grace_time=5) +async def speak_up(): + ret = Chat.speak() + if not ret: + return + + bot_id, group_id, messages = ret + + for msg in messages: + logger.info(f'repeater:派蒙[{bot_id}]准备向群[{group_id}]发送消息[{messages}]') + await get_bot(str(bot_id)).call_api('send_group_msg', **{ + 'message': msg, + 'group_id': group_id + }) + await asyncio.sleep(random.randint(2, 4)) + + +update_scheduler = require('nonebot_plugin_apscheduler').scheduler + + +async def is_drink_msg(bot: Bot, event: GroupMessageEvent) -> bool: + return event.get_plaintext().strip() in ['派蒙干杯', '应急食品开餐', '派蒙干饭'] + + +drink_msg = on_message( + rule=Rule(is_drink_msg), + priority=5, + block=True, + permission=permission.GROUP_OWNER | permission.GROUP_ADMIN +) + + +@drink_msg.handle() +async def _(bot: Bot, event: GroupMessageEvent): + drunk_duration = random.randint(60, 600) + logger.info(f'repeater:派蒙[{event.self_id}]即将在群[{event.group_id}]喝醉,在[{drunk_duration}秒]后醒来') + Chat.drink(event.group_id) + try: + await drink_msg.send('呀,旅行者。你今天走起路来,怎么看着摇摇晃晃的?') + except ActionFailed: + pass + + await asyncio.sleep(drunk_duration) + ret = Chat.sober_up(event.group_id) + if ret: + logger.info(f'repeater:派蒙[{event.self_id}]在群[{event.group_id}]醒酒了') + await drink_msg.finish('呃...头好疼...下次不能喝那么多了...') + + +@update_scheduler.scheduled_job('cron', hour='4') +def update_data(): + Chat.clearup_context() + Chat.completely_sober() diff --git a/Paimon_Chat/Learning_repeate/model.py b/Paimon_Chat/Learning_repeate/model.py new file mode 100644 index 0000000..46ecf11 --- /dev/null +++ b/Paimon_Chat/Learning_repeate/model.py @@ -0,0 +1,879 @@ +from typing import Generator, List, Optional, Union, Tuple, Dict, Any +from functools import cached_property, cmp_to_key +from dataclasses import dataclass +from collections import defaultdict + +import jieba_fast.analyse +import threading +import pypinyin +import pymongo +import time +import random +import re +import atexit + +from nonebot.adapters.onebot.v11 import GroupMessageEvent, PrivateMessageEvent +from nonebot.adapters.onebot.v11 import Message + +from utils.config import config + +mongo_client = pymongo.MongoClient(config.paimon_mongodb_url) + +mongo_db = mongo_client['PaimonChat'] + +message_mongo = mongo_db['message'] +message_mongo.create_index(name='time_index', + keys=[('time', pymongo.DESCENDING)]) + +context_mongo = mongo_db['context'] +context_mongo.create_index(name='keywords_index', + keys=[('keywords', pymongo.HASHED)]) +context_mongo.create_index(name='count_index', + keys=[('count', pymongo.DESCENDING)]) +context_mongo.create_index(name='time_index', + keys=[('time', pymongo.DESCENDING)]) +context_mongo.create_index(name='answers_index', + keys=[('answers.group_id', pymongo.TEXT), + ('answers.keywords', pymongo.TEXT)], + default_language='none') + +blacklist_mongo = mongo_db['blacklist'] +blacklist_mongo.create_index(name='group_index', + keys=[('group_id', pymongo.HASHED)]) + + +@dataclass +class ChatData: + group_id: int + user_id: int + raw_message: str + plain_text: str + time: int + bot_id: int + + _keywords_size: int = 3 + + @cached_property + def is_plain_text(self) -> bool: + return '[CQ:' not in self.raw_message and len(self.plain_text) != 0 + + @cached_property + def is_image(self) -> bool: + return '[CQ:image,' in self.raw_message or '[CQ:face,' in self.raw_message + + @cached_property + def keywords(self) -> str: + if not self.is_plain_text and len(self.plain_text) == 0: + return self.raw_message + + keywords_list = jieba_fast.analyse.extract_tags( + self.plain_text, topK=ChatData._keywords_size) + if len(keywords_list) < 2: + return self.plain_text + else: + # keywords_list.sort() + return ' '.join(keywords_list) + + @cached_property + def keywords_pinyin(self) -> str: + return ''.join([item[0] for item in pypinyin.pinyin( + self.keywords, style=pypinyin.NORMAL, errors='default')]).lower() + + @cached_property + def to_me(self) -> bool: + return self.plain_text.startswith('派蒙') + + +class Chat: + answer_threshold = config.paimon_answer_threshold # answer 相关的阈值,值越小废话越多,越大话越少 + answer_limit_threshold = config.paimon_answer_limit_threshold # 上限阈值,一般正常的上下文不可能发 50 遍,一般是其他 bot 的回复,禁了! + cross_group_threshold = config.paimon_cross_group_threshold # N 个群有相同的回复,就跨群作为全局回复 + repeat_threshold = config.paimon_repeat_threshold # 复读的阈值,群里连续多少次有相同的发言,就复读 + speak_threshold = config.paimon_speak_threshold # 主动发言的阈值,越小废话越多 + + drunk_probability = config.paimon_drunk_probability # 喝醉的概率(回复没达到阈值的话) + split_probability = 0.5 # 按逗号分割回复语的概率 + voice_probability = config.paimon_voice_probability # 回复语音的概率(仅纯文字) + speak_continuously_probability = config.paimon_speak_continuously_probability # 连续主动说话的概率 + speak_poke_probability = config.paimon_speak_poke_probability # 主动说话加上随机戳一戳群友的概率 + speak_continuously_max_len = config.paimon_speak_continuously_max_len # 连续主动说话最多几句话 + + save_time_threshold = 3600 # 每隔多久进行一次持久化 ( 秒 ) + save_count_threshold = 1000 # 单个群超过多少条聊天记录就进行一次持久化。与时间是或的关系 + + blacklist_answer = defaultdict(set) + blacklist_answer_reserve = defaultdict(set) + + def __init__(self, data: Union[ChatData, GroupMessageEvent, PrivateMessageEvent]): + + if isinstance(data, ChatData): + self.chat_data = data + elif isinstance(data, GroupMessageEvent): + self.chat_data = ChatData( + group_id=data.group_id, + user_id=data.user_id, + # 删除图片子类型字段,同一张图子类型经常不一样,影响判断 + raw_message=re.sub( + r',subType=\d+\]', + r']', + data.raw_message), + plain_text=data.get_plaintext(), + time=data.time, + bot_id=data.self_id, + ) + elif isinstance(data, PrivateMessageEvent): + event_dict = data.dict() + self.chat_data = ChatData( + group_id=data.user_id, # 故意加个符号,和群号区分开来 + user_id=data.user_id, + # 删除图片子类型字段,同一张图子类型经常不一样,影响判断 + raw_message=re.sub( + r',subType=\d+\]', + r']', + data.raw_message), + plain_text=data.get_plaintext(), + time=data.time, + bot_id=data.self_id, + ) + + def learn(self) -> bool: + """ + 学习这句话 + """ + + if len(self.chat_data.raw_message.strip()) == 0: + return False + + group_id = self.chat_data.group_id + if group_id in Chat._message_dict: + group_msgs = Chat._message_dict[group_id] + if group_msgs: + group_pre_msg = group_msgs[-1] + else: + group_pre_msg = None + + # 群里的上一条发言 + self._context_insert(group_pre_msg) + + user_id = self.chat_data.user_id + if group_pre_msg and group_pre_msg['user_id'] != user_id: + # 该用户在群里的上一条发言(倒序三句之内) + for msg in group_msgs[:-3:-1]: + if msg['user_id'] == user_id: + self._context_insert(msg) + break + + self._message_insert() + return True + + def answer(self, with_limit: bool = True) -> Optional[Generator[Message, None, None]]: + """ + 回复这句话,可能会分多次回复,也可能不回复 + """ + + group_id = self.chat_data.group_id + bot_id = self.chat_data.bot_id + group_bot_replies = Chat._reply_dict[group_id][bot_id] + + if with_limit: + # # 不回复太短的对话,大部分是“?”、“草” + # if self.chat_data.is_plain_text and len(self.chat_data.plain_text) < 2: + # return None + + if len(group_bot_replies): + latest_reply = group_bot_replies[-1] + # 限制发音频率,最多 6 秒一次 + if int(time.time()) - latest_reply['time'] < 6: + return None + # # 不要一直回复同一个内容 + # if self.chat_data.raw_message == latest_reply['pre_raw_message']: + # return None + # 有人复读了牛牛的回复,不继续回复 + # if self.chat_data.raw_message == latest_reply['reply']: + # return None + + results = self._context_find() + + if results: + raw_message = self.chat_data.raw_message + keywords = self.chat_data.keywords + with Chat._reply_lock: + group_bot_replies.append({ + 'time': int(time.time()), + 'pre_raw_message': raw_message, + 'pre_keywords': keywords, + 'reply': '[PaimonChat: Reply]', # flag + 'reply_keywords': '[PaimonChat: Reply]', # flag + }) + + def yield_results(results: Tuple[List[str], str]) -> Generator[Message, None, None]: + answer_list, answer_keywords = results + group_bot_replies = Chat._reply_dict[group_id][bot_id] + for item in answer_list: + with Chat._reply_lock: + group_bot_replies.append({ + 'time': int(time.time()), + 'pre_raw_message': raw_message, + 'pre_keywords': keywords, + 'reply': item, + 'reply_keywords': answer_keywords, + }) + if '[CQ:' not in item and len(item) > 1 \ + and random.random() < Chat.voice_probability: + yield Chat._text_to_speech(item) + else: + yield Message(item) + + with Chat._reply_lock: + group_bot_replies = group_bot_replies[-Chat._save_reserve_size:] + + return yield_results(results) + + return None + + @staticmethod + def speak() -> Optional[Tuple[int, int, List[Message]]]: + """ + 主动发言,返回当前最希望发言的 bot 账号、群号、发言消息 List,也有可能不发言 + """ + + basic_msgs_len = 10 + basic_delay = 600 + + def group_popularity_cmp(lhs: Tuple[int, List[Dict[str, Any]]], + rhs: Tuple[int, List[Dict[str, Any]]]) -> int: + + def cmp(a: Any, b: Any): + return (a > b) - (a < b) + + lhs_group_id, lhs_msgs = lhs + rhs_group_id, rhs_msgs = rhs + + lhs_len = len(lhs_msgs) + rhs_len = len(rhs_msgs) + + # 默认是 0, 加个 1 避免乘没了 + lhs_drunkenness = Chat._drunkenness_dict[lhs_group_id] + 1 + rhs_drunkenness = Chat._drunkenness_dict[rhs_group_id] + 1 + + if lhs_len < basic_msgs_len or rhs_len < basic_msgs_len: + return cmp(lhs_len * lhs_drunkenness, + rhs_len * rhs_drunkenness) + + lhs_duration = lhs_msgs[-1]['time'] - lhs_msgs[0]['time'] + rhs_duration = rhs_msgs[-1]['time'] - rhs_msgs[0]['time'] + + if not lhs_duration or not rhs_duration: + return cmp(lhs_len, rhs_len) + + return cmp(lhs_len * lhs_drunkenness / lhs_duration, + rhs_len * rhs_drunkenness / rhs_duration) + + # 按群聊热度排序 + popularity = sorted(Chat._message_dict.items(), + key=cmp_to_key(group_popularity_cmp)) + + cur_time = time.time() + for group_id, group_msgs in popularity: + group_replies = Chat._reply_dict[group_id] + if not len(group_replies) or len(group_msgs) < basic_msgs_len: + continue + + # 一般来说所有牛牛都是一起回复的,最后发言时间应该是一样的,随意随便选一个[0]就好了 + group_replies_front = list(group_replies.values())[0] + if not len(group_replies_front) or \ + group_replies_front[-1]['time'] > group_msgs[-1]['time']: + continue + + msgs_len = len(group_msgs) + latest_time = group_msgs[-1]['time'] + duration = latest_time - group_msgs[0]['time'] + avg_interval = duration / msgs_len + + # 已经超过平均发言间隔 N 倍的时间没有人说话了,才主动发言 + # print(cur_time - latest_time, '/', avg_interval * + # Chat.speak_threshold + basic_delay) + if cur_time - latest_time < avg_interval * Chat.speak_threshold + basic_delay: + continue + + # append 一个 flag, 防止这个群热度特别高,但压根就没有可用的 context 时,每次 speak 都查这个群,浪费时间 + with Chat._reply_lock: + group_replies_front.append({ + 'time': int(cur_time), + 'pre_raw_message': '[PaimonChat: Speak]', + 'pre_keywords': '[PaimonChat: Speak]', + 'reply': '[PaimonChat: Speak]', + 'reply_keywords': '[PaimonChat: Speak]', + }) + + available_time = cur_time - 24 * 3600 + speak_context = context_mongo.aggregate([ + { + '$match': { + 'count': { + '$gt': Chat.answer_threshold + }, + 'time': { + '$gt': available_time + }, + # 上面两行为了加快查找速度,对查找到的结果不产生影响 + 'answers.group_id': group_id, + 'answers.time': { + '$gt': available_time + }, + 'answers.count': { + '$gt': Chat.answer_threshold + } + } + }, { + '$sample': {'size': 1} # 随机一条 + } + ]) + + speak_context = list(speak_context) + if not speak_context: + continue + + ban_keywords = Chat._find_ban_keywords( + context=speak_context[0], group_id=group_id) + messages = [answer['messages'] + for answer in speak_context[0]['answers'] + if answer['count'] >= Chat.answer_threshold + and answer['keywords'] not in ban_keywords + and answer['group_id'] == group_id] + + if not messages: + continue + + speak = random.choice(random.choice(messages)) + + bot_id = random.choice( + [bid for bid in group_replies.keys() if bid]) + with Chat._reply_lock: + group_replies[bot_id].append({ + 'time': int(cur_time), + 'pre_raw_message': '[PaimonChat: Speak]', + 'pre_keywords': '[PaimonChat: Speak]', + 'reply': speak, + 'reply_keywords': '[PaimonChat: Speak]', + }) + + speak_list = [Message(speak), ] + while random.random() < Chat.speak_continuously_probability \ + and len(speak_list) < Chat.speak_continuously_max_len: + pre_msg = str(speak_list[-1]) + answer = Chat(ChatData(group_id, 0, pre_msg, + pre_msg, cur_time, 0)).answer(False) + if not answer: + break + speak_list.extend(answer) + + if random.random() < Chat.speak_poke_probability: + target_id = random.choice( + Chat._message_dict[group_id])['user_id'] + speak_list.append(Message('[CQ:poke,qq={}]'.format(target_id))) + + return bot_id, group_id, speak_list + + return None + + @staticmethod + def ban(group_id: int, bot_id: int, ban_raw_message: str, reason: str) -> bool: + """ + 禁止以后回复这句话,仅对该群有效果 + """ + + if group_id not in Chat._reply_dict: + return False + + ban_reply = None + reply_data = Chat._reply_dict[group_id][bot_id][::-1] + + for reply in reply_data: + cur_reply = reply['reply'] + # 为空时就直接 ban 最后一条回复 + if not ban_raw_message or ban_raw_message in cur_reply: + ban_reply = reply + break + + # 这种情况一般是有些 CQ 码,牛牛发送的时候,和被回复的时候,里面的内容不一样 + if not ban_reply: + search = re.search(r'(\[CQ:[a-zA-z0-9-_.]+)', + ban_raw_message) + if search: + type_keyword = search.group(1) + for reply in reply_data: + cur_reply = reply['reply'] + if type_keyword in cur_reply: + ban_reply = reply + break + + if not ban_reply: + return False + + pre_keywords = reply['pre_keywords'] + keywords = reply['reply_keywords'] + + # 考虑这句回复是从别的群捞过来的情况,所以这里要分两次 update + # context_mongo.update_one({ + # 'keywords': pre_keywords, + # 'answers.keywords': keywords, + # 'answers.group_id': group_id + # }, { + # '$set': { + # 'answers.$.count': -99999 + # } + # }) + context_mongo.update_one({ + 'keywords': pre_keywords + }, { + '$push': { + 'ban': { + 'keywords': keywords, + 'group_id': group_id, + 'reason': reason, + 'time': int(time.time()) + } + } + }) + if keywords in Chat.blacklist_answer_reserve[group_id]: + Chat.blacklist_answer[group_id].add(keywords) + if keywords in Chat.blacklist_answer_reserve[Chat._blacklist_flag]: + Chat.blacklist_answer[Chat._blacklist_flag].add( + keywords) + else: + Chat.blacklist_answer_reserve[group_id].add(keywords) + + return True + + @staticmethod + def drink(group_id: int) -> None: + """ + 牛牛喝酒,仅对该群有效果。提高醉酒程度(降低回复阈值的概率) + """ + Chat._drunkenness_dict[group_id] += 1 + + @staticmethod + def sober_up(group_id: int) -> bool: + """ + 牛牛醒酒,仅对该群有效果。返回醒酒是否成功 + """ + + Chat._drunkenness_dict[group_id] -= 1 + return Chat._drunkenness_dict[group_id] <= 0 + +# private: + _reply_dict = defaultdict(lambda: defaultdict(list)) # 牛牛回复的消息缓存,暂未做持久化 + _message_dict = {} # 群消息缓存 + _drunkenness_dict = defaultdict(int) # 醉酒程度,不同群应用不同的数值 + + _save_reserve_size = 100 # 保存时,给内存中保留的大小 + _late_save_time = 0 # 上次保存(消息数据持久化)的时刻 ( time.time(), 秒 ) + + _reply_lock = threading.Lock() + _message_lock = threading.Lock() + _blacklist_flag = 114514 + + def _message_insert(self): + group_id = self.chat_data.group_id + + with Chat._message_lock: + if group_id not in Chat._message_dict: + Chat._message_dict[group_id] = [] + + Chat._message_dict[group_id].append({ + 'group_id': group_id, + 'user_id': self.chat_data.user_id, + 'raw_message': self.chat_data.raw_message, + 'is_plain_text': self.chat_data.is_plain_text, + 'plain_text': self.chat_data.plain_text, + 'keywords': self.chat_data.keywords, + 'time': self.chat_data.time, + }) + + cur_time = self.chat_data.time + if Chat._late_save_time == 0: + Chat._late_save_time = cur_time - 1 + return + + if len(Chat._message_dict[group_id]) > Chat.save_count_threshold: + Chat._sync(cur_time) + + elif cur_time - Chat._late_save_time > Chat.save_time_threshold: + Chat._sync(cur_time) + + @staticmethod + def _sync(cur_time: int = time.time()): + """ + 持久化 + """ + + with Chat._message_lock: + save_list = [msg + for group_msgs in Chat._message_dict.values() + for msg in group_msgs + if msg['time'] > Chat._late_save_time] + if not save_list: + return + + Chat._message_dict = {group_id: group_msgs[-Chat._save_reserve_size:] + for group_id, group_msgs in Chat._message_dict.items()} + + Chat._late_save_time = cur_time + + message_mongo.insert_many(save_list) + + def _context_insert(self, pre_msg): + if not pre_msg: + return + + raw_message = self.chat_data.raw_message + + # 在复读,不学 + if pre_msg['raw_message'] == raw_message: + return + + # 回复别人的,不学 + if '[CQ:reply,' in raw_message: + return + + keywords = self.chat_data.keywords + group_id = self.chat_data.group_id + pre_keywords = pre_msg['keywords'] + cur_time = self.chat_data.time + + # update_key = { + # 'keywords': pre_keywords, + # 'answers.keywords': keywords, + # 'answers.group_id': group_id + # } + # update_value = { + # '$set': {'time': cur_time}, + # '$inc': {'answers.$.count': 1}, + # '$push': {'answers.$.messages': raw_message} + # } + # # update_value.update(update_key) + + # context_mongo.update_one( + # update_key, update_value, upsert=True) + + # 这个 upsert 太难写了,搞不定_(:з」∠)_ + # 先用 find + insert or update 凑合了 + find_key = {'keywords': pre_keywords} + context = context_mongo.find_one(find_key) + if context: + update_value = { + '$set': { + 'time': cur_time + }, + '$inc': {'count': 1} + } + answer_index = next((idx for idx, answer in enumerate(context['answers']) + if answer['group_id'] == group_id + and answer['keywords'] == keywords), -1) + if answer_index != -1: + update_value['$inc'].update({ + f'answers.{answer_index}.count': 1 + }) + update_value['$set'].update({ + f'answers.{answer_index}.time': cur_time + }) + # 不是纯文本的时候,raw_message 是完全一样的,没必要 push + if self.chat_data.is_plain_text: + update_value['$push'] = { + f'answers.{answer_index}.messages': raw_message + } + else: + update_value['$push'] = { + 'answers': { + 'keywords': keywords, + 'group_id': group_id, + 'count': 1, + 'time': cur_time, + 'messages': [ + raw_message + ] + } + } + + context_mongo.update_one(find_key, update_value) + else: + context = { + 'keywords': pre_keywords, + 'time': cur_time, + 'count': 1, + 'answers': [ + { + 'keywords': keywords, + 'group_id': group_id, + 'count': 1, + 'time': cur_time, + 'messages': [ + raw_message + ] + } + ] + } + context_mongo.insert_one(context) + + def _context_find(self) -> Optional[Tuple[List[str], str]]: + + group_id = self.chat_data.group_id + raw_message = self.chat_data.raw_message + keywords = self.chat_data.keywords + bot_id = self.chat_data.bot_id + + # 复读! + if group_id in Chat._message_dict: + group_msgs = Chat._message_dict[group_id] + if len(group_msgs) >= Chat.repeat_threshold and \ + all(item['raw_message'] == raw_message + for item in group_msgs[:-Chat.repeat_threshold:-1]): + # 到这里说明当前群里是在复读 + group_bot_replies = Chat._reply_dict[group_id][bot_id] + if len(group_bot_replies) and group_bot_replies[-1]['reply'] != raw_message: + return [raw_message, ], keywords + else: + # 复读过一次就不再回复这句话了 + return None + + context = context_mongo.find_one({'keywords': keywords}) + + if not context: + return None + + if Chat._drunkenness_dict[group_id] > 0: + answer_count_threshold = 1 + else: + answer_count_threshold = Chat.answer_threshold + + if self.chat_data.to_me: + cross_group_threshold = 1 + else: + cross_group_threshold = Chat.cross_group_threshold + + ban_keywords = Chat._find_ban_keywords( + context=context, group_id=group_id) + + candidate_answers = {} + other_group_cache = {} + answers_count = defaultdict(int) + + def candidate_append(dst, answer): + answer_key = answer['keywords'] + if answer_key not in dst: + dst[answer_key] = answer + else: + pre_answer = dst[answer_key] + pre_answer['count'] += answer['count'] + pre_answer['messages'] += answer['messages'] + + for answer in context['answers']: + answer_key = answer['keywords'] + if answer_key in ban_keywords or answer['count'] < answer_count_threshold: + continue + + sample_msg = answer['messages'][0] + if self.chat_data.is_image and '[CQ:' not in sample_msg: + # 图片消息不回复纯文本。图片经常是表情包,后面的纯文本啥都有,很乱 + continue + + if answer['group_id'] == group_id: + candidate_append(candidate_answers, answer) + # 别的群的 at, 忽略 + elif '[CQ:at,qq=' in sample_msg: + continue + else: # 有这么 N 个群都有相同的回复,就作为全局回复 + answers_count[answer_key] += 1 + cur_count = answers_count[answer_key] + if cur_count < cross_group_threshold: # 没达到阈值前,先缓存 + candidate_append(other_group_cache, answer) + elif cur_count == cross_group_threshold: # 刚达到阈值时,将缓存加入 + if cur_count > 1: + candidate_append(candidate_answers, + other_group_cache[answer_key]) + candidate_append(candidate_answers, answer) + else: # 超过阈值后,加入 + candidate_append(candidate_answers, answer) + + if not candidate_answers: + return None + + final_answer = random.choices(list(candidate_answers.values()), weights=[ + # 防止某个回复权重太大,别的都 Roll 不到了 + min(answer['count'], 10) for answer in candidate_answers.values()])[0] + answer_str = random.choice(final_answer['messages']) + answer_keywords = final_answer['keywords'] + + if 0 < answer_str.count(',') <= 3 and random.random() < Chat.split_probability: + return answer_str.split(','), answer_keywords + return [answer_str, ], answer_keywords + + @staticmethod + def _text_to_speech(text: str) -> Optional[Message]: + # if plugin_config.enable_voice: + # result = tts_client.synthesis(text, options={'per': 111}) # 度小萌 + # if not isinstance(result, dict): # error message + # return MessageSegment.record(result) + + return Message(f'[CQ:tts,text={text}]') + + @staticmethod + def update_global_blacklist() -> None: + Chat._select_blacklist() + + keywords_dict = defaultdict(int) + global_blacklist = set() + for _, keywords_list in Chat.blacklist_answer.items(): + for keywords in keywords_list: + keywords_dict[keywords] += 1 + if keywords_dict[keywords] == Chat.cross_group_threshold: + global_blacklist.add(keywords) + + Chat.blacklist_answer[Chat._blacklist_flag] |= global_blacklist + + @staticmethod + def _select_blacklist() -> None: + all_blacklist = blacklist_mongo.find() + + for item in all_blacklist: + group_id = item['group_id'] + if 'answers' in item: + Chat.blacklist_answer[group_id] |= set(item['answers']) + if 'answers_reserve' in item: + Chat.blacklist_answer_reserve[group_id] |= set( + item['answers_reserve']) + + @staticmethod + def _sync_blacklist() -> None: + Chat._select_blacklist() + + for group_id, answers in Chat.blacklist_answer.items(): + if not len(answers): + continue + blacklist_mongo.update_one( + {'group_id': group_id}, + {'$set': {'answers': list(answers)}}, + upsert=True) + + for group_id, answers in Chat.blacklist_answer_reserve.items(): + if not len(answers): + continue + if group_id in Chat.blacklist_answer: + answers = answers - Chat.blacklist_answer[group_id] + + blacklist_mongo.update_one( + {'group_id': group_id}, + {'$set': {'answers_reserve': list(answers)}}, + upsert=True) + + @staticmethod + def clearup_context() -> None: + """ + 清理所有超过 30 天没人说、且没有学会的话 + """ + + cur_time = int(time.time()) + expiration = cur_time - 30 * 24 * 3600 # 三十天前 + + context_mongo.delete_many({ + 'time': {'$lt': expiration}, + 'count': {'$lt': Chat.answer_threshold} # lt 是小于,不包括等于 + }) + + all_context = context_mongo.find({ + 'count': {'$gt': 100}, + '$or': [ + # 历史遗留问题,老版本的数据没有 clear_time 字段 + {"clear_time": {"$exists": False}}, + {"clear_time": {"$lt": expiration}} + ] + }) + for context in all_context: + answers = [ans + for ans in context['answers'] + # 历史遗留问题,老版本的数据没有 answers.$.time 字段 + if ans['count'] > 1 or ('time' in ans and ans['time'] > expiration)] + context_mongo.update_one({ + 'keywords': context['keywords'] + }, { + '$set': { + 'answers': answers, + 'clear_time': cur_time + } + }) + + @staticmethod + def completely_sober(): + for key in Chat._drunkenness_dict.keys(): + Chat._drunkenness_dict[key] = 0 + + @staticmethod + def _find_ban_keywords(context, group_id) -> set: + """ + 找到在 group_id 群中对应 context 不能回复的关键词 + """ + + # 全局的黑名单 + ban_keywords = Chat.blacklist_answer[Chat._blacklist_flag] | Chat.blacklist_answer[group_id] + # 针对单条回复的黑名单 + if 'ban' in context: + ban_count = defaultdict(int) + for ban in context['ban']: + ban_key = ban['keywords'] + if ban['group_id'] == group_id or ban['group_id'] == Chat._blacklist_flag: + ban_keywords.add(ban_key) + else: + # 超过 N 个群都把这句话 ban 了,那就全局 ban 掉 + ban_count[ban_key] += 1 + if ban_count[ban_key] == Chat.cross_group_threshold: + ban_keywords.add(ban_key) + return ban_keywords + + +# Auto sync on program start +Chat.update_global_blacklist() + + +def _chat_sync(): + Chat._sync() + Chat._sync_blacklist() + + +# Auto sync on program exit +atexit.register(_chat_sync) + + +if __name__ == '__main__': + + # Chat.clearup_context() + # # while True: + test_data: ChatData = ChatData( + group_id=1234567, + user_id=1111111, + raw_message='完了又有新bug', + plain_text='完了又有新bug', + time=time.time(), + bot_id=0, + ) + + test_chat: Chat = Chat(test_data) + + print(test_chat.answer()) + test_chat.learn() + + test_answer_data: ChatData = ChatData( + group_id=1234567, + user_id=1111111, + raw_message='完了又有新bug', + plain_text='完了又有新bug', + time=time.time(), + bot_id=0, + ) + + test_answer: Chat = Chat(test_answer_data) + print(test_chat.answer()) + test_answer.learn() + + # time.sleep(5) + # print(Chat.speak()) \ No newline at end of file diff --git a/Paimon_Info/draw_daily_note.py b/Paimon_Info/draw_daily_note.py index 4f39640..34c6d49 100644 --- a/Paimon_Info/draw_daily_note.py +++ b/Paimon_Info/draw_daily_note.py @@ -28,6 +28,8 @@ async def draw_ring(per): plt.savefig('temp.png', transparent=True) img = Image.open('temp.png').resize((266, 266)).convert('RGBA') os.remove('temp.png') + plt.cla() + plt.close("all") return img diff --git a/Paimon_Info/draw_month_info.py b/Paimon_Info/draw_month_info.py index a9c73a3..0c68750 100644 --- a/Paimon_Info/draw_month_info.py +++ b/Paimon_Info/draw_month_info.py @@ -40,6 +40,8 @@ async def draw_ring(per, colors): plt.savefig('temp.png', transparent=True) img = Image.open('temp.png').resize((378, 378)).convert('RGBA') os.remove('temp.png') + plt.cla() + plt.close("all") return img