11import json
22import logging
3+ import threading
34import uuid
4- from concurrent .futures import ThreadPoolExecutor
55from dataclasses import dataclass
66from typing import Callable , Any
77
88import cbor2
99import pyarrow
10+ from websockets .protocol import State
11+ from websockets .sync .client import ClientConnection
1012
1113from wherobots .db .constants import (
1214 RequestKind ,
1618 DataCompression ,
1719)
1820from wherobots .db .cursor import Cursor
19- from wherobots .db .errors import NotSupportedError , DatabaseError , OperationalError
20-
21+ from wherobots .db .errors import NotSupportedError , OperationalError
2122
2223_DEFAULT_RESULTS_FORMAT = ResultsFormat .ARROW
2324_DEFAULT_DATA_COMPRESSION = DataCompression .BROTLI
@@ -44,27 +45,24 @@ class Connection:
4445
4546 A background thread listens for events from the SQL session, and handles update to the
4647 corresponding query state. Queries are tracked by their unique execution ID.
47-
48- Note: the Connection object MUST be used as a context manager.
4948 """
5049
51- def __init__ (self , ws ):
50+ def __init__ (self , ws : ClientConnection ):
5251 self .__ws = ws
5352 self .__queries : dict [str , Query ] = {}
53+ self .__thread = threading .Thread (
54+ target = self .__main_loop , daemon = True , name = "wherobots-connection"
55+ )
56+ self .__thread .start ()
5457
5558 def __enter__ (self ):
56- self .__executor = ThreadPoolExecutor (
57- max_workers = 1 , thread_name_prefix = "wherobots-sql-connection"
58- )
59- self .__executor .submit (self .__listen )
6059 return self
6160
6261 def __exit__ (self , exc_type , exc_val , exc_tb ):
6362 self .close ()
6463
6564 def close (self ):
6665 self .__ws .close ()
67- self .__executor .shutdown (wait = True )
6866
6967 def commit (self ):
7068 raise NotSupportedError
@@ -75,78 +73,83 @@ def rollback(self):
7573 def cursor (self ) -> Cursor :
7674 return Cursor (self .__execute_sql , self .__cancel_query )
7775
76+ def __main_loop (self ):
77+ """Main background loop listening for messages from the SQL session."""
78+ while self .__ws .protocol .state < State .CLOSING :
79+ try :
80+ self .__listen ()
81+ except Exception as e :
82+ logging .exception ("Error handling message from SQL session" , e )
83+
7884 def __listen (self ):
79- """Main background loop listening for messages from the SQL session.
85+ """Waits for the next message from the SQL session and processes it .
8086
8187 The code in this method is purposefully defensive to avoid unexpected situations killing the thread.
8288 """
83- while True :
84- message = self .__recv ()
89+ message = self .__recv ()
90+ kind = message .get ("kind" )
91+ execution_id = message .get ("execution_id" )
92+ if not kind or not execution_id :
93+ # Invalid event.
94+ return
8595
86- execution_id = message .get ("execution_id" )
87- if not execution_id :
88- continue
96+ query = self .__queries .get (execution_id )
97+ if not query :
98+ logging .warning (
99+ "Received %s event for unknown execution ID %s" , kind , execution_id
100+ )
101+ return
89102
90- query = self .__queries .get (execution_id )
91- if not query :
92- logging .warning (
93- "Received %s event for unknown execution ID %s" , kind , execution_id
103+ match kind :
104+ case EventKind .STATE_UPDATED :
105+ try :
106+ query .state = ExecutionState [message ["state" ].upper ()]
107+ logging .info ("Query %s is now %s." , execution_id , query .state )
108+ except KeyError :
109+ logging .warning ("Invalid state update message for %s" , execution_id )
110+ return
111+
112+ # Incoming state transitions are handled here.
113+ match query .state :
114+ case ExecutionState .SUCCEEDED :
115+ self .__request_results (execution_id )
116+ case ExecutionState .FAILED :
117+ query .handler (OperationalError ("Query execution failed" ))
118+
119+ case EventKind .EXECUTION_RESULT :
120+ results = message .get ("results" )
121+ if not results or not isinstance (results , dict ):
122+ logging .warning ("Got no results back from %s." , execution_id )
123+ return
124+
125+ result_bytes = results .get ("result_bytes" )
126+ result_format = results .get ("format" )
127+ result_compression = results .get ("compression" )
128+ logging .info (
129+ "Received %d bytes of %s-compressed %s results from %s." ,
130+ len (result_bytes ),
131+ result_compression ,
132+ result_format ,
133+ execution_id ,
94134 )
95- continue
96-
97- kind = message .get ("kind" )
98- match kind :
99- case EventKind .STATE_UPDATED :
100- try :
101- query .state = ExecutionState [message ["state" ].upper ()]
102- logging .info ("Query %s is now %s." , execution_id , query .state )
103- except KeyError :
104- logging .warning (
105- "Invalid state update message for %s" , execution_id
106- )
107- continue
108-
109- # Incoming state transitions are handled here.
110- match query .state :
111- case ExecutionState .SUCCEEDED :
112- self .__request_results (execution_id )
113- case ExecutionState .FAILED :
114- query .handler (OperationalError ("Query execution failed" ))
115-
116- case EventKind .EXECUTION_RESULT :
117- results = message .get ("results" )
118- if not results or not isinstance (results , dict ):
119- logging .warning ("Got no results back from %s." , execution_id )
120- continue
121-
122- result_bytes = results .get ("result_bytes" )
123- result_format = results .get ("format" )
124- result_compression = results .get ("compression" )
125- logging .info (
126- "Received %d bytes of %s-compressed %s results from %s." ,
127- len (result_bytes ),
128- result_compression ,
129- result_format ,
130- execution_id ,
131- )
132-
133- query .state = ExecutionState .COMPLETED
134- match result_format :
135- case ResultsFormat .JSON :
136- query .handler (json .loads (result_bytes .decode ("utf-8" )))
137- case ResultsFormat .ARROW :
138- buffer = pyarrow .py_buffer (result_bytes )
139- stream = pyarrow .input_stream (buffer , result_compression )
140- with pyarrow .ipc .open_stream (stream ) as reader :
141- query .handler (reader .read_pandas ())
142- case _:
143- query .handler (
144- OperationalError (
145- f"Unsupported results format { result_format } "
146- )
135+
136+ query .state = ExecutionState .COMPLETED
137+ match result_format :
138+ case ResultsFormat .JSON :
139+ query .handler (json .loads (result_bytes .decode ("utf-8" )))
140+ case ResultsFormat .ARROW :
141+ buffer = pyarrow .py_buffer (result_bytes )
142+ stream = pyarrow .input_stream (buffer , result_compression )
143+ with pyarrow .ipc .open_stream (stream ) as reader :
144+ query .handler (reader .read_pandas ())
145+ case _:
146+ query .handler (
147+ OperationalError (
148+ f"Unsupported results format { result_format } "
147149 )
148- case _:
149- logging .warning ("Received unknown %s event!" , kind )
150+ )
151+ case _:
152+ logging .warning ("Received unknown %s event!" , kind )
150153
151154 def __send (self , message : dict [str , Any ]) -> None :
152155 logging .debug ("Sending %s" , message )
0 commit comments