Skip to content

Commit c878d34

Browse files
committed
feat: implement new results handling logic
Add support for the new encoding of results as a CBOR2-encoded binary WebSocket data frame of a results envelope, containing a JSON or Arrow encoding of a Pandas DataFrame result. Plus some nicer output in the smoke test script.
1 parent 9cb7dc6 commit c878d34

6 files changed

Lines changed: 354 additions & 39 deletions

File tree

poetry.lock

Lines changed: 280 additions & 10 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ python = "^3.11"
1616
requests = "^2.31.0"
1717
websockets = "^12.0"
1818
tenacity = "^8.2.3"
19+
pyarrow = "^15.0.2"
20+
cbor2 = "^5.6.3"
21+
pandas = "^2.2.2"
1922

2023
[tool.poetry.group.dev.dependencies]
2124
mypy = "^1.8.0"
@@ -24,7 +27,7 @@ black = "^24.2.0"
2427
pre-commit = "^3.6.2"
2528
conventional-pre-commit = "^3.1.0"
2629
shapely = "^2.0.3"
27-
tabulate = "^0.9.0"
30+
rich = "^13.7.1"
2831

2932
[build-system]
3033
requires = ["poetry-core"]

tests/smoke.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
import logging
66
import sys
77

8-
import shapely
9-
import tabulate
8+
import pandas
9+
from rich.console import Console
10+
from rich.table import Table
1011

1112
from wherobots.db import connect, connect_direct
1213
from wherobots.db.region import Region
@@ -24,6 +25,9 @@
2425
default=logging.INFO,
2526
)
2627
parser.add_argument("--ws-url", help="Direct URL to connect to")
28+
parser.add_argument(
29+
"--wide", help="Enable wide output", action="store_const", const=80, default=30
30+
)
2731
parser.add_argument("sql", help="SQL query to execute")
2832
args = parser.parse_args()
2933

@@ -60,11 +64,13 @@
6064
with conn_func() as conn:
6165
cursor = conn.cursor()
6266
cursor.execute(args.sql)
63-
results = cursor.fetchall()
64-
65-
for row in results:
66-
for key, value in row.items():
67-
if "geometry" in key:
68-
row[key] = shapely.to_geojson(shapely.from_wkt(value))
67+
results: pandas.DataFrame = cursor.fetchall()
6968

70-
print(tabulate.tabulate(results, headers="keys", tablefmt="rounded_outline"))
69+
table = Table()
70+
table.add_column("#")
71+
for column in results.columns:
72+
table.add_column(column, max_width=args.wide, no_wrap=True)
73+
for row in results.itertuples(name=None):
74+
r = [str(x) for x in row]
75+
table.add_row(*r)
76+
Console().print(table)

wherobots/db/connection.py

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,30 @@
55
from dataclasses import dataclass
66
from typing import Callable, Any
77

8-
from wherobots.db.constants import RequestKind, EventKind, ExecutionState
8+
import cbor2
9+
import pyarrow
10+
11+
from wherobots.db.constants import (
12+
RequestKind,
13+
EventKind,
14+
ExecutionState,
15+
ResultsFormat,
16+
DataCompression,
17+
)
918
from wherobots.db.cursor import Cursor
1019
from wherobots.db.errors import NotSupportedError, DatabaseError, OperationalError
1120

1221

22+
_DEFAULT_RESULTS_FORMAT = ResultsFormat.ARROW
23+
_DEFAULT_DATA_COMPRESSION = DataCompression.BROTLI
24+
25+
1326
@dataclass
1427
class Query:
1528
sql: str
1629
execution_id: str
1730
state: ExecutionState
18-
handler: Callable[[list[Any] | DatabaseError], None]
31+
handler: Callable[[Any], None]
1932

2033

2134
class Connection:
@@ -101,23 +114,35 @@ def __listen(self):
101114
query.handler(OperationalError("Query execution failed"))
102115

103116
case EventKind.EXECUTION_RESULT:
104-
results_format = message.get("results_format")
117+
results = message.get("results")
118+
if not results or not isinstance(results, dict):
119+
logging.warning("Got no results back from %s.", execution_id)
120+
continue
121+
122+
result_bytes = results.get("result_bytes")
123+
result_format = results.get("format")
124+
result_compression = results.get("compression")
105125
logging.info(
106-
"Received %s results from %s.",
107-
results_format,
126+
"Received %d bytes of %s-compressed %s results from %s.",
127+
len(result_bytes),
128+
result_compression,
129+
result_format,
108130
execution_id,
109131
)
110132

111133
query.state = ExecutionState.COMPLETED
112-
match results_format:
113-
case "json":
114-
query.handler(json.loads(message.get("results")))
115-
case "arrow":
116-
pass
134+
match result_format:
135+
case ResultsFormat.JSON:
136+
query.handler(json.loads(result_bytes.decode("utf-8")))
137+
case ResultsFormat.ARROW:
138+
buffer = pyarrow.py_buffer(result_bytes)
139+
stream = pyarrow.input_stream(buffer, result_compression)
140+
with pyarrow.ipc.open_stream(stream) as reader:
141+
query.handler(reader.read_pandas())
117142
case _:
118143
query.handler(
119144
OperationalError(
120-
f"Unsupported results format {results_format}"
145+
f"Unsupported results format {result_format}"
121146
)
122147
)
123148
case _:
@@ -129,15 +154,16 @@ def __send(self, message: dict[str, Any]) -> None:
129154

130155
def __recv(self) -> dict[str, Any]:
131156
frame = self.__ws.recv()
132-
if isinstance(frame, bytes):
133-
frame = frame.decode("utf-8")
134-
message = json.loads(frame)
157+
if isinstance(frame, str):
158+
message = json.loads(frame)
159+
elif isinstance(frame, bytes):
160+
message = cbor2.loads(frame)
161+
else:
162+
raise ValueError("Unexpected frame type received")
135163
logging.debug("Received message: %s", message)
136164
return message
137165

138-
def __execute_sql(
139-
self, sql: str, handler: Callable[[list[Any] | DatabaseError], None]
140-
) -> str:
166+
def __execute_sql(self, sql: str, handler: Callable[[Any], None]) -> str:
141167
"""Triggers the execution of the given SQL query."""
142168
execution_id = str(uuid.uuid4())
143169
request = {
@@ -164,7 +190,8 @@ def __request_results(self, execution_id: str) -> None:
164190
request = {
165191
"kind": RequestKind.RETRIEVE_RESULTS.value,
166192
"execution_id": execution_id,
167-
"results_format": "json",
193+
"format": _DEFAULT_RESULTS_FORMAT.value,
194+
"compression": _DEFAULT_DATA_COMPRESSION.value,
168195
}
169196
query.state = ExecutionState.RESULTS_REQUESTED
170197
logging.info("Requesting results from %s ...", execution_id)

wherobots/db/constants.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,12 @@ class RequestKind(StrEnum):
4545
class EventKind(StrEnum):
4646
STATE_UPDATED = auto()
4747
EXECUTION_RESULT = auto()
48+
49+
50+
class ResultsFormat(StrEnum):
51+
JSON = auto()
52+
ARROW = auto()
53+
54+
55+
class DataCompression(StrEnum):
56+
BROTLI = auto()

wherobots/db/cursor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def description(self) -> str | None:
3131
def rowcount(self) -> int:
3232
return self.__rowcount
3333

34-
def __on_execution_result(self, result: list[Any] | DatabaseError) -> None:
34+
def __on_execution_result(self, result) -> None:
3535
self.__queue.put(result)
3636

3737
def __get_results(self) -> list[Any] | None:

0 commit comments

Comments
 (0)