Skip to content

Commit ccd3b46

Browse files
committed
Continue implementing PSQLPy engine for piccolo
Signed-off-by: chandr-andr (Kiselev Aleksandr) <chandr@chandr.net>
1 parent df65186 commit ccd3b46

1 file changed

Lines changed: 95 additions & 3 deletions

File tree

psqlpy_piccolo/engine.py

Lines changed: 95 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import contextvars
22
from dataclasses import dataclass
33
from typing import Any, Dict, List, Mapping, Optional, Sequence
4-
from piccolo.engine.base import Engine
4+
from piccolo.engine.base import Engine, validate_savepoint_name
55
from psqlpy import ConnectionPool, Connection, Transaction, Cursor
66
from psqlpy.exceptions import RustPSQLDriverPyBaseError
77
from piccolo.utils.warnings import Level, colored_warning
88
from piccolo.query.base import DDL, Query
99
from piccolo.engine.base import Batch
1010
from piccolo.querystring import QueryString
1111
from piccolo.utils.sync import run_sync
12+
from piccolo.engine.exceptions import TransactionError
1213

1314

1415
@dataclass
@@ -62,6 +63,97 @@ async def __aexit__(self, exception_type, exception, traceback):
6263
return exception is not None
6364

6465

66+
class Savepoint:
67+
def __init__(self, name: str, transaction: "PostgresTransaction"):
68+
self.name = name
69+
self.transaction = transaction
70+
71+
async def rollback_to(self):
72+
validate_savepoint_name(self.name)
73+
await self.transaction.connection.execute(
74+
f"ROLLBACK TO SAVEPOINT {self.name}"
75+
)
76+
77+
async def release(self):
78+
validate_savepoint_name(self.name)
79+
await self.transaction.connection.execute(
80+
f"RELEASE SAVEPOINT {self.name}"
81+
)
82+
83+
84+
class PostgresTransaction:
85+
def __init__(self, engine: "PSQLPyEngine", allow_nested: bool = True) -> None:
86+
self.engine = engine
87+
current_transaction = self.engine.current_transaction.get()
88+
89+
self._savepoint_id = 0
90+
self._parent = None
91+
self._committed = False
92+
self._rolled_back = False
93+
94+
if current_transaction:
95+
if allow_nested:
96+
self._parent = current_transaction
97+
else:
98+
raise TransactionError(
99+
"A transaction is already active - nested transactions "
100+
"aren't allowed."
101+
)
102+
103+
async def __aenter__(self) -> "PostgresTransaction":
104+
if self._parent is not None:
105+
return self._parent
106+
107+
self.connection = await self.get_connection()
108+
self.transaction = self.connection.transaction()
109+
await self.begin()
110+
self.context = self.engine.current_transaction.set(self)
111+
return self
112+
113+
async def get_connection(self):
114+
if self.engine.pool:
115+
return await self.engine.pool.connection()
116+
else:
117+
return await self.engine.get_new_connection()
118+
119+
async def begin(self):
120+
await self.transaction.begin()
121+
122+
async def commit(self):
123+
await self.transaction.commit()
124+
self._committed = True
125+
126+
async def rollback(self):
127+
await self.transaction.rollback()
128+
self._rolled_back = True
129+
130+
def get_savepoint_id(self) -> int:
131+
self._savepoint_id += 1
132+
return self._savepoint_id
133+
134+
async def savepoint(self, name: Optional[str] = None) -> Savepoint:
135+
savepoint_name = name or f"savepoint_{self.get_savepoint_id()}"
136+
validate_savepoint_name(savepoint_name)
137+
await self.transaction.create_savepoint(savepoint_name=savepoint_name)
138+
return Savepoint(name=savepoint_name, transaction=self)
139+
140+
async def __aexit__(self, exception_type, exception, traceback):
141+
if self._parent:
142+
return exception is None
143+
144+
if exception:
145+
# The user may have manually rolled it back.
146+
if not self._rolled_back:
147+
await self.rollback()
148+
else:
149+
# The user may have manually committed it.
150+
if not self._committed and not self._rolled_back:
151+
await self.commit()
152+
153+
self.engine.current_transaction.reset(self.context)
154+
155+
return exception is None
156+
65157

66158
class PSQLPyEngine(Engine):
67159

@@ -281,5 +373,5 @@ async def run_ddl(self, ddl: str, in_pool: bool = True):
281373
def atomic(self) -> str:
282374
return "123"
283375

284-
def transaction(self, allow_nested: bool = True) -> str:
285-
return "str"
376+
def transaction(self, allow_nested: bool = True) -> PostgresTransaction:
377+
return PostgresTransaction(engine=self, allow_nested=allow_nested)

0 commit comments

Comments
 (0)