Coverage for src/loman/computeengine.py: 89%

991 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-21 05:36 +0000

1"""Core computation engine for dependency-aware calculation graphs.""" 

2 

3import inspect 

4import logging 

5import traceback 

6import types 

7import warnings 

8from collections import defaultdict 

9from collections.abc import Callable, Iterable, Mapping 

10from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait 

11from dataclasses import dataclass, field 

12from datetime import UTC, datetime 

13from enum import Enum 

14from typing import Any, Type, Union # noqa: UP035 

15 

16import decorator 

17import dill 

18import networkx as nx 

19import pandas as pd 

20 

21from .compat import get_signature 

22from .consts import EdgeAttributes, NodeAttributes, NodeTransformations, States, SystemTags 

23from .exception import ( 

24 CannotInsertToPlaceholderNodeException, 

25 ComputationError, 

26 LoopDetectedException, 

27 MapException, 

28 NodeAlreadyExistsException, 

29 NonExistentNodeException, 

30) 

31from .nodekey import Name, Names, NodeKey, names_to_node_keys, node_keys_to_names, to_nodekey 

32from .util import AttributeView, apply1, apply_n, as_iterable, value_eq 

33from .visualization import GraphView, NodeFormatter 

34 

35LOG = logging.getLogger("loman.computeengine") 

36 

37 

38@dataclass 

39class Error: 

40 """Container for error information during computation.""" 

41 

42 exception: Exception 

43 traceback: inspect.Traceback 

44 

45 

46@dataclass 

47class NodeData: 

48 """Data associated with a computation node.""" 

49 

50 state: States 

51 value: object 

52 

53 

54@dataclass 

55class TimingData: 

56 """Timing information for computation execution.""" 

57 

58 start: datetime 

59 end: datetime 

60 duration: float 

61 

62 

63class _ParameterType(Enum): 

64 ARG = 1 

65 KWD = 2 

66 

67 

68@dataclass 

69class _ParameterItem: 

70 type: object 

71 name: int | str 

72 value: object 

73 

74 

75def _node(func, *args, **kws): 

76 return func(*args, **kws) 

77 

78 

79def node(comp, name=None, *args, **kw): 

80 """Decorator to add a function as a node to a computation graph.""" 

81 

82 def inner(f): 

83 if name is None: 

84 comp.add_node(f.__name__, f, *args, **kw) 

85 else: 

86 comp.add_node(name, f, *args, **kw) 

87 return decorator.decorate(f, _node) 

88 

89 return inner 

90 

91 

92@dataclass() 

93class ConstantValue: 

94 """Container for constant values in computations.""" 

95 

96 value: object 

97 

98 

99C = ConstantValue 

100 

101 

102class Node: 

103 """Base class for computation graph nodes.""" 

104 

105 def add_to_comp(self, comp: "Computation", name: str, obj: object, ignore_self: bool): 

106 """Add this node to the computation graph.""" 

107 raise NotImplementedError() 

108 

109 

110@dataclass 

111class InputNode(Node): 

112 """A node representing input data in the computation graph.""" 

113 

114 args: tuple[Any, ...] = field(default_factory=tuple) 

115 kwds: dict = field(default_factory=dict) 

116 

117 def __init__(self, *args, **kwds): 

118 """Initialize an input node with arguments and keyword arguments.""" 

119 self.args = args 

120 self.kwds = kwds 

121 

122 def add_to_comp(self, comp: "Computation", name: str, obj: object, ignore_self: bool): 

123 """Add this input node to the computation graph.""" 

124 comp.add_node(name, **self.kwds) 

125 

126 

127input_node = InputNode 

128 

129 

130@dataclass 

131class CalcNode(Node): 

132 """A node representing a calculation in the computation graph.""" 

133 

134 f: Callable 

135 kwds: dict = field(default_factory=dict) 

136 

137 def add_to_comp(self, comp: "Computation", name: str, obj: object, ignore_self: bool): 

138 """Add this calculation node to the computation graph.""" 

139 kwds = self.kwds.copy() 

140 ignore_self = ignore_self or kwds.get("ignore_self", False) 

141 f = self.f 

142 if ignore_self: 

143 signature = get_signature(self.f) 

144 if len(signature.kwd_params) > 0 and signature.kwd_params[0] == "self": 

145 f = f.__get__(obj, obj.__class__) 

146 if "ignore_self" in kwds: 

147 del kwds["ignore_self"] 

148 comp.add_node(name, f, **kwds) 

149 

150 

151def calc_node(f=None, **kwds): 

152 """Decorator to mark a function as a calculation node.""" 

153 

154 def wrap(func): 

155 func._loman_node_info = CalcNode(func, kwds) 

156 return func 

157 

158 if f is None: 

159 return wrap 

160 return wrap(f) 

161 

162 

163@dataclass 

164class Block(Node): 

165 """A node representing a computational block or subgraph.""" 

166 

167 block: Union[Callable, "Computation"] 

168 args: tuple[Any, ...] = field(default_factory=tuple) 

169 kwds: dict = field(default_factory=dict) 

170 

171 def __init__(self, block, *args, **kwds): 

172 """Initialize a block node with a computation block and arguments.""" 

173 self.block = block 

174 self.args = args 

175 self.kwds = kwds 

176 

177 def add_to_comp(self, comp: "Computation", name: str, obj: object, ignore_self: bool): 

178 """Add this block node to the computation graph.""" 

179 if isinstance(self.block, Computation): 

180 comp.add_block(name, self.block, *self.args, **self.kwds) 

181 elif callable(self.block): 

182 block0 = self.block() 

183 comp.add_block(name, block0, *self.args, **self.kwds) 

184 else: 

185 raise TypeError(f"Block {self.block} must be callable or Computation") 

186 

187 

188block = Block 

189 

190 

191def populate_computation_from_class(comp, cls, obj, ignore_self=True): 

192 """Populate a computation from class methods with node decorators.""" 

193 for name, member in inspect.getmembers(cls): 

194 node_ = None 

195 if isinstance(member, Node): 

196 node_ = member 

197 elif hasattr(member, "_loman_node_info"): 

198 node_ = getattr(member, "_loman_node_info") 

199 if node_ is not None: 

200 node_.add_to_comp(comp, name, obj, ignore_self) 

201 

202 

203def computation_factory(maybe_cls=None, *, ignore_self=True) -> Type["Computation"]: # noqa: UP006 

204 """Factory function to create computations from class definitions.""" 

205 

206 def wrap(cls): 

207 def create_computation(*args, **kwargs): 

208 obj = cls() 

209 comp = Computation(*args, **kwargs) 

210 comp._definition_object = obj 

211 populate_computation_from_class(comp, cls, obj, ignore_self) 

212 return comp 

213 

214 return create_computation 

215 

216 if maybe_cls is None: 

217 return wrap 

218 

219 return wrap(maybe_cls) 

220 

221 

222def _eval_node(name, f, args, kwds, raise_exceptions): 

223 """To make multiprocessing work, this function must be standalone so that pickle works.""" 

224 exc, tb = None, None 

225 start_dt = datetime.now(UTC) 

226 try: 

227 logging.debug("Running " + str(name)) 

228 value = f(*args, **kwds) 

229 logging.debug("Completed " + str(name)) 

230 except Exception as e: 

231 value = None 

232 exc = e 

233 tb = traceback.format_exc() 

234 if raise_exceptions: 

235 raise 

236 end_dt = datetime.now(UTC) 

237 return value, exc, tb, start_dt, end_dt 

238 

239 

240_MISSING_VALUE_SENTINEL = object() 

241 

242 

243class NullObject: 

244 """Debug helper object that raises exceptions for all attribute/item access.""" 

245 

246 def __getattr__(self, name): 

247 """Raise AttributeError for any attribute access.""" 

248 print(f"__getattr__: {name}") 

249 raise AttributeError(f"'NullObject' object has no attribute '{name}'") 

250 

251 def __setattr__(self, name, value): 

252 """Raise AttributeError for any attribute assignment.""" 

253 print(f"__setattr__: {name}") 

254 raise AttributeError(f"'NullObject' object has no attribute '{name}'") 

255 

256 def __delattr__(self, name): 

257 """Raise AttributeError for any attribute deletion.""" 

258 print(f"__delattr__: {name}") 

259 raise AttributeError(f"'NullObject' object has no attribute '{name}'") 

260 

261 def __call__(self, *args, **kwargs): 

262 """Raise TypeError when called as a function.""" 

263 print(f"__call__: {args}, {kwargs}") 

264 raise TypeError("'NullObject' object is not callable") 

265 

266 def __getitem__(self, key): 

267 """Raise KeyError for any item access.""" 

268 print(f"__getitem__: {key}") 

269 raise KeyError(f"'NullObject' object has no item with key '{key}'") 

270 

271 def __setitem__(self, key, value): 

272 """Raise KeyError for any item assignment.""" 

273 print(f"__setitem__: {key}") 

274 raise KeyError(f"'NullObject' object cannot have items set with key '{key}'") 

275 

276 def __repr__(self): 

277 """Return string representation of NullObject.""" 

278 print(f"__repr__: {self.__dict__}") 

279 return "<NullObject>" 

280 

281 

282def identity_function(x): 

283 """Return the input value unchanged.""" 

284 return x 

285 

286 

287class Computation: 

