Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 100 additions & 27 deletions src/google/adk/integrations/api_registry/api_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,68 @@

from __future__ import annotations

from enum import Enum
import os
from typing import Any
from typing import Callable
from urllib.parse import urlparse

from google.adk.agents.readonly_context import ReadonlyContext
from google.adk.tools.base_toolset import ToolPredicate
from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset
import google.auth
from google.auth.transport import mtls
from google.auth.transport import requests as requests_auth
import google.auth.transport.requests
import httpx
import requests

API_REGISTRY_URL = "https://cloudapiregistry.googleapis.com"
API_REGISTRY_MTLS_URL = "https://cloudapiregistry.mtls.googleapis.com"


class _MtlsEndpoint(Enum):
"""The mTLS endpoint setting."""

AUTO = "auto"
ALWAYS = "always"
NEVER = "never"


def _is_google_api(url: str) -> bool:
"""Returns True if the URL points to a Google API host."""
hostname = urlparse(url).hostname
return hostname is not None and (
hostname == "googleapis.com" or hostname.endswith(".googleapis.com")
)


def _use_client_cert_effective() -> bool:
"""Returns whether client certificate should be used for mTLS."""
try:
# Prefer google.auth.transport.mtls.should_use_client_cert when available.
return bool(mtls.should_use_client_cert())
except (ImportError, AttributeError):
use_client_cert_str = os.getenv(
"GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"
).lower()
return use_client_cert_str == "true"


def _get_api_registry_base_url(client_cert_source: Any | None = None) -> str:
"""Returns the base URL based on mTLS configuration and cert availability."""
use_mtls_endpoint_str = os.getenv(
"GOOGLE_API_USE_MTLS_ENDPOINT", _MtlsEndpoint.AUTO.value
).lower()
try:
use_mtls_endpoint = _MtlsEndpoint(use_mtls_endpoint_str)
except ValueError:
use_mtls_endpoint = _MtlsEndpoint.AUTO
if (use_mtls_endpoint is _MtlsEndpoint.ALWAYS) or (
use_mtls_endpoint is _MtlsEndpoint.AUTO and client_cert_source is not None
):
return API_REGISTRY_MTLS_URL
return API_REGISTRY_URL


class ApiRegistry:
Expand Down Expand Up @@ -53,34 +103,50 @@ def __init__(
self._mcp_servers: dict[str, dict[str, Any]] = {}
self._header_provider = header_provider

url = f"{API_REGISTRY_URL}/v1beta/projects/{self.api_registry_project_id}/locations/{self.location}/mcpServers"
# Use an AuthorizedSession so credential refresh and, when a client
# certificate is available, mutual TLS are handled by google-auth. The base
# URL is selected dynamically so the .mtls. endpoint is used when mTLS is
# active. This mirrors AgentRegistry.
self._session = requests_auth.AuthorizedSession(
credentials=self._credentials
)
client_cert_source = None
if _use_client_cert_effective() and mtls.has_default_client_cert_source():
client_cert_source = mtls.default_client_cert_source()
self._session.configure_mtls_channel(client_cert_source)
base_url = _get_api_registry_base_url(client_cert_source)

url = f"{base_url}/v1beta/projects/{self.api_registry_project_id}/locations/{self.location}/mcpServers"

try:
headers = self._get_auth_headers()
headers["Content-Type"] = "application/json"
# AuthorizedSession attaches the Authorization header on each request.
quota_project_id = getattr(self._credentials, "quota_project_id", None)
headers = (
{"x-goog-user-project": quota_project_id} if quota_project_id else {}
)
page_token = None
with httpx.Client() as client:
while True:
params = {
# Include all the apis including disabled ones. API registry no longer supports enabling APIs.
"filter": "enabled=false"
}
if page_token:
params["pageToken"] = page_token

response = client.get(url, headers=headers, params=params)
response.raise_for_status()
data = response.json()
mcp_servers_list = data.get("mcpServers", [])
for server in mcp_servers_list:
server_name = server.get("name", "")
if server_name:
self._mcp_servers[server_name] = server

page_token = data.get("nextPageToken")
if not page_token:
break
except (httpx.HTTPError, ValueError) as e:
while True:
params = {
# Include all the apis including disabled ones. API registry no
# longer supports enabling APIs.
"filter": "enabled=false",
}
if page_token:
params["pageToken"] = page_token

response = self._session.get(url, headers=headers, params=params)
response.raise_for_status()
data = response.json()
mcp_servers_list = data.get("mcpServers", [])
for server in mcp_servers_list:
server_name = server.get("name", "")
if server_name:
self._mcp_servers[server_name] = server

page_token = data.get("nextPageToken")
if not page_token:
break
except (requests.exceptions.RequestException, ValueError) as e:
# Handle error in fetching or parsing tool definitions
raise RuntimeError(
f"Error fetching MCP servers from API Registry: {e}"
Expand Down Expand Up @@ -113,12 +179,19 @@ def get_toolset(
raise ValueError(f"MCP server {mcp_server_name} has no URLs.")

mcp_server_url = server["urls"][0]
headers = self._get_auth_headers()

# Only prepend "https://" if the URL doesn't already have a scheme
if not mcp_server_url.startswith(("http://", "https://")):
mcp_server_url = "https://" + mcp_server_url

# Only attach the runtime's Application Default Credentials to Google API
# hosts. The server URL comes from the API Registry listing and may point at
# a non-Google host, which must not receive the runtime's Google
# credentials. This mirrors AgentRegistry, which gates the same credentials
# with _is_google_api. Non-Google servers can authenticate via
# header_provider.
headers = self._get_auth_headers() if _is_google_api(mcp_server_url) else {}

return McpToolset(
connection_params=StreamableHTTPConnectionParams(
url=mcp_server_url,
Expand Down
Loading