133 lines
4.4 KiB
Python
Raw Normal View History

2022-03-13 21:25:42 +08:00
import asyncio
import hashlib
import json
import os
import random
import re
from os import path
from typing import List, Union
import aiohttp
import filetype
import nonebot
from aiocqhttp.event import Event
from aiocqhttp.exceptions import ActionFailed
from nonebot import scheduler
from hoshino.log import new_logger
from hoshino.service import Service
logger = new_logger('shebot', debug=False)
bot = nonebot.get_bot()
async def download_async(url: str, save_path: str, save_name: str, auto_extension=False) -> None:
timeout = aiohttp.ClientTimeout(total=30)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(url) as resp:
content = await resp.read()
if auto_extension: #没有指定后缀,自动识别后缀名
try:
extension = filetype.guess_mime(content).split('/')[1]
except:
raise ValueError('不是有效文件类型')
abs_path = path.join(save_path, f'{save_name}.{extension}')
else:
abs_path = path.join(save_path, save_name)
with open(abs_path, 'wb') as f:
f.write(content)
return abs_path
def get_random_file(path) -> str:
files = os.listdir(path)
rfile = random.choice(files)
return rfile
def get_md5(val: Union[bytes, str]) -> str:
if isinstance(val, str):
val = val.encode('utf-8')
m = hashlib.md5()
m.update(val)
return m.hexdigest()
async def broadcast(msg,groups=None,sv_name=None):
bot = nonebot.get_bot()
#当groups指定时在groups中广播当groups未指定但sv_name指定将在开启该服务的群广播
svs = Service.get_loaded_services()
if not groups and sv_name not in svs:
raise ValueError(f'不存在服务 {sv_name}')
if sv_name:
enable_groups = await svs[sv_name].get_enable_groups()
send_groups = enable_groups.keys() if not groups else groups
else:
send_groups = groups
for gid in send_groups:
try:
await bot.send_group_msg(group_id=gid,message=msg)
logger.info(f'{gid}投递消息成功')
await asyncio.sleep(0.5)
except ActionFailed as e:
logger.error(f'在群{gid}投递消息失败, retcode={e.retcode}')
except Exception as e:
logger.exception(e)
def extract_url_from_event(event: Event) -> List[str]:
urls = re.findall(r'http.*?term=\d', str(event.message))
return urls
def save_config(config:dict,path:str):
try:
with open(path,'w',encoding='utf8') as f:
json.dump(config, f, ensure_ascii=False, indent=2)
return True
except Exception as ex:
logger.error(ex)
return False
def load_config(path):
try:
with open(path, mode='r', encoding='utf-8') as f:
config = json.load(f)
return config
except Exception as ex:
logger.error(f'exception occured when loading config in {path} {ex}')
logger.exception(ex)
return {}
class RSS():
def __init__(self):
self.base_url = 'http://101.32.36.8:1200'
self.route :str= None
self.xml : bytes = None
self.filter : dict = dict()
self.filterout :dict = dict() #out为过滤掉
'''
filter 选出想要的内容
filter: 过滤标题和描述
filter_title: 过滤标题
filter_description: 过滤描述
filter_author: 过滤作者
filter_time: 过滤时间仅支持数字单位为秒返回指定时间范围内的内容如果条目没有输出pubDate或者格式不正确将不会被过滤
'''
self.filter_case_sensitive = True #过滤是否区分大小写,默认区分大小写
self.limit = 10 #限制最大条数,主要用于排行榜类 RSS
async def get(self):
url = self.base_url + self.route
params = {}
for key in self.filter:
if self.filter[key]:
params[key] = self.filter[key]
for key in self.filterout:
if self.filterout[key]:
params[key] = self.filterout[key]
params['limit'] = self.limit
async with aiohttp.ClientSession() as session:
async with session.get(url,params=params) as resp:
self.xml = await resp.read()
def parse_xml(self):
#在实现类中编写解析xml函数
raise NotImplementedError