288 """A computation graph that manages dependencies and calculations. 

289 

290 The Computation class provides a framework for building and executing 

291 computation graphs where nodes represent data or calculations, and edges 

292 represent dependencies between them. 

293 """ 

294 

295 def __init__(self, *, default_executor=None, executor_map=None, metadata=None): 

296 """Initialize a new Computation. 

297 

298 :param default_executor: An executor 

299 :type default_executor: concurrent.futures.Executor, default ThreadPoolExecutor(max_workers=1) 

300 """ 

301 if default_executor is None: 

302 self.default_executor = ThreadPoolExecutor(1) 

303 else: 

304 self.default_executor = default_executor 

305 if executor_map is None: 

306 self.executor_map = {} 

307 else: 

308 self.executor_map = executor_map 

309 self.dag = nx.DiGraph() 

310 self._metadata = {} 

311 if metadata is not None: 

312 self._metadata[NodeKey.root()] = metadata 

313 

314 self.v = self.get_attribute_view_for_path(NodeKey.root(), self._value_one, self.value) 

315 self.s = self.get_attribute_view_for_path(NodeKey.root(), self._state_one, self.state) 

316 self.i = self.get_attribute_view_for_path(NodeKey.root(), self._get_inputs_one_names, self.get_inputs) 

317 self.o = self.get_attribute_view_for_path(NodeKey.root(), self._get_outputs_one, self.get_outputs) 

318 self.t = self.get_attribute_view_for_path(NodeKey.root(), self._tag_one, self.tags) 

319 self.style = self.get_attribute_view_for_path(NodeKey.root(), self._style_one, self.styles) 

320 self.tim = self.get_attribute_view_for_path(NodeKey.root(), self._get_timing_one, self.get_timing) 

321 self.x = self.get_attribute_view_for_path( 

322 NodeKey.root(), self.compute_and_get_value, self.compute_and_get_value 

323 ) 

324 self.src = self.get_attribute_view_for_path(NodeKey.root(), self.print_source, self.print_source) 

325 self._tag_map = defaultdict(set) 

326 self._state_map = {state: set() for state in States} 

327 

328 def get_attribute_view_for_path(self, nodekey: NodeKey, get_one_func: callable, get_many_func: callable): 

329 """Create an attribute view for a specific node path.""" 

330 

331 def node_func(): 

332 return self.get_tree_list_children(nodekey) 

333 

334 def get_one_func_for_path(name: Name): 

335 nk = to_nodekey(name) 

336 new_nk = nk.prepend(nodekey) 

337 if self.has_node(new_nk): 

338 return get_one_func(new_nk) 

339 elif self.tree_has_path(new_nk): 

340 return self.get_attribute_view_for_path(new_nk, get_one_func, get_many_func) 

341 else: 

342 raise KeyError(f"Path {new_nk} does not exist") 

343 

344 def get_many_func_for_path(name: Name | Names): 

345 if isinstance(name, list): 

346 return [get_one_func_for_path(n) for n in name] 

347 else: 

348 return get_one_func_for_path(name) 

349 

350 return AttributeView(node_func, get_one_func_for_path, get_many_func_for_path) 

351 

352 def _get_names_for_state(self, state: States): 

353 return set(node_keys_to_names(self._state_map[state])) 

354 

355 def _get_tags_for_state(self, tag: str): 

356 return set(node_keys_to_names(self._tag_map[tag])) 

357 

358 def add_node( 

359 self, 

360 name: Name, 

361 func=None, 

362 *, 

363 args=None, 

364 kwds=None, 

365 value=_MISSING_VALUE_SENTINEL, 

366 converter=None, 

367 serialize=True, 

368 inspect=True, 

369 group=None, 

370 tags=None, 

371 style=None, 

372 executor=None, 

373 metadata=None, 

374 ): 

375 """Adds or updates a node in a computation. 

376 

377 :param name: Name of the node to add. This may be any hashable object. 

378 :param func: Function to use to calculate the node if the node is a calculation node. By default, the input 

379 nodes to the function will be implied from the names of the function parameters. For example, a 

380 parameter called ``a`` would be taken from the node called ``a``. This can be modified with the 

381 ``kwds`` parameter. 

382 :type func: Function, default None 

383 :param args: Specifies a list of nodes that will be used to populate arguments of the function positionally 

384 for a calculation node. e.g. If args is ``['a', 'b', 'c']`` then the function would be called with 

385 three parameters, taken from the nodes 'a', 'b' and 'c' respectively. 

386 :type args: List, default None 

387 :param kwds: Specifies a mapping from parameter name to the node that should be used to populate that 

388 parameter when calling the function for a calculation node. e.g. If args is ``{'x': 'a', 'y': 'b'}`` 

389 then the function would be called with parameters named 'x' and 'y', and their values would be taken 

390 from nodes 'a' and 'b' respectively. Each entry in the dictionary can be read as "take parameter 

391 [key] from node [value]". 

392 :type kwds: Dictionary, default None 

393 :param value: If given, the value is inserted into the node, and the node state set to UPTODATE. 

394 :type value: default None 

395 :param serialize: Whether the node should be serialized. Some objects cannot be serialized, in which 

396 case, set serialize to False 

397 :type serialize: boolean, default True 

398 :param inspect: Whether to use introspection to determine the arguments of the function, which can be 

399 slow. If this is not set, kwds and args must be set for the function to obtain parameters. 

400 :type inspect: boolean, default True 

401 :param group: Subgraph to render node in 

402 :type group: default None 

403 :param tags: Set of tags to apply to node 

404 :type tags: Iterable 

405 :param styles: Style to apply to node 

406 :type styles: String, default None 

407 :param executor: Name of executor to run node on 

408 :type executor: string 

409 """ 

410 node_key = to_nodekey(name) 

411 LOG.debug(f"Adding node {node_key}") 

412 has_value = value is not _MISSING_VALUE_SENTINEL 

413 if value is _MISSING_VALUE_SENTINEL: 

414 value = None 

415 if tags is None: 

416 tags = [] 

417 

418 self.dag.add_node(node_key) 

419 pred_edges = [(p, node_key) for p in self.dag.predecessors(node_key)] 

420 self.dag.remove_edges_from(pred_edges) 

421 node = self.dag.nodes[node_key] 

422 

423 if metadata is None: 

424 if node_key in self._metadata: 

425 del self._metadata[node_key] 

426 else: 

427 self._metadata[node_key] = metadata 

428 

429 self._set_state_and_literal_value(node_key, States.UNINITIALIZED, None, require_old_state=False) 

430 

431 node[NodeAttributes.TAG] = set() 

432 node[NodeAttributes.STYLE] = style 

433 node[NodeAttributes.GROUP] = group 

434 node[NodeAttributes.ARGS] = {} 

435 node[NodeAttributes.KWDS] = {} 

436 node[NodeAttributes.FUNC] = None 

437 node[NodeAttributes.EXECUTOR] = executor 

438 node[NodeAttributes.CONVERTER] = converter 

439 

440 if func: 

441 node[NodeAttributes.FUNC] = func 

442 args_count = 0 

443 if args: 

444 args_count = len(args) 

445 for i, arg in enumerate(args): 

446 if isinstance(arg, ConstantValue): 

447 node[NodeAttributes.ARGS][i] = arg.value 

448 else: 

449 input_vertex_name = arg 

450 input_vertex_node_key = to_nodekey(input_vertex_name) 

451 if not self.dag.has_node(input_vertex_node_key): 

452 self.dag.add_node(input_vertex_node_key, **{NodeAttributes.STATE: States.PLACEHOLDER}) 

453 self._state_map[States.PLACEHOLDER].add(input_vertex_node_key) 

454 self.dag.add_edge( 

455 input_vertex_node_key, node_key, **{EdgeAttributes.PARAM: (_ParameterType.ARG, i)} 

456 ) 

457 param_map = {} 

458 if inspect: 

459 signature = get_signature(func) 

460 if not signature.has_var_args: 

461 for param_name in signature.kwd_params[args_count:]: 

462 if kwds is not None and param_name in kwds: 

463 param_source = kwds[param_name] 

464 else: 

465 param_source = node_key.parent.join_parts(param_name) 

466 param_map[param_name] = param_source 

467 if signature.has_var_kwds and kwds is not None: 

468 for param_name, param_source in kwds.items(): 

469 param_map[param_name] = param_source 

470 default_names = signature.default_params 

471 else: 

472 if kwds is not None: 

473 for param_name, param_source in kwds.items(): 

474 param_map[param_name] = param_source 

475 default_names = [] 

476 for param_name, param_source in param_map.items(): 

477 if isinstance(param_source, ConstantValue): 

478 node[NodeAttributes.KWDS][param_name] = param_source.value 

479 else: 

480 in_node_name = param_source 

481 in_node_key = to_nodekey(in_node_name) 

482 if not self.dag.has_node(in_node_key): 

483 if param_name in default_names: 

484 continue 

485 else: 

486 self.dag.add_node(in_node_key, **{NodeAttributes.STATE: States.PLACEHOLDER}) 

487 self._state_map[States.PLACEHOLDER].add(in_node_key) 

488 self.dag.add_edge(in_node_key, node_key, **{EdgeAttributes.PARAM: (_ParameterType.KWD, param_name)}) 

489 self._set_descendents(node_key, States.STALE) 

490 

491 if has_value: 

