3131from sqlalchemy .engine import create_engine
3232from {{cookiecutter .project_name }}.db .config import database
3333from {{cookiecutter .project_name }}.db .utils import create_database , drop_database
34+ {% - elif cookiecutter .orm == "psycopg" % }
35+ from psycopg import AsyncConnection
36+ from psycopg_pool import AsyncConnectionPool
37+
38+ from {{cookiecutter .project_name }}.db .dependencies import get_db_session
39+
3440{% - endif % }
3541
3642
@@ -146,6 +152,99 @@ async def initialize_db() -> AsyncGenerator[None, None]:
146152 await database .disconnect ()
147153 drop_database ()
148154
155+ {% - elif cookiecutter .orm == "psycopg" % }
156+
157+ async def drop_db () -> None :
158+ """Drops database after tests."""
159+ pool = AsyncConnectionPool (conninfo = str (settings .db_url .with_path ("/postgres" )))
160+ await pool .wait ()
161+ async with pool .connection () as conn :
162+ await conn .set_autocommit (True )
163+ await conn .execute (
164+ "SELECT pg_terminate_backend(pg_stat_activity.pid) " # noqa: S608
165+ "FROM pg_stat_activity "
166+ "WHERE pg_stat_activity.datname = %(dbname)s "
167+ "AND pid <> pg_backend_pid();" ,
168+ params = {
169+ "dbname" : settings .db_base ,
170+ }
171+ )
172+ await conn .execute (
173+ f"DROP DATABASE { settings .db_base } " ,
174+ )
175+ await pool .close ()
176+
177+
178+ async def create_db () -> None : # noqa: WPS217
179+ """Creates database for tests."""
180+ pool = AsyncConnectionPool (conninfo = str (settings .db_url .with_path ("/postgres" )))
181+ await pool .wait ()
182+ async with pool .connection () as conn_check :
183+ res = await conn_check .execute (
184+ "SELECT 1 FROM pg_database WHERE datname=%(dbname)s" ,
185+ params = {
186+ "dbname" : settings .db_base ,
187+ }
188+ )
189+ db_exists = False
190+ row = await res .fetchone ()
191+ if row is not None :
192+ db_exists = row [0 ]
193+
194+ if db_exists :
195+ await drop_db ()
196+
197+ async with pool .connection () as conn_create :
198+ await conn_create .set_autocommit (True )
199+ await conn_create .execute (
200+ f"CREATE DATABASE { settings .db_base } ;" ,
201+ )
202+ await pool .close ()
203+
204+
205+ async def create_tables (connection : AsyncConnection [Any ]) -> None :
206+ """
207+ Create tables for your database.
208+
209+ Since psycopg doesn't have migration tool,
210+ you must create your tables for tests.
211+
212+ :param connection: connection to database.
213+ """
214+ {% - if cookiecutter .add_dummy == 'True' % }
215+ await connection .execute (
216+ "CREATE TABLE dummy ("
217+ "id SERIAL primary key,"
218+ "name VARCHAR(200)"
219+ ");"
220+ )
221+ {% - endif % }
222+ pass # noqa: WPS420
223+
224+
225+ @pytest .fixture
226+ async def dbsession () -> AsyncGenerator [AsyncConnection [Any ], None ]:
227+ """
228+ Creates connection to some test database.
229+
230+ This connection must be used in tests and for application.
231+
232+ :yield: connection to database.
233+ """
234+ await create_db ()
235+ pool = AsyncConnectionPool (conninfo = str (settings .db_url ))
236+ await pool .wait ()
237+
238+ async with pool .connection () as create_conn :
239+ await create_tables (create_conn )
240+
241+ try :
242+ async with pool .connection () as conn :
243+ yield conn
244+ finally :
245+ await pool .close ()
246+ await drop_db ()
247+
149248{% - endif % }
150249
151250
@@ -167,6 +266,8 @@ async def fake_redis() -> AsyncGenerator[FakeRedis, None]:
167266def fastapi_app (
168267 {% - if cookiecutter .orm == "sqlalchemy" % }
169268 dbsession : AsyncSession ,
269+ {% - elif cookiecutter .orm == "psycopg" % }
270+ dbsession : AsyncConnection [Any ],
170271 {% - endif % }
171272 {% if cookiecutter .enable_redis == "True" - % }
172273 fake_redis : FakeRedis ,
@@ -178,7 +279,7 @@ def fastapi_app(
178279 :return: fastapi app with mocked dependencies.
179280 """
180281 application = get_app ()
181- {% if cookiecutter .orm == "sqlalchemy" - % }
282+ {% if cookiecutter .orm in [ "sqlalchemy" , "psycopg" ] - % }
182283 application .dependency_overrides [get_db_session ] = lambda : dbsession
183284 {% - endif % }
184285 {% - if cookiecutter .enable_redis == "True" % }
0 commit comments