Skip to content

Commit df65186

Browse files
committed
Continue implementing PSQLPy engine for piccolo
Signed-off-by: chandr-andr (Kiselev Aleksandr) <chandr@chandr.net>
1 parent 7851baa commit df65186

10 files changed

Lines changed: 394 additions & 83 deletions

poetry.lock

Lines changed: 84 additions & 82 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
185 Bytes
Binary file not shown.
179 Bytes
Binary file not shown.
179 Bytes
Binary file not shown.
13.8 KB
Binary file not shown.
8.91 KB
Binary file not shown.
8.92 KB
Binary file not shown.

psqlpy_piccolo/engine.py

Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
import contextvars
2+
from dataclasses import dataclass
3+
from typing import Any, Dict, List, Mapping, Optional, Sequence
4+
from piccolo.engine.base import Engine
5+
from psqlpy import ConnectionPool, Connection, Transaction, Cursor
6+
from psqlpy.exceptions import RustPSQLDriverPyBaseError
7+
from piccolo.utils.warnings import Level, colored_warning
8+
from piccolo.query.base import DDL, Query
9+
from piccolo.engine.base import Batch
10+
from piccolo.querystring import QueryString
11+
from piccolo.utils.sync import run_sync
12+
13+
14+
@dataclass
15+
class AsyncBatch(Batch):
16+
connection: Connection
17+
query: Query
18+
batch_size: int
19+
20+
# Set internally
21+
_transaction: Optional[Transaction] = None
22+
_cursor: Optional[Cursor] = None
23+
24+
@property
25+
def cursor(self) -> Cursor:
26+
if not self._cursor:
27+
raise ValueError("_cursor not set")
28+
return self._cursor
29+
30+
async def next(self) -> List[Dict]:
31+
data = await self.cursor.fetch(self.batch_size)
32+
return await self.query._process_results(data.result())
33+
34+
def __aiter__(self):
35+
return self
36+
37+
async def __anext__(self):
38+
response = await self.next()
39+
if response == []:
40+
raise StopAsyncIteration()
41+
return response
42+
43+
async def __aenter__(self):
44+
transaction = self.connection.transaction()
45+
self._transaction = transaction
46+
await self._transaction.begin()
47+
querystring = self.query.querystrings[0]
48+
template, template_args = querystring.compile_string()
49+
50+
self._cursor = await transaction.cursor(
51+
querystring=template,
52+
parameters=template_args,
53+
)
54+
return self
55+
56+
async def __aexit__(self, exception_type, exception, traceback):
57+
if exception:
58+
await self._transaction.rollback()
59+
else:
60+
await self._transaction.commit()
61+
62+
return exception is not None
63+
64+
65+
66+
class PSQLPyEngine(Engine):
67+
68+
engine_type = "postgres"
69+
min_version_number = 10
70+
71+
def __init__(
72+
self,
73+
config: Dict[str, Any],
74+
extensions: Sequence[str] = ("uuid-ossp",),
75+
log_queries: bool = False,
76+
log_responses: bool = False,
77+
extra_nodes: Optional[Mapping[str, "PSQLPyEngine"]] = None,
78+
) -> None:
79+
if extra_nodes is None:
80+
extra_nodes = {}
81+
82+
self.config = config
83+
self.extensions = extensions
84+
self.log_queries = log_queries
85+
self.log_responses = log_responses
86+
self.extra_nodes = extra_nodes
87+
self.pool: Optional[ConnectionPool] = None
88+
database_name = config.get("database", "Unknown")
89+
self.current_transaction = contextvars.ContextVar(
90+
f"pg_current_transaction_{database_name}",
91+
default=None,
92+
)
93+
super().__init__()
94+
95+
@staticmethod
96+
def _parse_raw_version_string(version_string: str) -> float:
97+
"""
98+
The format of the version string isn't always consistent. Sometimes
99+
it's just the version number e.g. '9.6.18', and sometimes
100+
it contains specific build information e.g.
101+
'12.4 (Ubuntu 12.4-0ubuntu0.20.04.1)'. Just extract the major and
102+
minor version numbers.
103+
"""
104+
version_segment = version_string.split(" ")[0]
105+
major, minor = version_segment.split(".")[:2]
106+
return float(f"{major}.{minor}")
107+
108+
async def get_version(self) -> float:
109+
"""
110+
Returns the version of Postgres being run.
111+
"""
112+
try:
113+
response: Sequence[Dict] = await self._run_in_new_connection(
114+
"SHOW server_version"
115+
)
116+
except ConnectionRefusedError as exception:
117+
# Suppressing the exception, otherwise importing piccolo_conf.py
118+
# containing an engine will raise an ImportError.
119+
colored_warning(f"Unable to connect to database - {exception}")
120+
return 0.0
121+
else:
122+
version_string = response[0]["server_version"]
123+
return self._parse_raw_version_string(
124+
version_string=version_string
125+
)
126+
127+
def get_version_sync(self) -> float:
128+
return run_sync(self.get_version())
129+
130+
async def prep_database(self):
131+
for extension in self.extensions:
132+
try:
133+
await self._run_in_new_connection(
134+
f'CREATE EXTENSION IF NOT EXISTS "{extension}"',
135+
)
136+
except RustPSQLDriverPyBaseError:
137+
colored_warning(
138+
f"=> Unable to create {extension} extension - some "
139+
"functionality may not behave as expected. Make sure "
140+
"your database user has permission to create "
141+
"extensions, or add it manually using "
142+
f'`CREATE EXTENSION "{extension}";`',
143+
level=Level.medium,
144+
)
145+
146+
async def start_connnection_pool(self, **kwargs) -> None:
147+
colored_warning(
148+
"`start_connnection_pool` is a typo - please change it to "
149+
"`start_connection_pool`.",
150+
category=DeprecationWarning,
151+
)
152+
return await self.start_connection_pool()
153+
154+
async def close_connnection_pool(self, **kwargs) -> None:
155+
colored_warning(
156+
"`close_connnection_pool` is a typo - please change it to "
157+
"`close_connection_pool`.",
158+
category=DeprecationWarning,
159+
)
160+
return await self.close_connection_pool()
161+
162+
async def start_connection_pool(self, **kwargs) -> None:
163+
if self.pool:
164+
colored_warning(
165+
"A pool already exists - close it first if you want to create "
166+
"a new pool.",
167+
)
168+
else:
169+
config = dict(self.config)
170+
config.update(**kwargs)
171+
self.pool = ConnectionPool(**config)
172+
173+
async def close_connection_pool(self) -> None:
174+
if self.pool:
175+
self.pool.close()
176+
self.pool = None
177+
else:
178+
colored_warning("No pool is running.")
179+
180+
async def get_new_connection(self) -> Connection:
181+
"""
182+
Returns a new connection - doesn't retrieve it from the pool.
183+
"""
184+
return await (ConnectionPool(**dict(self.config))).connection()
185+
186+
async def batch(
187+
self,
188+
query: Query,
189+
batch_size: int = 100,
190+
node: Optional[str] = None,
191+
) -> AsyncBatch:
192+
"""
193+
:param query:
194+
The database query to run.
195+
:param batch_size:
196+
How many rows to fetch on each iteration.
197+
:param node:
198+
Which node to run the query on (see ``extra_nodes``). If not
199+
specified, it runs on the main Postgres node.
200+
"""
201+
engine: Any = self.extra_nodes.get(node) if node else self
202+
connection = await engine.get_new_connection()
203+
return AsyncBatch(
204+
connection=connection, query=query, batch_size=batch_size
205+
)
206+
207+
async def _run_in_pool(self, query: str, args: Optional[Sequence[Any]] = None):
208+
if not self.pool:
209+
raise ValueError("A pool isn't currently running.")
210+
211+
connection = await self.pool.connection()
212+
response = await connection.execute(
213+
querystring=query,
214+
parameters=args,
215+
)
216+
217+
return response.result()
218+
219+
async def _run_in_new_connection(
220+
self, query: str, args: Optional[Sequence[Any]] = None
221+
):
222+
if args is None:
223+
args = []
224+
connection = await self.get_new_connection()
225+
226+
try:
227+
results = await connection.execute(query, *args)
228+
except RustPSQLDriverPyBaseError as exception:
229+
raise exception
230+
231+
return results.result()
232+
233+
async def run_querystring(
234+
self, querystring: QueryString, in_pool: bool = True
235+
):
236+
query, query_args = querystring.compile_string(
237+
engine_type=self.engine_type
238+
)
239+
240+
query_id = self.get_query_id()
241+
242+
if self.log_queries:
243+
self.print_query(query_id=query_id, query=querystring.__str__())
244+
245+
# If running inside a transaction:
246+
current_transaction = self.current_transaction.get()
247+
if current_transaction:
248+
response = await current_transaction.connection.fetch(
249+
query, *query_args
250+
)
251+
elif in_pool and self.pool:
252+
response = await self._run_in_pool(query, query_args)
253+
else:
254+
response = await self._run_in_new_connection(query, query_args)
255+
256+
if self.log_responses:
257+
self.print_response(query_id=query_id, response=response)
258+
259+
return response
260+
261+
async def run_ddl(self, ddl: str, in_pool: bool = True):
262+
query_id = self.get_query_id()
263+
264+
if self.log_queries:
265+
self.print_query(query_id=query_id, query=ddl)
266+
267+
# If running inside a transaction:
268+
current_transaction = self.current_transaction.get()
269+
if current_transaction:
270+
response = await current_transaction.connection.fetch(ddl)
271+
elif in_pool and self.pool:
272+
response = await self._run_in_pool(ddl)
273+
else:
274+
response = await self._run_in_new_connection(ddl)
275+
276+
if self.log_responses:
277+
self.print_response(query_id=query_id, response=response)
278+
279+
return response
280+
281+
def atomic(self) -> str:
282+
return "123"
283+
284+
def transaction(self, allow_nested: bool = True) -> str:
285+
return "str"

psqlpy_piccolo/main.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import asyncio
2+
from psqlpy_piccolo.engine import PSQLPyEngine
3+
from piccolo.querystring import QueryString
4+
5+
6+
engine = PSQLPyEngine(
7+
config={
8+
"dsn": "postgres://postgres:postgres@localhost:5432/postgres",
9+
},
10+
)
11+
12+
13+
async def main() -> None:
14+
await engine.start_connection_pool()
15+
qs = QueryString(
16+
"SELECT * FROM users WHERE id = {}",
17+
3,
18+
)
19+
20+
res = await engine.run_querystring(qs)
21+
print(res)
22+
23+
24+
asyncio.run(main())

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ readme = "README.md"
77

88
[tool.poetry.dependencies]
99
python = "^3.8"
10-
psqlpy = "^0.5.5"
10+
psqlpy = "^0.6.0"
1111
piccolo = "^1.5.0"
1212

1313

0 commit comments

Comments
 (0)