492 self._set_uptodate(node_key, value) 

493 if node[NodeAttributes.STATE] == States.UNINITIALIZED: 

494 self._try_set_computable(node_key) 

495 self.set_tag(node_key, tags) 

496 if serialize: 

497 self.set_tag(node_key, SystemTags.SERIALIZE) 

498 

499 def _refresh_maps(self): 

500 self._tag_map.clear() 

501 for state in States: 

502 self._state_map[state].clear() 

503 for node_key in self._node_keys(): 

504 state = self.dag.nodes[node_key][NodeAttributes.STATE] 

505 self._state_map[state].add(node_key) 

506 tags = self.dag.nodes[node_key].get(NodeAttributes.TAG, set()) 

507 for tag in tags: 

508 self._tag_map[tag].add(node_key) 

509 

510 def _set_tag_one(self, name: Name, tag): 

511 node_key = to_nodekey(name) 

512 self.dag.nodes[node_key][NodeAttributes.TAG].add(tag) 

513 self._tag_map[tag].add(node_key) 

514 

515 def set_tag(self, name: Names, tag): 

516 """Set tags on a node or nodes. Ignored if tags are already set. 

517 

518 :param name: Node or nodes to set tag for 

519 :param tag: Tag to set 

520 """ 

521 apply_n(self._set_tag_one, name, tag) 

522 

523 def _clear_tag_one(self, name: Name, tag): 

524 node_key = to_nodekey(name) 

525 self.dag.nodes[node_key][NodeAttributes.TAG].discard(tag) 

526 self._tag_map[tag].discard(node_key) 

527 

528 def clear_tag(self, name: Names, tag): 

529 """Clear tag on a node or nodes. Ignored if tags are not set. 

530 

531 :param name: Node or nodes to clear tags for 

532 :param tag: Tag to clear 

533 """ 

534 apply_n(self._clear_tag_one, name, tag) 

535 

536 def _set_style_one(self, name: Name, style): 

537 node_key = to_nodekey(name) 

538 self.dag.nodes[node_key][NodeAttributes.STYLE] = style 

539 

540 def set_style(self, name: Names, style): 

541 """Set styles on a node or nodes. 

542 

543 :param name: Node or nodes to set style for 

544 :param style: Style to set 

545 """ 

546 apply_n(self._set_style_one, name, style) 

547 

548 def _clear_style_one(self, name): 

549 node_key = to_nodekey(name) 

550 self.dag.nodes[node_key][NodeAttributes.STYLE] = None 

551 

552 def clear_style(self, name): 

553 """Clear style on a node or nodes. 

554 

555 :param name: Node or nodes to clear styles for 

556 """ 

557 apply_n(self._clear_style_one, name) 

558 

559 def metadata(self, name): 

560 """Get metadata for a node.""" 

561 node_key = to_nodekey(name) 

562 if self.tree_has_path(name): 

563 if node_key not in self._metadata: 

564 self._metadata[node_key] = {} 

565 return self._metadata[node_key] 

566 else: 

567 raise NonExistentNodeException(f"Node {node_key} does not exist.") 

568 

569 def delete_node(self, name): 

570 """Delete a node from a computation. 

571 

572 When nodes are explicitly deleted with ``delete_node``, but are still depended on by other nodes, then they 

573 will be set to PLACEHOLDER status. In this case, if the nodes that depend on a PLACEHOLDER node are deleted, 

574 then the PLACEHOLDER node will also be deleted. 

575 

576 :param name: Name of the node to delete. If the node does not exist, a ``NonExistentNodeException`` will 

577 be raised. 

578 """ 

579 node_key = to_nodekey(name) 

580 LOG.debug(f"Deleting node {node_key}") 

581 

582 if not self.dag.has_node(node_key): 

583 raise NonExistentNodeException(f"Node {node_key} does not exist") 

584 

585 if node_key in self._metadata: 

586 del self._metadata[node_key] 

587 

588 if len(self.dag.succ[node_key]) == 0: 

589 preds = self.dag.predecessors(node_key) 

590 state = self.dag.nodes[node_key][NodeAttributes.STATE] 

591 self.dag.remove_node(node_key) 

592 self._state_map[state].remove(node_key) 

593 for n in preds: 

594 if self.dag.nodes[n][NodeAttributes.STATE] == States.PLACEHOLDER: 

595 self.delete_node(n) 

596 else: 

597 self._set_state(node_key, States.PLACEHOLDER) 

598 

599 def rename_node(self, old_name: Name | Mapping[Name, Name], new_name: Name | None = None): 

600 """Rename a node in a computation. 

601 

602 :param old_name: Node to rename, or a dictionary of nodes to rename, with existing names as keys, and 

603 new names as values 

604 :param new_name: New name for node. 

605 """ 

606 if hasattr(old_name, "__getitem__") and not isinstance(old_name, str): 

607 for k, v in old_name.items(): 

608 LOG.debug(f"Renaming node {k} to {v}") 

609 if new_name is not None: 

610 raise ValueError("new_name must not be set if rename_node is passed a dictionary") 

611 else: 

612 name_mapping = old_name 

613 else: 

614 LOG.debug(f"Renaming node {old_name} to {new_name}") 

615 old_node_key = to_nodekey(old_name) 

616 if not self.dag.has_node(old_node_key): 

617 raise NonExistentNodeException(f"Node {old_name} does not exist") 

618 new_node_key = to_nodekey(new_name) 

619 if self.dag.has_node(new_node_key): 

620 raise NodeAlreadyExistsException(f"Node {new_name} already exists") 

621 name_mapping = {old_name: new_name} 

622 

623 node_key_mapping = {to_nodekey(old_name): to_nodekey(new_name) for old_name, new_name in name_mapping.items()} 

624 nx.relabel_nodes(self.dag, node_key_mapping, copy=False) 

625 

626 for old_node_key, new_node_key in node_key_mapping.items(): 

627 if old_node_key in self._metadata: 

628 self._metadata[new_node_key] = self._metadata[old_node_key] 

629 del self._metadata[old_node_key] 

630 else: 

631 if new_node_key in self._metadata: 

632 del self._metadata[new_node_key] 

633 

634 self._refresh_maps() 

635 

636 def repoint(self, old_name: Name, new_name: Name): 

637 """Changes all nodes that use old_name as an input to use new_name instead. 

638 

639 Note that if old_name is an input to new_name, then that will not be changed, to try to avoid introducing 

640 circular dependencies, but other circular dependencies will not be checked. 

641 

642 If new_name does not exist, then it will be created as a PLACEHOLDER node. 

643 

644 :param old_name: 

645 :param new_name: 

646 :return: 

647 """ 

648 old_node_key = to_nodekey(old_name) 

649 new_node_key = to_nodekey(new_name) 

650 if old_node_key == new_node_key: 

651 return 

652 

653 changed_names = list(self.dag.successors(old_node_key)) 

654 

655 if len(changed_names) > 0 and not self.dag.has_node(new_node_key): 

656 self.dag.add_node(new_node_key, **{NodeAttributes.STATE: States.PLACEHOLDER}) 

657 self._state_map[States.PLACEHOLDER].add(new_node_key) 

658 

659 for name in changed_names: 

660 if name == new_node_key: 

661 continue 

662 edge_data = self.dag.get_edge_data(old_node_key, name) 

663 self.dag.add_edge(new_node_key, name, **edge_data) 

664 self.dag.remove_edge(old_node_key, name) 

665 

666 for name in changed_names: 

667 self.set_stale(name) 

668 

669 def insert(self, name: Name, value, force=False): 

670 """Insert a value into a node of a computation. 

671 

672 Following insertation, the node will have state UPTODATE, and all its descendents will be COMPUTABLE or STALE. 

673 

674 If an attempt is made to insert a value into a node that does not exist, a ``NonExistentNodeException`` 

675 will be raised. 

676 

677 :param name: Name of the node to add. 

678 :param value: The value to be inserted into the node. 

679 :param force: Whether to force recalculation of descendents if node value and state would not be changed 

680 """ 

681 node_key = to_nodekey(name) 

682 LOG.debug(f"Inserting value into node {node_key}") 

683 

684 if not self.dag.has_node(node_key): 

685 raise NonExistentNodeException(f"Node {node_key} does not exist") 

686 

687 state = self._state_one(name) 

688 if state == States.PLACEHOLDER: 

689 raise CannotInsertToPlaceholderNodeException( 

690 "Cannot insert into placeholder node. Use add_node to create the node first" 

691 ) 

692 

693 if not force: 

694 if state == States.UPTODATE: 

695 current_value = self._value_one(name) 

696 if value_eq(value, current_value): 

697 return 

698 

699 self._set_state_and_value(node_key, States.UPTODATE, value) 

700 self._set_descendents(node_key, States.STALE) 

701 for n in self.dag.successors(node_key): 

702 self._try_set_computable(n) 

703 

704 def insert_many(self, name_value_pairs: Iterable[tuple[Name, object]]): 

705 """Insert values into many nodes of a computation simultaneously. 

706 

707 Following insertation, the nodes will have state UPTODATE, and all their descendents will be COMPUTABLE 

708 or STALE. In the case of inserting many nodes, some of which are descendents of others, this ensures that 

709 the inserted nodes have correct status, rather than being set as STALE when their ancestors are inserted. 

710 

711 If an attempt is made to insert a value into a node that does not exist, a ``NonExistentNodeException`` will be 

712 raised, and none of the nodes will be inserted. 

713 

714 :param name_value_pairs: Each tuple should be a pair (name, value), where name is the name of the node to 

715 insert the value into. 

716 :type name_value_pairs: List of tuples 

717 """ 

