Skip to content

Commit 7a1812a

Browse files
chore: attempt moving socket into ConnectionInfo
1 parent f92fd88 commit 7a1812a

15 files changed

Lines changed: 45 additions & 33 deletions

google/cloud/sql/connector/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
""""
1+
"""
22
Copyright 2019 Google LLC
33
44
Licensed under the Apache License, Version 2.0 (the "License");

google/cloud/sql/connector/connection_info.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,21 @@
1717
import abc
1818
from dataclasses import dataclass
1919
import logging
20+
import socket
2021
import ssl
21-
from typing import Any, Optional, TYPE_CHECKING
22+
from typing import Any, Optional, TYPE_CHECKING, Union
2223

2324
from aiofiles.tempfile import TemporaryDirectory
2425

2526
from google.cloud.sql.connector.connection_name import ConnectionName
27+
from google.cloud.sql.connector.enums import IPTypes
2628
from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError
2729
from google.cloud.sql.connector.exceptions import TLSVersionError
2830
from google.cloud.sql.connector.utils import write_to_file
2931

3032
if TYPE_CHECKING:
3133
import datetime
3234

33-
from google.cloud.sql.connector.enums import IPTypes
3435

3536
logger = logging.getLogger(name=__name__)
3637

@@ -69,13 +70,21 @@ class ConnectionInfo:
6970
database_version: str
7071
expiration: datetime.datetime
7172
context: Optional[ssl.SSLContext] = None
73+
sock: Optional[ssl.SSLSocket] = None
7274

73-
async def create_ssl_context(self, enable_iam_auth: bool = False) -> ssl.SSLContext:
75+
async def create_ssl_context(
76+
self, enable_iam_auth: bool = False, return_socket: bool = False
77+
) -> Union[ssl.SSLContext, ssl.SSLSocket]:
7478
"""Constructs a SSL/TLS context for the given connection info.
7579
7680
Cache the SSL context to ensure we don't read from disk repeatedly when
7781
configuring a secure connection.
7882
"""
83+
# Return socket if socket is cached and return_socket is set to True
84+
if self.sock is not None and return_socket:
85+
logger.debug("Socket in cache, returning it!")
86+
return self.sock
87+
7988
# if SSL context is cached, use it
8089
if self.context is not None:
8190
return self.context
@@ -116,6 +125,15 @@ async def create_ssl_context(self, enable_iam_auth: bool = False) -> ssl.SSLCont
116125
context.load_verify_locations(cafile=ca_filename)
117126
# set class attribute to cache context for subsequent calls
118127
self.context = context
128+
# If return_socket is True, cache socket and return it
129+
if return_socket:
130+
logger.debug("Returning socket instead of context!")
131+
sock = self.context.wrap_socket(
132+
socket.create_connection((self.get_preferred_ip(IPTypes.PUBLIC), 3307)),
133+
server_hostname="blah",
134+
)
135+
self.sock = sock
136+
return sock
119137
return context
120138

121139
def get_preferred_ip(self, ip_type: IPTypes) -> str:

