-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path_network.py
More file actions
134 lines (105 loc) · 4.19 KB
/
_network.py
File metadata and controls
134 lines (105 loc) · 4.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import ssl
import types
import typing
import trio
import truststore
from ._streams import Stream
__all__ = ["NetworkBackend", "NetworkStream", "timeout"]
class NetworkStream(Stream):
def __init__(
self, trio_stream: trio.abc.Stream, address: str = ''
) -> None:
self._trio_stream = trio_stream
self._address = address
self._closed = False
async def read(self, size: int = -1) -> bytes:
if size < 0:
size = 64 * 1024
return await self._trio_stream.receive_some(size)
async def write(self, buffer: bytes) -> None:
await self._trio_stream.send_all(buffer)
async def close(self) -> None:
# Close the NetworkStream.
# If the stream is already closed this is a checkpointed no-op.
try:
await self._trio_stream.aclose()
finally:
self._closed = True
def __repr__(self):
description = ""
description += " CLOSED" if self._closed else ""
return f"<NetworkStream [{self._address}{description}]>"
def __del__(self):
if not self._closed:
import warnings
warnings.warn(f"{self!r} was garbage collected without being closed.")
# Context managed usage...
async def __aenter__(self) -> "NetworkStream":
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: types.TracebackType | None = None,
):
await self.close()
class NetworkServer:
def __init__(self, host: str, port: int, handler, listeners: list[trio.SocketListener]):
self.host = host
self.port = port
self._handler = handler
self._listeners = listeners
# Context managed usage...
async def __aenter__(self) -> "NetworkServer":
self._nursery_manager = trio.open_nursery()
self._nursery = await self._nursery_manager.__aenter__()
self._nursery.start_soon(trio.serve_listeners, self._handler, self._listeners)
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: types.TracebackType | None = None,
):
self._nursery.cancel_scope.cancel()
await self._nursery_manager.__aexit__(exc_type, exc_value, traceback)
class NetworkBackend:
def __init__(self, ssl_ctx: ssl.SSLContext | None = None):
self._ssl_ctx = self.create_default_context() if ssl_ctx is None else ssl_ctx
def create_default_context(self) -> ssl.SSLContext:
return truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
async def connect(self, host: str, port: int) -> NetworkStream:
"""
Connect to the given address, returning a Stream instance.
"""
# Create the TCP stream
address = f"{host}:{port}"
trio_stream = await trio.open_tcp_stream(host, port)
return NetworkStream(trio_stream, address=address)
async def connect_tls(self, host: str, port: int, hostname: str = '') -> NetworkStream:
"""
Connect to the given address, returning a Stream instance.
"""
# Create the TCP stream
address = f"{host}:{port}"
trio_stream = await trio.open_tcp_stream(host, port)
# Establish SSL over TCP
hostname = hostname or host
ssl_stream = trio.SSLStream(trio_stream, ssl_context=self._ssl_ctx, server_hostname=hostname)
await ssl_stream.do_handshake()
return NetworkStream(ssl_stream, address=address)
async def serve(self, host: str, port: int, handler: typing.Callable[[NetworkStream], None]) -> NetworkServer:
async def callback(trio_stream):
stream = NetworkStream(trio_stream, address=f"{host}:{port}")
try:
await handler(stream)
finally:
await stream.close()
listeners = await trio.open_tcp_listeners(port=port, host=host)
return NetworkServer(host, port, callback, listeners)
def __repr__(self):
return f"<NetworkBackend [trio]>"
Semaphore = trio.Semaphore
Lock = trio.Lock
timeout = trio.move_on_after
sleep = trio.sleep