Skip to content

Commit e8e661f

Browse files
committed
feat: add store support to cursor.execute() and connection
1 parent 5d585fe commit e8e661f

2 files changed

Lines changed: 75 additions & 14 deletions

File tree

wherobots/db/connection.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import threading
55
import uuid
66
from dataclasses import dataclass
7-
from typing import Any, Callable, Union, Dict
7+
from typing import Any, Callable, Dict
88

99
import pandas
1010
import pyarrow
@@ -24,6 +24,7 @@
2424
)
2525
from wherobots.db.cursor import Cursor
2626
from wherobots.db.errors import NotSupportedError, OperationalError
27+
from wherobots.db.store import Store, StoreResult
2728

2829

2930
@dataclass
@@ -32,6 +33,7 @@ class Query:
3233
execution_id: str
3334
state: ExecutionState
3435
handler: Callable[[Any], None]
36+
store: Store | None = None
3537

3638

3739
class Connection:
@@ -53,9 +55,9 @@ def __init__(
5355
self,
5456
ws: websockets.sync.client.ClientConnection,
5557
read_timeout: float = DEFAULT_READ_TIMEOUT_SECONDS,
56-
results_format: Union[ResultsFormat, None] = None,
57-
data_compression: Union[DataCompression, None] = None,
58-
geometry_representation: Union[GeometryRepresentation, None] = None,
58+
results_format: ResultsFormat | None = None,
59+
data_compression: DataCompression | None = None,
60+
geometry_representation: GeometryRepresentation | None = None,
5961
):
6062
self.__ws = ws
6163
self.__read_timeout = read_timeout
@@ -132,8 +134,27 @@ def __listen(self) -> None:
132134

133135
if query.state == ExecutionState.SUCCEEDED:
134136
# On a state_updated event telling us the query succeeded,
135-
# ask for results.
137+
# check if results are stored in cloud storage or need to be fetched.
136138
if kind == EventKind.STATE_UPDATED:
139+
result_uri = message.get("result_uri")
140+
if result_uri:
141+
# Results are stored in cloud storage
142+
store_result = StoreResult(
143+
result_uri=result_uri,
144+
size=message.get("size"),
145+
)
146+
logging.info(
147+
"Query %s results stored at: %s (size: %s)",
148+
execution_id,
149+
result_uri,
150+
store_result.size,
151+
)
152+
query.state = ExecutionState.COMPLETED
153+
# Return empty DataFrame with the StoreResult
154+
query.handler((pandas.DataFrame(), store_result))
155+
return
156+
157+
# No store configured, request results normally
137158
self.__request_results(execution_id)
138159
return
139160

@@ -200,7 +221,12 @@ def __recv(self) -> Dict[str, Any]:
200221
raise ValueError("Unexpected frame type received")
201222
return message
202223

203-
def __execute_sql(self, sql: str, handler: Callable[[Any], None]) -> str:
224+
def __execute_sql(
225+
self,
226+
sql: str,
227+
handler: Callable[[Any], None],
228+
store: Store | None = None,
229+
) -> str:
204230
"""Triggers the execution of the given SQL query."""
205231
execution_id = str(uuid.uuid4())
206232
request = {
@@ -209,11 +235,19 @@ def __execute_sql(self, sql: str, handler: Callable[[Any], None]) -> str:
209235
"statement": sql,
210236
}
211237

238+
if store:
239+
request["store"] = {
240+
"format": store.format.value if store.format else None,
241+
"single": store.single,
242+
"generate_presigned_url": store.generate_presigned_url,
243+
}
244+
212245
self.__queries[execution_id] = Query(
213246
sql=sql,
214247
execution_id=execution_id,
215248
state=ExecutionState.EXECUTION_REQUESTED,
216249
handler=handler,
250+
store=store,
217251
)
218252

219253
logging.info(

wherobots/db/cursor.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import queue
2-
from typing import Any, Optional, List, Tuple, Dict
2+
from typing import Any, List, Tuple, Dict
33

44
from .errors import DatabaseError, ProgrammingError
5+
from .store import Store, StoreResult
56

67
_TYPE_MAP = {
78
"object": "STRING",
@@ -20,20 +21,21 @@ def __init__(self, exec_fn, cancel_fn) -> None:
2021
self.__cancel_fn = cancel_fn
2122

2223
self.__queue: queue.Queue = queue.Queue()
23-
self.__results: Optional[list[Any]] = None
24-
self.__current_execution_id: Optional[str] = None
24+
self.__results: list[Any] | None = None
25+
self.__store_result: StoreResult | None = None
26+
self.__current_execution_id: str | None = None
2527
self.__current_row: int = 0
2628

2729
# Description and row count are set by the last executed operation.
2830
# Their default values are defined by PEP-0249.
29-
self.__description: Optional[List[Tuple]] = None
31+
self.__description: List[Tuple] | None = None
3032
self.__rowcount: int = -1
3133

3234
# Array-size is also defined by PEP-0249 and is expected to be read/writable.
3335
self.arraysize: int = 1
3436

3537
@property
36-
def description(self) -> Optional[List[Tuple]]:
38+
def description(self) -> List[Tuple] | None:
3739
return self.__description
3840

3941
@property
@@ -43,7 +45,7 @@ def rowcount(self) -> int:
4345
def __on_execution_result(self, result) -> None:
4446
self.__queue.put(result)
4547

46-
def __get_results(self) -> Optional[List[Tuple[Any, ...]]]:
48+
def __get_results(self) -> List[Tuple[Any, ...]] | None:
4749
if not self.__current_execution_id:
4850
raise ProgrammingError("No query has been executed yet")
4951
if self.__results is not None:
@@ -53,6 +55,10 @@ def __get_results(self) -> Optional[List[Tuple[Any, ...]]]:
5355
if isinstance(result, DatabaseError):
5456
raise result
5557

58+
# Unpack store result if present (result is a tuple of (DataFrame, StoreResult))
59+
if isinstance(result, tuple):
60+
result, self.__store_result = result
61+
5662
self.__rowcount = len(result)
5763
self.__results = result
5864
if not result.empty:
@@ -71,19 +77,40 @@ def __get_results(self) -> Optional[List[Tuple[Any, ...]]]:
7177

7278
return self.__results
7379

74-
def execute(self, operation: str, parameters: Dict[str, Any] = None) -> None:
80+
def execute(
81+
self,
82+
operation: str,
83+
parameters: Dict[str, Any] | None = None,
84+
store: Store | None = None,
85+
) -> None:
7586
if self.__current_execution_id:
7687
self.__cancel_fn(self.__current_execution_id)
7788

7889
self.__results = None
90+
self.__store_result = None
7991
self.__current_row = 0
8092
self.__rowcount = -1
8193
self.__description = None
8294

8395
self.__current_execution_id = self.__exec_fn(
84-
operation % (parameters or {}), self.__on_execution_result
96+
operation % (parameters or {}), self.__on_execution_result, store
8597
)
8698

99+
def get_store_result(self) -> StoreResult | None:
100+
"""Get the store result for the last executed query.
101+
102+
Returns the StoreResult containing the URI and size of the stored
103+
results, or None if the query was not configured to store results.
104+
105+
This method blocks until the query completes.
106+
"""
107+
if not self.__current_execution_id:
108+
raise ProgrammingError("No query has been executed yet")
109+
110+
# Ensure we've waited for the result
111+
self.__get_results()
112+
return self.__store_result
113+
87114
def executemany(
88115
self, operation: str, seq_of_parameters: List[Dict[str, Any]]
89116
) -> None:

0 commit comments

Comments
 (0)