Coverage for src/loman/nodekey.py: 96%

162 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-21 05:36 +0000

1"""Node key implementation for computation graph navigation.""" 

2 

3import json 

4import re 

5from collections.abc import Iterable 

6from dataclasses import dataclass 

7from typing import Optional, Union 

8 

9from loman.util import as_iterable 

10 

11Name = Union[str, "NodeKey", object] 

12Names = list[Name] 

13 

14 

15class PathNotFoundError(Exception): 

16 """Exception raised when a node path cannot be found.""" 

17 

18 pass 

19 

20 

21def quote_part(part: object) -> str: 

22 """Quote a node key part for safe representation in paths.""" 

23 if isinstance(part, str): 

24 if "/" in part: 

25 return json.dumps(part) 

26 else: 

27 return part 

28 return str(part) 

29 

30 

31@dataclass(frozen=True, repr=False) 

32class NodeKey: 

33 """Immutable key for identifying nodes in the computation graph hierarchy.""" 

34 

35 parts: tuple 

36 

37 def __str__(self): 

38 """Return string representation using path notation.""" 

39 return "/".join([quote_part(part) for part in self.parts]) 

40 

41 @property 

42 def name(self) -> Name: 

43 """Get the name of this node (last part of the path).""" 

44 if len(self.parts) == 0: 

45 return "" 

46 elif len(self.parts) == 1: 

47 return self.parts[0] 

48 elif all(isinstance(part, str) for part in self.parts): 

49 return "/".join(quote_part(part) for part in self.parts) 

50 else: 

51 return self 

52 

53 @property 

54 def label(self) -> str: 

55 """Get the label for this node (for display purposes).""" 

56 if len(self.parts) == 0: 

57 return "" 

58 return str(self.parts[-1]) 

59 

60 def drop_root(self, root: Optional["Name"]) -> Optional["NodeKey"]: 

61 """Remove a root prefix from this node key if it matches.""" 

62 if root is None: 

63 return self 

64 root = to_nodekey(root) 

65 n_root_parts = len(root.parts) 

66 if self.is_descendent_of(root): 

67 parts = self.parts[n_root_parts:] 

68 return NodeKey(parts) 

69 else: 

70 return None 

71 

72 def join(self, *others: Name) -> "NodeKey": 

73 """Join this node key with other names to create a new node key.""" 

74 result = self 

75 for other in others: 

76 if other is None: 

77 continue 

78 other = to_nodekey(other) 

79 result = result.join_parts(*other.parts) 

80 return result 

81 

82 def join_parts(self, *parts) -> "NodeKey": 

83 """Join this node key with raw parts to create a new node key.""" 

84 if len(parts) == 0: 

85 return self 

86 return NodeKey(self.parts + tuple(parts)) 

87 

88 def is_descendent_of(self, other: "NodeKey") -> bool: 

89 """Check if this node key is a descendant of another node key.""" 

90 n_self_parts = len(self.parts) 

91 n_other_parts = len(other.parts) 

92 return n_self_parts > n_other_parts and self.parts[:n_other_parts] == other.parts 

93 

94 @property 

95 def parent(self) -> "NodeKey": 

96 """Get the parent node key.""" 

97 if len(self.parts) == 0: 

98 raise PathNotFoundError() 

99 return NodeKey(self.parts[:-1]) 

100 

101 def prepend(self, nk: "NodeKey") -> "NodeKey": 

102 """Prepend another node key to this one.""" 

103 return nk.join_parts(*self.parts) 

104 

105 def __repr__(self) -> str: 

106 """Return string representation for debugging.""" 

107 path_str = str(self) 

108 quoted_path_str = repr(path_str) 

109 return f"{self.__class__.__name__}({quoted_path_str})" 

110 

111 def __eq__(self, other) -> bool: 

112 """Check equality with another NodeKey.""" 

113 if other is None: 

114 return False 

115 if not isinstance(other, NodeKey): 

116 return NotImplemented 

117 return self.parts == other.parts 

118 

119 _ROOT = None 

120 

121 @classmethod 

122 def root(cls) -> "NodeKey": 

123 """Get the root node key.""" 

124 if cls._ROOT is None: 

125 cls._ROOT = cls(()) 

126 return cls._ROOT 

127 

128 @property 

129 def is_root(self): 

130 """Check if this is the root node key.""" 

131 return len(self.parts) == 0 

132 

133 @staticmethod 

134 def common_parent(nodekey1: Name, nodekey2: Name): 

135 """Find the common parent of two node keys.""" 

136 nodekey1 = to_nodekey(nodekey1) 

137 nodekey2 = to_nodekey(nodekey2) 

138 parts = [] 

139 for p1, p2 in zip(nodekey1.parts, nodekey2.parts): 

