diff --git a/astrbot/cli/commands/cmd_run.py b/astrbot/cli/commands/cmd_run.py index de09e58521..2e3dd1c9c2 100644 --- a/astrbot/cli/commands/cmd_run.py +++ b/astrbot/cli/commands/cmd_run.py @@ -1,14 +1,71 @@ import asyncio +import contextlib import os +import signal import sys import traceback +from collections.abc import Callable from pathlib import Path +from typing import Any import click from filelock import FileLock, Timeout from ..utils import check_astrbot_root, check_dashboard, get_astrbot_root +ShutdownCallback = Callable[[signal.Signals], None] + + +def _install_shutdown_signal_handlers( + loop: asyncio.AbstractEventLoop, + callback: ShutdownCallback, +) -> Callable[[], None]: + """Install SIGINT/SIGTERM handlers and return a cleanup callback.""" + handled_signals = (signal.SIGINT, signal.SIGTERM) + previous_handlers: dict[signal.Signals, Any] = {} + installed: list[signal.Signals] = [] + + for signum in handled_signals: + try: + previous_handlers[signum] = signal.getsignal(signum) + except ValueError: + previous_handlers[signum] = None + try: + loop.add_signal_handler(signum, callback, signum) + installed.append(signum) + except (NotImplementedError, RuntimeError, ValueError): + + def fallback_handler(received_signum, frame): + _ = frame + if not loop.is_closed(): + try: + loop.call_soon_threadsafe( + callback, signal.Signals(received_signum) + ) + except RuntimeError: + pass + + try: + signal.signal(signum, fallback_handler) + installed.append(signum) + except ValueError: + pass + + def cleanup() -> None: + for signum in installed: + try: + loop.remove_signal_handler(signum) + except (NotImplementedError, RuntimeError, ValueError): + pass + previous_handler = previous_handlers.get(signum) + if previous_handler is not None: + try: + signal.signal(signum, previous_handler) + except (TypeError, ValueError): + pass + + return cleanup + async def run_astrbot(astrbot_root: Path) -> None: """Run AstrBot""" @@ -23,7 +80,46 @@ async def run_astrbot(astrbot_root: Path) -> None: core_lifecycle = InitialLoader(db, log_broker) - await core_lifecycle.start() + loop = asyncio.get_running_loop() + shutdown_requested = asyncio.Event() + shutdown_signal: signal.Signals | None = None + + def request_shutdown(signum: signal.Signals) -> None: + nonlocal shutdown_signal + shutdown_signal = signum + shutdown_requested.set() + + cleanup_signal_handlers = _install_shutdown_signal_handlers(loop, request_shutdown) + runner_task = asyncio.create_task(core_lifecycle.start(), name="astrbot") + shutdown_task = asyncio.create_task( + shutdown_requested.wait(), name="astrbot_shutdown" + ) + + try: + done, _ = await asyncio.wait( + {runner_task, shutdown_task}, + return_when=asyncio.FIRST_COMPLETED, + ) + shutdown_requested_by_signal = shutdown_task in done + if shutdown_requested_by_signal and not runner_task.done(): + signal_name = shutdown_signal.name if shutdown_signal else "unknown" + logger.info(f"Received {signal_name}; stopping AstrBot...") + runner_task.cancel() + try: + await runner_task + except asyncio.CancelledError: + if not shutdown_requested_by_signal: + raise + finally: + cleanup_signal_handlers() + if not runner_task.done(): + runner_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await runner_task + if not shutdown_task.done(): + shutdown_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await shutdown_task @click.option("--reload", "-r", is_flag=True, help="Auto-reload plugins") diff --git a/astrbot/core/initial_loader.py b/astrbot/core/initial_loader.py index 3f836a4c42..0d185d3343 100644 --- a/astrbot/core/initial_loader.py +++ b/astrbot/core/initial_loader.py @@ -8,7 +8,7 @@ import asyncio import traceback -from astrbot.core import LogBroker, logger +from astrbot.core import LogBroker, LogManager, logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.db import BaseDatabase from astrbot.dashboard.server import AstrBotDashboard @@ -25,33 +25,48 @@ def __init__(self, db: BaseDatabase, log_broker: LogBroker) -> None: async def start(self) -> None: core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db) + initialized = False try: - await core_lifecycle.initialize() - except Exception as e: - logger.critical(traceback.format_exc()) - logger.critical(f"😭 初始化 AstrBot 失败:{e} !!!") - return - - core_task = core_lifecycle.start() - - webui_dir = self.webui_dir - - self.dashboard_server = AstrBotDashboard( - core_lifecycle, - self.db, - core_lifecycle.dashboard_shutdown_event, - webui_dir, - ) - - coro = self.dashboard_server.run() - if coro: - # 启动核心任务和仪表板服务器 - task = asyncio.gather(core_task, coro) - else: - task = core_task - try: + try: + await core_lifecycle.initialize() + initialized = True + except Exception as e: + logger.critical(traceback.format_exc()) + logger.critical(f"😭 初始化 AstrBot 失败:{e} !!!") + return + + core_task = core_lifecycle.start() + + webui_dir = self.webui_dir + + self.dashboard_server = AstrBotDashboard( + core_lifecycle, + self.db, + core_lifecycle.dashboard_shutdown_event, + webui_dir, + ) + + coro = self.dashboard_server.run() + if coro: + # 启动核心任务和仪表板服务器 + task = asyncio.gather(core_task, coro) + else: + task = core_task await task # 整个AstrBot在这里运行 except asyncio.CancelledError: logger.info("🌈 正在关闭 AstrBot...") - await core_lifecycle.stop() + if initialized: + await core_lifecycle.stop() + except Exception: + if initialized: + try: + await core_lifecycle.stop() + except Exception: + logger.error( + "AstrBot shutdown during runtime-error handling failed", + exc_info=True, + ) + raise + finally: + await LogManager.shutdown() diff --git a/astrbot/core/log.py b/astrbot/core/log.py index 3dd0719b11..d00718b23b 100644 --- a/astrbot/core/log.py +++ b/astrbot/core/log.py @@ -415,3 +415,15 @@ def configure_trace_logger(cls, config: dict | None) -> None: backup_count=3, trace=True, ) + + @classmethod + async def shutdown(cls) -> None: + """Flush and remove loguru sinks during process shutdown.""" + try: + await _loguru.complete() + finally: + cls._remove_sink(cls._trace_sink_id) + cls._trace_sink_id = None + cls._remove_sink(cls._file_sink_id) + cls._file_sink_id = None + cls._configured = False diff --git a/tests/unit/test_cmd_run_shutdown.py b/tests/unit/test_cmd_run_shutdown.py new file mode 100644 index 0000000000..574a57cb0c --- /dev/null +++ b/tests/unit/test_cmd_run_shutdown.py @@ -0,0 +1,519 @@ +from __future__ import annotations + +import asyncio +import signal +from collections.abc import Callable +from pathlib import Path +from typing import Any, cast +from unittest.mock import AsyncMock, MagicMock + +import pytest + +import astrbot.cli.commands.cmd_run as cmd_run +import astrbot.core as core_module +import astrbot.core.initial_loader as initial_loader_module + + +@pytest.mark.asyncio +async def test_run_astrbot_stops_gracefully_on_sigterm(monkeypatch): + set_queue_handler_mock = MagicMock() + check_dashboard_mock = AsyncMock(return_value=None) + signal_restore_mock = MagicMock() + pending_signals: dict[signal.Signals, Callable[[], None]] = {} + removed_signals: list[signal.Signals] = [] + previous_handlers = { + signal.SIGINT: object(), + signal.SIGTERM: object(), + } + + class FakeLoop: + def add_signal_handler(self, signum, callback, *args): + pending_signals[signum] = lambda: callback(*args) + + def remove_signal_handler(self, signum): + removed_signals.append(signum) + pending_signals.pop(signum, None) + return True + + started = asyncio.Event() + cancelled = asyncio.Event() + + class FakeLoader: + def __init__(self, db, log_broker): + self.db = db + self.log_broker = log_broker + + async def start(self): + started.set() + try: + await asyncio.Event().wait() + except asyncio.CancelledError: + cancelled.set() + return + + monkeypatch.setattr(initial_loader_module, "InitialLoader", FakeLoader) + monkeypatch.setattr(cmd_run, "check_dashboard", check_dashboard_mock) + fake_loop = FakeLoop() + monkeypatch.setattr(cmd_run.asyncio, "get_running_loop", lambda: fake_loop) + monkeypatch.setattr( + cmd_run.signal, "getsignal", lambda signum: previous_handlers[signum] + ) + monkeypatch.setattr(cmd_run.signal, "signal", signal_restore_mock) + monkeypatch.setattr( + core_module.LogManager, "set_queue_handler", set_queue_handler_mock + ) + + awaitable = asyncio.create_task(cmd_run.run_astrbot(Path("/tmp/astrbot-root"))) + await started.wait() + + assert signal.SIGTERM in pending_signals + pending_signals[signal.SIGTERM]() + + await awaitable + + assert cancelled.is_set() + assert set(removed_signals) == {signal.SIGINT, signal.SIGTERM} + assert signal_restore_mock.call_count == 2 + check_dashboard_mock.assert_awaited_once() + set_queue_handler_mock.assert_called_once() + + +@pytest.mark.asyncio +async def test_run_astrbot_suppresses_signal_cancelled_runner(monkeypatch): + check_dashboard_mock = AsyncMock(return_value=None) + signal_restore_mock = MagicMock() + pending_signals: dict[signal.Signals, Callable[[], None]] = {} + previous_handlers = { + signal.SIGINT: object(), + signal.SIGTERM: object(), + } + + class FakeLoop: + def add_signal_handler(self, signum, callback, *args): + pending_signals[signum] = lambda: callback(*args) + + def remove_signal_handler(self, signum): + pending_signals.pop(signum, None) + return True + + started = asyncio.Event() + + class FakeLoader: + def __init__(self, db, log_broker): + self.db = db + self.log_broker = log_broker + + async def start(self): + started.set() + await asyncio.Event().wait() + + monkeypatch.setattr(initial_loader_module, "InitialLoader", FakeLoader) + monkeypatch.setattr(cmd_run, "check_dashboard", check_dashboard_mock) + monkeypatch.setattr(cmd_run.asyncio, "get_running_loop", lambda: FakeLoop()) + monkeypatch.setattr( + cmd_run.signal, "getsignal", lambda signum: previous_handlers[signum] + ) + monkeypatch.setattr(cmd_run.signal, "signal", signal_restore_mock) + monkeypatch.setattr(core_module.LogManager, "set_queue_handler", MagicMock()) + + awaitable = asyncio.create_task(cmd_run.run_astrbot(Path("/tmp/astrbot-root"))) + await started.wait() + + pending_signals[signal.SIGTERM]() + + await awaitable + + check_dashboard_mock.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_run_astrbot_cancels_runner_when_parent_is_cancelled(monkeypatch): + check_dashboard_mock = AsyncMock(return_value=None) + signal_restore_mock = MagicMock() + previous_handlers = { + signal.SIGINT: object(), + signal.SIGTERM: object(), + } + + class FakeLoop: + def add_signal_handler(self, _signum, _callback, *_args): + _ = (_signum, _callback, _args) + + def remove_signal_handler(self, _signum): + _ = _signum + return True + + started = asyncio.Event() + cancelled = asyncio.Event() + shutdown_complete = asyncio.Event() + + class FakeLoader: + def __init__(self, db, log_broker): + self.db = db + self.log_broker = log_broker + + async def start(self): + started.set() + try: + await asyncio.Event().wait() + except asyncio.CancelledError: + cancelled.set() + raise + finally: + await asyncio.sleep(0) + shutdown_complete.set() + + monkeypatch.setattr(initial_loader_module, "InitialLoader", FakeLoader) + monkeypatch.setattr(cmd_run, "check_dashboard", check_dashboard_mock) + monkeypatch.setattr(cmd_run.asyncio, "get_running_loop", lambda: FakeLoop()) + monkeypatch.setattr( + cmd_run.signal, "getsignal", lambda signum: previous_handlers[signum] + ) + monkeypatch.setattr(cmd_run.signal, "signal", signal_restore_mock) + monkeypatch.setattr(core_module.LogManager, "set_queue_handler", MagicMock()) + + awaitable = asyncio.create_task(cmd_run.run_astrbot(Path("/tmp/astrbot-root"))) + await started.wait() + + awaitable.cancel() + + with pytest.raises(asyncio.CancelledError): + await awaitable + + assert cancelled.is_set() + assert shutdown_complete.is_set() + check_dashboard_mock.assert_awaited_once() + assert signal_restore_mock.call_count == 2 + + +def test_install_shutdown_signal_handlers_falls_back_and_restores(monkeypatch): + restored_handlers: list[tuple[signal.Signals, Any]] = [] + installed_handlers: dict[signal.Signals, Callable[[int, object], object]] = {} + previous_handlers = { + signal.SIGINT: object(), + signal.SIGTERM: object(), + } + callback = MagicMock() + scheduled_callbacks: list[tuple[Callable[..., object], tuple[object, ...]]] = [] + + class FakeLoop: + def add_signal_handler(self, _signum, _callback, *_args): + _ = (_signum, _callback, _args) + raise NotImplementedError + + def remove_signal_handler(self, _signum): + _ = _signum + raise NotImplementedError + + def is_closed(self): + return False + + def call_soon_threadsafe(self, callback, *args): + scheduled_callbacks.append((callback, args)) + + def fake_signal(signum: signal.Signals, handler: Any) -> object: + if callable(handler): + installed_handlers[signum] = cast(Callable[[int, object], object], handler) + else: + restored_handlers.append((signum, handler)) + return previous_handlers[signum] + + monkeypatch.setattr( + cmd_run.signal, "getsignal", lambda signum: previous_handlers[signum] + ) + monkeypatch.setattr(cmd_run.signal, "signal", fake_signal) + + cleanup = cmd_run._install_shutdown_signal_handlers(cast(Any, FakeLoop()), callback) + + installed_handlers[signal.SIGTERM](signal.SIGTERM, None) + callback.assert_not_called() + assert scheduled_callbacks == [(callback, (signal.SIGTERM,))] + + scheduled_callback, scheduled_args = scheduled_callbacks.pop() + scheduled_callback(*scheduled_args) + callback.assert_called_once_with(signal.SIGTERM) + + cleanup() + + assert restored_handlers == [ + (signal.SIGINT, previous_handlers[signal.SIGINT]), + (signal.SIGTERM, previous_handlers[signal.SIGTERM]), + ] + + +def test_fallback_signal_handler_ignores_closed_loop(monkeypatch): + installed_handlers: dict[signal.Signals, Callable[[int, object], object]] = {} + previous_handlers = { + signal.SIGINT: object(), + signal.SIGTERM: object(), + } + callback = MagicMock() + + class FakeLoop: + def add_signal_handler(self, _signum, _callback, *_args): + _ = (_signum, _callback, _args) + raise NotImplementedError + + def remove_signal_handler(self, _signum): + _ = _signum + raise NotImplementedError + + def is_closed(self): + return True + + def call_soon_threadsafe(self, _callback, *_args): + raise AssertionError("closed loop should not schedule callbacks") + + def fake_signal(signum: signal.Signals, handler: Any) -> object: + if callable(handler): + installed_handlers[signum] = cast(Callable[[int, object], object], handler) + return previous_handlers[signum] + + monkeypatch.setattr( + cmd_run.signal, "getsignal", lambda signum: previous_handlers[signum] + ) + monkeypatch.setattr(cmd_run.signal, "signal", fake_signal) + + cleanup = cmd_run._install_shutdown_signal_handlers(cast(Any, FakeLoop()), callback) + + installed_handlers[signal.SIGTERM](signal.SIGTERM, None) + + callback.assert_not_called() + cleanup() + + +def test_fallback_signal_handler_ignores_call_soon_runtime_error(monkeypatch): + installed_handlers: dict[signal.Signals, Callable[[int, object], object]] = {} + previous_handlers = { + signal.SIGINT: object(), + signal.SIGTERM: object(), + } + callback = MagicMock() + + class FakeLoop: + def add_signal_handler(self, _signum, _callback, *_args): + _ = (_signum, _callback, _args) + raise NotImplementedError + + def remove_signal_handler(self, _signum): + _ = _signum + raise NotImplementedError + + def is_closed(self): + return False + + def call_soon_threadsafe(self, _callback, *_args): + raise RuntimeError("event loop is closing") + + def fake_signal(signum: signal.Signals, handler: Any) -> object: + if callable(handler): + installed_handlers[signum] = cast(Callable[[int, object], object], handler) + return previous_handlers[signum] + + monkeypatch.setattr( + cmd_run.signal, "getsignal", lambda signum: previous_handlers[signum] + ) + monkeypatch.setattr(cmd_run.signal, "signal", fake_signal) + + cleanup = cmd_run._install_shutdown_signal_handlers(cast(Any, FakeLoop()), callback) + + installed_handlers[signal.SIGTERM](signal.SIGTERM, None) + + callback.assert_not_called() + cleanup() + + +def test_install_shutdown_signal_handlers_skips_unavailable_signal_api(monkeypatch): + callback = MagicMock() + + class FakeLoop: + def add_signal_handler(self, _signum, _callback, *_args): + _ = (_signum, _callback, _args) + raise ValueError("signal only works in main thread") + + def remove_signal_handler(self, _signum): + raise AssertionError("no handlers should be installed") + + monkeypatch.setattr( + cmd_run.signal, + "getsignal", + MagicMock(side_effect=ValueError("signal only works in main thread")), + ) + monkeypatch.setattr( + cmd_run.signal, + "signal", + MagicMock(side_effect=ValueError("signal only works in main thread")), + ) + + cleanup = cmd_run._install_shutdown_signal_handlers(cast(Any, FakeLoop()), callback) + + cleanup() + callback.assert_not_called() + + +def test_cleanup_signal_handlers_skips_none_previous_handler(monkeypatch): + restored_handlers: list[tuple[signal.Signals, Any]] = [] + callback = MagicMock() + + class FakeLoop: + def add_signal_handler(self, _signum, _callback, *_args): + _ = (_signum, _callback, _args) + + def remove_signal_handler(self, _signum): + _ = _signum + return True + + def fake_signal(signum: signal.Signals, handler: Any) -> object: + restored_handlers.append((signum, handler)) + return object() + + monkeypatch.setattr(cmd_run.signal, "getsignal", lambda _signum: None) + monkeypatch.setattr(cmd_run.signal, "signal", fake_signal) + + cleanup = cmd_run._install_shutdown_signal_handlers(cast(Any, FakeLoop()), callback) + + cleanup() + assert restored_handlers == [] + + +@pytest.mark.asyncio +async def test_initial_loader_shutdowns_logs_on_initialize_failure(monkeypatch): + shutdown_mock = AsyncMock(return_value=None) + lifecycle_instances: list[MagicMock] = [] + + class FakeLifecycle: + def __init__(self, log_broker, db): + _ = (log_broker, db) + lifecycle = MagicMock() + lifecycle.initialize = AsyncMock(side_effect=RuntimeError("boom")) + lifecycle.stop = AsyncMock() + lifecycle.start = AsyncMock() + lifecycle.dashboard_shutdown_event = asyncio.Event() + lifecycle.astrbot_config = MagicMock() + lifecycle_instances.append(lifecycle) + self.__dict__.update(lifecycle.__dict__) + + monkeypatch.setattr(initial_loader_module, "AstrBotCoreLifecycle", FakeLifecycle) + monkeypatch.setattr(initial_loader_module.LogManager, "shutdown", shutdown_mock) + + loader = initial_loader_module.InitialLoader(MagicMock(), MagicMock()) + + await loader.start() + + assert len(lifecycle_instances) == 1 + lifecycle_instances[0].stop.assert_not_awaited() + shutdown_mock.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_initial_loader_handles_cancellation_during_initialize(monkeypatch): + shutdown_mock = AsyncMock(return_value=None) + initialize_started = asyncio.Event() + lifecycle_instances: list[MagicMock] = [] + + class FakeLifecycle: + def __init__(self, log_broker, db): + _ = (log_broker, db) + lifecycle = MagicMock() + + async def initialize(): + initialize_started.set() + await asyncio.Event().wait() + + lifecycle.initialize = AsyncMock(side_effect=initialize) + lifecycle.stop = AsyncMock() + lifecycle.start = AsyncMock() + lifecycle.dashboard_shutdown_event = asyncio.Event() + lifecycle.astrbot_config = MagicMock() + lifecycle_instances.append(lifecycle) + self.__dict__.update(lifecycle.__dict__) + + monkeypatch.setattr(initial_loader_module, "AstrBotCoreLifecycle", FakeLifecycle) + monkeypatch.setattr(initial_loader_module.LogManager, "shutdown", shutdown_mock) + + loader = initial_loader_module.InitialLoader(MagicMock(), MagicMock()) + task = asyncio.create_task(loader.start()) + await initialize_started.wait() + + task.cancel() + await task + + assert len(lifecycle_instances) == 1 + lifecycle_instances[0].stop.assert_not_awaited() + shutdown_mock.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_initial_loader_stops_core_on_runtime_exception(monkeypatch): + shutdown_mock = AsyncMock(return_value=None) + lifecycle_instances: list[MagicMock] = [] + + class FakeLifecycle: + def __init__(self, log_broker, db): + _ = (log_broker, db) + lifecycle = MagicMock() + lifecycle.initialize = AsyncMock(return_value=None) + lifecycle.stop = AsyncMock() + lifecycle.start = AsyncMock(side_effect=RuntimeError("run boom")) + lifecycle.dashboard_shutdown_event = asyncio.Event() + lifecycle.astrbot_config = MagicMock() + lifecycle_instances.append(lifecycle) + self.__dict__.update(lifecycle.__dict__) + + class FakeDashboard: + def __init__(self, *args, **kwargs): + _ = (args, kwargs) + + def run(self): + return None + + monkeypatch.setattr(initial_loader_module, "AstrBotCoreLifecycle", FakeLifecycle) + monkeypatch.setattr(initial_loader_module, "AstrBotDashboard", FakeDashboard) + monkeypatch.setattr(initial_loader_module.LogManager, "shutdown", shutdown_mock) + + loader = initial_loader_module.InitialLoader(MagicMock(), MagicMock()) + + with pytest.raises(RuntimeError, match="run boom"): + await loader.start() + + assert len(lifecycle_instances) == 1 + lifecycle_instances[0].stop.assert_awaited_once() + shutdown_mock.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_initial_loader_preserves_runtime_error_if_stop_fails(monkeypatch): + shutdown_mock = AsyncMock(return_value=None) + lifecycle_instances: list[MagicMock] = [] + + class FakeLifecycle: + def __init__(self, log_broker, db): + _ = (log_broker, db) + lifecycle = MagicMock() + lifecycle.initialize = AsyncMock(return_value=None) + lifecycle.stop = AsyncMock(side_effect=RuntimeError("stop boom")) + lifecycle.start = AsyncMock(side_effect=RuntimeError("run boom")) + lifecycle.dashboard_shutdown_event = asyncio.Event() + lifecycle.astrbot_config = MagicMock() + lifecycle_instances.append(lifecycle) + self.__dict__.update(lifecycle.__dict__) + + class FakeDashboard: + def __init__(self, *args, **kwargs): + _ = (args, kwargs) + + def run(self): + return None + + monkeypatch.setattr(initial_loader_module, "AstrBotCoreLifecycle", FakeLifecycle) + monkeypatch.setattr(initial_loader_module, "AstrBotDashboard", FakeDashboard) + monkeypatch.setattr(initial_loader_module.LogManager, "shutdown", shutdown_mock) + + loader = initial_loader_module.InitialLoader(MagicMock(), MagicMock()) + + with pytest.raises(RuntimeError, match="run boom"): + await loader.start() + + assert len(lifecycle_instances) == 1 + lifecycle_instances[0].stop.assert_awaited_once() + shutdown_mock.assert_awaited_once() diff --git a/tests/unit/test_log_manager_shutdown.py b/tests/unit/test_log_manager_shutdown.py new file mode 100644 index 0000000000..f493d3ea37 --- /dev/null +++ b/tests/unit/test_log_manager_shutdown.py @@ -0,0 +1,42 @@ +from unittest.mock import AsyncMock, MagicMock, call + +import pytest + +import astrbot.core.log as log_module +from astrbot.core.log import LogManager + + +@pytest.mark.asyncio +async def test_shutdown_completes_and_removes_queued_file_sinks(monkeypatch): + fake_loguru = MagicMock() + fake_loguru.complete = AsyncMock(return_value=None) + fake_loguru.remove = MagicMock() + monkeypatch.setattr(log_module, "_loguru", fake_loguru) + + original_state = ( + LogManager._trace_sink_id, + LogManager._file_sink_id, + LogManager._console_sink_id, + LogManager._configured, + ) + LogManager._trace_sink_id = 22 + LogManager._file_sink_id = 11 + LogManager._console_sink_id = 33 + LogManager._configured = True + + try: + await LogManager.shutdown() + + fake_loguru.complete.assert_awaited_once() + assert fake_loguru.remove.call_args_list == [call(22), call(11)] + assert LogManager._trace_sink_id is None + assert LogManager._file_sink_id is None + assert LogManager._console_sink_id == 33 + assert LogManager._configured is False + finally: + ( + LogManager._trace_sink_id, + LogManager._file_sink_id, + LogManager._console_sink_id, + LogManager._configured, + ) = original_state