Coverage for src / loman / computeengine.py: 99%

1019 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-22 21:30 +0000

1"""Core computation engine for dependency-aware calculation graphs.""" 

2 

3import inspect 

4import logging 

5import traceback 

6import types 

7import warnings 

8from collections import defaultdict 

9from collections.abc import Callable, Iterable, Mapping 

10from concurrent.futures import FIRST_COMPLETED, Executor, ThreadPoolExecutor, wait 

11from dataclasses import dataclass, field 

12from datetime import UTC, datetime 

13from enum import Enum 

14from typing import Any, BinaryIO 

15 

16import decorator 

17import dill 

18import networkx as nx 

19import pandas as pd 

20 

21from .compat import get_signature 

22from .consts import EdgeAttributes, NodeAttributes, NodeTransformations, States, SystemTags 

23from .exception import ( 

24 CannotInsertToPlaceholderNodeException, 

25 ComputationError, 

26 LoopDetectedException, 

27 MapException, 

28 NodeAlreadyExistsException, 

29 NonExistentNodeException, 

30 ValidationError, 

31) 

32from .graph_utils import topological_sort 

33from .nodekey import Name, Names, NodeKey, names_to_node_keys, node_keys_to_names, to_nodekey 

34from .util import AttributeView, apply1, apply_n, as_iterable, value_eq 

35from .visualization import GraphView, NodeFormatter 

36 

37LOG = logging.getLogger("loman.computeengine") 

38 

39 

40@dataclass 

41class Error: 

42 """Container for error information during computation.""" 

43 

44 exception: Exception 

45 traceback: str 

46 

47 

48@dataclass 

49class NodeData: 

50 """Data associated with a computation node.""" 

51 

52 state: States 

53 value: object 

54 

55 

56@dataclass 

57class TimingData: 

58 """Timing information for computation execution.""" 

59 

60 start: datetime 

61 end: datetime 

62 duration: float 

63 

64 

65class _ParameterType(Enum): 

66 """Internal enum for distinguishing positional and keyword parameters.""" 

67 

68 ARG = 1 

69 KWD = 2 

70 

71 

72@dataclass 

73class _ParameterItem: 

74 """Internal container for parameter information during computation.""" 

75 

76 type: _ParameterType 

77 name: int | str 

78 value: object 

79 

80 

81def _node(func: Callable[..., Any], *args: Any, **kws: Any) -> Any: # pragma: no cover 

82 """Internal wrapper function for node decorator.""" 

83 return func(*args, **kws) 

84 

85 

86def node( 

87 comp: "Computation", name: Name | None = None, *args: Any, **kw: Any 

88) -> Callable[[Callable[..., Any]], Callable[..., Any]]: 

89 """Decorator to add a function as a node to a computation graph.""" 

90 

91 def inner(f: Callable[..., Any]) -> Callable[..., Any]: 

92 """Inner decorator that registers the function as a node.""" 

93 if name is None: 

94 comp.add_node(f.__name__, f, *args, **kw) # type: ignore[attr-defined] 

95 else: 

96 comp.add_node(name, f, *args, **kw) 

97 result: Callable[..., Any] = decorator.decorate(f, _node) 

98 return result 

99 

100 return inner 

101 

102 

103@dataclass() 

104class ConstantValue: 

105 """Container for constant values in computations.""" 

106 

107 value: object 

108 

109 

110C = ConstantValue 

111 

112 

113class Node: 

114 """Base class for computation graph nodes.""" 

115 

116 def add_to_comp(self, comp: "Computation", name: str, obj: object, ignore_self: bool) -> None: 

117 """Add this node to the computation graph.""" 

118 raise NotImplementedError() 

119 

120 

121@dataclass 

122class InputNode(Node): 

123 """A node representing input data in the computation graph.""" 

124 

125 args: tuple[Any, ...] = field(default_factory=tuple) 

126 kwds: dict[str, Any] = field(default_factory=dict) 

127 

128 def __init__(self, *args: Any, **kwds: Any) -> None: 

129 """Initialize an input node with arguments and keyword arguments.""" 

130 self.args = args 

131 self.kwds = kwds 

132 

133 def add_to_comp(self, comp: "Computation", name: str, obj: object, ignore_self: bool) -> None: 

134 """Add this input node to the computation graph.""" 

135 comp.add_node(name, **self.kwds) 

136 

137 

138input_node = InputNode 

139 

140 

141@dataclass 

142class CalcNode(Node): 

143 """A node representing a calculation in the computation graph.""" 

144 

145 f: Callable[..., Any] 

146 kwds: dict[str, Any] = field(default_factory=dict) 

147 

148 def add_to_comp(self, comp: "Computation", name: str, obj: object, ignore_self: bool) -> None: 

149 """Add this calculation node to the computation graph.""" 

150 kwds = self.kwds.copy() 

151 ignore_self = ignore_self or kwds.get("ignore_self", False) 

152 f = self.f 

153 if ignore_self: 

154 signature = get_signature(self.f) 

155 if len(signature.kwd_params) > 0 and signature.kwd_params[0] == "self": 

156 f = f.__get__(obj, obj.__class__) # type: ignore[attr-defined] 

157 if "ignore_self" in kwds: 

158 del kwds["ignore_self"] 

159 comp.add_node(name, f, **kwds) 

160 

161 

162def calc_node( 

163 f: Callable[..., Any] | None = None, **kwds: Any 

164) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]: 

165 """Decorator to mark a function as a calculation node.""" 

166 

167 def wrap(func: Callable[..., Any]) -> Callable[..., Any]: 

168 """Wrap function with node info attribute.""" 

169 func._loman_node_info = CalcNode(func, kwds) # type: ignore[attr-defined] 

170 return func 

171 

172 if f is None: 

173 return wrap 

174 return wrap(f) 

175 

176 

177@dataclass 

178class Block(Node): 

179 """A node representing a computational block or subgraph.""" 

180 

181 block: "Callable[..., Computation] | Computation" 

182 args: tuple[Any, ...] = field(default_factory=tuple) 

183 kwds: dict[str, Any] = field(default_factory=dict) 

184 

185 def __init__(self, block: "Callable[..., Computation] | Computation", *args: Any, **kwds: Any) -> None: 

186 """Initialize a block node with a computation block and arguments.""" 

187 self.block = block 

188 self.args = args 

189 self.kwds = kwds 

190 

191 def add_to_comp(self, comp: "Computation", name: str, obj: object, ignore_self: bool) -> None: 

192 """Add this block node to the computation graph.""" 

193 if isinstance(self.block, Computation): 

194 comp.add_block(name, self.block, *self.args, **self.kwds) 

195 elif callable(self.block): 

196 block0 = self.block() 

197 comp.add_block(name, block0, *self.args, **self.kwds) 

198 else: 

199 msg = f"Block {self.block} must be callable or Computation" 

200 raise TypeError(msg) 

201 

202 

203block = Block 

204 

205 

206def populate_computation_from_class(comp: "Computation", cls: type, obj: object, ignore_self: bool = True) -> None: 

207 """Populate a computation from class methods with node decorators.""" 

208 for name, member in inspect.getmembers(cls): 

209 node_: Node | None = None 

210 if isinstance(member, Node): 

211 node_ = member 

212 elif hasattr(member, "_loman_node_info"): 

213 node_ = member._loman_node_info 

214 if node_ is not None: 

215 node_.add_to_comp(comp, name, obj, ignore_self) 

216 

217 

218def computation_factory( 

219 maybe_cls: type | None = None, *, ignore_self: bool = True 

220) -> Callable[..., "Computation"] | Callable[[type], Callable[..., "Computation"]]: 

221 """Factory function to create computations from class definitions.""" 

222 

223 def wrap(cls: type) -> Callable[..., "Computation"]: 

224 """Wrap class to create computation factory function.""" 

225 

226 def create_computation(*args: Any, **kwargs: Any) -> "Computation": 

227 """Create a computation instance from the wrapped class.""" 

228 obj = cls() 

229 comp = Computation(*args, **kwargs) 

230 comp._definition_object = obj # type: ignore[attr-defined] 

231 populate_computation_from_class(comp, cls, obj, ignore_self) 

232 return comp 

233 

234 return create_computation 

235 

236 if maybe_cls is None: 

237 return wrap 

238 

239 return wrap(maybe_cls) 

240 

241 

242def _eval_node( 

243 name: NodeKey, 

244 f: Callable[..., Any], 

245 args: list[Any], 

246 kwds: dict[str, Any], 

247 raise_exceptions: bool, 

248) -> tuple[Any, Exception | None, str | None, datetime, datetime]: 

249 """To make multiprocessing work, this function must be standalone so that pickle works.""" 

250 exc: Exception | None = None 

251 tb: str | None = None 

252 start_dt = datetime.now(UTC) 

253 try: 

254 logging.debug("Running " + str(name)) 

255 value = f(*args, **kwds) 

256 logging.debug("Completed " + str(name)) 

257 except Exception as e: 

258 value = None 

259 exc = e 

260 tb = traceback.format_exc() 

261 if raise_exceptions: 

262 raise 

263 end_dt = datetime.now(UTC) 

264 return value, exc, tb, start_dt, end_dt 

265 

266 

267_MISSING_VALUE_SENTINEL = object() 

268 

269 

270class NullObject: 

271 """Debug helper object that raises exceptions for all attribute/item access.""" 

272 

273 def __getattr__(self, name: str) -> Any: 

274 """Raise AttributeError for any attribute access.""" 

275 print(f"__getattr__: {name}") 

276 msg = f"'NullObject' object has no attribute '{name}'" 

277 raise AttributeError(msg) 

278 

279 def __setattr__(self, name: str, value: Any) -> None: 

280 """Raise AttributeError for any attribute assignment.""" 

281 print(f"__setattr__: {name}") 

282 msg = f"'NullObject' object has no attribute '{name}'" 

283 raise AttributeError(msg) 

284 

285 def __delattr__(self, name: str) -> None: 

286 """Raise AttributeError for any attribute deletion.""" 

287 print(f"__delattr__: {name}") 

288 msg = f"'NullObject' object has no attribute '{name}'" 

289 raise AttributeError(msg) 

290 

291 def __call__(self, *args: Any, **kwargs: Any) -> Any: 

292 """Raise TypeError when called as a function.""" 

293 print(f"__call__: {args}, {kwargs}") 

294 msg = "'NullObject' object is not callable" 

295 raise TypeError(msg) 

296 

297 def __getitem__(self, key: Any) -> Any: 

298 """Raise KeyError for any item access.""" 

299 print(f"__getitem__: {key}") 

300 msg = f"'NullObject' object has no item with key '{key}'" 

