优化群聊学习复读表现和插件权限管理

This commit is contained in:
CMHopeSunshine 2022-11-26 18:25:11 +08:00
parent fab1cc3f85
commit 689a7701d3
8 changed files with 208 additions and 217 deletions

View File

@ -2,6 +2,7 @@ from pathlib import Path
from nonebot import load_plugins, logger
from LittlePaimon import database, web
from LittlePaimon.config import PluginManager
from LittlePaimon.utils import DRIVER, __version__, NICKNAME, SUPERUSERS
from LittlePaimon.utils.tool import check_resource
@ -35,6 +36,7 @@ logo = """<g>
async def startup():
logger.opt(colors=True).info(logo)
await database.connect()
await PluginManager.init()
await check_resource()

View File

@ -1,17 +1,18 @@
import asyncio
import contextlib
import datetime
from typing import Dict, List
from nonebot import plugin as nb_plugin
from nonebot import get_bot
from nonebot.matcher import Matcher
from nonebot.exception import IgnoredException
from nonebot.message import run_preprocessor
from nonebot.adapters.onebot.v11 import MessageEvent, PrivateMessageEvent, GroupMessageEvent
from LittlePaimon.utils import logger, DRIVER, SUPERUSERS
from LittlePaimon.utils.path import PLUGIN_CONFIG
from nonebot.exception import IgnoredException
from nonebot.matcher import Matcher
from nonebot.message import run_preprocessor
from tortoise.queryset import Q
from LittlePaimon.database.models import PluginPermission, PluginStatistics, PluginDisable
from LittlePaimon.utils import logger, SUPERUSERS
from LittlePaimon.utils.files import load_yaml, save_yaml
from LittlePaimon.database.models import PluginPermission, PluginStatistics
from LittlePaimon.utils.path import PLUGIN_CONFIG
from .model import MatcherInfo, PluginInfo
HIDDEN_PLUGINS = [
@ -43,42 +44,22 @@ class PluginManager:
@classmethod
async def init(cls):
plugin_list = nb_plugin.get_loaded_plugins()
group_list = await get_bot().get_group_list()
user_list = await get_bot().get_friend_list()
if not await PluginDisable.all().exists() and await PluginPermission.all().exists():
perms = await PluginPermission.filter(Q(status=False) | Q(ban__not=[])).all()
for perm in perms:
with contextlib.suppress(Exception):
if perm.session_type == 'group':
if not perm.status:
await PluginDisable.update_or_create(name=perm.name, group_id=perm.session_id)
for ban_user in perm.ban:
await PluginDisable.update_or_create(name=perm.name, group_id=perm.session_id,
user_id=ban_user)
else:
if not perm.status:
await PluginDisable.update_or_create(name=perm.name, user_id=perm.session_id)
await PluginPermission.all().delete()
await PluginDisable.filter(global_disable=False, group_id=None, user_id=None).delete()
for plugin in plugin_list:
if plugin.name not in HIDDEN_PLUGINS and PluginPermission._meta.default_connection is not None:
if group_list:
for group in group_list:
count = await PluginPermission.filter(
name=plugin.name, session_id=group['group_id'], session_type='group'
).count()
if count > 1:
first = await PluginPermission.filter(
name=plugin.name, session_id=group['group_id'], session_type='group'
).order_by('id').first()
await PluginPermission.filter(
name=plugin.name, session_id=group['group_id'], session_type='group'
).delete()
await first.save()
elif count == 0:
await PluginPermission.create(name=plugin.name, session_id=group['group_id'],
session_type='group')
if user_list:
for user in user_list:
count = await PluginPermission.filter(
name=plugin.name, session_id=user['user_id'], session_type='user'
).count()
if count > 1:
first = await PluginPermission.filter(
name=plugin.name, session_id=user['user_id'], session_type='user'
).order_by('id').first()
await PluginPermission.filter(
name=plugin.name, session_id=user['user_id'], session_type='user'
).delete()
await first.save()
elif count == 0:
await PluginPermission.create(name=plugin.name, session_id=user['user_id'],
session_type='user')
if plugin.name not in HIDDEN_PLUGINS:
if plugin.name not in cls.plugins:
if metadata := plugin.metadata:
@ -113,15 +94,23 @@ class PluginManager:
:param message_type: 消息类型
:param session_id: 消息ID
"""
load_plugins = nb_plugin.get_loaded_plugins()
load_plugins = [p.name for p in load_plugins]
load_plugins = [p.name for p in nb_plugin.get_loaded_plugins()]
plugin_list = sorted(cls.plugins.values(), key=lambda x: x.priority).copy()
plugin_list = [p for p in plugin_list if p.show and p.module_name in load_plugins]
for plugin in plugin_list:
if message_type != 'guild':
plugin_info = await PluginPermission.get_or_none(name=plugin.module_name, session_id=session_id,
session_type=message_type)
plugin.status = True if plugin_info is None else plugin_info.status
if not await PluginDisable.filter(name=plugin.module_name, global_disable=True).exists():
if message_type != 'guild':
# plugin_info = await PluginPermission.get_or_none(name=plugin.module_name, session_id=session_id,
# session_type=message_type)
# plugin.status = True if plugin_info is None else plugin_info.status
if message_type == 'group':
plugin.status = not await PluginDisable.filter(name=plugin.module_name,
group_id=session_id).exists()
else:
plugin.status = not await PluginDisable.filter(name=plugin.module_name,
user_id=session_id).exists()
else:
plugin.status = True
else:
plugin.status = True
if plugin.matchers:
@ -134,57 +123,48 @@ class PluginManager:
"""
获取插件列表供Web UI使用
"""
load_plugins = nb_plugin.get_loaded_plugins()
load_plugins = [p.name for p in load_plugins]
load_plugins = [p.name for p in nb_plugin.get_loaded_plugins()]
plugin_list = [p.dict(exclude={'status'}) for p in cls.plugins.values()]
for plugin in plugin_list:
plugin['matchers'].sort(key=lambda x: x['pm_priority'])
plugin['isLoad'] = plugin['module_name'] in load_plugins
plugin['status'] = await PluginPermission.filter(name=plugin['module_name'], status=True).exists()
plugin['status'] = not await PluginDisable.filter(name=plugin['module_name'], global_disable=True).exists()
plugin_list.sort(key=lambda x: (x['isLoad'], x['status'], -x['priority']), reverse=True)
return plugin_list
@DRIVER.on_bot_connect
async def _():
await PluginManager.init()
@run_preprocessor
async def _(event: MessageEvent, matcher: Matcher):
if event.user_id in SUPERUSERS:
return
if not matcher.plugin_name or matcher.plugin_name in HIDDEN_PLUGINS:
return
if isinstance(event, PrivateMessageEvent):
session_id = event.user_id
session_type = 'user'
elif isinstance(event, GroupMessageEvent):
session_id = event.group_id
session_type = 'group'
else:
return
try:
if event.user_id in SUPERUSERS:
return
if not matcher.plugin_name or matcher.plugin_name in HIDDEN_PLUGINS:
return
if not isinstance(event, (PrivateMessageEvent, GroupMessageEvent)):
return
# 权限检查
perm = await PluginPermission.get_or_none(name=matcher.plugin_name, session_id=session_id,
session_type=session_type)
if not perm:
await PluginPermission.create(name=matcher.plugin_name, session_id=session_id, session_type=session_type)
return
if not perm.status:
raise IgnoredException('插件使用权限已禁用')
if isinstance(event, GroupMessageEvent) and event.user_id in perm.ban:
raise IgnoredException('用户被禁止使用该插件')
# 权限检查
if await PluginDisable.get_or_none(name=matcher.plugin_name, global_disable=True):
raise IgnoredException('插件使用权限已禁用')
if await PluginDisable.get_or_none(name=matcher.plugin_name, user_id=event.user_id, group_id=None):
raise IgnoredException('插件使用权限已禁用')
elif isinstance(event, GroupMessageEvent) and (
perms := await PluginDisable.filter(name=matcher.plugin_name, group_id=event.group_id)):
user_ids = [p.user_id for p in perms]
if None in user_ids or event.user_id in user_ids:
raise IgnoredException('插件使用权限已禁用')
# 命令调用统计
if matcher.plugin_name in PluginManager.plugins and 'pm_name' in matcher.state:
if matcher_info := list(filter(lambda x: x.pm_name == matcher.state['pm_name'],
PluginManager.plugins[matcher.plugin_name].matchers)):
matcher_info = matcher_info[0]
await PluginStatistics.create(plugin_name=matcher.plugin_name,
matcher_name=matcher_info.pm_name,
matcher_usage=matcher_info.pm_usage,
group_id=event.group_id if isinstance(event, GroupMessageEvent) else None,
user_id=event.user_id,
message_type=session_type,
time=datetime.datetime.now())
# 命令调用统计
if matcher.plugin_name in PluginManager.plugins and 'pm_name' in matcher.state:
if matcher_info := list(filter(lambda x: x.pm_name == matcher.state['pm_name'],
PluginManager.plugins[matcher.plugin_name].matchers)):
matcher_info = matcher_info[0]
await PluginStatistics.create(plugin_name=matcher.plugin_name,
matcher_name=matcher_info.pm_name,
matcher_usage=matcher_info.pm_usage,
group_id=event.group_id if isinstance(event, GroupMessageEvent) else None,
user_id=event.user_id,
message_type=event.message_type,
time=datetime.datetime.now())
except Exception as e:
logger.info('插件管理器', f'插件权限检查<r>失败:{e}</r>')

