Coverage for src / loman / computeengine.py: 90%

992 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-02 23:34 +0000

1"""Core computation engine for dependency-aware calculation graphs.""" 

2 

3import inspect 

4import logging 

5import traceback 

6import types 

7import warnings 

8from collections import defaultdict 

9from collections.abc import Callable, Iterable, Mapping 

10from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait 

11from dataclasses import dataclass, field 

12from datetime import UTC, datetime 

13from enum import Enum 

14from typing import Any, Type, Union # noqa: UP035 

15 

16import decorator 

17import dill 

18import networkx as nx 

19import pandas as pd 

20 

21from .compat import get_signature 

22from .consts import EdgeAttributes, NodeAttributes, NodeTransformations, States, SystemTags 

23from .exception import ( 

24 CannotInsertToPlaceholderNodeException, 

25 ComputationError, 

26 LoopDetectedException, 

27 MapException, 

28 NodeAlreadyExistsException, 

29 NonExistentNodeException, 

30) 

31from .graph_utils import topological_sort 

32from .nodekey import Name, Names, NodeKey, names_to_node_keys, node_keys_to_names, to_nodekey 

33from .util import AttributeView, apply1, apply_n, as_iterable, value_eq 

34from .visualization import GraphView, NodeFormatter 

35 

36LOG = logging.getLogger("loman.computeengine") 

37 

38 

39@dataclass 

40class Error: 

41 """Container for error information during computation.""" 

42 

43 exception: Exception 

44 traceback: inspect.Traceback 

45 

46 

47@dataclass 

48class NodeData: 

49 """Data associated with a computation node.""" 

50 

51 state: States 

52 value: object 

53 

54 

55@dataclass 

56class TimingData: 

57 """Timing information for computation execution.""" 

58 

59 start: datetime 

60 end: datetime 

61 duration: float 

62 

63 

64class _ParameterType(Enum): 

65 ARG = 1 

66 KWD = 2 

67 

68 

69@dataclass 

70class _ParameterItem: 

71 type: object 

72 name: int | str 

73 value: object 

74 

75 

76def _node(func, *args, **kws): 

77 return func(*args, **kws) 

78 

79 

80def node(comp, name=None, *args, **kw): 

81 """Decorator to add a function as a node to a computation graph.""" 

82 

83 def inner(f): 

84 if name is None: 

85 comp.add_node(f.__name__, f, *args, **kw) 

86 else: 

87 comp.add_node(name, f, *args, **kw) 

88 return decorator.decorate(f, _node) 

89 

90 return inner 

91 

92 

93@dataclass() 

94class ConstantValue: 

95 """Container for constant values in computations.""" 

96 

97 value: object 

98 

99 

100C = ConstantValue 

101 

102 

103class Node: 

104 """Base class for computation graph nodes.""" 

105 

106 def add_to_comp(self, comp: "Computation", name: str, obj: object, ignore_self: bool): 

107 """Add this node to the computation graph.""" 

108 raise NotImplementedError() 

109 

110 

111@dataclass 

112class InputNode(Node): 

113 """A node representing input data in the computation graph.""" 

114 

115 args: tuple[Any, ...] = field(default_factory=tuple) 

116 kwds: dict = field(default_factory=dict) 

117 

118 def __init__(self, *args, **kwds): 

119 """Initialize an input node with arguments and keyword arguments.""" 

120 self.args = args 

121 self.kwds = kwds 

122 

123 def add_to_comp(self, comp: "Computation", name: str, obj: object, ignore_self: bool): 

124 """Add this input node to the computation graph.""" 

125 comp.add_node(name, **self.kwds) 

126 

127 

128input_node = InputNode 

129 

130 

131@dataclass 

132class CalcNode(Node): 

133 """A node representing a calculation in the computation graph.""" 

134 

135 f: Callable 

136 kwds: dict = field(default_factory=dict) 

137 

138 def add_to_comp(self, comp: "Computation", name: str, obj: object, ignore_self: bool): 

139 """Add this calculation node to the computation graph.""" 

140 kwds = self.kwds.copy() 

141 ignore_self = ignore_self or kwds.get("ignore_self", False) 

142 f = self.f 

143 if ignore_self: 

144 signature = get_signature(self.f) 

145 if len(signature.kwd_params) > 0 and signature.kwd_params[0] == "self": 

146 f = f.__get__(obj, obj.__class__) 

147 if "ignore_self" in kwds: 

148 del kwds["ignore_self"] 

149 comp.add_node(name, f, **kwds) 

150 

151 

152def calc_node(f=None, **kwds): 

153 """Decorator to mark a function as a calculation node.""" 

154 

155 def wrap(func): 

156 func._loman_node_info = CalcNode(func, kwds) 

157 return func 

158 

159 if f is None: 

160 return wrap 

161 return wrap(f) 

162 

163 

164@dataclass 

165class Block(Node): 

166 """A node representing a computational block or subgraph.""" 

167 

168 block: Union[Callable, "Computation"] 

169 args: tuple[Any, ...] = field(default_factory=tuple) 

170 kwds: dict = field(default_factory=dict) 

171 

172 def __init__(self, block, *args, **kwds): 

173 """Initialize a block node with a computation block and arguments.""" 

174 self.block = block 

175 self.args = args 

176 self.kwds = kwds 

177 

178 def add_to_comp(self, comp: "Computation", name: str, obj: object, ignore_self: bool): 

179 """Add this block node to the computation graph.""" 

180 if isinstance(self.block, Computation): 

181 comp.add_block(name, self.block, *self.args, **self.kwds) 

182 elif callable(self.block): 

183 block0 = self.block() 

184 comp.add_block(name, block0, *self.args, **self.kwds) 

185 else: 

186 raise TypeError(f"Block {self.block} must be callable or Computation") 

187 

188 

189block = Block 

190 

191 

192def populate_computation_from_class(comp, cls, obj, ignore_self=True): 

193 """Populate a computation from class methods with node decorators.""" 

194 for name, member in inspect.getmembers(cls): 

195 node_ = None 

196 if isinstance(member, Node): 

197 node_ = member 

198 elif hasattr(member, "_loman_node_info"): 

199 node_ = getattr(member, "_loman_node_info") 

200 if node_ is not None: 

201 node_.add_to_comp(comp, name, obj, ignore_self) 

202 

203 

204def computation_factory(maybe_cls=None, *, ignore_self=True) -> Type["Computation"]: # noqa: UP006 

205 """Factory function to create computations from class definitions.""" 

206 

207 def wrap(cls): 

208 def create_computation(*args, **kwargs): 

209 obj = cls() 

210 comp = Computation(*args, **kwargs) 

211 comp._definition_object = obj 

212 populate_computation_from_class(comp, cls, obj, ignore_self) 

213 return comp 

214 

215 return create_computation 

216 

217 if maybe_cls is None: 

218 return wrap 

219 

220 return wrap(maybe_cls) 

221 

222 

223def _eval_node(name, f, args, kwds, raise_exceptions): 

224 """To make multiprocessing work, this function must be standalone so that pickle works.""" 

225 exc, tb = None, None 

226 start_dt = datetime.now(UTC) 

227 try: 

228 logging.debug("Running " + str(name)) 

229 value = f(*args, **kwds) 

230 logging.debug("Completed " + str(name)) 

231 except Exception as e: 

232 value = None 

233 exc = e 

234 tb = traceback.format_exc() 

235 if raise_exceptions: 

236 raise 

237 end_dt = datetime.now(UTC) 

238 return value, exc, tb, start_dt, end_dt 

239 

240 

241_MISSING_VALUE_SENTINEL = object() 

242 

243 

244class NullObject: 

245 """Debug helper object that raises exceptions for all attribute/item access.""" 

246 

247 def __getattr__(self, name): 

248 """Raise AttributeError for any attribute access.""" 

249 print(f"__getattr__: {name}") 

250 raise AttributeError(f"'NullObject' object has no attribute '{name}'") 

251 

252 def __setattr__(self, name, value): 

253 """Raise AttributeError for any attribute assignment.""" 

254 print(f"__setattr__: {name}") 

255 raise AttributeError(f"'NullObject' object has no attribute '{name}'") 

256 

257 def __delattr__(self, name): 

258 """Raise AttributeError for any attribute deletion.""" 

259 print(f"__delattr__: {name}") 

260 raise AttributeError(f"'NullObject' object has no attribute '{name}'") 

261 

262 def __call__(self, *args, **kwargs): 

263 """Raise TypeError when called as a function.""" 

264 print(f"__call__: {args}, {kwargs}") 

265 raise TypeError("'NullObject' object is not callable") 

266 

267 def __getitem__(self, key): 

268 """Raise KeyError for any item access.""" 

269 print(f"__getitem__: {key}") 

270 raise KeyError(f"'NullObject' object has no item with key '{key}'") 

271 

272 def __setitem__(self, key, value): 

273 """Raise KeyError for any item assignment.""" 

274 print(f"__setitem__: {key}") 

275 raise KeyError(f"'NullObject' object cannot have items set with key '{key}'") 

276 

277 def __repr__(self): 

278 """Return string representation of NullObject.""" 

279 print(f"__repr__: {self.__dict__}") 

280 return "<NullObject>" 

281 

282 

283def identity_function(x): 

284 """Return the input value unchanged.""" 

285 return x 

286 

287 

288class Computation: 

