Skip to content

Commit e01e29a

Browse files
authored
Added psycopg support. (#70)
Signed-off-by: Pavel Kirilin <win10@list.ru>
1 parent 05e3500 commit e01e29a

12 files changed

Lines changed: 301 additions & 9 deletions

File tree

fastapi_template/cli.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
import re
22
from argparse import ArgumentParser
33
from operator import attrgetter
4+
from termcolor import cprint
45

56
from prompt_toolkit import prompt
67
from prompt_toolkit.document import Document
78
from prompt_toolkit.shortcuts import checkboxlist_dialog, radiolist_dialog
89
from prompt_toolkit.validation import ValidationError, Validator
910

1011
from fastapi_template.input_model import (
12+
SUPPORTED_ORMS,
13+
ORMS_WITHOUT_MIGRATIONS,
1114
ORM,
1215
BuilderContext,
1316
DB_INFO,
@@ -199,10 +202,19 @@ def read_user_input(current_context: BuilderContext) -> BuilderContext:
199202
current_context.orm = radiolist_dialog(
200203
"ORM",
201204
text="Which ORM do you want?",
202-
values=[(orm, orm.value) for orm in list(ORM) if orm != ORM.none],
205+
values=[(orm, orm.value) for orm in SUPPORTED_ORMS[current_context.db]],
203206
).run()
204207
if current_context.orm is None:
205208
raise KeyboardInterrupt()
209+
if (
210+
current_context.orm is not None
211+
and current_context.orm != ORM.none
212+
and current_context.orm not in SUPPORTED_ORMS.get(current_context.db, [])
213+
):
214+
cprint("This ORM is not supported by chosen database.", "red")
215+
raise KeyboardInterrupt()
216+
if current_context.orm in ORMS_WITHOUT_MIGRATIONS:
217+
current_context.enable_migrations = False
206218
if current_context.ci_type is None:
207219
current_context.ci_type = radiolist_dialog(
208220
"CI",

fastapi_template/input_model.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class ORM(enum.Enum):
2424
ormar = "ormar"
2525
sqlalchemy = "sqlalchemy"
2626
tortoise = "tortoise"
27+
psycopg = "psycopg"
2728

2829

2930
class Database(BaseModel):
@@ -69,6 +70,28 @@ class Database(BaseModel):
6970
),
7071
}
7172

73+
SUPPORTED_ORMS = {
74+
DatabaseType.postgresql: [
75+
ORM.ormar,
76+
ORM.psycopg,
77+
ORM.tortoise,
78+
ORM.sqlalchemy,
79+
],
80+
DatabaseType.sqlite: [
81+
ORM.ormar,
82+
ORM.tortoise,
83+
ORM.sqlalchemy,
84+
],
85+
DatabaseType.mysql: [
86+
ORM.ormar,
87+
ORM.tortoise,
88+
ORM.sqlalchemy,
89+
]
90+
}
91+
92+
ORMS_WITHOUT_MIGRATIONS = [
93+
ORM.psycopg,
94+
]
7295

7396
class BuilderContext(BaseModel):
7497
"""Options for project generation."""

fastapi_template/template/{{cookiecutter.project_name}}/conditional_files.json

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@
8080
"{{cookiecutter.project_name}}/db_ormar/models/dummy_model.py",
8181
"{{cookiecutter.project_name}}/db_tortoise/dao",
8282
"{{cookiecutter.project_name}}/db_tortoise/models/dummy_model.py",
83+
"{{cookiecutter.project_name}}/db_psycopg/dao",
84+
"{{cookiecutter.project_name}}/db_psycopg/models/dummy_model.py",
8385
"{{cookiecutter.project_name}}/tests/test_dummy.py",
8486
"{{cookiecutter.project_name}}/db_sa/migrations/versions/2021-08-16-16-55_2b7380507a71.py",
8587
"{{cookiecutter.project_name}}/db_ormar/migrations/versions/2021-08-16-16-55_2b7380507a71.py",
@@ -114,6 +116,12 @@
114116
"{{cookiecutter.project_name}}/db_ormar"
115117
]
116118
},
119+
"PsycoPG": {
120+
"enabled": "{{cookiecutter.orm == 'psycopg'}}",
121+
"resources": [
122+
"{{cookiecutter.project_name}}/db_psycopg"
123+
]
124+
},
117125
"Postgresql DB": {
118126
"enabled": "{{cookiecutter.db_info.name == 'postgresql'}}",
119127
"resources": [

fastapi_template/template/{{cookiecutter.project_name}}/pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ aioredis = {version = "^2.0.1", extras = ["hiredis"]}
6767
{%- if cookiecutter.self_hosted_swagger == 'True' %}
6868
aiofiles = "^0.8.0"
6969
{%- endif %}
70+
{%- if cookiecutter.orm == "psycopg" %}
71+
psycopg = { version = "^3.0.11", extras = ["binary", "pool"] }
72+
{%- endif %}
7073
httptools = "^0.3.0"
7174

7275
[tool.poetry.dev-dependencies]

fastapi_template/template/{{cookiecutter.project_name}}/replaceable_files.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
"{{cookiecutter.project_name}}/db": [
33
"{{cookiecutter.project_name}}/db_sa",
44
"{{cookiecutter.project_name}}/db_ormar",
5-
"{{cookiecutter.project_name}}/db_tortoise"
5+
"{{cookiecutter.project_name}}/db_tortoise",
6+
"{{cookiecutter.project_name}}/db_psycopg"
67
]
78
}

fastapi_template/template/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/conftest.py

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@
3131
from sqlalchemy.engine import create_engine
3232
from {{cookiecutter.project_name}}.db.config import database
3333
from {{cookiecutter.project_name}}.db.utils import create_database, drop_database
34+
{%- elif cookiecutter.orm == "psycopg" %}
35+
from psycopg import AsyncConnection
36+
from psycopg_pool import AsyncConnectionPool
37+
38+
from {{cookiecutter.project_name}}.db.dependencies import get_db_session
39+
3440
{%- endif %}
3541

3642

@@ -146,6 +152,99 @@ async def initialize_db() -> AsyncGenerator[None, None]:
146152
await database.disconnect()
147153
drop_database()
148154

155+
{%- elif cookiecutter.orm == "psycopg" %}
156+
157+
async def drop_db() -> None:
158+
"""Drops database after tests."""
159+
pool = AsyncConnectionPool(conninfo=str(settings.db_url.with_path("/postgres")))
160+
await pool.wait()
161+
async with pool.connection() as conn:
162+
await conn.set_autocommit(True)
163+
await conn.execute(
164+
"SELECT pg_terminate_backend(pg_stat_activity.pid) " # noqa: S608
165+
"FROM pg_stat_activity "
166+
"WHERE pg_stat_activity.datname = %(dbname)s "
167+
"AND pid <> pg_backend_pid();",
168+
params={
169+
"dbname": settings.db_base,
170+
}
171+
)
172+
await conn.execute(
173+
f"DROP DATABASE {settings.db_base}",
174+
)
175+
await pool.close()
176+
177+
178+
async def create_db() -> None: # noqa: WPS217
179+
"""Creates database for tests."""
180+
pool = AsyncConnectionPool(conninfo=str(settings.db_url.with_path("/postgres")))
181+
await pool.wait()
182+
async with pool.connection() as conn_check:
183+
res = await conn_check.execute(
184+
"SELECT 1 FROM pg_database WHERE datname=%(dbname)s",
185+
params={
186+
"dbname": settings.db_base,
187+
}
188+
)
189+
db_exists = False
190+
row = await res.fetchone()
191+
if row is not None:
192+
db_exists = row[0]
193+
194+
if db_exists:
195+
await drop_db()
196+
197+
async with pool.connection() as conn_create:
198+
await conn_create.set_autocommit(True)
199+
await conn_create.execute(
200+
f"CREATE DATABASE {settings.db_base};",
201+
)
202+
await pool.close()
203+
204+
205+
async def create_tables(connection: AsyncConnection[Any]) -> None:
206+
"""
207+
Create tables for your database.
208+
209+
Since psycopg doesn't have migration tool,
210+
you must create your tables for tests.
211+
212+
:param connection: connection to database.
213+
"""
214+
{%- if cookiecutter.add_dummy == 'True' %}
215+
await connection.execute(
216+
"CREATE TABLE dummy ("
217+
"id SERIAL primary key,"
218+
"name VARCHAR(200)"
219+
");"
220+
)
221+
{%- endif %}
222+
pass # noqa: WPS420
223+
224+
225+
@pytest.fixture
226+
async def dbsession() -> AsyncGenerator[AsyncConnection[Any], None]:
227+
"""
228+
Creates connection to some test database.
229+
230+
This connection must be used in tests and for application.
231+
232+
:yield: connection to database.
233+
"""
234+
await create_db()
235+
pool = AsyncConnectionPool(conninfo=str(settings.db_url))
236+
await pool.wait()
237+
238+
async with pool.connection() as create_conn:
239+
await create_tables(create_conn)
240+
241+
try:
242+
async with pool.connection() as conn:
243+
yield conn
244+
finally:
245+
await pool.close()
246+
await drop_db()
247+
149248
{%- endif %}
150249

151250

@@ -167,6 +266,8 @@ async def fake_redis() -> AsyncGenerator[FakeRedis, None]:
167266
def fastapi_app(
168267
{%- if cookiecutter.orm == "sqlalchemy" %}
169268
dbsession: AsyncSession,
269+
{%- elif cookiecutter.orm == "psycopg" %}
270+
dbsession: AsyncConnection[Any],
170271
{%- endif %}
171272
{% if cookiecutter.enable_redis == "True" -%}
172273
fake_redis: FakeRedis,
@@ -178,7 +279,7 @@ def fastapi_app(
178279
:return: fastapi app with mocked dependencies.
179280
"""
180281
application = get_app()
181-
{% if cookiecutter.orm == "sqlalchemy" -%}
282+
{% if cookiecutter.orm in ["sqlalchemy", "psycopg"] -%}
182283
application.dependency_overrides[get_db_session] = lambda: dbsession
183284
{%- endif %}
184285
{%- if cookiecutter.enable_redis == "True" %}
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from termios import OFDEL
2+
from {{cookiecutter.project_name}}.db.models.dummy_model import DummyModel
3+
from typing import Any
4+
5+
from fastapi import Depends
6+
from psycopg import AsyncConnection
7+
from psycopg.rows import class_row
8+
from {{cookiecutter.project_name}}.db.dependencies import get_db_session
9+
from typing import List, Optional
10+
11+
class DummyDAO:
12+
"""Class for accessing dummy table."""
13+
14+
def __init__(
15+
self,
16+
connection: AsyncConnection[Any] = Depends(get_db_session),
17+
):
18+
self.connection = connection
19+
20+
21+
async def create_dummy_model(self, name: str) -> None:
22+
"""
23+
Creates new dummy in a database.
24+
25+
:param name: name of a dummy.
26+
"""
27+
async with self.connection.cursor(
28+
binary=True,
29+
) as cur:
30+
await cur.execute(
31+
"INSERT INTO dummy (name) VALUES (%(name)s);",
32+
params={
33+
"name": name,
34+
}
35+
)
36+
37+
async def get_all_dummies(self, limit: int, offset: int) -> List[DummyModel]:
38+
"""
39+
Get all dummy models with limit/offset pagination.
40+
41+
:param limit: limit of dummies.
42+
:param offset: offset of dummies.
43+
:return: stream of dummies.
44+
"""
45+
async with self.connection.cursor(
46+
binary=True,
47+
row_factory=class_row(DummyModel)
48+
) as cur:
49+
res = await cur.execute(
50+
"SELECT id, name FROM dummy LIMIT %(limit)s OFFSET %(offset)s;",
51+
params={
52+
"limit": limit,
53+
"offset": offset,
54+
}
55+
)
56+
return await res.fetchall()
57+
58+
async def filter(
59+
self,
60+
name: Optional[str] = None,
61+
) -> List[DummyModel]:
62+
"""
63+
Get specific dummy model.
64+
65+
:param name: name of dummy instance.
66+
:return: dummy models.
67+
"""
68+
async with self.connection.cursor(
69+
binary=True,
70+
row_factory=class_row(DummyModel)
71+
) as cur:
72+
if name is not None:
73+
res = await cur.execute(
74+
"SELECT id, name FROM dummy WHERE name=%(name)s;",
75+
params={
76+
"name": name,
77+
}
78+
)
79+
else:
80+
res = await cur.execute("SELECT id, name FROM dummy;")
81+
return await res.fetchall()
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from typing import Any, AsyncGenerator
2+
3+
from psycopg import AsyncConnection
4+
from starlette.requests import Request
5+
6+
7+
async def get_db_session(
8+
request: Request,
9+
) -> AsyncGenerator[AsyncConnection[Any], None]:
10+
"""
11+
Create and get database connection.
12+
13+
:param request: current request.
14+
:yield: database connection.
15+
"""
16+
async with request.app.state.db_pool.connection() as conn:
17+
try:
18+
yield conn
19+
except Exception: # noqa: S110
20+
pass # noqa: WPS420
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from pydantic import BaseModel
2+
3+
class DummyModel(BaseModel):
4+
"""Dummy model for database."""
5+
6+
id: int
7+
name: str

0 commit comments

Comments
 (0)