Coverage for src / loman / visualization.py: 99%

403 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-22 21:30 +0000

1"""Visualization tools for computation graphs using Graphviz.""" 

2 

3import os 

4import subprocess 

5import sys 

6import tempfile 

7from abc import ABC, abstractmethod 

8from collections import defaultdict 

9from dataclasses import dataclass, field 

10from typing import TYPE_CHECKING, Any, ClassVar 

11 

12import matplotlib as mpl 

13import networkx as nx 

14import numpy as np 

15import pandas as pd 

16import pydotplus 

17from matplotlib.colors import Colormap 

18 

19import loman 

20 

21from .consts import NodeAttributes, NodeTransformations, States 

22from .graph_utils import contract_node 

23from .nodekey import Name, NodeKey, is_pattern, match_pattern, to_nodekey 

24 

25if TYPE_CHECKING: 

26 from .computeengine import Computation 

27 

28 

29@dataclass 

30class Node: 

31 """Represents a node in the visualization graph.""" 

32 

33 nodekey: NodeKey 

34 original_nodekey: NodeKey 

35 data: dict[str, Any] 

36 

37 

38class NodeFormatter(ABC): 

39 """Abstract base class for node formatting in visualizations.""" 

40 

41 @abstractmethod 

42 def calibrate(self, nodes: list[Node]) -> None: 

43 """Calibrate formatter based on all nodes in the graph.""" 

44 pass # pragma: no cover 

45 

46 @abstractmethod 

47 def format(self, name: NodeKey, nodes: list[Node], is_composite: bool) -> dict[str, Any] | None: 

48 """Format node appearance returning dict of graphviz attributes.""" 

49 pass # pragma: no cover 

50 

51 @staticmethod 

52 def create( 

53 cmap: dict[States | None, str] | Colormap | None = None, colors: str = "state", shapes: str | None = None 

54 ) -> "CompositeNodeFormatter": 

55 """Create a composite node formatter with specified color and shape options.""" 

56 node_formatters: list[NodeFormatter] = [StandardLabel(), StandardGroup()] 

57 

58 if isinstance(shapes, str): 

59 shapes = shapes.lower() 

60 if shapes == "type": 

61 node_formatters.append(ShapeByType()) 

62 elif shapes is None: 

63 pass 

64 else: 

65 msg = f"{shapes} is not a valid loman shapes parameter for visualization" 

66 raise ValueError(msg) 

67 

68 colors = colors.lower() 

69 if colors == "state": 

70 state_cmap = cmap if isinstance(cmap, dict) else None 

71 node_formatters.append(ColorByState(state_cmap)) # type: ignore[arg-type] 

72 elif colors == "timing": 

73 timing_cmap = cmap if isinstance(cmap, Colormap) else None 

74 node_formatters.append(ColorByTiming(timing_cmap)) 

75 else: 

76 msg = f"{colors} is not a valid loman colors parameter for visualization" 

77 raise ValueError(msg) 

78 

79 node_formatters.append(StandardStylingOverrides()) 

80 node_formatters.append(RectBlocks()) 

81 

82 return CompositeNodeFormatter(node_formatters) 

83 

84 

85class ColorByState(NodeFormatter): 

86 """Node formatter that colors nodes based on their computation state.""" 

87 

88 DEFAULT_STATE_COLORS: ClassVar[dict[States | None, str]] = { 

89 None: "#ffffff", # xkcd white 

90 States.PLACEHOLDER: "#f97306", # xkcd orange 

91 States.UNINITIALIZED: "#0343df", # xkcd blue 

92 States.STALE: "#ffff14", # xkcd yellow 

93 States.COMPUTABLE: "#9dff00", # xkcd bright yellow green 

94 States.UPTODATE: "#15b01a", # xkcd green 

95 States.ERROR: "#e50000", # xkcd red 

96 States.PINNED: "#bf77f6", # xkcd light purple 

97 } 

98 

99 def __init__(self, state_colors: dict[States | None, str] | None = None) -> None: 

100 """Initialize with custom state color mapping.""" 

101 if state_colors is None: 

102 state_colors = self.DEFAULT_STATE_COLORS.copy() 