301 raise KeyError(msg) 

302 

303 def __setitem__(self, key: Any, value: Any) -> None: 

304 """Raise KeyError for any item assignment.""" 

305 print(f"__setitem__: {key}") 

306 msg = f"'NullObject' object cannot have items set with key '{key}'" 

307 raise KeyError(msg) 

308 

309 def __repr__(self) -> str: 

310 """Return string representation of NullObject.""" 

311 print(f"__repr__: {object.__getattribute__(self, '__dict__')}") 

312 return "<NullObject>" 

313 

314 

315def identity_function(x: Any) -> Any: 

316 """Return the input value unchanged.""" 

317 return x 

318 

319 

320class Computation: 

321 """A computation graph that manages dependencies and calculations. 

322 

323 The Computation class provides a framework for building and executing 

324 computation graphs where nodes represent data or calculations, and edges 

325 represent dependencies between them. 

326 """ 

327 

328 def __init__( 

329 self, 

330 *, 

331 default_executor: Executor | None = None, 

332 executor_map: dict[str, Executor] | None = None, 

333 metadata: dict[str, Any] | None = None, 

334 ) -> None: 

335 """Initialize a new Computation. 

336 

337 :param default_executor: An executor 

338 :type default_executor: concurrent.futures.Executor, default ThreadPoolExecutor(max_workers=1) 

339 """ 

340 if default_executor is None: 

341 self.default_executor: Executor = ThreadPoolExecutor(1) 

342 else: 

343 self.default_executor = default_executor 

344 if executor_map is None: 

345 self.executor_map: dict[str, Executor] = {} 

346 else: 

347 self.executor_map = executor_map 

348 self.dag: nx.DiGraph = nx.DiGraph() 

349 self._metadata: dict[NodeKey, Any] = {} 

350 if metadata is not None: 

351 self._metadata[NodeKey.root()] = metadata 

352 

353 self.v = self.get_attribute_view_for_path(NodeKey.root(), self._value_one, self.value) 

354 self.s = self.get_attribute_view_for_path(NodeKey.root(), self._state_one, self.state) 

355 self.i = self.get_attribute_view_for_path(NodeKey.root(), self._get_inputs_one_names, self.get_inputs) 

356 self.o = self.get_attribute_view_for_path(NodeKey.root(), self._get_outputs_one, self.get_outputs) 

357 self.t = self.get_attribute_view_for_path(NodeKey.root(), self._tag_one, self.tags) 

358 self.style = self.get_attribute_view_for_path(NodeKey.root(), self._style_one, self.styles) 

359 self.tim = self.get_attribute_view_for_path(NodeKey.root(), self._get_timing_one, self.get_timing) 

360 self.x = self.get_attribute_view_for_path( 

361 NodeKey.root(), self.compute_and_get_value, self.compute_and_get_value 

362 ) 

363 self.src = self.get_attribute_view_for_path(NodeKey.root(), self.print_source, self.print_source) 

364 self._tag_map: defaultdict[str, set[NodeKey]] = defaultdict(set) 

365 self._state_map: dict[States, set[NodeKey]] = {state: set() for state in States} 

366 

367 def get_attribute_view_for_path( 

368 self, nodekey: NodeKey, get_one_func: Callable[[Name], Any], get_many_func: Callable[[Name | Names], Any] 

369 ) -> AttributeView: 

370 """Create an attribute view for a specific node path.""" 

371 

372 def node_func() -> Iterable[str]: 

373 """Return list of child node names for this path.""" 

374 return [str(n) for n in self.get_tree_list_children(nodekey)] 

375 

376 def get_one_func_for_path(name: str) -> Any: 

377 """Get value for a single node at this path.""" 

378 nk = to_nodekey(name) 

379 new_nk = nk.prepend(nodekey) 

380 if self.has_node(new_nk): 

381 return get_one_func(new_nk) 

382 elif self.tree_has_path(new_nk): 

383 return self.get_attribute_view_for_path(new_nk, get_one_func, get_many_func) 

384 else: 

385 msg = f"Path {new_nk} does not exist" 

386 raise KeyError(msg) # pragma: no cover 

387 

388 def get_many_func_for_path(name: Name | Names) -> Any: 

389 """Get values for one or more nodes at this path.""" 

390 if isinstance(name, list): 

391 return [get_one_func_for_path(str(n)) for n in name] 

392 else: 

393 return get_one_func_for_path(str(name)) 

394 

395 return AttributeView(node_func, get_one_func_for_path, get_many_func_for_path) 

396 

397 def _get_names_for_state(self, state: States) -> set[Name]: 

398 """Get node names that have a specific state.""" 

399 return set(node_keys_to_names(self._state_map[state])) 

400 

401 def _get_tags_for_state(self, tag: str) -> set[Name]: 

402 """Get node names that have a specific tag.""" 

403 return set(node_keys_to_names(self._tag_map[tag])) 

404 

405 def _process_function_args(self, node_key: NodeKey, node: dict[str, Any], args: list[Any] | None) -> int: 

406 """Process positional arguments for a function node.""" 

407 args_count = 0 

408 if args: 

409 args_count = len(args) 

410 for i, arg in enumerate(args): 

411 if isinstance(arg, ConstantValue): 

412 node[NodeAttributes.ARGS][i] = arg.value 

413 else: 

414 input_vertex_name = arg 

415 input_vertex_node_key = to_nodekey(input_vertex_name) 

416 if not self.dag.has_node(input_vertex_node_key): 

417 self.dag.add_node(input_vertex_node_key, **{NodeAttributes.STATE: States.PLACEHOLDER}) 

418 self._state_map[States.PLACEHOLDER].add(input_vertex_node_key) 

419 self.dag.add_edge( 

420 input_vertex_node_key, node_key, **{EdgeAttributes.PARAM: (_ParameterType.ARG, i)} 

421 ) 

422 return args_count 

423 

424 def _build_param_map( 

425 self, 

426 func: Callable[..., Any], 

427 node_key: NodeKey, 

428 args_count: int, 

429 kwds: dict[str, Any] | None, 

430 inspect: bool, 

431 ) -> tuple[dict[str, Any], list[str]]: 

432 """Build parameter map for function node.""" 

433 param_map: dict[str, Any] = {} 

434 default_names: list[str] = [] 

435 

436 if inspect: 

437 signature = get_signature(func) 

438 if not signature.has_var_args: 

439 for param_name in signature.kwd_params[args_count:]: 

440 if kwds is not None and param_name in kwds: 

441 param_source = kwds[param_name] 

442 else: 

443 param_source = node_key.parent.join_parts(param_name) 

444 param_map[param_name] = param_source 

445 if signature.has_var_kwds and kwds is not None: 

446 for param_name, param_source in kwds.items(): 

447 param_map[param_name] = param_source 

448 default_names = signature.default_params 

449 else: 

450 if kwds is not None: 

451 for param_name, param_source in kwds.items(): 

452 param_map[param_name] = param_source 

453 

454 return param_map, default_names 

455 

456 def _process_function_kwds( 

457 self, node_key: NodeKey, node: dict[str, Any], param_map: dict[str, Any], default_names: list[str] 

458 ) -> None: 

459 """Process keyword arguments for a function node.""" 

460 for param_name, param_source in param_map.items(): 

461 if isinstance(param_source, ConstantValue): 

462 node[NodeAttributes.KWDS][param_name] = param_source.value 

463 else: 

464 in_node_name = param_source 

465 in_node_key = to_nodekey(in_node_name) 

466 if not self.dag.has_node(in_node_key): 

467 if param_name in default_names: 

468 continue 

469 else: 

470 self.dag.add_node(in_node_key, **{NodeAttributes.STATE: States.PLACEHOLDER}) 

471 self._state_map[States.PLACEHOLDER].add(in_node_key) 

472 self.dag.add_edge(in_node_key, node_key, **{EdgeAttributes.PARAM: (_ParameterType.KWD, param_name)}) 

473 

474 def add_node( 

475 self, 

476 name: Name, 

477 func: Callable[..., Any] | None = None, 

478 *, 

479 args: list[Any] | None = None, 

480 kwds: dict[str, Any] | None = None, 

481 value: Any = _MISSING_VALUE_SENTINEL, 

482 converter: Callable[[Any], Any] | None = None, 

483 serialize: bool = True, 

484 inspect: bool = True, 

485 group: str | None = None, 

486 tags: Iterable[str] | None = None, 

487 style: str | None = None, 

488 executor: str | None = None, 

489 metadata: dict[str, Any] | None = None, 

490 ) -> None: 

491 """Adds or updates a node in a computation. 

492 

493 :param name: Name of the node to add. This may be any hashable object. 

494 :param func: Function to use to calculate the node if the node is a calculation node. By default, the input 

495 nodes to the function will be implied from the names of the function parameters. For example, a 

496 parameter called ``a`` would be taken from the node called ``a``. This can be modified with the 

497 ``kwds`` parameter. 

498 :type func: Function, default None 

499 :param args: Specifies a list of nodes that will be used to populate arguments of the function positionally 

500 for a calculation node. e.g. If args is ``['a', 'b', 'c']`` then the function would be called with 

501 three parameters, taken from the nodes 'a', 'b' and 'c' respectively. 

502 :type args: List, default None 

503 :param kwds: Specifies a mapping from parameter name to the node that should be used to populate that 

504 parameter when calling the function for a calculation node. e.g. If args is ``{'x': 'a', 'y': 'b'}`` 

505 then the function would be called with parameters named 'x' and 'y', and their values would be taken 

506 from nodes 'a' and 'b' respectively. Each entry in the dictionary can be read as "take parameter 

507 [key] from node [value]". 

508 :type kwds: Dictionary, default None 

509 :param value: If given, the value is inserted into the node, and the node state set to UPTODATE. 

510 :type value: default None 

511 :param serialize: Whether the node should be serialized. Some objects cannot be serialized, in which 

512 case, set serialize to False 

513 :type serialize: boolean, default True 

514 :param inspect: Whether to use introspection to determine the arguments of the function, which can be 

515 slow. If this is not set, kwds and args must be set for the function to obtain parameters. 

516 :type inspect: boolean, default True 

517 :param group: Subgraph to render node in 

518 :type group: default None 

519 :param tags: Set of tags to apply to node 

520 :type tags: Iterable 

521 :param styles: Style to apply to node 

522 :type styles: String, default None 

523 :param executor: Name of executor to run node on 

524 :type executor: string 

525 """ 

526 node_key = to_nodekey(name) 

527 LOG.debug(f"Adding node {node_key}") 

528 has_value = value is not _MISSING_VALUE_SENTINEL 

529 if value is _MISSING_VALUE_SENTINEL: 

