Coverage for src / loman / nodekey.py: 98%
171 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-07 21:24 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-07 21:24 +0000
1"""Node key implementation for computation graph navigation."""
3import json
4import re
5from collections.abc import Hashable, Iterable
6from dataclasses import dataclass
7from typing import ClassVar, Optional, Union
9from loman.util import as_iterable
11Name = Union[str, "NodeKey", Hashable]
12Names = list[Name]
15class PathNotFoundError(Exception):
16 """Exception raised when a node path cannot be found."""
18 pass
21def quote_part(part: Hashable) -> str:
22 """Quote a node key part for safe representation in paths."""
23 if isinstance(part, str):
24 if "/" in part:
25 return json.dumps(part)
26 else:
27 return part
28 return str(part)
31@dataclass(frozen=True, repr=False)
32class NodeKey:
33 """Immutable key for identifying nodes in the computation graph hierarchy."""
35 parts: tuple[Hashable, ...]
37 def __str__(self) -> str:
38 """Return string representation using path notation."""
39 return "/".join([quote_part(part) for part in self.parts])
41 @property
42 def name(self) -> Name:
43 """Get the name of this node (last part of the path)."""
44 if len(self.parts) == 0:
45 return ""
46 elif len(self.parts) == 1:
47 part: Name = self.parts[0]
48 return part
49 elif all(isinstance(part, str) for part in self.parts):
50 return "/".join(quote_part(part) for part in self.parts)
51 else:
52 return self
54 @property
55 def label(self) -> str:
56 """Get the label for this node (for display purposes)."""
57 if len(self.parts) == 0:
58 return ""
59 return str(self.parts[-1])
61 def drop_root(self, root: Optional["Name"]) -> Optional["NodeKey"]:
62 """Remove a root prefix from this node key if it matches."""
63 if root is None:
64 return self
65 root = to_nodekey(root)
66 n_root_parts = len(root.parts)
67 if self.is_descendent_of(root):
68 parts = self.parts[n_root_parts:]
69 return NodeKey(parts)
70 else:
71 return None
73 def join(self, *others: Name) -> "NodeKey":
74 """Join this node key with other names to create a new node key."""
75 result = self
76 for other in others:
77 if other is None:
78 continue
79 other = to_nodekey(other)
80 result = result.join_parts(*other.parts)
81 return result
83 def join_parts(self, *parts: Hashable) -> "NodeKey":
84 """Join this node key with raw parts to create a new node key."""
85 if len(parts) == 0:
86 return self
87 return NodeKey(self.parts + tuple(parts))
89 def __truediv__(self, other: Name) -> "NodeKey":
90 """Join this node key with other to create a new node key."""
91 return self.join(other)
93 def is_descendent_of(self, other: "NodeKey") -> bool:
94 """Check if this node key is a descendant of another node key."""
95 n_self_parts = len(self.parts)
96 n_other_parts = len(other.parts)
97 return n_self_parts > n_other_parts and self.parts[:n_other_parts] == other.parts
99 @property
100 def parent(self) -> "NodeKey":
101 """Get the parent node key."""
102 if len(self.parts) == 0:
103 raise PathNotFoundError()
104 return NodeKey(self.parts[:-1])
106 def prepend(self, nk: "NodeKey") -> "NodeKey":
107 """Prepend another node key to this one."""
108 return nk.join_parts(*self.parts)
110 def __repr__(self) -> str:
111 """Return string representation for debugging."""
112 path_str = str(self)
113 quoted_path_str = repr(path_str)
114 return f"{self.__class__.__name__}({quoted_path_str})"
116 def __eq__(self, other: object) -> bool:
117 """Check equality with another NodeKey."""
118 if other is None:
119 return False
120 if not isinstance(other, NodeKey):
121 return NotImplemented
122 return self.parts == other.parts
124 _ROOT: ClassVar["NodeKey | None"] = None
126 @classmethod
127 def root(cls) -> "NodeKey":
128 """Get the root node key."""
129 if cls._ROOT is None:
130 cls._ROOT = cls(())
131 return cls._ROOT
133 @property
134 def is_root(self) -> bool:
135 """Check if this is the root node key."""
136 return len(self.parts) == 0
138 @staticmethod
139 def common_parent(nodekey1: Name, nodekey2: Name) -> "NodeKey":
140 """Find the common parent of two node keys."""
141 nk1 = to_nodekey(nodekey1)
142 nk2 = to_nodekey(nodekey2)
143 parts: list[Hashable] = []
144 for p1, p2 in zip(nk1.parts, nk2.parts, strict=False):
145 if p1 != p2:
146 break
147 parts.append(p1)
148 return NodeKey(tuple(parts))
150 def ancestors(self) -> list["NodeKey"]:
151 """Get all ancestor node keys from root to parent."""
152 result = []
153 x = self
154 while True:
155 result.append(x)
156 if x.is_root:
157 break
158 x = x.parent
159 return result
162def names_to_node_keys(names: Name | Names) -> list[NodeKey]:
163 """Convert names to NodeKey objects."""
164 return [to_nodekey(name) for name in as_iterable(names)]
167def node_keys_to_names(node_keys: Iterable[NodeKey]) -> list[Name]:
168 """Convert NodeKey objects back to names."""
169 return [node_key.name for node_key in node_keys]
172PART = re.compile(r"([^/]*)/?")
175def _parse_nodekey(path_str: str, end: int) -> NodeKey:
176 """Parse a path string into a NodeKey starting from the given position."""
177 parts: list[str] = []
178 parts_append = parts.append
180 while path_str[end : end + 1] == "/":
181 end = end + 1
182 while True:
183 nextchar = path_str[end : end + 1]
184 if nextchar == "":
185 break
186 if nextchar == '"':
187 part, end = json.decoder.scanstring(path_str, end + 1) # type: ignore[attr-defined]
188 parts_append(part)
189 nextchar = path_str[end : end + 1]
190 if nextchar != "" and nextchar != "/":
191 msg = f"Expected end of string or '/' after quoted part, got {nextchar!r} in {path_str!r}"
192 raise ValueError(msg)
193 if nextchar != "":
194 end = end + 1
195 else:
196 chunk = PART.match(path_str, end)
197 if chunk is None:
198 msg = f"Failed to match node key part at position {end} in {path_str!r}"
199 raise ValueError(msg)
200 end = chunk.end()
201 (part,) = chunk.groups()
202 parts_append(part)
204 if end != len(path_str):
205 msg = f"Unexpected trailing content at position {end} in {path_str!r}"
206 raise ValueError(msg)
208 return NodeKey(tuple(parts))
211def parse_nodekey(path_str: str) -> NodeKey:
212 """Parse a string representation into a NodeKey."""
213 return _parse_nodekey(path_str, 0)
216def to_nodekey(name: Name) -> NodeKey:
217 """Convert a name to a NodeKey object."""
218 if isinstance(name, str):
219 return parse_nodekey(name)
220 elif isinstance(name, NodeKey):
221 return name
222 elif isinstance(name, object):
223 return NodeKey((name,))
224 else: # pragma: no cover
225 msg = f"Unexpected error creating node key for name {name}"
226 raise TypeError(msg)
229def nodekey_join(*names: Name) -> NodeKey:
230 """Join multiple names into a single NodeKey."""
231 return NodeKey.root().join(*names)
234def _match_pattern_recursive(pattern: NodeKey, target: NodeKey, p_idx: int, t_idx: int) -> bool:
235 """Recursively match pattern parts against target parts.
237 Args:
238 pattern: The pattern NodeKey to match against
239 target: The target NodeKey to match
240 p_idx: Current index in pattern parts
241 t_idx: Current index in target parts
243 Returns:
244 bool: True if pattern matches target, False otherwise
245 """
246 if p_idx == len(pattern.parts) and t_idx == len(target.parts):
247 return True
248 if p_idx == len(pattern.parts):
249 return False
250 if t_idx == len(target.parts):
251 return all(p == "**" for p in pattern.parts[p_idx:])
253 if pattern.parts[p_idx] == "**":
254 return _match_pattern_recursive(pattern, target, p_idx + 1, t_idx) or _match_pattern_recursive(
255 pattern, target, p_idx, t_idx + 1
256 )
257 elif pattern.parts[p_idx] == "*":
258 return _match_pattern_recursive(pattern, target, p_idx + 1, t_idx + 1)
259 else:
260 if pattern.parts[p_idx] == target.parts[t_idx]:
261 return _match_pattern_recursive(pattern, target, p_idx + 1, t_idx + 1)
262 return False
265def is_pattern(nodekey: NodeKey) -> bool:
266 """Check if a node key contains wildcard patterns."""
267 return any(isinstance(part, str) and ("*" in part or "**" in part) for part in nodekey.parts)
270def match_pattern(pattern: NodeKey, target: NodeKey) -> bool:
271 """Match a pattern against a target NodeKey.
273 Supports wildcards:
274 * - matches exactly one part
275 ** - matches zero or more parts
277 Args:
278 pattern: The pattern to match against
279 target: The target to match
281 Returns:
282 bool: True if pattern matches target, False otherwise
283 """
284 return _match_pattern_recursive(pattern, target, 0, 0)