140 if p1 != p2: 

141 break 

142 parts.append(p1) 

143 return NodeKey(tuple(parts)) 

144 

145 def ancestors(self) -> list["NodeKey"]: 

146 """Get all ancestor node keys from root to parent.""" 

147 result = [] 

148 x = self 

149 while True: 

150 result.append(x) 

151 if x.is_root: 

152 break 

153 x = x.parent 

154 return result 

155 

156 

157def names_to_node_keys(names: Name | Names) -> list[NodeKey]: 

158 """Convert names to NodeKey objects.""" 

159 return [to_nodekey(name) for name in as_iterable(names)] 

160 

161 

162def node_keys_to_names(node_keys: Iterable[NodeKey]) -> list[Name]: 

163 """Convert NodeKey objects back to names.""" 

164 return [node_key.name for node_key in node_keys] 

165 

166 

167PART = re.compile(r"([^/]*)/?") 

168 

169 

170def _parse_nodekey(path_str: str, end: int) -> NodeKey: 

171 parts = [] 

172 parts_append = parts.append 

173 

174 while path_str[end : end + 1] == "/": 

175 end = end + 1 

176 while True: 

177 nextchar = path_str[end : end + 1] 

178 if nextchar == "": 

179 break 

180 if nextchar == '"': 

181 part, end = json.decoder.scanstring(path_str, end + 1) 

182 parts_append(part) 

183 nextchar = path_str[end : end + 1] 

184 assert nextchar == "" or nextchar == "/" 

185 if nextchar != "": 

186 end = end + 1 

187 else: 

188 chunk = PART.match(path_str, end) 

189 end = chunk.end() 

190 (part,) = chunk.groups() 

191 parts_append(part) 

192 

193 assert end == len(path_str) 

194 

195 return NodeKey(tuple(parts)) 

196 

197 

198def parse_nodekey(path_str: str) -> NodeKey: 

199 """Parse a string representation into a NodeKey.""" 

200 return _parse_nodekey(path_str, 0) 

201 

202 

203def to_nodekey(name: Name) -> NodeKey: 

204 """Convert a name to a NodeKey object.""" 

205 if isinstance(name, str): 

206 return parse_nodekey(name) 

207 elif isinstance(name, NodeKey): 

208 return name 

209 elif isinstance(name, object): 

210 return NodeKey((name,)) 

211 else: 

212 raise ValueError(f"Unexpected error creating node key for name {name}") 

213 

214 

215def nodekey_join(*names: Name) -> NodeKey: 

216 """Join multiple names into a single NodeKey.""" 

217 return NodeKey.root().join(*names) 

218 

219 

220def _match_pattern_recursive(pattern: NodeKey, target: NodeKey, p_idx: int, t_idx: int) -> bool: 

221 """Recursively match pattern parts against target parts. 

222 

223 Args: 

224 pattern: The pattern NodeKey to match against 

225 target: The target NodeKey to match 

226 p_idx: Current index in pattern parts 

227 t_idx: Current index in target parts 

228 

229 Returns: 

230 bool: True if pattern matches target, False otherwise 

231 """ 

232 if p_idx == len(pattern.parts) and t_idx == len(target.parts): 

233 return True 

234 if p_idx == len(pattern.parts): 

235 return False 

236 if t_idx == len(target.parts): 

237 return all(p == "**" for p in pattern.parts[p_idx:]) 

238 

239 if pattern.parts[p_idx] == "**": 

240 return _match_pattern_recursive(pattern, target, p_idx + 1, t_idx) or _match_pattern_recursive( 

241 pattern, target, p_idx, t_idx + 1 

242 ) 

243 elif pattern.parts[p_idx] == "*": 

244 return _match_pattern_recursive(pattern, target, p_idx + 1, t_idx + 1) 

245 else: 

246 if pattern.parts[p_idx] == target.parts[t_idx]: 

247 return _match_pattern_recursive(pattern, target, p_idx + 1, t_idx + 1) 

248 return False 

249 

250 

251def is_pattern(nodekey: NodeKey) -> bool: 

252 """Check if a node key contains wildcard patterns.""" 

253 return any("*" in part or "**" in part for part in nodekey.parts) 

254 

255 

256def match_pattern(pattern: NodeKey, target: NodeKey) -> bool: 

257 """Match a pattern against a target NodeKey. 

258 

259 Supports wildcards: 

260 * - matches exactly one part 

261 ** - matches zero or more parts 

262 

263 Args: 

264 pattern: The pattern to match against 

265 target: The target to match 

266 

267 Returns: 

268 bool: True if pattern matches target, False otherwise 

269 """ 

270 return _match_pattern_recursive(pattern, target, 0, 0)