289 """A computation graph that manages dependencies and calculations. 

290 

291 The Computation class provides a framework for building and executing 

292 computation graphs where nodes represent data or calculations, and edges 

293 represent dependencies between them. 

294 """ 

295 

296 def __init__(self, *, default_executor=None, executor_map=None, metadata=None): 

297 """Initialize a new Computation. 

298 

299 :param default_executor: An executor 

300 :type default_executor: concurrent.futures.Executor, default ThreadPoolExecutor(max_workers=1) 

301 """ 

302 if default_executor is None: 

303 self.default_executor = ThreadPoolExecutor(1) 

304 else: 

305 self.default_executor = default_executor 

306 if executor_map is None: 

307 self.executor_map = {} 

308 else: 

309 self.executor_map = executor_map 

310 self.dag = nx.DiGraph() 

311 self._metadata = {} 

312 if metadata is not None: 

313 self._metadata[NodeKey.root()] = metadata 

314 

315 self.v = self.get_attribute_view_for_path(NodeKey.root(), self._value_one, self.value) 

316 self.s = self.get_attribute_view_for_path(NodeKey.root(), self._state_one, self.state) 

317 self.i = self.get_attribute_view_for_path(NodeKey.root(), self._get_inputs_one_names, self.get_inputs) 

318 self.o = self.get_attribute_view_for_path(NodeKey.root(), self._get_outputs_one, self.get_outputs) 

319 self.t = self.get_attribute_view_for_path(NodeKey.root(), self._tag_one, self.tags) 

320 self.style = self.get_attribute_view_for_path(NodeKey.root(), self._style_one, self.styles) 

321 self.tim = self.get_attribute_view_for_path(NodeKey.root(), self._get_timing_one, self.get_timing) 

322 self.x = self.get_attribute_view_for_path( 

323 NodeKey.root(), self.compute_and_get_value, self.compute_and_get_value 

324 ) 

325 self.src = self.get_attribute_view_for_path(NodeKey.root(), self.print_source, self.print_source) 

326 self._tag_map = defaultdict(set) 

327 self._state_map = {state: set() for state in States} 

328 

329 def get_attribute_view_for_path(self, nodekey: NodeKey, get_one_func: callable, get_many_func: callable): 

330 """Create an attribute view for a specific node path.""" 

331 

332 def node_func(): 

333 return self.get_tree_list_children(nodekey) 

334 

335 def get_one_func_for_path(name: Name): 

336 nk = to_nodekey(name) 

337 new_nk = nk.prepend(nodekey) 

338 if self.has_node(new_nk): 

339 return get_one_func(new_nk) 

340 elif self.tree_has_path(new_nk): 

341 return self.get_attribute_view_for_path(new_nk, get_one_func, get_many_func) 

342 else: 

343 raise KeyError(f"Path {new_nk} does not exist") 

344 

345 def get_many_func_for_path(name: Name | Names): 

346 if isinstance(name, list): 

347 return [get_one_func_for_path(n) for n in name] 

348 else: 

349 return get_one_func_for_path(name) 

350 

351 return AttributeView(node_func, get_one_func_for_path, get_many_func_for_path) 

352 

353 def _get_names_for_state(self, state: States): 

354 return set(node_keys_to_names(self._state_map[state])) 

355 

356 def _get_tags_for_state(self, tag: str): 

357 return set(node_keys_to_names(self._tag_map[tag])) 

358 

359 def add_node( 

360 self, 

361 name: Name, 

362 func=None, 

363 *, 

364 args=None, 

365 kwds=None, 

366 value=_MISSING_VALUE_SENTINEL, 

367 converter=None, 

368 serialize=True, 

369 inspect=True, 

370 group=None, 

371 tags=None, 

372 style=None, 

373 executor=None, 

374 metadata=None, 

375 ): 

376 """Adds or updates a node in a computation. 

377 

378 :param name: Name of the node to add. This may be any hashable object. 

379 :param func: Function to use to calculate the node if the node is a calculation node. By default, the input 

380 nodes to the function will be implied from the names of the function parameters. For example, a 

381 parameter called ``a`` would be taken from the node called ``a``. This can be modified with the 

382 ``kwds`` parameter. 

383 :type func: Function, default None 

384 :param args: Specifies a list of nodes that will be used to populate arguments of the function positionally 

385 for a calculation node. e.g. If args is ``['a', 'b', 'c']`` then the function would be called with 

386 three parameters, taken from the nodes 'a', 'b' and 'c' respectively. 

387 :type args: List, default None 

388 :param kwds: Specifies a mapping from parameter name to the node that should be used to populate that 

389 parameter when calling the function for a calculation node. e.g. If args is ``{'x': 'a', 'y': 'b'}`` 

390 then the function would be called with parameters named 'x' and 'y', and their values would be taken 

391 from nodes 'a' and 'b' respectively. Each entry in the dictionary can be read as "take parameter 

392 [key] from node [value]". 

393 :type kwds: Dictionary, default None 

394 :param value: If given, the value is inserted into the node, and the node state set to UPTODATE. 

395 :type value: default None 

396 :param serialize: Whether the node should be serialized. Some objects cannot be serialized, in which 

397 case, set serialize to False 

398 :type serialize: boolean, default True 

399 :param inspect: Whether to use introspection to determine the arguments of the function, which can be 

400 slow. If this is not set, kwds and args must be set for the function to obtain parameters. 

401 :type inspect: boolean, default True 

402 :param group: Subgraph to render node in 

403 :type group: default None 

404 :param tags: Set of tags to apply to node 

405 :type tags: Iterable 

406 :param styles: Style to apply to node 

407 :type styles: String, default None 

408 :param executor: Name of executor to run node on 

409 :type executor: string 

410 """ 

411 node_key = to_nodekey(name) 

412 LOG.debug(f"Adding node {node_key}") 

413 has_value = value is not _MISSING_VALUE_SENTINEL 

414 if value is _MISSING_VALUE_SENTINEL: 

415 value = None 

416 if tags is None: 

417 tags = [] 

418 

419 self.dag.add_node(node_key) 

420 pred_edges = [(p, node_key) for p in self.dag.predecessors(node_key)] 

421 self.dag.remove_edges_from(pred_edges) 

422 node = self.dag.nodes[node_key] 

423 

424 if metadata is None: 

425 if node_key in self._metadata: 

426 del self._metadata[node_key] 

427 else: 

428 self._metadata[node_key] = metadata 

429 

430 self._set_state_and_literal_value(node_key, States.UNINITIALIZED, None, require_old_state=False) 

431 

432 node[NodeAttributes.TAG] = set() 

433 node[NodeAttributes.STYLE] = style 

434 node[NodeAttributes.GROUP] = group 

435 node[NodeAttributes.ARGS] = {} 

436 node[NodeAttributes.KWDS] = {} 

437 node[NodeAttributes.FUNC] = None 

438 node[NodeAttributes.EXECUTOR] = executor 

439 node[NodeAttributes.CONVERTER] = converter 

440 

441 if func: 

442 node[NodeAttributes.FUNC] = func 

443 args_count = 0 

444 if args: 

445 args_count = len(args) 

446 for i, arg in enumerate(args): 

447 if isinstance(arg, ConstantValue): 

448 node[NodeAttributes.ARGS][i] = arg.value 

449 else: 

450 input_vertex_name = arg 

451 input_vertex_node_key = to_nodekey(input_vertex_name) 

452 if not self.dag.has_node(input_vertex_node_key): 

453 self.dag.add_node(input_vertex_node_key, **{NodeAttributes.STATE: States.PLACEHOLDER}) 

454 self._state_map[States.PLACEHOLDER].add(input_vertex_node_key) 

455 self.dag.add_edge( 

456 input_vertex_node_key, node_key, **{EdgeAttributes.PARAM: (_ParameterType.ARG, i)} 

457 ) 

458 param_map = {} 

459 if inspect: 

460 signature = get_signature(func) 

461 if not signature.has_var_args: 

462 for param_name in signature.kwd_params[args_count:]: 

463 if kwds is not None and param_name in kwds: 

464 param_source = kwds[param_name] 

465 else: 

466 param_source = node_key.parent.join_parts(param_name) 

467 param_map[param_name] = param_source 

468 if signature.has_var_kwds and kwds is not None: 

469 for param_name, param_source in kwds.items(): 

470 param_map[param_name] = param_source 

471 default_names = signature.default_params 

472 else: 

473 if kwds is not None: 

474 for param_name, param_source in kwds.items(): 

475 param_map[param_name] = param_source 

476 default_names = [] 

477 for param_name, param_source in param_map.items(): 

478 if isinstance(param_source, ConstantValue): 

479 node[NodeAttributes.KWDS][param_name] = param_source.value 

480 else: 

481 in_node_name = param_source 

482 in_node_key = to_nodekey(in_node_name) 

483 if not self.dag.has_node(in_node_key): 

484 if param_name in default_names: 

485 continue 

486 else: 

487 self.dag.add_node(in_node_key, **{NodeAttributes.STATE: States.PLACEHOLDER}) 

488 self._state_map[States.PLACEHOLDER].add(in_node_key) 

489 self.dag.add_edge(in_node_key, node_key, **{EdgeAttributes.PARAM: (_ParameterType.KWD, param_name)}) 

490 self._set_descendents(node_key, States.STALE) 

491 

492 if has_value: 

