mirror of
https://github.com/xuthus83/LittlePaimon.git
synced 2024-12-16 13:40:53 +08:00
新增群聊记录学习发言,优化代码
This commit is contained in:
parent
de09bb113f
commit
b5323fd378
214
Paimon_Chat/Learning_repeate/main.py
Normal file
214
Paimon_Chat/Learning_repeate/main.py
Normal file
@ -0,0 +1,214 @@
|
||||
import random
|
||||
import asyncio
|
||||
import re
|
||||
import time
|
||||
import os
|
||||
import threading
|
||||
|
||||
from nonebot import on_message, require, get_bot, logger
|
||||
from nonebot.exception import ActionFailed
|
||||
from nonebot.typing import T_State
|
||||
from nonebot.rule import keyword, to_me, Rule
|
||||
from nonebot.adapters import Bot
|
||||
from nonebot.adapters.onebot.v11 import GroupMessageEvent
|
||||
|
||||
from nonebot.adapters.onebot.v11 import permission
|
||||
|
||||
from .model import Chat
|
||||
from utils.config import config
|
||||
|
||||
message_id_lock = threading.Lock()
|
||||
message_id_dict = {}
|
||||
|
||||
|
||||
async def check_accounts(event: GroupMessageEvent) -> bool:
|
||||
# 不响应其他nonebot_plugin_gocqhttp机器人账号的信息
|
||||
if os.path.exists('accounts'):
|
||||
accounts = [int(d) for d in os.listdir('accounts')
|
||||
if d.isnumeric()]
|
||||
if event.user_id in accounts:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
async def get_answer(event: GroupMessageEvent, state: T_State) -> bool:
|
||||
# 不响应被屏蔽的人的信息
|
||||
if event.user_id in config.paimon_chat_ban:
|
||||
return False
|
||||
chat: Chat = Chat(event)
|
||||
to_learn = True
|
||||
# 多账号登陆,且在同一群中时;避免一条消息被处理多次
|
||||
with message_id_lock:
|
||||
message_id = event.message_id
|
||||
group_id = event.group_id
|
||||
if group_id in message_id_dict:
|
||||
if message_id in message_id_dict[group_id]:
|
||||
to_learn = False
|
||||
else:
|
||||
message_id_dict[group_id] = []
|
||||
|
||||
group_message = message_id_dict[group_id]
|
||||
group_message.append(message_id)
|
||||
if len(group_message) > 100:
|
||||
group_message = group_message[:-10]
|
||||
answers = chat.answer()
|
||||
if to_learn:
|
||||
chat.learn()
|
||||
|
||||
if answers:
|
||||
state['answers'] = answers
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
any_msg = on_message(
|
||||
priority=20,
|
||||
block=False,
|
||||
rule=Rule(check_accounts, get_answer),
|
||||
permission=permission.GROUP # | permission.PRIVATE_FRIEND
|
||||
)
|
||||
|
||||
|
||||
async def is_shutup(self_id: int, group_id: int) -> bool:
|
||||
info = await get_bot(str(self_id)).call_api('get_group_member_info', **{
|
||||
'user_id': self_id,
|
||||
'group_id': group_id
|
||||
})
|
||||
flag: bool = info['shut_up_timestamp'] > time.time()
|
||||
|
||||
if flag:
|
||||
logger.info(f'repeater:派蒙[{self_id}]在群[{group_id}] 处于禁言状态')
|
||||
|
||||
return flag
|
||||
|
||||
|
||||
@any_msg.handle()
|
||||
async def _(bot: Bot, event: GroupMessageEvent, state: T_State):
|
||||
|
||||
delay = random.randint(2, 4)
|
||||
for item in state['answers']:
|
||||
logger.info(f'repeater:派蒙[{event.self_id}]准备向群[{event.group_id}]回复[{item}]')
|
||||
|
||||
await asyncio.sleep(delay)
|
||||
try:
|
||||
await any_msg.send(item)
|
||||
except ActionFailed:
|
||||
# 自动删除失效消息。若 bot 处于风控期,请勿开启该功能
|
||||
shutup = await is_shutup(event.self_id, event.group_id)
|
||||
if not shutup: # 说明这条消息失效了
|
||||
logger.info('repeater | bot [{}] ready to ban [{}] in group [{}]'.format(
|
||||
event.self_id, str(item), event.group_id))
|
||||
Chat.ban(event.group_id, event.self_id, str(item), 'ActionFailed')
|
||||
break
|
||||
delay = random.randint(2, 4)
|
||||
|
||||
|
||||
async def is_reply(bot: Bot, event: GroupMessageEvent) -> bool:
|
||||
return bool(event.reply)
|
||||
|
||||
|
||||
ban_msg = on_message(
|
||||
rule=to_me() & keyword('不可以', '达咩', '不行', 'no') & Rule(is_reply),
|
||||
priority=5,
|
||||
block=True,
|
||||
permission=permission.GROUP_OWNER | permission.GROUP_ADMIN
|
||||
)
|
||||
|
||||
|
||||
@ban_msg.handle()
|
||||
async def _(bot: Bot, event: GroupMessageEvent):
|
||||
if '[CQ:reply,' not in event.raw_message:
|
||||
return False
|
||||
|
||||
raw_message = ''
|
||||
for item in event.reply.message:
|
||||
raw_reply = str(item)
|
||||
# 去掉图片消息中的 url, subType 等字段
|
||||
raw_message += re.sub(r'(\[CQ\:.+)(?:,url=*)(\])',
|
||||
r'\1\2', raw_reply)
|
||||
|
||||
logger.info(f'repeater:派蒙[{event.self_id}] ready to ban [{raw_message}] in group [{event.group_id}]')
|
||||
|
||||
if Chat.ban(event.group_id, event.self_id, raw_message, str(event.user_id)):
|
||||
msg_send = ['派蒙知道错了...达咩!', '派蒙不会再这么说了...', '果面呐噻,派蒙说错话了...']
|
||||
await ban_msg.finish(random.choice(msg_send))
|
||||
|
||||
|
||||
scheduler = require('nonebot_plugin_apscheduler').scheduler
|
||||
|
||||
|
||||
async def message_is_ban(bot: Bot, event: GroupMessageEvent) -> bool:
|
||||
return event.get_plaintext().strip() == '不可以发这个'
|
||||
|
||||
|
||||
ban_msg_latest = on_message(
|
||||
rule=to_me() & Rule(message_is_ban),
|
||||
priority=5,
|
||||
block=True,
|
||||
permission=permission.GROUP_OWNER | permission.GROUP_ADMIN
|
||||
)
|
||||
|
||||
|
||||
@ban_msg_latest.handle()
|
||||
async def _(bot: Bot, event: GroupMessageEvent):
|
||||
logger.info(
|
||||
f'repeater:派蒙[{event.self_id}]把群[{event.group_id}]最后的回复ban了')
|
||||
|
||||
if Chat.ban(event.group_id, event.self_id, '', str(event.user_id)):
|
||||
msg_send = ['派蒙知道错了...达咩!', '派蒙不会再这么说了...', '果面呐噻,派蒙说错话了...']
|
||||
await ban_msg_latest.finish(random.choice(msg_send))
|
||||
|
||||
|
||||
@scheduler.scheduled_job('interval', seconds=5, misfire_grace_time=5)
|
||||
async def speak_up():
|
||||
ret = Chat.speak()
|
||||
if not ret:
|
||||
return
|
||||
|
||||
bot_id, group_id, messages = ret
|
||||
|
||||
for msg in messages:
|
||||
logger.info(f'repeater:派蒙[{bot_id}]准备向群[{group_id}]发送消息[{messages}]')
|
||||
await get_bot(str(bot_id)).call_api('send_group_msg', **{
|
||||
'message': msg,
|
||||
'group_id': group_id
|
||||
})
|
||||
await asyncio.sleep(random.randint(2, 4))
|
||||
|
||||
|
||||
update_scheduler = require('nonebot_plugin_apscheduler').scheduler
|
||||
|
||||
|
||||
async def is_drink_msg(bot: Bot, event: GroupMessageEvent) -> bool:
|
||||
return event.get_plaintext().strip() in ['派蒙干杯', '应急食品开餐', '派蒙干饭']
|
||||
|
||||
|
||||
drink_msg = on_message(
|
||||
rule=Rule(is_drink_msg),
|
||||
priority=5,
|
||||
block=True,
|
||||
permission=permission.GROUP_OWNER | permission.GROUP_ADMIN
|
||||
)
|
||||
|
||||
|
||||
@drink_msg.handle()
|
||||
async def _(bot: Bot, event: GroupMessageEvent):
|
||||
drunk_duration = random.randint(60, 600)
|
||||
logger.info(f'repeater:派蒙[{event.self_id}]即将在群[{event.group_id}]喝醉,在[{drunk_duration}秒]后醒来')
|
||||
Chat.drink(event.group_id)
|
||||
try:
|
||||
await drink_msg.send('呀,旅行者。你今天走起路来,怎么看着摇摇晃晃的?')
|
||||
except ActionFailed:
|
||||
pass
|
||||
|
||||
await asyncio.sleep(drunk_duration)
|
||||
ret = Chat.sober_up(event.group_id)
|
||||
if ret:
|
||||
logger.info(f'repeater:派蒙[{event.self_id}]在群[{event.group_id}]醒酒了')
|
||||
await drink_msg.finish('呃...头好疼...下次不能喝那么多了...')
|
||||
|
||||
|
||||
@update_scheduler.scheduled_job('cron', hour='4')
|
||||
def update_data():
|
||||
Chat.clearup_context()
|
||||
Chat.completely_sober()
|
879
Paimon_Chat/Learning_repeate/model.py
Normal file
879
Paimon_Chat/Learning_repeate/model.py
Normal file
@ -0,0 +1,879 @@
|
||||
from typing import Generator, List, Optional, Union, Tuple, Dict, Any
|
||||
from functools import cached_property, cmp_to_key
|
||||
from dataclasses import dataclass
|
||||
from collections import defaultdict
|
||||
|
||||
import jieba_fast.analyse
|
||||
import threading
|
||||
import pypinyin
|
||||
import pymongo
|
||||
import time
|
||||
import random
|
||||
import re
|
||||
import atexit
|
||||
|
||||
from nonebot.adapters.onebot.v11 import GroupMessageEvent, PrivateMessageEvent
|
||||
from nonebot.adapters.onebot.v11 import Message
|
||||
|
||||
from utils.config import config
|
||||
|
||||
mongo_client = pymongo.MongoClient(config.paimon_mongodb_url)
|
||||
|
||||
mongo_db = mongo_client['PaimonChat']
|
||||
|
||||
message_mongo = mongo_db['message']
|
||||
message_mongo.create_index(name='time_index',
|
||||
keys=[('time', pymongo.DESCENDING)])
|
||||
|
||||
context_mongo = mongo_db['context']
|
||||
context_mongo.create_index(name='keywords_index',
|
||||
keys=[('keywords', pymongo.HASHED)])
|
||||
context_mongo.create_index(name='count_index',
|
||||
keys=[('count', pymongo.DESCENDING)])
|
||||
context_mongo.create_index(name='time_index',
|
||||
keys=[('time', pymongo.DESCENDING)])
|
||||
context_mongo.create_index(name='answers_index',
|
||||
keys=[('answers.group_id', pymongo.TEXT),
|
||||
('answers.keywords', pymongo.TEXT)],
|
||||
default_language='none')
|
||||
|
||||
blacklist_mongo = mongo_db['blacklist']
|
||||
blacklist_mongo.create_index(name='group_index',
|
||||
keys=[('group_id', pymongo.HASHED)])
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatData:
|
||||
group_id: int
|
||||
user_id: int
|
||||
raw_message: str
|
||||
plain_text: str
|
||||
time: int
|
||||
bot_id: int
|
||||
|
||||
_keywords_size: int = 3
|
||||
|
||||
@cached_property
|
||||
def is_plain_text(self) -> bool:
|
||||
return '[CQ:' not in self.raw_message and len(self.plain_text) != 0
|
||||
|
||||
@cached_property
|
||||
def is_image(self) -> bool:
|
||||
return '[CQ:image,' in self.raw_message or '[CQ:face,' in self.raw_message
|
||||
|
||||
@cached_property
|
||||
def keywords(self) -> str:
|
||||
if not self.is_plain_text and len(self.plain_text) == 0:
|
||||
return self.raw_message
|
||||
|
||||
keywords_list = jieba_fast.analyse.extract_tags(
|
||||
self.plain_text, topK=ChatData._keywords_size)
|
||||
if len(keywords_list) < 2:
|
||||
return self.plain_text
|
||||
else:
|
||||
# keywords_list.sort()
|
||||
return ' '.join(keywords_list)
|
||||
|
||||
@cached_property
|
||||
def keywords_pinyin(self) -> str:
|
||||
return ''.join([item[0] for item in pypinyin.pinyin(
|
||||
self.keywords, style=pypinyin.NORMAL, errors='default')]).lower()
|
||||
|
||||
@cached_property
|
||||
def to_me(self) -> bool:
|
||||
return self.plain_text.startswith('派蒙')
|
||||
|
||||
|
||||
class Chat:
|
||||
answer_threshold = config.paimon_answer_threshold # answer 相关的阈值,值越小废话越多,越大话越少
|
||||
answer_limit_threshold = config.paimon_answer_limit_threshold # 上限阈值,一般正常的上下文不可能发 50 遍,一般是其他 bot 的回复,禁了!
|
||||
cross_group_threshold = config.paimon_cross_group_threshold # N 个群有相同的回复,就跨群作为全局回复
|
||||
repeat_threshold = config.paimon_repeat_threshold # 复读的阈值,群里连续多少次有相同的发言,就复读
|
||||
speak_threshold = config.paimon_speak_threshold # 主动发言的阈值,越小废话越多
|
||||
|
||||
drunk_probability = config.paimon_drunk_probability # 喝醉的概率(回复没达到阈值的话)
|
||||
split_probability = 0.5 # 按逗号分割回复语的概率
|
||||
voice_probability = config.paimon_voice_probability # 回复语音的概率(仅纯文字)
|
||||
speak_continuously_probability = config.paimon_speak_continuously_probability # 连续主动说话的概率
|
||||
speak_poke_probability = config.paimon_speak_poke_probability # 主动说话加上随机戳一戳群友的概率
|
||||
speak_continuously_max_len = config.paimon_speak_continuously_max_len # 连续主动说话最多几句话
|
||||
|
||||
save_time_threshold = 3600 # 每隔多久进行一次持久化 ( 秒 )
|
||||
save_count_threshold = 1000 # 单个群超过多少条聊天记录就进行一次持久化。与时间是或的关系
|
||||
|
||||
blacklist_answer = defaultdict(set)
|
||||
blacklist_answer_reserve = defaultdict(set)
|
||||
|
||||
def __init__(self, data: Union[ChatData, GroupMessageEvent, PrivateMessageEvent]):
|
||||
|
||||
if isinstance(data, ChatData):
|
||||
self.chat_data = data
|
||||
elif isinstance(data, GroupMessageEvent):
|
||||
self.chat_data = ChatData(
|
||||
group_id=data.group_id,
|
||||
user_id=data.user_id,
|
||||
# 删除图片子类型字段,同一张图子类型经常不一样,影响判断
|
||||
raw_message=re.sub(
|
||||
r',subType=\d+\]',
|
||||
r']',
|
||||
data.raw_message),
|
||||
plain_text=data.get_plaintext(),
|
||||
time=data.time,
|
||||
bot_id=data.self_id,
|
||||
)
|
||||
elif isinstance(data, PrivateMessageEvent):
|
||||
event_dict = data.dict()
|
||||
self.chat_data = ChatData(
|
||||
group_id=data.user_id, # 故意加个符号,和群号区分开来
|
||||
user_id=data.user_id,
|
||||
# 删除图片子类型字段,同一张图子类型经常不一样,影响判断
|
||||
raw_message=re.sub(
|
||||
r',subType=\d+\]',
|
||||
r']',
|
||||
data.raw_message),
|
||||
plain_text=data.get_plaintext(),
|
||||
time=data.time,
|
||||
bot_id=data.self_id,
|
||||
)
|
||||
|
||||
def learn(self) -> bool:
|
||||
"""
|
||||
学习这句话
|
||||
"""
|
||||
|
||||
if len(self.chat_data.raw_message.strip()) == 0:
|
||||
return False
|
||||
|
||||
group_id = self.chat_data.group_id
|
||||
if group_id in Chat._message_dict:
|
||||
group_msgs = Chat._message_dict[group_id]
|
||||
if group_msgs:
|
||||
group_pre_msg = group_msgs[-1]
|
||||
else:
|
||||
group_pre_msg = None
|
||||
|
||||
# 群里的上一条发言
|
||||
self._context_insert(group_pre_msg)
|
||||
|
||||
user_id = self.chat_data.user_id
|
||||
if group_pre_msg and group_pre_msg['user_id'] != user_id:
|
||||
# 该用户在群里的上一条发言(倒序三句之内)
|
||||
for msg in group_msgs[:-3:-1]:
|
||||
if msg['user_id'] == user_id:
|
||||
self._context_insert(msg)
|
||||
break
|
||||
|
||||
self._message_insert()
|
||||
return True
|
||||
|
||||
def answer(self, with_limit: bool = True) -> Optional[Generator[Message, None, None]]:
|
||||
"""
|
||||
回复这句话,可能会分多次回复,也可能不回复
|
||||
"""
|
||||
|
||||
group_id = self.chat_data.group_id
|
||||
bot_id = self.chat_data.bot_id
|
||||
group_bot_replies = Chat._reply_dict[group_id][bot_id]
|
||||
|
||||
if with_limit:
|
||||
# # 不回复太短的对话,大部分是“?”、“草”
|
||||
# if self.chat_data.is_plain_text and len(self.chat_data.plain_text) < 2:
|
||||
# return None
|
||||
|
||||
if len(group_bot_replies):
|
||||
latest_reply = group_bot_replies[-1]
|
||||
# 限制发音频率,最多 6 秒一次
|
||||
if int(time.time()) - latest_reply['time'] < 6:
|
||||
return None
|
||||
# # 不要一直回复同一个内容
|
||||
# if self.chat_data.raw_message == latest_reply['pre_raw_message']:
|
||||
# return None
|
||||
# 有人复读了牛牛的回复,不继续回复
|
||||
# if self.chat_data.raw_message == latest_reply['reply']:
|
||||
# return None
|
||||
|
||||
results = self._context_find()
|
||||
|
||||
if results:
|
||||
raw_message = self.chat_data.raw_message
|
||||
keywords = self.chat_data.keywords
|
||||
with Chat._reply_lock:
|
||||
group_bot_replies.append({
|
||||
'time': int(time.time()),
|
||||
'pre_raw_message': raw_message,
|
||||
'pre_keywords': keywords,
|
||||
'reply': '[PaimonChat: Reply]', # flag
|
||||
'reply_keywords': '[PaimonChat: Reply]', # flag
|
||||
})
|
||||
|
||||
def yield_results(results: Tuple[List[str], str]) -> Generator[Message, None, None]:
|
||||
answer_list, answer_keywords = results
|
||||
group_bot_replies = Chat._reply_dict[group_id][bot_id]
|
||||
for item in answer_list:
|
||||
with Chat._reply_lock:
|
||||
group_bot_replies.append({
|
||||
'time': int(time.time()),
|
||||
'pre_raw_message': raw_message,
|
||||
'pre_keywords': keywords,
|
||||
'reply': item,
|
||||
'reply_keywords': answer_keywords,
|
||||
})
|
||||
if '[CQ:' not in item and len(item) > 1 \
|
||||
and random.random() < Chat.voice_probability:
|
||||
yield Chat._text_to_speech(item)
|
||||
else:
|
||||
yield Message(item)
|
||||
|
||||
with Chat._reply_lock:
|
||||
group_bot_replies = group_bot_replies[-Chat._save_reserve_size:]
|
||||
|
||||
return yield_results(results)
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def speak() -> Optional[Tuple[int, int, List[Message]]]:
|
||||
"""
|
||||
主动发言,返回当前最希望发言的 bot 账号、群号、发言消息 List,也有可能不发言
|
||||
"""
|
||||
|
||||
basic_msgs_len = 10
|
||||
basic_delay = 600
|
||||
|
||||
def group_popularity_cmp(lhs: Tuple[int, List[Dict[str, Any]]],
|
||||
rhs: Tuple[int, List[Dict[str, Any]]]) -> int:
|
||||
|
||||
def cmp(a: Any, b: Any):
|
||||
return (a > b) - (a < b)
|
||||
|
||||
lhs_group_id, lhs_msgs = lhs
|
||||
rhs_group_id, rhs_msgs = rhs
|
||||
|
||||
lhs_len = len(lhs_msgs)
|
||||
rhs_len = len(rhs_msgs)
|
||||
|
||||
# 默认是 0, 加个 1 避免乘没了
|
||||
lhs_drunkenness = Chat._drunkenness_dict[lhs_group_id] + 1
|
||||
rhs_drunkenness = Chat._drunkenness_dict[rhs_group_id] + 1
|
||||
|
||||
if lhs_len < basic_msgs_len or rhs_len < basic_msgs_len:
|
||||
return cmp(lhs_len * lhs_drunkenness,
|
||||
rhs_len * rhs_drunkenness)
|
||||
|
||||
lhs_duration = lhs_msgs[-1]['time'] - lhs_msgs[0]['time']
|
||||
rhs_duration = rhs_msgs[-1]['time'] - rhs_msgs[0]['time']
|
||||
|
||||
if not lhs_duration or not rhs_duration:
|
||||
return cmp(lhs_len, rhs_len)
|
||||
|
||||
return cmp(lhs_len * lhs_drunkenness / lhs_duration,
|
||||
rhs_len * rhs_drunkenness / rhs_duration)
|
||||
|
||||
# 按群聊热度排序
|
||||
popularity = sorted(Chat._message_dict.items(),
|
||||
key=cmp_to_key(group_popularity_cmp))
|
||||
|
||||
cur_time = time.time()
|
||||
for group_id, group_msgs in popularity:
|
||||
group_replies = Chat._reply_dict[group_id]
|
||||
if not len(group_replies) or len(group_msgs) < basic_msgs_len:
|
||||
continue
|
||||
|
||||
# 一般来说所有牛牛都是一起回复的,最后发言时间应该是一样的,随意随便选一个[0]就好了
|
||||
group_replies_front = list(group_replies.values())[0]
|
||||
if not len(group_replies_front) or \
|
||||
group_replies_front[-1]['time'] > group_msgs[-1]['time']:
|
||||
continue
|
||||
|
||||
msgs_len = len(group_msgs)
|
||||
latest_time = group_msgs[-1]['time']
|
||||
duration = latest_time - group_msgs[0]['time']
|
||||
avg_interval = duration / msgs_len
|
||||
|
||||
# 已经超过平均发言间隔 N 倍的时间没有人说话了,才主动发言
|
||||
# print(cur_time - latest_time, '/', avg_interval *
|
||||
# Chat.speak_threshold + basic_delay)
|
||||
if cur_time - latest_time < avg_interval * Chat.speak_threshold + basic_delay:
|
||||
continue
|
||||
|
||||
# append 一个 flag, 防止这个群热度特别高,但压根就没有可用的 context 时,每次 speak 都查这个群,浪费时间
|
||||
with Chat._reply_lock:
|
||||
group_replies_front.append({
|
||||
'time': int(cur_time),
|
||||
'pre_raw_message': '[PaimonChat: Speak]',
|
||||
'pre_keywords': '[PaimonChat: Speak]',
|
||||
'reply': '[PaimonChat: Speak]',
|
||||
'reply_keywords': '[PaimonChat: Speak]',
|
||||
})
|
||||
|
||||
available_time = cur_time - 24 * 3600
|
||||
speak_context = context_mongo.aggregate([
|
||||
{
|
||||
'$match': {
|
||||
'count': {
|
||||
'$gt': Chat.answer_threshold
|
||||
},
|
||||
'time': {
|
||||
'$gt': available_time
|
||||
},
|
||||
# 上面两行为了加快查找速度,对查找到的结果不产生影响
|
||||
'answers.group_id': group_id,
|
||||
'answers.time': {
|
||||
'$gt': available_time
|
||||
},
|
||||
'answers.count': {
|
||||
'$gt': Chat.answer_threshold
|
||||
}
|
||||
}
|
||||
}, {
|
||||
'$sample': {'size': 1} # 随机一条
|
||||
}
|
||||
])
|
||||
|
||||
speak_context = list(speak_context)
|
||||
if not speak_context:
|
||||
continue
|
||||
|
||||
ban_keywords = Chat._find_ban_keywords(
|
||||
context=speak_context[0], group_id=group_id)
|
||||
messages = [answer['messages']
|
||||
for answer in speak_context[0]['answers']
|
||||
if answer['count'] >= Chat.answer_threshold
|
||||
and answer['keywords'] not in ban_keywords
|
||||
and answer['group_id'] == group_id]
|
||||
|
||||
if not messages:
|
||||
continue
|
||||
|
||||
speak = random.choice(random.choice(messages))
|
||||
|
||||
bot_id = random.choice(
|
||||
[bid for bid in group_replies.keys() if bid])
|
||||
with Chat._reply_lock:
|
||||
group_replies[bot_id].append({
|
||||
'time': int(cur_time),
|
||||
'pre_raw_message': '[PaimonChat: Speak]',
|
||||
'pre_keywords': '[PaimonChat: Speak]',
|
||||
'reply': speak,
|
||||
'reply_keywords': '[PaimonChat: Speak]',
|
||||
})
|
||||
|
||||
speak_list = [Message(speak), ]
|
||||
while random.random() < Chat.speak_continuously_probability \
|
||||
and len(speak_list) < Chat.speak_continuously_max_len:
|
||||
pre_msg = str(speak_list[-1])
|
||||
answer = Chat(ChatData(group_id, 0, pre_msg,
|
||||
pre_msg, cur_time, 0)).answer(False)
|
||||
if not answer:
|
||||
break
|
||||
speak_list.extend(answer)
|
||||
|
||||
if random.random() < Chat.speak_poke_probability:
|
||||
target_id = random.choice(
|
||||
Chat._message_dict[group_id])['user_id']
|
||||
speak_list.append(Message('[CQ:poke,qq={}]'.format(target_id)))
|
||||
|
||||
return bot_id, group_id, speak_list
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def ban(group_id: int, bot_id: int, ban_raw_message: str, reason: str) -> bool:
|
||||
"""
|
||||
禁止以后回复这句话,仅对该群有效果
|
||||
"""
|
||||
|
||||
if group_id not in Chat._reply_dict:
|
||||
return False
|
||||
|
||||
ban_reply = None
|
||||
reply_data = Chat._reply_dict[group_id][bot_id][::-1]
|
||||
|
||||
for reply in reply_data:
|
||||
cur_reply = reply['reply']
|
||||
# 为空时就直接 ban 最后一条回复
|
||||
if not ban_raw_message or ban_raw_message in cur_reply:
|
||||
ban_reply = reply
|
||||
break
|
||||
|
||||
# 这种情况一般是有些 CQ 码,牛牛发送的时候,和被回复的时候,里面的内容不一样
|
||||
if not ban_reply:
|
||||
search = re.search(r'(\[CQ:[a-zA-z0-9-_.]+)',
|
||||
ban_raw_message)
|
||||
if search:
|
||||
type_keyword = search.group(1)
|
||||
for reply in reply_data:
|
||||
cur_reply = reply['reply']
|
||||
if type_keyword in cur_reply:
|
||||
ban_reply = reply
|
||||
break
|
||||
|
||||
if not ban_reply:
|
||||
return False
|
||||
|
||||
pre_keywords = reply['pre_keywords']
|
||||
keywords = reply['reply_keywords']
|
||||
|
||||
# 考虑这句回复是从别的群捞过来的情况,所以这里要分两次 update
|
||||
# context_mongo.update_one({
|
||||
# 'keywords': pre_keywords,
|
||||
# 'answers.keywords': keywords,
|
||||
# 'answers.group_id': group_id
|
||||
# }, {
|
||||
# '$set': {
|
||||
# 'answers.$.count': -99999
|
||||
# }
|
||||
# })
|
||||
context_mongo.update_one({
|
||||
'keywords': pre_keywords
|
||||
}, {
|
||||
'$push': {
|
||||
'ban': {
|
||||
'keywords': keywords,
|
||||
'group_id': group_id,
|
||||
'reason': reason,
|
||||
'time': int(time.time())
|
||||
}
|
||||
}
|
||||
})
|
||||
if keywords in Chat.blacklist_answer_reserve[group_id]:
|
||||
Chat.blacklist_answer[group_id].add(keywords)
|
||||
if keywords in Chat.blacklist_answer_reserve[Chat._blacklist_flag]:
|
||||
Chat.blacklist_answer[Chat._blacklist_flag].add(
|
||||
keywords)
|
||||
else:
|
||||
Chat.blacklist_answer_reserve[group_id].add(keywords)
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def drink(group_id: int) -> None:
|
||||
"""
|
||||
牛牛喝酒,仅对该群有效果。提高醉酒程度(降低回复阈值的概率)
|
||||
"""
|
||||
Chat._drunkenness_dict[group_id] += 1
|
||||
|
||||
@staticmethod
|
||||
def sober_up(group_id: int) -> bool:
|
||||
"""
|
||||
牛牛醒酒,仅对该群有效果。返回醒酒是否成功
|
||||
"""
|
||||
|
||||
Chat._drunkenness_dict[group_id] -= 1
|
||||
return Chat._drunkenness_dict[group_id] <= 0
|
||||
|
||||
# private:
|
||||
_reply_dict = defaultdict(lambda: defaultdict(list)) # 牛牛回复的消息缓存,暂未做持久化
|
||||
_message_dict = {} # 群消息缓存
|
||||
_drunkenness_dict = defaultdict(int) # 醉酒程度,不同群应用不同的数值
|
||||
|
||||
_save_reserve_size = 100 # 保存时,给内存中保留的大小
|
||||
_late_save_time = 0 # 上次保存(消息数据持久化)的时刻 ( time.time(), 秒 )
|
||||
|
||||
_reply_lock = threading.Lock()
|
||||
_message_lock = threading.Lock()
|
||||
_blacklist_flag = 114514
|
||||
|
||||
def _message_insert(self):
|
||||
group_id = self.chat_data.group_id
|
||||
|
||||
with Chat._message_lock:
|
||||
if group_id not in Chat._message_dict:
|
||||
Chat._message_dict[group_id] = []
|
||||
|
||||
Chat._message_dict[group_id].append({
|
||||
'group_id': group_id,
|
||||
'user_id': self.chat_data.user_id,
|
||||
'raw_message': self.chat_data.raw_message,
|
||||
'is_plain_text': self.chat_data.is_plain_text,
|
||||
'plain_text': self.chat_data.plain_text,
|
||||
'keywords': self.chat_data.keywords,
|
||||
'time': self.chat_data.time,
|
||||
})
|
||||
|
||||
cur_time = self.chat_data.time
|
||||
if Chat._late_save_time == 0:
|
||||
Chat._late_save_time = cur_time - 1
|
||||
return
|
||||
|
||||
if len(Chat._message_dict[group_id]) > Chat.save_count_threshold:
|
||||
Chat._sync(cur_time)
|
||||
|
||||
elif cur_time - Chat._late_save_time > Chat.save_time_threshold:
|
||||
Chat._sync(cur_time)
|
||||
|
||||
@staticmethod
|
||||
def _sync(cur_time: int = time.time()):
|
||||
"""
|
||||
持久化
|
||||
"""
|
||||
|
||||
with Chat._message_lock:
|
||||
save_list = [msg
|
||||
for group_msgs in Chat._message_dict.values()
|
||||
for msg in group_msgs
|
||||
if msg['time'] > Chat._late_save_time]
|
||||
if not save_list:
|
||||
return
|
||||
|
||||
Chat._message_dict = {group_id: group_msgs[-Chat._save_reserve_size:]
|
||||
for group_id, group_msgs in Chat._message_dict.items()}
|
||||
|
||||
Chat._late_save_time = cur_time
|
||||
|
||||
message_mongo.insert_many(save_list)
|
||||
|
||||
def _context_insert(self, pre_msg):
|
||||
if not pre_msg:
|
||||
return
|
||||
|
||||
raw_message = self.chat_data.raw_message
|
||||
|
||||
# 在复读,不学
|
||||
if pre_msg['raw_message'] == raw_message:
|
||||
return
|
||||
|
||||
# 回复别人的,不学
|
||||
if '[CQ:reply,' in raw_message:
|
||||
return
|
||||
|
||||
keywords = self.chat_data.keywords
|
||||
group_id = self.chat_data.group_id
|
||||
pre_keywords = pre_msg['keywords']
|
||||
cur_time = self.chat_data.time
|
||||
|
||||
# update_key = {
|
||||
# 'keywords': pre_keywords,
|
||||
# 'answers.keywords': keywords,
|
||||
# 'answers.group_id': group_id
|
||||
# }
|
||||
# update_value = {
|
||||
# '$set': {'time': cur_time},
|
||||
# '$inc': {'answers.$.count': 1},
|
||||
# '$push': {'answers.$.messages': raw_message}
|
||||
# }
|
||||
# # update_value.update(update_key)
|
||||
|
||||
# context_mongo.update_one(
|
||||
# update_key, update_value, upsert=True)
|
||||
|
||||
# 这个 upsert 太难写了,搞不定_(:з」∠)_
|
||||
# 先用 find + insert or update 凑合了
|
||||
find_key = {'keywords': pre_keywords}
|
||||
context = context_mongo.find_one(find_key)
|
||||
if context:
|
||||
update_value = {
|
||||
'$set': {
|
||||
'time': cur_time
|
||||
},
|
||||
'$inc': {'count': 1}
|
||||
}
|
||||
answer_index = next((idx for idx, answer in enumerate(context['answers'])
|
||||
if answer['group_id'] == group_id
|
||||
and answer['keywords'] == keywords), -1)
|
||||
if answer_index != -1:
|
||||
update_value['$inc'].update({
|
||||
f'answers.{answer_index}.count': 1
|
||||
})
|
||||
update_value['$set'].update({
|
||||
f'answers.{answer_index}.time': cur_time
|
||||
})
|
||||
# 不是纯文本的时候,raw_message 是完全一样的,没必要 push
|
||||
if self.chat_data.is_plain_text:
|
||||
update_value['$push'] = {
|
||||
f'answers.{answer_index}.messages': raw_message
|
||||
}
|
||||
else:
|
||||
update_value['$push'] = {
|
||||
'answers': {
|
||||
'keywords': keywords,
|
||||
'group_id': group_id,
|
||||
'count': 1,
|
||||
'time': cur_time,
|
||||
'messages': [
|
||||
raw_message
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
context_mongo.update_one(find_key, update_value)
|
||||
else:
|
||||
context = {
|
||||
'keywords': pre_keywords,
|
||||
'time': cur_time,
|
||||
'count': 1,
|
||||
'answers': [
|
||||
{
|
||||
'keywords': keywords,
|
||||
'group_id': group_id,
|
||||
'count': 1,
|
||||
'time': cur_time,
|
||||
'messages': [
|
||||
raw_message
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
context_mongo.insert_one(context)
|
||||
|
||||
def _context_find(self) -> Optional[Tuple[List[str], str]]:
|
||||
|
||||
group_id = self.chat_data.group_id
|
||||
raw_message = self.chat_data.raw_message
|
||||
keywords = self.chat_data.keywords
|
||||
bot_id = self.chat_data.bot_id
|
||||
|
||||
# 复读!
|
||||
if group_id in Chat._message_dict:
|
||||
group_msgs = Chat._message_dict[group_id]
|
||||
if len(group_msgs) >= Chat.repeat_threshold and \
|
||||
all(item['raw_message'] == raw_message
|
||||
for item in group_msgs[:-Chat.repeat_threshold:-1]):
|
||||
# 到这里说明当前群里是在复读
|
||||
group_bot_replies = Chat._reply_dict[group_id][bot_id]
|
||||
if len(group_bot_replies) and group_bot_replies[-1]['reply'] != raw_message:
|
||||
return [raw_message, ], keywords
|
||||
else:
|
||||
# 复读过一次就不再回复这句话了
|
||||
return None
|
||||
|
||||
context = context_mongo.find_one({'keywords': keywords})
|
||||
|
||||
if not context:
|
||||
return None
|
||||
|
||||
if Chat._drunkenness_dict[group_id] > 0:
|
||||
answer_count_threshold = 1
|
||||
else:
|
||||
answer_count_threshold = Chat.answer_threshold
|
||||
|
||||
if self.chat_data.to_me:
|
||||
cross_group_threshold = 1
|
||||
else:
|
||||
cross_group_threshold = Chat.cross_group_threshold
|
||||
|
||||
ban_keywords = Chat._find_ban_keywords(
|
||||
context=context, group_id=group_id)
|
||||
|
||||
candidate_answers = {}
|
||||
other_group_cache = {}
|
||||
answers_count = defaultdict(int)
|
||||
|
||||
def candidate_append(dst, answer):
|
||||
answer_key = answer['keywords']
|
||||
if answer_key not in dst:
|
||||
dst[answer_key] = answer
|
||||
else:
|
||||
pre_answer = dst[answer_key]
|
||||
pre_answer['count'] += answer['count']
|
||||
pre_answer['messages'] += answer['messages']
|
||||
|
||||
for answer in context['answers']:
|
||||
answer_key = answer['keywords']
|
||||
if answer_key in ban_keywords or answer['count'] < answer_count_threshold:
|
||||
continue
|
||||
|
||||
sample_msg = answer['messages'][0]
|
||||
if self.chat_data.is_image and '[CQ:' not in sample_msg:
|
||||
# 图片消息不回复纯文本。图片经常是表情包,后面的纯文本啥都有,很乱
|
||||
continue
|
||||
|
||||
if answer['group_id'] == group_id:
|
||||
candidate_append(candidate_answers, answer)
|
||||
# 别的群的 at, 忽略
|
||||
elif '[CQ:at,qq=' in sample_msg:
|
||||
continue
|
||||
else: # 有这么 N 个群都有相同的回复,就作为全局回复
|
||||
answers_count[answer_key] += 1
|
||||
cur_count = answers_count[answer_key]
|
||||
if cur_count < cross_group_threshold: # 没达到阈值前,先缓存
|
||||
candidate_append(other_group_cache, answer)
|
||||
elif cur_count == cross_group_threshold: # 刚达到阈值时,将缓存加入
|
||||
if cur_count > 1:
|
||||
candidate_append(candidate_answers,
|
||||
other_group_cache[answer_key])
|
||||
candidate_append(candidate_answers, answer)
|
||||
else: # 超过阈值后,加入
|
||||
candidate_append(candidate_answers, answer)
|
||||
|
||||
if not candidate_answers:
|
||||
return None
|
||||
|
||||
final_answer = random.choices(list(candidate_answers.values()), weights=[
|
||||
# 防止某个回复权重太大,别的都 Roll 不到了
|
||||
min(answer['count'], 10) for answer in candidate_answers.values()])[0]
|
||||
answer_str = random.choice(final_answer['messages'])
|
||||
answer_keywords = final_answer['keywords']
|
||||
|
||||
if 0 < answer_str.count(',') <= 3 and random.random() < Chat.split_probability:
|
||||
return answer_str.split(','), answer_keywords
|
||||
return [answer_str, ], answer_keywords
|
||||
|
||||
@staticmethod
|
||||
def _text_to_speech(text: str) -> Optional[Message]:
|
||||
# if plugin_config.enable_voice:
|
||||
# result = tts_client.synthesis(text, options={'per': 111}) # 度小萌
|
||||
# if not isinstance(result, dict): # error message
|
||||
# return MessageSegment.record(result)
|
||||
|
||||
return Message(f'[CQ:tts,text={text}]')
|
||||
|
||||
@staticmethod
|
||||
def update_global_blacklist() -> None:
|
||||
Chat._select_blacklist()
|
||||
|
||||
keywords_dict = defaultdict(int)
|
||||
global_blacklist = set()
|
||||
for _, keywords_list in Chat.blacklist_answer.items():
|
||||
for keywords in keywords_list:
|
||||
keywords_dict[keywords] += 1
|
||||
if keywords_dict[keywords] == Chat.cross_group_threshold:
|
||||
global_blacklist.add(keywords)
|
||||
|
||||
Chat.blacklist_answer[Chat._blacklist_flag] |= global_blacklist
|
||||
|
||||
@staticmethod
|
||||
def _select_blacklist() -> None:
|
||||
all_blacklist = blacklist_mongo.find()
|
||||
|
||||
for item in all_blacklist:
|
||||
group_id = item['group_id']
|
||||
if 'answers' in item:
|
||||
Chat.blacklist_answer[group_id] |= set(item['answers'])
|
||||
if 'answers_reserve' in item:
|
||||
Chat.blacklist_answer_reserve[group_id] |= set(
|
||||
item['answers_reserve'])
|
||||
|
||||
@staticmethod
|
||||
def _sync_blacklist() -> None:
|
||||
Chat._select_blacklist()
|
||||
|
||||
for group_id, answers in Chat.blacklist_answer.items():
|
||||
if not len(answers):
|
||||
continue
|
||||
blacklist_mongo.update_one(
|
||||
{'group_id': group_id},
|
||||
{'$set': {'answers': list(answers)}},
|
||||
upsert=True)
|
||||
|
||||
for group_id, answers in Chat.blacklist_answer_reserve.items():
|
||||
if not len(answers):
|
||||
continue
|
||||
if group_id in Chat.blacklist_answer:
|
||||
answers = answers - Chat.blacklist_answer[group_id]
|
||||
|
||||
blacklist_mongo.update_one(
|
||||
{'group_id': group_id},
|
||||
{'$set': {'answers_reserve': list(answers)}},
|
||||
upsert=True)
|
||||
|
||||
@staticmethod
|
||||
def clearup_context() -> None:
|
||||
"""
|
||||
清理所有超过 30 天没人说、且没有学会的话
|
||||
"""
|
||||
|
||||
cur_time = int(time.time())
|
||||
expiration = cur_time - 30 * 24 * 3600 # 三十天前
|
||||
|
||||
context_mongo.delete_many({
|
||||
'time': {'$lt': expiration},
|
||||
'count': {'$lt': Chat.answer_threshold} # lt 是小于,不包括等于
|
||||
})
|
||||
|
||||
all_context = context_mongo.find({
|
||||
'count': {'$gt': 100},
|
||||
'$or': [
|
||||
# 历史遗留问题,老版本的数据没有 clear_time 字段
|
||||
{"clear_time": {"$exists": False}},
|
||||
{"clear_time": {"$lt": expiration}}
|
||||
]
|
||||
})
|
||||
for context in all_context:
|
||||
answers = [ans
|
||||
for ans in context['answers']
|
||||
# 历史遗留问题,老版本的数据没有 answers.$.time 字段
|
||||
if ans['count'] > 1 or ('time' in ans and ans['time'] > expiration)]
|
||||
context_mongo.update_one({
|
||||
'keywords': context['keywords']
|
||||
}, {
|
||||
'$set': {
|
||||
'answers': answers,
|
||||
'clear_time': cur_time
|
||||
}
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
def completely_sober():
|
||||
for key in Chat._drunkenness_dict.keys():
|
||||
Chat._drunkenness_dict[key] = 0
|
||||
|
||||
@staticmethod
|
||||
def _find_ban_keywords(context, group_id) -> set:
|
||||
"""
|
||||
找到在 group_id 群中对应 context 不能回复的关键词
|
||||
"""
|
||||
|
||||
# 全局的黑名单
|
||||
ban_keywords = Chat.blacklist_answer[Chat._blacklist_flag] | Chat.blacklist_answer[group_id]
|
||||
# 针对单条回复的黑名单
|
||||
if 'ban' in context:
|
||||
ban_count = defaultdict(int)
|
||||
for ban in context['ban']:
|
||||
ban_key = ban['keywords']
|
||||
if ban['group_id'] == group_id or ban['group_id'] == Chat._blacklist_flag:
|
||||
ban_keywords.add(ban_key)
|
||||
else:
|
||||
# 超过 N 个群都把这句话 ban 了,那就全局 ban 掉
|
||||
ban_count[ban_key] += 1
|
||||
if ban_count[ban_key] == Chat.cross_group_threshold:
|
||||
ban_keywords.add(ban_key)
|
||||
return ban_keywords
|
||||
|
||||
|
||||
# Auto sync on program start
|
||||
Chat.update_global_blacklist()
|
||||
|
||||
|
||||
def _chat_sync():
|
||||
Chat._sync()
|
||||
Chat._sync_blacklist()
|
||||
|
||||
|
||||
# Auto sync on program exit
|
||||
atexit.register(_chat_sync)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# Chat.clearup_context()
|
||||
# # while True:
|
||||
test_data: ChatData = ChatData(
|
||||
group_id=1234567,
|
||||
user_id=1111111,
|
||||
raw_message='完了又有新bug',
|
||||
plain_text='完了又有新bug',
|
||||
time=time.time(),
|
||||
bot_id=0,
|
||||
)
|
||||
|
||||
test_chat: Chat = Chat(test_data)
|
||||
|
||||
print(test_chat.answer())
|
||||
test_chat.learn()
|
||||
|
||||
test_answer_data: ChatData = ChatData(
|
||||
group_id=1234567,
|
||||
user_id=1111111,
|
||||
raw_message='完了又有新bug',
|
||||
plain_text='完了又有新bug',
|
||||
time=time.time(),
|
||||
bot_id=0,
|
||||
)
|
||||
|
||||
test_answer: Chat = Chat(test_answer_data)
|
||||
print(test_chat.answer())
|
||||
test_answer.learn()
|
||||
|
||||
# time.sleep(5)
|
||||
# print(Chat.speak())
|
@ -28,6 +28,8 @@ async def draw_ring(per):
|
||||
plt.savefig('temp.png', transparent=True)
|
||||
img = Image.open('temp.png').resize((266, 266)).convert('RGBA')
|
||||
os.remove('temp.png')
|
||||
plt.cla()
|
||||
plt.close("all")
|
||||
return img
|
||||
|
||||
|
||||
|
@ -40,6 +40,8 @@ async def draw_ring(per, colors):
|
||||
plt.savefig('temp.png', transparent=True)
|
||||
img = Image.open('temp.png').resize((378, 378)).convert('RGBA')
|
||||
os.remove('temp.png')
|
||||
plt.cla()
|
||||
plt.close("all")
|
||||
return img
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user