103 self.state_colors = state_colors 

104 

105 def calibrate(self, nodes: list[Node]) -> None: 

106 """Calibrate formatter based on all nodes in the graph.""" 

107 pass 

108 

109 def format(self, name: NodeKey, nodes: list[Node], is_composite: bool) -> dict[str, Any] | None: 

110 """Format node color based on computation state.""" 

111 states = [node.data.get(NodeAttributes.STATE, None) for node in nodes] 

112 state: States | None 

113 if len(nodes) == 1: 

114 state = states[0] 

115 else: 

116 if any(s == States.ERROR for s in states): 

117 state = States.ERROR 

118 elif any(s == States.STALE for s in states): 

119 state = States.STALE 

120 else: 

121 state0 = states[0] 

122 state = state0 if all(s == state0 for s in states) else None 

123 return {"style": "filled", "fillcolor": self.state_colors[state]} 

124 

125 

126class ColorByTiming(NodeFormatter): 

127 """Node formatter that colors nodes based on their execution timing.""" 

128 

129 def __init__(self, cmap: Colormap | None = None) -> None: 

130 """Initialize with an optional colormap for timing visualization.""" 

131 if cmap is None: 

132 cmap = mpl.colors.LinearSegmentedColormap.from_list("blend", ["#15b01a", "#ffff14", "#e50000"]) 

133 self.cmap = cmap 

134 self.min_duration: float = float("nan") 

135 self.max_duration: float = float("nan") 

136 

137 def calibrate(self, nodes: list[Node]) -> None: 

138 """Calibrate the color mapping based on node timing data.""" 

139 durations: list[float] = [] 

140 for node in nodes: 

141 timing = node.data.get(NodeAttributes.TIMING) 

142 if timing is not None: 

143 durations.append(timing.duration) 

144 if durations: 

145 self.max_duration = max(durations) 

146 self.min_duration = min(durations) 

147 

148 def format(self, name: NodeKey, nodes: list[Node], is_composite: bool) -> dict[str, Any] | None: 

149 """Format a node with timing-based coloring.""" 

150 if len(nodes) == 1: 

151 data = nodes[0].data 

152 timing_data = data.get(NodeAttributes.TIMING) 

153 if timing_data is None: 

154 col = "#FFFFFF" 

155 else: 

156 duration = timing_data.duration 

157 norm_duration: float = (duration - self.min_duration) / max(1e-8, self.max_duration - self.min_duration) 

158 col = mpl.colors.rgb2hex(self.cmap(norm_duration)) 

159 return {"style": "filled", "fillcolor": col} 

160 return None 

161 

162 

163class ShapeByType(NodeFormatter): 

164 """Node formatter that sets node shapes based on their type.""" 

165 

166 def calibrate(self, nodes: list[Node]) -> None: 

167 """Calibrate formatter based on all nodes in the graph.""" 

168 pass 

169 

170 def format(self, name: NodeKey, nodes: list[Node], is_composite: bool) -> dict[str, Any] | None: 

171 """Format a node with type-based shape styling.""" 

172 if len(nodes) == 1: 

173 data = nodes[0].data 

174 value = data.get(NodeAttributes.VALUE) 

175 if value is None: 

176 return None 

177 if isinstance(value, np.ndarray): 

178 return {"shape": "rect"} 

179 elif isinstance(value, pd.DataFrame): 

180 return {"shape": "box3d"} 

181 elif np.isscalar(value): 

182 return {"shape": "ellipse"} 

183 elif isinstance(value, (list, tuple)): 

184 return {"shape": "ellipse", "peripheries": 2} 

185 elif isinstance(value, dict): 

186 return {"shape": "house", "peripheries": 2} 

187 elif isinstance(value, loman.Computation): 

188 return {"shape": "hexagon"} 

189 else: 

190 return {"shape": "diamond"} 

191 return None 

192 

193 

194class RectBlocks(NodeFormatter): 

195 """Node formatter that shapes composite nodes as rectangles.""" 

196 

197 def calibrate(self, nodes: list[Node]) -> None: 

198 """Calibrate formatter based on all nodes in the graph.""" 

