Coverage for src / loman / graph_utils.py: 92%
26 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-02 23:34 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-02 23:34 +0000
1"""Graph utility functions for computation graph operations."""
3import functools
5import networkx as nx
7from loman.exception import LoopDetectedError
8from loman.util import apply_n
11def contract_node_one(g, n):
12 """Remove a node from graph and connect its predecessors to its successors."""
13 for p in g.predecessors(n):
14 for s in g.successors(n):
15 g.add_edge(p, s)
16 g.remove_node(n)
19def contract_node(g, ns):
20 """Remove multiple nodes from graph and connect their predecessors to successors."""
21 apply_n(functools.partial(contract_node_one, g), ns)
24def topological_sort(g):
25 """Performs a topological sort on a directed acyclic graph (DAG).
27 This function attempts to compute the topological order of the nodes in
28 the given graph `g`. If the graph contains a cycle, it raises a
29 `LoopDetectedError` with details about the detected cycle, making it
30 informative for debugging purposes.
32 Parameters:
33 g : networkx.DiGraph
34 A directed graph to be sorted. Must be provided as an instance of
35 `networkx.DiGraph`. The function assumes the graph is acyclic unless
36 a cycle is detected.
38 Returns:
39 list
40 A list of nodes in topologically sorted order, if the graph has no
41 cycles.
43 Raises:
44 LoopDetectedError
45 If the graph contains a cycle, a `LoopDetectedError` is raised with
46 information about the detected cycle if available. The detected cycle
47 is presented as a list of directed edges forming the cycle.
49 NetworkXUnfeasible
50 If topological sorting fails due to reasons other than cyclic
51 dependencies in the graph.
52 """
53 try:
54 return list(nx.topological_sort(g))
55 except nx.NetworkXUnfeasible as e:
56 cycle_lst = None
57 if g is not None:
58 try:
59 cycle_lst = nx.find_cycle(g)
60 except nx.NetworkXNoCycle:
61 # there must non-cycle reason NetworkXUnfeasible, leave as is
62 raise e
63 args = []
64 if cycle_lst:
65 lst = [f"{n_src}->{n_tgt}" for n_src, n_tgt in cycle_lst]
66 args = [f"DAG cycle: {', '.join(lst)}"]
67 raise LoopDetectedError(*args) from e