Coverage for src/loman/visualization.py: 81%

366 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-21 05:36 +0000

1"""Visualization tools for computation graphs using Graphviz.""" 

2 

3import os 

4import sys 

5import tempfile 

6from abc import ABC, abstractmethod 

7from collections import defaultdict 

8from dataclasses import dataclass, field 

9 

10import matplotlib as mpl 

11import networkx as nx 

12import numpy as np 

13import pandas as pd 

14import pydotplus 

15from matplotlib.colors import Colormap 

16 

17import loman 

18 

19from .consts import NodeAttributes, NodeTransformations, States 

20from .graph_utils import contract_node 

21from .nodekey import Name, NodeKey, is_pattern, match_pattern, to_nodekey 

22 

23 

24@dataclass 

25class Node: 

26 """Represents a node in the visualization graph.""" 

27 

28 nodekey: NodeKey 

29 original_nodekey: NodeKey 

30 data: dict 

31 

32 

33class NodeFormatter(ABC): 

34 """Abstract base class for node formatting in visualizations.""" 

35 

36 def calibrate(self, nodes: list[Node]) -> None: 

37 """Calibrate formatter based on all nodes in the graph.""" 

38 pass 

39 

40 @abstractmethod 

41 def format(self, name: NodeKey, nodes: list[Node], is_composite: bool) -> dict | None: 

42 """Format node appearance returning dict of graphviz attributes.""" 

43 pass 

44 

45 @staticmethod 

46 def create(cmap: dict | Colormap | None = None, colors: str = "state", shapes: str | None = None): 

47 """Create a composite node formatter with specified color and shape options.""" 

48 node_formatters = [StandardLabel(), StandardGroup()] 

49 

50 if isinstance(shapes, str): 

51 shapes = shapes.lower() 

52 if shapes == "type": 

53 node_formatters.append(ShapeByType()) 

54 elif shapes is None: 

55 pass 

56 else: 

57 raise ValueError(f"{shapes} is not a valid loman shapes parameter for visualization") 

58 

59 colors = colors.lower() 

60 if colors == "state": 

61 node_formatters.append(ColorByState(cmap)) 

62 elif colors == "timing": 

63 node_formatters.append(ColorByTiming(cmap)) 

64 else: 

65 raise ValueError(f"{colors} is not a valid loman colors parameter for visualization") 

66 

67 node_formatters.append(StandardStylingOverrides()) 

68 node_formatters.append(RectBlocks()) 

69 

70 return CompositeNodeFormatter(node_formatters) 

71 

72 

73class ColorByState(NodeFormatter): 

74 """Node formatter that colors nodes based on their computation state.""" 

75 

76 DEFAULT_STATE_COLORS = { 

77 None: "#ffffff", # xkcd white 

78 States.PLACEHOLDER: "#f97306", # xkcd orange 

79 States.UNINITIALIZED: "#0343df", # xkcd blue 

80 States.STALE: "#ffff14", # xkcd yellow 

81 States.COMPUTABLE: "#9dff00", # xkcd bright yellow green 

82 States.UPTODATE: "#15b01a", # xkcd green 

83 States.ERROR: "#e50000", # xkcd red 

84 States.PINNED: "#bf77f6", # xkcd light purple 

85 } 

86 

87 def __init__(self, state_colors=None): 

88 """Initialize with custom state color mapping.""" 

89 if state_colors is None: 

90 state_colors = self.DEFAULT_STATE_COLORS.copy() 

91 self.state_colors = state_colors 

92 

93 def format(self, name: NodeKey, nodes: list[Node], is_composite: bool) -> dict | None: 

94 """Format node color based on computation state.""" 

95 states = [node.data.get(NodeAttributes.STATE, None) for node in nodes] 

96 if len(nodes) == 1: 

97 state = states[0] 

98 else: 

99 if any(state == States.ERROR for state in states): 

100 state = States.ERROR 

101 elif any(state == States.STALE for state in states): 

102 state = States.STALE 

103 else: 

104 state0 = states[0] 

105 if all(state == state0 for state in states): 

106 state = state0 

107 else: 

108 state = None 

109 return {"style": "filled", "fillcolor": self.state_colors[state]} 

110 

111 

112class ColorByTiming(NodeFormatter): 

