Skip to content

Commit d5058dc

Browse files
Merge branch 'main' into custom-session-id-with-agent-engine
2 parents f1fa856 + 114deef commit d5058dc

14 files changed

Lines changed: 775 additions & 31 deletions

src/google/adk/auth/auth_provider_registry.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,18 @@ def register(
4242
"""
4343
self._providers[auth_scheme_type] = provider_instance
4444

45-
def get_provider(self, auth_scheme: AuthScheme) -> BaseAuthProvider | None:
45+
def get_provider(
46+
self, auth_scheme: AuthScheme | type[AuthScheme]
47+
) -> BaseAuthProvider | None:
4648
"""Get the provider instance for an auth scheme.
4749
4850
Args:
49-
auth_scheme: The auth scheme to get provider for.
51+
auth_scheme: The auth scheme or the auth scheme type to get the provider
52+
for.
5053
5154
Returns:
5255
The provider instance if registered, None otherwise.
5356
"""
57+
if isinstance(auth_scheme, type):
58+
return self._providers.get(auth_scheme)
5459
return self._providers.get(type(auth_scheme))

src/google/adk/auth/auth_schemes.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from pydantic import Field
2828

2929
from ..utils.feature_decorator import experimental
30+
from .auth_credential import BaseModelWithConfig
3031

3132

3233
class OpenIdConnectWithConfig(SecurityBase):
@@ -42,8 +43,20 @@ class OpenIdConnectWithConfig(SecurityBase):
4243
scopes: Optional[List[str]] = None
4344

4445

45-
# AuthSchemes contains SecuritySchemes from OpenAPI 3.0 and an extra flattened OpenIdConnectWithConfig.
46-
AuthScheme = Union[SecurityScheme, OpenIdConnectWithConfig]
46+
class CustomAuthScheme(BaseModelWithConfig):
47+
"""A flexible model for custom authentication schemes.
48+
49+
The subclasses must define a `default` for the `type_` field, if using OAuth2
50+
user consent flow, to ensure correct rehydration.
51+
"""
52+
53+
type_: str = Field(alias="type")
54+
55+
56+
# AuthSchemes contains SecuritySchemes from OpenAPI 3.0, an extra flattened
57+
# OpenIdConnectWithConfig, and supports external schemes that subclasses
58+
# CustomAuthScheme
59+
AuthScheme = Union[SecurityScheme, OpenIdConnectWithConfig, CustomAuthScheme]
4760

4861

4962
class OAuthGrantType(str, Enum):

src/google/adk/auth/auth_tool.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,11 @@ def get_credential_key(self):
106106
if auth_scheme.model_extra:
107107
auth_scheme = auth_scheme.model_copy(deep=True)
108108
auth_scheme.model_extra.clear()
109+
110+
type_ = auth_scheme.type_
111+
type_name = type_.name if type_ and hasattr(type_, "name") else str(type_)
109112
scheme_name = (
110-
f"{auth_scheme.type_.name}_{_stable_model_digest(auth_scheme)}"
113+
f"{type_name}_{_stable_model_digest(auth_scheme)}"
111114
if auth_scheme
112115
else ""
113116
)

src/google/adk/auth/base_auth_provider.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616

1717
from abc import ABC
1818
from abc import abstractmethod
19+
from typing import TYPE_CHECKING
20+
21+
if TYPE_CHECKING:
22+
from .auth_schemes import AuthScheme
1923

2024
from ..agents.callback_context import CallbackContext
2125
from ..features import experimental
@@ -28,6 +32,15 @@
2832
class BaseAuthProvider(ABC):
2933
"""Abstract base class for custom authentication providers."""
3034

35+
@property
36+
def supported_auth_schemes(self) -> tuple[type[AuthScheme], ...]:
37+
"""The AuthScheme types supported by this provider.
38+
39+
Subclasses can override this to return a tuple of scheme types, enabling
40+
1-parameter registration.
41+
"""
42+
return ()
43+
3144
@abstractmethod
3245
async def get_auth_credential(
3346
self, auth_config: AuthConfig, context: CallbackContext

src/google/adk/auth/credential_manager.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414

1515
from __future__ import annotations
1616

17+
from collections.abc import Sequence
1718
import logging
19+
import threading
1820
from typing import Optional
1921

2022
from fastapi.openapi.models import OAuth2
@@ -25,10 +27,13 @@
2527
from .auth_credential import AuthCredential
2628
from .auth_credential import AuthCredentialTypes
2729
from .auth_provider_registry import AuthProviderRegistry
30+
from .auth_schemes import AuthScheme
2831
from .auth_schemes import AuthSchemeType
32+
from .auth_schemes import CustomAuthScheme
2933
from .auth_schemes import ExtendedOAuth2
3034
from .auth_schemes import OpenIdConnectWithConfig
3135
from .auth_tool import AuthConfig
36+
from .base_auth_provider import BaseAuthProvider
3237
from .exchanger.base_credential_exchanger import BaseCredentialExchanger
3338
from .exchanger.base_credential_exchanger import ExchangeResult
3439
from .exchanger.credential_exchanger_registry import CredentialExchangerRegistry
@@ -38,6 +43,25 @@
3843
logger = logging.getLogger("google_adk." + __name__)
3944

4045

46+
def _rehydrate_custom_scheme(
47+
scheme: CustomAuthScheme, supported_schemes: Sequence[type[AuthScheme]]
48+
) -> CustomAuthScheme:
49+
"""Rehydrate a CustomAuthScheme into one of the given supported_schemes."""
50+
incoming_type = scheme.type_
51+
for scheme_class in supported_schemes:
52+
type_field = scheme_class.model_fields.get("type_")
53+
# Custom AuthScheme classes must define a `default` for their `type_` field
54+
# to be rehydrated correctly.
55+
if type_field and type_field.default == incoming_type:
56+
data = scheme.model_dump(by_alias=True)
57+
if scheme.model_extra:
58+
data.update(scheme.model_extra)
59+
return scheme_class.model_validate(data)
60+
raise ValueError(
61+
f"Cannot rehydrate: no registered scheme matches type '{incoming_type}'"
62+
)
63+
64+
4165
@experimental
4266
class CredentialManager:
4367
"""Manages authentication credentials through a structured workflow.
@@ -77,12 +101,32 @@ class CredentialManager:
77101
```
78102
"""
79103

104+
_auth_provider_registry = AuthProviderRegistry()
105+
_registry_lock = threading.Lock()
106+
107+
@classmethod
108+
def register_auth_provider(cls, provider: BaseAuthProvider) -> None:
109+
"""Public API for developers to register custom auth providers."""
110+
with cls._registry_lock:
111+
for scheme_type in provider.supported_auth_schemes:
112+
existing_provider = cls._auth_provider_registry.get_provider(
113+
scheme_type
114+
)
115+
if existing_provider is not None:
116+
if existing_provider is not provider:
117+
logger.warning(
118+
"An auth provider is already registered for scheme %s. "
119+
"Ignoring the new provider.",
120+
scheme_type,
121+
)
122+
continue
123+
cls._auth_provider_registry.register(scheme_type, provider)
124+
80125
def __init__(
81126
self,
82127
auth_config: AuthConfig,
83128
):
84129
self._auth_config = auth_config
85-
self._auth_provider_registry = AuthProviderRegistry()
86130
self._exchanger_registry = CredentialExchangerRegistry()
87131
self._refresher_registry = CredentialRefresherRegistry()
88132
self._discovery_manager = OAuth2DiscoveryManager()
@@ -139,6 +183,20 @@ async def get_auth_credential(
139183
) -> Optional[AuthCredential]:
140184
"""Load and prepare authentication credential through a structured workflow."""
141185

186+
# Pydantic may have deserialized an unknown scheme into a generic
187+
# CustomAuthScheme. If so, rehydrate it first into a specific subclass.
188+
# Note: Custom authentication scheme classes must have been imported into
189+
# the Python runtime before get_auth_credential is called for their
190+
# subclasses to be registered. This is fine as developer will anyway import
191+
# them while registering the auth providers.
192+
# Note: `__subclasses__()` only returns immediate subclasses, if there is a
193+
# subclass of a subclass of CustomAuthScheme then it will not be returned.
194+
# pylint: disable=unidiomatic-typecheck Needs exact class matching.
195+
if type(self._auth_config.auth_scheme) is CustomAuthScheme:
196+
self._auth_config.auth_scheme = _rehydrate_custom_scheme(
197+
self._auth_config.auth_scheme,
198+
CustomAuthScheme.__subclasses__(),
199+
)
142200
# First, check if a registered auth provider is available before attempting
143201
# to retrieve tokens natively.
144202
provider = self._auth_provider_registry.get_provider(

src/google/adk/cli/cli_tools_click.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,6 +1064,7 @@ def cli_optimize(
10641064
from .cli_eval import _collect_eval_results
10651065
from .cli_eval import _collect_inferences
10661066
from .cli_eval import get_root_agent
1067+
10671068
except ModuleNotFoundError as mnf:
10681069
raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE) from mnf
10691070

@@ -1199,6 +1200,7 @@ def cli_add_eval_case(
11991200
from ..evaluation.eval_case import EvalCase
12001201
from ..evaluation.eval_case import SessionInput
12011202
from .cli_eval import get_eval_sets_manager
1203+
12021204
except ModuleNotFoundError as mnf:
12031205
raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE) from mnf
12041206

@@ -1247,6 +1249,127 @@ def cli_add_eval_case(
12471249
raise click.ClickException(f"Failed to add eval case(s): {e}") from e
12481250

12491251

1252+
@eval_set.command("generate_eval_cases", cls=HelpfulCommand)
1253+
@click.argument(
1254+
"agent_module_file_path",
1255+
type=click.Path(
1256+
exists=True, dir_okay=True, file_okay=False, resolve_path=True
1257+
),
1258+
)
1259+
@click.argument("eval_set_id", type=str, required=True)
1260+
@click.option(
1261+
"--user_simulation_config_file",
1262+
type=click.Path(
1263+
exists=True, dir_okay=False, file_okay=True, resolve_path=True
1264+
),
1265+
help=(
1266+
"A path to file containing JSON serialized "
1267+
"UserScenarioGenerationConfig dict."
1268+
),
1269+
required=True,
1270+
)
1271+
@eval_options()
1272+
def cli_generate_eval_cases(
1273+
agent_module_file_path: str,
1274+
eval_set_id: str,
1275+
user_simulation_config_file: str,
1276+
eval_storage_uri: Optional[str] = None,
1277+
log_level: str = "INFO",
1278+
):
1279+
"""Generates eval cases dynamically and adds them to the given eval set.
1280+
1281+
Uses Vertex AI Eval SDK to generate conversation scenarios based on an
1282+
Agent's info and definitions. It will automatically create the empty eval_set
1283+
if it has not been created in advance.
1284+
1285+
Args:
1286+
agent_module_file_path: The path to the agent module file.
1287+
eval_set_id: The id of the eval set to generate cases for.
1288+
user_simulation_config_file: The path to the user simulation config file.
1289+
eval_storage_uri: The eval storage uri.
1290+
log_level: The log level.
1291+
"""
1292+
logs.setup_adk_logger(getattr(logging, log_level.upper()))
1293+
try:
1294+
from ..evaluation._vertex_ai_scenario_generation_facade import ScenarioGenerator
1295+
from ..evaluation.conversation_scenarios import ConversationGenerationConfig
1296+
from ..evaluation.eval_case import EvalCase
1297+
from ..evaluation.eval_case import SessionInput
1298+
from .cli_eval import get_eval_sets_manager
1299+
from .cli_eval import get_root_agent
1300+
from .utils.state import create_empty_state
1301+
1302+
except ModuleNotFoundError as mnf:
1303+
raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE) from mnf
1304+
1305+
app_name = os.path.basename(agent_module_file_path)
1306+
agents_dir = os.path.dirname(agent_module_file_path)
1307+
1308+
try:
1309+
eval_sets_manager = get_eval_sets_manager(eval_storage_uri, agents_dir)
1310+
root_agent = get_root_agent(agent_module_file_path)
1311+
1312+
# Try to create if it doesn't already exist.
1313+
if (
1314+
eval_sets_manager.get_eval_set(
1315+
app_name=app_name, eval_set_id=eval_set_id
1316+
)
1317+
is None
1318+
):
1319+
eval_sets_manager.create_eval_set(
1320+
app_name=app_name, eval_set_id=eval_set_id
1321+
)
1322+
click.echo(f"Eval set '{eval_set_id}' created for app '{app_name}'.")
1323+
else:
1324+
click.echo(f"Eval set '{eval_set_id}' already exists.")
1325+
1326+
with open(user_simulation_config_file, "r") as f:
1327+
config = ConversationGenerationConfig.model_validate_json(f.read())
1328+
1329+
generator = ScenarioGenerator()
1330+
click.echo("Generating scenarios utilizing Vertex AI Eval SDK...")
1331+
scenarios = generator.generate_scenarios(root_agent, config)
1332+
1333+
# TODO(pthodoroff): Expose initial session state when simulation library
1334+
# supports it.
1335+
initial_session_state = create_empty_state(root_agent)
1336+
1337+
session_input = SessionInput(
1338+
app_name=app_name, user_id="test_user_id", state=initial_session_state
1339+
)
1340+
1341+
for scenario in scenarios:
1342+
scenario_str = json.dumps(scenario.model_dump(), sort_keys=True)
1343+
eval_id = hashlib.sha256(scenario_str.encode("utf-8")).hexdigest()[:8]
1344+
eval_case = EvalCase(
1345+
eval_id=eval_id,
1346+
conversation_scenario=scenario,
1347+
session_input=session_input,
1348+
creation_timestamp=datetime.now().timestamp(),
1349+
)
1350+
1351+
if (
1352+
eval_sets_manager.get_eval_case(
1353+
app_name=app_name, eval_set_id=eval_set_id, eval_case_id=eval_id
1354+
)
1355+
is None
1356+
):
1357+
eval_sets_manager.add_eval_case(
1358+
app_name=app_name, eval_set_id=eval_set_id, eval_case=eval_case
1359+
)
1360+
click.echo(
1361+
f"Eval case '{eval_case.eval_id}' added to eval set"
1362+
f" '{eval_set_id}'."
1363+
)
1364+
else:
1365+
click.echo(
1366+
f"Eval case '{eval_case.eval_id}' already exists in eval set"
1367+
f" '{eval_set_id}', skipped adding."
1368+
)
1369+
except Exception as e:
1370+
raise click.ClickException(f"Failed to generate eval case(s): {e}") from e
1371+
1372+
12501373
def web_options():
12511374
"""Decorator to add web UI options to click commands."""
12521375

0 commit comments

Comments
 (0)