Coverage for src / loman / nodekey.py: 96%

164 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-02 23:34 +0000

1"""Node key implementation for computation graph navigation.""" 

2 

3import json 

4import re 

5from collections.abc import Iterable 

6from dataclasses import dataclass 

7from typing import Optional, Union 

8 

9from loman.util import as_iterable 

10 

11Name = Union[str, "NodeKey", object] 

12Names = list[Name] 

13 

14 

15class PathNotFoundError(Exception): 

16 """Exception raised when a node path cannot be found.""" 

17 

18 pass 

19 

20 

21def quote_part(part: object) -> str: 

22 """Quote a node key part for safe representation in paths.""" 

23 if isinstance(part, str): 

24 if "/" in part: 

25 return json.dumps(part) 

26 else: 

27 return part 

28 return str(part) 

29 

30 

31@dataclass(frozen=True, repr=False) 

32class NodeKey: 

33 """Immutable key for identifying nodes in the computation graph hierarchy.""" 

34 

35 parts: tuple 

36 

37 def __str__(self): 

38 """Return string representation using path notation.""" 

39 return "/".join([quote_part(part) for part in self.parts]) 

40 

41 @property 

42 def name(self) -> Name: 

43 """Get the name of this node (last part of the path).""" 

44 if len(self.parts) == 0: 

45 return "" 

46 elif len(self.parts) == 1: 

47 return self.parts[0] 

48 elif all(isinstance(part, str) for part in self.parts): 

49 return "/".join(quote_part(part) for part in self.parts) 

50 else: 

51 return self 

52 

53 @property 

54 def label(self) -> str: 

55 """Get the label for this node (for display purposes).""" 

56 if len(self.parts) == 0: 

57 return "" 

58 return str(self.parts[-1]) 

59 

60 def drop_root(self, root: Optional["Name"]) -> Optional["NodeKey"]: 

61 """Remove a root prefix from this node key if it matches.""" 

62 if root is None: 

63 return self 

64 root = to_nodekey(root) 

65 n_root_parts = len(root.parts) 

66 if self.is_descendent_of(root): 

67 parts = self.parts[n_root_parts:] 

68 return NodeKey(parts) 

69 else: 

70 return None 

71 

72 def join(self, *others: Name) -> "NodeKey": 

73 """Join this node key with other names to create a new node key.""" 

74 result = self 

75 for other in others: 

76 if other is None: 

77 continue 

78 other = to_nodekey(other) 

79 result = result.join_parts(*other.parts) 

80 return result 

81 

82 def join_parts(self, *parts) -> "NodeKey": 

83 """Join this node key with raw parts to create a new node key.""" 

84 if len(parts) == 0: 

85 return self 

86 return NodeKey(self.parts + tuple(parts)) 

87 

88 def __truediv__(self, other: Name) -> "NodeKey": 

89 """Join this node key with other to create a new node key.""" 

90 return self.join(other) 

91 

92 def is_descendent_of(self, other: "NodeKey") -> bool: 

93 """Check if this node key is a descendant of another node key.""" 

94 n_self_parts = len(self.parts) 

95 n_other_parts = len(other.parts) 

96 return n_self_parts > n_other_parts and self.parts[:n_other_parts] == other.parts 

97 

98 @property 

99 def parent(self) -> "NodeKey": 

100 """Get the parent node key.""" 

101 if len(self.parts) == 0: 

102 raise PathNotFoundError() 

103 return NodeKey(self.parts[:-1]) 

104 

105 def prepend(self, nk: "NodeKey") -> "NodeKey": 

106 """Prepend another node key to this one.""" 

107 return nk.join_parts(*self.parts) 

108 

109 def __repr__(self) -> str: 

110 """Return string representation for debugging.""" 

111 path_str = str(self) 

112 quoted_path_str = repr(path_str) 

113 return f"{self.__class__.__name__}({quoted_path_str})" 

114 

115 def __eq__(self, other) -> bool: 

116 """Check equality with another NodeKey.""" 

117 if other is None: 

118 return False 

119 if not isinstance(other, NodeKey): 

120 return NotImplemented 

121 return self.parts == other.parts 

122 

123 _ROOT = None 

124 

125 @classmethod 

126 def root(cls) -> "NodeKey": 

127 """Get the root node key.""" 

128 if cls._ROOT is None: 

129 cls._ROOT = cls(()) 

130 return cls._ROOT 

131 

132 @property 

133 def is_root(self): 

134 """Check if this is the root node key.""" 

135 return len(self.parts) == 0 

136 

137 @staticmethod 

138 def common_parent(nodekey1: Name, nodekey2: Name): 

139 """Find the common parent of two node keys.""" 

140 nodekey1 = to_nodekey(nodekey1) 

141 nodekey2 = to_nodekey(nodekey2) 

