diff --git a/SlashBot.py b/SlashBot.py index 5d8a7b1..ec63fd7 100644 --- a/SlashBot.py +++ b/SlashBot.py @@ -1,23 +1,22 @@ from __future__ import annotations +import asyncio import os import sys import re import html -import requests +import httpx import telegram from loguru import logger as _logger -from typing import Optional, Union, Any, Callable -from telegram.ext import Updater, MessageHandler, filters, Dispatcher -from functools import partial -from threading import Thread -from time import sleep +from typing import Optional, Union, Any, Callable, Final, Sequence, Iterable +from telegram.ext import Application, MessageHandler, filters +from functools import partial, wraps from itertools import product as _product from itertools import starmap from random import Random, SystemRandom from collections import deque, Counter - -Filters = filters.Filters +from contextvars import ContextVar +from http.cookiejar import CookieJar, DefaultCookiePolicy parser = re.compile( r'^(?P[\\/]_?\$?)' @@ -52,8 +51,36 @@ mentionParser = re.compile(r'@([a-zA-Z]\w{4,})') product = lambda a, b: tuple(map(','.join, _product(a, b))) -PUNCTUATION_TAIL = '.,?!;:~(' \ - '。,?!;:~(' +PUNCTUATION_TAIL = ( + '.,?!;:~(' + '。,?!;:~(' +) + +try: + random = SystemRandom() +except NotImplementedError: + random = Random() + _logger.warning('SystemRandom is not available, using Random instead') + +# env +TOKENS = re.compile(r'[^a-zA-Z\-_\d:]+').split(os.environ.get('TOKEN', '')) +if not TOKENS: + raise ValueError('no any valid token found') + +PROXY = os.environ.get('PROXY') +# Set proxy and disallow cookies +HTTPX_CLIENT = httpx.AsyncClient(http2=True, proxy=PROXY, cookies=CookieJar(DefaultCookiePolicy(allowed_domains=()))) + +_logger.remove() +_logger.add( + sys.stderr, + format="{time:YYYY-MM-DD HH:mm:ss.SSS}" + "|{level:^8}" + "|{extra[username]:^15}" + "|{message}", + level="DEBUG", +) +logger_var: ContextVar[_logger] = ContextVar("logger_var", default=_logger) class RandomizerMeta(type): @@ -125,41 +152,19 @@ class Stickers(Randomizer): return cls.__class_getitem__('stickers') -_logger.remove() -_logger.add(sys.stderr, - format="{time:YYYY-MM-DD HH:mm:ss.SSS}" - "|{level:^8}" - "|{extra[username]:^15}" - "|{message}", - level="DEBUG") +def log(func: Callable = None, verbose: bool = True): + if func is None: + return partial(log, verbose=verbose) - -def log(func: Callable): - def wrapper(update: telegram.Update, ctx: telegram.ext.CallbackContext): - logger = ctx.bot_data['logger'] + @wraps(func) + async def wrapper(update: telegram.Update, ctx: telegram.ext.CallbackContext): + logger = logger_var.get() logger.debug(str(update.to_dict())) - return func(update, ctx, logger) + return await func(update, ctx, logger) return wrapper -try: - random = SystemRandom() -except NotImplementedError: - random = Random() - _logger.warning('SystemRandom is not available, using Random instead') - -# Docker env -TOKENS = re.compile(r'[^a-zA-Z\-_\d:]+').split(os.environ.get('TOKEN', '')) -if not TOKENS: - raise ValueError('no any valid token found') - -TELEGRAM_PROXY = os.environ.get('PROXY', '') -REQUEST_PROXIES = {'all': TELEGRAM_PROXY} if TELEGRAM_PROXY else None - -_updaters: list[Updater] = [] - - class User: def __init__(self, uid: Optional[int] = None, username: Optional[str] = None, name: Optional[str] = None): if not (uid and name) and not username: @@ -167,11 +172,9 @@ class User: self.name = name self.uid = uid self.username = username - if not self.name and self.username: - self.__get_user_by_username() - def __get_user_by_username(self): - r = requests.get(f'https://t.me/{self.username}', proxies=REQUEST_PROXIES) + async def __get_user_by_username(self): + r = await HTTPX_CLIENT.get(f'https://t.me/{self.username}') og_t = re.search(r'(?<=).*(?=)', r.text, re.IGNORECASE).group(0) @@ -180,7 +183,9 @@ class User: elif name: self.name = name - def mention(self, mention_self: bool = False, pure: bool = False) -> str: + async def mention(self, mention_self: bool = False, pure: bool = False) -> str: + if not self.name and self.username: + await self.__get_user_by_username() if not self.name: return f'@{self.username}' @@ -195,8 +200,8 @@ class User: return ( type(self) == type(other) and ( - ((self.uid or other.uid) and self.uid == other.uid) or - ((self.username or other.username) and self.username == other.username) + (self.uid and other.uid and self.uid == other.uid) + or (self.username and other.username and self.username == other.username) ) ) @@ -252,10 +257,9 @@ def get_tail(tail_char: str) -> str: return '!' if halfwidth_mark else '!' -def get_text(user_from: User, user_rpl: User, command: dict): - rpl_self = user_from == user_rpl - mention_from = user_from.mention() - mention_rpl = user_rpl.mention(mention_self=rpl_self) +async def get_text(user_from: User, user_rpl: User, command: dict): + is_self_rpl = user_from == user_rpl + mention_from, mention_rpl = await asyncio.gather(user_from.mention(), user_rpl.mention(mention_self=is_self_rpl)) slash, predicate, complement, omit_le = \ command['slash'], command['predicate'], command['complement'], command['omit_le'] @@ -266,7 +270,7 @@ def get_text(user_from: User, user_rpl: User, command: dict): ret += get_tail((complement or user_from.mention(pure=True))[-1]) elif predicate == 'you': ret = f"{mention_rpl}{bool(complement) * ' '}{complement}" - ret += get_tail((complement or user_rpl.mention(mention_self=rpl_self, pure=True))[-1]) + ret += get_tail((complement or user_rpl.mention(mention_self=is_self_rpl, pure=True))[-1]) elif complement: ret = f"{mention_from} {predicate} {mention_rpl} {complement}" ret += get_tail(complement[-1]) @@ -278,12 +282,12 @@ def get_text(user_from: User, user_rpl: User, command: dict): return ret -def reply(update: telegram.Update, ctx: telegram.ext.CallbackContext): +@log(verbose=False) +async def reply(update: telegram.Update, ctx: telegram.ext.CallbackContext, logger: _logger = _logger): command = parse_command(ctx) if not command: return - logger = ctx.bot_data['logger'] logger.debug(str(update.to_dict())) msg = update.effective_message from_user, rpl_user = get_users(msg) @@ -306,95 +310,148 @@ def reply(update: telegram.Update, ctx: telegram.ext.CallbackContext): if command['swap'] and (not from_user == rpl_user): (from_user, rpl_user) = (rpl_user, from_user) - text = get_text(from_user, rpl_user, command) + text = await get_text(from_user, rpl_user, command) logger.info(text) - update.effective_message.reply_text('\u200e' + text, parse_mode='HTML') + await msg.reply_text('\u200e' + text, parse_mode='HTML') @log -def repeat(update: telegram.Update, _ctx: telegram.ext.CallbackContext, logger: _logger): +async def repeat(update: telegram.Update, _ctx: telegram.ext.CallbackContext, logger: _logger = _logger): chat = update.effective_chat msg = update.effective_message tid = msg.message_thread_id logger.info(msg.text) - ( - msg.copy - if msg.has_protected_content - else msg.forward - )(chat.id, message_thread_id=tid) + if msg.has_protected_content: + await msg.copy(chat.id, message_thread_id=tid) + else: + await msg.forward(chat.id, message_thread_id=tid) @log -def pin(update: telegram.Update, _ctx: telegram.ext.CallbackContext, logger: _logger): +async def pin(update: telegram.Update, _ctx: telegram.ext.CallbackContext, logger: _logger = _logger): msg = update.effective_message msg_to_pin = get_reply(msg) if not msg_to_pin: vegetable = f'{Vegetable["reject"]} (Reply to a message to use the command)' - msg.reply_text(vegetable) + await msg.reply_text(vegetable) logger.warning(vegetable) return try: - msg_to_pin.unpin() - msg_to_pin.pin(disable_notification=True) + await msg_to_pin.unpin() + await msg_to_pin.pin(disable_notification=True) logger.info(f'Pinned {msg_to_pin.text}') except telegram.error.BadRequest as e: vegetable = f'{Vegetable["permission_denied"]} ({e})' - msg_to_pin.reply_text(vegetable) + await msg_to_pin.reply_text(vegetable) logger.warning(vegetable) @log -def random_sticker(update: telegram.Update, _ctx: telegram.ext.CallbackContext, logger: _logger): +async def random_sticker(update: telegram.Update, _ctx: telegram.ext.CallbackContext, logger: _logger = _logger): msg = update.effective_message sticker = Stickers() logger.info(sticker) - msg.reply_sticker(sticker) + await msg.reply_sticker(sticker) -def start(token: str): - updater = Updater(token=token, use_context=True, request_kwargs={'proxy_url': TELEGRAM_PROXY}) - dp: Dispatcher = updater.dispatcher - dp.add_handler(MessageHandler(Filters.regex(ouenParser) & ~Filters.update.edited_message, repeat, run_async=True)) - dp.add_handler(MessageHandler(Filters.regex(randomStickerParser) & ~Filters.update.edited_message, random_sticker, - run_async=True)) - dp.add_handler(MessageHandler(Filters.regex(pinParser) & ~Filters.update.edited_message, pin, run_async=True)) - dp.add_handler(MessageHandler(Filters.regex(parser) & ~Filters.update.edited_message, reply, run_async=True)) - username = f'@{updater.bot.username}' - logger = _logger.bind(username=username) - dp.bot_data['delUsername'] = partial(re.compile(username, re.I).sub, '') - dp.bot_data['logger'] = logger +class App: + _apps: Final[set["App"]] = set() + # MessageHandler is stateless and reusable, so we can reuse the same instance for all handlers. + # Note: this is not always true for other handlers, e.g., ConversationHandler. + _handlers: Final[Sequence[MessageHandler]] = ( + MessageHandler( + filters.Regex(ouenParser) & ~filters.UpdateType.EDITED, + repeat, + block=False, + ), + MessageHandler( + filters.Regex(randomStickerParser) & ~filters.UpdateType.EDITED, + random_sticker, + block=False, + ), + MessageHandler( + filters.Regex(pinParser) & ~filters.UpdateType.EDITED, + pin, + block=False, + ), + MessageHandler( + filters.Regex(parser) & ~filters.UpdateType.EDITED, + reply, + block=False, + ), + ) - updater.start_polling() - logger.info('Started') + def __init__(self, token: str): + self.token = token + ab = Application.builder().token(token) + if PROXY: + ab = ab.proxy(PROXY).get_updates_proxy(PROXY) + self.application = ab.build() + self.application.add_handlers(self._handlers) - _updaters.append(updater) - # updater.idle() + async def start(self): + app = self.application + await app.initialize() + + username = f'@{app.bot.username}' + logger = _logger.bind(username=username) + logger_var.set(logger) + app.bot_data['delUsername'] = partial(re.compile(username, re.I).sub, '') + + if app.post_init: + await app.post_init(app) + + await app.updater.start_polling() + await app.start() + + logger.info('Started') + self._apps.add(self) + + async def shutdown(self): + app = self.application + logger = logger_var.get() + + await app.updater.stop() + await app.stop() + if app.post_stop: + await app.post_stop(app) + + await app.shutdown() + if app.post_shutdown: + await app.post_shutdown(app) + + logger.info('Stopped') + self._apps.discard(self) + + @classmethod + async def start_all(cls, tokens: Iterable[str]): + await asyncio.gather(*(cls(token).start() for token in tokens)) + + @classmethod + async def shutdown_all(cls): + if cls._apps: + await asyncio.gather(*(app.shutdown() for app in cls._apps)) + assert not cls._apps, 'Not all apps were stopped' + + @classmethod + async def run(cls, tokens: Iterable[str]): + try: + # Initialize and reuse the HTTPX client for all instances, and shut it down on exit. + async with HTTPX_CLIENT: + await cls.start_all(tokens) + # The Event is never set to finish, so it is equivalent to asyncio.get_running_loop().run_forever(). + await asyncio.Event().wait() + except (KeyboardInterrupt, SystemExit): + pass + finally: + await cls.shutdown_all() def main(): - threads: list[Thread] = [] - for token in TOKENS: - thread = Thread(target=start, args=(token,), daemon=True) - threads.append(thread) - thread.start() - for thread in threads: - thread.join() - - try: - while True: - sleep(1) - except KeyboardInterrupt: - threads_and_logger: list[tuple[Thread, Any]] = [] - for updater in _updaters: - thread = Thread(target=updater.stop, daemon=True) - threads_and_logger.append((thread, updater.dispatcher.bot_data['logger'])) - thread.start() - for thread, logger in threads_and_logger: - thread.join() - logger.info('Stopped') + asyncio.run(App.run(TOKENS)) if __name__ == '__main__': diff --git a/requirements.txt b/requirements.txt index ff36ec6..5ccc88f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -python-telegram-bot==13.15 -requests==2.31.0 -loguru==0.7.3 \ No newline at end of file +python-telegram-bot[socks,job-queue]>=22.0,<23.0 +httpx[http2]>=0.28.1,<0.29.0 +loguru>=0.7.3,<0.8.0 \ No newline at end of file