import atexit
import random
import re
import threading
import time
from collections import defaultdict
from dataclasses import dataclass
from functools import cached_property, cmp_to_key
from typing import Generator, List, Optional, Union, Tuple, Dict, Any

import jieba_fast.analyse
import pymongo
import pypinyin
from nonebot.adapters.onebot.v11 import GroupMessageEvent, PrivateMessageEvent
from nonebot.adapters.onebot.v11 import Message

from ...utils.config import config

mongo_client = pymongo.MongoClient(config.paimon_mongodb_url)

mongo_db = mongo_client['PaimonChat']

message_mongo = mongo_db['message']
message_mongo.create_index(name='time_index',
                           keys=[('time', pymongo.DESCENDING)])

context_mongo = mongo_db['context']
context_mongo.create_index(name='keywords_index',
                           keys=[('keywords', pymongo.HASHED)])
context_mongo.create_index(name='count_index',
                           keys=[('count', pymongo.DESCENDING)])
context_mongo.create_index(name='time_index',
                           keys=[('time', pymongo.DESCENDING)])
context_mongo.create_index(name='answers_index',
                           keys=[('answers.group_id', pymongo.TEXT),
                                 ('answers.keywords', pymongo.TEXT)],
                           default_language='none')

blacklist_mongo = mongo_db['blacklist']
blacklist_mongo.create_index(name='group_index',
                             keys=[('group_id', pymongo.HASHED)])


@dataclass
class ChatData:
    group_id: int
    user_id: int
    raw_message: str
    plain_text: str
    time: int
    bot_id: int

    _keywords_size: int = 3

    @cached_property
    def is_plain_text(self) -> bool:
        return '[CQ:' not in self.raw_message and len(self.plain_text) != 0

    @cached_property
    def is_image(self) -> bool:
        return '[CQ:image,' in self.raw_message or '[CQ:face,' in self.raw_message

    @cached_property
    def keywords(self) -> str:
        if not self.is_plain_text and len(self.plain_text) == 0:
            return self.raw_message

        keywords_list = jieba_fast.analyse.extract_tags(
            self.plain_text, topK=ChatData._keywords_size)
        if len(keywords_list) < 2:
            return self.plain_text
        else:
            # keywords_list.sort()
            return ' '.join(keywords_list)

    @cached_property
    def keywords_pinyin(self) -> str:
        return ''.join([item[0] for item in pypinyin.pinyin(
            self.keywords, style=pypinyin.NORMAL, errors='default')]).lower()

    @cached_property
    def to_me(self) -> bool:
        return self.plain_text.startswith('派蒙')


