|
17 | 17 | import abc |
18 | 18 | from dataclasses import dataclass |
19 | 19 | import logging |
| 20 | +import socket |
20 | 21 | import ssl |
21 | | -from typing import Any, Optional, TYPE_CHECKING |
| 22 | +from typing import Any, Optional, TYPE_CHECKING, Union |
22 | 23 |
|
23 | 24 | from aiofiles.tempfile import TemporaryDirectory |
24 | 25 |
|
25 | 26 | from google.cloud.sql.connector.connection_name import ConnectionName |
| 27 | +from google.cloud.sql.connector.enums import IPTypes |
26 | 28 | from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError |
27 | 29 | from google.cloud.sql.connector.exceptions import TLSVersionError |
28 | 30 | from google.cloud.sql.connector.utils import write_to_file |
29 | 31 |
|
30 | 32 | if TYPE_CHECKING: |
31 | 33 | import datetime |
32 | 34 |
|
33 | | - from google.cloud.sql.connector.enums import IPTypes |
34 | 35 |
|
35 | 36 | logger = logging.getLogger(name=__name__) |
36 | 37 |
|
@@ -69,13 +70,21 @@ class ConnectionInfo: |
69 | 70 | database_version: str |
70 | 71 | expiration: datetime.datetime |
71 | 72 | context: Optional[ssl.SSLContext] = None |
| 73 | + sock: Optional[ssl.SSLSocket] = None |
72 | 74 |
|
73 | | - async def create_ssl_context(self, enable_iam_auth: bool = False) -> ssl.SSLContext: |
| 75 | + async def create_ssl_context( |
| 76 | + self, enable_iam_auth: bool = False, return_socket: bool = False |
| 77 | + ) -> Union[ssl.SSLContext, ssl.SSLSocket]: |
74 | 78 | """Constructs a SSL/TLS context for the given connection info. |
75 | 79 |
|
76 | 80 | Cache the SSL context to ensure we don't read from disk repeatedly when |
77 | 81 | configuring a secure connection. |
78 | 82 | """ |
| 83 | + # Return socket if socket is cached and return_socket is set to True |
| 84 | + if self.sock is not None and return_socket: |
| 85 | + logger.debug("Socket in cache, returning it!") |
| 86 | + return self.sock |
| 87 | + |
79 | 88 | # if SSL context is cached, use it |
80 | 89 | if self.context is not None: |
81 | 90 | return self.context |
@@ -116,6 +125,15 @@ async def create_ssl_context(self, enable_iam_auth: bool = False) -> ssl.SSLCont |
116 | 125 | context.load_verify_locations(cafile=ca_filename) |
117 | 126 | # set class attribute to cache context for subsequent calls |
118 | 127 | self.context = context |
| 128 | + # If return_socket is True, cache socket and return it |
| 129 | + if return_socket: |
| 130 | + logger.debug("Returning socket instead of context!") |
| 131 | + sock = self.context.wrap_socket( |
| 132 | + socket.create_connection((self.get_preferred_ip(IPTypes.PUBLIC), 3307)), |
| 133 | + server_hostname="blah", |
| 134 | + ) |
| 135 | + self.sock = sock |
| 136 | + return sock |
119 | 137 | return context |
120 | 138 |
|
121 | 139 | def get_preferred_ip(self, ip_type: IPTypes) -> str: |
|
0 commit comments