Coverage for src/loman/nodekey.py: 96%
162 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-21 05:36 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-21 05:36 +0000
1"""Node key implementation for computation graph navigation."""
3import json
4import re
5from collections.abc import Iterable
6from dataclasses import dataclass
7from typing import Optional, Union
9from loman.util import as_iterable
11Name = Union[str, "NodeKey", object]
12Names = list[Name]
15class PathNotFoundError(Exception):
16 """Exception raised when a node path cannot be found."""
18 pass
21def quote_part(part: object) -> str:
22 """Quote a node key part for safe representation in paths."""
23 if isinstance(part, str):
24 if "/" in part:
25 return json.dumps(part)
26 else:
27 return part
28 return str(part)
31@dataclass(frozen=True, repr=False)
32class NodeKey:
33 """Immutable key for identifying nodes in the computation graph hierarchy."""
35 parts: tuple
37 def __str__(self):
38 """Return string representation using path notation."""
39 return "/".join([quote_part(part) for part in self.parts])
41 @property
42 def name(self) -> Name:
43 """Get the name of this node (last part of the path)."""
44 if len(self.parts) == 0:
45 return ""
46 elif len(self.parts) == 1:
47 return self.parts[0]
48 elif all(isinstance(part, str) for part in self.parts):
49 return "/".join(quote_part(part) for part in self.parts)
50 else:
51 return self
53 @property
54 def label(self) -> str:
55 """Get the label for this node (for display purposes)."""
56 if len(self.parts) == 0:
57 return ""
58 return str(self.parts[-1])
60 def drop_root(self, root: Optional["Name"]) -> Optional["NodeKey"]:
61 """Remove a root prefix from this node key if it matches."""
62 if root is None:
63 return self
64 root = to_nodekey(root)
65 n_root_parts = len(root.parts)
66 if self.is_descendent_of(root):
67 parts = self.parts[n_root_parts:]
68 return NodeKey(parts)
69 else:
70 return None
72 def join(self, *others: Name) -> "NodeKey":
73 """Join this node key with other names to create a new node key."""
74 result = self
75 for other in others:
76 if other is None:
77 continue
78 other = to_nodekey(other)
79 result = result.join_parts(*other.parts)
80 return result
82 def join_parts(self, *parts) -> "NodeKey":
83 """Join this node key with raw parts to create a new node key."""
84 if len(parts) == 0:
85 return self
86 return NodeKey(self.parts + tuple(parts))
88 def is_descendent_of(self, other: "NodeKey") -> bool:
89 """Check if this node key is a descendant of another node key."""
90 n_self_parts = len(self.parts)
91 n_other_parts = len(other.parts)
92 return n_self_parts > n_other_parts and self.parts[:n_other_parts] == other.parts
94 @property
95 def parent(self) -> "NodeKey":
96 """Get the parent node key."""
97 if len(self.parts) == 0:
98 raise PathNotFoundError()
99 return NodeKey(self.parts[:-1])
101 def prepend(self, nk: "NodeKey") -> "NodeKey":
102 """Prepend another node key to this one."""
103 return nk.join_parts(*self.parts)
105 def __repr__(self) -> str:
106 """Return string representation for debugging."""
107 path_str = str(self)
108 quoted_path_str = repr(path_str)
109 return f"{self.__class__.__name__}({quoted_path_str})"
111 def __eq__(self, other) -> bool:
112 """Check equality with another NodeKey."""
113 if other is None:
114 return False
115 if not isinstance(other, NodeKey):
116 return NotImplemented
117 return self.parts == other.parts
119 _ROOT = None
121 @classmethod
122 def root(cls) -> "NodeKey":
123 """Get the root node key."""
124 if cls._ROOT is None:
125 cls._ROOT = cls(())
126 return cls._ROOT
128 @property
129 def is_root(self):
130 """Check if this is the root node key."""
131 return len(self.parts) == 0
133 @staticmethod
134 def common_parent(nodekey1: Name, nodekey2: Name):
135 """Find the common parent of two node keys."""
136 nodekey1 = to_nodekey(nodekey1)
137 nodekey2 = to_nodekey(nodekey2)
138 parts = []
139 for p1, p2 in zip(nodekey1.parts, nodekey2.parts):
140 if p1 != p2:
141 break
142 parts.append(p1)
143 return NodeKey(tuple(parts))
145 def ancestors(self) -> list["NodeKey"]:
146 """Get all ancestor node keys from root to parent."""
147 result = []
148 x = self
149 while True:
150 result.append(x)
151 if x.is_root:
152 break
153 x = x.parent
154 return result
157def names_to_node_keys(names: Name | Names) -> list[NodeKey]:
158 """Convert names to NodeKey objects."""
159 return [to_nodekey(name) for name in as_iterable(names)]
162def node_keys_to_names(node_keys: Iterable[NodeKey]) -> list[Name]:
163 """Convert NodeKey objects back to names."""
164 return [node_key.name for node_key in node_keys]
167PART = re.compile(r"([^/]*)/?")
170def _parse_nodekey(path_str: str, end: int) -> NodeKey:
171 parts = []
172 parts_append = parts.append
174 while path_str[end : end + 1] == "/":
175 end = end + 1
176 while True:
177 nextchar = path_str[end : end + 1]
178 if nextchar == "":
179 break
180 if nextchar == '"':
181 part, end = json.decoder.scanstring(path_str, end + 1)
182 parts_append(part)
183 nextchar = path_str[end : end + 1]
184 assert nextchar == "" or nextchar == "/"
185 if nextchar != "":
186 end = end + 1
187 else:
188 chunk = PART.match(path_str, end)
189 end = chunk.end()
190 (part,) = chunk.groups()
191 parts_append(part)
193 assert end == len(path_str)
195 return NodeKey(tuple(parts))
198def parse_nodekey(path_str: str) -> NodeKey:
199 """Parse a string representation into a NodeKey."""
200 return _parse_nodekey(path_str, 0)
203def to_nodekey(name: Name) -> NodeKey:
204 """Convert a name to a NodeKey object."""
205 if isinstance(name, str):
206 return parse_nodekey(name)
207 elif isinstance(name, NodeKey):
208 return name
209 elif isinstance(name, object):
210 return NodeKey((name,))
211 else:
212 raise ValueError(f"Unexpected error creating node key for name {name}")
215def nodekey_join(*names: Name) -> NodeKey:
216 """Join multiple names into a single NodeKey."""
217 return NodeKey.root().join(*names)
220def _match_pattern_recursive(pattern: NodeKey, target: NodeKey, p_idx: int, t_idx: int) -> bool:
221 """Recursively match pattern parts against target parts.
223 Args:
224 pattern: The pattern NodeKey to match against
225 target: The target NodeKey to match
226 p_idx: Current index in pattern parts
227 t_idx: Current index in target parts
229 Returns:
230 bool: True if pattern matches target, False otherwise
231 """
232 if p_idx == len(pattern.parts) and t_idx == len(target.parts):
233 return True
234 if p_idx == len(pattern.parts):
235 return False
236 if t_idx == len(target.parts):
237 return all(p == "**" for p in pattern.parts[p_idx:])
239 if pattern.parts[p_idx] == "**":
240 return _match_pattern_recursive(pattern, target, p_idx + 1, t_idx) or _match_pattern_recursive(
241 pattern, target, p_idx, t_idx + 1
242 )
243 elif pattern.parts[p_idx] == "*":
244 return _match_pattern_recursive(pattern, target, p_idx + 1, t_idx + 1)
245 else:
246 if pattern.parts[p_idx] == target.parts[t_idx]:
247 return _match_pattern_recursive(pattern, target, p_idx + 1, t_idx + 1)
248 return False
251def is_pattern(nodekey: NodeKey) -> bool:
252 """Check if a node key contains wildcard patterns."""
253 return any("*" in part or "**" in part for part in nodekey.parts)
256def match_pattern(pattern: NodeKey, target: NodeKey) -> bool:
257 """Match a pattern against a target NodeKey.
259 Supports wildcards:
260 * - matches exactly one part
261 ** - matches zero or more parts
263 Args:
264 pattern: The pattern to match against
265 target: The target to match
267 Returns:
268 bool: True if pattern matches target, False otherwise
269 """
270 return _match_pattern_recursive(pattern, target, 0, 0)