Skip to content

Commit 9ce2bec

Browse files
committed
Add Upload /downoload database
1 parent 25ae065 commit 9ce2bec

10 files changed

Lines changed: 292 additions & 45 deletions

File tree

README.md

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,12 @@ You can install SqliteCloud Package using Python Package Index (PYPI):
1313
$ pip install SqliteCloud
1414
```
1515

16-
- Follow the instructions reported here https://github.com/sqlitecloud/sdk/tree/master/C to build the driver.
17-
18-
- Set SQLITECLOUD_DRIVER_PATH environment variable to the path of the driver file build.
19-
2016
## Usage
2117
<hr>
2218

2319
```python
24-
from sqlitecloud.client import SqliteCloudClient, SqliteCloudAccount
20+
from sqlitecloud.client import SqliteCloudClient
21+
from sqlitecloud.types import SqliteCloudAccount
2522
```
2623

2724
### _Init a connection_
@@ -45,9 +42,8 @@ conn = client.open_connection()
4542
### _Execute a query_
4643
You can bind values to parametric queries: you can pass parameters as positional values in an array
4744
```python
48-
result = client.exec_statement(
49-
"SELECT * FROM table_name WHERE id = ?",
50-
[1],
45+
result = client.exec_query(
46+
"SELECT * FROM table_name WHERE id = 1"
5147
conn=conn
5248
)
5349
```

src/sqlitecloud/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def sendblob(self, blob: bytes, conn: SQCloudConnect) -> SqliteCloudResultSet:
9292
blob (bytes): The blob to be sent to the database.
9393
conn (SQCloudConnect): The connection to the database.
9494
"""
95-
return self.driver.sendblob(blob, conn)
95+
return self.driver.send_blob(blob, conn)
9696

9797
def _parse_connection_string(self, connection_string) -> SQCloudConfig:
9898
# URL STRING FORMAT

src/sqlitecloud/download.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from io import BufferedWriter
2+
import logging
3+
4+
from sqlitecloud.driver import Driver
5+
from sqlitecloud.types import SQCloudConnect
6+
7+
8+
def xCallback(
9+
fd: BufferedWriter, data: bytes, blen: int, ntot: int, nprogress: int
10+
) -> None:
11+
fd.write(data)
12+
13+
if blen == 0:
14+
logging.log(logging.DEBUG, "DOWNLOAD COMPLETE")
15+
else:
16+
logging.log(logging.DEBUG, f"{(nprogress + blen) / ntot * 100:.2f}%")
17+
18+
19+
def download_db(connection: SQCloudConnect, dbname: str, filename: str) -> None:
20+
driver = Driver()
21+
22+
with open(filename, "wb") as fd:
23+
driver.download_database(connection, dbname, fd, xCallback, False)

src/sqlitecloud/driver.py

Lines changed: 114 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
from io import BufferedReader, BufferedWriter
12
import ssl
2-
from typing import Optional, Union
3+
from typing import Callable, Optional, Union
34
import lz4.block
45
from sqlitecloud.resultset import SQCloudResult
56
from sqlitecloud.types import (
67
SQCLOUD_CMD,
8+
SQCLOUD_DEFAULT,
79
SQCLOUD_INTERNAL_ERRCODE,
810
SQCLOUD_ROWSET,
911
SQCloudConfig,
@@ -17,6 +19,8 @@
1719

1820

1921
class Driver:
22+
SQCLOUD_DEFAULT_UPLOAD_SIZE = 512 * 1024
23+
2024
def __init__(self) -> None:
2125
# Used while parsing chunked rowset
2226
self._rowset: SQCloudResult = None
@@ -77,7 +81,7 @@ def disconnect(self, conn: SQCloudConnect):
7781
def execute(self, command: str, connection: SQCloudConnect) -> SQCloudResult:
7882
return self._internal_run_command(connection, command)
7983

80-
def sendblob(self, blob: bytes, conn: SQCloudConnect) -> SQCloudResult:
84+
def send_blob(self, blob: bytes, conn: SQCloudConnect) -> SQCloudResult:
8185
try:
8286
conn.isblob = True
8387
return self._internal_run_command(conn, blob)
@@ -90,6 +94,113 @@ def _internal_reconnect(self, buffer: bytes) -> bool:
9094
def _internal_setup_pubsub(self, buffer: bytes) -> bool:
9195
return True
9296

97+
def upload_database(
98+
self,
99+
connection: SQCloudConnect,
100+
dbname: str,
101+
key: Optional[str],
102+
is_file_transfer: bool,
103+
snapshot_id: int,
104+
is_internal_db: bool,
105+
fd: BufferedReader,
106+
dbsize: int,
107+
xCallback: Callable[[BufferedReader, int, int, int], bytes],
108+
) -> None:
109+
keyarg = "KEY " if key else ""
110+
keyvalue = key if key else ""
111+
112+
# prepare command to execute
113+
command = ""
114+
if is_file_transfer:
115+
internalarg = "INTERNAL" if is_internal_db else ""
116+
command = f"TRANSFER DATABASE '{dbname}' {keyarg}{keyvalue} SNAPSHOT {snapshot_id} {internalarg}"
117+
else:
118+
command = f"UPLOAD DATABASE '{dbname}' {keyarg}{keyvalue}"
119+
120+
# execute command on server side
121+
result = self._internal_run_command(connection, command)
122+
if not result.data[0]:
123+
raise SQCloudException(
124+
"An error occurred while initializing the upload of the database."
125+
)
126+
127+
buffer: bytes = b""
128+
blen = 0
129+
nprogress = 0
130+
try:
131+
while True:
132+
# execute callback to read buffer
133+
blen = SQCLOUD_DEFAULT.UPLOAD_SIZE.value
134+
try:
135+
buffer = xCallback(fd, blen, dbsize, nprogress)
136+
blen = len(buffer)
137+
except Exception as e:
138+
raise SQCloudException(
139+
"An error occurred while reading the file."
140+
) from e
141+
142+
try:
143+
# send also the final confirmation blob of zero bytes
144+
self.send_blob(buffer, connection)
145+
except Exception as e:
146+
raise SQCloudException(
147+
"An error occurred while uploading the file."
148+
) from e
149+
150+
# update progress
151+
nprogress += blen
152+
153+
if blen == 0:
154+
# Upload completed
155+
break
156+
except Exception as e:
157+
self._internal_run_command(connection, "UPLOAD ABORT")
158+
raise e
159+
160+
def download_database(
161+
self,
162+
connection: SQCloudConnect,
163+
dbname: str,
164+
fd: BufferedWriter,
165+
xCallback: Callable[[BufferedWriter, int, int, int], bytes],
166+
if_exists: bool,
167+
) -> None:
168+
exists_cmd = " IF EXISTS" if if_exists else ""
169+
result = self._internal_run_command(
170+
connection, f"DOWNLOAD DATABASE {dbname}{exists_cmd};"
171+
)
172+
173+
if result.nrows == 0:
174+
raise SQCloudException(
175+
"An error occurred while initializing the download of the database."
176+
)
177+
178+
# result is an ARRAY (database size, number of pages, raft_index)
179+
download_info = result.data[0]
180+
db_size = int(download_info[0])
181+
182+
# loop to download
183+
progress_size = 0
184+
185+
try:
186+
while progress_size < db_size:
187+
result = self._internal_run_command(connection, "DOWNLOAD STEP")
188+
189+
# res is BLOB, decode it
190+
data = result.data[0]
191+
data_len = len(data)
192+
193+
# execute callback (with progress_size updated)
194+
progress_size += data_len
195+
xCallback(fd, data, data_len, db_size, progress_size)
196+
197+
# check exit condition
198+
if data_len == 0:
199+
break
200+
except Exception as e:
201+
self._internal_run_command(connection, "DOWNLOAD ABORT")
202+
raise e
203+
93204
def _internal_config_apply(
94205
self, connection: SQCloudConnect, config: SQCloudConfig
95206
) -> None:
@@ -136,7 +247,7 @@ def _internal_config_apply(
136247

137248
def _internal_run_command(
138249
self, connection: SQCloudConnect, command: Union[str, bytes]
139-
) -> None:
250+
) -> SQCloudResult:
140251
self._internal_socket_write(connection, command)
141252
return self._internal_socket_read(connection)
142253

src/sqlitecloud/types.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
from typing import Optional
33
from enum import Enum
44

5+
class SQCLOUD_DEFAULT(Enum):
6+
PORT = 8860
7+
TIMEOUT = 12
8+
UPLOAD_SIZE = 512*1024
59

610
class SQCLOUD_CMD(Enum):
711
STRING = "+"
@@ -62,7 +66,7 @@ def __init__(
6266
password: Optional[str] = "",
6367
hostname: Optional[str] = "",
6468
dbname: Optional[str] = "",
65-
port: Optional[int] = 8860,
69+
port: Optional[int] = SQCLOUD_DEFAULT.PORT.value,
6670
apikey: Optional[str] = "",
6771
) -> None:
6872
# User name is required unless connectionstring is provided
@@ -98,7 +102,7 @@ def __init__(self) -> None:
98102
# Optional query timeout passed directly to TLS socket
99103
self.timeout = 0
100104
# Socket connection timeout
101-
self.connect_timeout = 20
105+
self.connect_timeout = SQCLOUD_DEFAULT.TIMEOUT.value
102106

103107
# Enable compression
104108
self.compression = False
@@ -132,7 +136,7 @@ def __init__(self) -> None:
132136

133137
class SQCloudException(Exception):
134138
def __init__(
135-
self, message: str, code: Optional[int] = -1, xerrcode: Optional[int] = 0
139+
self, message: str, code: int = -1, xerrcode: int = 0
136140
) -> None:
137141
self.errmsg = str(message)
138142
self.errcode = code

src/sqlitecloud/upload.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from io import BufferedReader
2+
import os
3+
from typing import Optional
4+
from sqlitecloud.driver import Driver
5+
from sqlitecloud.types import SQCloudConnect
6+
import logging
7+
8+
def xCallback(fd: BufferedReader, blen: int, ntot: int, nprogress: int) -> bytes:
9+
buffer = fd.read(blen)
10+
nread = len(buffer)
11+
12+
if nread == 0:
13+
logging.log(logging.DEBUG, "UPLOAD COMPLETE\n\n")
14+
else:
15+
logging.log(logging.DEBUG, f"{(nprogress + nread) / ntot * 100:.2f}%")
16+
17+
return buffer
18+
19+
20+
def upload_db(
21+
connection: SQCloudConnect, dbname: str, key: Optional[str], filename: str
22+
) -> bool:
23+
"""
24+
Uploads a SQLite database to the SQLite Cloud node using the provided connection.
25+
26+
Args:
27+
connection (SQCloudConnect): The connection object used to connect to the node.
28+
dbname (str): The name of the database in SQLite Cloud.
29+
key (Optional[str]): The encryption key for the database. If None, no encryption is used.
30+
filename (str): The path to the SQLite database file to be uploaded.
31+
32+
Returns:
33+
bool: True if the upload is successful, SQCloudException in case of errors.
34+
"""
35+
36+
# Create a driver object
37+
driver = Driver()
38+
39+
with open(filename, 'rb') as fd:
40+
dbsize = os.path.getsize(filename)
41+
42+
driver.upload_database(
43+
connection,
44+
dbname,
45+
key,
46+
False,
47+
0,
48+
False,
49+
fd,
50+
dbsize,
51+
xCallback,
52+
)
53+
54+
return True

src/tests/assets/test.db

16 KB
Binary file not shown.

src/tests/integration/test_client.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -554,36 +554,6 @@ def test_stress_test_20x_batched_selects(self, sqlitecloud_connection):
554554
query_ms < self.EXPECT_SPEED_MS
555555
), f"{num_queries}x batched selects, {query_ms}ms per query"
556556

557-
def test_download_database(self, sqlitecloud_connection):
558-
connection, client = sqlitecloud_connection
559-
560-
rowset = client.exec_query(
561-
"DOWNLOAD DATABASE " + os.getenv("SQLITE_DB"), connection
562-
)
563-
564-
result_array = rowset.get_result()
565-
566-
db_size = int(result_array[0])
567-
568-
tot_read = 0
569-
data: bytes = b""
570-
while tot_read < db_size:
571-
result = client.exec_query("DOWNLOAD STEP;", connection)
572-
573-
data += result.get_result()
574-
tot_read += len(data)
575-
576-
temp_file = tempfile.mkstemp(prefix="chinook")[1]
577-
with open(temp_file, "wb") as f:
578-
f.write(data)
579-
580-
db = sqlite3.connect(temp_file)
581-
cursor = db.execute("SELECT * FROM albums")
582-
rowset = cursor.fetchall()
583-
584-
assert cursor.description[0][0] == "AlbumId"
585-
assert cursor.description[1][0] == "Title"
586-
587557
def test_compression_single_column(self):
588558
account = SqliteCloudAccount()
589559
account.hostname = os.getenv("SQLITE_HOST")
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import os
2+
import sqlite3
3+
import tempfile
4+
5+
import pytest
6+
7+
from sqlitecloud import download
8+
from sqlitecloud.client import SqliteCloudClient
9+
from sqlitecloud.types import SQCloudConnect, SqliteCloudAccount
10+
11+
12+
class TestDownload:
13+
@pytest.fixture()
14+
def sqlitecloud_connection(self):
15+
account = SqliteCloudAccount()
16+
account.username = os.getenv("SQLITE_USER")
17+
account.password = os.getenv("SQLITE_PASSWORD")
18+
account.dbname = os.getenv("SQLITE_DB")
19+
account.hostname = os.getenv("SQLITE_HOST")
20+
account.port = 8860
21+
22+
client = SqliteCloudClient(cloud_account=account)
23+
24+
connection = client.open_connection()
25+
assert isinstance(connection, SQCloudConnect)
26+
27+
yield (connection, client)
28+
29+
client.disconnect(connection)
30+
31+
def test_download_database(self, sqlitecloud_connection):
32+
connection, _ = sqlitecloud_connection
33+
34+
temp_file = tempfile.mkstemp(prefix="chinook")[1]
35+
download.download_db(connection, "chinook.sqlite", temp_file)
36+
37+
db = sqlite3.connect(temp_file)
38+
cursor = db.execute("SELECT * FROM albums")
39+
40+
assert cursor.description[0][0] == "AlbumId"
41+
assert cursor.description[1][0] == "Title"

0 commit comments

Comments
 (0)