diff --git a/LittlePaimon/config/plugin/manage.py b/LittlePaimon/config/plugin/manage.py index 4e3f02c..50f5b32 100644 --- a/LittlePaimon/config/plugin/manage.py +++ b/LittlePaimon/config/plugin/manage.py @@ -1,3 +1,4 @@ +import asyncio import datetime from typing import Dict, List @@ -46,15 +47,15 @@ class PluginManager: user_list = await get_bot().get_friend_list() for plugin in plugin_list: if plugin.name not in HIDDEN_PLUGINS: - models = ([PluginPermission(name=plugin.name, session_id=group['group_id'], session_type='group') for - group in group_list] if group_list else []) + \ - ([PluginPermission(name=plugin.name, session_id=user['user_id'], session_type='user') for - user in user_list] if user_list else []) - if models: - await PluginPermission.bulk_create( - models, - ignore_conflicts=True - ) + # 将所有PluginPermission相同的,只保留一个 + if group_list: + await asyncio.gather( + *[PluginPermission.update_or_create(name=plugin.name, session_id=group['group_id'], + session_type='group') for group in group_list]) + if user_list: + await asyncio.gather( + *[PluginPermission.update_or_create(name=plugin.name, session_id=user['user_id'], + session_type='user') for user in user_list]) if plugin.name not in HIDDEN_PLUGINS: if plugin.name not in cls.plugins: if metadata := plugin.metadata: @@ -142,8 +143,8 @@ async def _(event: MessageEvent, matcher: Matcher): return # 权限检查 - perm = await PluginPermission.get_or_none(name=matcher.plugin_name, session_id=session_id, - session_type=session_type) + perm = await PluginPermission.filter(name=matcher.plugin_name, session_id=session_id, + session_type=session_type).first() if not perm: await PluginPermission.create(name=matcher.plugin_name, session_id=session_id, session_type=session_type) return diff --git a/LittlePaimon/plugins/plugin_manager/__init__.py b/LittlePaimon/plugins/plugin_manager/__init__.py index e89a7a5..30ebc48 100644 --- a/LittlePaimon/plugins/plugin_manager/__init__.py +++ b/LittlePaimon/plugins/plugin_manager/__init__.py @@ -1,3 +1,4 @@ +import asyncio from nonebot import on_regex, on_command, on_notice from nonebot import plugin as nb_plugin from nonebot.adapters.onebot.v11 import Message, GroupMessageEvent, PrivateMessageEvent, MessageEvent @@ -26,8 +27,11 @@ __plugin_meta__ = PluginMetadata( ) -def notice_rule(event: NoticeEvent): - return isinstance(event, (FriendAddNoticeEvent, GroupIncreaseNoticeEvent)) +def notice_rule(event: NoticeEvent) -> bool: + if isinstance(event, FriendAddNoticeEvent): + return True + elif isinstance(event, GroupIncreaseNoticeEvent): + return event.user_id == event.self_id def fullmatch(msg: Message = CommandArg()) -> bool: @@ -183,18 +187,9 @@ async def _(event: MessageEvent, msg: Message = CommandArg()): async def _(event: NoticeEvent): plugin_list = nb_plugin.get_loaded_plugins() if isinstance(event, FriendAddNoticeEvent): - models = [ - PluginPermission(name=plugin, session_id=event.user_id, session_type='user') for plugin in plugin_list if plugin not in HIDDEN_PLUGINS - ] - elif isinstance(event, GroupIncreaseNoticeEvent) and event.user_id == event.self_id: - models = [ - PluginPermission(name=plugin, session_id=event.group_id, session_type='group') for plugin in plugin_list if - plugin not in HIDDEN_PLUGINS - ] - else: - return - if models: - await PluginPermission.bulk_create( - models, - ignore_conflicts=True - ) + await asyncio.gather(*[PluginPermission.update_or_create(name=plugin, session_id=event.user_id, session_type='user') for plugin + in plugin_list if plugin not in HIDDEN_PLUGINS]) + elif isinstance(event, GroupIncreaseNoticeEvent): + await asyncio.gather( + *[PluginPermission.update_or_create(name=plugin, session_id=event.group_id, session_type='group') for plugin + in plugin_list if plugin not in HIDDEN_PLUGINS])