199 pass 

200 

201 def format(self, name: NodeKey, nodes: list[Node], is_composite: bool) -> dict[str, Any] | None: 

202 """Return rectangle shape for composite nodes.""" 

203 if is_composite: 

204 return {"shape": "rect", "peripheries": 2} 

205 return None 

206 

207 

208class StandardLabel(NodeFormatter): 

209 """Node formatter that sets node labels.""" 

210 

211 def calibrate(self, nodes: list[Node]) -> None: 

212 """Calibrate formatter based on all nodes in the graph.""" 

213 pass 

214 

215 def format(self, name: NodeKey, nodes: list[Node], is_composite: bool) -> dict[str, Any] | None: 

216 """Return standard label for node.""" 

217 return {"label": name.label} 

218 

219 

220def get_group_path(name: NodeKey, data: dict[str, Any]) -> NodeKey: 

221 """Determine the group path for a node based on name hierarchy and group attribute.""" 

222 name_group_path = name.parent 

223 attribute_group = data.get(NodeAttributes.GROUP) 

224 attribute_group_path = None if attribute_group is None else NodeKey((attribute_group,)) 

225 

226 group_path = name_group_path.join(attribute_group_path) 

227 return group_path 

228 

229 

230class StandardGroup(NodeFormatter): 

231 """Node formatter that applies standard grouping styles.""" 

232 

233 def calibrate(self, nodes: list[Node]) -> None: 

234 """Calibrate formatter based on all nodes in the graph.""" 

235 pass 

236 

237 def format(self, name: NodeKey, nodes: list[Node], is_composite: bool) -> dict[str, Any] | None: 

238 """Format a node with standard group styling.""" 

239 if len(nodes) == 1: 

240 data = nodes[0].data 

241 group_path = get_group_path(name, data) 

242 else: 

243 group_path = name.parent 

244 if group_path.is_root: 

245 return None 

246 return {"_group": group_path} 

247 

248 

249class StandardStylingOverrides(NodeFormatter): 

250 """Node formatter that applies standard styling overrides.""" 

251 

252 def calibrate(self, nodes: list[Node]) -> None: 

253 """Calibrate formatter based on all nodes in the graph.""" 

254 pass 

255 

256 def format(self, name: NodeKey, nodes: list[Node], is_composite: bool) -> dict[str, Any] | None: 

257 """Format a node with standard styling overrides.""" 

258 if len(nodes) == 1: 

259 data = nodes[0].data 

260 style = data.get(NodeAttributes.STYLE) 

261 if style is None: 

262 return None 

263 if style == "small": 

264 return {"width": 0.3, "height": 0.2, "fontsize": 8} 

265 elif style == "dot": 

266 return {"shape": "point", "width": 0.1, "peripheries": 1} 

267 return None 

268 

269 

270@dataclass 

271class CompositeNodeFormatter(NodeFormatter): 

272 """A node formatter that combines multiple formatters together.""" 

273 

274 formatters: list[NodeFormatter] = field(default_factory=list) 

275 

276 def calibrate(self, nodes: list[Node]) -> None: 

277 """Calibrate all the contained formatters with the given nodes.""" 

278 for formatter in self.formatters: 

279 formatter.calibrate(nodes) 

280 

281 def format(self, name: NodeKey, nodes: list[Node], is_composite: bool) -> dict[str, Any] | None: 

282 """Format a node by combining output from all contained formatters.""" 

283 d: dict[str, Any] = {} 

284 for formatter in self.formatters: 

285 format_attrs = formatter.format(name, nodes, is_composite) 

286 if format_attrs is not None: 

287 d.update(format_attrs) 

288 return d 

289 

290 

291@dataclass 

292class GraphView: 

293 """A view for visualizing computation graphs as graphical diagrams.""" 

294 

295 computation: "Computation" 

296 root: Name | None = None 

297 node_formatter: NodeFormatter | None = None 

298 node_transformations: dict[Name, str] | None = None 

299 collapse_all: bool = True 

300 

301 graph_attr: dict[str, Any] | None = None 

302 node_attr: dict[str, Any] | None = None 

