Skip to content

Commit ca6a271

Browse files
Uziel Silvathameezb
authored andcommitted
fix(main): Make local proxy to accept multiple connections (WIP)
1 parent b663bbe commit ca6a271

7 files changed

Lines changed: 169 additions & 130 deletions

File tree

google/cloud/sql/connector/connector.py

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@
3838
from google.cloud.sql.connector.exceptions import ConnectorLoopError
3939
from google.cloud.sql.connector.instance import RefreshAheadCache
4040
from google.cloud.sql.connector.lazy import LazyRefreshCache
41+
import google.cloud.sql.connector.local_unix_socket as local_unix_socket
4142
from google.cloud.sql.connector.monitored_cache import MonitoredCache
4243
import google.cloud.sql.connector.pg8000 as pg8000
43-
import google.cloud.sql.connector.proxy as proxy
44-
import google.cloud.sql.connector.psycopg as psycopg
44+
from google.cloud.sql.connector.proxy import Proxy
4545
import google.cloud.sql.connector.pymysql as pymysql
4646
import google.cloud.sql.connector.pytds as pytds
4747
from google.cloud.sql.connector.resolver import DefaultResolver
@@ -52,7 +52,6 @@
5252
logger = logging.getLogger(name=__name__)
5353

5454
ASYNC_DRIVERS = ["asyncpg"]
55-
LOCAL_PROXY_DRIVERS = ["psycopg"]
5655
SERVER_PROXY_PORT = 3307
5756
_DEFAULT_SCHEME = "https://"
5857
_DEFAULT_UNIVERSE_DOMAIN = "googleapis.com"
@@ -160,7 +159,7 @@ def __init__(
160159
self._cache: dict[tuple[str, bool], MonitoredCache] = {}
161160
self._client: Optional[CloudSQLClient] = None
162161
self._closed: bool = False
163-
self._proxy: Optional[asyncio.Task] = None
162+
self._proxies: Optional[Proxy] = None
164163

165164
# initialize credentials
166165
scopes = ["https://www.googleapis.com/auth/sqlservice.admin"]
@@ -221,6 +220,29 @@ def __init__(
221220
def universe_domain(self) -> str:
222221
return self._universe_domain or _DEFAULT_UNIVERSE_DOMAIN
223222

223+
def start_unix_socket_proxy_async(
224+
self,
225+
instance_connection_name: str,
226+
local_socket_path: str,
227+
**kwargs: Any
228+
) -> None:
229+
"""Creates a new Proxy instance and stores it to properly disposal
230+
231+
Args:
232+
instance_connection_string (str): The instance connection name of the
233+
Cloud SQL instance to connect to. Takes the form of
234+
"project-id:region:instance-name"
235+
236+
Example: "my-project:us-central1:my-instance"
237+
238+
local_socket_path (str): A string representing the location of the local socket.
239+
240+
**kwargs: Any driver-specific arguments to pass to the underlying
241+
driver .connect call.
242+
"""
243+
# TODO: validates the local socket path is not the same as other invocation
244+
self._proxies.append(new Proxy(self, instance_connection_name, local_socket_path, self.loop, **kwargs))
245+
224246
def connect(
225247
self, instance_connection_string: str, driver: str, **kwargs: Any
226248
) -> Any:
@@ -238,7 +260,7 @@ def connect(
238260
Example: "my-project:us-central1:my-instance"
239261
240262
driver (str): A string representing the database driver to connect
241-
with. Supported drivers are pymysql, pg8000, psycopg, and pytds.
263+
with. Supported drivers are pymysql, pg8000, local_unix_socket, and pytds.
242264
243265
**kwargs: Any driver-specific arguments to pass to the underlying
244266
driver .connect call.
@@ -280,7 +302,7 @@ async def connect_async(
280302
Example: "my-project:us-central1:my-instance"
281303
282304
driver (str): A string representing the database driver to connect
283-
with. Supported drivers are pymysql, asyncpg, pg8000, psycopg, and
305+
with. Supported drivers are pymysql, asyncpg, pg8000, local_unix_socket, and
284306
pytds.
285307
286308
**kwargs: Any driver-specific arguments to pass to the underlying
@@ -293,7 +315,7 @@ async def connect_async(
293315
ValueError: Connection attempt with built-in database authentication
294316
and then subsequent attempt with IAM database authentication.
295317
KeyError: Unsupported database driver Must be one of pymysql, asyncpg,
296-
pg8000, psycopg, and pytds.
318+
pg8000, local_unix_socket, and pytds.
297319
RuntimeError: Connector has been closed. Cannot connect using a closed
298320
Connector.
299321
"""
@@ -362,7 +384,7 @@ async def connect_async(
362384
connect_func = {
363385
"pymysql": pymysql.connect,
364386
"pg8000": pg8000.connect,
365-
"psycopg": psycopg.connect,
387+
"local_unix_socket": local_unix_socket.connect,
366388
"asyncpg": asyncpg.connect,
367389
"pytds": pytds.connect,
368390
}
@@ -449,17 +471,6 @@ async def connect_async(
449471
server_hostname=ip_address,
450472
)
451473

452-
host = ip_address
453-
# start local proxy if driver needs it
454-
if driver in LOCAL_PROXY_DRIVERS:
455-
local_socket_path = kwargs.pop("local_socket_path", "/tmp/connector-socket")
456-
host = local_socket_path
457-
self._proxy = proxy.start_local_proxy(
458-
sock,
459-
socket_path=f"{local_socket_path}/.s.PGSQL.{SERVER_PROXY_PORT}",
460-
loop=self._loop
461-
)
462-
463474
# If this connection was opened using a domain name, then store it
464475
# for later in case we need to forcibly close it on failover.
465476
if conn_info.conn_name.domain_name:
@@ -543,6 +554,7 @@ async def close_async(self) -> None:
543554
await self._client.close()
544555
await asyncio.gather(*[cache.close() for cache in self._cache.values()])
545556

557+
546558
async def create_async_connector(
547559
ip_type: str | IPTypes = IPTypes.PUBLIC,
548560
enable_iam_auth: bool = False,

google/cloud/sql/connector/enums.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class DriverMapping(Enum):
6262

6363
ASYNCPG = "POSTGRES"
6464
PG8000 = "POSTGRES"
65-
PSYCOPG = "POSTGRES"
65+
LOCAL_UNIX_SOCKET = "ANY"
6666
PYMYSQL = "MYSQL"
6767
PYTDS = "SQLSERVER"
6868

@@ -79,7 +79,7 @@ def validate_engine(driver: str, engine_version: str) -> None:
7979
the given engine.
8080
"""
8181
mapping = DriverMapping[driver.upper()]
82-
if not engine_version.startswith(mapping.value):
82+
if not mapping.value == "ANY" and not engine_version.startswith(mapping.value):
8383
raise IncompatibleDriverError(
8484
f"Database driver '{driver}' is incompatible with database "
8585
f"version '{engine_version}'. Given driver can "

google/cloud/sql/connector/psycopg.py renamed to google/cloud/sql/connector/local_unix_socket.py

Lines changed: 6 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -19,44 +19,19 @@
1919

2020
SERVER_PROXY_PORT = 3307
2121

22-
if TYPE_CHECKING:
23-
import psycopg
24-
25-
2622
def connect(
2723
host: str, sock: ssl.SSLSocket, **kwargs: Any
28-
) -> "psycopg.Connection":
29-
"""Helper function to create a psycopg DB-API connection object.
24+
) -> "ssl.SSLSocket":
25+
"""Helper function to retrieve the socket for local UNIX sockets.
3026
3127
Args:
3228
host (str): A string containing the socket path used by the local proxy.
3329
sock (ssl.SSLSocket): An SSLSocket object created from the Cloud SQL
3430
server CA cert and ephemeral cert.
35-
kwargs: Additional arguments to pass to the psycopg connect method.
31+
kwargs: Additional arguments to pass to the local UNIX socket connect method.
3632
3733
Returns:
38-
psycopg.Connection: A psycopg connection to the Cloud SQL
39-
instance.
40-
41-
Raises:
42-
ImportError: The psycopg module cannot be imported.
34+
ssl.SSLSocket: The same socket
4335
"""
44-
try:
45-
from psycopg import Connection
46-
except ImportError:
47-
raise ImportError(
48-
'Unable to import module "psycopg." Please install and try again.'
49-
)
50-
51-
user = kwargs.pop("user")
52-
db = kwargs.pop("db")
53-
passwd = kwargs.pop("password", None)
54-
55-
kwargs.pop("timeout", None)
56-
57-
conn = Connection.connect(
58-
f"host={host} port={SERVER_PROXY_PORT} dbname={db} user={user} password={passwd} sslmode=require",
59-
**kwargs
60-
)
61-
62-
return conn
36+
37+
return sock

google/cloud/sql/connector/proxy.py

Lines changed: 114 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -18,73 +18,125 @@
1818
import os
1919
from pathlib import Path
2020
import socket
21+
import selectors
2122
import ssl
2223

24+
from google.cloud.sql.connector import Connector
2325
from google.cloud.sql.connector.exceptions import LocalProxyStartupError
2426

2527
SERVER_PROXY_PORT = 3307
2628
LOCAL_PROXY_MAX_MESSAGE_SIZE = 10485760
2729

28-
def start_local_proxy(
29-
ssl_sock: ssl.SSLSocket,
30-
socket_path: str,
31-
loop: asyncio.AbstractEventLoop
32-
) -> asyncio.Task:
33-
"""Helper function to start a UNIX based local proxy for
34-
transport messages through the SSL Socket.
35-
36-
Args:
37-
ssl_sock (ssl.SSLSocket): An SSLSocket object created from the Cloud SQL
38-
server CA cert and ephemeral cert.
39-
socket_path: A system path that is going to be used to store the socket.
40-
loop (asyncio.AbstractEventLoop): Event loop to run asyncio tasks.
41-
42-
Returns:
43-
asyncio.Task: The asyncio task containing the proxy server process.
44-
45-
Raises:
46-
LocalProxyStartupError: Local UNIX socket based proxy was not able to
47-
get started.
48-
"""
49-
unix_socket = None
50-
51-
try:
52-
path_parts = socket_path.rsplit('/', 1)
53-
parent_directory = '/'.join(path_parts[:-1])
54-
55-
desired_path = Path(parent_directory)
56-
desired_path.mkdir(parents=True, exist_ok=True)
57-
58-
if os.path.exists(socket_path):
59-
os.remove(socket_path)
60-
unix_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
61-
62-
unix_socket.bind(socket_path)
63-
unix_socket.listen(1)
64-
unix_socket.setblocking(False)
65-
os.chmod(socket_path, 0o600)
66-
except Exception:
67-
raise LocalProxyStartupError(
68-
'Local UNIX socket based proxy was not able to get started.'
69-
)
70-
71-
return loop.create_task(local_communication(unix_socket, ssl_sock, socket_path, loop))
72-
73-
74-
async def local_communication(
75-
unix_socket, ssl_sock, socket_path, loop
76-
):
77-
client, _ = await loop.sock_accept(unix_socket)
78-
79-
try:
30+
31+
class Proxy:
32+
"""Creates an "accept loop" async task which will open the unix server socket and listen for new connections."""
33+
34+
def __init__(
35+
self,
36+
connector: Connector,
37+
instance_connection_string: str,
38+
socket_path: str,
39+
loop: asyncio.AbstractEventLoop,
40+
**kwargs: Any
41+
) -> None:
42+
"""Keeps track of all the async tasks and starts the accept loop for new connections.
43+
44+
Args:
45+
connector (Connector): The instance where this Proxy class was created.
46+
47+
instance_connection_string (str): The instance connection name of the
48+
Cloud SQL instance to connect to. Takes the form of
49+
"project-id:region:instance-name"
50+
51+
Example: "my-project:us-central1:my-instance"
52+
53+
socket_path (str): A system path that is going to be used to store the socket.
54+
55+
loop (asyncio.AbstractEventLoop): Event loop to run asyncio tasks.
56+
57+
**kwargs: Any driver-specific arguments to pass to the underlying
58+
driver .connect call.
59+
"""
60+
self._connection_tasks = []
61+
self._addr = instance_connection_string
62+
self._kwargs = kwargs
63+
self._connector = connector
64+
self._task = loop.create_task(accept_loop(socket_path, loop, **kwargs))
65+
66+
async def accept_loop(
67+
self
68+
socket_path: str,
69+
loop: asyncio.AbstractEventLoop
70+
) -> asyncio.Task:
71+
"""Starts a UNIX based local proxy for transporting messages through
72+
the SSL Socket, and waits until there is a new connection to accept, to register it
73+
and keep track of it.
74+
75+
Args:
76+
socket_path: A system path that is going to be used to store the socket.
77+
78+
loop (asyncio.AbstractEventLoop): Event loop to run asyncio tasks.
79+
80+
Raises:
81+
LocalProxyStartupError: Local UNIX socket based proxy was not able to
82+
get started.
83+
"""
84+
unix_socket = None
85+
sel = selectors.DefaultSelector()
86+
87+
try:
88+
path_parts = socket_path.rsplit('/', 1)
89+
parent_directory = '/'.join(path_parts[:-1])
90+
91+
desired_path = Path(parent_directory)
92+
desired_path.mkdir(parents=True, exist_ok=True)
93+
94+
if os.path.exists(socket_path):
95+
os.remove(socket_path)
96+
97+
unix_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
98+
99+
unix_socket.bind(socket_path)
100+
unix_socket.listen(1)
101+
unix_socket.setblocking(False)
102+
os.chmod(socket_path, 0o600)
103+
104+
sel.register(unix_socket, selectors.EVENT_READ, data=None)
105+
106+
except Exception:
107+
raise LocalProxyStartupError(
108+
'Local UNIX socket based proxy was not able to get started.'
109+
)
110+
80111
while True:
81-
data = await loop.sock_recv(client, LOCAL_PROXY_MAX_MESSAGE_SIZE)
82-
if not data:
83-
client.close()
84-
break
85-
ssl_sock.sendall(data)
86-
response = ssl_sock.recv(LOCAL_PROXY_MAX_MESSAGE_SIZE)
87-
await loop.sock_sendall(client, response)
88-
finally:
89-
client.close()
90-
os.remove(socket_path) # Clean up the socket file
112+
client, _ = await loop.sock_accept(unix_socket)
113+
self._connection_tasks.append(loop.create_task(self.client_socket(client, unix_socket, socket_path, loop)))
114+
115+
async def close_async(self):
116+
proxy_task = asyncio.gather(self._task)
117+
try:
118+
await asyncio.wait_for(proxy_task, timeout=0.1)
119+
except (asyncio.CancelledError, asyncio.TimeoutError, TimeoutError):
120+
pass # This task runs forever so it is expected to throw this exception
121+
122+
123+
async def client_socket(
124+
self, client, unix_socket, socket_path, loop
125+
):
126+
try:
127+
ssl_sock = self.connector.connect(
128+
self._addr,
129+
'local_unix_socket',
130+
**self._kwargs
131+
)
132+
while True:
133+
data = await loop.sock_recv(client, LOCAL_PROXY_MAX_MESSAGE_SIZE)
134+
if not data:
135+
client.close()
136+
break
137+
ssl_sock.sendall(data)
138+
response = ssl_sock.recv(LOCAL_PROXY_MAX_MESSAGE_SIZE)
139+
await loop.sock_sendall(client, response)
140+
finally:
141+
client.close()
142+
os.remove(socket_path) # Clean up the socket file

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ Changelog = "https://github.com/GoogleCloudPlatform/cloud-sql-python-connector/b
5959
[project.optional-dependencies]
6060
pymysql = ["PyMySQL>=1.1.0"]
6161
pg8000 = ["pg8000>=1.31.1"]
62-
psycopg = ["psycopg>=3.2.9"]
6362
pytds = ["python-tds>=1.15.0"]
6463
asyncpg = ["asyncpg>=0.30.0"]
6564

0 commit comments

Comments
 (0)