support multiple comma-separated values in ADMIN_ID

This commit is contained in:
arĉi 2022-10-29 18:52:08 +06:00
parent 30ab7c84b4
commit afc5389520
3 changed files with 15 additions and 14 deletions

View File

@ -25,7 +25,7 @@ async def init_database():
async def init_olgram(): async def init_olgram():
from olgram.router import bot, dp from olgram.router import bot, dp
dp.setup_middleware(AccessMiddleware(OlgramSettings.admin_id())) dp.setup_middleware(AccessMiddleware(OlgramSettings.admin_ids()))
from aiogram.types import BotCommand from aiogram.types import BotCommand
await bot.set_my_commands( await bot.set_my_commands(
[ [

View File

@ -45,9 +45,9 @@ class OlgramSettings(AbstractSettings):
@classmethod @classmethod
@lru_cache @lru_cache
def admin_id(cls): def admin_ids(cls):
_id = cls._get_env("ADMIN_ID", True) _ids = cls._get_env("ADMIN_ID", True)
return int(_id) if _id else None return set(map(int, _ids.split(","))) if _ids else None
@classmethod @classmethod
@lru_cache @lru_cache

View File

@ -1,6 +1,7 @@
import aiogram.types as types import aiogram.types as types
from aiogram.dispatcher.handler import CancelHandler, current_handler from aiogram.dispatcher.handler import CancelHandler, current_handler
from aiogram.dispatcher.middlewares import BaseMiddleware from aiogram.dispatcher.middlewares import BaseMiddleware
from collections.abc import Container
from locales.locale import _ from locales.locale import _
@ -19,8 +20,8 @@ def public():
class AccessMiddleware(BaseMiddleware): class AccessMiddleware(BaseMiddleware):
def __init__(self, access_chat_id: int): def __init__(self, access_chat_ids: Container[int]):
self._access_chat_id = access_chat_id self._access_chat_ids = access_chat_ids
super(AccessMiddleware, self).__init__() super(AccessMiddleware, self).__init__()
@classmethod @classmethod
@ -29,25 +30,25 @@ class AccessMiddleware(BaseMiddleware):
return handler and getattr(handler, "access_public", False) return handler and getattr(handler, "access_public", False)
async def on_process_message(self, message: types.Message, data: dict): async def on_process_message(self, message: types.Message, data: dict):
admin_id = self._access_chat_id admin_ids = self._access_chat_ids
if not admin_id: if not admin_ids:
return # Администратор бота вообще не указан return # Администраторы бота вообще не указаны
if self._is_public_command(): # Эта команда разрешена всем пользователям if self._is_public_command(): # Эта команда разрешена всем пользователям
return return
if message.chat.id != admin_id: if message.chat.id not in admin_ids:
await message.answer(_("Владелец бота ограничил доступ к этому функционалу 😞")) await message.answer(_("Владелец бота ограничил доступ к этому функционалу 😞"))
raise CancelHandler() raise CancelHandler()
async def on_process_callback_query(self, call: types.CallbackQuery, data: dict): async def on_process_callback_query(self, call: types.CallbackQuery, data: dict):
admin_id = self._access_chat_id admin_ids = self._access_chat_ids
if not admin_id: if not admin_ids:
return # Администратор бота вообще не указан return # Администраторы бота вообще не указаны
if self._is_public_command(): # Эта команда разрешена всем пользователям if self._is_public_command(): # Эта команда разрешена всем пользователям
return return
if call.message.chat.id != admin_id: if call.message.chat.id not in admin_ids:
await call.answer(_("Владелец бота ограничил доступ к этому функционалу😞")) await call.answer(_("Владелец бота ограничил доступ к этому функционалу😞"))
raise CancelHandler() raise CancelHandler()