diff --git a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py index b1a8156a45..5d2c03652f 100644 --- a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +++ b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py @@ -3,6 +3,7 @@ import threading import time import uuid +from concurrent.futures import CancelledError as FutureCancelledError from pathlib import Path from typing import Literal, NoReturn, cast @@ -34,6 +35,18 @@ from ...register import register_platform_adapter from .dingtalk_event import DingtalkMessageEvent +DINGTALK_RECONNECT_INITIAL_DELAY = 10 +DINGTALK_RECONNECT_MAX_DELAY = 300 +DINGTALK_RECONNECT_STABLE_SECONDS = 300 + + +def _dingtalk_reconnect_delay(retry_count: int) -> int: + safe_retry_count = max(retry_count, 1) + return min( + DINGTALK_RECONNECT_INITIAL_DELAY * 2 ** (safe_retry_count - 1), + DINGTALK_RECONNECT_MAX_DELAY, + ) + class MyEventHandler(dingtalk_stream.EventHandler): async def process(self, event: dingtalk_stream.EventMessage): @@ -83,7 +96,8 @@ async def process(self, message: dingtalk_stream.CallbackMessage): self.client, ) self.client_ = client # 用于 websockets 的 client - self._shutdown_event: threading.Event | None = None + self._shutdown_event = threading.Event() + self._terminated_event = threading.Event() def _id_to_sid(self, dingtalk_id: str | None) -> str: if not dingtalk_id: @@ -750,56 +764,70 @@ async def run(self) -> None: # 钉钉的 SDK 并没有实现真正的异步,start() 里面有堵塞方法。 # SDK 内部已有 while True 重连循环,但需要监控 task 状态, # 如果 task 意外退出则重新启动。 - MAX_RETRIES = 5 - RETRY_INTERVAL = 10 def start_client(loop: asyncio.AbstractEventLoop) -> None: retry_count = 0 - def handle_retry(error_msg: str) -> bool: - """处理重试逻辑,返回 True 表示需要继续重试,False 表示放弃。""" + def handle_retry(error_msg: str, run_seconds: float) -> None: nonlocal retry_count logger.error(error_msg) + if run_seconds >= DINGTALK_RECONNECT_STABLE_SECONDS: + retry_count = 0 retry_count += 1 - if retry_count < MAX_RETRIES: - logger.info(f"钉钉适配器尝试重连 ({retry_count}/{MAX_RETRIES})...") - time.sleep(RETRY_INTERVAL) - return True - logger.error("钉钉适配器重连失败,已达最大重试次数") - return False - - while retry_count < MAX_RETRIES: + delay = _dingtalk_reconnect_delay(retry_count) + logger.info( + f"钉钉适配器将在 {delay} 秒后重连 (第 {retry_count} 次)...", + ) + self._terminated_event.wait(delay) + + while not self._terminated_event.is_set(): task = None + should_cancel_task = False + start_time = time.monotonic() try: - self._shutdown_event = threading.Event() - task = loop.create_task(self.client_.start()) + self._shutdown_event.clear() + if self._terminated_event.is_set(): + return + task = asyncio.run_coroutine_threadsafe(self.client_.start(), loop) # 当 task 完成时唤醒线程(无论是正常退出还是异常退出) task.add_done_callback(lambda _: self._shutdown_event.set()) + if self._terminated_event.is_set(): + should_cancel_task = True + self._shutdown_event.set() self._shutdown_event.wait() + if self._terminated_event.is_set(): + return if task.done(): try: exc = task.exception() - except asyncio.CancelledError: + except (asyncio.CancelledError, FutureCancelledError): logger.info("钉钉适配器 task 已取消") return if exc: if "Graceful shutdown" in str(exc): logger.info("钉钉适配器已被关闭") return - if handle_retry(f"钉钉 SDK task 异常退出: {exc}"): - continue - return + should_cancel_task = True + handle_retry( + f"钉钉 SDK task 异常退出: {exc}", + time.monotonic() - start_time, + ) + continue # task 仍在运行,shutdown_event 被设置(正常关闭) return except Exception as e: if "Graceful shutdown" in str(e): logger.info("钉钉适配器已被关闭") return - if not handle_retry(f"钉钉机器人启动失败: {e}"): - return + should_cancel_task = True + handle_retry( + f"钉钉机器人启动失败: {e}", + time.monotonic() - start_time, + ) + continue finally: # 仅在重试/失败路径取消 task,正常关闭不取消 - if task is not None and not task.done() and retry_count > 0: + if task is not None and not task.done() and should_cancel_task: task.cancel() loop = asyncio.get_running_loop() @@ -809,11 +837,11 @@ async def terminate(self) -> None: def monkey_patch_close() -> NoReturn: raise KeyboardInterrupt("Graceful shutdown") + self._terminated_event.set() + self._shutdown_event.set() if self.client_.websocket is not None: self.client_.open_connection = monkey_patch_close await self.client_.websocket.close(code=1000, reason="Graceful shutdown") - if self._shutdown_event is not None: - self._shutdown_event.set() def get_client(self): return self.client diff --git a/tests/test_dingtalk_adapter.py b/tests/test_dingtalk_adapter.py new file mode 100644 index 0000000000..aa1e638c8c --- /dev/null +++ b/tests/test_dingtalk_adapter.py @@ -0,0 +1,78 @@ +import asyncio +import threading + +import pytest + +from astrbot.core.platform.sources.dingtalk import dingtalk_adapter +from astrbot.core.platform.sources.dingtalk.dingtalk_adapter import ( + DINGTALK_RECONNECT_INITIAL_DELAY, + DINGTALK_RECONNECT_MAX_DELAY, + DingtalkPlatformAdapter, + _dingtalk_reconnect_delay, +) + + +def test_dingtalk_reconnect_delay_uses_exponential_backoff(): + assert [_dingtalk_reconnect_delay(i) for i in range(1, 5)] == [ + 10, + 20, + 40, + 80, + ] + + +def test_dingtalk_reconnect_delay_has_minimum_delay(): + assert _dingtalk_reconnect_delay(0) == DINGTALK_RECONNECT_INITIAL_DELAY + assert _dingtalk_reconnect_delay(-1) == DINGTALK_RECONNECT_INITIAL_DELAY + + +def test_dingtalk_reconnect_delay_is_capped(): + assert _dingtalk_reconnect_delay(20) == DINGTALK_RECONNECT_MAX_DELAY + + +@pytest.mark.asyncio +async def test_dingtalk_reconnect_delay_wakes_on_terminate(monkeypatch): + class ObservedEvent: + def __init__(self) -> None: + self._event = threading.Event() + self.wait_started = threading.Event() + self.wait_timeout: float | None = None + + def is_set(self) -> bool: + return self._event.is_set() + + def set(self) -> None: + self._event.set() + + def wait(self, timeout: float | None = None) -> bool: + self.wait_timeout = timeout + self.wait_started.set() + return self._event.wait(timeout) + + class FailingClient: + websocket = None + + async def start(self) -> None: + raise RuntimeError("connect failed") + + terminated_event = ObservedEvent() + adapter = DingtalkPlatformAdapter.__new__(DingtalkPlatformAdapter) + adapter.client_ = FailingClient() + adapter._shutdown_event = threading.Event() + adapter._terminated_event = terminated_event + + monkeypatch.setattr(dingtalk_adapter, "_dingtalk_reconnect_delay", lambda _: 60) + + run_task = asyncio.create_task(adapter.run()) + try: + wait_started = await asyncio.to_thread(terminated_event.wait_started.wait, 1) + assert wait_started + assert terminated_event.wait_timeout == 60 + + await adapter.terminate() + await asyncio.wait_for(run_task, timeout=1) + finally: + if not run_task.done(): + await adapter.terminate() + run_task.cancel() + await asyncio.gather(run_task, return_exceptions=True)