Skip to content

Commit 2d246d8

Browse files
committed
feat: refactor to support multiple concurrent queries
Major refactoring of the protocol handling logic, moving it to the Connection. The Connection object is now responsible for all interactions with the SQL session, query tracking, and the corresponding state machine. The Cursor becomes a much simpler affair, simply requesting a query execution from the Connection, providing a callback handler for results. Internally, a simple queue is used to allow cursor.fetch*() methods to block until results become available.
1 parent 67a5aa2 commit 2d246d8

4 files changed

Lines changed: 175 additions & 117 deletions

File tree

README.md

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ $ pip install wherobots-python-dbapi-driver
2222
## Usage
2323

2424
```python
25-
from contextlib import closing
2625
import tabulate
2726

2827
from wherobots.db import connect
@@ -33,8 +32,8 @@ with connect(
3332
api_key='...',
3433
runtime=Runtime.SEDONA,
3534
region=Region.AWS_US_WEST_2) as conn:
36-
with closing(conn.cursor()) as curr:
37-
curr.execute("SHOW SCHEMAS IN wherobots_open_data")
38-
results = curr.fetchall()
39-
print(tabulate.tabulate(results, headers="keys", tablefmt="pretty"))
35+
curr = conn.cursor()
36+
curr.execute("SHOW SCHEMAS IN wherobots_open_data")
37+
results = curr.fetchall()
38+
print(tabulate.tabulate(results, headers="keys", tablefmt="pretty"))
4039
```

tests/smoke.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,14 @@
33
import argparse
44
import functools
55
import logging
6-
from contextlib import closing
6+
import sys
77

88
import shapely
9-
import sys
109
import tabulate
1110

1211
from wherobots.db import connect, connect_direct
13-
from wherobots.db.runtime import Runtime
1412
from wherobots.db.region import Region
15-
13+
from wherobots.db.runtime import Runtime
1614

1715
if __name__ == "__main__":
1816
parser = argparse.ArgumentParser()
@@ -60,9 +58,9 @@
6058
)
6159

6260
with conn_func() as conn:
63-
with closing(conn.cursor()) as cursor:
64-
cursor.execute(args.sql)
65-
results = cursor.fetchall()
61+
cursor = conn.cursor()
62+
cursor.execute(args.sql)
63+
results = cursor.fetchall()
6664

6765
for row in results:
6866
for key, value in row.items():

wherobots/db/connection.py

Lines changed: 139 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,21 @@
1+
import json
2+
import logging
3+
import uuid
4+
from concurrent.futures import ThreadPoolExecutor
5+
from dataclasses import dataclass
6+
from typing import Callable, Any
7+
8+
from wherobots.db.constants import RequestKind, EventKind, ExecutionState
19
from wherobots.db.cursor import Cursor
2-
from wherobots.db.errors import NotSupportedError
10+
from wherobots.db.errors import NotSupportedError, DatabaseError, OperationalError
11+
12+
13+
@dataclass
14+
class Query:
15+
sql: str
16+
execution_id: str
17+
state: ExecutionState
18+
handler: Callable[[list[Any] | DatabaseError], None]
319

420