303 edge_attr: dict[str, Any] | None = None 

304 

305 struct_dag: nx.DiGraph | None = None 

306 viz_dag: nx.DiGraph | None = None 

307 viz_dot: pydotplus.Dot | None = None 

308 

309 def __post_init__(self) -> None: 

310 """Initialize the graph view after dataclass construction.""" 

311 self.refresh() 

312 

313 @staticmethod 

314 def get_sub_block( 

315 dag: nx.DiGraph, root: Name | None, node_transformations: dict[NodeKey, str] 

316 ) -> tuple[nx.DiGraph, defaultdict[NodeKey, list[NodeKey]], set[NodeKey]]: 

317 """Extract a subgraph with node transformations for visualization.""" 

318 d_transform_to_nodes: defaultdict[str, list[NodeKey]] = defaultdict(list) 

319 for nk, transform in node_transformations.items(): 

320 d_transform_to_nodes[transform].append(nk) 

321 

322 dag_out: nx.DiGraph = nx.DiGraph() 

323 

324 d_original_to_mapped: dict[NodeKey, NodeKey] = {} 

325 s_collapsed: set[NodeKey] = set() 

326 

327 for nk_original in dag.nodes(): 

328 nk_mapped = nk_original.drop_root(root) 

329 if nk_mapped is None: 

330 continue 

331 nk_highest_collapse = nk_original 

332 is_collapsed = False 

333 for nk_collapse in d_transform_to_nodes[NodeTransformations.COLLAPSE]: 

334 if nk_highest_collapse.is_descendent_of(nk_collapse): 

335 nk_highest_collapse = nk_collapse 

336 is_collapsed = True 

337 nk_mapped = nk_highest_collapse.drop_root(root) 

338 if nk_mapped is None: # pragma: no cover 

339 continue 

340 d_original_to_mapped[nk_original] = nk_mapped 

341 if is_collapsed: 

342 s_collapsed.add(nk_mapped) 

343 

344 for nk_mapped in d_original_to_mapped.values(): 

345 dag_out.add_node(nk_mapped) 

346 

347 for nk_u, nk_v in dag.edges(): 

348 nk_mapped_u = d_original_to_mapped.get(nk_u) 

349 nk_mapped_v = d_original_to_mapped.get(nk_v) 

350 if nk_mapped_u is None or nk_mapped_v is None or nk_mapped_u == nk_mapped_v: 

351 continue 

352 dag_out.add_edge(nk_mapped_u, nk_mapped_v) 

353 

354 for nk in d_transform_to_nodes[NodeTransformations.CONTRACT]: 

355 contract_node(dag_out, d_original_to_mapped[nk]) 

356 del d_original_to_mapped[nk] 

357 

358 d_mapped_to_original: defaultdict[NodeKey, list[NodeKey]] = defaultdict(list) 

359 for nk_original, nk_mapped in d_original_to_mapped.items(): 

360 if nk_mapped in dag_out.nodes: 

361 d_mapped_to_original[nk_mapped].append(nk_original) 

362 

363 s_collapsed.intersection_update(dag_out.nodes) 

364 

365 return dag_out, d_mapped_to_original, s_collapsed 

366 

367 def _initialize_transforms(self) -> dict[NodeKey, str]: 

368 """Initialize node transformations for visualization.""" 

369 node_transformations: dict[NodeKey, str] = {} 

370 if self.collapse_all: 

371 self._apply_default_collapse_transforms(node_transformations) 

372 self._apply_custom_transforms(node_transformations) 

373 return node_transformations 

374 

375 def _apply_default_collapse_transforms(self, node_transformations: dict[NodeKey, str]) -> None: 

376 """Apply default collapse transformations to tree nodes.""" 

377 for n in self.computation.get_tree_descendents(self.root): 

378 nk = to_nodekey(n) 

379 if not self.computation.has_node(nk): 

380 node_transformations[nk] = NodeTransformations.COLLAPSE 

381 

382 def _apply_custom_transforms(self, node_transformations: dict[NodeKey, str]) -> dict[NodeKey, str]: 

383 """Apply user-specified custom transformations to nodes.""" 