493 self._set_uptodate(node_key, value) 

494 if node[NodeAttributes.STATE] == States.UNINITIALIZED: 

495 self._try_set_computable(node_key) 

496 self.set_tag(node_key, tags) 

497 if serialize: 

498 self.set_tag(node_key, SystemTags.SERIALIZE) 

499 

500 def _refresh_maps(self): 

501 self._tag_map.clear() 

502 for state in States: 

503 self._state_map[state].clear() 

504 for node_key in self._node_keys(): 

505 state = self.dag.nodes[node_key][NodeAttributes.STATE] 

506 self._state_map[state].add(node_key) 

507 tags = self.dag.nodes[node_key].get(NodeAttributes.TAG, set()) 

508 for tag in tags: 

509 self._tag_map[tag].add(node_key) 

510 

511 def _set_tag_one(self, name: Name, tag): 

512 node_key = to_nodekey(name) 

513 self.dag.nodes[node_key][NodeAttributes.TAG].add(tag) 

514 self._tag_map[tag].add(node_key) 

515 

516 def set_tag(self, name: Names, tag): 

517 """Set tags on a node or nodes. Ignored if tags are already set. 

518 

519 :param name: Node or nodes to set tag for 

520 :param tag: Tag to set 

521 """ 

522 apply_n(self._set_tag_one, name, tag) 

523 

524 def _clear_tag_one(self, name: Name, tag): 

525 node_key = to_nodekey(name) 

526 self.dag.nodes[node_key][NodeAttributes.TAG].discard(tag) 

527 self._tag_map[tag].discard(node_key) 

528 

529 def clear_tag(self, name: Names, tag): 

530 """Clear tag on a node or nodes. Ignored if tags are not set. 

531 

532 :param name: Node or nodes to clear tags for 

533 :param tag: Tag to clear 

534 """ 

535 apply_n(self._clear_tag_one, name, tag) 

536 

537 def _set_style_one(self, name: Name, style): 

538 node_key = to_nodekey(name) 

539 self.dag.nodes[node_key][NodeAttributes.STYLE] = style 

540 

541 def set_style(self, name: Names, style): 

542 """Set styles on a node or nodes. 

543 

544 :param name: Node or nodes to set style for 

545 :param style: Style to set 

546 """ 

547 apply_n(self._set_style_one, name, style) 

548 

549 def _clear_style_one(self, name): 

550 node_key = to_nodekey(name) 

551 self.dag.nodes[node_key][NodeAttributes.STYLE] = None 

552 

553 def clear_style(self, name): 

554 """Clear style on a node or nodes. 

555 

556 :param name: Node or nodes to clear styles for 

557 """ 

558 apply_n(self._clear_style_one, name) 

559 

560 def metadata(self, name): 

561 """Get metadata for a node.""" 

562 node_key = to_nodekey(name) 

563 if self.tree_has_path(name): 

564 if node_key not in self._metadata: 

565 self._metadata[node_key] = {} 

566 return self._metadata[node_key] 

567 else: 

568 raise NonExistentNodeException(f"Node {node_key} does not exist.") 

569 

570 def delete_node(self, name): 

571 """Delete a node from a computation. 

572 

573 When nodes are explicitly deleted with ``delete_node``, but are still depended on by other nodes, then they 

574 will be set to PLACEHOLDER status. In this case, if the nodes that depend on a PLACEHOLDER node are deleted, 

575 then the PLACEHOLDER node will also be deleted. 

576 

577 :param name: Name of the node to delete. If the node does not exist, a ``NonExistentNodeException`` will 

578 be raised. 

579 """ 

580 node_key = to_nodekey(name) 

581 LOG.debug(f"Deleting node {node_key}") 

582 

583 if not self.dag.has_node(node_key): 

584 raise NonExistentNodeException(f"Node {node_key} does not exist") 

585 

586 if node_key in self._metadata: 

587 del self._metadata[node_key] 

588 

589 if len(self.dag.succ[node_key]) == 0: 

590 preds = self.dag.predecessors(node_key) 

591 state = self.dag.nodes[node_key][NodeAttributes.STATE] 

592 self.dag.remove_node(node_key) 

593 self._state_map[state].remove(node_key) 

594 for n in preds: 

595 if self.dag.nodes[n][NodeAttributes.STATE] == States.PLACEHOLDER: 

596 self.delete_node(n) 

597 else: 

598 self._set_state(node_key, States.PLACEHOLDER) 

599 

600 def rename_node(self, old_name: Name | Mapping[Name, Name], new_name: Name | None = None): 

601 """Rename a node in a computation. 

602 

603 :param old_name: Node to rename, or a dictionary of nodes to rename, with existing names as keys, and 

604 new names as values 

605 :param new_name: New name for node. 

606 """ 

607 if hasattr(old_name, "__getitem__") and not isinstance(old_name, str): 

608 for k, v in old_name.items(): 

609 LOG.debug(f"Renaming node {k} to {v}") 

610 if new_name is not None: 

611 raise ValueError("new_name must not be set if rename_node is passed a dictionary") 

612 else: 

613 name_mapping = old_name 

614 else: 

615 LOG.debug(f"Renaming node {old_name} to {new_name}") 

616 old_node_key = to_nodekey(old_name) 

617 if not self.dag.has_node(old_node_key): 

618 raise NonExistentNodeException(f"Node {old_name} does not exist") 

619 new_node_key = to_nodekey(new_name) 

620 if self.dag.has_node(new_node_key): 

621 raise NodeAlreadyExistsException(f"Node {new_name} already exists") 

622 name_mapping = {old_name: new_name} 

623 

624 node_key_mapping = {to_nodekey(old_name): to_nodekey(new_name) for old_name, new_name in name_mapping.items()} 

625 nx.relabel_nodes(self.dag, node_key_mapping, copy=False) 

626 

627 for old_node_key, new_node_key in node_key_mapping.items(): 

628 if old_node_key in self._metadata: 

629 self._metadata[new_node_key] = self._metadata[old_node_key] 

630 del self._metadata[old_node_key] 

631 else: 

632 if new_node_key in self._metadata: 

633 del self._metadata[new_node_key] 

634 

635 self._refresh_maps() 

636 

637 def repoint(self, old_name: Name, new_name: Name): 

638 """Changes all nodes that use old_name as an input to use new_name instead. 

639 

640 Note that if old_name is an input to new_name, then that will not be changed, to try to avoid introducing 

641 circular dependencies, but other circular dependencies will not be checked. 

642 

643 If new_name does not exist, then it will be created as a PLACEHOLDER node. 

644 

645 :param old_name: 

646 :param new_name: 

647 :return: 

648 """ 

649 old_node_key = to_nodekey(old_name) 

650 new_node_key = to_nodekey(new_name) 

651 if old_node_key == new_node_key: 

652 return 

653 

654 changed_names = list(self.dag.successors(old_node_key)) 

655 

656 if len(changed_names) > 0 and not self.dag.has_node(new_node_key): 

657 self.dag.add_node(new_node_key, **{NodeAttributes.STATE: States.PLACEHOLDER}) 

658 self._state_map[States.PLACEHOLDER].add(new_node_key) 

659 

660 for name in changed_names: 

661 if name == new_node_key: 

662 continue 

663 edge_data = self.dag.get_edge_data(old_node_key, name) 

664 self.dag.add_edge(new_node_key, name, **edge_data) 

665 self.dag.remove_edge(old_node_key, name) 

666 

667 for name in changed_names: 

668 self.set_stale(name) 

669 

670 def insert(self, name: Name, value, force=False): 

671 """Insert a value into a node of a computation. 

672 

673 Following insertation, the node will have state UPTODATE, and all its descendents will be COMPUTABLE or STALE. 

674 

675 If an attempt is made to insert a value into a node that does not exist, a ``NonExistentNodeException`` 

676 will be raised. 

677 

678 :param name: Name of the node to add. 

679 :param value: The value to be inserted into the node. 

680 :param force: Whether to force recalculation of descendents if node value and state would not be changed 

681 """ 

682 node_key = to_nodekey(name) 

683 LOG.debug(f"Inserting value into node {node_key}") 

684 

685 if not self.dag.has_node(node_key): 

686 raise NonExistentNodeException(f"Node {node_key} does not exist") 

687 

688 state = self._state_one(name) 

689 if state == States.PLACEHOLDER: 

690 raise CannotInsertToPlaceholderNodeException( 

691 "Cannot insert into placeholder node. Use add_node to create the node first" 

692 ) 

693 

694 if not force: 

695 if state == States.UPTODATE: 

696 current_value = self._value_one(name) 

697 if value_eq(value, current_value): 

698 return 

699 

700 self._set_state_and_value(node_key, States.UPTODATE, value) 

701 self._set_descendents(node_key, States.STALE) 

702 for n in self.dag.successors(node_key): 

703 self._try_set_computable(n) 

704 

705 def insert_many(self, name_value_pairs: Iterable[tuple[Name, object]]): 

706 """Insert values into many nodes of a computation simultaneously. 

707 

708 Following insertation, the nodes will have state UPTODATE, and all their descendents will be COMPUTABLE 

709 or STALE. In the case of inserting many nodes, some of which are descendents of others, this ensures that 

710 the inserted nodes have correct status, rather than being set as STALE when their ancestors are inserted. 

711 

712 If an attempt is made to insert a value into a node that does not exist, a ``NonExistentNodeException`` will be 

713 raised, and none of the nodes will be inserted. 

714 

715 :param name_value_pairs: Each tuple should be a pair (name, value), where name is the name of the node to 

716 insert the value into. 

717 :type name_value_pairs: List of tuples 

718 """ 