718 node_key_value_pairs = [(to_nodekey(name), value) for name, value in name_value_pairs] 

719 LOG.debug(f"Inserting value into nodes {', '.join(str(name) for name, value in node_key_value_pairs)}") 

720 

721 for name, value in node_key_value_pairs: 

722 if not self.dag.has_node(name): 

723 raise NonExistentNodeException(f"Node {name} does not exist") 

724 

725 stale = set() 

726 computable = set() 

727 for name, value in node_key_value_pairs: 

728 self._set_state_and_value(name, States.UPTODATE, value) 

729 stale.update(nx.dag.descendants(self.dag, name)) 

730 computable.update(self.dag.successors(name)) 

731 names = set([name for name, value in node_key_value_pairs]) 

732 stale.difference_update(names) 

733 computable.difference_update(names) 

734 for name in stale: 

735 self._set_state(name, States.STALE) 

736 for name in computable: 

737 self._try_set_computable(name) 

738 

739 def insert_from(self, other, nodes: Iterable[Name] | None = None): 

740 """Insert values into another Computation object into this Computation object. 

741 

742 :param other: The computation object to take values from 

743 :type Computation: 

744 :param nodes: Only populate the nodes with the names provided in this list. By default, all nodes from the 

745 other Computation object that have corresponding nodes in this Computation object will be inserted 

746 :type nodes: List, default None 

747 """ 

748 if nodes is None: 

749 nodes = set(self.dag.nodes) 

750 nodes.intersection_update(other.dag.nodes()) 

751 name_value_pairs = [(name, other.value(name)) for name in nodes] 

752 self.insert_many(name_value_pairs) 

753 

754 def _set_state(self, node_key: NodeKey, state: States): 

755 node = self.dag.nodes[node_key] 

756 old_state = node[NodeAttributes.STATE] 

757 self._state_map[old_state].remove(node_key) 

758 node[NodeAttributes.STATE] = state 

759 self._state_map[state].add(node_key) 

760 

761 def _set_state_and_value( 

762 self, node_key: NodeKey, state: States, value: object, *, throw_conversion_exception: bool = True 

763 ): 

764 node = self.dag.nodes[node_key] 

765 converter = node.get(NodeAttributes.CONVERTER) 

766 if converter is None: 

767 self._set_state_and_literal_value(node_key, state, value) 

768 else: 

769 try: 

770 converted_value = converter(value) 

771 self._set_state_and_literal_value(node_key, state, converted_value) 

772 except Exception as e: 

773 tb = traceback.format_exc() 

774 self._set_error(node_key, e, tb) 

775 if throw_conversion_exception: 

776 raise e 

777 

778 def _set_state_and_literal_value( 

779 self, node_key: NodeKey, state: States, value: object, require_old_state: bool = True 

780 ): 

781 node = self.dag.nodes[node_key] 

782 try: 

783 old_state = node[NodeAttributes.STATE] 

784 self._state_map[old_state].remove(node_key) 

785 except KeyError: 

786 if require_old_state: 

787 raise 

788 node[NodeAttributes.STATE] = state 

789 node[NodeAttributes.VALUE] = value 

790 self._state_map[state].add(node_key) 

791 

792 def _set_states(self, node_keys: Iterable[NodeKey], state: States): 

793 for name in node_keys: 

794 node = self.dag.nodes[name] 

795 old_state = node[NodeAttributes.STATE] 

796 self._state_map[old_state].remove(name) 

797 node[NodeAttributes.STATE] = state 

798 self._state_map[state].update(node_keys) 

799 

800 def set_stale(self, name: Name): 

801 """Set the state of a node and all its dependencies to STALE. 

802 

803 :param name: Name of the node to set as STALE. 

804 """ 

805 node_key = to_nodekey(name) 

806 node_keys = [node_key] 

807 node_keys.extend(nx.dag.descendants(self.dag, node_key)) 

808 self._set_states(node_keys, States.STALE) 

809 self._try_set_computable(node_key) 

810 

811 def pin(self, name: Name, value=None): 

812 """Set the state of a node to PINNED. 

813 

814 :param name: Name of the node to set as PINNED. 

815 :param value: Value to pin to the node, if provided. 

816 :type value: default None 

817 """ 

818 node_key = to_nodekey(name) 

819 if value is not None: 

820 self.insert(node_key, value) 

821 self._set_states([node_key], States.PINNED) 

822 

823 def unpin(self, name): 

824 """Unpin a node (state of node and all descendents will be set to STALE). 

825 

826 :param name: Name of the node to set as PINNED. 

827 """ 

828 node_key = to_nodekey(name) 

829 self.set_stale(node_key) 

830 

831 def _get_descendents(self, node_key: NodeKey, stop_states: set[States] | None = None) -> set[NodeKey]: 

832 if self.dag.nodes[node_key][NodeAttributes.STATE] in stop_states: 

833 return set() 

834 if stop_states is None: 

835 stop_states = [] 

836 visited = set() 

837 to_visit = {node_key} 

838 while to_visit: 

839 n = to_visit.pop() 

840 visited.add(n) 

841 for n1 in self.dag.successors(n): 

842 if n1 in visited: 

843 continue 

844 if self.dag.nodes[n1][NodeAttributes.STATE] in stop_states: 

845 continue 

846 to_visit.add(n1) 

847 visited.remove(node_key) 

848 return visited 

849 

850 def _set_descendents(self, node_key: NodeKey, state): 

851 descendents = self._get_descendents(node_key, {States.PINNED}) 

852 self._set_states(descendents, state) 

853 

854 def _set_uninitialized(self, node_key: NodeKey): 

855 self._set_states([node_key], States.UNINITIALIZED) 

856 self.dag.nodes[node_key].pop(NodeAttributes.VALUE, None) 

857 

858 def _set_uptodate(self, node_key: NodeKey, value: object): 

859 self._set_state_and_value(node_key, States.UPTODATE, value) 

860 self._set_descendents(node_key, States.STALE) 

861 for n in self.dag.successors(node_key): 

862 self._try_set_computable(n) 

863 

864 def _set_error(self, node_key: NodeKey, exc: Exception, tb: inspect.Traceback): 

865 self._set_state_and_literal_value(node_key, States.ERROR, Error(exc, tb)) 

866 self._set_descendents(node_key, States.STALE) 

867 

868 def _try_set_computable(self, node_key: NodeKey): 

869 if self.dag.nodes[node_key][NodeAttributes.STATE] == States.PINNED: 

870 return 

871 if self.dag.nodes[node_key].get(NodeAttributes.FUNC) is not None: 

872 for n in self.dag.predecessors(node_key): 

873 if not self.dag.has_node(n): 

874 return 

875 if self.dag.nodes[n][NodeAttributes.STATE] != States.UPTODATE: 

876 return 

877 self._set_state(node_key, States.COMPUTABLE) 

878 

879 def _get_parameter_data(self, node_key: NodeKey): 

880 for arg, value in self.dag.nodes[node_key][NodeAttributes.ARGS].items(): 

881 yield _ParameterItem(_ParameterType.ARG, arg, value) 

882 for param_name, value in self.dag.nodes[node_key][NodeAttributes.KWDS].items(): 

883 yield _ParameterItem(_ParameterType.KWD, param_name, value) 

884 for in_node_name in self.dag.predecessors(node_key): 

885 param_value = self.dag.nodes[in_node_name][NodeAttributes.VALUE] 

886 edge = self.dag[in_node_name][node_key] 

887 param_type, param_name = edge[EdgeAttributes.PARAM] 

888 yield _ParameterItem(param_type, param_name, param_value) 

889 

890 def _get_func_args_kwds(self, node_key: NodeKey): 

891 node0 = self.dag.nodes[node_key] 

892 f = node0[NodeAttributes.FUNC] 

893 executor_name = node0.get(NodeAttributes.EXECUTOR) 

894 args, kwds = [], {} 

895 for param in self._get_parameter_data(node_key): 

896 if param.type == _ParameterType.ARG: 

897 idx = param.name 

898 while len(args) <= idx: 

899 args.append(None) 

900 args[idx] = param.value 

901 elif param.type == _ParameterType.KWD: 

902 kwds[param.name] = param.value 

903 else: 

904 raise Exception(f"Unexpected param type: {param.type}") 

905 return f, executor_name, args, kwds 

906 

907 def get_definition_args_kwds(self, name: Name) -> tuple[list, dict]: 

908 """Get the arguments and keyword arguments for a node's function definition.""" 

909 res_args = [] 

910 res_kwds = {} 

911 node_key = to_nodekey(name) 

912 node_data = self.dag.nodes[node_key] 

913 if NodeAttributes.ARGS in node_data: 

914 for idx, value in node_data[NodeAttributes.ARGS].items(): 

915 while len(res_args) <= idx: 

916 res_args.append(None) 

917 res_args[idx] = C(value) 

918 if NodeAttributes.KWDS in node_data: 

919 for param_name, value in node_data[NodeAttributes.KWDS].items(): 

920 res_kwds[param_name] = C(value) 

