-
Notifications
You must be signed in to change notification settings - Fork 69
Expand file tree
/
Copy pathbase.py
More file actions
213 lines (172 loc) · 7.66 KB
/
base.py
File metadata and controls
213 lines (172 loc) · 7.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import asyncio
import contextlib
import warnings
from collections import defaultdict
from collections.abc import AsyncGenerator, Awaitable, Sequence
from typing import (
Any,
Callable,
Optional,
)
from typing_extensions import Literal, Protocol, TypedDict
from weakref import WeakSet
from channels.consumer import AsyncConsumer
from channels.generic.websocket import AsyncWebsocketConsumer
class ChannelsMessage(TypedDict, total=False):
type: str
class ChannelsLayer(Protocol): # pragma: no cover
"""Channels layer spec.
Based on: https://channels.readthedocs.io/en/stable/channel_layer_spec.html
"""
# Default channels API
extensions: list[Literal["groups", "flush"]]
async def send(self, channel: str, message: dict) -> None: ...
async def receive(self, channel: str) -> dict: ...
async def new_channel(self, prefix: str = ...) -> str: ...
# If groups extension is supported
group_expiry: int
async def group_add(self, group: str, channel: str) -> None: ...
async def group_discard(self, group: str, channel: str) -> None: ...
async def group_send(self, group: str, message: dict) -> None: ...
# If flush extension is supported
async def flush(self) -> None: ...
class ChannelsConsumer(AsyncConsumer):
"""Base channels async consumer."""
channel_name: str
channel_layer: Optional[ChannelsLayer]
channel_receive: Callable[[], Awaitable[dict]]
def __init__(self, *args: str, **kwargs: Any) -> None:
self.listen_queues: defaultdict[str, WeakSet[asyncio.Queue]] = defaultdict(
WeakSet
)
super().__init__(*args, **kwargs)
async def dispatch(self, message: ChannelsMessage) -> None:
# AsyncConsumer will try to get a function for message["type"] to handle
# for both http/websocket types and also for layers communication.
# In case the type isn't one of those, pass it to the listen queue so
# that it can be consumed by self.channel_listen
type_ = message.get("type", "")
if type_ and not type_.startswith(("http.", "websocket.")):
for queue in self.listen_queues[type_]:
queue.put_nowait(message)
return
await super().dispatch(message)
async def channel_listen(
self,
type: str,
*,
timeout: Optional[float] = None,
groups: Sequence[str] = (),
) -> AsyncGenerator[Any, None]:
"""Listen for messages sent to this consumer.
Utility to listen for channels messages for this consumer inside
a resolver (usually inside a subscription).
Args:
type:
The type of the message to wait for.
timeout:
An optional timeout to wait for each subsequent message
groups:
An optional sequence of groups to receive messages from.
When passing this parameter, the groups will be registered
using `self.channel_layer.group_add` at the beginning of the
execution and then discarded using `self.channel_layer.group_discard`
at the end of the execution.
"""
warnings.warn("Use listen_to_channel instead", DeprecationWarning, stacklevel=2)
if self.channel_layer is None:
raise RuntimeError(
"Layers integration is required listening for channels.\n"
"Check https://channels.readthedocs.io/en/stable/topics/channel_layers.html "
"for more information"
)
added_groups = []
try:
# This queue will receive incoming messages for this generator instance
queue: asyncio.Queue = asyncio.Queue()
# Create a weak reference to the queue. Once we leave the current scope, it
# will be garbage collected
self.listen_queues[type].add(queue)
for group in groups:
await self.channel_layer.group_add(group, self.channel_name)
added_groups.append(group)
while True:
awaitable = queue.get()
if timeout is not None:
awaitable = asyncio.wait_for(awaitable, timeout)
try:
yield await awaitable
except asyncio.TimeoutError:
# TODO: shall we add log here and maybe in the suppress below?
return
finally:
for group in added_groups:
with contextlib.suppress(Exception):
await self.channel_layer.group_discard(group, self.channel_name)
@contextlib.asynccontextmanager
async def listen_to_channel(
self,
type: str,
*,
timeout: Optional[float] = None,
groups: Sequence[str] = (),
) -> AsyncGenerator[Any, None]:
"""Listen for messages sent to this consumer.
Utility to listen for channels messages for this consumer inside
a resolver (usually inside a subscription).
Args:
type:
The type of the message to wait for.
timeout:
An optional timeout to wait for each subsequent message
groups:
An optional sequence of groups to receive messages from.
When passing this parameter, the groups will be registered
using `self.channel_layer.group_add` at the beginning of the
execution and then discarded using `self.channel_layer.group_discard`
at the end of the execution.
"""
# Code to acquire resource (Channels subscriptions)
if self.channel_layer is None:
raise RuntimeError(
"Layers integration is required listening for channels.\n"
"Check https://channels.readthedocs.io/en/stable/topics/channel_layers.html "
"for more information"
)
added_groups = []
# This queue will receive incoming messages for this generator instance
queue: asyncio.Queue = asyncio.Queue()
# Create a weak reference to the queue. Once we leave the current scope, it
# will be garbage collected
self.listen_queues[type].add(queue)
# Subscribe to all groups but return generator object to allow user
# code to run before blocking on incoming messages
for group in groups:
await self.channel_layer.group_add(group, self.channel_name)
added_groups.append(group)
try:
yield self._listen_to_channel_generator(queue, timeout)
finally:
# Code to release resource (Channels subscriptions)
for group in added_groups:
with contextlib.suppress(Exception):
await self.channel_layer.group_discard(group, self.channel_name)
async def _listen_to_channel_generator(
self, queue: asyncio.Queue, timeout: Optional[float]
) -> AsyncGenerator[Any, None]:
"""Generator for listen_to_channel method.
Separated to allow user code to be run after subscribing to channels
and before blocking to wait for incoming channel messages.
"""
while True:
awaitable = queue.get()
if timeout is not None:
awaitable = asyncio.wait_for(awaitable, timeout)
try:
yield await awaitable
except asyncio.TimeoutError:
# TODO: shall we add log here and maybe in the suppress below?
return
class ChannelsWSConsumer(ChannelsConsumer, AsyncWebsocketConsumer):
"""Base channels websocket async consumer."""
__all__ = ["ChannelsConsumer", "ChannelsWSConsumer"]