diff --git a/sqlspec/adapters/spanner/config.py b/sqlspec/adapters/spanner/config.py index 4391d1701..6b376845a 100644 --- a/sqlspec/adapters/spanner/config.py +++ b/sqlspec/adapters/spanner/config.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast from google.cloud.spanner_v1 import Client -from google.cloud.spanner_v1.pool import AbstractSessionPool, FixedSizePool +from google.cloud.spanner_v1.pool import AbstractSessionPool, BurstyPool, FixedSizePool, PingingPool from typing_extensions import NotRequired from sqlspec.adapters.spanner._typing import SpannerConnection @@ -18,10 +18,17 @@ if TYPE_CHECKING: from collections.abc import Callable + from logging import Logger from types import TracebackType + from google.api_core.client_info import ClientInfo + from google.api_core.client_options import ClientOptions + from google.api_core.retry import Retry from google.auth.credentials import Credentials + from google.cloud.spanner_admin_database_v1.types import DatabaseDialect, EncryptionConfig + from google.cloud.spanner_v1 import DirectedReadOptions, ExecuteSqlRequest from google.cloud.spanner_v1.database import Database + from google.cloud.spanner_v1.transaction import DefaultTransactionOptions from sqlspec.config import ExtensionConfigs from sqlspec.core import StatementConfig @@ -37,15 +44,73 @@ :meth:`SpannerSyncConfig.provide_read_session`. Pulled into a module-level constant so an eventual ``SpannerAsyncConfig`` can import the same default.""" +_CLIENT_CONFIG_FIELDS = frozenset({ + "project", + "credentials", + "client_info", + "client_options", + "query_options", + "route_to_leader_enabled", + "directed_read_options", + "observability_options", + "default_transaction_options", + "experimental_host", + "disable_builtin_metrics", + "client_context", + "use_plain_text", + "ca_certificate", + "client_certificate", + "client_key", + "instance_type", +}) +_INSTANCE_CONFIG_FIELDS = frozenset({"configuration_name", "display_name", "node_count", "processing_units"}) +_DATABASE_CONFIG_FIELDS = frozenset({ + "ddl_statements", + "logger", + "encryption_config", + "database_dialect", + "database_role", + "enable_drop_protection", + "enable_interceptors_in_tests", + "proto_descriptors", +}) + class SpannerConnectionParams(TypedDict): """Spanner connection parameters.""" project: "NotRequired[str]" + credentials: "NotRequired[Credentials]" + client_info: "NotRequired[ClientInfo]" + client_options: "NotRequired[ClientOptions | dict[str, Any]]" + query_options: "NotRequired[ExecuteSqlRequest.QueryOptions]" + route_to_leader_enabled: "NotRequired[bool]" + directed_read_options: "NotRequired[DirectedReadOptions]" + observability_options: "NotRequired[Any]" + default_transaction_options: "NotRequired[DefaultTransactionOptions]" + experimental_host: "NotRequired[str]" + disable_builtin_metrics: "NotRequired[bool]" + client_context: "NotRequired[dict[str, str]]" + use_plain_text: "NotRequired[bool]" + ca_certificate: "NotRequired[str]" + client_certificate: "NotRequired[str]" + client_key: "NotRequired[str]" + instance_type: "NotRequired[str]" instance_id: "NotRequired[str]" + configuration_name: "NotRequired[str]" + display_name: "NotRequired[str]" + node_count: "NotRequired[int]" + processing_units: "NotRequired[int]" + instance_labels: "NotRequired[dict[str, str]]" database_id: "NotRequired[str]" - credentials: "NotRequired[Credentials]" - client_options: "NotRequired[dict[str, Any]]" + ddl_statements: "NotRequired[tuple[str, ...] | list[str]]" + logger: "NotRequired[Logger]" + encryption_config: "NotRequired[EncryptionConfig | dict[str, Any]]" + database_dialect: "NotRequired[DatabaseDialect]" + database_role: "NotRequired[str]" + enable_drop_protection: "NotRequired[bool]" + enable_interceptors_in_tests: "NotRequired[bool]" + proto_descriptors: "NotRequired[bytes]" extra: "NotRequired[dict[str, Any]]" @@ -53,10 +118,14 @@ class SpannerPoolParams(SpannerConnectionParams): """Session pool configuration.""" pool_type: "NotRequired[type[AbstractSessionPool]]" - min_sessions: "NotRequired[int]" + size: "NotRequired[int]" + target_size: "NotRequired[int]" max_sessions: "NotRequired[int]" + default_timeout: "NotRequired[int | float]" + session_labels: "NotRequired[dict[str, str]]" labels: "NotRequired[dict[str, str]]" ping_interval: "NotRequired[int]" + max_age_minutes: "NotRequired[int]" class SpannerDriverFeatures(TypedDict): @@ -66,7 +135,10 @@ class SpannerDriverFeatures(TypedDict): enable_uuid_conversion: Enable automatic UUID string conversion. json_serializer: Custom JSON serializer for parameter conversion. json_deserializer: Custom JSON deserializer for result conversion. - session_labels: Labels to apply to Spanner sessions. + retry: Per-request retry policy passed to execute_sql(), execute_update(), and batch_update(). + timeout: Per-request timeout in seconds passed to execute_sql(), execute_update(), and batch_update(). + session_labels: Deprecated compatibility alias for pool session labels. + Prefer ``connection_config["session_labels"]``. enable_events: Enable database event channel support. Defaults to True when extension_config["events"] is configured. events_backend: Backend type for event handling. @@ -76,6 +148,8 @@ class SpannerDriverFeatures(TypedDict): enable_uuid_conversion: "NotRequired[bool]" json_serializer: "NotRequired[Callable[[Any], str]]" json_deserializer: "NotRequired[Callable[[str], Any]]" + retry: "NotRequired[Retry | None]" + timeout: "NotRequired[float | None]" session_labels: "NotRequired[dict[str, str]]" enable_events: "NotRequired[bool]" events_backend: "NotRequired[str]" @@ -190,12 +264,23 @@ def __init__( **kwargs: Any, ) -> None: self.connection_config = normalize_connection_config(connection_config) + if "min_sessions" in self.connection_config: + msg = "Spanner session pools do not support 'min_sessions'; use 'size' or 'target_size'." + raise ImproperConfigurationError(msg) + + raw_driver_features: dict[str, Any] = dict(driver_features) if driver_features else {} + legacy_session_labels = raw_driver_features.pop("session_labels", None) + if ( + legacy_session_labels is not None + and "session_labels" not in self.connection_config + and "labels" not in self.connection_config + ): + self.connection_config["session_labels"] = legacy_session_labels - self.connection_config.setdefault("min_sessions", 1) - self.connection_config.setdefault("max_sessions", 10) + self.connection_config.setdefault("size", self.connection_config.pop("max_sessions", 10)) self.connection_config.setdefault("pool_type", FixedSizePool) - driver_features = apply_driver_features(driver_features) + driver_features = apply_driver_features(raw_driver_features) statement_config = statement_config or default_statement_config @@ -216,11 +301,8 @@ def __init__( def _get_client(self) -> Client: if self._client is None: - self._client = Client( - project=self.connection_config.get("project"), - credentials=self.connection_config.get("credentials"), - client_options=self.connection_config.get("client_options"), - ) + client_kwargs = self._resolve_kwargs(_CLIENT_CONFIG_FIELDS) + self._client = Client(**client_kwargs) return self._client def get_database(self) -> "Database": @@ -235,7 +317,15 @@ def get_database(self) -> "Database": if self._database is None: client = self._get_client() - self._database = client.instance(instance_id).database(database_id, pool=self.connection_instance) # type: ignore[no-untyped-call] + instance_kwargs = self._resolve_kwargs(_INSTANCE_CONFIG_FIELDS) + instance_labels = self.connection_config.get("instance_labels") + if instance_labels is not None: + instance_kwargs["labels"] = instance_labels + database_kwargs = self._resolve_kwargs(_DATABASE_CONFIG_FIELDS) + database_kwargs["pool"] = self.connection_instance + self._database = client.instance(instance_id, **instance_kwargs).database( # type: ignore[no-untyped-call] + database_id, **database_kwargs + ) return self._database def create_connection(self) -> SpannerConnection: @@ -250,23 +340,38 @@ def _create_pool(self) -> AbstractSessionPool: pool_type = cast("type[AbstractSessionPool]", self.connection_config.get("pool_type", FixedSizePool)) - pool_kwargs: dict[str, Any] = {} - if pool_type is FixedSizePool: - if "size" in self.connection_config: - pool_kwargs["size"] = self.connection_config["size"] - elif "max_sessions" in self.connection_config: - pool_kwargs["size"] = self.connection_config["max_sessions"] - if "labels" in self.connection_config: - pool_kwargs["labels"] = self.connection_config["labels"] + labels = self.connection_config.get("session_labels", self.connection_config.get("labels")) + pool_kwargs: dict[str, Any] = self._resolve_pool_base_kwargs(labels=cast("dict[str, str] | None", labels)) + if issubclass(pool_type, PingingPool): + pool_kwargs.update(self._resolve_kwargs({"size", "default_timeout", "ping_interval"})) + elif issubclass(pool_type, FixedSizePool): + pool_kwargs.update(self._resolve_kwargs({"size", "default_timeout", "max_age_minutes"})) + elif issubclass(pool_type, BurstyPool): + target_size = self.connection_config.get("target_size", self.connection_config.get("size")) + if target_size is not None: + pool_kwargs["target_size"] = target_size else: - valid_pool_keys = {"size", "labels", "ping_interval"} - pool_kwargs = {k: v for k, v in self.connection_config.items() if k in valid_pool_keys and v is not None} - if "size" not in pool_kwargs and "max_sessions" in self.connection_config: - pool_kwargs["size"] = self.connection_config["max_sessions"] + pool_kwargs.update( + self._resolve_kwargs({"size", "target_size", "default_timeout", "ping_interval", "max_age_minutes"}) + ) pool_factory = cast("Callable[..., AbstractSessionPool]", pool_type) return pool_factory(**pool_kwargs) + def _resolve_pool_base_kwargs(self, *, labels: "dict[str, str] | None") -> dict[str, Any]: + pool_kwargs: dict[str, Any] = {} + if labels is not None: + pool_kwargs["labels"] = labels + database_role = self.connection_config.get("database_role") + if database_role is not None: + pool_kwargs["database_role"] = database_role + return pool_kwargs + + def _resolve_kwargs(self, fields: "frozenset[str] | set[str]") -> dict[str, Any]: + return { + field: self.connection_config[field] for field in fields if self.connection_config.get(field) is not None + } + def _close_pool(self) -> None: if self.connection_instance and supports_close(self.connection_instance): self.connection_instance.close() diff --git a/sqlspec/adapters/spanner/driver.py b/sqlspec/adapters/spanner/driver.py index b39fa58da..8a35850e8 100644 --- a/sqlspec/adapters/spanner/driver.py +++ b/sqlspec/adapters/spanner/driver.py @@ -63,7 +63,11 @@ def __iter__(self) -> Iterator[Any]: ... class _SpannerReadProtocol(Protocol): def execute_sql( - self, sql: str, params: "dict[str, Any] | None" = None, param_types: "dict[str, Any] | None" = None + self, + sql: str, + params: "dict[str, Any] | None" = None, + param_types: "dict[str, Any] | None" = None, + **kwargs: Any, ) -> _SpannerResultSetProtocol: ... @@ -71,11 +75,15 @@ class _SpannerWriteProtocol(_SpannerReadProtocol, Protocol): committed: "Any | None" def execute_update( - self, sql: str, params: "dict[str, Any] | None" = None, param_types: "dict[str, Any] | None" = None + self, + sql: str, + params: "dict[str, Any] | None" = None, + param_types: "dict[str, Any] | None" = None, + **kwargs: Any, ) -> int: ... def batch_update( - self, batch: "list[tuple[str, dict[str, Any] | None, dict[str, Any]]]" + self, batch: "list[tuple[str, dict[str, Any] | None, dict[str, Any]]]", **kwargs: Any ) -> "tuple[Any, list[int]]": ... def commit(self) -> None: ... @@ -138,10 +146,11 @@ def dispatch_execute(self, cursor: "SpannerConnection", statement: "SQL") -> Exe params = cast("dict[str, Any] | None", params) coerced_params = self._coerce_params(params) param_types_map = self._infer_param_types(coerced_params) + execute_kwargs = self._execute_kwargs() if statement.returns_rows(): reader = cast("_SpannerReadProtocol", cursor) - result_set = reader.execute_sql(sql, params=coerced_params, param_types=param_types_map) + result_set = reader.execute_sql(sql, params=coerced_params, param_types=param_types_map, **execute_kwargs) rows = list(result_set) try: metadata = result_set.metadata @@ -165,7 +174,7 @@ def dispatch_execute(self, cursor: "SpannerConnection", statement: "SQL") -> Exe if supports_write(cursor): writer = cast("_SpannerWriteProtocol", cursor) - row_count = writer.execute_update(sql, params=coerced_params, param_types=param_types_map) + row_count = writer.execute_update(sql, params=coerced_params, param_types=param_types_map, **execute_kwargs) return self.create_execution_result(cursor, rowcount_override=row_count) raise SQLConversionError(_READ_ONLY_SNAPSHOT_ERROR_MESSAGE) @@ -183,6 +192,7 @@ def dispatch_execute_many(self, cursor: "SpannerConnection", statement: "SQL") - _coerce = self._coerce_params _infer = self._infer_param_types + execute_kwargs = self._execute_kwargs() param_types_cache: dict[tuple[tuple[str, type[Any]], ...], dict[str, Any]] = {} empty_param_types: dict[str, Any] = {} batch_args: list[tuple[str, dict[str, Any] | None, dict[str, Any]]] = [] @@ -200,7 +210,7 @@ def dispatch_execute_many(self, cursor: "SpannerConnection", statement: "SQL") - append_batch_arg((sql, coerced_params, param_types)) writer = cast("_SpannerWriteProtocol", cursor) - _status, row_counts = writer.batch_update(batch_args) + _status, row_counts = writer.batch_update(batch_args, **execute_kwargs) total_rows = sum(row_counts) if row_counts else 0 return self.create_execution_result(cursor, rowcount_override=total_rows, is_many_result=True) @@ -215,6 +225,7 @@ def dispatch_execute_script(self, cursor: "SpannerConnection", statement: "SQL") script_params = cast("dict[str, Any] | None", params) coerced_params = self._coerce_params(script_params) param_types_map = self._infer_param_types(coerced_params) + execute_kwargs = self._execute_kwargs() for stmt in statements: try: parsed = _sqlglot.parse_one(stmt) @@ -225,9 +236,9 @@ def dispatch_execute_script(self, cursor: "SpannerConnection", statement: "SQL") raise SQLConversionError(_READ_ONLY_SNAPSHOT_ERROR_MESSAGE) if not is_select and is_transaction: writer = cast("_SpannerWriteProtocol", cursor) - writer.execute_update(stmt, params=coerced_params, param_types=param_types_map) + writer.execute_update(stmt, params=coerced_params, param_types=param_types_map, **execute_kwargs) else: - _ = list(reader.execute_sql(stmt, params=coerced_params, param_types=param_types_map)) + _ = list(reader.execute_sql(stmt, params=coerced_params, param_types=param_types_map, **execute_kwargs)) count += 1 return self.create_execution_result( @@ -259,6 +270,9 @@ def with_cursor(self, connection: "SpannerConnection") -> "SpannerSyncCursor": def handle_database_exceptions(self) -> "SpannerExceptionHandler": return SpannerExceptionHandler() + def _execute_kwargs(self) -> dict[str, Any]: + return {key: self.driver_features[key] for key in ("retry", "timeout") if key in self.driver_features} + # ───────────────────────────────────────────────────────────────────────────── # ARROW API METHODS # ───────────────────────────────────────────────────────────────────────────── diff --git a/tests/integration/adapters/spanner/conftest.py b/tests/integration/adapters/spanner/conftest.py index 2b26ac491..bd9a63499 100644 --- a/tests/integration/adapters/spanner/conftest.py +++ b/tests/integration/adapters/spanner/conftest.py @@ -96,8 +96,7 @@ def spanner_config( "database_id": spanner_service.database_name, "credentials": spanner_service.credentials, "client_options": {"api_endpoint": api_endpoint}, - "min_sessions": 1, - "max_sessions": 5, + "size": 5, } ) try: diff --git a/tests/integration/adapters/spanner/extensions/adk/conftest.py b/tests/integration/adapters/spanner/extensions/adk/conftest.py index 57ad9bace..1b697c64a 100644 --- a/tests/integration/adapters/spanner/extensions/adk/conftest.py +++ b/tests/integration/adapters/spanner/extensions/adk/conftest.py @@ -22,8 +22,7 @@ def spanner_adk_config(spanner_service: SpannerService, spanner_database: "Datab "database_id": spanner_service.database_name, "credentials": spanner_service.credentials, "client_options": {"api_endpoint": api_endpoint}, - "min_sessions": 1, - "max_sessions": 5, + "size": 5, }, extension_config={"adk": {"session_table": "adk_sessions", "events_table": "adk_events"}}, ) diff --git a/tests/integration/adapters/spanner/extensions/events/conftest.py b/tests/integration/adapters/spanner/extensions/events/conftest.py index 6a7d28b7c..bd7d4370c 100644 --- a/tests/integration/adapters/spanner/extensions/events/conftest.py +++ b/tests/integration/adapters/spanner/extensions/events/conftest.py @@ -26,8 +26,7 @@ def spanner_events_config(spanner_service: SpannerService, spanner_database: "Da "database_id": spanner_service.database_name, "credentials": spanner_service.credentials, "client_options": {"api_endpoint": api_endpoint}, - "min_sessions": 1, - "max_sessions": 5, + "size": 5, }, extension_config={"events": {"queue_table": "sqlspec_event_queue"}}, ) diff --git a/tests/integration/adapters/spanner/extensions/litestar/conftest.py b/tests/integration/adapters/spanner/extensions/litestar/conftest.py index c56dd7765..71085ea05 100644 --- a/tests/integration/adapters/spanner/extensions/litestar/conftest.py +++ b/tests/integration/adapters/spanner/extensions/litestar/conftest.py @@ -22,8 +22,7 @@ def spanner_litestar_config(spanner_service: SpannerService, spanner_database: " "database_id": spanner_service.database_name, "credentials": spanner_service.credentials, "client_options": {"api_endpoint": api_endpoint}, - "min_sessions": 1, - "max_sessions": 5, + "size": 5, }, extension_config={"litestar": {"session_table": "litestar_sessions"}}, ) diff --git a/tests/unit/adapters/test_spanner/test_config.py b/tests/unit/adapters/test_spanner/test_config.py index 2226a1f26..52e5aeaa6 100644 --- a/tests/unit/adapters/test_spanner/test_config.py +++ b/tests/unit/adapters/test_spanner/test_config.py @@ -1,11 +1,22 @@ +from datetime import timedelta +from typing import TYPE_CHECKING, Any, cast + import pytest +from google.cloud.spanner_v1.pool import AbstractSessionPool, BurstyPool, FixedSizePool -from sqlspec.adapters.spanner.config import SpannerSyncConfig +from sqlspec.adapters.spanner import config as config_module +from sqlspec.adapters.spanner.config import SpannerConnectionParams, SpannerPoolParams, SpannerSyncConfig from sqlspec.adapters.spanner.core import default_statement_config from sqlspec.driver import SyncDriverAdapterBase from sqlspec.exceptions import ImproperConfigurationError from tests.conftest import requires_interpreted +if TYPE_CHECKING: + from google.api_core.client_info import ClientInfo + from google.auth.credentials import Credentials + from google.cloud.spanner_v1 import DirectedReadOptions, ExecuteSqlRequest + from google.cloud.spanner_v1.transaction import DefaultTransactionOptions + pytestmark = requires_interpreted @@ -37,8 +48,14 @@ def test_config_defaults() -> None: """Test default values.""" config = SpannerSyncConfig(connection_config={"project": "p", "instance_id": "i", "database_id": "d"}) assert config.connection_config is not None - assert config.connection_config["min_sessions"] == 1 - assert config.connection_config["max_sessions"] == 10 + assert "min_sessions" not in config.connection_config + assert config.connection_config["size"] == 10 + + +def test_min_sessions_is_rejected() -> None: + """Spanner's current session pool classes do not support min_sessions.""" + with pytest.raises(ImproperConfigurationError, match="min_sessions"): + SpannerSyncConfig(connection_config={"project": "p", "instance_id": "i", "database_id": "d", "min_sessions": 1}) def test_improper_configuration() -> None: @@ -55,6 +72,254 @@ def test_driver_features_defaults() -> None: assert config.driver_features["json_serializer"] is not None +def test_driver_feature_session_labels_are_routed_to_pool() -> None: + """Legacy driver feature session labels should configure pool session labels.""" + labels = {"workload": "analytics"} + config = SpannerSyncConfig( + connection_config={"project": "p", "instance_id": "i", "database_id": "d"}, + driver_features={"session_labels": labels}, + ) + + pool = config.provide_pool() + + assert pool.labels == labels + assert "session_labels" not in config.driver_features + + +def test_fixed_size_pool_routes_current_session_controls() -> None: + """FixedSizePool should receive the declared Spanner session pool settings.""" + labels = {"service": "api"} + config = SpannerSyncConfig( + connection_config={ + "project": "p", + "instance_id": "i", + "database_id": "d", + "pool_type": FixedSizePool, + "size": 4, + "default_timeout": 6, + "session_labels": labels, + "database_role": "reader", + "max_age_minutes": 23, + } + ) + + pool = config.provide_pool() + + assert isinstance(pool, FixedSizePool) + assert pool.size == 4 + assert pool.default_timeout == 6 + assert pool.labels == labels + assert pool.database_role == "reader" + assert pool._max_age == timedelta(minutes=23) + + +def test_bursty_pool_uses_target_size() -> None: + """BurstyPool should receive target_size instead of the FixedSize size key.""" + config = SpannerSyncConfig( + connection_config={ + "project": "p", + "instance_id": "i", + "database_id": "d", + "pool_type": BurstyPool, + "target_size": 7, + } + ) + + pool = config.provide_pool() + + assert isinstance(pool, BurstyPool) + assert pool.target_size == 7 + + +def test_get_database_routes_client_instance_and_database_settings(monkeypatch: pytest.MonkeyPatch) -> None: + """Current Google client, instance, and database settings should be forwarded.""" + + class _FakeDatabase: + def __init__(self, database_id: str, kwargs: dict[str, Any]) -> None: + self.database_id = database_id + self.kwargs = kwargs + + class _FakeInstance: + def __init__(self, instance_id: str, kwargs: dict[str, Any]) -> None: + self.instance_id = instance_id + self.kwargs = kwargs + self.database_calls: list[tuple[str, dict[str, Any]]] = [] + + def database(self, database_id: str, **kwargs: Any) -> "_FakeDatabase": + self.database_calls.append((database_id, kwargs)) + return _FakeDatabase(database_id, kwargs) + + class _FakeClient: + def __init__(self, **kwargs: Any) -> None: + self.kwargs = kwargs + self.instances: list[_FakeInstance] = [] + created_clients.append(self) + + def instance(self, instance_id: str, **kwargs: Any) -> _FakeInstance: + instance = _FakeInstance(instance_id, kwargs) + self.instances.append(instance) + return instance + + created_clients: list[_FakeClient] = [] + monkeypatch.setattr(config_module, "Client", _FakeClient) + + client_info = object() + query_options = object() + directed_read_options = object() + observability_options = object() + default_transaction_options = object() + client_context = object() + client_options = {"api_endpoint": "spanner.example.test"} + credentials = object() + pool = cast(AbstractSessionPool, object()) + logger = object() + encryption_config = {"kms_key_name": "projects/p/locations/l/keyRings/r/cryptoKeys/k"} + + config = SpannerSyncConfig( + connection_config={ + "project": "p", + "credentials": credentials, + "client_options": client_options, + "client_info": client_info, + "query_options": query_options, + "route_to_leader_enabled": False, + "directed_read_options": directed_read_options, + "observability_options": observability_options, + "default_transaction_options": default_transaction_options, + "disable_builtin_metrics": True, + "client_context": client_context, + "use_plain_text": True, + "ca_certificate": "ca", + "client_certificate": "cert", + "client_key": "key", + "instance_type": "cloud", + "instance_id": "instance", + "configuration_name": "regional-us-central1", + "display_name": "Instance", + "node_count": 1, + "instance_labels": {"env": "test"}, + "database_id": "database", + "logger": logger, + "encryption_config": encryption_config, + "database_role": "reader", + "enable_drop_protection": True, + "enable_interceptors_in_tests": True, + "proto_descriptors": b"proto", + }, + connection_instance=pool, + ) + + database = cast(_FakeDatabase, config.get_database()) + + client = created_clients[0] + assert client.kwargs == { + "project": "p", + "credentials": credentials, + "client_options": client_options, + "client_info": client_info, + "query_options": query_options, + "route_to_leader_enabled": False, + "directed_read_options": directed_read_options, + "observability_options": observability_options, + "default_transaction_options": default_transaction_options, + "disable_builtin_metrics": True, + "client_context": client_context, + "use_plain_text": True, + "ca_certificate": "ca", + "client_certificate": "cert", + "client_key": "key", + "instance_type": "cloud", + } + instance = client.instances[0] + assert instance.instance_id == "instance" + assert instance.kwargs == { + "configuration_name": "regional-us-central1", + "display_name": "Instance", + "node_count": 1, + "labels": {"env": "test"}, + } + assert database.database_id == "database" + assert database.kwargs == { + "pool": pool, + "logger": logger, + "encryption_config": encryption_config, + "database_role": "reader", + "enable_drop_protection": True, + "enable_interceptors_in_tests": True, + "proto_descriptors": b"proto", + } + + +def test_spanner_params_type_current_client_database_and_pool_settings() -> None: + """Static type check coverage for modern client, database, and pool options.""" + connection_config: SpannerConnectionParams = { + "project": "p", + "credentials": cast("Credentials", object()), + "client_options": {"api_endpoint": "spanner.example.test"}, + "client_info": cast("ClientInfo", object()), + "query_options": cast("ExecuteSqlRequest.QueryOptions", object()), + "route_to_leader_enabled": False, + "directed_read_options": cast("DirectedReadOptions", object()), + "observability_options": object(), + "default_transaction_options": cast("DefaultTransactionOptions", object()), + "disable_builtin_metrics": True, + "client_context": {"trace": "enabled"}, + "use_plain_text": True, + "ca_certificate": "ca", + "client_certificate": "cert", + "client_key": "key", + "instance_type": "cloud", + "instance_id": "i", + "configuration_name": "regional-us-central1", + "display_name": "Instance", + "node_count": 1, + "instance_labels": {"env": "test"}, + "database_id": "d", + "database_role": "reader", + "enable_drop_protection": True, + "enable_interceptors_in_tests": True, + "proto_descriptors": b"proto", + } + pool_config: SpannerPoolParams = { + "project": "p", + "credentials": cast("Credentials", object()), + "client_options": {"api_endpoint": "spanner.example.test"}, + "client_info": cast("ClientInfo", object()), + "query_options": cast("ExecuteSqlRequest.QueryOptions", object()), + "route_to_leader_enabled": False, + "directed_read_options": cast("DirectedReadOptions", object()), + "observability_options": object(), + "default_transaction_options": cast("DefaultTransactionOptions", object()), + "disable_builtin_metrics": True, + "client_context": {"trace": "enabled"}, + "use_plain_text": True, + "ca_certificate": "ca", + "client_certificate": "cert", + "client_key": "key", + "instance_type": "cloud", + "instance_id": "i", + "configuration_name": "regional-us-central1", + "display_name": "Instance", + "node_count": 1, + "instance_labels": {"env": "test"}, + "database_id": "d", + "database_role": "reader", + "enable_drop_protection": True, + "enable_interceptors_in_tests": True, + "proto_descriptors": b"proto", + "pool_type": FixedSizePool, + "size": 4, + "target_size": 4, + "default_timeout": 6, + "session_labels": {"service": "api"}, + "max_age_minutes": 23, + "ping_interval": 300, + } + + assert connection_config["project"] == "p" + assert pool_config["size"] == 4 + + def test_provide_connection_batch_and_snapshot() -> None: """Ensure provide_connection selects snapshot vs transaction correctly.""" snap_obj = object() diff --git a/tests/unit/adapters/test_spanner/test_driver.py b/tests/unit/adapters/test_spanner/test_driver.py index 9ac98997f..85f1446a5 100644 --- a/tests/unit/adapters/test_spanner/test_driver.py +++ b/tests/unit/adapters/test_spanner/test_driver.py @@ -67,6 +67,25 @@ def test_execute_statement_select(mock_connection: MagicMock) -> None: assert result.selected_data[1] == (2, "Bob") +def test_execute_statement_select_forwards_retry_and_timeout(mock_connection: MagicMock) -> None: + retry = object() + driver = SpannerSyncDriver(mock_connection, driver_features={"retry": retry, "timeout": 12.5}) + + mock_result = MagicMock(spec=StreamedResultSet) + field = Mock() + field.name = "id" + mock_result.metadata.row_type.fields = [field] + mock_result.__iter__.return_value = iter([(1,)]) + mock_connection.execute_sql.return_value = mock_result + + statement = driver.prepare_statement("SELECT id FROM users", statement_config=driver.statement_config) + driver.dispatch_execute(mock_connection, statement) # type: ignore[protected-access] + + mock_connection.execute_sql.assert_called_once_with( + "SELECT id FROM users", params=None, param_types={}, retry=retry, timeout=12.5 + ) + + def test_execute_statement_dml_in_transaction(mock_transaction: MagicMock) -> None: driver = SpannerSyncDriver(mock_transaction) mock_transaction.execute_update.return_value = 10 @@ -78,6 +97,19 @@ def test_execute_statement_dml_in_transaction(mock_transaction: MagicMock) -> No mock_transaction.execute_update.assert_called_once() +def test_execute_statement_dml_forwards_retry_and_timeout(mock_transaction: MagicMock) -> None: + retry = object() + driver = SpannerSyncDriver(mock_transaction, driver_features={"retry": retry, "timeout": 12.5}) + mock_transaction.execute_update.return_value = 10 + + statement = driver.prepare_statement("UPDATE users SET name = 'Bob'", statement_config=driver.statement_config) + driver.dispatch_execute(mock_transaction, statement) # type: ignore[protected-access] + + mock_transaction.execute_update.assert_called_once_with( + "UPDATE users SET name = 'Bob'", params=None, param_types={}, retry=retry, timeout=12.5 + ) + + def test_insert_requires_transaction_or_update_method(mock_connection: MagicMock) -> None: driver = SpannerSyncDriver(mock_connection) # If connection doesn't have execute_update, DML should fail (Snapshot)