From 689a7701d36c09c610f508e7819fab7a6b1b269e Mon Sep 17 00:00:00 2001 From: CMHopeSunshine <277073121@qq.com> Date: Sat, 26 Nov 2022 18:25:11 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20=E4=BC=98=E5=8C=96`=E7=BE=A4?= =?UTF-8?q?=E8=81=8A=E5=AD=A6=E4=B9=A0`=E5=A4=8D=E8=AF=BB=E8=A1=A8?= =?UTF-8?q?=E7=8E=B0=E5=92=8C`=E6=8F=92=E4=BB=B6=E6=9D=83=E9=99=90?= =?UTF-8?q?=E7=AE=A1=E7=90=86`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- LittlePaimon/__init__.py | 2 + LittlePaimon/config/plugin/manage.py | 162 ++++++++---------- LittlePaimon/database/models/manage.py | 16 ++ .../plugins/Learning_Chat/__init__.py | 2 +- LittlePaimon/plugins/Learning_Chat/handler.py | 49 ++++-- .../plugins/plugin_manager/__init__.py | 156 ++++++++--------- LittlePaimon/utils/__init__.py | 2 +- LittlePaimon/web/api/plugin.py | 36 ++-- 8 files changed, 208 insertions(+), 217 deletions(-) diff --git a/LittlePaimon/__init__.py b/LittlePaimon/__init__.py index a49f7d1..8243c2a 100644 --- a/LittlePaimon/__init__.py +++ b/LittlePaimon/__init__.py @@ -2,6 +2,7 @@ from pathlib import Path from nonebot import load_plugins, logger from LittlePaimon import database, web +from LittlePaimon.config import PluginManager from LittlePaimon.utils import DRIVER, __version__, NICKNAME, SUPERUSERS from LittlePaimon.utils.tool import check_resource @@ -35,6 +36,7 @@ logo = """ async def startup(): logger.opt(colors=True).info(logo) await database.connect() + await PluginManager.init() await check_resource() diff --git a/LittlePaimon/config/plugin/manage.py b/LittlePaimon/config/plugin/manage.py index fb5580b..03769b5 100644 --- a/LittlePaimon/config/plugin/manage.py +++ b/LittlePaimon/config/plugin/manage.py @@ -1,17 +1,18 @@ -import asyncio +import contextlib import datetime from typing import Dict, List from nonebot import plugin as nb_plugin -from nonebot import get_bot -from nonebot.matcher import Matcher -from nonebot.exception import IgnoredException -from nonebot.message import run_preprocessor from nonebot.adapters.onebot.v11 import MessageEvent, PrivateMessageEvent, GroupMessageEvent -from LittlePaimon.utils import logger, DRIVER, SUPERUSERS -from LittlePaimon.utils.path import PLUGIN_CONFIG +from nonebot.exception import IgnoredException +from nonebot.matcher import Matcher +from nonebot.message import run_preprocessor +from tortoise.queryset import Q + +from LittlePaimon.database.models import PluginPermission, PluginStatistics, PluginDisable +from LittlePaimon.utils import logger, SUPERUSERS from LittlePaimon.utils.files import load_yaml, save_yaml -from LittlePaimon.database.models import PluginPermission, PluginStatistics +from LittlePaimon.utils.path import PLUGIN_CONFIG from .model import MatcherInfo, PluginInfo HIDDEN_PLUGINS = [ @@ -43,42 +44,22 @@ class PluginManager: @classmethod async def init(cls): plugin_list = nb_plugin.get_loaded_plugins() - group_list = await get_bot().get_group_list() - user_list = await get_bot().get_friend_list() + if not await PluginDisable.all().exists() and await PluginPermission.all().exists(): + perms = await PluginPermission.filter(Q(status=False) | Q(ban__not=[])).all() + for perm in perms: + with contextlib.suppress(Exception): + if perm.session_type == 'group': + if not perm.status: + await PluginDisable.update_or_create(name=perm.name, group_id=perm.session_id) + for ban_user in perm.ban: + await PluginDisable.update_or_create(name=perm.name, group_id=perm.session_id, + user_id=ban_user) + else: + if not perm.status: + await PluginDisable.update_or_create(name=perm.name, user_id=perm.session_id) + await PluginPermission.all().delete() + await PluginDisable.filter(global_disable=False, group_id=None, user_id=None).delete() for plugin in plugin_list: - if plugin.name not in HIDDEN_PLUGINS and PluginPermission._meta.default_connection is not None: - if group_list: - for group in group_list: - count = await PluginPermission.filter( - name=plugin.name, session_id=group['group_id'], session_type='group' - ).count() - if count > 1: - first = await PluginPermission.filter( - name=plugin.name, session_id=group['group_id'], session_type='group' - ).order_by('id').first() - await PluginPermission.filter( - name=plugin.name, session_id=group['group_id'], session_type='group' - ).delete() - await first.save() - elif count == 0: - await PluginPermission.create(name=plugin.name, session_id=group['group_id'], - session_type='group') - if user_list: - for user in user_list: - count = await PluginPermission.filter( - name=plugin.name, session_id=user['user_id'], session_type='user' - ).count() - if count > 1: - first = await PluginPermission.filter( - name=plugin.name, session_id=user['user_id'], session_type='user' - ).order_by('id').first() - await PluginPermission.filter( - name=plugin.name, session_id=user['user_id'], session_type='user' - ).delete() - await first.save() - elif count == 0: - await PluginPermission.create(name=plugin.name, session_id=user['user_id'], - session_type='user') if plugin.name not in HIDDEN_PLUGINS: if plugin.name not in cls.plugins: if metadata := plugin.metadata: @@ -113,15 +94,23 @@ class PluginManager: :param message_type: 消息类型 :param session_id: 消息ID """ - load_plugins = nb_plugin.get_loaded_plugins() - load_plugins = [p.name for p in load_plugins] + load_plugins = [p.name for p in nb_plugin.get_loaded_plugins()] plugin_list = sorted(cls.plugins.values(), key=lambda x: x.priority).copy() plugin_list = [p for p in plugin_list if p.show and p.module_name in load_plugins] for plugin in plugin_list: - if message_type != 'guild': - plugin_info = await PluginPermission.get_or_none(name=plugin.module_name, session_id=session_id, - session_type=message_type) - plugin.status = True if plugin_info is None else plugin_info.status + if not await PluginDisable.filter(name=plugin.module_name, global_disable=True).exists(): + if message_type != 'guild': + # plugin_info = await PluginPermission.get_or_none(name=plugin.module_name, session_id=session_id, + # session_type=message_type) + # plugin.status = True if plugin_info is None else plugin_info.status + if message_type == 'group': + plugin.status = not await PluginDisable.filter(name=plugin.module_name, + group_id=session_id).exists() + else: + plugin.status = not await PluginDisable.filter(name=plugin.module_name, + user_id=session_id).exists() + else: + plugin.status = True else: plugin.status = True if plugin.matchers: @@ -134,57 +123,48 @@ class PluginManager: """ 获取插件列表(供Web UI使用) """ - load_plugins = nb_plugin.get_loaded_plugins() - load_plugins = [p.name for p in load_plugins] + load_plugins = [p.name for p in nb_plugin.get_loaded_plugins()] plugin_list = [p.dict(exclude={'status'}) for p in cls.plugins.values()] for plugin in plugin_list: plugin['matchers'].sort(key=lambda x: x['pm_priority']) plugin['isLoad'] = plugin['module_name'] in load_plugins - plugin['status'] = await PluginPermission.filter(name=plugin['module_name'], status=True).exists() + plugin['status'] = not await PluginDisable.filter(name=plugin['module_name'], global_disable=True).exists() plugin_list.sort(key=lambda x: (x['isLoad'], x['status'], -x['priority']), reverse=True) return plugin_list -@DRIVER.on_bot_connect -async def _(): - await PluginManager.init() - - @run_preprocessor async def _(event: MessageEvent, matcher: Matcher): - if event.user_id in SUPERUSERS: - return - if not matcher.plugin_name or matcher.plugin_name in HIDDEN_PLUGINS: - return - if isinstance(event, PrivateMessageEvent): - session_id = event.user_id - session_type = 'user' - elif isinstance(event, GroupMessageEvent): - session_id = event.group_id - session_type = 'group' - else: - return + try: + if event.user_id in SUPERUSERS: + return + if not matcher.plugin_name or matcher.plugin_name in HIDDEN_PLUGINS: + return + if not isinstance(event, (PrivateMessageEvent, GroupMessageEvent)): + return - # 权限检查 - perm = await PluginPermission.get_or_none(name=matcher.plugin_name, session_id=session_id, - session_type=session_type) - if not perm: - await PluginPermission.create(name=matcher.plugin_name, session_id=session_id, session_type=session_type) - return - if not perm.status: - raise IgnoredException('插件使用权限已禁用') - if isinstance(event, GroupMessageEvent) and event.user_id in perm.ban: - raise IgnoredException('用户被禁止使用该插件') + # 权限检查 + if await PluginDisable.get_or_none(name=matcher.plugin_name, global_disable=True): + raise IgnoredException('插件使用权限已禁用') + if await PluginDisable.get_or_none(name=matcher.plugin_name, user_id=event.user_id, group_id=None): + raise IgnoredException('插件使用权限已禁用') + elif isinstance(event, GroupMessageEvent) and ( + perms := await PluginDisable.filter(name=matcher.plugin_name, group_id=event.group_id)): + user_ids = [p.user_id for p in perms] + if None in user_ids or event.user_id in user_ids: + raise IgnoredException('插件使用权限已禁用') - # 命令调用统计 - if matcher.plugin_name in PluginManager.plugins and 'pm_name' in matcher.state: - if matcher_info := list(filter(lambda x: x.pm_name == matcher.state['pm_name'], - PluginManager.plugins[matcher.plugin_name].matchers)): - matcher_info = matcher_info[0] - await PluginStatistics.create(plugin_name=matcher.plugin_name, - matcher_name=matcher_info.pm_name, - matcher_usage=matcher_info.pm_usage, - group_id=event.group_id if isinstance(event, GroupMessageEvent) else None, - user_id=event.user_id, - message_type=session_type, - time=datetime.datetime.now()) + # 命令调用统计 + if matcher.plugin_name in PluginManager.plugins and 'pm_name' in matcher.state: + if matcher_info := list(filter(lambda x: x.pm_name == matcher.state['pm_name'], + PluginManager.plugins[matcher.plugin_name].matchers)): + matcher_info = matcher_info[0] + await PluginStatistics.create(plugin_name=matcher.plugin_name, + matcher_name=matcher_info.pm_name, + matcher_usage=matcher_info.pm_usage, + group_id=event.group_id if isinstance(event, GroupMessageEvent) else None, + user_id=event.user_id, + message_type=event.message_type, + time=datetime.datetime.now()) + except Exception as e: + logger.info('插件管理器', f'插件权限检查失败:{e}') diff --git a/LittlePaimon/database/models/manage.py b/LittlePaimon/database/models/manage.py index 9665e67..b1390b2 100644 --- a/LittlePaimon/database/models/manage.py +++ b/LittlePaimon/database/models/manage.py @@ -7,6 +7,7 @@ from tortoise.models import Model class PluginPermission(Model): + """将在N个版本后废弃""" id = fields.IntField(pk=True, generated=True, auto_increment=True) name: str = fields.TextField() """插件名称""" @@ -25,6 +26,21 @@ class PluginPermission(Model): table = 'plugin_permission' +class PluginDisable(Model): + id = fields.IntField(pk=True, generated=True, auto_increment=True) + name: str = fields.TextField() + """插件名称""" + global_disable: bool = fields.BooleanField(default=False) + """全局禁用""" + user_id: int = fields.IntField(null=True) + """用户id""" + group_id: int = fields.IntField(null=True) + """群组id""" + + class Meta: + table = 'plugin_disable' + + class PluginStatistics(Model): id = fields.IntField(pk=True, generated=True, auto_increment=True) plugin_name: str = fields.TextField() diff --git a/LittlePaimon/plugins/Learning_Chat/__init__.py b/LittlePaimon/plugins/Learning_Chat/__init__.py index 5472d16..163b73c 100644 --- a/LittlePaimon/plugins/Learning_Chat/__init__.py +++ b/LittlePaimon/plugins/Learning_Chat/__init__.py @@ -45,7 +45,6 @@ learning_chat = on_message(priority=99, block=False, rule=Rule(ChatRule), permis @learning_chat.handle() async def _(event: GroupMessageEvent, answers=Arg('answers')): for answer in answers: - await asyncio.sleep(random.randint(1, 2)) try: logger.info('群聊学习', f'{NICKNAME}将向群{event.group_id}回复"{answer}"') msg = await learning_chat.send(Message(answer)) @@ -56,6 +55,7 @@ async def _(event: GroupMessageEvent, answers=Arg('answers')): raw_message=answer, time=int(time.time()), plain_text=Message(answer).extract_plain_text()) + await asyncio.sleep(random.random() + 0.5) except ActionFailed: logger.info('群聊学习', f'{NICKNAME}向群{event.group_id}的回复"{answer}"发送失败,可能处于风控中') diff --git a/LittlePaimon/plugins/Learning_Chat/handler.py b/LittlePaimon/plugins/Learning_Chat/handler.py index 86143ef..e38dd77 100644 --- a/LittlePaimon/plugins/Learning_Chat/handler.py +++ b/LittlePaimon/plugins/Learning_Chat/handler.py @@ -1,3 +1,4 @@ +import asyncio import datetime import random import re @@ -90,7 +91,7 @@ class LearningChat: return Result.Pass elif self.reply: # 如果是回复消息 - if not (message := await ChatMessage.get_or_none(message_id=self.reply.message_id)): + if not (message := await ChatMessage.filter(message_id=self.reply.message_id).first()): # 回复的消息在数据库中有记录 logger.debug('群聊学习', '➤回复的消息不在数据库中,跳过') return Result.Pass @@ -167,10 +168,17 @@ class LearningChat: elif result == Result.Pass: # 跳过 return None - elif result == Result.Repeat and (messages := await ChatMessage.filter(group_id=self.data.group_id, - time__gte=self.data.time - 3600).limit( - self.config.repeat_threshold)): - # 如果达到阈值,且bot没有回复过,且不是全都为同一个人在说,则进行复读 + elif result == Result.Repeat: + query_set = ChatMessage.filter(group_id=self.data.group_id, time__gte=self.data.time - 3600) + if await query_set.limit(self.config.repeat_threshold + 5).filter( + user_id=self.bot_id, message=self.data.message).exists(): + # 如果在阈值+5条消息内,bot已经回复过这句话,则跳过 + logger.debug('群聊学习', f'➤➤已经复读过了,跳过') + return None + if not (messages := await query_set.limit( + self.config.repeat_threshold + 5)): + return None + # 如果达到阈值,且不是全都为同一个人在说,则进行复读 if len(messages) >= self.config.repeat_threshold and all( message.message == self.data.message and message.user_id != self.bot_id for message in messages) and not all( @@ -181,12 +189,13 @@ class LearningChat: else: logger.debug('群聊学习', f'➤➤达到复读阈值,复读{messages[0].message}') return [self.data.message] + return None else: # 回复 if self.data.is_plain_text and len(self.data.plain_text) <= 1: logger.debug('群聊学习', '➤➤消息过短,不回复') return None - if not (context := await ChatContext.get_or_none(keywords=self.data.keywords)): + if not (context := await ChatContext.filter(keywords=self.data.keywords).first()): logger.debug('群聊学习', '➤➤尚未有已学习的回复,不回复') return None @@ -204,7 +213,8 @@ class LearningChat: else: answer_count_threshold = 1 cross_group_threshold = 1 - logger.debug('群聊学习', f'➤➤本次回复阈值为{answer_count_threshold},跨群阈值为{cross_group_threshold}') + logger.debug('群聊学习', + f'➤➤本次回复阈值为{answer_count_threshold},跨群阈值为{cross_group_threshold}') # 获取满足跨群条件的回复 answers_cross = await ChatAnswer.filter(context=context, count__gte=answer_count_threshold, keywords__in=await ChatAnswer.annotate( @@ -241,6 +251,7 @@ class LearningChat: return None result_message = random.choice(result.messages) logger.debug('群聊学习', f'➤➤将回复{result_message}') + await asyncio.sleep(random.random() + 0.5) return [result_message] async def _ban(self, message_id: Optional[int] = None) -> bool: @@ -248,7 +259,9 @@ class LearningChat: bot = get_bot() if message_id: # 如果有指定消息ID,则屏蔽该消息 - if (message := await ChatMessage.get_or_none(message_id=message_id)) and message.message not in ALL_WORDS: + if ( + message := await ChatMessage.filter( + message_id=message_id).first()) and message.message not in ALL_WORDS: keywords = message.keywords try: await bot.delete_msg(message_id=message_id) @@ -266,7 +279,7 @@ class LearningChat: logger.info('群聊学习', f'待禁用消息{last_reply.message_id}尝试撤回失败') else: return False - if ban_word := await ChatBlackList.get_or_none(keywords=keywords): + if ban_word := await ChatBlackList.filter(keywords=keywords).first(): # 如果已有屏蔽记录 if self.data.group_id not in ban_word.ban_group_id: # 如果不在屏蔽群列表中,则添加 @@ -290,7 +303,7 @@ class LearningChat: @staticmethod async def add_ban(data: Union[ChatMessage, ChatContext, ChatAnswer]): - if ban_word := await ChatBlackList.get_or_none(keywords=data.keywords): + if ban_word := await ChatBlackList.filter(keywords=data.keywords).first(): # 如果已有屏蔽记录 if isinstance(data, ChatMessage): if data.group_id not in ban_word.ban_group_id: @@ -360,7 +373,9 @@ class LearningChat: continue config = config_manager.get_group_config(group_id) - ban_words = set(chat_config.ban_words + config.ban_words + ['[CQ:xml', '[CQ:json', '[CQ:at', '[CQ:video', '[CQ:record', '[CQ:share']) + ban_words = set( + chat_config.ban_words + config.ban_words + ['[CQ:xml', '[CQ:json', '[CQ:at', '[CQ:video', '[CQ:record', + '[CQ:share']) # 是否开启了主动发言 if not config.speak_enable: @@ -400,7 +415,7 @@ class LearningChat: speak_list.append(message) while random.random() < config.speak_continuously_probability and len( speak_list) < config.speak_continuously_max_len: - if (follow_context := await ChatContext.get_or_none(keywords=answer.keywords)) and ( + if (follow_context := await ChatContext.filter(keywords=answer.keywords).first()) and ( follow_answers := await ChatAnswer.filter( group_id=group_id, context=follow_context, @@ -432,13 +447,13 @@ class LearningChat: return None async def _set_answer(self, message: ChatMessage): - if context := await ChatContext.get_or_none(keywords=message.keywords): + if context := await ChatContext.filter(keywords=message.keywords).first(): if context.count < chat_config.learn_max_count: context.count += 1 context.time = self.data.time - if answer := await ChatAnswer.get_or_none(keywords=self.data.keywords, - group_id=self.data.group_id, - context=context): + if answer := await ChatAnswer.filter(keywords=self.data.keywords, + group_id=self.data.group_id, + context=context).first(): if answer.count < chat_config.learn_max_count: answer.count += 1 answer.time = self.data.time @@ -476,7 +491,7 @@ class LearningChat: if raw_message.startswith('[') and raw_message.endswith(']'): # logger.debug('群聊学习', f'➤检验{keywords}不通过') return False - if ban_word := await ChatBlackList.get_or_none(keywords=message.keywords): + if ban_word := await ChatBlackList.filter(keywords=message.keywords).first(): if ban_word.global_ban or message.group_id in ban_word.ban_group_id: # logger.debug('群聊学习', f'➤检验{keywords}不通过') return False diff --git a/LittlePaimon/plugins/plugin_manager/__init__.py b/LittlePaimon/plugins/plugin_manager/__init__.py index 30ebc48..40fa763 100644 --- a/LittlePaimon/plugins/plugin_manager/__init__.py +++ b/LittlePaimon/plugins/plugin_manager/__init__.py @@ -1,8 +1,5 @@ -import asyncio -from nonebot import on_regex, on_command, on_notice -from nonebot import plugin as nb_plugin +from nonebot import on_regex, on_command from nonebot.adapters.onebot.v11 import Message, GroupMessageEvent, PrivateMessageEvent, MessageEvent -from nonebot.adapters.onebot.v11 import NoticeEvent, FriendAddNoticeEvent, GroupIncreaseNoticeEvent from nonebot.params import RegexDict, CommandArg from nonebot.permission import SUPERUSER from nonebot.plugin import PluginMetadata @@ -10,8 +7,8 @@ from nonebot.rule import Rule from nonebot.typing import T_State from LittlePaimon import SUPERUSERS -from LittlePaimon.config import ConfigManager, PluginManager, HIDDEN_PLUGINS -from LittlePaimon.database import PluginPermission +from LittlePaimon.config import ConfigManager, PluginManager +from LittlePaimon.database import PluginDisable from LittlePaimon.utils import logger from LittlePaimon.utils.message import CommandObjectID from .draw_help import draw_help @@ -27,19 +24,12 @@ __plugin_meta__ = PluginMetadata( ) -def notice_rule(event: NoticeEvent) -> bool: - if isinstance(event, FriendAddNoticeEvent): - return True - elif isinstance(event, GroupIncreaseNoticeEvent): - return event.user_id == event.self_id - - def fullmatch(msg: Message = CommandArg()) -> bool: return not bool(msg) manage_cmd = on_regex( - r'^pm (?Pban|unban) (?P([\w ]*)|all|全部) ?(-g (?P[\d ]*) ?)?(-u (?P[\d ]*) ?)?(?P-r)?', + r'^pm (?Pban|unban) (?P([\w ]*)|all|全部) ?(-g (?P[\d ]*) ?)?(-u (?P[\d ]*) ?)?', priority=1, block=True, state={ 'pm_name': 'pm-ban|unban', 'pm_description': '禁用|取消禁用插件的群|用户使用权限', @@ -58,11 +48,6 @@ set_config_cmd = on_command('pm set', priority=1, permission=SUPERUSER, block=Tr 'pm_usage': 'pm set<配置名> <值>', 'pm_priority': 2 }) -notices = on_notice(priority=1, rule=Rule(notice_rule), block=True, state={ - 'pm_name': 'pm-new-group-user', - 'pm_description': '为新加入的群|用户添加插件使用权限', - 'pm_show': False -}) cache_help = {} @@ -73,21 +58,26 @@ async def _(event: GroupMessageEvent, state: T_State, match: dict = RegexDict(), await manage_cmd.finish('你没有权限使用该命令', at_sender=True) state['session_id'] = session_id state['bool'] = match['func'] == 'unban' - state['plugin'] = [] state['plugin_no_exist'] = [] - for plugin in match['plugin'].strip().split(' '): - if plugin in PluginManager.plugins.keys() or plugin in ['all', '全部']: - state['plugin'].append(plugin) - elif module_name := list( - filter(lambda x: PluginManager.plugins[x].name == plugin, PluginManager.plugins.keys())): - state['plugin'].append(module_name[0]) - else: - state['plugin_no_exist'].append(plugin) + if any(w in match['plugin'] for w in {'all', '全部'}): + state['is_all'] = True + state['plugin'] = [p for p in PluginManager.plugins.keys() if p != 'plugin_manager'] + else: + state['is_all'] = False + state['plugin'] = [] + for plugin in match['plugin'].strip().split(' '): + if plugin in PluginManager.plugins.keys(): + state['plugin'].append(plugin) + elif module_name := list( + filter(lambda x: PluginManager.plugins[x].name == plugin, PluginManager.plugins.keys())): + state['plugin'].append(module_name[0]) + else: + state['plugin_no_exist'].append(plugin) if not match['group'] or event.user_id not in SUPERUSERS: state['group'] = [event.group_id] else: state['group'] = [int(group) for group in match['group'].strip().split(' ')] - state['user'] = [int(user) for user in match['user'].strip().split(' ')] if match['user'] else [] + state['user'] = [int(user) for user in match['user'].strip().split(' ')] if match['user'] else None @manage_cmd.handle() @@ -96,18 +86,23 @@ async def _(event: PrivateMessageEvent, state: T_State, match: dict = RegexDict( await manage_cmd.finish('你没有权限使用该命令', at_sender=True) state['session_id'] = session_id state['bool'] = match['func'] == 'unban' - state['plugin'] = [] state['plugin_no_exist'] = [] - for plugin in match['plugin'].strip().split(' '): - if plugin in PluginManager.plugins.keys() or plugin in ['all', '全部']: - state['plugin'].append(plugin) - elif module_name := list( - filter(lambda x: PluginManager.plugins[x].name == plugin, PluginManager.plugins.keys())): - state['plugin'].append(module_name[0]) - else: - state['plugin_no_exist'].append(plugin) - state['group'] = [int(group) for group in match['group'].strip().split(' ')] if match['group'] else [] - state['user'] = [int(user) for user in match['user'].strip().split(' ')] if match['user'] else [] + if any(w in match['plugin'] for w in {'all', '全部'}): + state['is_all'] = True + state['plugin'] = [p for p in PluginManager.plugins.keys() if p != 'plugin_manager'] + else: + state['is_all'] = False + state['plugin'] = [] + for plugin in match['plugin'].strip().split(' '): + if plugin in PluginManager.plugins.keys(): + state['plugin'].append(plugin) + elif module_name := list( + filter(lambda x: PluginManager.plugins[x].name == plugin, PluginManager.plugins.keys())): + state['plugin'].append(module_name[0]) + else: + state['plugin_no_exist'].append(plugin) + state['group'] = [int(group) for group in match['group'].strip().split(' ')] if match['group'] else None + state['user'] = [int(user) for user in match['user'].strip().split(' ')] if match['user'] else None @manage_cmd.got('bool') @@ -119,45 +114,40 @@ async def _(state: T_State): if not state['plugin'] and state['plugin_no_exist']: await manage_cmd.finish(f'没有叫{" ".join(state["plugin_no_exist"])}的插件') extra_msg = f',但没有叫{" ".join(state["plugin_no_exist"])}的插件。' if state['plugin_no_exist'] else '。' - if state['group'] and not state['user']: - for group_id in state['group']: - if 'all' in state['plugin']: - await PluginPermission.filter(session_id=group_id, session_type='group').update(status=state['bool']) - else: - await PluginPermission.filter(name__in=state['plugin'], session_id=group_id, - session_type='group').update( - status=state['bool']) - logger.info('插件管理器', - f'已{"启用" if state["bool"] else "禁用"}群{" ".join(map(str, state["group"]))}的插件{" ".join(state["plugin"])}使用权限') - await manage_cmd.finish( - f'已{"启用" if state["bool"] else "禁用"}群{" ".join(map(str, state["group"]))}的插件{" ".join(state["plugin"])}使用权限{extra_msg}') - elif state['user'] and not state['group']: - for user_id in state['user']: - if 'all' in state['plugin']: - await PluginPermission.filter(session_id=user_id, session_type='user').update(status=state['bool']) - else: - await PluginPermission.filter(name__in=state['plugin'], session_id=user_id, session_type='user').update( - status=state['bool']) - logger.info('插件管理器', - f'已{"启用" if state["bool"] else "禁用"}用户{" ".join(map(str, state["user"]))}的插件{" ".join(state["plugin"])}使用权限') - await manage_cmd.finish( - f'已{"启用" if state["bool"] else "禁用"}用户{" ".join(map(str, state["user"]))}的插件{" ".join(state["plugin"])}使用权限{extra_msg}') + filter_arg = {} + if state['group']: + filter_arg['group_id__in'] = state['group'] + if state['user']: + filter_arg['user_id__in'] = state['user'] + logger.info('插件管理器', + f'已{"启用" if state["bool"] else "禁用"}群{" ".join(map(str, state["group"]))}中用户{" ".join(map(str, state["user"]))}的插件{" ".join(state["plugin"]) if not state["is_all"] else "全部"}使用权限') + msg = f'已{"启用" if state["bool"] else "禁用"}群{" ".join(map(str, state["group"]))}中用户{" ".join(map(str, state["user"]))}的插件{" ".join(state["plugin"]) if not state["is_all"] else "全部"}使用权限{extra_msg}' + else: + filter_arg['user_id'] = None + logger.info('插件管理器', + f'已{"启用" if state["bool"] else "禁用"}群{" ".join(map(str, state["group"]))}的插件{" ".join(state["plugin"]) if not state["is_all"] else "全部"}使用权限') + msg = f'已{"启用" if state["bool"] else "禁用"}群{" ".join(map(str, state["group"]))}的插件{" ".join(state["plugin"]) if not state["is_all"] else "全部"}使用权限{extra_msg}' else: - for group_id in state['group']: - if 'all' in state['plugin']: - plugin_list = await PluginPermission.filter(session_id=group_id, session_type='group').all() - else: - plugin_list = await PluginPermission.filter(name__in=state['plugin'], session_id=group_id, - session_type='group').all() - if plugin_list: - for plugin in plugin_list: - plugin.ban = list(set(plugin.ban) - set(state['user'])) if state['bool'] else list( - set(plugin.ban) | set(state['user'])) - await plugin.save() + filter_arg['user_id__in'] = state['user'] logger.info('插件管理器', - f'已{"启用" if state["bool"] else "禁用"}群{" ".join(map(str, state["group"]))}中用户{" ".join(map(str, state["user"]))}的插件{" ".join(state["plugin"])}使用权限') - await manage_cmd.finish( - f'已{"启用" if state["bool"] else "禁用"}群{" ".join(map(str, state["group"]))}中用户{" ".join(map(str, state["user"]))}的插件{" ".join(state["plugin"])}使用权限{extra_msg}') + f'已{"启用" if state["bool"] else "禁用"}用户{" ".join(map(str, state["user"]))}的插件{" ".join(state["plugin"]) if not state["is_all"] else "全部"}使用权限') + msg = f'已{"启用" if state["bool"] else "禁用"}用户{" ".join(map(str, state["user"]))}的插件{" ".join(state["plugin"]) if not state["is_all"] else "全部"}使用权限{extra_msg}' + if state['bool']: + await PluginDisable.filter(name__in=state['plugin'], **filter_arg).delete() + else: + for plugin in state['plugin']: + if state['group']: + for group in state['group']: + if state['user']: + for user in state['user']: + await PluginDisable.update_or_create(name=plugin, group_id=group, user_id=user) + else: + await PluginDisable.update_or_create(name=plugin, group_id=group) + else: + for user in state['user']: + await PluginDisable.update_or_create(name=plugin, user_id=user) + + await manage_cmd.finish(msg) @help_cmd.handle() @@ -181,15 +171,3 @@ async def _(event: MessageEvent, msg: Message = CommandArg()): else: result = ConfigManager.set_config(msg[0], msg[1]) await set_config_cmd.finish(result) - - -@notices.handle() -async def _(event: NoticeEvent): - plugin_list = nb_plugin.get_loaded_plugins() - if isinstance(event, FriendAddNoticeEvent): - await asyncio.gather(*[PluginPermission.update_or_create(name=plugin, session_id=event.user_id, session_type='user') for plugin - in plugin_list if plugin not in HIDDEN_PLUGINS]) - elif isinstance(event, GroupIncreaseNoticeEvent): - await asyncio.gather( - *[PluginPermission.update_or_create(name=plugin, session_id=event.group_id, session_type='group') for plugin - in plugin_list if plugin not in HIDDEN_PLUGINS]) diff --git a/LittlePaimon/utils/__init__.py b/LittlePaimon/utils/__init__.py index a060431..baa3b99 100644 --- a/LittlePaimon/utils/__init__.py +++ b/LittlePaimon/utils/__init__.py @@ -4,7 +4,7 @@ from nonebot import get_driver from .logger import logger from .scheduler import scheduler -__version__ = '3.0.0rc3' +__version__ = '3.0.0rc4' DRIVER = get_driver() try: diff --git a/LittlePaimon/web/api/plugin.py b/LittlePaimon/web/api/plugin.py index d38a86c..ed4391b 100644 --- a/LittlePaimon/web/api/plugin.py +++ b/LittlePaimon/web/api/plugin.py @@ -5,7 +5,7 @@ from fastapi import APIRouter from fastapi.responses import JSONResponse from LittlePaimon.config import ConfigManager, PluginManager, PluginInfo -from LittlePaimon.database import PluginPermission +from LittlePaimon.database import PluginDisable from .utils import authentication @@ -27,28 +27,31 @@ async def get_plugins(): @route.post('/set_plugin_status', response_class=JSONResponse, dependencies=[authentication()]) async def set_plugin_status(data: dict): - module_name = data.get('plugin') - status = data.get('status') + module_name: str = data.get('plugin') + status: bool = data.get('status') try: from LittlePaimon.plugins.plugin_manager import cache_help cache_help.clear() except Exception: pass - await PluginPermission.filter(name=module_name).update(status=status) + if status: + await PluginDisable.filter(name=module_name, global_disable=True).delete() + else: + await PluginDisable.create(name=module_name, global_disable=True) return {'status': 0, 'msg': f'成功设置{module_name}插件状态为{status}'} @route.get('/get_plugin_bans', response_class=JSONResponse, dependencies=[authentication()]) async def get_plugin_status(module_name: str): result = [] - bans = await PluginPermission.filter(name=module_name).all() + bans = await PluginDisable.filter(name=module_name).all() for ban in bans: - if ban.session_type == 'group': - result.extend(f'群{ban.session_id}.{b}' for b in ban.ban) - if not ban.status: - result.append(f'群{ban.session_id}') - elif ban.session_type == 'user' and not ban.status: - result.append(f'{ban.session_id}') + if ban.user_id and ban.group_id: + result.append(f'群{ban.group_id}.{ban.user_id}') + elif ban.group_id and not ban.user_id: + result.append(f'群{ban.group_id}') + elif ban.user_id and not ban.group_id: + result.append(f'{ban.user_id}') return { 'status': 0, 'msg': 'ok', @@ -63,20 +66,17 @@ async def get_plugin_status(module_name: str): async def set_plugin_bans(data: dict): bans = data['bans'] name = data['module_name'] - await PluginPermission.filter(name=name).update(status=True, ban=[]) + await PluginDisable.filter(name=name, global_disable=False).delete() for ban in bans: if ban.startswith('群'): if '.' in ban: group_id = int(ban.split('.')[0][1:]) user_id = int(ban.split('.')[1]) - plugin = await PluginPermission.filter(name=name, session_type='group', session_id=group_id).first() - plugin.ban.append(user_id) - await plugin.save() + await PluginDisable.create(name=name, group_id=group_id, user_id=user_id) else: - await PluginPermission.filter(name=name, session_type='group', session_id=int(ban[1:])).update( - status=False) + await PluginDisable.create(name=name, group_id=int(ban[1:])) else: - await PluginPermission.filter(name=name, session_type='user', session_id=int(ban)).update(status=False) + await PluginDisable.create(name=name, user_id=int(ban)) try: from LittlePaimon.plugins.plugin_manager import cache_help cache_help.clear()