Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 52 additions & 24 deletions astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import threading
import time
import uuid
from concurrent.futures import CancelledError as FutureCancelledError
from pathlib import Path
from typing import Literal, NoReturn, cast

Expand Down Expand Up @@ -34,6 +35,18 @@
from ...register import register_platform_adapter
from .dingtalk_event import DingtalkMessageEvent

DINGTALK_RECONNECT_INITIAL_DELAY = 10
DINGTALK_RECONNECT_MAX_DELAY = 300
DINGTALK_RECONNECT_STABLE_SECONDS = 300


def _dingtalk_reconnect_delay(retry_count: int) -> int:
safe_retry_count = max(retry_count, 1)
return min(
DINGTALK_RECONNECT_INITIAL_DELAY * 2 ** (safe_retry_count - 1),
DINGTALK_RECONNECT_MAX_DELAY,
)


class MyEventHandler(dingtalk_stream.EventHandler):
async def process(self, event: dingtalk_stream.EventMessage):
Expand Down Expand Up @@ -83,7 +96,8 @@ async def process(self, message: dingtalk_stream.CallbackMessage):
self.client,
)
self.client_ = client # 用于 websockets 的 client
self._shutdown_event: threading.Event | None = None
self._shutdown_event = threading.Event()
self._terminated_event = threading.Event()

def _id_to_sid(self, dingtalk_id: str | None) -> str:
if not dingtalk_id:
Expand Down Expand Up @@ -750,56 +764,70 @@ async def run(self) -> None:
# 钉钉的 SDK 并没有实现真正的异步,start() 里面有堵塞方法。
# SDK 内部已有 while True 重连循环,但需要监控 task 状态,
# 如果 task 意外退出则重新启动。
MAX_RETRIES = 5
RETRY_INTERVAL = 10

def start_client(loop: asyncio.AbstractEventLoop) -> None:
retry_count = 0

def handle_retry(error_msg: str) -> bool:
"""处理重试逻辑,返回 True 表示需要继续重试,False 表示放弃。"""
def handle_retry(error_msg: str, run_seconds: float) -> None:
nonlocal retry_count
logger.error(error_msg)
if run_seconds >= DINGTALK_RECONNECT_STABLE_SECONDS:
retry_count = 0
retry_count += 1
if retry_count < MAX_RETRIES:
logger.info(f"钉钉适配器尝试重连 ({retry_count}/{MAX_RETRIES})...")
time.sleep(RETRY_INTERVAL)
return True
logger.error("钉钉适配器重连失败,已达最大重试次数")
return False

while retry_count < MAX_RETRIES:
delay = _dingtalk_reconnect_delay(retry_count)
logger.info(
f"钉钉适配器将在 {delay} 秒后重连 (第 {retry_count} 次)...",
)
self._terminated_event.wait(delay)

while not self._terminated_event.is_set():
task = None
Comment on lines +771 to 784
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using time.sleep(delay) inside the background thread can block the graceful shutdown of the application for up to 300 seconds (5 minutes) if a reconnect delay is active when terminate() is called. Additionally, re-creating self._shutdown_event = threading.Event() in every iteration of the while True loop can lead to race conditions where terminate() sets an old event, causing the new event to wait indefinitely.

We can resolve both issues elegantly by monkey-patching self.terminate to set a thread-safe terminated event, and using terminated.wait(delay) instead of time.sleep(delay). This allows the background thread to wake up and exit immediately upon shutdown.

            def handle_retry(error_msg: str, run_seconds: float) -> None:
                nonlocal retry_count
                logger.error(error_msg)
                if run_seconds >= DINGTALK_RECONNECT_STABLE_SECONDS:
                    retry_count = 0
                retry_count += 1
                delay = _dingtalk_reconnect_delay(retry_count)
                logger.info(
                    f"钉钉适配器将在 {delay} 秒后重连 (第 {retry_count} 次)...",
                )
                terminated.wait(delay)

            terminated = threading.Event()
            original_terminate = self.terminate
            async def new_terminate():
                terminated.set()
                await original_terminate()
            self.terminate = new_terminate

            while not terminated.is_set():
                task = None