719 node_key_value_pairs = [(to_nodekey(name), value) for name, value in name_value_pairs] 

720 LOG.debug(f"Inserting value into nodes {', '.join(str(name) for name, value in node_key_value_pairs)}") 

721 

722 for name, value in node_key_value_pairs: 

723 if not self.dag.has_node(name): 

724 raise NonExistentNodeException(f"Node {name} does not exist") 

725 

726 stale = set() 

727 computable = set() 

728 for name, value in node_key_value_pairs: 

729 self._set_state_and_value(name, States.UPTODATE, value) 

730 stale.update(nx.dag.descendants(self.dag, name)) 

731 computable.update(self.dag.successors(name)) 

732 names = set([name for name, value in node_key_value_pairs]) 

733 stale.difference_update(names) 

734 computable.difference_update(names) 

735 for name in stale: 

736 self._set_state(name, States.STALE) 

737 for name in computable: 

738 self._try_set_computable(name) 

739 

740 def insert_from(self, other, nodes: Iterable[Name] | None = None): 

741 """Insert values into another Computation object into this Computation object. 

742 

743 :param other: The computation object to take values from 

744 :type Computation: 

745 :param nodes: Only populate the nodes with the names provided in this list. By default, all nodes from the 

746 other Computation object that have corresponding nodes in this Computation object will be inserted 

747 :type nodes: List, default None 

748 """ 

749 if nodes is None: 

750 nodes = set(self.dag.nodes) 

751 nodes.intersection_update(other.dag.nodes()) 

752 name_value_pairs = [(name, other.value(name)) for name in nodes] 

753 self.insert_many(name_value_pairs) 

754 

755 def _set_state(self, node_key: NodeKey, state: States): 

756 node = self.dag.nodes[node_key] 

757 old_state = node[NodeAttributes.STATE] 

758 self._state_map[old_state].remove(node_key) 

759 node[NodeAttributes.STATE] = state 

760 self._state_map[state].add(node_key) 

761 

762 def _set_state_and_value( 

763 self, node_key: NodeKey, state: States, value: object, *, throw_conversion_exception: bool = True 

764 ): 

765 node = self.dag.nodes[node_key] 

766 converter = node.get(NodeAttributes.CONVERTER) 

767 if converter is None: 

768 self._set_state_and_literal_value(node_key, state, value) 

769 else: 

770 try: 

771 converted_value = converter(value) 

772 self._set_state_and_literal_value(node_key, state, converted_value) 

773 except Exception as e: 

774 tb = traceback.format_exc() 

775 self._set_error(node_key, e, tb) 

776 if throw_conversion_exception: 

777 raise e 

778 

779 def _set_state_and_literal_value( 

780 self, node_key: NodeKey, state: States, value: object, require_old_state: bool = True 

781 ): 

782 node = self.dag.nodes[node_key] 

783 try: 

784 old_state = node[NodeAttributes.STATE] 

785 self._state_map[old_state].remove(node_key) 

786 except KeyError: 

787 if require_old_state: 

788 raise 

789 node[NodeAttributes.STATE] = state 

790 node[NodeAttributes.VALUE] = value 

791 self._state_map[state].add(node_key) 

792 

793 def _set_states(self, node_keys: Iterable[NodeKey], state: States): 

794 for name in node_keys: 

795 node = self.dag.nodes[name] 

796 old_state = node[NodeAttributes.STATE] 

797 self._state_map[old_state].remove(name) 

798 node[NodeAttributes.STATE] = state 

799 self._state_map[state].update(node_keys) 

800 

801 def set_stale(self, name: Name): 

802 """Set the state of a node and all its dependencies to STALE. 

803 

804 :param name: Name of the node to set as STALE. 

805 """ 

806 node_key = to_nodekey(name) 

807 node_keys = [node_key] 

808 node_keys.extend(nx.dag.descendants(self.dag, node_key)) 

809 self._set_states(node_keys, States.STALE) 

810 self._try_set_computable(node_key) 

811 

812 def pin(self, name: Name, value=None): 

813 """Set the state of a node to PINNED. 

814 

815 :param name: Name of the node to set as PINNED. 

816 :param value: Value to pin to the node, if provided. 

817 :type value: default None 

818 """ 

819 node_key = to_nodekey(name) 

820 if value is not None: 

821 self.insert(node_key, value) 

822 self._set_states([node_key], States.PINNED) 

823 

824 def unpin(self, name): 

825 """Unpin a node (state of node and all descendents will be set to STALE). 

826 

827 :param name: Name of the node to set as PINNED. 

828 """ 

829 node_key = to_nodekey(name) 

830 self.set_stale(node_key) 

831 

832 def _get_descendents(self, node_key: NodeKey, stop_states: set[States] | None = None) -> set[NodeKey]: 

833 if self.dag.nodes[node_key][NodeAttributes.STATE] in stop_states: 

834 return set() 

835 if stop_states is None: 

836 stop_states = [] 

837 visited = set() 

838 to_visit = {node_key} 

839 while to_visit: 

840 n = to_visit.pop() 

841 visited.add(n) 

842 for n1 in self.dag.successors(n): 

843 if n1 in visited: 

844 continue 

845 if self.dag.nodes[n1][NodeAttributes.STATE] in stop_states: 

846 continue 

847 to_visit.add(n1) 

848 visited.remove(node_key) 

849 return visited 

850 

851 def _set_descendents(self, node_key: NodeKey, state): 

852 descendents = self._get_descendents(node_key, {States.PINNED}) 

853 self._set_states(descendents, state) 

854 

855 def _set_uninitialized(self, node_key: NodeKey): 

856 self._set_states([node_key], States.UNINITIALIZED) 

857 self.dag.nodes[node_key].pop(NodeAttributes.VALUE, None) 

858 

859 def _set_uptodate(self, node_key: NodeKey, value: object): 

860 self._set_state_and_value(node_key, States.UPTODATE, value) 

861 self._set_descendents(node_key, States.STALE) 

862 for n in self.dag.successors(node_key): 

863 self._try_set_computable(n) 

864 

865 def _set_error(self, node_key: NodeKey, exc: Exception, tb: inspect.Traceback): 

866 self._set_state_and_literal_value(node_key, States.ERROR, Error(exc, tb)) 

867 self._set_descendents(node_key, States.STALE) 

868 

869 def _try_set_computable(self, node_key: NodeKey): 

870 if self.dag.nodes[node_key][NodeAttributes.STATE] == States.PINNED: 

871 return 

872 if self.dag.nodes[node_key].get(NodeAttributes.FUNC) is not None: 

873 for n in self.dag.predecessors(node_key): 

874 if not self.dag.has_node(n): 

875 return 

876 if self.dag.nodes[n][NodeAttributes.STATE] != States.UPTODATE: 

877 return 

878 self._set_state(node_key, States.COMPUTABLE) 

879 

880 def _get_parameter_data(self, node_key: NodeKey): 

881 for arg, value in self.dag.nodes[node_key][NodeAttributes.ARGS].items(): 

882 yield _ParameterItem(_ParameterType.ARG, arg, value) 

883 for param_name, value in self.dag.nodes[node_key][NodeAttributes.KWDS].items(): 

884 yield _ParameterItem(_ParameterType.KWD, param_name, value) 

885 for in_node_name in self.dag.predecessors(node_key): 

886 param_value = self.dag.nodes[in_node_name][NodeAttributes.VALUE] 

887 edge = self.dag[in_node_name][node_key] 

888 param_type, param_name = edge[EdgeAttributes.PARAM] 

889 yield _ParameterItem(param_type, param_name, param_value) 

890 

891 def _get_func_args_kwds(self, node_key: NodeKey): 

892 node0 = self.dag.nodes[node_key] 

893 f = node0[NodeAttributes.FUNC] 

894 executor_name = node0.get(NodeAttributes.EXECUTOR) 

895 args, kwds = [], {} 

896 for param in self._get_parameter_data(node_key): 

897 if param.type == _ParameterType.ARG: 

898 idx = param.name 

899 while len(args) <= idx: 

900 args.append(None) 

901 args[idx] = param.value 

902 elif param.type == _ParameterType.KWD: 

903 kwds[param.name] = param.value 

904 else: 

905 raise Exception(f"Unexpected param type: {param.type}") 

906 return f, executor_name, args, kwds 

907 

908 def get_definition_args_kwds(self, name: Name) -> tuple[list, dict]: 

909 """Get the arguments and keyword arguments for a node's function definition.""" 

910 res_args = [] 

911 res_kwds = {} 

912 node_key = to_nodekey(name) 

913 node_data = self.dag.nodes[node_key] 

914 if NodeAttributes.ARGS in node_data: 

915 for idx, value in node_data[NodeAttributes.ARGS].items(): 

916 while len(res_args) <= idx: 

917 res_args.append(None) 

918 res_args[idx] = C(value) 

919 if NodeAttributes.KWDS in node_data: 

920 for param_name, value in node_data[NodeAttributes.KWDS].items(): 

921 res_kwds[param_name] = C(value) 

922 for in_node_key in self.dag.predecessors(node_key): 