921 for in_node_key in self.dag.predecessors(node_key): 

922 edge = self.dag[in_node_key][node_key] 

923 if EdgeAttributes.PARAM in edge: 

924 param_type, param_name = edge[EdgeAttributes.PARAM] 

925 if param_type == _ParameterType.ARG: 

926 idx: int = param_name 

927 while len(res_args) <= idx: 

928 res_args.append(None) 

929 res_args[idx] = in_node_key.name 

930 elif param_type == _ParameterType.KWD: 

931 res_kwds[param_name] = in_node_key.name 

932 else: 

933 raise Exception(f"Unexpected param type: {param_type}") 

934 return res_args, res_kwds 

935 

936 def _compute_nodes(self, node_keys: Iterable[NodeKey], raise_exceptions: bool = False): 

937 LOG.debug(f"Computing nodes {node_keys}") 

938 

939 futs = {} 

940 

941 def run(name): 

942 f, executor_name, args, kwds = self._get_func_args_kwds(name) 

943 if executor_name is None: 

944 executor = self.default_executor 

945 else: 

946 executor = self.executor_map[executor_name] 

947 fut = executor.submit(_eval_node, name, f, args, kwds, raise_exceptions) 

948 futs[fut] = name 

949 

950 computed = set() 

951 

952 for node_key in node_keys: 

953 node0 = self.dag.nodes[node_key] 

954 state = node0[NodeAttributes.STATE] 

955 if state == States.COMPUTABLE: 

956 run(node_key) 

957 

958 while len(futs) > 0: 

959 done, not_done = wait(futs.keys(), return_when=FIRST_COMPLETED) 

960 for fut in done: 

961 node_key = futs.pop(fut) 

962 node0 = self.dag.nodes[node_key] 

963 try: 

964 value, exc, tb, start_dt, end_dt = fut.result() 

965 except Exception as e: 

966 exc = e 

967 tb = traceback.format_exc() 

968 self._set_error(node_key, exc, tb) 

969 raise 

970 delta = (end_dt - start_dt).total_seconds() 

971 if exc is None: 

972 self._set_state_and_value(node_key, States.UPTODATE, value, throw_conversion_exception=False) 

973 node0[NodeAttributes.TIMING] = TimingData(start_dt, end_dt, delta) 

974 self._set_descendents(node_key, States.STALE) 

975 for n in self.dag.successors(node_key): 

976 logging.debug(str(node_key) + " " + str(n) + " " + str(computed)) 

977 if n in computed: 

978 raise LoopDetectedException(f"Calculating {node_key} for the second time") 

979 self._try_set_computable(n) 

980 node0 = self.dag.nodes[n] 

981 state = node0[NodeAttributes.STATE] 

982 if state == States.COMPUTABLE and n in node_keys: 

983 run(n) 

984 else: 

985 self._set_error(node_key, exc, tb) 

986 computed.add(node_key) 

987 

988 def _get_calc_node_keys(self, node_key: NodeKey) -> list[NodeKey]: 

989 g = nx.DiGraph() 

990 g.add_nodes_from(self.dag.nodes) 

991 g.add_edges_from(self.dag.edges) 

992 for n in nx.ancestors(g, node_key): 

993 node = self.dag.nodes[n] 

994 state = node[NodeAttributes.STATE] 

995 if state == States.UPTODATE or state == States.PINNED: 

996 g.remove_node(n) 

997 

998 ancestors = nx.ancestors(g, node_key) 

999 for n in ancestors: 

1000 if state == States.UNINITIALIZED and len(self.dag.pred[n]) == 0: 

1001 raise Exception(f"Cannot compute {node_key} because {n} uninitialized") 

1002 if state == States.PLACEHOLDER: 

1003 raise Exception(f"Cannot compute {node_key} because {n} is placeholder") 

1004 

1005 ancestors.add(node_key) 

1006 nodes_sorted = nx.topological_sort(g) 

1007 return [n for n in nodes_sorted if n in ancestors] 

1008 

1009 def _get_calc_node_names(self, name: Name) -> Names: 

1010 node_key = to_nodekey(name) 

1011 return node_keys_to_names(self._get_calc_node_keys(node_key)) 

1012 

1013 def compute(self, name: Name | Iterable[Name], raise_exceptions=False): 

1014 """Compute a node and all necessary predecessors. 

1015 

1016 Following the computation, if successful, the target node, and all necessary ancestors that were not already 

1017 UPTODATE will have been calculated and set to UPTODATE. Any node that did not need to be calculated will not 

1018 have been recalculated. 

1019 

1020 If any nodes raises an exception, then the state of that node will be set to ERROR, and its value set to an 

1021 object containing the exception object, as well as a traceback. This will not halt the computation, which 

1022 will proceed as far as it can, until no more nodes that would be required to calculate the target are 

1023 COMPUTABLE. 

1024 

1025 :param name: Name of the node to compute 

1026 :param raise_exceptions: Whether to pass exceptions raised by node computations back to the caller 

1027 :type raise_exceptions: Boolean, default False 

1028 """ 

1029 if isinstance(name, (types.GeneratorType, list)): 

1030 calc_nodes = set() 

1031 for name0 in name: 

1032 node_key = to_nodekey(name0) 

1033 for n in self._get_calc_node_keys(node_key): 

1034 calc_nodes.add(n) 

1035 else: 

1036 node_key = to_nodekey(name) 

1037 calc_nodes = self._get_calc_node_keys(node_key) 

1038 self._compute_nodes(calc_nodes, raise_exceptions=raise_exceptions) 

1039 

1040 def compute_all(self, raise_exceptions=False): 

1041 """Compute all nodes of a computation that can be computed. 

1042 

1043 Nodes that are already UPTODATE will not be recalculated. Following the computation, if successful, all 

1044 nodes will have state UPTODATE, except UNINITIALIZED input nodes and PLACEHOLDER nodes. 

1045 

1046 If any nodes raises an exception, then the state of that node will be set to ERROR, and its value set to an 

1047 object containing the exception object, as well as a traceback. This will not halt the computation, which 

1048 will proceed as far as it can, until no more nodes are COMPUTABLE. 

1049 

1050 :param raise_exceptions: Whether to pass exceptions raised by node computations back to the caller 

1051 :type raise_exceptions: Boolean, default False 

1052 """ 

1053 self._compute_nodes(self._node_keys(), raise_exceptions=raise_exceptions) 

1054 

1055 def _node_keys(self) -> list[NodeKey]: 

1056 """Get a list of nodes in this computation. 

1057 

1058 :return: List of nodes. 

1059 """ 

1060 return list(self.dag.nodes) 

1061 

1062 def nodes(self) -> list[Name]: 

1063 """Get a list of nodes in this computation. 

1064 

1065 :return: List of nodes. 

1066 """ 

1067 return list(n.name for n in self.dag.nodes) 

1068 

1069 def get_tree_list_children(self, name: Name) -> set[Name]: 

1070 """Get a list of nodes in this computation. 

1071 

1072 :return: List of nodes. 

1073 """ 

1074 node_key = to_nodekey(name) 

1075 idx = len(node_key.parts) 

1076 result = set() 

1077 for n in self.dag.nodes: 

1078 if n.is_descendent_of(node_key): 

1079 result.add(n.parts[idx]) 

1080 return result 

1081 

1082 def has_node(self, name: Name): 

1083 """Check if a node with the given name exists in the computation.""" 

1084 node_key = to_nodekey(name) 

1085 return node_key in self.dag.nodes 

1086 

1087 def tree_has_path(self, name: Name): 

1088 """Check if a hierarchical path exists in the computation tree.""" 

1089 node_key = to_nodekey(name) 

1090 if node_key.is_root: 

1091 return True 

1092 if self.has_node(node_key): 

1093 return True 

1094 for n in self.dag.nodes: 

1095 if n.is_descendent_of(node_key): 

1096 return True 

1097 return False 

1098 

1099 def get_tree_descendents( 

1100 self, name: Name | None = None, *, include_stem: bool = True, graph_nodes_only: bool = False 

1101 ) -> set[Name]: 

1102 """Get a list of descendent blocks and nodes. 

1103 

1104 Returns blocks and nodes that are descendents of the input node, 

1105 e.g. for node 'foo', might return ['foo/bar', 'foo/baz']. 

1106 

1107 :param name: Name of node to get descendents for 

1108 :return: List of descendent node names 

1109 """ 

1110 node_key = NodeKey.root() if name is None else to_nodekey(name) 

1111 stemsize = len(node_key.parts) 

1112 result = set() 

1113 for n in self.dag.nodes: 

1114 if n.is_descendent_of(node_key): 

1115 if graph_nodes_only: 

1116 nodes = [n] 

1117 else: 

1118 nodes = n.ancestors() 

1119 for n2 in nodes: 

1120 if n2.is_descendent_of(node_key): 

1121 if include_stem: 

1122 nm = n2.name 

1123 else: 

1124 nm = NodeKey(tuple(n2.parts[stemsize:])).name 

1125 result.add(nm) 

1126 return result 

1127 

1128 def _state_one(self, name: Name): 

1129 node_key = to_nodekey(name) 

1130 return self.dag.nodes[node_key][NodeAttributes.STATE] 

1131 

1132 def state(self, name: Name | Names): 