530 value = None 

531 if tags is None: 

532 tags = [] 

533 

534 self.dag.add_node(node_key) 

535 pred_edges = [(p, node_key) for p in self.dag.predecessors(node_key)] 

536 self.dag.remove_edges_from(pred_edges) 

537 node = self.dag.nodes[node_key] 

538 

539 if metadata is None: 

540 if node_key in self._metadata: 

541 del self._metadata[node_key] 

542 else: 

543 self._metadata[node_key] = metadata 

544 

545 self._set_state_and_literal_value(node_key, States.UNINITIALIZED, None, require_old_state=False) 

546 

547 node[NodeAttributes.TAG] = set() 

548 node[NodeAttributes.STYLE] = style 

549 node[NodeAttributes.GROUP] = group 

550 node[NodeAttributes.ARGS] = {} 

551 node[NodeAttributes.KWDS] = {} 

552 node[NodeAttributes.FUNC] = None 

553 node[NodeAttributes.EXECUTOR] = executor 

554 node[NodeAttributes.CONVERTER] = converter 

555 

556 if func: 

557 node[NodeAttributes.FUNC] = func 

558 args_count = self._process_function_args(node_key, node, args) 

559 param_map, default_names = self._build_param_map(func, node_key, args_count, kwds, inspect) 

560 self._process_function_kwds(node_key, node, param_map, default_names) 

561 self._set_descendents(node_key, States.STALE) 

562 

563 if has_value: 

564 self._set_uptodate(node_key, value) 

565 if node[NodeAttributes.STATE] == States.UNINITIALIZED: 

566 self._try_set_computable(node_key) 

567 self.set_tag(node_key, tags) 

568 if serialize: 

569 self.set_tag(node_key, SystemTags.SERIALIZE) 

570 

571 def _refresh_maps(self) -> None: 

572 """Refresh internal tag and state maps from node data.""" 

573 self._tag_map.clear() 

574 for state in States: 

575 self._state_map[state].clear() 

576 for node_key in self._node_keys(): 

577 state = self.dag.nodes[node_key][NodeAttributes.STATE] 

578 self._state_map[state].add(node_key) 

579 tags = self.dag.nodes[node_key].get(NodeAttributes.TAG, set()) 

580 for tag in tags: 

581 self._tag_map[tag].add(node_key) 

582 

583 def _set_tag_one(self, name: Name, tag: str) -> None: 

584 """Set a single tag on a single node.""" 

585 node_key = to_nodekey(name) 

586 self.dag.nodes[node_key][NodeAttributes.TAG].add(tag) 

587 self._tag_map[tag].add(node_key) 

588 

589 def set_tag(self, name: Name | Names, tag: str | Iterable[str]) -> None: 

590 """Set tags on a node or nodes. Ignored if tags are already set. 

591 

592 :param name: Node or nodes to set tag for 

593 :param tag: Tag to set 

594 """ 

595 apply_n(self._set_tag_one, name, tag) 

596 

597 def _clear_tag_one(self, name: Name, tag: str) -> None: 

598 """Clear a single tag from a single node.""" 

599 node_key = to_nodekey(name) 

600 self.dag.nodes[node_key][NodeAttributes.TAG].discard(tag) 

601 self._tag_map[tag].discard(node_key) 

602 

603 def clear_tag(self, name: Name | Names, tag: str | Iterable[str]) -> None: 

604 """Clear tag on a node or nodes. Ignored if tags are not set. 

605 

606 :param name: Node or nodes to clear tags for 

607 :param tag: Tag to clear 

608 """ 

609 apply_n(self._clear_tag_one, name, tag) 

610 

611 def _set_style_one(self, name: Name, style: str) -> None: 

612 """Set style on a single node.""" 

613 node_key = to_nodekey(name) 

614 self.dag.nodes[node_key][NodeAttributes.STYLE] = style 

615 

616 def set_style(self, name: Name | Names, style: str) -> None: 

617 """Set styles on a node or nodes. 

618 

619 :param name: Node or nodes to set style for 

620 :param style: Style to set 

621 """ 

622 apply_n(self._set_style_one, name, style) 

623 

624 def _clear_style_one(self, name: Name) -> None: 

625 """Clear style from a single node.""" 

626 node_key = to_nodekey(name) 

627 self.dag.nodes[node_key][NodeAttributes.STYLE] = None 

628 

629 def clear_style(self, name: Name | Names) -> None: 

630 """Clear style on a node or nodes. 

631 

632 :param name: Node or nodes to clear styles for 

633 """ 

634 apply_n(self._clear_style_one, name) 

635 

636 def metadata(self, name: Name) -> dict[str, Any]: 

637 """Get metadata for a node.""" 

638 node_key = to_nodekey(name) 

639 if self.tree_has_path(name): 

640 if node_key not in self._metadata: 

641 self._metadata[node_key] = {} 

642 result: dict[str, Any] = self._metadata[node_key] 

643 return result 

644 else: 

645 msg = f"Node {node_key} does not exist." 

646 raise NonExistentNodeException(msg) 

647 

648 def delete_node(self, name: Name) -> None: 

649 """Delete a node from a computation. 

650 

651 When nodes are explicitly deleted with ``delete_node``, but are still depended on by other nodes, then they 

652 will be set to PLACEHOLDER status. In this case, if the nodes that depend on a PLACEHOLDER node are deleted, 

653 then the PLACEHOLDER node will also be deleted. 

654 

655 :param name: Name of the node to delete. If the node does not exist, a ``NonExistentNodeException`` will 

656 be raised. 

657 """ 

658 node_key = to_nodekey(name) 

659 LOG.debug(f"Deleting node {node_key}") 

660 

661 if not self.dag.has_node(node_key): 

662 msg = f"Node {node_key} does not exist" 

663 raise NonExistentNodeException(msg) 

664 

665 if node_key in self._metadata: 

666 del self._metadata[node_key] 

667 

668 if len(self.dag.succ[node_key]) == 0: 

669 preds = self.dag.predecessors(node_key) 

670 state = self.dag.nodes[node_key][NodeAttributes.STATE] 

671 self.dag.remove_node(node_key) 

672 self._state_map[state].remove(node_key) 

673 for n in preds: 

674 if self.dag.nodes[n][NodeAttributes.STATE] == States.PLACEHOLDER: 

675 self.delete_node(n) 

676 else: 

677 self._set_state(node_key, States.PLACEHOLDER) 

678 

679 def rename_node(self, old_name: Name | Mapping[Name, Name], new_name: Name | None = None) -> None: 

680 """Rename a node in a computation. 

681 

682 :param old_name: Node to rename, or a dictionary of nodes to rename, with existing names as keys, and 

683 new names as values 

684 :param new_name: New name for node. 

685 """ 

686 name_mapping: dict[Name, Name] 

687 if isinstance(old_name, Mapping) and not isinstance(old_name, str): 

688 for k, v in old_name.items(): 

689 LOG.debug(f"Renaming node {k} to {v}") 

690 if new_name is not None: 

691 msg = "new_name must not be set if rename_node is passed a dictionary" 

692 raise ValueError(msg) 

693 else: 

694 name_mapping = dict(old_name) # type: ignore[arg-type] 

695 else: 

696 LOG.debug(f"Renaming node {old_name} to {new_name}") 

697 old_node_key = to_nodekey(old_name) 

698 if not self.dag.has_node(old_node_key): 

699 msg = f"Node {old_name} does not exist" 

700 raise NonExistentNodeException(msg) 

701 assert new_name is not None # noqa: S101 

702 new_node_key = to_nodekey(new_name) 

703 if self.dag.has_node(new_node_key): 

704 msg = f"Node {new_name} already exists" 

705 raise NodeAlreadyExistsException(msg) 

706 name_mapping = {old_name: new_name} 

707 

708 node_key_mapping = {to_nodekey(on): to_nodekey(nn) for on, nn in name_mapping.items()} 

709 nx.relabel_nodes(self.dag, node_key_mapping, copy=False) 

710 

711 for old_nk, new_nk in node_key_mapping.items(): 

712 if old_nk in self._metadata: 

713 self._metadata[new_nk] = self._metadata[old_nk] 

714 del self._metadata[old_nk] 

715 else: 

716 if new_nk in self._metadata: # pragma: no cover 

717 del self._metadata[new_nk] 

718 

719 self._refresh_maps() 

720 

721 def repoint(self, old_name: Name, new_name: Name) -> None: 

722 """Changes all nodes that use old_name as an input to use new_name instead. 

723 

724 Note that if old_name is an input to new_name, then that will not be changed, to try to avoid introducing 

725 circular dependencies, but other circular dependencies will not be checked. 

726 

727 If new_name does not exist, then it will be created as a PLACEHOLDER node. 

728 

729 :param old_name: 

730 :param new_name: 

731 :return: 

732 """ 

733 old_node_key = to_nodekey(old_name) 

734 new_node_key = to_nodekey(new_name) 

735 if old_node_key == new_node_key: 

736 return 

737 

738 changed_names = list(self.dag.successors(old_node_key)) 

739 

740 if len(changed_names) > 0 and not self.dag.has_node(new_node_key): 

741 self.dag.add_node(new_node_key, **{NodeAttributes.STATE: States.PLACEHOLDER}) 

742 self._state_map[States.PLACEHOLDER].add(new_node_key) 

743 

744 for name in changed_names: 

745 if name == new_node_key: 

746 continue 

747 edge_data = self.dag.get_edge_data(old_node_key, name) 

748 self.dag.add_edge(new_node_key, name, **edge_data) 

749 self.dag.remove_edge(old_node_key, name) 

750 

751 for n in changed_names: 

752 self.set_stale(n) 

753 

754 def insert(self, name: Name, value: Any, force: bool = False) -> None: 

755 """Insert a value into a node of a computation. 

756 

757 Following insertation, the node will have state UPTODATE, and all its descendents will be COMPUTABLE or STALE. 

758 

759 If an attempt is made to insert a value into a node that does not exist, a ``NonExistentNodeException`` 

760 will be raised. 

761 

762 :param name: Name of the node to add. 

763 :param value: The value to be inserted into the node. 

764 :param force: Whether to force recalculation of descendents if node value and state would not be changed 

765 """ 

766 node_key = to_nodekey(name) 

767 LOG.debug(f"Inserting value into node {node_key}") 

768 

769 if not self.dag.has_node(node_key): 

770 msg = f"Node {node_key} does not exist" 

771 raise NonExistentNodeException(msg) 

772 

773 state = self._state_one(name) 

774 if state == States.PLACEHOLDER: 

775 msg = "Cannot insert into placeholder node. Use add_node to create the node first" 

776 raise CannotInsertToPlaceholderNodeException(msg) 

