Skip to content

Commit a0037fd

Browse files
Merge branch 'main' into fix/eval-skip-agent-only-invocations
2 parents bc8727d + 0acee31 commit a0037fd

9 files changed

Lines changed: 1359 additions & 0 deletions

File tree

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ test = [
123123
"a2a-sdk>=0.3.0,<0.4.0",
124124
"anthropic>=0.43.0", # For anthropic model tests
125125
"crewai[tools];python_version>='3.11' and python_version<'3.12'", # For CrewaiTool tests; chromadb/pypika fail on 3.12+
126+
"google-cloud-iamconnectorcredentials>=0.1.0, <0.2.0",
126127
"google-cloud-parametermanager>=0.4.0, <1.0.0",
127128
"kubernetes>=29.0.0", # For GkeCodeExecutor
128129
"langchain-community>=0.3.17",
@@ -176,6 +177,10 @@ toolbox = ["toolbox-adk>=1.0.0, <2.0.0"]
176177

177178
slack = ["slack-bolt>=1.22.0"]
178179

180+
agent-identity = [
181+
"google-cloud-iamconnectorcredentials>=0.1.0, <0.2.0",
182+
]
183+
179184
[tool.pyink]
180185
# Format py files following Google style-guide
181186
line-length = 80
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# GCP IAM Connector Auth
2+
3+
Manages the complete lifecycle of an access token using the Google Cloud
4+
Platform Agent Identity Credentials service.
5+
6+
## Usage
7+
8+
1. **Install Dependencies:**
9+
```bash
10+
pip install "google-adk[agent-identity]"
11+
```
12+
13+
2. **Register the provider:**
14+
Register the `GcpAuthProvider` with the `CredentialManager`. This is to be
15+
done one time.
16+
17+
``` py
18+
# user_agent_app.py
19+
from google.adk.auth.credential_manager import CredentialManager
20+
from google.adk.integrations.agent_identity import GcpAuthProvider
21+
22+
CredentialManager.register_auth_provider(GcpAuthProvider())
23+
```
24+
25+
3. **Configure the Auth provider:**
26+
Specify the Agent Identity provider configuration using the
27+
`GcpAuthProviderScheme`.
28+
``` py
29+
# user_agent_app.py
30+
from google.adk.integrations.agent_identity import GcpAuthProviderScheme
31+
32+
# Configures Toolset
33+
auth_scheme = GcpAuthProviderScheme(name="my-jira-auth_provider")
34+
mcp_toolset_jira = McpToolset(..., auth_scheme=auth_scheme)
35+
```
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .gcp_auth_provider import GcpAuthProvider
16+
from .gcp_auth_provider_scheme import GcpAuthProviderScheme
17+
18+
__all__ = [
19+
"GcpAuthProvider",
20+
"GcpAuthProviderScheme",
21+
]
Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import asyncio
18+
import logging
19+
import os
20+
import time
21+
22+
from google.adk.agents.callback_context import CallbackContext
23+
from google.adk.auth.auth_credential import AuthCredential
24+
from google.adk.auth.auth_credential import AuthCredentialTypes
25+
from google.adk.auth.auth_credential import HttpAuth
26+
from google.adk.auth.auth_credential import HttpCredentials
27+
from google.adk.auth.auth_credential import OAuth2Auth
28+
from google.adk.auth.auth_tool import AuthConfig
29+
from google.adk.auth.base_auth_provider import BaseAuthProvider
30+
from google.adk.flows.llm_flows.functions import REQUEST_EUC_FUNCTION_CALL_NAME
31+
from google.api_core.client_options import ClientOptions
32+
from google.cloud.iamconnectorcredentials_v1alpha import IAMConnectorCredentialsServiceClient as Client
33+
from google.cloud.iamconnectorcredentials_v1alpha import RetrieveCredentialsMetadata
34+
from google.cloud.iamconnectorcredentials_v1alpha import RetrieveCredentialsRequest
35+
from google.cloud.iamconnectorcredentials_v1alpha import RetrieveCredentialsResponse
36+
from google.longrunning.operations_pb2 import Operation
37+
from typing_extensions import override
38+
39+
from .gcp_auth_provider_scheme import GcpAuthProviderScheme
40+
41+
# Notes on the current Agent Identity Credentials service implementation:
42+
# 1. The service does not yet support LROs, so even though the
43+
# retrieve_credentials method returns an Operation object, the methods like
44+
# operation.done() and operation.result() will not work yet.
45+
# 2. For API key flows, the returned Operation contains the credentials.
46+
# 3. For 2-legged OAuth flows, the returned Operation contains pending status,
47+
# client needs to retry the request until response with credentials is
48+
# returned or timeout occurs.
49+
# 4. For 3-legged OAuth flows, the returned Operation contains consent pending
50+
# status along with the authorization URI.
51+
52+
# TODO: Catch specific exceptions instead of generic ones.
53+
54+
logger = logging.getLogger("google_adk." + __name__)
55+
56+
NON_INTERACTIVE_TOKEN_POLL_INTERVAL_SEC: float = 1.0
57+
NON_INTERACTIVE_TOKEN_POLL_TIMEOUT_SEC: float = 10.0
58+
59+
60+
def _construct_auth_credential(
61+
response: RetrieveCredentialsResponse,
62+
) -> AuthCredential:
63+
"""Constructs a simplified HTTP auth credential from the header-token tuple returned by the upstream service."""
64+
if not response.header or not response.token:
65+
raise ValueError(
66+
"Received either empty header or token from Agent Identity Credentials"
67+
" service."
68+
)
69+
70+
header_name, _, header_value = response.header.partition(":")
71+
if (
72+
header_name.strip().lower() == "authorization"
73+
and header_value.strip().lower().startswith("bearer")
74+
):
75+
return AuthCredential(
76+
auth_type=AuthCredentialTypes.HTTP,
77+
http=HttpAuth(
78+
scheme="bearer",
79+
credentials=HttpCredentials(token=response.token),
80+
),
81+
)
82+
83+
# Handle custom header.
84+
return AuthCredential(
85+
auth_type=AuthCredentialTypes.HTTP,
86+
http=HttpAuth(
87+
# For custom headers, scheme and credentials fields are not used.
88+
scheme="",
89+
credentials=HttpCredentials(),
90+
additional_headers={
91+
response.header: response.token,
92+
"X-GOOG-API-KEY": response.token,
93+
},
94+
),
95+
)
96+
97+
98+
class GcpAuthProvider(BaseAuthProvider):
99+
"""An auth provider that uses the Agent Identity Credentials service to generate access tokens."""
100+
101+
_client: Client | None = None
102+
103+
def __init__(self, client: Client | None = None):
104+
self._client = client
105+
106+
@property
107+
@override
108+
def supported_auth_schemes(self) -> tuple[type[GcpAuthProviderScheme], ...]:
109+
return (GcpAuthProviderScheme,)
110+
111+
def _get_client(self) -> Client:
112+
"""Lazy loads the client to avoid unnecessary setup on startup."""
113+
if self._client is None:
114+
client_options = None
115+
if host := os.environ.get("IAM_CONNECTOR_CREDENTIALS_TARGET_HOST"):
116+
client_options = ClientOptions(api_endpoint=host)
117+
self._client = Client(client_options=client_options, transport="rest")
118+
return self._client
119+
120+
async def _retrieve_credentials(
121+
self,
122+
user_id: str,
123+
auth_scheme: GcpAuthProviderScheme,
124+
) -> Operation:
125+
request = RetrieveCredentialsRequest(
126+
connector=auth_scheme.name,
127+
user_id=user_id,
128+
scopes=auth_scheme.scopes,
129+
continue_uri=auth_scheme.continue_uri or "",
130+
force_refresh=False,
131+
)
132+
# TODO: Use async client once available. Temporarily using threading to
133+
# prevent blocking the event loop.
134+
operation = await asyncio.to_thread(
135+
self._get_client().retrieve_credentials, request
136+
)
137+
return operation.operation
138+
139+
def _unpack_operation(
140+
self, operation: Operation
141+
) -> tuple[
142+
RetrieveCredentialsResponse | None, RetrieveCredentialsMetadata | None
143+
]:
144+
"""Deserializes the response and metadata from the operation."""
145+
response = None
146+
metadata = None
147+
if operation.response:
148+
response = RetrieveCredentialsResponse.deserialize(
149+
operation.response.value
150+
)
151+
if operation.metadata:
152+
metadata = RetrieveCredentialsMetadata.deserialize(
153+
operation.metadata.value
154+
)
155+
return response, metadata
156+
157+
async def _poll_credentials(
158+
self, user_id: str, auth_scheme: GcpAuthProviderScheme, timeout: float
159+
) -> Operation:
160+
end_time = time.time() + timeout
161+
while time.time() < end_time:
162+
operation = await self._retrieve_credentials(user_id, auth_scheme)
163+
if operation.done:
164+
return operation
165+
await asyncio.sleep(NON_INTERACTIVE_TOKEN_POLL_INTERVAL_SEC)
166+
raise TimeoutError("Timeout waiting for credentials.")
167+
168+
@staticmethod
169+
def _is_consent_completed(context: CallbackContext) -> bool:
170+
"""Checks if the user consent flow is completed for the current function call."""
171+
if not context.function_call_id:
172+
return False
173+
174+
if not context.session:
175+
return False
176+
177+
events = context.session.events
178+
target_tool_call_id = context.function_call_id
179+
180+
# Find all relevant function calls and responses
181+
euc_calls = {}
182+
euc_responses = {}
183+
184+
for event in events:
185+
for call in event.get_function_calls():
186+
if call.name == REQUEST_EUC_FUNCTION_CALL_NAME:
187+
euc_calls[call.id] = call
188+
for response in event.get_function_responses():
189+
if response.name == REQUEST_EUC_FUNCTION_CALL_NAME:
190+
euc_responses[response.id] = response
191+
192+
# Check for a response that matches a call for the current tool invocation
193+
for call_id, _ in euc_responses.items():
194+
if call_id in euc_calls:
195+
call = euc_calls[call_id]
196+
if (
197+
call.args
198+
and call.args.get("function_call_id") == target_tool_call_id
199+
):
200+
return True
201+
return False
202+
203+
@override
204+
async def get_auth_credential(
205+
self,
206+
auth_config: AuthConfig,
207+
context: CallbackContext | None = None,
208+
) -> AuthCredential:
209+
"""Retrieves credentials using the Agent Identity Credentials service.
210+
211+
Args:
212+
auth_config: The authentication configuration.
213+
context: Optional context for the callback.
214+
215+
Returns:
216+
An AuthCredential instance.
217+
218+
Raises:
219+
ValueError: If auth_scheme is not a GcpAuthProviderScheme.
220+
RuntimeError: If credential retrieval or polling fails.
221+
"""
222+
223+
auth_scheme = auth_config.auth_scheme
224+
if not isinstance(auth_scheme, GcpAuthProviderScheme):
225+
raise ValueError(
226+
f"Expected GcpAuthProviderScheme, got {type(auth_scheme)}"
227+
)
228+
229+
if context is None or context.user_id is None:
230+
raise ValueError(
231+
"GcpAuthProvider requires a context with a valid user_id."
232+
)
233+
234+
user_id = context.user_id
235+
236+
try:
237+
operation = await self._retrieve_credentials(user_id, auth_scheme)
238+
except Exception as e:
239+
raise RuntimeError(
240+
f"Failed to retrieve credential for user '{user_id}' on connector"
241+
f" '{auth_scheme.name}'."
242+
) from e
243+
244+
response, metadata = self._unpack_operation(operation)
245+
246+
if operation.HasField("error"):
247+
raise RuntimeError(f"Operation failed: {operation.error.message}")
248+
249+
if operation.done:
250+
logger.debug("Auth credential obtained immediately.")
251+
return _construct_auth_credential(response)
252+
253+
if metadata and metadata.consent_pending:
254+
# Get 2-legged OAuth token. Allow enough time for token exchange.
255+
try:
256+
operation = await self._poll_credentials(
257+
user_id,
258+
auth_scheme,
259+
timeout=NON_INTERACTIVE_TOKEN_POLL_TIMEOUT_SEC,
260+
)
261+
if operation.HasField("error"):
262+
raise RuntimeError(f"Operation failed: {operation.error.message}")
263+
if operation.done:
264+
logger.debug("Auth credential obtained after polling.")
265+
response, _ = self._unpack_operation(operation)
266+
return _construct_auth_credential(response)
267+
except Exception as e:
268+
raise RuntimeError(
269+
f"Failed to retrieve credential for user '{user_id}' on connector"
270+
f" '{auth_scheme.name}'."
271+
) from e
272+
273+
if metadata is not None and metadata.uri_consent_required:
274+
if self._is_consent_completed(context):
275+
raise RuntimeError("Failed to retrieve consent based credential.")
276+
277+
# Return AuthCredential with only auth_uri to trigger user consent flow.
278+
return AuthCredential(
279+
auth_type=AuthCredentialTypes.OAUTH2,
280+
oauth2=OAuth2Auth(
281+
auth_uri=metadata.uri_consent_required.authorization_uri,
282+
nonce=metadata.uri_consent_required.consent_nonce,
283+
),
284+
)

0 commit comments

Comments
 (0)