Skip to content

Commit ce98fea

Browse files
committed
chore: cleanups, error handling, and types
1 parent 4afd77a commit ce98fea

4 files changed

Lines changed: 57 additions & 55 deletions

File tree

wherobots/db/cursor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
import json
33
import logging
44
import uuid
5-
from typing import Any
5+
from typing import Any, Callable
66

77
from .constants import ExecutionState, RequestKind, EventKind
88
from .errors import ProgrammingError, OperationalError
99

1010

1111
class Cursor:
1212

13-
def __init__(self, send_func, recv_func):
13+
def __init__(self, send_func: Callable[[str], None], recv_func: Callable[[], str]):
1414
self.__send_func = send_func
1515
self.__recv_func = recv_func
1616

@@ -49,6 +49,7 @@ def execute(self, operation: str, parameters: dict[str, Any] = None):
4949
self.__current_execution_id = str(uuid.uuid4())
5050
self.__current_execution_state = ExecutionState.EXECUTION_REQUESTED
5151
self.__current_row = 0
52+
self.__rowcount = -1
5253

5354
def _execute(request):
5455
"""This function is executed in a separate thread to send the request, wait for, and gather the results."""
@@ -133,6 +134,7 @@ def __gather_results(self):
133134
f"Unsupported results format {results_format}"
134135
)
135136
# TODO: When full results are sent, we won't need to wrap them in a list to simulate a row.
137+
self.__rowcount = 1
136138
return [json.loads(response["results"])]
137139

138140
def fetchone(self):

wherobots/db/driver.py

Lines changed: 21 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,13 @@
1515
DEFAULT_RUNTIME,
1616
DEFAULT_SESSION_WAIT_TIMEOUT_SECONDS,
1717
)
18-
from .cursor import Cursor
1918
from .errors import (
2019
InterfaceError,
21-
NotSupportedError,
2220
OperationalError,
2321
)
2422
from .region import Region
2523
from .runtime import Runtime
24+
from .session import Session
2625

2726
apilevel = "2.0"
2827
threadsafety = 1
@@ -36,7 +35,7 @@ def connect(
3635
runtime: Runtime = DEFAULT_RUNTIME,
3736
region: Region = DEFAULT_REGION,
3837
wait_timeout_seconds: int = DEFAULT_SESSION_WAIT_TIMEOUT_SECONDS,
39-
):
38+
) -> Session:
4039
if not token and not api_key:
4140
raise ValueError("At least one of `token` or `api_key` is required")
4241
if token and api_key:
@@ -60,13 +59,16 @@ def connect(
6059
if not host.startswith("http:"):
6160
host = f"https://{host}"
6261

63-
resp = requests.post(
64-
url=f"{host}/sql/session",
65-
params={"region": region.value},
66-
json={"runtimeId": runtime.value},
67-
headers=headers,
68-
)
69-
resp.raise_for_status()
62+
try:
63+
resp = requests.post(
64+
url=f"{host}/sql/session",
65+
params={"region": region.value},
66+
json={"runtimeId": runtime.value},
67+
headers=headers,
68+
)
69+
resp.raise_for_status()
70+
except requests.HTTPError as e:
71+
raise InterfaceError("Failed to create SQL session!", e)
7072

7173
# At this point we've been redirected to /sql/session/{session_id}, which we'll need to keep polling until the
7274
# session is in READY state.
@@ -79,7 +81,7 @@ def connect(
7981
(requests.HTTPError, OperationalError)
8082
),
8183
)
82-
def get_session_uri():
84+
def get_session_uri() -> str:
8385
r = requests.get(session_id_url, headers=headers)
8486
r.raise_for_status()
8587
payload = r.json()
@@ -103,45 +105,19 @@ def get_session_uri():
103105

104106

105107
def http_to_ws(uri: str) -> str:
108+
"""Converts an HTTP URI to a WebSocket URI."""
106109
parsed = urllib.parse.urlparse(uri)
107110
for from_scheme, to_scheme in [("http", "ws"), ("https", "wss")]:
108111
if parsed.scheme == from_scheme:
109112
parsed = parsed._replace(scheme=to_scheme)
110113
return str(urllib.parse.urlunparse(parsed))
111114

112115

113-
def connect_direct(uri: str, headers: dict[str, str] = None):
116+
def connect_direct(uri: str, headers: dict[str, str] = None) -> Session:
114117
logging.info("Connecting to SQL session at %s ...", uri)
115-
connection = websockets.sync.client.connect(uri=uri, additional_headers=headers)
116-
session = Session(ws=connection)
117-
return session
118-
119-
120-
class Session:
121-
"""
122-
A PEP-0249 compatible Session object for Wherobots DB.
123-
124-
The session is backed by the WebSocket connection to the Wherobots SQL session.
125-
Transactions are not supported, so commit() and rollback() raise NotSupportedError.
126-
"""
127-
128-
def __init__(self, ws):
129-
self.__ws = ws
130-
131-
def __enter__(self):
132-
return self
133-
134-
def __exit__(self, exc_type, exc_val, exc_tb):
135-
self.close()
136-
137-
def close(self):
138-
self.__ws.close()
139-
140-
def commit(self):
141-
raise NotSupportedError
142-
143-
def rollback(self):
144-
raise NotSupportedError
145-
146-
def cursor(self):
147-
return Cursor(self.__ws.send, self.__ws.recv)
118+
try:
119+
connection = websockets.sync.client.connect(uri=uri, additional_headers=headers)
120+
session = Session(ws=connection)
121+
return session
122+
except Exception as e:
123+
raise InterfaceError("Failed to connect to SQL session!", e)

wherobots/db/errors.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,5 @@ class ProgrammingError(DatabaseError):
2222
pass
2323

2424

25-
class IntegrityError(DatabaseError):
26-
pass
27-
28-
29-
class DataError(DatabaseError):
30-
pass
31-
32-
3325
class NotSupportedError(DatabaseError):
3426
pass

wherobots/db/session.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from wherobots.db.cursor import Cursor
2+
from wherobots.db.errors import NotSupportedError
3+
4+
5+
class Session:
6+
"""
7+
A PEP-0249 compatible Session object for Wherobots DB.
8+
9+
The session is backed by the WebSocket connection to the Wherobots SQL session.
10+
Transactions are not supported, so commit() and rollback() raise NotSupportedError.
11+
"""
12+
13+
def __init__(self, ws):
14+
self.__ws = ws
15+
16+
def __enter__(self):
17+
return self
18+
19+
def __exit__(self, exc_type, exc_val, exc_tb):
20+
self.close()
21+
22+
def close(self):
23+
self.__ws.close()
24+
25+
def commit(self):
26+
raise NotSupportedError
27+
28+
def rollback(self):
29+
raise NotSupportedError
30+
31+
def cursor(self) -> Cursor:
32+
return Cursor(self.__ws.send, self.__ws.recv)

0 commit comments

Comments
 (0)