优化群聊学习复读表现和插件权限管理

This commit is contained in:
CMHopeSunshine 2022-11-26 18:25:11 +08:00
parent fab1cc3f85
commit 689a7701d3
8 changed files with 208 additions and 217 deletions

View File

@ -2,6 +2,7 @@ from pathlib import Path
from nonebot import load_plugins, logger from nonebot import load_plugins, logger
from LittlePaimon import database, web from LittlePaimon import database, web
from LittlePaimon.config import PluginManager
from LittlePaimon.utils import DRIVER, __version__, NICKNAME, SUPERUSERS from LittlePaimon.utils import DRIVER, __version__, NICKNAME, SUPERUSERS
from LittlePaimon.utils.tool import check_resource from LittlePaimon.utils.tool import check_resource
@ -35,6 +36,7 @@ logo = """<g>
async def startup(): async def startup():
logger.opt(colors=True).info(logo) logger.opt(colors=True).info(logo)
await database.connect() await database.connect()
await PluginManager.init()
await check_resource() await check_resource()

View File

@ -1,17 +1,18 @@
import asyncio import contextlib
import datetime import datetime
from typing import Dict, List from typing import Dict, List
from nonebot import plugin as nb_plugin from nonebot import plugin as nb_plugin
from nonebot import get_bot
from nonebot.matcher import Matcher
from nonebot.exception import IgnoredException
from nonebot.message import run_preprocessor
from nonebot.adapters.onebot.v11 import MessageEvent, PrivateMessageEvent, GroupMessageEvent from nonebot.adapters.onebot.v11 import MessageEvent, PrivateMessageEvent, GroupMessageEvent
from LittlePaimon.utils import logger, DRIVER, SUPERUSERS from nonebot.exception import IgnoredException
from LittlePaimon.utils.path import PLUGIN_CONFIG from nonebot.matcher import Matcher
from nonebot.message import run_preprocessor
from tortoise.queryset import Q
from LittlePaimon.database.models import PluginPermission, PluginStatistics, PluginDisable
from LittlePaimon.utils import logger, SUPERUSERS
from LittlePaimon.utils.files import load_yaml, save_yaml from LittlePaimon.utils.files import load_yaml, save_yaml
from LittlePaimon.database.models import PluginPermission, PluginStatistics from LittlePaimon.utils.path import PLUGIN_CONFIG
from .model import MatcherInfo, PluginInfo from .model import MatcherInfo, PluginInfo
HIDDEN_PLUGINS = [ HIDDEN_PLUGINS = [
@ -43,42 +44,22 @@ class PluginManager:
@classmethod @classmethod
async def init(cls): async def init(cls):
plugin_list = nb_plugin.get_loaded_plugins() plugin_list = nb_plugin.get_loaded_plugins()
group_list = await get_bot().get_group_list() if not await PluginDisable.all().exists() and await PluginPermission.all().exists():
user_list = await get_bot().get_friend_list() perms = await PluginPermission.filter(Q(status=False) | Q(ban__not=[])).all()
for perm in perms:
with contextlib.suppress(Exception):
if perm.session_type == 'group':
if not perm.status:
await PluginDisable.update_or_create(name=perm.name, group_id=perm.session_id)
for ban_user in perm.ban:
await PluginDisable.update_or_create(name=perm.name, group_id=perm.session_id,
user_id=ban_user)
else:
if not perm.status:
await PluginDisable.update_or_create(name=perm.name, user_id=perm.session_id)
await PluginPermission.all().delete()
await PluginDisable.filter(global_disable=False, group_id=None, user_id=None).delete()
for plugin in plugin_list: for plugin in plugin_list:
if plugin.name not in HIDDEN_PLUGINS and PluginPermission._meta.default_connection is not None:
if group_list:
for group in group_list:
count = await PluginPermission.filter(
name=plugin.name, session_id=group['group_id'], session_type='group'
).count()
if count > 1:
first = await PluginPermission.filter(
name=plugin.name, session_id=group['group_id'], session_type='group'
).order_by('id').first()
await PluginPermission.filter(
name=plugin.name, session_id=group['group_id'], session_type='group'
).delete()
await first.save()
elif count == 0:
await PluginPermission.create(name=plugin.name, session_id=group['group_id'],
session_type='group')
if user_list:
for user in user_list:
count = await PluginPermission.filter(
name=plugin.name, session_id=user['user_id'], session_type='user'
).count()
if count > 1:
first = await PluginPermission.filter(
name=plugin.name, session_id=user['user_id'], session_type='user'
).order_by('id').first()
await PluginPermission.filter(
name=plugin.name, session_id=user['user_id'], session_type='user'
).delete()
await first.save()
elif count == 0:
await PluginPermission.create(name=plugin.name, session_id=user['user_id'],
session_type='user')
if plugin.name not in HIDDEN_PLUGINS: if plugin.name not in HIDDEN_PLUGINS:
if plugin.name not in cls.plugins: if plugin.name not in cls.plugins:
if metadata := plugin.metadata: if metadata := plugin.metadata:
@ -113,15 +94,23 @@ class PluginManager:
:param message_type: 消息类型 :param message_type: 消息类型
:param session_id: 消息ID :param session_id: 消息ID
""" """
load_plugins = nb_plugin.get_loaded_plugins() load_plugins = [p.name for p in nb_plugin.get_loaded_plugins()]
load_plugins = [p.name for p in load_plugins]
plugin_list = sorted(cls.plugins.values(), key=lambda x: x.priority).copy() plugin_list = sorted(cls.plugins.values(), key=lambda x: x.priority).copy()
plugin_list = [p for p in plugin_list if p.show and p.module_name in load_plugins] plugin_list = [p for p in plugin_list if p.show and p.module_name in load_plugins]
for plugin in plugin_list: for plugin in plugin_list:
if not await PluginDisable.filter(name=plugin.module_name, global_disable=True).exists():
if message_type != 'guild': if message_type != 'guild':
plugin_info = await PluginPermission.get_or_none(name=plugin.module_name, session_id=session_id, # plugin_info = await PluginPermission.get_or_none(name=plugin.module_name, session_id=session_id,
session_type=message_type) # session_type=message_type)
plugin.status = True if plugin_info is None else plugin_info.status # plugin.status = True if plugin_info is None else plugin_info.status
if message_type == 'group':
plugin.status = not await PluginDisable.filter(name=plugin.module_name,
group_id=session_id).exists()
else:
plugin.status = not await PluginDisable.filter(name=plugin.module_name,
user_id=session_id).exists()
else:
plugin.status = True
else: else:
plugin.status = True plugin.status = True
if plugin.matchers: if plugin.matchers:
@ -134,47 +123,36 @@ class PluginManager:
""" """
获取插件列表供Web UI使用 获取插件列表供Web UI使用
""" """
load_plugins = nb_plugin.get_loaded_plugins() load_plugins = [p.name for p in nb_plugin.get_loaded_plugins()]
load_plugins = [p.name for p in load_plugins]
plugin_list = [p.dict(exclude={'status'}) for p in cls.plugins.values()] plugin_list = [p.dict(exclude={'status'}) for p in cls.plugins.values()]
for plugin in plugin_list: for plugin in plugin_list:
plugin['matchers'].sort(key=lambda x: x['pm_priority']) plugin['matchers'].sort(key=lambda x: x['pm_priority'])
plugin['isLoad'] = plugin['module_name'] in load_plugins plugin['isLoad'] = plugin['module_name'] in load_plugins
plugin['status'] = await PluginPermission.filter(name=plugin['module_name'], status=True).exists() plugin['status'] = not await PluginDisable.filter(name=plugin['module_name'], global_disable=True).exists()
plugin_list.sort(key=lambda x: (x['isLoad'], x['status'], -x['priority']), reverse=True) plugin_list.sort(key=lambda x: (x['isLoad'], x['status'], -x['priority']), reverse=True)
return plugin_list return plugin_list
@DRIVER.on_bot_connect
async def _():
await PluginManager.init()
@run_preprocessor @run_preprocessor
async def _(event: MessageEvent, matcher: Matcher): async def _(event: MessageEvent, matcher: Matcher):
try:
if event.user_id in SUPERUSERS: if event.user_id in SUPERUSERS:
return return
if not matcher.plugin_name or matcher.plugin_name in HIDDEN_PLUGINS: if not matcher.plugin_name or matcher.plugin_name in HIDDEN_PLUGINS:
return return
if isinstance(event, PrivateMessageEvent): if not isinstance(event, (PrivateMessageEvent, GroupMessageEvent)):
session_id = event.user_id
session_type = 'user'
elif isinstance(event, GroupMessageEvent):
session_id = event.group_id
session_type = 'group'
else:
return return
# 权限检查 # 权限检查
perm = await PluginPermission.get_or_none(name=matcher.plugin_name, session_id=session_id, if await PluginDisable.get_or_none(name=matcher.plugin_name, global_disable=True):
session_type=session_type) raise IgnoredException('插件使用权限已禁用')
if not perm: if await PluginDisable.get_or_none(name=matcher.plugin_name, user_id=event.user_id, group_id=None):
await PluginPermission.create(name=matcher.plugin_name, session_id=session_id, session_type=session_type) raise IgnoredException('插件使用权限已禁用')
return elif isinstance(event, GroupMessageEvent) and (
if not perm.status: perms := await PluginDisable.filter(name=matcher.plugin_name, group_id=event.group_id)):
user_ids = [p.user_id for p in perms]
if None in user_ids or event.user_id in user_ids:
raise IgnoredException('插件使用权限已禁用') raise IgnoredException('插件使用权限已禁用')
if isinstance(event, GroupMessageEvent) and event.user_id in perm.ban:
raise IgnoredException('用户被禁止使用该插件')
# 命令调用统计 # 命令调用统计
if matcher.plugin_name in PluginManager.plugins and 'pm_name' in matcher.state: if matcher.plugin_name in PluginManager.plugins and 'pm_name' in matcher.state:
@ -186,5 +164,7 @@ async def _(event: MessageEvent, matcher: Matcher):
matcher_usage=matcher_info.pm_usage, matcher_usage=matcher_info.pm_usage,
group_id=event.group_id if isinstance(event, GroupMessageEvent) else None, group_id=event.group_id if isinstance(event, GroupMessageEvent) else None,
user_id=event.user_id, user_id=event.user_id,
message_type=session_type, message_type=event.message_type,
time=datetime.datetime.now()) time=datetime.datetime.now())
except Exception as e:
logger.info('插件管理器', f'插件权限检查<r>失败:{e}</r>')

View File

@ -7,6 +7,7 @@ from tortoise.models import Model
class PluginPermission(Model): class PluginPermission(Model):
"""将在N个版本后废弃"""
id = fields.IntField(pk=True, generated=True, auto_increment=True) id = fields.IntField(pk=True, generated=True, auto_increment=True)
name: str = fields.TextField() name: str = fields.TextField()
"""插件名称""" """插件名称"""
@ -25,6 +26,21 @@ class PluginPermission(Model):
table = 'plugin_permission' table = 'plugin_permission'
class PluginDisable(Model):
id = fields.IntField(pk=True, generated=True, auto_increment=True)
name: str = fields.TextField()
"""插件名称"""
global_disable: bool = fields.BooleanField(default=False)
"""全局禁用"""
user_id: int = fields.IntField(null=True)
"""用户id"""
group_id: int = fields.IntField(null=True)
"""群组id"""
class Meta:
table = 'plugin_disable'
class PluginStatistics(Model): class PluginStatistics(Model):
id = fields.IntField(pk=True, generated=True, auto_increment=True) id = fields.IntField(pk=True, generated=True, auto_increment=True)
plugin_name: str = fields.TextField() plugin_name: str = fields.TextField()

View File

@ -45,7 +45,6 @@ learning_chat = on_message(priority=99, block=False, rule=Rule(ChatRule), permis
@learning_chat.handle() @learning_chat.handle()
async def _(event: GroupMessageEvent, answers=Arg('answers')): async def _(event: GroupMessageEvent, answers=Arg('answers')):
for answer in answers: for answer in answers:
await asyncio.sleep(random.randint(1, 2))
try: try:
logger.info('群聊学习', f'{NICKNAME}将向群<m>{event.group_id}</m>回复<m>"{answer}"</m>') logger.info('群聊学习', f'{NICKNAME}将向群<m>{event.group_id}</m>回复<m>"{answer}"</m>')
msg = await learning_chat.send(Message(answer)) msg = await learning_chat.send(Message(answer))
@ -56,6 +55,7 @@ async def _(event: GroupMessageEvent, answers=Arg('answers')):
raw_message=answer, raw_message=answer,
time=int(time.time()), time=int(time.time()),
plain_text=Message(answer).extract_plain_text()) plain_text=Message(answer).extract_plain_text())
await asyncio.sleep(random.random() + 0.5)
except ActionFailed: except ActionFailed:
logger.info('群聊学习', f'{NICKNAME}向群<m>{event.group_id}</m>的回复<m>"{answer}"</m>发送<r>失败,可能处于风控中</r>') logger.info('群聊学习', f'{NICKNAME}向群<m>{event.group_id}</m>的回复<m>"{answer}"</m>发送<r>失败,可能处于风控中</r>')

View File

@ -1,3 +1,4 @@
import asyncio
import datetime import datetime
import random import random
import re import re
@ -90,7 +91,7 @@ class LearningChat:
return Result.Pass return Result.Pass
elif self.reply: elif self.reply:
# 如果是回复消息 # 如果是回复消息
if not (message := await ChatMessage.get_or_none(message_id=self.reply.message_id)): if not (message := await ChatMessage.filter(message_id=self.reply.message_id).first()):
# 回复的消息在数据库中有记录 # 回复的消息在数据库中有记录
logger.debug('群聊学习', '➤回复的消息不在数据库中,跳过') logger.debug('群聊学习', '➤回复的消息不在数据库中,跳过')
return Result.Pass return Result.Pass
@ -167,10 +168,17 @@ class LearningChat:
elif result == Result.Pass: elif result == Result.Pass:
# 跳过 # 跳过
return None return None
elif result == Result.Repeat and (messages := await ChatMessage.filter(group_id=self.data.group_id, elif result == Result.Repeat:
time__gte=self.data.time - 3600).limit( query_set = ChatMessage.filter(group_id=self.data.group_id, time__gte=self.data.time - 3600)
self.config.repeat_threshold)): if await query_set.limit(self.config.repeat_threshold + 5).filter(
# 如果达到阈值且bot没有回复过且不是全都为同一个人在说则进行复读 user_id=self.bot_id, message=self.data.message).exists():
# 如果在阈值+5条消息内bot已经回复过这句话则跳过
logger.debug('群聊学习', f'➤➤已经复读过了,跳过')
return None
if not (messages := await query_set.limit(
self.config.repeat_threshold + 5)):
return None
# 如果达到阈值,且不是全都为同一个人在说,则进行复读
if len(messages) >= self.config.repeat_threshold and all( if len(messages) >= self.config.repeat_threshold and all(
message.message == self.data.message and message.user_id != self.bot_id message.message == self.data.message and message.user_id != self.bot_id
for message in messages) and not all( for message in messages) and not all(
@ -181,12 +189,13 @@ class LearningChat:
else: else:
logger.debug('群聊学习', f'➤➤达到复读阈值,复读<m>{messages[0].message}</m>') logger.debug('群聊学习', f'➤➤达到复读阈值,复读<m>{messages[0].message}</m>')
return [self.data.message] return [self.data.message]
return None
else: else:
# 回复 # 回复
if self.data.is_plain_text and len(self.data.plain_text) <= 1: if self.data.is_plain_text and len(self.data.plain_text) <= 1:
logger.debug('群聊学习', '➤➤消息过短,不回复') logger.debug('群聊学习', '➤➤消息过短,不回复')
return None return None
if not (context := await ChatContext.get_or_none(keywords=self.data.keywords)): if not (context := await ChatContext.filter(keywords=self.data.keywords).first()):
logger.debug('群聊学习', '➤➤尚未有已学习的回复,不回复') logger.debug('群聊学习', '➤➤尚未有已学习的回复,不回复')
return None return None
@ -204,7 +213,8 @@ class LearningChat:
else: else:
answer_count_threshold = 1 answer_count_threshold = 1
cross_group_threshold = 1 cross_group_threshold = 1
logger.debug('群聊学习', f'➤➤本次回复阈值为<m>{answer_count_threshold}</m>,跨群阈值为<m>{cross_group_threshold}</m>') logger.debug('群聊学习',
f'➤➤本次回复阈值为<m>{answer_count_threshold}</m>,跨群阈值为<m>{cross_group_threshold}</m>')
# 获取满足跨群条件的回复 # 获取满足跨群条件的回复
answers_cross = await ChatAnswer.filter(context=context, count__gte=answer_count_threshold, answers_cross = await ChatAnswer.filter(context=context, count__gte=answer_count_threshold,
keywords__in=await ChatAnswer.annotate( keywords__in=await ChatAnswer.annotate(
@ -241,6 +251,7 @@ class LearningChat:
return None return None
result_message = random.choice(result.messages) result_message = random.choice(result.messages)
logger.debug('群聊学习', f'➤➤将回复<m>{result_message}</m>') logger.debug('群聊学习', f'➤➤将回复<m>{result_message}</m>')
await asyncio.sleep(random.random() + 0.5)
return [result_message] return [result_message]
async def _ban(self, message_id: Optional[int] = None) -> bool: async def _ban(self, message_id: Optional[int] = None) -> bool:
@ -248,7 +259,9 @@ class LearningChat:
bot = get_bot() bot = get_bot()
if message_id: if message_id:
# 如果有指定消息ID则屏蔽该消息 # 如果有指定消息ID则屏蔽该消息
if (message := await ChatMessage.get_or_none(message_id=message_id)) and message.message not in ALL_WORDS: if (
message := await ChatMessage.filter(
message_id=message_id).first()) and message.message not in ALL_WORDS:
keywords = message.keywords keywords = message.keywords
try: try:
await bot.delete_msg(message_id=message_id) await bot.delete_msg(message_id=message_id)
@ -266,7 +279,7 @@ class LearningChat:
logger.info('群聊学习', f'待禁用消息<m>{last_reply.message_id}</m>尝试撤回<r>失败</r>') logger.info('群聊学习', f'待禁用消息<m>{last_reply.message_id}</m>尝试撤回<r>失败</r>')
else: else:
return False return False
if ban_word := await ChatBlackList.get_or_none(keywords=keywords): if ban_word := await ChatBlackList.filter(keywords=keywords).first():
# 如果已有屏蔽记录 # 如果已有屏蔽记录
if self.data.group_id not in ban_word.ban_group_id: if self.data.group_id not in ban_word.ban_group_id:
# 如果不在屏蔽群列表中,则添加 # 如果不在屏蔽群列表中,则添加
@ -290,7 +303,7 @@ class LearningChat:
@staticmethod @staticmethod
async def add_ban(data: Union[ChatMessage, ChatContext, ChatAnswer]): async def add_ban(data: Union[ChatMessage, ChatContext, ChatAnswer]):
if ban_word := await ChatBlackList.get_or_none(keywords=data.keywords): if ban_word := await ChatBlackList.filter(keywords=data.keywords).first():
# 如果已有屏蔽记录 # 如果已有屏蔽记录
if isinstance(data, ChatMessage): if isinstance(data, ChatMessage):
if data.group_id not in ban_word.ban_group_id: if data.group_id not in ban_word.ban_group_id:
@ -360,7 +373,9 @@ class LearningChat:
continue continue
config = config_manager.get_group_config(group_id) config = config_manager.get_group_config(group_id)
ban_words = set(chat_config.ban_words + config.ban_words + ['[CQ:xml', '[CQ:json', '[CQ:at', '[CQ:video', '[CQ:record', '[CQ:share']) ban_words = set(
chat_config.ban_words + config.ban_words + ['[CQ:xml', '[CQ:json', '[CQ:at', '[CQ:video', '[CQ:record',
'[CQ:share'])
# 是否开启了主动发言 # 是否开启了主动发言
if not config.speak_enable: if not config.speak_enable:
@ -400,7 +415,7 @@ class LearningChat:
speak_list.append(message) speak_list.append(message)
while random.random() < config.speak_continuously_probability and len( while random.random() < config.speak_continuously_probability and len(
speak_list) < config.speak_continuously_max_len: speak_list) < config.speak_continuously_max_len:
if (follow_context := await ChatContext.get_or_none(keywords=answer.keywords)) and ( if (follow_context := await ChatContext.filter(keywords=answer.keywords).first()) and (
follow_answers := await ChatAnswer.filter( follow_answers := await ChatAnswer.filter(
group_id=group_id, group_id=group_id,
context=follow_context, context=follow_context,
@ -432,13 +447,13 @@ class LearningChat:
return None return None
async def _set_answer(self, message: ChatMessage): async def _set_answer(self, message: ChatMessage):
if context := await ChatContext.get_or_none(keywords=message.keywords): if context := await ChatContext.filter(keywords=message.keywords).first():
if context.count < chat_config.learn_max_count: if context.count < chat_config.learn_max_count:
context.count += 1 context.count += 1
context.time = self.data.time context.time = self.data.time
if answer := await ChatAnswer.get_or_none(keywords=self.data.keywords, if answer := await ChatAnswer.filter(keywords=self.data.keywords,
group_id=self.data.group_id, group_id=self.data.group_id,
context=context): context=context).first():
if answer.count < chat_config.learn_max_count: if answer.count < chat_config.learn_max_count:
answer.count += 1 answer.count += 1
answer.time = self.data.time answer.time = self.data.time
@ -476,7 +491,7 @@ class LearningChat:
if raw_message.startswith('&#91;') and raw_message.endswith('&#93;'): if raw_message.startswith('&#91;') and raw_message.endswith('&#93;'):
# logger.debug('群聊学习', f'➤检验<m>{keywords}</m><r>不通过</r>') # logger.debug('群聊学习', f'➤检验<m>{keywords}</m><r>不通过</r>')
return False return False
if ban_word := await ChatBlackList.get_or_none(keywords=message.keywords): if ban_word := await ChatBlackList.filter(keywords=message.keywords).first():
if ban_word.global_ban or message.group_id in ban_word.ban_group_id: if ban_word.global_ban or message.group_id in ban_word.ban_group_id:
# logger.debug('群聊学习', f'➤检验<m>{keywords}</m><r>不通过</r>') # logger.debug('群聊学习', f'➤检验<m>{keywords}</m><r>不通过</r>')
return False return False

View File

@ -1,8 +1,5 @@
import asyncio from nonebot import on_regex, on_command
from nonebot import on_regex, on_command, on_notice
from nonebot import plugin as nb_plugin
from nonebot.adapters.onebot.v11 import Message, GroupMessageEvent, PrivateMessageEvent, MessageEvent from nonebot.adapters.onebot.v11 import Message, GroupMessageEvent, PrivateMessageEvent, MessageEvent
from nonebot.adapters.onebot.v11 import NoticeEvent, FriendAddNoticeEvent, GroupIncreaseNoticeEvent
from nonebot.params import RegexDict, CommandArg from nonebot.params import RegexDict, CommandArg
from nonebot.permission import SUPERUSER from nonebot.permission import SUPERUSER
from nonebot.plugin import PluginMetadata from nonebot.plugin import PluginMetadata
@ -10,8 +7,8 @@ from nonebot.rule import Rule
from nonebot.typing import T_State from nonebot.typing import T_State
from LittlePaimon import SUPERUSERS from LittlePaimon import SUPERUSERS
from LittlePaimon.config import ConfigManager, PluginManager, HIDDEN_PLUGINS from LittlePaimon.config import ConfigManager, PluginManager
from LittlePaimon.database import PluginPermission from LittlePaimon.database import PluginDisable
from LittlePaimon.utils import logger from LittlePaimon.utils import logger
from LittlePaimon.utils.message import CommandObjectID from LittlePaimon.utils.message import CommandObjectID
from .draw_help import draw_help from .draw_help import draw_help
@ -27,19 +24,12 @@ __plugin_meta__ = PluginMetadata(
) )
def notice_rule(event: NoticeEvent) -> bool:
if isinstance(event, FriendAddNoticeEvent):
return True
elif isinstance(event, GroupIncreaseNoticeEvent):
return event.user_id == event.self_id
def fullmatch(msg: Message = CommandArg()) -> bool: def fullmatch(msg: Message = CommandArg()) -> bool:
return not bool(msg) return not bool(msg)
manage_cmd = on_regex( manage_cmd = on_regex(
r'^pm (?P<func>ban|unban) (?P<plugin>([\w ]*)|all|全部) ?(-g (?P<group>[\d ]*) ?)?(-u (?P<user>[\d ]*) ?)?(?P<reserve>-r)?', r'^pm (?P<func>ban|unban) (?P<plugin>([\w ]*)|all|全部) ?(-g (?P<group>[\d ]*) ?)?(-u (?P<user>[\d ]*) ?)?',
priority=1, block=True, state={ priority=1, block=True, state={
'pm_name': 'pm-ban|unban', 'pm_name': 'pm-ban|unban',
'pm_description': '禁用|取消禁用插件的群|用户使用权限', 'pm_description': '禁用|取消禁用插件的群|用户使用权限',
@ -58,11 +48,6 @@ set_config_cmd = on_command('pm set', priority=1, permission=SUPERUSER, block=Tr
'pm_usage': 'pm set<配置名> <值>', 'pm_usage': 'pm set<配置名> <值>',
'pm_priority': 2 'pm_priority': 2
}) })
notices = on_notice(priority=1, rule=Rule(notice_rule), block=True, state={
'pm_name': 'pm-new-group-user',
'pm_description': '为新加入的群|用户添加插件使用权限',
'pm_show': False
})
cache_help = {} cache_help = {}
@ -73,10 +58,15 @@ async def _(event: GroupMessageEvent, state: T_State, match: dict = RegexDict(),
await manage_cmd.finish('你没有权限使用该命令', at_sender=True) await manage_cmd.finish('你没有权限使用该命令', at_sender=True)
state['session_id'] = session_id state['session_id'] = session_id
state['bool'] = match['func'] == 'unban' state['bool'] = match['func'] == 'unban'
state['plugin'] = []
state['plugin_no_exist'] = [] state['plugin_no_exist'] = []
if any(w in match['plugin'] for w in {'all', '全部'}):
state['is_all'] = True
state['plugin'] = [p for p in PluginManager.plugins.keys() if p != 'plugin_manager']
else:
state['is_all'] = False
state['plugin'] = []
for plugin in match['plugin'].strip().split(' '): for plugin in match['plugin'].strip().split(' '):
if plugin in PluginManager.plugins.keys() or plugin in ['all', '全部']: if plugin in PluginManager.plugins.keys():
state['plugin'].append(plugin) state['plugin'].append(plugin)
elif module_name := list( elif module_name := list(
filter(lambda x: PluginManager.plugins[x].name == plugin, PluginManager.plugins.keys())): filter(lambda x: PluginManager.plugins[x].name == plugin, PluginManager.plugins.keys())):
@ -87,7 +77,7 @@ async def _(event: GroupMessageEvent, state: T_State, match: dict = RegexDict(),
state['group'] = [event.group_id] state['group'] = [event.group_id]
else: else:
state['group'] = [int(group) for group in match['group'].strip().split(' ')] state['group'] = [int(group) for group in match['group'].strip().split(' ')]
state['user'] = [int(user) for user in match['user'].strip().split(' ')] if match['user'] else [] state['user'] = [int(user) for user in match['user'].strip().split(' ')] if match['user'] else None
@manage_cmd.handle() @manage_cmd.handle()
@ -96,18 +86,23 @@ async def _(event: PrivateMessageEvent, state: T_State, match: dict = RegexDict(
await manage_cmd.finish('你没有权限使用该命令', at_sender=True) await manage_cmd.finish('你没有权限使用该命令', at_sender=True)
state['session_id'] = session_id state['session_id'] = session_id
state['bool'] = match['func'] == 'unban' state['bool'] = match['func'] == 'unban'
state['plugin'] = []
state['plugin_no_exist'] = [] state['plugin_no_exist'] = []
if any(w in match['plugin'] for w in {'all', '全部'}):
state['is_all'] = True
state['plugin'] = [p for p in PluginManager.plugins.keys() if p != 'plugin_manager']
else:
state['is_all'] = False
state['plugin'] = []
for plugin in match['plugin'].strip().split(' '): for plugin in match['plugin'].strip().split(' '):
if plugin in PluginManager.plugins.keys() or plugin in ['all', '全部']: if plugin in PluginManager.plugins.keys():
state['plugin'].append(plugin) state['plugin'].append(plugin)
elif module_name := list( elif module_name := list(
filter(lambda x: PluginManager.plugins[x].name == plugin, PluginManager.plugins.keys())): filter(lambda x: PluginManager.plugins[x].name == plugin, PluginManager.plugins.keys())):
state['plugin'].append(module_name[0]) state['plugin'].append(module_name[0])
else: else:
state['plugin_no_exist'].append(plugin) state['plugin_no_exist'].append(plugin)
state['group'] = [int(group) for group in match['group'].strip().split(' ')] if match['group'] else [] state['group'] = [int(group) for group in match['group'].strip().split(' ')] if match['group'] else None
state['user'] = [int(user) for user in match['user'].strip().split(' ')] if match['user'] else [] state['user'] = [int(user) for user in match['user'].strip().split(' ')] if match['user'] else None
@manage_cmd.got('bool') @manage_cmd.got('bool')
@ -119,45 +114,40 @@ async def _(state: T_State):
if not state['plugin'] and state['plugin_no_exist']: if not state['plugin'] and state['plugin_no_exist']:
await manage_cmd.finish(f'没有叫{" ".join(state["plugin_no_exist"])}的插件') await manage_cmd.finish(f'没有叫{" ".join(state["plugin_no_exist"])}的插件')
extra_msg = f',但没有叫{" ".join(state["plugin_no_exist"])}的插件。' if state['plugin_no_exist'] else '' extra_msg = f',但没有叫{" ".join(state["plugin_no_exist"])}的插件。' if state['plugin_no_exist'] else ''
if state['group'] and not state['user']: filter_arg = {}
for group_id in state['group']: if state['group']:
if 'all' in state['plugin']: filter_arg['group_id__in'] = state['group']
await PluginPermission.filter(session_id=group_id, session_type='group').update(status=state['bool']) if state['user']:
else: filter_arg['user_id__in'] = state['user']
await PluginPermission.filter(name__in=state['plugin'], session_id=group_id,
session_type='group').update(
status=state['bool'])
logger.info('插件管理器', logger.info('插件管理器',
f'{"<g>启用</g>" if state["bool"] else "<r>禁用</r>"}群<m>{" ".join(map(str, state["group"]))}</m>的插件<m>{" ".join(state["plugin"])}</m>使用权限') f'{"<g>启用</g>" if state["bool"] else "<r>禁用</r>"}群<m>{" ".join(map(str, state["group"]))}</m>中用户<m>{" ".join(map(str, state["user"]))}</m>的插件<m>{" ".join(state["plugin"]) if not state["is_all"] else "全部"}</m>使用权限')
await manage_cmd.finish( msg = f'{"启用" if state["bool"] else "禁用"}{" ".join(map(str, state["group"]))}中用户{" ".join(map(str, state["user"]))}的插件{" ".join(state["plugin"]) if not state["is_all"] else "全部"}使用权限{extra_msg}'
f'{"启用" if state["bool"] else "禁用"}{" ".join(map(str, state["group"]))}的插件{" ".join(state["plugin"])}使用权限{extra_msg}')
elif state['user'] and not state['group']:
for user_id in state['user']:
if 'all' in state['plugin']:
await PluginPermission.filter(session_id=user_id, session_type='user').update(status=state['bool'])
else: else:
await PluginPermission.filter(name__in=state['plugin'], session_id=user_id, session_type='user').update( filter_arg['user_id'] = None
status=state['bool'])
logger.info('插件管理器', logger.info('插件管理器',
f'{"<g>启用</g>" if state["bool"] else "<r>禁用</r>"}用户<m>{" ".join(map(str, state["user"]))}</m>的插件<m>{" ".join(state["plugin"])}</m>使用权限') f'{"<g>启用</g>" if state["bool"] else "<r>禁用</r>"}群<m>{" ".join(map(str, state["group"]))}</m>的插件<m>{" ".join(state["plugin"]) if not state["is_all"] else "全部"}</m>使用权限')
await manage_cmd.finish( msg = f'{"启用" if state["bool"] else "禁用"}{" ".join(map(str, state["group"]))}的插件{" ".join(state["plugin"]) if not state["is_all"] else "全部"}使用权限{extra_msg}'
f'{"启用" if state["bool"] else "禁用"}用户{" ".join(map(str, state["user"]))}的插件{" ".join(state["plugin"])}使用权限{extra_msg}')
else: else:
for group_id in state['group']: filter_arg['user_id__in'] = state['user']
if 'all' in state['plugin']:
plugin_list = await PluginPermission.filter(session_id=group_id, session_type='group').all()
else:
plugin_list = await PluginPermission.filter(name__in=state['plugin'], session_id=group_id,
session_type='group').all()
if plugin_list:
for plugin in plugin_list:
plugin.ban = list(set(plugin.ban) - set(state['user'])) if state['bool'] else list(
set(plugin.ban) | set(state['user']))
await plugin.save()
logger.info('插件管理器', logger.info('插件管理器',
f'{"<g>启用</g>" if state["bool"] else "<r>禁用</r>"}群<m>{" ".join(map(str, state["group"]))}</m>中用户<m>{" ".join(map(str, state["user"]))}</m>的插件<m>{" ".join(state["plugin"])}</m>使用权限') f'{"<g>启用</g>" if state["bool"] else "<r>禁用</r>"}用户<m>{" ".join(map(str, state["user"]))}</m>的插件<m>{" ".join(state["plugin"]) if not state["is_all"] else "全部"}</m>使用权限')
await manage_cmd.finish( msg = f'{"启用" if state["bool"] else "禁用"}用户{" ".join(map(str, state["user"]))}的插件{" ".join(state["plugin"]) if not state["is_all"] else "全部"}使用权限{extra_msg}'
f'{"启用" if state["bool"] else "禁用"}{" ".join(map(str, state["group"]))}中用户{" ".join(map(str, state["user"]))}的插件{" ".join(state["plugin"])}使用权限{extra_msg}') if state['bool']:
await PluginDisable.filter(name__in=state['plugin'], **filter_arg).delete()
else:
for plugin in state['plugin']:
if state['group']:
for group in state['group']:
if state['user']:
for user in state['user']:
await PluginDisable.update_or_create(name=plugin, group_id=group, user_id=user)
else:
await PluginDisable.update_or_create(name=plugin, group_id=group)
else:
for user in state['user']:
await PluginDisable.update_or_create(name=plugin, user_id=user)
await manage_cmd.finish(msg)
@help_cmd.handle() @help_cmd.handle()
@ -181,15 +171,3 @@ async def _(event: MessageEvent, msg: Message = CommandArg()):
else: else:
result = ConfigManager.set_config(msg[0], msg[1]) result = ConfigManager.set_config(msg[0], msg[1])
await set_config_cmd.finish(result) await set_config_cmd.finish(result)
@notices.handle()
async def _(event: NoticeEvent):
plugin_list = nb_plugin.get_loaded_plugins()
if isinstance(event, FriendAddNoticeEvent):
await asyncio.gather(*[PluginPermission.update_or_create(name=plugin, session_id=event.user_id, session_type='user') for plugin
in plugin_list if plugin not in HIDDEN_PLUGINS])
elif isinstance(event, GroupIncreaseNoticeEvent):
await asyncio.gather(
*[PluginPermission.update_or_create(name=plugin, session_id=event.group_id, session_type='group') for plugin
in plugin_list if plugin not in HIDDEN_PLUGINS])

View File

@ -4,7 +4,7 @@ from nonebot import get_driver
from .logger import logger from .logger import logger
from .scheduler import scheduler from .scheduler import scheduler
__version__ = '3.0.0rc3' __version__ = '3.0.0rc4'
DRIVER = get_driver() DRIVER = get_driver()
try: try:

View File

@ -5,7 +5,7 @@ from fastapi import APIRouter
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from LittlePaimon.config import ConfigManager, PluginManager, PluginInfo from LittlePaimon.config import ConfigManager, PluginManager, PluginInfo
from LittlePaimon.database import PluginPermission from LittlePaimon.database import PluginDisable
from .utils import authentication from .utils import authentication
@ -27,28 +27,31 @@ async def get_plugins():
@route.post('/set_plugin_status', response_class=JSONResponse, dependencies=[authentication()]) @route.post('/set_plugin_status', response_class=JSONResponse, dependencies=[authentication()])
async def set_plugin_status(data: dict): async def set_plugin_status(data: dict):
module_name = data.get('plugin') module_name: str = data.get('plugin')
status = data.get('status') status: bool = data.get('status')
try: try:
from LittlePaimon.plugins.plugin_manager import cache_help from LittlePaimon.plugins.plugin_manager import cache_help
cache_help.clear() cache_help.clear()
except Exception: except Exception:
pass pass
await PluginPermission.filter(name=module_name).update(status=status) if status:
await PluginDisable.filter(name=module_name, global_disable=True).delete()
else:
await PluginDisable.create(name=module_name, global_disable=True)
return {'status': 0, 'msg': f'成功设置{module_name}插件状态为{status}'} return {'status': 0, 'msg': f'成功设置{module_name}插件状态为{status}'}
@route.get('/get_plugin_bans', response_class=JSONResponse, dependencies=[authentication()]) @route.get('/get_plugin_bans', response_class=JSONResponse, dependencies=[authentication()])
async def get_plugin_status(module_name: str): async def get_plugin_status(module_name: str):
result = [] result = []
bans = await PluginPermission.filter(name=module_name).all() bans = await PluginDisable.filter(name=module_name).all()
for ban in bans: for ban in bans:
if ban.session_type == 'group': if ban.user_id and ban.group_id:
result.extend(f'{ban.session_id}.{b}' for b in ban.ban) result.append(f'{ban.group_id}.{ban.user_id}')
if not ban.status: elif ban.group_id and not ban.user_id:
result.append(f'{ban.session_id}') result.append(f'{ban.group_id}')
elif ban.session_type == 'user' and not ban.status: elif ban.user_id and not ban.group_id:
result.append(f'{ban.session_id}') result.append(f'{ban.user_id}')
return { return {
'status': 0, 'status': 0,
'msg': 'ok', 'msg': 'ok',
@ -63,20 +66,17 @@ async def get_plugin_status(module_name: str):
async def set_plugin_bans(data: dict): async def set_plugin_bans(data: dict):
bans = data['bans'] bans = data['bans']
name = data['module_name'] name = data['module_name']
await PluginPermission.filter(name=name).update(status=True, ban=[]) await PluginDisable.filter(name=name, global_disable=False).delete()
for ban in bans: for ban in bans:
if ban.startswith(''): if ban.startswith(''):
if '.' in ban: if '.' in ban:
group_id = int(ban.split('.')[0][1:]) group_id = int(ban.split('.')[0][1:])
user_id = int(ban.split('.')[1]) user_id = int(ban.split('.')[1])
plugin = await PluginPermission.filter(name=name, session_type='group', session_id=group_id).first() await PluginDisable.create(name=name, group_id=group_id, user_id=user_id)
plugin.ban.append(user_id)
await plugin.save()
else: else:
await PluginPermission.filter(name=name, session_type='group', session_id=int(ban[1:])).update( await PluginDisable.create(name=name, group_id=int(ban[1:]))
status=False)
else: else:
await PluginPermission.filter(name=name, session_type='user', session_id=int(ban)).update(status=False) await PluginDisable.create(name=name, user_id=int(ban))
try: try:
from LittlePaimon.plugins.plugin_manager import cache_help from LittlePaimon.plugins.plugin_manager import cache_help
cache_help.clear() cache_help.clear()