2020from functools import partial
2121import logging
2222import os
23- import socket
2423from threading import Thread
2524from types import TracebackType
2625from typing import Any , Callable , Optional , Union
3837from google .cloud .sql .connector .exceptions import ConnectorLoopError
3938from google .cloud .sql .connector .instance import RefreshAheadCache
4039from google .cloud .sql .connector .lazy import LazyRefreshCache
41- import google .cloud .sql .connector .local_unix_socket as local_unix_socket
4240from google .cloud .sql .connector .monitored_cache import MonitoredCache
4341import google .cloud .sql .connector .pg8000 as pg8000
44- from google .cloud .sql .connector .proxy import Proxy
42+ import google .cloud .sql .connector .proxy as proxy
4543import google .cloud .sql .connector .pymysql as pymysql
4644import google .cloud .sql .connector .pytds as pytds
4745from google .cloud .sql .connector .resolver import DefaultResolver
@@ -220,6 +218,108 @@ def __init__(
220218 def universe_domain (self ) -> str :
221219 return self ._universe_domain or _DEFAULT_UNIVERSE_DOMAIN
222220
221+ async def _get_cache (
222+ self ,
223+ instance_connection_string : str ,
224+ enable_iam_auth : bool ,
225+ ip_type : IPTypes ,
226+ driver : str | None ,
227+ ) -> MonitoredCache :
228+ """Helper function to get instance's cache from Connector cache."""
229+
230+ # resolve instance connection name
231+ conn_name = await self ._resolver .resolve (instance_connection_string )
232+ cache_key = (str (conn_name ), enable_iam_auth )
233+
234+ # if cache entry doesn't exist or is closed, create it
235+ if cache_key not in self ._cache or self ._cache [cache_key ].closed :
236+ # if lazy refresh, init keys now
237+ if self ._refresh_strategy == RefreshStrategy .LAZY and self ._keys is None :
238+ self ._keys = asyncio .create_task (generate_keys ())
239+ # create cache
240+ if self ._refresh_strategy == RefreshStrategy .LAZY :
241+ logger .debug (
242+ f"['{ conn_name } ']: Refresh strategy is set to lazy refresh"
243+ )
244+ cache : Union [LazyRefreshCache , RefreshAheadCache ] = LazyRefreshCache (
245+ conn_name ,
246+ self ._init_client (driver ),
247+ self ._keys , # type: ignore
248+ enable_iam_auth ,
249+ )
250+ else :
251+ logger .debug (
252+ f"['{ conn_name } ']: Refresh strategy is set to background refresh"
253+ )
254+ cache = RefreshAheadCache (
255+ conn_name ,
256+ self ._init_client (driver ),
257+ self ._keys , # type: ignore
258+ enable_iam_auth ,
259+ )
260+ # wrap cache as a MonitoredCache
261+ monitored_cache = MonitoredCache (
262+ cache ,
263+ self ._failover_period ,
264+ self ._resolver ,
265+ )
266+ logger .debug (f"['{ conn_name } ']: Connection info added to cache" )
267+ self ._cache [cache_key ] = monitored_cache
268+
269+ monitored_cache = self ._cache [(str (conn_name ), enable_iam_auth )]
270+
271+ # Check that the information is valid and matches the driver and db type
272+ try :
273+ conn_info = await monitored_cache .connect_info ()
274+ # validate driver matches intended database engine
275+ if driver :
276+ DriverMapping .validate_engine (driver , conn_info .database_version )
277+ if ip_type :
278+ conn_info .get_preferred_ip (ip_type )
279+ except Exception :
280+ await self ._remove_cached (str (conn_name ), enable_iam_auth )
281+ raise
282+
283+ return monitored_cache
284+
285+ async def connect_socket_async (
286+ self ,
287+ instance_connection_string : str ,
288+ protocol_fn : Callable [[], asyncio .Protocol ],
289+ ** kwargs : Any ,
290+ ) -> tuple [asyncio .Transport , asyncio .Protocol ]:
291+ """Helper function to connect to a Cloud SQL instance and return a socket."""
292+
293+ enable_iam_auth = kwargs .pop ("enable_iam_auth" , self ._enable_iam_auth )
294+ ip_type = kwargs .pop ("ip_type" , self ._ip_type )
295+ driver = kwargs .pop ("driver" , None )
296+ # if ip_type is str, convert to IPTypes enum
297+ if isinstance (ip_type , str ):
298+ ip_type = IPTypes ._from_str (ip_type )
299+
300+ monitored_cache = await self ._get_cache (
301+ instance_connection_string , enable_iam_auth , ip_type , driver
302+ )
303+
304+ try :
305+ conn_info = await monitored_cache .connect_info ()
306+ ctx = await conn_info .create_ssl_context (enable_iam_auth )
307+ ip_address = conn_info .get_preferred_ip (ip_type )
308+ tx , p = await self ._loop .create_connection (
309+ protocol_fn , host = ip_address , port = 3307 , ssl = ctx
310+ )
311+ except Exception as ex :
312+ logger .exception ("exception starting tls protocol" , exc_info = ex )
313+ # with an error from Cloud SQL Admin API call or IP type, invalidate
314+ # the cache and re-raise the error
315+ await self ._remove_cached (
316+ instance_connection_string ,
317+ enable_iam_auth ,
318+ )
319+ raise
320+
321+ return tx , p
322+
223323 def connect (
224324 self , instance_connection_string : str , driver : str , ** kwargs : Any
225325 ) -> Any :
@@ -237,7 +337,7 @@ def connect(
237337 Example: "my-project:us-central1:my-instance"
238338
239339 driver (str): A string representing the database driver to connect
240- with. Supported drivers are pymysql, pg8000, local_unix_socket , and pytds.
340+ with. Supported drivers are pymysql, pg8000, psycopg , and pytds.
241341
242342 **kwargs: Any driver-specific arguments to pass to the underlying
243343 driver .connect call.
@@ -261,6 +361,18 @@ def connect(
261361 )
262362 return connect_future .result ()
263363
364+ def _init_client (self , driver : Optional [str ]) -> CloudSQLClient :
365+ """Lazy initialize the client, setting the driver name in the user agent string."""
366+ if self ._client is None :
367+ self ._client = CloudSQLClient (
368+ self ._sqladmin_api_endpoint ,
369+ self ._quota_project ,
370+ self ._credentials ,
371+ user_agent = self ._user_agent ,
372+ driver = driver ,
373+ )
374+ return self ._client
375+
264376 async def connect_async (
265377 self , instance_connection_string : str , driver : str , ** kwargs : Any
266378 ) -> Any :
@@ -279,7 +391,7 @@ async def connect_async(
279391 Example: "my-project:us-central1:my-instance"
280392
281393 driver (str): A string representing the database driver to connect
282- with. Supported drivers are pymysql, asyncpg, pg8000, local_unix_socket , and
394+ with. Supported drivers are pymysql, asyncpg, pg8000, psycopg , and
283395 pytds.
284396
285397 **kwargs: Any driver-specific arguments to pass to the underlying
@@ -358,15 +470,15 @@ async def connect_async(
358470 logger .debug (f"['{ conn_name } ']: Connection info added to cache" )
359471 self ._cache [(str (conn_name ), enable_iam_auth )] = monitored_cache
360472
473+ # Map drivers to connect functions
361474 connect_func = {
362475 "pymysql" : pymysql .connect ,
363476 "pg8000" : pg8000 .connect ,
364- "local_unix_socket" : local_unix_socket .connect ,
365477 "asyncpg" : asyncpg .connect ,
366478 "pytds" : pytds .connect ,
367479 }
368480
369- # only accept supported database drivers
481+ # Only accept supported database drivers
370482 try :
371483 connector : Callable = connect_func [driver ] # type: ignore
372484 except KeyError :
@@ -376,6 +488,7 @@ async def connect_async(
376488 # if ip_type is str, convert to IPTypes enum
377489 if isinstance (ip_type , str ):
378490 ip_type = IPTypes ._from_str (ip_type )
491+ enable_iam_auth = kwargs .pop ("enable_iam_auth" , self ._enable_iam_auth )
379492 kwargs ["timeout" ] = kwargs .get ("timeout" , self ._timeout )
380493
381494 # Host and ssl options come from the certificates and metadata, so we don't
@@ -384,7 +497,12 @@ async def connect_async(
384497 kwargs .pop ("ssl" , None )
385498 kwargs .pop ("port" , None )
386499
387- # attempt to get connection info for Cloud SQL instance
500+ monitored_cache = await self ._get_cache (
501+ instance_connection_string , enable_iam_auth , ip_type , driver
502+ )
503+ conn_info = await monitored_cache .connect_info ()
504+ ip_address = conn_info .get_preferred_ip (ip_type )
505+
388506 try :
389507 conn_info = await monitored_cache .connect_info ()
390508 # validate driver matches intended database engine
@@ -430,39 +548,31 @@ async def connect_async(
430548 )
431549 if formatted_user != kwargs ["user" ]:
432550 logger .debug (
433- f"['{ instance_connection_string } ']: Truncated IAM database username from { kwargs ['user' ]} to { formatted_user } "
551+ f"['{ instance_connection_string } ']: "
552+ "Truncated IAM database username from "
553+ f"{ kwargs ['user' ]} to { formatted_user } "
434554 )
435555 kwargs ["user" ] = formatted_user
436- try :
556+
557+ ctx = await conn_info .create_ssl_context (enable_iam_auth )
437558 # async drivers are unblocking and can be awaited directly
438559 if driver in ASYNC_DRIVERS :
439- return await connector (
440- ip_address ,
441- await conn_info .create_ssl_context (enable_iam_auth ),
442- ** kwargs ,
560+ return await connector (ip_address , ctx , ** kwargs )
561+ else :
562+ # Synchronous drivers are blocking and run using executor
563+ tx , _ = await self .connect_socket_async (
564+ instance_connection_string , asyncio .Protocol , ** kwargs
443565 )
444- # Create socket with SSLContext for sync drivers
445- ctx = await conn_info .create_ssl_context (enable_iam_auth )
446- sock = ctx .wrap_socket (
447- socket .create_connection ((ip_address , SERVER_PROXY_PORT )),
448- server_hostname = ip_address ,
449- )
450-
451- # If this connection was opened using a domain name, then store it
452- # for later in case we need to forcibly close it on failover.
453- if conn_info .conn_name .domain_name :
454- monitored_cache .sockets .append (sock )
455- # Synchronous drivers are blocking and run using executor
456- connect_partial = partial (
457- connector ,
458- ip_address ,
459- sock ,
460- ** kwargs ,
461- )
462- return await self ._loop .run_in_executor (None , connect_partial )
566+ # See https://docs.python.org/3/library/asyncio-protocol.html#asyncio.BaseTransport.get_extra_info
567+ sock = tx .get_extra_info ("ssl_object" )
568+ connect_partial = partial (connector , ip_address , sock , ** kwargs )
569+ return await self ._loop .run_in_executor (None , connect_partial )
463570
464571 except Exception :
465572 # with any exception, we attempt a force refresh, then throw the error
573+ monitored_cache = await self ._get_cache (
574+ instance_connection_string , enable_iam_auth , ip_type , driver
575+ )
466576 await monitored_cache .force_refresh ()
467577 raise
468578
@@ -501,7 +611,7 @@ async def start_unix_socket_proxy_async(
501611 )
502612 await proxy_instance .start ()
503613 self ._proxies .append (proxy_instance )
504-
614+
505615 async def _remove_cached (
506616 self , instance_connection_string : str , enable_iam_auth : bool
507617 ) -> None :
@@ -567,7 +677,6 @@ async def close_async(self) -> None:
567677 await self ._client .close ()
568678 await asyncio .gather (* [cache .close () for cache in self ._cache .values ()])
569679
570-
571680async def create_async_connector (
572681 ip_type : str | IPTypes = IPTypes .PUBLIC ,
573682 enable_iam_auth : bool = False ,
@@ -657,3 +766,13 @@ async def create_async_connector(
657766 resolver = resolver ,
658767 failover_period = failover_period ,
659768 )
769+
770+
771+ class ConnectorSocketFactory (proxy .ServerConnectionFactory ):
772+ def __init__ (self , connector :Connector , instance_connection_string :str , ** kwargs ):
773+ self ._connector = connector
774+ self ._instance_connection_string = instance_connection_string
775+ self ._connect_args = kwargs
776+
777+ async def connect (self , protocol_fn : Callable [[], asyncio .Protocol ]):
778+ await self ._connector .connect_socket_async (self ._instance_connection_string , protocol_fn , ** self ._connect_args )
0 commit comments