777 

778 if not force and state == States.UPTODATE: 

779 current_value = self._value_one(name) 

780 if value_eq(value, current_value): 

781 return 

782 

783 self._set_state_and_value(node_key, States.UPTODATE, value) 

784 self._set_descendents(node_key, States.STALE) 

785 for n in self.dag.successors(node_key): 

786 self._try_set_computable(n) 

787 

788 def insert_many(self, name_value_pairs: Iterable[tuple[Name, object]]) -> None: 

789 """Insert values into many nodes of a computation simultaneously. 

790 

791 Following insertation, the nodes will have state UPTODATE, and all their descendents will be COMPUTABLE 

792 or STALE. In the case of inserting many nodes, some of which are descendents of others, this ensures that 

793 the inserted nodes have correct status, rather than being set as STALE when their ancestors are inserted. 

794 

795 If an attempt is made to insert a value into a node that does not exist, a ``NonExistentNodeException`` will be 

796 raised, and none of the nodes will be inserted. 

797 

798 :param name_value_pairs: Each tuple should be a pair (name, value), where name is the name of the node to 

799 insert the value into. 

800 :type name_value_pairs: List of tuples 

801 """ 

802 node_key_value_pairs = [(to_nodekey(name), value) for name, value in name_value_pairs] 

803 LOG.debug(f"Inserting value into nodes {', '.join(str(name) for name, value in node_key_value_pairs)}") 

804 

805 for name, _value in node_key_value_pairs: 

806 if not self.dag.has_node(name): 

807 msg = f"Node {name} does not exist" 

808 raise NonExistentNodeException(msg) 

809 

810 stale = set() 

811 computable = set() 

812 for name, value in node_key_value_pairs: 

813 self._set_state_and_value(name, States.UPTODATE, value) 

814 stale.update(nx.dag.descendants(self.dag, name)) 

815 computable.update(self.dag.successors(name)) 

816 names = {name for name, value in node_key_value_pairs} 

817 stale.difference_update(names) 

818 computable.difference_update(names) 

819 for name in stale: 

820 self._set_state(name, States.STALE) 

821 for name in computable: 

822 self._try_set_computable(name) 

823 

824 def insert_from(self, other: "Computation", nodes: Iterable[Name] | None = None) -> None: 

825 """Insert values into another Computation object into this Computation object. 

826 

827 :param other: The computation object to take values from 

828 :type Computation: 

829 :param nodes: Only populate the nodes with the names provided in this list. By default, all nodes from the 

830 other Computation object that have corresponding nodes in this Computation object will be inserted 

831 :type nodes: List, default None 

832 """ 

833 if nodes is None: 

834 nodes_set: set[Any] = set(self.dag.nodes) 

835 nodes_set.intersection_update(other.dag.nodes()) 

836 nodes = nodes_set 

837 name_value_pairs = [(name, other.value(name)) for name in nodes] 

838 self.insert_many(name_value_pairs) 

839 

840 def _set_state(self, node_key: NodeKey, state: States) -> None: 

841 """Set the state of a node without changing its value.""" 

842 node = self.dag.nodes[node_key] 

843 old_state = node[NodeAttributes.STATE] 

844 self._state_map[old_state].remove(node_key) 

845 node[NodeAttributes.STATE] = state 

846 self._state_map[state].add(node_key) 

847 

848 def _set_state_and_value( 

849 self, node_key: NodeKey, state: States, value: object, *, throw_conversion_exception: bool = True 

850 ) -> None: 

851 """Set state and value of a node, applying any converter.""" 

852 node = self.dag.nodes[node_key] 

853 converter = node.get(NodeAttributes.CONVERTER) 

854 if converter is None: 

855 self._set_state_and_literal_value(node_key, state, value) 

856 else: 

857 try: 

858 converted_value = converter(value) 

859 self._set_state_and_literal_value(node_key, state, converted_value) 

860 except Exception as e: 

861 tb = traceback.format_exc() 

862 self._set_error(node_key, e, tb) 

863 if throw_conversion_exception: 

864 raise 

865 

866 def _set_state_and_literal_value( 

867 self, node_key: NodeKey, state: States, value: object, require_old_state: bool = True 

868 ) -> None: 

869 """Set state and literal value of a node without conversion.""" 

870 node = self.dag.nodes[node_key] 

871 try: 

872 old_state = node[NodeAttributes.STATE] 

873 self._state_map[old_state].remove(node_key) 

874 except KeyError: 

875 if require_old_state: 

876 raise # pragma: no cover 

877 node[NodeAttributes.STATE] = state 

878 node[NodeAttributes.VALUE] = value 

879 self._state_map[state].add(node_key) 

880 

881 def _set_states(self, node_keys: Iterable[NodeKey], state: States) -> None: 

882 """Set the state of multiple nodes at once.""" 

883 for name in node_keys: 

884 node = self.dag.nodes[name] 

885 old_state = node[NodeAttributes.STATE] 

886 self._state_map[old_state].remove(name) 

887 node[NodeAttributes.STATE] = state 

888 self._state_map[state].update(node_keys) 

889 

890 def set_stale(self, name: Name) -> None: 

891 """Set the state of a node and all its dependencies to STALE. 

892 

893 :param name: Name of the node to set as STALE. 

894 """ 

895 node_key = to_nodekey(name) 

896 node_keys: list[NodeKey] = [node_key] 

897 node_keys.extend(nx.dag.descendants(self.dag, node_key)) 

898 self._set_states(node_keys, States.STALE) 

899 self._try_set_computable(node_key) 

900 

901 def pin(self, name: Name, value: Any = None) -> None: 

902 """Set the state of a node to PINNED. 

903 

904 :param name: Name of the node to set as PINNED. 

905 :param value: Value to pin to the node, if provided. 

906 :type value: default None 

907 """ 

908 node_key = to_nodekey(name) 

909 if value is not None: 

910 self.insert(node_key, value) 

911 self._set_states([node_key], States.PINNED) 

912 

913 def unpin(self, name: Name) -> None: 

914 """Unpin a node (state of node and all descendents will be set to STALE). 

915 

916 :param name: Name of the node to set as PINNED. 

917 """ 

918 node_key = to_nodekey(name) 

919 self.set_stale(node_key) 

920 

921 def _get_descendents(self, node_key: NodeKey, stop_states: set[States] | None = None) -> set[NodeKey]: 

922 """Get all descendant nodes, optionally stopping at certain states.""" 

923 if stop_states is None: 

924 stop_states = set() 

925 if self.dag.nodes[node_key][NodeAttributes.STATE] in stop_states: 

926 return set() 

927 visited = set() 

928 to_visit = {node_key} 

929 while to_visit: 

930 n = to_visit.pop() 

931 visited.add(n) 

932 for n1 in self.dag.successors(n): 

933 if n1 in visited: 

934 continue 

935 if self.dag.nodes[n1][NodeAttributes.STATE] in stop_states: 

936 continue 

937 to_visit.add(n1) 

938 visited.remove(node_key) 

939 return visited 

940 

941 def _set_descendents(self, node_key: NodeKey, state: States) -> None: 

942 """Set the state of all descendant nodes.""" 

943 descendents = self._get_descendents(node_key, {States.PINNED}) 

944 self._set_states(descendents, state) 

945 

946 def _set_uninitialized(self, node_key: NodeKey) -> None: 

947 """Set a node to uninitialized state and clear its value.""" 

948 self._set_states([node_key], States.UNINITIALIZED) 

949 self.dag.nodes[node_key].pop(NodeAttributes.VALUE, None) 

950 

951 def _set_uptodate(self, node_key: NodeKey, value: object) -> None: 

952 """Set a node to up-to-date state with a value.""" 

953 self._set_state_and_value(node_key, States.UPTODATE, value) 

954 self._set_descendents(node_key, States.STALE) 

955 for n in self.dag.successors(node_key): 

956 self._try_set_computable(n) 

957 

958 def _set_error(self, node_key: NodeKey, exc: Exception, tb: str) -> None: 

959 """Set a node to error state with exception information.""" 

960 self._set_state_and_literal_value(node_key, States.ERROR, Error(exc, tb)) 

961 self._set_descendents(node_key, States.STALE) 

962 

963 def _try_set_computable(self, node_key: NodeKey) -> None: 

964 """Set node to computable if all predecessors are up-to-date.""" 

965 if self.dag.nodes[node_key][NodeAttributes.STATE] == States.PINNED: 

966 return 

967 if self.dag.nodes[node_key].get(NodeAttributes.FUNC) is not None: 

968 for n in self.dag.predecessors(node_key): 

969 if not self.dag.has_node(n): 

970 return # pragma: no cover 

971 if self.dag.nodes[n][NodeAttributes.STATE] != States.UPTODATE: 

972 return 

973 self._set_state(node_key, States.COMPUTABLE) 

974 

975 def _get_parameter_data(self, node_key: NodeKey) -> Iterable[_ParameterItem]: 

976 """Get all parameter data for a node's function call.""" 

977 for arg, value in self.dag.nodes[node_key][NodeAttributes.ARGS].items(): 

978 yield _ParameterItem(_ParameterType.ARG, arg, value) 

979 for param_name, value in self.dag.nodes[node_key][NodeAttributes.KWDS].items(): 

980 yield _ParameterItem(_ParameterType.KWD, param_name, value) 

981 for in_node_name in self.dag.predecessors(node_key): 

982 param_value = self.dag.nodes[in_node_name][NodeAttributes.VALUE] 

983 edge = self.dag[in_node_name][node_key] 

984 param_type, param_name = edge[EdgeAttributes.PARAM] 

985 yield _ParameterItem(param_type, param_name, param_value) 

986 

987 def _get_func_args_kwds( 

988 self, node_key: NodeKey 

989 ) -> tuple[Callable[..., Any], str | None, list[Any], dict[str, Any]]: 

990 """Get function, executor name, args and kwargs for a node.""" 

991 node0 = self.dag.nodes[node_key] 

992 f = node0[NodeAttributes.FUNC] 

993 executor_name = node0.get(NodeAttributes.EXECUTOR) 

994 args: list[Any] = [] 

995 kwds: dict[str, Any] = {} 

996 for param in self._get_parameter_data(node_key): 

997 if param.type == _ParameterType.ARG: 

998 idx = param.name 

999 assert isinstance(idx, int) # noqa: S101 

1000 while len(args) <= idx: 

1001 args.append(None) 

1002 args[idx] = param.value 

1003 elif param.type == _ParameterType.KWD: 

1004 assert isinstance(param.name, str) # noqa: S101 

1005 kwds[param.name] = param.value 

1006 else: # pragma: no cover 

