Skip to content

Commit fb3547b

Browse files
authored
Feature/pools (#110)
* Redis changed to pool. Signed-off-by: Pavel Kirilin <win10@list.ru>
1 parent 5a85264 commit fb3547b

10 files changed

Lines changed: 119 additions & 78 deletions

File tree

fastapi_template/template/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/conftest.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
from unittest.mock import Mock
1111

1212
{%- if cookiecutter.enable_redis == "True" %}
13-
from fakeredis.aioredis import FakeRedis
14-
from {{cookiecutter.project_name}}.services.redis.dependency import get_redis_connection
13+
from fakeredis import FakeServer
14+
from fakeredis.aioredis import FakeConnection
15+
from redis.asyncio import ConnectionPool
16+
from {{cookiecutter.project_name}}.services.redis.dependency import get_redis_pool
1517
{%- endif %}
1618
{%- if cookiecutter.enable_rmq == "True" %}
1719
from aio_pika import Channel
@@ -425,15 +427,19 @@ async def test_kafka_producer() -> AsyncGenerator[AIOKafkaProducer, None]:
425427

426428
{% if cookiecutter.enable_redis == "True" -%}
427429
@pytest.fixture
428-
async def fake_redis() -> AsyncGenerator[FakeRedis, None]:
430+
async def fake_redis_pool() -> AsyncGenerator[ConnectionPool, None]:
429431
"""
430432
Get instance of a fake redis.
431433
432434
:yield: FakeRedis instance.
433435
"""
434-
redis = FakeRedis(decode_responses=True)
435-
yield redis
436-
await redis.close()
436+
server = FakeServer()
437+
server.connected = True
438+
pool = ConnectionPool(connection_class=FakeConnection, server=server)
439+
440+
yield pool
441+
442+
await pool.disconnect()
437443

438444
{%- endif %}
439445

@@ -445,7 +451,7 @@ def fastapi_app(
445451
dbpool: AsyncConnectionPool,
446452
{%- endif %}
447453
{% if cookiecutter.enable_redis == "True" -%}
448-
fake_redis: FakeRedis,
454+
fake_redis_pool: ConnectionPool,
449455
{%- endif %}
450456
{%- if cookiecutter.enable_rmq == 'True' %}
451457
test_rmq_pool: Pool[Channel],
@@ -466,7 +472,7 @@ def fastapi_app(
466472
application.dependency_overrides[get_db_pool] = lambda: dbpool
467473
{%- endif %}
468474
{%- if cookiecutter.enable_redis == "True" %}
469-
application.dependency_overrides[get_redis_connection] = lambda: fake_redis
475+
application.dependency_overrides[get_redis_pool] = lambda: fake_redis_pool
470476
{%- endif %}
471477
{%- if cookiecutter.enable_rmq == 'True' %}
472478
application.dependency_overrides[get_rmq_channel_pool] = lambda: test_rmq_pool

fastapi_template/template/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/services/redis/dependency.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,21 @@
44
from starlette.requests import Request
55

66

7-
async def get_redis_connection(request: Request) -> AsyncGenerator[Redis, None]: # pragma: no cover
7+
async def get_redis_pool(request: Request) -> AsyncGenerator[Redis, None]: # pragma: no cover
88
"""
9-
Get redis client.
9+
Returns connection pool.
1010
11-
This dependency aquires connection from pool.
11+
You can use it like this:
12+
13+
>>> from redis.asyncio import ConnectionPool, Redis
14+
>>>
15+
>>> async def handler(redis_pool: ConnectionPool = Depends(get_redis_pool)):
16+
>>> async with Redis(connection_pool=redis_pool) as redis:
17+
>>> await redis.get('key')
18+
19+
I use pools so you don't acquire connection till the end of the handler.
1220
1321
:param request: current request.
14-
:yield: redis client.
22+
:returns: redis connection pool.
1523
"""
16-
redis_client = Redis(connection_pool=request.app.state.redis_pool)
17-
18-
try: # noqa: WPS501
19-
yield redis_client
20-
finally:
21-
await redis_client.close()
24+
return request.app.state.redis_pool

fastapi_template/template/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/tests/test_redis.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,20 @@
55
from fastapi import FastAPI
66
from httpx import AsyncClient
77
from starlette import status
8-
from fakeredis.aioredis import FakeRedis
8+
from redis.asyncio import ConnectionPool, Redis
99

1010

1111
@pytest.mark.anyio
1212
async def test_setting_value(
1313
fastapi_app: FastAPI,
14-
fake_redis: FakeRedis,
14+
fake_redis_pool: ConnectionPool,
1515
client: AsyncClient,
1616
) -> None:
1717
"""
1818
Tests that you can set value in redis.
1919
2020
:param fastapi_app: current application fixture.
21-
:param fake_redis: fake redis instance.
21+
:param fake_redis_pool: fake redis pool.
2222
:param client: client fixture.
2323
"""
2424
{%- if cookiecutter.api_type == 'rest' %}
@@ -53,26 +53,28 @@ async def test_setting_value(
5353
{%- endif %}
5454

5555
assert response.status_code == status.HTTP_200_OK
56-
actual_value = await fake_redis.get(test_key)
57-
assert actual_value == test_val
56+
async with Redis(connection_pool=fake_redis_pool) as redis:
57+
actual_value = await redis.get(test_key)
58+
assert actual_value.decode() == test_val
5859

5960

6061
@pytest.mark.anyio
6162
async def test_getting_value(
6263
fastapi_app: FastAPI,
63-
fake_redis: FakeRedis,
64+
fake_redis_pool: ConnectionPool,
6465
client: AsyncClient,
6566
) -> None:
6667
"""
6768
Tests that you can get value from redis by key.
6869
6970
:param fastapi_app: current application fixture.
70-
:param fake_redis: fake redis instance.
71+
:param fake_redis_pool: fake redis pool.
7172
:param client: client fixture.
7273
"""
7374
test_key = uuid.uuid4().hex
7475
test_val = uuid.uuid4().hex
75-
await fake_redis.set(test_key, test_val)
76+
async with Redis(connection_pool=fake_redis_pool) as redis:
77+
await redis.set(test_key, test_val)
7678

7779
{%- if cookiecutter.api_type == 'rest' %}
7880
url = fastapi_app.url_path_for('get_redis_value')
Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from redis.asyncio import Redis
1+
from redis.asyncio import ConnectionPool, Redis
22
from fastapi import APIRouter
33
from fastapi.param_functions import Depends
44

5-
from {{cookiecutter.project_name}}.services.redis.dependency import get_redis_connection
5+
from {{cookiecutter.project_name}}.services.redis.dependency import get_redis_pool
66
from {{cookiecutter.project_name}}.web.api.redis.schema import RedisValueDTO
77

88
router = APIRouter()
@@ -11,16 +11,17 @@
1111
@router.get("/", response_model=RedisValueDTO)
1212
async def get_redis_value(
1313
key: str,
14-
redis: Redis = Depends(get_redis_connection),
14+
redis_pool: ConnectionPool = Depends(get_redis_pool),
1515
) -> RedisValueDTO:
1616
"""
1717
Get value from redis.
1818
1919
:param key: redis key, to get data from.
20-
:param redis: redis connection.
20+
:param redis_pool: redis connection pool.
2121
:returns: information from redis.
2222
"""
23-
redis_value = await redis.get(key)
23+
async with Redis(connection_pool=redis_pool) as redis:
24+
redis_value = await redis.get(key)
2425
return RedisValueDTO(
2526
key=key,
2627
value=redis_value,
@@ -30,13 +31,14 @@ async def get_redis_value(
3031
@router.put("/")
3132
async def set_redis_value(
3233
redis_value: RedisValueDTO,
33-
redis: Redis = Depends(get_redis_connection),
34+
redis_pool: ConnectionPool = Depends(get_redis_pool),
3435
) -> None:
3536
"""
3637
Set value in redis.
3738
3839
:param redis_value: new value data.
39-
:param redis: redis connection.
40+
:param redis_pool: redis connection pool.
4041
"""
4142
if redis_value.value is not None:
42-
await redis.set(name=redis_value.key, value=redis_value.value)
43+
async with Redis(connection_pool=redis_pool) as redis:
44+
await redis.set(name=redis_value.key, value=redis_value.value)

fastapi_template/template/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/web/gql/context.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from strawberry.fastapi import BaseContext
33

44
{%- if cookiecutter.enable_redis == "True" %}
5-
from redis.asyncio import Redis
6-
from {{cookiecutter.project_name}}.services.redis.dependency import get_redis_connection
5+
from redis.asyncio import ConnectionPool
6+
from {{cookiecutter.project_name}}.services.redis.dependency import get_redis_pool
77
{%- endif %}
88

99
{%- if cookiecutter.enable_rmq == "True" %}
@@ -18,14 +18,12 @@
1818
{%- endif %}
1919

2020

21-
{%- if cookiecutter.db_info.name != 'none' %}
22-
from {{cookiecutter.project_name}}.db.dependencies import get_db_session
23-
{%- endif %}
24-
2521
{%- if cookiecutter.orm == "sqlalchemy" %}
2622
from sqlalchemy.ext.asyncio import AsyncSession
23+
from {{cookiecutter.project_name}}.db.dependencies import get_db_session
2724
{%- elif cookiecutter.orm == "psycopg" %}
28-
from psycopg import AsyncConnection
25+
from psycopg_pool import AsyncConnectionPool
26+
from {{cookiecutter.project_name}}.db.dependencies import get_db_pool
2927
{%- endif %}
3028

3129

@@ -35,29 +33,32 @@ class Context(BaseContext):
3533
def __init__(
3634
self,
3735
{%- if cookiecutter.enable_redis == "True" %}
38-
redis: Redis = Depends(get_redis_connection),
36+
redis_pool: ConnectionPool = Depends(get_redis_pool),
3937
{%- endif %}
4038
{%- if cookiecutter.enable_rmq == "True" %}
4139
rabbit: Pool[Channel] = Depends(get_rmq_channel_pool),
4240
{%- endif %}
4341
{%- if cookiecutter.orm == "sqlalchemy" %}
4442
db_connection: AsyncSession = Depends(get_db_session),
4543
{%- elif cookiecutter.orm == "psycopg" %}
46-
db_connection: AsyncConnection[Any] = Depends(get_db_session),
44+
db_pool: AsyncConnectionPool = Depends(get_db_pool),
4745
{%- endif %}
4846
{%- if cookiecutter.enable_kafka == "True" %}
4947
kafka_producer: AIOKafkaProducer = Depends(get_kafka_producer),
5048
{%- endif %}
5149
) -> None:
5250
{%- if cookiecutter.enable_redis == "True" %}
53-
self.redis = redis
51+
self.redis_pool = redis_pool
5452
{%- endif %}
5553
{%- if cookiecutter.enable_rmq == "True" %}
5654
self.rabbit = rabbit
5755
{%- endif %}
58-
{%- if cookiecutter.orm in ["sqlalchemy", "psycopg"] %}
56+
{%- if cookiecutter.orm == "sqlalchemy" %}
5957
self.db_connection = db_connection
6058
{%- endif %}
59+
{%- if cookiecutter.orm == "psycopg" %}
60+
self.db_pool = db_pool
61+
{%- endif %}
6162
{%- if cookiecutter.enable_kafka == "True" %}
6263
self.kafka_producer = kafka_producer
6364
{%- endif %}

fastapi_template/template/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/web/gql/dummy/mutation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,10 @@ async def create_dummy_model(
2626
:param name: name of a dummy.
2727
:return: name of a dummt model.
2828
"""
29-
{%- if cookiecutter.orm in ["sqlalchemy", "psycopg"] %}
29+
{%- if cookiecutter.orm == "sqlalchemy" %}
3030
dao = DummyDAO(info.context.db_connection)
31+
{%- elif cookiecutter.orm == "psycopg" %}
32+
dao = DummyDAO(info.context.db_pool)
3133
{%- else %}
3234
dao = DummyDAO()
3335
{%- endif %}

fastapi_template/template/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/web/gql/dummy/query.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@ async def get_dummy_models(
3131
:param offset: offset of dummy objects, defaults to 0.
3232
:return: list of dummy obbjects from database.
3333
"""
34-
{%- if cookiecutter.orm in ["sqlalchemy", "psycopg"] %}
34+
{%- if cookiecutter.orm == "sqlalchemy" %}
3535
dao = DummyDAO(info.context.db_connection)
36+
{%- elif cookiecutter.orm == "psycopg" %}
37+
dao = DummyDAO(info.context.db_pool)
3638
{%- else %}
3739
dao = DummyDAO()
3840
{%- endif %}

fastapi_template/template/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/web/gql/redis/mutation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from {{cookiecutter.project_name}}.web.gql.context import Context
55
from {{cookiecutter.project_name}}.web.gql.redis.schema import RedisDTO, RedisDTOInput
6-
6+
from redis.asyncio import Redis
77

88
@strawberry.type
99
class Mutation:
@@ -22,5 +22,6 @@ async def set_redis_value(
2222
:param info: connection info.
2323
:return: key and value.
2424
"""
25-
await info.context.redis.set(name=data.key, value=data.value)
25+
async with Redis(connection_pool=info.context.redis_pool) as redis:
26+
await redis.set(name=data.key, value=data.value)
2627
return RedisDTO(key=data.key, value=data.value) # type: ignore

fastapi_template/template/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/web/gql/redis/query.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from {{cookiecutter.project_name}}.web.gql.context import Context
55
from {{cookiecutter.project_name}}.web.gql.redis.schema import RedisDTO
6-
6+
from redis.asyncio import Redis
77

88
@strawberry.type
99
class Query:
@@ -18,7 +18,8 @@ async def get_redis_value(self, key: str, info: Info[Context, None]) -> RedisDTO
1818
:param info: resolver context.
1919
:return: information from redis.
2020
"""
21-
val = await info.context.redis.get(name=key)
21+
async with Redis(connection_pool=info.context.redis_pool) as redis:
22+
val = await redis.get(name=key)
2223
if isinstance(val, bytes):
2324
val = val.decode("utf-8")
2425
return RedisDTO(key=key, value=val) # type: ignore

0 commit comments

Comments
 (0)