1+ import contextvars
2+ from dataclasses import dataclass
3+ from typing import Any , Dict , List , Mapping , Optional , Sequence
4+ from piccolo .engine .base import Engine
5+ from psqlpy import ConnectionPool , Connection , Transaction , Cursor
6+ from psqlpy .exceptions import RustPSQLDriverPyBaseError
7+ from piccolo .utils .warnings import Level , colored_warning
8+ from piccolo .query .base import DDL , Query
9+ from piccolo .engine .base import Batch
10+ from piccolo .querystring import QueryString
11+ from piccolo .utils .sync import run_sync
12+
13+
14+ @dataclass
15+ class AsyncBatch (Batch ):
16+ connection : Connection
17+ query : Query
18+ batch_size : int
19+
20+ # Set internally
21+ _transaction : Optional [Transaction ] = None
22+ _cursor : Optional [Cursor ] = None
23+
24+ @property
25+ def cursor (self ) -> Cursor :
26+ if not self ._cursor :
27+ raise ValueError ("_cursor not set" )
28+ return self ._cursor
29+
30+ async def next (self ) -> List [Dict ]:
31+ data = await self .cursor .fetch (self .batch_size )
32+ return await self .query ._process_results (data .result ())
33+
34+ def __aiter__ (self ):
35+ return self
36+
37+ async def __anext__ (self ):
38+ response = await self .next ()
39+ if response == []:
40+ raise StopAsyncIteration ()
41+ return response
42+
43+ async def __aenter__ (self ):
44+ transaction = self .connection .transaction ()
45+ self ._transaction = transaction
46+ await self ._transaction .begin ()
47+ querystring = self .query .querystrings [0 ]
48+ template , template_args = querystring .compile_string ()
49+
50+ self ._cursor = await transaction .cursor (
51+ querystring = template ,
52+ parameters = template_args ,
53+ )
54+ return self
55+
56+ async def __aexit__ (self , exception_type , exception , traceback ):
57+ if exception :
58+ await self ._transaction .rollback ()
59+ else :
60+ await self ._transaction .commit ()
61+
62+ return exception is not None
63+
64+
65+
66+ class PSQLPyEngine (Engine ):
67+
68+ engine_type = "postgres"
69+ min_version_number = 10
70+
71+ def __init__ (
72+ self ,
73+ config : Dict [str , Any ],
74+ extensions : Sequence [str ] = ("uuid-ossp" ,),
75+ log_queries : bool = False ,
76+ log_responses : bool = False ,
77+ extra_nodes : Optional [Mapping [str , "PSQLPyEngine" ]] = None ,
78+ ) -> None :
79+ if extra_nodes is None :
80+ extra_nodes = {}
81+
82+ self .config = config
83+ self .extensions = extensions
84+ self .log_queries = log_queries
85+ self .log_responses = log_responses
86+ self .extra_nodes = extra_nodes
87+ self .pool : Optional [ConnectionPool ] = None
88+ database_name = config .get ("database" , "Unknown" )
89+ self .current_transaction = contextvars .ContextVar (
90+ f"pg_current_transaction_{ database_name } " ,
91+ default = None ,
92+ )
93+ super ().__init__ ()
94+
95+ @staticmethod
96+ def _parse_raw_version_string (version_string : str ) -> float :
97+ """
98+ The format of the version string isn't always consistent. Sometimes
99+ it's just the version number e.g. '9.6.18', and sometimes
100+ it contains specific build information e.g.
101+ '12.4 (Ubuntu 12.4-0ubuntu0.20.04.1)'. Just extract the major and
102+ minor version numbers.
103+ """
104+ version_segment = version_string .split (" " )[0 ]
105+ major , minor = version_segment .split ("." )[:2 ]
106+ return float (f"{ major } .{ minor } " )
107+
108+ async def get_version (self ) -> float :
109+ """
110+ Returns the version of Postgres being run.
111+ """
112+ try :
113+ response : Sequence [Dict ] = await self ._run_in_new_connection (
114+ "SHOW server_version"
115+ )
116+ except ConnectionRefusedError as exception :
117+ # Suppressing the exception, otherwise importing piccolo_conf.py
118+ # containing an engine will raise an ImportError.
119+ colored_warning (f"Unable to connect to database - { exception } " )
120+ return 0.0
121+ else :
122+ version_string = response [0 ]["server_version" ]
123+ return self ._parse_raw_version_string (
124+ version_string = version_string
125+ )
126+
127+ def get_version_sync (self ) -> float :
128+ return run_sync (self .get_version ())
129+
130+ async def prep_database (self ):
131+ for extension in self .extensions :
132+ try :
133+ await self ._run_in_new_connection (
134+ f'CREATE EXTENSION IF NOT EXISTS "{ extension } "' ,
135+ )
136+ except RustPSQLDriverPyBaseError :
137+ colored_warning (
138+ f"=> Unable to create { extension } extension - some "
139+ "functionality may not behave as expected. Make sure "
140+ "your database user has permission to create "
141+ "extensions, or add it manually using "
142+ f'`CREATE EXTENSION "{ extension } ";`' ,
143+ level = Level .medium ,
144+ )
145+
146+ async def start_connnection_pool (self , ** kwargs ) -> None :
147+ colored_warning (
148+ "`start_connnection_pool` is a typo - please change it to "
149+ "`start_connection_pool`." ,
150+ category = DeprecationWarning ,
151+ )
152+ return await self .start_connection_pool ()
153+
154+ async def close_connnection_pool (self , ** kwargs ) -> None :
155+ colored_warning (
156+ "`close_connnection_pool` is a typo - please change it to "
157+ "`close_connection_pool`." ,
158+ category = DeprecationWarning ,
159+ )
160+ return await self .close_connection_pool ()
161+
162+ async def start_connection_pool (self , ** kwargs ) -> None :
163+ if self .pool :
164+ colored_warning (
165+ "A pool already exists - close it first if you want to create "
166+ "a new pool." ,
167+ )
168+ else :
169+ config = dict (self .config )
170+ config .update (** kwargs )
171+ self .pool = ConnectionPool (** config )
172+
173+ async def close_connection_pool (self ) -> None :
174+ if self .pool :
175+ self .pool .close ()
176+ self .pool = None
177+ else :
178+ colored_warning ("No pool is running." )
179+
180+ async def get_new_connection (self ) -> Connection :
181+ """
182+ Returns a new connection - doesn't retrieve it from the pool.
183+ """
184+ return await (ConnectionPool (** dict (self .config ))).connection ()
185+
186+ async def batch (
187+ self ,
188+ query : Query ,
189+ batch_size : int = 100 ,
190+ node : Optional [str ] = None ,
191+ ) -> AsyncBatch :
192+ """
193+ :param query:
194+ The database query to run.
195+ :param batch_size:
196+ How many rows to fetch on each iteration.
197+ :param node:
198+ Which node to run the query on (see ``extra_nodes``). If not
199+ specified, it runs on the main Postgres node.
200+ """
201+ engine : Any = self .extra_nodes .get (node ) if node else self
202+ connection = await engine .get_new_connection ()
203+ return AsyncBatch (
204+ connection = connection , query = query , batch_size = batch_size
205+ )
206+
207+ async def _run_in_pool (self , query : str , args : Optional [Sequence [Any ]] = None ):
208+ if not self .pool :
209+ raise ValueError ("A pool isn't currently running." )
210+
211+ connection = await self .pool .connection ()
212+ response = await connection .execute (
213+ querystring = query ,
214+ parameters = args ,
215+ )
216+
217+ return response .result ()
218+
219+ async def _run_in_new_connection (
220+ self , query : str , args : Optional [Sequence [Any ]] = None
221+ ):
222+ if args is None :
223+ args = []
224+ connection = await self .get_new_connection ()
225+
226+ try :
227+ results = await connection .execute (query , * args )
228+ except RustPSQLDriverPyBaseError as exception :
229+ raise exception
230+
231+ return results .result ()
232+
233+ async def run_querystring (
234+ self , querystring : QueryString , in_pool : bool = True
235+ ):
236+ query , query_args = querystring .compile_string (
237+ engine_type = self .engine_type
238+ )
239+
240+ query_id = self .get_query_id ()
241+
242+ if self .log_queries :
243+ self .print_query (query_id = query_id , query = querystring .__str__ ())
244+
245+ # If running inside a transaction:
246+ current_transaction = self .current_transaction .get ()
247+ if current_transaction :
248+ response = await current_transaction .connection .fetch (
249+ query , * query_args
250+ )
251+ elif in_pool and self .pool :
252+ response = await self ._run_in_pool (query , query_args )
253+ else :
254+ response = await self ._run_in_new_connection (query , query_args )
255+
256+ if self .log_responses :
257+ self .print_response (query_id = query_id , response = response )
258+
259+ return response
260+
261+ async def run_ddl (self , ddl : str , in_pool : bool = True ):
262+ query_id = self .get_query_id ()
263+
264+ if self .log_queries :
265+ self .print_query (query_id = query_id , query = ddl )
266+
267+ # If running inside a transaction:
268+ current_transaction = self .current_transaction .get ()
269+ if current_transaction :
270+ response = await current_transaction .connection .fetch (ddl )
271+ elif in_pool and self .pool :
272+ response = await self ._run_in_pool (ddl )
273+ else :
274+ response = await self ._run_in_new_connection (ddl )
275+
276+ if self .log_responses :
277+ self .print_response (query_id = query_id , response = response )
278+
279+ return response
280+
281+ def atomic (self ) -> str :
282+ return "123"
283+
284+ def transaction (self , allow_nested : bool = True ) -> str :
285+ return "str"
0 commit comments