|
1 | | -import typing as t |
2 | | - |
3 | 1 | import ssl |
4 | 2 | from motor.motor_asyncio import AsyncIOMotorClient |
5 | | -from starlette.middleware.base import BaseHTTPMiddleware |
6 | 3 | from starlette.requests import Request |
7 | | -from starlette.responses import JSONResponse, Response |
| 4 | +from starlette.responses import JSONResponse |
| 5 | +from starlette.types import ASGIApp, Scope, Receive, Send |
8 | 6 |
|
9 | 7 | from backend.constants import DATABASE_URL, DOCS_PASSWORD, MONGO_DATABASE |
10 | 8 |
|
11 | 9 |
|
12 | | -class DatabaseMiddleware(BaseHTTPMiddleware): |
13 | | - async def dispatch(self, request: Request, call_next: t.Callable) -> Response: |
| 10 | +class DatabaseMiddleware: |
| 11 | + |
| 12 | + def __init__(self, app: ASGIApp) -> None: |
| 13 | + self._app = app |
| 14 | + |
| 15 | + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: |
14 | 16 | client: AsyncIOMotorClient = AsyncIOMotorClient( |
15 | 17 | DATABASE_URL, |
16 | 18 | ssl_cert_reqs=ssl.CERT_NONE |
17 | 19 | ) |
18 | 20 | db = client[MONGO_DATABASE] |
19 | | - request.state.db = db |
20 | | - response = await call_next(request) |
21 | | - return response |
| 21 | + Request(scope).state.db = db |
| 22 | + await self._app(scope, receive, send) |
22 | 23 |
|
23 | 24 |
|
24 | | -class ProtectedDocsMiddleware(BaseHTTPMiddleware): |
25 | | - async def dispatch(self, request: Request, call_next: t.Callable) -> Response: |
| 25 | +class ProtectedDocsMiddleware: |
| 26 | + |
| 27 | + def __init__(self, app: ASGIApp) -> None: |
| 28 | + self._app = app |
| 29 | + |
| 30 | + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: |
| 31 | + request = Request(scope) |
26 | 32 | if DOCS_PASSWORD and request.url.path.startswith("/docs"): |
27 | 33 | if request.cookies.get("docs_password") != DOCS_PASSWORD: |
28 | | - return JSONResponse({"status": "unauthorized"}, status_code=403) |
29 | | - |
30 | | - resp = await call_next(request) |
31 | | - return resp |
| 34 | + resp = JSONResponse({"status": "unauthorized"}, status_code=403) |
| 35 | + await resp(scope, receive, send) |
| 36 | + return |
| 37 | + await self._app(scope, receive, send) |
0 commit comments