113 """Node formatter that colors nodes based on their execution timing.""" 

114 

115 def __init__(self, cmap: Colormap | None = None): 

116 """Initialize with an optional colormap for timing visualization.""" 

117 if cmap is None: 

118 cmap = mpl.colors.LinearSegmentedColormap.from_list("blend", ["#15b01a", "#ffff14", "#e50000"]) 

119 self.cmap = cmap 

120 self.min_duration = np.nan 

121 self.max_duration = np.nan 

122 

123 def calibrate(self, nodes: list[Node]) -> None: 

124 """Calibrate the color mapping based on node timing data.""" 

125 durations = [] 

126 for node in nodes: 

127 timing = node.data.get(NodeAttributes.TIMING) 

128 if timing is not None: 

129 durations.append(timing.duration) 

130 self.max_duration = max(durations) 

131 self.min_duration = min(durations) 

132 

133 def format(self, name: NodeKey, nodes: list[Node], is_composite: bool) -> dict | None: 

134 """Format a node with timing-based coloring.""" 

135 if len(nodes) == 1: 

136 data = nodes[0].data 

137 timing_data = data.get(NodeAttributes.TIMING) 

138 if timing_data is None: 

139 col = "#FFFFFF" 

140 else: 

141 duration = timing_data.duration 

142 norm_duration: float = (duration - self.min_duration) / max(1e-8, self.max_duration - self.min_duration) 

143 col = mpl.colors.rgb2hex(self.cmap(norm_duration)) 

144 return {"style": "filled", "fillcolor": col} 

145 

146 

147class ShapeByType(NodeFormatter): 

148 """Node formatter that sets node shapes based on their type.""" 

149 

150 def format(self, name: NodeKey, nodes: list[Node], is_composite: bool) -> dict | None: 

151 """Format a node with type-based shape styling.""" 

152 if len(nodes) == 1: 

153 data = nodes[0].data 

154 value = data.get(NodeAttributes.VALUE) 

155 if value is None: 

156 return 

157 if isinstance(value, np.ndarray): 

158 return {"shape": "rect"} 

159 elif isinstance(value, pd.DataFrame): 

160 return {"shape": "box3d"} 

161 elif np.isscalar(value): 

162 return {"shape": "ellipse"} 

163 elif isinstance(value, (list, tuple)): 

164 return {"shape": "ellipse", "peripheries": 2} 

165 elif isinstance(value, dict): 

166 return {"shape": "house", "peripheries": 2} 

167 elif isinstance(value, loman.Computation): 

168 return {"shape": "hexagon"} 

169 else: 

170 return {"shape": "diamond"} 

171 

172 

173class RectBlocks(NodeFormatter): 

174 """Node formatter that shapes composite nodes as rectangles.""" 

175 

176 def format(self, name: NodeKey, nodes: list[Node], is_composite: bool) -> dict | None: 

177 """Return rectangle shape for composite nodes.""" 

178 if is_composite: 

179 return {"shape": "rect", "peripheries": 2} 

180 

181 

182class StandardLabel(NodeFormatter): 

183 """Node formatter that sets node labels.""" 

184 

185 def format(self, name: NodeKey, nodes: list[Node], is_composite: bool) -> dict | None: 

186 """Return standard label for node.""" 

187 return {"label": name.label} 

188 

189 

190def get_group_path(name: NodeKey, data: dict) -> NodeKey: 

191 """Determine the group path for a node based on name hierarchy and group attribute.""" 

192 name_group_path = name.parent 

193 attribute_group = data.get(NodeAttributes.GROUP) 

194 attribute_group_path = None if attribute_group is None else NodeKey((attribute_group,)) 

195 

196 group_path = name_group_path.join(attribute_group_path) 

197 return group_path 

198 

199 

200class StandardGroup(NodeFormatter): 

201 """Node formatter that applies standard grouping styles.""" 

202 

203 def format(self, name: NodeKey, nodes: list[Node], is_composite: bool) -> dict | None: 

204 """Format a node with standard group styling.""" 

205 if len(nodes) == 1: 

206 data = nodes[0].data 

207 group_path = get_group_path(name, data) 

208 else: 

209 group_path = name.parent 

210 if group_path.is_root: 

211 return None 

212 return {"_group": group_path} 

213 

