Coverage for src / loman / graph_utils.py: 100%
25 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-22 21:30 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-22 21:30 +0000
1"""Graph utility functions for computation graph operations."""
3import functools
4from typing import Any
6import networkx as nx
8from loman.exception import LoopDetectedError
9from loman.util import apply_n
12def contract_node_one(g: nx.DiGraph, n: Any) -> None:
13 """Remove a node from graph and connect its predecessors to its successors."""
14 for p in g.predecessors(n):
15 for s in g.successors(n):
16 g.add_edge(p, s)
17 g.remove_node(n)
20def contract_node(g: nx.DiGraph, ns: Any) -> None:
21 """Remove multiple nodes from graph and connect their predecessors to successors."""
22 apply_n(functools.partial(contract_node_one, g), ns)
25def topological_sort(g: nx.DiGraph) -> list[Any]:
26 """Performs a topological sort on a directed acyclic graph (DAG).
28 This function attempts to compute the topological order of the nodes in
29 the given graph `g`. If the graph contains a cycle, it raises a
30 `LoopDetectedError` with details about the detected cycle, making it
31 informative for debugging purposes.
33 Parameters:
34 g : networkx.DiGraph
35 A directed graph to be sorted. Must be provided as an instance of
36 `networkx.DiGraph`. The function assumes the graph is acyclic unless
37 a cycle is detected.
39 Returns:
40 list
41 A list of nodes in topologically sorted order, if the graph has no
42 cycles.
44 Raises:
45 LoopDetectedError
46 If the graph contains a cycle, a `LoopDetectedError` is raised with
47 information about the detected cycle if available. The detected cycle
48 is presented as a list of directed edges forming the cycle.
50 NetworkXUnfeasible
51 If topological sorting fails due to reasons other than cyclic
52 dependencies in the graph.
53 """
54 try:
55 return list(nx.topological_sort(g))
56 except nx.NetworkXUnfeasible as e:
57 cycle_lst = None
58 if g is not None:
59 try:
60 cycle_lst = nx.find_cycle(g)
61 except nx.NetworkXNoCycle: # pragma: no cover
62 # there must non-cycle reason NetworkXUnfeasible, leave as is
63 raise e from None
64 args: list[str] = []
65 if cycle_lst:
66 lst = [f"{n_src}->{n_tgt}" for n_src, n_tgt in cycle_lst]
67 args = [f"DAG cycle: {', '.join(lst)}"]
68 raise LoopDetectedError(*args) from e