Coverage for src / loman / computeengine.py: 99%
1042 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-07 21:24 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-07 21:24 +0000
1"""Core computation engine for dependency-aware calculation graphs."""
3import functools
4import inspect
5import logging
6import traceback
7import types
8import warnings
9from collections import defaultdict
10from collections.abc import Callable, Iterable, Mapping
11from concurrent.futures import FIRST_COMPLETED, Executor, ThreadPoolExecutor, wait
12from dataclasses import dataclass, field
13from datetime import UTC, datetime
14from enum import Enum
15from typing import TYPE_CHECKING, Any, BinaryIO, TextIO, TypeVar, overload
17if TYPE_CHECKING:
18 from .serialization.computation import ComputationSerializer
20import decorator
21import dill # nosec B403
22import networkx as nx
23import pandas as pd
25from .compat import get_signature
26from .consts import EdgeAttributes, NodeAttributes, NodeTransformations, States, SystemTags
27from .exception import (
28 CannotInsertToPlaceholderNodeException,
29 ComputationError,
30 LoopDetectedException,
31 MapException,
32 NodeAlreadyExistsException,
33 NonExistentNodeException,
34 ValidationError,
35)
36from .graph_utils import topological_sort
37from .nodekey import Name, Names, NodeKey, names_to_node_keys, node_keys_to_names, to_nodekey
38from .util import AttributeView, apply1, apply_n, as_iterable, value_eq
39from .visualization import GraphView, NodeFormatter
41LOG = logging.getLogger("loman.computeengine")
43F = TypeVar("F", bound=Callable[..., Any])
46@dataclass
47class Error:
48 """Container for error information during computation."""
50 exception: Exception
51 traceback: str
54@dataclass
55class NodeData:
56 """Data associated with a computation node."""
58 state: States
59 value: object
62@dataclass
63class TimingData:
64 """Timing information for computation execution."""
66 start: datetime
67 end: datetime
68 duration: float
71class _ParameterType(Enum):
72 """Internal enum for distinguishing positional and keyword parameters."""
74 ARG = 1
75 KWD = 2
78@dataclass
79class _ParameterItem:
80 """Internal container for parameter information during computation."""
82 type: _ParameterType
83 name: int | str
84 value: object
87def _node(func: Callable[..., Any], *args: Any, **kws: Any) -> Any: # pragma: no cover
88 """Internal wrapper function for node decorator."""
89 return func(*args, **kws)
92def node(comp: "Computation", name: Name | None = None, *args: Any, **kw: Any) -> Callable[[F], F]:
93 """Decorator to add a function as a node to a computation graph."""
95 def inner(f: F) -> F:
96 """Inner decorator that registers the function as a node."""
97 if name is None:
98 comp.add_node(f.__name__, f, *args, **kw)
99 else:
100 comp.add_node(name, f, *args, **kw)
101 result: F = decorator.decorate(f, _node)
102 return result
104 return inner
107@dataclass()
108class ConstantValue:
109 """Container for constant values in computations."""
111 value: object
114C = ConstantValue
117class Node:
118 """Base class for computation graph nodes."""
120 def add_to_comp(self, comp: "Computation", name: str, obj: object, ignore_self: bool) -> None:
121 """Add this node to the computation graph."""
122 raise NotImplementedError()
125@dataclass
126class InputNode(Node):
127 """A node representing input data in the computation graph."""
129 args: tuple[Any, ...] = field(default_factory=tuple)
130 kwds: dict[str, Any] = field(default_factory=dict)
132 def __init__(self, *args: Any, **kwds: Any) -> None:
133 """Initialize an input node with arguments and keyword arguments."""
134 self.args = args
135 self.kwds = kwds
137 def add_to_comp(self, comp: "Computation", name: str, obj: object, ignore_self: bool) -> None:
138 """Add this input node to the computation graph."""
139 comp.add_node(name, **self.kwds)
142input_node = InputNode
145@dataclass
146class CalcNode(Node):
147 """A node representing a calculation in the computation graph."""
149 f: Callable[..., Any]
150 kwds: dict[str, Any] = field(default_factory=dict)
152 def add_to_comp(self, comp: "Computation", name: str, obj: object, ignore_self: bool) -> None:
153 """Add this calculation node to the computation graph."""
154 kwds = self.kwds.copy()
155 ignore_self = ignore_self or kwds.get("ignore_self", False)
156 f = self.f
157 if ignore_self:
158 signature = get_signature(self.f)
159 if len(signature.kwd_params) > 0 and signature.kwd_params[0] == "self":
160 f = f.__get__(obj, obj.__class__) # type: ignore[attr-defined]
161 if "ignore_self" in kwds:
162 del kwds["ignore_self"]
163 comp.add_node(name, f, **kwds)
166@overload
167def calc_node(f: F, **kwds: Any) -> F: ...
170@overload
171def calc_node(f: None = None, **kwds: Any) -> Callable[[F], F]: ...
174def calc_node(f: F | None = None, **kwds: Any) -> F | Callable[[F], F]:
175 """Decorator to mark a function as a calculation node."""
177 def wrap(func: F) -> F:
178 """Wrap function with node info attribute."""
179 func._loman_node_info = CalcNode(func, kwds)
180 return func
182 if f is None:
183 return wrap
184 return wrap(f)
187@dataclass
188class Block(Node):
189 """A node representing a computational block or subgraph."""
191 block: "Callable[[], Computation] | Computation"
192 args: tuple[Any, ...] = field(default_factory=tuple)
193 kwds: dict[str, Any] = field(default_factory=dict)
195 def __init__(self, block: "Callable[[], Computation] | Computation", *args: Any, **kwds: Any) -> None:
196 """Initialize a block node with a computation block and arguments."""
197 self.block = block
198 self.args = args
199 self.kwds = kwds
201 def add_to_comp(self, comp: "Computation", name: str, obj: object, ignore_self: bool) -> None:
202 """Add this block node to the computation graph."""
203 if isinstance(self.block, Computation):
204 comp.add_block(name, self.block, *self.args, **self.kwds)
205 elif callable(self.block):
206 block0 = self.block()
207 comp.add_block(name, block0, *self.args, **self.kwds)
208 else:
209 msg = f"Block {self.block} must be callable or Computation"
210 raise TypeError(msg)
213block = Block
216def populate_computation_from_class(comp: "Computation", cls: type, obj: object, ignore_self: bool = True) -> None:
217 """Populate a computation from class methods with node decorators."""
218 for name, member in inspect.getmembers(cls):
219 node_: Node | None = None
220 if isinstance(member, Node):
221 node_ = member
222 elif hasattr(member, "_loman_node_info"):
223 node_ = member._loman_node_info
224 if node_ is not None:
225 node_.add_to_comp(comp, name, obj, ignore_self)
228def computation_factory(
229 maybe_cls: type | None = None, *, ignore_self: bool = True
230) -> Callable[..., "Computation"] | Callable[[type], Callable[..., "Computation"]]:
231 """Factory function to create computations from class definitions."""
233 def wrap(cls: type) -> Callable[..., "Computation"]:
234 """Wrap class to create computation factory function."""
236 @functools.wraps(cls, updated=())
237 def create_computation(*args: Any, **kwargs: Any) -> "Computation":
238 """Create a computation instance from the wrapped class."""
239 obj = cls()
240 comp = Computation(*args, **kwargs)
241 comp._definition_object = obj # type: ignore[attr-defined]
242 populate_computation_from_class(comp, cls, obj, ignore_self)
243 return comp
245 return create_computation
247 if maybe_cls is None:
248 return wrap
250 return wrap(maybe_cls)
253def _eval_node(
254 name: NodeKey,
255 f: Callable[..., Any],
256 args: list[Any],
257 kwds: dict[str, Any],
258 raise_exceptions: bool,
259) -> tuple[Any, Exception | None, str | None, datetime, datetime]:
260 """To make multiprocessing work, this function must be standalone so that pickle works."""
261 exc: Exception | None = None
262 tb: str | None = None
263 start_dt = datetime.now(UTC)
264 try:
265 logging.debug("Running " + str(name))
266 value = f(*args, **kwds)
267 logging.debug("Completed " + str(name))
268 except Exception as e:
269 value = None
270 exc = e
271 tb = traceback.format_exc()
272 if raise_exceptions:
273 raise
274 end_dt = datetime.now(UTC)
275 return value, exc, tb, start_dt, end_dt
278_MISSING_VALUE_SENTINEL = object()
281class NullObject:
282 """Debug helper object that raises exceptions for all attribute/item access."""
284 def __getattr__(self, name: str) -> Any:
285 """Raise AttributeError for any attribute access."""
286 print(f"__getattr__: {name}")
287 msg = f"'NullObject' object has no attribute '{name}'"
288 raise AttributeError(msg)
290 def __setattr__(self, name: str, value: Any) -> None:
291 """Raise AttributeError for any attribute assignment."""
292 print(f"__setattr__: {name}")
293 msg = f"'NullObject' object has no attribute '{name}'"
294 raise AttributeError(msg)
296 def __delattr__(self, name: str) -> None:
297 """Raise AttributeError for any attribute deletion."""
298 print(f"__delattr__: {name}")
299 msg = f"'NullObject' object has no attribute '{name}'"
300 raise AttributeError(msg)
302 def __call__(self, *args: Any, **kwargs: Any) -> Any:
303 """Raise TypeError when called as a function."""
304 print(f"__call__: {args}, {kwargs}")
305 msg = "'NullObject' object is not callable"
306 raise TypeError(msg)
308 def __getitem__(self, key: Any) -> Any:
309 """Raise KeyError for any item access."""
310 print(f"__getitem__: {key}")
311 msg = f"'NullObject' object has no item with key '{key}'"
312 raise KeyError(msg)
314 def __setitem__(self, key: Any, value: Any) -> None:
315 """Raise KeyError for any item assignment."""
316 print(f"__setitem__: {key}")
317 msg = f"'NullObject' object cannot have items set with key '{key}'"
318 raise KeyError(msg)
320 def __repr__(self) -> str:
321 """Return string representation of NullObject."""
322 print(f"__repr__: {object.__getattribute__(self, '__dict__')}")
323 return "<NullObject>"
326def identity_function(x: Any) -> Any:
327 """Return the input value unchanged."""
328 return x
331class Computation:
332 """A computation graph that manages dependencies and calculations.
334 The Computation class provides a framework for building and executing
335 computation graphs where nodes represent data or calculations, and edges
336 represent dependencies between them.
337 """
339 def __init__(
340 self,
341 *,
342 default_executor: Executor | None = None,
343 executor_map: dict[str, Executor] | None = None,
344 metadata: dict[str, Any] | None = None,
345 ) -> None:
346 """Initialize a new Computation.
348 :param default_executor: An executor
349 :type default_executor: concurrent.futures.Executor, default ThreadPoolExecutor(max_workers=1)
350 """
351 if default_executor is None:
352 self.default_executor: Executor = ThreadPoolExecutor(1)
353 else:
354 self.default_executor = default_executor
355 if executor_map is None:
356 self.executor_map: dict[str, Executor] = {}
357 else:
358 self.executor_map = executor_map
359 self.dag: nx.DiGraph = nx.DiGraph()
360 self._metadata: dict[NodeKey, Any] = {}
361 if metadata is not None:
362 self._metadata[NodeKey.root()] = metadata
364 self.v = self.get_attribute_view_for_path(NodeKey.root(), self._value_one, self.value)
365 self.s = self.get_attribute_view_for_path(NodeKey.root(), self._state_one, self.state)
366 self.i = self.get_attribute_view_for_path(NodeKey.root(), self._get_inputs_one_names, self.get_inputs)
367 self.o = self.get_attribute_view_for_path(NodeKey.root(), self._get_outputs_one, self.get_outputs)
368 self.t = self.get_attribute_view_for_path(NodeKey.root(), self._tag_one, self.tags)
369 self.style = self.get_attribute_view_for_path(NodeKey.root(), self._style_one, self.styles)
370 self.tim = self.get_attribute_view_for_path(NodeKey.root(), self._get_timing_one, self.get_timing)
371 self.x = self.get_attribute_view_for_path(
372 NodeKey.root(), self.compute_and_get_value, self.compute_and_get_value
373 )
374 self.src = self.get_attribute_view_for_path(NodeKey.root(), self.print_source, self.print_source)
375 self._tag_map: defaultdict[str, set[NodeKey]] = defaultdict(set)
376 self._state_map: dict[States, set[NodeKey]] = {state: set() for state in States}
378 def get_attribute_view_for_path(
379 self, nodekey: NodeKey, get_one_func: Callable[[Name], Any], get_many_func: Callable[[Name | Names], Any]
380 ) -> AttributeView:
381 """Create an attribute view for a specific node path."""
383 def node_func() -> Iterable[str]:
384 """Return list of child node names for this path."""
385 return [str(n) for n in self.get_tree_list_children(nodekey)]
387 def get_one_func_for_path(name: str) -> Any:
388 """Get value for a single node at this path."""
389 nk = to_nodekey(name)
390 new_nk = nk.prepend(nodekey)
391 if self.has_node(new_nk):
392 return get_one_func(new_nk)
393 elif self.tree_has_path(new_nk):
394 return self.get_attribute_view_for_path(new_nk, get_one_func, get_many_func)
395 else:
396 msg = f"Path {new_nk} does not exist"
397 raise KeyError(msg) # pragma: no cover
399 def get_many_func_for_path(name: Name | Names) -> Any:
400 """Get values for one or more nodes at this path."""
401 if isinstance(name, list):
402 return [get_one_func_for_path(str(n)) for n in name]
403 else:
404 return get_one_func_for_path(str(name))
406 return AttributeView(node_func, get_one_func_for_path, get_many_func_for_path)
408 def _get_names_for_state(self, state: States) -> set[Name]:
409 """Get node names that have a specific state."""
410 return set(node_keys_to_names(self._state_map[state]))
412 def _get_tags_for_state(self, tag: str) -> set[Name]:
413 """Get node names that have a specific tag."""
414 return set(node_keys_to_names(self._tag_map[tag]))
416 def _process_function_args(self, node_key: NodeKey, node: dict[str, Any], args: list[Any] | None) -> int:
417 """Process positional arguments for a function node."""
418 args_count = 0
419 if args:
420 args_count = len(args)
421 for i, arg in enumerate(args):
422 if isinstance(arg, ConstantValue):
423 node[NodeAttributes.ARGS][i] = arg.value
424 else:
425 input_vertex_name = arg
426 input_vertex_node_key = to_nodekey(input_vertex_name)
427 if not self.dag.has_node(input_vertex_node_key):
428 self.dag.add_node(input_vertex_node_key, **{NodeAttributes.STATE: States.PLACEHOLDER})
429 self._state_map[States.PLACEHOLDER].add(input_vertex_node_key)
430 self.dag.add_edge(
431 input_vertex_node_key, node_key, **{EdgeAttributes.PARAM: (_ParameterType.ARG, i)}
432 )
433 return args_count
435 def _build_param_map(
436 self,
437 func: Callable[..., Any],
438 node_key: NodeKey,
439 args_count: int,
440 kwds: dict[str, Any] | None,
441 inspect: bool,
442 ) -> tuple[dict[str, Any], list[str]]:
443 """Build parameter map for function node."""
444 param_map: dict[str, Any] = {}
445 default_names: list[str] = []
447 if inspect:
448 signature = get_signature(func)
449 if not signature.has_var_args:
450 for param_name in signature.kwd_params[args_count:]:
451 if kwds is not None and param_name in kwds:
452 param_source = kwds[param_name]
453 else:
454 param_source = node_key.parent.join_parts(param_name)
455 param_map[param_name] = param_source
456 if signature.has_var_kwds and kwds is not None:
457 for param_name, param_source in kwds.items():
458 param_map[param_name] = param_source
459 default_names = signature.default_params
460 else:
461 if kwds is not None:
462 for param_name, param_source in kwds.items():
463 param_map[param_name] = param_source
465 return param_map, default_names
467 def _process_function_kwds(
468 self, node_key: NodeKey, node: dict[str, Any], param_map: dict[str, Any], default_names: list[str]
469 ) -> None:
470 """Process keyword arguments for a function node."""
471 for param_name, param_source in param_map.items():
472 if isinstance(param_source, ConstantValue):
473 node[NodeAttributes.KWDS][param_name] = param_source.value
474 else:
475 in_node_name = param_source
476 in_node_key = to_nodekey(in_node_name)
477 if not self.dag.has_node(in_node_key):
478 if param_name in default_names:
479 continue
480 else:
481 self.dag.add_node(in_node_key, **{NodeAttributes.STATE: States.PLACEHOLDER})
482 self._state_map[States.PLACEHOLDER].add(in_node_key)
483 self.dag.add_edge(in_node_key, node_key, **{EdgeAttributes.PARAM: (_ParameterType.KWD, param_name)})
485 def add_node(
486 self,
487 name: Name,
488 func: Callable[..., Any] | None = None,
489 *,
490 args: list[Any] | None = None,
491 kwds: dict[str, Any] | None = None,
492 value: Any = _MISSING_VALUE_SENTINEL,
493 converter: Callable[[Any], Any] | None = None,
494 serialize: bool = True,
495 inspect: bool = True,
496 group: str | None = None,
497 tags: Iterable[str] | None = None,
498 style: str | None = None,
499 executor: str | None = None,
500 metadata: dict[str, Any] | None = None,
501 ) -> None:
502 """Adds or updates a node in a computation.
504 :param name: Name of the node to add. This may be any hashable object.
505 :param func: Function to use to calculate the node if the node is a calculation node. By default, the input
506 nodes to the function will be implied from the names of the function parameters. For example, a
507 parameter called ``a`` would be taken from the node called ``a``. This can be modified with the
508 ``kwds`` parameter.
509 :type func: Function, default None
510 :param args: Specifies a list of nodes that will be used to populate arguments of the function positionally
511 for a calculation node. e.g. If args is ``['a', 'b', 'c']`` then the function would be called with
512 three parameters, taken from the nodes 'a', 'b' and 'c' respectively.
513 :type args: List, default None
514 :param kwds: Specifies a mapping from parameter name to the node that should be used to populate that
515 parameter when calling the function for a calculation node. e.g. If args is ``{'x': 'a', 'y': 'b'}``
516 then the function would be called with parameters named 'x' and 'y', and their values would be taken
517 from nodes 'a' and 'b' respectively. Each entry in the dictionary can be read as "take parameter
518 [key] from node [value]".
519 :type kwds: Dictionary, default None
520 :param value: If given, the value is inserted into the node, and the node state set to UPTODATE.
521 :type value: default None
522 :param serialize: Whether the node should be serialized. Some objects cannot be serialized, in which
523 case, set serialize to False
524 :type serialize: boolean, default True
525 :param inspect: Whether to use introspection to determine the arguments of the function, which can be
526 slow. If this is not set, kwds and args must be set for the function to obtain parameters.
527 :type inspect: boolean, default True
528 :param group: Subgraph to render node in
529 :type group: default None
530 :param tags: Set of tags to apply to node
531 :type tags: Iterable
532 :param styles: Style to apply to node
533 :type styles: String, default None
534 :param executor: Name of executor to run node on
535 :type executor: string
536 """
537 node_key = to_nodekey(name)
538 LOG.debug(f"Adding node {node_key}")
539 has_value = value is not _MISSING_VALUE_SENTINEL
540 if value is _MISSING_VALUE_SENTINEL:
541 value = None
542 if tags is None:
543 tags = []
545 self.dag.add_node(node_key)
546 pred_edges = [(p, node_key) for p in self.dag.predecessors(node_key)]
547 self.dag.remove_edges_from(pred_edges)
548 node = self.dag.nodes[node_key]
550 if metadata is None:
551 if node_key in self._metadata:
552 del self._metadata[node_key]
553 else:
554 self._metadata[node_key] = metadata
556 self._set_state_and_literal_value(node_key, States.UNINITIALIZED, None, require_old_state=False)
558 node[NodeAttributes.TAG] = set()
559 node[NodeAttributes.STYLE] = style
560 node[NodeAttributes.GROUP] = group
561 node[NodeAttributes.ARGS] = {}
562 node[NodeAttributes.KWDS] = {}
563 node[NodeAttributes.FUNC] = None
564 node[NodeAttributes.EXECUTOR] = executor
565 node[NodeAttributes.CONVERTER] = converter
567 if func:
568 node[NodeAttributes.FUNC] = func
569 args_count = self._process_function_args(node_key, node, args)
570 param_map, default_names = self._build_param_map(func, node_key, args_count, kwds, inspect)
571 self._process_function_kwds(node_key, node, param_map, default_names)
572 self._set_descendents(node_key, States.STALE)
574 if has_value:
575 self._set_uptodate(node_key, value)
576 if node[NodeAttributes.STATE] == States.UNINITIALIZED:
577 self._try_set_computable(node_key)
578 self.set_tag(node_key, tags)
579 if serialize:
580 self.set_tag(node_key, SystemTags.SERIALIZE)
582 def _refresh_maps(self) -> None:
583 """Refresh internal tag and state maps from node data."""
584 self._tag_map.clear()
585 for state in States:
586 self._state_map[state].clear()
587 for node_key in self._node_keys():
588 state = self.dag.nodes[node_key][NodeAttributes.STATE]
589 self._state_map[state].add(node_key)
590 tags = self.dag.nodes[node_key].get(NodeAttributes.TAG, set())
591 for tag in tags:
592 self._tag_map[tag].add(node_key)
594 def _set_tag_one(self, name: Name, tag: str) -> None:
595 """Set a single tag on a single node."""
596 node_key = to_nodekey(name)
597 self.dag.nodes[node_key][NodeAttributes.TAG].add(tag)
598 self._tag_map[tag].add(node_key)
600 def set_tag(self, name: Name | Names, tag: str | Iterable[str]) -> None:
601 """Set tags on a node or nodes. Ignored if tags are already set.
603 :param name: Node or nodes to set tag for
604 :param tag: Tag to set
605 """
606 apply_n(self._set_tag_one, name, tag)
608 def _clear_tag_one(self, name: Name, tag: str) -> None:
609 """Clear a single tag from a single node."""
610 node_key = to_nodekey(name)
611 self.dag.nodes[node_key][NodeAttributes.TAG].discard(tag)
612 self._tag_map[tag].discard(node_key)
614 def clear_tag(self, name: Name | Names, tag: str | Iterable[str]) -> None:
615 """Clear tag on a node or nodes. Ignored if tags are not set.
617 :param name: Node or nodes to clear tags for
618 :param tag: Tag to clear
619 """
620 apply_n(self._clear_tag_one, name, tag)
622 def _set_style_one(self, name: Name, style: str) -> None:
623 """Set style on a single node."""
624 node_key = to_nodekey(name)
625 self.dag.nodes[node_key][NodeAttributes.STYLE] = style
627 def set_style(self, name: Name | Names, style: str) -> None:
628 """Set styles on a node or nodes.
630 :param name: Node or nodes to set style for
631 :param style: Style to set
632 """
633 apply_n(self._set_style_one, name, style)
635 def _clear_style_one(self, name: Name) -> None:
636 """Clear style from a single node."""
637 node_key = to_nodekey(name)
638 self.dag.nodes[node_key][NodeAttributes.STYLE] = None
640 def clear_style(self, name: Name | Names) -> None:
641 """Clear style on a node or nodes.
643 :param name: Node or nodes to clear styles for
644 """
645 apply_n(self._clear_style_one, name)
647 def metadata(self, name: Name) -> dict[str, Any]:
648 """Get metadata for a node."""
649 node_key = to_nodekey(name)
650 if self.tree_has_path(name):
651 if node_key not in self._metadata:
652 self._metadata[node_key] = {}
653 result: dict[str, Any] = self._metadata[node_key]
654 return result
655 else:
656 msg = f"Node {node_key} does not exist."
657 raise NonExistentNodeException(msg)
659 def delete_node(self, name: Name) -> None:
660 """Delete a node from a computation.
662 When nodes are explicitly deleted with ``delete_node``, but are still depended on by other nodes, then they
663 will be set to PLACEHOLDER status. In this case, if the nodes that depend on a PLACEHOLDER node are deleted,
664 then the PLACEHOLDER node will also be deleted.
666 :param name: Name of the node to delete. If the node does not exist, a ``NonExistentNodeException`` will
667 be raised.
668 """
669 node_key = to_nodekey(name)
670 LOG.debug(f"Deleting node {node_key}")
672 if not self.dag.has_node(node_key):
673 msg = f"Node {node_key} does not exist"
674 raise NonExistentNodeException(msg)
676 if node_key in self._metadata:
677 del self._metadata[node_key]
679 if len(self.dag.succ[node_key]) == 0:
680 preds = self.dag.predecessors(node_key)
681 state = self.dag.nodes[node_key][NodeAttributes.STATE]
682 self.dag.remove_node(node_key)
683 self._state_map[state].remove(node_key)
684 for n in preds:
685 if self.dag.nodes[n][NodeAttributes.STATE] == States.PLACEHOLDER:
686 self.delete_node(n)
687 else:
688 self._set_state(node_key, States.PLACEHOLDER)
690 def rename_node(self, old_name: Name | Mapping[Name, Name], new_name: Name | None = None) -> None:
691 """Rename a node in a computation.
693 :param old_name: Node to rename, or a dictionary of nodes to rename, with existing names as keys, and
694 new names as values
695 :param new_name: New name for node.
696 """
697 name_mapping: dict[Name, Name]
698 if isinstance(old_name, Mapping) and not isinstance(old_name, str):
699 for k, v in old_name.items():
700 LOG.debug(f"Renaming node {k} to {v}")
701 if new_name is not None:
702 msg = "new_name must not be set if rename_node is passed a dictionary"
703 raise ValueError(msg)
704 else:
705 name_mapping = dict(old_name) # type: ignore[arg-type]
706 else:
707 LOG.debug(f"Renaming node {old_name} to {new_name}")
708 old_node_key = to_nodekey(old_name)
709 if not self.dag.has_node(old_node_key):
710 msg = f"Node {old_name} does not exist"
711 raise NonExistentNodeException(msg)
712 assert new_name is not None # noqa: S101
713 new_node_key = to_nodekey(new_name)
714 if self.dag.has_node(new_node_key):
715 msg = f"Node {new_name} already exists"
716 raise NodeAlreadyExistsException(msg)
717 name_mapping = {old_name: new_name}
719 node_key_mapping = {to_nodekey(on): to_nodekey(nn) for on, nn in name_mapping.items()}
720 nx.relabel_nodes(self.dag, node_key_mapping, copy=False)
722 for old_nk, new_nk in node_key_mapping.items():
723 if old_nk in self._metadata:
724 self._metadata[new_nk] = self._metadata[old_nk]
725 del self._metadata[old_nk]
726 else:
727 if new_nk in self._metadata: # pragma: no cover
728 del self._metadata[new_nk]
730 self._refresh_maps()
732 def repoint(self, old_name: Name, new_name: Name) -> None:
733 """Changes all nodes that use old_name as an input to use new_name instead.
735 Note that if old_name is an input to new_name, then that will not be changed, to try to avoid introducing
736 circular dependencies, but other circular dependencies will not be checked.
738 If new_name does not exist, then it will be created as a PLACEHOLDER node.
740 :param old_name:
741 :param new_name:
742 :return:
743 """
744 old_node_key = to_nodekey(old_name)
745 new_node_key = to_nodekey(new_name)
746 if old_node_key == new_node_key:
747 return
749 changed_names = list(self.dag.successors(old_node_key))
751 if len(changed_names) > 0 and not self.dag.has_node(new_node_key):
752 self.dag.add_node(new_node_key, **{NodeAttributes.STATE: States.PLACEHOLDER})
753 self._state_map[States.PLACEHOLDER].add(new_node_key)
755 for name in changed_names:
756 if name == new_node_key:
757 continue
758 edge_data = self.dag.get_edge_data(old_node_key, name)
759 self.dag.add_edge(new_node_key, name, **edge_data)
760 self.dag.remove_edge(old_node_key, name)
762 for n in changed_names:
763 self.set_stale(n)
765 def insert(self, name: Name, value: Any, force: bool = False) -> None:
766 """Insert a value into a node of a computation.
768 Following insertation, the node will have state UPTODATE, and all its descendents will be COMPUTABLE or STALE.
770 If an attempt is made to insert a value into a node that does not exist, a ``NonExistentNodeException``
771 will be raised.
773 :param name: Name of the node to add.
774 :param value: The value to be inserted into the node.
775 :param force: Whether to force recalculation of descendents if node value and state would not be changed
776 """
777 node_key = to_nodekey(name)
778 LOG.debug(f"Inserting value into node {node_key}")
780 if not self.dag.has_node(node_key):
781 msg = f"Node {node_key} does not exist"
782 raise NonExistentNodeException(msg)
784 state = self._state_one(name)
785 if state == States.PLACEHOLDER:
786 msg = "Cannot insert into placeholder node. Use add_node to create the node first"
787 raise CannotInsertToPlaceholderNodeException(msg)
789 if not force and state == States.UPTODATE:
790 current_value = self._value_one(name)
791 if value_eq(value, current_value):
792 return
794 self._set_state_and_value(node_key, States.UPTODATE, value)
795 self._set_descendents(node_key, States.STALE)
796 for n in self.dag.successors(node_key):
797 self._try_set_computable(n)
799 def insert_many(self, name_value_pairs: Iterable[tuple[Name, object]]) -> None:
800 """Insert values into many nodes of a computation simultaneously.
802 Following insertation, the nodes will have state UPTODATE, and all their descendents will be COMPUTABLE
803 or STALE. In the case of inserting many nodes, some of which are descendents of others, this ensures that
804 the inserted nodes have correct status, rather than being set as STALE when their ancestors are inserted.
806 If an attempt is made to insert a value into a node that does not exist, a ``NonExistentNodeException`` will be
807 raised, and none of the nodes will be inserted.
809 :param name_value_pairs: Each tuple should be a pair (name, value), where name is the name of the node to
810 insert the value into.
811 :type name_value_pairs: List of tuples
812 """
813 node_key_value_pairs = [(to_nodekey(name), value) for name, value in name_value_pairs]
814 LOG.debug(f"Inserting value into nodes {', '.join(str(name) for name, value in node_key_value_pairs)}")
816 for name, _value in node_key_value_pairs:
817 if not self.dag.has_node(name):
818 msg = f"Node {name} does not exist"
819 raise NonExistentNodeException(msg)
821 stale = set()
822 computable = set()
823 for name, value in node_key_value_pairs:
824 self._set_state_and_value(name, States.UPTODATE, value)
825 stale.update(nx.dag.descendants(self.dag, name))
826 computable.update(self.dag.successors(name))
827 names = {name for name, value in node_key_value_pairs}
828 stale.difference_update(names)
829 computable.difference_update(names)
830 for name in stale:
831 self._set_state(name, States.STALE)
832 for name in computable:
833 self._try_set_computable(name)
835 def insert_from(self, other: "Computation", nodes: Iterable[Name] | None = None) -> None:
836 """Insert values into another Computation object into this Computation object.
838 :param other: The computation object to take values from
839 :type Computation:
840 :param nodes: Only populate the nodes with the names provided in this list. By default, all nodes from the
841 other Computation object that have corresponding nodes in this Computation object will be inserted
842 :type nodes: List, default None
843 """
844 if nodes is None:
845 nodes_set: set[Any] = set(self.dag.nodes)
846 nodes_set.intersection_update(other.dag.nodes())
847 nodes = nodes_set
848 name_value_pairs = [(name, other.value(name)) for name in nodes]
849 self.insert_many(name_value_pairs)
851 def _set_state(self, node_key: NodeKey, state: States) -> None:
852 """Set the state of a node without changing its value."""
853 node = self.dag.nodes[node_key]
854 old_state = node[NodeAttributes.STATE]
855 self._state_map[old_state].remove(node_key)
856 node[NodeAttributes.STATE] = state
857 self._state_map[state].add(node_key)
859 def _set_state_and_value(
860 self, node_key: NodeKey, state: States, value: object, *, throw_conversion_exception: bool = True
861 ) -> None:
862 """Set state and value of a node, applying any converter."""
863 node = self.dag.nodes[node_key]
864 converter = node.get(NodeAttributes.CONVERTER)
865 if converter is None:
866 self._set_state_and_literal_value(node_key, state, value)
867 else:
868 try:
869 converted_value = converter(value)
870 self._set_state_and_literal_value(node_key, state, converted_value)
871 except Exception as e:
872 tb = traceback.format_exc()
873 self._set_error(node_key, e, tb)
874 if throw_conversion_exception:
875 raise
877 def _set_state_and_literal_value(
878 self, node_key: NodeKey, state: States, value: object, require_old_state: bool = True
879 ) -> None:
880 """Set state and literal value of a node without conversion."""
881 node = self.dag.nodes[node_key]
882 try:
883 old_state = node[NodeAttributes.STATE]
884 self._state_map[old_state].remove(node_key)
885 except KeyError:
886 if require_old_state:
887 raise # pragma: no cover
888 node[NodeAttributes.STATE] = state
889 node[NodeAttributes.VALUE] = value
890 self._state_map[state].add(node_key)
892 def _set_states(self, node_keys: Iterable[NodeKey], state: States) -> None:
893 """Set the state of multiple nodes at once."""
894 for name in node_keys:
895 node = self.dag.nodes[name]
896 old_state = node[NodeAttributes.STATE]
897 self._state_map[old_state].remove(name)
898 node[NodeAttributes.STATE] = state
899 self._state_map[state].update(node_keys)
901 def set_stale(self, name: Name) -> None:
902 """Set the state of a node and all its dependencies to STALE.
904 :param name: Name of the node to set as STALE.
905 """
906 node_key = to_nodekey(name)
907 node_keys: list[NodeKey] = [node_key]
908 node_keys.extend(nx.dag.descendants(self.dag, node_key))
909 self._set_states(node_keys, States.STALE)
910 self._try_set_computable(node_key)
912 def pin(self, name: Name, value: Any = None) -> None:
913 """Set the state of a node to PINNED.
915 :param name: Name of the node to set as PINNED.
916 :param value: Value to pin to the node, if provided.
917 :type value: default None
918 """
919 node_key = to_nodekey(name)
920 if value is not None:
921 self.insert(node_key, value)
922 self._set_states([node_key], States.PINNED)
924 def unpin(self, name: Name) -> None:
925 """Unpin a node (state of node and all descendents will be set to STALE).
927 :param name: Name of the node to set as PINNED.
928 """
929 node_key = to_nodekey(name)
930 self.set_stale(node_key)
932 def _get_descendents(self, node_key: NodeKey, stop_states: set[States] | None = None) -> set[NodeKey]:
933 """Get all descendant nodes, optionally stopping at certain states."""
934 if stop_states is None:
935 stop_states = set()
936 if self.dag.nodes[node_key][NodeAttributes.STATE] in stop_states:
937 return set()
938 visited = set()
939 to_visit = {node_key}
940 while to_visit:
941 n = to_visit.pop()
942 visited.add(n)
943 for n1 in self.dag.successors(n):
944 if n1 in visited:
945 continue
946 if self.dag.nodes[n1][NodeAttributes.STATE] in stop_states:
947 continue
948 to_visit.add(n1)
949 visited.remove(node_key)
950 return visited
952 def _set_descendents(self, node_key: NodeKey, state: States) -> None:
953 """Set the state of all descendant nodes."""
954 descendents = self._get_descendents(node_key, {States.PINNED})
955 self._set_states(descendents, state)
957 def _set_uninitialized(self, node_key: NodeKey) -> None:
958 """Set a node to uninitialized state and clear its value."""
959 self._set_states([node_key], States.UNINITIALIZED)
960 self.dag.nodes[node_key].pop(NodeAttributes.VALUE, None)
962 def _set_uptodate(self, node_key: NodeKey, value: object) -> None:
963 """Set a node to up-to-date state with a value."""
964 self._set_state_and_value(node_key, States.UPTODATE, value)
965 self._set_descendents(node_key, States.STALE)
966 for n in self.dag.successors(node_key):
967 self._try_set_computable(n)
969 def _set_error(self, node_key: NodeKey, exc: Exception, tb: str) -> None:
970 """Set a node to error state with exception information."""
971 self._set_state_and_literal_value(node_key, States.ERROR, Error(exc, tb))
972 self._set_descendents(node_key, States.STALE)
974 def _try_set_computable(self, node_key: NodeKey) -> None:
975 """Set node to computable if all predecessors are up-to-date."""
976 if self.dag.nodes[node_key][NodeAttributes.STATE] == States.PINNED:
977 return
978 if self.dag.nodes[node_key].get(NodeAttributes.FUNC) is not None:
979 for n in self.dag.predecessors(node_key):
980 if not self.dag.has_node(n):
981 return # pragma: no cover
982 if self.dag.nodes[n][NodeAttributes.STATE] != States.UPTODATE:
983 return
984 self._set_state(node_key, States.COMPUTABLE)
986 def _get_parameter_data(self, node_key: NodeKey) -> Iterable[_ParameterItem]:
987 """Get all parameter data for a node's function call."""
988 for arg, value in self.dag.nodes[node_key][NodeAttributes.ARGS].items():
989 yield _ParameterItem(_ParameterType.ARG, arg, value)
990 for param_name, value in self.dag.nodes[node_key][NodeAttributes.KWDS].items():
991 yield _ParameterItem(_ParameterType.KWD, param_name, value)
992 for in_node_name in self.dag.predecessors(node_key):
993 param_value = self.dag.nodes[in_node_name][NodeAttributes.VALUE]
994 edge = self.dag[in_node_name][node_key]
995 param_type, param_name = edge[EdgeAttributes.PARAM]
996 yield _ParameterItem(param_type, param_name, param_value)
998 def _get_func_args_kwds(
999 self, node_key: NodeKey
1000 ) -> tuple[Callable[..., Any], str | None, list[Any], dict[str, Any]]:
1001 """Get function, executor name, args and kwargs for a node."""
1002 node0 = self.dag.nodes[node_key]
1003 f = node0[NodeAttributes.FUNC]
1004 executor_name = node0.get(NodeAttributes.EXECUTOR)
1005 args: list[Any] = []
1006 kwds: dict[str, Any] = {}
1007 for param in self._get_parameter_data(node_key):
1008 if param.type == _ParameterType.ARG:
1009 idx = param.name
1010 assert isinstance(idx, int) # noqa: S101
1011 while len(args) <= idx:
1012 args.append(None)
1013 args[idx] = param.value
1014 elif param.type == _ParameterType.KWD:
1015 assert isinstance(param.name, str) # noqa: S101
1016 kwds[param.name] = param.value
1017 else: # pragma: no cover
1018 msg = f"Unexpected param type: {param.type}"
1019 raise ValidationError(msg)
1020 return f, executor_name, args, kwds
1022 def get_definition_args_kwds(self, name: Name) -> tuple[list[Any], dict[str, Any]]:
1023 """Get the arguments and keyword arguments for a node's function definition."""
1024 res_args: list[Any] = []
1025 res_kwds: dict[str, Any] = {}
1026 node_key = to_nodekey(name)
1027 node_data = self.dag.nodes[node_key]
1028 if NodeAttributes.ARGS in node_data:
1029 for idx, value in node_data[NodeAttributes.ARGS].items():
1030 while len(res_args) <= idx:
1031 res_args.append(None)
1032 res_args[idx] = C(value)
1033 if NodeAttributes.KWDS in node_data:
1034 for param_name, value in node_data[NodeAttributes.KWDS].items():
1035 res_kwds[param_name] = C(value)
1036 for in_node_key in self.dag.predecessors(node_key):
1037 edge = self.dag[in_node_key][node_key]
1038 if EdgeAttributes.PARAM in edge:
1039 param_type, param_name = edge[EdgeAttributes.PARAM]
1040 if param_type == _ParameterType.ARG:
1041 idx = param_name
1042 assert isinstance(idx, int) # noqa: S101
1043 while len(res_args) <= idx:
1044 res_args.append(None)
1045 res_args[idx] = in_node_key.name
1046 elif param_type == _ParameterType.KWD:
1047 res_kwds[param_name] = in_node_key.name
1048 else: # pragma: no cover
1049 msg = f"Unexpected param type: {param_type}"
1050 raise ValidationError(msg)
1051 return res_args, res_kwds
1053 def _compute_nodes(self, node_keys: Iterable[NodeKey], raise_exceptions: bool = False) -> None:
1054 """Compute multiple nodes, handling dependencies and parallel execution."""
1055 LOG.debug(f"Computing nodes {node_keys}")
1057 futs: dict[Any, NodeKey] = {}
1058 node_keys_set = set(node_keys)
1060 def run(name: NodeKey) -> None:
1061 """Submit a node computation to an executor."""
1062 f, executor_name, args, kwds = self._get_func_args_kwds(name)
1063 executor = self.default_executor if executor_name is None else self.executor_map[executor_name]
1064 fut = executor.submit(_eval_node, name, f, args, kwds, raise_exceptions)
1065 futs[fut] = name
1067 computed: set[NodeKey] = set()
1069 for node_key in node_keys_set:
1070 node0 = self.dag.nodes[node_key]
1071 state = node0[NodeAttributes.STATE]
1072 if state == States.COMPUTABLE:
1073 run(node_key)
1075 while len(futs) > 0:
1076 done, _not_done = wait(futs.keys(), return_when=FIRST_COMPLETED)
1077 for fut in done:
1078 node_key = futs.pop(fut)
1079 node0 = self.dag.nodes[node_key]
1080 try:
1081 value, exc, tb, start_dt, end_dt = fut.result()
1082 except Exception as e:
1083 exc = e
1084 tb = traceback.format_exc()
1085 self._set_error(node_key, exc, tb)
1086 raise
1087 delta = (end_dt - start_dt).total_seconds()
1088 if exc is None:
1089 self._set_state_and_value(node_key, States.UPTODATE, value, throw_conversion_exception=False)
1090 node0[NodeAttributes.TIMING] = TimingData(start_dt, end_dt, delta)
1091 self._set_descendents(node_key, States.STALE)
1092 for n in self.dag.successors(node_key):
1093 logging.debug(str(node_key) + " " + str(n) + " " + str(computed))
1094 if n in computed:
1095 msg = f"Calculating {node_key} for the second time"
1096 raise LoopDetectedException(msg)
1097 self._try_set_computable(n)
1098 node0 = self.dag.nodes[n]
1099 state = node0[NodeAttributes.STATE]
1100 if state == States.COMPUTABLE and n in node_keys_set:
1101 run(n)
1102 else:
1103 assert tb is not None # noqa: S101
1104 self._set_error(node_key, exc, tb)
1105 computed.add(node_key)
1107 def _get_calc_node_keys(self, node_key: NodeKey) -> list[NodeKey]:
1108 """Get node keys that need to be computed for a target node."""
1109 g = nx.DiGraph()
1110 g.add_nodes_from(self.dag.nodes)
1111 g.add_edges_from(self.dag.edges)
1112 for n in nx.ancestors(g, node_key):
1113 node = self.dag.nodes[n]
1114 state = node[NodeAttributes.STATE]
1115 if state == States.UPTODATE or state == States.PINNED:
1116 g.remove_node(n)
1118 ancestors = nx.ancestors(g, node_key)
1119 for n in ancestors:
1120 node = self.dag.nodes[n]
1121 state = node[NodeAttributes.STATE]
1123 if state == States.UNINITIALIZED and len(self.dag.pred[n]) == 0:
1124 msg = f"Cannot compute {node_key} because {n} uninitialized"
1125 raise ValidationError(msg)
1126 if state == States.PLACEHOLDER:
1127 msg = f"Cannot compute {node_key} because {n} is placeholder"
1128 raise ValidationError(msg)
1130 ancestors.add(node_key)
1131 nodes_sorted = topological_sort(g)
1132 return [n for n in nodes_sorted if n in ancestors]
1134 def _get_calc_node_names(self, name: Name) -> Names:
1135 """Get node names that need to be computed for a target node."""
1136 node_key = to_nodekey(name)
1137 return node_keys_to_names(self._get_calc_node_keys(node_key))
1139 def compute(self, name: Name | Iterable[Name], raise_exceptions: bool = False) -> None:
1140 """Compute a node and all necessary predecessors.
1142 Following the computation, if successful, the target node, and all necessary ancestors that were not already
1143 UPTODATE will have been calculated and set to UPTODATE. Any node that did not need to be calculated will not
1144 have been recalculated.
1146 If any nodes raises an exception, then the state of that node will be set to ERROR, and its value set to an
1147 object containing the exception object, as well as a traceback. This will not halt the computation, which
1148 will proceed as far as it can, until no more nodes that would be required to calculate the target are
1149 COMPUTABLE.
1151 :param name: Name of the node to compute
1152 :param raise_exceptions: Whether to pass exceptions raised by node computations back to the caller
1153 :type raise_exceptions: Boolean, default False
1154 """
1155 calc_nodes: set[NodeKey] | list[NodeKey]
1156 if isinstance(name, (types.GeneratorType, list)):
1157 calc_nodes = set()
1158 for name0 in name:
1159 node_key = to_nodekey(name0)
1160 for n in self._get_calc_node_keys(node_key):
1161 calc_nodes.add(n)
1162 else:
1163 node_key = to_nodekey(name)
1164 calc_nodes = self._get_calc_node_keys(node_key)
1165 self._compute_nodes(calc_nodes, raise_exceptions=raise_exceptions)
1167 def compute_all(self, raise_exceptions: bool = False) -> None:
1168 """Compute all nodes of a computation that can be computed.
1170 Nodes that are already UPTODATE will not be recalculated. Following the computation, if successful, all
1171 nodes will have state UPTODATE, except UNINITIALIZED input nodes and PLACEHOLDER nodes.
1173 If any nodes raises an exception, then the state of that node will be set to ERROR, and its value set to an
1174 object containing the exception object, as well as a traceback. This will not halt the computation, which
1175 will proceed as far as it can, until no more nodes are COMPUTABLE.
1177 :param raise_exceptions: Whether to pass exceptions raised by node computations back to the caller
1178 :type raise_exceptions: Boolean, default False
1179 """
1180 self._compute_nodes(self._node_keys(), raise_exceptions=raise_exceptions)
1182 def _node_keys(self) -> list[NodeKey]:
1183 """Get a list of nodes in this computation.
1185 :return: List of nodes.
1186 """
1187 return list(self.dag.nodes)
1189 def nodes(self) -> list[Name]:
1190 """Get a list of nodes in this computation.
1192 :return: List of nodes.
1193 """
1194 return [n.name for n in self.dag.nodes]
1196 def get_tree_list_children(self, name: Name) -> set[Name]:
1197 """Get a list of nodes in this computation.
1199 :return: List of nodes.
1200 """
1201 node_key = to_nodekey(name)
1202 idx = len(node_key.parts)
1203 result = set()
1204 for n in self.dag.nodes:
1205 if n.is_descendent_of(node_key):
1206 result.add(n.parts[idx])
1207 return result
1209 def has_node(self, name: Name) -> bool:
1210 """Check if a node with the given name exists in the computation."""
1211 node_key = to_nodekey(name)
1212 return node_key in self.dag.nodes
1214 def tree_has_path(self, name: Name) -> bool:
1215 """Check if a hierarchical path exists in the computation tree."""
1216 node_key = to_nodekey(name)
1217 if node_key.is_root:
1218 return True
1219 if self.has_node(node_key):
1220 return True
1221 return any(n.is_descendent_of(node_key) for n in self.dag.nodes)
1223 def get_tree_descendents(
1224 self, name: Name | None = None, *, include_stem: bool = True, graph_nodes_only: bool = False
1225 ) -> set[Name]:
1226 """Get a list of descendent blocks and nodes.
1228 Returns blocks and nodes that are descendents of the input node,
1229 e.g. for node 'foo', might return ['foo/bar', 'foo/baz'].
1231 :param name: Name of node to get descendents for
1232 :return: List of descendent node names
1233 """
1234 node_key = NodeKey.root() if name is None else to_nodekey(name)
1235 stemsize = len(node_key.parts)
1236 result = set()
1237 for n in self.dag.nodes:
1238 if n.is_descendent_of(node_key):
1239 nodes = [n] if graph_nodes_only else n.ancestors()
1240 for n2 in nodes:
1241 if n2.is_descendent_of(node_key):
1242 nm = n2.name if include_stem else NodeKey(tuple(n2.parts[stemsize:])).name
1243 result.add(nm)
1244 return result
1246 def _state_one(self, name: Name) -> States:
1247 """Get the state of a single node."""
1248 node_key = to_nodekey(name)
1249 state: States = self.dag.nodes[node_key][NodeAttributes.STATE]
1250 return state
1252 @overload
1253 def state(self, name: Name) -> States: ...
1255 @overload
1256 def state(self, name: Names) -> list[States]: ...
1258 def state(self, name: Name | Names) -> States | list[States]:
1259 """Get the state of a node.
1261 This can also be accessed using the attribute-style accessor ``s`` if ``name`` is a valid Python
1262 attribute name::
1264 >>> comp = Computation()
1265 >>> comp.add_node('foo', value=1)
1266 >>> comp.state('foo')
1267 <States.UPTODATE: 4>
1268 >>> comp.s.foo
1269 <States.UPTODATE: 4>
1271 :param name: Name or names of the node to get state for
1272 :type name: Name or Names
1273 """
1274 return apply1(self._state_one, name)
1276 def _value_one(self, name: Name) -> Any:
1277 """Get the value of a single node."""
1278 node_key = to_nodekey(name)
1279 return self.dag.nodes[node_key][NodeAttributes.VALUE]
1281 @overload
1282 def value(self, name: Name) -> Any: ...
1284 @overload
1285 def value(self, name: Names) -> list[Any]: ...
1287 def value(self, name: Name | Names) -> Any | list[Any]:
1288 """Get the current value of a node.
1290 This can also be accessed using the attribute-style accessor ``v`` if ``name`` is a valid Python
1291 attribute name::
1293 >>> comp = Computation()
1294 >>> comp.add_node('foo', value=1)
1295 >>> comp.value('foo')
1296 1
1297 >>> comp.v.foo
1298 1
1300 :param name: Name or names of the node to get the value of
1301 :type name: Name or Names
1302 """
1303 return apply1(self._value_one, name)
1305 def compute_and_get_value(self, name: Name) -> Any:
1306 """Get the current value of a node.
1308 This can also be accessed using the attribute-style accessor ``v`` if ``name`` is a valid Python
1309 attribute name::
1311 >>> comp = Computation()
1312 >>> comp.add_node('foo', value=1)
1313 >>> comp.add_node('bar', lambda foo: foo + 1)
1314 >>> comp.compute_and_get_value('bar')
1315 2
1316 >>> comp.x.bar
1317 2
1319 :param name: Name or names of the node to get the value of
1320 :type name: Name
1321 """
1322 nk = to_nodekey(name)
1323 if self.state(nk) == States.UPTODATE:
1324 return self.value(nk)
1325 self.compute(nk, raise_exceptions=True)
1326 if self.state(nk) == States.UPTODATE:
1327 return self.value(nk)
1328 msg = f"Unable to compute node {nk}"
1329 raise ComputationError(msg)
1331 def _tag_one(self, name: Name) -> set[str]:
1332 """Get the tags of a single node."""
1333 node_key = to_nodekey(name)
1334 node = self.dag.nodes[node_key]
1335 tags: set[str] = node[NodeAttributes.TAG]
1336 return tags
1338 @overload
1339 def tags(self, name: Name) -> set[str]: ...
1341 @overload
1342 def tags(self, name: Names) -> list[set[str]]: ...
1344 def tags(self, name: Name | Names) -> set[str] | list[set[str]]:
1345 """Get the tags associated with a node.
1347 >>> comp = Computation()
1348 >>> comp.add_node('a', tags=['foo', 'bar'])
1349 >>> sorted(comp.t.a)
1350 ['__serialize__', 'bar', 'foo']
1352 :param name: Name or names of the node to get the tags of
1353 :return:
1354 """
1355 return apply1(self._tag_one, name)
1357 def nodes_by_tag(self, tag: str | Iterable[str]) -> set[Name]:
1358 """Get the names of nodes with a particular tag or tags.
1360 :param tag: Tag or tags for which to retrieve nodes
1361 :return: Names of the nodes with those tags
1362 """
1363 nodes: set[NodeKey] = set()
1364 tags_to_check: Iterable[str] = [tag] if isinstance(tag, str) else tag
1365 for tag1 in tags_to_check:
1366 nodes1 = self._tag_map.get(tag1)
1367 if nodes1 is not None:
1368 nodes.update(nodes1)
1369 return {n.name for n in nodes}
1371 def _style_one(self, name: Name) -> str | None:
1372 """Get the style of a single node."""
1373 node_key = to_nodekey(name)
1374 node = self.dag.nodes[node_key]
1375 style: str | None = node.get(NodeAttributes.STYLE)
1376 return style
1378 @overload
1379 def styles(self, name: Name) -> str | None: ...
1381 @overload
1382 def styles(self, name: Names) -> list[str | None]: ...
1384 def styles(self, name: Name | Names) -> str | None | list[str | None]:
1385 """Get the tags associated with a node.
1387 >>> comp = Computation()
1388 >>> comp.add_node('a', style='dot')
1389 >>> comp.style.a
1390 'dot'
1392 :param name: Name or names of the node to get the tags of
1393 :return:
1394 """
1395 return apply1(self._style_one, name)
1397 def _get_item_one(self, name: Name) -> NodeData:
1398 """Get state and value data for a single node."""
1399 node_key = to_nodekey(name)
1400 node = self.dag.nodes[node_key]
1401 return NodeData(node[NodeAttributes.STATE], node[NodeAttributes.VALUE])
1403 @overload
1404 def __getitem__(self, name: Name) -> NodeData: ...
1406 @overload
1407 def __getitem__(self, name: Names) -> list[NodeData]: ...
1409 def __getitem__(self, name: Name | Names) -> NodeData | list[NodeData]:
1410 """Get the state and current value of a node.
1412 :param name: Name of the node to get the state and value of
1413 """
1414 return apply1(self._get_item_one, name)
1416 def _get_timing_one(self, name: Name) -> TimingData | None:
1417 """Get timing data for a single node."""
1418 node_key = to_nodekey(name)
1419 node = self.dag.nodes[node_key]
1420 timing: TimingData | None = node.get(NodeAttributes.TIMING, None)
1421 return timing
1423 @overload
1424 def get_timing(self, name: Name) -> TimingData | None: ...
1426 @overload
1427 def get_timing(self, name: Names) -> list[TimingData | None]: ...
1429 def get_timing(self, name: Name | Names) -> TimingData | None | list[TimingData | None]:
1430 """Get the timing information for a node.
1432 :param name: Name or names of the node to get the timing information of
1433 :return:
1434 """
1435 return apply1(self._get_timing_one, name)
1437 def to_df(self) -> pd.DataFrame:
1438 """Get a dataframe containing the states and value of all nodes of computation.
1440 ::
1442 >>> import loman
1443 >>> comp = loman.Computation()
1444 >>> comp.add_node('foo', value=1)
1445 >>> comp.add_node('bar', value=2)
1446 >>> comp.to_df() # doctest: +NORMALIZE_WHITESPACE
1447 state value
1448 foo States.UPTODATE 1
1449 bar States.UPTODATE 2
1450 """
1451 df = pd.DataFrame(index=topological_sort(self.dag))
1452 df[NodeAttributes.STATE] = pd.Series(nx.get_node_attributes(self.dag, NodeAttributes.STATE))
1453 df[NodeAttributes.VALUE] = pd.Series(nx.get_node_attributes(self.dag, NodeAttributes.VALUE))
1454 df_timing = pd.DataFrame.from_dict(nx.get_node_attributes(self.dag, "timing"), orient="index")
1455 df = pd.merge(df, df_timing, left_index=True, right_index=True, how="left")
1456 df.index = pd.Index([nk.name for nk in df.index])
1457 return df
1459 def to_dict(self) -> dict[NodeKey, Any]:
1460 """Get a dictionary containing the values of all nodes of a computation.
1462 ::
1464 >>> import loman
1465 >>> comp = loman.Computation()
1466 >>> comp.add_node('foo', value=1)
1467 >>> comp.add_node('bar', value=2)
1468 >>> comp.to_dict() # doctest: +ELLIPSIS
1469 {NodeKey('foo'): 1, NodeKey('bar'): 2}
1470 """
1471 result: dict[NodeKey, Any] = nx.get_node_attributes(self.dag, NodeAttributes.VALUE)
1472 return result
1474 def _get_inputs_one_node_keys(self, node_key: NodeKey) -> list[NodeKey | None]:
1475 """Get input node keys for a single node."""
1476 args_dict: dict[int, NodeKey] = {}
1477 kwds: list[NodeKey | None] = []
1478 max_arg_index = -1
1479 for input_node in self.dag.predecessors(node_key):
1480 input_edge = self.dag[input_node][node_key]
1481 input_type, input_param = input_edge[EdgeAttributes.PARAM]
1482 if input_type == _ParameterType.ARG:
1483 idx = input_param
1484 max_arg_index = max(max_arg_index, idx)
1485 args_dict[idx] = input_node
1486 elif input_type == _ParameterType.KWD:
1487 kwds.append(input_node)
1488 if max_arg_index >= 0:
1489 args: list[NodeKey | None] = [None] * (max_arg_index + 1)
1490 for idx, input_node in args_dict.items():
1491 args[idx] = input_node
1492 result: list[NodeKey | None] = args + kwds
1493 return result
1494 else:
1495 return kwds
1497 def _get_inputs_one_names(self, name: Name) -> Names:
1498 """Get input node names for a single node."""
1499 node_key = to_nodekey(name)
1500 return node_keys_to_names([nk for nk in self._get_inputs_one_node_keys(node_key) if nk is not None])
1502 @overload
1503 def get_inputs(self, name: Name) -> Names: ...
1505 @overload
1506 def get_inputs(self, name: Names) -> list[Names]: ...
1508 def get_inputs(self, name: Name | Names) -> Names | list[Names]:
1509 """Get a list of the inputs for a node or set of nodes.
1511 :param name: Name or names of nodes to get inputs for
1512 :return: If name is scalar, return a list of upstream nodes used as input. If name is a list, return a
1513 list of list of inputs.
1514 """
1515 return apply1(self._get_inputs_one_names, name)
1517 def _get_ancestors_node_keys(self, node_keys: Iterable[NodeKey], include_self: bool = True) -> set[NodeKey]:
1518 """Get all ancestor node keys for a set of nodes."""
1519 ancestors: set[NodeKey] = set()
1520 for n in node_keys:
1521 if include_self:
1522 ancestors.add(n)
1523 for ancestor in nx.ancestors(self.dag, n):
1524 ancestors.add(ancestor)
1525 return ancestors
1527 def get_ancestors(self, names: Name | Names, include_self: bool = True) -> Names:
1528 """Get all ancestor nodes of the specified nodes."""
1529 node_keys = names_to_node_keys(names)
1530 ancestor_node_keys = self._get_ancestors_node_keys(node_keys, include_self)
1531 return node_keys_to_names(ancestor_node_keys)
1533 def _get_original_inputs_node_keys(self, node_keys: list[NodeKey] | None) -> list[NodeKey]:
1534 """Get original input node keys that have no computation function."""
1535 resolved_node_keys: Iterable[NodeKey]
1536 resolved_node_keys = self._node_keys() if node_keys is None else self._get_ancestors_node_keys(node_keys)
1537 return [n for n in resolved_node_keys if self.dag.nodes[n].get(NodeAttributes.FUNC) is None]
1539 def get_original_inputs(self, names: Name | Names | None = None) -> Names:
1540 """Get a list of the original non-computed inputs for a node or set of nodes.
1542 :param names: Name or names of nodes to get inputs for
1543 :return: Return a list of original non-computed inputs that are ancestors of the input nodes
1544 """
1545 nks = None if names is None else names_to_node_keys(names)
1547 result_nks = self._get_original_inputs_node_keys(nks)
1549 return node_keys_to_names(result_nks)
1551 def _get_outputs_one(self, name: Name) -> Names:
1552 """Get output node names for a single node."""
1553 node_key = to_nodekey(name)
1554 output_node_keys = list(self.dag.successors(node_key))
1555 return node_keys_to_names(output_node_keys)
1557 @overload
1558 def get_outputs(self, name: Name) -> Names: ...
1560 @overload
1561 def get_outputs(self, name: Names) -> list[Names]: ...
1563 def get_outputs(self, name: Name | Names) -> Names | list[Names]:
1564 """Get a list of the outputs for a node or set of nodes.
1566 :param name: Name or names of nodes to get outputs for
1567 :return: If name is scalar, return a list of downstream nodes used as output. If name is a list, return a
1568 list of list of outputs.
1570 """
1571 return apply1(self._get_outputs_one, name)
1573 def _get_descendents_node_keys(self, node_keys: Iterable[NodeKey], include_self: bool = True) -> set[NodeKey]:
1574 """Get all descendant node keys for a set of nodes."""
1575 descendent_node_keys: set[NodeKey] = set()
1576 for node_key in node_keys:
1577 if include_self:
1578 descendent_node_keys.add(node_key)
1579 for descendent in nx.descendants(self.dag, node_key):
1580 descendent_node_keys.add(descendent)
1581 return descendent_node_keys
1583 def get_descendents(self, names: Name | Names, include_self: bool = True) -> Names:
1584 """Get all descendent nodes of the specified nodes."""
1585 node_keys = names_to_node_keys(names)
1586 descendent_node_keys = self._get_descendents_node_keys(node_keys, include_self)
1587 return node_keys_to_names(descendent_node_keys)
1589 def get_final_outputs(self, names: Name | Names | None = None) -> Names:
1590 """Get final output nodes (nodes with no descendants) from the specified nodes."""
1591 final_node_keys: Iterable[NodeKey]
1592 if names is None:
1593 final_node_keys = self._node_keys()
1594 else:
1595 nks = names_to_node_keys(names)
1596 final_node_keys = self._get_descendents_node_keys(nks)
1597 output_node_keys = [n for n in final_node_keys if len(nx.descendants(self.dag, n)) == 0]
1598 return node_keys_to_names(output_node_keys)
1600 def get_source(self, name: Name) -> str:
1601 """Get the source code for a node."""
1602 node_key = to_nodekey(name)
1603 func = self.dag.nodes[node_key].get(NodeAttributes.FUNC, None)
1604 if func is not None:
1605 file = inspect.getsourcefile(func)
1606 _, lineno = inspect.getsourcelines(func)
1607 source = inspect.getsource(func)
1608 return f"{file}:{lineno}\n\n{source}"
1609 else:
1610 return "NOT A CALCULATED NODE"
1612 def print_source(self, name: Name) -> None:
1613 """Print the source code for a computation node."""
1614 print(self.get_source(name))
1616 def restrict(self, output_names: Name | Names, input_names: Name | Names | None = None) -> None:
1617 """Restrict a computation to the ancestors of a set of output nodes.
1619 Excludes ancestors of a set of input nodes.
1621 If the set of input_nodes that is specified is not sufficient for the set of output_nodes then additional
1622 nodes that are ancestors of the output_nodes will be included, but the input nodes specified will be input
1623 nodes of the modified Computation.
1625 :param output_nodes:
1626 :param input_nodes:
1627 :return: None - modifies existing computation in place
1628 """
1629 if input_names is not None:
1630 for n in as_iterable(input_names):
1631 nodedata = self._get_item_one(n)
1632 self.add_node(n)
1633 self._set_state_and_literal_value(to_nodekey(n), nodedata.state, nodedata.value)
1634 output_node_keys = names_to_node_keys(output_names)
1635 ancestor_node_keys = self._get_ancestors_node_keys(output_node_keys)
1636 self.dag.remove_nodes_from([n for n in self.dag if n not in ancestor_node_keys])
1638 def __getstate__(self) -> dict[str, Any]:
1639 """Prepare computation for serialization by removing non-serializable nodes."""
1640 node_serialize = nx.get_node_attributes(self.dag, NodeAttributes.TAG)
1641 obj = self.copy()
1642 for name, tags in node_serialize.items():
1643 if SystemTags.SERIALIZE not in tags:
1644 obj._set_uninitialized(name)
1645 return {"dag": obj.dag}
1647 def __setstate__(self, state: dict[str, Any]) -> None:
1648 """Restore computation from serialized state."""
1649 self.__init__()
1650 self.dag = state["dag"]
1651 self._refresh_maps()
1653 def write_dill_old(self, file_: str | BinaryIO) -> None:
1654 """Serialize a computation to a file or file-like object.
1656 :param file_: If string, writes to a file
1657 :type file_: File-like object, or string
1658 """
1659 warnings.warn("write_dill_old is deprecated, use write_dill instead", DeprecationWarning, stacklevel=2)
1660 original_getstate = self.__class__.__getstate__
1661 original_setstate = self.__class__.__setstate__
1663 try:
1664 del self.__class__.__getstate__
1665 del self.__class__.__setstate__
1667 node_serialize = nx.get_node_attributes(self.dag, NodeAttributes.TAG)
1668 obj = self.copy()
1669 obj.executor_map = None # type: ignore[assignment]
1670 obj.default_executor = None # type: ignore[assignment]
1671 for name, tags in node_serialize.items():
1672 if SystemTags.SERIALIZE not in tags:
1673 obj._set_uninitialized(name)
1675 if isinstance(file_, str):
1676 with open(file_, "wb") as f:
1677 dill.dump(obj, f)
1678 else:
1679 dill.dump(obj, file_)
1680 finally:
1681 self.__class__.__getstate__ = original_getstate # type: ignore[method-assign]
1682 self.__class__.__setstate__ = original_setstate
1684 def write_dill(self, file_: str | BinaryIO) -> None:
1685 """Serialize a computation to a file or file-like object.
1687 .. deprecated::
1688 Use :meth:`write_json` instead. dill-based serialization will be
1689 removed in a future release.
1691 :param file_: If string, writes to a file
1692 :type file_: File-like object, or string
1693 """
1694 warnings.warn(
1695 "write_dill is deprecated and will be removed in a future release. Use write_json instead.",
1696 DeprecationWarning,
1697 stacklevel=2,
1698 )
1699 if isinstance(file_, str):
1700 with open(file_, "wb") as f:
1701 dill.dump(self, f)
1702 else:
1703 dill.dump(self, file_)
1705 @staticmethod
1706 def read_dill(file_: str | BinaryIO) -> "Computation":
1707 """Deserialize a computation from a file or file-like object.
1709 .. deprecated::
1710 Use :meth:`read_json` instead. dill-based serialization will be
1711 removed in a future release.
1713 .. warning::
1714 This method uses dill.load() which can execute arbitrary code.
1715 Only load files from trusted sources. Never load data from
1716 untrusted or unauthenticated sources as it may lead to arbitrary
1717 code execution.
1719 :param file_: If string, writes to a file
1720 :type file_: File-like object, or string
1721 """
1722 warnings.warn(
1723 "read_dill is deprecated and will be removed in a future release. Use read_json instead.",
1724 DeprecationWarning,
1725 stacklevel=2,
1726 )
1727 if isinstance(file_, str):
1728 with open(file_, "rb") as f:
1729 obj = dill.load(f) # noqa: S301 # nosec B301
1730 else:
1731 obj = dill.load(file_) # noqa: S301 # nosec B301
1732 if isinstance(obj, Computation):
1733 return obj
1734 else:
1735 msg = "Loaded object is not a Computation"
1736 raise ValidationError(msg)
1738 def write_json(self, file_: str | TextIO, *, serializer: "ComputationSerializer | None" = None) -> None:
1739 """Serialize a computation to a JSON file or file-like object.
1741 Custom types can be supported by passing a custom *serializer* —
1742 either a :class:`~loman.serialization.computation.ComputationSerializer`
1743 instance with extra transformers registered, or a subclass that
1744 overrides the transformer factory.
1746 :param file_: Destination file path (str) or text-mode file-like object.
1747 :param serializer: Optional custom serializer. If ``None`` the default
1748 :class:`~loman.serialization.computation.ComputationSerializer` is used.
1749 """
1750 from .serialization.computation import ComputationSerializer
1752 s = serializer if serializer is not None else ComputationSerializer()
1753 if isinstance(file_, str):
1754 with open(file_, "w") as f:
1755 s.dump(self, f)
1756 else:
1757 s.dump(self, file_)
1759 @staticmethod
1760 def read_json(file_: str | TextIO, *, serializer: "ComputationSerializer | None" = None) -> "Computation":
1761 """Deserialize a computation from a JSON file or file-like object.
1763 :param file_: Source file path (str) or text-mode file-like object.
1764 :param serializer: Optional custom serializer. If ``None`` the default
1765 :class:`~loman.serialization.computation.ComputationSerializer` is used.
1766 :rtype: Computation
1767 """
1768 from .serialization.computation import ComputationSerializer
1770 s = serializer if serializer is not None else ComputationSerializer()
1771 if isinstance(file_, str):
1772 with open(file_) as f:
1773 return s.load(f)
1774 else:
1775 return s.load(file_)
1777 def copy(self) -> "Computation":
1778 """Create a copy of a computation.
1780 The copy is shallow. Any values in the new Computation's DAG will be the same object as this Computation's
1781 DAG. As new objects will be created by any further computations, this should not be an issue.
1783 :rtype: Computation
1784 """
1785 obj = Computation()
1786 obj.dag = nx.DiGraph(self.dag)
1787 obj._tag_map = defaultdict(set, {tag: nodes.copy() for tag, nodes in self._tag_map.items()})
1788 obj._state_map = {state: nodes.copy() for state, nodes in self._state_map.items()}
1789 return obj
1791 def add_named_tuple_expansion(self, name: Name, namedtuple_type: type, group: str | None = None) -> None:
1792 """Automatically add nodes to extract each element of a named tuple type.
1794 It is often convenient for a calculation to return multiple values, and it is polite to do this a namedtuple
1795 rather than a regular tuple, so that later users have same name to identify elements of the tuple. It can
1796 also help make a computation clearer if a downstream computation depends on one element of such a tuple,
1797 rather than the entire tuple. This does not affect the computation per se, but it does make the intention
1798 clearer.
1800 To avoid having to create many boiler-plate node definitions to expand namedtuples, the
1801 ``add_named_tuple_expansion`` method automatically creates new nodes for each element of a tuple. The
1802 convention is that an element called 'element', in a node called 'node' will be expanded into a new node
1803 called 'node.element', and that this will be applied for each element.
1805 Example::
1807 >>> from collections import namedtuple
1808 >>> Coordinate = namedtuple('Coordinate', ['x', 'y'])
1809 >>> comp = Computation()
1810 >>> comp.add_node('c', value=Coordinate(1, 2))
1811 >>> comp.add_named_tuple_expansion('c', Coordinate)
1812 >>> comp.compute_all()
1813 >>> comp.value('c.x')
1814 1
1815 >>> comp.value('c.y')
1816 2
1818 :param name: Node to cera
1819 :param namedtuple_type: Expected type of the node
1820 :type namedtuple_type: namedtuple class
1821 """
1823 def make_f(field_name: str) -> Callable[[Any], Any]:
1824 """Create a function to extract a field from a namedtuple."""
1826 def get_field_value(tuple_val: Any) -> Any:
1827 """Extract field value from the namedtuple."""
1828 return getattr(tuple_val, field_name)
1830 return get_field_value
1832 for field_name in namedtuple_type._fields: # type: ignore[attr-defined]
1833 node_name = f"{name}.{field_name}"
1834 self.add_node(node_name, make_f(field_name), kwds={"tuple_val": name}, group=group)
1835 self.set_tag(node_name, SystemTags.EXPANSION)
1837 def add_map_node(
1838 self,
1839 result_node: Name,
1840 input_node: Name,
1841 subgraph: "Computation",
1842 subgraph_input_node: Name,
1843 subgraph_output_node: Name,
1844 ) -> None:
1845 """Apply a graph to each element of iterable.
1847 In turn, each element in the ``input_node`` of this graph will be inserted in turn into the subgraph's
1848 ``subgraph_input_node``, then the subgraph's ``subgraph_output_node`` calculated. The resultant list, with
1849 an element or each element in ``input_node``, will be inserted into ``result_node`` of this graph. In this
1850 way ``add_map_node`` is similar to ``map`` in functional programming.
1852 :param result_node: The node to place a list of results in **this** graph
1853 :param input_node: The node to get a list input values from **this** graph
1854 :param subgraph: The graph to use to perform calculation for each element
1855 :param subgraph_input_node: The node in **subgraph** to insert each element in turn
1856 :param subgraph_output_node: The node in **subgraph** to read the result for each element
1857 """
1859 def f(xs: Iterable[Any]) -> list[Any]:
1860 """Apply subgraph computation to each element in the input."""
1861 results: list[Any] = []
1862 is_error = False
1863 for x in xs:
1864 subgraph.insert(subgraph_input_node, x)
1865 subgraph.compute(subgraph_output_node)
1866 if subgraph.state(subgraph_output_node) == States.UPTODATE:
1867 results.append(subgraph.value(subgraph_output_node))
1868 else:
1869 is_error = True
1870 results.append(subgraph.copy())
1871 if is_error:
1872 msg = f"Unable to calculate {result_node}"
1873 raise MapException(msg, results)
1874 return results
1876 self.add_node(result_node, f, kwds={"xs": input_node})
1878 def prepend_path(self, path: Name | ConstantValue, prefix_path: NodeKey) -> NodeKey | ConstantValue:
1879 """Prepend a prefix path to a node path."""
1880 if isinstance(path, ConstantValue):
1881 return path
1882 nk = to_nodekey(path)
1883 return prefix_path.join(nk)
1885 def add_block(
1886 self,
1887 base_path: Name,
1888 block: "Computation",
1889 *,
1890 keep_values: bool | None = True,
1891 links: dict[str, Name] | None = None,
1892 metadata: dict[str, Any] | None = None,
1893 ) -> None:
1894 """Add a computation block as a subgraph to this computation."""
1895 base_path_nk = to_nodekey(base_path)
1896 for node_name in block.nodes():
1897 node_key = to_nodekey(node_name)
1898 node_data = block.dag.nodes[node_key]
1899 tags = node_data.get(NodeAttributes.TAG, None)
1900 # strip the serialize tag from the original node: add_block explicitly
1901 # sets serialize=False, meaning "don't serialize the function".
1902 # Value serialization is controlled separately via keep_values below.
1903 if tags is not None:
1904 tags = tags - {SystemTags.SERIALIZE}
1905 style = node_data.get(NodeAttributes.STYLE, None)
1906 group = node_data.get(NodeAttributes.GROUP, None)
1907 args_def, kwds_def = block.get_definition_args_kwds(node_key)
1908 args_prepended = [self.prepend_path(arg, base_path_nk) for arg in args_def]
1909 kwds_prepended = {k: self.prepend_path(v, base_path_nk) for k, v in kwds_def.items()}
1910 func = node_data.get(NodeAttributes.FUNC, None)
1911 executor = node_data.get(NodeAttributes.EXECUTOR, None)
1912 converter = node_data.get(NodeAttributes.CONVERTER, None)
1913 new_node_name = self.prepend_path(node_name, base_path_nk)
1914 self.add_node(
1915 new_node_name,
1916 func,
1917 args=args_prepended,
1918 kwds=kwds_prepended,
1919 converter=converter,
1920 serialize=False,
1921 inspect=False,
1922 group=group,
1923 tags=tags,
1924 style=style,
1925 executor=executor,
1926 )
1927 if keep_values and NodeAttributes.VALUE in node_data:
1928 new_node_key = to_nodekey(new_node_name)
1929 self._set_state_and_literal_value(
1930 new_node_key, node_data[NodeAttributes.STATE], node_data[NodeAttributes.VALUE]
1931 )
1932 # The node has a concrete value — mark it serializable so the
1933 # value survives a JSON roundtrip even though the function is not.
1934 self._set_tag_one(new_node_key, SystemTags.SERIALIZE)
1935 if links is not None:
1936 for target, source in links.items():
1937 self.link(base_path_nk.join_parts(target), source)
1938 if metadata is not None:
1939 self._metadata[base_path_nk] = metadata
1940 else:
1941 if base_path_nk in self._metadata:
1942 del self._metadata[base_path_nk]
1944 def link(self, target: Name, source: Name) -> None:
1945 """Create a link between two nodes in the computation graph."""
1946 target_nk = to_nodekey(target)
1947 source_nk = to_nodekey(source)
1948 if target_nk == source_nk:
1949 return
1951 target_style = self._style_one(target_nk) if self.has_node(target_nk) else None
1952 source_style = self._style_one(source_nk) if self.has_node(source_nk) else None
1953 style = target_style if target_style else source_style
1955 self.add_node(target_nk, identity_function, kwds={"x": source_nk}, style=style)
1957 def _repr_svg_(self) -> str | None:
1958 """Return SVG representation for Jupyter notebook display."""
1959 return GraphView(self).svg()
1961 def draw(
1962 self,
1963 root: NodeKey | None = None,
1964 *,
1965 node_transformations: dict[Name, str] | None = None,
1966 cmap: Any = None,
1967 colors: str = "state",
1968 shapes: str | None = None,
1969 graph_attr: dict[str, Any] | None = None,
1970 node_attr: dict[str, Any] | None = None,
1971 edge_attr: dict[str, Any] | None = None,
1972 show_expansion: bool = False,
1973 collapse_all: bool = True,
1974 ) -> GraphView:
1975 """Draw a computation's current state using the GraphViz utility.
1977 :param root: Optional PathType. Sub-block to draw
1978 :param cmap: Default: None
1979 :param colors: 'state' - colors indicate state. 'timing' - colors indicate execution time. Default: 'state'.
1980 :param shapes: None - ovals. 'type' - shapes indicate type. Default: None.
1981 :param graph_attr: Mapping of (attribute, value) pairs for the graph. For example
1982 ``graph_attr={'size': '"10,8"'}`` can control the size of the output graph
1983 :param node_attr: Mapping of (attribute, value) pairs set for all nodes.
1984 :param edge_attr: Mapping of (attribute, value) pairs set for all edges.
1985 :param collapse_all: Whether to collapse all blocks that aren't explicitly expanded.
1986 """
1987 node_formatter = NodeFormatter.create(cmap, colors, shapes)
1988 node_transformations_copy: dict[Name, str] = (
1989 node_transformations.copy() if node_transformations is not None else {}
1990 )
1991 if not show_expansion:
1992 for nodekey in self.nodes_by_tag(SystemTags.EXPANSION):
1993 node_transformations_copy[nodekey] = NodeTransformations.CONTRACT
1994 v = GraphView(
1995 self,
1996 root=root,
1997 node_formatter=node_formatter,
1998 graph_attr=graph_attr,
1999 node_attr=node_attr,
2000 edge_attr=edge_attr,
2001 node_transformations=node_transformations_copy,
2002 collapse_all=collapse_all,
2003 )
2004 return v
2006 def view(self, cmap: Any = None, colors: str = "state", shapes: str | None = None) -> None:
2007 """Create and display a visualization of the computation graph."""
2008 node_formatter = NodeFormatter.create(cmap, colors, shapes)
2009 v = GraphView(self, node_formatter=node_formatter)
2010 v.view()
2012 def print_errors(self) -> None:
2013 """Print tracebacks for every node with state "ERROR" in a Computation."""
2014 for n in self.nodes():
2015 if self.s[n] == States.ERROR:
2016 print(f"{n}")
2017 print("=" * len(str(n)))
2018 print()
2019 print(self.v[n].traceback)
2020 print()
2022 @classmethod
2023 def from_class(cls, definition_class: type, ignore_self: bool = True) -> "Computation":
2024 """Create a computation from a class with decorated methods."""
2025 comp = cls()
2026 obj = definition_class()
2027 populate_computation_from_class(comp, definition_class, obj, ignore_self=ignore_self)
2028 return comp
2030 def inject_dependencies(self, dependencies: dict[Name, Any], *, force: bool = False) -> None:
2031 """Injects dependencies into the nodes of the current computation where nodes are in a placeholder state.
2033 (or all possible nodes when the 'force' parameter is set to True), using values
2034 provided in the 'dependencies' dictionary.
2036 Each key in the 'dependencies' dictionary corresponds to a node identifier, and the associated
2037 value is the dependency object to inject. If the value is a callable, it will be added as a calc node.
2039 :param dependencies: A dictionary where each key-value pair consists of a node identifier and
2040 its corresponding dependency object or a callable that returns the dependency object.
2041 :param force: A boolean flag that, when set to True, forces the replacement of existing node values
2042 with the ones provided in 'dependencies', regardless of their current state. Defaults to False.
2043 :return: None
2044 """
2045 for n in self.nodes():
2046 if force or self.s[n] == States.PLACEHOLDER:
2047 obj = dependencies.get(n)
2048 if obj is None:
2049 continue
2050 if callable(obj):
2051 self.add_node(n, obj)
2052 else:
2053 self.add_node(n, value=obj)