142 parts = [] 

143 for p1, p2 in zip(nodekey1.parts, nodekey2.parts): 

144 if p1 != p2: 

145 break 

146 parts.append(p1) 

147 return NodeKey(tuple(parts)) 

148 

149 def ancestors(self) -> list["NodeKey"]: 

150 """Get all ancestor node keys from root to parent.""" 

151 result = [] 

152 x = self 

153 while True: 

154 result.append(x) 

155 if x.is_root: 

156 break 

157 x = x.parent 

158 return result 

159 

160 

161def names_to_node_keys(names: Name | Names) -> list[NodeKey]: 

162 """Convert names to NodeKey objects.""" 

163 return [to_nodekey(name) for name in as_iterable(names)] 

164 

165 

166def node_keys_to_names(node_keys: Iterable[NodeKey]) -> list[Name]: 

167 """Convert NodeKey objects back to names.""" 

168 return [node_key.name for node_key in node_keys] 

169 

170 

171PART = re.compile(r"([^/]*)/?") 

172 

173 

174def _parse_nodekey(path_str: str, end: int) -> NodeKey: 

175 parts = [] 

176 parts_append = parts.append 

177 

178 while path_str[end : end + 1] == "/": 

179 end = end + 1 

180 while True: 

181 nextchar = path_str[end : end + 1] 

182 if nextchar == "": 

183 break 

184 if nextchar == '"': 

185 part, end = json.decoder.scanstring(path_str, end + 1) 

186 parts_append(part) 

187 nextchar = path_str[end : end + 1] 

188 assert nextchar == "" or nextchar == "/" 

189 if nextchar != "": 

190 end = end + 1 

191 else: 

192 chunk = PART.match(path_str, end) 

193 end = chunk.end() 

194 (part,) = chunk.groups() 

195 parts_append(part) 

196 

197 assert end == len(path_str) 

198 

199 return NodeKey(tuple(parts)) 

200 

201 

202def parse_nodekey(path_str: str) -> NodeKey: 

203 """Parse a string representation into a NodeKey.""" 

204 return _parse_nodekey(path_str, 0) 

205 

206 

207def to_nodekey(name: Name) -> NodeKey: 

208 """Convert a name to a NodeKey object.""" 

209 if isinstance(name, str): 

210 return parse_nodekey(name) 

211 elif isinstance(name, NodeKey): 

212 return name 

213 elif isinstance(name, object): 

214 return NodeKey((name,)) 

215 else: 

216 raise ValueError(f"Unexpected error creating node key for name {name}") 

217 

218 

219def nodekey_join(*names: Name) -> NodeKey: 

220 """Join multiple names into a single NodeKey.""" 

221 return NodeKey.root().join(*names) 

222 

223 

224def _match_pattern_recursive(pattern: NodeKey, target: NodeKey, p_idx: int, t_idx: int) -> bool: 

225 """Recursively match pattern parts against target parts. 

226 

227 Args: 

228 pattern: The pattern NodeKey to match against 

229 target: The target NodeKey to match 

230 p_idx: Current index in pattern parts 

231 t_idx: Current index in target parts 

232 

233 Returns: 

234 bool: True if pattern matches target, False otherwise 

235 """ 

236 if p_idx == len(pattern.parts) and t_idx == len(target.parts): 

237 return True 

238 if p_idx == len(pattern.parts): 

239 return False 

240 if t_idx == len(target.parts): 

241 return all(p == "**" for p in pattern.parts[p_idx:]) 

242 

243 if pattern.parts[p_idx] == "**": 

244 return _match_pattern_recursive(pattern, target, p_idx + 1, t_idx) or _match_pattern_recursive( 

245 pattern, target, p_idx, t_idx + 1 

246 ) 

247 elif pattern.parts[p_idx] == "*": 

248 return _match_pattern_recursive(pattern, target, p_idx + 1, t_idx + 1) 

249 else: 

250 if pattern.parts[p_idx] == target.parts[t_idx]: 

251 return _match_pattern_recursive(pattern, target, p_idx + 1, t_idx + 1) 

252 return False 

253 

254 

255def is_pattern(nodekey: NodeKey) -> bool: 

256 """Check if a node key contains wildcard patterns.""" 

257 return any("*" in part or "**" in part for part in nodekey.parts) 

258 

259 

260def match_pattern(pattern: NodeKey, target: NodeKey) -> bool: 

261 """Match a pattern against a target NodeKey. 

262 

263 Supports wildcards: 

264 * - matches exactly one part 

265 ** - matches zero or more parts 

266 

267 Args: 

268 pattern: The pattern to match against 

269 target: The target to match 

270 

271 Returns: 

272 bool: True if pattern matches target, False otherwise 

273 """ 

274 return _match_pattern_recursive(pattern, target, 0, 0)