Skip to content

Commit 58e963f

Browse files
Merge pull request #795 from tiran/tracking-topo
feat: add TrackingTopologicalSorter
2 parents 2918a78 + 128202b commit 58e963f

2 files changed

Lines changed: 260 additions & 1 deletion

File tree

src/fromager/dependency_graph.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from __future__ import annotations
22

33
import dataclasses
4+
import graphlib
45
import json
56
import logging
67
import pathlib
8+
import threading
79
import typing
810

911
from packaging.requirements import Requirement
@@ -379,3 +381,137 @@ def _depth_first_traversal(
379381
yield from self._depth_first_traversal(
380382
edge.destination_node.children, visited, match_dep_types
381383
)
384+
385+
386+
class TrackingTopologicalSorter:
387+
"""A thread-safe topological sorter that tracks nodes in progress
388+
389+
``TopologicalSorter.get_ready()`` returns each node only once. The
390+
tracking topological sorter keeps track which nodes are marked as done.
391+
The ``get_available()`` method returns nodes again and again, until
392+
they are marked as done. The graph is active until all nodes are marked
393+
as done.
394+
395+
Individual nodes can be marked as exclusive nodes. ``get_available``
396+
treats exclusive nodes special and returns:
397+
398+
1. one or more non-exclusive nodes
399+
2. exactly one exclusive node that is a predecessor of another node
400+
3. exactly one exclusive node
401+
402+
The class uses a lock for ``is_active`, ``get_available`, and ``done``,
403+
so the methods can be used from threading pool and future callback.
404+
"""
405+
406+
__slots__ = (
407+
"_dep_nodes",
408+
"_exclusive_nodes",
409+
"_in_progress_nodes",
410+
"_lock",
411+
"_topo",
412+
)
413+
414+
def __init__(
415+
self,
416+
graph: typing.Mapping[DependencyNode, typing.Iterable[DependencyNode]]
417+
| None = None,
418+
) -> None:
419+
self._topo: graphlib.TopologicalSorter[DependencyNode] = (
420+
graphlib.TopologicalSorter()
421+
)
422+
# set of nodes that are not done, yet
423+
self._in_progress_nodes: set[DependencyNode] = set()
424+
# set of nodes that are predecessors of other nodes
425+
self._dep_nodes: set[DependencyNode] = set()
426+
# dict of nodes -> priority; dependency: -1, leaf: +1
427+
self._exclusive_nodes: dict[DependencyNode, int] = {}
428+
self._lock = threading.Lock()
429+
if graph is not None:
430+
for node, predecessors in graph.items():
431+
self.add(node, *predecessors)
432+
433+
@property
434+
def dependency_nodes(self) -> set[DependencyNode]:
435+
"""Nodes that other nodes depend on"""
436+
return self._dep_nodes.copy()
437+
438+
@property
439+
def exclusive_nodes(self) -> set[DependencyNode]:
440+
"""Nodes that are marked as exclusive"""
441+
return set(self._exclusive_nodes)
442+
443+
def add(
444+
self,
445+
node: DependencyNode,
446+
*predecessors: DependencyNode,
447+
exclusive: bool = False,
448+
) -> None:
449+
"""Add new node
450+
451+
Can be called multiple times for a node to add more predecessors or
452+
to mark a node as exclusive. Exclusive nodes cannot be unmarked.
453+
"""
454+
self._topo.add(node, *predecessors)
455+
self._dep_nodes.update(predecessors)
456+
if exclusive:
457+
self._exclusive_nodes[node] = 1
458+
459+
def prepare(self) -> None:
460+
"""Prepare and check for cyclic dependencies"""
461+
self._topo.prepare()
462+
for node in self._exclusive_nodes:
463+
if node in self._dep_nodes:
464+
# give dependency nodes a higher priority
465+
self._exclusive_nodes[node] = -1
466+
467+
def is_active(self) -> bool:
468+
with self._lock:
469+
return bool(self._in_progress_nodes) or self._topo.is_active()
470+
471+
def __bool__(self) -> bool:
472+
return self.is_active()
473+
474+
def get_available(self) -> set[DependencyNode]:
475+
"""Get available nodes
476+
477+
A node can be returned multiple times until it is marked as 'done'.
478+
"""
479+
with self._lock:
480+
# get ready nodes, update in progress nodes.
481+
ready = self._topo.get_ready()
482+
self._in_progress_nodes.update(ready)
483+
484+
if not self._in_progress_nodes:
485+
# API misuse, user did not check "is_active"
486+
raise ValueError("topology is not active")
487+
488+
# get and prefer non-exclusive nodes. Exclusive nodes are
489+
# 'heavy' nodes, that that a long time to build. Start with
490+
# 'light' nodes first.
491+
exclusive_nodes = self._exclusive_nodes
492+
non_exclusive = self._in_progress_nodes.difference(exclusive_nodes)
493+
if non_exclusive:
494+
# set.difference() returns a new set object
495+
return non_exclusive
496+
497+
# return a single exclusive node, prefer nodes that are a
498+
# dependency of other nodes.
499+
exclusive = self._in_progress_nodes.intersection(exclusive_nodes)
500+
exclusive_list = sorted(
501+
exclusive,
502+
key=lambda node: (exclusive_nodes[node], node),
503+
)
504+
return {exclusive_list[0]}
505+
506+
def done(self, *nodes: DependencyNode) -> None:
507+
"""Mark nodes as done"""
508+
with self._lock:
509+
self._in_progress_nodes.difference_update(nodes)
510+
self._topo.done(*nodes)
511+
512+
def static_batches(self) -> typing.Iterable[set[DependencyNode]]:
513+
self.prepare()
514+
while self.is_active():
515+
nodes = self.get_available()
516+
yield nodes
517+
self.done(*nodes)

tests/test_dependency_graph.py

Lines changed: 124 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import dataclasses
22
import graphlib
3+
import typing
34

45
import pytest
56
from packaging.requirements import Requirement
67
from packaging.utils import canonicalize_name
78
from packaging.version import Version
89

9-
from fromager.dependency_graph import DependencyNode
10+
from fromager.dependency_graph import DependencyNode, TrackingTopologicalSorter
1011
from fromager.requirements_file import RequirementType
1112

1213

@@ -172,3 +173,125 @@ def test_pr759_discussion() -> None:
172173
assert sorted(d.iter_install_requirements()) == [e]
173174
assert sorted(e.iter_install_requirements()) == []
174175
assert sorted(f.iter_install_requirements()) == []
176+
177+
178+
def test_tracking_topology_sorter() -> None:
179+
a = mknode("a")
180+
b = mknode("b")
181+
c = mknode("c")
182+
d = mknode("d")
183+
e = mknode("e")
184+
f = mknode("f")
185+
186+
graph: typing.Mapping[DependencyNode, typing.Iterable[DependencyNode]]
187+
graph = {
188+
a: [b, c],
189+
b: [c, d],
190+
d: [e],
191+
f: [d],
192+
}
193+
194+
topo = TrackingTopologicalSorter(graph)
195+
topo.prepare()
196+
197+
assert topo.dependency_nodes == {b, c, d, e}
198+
assert topo.exclusive_nodes == set()
199+
# properties return new objects
200+
assert topo.dependency_nodes is not topo.dependency_nodes
201+
assert topo.exclusive_nodes is not topo.exclusive_nodes
202+
203+
processed: list[DependencyNode] = []
204+
while topo.is_active():
205+
ready = sorted(topo.get_available())
206+
r0 = ready[0]
207+
processed.append(r0)
208+
topo.done(r0)
209+
# c and e have no dependency
210+
# d depends on e
211+
# b after d
212+
# f after d, but sorting pushes it after a
213+
# a on b
214+
assert processed == [c, e, d, b, a, f]
215+
216+
topo = TrackingTopologicalSorter(graph)
217+
assert topo.dependency_nodes == {b, c, d, e}
218+
assert topo.exclusive_nodes == set()
219+
batches = list(topo.static_batches())
220+
assert batches == [
221+
{c, e},
222+
{d},
223+
{b, f},
224+
{a},
225+
]
226+
227+
topo = TrackingTopologicalSorter(graph)
228+
# mark b as exclusive
229+
topo.add(b, exclusive=True)
230+
assert topo.dependency_nodes == {b, c, d, e}
231+
assert topo.exclusive_nodes == {b}
232+
batches = list(topo.static_batches())
233+
assert batches == [
234+
{c, e},
235+
{d},
236+
{f},
237+
{b},
238+
{a},
239+
]
240+
241+
# call get_available() multiple times
242+
topo = TrackingTopologicalSorter(graph)
243+
topo.prepare()
244+
assert topo.get_available() == {c, e}
245+
assert topo.get_available() == {c, e}
246+
assert topo.get_available() == {c, e}
247+
topo.done(c, e)
248+
assert topo.get_available() == {d}
249+
250+
251+
def test_tracking_topology_sorter_cyclic_error() -> None:
252+
# cyclic graph
253+
a = mknode("a")
254+
b = mknode("b")
255+
256+
graph: typing.Mapping[DependencyNode, typing.Iterable[DependencyNode]]
257+
graph = {
258+
a: [b],
259+
b: [a],
260+
}
261+
262+
topo = TrackingTopologicalSorter(graph)
263+
with pytest.raises(graphlib.CycleError):
264+
topo.prepare()
265+
266+
267+
def test_tracking_topology_sorter_not_passed_out_error() -> None:
268+
# mark node as ready before it was passed out
269+
a = mknode("a")
270+
b = mknode("b")
271+
graph: typing.Mapping[DependencyNode, typing.Iterable[DependencyNode]]
272+
graph = {
273+
a: [b],
274+
b: [],
275+
}
276+
topo = TrackingTopologicalSorter(graph)
277+
topo.prepare()
278+
with pytest.raises(ValueError) as excinfo:
279+
topo.done(a)
280+
assert "was not passed out" in str(excinfo.value)
281+
282+
283+
def test_tracking_topology_sorter_not_active_error() -> None:
284+
# call get_available without checking is_active
285+
a = mknode("a")
286+
graph: typing.Mapping[DependencyNode, typing.Iterable[DependencyNode]]
287+
graph = {
288+
a: [],
289+
}
290+
topo = TrackingTopologicalSorter(graph)
291+
topo.prepare()
292+
done = topo.get_available()
293+
topo.done(*done)
294+
assert not topo.is_active()
295+
with pytest.raises(ValueError) as excinfo:
296+
topo.get_available()
297+
assert "topology is not active" in str(excinfo.value)

0 commit comments

Comments
 (0)