Coverage for src / loman / graph_utils.py: 100%

25 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-22 21:30 +0000

1"""Graph utility functions for computation graph operations.""" 

2 

3import functools 

4from typing import Any 

5 

6import networkx as nx 

7 

8from loman.exception import LoopDetectedError 

9from loman.util import apply_n 

10 

11 

12def contract_node_one(g: nx.DiGraph, n: Any) -> None: 

13 """Remove a node from graph and connect its predecessors to its successors.""" 

14 for p in g.predecessors(n): 

15 for s in g.successors(n): 

16 g.add_edge(p, s) 

17 g.remove_node(n) 

18 

19 

20def contract_node(g: nx.DiGraph, ns: Any) -> None: 

21 """Remove multiple nodes from graph and connect their predecessors to successors.""" 

22 apply_n(functools.partial(contract_node_one, g), ns) 

23 

24 

25def topological_sort(g: nx.DiGraph) -> list[Any]: 

26 """Performs a topological sort on a directed acyclic graph (DAG). 

27 

28 This function attempts to compute the topological order of the nodes in 

29 the given graph `g`. If the graph contains a cycle, it raises a 

30 `LoopDetectedError` with details about the detected cycle, making it 

31 informative for debugging purposes. 

32 

33 Parameters: 

34 g : networkx.DiGraph 

35 A directed graph to be sorted. Must be provided as an instance of 

36 `networkx.DiGraph`. The function assumes the graph is acyclic unless 

37 a cycle is detected. 

38 

39 Returns: 

40 list 

41 A list of nodes in topologically sorted order, if the graph has no 

42 cycles. 

43 

44 Raises: 

45 LoopDetectedError 

46 If the graph contains a cycle, a `LoopDetectedError` is raised with 

47 information about the detected cycle if available. The detected cycle 

48 is presented as a list of directed edges forming the cycle. 

49 

50 NetworkXUnfeasible 

51 If topological sorting fails due to reasons other than cyclic 

52 dependencies in the graph. 

53 """ 

54 try: 

55 return list(nx.topological_sort(g)) 

56 except nx.NetworkXUnfeasible as e: 

57 cycle_lst = None 

58 if g is not None: 

59 try: 

60 cycle_lst = nx.find_cycle(g) 

61 except nx.NetworkXNoCycle: # pragma: no cover 

62 # there must non-cycle reason NetworkXUnfeasible, leave as is 

63 raise e from None 

64 args: list[str] = [] 

65 if cycle_lst: 

66 lst = [f"{n_src}->{n_tgt}" for n_src, n_tgt in cycle_lst] 

67 args = [f"DAG cycle: {', '.join(lst)}"] 

68 raise LoopDetectedError(*args) from e