1007 msg = f"Unexpected param type: {param.type}" 

1008 raise ValidationError(msg) 

1009 return f, executor_name, args, kwds 

1010 

1011 def get_definition_args_kwds(self, name: Name) -> tuple[list[Any], dict[str, Any]]: 

1012 """Get the arguments and keyword arguments for a node's function definition.""" 

1013 res_args: list[Any] = [] 

1014 res_kwds: dict[str, Any] = {} 

1015 node_key = to_nodekey(name) 

1016 node_data = self.dag.nodes[node_key] 

1017 if NodeAttributes.ARGS in node_data: 

1018 for idx, value in node_data[NodeAttributes.ARGS].items(): 

1019 while len(res_args) <= idx: 

1020 res_args.append(None) 

1021 res_args[idx] = C(value) 

1022 if NodeAttributes.KWDS in node_data: 

1023 for param_name, value in node_data[NodeAttributes.KWDS].items(): 

1024 res_kwds[param_name] = C(value) 

1025 for in_node_key in self.dag.predecessors(node_key): 

1026 edge = self.dag[in_node_key][node_key] 

1027 if EdgeAttributes.PARAM in edge: 

1028 param_type, param_name = edge[EdgeAttributes.PARAM] 

1029 if param_type == _ParameterType.ARG: 

1030 idx = param_name 

1031 assert isinstance(idx, int) # noqa: S101 

1032 while len(res_args) <= idx: 

1033 res_args.append(None) 

1034 res_args[idx] = in_node_key.name 

1035 elif param_type == _ParameterType.KWD: 

1036 res_kwds[param_name] = in_node_key.name 

1037 else: # pragma: no cover 

1038 msg = f"Unexpected param type: {param_type}" 

1039 raise ValidationError(msg) 

1040 return res_args, res_kwds 

1041 

1042 def _compute_nodes(self, node_keys: Iterable[NodeKey], raise_exceptions: bool = False) -> None: 

1043 """Compute multiple nodes, handling dependencies and parallel execution.""" 

1044 LOG.debug(f"Computing nodes {node_keys}") 

1045 

1046 futs: dict[Any, NodeKey] = {} 

1047 node_keys_set = set(node_keys) 

1048 

1049 def run(name: NodeKey) -> None: 

1050 """Submit a node computation to an executor.""" 

1051 f, executor_name, args, kwds = self._get_func_args_kwds(name) 

1052 executor = self.default_executor if executor_name is None else self.executor_map[executor_name] 

1053 fut = executor.submit(_eval_node, name, f, args, kwds, raise_exceptions) 

1054 futs[fut] = name 

1055 

1056 computed: set[NodeKey] = set() 

1057 

1058 for node_key in node_keys_set: 

1059 node0 = self.dag.nodes[node_key] 

1060 state = node0[NodeAttributes.STATE] 

1061 if state == States.COMPUTABLE: 

1062 run(node_key) 

1063 

1064 while len(futs) > 0: 

1065 done, _not_done = wait(futs.keys(), return_when=FIRST_COMPLETED) 

1066 for fut in done: 

1067 node_key = futs.pop(fut) 

1068 node0 = self.dag.nodes[node_key] 

1069 try: 

1070 value, exc, tb, start_dt, end_dt = fut.result() 

1071 except Exception as e: 

1072 exc = e 

1073 tb = traceback.format_exc() 

1074 self._set_error(node_key, exc, tb) 

1075 raise 

1076 delta = (end_dt - start_dt).total_seconds() 

1077 if exc is None: 

1078 self._set_state_and_value(node_key, States.UPTODATE, value, throw_conversion_exception=False) 

1079 node0[NodeAttributes.TIMING] = TimingData(start_dt, end_dt, delta) 

1080 self._set_descendents(node_key, States.STALE) 

1081 for n in self.dag.successors(node_key): 

1082 logging.debug(str(node_key) + " " + str(n) + " " + str(computed)) 

1083 if n in computed: 

1084 msg = f"Calculating {node_key} for the second time" 

1085 raise LoopDetectedException(msg) 

1086 self._try_set_computable(n) 

1087 node0 = self.dag.nodes[n] 

1088 state = node0[NodeAttributes.STATE] 

1089 if state == States.COMPUTABLE and n in node_keys_set: 

1090 run(n) 

1091 else: 

1092 assert tb is not None # noqa: S101 

1093 self._set_error(node_key, exc, tb) 

1094 computed.add(node_key) 

1095 

1096 def _get_calc_node_keys(self, node_key: NodeKey) -> list[NodeKey]: 

1097 """Get node keys that need to be computed for a target node.""" 

1098 g = nx.DiGraph() 

1099 g.add_nodes_from(self.dag.nodes) 

1100 g.add_edges_from(self.dag.edges) 

1101 for n in nx.ancestors(g, node_key): 

1102 node = self.dag.nodes[n] 

1103 state = node[NodeAttributes.STATE] 

1104 if state == States.UPTODATE or state == States.PINNED: 

1105 g.remove_node(n) 

1106 

1107 ancestors = nx.ancestors(g, node_key) 

1108 for n in ancestors: 

1109 node = self.dag.nodes[n] 

1110 state = node[NodeAttributes.STATE] 

1111 

1112 if state == States.UNINITIALIZED and len(self.dag.pred[n]) == 0: 

1113 msg = f"Cannot compute {node_key} because {n} uninitialized" 

1114 raise ValidationError(msg) 

1115 if state == States.PLACEHOLDER: 

1116 msg = f"Cannot compute {node_key} because {n} is placeholder" 

1117 raise ValidationError(msg) 

1118 

1119 ancestors.add(node_key) 

1120 nodes_sorted = topological_sort(g) 

1121 return [n for n in nodes_sorted if n in ancestors] 

1122 

1123 def _get_calc_node_names(self, name: Name) -> Names: 

1124 """Get node names that need to be computed for a target node.""" 

1125 node_key = to_nodekey(name) 

1126 return node_keys_to_names(self._get_calc_node_keys(node_key)) 

1127 

1128 def compute(self, name: Name | Iterable[Name], raise_exceptions: bool = False) -> None: 

1129 """Compute a node and all necessary predecessors. 

1130 

1131 Following the computation, if successful, the target node, and all necessary ancestors that were not already 

1132 UPTODATE will have been calculated and set to UPTODATE. Any node that did not need to be calculated will not 

1133 have been recalculated. 

1134 

1135 If any nodes raises an exception, then the state of that node will be set to ERROR, and its value set to an 

1136 object containing the exception object, as well as a traceback. This will not halt the computation, which 

1137 will proceed as far as it can, until no more nodes that would be required to calculate the target are 

1138 COMPUTABLE. 

1139 

1140 :param name: Name of the node to compute 

1141 :param raise_exceptions: Whether to pass exceptions raised by node computations back to the caller 

1142 :type raise_exceptions: Boolean, default False 

1143 """ 

1144 calc_nodes: set[NodeKey] | list[NodeKey] 

1145 if isinstance(name, (types.GeneratorType, list)): 

1146 calc_nodes = set() 

1147 for name0 in name: 

1148 node_key = to_nodekey(name0) 

1149 for n in self._get_calc_node_keys(node_key): 

1150 calc_nodes.add(n) 

1151 else: 

1152 node_key = to_nodekey(name) 

1153 calc_nodes = self._get_calc_node_keys(node_key) 

1154 self._compute_nodes(calc_nodes, raise_exceptions=raise_exceptions) 

1155 

1156 def compute_all(self, raise_exceptions: bool = False) -> None: 

1157 """Compute all nodes of a computation that can be computed. 

1158 

1159 Nodes that are already UPTODATE will not be recalculated. Following the computation, if successful, all 

1160 nodes will have state UPTODATE, except UNINITIALIZED input nodes and PLACEHOLDER nodes. 

1161 

1162 If any nodes raises an exception, then the state of that node will be set to ERROR, and its value set to an 

1163 object containing the exception object, as well as a traceback. This will not halt the computation, which 

1164 will proceed as far as it can, until no more nodes are COMPUTABLE. 

1165 

1166 :param raise_exceptions: Whether to pass exceptions raised by node computations back to the caller 

1167 :type raise_exceptions: Boolean, default False 

1168 """ 

1169 self._compute_nodes(self._node_keys(), raise_exceptions=raise_exceptions) 

1170 

1171 def _node_keys(self) -> list[NodeKey]: 

1172 """Get a list of nodes in this computation. 

1173 

1174 :return: List of nodes. 

1175 """ 

1176 return list(self.dag.nodes) 

1177 

1178 def nodes(self) -> list[Name]: 

1179 """Get a list of nodes in this computation. 

1180 

1181 :return: List of nodes. 

1182 """ 

1183 return [n.name for n in self.dag.nodes] 

1184 

1185 def get_tree_list_children(self, name: Name) -> set[Name]: 

1186 """Get a list of nodes in this computation. 

1187 

1188 :return: List of nodes. 

1189 """ 

1190 node_key = to_nodekey(name) 

1191 idx = len(node_key.parts) 

1192 result = set() 

1193 for n in self.dag.nodes: 

1194 if n.is_descendent_of(node_key): 

1195 result.add(n.parts[idx]) 

1196 return result 

1197 

1198 def has_node(self, name: Name) -> bool: 

1199 """Check if a node with the given name exists in the computation.""" 

1200 node_key = to_nodekey(name) 

1201 return node_key in self.dag.nodes 

1202 

1203 def tree_has_path(self, name: Name) -> bool: 

1204 """Check if a hierarchical path exists in the computation tree.""" 

1205 node_key = to_nodekey(name) 

1206 if node_key.is_root: 

1207 return True 

1208 if self.has_node(node_key): 

1209 return True 

1210 return any(n.is_descendent_of(node_key) for n in self.dag.nodes) 

1211 

1212 def get_tree_descendents( 

1213 self, name: Name | None = None, *, include_stem: bool = True, graph_nodes_only: bool = False 

1214 ) -> set[Name]: 

1215 """Get a list of descendent blocks and nodes. 

1216 

1217 Returns blocks and nodes that are descendents of the input node, 

1218 e.g. for node 'foo', might return ['foo/bar', 'foo/baz']. 

1219 

1220 :param name: Name of node to get descendents for 

1221 :return: List of descendent node names 

1222 """ 

1223 node_key = NodeKey.root() if name is None else to_nodekey(name) 

1224 stemsize = len(node_key.parts) 

1225 result = set() 

1226 for n in self.dag.nodes: 

1227 if n.is_descendent_of(node_key): 

1228 nodes = [n] if graph_nodes_only else n.ancestors() 

1229 for n2 in nodes: 

1230 if n2.is_descendent_of(node_key): 