1133 """Get the state of a node. 

1134 

1135 This can also be accessed using the attribute-style accessor ``s`` if ``name`` is a valid Python 

1136 attribute name:: 

1137 

1138 >>> comp = Computation() 

1139 >>> comp.add_node('foo', value=1) 

1140 >>> comp.state('foo') 

1141 <States.UPTODATE: 4> 

1142 >>> comp.s.foo 

1143 <States.UPTODATE: 4> 

1144 

1145 :param name: Name or names of the node to get state for 

1146 :type name: Name or Names 

1147 """ 

1148 return apply1(self._state_one, name) 

1149 

1150 def _value_one(self, name: Name): 

1151 node_key = to_nodekey(name) 

1152 return self.dag.nodes[node_key][NodeAttributes.VALUE] 

1153 

1154 def value(self, name: Name | Names): 

1155 """Get the current value of a node. 

1156 

1157 This can also be accessed using the attribute-style accessor ``v`` if ``name`` is a valid Python 

1158 attribute name:: 

1159 

1160 >>> comp = Computation() 

1161 >>> comp.add_node('foo', value=1) 

1162 >>> comp.value('foo') 

1163 1 

1164 >>> comp.v.foo 

1165 1 

1166 

1167 :param name: Name or names of the node to get the value of 

1168 :type name: Name or Names 

1169 """ 

1170 return apply1(self._value_one, name) 

1171 

1172 def compute_and_get_value(self, name: Name): 

1173 """Get the current value of a node. 

1174 

1175 This can also be accessed using the attribute-style accessor ``v`` if ``name`` is a valid Python 

1176 attribute name:: 

1177 

1178 >>> comp = Computation() 

1179 >>> comp.add_node('foo', value=1) 

1180 >>> comp.add_node('bar', lambda foo: foo + 1) 

1181 >>> comp.compute_and_get_value('bar') 

1182 2 

1183 >>> comp.x.bar 

1184 2 

1185 

1186 :param name: Name or names of the node to get the value of 

1187 :type name: Name 

1188 """ 

1189 name = to_nodekey(name) 

1190 if self.state(name) == States.UPTODATE: 

1191 return self.value(name) 

1192 self.compute(name, raise_exceptions=True) 

1193 if self.state(name) == States.UPTODATE: 

1194 return self.value(name) 

1195 raise ComputationError(f"Unable to compute node {name}") 

1196 

1197 def _tag_one(self, name: Name): 

1198 node_key = to_nodekey(name) 

1199 node = self.dag.nodes[node_key] 

1200 return node[NodeAttributes.TAG] 

1201 

1202 def tags(self, name: Name | Names): 

1203 """Get the tags associated with a node. 

1204 

1205 >>> comp = Computation() 

1206 >>> comp.add_node('a', tags=['foo', 'bar']) 

1207 >>> comp.t.a 

1208 {'__serialize__', 'bar', 'foo'} 

1209 :param name: Name or names of the node to get the tags of 

1210 :return: 

1211 """ 

1212 return apply1(self._tag_one, name) 

1213 

1214 def nodes_by_tag(self, tag) -> set[Name]: 

1215 """Get the names of nodes with a particular tag or tags. 

1216 

1217 :param tag: Tag or tags for which to retrieve nodes 

1218 :return: Names of the nodes with those tags 

1219 """ 

1220 nodes = set() 

1221 for tag1 in as_iterable(tag): 

1222 nodes1 = self._tag_map.get(tag1) 

1223 if nodes1 is not None: 

1224 nodes.update(nodes1) 

1225 return set(n.name for n in nodes) 

1226 

1227 def _style_one(self, name: Name): 

1228 node_key = to_nodekey(name) 

1229 node = self.dag.nodes[node_key] 

1230 return node.get(NodeAttributes.STYLE) 

1231 

1232 def styles(self, name: Name | Names): 

1233 """Get the tags associated with a node. 

1234 

1235 >>> comp = Computation() 

1236 >>> comp.add_node('a', styles='dot') 

1237 >>> comp.style.a 

1238 'dot' 

1239 :param name: Name or names of the node to get the tags of 

1240 :return: 

1241 """ 

1242 return apply1(self._style_one, name) 

1243 

1244 def _get_item_one(self, name: Name): 

1245 node_key = to_nodekey(name) 

1246 node = self.dag.nodes[node_key] 

1247 return NodeData(node[NodeAttributes.STATE], node[NodeAttributes.VALUE]) 

1248 

1249 def __getitem__(self, name: Name | Names): 

1250 """Get the state and current value of a node. 

1251 

1252 :param name: Name of the node to get the state and value of 

1253 """ 

1254 return apply1(self._get_item_one, name) 

1255 

1256 def _get_timing_one(self, name: Name): 

1257 node_key = to_nodekey(name) 

1258 node = self.dag.nodes[node_key] 

1259 return node.get(NodeAttributes.TIMING, None) 

1260 

1261 def get_timing(self, name: Name | Names): 

1262 """Get the timing information for a node. 

1263 

1264 :param name: Name or names of the node to get the timing information of 

1265 :return: 

1266 """ 

1267 return apply1(self._get_timing_one, name) 

1268 

1269 def to_df(self): 

1270 """Get a dataframe containing the states and value of all nodes of computation. 

1271 

1272 :: 

1273 

1274 >>> comp = loman.Computation() 

1275 >>> comp.add_node('foo', value=1) 

1276 >>> comp.add_node('bar', value=2) 

1277 >>> comp.to_df() 

1278 state value is_expansion 

1279 bar States.UPTODATE 2 NaN 

1280 foo States.UPTODATE 1 NaN 

1281 """ 

1282 df = pd.DataFrame(index=nx.topological_sort(self.dag)) 

1283 df[NodeAttributes.STATE] = pd.Series(nx.get_node_attributes(self.dag, NodeAttributes.STATE)) 

1284 df[NodeAttributes.VALUE] = pd.Series(nx.get_node_attributes(self.dag, NodeAttributes.VALUE)) 

1285 df_timing = pd.DataFrame.from_dict(nx.get_node_attributes(self.dag, "timing"), orient="index") 

1286 df = pd.merge(df, df_timing, left_index=True, right_index=True, how="left") 

1287 df.index = pd.Index([nk.name for nk in df.index]) 

1288 return df 

1289 

1290 def to_dict(self): 

1291 """Get a dictionary containing the values of all nodes of a computation. 

1292 

1293 :: 

1294 

1295 >>> comp = loman.Computation() 

1296 >>> comp.add_node('foo', value=1) 

1297 >>> comp.add_node('bar', value=2) 

1298 >>> comp.to_dict() 

1299 {'bar': 2, 'foo': 1} 

1300 """ 

1301 return nx.get_node_attributes(self.dag, NodeAttributes.VALUE) 

1302 

1303 def _get_inputs_one_node_keys(self, node_key: NodeKey) -> list[NodeKey]: 

1304 args_dict = {} 

1305 kwds = [] 

1306 max_arg_index = -1 

1307 for input_node in self.dag.predecessors(node_key): 

1308 input_edge = self.dag[input_node][node_key] 

1309 input_type, input_param = input_edge[EdgeAttributes.PARAM] 

1310 if input_type == _ParameterType.ARG: 

1311 idx = input_param 

1312 max_arg_index = max(max_arg_index, idx) 

1313 args_dict[idx] = input_node 

1314 elif input_type == _ParameterType.KWD: 

1315 kwds.append(input_node) 

1316 if max_arg_index >= 0: 

1317 args = [None] * (max_arg_index + 1) 

1318 for idx, input_node in args_dict.items(): 

1319 args[idx] = input_node 

1320 return args + kwds 

1321 else: 

1322 return kwds 

1323 

1324 def _get_inputs_one_names(self, name: Name) -> Names: 

1325 node_key = to_nodekey(name) 

1326 return node_keys_to_names(self._get_inputs_one_node_keys(node_key)) 

1327 

1328 def get_inputs(self, name: Name | Names) -> list[Names]: 

1329 """Get a list of the inputs for a node or set of nodes. 

1330 

1331 :param name: Name or names of nodes to get inputs for 

1332 :return: If name is scalar, return a list of upstream nodes used as input. If name is a list, return a 

1333 list of list of inputs. 

1334 """ 

1335 return apply1(self._get_inputs_one_names, name) 

1336 

1337 def _get_ancestors_node_keys(self, node_keys: Iterable[NodeKey], include_self=True) -> set[NodeKey]: 

1338 ancestors = set() 

1339 for n in node_keys: 

1340 if include_self: 

1341 ancestors.add(n) 

1342 for ancestor in nx.ancestors(self.dag, n): 

1343 ancestors.add(ancestor) 

1344 return ancestors 

1345 

1346 def get_ancestors(self, names: Name | Names, include_self=True) -> Names: 

1347 """Get all ancestor nodes of the specified nodes.""" 

1348 node_keys = names_to_node_keys(names) 

1349 ancestor_node_keys = self._get_ancestors_node_keys(node_keys, include_self) 

1350 return node_keys_to_names(ancestor_node_keys) 

1351 

1352 def _get_original_inputs_node_keys(self, node_keys: list[NodeKey] | None) -> Names: 

1353 if node_keys is None: 

1354 node_keys = self._node_keys() 

1355 else: 

1356 node_keys = self._get_ancestors_node_keys(node_keys) 

1357 return [n for n in node_keys if self.dag.nodes[n].get(NodeAttributes.FUNC) is None] 