google/cloud/sql/connector/connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ async def connect_async(
390390

391391
except Exception:
392392
# with any exception, we attempt a force refresh, then throw the error
393-
await cache.force_refresh()
393+
await monitored_cache.force_refresh()
394394
raise
395395

396396
async def _remove_cached(

google/cloud/sql/connector/monitored_cache.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(
3636
self.resolver = resolver
3737
self.cache = cache
3838
self.domain_name_ticker: Optional[asyncio.Task] = None
39-
self.open_conns_count: int = 0
39+
self.open_conns: int = 0
4040

4141
if self.cache.conn_name.domain_name:
4242
self.domain_name_ticker = asyncio.create_task(
@@ -66,6 +66,12 @@ async def _check_domain_name(self) -> None:
6666
"connections!"
6767
)
6868
await self.close()
69+
conn_info = await self.connect_info()
70+
if conn_info.sock:
71+
logger.debug(f"Socket type: {type(conn_info.sock)}")
72+
conn_info.sock.close()
73+
else:
74+
logger.debug("Domain name mapping has not changed!")
6975

7076
except Exception as e:
7177
# Domain name checks should not be fatal, log error and continue.

google/cloud/sql/connector/pymysql.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
limitations under the License.
1515
"""
1616

17-
import socket
1817
import ssl
1918
from typing import Any, TYPE_CHECKING
2019

@@ -25,16 +24,16 @@
2524

2625

2726
def connect(
28-
ip_address: str, ctx: ssl.SSLContext, **kwargs: Any
27+
ip_address: str, sock: ssl.SSLSocket, **kwargs: Any
2928
) -> "pymysql.connections.Connection":
3029
"""Helper function to create a pymysql DB-API connection object.
3130
3231
:type ip_address: str
3332
:param ip_address: A string containing an IP address for the Cloud SQL
3433
instance.
3534
36-
:type ctx: ssl.SSLContext
37-
:param ctx: An SSLContext object created from the Cloud SQL server CA
35+
:type sock: ssl.SSLSocket
36+
:param sock: An SSLSocket object created from the Cloud SQL server CA
3837
cert and ephemeral cert.
3938
4039
:rtype: pymysql.Connection
@@ -50,11 +49,6 @@ def connect(
5049
# allow automatic IAM database authentication to not require password
5150
kwargs["password"] = kwargs["password"] if "password" in kwargs else None
5251

53-
# Create socket and wrap with context.
54-
sock = ctx.wrap_socket(
55-
socket.create_connection((ip_address, SERVER_PROXY_PORT)),
56-
server_hostname=ip_address,
57-
)
5852
# pop timeout as timeout arg is called 'connect_timeout' for pymysql
5953
timeout = kwargs.pop("timeout")
6054
kwargs["connect_timeout"] = kwargs.get("connect_timeout", timeout)

google/cloud/sql/connector/pytds.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
"""
1616

1717
import platform
18-
import socket
1918
import ssl
2019
from typing import Any, TYPE_CHECKING
2120

@@ -27,15 +26,15 @@
2726
import pytds
2827

2928

30-
def connect(ip_address: str, ctx: ssl.SSLContext, **kwargs: Any) -> "pytds.Connection":
29+
def connect(ip_address: str, sock: ssl.SSLSocket, **kwargs: Any) -> "pytds.Connection":
3130
"""Helper function to create a pytds DB-API connection object.
3231
3332
:type ip_address: str
3433
:param ip_address: A string containing an IP address for the Cloud SQL
3534
instance.
3635
37-
:type ctx: ssl.SSLContext
38-
:param ctx: An SSLContext object created from the Cloud SQL server CA
36+
:type sock: ssl.SSLSocket
37+
:param sock: An SSLSocket object created from the Cloud SQL server CA
3938
cert and ephemeral cert.
4039
4140
@@ -51,11 +50,6 @@ def connect(ip_address: str, ctx: ssl.SSLContext, **kwargs: Any) -> "pytds.Conne
5150

5251
db = kwargs.pop("db", None)
5352

54-
# Create socket and wrap with context.
55-
sock = ctx.wrap_socket(
56-
socket.create_connection((ip_address, SERVER_PROXY_PORT)),
57-
server_hostname=ip_address,
58-
)
5953
if kwargs.pop("active_directory_auth", False):
6054
if platform.system() == "Windows":
6155
# Ignore username and password if using active directory auth

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
""""
1+
"""
22
Copyright 2021 Google LLC
33
44
Licensed under the Apache License, Version 2.0 (the "License");

tests/system/test_connector_object.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
""""
1+
"""
22
Copyright 2021 Google LLC
33
44
Licensed under the Apache License, Version 2.0 (the "License");

tests/system/test_ip_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
""""
1+
"""
22
Copyright 2021 Google LLC
33
44
Licensed under the Apache License, Version 2.0 (the "License");

tests/system/test_pymysql_connection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
""""
1+
"""
22
Copyright 2021 Google LLC
33
44
Licensed under the Apache License, Version 2.0 (the "License");

0 commit comments

Comments
 (0)