Coverage for src / loman / computeengine.py: 99%

1042 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-07 21:24 +0000

1"""Core computation engine for dependency-aware calculation graphs.""" 

2 

3import functools 

4import inspect 

5import logging 

6import traceback 

7import types 

8import warnings 

9from collections import defaultdict 

10from collections.abc import Callable, Iterable, Mapping 

11from concurrent.futures import FIRST_COMPLETED, Executor, ThreadPoolExecutor, wait 

12from dataclasses import dataclass, field 

13from datetime import UTC, datetime 

14from enum import Enum 

15from typing import TYPE_CHECKING, Any, BinaryIO, TextIO, TypeVar, overload 

16 

17if TYPE_CHECKING: 

18 from .serialization.computation import ComputationSerializer 

19 

20import decorator 

21import dill # nosec B403 

22import networkx as nx 

23import pandas as pd 

24 

25from .compat import get_signature 

26from .consts import EdgeAttributes, NodeAttributes, NodeTransformations, States, SystemTags 

27from .exception import ( 

28 CannotInsertToPlaceholderNodeException, 

29 ComputationError, 

30 LoopDetectedException, 

31 MapException, 

32 NodeAlreadyExistsException, 

33 NonExistentNodeException, 

34 ValidationError, 

35) 

36from .graph_utils import topological_sort 

37from .nodekey import Name, Names, NodeKey, names_to_node_keys, node_keys_to_names, to_nodekey 

38from .util import AttributeView, apply1, apply_n, as_iterable, value_eq 

39from .visualization import GraphView, NodeFormatter 

40 

41LOG = logging.getLogger("loman.computeengine") 

42 

43F = TypeVar("F", bound=Callable[..., Any]) 

44 

45 

46@dataclass 

47class Error: 

48 """Container for error information during computation.""" 

49 

50 exception: Exception 

51 traceback: str 

52 

53 

54@dataclass 

55class NodeData: 

56 """Data associated with a computation node.""" 

57 

58 state: States 

59 value: object 

60 

61 

62@dataclass 

63class TimingData: 

64 """Timing information for computation execution.""" 

65 

66 start: datetime 

67 end: datetime 

68 duration: float 

69 

70 

71class _ParameterType(Enum): 

72 """Internal enum for distinguishing positional and keyword parameters.""" 

73 

74 ARG = 1 

75 KWD = 2 

76 

77 

78@dataclass 

79class _ParameterItem: 

80 """Internal container for parameter information during computation.""" 

81 

82 type: _ParameterType 

83 name: int | str 

84 value: object 

85 

86 

87def _node(func: Callable[..., Any], *args: Any, **kws: Any) -> Any: # pragma: no cover 

88 """Internal wrapper function for node decorator.""" 

89 return func(*args, **kws) 

90 

91 

92def node(comp: "Computation", name: Name | None = None, *args: Any, **kw: Any) -> Callable[[F], F]: 

93 """Decorator to add a function as a node to a computation graph.""" 

94 

95 def inner(f: F) -> F: 

96 """Inner decorator that registers the function as a node.""" 

97 if name is None: 

98 comp.add_node(f.__name__, f, *args, **kw) 

99 else: 

100 comp.add_node(name, f, *args, **kw) 

101 result: F = decorator.decorate(f, _node) 

102 return result 

103 

104 return inner 

105 

106 

107@dataclass() 

108class ConstantValue: 

109 """Container for constant values in computations.""" 

110 

111 value: object 

112 

113 

114C = ConstantValue 

115 

116 

117class Node: 

118 """Base class for computation graph nodes.""" 

119 

120 def add_to_comp(self, comp: "Computation", name: str, obj: object, ignore_self: bool) -> None: 

121 """Add this node to the computation graph.""" 

122 raise NotImplementedError() 

123 

124 

125@dataclass 

126class InputNode(Node): 

127 """A node representing input data in the computation graph.""" 

128 

129 args: tuple[Any, ...] = field(default_factory=tuple) 

130 kwds: dict[str, Any] = field(default_factory=dict) 

131 

132 def __init__(self, *args: Any, **kwds: Any) -> None: 

133 """Initialize an input node with arguments and keyword arguments.""" 

134 self.args = args 

135 self.kwds = kwds 

136 

137 def add_to_comp(self, comp: "Computation", name: str, obj: object, ignore_self: bool) -> None: 

138 """Add this input node to the computation graph.""" 

139 comp.add_node(name, **self.kwds) 

140 

141 

142input_node = InputNode 

143 

144 

145@dataclass 

146class CalcNode(Node): 

147 """A node representing a calculation in the computation graph.""" 

148 

149 f: Callable[..., Any] 

150 kwds: dict[str, Any] = field(default_factory=dict) 

151 

152 def add_to_comp(self, comp: "Computation", name: str, obj: object, ignore_self: bool) -> None: 

153 """Add this calculation node to the computation graph.""" 

154 kwds = self.kwds.copy() 

155 ignore_self = ignore_self or kwds.get("ignore_self", False) 

156 f = self.f 

157 if ignore_self: 

158 signature = get_signature(self.f) 

159 if len(signature.kwd_params) > 0 and signature.kwd_params[0] == "self": 

160 f = f.__get__(obj, obj.__class__) # type: ignore[attr-defined] 

161 if "ignore_self" in kwds: 

162 del kwds["ignore_self"] 

163 comp.add_node(name, f, **kwds) 

164 

165 

166@overload 

167def calc_node(f: F, **kwds: Any) -> F: ... 

168 

169 

170@overload 

171def calc_node(f: None = None, **kwds: Any) -> Callable[[F], F]: ... 

172 

173 

174def calc_node(f: F | None = None, **kwds: Any) -> F | Callable[[F], F]: 

175 """Decorator to mark a function as a calculation node.""" 

176 

177 def wrap(func: F) -> F: 

178 """Wrap function with node info attribute.""" 

179 func._loman_node_info = CalcNode(func, kwds) 

180 return func 

181 

182 if f is None: 

183 return wrap 

184 return wrap(f) 

185 

186 

187@dataclass 

188class Block(Node): 

189 """A node representing a computational block or subgraph.""" 

190 

191 block: "Callable[[], Computation] | Computation" 

192 args: tuple[Any, ...] = field(default_factory=tuple) 

193 kwds: dict[str, Any] = field(default_factory=dict) 

194 

195 def __init__(self, block: "Callable[[], Computation] | Computation", *args: Any, **kwds: Any) -> None: 

196 """Initialize a block node with a computation block and arguments.""" 

197 self.block = block 

198 self.args = args 

199 self.kwds = kwds 

200 

201 def add_to_comp(self, comp: "Computation", name: str, obj: object, ignore_self: bool) -> None: 

202 """Add this block node to the computation graph.""" 

203 if isinstance(self.block, Computation): 

204 comp.add_block(name, self.block, *self.args, **self.kwds) 

205 elif callable(self.block): 

206 block0 = self.block() 

207 comp.add_block(name, block0, *self.args, **self.kwds) 

208 else: 

209 msg = f"Block {self.block} must be callable or Computation" 

210 raise TypeError(msg) 

211 

212 

213block = Block 

214 

215 

216def populate_computation_from_class(comp: "Computation", cls: type, obj: object, ignore_self: bool = True) -> None: 

217 """Populate a computation from class methods with node decorators.""" 

218 for name, member in inspect.getmembers(cls): 

219 node_: Node | None = None 

220 if isinstance(member, Node): 

221 node_ = member 

222 elif hasattr(member, "_loman_node_info"): 

223 node_ = member._loman_node_info 

224 if node_ is not None: 

225 node_.add_to_comp(comp, name, obj, ignore_self) 

226 

227 

228def computation_factory( 

229 maybe_cls: type | None = None, *, ignore_self: bool = True 

230) -> Callable[..., "Computation"] | Callable[[type], Callable[..., "Computation"]]: 

231 """Factory function to create computations from class definitions.""" 

232 

233 def wrap(cls: type) -> Callable[..., "Computation"]: 

234 """Wrap class to create computation factory function.""" 

235 

236 @functools.wraps(cls, updated=()) 

237 def create_computation(*args: Any, **kwargs: Any) -> "Computation": 

238 """Create a computation instance from the wrapped class.""" 

239 obj = cls() 

240 comp = Computation(*args, **kwargs) 

241 comp._definition_object = obj # type: ignore[attr-defined] 

242 populate_computation_from_class(comp, cls, obj, ignore_self) 

243 return comp 

244 

245 return create_computation 

246 

247 if maybe_cls is None: 

248 return wrap 

249 

250 return wrap(maybe_cls) 

251 

252 

253def _eval_node( 

254 name: NodeKey, 

255 f: Callable[..., Any], 

256 args: list[Any], 

257 kwds: dict[str, Any], 

258 raise_exceptions: bool, 

259) -> tuple[Any, Exception | None, str | None, datetime, datetime]: 

260 """To make multiprocessing work, this function must be standalone so that pickle works.""" 

261 exc: Exception | None = None 

262 tb: str | None = None 

263 start_dt = datetime.now(UTC) 

264 try: 

265 logging.debug("Running " + str(name)) 

266 value = f(*args, **kwds) 

267 logging.debug("Completed " + str(name)) 

268 except Exception as e: 

269 value = None 

270 exc = e 

271 tb = traceback.format_exc() 

272 if raise_exceptions: 

273 raise 

274 end_dt = datetime.now(UTC) 

275 return value, exc, tb, start_dt, end_dt 

276 

277 

278_MISSING_VALUE_SENTINEL = object() 

279 

280 

281class NullObject: 

282 """Debug helper object that raises exceptions for all attribute/item access.""" 

283 

284 def __getattr__(self, name: str) -> Any: 

285 """Raise AttributeError for any attribute access.""" 

286 print(f"__getattr__: {name}") 

287 msg = f"'NullObject' object has no attribute '{name}'" 

288 raise AttributeError(msg) 

289 

290 def __setattr__(self, name: str, value: Any) -> None: 

291 """Raise AttributeError for any attribute assignment.""" 

292 print(f"__setattr__: {name}") 

293 msg = f"'NullObject' object has no attribute '{name}'" 

294 raise AttributeError(msg) 

295 

296 def __delattr__(self, name: str) -> None: 

297 """Raise AttributeError for any attribute deletion.""" 

298 print(f"__delattr__: {name}") 

299 msg = f"'NullObject' object has no attribute '{name}'" 

300 raise AttributeError(msg) 

301 

302 def __call__(self, *args: Any, **kwargs: Any) -> Any: 

303 """Raise TypeError when called as a function.""" 

304 print(f"__call__: {args}, {kwargs}") 

305 msg = "'NullObject' object is not callable" 

306 raise TypeError(msg) 

307 

308 def __getitem__(self, key: Any) -> Any: 

309 """Raise KeyError for any item access.""" 

310 print(f"__getitem__: {key}") 

311 msg = f"'NullObject' object has no item with key '{key}'" 

312 raise KeyError(msg) 

313 

314 def __setitem__(self, key: Any, value: Any) -> None: 

315 """Raise KeyError for any item assignment.""" 

316 print(f"__setitem__: {key}") 

317 msg = f"'NullObject' object cannot have items set with key '{key}'" 

318 raise KeyError(msg) 

319 

