diff --git a/bot/bot.py b/bot/bot.py index 35dbd1ba4e..c271f57e0e 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -1,19 +1,23 @@ import asyncio import contextlib +import types from sys import exception import aiohttp from discord.errors import Forbidden +from discord.ext import commands from pydis_core import BotBase +from pydis_core.utils import scheduling +from pydis_core.utils._extensions import walk_extensions from pydis_core.utils.error_handling import handle_forbidden_from_block from sentry_sdk import new_scope, start_transaction from bot import constants, exts from bot.log import get_logger +from bot.utils.startup_reporting import StartupFailureReporter log = get_logger("bot") - class StartupError(Exception): """Exception class for startup errors.""" @@ -26,9 +30,13 @@ class Bot(BotBase): """A subclass of `pydis_core.BotBase` that implements bot-specific functions.""" def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + # Track extension load failures and tasks so we can report them after all attempts have completed + self.extension_load_failures: dict[str, BaseException] = {} + self._extension_load_tasks: dict[str, asyncio.Task] = {} + self._startup_failure_reporter = StartupFailureReporter() + async def load_extension(self, name: str, *args, **kwargs) -> None: """Extend D.py's load_extension function to also record sentry performance stats.""" with start_transaction(op="cog-load", name=name): @@ -77,3 +85,53 @@ async def on_error(self, event: str, *args, **kwargs) -> None: scope.set_extra("kwargs", kwargs) log.exception(f"Unhandled exception in {event}.") + + async def add_cog(self, cog: commands.Cog) -> None: + """ + Add a cog to the bot with exception handling. + + Override of `BotBase.add_cog` to capture and log any exceptions raised during cog loading, + including the extension name if available. + """ + extension = cog.__module__ + + try: + await super().add_cog(cog) + log.info(f"Cog successfully loaded: {cog.qualified_name}") + + except BaseException as e: + key = extension or f"(unknown)::{cog.qualified_name}" + self.extension_load_failures[key] = e + + log.exception( + f"Failed during add_cog (extension={extension}, cog={cog.qualified_name})" + ) + # Propagate error + raise + + async def _load_extensions(self, module: types.ModuleType) -> None: + """Load extensions for the bot.""" + await self.wait_until_guild_available() + + self.all_extensions = walk_extensions(module) + + async def _load_one(extension: str) -> None: + try: + await self.load_extension(extension) + log.info(f"Extension successfully loaded: {extension}") + + except BaseException as e: + self.extension_load_failures[extension] = e + log.exception(f"Failed to load extension: {extension}") + raise + + for extension in self.all_extensions: + task = scheduling.create_task(_load_one(extension)) + self._extension_load_tasks[extension] = task + + # Wait for all load tasks to complete so we can report any failures together + await asyncio.gather(*self._extension_load_tasks.values(), return_exceptions=True) + + # Send a Discord message to moderators if any extensions failed to load + if self.extension_load_failures : + await self._startup_failure_reporter.notify(self, self.extension_load_failures) diff --git a/bot/constants.py b/bot/constants.py index fc4fa7beed..30e6d18cb5 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -462,6 +462,9 @@ class _URLs(_BaseURLs): connect_max_retries: int = 3 connect_cooldown: int = 5 + # Back-off in cog_load + connect_initial_backoff: int = 1 + site_logs_view: str = "https://pythondiscord.com/staff/bot/logs" diff --git a/bot/exts/filtering/filtering.py b/bot/exts/filtering/filtering.py index 210ae3fb05..719e0ad796 100644 --- a/bot/exts/filtering/filtering.py +++ b/bot/exts/filtering/filtering.py @@ -1,3 +1,4 @@ +import asyncio import datetime import io import json @@ -24,7 +25,7 @@ import bot.exts.filtering._ui.filter as filters_ui from bot import constants from bot.bot import Bot -from bot.constants import BaseURLs, Channels, Guild, MODERATION_ROLES, Roles +from bot.constants import BaseURLs, Channels, Guild, MODERATION_ROLES, Roles, URLs from bot.exts.backend.branding._repository import HEADERS, PARAMS from bot.exts.filtering._filter_context import Event, FilterContext from bot.exts.filtering._filter_lists import FilterList, ListType, ListTypeConverter, filter_list_types @@ -55,6 +56,7 @@ from bot.utils.channel import is_mod_channel from bot.utils.lock import lock_arg from bot.utils.message_cache import MessageCache +from bot.utils.retry import is_retryable_api_error log = get_logger(__name__) @@ -108,7 +110,31 @@ async def cog_load(self) -> None: await self.bot.wait_until_guild_available() log.trace("Loading filtering information from the database.") - raw_filter_lists = await self.bot.api_client.get("bot/filter/filter_lists") + for attempt in range(1, URLs.connect_max_retries + 1): + try: + raw_filter_lists = await self.bot.api_client.get("bot/filter/filter_lists") + break + except Exception as error: + is_retryable = is_retryable_api_error(error) + is_last_attempt = attempt == URLs.connect_max_retries + + if not is_retryable: + raise + + if is_last_attempt: + log.exception("Failed to load filtering data after %d attempts.", URLs.connect_max_retries) + raise + + backoff_seconds = URLs.connect_initial_backoff * (2 ** (attempt - 1)) + log.warning( + "Failed to load filtering data (attempt %d/%d). Retrying in %d second(s): %s", + attempt, + URLs.connect_max_retries, + backoff_seconds, + error + ) + await asyncio.sleep(backoff_seconds) + example_list = None for raw_filter_list in raw_filter_lists: loaded_list = self._load_raw_filter_list(raw_filter_list) diff --git a/bot/exts/info/python_news.py b/bot/exts/info/python_news.py index c786a9d192..437e44cd38 100644 --- a/bot/exts/info/python_news.py +++ b/bot/exts/info/python_news.py @@ -1,3 +1,4 @@ +import asyncio import re import typing as t from datetime import UTC, datetime, timedelta @@ -12,7 +13,9 @@ from bot import constants from bot.bot import Bot +from bot.constants import URLs from bot.log import get_logger +from bot.utils.retry import is_retryable_api_error from bot.utils.webhooks import send_webhook PEPS_RSS_URL = "https://peps.python.org/peps.rss" @@ -46,19 +49,45 @@ def __init__(self, bot: Bot): async def cog_load(self) -> None: """Load all existing seen items from db and create any missing mailing lists.""" - with sentry_sdk.start_span(description="Fetch mailing lists from site"): - response = await self.bot.api_client.get("bot/mailing-lists") - - for mailing_list in response: - self.seen_items[mailing_list["name"]] = set(mailing_list["seen_items"]) - - with sentry_sdk.start_span(description="Update site with new mailing lists"): - for mailing_list in ("pep", *constants.PythonNews.mail_lists): - if mailing_list not in self.seen_items: - await self.bot.api_client.post("bot/mailing-lists", json={"name": mailing_list}) - self.seen_items[mailing_list] = set() - - self.fetch_new_media.start() + for attempt in range(1, URLs.connect_max_retries + 1): + try: + with sentry_sdk.start_span(description="Fetch mailing lists from site"): + response = await self.bot.api_client.get("bot/mailing-lists") + + # Rebuild state on each successful fetch (avoid partial state across retries) + self.seen_items = {} + for mailing_list in response: + self.seen_items[mailing_list["name"]] = set(mailing_list["seen_items"]) + + with sentry_sdk.start_span(description="Update site with new mailing lists"): + for mailing_list in ("pep", *constants.PythonNews.mail_lists): + if mailing_list not in self.seen_items: + await self.bot.api_client.post("bot/mailing-lists", json={"name": mailing_list}) + self.seen_items[mailing_list] = set() + + self.fetch_new_media.start() + return + + except Exception as error: + if not is_retryable_api_error(error): + raise + + if attempt == URLs.connect_max_retries: + log.exception( + "Failed to load PythonNews mailing lists after %d attempt(s).", + URLs.connect_max_retries, + ) + raise + + backoff_seconds = URLs.connect_initial_backoff * (2 ** (attempt - 1)) + log.warning( + "Failed to load PythonNews mailing lists (attempt %d/%d). Retrying in %d second(s). Error: %s", + attempt, + URLs.connect_max_retries, + backoff_seconds, + error, + ) + await asyncio.sleep(backoff_seconds) async def cog_unload(self) -> None: """Stop news posting tasks on cog unload.""" diff --git a/bot/exts/moderation/infraction/superstarify.py b/bot/exts/moderation/infraction/superstarify.py index 006334755d..01481d1f68 100644 --- a/bot/exts/moderation/infraction/superstarify.py +++ b/bot/exts/moderation/infraction/superstarify.py @@ -1,3 +1,4 @@ +import asyncio import json import random import textwrap @@ -10,6 +11,7 @@ from bot import constants from bot.bot import Bot +from bot.constants import URLs from bot.converters import Duration, DurationOrExpiry from bot.decorators import ensure_future_timestamp from bot.exts.moderation.infraction import _utils @@ -17,6 +19,7 @@ from bot.log import get_logger from bot.utils import time from bot.utils.messages import format_user +from bot.utils.retry import is_retryable_api_error log = get_logger(__name__) NICKNAME_POLICY_URL = "https://pythondiscord.com/pages/rules/#nickname-policy" @@ -43,9 +46,7 @@ async def on_member_update(self, before: Member, after: Member) -> None: f"{after.display_name}. Checking if the user is in superstar-prison..." ) - active_superstarifies = await self.bot.api_client.get( - "bot/infractions", - params={ + active_superstarifies = await self._fetch_with_retries(params={ "active": "true", "type": "superstar", "user__id": str(before.id) @@ -84,9 +85,7 @@ async def on_member_update(self, before: Member, after: Member) -> None: @Cog.listener() async def on_member_join(self, member: Member) -> None: """Reapply active superstar infractions for returning members.""" - active_superstarifies = await self.bot.api_client.get( - "bot/infractions", - params={ + active_superstarifies = await self._fetch_with_retries(params={ "active": "true", "type": "superstar", "user__id": member.id @@ -238,6 +237,22 @@ async def cog_check(self, ctx: Context) -> bool: """Only allow moderators to invoke the commands in this cog.""" return await has_any_role(*constants.MODERATION_ROLES).predicate(ctx) + async def _fetch_with_retries(self, + retries: int = URLs.connect_max_retries, + params: dict[str, str] | None = None) -> list[dict]: + """Fetch infractions from the API with retries and exponential backoff.""" + if retries < 1: + raise ValueError("retries must be at least 1") + + for attempt in range(1, retries + 1): + try: + return await self.bot.api_client.get("bot/infractions", params=params) + except Exception as e: + if attempt == retries or not is_retryable_api_error(e): + raise + await asyncio.sleep(URLs.connect_initial_backoff * (2 ** (attempt - 1))) + return None + async def setup(bot: Bot) -> None: """Load the Superstarify cog.""" diff --git a/bot/exts/utils/reminders.py b/bot/exts/utils/reminders.py index 1b386ec000..e116dcf2ae 100644 --- a/bot/exts/utils/reminders.py +++ b/bot/exts/utils/reminders.py @@ -1,3 +1,4 @@ +import asyncio import random import textwrap import typing as t @@ -23,6 +24,7 @@ POSITIVE_REPLIES, Roles, STAFF_AND_COMMUNITY_ROLES, + URLs, ) from bot.converters import Duration, UnambiguousUser from bot.errors import LockedResourceError @@ -224,13 +226,25 @@ async def cog_unload(self) -> None: async def cog_load(self) -> None: """Get all current reminders from the API and reschedule them.""" await self.bot.wait_until_guild_available() - response = await self.bot.api_client.get( - "bot/reminders", - params={"active": "true"} - ) - + # retry fetching reminders with exponential backoff + for attempt in range(1, URLs.connect_max_retries + 1): + try: + # response either throws, or is a list of reminders (possibly empty) + response = await self.bot.api_client.get( + "bot/reminders", + params={"active": "true"} + ) + break + except Exception as e: + if not self._check_error_is_retriable(e): + log.error(f"Failed to load reminders due to non-retryable error: {e}") + raise + log.warning(f"Attempt {attempt} - Failed to fetch reminders from the API: {e}") + if attempt == URLs.connect_max_retries: + log.error("Max retry attempts reached. Failed to load reminders.") + raise + await asyncio.sleep(URLs.connect_initial_backoff * (2 ** (attempt - 1))) # Exponential backoff now = datetime.now(UTC) - for reminder in response: is_valid, *_ = self.ensure_valid_reminder(reminder) if not is_valid: @@ -244,6 +258,13 @@ async def cog_load(self) -> None: else: self.schedule_reminder(reminder) + def _check_error_is_retriable(self, error: Exception) -> bool: + """Return whether loading filter lists failed due to some temporary error, thus retrying could help.""" + if isinstance(error, ResponseCodeError): + return error.status in (408, 429) or error.status >= 500 + + return isinstance(error, (TimeoutError, OSError)) + def ensure_valid_reminder(self, reminder: dict) -> tuple[bool, discord.TextChannel]: """Ensure reminder channel can be fetched otherwise delete the reminder.""" channel = self.bot.get_channel(reminder["channel_id"]) diff --git a/bot/utils/retry.py b/bot/utils/retry.py new file mode 100644 index 0000000000..342897f381 --- /dev/null +++ b/bot/utils/retry.py @@ -0,0 +1,9 @@ +from pydis_core.site_api import ResponseCodeError + + +def is_retryable_api_error(error: Exception) -> bool: + """Return whether an API error is temporary and worth retrying.""" + if isinstance(error, ResponseCodeError): + return error.status in (408, 429) or error.status >= 500 + + return isinstance(error, (TimeoutError, OSError)) diff --git a/bot/utils/startup_reporting.py b/bot/utils/startup_reporting.py new file mode 100644 index 0000000000..f7713d1ca4 --- /dev/null +++ b/bot/utils/startup_reporting.py @@ -0,0 +1,55 @@ +from collections.abc import Mapping +from typing import TYPE_CHECKING + +import discord + +from bot.constants import Channels, Icons +from bot.log import get_logger + +log = get_logger(__name__) + +if TYPE_CHECKING: + from bot.bot import Bot + +class StartupFailureReporter: + """Formats and sends one aggregated startup failure alert to moderators.""" + + async def notify(self, bot: Bot, failures: Mapping[str, BaseException], channel_id: int = Channels.mod_log) -> None: + """Notify moderators of startup failures.""" + if not failures: + return + + if bot.get_channel(channel_id) is None: + # Can't send a message if the channel doesn't exist, so log instead + log.warning("Failed to send startup failure report: mod_log channel not found.") + return + + try: + # Local import avoids circular dependency + from bot.utils.modlog import send_log_message + + text = self.render(failures) + + await send_log_message( + bot, + icon_url=Icons.token_removed, + colour=discord.Colour.red(), + title="Startup: Some extensions failed to load", + text=text, + ping_everyone=True, + channel_id=channel_id + ) + except Exception as exception: + log.exception(f"Failed to send startup failure report: {exception}") + + def render(self, failures: Mapping[str, BaseException]) -> str: + """Render a human-readable message from the given failures.""" + failure_keys = sorted(failures.keys()) + + lines = [] + lines.append("The following extension(s) failed to load:") + for failure_key in failure_keys: + exception = failures[failure_key] + lines.append(f"- **{failure_key}** - `{type(exception).__name__}: {exception}`") + + return "\n".join(lines) diff --git a/tests/bot/exts/filtering/test_filtering_cog.py b/tests/bot/exts/filtering/test_filtering_cog.py new file mode 100644 index 0000000000..736c4bf1aa --- /dev/null +++ b/tests/bot/exts/filtering/test_filtering_cog.py @@ -0,0 +1,77 @@ +import unittest +from unittest.mock import AsyncMock, MagicMock, patch + +from bot.exts.filtering.filtering import Filtering + + +class FilteringCogLoadTests(unittest.IsolatedAsyncioTestCase): + """Test startup behavior of the Filtering cog (`cog_load`).""" + + def setUp(self) -> None: + """Set up a Filtering cog with a mocked bot and stubbed startup dependencies.""" + self.bot = MagicMock() + self.bot.wait_until_guild_available = AsyncMock() + + self.bot.api_client = MagicMock() + self.bot.api_client.get = AsyncMock() + + self.cog = Filtering(self.bot) + + # Stub internals that are not relevant to this unit test. + self.cog.collect_loaded_types = MagicMock() + self.cog.schedule_offending_messages_deletion = AsyncMock() + self.cog._fetch_or_generate_filtering_webhook = AsyncMock(return_value=MagicMock()) + + # `weekly_auto_infraction_report_task` is a discord task loop; patch its start method. + self.start_patcher = patch.object(self.cog.weekly_auto_infraction_report_task, "start") + self.mock_weekly_task_start = self.start_patcher.start() + self.addCleanup(self.start_patcher.stop) + + async def test_cog_load_retries_then_succeeds(self): + """`cog_load` should retry temporary failures and complete startup after a successful fetch.""" + self.bot.api_client.get.side_effect = [ + OSError("temporary outage"), + TimeoutError("temporary timeout"), + [], + ] + + with patch("bot.exts.filtering.filtering.asyncio.sleep", new_callable=AsyncMock) as mock_sleep: + await self.cog.cog_load() + + self.bot.wait_until_guild_available.assert_awaited_once() + self.assertEqual(self.bot.api_client.get.await_count, 3) + self.bot.api_client.get.assert_awaited_with("bot/filter/filter_lists") + self.assertEqual(mock_sleep.await_count, 2) + self.cog._fetch_or_generate_filtering_webhook.assert_awaited_once() + self.cog.collect_loaded_types.assert_called_once_with(None) + self.cog.schedule_offending_messages_deletion.assert_awaited_once() + self.mock_weekly_task_start.assert_called_once() + + async def test_retries_three_times_fails_and_reraises(self): + """`cog_load` should retry and re-raise when all retry attempts fail.""" + self.bot.api_client.get.side_effect = OSError( + "Simulated site/API outage during cog_load" + ) + + with patch( + "bot.exts.filtering.filtering.asyncio.sleep", + new_callable=AsyncMock, + ) as mock_sleep, self.assertRaises(OSError) as ctx: + await self.cog.cog_load() + + self.assertIs(ctx.exception, self.bot.api_client.get.side_effect) + + # Waited for guild availability + self.bot.wait_until_guild_available.assert_awaited_once() + + # 3 attempts + self.assertEqual(self.bot.api_client.get.await_count, 3) + self.bot.api_client.get.assert_awaited_with("bot/filter/filter_lists") + + # Backoff between attempts (attempts - 1) + self.assertEqual(mock_sleep.await_count, 2) + + # Startup should stop before later steps. + self.cog._fetch_or_generate_filtering_webhook.assert_not_awaited() + self.cog.schedule_offending_messages_deletion.assert_not_awaited() + self.mock_weekly_task_start.assert_not_called() diff --git a/tests/bot/exts/info/test_python_news.py b/tests/bot/exts/info/test_python_news.py new file mode 100644 index 0000000000..75c9e386aa --- /dev/null +++ b/tests/bot/exts/info/test_python_news.py @@ -0,0 +1,100 @@ +import unittest +from unittest.mock import AsyncMock, MagicMock, patch + +from pydis_core.site_api import ResponseCodeError + +from bot.constants import URLs +from bot.exts.info.python_news import PythonNews + + +class PythonNewsCogLoadTests(unittest.IsolatedAsyncioTestCase): + """Test startup behavior of the PythonNews cog (`cog_load`).""" + + def setUp(self) -> None: + """Set up a PythonNews cog with a mocked bot and stubbed startup dependencies.""" + self.bot = MagicMock() + self.bot.wait_until_guild_available = AsyncMock() + + self.bot.api_client = MagicMock() + self.bot.api_client.get = AsyncMock() + self.bot.api_client.post = AsyncMock() + + # Required by `fetch_new_media` later, but not used in these tests. + self.bot.http_session = MagicMock() + + self.cog = PythonNews(self.bot) + + # Stub out task-loop start, so it doesn't actually schedule anything. + self.start_patcher = patch.object(self.cog.fetch_new_media, "start") + self.mock_fetch_new_media_start = self.start_patcher.start() + self.addCleanup(self.start_patcher.stop) + + async def test_cog_load_retries_then_succeeds(self): + """`cog_load` should retry temporary failures and complete startup after a successful fetch.""" + # First two attempts fail with retryable errors, third succeeds. + self.bot.api_client.get.side_effect = [ + OSError("temporary outage"), + TimeoutError("temporary timeout"), + [ + {"name": "pep", "seen_items": ["1", "2"]}, + ], + ] + + # Ensure no missing mailing lists need creating in this test. + with ( + patch("bot.exts.info.python_news.constants.PythonNews.mail_lists", new=()), + patch("bot.exts.info.python_news.asyncio.sleep", new_callable=AsyncMock) as mock_sleep, + ): + await self.cog.cog_load() + + self.assertEqual(self.bot.api_client.get.await_count, 3) + self.bot.api_client.get.assert_awaited_with("bot/mailing-lists") + + # Sleep should have been awaited for the two failed attempts. + self.assertEqual(mock_sleep.await_count, 2) + + # Task should start after successful load. + self.mock_fetch_new_media_start.assert_called_once() + + # State should be populated. + self.assertIn("pep", self.cog.seen_items) + self.assertEqual(self.cog.seen_items["pep"], {"1", "2"}) + + # No posts should happen because no missing lists. + self.bot.api_client.post.assert_not_awaited() + + async def test_retries_max_times_fails_and_reraises(self): + """`cog_load` should re-raise when all retry attempts fail.""" + self.bot.api_client.get.side_effect = OSError("Simulated site/API outage during cog_load") + + with ( + patch("bot.exts.info.python_news.asyncio.sleep", new_callable=AsyncMock) as mock_sleep, + self.assertRaises(OSError), + ): + await self.cog.cog_load() + + # Should try exactly MAX_ATTEMPTS times. + + self.assertEqual(self.bot.api_client.get.await_count, URLs.connect_max_retries) + self.bot.api_client.get.assert_awaited_with("bot/mailing-lists") + + # Sleeps happen between attempts, so MAX_ATTEMPTS - 1 times. + self.assertEqual(mock_sleep.await_count, URLs.connect_max_retries - 1) + + # Task should never start if load fails. + self.mock_fetch_new_media_start.assert_not_called() + + async def test_cog_load_does_not_retry_non_retryable_error(self): + """`cog_load` should not retry when the error is non-retryable.""" + # 404 should be considered non-retryable by your predicate. + self.bot.api_client.get.side_effect = ResponseCodeError(MagicMock(status=404)) + + with ( + patch("bot.exts.info.python_news.asyncio.sleep", new_callable=AsyncMock) as mock_sleep, + self.assertRaises(ResponseCodeError), + ): + await self.cog.cog_load() + + self.assertEqual(self.bot.api_client.get.await_count, 1) + self.assertEqual(mock_sleep.await_count, 0) + self.mock_fetch_new_media_start.assert_not_called() diff --git a/tests/bot/exts/moderation/infraction/test_superstarify_cog.py b/tests/bot/exts/moderation/infraction/test_superstarify_cog.py new file mode 100644 index 0000000000..54473c7064 --- /dev/null +++ b/tests/bot/exts/moderation/infraction/test_superstarify_cog.py @@ -0,0 +1,105 @@ +import unittest +from unittest.mock import AsyncMock, MagicMock, patch + +from bot.exts.moderation.infraction.superstarify import Superstarify +from tests.helpers import MockBot + + +class TestSuperstarify(unittest.IsolatedAsyncioTestCase): + + async def asyncSetUp(self): + self.bot = MockBot() + + self.cog = Superstarify(self.bot) + + self.bot.api_client = MagicMock() + self.bot.api_client.get = AsyncMock() + + self.cog._check_error_is_retriable = MagicMock(return_value=True) + + async def test_fetch_from_api_success(self): + """API succeeds on first attempt.""" + expected = [{"id": 1}] + self.bot.api_client.get.return_value = expected + + result = await self.cog._fetch_with_retries( + params={"user__id": "123"} + ) + self.assertEqual(result, expected) + + self.bot.api_client.get.assert_awaited_once_with( + "bot/infractions", + params={"user__id": "123"}, + ) + + @patch("asyncio.sleep", new_callable=AsyncMock) + async def test_fetch_retries_then_succeeds(self, _): + self.bot.api_client.get.side_effect = [ + OSError("temporary failure"), + [{"id": 42}], + ] + + result = await self.cog._fetch_with_retries( + params={"user__id": "123"} + ) + + self.assertEqual(result, [{"id": 42}]) + self.assertEqual(self.bot.api_client.get.await_count, 2) + + + @patch("asyncio.sleep", new_callable=AsyncMock) + async def test_fetch_fails_after_max_retries(self, _): + error = OSError("API down") + + self.bot.api_client.get.side_effect = error + + with self.assertRaises(OSError): + await self.cog._fetch_with_retries( + retries=3, + params={"user__id": "123"}, + ) + + self.assertEqual(self.bot.api_client.get.await_count, 3) + + + @patch("asyncio.sleep", new_callable=AsyncMock) + async def test_non_retriable_error_stops_immediately(self, _): + error = ValueError("bad request") + + self.bot.api_client.get.side_effect = error + self.cog._check_error_is_retriable.return_value = False + + with self.assertRaises(ValueError): + await self.cog._fetch_with_retries() + + # only one attempt + self.bot.api_client.get.assert_awaited_once() + + + @patch("asyncio.sleep", new_callable=AsyncMock) + async def test_member_update_recovers_from_api_failure(self, _): + before = MagicMock(display_name="Old", id=123) + after = MagicMock(display_name="New", id=123) + after.edit = AsyncMock() + + self.bot.api_client.get.side_effect = [ + OSError(), + [{"id": 42}], + ] + + self.cog.get_nick = MagicMock(return_value="Taylor Swift") + + with patch( + "bot.exts.moderation.infraction._utils.notify_infraction", + new=AsyncMock(return_value=True), + ): + await self.cog.on_member_update(before, after) + + after.edit.assert_awaited_once() + + @patch("asyncio.sleep", new_callable=AsyncMock) + async def test_alert_triggered_after_total_failure(self, _): + self.bot.api_client.get.side_effect = OSError("down") + + with self.assertRaises(OSError): + await self.cog._fetch_with_retries(retries=3) diff --git a/tests/bot/exts/test_extensions.py b/tests/bot/exts/test_extensions.py new file mode 100644 index 0000000000..2546873883 --- /dev/null +++ b/tests/bot/exts/test_extensions.py @@ -0,0 +1,157 @@ +import asyncio +import contextlib +import importlib +import sys +import unittest +import unittest.mock +from pathlib import Path +from tempfile import TemporaryDirectory + +import discord + +from bot.bot import Bot + + +class ExtensionLoadingTests(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self) -> None: + self.http_session = unittest.mock.MagicMock(name="http_session") + + # Set up a Bot instance with minimal configuration for testing extension loading. + self.bot = Bot( + command_prefix="!", + guild_id=123456789012345678, + allowed_roles=[], + http_session=self.http_session, + intents=discord.Intents.none() + ) + + # Avoid blocking in _load_extensions() + async def _instant() -> None: + return None + self.bot.wait_until_guild_available = _instant + + # Ensure clean state + self.bot.extension_load_failures = {} + self.bot._extension_load_tasks = {} + + # Temporary importable package: tmp_root/testexts/__init__.py + modules + self._temp_dir = TemporaryDirectory() + self.addCleanup(self._temp_dir.cleanup) + self.tmp_root = Path(self._temp_dir.name) + + self.pkg_name = "testexts" + self.pkg_dir = self.tmp_root / self.pkg_name + self.pkg_dir.mkdir(parents=True, exist_ok=True) + (self.pkg_dir / "__init__.py").write_text("", encoding="utf-8") + + sys.path.insert(0, str(self.tmp_root)) + self.addCleanup(self._remove_tmp_from_syspath) + self._purge_modules(self.pkg_name) + + # Ensure scheduled tasks execute during tests + self._create_task_patcher = unittest.mock.patch( + "pydis_core.utils.scheduling.create_task", + side_effect=lambda coro, *a, **k: asyncio.create_task(coro), + ) + self._create_task_patcher.start() + self.addCleanup(self._create_task_patcher.stop) + + def _remove_tmp_from_syspath(self) -> None: + """Remove the temporary directory from sys.path.""" + with contextlib.suppress(ValueError): + sys.path.remove(str(self.tmp_root)) + + def _purge_modules(self, prefix: str) -> None: + """Remove modules from sys.modules that match the given prefix.""" + for name in list(sys.modules.keys()): + if name == prefix or name.startswith(prefix + "."): + del sys.modules[name] + + def _write_ext(self, module_name: str, source: str) -> str: + """Write an extension module with the given source code and + return its full import path.""" + (self.pkg_dir / f"{module_name}.py").write_text(source, encoding="utf-8") + full = f"{self.pkg_name}.{module_name}" + self._purge_modules(full) + return full + + async def _run_loader(self) -> None: + """Run the extension loader on the package containing our test extensions.""" + module = importlib.import_module(self.pkg_name) + + await self.bot._load_extensions(module) + + # Wait for all extension load tasks to complete so that exceptions are recorded in the bot's state + tasks = getattr(self.bot, "_extension_load_tasks", {}) or {} + if tasks: + await asyncio.gather(*tasks.values(), return_exceptions=True) + + def _assert_failure_recorded_for_extension(self, ext: str) -> None: + """Assert that a failure is recorded for the given extension.""" + if ext in self.bot.extension_load_failures: + return + for key in self.bot.extension_load_failures: + if key.startswith(ext): + return + self.fail( + f"Expected a failure recorded for {ext!r}. " + f"Recorded keys: {sorted(self.bot.extension_load_failures.keys())}" + ) + + async def test_setup_failure_is_captured(self) -> None: + ext = self._write_ext( + "ext_setup_fail", + """ +async def setup(bot): + raise RuntimeError("setup failed") +""", + ) + await self._run_loader() + self._assert_failure_recorded_for_extension(ext) + + async def test_cog_load_failure_is_captured(self) -> None: + ext = self._write_ext( + "ext_cogload_fail", + """ +from discord.ext import commands + +class BadCog(commands.Cog): + async def cog_load(self): + raise RuntimeError("cog_load failed") + +async def setup(bot): + await bot.add_cog(BadCog()) +""", + ) + await self._run_loader() + self._assert_failure_recorded_for_extension(ext) + + async def test_add_cog_failure_is_captured(self) -> None: + ext = self._write_ext( + "ext_addcog_fail", + """ +from discord.ext import commands + +class DupCog(commands.Cog): + pass + +async def setup(bot): + await bot.add_cog(DupCog()) + await bot.add_cog(DupCog()) +""", + ) + await self._run_loader() + self._assert_failure_recorded_for_extension(ext) + + async def test_import_failure_is_captured(self) -> None: + ext = self._write_ext( + "ext_import_fail", + """ +raise RuntimeError("import failed before setup()") + +async def setup(bot): + pass +""", + ) + await self._run_loader() + self._assert_failure_recorded_for_extension(ext) diff --git a/tests/bot/exts/utils/test_reminders.py b/tests/bot/exts/utils/test_reminders.py new file mode 100644 index 0000000000..eb1d903876 --- /dev/null +++ b/tests/bot/exts/utils/test_reminders.py @@ -0,0 +1,57 @@ +import unittest +from unittest.mock import AsyncMock, MagicMock, patch + +from pydis_core.site_api import ResponseCodeError + +from bot.constants import URLs +from bot.exts.utils.reminders import Reminders +from tests.helpers import MockBot + + +class RemindersCogLoadTests(unittest.IsolatedAsyncioTestCase): + """ Tests startup behaviour of the Reminders cog. """ + + def setUp(self): + self.bot = MockBot() + self.bot.wait_until_guild_available = AsyncMock() + self.cog = Reminders(self.bot) + + self.cog._alert_mods_if_loading_failed = AsyncMock() + self.cog.ensure_valid_reminder = MagicMock(return_value=(False, None)) + self.cog.schedule_reminder = MagicMock() + self.cog._alert_mods_if_loading_failed = AsyncMock() + + self.bot.api_client = MagicMock() + self.bot.api_client.get = AsyncMock() + + async def test_reminders_cog_loads_correctly(self): + """ Tests if the Reminders cog loads without error if the GET requests works. """ + self.bot.api_client.get.return_value = [] + try: + with patch("bot.exts.utils.reminders.asyncio.sleep", new_callable=AsyncMock): + await self.cog.cog_load() + except Exception as e: + self.fail(f"Reminders cog failed to load with exception: {e}") + + async def test_reminders_cog_load_retries_after_initial_exception(self): + """ Tests if the Reminders cog loads after retrying on initial exception. """ + self.bot.api_client.get.side_effect = [OSError("fail1"), OSError("fail2"), []] + try: + with patch("bot.exts.utils.reminders.asyncio.sleep", new_callable=AsyncMock) as mock_sleep: + await self.cog.cog_load() + except Exception as e: + self.fail(f"Reminders cog failed to load after retrying with exception: {e}") + self.assertEqual(mock_sleep.await_count, 2) + self.bot.api_client.get.assert_called() + + async def test_reminders_cog_load_fails_after_max_retries(self): + """ Tests if the Reminders cog fails to load after max retries. """ + self.bot.api_client.get.side_effect = ResponseCodeError(response=MagicMock(status=500), + response_text="Internal Server Error") + with patch("bot.exts.utils.reminders.asyncio.sleep", new_callable=AsyncMock) as mock_sleep, \ + self.assertRaises(ResponseCodeError): + await self.cog.cog_load() + + # Should have retried MAX_RETRY_ATTEMPTS - 1 times before failing + self.assertEqual(mock_sleep.await_count, URLs.connect_max_retries - 1) + self.bot.api_client.get.assert_called() diff --git a/tests/bot/utils/test_retry.py b/tests/bot/utils/test_retry.py new file mode 100644 index 0000000000..8ce6f6db0c --- /dev/null +++ b/tests/bot/utils/test_retry.py @@ -0,0 +1,29 @@ +import unittest +from unittest.mock import MagicMock + +from pydis_core.site_api import ResponseCodeError + +from bot.utils.retry import is_retryable_api_error + + +class RetryTests(unittest.TestCase): + """Tests for retry classification helpers.""" + + def test_is_retryable_api_error(self): + """`is_retryable_api_error` should classify temporary failures as retryable.""" + test_cases = ( + (ResponseCodeError(MagicMock(status=408)), True), + (ResponseCodeError(MagicMock(status=429)), True), + (ResponseCodeError(MagicMock(status=500)), True), + (ResponseCodeError(MagicMock(status=503)), True), + (ResponseCodeError(MagicMock(status=400)), False), + (ResponseCodeError(MagicMock(status=404)), False), + (TimeoutError("timeout"), True), + (OSError("os error"), True), + (AttributeError("attr"), False), + (ValueError("value"), False), + ) + + for error, expected_retryable in test_cases: + with self.subTest(error=error): + self.assertEqual(is_retryable_api_error(error), expected_retryable)