1717import socket
1818from threading import Thread
1919from typing import Union
20+ from unittest .mock import AsyncMock
21+ from unittest .mock import MagicMock
2022
2123from aiohttp import ClientResponseError
2224from google .auth .credentials import Credentials
2830from google .cloud .sql .connector import IPTypes
2931from google .cloud .sql .connector .client import CloudSQLClient
3032from google .cloud .sql .connector .connection_name import ConnectionName
33+ from google .cloud .sql .connector .connector import ConnectorSocketFactory
3134from google .cloud .sql .connector .exceptions import ClosedConnectorError
3235from google .cloud .sql .connector .exceptions import CloudSQLIPTypeError
3336from google .cloud .sql .connector .exceptions import ConnectorLoopError
@@ -550,6 +553,57 @@ async def test_Connector_start_unix_socket_proxy_async(
550553 )
551554
552555
556+ @pytest .mark .asyncio
557+ async def test_Connector_start_unix_socket_proxy_async_rejects_duplicate_socket_path (
558+ fake_credentials : Credentials ,
559+ ) -> None :
560+ socket_path = "/tmp/cloudsql-test.sock"
561+ async with Connector (
562+ credentials = fake_credentials , loop = asyncio .get_running_loop ()
563+ ) as connector :
564+ existing_proxy = MagicMock ()
565+ existing_proxy .unix_socket_path = socket_path
566+ existing_proxy .close = AsyncMock ()
567+ connector ._proxies .append (existing_proxy )
568+
569+ with pytest .raises (ValueError ) as exc_info :
570+ await connector .start_unix_socket_proxy_async (
571+ "test-project:test-region:test-instance" ,
572+ socket_path ,
573+ )
574+
575+ assert (
576+ exc_info .value .args [0 ]
577+ == f"Proxy for socket path { socket_path } already exists."
578+ )
579+
580+
581+ @pytest .mark .asyncio
582+ async def test_Connector_close_async_closes_proxies_client_and_cache (
583+ fake_credentials : Credentials ,
584+ ) -> None :
585+ async with Connector (
586+ credentials = fake_credentials , loop = asyncio .get_running_loop ()
587+ ) as connector :
588+ proxy_instance = MagicMock ()
589+ proxy_instance .close = AsyncMock ()
590+ connector ._proxies .append (proxy_instance )
591+
592+ connector ._client = MagicMock ()
593+ connector ._client .close = AsyncMock ()
594+
595+ cached = MagicMock ()
596+ cached .close = AsyncMock ()
597+ connector ._cache [("test-project:test-region:test-instance" , False )] = cached
598+
599+ await connector .close_async ()
600+
601+ assert connector ._closed is True
602+ proxy_instance .close .assert_awaited_once ()
603+ connector ._client .close .assert_awaited_once ()
604+ cached .close .assert_awaited_once ()
605+
606+
553607def test_connect_closed_connector (
554608 fake_credentials : Credentials , fake_client : CloudSQLClient
555609) -> None :
@@ -678,7 +732,53 @@ async def test_Connector_connect_async_custom_dns_resolver_fallback(
678732 # Restore original IPs
679733 fake_client .instance .ip_addrs = original_ips
680734
681- class TestProtocol (asyncio .Protocol ):
735+
736+ @pytest .mark .asyncio
737+ async def test_Connector_get_cache_invalidates_bad_cached_entry (
738+ fake_credentials : Credentials ,
739+ ) -> None :
740+ connect_string = "test-project:test-region:test-instance"
741+ async with Connector (
742+ credentials = fake_credentials , loop = asyncio .get_running_loop ()
743+ ) as connector :
744+ monitored_cache = MagicMock ()
745+ monitored_cache .closed = False
746+ monitored_cache .close = AsyncMock ()
747+
748+ conn_info = MagicMock ()
749+ conn_info .get_preferred_ip .side_effect = RuntimeError ("invalid ip" )
750+ monitored_cache .connect_info = AsyncMock (return_value = conn_info )
751+ connector ._cache [(connect_string , False )] = monitored_cache
752+
753+ with (
754+ patch .object (
755+ connector ._resolver ,
756+ "resolve" ,
757+ AsyncMock (
758+ return_value = ConnectionName (
759+ "test-project" ,
760+ "test-region" ,
761+ "test-instance" ,
762+ )
763+ ),
764+ ),
765+ patch .object (
766+ connector , "_remove_cached" , AsyncMock ()
767+ ) as mock_remove_cached ,
768+ ):
769+ with pytest .raises (RuntimeError ) as exc_info :
770+ await connector ._get_cache (
771+ connect_string ,
772+ False ,
773+ IPTypes .PUBLIC ,
774+ None ,
775+ )
776+
777+ assert exc_info .value .args [0 ] == "invalid ip"
778+ mock_remove_cached .assert_awaited_once_with (connect_string , False )
779+
780+
781+ class SocketTestProtocol (asyncio .Protocol ):
682782 """
683783 A protocol to proxy data between two transports.
684784 """
@@ -738,7 +838,7 @@ async def test_Connector_connect_socket_async(
738838 ) as connector :
739839 logger .info ("client socket opening" )
740840 connector ._client = fake_client
741- p = TestProtocol ()
841+ p = SocketTestProtocol ()
742842
743843 # Open proxy connection
744844 # start the proxy server
@@ -759,3 +859,72 @@ async def test_Connector_connect_socket_async(
759859 logger .info ("client socket done" )
760860
761861 assert p .received .decode () == "world\n "
862+
863+
864+ @pytest .mark .asyncio
865+ async def test_Connector_connect_socket_async_invalidates_cache_on_connection_error (
866+ fake_credentials : Credentials ,
867+ ) -> None :
868+ connect_string = "test-project:test-region:test-instance"
869+ async with Connector (
870+ credentials = fake_credentials , loop = asyncio .get_running_loop ()
871+ ) as connector :
872+ monitored_cache = MagicMock ()
873+ conn_info = MagicMock ()
874+ conn_info .create_ssl_context = AsyncMock (return_value = object ())
875+ conn_info .get_preferred_ip .return_value = "127.0.0.1"
876+ monitored_cache .connect_info = AsyncMock (return_value = conn_info )
877+
878+ with (
879+ patch .object (
880+ connector , "_get_cache" , AsyncMock (return_value = monitored_cache )
881+ ),
882+ patch .object (
883+ connector ._loop ,
884+ "create_connection" ,
885+ AsyncMock (side_effect = RuntimeError ("boom" )),
886+ ),
887+ patch .object (
888+ connector , "_remove_cached" , AsyncMock ()
889+ ) as mock_remove_cached ,
890+ ):
891+ with pytest .raises (RuntimeError ) as exc_info :
892+ await connector .connect_socket_async (
893+ connect_string ,
894+ asyncio .Protocol ,
895+ driver = "asyncpg" ,
896+ )
897+
898+ assert exc_info .value .args [0 ] == "boom"
899+ mock_remove_cached .assert_awaited_once_with (connect_string , False )
900+
901+
902+ @pytest .mark .asyncio
903+ async def test_ConnectorSocketFactory_connect_forwards_arguments (
904+ fake_credentials : Credentials ,
905+ ) -> None :
906+ connect_string = "test-project:test-region:test-instance"
907+ async with Connector (
908+ credentials = fake_credentials , loop = asyncio .get_running_loop ()
909+ ) as connector :
910+ protocol_fn = MagicMock ()
911+ with patch .object (
912+ connector ,
913+ "connect_socket_async" ,
914+ AsyncMock (),
915+ ) as mock_connect_socket_async :
916+ factory = ConnectorSocketFactory (
917+ connector ,
918+ connect_string ,
919+ driver = "asyncpg" ,
920+ user = "my-user" ,
921+ )
922+
923+ await factory .connect (protocol_fn )
924+
925+ mock_connect_socket_async .assert_awaited_once_with (
926+ connect_string ,
927+ protocol_fn ,
928+ driver = "asyncpg" ,
929+ user = "my-user" ,
930+ )
0 commit comments