Coverage for src / loman / nodekey.py: 96%
164 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-02 23:34 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-02 23:34 +0000
1"""Node key implementation for computation graph navigation."""
3import json
4import re
5from collections.abc import Iterable
6from dataclasses import dataclass
7from typing import Optional, Union
9from loman.util import as_iterable
11Name = Union[str, "NodeKey", object]
12Names = list[Name]
15class PathNotFoundError(Exception):
16 """Exception raised when a node path cannot be found."""
18 pass
21def quote_part(part: object) -> str:
22 """Quote a node key part for safe representation in paths."""
23 if isinstance(part, str):
24 if "/" in part:
25 return json.dumps(part)
26 else:
27 return part
28 return str(part)
31@dataclass(frozen=True, repr=False)
32class NodeKey:
33 """Immutable key for identifying nodes in the computation graph hierarchy."""
35 parts: tuple
37 def __str__(self):
38 """Return string representation using path notation."""
39 return "/".join([quote_part(part) for part in self.parts])
41 @property
42 def name(self) -> Name:
43 """Get the name of this node (last part of the path)."""
44 if len(self.parts) == 0:
45 return ""
46 elif len(self.parts) == 1:
47 return self.parts[0]
48 elif all(isinstance(part, str) for part in self.parts):
49 return "/".join(quote_part(part) for part in self.parts)
50 else:
51 return self
53 @property
54 def label(self) -> str:
55 """Get the label for this node (for display purposes)."""
56 if len(self.parts) == 0:
57 return ""
58 return str(self.parts[-1])
60 def drop_root(self, root: Optional["Name"]) -> Optional["NodeKey"]:
61 """Remove a root prefix from this node key if it matches."""
62 if root is None:
63 return self
64 root = to_nodekey(root)
65 n_root_parts = len(root.parts)
66 if self.is_descendent_of(root):
67 parts = self.parts[n_root_parts:]
68 return NodeKey(parts)
69 else:
70 return None
72 def join(self, *others: Name) -> "NodeKey":
73 """Join this node key with other names to create a new node key."""
74 result = self
75 for other in others:
76 if other is None:
77 continue
78 other = to_nodekey(other)
79 result = result.join_parts(*other.parts)
80 return result
82 def join_parts(self, *parts) -> "NodeKey":
83 """Join this node key with raw parts to create a new node key."""
84 if len(parts) == 0:
85 return self
86 return NodeKey(self.parts + tuple(parts))
88 def __truediv__(self, other: Name) -> "NodeKey":
89 """Join this node key with other to create a new node key."""
90 return self.join(other)
92 def is_descendent_of(self, other: "NodeKey") -> bool:
93 """Check if this node key is a descendant of another node key."""
94 n_self_parts = len(self.parts)
95 n_other_parts = len(other.parts)
96 return n_self_parts > n_other_parts and self.parts[:n_other_parts] == other.parts
98 @property
99 def parent(self) -> "NodeKey":
100 """Get the parent node key."""
101 if len(self.parts) == 0:
102 raise PathNotFoundError()
103 return NodeKey(self.parts[:-1])
105 def prepend(self, nk: "NodeKey") -> "NodeKey":
106 """Prepend another node key to this one."""
107 return nk.join_parts(*self.parts)
109 def __repr__(self) -> str:
110 """Return string representation for debugging."""
111 path_str = str(self)
112 quoted_path_str = repr(path_str)
113 return f"{self.__class__.__name__}({quoted_path_str})"
115 def __eq__(self, other) -> bool:
116 """Check equality with another NodeKey."""
117 if other is None:
118 return False
119 if not isinstance(other, NodeKey):
120 return NotImplemented
121 return self.parts == other.parts
123 _ROOT = None
125 @classmethod
126 def root(cls) -> "NodeKey":
127 """Get the root node key."""
128 if cls._ROOT is None:
129 cls._ROOT = cls(())
130 return cls._ROOT
132 @property
133 def is_root(self):
134 """Check if this is the root node key."""
135 return len(self.parts) == 0
137 @staticmethod
138 def common_parent(nodekey1: Name, nodekey2: Name):
139 """Find the common parent of two node keys."""
140 nodekey1 = to_nodekey(nodekey1)
141 nodekey2 = to_nodekey(nodekey2)
142 parts = []
143 for p1, p2 in zip(nodekey1.parts, nodekey2.parts):
144 if p1 != p2:
145 break
146 parts.append(p1)
147 return NodeKey(tuple(parts))
149 def ancestors(self) -> list["NodeKey"]:
150 """Get all ancestor node keys from root to parent."""
151 result = []
152 x = self
153 while True:
154 result.append(x)
155 if x.is_root:
156 break
157 x = x.parent
158 return result
161def names_to_node_keys(names: Name | Names) -> list[NodeKey]:
162 """Convert names to NodeKey objects."""
163 return [to_nodekey(name) for name in as_iterable(names)]
166def node_keys_to_names(node_keys: Iterable[NodeKey]) -> list[Name]:
167 """Convert NodeKey objects back to names."""
168 return [node_key.name for node_key in node_keys]
171PART = re.compile(r"([^/]*)/?")
174def _parse_nodekey(path_str: str, end: int) -> NodeKey:
175 parts = []
176 parts_append = parts.append
178 while path_str[end : end + 1] == "/":
179 end = end + 1
180 while True:
181 nextchar = path_str[end : end + 1]
182 if nextchar == "":
183 break
184 if nextchar == '"':
185 part, end = json.decoder.scanstring(path_str, end + 1)
186 parts_append(part)
187 nextchar = path_str[end : end + 1]
188 assert nextchar == "" or nextchar == "/"
189 if nextchar != "":
190 end = end + 1
191 else:
192 chunk = PART.match(path_str, end)
193 end = chunk.end()
194 (part,) = chunk.groups()
195 parts_append(part)
197 assert end == len(path_str)
199 return NodeKey(tuple(parts))
202def parse_nodekey(path_str: str) -> NodeKey:
203 """Parse a string representation into a NodeKey."""
204 return _parse_nodekey(path_str, 0)
207def to_nodekey(name: Name) -> NodeKey:
208 """Convert a name to a NodeKey object."""
209 if isinstance(name, str):
210 return parse_nodekey(name)
211 elif isinstance(name, NodeKey):
212 return name
213 elif isinstance(name, object):
214 return NodeKey((name,))
215 else:
216 raise ValueError(f"Unexpected error creating node key for name {name}")
219def nodekey_join(*names: Name) -> NodeKey:
220 """Join multiple names into a single NodeKey."""
221 return NodeKey.root().join(*names)
224def _match_pattern_recursive(pattern: NodeKey, target: NodeKey, p_idx: int, t_idx: int) -> bool:
225 """Recursively match pattern parts against target parts.
227 Args:
228 pattern: The pattern NodeKey to match against
229 target: The target NodeKey to match
230 p_idx: Current index in pattern parts
231 t_idx: Current index in target parts
233 Returns:
234 bool: True if pattern matches target, False otherwise
235 """
236 if p_idx == len(pattern.parts) and t_idx == len(target.parts):
237 return True
238 if p_idx == len(pattern.parts):
239 return False
240 if t_idx == len(target.parts):
241 return all(p == "**" for p in pattern.parts[p_idx:])
243 if pattern.parts[p_idx] == "**":
244 return _match_pattern_recursive(pattern, target, p_idx + 1, t_idx) or _match_pattern_recursive(
245 pattern, target, p_idx, t_idx + 1
246 )
247 elif pattern.parts[p_idx] == "*":
248 return _match_pattern_recursive(pattern, target, p_idx + 1, t_idx + 1)
249 else:
250 if pattern.parts[p_idx] == target.parts[t_idx]:
251 return _match_pattern_recursive(pattern, target, p_idx + 1, t_idx + 1)
252 return False
255def is_pattern(nodekey: NodeKey) -> bool:
256 """Check if a node key contains wildcard patterns."""
257 return any("*" in part or "**" in part for part in nodekey.parts)
260def match_pattern(pattern: NodeKey, target: NodeKey) -> bool:
261 """Match a pattern against a target NodeKey.
263 Supports wildcards:
264 * - matches exactly one part
265 ** - matches zero or more parts
267 Args:
268 pattern: The pattern to match against
269 target: The target to match
271 Returns:
272 bool: True if pattern matches target, False otherwise
273 """
274 return _match_pattern_recursive(pattern, target, 0, 0)