Coverage for src / loman / computeengine.py: 90%
992 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-02 23:34 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-02 23:34 +0000
1"""Core computation engine for dependency-aware calculation graphs."""
3import inspect
4import logging
5import traceback
6import types
7import warnings
8from collections import defaultdict
9from collections.abc import Callable, Iterable, Mapping
10from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
11from dataclasses import dataclass, field
12from datetime import UTC, datetime
13from enum import Enum
14from typing import Any, Type, Union # noqa: UP035
16import decorator
17import dill
18import networkx as nx
19import pandas as pd
21from .compat import get_signature
22from .consts import EdgeAttributes, NodeAttributes, NodeTransformations, States, SystemTags
23from .exception import (
24 CannotInsertToPlaceholderNodeException,
25 ComputationError,
26 LoopDetectedException,
27 MapException,
28 NodeAlreadyExistsException,
29 NonExistentNodeException,
30)
31from .graph_utils import topological_sort
32from .nodekey import Name, Names, NodeKey, names_to_node_keys, node_keys_to_names, to_nodekey
33from .util import AttributeView, apply1, apply_n, as_iterable, value_eq
34from .visualization import GraphView, NodeFormatter
36LOG = logging.getLogger("loman.computeengine")
39@dataclass
40class Error:
41 """Container for error information during computation."""
43 exception: Exception
44 traceback: inspect.Traceback
47@dataclass
48class NodeData:
49 """Data associated with a computation node."""
51 state: States
52 value: object
55@dataclass
56class TimingData:
57 """Timing information for computation execution."""
59 start: datetime
60 end: datetime
61 duration: float
64class _ParameterType(Enum):
65 ARG = 1
66 KWD = 2
69@dataclass
70class _ParameterItem:
71 type: object
72 name: int | str
73 value: object
76def _node(func, *args, **kws):
77 return func(*args, **kws)
80def node(comp, name=None, *args, **kw):
81 """Decorator to add a function as a node to a computation graph."""
83 def inner(f):
84 if name is None:
85 comp.add_node(f.__name__, f, *args, **kw)
86 else:
87 comp.add_node(name, f, *args, **kw)
88 return decorator.decorate(f, _node)
90 return inner
93@dataclass()
94class ConstantValue:
95 """Container for constant values in computations."""
97 value: object
100C = ConstantValue
103class Node:
104 """Base class for computation graph nodes."""
106 def add_to_comp(self, comp: "Computation", name: str, obj: object, ignore_self: bool):
107 """Add this node to the computation graph."""
108 raise NotImplementedError()
111@dataclass
112class InputNode(Node):
113 """A node representing input data in the computation graph."""
115 args: tuple[Any, ...] = field(default_factory=tuple)
116 kwds: dict = field(default_factory=dict)
118 def __init__(self, *args, **kwds):
119 """Initialize an input node with arguments and keyword arguments."""
120 self.args = args
121 self.kwds = kwds
123 def add_to_comp(self, comp: "Computation", name: str, obj: object, ignore_self: bool):
124 """Add this input node to the computation graph."""
125 comp.add_node(name, **self.kwds)
128input_node = InputNode
131@dataclass
132class CalcNode(Node):
133 """A node representing a calculation in the computation graph."""
135 f: Callable
136 kwds: dict = field(default_factory=dict)
138 def add_to_comp(self, comp: "Computation", name: str, obj: object, ignore_self: bool):
139 """Add this calculation node to the computation graph."""
140 kwds = self.kwds.copy()
141 ignore_self = ignore_self or kwds.get("ignore_self", False)
142 f = self.f
143 if ignore_self:
144 signature = get_signature(self.f)
145 if len(signature.kwd_params) > 0 and signature.kwd_params[0] == "self":
146 f = f.__get__(obj, obj.__class__)
147 if "ignore_self" in kwds:
148 del kwds["ignore_self"]
149 comp.add_node(name, f, **kwds)
152def calc_node(f=None, **kwds):
153 """Decorator to mark a function as a calculation node."""
155 def wrap(func):
156 func._loman_node_info = CalcNode(func, kwds)
157 return func
159 if f is None:
160 return wrap
161 return wrap(f)
164@dataclass
165class Block(Node):
166 """A node representing a computational block or subgraph."""
168 block: Union[Callable, "Computation"]
169 args: tuple[Any, ...] = field(default_factory=tuple)
170 kwds: dict = field(default_factory=dict)
172 def __init__(self, block, *args, **kwds):
173 """Initialize a block node with a computation block and arguments."""
174 self.block = block
175 self.args = args
176 self.kwds = kwds
178 def add_to_comp(self, comp: "Computation", name: str, obj: object, ignore_self: bool):
179 """Add this block node to the computation graph."""
180 if isinstance(self.block, Computation):
181 comp.add_block(name, self.block, *self.args, **self.kwds)
182 elif callable(self.block):
183 block0 = self.block()
184 comp.add_block(name, block0, *self.args, **self.kwds)
185 else:
186 raise TypeError(f"Block {self.block} must be callable or Computation")
189block = Block
192def populate_computation_from_class(comp, cls, obj, ignore_self=True):
193 """Populate a computation from class methods with node decorators."""
194 for name, member in inspect.getmembers(cls):
195 node_ = None
196 if isinstance(member, Node):
197 node_ = member
198 elif hasattr(member, "_loman_node_info"):
199 node_ = getattr(member, "_loman_node_info")
200 if node_ is not None:
201 node_.add_to_comp(comp, name, obj, ignore_self)
204def computation_factory(maybe_cls=None, *, ignore_self=True) -> Type["Computation"]: # noqa: UP006
205 """Factory function to create computations from class definitions."""
207 def wrap(cls):
208 def create_computation(*args, **kwargs):
209 obj = cls()
210 comp = Computation(*args, **kwargs)
211 comp._definition_object = obj
212 populate_computation_from_class(comp, cls, obj, ignore_self)
213 return comp
215 return create_computation
217 if maybe_cls is None:
218 return wrap
220 return wrap(maybe_cls)
223def _eval_node(name, f, args, kwds, raise_exceptions):
224 """To make multiprocessing work, this function must be standalone so that pickle works."""
225 exc, tb = None, None
226 start_dt = datetime.now(UTC)
227 try:
228 logging.debug("Running " + str(name))
229 value = f(*args, **kwds)
230 logging.debug("Completed " + str(name))
231 except Exception as e:
232 value = None
233 exc = e
234 tb = traceback.format_exc()
235 if raise_exceptions:
236 raise
237 end_dt = datetime.now(UTC)
238 return value, exc, tb, start_dt, end_dt
241_MISSING_VALUE_SENTINEL = object()
244class NullObject:
245 """Debug helper object that raises exceptions for all attribute/item access."""
247 def __getattr__(self, name):
248 """Raise AttributeError for any attribute access."""
249 print(f"__getattr__: {name}")
250 raise AttributeError(f"'NullObject' object has no attribute '{name}'")
252 def __setattr__(self, name, value):
253 """Raise AttributeError for any attribute assignment."""
254 print(f"__setattr__: {name}")
255 raise AttributeError(f"'NullObject' object has no attribute '{name}'")
257 def __delattr__(self, name):
258 """Raise AttributeError for any attribute deletion."""
259 print(f"__delattr__: {name}")
260 raise AttributeError(f"'NullObject' object has no attribute '{name}'")
262 def __call__(self, *args, **kwargs):
263 """Raise TypeError when called as a function."""
264 print(f"__call__: {args}, {kwargs}")
265 raise TypeError("'NullObject' object is not callable")
267 def __getitem__(self, key):
268 """Raise KeyError for any item access."""
269 print(f"__getitem__: {key}")
270 raise KeyError(f"'NullObject' object has no item with key '{key}'")
272 def __setitem__(self, key, value):
273 """Raise KeyError for any item assignment."""
274 print(f"__setitem__: {key}")
275 raise KeyError(f"'NullObject' object cannot have items set with key '{key}'")
277 def __repr__(self):
278 """Return string representation of NullObject."""
279 print(f"__repr__: {self.__dict__}")
280 return "<NullObject>"
283def identity_function(x):
284 """Return the input value unchanged."""
285 return x
288class Computation:
289 """A computation graph that manages dependencies and calculations.
291 The Computation class provides a framework for building and executing
292 computation graphs where nodes represent data or calculations, and edges
293 represent dependencies between them.
294 """
296 def __init__(self, *, default_executor=None, executor_map=None, metadata=None):
297 """Initialize a new Computation.
299 :param default_executor: An executor
300 :type default_executor: concurrent.futures.Executor, default ThreadPoolExecutor(max_workers=1)
301 """
302 if default_executor is None:
303 self.default_executor = ThreadPoolExecutor(1)
304 else:
305 self.default_executor = default_executor
306 if executor_map is None:
307 self.executor_map = {}
308 else:
309 self.executor_map = executor_map
310 self.dag = nx.DiGraph()
311 self._metadata = {}
312 if metadata is not None:
313 self._metadata[NodeKey.root()] = metadata
315 self.v = self.get_attribute_view_for_path(NodeKey.root(), self._value_one, self.value)
316 self.s = self.get_attribute_view_for_path(NodeKey.root(), self._state_one, self.state)
317 self.i = self.get_attribute_view_for_path(NodeKey.root(), self._get_inputs_one_names, self.get_inputs)
318 self.o = self.get_attribute_view_for_path(NodeKey.root(), self._get_outputs_one, self.get_outputs)
319 self.t = self.get_attribute_view_for_path(NodeKey.root(), self._tag_one, self.tags)
320 self.style = self.get_attribute_view_for_path(NodeKey.root(), self._style_one, self.styles)
321 self.tim = self.get_attribute_view_for_path(NodeKey.root(), self._get_timing_one, self.get_timing)
322 self.x = self.get_attribute_view_for_path(
323 NodeKey.root(), self.compute_and_get_value, self.compute_and_get_value
324 )
325 self.src = self.get_attribute_view_for_path(NodeKey.root(), self.print_source, self.print_source)
326 self._tag_map = defaultdict(set)
327 self._state_map = {state: set() for state in States}
329 def get_attribute_view_for_path(self, nodekey: NodeKey, get_one_func: callable, get_many_func: callable):
330 """Create an attribute view for a specific node path."""
332 def node_func():
333 return self.get_tree_list_children(nodekey)
335 def get_one_func_for_path(name: Name):
336 nk = to_nodekey(name)
337 new_nk = nk.prepend(nodekey)
338 if self.has_node(new_nk):
339 return get_one_func(new_nk)
340 elif self.tree_has_path(new_nk):
341 return self.get_attribute_view_for_path(new_nk, get_one_func, get_many_func)
342 else:
343 raise KeyError(f"Path {new_nk} does not exist")
345 def get_many_func_for_path(name: Name | Names):
346 if isinstance(name, list):
347 return [get_one_func_for_path(n) for n in name]
348 else:
349 return get_one_func_for_path(name)
351 return AttributeView(node_func, get_one_func_for_path, get_many_func_for_path)
353 def _get_names_for_state(self, state: States):
354 return set(node_keys_to_names(self._state_map[state]))
356 def _get_tags_for_state(self, tag: str):
357 return set(node_keys_to_names(self._tag_map[tag]))
359 def add_node(
360 self,
361 name: Name,
362 func=None,
363 *,
364 args=None,
365 kwds=None,
366 value=_MISSING_VALUE_SENTINEL,
367 converter=None,
368 serialize=True,
369 inspect=True,
370 group=None,
371 tags=None,
372 style=None,
373 executor=None,
374 metadata=None,
375 ):
376 """Adds or updates a node in a computation.
378 :param name: Name of the node to add. This may be any hashable object.
379 :param func: Function to use to calculate the node if the node is a calculation node. By default, the input
380 nodes to the function will be implied from the names of the function parameters. For example, a
381 parameter called ``a`` would be taken from the node called ``a``. This can be modified with the
382 ``kwds`` parameter.
383 :type func: Function, default None
384 :param args: Specifies a list of nodes that will be used to populate arguments of the function positionally
385 for a calculation node. e.g. If args is ``['a', 'b', 'c']`` then the function would be called with
386 three parameters, taken from the nodes 'a', 'b' and 'c' respectively.
387 :type args: List, default None
388 :param kwds: Specifies a mapping from parameter name to the node that should be used to populate that
389 parameter when calling the function for a calculation node. e.g. If args is ``{'x': 'a', 'y': 'b'}``
390 then the function would be called with parameters named 'x' and 'y', and their values would be taken
391 from nodes 'a' and 'b' respectively. Each entry in the dictionary can be read as "take parameter
392 [key] from node [value]".
393 :type kwds: Dictionary, default None
394 :param value: If given, the value is inserted into the node, and the node state set to UPTODATE.
395 :type value: default None
396 :param serialize: Whether the node should be serialized. Some objects cannot be serialized, in which
397 case, set serialize to False
398 :type serialize: boolean, default True
399 :param inspect: Whether to use introspection to determine the arguments of the function, which can be
400 slow. If this is not set, kwds and args must be set for the function to obtain parameters.
401 :type inspect: boolean, default True
402 :param group: Subgraph to render node in
403 :type group: default None
404 :param tags: Set of tags to apply to node
405 :type tags: Iterable
406 :param styles: Style to apply to node
407 :type styles: String, default None
408 :param executor: Name of executor to run node on
409 :type executor: string
410 """
411 node_key = to_nodekey(name)
412 LOG.debug(f"Adding node {node_key}")
413 has_value = value is not _MISSING_VALUE_SENTINEL
414 if value is _MISSING_VALUE_SENTINEL:
415 value = None
416 if tags is None:
417 tags = []
419 self.dag.add_node(node_key)
420 pred_edges = [(p, node_key) for p in self.dag.predecessors(node_key)]
421 self.dag.remove_edges_from(pred_edges)
422 node = self.dag.nodes[node_key]
424 if metadata is None:
425 if node_key in self._metadata:
426 del self._metadata[node_key]
427 else:
428 self._metadata[node_key] = metadata
430 self._set_state_and_literal_value(node_key, States.UNINITIALIZED, None, require_old_state=False)
432 node[NodeAttributes.TAG] = set()
433 node[NodeAttributes.STYLE] = style
434 node[NodeAttributes.GROUP] = group
435 node[NodeAttributes.ARGS] = {}
436 node[NodeAttributes.KWDS] = {}
437 node[NodeAttributes.FUNC] = None
438 node[NodeAttributes.EXECUTOR] = executor
439 node[NodeAttributes.CONVERTER] = converter
441 if func:
442 node[NodeAttributes.FUNC] = func
443 args_count = 0
444 if args:
445 args_count = len(args)
446 for i, arg in enumerate(args):
447 if isinstance(arg, ConstantValue):
448 node[NodeAttributes.ARGS][i] = arg.value
449 else:
450 input_vertex_name = arg
451 input_vertex_node_key = to_nodekey(input_vertex_name)
452 if not self.dag.has_node(input_vertex_node_key):
453 self.dag.add_node(input_vertex_node_key, **{NodeAttributes.STATE: States.PLACEHOLDER})
454 self._state_map[States.PLACEHOLDER].add(input_vertex_node_key)
455 self.dag.add_edge(
456 input_vertex_node_key, node_key, **{EdgeAttributes.PARAM: (_ParameterType.ARG, i)}
457 )
458 param_map = {}
459 if inspect:
460 signature = get_signature(func)
461 if not signature.has_var_args:
462 for param_name in signature.kwd_params[args_count:]:
463 if kwds is not None and param_name in kwds:
464 param_source = kwds[param_name]
465 else:
466 param_source = node_key.parent.join_parts(param_name)
467 param_map[param_name] = param_source
468 if signature.has_var_kwds and kwds is not None:
469 for param_name, param_source in kwds.items():
470 param_map[param_name] = param_source
471 default_names = signature.default_params
472 else:
473 if kwds is not None:
474 for param_name, param_source in kwds.items():
475 param_map[param_name] = param_source
476 default_names = []
477 for param_name, param_source in param_map.items():
478 if isinstance(param_source, ConstantValue):
479 node[NodeAttributes.KWDS][param_name] = param_source.value
480 else:
481 in_node_name = param_source
482 in_node_key = to_nodekey(in_node_name)
483 if not self.dag.has_node(in_node_key):
484 if param_name in default_names:
485 continue
486 else:
487 self.dag.add_node(in_node_key, **{NodeAttributes.STATE: States.PLACEHOLDER})
488 self._state_map[States.PLACEHOLDER].add(in_node_key)
489 self.dag.add_edge(in_node_key, node_key, **{EdgeAttributes.PARAM: (_ParameterType.KWD, param_name)})
490 self._set_descendents(node_key, States.STALE)
492 if has_value:
493 self._set_uptodate(node_key, value)
494 if node[NodeAttributes.STATE] == States.UNINITIALIZED:
495 self._try_set_computable(node_key)
496 self.set_tag(node_key, tags)
497 if serialize:
498 self.set_tag(node_key, SystemTags.SERIALIZE)
500 def _refresh_maps(self):
501 self._tag_map.clear()
502 for state in States:
503 self._state_map[state].clear()
504 for node_key in self._node_keys():
505 state = self.dag.nodes[node_key][NodeAttributes.STATE]
506 self._state_map[state].add(node_key)
507 tags = self.dag.nodes[node_key].get(NodeAttributes.TAG, set())
508 for tag in tags:
509 self._tag_map[tag].add(node_key)
511 def _set_tag_one(self, name: Name, tag):
512 node_key = to_nodekey(name)
513 self.dag.nodes[node_key][NodeAttributes.TAG].add(tag)
514 self._tag_map[tag].add(node_key)
516 def set_tag(self, name: Names, tag):
517 """Set tags on a node or nodes. Ignored if tags are already set.
519 :param name: Node or nodes to set tag for
520 :param tag: Tag to set
521 """
522 apply_n(self._set_tag_one, name, tag)
524 def _clear_tag_one(self, name: Name, tag):
525 node_key = to_nodekey(name)
526 self.dag.nodes[node_key][NodeAttributes.TAG].discard(tag)
527 self._tag_map[tag].discard(node_key)
529 def clear_tag(self, name: Names, tag):
530 """Clear tag on a node or nodes. Ignored if tags are not set.
532 :param name: Node or nodes to clear tags for
533 :param tag: Tag to clear
534 """
535 apply_n(self._clear_tag_one, name, tag)
537 def _set_style_one(self, name: Name, style):
538 node_key = to_nodekey(name)
539 self.dag.nodes[node_key][NodeAttributes.STYLE] = style
541 def set_style(self, name: Names, style):
542 """Set styles on a node or nodes.
544 :param name: Node or nodes to set style for
545 :param style: Style to set
546 """
547 apply_n(self._set_style_one, name, style)
549 def _clear_style_one(self, name):
550 node_key = to_nodekey(name)
551 self.dag.nodes[node_key][NodeAttributes.STYLE] = None
553 def clear_style(self, name):
554 """Clear style on a node or nodes.
556 :param name: Node or nodes to clear styles for
557 """
558 apply_n(self._clear_style_one, name)
560 def metadata(self, name):
561 """Get metadata for a node."""
562 node_key = to_nodekey(name)
563 if self.tree_has_path(name):
564 if node_key not in self._metadata:
565 self._metadata[node_key] = {}
566 return self._metadata[node_key]
567 else:
568 raise NonExistentNodeException(f"Node {node_key} does not exist.")
570 def delete_node(self, name):
571 """Delete a node from a computation.
573 When nodes are explicitly deleted with ``delete_node``, but are still depended on by other nodes, then they
574 will be set to PLACEHOLDER status. In this case, if the nodes that depend on a PLACEHOLDER node are deleted,
575 then the PLACEHOLDER node will also be deleted.
577 :param name: Name of the node to delete. If the node does not exist, a ``NonExistentNodeException`` will
578 be raised.
579 """
580 node_key = to_nodekey(name)
581 LOG.debug(f"Deleting node {node_key}")
583 if not self.dag.has_node(node_key):
584 raise NonExistentNodeException(f"Node {node_key} does not exist")
586 if node_key in self._metadata:
587 del self._metadata[node_key]
589 if len(self.dag.succ[node_key]) == 0:
590 preds = self.dag.predecessors(node_key)
591 state = self.dag.nodes[node_key][NodeAttributes.STATE]
592 self.dag.remove_node(node_key)
593 self._state_map[state].remove(node_key)
594 for n in preds:
595 if self.dag.nodes[n][NodeAttributes.STATE] == States.PLACEHOLDER:
596 self.delete_node(n)
597 else:
598 self._set_state(node_key, States.PLACEHOLDER)
600 def rename_node(self, old_name: Name | Mapping[Name, Name], new_name: Name | None = None):
601 """Rename a node in a computation.
603 :param old_name: Node to rename, or a dictionary of nodes to rename, with existing names as keys, and
604 new names as values
605 :param new_name: New name for node.
606 """
607 if hasattr(old_name, "__getitem__") and not isinstance(old_name, str):
608 for k, v in old_name.items():
609 LOG.debug(f"Renaming node {k} to {v}")
610 if new_name is not None:
611 raise ValueError("new_name must not be set if rename_node is passed a dictionary")
612 else:
613 name_mapping = old_name
614 else:
615 LOG.debug(f"Renaming node {old_name} to {new_name}")
616 old_node_key = to_nodekey(old_name)
617 if not self.dag.has_node(old_node_key):
618 raise NonExistentNodeException(f"Node {old_name} does not exist")
619 new_node_key = to_nodekey(new_name)
620 if self.dag.has_node(new_node_key):
621 raise NodeAlreadyExistsException(f"Node {new_name} already exists")
622 name_mapping = {old_name: new_name}
624 node_key_mapping = {to_nodekey(old_name): to_nodekey(new_name) for old_name, new_name in name_mapping.items()}
625 nx.relabel_nodes(self.dag, node_key_mapping, copy=False)
627 for old_node_key, new_node_key in node_key_mapping.items():
628 if old_node_key in self._metadata:
629 self._metadata[new_node_key] = self._metadata[old_node_key]
630 del self._metadata[old_node_key]
631 else:
632 if new_node_key in self._metadata:
633 del self._metadata[new_node_key]
635 self._refresh_maps()
637 def repoint(self, old_name: Name, new_name: Name):
638 """Changes all nodes that use old_name as an input to use new_name instead.
640 Note that if old_name is an input to new_name, then that will not be changed, to try to avoid introducing
641 circular dependencies, but other circular dependencies will not be checked.
643 If new_name does not exist, then it will be created as a PLACEHOLDER node.
645 :param old_name:
646 :param new_name:
647 :return:
648 """
649 old_node_key = to_nodekey(old_name)
650 new_node_key = to_nodekey(new_name)
651 if old_node_key == new_node_key:
652 return
654 changed_names = list(self.dag.successors(old_node_key))
656 if len(changed_names) > 0 and not self.dag.has_node(new_node_key):
657 self.dag.add_node(new_node_key, **{NodeAttributes.STATE: States.PLACEHOLDER})
658 self._state_map[States.PLACEHOLDER].add(new_node_key)
660 for name in changed_names:
661 if name == new_node_key:
662 continue
663 edge_data = self.dag.get_edge_data(old_node_key, name)
664 self.dag.add_edge(new_node_key, name, **edge_data)
665 self.dag.remove_edge(old_node_key, name)
667 for name in changed_names:
668 self.set_stale(name)
670 def insert(self, name: Name, value, force=False):
671 """Insert a value into a node of a computation.
673 Following insertation, the node will have state UPTODATE, and all its descendents will be COMPUTABLE or STALE.
675 If an attempt is made to insert a value into a node that does not exist, a ``NonExistentNodeException``
676 will be raised.
678 :param name: Name of the node to add.
679 :param value: The value to be inserted into the node.
680 :param force: Whether to force recalculation of descendents if node value and state would not be changed
681 """
682 node_key = to_nodekey(name)
683 LOG.debug(f"Inserting value into node {node_key}")
685 if not self.dag.has_node(node_key):
686 raise NonExistentNodeException(f"Node {node_key} does not exist")
688 state = self._state_one(name)
689 if state == States.PLACEHOLDER:
690 raise CannotInsertToPlaceholderNodeException(
691 "Cannot insert into placeholder node. Use add_node to create the node first"
692 )
694 if not force:
695 if state == States.UPTODATE:
696 current_value = self._value_one(name)
697 if value_eq(value, current_value):
698 return
700 self._set_state_and_value(node_key, States.UPTODATE, value)
701 self._set_descendents(node_key, States.STALE)
702 for n in self.dag.successors(node_key):
703 self._try_set_computable(n)
705 def insert_many(self, name_value_pairs: Iterable[tuple[Name, object]]):
706 """Insert values into many nodes of a computation simultaneously.
708 Following insertation, the nodes will have state UPTODATE, and all their descendents will be COMPUTABLE
709 or STALE. In the case of inserting many nodes, some of which are descendents of others, this ensures that
710 the inserted nodes have correct status, rather than being set as STALE when their ancestors are inserted.
712 If an attempt is made to insert a value into a node that does not exist, a ``NonExistentNodeException`` will be
713 raised, and none of the nodes will be inserted.
715 :param name_value_pairs: Each tuple should be a pair (name, value), where name is the name of the node to
716 insert the value into.
717 :type name_value_pairs: List of tuples
718 """
719 node_key_value_pairs = [(to_nodekey(name), value) for name, value in name_value_pairs]
720 LOG.debug(f"Inserting value into nodes {', '.join(str(name) for name, value in node_key_value_pairs)}")
722 for name, value in node_key_value_pairs:
723 if not self.dag.has_node(name):
724 raise NonExistentNodeException(f"Node {name} does not exist")
726 stale = set()
727 computable = set()
728 for name, value in node_key_value_pairs:
729 self._set_state_and_value(name, States.UPTODATE, value)
730 stale.update(nx.dag.descendants(self.dag, name))
731 computable.update(self.dag.successors(name))
732 names = set([name for name, value in node_key_value_pairs])
733 stale.difference_update(names)
734 computable.difference_update(names)
735 for name in stale:
736 self._set_state(name, States.STALE)
737 for name in computable:
738 self._try_set_computable(name)
740 def insert_from(self, other, nodes: Iterable[Name] | None = None):
741 """Insert values into another Computation object into this Computation object.
743 :param other: The computation object to take values from
744 :type Computation:
745 :param nodes: Only populate the nodes with the names provided in this list. By default, all nodes from the
746 other Computation object that have corresponding nodes in this Computation object will be inserted
747 :type nodes: List, default None
748 """
749 if nodes is None:
750 nodes = set(self.dag.nodes)
751 nodes.intersection_update(other.dag.nodes())
752 name_value_pairs = [(name, other.value(name)) for name in nodes]
753 self.insert_many(name_value_pairs)
755 def _set_state(self, node_key: NodeKey, state: States):
756 node = self.dag.nodes[node_key]
757 old_state = node[NodeAttributes.STATE]
758 self._state_map[old_state].remove(node_key)
759 node[NodeAttributes.STATE] = state
760 self._state_map[state].add(node_key)
762 def _set_state_and_value(
763 self, node_key: NodeKey, state: States, value: object, *, throw_conversion_exception: bool = True
764 ):
765 node = self.dag.nodes[node_key]
766 converter = node.get(NodeAttributes.CONVERTER)
767 if converter is None:
768 self._set_state_and_literal_value(node_key, state, value)
769 else:
770 try:
771 converted_value = converter(value)
772 self._set_state_and_literal_value(node_key, state, converted_value)
773 except Exception as e:
774 tb = traceback.format_exc()
775 self._set_error(node_key, e, tb)
776 if throw_conversion_exception:
777 raise e
779 def _set_state_and_literal_value(
780 self, node_key: NodeKey, state: States, value: object, require_old_state: bool = True
781 ):
782 node = self.dag.nodes[node_key]
783 try:
784 old_state = node[NodeAttributes.STATE]
785 self._state_map[old_state].remove(node_key)
786 except KeyError:
787 if require_old_state:
788 raise
789 node[NodeAttributes.STATE] = state
790 node[NodeAttributes.VALUE] = value
791 self._state_map[state].add(node_key)
793 def _set_states(self, node_keys: Iterable[NodeKey], state: States):
794 for name in node_keys:
795 node = self.dag.nodes[name]
796 old_state = node[NodeAttributes.STATE]
797 self._state_map[old_state].remove(name)
798 node[NodeAttributes.STATE] = state
799 self._state_map[state].update(node_keys)
801 def set_stale(self, name: Name):
802 """Set the state of a node and all its dependencies to STALE.
804 :param name: Name of the node to set as STALE.
805 """
806 node_key = to_nodekey(name)
807 node_keys = [node_key]
808 node_keys.extend(nx.dag.descendants(self.dag, node_key))
809 self._set_states(node_keys, States.STALE)
810 self._try_set_computable(node_key)
812 def pin(self, name: Name, value=None):
813 """Set the state of a node to PINNED.
815 :param name: Name of the node to set as PINNED.
816 :param value: Value to pin to the node, if provided.
817 :type value: default None
818 """
819 node_key = to_nodekey(name)
820 if value is not None:
821 self.insert(node_key, value)
822 self._set_states([node_key], States.PINNED)
824 def unpin(self, name):
825 """Unpin a node (state of node and all descendents will be set to STALE).
827 :param name: Name of the node to set as PINNED.
828 """
829 node_key = to_nodekey(name)
830 self.set_stale(node_key)
832 def _get_descendents(self, node_key: NodeKey, stop_states: set[States] | None = None) -> set[NodeKey]:
833 if self.dag.nodes[node_key][NodeAttributes.STATE] in stop_states:
834 return set()
835 if stop_states is None:
836 stop_states = []
837 visited = set()
838 to_visit = {node_key}
839 while to_visit:
840 n = to_visit.pop()
841 visited.add(n)
842 for n1 in self.dag.successors(n):
843 if n1 in visited:
844 continue
845 if self.dag.nodes[n1][NodeAttributes.STATE] in stop_states:
846 continue
847 to_visit.add(n1)
848 visited.remove(node_key)
849 return visited
851 def _set_descendents(self, node_key: NodeKey, state):
852 descendents = self._get_descendents(node_key, {States.PINNED})
853 self._set_states(descendents, state)
855 def _set_uninitialized(self, node_key: NodeKey):
856 self._set_states([node_key], States.UNINITIALIZED)
857 self.dag.nodes[node_key].pop(NodeAttributes.VALUE, None)
859 def _set_uptodate(self, node_key: NodeKey, value: object):
860 self._set_state_and_value(node_key, States.UPTODATE, value)
861 self._set_descendents(node_key, States.STALE)
862 for n in self.dag.successors(node_key):
863 self._try_set_computable(n)
865 def _set_error(self, node_key: NodeKey, exc: Exception, tb: inspect.Traceback):
866 self._set_state_and_literal_value(node_key, States.ERROR, Error(exc, tb))
867 self._set_descendents(node_key, States.STALE)
869 def _try_set_computable(self, node_key: NodeKey):
870 if self.dag.nodes[node_key][NodeAttributes.STATE] == States.PINNED:
871 return
872 if self.dag.nodes[node_key].get(NodeAttributes.FUNC) is not None:
873 for n in self.dag.predecessors(node_key):
874 if not self.dag.has_node(n):
875 return
876 if self.dag.nodes[n][NodeAttributes.STATE] != States.UPTODATE:
877 return
878 self._set_state(node_key, States.COMPUTABLE)
880 def _get_parameter_data(self, node_key: NodeKey):
881 for arg, value in self.dag.nodes[node_key][NodeAttributes.ARGS].items():
882 yield _ParameterItem(_ParameterType.ARG, arg, value)
883 for param_name, value in self.dag.nodes[node_key][NodeAttributes.KWDS].items():
884 yield _ParameterItem(_ParameterType.KWD, param_name, value)
885 for in_node_name in self.dag.predecessors(node_key):
886 param_value = self.dag.nodes[in_node_name][NodeAttributes.VALUE]
887 edge = self.dag[in_node_name][node_key]
888 param_type, param_name = edge[EdgeAttributes.PARAM]
889 yield _ParameterItem(param_type, param_name, param_value)
891 def _get_func_args_kwds(self, node_key: NodeKey):
892 node0 = self.dag.nodes[node_key]
893 f = node0[NodeAttributes.FUNC]
894 executor_name = node0.get(NodeAttributes.EXECUTOR)
895 args, kwds = [], {}
896 for param in self._get_parameter_data(node_key):
897 if param.type == _ParameterType.ARG:
898 idx = param.name
899 while len(args) <= idx:
900 args.append(None)
901 args[idx] = param.value
902 elif param.type == _ParameterType.KWD:
903 kwds[param.name] = param.value
904 else:
905 raise Exception(f"Unexpected param type: {param.type}")
906 return f, executor_name, args, kwds
908 def get_definition_args_kwds(self, name: Name) -> tuple[list, dict]:
909 """Get the arguments and keyword arguments for a node's function definition."""
910 res_args = []
911 res_kwds = {}
912 node_key = to_nodekey(name)
913 node_data = self.dag.nodes[node_key]
914 if NodeAttributes.ARGS in node_data:
915 for idx, value in node_data[NodeAttributes.ARGS].items():
916 while len(res_args) <= idx:
917 res_args.append(None)
918 res_args[idx] = C(value)
919 if NodeAttributes.KWDS in node_data:
920 for param_name, value in node_data[NodeAttributes.KWDS].items():
921 res_kwds[param_name] = C(value)
922 for in_node_key in self.dag.predecessors(node_key):
923 edge = self.dag[in_node_key][node_key]
924 if EdgeAttributes.PARAM in edge:
925 param_type, param_name = edge[EdgeAttributes.PARAM]
926 if param_type == _ParameterType.ARG:
927 idx: int = param_name
928 while len(res_args) <= idx:
929 res_args.append(None)
930 res_args[idx] = in_node_key.name
931 elif param_type == _ParameterType.KWD:
932 res_kwds[param_name] = in_node_key.name
933 else:
934 raise Exception(f"Unexpected param type: {param_type}")
935 return res_args, res_kwds
937 def _compute_nodes(self, node_keys: Iterable[NodeKey], raise_exceptions: bool = False):
938 LOG.debug(f"Computing nodes {node_keys}")
940 futs = {}
942 def run(name):
943 f, executor_name, args, kwds = self._get_func_args_kwds(name)
944 if executor_name is None:
945 executor = self.default_executor
946 else:
947 executor = self.executor_map[executor_name]
948 fut = executor.submit(_eval_node, name, f, args, kwds, raise_exceptions)
949 futs[fut] = name
951 computed = set()
953 for node_key in node_keys:
954 node0 = self.dag.nodes[node_key]
955 state = node0[NodeAttributes.STATE]
956 if state == States.COMPUTABLE:
957 run(node_key)
959 while len(futs) > 0:
960 done, not_done = wait(futs.keys(), return_when=FIRST_COMPLETED)
961 for fut in done:
962 node_key = futs.pop(fut)
963 node0 = self.dag.nodes[node_key]
964 try:
965 value, exc, tb, start_dt, end_dt = fut.result()
966 except Exception as e:
967 exc = e
968 tb = traceback.format_exc()
969 self._set_error(node_key, exc, tb)
970 raise
971 delta = (end_dt - start_dt).total_seconds()
972 if exc is None:
973 self._set_state_and_value(node_key, States.UPTODATE, value, throw_conversion_exception=False)
974 node0[NodeAttributes.TIMING] = TimingData(start_dt, end_dt, delta)
975 self._set_descendents(node_key, States.STALE)
976 for n in self.dag.successors(node_key):
977 logging.debug(str(node_key) + " " + str(n) + " " + str(computed))
978 if n in computed:
979 raise LoopDetectedException(f"Calculating {node_key} for the second time")
980 self._try_set_computable(n)
981 node0 = self.dag.nodes[n]
982 state = node0[NodeAttributes.STATE]
983 if state == States.COMPUTABLE and n in node_keys:
984 run(n)
985 else:
986 self._set_error(node_key, exc, tb)
987 computed.add(node_key)
989 def _get_calc_node_keys(self, node_key: NodeKey) -> list[NodeKey]:
990 g = nx.DiGraph()
991 g.add_nodes_from(self.dag.nodes)
992 g.add_edges_from(self.dag.edges)
993 for n in nx.ancestors(g, node_key):
994 node = self.dag.nodes[n]
995 state = node[NodeAttributes.STATE]
996 if state == States.UPTODATE or state == States.PINNED:
997 g.remove_node(n)
999 ancestors = nx.ancestors(g, node_key)
1000 for n in ancestors:
1001 if state == States.UNINITIALIZED and len(self.dag.pred[n]) == 0:
1002 raise Exception(f"Cannot compute {node_key} because {n} uninitialized")
1003 if state == States.PLACEHOLDER:
1004 raise Exception(f"Cannot compute {node_key} because {n} is placeholder")
1006 ancestors.add(node_key)
1007 nodes_sorted = topological_sort(g)
1008 return [n for n in nodes_sorted if n in ancestors]
1010 def _get_calc_node_names(self, name: Name) -> Names:
1011 node_key = to_nodekey(name)
1012 return node_keys_to_names(self._get_calc_node_keys(node_key))
1014 def compute(self, name: Name | Iterable[Name], raise_exceptions=False):
1015 """Compute a node and all necessary predecessors.
1017 Following the computation, if successful, the target node, and all necessary ancestors that were not already
1018 UPTODATE will have been calculated and set to UPTODATE. Any node that did not need to be calculated will not
1019 have been recalculated.
1021 If any nodes raises an exception, then the state of that node will be set to ERROR, and its value set to an
1022 object containing the exception object, as well as a traceback. This will not halt the computation, which
1023 will proceed as far as it can, until no more nodes that would be required to calculate the target are
1024 COMPUTABLE.
1026 :param name: Name of the node to compute
1027 :param raise_exceptions: Whether to pass exceptions raised by node computations back to the caller
1028 :type raise_exceptions: Boolean, default False
1029 """
1030 if isinstance(name, (types.GeneratorType, list)):
1031 calc_nodes = set()
1032 for name0 in name:
1033 node_key = to_nodekey(name0)
1034 for n in self._get_calc_node_keys(node_key):
1035 calc_nodes.add(n)
1036 else:
1037 node_key = to_nodekey(name)
1038 calc_nodes = self._get_calc_node_keys(node_key)
1039 self._compute_nodes(calc_nodes, raise_exceptions=raise_exceptions)
1041 def compute_all(self, raise_exceptions=False):
1042 """Compute all nodes of a computation that can be computed.
1044 Nodes that are already UPTODATE will not be recalculated. Following the computation, if successful, all
1045 nodes will have state UPTODATE, except UNINITIALIZED input nodes and PLACEHOLDER nodes.
1047 If any nodes raises an exception, then the state of that node will be set to ERROR, and its value set to an
1048 object containing the exception object, as well as a traceback. This will not halt the computation, which
1049 will proceed as far as it can, until no more nodes are COMPUTABLE.
1051 :param raise_exceptions: Whether to pass exceptions raised by node computations back to the caller
1052 :type raise_exceptions: Boolean, default False
1053 """
1054 self._compute_nodes(self._node_keys(), raise_exceptions=raise_exceptions)
1056 def _node_keys(self) -> list[NodeKey]:
1057 """Get a list of nodes in this computation.
1059 :return: List of nodes.
1060 """
1061 return list(self.dag.nodes)
1063 def nodes(self) -> list[Name]:
1064 """Get a list of nodes in this computation.
1066 :return: List of nodes.
1067 """
1068 return list(n.name for n in self.dag.nodes)
1070 def get_tree_list_children(self, name: Name) -> set[Name]:
1071 """Get a list of nodes in this computation.
1073 :return: List of nodes.
1074 """
1075 node_key = to_nodekey(name)
1076 idx = len(node_key.parts)
1077 result = set()
1078 for n in self.dag.nodes:
1079 if n.is_descendent_of(node_key):
1080 result.add(n.parts[idx])
1081 return result
1083 def has_node(self, name: Name):
1084 """Check if a node with the given name exists in the computation."""
1085 node_key = to_nodekey(name)
1086 return node_key in self.dag.nodes
1088 def tree_has_path(self, name: Name):
1089 """Check if a hierarchical path exists in the computation tree."""
1090 node_key = to_nodekey(name)
1091 if node_key.is_root:
1092 return True
1093 if self.has_node(node_key):
1094 return True
1095 for n in self.dag.nodes:
1096 if n.is_descendent_of(node_key):
1097 return True
1098 return False
1100 def get_tree_descendents(
1101 self, name: Name | None = None, *, include_stem: bool = True, graph_nodes_only: bool = False
1102 ) -> set[Name]:
1103 """Get a list of descendent blocks and nodes.
1105 Returns blocks and nodes that are descendents of the input node,
1106 e.g. for node 'foo', might return ['foo/bar', 'foo/baz'].
1108 :param name: Name of node to get descendents for
1109 :return: List of descendent node names
1110 """
1111 node_key = NodeKey.root() if name is None else to_nodekey(name)
1112 stemsize = len(node_key.parts)
1113 result = set()
1114 for n in self.dag.nodes:
1115 if n.is_descendent_of(node_key):
1116 if graph_nodes_only:
1117 nodes = [n]
1118 else:
1119 nodes = n.ancestors()
1120 for n2 in nodes:
1121 if n2.is_descendent_of(node_key):
1122 if include_stem:
1123 nm = n2.name
1124 else:
1125 nm = NodeKey(tuple(n2.parts[stemsize:])).name
1126 result.add(nm)
1127 return result
1129 def _state_one(self, name: Name):
1130 node_key = to_nodekey(name)
1131 return self.dag.nodes[node_key][NodeAttributes.STATE]
1133 def state(self, name: Name | Names):
1134 """Get the state of a node.
1136 This can also be accessed using the attribute-style accessor ``s`` if ``name`` is a valid Python
1137 attribute name::
1139 >>> comp = Computation()
1140 >>> comp.add_node('foo', value=1)
1141 >>> comp.state('foo')
1142 <States.UPTODATE: 4>
1143 >>> comp.s.foo
1144 <States.UPTODATE: 4>
1146 :param name: Name or names of the node to get state for
1147 :type name: Name or Names
1148 """
1149 return apply1(self._state_one, name)
1151 def _value_one(self, name: Name):
1152 node_key = to_nodekey(name)
1153 return self.dag.nodes[node_key][NodeAttributes.VALUE]
1155 def value(self, name: Name | Names):
1156 """Get the current value of a node.
1158 This can also be accessed using the attribute-style accessor ``v`` if ``name`` is a valid Python
1159 attribute name::
1161 >>> comp = Computation()
1162 >>> comp.add_node('foo', value=1)
1163 >>> comp.value('foo')
1164 1
1165 >>> comp.v.foo
1166 1
1168 :param name: Name or names of the node to get the value of
1169 :type name: Name or Names
1170 """
1171 return apply1(self._value_one, name)
1173 def compute_and_get_value(self, name: Name):
1174 """Get the current value of a node.
1176 This can also be accessed using the attribute-style accessor ``v`` if ``name`` is a valid Python
1177 attribute name::
1179 >>> comp = Computation()
1180 >>> comp.add_node('foo', value=1)
1181 >>> comp.add_node('bar', lambda foo: foo + 1)
1182 >>> comp.compute_and_get_value('bar')
1183 2
1184 >>> comp.x.bar
1185 2
1187 :param name: Name or names of the node to get the value of
1188 :type name: Name
1189 """
1190 name = to_nodekey(name)
1191 if self.state(name) == States.UPTODATE:
1192 return self.value(name)
1193 self.compute(name, raise_exceptions=True)
1194 if self.state(name) == States.UPTODATE:
1195 return self.value(name)
1196 raise ComputationError(f"Unable to compute node {name}")
1198 def _tag_one(self, name: Name):
1199 node_key = to_nodekey(name)
1200 node = self.dag.nodes[node_key]
1201 return node[NodeAttributes.TAG]
1203 def tags(self, name: Name | Names):
1204 """Get the tags associated with a node.
1206 >>> comp = Computation()
1207 >>> comp.add_node('a', tags=['foo', 'bar'])
1208 >>> sorted(comp.t.a)
1209 ['__serialize__', 'bar', 'foo']
1211 :param name: Name or names of the node to get the tags of
1212 :return:
1213 """
1214 return apply1(self._tag_one, name)
1216 def nodes_by_tag(self, tag) -> set[Name]:
1217 """Get the names of nodes with a particular tag or tags.
1219 :param tag: Tag or tags for which to retrieve nodes
1220 :return: Names of the nodes with those tags
1221 """
1222 nodes = set()
1223 for tag1 in as_iterable(tag):
1224 nodes1 = self._tag_map.get(tag1)
1225 if nodes1 is not None:
1226 nodes.update(nodes1)
1227 return set(n.name for n in nodes)
1229 def _style_one(self, name: Name):
1230 node_key = to_nodekey(name)
1231 node = self.dag.nodes[node_key]
1232 return node.get(NodeAttributes.STYLE)
1234 def styles(self, name: Name | Names):
1235 """Get the tags associated with a node.
1237 >>> comp = Computation()
1238 >>> comp.add_node('a', style='dot')
1239 >>> comp.style.a
1240 'dot'
1242 :param name: Name or names of the node to get the tags of
1243 :return:
1244 """
1245 return apply1(self._style_one, name)
1247 def _get_item_one(self, name: Name):
1248 node_key = to_nodekey(name)
1249 node = self.dag.nodes[node_key]
1250 return NodeData(node[NodeAttributes.STATE], node[NodeAttributes.VALUE])
1252 def __getitem__(self, name: Name | Names):
1253 """Get the state and current value of a node.
1255 :param name: Name of the node to get the state and value of
1256 """
1257 return apply1(self._get_item_one, name)
1259 def _get_timing_one(self, name: Name):
1260 node_key = to_nodekey(name)
1261 node = self.dag.nodes[node_key]
1262 return node.get(NodeAttributes.TIMING, None)
1264 def get_timing(self, name: Name | Names):
1265 """Get the timing information for a node.
1267 :param name: Name or names of the node to get the timing information of
1268 :return:
1269 """
1270 return apply1(self._get_timing_one, name)
1272 def to_df(self):
1273 """Get a dataframe containing the states and value of all nodes of computation.
1275 ::
1277 >>> import loman
1278 >>> comp = loman.Computation()
1279 >>> comp.add_node('foo', value=1)
1280 >>> comp.add_node('bar', value=2)
1281 >>> comp.to_df() # doctest: +NORMALIZE_WHITESPACE
1282 state value
1283 foo States.UPTODATE 1
1284 bar States.UPTODATE 2
1285 """
1286 df = pd.DataFrame(index=topological_sort(self.dag))
1287 df[NodeAttributes.STATE] = pd.Series(nx.get_node_attributes(self.dag, NodeAttributes.STATE))
1288 df[NodeAttributes.VALUE] = pd.Series(nx.get_node_attributes(self.dag, NodeAttributes.VALUE))
1289 df_timing = pd.DataFrame.from_dict(nx.get_node_attributes(self.dag, "timing"), orient="index")
1290 df = pd.merge(df, df_timing, left_index=True, right_index=True, how="left")
1291 df.index = pd.Index([nk.name for nk in df.index])
1292 return df
1294 def to_dict(self):
1295 """Get a dictionary containing the values of all nodes of a computation.
1297 ::
1299 >>> import loman
1300 >>> comp = loman.Computation()
1301 >>> comp.add_node('foo', value=1)
1302 >>> comp.add_node('bar', value=2)
1303 >>> comp.to_dict() # doctest: +ELLIPSIS
1304 {NodeKey('foo'): 1, NodeKey('bar'): 2}
1305 """
1306 return nx.get_node_attributes(self.dag, NodeAttributes.VALUE)
1308 def _get_inputs_one_node_keys(self, node_key: NodeKey) -> list[NodeKey]:
1309 args_dict = {}
1310 kwds = []
1311 max_arg_index = -1
1312 for input_node in self.dag.predecessors(node_key):
1313 input_edge = self.dag[input_node][node_key]
1314 input_type, input_param = input_edge[EdgeAttributes.PARAM]
1315 if input_type == _ParameterType.ARG:
1316 idx = input_param
1317 max_arg_index = max(max_arg_index, idx)
1318 args_dict[idx] = input_node
1319 elif input_type == _ParameterType.KWD:
1320 kwds.append(input_node)
1321 if max_arg_index >= 0:
1322 args = [None] * (max_arg_index + 1)
1323 for idx, input_node in args_dict.items():
1324 args[idx] = input_node
1325 return args + kwds
1326 else:
1327 return kwds
1329 def _get_inputs_one_names(self, name: Name) -> Names:
1330 node_key = to_nodekey(name)
1331 return node_keys_to_names(self._get_inputs_one_node_keys(node_key))
1333 def get_inputs(self, name: Name | Names) -> list[Names]:
1334 """Get a list of the inputs for a node or set of nodes.
1336 :param name: Name or names of nodes to get inputs for
1337 :return: If name is scalar, return a list of upstream nodes used as input. If name is a list, return a
1338 list of list of inputs.
1339 """
1340 return apply1(self._get_inputs_one_names, name)
1342 def _get_ancestors_node_keys(self, node_keys: Iterable[NodeKey], include_self=True) -> set[NodeKey]:
1343 ancestors = set()
1344 for n in node_keys:
1345 if include_self:
1346 ancestors.add(n)
1347 for ancestor in nx.ancestors(self.dag, n):
1348 ancestors.add(ancestor)
1349 return ancestors
1351 def get_ancestors(self, names: Name | Names, include_self=True) -> Names:
1352 """Get all ancestor nodes of the specified nodes."""
1353 node_keys = names_to_node_keys(names)
1354 ancestor_node_keys = self._get_ancestors_node_keys(node_keys, include_self)
1355 return node_keys_to_names(ancestor_node_keys)
1357 def _get_original_inputs_node_keys(self, node_keys: list[NodeKey] | None) -> Names:
1358 if node_keys is None:
1359 node_keys = self._node_keys()
1360 else:
1361 node_keys = self._get_ancestors_node_keys(node_keys)
1362 return [n for n in node_keys if self.dag.nodes[n].get(NodeAttributes.FUNC) is None]
1364 def get_original_inputs(self, names: Name | Names | None = None) -> Names:
1365 """Get a list of the original non-computed inputs for a node or set of nodes.
1367 :param names: Name or names of nodes to get inputs for
1368 :return: Return a list of original non-computed inputs that are ancestors of the input nodes
1369 """
1370 if names is None:
1371 node_keys = None
1372 else:
1373 node_keys = names_to_node_keys(names)
1375 node_keys = self._get_original_inputs_node_keys(node_keys)
1377 return node_keys_to_names(node_keys)
1379 def _get_outputs_one(self, name: Name) -> Names:
1380 node_key = to_nodekey(name)
1381 output_node_keys = list(self.dag.successors(node_key))
1382 return node_keys_to_names(output_node_keys)
1384 def get_outputs(self, name: Name | Names) -> Names | list[Names]:
1385 """Get a list of the outputs for a node or set of nodes.
1387 :param name: Name or names of nodes to get outputs for
1388 :return: If name is scalar, return a list of downstream nodes used as output. If name is a list, return a
1389 list of list of outputs.
1391 """
1392 return apply1(self._get_outputs_one, name)
1394 def _get_descendents_node_keys(self, node_keys: Iterable[NodeKey], include_self: bool = True) -> Names:
1395 ancestor_node_keys = set()
1396 for node_key in node_keys:
1397 if include_self:
1398 ancestor_node_keys.add(node_key)
1399 for ancestor in nx.descendants(self.dag, node_key):
1400 ancestor_node_keys.add(ancestor)
1401 return ancestor_node_keys
1403 def get_descendents(self, names: Name | Names, include_self: bool = True) -> Names:
1404 """Get all descendent nodes of the specified nodes."""
1405 node_keys = names_to_node_keys(names)
1406 descendent_node_keys = self._get_descendents_node_keys(node_keys, include_self)
1407 return node_keys_to_names(descendent_node_keys)
1409 def get_final_outputs(self, names: Name | Names | None = None):
1410 """Get final output nodes (nodes with no descendants) from the specified nodes."""
1411 if names is None:
1412 node_keys = self._node_keys()
1413 else:
1414 node_keys = names_to_node_keys(names)
1415 node_keys = self._get_descendents_node_keys(node_keys)
1416 output_node_keys = [n for n in node_keys if len(nx.descendants(self.dag, n)) == 0]
1417 return node_keys_to_names(output_node_keys)
1419 def get_source(self, name: Name) -> str:
1420 """Get the source code for a node."""
1421 node_key = to_nodekey(name)
1422 func = self.dag.nodes[node_key].get(NodeAttributes.FUNC, None)
1423 if func is not None:
1424 file = inspect.getsourcefile(func)
1425 _, lineno = inspect.getsourcelines(func)
1426 source = inspect.getsource(func)
1427 return f"{file}:{lineno}\n\n{source}"
1428 else:
1429 return "NOT A CALCULATED NODE"
1431 def print_source(self, name: Name):
1432 """Print the source code for a computation node."""
1433 print(self.get_source(name))
1435 def restrict(self, output_names: Name | Names, input_names: Name | Names | None = None):
1436 """Restrict a computation to the ancestors of a set of output nodes.
1438 Excludes ancestors of a set of input nodes.
1440 If the set of input_nodes that is specified is not sufficient for the set of output_nodes then additional
1441 nodes that are ancestors of the output_nodes will be included, but the input nodes specified will be input
1442 nodes of the modified Computation.
1444 :param output_nodes:
1445 :param input_nodes:
1446 :return: None - modifies existing computation in place
1447 """
1448 if input_names is not None:
1449 for name in input_names:
1450 nodedata = self._get_item_one(name)
1451 self.add_node(name)
1452 self._set_state_and_literal_value(to_nodekey(name), nodedata.state, nodedata.value)
1453 output_node_keys = names_to_node_keys(output_names)
1454 ancestor_node_keys = self._get_ancestors_node_keys(output_node_keys)
1455 self.dag.remove_nodes_from([n for n in self.dag if n not in ancestor_node_keys])
1457 def __getstate__(self):
1458 """Prepare computation for serialization by removing non-serializable nodes."""
1459 node_serialize = nx.get_node_attributes(self.dag, NodeAttributes.TAG)
1460 obj = self.copy()
1461 for name, tags in node_serialize.items():
1462 if SystemTags.SERIALIZE not in tags:
1463 obj._set_uninitialized(name)
1464 return {"dag": obj.dag}
1466 def __setstate__(self, state):
1467 """Restore computation from serialized state."""
1468 self.__init__()
1469 self.dag = state["dag"]
1470 self._refresh_maps()
1472 def write_dill_old(self, file_):
1473 """Serialize a computation to a file or file-like object.
1475 :param file_: If string, writes to a file
1476 :type file_: File-like object, or string
1477 """
1478 warnings.warn("write_dill_old is deprecated, use write_dill instead", DeprecationWarning, stacklevel=2)
1479 original_getstate = self.__class__.__getstate__
1480 original_setstate = self.__class__.__setstate__
1482 try:
1483 del self.__class__.__getstate__
1484 del self.__class__.__setstate__
1486 node_serialize = nx.get_node_attributes(self.dag, NodeAttributes.TAG)
1487 obj = self.copy()
1488 obj.executor_map = None
1489 obj.default_executor = None
1490 for name, tags in node_serialize.items():
1491 if SystemTags.SERIALIZE not in tags:
1492 obj._set_uninitialized(name)
1494 if isinstance(file_, str):
1495 with open(file_, "wb") as f:
1496 dill.dump(obj, f)
1497 else:
1498 dill.dump(obj, file_)
1499 finally:
1500 self.__class__.__getstate__ = original_getstate
1501 self.__class__.__setstate__ = original_setstate
1503 def write_dill(self, file_):
1504 """Serialize a computation to a file or file-like object.
1506 :param file_: If string, writes to a file
1507 :type file_: File-like object, or string
1508 """
1509 if isinstance(file_, str):
1510 with open(file_, "wb") as f:
1511 dill.dump(self, f)
1512 else:
1513 dill.dump(self, file_)
1515 @staticmethod
1516 def read_dill(file_):
1517 """Deserialize a computation from a file or file-like object.
1519 :param file_: If string, writes to a file
1520 :type file_: File-like object, or string
1521 """
1522 if isinstance(file_, str):
1523 with open(file_, "rb") as f:
1524 obj = dill.load(f)
1525 else:
1526 obj = dill.load(file_)
1527 if isinstance(obj, Computation):
1528 return obj
1529 else:
1530 raise Exception()
1532 def copy(self):
1533 """Create a copy of a computation.
1535 The copy is shallow. Any values in the new Computation's DAG will be the same object as this Computation's
1536 DAG. As new objects will be created by any further computations, this should not be an issue.
1538 :rtype: Computation
1539 """
1540 obj = Computation()
1541 obj.dag = nx.DiGraph(self.dag)
1542 obj._tag_map = {tag: nodes.copy() for tag, nodes in self._tag_map.items()}
1543 obj._state_map = {state: nodes.copy() for state, nodes in self._state_map.items()}
1544 return obj
1546 def add_named_tuple_expansion(self, name, namedtuple_type, group=None):
1547 """Automatically add nodes to extract each element of a named tuple type.
1549 It is often convenient for a calculation to return multiple values, and it is polite to do this a namedtuple
1550 rather than a regular tuple, so that later users have same name to identify elements of the tuple. It can
1551 also help make a computation clearer if a downstream computation depends on one element of such a tuple,
1552 rather than the entire tuple. This does not affect the computation per se, but it does make the intention
1553 clearer.
1555 To avoid having to create many boiler-plate node definitions to expand namedtuples, the
1556 ``add_named_tuple_expansion`` method automatically creates new nodes for each element of a tuple. The
1557 convention is that an element called 'element', in a node called 'node' will be expanded into a new node
1558 called 'node.element', and that this will be applied for each element.
1560 Example::
1562 >>> from collections import namedtuple
1563 >>> Coordinate = namedtuple('Coordinate', ['x', 'y'])
1564 >>> comp = Computation()
1565 >>> comp.add_node('c', value=Coordinate(1, 2))
1566 >>> comp.add_named_tuple_expansion('c', Coordinate)
1567 >>> comp.compute_all()
1568 >>> comp.value('c.x')
1569 1
1570 >>> comp.value('c.y')
1571 2
1573 :param name: Node to cera
1574 :param namedtuple_type: Expected type of the node
1575 :type namedtuple_type: namedtuple class
1576 """
1578 def make_f(field_name):
1579 def get_field_value(tuple):
1580 return getattr(tuple, field_name)
1582 return get_field_value
1584 for field_name in namedtuple_type._fields:
1585 node_name = f"{name}.{field_name}"
1586 self.add_node(node_name, make_f(field_name), kwds={"tuple": name}, group=group)
1587 self.set_tag(node_name, SystemTags.EXPANSION)
1589 def add_map_node(self, result_node, input_node, subgraph, subgraph_input_node, subgraph_output_node):
1590 """Apply a graph to each element of iterable.
1592 In turn, each element in the ``input_node`` of this graph will be inserted in turn into the subgraph's
1593 ``subgraph_input_node``, then the subgraph's ``subgraph_output_node`` calculated. The resultant list, with
1594 an element or each element in ``input_node``, will be inserted into ``result_node`` of this graph. In this
1595 way ``add_map_node`` is similar to ``map`` in functional programming.
1597 :param result_node: The node to place a list of results in **this** graph
1598 :param input_node: The node to get a list input values from **this** graph
1599 :param subgraph: The graph to use to perform calculation for each element
1600 :param subgraph_input_node: The node in **subgraph** to insert each element in turn
1601 :param subgraph_output_node: The node in **subgraph** to read the result for each element
1602 """
1604 def f(xs):
1605 results = []
1606 is_error = False
1607 for x in xs:
1608 subgraph.insert(subgraph_input_node, x)
1609 subgraph.compute(subgraph_output_node)
1610 if subgraph.state(subgraph_output_node) == States.UPTODATE:
1611 results.append(subgraph.value(subgraph_output_node))
1612 else:
1613 is_error = True
1614 results.append(subgraph.copy())
1615 if is_error:
1616 raise MapException(f"Unable to calculate {result_node}", results)
1617 return results
1619 self.add_node(result_node, f, kwds={"xs": input_node})
1621 def prepend_path(self, path, prefix_path: NodeKey):
1622 """Prepend a prefix path to a node path."""
1623 if isinstance(path, ConstantValue):
1624 return path
1625 path = to_nodekey(path)
1626 return prefix_path.join(path)
1628 def add_block(
1629 self,
1630 base_path: Name,
1631 block: "Computation",
1632 *,
1633 keep_values: bool | None = True,
1634 links: dict | None = None,
1635 metadata: dict | None = None,
1636 ):
1637 """Add a computation block as a subgraph to this computation."""
1638 base_path = to_nodekey(base_path)
1639 for node_name in block.nodes():
1640 node_key = to_nodekey(node_name)
1641 node_data = block.dag.nodes[node_key]
1642 tags = node_data.get(NodeAttributes.TAG, None)
1643 style = node_data.get(NodeAttributes.STYLE, None)
1644 group = node_data.get(NodeAttributes.GROUP, None)
1645 args, kwds = block.get_definition_args_kwds(node_key)
1646 args = [self.prepend_path(arg, base_path) for arg in args]
1647 kwds = {k: self.prepend_path(v, base_path) for k, v in kwds.items()}
1648 func = node_data.get(NodeAttributes.FUNC, None)
1649 executor = node_data.get(NodeAttributes.EXECUTOR, None)
1650 converter = node_data.get(NodeAttributes.CONVERTER, None)
1651 new_node_name = self.prepend_path(node_name, base_path)
1652 self.add_node(
1653 new_node_name,
1654 func,
1655 args=args,
1656 kwds=kwds,
1657 converter=converter,
1658 serialize=False,
1659 inspect=False,
1660 group=group,
1661 tags=tags,
1662 style=style,
1663 executor=executor,
1664 )
1665 if keep_values and NodeAttributes.VALUE in node_data:
1666 new_node_key = to_nodekey(new_node_name)
1667 self._set_state_and_literal_value(
1668 new_node_key, node_data[NodeAttributes.STATE], node_data[NodeAttributes.VALUE]
1669 )
1670 if links is not None:
1671 for target, source in links.items():
1672 self.link(base_path.join_parts(target), source)
1673 if metadata is not None:
1674 self._metadata[base_path] = metadata
1675 else:
1676 if base_path in self._metadata:
1677 del self._metadata[base_path]
1679 def link(self, target: Name, source: Name):
1680 """Create a link between two nodes in the computation graph."""
1681 target = to_nodekey(target)
1682 source = to_nodekey(source)
1683 if target == source:
1684 return
1686 target_style = self._style_one(target) if self.has_node(target) else None
1687 source_style = self._style_one(source) if self.has_node(source) else None
1688 style = target_style if target_style else source_style
1690 self.add_node(target, identity_function, kwds={"x": source}, style=style)
1692 def _repr_svg_(self):
1693 return GraphView(self).svg()
1695 def draw(
1696 self,
1697 root: NodeKey | None = None,
1698 *,
1699 node_transformations: dict | None = None,
1700 cmap=None,
1701 colors="state",
1702 shapes=None,
1703 graph_attr=None,
1704 node_attr=None,
1705 edge_attr=None,
1706 show_expansion=False,
1707 collapse_all=True,
1708 ):
1709 """Draw a computation's current state using the GraphViz utility.
1711 :param root: Optional PathType. Sub-block to draw
1712 :param cmap: Default: None
1713 :param colors: 'state' - colors indicate state. 'timing' - colors indicate execution time. Default: 'state'.
1714 :param shapes: None - ovals. 'type' - shapes indicate type. Default: None.
1715 :param graph_attr: Mapping of (attribute, value) pairs for the graph. For example
1716 ``graph_attr={'size': '"10,8"'}`` can control the size of the output graph
1717 :param node_attr: Mapping of (attribute, value) pairs set for all nodes.
1718 :param edge_attr: Mapping of (attribute, value) pairs set for all edges.
1719 :param collapse_all: Whether to collapse all blocks that aren't explicitly expanded.
1720 """
1721 node_formatter = NodeFormatter.create(cmap, colors, shapes)
1722 node_transformations = node_transformations.copy() if node_transformations is not None else {}
1723 if not show_expansion:
1724 for nodekey in self.nodes_by_tag(SystemTags.EXPANSION):
1725 node_transformations[nodekey] = NodeTransformations.CONTRACT
1726 v = GraphView(
1727 self,
1728 root=root,
1729 node_formatter=node_formatter,
1730 graph_attr=graph_attr,
1731 node_attr=node_attr,
1732 edge_attr=edge_attr,
1733 node_transformations=node_transformations,
1734 collapse_all=collapse_all,
1735 )
1736 return v
1738 def view(self, cmap=None, colors="state", shapes=None):
1739 """Create and display a visualization of the computation graph."""
1740 node_formatter = NodeFormatter.create(cmap, colors, shapes)
1741 v = GraphView(self, node_formatter=node_formatter)
1742 v.view()
1744 def print_errors(self):
1745 """Print tracebacks for every node with state "ERROR" in a Computation."""
1746 for n in self.nodes():
1747 if self.s[n] == States.ERROR:
1748 print(f"{n}")
1749 print("=" * len(n))
1750 print()
1751 print(self.v[n].traceback)
1752 print()
1754 @classmethod
1755 def from_class(cls, definition_class, ignore_self=True):
1756 """Create a computation from a class with decorated methods."""
1757 comp = cls()
1758 obj = definition_class()
1759 populate_computation_from_class(comp, definition_class, obj, ignore_self=ignore_self)
1760 return comp
1762 def inject_dependencies(self, dependencies: dict, *, force: bool = False):
1763 """Injects dependencies into the nodes of the current computation where nodes are in a placeholder state.
1765 (or all possible nodes when the 'force' parameter is set to True), using values
1766 provided in the 'dependencies' dictionary.
1768 Each key in the 'dependencies' dictionary corresponds to a node identifier, and the associated
1769 value is the dependency object to inject. If the value is a callable, it will be added as a calc node.
1771 :param dependencies: A dictionary where each key-value pair consists of a node identifier and
1772 its corresponding dependency object or a callable that returns the dependency object.
1773 :param force: A boolean flag that, when set to True, forces the replacement of existing node values
1774 with the ones provided in 'dependencies', regardless of their current state. Defaults to False.
1775 :return: None
1776 """
1777 for n in self.nodes():
1778 if force or self.s[n] == States.PLACEHOLDER:
1779 obj = dependencies.get(n)
1780 if obj is None:
1781 continue
1782 if callable(obj):
1783 self.add_node(n, obj)
1784 else:
1785 self.add_node(n, value=obj)