320 def __repr__(self) -> str: 

321 """Return string representation of NullObject.""" 

322 print(f"__repr__: {object.__getattribute__(self, '__dict__')}") 

323 return "<NullObject>" 

324 

325 

326def identity_function(x: Any) -> Any: 

327 """Return the input value unchanged.""" 

328 return x 

329 

330 

331class Computation: 

332 """A computation graph that manages dependencies and calculations. 

333 

334 The Computation class provides a framework for building and executing 

335 computation graphs where nodes represent data or calculations, and edges 

336 represent dependencies between them. 

337 """ 

338 

339 def __init__( 

340 self, 

341 *, 

342 default_executor: Executor | None = None, 

343 executor_map: dict[str, Executor] | None = None, 

344 metadata: dict[str, Any] | None = None, 

345 ) -> None: 

346 """Initialize a new Computation. 

347 

348 :param default_executor: An executor 

349 :type default_executor: concurrent.futures.Executor, default ThreadPoolExecutor(max_workers=1) 

350 """ 

351 if default_executor is None: 

352 self.default_executor: Executor = ThreadPoolExecutor(1) 

353 else: 

354 self.default_executor = default_executor 

355 if executor_map is None: 

356 self.executor_map: dict[str, Executor] = {} 

357 else: 

358 self.executor_map = executor_map 

359 self.dag: nx.DiGraph = nx.DiGraph() 

360 self._metadata: dict[NodeKey, Any] = {} 

361 if metadata is not None: 

362 self._metadata[NodeKey.root()] = metadata 

363 

364 self.v = self.get_attribute_view_for_path(NodeKey.root(), self._value_one, self.value) 

365 self.s = self.get_attribute_view_for_path(NodeKey.root(), self._state_one, self.state) 

366 self.i = self.get_attribute_view_for_path(NodeKey.root(), self._get_inputs_one_names, self.get_inputs) 

367 self.o = self.get_attribute_view_for_path(NodeKey.root(), self._get_outputs_one, self.get_outputs) 

368 self.t = self.get_attribute_view_for_path(NodeKey.root(), self._tag_one, self.tags) 

369 self.style = self.get_attribute_view_for_path(NodeKey.root(), self._style_one, self.styles) 

370 self.tim = self.get_attribute_view_for_path(NodeKey.root(), self._get_timing_one, self.get_timing) 

371 self.x = self.get_attribute_view_for_path( 

372 NodeKey.root(), self.compute_and_get_value, self.compute_and_get_value 

373 ) 

374 self.src = self.get_attribute_view_for_path(NodeKey.root(), self.print_source, self.print_source) 

375 self._tag_map: defaultdict[str, set[NodeKey]] = defaultdict(set) 

376 self._state_map: dict[States, set[NodeKey]] = {state: set() for state in States} 

377 

378 def get_attribute_view_for_path( 

379 self, nodekey: NodeKey, get_one_func: Callable[[Name], Any], get_many_func: Callable[[Name | Names], Any] 

380 ) -> AttributeView: 

381 """Create an attribute view for a specific node path.""" 

382 

383 def node_func() -> Iterable[str]: 

384 """Return list of child node names for this path.""" 

385 return [str(n) for n in self.get_tree_list_children(nodekey)] 

386 

387 def get_one_func_for_path(name: str) -> Any: 

388 """Get value for a single node at this path.""" 

389 nk = to_nodekey(name) 

390 new_nk = nk.prepend(nodekey) 

391 if self.has_node(new_nk): 

392 return get_one_func(new_nk) 

393 elif self.tree_has_path(new_nk): 

394 return self.get_attribute_view_for_path(new_nk, get_one_func, get_many_func) 

395 else: 

396 msg = f"Path {new_nk} does not exist" 

397 raise KeyError(msg) # pragma: no cover 

398 

399 def get_many_func_for_path(name: Name | Names) -> Any: 

400 """Get values for one or more nodes at this path.""" 

401 if isinstance(name, list): 

402 return [get_one_func_for_path(str(n)) for n in name] 

403 else: 

404 return get_one_func_for_path(str(name)) 

405 

406 return AttributeView(node_func, get_one_func_for_path, get_many_func_for_path) 

407 

408 def _get_names_for_state(self, state: States) -> set[Name]: 

409 """Get node names that have a specific state.""" 

410 return set(node_keys_to_names(self._state_map[state])) 

411 

412 def _get_tags_for_state(self, tag: str) -> set[Name]: 

413 """Get node names that have a specific tag.""" 

414 return set(node_keys_to_names(self._tag_map[tag])) 

415 

416 def _process_function_args(self, node_key: NodeKey, node: dict[str, Any], args: list[Any] | None) -> int: 

417 """Process positional arguments for a function node.""" 

418 args_count = 0 

419 if args: 

420 args_count = len(args) 

421 for i, arg in enumerate(args): 

422 if isinstance(arg, ConstantValue): 

423 node[NodeAttributes.ARGS][i] = arg.value 

424 else: 

425 input_vertex_name = arg 

426 input_vertex_node_key = to_nodekey(input_vertex_name) 

427 if not self.dag.has_node(input_vertex_node_key): 

428 self.dag.add_node(input_vertex_node_key, **{NodeAttributes.STATE: States.PLACEHOLDER}) 

429 self._state_map[States.PLACEHOLDER].add(input_vertex_node_key) 

430 self.dag.add_edge( 

431 input_vertex_node_key, node_key, **{EdgeAttributes.PARAM: (_ParameterType.ARG, i)} 

432 ) 

433 return args_count 

434 

435 def _build_param_map( 

436 self, 

437 func: Callable[..., Any], 

438 node_key: NodeKey, 

439 args_count: int, 

440 kwds: dict[str, Any] | None, 

441 inspect: bool, 

442 ) -> tuple[dict[str, Any], list[str]]: 

443 """Build parameter map for function node.""" 

444 param_map: dict[str, Any] = {} 

445 default_names: list[str] = [] 

446 

447 if inspect: 

448 signature = get_signature(func) 

449 if not signature.has_var_args: 

450 for param_name in signature.kwd_params[args_count:]: 

451 if kwds is not None and param_name in kwds: 

452 param_source = kwds[param_name] 

453 else: 

454 param_source = node_key.parent.join_parts(param_name) 

455 param_map[param_name] = param_source 

456 if signature.has_var_kwds and kwds is not None: 

457 for param_name, param_source in kwds.items(): 

458 param_map[param_name] = param_source 

459 default_names = signature.default_params 

460 else: 

461 if kwds is not None: 

462 for param_name, param_source in kwds.items(): 

463 param_map[param_name] = param_source 

464 

465 return param_map, default_names 

466 

467 def _process_function_kwds( 

468 self, node_key: NodeKey, node: dict[str, Any], param_map: dict[str, Any], default_names: list[str] 

469 ) -> None: 

470 """Process keyword arguments for a function node.""" 

471 for param_name, param_source in param_map.items(): 

472 if isinstance(param_source, ConstantValue): 

473 node[NodeAttributes.KWDS][param_name] = param_source.value 

474 else: 

475 in_node_name = param_source 

476 in_node_key = to_nodekey(in_node_name) 

477 if not self.dag.has_node(in_node_key): 

478 if param_name in default_names: 

479 continue 

480 else: 

481 self.dag.add_node(in_node_key, **{NodeAttributes.STATE: States.PLACEHOLDER}) 

482 self._state_map[States.PLACEHOLDER].add(in_node_key) 

483 self.dag.add_edge(in_node_key, node_key, **{EdgeAttributes.PARAM: (_ParameterType.KWD, param_name)}) 

484 

485 def add_node( 

486 self, 

487 name: Name, 

488 func: Callable[..., Any] | None = None, 

489 *, 

490 args: list[Any] | None = None, 

491 kwds: dict[str, Any] | None = None, 

492 value: Any = _MISSING_VALUE_SENTINEL, 

493 converter: Callable[[Any], Any] | None = None, 

494 serialize: bool = True, 

495 inspect: bool = True, 

496 group: str | None = None, 

497 tags: Iterable[str] | None = None, 

498 style: str | None = None, 

499 executor: str | None = None, 

500 metadata: dict[str, Any] | None = None, 

501 ) -> None: 

502 """Adds or updates a node in a computation. 

503 

504 :param name: Name of the node to add. This may be any hashable object. 

505 :param func: Function to use to calculate the node if the node is a calculation node. By default, the input 

506 nodes to the function will be implied from the names of the function parameters. For example, a 

507 parameter called ``a`` would be taken from the node called ``a``. This can be modified with the 

508 ``kwds`` parameter. 

509 :type func: Function, default None 

510 :param args: Specifies a list of nodes that will be used to populate arguments of the function positionally 

511 for a calculation node. e.g. If args is ``['a', 'b', 'c']`` then the function would be called with 

512 three parameters, taken from the nodes 'a', 'b' and 'c' respectively. 

513 :type args: List, default None 

514 :param kwds: Specifies a mapping from parameter name to the node that should be used to populate that 

515 parameter when calling the function for a calculation node. e.g. If args is ``{'x': 'a', 'y': 'b'}`` 

516 then the function would be called with parameters named 'x' and 'y', and their values would be taken 

517 from nodes 'a' and 'b' respectively. Each entry in the dictionary can be read as "take parameter 

518 [key] from node [value]". 

519 :type kwds: Dictionary, default None 

520 :param value: If given, the value is inserted into the node, and the node state set to UPTODATE. 

521 :type value: default None 

522 :param serialize: Whether the node should be serialized. Some objects cannot be serialized, in which 

523 case, set serialize to False 

524 :type serialize: boolean, default True 

525 :param inspect: Whether to use introspection to determine the arguments of the function, which can be 

526 slow. If this is not set, kwds and args must be set for the function to obtain parameters. 

527 :type inspect: boolean, default True 

528 :param group: Subgraph to render node in 

529 :type group: default None 

530 :param tags: Set of tags to apply to node 

531 :type tags: Iterable 

532 :param styles: Style to apply to node 

533 :type styles: String, default None 

534 :param executor: Name of executor to run node on 

535 :type executor: string 

536 """ 

537 node_key = to_nodekey(name) 

538 LOG.debug(f"Adding node {node_key}") 

539 has_value = value is not _MISSING_VALUE_SENTINEL 

540 if value is _MISSING_VALUE_SENTINEL: 

541 value = None 

542 if tags is None: 

543 tags = [] 

544 

545 self.dag.add_node(node_key) 

546 pred_edges = [(p, node_key) for p in self.dag.predecessors(node_key)] 

547 self.dag.remove_edges_from(pred_edges) 

548 node = self.dag.nodes[node_key] 

549 

550 if metadata is None: 

551 if node_key in self._metadata: 

552 del self._metadata[node_key] 

553 else: 

554 self._metadata[node_key] = metadata 

555 

556 self._set_state_and_literal_value(node_key, States.UNINITIALIZED, None, require_old_state=False) 

557 

558 node[NodeAttributes.TAG] = set() 

559 node[NodeAttributes.STYLE] = style 

560 node[NodeAttributes.GROUP] = group 

561 node[NodeAttributes.ARGS] = {} 

562 node[NodeAttributes.KWDS] = {} 

563 node[NodeAttributes.FUNC] = None 

564 node[NodeAttributes.EXECUTOR] = executor 

