mirror of
https://github.com/xuthus83/LittlePaimon.git
synced 2024-12-16 13:40:53 +08:00
✨ 优化群聊学习
复读表现和插件权限管理
This commit is contained in:
parent
fab1cc3f85
commit
689a7701d3
@ -2,6 +2,7 @@ from pathlib import Path
|
||||
|
||||
from nonebot import load_plugins, logger
|
||||
from LittlePaimon import database, web
|
||||
from LittlePaimon.config import PluginManager
|
||||
from LittlePaimon.utils import DRIVER, __version__, NICKNAME, SUPERUSERS
|
||||
from LittlePaimon.utils.tool import check_resource
|
||||
|
||||
@ -35,6 +36,7 @@ logo = """<g>
|
||||
async def startup():
|
||||
logger.opt(colors=True).info(logo)
|
||||
await database.connect()
|
||||
await PluginManager.init()
|
||||
await check_resource()
|
||||
|
||||
|
||||
|
@ -1,17 +1,18 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import datetime
|
||||
from typing import Dict, List
|
||||
|
||||
from nonebot import plugin as nb_plugin
|
||||
from nonebot import get_bot
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot.exception import IgnoredException
|
||||
from nonebot.message import run_preprocessor
|
||||
from nonebot.adapters.onebot.v11 import MessageEvent, PrivateMessageEvent, GroupMessageEvent
|
||||
from LittlePaimon.utils import logger, DRIVER, SUPERUSERS
|
||||
from LittlePaimon.utils.path import PLUGIN_CONFIG
|
||||
from nonebot.exception import IgnoredException
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot.message import run_preprocessor
|
||||
from tortoise.queryset import Q
|
||||
|
||||
from LittlePaimon.database.models import PluginPermission, PluginStatistics, PluginDisable
|
||||
from LittlePaimon.utils import logger, SUPERUSERS
|
||||
from LittlePaimon.utils.files import load_yaml, save_yaml
|
||||
from LittlePaimon.database.models import PluginPermission, PluginStatistics
|
||||
from LittlePaimon.utils.path import PLUGIN_CONFIG
|
||||
from .model import MatcherInfo, PluginInfo
|
||||
|
||||
HIDDEN_PLUGINS = [
|
||||
@ -43,42 +44,22 @@ class PluginManager:
|
||||
@classmethod
|
||||
async def init(cls):
|
||||
plugin_list = nb_plugin.get_loaded_plugins()
|
||||
group_list = await get_bot().get_group_list()
|
||||
user_list = await get_bot().get_friend_list()
|
||||
if not await PluginDisable.all().exists() and await PluginPermission.all().exists():
|
||||
perms = await PluginPermission.filter(Q(status=False) | Q(ban__not=[])).all()
|
||||
for perm in perms:
|
||||
with contextlib.suppress(Exception):
|
||||
if perm.session_type == 'group':
|
||||
if not perm.status:
|
||||
await PluginDisable.update_or_create(name=perm.name, group_id=perm.session_id)
|
||||
for ban_user in perm.ban:
|
||||
await PluginDisable.update_or_create(name=perm.name, group_id=perm.session_id,
|
||||
user_id=ban_user)
|
||||
else:
|
||||
if not perm.status:
|
||||
await PluginDisable.update_or_create(name=perm.name, user_id=perm.session_id)
|
||||
await PluginPermission.all().delete()
|
||||
await PluginDisable.filter(global_disable=False, group_id=None, user_id=None).delete()
|
||||
for plugin in plugin_list:
|
||||
if plugin.name not in HIDDEN_PLUGINS and PluginPermission._meta.default_connection is not None:
|
||||
if group_list:
|
||||
for group in group_list:
|
||||
count = await PluginPermission.filter(
|
||||
name=plugin.name, session_id=group['group_id'], session_type='group'
|
||||
).count()
|
||||
if count > 1:
|
||||
first = await PluginPermission.filter(
|
||||
name=plugin.name, session_id=group['group_id'], session_type='group'
|
||||
).order_by('id').first()
|
||||
await PluginPermission.filter(
|
||||
name=plugin.name, session_id=group['group_id'], session_type='group'
|
||||
).delete()
|
||||
await first.save()
|
||||
elif count == 0:
|
||||
await PluginPermission.create(name=plugin.name, session_id=group['group_id'],
|
||||
session_type='group')
|
||||
if user_list:
|
||||
for user in user_list:
|
||||
count = await PluginPermission.filter(
|
||||
name=plugin.name, session_id=user['user_id'], session_type='user'
|
||||
).count()
|
||||
if count > 1:
|
||||
first = await PluginPermission.filter(
|
||||
name=plugin.name, session_id=user['user_id'], session_type='user'
|
||||
).order_by('id').first()
|
||||
await PluginPermission.filter(
|
||||
name=plugin.name, session_id=user['user_id'], session_type='user'
|
||||
).delete()
|
||||
await first.save()
|
||||
elif count == 0:
|
||||
await PluginPermission.create(name=plugin.name, session_id=user['user_id'],
|
||||
session_type='user')
|
||||
if plugin.name not in HIDDEN_PLUGINS:
|
||||
if plugin.name not in cls.plugins:
|
||||
if metadata := plugin.metadata:
|
||||
@ -113,15 +94,23 @@ class PluginManager:
|
||||
:param message_type: 消息类型
|
||||
:param session_id: 消息ID
|
||||
"""
|
||||
load_plugins = nb_plugin.get_loaded_plugins()
|
||||
load_plugins = [p.name for p in load_plugins]
|
||||
load_plugins = [p.name for p in nb_plugin.get_loaded_plugins()]
|
||||
plugin_list = sorted(cls.plugins.values(), key=lambda x: x.priority).copy()
|
||||
plugin_list = [p for p in plugin_list if p.show and p.module_name in load_plugins]
|
||||
for plugin in plugin_list:
|
||||
if message_type != 'guild':
|
||||
plugin_info = await PluginPermission.get_or_none(name=plugin.module_name, session_id=session_id,
|
||||
session_type=message_type)
|
||||
plugin.status = True if plugin_info is None else plugin_info.status
|
||||
if not await PluginDisable.filter(name=plugin.module_name, global_disable=True).exists():
|
||||
if message_type != 'guild':
|
||||
# plugin_info = await PluginPermission.get_or_none(name=plugin.module_name, session_id=session_id,
|
||||
# session_type=message_type)
|
||||
# plugin.status = True if plugin_info is None else plugin_info.status
|
||||
if message_type == 'group':
|
||||
plugin.status = not await PluginDisable.filter(name=plugin.module_name,
|
||||
group_id=session_id).exists()
|
||||
else:
|
||||
plugin.status = not await PluginDisable.filter(name=plugin.module_name,
|
||||
user_id=session_id).exists()
|
||||
else:
|
||||
plugin.status = True
|
||||
else:
|
||||
plugin.status = True
|
||||
if plugin.matchers:
|
||||
@ -134,57 +123,48 @@ class PluginManager:
|
||||
"""
|
||||
获取插件列表(供Web UI使用)
|
||||
"""
|
||||
load_plugins = nb_plugin.get_loaded_plugins()
|
||||
load_plugins = [p.name for p in load_plugins]
|
||||
load_plugins = [p.name for p in nb_plugin.get_loaded_plugins()]
|
||||
plugin_list = [p.dict(exclude={'status'}) for p in cls.plugins.values()]
|
||||
for plugin in plugin_list:
|
||||
plugin['matchers'].sort(key=lambda x: x['pm_priority'])
|
||||
plugin['isLoad'] = plugin['module_name'] in load_plugins
|
||||
plugin['status'] = await PluginPermission.filter(name=plugin['module_name'], status=True).exists()
|
||||
plugin['status'] = not await PluginDisable.filter(name=plugin['module_name'], global_disable=True).exists()
|
||||
plugin_list.sort(key=lambda x: (x['isLoad'], x['status'], -x['priority']), reverse=True)
|
||||
return plugin_list
|
||||
|
||||
|
||||
@DRIVER.on_bot_connect
|
||||
async def _():
|
||||
await PluginManager.init()
|
||||
|
||||
|
||||
@run_preprocessor
|
||||
async def _(event: MessageEvent, matcher: Matcher):
|
||||
if event.user_id in SUPERUSERS:
|
||||
return
|
||||
if not matcher.plugin_name or matcher.plugin_name in HIDDEN_PLUGINS:
|
||||
return
|
||||
if isinstance(event, PrivateMessageEvent):
|
||||
session_id = event.user_id
|
||||
session_type = 'user'
|
||||
elif isinstance(event, GroupMessageEvent):
|
||||
session_id = event.group_id
|
||||
session_type = 'group'
|
||||
else:
|
||||
return
|
||||
try:
|
||||
if event.user_id in SUPERUSERS:
|
||||
return
|
||||
if not matcher.plugin_name or matcher.plugin_name in HIDDEN_PLUGINS:
|
||||
return
|
||||
if not isinstance(event, (PrivateMessageEvent, GroupMessageEvent)):
|
||||
return
|
||||
|
||||
# 权限检查
|
||||
perm = await PluginPermission.get_or_none(name=matcher.plugin_name, session_id=session_id,
|
||||
session_type=session_type)
|
||||
if not perm:
|
||||
await PluginPermission.create(name=matcher.plugin_name, session_id=session_id, session_type=session_type)
|
||||
return
|
||||
if not perm.status:
|
||||
raise IgnoredException('插件使用权限已禁用')
|
||||
if isinstance(event, GroupMessageEvent) and event.user_id in perm.ban:
|
||||
raise IgnoredException('用户被禁止使用该插件')
|
||||
# 权限检查
|
||||
if await PluginDisable.get_or_none(name=matcher.plugin_name, global_disable=True):
|
||||
raise IgnoredException('插件使用权限已禁用')
|
||||
if await PluginDisable.get_or_none(name=matcher.plugin_name, user_id=event.user_id, group_id=None):
|
||||
raise IgnoredException('插件使用权限已禁用')
|
||||
elif isinstance(event, GroupMessageEvent) and (
|
||||
perms := await PluginDisable.filter(name=matcher.plugin_name, group_id=event.group_id)):
|
||||
user_ids = [p.user_id for p in perms]
|
||||
if None in user_ids or event.user_id in user_ids:
|
||||
raise IgnoredException('插件使用权限已禁用')
|
||||
|
||||
# 命令调用统计
|
||||
if matcher.plugin_name in PluginManager.plugins and 'pm_name' in matcher.state:
|
||||
if matcher_info := list(filter(lambda x: x.pm_name == matcher.state['pm_name'],
|
||||
PluginManager.plugins[matcher.plugin_name].matchers)):
|
||||
matcher_info = matcher_info[0]
|
||||
await PluginStatistics.create(plugin_name=matcher.plugin_name,
|
||||
matcher_name=matcher_info.pm_name,
|
||||
matcher_usage=matcher_info.pm_usage,
|
||||
group_id=event.group_id if isinstance(event, GroupMessageEvent) else None,
|
||||
user_id=event.user_id,
|
||||
message_type=session_type,
|
||||
time=datetime.datetime.now())
|
||||
# 命令调用统计
|
||||
if matcher.plugin_name in PluginManager.plugins and 'pm_name' in matcher.state:
|
||||
if matcher_info := list(filter(lambda x: x.pm_name == matcher.state['pm_name'],
|
||||
PluginManager.plugins[matcher.plugin_name].matchers)):
|
||||
matcher_info = matcher_info[0]
|
||||
await PluginStatistics.create(plugin_name=matcher.plugin_name,
|
||||
matcher_name=matcher_info.pm_name,
|
||||
matcher_usage=matcher_info.pm_usage,
|
||||
group_id=event.group_id if isinstance(event, GroupMessageEvent) else None,
|
||||
user_id=event.user_id,
|
||||
message_type=event.message_type,
|
||||
time=datetime.datetime.now())
|
||||
except Exception as e:
|
||||
logger.info('插件管理器', f'插件权限检查<r>失败:{e}</r>')
|
||||
|
@ -7,6 +7,7 @@ from tortoise.models import Model
|
||||
|
||||
|
||||
class PluginPermission(Model):
|
||||
"""将在N个版本后废弃"""
|
||||
id = fields.IntField(pk=True, generated=True, auto_increment=True)
|
||||
name: str = fields.TextField()
|
||||
"""插件名称"""
|
||||
@ -25,6 +26,21 @@ class PluginPermission(Model):
|
||||
table = 'plugin_permission'
|
||||
|
||||
|
||||
class PluginDisable(Model):
|
||||
id = fields.IntField(pk=True, generated=True, auto_increment=True)
|
||||
name: str = fields.TextField()
|
||||
"""插件名称"""
|
||||
global_disable: bool = fields.BooleanField(default=False)
|
||||
"""全局禁用"""
|
||||
user_id: int = fields.IntField(null=True)
|
||||
"""用户id"""
|
||||
group_id: int = fields.IntField(null=True)
|
||||
"""群组id"""
|
||||
|
||||
class Meta:
|
||||
table = 'plugin_disable'
|
||||
|
||||
|
||||
class PluginStatistics(Model):
|
||||
id = fields.IntField(pk=True, generated=True, auto_increment=True)
|
||||
plugin_name: str = fields.TextField()
|
||||
|
@ -45,7 +45,6 @@ learning_chat = on_message(priority=99, block=False, rule=Rule(ChatRule), permis
|
||||
@learning_chat.handle()
|
||||
async def _(event: GroupMessageEvent, answers=Arg('answers')):
|
||||
for answer in answers:
|
||||
await asyncio.sleep(random.randint(1, 2))
|
||||
try:
|
||||
logger.info('群聊学习', f'{NICKNAME}将向群<m>{event.group_id}</m>回复<m>"{answer}"</m>')
|
||||
msg = await learning_chat.send(Message(answer))
|
||||
@ -56,6 +55,7 @@ async def _(event: GroupMessageEvent, answers=Arg('answers')):
|
||||
raw_message=answer,
|
||||
time=int(time.time()),
|
||||
plain_text=Message(answer).extract_plain_text())
|
||||
await asyncio.sleep(random.random() + 0.5)
|
||||
except ActionFailed:
|
||||
logger.info('群聊学习', f'{NICKNAME}向群<m>{event.group_id}</m>的回复<m>"{answer}"</m>发送<r>失败,可能处于风控中</r>')
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import datetime
|
||||
import random
|
||||
import re
|
||||
@ -90,7 +91,7 @@ class LearningChat:
|
||||
return Result.Pass
|
||||
elif self.reply:
|
||||
# 如果是回复消息
|
||||
if not (message := await ChatMessage.get_or_none(message_id=self.reply.message_id)):
|
||||
if not (message := await ChatMessage.filter(message_id=self.reply.message_id).first()):
|
||||
# 回复的消息在数据库中有记录
|
||||
logger.debug('群聊学习', '➤回复的消息不在数据库中,跳过')
|
||||
return Result.Pass
|
||||
@ -167,10 +168,17 @@ class LearningChat:
|
||||
elif result == Result.Pass:
|
||||
# 跳过
|
||||
return None
|
||||
elif result == Result.Repeat and (messages := await ChatMessage.filter(group_id=self.data.group_id,
|
||||
time__gte=self.data.time - 3600).limit(
|
||||
self.config.repeat_threshold)):
|
||||
# 如果达到阈值,且bot没有回复过,且不是全都为同一个人在说,则进行复读
|
||||
elif result == Result.Repeat:
|
||||
query_set = ChatMessage.filter(group_id=self.data.group_id, time__gte=self.data.time - 3600)
|
||||
if await query_set.limit(self.config.repeat_threshold + 5).filter(
|
||||
user_id=self.bot_id, message=self.data.message).exists():
|
||||
# 如果在阈值+5条消息内,bot已经回复过这句话,则跳过
|
||||
logger.debug('群聊学习', f'➤➤已经复读过了,跳过')
|
||||
return None
|
||||
if not (messages := await query_set.limit(
|
||||
self.config.repeat_threshold + 5)):
|
||||
return None
|
||||
# 如果达到阈值,且不是全都为同一个人在说,则进行复读
|
||||
if len(messages) >= self.config.repeat_threshold and all(
|
||||
message.message == self.data.message and message.user_id != self.bot_id
|
||||
for message in messages) and not all(
|
||||
@ -181,12 +189,13 @@ class LearningChat:
|
||||
else:
|
||||
logger.debug('群聊学习', f'➤➤达到复读阈值,复读<m>{messages[0].message}</m>')
|
||||
return [self.data.message]
|
||||
return None
|
||||
else:
|
||||
# 回复
|
||||
if self.data.is_plain_text and len(self.data.plain_text) <= 1:
|
||||
logger.debug('群聊学习', '➤➤消息过短,不回复')
|
||||
return None
|
||||
if not (context := await ChatContext.get_or_none(keywords=self.data.keywords)):
|
||||
if not (context := await ChatContext.filter(keywords=self.data.keywords).first()):
|
||||
logger.debug('群聊学习', '➤➤尚未有已学习的回复,不回复')
|
||||
return None
|
||||
|
||||
@ -204,7 +213,8 @@ class LearningChat:
|
||||
else:
|
||||
answer_count_threshold = 1
|
||||
cross_group_threshold = 1
|
||||
logger.debug('群聊学习', f'➤➤本次回复阈值为<m>{answer_count_threshold}</m>,跨群阈值为<m>{cross_group_threshold}</m>')
|
||||
logger.debug('群聊学习',
|
||||
f'➤➤本次回复阈值为<m>{answer_count_threshold}</m>,跨群阈值为<m>{cross_group_threshold}</m>')
|
||||
# 获取满足跨群条件的回复
|
||||
answers_cross = await ChatAnswer.filter(context=context, count__gte=answer_count_threshold,
|
||||
keywords__in=await ChatAnswer.annotate(
|
||||
@ -241,6 +251,7 @@ class LearningChat:
|
||||
return None
|
||||
result_message = random.choice(result.messages)
|
||||
logger.debug('群聊学习', f'➤➤将回复<m>{result_message}</m>')
|
||||
await asyncio.sleep(random.random() + 0.5)
|
||||
return [result_message]
|
||||
|
||||
async def _ban(self, message_id: Optional[int] = None) -> bool:
|
||||
@ -248,7 +259,9 @@ class LearningChat:
|
||||
bot = get_bot()
|
||||
if message_id:
|
||||
# 如果有指定消息ID,则屏蔽该消息
|
||||
if (message := await ChatMessage.get_or_none(message_id=message_id)) and message.message not in ALL_WORDS:
|
||||
if (
|
||||
message := await ChatMessage.filter(
|
||||
message_id=message_id).first()) and message.message not in ALL_WORDS:
|
||||
keywords = message.keywords
|
||||
try:
|
||||
await bot.delete_msg(message_id=message_id)
|
||||
@ -266,7 +279,7 @@ class LearningChat:
|
||||
logger.info('群聊学习', f'待禁用消息<m>{last_reply.message_id}</m>尝试撤回<r>失败</r>')
|
||||
else:
|
||||
return False
|
||||
if ban_word := await ChatBlackList.get_or_none(keywords=keywords):
|
||||
if ban_word := await ChatBlackList.filter(keywords=keywords).first():
|
||||
# 如果已有屏蔽记录
|
||||
if self.data.group_id not in ban_word.ban_group_id:
|
||||
# 如果不在屏蔽群列表中,则添加
|
||||
@ -290,7 +303,7 @@ class LearningChat:
|
||||
|
||||
@staticmethod
|
||||
async def add_ban(data: Union[ChatMessage, ChatContext, ChatAnswer]):
|
||||
if ban_word := await ChatBlackList.get_or_none(keywords=data.keywords):
|
||||
if ban_word := await ChatBlackList.filter(keywords=data.keywords).first():
|
||||
# 如果已有屏蔽记录
|
||||
if isinstance(data, ChatMessage):
|
||||
if data.group_id not in ban_word.ban_group_id:
|
||||
@ -360,7 +373,9 @@ class LearningChat:
|
||||
continue
|
||||
|
||||
config = config_manager.get_group_config(group_id)
|
||||
ban_words = set(chat_config.ban_words + config.ban_words + ['[CQ:xml', '[CQ:json', '[CQ:at', '[CQ:video', '[CQ:record', '[CQ:share'])
|
||||
ban_words = set(
|
||||
chat_config.ban_words + config.ban_words + ['[CQ:xml', '[CQ:json', '[CQ:at', '[CQ:video', '[CQ:record',
|
||||
'[CQ:share'])
|
||||
|
||||
# 是否开启了主动发言
|
||||
if not config.speak_enable:
|
||||
@ -400,7 +415,7 @@ class LearningChat:
|
||||
speak_list.append(message)
|
||||
while random.random() < config.speak_continuously_probability and len(
|
||||
speak_list) < config.speak_continuously_max_len:
|
||||
if (follow_context := await ChatContext.get_or_none(keywords=answer.keywords)) and (
|
||||
if (follow_context := await ChatContext.filter(keywords=answer.keywords).first()) and (
|
||||
follow_answers := await ChatAnswer.filter(
|
||||
group_id=group_id,
|
||||
context=follow_context,
|
||||
@ -432,13 +447,13 @@ class LearningChat:
|
||||
return None
|
||||
|
||||
async def _set_answer(self, message: ChatMessage):
|
||||
if context := await ChatContext.get_or_none(keywords=message.keywords):
|
||||
if context := await ChatContext.filter(keywords=message.keywords).first():
|
||||
if context.count < chat_config.learn_max_count:
|
||||
context.count += 1
|
||||
context.time = self.data.time
|
||||
if answer := await ChatAnswer.get_or_none(keywords=self.data.keywords,
|
||||
group_id=self.data.group_id,
|
||||
context=context):
|
||||
if answer := await ChatAnswer.filter(keywords=self.data.keywords,
|
||||
group_id=self.data.group_id,
|
||||
context=context).first():
|
||||
if answer.count < chat_config.learn_max_count:
|
||||
answer.count += 1
|
||||
answer.time = self.data.time
|
||||
@ -476,7 +491,7 @@ class LearningChat:
|
||||
if raw_message.startswith('[') and raw_message.endswith(']'):
|
||||
# logger.debug('群聊学习', f'➤检验<m>{keywords}</m><r>不通过</r>')
|
||||
return False
|
||||
if ban_word := await ChatBlackList.get_or_none(keywords=message.keywords):
|
||||
if ban_word := await ChatBlackList.filter(keywords=message.keywords).first():
|
||||
if ban_word.global_ban or message.group_id in ban_word.ban_group_id:
|
||||
# logger.debug('群聊学习', f'➤检验<m>{keywords}</m><r>不通过</r>')
|
||||
return False
|
||||
|
@ -1,8 +1,5 @@
|
||||
import asyncio
|
||||
from nonebot import on_regex, on_command, on_notice
|
||||
from nonebot import plugin as nb_plugin
|
||||
from nonebot import on_regex, on_command
|
||||
from nonebot.adapters.onebot.v11 import Message, GroupMessageEvent, PrivateMessageEvent, MessageEvent
|
||||
from nonebot.adapters.onebot.v11 import NoticeEvent, FriendAddNoticeEvent, GroupIncreaseNoticeEvent
|
||||
from nonebot.params import RegexDict, CommandArg
|
||||
from nonebot.permission import SUPERUSER
|
||||
from nonebot.plugin import PluginMetadata
|
||||
@ -10,8 +7,8 @@ from nonebot.rule import Rule
|
||||
from nonebot.typing import T_State
|
||||
|
||||
from LittlePaimon import SUPERUSERS
|
||||
from LittlePaimon.config import ConfigManager, PluginManager, HIDDEN_PLUGINS
|
||||
from LittlePaimon.database import PluginPermission
|
||||
from LittlePaimon.config import ConfigManager, PluginManager
|
||||
from LittlePaimon.database import PluginDisable
|
||||
from LittlePaimon.utils import logger
|
||||
from LittlePaimon.utils.message import CommandObjectID
|
||||
from .draw_help import draw_help
|
||||
@ -27,19 +24,12 @@ __plugin_meta__ = PluginMetadata(
|
||||
)
|
||||
|
||||
|
||||
def notice_rule(event: NoticeEvent) -> bool:
|
||||
if isinstance(event, FriendAddNoticeEvent):
|
||||
return True
|
||||
elif isinstance(event, GroupIncreaseNoticeEvent):
|
||||
return event.user_id == event.self_id
|
||||
|
||||
|
||||
def fullmatch(msg: Message = CommandArg()) -> bool:
|
||||
return not bool(msg)
|
||||
|
||||
|
||||
manage_cmd = on_regex(
|
||||
r'^pm (?P<func>ban|unban) (?P<plugin>([\w ]*)|all|全部) ?(-g (?P<group>[\d ]*) ?)?(-u (?P<user>[\d ]*) ?)?(?P<reserve>-r)?',
|
||||
r'^pm (?P<func>ban|unban) (?P<plugin>([\w ]*)|all|全部) ?(-g (?P<group>[\d ]*) ?)?(-u (?P<user>[\d ]*) ?)?',
|
||||
priority=1, block=True, state={
|
||||
'pm_name': 'pm-ban|unban',
|
||||
'pm_description': '禁用|取消禁用插件的群|用户使用权限',
|
||||
@ -58,11 +48,6 @@ set_config_cmd = on_command('pm set', priority=1, permission=SUPERUSER, block=Tr
|
||||
'pm_usage': 'pm set<配置名> <值>',
|
||||
'pm_priority': 2
|
||||
})
|
||||
notices = on_notice(priority=1, rule=Rule(notice_rule), block=True, state={
|
||||
'pm_name': 'pm-new-group-user',
|
||||
'pm_description': '为新加入的群|用户添加插件使用权限',
|
||||
'pm_show': False
|
||||
})
|
||||
|
||||
cache_help = {}
|
||||
|
||||
@ -73,21 +58,26 @@ async def _(event: GroupMessageEvent, state: T_State, match: dict = RegexDict(),
|
||||
await manage_cmd.finish('你没有权限使用该命令', at_sender=True)
|
||||
state['session_id'] = session_id
|
||||
state['bool'] = match['func'] == 'unban'
|
||||
state['plugin'] = []
|
||||
state['plugin_no_exist'] = []
|
||||
for plugin in match['plugin'].strip().split(' '):
|
||||
if plugin in PluginManager.plugins.keys() or plugin in ['all', '全部']:
|
||||
state['plugin'].append(plugin)
|
||||
elif module_name := list(
|
||||
filter(lambda x: PluginManager.plugins[x].name == plugin, PluginManager.plugins.keys())):
|
||||
state['plugin'].append(module_name[0])
|
||||
else:
|
||||
state['plugin_no_exist'].append(plugin)
|
||||
if any(w in match['plugin'] for w in {'all', '全部'}):
|
||||
state['is_all'] = True
|
||||
state['plugin'] = [p for p in PluginManager.plugins.keys() if p != 'plugin_manager']
|
||||
else:
|
||||
state['is_all'] = False
|
||||
state['plugin'] = []
|
||||
for plugin in match['plugin'].strip().split(' '):
|
||||
if plugin in PluginManager.plugins.keys():
|
||||
state['plugin'].append(plugin)
|
||||
elif module_name := list(
|
||||
filter(lambda x: PluginManager.plugins[x].name == plugin, PluginManager.plugins.keys())):
|
||||
state['plugin'].append(module_name[0])
|
||||
else:
|
||||
state['plugin_no_exist'].append(plugin)
|
||||
if not match['group'] or event.user_id not in SUPERUSERS:
|
||||
state['group'] = [event.group_id]
|
||||
else:
|
||||
state['group'] = [int(group) for group in match['group'].strip().split(' ')]
|
||||
state['user'] = [int(user) for user in match['user'].strip().split(' ')] if match['user'] else []
|
||||
state['user'] = [int(user) for user in match['user'].strip().split(' ')] if match['user'] else None
|
||||
|
||||
|
||||
@manage_cmd.handle()
|
||||
@ -96,18 +86,23 @@ async def _(event: PrivateMessageEvent, state: T_State, match: dict = RegexDict(
|
||||
await manage_cmd.finish('你没有权限使用该命令', at_sender=True)
|
||||
state['session_id'] = session_id
|
||||
state['bool'] = match['func'] == 'unban'
|
||||
state['plugin'] = []
|
||||
state['plugin_no_exist'] = []
|
||||
for plugin in match['plugin'].strip().split(' '):
|
||||
if plugin in PluginManager.plugins.keys() or plugin in ['all', '全部']:
|
||||
state['plugin'].append(plugin)
|
||||
elif module_name := list(
|
||||
filter(lambda x: PluginManager.plugins[x].name == plugin, PluginManager.plugins.keys())):
|
||||
state['plugin'].append(module_name[0])
|
||||
else:
|
||||
state['plugin_no_exist'].append(plugin)
|
||||
state['group'] = [int(group) for group in match['group'].strip().split(' ')] if match['group'] else []
|
||||
state['user'] = [int(user) for user in match['user'].strip().split(' ')] if match['user'] else []
|
||||
if any(w in match['plugin'] for w in {'all', '全部'}):
|
||||
state['is_all'] = True
|
||||
state['plugin'] = [p for p in PluginManager.plugins.keys() if p != 'plugin_manager']
|
||||
else:
|
||||
state['is_all'] = False
|
||||
state['plugin'] = []
|
||||
for plugin in match['plugin'].strip().split(' '):
|
||||
if plugin in PluginManager.plugins.keys():
|
||||
state['plugin'].append(plugin)
|
||||
elif module_name := list(
|
||||
filter(lambda x: PluginManager.plugins[x].name == plugin, PluginManager.plugins.keys())):
|
||||
state['plugin'].append(module_name[0])
|
||||
else:
|
||||
state['plugin_no_exist'].append(plugin)
|
||||
state['group'] = [int(group) for group in match['group'].strip().split(' ')] if match['group'] else None
|
||||
state['user'] = [int(user) for user in match['user'].strip().split(' ')] if match['user'] else None
|
||||
|
||||
|
||||
@manage_cmd.got('bool')
|
||||
@ -119,45 +114,40 @@ async def _(state: T_State):
|
||||
if not state['plugin'] and state['plugin_no_exist']:
|
||||
await manage_cmd.finish(f'没有叫{" ".join(state["plugin_no_exist"])}的插件')
|
||||
extra_msg = f',但没有叫{" ".join(state["plugin_no_exist"])}的插件。' if state['plugin_no_exist'] else '。'
|
||||
if state['group'] and not state['user']:
|
||||
for group_id in state['group']:
|
||||
if 'all' in state['plugin']:
|
||||
await PluginPermission.filter(session_id=group_id, session_type='group').update(status=state['bool'])
|
||||
else:
|
||||
await PluginPermission.filter(name__in=state['plugin'], session_id=group_id,
|
||||
session_type='group').update(
|
||||
status=state['bool'])
|
||||
logger.info('插件管理器',
|
||||
f'已{"<g>启用</g>" if state["bool"] else "<r>禁用</r>"}群<m>{" ".join(map(str, state["group"]))}</m>的插件<m>{" ".join(state["plugin"])}</m>使用权限')
|
||||
await manage_cmd.finish(
|
||||
f'已{"启用" if state["bool"] else "禁用"}群{" ".join(map(str, state["group"]))}的插件{" ".join(state["plugin"])}使用权限{extra_msg}')
|
||||
elif state['user'] and not state['group']:
|
||||
for user_id in state['user']:
|
||||
if 'all' in state['plugin']:
|
||||
await PluginPermission.filter(session_id=user_id, session_type='user').update(status=state['bool'])
|
||||
else:
|
||||
await PluginPermission.filter(name__in=state['plugin'], session_id=user_id, session_type='user').update(
|
||||
status=state['bool'])
|
||||
logger.info('插件管理器',
|
||||
f'已{"<g>启用</g>" if state["bool"] else "<r>禁用</r>"}用户<m>{" ".join(map(str, state["user"]))}</m>的插件<m>{" ".join(state["plugin"])}</m>使用权限')
|
||||
await manage_cmd.finish(
|
||||
f'已{"启用" if state["bool"] else "禁用"}用户{" ".join(map(str, state["user"]))}的插件{" ".join(state["plugin"])}使用权限{extra_msg}')
|
||||
filter_arg = {}
|
||||
if state['group']:
|
||||
filter_arg['group_id__in'] = state['group']
|
||||
if state['user']:
|
||||
filter_arg['user_id__in'] = state['user']
|
||||
logger.info('插件管理器',
|
||||
f'已{"<g>启用</g>" if state["bool"] else "<r>禁用</r>"}群<m>{" ".join(map(str, state["group"]))}</m>中用户<m>{" ".join(map(str, state["user"]))}</m>的插件<m>{" ".join(state["plugin"]) if not state["is_all"] else "全部"}</m>使用权限')
|
||||
msg = f'已{"启用" if state["bool"] else "禁用"}群{" ".join(map(str, state["group"]))}中用户{" ".join(map(str, state["user"]))}的插件{" ".join(state["plugin"]) if not state["is_all"] else "全部"}使用权限{extra_msg}'
|
||||
else:
|
||||
filter_arg['user_id'] = None
|
||||
logger.info('插件管理器',
|
||||
f'已{"<g>启用</g>" if state["bool"] else "<r>禁用</r>"}群<m>{" ".join(map(str, state["group"]))}</m>的插件<m>{" ".join(state["plugin"]) if not state["is_all"] else "全部"}</m>使用权限')
|
||||
msg = f'已{"启用" if state["bool"] else "禁用"}群{" ".join(map(str, state["group"]))}的插件{" ".join(state["plugin"]) if not state["is_all"] else "全部"}使用权限{extra_msg}'
|
||||
else:
|
||||
for group_id in state['group']:
|
||||
if 'all' in state['plugin']:
|
||||
plugin_list = await PluginPermission.filter(session_id=group_id, session_type='group').all()
|
||||
else:
|
||||
plugin_list = await PluginPermission.filter(name__in=state['plugin'], session_id=group_id,
|
||||
session_type='group').all()
|
||||
if plugin_list:
|
||||
for plugin in plugin_list:
|
||||
plugin.ban = list(set(plugin.ban) - set(state['user'])) if state['bool'] else list(
|
||||
set(plugin.ban) | set(state['user']))
|
||||
await plugin.save()
|
||||
filter_arg['user_id__in'] = state['user']
|
||||
logger.info('插件管理器',
|
||||
f'已{"<g>启用</g>" if state["bool"] else "<r>禁用</r>"}群<m>{" ".join(map(str, state["group"]))}</m>中用户<m>{" ".join(map(str, state["user"]))}</m>的插件<m>{" ".join(state["plugin"])}</m>使用权限')
|
||||
await manage_cmd.finish(
|
||||
f'已{"启用" if state["bool"] else "禁用"}群{" ".join(map(str, state["group"]))}中用户{" ".join(map(str, state["user"]))}的插件{" ".join(state["plugin"])}使用权限{extra_msg}')
|
||||
f'已{"<g>启用</g>" if state["bool"] else "<r>禁用</r>"}用户<m>{" ".join(map(str, state["user"]))}</m>的插件<m>{" ".join(state["plugin"]) if not state["is_all"] else "全部"}</m>使用权限')
|
||||
msg = f'已{"启用" if state["bool"] else "禁用"}用户{" ".join(map(str, state["user"]))}的插件{" ".join(state["plugin"]) if not state["is_all"] else "全部"}使用权限{extra_msg}'
|
||||
if state['bool']:
|
||||
await PluginDisable.filter(name__in=state['plugin'], **filter_arg).delete()
|
||||
else:
|
||||
for plugin in state['plugin']:
|
||||
if state['group']:
|
||||
for group in state['group']:
|
||||
if state['user']:
|
||||
for user in state['user']:
|
||||
await PluginDisable.update_or_create(name=plugin, group_id=group, user_id=user)
|
||||
else:
|
||||
await PluginDisable.update_or_create(name=plugin, group_id=group)
|
||||
else:
|
||||
for user in state['user']:
|
||||
await PluginDisable.update_or_create(name=plugin, user_id=user)
|
||||
|
||||
await manage_cmd.finish(msg)
|
||||
|
||||
|
||||
@help_cmd.handle()
|
||||
@ -181,15 +171,3 @@ async def _(event: MessageEvent, msg: Message = CommandArg()):
|
||||
else:
|
||||
result = ConfigManager.set_config(msg[0], msg[1])
|
||||
await set_config_cmd.finish(result)
|
||||
|
||||
|
||||
@notices.handle()
|
||||
async def _(event: NoticeEvent):
|
||||
plugin_list = nb_plugin.get_loaded_plugins()
|
||||
if isinstance(event, FriendAddNoticeEvent):
|
||||
await asyncio.gather(*[PluginPermission.update_or_create(name=plugin, session_id=event.user_id, session_type='user') for plugin
|
||||
in plugin_list if plugin not in HIDDEN_PLUGINS])
|
||||
elif isinstance(event, GroupIncreaseNoticeEvent):
|
||||
await asyncio.gather(
|
||||
*[PluginPermission.update_or_create(name=plugin, session_id=event.group_id, session_type='group') for plugin
|
||||
in plugin_list if plugin not in HIDDEN_PLUGINS])
|
||||
|
@ -4,7 +4,7 @@ from nonebot import get_driver
|
||||
from .logger import logger
|
||||
from .scheduler import scheduler
|
||||
|
||||
__version__ = '3.0.0rc3'
|
||||
__version__ = '3.0.0rc4'
|
||||
|
||||
DRIVER = get_driver()
|
||||
try:
|
||||
|
@ -5,7 +5,7 @@ from fastapi import APIRouter
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from LittlePaimon.config import ConfigManager, PluginManager, PluginInfo
|
||||
from LittlePaimon.database import PluginPermission
|
||||
from LittlePaimon.database import PluginDisable
|
||||
|
||||
from .utils import authentication
|
||||
|
||||
@ -27,28 +27,31 @@ async def get_plugins():
|
||||
|
||||
@route.post('/set_plugin_status', response_class=JSONResponse, dependencies=[authentication()])
|
||||
async def set_plugin_status(data: dict):
|
||||
module_name = data.get('plugin')
|
||||
status = data.get('status')
|
||||
module_name: str = data.get('plugin')
|
||||
status: bool = data.get('status')
|
||||
try:
|
||||
from LittlePaimon.plugins.plugin_manager import cache_help
|
||||
cache_help.clear()
|
||||
except Exception:
|
||||
pass
|
||||
await PluginPermission.filter(name=module_name).update(status=status)
|
||||
if status:
|
||||
await PluginDisable.filter(name=module_name, global_disable=True).delete()
|
||||
else:
|
||||
await PluginDisable.create(name=module_name, global_disable=True)
|
||||
return {'status': 0, 'msg': f'成功设置{module_name}插件状态为{status}'}
|
||||
|
||||
|
||||
@route.get('/get_plugin_bans', response_class=JSONResponse, dependencies=[authentication()])
|
||||
async def get_plugin_status(module_name: str):
|
||||
result = []
|
||||
bans = await PluginPermission.filter(name=module_name).all()
|
||||
bans = await PluginDisable.filter(name=module_name).all()
|
||||
for ban in bans:
|
||||
if ban.session_type == 'group':
|
||||
result.extend(f'群{ban.session_id}.{b}' for b in ban.ban)
|
||||
if not ban.status:
|
||||
result.append(f'群{ban.session_id}')
|
||||
elif ban.session_type == 'user' and not ban.status:
|
||||
result.append(f'{ban.session_id}')
|
||||
if ban.user_id and ban.group_id:
|
||||
result.append(f'群{ban.group_id}.{ban.user_id}')
|
||||
elif ban.group_id and not ban.user_id:
|
||||
result.append(f'群{ban.group_id}')
|
||||
elif ban.user_id and not ban.group_id:
|
||||
result.append(f'{ban.user_id}')
|
||||
return {
|
||||
'status': 0,
|
||||
'msg': 'ok',
|
||||
@ -63,20 +66,17 @@ async def get_plugin_status(module_name: str):
|
||||
async def set_plugin_bans(data: dict):
|
||||
bans = data['bans']
|
||||
name = data['module_name']
|
||||
await PluginPermission.filter(name=name).update(status=True, ban=[])
|
||||
await PluginDisable.filter(name=name, global_disable=False).delete()
|
||||
for ban in bans:
|
||||
if ban.startswith('群'):
|
||||
if '.' in ban:
|
||||
group_id = int(ban.split('.')[0][1:])
|
||||
user_id = int(ban.split('.')[1])
|
||||
plugin = await PluginPermission.filter(name=name, session_type='group', session_id=group_id).first()
|
||||
plugin.ban.append(user_id)
|
||||
await plugin.save()
|
||||
await PluginDisable.create(name=name, group_id=group_id, user_id=user_id)
|
||||
else:
|
||||
await PluginPermission.filter(name=name, session_type='group', session_id=int(ban[1:])).update(
|
||||
status=False)
|
||||
await PluginDisable.create(name=name, group_id=int(ban[1:]))
|
||||
else:
|
||||
await PluginPermission.filter(name=name, session_type='user', session_id=int(ban)).update(status=False)
|
||||
await PluginDisable.create(name=name, user_id=int(ban))
|
||||
try:
|
||||
from LittlePaimon.plugins.plugin_manager import cache_help
|
||||
cache_help.clear()
|
||||
|
Loading…
x
Reference in New Issue
Block a user