923 edge = self.dag[in_node_key][node_key] 

924 if EdgeAttributes.PARAM in edge: 

925 param_type, param_name = edge[EdgeAttributes.PARAM] 

926 if param_type == _ParameterType.ARG: 

927 idx: int = param_name 

928 while len(res_args) <= idx: 

929 res_args.append(None) 

930 res_args[idx] = in_node_key.name 

931 elif param_type == _ParameterType.KWD: 

932 res_kwds[param_name] = in_node_key.name 

933 else: 

934 raise Exception(f"Unexpected param type: {param_type}") 

935 return res_args, res_kwds 

936 

937 def _compute_nodes(self, node_keys: Iterable[NodeKey], raise_exceptions: bool = False): 

938 LOG.debug(f"Computing nodes {node_keys}") 

939 

940 futs = {} 

941 

942 def run(name): 

943 f, executor_name, args, kwds = self._get_func_args_kwds(name) 

944 if executor_name is None: 

945 executor = self.default_executor 

946 else: 

947 executor = self.executor_map[executor_name] 

948 fut = executor.submit(_eval_node, name, f, args, kwds, raise_exceptions) 

949 futs[fut] = name 

950 

951 computed = set() 

952 

953 for node_key in node_keys: 

954 node0 = self.dag.nodes[node_key] 

955 state = node0[NodeAttributes.STATE] 

956 if state == States.COMPUTABLE: 

957 run(node_key) 

958 

959 while len(futs) > 0: 

960 done, not_done = wait(futs.keys(), return_when=FIRST_COMPLETED) 

961 for fut in done: 

962 node_key = futs.pop(fut) 

963 node0 = self.dag.nodes[node_key] 

964 try: 

965 value, exc, tb, start_dt, end_dt = fut.result() 

966 except Exception as e: 

967 exc = e 

968 tb = traceback.format_exc() 

969 self._set_error(node_key, exc, tb) 

970 raise 

971 delta = (end_dt - start_dt).total_seconds() 

972 if exc is None: 

973 self._set_state_and_value(node_key, States.UPTODATE, value, throw_conversion_exception=False) 

974 node0[NodeAttributes.TIMING] = TimingData(start_dt, end_dt, delta) 

975 self._set_descendents(node_key, States.STALE) 

976 for n in self.dag.successors(node_key): 

977 logging.debug(str(node_key) + " " + str(n) + " " + str(computed)) 

978 if n in computed: 

979 raise LoopDetectedException(f"Calculating {node_key} for the second time") 

980 self._try_set_computable(n) 

981 node0 = self.dag.nodes[n] 

982 state = node0[NodeAttributes.STATE] 

983 if state == States.COMPUTABLE and n in node_keys: 

984 run(n) 

985 else: 

986 self._set_error(node_key, exc, tb) 

987 computed.add(node_key) 

988 

989 def _get_calc_node_keys(self, node_key: NodeKey) -> list[NodeKey]: 

990 g = nx.DiGraph() 

991 g.add_nodes_from(self.dag.nodes) 

992 g.add_edges_from(self.dag.edges) 

993 for n in nx.ancestors(g, node_key): 

994 node = self.dag.nodes[n] 

995 state = node[NodeAttributes.STATE] 

996 if state == States.UPTODATE or state == States.PINNED: 

997 g.remove_node(n) 

998 

999 ancestors = nx.ancestors(g, node_key) 

1000 for n in ancestors: 

1001 if state == States.UNINITIALIZED and len(self.dag.pred[n]) == 0: 

1002 raise Exception(f"Cannot compute {node_key} because {n} uninitialized") 

1003 if state == States.PLACEHOLDER: 

1004 raise Exception(f"Cannot compute {node_key} because {n} is placeholder") 

1005 

1006 ancestors.add(node_key) 

1007 nodes_sorted = topological_sort(g) 

1008 return [n for n in nodes_sorted if n in ancestors] 

1009 

1010 def _get_calc_node_names(self, name: Name) -> Names: 

1011 node_key = to_nodekey(name) 

1012 return node_keys_to_names(self._get_calc_node_keys(node_key)) 

1013 

1014 def compute(self, name: Name | Iterable[Name], raise_exceptions=False): 

1015 """Compute a node and all necessary predecessors. 

1016 

1017 Following the computation, if successful, the target node, and all necessary ancestors that were not already 

1018 UPTODATE will have been calculated and set to UPTODATE. Any node that did not need to be calculated will not 

1019 have been recalculated. 

1020 

1021 If any nodes raises an exception, then the state of that node will be set to ERROR, and its value set to an 

1022 object containing the exception object, as well as a traceback. This will not halt the computation, which 

1023 will proceed as far as it can, until no more nodes that would be required to calculate the target are 

1024 COMPUTABLE. 

1025 

1026 :param name: Name of the node to compute 

1027 :param raise_exceptions: Whether to pass exceptions raised by node computations back to the caller 

1028 :type raise_exceptions: Boolean, default False 

1029 """ 

1030 if isinstance(name, (types.GeneratorType, list)): 

1031 calc_nodes = set() 

1032 for name0 in name: 

1033 node_key = to_nodekey(name0) 

1034 for n in self._get_calc_node_keys(node_key): 

1035 calc_nodes.add(n) 

1036 else: 

1037 node_key = to_nodekey(name) 

1038 calc_nodes = self._get_calc_node_keys(node_key) 

1039 self._compute_nodes(calc_nodes, raise_exceptions=raise_exceptions) 

1040 

1041 def compute_all(self, raise_exceptions=False): 

1042 """Compute all nodes of a computation that can be computed. 

1043 

1044 Nodes that are already UPTODATE will not be recalculated. Following the computation, if successful, all 

1045 nodes will have state UPTODATE, except UNINITIALIZED input nodes and PLACEHOLDER nodes. 

1046 

1047 If any nodes raises an exception, then the state of that node will be set to ERROR, and its value set to an 

1048 object containing the exception object, as well as a traceback. This will not halt the computation, which 

1049 will proceed as far as it can, until no more nodes are COMPUTABLE. 

1050 

1051 :param raise_exceptions: Whether to pass exceptions raised by node computations back to the caller 

1052 :type raise_exceptions: Boolean, default False 

1053 """ 

1054 self._compute_nodes(self._node_keys(), raise_exceptions=raise_exceptions) 

1055 

1056 def _node_keys(self) -> list[NodeKey]: 

1057 """Get a list of nodes in this computation. 

1058 

1059 :return: List of nodes. 

1060 """ 

1061 return list(self.dag.nodes) 

1062 

1063 def nodes(self) -> list[Name]: 

1064 """Get a list of nodes in this computation. 

1065 

1066 :return: List of nodes. 

1067 """ 

1068 return list(n.name for n in self.dag.nodes) 

1069 

1070 def get_tree_list_children(self, name: Name) -> set[Name]: 

1071 """Get a list of nodes in this computation. 

1072 

1073 :return: List of nodes. 

1074 """ 

1075 node_key = to_nodekey(name) 

1076 idx = len(node_key.parts) 

1077 result = set() 

1078 for n in self.dag.nodes: 

1079 if n.is_descendent_of(node_key): 

1080 result.add(n.parts[idx]) 

1081 return result 

1082 

1083 def has_node(self, name: Name): 

1084 """Check if a node with the given name exists in the computation.""" 

1085 node_key = to_nodekey(name) 

1086 return node_key in self.dag.nodes 

1087 

1088 def tree_has_path(self, name: Name): 

1089 """Check if a hierarchical path exists in the computation tree.""" 

1090 node_key = to_nodekey(name) 

1091 if node_key.is_root: 

1092 return True 

1093 if self.has_node(node_key): 

1094 return True 

1095 for n in self.dag.nodes: 

1096 if n.is_descendent_of(node_key): 

1097 return True 

1098 return False 

1099 

1100 def get_tree_descendents( 

1101 self, name: Name | None = None, *, include_stem: bool = True, graph_nodes_only: bool = False 

1102 ) -> set[Name]: 

1103 """Get a list of descendent blocks and nodes. 

1104 

1105 Returns blocks and nodes that are descendents of the input node, 

1106 e.g. for node 'foo', might return ['foo/bar', 'foo/baz']. 

1107 

1108 :param name: Name of node to get descendents for 

1109 :return: List of descendent node names 

1110 """ 

1111 node_key = NodeKey.root() if name is None else to_nodekey(name) 

1112 stemsize = len(node_key.parts) 

1113 result = set() 

1114 for n in self.dag.nodes: 

1115 if n.is_descendent_of(node_key): 

1116 if graph_nodes_only: 

1117 nodes = [n] 

1118 else: 

1119 nodes = n.ancestors() 

1120 for n2 in nodes: 

1121 if n2.is_descendent_of(node_key): 

1122 if include_stem: 

1123 nm = n2.name 

1124 else: 

1125 nm = NodeKey(tuple(n2.parts[stemsize:])).name 

1126 result.add(nm) 

1127 return result 

1128 

1129 def _state_one(self, name: Name): 

1130 node_key = to_nodekey(name) 

1131 return self.dag.nodes[node_key][NodeAttributes.STATE] 

1132 

1133 def state(self, name: Name | Names): 