214 

215class StandardStylingOverrides(NodeFormatter): 

216 """Node formatter that applies standard styling overrides.""" 

217 

218 def format(self, name: NodeKey, nodes: list[Node], is_composite: bool) -> dict | None: 

219 """Format a node with standard styling overrides.""" 

220 if len(nodes) == 1: 

221 data = nodes[0].data 

222 style = data.get(NodeAttributes.STYLE) 

223 if style is None: 

224 return 

225 if style == "small": 

226 return {"width": 0.3, "height": 0.2, "fontsize": 8} 

227 elif style == "dot": 

228 return {"shape": "point", "width": 0.1, "peripheries": 1} 

229 

230 

231@dataclass 

232class CompositeNodeFormatter(NodeFormatter): 

233 """A node formatter that combines multiple formatters together.""" 

234 

235 formatters: list[NodeFormatter] = field(default_factory=list) 

236 

237 def calibrate(self, nodes: list[Node]) -> None: 

238 """Calibrate all the contained formatters with the given nodes.""" 

239 for formatter in self.formatters: 

240 formatter.calibrate(nodes) 

241 

242 def format(self, name: NodeKey, nodes: list[Node], is_composite: bool) -> dict | None: 

243 """Format a node by combining output from all contained formatters.""" 

244 d = {} 

245 for formatter in self.formatters: 

246 format_attrs = formatter.format(name, nodes, is_composite) 

247 if format_attrs is not None: 

248 d.update(format_attrs) 

249 return d 

250 

251 

252@dataclass 

253class GraphView: 

254 """A view for visualizing computation graphs as graphical diagrams.""" 

255 

256 computation: "loman.Computation" 

257 root: Name | None = None 

258 node_formatter: NodeFormatter | None = None 

259 node_transformations: dict | None = None 

260 collapse_all: bool = True 

261 

262 graph_attr: dict | None = None 

263 node_attr: dict | None = None 

264 edge_attr: dict | None = None 

265 

266 struct_dag: nx.DiGraph | None = None 

267 viz_dag: nx.DiGraph | None = None 

268 viz_dot: pydotplus.Dot | None = None 

269 

270 def __post_init__(self): 

271 """Initialize the graph view after dataclass construction.""" 

272 self.refresh() 

273 

274 @staticmethod 

275 def get_sub_block(dag, root, node_transformations: dict): 

276 """Extract a subgraph with node transformations for visualization.""" 

277 d_transform_to_nodes = defaultdict(list) 

278 for nk, transform in node_transformations.items(): 

279 d_transform_to_nodes[transform].append(nk) 

280 

281 dag_out = nx.DiGraph() 

282 

283 d_original_to_mapped = {} 

284 s_collapsed = set() 

285 

286 for nk_original in dag.nodes(): 

287 nk_mapped = nk_original.drop_root(root) 

288 if nk_mapped is None: 

289 continue 

290 nk_highest_collapse = nk_original 

291 is_collapsed = False 

292 for nk_collapse in d_transform_to_nodes[NodeTransformations.COLLAPSE]: 

293 if nk_highest_collapse.is_descendent_of(nk_collapse): 

294 nk_highest_collapse = nk_collapse 

295 is_collapsed = True 

296 nk_mapped = nk_highest_collapse.drop_root(root) 

297 if nk_mapped is None: 

298 continue 

299 d_original_to_mapped[nk_original] = nk_mapped 

300 if is_collapsed: 

301 s_collapsed.add(nk_mapped) 

302 

303 for nk_mapped in d_original_to_mapped.values(): 

304 dag_out.add_node(nk_mapped) 

305 

306 for nk_u, nk_v in dag.edges(): 

307 nk_mapped_u = d_original_to_mapped.get(nk_u) 

308 nk_mapped_v = d_original_to_mapped.get(nk_v) 

309 if nk_mapped_u is None or nk_mapped_v is None or nk_mapped_u == nk_mapped_v: 

310 continue 

311 dag_out.add_edge(nk_mapped_u, nk_mapped_v) 

312 

313 for nk in d_transform_to_nodes[NodeTransformations.CONTRACT]: 

314 contract_node(dag_out, d_original_to_mapped[nk]) 

315 del d_original_to_mapped[nk] 

316 