class Chat:
    answer_threshold = config.paimon_answer_threshold            # answer 相关的阈值,值越小废话越多,越大话越少
    answer_limit_threshold = config.paimon_answer_limit_threshold     # 上限阈值,一般正常的上下文不可能发 50 遍,一般是其他 bot 的回复,禁了!
    cross_group_threshold = config.paimon_cross_group_threshold      # N 个群有相同的回复,就跨群作为全局回复
    repeat_threshold = config.paimon_repeat_threshold            # 复读的阈值,群里连续多少次有相同的发言,就复读
    speak_threshold = config.paimon_speak_threshold             # 主动发言的阈值,越小废话越多

    drunk_probability = config.paimon_drunk_probability        # 喝醉的概率(回复没达到阈值的话)
    split_probability = 0.5         # 按逗号分割回复语的概率
    voice_probability = config.paimon_voice_probability           # 回复语音的概率(仅纯文字)
    speak_continuously_probability = config.paimon_speak_continuously_probability  # 连续主动说话的概率
    speak_poke_probability = config.paimon_speak_poke_probability    # 主动说话加上随机戳一戳群友的概率
    speak_continuously_max_len = config.paimon_speak_continuously_max_len  # 连续主动说话最多几句话

    save_time_threshold = 3600      # 每隔多久进行一次持久化 ( 秒 )
    save_count_threshold = 1000     # 单个群超过多少条聊天记录就进行一次持久化。与时间是或的关系

    blacklist_answer = defaultdict(set)
    blacklist_answer_reserve = defaultdict(set)

    learningGroup = config.paimon_chat_group# 机器学习群组

    def __init__(self, data: Union[ChatData, GroupMessageEvent, PrivateMessageEvent]):

        if isinstance(data, ChatData):
            self.chat_data = data
        elif isinstance(data, GroupMessageEvent):
            self.chat_data = ChatData(
                group_id=data.group_id,
                user_id=data.user_id,
                # 删除图片子类型字段,同一张图子类型经常不一样,影响判断
                raw_message=re.sub(
                    r',subType=\d+\]',
                    r']',
                    data.raw_message),
                plain_text=data.get_plaintext(),
                time=data.time,
                bot_id=data.self_id,
            )
        elif isinstance(data, PrivateMessageEvent):
            event_dict = data.dict()
            self.chat_data = ChatData(
                group_id=data.user_id,  # 故意加个符号,和群号区分开来
                user_id=data.user_id,
                # 删除图片子类型字段,同一张图子类型经常不一样,影响判断
                raw_message=re.sub(
                    r',subType=\d+\]',
                    r']',
                    data.raw_message),
                plain_text=data.get_plaintext(),
                time=data.time,
                bot_id=data.self_id,
            )

    def learn(self) -> bool:
        """
        学习这句话
        """

        if len(self.chat_data.raw_message.strip()) == 0:
            return False

        group_id = self.chat_data.group_id
        if group_id in Chat._message_dict:
            group_msgs = Chat._message_dict[group_id]
            if group_msgs:
                group_pre_msg = group_msgs[-1]
            else:
                group_pre_msg = None

            # 群里的上一条发言
            self._context_insert(group_pre_msg)

            user_id = self.chat_data.user_id
            if group_pre_msg and group_pre_msg['user_id'] != user_id:
                # 该用户在群里的上一条发言(倒序三句之内)
                for msg in group_msgs[:-3:-1]:
                    if msg['user_id'] == user_id:
                        self._context_insert(msg)
                        break

        self._message_insert()
        return True

    def answer(self, with_limit: bool = True) -> Optional[Generator[Message, None, None]]:
        """
        回复这句话,可能会分多次回复,也可能不回复
        """

        group_id = self.chat_data.group_id
        bot_id = self.chat_data.bot_id
        group_bot_replies = Chat._reply_dict[group_id][bot_id]

        if with_limit:
            # # 不回复太短的对话,大部分是“?”、“草”
            # if self.chat_data.is_plain_text and len(self.chat_data.plain_text) < 2:
            #     return None

            if len(group_bot_replies):
                latest_reply = group_bot_replies[-1]
                # 限制发音频率,最多 6 秒一次
                if int(time.time()) - latest_reply['time'] < 6:
                    return None
                # # 不要一直回复同一个内容
                # if self.chat_data.raw_message == latest_reply['pre_raw_message']:
                #     return None
                # 有人复读了牛牛的回复,不继续回复
                # if self.chat_data.raw_message == latest_reply['reply']:
                #    return None

        results = self._context_find()

        if results:
            raw_message = self.chat_data.raw_message
            keywords = self.chat_data.keywords
            with Chat._reply_lock:
                group_bot_replies.append({
                    'time': int(time.time()),
                    'pre_raw_message': raw_message,
                    'pre_keywords': keywords,
                    'reply': '[PaimonChat: Reply]',  # flag
                    'reply_keywords': '[PaimonChat: Reply]',  # flag
                })

            def yield_results(results: Tuple[List[str], str]) -> Generator[Message, None, None]:
                answer_list, answer_keywords = results
                group_bot_replies = Chat._reply_dict[group_id][bot_id]
                for item in answer_list:
                    with Chat._reply_lock:
                        group_bot_replies.append({
                            'time': int(time.time()),
                            'pre_raw_message': raw_message,
                            'pre_keywords': keywords,
                            'reply': item,
                            'reply_keywords': answer_keywords,
                        })
                    if '[CQ:' not in item and len(item) > 1 \
                            and random.random() < Chat.voice_probability:
                        yield Chat._text_to_speech(item)
                    else:
                        yield Message(item)

                with Chat._reply_lock:
                    group_bot_replies = group_bot_replies[-Chat._save_reserve_size:]

            return yield_results(results)

        return None

    @staticmethod
    def speak() -> Optional[Tuple[int, int, List[Message]]]:
        """
        主动发言,返回当前最希望发言的 bot 账号、群号、发言消息 List,也有可能不发言
        """

        basic_msgs_len = 10
        basic_delay = 600

        def group_popularity_cmp(lhs: Tuple[int, List[Dict[str, Any]]],
                                 rhs: Tuple[int, List[Dict[str, Any]]]) -> int:

            def cmp(a: Any, b: Any):
                return (a > b) - (a < b)

            lhs_group_id, lhs_msgs = lhs
            rhs_group_id, rhs_msgs = rhs

            lhs_len = len(lhs_msgs)
            rhs_len = len(rhs_msgs)

            # 默认是 0, 加个 1 避免乘没了
            lhs_drunkenness = Chat._drunkenness_dict[lhs_group_id] + 1
            rhs_drunkenness = Chat._drunkenness_dict[rhs_group_id] + 1

            if lhs_len < basic_msgs_len or rhs_len < basic_msgs_len:
                return cmp(lhs_len * lhs_drunkenness,
                           rhs_len * rhs_drunkenness)

            lhs_duration = lhs_msgs[-1]['time'] - lhs_msgs[0]['time']
            rhs_duration = rhs_msgs[-1]['time'] - rhs_msgs[0]['time']

            if not lhs_duration or not rhs_duration:
                return cmp(lhs_len, rhs_len)

            return cmp(lhs_len * lhs_drunkenness / lhs_duration,
                       rhs_len * rhs_drunkenness / rhs_duration)

        # 按群聊热度排序
        popularity = sorted(Chat._message_dict.items(),
                            key=cmp_to_key(group_popularity_cmp))

        cur_time = time.time()
        for group_id, group_msgs in popularity:
            group_replies = Chat._reply_dict[group_id]
            if not len(group_replies) or len(group_msgs) < basic_msgs_len:
                continue

            # 一般来说所有牛牛都是一起回复的,最后发言时间应该是一样的,随意随便选一个[0]就好了
            group_replies_front = list(group_replies.values())[0]
            if not len(group_replies_front) or \
                    group_replies_front[-1]['time'] > group_msgs[-1]['time']:
                continue

            msgs_len = len(group_msgs)
            latest_time = group_msgs[-1]['time']
            duration = latest_time - group_msgs[0]['time']
            avg_interval = duration / msgs_len

            # 已经超过平均发言间隔 N 倍的时间没有人说话了,才主动发言
            # print(cur_time - latest_time, '/', avg_interval *
            #       Chat.speak_threshold + basic_delay)
            if cur_time - latest_time < avg_interval * Chat.speak_threshold + basic_delay:
                continue

            # append 一个 flag, 防止这个群热度特别高,但压根就没有可用的 context 时,每次 speak 都查这个群,浪费时间
            with Chat._reply_lock:
                group_replies_front.append({
                    'time': int(cur_time),
                    'pre_raw_message': '[PaimonChat: Speak]',
                    'pre_keywords': '[PaimonChat: Speak]',
                    'reply': '[PaimonChat: Speak]',
                    'reply_keywords': '[PaimonChat: Speak]',
                })

            available_time = cur_time - 24 * 3600
            speak_context = context_mongo.aggregate([
                {
                    '$match': {
                        'count': {
                            '$gt': Chat.answer_threshold
                        },
                        'time': {
                            '$gt': available_time
                        },
                        # 上面两行为了加快查找速度,对查找到的结果不产生影响
                        'answers.group_id': group_id,
                        'answers.time': {
                            '$gt': available_time
                        },
                        'answers.count': {
                            '$gt': Chat.answer_threshold
                        }
                    }
                }, {
                    '$sample': {'size': 1}  # 随机一条
                }
            ])

            speak_context = list(speak_context)
            if not speak_context:
                continue

            ban_keywords = Chat._find_ban_keywords(
                context=speak_context[0], group_id=group_id)
            messages = [answer['messages']
                        for answer in speak_context[0]['answers']
                        if answer['count'] >= Chat.answer_threshold
                        and answer['keywords'] not in ban_keywords
                        and answer['group_id'] == group_id]

            if not messages:
                continue

            speak = random.choice(random.choice(messages))

            bot_id = random.choice(
                [bid for bid in group_replies.keys() if bid])
            with Chat._reply_lock:
                group_replies[bot_id].append({
                    'time': int(cur_time),
                    'pre_raw_message': '[PaimonChat: Speak]',
                    'pre_keywords': '[PaimonChat: Speak]',
                    'reply': speak,
                    'reply_keywords': '[PaimonChat: Speak]',
                })

            speak_list = [Message(speak), ]
            while random.random() < Chat.speak_continuously_probability \
                    and len(speak_list) < Chat.speak_continuously_max_len:
                pre_msg = str(speak_list[-1])
                answer = Chat(ChatData(group_id, 0, pre_msg,
                                       pre_msg, cur_time, 0)).answer(False)
                if not answer:
                    break
                speak_list.extend(answer)

            if random.random() < Chat.speak_poke_probability:
                target_id = random.choice(
                    Chat._message_dict[group_id])['user_id']
                speak_list.append(Message('[CQ:poke,qq={}]'.format(target_id)))

            return bot_id, group_id, speak_list

        return None

    @staticmethod
    def ban(group_id: int, bot_id: int, ban_raw_message: str, reason: str) -> bool:
        """
        禁止以后回复这句话,仅对该群有效果
        """

        if group_id not in Chat._reply_dict:
            return False

        ban_reply = None
        reply_data = Chat._reply_dict[group_id][bot_id][::-1]

        for reply in reply_data:
            cur_reply = reply['reply']
            # 为空时就直接 ban 最后一条回复
            if not ban_raw_message or ban_raw_message in cur_reply:
                ban_reply = reply
                break

        # 这种情况一般是有些 CQ 码,牛牛发送的时候,和被回复的时候,里面的内容不一样
        if not ban_reply:
            search = re.search(r'(\[CQ:[a-zA-z0-9-_.]+)',
                               ban_raw_message)
            if search:
                type_keyword = search.group(1)
                for reply in reply_data:
                    cur_reply = reply['reply']
                    if type_keyword in cur_reply:
                        ban_reply = reply
                        break

        if not ban_reply:
            return False

        pre_keywords = reply['pre_keywords']
        keywords = reply['reply_keywords']

        # 考虑这句回复是从别的群捞过来的情况,所以这里要分两次 update
        # context_mongo.update_one({
        #     'keywords': pre_keywords,
        #     'answers.keywords': keywords,
        #     'answers.group_id': group_id
        # }, {
        #     '$set': {
        #         'answers.$.count': -99999
        #     }
        # })
        context_mongo.update_one({
            'keywords': pre_keywords
        }, {
            '$push': {
                'ban': {
                    'keywords': keywords,
                    'group_id': group_id,
                    'reason': reason,
                    'time': int(time.time())
                }
            }
        })
        if keywords in Chat.blacklist_answer_reserve[group_id]:
            Chat.blacklist_answer[group_id].add(keywords)
            if keywords in Chat.blacklist_answer_reserve[Chat._blacklist_flag]:
                Chat.blacklist_answer[Chat._blacklist_flag].add(
                    keywords)
        else:
            Chat.blacklist_answer_reserve[group_id].add(keywords)

        return True

    @staticmethod
    def drink(group_id: int) -> None:
        """
        牛牛喝酒,仅对该群有效果。提高醉酒程度(降低回复阈值的概率)
        """
        Chat._drunkenness_dict[group_id] += 1

    @staticmethod
    def sober_up(group_id: int) -> bool:
        """
        牛牛醒酒,仅对该群有效果。返回醒酒是否成功
        """

        Chat._drunkenness_dict[group_id] -= 1
        return Chat._drunkenness_dict[group_id] <= 0