384 if self.node_transformations is not None: 

385 for rule_name, transform in self.node_transformations.items(): 

386 include_ancestors = transform == NodeTransformations.EXPAND 

387 rule_nk = to_nodekey(rule_name) 

388 if is_pattern(rule_nk): 

389 apply_nodes: set[NodeKey] = set() 

390 for n in self.computation.get_tree_descendents(self.root): 

391 nk = to_nodekey(n) 

392 if match_pattern(rule_nk, nk): 

393 apply_nodes.add(nk) 

394 else: 

395 apply_nodes = {rule_nk} 

396 node_transformations[rule_nk] = transform 

397 if include_ancestors: 

398 for nk in apply_nodes: 

399 for nk1 in nk.ancestors(): 

400 if nk1.is_root or nk1 == self.root: 

401 break 

402 node_transformations[nk1] = NodeTransformations.EXPAND 

403 for r_nk in apply_nodes: 

404 node_transformations[r_nk] = transform 

405 return node_transformations 

406 

407 def _create_visualization_dag( 

408 self, original_nodes: defaultdict[NodeKey, list[NodeKey]], composite_nodes: set[NodeKey] 

409 ) -> nx.DiGraph: 

410 """Create the visualization DAG from structure and node data.""" 

411 node_formatter = self.node_formatter 

412 if node_formatter is None: 

413 node_formatter = NodeFormatter.create() 

414 assert self.struct_dag is not None # noqa: S101 

415 return create_viz_dag(self.struct_dag, self.computation.dag, node_formatter, original_nodes, composite_nodes) 

416 

417 def _create_dot_graph(self) -> pydotplus.Dot: 

418 """Create a PyDot graph from the visualization DAG.""" 

419 return to_pydot(self.viz_dag, self.graph_attr, self.node_attr, self.edge_attr) 

420 

421 def refresh(self) -> None: 

422 """Refresh the visualization by rebuilding the graph structure.""" 

423 node_transformations = self._initialize_transforms() 

424 self.struct_dag, original_nodes, composite_nodes = self.get_sub_block( 

425 self.computation.dag, self.root, node_transformations 

426 ) 

427 self.viz_dag = self._create_visualization_dag(original_nodes, composite_nodes) 

428 self.viz_dot = self._create_dot_graph() 

429 

430 def svg(self) -> str | None: 

431 """Generate SVG representation of the visualization.""" 

432 if self.viz_dot is None: 

433 return None 

434 svg_bytes: bytes = self.viz_dot.create_svg() # type: ignore[attr-defined] 

435 return svg_bytes.decode("utf-8") 

436 

437 def view(self) -> None: # pragma: no cover 

438 """Open the visualization in a PDF viewer.""" 

439 assert self.viz_dot is not None # noqa: S101 

440 with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f: 

441 f.write(self.viz_dot.create_pdf()) # type: ignore[attr-defined] 

442 if sys.platform == "win32": 

443 os.startfile(f.name) # pragma: no cover # nosec B606 # noqa: S606 

444 else: 

445 subprocess.run(["open", f.name], check=False) # pragma: no cover # nosec B603 B607 # noqa: S603, S607 

446 

447 def _repr_svg_(self) -> str | None: 

448 """Return SVG representation for Jupyter notebook display.""" 

449 return self.svg() 

450 

451 

452def create_viz_dag( 

453 struct_dag: nx.DiGraph, 

454 comp_dag: nx.DiGraph, 

455 node_formatter: NodeFormatter, 

456 original_nodes: defaultdict[NodeKey, list[NodeKey]], 

457 composite_nodes: set[NodeKey], 

458) -> nx.DiGraph: 

459 """Create a visualization DAG from the computation structure.""" 

460 if node_formatter is not None: 

461 nodes: list[Node] = [] 

462 for nodekey in struct_dag.nodes: 

463 for original_nodekey in original_nodes[nodekey]: 

464 data = comp_dag.nodes[original_nodekey] 

465 n = Node(nodekey, original_nodekey, data) 

466 nodes.append(n) 

467 node_formatter.calibrate(nodes) 

468 