should_cancel_task = False
start_time = time.monotonic()
try:
self._shutdown_event = threading.Event()
task = loop.create_task(self.client_.start())
self._shutdown_event.clear()
if self._terminated_event.is_set():
return
task = asyncio.run_coroutine_threadsafe(self.client_.start(), loop)
# 当 task 完成时唤醒线程(无论是正常退出还是异常退出)
task.add_done_callback(lambda _: self._shutdown_event.set())
if self._terminated_event.is_set():
should_cancel_task = True
self._shutdown_event.set()
self._shutdown_event.wait()
if self._terminated_event.is_set():
return
if task.done():
try:
exc = task.exception()
except asyncio.CancelledError:
except (asyncio.CancelledError, FutureCancelledError):
logger.info("钉钉适配器 task 已取消")
return
if exc:
if "Graceful shutdown" in str(exc):
logger.info("钉钉适配器已被关闭")
return
if handle_retry(f"钉钉 SDK task 异常退出: {exc}"):
continue
return
should_cancel_task = True
handle_retry(
f"钉钉 SDK task 异常退出: {exc}",
time.monotonic() - start_time,
)
continue
# task 仍在运行,shutdown_event 被设置(正常关闭)
return
except Exception as e:
if "Graceful shutdown" in str(e):
logger.info("钉钉适配器已被关闭")
return
if not handle_retry(f"钉钉机器人启动失败: {e}"):
return
should_cancel_task = True
handle_retry(
f"钉钉机器人启动失败: {e}",
time.monotonic() - start_time,
)
continue
finally:
# 仅在重试/失败路径取消 task,正常关闭不取消
if task is not None and not task.done() and retry_count > 0:
if task is not None and not task.done() and should_cancel_task:
task.cancel()

loop = asyncio.get_running_loop()
Expand All @@ -809,11 +837,11 @@ async def terminate(self) -> None:
def monkey_patch_close() -> NoReturn:
raise KeyboardInterrupt("Graceful shutdown")

self._terminated_event.set()
self._shutdown_event.set()
if self.client_.websocket is not None:
self.client_.open_connection = monkey_patch_close
await self.client_.websocket.close(code=1000, reason="Graceful shutdown")
if self._shutdown_event is not None:
self._shutdown_event.set()

def get_client(self):
return self.client
78 changes: 78 additions & 0 deletions tests/test_dingtalk_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import asyncio
import threading

import pytest

from astrbot.core.platform.sources.dingtalk import dingtalk_adapter
from astrbot.core.platform.sources.dingtalk.dingtalk_adapter import (
DINGTALK_RECONNECT_INITIAL_DELAY,
DINGTALK_RECONNECT_MAX_DELAY,
DingtalkPlatformAdapter,
_dingtalk_reconnect_delay,
)


def test_dingtalk_reconnect_delay_uses_exponential_backoff():
assert [_dingtalk_reconnect_delay(i) for i in range(1, 5)] == [
10,
20,
40,
80,
]


def test_dingtalk_reconnect_delay_has_minimum_delay():
assert _dingtalk_reconnect_delay(0) == DINGTALK_RECONNECT_INITIAL_DELAY
assert _dingtalk_reconnect_delay(-1) == DINGTALK_RECONNECT_INITIAL_DELAY


def test_dingtalk_reconnect_delay_is_capped():
assert _dingtalk_reconnect_delay(20) == DINGTALK_RECONNECT_MAX_DELAY


@pytest.mark.asyncio
async def test_dingtalk_reconnect_delay_wakes_on_terminate(monkeypatch):
class ObservedEvent:
def __init__(self) -> None:
self._event = threading.Event()
self.wait_started = threading.Event()
self.wait_timeout: float | None = None

def is_set(self) -> bool:
return self._event.is_set()

def set(self) -> None:
self._event.set()

def wait(self, timeout: float | None = None) -> bool:
self.wait_timeout = timeout
self.wait_started.set()
return self._event.wait(timeout)

class FailingClient:
websocket = None

async def start(self) -> None:
raise RuntimeError("connect failed")

terminated_event = ObservedEvent()
adapter = DingtalkPlatformAdapter.__new__(DingtalkPlatformAdapter)
adapter.client_ = FailingClient()
adapter._shutdown_event = threading.Event()
adapter._terminated_event = terminated_event

monkeypatch.setattr(dingtalk_adapter, "_dingtalk_reconnect_delay", lambda _: 60)

run_task = asyncio.create_task(adapter.run())
try:
wait_started = await asyncio.to_thread(terminated_event.wait_started.wait, 1)
assert wait_started
assert terminated_event.wait_timeout == 60

await adapter.terminate()
await asyncio.wait_for(run_task, timeout=1)
finally:
if not run_task.done():
await adapter.terminate()
run_task.cancel()
await asyncio.gather(run_task, return_exceptions=True)
Loading