1231 nm = n2.name if include_stem else NodeKey(tuple(n2.parts[stemsize:])).name 

1232 result.add(nm) 

1233 return result 

1234 

1235 def _state_one(self, name: Name) -> States: 

1236 """Get the state of a single node.""" 

1237 node_key = to_nodekey(name) 

1238 state: States = self.dag.nodes[node_key][NodeAttributes.STATE] 

1239 return state 

1240 

1241 def state(self, name: Name | Names) -> Any: 

1242 """Get the state of a node. 

1243 

1244 This can also be accessed using the attribute-style accessor ``s`` if ``name`` is a valid Python 

1245 attribute name:: 

1246 

1247 >>> comp = Computation() 

1248 >>> comp.add_node('foo', value=1) 

1249 >>> comp.state('foo') 

1250 <States.UPTODATE: 4> 

1251 >>> comp.s.foo 

1252 <States.UPTODATE: 4> 

1253 

1254 :param name: Name or names of the node to get state for 

1255 :type name: Name or Names 

1256 """ 

1257 return apply1(self._state_one, name) 

1258 

1259 def _value_one(self, name: Name) -> Any: 

1260 """Get the value of a single node.""" 

1261 node_key = to_nodekey(name) 

1262 return self.dag.nodes[node_key][NodeAttributes.VALUE] 

1263 

1264 def value(self, name: Name | Names) -> Any: 

1265 """Get the current value of a node. 

1266 

1267 This can also be accessed using the attribute-style accessor ``v`` if ``name`` is a valid Python 

1268 attribute name:: 

1269 

1270 >>> comp = Computation() 

1271 >>> comp.add_node('foo', value=1) 

1272 >>> comp.value('foo') 

1273 1 

1274 >>> comp.v.foo 

1275 1 

1276 

1277 :param name: Name or names of the node to get the value of 

1278 :type name: Name or Names 

1279 """ 

1280 return apply1(self._value_one, name) 

1281 

1282 def compute_and_get_value(self, name: Name) -> Any: 

1283 """Get the current value of a node. 

1284 

1285 This can also be accessed using the attribute-style accessor ``v`` if ``name`` is a valid Python 

1286 attribute name:: 

1287 

1288 >>> comp = Computation() 

1289 >>> comp.add_node('foo', value=1) 

1290 >>> comp.add_node('bar', lambda foo: foo + 1) 

1291 >>> comp.compute_and_get_value('bar') 

1292 2 

1293 >>> comp.x.bar 

1294 2 

1295 

1296 :param name: Name or names of the node to get the value of 

1297 :type name: Name 

1298 """ 

1299 nk = to_nodekey(name) 

1300 if self.state(nk) == States.UPTODATE: 

1301 return self.value(nk) 

1302 self.compute(nk, raise_exceptions=True) 

1303 if self.state(nk) == States.UPTODATE: 

1304 return self.value(nk) 

1305 msg = f"Unable to compute node {nk}" 

1306 raise ComputationError(msg) 

1307 

1308 def _tag_one(self, name: Name) -> set[str]: 

1309 """Get the tags of a single node.""" 

1310 node_key = to_nodekey(name) 

1311 node = self.dag.nodes[node_key] 

1312 tags: set[str] = node[NodeAttributes.TAG] 

1313 return tags 

1314 

1315 def tags(self, name: Name | Names) -> Any: 

1316 """Get the tags associated with a node. 

1317 

1318 >>> comp = Computation() 

1319 >>> comp.add_node('a', tags=['foo', 'bar']) 

1320 >>> sorted(comp.t.a) 

1321 ['__serialize__', 'bar', 'foo'] 

1322 

1323 :param name: Name or names of the node to get the tags of 

1324 :return: 

1325 """ 

1326 return apply1(self._tag_one, name) 

1327 

1328 def nodes_by_tag(self, tag: str | Iterable[str]) -> set[Name]: 

1329 """Get the names of nodes with a particular tag or tags. 

1330 

1331 :param tag: Tag or tags for which to retrieve nodes 

1332 :return: Names of the nodes with those tags 

1333 """ 

1334 nodes: set[NodeKey] = set() 

1335 tags_to_check: Iterable[str] = [tag] if isinstance(tag, str) else tag 

1336 for tag1 in tags_to_check: 

1337 nodes1 = self._tag_map.get(tag1) 

1338 if nodes1 is not None: 

1339 nodes.update(nodes1) 

1340 return {n.name for n in nodes} 

1341 

1342 def _style_one(self, name: Name) -> str | None: 

1343 """Get the style of a single node.""" 

1344 node_key = to_nodekey(name) 

1345 node = self.dag.nodes[node_key] 

1346 style: str | None = node.get(NodeAttributes.STYLE) 

1347 return style 

1348 

1349 def styles(self, name: Name | Names) -> Any: 

1350 """Get the tags associated with a node. 

1351 

1352 >>> comp = Computation() 

1353 >>> comp.add_node('a', style='dot') 

1354 >>> comp.style.a 

1355 'dot' 

1356 

1357 :param name: Name or names of the node to get the tags of 

1358 :return: 

1359 """ 

1360 return apply1(self._style_one, name) 

1361 

1362 def _get_item_one(self, name: Name) -> NodeData: 

1363 """Get state and value data for a single node.""" 

1364 node_key = to_nodekey(name) 

1365 node = self.dag.nodes[node_key] 

1366 return NodeData(node[NodeAttributes.STATE], node[NodeAttributes.VALUE]) 

1367 

1368 def __getitem__(self, name: Name | Names) -> Any: 

1369 """Get the state and current value of a node. 

1370 

1371 :param name: Name of the node to get the state and value of 

1372 """ 

1373 return apply1(self._get_item_one, name) 

1374 

1375 def _get_timing_one(self, name: Name) -> TimingData | None: 

1376 """Get timing data for a single node.""" 

1377 node_key = to_nodekey(name) 

1378 node = self.dag.nodes[node_key] 

1379 timing: TimingData | None = node.get(NodeAttributes.TIMING, None) 

1380 return timing 

1381 

1382 def get_timing(self, name: Name | Names) -> Any: 

1383 """Get the timing information for a node. 

1384 

1385 :param name: Name or names of the node to get the timing information of 

1386 :return: 

1387 """ 

1388 return apply1(self._get_timing_one, name) 

1389 

1390 def to_df(self) -> pd.DataFrame: 

1391 """Get a dataframe containing the states and value of all nodes of computation. 

1392 

1393 :: 

1394 

1395 >>> import loman 

1396 >>> comp = loman.Computation() 

1397 >>> comp.add_node('foo', value=1) 

1398 >>> comp.add_node('bar', value=2) 

1399 >>> comp.to_df() # doctest: +NORMALIZE_WHITESPACE 

1400 state value 

1401 foo States.UPTODATE 1 

1402 bar States.UPTODATE 2 

1403 """ 

1404 df = pd.DataFrame(index=topological_sort(self.dag)) 

1405 df[NodeAttributes.STATE] = pd.Series(nx.get_node_attributes(self.dag, NodeAttributes.STATE)) 

1406 df[NodeAttributes.VALUE] = pd.Series(nx.get_node_attributes(self.dag, NodeAttributes.VALUE)) 

1407 df_timing = pd.DataFrame.from_dict(nx.get_node_attributes(self.dag, "timing"), orient="index") 

1408 df = pd.merge(df, df_timing, left_index=True, right_index=True, how="left") 

1409 df.index = pd.Index([nk.name for nk in df.index]) 

1410 return df 

1411 

1412 def to_dict(self) -> dict[NodeKey, Any]: 

1413 """Get a dictionary containing the values of all nodes of a computation. 

1414 

1415 :: 

1416 

1417 >>> import loman 

1418 >>> comp = loman.Computation() 

1419 >>> comp.add_node('foo', value=1) 

1420 >>> comp.add_node('bar', value=2) 

1421 >>> comp.to_dict() # doctest: +ELLIPSIS 

1422 {NodeKey('foo'): 1, NodeKey('bar'): 2} 

1423 """ 

1424 result: dict[NodeKey, Any] = nx.get_node_attributes(self.dag, NodeAttributes.VALUE) 

1425 return result 

1426 

1427 def _get_inputs_one_node_keys(self, node_key: NodeKey) -> list[NodeKey | None]: 

1428 """Get input node keys for a single node.""" 

1429 args_dict: dict[int, NodeKey] = {} 

1430 kwds: list[NodeKey | None] = [] 

1431 max_arg_index = -1 

1432 for input_node in self.dag.predecessors(node_key): 

1433 input_edge = self.dag[input_node][node_key] 

1434 input_type, input_param = input_edge[EdgeAttributes.PARAM] 

1435 if input_type == _ParameterType.ARG: 

1436 idx = input_param 

1437 max_arg_index = max(max_arg_index, idx) 

1438 args_dict[idx] = input_node 

1439 elif input_type == _ParameterType.KWD: 

1440 kwds.append(input_node) 

1441 if max_arg_index >= 0: 

1442 args: list[NodeKey | None] = [None] * (max_arg_index + 1) 

1443 for idx, input_node in args_dict.items(): 

1444 args[idx] = input_node 

1445 result: list[NodeKey | None] = args + kwds 

1446 return result 

1447 else: 

1448 return kwds 

1449 

1450 def _get_inputs_one_names(self, name: Name) -> Names: 

1451 """Get input node names for a single node.""" 

1452 node_key = to_nodekey(name) 

1453 return node_keys_to_names([nk for nk in self._get_inputs_one_node_keys(node_key) if nk is not None]) 

1454 

1455 def get_inputs(self, name: Name | Names) -> Any: 

1456 """Get a list of the inputs for a node or set of nodes. 

1457 

1458 :param name: Name or names of nodes to get inputs for 

1459 :return: If name is scalar, return a list of upstream nodes used as input. If name is a list, return a 

1460 list of list of inputs. 

1461 """ 

1462 return apply1(self._get_inputs_one_names, name) 

1463 

1464 def _get_ancestors_node_keys(self, node_keys: Iterable[NodeKey], include_self: bool = True) -> set[NodeKey]: 

1465 """Get all ancestor node keys for a set of nodes.""" 

1466 ancestors: set[NodeKey] = set() 

1467 for n in node_keys: 

1468 if include_self: 

1469 ancestors.add(n) 

1470 for ancestor in nx.ancestors(self.dag, n): 

1471 ancestors.add(ancestor) 

1472 return ancestors 

1473 

1474 def get_ancestors(self, names: Name | Names, include_self: bool = True) -> Names: 

1475 """Get all ancestor nodes of the specified nodes.""" 

1476 node_keys = names_to_node_keys(names) 

