Skip to content

Commit 949ccd0

Browse files
authored
Optimize parallel workers start-up (#21203)
This is a micro-optimization that gives fixed 50-100ms performance improvement. It is mostly important for "simple" runs, where one only modifies few files. Such runs are often sub-second, where this will be a visible win. Two ideas here: * Start workers in background, but don't wait for them to become ready until they are actually needed. * Broadcast graph/SCCs data to all workers in parallel. I am actually not 100% sure about second one, but I guess it should help in big code-bases. IIUC sockets are I/O, so it should be possible to speed it up using threads. I am also adding a bit more logging, so that we get more insight into communication overhead.
1 parent fed3593 commit 949ccd0

File tree

5 files changed

+97
-34
lines changed

5 files changed

+97
-34
lines changed

mypy/build.py

Lines changed: 69 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from collections.abc import Callable, Iterator, Mapping, Sequence, Set as AbstractSet
3030
from heapq import heappop, heappush
3131
from textwrap import dedent
32+
from threading import Thread
3233
from typing import (
3334
TYPE_CHECKING,
3435
Any,
@@ -371,6 +372,7 @@ def default_flush_errors(
371372
extra_plugins = extra_plugins or []
372373

373374
workers = []
375+
connect_threads = []
374376
if options.num_workers > 0:
375377
# TODO: switch to something more efficient than pickle (also in the daemon).
376378
pickled_options = pickle.dumps(options.snapshot())
@@ -383,10 +385,17 @@ def default_flush_errors(
383385
buf = WriteBuffer()
384386
sources_message.write(buf)
385387
sources_data = buf.getvalue()
388+
389+
def connect(wc: WorkerClient, data: bytes) -> None:
390+
# Start loading sources in each worker as soon as it is up.
391+
wc.connect()
392+
wc.conn.write_bytes(data)
393+
394+
# We don't wait for workers to be ready until they are actually needed.
386395
for worker in workers:
387-
# Start loading graph in each worker as soon as it is up.
388-
worker.connect()
389-
worker.conn.write_bytes(sources_data)
396+
thread = Thread(target=connect, args=(worker, sources_data))
397+
thread.start()
398+
connect_threads.append(thread)
390399

391400
try:
392401
result = build_inner(
@@ -399,6 +408,7 @@ def default_flush_errors(
399408
stderr,
400409
extra_plugins,
401410
workers,
411+
connect_threads,
402412
)
403413
result.errors = messages
404414
return result
@@ -412,6 +422,10 @@ def default_flush_errors(
412422
e.messages = messages
413423
raise
414424
finally:
425+
# In case of an early crash it is better to wait for workers to become ready, and
426+
# shut them down cleanly. Otherwise, they will linger until connection timeout.
427+
for thread in connect_threads:
428+
thread.join()
415429
for worker in workers:
416430
try:
417431
send(worker.conn, SccRequestMessage(scc_id=None, import_errors={}, mod_data={}))
@@ -431,6 +445,7 @@ def build_inner(
431445
stderr: TextIO,
432446
extra_plugins: Sequence[Plugin],
433447
workers: list[WorkerClient],
448+
connect_threads: list[Thread],
434449
) -> BuildResult:
435450
if platform.python_implementation() == "CPython":
436451
# Run gc less frequently, as otherwise we can spend a large fraction of
@@ -486,7 +501,7 @@ def build_inner(
486501

487502
reset_global_state()
488503
try:
489-
graph = dispatch(sources, manager, stdout)
504+
graph = dispatch(sources, manager, stdout, connect_threads)
490505
if not options.fine_grained_incremental:
491506
type_state.reset_all_subtype_caches()
492507
if options.timing_stats is not None:
@@ -496,9 +511,7 @@ def build_inner(
496511
warn_unused_configs(options, flush_errors)
497512
return BuildResult(manager, graph)
498513
finally:
499-
t0 = time.time()
500-
manager.metastore.commit()
501-
manager.add_stats(cache_commit_time=time.time() - t0)
514+
manager.commit()
502515
manager.log(
503516
"Build finished in %.3f seconds with %d modules, and %d errors"
504517
% (
@@ -1119,6 +1132,11 @@ def report_file(
11191132
if self.reports is not None and self.source_set.is_source(file):
11201133
self.reports.file(file, self.modules, type_map, options)
11211134

1135+
def commit(self) -> None:
1136+
t0 = time.time()
1137+
self.metastore.commit()
1138+
self.add_stats(cache_commit_time=time.time() - t0)
1139+
11221140
def verbosity(self) -> int:
11231141
return self.options.verbosity
11241142

@@ -1156,6 +1174,24 @@ def add_stats(self, **kwds: Any) -> None:
11561174
def stats_summary(self) -> Mapping[str, object]:
11571175
return self.stats
11581176

1177+
def broadcast(self, message: bytes) -> None:
1178+
"""Broadcast same message to all workers in parallel."""
1179+
t0 = time.time()
1180+
threads = []
1181+
for worker in self.workers:
1182+
thread = Thread(target=worker.conn.write_bytes, args=(message,))
1183+
thread.start()
1184+
threads.append(thread)
1185+
for thread in threads:
1186+
thread.join()
1187+
self.add_stats(broadcast_time=time.time() - t0)
1188+
1189+
def wait_ack(self) -> None:
1190+
"""Wait for an ack from all workers."""
1191+
for worker in self.workers:
1192+
buf = receive(worker.conn)
1193+
assert read_tag(buf) == ACK_MESSAGE
1194+
11591195
def submit(self, graph: Graph, sccs: list[SCC]) -> None:
11601196
"""Submit a stale SCC for processing in current process or parallel workers."""
11611197
if self.workers:
@@ -1176,6 +1212,7 @@ def submit_to_workers(self, graph: Graph, sccs: list[SCC] | None = None) -> None
11761212
for mod_id in scc.mod_ids
11771213
if (path := graph[mod_id].xpath) in self.errors.recorded
11781214
}
1215+
t0 = time.time()
11791216
send(
11801217
self.workers[idx].conn,
11811218
SccRequestMessage(
@@ -1193,6 +1230,7 @@ def submit_to_workers(self, graph: Graph, sccs: list[SCC] | None = None) -> None
11931230
},
11941231
),
11951232
)
1233+
self.add_stats(scc_send_time=time.time() - t0)
11961234

11971235
def wait_for_done(
11981236
self, graph: Graph
@@ -1221,7 +1259,10 @@ def wait_for_done_workers(
12211259

12221260
done_sccs = []
12231261
results = {}
1224-
for idx in ready_to_read([w.conn for w in self.workers], WORKER_DONE_TIMEOUT):
1262+
t0 = time.time()
1263+
ready = ready_to_read([w.conn for w in self.workers], WORKER_DONE_TIMEOUT)
1264+
t1 = time.time()
1265+
for idx in ready:
12251266
buf = receive(self.workers[idx].conn)
12261267
assert read_tag(buf) == SCC_RESPONSE_MESSAGE
12271268
data = SccResponseMessage.read(buf)
@@ -1232,6 +1273,7 @@ def wait_for_done_workers(
12321273
assert data.result is not None
12331274
results.update(data.result)
12341275
done_sccs.append(self.scc_by_id[scc_id])
1276+
self.add_stats(scc_wait_time=t1 - t0, scc_receive_time=time.time() - t1)
12351277
self.submit_to_workers(graph) # advance after some workers are free.
12361278
return (
12371279
done_sccs,
@@ -3685,7 +3727,12 @@ def log_configuration(manager: BuildManager, sources: list[BuildSource]) -> None
36853727
# The driver
36863728

36873729

3688-
def dispatch(sources: list[BuildSource], manager: BuildManager, stdout: TextIO) -> Graph:
3730+
def dispatch(
3731+
sources: list[BuildSource],
3732+
manager: BuildManager,
3733+
stdout: TextIO,
3734+
connect_threads: list[Thread],
3735+
) -> Graph:
36893736
log_configuration(manager, sources)
36903737

36913738
t0 = time.time()
@@ -3742,7 +3789,7 @@ def dispatch(sources: list[BuildSource], manager: BuildManager, stdout: TextIO)
37423789
dump_graph(graph, stdout)
37433790
return graph
37443791

3745-
# Fine grained dependencies that didn't have an associated module in the build
3792+
# Fine-grained dependencies that didn't have an associated module in the build
37463793
# are serialized separately, so we read them after we load the graph.
37473794
# We need to read them both for running in daemon mode and if we are generating
37483795
# a fine-grained cache (so that we can properly update them incrementally).
@@ -3755,25 +3802,28 @@ def dispatch(sources: list[BuildSource], manager: BuildManager, stdout: TextIO)
37553802
if fg_deps_meta is not None:
37563803
manager.fg_deps_meta = fg_deps_meta
37573804
elif manager.stats.get("fresh_metas", 0) > 0:
3758-
# Clear the stats so we don't infinite loop because of positive fresh_metas
3805+
# Clear the stats, so we don't infinite loop because of positive fresh_metas
37593806
manager.stats.clear()
37603807
# There were some cache files read, but no fine-grained dependencies loaded.
37613808
manager.log("Error reading fine-grained dependencies cache -- aborting cache load")
37623809
manager.cache_enabled = False
37633810
manager.log("Falling back to full run -- reloading graph...")
3764-
return dispatch(sources, manager, stdout)
3811+
return dispatch(sources, manager, stdout, connect_threads)
37653812

37663813
# If we are loading a fine-grained incremental mode cache, we
37673814
# don't want to do a real incremental reprocess of the
37683815
# graph---we'll handle it all later.
37693816
if not manager.use_fine_grained_cache():
3817+
# Wait for workers since they may be needed at this point.
3818+
for thread in connect_threads:
3819+
thread.join()
37703820
process_graph(graph, manager)
37713821
# Update plugins snapshot.
37723822
write_plugins_snapshot(manager)
37733823
manager.old_plugins_snapshot = manager.plugins_snapshot
37743824
if manager.options.cache_fine_grained or manager.options.fine_grained_incremental:
3775-
# If we are running a daemon or are going to write cache for further fine grained use,
3776-
# then we need to collect fine grained protocol dependencies.
3825+
# If we are running a daemon or are going to write cache for further fine-grained use,
3826+
# then we need to collect fine-grained protocol dependencies.
37773827
# Since these are a global property of the program, they are calculated after we
37783828
# processed the whole graph.
37793829
type_state.add_all_protocol_deps(manager.fg_deps)
@@ -4166,10 +4216,8 @@ def process_graph(graph: Graph, manager: BuildManager) -> None:
41664216
buf = WriteBuffer()
41674217
graph_message.write(buf)
41684218
graph_data = buf.getvalue()
4169-
for worker in manager.workers:
4170-
buf = receive(worker.conn)
4171-
assert read_tag(buf) == ACK_MESSAGE
4172-
worker.conn.write_bytes(graph_data)
4219+
manager.wait_ack()
4220+
manager.broadcast(graph_data)
41734221

41744222
sccs = sorted_components(graph)
41754223
manager.log(
@@ -4187,13 +4235,9 @@ def process_graph(graph: Graph, manager: BuildManager) -> None:
41874235
buf = WriteBuffer()
41884236
sccs_message.write(buf)
41894237
sccs_data = buf.getvalue()
4190-
for worker in manager.workers:
4191-
buf = receive(worker.conn)
4192-
assert read_tag(buf) == ACK_MESSAGE
4193-
worker.conn.write_bytes(sccs_data)
4194-
for worker in manager.workers:
4195-
buf = receive(worker.conn)
4196-
assert read_tag(buf) == ACK_MESSAGE
4238+
manager.wait_ack()
4239+
manager.broadcast(sccs_data)
4240+
manager.wait_ack()
41974241

41984242
manager.free_workers = set(range(manager.options.num_workers))
41994243

mypy/build_worker/worker.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,10 @@
4545
process_stale_scc,
4646
)
4747
from mypy.cache import Tag, read_int_opt
48-
from mypy.defaults import RECURSION_LIMIT, WORKER_CONNECTION_TIMEOUT
48+
from mypy.defaults import RECURSION_LIMIT, WORKER_CONNECTION_TIMEOUT, WORKER_IDLE_TIMEOUT
4949
from mypy.errors import CompileError, ErrorInfo, Errors, report_internal_error
5050
from mypy.fscache import FileSystemCache
51-
from mypy.ipc import IPCException, IPCServer, receive, send
51+
from mypy.ipc import IPCException, IPCServer, ready_to_read, receive, send
5252
from mypy.modulefinder import BuildSource, BuildSourceSet, compute_search_paths
5353
from mypy.nodes import FileRawData
5454
from mypy.options import Options
@@ -170,9 +170,13 @@ def serve(server: IPCServer, ctx: ServerContext) -> None:
170170
# Notify coordinator we are ready to start processing SCCs.
171171
send(server, AckMessage())
172172
while True:
173+
t0 = time.time()
174+
ready_to_read([server], WORKER_IDLE_TIMEOUT)
175+
t1 = time.time()
173176
buf = receive(server)
174177
assert read_tag(buf) == SCC_REQUEST_MESSAGE
175178
scc_message = SccRequestMessage.read(buf)
179+
manager.add_stats(scc_wait_time=t1 - t0, scc_receive_time=time.time() - t1)
176180
scc_id = scc_message.scc_id
177181
if scc_id is None:
178182
manager.dump_stats()
@@ -193,11 +197,13 @@ def serve(server: IPCServer, ctx: ServerContext) -> None:
193197
gc.enable()
194198
result = process_stale_scc(graph, scc, manager, from_cache=graph_data.from_cache)
195199
# We must commit after each SCC, otherwise we break --sqlite-cache.
196-
manager.metastore.commit()
200+
manager.commit()
197201
except CompileError as blocker:
198202
send(server, SccResponseMessage(scc_id=scc_id, blocker=blocker))
199203
else:
204+
t1 = time.time()
200205
send(server, SccResponseMessage(scc_id=scc_id, result=result))
206+
manager.add_stats(scc_send_time=time.time() - t1)
201207
manager.add_stats(total_process_stale_time=time.time() - t0, stale_sccs_processed=1)
202208

203209

mypy/defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,5 @@
4848
WORKER_START_INTERVAL: Final = 0.01
4949
WORKER_START_TIMEOUT: Final = 3
5050
WORKER_CONNECTION_TIMEOUT: Final = 10
51+
WORKER_IDLE_TIMEOUT: Final = 600
5152
WORKER_DONE_TIMEOUT: Final = 600

mypy/ipc.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import sys
1414
import tempfile
1515
from abc import abstractmethod
16-
from collections.abc import Callable
16+
from collections.abc import Callable, Sequence
1717
from select import select
1818
from types import TracebackType
1919
from typing import Final
@@ -38,7 +38,11 @@
3838
_IPCHandle = socket.socket
3939

4040
# Size of the message packed as !L, i.e. 4 bytes in network order (big-endian).
41-
HEADER_SIZE = 4
41+
HEADER_SIZE: Final = 4
42+
43+
# This is Linux default socket buffer size (for 64 bit), so we will not
44+
# introduce an additional obstacle when exchanging a large IPC message.
45+
MAX_READ: Final = 212992
4246

4347

4448
# TODO: we should make sure consistent exceptions are raised on different platforms.
@@ -80,10 +84,10 @@ def frame_from_buffer(self) -> bytes | None:
8084
self.message_size = None
8185
return bytes(bdata)
8286

83-
def read(self, size: int = 100000) -> str:
87+
def read(self, size: int = MAX_READ) -> str:
8488
return self.read_bytes(size).decode("utf-8")
8589

86-
def read_bytes(self, size: int = 100000) -> bytes:
90+
def read_bytes(self, size: int = MAX_READ) -> bytes:
8791
"""Read bytes from an IPC connection until we have a full frame."""
8892
if sys.platform == "win32":
8993
while True:
@@ -215,6 +219,10 @@ def __init__(self, name: str, timeout: float | None) -> None:
215219
)
216220
else:
217221
self.connection = socket.socket(socket.AF_UNIX)
222+
# This is already default on Linux, we set same buffer size
223+
# for macOS vs Linux consistency to simplify reasoning.
224+
self.connection.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, MAX_READ)
225+
self.connection.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, MAX_READ)
218226
self.connection.settimeout(timeout)
219227
self.connection.connect(name)
220228

@@ -291,6 +299,10 @@ def __enter__(self) -> IPCServer:
291299
else:
292300
try:
293301
self.connection, _ = self.sock.accept()
302+
# This is already default on Linux, we set same buffer size
303+
# for macOS vs Linux consistency to simplify reasoning.
304+
self.connection.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, MAX_READ)
305+
self.connection.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, MAX_READ)
294306
except TimeoutError as e:
295307
raise IPCException("The socket timed out") from e
296308
return self
@@ -361,7 +373,7 @@ def read_status(status_file: str) -> dict[str, object]:
361373
return data
362374

363375

364-
def ready_to_read(conns: list[IPCClient], timeout: float | None = None) -> list[int]:
376+
def ready_to_read(conns: Sequence[IPCBase], timeout: float | None = None) -> list[int]:
365377
"""Wait until some connections are readable.
366378
367379
Return index of each readable connection in the original list.

mypy/metastore.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def close(self) -> None:
157157
def connect_db(db_file: str) -> sqlite3.Connection:
158158
import sqlite3.dbapi2
159159

160-
db = sqlite3.dbapi2.connect(db_file)
160+
db = sqlite3.dbapi2.connect(db_file, check_same_thread=False)
161161
# This is a bit unfortunate (as we may get corrupt cache after e.g. Ctrl + C),
162162
# but without this flag, commits are *very* slow, especially when using HDDs,
163163
# see https://www.sqlite.org/faq.html#q19 for details.

0 commit comments

Comments
 (0)