1358 

1359 def get_original_inputs(self, names: Name | Names | None = None) -> Names: 

1360 """Get a list of the original non-computed inputs for a node or set of nodes. 

1361 

1362 :param names: Name or names of nodes to get inputs for 

1363 :return: Return a list of original non-computed inputs that are ancestors of the input nodes 

1364 """ 

1365 if names is None: 

1366 node_keys = None 

1367 else: 

1368 node_keys = names_to_node_keys(names) 

1369 

1370 node_keys = self._get_original_inputs_node_keys(node_keys) 

1371 

1372 return node_keys_to_names(node_keys) 

1373 

1374 def _get_outputs_one(self, name: Name) -> Names: 

1375 node_key = to_nodekey(name) 

1376 output_node_keys = list(self.dag.successors(node_key)) 

1377 return node_keys_to_names(output_node_keys) 

1378 

1379 def get_outputs(self, name: Name | Names) -> Names | list[Names]: 

1380 """Get a list of the outputs for a node or set of nodes. 

1381 

1382 :param name: Name or names of nodes to get outputs for 

1383 :return: If name is scalar, return a list of downstream nodes used as output. If name is a list, return a 

1384 list of list of outputs. 

1385 

1386 """ 

1387 return apply1(self._get_outputs_one, name) 

1388 

1389 def _get_descendents_node_keys(self, node_keys: Iterable[NodeKey], include_self: bool = True) -> Names: 

1390 ancestor_node_keys = set() 

1391 for node_key in node_keys: 

1392 if include_self: 

1393 ancestor_node_keys.add(node_key) 

1394 for ancestor in nx.descendants(self.dag, node_key): 

1395 ancestor_node_keys.add(ancestor) 

1396 return ancestor_node_keys 

1397 

1398 def get_descendents(self, names: Name | Names, include_self: bool = True) -> Names: 

1399 """Get all descendent nodes of the specified nodes.""" 

1400 node_keys = names_to_node_keys(names) 

1401 descendent_node_keys = self._get_descendents_node_keys(node_keys, include_self) 

1402 return node_keys_to_names(descendent_node_keys) 

1403 

1404 def get_final_outputs(self, names: Name | Names | None = None): 

1405 """Get final output nodes (nodes with no descendants) from the specified nodes.""" 

1406 if names is None: 

1407 node_keys = self._node_keys() 

1408 else: 

1409 node_keys = names_to_node_keys(names) 

1410 node_keys = self._get_descendents_node_keys(node_keys) 

1411 output_node_keys = [n for n in node_keys if len(nx.descendants(self.dag, n)) == 0] 

1412 return node_keys_to_names(output_node_keys) 

1413 

1414 def get_source(self, name: Name) -> str: 

1415 """Get the source code for a node.""" 

1416 node_key = to_nodekey(name) 

1417 func = self.dag.nodes[node_key].get(NodeAttributes.FUNC, None) 

1418 if func is not None: 

1419 file = inspect.getsourcefile(func) 

1420 _, lineno = inspect.getsourcelines(func) 

1421 source = inspect.getsource(func) 

1422 return f"{file}:{lineno}\n\n{source}" 

1423 else: 

1424 return "NOT A CALCULATED NODE" 

1425 

1426 def print_source(self, name: Name): 

1427 """Print the source code for a computation node.""" 

1428 print(self.get_source(name)) 

1429 

1430 def restrict(self, output_names: Name | Names, input_names: Name | Names | None = None): 

1431 """Restrict a computation to the ancestors of a set of output nodes. 

1432 

1433 Excludes ancestors of a set of input nodes. 

1434 

1435 If the set of input_nodes that is specified is not sufficient for the set of output_nodes then additional 

1436 nodes that are ancestors of the output_nodes will be included, but the input nodes specified will be input 

1437 nodes of the modified Computation. 

1438 

1439 :param output_nodes: 

1440 :param input_nodes: 

1441 :return: None - modifies existing computation in place 

1442 """ 

1443 if input_names is not None: 

1444 for name in input_names: 

1445 nodedata = self._get_item_one(name) 

1446 self.add_node(name) 

1447 self._set_state_and_literal_value(to_nodekey(name), nodedata.state, nodedata.value) 

1448 output_node_keys = names_to_node_keys(output_names) 

1449 ancestor_node_keys = self._get_ancestors_node_keys(output_node_keys) 

1450 self.dag.remove_nodes_from([n for n in self.dag if n not in ancestor_node_keys]) 

1451 

1452 def __getstate__(self): 

1453 """Prepare computation for serialization by removing non-serializable nodes.""" 

1454 node_serialize = nx.get_node_attributes(self.dag, NodeAttributes.TAG) 

1455 obj = self.copy() 

1456 for name, tags in node_serialize.items(): 

1457 if SystemTags.SERIALIZE not in tags: 

1458 obj._set_uninitialized(name) 

1459 return {"dag": obj.dag} 

1460 

1461 def __setstate__(self, state): 

1462 """Restore computation from serialized state.""" 

1463 self.__init__() 

1464 self.dag = state["dag"] 

1465 self._refresh_maps() 

1466 

1467 def write_dill_old(self, file_): 

1468 """Serialize a computation to a file or file-like object. 

1469 

1470 :param file_: If string, writes to a file 

1471 :type file_: File-like object, or string 

1472 """ 

1473 warnings.warn("write_dill_old is deprecated, use write_dill instead", DeprecationWarning, stacklevel=2) 

1474 original_getstate = self.__class__.__getstate__ 

1475 original_setstate = self.__class__.__setstate__ 

1476 

1477 try: 

1478 del self.__class__.__getstate__ 

1479 del self.__class__.__setstate__ 

1480 

1481 node_serialize = nx.get_node_attributes(self.dag, NodeAttributes.TAG) 

1482 obj = self.copy() 

1483 obj.executor_map = None 

1484 obj.default_executor = None 

1485 for name, tags in node_serialize.items(): 

1486 if SystemTags.SERIALIZE not in tags: 

1487 obj._set_uninitialized(name) 

1488 

1489 if isinstance(file_, str): 

1490 with open(file_, "wb") as f: 

1491 dill.dump(obj, f) 

1492 else: 

1493 dill.dump(obj, file_) 

1494 finally: 

1495 self.__class__.__getstate__ = original_getstate 

1496 self.__class__.__setstate__ = original_setstate 

1497 

1498 def write_dill(self, file_): 

1499 """Serialize a computation to a file or file-like object. 

1500 

1501 :param file_: If string, writes to a file 

1502 :type file_: File-like object, or string 

1503 """ 

1504 if isinstance(file_, str): 

1505 with open(file_, "wb") as f: 

1506 dill.dump(self, f) 

1507 else: 

1508 dill.dump(self, file_) 

1509 

1510 @staticmethod 

1511 def read_dill(file_): 

1512 """Deserialize a computation from a file or file-like object. 

1513 

1514 :param file_: If string, writes to a file 

1515 :type file_: File-like object, or string 

1516 """ 

1517 if isinstance(file_, str): 

1518 with open(file_, "rb") as f: 

1519 obj = dill.load(f) 

1520 else: 

1521 obj = dill.load(file_) 

1522 if isinstance(obj, Computation): 

1523 return obj 

1524 else: 

1525 raise Exception() 

1526 

1527 def copy(self): 

1528 """Create a copy of a computation. 

1529 

1530 The copy is shallow. Any values in the new Computation's DAG will be the same object as this Computation's 

1531 DAG. As new objects will be created by any further computations, this should not be an issue. 

1532 

1533 :rtype: Computation 

1534 """ 

1535 obj = Computation() 

1536 obj.dag = nx.DiGraph(self.dag) 

1537 obj._tag_map = {tag: nodes.copy() for tag, nodes in self._tag_map.items()} 

1538 obj._state_map = {state: nodes.copy() for state, nodes in self._state_map.items()} 

1539 return obj 

1540 

1541 def add_named_tuple_expansion(self, name, namedtuple_type, group=None): 

1542 """Automatically add nodes to extract each element of a named tuple type. 

1543 

1544 It is often convenient for a calculation to return multiple values, and it is polite to do this a namedtuple 

1545 rather than a regular tuple, so that later users have same name to identify elements of the tuple. It can 

1546 also help make a computation clearer if a downstream computation depends on one element of such a tuple, 

1547 rather than the entire tuple. This does not affect the computation per se, but it does make the intention 

1548 clearer. 

1549 

1550 To avoid having to create many boiler-plate node definitions to expand namedtuples, the 

1551 ``add_named_tuple_expansion`` method automatically creates new nodes for each element of a tuple. The 

1552 convention is that an element called 'element', in a node called 'node' will be expanded into a new node 

1553 called 'node.element', and that this will be applied for each element. 

1554 

1555 Example:: 

1556 

1557 >>> from collections import namedtuple 

1558 >>> Coordinate = namedtuple('Coordinate', ['x', 'y']) 

1559 >>> comp = Computation() 

1560 >>> comp.add_node('c', value=Coordinate(1, 2)) 

1561 >>> comp.add_named_tuple_expansion('c', Coordinate) 

1562 >>> comp.compute_all() 

1563 >>> comp.value('c.x') 

1564 1 

1565 >>> comp.value('c.y') 

1566 2 

1567 

1568 :param name: Node to cera 

1569 :param namedtuple_type: Expected type of the node 

1570 :type namedtuple_type: namedtuple class 

1571 """ 

