11import contextvars
22from dataclasses import dataclass
33from typing import Any , Dict , List , Mapping , Optional , Sequence
4- from piccolo .engine .base import Engine
4+ from piccolo .engine .base import Engine , validate_savepoint_name
55from psqlpy import ConnectionPool , Connection , Transaction , Cursor
66from psqlpy .exceptions import RustPSQLDriverPyBaseError
77from piccolo .utils .warnings import Level , colored_warning
88from piccolo .query .base import DDL , Query
99from piccolo .engine .base import Batch
1010from piccolo .querystring import QueryString
1111from piccolo .utils .sync import run_sync
12+ from piccolo .engine .exceptions import TransactionError
1213
1314
1415@dataclass
@@ -62,6 +63,97 @@ async def __aexit__(self, exception_type, exception, traceback):
6263 return exception is not None
6364
6465
66+ class Savepoint :
67+ def __init__ (self , name : str , transaction : "PostgresTransaction" ):
68+ self .name = name
69+ self .transaction = transaction
70+
71+ async def rollback_to (self ):
72+ validate_savepoint_name (self .name )
73+ await self .transaction .connection .execute (
74+ f"ROLLBACK TO SAVEPOINT { self .name } "
75+ )
76+
77+ async def release (self ):
78+ validate_savepoint_name (self .name )
79+ await self .transaction .connection .execute (
80+ f"RELEASE SAVEPOINT { self .name } "
81+ )
82+
83+
84+ class PostgresTransaction :
85+ def __init__ (self , engine : "PSQLPyEngine" , allow_nested : bool = True ) -> None :
86+ self .engine = engine
87+ current_transaction = self .engine .current_transaction .get ()
88+
89+ self ._savepoint_id = 0
90+ self ._parent = None
91+ self ._committed = False
92+ self ._rolled_back = False
93+
94+ if current_transaction :
95+ if allow_nested :
96+ self ._parent = current_transaction
97+ else :
98+ raise TransactionError (
99+ "A transaction is already active - nested transactions "
100+ "aren't allowed."
101+ )
102+
103+ async def __aenter__ (self ) -> "PostgresTransaction" :
104+ if self ._parent is not None :
105+ return self ._parent
106+
107+ self .connection = await self .get_connection ()
108+ self .transaction = self .connection .transaction ()
109+ await self .begin ()
110+ self .context = self .engine .current_transaction .set (self )
111+ return self
112+
113+ async def get_connection (self ):
114+ if self .engine .pool :
115+ return await self .engine .pool .connection ()
116+ else :
117+ return await self .engine .get_new_connection ()
118+
119+ async def begin (self ):
120+ await self .transaction .begin ()
121+
122+ async def commit (self ):
123+ await self .transaction .commit ()
124+ self ._committed = True
125+
126+ async def rollback (self ):
127+ await self .transaction .rollback ()
128+ self ._rolled_back = True
129+
130+ def get_savepoint_id (self ) -> int :
131+ self ._savepoint_id += 1
132+ return self ._savepoint_id
133+
134+ async def savepoint (self , name : Optional [str ] = None ) -> Savepoint :
135+ savepoint_name = name or f"savepoint_{ self .get_savepoint_id ()} "
136+ validate_savepoint_name (savepoint_name )
137+ await self .transaction .create_savepoint (savepoint_name = savepoint_name )
138+ return Savepoint (name = savepoint_name , transaction = self )
139+
140+ async def __aexit__ (self , exception_type , exception , traceback ):
141+ if self ._parent :
142+ return exception is None
143+
144+ if exception :
145+ # The user may have manually rolled it back.
146+ if not self ._rolled_back :
147+ await self .rollback ()
148+ else :
149+ # The user may have manually committed it.
150+ if not self ._committed and not self ._rolled_back :
151+ await self .commit ()
152+
153+ self .engine .current_transaction .reset (self .context )
154+
155+ return exception is None
156+
65157
66158class PSQLPyEngine (Engine ):
67159
@@ -281,5 +373,5 @@ async def run_ddl(self, ddl: str, in_pool: bool = True):
281373 def atomic (self ) -> str :
282374 return "123"
283375
284- def transaction (self , allow_nested : bool = True ) -> str :
285- return "str"
376+ def transaction (self , allow_nested : bool = True ) -> PostgresTransaction :
377+ return PostgresTransaction ( engine = self , allow_nested = allow_nested )
0 commit comments