565 node[NodeAttributes.CONVERTER] = converter 

566 

567 if func: 

568 node[NodeAttributes.FUNC] = func 

569 args_count = self._process_function_args(node_key, node, args) 

570 param_map, default_names = self._build_param_map(func, node_key, args_count, kwds, inspect) 

571 self._process_function_kwds(node_key, node, param_map, default_names) 

572 self._set_descendents(node_key, States.STALE) 

573 

574 if has_value: 

575 self._set_uptodate(node_key, value) 

576 if node[NodeAttributes.STATE] == States.UNINITIALIZED: 

577 self._try_set_computable(node_key) 

578 self.set_tag(node_key, tags) 

579 if serialize: 

580 self.set_tag(node_key, SystemTags.SERIALIZE) 

581 

582 def _refresh_maps(self) -> None: 

583 """Refresh internal tag and state maps from node data.""" 

584 self._tag_map.clear() 

585 for state in States: 

586 self._state_map[state].clear() 

587 for node_key in self._node_keys(): 

588 state = self.dag.nodes[node_key][NodeAttributes.STATE] 

589 self._state_map[state].add(node_key) 

590 tags = self.dag.nodes[node_key].get(NodeAttributes.TAG, set()) 

591 for tag in tags: 

592 self._tag_map[tag].add(node_key) 

593 

594 def _set_tag_one(self, name: Name, tag: str) -> None: 

595 """Set a single tag on a single node.""" 

596 node_key = to_nodekey(name) 

597 self.dag.nodes[node_key][NodeAttributes.TAG].add(tag) 

598 self._tag_map[tag].add(node_key) 

599 

600 def set_tag(self, name: Name | Names, tag: str | Iterable[str]) -> None: 

601 """Set tags on a node or nodes. Ignored if tags are already set. 

602 

603 :param name: Node or nodes to set tag for 

604 :param tag: Tag to set 

605 """ 

606 apply_n(self._set_tag_one, name, tag) 

607 

608 def _clear_tag_one(self, name: Name, tag: str) -> None: 

609 """Clear a single tag from a single node.""" 

610 node_key = to_nodekey(name) 

611 self.dag.nodes[node_key][NodeAttributes.TAG].discard(tag) 

612 self._tag_map[tag].discard(node_key) 

613 

614 def clear_tag(self, name: Name | Names, tag: str | Iterable[str]) -> None: 

615 """Clear tag on a node or nodes. Ignored if tags are not set. 

616 

617 :param name: Node or nodes to clear tags for 

618 :param tag: Tag to clear 

619 """ 

620 apply_n(self._clear_tag_one, name, tag) 

621 

622 def _set_style_one(self, name: Name, style: str) -> None: 

623 """Set style on a single node.""" 

624 node_key = to_nodekey(name) 

625 self.dag.nodes[node_key][NodeAttributes.STYLE] = style 

626 

627 def set_style(self, name: Name | Names, style: str) -> None: 

628 """Set styles on a node or nodes. 

629 

630 :param name: Node or nodes to set style for 

631 :param style: Style to set 

632 """ 

633 apply_n(self._set_style_one, name, style) 

634 

635 def _clear_style_one(self, name: Name) -> None: 

636 """Clear style from a single node.""" 

637 node_key = to_nodekey(name) 

638 self.dag.nodes[node_key][NodeAttributes.STYLE] = None 

639 

640 def clear_style(self, name: Name | Names) -> None: 

641 """Clear style on a node or nodes. 

642 

643 :param name: Node or nodes to clear styles for 

644 """ 

645 apply_n(self._clear_style_one, name) 

646 

647 def metadata(self, name: Name) -> dict[str, Any]: 

648 """Get metadata for a node.""" 

649 node_key = to_nodekey(name) 

650 if self.tree_has_path(name): 

651 if node_key not in self._metadata: 

652 self._metadata[node_key] = {} 

653 result: dict[str, Any] = self._metadata[node_key] 

654 return result 

655 else: 

656 msg = f"Node {node_key} does not exist." 

657 raise NonExistentNodeException(msg) 

658 

659 def delete_node(self, name: Name) -> None: 

660 """Delete a node from a computation. 

661 

662 When nodes are explicitly deleted with ``delete_node``, but are still depended on by other nodes, then they 

663 will be set to PLACEHOLDER status. In this case, if the nodes that depend on a PLACEHOLDER node are deleted, 

664 then the PLACEHOLDER node will also be deleted. 

665 

666 :param name: Name of the node to delete. If the node does not exist, a ``NonExistentNodeException`` will 

667 be raised. 

668 """ 

669 node_key = to_nodekey(name) 

670 LOG.debug(f"Deleting node {node_key}") 

671 

672 if not self.dag.has_node(node_key): 

673 msg = f"Node {node_key} does not exist" 

674 raise NonExistentNodeException(msg) 

675 

676 if node_key in self._metadata: 

677 del self._metadata[node_key] 

678 

679 if len(self.dag.succ[node_key]) == 0: 

680 preds = self.dag.predecessors(node_key) 

681 state = self.dag.nodes[node_key][NodeAttributes.STATE] 

682 self.dag.remove_node(node_key) 

683 self._state_map[state].remove(node_key) 

684 for n in preds: 

685 if self.dag.nodes[n][NodeAttributes.STATE] == States.PLACEHOLDER: 

686 self.delete_node(n) 

687 else: 

688 self._set_state(node_key, States.PLACEHOLDER) 

689 

690 def rename_node(self, old_name: Name | Mapping[Name, Name], new_name: Name | None = None) -> None: 

691 """Rename a node in a computation. 

692 

693 :param old_name: Node to rename, or a dictionary of nodes to rename, with existing names as keys, and 

694 new names as values 

695 :param new_name: New name for node. 

696 """ 

697 name_mapping: dict[Name, Name] 

698 if isinstance(old_name, Mapping) and not isinstance(old_name, str): 

699 for k, v in old_name.items(): 

700 LOG.debug(f"Renaming node {k} to {v}") 

701 if new_name is not None: 

702 msg = "new_name must not be set if rename_node is passed a dictionary" 

703 raise ValueError(msg) 

704 else: 

705 name_mapping = dict(old_name) # type: ignore[arg-type] 

706 else: 

707 LOG.debug(f"Renaming node {old_name} to {new_name}") 

708 old_node_key = to_nodekey(old_name) 

709 if not self.dag.has_node(old_node_key): 

710 msg = f"Node {old_name} does not exist" 

711 raise NonExistentNodeException(msg) 

712 assert new_name is not None # noqa: S101 

713 new_node_key = to_nodekey(new_name) 

714 if self.dag.has_node(new_node_key): 

715 msg = f"Node {new_name} already exists" 

716 raise NodeAlreadyExistsException(msg) 

717 name_mapping = {old_name: new_name} 

718 

719 node_key_mapping = {to_nodekey(on): to_nodekey(nn) for on, nn in name_mapping.items()} 

720 nx.relabel_nodes(self.dag, node_key_mapping, copy=False) 

721 

722 for old_nk, new_nk in node_key_mapping.items(): 

723 if old_nk in self._metadata: 

724 self._metadata[new_nk] = self._metadata[old_nk] 

725 del self._metadata[old_nk] 

726 else: 

727 if new_nk in self._metadata: # pragma: no cover 

728 del self._metadata[new_nk] 

729 

730 self._refresh_maps() 

731 

732 def repoint(self, old_name: Name, new_name: Name) -> None: 

733 """Changes all nodes that use old_name as an input to use new_name instead. 

734 

735 Note that if old_name is an input to new_name, then that will not be changed, to try to avoid introducing 

736 circular dependencies, but other circular dependencies will not be checked. 

737 

738 If new_name does not exist, then it will be created as a PLACEHOLDER node. 

739 

740 :param old_name: 

741 :param new_name: 

742 :return: 

743 """ 

744 old_node_key = to_nodekey(old_name) 

745 new_node_key = to_nodekey(new_name) 

746 if old_node_key == new_node_key: 

747 return 

748 

749 changed_names = list(self.dag.successors(old_node_key)) 

750 

751 if len(changed_names) > 0 and not self.dag.has_node(new_node_key): 

752 self.dag.add_node(new_node_key, **{NodeAttributes.STATE: States.PLACEHOLDER}) 

753 self._state_map[States.PLACEHOLDER].add(new_node_key) 

754 

755 for name in changed_names: 

756 if name == new_node_key: 

757 continue 

758 edge_data = self.dag.get_edge_data(old_node_key, name) 

759 self.dag.add_edge(new_node_key, name, **edge_data) 

760 self.dag.remove_edge(old_node_key, name) 

761 

762 for n in changed_names: 

763 self.set_stale(n) 

764 

765 def insert(self, name: Name, value: Any, force: bool = False) -> None: 

766 """Insert a value into a node of a computation. 

767 

768 Following insertation, the node will have state UPTODATE, and all its descendents will be COMPUTABLE or STALE. 

769 

770 If an attempt is made to insert a value into a node that does not exist, a ``NonExistentNodeException`` 

771 will be raised. 

772 

773 :param name: Name of the node to add. 

774 :param value: The value to be inserted into the node. 

775 :param force: Whether to force recalculation of descendents if node value and state would not be changed 

776 """ 

777 node_key = to_nodekey(name) 

778 LOG.debug(f"Inserting value into node {node_key}") 

779 

780 if not self.dag.has_node(node_key): 

781 msg = f"Node {node_key} does not exist" 

782 raise NonExistentNodeException(msg) 

783 

784 state = self._state_one(name) 

785 if state == States.PLACEHOLDER: 

786 msg = "Cannot insert into placeholder node. Use add_node to create the node first" 

787 raise CannotInsertToPlaceholderNodeException(msg) 

788 

789 if not force and state == States.UPTODATE: 

790 current_value = self._value_one(name) 

791 if value_eq(value, current_value): 

792 return 

793 

794 self._set_state_and_value(node_key, States.UPTODATE, value) 

795 self._set_descendents(node_key, States.STALE) 

796 for n in self.dag.successors(node_key): 

797 self._try_set_computable(n) 

798 

799 def insert_many(self, name_value_pairs: Iterable[tuple[Name, object]]) -> None: 

800 """Insert values into many nodes of a computation simultaneously. 

801 

802 Following insertation, the nodes will have state UPTODATE, and all their descendents will be COMPUTABLE 

803 or STALE. In the case of inserting many nodes, some of which are descendents of others, this ensures that 

804 the inserted nodes have correct status, rather than being set as STALE when their ancestors are inserted. 

805 

806 If an attempt is made to insert a value into a node that does not exist, a ``NonExistentNodeException`` will be 

807 raised, and none of the nodes will be inserted. 

808 

809 :param name_value_pairs: Each tuple should be a pair (name, value), where name is the name of the node to 

810 insert the value into. 

811 :type name_value_pairs: List of tuples 

812 """ 

813 node_key_value_pairs = [(to_nodekey(name), value) for name, value in name_value_pairs] 

814 LOG.debug(f"Inserting value into nodes {', '.join(str(name) for name, value in node_key_value_pairs)}") 

815 

816 for name, _value in node_key_value_pairs: 

817 if not self.dag.has_node(name): 

818 msg = f"Node {name} does not exist" 

