Skip to content

Commit 7d4affd

Browse files
Merge pull request #173 from adriangb/asgi-middleware
Replace BaseHTTPMiddleware with pure ASGI middleware
2 parents 221ccd6 + fe94892 commit 7d4affd

1 file changed

Lines changed: 21 additions & 15 deletions

File tree

backend/middleware.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,37 @@
1-
import typing as t
2-
31
import ssl
42
from motor.motor_asyncio import AsyncIOMotorClient
5-
from starlette.middleware.base import BaseHTTPMiddleware
63
from starlette.requests import Request
7-
from starlette.responses import JSONResponse, Response
4+
from starlette.responses import JSONResponse
5+
from starlette.types import ASGIApp, Scope, Receive, Send
86

97
from backend.constants import DATABASE_URL, DOCS_PASSWORD, MONGO_DATABASE
108

119

12-
class DatabaseMiddleware(BaseHTTPMiddleware):
13-
async def dispatch(self, request: Request, call_next: t.Callable) -> Response:
10+
class DatabaseMiddleware:
11+
12+
def __init__(self, app: ASGIApp) -> None:
13+
self._app = app
14+
15+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
1416
client: AsyncIOMotorClient = AsyncIOMotorClient(
1517
DATABASE_URL,
1618
ssl_cert_reqs=ssl.CERT_NONE
1719
)
1820
db = client[MONGO_DATABASE]
19-
request.state.db = db
20-
response = await call_next(request)
21-
return response
21+
Request(scope).state.db = db
22+
await self._app(scope, receive, send)
2223

2324

24-
class ProtectedDocsMiddleware(BaseHTTPMiddleware):
25-
async def dispatch(self, request: Request, call_next: t.Callable) -> Response:
25+
class ProtectedDocsMiddleware:
26+
27+
def __init__(self, app: ASGIApp) -> None:
28+
self._app = app
29+
30+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
31+
request = Request(scope)
2632
if DOCS_PASSWORD and request.url.path.startswith("/docs"):
2733
if request.cookies.get("docs_password") != DOCS_PASSWORD:
28-
return JSONResponse({"status": "unauthorized"}, status_code=403)
29-
30-
resp = await call_next(request)
31-
return resp
34+
resp = JSONResponse({"status": "unauthorized"}, status_code=403)
35+
await resp(scope, receive, send)
36+
return
37+
await self._app(scope, receive, send)

0 commit comments

Comments
 (0)