Skip to content

Commit 378f325

Browse files
author
Thameez Bodhanya
committed
fix: increase test coverage
1 parent c180871 commit 378f325

File tree

2 files changed

+239
-3
lines changed

2 files changed

+239
-3
lines changed

tests/unit/test_connector.py

Lines changed: 171 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import socket
1818
from threading import Thread
1919
from typing import Union
20+
from unittest.mock import AsyncMock
21+
from unittest.mock import MagicMock
2022

2123
from aiohttp import ClientResponseError
2224
from google.auth.credentials import Credentials
@@ -28,6 +30,7 @@
2830
from google.cloud.sql.connector import IPTypes
2931
from google.cloud.sql.connector.client import CloudSQLClient
3032
from google.cloud.sql.connector.connection_name import ConnectionName
33+
from google.cloud.sql.connector.connector import ConnectorSocketFactory
3134
from google.cloud.sql.connector.exceptions import ClosedConnectorError
3235
from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError
3336
from google.cloud.sql.connector.exceptions import ConnectorLoopError
@@ -550,6 +553,57 @@ async def test_Connector_start_unix_socket_proxy_async(
550553
)
551554

552555

556+
@pytest.mark.asyncio
557+
async def test_Connector_start_unix_socket_proxy_async_rejects_duplicate_socket_path(
558+
fake_credentials: Credentials,
559+
) -> None:
560+
socket_path = "/tmp/cloudsql-test.sock"
561+
async with Connector(
562+
credentials=fake_credentials, loop=asyncio.get_running_loop()
563+
) as connector:
564+
existing_proxy = MagicMock()
565+
existing_proxy.unix_socket_path = socket_path
566+
existing_proxy.close = AsyncMock()
567+
connector._proxies.append(existing_proxy)
568+
569+
with pytest.raises(ValueError) as exc_info:
570+
await connector.start_unix_socket_proxy_async(
571+
"test-project:test-region:test-instance",
572+
socket_path,
573+
)
574+
575+
assert (
576+
exc_info.value.args[0]
577+
== f"Proxy for socket path {socket_path} already exists."
578+
)
579+
580+
581+
@pytest.mark.asyncio
582+
async def test_Connector_close_async_closes_proxies_client_and_cache(
583+
fake_credentials: Credentials,
584+
) -> None:
585+
async with Connector(
586+
credentials=fake_credentials, loop=asyncio.get_running_loop()
587+
) as connector:
588+
proxy_instance = MagicMock()
589+
proxy_instance.close = AsyncMock()
590+
connector._proxies.append(proxy_instance)
591+
592+
connector._client = MagicMock()
593+
connector._client.close = AsyncMock()
594+
595+
cached = MagicMock()
596+
cached.close = AsyncMock()
597+
connector._cache[("test-project:test-region:test-instance", False)] = cached
598+
599+
await connector.close_async()
600+
601+
assert connector._closed is True
602+
proxy_instance.close.assert_awaited_once()
603+
connector._client.close.assert_awaited_once()
604+
cached.close.assert_awaited_once()
605+
606+
553607
def test_connect_closed_connector(
554608
fake_credentials: Credentials, fake_client: CloudSQLClient
555609
) -> None:
@@ -678,7 +732,53 @@ async def test_Connector_connect_async_custom_dns_resolver_fallback(
678732
# Restore original IPs
679733
fake_client.instance.ip_addrs = original_ips
680734

681-
class TestProtocol(asyncio.Protocol):
735+
736+
@pytest.mark.asyncio
737+
async def test_Connector_get_cache_invalidates_bad_cached_entry(
738+
fake_credentials: Credentials,
739+
) -> None:
740+
connect_string = "test-project:test-region:test-instance"
741+
async with Connector(
742+
credentials=fake_credentials, loop=asyncio.get_running_loop()
743+
) as connector:
744+
monitored_cache = MagicMock()
745+
monitored_cache.closed = False
746+
monitored_cache.close = AsyncMock()
747+
748+
conn_info = MagicMock()
749+
conn_info.get_preferred_ip.side_effect = RuntimeError("invalid ip")
750+
monitored_cache.connect_info = AsyncMock(return_value=conn_info)
751+
connector._cache[(connect_string, False)] = monitored_cache
752+
753+
with (
754+
patch.object(
755+
connector._resolver,
756+
"resolve",
757+
AsyncMock(
758+
return_value=ConnectionName(
759+
"test-project",
760+
"test-region",
761+
"test-instance",
762+
)
763+
),
764+
),
765+
patch.object(
766+
connector, "_remove_cached", AsyncMock()
767+
) as mock_remove_cached,
768+
):
769+
with pytest.raises(RuntimeError) as exc_info:
770+
await connector._get_cache(
771+
connect_string,
772+
False,
773+
IPTypes.PUBLIC,
774+
None,
775+
)
776+
777+
assert exc_info.value.args[0] == "invalid ip"
778+
mock_remove_cached.assert_awaited_once_with(connect_string, False)
779+
780+
781+
class SocketTestProtocol(asyncio.Protocol):
682782
"""
683783
A protocol to proxy data between two transports.
684784
"""
@@ -738,7 +838,7 @@ async def test_Connector_connect_socket_async(
738838
) as connector:
739839
logger.info("client socket opening")
740840
connector._client = fake_client
741-
p = TestProtocol()
841+
p = SocketTestProtocol()
742842

743843
# Open proxy connection
744844
# start the proxy server
@@ -759,3 +859,72 @@ async def test_Connector_connect_socket_async(
759859
logger.info("client socket done")
760860

761861
assert p.received.decode() == "world\n"
862+
863+
864+
@pytest.mark.asyncio
865+
async def test_Connector_connect_socket_async_invalidates_cache_on_connection_error(
866+
fake_credentials: Credentials,
867+
) -> None:
868+
connect_string = "test-project:test-region:test-instance"
869+
async with Connector(
870+
credentials=fake_credentials, loop=asyncio.get_running_loop()
871+
) as connector:
872+
monitored_cache = MagicMock()
873+
conn_info = MagicMock()
874+
conn_info.create_ssl_context = AsyncMock(return_value=object())
875+
conn_info.get_preferred_ip.return_value = "127.0.0.1"
876+
monitored_cache.connect_info = AsyncMock(return_value=conn_info)
877+
878+
with (
879+
patch.object(
880+
connector, "_get_cache", AsyncMock(return_value=monitored_cache)
881+
),
882+
patch.object(
883+
connector._loop,
884+
"create_connection",
885+
AsyncMock(side_effect=RuntimeError("boom")),
886+
),
887+
patch.object(
888+
connector, "_remove_cached", AsyncMock()
889+
) as mock_remove_cached,
890+
):
891+
with pytest.raises(RuntimeError) as exc_info:
892+
await connector.connect_socket_async(
893+
connect_string,
894+
asyncio.Protocol,
895+
driver="asyncpg",
896+
)
897+
898+
assert exc_info.value.args[0] == "boom"
899+
mock_remove_cached.assert_awaited_once_with(connect_string, False)
900+
901+
902+
@pytest.mark.asyncio
903+
async def test_ConnectorSocketFactory_connect_forwards_arguments(
904+
fake_credentials: Credentials,
905+
) -> None:
906+
connect_string = "test-project:test-region:test-instance"
907+
async with Connector(
908+
credentials=fake_credentials, loop=asyncio.get_running_loop()
909+
) as connector:
910+
protocol_fn = MagicMock()
911+
with patch.object(
912+
connector,
913+
"connect_socket_async",
914+
AsyncMock(),
915+
) as mock_connect_socket_async:
916+
factory = ConnectorSocketFactory(
917+
connector,
918+
connect_string,
919+
driver="asyncpg",
920+
user="my-user",
921+
)
922+
923+
await factory.connect(protocol_fn)
924+
925+
mock_connect_socket_async.assert_awaited_once_with(
926+
connect_string,
927+
protocol_fn,
928+
driver="asyncpg",
929+
user="my-user",
930+
)

tests/unit/test_proxy.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222

2323
import pytest
2424

25+
from google.cloud.sql.connector.proxy import BaseProxyProtocol
26+
from google.cloud.sql.connector.proxy import ClientToServerProtocol
2527
from google.cloud.sql.connector.proxy import Proxy
28+
from google.cloud.sql.connector.proxy import ProxyClientConnection
2629
from google.cloud.sql.connector.proxy import ServerConnectionFactory
2730

2831

@@ -50,6 +53,70 @@ async def test_proxy_creates_folder_and_socket(short_tmpdir):
5053
await proxy.close()
5154

5255

56+
def test_base_proxy_protocol_flushes_cached_data_when_target_is_set():
57+
protocol = BaseProxyProtocol(MagicMock())
58+
target = MagicMock(spec=asyncio.Transport)
59+
60+
protocol.data_received(b"hello")
61+
protocol.data_received(b"world")
62+
protocol.set_target(target)
63+
64+
target.writelines.assert_called_once_with([b"hello", b"world"])
65+
assert protocol._cached == []
66+
67+
68+
def test_base_proxy_protocol_forwards_eof_and_close_to_target():
69+
protocol = BaseProxyProtocol(MagicMock())
70+
target = MagicMock(spec=asyncio.Transport)
71+
protocol.set_target(target)
72+
73+
protocol.eof_received()
74+
protocol.connection_lost(None)
75+
76+
target.write_eof.assert_called_once()
77+
target.close.assert_called_once()
78+
79+
80+
def test_proxy_client_connection_close_prefers_write_eof():
81+
client_transport = MagicMock(spec=asyncio.Transport)
82+
client_transport.is_closing.return_value = False
83+
client_transport.can_write_eof.return_value = True
84+
85+
server_transport = MagicMock(spec=asyncio.Transport)
86+
server_transport.is_closing.return_value = False
87+
server_transport.can_write_eof.return_value = True
88+
89+
connection = ProxyClientConnection(client_transport, MagicMock())
90+
connection.server_transport = server_transport
91+
92+
connection.close()
93+
94+
client_transport.write_eof.assert_called_once()
95+
server_transport.write_eof.assert_called_once()
96+
97+
98+
def test_proxy_client_connection_close_falls_back_to_close():
99+
client_transport = MagicMock(spec=asyncio.Transport)
100+
client_transport.is_closing.return_value = False
101+
client_transport.can_write_eof.return_value = False
102+
103+
connection = ProxyClientConnection(client_transport, MagicMock())
104+
105+
connection.close()
106+
107+
client_transport.close.assert_called_once()
108+
109+
110+
def test_client_to_server_protocol_opens_backend_connection_on_accept():
111+
proxy = MagicMock()
112+
protocol = ClientToServerProtocol(proxy)
113+
transport = MagicMock(spec=asyncio.Transport)
114+
115+
protocol.connection_made(transport)
116+
117+
proxy._handle_client_connection.assert_called_once_with(transport, protocol)
118+
119+
53120
# A mock ServerConnectionFactory for testing purposes.
54121
class MockServerConnectionFactory(ServerConnectionFactory):
55122
def __init__(self, loop):
@@ -373,4 +440,4 @@ async def test_tcp_proxy_client_closes_connection(tcp_proxy_server):
373440

374441
# Check that the server socket is closing
375442
await asyncio.sleep(0.01)
376-
assert connector.server_transport.is_closing()
443+
assert connector.server_transport.is_closing()

0 commit comments

Comments
 (0)