Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 131 additions & 26 deletions sqlspec/adapters/spanner/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast

from google.cloud.spanner_v1 import Client
from google.cloud.spanner_v1.pool import AbstractSessionPool, FixedSizePool
from google.cloud.spanner_v1.pool import AbstractSessionPool, BurstyPool, FixedSizePool, PingingPool
from typing_extensions import NotRequired

from sqlspec.adapters.spanner._typing import SpannerConnection
Expand All @@ -18,10 +18,17 @@

if TYPE_CHECKING:
from collections.abc import Callable
from logging import Logger
from types import TracebackType

from google.api_core.client_info import ClientInfo
from google.api_core.client_options import ClientOptions
from google.api_core.retry import Retry
from google.auth.credentials import Credentials
from google.cloud.spanner_admin_database_v1.types import DatabaseDialect, EncryptionConfig
from google.cloud.spanner_v1 import DirectedReadOptions, ExecuteSqlRequest
from google.cloud.spanner_v1.database import Database
from google.cloud.spanner_v1.transaction import DefaultTransactionOptions

from sqlspec.config import ExtensionConfigs
from sqlspec.core import StatementConfig
Expand All @@ -37,26 +44,88 @@
:meth:`SpannerSyncConfig.provide_read_session`. Pulled into a module-level
constant so an eventual ``SpannerAsyncConfig`` can import the same default."""

_CLIENT_CONFIG_FIELDS = frozenset({
"project",
"credentials",
"client_info",
"client_options",
"query_options",
"route_to_leader_enabled",
"directed_read_options",
"observability_options",
"default_transaction_options",
"experimental_host",
"disable_builtin_metrics",
"client_context",
"use_plain_text",
"ca_certificate",
"client_certificate",
"client_key",
"instance_type",
})
_INSTANCE_CONFIG_FIELDS = frozenset({"configuration_name", "display_name", "node_count", "processing_units"})
_DATABASE_CONFIG_FIELDS = frozenset({
"ddl_statements",
"logger",
"encryption_config",
"database_dialect",
"database_role",
"enable_drop_protection",
"enable_interceptors_in_tests",
"proto_descriptors",
})


class SpannerConnectionParams(TypedDict):
"""Spanner connection parameters."""

project: "NotRequired[str]"
credentials: "NotRequired[Credentials]"
client_info: "NotRequired[ClientInfo]"
client_options: "NotRequired[ClientOptions | dict[str, Any]]"
query_options: "NotRequired[ExecuteSqlRequest.QueryOptions]"
route_to_leader_enabled: "NotRequired[bool]"
directed_read_options: "NotRequired[DirectedReadOptions]"
observability_options: "NotRequired[Any]"
default_transaction_options: "NotRequired[DefaultTransactionOptions]"
experimental_host: "NotRequired[str]"
disable_builtin_metrics: "NotRequired[bool]"
client_context: "NotRequired[dict[str, str]]"
use_plain_text: "NotRequired[bool]"
ca_certificate: "NotRequired[str]"
client_certificate: "NotRequired[str]"
client_key: "NotRequired[str]"
instance_type: "NotRequired[str]"
instance_id: "NotRequired[str]"
configuration_name: "NotRequired[str]"
display_name: "NotRequired[str]"
node_count: "NotRequired[int]"
processing_units: "NotRequired[int]"
instance_labels: "NotRequired[dict[str, str]]"
database_id: "NotRequired[str]"
credentials: "NotRequired[Credentials]"
client_options: "NotRequired[dict[str, Any]]"
ddl_statements: "NotRequired[tuple[str, ...] | list[str]]"
logger: "NotRequired[Logger]"
encryption_config: "NotRequired[EncryptionConfig | dict[str, Any]]"
database_dialect: "NotRequired[DatabaseDialect]"
database_role: "NotRequired[str]"
enable_drop_protection: "NotRequired[bool]"
enable_interceptors_in_tests: "NotRequired[bool]"
proto_descriptors: "NotRequired[bytes]"
extra: "NotRequired[dict[str, Any]]"


class SpannerPoolParams(SpannerConnectionParams):
"""Session pool configuration."""

pool_type: "NotRequired[type[AbstractSessionPool]]"
min_sessions: "NotRequired[int]"
size: "NotRequired[int]"
target_size: "NotRequired[int]"
max_sessions: "NotRequired[int]"
default_timeout: "NotRequired[int | float]"
session_labels: "NotRequired[dict[str, str]]"
labels: "NotRequired[dict[str, str]]"
ping_interval: "NotRequired[int]"
max_age_minutes: "NotRequired[int]"


class SpannerDriverFeatures(TypedDict):
Expand All @@ -66,7 +135,10 @@ class SpannerDriverFeatures(TypedDict):
enable_uuid_conversion: Enable automatic UUID string conversion.
json_serializer: Custom JSON serializer for parameter conversion.
json_deserializer: Custom JSON deserializer for result conversion.
session_labels: Labels to apply to Spanner sessions.
retry: Per-request retry policy passed to execute_sql(), execute_update(), and batch_update().
timeout: Per-request timeout in seconds passed to execute_sql(), execute_update(), and batch_update().
session_labels: Deprecated compatibility alias for pool session labels.
Prefer ``connection_config["session_labels"]``.
enable_events: Enable database event channel support.
Defaults to True when extension_config["events"] is configured.
events_backend: Backend type for event handling.
Expand All @@ -76,6 +148,8 @@ class SpannerDriverFeatures(TypedDict):
enable_uuid_conversion: "NotRequired[bool]"
json_serializer: "NotRequired[Callable[[Any], str]]"
json_deserializer: "NotRequired[Callable[[str], Any]]"
retry: "NotRequired[Retry | None]"
timeout: "NotRequired[float | None]"
session_labels: "NotRequired[dict[str, str]]"
enable_events: "NotRequired[bool]"
events_backend: "NotRequired[str]"
Expand Down Expand Up @@ -190,12 +264,23 @@ def __init__(
**kwargs: Any,
) -> None:
self.connection_config = normalize_connection_config(connection_config)
if "min_sessions" in self.connection_config:
msg = "Spanner session pools do not support 'min_sessions'; use 'size' or 'target_size'."
raise ImproperConfigurationError(msg)

raw_driver_features: dict[str, Any] = dict(driver_features) if driver_features else {}
legacy_session_labels = raw_driver_features.pop("session_labels", None)
if (
legacy_session_labels is not None
and "session_labels" not in self.connection_config
and "labels" not in self.connection_config
):
self.connection_config["session_labels"] = legacy_session_labels

self.connection_config.setdefault("min_sessions", 1)
self.connection_config.setdefault("max_sessions", 10)
self.connection_config.setdefault("size", self.connection_config.pop("max_sessions", 10))
self.connection_config.setdefault("pool_type", FixedSizePool)

driver_features = apply_driver_features(driver_features)
driver_features = apply_driver_features(raw_driver_features)

statement_config = statement_config or default_statement_config

Expand All @@ -216,11 +301,8 @@ def __init__(

def _get_client(self) -> Client:
if self._client is None:
self._client = Client(
project=self.connection_config.get("project"),
credentials=self.connection_config.get("credentials"),
client_options=self.connection_config.get("client_options"),
)
client_kwargs = self._resolve_kwargs(_CLIENT_CONFIG_FIELDS)
self._client = Client(**client_kwargs)
return self._client

def get_database(self) -> "Database":
Expand All @@ -235,7 +317,15 @@ def get_database(self) -> "Database":

if self._database is None:
client = self._get_client()
self._database = client.instance(instance_id).database(database_id, pool=self.connection_instance) # type: ignore[no-untyped-call]
instance_kwargs = self._resolve_kwargs(_INSTANCE_CONFIG_FIELDS)
instance_labels = self.connection_config.get("instance_labels")
if instance_labels is not None:
instance_kwargs["labels"] = instance_labels
database_kwargs = self._resolve_kwargs(_DATABASE_CONFIG_FIELDS)
database_kwargs["pool"] = self.connection_instance
self._database = client.instance(instance_id, **instance_kwargs).database( # type: ignore[no-untyped-call]
database_id, **database_kwargs
)
return self._database

def create_connection(self) -> SpannerConnection:
Expand All @@ -250,23 +340,38 @@ def _create_pool(self) -> AbstractSessionPool:

pool_type = cast("type[AbstractSessionPool]", self.connection_config.get("pool_type", FixedSizePool))

pool_kwargs: dict[str, Any] = {}
if pool_type is FixedSizePool:
if "size" in self.connection_config:
pool_kwargs["size"] = self.connection_config["size"]
elif "max_sessions" in self.connection_config:
pool_kwargs["size"] = self.connection_config["max_sessions"]
if "labels" in self.connection_config:
pool_kwargs["labels"] = self.connection_config["labels"]
labels = self.connection_config.get("session_labels", self.connection_config.get("labels"))
pool_kwargs: dict[str, Any] = self._resolve_pool_base_kwargs(labels=cast("dict[str, str] | None", labels))
if issubclass(pool_type, PingingPool):
pool_kwargs.update(self._resolve_kwargs({"size", "default_timeout", "ping_interval"}))
elif issubclass(pool_type, FixedSizePool):
pool_kwargs.update(self._resolve_kwargs({"size", "default_timeout", "max_age_minutes"}))
elif issubclass(pool_type, BurstyPool):
target_size = self.connection_config.get("target_size", self.connection_config.get("size"))
if target_size is not None:
pool_kwargs["target_size"] = target_size
else:
valid_pool_keys = {"size", "labels", "ping_interval"}
pool_kwargs = {k: v for k, v in self.connection_config.items() if k in valid_pool_keys and v is not None}
if "size" not in pool_kwargs and "max_sessions" in self.connection_config:
pool_kwargs["size"] = self.connection_config["max_sessions"]
pool_kwargs.update(
self._resolve_kwargs({"size", "target_size", "default_timeout", "ping_interval", "max_age_minutes"})
)

pool_factory = cast("Callable[..., AbstractSessionPool]", pool_type)
return pool_factory(**pool_kwargs)

def _resolve_pool_base_kwargs(self, *, labels: "dict[str, str] | None") -> dict[str, Any]:
pool_kwargs: dict[str, Any] = {}
if labels is not None:
pool_kwargs["labels"] = labels
database_role = self.connection_config.get("database_role")
if database_role is not None:
pool_kwargs["database_role"] = database_role
return pool_kwargs

def _resolve_kwargs(self, fields: "frozenset[str] | set[str]") -> dict[str, Any]:
return {
field: self.connection_config[field] for field in fields if self.connection_config.get(field) is not None
}

def _close_pool(self) -> None:
if self.connection_instance and supports_close(self.connection_instance):
self.connection_instance.close()
Expand Down
30 changes: 22 additions & 8 deletions sqlspec/adapters/spanner/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,27 @@ def __iter__(self) -> Iterator[Any]: ...

class _SpannerReadProtocol(Protocol):
def execute_sql(
self, sql: str, params: "dict[str, Any] | None" = None, param_types: "dict[str, Any] | None" = None
self,
sql: str,
params: "dict[str, Any] | None" = None,
param_types: "dict[str, Any] | None" = None,
**kwargs: Any,
) -> _SpannerResultSetProtocol: ...


class _SpannerWriteProtocol(_SpannerReadProtocol, Protocol):
committed: "Any | None"

def execute_update(
self, sql: str, params: "dict[str, Any] | None" = None, param_types: "dict[str, Any] | None" = None
self,
sql: str,
params: "dict[str, Any] | None" = None,
param_types: "dict[str, Any] | None" = None,
**kwargs: Any,
) -> int: ...

def batch_update(
self, batch: "list[tuple[str, dict[str, Any] | None, dict[str, Any]]]"
self, batch: "list[tuple[str, dict[str, Any] | None, dict[str, Any]]]", **kwargs: Any
) -> "tuple[Any, list[int]]": ...

def commit(self) -> None: ...
Expand Down Expand Up @@ -138,10 +146,11 @@ def dispatch_execute(self, cursor: "SpannerConnection", statement: "SQL") -> Exe
params = cast("dict[str, Any] | None", params)
coerced_params = self._coerce_params(params)
param_types_map = self._infer_param_types(coerced_params)
execute_kwargs = self._execute_kwargs()

if statement.returns_rows():
reader = cast("_SpannerReadProtocol", cursor)
result_set = reader.execute_sql(sql, params=coerced_params, param_types=param_types_map)
result_set = reader.execute_sql(sql, params=coerced_params, param_types=param_types_map, **execute_kwargs)
rows = list(result_set)
try:
metadata = result_set.metadata
Expand All @@ -165,7 +174,7 @@ def dispatch_execute(self, cursor: "SpannerConnection", statement: "SQL") -> Exe

if supports_write(cursor):
writer = cast("_SpannerWriteProtocol", cursor)
row_count = writer.execute_update(sql, params=coerced_params, param_types=param_types_map)
row_count = writer.execute_update(sql, params=coerced_params, param_types=param_types_map, **execute_kwargs)
return self.create_execution_result(cursor, rowcount_override=row_count)

raise SQLConversionError(_READ_ONLY_SNAPSHOT_ERROR_MESSAGE)
Expand All @@ -183,6 +192,7 @@ def dispatch_execute_many(self, cursor: "SpannerConnection", statement: "SQL") -

_coerce = self._coerce_params
_infer = self._infer_param_types
execute_kwargs = self._execute_kwargs()
param_types_cache: dict[tuple[tuple[str, type[Any]], ...], dict[str, Any]] = {}
empty_param_types: dict[str, Any] = {}
batch_args: list[tuple[str, dict[str, Any] | None, dict[str, Any]]] = []
Expand All @@ -200,7 +210,7 @@ def dispatch_execute_many(self, cursor: "SpannerConnection", statement: "SQL") -
append_batch_arg((sql, coerced_params, param_types))

writer = cast("_SpannerWriteProtocol", cursor)
_status, row_counts = writer.batch_update(batch_args)
_status, row_counts = writer.batch_update(batch_args, **execute_kwargs)
total_rows = sum(row_counts) if row_counts else 0

return self.create_execution_result(cursor, rowcount_override=total_rows, is_many_result=True)
Expand All @@ -215,6 +225,7 @@ def dispatch_execute_script(self, cursor: "SpannerConnection", statement: "SQL")
script_params = cast("dict[str, Any] | None", params)
coerced_params = self._coerce_params(script_params)
param_types_map = self._infer_param_types(coerced_params)
execute_kwargs = self._execute_kwargs()
for stmt in statements:
try:
parsed = _sqlglot.parse_one(stmt)
Expand All @@ -225,9 +236,9 @@ def dispatch_execute_script(self, cursor: "SpannerConnection", statement: "SQL")
raise SQLConversionError(_READ_ONLY_SNAPSHOT_ERROR_MESSAGE)
if not is_select and is_transaction:
writer = cast("_SpannerWriteProtocol", cursor)
writer.execute_update(stmt, params=coerced_params, param_types=param_types_map)
writer.execute_update(stmt, params=coerced_params, param_types=param_types_map, **execute_kwargs)
else:
_ = list(reader.execute_sql(stmt, params=coerced_params, param_types=param_types_map))
_ = list(reader.execute_sql(stmt, params=coerced_params, param_types=param_types_map, **execute_kwargs))
count += 1

return self.create_execution_result(
Expand Down Expand Up @@ -259,6 +270,9 @@ def with_cursor(self, connection: "SpannerConnection") -> "SpannerSyncCursor":
def handle_database_exceptions(self) -> "SpannerExceptionHandler":
return SpannerExceptionHandler()

def _execute_kwargs(self) -> dict[str, Any]:
return {key: self.driver_features[key] for key in ("retry", "timeout") if key in self.driver_features}

# ─────────────────────────────────────────────────────────────────────────────
# ARROW API METHODS
# ─────────────────────────────────────────────────────────────────────────────
Expand Down
3 changes: 1 addition & 2 deletions tests/integration/adapters/spanner/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ def spanner_config(
"database_id": spanner_service.database_name,
"credentials": spanner_service.credentials,
"client_options": {"api_endpoint": api_endpoint},
"min_sessions": 1,
"max_sessions": 5,
"size": 5,
}
)
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ def spanner_adk_config(spanner_service: SpannerService, spanner_database: "Datab
"database_id": spanner_service.database_name,
"credentials": spanner_service.credentials,
"client_options": {"api_endpoint": api_endpoint},
"min_sessions": 1,
"max_sessions": 5,
"size": 5,
},
extension_config={"adk": {"session_table": "adk_sessions", "events_table": "adk_events"}},
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ def spanner_events_config(spanner_service: SpannerService, spanner_database: "Da
"database_id": spanner_service.database_name,
"credentials": spanner_service.credentials,
"client_options": {"api_endpoint": api_endpoint},
"min_sessions": 1,
"max_sessions": 5,
"size": 5,
},
extension_config={"events": {"queue_table": "sqlspec_event_queue"}},
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ def spanner_litestar_config(spanner_service: SpannerService, spanner_database: "
"database_id": spanner_service.database_name,
"credentials": spanner_service.credentials,
"client_options": {"api_endpoint": api_endpoint},
"min_sessions": 1,
"max_sessions": 5,
"size": 5,
},
extension_config={"litestar": {"session_table": "litestar_sessions"}},
)
Expand Down
Loading
Loading