Coverage for src / loman / visualization.py: 99%
403 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-22 21:30 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-22 21:30 +0000
1"""Visualization tools for computation graphs using Graphviz."""
3import os
4import subprocess
5import sys
6import tempfile
7from abc import ABC, abstractmethod
8from collections import defaultdict
9from dataclasses import dataclass, field
10from typing import TYPE_CHECKING, Any, ClassVar
12import matplotlib as mpl
13import networkx as nx
14import numpy as np
15import pandas as pd
16import pydotplus
17from matplotlib.colors import Colormap
19import loman
21from .consts import NodeAttributes, NodeTransformations, States
22from .graph_utils import contract_node
23from .nodekey import Name, NodeKey, is_pattern, match_pattern, to_nodekey
25if TYPE_CHECKING:
26 from .computeengine import Computation
29@dataclass
30class Node:
31 """Represents a node in the visualization graph."""
33 nodekey: NodeKey
34 original_nodekey: NodeKey
35 data: dict[str, Any]
38class NodeFormatter(ABC):
39 """Abstract base class for node formatting in visualizations."""
41 @abstractmethod
42 def calibrate(self, nodes: list[Node]) -> None:
43 """Calibrate formatter based on all nodes in the graph."""
44 pass # pragma: no cover
46 @abstractmethod
47 def format(self, name: NodeKey, nodes: list[Node], is_composite: bool) -> dict[str, Any] | None:
48 """Format node appearance returning dict of graphviz attributes."""
49 pass # pragma: no cover
51 @staticmethod
52 def create(
53 cmap: dict[States | None, str] | Colormap | None = None, colors: str = "state", shapes: str | None = None
54 ) -> "CompositeNodeFormatter":
55 """Create a composite node formatter with specified color and shape options."""
56 node_formatters: list[NodeFormatter] = [StandardLabel(), StandardGroup()]
58 if isinstance(shapes, str):
59 shapes = shapes.lower()
60 if shapes == "type":
61 node_formatters.append(ShapeByType())
62 elif shapes is None:
63 pass
64 else:
65 msg = f"{shapes} is not a valid loman shapes parameter for visualization"
66 raise ValueError(msg)
68 colors = colors.lower()
69 if colors == "state":
70 state_cmap = cmap if isinstance(cmap, dict) else None
71 node_formatters.append(ColorByState(state_cmap)) # type: ignore[arg-type]
72 elif colors == "timing":
73 timing_cmap = cmap if isinstance(cmap, Colormap) else None
74 node_formatters.append(ColorByTiming(timing_cmap))
75 else:
76 msg = f"{colors} is not a valid loman colors parameter for visualization"
77 raise ValueError(msg)
79 node_formatters.append(StandardStylingOverrides())
80 node_formatters.append(RectBlocks())
82 return CompositeNodeFormatter(node_formatters)
85class ColorByState(NodeFormatter):
86 """Node formatter that colors nodes based on their computation state."""
88 DEFAULT_STATE_COLORS: ClassVar[dict[States | None, str]] = {
89 None: "#ffffff", # xkcd white
90 States.PLACEHOLDER: "#f97306", # xkcd orange
91 States.UNINITIALIZED: "#0343df", # xkcd blue
92 States.STALE: "#ffff14", # xkcd yellow
93 States.COMPUTABLE: "#9dff00", # xkcd bright yellow green
94 States.UPTODATE: "#15b01a", # xkcd green
95 States.ERROR: "#e50000", # xkcd red
96 States.PINNED: "#bf77f6", # xkcd light purple
97 }
99 def __init__(self, state_colors: dict[States | None, str] | None = None) -> None:
100 """Initialize with custom state color mapping."""
101 if state_colors is None:
102 state_colors = self.DEFAULT_STATE_COLORS.copy()
103 self.state_colors = state_colors
105 def calibrate(self, nodes: list[Node]) -> None:
106 """Calibrate formatter based on all nodes in the graph."""
107 pass
109 def format(self, name: NodeKey, nodes: list[Node], is_composite: bool) -> dict[str, Any] | None:
110 """Format node color based on computation state."""
111 states = [node.data.get(NodeAttributes.STATE, None) for node in nodes]
112 state: States | None
113 if len(nodes) == 1:
114 state = states[0]
115 else:
116 if any(s == States.ERROR for s in states):
117 state = States.ERROR
118 elif any(s == States.STALE for s in states):
119 state = States.STALE
120 else:
121 state0 = states[0]
122 state = state0 if all(s == state0 for s in states) else None
123 return {"style": "filled", "fillcolor": self.state_colors[state]}
126class ColorByTiming(NodeFormatter):
127 """Node formatter that colors nodes based on their execution timing."""
129 def __init__(self, cmap: Colormap | None = None) -> None:
130 """Initialize with an optional colormap for timing visualization."""
131 if cmap is None:
132 cmap = mpl.colors.LinearSegmentedColormap.from_list("blend", ["#15b01a", "#ffff14", "#e50000"])
133 self.cmap = cmap
134 self.min_duration: float = float("nan")
135 self.max_duration: float = float("nan")
137 def calibrate(self, nodes: list[Node]) -> None:
138 """Calibrate the color mapping based on node timing data."""
139 durations: list[float] = []
140 for node in nodes:
141 timing = node.data.get(NodeAttributes.TIMING)
142 if timing is not None:
143 durations.append(timing.duration)
144 if durations:
145 self.max_duration = max(durations)
146 self.min_duration = min(durations)
148 def format(self, name: NodeKey, nodes: list[Node], is_composite: bool) -> dict[str, Any] | None:
149 """Format a node with timing-based coloring."""
150 if len(nodes) == 1:
151 data = nodes[0].data
152 timing_data = data.get(NodeAttributes.TIMING)
153 if timing_data is None:
154 col = "#FFFFFF"
155 else:
156 duration = timing_data.duration
157 norm_duration: float = (duration - self.min_duration) / max(1e-8, self.max_duration - self.min_duration)
158 col = mpl.colors.rgb2hex(self.cmap(norm_duration))
159 return {"style": "filled", "fillcolor": col}
160 return None
163class ShapeByType(NodeFormatter):
164 """Node formatter that sets node shapes based on their type."""
166 def calibrate(self, nodes: list[Node]) -> None:
167 """Calibrate formatter based on all nodes in the graph."""
168 pass
170 def format(self, name: NodeKey, nodes: list[Node], is_composite: bool) -> dict[str, Any] | None:
171 """Format a node with type-based shape styling."""
172 if len(nodes) == 1:
173 data = nodes[0].data
174 value = data.get(NodeAttributes.VALUE)
175 if value is None:
176 return None
177 if isinstance(value, np.ndarray):
178 return {"shape": "rect"}
179 elif isinstance(value, pd.DataFrame):
180 return {"shape": "box3d"}
181 elif np.isscalar(value):
182 return {"shape": "ellipse"}
183 elif isinstance(value, (list, tuple)):
184 return {"shape": "ellipse", "peripheries": 2}
185 elif isinstance(value, dict):
186 return {"shape": "house", "peripheries": 2}
187 elif isinstance(value, loman.Computation):
188 return {"shape": "hexagon"}
189 else:
190 return {"shape": "diamond"}
191 return None
194class RectBlocks(NodeFormatter):
195 """Node formatter that shapes composite nodes as rectangles."""
197 def calibrate(self, nodes: list[Node]) -> None:
198 """Calibrate formatter based on all nodes in the graph."""
199 pass
201 def format(self, name: NodeKey, nodes: list[Node], is_composite: bool) -> dict[str, Any] | None:
202 """Return rectangle shape for composite nodes."""
203 if is_composite:
204 return {"shape": "rect", "peripheries": 2}
205 return None
208class StandardLabel(NodeFormatter):
209 """Node formatter that sets node labels."""
211 def calibrate(self, nodes: list[Node]) -> None:
212 """Calibrate formatter based on all nodes in the graph."""
213 pass
215 def format(self, name: NodeKey, nodes: list[Node], is_composite: bool) -> dict[str, Any] | None:
216 """Return standard label for node."""
217 return {"label": name.label}
220def get_group_path(name: NodeKey, data: dict[str, Any]) -> NodeKey:
221 """Determine the group path for a node based on name hierarchy and group attribute."""
222 name_group_path = name.parent
223 attribute_group = data.get(NodeAttributes.GROUP)
224 attribute_group_path = None if attribute_group is None else NodeKey((attribute_group,))
226 group_path = name_group_path.join(attribute_group_path)
227 return group_path
230class StandardGroup(NodeFormatter):
231 """Node formatter that applies standard grouping styles."""
233 def calibrate(self, nodes: list[Node]) -> None:
234 """Calibrate formatter based on all nodes in the graph."""
235 pass
237 def format(self, name: NodeKey, nodes: list[Node], is_composite: bool) -> dict[str, Any] | None:
238 """Format a node with standard group styling."""
239 if len(nodes) == 1:
240 data = nodes[0].data
241 group_path = get_group_path(name, data)
242 else:
243 group_path = name.parent
244 if group_path.is_root:
245 return None
246 return {"_group": group_path}
249class StandardStylingOverrides(NodeFormatter):
250 """Node formatter that applies standard styling overrides."""
252 def calibrate(self, nodes: list[Node]) -> None:
253 """Calibrate formatter based on all nodes in the graph."""
254 pass
256 def format(self, name: NodeKey, nodes: list[Node], is_composite: bool) -> dict[str, Any] | None:
257 """Format a node with standard styling overrides."""
258 if len(nodes) == 1:
259 data = nodes[0].data
260 style = data.get(NodeAttributes.STYLE)
261 if style is None:
262 return None
263 if style == "small":
264 return {"width": 0.3, "height": 0.2, "fontsize": 8}
265 elif style == "dot":
266 return {"shape": "point", "width": 0.1, "peripheries": 1}
267 return None
270@dataclass
271class CompositeNodeFormatter(NodeFormatter):
272 """A node formatter that combines multiple formatters together."""
274 formatters: list[NodeFormatter] = field(default_factory=list)
276 def calibrate(self, nodes: list[Node]) -> None:
277 """Calibrate all the contained formatters with the given nodes."""
278 for formatter in self.formatters:
279 formatter.calibrate(nodes)
281 def format(self, name: NodeKey, nodes: list[Node], is_composite: bool) -> dict[str, Any] | None:
282 """Format a node by combining output from all contained formatters."""
283 d: dict[str, Any] = {}
284 for formatter in self.formatters:
285 format_attrs = formatter.format(name, nodes, is_composite)
286 if format_attrs is not None:
287 d.update(format_attrs)
288 return d
291@dataclass
292class GraphView:
293 """A view for visualizing computation graphs as graphical diagrams."""
295 computation: "Computation"
296 root: Name | None = None
297 node_formatter: NodeFormatter | None = None
298 node_transformations: dict[Name, str] | None = None
299 collapse_all: bool = True
301 graph_attr: dict[str, Any] | None = None
302 node_attr: dict[str, Any] | None = None
303 edge_attr: dict[str, Any] | None = None
305 struct_dag: nx.DiGraph | None = None
306 viz_dag: nx.DiGraph | None = None
307 viz_dot: pydotplus.Dot | None = None
309 def __post_init__(self) -> None:
310 """Initialize the graph view after dataclass construction."""
311 self.refresh()
313 @staticmethod
314 def get_sub_block(
315 dag: nx.DiGraph, root: Name | None, node_transformations: dict[NodeKey, str]
316 ) -> tuple[nx.DiGraph, defaultdict[NodeKey, list[NodeKey]], set[NodeKey]]:
317 """Extract a subgraph with node transformations for visualization."""
318 d_transform_to_nodes: defaultdict[str, list[NodeKey]] = defaultdict(list)
319 for nk, transform in node_transformations.items():
320 d_transform_to_nodes[transform].append(nk)
322 dag_out: nx.DiGraph = nx.DiGraph()
324 d_original_to_mapped: dict[NodeKey, NodeKey] = {}
325 s_collapsed: set[NodeKey] = set()
327 for nk_original in dag.nodes():
328 nk_mapped = nk_original.drop_root(root)
329 if nk_mapped is None:
330 continue
331 nk_highest_collapse = nk_original
332 is_collapsed = False
333 for nk_collapse in d_transform_to_nodes[NodeTransformations.COLLAPSE]:
334 if nk_highest_collapse.is_descendent_of(nk_collapse):
335 nk_highest_collapse = nk_collapse
336 is_collapsed = True
337 nk_mapped = nk_highest_collapse.drop_root(root)
338 if nk_mapped is None: # pragma: no cover
339 continue
340 d_original_to_mapped[nk_original] = nk_mapped
341 if is_collapsed:
342 s_collapsed.add(nk_mapped)
344 for nk_mapped in d_original_to_mapped.values():
345 dag_out.add_node(nk_mapped)
347 for nk_u, nk_v in dag.edges():
348 nk_mapped_u = d_original_to_mapped.get(nk_u)
349 nk_mapped_v = d_original_to_mapped.get(nk_v)
350 if nk_mapped_u is None or nk_mapped_v is None or nk_mapped_u == nk_mapped_v:
351 continue
352 dag_out.add_edge(nk_mapped_u, nk_mapped_v)
354 for nk in d_transform_to_nodes[NodeTransformations.CONTRACT]:
355 contract_node(dag_out, d_original_to_mapped[nk])
356 del d_original_to_mapped[nk]
358 d_mapped_to_original: defaultdict[NodeKey, list[NodeKey]] = defaultdict(list)
359 for nk_original, nk_mapped in d_original_to_mapped.items():
360 if nk_mapped in dag_out.nodes:
361 d_mapped_to_original[nk_mapped].append(nk_original)
363 s_collapsed.intersection_update(dag_out.nodes)
365 return dag_out, d_mapped_to_original, s_collapsed
367 def _initialize_transforms(self) -> dict[NodeKey, str]:
368 """Initialize node transformations for visualization."""
369 node_transformations: dict[NodeKey, str] = {}
370 if self.collapse_all:
371 self._apply_default_collapse_transforms(node_transformations)
372 self._apply_custom_transforms(node_transformations)
373 return node_transformations
375 def _apply_default_collapse_transforms(self, node_transformations: dict[NodeKey, str]) -> None:
376 """Apply default collapse transformations to tree nodes."""
377 for n in self.computation.get_tree_descendents(self.root):
378 nk = to_nodekey(n)
379 if not self.computation.has_node(nk):
380 node_transformations[nk] = NodeTransformations.COLLAPSE
382 def _apply_custom_transforms(self, node_transformations: dict[NodeKey, str]) -> dict[NodeKey, str]:
383 """Apply user-specified custom transformations to nodes."""
384 if self.node_transformations is not None:
385 for rule_name, transform in self.node_transformations.items():
386 include_ancestors = transform == NodeTransformations.EXPAND
387 rule_nk = to_nodekey(rule_name)
388 if is_pattern(rule_nk):
389 apply_nodes: set[NodeKey] = set()
390 for n in self.computation.get_tree_descendents(self.root):
391 nk = to_nodekey(n)
392 if match_pattern(rule_nk, nk):
393 apply_nodes.add(nk)
394 else:
395 apply_nodes = {rule_nk}
396 node_transformations[rule_nk] = transform
397 if include_ancestors:
398 for nk in apply_nodes:
399 for nk1 in nk.ancestors():
400 if nk1.is_root or nk1 == self.root:
401 break
402 node_transformations[nk1] = NodeTransformations.EXPAND
403 for r_nk in apply_nodes:
404 node_transformations[r_nk] = transform
405 return node_transformations
407 def _create_visualization_dag(
408 self, original_nodes: defaultdict[NodeKey, list[NodeKey]], composite_nodes: set[NodeKey]
409 ) -> nx.DiGraph:
410 """Create the visualization DAG from structure and node data."""
411 node_formatter = self.node_formatter
412 if node_formatter is None:
413 node_formatter = NodeFormatter.create()
414 assert self.struct_dag is not None # noqa: S101
415 return create_viz_dag(self.struct_dag, self.computation.dag, node_formatter, original_nodes, composite_nodes)
417 def _create_dot_graph(self) -> pydotplus.Dot:
418 """Create a PyDot graph from the visualization DAG."""
419 return to_pydot(self.viz_dag, self.graph_attr, self.node_attr, self.edge_attr)
421 def refresh(self) -> None:
422 """Refresh the visualization by rebuilding the graph structure."""
423 node_transformations = self._initialize_transforms()
424 self.struct_dag, original_nodes, composite_nodes = self.get_sub_block(
425 self.computation.dag, self.root, node_transformations
426 )
427 self.viz_dag = self._create_visualization_dag(original_nodes, composite_nodes)
428 self.viz_dot = self._create_dot_graph()
430 def svg(self) -> str | None:
431 """Generate SVG representation of the visualization."""
432 if self.viz_dot is None:
433 return None
434 svg_bytes: bytes = self.viz_dot.create_svg() # type: ignore[attr-defined]
435 return svg_bytes.decode("utf-8")
437 def view(self) -> None: # pragma: no cover
438 """Open the visualization in a PDF viewer."""
439 assert self.viz_dot is not None # noqa: S101
440 with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
441 f.write(self.viz_dot.create_pdf()) # type: ignore[attr-defined]
442 if sys.platform == "win32":
443 os.startfile(f.name) # pragma: no cover # nosec B606 # noqa: S606
444 else:
445 subprocess.run(["open", f.name], check=False) # pragma: no cover # nosec B603 B607 # noqa: S603, S607
447 def _repr_svg_(self) -> str | None:
448 """Return SVG representation for Jupyter notebook display."""
449 return self.svg()
452def create_viz_dag(
453 struct_dag: nx.DiGraph,
454 comp_dag: nx.DiGraph,
455 node_formatter: NodeFormatter,
456 original_nodes: defaultdict[NodeKey, list[NodeKey]],
457 composite_nodes: set[NodeKey],
458) -> nx.DiGraph:
459 """Create a visualization DAG from the computation structure."""
460 if node_formatter is not None:
461 nodes: list[Node] = []
462 for nodekey in struct_dag.nodes:
463 for original_nodekey in original_nodes[nodekey]:
464 data = comp_dag.nodes[original_nodekey]
465 n = Node(nodekey, original_nodekey, data)
466 nodes.append(n)
467 node_formatter.calibrate(nodes)
469 viz_dag: nx.DiGraph = nx.DiGraph()
470 node_index_map: dict[NodeKey, str] = {}
471 for i, nodekey in enumerate(struct_dag.nodes):
472 short_name = f"n{i}"
473 attr_dict: dict[str, Any] | None = None
475 if node_formatter is not None:
476 nodes = []
477 for original_nodekey in original_nodes[nodekey]:
478 data = comp_dag.nodes[original_nodekey]
479 n = Node(nodekey, original_nodekey, data)
480 nodes.append(n)
481 is_composite = nodekey in composite_nodes
482 attr_dict = node_formatter.format(nodekey, nodes, is_composite)
483 if attr_dict is None: # pragma: no cover
484 attr_dict = {}
486 attr_dict = {k: v for k, v in attr_dict.items() if v is not None}
488 viz_dag.add_node(short_name, **attr_dict)
489 node_index_map[nodekey] = short_name
491 for name1, name2 in struct_dag.edges():
492 short_name_1 = node_index_map[name1]
493 short_name_2 = node_index_map[name2]
495 group_path1 = get_group_path(name1, struct_dag.nodes[name1])
496 group_path2 = get_group_path(name2, struct_dag.nodes[name2])
497 group_path = NodeKey.common_parent(group_path1, group_path2)
499 edge_attr_dict: dict[str, Any] = {}
500 if not group_path.is_root:
501 # group_path = None
502 edge_attr_dict["_group"] = group_path
504 viz_dag.add_edge(short_name_1, short_name_2, **edge_attr_dict)
506 return viz_dag
509def _group_nodes_and_edges(
510 viz_dag: nx.DiGraph,
511) -> tuple[NodeKey, dict[NodeKey, list[str]], dict[NodeKey, list[tuple[str, str]]]]:
512 """Group nodes and edges by their groups."""
513 root = NodeKey.root()
515 node_groups: dict[NodeKey, list[str]] = {}
516 for name, data in viz_dag.nodes(data=True):
517 group = data.get("_group", root)
518 node_groups.setdefault(group, []).append(name)
520 edge_groups: dict[NodeKey, list[tuple[str, str]]] = {}
521 for name1, name2, data in viz_dag.edges(data=True):
522 group = data.get("_group", root)
523 edge_groups.setdefault(group, []).append((name1, name2))
525 return root, node_groups, edge_groups
528def _create_pydot_nodes(
529 viz_dag: nx.DiGraph,
530 node_groups: dict[NodeKey, list[str]],
531 subgraphs: dict[NodeKey, pydotplus.Dot | pydotplus.Subgraph],
532 root: NodeKey,
533) -> None:
534 """Create PyDot nodes for each group."""
535 for group, names in node_groups.items():
536 c = subgraphs[root] if group is root else create_subgraph(group)
538 for name in names:
539 node = pydotplus.Node(name)
540 for k, v in viz_dag.nodes[name].items():
541 if not k.startswith("_"):
542 node.set(k, v)
543 c.add_node(node)
545 subgraphs[group] = c
548def _ensure_parent_subgraphs(subgraphs: dict[NodeKey, pydotplus.Dot | pydotplus.Subgraph]) -> None:
549 """Ensure all parent subgraphs exist in the hierarchy."""
550 groups = list(subgraphs.keys())
551 for group in groups:
552 group1 = group
553 while True:
554 if group1.is_root:
555 break
556 group1 = group1.parent
557 if group1 in subgraphs:
558 break
559 subgraphs[group1] = create_subgraph(group1)
562def _link_subgraphs(subgraphs: dict[NodeKey, pydotplus.Dot | pydotplus.Subgraph]) -> None:
563 """Link subgraphs to their parents."""
564 for group, subgraph in subgraphs.items():
565 if group.is_root:
566 continue
567 parent = group
568 while True:
569 parent = parent.parent
570 if parent in subgraphs or parent.is_root:
571 break
572 subgraphs[parent].add_subgraph(subgraph)
575def _add_edges_to_subgraphs(
576 edge_groups: dict[NodeKey, list[tuple[str, str]]], subgraphs: dict[NodeKey, pydotplus.Dot | pydotplus.Subgraph]
577) -> None:
578 """Add edges to their respective subgraphs."""
579 for group, edges in edge_groups.items():
580 c = subgraphs[group]
581 for name1, name2 in edges:
582 edge = pydotplus.Edge(name1, name2)
583 c.add_edge(edge)
586def to_pydot(
587 viz_dag: nx.DiGraph | None,
588 graph_attr: dict[str, Any] | None = None,
589 node_attr: dict[str, Any] | None = None,
590 edge_attr: dict[str, Any] | None = None,
591) -> pydotplus.Dot:
592 """Convert a visualization DAG to a PyDot graph for rendering."""
593 assert viz_dag is not None # noqa: S101
594 root, node_groups, edge_groups = _group_nodes_and_edges(viz_dag)
596 subgraphs: dict[NodeKey, pydotplus.Dot | pydotplus.Subgraph] = {
597 root: create_root_graph(graph_attr, node_attr, edge_attr)
598 }
600 _create_pydot_nodes(viz_dag, node_groups, subgraphs, root)
601 _ensure_parent_subgraphs(subgraphs)
602 _link_subgraphs(subgraphs)
603 _add_edges_to_subgraphs(edge_groups, subgraphs)
605 result = subgraphs[root]
606 assert isinstance(result, pydotplus.Dot) # noqa: S101
607 return result
610def create_root_graph(
611 graph_attr: dict[str, Any] | None, node_attr: dict[str, Any] | None, edge_attr: dict[str, Any] | None
612) -> pydotplus.Dot:
613 """Create root Graphviz graph with specified attributes.
615 Notes:
616 Graphviz attributes like size expect a quoted string when containing
617 commas (e.g. "10,8"). Some pydotplus setters don't auto-quote, which
618 can produce a DOT syntax error near ',' if we pass a raw string.
619 We defensively quote string values that contain commas or whitespace.
620 """
622 def _normalize_attr_value(v: Any) -> Any:
623 """Normalize attribute values for Graphviz, quoting strings as needed."""
624 # Keep numeric values as-is
625 if isinstance(v, (int, float)):
626 return v
627 s = str(v)
628 # If already quoted, keep
629 if len(s) >= 2 and ((s[0] == '"' and s[-1] == '"') or (s[0] == "'" and s[-1] == "'")):
630 return s
631 # Quote if contains comma, whitespace, or special characters
632 if any(c in s for c in [",", " ", "\t", "\n"]) or s == "":
633 return f'"{s}"'
634 return s
636 root_graph = pydotplus.Dot()
637 if graph_attr is not None:
638 for k, v in graph_attr.items():
639 root_graph.set(k, _normalize_attr_value(v))
640 if node_attr is not None:
641 # For node/edge defaults, normalize each value too
642 node_defaults = {k: _normalize_attr_value(v) for k, v in node_attr.items()}
643 root_graph.set_node_defaults(**node_defaults)
644 if edge_attr is not None:
645 edge_defaults = {k: _normalize_attr_value(v) for k, v in edge_attr.items()}
646 root_graph.set_edge_defaults(**edge_defaults)
647 return root_graph
650def create_subgraph(group: NodeKey) -> pydotplus.Subgraph:
651 """Create a Graphviz subgraph for a node group."""
652 c = pydotplus.Subgraph("cluster_" + str(group))
653 c.obj_dict["attributes"]["label"] = str(group)
654 return c