Skip to content

Commit 9444ea5

Browse files
test: update mock server used for tests
1 parent ccd414c commit 9444ea5

7 files changed

Lines changed: 96 additions & 73 deletions

File tree

tests/conftest.py

Lines changed: 54 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,23 @@
1717
import asyncio
1818
import os
1919
import socket
20+
import ssl
2021
from threading import Thread
21-
from typing import Any, AsyncGenerator, Generator
22+
from typing import Any, AsyncGenerator
2223

24+
from aiofiles.tempfile import TemporaryDirectory
2325
from aiohttp import web
26+
from cryptography.hazmat.primitives import serialization
2427
import pytest # noqa F401 Needed to run the tests
28+
from unit.mocks import create_ssl_context # type: ignore
2529
from unit.mocks import FakeCredentials # type: ignore
2630
from unit.mocks import FakeCSQLInstance # type: ignore
2731

2832
from google.cloud.sql.connector.client import CloudSQLClient
2933
from google.cloud.sql.connector.connection_name import ConnectionName
3034
from google.cloud.sql.connector.instance import RefreshAheadCache
3135
from google.cloud.sql.connector.utils import generate_keys
36+
from google.cloud.sql.connector.utils import write_to_file
3237

3338
SCOPES = ["https://www.googleapis.com/auth/sqlservice.admin"]
3439

@@ -79,25 +84,58 @@ def fake_credentials() -> FakeCredentials:
7984
return FakeCredentials()
8085

8186

82-
def mock_server(server_sock: socket.socket) -> None:
83-
"""Create mock server listening on specified ip_address and port."""
87+
async def start_proxy_server(instance: FakeCSQLInstance) -> None:
88+
"""Run local proxy server capable of performing mTLS"""
8489
ip_address = "127.0.0.1"
8590
port = 3307
86-
server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
87-
server_sock.bind((ip_address, port))
88-
server_sock.listen(0)
89-
server_sock.accept()
91+
# create socket
92+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
93+
# create SSL/TLS context
94+
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
95+
context.minimum_version = ssl.TLSVersion.TLSv1_3
96+
# tmpdir and its contents are automatically deleted after the CA cert
97+
# and cert chain are loaded into the SSLcontext. The values
98+
# need to be written to files in order to be loaded by the SSLContext
99+
server_key_bytes = instance.server_key.private_bytes(
100+
encoding=serialization.Encoding.PEM,
101+
format=serialization.PrivateFormat.TraditionalOpenSSL,
102+
encryption_algorithm=serialization.NoEncryption(),
103+
)
104+
async with TemporaryDirectory() as tmpdir:
105+
server_filename, _, key_filename = await write_to_file(
106+
tmpdir, instance.server_cert_pem, "", server_key_bytes
107+
)
108+
context.load_cert_chain(server_filename, key_filename)
109+
# bind socket to Cloud SQL proxy server port on localhost
110+
sock.bind((ip_address, port))
111+
# listen for incoming connections
112+
sock.listen(5)
113+
114+
with context.wrap_socket(sock, server_side=True) as ssock:
115+
while True:
116+
conn, _ = ssock.accept()
117+
conn.close()
118+
119+
120+
@pytest.fixture(scope="session")
121+
def proxy_server(fake_instance: FakeCSQLInstance) -> None:
122+
"""Run local proxy server capable of performing mTLS"""
123+
thread = Thread(
124+
target=asyncio.run,
125+
args=(
126+
start_proxy_server(
127+
fake_instance,
128+
),
129+
),
130+
daemon=True,
131+
)
132+
thread.start()
133+
thread.join(1.0) # add a delay to allow the proxy server to start
90134

91135

92136
@pytest.fixture
93-
def server() -> Generator:
94-
"""Create thread with server listening on proper port"""
95-
server_sock = socket.socket()
96-
thread = Thread(target=mock_server, args=(server_sock,), daemon=True)
97-
thread.start()
98-
yield thread
99-
server_sock.close()
100-
thread.join()
137+
async def context(fake_instance: FakeCSQLInstance) -> ssl.SSLContext:
138+
return await create_ssl_context(fake_instance)
101139

102140

103141
@pytest.fixture
@@ -107,7 +145,7 @@ def kwargs() -> Any:
107145
return kwargs
108146