819 raise NonExistentNodeException(msg) 

820 

821 stale = set() 

822 computable = set() 

823 for name, value in node_key_value_pairs: 

824 self._set_state_and_value(name, States.UPTODATE, value) 

825 stale.update(nx.dag.descendants(self.dag, name)) 

826 computable.update(self.dag.successors(name)) 

827 names = {name for name, value in node_key_value_pairs} 

828 stale.difference_update(names) 

829 computable.difference_update(names) 

830 for name in stale: 

831 self._set_state(name, States.STALE) 

832 for name in computable: 

833 self._try_set_computable(name) 

834 

835 def insert_from(self, other: "Computation", nodes: Iterable[Name] | None = None) -> None: 

836 """Insert values into another Computation object into this Computation object. 

837 

838 :param other: The computation object to take values from 

839 :type Computation: 

840 :param nodes: Only populate the nodes with the names provided in this list. By default, all nodes from the 

841 other Computation object that have corresponding nodes in this Computation object will be inserted 

842 :type nodes: List, default None 

843 """ 

844 if nodes is None: 

845 nodes_set: set[Any] = set(self.dag.nodes) 

846 nodes_set.intersection_update(other.dag.nodes()) 

847 nodes = nodes_set 

848 name_value_pairs = [(name, other.value(name)) for name in nodes] 

849 self.insert_many(name_value_pairs) 

850 

851 def _set_state(self, node_key: NodeKey, state: States) -> None: 

852 """Set the state of a node without changing its value.""" 

853 node = self.dag.nodes[node_key] 

854 old_state = node[NodeAttributes.STATE] 

855 self._state_map[old_state].remove(node_key) 

856 node[NodeAttributes.STATE] = state 

857 self._state_map[state].add(node_key) 

858 

859 def _set_state_and_value( 

860 self, node_key: NodeKey, state: States, value: object, *, throw_conversion_exception: bool = True 

861 ) -> None: 

862 """Set state and value of a node, applying any converter.""" 

863 node = self.dag.nodes[node_key] 

864 converter = node.get(NodeAttributes.CONVERTER) 

865 if converter is None: 

866 self._set_state_and_literal_value(node_key, state, value) 

867 else: 

868 try: 

869 converted_value = converter(value) 

870 self._set_state_and_literal_value(node_key, state, converted_value) 

871 except Exception as e: 

872 tb = traceback.format_exc() 

873 self._set_error(node_key, e, tb) 

874 if throw_conversion_exception: 

875 raise 

876 

877 def _set_state_and_literal_value( 

878 self, node_key: NodeKey, state: States, value: object, require_old_state: bool = True 

879 ) -> None: 

880 """Set state and literal value of a node without conversion.""" 

881 node = self.dag.nodes[node_key] 

882 try: 

883 old_state = node[NodeAttributes.STATE] 

884 self._state_map[old_state].remove(node_key) 

885 except KeyError: 

886 if require_old_state: 

887 raise # pragma: no cover 

888 node[NodeAttributes.STATE] = state 

889 node[NodeAttributes.VALUE] = value 

890 self._state_map[state].add(node_key) 

891 

892 def _set_states(self, node_keys: Iterable[NodeKey], state: States) -> None: 

893 """Set the state of multiple nodes at once.""" 

894 for name in node_keys: 

895 node = self.dag.nodes[name] 

896 old_state = node[NodeAttributes.STATE] 

897 self._state_map[old_state].remove(name) 

898 node[NodeAttributes.STATE] = state 

899 self._state_map[state].update(node_keys) 

900 

901 def set_stale(self, name: Name) -> None: 

902 """Set the state of a node and all its dependencies to STALE. 

903 

904 :param name: Name of the node to set as STALE. 

905 """ 

906 node_key = to_nodekey(name) 

907 node_keys: list[NodeKey] = [node_key] 

908 node_keys.extend(nx.dag.descendants(self.dag, node_key)) 

909 self._set_states(node_keys, States.STALE) 

910 self._try_set_computable(node_key) 

911 

912 def pin(self, name: Name, value: Any = None) -> None: 

913 """Set the state of a node to PINNED. 

914 

915 :param name: Name of the node to set as PINNED. 

916 :param value: Value to pin to the node, if provided. 

917 :type value: default None 

918 """ 

919 node_key = to_nodekey(name) 

920 if value is not None: 

921 self.insert(node_key, value) 

922 self._set_states([node_key], States.PINNED) 

923 

924 def unpin(self, name: Name) -> None: 

925 """Unpin a node (state of node and all descendents will be set to STALE). 

926 

927 :param name: Name of the node to set as PINNED. 

928 """ 

929 node_key = to_nodekey(name) 

930 self.set_stale(node_key) 

931 

932 def _get_descendents(self, node_key: NodeKey, stop_states: set[States] | None = None) -> set[NodeKey]: 

933 """Get all descendant nodes, optionally stopping at certain states.""" 

934 if stop_states is None: 

935 stop_states = set() 

936 if self.dag.nodes[node_key][NodeAttributes.STATE] in stop_states: 

937 return set() 

938 visited = set() 

939 to_visit = {node_key} 

940 while to_visit: 

941 n = to_visit.pop() 

942 visited.add(n) 

943 for n1 in self.dag.successors(n): 

944 if n1 in visited: 

945 continue 

946 if self.dag.nodes[n1][NodeAttributes.STATE] in stop_states: 

947 continue 

948 to_visit.add(n1) 

949 visited.remove(node_key) 

950 return visited 

951 

952 def _set_descendents(self, node_key: NodeKey, state: States) -> None: 

953 """Set the state of all descendant nodes.""" 

954 descendents = self._get_descendents(node_key, {States.PINNED}) 

955 self._set_states(descendents, state) 

956 

957 def _set_uninitialized(self, node_key: NodeKey) -> None: 

958 """Set a node to uninitialized state and clear its value.""" 

959 self._set_states([node_key], States.UNINITIALIZED) 

960 self.dag.nodes[node_key].pop(NodeAttributes.VALUE, None) 

961 

962 def _set_uptodate(self, node_key: NodeKey, value: object) -> None: 

963 """Set a node to up-to-date state with a value.""" 

964 self._set_state_and_value(node_key, States.UPTODATE, value) 

965 self._set_descendents(node_key, States.STALE) 

966 for n in self.dag.successors(node_key): 

967 self._try_set_computable(n) 

968 

969 def _set_error(self, node_key: NodeKey, exc: Exception, tb: str) -> None: 

970 """Set a node to error state with exception information.""" 

971 self._set_state_and_literal_value(node_key, States.ERROR, Error(exc, tb)) 

972 self._set_descendents(node_key, States.STALE) 

973 

974 def _try_set_computable(self, node_key: NodeKey) -> None: 

975 """Set node to computable if all predecessors are up-to-date.""" 

976 if self.dag.nodes[node_key][NodeAttributes.STATE] == States.PINNED: 

977 return 

978 if self.dag.nodes[node_key].get(NodeAttributes.FUNC) is not None: 

979 for n in self.dag.predecessors(node_key): 

980 if not self.dag.has_node(n): 

981 return # pragma: no cover 

982 if self.dag.nodes[n][NodeAttributes.STATE] != States.UPTODATE: 

983 return 

984 self._set_state(node_key, States.COMPUTABLE) 

985 

986 def _get_parameter_data(self, node_key: NodeKey) -> Iterable[_ParameterItem]: 

987 """Get all parameter data for a node's function call.""" 

988 for arg, value in self.dag.nodes[node_key][NodeAttributes.ARGS].items(): 

989 yield _ParameterItem(_ParameterType.ARG, arg, value) 

990 for param_name, value in self.dag.nodes[node_key][NodeAttributes.KWDS].items(): 

991 yield _ParameterItem(_ParameterType.KWD, param_name, value) 

992 for in_node_name in self.dag.predecessors(node_key): 

993 param_value = self.dag.nodes[in_node_name][NodeAttributes.VALUE] 

994 edge = self.dag[in_node_name][node_key] 

995 param_type, param_name = edge[EdgeAttributes.PARAM] 

996 yield _ParameterItem(param_type, param_name, param_value) 

997 

998 def _get_func_args_kwds( 

999 self, node_key: NodeKey 

1000 ) -> tuple[Callable[..., Any], str | None, list[Any], dict[str, Any]]: 

1001 """Get function, executor name, args and kwargs for a node.""" 

1002 node0 = self.dag.nodes[node_key] 

1003 f = node0[NodeAttributes.FUNC] 

1004 executor_name = node0.get(NodeAttributes.EXECUTOR) 

1005 args: list[Any] = [] 

1006 kwds: dict[str, Any] = {} 

1007 for param in self._get_parameter_data(node_key): 

1008 if param.type == _ParameterType.ARG: 

1009 idx = param.name 

1010 assert isinstance(idx, int) # noqa: S101 

1011 while len(args) <= idx: 

1012 args.append(None) 

1013 args[idx] = param.value 

1014 elif param.type == _ParameterType.KWD: 

1015 assert isinstance(param.name, str) # noqa: S101 

1016 kwds[param.name] = param.value 

1017 else: # pragma: no cover 

1018 msg = f"Unexpected param type: {param.type}" 

1019 raise ValidationError(msg) 

1020 return f, executor_name, args, kwds 

1021 

1022 def get_definition_args_kwds(self, name: Name) -> tuple[list[Any], dict[str, Any]]: 

1023 """Get the arguments and keyword arguments for a node's function definition.""" 

1024 res_args: list[Any] = [] 

1025 res_kwds: dict[str, Any] = {} 

1026 node_key = to_nodekey(name) 

1027 node_data = self.dag.nodes[node_key] 

1028 if NodeAttributes.ARGS in node_data: 

1029 for idx, value in node_data[NodeAttributes.ARGS].items(): 

1030 while len(res_args) <= idx: 

1031 res_args.append(None) 

1032 res_args[idx] = C(value) 

1033 if NodeAttributes.KWDS in node_data: 

1034 for param_name, value in node_data[NodeAttributes.KWDS].items(): 

1035 res_kwds[param_name] = C(value) 

1036 for in_node_key in self.dag.predecessors(node_key): 

1037 edge = self.dag[in_node_key][node_key] 

1038 if EdgeAttributes.PARAM in edge: 

1039 param_type, param_name = edge[EdgeAttributes.PARAM] 

1040 if param_type == _ParameterType.ARG: 

1041 idx = param_name 

1042 assert isinstance(idx, int) # noqa: S101 

1043 while len(res_args) <= idx: 

1044 res_args.append(None) 

1045 res_args[idx] = in_node_key.name 

1046 elif param_type == _ParameterType.KWD: 

1047 res_kwds[param_name] = in_node_key.name 

1048 else: # pragma: no cover 

1049 msg = f"Unexpected param type: {param_type}" 

1050 raise ValidationError(msg) 

1051 return res_args, res_kwds 

1052 

1053 def _compute_nodes(self, node_keys: Iterable[NodeKey], raise_exceptions: bool = False) -> None: 

1054 """Compute multiple nodes, handling dependencies and parallel execution.""" 

1055 LOG.debug(f"Computing nodes {node_keys}") 

1056 

1057 futs: dict[Any, NodeKey] = {} 

1058 node_keys_set = set(node_keys) 

1059 