1572 

1573 def make_f(field_name): 

1574 def get_field_value(tuple): 

1575 return getattr(tuple, field_name) 

1576 

1577 return get_field_value 

1578 

1579 for field_name in namedtuple_type._fields: 

1580 node_name = f"{name}.{field_name}" 

1581 self.add_node(node_name, make_f(field_name), kwds={"tuple": name}, group=group) 

1582 self.set_tag(node_name, SystemTags.EXPANSION) 

1583 

1584 def add_map_node(self, result_node, input_node, subgraph, subgraph_input_node, subgraph_output_node): 

1585 """Apply a graph to each element of iterable. 

1586 

1587 In turn, each element in the ``input_node`` of this graph will be inserted in turn into the subgraph's 

1588 ``subgraph_input_node``, then the subgraph's ``subgraph_output_node`` calculated. The resultant list, with 

1589 an element or each element in ``input_node``, will be inserted into ``result_node`` of this graph. In this 

1590 way ``add_map_node`` is similar to ``map`` in functional programming. 

1591 

1592 :param result_node: The node to place a list of results in **this** graph 

1593 :param input_node: The node to get a list input values from **this** graph 

1594 :param subgraph: The graph to use to perform calculation for each element 

1595 :param subgraph_input_node: The node in **subgraph** to insert each element in turn 

1596 :param subgraph_output_node: The node in **subgraph** to read the result for each element 

1597 """ 

1598 

1599 def f(xs): 

1600 results = [] 

1601 is_error = False 

1602 for x in xs: 

1603 subgraph.insert(subgraph_input_node, x) 

1604 subgraph.compute(subgraph_output_node) 

1605 if subgraph.state(subgraph_output_node) == States.UPTODATE: 

1606 results.append(subgraph.value(subgraph_output_node)) 

1607 else: 

1608 is_error = True 

1609 results.append(subgraph.copy()) 

1610 if is_error: 

1611 raise MapException(f"Unable to calculate {result_node}", results) 

1612 return results 

1613 

1614 self.add_node(result_node, f, kwds={"xs": input_node}) 

1615 

1616 def prepend_path(self, path, prefix_path: NodeKey): 

1617 """Prepend a prefix path to a node path.""" 

1618 if isinstance(path, ConstantValue): 

1619 return path 

1620 path = to_nodekey(path) 

1621 return prefix_path.join(path) 

1622 

1623 def add_block( 

1624 self, 

1625 base_path: Name, 

1626 block: "Computation", 

1627 *, 

1628 keep_values: bool | None = True, 

1629 links: dict | None = None, 

1630 metadata: dict | None = None, 

1631 ): 

1632 """Add a computation block as a subgraph to this computation.""" 

1633 base_path = to_nodekey(base_path) 

1634 for node_name in block.nodes(): 

1635 node_key = to_nodekey(node_name) 

1636 node_data = block.dag.nodes[node_key] 

1637 tags = node_data.get(NodeAttributes.TAG, None) 

1638 style = node_data.get(NodeAttributes.STYLE, None) 

1639 group = node_data.get(NodeAttributes.GROUP, None) 

1640 args, kwds = block.get_definition_args_kwds(node_key) 

1641 args = [self.prepend_path(arg, base_path) for arg in args] 

1642 kwds = {k: self.prepend_path(v, base_path) for k, v in kwds.items()} 

1643 func = node_data.get(NodeAttributes.FUNC, None) 

1644 executor = node_data.get(NodeAttributes.EXECUTOR, None) 

1645 converter = node_data.get(NodeAttributes.CONVERTER, None) 

1646 new_node_name = self.prepend_path(node_name, base_path) 

1647 self.add_node( 

1648 new_node_name, 

1649 func, 

1650 args=args, 

1651 kwds=kwds, 

1652 converter=converter, 

1653 serialize=False, 

1654 inspect=False, 

1655 group=group, 

1656 tags=tags, 

1657 style=style, 

1658 executor=executor, 

1659 ) 

1660 if keep_values and NodeAttributes.VALUE in node_data: 

1661 new_node_key = to_nodekey(new_node_name) 

1662 self._set_state_and_literal_value( 

1663 new_node_key, node_data[NodeAttributes.STATE], node_data[NodeAttributes.VALUE] 

1664 ) 

1665 if links is not None: 

1666 for target, source in links.items(): 

1667 self.link(base_path.join_parts(target), source) 

1668 if metadata is not None: 

1669 self._metadata[base_path] = metadata 

1670 else: 

1671 if base_path in self._metadata: 

1672 del self._metadata[base_path] 

1673 

1674 def link(self, target: Name, source: Name): 

1675 """Create a link between two nodes in the computation graph.""" 

1676 target = to_nodekey(target) 

1677 source = to_nodekey(source) 

1678 if target == source: 

1679 return 

1680 

1681 target_style = self._style_one(target) if self.has_node(target) else None 

1682 source_style = self._style_one(source) if self.has_node(source) else None 

1683 style = target_style if target_style else source_style 

1684 

1685 self.add_node(target, identity_function, kwds={"x": source}, style=style) 

1686 

1687 def _repr_svg_(self): 

1688 return GraphView(self).svg() 

1689 

1690 def draw( 

1691 self, 

1692 root: NodeKey | None = None, 

1693 *, 

1694 node_transformations: dict | None = None, 

1695 cmap=None, 

1696 colors="state", 

1697 shapes=None, 

1698 graph_attr=None, 

1699 node_attr=None, 

1700 edge_attr=None, 

1701 show_expansion=False, 

1702 collapse_all=True, 

1703 ): 

1704 """Draw a computation's current state using the GraphViz utility. 

1705 

1706 :param root: Optional PathType. Sub-block to draw 

1707 :param cmap: Default: None 

1708 :param colors: 'state' - colors indicate state. 'timing' - colors indicate execution time. Default: 'state'. 

1709 :param shapes: None - ovals. 'type' - shapes indicate type. Default: None. 

1710 :param graph_attr: Mapping of (attribute, value) pairs for the graph. For example 

1711 ``graph_attr={'size': '"10,8"'}`` can control the size of the output graph 

1712 :param node_attr: Mapping of (attribute, value) pairs set for all nodes. 

1713 :param edge_attr: Mapping of (attribute, value) pairs set for all edges. 

1714 :param collapse_all: Whether to collapse all blocks that aren't explicitly expanded. 

1715 """ 

1716 node_formatter = NodeFormatter.create(cmap, colors, shapes) 

1717 node_transformations = node_transformations.copy() if node_transformations is not None else {} 

1718 if not show_expansion: 

1719 for nodekey in self.nodes_by_tag(SystemTags.EXPANSION): 

1720 node_transformations[nodekey] = NodeTransformations.CONTRACT 

1721 v = GraphView( 

1722 self, 

1723 root=root, 

1724 node_formatter=node_formatter, 

1725 graph_attr=graph_attr, 

1726 node_attr=node_attr, 

1727 edge_attr=edge_attr, 

1728 node_transformations=node_transformations, 

1729 collapse_all=collapse_all, 

1730 ) 

1731 return v 

1732 

1733 def view(self, cmap=None, colors="state", shapes=None): 

1734 """Create and display a visualization of the computation graph.""" 

1735 node_formatter = NodeFormatter.create(cmap, colors, shapes) 

1736 v = GraphView(self, node_formatter=node_formatter) 

1737 v.view() 

1738 

1739 def print_errors(self): 

1740 """Print tracebacks for every node with state "ERROR" in a Computation.""" 

1741 for n in self.nodes(): 

1742 if self.s[n] == States.ERROR: 

1743 print(f"{n}") 

1744 print("=" * len(n)) 

1745 print() 

1746 print(self.v[n].traceback) 

1747 print() 

1748 

1749 @classmethod 

1750 def from_class(cls, definition_class, ignore_self=True): 

1751 """Create a computation from a class with decorated methods.""" 

1752 comp = cls() 

1753 obj = definition_class() 

1754 populate_computation_from_class(comp, definition_class, obj, ignore_self=ignore_self) 

1755 return comp 

1756 

1757 def inject_dependencies(self, dependencies: dict, *, force: bool = False): 

1758 """Injects dependencies into the nodes of the current computation where nodes are in a placeholder state. 

1759 

1760 (or all possible nodes when the 'force' parameter is set to True), using values 

1761 provided in the 'dependencies' dictionary. 

1762 

1763 Each key in the 'dependencies' dictionary corresponds to a node identifier, and the associated 

1764 value is the dependency object to inject. If the value is a callable, it will be added as a calc node. 

1765 

1766 :param dependencies: A dictionary where each key-value pair consists of a node identifier and 

1767 its corresponding dependency object or a callable that returns the dependency object. 

1768 :param force: A boolean flag that, when set to True, forces the replacement of existing node values 

1769 with the ones provided in 'dependencies', regardless of their current state. Defaults to False. 

1770 :return: None 

1771 """ 

1772 for n in self.nodes(): 

1773 if force or self.s[n] == States.PLACEHOLDER: 

1774 obj = dependencies.get(n) 

1775 if obj is None: 

1776 continue 

1777 if callable(obj): 

1778 self.add_node(n, obj) 

1779 else: 

1780 self.add_node(n, value=obj)