317 d_mapped_to_original = defaultdict(list) 

318 for nk_original, nk_mapped in d_original_to_mapped.items(): 

319 if nk_mapped in dag_out.nodes: 

320 d_mapped_to_original[nk_mapped].append(nk_original) 

321 

322 s_collapsed.intersection_update(dag_out.nodes) 

323 

324 return dag_out, d_mapped_to_original, s_collapsed 

325 

326 def _initialize_transforms(self): 

327 node_transformations = {} 

328 if self.collapse_all: 

329 self._apply_default_collapse_transforms(node_transformations) 

330 self._apply_custom_transforms(node_transformations) 

331 return node_transformations 

332 

333 def _apply_default_collapse_transforms(self, node_transformations): 

334 for n in self.computation.get_tree_descendents(self.root): 

335 nk = to_nodekey(n) 

336 if not self.computation.has_node(nk): 

337 node_transformations[nk] = NodeTransformations.COLLAPSE 

338 

339 def _apply_custom_transforms(self, node_transformations): 

340 if self.node_transformations is not None: 

341 for rule_name, transform in self.node_transformations.items(): 

342 include_ancestors = transform == NodeTransformations.EXPAND 

343 rule_nk = to_nodekey(rule_name) 

344 if is_pattern(rule_nk): 

345 apply_nodes = set() 

346 for n in self.computation.get_tree_descendents(self.root): 

347 nk = to_nodekey(n) 

348 if match_pattern(rule_nk, nk): 

349 apply_nodes.add(nk) 

350 else: 

351 apply_nodes = {rule_nk} 

352 node_transformations[rule_nk] = transform 

353 if include_ancestors: 

354 for nk in apply_nodes: 

355 for nk1 in nk.ancestors(): 

356 if nk1.is_root or nk1 == self.root: 

357 break 

358 node_transformations[nk1] = NodeTransformations.EXPAND 

359 for rule_nk in apply_nodes: 

360 node_transformations[rule_nk] = transform 

361 return node_transformations 

362 

363 def _create_visualization_dag(self, original_nodes, composite_nodes): 

364 node_formatter = self.node_formatter 

365 if node_formatter is None: 

366 node_formatter = NodeFormatter.create() 

367 return create_viz_dag(self.struct_dag, self.computation.dag, node_formatter, original_nodes, composite_nodes) 

368 

369 def _create_dot_graph(self): 

370 return to_pydot(self.viz_dag, self.graph_attr, self.node_attr, self.edge_attr) 

371 

372 def refresh(self): 

373 """Refresh the visualization by rebuilding the graph structure.""" 

374 node_transformations = self._initialize_transforms() 

375 self.struct_dag, original_nodes, composite_nodes = self.get_sub_block( 

376 self.computation.dag, self.root, node_transformations 

377 ) 

378 self.viz_dag = self._create_visualization_dag(original_nodes, composite_nodes) 

379 self.viz_dot = self._create_dot_graph() 

380 

381 def svg(self) -> str | None: 

382 """Generate SVG representation of the visualization.""" 

383 if self.viz_dot is None: 

384 return None 

385 return self.viz_dot.create_svg().decode("utf-8") 

386 

387 def view(self): 

388 """Open the visualization in a PDF viewer.""" 

389 with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f: 

390 f.write(self.viz_dot.create_pdf()) 

391 if sys.platform != "win32": 

392 os.system(f"open {f.name}") 

393 os.startfile(f.name) 

394 

395 def _repr_svg_(self): 

396 return self.svg() 

397 

398 

399def create_viz_dag( 

400 struct_dag, comp_dag, node_formatter: NodeFormatter, original_nodes: dict, composite_nodes: set 

401) -> nx.DiGraph: 

402 """Create a visualization DAG from the computation structure.""" 

403 if node_formatter is not None: 

404 nodes = [] 

405 for nodekey in struct_dag.nodes: 

406 for original_nodekey in original_nodes[nodekey]: 

407 data = comp_dag.nodes[original_nodekey] 

408 node = Node(nodekey, original_nodekey, data) 

409 nodes.append(node) 

410 node_formatter.calibrate(nodes) 

411 

412 viz_dag = nx.DiGraph() 

413 node_index_map = {} 

414 for i, nodekey in enumerate(struct_dag.nodes): 