109147

110-
@pytest.fixture
148+
@pytest.fixture(scope="session")
111149
def fake_instance() -> FakeCSQLInstance:
112150
return FakeCSQLInstance()
113151

tests/unit/mocks.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
""""
1+
""" "
22
Copyright 2022 Google LLC
33
44
Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,6 +16,8 @@
1616

1717
# file containing all mocks used for Cloud SQL Python Connector unit tests
1818

19+
from __future__ import annotations
20+
1921
import datetime
2022
import json
2123
import ssl
@@ -184,28 +186,28 @@ def client_key_signed_cert(
184186
.not_valid_after(cert_expiration) # type: ignore
185187
)
186188
return (
187-
cert.sign(priv_key, hashes.SHA256(), default_backend())
189+
cert.sign(priv_key, hashes.SHA256())
188190
.public_bytes(encoding=serialization.Encoding.PEM)
189191
.decode("UTF-8")
190192
)
191193

192194

193-
async def create_ssl_context() -> ssl.SSLContext:
195+
async def create_ssl_context(instance: FakeCSQLInstance) -> ssl.SSLContext:
194196
"""Helper method to build an ssl.SSLContext for tests"""
195-
# generate keys and certs for test
196-
cert, private_key = generate_cert("my-project", "my-instance")
197-
server_ca_cert = self_signed_cert(cert, private_key)
198197
client_private, client_bytes = await generate_keys()
199198
client_key: rsa.RSAPublicKey = serialization.load_pem_public_key(
200-
client_bytes.encode("UTF-8"), default_backend()
199+
client_bytes.encode("UTF-8"),
201200
) # type: ignore
202-
ephemeral_cert = client_key_signed_cert(cert, private_key, client_key)
203-
# build default ssl.SSLContext
204-
context = ssl.create_default_context()
201+
ephemeral_cert = client_key_signed_cert(
202+
instance.server_ca, instance.server_key, client_key
203+
)
204+
# create SSL/TLS context
205+
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
206+
context.check_hostname = False
205207
# load ssl.SSLContext with certs
206208
async with TemporaryDirectory() as tmpdir:
207209
ca_filename, cert_filename, key_filename = await write_to_file(
208-
tmpdir, server_ca_cert, ephemeral_cert, client_private
210+
tmpdir, instance.server_cert_pem, ephemeral_cert, client_private
209211
)
210212
context.load_cert_chain(cert_filename, keyfile=key_filename)
211213
context.load_verify_locations(cafile=ca_filename)
@@ -279,8 +281,8 @@ async def generate_ephemeral(self, request: Any) -> web.Response:
279281
body = await request.json()
280282
pub_key = body["public_key"]
281283
client_key: rsa.RSAPublicKey = serialization.load_pem_public_key(
282-
pub_key.encode("UTF-8"), default_backend()
283-
) # type: ignore
284+
pub_key.encode("UTF-8"),
285+
)
284286
ephemeral_cert = client_key_signed_cert(
285287
self.server_ca,
286288
self.server_key,

tests/unit/test_instance.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,17 +186,18 @@ async def test_RefreshAheadCache_close(cache: RefreshAheadCache) -> None:
186186
@pytest.mark.asyncio
187187
async def test_perform_refresh(
188188
cache: RefreshAheadCache,
189-
fake_instance: mocks.FakeCSQLInstance,
190189
) -> None:
191190
"""
192191
Test that _perform_refresh returns valid ConnectionInfo object.
193192
"""
194193
instance_metadata = await cache._perform_refresh()
195-
196194
# verify instance metadata object is returned
197195
assert isinstance(instance_metadata, ConnectionInfo)
198196
# verify instance metadata expiration
199-
assert fake_instance.server_cert.not_valid_after_utc == instance_metadata.expiration
197+
assert (
198+
cache._client.instance.cert_expiration.replace(microsecond=0)
199+
== instance_metadata.expiration
200+
)
200201

201202

202203
@pytest.mark.asyncio

tests/unit/test_monitored_cache.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414

1515
import asyncio
1616
import socket
17+
import ssl
1718

1819
import dns.message
1920
import dns.rdataclass
2021
import dns.rdatatype
2122
import dns.resolver
2223
from mock import patch
23-
from mocks import create_ssl_context
2424
import pytest
2525

2626
from google.cloud.sql.connector.client import CloudSQLClient
@@ -149,8 +149,10 @@ async def test_MonitoredCache_with_disabled_failover(
149149
assert monitored_cache.closed is True
150150

151151

152-
@pytest.mark.usefixtures("server")
153-
async def test_MonitoredCache_check_domain_name(fake_client: CloudSQLClient) -> None:
152+
@pytest.mark.usefixtures("proxy_server")
153+
async def test_MonitoredCache_check_domain_name(
154+
context: ssl.SSLContext, fake_client: CloudSQLClient
155+
) -> None:
154156
"""
155157
Test that MonitoredCache is closed when _check_domain_name has domain change.
156158
"""
@@ -177,11 +179,9 @@ async def test_MonitoredCache_check_domain_name(fake_client: CloudSQLClient) ->
177179

178180
# configure a local socket
179181
ip_addr = "127.0.0.1"
180-
context = await create_ssl_context()
181182
sock = context.wrap_socket(
182183
socket.create_connection((ip_addr, 3307)),
183184
server_hostname=ip_addr,
184-
do_handshake_on_connect=False,
185185
)
186186
# verify socket is open
187187
assert sock.fileno() != -1
@@ -198,8 +198,10 @@ async def test_MonitoredCache_check_domain_name(fake_client: CloudSQLClient) ->
198198
assert sock.fileno() == -1
199199

200200

201-
@pytest.mark.usefixtures("server")
202-
async def test_MonitoredCache_purge_closed_sockets(fake_client: CloudSQLClient) -> None:
201+
@pytest.mark.usefixtures("proxy_server")
202+
async def test_MonitoredCache_purge_closed_sockets(
203+
context: ssl.SSLContext, fake_client: CloudSQLClient
204+
) -> None:
203205
"""
204206
Test that MonitoredCache._purge_closed_sockets removes closed sockets from
205207
cache.
@@ -215,11 +217,9 @@ async def test_MonitoredCache_purge_closed_sockets(fake_client: CloudSQLClient)
215217
)
216218
# configure a local socket
217219
ip_addr = "127.0.0.1"
218-
context = await create_ssl_context()
219220
sock = context.wrap_socket(
220221
socket.create_connection((ip_addr, 3307)),
221222
server_hostname=ip_addr,
222-
do_handshake_on_connect=False,
223223
)
224224

225225
# set failover to 0 to disable polling

tests/unit/test_pg8000.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,22 @@
1515
"""
1616

1717
import socket
18+
import ssl
1819
from typing import Any
1920

2021
from mock import patch
21-
from mocks import create_ssl_context
2222
import pytest
2323

2424
from google.cloud.sql.connector.pg8000 import connect
2525

2626

27-
@pytest.mark.usefixtures("server")
28-
@pytest.mark.asyncio
29-
async def test_pg8000(kwargs: Any) -> None:
27+
@pytest.mark.usefixtures("proxy_server")
28+
async def test_pg8000(context: ssl.SSLContext, kwargs: Any) -> None:
3029
"""Test to verify that pg8000 gets to proper connection call."""
3130
ip_addr = "127.0.0.1"
32-
# build ssl.SSLContext
33-
context = await create_ssl_context()
3431
sock = context.wrap_socket(
3532
socket.create_connection((ip_addr, 3307)),
3633
server_hostname=ip_addr,
37-
do_handshake_on_connect=False,
3834
)
3935
with patch("pg8000.dbapi.connect") as mock_connect:
4036
mock_connect.return_value = True

tests/unit/test_pymysql.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from typing import Any
2020

2121
from mock import patch
22-
from mocks import create_ssl_context
2322
import pytest
2423

2524
from google.cloud.sql.connector.pymysql import connect as pymysql_connect
@@ -33,17 +32,14 @@ def connect(sock: ssl.SSLSocket) -> None: # type: ignore
3332
assert isinstance(sock, ssl.SSLSocket)
3433

3534

36-
@pytest.mark.usefixtures("server")
35+
@pytest.mark.usefixtures("proxy_server")
3736
@pytest.mark.asyncio
38-
async def test_pymysql(kwargs: Any) -> None:
37+
async def test_pymysql(context: ssl.SSLContext, kwargs: Any) -> None:
3938
"""Test to verify that pymysql gets to proper connection call."""
4039
ip_addr = "127.0.0.1"
41-
# build ssl.SSLContext
42-
context = await create_ssl_context()
4340
sock = context.wrap_socket(
4441
socket.create_connection((ip_addr, 3307)),
4542
server_hostname=ip_addr,
46-
do_handshake_on_connect=False,
4743
)
4844
kwargs["timeout"] = 30
4945
with patch("pymysql.Connection") as mock_connect:

tests/unit/test_pytds.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616

1717
import platform
1818
import socket
19+
import ssl
1920
from typing import Any
2021

2122
from mock import patch
22-
from mocks import create_ssl_context
2323
import pytest
2424

2525
from google.cloud.sql.connector.exceptions import PlatformNotSupportedError
@@ -36,17 +36,13 @@ def stub_platform_windows() -> str:
3636
return "Windows"
3737

3838

39-
@pytest.mark.usefixtures("server")
40-
@pytest.mark.asyncio
41-
async def test_pytds(kwargs: Any) -> None:
39+
@pytest.mark.usefixtures("proxy_server")
40+
async def test_pytds(context: ssl.SSLContext, kwargs: Any) -> None:
4241
"""Test to verify that pytds gets to proper connection call."""
4342
ip_addr = "127.0.0.1"
44-
# build ssl.SSLContext
45-
context = await create_ssl_context()
4643
sock = context.wrap_socket(
4744
socket.create_connection((ip_addr, 3307)),
4845
server_hostname=ip_addr,
49-
do_handshake_on_connect=False,
5046
)
5147

5248
with patch("pytds.connect") as mock_connect:
@@ -57,20 +53,16 @@ async def test_pytds(kwargs: Any) -> None:
5753
assert mock_connect.assert_called_once
5854

5955

60-
@pytest.mark.usefixtures("server")
61-
@pytest.mark.asyncio
62-
async def test_pytds_platform_error(kwargs: Any) -> None:
56+
@pytest.mark.usefixtures("proxy_server")
57+
async def test_pytds_platform_error(context: ssl.SSLContext, kwargs: Any) -> None:
6358
"""Test to verify that pytds.connect throws proper PlatformNotSupportedError."""
6459
ip_addr = "127.0.0.1"
6560
# stub operating system to Linux
6661
setattr(platform, "system", stub_platform_linux)
6762
assert platform.system() == "Linux"
68-
# build ssl.SSLContext
69-
context = await create_ssl_context()
7063
sock = context.wrap_socket(
7164
socket.create_connection((ip_addr, 3307)),
7265
server_hostname=ip_addr,
73-
do_handshake_on_connect=False,
7466
)
7567
# add active_directory_auth to kwargs
7668
kwargs["active_directory_auth"] = True
@@ -79,9 +71,10 @@ async def test_pytds_platform_error(kwargs: Any) -> None:
7971
connect(ip_addr, sock, **kwargs)
8072

8173

82-
@pytest.mark.usefixtures("server")
83-
@pytest.mark.asyncio
84-
async def test_pytds_windows_active_directory_auth(kwargs: Any) -> None:
74+
@pytest.mark.usefixtures("proxy_server")
75+
async def test_pytds_windows_active_directory_auth(
76+
context: ssl.SSLContext, kwargs: Any
77+
) -> None:
8578
"""
8679
Test to verify that pytds gets to connection call on Windows with
8780
active_directory_auth arg set.
@@ -90,12 +83,9 @@ async def test_pytds_windows_active_directory_auth(kwargs: Any) -> None:
9083
# stub operating system to Windows
9184
setattr(platform, "system", stub_platform_windows)
9285
assert platform.system() == "Windows"
93-
# build ssl.SSLContext
94-
context = await create_ssl_context()
9586
sock = context.wrap_socket(
9687
socket.create_connection((ip_addr, 3307)),
9788
server_hostname=ip_addr,
98-
do_handshake_on_connect=False,
9989
)
10090
# add active_directory_auth and server_name to kwargs
10191
kwargs["active_directory_auth"] = True

0 commit comments

Comments
 (0)