# private:
    _reply_dict = defaultdict(lambda: defaultdict(list))  # 牛牛回复的消息缓存,暂未做持久化
    _message_dict = {}              # 群消息缓存
    _drunkenness_dict = defaultdict(int)          # 醉酒程度,不同群应用不同的数值

    _save_reserve_size = 100        # 保存时,给内存中保留的大小
    _late_save_time = 0             # 上次保存(消息数据持久化)的时刻 ( time.time(), 秒 )

    _reply_lock = threading.Lock()
    _message_lock = threading.Lock()
    _blacklist_flag = 114514

    def _message_insert(self):
        group_id = self.chat_data.group_id

        with Chat._message_lock:
            if group_id not in Chat._message_dict:
                Chat._message_dict[group_id] = []

            Chat._message_dict[group_id].append({
                'group_id': group_id,
                'user_id': self.chat_data.user_id,
                'raw_message': self.chat_data.raw_message,
                'is_plain_text': self.chat_data.is_plain_text,
                'plain_text': self.chat_data.plain_text,
                'keywords': self.chat_data.keywords,
                'time': self.chat_data.time,
            })

        cur_time = self.chat_data.time
        if Chat._late_save_time == 0:
            Chat._late_save_time = cur_time - 1
            return

        if len(Chat._message_dict[group_id]) > Chat.save_count_threshold:
            Chat._sync(cur_time)

        elif cur_time - Chat._late_save_time > Chat.save_time_threshold:
            Chat._sync(cur_time)

    @staticmethod
    def _sync(cur_time: int = time.time()):
        """
        持久化
        """

        with Chat._message_lock:
            save_list = [msg
                         for group_msgs in Chat._message_dict.values()
                         for msg in group_msgs
                         if msg['time'] > Chat._late_save_time]
            if not save_list:
                return

            Chat._message_dict = {group_id: group_msgs[-Chat._save_reserve_size:]
                                  for group_id, group_msgs in Chat._message_dict.items()}

            Chat._late_save_time = cur_time

        message_mongo.insert_many(save_list)

    def _context_insert(self, pre_msg):
        if not pre_msg:
            return

        raw_message = self.chat_data.raw_message

        # 在复读,不学
        if pre_msg['raw_message'] == raw_message:
            return

        # 回复别人的,不学
        if '[CQ:reply,' in raw_message:
            return

        keywords = self.chat_data.keywords
        group_id = self.chat_data.group_id
        pre_keywords = pre_msg['keywords']
        cur_time = self.chat_data.time

        # update_key = {
        #     'keywords': pre_keywords,
        #     'answers.keywords': keywords,
        #     'answers.group_id': group_id
        # }
        # update_value = {
        #     '$set': {'time': cur_time},
        #     '$inc': {'answers.$.count': 1},
        #     '$push': {'answers.$.messages': raw_message}
        # }
        # # update_value.update(update_key)

        # context_mongo.update_one(
        #     update_key, update_value, upsert=True)

        # 这个 upsert 太难写了,搞不定_(:з」∠)_
        # 先用 find + insert or update 凑合了
        find_key = {'keywords': pre_keywords}
        context = context_mongo.find_one(find_key)
        if context:
            update_value = {
                '$set': {
                    'time': cur_time
                },
                '$inc': {'count': 1}
            }
            answer_index = next((idx for idx, answer in enumerate(context['answers'])
                                 if answer['group_id'] == group_id
                                 and answer['keywords'] == keywords), -1)
            if answer_index != -1:
                update_value['$inc'].update({
                    f'answers.{answer_index}.count': 1
                })
                update_value['$set'].update({
                    f'answers.{answer_index}.time': cur_time
                })
                # 不是纯文本的时候,raw_message 是完全一样的,没必要 push
                if self.chat_data.is_plain_text:
                    update_value['$push'] = {
                        f'answers.{answer_index}.messages': raw_message
                    }
            else:
                update_value['$push'] = {
                    'answers': {
                        'keywords': keywords,
                        'group_id': group_id,
                        'count': 1,
                        'time': cur_time,
                        'messages': [
                            raw_message
                        ]
                    }
                }

            context_mongo.update_one(find_key, update_value)
        else:
            context = {
                'keywords': pre_keywords,
                'time': cur_time,
                'count': 1,
                'answers': [
                    {
                        'keywords': keywords,
                        'group_id': group_id,
                        'count': 1,
                        'time': cur_time,
                        'messages': [
                            raw_message
                        ]
                    }
                ]
            }
            context_mongo.insert_one(context)

    def _context_find(self) -> Optional[Tuple[List[str], str]]:

        group_id = self.chat_data.group_id
        raw_message = self.chat_data.raw_message
        keywords = self.chat_data.keywords
        bot_id = self.chat_data.bot_id

        # 复读!
        if group_id in Chat._message_dict:
            group_msgs = Chat._message_dict[group_id]
            if len(group_msgs) >= Chat.repeat_threshold and \
                all(item['raw_message'] == raw_message
                    for item in group_msgs[:-Chat.repeat_threshold:-1]):
                # 到这里说明当前群里是在复读
                group_bot_replies = Chat._reply_dict[group_id][bot_id]
                if len(group_bot_replies) and group_bot_replies[-1]['reply'] != raw_message:
                    return [raw_message, ], keywords
                else:
                    # 复读过一次就不再回复这句话了
                    return None

        context = context_mongo.find_one({'keywords': keywords})

        if not context:
            return None

        if Chat._drunkenness_dict[group_id] > 0:
            answer_count_threshold = 1
        else:
            answer_count_threshold = Chat.answer_threshold

        if self.chat_data.to_me:
            cross_group_threshold = 1
        else:
            cross_group_threshold = Chat.cross_group_threshold

        ban_keywords = Chat._find_ban_keywords(
            context=context, group_id=group_id)

        candidate_answers = {}
        other_group_cache = {}
        answers_count = defaultdict(int)

        def candidate_append(dst, answer):
            answer_key = answer['keywords']
            if answer_key not in dst:
                dst[answer_key] = answer
            else:
                pre_answer = dst[answer_key]
                pre_answer['count'] += answer['count']
                pre_answer['messages'] += answer['messages']

        for answer in context['answers']:
            answer_key = answer['keywords']
            if answer_key in ban_keywords or answer['count'] < answer_count_threshold:
                continue

            sample_msg = answer['messages'][0]
            if self.chat_data.is_image and '[CQ:' not in sample_msg:
                # 图片消息不回复纯文本。图片经常是表情包,后面的纯文本啥都有,很乱
                continue

            if answer['group_id'] == group_id:
                candidate_append(candidate_answers, answer)
            # 别的群的 at, 忽略
            elif '[CQ:at,qq=' in sample_msg:
                continue
            else:   # 有这么 N 个群都有相同的回复,就作为全局回复
                answers_count[answer_key] += 1
                cur_count = answers_count[answer_key]
                if cur_count < cross_group_threshold:      # 没达到阈值前,先缓存
                    candidate_append(other_group_cache, answer)
                elif cur_count == cross_group_threshold:   # 刚达到阈值时,将缓存加入
                    if cur_count > 1:
                        candidate_append(candidate_answers,
                                         other_group_cache[answer_key])
                    candidate_append(candidate_answers, answer)
                else:                                      # 超过阈值后,加入
                    candidate_append(candidate_answers, answer)

        if not candidate_answers:
            return None

        final_answer = random.choices(list(candidate_answers.values()), weights=[
            # 防止某个回复权重太大,别的都 Roll 不到了
            min(answer['count'], 10) for answer in candidate_answers.values()])[0]
        answer_str = random.choice(final_answer['messages'])
        answer_keywords = final_answer['keywords']

        if 0 < answer_str.count(',') <= 3 and random.random() < Chat.split_probability:
            return answer_str.split(','), answer_keywords
        return [answer_str, ], answer_keywords

    @staticmethod
    def _text_to_speech(text: str) -> Optional[Message]:
        # if plugin_config.enable_voice:
        #     result = tts_client.synthesis(text, options={'per': 111})  # 度小萌
        #     if not isinstance(result, dict):  # error message
        #         return MessageSegment.record(result)

        return Message(f'[CQ:tts,text={text}]')

    @staticmethod
    def update_global_blacklist() -> None:
        Chat._select_blacklist()

        keywords_dict = defaultdict(int)
        global_blacklist = set()
        for _, keywords_list in Chat.blacklist_answer.items():
            for keywords in keywords_list:
                keywords_dict[keywords] += 1
                if keywords_dict[keywords] == Chat.cross_group_threshold:
                    global_blacklist.add(keywords)

        Chat.blacklist_answer[Chat._blacklist_flag] |= global_blacklist

    @staticmethod
    def _select_blacklist() -> None:
        all_blacklist = blacklist_mongo.find()

        for item in all_blacklist:
            group_id = item['group_id']
            if 'answers' in item:
                Chat.blacklist_answer[group_id] |= set(item['answers'])
            if 'answers_reserve' in item:
                Chat.blacklist_answer_reserve[group_id] |= set(
                    item['answers_reserve'])

    @staticmethod
    def _sync_blacklist() -> None:
        Chat._select_blacklist()

        for group_id, answers in Chat.blacklist_answer.items():
            if not len(answers):
                continue
            blacklist_mongo.update_one(
                {'group_id': group_id},
                {'$set': {'answers': list(answers)}},
                upsert=True)

        for group_id, answers in Chat.blacklist_answer_reserve.items():
            if not len(answers):
                continue
            if group_id in Chat.blacklist_answer:
                answers = answers - Chat.blacklist_answer[group_id]

            blacklist_mongo.update_one(
                {'group_id': group_id},
                {'$set': {'answers_reserve': list(answers)}},
                upsert=True)

    @staticmethod
    def clearup_context() -> None:
        """
        清理所有超过 30 天没人说、且没有学会的话
        """

        cur_time = int(time.time())
        expiration = cur_time - 30 * 24 * 3600  # 三十天前

        context_mongo.delete_many({
            'time': {'$lt': expiration},
            'count': {'$lt': Chat.answer_threshold}    # lt 是小于,不包括等于
        })

        all_context = context_mongo.find({
            'count': {'$gt': 100},
            '$or': [
                # 历史遗留问题,老版本的数据没有 clear_time 字段
                {"clear_time": {"$exists": False}},
                {"clear_time": {"$lt": expiration}}
            ]
        })
        for context in all_context:
            answers = [ans
                       for ans in context['answers']
                       # 历史遗留问题,老版本的数据没有 answers.$.time 字段
                       if ans['count'] > 1 or ('time' in ans and ans['time'] > expiration)]
            context_mongo.update_one({
                'keywords': context['keywords']
            }, {
                '$set': {
                    'answers': answers,
                    'clear_time': cur_time
                }
            })

    @staticmethod
    def completely_sober():
        for key in Chat._drunkenness_dict.keys():
            Chat._drunkenness_dict[key] = 0

    @staticmethod
    def _find_ban_keywords(context, group_id) -> set:
        """
        找到在 group_id 群中对应 context 不能回复的关键词
        """

        # 全局的黑名单
        ban_keywords = Chat.blacklist_answer[Chat._blacklist_flag] | Chat.blacklist_answer[group_id]
        # 针对单条回复的黑名单
        if 'ban' in context:
            ban_count = defaultdict(int)
            for ban in context['ban']:
                ban_key = ban['keywords']
                if ban['group_id'] == group_id or ban['group_id'] == Chat._blacklist_flag:
                    ban_keywords.add(ban_key)
                else:
                    # 超过 N 个群都把这句话 ban 了,那就全局 ban 掉
                    ban_count[ban_key] += 1
                    if ban_count[ban_key] == Chat.cross_group_threshold:
                        ban_keywords.add(ban_key)
        return ban_keywords


# Auto sync on program start
Chat.update_global_blacklist()


def _chat_sync():
    Chat._sync()
    Chat._sync_blacklist()


# Auto sync on program exit
atexit.register(_chat_sync)


if __name__ == '__main__':

    # Chat.clearup_context()
    # # while True:
    test_data: ChatData = ChatData(
        group_id=1234567,
        user_id=1111111,
        raw_message='完了又有新bug',
        plain_text='完了又有新bug',
        time=time.time(),
        bot_id=0,
    )

    test_chat: Chat = Chat(test_data)

    print(test_chat.answer())
    test_chat.learn()

    test_answer_data: ChatData = ChatData(
        group_id=1234567,
        user_id=1111111,
        raw_message='完了又有新bug',
        plain_text='完了又有新bug',
        time=time.time(),
        bot_id=0,
    )

    test_answer: Chat = Chat(test_answer_data)
    print(test_chat.answer())
    test_answer.learn()

    # time.sleep(5)
    # print(Chat.speak())