Skip to content

Commit 6e1c4c1

Browse files
committed
Add execution progress handler following sqlite3 set_progress_handler pattern
Implement connection-level progress reporting as a PEP 249 vendor extension: - ProgressInfo NamedTuple and ProgressHandler type alias in connection.py - set_progress_handler(handler) method on Connection - Automatically sends enable_progress_events: true when handler is set - Handles execution_progress WebSocket events independently of the query state machine, with exception isolation - Export ProgressInfo from wherobots.db - Add --progress flag to tests/smoke.py for manual testing with rich output
1 parent a5fd480 commit 6e1c4c1

File tree

5 files changed

+86
-4
lines changed

5 files changed

+86
-4
lines changed

tests/smoke.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
from rich.console import Console
1111
from rich.table import Table
1212

13-
from wherobots.db import connect, connect_direct, errors
14-
from wherobots.db.constants import DEFAULT_ENDPOINT, DEFAULT_SESSION_TYPE
13+
from wherobots.db import connect, connect_direct, errors, ProgressInfo
1514
from wherobots.db.connection import Connection
15+
from wherobots.db.constants import DEFAULT_ENDPOINT, DEFAULT_SESSION_TYPE
1616
from wherobots.db.region import Region
1717
from wherobots.db.runtime import Runtime
1818
from wherobots.db.session_type import SessionType
@@ -54,6 +54,11 @@
5454
parser.add_argument(
5555
"--wide", help="Enable wide output", action="store_const", const=80, default=30
5656
)
57+
parser.add_argument(
58+
"--progress",
59+
help="Enable execution progress reporting",
60+
action="store_true",
61+
)
5762
parser.add_argument("sql", nargs="+", help="SQL query to execute")
5863
args = parser.parse_args()
5964

@@ -134,6 +139,26 @@ def execute(conn: Connection, sql: str) -> pandas.DataFrame | StoreResult:
134139

135140
try:
136141
with conn_func() as conn:
142+
if args.progress:
143+
console = Console(stderr=True)
144+
145+
def _on_progress(info: ProgressInfo) -> None:
146+
pct = (
147+
f"{info.tasks_completed / info.tasks_total * 100:.0f}%"
148+
if info.tasks_total
149+
else "?"
150+
)
151+
console.print(
152+
f" [dim]\\[progress][/dim] "
153+
f"[bold]{pct}[/bold] "
154+
f"{info.tasks_completed}/{info.tasks_total} tasks "
155+
f"[dim]({info.tasks_active} active)[/dim] "
156+
f"[dim]{info.execution_id[:8]}[/dim]",
157+
highlight=False,
158+
)
159+
160+
conn.set_progress_handler(_on_progress)
161+
137162
with concurrent.futures.ThreadPoolExecutor() as pool:
138163
futures = [pool.submit(execute, conn, s) for s in args.sql]
139164
for future in concurrent.futures.as_completed(futures):

wherobots/db/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,15 @@
1010
ProgrammingError,
1111
NotSupportedError,
1212
)
13-
from .models import Store, StoreResult
13+
from .models import ProgressInfo, Store, StoreResult
1414
from .region import Region
1515
from .runtime import Runtime
1616
from .types import StorageFormat
1717

1818
__all__ = [
1919
"Connection",
2020
"Cursor",
21+
"ProgressInfo",
2122
"connect",
2223
"connect_direct",
2324
"Error",

wherobots/db/connection.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from .constants import DEFAULT_READ_TIMEOUT_SECONDS
1717
from .cursor import Cursor
1818
from .errors import NotSupportedError, OperationalError
19-
from .models import ExecutionResult, Store, StoreResult
19+
from .models import ExecutionResult, ProgressInfo, Store, StoreResult
2020
from .types import (
2121
RequestKind,
2222
EventKind,
@@ -27,6 +27,10 @@
2727
)
2828

2929

30+
ProgressHandler = Callable[[ProgressInfo], None]
31+
"""A callable invoked with a :class:`ProgressInfo` on every progress event."""
32+
33+
3034
@dataclass
3135
class Query:
3236
sql: str
@@ -64,6 +68,7 @@ def __init__(
6468
self.__results_format = results_format
6569
self.__data_compression = data_compression
6670
self.__geometry_representation = geometry_representation
71+
self.__progress_handler: ProgressHandler | None = None
6772

6873
self.__queries: dict[str, Query] = {}
6974
self.__thread = threading.Thread(
@@ -89,6 +94,21 @@ def rollback(self) -> None:
8994
def cursor(self) -> Cursor:
9095
return Cursor(self.__execute_sql, self.__cancel_query)
9196

97+
def set_progress_handler(self, handler: ProgressHandler | None) -> None:
98+
"""Register a callback invoked for execution progress events.
99+
100+
When a handler is set, every ``execute_sql`` request automatically
101+
includes ``enable_progress_events: true`` so the SQL session streams
102+
progress updates for running queries.
103+
104+
Pass ``None`` to disable progress reporting.
105+
106+
This follows the `sqlite3 Connection.set_progress_handler()
107+
<https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection.set_progress_handler>`_
108+
pattern (PEP 249 vendor extension).
109+
"""
110+
self.__progress_handler = handler
111+
92112
def __main_loop(self) -> None:
93113
"""Main background loop listening for messages from the SQL session."""
94114
logging.info("Starting background connection handling loop...")
@@ -116,6 +136,25 @@ def __listen(self) -> None:
116136
# Invalid event.
117137
return
118138

139+
# Progress events are independent of the query state machine and don't
140+
# require a tracked query — the handler is connection-level.
141+
if kind == EventKind.EXECUTION_PROGRESS:
142+
handler = self.__progress_handler
143+
if handler is None:
144+
return
145+
try:
146+
handler(
147+
ProgressInfo(
148+
execution_id=execution_id,
149+
tasks_total=message.get("tasks_total", 0),
150+
tasks_completed=message.get("tasks_completed", 0),
151+
tasks_active=message.get("tasks_active", 0),
152+
)
153+
)
154+
except Exception:
155+
logging.exception("Progress handler raised an exception")
156+
return
157+
119158
query = self.__queries.get(execution_id)
120159
if not query:
121160
logging.warning(
@@ -236,6 +275,9 @@ def __execute_sql(
236275
"statement": sql,
237276
}
238277

278+
if self.__progress_handler is not None:
279+
request["enable_progress_events"] = True
280+
239281
if store:
240282
request["store"] = {
241283
"format": store.format.value,

wherobots/db/models.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,16 @@ class ExecutionResult:
7878
results: pandas.DataFrame | None = None
7979
error: Exception | None = None
8080
store_result: StoreResult | None = None
81+
82+
83+
@dataclass(frozen=True)
84+
class ProgressInfo:
85+
"""Progress information for a running query.
86+
87+
Mirrors the ``execution_progress`` event sent by the SQL session.
88+
"""
89+
90+
execution_id: str
91+
tasks_total: int
92+
tasks_completed: int
93+
tasks_active: int

wherobots/db/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class EventKind(LowercaseStrEnum):
4545
STATE_UPDATED = auto()
4646
EXECUTION_RESULT = auto()
4747
ERROR = auto()
48+
EXECUTION_PROGRESS = auto()
4849

4950

5051
class ResultsFormat(LowercaseStrEnum):

0 commit comments

Comments
 (0)