Skip to content

Commit 2e15072

Browse files
committed
feat: switch to a simple, true daemon thread for connection loop
1 parent a114595 commit 2e15072

1 file changed

Lines changed: 78 additions & 75 deletions

File tree

wherobots/db/connection.py

Lines changed: 78 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import json
22
import logging
3+
import threading
34
import uuid
4-
from concurrent.futures import ThreadPoolExecutor
55
from dataclasses import dataclass
66
from typing import Callable, Any
77

88
import cbor2
99
import pyarrow
10+
from websockets.protocol import State
11+
from websockets.sync.client import ClientConnection
1012

1113
from wherobots.db.constants import (
1214
RequestKind,
@@ -16,8 +18,7 @@
1618
DataCompression,
1719
)
1820
from wherobots.db.cursor import Cursor
19-
from wherobots.db.errors import NotSupportedError, DatabaseError, OperationalError
20-
21+
from wherobots.db.errors import NotSupportedError, OperationalError
2122

2223
_DEFAULT_RESULTS_FORMAT = ResultsFormat.ARROW
2324
_DEFAULT_DATA_COMPRESSION = DataCompression.BROTLI
@@ -44,27 +45,24 @@ class Connection:
4445
4546
A background thread listens for events from the SQL session, and handles update to the
4647
corresponding query state. Queries are tracked by their unique execution ID.
47-
48-
Note: the Connection object MUST be used as a context manager.
4948
"""
5049

51-
def __init__(self, ws):
50+
def __init__(self, ws: ClientConnection):
5251
self.__ws = ws
5352
self.__queries: dict[str, Query] = {}
53+
self.__thread = threading.Thread(
54+
target=self.__main_loop, daemon=True, name="wherobots-connection"
55+
)
56+
self.__thread.start()
5457

5558
def __enter__(self):
56-
self.__executor = ThreadPoolExecutor(
57-
max_workers=1, thread_name_prefix="wherobots-sql-connection"
58-
)
59-
self.__executor.submit(self.__listen)
6059
return self
6160

6261
def __exit__(self, exc_type, exc_val, exc_tb):
6362
self.close()
6463

6564
def close(self):
6665
self.__ws.close()
67-
self.__executor.shutdown(wait=True)
6866

6967
def commit(self):
7068
raise NotSupportedError
@@ -75,78 +73,83 @@ def rollback(self):
7573
def cursor(self) -> Cursor:
7674
return Cursor(self.__execute_sql, self.__cancel_query)
7775

76+
def __main_loop(self):
77+
"""Main background loop listening for messages from the SQL session."""
78+
while self.__ws.protocol.state < State.CLOSING:
79+
try:
80+
self.__listen()
81+
except Exception as e:
82+
logging.exception("Error handling message from SQL session", e)
83+
7884
def __listen(self):
79-
"""Main background loop listening for messages from the SQL session.
85+
"""Waits for the next message from the SQL session and processes it.
8086
8187
The code in this method is purposefully defensive to avoid unexpected situations killing the thread.
8288
"""
83-
while True:
84-
message = self.__recv()
89+
message = self.__recv()
90+
kind = message.get("kind")
91+
execution_id = message.get("execution_id")
92+
if not kind or not execution_id:
93+
# Invalid event.
94+
return
8595

86-
execution_id = message.get("execution_id")
87-
if not execution_id:
88-
continue
96+
query = self.__queries.get(execution_id)
97+
if not query:
98+
logging.warning(
99+
"Received %s event for unknown execution ID %s", kind, execution_id
100+
)
101+
return
89102

90-
query = self.__queries.get(execution_id)
91-
if not query:
92-
logging.warning(
93-
"Received %s event for unknown execution ID %s", kind, execution_id
103+
match kind:
104+
case EventKind.STATE_UPDATED:
105+
try:
106+
query.state = ExecutionState[message["state"].upper()]
107+
logging.info("Query %s is now %s.", execution_id, query.state)
108+
except KeyError:
109+
logging.warning("Invalid state update message for %s", execution_id)
110+
return
111+
112+
# Incoming state transitions are handled here.
113+
match query.state:
114+
case ExecutionState.SUCCEEDED:
115+
self.__request_results(execution_id)
116+
case ExecutionState.FAILED:
117+
query.handler(OperationalError("Query execution failed"))
118+
119+
case EventKind.EXECUTION_RESULT:
120+
results = message.get("results")
121+
if not results or not isinstance(results, dict):
122+
logging.warning("Got no results back from %s.", execution_id)
123+
return
124+
125+
result_bytes = results.get("result_bytes")
126+
result_format = results.get("format")
127+
result_compression = results.get("compression")
128+
logging.info(
129+
"Received %d bytes of %s-compressed %s results from %s.",
130+
len(result_bytes),
131+
result_compression,
132+
result_format,
133+
execution_id,
94134
)
95-
continue
96-
97-
kind = message.get("kind")
98-
match kind:
99-
case EventKind.STATE_UPDATED:
100-
try:
101-
query.state = ExecutionState[message["state"].upper()]
102-
logging.info("Query %s is now %s.", execution_id, query.state)
103-
except KeyError:
104-
logging.warning(
105-
"Invalid state update message for %s", execution_id
106-
)
107-
continue
108-
109-
# Incoming state transitions are handled here.
110-
match query.state:
111-
case ExecutionState.SUCCEEDED:
112-
self.__request_results(execution_id)
113-
case ExecutionState.FAILED:
114-
query.handler(OperationalError("Query execution failed"))
115-
116-
case EventKind.EXECUTION_RESULT:
117-
results = message.get("results")
118-
if not results or not isinstance(results, dict):
119-
logging.warning("Got no results back from %s.", execution_id)
120-
continue
121-
122-
result_bytes = results.get("result_bytes")
123-
result_format = results.get("format")
124-
result_compression = results.get("compression")
125-
logging.info(
126-
"Received %d bytes of %s-compressed %s results from %s.",
127-
len(result_bytes),
128-
result_compression,
129-
result_format,
130-
execution_id,
131-
)
132-
133-
query.state = ExecutionState.COMPLETED
134-
match result_format:
135-
case ResultsFormat.JSON:
136-
query.handler(json.loads(result_bytes.decode("utf-8")))
137-
case ResultsFormat.ARROW:
138-
buffer = pyarrow.py_buffer(result_bytes)
139-
stream = pyarrow.input_stream(buffer, result_compression)
140-
with pyarrow.ipc.open_stream(stream) as reader:
141-
query.handler(reader.read_pandas())
142-
case _:
143-
query.handler(
144-
OperationalError(
145-
f"Unsupported results format {result_format}"
146-
)
135+
136+
query.state = ExecutionState.COMPLETED
137+
match result_format:
138+
case ResultsFormat.JSON:
139+
query.handler(json.loads(result_bytes.decode("utf-8")))
140+
case ResultsFormat.ARROW:
141+
buffer = pyarrow.py_buffer(result_bytes)
142+
stream = pyarrow.input_stream(buffer, result_compression)
143+
with pyarrow.ipc.open_stream(stream) as reader:
144+
query.handler(reader.read_pandas())
145+
case _:
146+
query.handler(
147+
OperationalError(
148+
f"Unsupported results format {result_format}"
147149
)
148-
case _:
149-
logging.warning("Received unknown %s event!", kind)
150+
)
151+
case _:
152+
logging.warning("Received unknown %s event!", kind)
150153

151154
def __send(self, message: dict[str, Any]) -> None:
152155
logging.debug("Sending %s", message)

0 commit comments

Comments
 (0)