1134 """Get the state of a node. 

1135 

1136 This can also be accessed using the attribute-style accessor ``s`` if ``name`` is a valid Python 

1137 attribute name:: 

1138 

1139 >>> comp = Computation() 

1140 >>> comp.add_node('foo', value=1) 

1141 >>> comp.state('foo') 

1142 <States.UPTODATE: 4> 

1143 >>> comp.s.foo 

1144 <States.UPTODATE: 4> 

1145 

1146 :param name: Name or names of the node to get state for 

1147 :type name: Name or Names 

1148 """ 

1149 return apply1(self._state_one, name) 

1150 

1151 def _value_one(self, name: Name): 

1152 node_key = to_nodekey(name) 

1153 return self.dag.nodes[node_key][NodeAttributes.VALUE] 

1154 

1155 def value(self, name: Name | Names): 

1156 """Get the current value of a node. 

1157 

1158 This can also be accessed using the attribute-style accessor ``v`` if ``name`` is a valid Python 

1159 attribute name:: 

1160 

1161 >>> comp = Computation() 

1162 >>> comp.add_node('foo', value=1) 

1163 >>> comp.value('foo') 

1164 1 

1165 >>> comp.v.foo 

1166 1 

1167 

1168 :param name: Name or names of the node to get the value of 

1169 :type name: Name or Names 

1170 """ 

1171 return apply1(self._value_one, name) 

1172 

1173 def compute_and_get_value(self, name: Name): 

1174 """Get the current value of a node. 

1175 

1176 This can also be accessed using the attribute-style accessor ``v`` if ``name`` is a valid Python 

1177 attribute name:: 

1178 

1179 >>> comp = Computation() 

1180 >>> comp.add_node('foo', value=1) 

1181 >>> comp.add_node('bar', lambda foo: foo + 1) 

1182 >>> comp.compute_and_get_value('bar') 

1183 2 

1184 >>> comp.x.bar 

1185 2 

1186 

1187 :param name: Name or names of the node to get the value of 

1188 :type name: Name 

1189 """ 

1190 name = to_nodekey(name) 

1191 if self.state(name) == States.UPTODATE: 

1192 return self.value(name) 

1193 self.compute(name, raise_exceptions=True) 

1194 if self.state(name) == States.UPTODATE: 

1195 return self.value(name) 

1196 raise ComputationError(f"Unable to compute node {name}") 

1197 

1198 def _tag_one(self, name: Name): 

1199 node_key = to_nodekey(name) 

1200 node = self.dag.nodes[node_key] 

1201 return node[NodeAttributes.TAG] 

1202 

1203 def tags(self, name: Name | Names): 

1204 """Get the tags associated with a node. 

1205 

1206 >>> comp = Computation() 

1207 >>> comp.add_node('a', tags=['foo', 'bar']) 

1208 >>> sorted(comp.t.a) 

1209 ['__serialize__', 'bar', 'foo'] 

1210 

1211 :param name: Name or names of the node to get the tags of 

1212 :return: 

1213 """ 

1214 return apply1(self._tag_one, name) 

1215 

1216 def nodes_by_tag(self, tag) -> set[Name]: 

1217 """Get the names of nodes with a particular tag or tags. 

1218 

1219 :param tag: Tag or tags for which to retrieve nodes 

1220 :return: Names of the nodes with those tags 

1221 """ 

1222 nodes = set() 

1223 for tag1 in as_iterable(tag): 

1224 nodes1 = self._tag_map.get(tag1) 

1225 if nodes1 is not None: 

1226 nodes.update(nodes1) 

1227 return set(n.name for n in nodes) 

1228 

1229 def _style_one(self, name: Name): 

1230 node_key = to_nodekey(name) 

1231 node = self.dag.nodes[node_key] 

1232 return node.get(NodeAttributes.STYLE) 

1233 

1234 def styles(self, name: Name | Names): 

1235 """Get the tags associated with a node. 

1236 

1237 >>> comp = Computation() 

1238 >>> comp.add_node('a', style='dot') 

1239 >>> comp.style.a 

1240 'dot' 

1241 

1242 :param name: Name or names of the node to get the tags of 

1243 :return: 

1244 """ 

1245 return apply1(self._style_one, name) 

1246 

1247 def _get_item_one(self, name: Name): 

1248 node_key = to_nodekey(name) 

1249 node = self.dag.nodes[node_key] 

1250 return NodeData(node[NodeAttributes.STATE], node[NodeAttributes.VALUE]) 

1251 

1252 def __getitem__(self, name: Name | Names): 

1253 """Get the state and current value of a node. 

1254 

1255 :param name: Name of the node to get the state and value of 

1256 """ 

1257 return apply1(self._get_item_one, name) 

1258 

1259 def _get_timing_one(self, name: Name): 

1260 node_key = to_nodekey(name) 

1261 node = self.dag.nodes[node_key] 

1262 return node.get(NodeAttributes.TIMING, None) 

1263 

1264 def get_timing(self, name: Name | Names): 

1265 """Get the timing information for a node. 

1266 

1267 :param name: Name or names of the node to get the timing information of 

1268 :return: 

1269 """ 

1270 return apply1(self._get_timing_one, name) 

1271 

1272 def to_df(self): 

1273 """Get a dataframe containing the states and value of all nodes of computation. 

1274 

1275 :: 

1276 

1277 >>> import loman 

1278 >>> comp = loman.Computation() 

1279 >>> comp.add_node('foo', value=1) 

1280 >>> comp.add_node('bar', value=2) 

1281 >>> comp.to_df() # doctest: +NORMALIZE_WHITESPACE 

1282 state value 

1283 foo States.UPTODATE 1 

1284 bar States.UPTODATE 2 

1285 """ 

1286 df = pd.DataFrame(index=topological_sort(self.dag)) 

1287 df[NodeAttributes.STATE] = pd.Series(nx.get_node_attributes(self.dag, NodeAttributes.STATE)) 

1288 df[NodeAttributes.VALUE] = pd.Series(nx.get_node_attributes(self.dag, NodeAttributes.VALUE)) 

1289 df_timing = pd.DataFrame.from_dict(nx.get_node_attributes(self.dag, "timing"), orient="index") 

1290 df = pd.merge(df, df_timing, left_index=True, right_index=True, how="left") 

1291 df.index = pd.Index([nk.name for nk in df.index]) 

1292 return df 

1293 

1294 def to_dict(self): 

1295 """Get a dictionary containing the values of all nodes of a computation. 

1296 

1297 :: 

1298 

1299 >>> import loman 

1300 >>> comp = loman.Computation() 

1301 >>> comp.add_node('foo', value=1) 

1302 >>> comp.add_node('bar', value=2) 

1303 >>> comp.to_dict() # doctest: +ELLIPSIS 

1304 {NodeKey('foo'): 1, NodeKey('bar'): 2} 

1305 """ 

1306 return nx.get_node_attributes(self.dag, NodeAttributes.VALUE) 

1307 

1308 def _get_inputs_one_node_keys(self, node_key: NodeKey) -> list[NodeKey]: 

1309 args_dict = {} 

1310 kwds = [] 

1311 max_arg_index = -1 

1312 for input_node in self.dag.predecessors(node_key): 

1313 input_edge = self.dag[input_node][node_key] 

1314 input_type, input_param = input_edge[EdgeAttributes.PARAM] 

1315 if input_type == _ParameterType.ARG: 

1316 idx = input_param 

1317 max_arg_index = max(max_arg_index, idx) 

1318 args_dict[idx] = input_node 

1319 elif input_type == _ParameterType.KWD: 

1320 kwds.append(input_node) 

1321 if max_arg_index >= 0: 

1322 args = [None] * (max_arg_index + 1) 

1323 for idx, input_node in args_dict.items(): 

1324 args[idx] = input_node 

1325 return args + kwds 

1326 else: 

1327 return kwds 

1328 

1329 def _get_inputs_one_names(self, name: Name) -> Names: 

1330 node_key = to_nodekey(name) 

1331 return node_keys_to_names(self._get_inputs_one_node_keys(node_key)) 

1332 

1333 def get_inputs(self, name: Name | Names) -> list[Names]: 

1334 """Get a list of the inputs for a node or set of nodes. 

1335 

1336 :param name: Name or names of nodes to get inputs for 

1337 :return: If name is scalar, return a list of upstream nodes used as input. If name is a list, return a 

1338 list of list of inputs. 

1339 """ 

1340 return apply1(self._get_inputs_one_names, name) 

1341 

1342 def _get_ancestors_node_keys(self, node_keys: Iterable[NodeKey], include_self=True) -> set[NodeKey]: 

1343 ancestors = set() 

1344 for n in node_keys: 

1345 if include_self: 

1346 ancestors.add(n) 

1347 for ancestor in nx.ancestors(self.dag, n): 

1348 ancestors.add(ancestor) 

1349 return ancestors 

1350 

1351 def get_ancestors(self, names: Name | Names, include_self=True) -> Names: 

1352 """Get all ancestor nodes of the specified nodes.""" 

1353 node_keys = names_to_node_keys(names) 

1354 ancestor_node_keys = self._get_ancestors_node_keys(node_keys, include_self) 

1355 return node_keys_to_names(ancestor_node_keys) 

1356 

1357 def _get_original_inputs_node_keys(self, node_keys: list[NodeKey] | None) -> Names: 

1358 if node_keys is None: 

1359 node_keys = self._node_keys() 

1360 else: 

1361 node_keys = self._get_ancestors_node_keys(node_keys) 