1477 ancestor_node_keys = self._get_ancestors_node_keys(node_keys, include_self) 

1478 return node_keys_to_names(ancestor_node_keys) 

1479 

1480 def _get_original_inputs_node_keys(self, node_keys: list[NodeKey] | None) -> list[NodeKey]: 

1481 """Get original input node keys that have no computation function.""" 

1482 resolved_node_keys: Iterable[NodeKey] 

1483 resolved_node_keys = self._node_keys() if node_keys is None else self._get_ancestors_node_keys(node_keys) 

1484 return [n for n in resolved_node_keys if self.dag.nodes[n].get(NodeAttributes.FUNC) is None] 

1485 

1486 def get_original_inputs(self, names: Name | Names | None = None) -> Names: 

1487 """Get a list of the original non-computed inputs for a node or set of nodes. 

1488 

1489 :param names: Name or names of nodes to get inputs for 

1490 :return: Return a list of original non-computed inputs that are ancestors of the input nodes 

1491 """ 

1492 nks = None if names is None else names_to_node_keys(names) 

1493 

1494 result_nks = self._get_original_inputs_node_keys(nks) 

1495 

1496 return node_keys_to_names(result_nks) 

1497 

1498 def _get_outputs_one(self, name: Name) -> Names: 

1499 """Get output node names for a single node.""" 

1500 node_key = to_nodekey(name) 

1501 output_node_keys = list(self.dag.successors(node_key)) 

1502 return node_keys_to_names(output_node_keys) 

1503 

1504 def get_outputs(self, name: Name | Names) -> Any: 

1505 """Get a list of the outputs for a node or set of nodes. 

1506 

1507 :param name: Name or names of nodes to get outputs for 

1508 :return: If name is scalar, return a list of downstream nodes used as output. If name is a list, return a 

1509 list of list of outputs. 

1510 

1511 """ 

1512 return apply1(self._get_outputs_one, name) 

1513 

1514 def _get_descendents_node_keys(self, node_keys: Iterable[NodeKey], include_self: bool = True) -> set[NodeKey]: 

1515 """Get all descendant node keys for a set of nodes.""" 

1516 descendent_node_keys: set[NodeKey] = set() 

1517 for node_key in node_keys: 

1518 if include_self: 

1519 descendent_node_keys.add(node_key) 

1520 for descendent in nx.descendants(self.dag, node_key): 

1521 descendent_node_keys.add(descendent) 

1522 return descendent_node_keys 

1523 

1524 def get_descendents(self, names: Name | Names, include_self: bool = True) -> Names: 

1525 """Get all descendent nodes of the specified nodes.""" 

1526 node_keys = names_to_node_keys(names) 

1527 descendent_node_keys = self._get_descendents_node_keys(node_keys, include_self) 

1528 return node_keys_to_names(descendent_node_keys) 

1529 

1530 def get_final_outputs(self, names: Name | Names | None = None) -> Names: 

1531 """Get final output nodes (nodes with no descendants) from the specified nodes.""" 

1532 final_node_keys: Iterable[NodeKey] 

1533 if names is None: 

1534 final_node_keys = self._node_keys() 

1535 else: 

1536 nks = names_to_node_keys(names) 

1537 final_node_keys = self._get_descendents_node_keys(nks) 

1538 output_node_keys = [n for n in final_node_keys if len(nx.descendants(self.dag, n)) == 0] 

1539 return node_keys_to_names(output_node_keys) 

1540 

1541 def get_source(self, name: Name) -> str: 

1542 """Get the source code for a node.""" 

1543 node_key = to_nodekey(name) 

1544 func = self.dag.nodes[node_key].get(NodeAttributes.FUNC, None) 

1545 if func is not None: 

1546 file = inspect.getsourcefile(func) 

1547 _, lineno = inspect.getsourcelines(func) 

1548 source = inspect.getsource(func) 

1549 return f"{file}:{lineno}\n\n{source}" 

1550 else: 

1551 return "NOT A CALCULATED NODE" 

1552 

1553 def print_source(self, name: Name) -> None: 

1554 """Print the source code for a computation node.""" 

1555 print(self.get_source(name)) 

1556 

1557 def restrict(self, output_names: Name | Names, input_names: Name | Names | None = None) -> None: 

1558 """Restrict a computation to the ancestors of a set of output nodes. 

1559 

1560 Excludes ancestors of a set of input nodes. 

1561 

1562 If the set of input_nodes that is specified is not sufficient for the set of output_nodes then additional 

1563 nodes that are ancestors of the output_nodes will be included, but the input nodes specified will be input 

1564 nodes of the modified Computation. 

1565 

1566 :param output_nodes: 

1567 :param input_nodes: 

1568 :return: None - modifies existing computation in place 

1569 """ 

1570 if input_names is not None: 

1571 for n in as_iterable(input_names): 

1572 nodedata = self._get_item_one(n) 

1573 self.add_node(n) 

1574 self._set_state_and_literal_value(to_nodekey(n), nodedata.state, nodedata.value) 

1575 output_node_keys = names_to_node_keys(output_names) 

1576 ancestor_node_keys = self._get_ancestors_node_keys(output_node_keys) 

1577 self.dag.remove_nodes_from([n for n in self.dag if n not in ancestor_node_keys]) 

1578 

1579 def __getstate__(self) -> dict[str, Any]: 

1580 """Prepare computation for serialization by removing non-serializable nodes.""" 

1581 node_serialize = nx.get_node_attributes(self.dag, NodeAttributes.TAG) 

1582 obj = self.copy() 

1583 for name, tags in node_serialize.items(): 

1584 if SystemTags.SERIALIZE not in tags: 

1585 obj._set_uninitialized(name) 

1586 return {"dag": obj.dag} 

1587 

1588 def __setstate__(self, state: dict[str, Any]) -> None: 

1589 """Restore computation from serialized state.""" 

1590 self.__init__() 

1591 self.dag = state["dag"] 

1592 self._refresh_maps() 

1593 

1594 def write_dill_old(self, file_: str | BinaryIO) -> None: 

1595 """Serialize a computation to a file or file-like object. 

1596 

1597 :param file_: If string, writes to a file 

1598 :type file_: File-like object, or string 

1599 """ 

1600 warnings.warn("write_dill_old is deprecated, use write_dill instead", DeprecationWarning, stacklevel=2) 

1601 original_getstate = self.__class__.__getstate__ 

1602 original_setstate = self.__class__.__setstate__ 

1603 

1604 try: 

1605 del self.__class__.__getstate__ 

1606 del self.__class__.__setstate__ 

1607 

1608 node_serialize = nx.get_node_attributes(self.dag, NodeAttributes.TAG) 

1609 obj = self.copy() 

1610 obj.executor_map = None # type: ignore[assignment] 

1611 obj.default_executor = None # type: ignore[assignment] 

1612 for name, tags in node_serialize.items(): 

1613 if SystemTags.SERIALIZE not in tags: 

1614 obj._set_uninitialized(name) 

1615 

1616 if isinstance(file_, str): 

1617 with open(file_, "wb") as f: 

1618 dill.dump(obj, f) 

1619 else: 

1620 dill.dump(obj, file_) 

1621 finally: 

1622 self.__class__.__getstate__ = original_getstate # type: ignore[method-assign] 

1623 self.__class__.__setstate__ = original_setstate 

1624 

1625 def write_dill(self, file_: str | BinaryIO) -> None: 

1626 """Serialize a computation to a file or file-like object. 

1627 

1628 :param file_: If string, writes to a file 

1629 :type file_: File-like object, or string 

1630 """ 

1631 if isinstance(file_, str): 

1632 with open(file_, "wb") as f: 

1633 dill.dump(self, f) 

1634 else: 

1635 dill.dump(self, file_) 

1636 

1637 @staticmethod 

1638 def read_dill(file_: str | BinaryIO) -> "Computation": 

1639 """Deserialize a computation from a file or file-like object. 

1640 

1641 .. warning:: 

1642 This method uses dill.load() which can execute arbitrary code. 

1643 Only load files from trusted sources. Never load data from 

1644 untrusted or unauthenticated sources as it may lead to arbitrary 

1645 code execution. 

1646 

1647 :param file_: If string, writes to a file 

1648 :type file_: File-like object, or string 

1649 """ 

1650 if isinstance(file_, str): 

1651 with open(file_, "rb") as f: 

1652 obj = dill.load(f) # noqa: S301 

1653 else: 

1654 obj = dill.load(file_) # noqa: S301 

1655 if isinstance(obj, Computation): 

1656 return obj 

1657 else: 

1658 msg = "Loaded object is not a Computation" 

1659 raise ValidationError(msg) 

1660 

1661 def copy(self) -> "Computation": 

1662 """Create a copy of a computation. 

1663 

1664 The copy is shallow. Any values in the new Computation's DAG will be the same object as this Computation's 

1665 DAG. As new objects will be created by any further computations, this should not be an issue. 

1666 

1667 :rtype: Computation 

1668 """ 

1669 obj = Computation() 

1670 obj.dag = nx.DiGraph(self.dag) 

1671 obj._tag_map = defaultdict(set, {tag: nodes.copy() for tag, nodes in self._tag_map.items()}) 

1672 obj._state_map = {state: nodes.copy() for state, nodes in self._state_map.items()} 

1673 return obj 

1674 

1675 def add_named_tuple_expansion(self, name: Name, namedtuple_type: type, group: str | None = None) -> None: 

1676 """Automatically add nodes to extract each element of a named tuple type. 

1677 

1678 It is often convenient for a calculation to return multiple values, and it is polite to do this a namedtuple 

1679 rather than a regular tuple, so that later users have same name to identify elements of the tuple. It can 

1680 also help make a computation clearer if a downstream computation depends on one element of such a tuple, 

1681 rather than the entire tuple. This does not affect the computation per se, but it does make the intention 

1682 clearer. 

1683 

1684 To avoid having to create many boiler-plate node definitions to expand namedtuples, the 

1685 ``add_named_tuple_expansion`` method automatically creates new nodes for each element of a tuple. The 

1686 convention is that an element called 'element', in a node called 'node' will be expanded into a new node 

1687 called 'node.element', and that this will be applied for each element. 

1688 

1689 Example:: 

1690 

1691 >>> from collections import namedtuple 

1692 >>> Coordinate = namedtuple('Coordinate', ['x', 'y']) 

1693 >>> comp = Computation() 

1694 >>> comp.add_node('c', value=Coordinate(1, 2)) 

1695 >>> comp.add_named_tuple_expansion('c', Coordinate) 

1696 >>> comp.compute_all() 

1697 >>> comp.value('c.x') 

1698 1 

1699 >>> comp.value('c.y') 

1700 2 

1701 

1702 :param name: Node to cera 

1703 :param namedtuple_type: Expected type of the node 

1704 :type namedtuple_type: namedtuple class 

1705 """ 

