Coverage for src / loman / nodekey.py: 100%

165 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-22 21:30 +0000

1"""Node key implementation for computation graph navigation.""" 

2 

3import json 

4import re 

5from collections.abc import Iterable 

6from dataclasses import dataclass 

7from typing import Any, ClassVar, Optional, Union 

8 

9from loman.util import as_iterable 

10 

11Name = Union[str, "NodeKey", object] 

12Names = list[Name] 

13 

14 

15class PathNotFoundError(Exception): 

16 """Exception raised when a node path cannot be found.""" 

17 

18 pass 

19 

20 

21def quote_part(part: object) -> str: 

22 """Quote a node key part for safe representation in paths.""" 

23 if isinstance(part, str): 

24 if "/" in part: 

25 return json.dumps(part) 

26 else: 

27 return part 

28 return str(part) 

29 

30 

31@dataclass(frozen=True, repr=False) 

32class NodeKey: 

33 """Immutable key for identifying nodes in the computation graph hierarchy.""" 

34 

35 parts: tuple[Any, ...] 

36 

37 def __str__(self) -> str: 

38 """Return string representation using path notation.""" 

39 return "/".join([quote_part(part) for part in self.parts]) 

40 

41 @property 

42 def name(self) -> Name: 

43 """Get the name of this node (last part of the path).""" 

44 if len(self.parts) == 0: 

45 return "" 

46 elif len(self.parts) == 1: 

47 part: Name = self.parts[0] 

48 return part 

49 elif all(isinstance(part, str) for part in self.parts): 

50 return "/".join(quote_part(part) for part in self.parts) 

51 else: 

52 return self 

53 

54 @property 

55 def label(self) -> str: 

56 """Get the label for this node (for display purposes).""" 

57 if len(self.parts) == 0: 

58 return "" 

59 return str(self.parts[-1]) 

60 

61 def drop_root(self, root: Optional["Name"]) -> Optional["NodeKey"]: 

62 """Remove a root prefix from this node key if it matches.""" 

63 if root is None: 

64 return self 

65 root = to_nodekey(root) 

66 n_root_parts = len(root.parts) 

67 if self.is_descendent_of(root): 

68 parts = self.parts[n_root_parts:] 

69 return NodeKey(parts) 

70 else: 

71 return None 

72 

73 def join(self, *others: Name) -> "NodeKey": 

74 """Join this node key with other names to create a new node key.""" 

75 result = self 

76 for other in others: 

77 if other is None: 

78 continue 

79 other = to_nodekey(other) 

80 result = result.join_parts(*other.parts) 

81 return result 

82 

83 def join_parts(self, *parts: Any) -> "NodeKey": 

84 """Join this node key with raw parts to create a new node key.""" 

85 if len(parts) == 0: 

86 return self 

87 return NodeKey(self.parts + tuple(parts)) 

88 

89 def __truediv__(self, other: Name) -> "NodeKey": 

90 """Join this node key with other to create a new node key.""" 

91 return self.join(other) 

92 

93 def is_descendent_of(self, other: "NodeKey") -> bool: 

94 """Check if this node key is a descendant of another node key.""" 

95 n_self_parts = len(self.parts) 

96 n_other_parts = len(other.parts) 

97 return n_self_parts > n_other_parts and self.parts[:n_other_parts] == other.parts 

98 

99 @property 

100 def parent(self) -> "NodeKey": 

101 """Get the parent node key.""" 

102 if len(self.parts) == 0: 

103 raise PathNotFoundError() 

104 return NodeKey(self.parts[:-1]) 

105 

106 def prepend(self, nk: "NodeKey") -> "NodeKey": 

107 """Prepend another node key to this one.""" 

108 return nk.join_parts(*self.parts) 

109 

110 def __repr__(self) -> str: 

111 """Return string representation for debugging.""" 

112 path_str = str(self) 

113 quoted_path_str = repr(path_str) 

114 return f"{self.__class__.__name__}({quoted_path_str})" 

115 

116 def __eq__(self, other: object) -> bool: 

117 """Check equality with another NodeKey.""" 

118 if other is None: 

119 return False 

120 if not isinstance(other, NodeKey): 

121 return NotImplemented 

122 return self.parts == other.parts 

123 

124 _ROOT: ClassVar["NodeKey | None"] = None 

125 

126 @classmethod 

127 def root(cls) -> "NodeKey": 

128 """Get the root node key.""" 

129 if cls._ROOT is None: 

130 cls._ROOT = cls(()) 

131 return cls._ROOT 

132 

133 @property 

134 def is_root(self) -> bool: 

135 """Check if this is the root node key.""" 

136 return len(self.parts) == 0 

137 

138 @staticmethod 

139 def common_parent(nodekey1: Name, nodekey2: Name) -> "NodeKey": 

140 """Find the common parent of two node keys.""" 

141 nk1 = to_nodekey(nodekey1) 

142 nk2 = to_nodekey(nodekey2) 

143 parts: list[Any] = [] 