1362 return [n for n in node_keys if self.dag.nodes[n].get(NodeAttributes.FUNC) is None] 

1363 

1364 def get_original_inputs(self, names: Name | Names | None = None) -> Names: 

1365 """Get a list of the original non-computed inputs for a node or set of nodes. 

1366 

1367 :param names: Name or names of nodes to get inputs for 

1368 :return: Return a list of original non-computed inputs that are ancestors of the input nodes 

1369 """ 

1370 if names is None: 

1371 node_keys = None 

1372 else: 

1373 node_keys = names_to_node_keys(names) 

1374 

1375 node_keys = self._get_original_inputs_node_keys(node_keys) 

1376 

1377 return node_keys_to_names(node_keys) 

1378 

1379 def _get_outputs_one(self, name: Name) -> Names: 

1380 node_key = to_nodekey(name) 

1381 output_node_keys = list(self.dag.successors(node_key)) 

1382 return node_keys_to_names(output_node_keys) 

1383 

1384 def get_outputs(self, name: Name | Names) -> Names | list[Names]: 

1385 """Get a list of the outputs for a node or set of nodes. 

1386 

1387 :param name: Name or names of nodes to get outputs for 

1388 :return: If name is scalar, return a list of downstream nodes used as output. If name is a list, return a 

1389 list of list of outputs. 

1390 

1391 """ 

1392 return apply1(self._get_outputs_one, name) 

1393 

1394 def _get_descendents_node_keys(self, node_keys: Iterable[NodeKey], include_self: bool = True) -> Names: 

1395 ancestor_node_keys = set() 

1396 for node_key in node_keys: 

1397 if include_self: 

1398 ancestor_node_keys.add(node_key) 

1399 for ancestor in nx.descendants(self.dag, node_key): 

1400 ancestor_node_keys.add(ancestor) 

1401 return ancestor_node_keys 

1402 

1403 def get_descendents(self, names: Name | Names, include_self: bool = True) -> Names: 

1404 """Get all descendent nodes of the specified nodes.""" 

1405 node_keys = names_to_node_keys(names) 

1406 descendent_node_keys = self._get_descendents_node_keys(node_keys, include_self) 

1407 return node_keys_to_names(descendent_node_keys) 

1408 

1409 def get_final_outputs(self, names: Name | Names | None = None): 

1410 """Get final output nodes (nodes with no descendants) from the specified nodes.""" 

1411 if names is None: 

1412 node_keys = self._node_keys() 

1413 else: 

1414 node_keys = names_to_node_keys(names) 

1415 node_keys = self._get_descendents_node_keys(node_keys) 

1416 output_node_keys = [n for n in node_keys if len(nx.descendants(self.dag, n)) == 0] 

1417 return node_keys_to_names(output_node_keys) 

1418 

1419 def get_source(self, name: Name) -> str: 

1420 """Get the source code for a node.""" 

1421 node_key = to_nodekey(name) 

1422 func = self.dag.nodes[node_key].get(NodeAttributes.FUNC, None) 

1423 if func is not None: 

1424 file = inspect.getsourcefile(func) 

1425 _, lineno = inspect.getsourcelines(func) 

1426 source = inspect.getsource(func) 

1427 return f"{file}:{lineno}\n\n{source}" 

1428 else: 

1429 return "NOT A CALCULATED NODE" 

1430 

1431 def print_source(self, name: Name): 

1432 """Print the source code for a computation node.""" 

1433 print(self.get_source(name)) 

1434 

1435 def restrict(self, output_names: Name | Names, input_names: Name | Names | None = None): 

1436 """Restrict a computation to the ancestors of a set of output nodes. 

1437 

1438 Excludes ancestors of a set of input nodes. 

1439 

1440 If the set of input_nodes that is specified is not sufficient for the set of output_nodes then additional 

1441 nodes that are ancestors of the output_nodes will be included, but the input nodes specified will be input 

1442 nodes of the modified Computation. 

1443 

1444 :param output_nodes: 

1445 :param input_nodes: 

1446 :return: None - modifies existing computation in place 

1447 """ 

1448 if input_names is not None: 

1449 for name in input_names: 

1450 nodedata = self._get_item_one(name) 

1451 self.add_node(name) 

1452 self._set_state_and_literal_value(to_nodekey(name), nodedata.state, nodedata.value) 

1453 output_node_keys = names_to_node_keys(output_names) 

1454 ancestor_node_keys = self._get_ancestors_node_keys(output_node_keys) 

1455 self.dag.remove_nodes_from([n for n in self.dag if n not in ancestor_node_keys]) 

1456 

1457 def __getstate__(self): 

1458 """Prepare computation for serialization by removing non-serializable nodes.""" 

1459 node_serialize = nx.get_node_attributes(self.dag, NodeAttributes.TAG) 

1460 obj = self.copy() 

1461 for name, tags in node_serialize.items(): 

1462 if SystemTags.SERIALIZE not in tags: 

1463 obj._set_uninitialized(name) 

1464 return {"dag": obj.dag} 

1465 

1466 def __setstate__(self, state): 

1467 """Restore computation from serialized state.""" 

1468 self.__init__() 

1469 self.dag = state["dag"] 

1470 self._refresh_maps() 

1471 

1472 def write_dill_old(self, file_): 

1473 """Serialize a computation to a file or file-like object. 

1474 

1475 :param file_: If string, writes to a file 

1476 :type file_: File-like object, or string 

1477 """ 

1478 warnings.warn("write_dill_old is deprecated, use write_dill instead", DeprecationWarning, stacklevel=2) 

1479 original_getstate = self.__class__.__getstate__ 

1480 original_setstate = self.__class__.__setstate__ 

1481 

1482 try: 

1483 del self.__class__.__getstate__ 

1484 del self.__class__.__setstate__ 

1485 

1486 node_serialize = nx.get_node_attributes(self.dag, NodeAttributes.TAG) 

1487 obj = self.copy() 

1488 obj.executor_map = None 

1489 obj.default_executor = None 

1490 for name, tags in node_serialize.items(): 

1491 if SystemTags.SERIALIZE not in tags: 

1492 obj._set_uninitialized(name) 

1493 

1494 if isinstance(file_, str): 

1495 with open(file_, "wb") as f: 

1496 dill.dump(obj, f) 

1497 else: 

1498 dill.dump(obj, file_) 

1499 finally: 

1500 self.__class__.__getstate__ = original_getstate 

1501 self.__class__.__setstate__ = original_setstate 

1502 

1503 def write_dill(self, file_): 

1504 """Serialize a computation to a file or file-like object. 

1505 

1506 :param file_: If string, writes to a file 

1507 :type file_: File-like object, or string 

1508 """ 

1509 if isinstance(file_, str): 

1510 with open(file_, "wb") as f: 

1511 dill.dump(self, f) 

1512 else: 

1513 dill.dump(self, file_) 

1514 

1515 @staticmethod 

1516 def read_dill(file_): 

1517 """Deserialize a computation from a file or file-like object. 

1518 

1519 :param file_: If string, writes to a file 

1520 :type file_: File-like object, or string 

1521 """ 

1522 if isinstance(file_, str): 

1523 with open(file_, "rb") as f: 

1524 obj = dill.load(f) 

1525 else: 

1526 obj = dill.load(file_) 

1527 if isinstance(obj, Computation): 

1528 return obj 

1529 else: 

1530 raise Exception() 

1531 

1532 def copy(self): 

1533 """Create a copy of a computation. 

1534 

1535 The copy is shallow. Any values in the new Computation's DAG will be the same object as this Computation's 

1536 DAG. As new objects will be created by any further computations, this should not be an issue. 

1537 

1538 :rtype: Computation 

1539 """ 

1540 obj = Computation() 

1541 obj.dag = nx.DiGraph(self.dag) 

1542 obj._tag_map = {tag: nodes.copy() for tag, nodes in self._tag_map.items()} 

1543 obj._state_map = {state: nodes.copy() for state, nodes in self._state_map.items()} 

1544 return obj 

1545 

1546 def add_named_tuple_expansion(self, name, namedtuple_type, group=None): 

1547 """Automatically add nodes to extract each element of a named tuple type. 

1548 

1549 It is often convenient for a calculation to return multiple values, and it is polite to do this a namedtuple 

1550 rather than a regular tuple, so that later users have same name to identify elements of the tuple. It can 

1551 also help make a computation clearer if a downstream computation depends on one element of such a tuple, 

1552 rather than the entire tuple. This does not affect the computation per se, but it does make the intention 

1553 clearer. 

1554 

1555 To avoid having to create many boiler-plate node definitions to expand namedtuples, the 

1556 ``add_named_tuple_expansion`` method automatically creates new nodes for each element of a tuple. The 

1557 convention is that an element called 'element', in a node called 'node' will be expanded into a new node 

1558 called 'node.element', and that this will be applied for each element. 

1559 

1560 Example:: 

1561 

1562 >>> from collections import namedtuple 

1563 >>> Coordinate = namedtuple('Coordinate', ['x', 'y']) 

1564 >>> comp = Computation() 

1565 >>> comp.add_node('c', value=Coordinate(1, 2)) 

1566 >>> comp.add_named_tuple_expansion('c', Coordinate) 

1567 >>> comp.compute_all() 

1568 >>> comp.value('c.x') 

1569 1 

1570 >>> comp.value('c.y') 

1571 2 

1572 

1573 :param name: Node to cera 

1574 :param namedtuple_type: Expected type of the node 

1575 :type namedtuple_type: namedtuple class 

1576 """ 

