Skip to content

Commit 663dac2

Browse files
committed
Fix wrong parsing position
1 parent ed343dc commit 663dac2

3 files changed

Lines changed: 59 additions & 63 deletions

File tree

src/sqlitecloud/driver.py

Lines changed: 50 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import socket
1717
import sys
1818

19+
1920
class Driver:
2021
def __init__(self) -> None:
2122
# used for parsing chunked rowset
@@ -54,7 +55,7 @@ def connect(
5455
sock.connect((hostname, port))
5556
except Exception as e:
5657
errmsg = f"An error occurred while initializing the socket."
57-
raise SQCloudException(errmsg, -1) from e
58+
raise SQCloudException(errmsg) from e
5859

5960
connection = SQCloudConnect()
6061
connection.socket = sock
@@ -119,15 +120,15 @@ def _internal_config_apply(
119120
if len(buffer) > 0:
120121
self._internal_run_command(connection, buffer)
121122

122-
def _internal_run_command(self, connection: SQCloudConnect, buffer: str) -> None:
123-
self._internal_socket_write(connection, buffer)
123+
def _internal_run_command(self, connection: SQCloudConnect, command: str) -> None:
124+
self._internal_socket_write(connection, command)
124125
return self._internal_socket_read(connection)
125126

126-
def _internal_socket_write(self, connection: SQCloudConnect, buffer: str) -> None:
127+
def _internal_socket_write(self, connection: SQCloudConnect, command: str) -> None:
127128
# compute header
128129
delimit = "$" if connection.isblob else "+"
129-
bytebuffer = buffer.encode()
130-
buffer_len = len(bytebuffer)
130+
buffer = command.encode()
131+
buffer_len = len(buffer)
131132
header = f"{delimit}{buffer_len} "
132133

133134
# write header
@@ -143,39 +144,38 @@ def _internal_socket_write(self, connection: SQCloudConnect, buffer: str) -> Non
143144
if buffer_len == 0:
144145
return
145146
try:
146-
connection.socket.sendall(buffer.encode())
147+
connection.socket.sendall(buffer)
147148
except Exception as exc:
148149
raise SQCloudException(
149150
"An error occurred while writing data.",
150151
SQCLOUD_INTERNAL_ERRCODE.INTERNAL_ERRCODE_NETWORK,
151152
) from exc
152153

153154
def _internal_socket_read(self, connection: SQCloudConnect) -> SQCloudResult:
154-
buffer = ""
155+
buffer = b""
155156
buffer_size = 8192
156157
nread = 0
157-
bytebuffer = b""
158+
158159
try:
159160
while True:
160161
data = connection.socket.recv(buffer_size)
161162
if not data:
162-
raise SQCloudException('Incomplete response from server.', -1)
163+
raise SQCloudException("Incomplete response from server.")
163164

164-
# the expected data length to read
165+
# the expected data length to read
165166
# matches the string size before decoding it
166167
nread += len(data)
167168
# update buffers
168-
buffer += data.decode()
169-
bytebuffer += data
169+
buffer += data
170170

171-
c = buffer[0]
171+
c = chr(buffer[0])
172172

173173
if (
174174
c == SQCLOUD_CMD.INT.value
175175
or c == SQCLOUD_CMD.FLOAT.value
176176
or c == SQCLOUD_CMD.NULL.value
177177
):
178-
if not buffer.endswith(' '):
178+
if not buffer.endswith(b" "):
179179
continue
180180
elif c == SQCLOUD_CMD.ROWSET_CHUNK.value:
181181
isEndOfChunk = buffer.endswith(SQCLOUD_ROWSET.CHUNKS_END.value)
@@ -202,7 +202,7 @@ def _internal_socket_read(self, connection: SQCloudConnect) -> SQCloudResult:
202202
SQCLOUD_INTERNAL_ERRCODE.INTERNAL_ERRCODE_NETWORK,
203203
) from exc
204204

205-
def _internal_parse_number(self, buffer: str, index: int = 1) -> SQCloudNumber:
205+
def _internal_parse_number(self, buffer: bytes, index: int = 1) -> SQCloudNumber:
206206
sqcloud_number = SQCloudNumber()
207207
sqcloud_number.value = 0
208208
extvalue = 0
@@ -211,7 +211,7 @@ def _internal_parse_number(self, buffer: str, index: int = 1) -> SQCloudNumber:
211211

212212
# from 1 to skip the first command type character
213213
for i in range(index, blen):
214-
c = buffer[i]
214+
c = chr(buffer[i])
215215

216216
# check for optional extended error code (ERRCODE:EXTERRCODE)
217217
if c == ":":
@@ -235,7 +235,7 @@ def _internal_parse_number(self, buffer: str, index: int = 1) -> SQCloudNumber:
235235
sqcloud_number.value = 0
236236
return sqcloud_number
237237

238-
def _internal_parse_buffer(self, buffer: str, blen: int) -> SQCloudResult:
238+
def _internal_parse_buffer(self, buffer: bytes, blen: int) -> SQCloudResult:
239239
# possible return values:
240240
# True => OK
241241
# False => error
@@ -247,18 +247,17 @@ def _internal_parse_buffer(self, buffer: str, blen: int) -> SQCloudResult:
247247
# None
248248

249249
# check OK value
250-
if buffer == "+2 OK":
250+
if buffer == b"+2 OK":
251251
return SQCloudResult(True)
252252

253-
cmd = buffer[0]
253+
cmd = chr(buffer[0])
254254

255255
# check for compressed result
256256
if cmd == SQCLOUD_CMD.COMPRESSED.value:
257257
buffer = self._internal_uncompress_data(buffer, blen)
258258
if buffer is None:
259259
raise SQCloudException(
260-
f"An error occurred while decompressing the input buffer of len {blen}.",
261-
-1,
260+
f"An error occurred while decompressing the input buffer of len {blen}."
262261
)
263262

264263
# first character contains command type
@@ -360,9 +359,10 @@ def _internal_parse_buffer(self, buffer: str, blen: int) -> SQCloudResult:
360359
# TODO: isn't implemented in C?
361360
return SQCloudResult(None)
362361

363-
return None
362+
# TODO: exception here?
363+
return SQCloudResult(None)
364364

365-
def _internal_uncompress_data(self, buffer: str, blen: int) -> Optional[str]:
365+
def _internal_uncompress_data(self, buffer: bytes, blen: int) -> Optional[str]:
366366
"""
367367
%LEN COMPRESSED UNCOMPRESSED BUFFER
368368
@@ -385,7 +385,7 @@ def _internal_uncompress_data(self, buffer: str, blen: int) -> Optional[str]:
385385
start = 1
386386
counter = 0
387387
for i in range(blen):
388-
if buffer[i] != " ":
388+
if buffer[i] != b" ":
389389
continue
390390
counter += 1
391391

@@ -413,18 +413,18 @@ def _internal_uncompress_data(self, buffer: str, blen: int) -> Optional[str]:
413413
start += hlen
414414

415415
# perform real decompression
416-
clone = header + str(lz4.block.decompress(buffer[start:]))
416+
clone = header + lz4.block.decompress(buffer[start:])
417417

418418
# sanity check result
419419
if len(clone) != ulen + hlen:
420420
return None
421421

422422
return clone
423423

424-
def _internal_reconnect(self, buffer: str) -> bool:
424+
def _internal_reconnect(self, buffer: bytes) -> bool:
425425
return True
426426

427-
def _internal_parse_array(self, buffer: str) -> list:
427+
def _internal_parse_array(self, buffer: bytes) -> list:
428428
start = 0
429429
sqlite_number = self._internal_parse_number(buffer, start)
430430
n = sqlite_number.value
@@ -438,13 +438,14 @@ def _internal_parse_array(self, buffer: str) -> list:
438438

439439
return r
440440

441-
def _internal_parse_value(self, buffer: str, index: int = 0) -> SQCloudValue:
441+
def _internal_parse_value(self, buffer: bytes, index: int = 0) -> SQCloudValue:
442442
sqcloud_value = SQCloudValue()
443443
len = 0
444444
cellsize = 0
445445

446446
# handle special NULL value case
447-
if buffer is None or buffer[index] == SQCLOUD_CMD.NULL.value:
447+
c = chr(buffer[index])
448+
if buffer is None or c == SQCLOUD_CMD.NULL.value:
448449
len = 0
449450
if cellsize is not None:
450451
cellsize = 2
@@ -460,36 +461,36 @@ def _internal_parse_value(self, buffer: str, index: int = 0) -> SQCloudValue:
460461

461462
# handle decimal/float cases
462463
if (
463-
buffer[index] == SQCLOUD_CMD.INT.value
464-
or buffer[index] == SQCLOUD_CMD.FLOAT.value
464+
c == SQCLOUD_CMD.INT.value
465+
or c == SQCLOUD_CMD.FLOAT.value
465466
):
466467
nlen = cstart - index
467468
len = nlen - 2
468469
cellsize = nlen
469470

470-
sqcloud_value.value = buffer[index + 1 : index + 1 + len]
471+
sqcloud_value.value = (buffer[index + 1 : index + 1 + len]).decode()
471472
sqcloud_value.len
472473
sqcloud_value.cellsize = cellsize
473474

474475
return sqcloud_value
475476

476-
len = blen - 1 if buffer[index] == SQCLOUD_CMD.ZEROSTRING.value else blen
477+
len = blen - 1 if c == SQCLOUD_CMD.ZEROSTRING.value else blen
477478
cellsize = blen + cstart - index
478479

479-
sqcloud_value.value = buffer[cstart : cstart + len]
480+
sqcloud_value.value = (buffer[cstart : cstart + len]).decode()
480481
sqcloud_value.len = len
481482
sqcloud_value.cellsize = cellsize
482483

483484
return sqcloud_value
484485

485-
def _internal_parse_rowset_signature(self, buffer: str) -> SQCloudRowsetSignature:
486+
def _internal_parse_rowset_signature(self, buffer: bytes) -> SQCloudRowsetSignature:
486487
# ROWSET: *LEN 0:VERS NROWS NCOLS DATA
487488
# ROWSET in CHUNK: /LEN IDX:VERS NROWS NCOLS DATA
488489

489490
signature = SQCloudRowsetSignature()
490491

491492
# check for end-of-chunk condition
492-
if buffer == SQCLOUD_ROWSET.CHUNKS_END:
493+
if buffer == SQCLOUD_ROWSET.CHUNKS_END.value:
493494
signature.version = 0
494495
signature.start = 0
495496
return signature
@@ -498,11 +499,11 @@ def _internal_parse_rowset_signature(self, buffer: str) -> SQCloudRowsetSignatur
498499
counter = 0
499500
n = len(buffer)
500501
for i in range(n):
501-
if buffer[i] != " ":
502+
if chr(buffer[i]) != " ":
502503
continue
503504
counter += 1
504505

505-
data = buffer[start:i]
506+
data = (buffer[start:i]).decode()
506507
start = i + 1
507508

508509
if counter == 1:
@@ -525,11 +526,11 @@ def _internal_parse_rowset_signature(self, buffer: str) -> SQCloudRowsetSignatur
525526
return SQCloudRowsetSignature()
526527

527528
def _internal_parse_rowset(
528-
self, buffer: str, start: int, idx: int, version: int, nrows: int, ncols: int
529+
self, buffer: bytes, start: int, idx: int, version: int, nrows: int, ncols: int
529530
) -> SQCloudResult:
530531
rowset = None
531532
n = start
532-
ischunk = buffer[0] == SQCLOUD_CMD.ROWSET_CHUNK.value
533+
ischunk = chr(buffer[0]) == SQCLOUD_CMD.ROWSET_CHUNK.value
533534

534535
# idx == 0 means first (and only) chunk for rowset
535536
# idx == 1 means first chunk for chunked rowset
@@ -555,7 +556,7 @@ def _internal_parse_rowset(
555556
return rowset
556557

557558
def _internal_parse_rowset_header(
558-
self, rowset: SQCloudResult, buffer: str, start: int
559+
self, rowset: SQCloudResult, buffer: bytes, start: int
559560
) -> int:
560561
ncols = rowset.ncols
561562

@@ -566,16 +567,14 @@ def _internal_parse_rowset_header(
566567
number_len = sqcloud_number.value
567568
cstart = sqcloud_number.cstart
568569
value = buffer[cstart : cstart + number_len]
569-
rowset.colname.append(value)
570+
rowset.colname.append(value.decode())
570571
start = cstart + number_len
571572

572573
if rowset.version == 1:
573574
return start
574575

575576
if rowset.version != 2:
576-
raise SQCloudException(
577-
f"Rowset version {rowset.version} is not supported.", -1
578-
)
577+
raise SQCloudException(f"Rowset version {rowset.version} is not supported.")
579578

580579
# parse declared types
581580
rowset.decltype = []
@@ -584,7 +583,7 @@ def _internal_parse_rowset_header(
584583
number_len = sqcloud_number.value
585584
cstart = sqcloud_number.cstart
586585
value = buffer[cstart : cstart + number_len]
587-
rowset.decltype.append(value)
586+
rowset.decltype.append(value.decode())
588587
start = cstart + number_len
589588

590589
# parse database names
@@ -594,7 +593,7 @@ def _internal_parse_rowset_header(
594593
number_len = sqcloud_number.value
595594
cstart = sqcloud_number.cstart
596595
value = buffer[cstart : cstart + number_len]
597-
rowset.dbname.append(value)
596+
rowset.dbname.append(value.decode())
598597
start = cstart + number_len
599598

600599
# parse table names
@@ -604,7 +603,7 @@ def _internal_parse_rowset_header(
604603
number_len = sqcloud_number.value
605604
cstart = sqcloud_number.cstart
606605
value = buffer[cstart : cstart + number_len]
607-
rowset.tblname.append(value)
606+
rowset.tblname.append(value.decode())
608607
start = cstart + number_len
609608

610609
# parse column original names
@@ -614,7 +613,7 @@ def _internal_parse_rowset_header(
614613
number_len = sqcloud_number.value
615614
cstart = sqcloud_number.cstart
616615
value = buffer[cstart : cstart + number_len]
617-
rowset.origname.append(value)
616+
rowset.origname.append(value.decode())
618617
start = cstart + number_len
619618

620619
# parse not null flags
@@ -641,7 +640,7 @@ def _internal_parse_rowset_header(
641640
return start
642641

643642
def _internal_parse_rowset_values(
644-
self, rowset: SQCloudResult, buffer: str, start: int, bound: int
643+
self, rowset: SQCloudResult, buffer: bytes, start: int, bound: int
645644
):
646645
# loop to parse each individual value
647646
for i in range(bound):

src/sqlitecloud/types.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class SQCLOUD_CMD(Enum):
4444

4545

4646
class SQCLOUD_ROWSET(Enum):
47-
CHUNKS_END = "/6 0 0 0 "
47+
CHUNKS_END = b"/6 0 0 0 "
4848

4949

5050
class SQCLOUD_INTERNAL_ERRCODE(Enum):
@@ -112,9 +112,6 @@ def __init__(self):
112112

113113
# callback: SQCloudPubSubCB
114114

115-
# todo: which is the proper type?
116-
self.data: any
117-
118115

119116
class SQCloudConfig:
120117
def __init__(self) -> None:
@@ -157,7 +154,7 @@ def __init__(self) -> None:
157154

158155
class SQCloudException(Exception):
159156
def __init__(
160-
self, message: str, code: int, xerrcode=0
157+
self, message: str, code: Optional[int] = -1, xerrcode: Optional[int] = 0
161158
) -> None:
162159
self.errmsg = str(message)
163160
self.errcode = code

0 commit comments

Comments
 (0)