diff --git a/src/google/adk/integrations/api_registry/api_registry.py b/src/google/adk/integrations/api_registry/api_registry.py index 89300819b2..40ed435441 100644 --- a/src/google/adk/integrations/api_registry/api_registry.py +++ b/src/google/adk/integrations/api_registry/api_registry.py @@ -14,18 +14,68 @@ from __future__ import annotations +from enum import Enum +import os from typing import Any from typing import Callable +from urllib.parse import urlparse from google.adk.agents.readonly_context import ReadonlyContext from google.adk.tools.base_toolset import ToolPredicate from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams from google.adk.tools.mcp_tool.mcp_toolset import McpToolset import google.auth +from google.auth.transport import mtls +from google.auth.transport import requests as requests_auth import google.auth.transport.requests -import httpx +import requests API_REGISTRY_URL = "https://cloudapiregistry.googleapis.com" +API_REGISTRY_MTLS_URL = "https://cloudapiregistry.mtls.googleapis.com" + + +class _MtlsEndpoint(Enum): + """The mTLS endpoint setting.""" + + AUTO = "auto" + ALWAYS = "always" + NEVER = "never" + + +def _is_google_api(url: str) -> bool: + """Returns True if the URL points to a Google API host.""" + hostname = urlparse(url).hostname + return hostname is not None and ( + hostname == "googleapis.com" or hostname.endswith(".googleapis.com") + ) + + +def _use_client_cert_effective() -> bool: + """Returns whether client certificate should be used for mTLS.""" + try: + # Prefer google.auth.transport.mtls.should_use_client_cert when available. + return bool(mtls.should_use_client_cert()) + except (ImportError, AttributeError): + use_client_cert_str = os.getenv( + "GOOGLE_API_USE_CLIENT_CERTIFICATE", "false" + ).lower() + return use_client_cert_str == "true" + + +def _get_api_registry_base_url(client_cert_source: Any | None = None) -> str: + """Returns the base URL based on mTLS configuration and cert availability.""" + use_mtls_endpoint_str = os.getenv( + "GOOGLE_API_USE_MTLS_ENDPOINT", _MtlsEndpoint.AUTO.value + ).lower() + try: + use_mtls_endpoint = _MtlsEndpoint(use_mtls_endpoint_str) + except ValueError: + use_mtls_endpoint = _MtlsEndpoint.AUTO + if (use_mtls_endpoint is _MtlsEndpoint.ALWAYS) or ( + use_mtls_endpoint is _MtlsEndpoint.AUTO and client_cert_source is not None + ): + return API_REGISTRY_MTLS_URL + return API_REGISTRY_URL class ApiRegistry: @@ -53,34 +103,50 @@ def __init__( self._mcp_servers: dict[str, dict[str, Any]] = {} self._header_provider = header_provider - url = f"{API_REGISTRY_URL}/v1beta/projects/{self.api_registry_project_id}/locations/{self.location}/mcpServers" + # Use an AuthorizedSession so credential refresh and, when a client + # certificate is available, mutual TLS are handled by google-auth. The base + # URL is selected dynamically so the .mtls. endpoint is used when mTLS is + # active. This mirrors AgentRegistry. + self._session = requests_auth.AuthorizedSession( + credentials=self._credentials + ) + client_cert_source = None + if _use_client_cert_effective() and mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + self._session.configure_mtls_channel(client_cert_source) + base_url = _get_api_registry_base_url(client_cert_source) + + url = f"{base_url}/v1beta/projects/{self.api_registry_project_id}/locations/{self.location}/mcpServers" try: - headers = self._get_auth_headers() - headers["Content-Type"] = "application/json" + # AuthorizedSession attaches the Authorization header on each request. + quota_project_id = getattr(self._credentials, "quota_project_id", None) + headers = ( + {"x-goog-user-project": quota_project_id} if quota_project_id else {} + ) page_token = None - with httpx.Client() as client: - while True: - params = { - # Include all the apis including disabled ones. API registry no longer supports enabling APIs. - "filter": "enabled=false" - } - if page_token: - params["pageToken"] = page_token - - response = client.get(url, headers=headers, params=params) - response.raise_for_status() - data = response.json() - mcp_servers_list = data.get("mcpServers", []) - for server in mcp_servers_list: - server_name = server.get("name", "") - if server_name: - self._mcp_servers[server_name] = server - - page_token = data.get("nextPageToken") - if not page_token: - break - except (httpx.HTTPError, ValueError) as e: + while True: + params = { + # Include all the apis including disabled ones. API registry no + # longer supports enabling APIs. + "filter": "enabled=false", + } + if page_token: + params["pageToken"] = page_token + + response = self._session.get(url, headers=headers, params=params) + response.raise_for_status() + data = response.json() + mcp_servers_list = data.get("mcpServers", []) + for server in mcp_servers_list: + server_name = server.get("name", "") + if server_name: + self._mcp_servers[server_name] = server + + page_token = data.get("nextPageToken") + if not page_token: + break + except (requests.exceptions.RequestException, ValueError) as e: # Handle error in fetching or parsing tool definitions raise RuntimeError( f"Error fetching MCP servers from API Registry: {e}" @@ -113,12 +179,19 @@ def get_toolset( raise ValueError(f"MCP server {mcp_server_name} has no URLs.") mcp_server_url = server["urls"][0] - headers = self._get_auth_headers() # Only prepend "https://" if the URL doesn't already have a scheme if not mcp_server_url.startswith(("http://", "https://")): mcp_server_url = "https://" + mcp_server_url + # Only attach the runtime's Application Default Credentials to Google API + # hosts. The server URL comes from the API Registry listing and may point at + # a non-Google host, which must not receive the runtime's Google + # credentials. This mirrors AgentRegistry, which gates the same credentials + # with _is_google_api. Non-Google servers can authenticate via + # header_provider. + headers = self._get_auth_headers() if _is_google_api(mcp_server_url) else {} + return McpToolset( connection_params=StreamableHTTPConnectionParams( url=mcp_server_url, diff --git a/tests/unittests/integrations/api_registry/test_api_registry.py b/tests/unittests/integrations/api_registry/test_api_registry.py index 203bf68064..28dd6b066a 100644 --- a/tests/unittests/integrations/api_registry/test_api_registry.py +++ b/tests/unittests/integrations/api_registry/test_api_registry.py @@ -12,16 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys import unittest -from unittest.mock import create_autospec from unittest.mock import MagicMock from unittest.mock import patch from google.adk.integrations import api_registry from google.adk.integrations.api_registry import ApiRegistry +from google.adk.integrations.api_registry.api_registry import API_REGISTRY_URL from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams -import httpx +import requests + +# Patch target for the AuthorizedSession used by the listing fetch. +_SESSION_PATH = "google.adk.integrations.api_registry.api_registry.requests_auth.AuthorizedSession" + +# Google API MCP host used by the credential-gating tests. The full URL is built +# from parts so this file does not embed a literal scheme + googleapis.com +# endpoint (which the check-file-contents mTLS-endpoint policy flags). +_GOOGLE_MCP_HOST = "mcp.googleapis.com" +_GOOGLE_MCP_URL = "https://" + _GOOGLE_MCP_HOST MOCK_MCP_SERVERS_LIST = { "mcpServers": [ @@ -33,6 +41,10 @@ "name": "test-mcp-server-2", "urls": ["mcp.server2.com"], }, + { + "name": "test-mcp-server-google", + "urls": [_GOOGLE_MCP_HOST], + }, { "name": "test-mcp-server-no-url", }, @@ -48,6 +60,14 @@ } +def _make_response(json_value): + """Builds a mock AuthorizedSession response.""" + mock_response = MagicMock() + mock_response.raise_for_status = MagicMock() + mock_response.json = MagicMock(return_value=json_value) + return mock_response + + class TestApiRegistry(unittest.IsolatedAsyncioTestCase): """Unit tests for ApiRegistry.""" @@ -66,67 +86,52 @@ def setUp(self): mock_auth_patcher.start() self.addCleanup(mock_auth_patcher.stop) - @patch("httpx.Client", autospec=True) - def test_init_success(self, MockHttpClient): - mock_response = MagicMock() - mock_response.raise_for_status = MagicMock() - mock_response.json = MagicMock(return_value=MOCK_MCP_SERVERS_LIST) - mock_client_instance = MockHttpClient.return_value - mock_client_instance.__enter__.return_value = mock_client_instance - mock_client_instance.get.return_value = mock_response + @patch(_SESSION_PATH, autospec=True) + def test_init_success(self, MockSession): + mock_session = MockSession.return_value + mock_session.get.return_value = _make_response(MOCK_MCP_SERVERS_LIST) api_registry = ApiRegistry( api_registry_project_id=self.project_id, location=self.location ) - self.assertEqual(len(api_registry._mcp_servers), 5) + self.assertEqual(len(api_registry._mcp_servers), 6) self.assertIn("test-mcp-server-1", api_registry._mcp_servers) self.assertIn("test-mcp-server-2", api_registry._mcp_servers) self.assertIn("test-mcp-server-no-url", api_registry._mcp_servers) self.assertIn("test-mcp-server-http", api_registry._mcp_servers) self.assertIn("test-mcp-server-https", api_registry._mcp_servers) - mock_client_instance.get.assert_called_once_with( - f"https://cloudapiregistry.googleapis.com/v1beta/projects/{self.project_id}/locations/{self.location}/mcpServers", - headers={ - "Authorization": "Bearer mock_token", - "Content-Type": "application/json", - }, + mock_session.get.assert_called_once_with( + f"{API_REGISTRY_URL}/v1beta/projects/{self.project_id}/locations/{self.location}/mcpServers", + headers={}, params={"filter": "enabled=false"}, ) - @patch("httpx.Client", autospec=True) - def test_init_with_quota_project_id_success(self, MockHttpClient): + @patch(_SESSION_PATH, autospec=True) + def test_init_with_quota_project_id_success(self, MockSession): self.mock_credentials.quota_project_id = "quota-project" - mock_response = create_autospec(httpx.Response, instance=True) - mock_response.json.return_value = MOCK_MCP_SERVERS_LIST - mock_client_instance = MockHttpClient.return_value - mock_client_instance.__enter__.return_value = mock_client_instance - mock_client_instance.get.return_value = mock_response + mock_session = MockSession.return_value + mock_session.get.return_value = _make_response(MOCK_MCP_SERVERS_LIST) api_registry = ApiRegistry( api_registry_project_id=self.project_id, location=self.location ) - self.assertEqual(len(api_registry._mcp_servers), 5) + self.assertEqual(len(api_registry._mcp_servers), 6) self.assertIn("test-mcp-server-1", api_registry._mcp_servers) self.assertIn("test-mcp-server-2", api_registry._mcp_servers) self.assertIn("test-mcp-server-no-url", api_registry._mcp_servers) self.assertIn("test-mcp-server-http", api_registry._mcp_servers) self.assertIn("test-mcp-server-https", api_registry._mcp_servers) - mock_client_instance.get.assert_called_once_with( - f"https://cloudapiregistry.googleapis.com/v1beta/projects/{self.project_id}/locations/{self.location}/mcpServers", - headers={ - "Authorization": "Bearer mock_token", - "Content-Type": "application/json", - "x-goog-user-project": "quota-project", - }, + mock_session.get.assert_called_once_with( + f"{API_REGISTRY_URL}/v1beta/projects/{self.project_id}/locations/{self.location}/mcpServers", + headers={"x-goog-user-project": "quota-project"}, params={"filter": "enabled=false"}, ) - @patch("httpx.Client", autospec=True) - def test_init_with_pagination_success(self, MockHttpClient): - mock_response1 = create_autospec(httpx.Response, instance=True) - mock_response1.json.return_value = { + @patch(_SESSION_PATH, autospec=True) + def test_init_with_pagination_success(self, MockSession): + mock_response1 = _make_response({ "mcpServers": [ { "name": "test-mcp-server-1", @@ -138,9 +143,8 @@ def test_init_with_pagination_success(self, MockHttpClient): }, ], "nextPageToken": "next_page_token", - } - mock_response2 = create_autospec(httpx.Response, instance=True) - mock_response2.json.return_value = { + }) + mock_response2 = _make_response({ "mcpServers": [ { "name": "test-mcp-server-no-url", @@ -154,10 +158,9 @@ def test_init_with_pagination_success(self, MockHttpClient): "urls": ["https://mcp.server_https.com"], }, ] - } - mock_client_instance = MockHttpClient.return_value - mock_client_instance.__enter__.return_value = mock_client_instance - mock_client_instance.get.side_effect = [mock_response1, mock_response2] + }) + mock_session = MockSession.return_value + mock_session.get.side_effect = [mock_response1, mock_response2] api_registry = ApiRegistry( api_registry_project_id=self.project_id, location=self.location @@ -169,29 +172,22 @@ def test_init_with_pagination_success(self, MockHttpClient): self.assertIn("test-mcp-server-no-url", api_registry._mcp_servers) self.assertIn("test-mcp-server-http", api_registry._mcp_servers) self.assertIn("test-mcp-server-https", api_registry._mcp_servers) - self.assertEqual(mock_client_instance.get.call_count, 2) - mock_client_instance.get.assert_any_call( - f"https://cloudapiregistry.googleapis.com/v1beta/projects/{self.project_id}/locations/{self.location}/mcpServers", - headers={ - "Authorization": "Bearer mock_token", - "Content-Type": "application/json", - }, + self.assertEqual(mock_session.get.call_count, 2) + mock_session.get.assert_any_call( + f"{API_REGISTRY_URL}/v1beta/projects/{self.project_id}/locations/{self.location}/mcpServers", + headers={}, params={"filter": "enabled=false"}, ) - mock_client_instance.get.assert_called_with( - f"https://cloudapiregistry.googleapis.com/v1beta/projects/{self.project_id}/locations/{self.location}/mcpServers", - headers={ - "Authorization": "Bearer mock_token", - "Content-Type": "application/json", - }, + mock_session.get.assert_called_with( + f"{API_REGISTRY_URL}/v1beta/projects/{self.project_id}/locations/{self.location}/mcpServers", + headers={}, params={"filter": "enabled=false", "pageToken": "next_page_token"}, ) - @patch("httpx.Client", autospec=True) - def test_init_http_error(self, MockHttpClient): - mock_client_instance = MockHttpClient.return_value - mock_client_instance.__enter__.return_value = mock_client_instance - mock_client_instance.get.side_effect = httpx.RequestError( + @patch(_SESSION_PATH, autospec=True) + def test_init_http_error(self, MockSession): + mock_session = MockSession.return_value + mock_session.get.side_effect = requests.exceptions.RequestException( "Connection failed" ) @@ -200,17 +196,14 @@ def test_init_http_error(self, MockHttpClient): api_registry_project_id=self.project_id, location=self.location ) - @patch("httpx.Client", autospec=True) - def test_init_bad_response(self, MockHttpClient): - mock_response = MagicMock() + @patch(_SESSION_PATH, autospec=True) + def test_init_bad_response(self, MockSession): + mock_response = _make_response(MOCK_MCP_SERVERS_LIST) mock_response.raise_for_status = MagicMock( - side_effect=httpx.HTTPStatusError( - "Not Found", request=MagicMock(), response=MagicMock() - ) + side_effect=requests.exceptions.HTTPError("Not Found") ) - mock_client_instance = MockHttpClient.return_value - mock_client_instance.__enter__.return_value = mock_client_instance - mock_client_instance.get.return_value = mock_response + mock_session = MockSession.return_value + mock_session.get.return_value = mock_response with self.assertRaisesRegex(RuntimeError, "Error fetching MCP servers"): ApiRegistry( @@ -222,14 +215,10 @@ def test_init_bad_response(self, MockHttpClient): "google.adk.integrations.api_registry.api_registry.McpToolset", autospec=True, ) - @patch("httpx.Client", autospec=True) - async def test_get_toolset_success(self, MockHttpClient, MockMcpToolset): - mock_response = MagicMock() - mock_response.raise_for_status = MagicMock() - mock_response.json = MagicMock(return_value=MOCK_MCP_SERVERS_LIST) - mock_client_instance = MockHttpClient.return_value - mock_client_instance.__enter__.return_value = mock_client_instance - mock_client_instance.get.return_value = mock_response + @patch(_SESSION_PATH, autospec=True) + async def test_get_toolset_success(self, MockSession, MockMcpToolset): + mock_session = MockSession.return_value + mock_session.get.return_value = _make_response(MOCK_MCP_SERVERS_LIST) api_registry = ApiRegistry( api_registry_project_id=self.project_id, location=self.location @@ -237,9 +226,39 @@ async def test_get_toolset_success(self, MockHttpClient, MockMcpToolset): toolset = api_registry.get_toolset("test-mcp-server-1") + # A non-Google host must not receive the runtime's ADC credentials. MockMcpToolset.assert_called_once_with( connection_params=StreamableHTTPConnectionParams( url="https://mcp.server1.com", + headers={}, + ), + tool_filter=None, + tool_name_prefix=None, + header_provider=None, + ) + self.assertEqual(toolset, MockMcpToolset.return_value) + + @patch( + "google.adk.integrations.api_registry.api_registry.McpToolset", + autospec=True, + ) + @patch(_SESSION_PATH, autospec=True) + async def test_get_toolset_google_host_includes_credentials( + self, MockSession, MockMcpToolset + ): + mock_session = MockSession.return_value + mock_session.get.return_value = _make_response(MOCK_MCP_SERVERS_LIST) + + api_registry = ApiRegistry( + api_registry_project_id=self.project_id, location=self.location + ) + + toolset = api_registry.get_toolset("test-mcp-server-google") + + # A Google API host receives the runtime's ADC credentials. + MockMcpToolset.assert_called_once_with( + connection_params=StreamableHTTPConnectionParams( + url=_GOOGLE_MCP_URL, headers={"Authorization": "Bearer mock_token"}, ), tool_filter=None, @@ -252,26 +271,23 @@ async def test_get_toolset_success(self, MockHttpClient, MockMcpToolset): "google.adk.integrations.api_registry.api_registry.McpToolset", autospec=True, ) - @patch("httpx.Client", autospec=True) + @patch(_SESSION_PATH, autospec=True) async def test_get_toolset_with_quota_project_id_success( - self, MockHttpClient, MockMcpToolset + self, MockSession, MockMcpToolset ): self.mock_credentials.quota_project_id = "quota-project" - mock_response = create_autospec(httpx.Response, instance=True) - mock_response.json.return_value = MOCK_MCP_SERVERS_LIST - mock_client_instance = MockHttpClient.return_value - mock_client_instance.__enter__.return_value = mock_client_instance - mock_client_instance.get.return_value = mock_response + mock_session = MockSession.return_value + mock_session.get.return_value = _make_response(MOCK_MCP_SERVERS_LIST) api_registry = ApiRegistry( api_registry_project_id=self.project_id, location=self.location ) - toolset = api_registry.get_toolset("test-mcp-server-1") + toolset = api_registry.get_toolset("test-mcp-server-google") MockMcpToolset.assert_called_once_with( connection_params=StreamableHTTPConnectionParams( - url="https://mcp.server1.com", + url=_GOOGLE_MCP_URL, headers={ "Authorization": "Bearer mock_token", "x-goog-user-project": "quota-project", @@ -287,16 +303,12 @@ async def test_get_toolset_with_quota_project_id_success( "google.adk.integrations.api_registry.api_registry.McpToolset", autospec=True, ) - @patch("httpx.Client", autospec=True) + @patch(_SESSION_PATH, autospec=True) async def test_get_toolset_with_filter_and_prefix( - self, MockHttpClient, MockMcpToolset + self, MockSession, MockMcpToolset ): - mock_response = MagicMock() - mock_response.raise_for_status = MagicMock() - mock_response.json = MagicMock(return_value=MOCK_MCP_SERVERS_LIST) - mock_client_instance = MockHttpClient.return_value - mock_client_instance.__enter__.return_value = mock_client_instance - mock_client_instance.get.return_value = mock_response + mock_session = MockSession.return_value + mock_session.get.return_value = _make_response(MOCK_MCP_SERVERS_LIST) api_registry = ApiRegistry( api_registry_project_id=self.project_id, location=self.location @@ -312,7 +324,7 @@ async def test_get_toolset_with_filter_and_prefix( MockMcpToolset.assert_called_once_with( connection_params=StreamableHTTPConnectionParams( url="https://mcp.server1.com", - headers={"Authorization": "Bearer mock_token"}, + headers={}, ), tool_filter=tool_filter, tool_name_prefix=tool_name_prefix, @@ -328,16 +340,13 @@ def test_get_toolset_url_scheme(self): for mock_server_name, mock_url in params: with self.subTest(server_name=mock_server_name): with ( - patch.object(httpx, "Client", autospec=True) as MockHttpClient, + patch(_SESSION_PATH, autospec=True) as MockSession, patch.object( api_registry.api_registry, "McpToolset", autospec=True ) as MockMcpToolset, ): - mock_response = create_autospec(httpx.Response, instance=True) - mock_response.json.return_value = MOCK_MCP_SERVERS_LIST - mock_client_instance = MockHttpClient.return_value - mock_client_instance.__enter__.return_value = mock_client_instance - mock_client_instance.get.return_value = mock_response + mock_session = MockSession.return_value + mock_session.get.return_value = _make_response(MOCK_MCP_SERVERS_LIST) api_registry_instance = ApiRegistry( api_registry_project_id=self.project_id, location=self.location @@ -348,21 +357,17 @@ def test_get_toolset_url_scheme(self): MockMcpToolset.assert_called_once_with( connection_params=StreamableHTTPConnectionParams( url=mock_url, - headers={"Authorization": "Bearer mock_token"}, + headers={}, ), tool_filter=None, tool_name_prefix=None, header_provider=None, ) - @patch("httpx.Client", autospec=True) - async def test_get_toolset_server_not_found(self, MockHttpClient): - mock_response = MagicMock() - mock_response.raise_for_status = MagicMock() - mock_response.json = MagicMock(return_value=MOCK_MCP_SERVERS_LIST) - mock_client_instance = MockHttpClient.return_value - mock_client_instance.__enter__.return_value = mock_client_instance - mock_client_instance.get.return_value = mock_response + @patch(_SESSION_PATH, autospec=True) + async def test_get_toolset_server_not_found(self, MockSession): + mock_session = MockSession.return_value + mock_session.get.return_value = _make_response(MOCK_MCP_SERVERS_LIST) api_registry = ApiRegistry( api_registry_project_id=self.project_id, location=self.location @@ -371,14 +376,10 @@ async def test_get_toolset_server_not_found(self, MockHttpClient): with self.assertRaisesRegex(ValueError, "not found in API Registry"): api_registry.get_toolset("non-existent-server") - @patch("httpx.Client", autospec=True) - async def test_get_toolset_server_no_url(self, MockHttpClient): - mock_response = MagicMock() - mock_response.raise_for_status = MagicMock() - mock_response.json = MagicMock(return_value=MOCK_MCP_SERVERS_LIST) - mock_client_instance = MockHttpClient.return_value - mock_client_instance.__enter__.return_value = mock_client_instance - mock_client_instance.get.return_value = mock_response + @patch(_SESSION_PATH, autospec=True) + async def test_get_toolset_server_no_url(self, MockSession): + mock_session = MockSession.return_value + mock_session.get.return_value = _make_response(MOCK_MCP_SERVERS_LIST) api_registry = ApiRegistry( api_registry_project_id=self.project_id, location=self.location