Skip to content

Commit 5f9011a

Browse files
authored
Change psycopg dependencies (#107)
Fixed connection pool dependency for psycopg.
1 parent 8db75a2 commit 5f9011a

4 files changed

Lines changed: 63 additions & 64 deletions

File tree

fastapi_template/template/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/conftest.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
from psycopg import AsyncConnection
5252
from psycopg_pool import AsyncConnectionPool
5353

54-
from {{cookiecutter.project_name}}.db.dependencies import get_db_session
54+
from {{cookiecutter.project_name}}.db.dependencies import get_db_pool
5555
{%- elif cookiecutter.orm == "piccolo" %}
5656
{%- if cookiecutter.db_info.name == "postgresql" %}
5757
from piccolo.engine.postgres import PostgresEngine
@@ -245,13 +245,13 @@ async def create_tables(connection: AsyncConnection[Any]) -> None:
245245

246246

247247
@pytest.fixture
248-
async def dbsession() -> AsyncGenerator[AsyncConnection[Any], None]:
248+
async def dbpool() -> AsyncGenerator[AsyncConnectionPool, None]:
249249
"""
250-
Creates connection to some test database.
250+
Creates database connections pool to test database.
251251
252252
This connection must be used in tests and for application.
253253
254-
:yield: connection to database.
254+
:yield: database connections pool.
255255
"""
256256
await create_db()
257257
pool = AsyncConnectionPool(conninfo=str(settings.db_url))
@@ -261,8 +261,7 @@ async def dbsession() -> AsyncGenerator[AsyncConnection[Any], None]:
261261
await create_tables(create_conn)
262262

263263
try:
264-
async with pool.connection() as conn:
265-
yield conn
264+
yield pool
266265
finally:
267266
await pool.close()
268267
await drop_db()
@@ -443,7 +442,7 @@ def fastapi_app(
443442
{%- if cookiecutter.orm == "sqlalchemy" %}
444443
dbsession: AsyncSession,
445444
{%- elif cookiecutter.orm == "psycopg" %}
446-
dbsession: AsyncConnection[Any],
445+
dbpool: AsyncConnectionPool,
447446
{%- endif %}
448447
{% if cookiecutter.enable_redis == "True" -%}
449448
fake_redis: FakeRedis,
@@ -461,8 +460,10 @@ def fastapi_app(
461460
:return: fastapi app with mocked dependencies.
462461
"""
463462
application = get_app()
464-
{% if cookiecutter.orm in ["sqlalchemy", "psycopg"] -%}
463+
{%- if cookiecutter.orm == "sqlalchemy" %}
465464
application.dependency_overrides[get_db_session] = lambda: dbsession
465+
{%- elif cookiecutter.orm == "psycopg" %}
466+
application.dependency_overrides[get_db_pool] = lambda: dbpool
466467
{%- endif %}
467468
{%- if cookiecutter.enable_redis == "True" %}
468469
application.dependency_overrides[get_redis_connection] = lambda: fake_redis

fastapi_template/template/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/db_psycopg/dao/dummy_dao.py

Lines changed: 40 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,19 @@
33
from typing import Any
44

55
from fastapi import Depends
6-
from psycopg import AsyncConnection
6+
from psycopg_pool import AsyncConnectionPool
77
from psycopg.rows import class_row
8-
from {{cookiecutter.project_name}}.db.dependencies import get_db_session
8+
from {{cookiecutter.project_name}}.db.dependencies import get_db_pool
99
from typing import List, Optional
1010

1111
class DummyDAO:
1212
"""Class for accessing dummy table."""
1313

1414
def __init__(
1515
self,
16-
connection: AsyncConnection[Any] = Depends(get_db_session),
16+
db_pool: AsyncConnectionPool = Depends(get_db_pool),
1717
):
18-
self.connection = connection
18+
self.db_pool = db_pool
1919

2020

2121
async def create_dummy_model(self, name: str) -> None:
@@ -24,15 +24,14 @@ async def create_dummy_model(self, name: str) -> None:
2424
2525
:param name: name of a dummy.
2626
"""
27-
async with self.connection.cursor(
28-
binary=True,
29-
) as cur:
30-
await cur.execute(
31-
"INSERT INTO dummy (name) VALUES (%(name)s);",
32-
params={
33-
"name": name,
34-
}
35-
)
27+
async with self.db_pool.connection() as connection:
28+
async with connection.cursor(binary=True) as cur:
29+
await cur.execute(
30+
"INSERT INTO dummy (name) VALUES (%(name)s);",
31+
params={
32+
"name": name,
33+
}
34+
)
3635

3736
async def get_all_dummies(self, limit: int, offset: int) -> List[DummyModel]:
3837
"""
@@ -42,18 +41,19 @@ async def get_all_dummies(self, limit: int, offset: int) -> List[DummyModel]:
4241
:param offset: offset of dummies.
4342
:return: stream of dummies.
4443
"""
45-
async with self.connection.cursor(
46-
binary=True,
47-
row_factory=class_row(DummyModel)
48-
) as cur:
49-
res = await cur.execute(
50-
"SELECT id, name FROM dummy LIMIT %(limit)s OFFSET %(offset)s;",
51-
params={
52-
"limit": limit,
53-
"offset": offset,
54-
}
55-
)
56-
return await res.fetchall()
44+
async with self.db_pool.connection() as connection:
45+
async with connection.cursor(
46+
binary=True,
47+
row_factory=class_row(DummyModel)
48+
) as cur:
49+
res = await cur.execute(
50+
"SELECT id, name FROM dummy LIMIT %(limit)s OFFSET %(offset)s;",
51+
params={
52+
"limit": limit,
53+
"offset": offset,
54+
}
55+
)
56+
return await res.fetchall()
5757

5858
async def filter(
5959
self,
@@ -65,17 +65,18 @@ async def filter(
6565
:param name: name of dummy instance.
6666
:return: dummy models.
6767
"""
68-
async with self.connection.cursor(
69-
binary=True,
70-
row_factory=class_row(DummyModel)
71-
) as cur:
72-
if name is not None:
73-
res = await cur.execute(
74-
"SELECT id, name FROM dummy WHERE name=%(name)s;",
75-
params={
76-
"name": name,
77-
}
78-
)
79-
else:
80-
res = await cur.execute("SELECT id, name FROM dummy;")
81-
return await res.fetchall()
68+
async with self.db_pool.connection() as connection:
69+
async with connection.cursor(
70+
binary=True,
71+
row_factory=class_row(DummyModel)
72+
) as cur:
73+
if name is not None:
74+
res = await cur.execute(
75+
"SELECT id, name FROM dummy WHERE name=%(name)s;",
76+
params={
77+
"name": name,
78+
}
79+
)
80+
else:
81+
res = await cur.execute("SELECT id, name FROM dummy;")
82+
return await res.fetchall()
Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,12 @@
1-
from typing import Any, AsyncGenerator
2-
3-
from psycopg import AsyncConnection
1+
from psycopg_pool import AsyncConnectionPool
42
from starlette.requests import Request
53

64

7-
async def get_db_session(
8-
request: Request,
9-
) -> AsyncGenerator[AsyncConnection[Any], None]:
5+
async def get_db_pool(request: Request) -> AsyncConnectionPool:
106
"""
11-
Create and get database connection.
7+
Return database connections pool.
128
139
:param request: current request.
14-
:yield: database connection.
10+
:returns: database connections pool.
1511
"""
16-
async with request.app.state.db_pool.connection() as conn:
17-
try:
18-
yield conn
19-
except Exception: # noqa: S110
20-
pass # noqa: WPS420
12+
return request.app.state.db_pool

fastapi_template/template/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/tests/test_dummy.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from sqlalchemy.ext.asyncio import AsyncSession
88
{%- elif cookiecutter.orm == 'psycopg' %}
99
from psycopg.connection_async import AsyncConnection
10+
from psycopg_pool import AsyncConnectionPool
1011
{%- endif %}
1112
from starlette import status
1213
from {{cookiecutter.project_name}}.db.models.dummy_model import DummyModel
@@ -19,7 +20,7 @@ async def test_creation(
1920
{%- if cookiecutter.orm == "sqlalchemy" %}
2021
dbsession: AsyncSession,
2122
{%- elif cookiecutter.orm == "psycopg" %}
22-
dbsession: AsyncConnection[Any],
23+
dbpool: AsyncConnectionPool,
2324
{%- endif %}
2425
) -> None:
2526
"""Tests dummy instance creation."""
@@ -43,8 +44,10 @@ async def test_creation(
4344
)
4445
{%- endif %}
4546
assert response.status_code == status.HTTP_200_OK
46-
{%- if cookiecutter.orm in ["sqlalchemy", "psycopg"] %}
47+
{%- if cookiecutter.orm == "sqlalchemy" %}
4748
dao = DummyDAO(dbsession)
49+
{%- elif cookiecutter.orm == "psycopg" %}
50+
dao = DummyDAO(dbpool)
4851
{%- elif cookiecutter.orm in ["tortoise", "ormar", "piccolo"] %}
4952
dao = DummyDAO()
5053
{%- endif %}
@@ -59,12 +62,14 @@ async def test_getting(
5962
{%- if cookiecutter.orm == "sqlalchemy" %}
6063
dbsession: AsyncSession,
6164
{%- elif cookiecutter.orm == "psycopg" %}
62-
dbsession: AsyncConnection[Any],
65+
dbpool: AsyncConnectionPool,
6366
{%- endif %}
6467
) -> None:
6568
"""Tests dummy instance retrieval."""
66-
{%- if cookiecutter.orm in ["sqlalchemy", "psycopg"] %}
69+
{%- if cookiecutter.orm == "sqlalchemy" %}
6770
dao = DummyDAO(dbsession)
71+
{%- elif cookiecutter.orm == "psycopg" %}
72+
dao = DummyDAO(dbpool)
6873
{%- elif cookiecutter.orm in ["tortoise", "ormar", "piccolo"] %}
6974
dao = DummyDAO()
7075
{%- endif %}

0 commit comments

Comments
 (0)