diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index dec98692bc..ea9b6147dd 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -1174,6 +1174,34 @@ "gm_thinking_config": {"budget": 0, "level": "HIGH"}, "proxy": "", }, + "Google Vertex AI": { + "id": "google_vertex_ai", + "provider": "google-vertex-ai", + "type": "googlegenai_chat_completion", + "provider_type": "chat_completion", + "hint": "provider_group.provider.google_vertex_ai.hint", + "enable": True, + "key": [], + "api_base": "https://aiplatform.googleapis.com/v1", + "timeout": 120, + "gm_resp_image_modal": False, + "gm_native_search": False, + "gm_native_coderunner": False, + "gm_url_context": False, + "gm_safety_settings": { + "harassment": "BLOCK_MEDIUM_AND_ABOVE", + "hate_speech": "BLOCK_MEDIUM_AND_ABOVE", + "sexually_explicit": "BLOCK_MEDIUM_AND_ABOVE", + "dangerous_content": "BLOCK_MEDIUM_AND_ABOVE", + }, + "gm_thinking_config": {"budget": 0, "level": "HIGH"}, + "proxy": "", + "vertex_ai_auth_type": "json", + "vertex_ai_project_id": "", + "vertex_ai_location": "global", + "vertex_ai_credentials_path": "", + "vertex_ai_credentials_json": "", + }, "Anthropic": { "id": "anthropic", "provider": "anthropic", @@ -1986,6 +2014,55 @@ "items": {}, "hint": "此处添加的键值对将被合并到 OpenAI SDK 的 default_headers 中,用于自定义 HTTP 请求头。", }, + "vertex_ai_auth_type": { + "description": "密钥格式", + "type": "string", + "options": ["json", "api_key"], + "hint": "provider_group.provider.vertex_ai_auth_type.hint", + "condition": {"provider": "google-vertex-ai"}, + }, + "vertex_ai_api_key": { + "description": "API Key", + "type": "list", + "items": {"type": "string"}, + "hint": "provider_group.provider.vertex_ai_api_key.hint", + "condition": { + "provider": "google-vertex-ai", + "vertex_ai_auth_type": "api_key", + }, + "button_text": "API Key", + "dialog_title": "API Key", + "prefer_single_item": True, + }, + "vertex_ai_project_id": { + "description": "Vertex AI Project ID", + "type": "string", + "hint": "provider_group.provider.vertex_ai_project_id.hint", + "invisible": True, + }, + "vertex_ai_location": { + "description": "Vertex AI Location", + "type": "string", + "hint": "provider_group.provider.vertex_ai_location.hint", + "condition": {"provider": "google-vertex-ai"}, + }, + "vertex_ai_credentials_path": { + "description": "Service Account JSON Path", + "type": "string", + "hint": "provider_group.provider.vertex_ai_credentials_path.hint", + "invisible": True, + }, + "vertex_ai_credentials_json": { + "description": "Service Account JSON", + "type": "text", + "editor_mode": True, + "editor_language": "json", + "hint": "provider_group.provider.vertex_ai_credentials_json.hint", + "condition": { + "provider": "google-vertex-ai", + "vertex_ai_auth_type": "json", + }, + }, "ollama_disable_thinking": { "description": "关闭思考模式", "type": "bool", diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 5aa452bbd5..a35862381f 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -21,6 +21,7 @@ TTSProvider, ) from .register import llm_tools, provider_cls_map +from .sources.vertex_ai import normalize_vertex_ai_provider_config @runtime_checkable @@ -515,7 +516,7 @@ def get_merged_provider_config(self, provider_config: dict) -> dict: # 保持 id 为 provider 的 id,而不是 source 的 id merged_config["id"] = pc["id"] pc = merged_config - return pc + return normalize_vertex_ai_provider_config(pc) def get_provider_config_by_id( self, diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index f38fcfc359..9093a9c80a 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -26,8 +26,26 @@ from astrbot.core.utils.io import download_file, download_image_by_url from astrbot.core.utils.media_utils import ensure_wav from astrbot.core.utils.network_utils import is_connection_error, log_connection_failure +from astrbot.core.utils.string_utils import normalize_and_dedupe_strings from ..register import register_provider_adapter +from .vertex_ai import ( + GOOGLE_CLOUD_PLATFORM_SCOPE, + VERTEX_AI_API_KEY_AUTH, + VERTEX_AI_API_VERSION, + VERTEX_AI_DEFAULT_API_BASE, + VERTEX_AI_JSON_AUTH, + VERTEX_AI_MODEL_NAME_SEPARATOR, + build_vertex_ai_publisher_models_url, + extract_vertex_ai_credentials_json, + fetch_vertex_ai_publisher_models, + make_vertex_ai_refresh_request, + normalize_vertex_ai_auth_type, + normalize_vertex_ai_location, + normalize_vertex_ai_provider_config, + resolve_vertex_ai_project_id, + to_vertex_ai_genai_model_name, +) class SuppressNonTextPartsWarning(logging.Filter): @@ -64,6 +82,8 @@ def __init__( provider_config, provider_settings, ) -> None: + if provider_config.get("provider") == "google-vertex-ai": + provider_config = normalize_vertex_ai_provider_config(provider_config) super().__init__( provider_config, provider_settings, @@ -85,17 +105,24 @@ def __init__( def _init_client(self) -> None: """初始化Gemini客户端""" proxy = self.provider_config.get("proxy", "") - http_options = types.HttpOptions( - base_url=self.api_base, - timeout=self.timeout * 1000, # 毫秒 - ) + is_vertex_ai = self._is_vertex_ai_config() + base_url = self._get_vertex_ai_sdk_base_url() if is_vertex_ai else self.api_base + + http_options_kwargs = { + "base_url": base_url, + "timeout": self.timeout * 1000, # 毫秒 + } + if is_vertex_ai: + http_options_kwargs["api_version"] = VERTEX_AI_API_VERSION + http_options = types.HttpOptions(**http_options_kwargs) # 强制使用 httpx 作为异步 HTTP 后端,避免 aiohttp 响应类型兼容问题 (#7564) # httpx.AsyncClient 的 timeout 单位为秒(与 HttpOptions 的毫秒不同) async_client_kwargs: dict = { - "base_url": self.api_base, "timeout": self.timeout, } + if base_url: + async_client_kwargs["base_url"] = base_url if proxy: async_client_kwargs["proxy"] = proxy async_client_kwargs["trust_env"] = False @@ -112,10 +139,129 @@ def _init_client(self) -> None: self._http_client = httpx.AsyncClient(**async_client_kwargs) http_options.httpx_async_client = self._http_client - self.client = genai.Client( - api_key=self.chosen_api_key, - http_options=http_options, - ).aio + client_kwargs = self._get_genai_client_kwargs(http_options) + self.client = genai.Client(**client_kwargs).aio + + def _is_vertex_ai_config(self) -> bool: + return self.provider_config.get("provider") == "google-vertex-ai" + + def _get_vertex_ai_auth_type(self) -> str: + return normalize_vertex_ai_auth_type( + self.provider_config.get("vertex_ai_auth_type") + ) + + def _get_vertex_ai_sdk_base_url(self) -> str: + configured_base_url = str(self.api_base or "").strip().rstrip("/") + location = normalize_vertex_ai_location( + self.provider_config.get("vertex_ai_location") + ) + + if configured_base_url in { + "", + VERTEX_AI_DEFAULT_API_BASE, + f"{VERTEX_AI_DEFAULT_API_BASE}/{VERTEX_AI_API_VERSION}", + }: + if location == "global": + return VERTEX_AI_DEFAULT_API_BASE + return f"https://{location}-aiplatform.googleapis.com" + + version_suffix = f"/{VERTEX_AI_API_VERSION}" + if configured_base_url.endswith(version_suffix): + return ( + configured_base_url[: -len(version_suffix)] + or VERTEX_AI_DEFAULT_API_BASE + ) + return configured_base_url + + def _get_genai_client_kwargs(self, http_options: types.HttpOptions) -> dict: + if not self._is_vertex_ai_config(): + return { + "api_key": self.chosen_api_key, + "http_options": http_options, + } + + auth_type = self._get_vertex_ai_auth_type() + # API-key auth is used only when no service-account JSON is provided. A + # pasted service-account JSON always takes precedence (it enables OAuth + # and Model Garden discovery), even if the auth type is set to api_key. + if auth_type == VERTEX_AI_API_KEY_AUTH and not self._has_vertex_ai_key_json(): + client_kwargs = { + "vertexai": True, + "http_options": http_options, + } + if self.chosen_api_key: + client_kwargs["api_key"] = self.chosen_api_key + return client_kwargs + + if auth_type in {VERTEX_AI_API_KEY_AUTH, VERTEX_AI_JSON_AUTH}: + return { + "credentials": self._load_vertex_ai_service_account_credentials(), + "project": self._resolve_vertex_ai_project_id_required(), + "location": normalize_vertex_ai_location( + self.provider_config.get("vertex_ai_location") + ), + "vertexai": True, + "http_options": http_options, + } + + raise ValueError( + "Vertex AI key format must be api_key or json for Google GenAI provider." + ) + + def _has_vertex_ai_key_json(self) -> bool: + return bool(extract_vertex_ai_credentials_json(self.provider_config).strip()) + + def _resolve_vertex_ai_project_id_required(self) -> str: + project_id = resolve_vertex_ai_project_id(self.provider_config) + if not project_id: + raise ValueError( + "Vertex AI project id is required for service account auth." + ) + return project_id + + def _load_vertex_ai_service_account_credentials(self): + try: + from google.oauth2 import service_account + except ImportError as exc: + raise RuntimeError( + "google-auth is required for Vertex AI service account auth." + ) from exc + + scopes = [GOOGLE_CLOUD_PLATFORM_SCOPE] + credentials_json = extract_vertex_ai_credentials_json(self.provider_config) + credentials_path = str( + self.provider_config.get("vertex_ai_credentials_path") or "" + ) + + if credentials_json.strip(): + return service_account.Credentials.from_service_account_info( + json.loads(credentials_json), + scopes=scopes, + ) + + if credentials_path.strip(): + return service_account.Credentials.from_service_account_file( + str(Path(credentials_path).expanduser()), + scopes=scopes, + ) + + raise ValueError( + "Vertex AI service account auth requires a JSON file path or pasted JSON." + ) + + def _ensure_vertex_ai_api_key_for_generation(self) -> None: + if ( + self._is_vertex_ai_config() + and self._get_vertex_ai_auth_type() == VERTEX_AI_API_KEY_AUTH + and not self._has_vertex_ai_key_json() + and not self.chosen_api_key + ): + raise ValueError("Vertex AI API key is required for chat generation.") + + def _get_request_model_name(self, model: str) -> str: + if self._is_vertex_ai_config(): + return to_vertex_ai_genai_model_name(model) + return model def _init_safety_settings(self) -> None: """初始化安全设置""" @@ -605,12 +751,13 @@ def _process_content_parts( async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: """非流式请求 Gemini API""" + self._ensure_vertex_ai_api_key_for_generation() system_instruction = next( (msg["content"] for msg in payloads["messages"] if msg["role"] == "system"), None, ) - model = payloads.get("model", self.get_model()) + model = self._get_request_model_name(payloads.get("model", self.get_model())) modalities = ["TEXT"] if self.provider_config.get("gm_resp_image_modal", False): @@ -635,7 +782,12 @@ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: contents=cast(types.ContentListUnion, conversation), config=config, ) - logger.debug(f"genai result: {result}") + logger.debug( + "[Gemini] generate_content result: " + f"candidates={len(result.candidates or [])}, " + f"model_version={getattr(result, 'model_version', None)}, " + f"response_id={getattr(result, 'response_id', None)}" + ) if not result.candidates: logger.error(f"请求失败, 返回的 candidates 为空: {result}") @@ -694,11 +846,12 @@ async def _query_stream( tools: ToolSet | None, ) -> AsyncGenerator[LLMResponse, None]: """流式请求 Gemini API""" + self._ensure_vertex_ai_api_key_for_generation() system_instruction = next( (msg["content"] for msg in payloads["messages"] if msg["role"] == "system"), None, ) - model = payloads.get("model", self.get_model()) + model = self._get_request_model_name(payloads.get("model", self.get_model())) conversation = self._prepare_conversation(payloads) result = None @@ -939,18 +1092,69 @@ async def text_chat_stream( break async def get_models(self): + if self._is_vertex_ai_config(): + return await self._get_vertex_ai_models() + try: models = await self.client.models.list() - return [ - m.name.replace("models/", "") - for m in models - if m.supported_actions - and "generateContent" in m.supported_actions - and m.name - ] + model_ids = [] + for model in models: + if ( + not model.supported_actions + or "generateContent" not in model.supported_actions + or not model.name + ): + continue + model_ids.append(self._normalize_model_name(model.name)) + return normalize_and_dedupe_strings(model_ids) except APIError as e: raise Exception(f"获取模型列表失败: {e.message}") + async def _get_vertex_ai_models(self) -> list[str]: + auth_type = self._get_vertex_ai_auth_type() + if auth_type == VERTEX_AI_API_KEY_AUTH and not self._has_vertex_ai_key_json(): + raise ValueError( + "Vertex AI API Key 无法获取模型列表。请将密钥格式切换为 json " + "并填写 Google Cloud 服务账号 JSON 后获取模型列表,或手动添加模型。" + ) + + credentials = self._load_vertex_ai_service_account_credentials() + if not getattr(credentials, "valid", False) or not getattr( + credentials, "token", None + ): + # credentials.refresh performs blocking network I/O; run it off the + # event loop so model discovery does not stall the async runtime. + await asyncio.to_thread( + credentials.refresh, + make_vertex_ai_refresh_request(self.provider_config), + ) + + token = getattr(credentials, "token", None) + if not token: + raise RuntimeError("Failed to refresh Vertex AI access token.") + + url = build_vertex_ai_publisher_models_url(self.provider_config) + project_id = self._resolve_vertex_ai_project_id_required() + headers = { + "Authorization": f"Bearer {token}", + "Accept-Encoding": "gzip, deflate", + "x-goog-user-project": project_id, + } + assert self._http_client is not None + return await fetch_vertex_ai_publisher_models(self._http_client, url, headers) + + def _normalize_model_name(self, name: str) -> str: + name = name.strip() + if self._is_vertex_ai_config(): + if VERTEX_AI_MODEL_NAME_SEPARATOR in name: + name = name.rsplit(VERTEX_AI_MODEL_NAME_SEPARATOR, 1)[1] + elif name.startswith("models/"): + name = name.removeprefix("models/") + if not name.startswith("google/"): + name = f"google/{name}" + return name + return name.replace("models/", "") + def get_current_key(self) -> str: return self.chosen_api_key diff --git a/astrbot/core/provider/sources/vertex_ai.py b/astrbot/core/provider/sources/vertex_ai.py new file mode 100644 index 0000000000..1adefe2a1a --- /dev/null +++ b/astrbot/core/provider/sources/vertex_ai.py @@ -0,0 +1,286 @@ +import json +from pathlib import Path +from typing import Any +from urllib.parse import urlparse, urlunparse + +from astrbot.core import logger +from astrbot.core.utils.string_utils import normalize_and_dedupe_strings + +GOOGLE_CLOUD_PLATFORM_SCOPE = "https://www.googleapis.com/auth/cloud-platform" +VERTEX_AI_DEFAULT_LOCATION = "global" +VERTEX_AI_DEFAULT_API_BASE = "https://aiplatform.googleapis.com" +VERTEX_AI_API_VERSION = "v1" +VERTEX_AI_MODEL_GARDEN_API_VERSION = "v1beta1" +VERTEX_AI_API_KEY_AUTH = "api_key" +VERTEX_AI_JSON_AUTH = "json" +VERTEX_AI_SERVICE_ACCOUNT_AUTH = "service_account" +VERTEX_AI_GOOGLE_GENAI_TYPE = "googlegenai_chat_completion" +VERTEX_AI_API_KEY_FIELD = "vertex_ai_api_key" +VERTEX_AI_MODEL_NAME_SEPARATOR = "/models/" +VERTEX_AI_GOOGLE_MODEL_PREFIX = "google/" +VERTEX_AI_GOOGLE_PUBLISHER_MODEL_PREFIX = "publishers/google/models/" +VERTEX_AI_MODEL_LIST_PAGE_SIZE = "300" +VERTEX_AI_NON_CHAT_MODEL_ID_PARTS = ( + "embedding", + "-tts", + "gemini-live", + "native-audio", +) + + +def normalize_vertex_ai_auth_type(auth_type: Any) -> str: + normalized = str(auth_type or VERTEX_AI_JSON_AUTH).strip().lower() + if normalized == VERTEX_AI_SERVICE_ACCOUNT_AUTH: + return VERTEX_AI_JSON_AUTH + return normalized + + +def normalize_vertex_ai_provider_config( + provider_config: dict[str, Any], +) -> dict[str, Any]: + """Normalize Vertex AI source configs for runtime compatibility. + + The dashboard stores Vertex AI API keys in a provider-specific field so the + service-account JSON field can remain separate. Provider implementations + still use the common ``key`` list for key rotation, so API-key configs are + mirrored into ``key`` at runtime. Legacy configs that stored either an API + key or a pasted service-account JSON in ``key`` remain readable. + """ + + if provider_config.get("provider") != "google-vertex-ai": + return provider_config + + provider_config = dict(provider_config) + auth_type = normalize_vertex_ai_auth_type( + provider_config.get("vertex_ai_auth_type") + ) + provider_config["vertex_ai_auth_type"] = auth_type + + explicit_api_key = provider_config.get(VERTEX_AI_API_KEY_FIELD) + legacy_key = provider_config.get("key") + if explicit_api_key is None: + provider_config[VERTEX_AI_API_KEY_FIELD] = ( + [] + if _key_value_looks_like_json(legacy_key) + else _normalize_key_list(legacy_key) + ) + + if auth_type == VERTEX_AI_API_KEY_AUTH: + provider_config["type"] = VERTEX_AI_GOOGLE_GENAI_TYPE + provider_config["key"] = _normalize_key_list( + provider_config.get(VERTEX_AI_API_KEY_FIELD) + ) + return provider_config + + +def normalize_vertex_ai_provider_source_config( + provider_config: dict[str, Any], +) -> dict[str, Any]: + """Normalize a Vertex AI provider source before persisting it.""" + + normalized = normalize_vertex_ai_provider_config(provider_config) + if normalized.get("provider") == "google-vertex-ai": + normalized.pop("key", None) + return normalized + + +def _normalize_key_list(value: Any) -> list[str]: + if value is None: + return [] + if isinstance(value, list): + return [str(item) for item in value if str(item).strip()] + text = str(value).strip() + return [text] if text else [] + + +def _key_value_looks_like_json(value: Any) -> bool: + keys = _normalize_key_list(value) + return bool(keys and keys[0].lstrip().startswith("{")) + + +def extract_vertex_ai_credentials_json(provider_config: dict[str, Any]) -> str: + """Return service-account JSON from explicit config or the common key field.""" + + credentials_json = str(provider_config.get("vertex_ai_credentials_json") or "") + if credentials_json.strip(): + return credentials_json + + key = provider_config.get("key") or "" + if isinstance(key, list): + key = key[0] if key else "" + key = str(key).strip() + if key.startswith("{"): + return key + return "" + + +def to_vertex_ai_genai_model_name(model: Any) -> str: + """Convert AstrBot-facing Vertex AI model ids to google-genai request ids.""" + + model_name = str(model or "").strip() + if not model_name: + return model_name + if model_name.startswith(VERTEX_AI_GOOGLE_PUBLISHER_MODEL_PREFIX): + return model_name + if model_name.startswith(VERTEX_AI_GOOGLE_MODEL_PREFIX): + return ( + f"{VERTEX_AI_GOOGLE_PUBLISHER_MODEL_PREFIX}" + f"{model_name.removeprefix(VERTEX_AI_GOOGLE_MODEL_PREFIX)}" + ) + if model_name.startswith("models/"): + return f"{VERTEX_AI_GOOGLE_PUBLISHER_MODEL_PREFIX}{model_name.removeprefix('models/')}" + if model_name.startswith("gemini-"): + return f"{VERTEX_AI_GOOGLE_PUBLISHER_MODEL_PREFIX}{model_name}" + return model_name + + +def make_vertex_ai_refresh_request(provider_config: dict[str, Any]) -> Any: + try: + from google.auth.transport.requests import Request + except ImportError as exc: + raise RuntimeError( + "google-auth requests transport is required for Vertex AI auth." + ) from exc + + proxy = str(provider_config.get("proxy") or "").strip() + if not proxy: + return Request() + + try: + import requests + except ImportError: + logger.warning( + "Vertex AI proxy is configured but the 'requests' package is " + "unavailable; the proxy will be ignored for OAuth token refresh." + ) + return Request() + + session = requests.Session() + session.proxies.update({"http": proxy, "https": proxy}) + session.trust_env = False + return Request(session=session) + + +def normalize_vertex_ai_location(location: Any) -> str: + normalized = str(location or "").strip() + return normalized or VERTEX_AI_DEFAULT_LOCATION + + +def _normalize_base_url(base_url: str) -> str: + return base_url.strip().rstrip("/") + + +def _default_vertex_ai_api_base(location: str) -> str: + if location == VERTEX_AI_DEFAULT_LOCATION: + return VERTEX_AI_DEFAULT_API_BASE + return f"https://{location}-aiplatform.googleapis.com" + + +def _is_default_vertex_ai_api_base(base_url: str) -> bool: + parsed = urlparse(base_url) + return ( + parsed.scheme.lower() == "https" + and parsed.netloc.lower() == "aiplatform.googleapis.com" + and parsed.path.strip("/") in {"", VERTEX_AI_API_VERSION} + ) + + +def _read_service_account_info(provider_config: dict[str, Any]) -> dict[str, Any]: + credentials_json = extract_vertex_ai_credentials_json(provider_config) + if credentials_json.strip(): + return json.loads(credentials_json) + + credentials_path = str(provider_config.get("vertex_ai_credentials_path") or "") + if credentials_path.strip(): + return json.loads( + Path(credentials_path).expanduser().read_text(encoding="utf-8") + ) + + return {} + + +def resolve_vertex_ai_project_id(provider_config: dict[str, Any]) -> str: + project_id = str(provider_config.get("vertex_ai_project_id") or "").strip() + if project_id: + return project_id + + try: + credentials_json = extract_vertex_ai_credentials_json(provider_config) + if credentials_json.strip(): + info = json.loads(credentials_json) + else: + info = _read_service_account_info(provider_config) + except Exception: + return "" + return str(info.get("project_id") or "").strip() + + +def build_vertex_ai_publisher_models_url(provider_config: dict[str, Any]) -> str: + """Build the native Vertex AI publisher models URL for model discovery.""" + + location = normalize_vertex_ai_location(provider_config.get("vertex_ai_location")) + configured_base_url = _normalize_base_url( + str(provider_config.get("api_base") or "") + ) + if not configured_base_url or _is_default_vertex_ai_api_base(configured_base_url): + base_url = _default_vertex_ai_api_base(location) + else: + base_url = configured_base_url + + parsed = urlparse(base_url) + path_parts = [ + part + for part in parsed.path.strip("/").split("/") + if part and part != VERTEX_AI_API_VERSION + ] + path = "/".join([VERTEX_AI_MODEL_GARDEN_API_VERSION, *path_parts]) + path = f"/{path}/publishers/google/models" + return urlunparse(parsed._replace(path=path, query="", fragment="")) + + +async def fetch_vertex_ai_publisher_models( + http_client: Any, + url: str, + headers: dict[str, str], +) -> list[str]: + """Fetch Gemini publisher model IDs from Vertex AI Model Garden.""" + + params = { + "pageSize": VERTEX_AI_MODEL_LIST_PAGE_SIZE, + } + model_ids: list[str] = [] + next_page_token = None + + while True: + if next_page_token: + params["pageToken"] = next_page_token + else: + params.pop("pageToken", None) + + response = await http_client.get(url, headers=headers, params=params) + response.raise_for_status() + data = response.json() + + for model in data.get("publisherModels", data.get("models", [])): + name = str(model.get("name") or "").strip() + if not name: + continue + if VERTEX_AI_MODEL_NAME_SEPARATOR in name: + name = name.rsplit(VERTEX_AI_MODEL_NAME_SEPARATOR, 1)[1] + elif name.startswith("models/"): + name = name.removeprefix("models/") + if not _is_vertex_ai_chat_model_id(name): + continue + model_ids.append(f"{VERTEX_AI_GOOGLE_MODEL_PREFIX}{name}") + + next_page_token = data.get("nextPageToken") or data.get("next_page_token") + if not next_page_token: + break + + return sorted(normalize_and_dedupe_strings(model_ids)) + + +def _is_vertex_ai_chat_model_id(model_id: str) -> bool: + if not model_id.startswith("gemini-"): + return False + return not any(part in model_id for part in VERTEX_AI_NON_CHAT_MODEL_ID_PARTS) diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 9ec24d254d..25a4ee8647 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -22,6 +22,10 @@ from astrbot.core.platform.register import platform_cls_map, platform_registry from astrbot.core.provider import Provider from astrbot.core.provider.register import provider_registry +from astrbot.core.provider.sources.vertex_ai import ( + normalize_vertex_ai_provider_config, + normalize_vertex_ai_provider_source_config, +) from astrbot.core.star.star import StarMetadata, star_registry from astrbot.core.utils.astrbot_path import ( get_astrbot_plugin_data_path, @@ -445,6 +449,9 @@ async def update_provider_source(self): # 确保配置中有 id 字段 if not new_source_config.get("id"): new_source_config["id"] = original_id + new_source_config = normalize_vertex_ai_provider_source_config( + new_source_config + ) provider_sources = self.config.get("provider_sources", []) @@ -924,7 +931,9 @@ async def get_provider_source_models(self): provider_source = None for ps in provider_sources: if ps.get("id") == provider_source_id: - provider_source = ps + provider_source = normalize_vertex_ai_provider_config( + copy.deepcopy(ps) + ) break if not provider_source: @@ -971,15 +980,17 @@ async def get_provider_source_models(self): # 临时实例化 provider inst = cls_type(provider_source, {}) + try: + init_fn = getattr(inst, "initialize", None) + if inspect.iscoroutinefunction(init_fn): + await init_fn() - # 如果有 initialize 方法,调用它 - init_fn = getattr(inst, "initialize", None) - if inspect.iscoroutinefunction(init_fn): - await init_fn() - - # 获取模型列表 - models = await inst.get_models() - models = models or [] + models = await inst.get_models() + models = models or [] + finally: + terminate_fn = getattr(inst, "terminate", None) + if inspect.iscoroutinefunction(terminate_fn): + await terminate_fn() metadata_map = {} for model_id in models: @@ -987,11 +998,6 @@ async def get_provider_source_models(self): if meta: metadata_map[model_id] = meta - # 销毁实例(如果有 terminate 方法) - terminate_fn = getattr(inst, "terminate", None) - if inspect.iscoroutinefunction(terminate_fn): - await terminate_fn() - return ( Response() .ok({"models": models, "model_metadata": metadata_map}) diff --git a/dashboard/src/components/shared/ConfigItemRenderer.vue b/dashboard/src/components/shared/ConfigItemRenderer.vue index 5211f8a2ec..afd2fc3bce 100644 --- a/dashboard/src/components/shared/ConfigItemRenderer.vue +++ b/dashboard/src/components/shared/ConfigItemRenderer.vue @@ -206,6 +206,9 @@ v-else-if="itemMeta?.type === 'list'" :model-value="modelValue" @update:model-value="emitUpdate" + :button-text="itemMeta?.button_text || ''" + :dialog-title="itemMeta?.dialog_title || ''" + :prefer-single-item="itemMeta?.prefer_single_item ?? true" class="config-field" /> diff --git a/dashboard/src/composables/useProviderSources.ts b/dashboard/src/composables/useProviderSources.ts index a85ef2ae52..aa2afcd454 100644 --- a/dashboard/src/composables/useProviderSources.ts +++ b/dashboard/src/composables/useProviderSources.ts @@ -184,7 +184,9 @@ export function useProviderSources(options: UseProviderSourcesOptions) { const basicSourceConfig = computed(() => { if (!editableProviderSource.value) return null - const fields = ['id', 'key', 'api_base'] + const fields = editableProviderSource.value.provider === 'google-vertex-ai' + ? ['id', 'api_base'] + : ['id', 'key', 'api_base'] const basic: Record = {} fields.forEach((field) => { @@ -202,13 +204,45 @@ export function useProviderSources(options: UseProviderSourcesOptions) { return basic }) + const vertexHiddenFields = new Set([ + 'vertex_ai_project_id', + 'vertex_ai_credentials_path' + ]) + + const vertexAdvancedSourceFields = [ + 'vertex_ai_auth_type', + 'vertex_ai_api_key', + 'vertex_ai_credentials_json', + 'vertex_ai_location', + 'timeout', + 'proxy', + 'gm_resp_image_modal', + 'gm_native_search', + 'gm_native_coderunner', + 'gm_url_context', + 'gm_safety_settings', + 'gm_thinking_config' + ] + const advancedSourceConfig = computed(() => { if (!editableProviderSource.value) return null const excluded = new Set(['id', 'key', 'api_base', 'enable', 'type', 'provider_type', 'provider']) + if (editableProviderSource.value.provider === 'google-vertex-ai') { + for (const field of vertexHiddenFields) { + excluded.add(field) + } + } const advanced: Record = {} - - for (const key of Object.keys(editableProviderSource.value)) { + const sourceKeys = Object.keys(editableProviderSource.value) + const keys = editableProviderSource.value.provider === 'google-vertex-ai' + ? [ + ...vertexAdvancedSourceFields.filter((field) => sourceKeys.includes(field)), + ...sourceKeys.filter((field) => !vertexAdvancedSourceFields.includes(field)) + ] + : sourceKeys + + for (const key of keys) { Object.defineProperty(advanced, key, { get() { return editableProviderSource.value![key] @@ -372,6 +406,21 @@ export function useProviderSources(options: UseProviderSourcesOptions) { source.ollama_disable_thinking = false } + if (source.provider === 'google-vertex-ai') { + if (!source.vertex_ai_auth_type || source.vertex_ai_auth_type === 'service_account') { + source.vertex_ai_auth_type = 'json' + } + if (source.vertex_ai_api_key === undefined) { + source.vertex_ai_api_key = source.vertex_ai_auth_type === 'api_key' ? (source.key || []) : [] + } + if (source.vertex_ai_credentials_json === undefined) { + source.vertex_ai_credentials_json = '' + } + if (!source.vertex_ai_location) { + source.vertex_ai_location = 'global' + } + } + return source } diff --git a/dashboard/src/i18n/locales/en-US/features/config-metadata.json b/dashboard/src/i18n/locales/en-US/features/config-metadata.json index 6363b71e31..c431b1c315 100644 --- a/dashboard/src/i18n/locales/en-US/features/config-metadata.json +++ b/dashboard/src/i18n/locales/en-US/features/config-metadata.json @@ -1188,6 +1188,33 @@ "description": "Custom request headers", "hint": "Key/value pairs added here are merged into the OpenAI SDK default_headers for custom HTTP headers. Values must be strings." }, + "google_vertex_ai": { + "hint": "Vertex AI through the native Gemini API. The default API Base URL is https://aiplatform.googleapis.com/v1 and the default location is global. Use either a Vertex AI API key or a service account JSON." + }, + "vertex_ai_auth_type": { + "description": "Key format", + "hint": "Use either a Vertex AI API key or a service account JSON. Service account JSON supports model discovery; API key is generation-only." + }, + "vertex_ai_api_key": { + "description": "API Key", + "hint": "Vertex AI API key. Used when key format is API key; leave empty when using service account JSON." + }, + "vertex_ai_project_id": { + "description": "Vertex AI project ID", + "hint": "Google Cloud project ID. If omitted, AstrBot tries to read it from the service account JSON." + }, + "vertex_ai_location": { + "description": "Vertex AI location", + "hint": "Model location, for example global, us-central1. Default: global." + }, + "vertex_ai_credentials_path": { + "description": "Service account JSON path", + "hint": "Absolute path to a Google Cloud service account JSON file. Leave empty if you paste JSON below." + }, + "vertex_ai_credentials_json": { + "description": "Service account JSON", + "hint": "Google Cloud service account JSON can be created at https://console.cloud.google.com/iam-admin/serviceaccounts" + }, "ollama_disable_thinking": { "description": "Disable thinking mode", "hint": "Close Ollama thinking mode." diff --git a/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json b/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json index 028bff8675..4346204394 100644 --- a/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json +++ b/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json @@ -1189,6 +1189,33 @@ "description": "Заголовки запроса", "hint": "Пары ключ/значение будут добавлены в заголовки запроса (default_headers). Значения должны быть строками." }, + "google_vertex_ai": { + "hint": "Vertex AI через нативный Gemini API. API Base URL по умолчанию: https://aiplatform.googleapis.com/v1, location по умолчанию: global. Используйте либо Vertex AI API key, либо JSON сервисного аккаунта." + }, + "vertex_ai_auth_type": { + "description": "Формат ключа", + "hint": "Используйте либо Vertex AI API key, либо JSON сервисного аккаунта. JSON поддерживает получение списка моделей; API key только для генерации." + }, + "vertex_ai_api_key": { + "description": "API Key", + "hint": "Vertex AI API key. Используется, когда выбран формат API key; оставьте пустым при использовании JSON сервисного аккаунта." + }, + "vertex_ai_project_id": { + "description": "Vertex AI project ID", + "hint": "ID проекта Google Cloud. Если пусто, AstrBot попробует прочитать его из JSON сервисного аккаунта." + }, + "vertex_ai_location": { + "description": "Vertex AI location", + "hint": "Location модели, например global, us-central1. По умолчанию: global." + }, + "vertex_ai_credentials_path": { + "description": "Путь к JSON сервисного аккаунта", + "hint": "Абсолютный путь к JSON файлу сервисного аккаунта Google Cloud. Оставьте пустым, если JSON вставлен ниже." + }, + "vertex_ai_credentials_json": { + "description": "JSON сервисного аккаунта", + "hint": "JSON сервисного аккаунта Google Cloud можно создать здесь: https://console.cloud.google.com/iam-admin/serviceaccounts" + }, "custom_extra_body": { "description": "Параметры тела запроса", "hint": "Добавление дополнительных параметров в запрос (temperature, top_p и др.).", diff --git a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json index 70f4fa5c79..c7fb0fb953 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json +++ b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json @@ -1190,6 +1190,33 @@ "description": "自定义请求头", "hint": "此处添加的键值对将被合并到 OpenAI SDK 的 default_headers 中,用于自定义 HTTP 请求头。值必须为字符串。" }, + "google_vertex_ai": { + "hint": "通过 Vertex AI 原生 Gemini API 调用 Google 模型。默认 API Base URL 为 https://aiplatform.googleapis.com/v1,地区默认为 global。在 Vertex AI API Key 与服务账号 JSON 选择其一使用即可。" + }, + "vertex_ai_auth_type": { + "description": "密钥格式", + "hint": "在 Vertex AI API Key、服务账号 JSON 选择其一使用即可。" + }, + "vertex_ai_api_key": { + "description": "API Key", + "hint": "在 Vertex AI API Key、服务账号 JSON 选择其一使用即可。" + }, + "vertex_ai_project_id": { + "description": "Vertex AI 项目 ID", + "hint": "Google Cloud 项目 ID。留空时 AstrBot 会尝试从服务账号 JSON 中读取。" + }, + "vertex_ai_location": { + "description": "Vertex AI 地区", + "hint": "模型所在地区,例如 global、us-central1。默认:global。" + }, + "vertex_ai_credentials_path": { + "description": "服务账号 JSON 路径", + "hint": "Google Cloud 服务账号 JSON 文件的绝对路径。如果在下方粘贴 JSON,可留空。" + }, + "vertex_ai_credentials_json": { + "description": "服务账号 JSON", + "hint": "Google Cloud 服务账号 JSON 可在该网址创建获得 https://console.cloud.google.com/iam-admin/serviceaccounts" + }, "ollama_disable_thinking": { "description": "关闭思考模式", "hint": "关闭 Ollama 思考模式。" diff --git a/dashboard/src/utils/providerUtils.js b/dashboard/src/utils/providerUtils.js index dbf09b83a3..4584ea5df6 100644 --- a/dashboard/src/utils/providerUtils.js +++ b/dashboard/src/utils/providerUtils.js @@ -15,6 +15,7 @@ export function getProviderIcon(type) { 'anthropic': 'https://cdn.jsdelivr.net/npm/@lobehub/icons-static-svg@latest/icons/anthropic.svg', 'ollama': 'https://cdn.jsdelivr.net/npm/@lobehub/icons-static-svg@latest/icons/ollama.svg', 'google': 'https://cdn.jsdelivr.net/npm/@lobehub/icons-static-svg@latest/icons/gemini-color.svg', + 'google-vertex-ai': 'https://cdn.jsdelivr.net/npm/@lobehub/icons-static-svg@latest/icons/vertexai-color.svg', 'deepseek': 'https://cdn.jsdelivr.net/npm/@lobehub/icons-static-svg@latest/icons/deepseek.svg', 'modelscope': 'https://cdn.jsdelivr.net/npm/@lobehub/icons-static-svg@latest/icons/modelscope.svg', 'zhipu': 'https://cdn.jsdelivr.net/npm/@lobehub/icons-static-svg@latest/icons/zhipu.svg', diff --git a/pyproject.toml b/pyproject.toml index c3ff88b741..e5239e0737 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "faiss-cpu>=1.12.0", "filelock>=3.18.0", "google-genai>=1.56.0", + "google-auth[requests]>=2.41.1", "httpx[socks]>=0.28.1", "lark-oapi>=1.4.15", "mcp>=1.8.0", diff --git a/requirements.txt b/requirements.txt index e667c67559..ae0bf9366e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,6 +15,7 @@ docstring-parser>=0.16 faiss-cpu>=1.12.0 filelock>=3.18.0 google-genai>=1.56.0 +google-auth[requests]>=2.41.1 httpx[socks]>=0.28.1 lark-oapi>=1.4.15 mcp>=1.8.0 diff --git a/tests/test_vertex_ai.py b/tests/test_vertex_ai.py new file mode 100644 index 0000000000..ce94804c92 --- /dev/null +++ b/tests/test_vertex_ai.py @@ -0,0 +1,744 @@ +import json +from pathlib import Path + +import httpx +import pytest +from google.auth import credentials as google_auth_credentials + +from astrbot.core.config.default import CONFIG_METADATA_2 +from astrbot.core.provider.sources.gemini_source import ProviderGoogleGenAI +from astrbot.core.provider.sources.vertex_ai import ( + build_vertex_ai_publisher_models_url, + make_vertex_ai_refresh_request, + normalize_vertex_ai_provider_config, + normalize_vertex_ai_provider_source_config, + resolve_vertex_ai_project_id, + to_vertex_ai_genai_model_name, +) + + +def _vertex_config(**overrides): + config = { + "id": "vertex-test", + "provider": "google-vertex-ai", + "type": "googlegenai_chat_completion", + "provider_type": "chat_completion", + "model": "google/gemini-3-flash-preview", + "key": [], + "api_base": "", + "timeout": 120, + "proxy": "", + "custom_headers": {}, + "vertex_ai_auth_type": "json", + "vertex_ai_project_id": "demo-project", + "vertex_ai_location": "global", + "vertex_ai_credentials_path": "", + "vertex_ai_credentials_json": "", + } + config.update(overrides) + return config + + +def _vertex_api_key_config(**overrides): + config = _vertex_config( + type="googlegenai_chat_completion", + vertex_ai_api_key=["test-api-key"], + api_base="https://aiplatform.googleapis.com/v1", + custom_headers={}, + vertex_ai_auth_type="api_key", + vertex_ai_project_id="", + vertex_ai_credentials_path="", + vertex_ai_credentials_json="", + ) + config.update(overrides) + return config + + +def _vertex_key_json_config(**overrides): + config = _vertex_config( + type="googlegenai_chat_completion", + api_base="https://aiplatform.googleapis.com/v1", + vertex_ai_credentials_json=json.dumps( + { + "project_id": "demo-project", + "client_email": "vertex@example.iam.gserviceaccount.com", + "private_key": "-----BEGIN PRIVATE KEY-----\nkey\n-----END PRIVATE KEY-----\n", + "token_uri": "https://oauth2.googleapis.com/token", + } + ), + ) + config.update(overrides) + return config + + +class FakeVertexCredentials(google_auth_credentials.Credentials): + def __init__(self): + super().__init__() + self.token = "ya29.fake-token" + self.refresh_count = 0 + + def refresh(self, request): + self.refresh_count += 1 + self.token = "ya29.refreshed-token" + + +def test_vertex_ai_publisher_models_url_uses_native_vertex_endpoint(): + config = _vertex_config( + api_base="https://aiplatform.googleapis.com/v1", + vertex_ai_project_id="demo-project", + vertex_ai_location="global", + ) + + assert build_vertex_ai_publisher_models_url(config) == ( + "https://aiplatform.googleapis.com/v1beta1/publishers/google/models" + ) + + +@pytest.mark.parametrize( + ("model", "expected"), + [ + ( + "google/gemini-3-flash-preview", + "publishers/google/models/gemini-3-flash-preview", + ), + ("models/gemini-2.5-pro", "publishers/google/models/gemini-2.5-pro"), + ("gemini-2.5-flash", "publishers/google/models/gemini-2.5-flash"), + ( + "publishers/google/models/gemini-3-flash-preview", + "publishers/google/models/gemini-3-flash-preview", + ), + ], +) +def test_vertex_ai_genai_model_name_normalization(model, expected): + assert to_vertex_ai_genai_model_name(model) == expected + + +def test_vertex_ai_project_id_can_be_loaded_from_service_account_file(tmp_path): + credentials_path = tmp_path / "service-account.json" + credentials_path.write_text( + json.dumps({"project_id": "file-project"}), + encoding="utf-8", + ) + + assert ( + resolve_vertex_ai_project_id( + _vertex_config( + vertex_ai_project_id="", + vertex_ai_credentials_path=str(credentials_path), + ) + ) + == "file-project" + ) + + +def test_vertex_ai_refresh_request_uses_provider_proxy(monkeypatch): + captured = {} + + class FakeSession: + def __init__(self): + self.proxies = {} + self.trust_env = True + + class FakeRequest: + def __init__(self, session=None): + captured["session"] = session + + monkeypatch.setattr("google.auth.transport.requests.Request", FakeRequest) + monkeypatch.setattr("requests.Session", FakeSession) + + assert make_vertex_ai_refresh_request({"proxy": "http://127.0.0.1:7890"}) + session = captured["session"] + assert session.proxies == { + "http": "http://127.0.0.1:7890", + "https": "http://127.0.0.1:7890", + } + assert session.trust_env is False + + +def test_vertex_ai_api_key_config_is_normalized_to_google_genai_provider(): + config = _vertex_config( + type="openai_chat_completion", + vertex_ai_auth_type="api_key", + vertex_ai_api_key=["vertex-api-key"], + ) + + normalized = normalize_vertex_ai_provider_config(config) + assert normalized["type"] == "googlegenai_chat_completion" + assert normalized["key"] == ["vertex-api-key"] + + +def test_vertex_ai_runtime_normalization_does_not_mutate_source_config(): + config = _vertex_config( + type="openai_chat_completion", + vertex_ai_auth_type="api_key", + vertex_ai_api_key=["vertex-api-key"], + vertex_ai_credentials_json='{"project_id":"demo-project"}', + ) + + normalized = normalize_vertex_ai_provider_config(config) + assert normalized is not config + assert normalized["key"] == ["vertex-api-key"] + assert normalized["vertex_ai_credentials_json"] == '{"project_id":"demo-project"}' + assert "key" in config + assert config["key"] == [] + assert config["type"] == "openai_chat_completion" + assert config["vertex_ai_credentials_json"] == '{"project_id":"demo-project"}' + + +def test_vertex_ai_api_key_normalization_migrates_legacy_key_field(): + config = _vertex_config( + type="openai_chat_completion", + vertex_ai_auth_type="api_key", + key=["legacy-api-key"], + ) + + normalized = normalize_vertex_ai_provider_config(config) + assert normalized["vertex_ai_api_key"] == ["legacy-api-key"] + assert normalized["key"] == ["legacy-api-key"] + + +def test_vertex_ai_source_normalization_does_not_persist_runtime_key(): + config = _vertex_config( + type="googlegenai_chat_completion", + vertex_ai_auth_type="api_key", + vertex_ai_api_key=["vertex-api-key"], + key=["stale-runtime-key"], + ) + + normalized = normalize_vertex_ai_provider_source_config(config) + assert normalized["vertex_ai_api_key"] == ["vertex-api-key"] + assert "key" not in normalized + assert config["key"] == ["stale-runtime-key"] + + +def test_vertex_ai_source_normalization_preserves_json_and_api_key_fields(): + config = _vertex_config( + type="googlegenai_chat_completion", + vertex_ai_auth_type="json", + vertex_ai_api_key=["vertex-api-key"], + vertex_ai_credentials_json='{"project_id":"demo-project"}', + key=["stale-runtime-key"], + ) + + normalized = normalize_vertex_ai_provider_source_config(config) + assert normalized["vertex_ai_auth_type"] == "json" + assert normalized["vertex_ai_api_key"] == ["vertex-api-key"] + assert normalized["vertex_ai_credentials_json"] == '{"project_id":"demo-project"}' + assert "key" not in normalized + + +def test_vertex_ai_missing_auth_type_defaults_to_json(): + config = _vertex_config(vertex_ai_auth_type="") + + assert normalize_vertex_ai_provider_config(config)["vertex_ai_auth_type"] == "json" + + +def test_vertex_ai_legacy_service_account_auth_type_normalizes_to_json(): + config = _vertex_config(vertex_ai_auth_type="service_account") + + assert normalize_vertex_ai_provider_config(config)["vertex_ai_auth_type"] == "json" + + +@pytest.mark.asyncio +async def test_vertex_ai_api_key_provider_uses_native_vertex_endpoint(): + provider = ProviderGoogleGenAI(_vertex_api_key_config(), provider_settings={}) + + try: + api_client = provider.client._api_client + request = api_client._build_request( + "post", + "publishers/google/models/gemini-3-flash-preview:generateContent", + {"contents": []}, + None, + ) + + assert api_client.vertexai is True + assert api_client._http_options.api_version == "v1" + assert ( + api_client._http_options.base_url.rstrip("/") + == "https://aiplatform.googleapis.com" + ) + assert api_client._http_options.headers["x-goog-api-key"] == "test-api-key" + assert str(request.url) == ( + "https://aiplatform.googleapis.com/v1/publishers/google/" + "models/gemini-3-flash-preview:generateContent" + ) + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_vertex_ai_api_key_text_chat_hits_native_vertex_endpoint(): + requests: list[httpx.Request] = [] + + async def handler(request: httpx.Request) -> httpx.Response: + requests.append(request) + return httpx.Response( + 200, + json={ + "candidates": [ + { + "content": { + "role": "model", + "parts": [{"text": "pong"}], + }, + "finishReason": "STOP", + } + ], + "usageMetadata": { + "promptTokenCount": 1, + "candidatesTokenCount": 1, + }, + "responseId": "vertex-api-key-response", + }, + ) + + transport = httpx.MockTransport(handler) + provider = ProviderGoogleGenAI(_vertex_api_key_config(), provider_settings={}) + await provider._http_client.aclose() + provider._http_client = httpx.AsyncClient( + transport=transport, + base_url=provider._get_vertex_ai_sdk_base_url(), + timeout=provider.timeout, + ) + provider.client._api_client._async_httpx_client = provider._http_client + + try: + response = await provider.text_chat(prompt="ping") + + assert response.completion_text == "pong" + assert len(requests) == 1 + assert str(requests[0].url) == ( + "https://aiplatform.googleapis.com/v1/publishers/google/" + "models/gemini-3-flash-preview:generateContent" + ) + assert requests[0].headers["x-goog-api-key"] == "test-api-key" + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_vertex_ai_api_key_text_chat_normalizes_astrbot_model_id(): + requests: list[httpx.Request] = [] + + async def handler(request: httpx.Request) -> httpx.Response: + requests.append(request) + return httpx.Response( + 200, + json={ + "candidates": [ + { + "content": {"role": "model", "parts": [{"text": "pong"}]}, + "finishReason": "STOP", + } + ], + "usageMetadata": {"promptTokenCount": 1, "candidatesTokenCount": 1}, + "responseId": "vertex-normalized-model-response", + }, + ) + + provider = ProviderGoogleGenAI( + _vertex_api_key_config(model="google/gemini-3-flash-preview"), + provider_settings={}, + ) + await provider._http_client.aclose() + provider._http_client = httpx.AsyncClient( + transport=httpx.MockTransport(handler), + base_url=provider._get_vertex_ai_sdk_base_url(), + timeout=provider.timeout, + ) + provider.client._api_client._async_httpx_client = provider._http_client + + try: + response = await provider.text_chat(prompt="ping") + + assert response.completion_text == "pong" + assert str(requests[0].url) == ( + "https://aiplatform.googleapis.com/v1/publishers/google/" + "models/gemini-3-flash-preview:generateContent" + ) + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_vertex_ai_api_key_provider_omits_empty_api_key_header(): + provider = ProviderGoogleGenAI( + _vertex_api_key_config(vertex_ai_api_key=[]), + provider_settings={}, + ) + + try: + headers = provider.client._api_client._http_options.headers + assert "x-goog-api-key" not in headers + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_vertex_ai_api_key_provider_requires_key_for_text_chat(): + provider = ProviderGoogleGenAI( + _vertex_api_key_config(vertex_ai_api_key=[]), + provider_settings={}, + ) + + try: + with pytest.raises(ValueError, match="Vertex AI API key is required"): + await provider.text_chat(prompt="ping") + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_vertex_ai_service_account_google_genai_get_models_uses_publisher_endpoint( + monkeypatch, +): + credentials = FakeVertexCredentials() + calls = [] + responses = [ + { + "models": [ + {"name": "publishers/google/models/gemini-3-flash-preview"}, + { + "name": ( + "projects/demo-project/locations/global/" + "publishers/google/models/gemini-2.5-pro" + ) + }, + ], + "nextPageToken": "page-2", + }, + {"models": [{"name": "gemini-3-flash-preview"}]}, + ] + + def fake_client(**kwargs): + assert kwargs["credentials"] is credentials + assert kwargs["project"] == "demo-project" + assert kwargs["location"] == "global" + assert kwargs["vertexai"] is True + return type( + "FakeGenAIClient", + (), + { + "aio": type( + "FakeAsyncClient", + (), + { + "models": type( + "FakeModels", + (), + {"list": lambda self: pytest.fail("models.list called")}, + )(), + "aclose": lambda self: None, + }, + )() + }, + )() + + def handler(request: httpx.Request) -> httpx.Response: + calls.append(request) + return httpx.Response(200, json=responses.pop(0)) + + monkeypatch.setattr( + ProviderGoogleGenAI, + "_load_vertex_ai_service_account_credentials", + lambda self: credentials, + ) + monkeypatch.setattr( + "astrbot.core.provider.sources.gemini_source.genai.Client", + fake_client, + ) + + transport = httpx.MockTransport(handler) + provider = ProviderGoogleGenAI( + _vertex_config(type="googlegenai_chat_completion"), + provider_settings={}, + ) + await provider._http_client.aclose() + provider._http_client = httpx.AsyncClient( + transport=transport, + base_url=provider._get_vertex_ai_sdk_base_url(), + timeout=provider.timeout, + ) + + try: + assert await provider.get_models() == [ + "google/gemini-2.5-pro", + "google/gemini-3-flash-preview", + ] + assert len(calls) == 2 + assert calls[0].url.path == "/v1beta1/publishers/google/models" + assert calls[0].url.params["pageSize"] == "300" + assert calls[1].url.path == "/v1beta1/publishers/google/models" + assert calls[1].url.params["pageToken"] == "page-2" + assert calls[0].headers["authorization"] == "Bearer ya29.fake-token" + assert calls[0].headers["x-goog-user-project"] == "demo-project" + assert calls[0].headers["accept-encoding"] == "gzip, deflate" + assert credentials.refresh_count == 0 + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_vertex_ai_api_key_provider_rejects_model_list_without_json( + monkeypatch, +): + provider = ProviderGoogleGenAI(_vertex_api_key_config(), provider_settings={}) + + async def fail_list(): + raise AssertionError("Vertex AI API-key mode must not call models.list()") + + monkeypatch.setattr(provider.client.models, "list", fail_list) + + try: + with pytest.raises(ValueError, match="API Key 无法获取模型列表"): + await provider.get_models() + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_vertex_ai_key_json_provider_fetches_model_list(monkeypatch): + credentials = FakeVertexCredentials() + calls = [] + + def fake_client(**kwargs): + assert kwargs["credentials"] is credentials + assert kwargs["project"] == "demo-project" + assert kwargs["location"] == "global" + assert kwargs["vertexai"] is True + return type( + "FakeGenAIClient", + (), + { + "aio": type( + "FakeAsyncClient", + (), + { + "models": type( + "FakeModels", + (), + {"list": lambda self: pytest.fail("models.list called")}, + )(), + "aclose": lambda self: None, + }, + )() + }, + )() + + def handler(request: httpx.Request) -> httpx.Response: + calls.append(request) + return httpx.Response( + 200, + json={ + "publisherModels": [ + {"name": "publishers/google/models/gemini-3-flash-preview"}, + ] + }, + ) + + from google.oauth2 import service_account + + captured_info = {} + + def fake_from_service_account_info(info, scopes=None): + captured_info["info"] = info + captured_info["scopes"] = scopes + return credentials + + monkeypatch.setattr( + service_account.Credentials, + "from_service_account_info", + staticmethod(fake_from_service_account_info), + ) + original_load_credentials = ( + ProviderGoogleGenAI._load_vertex_ai_service_account_credentials + ) + + def fake_load_credentials(self): + assert ( + "vertex@example.iam.gserviceaccount.com" + in self.provider_config["vertex_ai_credentials_json"] + ) + return original_load_credentials(self) + + monkeypatch.setattr( + ProviderGoogleGenAI, + "_load_vertex_ai_service_account_credentials", + fake_load_credentials, + ) + monkeypatch.setattr( + "astrbot.core.provider.sources.gemini_source.genai.Client", + fake_client, + ) + + provider = ProviderGoogleGenAI(_vertex_key_json_config(), provider_settings={}) + await provider._http_client.aclose() + provider._http_client = httpx.AsyncClient( + transport=httpx.MockTransport(handler), + base_url=provider._get_vertex_ai_sdk_base_url(), + timeout=provider.timeout, + ) + + try: + assert captured_info["info"]["project_id"] == "demo-project" + assert await provider.get_models() == ["google/gemini-3-flash-preview"] + assert calls[0].url.path == "/v1beta1/publishers/google/models" + assert calls[0].headers["authorization"] == "Bearer ya29.fake-token" + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_vertex_ai_get_models_refreshes_expired_credentials(monkeypatch): + class RefreshNeededCredentials(google_auth_credentials.Credentials): + def __init__(self): + super().__init__() + self.token = None + self.refresh_count = 0 + + def refresh(self, request): + self.refresh_count += 1 + self.token = "ya29.refreshed-token" + + credentials = RefreshNeededCredentials() + calls = [] + + def fake_client(**kwargs): + return type( + "FakeGenAIClient", + (), + { + "aio": type( + "FakeAsyncClient", + (), + { + "models": type( + "FakeModels", + (), + {"list": lambda self: pytest.fail("models.list called")}, + )(), + "aclose": lambda self: None, + }, + )() + }, + )() + + def handler(request: httpx.Request) -> httpx.Response: + calls.append(request) + return httpx.Response( + 200, + json={ + "publisherModels": [ + {"name": "publishers/google/models/gemini-3-flash-preview"}, + ] + }, + ) + + monkeypatch.setattr( + ProviderGoogleGenAI, + "_load_vertex_ai_service_account_credentials", + lambda self: credentials, + ) + monkeypatch.setattr( + "astrbot.core.provider.sources.gemini_source.genai.Client", + fake_client, + ) + + provider = ProviderGoogleGenAI( + _vertex_config(type="googlegenai_chat_completion"), + provider_settings={}, + ) + await provider._http_client.aclose() + provider._http_client = httpx.AsyncClient( + transport=httpx.MockTransport(handler), + base_url=provider._get_vertex_ai_sdk_base_url(), + timeout=provider.timeout, + ) + + try: + assert await provider.get_models() == ["google/gemini-3-flash-preview"] + assert credentials.refresh_count == 1 + assert calls[0].headers["authorization"] == "Bearer ya29.refreshed-token" + finally: + await provider.terminate() + + +def test_vertex_ai_config_template_defaults_to_json_and_global(): + provider_metadata = CONFIG_METADATA_2["provider_group"]["metadata"]["provider"] + template = provider_metadata["config_template"]["Google Vertex AI"] + assert template["provider"] == "google-vertex-ai" + assert template["type"] == "googlegenai_chat_completion" + assert template["vertex_ai_auth_type"] == "json" + assert template["vertex_ai_location"] == "global" + assert template["api_base"] == "https://aiplatform.googleapis.com/v1" + assert template["gm_safety_settings"]["harassment"] == "BLOCK_MEDIUM_AND_ABOVE" + assert template["gm_thinking_config"] == {"budget": 0, "level": "HIGH"} + + items = provider_metadata["items"] + assert items["vertex_ai_auth_type"]["options"] == ["json", "api_key"] + assert items["vertex_ai_api_key"]["type"] == "list" + assert items["vertex_ai_api_key"]["button_text"] == "API Key" + assert items["vertex_ai_api_key"]["dialog_title"] == "API Key" + assert items["vertex_ai_api_key"]["prefer_single_item"] is True + assert items["vertex_ai_api_key"]["condition"] == { + "provider": "google-vertex-ai", + "vertex_ai_auth_type": "api_key", + } + assert items["vertex_ai_location"]["condition"] == { + "provider": "google-vertex-ai", + } + assert items["vertex_ai_project_id"]["invisible"] is True + assert items["vertex_ai_credentials_path"]["invisible"] is True + assert items["vertex_ai_credentials_json"]["condition"] == { + "provider": "google-vertex-ai", + "vertex_ai_auth_type": "json", + } + + +def test_vertex_ai_config_metadata_i18n_keys_exist_for_all_locales(): + locales_dir = ( + Path(__file__).resolve().parents[1] / "dashboard" / "src" / "i18n" / "locales" + ) + expected_keys = { + "google_vertex_ai", + "vertex_ai_auth_type", + "vertex_ai_api_key", + "vertex_ai_project_id", + "vertex_ai_location", + "vertex_ai_credentials_path", + "vertex_ai_credentials_json", + } + + for locale in ("zh-CN", "en-US", "ru-RU"): + metadata_path = locales_dir / locale / "features" / "config-metadata.json" + metadata = json.loads(metadata_path.read_text(encoding="utf-8")) + provider_translations = metadata["provider_group"]["provider"] + + assert expected_keys <= provider_translations.keys() + + +def test_vertex_ai_chinese_i18n_matches_requested_copy(): + metadata_path = ( + Path(__file__).resolve().parents[1] + / "dashboard" + / "src" + / "i18n" + / "locales" + / "zh-CN" + / "features" + / "config-metadata.json" + ) + provider_translations = json.loads(metadata_path.read_text(encoding="utf-8"))[ + "provider_group" + ]["provider"] + + assert provider_translations["vertex_ai_auth_type"]["description"] == "密钥格式" + assert provider_translations["vertex_ai_credentials_json"]["hint"] == ( + "Google Cloud 服务账号 JSON 可在该网址创建获得 " + "https://console.cloud.google.com/iam-admin/serviceaccounts" + ) + assert ( + provider_translations["vertex_ai_api_key"]["hint"] + == "在 Vertex AI API Key、服务账号 JSON 选择其一使用即可。" + )