View File

@ -7,6 +7,7 @@ from tortoise.models import Model
class PluginPermission(Model):
"""将在N个版本后废弃"""
id = fields.IntField(pk=True, generated=True, auto_increment=True)
name: str = fields.TextField()
"""插件名称"""
@ -25,6 +26,21 @@ class PluginPermission(Model):
table = 'plugin_permission'
class PluginDisable(Model):
id = fields.IntField(pk=True, generated=True, auto_increment=True)
name: str = fields.TextField()
"""插件名称"""
global_disable: bool = fields.BooleanField(default=False)
"""全局禁用"""
user_id: int = fields.IntField(null=True)
"""用户id"""
group_id: int = fields.IntField(null=True)
"""群组id"""
class Meta:
table = 'plugin_disable'
class PluginStatistics(Model):
id = fields.IntField(pk=True, generated=True, auto_increment=True)
plugin_name: str = fields.TextField()

View File

@ -45,7 +45,6 @@ learning_chat = on_message(priority=99, block=False, rule=Rule(ChatRule), permis
@learning_chat.handle()
async def _(event: GroupMessageEvent, answers=Arg('answers')):
for answer in answers:
await asyncio.sleep(random.randint(1, 2))
try:
logger.info('群聊学习', f'{NICKNAME}将向群<m>{event.group_id}</m>回复<m>"{answer}"</m>')
msg = await learning_chat.send(Message(answer))
@ -56,6 +55,7 @@ async def _(event: GroupMessageEvent, answers=Arg('answers')):
raw_message=answer,
time=int(time.time()),
plain_text=Message(answer).extract_plain_text())
await asyncio.sleep(random.random() + 0.5)
except ActionFailed:
logger.info('群聊学习', f'{NICKNAME}向群<m>{event.group_id}</m>的回复<m>"{answer}"</m>发送<r>失败,可能处于风控中</r>')

