|
1 | 1 | import contextvars |
2 | 2 | from dataclasses import dataclass |
3 | | -from typing import Any, Dict, List, Mapping, Optional, Sequence |
| 3 | +from typing import Any, Dict, List, Mapping, Optional, Sequence, Union |
4 | 4 | from piccolo.engine.base import Engine, validate_savepoint_name |
5 | 5 | from psqlpy import ConnectionPool, Connection, Transaction, Cursor |
6 | 6 | from psqlpy.exceptions import RustPSQLDriverPyBaseError |
@@ -63,6 +63,37 @@ async def __aexit__(self, exception_type, exception, traceback): |
63 | 63 | return exception is not None |
64 | 64 |
|
65 | 65 |
|
| 66 | +class Atomic: |
| 67 | + __slots__ = ("engine", "queries") |
| 68 | + |
| 69 | + def __init__(self, engine: "PSQLPyEngine") -> None: |
| 70 | + self.engine = engine |
| 71 | + self.queries: List[Union[Query, DDL]] = [] |
| 72 | + |
| 73 | + def add(self, *query: Union[Query, DDL]): |
| 74 | + self.queries += list(query) |
| 75 | + |
| 76 | + async def run(self): |
| 77 | + from piccolo.query.methods.objects import Create, GetOrCreate |
| 78 | + |
| 79 | + try: |
| 80 | + async with self.engine.transaction(): |
| 81 | + for query in self.queries: |
| 82 | + if isinstance(query, (Query, DDL, Create, GetOrCreate)): |
| 83 | + await query.run() |
| 84 | + else: |
| 85 | + raise ValueError("Unrecognised query") |
| 86 | + self.queries = [] |
| 87 | + except Exception as exception: |
| 88 | + self.queries = [] |
| 89 | + raise exception from exception |
| 90 | + |
| 91 | + def run_sync(self): |
| 92 | + return run_sync(self.run()) |
| 93 | + |
| 94 | + def __await__(self): |
| 95 | + return self.run().__await__() |
| 96 | + |
66 | 97 | class Savepoint: |
67 | 98 | def __init__(self, name: str, transaction: "PostgresTransaction"): |
68 | 99 | self.name = name |
@@ -370,8 +401,8 @@ async def run_ddl(self, ddl: str, in_pool: bool = True): |
370 | 401 |
|
371 | 402 | return response |
372 | 403 |
|
373 | | - def atomic(self) -> str: |
374 | | - return "123" |
| 404 | + def atomic(self) -> Atomic: |
| 405 | + return Atomic(engine=self) |
375 | 406 |
|
376 | 407 | def transaction(self, allow_nested: bool = True) -> PostgresTransaction: |
377 | 408 | return PostgresTransaction(engine=self, allow_nested=allow_nested) |
0 commit comments