144 for p1, p2 in zip(nk1.parts, nk2.parts, strict=False): 

145 if p1 != p2: 

146 break 

147 parts.append(p1) 

148 return NodeKey(tuple(parts)) 

149 

150 def ancestors(self) -> list["NodeKey"]: 

151 """Get all ancestor node keys from root to parent.""" 

152 result = [] 

153 x = self 

154 while True: 

155 result.append(x) 

156 if x.is_root: 

157 break 

158 x = x.parent 

159 return result 

160 

161 

162def names_to_node_keys(names: Name | Names) -> list[NodeKey]: 

163 """Convert names to NodeKey objects.""" 

164 return [to_nodekey(name) for name in as_iterable(names)] 

165 

166 

167def node_keys_to_names(node_keys: Iterable[NodeKey]) -> list[Name]: 

168 """Convert NodeKey objects back to names.""" 

169 return [node_key.name for node_key in node_keys] 

170 

171 

172PART = re.compile(r"([^/]*)/?") 

173 

174 

175def _parse_nodekey(path_str: str, end: int) -> NodeKey: 

176 """Parse a path string into a NodeKey starting from the given position.""" 

177 parts: list[str] = [] 

178 parts_append = parts.append 

179 

180 while path_str[end : end + 1] == "/": 

181 end = end + 1 

182 while True: 

183 nextchar = path_str[end : end + 1] 

184 if nextchar == "": 

185 break 

186 if nextchar == '"': 

187 part, end = json.decoder.scanstring(path_str, end + 1) # type: ignore[attr-defined] 

188 parts_append(part) 

189 nextchar = path_str[end : end + 1] 

190 assert nextchar == "" or nextchar == "/" # noqa: S101 

191 if nextchar != "": 

192 end = end + 1 

193 else: 

194 chunk = PART.match(path_str, end) 

195 assert chunk is not None # noqa: S101 

196 end = chunk.end() 

197 (part,) = chunk.groups() 

198 parts_append(part) 

199 

200 assert end == len(path_str) # noqa: S101 

201 

202 return NodeKey(tuple(parts)) 

203 

204 

205def parse_nodekey(path_str: str) -> NodeKey: 

206 """Parse a string representation into a NodeKey.""" 

207 return _parse_nodekey(path_str, 0) 

208 

209 

210def to_nodekey(name: Name) -> NodeKey: 

211 """Convert a name to a NodeKey object.""" 

212 if isinstance(name, str): 

213 return parse_nodekey(name) 

214 elif isinstance(name, NodeKey): 

215 return name 

216 elif isinstance(name, object): 

217 return NodeKey((name,)) 

218 else: # pragma: no cover 

219 msg = f"Unexpected error creating node key for name {name}" 

220 raise TypeError(msg) 

221 

222 

223def nodekey_join(*names: Name) -> NodeKey: 

224 """Join multiple names into a single NodeKey.""" 

225 return NodeKey.root().join(*names) 

226 

227 

228def _match_pattern_recursive(pattern: NodeKey, target: NodeKey, p_idx: int, t_idx: int) -> bool: 

229 """Recursively match pattern parts against target parts. 

230 

231 Args: 

232 pattern: The pattern NodeKey to match against 

233 target: The target NodeKey to match 

234 p_idx: Current index in pattern parts 

235 t_idx: Current index in target parts 

236 

237 Returns: 

238 bool: True if pattern matches target, False otherwise 

239 """ 

240 if p_idx == len(pattern.parts) and t_idx == len(target.parts): 

241 return True 

242 if p_idx == len(pattern.parts): 

243 return False 

244 if t_idx == len(target.parts): 

245 return all(p == "**" for p in pattern.parts[p_idx:]) 

246 

247 if pattern.parts[p_idx] == "**": 

248 return _match_pattern_recursive(pattern, target, p_idx + 1, t_idx) or _match_pattern_recursive( 

249 pattern, target, p_idx, t_idx + 1 

250 ) 

251 elif pattern.parts[p_idx] == "*": 

252 return _match_pattern_recursive(pattern, target, p_idx + 1, t_idx + 1) 

253 else: 

254 if pattern.parts[p_idx] == target.parts[t_idx]: 

255 return _match_pattern_recursive(pattern, target, p_idx + 1, t_idx + 1) 

256 return False 

257 

258 

259def is_pattern(nodekey: NodeKey) -> bool: 

260 """Check if a node key contains wildcard patterns.""" 

261 return any("*" in part or "**" in part for part in nodekey.parts) 

262 

263 

264def match_pattern(pattern: NodeKey, target: NodeKey) -> bool: 

265 """Match a pattern against a target NodeKey. 

266 

267 Supports wildcards: 

268 * - matches exactly one part 

269 ** - matches zero or more parts 

270 

271 Args: 

272 pattern: The pattern to match against 

273 target: The target to match 

274 

275 Returns: 

276 bool: True if pattern matches target, False otherwise 

277 """ 

278 return _match_pattern_recursive(pattern, target, 0, 0)