1577 

1578 def make_f(field_name): 

1579 def get_field_value(tuple): 

1580 return getattr(tuple, field_name) 

1581 

1582 return get_field_value 

1583 

1584 for field_name in namedtuple_type._fields: 

1585 node_name = f"{name}.{field_name}" 

1586 self.add_node(node_name, make_f(field_name), kwds={"tuple": name}, group=group) 

1587 self.set_tag(node_name, SystemTags.EXPANSION) 

1588 

1589 def add_map_node(self, result_node, input_node, subgraph, subgraph_input_node, subgraph_output_node): 

1590 """Apply a graph to each element of iterable. 

1591 

1592 In turn, each element in the ``input_node`` of this graph will be inserted in turn into the subgraph's 

1593 ``subgraph_input_node``, then the subgraph's ``subgraph_output_node`` calculated. The resultant list, with 

1594 an element or each element in ``input_node``, will be inserted into ``result_node`` of this graph. In this 

1595 way ``add_map_node`` is similar to ``map`` in functional programming. 

1596 

1597 :param result_node: The node to place a list of results in **this** graph 

1598 :param input_node: The node to get a list input values from **this** graph 

1599 :param subgraph: The graph to use to perform calculation for each element 

1600 :param subgraph_input_node: The node in **subgraph** to insert each element in turn 

1601 :param subgraph_output_node: The node in **subgraph** to read the result for each element 

1602 """ 

1603 

1604 def f(xs): 

1605 results = [] 

1606 is_error = False 

1607 for x in xs: 

1608 subgraph.insert(subgraph_input_node, x) 

1609 subgraph.compute(subgraph_output_node) 

1610 if subgraph.state(subgraph_output_node) == States.UPTODATE: 

1611 results.append(subgraph.value(subgraph_output_node)) 

1612 else: 

1613 is_error = True 

1614 results.append(subgraph.copy()) 

1615 if is_error: 

1616 raise MapException(f"Unable to calculate {result_node}", results) 

1617 return results 

1618 

1619 self.add_node(result_node, f, kwds={"xs": input_node}) 

1620 

1621 def prepend_path(self, path, prefix_path: NodeKey): 

1622 """Prepend a prefix path to a node path.""" 

1623 if isinstance(path, ConstantValue): 

1624 return path 

1625 path = to_nodekey(path) 

1626 return prefix_path.join(path) 

1627 

1628 def add_block( 

1629 self, 

1630 base_path: Name, 

1631 block: "Computation", 

1632 *, 

1633 keep_values: bool | None = True, 

1634 links: dict | None = None, 

1635 metadata: dict | None = None, 

1636 ): 

1637 """Add a computation block as a subgraph to this computation.""" 

1638 base_path = to_nodekey(base_path) 

1639 for node_name in block.nodes(): 

1640 node_key = to_nodekey(node_name) 

1641 node_data = block.dag.nodes[node_key] 

1642 tags = node_data.get(NodeAttributes.TAG, None) 

1643 style = node_data.get(NodeAttributes.STYLE, None) 

1644 group = node_data.get(NodeAttributes.GROUP, None) 

1645 args, kwds = block.get_definition_args_kwds(node_key) 

1646 args = [self.prepend_path(arg, base_path) for arg in args] 

1647 kwds = {k: self.prepend_path(v, base_path) for k, v in kwds.items()} 

1648 func = node_data.get(NodeAttributes.FUNC, None) 

1649 executor = node_data.get(NodeAttributes.EXECUTOR, None) 

1650 converter = node_data.get(NodeAttributes.CONVERTER, None) 

1651 new_node_name = self.prepend_path(node_name, base_path) 

1652 self.add_node( 

1653 new_node_name, 

1654 func, 

1655 args=args, 

1656 kwds=kwds, 

1657 converter=converter, 

1658 serialize=False, 

1659 inspect=False, 

1660 group=group, 

1661 tags=tags, 

1662 style=style, 

1663 executor=executor, 

1664 ) 

1665 if keep_values and NodeAttributes.VALUE in node_data: 

1666 new_node_key = to_nodekey(new_node_name) 

1667 self._set_state_and_literal_value( 

1668 new_node_key, node_data[NodeAttributes.STATE], node_data[NodeAttributes.VALUE] 

1669 ) 

1670 if links is not None: 

1671 for target, source in links.items(): 

1672 self.link(base_path.join_parts(target), source) 

1673 if metadata is not None: 

1674 self._metadata[base_path] = metadata 

1675 else: 

1676 if base_path in self._metadata: 

1677 del self._metadata[base_path] 

1678 

1679 def link(self, target: Name, source: Name): 

1680 """Create a link between two nodes in the computation graph.""" 

1681 target = to_nodekey(target) 

1682 source = to_nodekey(source) 

1683 if target == source: 

1684 return 

1685 

1686 target_style = self._style_one(target) if self.has_node(target) else None 

1687 source_style = self._style_one(source) if self.has_node(source) else None 

1688 style = target_style if target_style else source_style 

1689 

1690 self.add_node(target, identity_function, kwds={"x": source}, style=style) 

1691 

1692 def _repr_svg_(self): 

1693 return GraphView(self).svg() 

1694 

1695 def draw( 

1696 self, 

1697 root: NodeKey | None = None, 

1698 *, 

1699 node_transformations: dict | None = None, 

1700 cmap=None, 

1701 colors="state", 

1702 shapes=None, 

1703 graph_attr=None, 

1704 node_attr=None, 

1705 edge_attr=None, 

1706 show_expansion=False, 

1707 collapse_all=True, 

1708 ): 

1709 """Draw a computation's current state using the GraphViz utility. 

1710 

1711 :param root: Optional PathType. Sub-block to draw 

1712 :param cmap: Default: None 

1713 :param colors: 'state' - colors indicate state. 'timing' - colors indicate execution time. Default: 'state'. 

1714 :param shapes: None - ovals. 'type' - shapes indicate type. Default: None. 

1715 :param graph_attr: Mapping of (attribute, value) pairs for the graph. For example 

1716 ``graph_attr={'size': '"10,8"'}`` can control the size of the output graph 

1717 :param node_attr: Mapping of (attribute, value) pairs set for all nodes. 

1718 :param edge_attr: Mapping of (attribute, value) pairs set for all edges. 

1719 :param collapse_all: Whether to collapse all blocks that aren't explicitly expanded. 

1720 """ 

1721 node_formatter = NodeFormatter.create(cmap, colors, shapes) 

1722 node_transformations = node_transformations.copy() if node_transformations is not None else {} 

1723 if not show_expansion: 

1724 for nodekey in self.nodes_by_tag(SystemTags.EXPANSION): 

1725 node_transformations[nodekey] = NodeTransformations.CONTRACT 

1726 v = GraphView( 

1727 self, 

1728 root=root, 

1729 node_formatter=node_formatter, 

1730 graph_attr=graph_attr, 

1731 node_attr=node_attr, 

1732 edge_attr=edge_attr, 

1733 node_transformations=node_transformations, 

1734 collapse_all=collapse_all, 

1735 ) 

1736 return v 

1737 

1738 def view(self, cmap=None, colors="state", shapes=None): 

1739 """Create and display a visualization of the computation graph.""" 

1740 node_formatter = NodeFormatter.create(cmap, colors, shapes) 

1741 v = GraphView(self, node_formatter=node_formatter) 

1742 v.view() 

1743 

1744 def print_errors(self): 

1745 """Print tracebacks for every node with state "ERROR" in a Computation.""" 

1746 for n in self.nodes(): 

1747 if self.s[n] == States.ERROR: 

1748 print(f"{n}") 

1749 print("=" * len(n)) 

1750 print() 

1751 print(self.v[n].traceback) 

1752 print() 

1753 

1754 @classmethod 

1755 def from_class(cls, definition_class, ignore_self=True): 

1756 """Create a computation from a class with decorated methods.""" 

1757 comp = cls() 

1758 obj = definition_class() 

1759 populate_computation_from_class(comp, definition_class, obj, ignore_self=ignore_self) 

1760 return comp 

1761 

1762 def inject_dependencies(self, dependencies: dict, *, force: bool = False): 

1763 """Injects dependencies into the nodes of the current computation where nodes are in a placeholder state. 

1764 

1765 (or all possible nodes when the 'force' parameter is set to True), using values 

1766 provided in the 'dependencies' dictionary. 

1767 

1768 Each key in the 'dependencies' dictionary corresponds to a node identifier, and the associated 

1769 value is the dependency object to inject. If the value is a callable, it will be added as a calc node. 

1770 

1771 :param dependencies: A dictionary where each key-value pair consists of a node identifier and 

1772 its corresponding dependency object or a callable that returns the dependency object. 

1773 :param force: A boolean flag that, when set to True, forces the replacement of existing node values 

1774 with the ones provided in 'dependencies', regardless of their current state. Defaults to False. 

1775 :return: None 

1776 """ 

1777 for n in self.nodes(): 

1778 if force or self.s[n] == States.PLACEHOLDER: 

1779 obj = dependencies.get(n) 

1780 if obj is None: 

1781 continue 

1782 if callable(obj): 

1783 self.add_node(n, obj) 

1784 else: 

1785 self.add_node(n, value=obj)