Coverage for src/loman/visualization.py: 81%
366 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-21 05:36 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-21 05:36 +0000
1"""Visualization tools for computation graphs using Graphviz."""
3import os
4import sys
5import tempfile
6from abc import ABC, abstractmethod
7from collections import defaultdict
8from dataclasses import dataclass, field
10import matplotlib as mpl
11import networkx as nx
12import numpy as np
13import pandas as pd
14import pydotplus
15from matplotlib.colors import Colormap
17import loman
19from .consts import NodeAttributes, NodeTransformations, States
20from .graph_utils import contract_node
21from .nodekey import Name, NodeKey, is_pattern, match_pattern, to_nodekey
24@dataclass
25class Node:
26 """Represents a node in the visualization graph."""
28 nodekey: NodeKey
29 original_nodekey: NodeKey
30 data: dict
33class NodeFormatter(ABC):
34 """Abstract base class for node formatting in visualizations."""
36 def calibrate(self, nodes: list[Node]) -> None:
37 """Calibrate formatter based on all nodes in the graph."""
38 pass
40 @abstractmethod
41 def format(self, name: NodeKey, nodes: list[Node], is_composite: bool) -> dict | None:
42 """Format node appearance returning dict of graphviz attributes."""
43 pass
45 @staticmethod
46 def create(cmap: dict | Colormap | None = None, colors: str = "state", shapes: str | None = None):
47 """Create a composite node formatter with specified color and shape options."""
48 node_formatters = [StandardLabel(), StandardGroup()]
50 if isinstance(shapes, str):
51 shapes = shapes.lower()
52 if shapes == "type":
53 node_formatters.append(ShapeByType())
54 elif shapes is None:
55 pass
56 else:
57 raise ValueError(f"{shapes} is not a valid loman shapes parameter for visualization")
59 colors = colors.lower()
60 if colors == "state":
61 node_formatters.append(ColorByState(cmap))
62 elif colors == "timing":
63 node_formatters.append(ColorByTiming(cmap))
64 else:
65 raise ValueError(f"{colors} is not a valid loman colors parameter for visualization")
67 node_formatters.append(StandardStylingOverrides())
68 node_formatters.append(RectBlocks())
70 return CompositeNodeFormatter(node_formatters)
73class ColorByState(NodeFormatter):
74 """Node formatter that colors nodes based on their computation state."""
76 DEFAULT_STATE_COLORS = {
77 None: "#ffffff", # xkcd white
78 States.PLACEHOLDER: "#f97306", # xkcd orange
79 States.UNINITIALIZED: "#0343df", # xkcd blue
80 States.STALE: "#ffff14", # xkcd yellow
81 States.COMPUTABLE: "#9dff00", # xkcd bright yellow green
82 States.UPTODATE: "#15b01a", # xkcd green
83 States.ERROR: "#e50000", # xkcd red
84 States.PINNED: "#bf77f6", # xkcd light purple
85 }
87 def __init__(self, state_colors=None):
88 """Initialize with custom state color mapping."""
89 if state_colors is None:
90 state_colors = self.DEFAULT_STATE_COLORS.copy()
91 self.state_colors = state_colors
93 def format(self, name: NodeKey, nodes: list[Node], is_composite: bool) -> dict | None:
94 """Format node color based on computation state."""
95 states = [node.data.get(NodeAttributes.STATE, None) for node in nodes]
96 if len(nodes) == 1:
97 state = states[0]
98 else:
99 if any(state == States.ERROR for state in states):
100 state = States.ERROR
101 elif any(state == States.STALE for state in states):
102 state = States.STALE
103 else:
104 state0 = states[0]
105 if all(state == state0 for state in states):
106 state = state0
107 else:
108 state = None
109 return {"style": "filled", "fillcolor": self.state_colors[state]}
112class ColorByTiming(NodeFormatter):
113 """Node formatter that colors nodes based on their execution timing."""
115 def __init__(self, cmap: Colormap | None = None):
116 """Initialize with an optional colormap for timing visualization."""
117 if cmap is None:
118 cmap = mpl.colors.LinearSegmentedColormap.from_list("blend", ["#15b01a", "#ffff14", "#e50000"])
119 self.cmap = cmap
120 self.min_duration = np.nan
121 self.max_duration = np.nan
123 def calibrate(self, nodes: list[Node]) -> None:
124 """Calibrate the color mapping based on node timing data."""
125 durations = []
126 for node in nodes:
127 timing = node.data.get(NodeAttributes.TIMING)
128 if timing is not None:
129 durations.append(timing.duration)
130 self.max_duration = max(durations)
131 self.min_duration = min(durations)
133 def format(self, name: NodeKey, nodes: list[Node], is_composite: bool) -> dict | None:
134 """Format a node with timing-based coloring."""
135 if len(nodes) == 1:
136 data = nodes[0].data
137 timing_data = data.get(NodeAttributes.TIMING)
138 if timing_data is None:
139 col = "#FFFFFF"
140 else:
141 duration = timing_data.duration
142 norm_duration: float = (duration - self.min_duration) / max(1e-8, self.max_duration - self.min_duration)
143 col = mpl.colors.rgb2hex(self.cmap(norm_duration))
144 return {"style": "filled", "fillcolor": col}
147class ShapeByType(NodeFormatter):
148 """Node formatter that sets node shapes based on their type."""
150 def format(self, name: NodeKey, nodes: list[Node], is_composite: bool) -> dict | None:
151 """Format a node with type-based shape styling."""
152 if len(nodes) == 1:
153 data = nodes[0].data
154 value = data.get(NodeAttributes.VALUE)
155 if value is None:
156 return
157 if isinstance(value, np.ndarray):
158 return {"shape": "rect"}
159 elif isinstance(value, pd.DataFrame):
160 return {"shape": "box3d"}
161 elif np.isscalar(value):
162 return {"shape": "ellipse"}
163 elif isinstance(value, (list, tuple)):
164 return {"shape": "ellipse", "peripheries": 2}
165 elif isinstance(value, dict):
166 return {"shape": "house", "peripheries": 2}
167 elif isinstance(value, loman.Computation):
168 return {"shape": "hexagon"}
169 else:
170 return {"shape": "diamond"}
173class RectBlocks(NodeFormatter):
174 """Node formatter that shapes composite nodes as rectangles."""
176 def format(self, name: NodeKey, nodes: list[Node], is_composite: bool) -> dict | None:
177 """Return rectangle shape for composite nodes."""
178 if is_composite:
179 return {"shape": "rect", "peripheries": 2}
182class StandardLabel(NodeFormatter):
183 """Node formatter that sets node labels."""
185 def format(self, name: NodeKey, nodes: list[Node], is_composite: bool) -> dict | None:
186 """Return standard label for node."""
187 return {"label": name.label}
190def get_group_path(name: NodeKey, data: dict) -> NodeKey:
191 """Determine the group path for a node based on name hierarchy and group attribute."""
192 name_group_path = name.parent
193 attribute_group = data.get(NodeAttributes.GROUP)
194 attribute_group_path = None if attribute_group is None else NodeKey((attribute_group,))
196 group_path = name_group_path.join(attribute_group_path)
197 return group_path
200class StandardGroup(NodeFormatter):
201 """Node formatter that applies standard grouping styles."""
203 def format(self, name: NodeKey, nodes: list[Node], is_composite: bool) -> dict | None:
204 """Format a node with standard group styling."""
205 if len(nodes) == 1:
206 data = nodes[0].data
207 group_path = get_group_path(name, data)
208 else:
209 group_path = name.parent
210 if group_path.is_root:
211 return None
212 return {"_group": group_path}
215class StandardStylingOverrides(NodeFormatter):
216 """Node formatter that applies standard styling overrides."""
218 def format(self, name: NodeKey, nodes: list[Node], is_composite: bool) -> dict | None:
219 """Format a node with standard styling overrides."""
220 if len(nodes) == 1:
221 data = nodes[0].data
222 style = data.get(NodeAttributes.STYLE)
223 if style is None:
224 return
225 if style == "small":
226 return {"width": 0.3, "height": 0.2, "fontsize": 8}
227 elif style == "dot":
228 return {"shape": "point", "width": 0.1, "peripheries": 1}
231@dataclass
232class CompositeNodeFormatter(NodeFormatter):
233 """A node formatter that combines multiple formatters together."""
235 formatters: list[NodeFormatter] = field(default_factory=list)
237 def calibrate(self, nodes: list[Node]) -> None:
238 """Calibrate all the contained formatters with the given nodes."""
239 for formatter in self.formatters:
240 formatter.calibrate(nodes)
242 def format(self, name: NodeKey, nodes: list[Node], is_composite: bool) -> dict | None:
243 """Format a node by combining output from all contained formatters."""
244 d = {}
245 for formatter in self.formatters:
246 format_attrs = formatter.format(name, nodes, is_composite)
247 if format_attrs is not None:
248 d.update(format_attrs)
249 return d
252@dataclass
253class GraphView:
254 """A view for visualizing computation graphs as graphical diagrams."""
256 computation: "loman.Computation"
257 root: Name | None = None
258 node_formatter: NodeFormatter | None = None
259 node_transformations: dict | None = None
260 collapse_all: bool = True
262 graph_attr: dict | None = None
263 node_attr: dict | None = None
264 edge_attr: dict | None = None
266 struct_dag: nx.DiGraph | None = None
267 viz_dag: nx.DiGraph | None = None
268 viz_dot: pydotplus.Dot | None = None
270 def __post_init__(self):
271 """Initialize the graph view after dataclass construction."""
272 self.refresh()
274 @staticmethod
275 def get_sub_block(dag, root, node_transformations: dict):
276 """Extract a subgraph with node transformations for visualization."""
277 d_transform_to_nodes = defaultdict(list)
278 for nk, transform in node_transformations.items():
279 d_transform_to_nodes[transform].append(nk)
281 dag_out = nx.DiGraph()
283 d_original_to_mapped = {}
284 s_collapsed = set()
286 for nk_original in dag.nodes():
287 nk_mapped = nk_original.drop_root(root)
288 if nk_mapped is None:
289 continue
290 nk_highest_collapse = nk_original
291 is_collapsed = False
292 for nk_collapse in d_transform_to_nodes[NodeTransformations.COLLAPSE]:
293 if nk_highest_collapse.is_descendent_of(nk_collapse):
294 nk_highest_collapse = nk_collapse
295 is_collapsed = True
296 nk_mapped = nk_highest_collapse.drop_root(root)
297 if nk_mapped is None:
298 continue
299 d_original_to_mapped[nk_original] = nk_mapped
300 if is_collapsed:
301 s_collapsed.add(nk_mapped)
303 for nk_mapped in d_original_to_mapped.values():
304 dag_out.add_node(nk_mapped)
306 for nk_u, nk_v in dag.edges():
307 nk_mapped_u = d_original_to_mapped.get(nk_u)
308 nk_mapped_v = d_original_to_mapped.get(nk_v)
309 if nk_mapped_u is None or nk_mapped_v is None or nk_mapped_u == nk_mapped_v:
310 continue
311 dag_out.add_edge(nk_mapped_u, nk_mapped_v)
313 for nk in d_transform_to_nodes[NodeTransformations.CONTRACT]:
314 contract_node(dag_out, d_original_to_mapped[nk])
315 del d_original_to_mapped[nk]
317 d_mapped_to_original = defaultdict(list)
318 for nk_original, nk_mapped in d_original_to_mapped.items():
319 if nk_mapped in dag_out.nodes:
320 d_mapped_to_original[nk_mapped].append(nk_original)
322 s_collapsed.intersection_update(dag_out.nodes)
324 return dag_out, d_mapped_to_original, s_collapsed
326 def _initialize_transforms(self):
327 node_transformations = {}
328 if self.collapse_all:
329 self._apply_default_collapse_transforms(node_transformations)
330 self._apply_custom_transforms(node_transformations)
331 return node_transformations
333 def _apply_default_collapse_transforms(self, node_transformations):
334 for n in self.computation.get_tree_descendents(self.root):
335 nk = to_nodekey(n)
336 if not self.computation.has_node(nk):
337 node_transformations[nk] = NodeTransformations.COLLAPSE
339 def _apply_custom_transforms(self, node_transformations):
340 if self.node_transformations is not None:
341 for rule_name, transform in self.node_transformations.items():
342 include_ancestors = transform == NodeTransformations.EXPAND
343 rule_nk = to_nodekey(rule_name)
344 if is_pattern(rule_nk):
345 apply_nodes = set()
346 for n in self.computation.get_tree_descendents(self.root):
347 nk = to_nodekey(n)
348 if match_pattern(rule_nk, nk):
349 apply_nodes.add(nk)
350 else:
351 apply_nodes = {rule_nk}
352 node_transformations[rule_nk] = transform
353 if include_ancestors:
354 for nk in apply_nodes:
355 for nk1 in nk.ancestors():
356 if nk1.is_root or nk1 == self.root:
357 break
358 node_transformations[nk1] = NodeTransformations.EXPAND
359 for rule_nk in apply_nodes:
360 node_transformations[rule_nk] = transform
361 return node_transformations
363 def _create_visualization_dag(self, original_nodes, composite_nodes):
364 node_formatter = self.node_formatter
365 if node_formatter is None:
366 node_formatter = NodeFormatter.create()
367 return create_viz_dag(self.struct_dag, self.computation.dag, node_formatter, original_nodes, composite_nodes)
369 def _create_dot_graph(self):
370 return to_pydot(self.viz_dag, self.graph_attr, self.node_attr, self.edge_attr)
372 def refresh(self):
373 """Refresh the visualization by rebuilding the graph structure."""
374 node_transformations = self._initialize_transforms()
375 self.struct_dag, original_nodes, composite_nodes = self.get_sub_block(
376 self.computation.dag, self.root, node_transformations
377 )
378 self.viz_dag = self._create_visualization_dag(original_nodes, composite_nodes)
379 self.viz_dot = self._create_dot_graph()
381 def svg(self) -> str | None:
382 """Generate SVG representation of the visualization."""
383 if self.viz_dot is None:
384 return None
385 return self.viz_dot.create_svg().decode("utf-8")
387 def view(self):
388 """Open the visualization in a PDF viewer."""
389 with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
390 f.write(self.viz_dot.create_pdf())
391 if sys.platform != "win32":
392 os.system(f"open {f.name}")
393 os.startfile(f.name)
395 def _repr_svg_(self):
396 return self.svg()
399def create_viz_dag(
400 struct_dag, comp_dag, node_formatter: NodeFormatter, original_nodes: dict, composite_nodes: set
401) -> nx.DiGraph:
402 """Create a visualization DAG from the computation structure."""
403 if node_formatter is not None:
404 nodes = []
405 for nodekey in struct_dag.nodes:
406 for original_nodekey in original_nodes[nodekey]:
407 data = comp_dag.nodes[original_nodekey]
408 node = Node(nodekey, original_nodekey, data)
409 nodes.append(node)
410 node_formatter.calibrate(nodes)
412 viz_dag = nx.DiGraph()
413 node_index_map = {}
414 for i, nodekey in enumerate(struct_dag.nodes):
415 short_name = f"n{i}"
416 attr_dict = None
418 if node_formatter is not None:
419 nodes = []
420 for original_nodekey in original_nodes[nodekey]:
421 data = comp_dag.nodes[original_nodekey]
422 node = Node(nodekey, original_nodekey, data)
423 nodes.append(node)
424 is_composite = nodekey in composite_nodes
425 attr_dict = node_formatter.format(nodekey, nodes, is_composite)
426 if attr_dict is None:
427 attr_dict = {}
429 attr_dict = {k: v for k, v in attr_dict.items() if v is not None}
431 viz_dag.add_node(short_name, **attr_dict)
432 node_index_map[nodekey] = short_name
434 for name1, name2 in struct_dag.edges():
435 short_name_1 = node_index_map[name1]
436 short_name_2 = node_index_map[name2]
438 group_path1 = get_group_path(name1, struct_dag.nodes[name1])
439 group_path2 = get_group_path(name2, struct_dag.nodes[name2])
440 group_path = NodeKey.common_parent(group_path1, group_path2)
442 attr_dict = {}
443 if not group_path.is_root:
444 # group_path = None
445 attr_dict["_group"] = group_path
447 viz_dag.add_edge(short_name_1, short_name_2, **attr_dict)
449 return viz_dag
452def to_pydot(viz_dag, graph_attr=None, node_attr=None, edge_attr=None) -> pydotplus.Dot:
453 """Convert a visualization DAG to a PyDot graph for rendering."""
454 root = NodeKey.root()
456 node_groups = {}
457 for name, data in viz_dag.nodes(data=True):
458 group = data.get("_group", root)
459 node_groups.setdefault(group, []).append(name)
461 edge_groups = {}
462 for name1, name2, data in viz_dag.edges(data=True):
463 group = data.get("_group", root)
464 edge_groups.setdefault(group, []).append((name1, name2))
466 subgraphs = {root: create_root_graph(graph_attr, node_attr, edge_attr)}
468 for group, names in node_groups.items():
469 c = subgraphs[root] if group is root else create_subgraph(group)
471 for name in names:
472 node = pydotplus.Node(name)
473 for k, v in viz_dag.nodes[name].items():
474 if not k.startswith("_"):
475 node.set(k, v)
476 c.add_node(node)
478 subgraphs[group] = c
480 groups = list(subgraphs.keys())
481 for group in groups:
482 group1 = group
483 while True:
484 if group1.is_root:
485 break
486 group1 = group1.parent
487 if group1 in subgraphs:
488 break
489 subgraphs[group1] = create_subgraph(group1)
491 for group, subgraph in subgraphs.items():
492 if group.is_root:
493 continue
494 parent = group
495 while True:
496 parent = parent.parent
497 if parent in subgraphs or parent.is_root:
498 break
499 subgraphs[parent].add_subgraph(subgraph)
501 for group, edges in edge_groups.items():
502 c = subgraphs[group]
503 for name1, name2 in edges:
504 edge = pydotplus.Edge(name1, name2)
505 c.add_edge(edge)
507 return subgraphs[root]
510def create_root_graph(graph_attr, node_attr, edge_attr):
511 """Create root Graphviz graph with specified attributes."""
512 root_graph = pydotplus.Dot()
513 if graph_attr is not None:
514 for k, v in graph_attr.items():
515 root_graph.set(k, v)
516 if node_attr is not None:
517 root_graph.set_node_defaults(**node_attr)
518 if edge_attr is not None:
519 root_graph.set_edge_defaults(**edge_attr)
520 return root_graph
523def create_subgraph(group: NodeKey):
524 """Create a Graphviz subgraph for a node group."""
525 c = pydotplus.Subgraph("cluster_" + str(group))
526 c.obj_dict["attributes"]["label"] = str(group)
527 return c