1060 def run(name: NodeKey) -> None: 

1061 """Submit a node computation to an executor.""" 

1062 f, executor_name, args, kwds = self._get_func_args_kwds(name) 

1063 executor = self.default_executor if executor_name is None else self.executor_map[executor_name] 

1064 fut = executor.submit(_eval_node, name, f, args, kwds, raise_exceptions) 

1065 futs[fut] = name 

1066 

1067 computed: set[NodeKey] = set() 

1068 

1069 for node_key in node_keys_set: 

1070 node0 = self.dag.nodes[node_key] 

1071 state = node0[NodeAttributes.STATE] 

1072 if state == States.COMPUTABLE: 

1073 run(node_key) 

1074 

1075 while len(futs) > 0: 

1076 done, _not_done = wait(futs.keys(), return_when=FIRST_COMPLETED) 

1077 for fut in done: 

1078 node_key = futs.pop(fut) 

1079 node0 = self.dag.nodes[node_key] 

1080 try: 

1081 value, exc, tb, start_dt, end_dt = fut.result() 

1082 except Exception as e: 

1083 exc = e 

1084 tb = traceback.format_exc() 

1085 self._set_error(node_key, exc, tb) 

1086 raise 

1087 delta = (end_dt - start_dt).total_seconds() 

1088 if exc is None: 

1089 self._set_state_and_value(node_key, States.UPTODATE, value, throw_conversion_exception=False) 

1090 node0[NodeAttributes.TIMING] = TimingData(start_dt, end_dt, delta) 

1091 self._set_descendents(node_key, States.STALE) 

1092 for n in self.dag.successors(node_key): 

1093 logging.debug(str(node_key) + " " + str(n) + " " + str(computed)) 

1094 if n in computed: 

1095 msg = f"Calculating {node_key} for the second time" 

1096 raise LoopDetectedException(msg) 

1097 self._try_set_computable(n) 

1098 node0 = self.dag.nodes[n] 

1099 state = node0[NodeAttributes.STATE] 

1100 if state == States.COMPUTABLE and n in node_keys_set: 

1101 run(n) 

1102 else: 

1103 assert tb is not None # noqa: S101 

1104 self._set_error(node_key, exc, tb) 

1105 computed.add(node_key) 

1106 

1107 def _get_calc_node_keys(self, node_key: NodeKey) -> list[NodeKey]: 

1108 """Get node keys that need to be computed for a target node.""" 

1109 g = nx.DiGraph() 

1110 g.add_nodes_from(self.dag.nodes) 

1111 g.add_edges_from(self.dag.edges) 

1112 for n in nx.ancestors(g, node_key): 

1113 node = self.dag.nodes[n] 

1114 state = node[NodeAttributes.STATE] 

1115 if state == States.UPTODATE or state == States.PINNED: 

1116 g.remove_node(n) 

1117 

1118 ancestors = nx.ancestors(g, node_key) 

1119 for n in ancestors: 

1120 node = self.dag.nodes[n] 

1121 state = node[NodeAttributes.STATE] 

1122 

1123 if state == States.UNINITIALIZED and len(self.dag.pred[n]) == 0: 

1124 msg = f"Cannot compute {node_key} because {n} uninitialized" 

1125 raise ValidationError(msg) 

1126 if state == States.PLACEHOLDER: 

1127 msg = f"Cannot compute {node_key} because {n} is placeholder" 

1128 raise ValidationError(msg) 

1129 

1130 ancestors.add(node_key) 

1131 nodes_sorted = topological_sort(g) 

1132 return [n for n in nodes_sorted if n in ancestors] 

1133 

1134 def _get_calc_node_names(self, name: Name) -> Names: 

1135 """Get node names that need to be computed for a target node.""" 

1136 node_key = to_nodekey(name) 

1137 return node_keys_to_names(self._get_calc_node_keys(node_key)) 

1138 

1139 def compute(self, name: Name | Iterable[Name], raise_exceptions: bool = False) -> None: 

1140 """Compute a node and all necessary predecessors. 

1141 

1142 Following the computation, if successful, the target node, and all necessary ancestors that were not already 

1143 UPTODATE will have been calculated and set to UPTODATE. Any node that did not need to be calculated will not 

1144 have been recalculated. 

1145 

1146 If any nodes raises an exception, then the state of that node will be set to ERROR, and its value set to an 

1147 object containing the exception object, as well as a traceback. This will not halt the computation, which 

1148 will proceed as far as it can, until no more nodes that would be required to calculate the target are 

1149 COMPUTABLE. 

1150 

1151 :param name: Name of the node to compute 

1152 :param raise_exceptions: Whether to pass exceptions raised by node computations back to the caller 

1153 :type raise_exceptions: Boolean, default False 

1154 """ 

1155 calc_nodes: set[NodeKey] | list[NodeKey] 

1156 if isinstance(name, (types.GeneratorType, list)): 

1157 calc_nodes = set() 

1158 for name0 in name: 

1159 node_key = to_nodekey(name0) 

1160 for n in self._get_calc_node_keys(node_key): 

1161 calc_nodes.add(n) 

1162 else: 

1163 node_key = to_nodekey(name) 

1164 calc_nodes = self._get_calc_node_keys(node_key) 

1165 self._compute_nodes(calc_nodes, raise_exceptions=raise_exceptions) 

1166 

1167 def compute_all(self, raise_exceptions: bool = False) -> None: 

1168 """Compute all nodes of a computation that can be computed. 

1169 

1170 Nodes that are already UPTODATE will not be recalculated. Following the computation, if successful, all 

1171 nodes will have state UPTODATE, except UNINITIALIZED input nodes and PLACEHOLDER nodes. 

1172 

1173 If any nodes raises an exception, then the state of that node will be set to ERROR, and its value set to an 

1174 object containing the exception object, as well as a traceback. This will not halt the computation, which 

1175 will proceed as far as it can, until no more nodes are COMPUTABLE. 

1176 

1177 :param raise_exceptions: Whether to pass exceptions raised by node computations back to the caller 

1178 :type raise_exceptions: Boolean, default False 

1179 """ 

1180 self._compute_nodes(self._node_keys(), raise_exceptions=raise_exceptions) 

1181 

1182 def _node_keys(self) -> list[NodeKey]: 

1183 """Get a list of nodes in this computation. 

1184 

1185 :return: List of nodes. 

1186 """ 

1187 return list(self.dag.nodes) 

1188 

1189 def nodes(self) -> list[Name]: 

1190 """Get a list of nodes in this computation. 

1191 

1192 :return: List of nodes. 

1193 """ 

1194 return [n.name for n in self.dag.nodes] 

1195 

1196 def get_tree_list_children(self, name: Name) -> set[Name]: 

1197 """Get a list of nodes in this computation. 

1198 

1199 :return: List of nodes. 

1200 """ 

1201 node_key = to_nodekey(name) 

1202 idx = len(node_key.parts) 

1203 result = set() 

1204 for n in self.dag.nodes: 

1205 if n.is_descendent_of(node_key): 

1206 result.add(n.parts[idx]) 

1207 return result 

1208 

1209 def has_node(self, name: Name) -> bool: 

1210 """Check if a node with the given name exists in the computation.""" 

1211 node_key = to_nodekey(name) 

1212 return node_key in self.dag.nodes 

1213 

1214 def tree_has_path(self, name: Name) -> bool: 

1215 """Check if a hierarchical path exists in the computation tree.""" 

1216 node_key = to_nodekey(name) 

1217 if node_key.is_root: 

1218 return True 

1219 if self.has_node(node_key): 

1220 return True 

1221 return any(n.is_descendent_of(node_key) for n in self.dag.nodes) 

1222 

1223 def get_tree_descendents( 

1224 self, name: Name | None = None, *, include_stem: bool = True, graph_nodes_only: bool = False 

1225 ) -> set[Name]: 

1226 """Get a list of descendent blocks and nodes. 

1227 

1228 Returns blocks and nodes that are descendents of the input node, 

1229 e.g. for node 'foo', might return ['foo/bar', 'foo/baz']. 

1230 

1231 :param name: Name of node to get descendents for 

1232 :return: List of descendent node names 

1233 """ 

1234 node_key = NodeKey.root() if name is None else to_nodekey(name) 

1235 stemsize = len(node_key.parts) 

1236 result = set() 

1237 for n in self.dag.nodes: 

1238 if n.is_descendent_of(node_key): 

1239 nodes = [n] if graph_nodes_only else n.ancestors() 

1240 for n2 in nodes: 

1241 if n2.is_descendent_of(node_key): 

1242 nm = n2.name if include_stem else NodeKey(tuple(n2.parts[stemsize:])).name 

1243 result.add(nm) 

1244 return result 

1245 

1246 def _state_one(self, name: Name) -> States: 

1247 """Get the state of a single node.""" 

1248 node_key = to_nodekey(name) 

1249 state: States = self.dag.nodes[node_key][NodeAttributes.STATE] 

1250 return state 

1251 

1252 @overload 

1253 def state(self, name: Name) -> States: ... 

1254 

1255 @overload 

1256 def state(self, name: Names) -> list[States]: ... 

1257 

1258 def state(self, name: Name | Names) -> States | list[States]: 

1259 """Get the state of a node. 

1260 

1261 This can also be accessed using the attribute-style accessor ``s`` if ``name`` is a valid Python 

1262 attribute name:: 

1263 

1264 >>> comp = Computation() 

1265 >>> comp.add_node('foo', value=1) 

1266 >>> comp.state('foo') 

1267 <States.UPTODATE: 4> 

1268 >>> comp.s.foo 

1269 <States.UPTODATE: 4> 

1270 

1271 :param name: Name or names of the node to get state for 

1272 :type name: Name or Names 

1273 """ 

1274 return apply1(self._state_one, name) 

1275 

1276 def _value_one(self, name: Name) -> Any: 

1277 """Get the value of a single node.""" 

1278 node_key = to_nodekey(name) 

1279 return self.dag.nodes[node_key][NodeAttributes.VALUE] 

1280 

1281 @overload 

1282 def value(self, name: Name) -> Any: ... 

1283 

1284 @overload 

1285 def value(self, name: Names) -> list[Any]: ... 

1286 

1287 def value(self, name: Name | Names) -> Any | list[Any]: 

1288 """Get the current value of a node. 

1289 

1290 This can also be accessed using the attribute-style accessor ``v`` if ``name`` is a valid Python 

1291 attribute name:: 

1292 

1293 >>> comp = Computation() 

1294 >>> comp.add_node('foo', value=1) 

1295 >>> comp.value('foo') 

1296 1 

1297 >>> comp.v.foo 

1298 1 

1299 

1300 :param name: Name or names of the node to get the value of 

1301 :type name: Name or Names 

1302 """ 

1303 return apply1(self._value_one, name) 

1304 

1305 def compute_and_get_value(self, name: Name) -> Any: 

1306 """Get the current value of a node. 

1307 

1308 This can also be accessed using the attribute-style accessor ``v`` if ``name`` is a valid Python 

1309 attribute name:: 

1310 

1311 >>> comp = Computation() 

1312 >>> comp.add_node('foo', value=1) 

1313 >>> comp.add_node('bar', lambda foo: foo + 1) 

1314 >>> comp.compute_and_get_value('bar') 

1315 2 

1316 >>> comp.x.bar 

1317 2 

1318 

1319 :param name: Name or names of the node to get the value of 

1320 :type name: Name 

1321 """ 

