Skip to content

Commit edd5515

Browse files
committed
Implement decompression lz4
1 parent 96f66ba commit edd5515

4 files changed

Lines changed: 59 additions & 49 deletions

File tree

src/sqlitecloud/driver.py

Lines changed: 29 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -279,12 +279,16 @@ def _internal_parse_buffer(
279279

280280
# check for compressed result
281281
if cmd == SQCLOUD_CMD.COMPRESSED.value:
282-
buffer = self._internal_uncompress_data(buffer, blen)
282+
buffer = self._internal_uncompress_data(buffer)
283283
if buffer is None:
284284
raise SQCloudException(
285285
f"An error occurred while decompressing the input buffer of len {blen}."
286286
)
287287

288+
# buffer after decompression
289+
blen = len(buffer)
290+
cmd = chr(buffer[0])
291+
288292
# first character contains command type
289293
if cmd in [
290294
SQCLOUD_CMD.ZEROSTRING.value,
@@ -338,7 +342,10 @@ def _internal_parse_buffer(
338342

339343
elif cmd in [SQCLOUD_CMD.ROWSET.value, SQCLOUD_CMD.ROWSET_CHUNK.value]:
340344
# CMD_ROWSET: *LEN 0:VERSION ROWS COLS DATA
345+
# - When decompressed, LEN for ROWSET is *0
346+
#
341347
# CMD_ROWSET_CHUNK: /LEN IDX:VERSION ROWS COLS DATA
348+
#
342349
rowset_signature = self._internal_parse_rowset_signature(buffer)
343350
if rowset_signature.start < 0:
344351
raise SQCloudException("Cannot parse rowset signature")
@@ -361,7 +368,7 @@ def _internal_parse_buffer(
361368
# continue parsing next chunk in the buffer
362369
sign_len = rowset_signature.len
363370
buffer = buffer[sign_len + len(f"/{sign_len} ") :]
364-
if buffer:
371+
if cmd == SQCLOUD_CMD.ROWSET_CHUNK.value and buffer:
365372
return self._internal_parse_buffer(connection, buffer, len(buffer))
366373

367374
return rowset
@@ -387,7 +394,7 @@ def _internal_parse_buffer(
387394
# TODO: exception here?
388395
return SQCloudResult(None)
389396

390-
def _internal_uncompress_data(self, buffer: bytes, blen: int) -> Optional[bytes]:
397+
def _internal_uncompress_data(self, buffer: bytes) -> Optional[bytes]:
391398
"""
392399
%LEN COMPRESSED UNCOMPRESSED BUFFER
393400
@@ -398,51 +405,34 @@ def _internal_uncompress_data(self, buffer: bytes, blen: int) -> Optional[bytes]
398405
Returns:
399406
str: The uncompressed data.
400407
"""
401-
tlen = 0 # total length
402-
clen = 0 # compressed length
403-
ulen = 0 # uncompressed length
404-
hlen = 0 # raw header length
405-
seek1 = 0
408+
space_index = buffer.index(b" ")
409+
buffer = buffer[space_index + 1 :]
406410

407-
start = 1
408-
counter = 0
409-
for i in range(blen):
410-
if chr(buffer[i]) != " ":
411-
continue
412-
counter += 1
411+
# extract compressed size
412+
space_index = buffer.index(b" ")
413+
compressed_size = int(buffer[:space_index].decode("utf-8"))
414+
buffer = buffer[space_index + 1 :]
413415

414-
data = buffer[start:i]
415-
start = i + 1
416-
417-
if counter == 1:
418-
tlen = int(data)
419-
seek1 = start
420-
elif counter == 2:
421-
clen = int(data)
422-
elif counter == 3:
423-
ulen = int(data)
424-
break
425-
426-
# sanity check header values
427-
if tlen == 0 or clen == 0 or ulen == 0 or start == 1 or seek1 == 0:
428-
return None
416+
# extract decompressed size
417+
space_index = buffer.index(b" ")
418+
uncompressed_size = int(buffer[:space_index].decode("utf-8"))
419+
buffer = buffer[space_index + 1 :]
429420

430-
# copy raw header
431-
hlen = start - seek1
432-
header = buffer[start : start + hlen]
421+
# extract data header
422+
header = buffer[:-compressed_size]
433423

434-
# compute index of the first compressed byte
435-
start += hlen
424+
# extract compressed data
425+
compressed_buffer = buffer[-compressed_size:]
436426

437-
# perform real decompression
438-
# clone = header + lz4.block.decompress(buffer[start:])
439-
clone = lz4decode(buffer, start, header)
427+
decompressed_buffer = header + lz4.block.decompress(
428+
compressed_buffer, uncompressed_size
429+
)
440430

441431
# sanity check result
442-
if len(clone) != ulen + hlen:
432+
if len(decompressed_buffer) != uncompressed_size + len(header):
443433
return None
444434

445-
return clone
435+
return decompressed_buffer
446436

447437
def _internal_parse_array(self, buffer: bytes) -> list:
448438
start = 0

src/sqlitecloud/resultset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,4 @@ def get_name(self, col: int) -> Optional[str]:
7575
return self._result.colname[col]
7676

7777
def get_result(self) -> Optional[any]:
78-
return self.get_value(0, 0)
78+
return self.get_value(0, 0)

src/tests/integration/test_client.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -584,8 +584,7 @@ def test_download_database(self, sqlitecloud_connection):
584584
assert cursor.description[0][0] == "AlbumId"
585585
assert cursor.description[1][0] == "Title"
586586

587-
# TODO
588-
def test_compression(self):
587+
def test_compression_single_column(self):
589588
account = SqliteCloudAccount()
590589
account.hostname = os.getenv("SQLITE_HOST")
591590
account.apikey = os.getenv("SQLITE_API_KEY")
@@ -595,12 +594,34 @@ def test_compression(self):
595594
client.config.compression = True
596595

597596
# min compression size for rowset set by default to 20400 bytes
598-
rowset = client.exec_query("SELECT '" + "a" * (1024 * 20) + "' AS DDD")
597+
blob_size = 20 * 1024
598+
# rowset = client.exec_query("SELECT * from albums inner join albums a2 on albums.AlbumId = a2.AlbumId")
599+
rowset = client.exec_query(
600+
f"SELECT hex(randomblob({blob_size})) AS 'someColumnName'"
601+
)
599602

600603
assert rowset.nrows == 1
601604
assert rowset.ncols == 1
602-
assert rowset.get_value(0, 0).startswith("aaaaa")
603-
assert len(rowset.get_value(0, 0)) == 100
605+
assert rowset.get_name(0) == "someColumnName"
606+
assert len(rowset.get_value(0, 0)) == blob_size * 2
607+
608+
def test_compression_multiple_columns(self):
609+
account = SqliteCloudAccount()
610+
account.hostname = os.getenv("SQLITE_HOST")
611+
account.apikey = os.getenv("SQLITE_API_KEY")
612+
account.database = os.getenv("SQLITE_DB")
613+
614+
client = SqliteCloudClient(cloud_account=account)
615+
client.config.compression = True
616+
617+
# min compression size for rowset set by default to 20400 bytes
618+
rowset = client.exec_query(
619+
"SELECT * from albums inner join albums a2 on albums.AlbumId = a2.AlbumId"
620+
)
621+
622+
assert rowset.nrows > 0
623+
assert rowset.ncols > 0
624+
assert rowset.get_name(0) == "AlbumId"
604625

605626
# def test_send_blob(self, sqlitecloud_connection):
606627
# connection, client = sqlitecloud_connection
@@ -615,19 +636,18 @@ def test_compression(self):
615636
# blob = b""
616637
# result = client.sendblob(blob, connection)
617638
# assert result is not None
618-
# # Add additional assertions as needed
639+
#
619640

620641
# def test_send_large_blob(self, sqlitecloud_connection):
621642
# connection, client = sqlitecloud_connection
622643
# blob = b"A" * 1024 * 1024 # 1MB blob
623644
# result = client.sendblob(blob, connection)
624645
# assert result is not None
625-
# # Add additional assertions as needed
646+
#
626647

627648
# def test_send_blob_with_connection_closed(self, sqlitecloud_connection):
628649
# connection, client = sqlitecloud_connection
629650
# client.disconnect(connection)
630651
# blob = b"Hello, this is a test blob"
631652
# with pytest.raises(Exception):
632653
# client.sendblob(blob, connection)
633-
# Add additional assertions as needed

src/tests/unit/test_resultset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def test_get_value_array(self):
7878
result = SQCloudResult(result=[1, 2, 3])
7979
result_set = SqliteCloudResultSet(result)
8080

81-
assert [1,2,3] == result_set.get_value(0, 0)
81+
assert [1, 2, 3] == result_set.get_value(0, 0)
8282

8383
def test_get_colname(self):
8484
result = SQCloudResult()

0 commit comments

Comments
 (0)