Coverage for src/loman/computeengine.py: 89%
991 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-21 05:36 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-21 05:36 +0000
1"""Core computation engine for dependency-aware calculation graphs."""
3import inspect
4import logging
5import traceback
6import types
7import warnings
8from collections import defaultdict
9from collections.abc import Callable, Iterable, Mapping
10from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
11from dataclasses import dataclass, field
12from datetime import UTC, datetime
13from enum import Enum
14from typing import Any, Type, Union # noqa: UP035
16import decorator
17import dill
18import networkx as nx
19import pandas as pd
21from .compat import get_signature
22from .consts import EdgeAttributes, NodeAttributes, NodeTransformations, States, SystemTags
23from .exception import (
24 CannotInsertToPlaceholderNodeException,
25 ComputationError,
26 LoopDetectedException,
27 MapException,
28 NodeAlreadyExistsException,
29 NonExistentNodeException,
30)
31from .nodekey import Name, Names, NodeKey, names_to_node_keys, node_keys_to_names, to_nodekey
32from .util import AttributeView, apply1, apply_n, as_iterable, value_eq
33from .visualization import GraphView, NodeFormatter
35LOG = logging.getLogger("loman.computeengine")
38@dataclass
39class Error:
40 """Container for error information during computation."""
42 exception: Exception
43 traceback: inspect.Traceback
46@dataclass
47class NodeData:
48 """Data associated with a computation node."""
50 state: States
51 value: object
54@dataclass
55class TimingData:
56 """Timing information for computation execution."""
58 start: datetime
59 end: datetime
60 duration: float
63class _ParameterType(Enum):
64 ARG = 1
65 KWD = 2
68@dataclass
69class _ParameterItem:
70 type: object
71 name: int | str
72 value: object
75def _node(func, *args, **kws):
76 return func(*args, **kws)
79def node(comp, name=None, *args, **kw):
80 """Decorator to add a function as a node to a computation graph."""
82 def inner(f):
83 if name is None:
84 comp.add_node(f.__name__, f, *args, **kw)
85 else:
86 comp.add_node(name, f, *args, **kw)
87 return decorator.decorate(f, _node)
89 return inner
92@dataclass()
93class ConstantValue:
94 """Container for constant values in computations."""
96 value: object
99C = ConstantValue
102class Node:
103 """Base class for computation graph nodes."""
105 def add_to_comp(self, comp: "Computation", name: str, obj: object, ignore_self: bool):
106 """Add this node to the computation graph."""
107 raise NotImplementedError()
110@dataclass
111class InputNode(Node):
112 """A node representing input data in the computation graph."""
114 args: tuple[Any, ...] = field(default_factory=tuple)
115 kwds: dict = field(default_factory=dict)
117 def __init__(self, *args, **kwds):
118 """Initialize an input node with arguments and keyword arguments."""
119 self.args = args
120 self.kwds = kwds
122 def add_to_comp(self, comp: "Computation", name: str, obj: object, ignore_self: bool):
123 """Add this input node to the computation graph."""
124 comp.add_node(name, **self.kwds)
127input_node = InputNode
130@dataclass
131class CalcNode(Node):
132 """A node representing a calculation in the computation graph."""
134 f: Callable
135 kwds: dict = field(default_factory=dict)
137 def add_to_comp(self, comp: "Computation", name: str, obj: object, ignore_self: bool):
138 """Add this calculation node to the computation graph."""
139 kwds = self.kwds.copy()
140 ignore_self = ignore_self or kwds.get("ignore_self", False)
141 f = self.f
142 if ignore_self:
143 signature = get_signature(self.f)
144 if len(signature.kwd_params) > 0 and signature.kwd_params[0] == "self":
145 f = f.__get__(obj, obj.__class__)
146 if "ignore_self" in kwds:
147 del kwds["ignore_self"]
148 comp.add_node(name, f, **kwds)
151def calc_node(f=None, **kwds):
152 """Decorator to mark a function as a calculation node."""
154 def wrap(func):
155 func._loman_node_info = CalcNode(func, kwds)
156 return func
158 if f is None:
159 return wrap
160 return wrap(f)
163@dataclass
164class Block(Node):
165 """A node representing a computational block or subgraph."""
167 block: Union[Callable, "Computation"]
168 args: tuple[Any, ...] = field(default_factory=tuple)
169 kwds: dict = field(default_factory=dict)
171 def __init__(self, block, *args, **kwds):
172 """Initialize a block node with a computation block and arguments."""
173 self.block = block
174 self.args = args
175 self.kwds = kwds
177 def add_to_comp(self, comp: "Computation", name: str, obj: object, ignore_self: bool):
178 """Add this block node to the computation graph."""
179 if isinstance(self.block, Computation):
180 comp.add_block(name, self.block, *self.args, **self.kwds)
181 elif callable(self.block):
182 block0 = self.block()
183 comp.add_block(name, block0, *self.args, **self.kwds)
184 else:
185 raise TypeError(f"Block {self.block} must be callable or Computation")
188block = Block
191def populate_computation_from_class(comp, cls, obj, ignore_self=True):
192 """Populate a computation from class methods with node decorators."""
193 for name, member in inspect.getmembers(cls):
194 node_ = None
195 if isinstance(member, Node):
196 node_ = member
197 elif hasattr(member, "_loman_node_info"):
198 node_ = getattr(member, "_loman_node_info")
199 if node_ is not None:
200 node_.add_to_comp(comp, name, obj, ignore_self)
203def computation_factory(maybe_cls=None, *, ignore_self=True) -> Type["Computation"]: # noqa: UP006
204 """Factory function to create computations from class definitions."""
206 def wrap(cls):
207 def create_computation(*args, **kwargs):
208 obj = cls()
209 comp = Computation(*args, **kwargs)
210 comp._definition_object = obj
211 populate_computation_from_class(comp, cls, obj, ignore_self)
212 return comp
214 return create_computation
216 if maybe_cls is None:
217 return wrap
219 return wrap(maybe_cls)
222def _eval_node(name, f, args, kwds, raise_exceptions):
223 """To make multiprocessing work, this function must be standalone so that pickle works."""
224 exc, tb = None, None
225 start_dt = datetime.now(UTC)
226 try:
227 logging.debug("Running " + str(name))
228 value = f(*args, **kwds)
229 logging.debug("Completed " + str(name))
230 except Exception as e:
231 value = None
232 exc = e
233 tb = traceback.format_exc()
234 if raise_exceptions:
235 raise
236 end_dt = datetime.now(UTC)
237 return value, exc, tb, start_dt, end_dt
240_MISSING_VALUE_SENTINEL = object()
243class NullObject:
244 """Debug helper object that raises exceptions for all attribute/item access."""
246 def __getattr__(self, name):
247 """Raise AttributeError for any attribute access."""
248 print(f"__getattr__: {name}")
249 raise AttributeError(f"'NullObject' object has no attribute '{name}'")
251 def __setattr__(self, name, value):
252 """Raise AttributeError for any attribute assignment."""
253 print(f"__setattr__: {name}")
254 raise AttributeError(f"'NullObject' object has no attribute '{name}'")
256 def __delattr__(self, name):
257 """Raise AttributeError for any attribute deletion."""
258 print(f"__delattr__: {name}")
259 raise AttributeError(f"'NullObject' object has no attribute '{name}'")
261 def __call__(self, *args, **kwargs):
262 """Raise TypeError when called as a function."""
263 print(f"__call__: {args}, {kwargs}")
264 raise TypeError("'NullObject' object is not callable")
266 def __getitem__(self, key):
267 """Raise KeyError for any item access."""
268 print(f"__getitem__: {key}")
269 raise KeyError(f"'NullObject' object has no item with key '{key}'")
271 def __setitem__(self, key, value):
272 """Raise KeyError for any item assignment."""
273 print(f"__setitem__: {key}")
274 raise KeyError(f"'NullObject' object cannot have items set with key '{key}'")
276 def __repr__(self):
277 """Return string representation of NullObject."""
278 print(f"__repr__: {self.__dict__}")
279 return "<NullObject>"
282def identity_function(x):
283 """Return the input value unchanged."""
284 return x
287class Computation:
288 """A computation graph that manages dependencies and calculations.
290 The Computation class provides a framework for building and executing
291 computation graphs where nodes represent data or calculations, and edges
292 represent dependencies between them.
293 """
295 def __init__(self, *, default_executor=None, executor_map=None, metadata=None):
296 """Initialize a new Computation.
298 :param default_executor: An executor
299 :type default_executor: concurrent.futures.Executor, default ThreadPoolExecutor(max_workers=1)
300 """
301 if default_executor is None:
302 self.default_executor = ThreadPoolExecutor(1)
303 else:
304 self.default_executor = default_executor
305 if executor_map is None:
306 self.executor_map = {}
307 else:
308 self.executor_map = executor_map
309 self.dag = nx.DiGraph()
310 self._metadata = {}
311 if metadata is not None:
312 self._metadata[NodeKey.root()] = metadata
314 self.v = self.get_attribute_view_for_path(NodeKey.root(), self._value_one, self.value)
315 self.s = self.get_attribute_view_for_path(NodeKey.root(), self._state_one, self.state)
316 self.i = self.get_attribute_view_for_path(NodeKey.root(), self._get_inputs_one_names, self.get_inputs)
317 self.o = self.get_attribute_view_for_path(NodeKey.root(), self._get_outputs_one, self.get_outputs)
318 self.t = self.get_attribute_view_for_path(NodeKey.root(), self._tag_one, self.tags)
319 self.style = self.get_attribute_view_for_path(NodeKey.root(), self._style_one, self.styles)
320 self.tim = self.get_attribute_view_for_path(NodeKey.root(), self._get_timing_one, self.get_timing)
321 self.x = self.get_attribute_view_for_path(
322 NodeKey.root(), self.compute_and_get_value, self.compute_and_get_value
323 )
324 self.src = self.get_attribute_view_for_path(NodeKey.root(), self.print_source, self.print_source)
325 self._tag_map = defaultdict(set)
326 self._state_map = {state: set() for state in States}
328 def get_attribute_view_for_path(self, nodekey: NodeKey, get_one_func: callable, get_many_func: callable):
329 """Create an attribute view for a specific node path."""
331 def node_func():
332 return self.get_tree_list_children(nodekey)
334 def get_one_func_for_path(name: Name):
335 nk = to_nodekey(name)
336 new_nk = nk.prepend(nodekey)
337 if self.has_node(new_nk):
338 return get_one_func(new_nk)
339 elif self.tree_has_path(new_nk):
340 return self.get_attribute_view_for_path(new_nk, get_one_func, get_many_func)
341 else:
342 raise KeyError(f"Path {new_nk} does not exist")
344 def get_many_func_for_path(name: Name | Names):
345 if isinstance(name, list):
346 return [get_one_func_for_path(n) for n in name]
347 else:
348 return get_one_func_for_path(name)
350 return AttributeView(node_func, get_one_func_for_path, get_many_func_for_path)
352 def _get_names_for_state(self, state: States):
353 return set(node_keys_to_names(self._state_map[state]))
355 def _get_tags_for_state(self, tag: str):
356 return set(node_keys_to_names(self._tag_map[tag]))
358 def add_node(
359 self,
360 name: Name,
361 func=None,
362 *,
363 args=None,
364 kwds=None,
365 value=_MISSING_VALUE_SENTINEL,
366 converter=None,
367 serialize=True,
368 inspect=True,
369 group=None,
370 tags=None,
371 style=None,
372 executor=None,
373 metadata=None,
374 ):
375 """Adds or updates a node in a computation.
377 :param name: Name of the node to add. This may be any hashable object.
378 :param func: Function to use to calculate the node if the node is a calculation node. By default, the input
379 nodes to the function will be implied from the names of the function parameters. For example, a
380 parameter called ``a`` would be taken from the node called ``a``. This can be modified with the
381 ``kwds`` parameter.
382 :type func: Function, default None
383 :param args: Specifies a list of nodes that will be used to populate arguments of the function positionally
384 for a calculation node. e.g. If args is ``['a', 'b', 'c']`` then the function would be called with
385 three parameters, taken from the nodes 'a', 'b' and 'c' respectively.
386 :type args: List, default None
387 :param kwds: Specifies a mapping from parameter name to the node that should be used to populate that
388 parameter when calling the function for a calculation node. e.g. If args is ``{'x': 'a', 'y': 'b'}``
389 then the function would be called with parameters named 'x' and 'y', and their values would be taken
390 from nodes 'a' and 'b' respectively. Each entry in the dictionary can be read as "take parameter
391 [key] from node [value]".
392 :type kwds: Dictionary, default None
393 :param value: If given, the value is inserted into the node, and the node state set to UPTODATE.
394 :type value: default None
395 :param serialize: Whether the node should be serialized. Some objects cannot be serialized, in which
396 case, set serialize to False
397 :type serialize: boolean, default True
398 :param inspect: Whether to use introspection to determine the arguments of the function, which can be
399 slow. If this is not set, kwds and args must be set for the function to obtain parameters.
400 :type inspect: boolean, default True
401 :param group: Subgraph to render node in
402 :type group: default None
403 :param tags: Set of tags to apply to node
404 :type tags: Iterable
405 :param styles: Style to apply to node
406 :type styles: String, default None
407 :param executor: Name of executor to run node on
408 :type executor: string
409 """
410 node_key = to_nodekey(name)
411 LOG.debug(f"Adding node {node_key}")
412 has_value = value is not _MISSING_VALUE_SENTINEL
413 if value is _MISSING_VALUE_SENTINEL:
414 value = None
415 if tags is None:
416 tags = []
418 self.dag.add_node(node_key)
419 pred_edges = [(p, node_key) for p in self.dag.predecessors(node_key)]
420 self.dag.remove_edges_from(pred_edges)
421 node = self.dag.nodes[node_key]
423 if metadata is None:
424 if node_key in self._metadata:
425 del self._metadata[node_key]
426 else:
427 self._metadata[node_key] = metadata
429 self._set_state_and_literal_value(node_key, States.UNINITIALIZED, None, require_old_state=False)
431 node[NodeAttributes.TAG] = set()
432 node[NodeAttributes.STYLE] = style
433 node[NodeAttributes.GROUP] = group
434 node[NodeAttributes.ARGS] = {}
435 node[NodeAttributes.KWDS] = {}
436 node[NodeAttributes.FUNC] = None
437 node[NodeAttributes.EXECUTOR] = executor
438 node[NodeAttributes.CONVERTER] = converter
440 if func:
441 node[NodeAttributes.FUNC] = func
442 args_count = 0
443 if args:
444 args_count = len(args)
445 for i, arg in enumerate(args):
446 if isinstance(arg, ConstantValue):
447 node[NodeAttributes.ARGS][i] = arg.value
448 else:
449 input_vertex_name = arg
450 input_vertex_node_key = to_nodekey(input_vertex_name)
451 if not self.dag.has_node(input_vertex_node_key):
452 self.dag.add_node(input_vertex_node_key, **{NodeAttributes.STATE: States.PLACEHOLDER})
453 self._state_map[States.PLACEHOLDER].add(input_vertex_node_key)
454 self.dag.add_edge(
455 input_vertex_node_key, node_key, **{EdgeAttributes.PARAM: (_ParameterType.ARG, i)}
456 )
457 param_map = {}
458 if inspect:
459 signature = get_signature(func)
460 if not signature.has_var_args:
461 for param_name in signature.kwd_params[args_count:]:
462 if kwds is not None and param_name in kwds:
463 param_source = kwds[param_name]
464 else:
465 param_source = node_key.parent.join_parts(param_name)
466 param_map[param_name] = param_source
467 if signature.has_var_kwds and kwds is not None:
468 for param_name, param_source in kwds.items():
469 param_map[param_name] = param_source
470 default_names = signature.default_params
471 else:
472 if kwds is not None:
473 for param_name, param_source in kwds.items():
474 param_map[param_name] = param_source
475 default_names = []
476 for param_name, param_source in param_map.items():
477 if isinstance(param_source, ConstantValue):
478 node[NodeAttributes.KWDS][param_name] = param_source.value
479 else:
480 in_node_name = param_source
481 in_node_key = to_nodekey(in_node_name)
482 if not self.dag.has_node(in_node_key):
483 if param_name in default_names:
484 continue
485 else:
486 self.dag.add_node(in_node_key, **{NodeAttributes.STATE: States.PLACEHOLDER})
487 self._state_map[States.PLACEHOLDER].add(in_node_key)
488 self.dag.add_edge(in_node_key, node_key, **{EdgeAttributes.PARAM: (_ParameterType.KWD, param_name)})
489 self._set_descendents(node_key, States.STALE)
491 if has_value:
492 self._set_uptodate(node_key, value)
493 if node[NodeAttributes.STATE] == States.UNINITIALIZED:
494 self._try_set_computable(node_key)
495 self.set_tag(node_key, tags)
496 if serialize:
497 self.set_tag(node_key, SystemTags.SERIALIZE)
499 def _refresh_maps(self):
500 self._tag_map.clear()
501 for state in States:
502 self._state_map[state].clear()
503 for node_key in self._node_keys():
504 state = self.dag.nodes[node_key][NodeAttributes.STATE]
505 self._state_map[state].add(node_key)
506 tags = self.dag.nodes[node_key].get(NodeAttributes.TAG, set())
507 for tag in tags:
508 self._tag_map[tag].add(node_key)
510 def _set_tag_one(self, name: Name, tag):
511 node_key = to_nodekey(name)
512 self.dag.nodes[node_key][NodeAttributes.TAG].add(tag)
513 self._tag_map[tag].add(node_key)
515 def set_tag(self, name: Names, tag):
516 """Set tags on a node or nodes. Ignored if tags are already set.
518 :param name: Node or nodes to set tag for
519 :param tag: Tag to set
520 """
521 apply_n(self._set_tag_one, name, tag)
523 def _clear_tag_one(self, name: Name, tag):
524 node_key = to_nodekey(name)
525 self.dag.nodes[node_key][NodeAttributes.TAG].discard(tag)
526 self._tag_map[tag].discard(node_key)
528 def clear_tag(self, name: Names, tag):
529 """Clear tag on a node or nodes. Ignored if tags are not set.
531 :param name: Node or nodes to clear tags for
532 :param tag: Tag to clear
533 """
534 apply_n(self._clear_tag_one, name, tag)
536 def _set_style_one(self, name: Name, style):
537 node_key = to_nodekey(name)
538 self.dag.nodes[node_key][NodeAttributes.STYLE] = style
540 def set_style(self, name: Names, style):
541 """Set styles on a node or nodes.
543 :param name: Node or nodes to set style for
544 :param style: Style to set
545 """
546 apply_n(self._set_style_one, name, style)
548 def _clear_style_one(self, name):
549 node_key = to_nodekey(name)
550 self.dag.nodes[node_key][NodeAttributes.STYLE] = None
552 def clear_style(self, name):
553 """Clear style on a node or nodes.
555 :param name: Node or nodes to clear styles for
556 """
557 apply_n(self._clear_style_one, name)
559 def metadata(self, name):
560 """Get metadata for a node."""
561 node_key = to_nodekey(name)
562 if self.tree_has_path(name):
563 if node_key not in self._metadata:
564 self._metadata[node_key] = {}
565 return self._metadata[node_key]
566 else:
567 raise NonExistentNodeException(f"Node {node_key} does not exist.")
569 def delete_node(self, name):
570 """Delete a node from a computation.
572 When nodes are explicitly deleted with ``delete_node``, but are still depended on by other nodes, then they
573 will be set to PLACEHOLDER status. In this case, if the nodes that depend on a PLACEHOLDER node are deleted,
574 then the PLACEHOLDER node will also be deleted.
576 :param name: Name of the node to delete. If the node does not exist, a ``NonExistentNodeException`` will
577 be raised.
578 """
579 node_key = to_nodekey(name)
580 LOG.debug(f"Deleting node {node_key}")
582 if not self.dag.has_node(node_key):
583 raise NonExistentNodeException(f"Node {node_key} does not exist")
585 if node_key in self._metadata:
586 del self._metadata[node_key]
588 if len(self.dag.succ[node_key]) == 0:
589 preds = self.dag.predecessors(node_key)
590 state = self.dag.nodes[node_key][NodeAttributes.STATE]
591 self.dag.remove_node(node_key)
592 self._state_map[state].remove(node_key)
593 for n in preds:
594 if self.dag.nodes[n][NodeAttributes.STATE] == States.PLACEHOLDER:
595 self.delete_node(n)
596 else:
597 self._set_state(node_key, States.PLACEHOLDER)
599 def rename_node(self, old_name: Name | Mapping[Name, Name], new_name: Name | None = None):
600 """Rename a node in a computation.
602 :param old_name: Node to rename, or a dictionary of nodes to rename, with existing names as keys, and
603 new names as values
604 :param new_name: New name for node.
605 """
606 if hasattr(old_name, "__getitem__") and not isinstance(old_name, str):
607 for k, v in old_name.items():
608 LOG.debug(f"Renaming node {k} to {v}")
609 if new_name is not None:
610 raise ValueError("new_name must not be set if rename_node is passed a dictionary")
611 else:
612 name_mapping = old_name
613 else:
614 LOG.debug(f"Renaming node {old_name} to {new_name}")
615 old_node_key = to_nodekey(old_name)
616 if not self.dag.has_node(old_node_key):
617 raise NonExistentNodeException(f"Node {old_name} does not exist")
618 new_node_key = to_nodekey(new_name)
619 if self.dag.has_node(new_node_key):
620 raise NodeAlreadyExistsException(f"Node {new_name} already exists")
621 name_mapping = {old_name: new_name}
623 node_key_mapping = {to_nodekey(old_name): to_nodekey(new_name) for old_name, new_name in name_mapping.items()}
624 nx.relabel_nodes(self.dag, node_key_mapping, copy=False)
626 for old_node_key, new_node_key in node_key_mapping.items():
627 if old_node_key in self._metadata:
628 self._metadata[new_node_key] = self._metadata[old_node_key]
629 del self._metadata[old_node_key]
630 else:
631 if new_node_key in self._metadata:
632 del self._metadata[new_node_key]
634 self._refresh_maps()
636 def repoint(self, old_name: Name, new_name: Name):
637 """Changes all nodes that use old_name as an input to use new_name instead.
639 Note that if old_name is an input to new_name, then that will not be changed, to try to avoid introducing
640 circular dependencies, but other circular dependencies will not be checked.
642 If new_name does not exist, then it will be created as a PLACEHOLDER node.
644 :param old_name:
645 :param new_name:
646 :return:
647 """
648 old_node_key = to_nodekey(old_name)
649 new_node_key = to_nodekey(new_name)
650 if old_node_key == new_node_key:
651 return
653 changed_names = list(self.dag.successors(old_node_key))
655 if len(changed_names) > 0 and not self.dag.has_node(new_node_key):
656 self.dag.add_node(new_node_key, **{NodeAttributes.STATE: States.PLACEHOLDER})
657 self._state_map[States.PLACEHOLDER].add(new_node_key)
659 for name in changed_names:
660 if name == new_node_key:
661 continue
662 edge_data = self.dag.get_edge_data(old_node_key, name)
663 self.dag.add_edge(new_node_key, name, **edge_data)
664 self.dag.remove_edge(old_node_key, name)
666 for name in changed_names:
667 self.set_stale(name)
669 def insert(self, name: Name, value, force=False):
670 """Insert a value into a node of a computation.
672 Following insertation, the node will have state UPTODATE, and all its descendents will be COMPUTABLE or STALE.
674 If an attempt is made to insert a value into a node that does not exist, a ``NonExistentNodeException``
675 will be raised.
677 :param name: Name of the node to add.
678 :param value: The value to be inserted into the node.
679 :param force: Whether to force recalculation of descendents if node value and state would not be changed
680 """
681 node_key = to_nodekey(name)
682 LOG.debug(f"Inserting value into node {node_key}")
684 if not self.dag.has_node(node_key):
685 raise NonExistentNodeException(f"Node {node_key} does not exist")
687 state = self._state_one(name)
688 if state == States.PLACEHOLDER:
689 raise CannotInsertToPlaceholderNodeException(
690 "Cannot insert into placeholder node. Use add_node to create the node first"
691 )
693 if not force:
694 if state == States.UPTODATE:
695 current_value = self._value_one(name)
696 if value_eq(value, current_value):
697 return
699 self._set_state_and_value(node_key, States.UPTODATE, value)
700 self._set_descendents(node_key, States.STALE)
701 for n in self.dag.successors(node_key):
702 self._try_set_computable(n)
704 def insert_many(self, name_value_pairs: Iterable[tuple[Name, object]]):
705 """Insert values into many nodes of a computation simultaneously.
707 Following insertation, the nodes will have state UPTODATE, and all their descendents will be COMPUTABLE
708 or STALE. In the case of inserting many nodes, some of which are descendents of others, this ensures that
709 the inserted nodes have correct status, rather than being set as STALE when their ancestors are inserted.
711 If an attempt is made to insert a value into a node that does not exist, a ``NonExistentNodeException`` will be
712 raised, and none of the nodes will be inserted.
714 :param name_value_pairs: Each tuple should be a pair (name, value), where name is the name of the node to
715 insert the value into.
716 :type name_value_pairs: List of tuples
717 """
718 node_key_value_pairs = [(to_nodekey(name), value) for name, value in name_value_pairs]
719 LOG.debug(f"Inserting value into nodes {', '.join(str(name) for name, value in node_key_value_pairs)}")
721 for name, value in node_key_value_pairs:
722 if not self.dag.has_node(name):
723 raise NonExistentNodeException(f"Node {name} does not exist")
725 stale = set()
726 computable = set()
727 for name, value in node_key_value_pairs:
728 self._set_state_and_value(name, States.UPTODATE, value)
729 stale.update(nx.dag.descendants(self.dag, name))
730 computable.update(self.dag.successors(name))
731 names = set([name for name, value in node_key_value_pairs])
732 stale.difference_update(names)
733 computable.difference_update(names)
734 for name in stale:
735 self._set_state(name, States.STALE)
736 for name in computable:
737 self._try_set_computable(name)
739 def insert_from(self, other, nodes: Iterable[Name] | None = None):
740 """Insert values into another Computation object into this Computation object.
742 :param other: The computation object to take values from
743 :type Computation:
744 :param nodes: Only populate the nodes with the names provided in this list. By default, all nodes from the
745 other Computation object that have corresponding nodes in this Computation object will be inserted
746 :type nodes: List, default None
747 """
748 if nodes is None:
749 nodes = set(self.dag.nodes)
750 nodes.intersection_update(other.dag.nodes())
751 name_value_pairs = [(name, other.value(name)) for name in nodes]
752 self.insert_many(name_value_pairs)
754 def _set_state(self, node_key: NodeKey, state: States):
755 node = self.dag.nodes[node_key]
756 old_state = node[NodeAttributes.STATE]
757 self._state_map[old_state].remove(node_key)
758 node[NodeAttributes.STATE] = state
759 self._state_map[state].add(node_key)
761 def _set_state_and_value(
762 self, node_key: NodeKey, state: States, value: object, *, throw_conversion_exception: bool = True
763 ):
764 node = self.dag.nodes[node_key]
765 converter = node.get(NodeAttributes.CONVERTER)
766 if converter is None:
767 self._set_state_and_literal_value(node_key, state, value)
768 else:
769 try:
770 converted_value = converter(value)
771 self._set_state_and_literal_value(node_key, state, converted_value)
772 except Exception as e:
773 tb = traceback.format_exc()
774 self._set_error(node_key, e, tb)
775 if throw_conversion_exception:
776 raise e
778 def _set_state_and_literal_value(
779 self, node_key: NodeKey, state: States, value: object, require_old_state: bool = True
780 ):
781 node = self.dag.nodes[node_key]
782 try:
783 old_state = node[NodeAttributes.STATE]
784 self._state_map[old_state].remove(node_key)
785 except KeyError:
786 if require_old_state:
787 raise
788 node[NodeAttributes.STATE] = state
789 node[NodeAttributes.VALUE] = value
790 self._state_map[state].add(node_key)
792 def _set_states(self, node_keys: Iterable[NodeKey], state: States):
793 for name in node_keys:
794 node = self.dag.nodes[name]
795 old_state = node[NodeAttributes.STATE]
796 self._state_map[old_state].remove(name)
797 node[NodeAttributes.STATE] = state
798 self._state_map[state].update(node_keys)
800 def set_stale(self, name: Name):
801 """Set the state of a node and all its dependencies to STALE.
803 :param name: Name of the node to set as STALE.
804 """
805 node_key = to_nodekey(name)
806 node_keys = [node_key]
807 node_keys.extend(nx.dag.descendants(self.dag, node_key))
808 self._set_states(node_keys, States.STALE)
809 self._try_set_computable(node_key)
811 def pin(self, name: Name, value=None):
812 """Set the state of a node to PINNED.
814 :param name: Name of the node to set as PINNED.
815 :param value: Value to pin to the node, if provided.
816 :type value: default None
817 """
818 node_key = to_nodekey(name)
819 if value is not None:
820 self.insert(node_key, value)
821 self._set_states([node_key], States.PINNED)
823 def unpin(self, name):
824 """Unpin a node (state of node and all descendents will be set to STALE).
826 :param name: Name of the node to set as PINNED.
827 """
828 node_key = to_nodekey(name)
829 self.set_stale(node_key)
831 def _get_descendents(self, node_key: NodeKey, stop_states: set[States] | None = None) -> set[NodeKey]:
832 if self.dag.nodes[node_key][NodeAttributes.STATE] in stop_states:
833 return set()
834 if stop_states is None:
835 stop_states = []
836 visited = set()
837 to_visit = {node_key}
838 while to_visit:
839 n = to_visit.pop()
840 visited.add(n)
841 for n1 in self.dag.successors(n):
842 if n1 in visited:
843 continue
844 if self.dag.nodes[n1][NodeAttributes.STATE] in stop_states:
845 continue
846 to_visit.add(n1)
847 visited.remove(node_key)
848 return visited
850 def _set_descendents(self, node_key: NodeKey, state):
851 descendents = self._get_descendents(node_key, {States.PINNED})
852 self._set_states(descendents, state)
854 def _set_uninitialized(self, node_key: NodeKey):
855 self._set_states([node_key], States.UNINITIALIZED)
856 self.dag.nodes[node_key].pop(NodeAttributes.VALUE, None)
858 def _set_uptodate(self, node_key: NodeKey, value: object):
859 self._set_state_and_value(node_key, States.UPTODATE, value)
860 self._set_descendents(node_key, States.STALE)
861 for n in self.dag.successors(node_key):
862 self._try_set_computable(n)
864 def _set_error(self, node_key: NodeKey, exc: Exception, tb: inspect.Traceback):
865 self._set_state_and_literal_value(node_key, States.ERROR, Error(exc, tb))
866 self._set_descendents(node_key, States.STALE)
868 def _try_set_computable(self, node_key: NodeKey):
869 if self.dag.nodes[node_key][NodeAttributes.STATE] == States.PINNED:
870 return
871 if self.dag.nodes[node_key].get(NodeAttributes.FUNC) is not None:
872 for n in self.dag.predecessors(node_key):
873 if not self.dag.has_node(n):
874 return
875 if self.dag.nodes[n][NodeAttributes.STATE] != States.UPTODATE:
876 return
877 self._set_state(node_key, States.COMPUTABLE)
879 def _get_parameter_data(self, node_key: NodeKey):
880 for arg, value in self.dag.nodes[node_key][NodeAttributes.ARGS].items():
881 yield _ParameterItem(_ParameterType.ARG, arg, value)
882 for param_name, value in self.dag.nodes[node_key][NodeAttributes.KWDS].items():
883 yield _ParameterItem(_ParameterType.KWD, param_name, value)
884 for in_node_name in self.dag.predecessors(node_key):
885 param_value = self.dag.nodes[in_node_name][NodeAttributes.VALUE]
886 edge = self.dag[in_node_name][node_key]
887 param_type, param_name = edge[EdgeAttributes.PARAM]
888 yield _ParameterItem(param_type, param_name, param_value)
890 def _get_func_args_kwds(self, node_key: NodeKey):
891 node0 = self.dag.nodes[node_key]
892 f = node0[NodeAttributes.FUNC]
893 executor_name = node0.get(NodeAttributes.EXECUTOR)
894 args, kwds = [], {}
895 for param in self._get_parameter_data(node_key):
896 if param.type == _ParameterType.ARG:
897 idx = param.name
898 while len(args) <= idx:
899 args.append(None)
900 args[idx] = param.value
901 elif param.type == _ParameterType.KWD:
902 kwds[param.name] = param.value
903 else:
904 raise Exception(f"Unexpected param type: {param.type}")
905 return f, executor_name, args, kwds
907 def get_definition_args_kwds(self, name: Name) -> tuple[list, dict]:
908 """Get the arguments and keyword arguments for a node's function definition."""
909 res_args = []
910 res_kwds = {}
911 node_key = to_nodekey(name)
912 node_data = self.dag.nodes[node_key]
913 if NodeAttributes.ARGS in node_data:
914 for idx, value in node_data[NodeAttributes.ARGS].items():
915 while len(res_args) <= idx:
916 res_args.append(None)
917 res_args[idx] = C(value)
918 if NodeAttributes.KWDS in node_data:
919 for param_name, value in node_data[NodeAttributes.KWDS].items():
920 res_kwds[param_name] = C(value)
921 for in_node_key in self.dag.predecessors(node_key):
922 edge = self.dag[in_node_key][node_key]
923 if EdgeAttributes.PARAM in edge:
924 param_type, param_name = edge[EdgeAttributes.PARAM]
925 if param_type == _ParameterType.ARG:
926 idx: int = param_name
927 while len(res_args) <= idx:
928 res_args.append(None)
929 res_args[idx] = in_node_key.name
930 elif param_type == _ParameterType.KWD:
931 res_kwds[param_name] = in_node_key.name
932 else:
933 raise Exception(f"Unexpected param type: {param_type}")
934 return res_args, res_kwds
936 def _compute_nodes(self, node_keys: Iterable[NodeKey], raise_exceptions: bool = False):
937 LOG.debug(f"Computing nodes {node_keys}")
939 futs = {}
941 def run(name):
942 f, executor_name, args, kwds = self._get_func_args_kwds(name)
943 if executor_name is None:
944 executor = self.default_executor
945 else:
946 executor = self.executor_map[executor_name]
947 fut = executor.submit(_eval_node, name, f, args, kwds, raise_exceptions)
948 futs[fut] = name
950 computed = set()
952 for node_key in node_keys:
953 node0 = self.dag.nodes[node_key]
954 state = node0[NodeAttributes.STATE]
955 if state == States.COMPUTABLE:
956 run(node_key)
958 while len(futs) > 0:
959 done, not_done = wait(futs.keys(), return_when=FIRST_COMPLETED)
960 for fut in done:
961 node_key = futs.pop(fut)
962 node0 = self.dag.nodes[node_key]
963 try:
964 value, exc, tb, start_dt, end_dt = fut.result()
965 except Exception as e:
966 exc = e
967 tb = traceback.format_exc()
968 self._set_error(node_key, exc, tb)
969 raise
970 delta = (end_dt - start_dt).total_seconds()
971 if exc is None:
972 self._set_state_and_value(node_key, States.UPTODATE, value, throw_conversion_exception=False)
973 node0[NodeAttributes.TIMING] = TimingData(start_dt, end_dt, delta)
974 self._set_descendents(node_key, States.STALE)
975 for n in self.dag.successors(node_key):
976 logging.debug(str(node_key) + " " + str(n) + " " + str(computed))
977 if n in computed:
978 raise LoopDetectedException(f"Calculating {node_key} for the second time")
979 self._try_set_computable(n)
980 node0 = self.dag.nodes[n]
981 state = node0[NodeAttributes.STATE]
982 if state == States.COMPUTABLE and n in node_keys:
983 run(n)
984 else:
985 self._set_error(node_key, exc, tb)
986 computed.add(node_key)
988 def _get_calc_node_keys(self, node_key: NodeKey) -> list[NodeKey]:
989 g = nx.DiGraph()
990 g.add_nodes_from(self.dag.nodes)
991 g.add_edges_from(self.dag.edges)
992 for n in nx.ancestors(g, node_key):
993 node = self.dag.nodes[n]
994 state = node[NodeAttributes.STATE]
995 if state == States.UPTODATE or state == States.PINNED:
996 g.remove_node(n)
998 ancestors = nx.ancestors(g, node_key)
999 for n in ancestors:
1000 if state == States.UNINITIALIZED and len(self.dag.pred[n]) == 0:
1001 raise Exception(f"Cannot compute {node_key} because {n} uninitialized")
1002 if state == States.PLACEHOLDER:
1003 raise Exception(f"Cannot compute {node_key} because {n} is placeholder")
1005 ancestors.add(node_key)
1006 nodes_sorted = nx.topological_sort(g)
1007 return [n for n in nodes_sorted if n in ancestors]
1009 def _get_calc_node_names(self, name: Name) -> Names:
1010 node_key = to_nodekey(name)
1011 return node_keys_to_names(self._get_calc_node_keys(node_key))
1013 def compute(self, name: Name | Iterable[Name], raise_exceptions=False):
1014 """Compute a node and all necessary predecessors.
1016 Following the computation, if successful, the target node, and all necessary ancestors that were not already
1017 UPTODATE will have been calculated and set to UPTODATE. Any node that did not need to be calculated will not
1018 have been recalculated.
1020 If any nodes raises an exception, then the state of that node will be set to ERROR, and its value set to an
1021 object containing the exception object, as well as a traceback. This will not halt the computation, which
1022 will proceed as far as it can, until no more nodes that would be required to calculate the target are
1023 COMPUTABLE.
1025 :param name: Name of the node to compute
1026 :param raise_exceptions: Whether to pass exceptions raised by node computations back to the caller
1027 :type raise_exceptions: Boolean, default False
1028 """
1029 if isinstance(name, (types.GeneratorType, list)):
1030 calc_nodes = set()
1031 for name0 in name:
1032 node_key = to_nodekey(name0)
1033 for n in self._get_calc_node_keys(node_key):
1034 calc_nodes.add(n)
1035 else:
1036 node_key = to_nodekey(name)
1037 calc_nodes = self._get_calc_node_keys(node_key)
1038 self._compute_nodes(calc_nodes, raise_exceptions=raise_exceptions)
1040 def compute_all(self, raise_exceptions=False):
1041 """Compute all nodes of a computation that can be computed.
1043 Nodes that are already UPTODATE will not be recalculated. Following the computation, if successful, all
1044 nodes will have state UPTODATE, except UNINITIALIZED input nodes and PLACEHOLDER nodes.
1046 If any nodes raises an exception, then the state of that node will be set to ERROR, and its value set to an
1047 object containing the exception object, as well as a traceback. This will not halt the computation, which
1048 will proceed as far as it can, until no more nodes are COMPUTABLE.
1050 :param raise_exceptions: Whether to pass exceptions raised by node computations back to the caller
1051 :type raise_exceptions: Boolean, default False
1052 """
1053 self._compute_nodes(self._node_keys(), raise_exceptions=raise_exceptions)
1055 def _node_keys(self) -> list[NodeKey]:
1056 """Get a list of nodes in this computation.
1058 :return: List of nodes.
1059 """
1060 return list(self.dag.nodes)
1062 def nodes(self) -> list[Name]:
1063 """Get a list of nodes in this computation.
1065 :return: List of nodes.
1066 """
1067 return list(n.name for n in self.dag.nodes)
1069 def get_tree_list_children(self, name: Name) -> set[Name]:
1070 """Get a list of nodes in this computation.
1072 :return: List of nodes.
1073 """
1074 node_key = to_nodekey(name)
1075 idx = len(node_key.parts)
1076 result = set()
1077 for n in self.dag.nodes:
1078 if n.is_descendent_of(node_key):
1079 result.add(n.parts[idx])
1080 return result
1082 def has_node(self, name: Name):
1083 """Check if a node with the given name exists in the computation."""
1084 node_key = to_nodekey(name)
1085 return node_key in self.dag.nodes
1087 def tree_has_path(self, name: Name):
1088 """Check if a hierarchical path exists in the computation tree."""
1089 node_key = to_nodekey(name)
1090 if node_key.is_root:
1091 return True
1092 if self.has_node(node_key):
1093 return True
1094 for n in self.dag.nodes:
1095 if n.is_descendent_of(node_key):
1096 return True
1097 return False
1099 def get_tree_descendents(
1100 self, name: Name | None = None, *, include_stem: bool = True, graph_nodes_only: bool = False
1101 ) -> set[Name]:
1102 """Get a list of descendent blocks and nodes.
1104 Returns blocks and nodes that are descendents of the input node,
1105 e.g. for node 'foo', might return ['foo/bar', 'foo/baz'].
1107 :param name: Name of node to get descendents for
1108 :return: List of descendent node names
1109 """
1110 node_key = NodeKey.root() if name is None else to_nodekey(name)
1111 stemsize = len(node_key.parts)
1112 result = set()
1113 for n in self.dag.nodes:
1114 if n.is_descendent_of(node_key):
1115 if graph_nodes_only:
1116 nodes = [n]
1117 else:
1118 nodes = n.ancestors()
1119 for n2 in nodes:
1120 if n2.is_descendent_of(node_key):
1121 if include_stem:
1122 nm = n2.name
1123 else:
1124 nm = NodeKey(tuple(n2.parts[stemsize:])).name
1125 result.add(nm)
1126 return result
1128 def _state_one(self, name: Name):
1129 node_key = to_nodekey(name)
1130 return self.dag.nodes[node_key][NodeAttributes.STATE]
1132 def state(self, name: Name | Names):
1133 """Get the state of a node.
1135 This can also be accessed using the attribute-style accessor ``s`` if ``name`` is a valid Python
1136 attribute name::
1138 >>> comp = Computation()
1139 >>> comp.add_node('foo', value=1)
1140 >>> comp.state('foo')
1141 <States.UPTODATE: 4>
1142 >>> comp.s.foo
1143 <States.UPTODATE: 4>
1145 :param name: Name or names of the node to get state for
1146 :type name: Name or Names
1147 """
1148 return apply1(self._state_one, name)
1150 def _value_one(self, name: Name):
1151 node_key = to_nodekey(name)
1152 return self.dag.nodes[node_key][NodeAttributes.VALUE]
1154 def value(self, name: Name | Names):
1155 """Get the current value of a node.
1157 This can also be accessed using the attribute-style accessor ``v`` if ``name`` is a valid Python
1158 attribute name::
1160 >>> comp = Computation()
1161 >>> comp.add_node('foo', value=1)
1162 >>> comp.value('foo')
1163 1
1164 >>> comp.v.foo
1165 1
1167 :param name: Name or names of the node to get the value of
1168 :type name: Name or Names
1169 """
1170 return apply1(self._value_one, name)
1172 def compute_and_get_value(self, name: Name):
1173 """Get the current value of a node.
1175 This can also be accessed using the attribute-style accessor ``v`` if ``name`` is a valid Python
1176 attribute name::
1178 >>> comp = Computation()
1179 >>> comp.add_node('foo', value=1)
1180 >>> comp.add_node('bar', lambda foo: foo + 1)
1181 >>> comp.compute_and_get_value('bar')
1182 2
1183 >>> comp.x.bar
1184 2
1186 :param name: Name or names of the node to get the value of
1187 :type name: Name
1188 """
1189 name = to_nodekey(name)
1190 if self.state(name) == States.UPTODATE:
1191 return self.value(name)
1192 self.compute(name, raise_exceptions=True)
1193 if self.state(name) == States.UPTODATE:
1194 return self.value(name)
1195 raise ComputationError(f"Unable to compute node {name}")
1197 def _tag_one(self, name: Name):
1198 node_key = to_nodekey(name)
1199 node = self.dag.nodes[node_key]
1200 return node[NodeAttributes.TAG]
1202 def tags(self, name: Name | Names):
1203 """Get the tags associated with a node.
1205 >>> comp = Computation()
1206 >>> comp.add_node('a', tags=['foo', 'bar'])
1207 >>> comp.t.a
1208 {'__serialize__', 'bar', 'foo'}
1209 :param name: Name or names of the node to get the tags of
1210 :return:
1211 """
1212 return apply1(self._tag_one, name)
1214 def nodes_by_tag(self, tag) -> set[Name]:
1215 """Get the names of nodes with a particular tag or tags.
1217 :param tag: Tag or tags for which to retrieve nodes
1218 :return: Names of the nodes with those tags
1219 """
1220 nodes = set()
1221 for tag1 in as_iterable(tag):
1222 nodes1 = self._tag_map.get(tag1)
1223 if nodes1 is not None:
1224 nodes.update(nodes1)
1225 return set(n.name for n in nodes)
1227 def _style_one(self, name: Name):
1228 node_key = to_nodekey(name)
1229 node = self.dag.nodes[node_key]
1230 return node.get(NodeAttributes.STYLE)
1232 def styles(self, name: Name | Names):
1233 """Get the tags associated with a node.
1235 >>> comp = Computation()
1236 >>> comp.add_node('a', styles='dot')
1237 >>> comp.style.a
1238 'dot'
1239 :param name: Name or names of the node to get the tags of
1240 :return:
1241 """
1242 return apply1(self._style_one, name)
1244 def _get_item_one(self, name: Name):
1245 node_key = to_nodekey(name)
1246 node = self.dag.nodes[node_key]
1247 return NodeData(node[NodeAttributes.STATE], node[NodeAttributes.VALUE])
1249 def __getitem__(self, name: Name | Names):
1250 """Get the state and current value of a node.
1252 :param name: Name of the node to get the state and value of
1253 """
1254 return apply1(self._get_item_one, name)
1256 def _get_timing_one(self, name: Name):
1257 node_key = to_nodekey(name)
1258 node = self.dag.nodes[node_key]
1259 return node.get(NodeAttributes.TIMING, None)
1261 def get_timing(self, name: Name | Names):
1262 """Get the timing information for a node.
1264 :param name: Name or names of the node to get the timing information of
1265 :return:
1266 """
1267 return apply1(self._get_timing_one, name)
1269 def to_df(self):
1270 """Get a dataframe containing the states and value of all nodes of computation.
1272 ::
1274 >>> comp = loman.Computation()
1275 >>> comp.add_node('foo', value=1)
1276 >>> comp.add_node('bar', value=2)
1277 >>> comp.to_df()
1278 state value is_expansion
1279 bar States.UPTODATE 2 NaN
1280 foo States.UPTODATE 1 NaN
1281 """
1282 df = pd.DataFrame(index=nx.topological_sort(self.dag))
1283 df[NodeAttributes.STATE] = pd.Series(nx.get_node_attributes(self.dag, NodeAttributes.STATE))
1284 df[NodeAttributes.VALUE] = pd.Series(nx.get_node_attributes(self.dag, NodeAttributes.VALUE))
1285 df_timing = pd.DataFrame.from_dict(nx.get_node_attributes(self.dag, "timing"), orient="index")
1286 df = pd.merge(df, df_timing, left_index=True, right_index=True, how="left")
1287 df.index = pd.Index([nk.name for nk in df.index])
1288 return df
1290 def to_dict(self):
1291 """Get a dictionary containing the values of all nodes of a computation.
1293 ::
1295 >>> comp = loman.Computation()
1296 >>> comp.add_node('foo', value=1)
1297 >>> comp.add_node('bar', value=2)
1298 >>> comp.to_dict()
1299 {'bar': 2, 'foo': 1}
1300 """
1301 return nx.get_node_attributes(self.dag, NodeAttributes.VALUE)
1303 def _get_inputs_one_node_keys(self, node_key: NodeKey) -> list[NodeKey]:
1304 args_dict = {}
1305 kwds = []
1306 max_arg_index = -1
1307 for input_node in self.dag.predecessors(node_key):
1308 input_edge = self.dag[input_node][node_key]
1309 input_type, input_param = input_edge[EdgeAttributes.PARAM]
1310 if input_type == _ParameterType.ARG:
1311 idx = input_param
1312 max_arg_index = max(max_arg_index, idx)
1313 args_dict[idx] = input_node
1314 elif input_type == _ParameterType.KWD:
1315 kwds.append(input_node)
1316 if max_arg_index >= 0:
1317 args = [None] * (max_arg_index + 1)
1318 for idx, input_node in args_dict.items():
1319 args[idx] = input_node
1320 return args + kwds
1321 else:
1322 return kwds
1324 def _get_inputs_one_names(self, name: Name) -> Names:
1325 node_key = to_nodekey(name)
1326 return node_keys_to_names(self._get_inputs_one_node_keys(node_key))
1328 def get_inputs(self, name: Name | Names) -> list[Names]:
1329 """Get a list of the inputs for a node or set of nodes.
1331 :param name: Name or names of nodes to get inputs for
1332 :return: If name is scalar, return a list of upstream nodes used as input. If name is a list, return a
1333 list of list of inputs.
1334 """
1335 return apply1(self._get_inputs_one_names, name)
1337 def _get_ancestors_node_keys(self, node_keys: Iterable[NodeKey], include_self=True) -> set[NodeKey]:
1338 ancestors = set()
1339 for n in node_keys:
1340 if include_self:
1341 ancestors.add(n)
1342 for ancestor in nx.ancestors(self.dag, n):
1343 ancestors.add(ancestor)
1344 return ancestors
1346 def get_ancestors(self, names: Name | Names, include_self=True) -> Names:
1347 """Get all ancestor nodes of the specified nodes."""
1348 node_keys = names_to_node_keys(names)
1349 ancestor_node_keys = self._get_ancestors_node_keys(node_keys, include_self)
1350 return node_keys_to_names(ancestor_node_keys)
1352 def _get_original_inputs_node_keys(self, node_keys: list[NodeKey] | None) -> Names:
1353 if node_keys is None:
1354 node_keys = self._node_keys()
1355 else:
1356 node_keys = self._get_ancestors_node_keys(node_keys)
1357 return [n for n in node_keys if self.dag.nodes[n].get(NodeAttributes.FUNC) is None]
1359 def get_original_inputs(self, names: Name | Names | None = None) -> Names:
1360 """Get a list of the original non-computed inputs for a node or set of nodes.
1362 :param names: Name or names of nodes to get inputs for
1363 :return: Return a list of original non-computed inputs that are ancestors of the input nodes
1364 """
1365 if names is None:
1366 node_keys = None
1367 else:
1368 node_keys = names_to_node_keys(names)
1370 node_keys = self._get_original_inputs_node_keys(node_keys)
1372 return node_keys_to_names(node_keys)
1374 def _get_outputs_one(self, name: Name) -> Names:
1375 node_key = to_nodekey(name)
1376 output_node_keys = list(self.dag.successors(node_key))
1377 return node_keys_to_names(output_node_keys)
1379 def get_outputs(self, name: Name | Names) -> Names | list[Names]:
1380 """Get a list of the outputs for a node or set of nodes.
1382 :param name: Name or names of nodes to get outputs for
1383 :return: If name is scalar, return a list of downstream nodes used as output. If name is a list, return a
1384 list of list of outputs.
1386 """
1387 return apply1(self._get_outputs_one, name)
1389 def _get_descendents_node_keys(self, node_keys: Iterable[NodeKey], include_self: bool = True) -> Names:
1390 ancestor_node_keys = set()
1391 for node_key in node_keys:
1392 if include_self:
1393 ancestor_node_keys.add(node_key)
1394 for ancestor in nx.descendants(self.dag, node_key):
1395 ancestor_node_keys.add(ancestor)
1396 return ancestor_node_keys
1398 def get_descendents(self, names: Name | Names, include_self: bool = True) -> Names:
1399 """Get all descendent nodes of the specified nodes."""
1400 node_keys = names_to_node_keys(names)
1401 descendent_node_keys = self._get_descendents_node_keys(node_keys, include_self)
1402 return node_keys_to_names(descendent_node_keys)
1404 def get_final_outputs(self, names: Name | Names | None = None):
1405 """Get final output nodes (nodes with no descendants) from the specified nodes."""
1406 if names is None:
1407 node_keys = self._node_keys()
1408 else:
1409 node_keys = names_to_node_keys(names)
1410 node_keys = self._get_descendents_node_keys(node_keys)
1411 output_node_keys = [n for n in node_keys if len(nx.descendants(self.dag, n)) == 0]
1412 return node_keys_to_names(output_node_keys)
1414 def get_source(self, name: Name) -> str:
1415 """Get the source code for a node."""
1416 node_key = to_nodekey(name)
1417 func = self.dag.nodes[node_key].get(NodeAttributes.FUNC, None)
1418 if func is not None:
1419 file = inspect.getsourcefile(func)
1420 _, lineno = inspect.getsourcelines(func)
1421 source = inspect.getsource(func)
1422 return f"{file}:{lineno}\n\n{source}"
1423 else:
1424 return "NOT A CALCULATED NODE"
1426 def print_source(self, name: Name):
1427 """Print the source code for a computation node."""
1428 print(self.get_source(name))
1430 def restrict(self, output_names: Name | Names, input_names: Name | Names | None = None):
1431 """Restrict a computation to the ancestors of a set of output nodes.
1433 Excludes ancestors of a set of input nodes.
1435 If the set of input_nodes that is specified is not sufficient for the set of output_nodes then additional
1436 nodes that are ancestors of the output_nodes will be included, but the input nodes specified will be input
1437 nodes of the modified Computation.
1439 :param output_nodes:
1440 :param input_nodes:
1441 :return: None - modifies existing computation in place
1442 """
1443 if input_names is not None:
1444 for name in input_names:
1445 nodedata = self._get_item_one(name)
1446 self.add_node(name)
1447 self._set_state_and_literal_value(to_nodekey(name), nodedata.state, nodedata.value)
1448 output_node_keys = names_to_node_keys(output_names)
1449 ancestor_node_keys = self._get_ancestors_node_keys(output_node_keys)
1450 self.dag.remove_nodes_from([n for n in self.dag if n not in ancestor_node_keys])
1452 def __getstate__(self):
1453 """Prepare computation for serialization by removing non-serializable nodes."""
1454 node_serialize = nx.get_node_attributes(self.dag, NodeAttributes.TAG)
1455 obj = self.copy()
1456 for name, tags in node_serialize.items():
1457 if SystemTags.SERIALIZE not in tags:
1458 obj._set_uninitialized(name)
1459 return {"dag": obj.dag}
1461 def __setstate__(self, state):
1462 """Restore computation from serialized state."""
1463 self.__init__()
1464 self.dag = state["dag"]
1465 self._refresh_maps()
1467 def write_dill_old(self, file_):
1468 """Serialize a computation to a file or file-like object.
1470 :param file_: If string, writes to a file
1471 :type file_: File-like object, or string
1472 """
1473 warnings.warn("write_dill_old is deprecated, use write_dill instead", DeprecationWarning, stacklevel=2)
1474 original_getstate = self.__class__.__getstate__
1475 original_setstate = self.__class__.__setstate__
1477 try:
1478 del self.__class__.__getstate__
1479 del self.__class__.__setstate__
1481 node_serialize = nx.get_node_attributes(self.dag, NodeAttributes.TAG)
1482 obj = self.copy()
1483 obj.executor_map = None
1484 obj.default_executor = None
1485 for name, tags in node_serialize.items():
1486 if SystemTags.SERIALIZE not in tags:
1487 obj._set_uninitialized(name)
1489 if isinstance(file_, str):
1490 with open(file_, "wb") as f:
1491 dill.dump(obj, f)
1492 else:
1493 dill.dump(obj, file_)
1494 finally:
1495 self.__class__.__getstate__ = original_getstate
1496 self.__class__.__setstate__ = original_setstate
1498 def write_dill(self, file_):
1499 """Serialize a computation to a file or file-like object.
1501 :param file_: If string, writes to a file
1502 :type file_: File-like object, or string
1503 """
1504 if isinstance(file_, str):
1505 with open(file_, "wb") as f:
1506 dill.dump(self, f)
1507 else:
1508 dill.dump(self, file_)
1510 @staticmethod
1511 def read_dill(file_):
1512 """Deserialize a computation from a file or file-like object.
1514 :param file_: If string, writes to a file
1515 :type file_: File-like object, or string
1516 """
1517 if isinstance(file_, str):
1518 with open(file_, "rb") as f:
1519 obj = dill.load(f)
1520 else:
1521 obj = dill.load(file_)
1522 if isinstance(obj, Computation):
1523 return obj
1524 else:
1525 raise Exception()
1527 def copy(self):
1528 """Create a copy of a computation.
1530 The copy is shallow. Any values in the new Computation's DAG will be the same object as this Computation's
1531 DAG. As new objects will be created by any further computations, this should not be an issue.
1533 :rtype: Computation
1534 """
1535 obj = Computation()
1536 obj.dag = nx.DiGraph(self.dag)
1537 obj._tag_map = {tag: nodes.copy() for tag, nodes in self._tag_map.items()}
1538 obj._state_map = {state: nodes.copy() for state, nodes in self._state_map.items()}
1539 return obj
1541 def add_named_tuple_expansion(self, name, namedtuple_type, group=None):
1542 """Automatically add nodes to extract each element of a named tuple type.
1544 It is often convenient for a calculation to return multiple values, and it is polite to do this a namedtuple
1545 rather than a regular tuple, so that later users have same name to identify elements of the tuple. It can
1546 also help make a computation clearer if a downstream computation depends on one element of such a tuple,
1547 rather than the entire tuple. This does not affect the computation per se, but it does make the intention
1548 clearer.
1550 To avoid having to create many boiler-plate node definitions to expand namedtuples, the
1551 ``add_named_tuple_expansion`` method automatically creates new nodes for each element of a tuple. The
1552 convention is that an element called 'element', in a node called 'node' will be expanded into a new node
1553 called 'node.element', and that this will be applied for each element.
1555 Example::
1557 >>> from collections import namedtuple
1558 >>> Coordinate = namedtuple('Coordinate', ['x', 'y'])
1559 >>> comp = Computation()
1560 >>> comp.add_node('c', value=Coordinate(1, 2))
1561 >>> comp.add_named_tuple_expansion('c', Coordinate)
1562 >>> comp.compute_all()
1563 >>> comp.value('c.x')
1564 1
1565 >>> comp.value('c.y')
1566 2
1568 :param name: Node to cera
1569 :param namedtuple_type: Expected type of the node
1570 :type namedtuple_type: namedtuple class
1571 """
1573 def make_f(field_name):
1574 def get_field_value(tuple):
1575 return getattr(tuple, field_name)
1577 return get_field_value
1579 for field_name in namedtuple_type._fields:
1580 node_name = f"{name}.{field_name}"
1581 self.add_node(node_name, make_f(field_name), kwds={"tuple": name}, group=group)
1582 self.set_tag(node_name, SystemTags.EXPANSION)
1584 def add_map_node(self, result_node, input_node, subgraph, subgraph_input_node, subgraph_output_node):
1585 """Apply a graph to each element of iterable.
1587 In turn, each element in the ``input_node`` of this graph will be inserted in turn into the subgraph's
1588 ``subgraph_input_node``, then the subgraph's ``subgraph_output_node`` calculated. The resultant list, with
1589 an element or each element in ``input_node``, will be inserted into ``result_node`` of this graph. In this
1590 way ``add_map_node`` is similar to ``map`` in functional programming.
1592 :param result_node: The node to place a list of results in **this** graph
1593 :param input_node: The node to get a list input values from **this** graph
1594 :param subgraph: The graph to use to perform calculation for each element
1595 :param subgraph_input_node: The node in **subgraph** to insert each element in turn
1596 :param subgraph_output_node: The node in **subgraph** to read the result for each element
1597 """
1599 def f(xs):
1600 results = []
1601 is_error = False
1602 for x in xs:
1603 subgraph.insert(subgraph_input_node, x)
1604 subgraph.compute(subgraph_output_node)
1605 if subgraph.state(subgraph_output_node) == States.UPTODATE:
1606 results.append(subgraph.value(subgraph_output_node))
1607 else:
1608 is_error = True
1609 results.append(subgraph.copy())
1610 if is_error:
1611 raise MapException(f"Unable to calculate {result_node}", results)
1612 return results
1614 self.add_node(result_node, f, kwds={"xs": input_node})
1616 def prepend_path(self, path, prefix_path: NodeKey):
1617 """Prepend a prefix path to a node path."""
1618 if isinstance(path, ConstantValue):
1619 return path
1620 path = to_nodekey(path)
1621 return prefix_path.join(path)
1623 def add_block(
1624 self,
1625 base_path: Name,
1626 block: "Computation",
1627 *,
1628 keep_values: bool | None = True,
1629 links: dict | None = None,
1630 metadata: dict | None = None,
1631 ):
1632 """Add a computation block as a subgraph to this computation."""
1633 base_path = to_nodekey(base_path)
1634 for node_name in block.nodes():
1635 node_key = to_nodekey(node_name)
1636 node_data = block.dag.nodes[node_key]
1637 tags = node_data.get(NodeAttributes.TAG, None)
1638 style = node_data.get(NodeAttributes.STYLE, None)
1639 group = node_data.get(NodeAttributes.GROUP, None)
1640 args, kwds = block.get_definition_args_kwds(node_key)
1641 args = [self.prepend_path(arg, base_path) for arg in args]
1642 kwds = {k: self.prepend_path(v, base_path) for k, v in kwds.items()}
1643 func = node_data.get(NodeAttributes.FUNC, None)
1644 executor = node_data.get(NodeAttributes.EXECUTOR, None)
1645 converter = node_data.get(NodeAttributes.CONVERTER, None)
1646 new_node_name = self.prepend_path(node_name, base_path)
1647 self.add_node(
1648 new_node_name,
1649 func,
1650 args=args,
1651 kwds=kwds,
1652 converter=converter,
1653 serialize=False,
1654 inspect=False,
1655 group=group,
1656 tags=tags,
1657 style=style,
1658 executor=executor,
1659 )
1660 if keep_values and NodeAttributes.VALUE in node_data:
1661 new_node_key = to_nodekey(new_node_name)
1662 self._set_state_and_literal_value(
1663 new_node_key, node_data[NodeAttributes.STATE], node_data[NodeAttributes.VALUE]
1664 )
1665 if links is not None:
1666 for target, source in links.items():
1667 self.link(base_path.join_parts(target), source)
1668 if metadata is not None:
1669 self._metadata[base_path] = metadata
1670 else:
1671 if base_path in self._metadata:
1672 del self._metadata[base_path]
1674 def link(self, target: Name, source: Name):
1675 """Create a link between two nodes in the computation graph."""
1676 target = to_nodekey(target)
1677 source = to_nodekey(source)
1678 if target == source:
1679 return
1681 target_style = self._style_one(target) if self.has_node(target) else None
1682 source_style = self._style_one(source) if self.has_node(source) else None
1683 style = target_style if target_style else source_style
1685 self.add_node(target, identity_function, kwds={"x": source}, style=style)
1687 def _repr_svg_(self):
1688 return GraphView(self).svg()
1690 def draw(
1691 self,
1692 root: NodeKey | None = None,
1693 *,
1694 node_transformations: dict | None = None,
1695 cmap=None,
1696 colors="state",
1697 shapes=None,
1698 graph_attr=None,
1699 node_attr=None,
1700 edge_attr=None,
1701 show_expansion=False,
1702 collapse_all=True,
1703 ):
1704 """Draw a computation's current state using the GraphViz utility.
1706 :param root: Optional PathType. Sub-block to draw
1707 :param cmap: Default: None
1708 :param colors: 'state' - colors indicate state. 'timing' - colors indicate execution time. Default: 'state'.
1709 :param shapes: None - ovals. 'type' - shapes indicate type. Default: None.
1710 :param graph_attr: Mapping of (attribute, value) pairs for the graph. For example
1711 ``graph_attr={'size': '"10,8"'}`` can control the size of the output graph
1712 :param node_attr: Mapping of (attribute, value) pairs set for all nodes.
1713 :param edge_attr: Mapping of (attribute, value) pairs set for all edges.
1714 :param collapse_all: Whether to collapse all blocks that aren't explicitly expanded.
1715 """
1716 node_formatter = NodeFormatter.create(cmap, colors, shapes)
1717 node_transformations = node_transformations.copy() if node_transformations is not None else {}
1718 if not show_expansion:
1719 for nodekey in self.nodes_by_tag(SystemTags.EXPANSION):
1720 node_transformations[nodekey] = NodeTransformations.CONTRACT
1721 v = GraphView(
1722 self,
1723 root=root,
1724 node_formatter=node_formatter,
1725 graph_attr=graph_attr,
1726 node_attr=node_attr,
1727 edge_attr=edge_attr,
1728 node_transformations=node_transformations,
1729 collapse_all=collapse_all,
1730 )
1731 return v
1733 def view(self, cmap=None, colors="state", shapes=None):
1734 """Create and display a visualization of the computation graph."""
1735 node_formatter = NodeFormatter.create(cmap, colors, shapes)
1736 v = GraphView(self, node_formatter=node_formatter)
1737 v.view()
1739 def print_errors(self):
1740 """Print tracebacks for every node with state "ERROR" in a Computation."""
1741 for n in self.nodes():
1742 if self.s[n] == States.ERROR:
1743 print(f"{n}")
1744 print("=" * len(n))
1745 print()
1746 print(self.v[n].traceback)
1747 print()
1749 @classmethod
1750 def from_class(cls, definition_class, ignore_self=True):
1751 """Create a computation from a class with decorated methods."""
1752 comp = cls()
1753 obj = definition_class()
1754 populate_computation_from_class(comp, definition_class, obj, ignore_self=ignore_self)
1755 return comp
1757 def inject_dependencies(self, dependencies: dict, *, force: bool = False):
1758 """Injects dependencies into the nodes of the current computation where nodes are in a placeholder state.
1760 (or all possible nodes when the 'force' parameter is set to True), using values
1761 provided in the 'dependencies' dictionary.
1763 Each key in the 'dependencies' dictionary corresponds to a node identifier, and the associated
1764 value is the dependency object to inject. If the value is a callable, it will be added as a calc node.
1766 :param dependencies: A dictionary where each key-value pair consists of a node identifier and
1767 its corresponding dependency object or a callable that returns the dependency object.
1768 :param force: A boolean flag that, when set to True, forces the replacement of existing node values
1769 with the ones provided in 'dependencies', regardless of their current state. Defaults to False.
1770 :return: None
1771 """
1772 for n in self.nodes():
1773 if force or self.s[n] == States.PLACEHOLDER:
1774 obj = dependencies.get(n)
1775 if obj is None:
1776 continue
1777 if callable(obj):
1778 self.add_node(n, obj)
1779 else:
1780 self.add_node(n, value=obj)