521
class Connection:
@@ -8,19 +24,34 @@ class Connection:
824
925
The connection is backed by the WebSocket connected to the Wherobots SQL session instance.
1026
Transactions are not supported, so commit() and rollback() raise NotSupportedError.
27+
28+
This class handles all the interactions with the remote SQL session, and the details of the
29+
Wherobots Spatial SQL API protocol. It supports multiple concurrent cursors, each one executing
30+
a single query at a time.
31+
32+
A background thread listens for events from the SQL session, and handles update to the
33+
corresponding query state. Queries are tracked by their unique execution ID.
34+
35+
Note: the Connection object MUST be used as a context manager.
1136
"""
1237

1338
def __init__(self, ws):
1439
self.__ws = ws
40+
self.__queries: dict[str, Query] = {}
1541

1642
def __enter__(self):
43+
self.__executor = ThreadPoolExecutor(
44+
max_workers=1, thread_name_prefix="wherobots-sql-connection"
45+
)
46+
self.__executor.submit(self.__listen)
1747
return self
1848

1949
def __exit__(self, exc_type, exc_val, exc_tb):
2050
self.close()
2151

2252
def close(self):
2353
self.__ws.close()
54+
self.__executor.shutdown(wait=True)
2455

2556
def commit(self):
2657
raise NotSupportedError
@@ -29,4 +60,110 @@ def rollback(self):
2960
raise NotSupportedError
3061

3162
def cursor(self) -> Cursor:
32-
return Cursor(self.__ws.send, self.__ws.recv)
63+
return Cursor(self.__execute_sql, self.__cancel_query)
64+
65+
def __listen(self):
66+
"""Main background loop listening for messages from the SQL session.
67+
68+
The code in this method is purposefully defensive to avoid unexpected situations killing the thread.
69+
"""
70+
while True:
71+
message = json.loads(self.__ws.recv())
72+
logging.debug("Received message: %s", message)
73+
74+
execution_id = message.get("execution_id")
75+
if not execution_id:
76+
continue
77+
78+
query = self.__queries.get(execution_id)
79+
if not query:
80+
logging.warning(
81+
"Received %s event for unknown execution ID %s", kind, execution_id
82+
)
83+
continue
84+
85+
kind = message.get("kind")
86+
match kind:
87+
case EventKind.STATE_UPDATED:
88+
try:
89+
query.state = ExecutionState[message["state"].upper()]
90+
logging.info("Query %s is now %s.", execution_id, query.state)
91+
except KeyError:
92+
logging.warning(
93+
"Invalid state update message for %s", execution_id
94+
)
95+
continue
96+
97+
# Incoming state transitions are handled here.
98+
match query.state:
99+
case ExecutionState.SUCCEEDED:
100+
self.__request_results(execution_id)
101+
case ExecutionState.FAILED:
102+
query.handler(OperationalError("Query execution failed"))
103+
104+
case EventKind.EXECUTION_RESULT:
105+
results_format = message.get("results_format")
106+
107+
# TODO: Support other results formats.
108+
if results_format != "json":
109+
query.handler(
110+
OperationalError(
111+
f"Unsupported results format {results_format}"
112+
)
113+
)
114+
continue
115+
116+
logging.info(
117+
"Received %s results from %s.",
118+
results_format,
119+
execution_id,
120+
)
121+
query.state = ExecutionState.COMPLETED
122+
query.handler([json.loads(message.get("results"))])
123+
case _:
124+
logging.warning("Received unknown %s event!", kind)
125+
126+
def __send(self, message: dict[str, Any]) -> None:
127+
logging.debug("Sending %s", message)
128+
self.__ws.send(json.dumps(message))
129+
130+
def __execute_sql(
131+
self, sql: str, handler: Callable[[list[Any] | DatabaseError], None]
132+
) -> str:
133+
"""Triggers the execution of the given SQL query."""
134+
execution_id = str(uuid.uuid4())
135+
request = {
136+
"kind": RequestKind.EXECUTE_SQL.value,
137+
"execution_id": execution_id,
138+
"statement": sql,
139+
}
140+
141+
self.__queries[execution_id] = Query(
142+
sql=sql,
143+
execution_id=execution_id,
144+
state=ExecutionState.EXECUTION_REQUESTED,
145+
handler=handler,
146+
)
147+
self.__send(request)
148+
return execution_id
149+
150+
def __request_results(self, execution_id: str) -> None:
151+
query = self.__queries.get(execution_id)
152+
if not query:
153+
return
154+
155+
# TODO: Switch to Arrow encoding of results when supported.
156+
request = {
157+
"kind": RequestKind.RETRIEVE_RESULTS.value,
158+
"execution_id": execution_id,
159+
"results_format": "json",
160+
}
161+
query.state = ExecutionState.RESULTS_REQUESTED
162+
logging.info("Requesting results from %s ...", execution_id)
163+
self.__send(request)
164+
165+
def __cancel_query(self, execution_id: str) -> None:
166+
query = self.__queries.pop(execution_id)
167+
if query:
168+
logging.info("Cancelled query %s.", execution_id)
169+
# TODO: when protocol supports it, send cancellation request.

wherobots/db/cursor.py

Lines changed: 27 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,20 @@
1-
from concurrent.futures import CancelledError, Future, ThreadPoolExecutor
2-
import json
3-
import logging
4-
import uuid
5-
from typing import Any, Callable
1+
import queue
2+
from typing import Any
63

7-
from .constants import ExecutionState, RequestKind, EventKind
8-
from .errors import ProgrammingError, OperationalError
4+
from .errors import ProgrammingError, DatabaseError
95

106

117
class Cursor:
128

13-
def __init__(self, send_func: Callable[[str], None], recv_func: Callable[[], str]):
14-
self.__send_func = send_func
15-
self.__recv_func = recv_func
9+
def __init__(self, exec_fn, cancel_fn):
10+
self.__exec_fn = exec_fn
11+
self.__cancel_fn = cancel_fn
1612

13+
self.__queue: queue.Queue = queue.Queue()
14+
self.__results: list[Any] | None = None
1715
self.__current_execution_id: str | None = None
18-
self.__current_execution_state: ExecutionState = ExecutionState.IDLE
19-
self.__current_execution_results: Future[list[Any]] | None = None
2016
self.__current_row: int = 0
2117

22-
self.__executor = ThreadPoolExecutor(
23-
max_workers=1, thread_name_prefix="wherobots-sql-cursor"
24-
)
25-
2618
# Description and row count are set by the last executed operation.
2719
# Their default values are defined by PEP-0249.
2820
self.__description: str | None = None
@@ -39,104 +31,36 @@ def description(self) -> str | None:
3931
def rowcount(self) -> int:
4032
return self.__rowcount
4133

42-
def close(self):
43-
self.__executor.shutdown()
34+
def __on_execution_result(self, result: list[Any] | DatabaseError) -> None:
35+
self.__queue.put(result)
4436

45-
def __send_request(self, request):
46-
return self.__send_func(json.dumps(request))
37+
def __get_results(self) -> list[Any] | None:
38+
if not self.__current_execution_id:
39+
raise ProgrammingError("No query has been executed yet")
40+
if self.__results is not None:
41+
return self.__results
42+
43+
result = self.__queue.get()
44+
if isinstance(result, DatabaseError):
45+
raise result
46+
self.__rowcount = len(result)
47+
self.__results = result
48+
return self.__results
4749

4850
def execute(self, operation: str, parameters: dict[str, Any] = None):
49-
self.__current_execution_id = str(uuid.uuid4())
50-
self.__current_execution_state = ExecutionState.EXECUTION_REQUESTED
51+
if self.__current_execution_id:
52+
self.__cancel_fn(self.__current_execution_id)
53+
54+
self.__results = None
5155
self.__current_row = 0
5256
self.__rowcount = -1
5357

54-
def _execute(request):
55-
"""This function is executed in a separate thread to send the request, wait for, and gather the results."""
56-
logging.info("Executing SQL: %s", request["statement"])
57-
self.__send_request(request)
58-
return self.__gather_results()
59-
6058
sql = operation.format(**(parameters or {}))
61-
exec_request = {
62-
"kind": RequestKind.EXECUTE_SQL.value,
63-
"execution_id": self.__current_execution_id,
64-
"statement": sql,
65-
}
66-
67-
if self.__current_execution_results:
68-
self.__current_execution_results.cancel()
69-
self.__current_execution_results = self.__executor.submit(
70-
_execute, exec_request
71-
)
59+
self.__current_execution_id = self.__exec_fn(sql, self.__on_execution_result)
7260

7361
def executemany(self, operation: str, seq_of_parameters: list[dict[str, Any]]):
7462
raise NotImplementedError
7563

76-
def __get_results(self) -> list[Any] | None:
77-
if not self.__current_execution_results:
78-
raise ProgrammingError("No query has been executed yet")
79-
try:
80-
return self.__current_execution_results.result()
81-
except CancelledError:
82-
raise ProgrammingError(
83-
"Query execution was cancelled while waiting for results"
84-
)
85-
86-
def __gather_results(self):
87-
"""
88-
Reads from the SQL session for updates on the current execution state. Once the query has completed
89-
successfully, requests and processes the results.
90-
"""
91-
while not self.__current_execution_state.is_terminal_state():
92-
response = json.loads(self.__recv_func())
93-
logging.debug("Received response: %s", response)
94-
95-
kind = EventKind[response["kind"].upper()]
96-
match kind:
97-
case EventKind.STATE_UPDATED:
98-
self.__current_execution_state = ExecutionState[
99-
response["state"].upper()
100-
]
101-
logging.info(
102-
"Query %s is %s.",
103-
self.__current_execution_id,
104-
self.__current_execution_state,
105-
)
106-
107-
match self.__current_execution_state:
108-
case ExecutionState.SUCCEEDED:
109-
results_request = {
110-
"kind": RequestKind.RETRIEVE_RESULTS.value,
111-
"execution_id": self.__current_execution_id,
112-
"results_format": "json",
113-
}
114-
logging.info(
115-
"Requesting results from %s ...",
116-
self.__current_execution_id,
117-
)
118-
self.__send_request(results_request)
119-
self.__current_execution_state = (
120-
ExecutionState.RESULTS_REQUESTED
121-
)
122-
case ExecutionState.FAILED:
123-
raise OperationalError("Execution failed")
124-
case EventKind.EXECUTION_RESULT:
125-
self.__current_execution_state = ExecutionState.COMPLETED
126-
results_format = response["results_format"]
127-
logging.info(
128-
"Received %s results from %s.",
129-
results_format,
130-
self.__current_execution_id,
131-
)
132-
if results_format != "json":
133-
raise OperationalError(
134-
f"Unsupported results format {results_format}"
135-
)
136-
# TODO: When full results are sent, we won't need to wrap them in a list to simulate a row.
137-
self.__rowcount = 1
138-
return [json.loads(response["results"])]
139-
14064
def fetchone(self):
14165
results = self.__get_results()[self.__current_row :]
14266
if not results:

0 commit comments

Comments
 (0)