Skip to content

Commit 173951e

Browse files
Uziel Silvathameezb
authored andcommitted
feat: Add proxy server and fix all unit tests
1 parent 6933d99 commit 173951e

8 files changed

Lines changed: 1049 additions & 289 deletions

File tree

google/cloud/sql/connector/connector.py

Lines changed: 154 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from functools import partial
2121
import logging
2222
import os
23-
import socket
2423
from threading import Thread
2524
from types import TracebackType
2625
from typing import Any, Callable, Optional, Union
@@ -38,10 +37,9 @@
3837
from google.cloud.sql.connector.exceptions import ConnectorLoopError
3938
from google.cloud.sql.connector.instance import RefreshAheadCache
4039
from google.cloud.sql.connector.lazy import LazyRefreshCache
41-
import google.cloud.sql.connector.local_unix_socket as local_unix_socket
4240
from google.cloud.sql.connector.monitored_cache import MonitoredCache
4341
import google.cloud.sql.connector.pg8000 as pg8000
44-
from google.cloud.sql.connector.proxy import Proxy
42+
import google.cloud.sql.connector.proxy as proxy
4543
import google.cloud.sql.connector.pymysql as pymysql
4644
import google.cloud.sql.connector.pytds as pytds
4745
from google.cloud.sql.connector.resolver import DefaultResolver
@@ -220,6 +218,108 @@ def __init__(
220218
def universe_domain(self) -> str:
221219
return self._universe_domain or _DEFAULT_UNIVERSE_DOMAIN
222220

221+
async def _get_cache(
222+
self,
223+
instance_connection_string: str,
224+
enable_iam_auth: bool,
225+
ip_type: IPTypes,
226+
driver: str | None,
227+
) -> MonitoredCache:
228+
"""Helper function to get instance's cache from Connector cache."""
229+
230+
# resolve instance connection name
231+
conn_name = await self._resolver.resolve(instance_connection_string)
232+
cache_key = (str(conn_name), enable_iam_auth)
233+
234+
# if cache entry doesn't exist or is closed, create it
235+
if cache_key not in self._cache or self._cache[cache_key].closed:
236+
# if lazy refresh, init keys now
237+
if self._refresh_strategy == RefreshStrategy.LAZY and self._keys is None:
238+
self._keys = asyncio.create_task(generate_keys())
239+
# create cache
240+
if self._refresh_strategy == RefreshStrategy.LAZY:
241+
logger.debug(
242+
f"['{conn_name}']: Refresh strategy is set to lazy refresh"
243+
)
244+
cache: Union[LazyRefreshCache, RefreshAheadCache] = LazyRefreshCache(
245+
conn_name,
246+
self._init_client(driver),
247+
self._keys, # type: ignore
248+
enable_iam_auth,
249+
)
250+
else:
251+
logger.debug(
252+
f"['{conn_name}']: Refresh strategy is set to background refresh"
253+
)
254+
cache = RefreshAheadCache(
255+
conn_name,
256+
self._init_client(driver),
257+
self._keys, # type: ignore
258+
enable_iam_auth,
259+
)
260+
# wrap cache as a MonitoredCache
261+
monitored_cache = MonitoredCache(
262+
cache,
263+
self._failover_period,
264+
self._resolver,
265+
)
266+
logger.debug(f"['{conn_name}']: Connection info added to cache")
267+
self._cache[cache_key] = monitored_cache
268+
269+
monitored_cache = self._cache[(str(conn_name), enable_iam_auth)]
270+
271+
# Check that the information is valid and matches the driver and db type
272+
try:
273+
conn_info = await monitored_cache.connect_info()
274+
# validate driver matches intended database engine
275+
if driver:
276+
DriverMapping.validate_engine(driver, conn_info.database_version)
277+
if ip_type:
278+
conn_info.get_preferred_ip(ip_type)
279+
except Exception:
280+
await self._remove_cached(str(conn_name), enable_iam_auth)
281+
raise
282+
283+
return monitored_cache
284+
285+
async def connect_socket_async(
286+
self,
287+
instance_connection_string: str,
288+
protocol_fn: Callable[[], asyncio.Protocol],
289+
**kwargs: Any,
290+
) -> tuple[asyncio.Transport, asyncio.Protocol]:
291+
"""Helper function to connect to a Cloud SQL instance and return a socket."""
292+
293+
enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth)
294+
ip_type = kwargs.pop("ip_type", self._ip_type)
295+
driver = kwargs.pop("driver", None)
296+
# if ip_type is str, convert to IPTypes enum
297+
if isinstance(ip_type, str):
298+
ip_type = IPTypes._from_str(ip_type)
299+
300+
monitored_cache = await self._get_cache(
301+
instance_connection_string, enable_iam_auth, ip_type, driver
302+
)
303+
304+
try:
305+
conn_info = await monitored_cache.connect_info()
306+
ctx = await conn_info.create_ssl_context(enable_iam_auth)
307+
ip_address = conn_info.get_preferred_ip(ip_type)
308+
tx, p = await self._loop.create_connection(
309+
protocol_fn, host=ip_address, port=3307, ssl=ctx
310+
)
311+
except Exception as ex:
312+
logger.exception("exception starting tls protocol", exc_info=ex)
313+
# with an error from Cloud SQL Admin API call or IP type, invalidate
314+
# the cache and re-raise the error
315+
await self._remove_cached(
316+
instance_connection_string,
317+
enable_iam_auth,
318+
)
319+
raise
320+
321+
return tx, p
322+
223323
def connect(
224324
self, instance_connection_string: str, driver: str, **kwargs: Any
225325
) -> Any:
@@ -237,7 +337,7 @@ def connect(
237337
Example: "my-project:us-central1:my-instance"
238338
239339
driver (str): A string representing the database driver to connect
240-
with. Supported drivers are pymysql, pg8000, local_unix_socket, and pytds.
340+
with. Supported drivers are pymysql, pg8000, psycopg, and pytds.
241341
242342
**kwargs: Any driver-specific arguments to pass to the underlying
243343
driver .connect call.
@@ -261,6 +361,18 @@ def connect(
261361
)
262362
return connect_future.result()
263363

364+
def _init_client(self, driver: Optional[str]) -> CloudSQLClient:
365+
"""Lazy initialize the client, setting the driver name in the user agent string."""
366+
if self._client is None:
367+
self._client = CloudSQLClient(
368+
self._sqladmin_api_endpoint,
369+
self._quota_project,
370+
self._credentials,
371+
user_agent=self._user_agent,
372+
driver=driver,
373+
)
374+
return self._client
375+
264376
async def connect_async(
265377
self, instance_connection_string: str, driver: str, **kwargs: Any
266378
) -> Any:
@@ -279,7 +391,7 @@ async def connect_async(
279391
Example: "my-project:us-central1:my-instance"
280392
281393
driver (str): A string representing the database driver to connect
282-
with. Supported drivers are pymysql, asyncpg, pg8000, local_unix_socket, and
394+
with. Supported drivers are pymysql, asyncpg, pg8000, psycopg, and
283395
pytds.
284396
285397
**kwargs: Any driver-specific arguments to pass to the underlying
@@ -358,15 +470,15 @@ async def connect_async(
358470
logger.debug(f"['{conn_name}']: Connection info added to cache")
359471
self._cache[(str(conn_name), enable_iam_auth)] = monitored_cache
360472

473+
# Map drivers to connect functions
361474
connect_func = {
362475
"pymysql": pymysql.connect,
363476
"pg8000": pg8000.connect,
364-
"local_unix_socket": local_unix_socket.connect,
365477
"asyncpg": asyncpg.connect,
366478
"pytds": pytds.connect,
367479
}
368480

369-
# only accept supported database drivers
481+
# Only accept supported database drivers
370482
try:
371483
connector: Callable = connect_func[driver] # type: ignore
372484
except KeyError:
@@ -376,6 +488,7 @@ async def connect_async(
376488
# if ip_type is str, convert to IPTypes enum
377489
if isinstance(ip_type, str):
378490
ip_type = IPTypes._from_str(ip_type)
491+
enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth)
379492
kwargs["timeout"] = kwargs.get("timeout", self._timeout)
380493

381494
# Host and ssl options come from the certificates and metadata, so we don't
@@ -384,7 +497,12 @@ async def connect_async(
384497
kwargs.pop("ssl", None)
385498
kwargs.pop("port", None)
386499

387-
# attempt to get connection info for Cloud SQL instance
500+
monitored_cache = await self._get_cache(
501+
instance_connection_string, enable_iam_auth, ip_type, driver
502+
)
503+
conn_info = await monitored_cache.connect_info()
504+
ip_address = conn_info.get_preferred_ip(ip_type)
505+
388506
try:
389507
conn_info = await monitored_cache.connect_info()
390508
# validate driver matches intended database engine
@@ -430,39 +548,31 @@ async def connect_async(
430548
)
431549
if formatted_user != kwargs["user"]:
432550
logger.debug(
433-
f"['{instance_connection_string}']: Truncated IAM database username from {kwargs['user']} to {formatted_user}"
551+
f"['{instance_connection_string}']: "
552+
"Truncated IAM database username from "
553+
f"{kwargs['user']} to {formatted_user}"
434554
)
435555
kwargs["user"] = formatted_user
436-
try:
556+
557+
ctx = await conn_info.create_ssl_context(enable_iam_auth)
437558
# async drivers are unblocking and can be awaited directly
438559
if driver in ASYNC_DRIVERS:
439-
return await connector(
440-
ip_address,
441-
await conn_info.create_ssl_context(enable_iam_auth),
442-
**kwargs,
560+
return await connector(ip_address, ctx, **kwargs)
561+
else:
562+
# Synchronous drivers are blocking and run using executor
563+
tx, _ = await self.connect_socket_async(
564+
instance_connection_string, asyncio.Protocol, **kwargs
443565
)
444-
# Create socket with SSLContext for sync drivers
445-
ctx = await conn_info.create_ssl_context(enable_iam_auth)
446-
sock = ctx.wrap_socket(
447-
socket.create_connection((ip_address, SERVER_PROXY_PORT)),
448-
server_hostname=ip_address,
449-
)
450-
451-
# If this connection was opened using a domain name, then store it
452-
# for later in case we need to forcibly close it on failover.
453-
if conn_info.conn_name.domain_name:
454-
monitored_cache.sockets.append(sock)
455-
# Synchronous drivers are blocking and run using executor
456-
connect_partial = partial(
457-
connector,
458-
ip_address,
459-
sock,
460-
**kwargs,
461-
)
462-
return await self._loop.run_in_executor(None, connect_partial)
566+
# See https://docs.python.org/3/library/asyncio-protocol.html#asyncio.BaseTransport.get_extra_info
567+
sock = tx.get_extra_info("ssl_object")
568+
connect_partial = partial(connector, ip_address, sock, **kwargs)
569+
return await self._loop.run_in_executor(None, connect_partial)
463570

464571
except Exception:
465572
# with any exception, we attempt a force refresh, then throw the error
573+
monitored_cache = await self._get_cache(
574+
instance_connection_string, enable_iam_auth, ip_type, driver
575+
)
466576
await monitored_cache.force_refresh()
467577
raise
468578

@@ -501,7 +611,7 @@ async def start_unix_socket_proxy_async(
501611
)
502612
await proxy_instance.start()
503613
self._proxies.append(proxy_instance)
504-
614+
505615
async def _remove_cached(
506616
self, instance_connection_string: str, enable_iam_auth: bool
507617
) -> None:
@@ -567,7 +677,6 @@ async def close_async(self) -> None:
567677
await self._client.close()
568678
await asyncio.gather(*[cache.close() for cache in self._cache.values()])
569679

570-
571680
async def create_async_connector(
572681
ip_type: str | IPTypes = IPTypes.PUBLIC,
573682
enable_iam_auth: bool = False,
@@ -657,3 +766,13 @@ async def create_async_connector(
657766
resolver=resolver,
658767
failover_period=failover_period,
659768
)
769+
770+
771+
class ConnectorSocketFactory(proxy.ServerConnectionFactory):
772+
def __init__(self, connector:Connector, instance_connection_string:str, **kwargs):
773+
self._connector = connector
774+
self._instance_connection_string = instance_connection_string
775+
self._connect_args=kwargs
776+
777+
async def connect(self, protocol_fn: Callable[[], asyncio.Protocol]):
778+
await self._connector.connect_socket_async(self._instance_connection_string, protocol_fn, **self._connect_args)

google/cloud/sql/connector/enums.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class DriverMapping(Enum):
6262

6363
ASYNCPG = "POSTGRES"
6464
PG8000 = "POSTGRES"
65-
LOCAL_UNIX_SOCKET = "ANY"
65+
PSYCOPG = "POSTGRES"
6666
PYMYSQL = "MYSQL"
6767
PYTDS = "SQLSERVER"
6868

@@ -79,7 +79,7 @@ def validate_engine(driver: str, engine_version: str) -> None:
7979
the given engine.
8080
"""
8181
mapping = DriverMapping[driver.upper()]
82-
if not mapping.value == "ANY" and not engine_version.startswith(mapping.value):
82+
if not engine_version.startswith(mapping.value):
8383
raise IncompatibleDriverError(
8484
f"Database driver '{driver}' is incompatible with database "
8585
f"version '{engine_version}'. Given driver can "

0 commit comments

Comments
 (0)