469 viz_dag: nx.DiGraph = nx.DiGraph() 

470 node_index_map: dict[NodeKey, str] = {} 

471 for i, nodekey in enumerate(struct_dag.nodes): 

472 short_name = f"n{i}" 

473 attr_dict: dict[str, Any] | None = None 

474 

475 if node_formatter is not None: 

476 nodes = [] 

477 for original_nodekey in original_nodes[nodekey]: 

478 data = comp_dag.nodes[original_nodekey] 

479 n = Node(nodekey, original_nodekey, data) 

480 nodes.append(n) 

481 is_composite = nodekey in composite_nodes 

482 attr_dict = node_formatter.format(nodekey, nodes, is_composite) 

483 if attr_dict is None: # pragma: no cover 

484 attr_dict = {} 

485 

486 attr_dict = {k: v for k, v in attr_dict.items() if v is not None} 

487 

488 viz_dag.add_node(short_name, **attr_dict) 

489 node_index_map[nodekey] = short_name 

490 

491 for name1, name2 in struct_dag.edges(): 

492 short_name_1 = node_index_map[name1] 

493 short_name_2 = node_index_map[name2] 

494 

495 group_path1 = get_group_path(name1, struct_dag.nodes[name1]) 

496 group_path2 = get_group_path(name2, struct_dag.nodes[name2]) 

497 group_path = NodeKey.common_parent(group_path1, group_path2) 

498 

499 edge_attr_dict: dict[str, Any] = {} 

500 if not group_path.is_root: 

501 # group_path = None 

502 edge_attr_dict["_group"] = group_path 

503 

504 viz_dag.add_edge(short_name_1, short_name_2, **edge_attr_dict) 

505 

506 return viz_dag 

507 

508 

509def _group_nodes_and_edges( 

510 viz_dag: nx.DiGraph, 

511) -> tuple[NodeKey, dict[NodeKey, list[str]], dict[NodeKey, list[tuple[str, str]]]]: 

512 """Group nodes and edges by their groups.""" 

513 root = NodeKey.root() 

514 

515 node_groups: dict[NodeKey, list[str]] = {} 

516 for name, data in viz_dag.nodes(data=True): 

517 group = data.get("_group", root) 

518 node_groups.setdefault(group, []).append(name) 

519 

520 edge_groups: dict[NodeKey, list[tuple[str, str]]] = {} 

521 for name1, name2, data in viz_dag.edges(data=True): 

522 group = data.get("_group", root) 

523 edge_groups.setdefault(group, []).append((name1, name2)) 

524 

525 return root, node_groups, edge_groups 

526 

527 

528def _create_pydot_nodes( 

529 viz_dag: nx.DiGraph, 

530 node_groups: dict[NodeKey, list[str]], 

531 subgraphs: dict[NodeKey, pydotplus.Dot | pydotplus.Subgraph], 

532 root: NodeKey, 

533) -> None: 

534 """Create PyDot nodes for each group.""" 

535 for group, names in node_groups.items(): 

536 c = subgraphs[root] if group is root else create_subgraph(group) 

537 

538 for name in names: 

539 node = pydotplus.Node(name) 

540 for k, v in viz_dag.nodes[name].items(): 

541 if not k.startswith("_"): 

542 node.set(k, v) 

543 c.add_node(node) 

544 

545 subgraphs[group] = c 

546 

547 

548def _ensure_parent_subgraphs(subgraphs: dict[NodeKey, pydotplus.Dot | pydotplus.Subgraph]) -> None: 

549 """Ensure all parent subgraphs exist in the hierarchy.""" 

550 groups = list(subgraphs.keys()) 

551 for group in groups: 

552 group1 = group 

553 while True: 

554 if group1.is_root: 

555 break 

556 group1 = group1.parent 

557 if group1 in subgraphs: 

558 break 

559 subgraphs[group1] = create_subgraph(group1) 

560 

561 

562def _link_subgraphs(subgraphs: dict[NodeKey, pydotplus.Dot | pydotplus.Subgraph]) -> None: 

563 """Link subgraphs to their parents.""" 

564 for group, subgraph in subgraphs.items(): 