1322 nk = to_nodekey(name) 

1323 if self.state(nk) == States.UPTODATE: 

1324 return self.value(nk) 

1325 self.compute(nk, raise_exceptions=True) 

1326 if self.state(nk) == States.UPTODATE: 

1327 return self.value(nk) 

1328 msg = f"Unable to compute node {nk}" 

1329 raise ComputationError(msg) 

1330 

1331 def _tag_one(self, name: Name) -> set[str]: 

1332 """Get the tags of a single node.""" 

1333 node_key = to_nodekey(name) 

1334 node = self.dag.nodes[node_key] 

1335 tags: set[str] = node[NodeAttributes.TAG] 

1336 return tags 

1337 

1338 @overload 

1339 def tags(self, name: Name) -> set[str]: ... 

1340 

1341 @overload 

1342 def tags(self, name: Names) -> list[set[str]]: ... 

1343 

1344 def tags(self, name: Name | Names) -> set[str] | list[set[str]]: 

1345 """Get the tags associated with a node. 

1346 

1347 >>> comp = Computation() 

1348 >>> comp.add_node('a', tags=['foo', 'bar']) 

1349 >>> sorted(comp.t.a) 

1350 ['__serialize__', 'bar', 'foo'] 

1351 

1352 :param name: Name or names of the node to get the tags of 

1353 :return: 

1354 """ 

1355 return apply1(self._tag_one, name) 

1356 

1357 def nodes_by_tag(self, tag: str | Iterable[str]) -> set[Name]: 

1358 """Get the names of nodes with a particular tag or tags. 

1359 

1360 :param tag: Tag or tags for which to retrieve nodes 

1361 :return: Names of the nodes with those tags 

1362 """ 

1363 nodes: set[NodeKey] = set() 

1364 tags_to_check: Iterable[str] = [tag] if isinstance(tag, str) else tag 

1365 for tag1 in tags_to_check: 

1366 nodes1 = self._tag_map.get(tag1) 

1367 if nodes1 is not None: 

1368 nodes.update(nodes1) 

1369 return {n.name for n in nodes} 

1370 

1371 def _style_one(self, name: Name) -> str | None: 

1372 """Get the style of a single node.""" 

1373 node_key = to_nodekey(name) 

1374 node = self.dag.nodes[node_key] 

1375 style: str | None = node.get(NodeAttributes.STYLE) 

1376 return style 

1377 

1378 @overload 

1379 def styles(self, name: Name) -> str | None: ... 

1380 

1381 @overload 

1382 def styles(self, name: Names) -> list[str | None]: ... 

1383 

1384 def styles(self, name: Name | Names) -> str | None | list[str | None]: 

1385 """Get the tags associated with a node. 

1386 

1387 >>> comp = Computation() 

1388 >>> comp.add_node('a', style='dot') 

1389 >>> comp.style.a 

1390 'dot' 

1391 

1392 :param name: Name or names of the node to get the tags of 

1393 :return: 

1394 """ 

1395 return apply1(self._style_one, name) 

1396 

1397 def _get_item_one(self, name: Name) -> NodeData: 

1398 """Get state and value data for a single node.""" 

1399 node_key = to_nodekey(name) 

1400 node = self.dag.nodes[node_key] 

1401 return NodeData(node[NodeAttributes.STATE], node[NodeAttributes.VALUE]) 

1402 

1403 @overload 

1404 def __getitem__(self, name: Name) -> NodeData: ... 

1405 

1406 @overload 

1407 def __getitem__(self, name: Names) -> list[NodeData]: ... 

1408 

1409 def __getitem__(self, name: Name | Names) -> NodeData | list[NodeData]: 

1410 """Get the state and current value of a node. 

1411 

1412 :param name: Name of the node to get the state and value of 

1413 """ 

1414 return apply1(self._get_item_one, name) 

1415 

1416 def _get_timing_one(self, name: Name) -> TimingData | None: 

1417 """Get timing data for a single node.""" 

1418 node_key = to_nodekey(name) 

1419 node = self.dag.nodes[node_key] 

1420 timing: TimingData | None = node.get(NodeAttributes.TIMING, None) 

1421 return timing 

1422 

1423 @overload 

1424 def get_timing(self, name: Name) -> TimingData | None: ... 

1425 

1426 @overload 

1427 def get_timing(self, name: Names) -> list[TimingData | None]: ... 

1428 

1429 def get_timing(self, name: Name | Names) -> TimingData | None | list[TimingData | None]: 

1430 """Get the timing information for a node. 

1431 

1432 :param name: Name or names of the node to get the timing information of 

1433 :return: 

1434 """ 

1435 return apply1(self._get_timing_one, name) 

1436 

1437 def to_df(self) -> pd.DataFrame: 

1438 """Get a dataframe containing the states and value of all nodes of computation. 

1439 

1440 :: 

1441 

1442 >>> import loman 

1443 >>> comp = loman.Computation() 

1444 >>> comp.add_node('foo', value=1) 

1445 >>> comp.add_node('bar', value=2) 

1446 >>> comp.to_df() # doctest: +NORMALIZE_WHITESPACE 

1447 state value 

1448 foo States.UPTODATE 1 

1449 bar States.UPTODATE 2 

1450 """ 

1451 df = pd.DataFrame(index=topological_sort(self.dag)) 

1452 df[NodeAttributes.STATE] = pd.Series(nx.get_node_attributes(self.dag, NodeAttributes.STATE)) 

1453 df[NodeAttributes.VALUE] = pd.Series(nx.get_node_attributes(self.dag, NodeAttributes.VALUE)) 

1454 df_timing = pd.DataFrame.from_dict(nx.get_node_attributes(self.dag, "timing"), orient="index") 

1455 df = pd.merge(df, df_timing, left_index=True, right_index=True, how="left") 

1456 df.index = pd.Index([nk.name for nk in df.index]) 

1457 return df 

1458 

1459 def to_dict(self) -> dict[NodeKey, Any]: 

1460 """Get a dictionary containing the values of all nodes of a computation. 

1461 

1462 :: 

1463 

1464 >>> import loman 

1465 >>> comp = loman.Computation() 

1466 >>> comp.add_node('foo', value=1) 

1467 >>> comp.add_node('bar', value=2) 

1468 >>> comp.to_dict() # doctest: +ELLIPSIS 

1469 {NodeKey('foo'): 1, NodeKey('bar'): 2} 

1470 """ 

1471 result: dict[NodeKey, Any] = nx.get_node_attributes(self.dag, NodeAttributes.VALUE) 

1472 return result 

1473 

1474 def _get_inputs_one_node_keys(self, node_key: NodeKey) -> list[NodeKey | None]: 

1475 """Get input node keys for a single node.""" 

1476 args_dict: dict[int, NodeKey] = {} 

1477 kwds: list[NodeKey | None] = [] 

1478 max_arg_index = -1 

1479 for input_node in self.dag.predecessors(node_key): 

1480 input_edge = self.dag[input_node][node_key] 

1481 input_type, input_param = input_edge[EdgeAttributes.PARAM] 

1482 if input_type == _ParameterType.ARG: 

1483 idx = input_param 

1484 max_arg_index = max(max_arg_index, idx) 

1485 args_dict[idx] = input_node 

1486 elif input_type == _ParameterType.KWD: 

1487 kwds.append(input_node) 

1488 if max_arg_index >= 0: 

1489 args: list[NodeKey | None] = [None] * (max_arg_index + 1) 

1490 for idx, input_node in args_dict.items(): 

1491 args[idx] = input_node 

1492 result: list[NodeKey | None] = args + kwds 

1493 return result 

1494 else: 

1495 return kwds 

1496 

1497 def _get_inputs_one_names(self, name: Name) -> Names: 

1498 """Get input node names for a single node.""" 

1499 node_key = to_nodekey(name) 

1500 return node_keys_to_names([nk for nk in self._get_inputs_one_node_keys(node_key) if nk is not None]) 

1501 

1502 @overload 

1503 def get_inputs(self, name: Name) -> Names: ... 

1504 

1505 @overload 

1506 def get_inputs(self, name: Names) -> list[Names]: ... 

1507 

1508 def get_inputs(self, name: Name | Names) -> Names | list[Names]: 

1509 """Get a list of the inputs for a node or set of nodes. 

1510 

1511 :param name: Name or names of nodes to get inputs for 

1512 :return: If name is scalar, return a list of upstream nodes used as input. If name is a list, return a 

1513 list of list of inputs. 

1514 """ 

1515 return apply1(self._get_inputs_one_names, name) 

1516 

1517 def _get_ancestors_node_keys(self, node_keys: Iterable[NodeKey], include_self: bool = True) -> set[NodeKey]: 

1518 """Get all ancestor node keys for a set of nodes.""" 

1519 ancestors: set[NodeKey] = set() 

1520 for n in node_keys: 

1521 if include_self: 

1522 ancestors.add(n) 

1523 for ancestor in nx.ancestors(self.dag, n): 

1524 ancestors.add(ancestor) 

1525 return ancestors 

1526 

1527 def get_ancestors(self, names: Name | Names, include_self: bool = True) -> Names: 

1528 """Get all ancestor nodes of the specified nodes.""" 

1529 node_keys = names_to_node_keys(names) 

1530 ancestor_node_keys = self._get_ancestors_node_keys(node_keys, include_self) 

1531 return node_keys_to_names(ancestor_node_keys) 

1532 

1533 def _get_original_inputs_node_keys(self, node_keys: list[NodeKey] | None) -> list[NodeKey]: 

1534 """Get original input node keys that have no computation function.""" 

1535 resolved_node_keys: Iterable[NodeKey] 

1536 resolved_node_keys = self._node_keys() if node_keys is None else self._get_ancestors_node_keys(node_keys) 

1537 return [n for n in resolved_node_keys if self.dag.nodes[n].get(NodeAttributes.FUNC) is None] 

1538 

1539 def get_original_inputs(self, names: Name | Names | None = None) -> Names: 

1540 """Get a list of the original non-computed inputs for a node or set of nodes. 

1541 

1542 :param names: Name or names of nodes to get inputs for 

1543 :return: Return a list of original non-computed inputs that are ancestors of the input nodes 

1544 """ 

1545 nks = None if names is None else names_to_node_keys(names) 

1546 

1547 result_nks = self._get_original_inputs_node_keys(nks) 

1548 

1549 return node_keys_to_names(result_nks) 

1550 

1551 def _get_outputs_one(self, name: Name) -> Names: 

1552 """Get output node names for a single node.""" 

1553 node_key = to_nodekey(name) 

1554 output_node_keys = list(self.dag.successors(node_key)) 

1555 return node_keys_to_names(output_node_keys) 

1556 

1557 @overload 

1558 def get_outputs(self, name: Name) -> Names: ... 

1559 

1560 @overload 

1561 def get_outputs(self, name: Names) -> list[Names]: ... 

1562 

1563 def get_outputs(self, name: Name | Names) -> Names | list[Names]: 

1564 """Get a list of the outputs for a node or set of nodes. 

1565 

1566 :param name: Name or names of nodes to get outputs for 

1567 :return: If name is scalar, return a list of downstream nodes used as output. If name is a list, return a 

1568 list of list of outputs. 

1569 

1570 """ 