1706 

1707 def make_f(field_name: str) -> Callable[[Any], Any]: 

1708 """Create a function to extract a field from a namedtuple.""" 

1709 

1710 def get_field_value(tuple_val: Any) -> Any: 

1711 """Extract field value from the namedtuple.""" 

1712 return getattr(tuple_val, field_name) 

1713 

1714 return get_field_value 

1715 

1716 for field_name in namedtuple_type._fields: # type: ignore[attr-defined] 

1717 node_name = f"{name}.{field_name}" 

1718 self.add_node(node_name, make_f(field_name), kwds={"tuple_val": name}, group=group) 

1719 self.set_tag(node_name, SystemTags.EXPANSION) 

1720 

1721 def add_map_node( 

1722 self, 

1723 result_node: Name, 

1724 input_node: Name, 

1725 subgraph: "Computation", 

1726 subgraph_input_node: Name, 

1727 subgraph_output_node: Name, 

1728 ) -> None: 

1729 """Apply a graph to each element of iterable. 

1730 

1731 In turn, each element in the ``input_node`` of this graph will be inserted in turn into the subgraph's 

1732 ``subgraph_input_node``, then the subgraph's ``subgraph_output_node`` calculated. The resultant list, with 

1733 an element or each element in ``input_node``, will be inserted into ``result_node`` of this graph. In this 

1734 way ``add_map_node`` is similar to ``map`` in functional programming. 

1735 

1736 :param result_node: The node to place a list of results in **this** graph 

1737 :param input_node: The node to get a list input values from **this** graph 

1738 :param subgraph: The graph to use to perform calculation for each element 

1739 :param subgraph_input_node: The node in **subgraph** to insert each element in turn 

1740 :param subgraph_output_node: The node in **subgraph** to read the result for each element 

1741 """ 

1742 

1743 def f(xs: Iterable[Any]) -> list[Any]: 

1744 """Apply subgraph computation to each element in the input.""" 

1745 results: list[Any] = [] 

1746 is_error = False 

1747 for x in xs: 

1748 subgraph.insert(subgraph_input_node, x) 

1749 subgraph.compute(subgraph_output_node) 

1750 if subgraph.state(subgraph_output_node) == States.UPTODATE: 

1751 results.append(subgraph.value(subgraph_output_node)) 

1752 else: 

1753 is_error = True 

1754 results.append(subgraph.copy()) 

1755 if is_error: 

1756 msg = f"Unable to calculate {result_node}" 

1757 raise MapException(msg, results) 

1758 return results 

1759 

1760 self.add_node(result_node, f, kwds={"xs": input_node}) 

1761 

1762 def prepend_path(self, path: Name | ConstantValue, prefix_path: NodeKey) -> NodeKey | ConstantValue: 

1763 """Prepend a prefix path to a node path.""" 

1764 if isinstance(path, ConstantValue): 

1765 return path 

1766 nk = to_nodekey(path) 

1767 return prefix_path.join(nk) 

1768 

1769 def add_block( 

1770 self, 

1771 base_path: Name, 

1772 block: "Computation", 

1773 *, 

1774 keep_values: bool | None = True, 

1775 links: dict[str, Name] | None = None, 

1776 metadata: dict[str, Any] | None = None, 

1777 ) -> None: 

1778 """Add a computation block as a subgraph to this computation.""" 

1779 base_path_nk = to_nodekey(base_path) 

1780 for node_name in block.nodes(): 

1781 node_key = to_nodekey(node_name) 

1782 node_data = block.dag.nodes[node_key] 

1783 tags = node_data.get(NodeAttributes.TAG, None) 

1784 style = node_data.get(NodeAttributes.STYLE, None) 

1785 group = node_data.get(NodeAttributes.GROUP, None) 

1786 args_def, kwds_def = block.get_definition_args_kwds(node_key) 

1787 args_prepended = [self.prepend_path(arg, base_path_nk) for arg in args_def] 

1788 kwds_prepended = {k: self.prepend_path(v, base_path_nk) for k, v in kwds_def.items()} 

1789 func = node_data.get(NodeAttributes.FUNC, None) 

1790 executor = node_data.get(NodeAttributes.EXECUTOR, None) 

1791 converter = node_data.get(NodeAttributes.CONVERTER, None) 

1792 new_node_name = self.prepend_path(node_name, base_path_nk) 

1793 self.add_node( 

1794 new_node_name, 

1795 func, 

1796 args=args_prepended, 

1797 kwds=kwds_prepended, 

1798 converter=converter, 

1799 serialize=False, 

1800 inspect=False, 

1801 group=group, 

1802 tags=tags, 

1803 style=style, 

1804 executor=executor, 

1805 ) 

1806 if keep_values and NodeAttributes.VALUE in node_data: 

1807 new_node_key = to_nodekey(new_node_name) 

1808 self._set_state_and_literal_value( 

1809 new_node_key, node_data[NodeAttributes.STATE], node_data[NodeAttributes.VALUE] 

1810 ) 

1811 if links is not None: 

1812 for target, source in links.items(): 

1813 self.link(base_path_nk.join_parts(target), source) 

1814 if metadata is not None: 

1815 self._metadata[base_path_nk] = metadata 

1816 else: 

1817 if base_path_nk in self._metadata: 

1818 del self._metadata[base_path_nk] 

1819 

1820 def link(self, target: Name, source: Name) -> None: 

1821 """Create a link between two nodes in the computation graph.""" 

1822 target_nk = to_nodekey(target) 

1823 source_nk = to_nodekey(source) 

1824 if target_nk == source_nk: 

1825 return 

1826 

1827 target_style = self._style_one(target_nk) if self.has_node(target_nk) else None 

1828 source_style = self._style_one(source_nk) if self.has_node(source_nk) else None 

1829 style = target_style if target_style else source_style 

1830 

1831 self.add_node(target_nk, identity_function, kwds={"x": source_nk}, style=style) 

1832 

1833 def _repr_svg_(self) -> str | None: 

1834 """Return SVG representation for Jupyter notebook display.""" 

1835 return GraphView(self).svg() 

1836 

1837 def draw( 

1838 self, 

1839 root: NodeKey | None = None, 

1840 *, 

1841 node_transformations: dict[Name, str] | None = None, 

1842 cmap: Any = None, 

1843 colors: str = "state", 

1844 shapes: str | None = None, 

1845 graph_attr: dict[str, Any] | None = None, 

1846 node_attr: dict[str, Any] | None = None, 

1847 edge_attr: dict[str, Any] | None = None, 

1848 show_expansion: bool = False, 

1849 collapse_all: bool = True, 

1850 ) -> GraphView: 

1851 """Draw a computation's current state using the GraphViz utility. 

1852 

1853 :param root: Optional PathType. Sub-block to draw 

1854 :param cmap: Default: None 

1855 :param colors: 'state' - colors indicate state. 'timing' - colors indicate execution time. Default: 'state'. 

1856 :param shapes: None - ovals. 'type' - shapes indicate type. Default: None. 

1857 :param graph_attr: Mapping of (attribute, value) pairs for the graph. For example 

1858 ``graph_attr={'size': '"10,8"'}`` can control the size of the output graph 

1859 :param node_attr: Mapping of (attribute, value) pairs set for all nodes. 

1860 :param edge_attr: Mapping of (attribute, value) pairs set for all edges. 

1861 :param collapse_all: Whether to collapse all blocks that aren't explicitly expanded. 

1862 """ 

1863 node_formatter = NodeFormatter.create(cmap, colors, shapes) 

1864 node_transformations_copy: dict[Name, str] = ( 

1865 node_transformations.copy() if node_transformations is not None else {} 

1866 ) 

1867 if not show_expansion: 

1868 for nodekey in self.nodes_by_tag(SystemTags.EXPANSION): 

1869 node_transformations_copy[nodekey] = NodeTransformations.CONTRACT 

1870 v = GraphView( 

1871 self, 

1872 root=root, 

1873 node_formatter=node_formatter, 

1874 graph_attr=graph_attr, 

1875 node_attr=node_attr, 

1876 edge_attr=edge_attr, 

1877 node_transformations=node_transformations_copy, 

1878 collapse_all=collapse_all, 

1879 ) 

1880 return v 

1881 

1882 def view(self, cmap: Any = None, colors: str = "state", shapes: str | None = None) -> None: 

1883 """Create and display a visualization of the computation graph.""" 

1884 node_formatter = NodeFormatter.create(cmap, colors, shapes) 

1885 v = GraphView(self, node_formatter=node_formatter) 

1886 v.view() 

1887 

1888 def print_errors(self) -> None: 

1889 """Print tracebacks for every node with state "ERROR" in a Computation.""" 

1890 for n in self.nodes(): 

1891 if self.s[n] == States.ERROR: 

1892 print(f"{n}") 

1893 print("=" * len(str(n))) 

1894 print() 

1895 print(self.v[n].traceback) 

1896 print() 

1897 

1898 @classmethod 

1899 def from_class(cls, definition_class: type, ignore_self: bool = True) -> "Computation": 

1900 """Create a computation from a class with decorated methods.""" 

1901 comp = cls() 

1902 obj = definition_class() 

1903 populate_computation_from_class(comp, definition_class, obj, ignore_self=ignore_self) 

1904 return comp 

1905 

1906 def inject_dependencies(self, dependencies: dict[Name, Any], *, force: bool = False) -> None: 

1907 """Injects dependencies into the nodes of the current computation where nodes are in a placeholder state. 

1908 

1909 (or all possible nodes when the 'force' parameter is set to True), using values 

1910 provided in the 'dependencies' dictionary. 

1911 

1912 Each key in the 'dependencies' dictionary corresponds to a node identifier, and the associated 

1913 value is the dependency object to inject. If the value is a callable, it will be added as a calc node. 

1914 

1915 :param dependencies: A dictionary where each key-value pair consists of a node identifier and 

1916 its corresponding dependency object or a callable that returns the dependency object. 

1917 :param force: A boolean flag that, when set to True, forces the replacement of existing node values 

1918 with the ones provided in 'dependencies', regardless of their current state. Defaults to False. 

1919 :return: None 

1920 """ 

1921 for n in self.nodes(): 

1922 if force or self.s[n] == States.PLACEHOLDER: 

1923 obj = dependencies.get(n) 

1924 if obj is None: 

1925 continue 

1926 if callable(obj): 

1927 self.add_node(n, obj) 

1928 else: 

1929 self.add_node(n, value=obj)