Coverage for src / loman / nodekey.py: 100%
165 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-22 21:30 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-22 21:30 +0000
1"""Node key implementation for computation graph navigation."""
3import json
4import re
5from collections.abc import Iterable
6from dataclasses import dataclass
7from typing import Any, ClassVar, Optional, Union
9from loman.util import as_iterable
11Name = Union[str, "NodeKey", object]
12Names = list[Name]
15class PathNotFoundError(Exception):
16 """Exception raised when a node path cannot be found."""
18 pass
21def quote_part(part: object) -> str:
22 """Quote a node key part for safe representation in paths."""
23 if isinstance(part, str):
24 if "/" in part:
25 return json.dumps(part)
26 else:
27 return part
28 return str(part)
31@dataclass(frozen=True, repr=False)
32class NodeKey:
33 """Immutable key for identifying nodes in the computation graph hierarchy."""
35 parts: tuple[Any, ...]
37 def __str__(self) -> str:
38 """Return string representation using path notation."""
39 return "/".join([quote_part(part) for part in self.parts])
41 @property
42 def name(self) -> Name:
43 """Get the name of this node (last part of the path)."""
44 if len(self.parts) == 0:
45 return ""
46 elif len(self.parts) == 1:
47 part: Name = self.parts[0]
48 return part
49 elif all(isinstance(part, str) for part in self.parts):
50 return "/".join(quote_part(part) for part in self.parts)
51 else:
52 return self
54 @property
55 def label(self) -> str:
56 """Get the label for this node (for display purposes)."""
57 if len(self.parts) == 0:
58 return ""
59 return str(self.parts[-1])
61 def drop_root(self, root: Optional["Name"]) -> Optional["NodeKey"]:
62 """Remove a root prefix from this node key if it matches."""
63 if root is None:
64 return self
65 root = to_nodekey(root)
66 n_root_parts = len(root.parts)
67 if self.is_descendent_of(root):
68 parts = self.parts[n_root_parts:]
69 return NodeKey(parts)
70 else:
71 return None
73 def join(self, *others: Name) -> "NodeKey":
74 """Join this node key with other names to create a new node key."""
75 result = self
76 for other in others:
77 if other is None:
78 continue
79 other = to_nodekey(other)
80 result = result.join_parts(*other.parts)
81 return result
83 def join_parts(self, *parts: Any) -> "NodeKey":
84 """Join this node key with raw parts to create a new node key."""
85 if len(parts) == 0:
86 return self
87 return NodeKey(self.parts + tuple(parts))
89 def __truediv__(self, other: Name) -> "NodeKey":
90 """Join this node key with other to create a new node key."""
91 return self.join(other)
93 def is_descendent_of(self, other: "NodeKey") -> bool:
94 """Check if this node key is a descendant of another node key."""
95 n_self_parts = len(self.parts)
96 n_other_parts = len(other.parts)
97 return n_self_parts > n_other_parts and self.parts[:n_other_parts] == other.parts
99 @property
100 def parent(self) -> "NodeKey":
101 """Get the parent node key."""
102 if len(self.parts) == 0:
103 raise PathNotFoundError()
104 return NodeKey(self.parts[:-1])
106 def prepend(self, nk: "NodeKey") -> "NodeKey":
107 """Prepend another node key to this one."""
108 return nk.join_parts(*self.parts)
110 def __repr__(self) -> str:
111 """Return string representation for debugging."""
112 path_str = str(self)
113 quoted_path_str = repr(path_str)
114 return f"{self.__class__.__name__}({quoted_path_str})"
116 def __eq__(self, other: object) -> bool:
117 """Check equality with another NodeKey."""
118 if other is None:
119 return False
120 if not isinstance(other, NodeKey):
121 return NotImplemented
122 return self.parts == other.parts
124 _ROOT: ClassVar["NodeKey | None"] = None
126 @classmethod
127 def root(cls) -> "NodeKey":
128 """Get the root node key."""
129 if cls._ROOT is None:
130 cls._ROOT = cls(())
131 return cls._ROOT
133 @property
134 def is_root(self) -> bool:
135 """Check if this is the root node key."""
136 return len(self.parts) == 0
138 @staticmethod
139 def common_parent(nodekey1: Name, nodekey2: Name) -> "NodeKey":
140 """Find the common parent of two node keys."""
141 nk1 = to_nodekey(nodekey1)
142 nk2 = to_nodekey(nodekey2)
143 parts: list[Any] = []
144 for p1, p2 in zip(nk1.parts, nk2.parts, strict=False):
145 if p1 != p2:
146 break
147 parts.append(p1)
148 return NodeKey(tuple(parts))
150 def ancestors(self) -> list["NodeKey"]:
151 """Get all ancestor node keys from root to parent."""
152 result = []
153 x = self
154 while True:
155 result.append(x)
156 if x.is_root:
157 break
158 x = x.parent
159 return result
162def names_to_node_keys(names: Name | Names) -> list[NodeKey]:
163 """Convert names to NodeKey objects."""
164 return [to_nodekey(name) for name in as_iterable(names)]
167def node_keys_to_names(node_keys: Iterable[NodeKey]) -> list[Name]:
168 """Convert NodeKey objects back to names."""
169 return [node_key.name for node_key in node_keys]
172PART = re.compile(r"([^/]*)/?")
175def _parse_nodekey(path_str: str, end: int) -> NodeKey:
176 """Parse a path string into a NodeKey starting from the given position."""
177 parts: list[str] = []
178 parts_append = parts.append
180 while path_str[end : end + 1] == "/":
181 end = end + 1
182 while True:
183 nextchar = path_str[end : end + 1]
184 if nextchar == "":
185 break
186 if nextchar == '"':
187 part, end = json.decoder.scanstring(path_str, end + 1) # type: ignore[attr-defined]
188 parts_append(part)
189 nextchar = path_str[end : end + 1]
190 assert nextchar == "" or nextchar == "/" # noqa: S101
191 if nextchar != "":
192 end = end + 1
193 else:
194 chunk = PART.match(path_str, end)
195 assert chunk is not None # noqa: S101
196 end = chunk.end()
197 (part,) = chunk.groups()
198 parts_append(part)
200 assert end == len(path_str) # noqa: S101
202 return NodeKey(tuple(parts))
205def parse_nodekey(path_str: str) -> NodeKey:
206 """Parse a string representation into a NodeKey."""
207 return _parse_nodekey(path_str, 0)
210def to_nodekey(name: Name) -> NodeKey:
211 """Convert a name to a NodeKey object."""
212 if isinstance(name, str):
213 return parse_nodekey(name)
214 elif isinstance(name, NodeKey):
215 return name
216 elif isinstance(name, object):
217 return NodeKey((name,))
218 else: # pragma: no cover
219 msg = f"Unexpected error creating node key for name {name}"
220 raise TypeError(msg)
223def nodekey_join(*names: Name) -> NodeKey:
224 """Join multiple names into a single NodeKey."""
225 return NodeKey.root().join(*names)
228def _match_pattern_recursive(pattern: NodeKey, target: NodeKey, p_idx: int, t_idx: int) -> bool:
229 """Recursively match pattern parts against target parts.
231 Args:
232 pattern: The pattern NodeKey to match against
233 target: The target NodeKey to match
234 p_idx: Current index in pattern parts
235 t_idx: Current index in target parts
237 Returns:
238 bool: True if pattern matches target, False otherwise
239 """
240 if p_idx == len(pattern.parts) and t_idx == len(target.parts):
241 return True
242 if p_idx == len(pattern.parts):
243 return False
244 if t_idx == len(target.parts):
245 return all(p == "**" for p in pattern.parts[p_idx:])
247 if pattern.parts[p_idx] == "**":
248 return _match_pattern_recursive(pattern, target, p_idx + 1, t_idx) or _match_pattern_recursive(
249 pattern, target, p_idx, t_idx + 1
250 )
251 elif pattern.parts[p_idx] == "*":
252 return _match_pattern_recursive(pattern, target, p_idx + 1, t_idx + 1)
253 else:
254 if pattern.parts[p_idx] == target.parts[t_idx]:
255 return _match_pattern_recursive(pattern, target, p_idx + 1, t_idx + 1)
256 return False
259def is_pattern(nodekey: NodeKey) -> bool:
260 """Check if a node key contains wildcard patterns."""
261 return any("*" in part or "**" in part for part in nodekey.parts)
264def match_pattern(pattern: NodeKey, target: NodeKey) -> bool:
265 """Match a pattern against a target NodeKey.
267 Supports wildcards:
268 * - matches exactly one part
269 ** - matches zero or more parts
271 Args:
272 pattern: The pattern to match against
273 target: The target to match
275 Returns:
276 bool: True if pattern matches target, False otherwise
277 """
278 return _match_pattern_recursive(pattern, target, 0, 0)