Coverage for src / loman / graph_utils.py: 92%

26 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-02 23:34 +0000

1"""Graph utility functions for computation graph operations.""" 

2 

3import functools 

4 

5import networkx as nx 

6 

7from loman.exception import LoopDetectedError 

8from loman.util import apply_n 

9 

10 

11def contract_node_one(g, n): 

12 """Remove a node from graph and connect its predecessors to its successors.""" 

13 for p in g.predecessors(n): 

14 for s in g.successors(n): 

15 g.add_edge(p, s) 

16 g.remove_node(n) 

17 

18 

19def contract_node(g, ns): 

20 """Remove multiple nodes from graph and connect their predecessors to successors.""" 

21 apply_n(functools.partial(contract_node_one, g), ns) 

22 

23 

24def topological_sort(g): 

25 """Performs a topological sort on a directed acyclic graph (DAG). 

26 

27 This function attempts to compute the topological order of the nodes in 

28 the given graph `g`. If the graph contains a cycle, it raises a 

29 `LoopDetectedError` with details about the detected cycle, making it 

30 informative for debugging purposes. 

31 

32 Parameters: 

33 g : networkx.DiGraph 

34 A directed graph to be sorted. Must be provided as an instance of 

35 `networkx.DiGraph`. The function assumes the graph is acyclic unless 

36 a cycle is detected. 

37 

38 Returns: 

39 list 

40 A list of nodes in topologically sorted order, if the graph has no 

41 cycles. 

42 

43 Raises: 

44 LoopDetectedError 

45 If the graph contains a cycle, a `LoopDetectedError` is raised with 

46 information about the detected cycle if available. The detected cycle 

47 is presented as a list of directed edges forming the cycle. 

48 

49 NetworkXUnfeasible 

50 If topological sorting fails due to reasons other than cyclic 

51 dependencies in the graph. 

52 """ 

53 try: 

54 return list(nx.topological_sort(g)) 

55 except nx.NetworkXUnfeasible as e: 

56 cycle_lst = None 

57 if g is not None: 

58 try: 

59 cycle_lst = nx.find_cycle(g) 

60 except nx.NetworkXNoCycle: 

61 # there must non-cycle reason NetworkXUnfeasible, leave as is 

62 raise e 

63 args = [] 

64 if cycle_lst: 

65 lst = [f"{n_src}->{n_tgt}" for n_src, n_tgt in cycle_lst] 

66 args = [f"DAG cycle: {', '.join(lst)}"] 

67 raise LoopDetectedError(*args) from e