1571 return apply1(self._get_outputs_one, name) 

1572 

1573 def _get_descendents_node_keys(self, node_keys: Iterable[NodeKey], include_self: bool = True) -> set[NodeKey]: 

1574 """Get all descendant node keys for a set of nodes.""" 

1575 descendent_node_keys: set[NodeKey] = set() 

1576 for node_key in node_keys: 

1577 if include_self: 

1578 descendent_node_keys.add(node_key) 

1579 for descendent in nx.descendants(self.dag, node_key): 

1580 descendent_node_keys.add(descendent) 

1581 return descendent_node_keys 

1582 

1583 def get_descendents(self, names: Name | Names, include_self: bool = True) -> Names: 

1584 """Get all descendent nodes of the specified nodes.""" 

1585 node_keys = names_to_node_keys(names) 

1586 descendent_node_keys = self._get_descendents_node_keys(node_keys, include_self) 

1587 return node_keys_to_names(descendent_node_keys) 

1588 

1589 def get_final_outputs(self, names: Name | Names | None = None) -> Names: 

1590 """Get final output nodes (nodes with no descendants) from the specified nodes.""" 

1591 final_node_keys: Iterable[NodeKey] 

1592 if names is None: 

1593 final_node_keys = self._node_keys() 

1594 else: 

1595 nks = names_to_node_keys(names) 

1596 final_node_keys = self._get_descendents_node_keys(nks) 

1597 output_node_keys = [n for n in final_node_keys if len(nx.descendants(self.dag, n)) == 0] 

1598 return node_keys_to_names(output_node_keys) 

1599 

1600 def get_source(self, name: Name) -> str: 

1601 """Get the source code for a node.""" 

1602 node_key = to_nodekey(name) 

1603 func = self.dag.nodes[node_key].get(NodeAttributes.FUNC, None) 

1604 if func is not None: 

1605 file = inspect.getsourcefile(func) 

1606 _, lineno = inspect.getsourcelines(func) 

1607 source = inspect.getsource(func) 

1608 return f"{file}:{lineno}\n\n{source}" 

1609 else: 

1610 return "NOT A CALCULATED NODE" 

1611 

1612 def print_source(self, name: Name) -> None: 

1613 """Print the source code for a computation node.""" 

1614 print(self.get_source(name)) 

1615 

1616 def restrict(self, output_names: Name | Names, input_names: Name | Names | None = None) -> None: 

1617 """Restrict a computation to the ancestors of a set of output nodes. 

1618 

1619 Excludes ancestors of a set of input nodes. 

1620 

1621 If the set of input_nodes that is specified is not sufficient for the set of output_nodes then additional 

1622 nodes that are ancestors of the output_nodes will be included, but the input nodes specified will be input 

1623 nodes of the modified Computation. 

1624 

1625 :param output_nodes: 

1626 :param input_nodes: 

1627 :return: None - modifies existing computation in place 

1628 """ 

1629 if input_names is not None: 

1630 for n in as_iterable(input_names): 

1631 nodedata = self._get_item_one(n) 

1632 self.add_node(n) 

1633 self._set_state_and_literal_value(to_nodekey(n), nodedata.state, nodedata.value) 

1634 output_node_keys = names_to_node_keys(output_names) 

1635 ancestor_node_keys = self._get_ancestors_node_keys(output_node_keys) 

1636 self.dag.remove_nodes_from([n for n in self.dag if n not in ancestor_node_keys]) 

1637 

1638 def __getstate__(self) -> dict[str, Any]: 

1639 """Prepare computation for serialization by removing non-serializable nodes.""" 

1640 node_serialize = nx.get_node_attributes(self.dag, NodeAttributes.TAG) 

1641 obj = self.copy() 

1642 for name, tags in node_serialize.items(): 

1643 if SystemTags.SERIALIZE not in tags: 

1644 obj._set_uninitialized(name) 

1645 return {"dag": obj.dag} 

1646 

1647 def __setstate__(self, state: dict[str, Any]) -> None: 

1648 """Restore computation from serialized state.""" 

1649 self.__init__() 

1650 self.dag = state["dag"] 

1651 self._refresh_maps() 

1652 

1653 def write_dill_old(self, file_: str | BinaryIO) -> None: 

1654 """Serialize a computation to a file or file-like object. 

1655 

1656 :param file_: If string, writes to a file 

1657 :type file_: File-like object, or string 

1658 """ 

1659 warnings.warn("write_dill_old is deprecated, use write_dill instead", DeprecationWarning, stacklevel=2) 

1660 original_getstate = self.__class__.__getstate__ 

1661 original_setstate = self.__class__.__setstate__ 

1662 

1663 try: 

1664 del self.__class__.__getstate__ 

1665 del self.__class__.__setstate__ 

1666 

1667 node_serialize = nx.get_node_attributes(self.dag, NodeAttributes.TAG) 

1668 obj = self.copy() 

1669 obj.executor_map = None # type: ignore[assignment] 

1670 obj.default_executor = None # type: ignore[assignment] 

1671 for name, tags in node_serialize.items(): 

1672 if SystemTags.SERIALIZE not in tags: 

1673 obj._set_uninitialized(name) 

1674 

1675 if isinstance(file_, str): 

1676 with open(file_, "wb") as f: 

1677 dill.dump(obj, f) 

1678 else: 

1679 dill.dump(obj, file_) 

1680 finally: 

1681 self.__class__.__getstate__ = original_getstate # type: ignore[method-assign] 

1682 self.__class__.__setstate__ = original_setstate 

1683 

1684 def write_dill(self, file_: str | BinaryIO) -> None: 

1685 """Serialize a computation to a file or file-like object. 

1686 

1687 .. deprecated:: 

1688 Use :meth:`write_json` instead. dill-based serialization will be 

1689 removed in a future release. 

1690 

1691 :param file_: If string, writes to a file 

1692 :type file_: File-like object, or string 

1693 """ 

1694 warnings.warn( 

1695 "write_dill is deprecated and will be removed in a future release. Use write_json instead.", 

1696 DeprecationWarning, 

1697 stacklevel=2, 

1698 ) 

1699 if isinstance(file_, str): 

1700 with open(file_, "wb") as f: 

1701 dill.dump(self, f) 

1702 else: 

1703 dill.dump(self, file_) 

1704 

1705 @staticmethod 

1706 def read_dill(file_: str | BinaryIO) -> "Computation": 

1707 """Deserialize a computation from a file or file-like object. 

1708 

1709 .. deprecated:: 

1710 Use :meth:`read_json` instead. dill-based serialization will be 

1711 removed in a future release. 

1712 

1713 .. warning:: 

1714 This method uses dill.load() which can execute arbitrary code. 

1715 Only load files from trusted sources. Never load data from 

1716 untrusted or unauthenticated sources as it may lead to arbitrary 

1717 code execution. 

1718 

1719 :param file_: If string, writes to a file 

1720 :type file_: File-like object, or string 

1721 """ 

1722 warnings.warn( 

1723 "read_dill is deprecated and will be removed in a future release. Use read_json instead.", 

1724 DeprecationWarning, 

1725 stacklevel=2, 

1726 ) 

1727 if isinstance(file_, str): 

1728 with open(file_, "rb") as f: 

1729 obj = dill.load(f) # noqa: S301 # nosec B301 

1730 else: 

1731 obj = dill.load(file_) # noqa: S301 # nosec B301 

1732 if isinstance(obj, Computation): 

1733 return obj 

1734 else: 

1735 msg = "Loaded object is not a Computation" 

1736 raise ValidationError(msg) 

1737 

1738 def write_json(self, file_: str | TextIO, *, serializer: "ComputationSerializer | None" = None) -> None: 

1739 """Serialize a computation to a JSON file or file-like object. 

1740 

1741 Custom types can be supported by passing a custom *serializer* — 

1742 either a :class:`~loman.serialization.computation.ComputationSerializer` 

1743 instance with extra transformers registered, or a subclass that 

1744 overrides the transformer factory. 

1745 

1746 :param file_: Destination file path (str) or text-mode file-like object. 

1747 :param serializer: Optional custom serializer. If ``None`` the default 

1748 :class:`~loman.serialization.computation.ComputationSerializer` is used. 

1749 """ 

1750 from .serialization.computation import ComputationSerializer 

1751 

1752 s = serializer if serializer is not None else ComputationSerializer() 

1753 if isinstance(file_, str): 

1754 with open(file_, "w") as f: 

1755 s.dump(self, f) 

1756 else: 

1757 s.dump(self, file_) 

1758 

1759 @staticmethod 

1760 def read_json(file_: str | TextIO, *, serializer: "ComputationSerializer | None" = None) -> "Computation": 

1761 """Deserialize a computation from a JSON file or file-like object. 

1762 

1763 :param file_: Source file path (str) or text-mode file-like object. 

1764 :param serializer: Optional custom serializer. If ``None`` the default 

1765 :class:`~loman.serialization.computation.ComputationSerializer` is used. 

1766 :rtype: Computation 

1767 """ 

1768 from .serialization.computation import ComputationSerializer 

1769 

1770 s = serializer if serializer is not None else ComputationSerializer() 

1771 if isinstance(file_, str): 

1772 with open(file_) as f: 

1773 return s.load(f) 

1774 else: 

1775 return s.load(file_) 

1776 

1777 def copy(self) -> "Computation": 

1778 """Create a copy of a computation. 

1779 

1780 The copy is shallow. Any values in the new Computation's DAG will be the same object as this Computation's 

1781 DAG. As new objects will be created by any further computations, this should not be an issue. 

1782 

1783 :rtype: Computation 

1784 """ 

1785 obj = Computation() 

1786 obj.dag = nx.DiGraph(self.dag) 

1787 obj._tag_map = defaultdict(set, {tag: nodes.copy() for tag, nodes in self._tag_map.items()}) 

1788 obj._state_map = {state: nodes.copy() for state, nodes in self._state_map.items()} 

1789 return obj 

1790 

1791 def add_named_tuple_expansion(self, name: Name, namedtuple_type: type, group: str | None = None) -> None: 

1792 """Automatically add nodes to extract each element of a named tuple type. 

1793 

1794 It is often convenient for a calculation to return multiple values, and it is polite to do this a namedtuple 

1795 rather than a regular tuple, so that later users have same name to identify elements of the tuple. It can 

1796 also help make a computation clearer if a downstream computation depends on one element of such a tuple, 

1797 rather than the entire tuple. This does not affect the computation per se, but it does make the intention 

1798 clearer. 

1799 

1800 To avoid having to create many boiler-plate node definitions to expand namedtuples, the 

1801 ``add_named_tuple_expansion`` method automatically creates new nodes for each element of a tuple. The 

1802 convention is that an element called 'element', in a node called 'node' will be expanded into a new node 

1803 called 'node.element', and that this will be applied for each element. 

1804 

1805 Example:: 

1806 

1807 >>> from collections import namedtuple 

1808 >>> Coordinate = namedtuple('Coordinate', ['x', 'y']) 

1809 >>> comp = Computation() 

1810 >>> comp.add_node('c', value=Coordinate(1, 2)) 

1811 >>> comp.add_named_tuple_expansion('c', Coordinate) 

1812 >>> comp.compute_all() 

1813 >>> comp.value('c.x') 

1814 1 

1815 >>> comp.value('c.y') 

1816 2 

1817 

1818 :param name: Node to cera 

1819 :param namedtuple_type: Expected type of the node 

1820 :type namedtuple_type: namedtuple class 

1821 """ 