415 short_name = f"n{i}" 

416 attr_dict = None 

417 

418 if node_formatter is not None: 

419 nodes = [] 

420 for original_nodekey in original_nodes[nodekey]: 

421 data = comp_dag.nodes[original_nodekey] 

422 node = Node(nodekey, original_nodekey, data) 

423 nodes.append(node) 

424 is_composite = nodekey in composite_nodes 

425 attr_dict = node_formatter.format(nodekey, nodes, is_composite) 

426 if attr_dict is None: 

427 attr_dict = {} 

428 

429 attr_dict = {k: v for k, v in attr_dict.items() if v is not None} 

430 

431 viz_dag.add_node(short_name, **attr_dict) 

432 node_index_map[nodekey] = short_name 

433 

434 for name1, name2 in struct_dag.edges(): 

435 short_name_1 = node_index_map[name1] 

436 short_name_2 = node_index_map[name2] 

437 

438 group_path1 = get_group_path(name1, struct_dag.nodes[name1]) 

439 group_path2 = get_group_path(name2, struct_dag.nodes[name2]) 

440 group_path = NodeKey.common_parent(group_path1, group_path2) 

441 

442 attr_dict = {} 

443 if not group_path.is_root: 

444 # group_path = None 

445 attr_dict["_group"] = group_path 

446 

447 viz_dag.add_edge(short_name_1, short_name_2, **attr_dict) 

448 

449 return viz_dag 

450 

451 

452def to_pydot(viz_dag, graph_attr=None, node_attr=None, edge_attr=None) -> pydotplus.Dot: 

453 """Convert a visualization DAG to a PyDot graph for rendering.""" 

454 root = NodeKey.root() 

455 

456 node_groups = {} 

457 for name, data in viz_dag.nodes(data=True): 

458 group = data.get("_group", root) 

459 node_groups.setdefault(group, []).append(name) 

460 

461 edge_groups = {} 

462 for name1, name2, data in viz_dag.edges(data=True): 

463 group = data.get("_group", root) 

464 edge_groups.setdefault(group, []).append((name1, name2)) 

465 

466 subgraphs = {root: create_root_graph(graph_attr, node_attr, edge_attr)} 

467 

468 for group, names in node_groups.items(): 

469 c = subgraphs[root] if group is root else create_subgraph(group) 

470 

471 for name in names: 

472 node = pydotplus.Node(name) 

473 for k, v in viz_dag.nodes[name].items(): 

474 if not k.startswith("_"): 

475 node.set(k, v) 

476 c.add_node(node) 

477 

478 subgraphs[group] = c 

479 

480 groups = list(subgraphs.keys()) 

481 for group in groups: 

482 group1 = group 

483 while True: 

484 if group1.is_root: 

485 break 

486 group1 = group1.parent 

487 if group1 in subgraphs: 

488 break 

489 subgraphs[group1] = create_subgraph(group1) 

490 

491 for group, subgraph in subgraphs.items(): 

492 if group.is_root: 

493 continue 

494 parent = group 

495 while True: 

496 parent = parent.parent 

497 if parent in subgraphs or parent.is_root: 

498 break 

499 subgraphs[parent].add_subgraph(subgraph) 

500 

501 for group, edges in edge_groups.items(): 

502 c = subgraphs[group] 

503 for name1, name2 in edges: 

504 edge = pydotplus.Edge(name1, name2) 

505 c.add_edge(edge) 

506 

507 return subgraphs[root] 

508 

509 

510def create_root_graph(graph_attr, node_attr, edge_attr): 

511 """Create root Graphviz graph with specified attributes.""" 

512 root_graph = pydotplus.Dot() 

513 if graph_attr is not None: 

514 for k, v in graph_attr.items(): 

515 root_graph.set(k, v) 

516 if node_attr is not None: 

517 root_graph.set_node_defaults(**node_attr) 

518 if edge_attr is not None: 

519 root_graph.set_edge_defaults(**edge_attr) 

520 return root_graph 

521 

522 

523def create_subgraph(group: NodeKey): 

524 """Create a Graphviz subgraph for a node group.""" 

525 c = pydotplus.Subgraph("cluster_" + str(group)) 

526 c.obj_dict["attributes"]["label"] = str(group) 

527 return c