Skip to content

Commit d783d63

Browse files
feat: automatically reset connection on failover
1 parent 3b01b0d commit d783d63

8 files changed

Lines changed: 197 additions & 10 deletions

File tree

google/cloud/sql/connector/connection_info.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import abc
1718
from dataclasses import dataclass
1819
import logging
1920
import ssl
@@ -34,6 +35,27 @@
3435
logger = logging.getLogger(name=__name__)
3536

3637

38+
class ConnectionInfoCache(abc.ABC):
39+
"""Abstract class for Connector connection info caches."""
40+
41+
@abc.abstractmethod
42+
async def connect_info(self) -> ConnectionInfo:
43+
pass
44+
45+
@abc.abstractmethod
46+
async def force_refresh(self) -> None:
47+
pass
48+
49+
@abc.abstractmethod
50+
async def close(self) -> None:
51+
pass
52+
53+
@property
54+
@abc.abstractmethod
55+
def closed(self) -> bool:
56+
pass
57+
58+
3759
@dataclass
3860
class ConnectionInfo:
3961
"""Contains all necessary information to connect securely to the

google/cloud/sql/connector/connection_name.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ def __str__(self) -> str:
3838
return f"{self.domain_name} -> {self.project}:{self.region}:{self.instance_name}"
3939
return f"{self.project}:{self.region}:{self.instance_name}"
4040

41+
def get_connection_string(self) -> str:
42+
"""Get the instance connection string for the Cloud SQL instance."""
43+
return f"{self.project}:{self.region}:{self.instance_name}"
44+
4145

4246
def _parse_connection_name(connection_name: str) -> ConnectionName:
4347
return _parse_connection_name_with_domain_name(connection_name, "")

google/cloud/sql/connector/connector.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from google.cloud.sql.connector.enums import RefreshStrategy
3636
from google.cloud.sql.connector.instance import RefreshAheadCache
3737
from google.cloud.sql.connector.lazy import LazyRefreshCache
38+
from google.cloud.sql.connector.monitored_cache import MonitoredCache
3839
import google.cloud.sql.connector.pg8000 as pg8000
3940
import google.cloud.sql.connector.pymysql as pymysql
4041
import google.cloud.sql.connector.pytds as pytds
@@ -149,9 +150,7 @@ def __init__(
149150
)
150151
# initialize dict to store caches, key is a tuple consisting of instance
151152
# connection name string and enable_iam_auth boolean flag
152-
self._cache: dict[
153-
tuple[str, bool], Union[RefreshAheadCache, LazyRefreshCache]
154-
] = {}
153+
self._cache: dict[tuple[str, bool], MonitoredCache] = {}
155154
self._client: Optional[CloudSQLClient] = None
156155

157156
# initialize credentials
@@ -289,14 +288,14 @@ async def connect_async(
289288
)
290289
enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth)
291290
if (instance_connection_string, enable_iam_auth) in self._cache:
292-
cache = self._cache[(instance_connection_string, enable_iam_auth)]
291+
monitored_cache = self._cache[(instance_connection_string, enable_iam_auth)]
293292
else:
294293
conn_name = await self._resolver.resolve(instance_connection_string)
295294
if self._refresh_strategy == RefreshStrategy.LAZY:
296295
logger.debug(
297296
f"['{conn_name}']: Refresh strategy is set to lazy refresh"
298297
)
299-
cache = LazyRefreshCache(
298+
cache: Union[LazyRefreshCache, RefreshAheadCache] = LazyRefreshCache(
300299
conn_name,
301300
self._client,
302301
self._keys,
@@ -312,8 +311,14 @@ async def connect_async(
312311
self._keys,
313312
enable_iam_auth,
314313
)
314+
# wrap cache as a MonitoredCache
315+
monitored_cache = MonitoredCache(
316+
cache,
317+
self._failover_period,
318+
self._resolver,
319+
)
315320
logger.debug(f"['{conn_name}']: Connection info added to cache")
316-
self._cache[(instance_connection_string, enable_iam_auth)] = cache
321+
self._cache[(instance_connection_string, enable_iam_auth)] = monitored_cache
317322

318323
connect_func = {
319324
"pymysql": pymysql.connect,
@@ -342,7 +347,7 @@ async def connect_async(
342347

343348
# attempt to get connection info for Cloud SQL instance
344349
try:
345-
conn_info = await cache.connect_info()
350+
conn_info = await monitored_cache.connect_info()
346351
# validate driver matches intended database engine
347352
DriverMapping.validate_engine(driver, conn_info.database_version)
348353
ip_address = conn_info.get_preferred_ip(ip_type)

google/cloud/sql/connector/instance.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from google.cloud.sql.connector.client import CloudSQLClient
2626
from google.cloud.sql.connector.connection_info import ConnectionInfo
27+
from google.cloud.sql.connector.connection_info import ConnectionInfoCache
2728
from google.cloud.sql.connector.connection_name import ConnectionName
2829
from google.cloud.sql.connector.exceptions import RefreshNotValidError
2930
from google.cloud.sql.connector.rate_limiter import AsyncRateLimiter
@@ -35,7 +36,7 @@
3536
APPLICATION_NAME = "cloud-sql-python-connector"
3637

3738

38-
class RefreshAheadCache:
39+
class RefreshAheadCache(ConnectionInfoCache):
3940
"""Cache that refreshes connection info in the background prior to expiration.
4041
4142
Background tasks are used to schedule refresh attempts to get a new
@@ -74,6 +75,15 @@ def __init__(
7475
self._refresh_in_progress = asyncio.locks.Event()
7576
self._current: asyncio.Task = self._schedule_refresh(0)
7677
self._next: asyncio.Task = self._current
78+
self._closed = False
79+
80+
@property
81+
def conn_name(self) -> ConnectionName:
82+
return self._conn_name
83+
84+
@property
85+
def closed(self) -> bool:
86+
return self._closed
7787

7888
async def force_refresh(self) -> None:
7989
"""
@@ -212,3 +222,4 @@ async def close(self) -> None:
212222
# gracefully wait for tasks to cancel
213223
tasks = asyncio.gather(self._current, self._next, return_exceptions=True)
214224
await asyncio.wait_for(tasks, timeout=2.0)
225+
self._closed = True

google/cloud/sql/connector/lazy.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@
2121

2222
from google.cloud.sql.connector.client import CloudSQLClient
2323
from google.cloud.sql.connector.connection_info import ConnectionInfo
24+
from google.cloud.sql.connector.connection_info import ConnectionInfoCache
2425
from google.cloud.sql.connector.connection_name import ConnectionName
2526
from google.cloud.sql.connector.refresh_utils import _refresh_buffer
2627

2728
logger = logging.getLogger(name=__name__)
2829

2930

30-
class LazyRefreshCache:
31+
class LazyRefreshCache(ConnectionInfoCache):
3132
"""Cache that refreshes connection info when a caller requests a connection.
3233
3334
Only refreshes the cache when a new connection is requested and the current
@@ -62,6 +63,15 @@ def __init__(
6263
self._lock = asyncio.Lock()
6364
self._cached: Optional[ConnectionInfo] = None
6465
self._needs_refresh = False
66+
self._closed = False
67+
68+
@property
69+
def conn_name(self) -> ConnectionName:
70+
return self._conn_name
71+
72+
@property
73+
def closed(self) -> bool:
74+
return self._closed
6575

6676
async def force_refresh(self) -> None:
6777
"""
@@ -121,4 +131,5 @@ async def close(self) -> None:
121131
"""Close is a no-op and provided purely for a consistent interface with
122132
other cache types.
123133
"""
124-
pass
134+
self._closed = True
135+
return
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import asyncio
16+
import logging
17+
from typing import Any, Callable, Optional, Union
18+
19+
from google.cloud.sql.connector.connection_info import ConnectionInfo
20+
from google.cloud.sql.connector.connection_info import ConnectionInfoCache
21+
from google.cloud.sql.connector.instance import RefreshAheadCache
22+
from google.cloud.sql.connector.lazy import LazyRefreshCache
23+
from google.cloud.sql.connector.resolver import DefaultResolver
24+
from google.cloud.sql.connector.resolver import DnsResolver
25+
26+
logger = logging.getLogger(name=__name__)
27+
28+
29+
class MonitoredCache(ConnectionInfoCache):
30+
def __init__(
31+
self,
32+
cache: Union[RefreshAheadCache, LazyRefreshCache],
33+
failover_period: int,
34+
resolver: Union[DefaultResolver, DnsResolver],
35+
) -> None:
36+
self.resolver = resolver
37+
self.cache = cache
38+
self.domain_name_ticker: Optional[asyncio.Task] = None
39+
self.open_conns_count: int = 0
40+
41+
if self.cache.conn_name.domain_name:
42+
self.domain_name_ticker = asyncio.create_task(
43+
ticker(failover_period, self._check_domain_name)
44+
)
45+
logger.debug(
46+
f"['{self.cache.conn_name}']: Configured polling of domain "
47+
f"name with failover period of {failover_period} seconds."
48+
)
49+
50+
@property
51+
def closed(self) -> bool:
52+
return self.cache.closed
53+
54+
async def _check_domain_name(self) -> None:
55+
try:
56+
# Resolve domain name and see if Cloud SQL instance connection name
57+
# has changed. If it has, close all connections.
58+
new_conn_name = await self.resolver.resolve(
59+
self.cache.conn_name.domain_name
60+
)
61+
if new_conn_name != self.cache.conn_name:
62+
logger.debug(
63+
f"['{self.cache.conn_name}']: Cloud SQL instance changed "
64+
f"from {self.cache.conn_name.get_connection_string()} to "
65+
f"{new_conn_name.get_connection_string()}, closing all "
66+
"connections!"
67+
)
68+
await self.close()
69+
70+
except Exception as e:
71+
# Domain name checks should not be fatal, log error and continue.
72+
logger.debug(
73+
f"['{self.cache.conn_name}']: Unable to check domain name, "
74+
f"domain name {self.cache.conn_name.domain_name} did not "
75+
f"resolve: {e}"
76+
)
77+
78+
async def connect_info(self) -> ConnectionInfo:
79+
return await self.cache.connect_info()
80+
81+
async def force_refresh(self) -> None:
82+
return await self.cache.force_refresh()
83+
84+
async def close(self) -> None:
85+
# Cancel domain name ticker task.
86+
if self.domain_name_ticker:
87+
self.domain_name_ticker.cancel()
88+
try:
89+
await self.domain_name_ticker
90+
except asyncio.CancelledError:
91+
logger.debug(
92+
f"['{self.cache.conn_name}']: Cancelled domain name polling task."
93+
)
94+
95+
# If cache is already closed, no further work.
96+
if self.cache.closed:
97+
return
98+
await self.cache.close()
99+
100+
101+
async def ticker(interval: int, function: Callable, *args: Any, **kwargs: Any) -> None:
102+
"""
103+
Ticker function to sleep for specified interval and then schedule call
104+
to given function.
105+
"""
106+
while True:
107+
# Sleep for interval and then schedule task
108+
await asyncio.sleep(interval)
109+
asyncio.create_task(function(*args, **kwargs))

tests/unit/test_connection_name.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ def test_ConnectionName() -> None:
3030
assert conn_name.domain_name == ""
3131
# test ConnectionName str() method prints instance connection name
3232
assert str(conn_name) == "project:region:instance"
33+
# test ConnectionName.get_connection_string
34+
assert conn_name.get_connection_string() == "project:region:instance"
3335

3436

3537
def test_ConnectionName_with_domain_name() -> None:
@@ -41,6 +43,8 @@ def test_ConnectionName_with_domain_name() -> None:
4143
assert conn_name.domain_name == "db.example.com"
4244
# test ConnectionName str() method prints with domain name
4345
assert str(conn_name) == "db.example.com -> project:region:instance"
46+
# test ConnectionName.get_connection_string
47+
assert conn_name.get_connection_string() == "project:region:instance"
4448

4549

4650
@pytest.mark.parametrize(

tests/unit/test_lazy.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,27 @@
2121
from google.cloud.sql.connector.utils import generate_keys
2222

2323

24+
async def test_LazyRefreshCache_properties(fake_client: CloudSQLClient) -> None:
25+
"""
26+
Test that LazyRefreshCache properties work as expected.
27+
"""
28+
keys = asyncio.create_task(generate_keys())
29+
conn_name = ConnectionName("test-project", "test-region", "test-instance")
30+
cache = LazyRefreshCache(
31+
conn_name,
32+
client=fake_client,
33+
keys=keys,
34+
enable_iam_auth=False,
35+
)
36+
# test conn_name property
37+
assert cache.conn_name == conn_name
38+
# test closed property
39+
assert cache.closed is False
40+
# close cache and make sure property is updated
41+
await cache.close()
42+
assert cache.closed is True
43+
44+
2445
async def test_LazyRefreshCache_connect_info(fake_client: CloudSQLClient) -> None:
2546
"""
2647
Test that LazyRefreshCache.connect_info works as expected.

0 commit comments

Comments
 (0)