1822 

1823 def make_f(field_name: str) -> Callable[[Any], Any]: 

1824 """Create a function to extract a field from a namedtuple.""" 

1825 

1826 def get_field_value(tuple_val: Any) -> Any: 

1827 """Extract field value from the namedtuple.""" 

1828 return getattr(tuple_val, field_name) 

1829 

1830 return get_field_value 

1831 

1832 for field_name in namedtuple_type._fields: # type: ignore[attr-defined] 

1833 node_name = f"{name}.{field_name}" 

1834 self.add_node(node_name, make_f(field_name), kwds={"tuple_val": name}, group=group) 

1835 self.set_tag(node_name, SystemTags.EXPANSION) 

1836 

1837 def add_map_node( 

1838 self, 

1839 result_node: Name, 

1840 input_node: Name, 

1841 subgraph: "Computation", 

1842 subgraph_input_node: Name, 

1843 subgraph_output_node: Name, 

1844 ) -> None: 

1845 """Apply a graph to each element of iterable. 

1846 

1847 In turn, each element in the ``input_node`` of this graph will be inserted in turn into the subgraph's 

1848 ``subgraph_input_node``, then the subgraph's ``subgraph_output_node`` calculated. The resultant list, with 

1849 an element or each element in ``input_node``, will be inserted into ``result_node`` of this graph. In this 

1850 way ``add_map_node`` is similar to ``map`` in functional programming. 

1851 

1852 :param result_node: The node to place a list of results in **this** graph 

1853 :param input_node: The node to get a list input values from **this** graph 

1854 :param subgraph: The graph to use to perform calculation for each element 

1855 :param subgraph_input_node: The node in **subgraph** to insert each element in turn 

1856 :param subgraph_output_node: The node in **subgraph** to read the result for each element 

1857 """ 

1858 

1859 def f(xs: Iterable[Any]) -> list[Any]: 

1860 """Apply subgraph computation to each element in the input.""" 

1861 results: list[Any] = [] 

1862 is_error = False 

1863 for x in xs: 

1864 subgraph.insert(subgraph_input_node, x) 

1865 subgraph.compute(subgraph_output_node) 

1866 if subgraph.state(subgraph_output_node) == States.UPTODATE: 

1867 results.append(subgraph.value(subgraph_output_node)) 

1868 else: 

1869 is_error = True 

1870 results.append(subgraph.copy()) 

1871 if is_error: 

1872 msg = f"Unable to calculate {result_node}" 

1873 raise MapException(msg, results) 

1874 return results 

1875 

1876 self.add_node(result_node, f, kwds={"xs": input_node}) 

1877 

1878 def prepend_path(self, path: Name | ConstantValue, prefix_path: NodeKey) -> NodeKey | ConstantValue: 

1879 """Prepend a prefix path to a node path.""" 

1880 if isinstance(path, ConstantValue): 

1881 return path 

1882 nk = to_nodekey(path) 

1883 return prefix_path.join(nk) 

1884 

1885 def add_block( 

1886 self, 

1887 base_path: Name, 

1888 block: "Computation", 

1889 *, 

1890 keep_values: bool | None = True, 

1891 links: dict[str, Name] | None = None, 

1892 metadata: dict[str, Any] | None = None, 

1893 ) -> None: 

1894 """Add a computation block as a subgraph to this computation.""" 

1895 base_path_nk = to_nodekey(base_path) 

1896 for node_name in block.nodes(): 

1897 node_key = to_nodekey(node_name) 

1898 node_data = block.dag.nodes[node_key] 

1899 tags = node_data.get(NodeAttributes.TAG, None) 

1900 # strip the serialize tag from the original node: add_block explicitly 

1901 # sets serialize=False, meaning "don't serialize the function". 

1902 # Value serialization is controlled separately via keep_values below. 

1903 if tags is not None: 

1904 tags = tags - {SystemTags.SERIALIZE} 

1905 style = node_data.get(NodeAttributes.STYLE, None) 

1906 group = node_data.get(NodeAttributes.GROUP, None) 

1907 args_def, kwds_def = block.get_definition_args_kwds(node_key) 

1908 args_prepended = [self.prepend_path(arg, base_path_nk) for arg in args_def] 

1909 kwds_prepended = {k: self.prepend_path(v, base_path_nk) for k, v in kwds_def.items()} 

1910 func = node_data.get(NodeAttributes.FUNC, None) 

1911 executor = node_data.get(NodeAttributes.EXECUTOR, None) 

1912 converter = node_data.get(NodeAttributes.CONVERTER, None) 

1913 new_node_name = self.prepend_path(node_name, base_path_nk) 

1914 self.add_node( 

1915 new_node_name, 

1916 func, 

1917 args=args_prepended, 

1918 kwds=kwds_prepended, 

1919 converter=converter, 

1920 serialize=False, 

1921 inspect=False, 

1922 group=group, 

1923 tags=tags, 

1924 style=style, 

1925 executor=executor, 

1926 ) 

1927 if keep_values and NodeAttributes.VALUE in node_data: 

1928 new_node_key = to_nodekey(new_node_name) 

1929 self._set_state_and_literal_value( 

1930 new_node_key, node_data[NodeAttributes.STATE], node_data[NodeAttributes.VALUE] 

1931 ) 

1932 # The node has a concrete value — mark it serializable so the 

1933 # value survives a JSON roundtrip even though the function is not. 

1934 self._set_tag_one(new_node_key, SystemTags.SERIALIZE) 

1935 if links is not None: 

1936 for target, source in links.items(): 

1937 self.link(base_path_nk.join_parts(target), source) 

1938 if metadata is not None: 

1939 self._metadata[base_path_nk] = metadata 

1940 else: 

1941 if base_path_nk in self._metadata: 

1942 del self._metadata[base_path_nk] 

1943 

1944 def link(self, target: Name, source: Name) -> None: 

1945 """Create a link between two nodes in the computation graph.""" 

1946 target_nk = to_nodekey(target) 

1947 source_nk = to_nodekey(source) 

1948 if target_nk == source_nk: 

1949 return 

1950 

1951 target_style = self._style_one(target_nk) if self.has_node(target_nk) else None 

1952 source_style = self._style_one(source_nk) if self.has_node(source_nk) else None 

1953 style = target_style if target_style else source_style 

1954 

1955 self.add_node(target_nk, identity_function, kwds={"x": source_nk}, style=style) 

1956 

1957 def _repr_svg_(self) -> str | None: 

1958 """Return SVG representation for Jupyter notebook display.""" 

1959 return GraphView(self).svg() 

1960 

1961 def draw( 

1962 self, 

1963 root: NodeKey | None = None, 

1964 *, 

1965 node_transformations: dict[Name, str] | None = None, 

1966 cmap: Any = None, 

1967 colors: str = "state", 

1968 shapes: str | None = None, 

1969 graph_attr: dict[str, Any] | None = None, 

1970 node_attr: dict[str, Any] | None = None, 

1971 edge_attr: dict[str, Any] | None = None, 

1972 show_expansion: bool = False, 

1973 collapse_all: bool = True, 

1974 ) -> GraphView: 

1975 """Draw a computation's current state using the GraphViz utility. 

1976 

1977 :param root: Optional PathType. Sub-block to draw 

1978 :param cmap: Default: None 

1979 :param colors: 'state' - colors indicate state. 'timing' - colors indicate execution time. Default: 'state'. 

1980 :param shapes: None - ovals. 'type' - shapes indicate type. Default: None. 

1981 :param graph_attr: Mapping of (attribute, value) pairs for the graph. For example 

1982 ``graph_attr={'size': '"10,8"'}`` can control the size of the output graph 

1983 :param node_attr: Mapping of (attribute, value) pairs set for all nodes. 

1984 :param edge_attr: Mapping of (attribute, value) pairs set for all edges. 

1985 :param collapse_all: Whether to collapse all blocks that aren't explicitly expanded. 

1986 """ 

1987 node_formatter = NodeFormatter.create(cmap, colors, shapes) 

1988 node_transformations_copy: dict[Name, str] = ( 

1989 node_transformations.copy() if node_transformations is not None else {} 

1990 ) 

1991 if not show_expansion: 

1992 for nodekey in self.nodes_by_tag(SystemTags.EXPANSION): 

1993 node_transformations_copy[nodekey] = NodeTransformations.CONTRACT 

1994 v = GraphView( 

1995 self, 

1996 root=root, 

1997 node_formatter=node_formatter, 

1998 graph_attr=graph_attr, 

1999 node_attr=node_attr, 

2000 edge_attr=edge_attr, 

2001 node_transformations=node_transformations_copy, 

2002 collapse_all=collapse_all, 

2003 ) 

2004 return v 

2005 

2006 def view(self, cmap: Any = None, colors: str = "state", shapes: str | None = None) -> None: 

2007 """Create and display a visualization of the computation graph.""" 

2008 node_formatter = NodeFormatter.create(cmap, colors, shapes) 

2009 v = GraphView(self, node_formatter=node_formatter) 

2010 v.view() 

2011 

2012 def print_errors(self) -> None: 

2013 """Print tracebacks for every node with state "ERROR" in a Computation.""" 

2014 for n in self.nodes(): 

2015 if self.s[n] == States.ERROR: 

2016 print(f"{n}") 

2017 print("=" * len(str(n))) 

2018 print() 

2019 print(self.v[n].traceback) 

2020 print() 

2021 

2022 @classmethod 

2023 def from_class(cls, definition_class: type, ignore_self: bool = True) -> "Computation": 

2024 """Create a computation from a class with decorated methods.""" 

2025 comp = cls() 

2026 obj = definition_class() 

2027 populate_computation_from_class(comp, definition_class, obj, ignore_self=ignore_self) 

2028 return comp 

2029 

2030 def inject_dependencies(self, dependencies: dict[Name, Any], *, force: bool = False) -> None: 

2031 """Injects dependencies into the nodes of the current computation where nodes are in a placeholder state. 

2032 

2033 (or all possible nodes when the 'force' parameter is set to True), using values 

2034 provided in the 'dependencies' dictionary. 

2035 

2036 Each key in the 'dependencies' dictionary corresponds to a node identifier, and the associated 

2037 value is the dependency object to inject. If the value is a callable, it will be added as a calc node. 

2038 

2039 :param dependencies: A dictionary where each key-value pair consists of a node identifier and 

2040 its corresponding dependency object or a callable that returns the dependency object. 

2041 :param force: A boolean flag that, when set to True, forces the replacement of existing node values 

2042 with the ones provided in 'dependencies', regardless of their current state. Defaults to False. 

2043 :return: None 

2044 """ 

2045 for n in self.nodes(): 

2046 if force or self.s[n] == States.PLACEHOLDER: 

2047 obj = dependencies.get(n) 

2048 if obj is None: 

2049 continue 

2050 if callable(obj): 

2051 self.add_node(n, obj) 

2052 else: 

2053 self.add_node(n, value=obj)