diff --git a/astrbot/builtin_stars/astrbot/group_chat_context.py b/astrbot/builtin_stars/astrbot/group_chat_context.py index 7fee3c0df9..22880acd5d 100644 --- a/astrbot/builtin_stars/astrbot/group_chat_context.py +++ b/astrbot/builtin_stars/astrbot/group_chat_context.py @@ -12,6 +12,10 @@ from astrbot.api.provider import Provider, ProviderRequest from astrbot.core.agent.message import TextPart from astrbot.core.astrbot_config_mgr import AstrBotConfigManager +from astrbot.core.utils.image_caption_cache import ( + image_caption_cache, + resolve_image_caption_cache_ttl, +) """ Group chat context awareness. @@ -67,6 +71,9 @@ def cfg(self, event: AstrMessageEvent): "image_caption": image_caption, "image_caption_prompt": image_caption_prompt, "image_caption_provider_id": image_caption_provider_id, + "image_caption_cache_ttl": resolve_image_caption_cache_ttl( + cfg.get("provider_settings", {}) + ), "enable_active_reply": enable_active_reply, "ar_method": ar_method, "ar_possibility": ar_possibility, @@ -79,17 +86,46 @@ async def get_image_caption( image_url: str, image_caption_provider_id: str, image_caption_prompt: str, + cache_ttl: int = 0, ) -> str: if not image_caption_provider_id: provider = self.context.get_using_provider() else: provider = self.context.get_provider_by_id(image_caption_provider_id) if not provider: - raise Exception(f"没有找到 ID 为 {image_caption_provider_id} 的提供商") + raise Exception( + f"Provider `{image_caption_provider_id}` was not found." + ) + if not isinstance(provider, Provider): - raise Exception(f"提供商类型错误({type(provider)}),无法获取图片描述") - response = await provider.text_chat( + raise Exception( + f"Provider type is invalid for image captioning: {type(provider)}." + ) + provider_id = _resolve_provider_cache_identity( + provider, + configured_provider_id=image_caption_provider_id, + ) + + return await image_caption_cache.get_or_create( + provider_id=provider_id, prompt=image_caption_prompt, + image_urls=[image_url], + ttl_seconds=cache_ttl, + caption_factory=lambda: self._fetch_image_caption( + provider, + image_caption_prompt, + image_url, + ), + ) + + async def _fetch_image_caption( + self, + provider: Provider, + prompt: str, + image_url: str, + ) -> str: + response = await provider.text_chat( + prompt=prompt, session_id=uuid.uuid4().hex, image_urls=[image_url], persist=False, @@ -195,15 +231,16 @@ async def _format_message(self, event: AstrMessageEvent, cfg: dict) -> str: try: url = comp.url if comp.url else comp.file if not url: - raise Exception("图片 URL 为空") + raise Exception("Image URL is empty.") caption = await self.get_image_caption( url, cfg["image_caption_provider_id"], cfg["image_caption_prompt"], + cfg["image_caption_cache_ttl"], ) parts.append(f" [Image: {caption}]") except Exception as e: - logger.error(f"获取图片描述失败: {e}") + logger.error(f"Failed to get image caption: {e}") else: parts.append(" [Image]") elif isinstance(comp, At): @@ -212,7 +249,7 @@ async def _format_message(self, event: AstrMessageEvent, cfg: dict) -> str: "all", ) if is_at_self: - parts.insert(1, "⚠️[DIRECTED AT YOU] ") + parts.insert(1, "[DIRECTED AT YOU] ") parts.append(f" [At: {comp.name}]") return "".join(parts) @@ -239,3 +276,27 @@ def _trim_left( def _format_group_history_block(records: list[str]) -> str: return GROUP_HISTORY_HEADER + "\n".join(records) + GROUP_HISTORY_FOOTER + + +def _resolve_provider_cache_identity( + provider: Provider, + configured_provider_id: str, +) -> str: + if configured_provider_id: + return configured_provider_id + + provider_config = provider.provider_config or {} + provider_id = provider_config.get("id", "") + if isinstance(provider_id, str) and provider_id: + return provider_id + + provider_type = provider_config.get("type", "") + model = provider.get_model() + return ":".join( + [ + provider.__class__.__module__, + provider.__class__.__qualname__, + "" if provider_type is None else str(provider_type), + "" if model is None else str(model), + ] + ) diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index 1c4fd400a0..542d3ec5c6 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -96,6 +96,10 @@ get_astrbot_workspaces_path, ) from astrbot.core.utils.file_extract import extract_file_moonshotai +from astrbot.core.utils.image_caption_cache import ( + image_caption_cache, + resolve_image_caption_cache_ttl, +) from astrbot.core.utils.llm_metadata import LLM_METADATAS from astrbot.core.utils.media_utils import ( IMAGE_COMPRESS_DEFAULT_MAX_SIZE, @@ -578,11 +582,41 @@ async def _ensure_persona_and_skills( pass +async def _request_img_caption_with_provider( + prov: Provider, + provider_id: str, + image_urls: list[str], + prompt: str, + cache_ttl: int | None = None, +) -> str: + if cache_ttl is None: + cache_ttl = resolve_image_caption_cache_ttl( + prov.provider_config if isinstance(prov.provider_config, dict) else None + ) + logger.debug("Processing image caption with provider: %s", provider_id) + + async def _caption_factory() -> str: + llm_resp = await prov.text_chat( + prompt=prompt, + image_urls=image_urls, + ) + return llm_resp.completion_text + + return await image_caption_cache.get_or_create( + provider_id=provider_id, + prompt=prompt, + image_urls=image_urls, + ttl_seconds=cache_ttl, + caption_factory=_caption_factory, + ) + + async def _request_img_caption( provider_id: str, cfg: dict, image_urls: list[str], plugin_context: Context, + prompt: str | None = None, ) -> str: prov = plugin_context.get_provider_by_id(provider_id) if prov is None: @@ -594,16 +628,18 @@ async def _request_img_caption( f"Cannot get image caption because provider `{provider_id}` is not a valid Provider, it is {type(prov)}.", ) - img_cap_prompt = cfg.get( + img_cap_prompt = prompt or cfg.get( "image_caption_prompt", "Please describe the image.", ) - logger.debug("Processing image caption with provider: %s", provider_id) - llm_resp = await prov.text_chat( - prompt=img_cap_prompt, + cache_ttl = resolve_image_caption_cache_ttl(cfg) + return await _request_img_caption_with_provider( + prov=prov, + provider_id=provider_id, image_urls=image_urls, + prompt=img_cap_prompt, + cache_ttl=cache_ttl, ) - return llm_resp.completion_text async def _ensure_img_caption( @@ -808,13 +844,23 @@ async def _process_quote_message( ) if path and _is_generated_compressed_image_path(path, compress_path): event.track_temporary_local_file(compress_path) - llm_resp = await prov.text_chat( - prompt="Please describe the image content.", + provider_config = ( + prov.provider_config + if isinstance(prov.provider_config, dict) + else {} + ) + caption = await _request_img_caption_with_provider( + prov=prov, + provider_id=provider_config.get("id", img_cap_prov_id or ""), image_urls=[compress_path], + prompt="Please describe the image content.", + cache_ttl=resolve_image_caption_cache_ttl( + config.provider_settings if config else None + ), ) - if llm_resp.completion_text: + if caption: content_parts.append( - f"[Image Caption in quoted message]: {llm_resp.completion_text}" + f"[Image Caption in quoted message]: {caption}" ) else: logger.warning("No provider found for image captioning in quote.") diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 22a53bb446..cbdf78a2e3 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -103,6 +103,7 @@ "fallback_chat_models": [], "default_image_caption_provider_id": "", "image_caption_prompt": "Please describe the image using Chinese.", + "image_caption_cache_ttl": 600, "provider_pool": ["*"], # "*" 表示使用所有可用的提供者 "wake_prefix": "", "web_search": False, @@ -3193,6 +3194,11 @@ "description": "图片转述提示词", "type": "text", }, + "provider_settings.image_caption_cache_ttl": { + "description": "图片转述缓存时长(秒)", + "type": "int", + "hint": "在缓存时间内再次收到相同图片时,直接复用已缓存的视觉识别结果;设为 0 表示禁用缓存", + }, }, "condition": { "provider_settings.enable": True, diff --git a/astrbot/core/utils/image_caption_cache.py b/astrbot/core/utils/image_caption_cache.py new file mode 100644 index 0000000000..4d160d2933 --- /dev/null +++ b/astrbot/core/utils/image_caption_cache.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +import asyncio +import base64 +import hashlib +import time +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from pathlib import Path +from urllib.parse import unquote, urlparse + +from astrbot.core import logger + +DEFAULT_IMAGE_CAPTION_CACHE_TTL = 600 + + +def resolve_image_caption_cache_ttl(config: dict | None) -> int: + raw = (config or {}).get( + "image_caption_cache_ttl", + DEFAULT_IMAGE_CAPTION_CACHE_TTL, + ) + if isinstance(raw, bool): + return DEFAULT_IMAGE_CAPTION_CACHE_TTL + try: + return max(int(raw), 0) + except (TypeError, ValueError): + return DEFAULT_IMAGE_CAPTION_CACHE_TTL + + +@dataclass(slots=True) +class _ImageCaptionCacheEntry: + caption: str + expires_at: float + + +class ImageCaptionCache: + def __init__(self) -> None: + self._entries: dict[str, _ImageCaptionCacheEntry] = {} + self._locks: dict[str, asyncio.Lock] = {} + + def clear(self) -> None: + self._entries.clear() + self._locks.clear() + + async def get_or_create( + self, + *, + provider_id: str, + prompt: str, + image_urls: list[str], + ttl_seconds: int, + caption_factory: Callable[[], Awaitable[str]], + ) -> str: + if ttl_seconds <= 0: + return await caption_factory() + + cache_key = await self._build_cache_key( + provider_id=provider_id, + prompt=prompt, + image_urls=image_urls, + ) + cached_caption = self._get(cache_key) + if cached_caption is not None: + logger.debug( + "Using cached image caption. provider=%s", + provider_id or "", + ) + return cached_caption + + lock = self._get_lock(cache_key) + async with lock: + cached_caption = self._get(cache_key) + if cached_caption is not None: + logger.debug( + "Using cached image caption after lock wait. provider=%s", + provider_id or "", + ) + return cached_caption + + caption = await caption_factory() + self._entries[cache_key] = _ImageCaptionCacheEntry( + caption=caption, + expires_at=time.monotonic() + ttl_seconds, + ) + self._cleanup_expired_entries() + return caption + + def _get(self, cache_key: str) -> str | None: + entry = self._entries.get(cache_key) + if entry is None: + return None + if entry.expires_at <= time.monotonic(): + self._entries.pop(cache_key, None) + return None + return entry.caption + + def _cleanup_expired_entries(self) -> None: + now = time.monotonic() + expired_keys = [ + key for key, entry in self._entries.items() if entry.expires_at <= now + ] + for key in expired_keys: + self._entries.pop(key, None) + self._locks.pop(key, None) + + def _get_lock(self, cache_key: str) -> asyncio.Lock: + lock = self._locks.get(cache_key) + if lock is None: + lock = asyncio.Lock() + self._locks[cache_key] = lock + return lock + + async def _build_cache_key( + self, + *, + provider_id: str, + prompt: str, + image_urls: list[str], + ) -> str: + image_fingerprints = [] + for image_url in image_urls: + image_fingerprints.append(await self._fingerprint_image(image_url)) + + joined = "\n".join([provider_id, prompt, *image_fingerprints]) + return hashlib.sha256(joined.encode("utf-8")).hexdigest() + + async def _fingerprint_image(self, image_url: str) -> str: + if image_url.startswith("base64://"): + return self._fingerprint_base64_image(image_url) + + if image_url.startswith("data:image"): + return self._fingerprint_data_uri_image(image_url) + + if image_url.startswith(("http://", "https://")): + return self._fingerprint_remote_image(image_url) + + return await self._fingerprint_local_image(image_url) + + def _fingerprint_base64_image(self, image_url: str) -> str: + raw_base64 = image_url.removeprefix("base64://") + try: + image_bytes = base64.b64decode(raw_base64) + except Exception: + return self._reference_fingerprint(image_url) + return self._hash_bytes(image_bytes) + + def _fingerprint_data_uri_image(self, image_url: str) -> str: + try: + _, encoded = image_url.split(",", 1) + image_bytes = base64.b64decode(encoded) + except Exception: + return self._reference_fingerprint(image_url) + return self._hash_bytes(image_bytes) + + def _fingerprint_remote_image(self, image_url: str) -> str: + return f"url:{image_url}" + + async def _fingerprint_local_image(self, image_url: str) -> str: + local_path = self._to_local_path(image_url) + if local_path and local_path.is_file(): + image_bytes = await asyncio.to_thread(local_path.read_bytes) + return self._hash_bytes(image_bytes) + + return self._reference_fingerprint(image_url) + + def _to_local_path(self, image_url: str) -> Path | None: + if image_url.startswith("file://"): + parsed = urlparse(image_url) + parsed_path = unquote(parsed.path) + if ( + parsed_path.startswith("/") + and len(parsed_path) >= 3 + and parsed_path[2] == ":" + ): + parsed_path = parsed_path[1:] + return Path(parsed_path) + + if image_url.startswith(("http://", "https://", "base64://", "data:image")): + return None + + return Path(image_url) + + def _hash_bytes(self, payload: bytes) -> str: + return hashlib.sha256(payload).hexdigest() + + def _reference_fingerprint(self, image_url: str) -> str: + return f"ref:{image_url}" + + +image_caption_cache = ImageCaptionCache() diff --git a/dashboard/src/assets/mdi-subset/materialdesignicons-subset.css b/dashboard/src/assets/mdi-subset/materialdesignicons-subset.css index 8e5fd76cd1..be565ba238 100644 --- a/dashboard/src/assets/mdi-subset/materialdesignicons-subset.css +++ b/dashboard/src/assets/mdi-subset/materialdesignicons-subset.css @@ -1,4 +1,4 @@ -/* Auto-generated MDI subset – 271 icons */ +/* Auto-generated MDI subset – 272 icons */ /* Do not edit manually. Run: pnpm run subset-icons */ @font-face { @@ -464,6 +464,10 @@ content: "\F1036"; } +.mdi-file-search-outline::before { + content: "\F0C7D"; +} + .mdi-file-upload::before { content: "\F0A4D"; } diff --git a/dashboard/src/assets/mdi-subset/materialdesignicons-webfont-subset.woff b/dashboard/src/assets/mdi-subset/materialdesignicons-webfont-subset.woff index 181ce7f861..8e57a70b4f 100644 Binary files a/dashboard/src/assets/mdi-subset/materialdesignicons-webfont-subset.woff and b/dashboard/src/assets/mdi-subset/materialdesignicons-webfont-subset.woff differ diff --git a/dashboard/src/assets/mdi-subset/materialdesignicons-webfont-subset.woff2 b/dashboard/src/assets/mdi-subset/materialdesignicons-webfont-subset.woff2 index 931625c6d7..107a267095 100644 Binary files a/dashboard/src/assets/mdi-subset/materialdesignicons-webfont-subset.woff2 and b/dashboard/src/assets/mdi-subset/materialdesignicons-webfont-subset.woff2 differ diff --git a/dashboard/src/i18n/locales/en-US/features/config-metadata.json b/dashboard/src/i18n/locales/en-US/features/config-metadata.json index 618b95bac4..fe45870099 100644 --- a/dashboard/src/i18n/locales/en-US/features/config-metadata.json +++ b/dashboard/src/i18n/locales/en-US/features/config-metadata.json @@ -49,6 +49,10 @@ "description": "Default Image Caption Model", "hint": "Leave empty to disable; useful for non-multimodal models" }, + "image_caption_cache_ttl": { + "description": "Image caption cache TTL (seconds)", + "hint": "Reuse the cached vision result when the same image is received again within this period; set to 0 to disable caching" + }, "image_caption_prompt": { "description": "Image Caption Prompt" } diff --git a/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json b/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json index c42a3313a5..bbe5720423 100644 --- a/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json +++ b/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json @@ -49,6 +49,10 @@ "description": "Модель описания изображений", "hint": "Оставьте пустым для отключения; полезно для моделей без поддержки мультимодальности" }, + "image_caption_cache_ttl": { + "description": "TTL кеша описания изображений (секунды)", + "hint": "При повторном получении одного и того же изображения в пределах этого времени будет использоваться кэшированный результат распознавания; установите 0 для отключения" + }, "image_caption_prompt": { "description": "Промпт для описания изображений" } diff --git a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json index 200b3b9fe1..090166acd8 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json +++ b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json @@ -49,6 +49,10 @@ "description": "默认图片转述模型", "hint": "留空代表不使用,可用于非多模态模型" }, + "image_caption_cache_ttl": { + "description": "图片转述缓存时长(秒)", + "hint": "在缓存时间内再次收到相同图片时,直接复用已缓存的视觉识别结果;设为 0 表示禁用缓存" + }, "image_caption_prompt": { "description": "图片转述提示词" } diff --git a/tests/test_group_chat_context.py b/tests/test_group_chat_context.py new file mode 100644 index 0000000000..64150b3ba1 --- /dev/null +++ b/tests/test_group_chat_context.py @@ -0,0 +1,237 @@ +import asyncio +import base64 +import hashlib +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from astrbot.api.provider import Provider +from astrbot.builtin_stars.astrbot.group_chat_context import ( + GroupChatContext, + _resolve_provider_cache_identity, +) +from astrbot.core.utils.image_caption_cache import ( + ImageCaptionCache, + image_caption_cache, +) + + +@pytest.mark.asyncio +async def test_group_chat_context_reuses_cached_image_caption(tmp_path): + image_caption_cache.clear() + image_path = tmp_path / "same-image.png" + image_path.write_bytes(b"same-image-bytes") + + provider = MagicMock(spec=Provider) + provider.provider_config = {"id": "caption-provider"} + provider.text_chat = AsyncMock( + return_value=MagicMock(completion_text="cached caption") + ) + + context = MagicMock() + context.get_provider_by_id.return_value = provider + + group_chat_context = GroupChatContext(MagicMock(), context) + + caption1 = await group_chat_context.get_image_caption( + str(image_path), + "caption-provider", + "Please describe the image using Chinese.", + 600, + ) + caption2 = await group_chat_context.get_image_caption( + str(image_path), + "caption-provider", + "Please describe the image using Chinese.", + 600, + ) + + assert caption1 == "cached caption" + assert caption2 == "cached caption" + provider.text_chat.assert_awaited_once() + image_caption_cache.clear() + + +@pytest.mark.asyncio +async def test_image_caption_cache_reuses_per_key_lock_after_waiters_complete(): + cache = ImageCaptionCache() + started = asyncio.Event() + release = asyncio.Event() + calls = 0 + + async def caption_factory() -> str: + nonlocal calls + calls += 1 + started.set() + await release.wait() + return "cached caption" + + task1 = asyncio.create_task( + cache.get_or_create( + provider_id="caption-provider", + prompt="Please describe the image using Chinese.", + image_urls=["same-image.png"], + ttl_seconds=600, + caption_factory=caption_factory, + ) + ) + await started.wait() + task2 = asyncio.create_task( + cache.get_or_create( + provider_id="caption-provider", + prompt="Please describe the image using Chinese.", + image_urls=["same-image.png"], + ttl_seconds=600, + caption_factory=caption_factory, + ) + ) + + await asyncio.sleep(0) + assert len(cache._locks) == 1 + + release.set() + + assert await task1 == "cached caption" + assert await task2 == "cached caption" + assert calls == 1 + assert len(cache._locks) == 1 + + +@pytest.mark.asyncio +async def test_image_caption_cache_accepts_lambda_caption_factory(): + cache = ImageCaptionCache() + calls = 0 + + async def fetch_caption() -> str: + nonlocal calls + calls += 1 + return "cached caption" + + caption1 = await cache.get_or_create( + provider_id="caption-provider", + prompt="Please describe the image using Chinese.", + image_urls=["same-image.png"], + ttl_seconds=600, + caption_factory=lambda: fetch_caption(), + ) + caption2 = await cache.get_or_create( + provider_id="caption-provider", + prompt="Please describe the image using Chinese.", + image_urls=["same-image.png"], + ttl_seconds=600, + caption_factory=lambda: fetch_caption(), + ) + + assert caption1 == "cached caption" + assert caption2 == "cached caption" + assert calls == 1 + + +@pytest.mark.asyncio +async def test_image_caption_cache_fingerprints_supported_image_reference_types( + tmp_path, +): + cache = ImageCaptionCache() + image_bytes = b"same-image-bytes" + expected_hash = hashlib.sha256(image_bytes).hexdigest() + image_path = tmp_path / "same-image.png" + image_path.write_bytes(image_bytes) + encoded = base64.b64encode(image_bytes).decode("ascii") + + assert await cache._fingerprint_image(f"base64://{encoded}") == expected_hash + assert ( + await cache._fingerprint_image(f"data:image/png;base64,{encoded}") + == expected_hash + ) + assert await cache._fingerprint_image(str(image_path)) == expected_hash + assert ( + await cache._fingerprint_image("https://example.com/image.png") + == "url:https://example.com/image.png" + ) + assert ( + await cache._fingerprint_image("missing-image.png") == "ref:missing-image.png" + ) + + +@pytest.mark.asyncio +async def test_group_chat_context_default_provider_cache_identity_is_stable_per_provider( + tmp_path, +): + image_caption_cache.clear() + image_path = tmp_path / "same-image.png" + image_path.write_bytes(b"same-image-bytes") + + provider1 = MagicMock(spec=Provider) + provider1.provider_config = {"type": "openai_chat_completion"} + provider1.get_model.return_value = "gpt-4o" + provider1.text_chat = AsyncMock( + return_value=MagicMock(completion_text="caption from provider one") + ) + + provider2 = MagicMock(spec=Provider) + provider2.provider_config = {"type": "google_genai"} + provider2.get_model.return_value = "gemini-2.5-pro" + provider2.text_chat = AsyncMock( + return_value=MagicMock(completion_text="caption from provider two") + ) + + context = MagicMock() + context.get_using_provider.side_effect = [provider1, provider2] + + group_chat_context = GroupChatContext(MagicMock(), context) + + caption1 = await group_chat_context.get_image_caption( + str(image_path), + "", + "Please describe the image using Chinese.", + 600, + ) + caption2 = await group_chat_context.get_image_caption( + str(image_path), + "", + "Please describe the image using Chinese.", + 600, + ) + + assert caption1 == "caption from provider one" + assert caption2 == "caption from provider two" + provider1.text_chat.assert_awaited_once() + provider2.text_chat.assert_awaited_once() + image_caption_cache.clear() + + +def test_resolve_provider_cache_identity_prefers_configured_provider_id(): + provider = MagicMock(spec=Provider) + provider.provider_config = {"id": "provider-config-id", "type": "google_genai"} + provider.get_model.return_value = "gemini-2.5-pro" + + assert ( + _resolve_provider_cache_identity( + provider, + configured_provider_id="configured-provider-id", + ) + == "configured-provider-id" + ) + + +def test_resolve_provider_cache_identity_uses_provider_config_id_as_fallback(): + provider = MagicMock(spec=Provider) + provider.provider_config = {"id": "provider-config-id", "type": "google_genai"} + provider.get_model.return_value = "gemini-2.5-pro" + + assert ( + _resolve_provider_cache_identity(provider, configured_provider_id="") + == "provider-config-id" + ) + + +def test_resolve_provider_cache_identity_uses_deterministic_string_when_ids_absent(): + provider = MagicMock(spec=Provider) + provider.provider_config = {"type": "openai_chat_completion"} + provider.get_model.return_value = "gpt-4o" + + assert _resolve_provider_cache_identity(provider, configured_provider_id="") == ( + f"{provider.__class__.__module__}:" + f"{provider.__class__.__qualname__}:" + "openai_chat_completion:gpt-4o" + ) diff --git a/tests/unit/test_astr_main_agent.py b/tests/unit/test_astr_main_agent.py index db729a23ba..0b06075f93 100644 --- a/tests/unit/test_astr_main_agent.py +++ b/tests/unit/test_astr_main_agent.py @@ -16,6 +16,7 @@ from astrbot.core.provider.entities import ProviderRequest from astrbot.core.skills.skill_manager import SkillInfo from astrbot.core.star.star import StarMetadata +from astrbot.core.utils.image_caption_cache import image_caption_cache @pytest.fixture @@ -1173,6 +1174,111 @@ async def test_build_main_agent_skips_caption_when_main_provider_supports_images ) mock_provider.text_chat.assert_not_called() + @pytest.mark.asyncio + async def test_request_img_caption_reuses_cached_result( + self, tmp_path, mock_context + ): + """Test repeated image caption requests reuse the cached vision result.""" + module = ama + image_caption_cache.clear() + + image_path = tmp_path / "same-image.png" + image_path.write_bytes(b"same-image") + + caption_provider = MagicMock(spec=Provider) + caption_provider.text_chat = AsyncMock( + return_value=MagicMock(completion_text="cached caption") + ) + mock_context.get_provider_by_id.return_value = caption_provider + + cfg = { + "image_caption_prompt": "Please describe the image using Chinese.", + "image_caption_cache_ttl": 600, + } + + caption1 = await module._request_img_caption( + "caption-provider", + cfg, + [str(image_path)], + mock_context, + ) + caption2 = await module._request_img_caption( + "caption-provider", + cfg, + [str(image_path)], + mock_context, + ) + + assert caption1 == "cached caption" + assert caption2 == "cached caption" + caption_provider.text_chat.assert_awaited_once() + image_caption_cache.clear() + + @pytest.mark.asyncio + async def test_process_quote_message_uses_provider_instance_for_image_caption( + self, + tmp_path, + mock_event, + mock_context, + ): + """Test quoted image captions use the resolved provider instance directly.""" + module = ama + image_caption_cache.clear() + + image_path = tmp_path / "quoted-image.png" + image_path.write_bytes(b"quoted-image") + + quoted_image = Image(file=f"file:///{image_path.as_posix()}") + quoted_reply = Reply( + id="reply-1", + chain=[Plain(text="quoted text"), quoted_image], + sender_nickname="", + message_str="quoted text", + ) + mock_event.message_obj.message = [quoted_reply] + + caption_provider = MagicMock(spec=Provider) + caption_provider.provider_config = {"id": "provider-config-id"} + caption_provider.text_chat = AsyncMock( + return_value=MagicMock(completion_text="quoted caption") + ) + mock_context.get_provider_by_id.return_value = caption_provider + + req = ProviderRequest(prompt="Hello") + + with ( + patch( + "astrbot.core.astr_main_agent.extract_quoted_message_text", + AsyncMock(return_value="quoted text"), + ), + patch.object( + Image, + "convert_to_file_path", + AsyncMock(return_value=str(image_path)), + ), + patch( + "astrbot.core.astr_main_agent._compress_image_for_provider", + AsyncMock(return_value=str(image_path)), + ), + ): + await module._process_quote_message( + mock_event, + req, + "caption-provider", + mock_context, + config=module.MainAgentBuildConfig( + tool_call_timeout=60, + provider_settings={"image_caption_cache_ttl": 600}, + ), + ) + + assert any( + "[Image Caption in quoted message]: quoted caption" in part.text + for part in req.extra_user_content_parts + ) + caption_provider.text_chat.assert_awaited_once() + image_caption_cache.clear() + @pytest.mark.asyncio async def test_build_main_agent_uses_image_fallback_provider( self, mock_event, mock_context