565 if group.is_root: 

566 continue 

567 parent = group 

568 while True: 

569 parent = parent.parent 

570 if parent in subgraphs or parent.is_root: 

571 break 

572 subgraphs[parent].add_subgraph(subgraph) 

573 

574 

575def _add_edges_to_subgraphs( 

576 edge_groups: dict[NodeKey, list[tuple[str, str]]], subgraphs: dict[NodeKey, pydotplus.Dot | pydotplus.Subgraph] 

577) -> None: 

578 """Add edges to their respective subgraphs.""" 

579 for group, edges in edge_groups.items(): 

580 c = subgraphs[group] 

581 for name1, name2 in edges: 

582 edge = pydotplus.Edge(name1, name2) 

583 c.add_edge(edge) 

584 

585 

586def to_pydot( 

587 viz_dag: nx.DiGraph | None, 

588 graph_attr: dict[str, Any] | None = None, 

589 node_attr: dict[str, Any] | None = None, 

590 edge_attr: dict[str, Any] | None = None, 

591) -> pydotplus.Dot: 

592 """Convert a visualization DAG to a PyDot graph for rendering.""" 

593 assert viz_dag is not None # noqa: S101 

594 root, node_groups, edge_groups = _group_nodes_and_edges(viz_dag) 

595 

596 subgraphs: dict[NodeKey, pydotplus.Dot | pydotplus.Subgraph] = { 

597 root: create_root_graph(graph_attr, node_attr, edge_attr) 

598 } 

599 

600 _create_pydot_nodes(viz_dag, node_groups, subgraphs, root) 

601 _ensure_parent_subgraphs(subgraphs) 

602 _link_subgraphs(subgraphs) 

603 _add_edges_to_subgraphs(edge_groups, subgraphs) 

604 

605 result = subgraphs[root] 

606 assert isinstance(result, pydotplus.Dot) # noqa: S101 

607 return result 

608 

609 

610def create_root_graph( 

611 graph_attr: dict[str, Any] | None, node_attr: dict[str, Any] | None, edge_attr: dict[str, Any] | None 

612) -> pydotplus.Dot: 

613 """Create root Graphviz graph with specified attributes. 

614 

615 Notes: 

616 Graphviz attributes like size expect a quoted string when containing 

617 commas (e.g. "10,8"). Some pydotplus setters don't auto-quote, which 

618 can produce a DOT syntax error near ',' if we pass a raw string. 

619 We defensively quote string values that contain commas or whitespace. 

620 """ 

621 

622 def _normalize_attr_value(v: Any) -> Any: 

623 """Normalize attribute values for Graphviz, quoting strings as needed.""" 

624 # Keep numeric values as-is 

625 if isinstance(v, (int, float)): 

626 return v 

627 s = str(v) 

628 # If already quoted, keep 

629 if len(s) >= 2 and ((s[0] == '"' and s[-1] == '"') or (s[0] == "'" and s[-1] == "'")): 

630 return s 

631 # Quote if contains comma, whitespace, or special characters 

632 if any(c in s for c in [",", " ", "\t", "\n"]) or s == "": 

633 return f'"{s}"' 

634 return s 

635 

636 root_graph = pydotplus.Dot() 

637 if graph_attr is not None: 

638 for k, v in graph_attr.items(): 

639 root_graph.set(k, _normalize_attr_value(v)) 

640 if node_attr is not None: 

641 # For node/edge defaults, normalize each value too 

642 node_defaults = {k: _normalize_attr_value(v) for k, v in node_attr.items()} 

643 root_graph.set_node_defaults(**node_defaults) 

644 if edge_attr is not None: 

645 edge_defaults = {k: _normalize_attr_value(v) for k, v in edge_attr.items()} 

646 root_graph.set_edge_defaults(**edge_defaults) 

647 return root_graph 

648 

649 

650def create_subgraph(group: NodeKey) -> pydotplus.Subgraph: 

651 """Create a Graphviz subgraph for a node group.""" 

652 c = pydotplus.Subgraph("cluster_" + str(group)) 

653 c.obj_dict["attributes"]["label"] = str(group) 

654 return c