Coverage for src / loman / computeengine.py: 99%
1019 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-22 21:30 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-22 21:30 +0000
1"""Core computation engine for dependency-aware calculation graphs."""
3import inspect
4import logging
5import traceback
6import types
7import warnings
8from collections import defaultdict
9from collections.abc import Callable, Iterable, Mapping
10from concurrent.futures import FIRST_COMPLETED, Executor, ThreadPoolExecutor, wait
11from dataclasses import dataclass, field
12from datetime import UTC, datetime
13from enum import Enum
14from typing import Any, BinaryIO
16import decorator
17import dill
18import networkx as nx
19import pandas as pd
21from .compat import get_signature
22from .consts import EdgeAttributes, NodeAttributes, NodeTransformations, States, SystemTags
23from .exception import (
24 CannotInsertToPlaceholderNodeException,
25 ComputationError,
26 LoopDetectedException,
27 MapException,
28 NodeAlreadyExistsException,
29 NonExistentNodeException,
30 ValidationError,
31)
32from .graph_utils import topological_sort
33from .nodekey import Name, Names, NodeKey, names_to_node_keys, node_keys_to_names, to_nodekey
34from .util import AttributeView, apply1, apply_n, as_iterable, value_eq
35from .visualization import GraphView, NodeFormatter
37LOG = logging.getLogger("loman.computeengine")
40@dataclass
41class Error:
42 """Container for error information during computation."""
44 exception: Exception
45 traceback: str
48@dataclass
49class NodeData:
50 """Data associated with a computation node."""
52 state: States
53 value: object
56@dataclass
57class TimingData:
58 """Timing information for computation execution."""
60 start: datetime
61 end: datetime
62 duration: float
65class _ParameterType(Enum):
66 """Internal enum for distinguishing positional and keyword parameters."""
68 ARG = 1
69 KWD = 2
72@dataclass
73class _ParameterItem:
74 """Internal container for parameter information during computation."""
76 type: _ParameterType
77 name: int | str
78 value: object
81def _node(func: Callable[..., Any], *args: Any, **kws: Any) -> Any: # pragma: no cover
82 """Internal wrapper function for node decorator."""
83 return func(*args, **kws)
86def node(
87 comp: "Computation", name: Name | None = None, *args: Any, **kw: Any
88) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
89 """Decorator to add a function as a node to a computation graph."""
91 def inner(f: Callable[..., Any]) -> Callable[..., Any]:
92 """Inner decorator that registers the function as a node."""
93 if name is None:
94 comp.add_node(f.__name__, f, *args, **kw) # type: ignore[attr-defined]
95 else:
96 comp.add_node(name, f, *args, **kw)
97 result: Callable[..., Any] = decorator.decorate(f, _node)
98 return result
100 return inner
103@dataclass()
104class ConstantValue:
105 """Container for constant values in computations."""
107 value: object
110C = ConstantValue
113class Node:
114 """Base class for computation graph nodes."""
116 def add_to_comp(self, comp: "Computation", name: str, obj: object, ignore_self: bool) -> None:
117 """Add this node to the computation graph."""
118 raise NotImplementedError()
121@dataclass
122class InputNode(Node):
123 """A node representing input data in the computation graph."""
125 args: tuple[Any, ...] = field(default_factory=tuple)
126 kwds: dict[str, Any] = field(default_factory=dict)
128 def __init__(self, *args: Any, **kwds: Any) -> None:
129 """Initialize an input node with arguments and keyword arguments."""
130 self.args = args
131 self.kwds = kwds
133 def add_to_comp(self, comp: "Computation", name: str, obj: object, ignore_self: bool) -> None:
134 """Add this input node to the computation graph."""
135 comp.add_node(name, **self.kwds)
138input_node = InputNode
141@dataclass
142class CalcNode(Node):
143 """A node representing a calculation in the computation graph."""
145 f: Callable[..., Any]
146 kwds: dict[str, Any] = field(default_factory=dict)
148 def add_to_comp(self, comp: "Computation", name: str, obj: object, ignore_self: bool) -> None:
149 """Add this calculation node to the computation graph."""
150 kwds = self.kwds.copy()
151 ignore_self = ignore_self or kwds.get("ignore_self", False)
152 f = self.f
153 if ignore_self:
154 signature = get_signature(self.f)
155 if len(signature.kwd_params) > 0 and signature.kwd_params[0] == "self":
156 f = f.__get__(obj, obj.__class__) # type: ignore[attr-defined]
157 if "ignore_self" in kwds:
158 del kwds["ignore_self"]
159 comp.add_node(name, f, **kwds)
162def calc_node(
163 f: Callable[..., Any] | None = None, **kwds: Any
164) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]:
165 """Decorator to mark a function as a calculation node."""
167 def wrap(func: Callable[..., Any]) -> Callable[..., Any]:
168 """Wrap function with node info attribute."""
169 func._loman_node_info = CalcNode(func, kwds) # type: ignore[attr-defined]
170 return func
172 if f is None:
173 return wrap
174 return wrap(f)
177@dataclass
178class Block(Node):
179 """A node representing a computational block or subgraph."""
181 block: "Callable[..., Computation] | Computation"
182 args: tuple[Any, ...] = field(default_factory=tuple)
183 kwds: dict[str, Any] = field(default_factory=dict)
185 def __init__(self, block: "Callable[..., Computation] | Computation", *args: Any, **kwds: Any) -> None:
186 """Initialize a block node with a computation block and arguments."""
187 self.block = block
188 self.args = args
189 self.kwds = kwds
191 def add_to_comp(self, comp: "Computation", name: str, obj: object, ignore_self: bool) -> None:
192 """Add this block node to the computation graph."""
193 if isinstance(self.block, Computation):
194 comp.add_block(name, self.block, *self.args, **self.kwds)
195 elif callable(self.block):
196 block0 = self.block()
197 comp.add_block(name, block0, *self.args, **self.kwds)
198 else:
199 msg = f"Block {self.block} must be callable or Computation"
200 raise TypeError(msg)
203block = Block
206def populate_computation_from_class(comp: "Computation", cls: type, obj: object, ignore_self: bool = True) -> None:
207 """Populate a computation from class methods with node decorators."""
208 for name, member in inspect.getmembers(cls):
209 node_: Node | None = None
210 if isinstance(member, Node):
211 node_ = member
212 elif hasattr(member, "_loman_node_info"):
213 node_ = member._loman_node_info
214 if node_ is not None:
215 node_.add_to_comp(comp, name, obj, ignore_self)
218def computation_factory(
219 maybe_cls: type | None = None, *, ignore_self: bool = True
220) -> Callable[..., "Computation"] | Callable[[type], Callable[..., "Computation"]]:
221 """Factory function to create computations from class definitions."""
223 def wrap(cls: type) -> Callable[..., "Computation"]:
224 """Wrap class to create computation factory function."""
226 def create_computation(*args: Any, **kwargs: Any) -> "Computation":
227 """Create a computation instance from the wrapped class."""
228 obj = cls()
229 comp = Computation(*args, **kwargs)
230 comp._definition_object = obj # type: ignore[attr-defined]
231 populate_computation_from_class(comp, cls, obj, ignore_self)
232 return comp
234 return create_computation
236 if maybe_cls is None:
237 return wrap
239 return wrap(maybe_cls)
242def _eval_node(
243 name: NodeKey,
244 f: Callable[..., Any],
245 args: list[Any],
246 kwds: dict[str, Any],
247 raise_exceptions: bool,
248) -> tuple[Any, Exception | None, str | None, datetime, datetime]:
249 """To make multiprocessing work, this function must be standalone so that pickle works."""
250 exc: Exception | None = None
251 tb: str | None = None
252 start_dt = datetime.now(UTC)
253 try:
254 logging.debug("Running " + str(name))
255 value = f(*args, **kwds)
256 logging.debug("Completed " + str(name))
257 except Exception as e:
258 value = None
259 exc = e
260 tb = traceback.format_exc()
261 if raise_exceptions:
262 raise
263 end_dt = datetime.now(UTC)
264 return value, exc, tb, start_dt, end_dt
267_MISSING_VALUE_SENTINEL = object()
270class NullObject:
271 """Debug helper object that raises exceptions for all attribute/item access."""
273 def __getattr__(self, name: str) -> Any:
274 """Raise AttributeError for any attribute access."""
275 print(f"__getattr__: {name}")
276 msg = f"'NullObject' object has no attribute '{name}'"
277 raise AttributeError(msg)
279 def __setattr__(self, name: str, value: Any) -> None:
280 """Raise AttributeError for any attribute assignment."""
281 print(f"__setattr__: {name}")
282 msg = f"'NullObject' object has no attribute '{name}'"
283 raise AttributeError(msg)
285 def __delattr__(self, name: str) -> None:
286 """Raise AttributeError for any attribute deletion."""
287 print(f"__delattr__: {name}")
288 msg = f"'NullObject' object has no attribute '{name}'"
289 raise AttributeError(msg)
291 def __call__(self, *args: Any, **kwargs: Any) -> Any:
292 """Raise TypeError when called as a function."""
293 print(f"__call__: {args}, {kwargs}")
294 msg = "'NullObject' object is not callable"
295 raise TypeError(msg)
297 def __getitem__(self, key: Any) -> Any:
298 """Raise KeyError for any item access."""
299 print(f"__getitem__: {key}")
300 msg = f"'NullObject' object has no item with key '{key}'"
301 raise KeyError(msg)
303 def __setitem__(self, key: Any, value: Any) -> None:
304 """Raise KeyError for any item assignment."""
305 print(f"__setitem__: {key}")
306 msg = f"'NullObject' object cannot have items set with key '{key}'"
307 raise KeyError(msg)
309 def __repr__(self) -> str:
310 """Return string representation of NullObject."""
311 print(f"__repr__: {object.__getattribute__(self, '__dict__')}")
312 return "<NullObject>"
315def identity_function(x: Any) -> Any:
316 """Return the input value unchanged."""
317 return x
320class Computation:
321 """A computation graph that manages dependencies and calculations.
323 The Computation class provides a framework for building and executing
324 computation graphs where nodes represent data or calculations, and edges
325 represent dependencies between them.
326 """
328 def __init__(
329 self,
330 *,
331 default_executor: Executor | None = None,
332 executor_map: dict[str, Executor] | None = None,
333 metadata: dict[str, Any] | None = None,
334 ) -> None:
335 """Initialize a new Computation.
337 :param default_executor: An executor
338 :type default_executor: concurrent.futures.Executor, default ThreadPoolExecutor(max_workers=1)
339 """
340 if default_executor is None:
341 self.default_executor: Executor = ThreadPoolExecutor(1)
342 else:
343 self.default_executor = default_executor
344 if executor_map is None:
345 self.executor_map: dict[str, Executor] = {}
346 else:
347 self.executor_map = executor_map
348 self.dag: nx.DiGraph = nx.DiGraph()
349 self._metadata: dict[NodeKey, Any] = {}
350 if metadata is not None:
351 self._metadata[NodeKey.root()] = metadata
353 self.v = self.get_attribute_view_for_path(NodeKey.root(), self._value_one, self.value)
354 self.s = self.get_attribute_view_for_path(NodeKey.root(), self._state_one, self.state)
355 self.i = self.get_attribute_view_for_path(NodeKey.root(), self._get_inputs_one_names, self.get_inputs)
356 self.o = self.get_attribute_view_for_path(NodeKey.root(), self._get_outputs_one, self.get_outputs)
357 self.t = self.get_attribute_view_for_path(NodeKey.root(), self._tag_one, self.tags)
358 self.style = self.get_attribute_view_for_path(NodeKey.root(), self._style_one, self.styles)
359 self.tim = self.get_attribute_view_for_path(NodeKey.root(), self._get_timing_one, self.get_timing)
360 self.x = self.get_attribute_view_for_path(
361 NodeKey.root(), self.compute_and_get_value, self.compute_and_get_value
362 )
363 self.src = self.get_attribute_view_for_path(NodeKey.root(), self.print_source, self.print_source)
364 self._tag_map: defaultdict[str, set[NodeKey]] = defaultdict(set)
365 self._state_map: dict[States, set[NodeKey]] = {state: set() for state in States}
367 def get_attribute_view_for_path(
368 self, nodekey: NodeKey, get_one_func: Callable[[Name], Any], get_many_func: Callable[[Name | Names], Any]
369 ) -> AttributeView:
370 """Create an attribute view for a specific node path."""
372 def node_func() -> Iterable[str]:
373 """Return list of child node names for this path."""
374 return [str(n) for n in self.get_tree_list_children(nodekey)]
376 def get_one_func_for_path(name: str) -> Any:
377 """Get value for a single node at this path."""
378 nk = to_nodekey(name)
379 new_nk = nk.prepend(nodekey)
380 if self.has_node(new_nk):
381 return get_one_func(new_nk)
382 elif self.tree_has_path(new_nk):
383 return self.get_attribute_view_for_path(new_nk, get_one_func, get_many_func)
384 else:
385 msg = f"Path {new_nk} does not exist"
386 raise KeyError(msg) # pragma: no cover
388 def get_many_func_for_path(name: Name | Names) -> Any:
389 """Get values for one or more nodes at this path."""
390 if isinstance(name, list):
391 return [get_one_func_for_path(str(n)) for n in name]
392 else:
393 return get_one_func_for_path(str(name))
395 return AttributeView(node_func, get_one_func_for_path, get_many_func_for_path)
397 def _get_names_for_state(self, state: States) -> set[Name]:
398 """Get node names that have a specific state."""
399 return set(node_keys_to_names(self._state_map[state]))
401 def _get_tags_for_state(self, tag: str) -> set[Name]:
402 """Get node names that have a specific tag."""
403 return set(node_keys_to_names(self._tag_map[tag]))
405 def _process_function_args(self, node_key: NodeKey, node: dict[str, Any], args: list[Any] | None) -> int:
406 """Process positional arguments for a function node."""
407 args_count = 0
408 if args:
409 args_count = len(args)
410 for i, arg in enumerate(args):
411 if isinstance(arg, ConstantValue):
412 node[NodeAttributes.ARGS][i] = arg.value
413 else:
414 input_vertex_name = arg
415 input_vertex_node_key = to_nodekey(input_vertex_name)
416 if not self.dag.has_node(input_vertex_node_key):
417 self.dag.add_node(input_vertex_node_key, **{NodeAttributes.STATE: States.PLACEHOLDER})
418 self._state_map[States.PLACEHOLDER].add(input_vertex_node_key)
419 self.dag.add_edge(
420 input_vertex_node_key, node_key, **{EdgeAttributes.PARAM: (_ParameterType.ARG, i)}
421 )
422 return args_count
424 def _build_param_map(
425 self,
426 func: Callable[..., Any],
427 node_key: NodeKey,
428 args_count: int,
429 kwds: dict[str, Any] | None,
430 inspect: bool,
431 ) -> tuple[dict[str, Any], list[str]]:
432 """Build parameter map for function node."""
433 param_map: dict[str, Any] = {}
434 default_names: list[str] = []
436 if inspect:
437 signature = get_signature(func)
438 if not signature.has_var_args:
439 for param_name in signature.kwd_params[args_count:]:
440 if kwds is not None and param_name in kwds:
441 param_source = kwds[param_name]
442 else:
443 param_source = node_key.parent.join_parts(param_name)
444 param_map[param_name] = param_source
445 if signature.has_var_kwds and kwds is not None:
446 for param_name, param_source in kwds.items():
447 param_map[param_name] = param_source
448 default_names = signature.default_params
449 else:
450 if kwds is not None:
451 for param_name, param_source in kwds.items():
452 param_map[param_name] = param_source
454 return param_map, default_names
456 def _process_function_kwds(
457 self, node_key: NodeKey, node: dict[str, Any], param_map: dict[str, Any], default_names: list[str]
458 ) -> None:
459 """Process keyword arguments for a function node."""
460 for param_name, param_source in param_map.items():
461 if isinstance(param_source, ConstantValue):
462 node[NodeAttributes.KWDS][param_name] = param_source.value
463 else:
464 in_node_name = param_source
465 in_node_key = to_nodekey(in_node_name)
466 if not self.dag.has_node(in_node_key):
467 if param_name in default_names:
468 continue
469 else:
470 self.dag.add_node(in_node_key, **{NodeAttributes.STATE: States.PLACEHOLDER})
471 self._state_map[States.PLACEHOLDER].add(in_node_key)
472 self.dag.add_edge(in_node_key, node_key, **{EdgeAttributes.PARAM: (_ParameterType.KWD, param_name)})
474 def add_node(
475 self,
476 name: Name,
477 func: Callable[..., Any] | None = None,
478 *,
479 args: list[Any] | None = None,
480 kwds: dict[str, Any] | None = None,
481 value: Any = _MISSING_VALUE_SENTINEL,
482 converter: Callable[[Any], Any] | None = None,
483 serialize: bool = True,
484 inspect: bool = True,
485 group: str | None = None,
486 tags: Iterable[str] | None = None,
487 style: str | None = None,
488 executor: str | None = None,
489 metadata: dict[str, Any] | None = None,
490 ) -> None:
491 """Adds or updates a node in a computation.
493 :param name: Name of the node to add. This may be any hashable object.
494 :param func: Function to use to calculate the node if the node is a calculation node. By default, the input
495 nodes to the function will be implied from the names of the function parameters. For example, a
496 parameter called ``a`` would be taken from the node called ``a``. This can be modified with the
497 ``kwds`` parameter.
498 :type func: Function, default None
499 :param args: Specifies a list of nodes that will be used to populate arguments of the function positionally
500 for a calculation node. e.g. If args is ``['a', 'b', 'c']`` then the function would be called with
501 three parameters, taken from the nodes 'a', 'b' and 'c' respectively.
502 :type args: List, default None
503 :param kwds: Specifies a mapping from parameter name to the node that should be used to populate that
504 parameter when calling the function for a calculation node. e.g. If args is ``{'x': 'a', 'y': 'b'}``
505 then the function would be called with parameters named 'x' and 'y', and their values would be taken
506 from nodes 'a' and 'b' respectively. Each entry in the dictionary can be read as "take parameter
507 [key] from node [value]".
508 :type kwds: Dictionary, default None
509 :param value: If given, the value is inserted into the node, and the node state set to UPTODATE.
510 :type value: default None
511 :param serialize: Whether the node should be serialized. Some objects cannot be serialized, in which
512 case, set serialize to False
513 :type serialize: boolean, default True
514 :param inspect: Whether to use introspection to determine the arguments of the function, which can be
515 slow. If this is not set, kwds and args must be set for the function to obtain parameters.
516 :type inspect: boolean, default True
517 :param group: Subgraph to render node in
518 :type group: default None
519 :param tags: Set of tags to apply to node
520 :type tags: Iterable
521 :param styles: Style to apply to node
522 :type styles: String, default None
523 :param executor: Name of executor to run node on
524 :type executor: string
525 """
526 node_key = to_nodekey(name)
527 LOG.debug(f"Adding node {node_key}")
528 has_value = value is not _MISSING_VALUE_SENTINEL
529 if value is _MISSING_VALUE_SENTINEL:
530 value = None
531 if tags is None:
532 tags = []
534 self.dag.add_node(node_key)
535 pred_edges = [(p, node_key) for p in self.dag.predecessors(node_key)]
536 self.dag.remove_edges_from(pred_edges)
537 node = self.dag.nodes[node_key]
539 if metadata is None:
540 if node_key in self._metadata:
541 del self._metadata[node_key]
542 else:
543 self._metadata[node_key] = metadata
545 self._set_state_and_literal_value(node_key, States.UNINITIALIZED, None, require_old_state=False)
547 node[NodeAttributes.TAG] = set()
548 node[NodeAttributes.STYLE] = style
549 node[NodeAttributes.GROUP] = group
550 node[NodeAttributes.ARGS] = {}
551 node[NodeAttributes.KWDS] = {}
552 node[NodeAttributes.FUNC] = None
553 node[NodeAttributes.EXECUTOR] = executor
554 node[NodeAttributes.CONVERTER] = converter
556 if func:
557 node[NodeAttributes.FUNC] = func
558 args_count = self._process_function_args(node_key, node, args)
559 param_map, default_names = self._build_param_map(func, node_key, args_count, kwds, inspect)
560 self._process_function_kwds(node_key, node, param_map, default_names)
561 self._set_descendents(node_key, States.STALE)
563 if has_value:
564 self._set_uptodate(node_key, value)
565 if node[NodeAttributes.STATE] == States.UNINITIALIZED:
566 self._try_set_computable(node_key)
567 self.set_tag(node_key, tags)
568 if serialize:
569 self.set_tag(node_key, SystemTags.SERIALIZE)
571 def _refresh_maps(self) -> None:
572 """Refresh internal tag and state maps from node data."""
573 self._tag_map.clear()
574 for state in States:
575 self._state_map[state].clear()
576 for node_key in self._node_keys():
577 state = self.dag.nodes[node_key][NodeAttributes.STATE]
578 self._state_map[state].add(node_key)
579 tags = self.dag.nodes[node_key].get(NodeAttributes.TAG, set())
580 for tag in tags:
581 self._tag_map[tag].add(node_key)
583 def _set_tag_one(self, name: Name, tag: str) -> None:
584 """Set a single tag on a single node."""
585 node_key = to_nodekey(name)
586 self.dag.nodes[node_key][NodeAttributes.TAG].add(tag)
587 self._tag_map[tag].add(node_key)
589 def set_tag(self, name: Name | Names, tag: str | Iterable[str]) -> None:
590 """Set tags on a node or nodes. Ignored if tags are already set.
592 :param name: Node or nodes to set tag for
593 :param tag: Tag to set
594 """
595 apply_n(self._set_tag_one, name, tag)
597 def _clear_tag_one(self, name: Name, tag: str) -> None:
598 """Clear a single tag from a single node."""
599 node_key = to_nodekey(name)
600 self.dag.nodes[node_key][NodeAttributes.TAG].discard(tag)
601 self._tag_map[tag].discard(node_key)
603 def clear_tag(self, name: Name | Names, tag: str | Iterable[str]) -> None:
604 """Clear tag on a node or nodes. Ignored if tags are not set.
606 :param name: Node or nodes to clear tags for
607 :param tag: Tag to clear
608 """
609 apply_n(self._clear_tag_one, name, tag)
611 def _set_style_one(self, name: Name, style: str) -> None:
612 """Set style on a single node."""
613 node_key = to_nodekey(name)
614 self.dag.nodes[node_key][NodeAttributes.STYLE] = style
616 def set_style(self, name: Name | Names, style: str) -> None:
617 """Set styles on a node or nodes.
619 :param name: Node or nodes to set style for
620 :param style: Style to set
621 """
622 apply_n(self._set_style_one, name, style)
624 def _clear_style_one(self, name: Name) -> None:
625 """Clear style from a single node."""
626 node_key = to_nodekey(name)
627 self.dag.nodes[node_key][NodeAttributes.STYLE] = None
629 def clear_style(self, name: Name | Names) -> None:
630 """Clear style on a node or nodes.
632 :param name: Node or nodes to clear styles for
633 """
634 apply_n(self._clear_style_one, name)
636 def metadata(self, name: Name) -> dict[str, Any]:
637 """Get metadata for a node."""
638 node_key = to_nodekey(name)
639 if self.tree_has_path(name):
640 if node_key not in self._metadata:
641 self._metadata[node_key] = {}
642 result: dict[str, Any] = self._metadata[node_key]
643 return result
644 else:
645 msg = f"Node {node_key} does not exist."
646 raise NonExistentNodeException(msg)
648 def delete_node(self, name: Name) -> None:
649 """Delete a node from a computation.
651 When nodes are explicitly deleted with ``delete_node``, but are still depended on by other nodes, then they
652 will be set to PLACEHOLDER status. In this case, if the nodes that depend on a PLACEHOLDER node are deleted,
653 then the PLACEHOLDER node will also be deleted.
655 :param name: Name of the node to delete. If the node does not exist, a ``NonExistentNodeException`` will
656 be raised.
657 """
658 node_key = to_nodekey(name)
659 LOG.debug(f"Deleting node {node_key}")
661 if not self.dag.has_node(node_key):
662 msg = f"Node {node_key} does not exist"
663 raise NonExistentNodeException(msg)
665 if node_key in self._metadata:
666 del self._metadata[node_key]
668 if len(self.dag.succ[node_key]) == 0:
669 preds = self.dag.predecessors(node_key)
670 state = self.dag.nodes[node_key][NodeAttributes.STATE]
671 self.dag.remove_node(node_key)
672 self._state_map[state].remove(node_key)
673 for n in preds:
674 if self.dag.nodes[n][NodeAttributes.STATE] == States.PLACEHOLDER:
675 self.delete_node(n)
676 else:
677 self._set_state(node_key, States.PLACEHOLDER)
679 def rename_node(self, old_name: Name | Mapping[Name, Name], new_name: Name | None = None) -> None:
680 """Rename a node in a computation.
682 :param old_name: Node to rename, or a dictionary of nodes to rename, with existing names as keys, and
683 new names as values
684 :param new_name: New name for node.
685 """
686 name_mapping: dict[Name, Name]
687 if isinstance(old_name, Mapping) and not isinstance(old_name, str):
688 for k, v in old_name.items():
689 LOG.debug(f"Renaming node {k} to {v}")
690 if new_name is not None:
691 msg = "new_name must not be set if rename_node is passed a dictionary"
692 raise ValueError(msg)
693 else:
694 name_mapping = dict(old_name) # type: ignore[arg-type]
695 else:
696 LOG.debug(f"Renaming node {old_name} to {new_name}")
697 old_node_key = to_nodekey(old_name)
698 if not self.dag.has_node(old_node_key):
699 msg = f"Node {old_name} does not exist"
700 raise NonExistentNodeException(msg)
701 assert new_name is not None # noqa: S101
702 new_node_key = to_nodekey(new_name)
703 if self.dag.has_node(new_node_key):
704 msg = f"Node {new_name} already exists"
705 raise NodeAlreadyExistsException(msg)
706 name_mapping = {old_name: new_name}
708 node_key_mapping = {to_nodekey(on): to_nodekey(nn) for on, nn in name_mapping.items()}
709 nx.relabel_nodes(self.dag, node_key_mapping, copy=False)
711 for old_nk, new_nk in node_key_mapping.items():
712 if old_nk in self._metadata:
713 self._metadata[new_nk] = self._metadata[old_nk]
714 del self._metadata[old_nk]
715 else:
716 if new_nk in self._metadata: # pragma: no cover
717 del self._metadata[new_nk]
719 self._refresh_maps()
721 def repoint(self, old_name: Name, new_name: Name) -> None:
722 """Changes all nodes that use old_name as an input to use new_name instead.
724 Note that if old_name is an input to new_name, then that will not be changed, to try to avoid introducing
725 circular dependencies, but other circular dependencies will not be checked.
727 If new_name does not exist, then it will be created as a PLACEHOLDER node.
729 :param old_name:
730 :param new_name:
731 :return:
732 """
733 old_node_key = to_nodekey(old_name)
734 new_node_key = to_nodekey(new_name)
735 if old_node_key == new_node_key:
736 return
738 changed_names = list(self.dag.successors(old_node_key))
740 if len(changed_names) > 0 and not self.dag.has_node(new_node_key):
741 self.dag.add_node(new_node_key, **{NodeAttributes.STATE: States.PLACEHOLDER})
742 self._state_map[States.PLACEHOLDER].add(new_node_key)
744 for name in changed_names:
745 if name == new_node_key:
746 continue
747 edge_data = self.dag.get_edge_data(old_node_key, name)
748 self.dag.add_edge(new_node_key, name, **edge_data)
749 self.dag.remove_edge(old_node_key, name)
751 for n in changed_names:
752 self.set_stale(n)
754 def insert(self, name: Name, value: Any, force: bool = False) -> None:
755 """Insert a value into a node of a computation.
757 Following insertation, the node will have state UPTODATE, and all its descendents will be COMPUTABLE or STALE.
759 If an attempt is made to insert a value into a node that does not exist, a ``NonExistentNodeException``
760 will be raised.
762 :param name: Name of the node to add.
763 :param value: The value to be inserted into the node.
764 :param force: Whether to force recalculation of descendents if node value and state would not be changed
765 """
766 node_key = to_nodekey(name)
767 LOG.debug(f"Inserting value into node {node_key}")
769 if not self.dag.has_node(node_key):
770 msg = f"Node {node_key} does not exist"
771 raise NonExistentNodeException(msg)
773 state = self._state_one(name)
774 if state == States.PLACEHOLDER:
775 msg = "Cannot insert into placeholder node. Use add_node to create the node first"
776 raise CannotInsertToPlaceholderNodeException(msg)
778 if not force and state == States.UPTODATE:
779 current_value = self._value_one(name)
780 if value_eq(value, current_value):
781 return
783 self._set_state_and_value(node_key, States.UPTODATE, value)
784 self._set_descendents(node_key, States.STALE)
785 for n in self.dag.successors(node_key):
786 self._try_set_computable(n)
788 def insert_many(self, name_value_pairs: Iterable[tuple[Name, object]]) -> None:
789 """Insert values into many nodes of a computation simultaneously.
791 Following insertation, the nodes will have state UPTODATE, and all their descendents will be COMPUTABLE
792 or STALE. In the case of inserting many nodes, some of which are descendents of others, this ensures that
793 the inserted nodes have correct status, rather than being set as STALE when their ancestors are inserted.
795 If an attempt is made to insert a value into a node that does not exist, a ``NonExistentNodeException`` will be
796 raised, and none of the nodes will be inserted.
798 :param name_value_pairs: Each tuple should be a pair (name, value), where name is the name of the node to
799 insert the value into.
800 :type name_value_pairs: List of tuples
801 """
802 node_key_value_pairs = [(to_nodekey(name), value) for name, value in name_value_pairs]
803 LOG.debug(f"Inserting value into nodes {', '.join(str(name) for name, value in node_key_value_pairs)}")
805 for name, _value in node_key_value_pairs:
806 if not self.dag.has_node(name):
807 msg = f"Node {name} does not exist"
808 raise NonExistentNodeException(msg)
810 stale = set()
811 computable = set()
812 for name, value in node_key_value_pairs:
813 self._set_state_and_value(name, States.UPTODATE, value)
814 stale.update(nx.dag.descendants(self.dag, name))
815 computable.update(self.dag.successors(name))
816 names = {name for name, value in node_key_value_pairs}
817 stale.difference_update(names)
818 computable.difference_update(names)
819 for name in stale:
820 self._set_state(name, States.STALE)
821 for name in computable:
822 self._try_set_computable(name)
824 def insert_from(self, other: "Computation", nodes: Iterable[Name] | None = None) -> None:
825 """Insert values into another Computation object into this Computation object.
827 :param other: The computation object to take values from
828 :type Computation:
829 :param nodes: Only populate the nodes with the names provided in this list. By default, all nodes from the
830 other Computation object that have corresponding nodes in this Computation object will be inserted
831 :type nodes: List, default None
832 """
833 if nodes is None:
834 nodes_set: set[Any] = set(self.dag.nodes)
835 nodes_set.intersection_update(other.dag.nodes())
836 nodes = nodes_set
837 name_value_pairs = [(name, other.value(name)) for name in nodes]
838 self.insert_many(name_value_pairs)
840 def _set_state(self, node_key: NodeKey, state: States) -> None:
841 """Set the state of a node without changing its value."""
842 node = self.dag.nodes[node_key]
843 old_state = node[NodeAttributes.STATE]
844 self._state_map[old_state].remove(node_key)
845 node[NodeAttributes.STATE] = state
846 self._state_map[state].add(node_key)
848 def _set_state_and_value(
849 self, node_key: NodeKey, state: States, value: object, *, throw_conversion_exception: bool = True
850 ) -> None:
851 """Set state and value of a node, applying any converter."""
852 node = self.dag.nodes[node_key]
853 converter = node.get(NodeAttributes.CONVERTER)
854 if converter is None:
855 self._set_state_and_literal_value(node_key, state, value)
856 else:
857 try:
858 converted_value = converter(value)
859 self._set_state_and_literal_value(node_key, state, converted_value)
860 except Exception as e:
861 tb = traceback.format_exc()
862 self._set_error(node_key, e, tb)
863 if throw_conversion_exception:
864 raise
866 def _set_state_and_literal_value(
867 self, node_key: NodeKey, state: States, value: object, require_old_state: bool = True
868 ) -> None:
869 """Set state and literal value of a node without conversion."""
870 node = self.dag.nodes[node_key]
871 try:
872 old_state = node[NodeAttributes.STATE]
873 self._state_map[old_state].remove(node_key)
874 except KeyError:
875 if require_old_state:
876 raise # pragma: no cover
877 node[NodeAttributes.STATE] = state
878 node[NodeAttributes.VALUE] = value
879 self._state_map[state].add(node_key)
881 def _set_states(self, node_keys: Iterable[NodeKey], state: States) -> None:
882 """Set the state of multiple nodes at once."""
883 for name in node_keys:
884 node = self.dag.nodes[name]
885 old_state = node[NodeAttributes.STATE]
886 self._state_map[old_state].remove(name)
887 node[NodeAttributes.STATE] = state
888 self._state_map[state].update(node_keys)
890 def set_stale(self, name: Name) -> None:
891 """Set the state of a node and all its dependencies to STALE.
893 :param name: Name of the node to set as STALE.
894 """
895 node_key = to_nodekey(name)
896 node_keys: list[NodeKey] = [node_key]
897 node_keys.extend(nx.dag.descendants(self.dag, node_key))
898 self._set_states(node_keys, States.STALE)
899 self._try_set_computable(node_key)
901 def pin(self, name: Name, value: Any = None) -> None:
902 """Set the state of a node to PINNED.
904 :param name: Name of the node to set as PINNED.
905 :param value: Value to pin to the node, if provided.
906 :type value: default None
907 """
908 node_key = to_nodekey(name)
909 if value is not None:
910 self.insert(node_key, value)
911 self._set_states([node_key], States.PINNED)
913 def unpin(self, name: Name) -> None:
914 """Unpin a node (state of node and all descendents will be set to STALE).
916 :param name: Name of the node to set as PINNED.
917 """
918 node_key = to_nodekey(name)
919 self.set_stale(node_key)
921 def _get_descendents(self, node_key: NodeKey, stop_states: set[States] | None = None) -> set[NodeKey]:
922 """Get all descendant nodes, optionally stopping at certain states."""
923 if stop_states is None:
924 stop_states = set()
925 if self.dag.nodes[node_key][NodeAttributes.STATE] in stop_states:
926 return set()
927 visited = set()
928 to_visit = {node_key}
929 while to_visit:
930 n = to_visit.pop()
931 visited.add(n)
932 for n1 in self.dag.successors(n):
933 if n1 in visited:
934 continue
935 if self.dag.nodes[n1][NodeAttributes.STATE] in stop_states:
936 continue
937 to_visit.add(n1)
938 visited.remove(node_key)
939 return visited
941 def _set_descendents(self, node_key: NodeKey, state: States) -> None:
942 """Set the state of all descendant nodes."""
943 descendents = self._get_descendents(node_key, {States.PINNED})
944 self._set_states(descendents, state)
946 def _set_uninitialized(self, node_key: NodeKey) -> None:
947 """Set a node to uninitialized state and clear its value."""
948 self._set_states([node_key], States.UNINITIALIZED)
949 self.dag.nodes[node_key].pop(NodeAttributes.VALUE, None)
951 def _set_uptodate(self, node_key: NodeKey, value: object) -> None:
952 """Set a node to up-to-date state with a value."""
953 self._set_state_and_value(node_key, States.UPTODATE, value)
954 self._set_descendents(node_key, States.STALE)
955 for n in self.dag.successors(node_key):
956 self._try_set_computable(n)
958 def _set_error(self, node_key: NodeKey, exc: Exception, tb: str) -> None:
959 """Set a node to error state with exception information."""
960 self._set_state_and_literal_value(node_key, States.ERROR, Error(exc, tb))
961 self._set_descendents(node_key, States.STALE)
963 def _try_set_computable(self, node_key: NodeKey) -> None:
964 """Set node to computable if all predecessors are up-to-date."""
965 if self.dag.nodes[node_key][NodeAttributes.STATE] == States.PINNED:
966 return
967 if self.dag.nodes[node_key].get(NodeAttributes.FUNC) is not None:
968 for n in self.dag.predecessors(node_key):
969 if not self.dag.has_node(n):
970 return # pragma: no cover
971 if self.dag.nodes[n][NodeAttributes.STATE] != States.UPTODATE:
972 return
973 self._set_state(node_key, States.COMPUTABLE)
975 def _get_parameter_data(self, node_key: NodeKey) -> Iterable[_ParameterItem]:
976 """Get all parameter data for a node's function call."""
977 for arg, value in self.dag.nodes[node_key][NodeAttributes.ARGS].items():
978 yield _ParameterItem(_ParameterType.ARG, arg, value)
979 for param_name, value in self.dag.nodes[node_key][NodeAttributes.KWDS].items():
980 yield _ParameterItem(_ParameterType.KWD, param_name, value)
981 for in_node_name in self.dag.predecessors(node_key):
982 param_value = self.dag.nodes[in_node_name][NodeAttributes.VALUE]
983 edge = self.dag[in_node_name][node_key]
984 param_type, param_name = edge[EdgeAttributes.PARAM]
985 yield _ParameterItem(param_type, param_name, param_value)
987 def _get_func_args_kwds(
988 self, node_key: NodeKey
989 ) -> tuple[Callable[..., Any], str | None, list[Any], dict[str, Any]]:
990 """Get function, executor name, args and kwargs for a node."""
991 node0 = self.dag.nodes[node_key]
992 f = node0[NodeAttributes.FUNC]
993 executor_name = node0.get(NodeAttributes.EXECUTOR)
994 args: list[Any] = []
995 kwds: dict[str, Any] = {}
996 for param in self._get_parameter_data(node_key):
997 if param.type == _ParameterType.ARG:
998 idx = param.name
999 assert isinstance(idx, int) # noqa: S101
1000 while len(args) <= idx:
1001 args.append(None)
1002 args[idx] = param.value
1003 elif param.type == _ParameterType.KWD:
1004 assert isinstance(param.name, str) # noqa: S101
1005 kwds[param.name] = param.value
1006 else: # pragma: no cover
1007 msg = f"Unexpected param type: {param.type}"
1008 raise ValidationError(msg)
1009 return f, executor_name, args, kwds
1011 def get_definition_args_kwds(self, name: Name) -> tuple[list[Any], dict[str, Any]]:
1012 """Get the arguments and keyword arguments for a node's function definition."""
1013 res_args: list[Any] = []
1014 res_kwds: dict[str, Any] = {}
1015 node_key = to_nodekey(name)
1016 node_data = self.dag.nodes[node_key]
1017 if NodeAttributes.ARGS in node_data:
1018 for idx, value in node_data[NodeAttributes.ARGS].items():
1019 while len(res_args) <= idx:
1020 res_args.append(None)
1021 res_args[idx] = C(value)
1022 if NodeAttributes.KWDS in node_data:
1023 for param_name, value in node_data[NodeAttributes.KWDS].items():
1024 res_kwds[param_name] = C(value)
1025 for in_node_key in self.dag.predecessors(node_key):
1026 edge = self.dag[in_node_key][node_key]
1027 if EdgeAttributes.PARAM in edge:
1028 param_type, param_name = edge[EdgeAttributes.PARAM]
1029 if param_type == _ParameterType.ARG:
1030 idx = param_name
1031 assert isinstance(idx, int) # noqa: S101
1032 while len(res_args) <= idx:
1033 res_args.append(None)
1034 res_args[idx] = in_node_key.name
1035 elif param_type == _ParameterType.KWD:
1036 res_kwds[param_name] = in_node_key.name
1037 else: # pragma: no cover
1038 msg = f"Unexpected param type: {param_type}"
1039 raise ValidationError(msg)
1040 return res_args, res_kwds
1042 def _compute_nodes(self, node_keys: Iterable[NodeKey], raise_exceptions: bool = False) -> None:
1043 """Compute multiple nodes, handling dependencies and parallel execution."""
1044 LOG.debug(f"Computing nodes {node_keys}")
1046 futs: dict[Any, NodeKey] = {}
1047 node_keys_set = set(node_keys)
1049 def run(name: NodeKey) -> None:
1050 """Submit a node computation to an executor."""
1051 f, executor_name, args, kwds = self._get_func_args_kwds(name)
1052 executor = self.default_executor if executor_name is None else self.executor_map[executor_name]
1053 fut = executor.submit(_eval_node, name, f, args, kwds, raise_exceptions)
1054 futs[fut] = name
1056 computed: set[NodeKey] = set()
1058 for node_key in node_keys_set:
1059 node0 = self.dag.nodes[node_key]
1060 state = node0[NodeAttributes.STATE]
1061 if state == States.COMPUTABLE:
1062 run(node_key)
1064 while len(futs) > 0:
1065 done, _not_done = wait(futs.keys(), return_when=FIRST_COMPLETED)
1066 for fut in done:
1067 node_key = futs.pop(fut)
1068 node0 = self.dag.nodes[node_key]
1069 try:
1070 value, exc, tb, start_dt, end_dt = fut.result()
1071 except Exception as e:
1072 exc = e
1073 tb = traceback.format_exc()
1074 self._set_error(node_key, exc, tb)
1075 raise
1076 delta = (end_dt - start_dt).total_seconds()
1077 if exc is None:
1078 self._set_state_and_value(node_key, States.UPTODATE, value, throw_conversion_exception=False)
1079 node0[NodeAttributes.TIMING] = TimingData(start_dt, end_dt, delta)
1080 self._set_descendents(node_key, States.STALE)
1081 for n in self.dag.successors(node_key):
1082 logging.debug(str(node_key) + " " + str(n) + " " + str(computed))
1083 if n in computed:
1084 msg = f"Calculating {node_key} for the second time"
1085 raise LoopDetectedException(msg)
1086 self._try_set_computable(n)
1087 node0 = self.dag.nodes[n]
1088 state = node0[NodeAttributes.STATE]
1089 if state == States.COMPUTABLE and n in node_keys_set:
1090 run(n)
1091 else:
1092 assert tb is not None # noqa: S101
1093 self._set_error(node_key, exc, tb)
1094 computed.add(node_key)
1096 def _get_calc_node_keys(self, node_key: NodeKey) -> list[NodeKey]:
1097 """Get node keys that need to be computed for a target node."""
1098 g = nx.DiGraph()
1099 g.add_nodes_from(self.dag.nodes)
1100 g.add_edges_from(self.dag.edges)
1101 for n in nx.ancestors(g, node_key):
1102 node = self.dag.nodes[n]
1103 state = node[NodeAttributes.STATE]
1104 if state == States.UPTODATE or state == States.PINNED:
1105 g.remove_node(n)
1107 ancestors = nx.ancestors(g, node_key)
1108 for n in ancestors:
1109 node = self.dag.nodes[n]
1110 state = node[NodeAttributes.STATE]
1112 if state == States.UNINITIALIZED and len(self.dag.pred[n]) == 0:
1113 msg = f"Cannot compute {node_key} because {n} uninitialized"
1114 raise ValidationError(msg)
1115 if state == States.PLACEHOLDER:
1116 msg = f"Cannot compute {node_key} because {n} is placeholder"
1117 raise ValidationError(msg)
1119 ancestors.add(node_key)
1120 nodes_sorted = topological_sort(g)
1121 return [n for n in nodes_sorted if n in ancestors]
1123 def _get_calc_node_names(self, name: Name) -> Names:
1124 """Get node names that need to be computed for a target node."""
1125 node_key = to_nodekey(name)
1126 return node_keys_to_names(self._get_calc_node_keys(node_key))
1128 def compute(self, name: Name | Iterable[Name], raise_exceptions: bool = False) -> None:
1129 """Compute a node and all necessary predecessors.
1131 Following the computation, if successful, the target node, and all necessary ancestors that were not already
1132 UPTODATE will have been calculated and set to UPTODATE. Any node that did not need to be calculated will not
1133 have been recalculated.
1135 If any nodes raises an exception, then the state of that node will be set to ERROR, and its value set to an
1136 object containing the exception object, as well as a traceback. This will not halt the computation, which
1137 will proceed as far as it can, until no more nodes that would be required to calculate the target are
1138 COMPUTABLE.
1140 :param name: Name of the node to compute
1141 :param raise_exceptions: Whether to pass exceptions raised by node computations back to the caller
1142 :type raise_exceptions: Boolean, default False
1143 """
1144 calc_nodes: set[NodeKey] | list[NodeKey]
1145 if isinstance(name, (types.GeneratorType, list)):
1146 calc_nodes = set()
1147 for name0 in name:
1148 node_key = to_nodekey(name0)
1149 for n in self._get_calc_node_keys(node_key):
1150 calc_nodes.add(n)
1151 else:
1152 node_key = to_nodekey(name)
1153 calc_nodes = self._get_calc_node_keys(node_key)
1154 self._compute_nodes(calc_nodes, raise_exceptions=raise_exceptions)
1156 def compute_all(self, raise_exceptions: bool = False) -> None:
1157 """Compute all nodes of a computation that can be computed.
1159 Nodes that are already UPTODATE will not be recalculated. Following the computation, if successful, all
1160 nodes will have state UPTODATE, except UNINITIALIZED input nodes and PLACEHOLDER nodes.
1162 If any nodes raises an exception, then the state of that node will be set to ERROR, and its value set to an
1163 object containing the exception object, as well as a traceback. This will not halt the computation, which
1164 will proceed as far as it can, until no more nodes are COMPUTABLE.
1166 :param raise_exceptions: Whether to pass exceptions raised by node computations back to the caller
1167 :type raise_exceptions: Boolean, default False
1168 """
1169 self._compute_nodes(self._node_keys(), raise_exceptions=raise_exceptions)
1171 def _node_keys(self) -> list[NodeKey]:
1172 """Get a list of nodes in this computation.
1174 :return: List of nodes.
1175 """
1176 return list(self.dag.nodes)
1178 def nodes(self) -> list[Name]:
1179 """Get a list of nodes in this computation.
1181 :return: List of nodes.
1182 """
1183 return [n.name for n in self.dag.nodes]
1185 def get_tree_list_children(self, name: Name) -> set[Name]:
1186 """Get a list of nodes in this computation.
1188 :return: List of nodes.
1189 """
1190 node_key = to_nodekey(name)
1191 idx = len(node_key.parts)
1192 result = set()
1193 for n in self.dag.nodes:
1194 if n.is_descendent_of(node_key):
1195 result.add(n.parts[idx])
1196 return result
1198 def has_node(self, name: Name) -> bool:
1199 """Check if a node with the given name exists in the computation."""
1200 node_key = to_nodekey(name)
1201 return node_key in self.dag.nodes
1203 def tree_has_path(self, name: Name) -> bool:
1204 """Check if a hierarchical path exists in the computation tree."""
1205 node_key = to_nodekey(name)
1206 if node_key.is_root:
1207 return True
1208 if self.has_node(node_key):
1209 return True
1210 return any(n.is_descendent_of(node_key) for n in self.dag.nodes)
1212 def get_tree_descendents(
1213 self, name: Name | None = None, *, include_stem: bool = True, graph_nodes_only: bool = False
1214 ) -> set[Name]:
1215 """Get a list of descendent blocks and nodes.
1217 Returns blocks and nodes that are descendents of the input node,
1218 e.g. for node 'foo', might return ['foo/bar', 'foo/baz'].
1220 :param name: Name of node to get descendents for
1221 :return: List of descendent node names
1222 """
1223 node_key = NodeKey.root() if name is None else to_nodekey(name)
1224 stemsize = len(node_key.parts)
1225 result = set()
1226 for n in self.dag.nodes:
1227 if n.is_descendent_of(node_key):
1228 nodes = [n] if graph_nodes_only else n.ancestors()
1229 for n2 in nodes:
1230 if n2.is_descendent_of(node_key):
1231 nm = n2.name if include_stem else NodeKey(tuple(n2.parts[stemsize:])).name
1232 result.add(nm)
1233 return result
1235 def _state_one(self, name: Name) -> States:
1236 """Get the state of a single node."""
1237 node_key = to_nodekey(name)
1238 state: States = self.dag.nodes[node_key][NodeAttributes.STATE]
1239 return state
1241 def state(self, name: Name | Names) -> Any:
1242 """Get the state of a node.
1244 This can also be accessed using the attribute-style accessor ``s`` if ``name`` is a valid Python
1245 attribute name::
1247 >>> comp = Computation()
1248 >>> comp.add_node('foo', value=1)
1249 >>> comp.state('foo')
1250 <States.UPTODATE: 4>
1251 >>> comp.s.foo
1252 <States.UPTODATE: 4>
1254 :param name: Name or names of the node to get state for
1255 :type name: Name or Names
1256 """
1257 return apply1(self._state_one, name)
1259 def _value_one(self, name: Name) -> Any:
1260 """Get the value of a single node."""
1261 node_key = to_nodekey(name)
1262 return self.dag.nodes[node_key][NodeAttributes.VALUE]
1264 def value(self, name: Name | Names) -> Any:
1265 """Get the current value of a node.
1267 This can also be accessed using the attribute-style accessor ``v`` if ``name`` is a valid Python
1268 attribute name::
1270 >>> comp = Computation()
1271 >>> comp.add_node('foo', value=1)
1272 >>> comp.value('foo')
1273 1
1274 >>> comp.v.foo
1275 1
1277 :param name: Name or names of the node to get the value of
1278 :type name: Name or Names
1279 """
1280 return apply1(self._value_one, name)
1282 def compute_and_get_value(self, name: Name) -> Any:
1283 """Get the current value of a node.
1285 This can also be accessed using the attribute-style accessor ``v`` if ``name`` is a valid Python
1286 attribute name::
1288 >>> comp = Computation()
1289 >>> comp.add_node('foo', value=1)
1290 >>> comp.add_node('bar', lambda foo: foo + 1)
1291 >>> comp.compute_and_get_value('bar')
1292 2
1293 >>> comp.x.bar
1294 2
1296 :param name: Name or names of the node to get the value of
1297 :type name: Name
1298 """
1299 nk = to_nodekey(name)
1300 if self.state(nk) == States.UPTODATE:
1301 return self.value(nk)
1302 self.compute(nk, raise_exceptions=True)
1303 if self.state(nk) == States.UPTODATE:
1304 return self.value(nk)
1305 msg = f"Unable to compute node {nk}"
1306 raise ComputationError(msg)
1308 def _tag_one(self, name: Name) -> set[str]:
1309 """Get the tags of a single node."""
1310 node_key = to_nodekey(name)
1311 node = self.dag.nodes[node_key]
1312 tags: set[str] = node[NodeAttributes.TAG]
1313 return tags
1315 def tags(self, name: Name | Names) -> Any:
1316 """Get the tags associated with a node.
1318 >>> comp = Computation()
1319 >>> comp.add_node('a', tags=['foo', 'bar'])
1320 >>> sorted(comp.t.a)
1321 ['__serialize__', 'bar', 'foo']
1323 :param name: Name or names of the node to get the tags of
1324 :return:
1325 """
1326 return apply1(self._tag_one, name)
1328 def nodes_by_tag(self, tag: str | Iterable[str]) -> set[Name]:
1329 """Get the names of nodes with a particular tag or tags.
1331 :param tag: Tag or tags for which to retrieve nodes
1332 :return: Names of the nodes with those tags
1333 """
1334 nodes: set[NodeKey] = set()
1335 tags_to_check: Iterable[str] = [tag] if isinstance(tag, str) else tag
1336 for tag1 in tags_to_check:
1337 nodes1 = self._tag_map.get(tag1)
1338 if nodes1 is not None:
1339 nodes.update(nodes1)
1340 return {n.name for n in nodes}
1342 def _style_one(self, name: Name) -> str | None:
1343 """Get the style of a single node."""
1344 node_key = to_nodekey(name)
1345 node = self.dag.nodes[node_key]
1346 style: str | None = node.get(NodeAttributes.STYLE)
1347 return style
1349 def styles(self, name: Name | Names) -> Any:
1350 """Get the tags associated with a node.
1352 >>> comp = Computation()
1353 >>> comp.add_node('a', style='dot')
1354 >>> comp.style.a
1355 'dot'
1357 :param name: Name or names of the node to get the tags of
1358 :return:
1359 """
1360 return apply1(self._style_one, name)
1362 def _get_item_one(self, name: Name) -> NodeData:
1363 """Get state and value data for a single node."""
1364 node_key = to_nodekey(name)
1365 node = self.dag.nodes[node_key]
1366 return NodeData(node[NodeAttributes.STATE], node[NodeAttributes.VALUE])
1368 def __getitem__(self, name: Name | Names) -> Any:
1369 """Get the state and current value of a node.
1371 :param name: Name of the node to get the state and value of
1372 """
1373 return apply1(self._get_item_one, name)
1375 def _get_timing_one(self, name: Name) -> TimingData | None:
1376 """Get timing data for a single node."""
1377 node_key = to_nodekey(name)
1378 node = self.dag.nodes[node_key]
1379 timing: TimingData | None = node.get(NodeAttributes.TIMING, None)
1380 return timing
1382 def get_timing(self, name: Name | Names) -> Any:
1383 """Get the timing information for a node.
1385 :param name: Name or names of the node to get the timing information of
1386 :return:
1387 """
1388 return apply1(self._get_timing_one, name)
1390 def to_df(self) -> pd.DataFrame:
1391 """Get a dataframe containing the states and value of all nodes of computation.
1393 ::
1395 >>> import loman
1396 >>> comp = loman.Computation()
1397 >>> comp.add_node('foo', value=1)
1398 >>> comp.add_node('bar', value=2)
1399 >>> comp.to_df() # doctest: +NORMALIZE_WHITESPACE
1400 state value
1401 foo States.UPTODATE 1
1402 bar States.UPTODATE 2
1403 """
1404 df = pd.DataFrame(index=topological_sort(self.dag))
1405 df[NodeAttributes.STATE] = pd.Series(nx.get_node_attributes(self.dag, NodeAttributes.STATE))
1406 df[NodeAttributes.VALUE] = pd.Series(nx.get_node_attributes(self.dag, NodeAttributes.VALUE))
1407 df_timing = pd.DataFrame.from_dict(nx.get_node_attributes(self.dag, "timing"), orient="index")
1408 df = pd.merge(df, df_timing, left_index=True, right_index=True, how="left")
1409 df.index = pd.Index([nk.name for nk in df.index])
1410 return df
1412 def to_dict(self) -> dict[NodeKey, Any]:
1413 """Get a dictionary containing the values of all nodes of a computation.
1415 ::
1417 >>> import loman
1418 >>> comp = loman.Computation()
1419 >>> comp.add_node('foo', value=1)
1420 >>> comp.add_node('bar', value=2)
1421 >>> comp.to_dict() # doctest: +ELLIPSIS
1422 {NodeKey('foo'): 1, NodeKey('bar'): 2}
1423 """
1424 result: dict[NodeKey, Any] = nx.get_node_attributes(self.dag, NodeAttributes.VALUE)
1425 return result
1427 def _get_inputs_one_node_keys(self, node_key: NodeKey) -> list[NodeKey | None]:
1428 """Get input node keys for a single node."""
1429 args_dict: dict[int, NodeKey] = {}
1430 kwds: list[NodeKey | None] = []
1431 max_arg_index = -1
1432 for input_node in self.dag.predecessors(node_key):
1433 input_edge = self.dag[input_node][node_key]
1434 input_type, input_param = input_edge[EdgeAttributes.PARAM]
1435 if input_type == _ParameterType.ARG:
1436 idx = input_param
1437 max_arg_index = max(max_arg_index, idx)
1438 args_dict[idx] = input_node
1439 elif input_type == _ParameterType.KWD:
1440 kwds.append(input_node)
1441 if max_arg_index >= 0:
1442 args: list[NodeKey | None] = [None] * (max_arg_index + 1)
1443 for idx, input_node in args_dict.items():
1444 args[idx] = input_node
1445 result: list[NodeKey | None] = args + kwds
1446 return result
1447 else:
1448 return kwds
1450 def _get_inputs_one_names(self, name: Name) -> Names:
1451 """Get input node names for a single node."""
1452 node_key = to_nodekey(name)
1453 return node_keys_to_names([nk for nk in self._get_inputs_one_node_keys(node_key) if nk is not None])
1455 def get_inputs(self, name: Name | Names) -> Any:
1456 """Get a list of the inputs for a node or set of nodes.
1458 :param name: Name or names of nodes to get inputs for
1459 :return: If name is scalar, return a list of upstream nodes used as input. If name is a list, return a
1460 list of list of inputs.
1461 """
1462 return apply1(self._get_inputs_one_names, name)
1464 def _get_ancestors_node_keys(self, node_keys: Iterable[NodeKey], include_self: bool = True) -> set[NodeKey]:
1465 """Get all ancestor node keys for a set of nodes."""
1466 ancestors: set[NodeKey] = set()
1467 for n in node_keys:
1468 if include_self:
1469 ancestors.add(n)
1470 for ancestor in nx.ancestors(self.dag, n):
1471 ancestors.add(ancestor)
1472 return ancestors
1474 def get_ancestors(self, names: Name | Names, include_self: bool = True) -> Names:
1475 """Get all ancestor nodes of the specified nodes."""
1476 node_keys = names_to_node_keys(names)
1477 ancestor_node_keys = self._get_ancestors_node_keys(node_keys, include_self)
1478 return node_keys_to_names(ancestor_node_keys)
1480 def _get_original_inputs_node_keys(self, node_keys: list[NodeKey] | None) -> list[NodeKey]:
1481 """Get original input node keys that have no computation function."""
1482 resolved_node_keys: Iterable[NodeKey]
1483 resolved_node_keys = self._node_keys() if node_keys is None else self._get_ancestors_node_keys(node_keys)
1484 return [n for n in resolved_node_keys if self.dag.nodes[n].get(NodeAttributes.FUNC) is None]
1486 def get_original_inputs(self, names: Name | Names | None = None) -> Names:
1487 """Get a list of the original non-computed inputs for a node or set of nodes.
1489 :param names: Name or names of nodes to get inputs for
1490 :return: Return a list of original non-computed inputs that are ancestors of the input nodes
1491 """
1492 nks = None if names is None else names_to_node_keys(names)
1494 result_nks = self._get_original_inputs_node_keys(nks)
1496 return node_keys_to_names(result_nks)
1498 def _get_outputs_one(self, name: Name) -> Names:
1499 """Get output node names for a single node."""
1500 node_key = to_nodekey(name)
1501 output_node_keys = list(self.dag.successors(node_key))
1502 return node_keys_to_names(output_node_keys)
1504 def get_outputs(self, name: Name | Names) -> Any:
1505 """Get a list of the outputs for a node or set of nodes.
1507 :param name: Name or names of nodes to get outputs for
1508 :return: If name is scalar, return a list of downstream nodes used as output. If name is a list, return a
1509 list of list of outputs.
1511 """
1512 return apply1(self._get_outputs_one, name)
1514 def _get_descendents_node_keys(self, node_keys: Iterable[NodeKey], include_self: bool = True) -> set[NodeKey]:
1515 """Get all descendant node keys for a set of nodes."""
1516 descendent_node_keys: set[NodeKey] = set()
1517 for node_key in node_keys:
1518 if include_self:
1519 descendent_node_keys.add(node_key)
1520 for descendent in nx.descendants(self.dag, node_key):
1521 descendent_node_keys.add(descendent)
1522 return descendent_node_keys
1524 def get_descendents(self, names: Name | Names, include_self: bool = True) -> Names:
1525 """Get all descendent nodes of the specified nodes."""
1526 node_keys = names_to_node_keys(names)
1527 descendent_node_keys = self._get_descendents_node_keys(node_keys, include_self)
1528 return node_keys_to_names(descendent_node_keys)
1530 def get_final_outputs(self, names: Name | Names | None = None) -> Names:
1531 """Get final output nodes (nodes with no descendants) from the specified nodes."""
1532 final_node_keys: Iterable[NodeKey]
1533 if names is None:
1534 final_node_keys = self._node_keys()
1535 else:
1536 nks = names_to_node_keys(names)
1537 final_node_keys = self._get_descendents_node_keys(nks)
1538 output_node_keys = [n for n in final_node_keys if len(nx.descendants(self.dag, n)) == 0]
1539 return node_keys_to_names(output_node_keys)
1541 def get_source(self, name: Name) -> str:
1542 """Get the source code for a node."""
1543 node_key = to_nodekey(name)
1544 func = self.dag.nodes[node_key].get(NodeAttributes.FUNC, None)
1545 if func is not None:
1546 file = inspect.getsourcefile(func)
1547 _, lineno = inspect.getsourcelines(func)
1548 source = inspect.getsource(func)
1549 return f"{file}:{lineno}\n\n{source}"
1550 else:
1551 return "NOT A CALCULATED NODE"
1553 def print_source(self, name: Name) -> None:
1554 """Print the source code for a computation node."""
1555 print(self.get_source(name))
1557 def restrict(self, output_names: Name | Names, input_names: Name | Names | None = None) -> None:
1558 """Restrict a computation to the ancestors of a set of output nodes.
1560 Excludes ancestors of a set of input nodes.
1562 If the set of input_nodes that is specified is not sufficient for the set of output_nodes then additional
1563 nodes that are ancestors of the output_nodes will be included, but the input nodes specified will be input
1564 nodes of the modified Computation.
1566 :param output_nodes:
1567 :param input_nodes:
1568 :return: None - modifies existing computation in place
1569 """
1570 if input_names is not None:
1571 for n in as_iterable(input_names):
1572 nodedata = self._get_item_one(n)
1573 self.add_node(n)
1574 self._set_state_and_literal_value(to_nodekey(n), nodedata.state, nodedata.value)
1575 output_node_keys = names_to_node_keys(output_names)
1576 ancestor_node_keys = self._get_ancestors_node_keys(output_node_keys)
1577 self.dag.remove_nodes_from([n for n in self.dag if n not in ancestor_node_keys])
1579 def __getstate__(self) -> dict[str, Any]:
1580 """Prepare computation for serialization by removing non-serializable nodes."""
1581 node_serialize = nx.get_node_attributes(self.dag, NodeAttributes.TAG)
1582 obj = self.copy()
1583 for name, tags in node_serialize.items():
1584 if SystemTags.SERIALIZE not in tags:
1585 obj._set_uninitialized(name)
1586 return {"dag": obj.dag}
1588 def __setstate__(self, state: dict[str, Any]) -> None:
1589 """Restore computation from serialized state."""
1590 self.__init__()
1591 self.dag = state["dag"]
1592 self._refresh_maps()
1594 def write_dill_old(self, file_: str | BinaryIO) -> None:
1595 """Serialize a computation to a file or file-like object.
1597 :param file_: If string, writes to a file
1598 :type file_: File-like object, or string
1599 """
1600 warnings.warn("write_dill_old is deprecated, use write_dill instead", DeprecationWarning, stacklevel=2)
1601 original_getstate = self.__class__.__getstate__
1602 original_setstate = self.__class__.__setstate__
1604 try:
1605 del self.__class__.__getstate__
1606 del self.__class__.__setstate__
1608 node_serialize = nx.get_node_attributes(self.dag, NodeAttributes.TAG)
1609 obj = self.copy()
1610 obj.executor_map = None # type: ignore[assignment]
1611 obj.default_executor = None # type: ignore[assignment]
1612 for name, tags in node_serialize.items():
1613 if SystemTags.SERIALIZE not in tags:
1614 obj._set_uninitialized(name)
1616 if isinstance(file_, str):
1617 with open(file_, "wb") as f:
1618 dill.dump(obj, f)
1619 else:
1620 dill.dump(obj, file_)
1621 finally:
1622 self.__class__.__getstate__ = original_getstate # type: ignore[method-assign]
1623 self.__class__.__setstate__ = original_setstate
1625 def write_dill(self, file_: str | BinaryIO) -> None:
1626 """Serialize a computation to a file or file-like object.
1628 :param file_: If string, writes to a file
1629 :type file_: File-like object, or string
1630 """
1631 if isinstance(file_, str):
1632 with open(file_, "wb") as f:
1633 dill.dump(self, f)
1634 else:
1635 dill.dump(self, file_)
1637 @staticmethod
1638 def read_dill(file_: str | BinaryIO) -> "Computation":
1639 """Deserialize a computation from a file or file-like object.
1641 .. warning::
1642 This method uses dill.load() which can execute arbitrary code.
1643 Only load files from trusted sources. Never load data from
1644 untrusted or unauthenticated sources as it may lead to arbitrary
1645 code execution.
1647 :param file_: If string, writes to a file
1648 :type file_: File-like object, or string
1649 """
1650 if isinstance(file_, str):
1651 with open(file_, "rb") as f:
1652 obj = dill.load(f) # noqa: S301
1653 else:
1654 obj = dill.load(file_) # noqa: S301
1655 if isinstance(obj, Computation):
1656 return obj
1657 else:
1658 msg = "Loaded object is not a Computation"
1659 raise ValidationError(msg)
1661 def copy(self) -> "Computation":
1662 """Create a copy of a computation.
1664 The copy is shallow. Any values in the new Computation's DAG will be the same object as this Computation's
1665 DAG. As new objects will be created by any further computations, this should not be an issue.
1667 :rtype: Computation
1668 """
1669 obj = Computation()
1670 obj.dag = nx.DiGraph(self.dag)
1671 obj._tag_map = defaultdict(set, {tag: nodes.copy() for tag, nodes in self._tag_map.items()})
1672 obj._state_map = {state: nodes.copy() for state, nodes in self._state_map.items()}
1673 return obj
1675 def add_named_tuple_expansion(self, name: Name, namedtuple_type: type, group: str | None = None) -> None:
1676 """Automatically add nodes to extract each element of a named tuple type.
1678 It is often convenient for a calculation to return multiple values, and it is polite to do this a namedtuple
1679 rather than a regular tuple, so that later users have same name to identify elements of the tuple. It can
1680 also help make a computation clearer if a downstream computation depends on one element of such a tuple,
1681 rather than the entire tuple. This does not affect the computation per se, but it does make the intention
1682 clearer.
1684 To avoid having to create many boiler-plate node definitions to expand namedtuples, the
1685 ``add_named_tuple_expansion`` method automatically creates new nodes for each element of a tuple. The
1686 convention is that an element called 'element', in a node called 'node' will be expanded into a new node
1687 called 'node.element', and that this will be applied for each element.
1689 Example::
1691 >>> from collections import namedtuple
1692 >>> Coordinate = namedtuple('Coordinate', ['x', 'y'])
1693 >>> comp = Computation()
1694 >>> comp.add_node('c', value=Coordinate(1, 2))
1695 >>> comp.add_named_tuple_expansion('c', Coordinate)
1696 >>> comp.compute_all()
1697 >>> comp.value('c.x')
1698 1
1699 >>> comp.value('c.y')
1700 2
1702 :param name: Node to cera
1703 :param namedtuple_type: Expected type of the node
1704 :type namedtuple_type: namedtuple class
1705 """
1707 def make_f(field_name: str) -> Callable[[Any], Any]:
1708 """Create a function to extract a field from a namedtuple."""
1710 def get_field_value(tuple_val: Any) -> Any:
1711 """Extract field value from the namedtuple."""
1712 return getattr(tuple_val, field_name)
1714 return get_field_value
1716 for field_name in namedtuple_type._fields: # type: ignore[attr-defined]
1717 node_name = f"{name}.{field_name}"
1718 self.add_node(node_name, make_f(field_name), kwds={"tuple_val": name}, group=group)
1719 self.set_tag(node_name, SystemTags.EXPANSION)
1721 def add_map_node(
1722 self,
1723 result_node: Name,
1724 input_node: Name,
1725 subgraph: "Computation",
1726 subgraph_input_node: Name,
1727 subgraph_output_node: Name,
1728 ) -> None:
1729 """Apply a graph to each element of iterable.
1731 In turn, each element in the ``input_node`` of this graph will be inserted in turn into the subgraph's
1732 ``subgraph_input_node``, then the subgraph's ``subgraph_output_node`` calculated. The resultant list, with
1733 an element or each element in ``input_node``, will be inserted into ``result_node`` of this graph. In this
1734 way ``add_map_node`` is similar to ``map`` in functional programming.
1736 :param result_node: The node to place a list of results in **this** graph
1737 :param input_node: The node to get a list input values from **this** graph
1738 :param subgraph: The graph to use to perform calculation for each element
1739 :param subgraph_input_node: The node in **subgraph** to insert each element in turn
1740 :param subgraph_output_node: The node in **subgraph** to read the result for each element
1741 """
1743 def f(xs: Iterable[Any]) -> list[Any]:
1744 """Apply subgraph computation to each element in the input."""
1745 results: list[Any] = []
1746 is_error = False
1747 for x in xs:
1748 subgraph.insert(subgraph_input_node, x)
1749 subgraph.compute(subgraph_output_node)
1750 if subgraph.state(subgraph_output_node) == States.UPTODATE:
1751 results.append(subgraph.value(subgraph_output_node))
1752 else:
1753 is_error = True
1754 results.append(subgraph.copy())
1755 if is_error:
1756 msg = f"Unable to calculate {result_node}"
1757 raise MapException(msg, results)
1758 return results
1760 self.add_node(result_node, f, kwds={"xs": input_node})
1762 def prepend_path(self, path: Name | ConstantValue, prefix_path: NodeKey) -> NodeKey | ConstantValue:
1763 """Prepend a prefix path to a node path."""
1764 if isinstance(path, ConstantValue):
1765 return path
1766 nk = to_nodekey(path)
1767 return prefix_path.join(nk)
1769 def add_block(
1770 self,
1771 base_path: Name,
1772 block: "Computation",
1773 *,
1774 keep_values: bool | None = True,
1775 links: dict[str, Name] | None = None,
1776 metadata: dict[str, Any] | None = None,
1777 ) -> None:
1778 """Add a computation block as a subgraph to this computation."""
1779 base_path_nk = to_nodekey(base_path)
1780 for node_name in block.nodes():
1781 node_key = to_nodekey(node_name)
1782 node_data = block.dag.nodes[node_key]
1783 tags = node_data.get(NodeAttributes.TAG, None)
1784 style = node_data.get(NodeAttributes.STYLE, None)
1785 group = node_data.get(NodeAttributes.GROUP, None)
1786 args_def, kwds_def = block.get_definition_args_kwds(node_key)
1787 args_prepended = [self.prepend_path(arg, base_path_nk) for arg in args_def]
1788 kwds_prepended = {k: self.prepend_path(v, base_path_nk) for k, v in kwds_def.items()}
1789 func = node_data.get(NodeAttributes.FUNC, None)
1790 executor = node_data.get(NodeAttributes.EXECUTOR, None)
1791 converter = node_data.get(NodeAttributes.CONVERTER, None)
1792 new_node_name = self.prepend_path(node_name, base_path_nk)
1793 self.add_node(
1794 new_node_name,
1795 func,
1796 args=args_prepended,
1797 kwds=kwds_prepended,
1798 converter=converter,
1799 serialize=False,
1800 inspect=False,
1801 group=group,
1802 tags=tags,
1803 style=style,
1804 executor=executor,
1805 )
1806 if keep_values and NodeAttributes.VALUE in node_data:
1807 new_node_key = to_nodekey(new_node_name)
1808 self._set_state_and_literal_value(
1809 new_node_key, node_data[NodeAttributes.STATE], node_data[NodeAttributes.VALUE]
1810 )
1811 if links is not None:
1812 for target, source in links.items():
1813 self.link(base_path_nk.join_parts(target), source)
1814 if metadata is not None:
1815 self._metadata[base_path_nk] = metadata
1816 else:
1817 if base_path_nk in self._metadata:
1818 del self._metadata[base_path_nk]
1820 def link(self, target: Name, source: Name) -> None:
1821 """Create a link between two nodes in the computation graph."""
1822 target_nk = to_nodekey(target)
1823 source_nk = to_nodekey(source)
1824 if target_nk == source_nk:
1825 return
1827 target_style = self._style_one(target_nk) if self.has_node(target_nk) else None
1828 source_style = self._style_one(source_nk) if self.has_node(source_nk) else None
1829 style = target_style if target_style else source_style
1831 self.add_node(target_nk, identity_function, kwds={"x": source_nk}, style=style)
1833 def _repr_svg_(self) -> str | None:
1834 """Return SVG representation for Jupyter notebook display."""
1835 return GraphView(self).svg()
1837 def draw(
1838 self,
1839 root: NodeKey | None = None,
1840 *,
1841 node_transformations: dict[Name, str] | None = None,
1842 cmap: Any = None,
1843 colors: str = "state",
1844 shapes: str | None = None,
1845 graph_attr: dict[str, Any] | None = None,
1846 node_attr: dict[str, Any] | None = None,
1847 edge_attr: dict[str, Any] | None = None,
1848 show_expansion: bool = False,
1849 collapse_all: bool = True,
1850 ) -> GraphView:
1851 """Draw a computation's current state using the GraphViz utility.
1853 :param root: Optional PathType. Sub-block to draw
1854 :param cmap: Default: None
1855 :param colors: 'state' - colors indicate state. 'timing' - colors indicate execution time. Default: 'state'.
1856 :param shapes: None - ovals. 'type' - shapes indicate type. Default: None.
1857 :param graph_attr: Mapping of (attribute, value) pairs for the graph. For example
1858 ``graph_attr={'size': '"10,8"'}`` can control the size of the output graph
1859 :param node_attr: Mapping of (attribute, value) pairs set for all nodes.
1860 :param edge_attr: Mapping of (attribute, value) pairs set for all edges.
1861 :param collapse_all: Whether to collapse all blocks that aren't explicitly expanded.
1862 """
1863 node_formatter = NodeFormatter.create(cmap, colors, shapes)
1864 node_transformations_copy: dict[Name, str] = (
1865 node_transformations.copy() if node_transformations is not None else {}
1866 )
1867 if not show_expansion:
1868 for nodekey in self.nodes_by_tag(SystemTags.EXPANSION):
1869 node_transformations_copy[nodekey] = NodeTransformations.CONTRACT
1870 v = GraphView(
1871 self,
1872 root=root,
1873 node_formatter=node_formatter,
1874 graph_attr=graph_attr,
1875 node_attr=node_attr,
1876 edge_attr=edge_attr,
1877 node_transformations=node_transformations_copy,
1878 collapse_all=collapse_all,
1879 )
1880 return v
1882 def view(self, cmap: Any = None, colors: str = "state", shapes: str | None = None) -> None:
1883 """Create and display a visualization of the computation graph."""
1884 node_formatter = NodeFormatter.create(cmap, colors, shapes)
1885 v = GraphView(self, node_formatter=node_formatter)
1886 v.view()
1888 def print_errors(self) -> None:
1889 """Print tracebacks for every node with state "ERROR" in a Computation."""
1890 for n in self.nodes():
1891 if self.s[n] == States.ERROR:
1892 print(f"{n}")
1893 print("=" * len(str(n)))
1894 print()
1895 print(self.v[n].traceback)
1896 print()
1898 @classmethod
1899 def from_class(cls, definition_class: type, ignore_self: bool = True) -> "Computation":
1900 """Create a computation from a class with decorated methods."""
1901 comp = cls()
1902 obj = definition_class()
1903 populate_computation_from_class(comp, definition_class, obj, ignore_self=ignore_self)
1904 return comp
1906 def inject_dependencies(self, dependencies: dict[Name, Any], *, force: bool = False) -> None:
1907 """Injects dependencies into the nodes of the current computation where nodes are in a placeholder state.
1909 (or all possible nodes when the 'force' parameter is set to True), using values
1910 provided in the 'dependencies' dictionary.
1912 Each key in the 'dependencies' dictionary corresponds to a node identifier, and the associated
1913 value is the dependency object to inject. If the value is a callable, it will be added as a calc node.
1915 :param dependencies: A dictionary where each key-value pair consists of a node identifier and
1916 its corresponding dependency object or a callable that returns the dependency object.
1917 :param force: A boolean flag that, when set to True, forces the replacement of existing node values
1918 with the ones provided in 'dependencies', regardless of their current state. Defaults to False.
1919 :return: None
1920 """
1921 for n in self.nodes():
1922 if force or self.s[n] == States.PLACEHOLDER:
1923 obj = dependencies.get(n)
1924 if obj is None:
1925 continue
1926 if callable(obj):
1927 self.add_node(n, obj)
1928 else:
1929 self.add_node(n, value=obj)