View File

@ -1,3 +1,4 @@
import asyncio
import datetime
import random
import re
@ -90,7 +91,7 @@ class LearningChat:
return Result.Pass
elif self.reply:
# 如果是回复消息
if not (message := await ChatMessage.get_or_none(message_id=self.reply.message_id)):
if not (message := await ChatMessage.filter(message_id=self.reply.message_id).first()):
# 回复的消息在数据库中有记录
logger.debug('群聊学习', '➤回复的消息不在数据库中,跳过')
return Result.Pass
@ -167,10 +168,17 @@ class LearningChat:
elif result == Result.Pass:
# 跳过
return None
elif result == Result.Repeat and (messages := await ChatMessage.filter(group_id=self.data.group_id,
time__gte=self.data.time - 3600).limit(
self.config.repeat_threshold)):
# 如果达到阈值且bot没有回复过且不是全都为同一个人在说则进行复读
elif result == Result.Repeat:
query_set = ChatMessage.filter(group_id=self.data.group_id, time__gte=self.data.time - 3600)
if await query_set.limit(self.config.repeat_threshold + 5).filter(
user_id=self.bot_id, message=self.data.message).exists():
# 如果在阈值+5条消息内bot已经回复过这句话则跳过
logger.debug('群聊学习', f'➤➤已经复读过了,跳过')
return None
if not (messages := await query_set.limit(
self.config.repeat_threshold + 5)):
return None
# 如果达到阈值,且不是全都为同一个人在说,则进行复读
if len(messages) >= self.config.repeat_threshold and all(
message.message == self.data.message and message.user_id != self.bot_id
for message in messages) and not all(
@ -181,12 +189,13 @@ class LearningChat:
else:
logger.debug('群聊学习', f'➤➤达到复读阈值,复读<m>{messages[0].message}</m>')
return [self.data.message]
return None
else:
# 回复
if self.data.is_plain_text and len(self.data.plain_text) <= 1:
logger.debug('群聊学习', '➤➤消息过短,不回复')
return None
if not (context := await ChatContext.get_or_none(keywords=self.data.keywords)):
if not (context := await ChatContext.filter(keywords=self.data.keywords).first()):
logger.debug('群聊学习', '➤➤尚未有已学习的回复,不回复')
return None
@ -204,7 +213,8 @@ class LearningChat:
else:
answer_count_threshold = 1
cross_group_threshold = 1
logger.debug('群聊学习', f'➤➤本次回复阈值为<m>{answer_count_threshold}</m>,跨群阈值为<m>{cross_group_threshold}</m>')
logger.debug('群聊学习',
f'➤➤本次回复阈值为<m>{answer_count_threshold}</m>,跨群阈值为<m>{cross_group_threshold}</m>')
# 获取满足跨群条件的回复
answers_cross = await ChatAnswer.filter(context=context, count__gte=answer_count_threshold,
keywords__in=await ChatAnswer.annotate(
@ -241,6 +251,7 @@ class LearningChat:
return None
result_message = random.choice(result.messages)
logger.debug('群聊学习', f'➤➤将回复<m>{result_message}</m>')
await asyncio.sleep(random.random() + 0.5)
return [result_message]
async def _ban(self, message_id: Optional[int] = None) -> bool:
@ -248,7 +259,9 @@ class LearningChat:
bot = get_bot()
if message_id:
# 如果有指定消息ID则屏蔽该消息
if (message := await ChatMessage.get_or_none(message_id=message_id)) and message.message not in ALL_WORDS:
if (
message := await ChatMessage.filter(
message_id=message_id).first()) and message.message not in ALL_WORDS:
keywords = message.keywords
try:
await bot.delete_msg(message_id=message_id)
@ -266,7 +279,7 @@ class LearningChat:
logger.info('群聊学习', f'待禁用消息<m>{last_reply.message_id}</m>尝试撤回<r>失败</r>')
else:
return False
if ban_word := await ChatBlackList.get_or_none(keywords=keywords):
if ban_word := await ChatBlackList.filter(keywords=keywords).first():
# 如果已有屏蔽记录
if self.data.group_id not in ban_word.ban_group_id:
# 如果不在屏蔽群列表中,则添加
@ -290,7 +303,7 @@ class LearningChat:
@staticmethod
async def add_ban(data: Union[ChatMessage, ChatContext, ChatAnswer]):
if ban_word := await ChatBlackList.get_or_none(keywords=data.keywords):
if ban_word := await ChatBlackList.filter(keywords=data.keywords).first():
# 如果已有屏蔽记录
if isinstance(data, ChatMessage):
if data.group_id not in ban_word.ban_group_id:
@ -360,7 +373,9 @@ class LearningChat:
continue
config = config_manager.get_group_config(group_id)
ban_words = set(chat_config.ban_words + config.ban_words + ['[CQ:xml', '[CQ:json', '[CQ:at', '[CQ:video', '[CQ:record', '[CQ:share'])
ban_words = set(
chat_config.ban_words + config.ban_words + ['[CQ:xml', '[CQ:json', '[CQ:at', '[CQ:video', '[CQ:record',
'[CQ:share'])
# 是否开启了主动发言
if not config.speak_enable:
@ -400,7 +415,7 @@ class LearningChat:
speak_list.append(message)
while random.random() < config.speak_continuously_probability and len(
speak_list) < config.speak_continuously_max_len:
if (follow_context := await ChatContext.get_or_none(keywords=answer.keywords)) and (
if (follow_context := await ChatContext.filter(keywords=answer.keywords).first()) and (
follow_answers := await ChatAnswer.filter(
group_id=group_id,
context=follow_context,
@ -432,13 +447,13 @@ class LearningChat:
return None
async def _set_answer(self, message: ChatMessage):
if context := await ChatContext.get_or_none(keywords=message.keywords):
if context := await ChatContext.filter(keywords=message.keywords).first():
if context.count < chat_config.learn_max_count:
context.count += 1
context.time = self.data.time
if answer := await ChatAnswer.get_or_none(keywords=self.data.keywords,
group_id=self.data.group_id,
context=context):
if answer := await ChatAnswer.filter(keywords=self.data.keywords,
group_id=self.data.group_id,
context=context).first():
if answer.count < chat_config.learn_max_count:
answer.count += 1
answer.time = self.data.time
@ -476,7 +491,7 @@ class LearningChat:
if raw_message.startswith('&#91;') and raw_message.endswith('&#93;'):
# logger.debug('群聊学习', f'➤检验<m>{keywords}</m><r>不通过</r>')
return False
if ban_word := await ChatBlackList.get_or_none(keywords=message.keywords):
if ban_word := await ChatBlackList.filter(keywords=message.keywords).first():
if ban_word.global_ban or message.group_id in ban_word.ban_group_id:
# logger.debug('群聊学习', f'➤检验<m>{keywords}</m><r>不通过</r>')
return False

View File

@ -1,8 +1,5 @@
import asyncio
from nonebot import on_regex, on_command, on_notice
from nonebot import plugin as nb_plugin
from nonebot import on_regex, on_command
from nonebot.adapters.onebot.v11 import Message, GroupMessageEvent, PrivateMessageEvent, MessageEvent
from nonebot.adapters.onebot.v11 import NoticeEvent, FriendAddNoticeEvent, GroupIncreaseNoticeEvent
from nonebot.params import RegexDict, CommandArg
from nonebot.permission import SUPERUSER
from nonebot.plugin import PluginMetadata
@ -10,8 +7,8 @@ from nonebot.rule import Rule
from nonebot.typing import T_State
from LittlePaimon import SUPERUSERS
from LittlePaimon.config import ConfigManager, PluginManager, HIDDEN_PLUGINS
from LittlePaimon.database import PluginPermission
from LittlePaimon.config import ConfigManager, PluginManager
from LittlePaimon.database import PluginDisable
from LittlePaimon.utils import logger
from LittlePaimon.utils.message import CommandObjectID
from .draw_help import draw_help
@ -27,19 +24,12 @@ __plugin_meta__ = PluginMetadata(
)
def notice_rule(event: NoticeEvent) -> bool:
if isinstance(event, FriendAddNoticeEvent):
return True
elif isinstance(event, GroupIncreaseNoticeEvent):
return event.user_id == event.self_id
def fullmatch(msg: Message = CommandArg()) -> bool:
return not bool(msg)
manage_cmd = on_regex(
r'^pm (?P<func>ban|unban) (?P<plugin>([\w ]*)|all|全部) ?(-g (?P<group>[\d ]*) ?)?(-u (?P<user>[\d ]*) ?)?(?P<reserve>-r)?',
r'^pm (?P<func>ban|unban) (?P<plugin>([\w ]*)|all|全部) ?(-g (?P<group>[\d ]*) ?)?(-u (?P<user>[\d ]*) ?)?',
priority=1, block=True, state={
'pm_name': 'pm-ban|unban',
'pm_description': '禁用|取消禁用插件的群|用户使用权限',
@ -58,11 +48,6 @@ set_config_cmd = on_command('pm set', priority=1, permission=SUPERUSER, block=Tr
'pm_usage': 'pm set<配置名> <值>',
'pm_priority': 2
})
notices = on_notice(priority=1, rule=Rule(notice_rule), block=True, state={
'pm_name': 'pm-new-group-user',
'pm_description': '为新加入的群|用户添加插件使用权限',
'pm_show': False
})
cache_help = {}
@ -73,21 +58,26 @@ async def _(event: GroupMessageEvent, state: T_State, match: dict = RegexDict(),
await manage_cmd.finish('你没有权限使用该命令', at_sender=True)
state['session_id'] = session_id
state['bool'] = match['func'] == 'unban'
state['plugin'] = []
state['plugin_no_exist'] = []
for plugin in match['plugin'].strip().split(' '):
if plugin in PluginManager.plugins.keys() or plugin in ['all', '全部']:
state['plugin'].append(plugin)
elif module_name := list(
filter(lambda x: PluginManager.plugins[x].name == plugin, PluginManager.plugins.keys())):
state['plugin'].append(module_name[0])
else:
state['plugin_no_exist'].append(plugin)
if any(w in match['plugin'] for w in {'all', '全部'}):
state['is_all'] = True
state['plugin'] = [p for p in PluginManager.plugins.keys() if p != 'plugin_manager']
else:
state['is_all'] = False
state['plugin'] = []
for plugin in match['plugin'].strip().split(' '):
if plugin in PluginManager.plugins.keys():
state['plugin'].append(plugin)
elif module_name := list(
filter(lambda x: PluginManager.plugins[x].name == plugin, PluginManager.plugins.keys())):
state['plugin'].append(module_name[0])
else:
state['plugin_no_exist'].append(plugin)
if not match['group'] or event.user_id not in SUPERUSERS:
state['group'] = [event.group_id]
else:
state['group'] = [int(group) for group in match['group'].strip().split(' ')]
state['user'] = [int(user) for user in match['user'].strip().split(' ')] if match['user'] else []
state['user'] = [int(user) for user in match['user'].strip().split(' ')] if match['user'] else None
@manage_cmd.handle()
@ -96,18 +86,23 @@ async def _(event: PrivateMessageEvent, state: T_State, match: dict = RegexDict(
await manage_cmd.finish('你没有权限使用该命令', at_sender=True)
state['session_id'] = session_id
state['bool'] = match['func'] == 'unban'
state['plugin'] = []
state['plugin_no_exist'] = []
for plugin in match['plugin'].strip().split(' '):
if plugin in PluginManager.plugins.keys() or plugin in ['all', '全部']:
state['plugin'].append(plugin)
elif module_name := list(
filter(lambda x: PluginManager.plugins[x].name == plugin, PluginManager.plugins.keys())):
state['plugin'].append(module_name[0])
else:
state['plugin_no_exist'].append(plugin)
state['group'] = [int(group) for group in match['group'].strip().split(' ')] if match['group'] else []
state['user'] = [int(user) for user in match['user'].strip().split(' ')] if match['user'] else []
if any(w in match['plugin'] for w in {'all', '全部'}):
state['is_all'] = True
state['plugin'] = [p for p in PluginManager.plugins.keys() if p != 'plugin_manager']
else:
state['is_all'] = False
state['plugin'] = []
for plugin in match['plugin'].strip().split(' '):
if plugin in PluginManager.plugins.keys():
state['plugin'].append(plugin)
elif module_name := list(
filter(lambda x: PluginManager.plugins[x].name == plugin, PluginManager.plugins.keys())):
state['plugin'].append(module_name[0])
else:
state['plugin_no_exist'].append(plugin)
state['group'] = [int(group) for group in match['group'].strip().split(' ')] if match['group'] else None
state['user'] = [int(user) for user in match['user'].strip().split(' ')] if match['user'] else None
@manage_cmd.got('bool')
@ -119,45 +114,40 @@ async def _(state: T_State):
if not state['plugin'] and state['plugin_no_exist']:
await manage_cmd.finish(f'没有叫{" ".join(state["plugin_no_exist"])}的插件')
extra_msg = f',但没有叫{" ".join(state["plugin_no_exist"])}的插件。' if state['plugin_no_exist'] else ''
if state['group'] and not state['user']:
for group_id in state['group']:
if 'all' in state['plugin']:
await PluginPermission.filter(session_id=group_id, session_type='group').update(status=state['bool'])
else:
await PluginPermission.filter(name__in=state['plugin'], session_id=group_id,
session_type='group').update(
status=state['bool'])
logger.info('插件管理器',
f'{"<g>启用</g>" if state["bool"] else "<r>禁用</r>"}群<m>{" ".join(map(str, state["group"]))}</m>的插件<m>{" ".join(state["plugin"])}</m>使用权限')
await manage_cmd.finish(
f'{"启用" if state["bool"] else "禁用"}{" ".join(map(str, state["group"]))}的插件{" ".join(state["plugin"])}使用权限{extra_msg}')
elif state['user'] and not state['group']:
for user_id in state['user']:
if 'all' in state['plugin']:
await PluginPermission.filter(session_id=user_id, session_type='user').update(status=state['bool'])
else:
await PluginPermission.filter(name__in=state['plugin'], session_id=user_id, session_type='user').update(
status=state['bool'])
logger.info('插件管理器',
f'{"<g>启用</g>" if state["bool"] else "<r>禁用</r>"}用户<m>{" ".join(map(str, state["user"]))}</m>的插件<m>{" ".join(state["plugin"])}</m>使用权限')
await manage_cmd.finish(
f'{"启用" if state["bool"] else "禁用"}用户{" ".join(map(str, state["user"]))}的插件{" ".join(state["plugin"])}使用权限{extra_msg}')
filter_arg = {}
if state['group']:
filter_arg['group_id__in'] = state['group']
if state['user']:
filter_arg['user_id__in'] = state['user']
logger.info('插件管理器',
f'{"<g>启用</g>" if state["bool"] else "<r>禁用</r>"}群<m>{" ".join(map(str, state["group"]))}</m>中用户<m>{" ".join(map(str, state["user"]))}</m>的插件<m>{" ".join(state["plugin"]) if not state["is_all"] else "全部"}</m>使用权限')
msg = f'{"启用" if state["bool"] else "禁用"}{" ".join(map(str, state["group"]))}中用户{" ".join(map(str, state["user"]))}的插件{" ".join(state["plugin"]) if not state["is_all"] else "全部"}使用权限{extra_msg}'
else:
filter_arg['user_id'] = None
logger.info('插件管理器',
f'{"<g>启用</g>" if state["bool"] else "<r>禁用</r>"}群<m>{" ".join(map(str, state["group"]))}</m>的插件<m>{" ".join(state["plugin"]) if not state["is_all"] else "全部"}</m>使用权限')
msg = f'{"启用" if state["bool"] else "禁用"}{" ".join(map(str, state["group"]))}的插件{" ".join(state["plugin"]) if not state["is_all"] else "全部"}使用权限{extra_msg}'
else:
for group_id in state['group']:
if 'all' in state['plugin']:
plugin_list = await PluginPermission.filter(session_id=group_id, session_type='group').all()
else:
plugin_list = await PluginPermission.filter(name__in=state['plugin'], session_id=group_id,
session_type='group').all()
if plugin_list:
for plugin in plugin_list:
plugin.ban = list(set(plugin.ban) - set(state['user'])) if state['bool'] else list(
set(plugin.ban) | set(state['user']))
await plugin.save()
filter_arg['user_id__in'] = state['user']
logger.info('插件管理器',
f'{"<g>启用</g>" if state["bool"] else "<r>禁用</r>"}群<m>{" ".join(map(str, state["group"]))}</m>中用户<m>{" ".join(map(str, state["user"]))}</m>的插件<m>{" ".join(state["plugin"])}</m>使用权限')
await manage_cmd.finish(
f'{"启用" if state["bool"] else "禁用"}{" ".join(map(str, state["group"]))}中用户{" ".join(map(str, state["user"]))}的插件{" ".join(state["plugin"])}使用权限{extra_msg}')
f'{"<g>启用</g>" if state["bool"] else "<r>禁用</r>"}用户<m>{" ".join(map(str, state["user"]))}</m>的插件<m>{" ".join(state["plugin"]) if not state["is_all"] else "全部"}</m>使用权限')
msg = f'{"启用" if state["bool"] else "禁用"}用户{" ".join(map(str, state["user"]))}的插件{" ".join(state["plugin"]) if not state["is_all"] else "全部"}使用权限{extra_msg}'
if state['bool']:
await PluginDisable.filter(name__in=state['plugin'], **filter_arg).delete()
else:
for plugin in state['plugin']:
if state['group']:
for group in state['group']:
if state['user']:
for user in state['user']:
await PluginDisable.update_or_create(name=plugin, group_id=group, user_id=user)
else:
await PluginDisable.update_or_create(name=plugin, group_id=group)
else:
for user in state['user']:
await PluginDisable.update_or_create(name=plugin, user_id=user)
await manage_cmd.finish(msg)
@help_cmd.handle()
@ -181,15 +171,3 @@ async def _(event: MessageEvent, msg: Message = CommandArg()):
else:
result = ConfigManager.set_config(msg[0], msg[1])
await set_config_cmd.finish(result)
@notices.handle()
async def _(event: NoticeEvent):
plugin_list = nb_plugin.get_loaded_plugins()
if isinstance(event, FriendAddNoticeEvent):
await asyncio.gather(*[PluginPermission.update_or_create(name=plugin, session_id=event.user_id, session_type='user') for plugin
in plugin_list if plugin not in HIDDEN_PLUGINS])
elif isinstance(event, GroupIncreaseNoticeEvent):
await asyncio.gather(
*[PluginPermission.update_or_create(name=plugin, session_id=event.group_id, session_type='group') for plugin
in plugin_list if plugin not in HIDDEN_PLUGINS])

View File

@ -4,7 +4,7 @@ from nonebot import get_driver
from .logger import logger
from .scheduler import scheduler
__version__ = '3.0.0rc3'
__version__ = '3.0.0rc4'
DRIVER = get_driver()
try:

View File

@ -5,7 +5,7 @@ from fastapi import APIRouter
from fastapi.responses import JSONResponse
from LittlePaimon.config import ConfigManager, PluginManager, PluginInfo
from LittlePaimon.database import PluginPermission
from LittlePaimon.database import PluginDisable
from .utils import authentication
@ -27,28 +27,31 @@ async def get_plugins():
@route.post('/set_plugin_status', response_class=JSONResponse, dependencies=[authentication()])
async def set_plugin_status(data: dict):
module_name = data.get('plugin')
status = data.get('status')
module_name: str = data.get('plugin')
status: bool = data.get('status')
try:
from LittlePaimon.plugins.plugin_manager import cache_help
cache_help.clear()
except Exception:
pass
await PluginPermission.filter(name=module_name).update(status=status)
if status:
await PluginDisable.filter(name=module_name, global_disable=True).delete()
else:
await PluginDisable.create(name=module_name, global_disable=True)
return {'status': 0, 'msg': f'成功设置{module_name}插件状态为{status}'}
@route.get('/get_plugin_bans', response_class=JSONResponse, dependencies=[authentication()])
async def get_plugin_status(module_name: str):
result = []
bans = await PluginPermission.filter(name=module_name).all()
bans = await PluginDisable.filter(name=module_name).all()
for ban in bans:
if ban.session_type == 'group':
result.extend(f'{ban.session_id}.{b}' for b in ban.ban)
if not ban.status:
result.append(f'{ban.session_id}')
elif ban.session_type == 'user' and not ban.status:
result.append(f'{ban.session_id}')
if ban.user_id and ban.group_id:
result.append(f'{ban.group_id}.{ban.user_id}')
elif ban.group_id and not ban.user_id:
result.append(f'{ban.group_id}')
elif ban.user_id and not ban.group_id:
result.append(f'{ban.user_id}')
return {
'status': 0,
'msg': 'ok',
@ -63,20 +66,17 @@ async def get_plugin_status(module_name: str):
async def set_plugin_bans(data: dict):
bans = data['bans']
name = data['module_name']
await PluginPermission.filter(name=name).update(status=True, ban=[])
await PluginDisable.filter(name=name, global_disable=False).delete()
for ban in bans:
if ban.startswith(''):
if '.' in ban:
group_id = int(ban.split('.')[0][1:])
user_id = int(ban.split('.')[1])
plugin = await PluginPermission.filter(name=name, session_type='group', session_id=group_id).first()
plugin.ban.append(user_id)
await plugin.save()
await PluginDisable.create(name=name, group_id=group_id, user_id=user_id)
else:
await PluginPermission.filter(name=name, session_type='group', session_id=int(ban[1:])).update(
status=False)
await PluginDisable.create(name=name, group_id=int(ban[1:]))
else:
await PluginPermission.filter(name=name, session_type='user', session_id=int(ban)).update(status=False)
await PluginDisable.create(name=name, user_id=int(ban))
try:
from LittlePaimon.plugins.plugin_manager import cache_help
cache_help.clear()