Coverage for src / loman / serialization / computation.py: 93%
140 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-07 21:24 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-07 21:24 +0000
1"""Serialization for Computation graphs to/from JSON."""
3from __future__ import annotations
5import json
6from typing import TYPE_CHECKING, Any, ClassVar, TextIO
8from loman.consts import EdgeAttributes, NodeAttributes, States, SystemTags
9from loman.exception import SerializationError
10from loman.nodekey import parse_nodekey
12from .transformer import (
13 DataFrameTransformer,
14 DillFunctionTransformer,
15 EnumTransformer,
16 FunctionRefTransformer,
17 NdArrayTransformer,
18 NodeKeyTransformer,
19 SeriesTransformer,
20 Transformer,
21 UntransformableTypeError,
22)
24if TYPE_CHECKING:
25 pass
27# Serialization format version — bump when the schema changes.
28FORMAT_VERSION = 1
31def default_computation_transformer() -> Transformer:
32 """Create a Transformer pre-registered with all types needed for Computation serialization."""
33 t = Transformer()
35 # Numeric arrays
36 t.register(NdArrayTransformer())
38 # Enums: register the States enum so node states roundtrip correctly.
39 enum_t = EnumTransformer()
40 enum_t.register_enum(States)
41 t.register(enum_t)
43 # Importable callables (module-level functions). Lambdas / closures raise.
44 t.register(FunctionRefTransformer())
46 # Pandas
47 t.register(DataFrameTransformer())
48 t.register(SeriesTransformer())
50 # NodeKey (hierarchical node names)
51 t.register(NodeKeyTransformer())
53 return t
56def dill_computation_transformer() -> Transformer:
57 """Create a Transformer that serializes all callables — including lambdas and closures — via dill.
59 Identical to :func:`default_computation_transformer` except that
60 :class:`~loman.serialization.transformer.DillFunctionTransformer` replaces
61 :class:`~loman.serialization.transformer.FunctionRefTransformer`, so lambdas
62 and locally-defined closures are serialized as base64-encoded dill blobs
63 rather than raising :class:`~loman.exception.SerializationError`.
64 """
65 t = Transformer()
67 t.register(NdArrayTransformer())
69 enum_t = EnumTransformer()
70 enum_t.register_enum(States)
71 t.register(enum_t)
73 # Dill-based callable serializer — handles lambdas and closures.
74 t.register(DillFunctionTransformer())
76 t.register(DataFrameTransformer())
77 t.register(SeriesTransformer())
78 t.register(NodeKeyTransformer())
80 return t
83class ComputationSerializer:
84 """Serialize and deserialize a :class:`~loman.computeengine.Computation` graph to JSON.
86 The serialized format is a JSON object with the following top-level keys:
88 - ``version``: integer format version
89 - ``nodes``: list of node objects
90 - ``edges``: list of edge objects
92 Each **node** object has:
94 - ``key``: string representation of the NodeKey
95 - ``state``: name of the :class:`~loman.consts.States` enum member (or ``null``)
96 - ``value``: transformer-encoded value (or ``null`` when absent / not serialized)
97 - ``has_value``: bool — false when the node has no meaningful value to restore
98 - ``func``: transformer-encoded callable (or ``null``)
99 - ``serialize``: bool — whether the node has the ``__serialize__`` tag
100 - ``tags``: list of non-system tags
102 Each **edge** object has:
104 - ``src``: string key of the source node
105 - ``dst``: string key of the destination node
106 - ``param_type``: ``"arg"`` or ``"kwd"``
107 - ``param``: positional index (int) for args, parameter name (str) for kwds
109 Parameters
110 ----------
111 transformer:
112 Custom :class:`~loman.serialization.transformer.Transformer` instance.
113 If ``None``, a default transformer is built based on *use_dill_for_functions*.
114 use_dill_for_functions:
115 When ``True``, lambdas and closures are serialized as base64-encoded dill
116 blobs rather than raising :class:`~loman.exception.SerializationError`.
117 Has no effect when a custom *transformer* is supplied. Defaults to ``False``.
118 """
120 # States whose nodes carry a meaningful value that should be preserved.
121 _VALUE_STATES: ClassVar[set[States]] = {States.UPTODATE, States.PINNED, States.ERROR}
123 def __init__(
124 self,
125 transformer: Transformer | None = None,
126 *,
127 use_dill_for_functions: bool = False,
128 ) -> None:
129 """Initialise with an optional custom transformer."""
130 if transformer is None:
131 transformer = (
132 dill_computation_transformer() if use_dill_for_functions else default_computation_transformer()
133 )
134 self._t = transformer
135 self._use_dill_for_functions = use_dill_for_functions
137 def dump(self, comp: Any, fp: TextIO) -> None:
138 """Serialize *comp* to *fp* (a text-mode file-like object)."""
139 data = self._to_dict(comp)
140 json.dump(data, fp)
142 def dumps(self, comp: Any) -> str:
143 """Serialize *comp* and return a JSON string."""
144 return json.dumps(self._to_dict(comp))
146 def _serialize_node_value(self, node_key: Any, state: States | None, node_data: dict[str, Any]) -> tuple[Any, bool]:
147 """Return ``(encoded_value, has_value)`` for a node that should be serialized.
149 Raises :class:`~loman.exception.SerializationError` if the value cannot
150 be encoded.
151 """
152 from loman.computeengine import Error
154 if state not in self._VALUE_STATES:
155 return None, False
157 raw_value = node_data.get(NodeAttributes.VALUE)
158 if state == States.ERROR and isinstance(raw_value, Error):
159 return (
160 {
161 "__loman_error__": True,
162 "exception_type": type(raw_value.exception).__name__,
163 "exception_str": str(raw_value.exception),
164 "traceback": raw_value.traceback,
165 },
166 True,
167 )
169 try:
170 return self._t.to_dict(raw_value), True
171 except (UntransformableTypeError, ValueError) as exc:
172 msg = f"Cannot serialize value of node {node_key!r}: {exc}"
173 raise SerializationError(msg) from exc
175 def _serialize_node_func(self, node_key: Any, raw_func: Any) -> Any:
176 """Return the encoded function for a node, or ``None`` if it cannot be serialized.
178 Lambdas raise :class:`~loman.exception.SerializationError` unless
179 ``use_dill_for_functions`` is enabled. Other non-importable callables
180 (e.g. framework closures from ``add_block``) are silently stored as ``null``.
181 """
182 qualname = getattr(raw_func, "__qualname__", "") or ""
183 if not self._use_dill_for_functions and "<lambda>" in qualname:
184 msg = (
185 f"Cannot serialize lambda function on node {node_key!r}. "
186 "Use a module-level importable function, serialize=False, "
187 "or ComputationSerializer(use_dill_for_functions=True)."
188 )
189 raise SerializationError(msg)
190 try:
191 return self._t.to_dict(raw_func)
192 except (UntransformableTypeError, ValueError, TypeError):
193 # Non-importable callable (e.g. framework closure) — store null.
194 return None
196 def _serialize_node(self, node_key: Any, node_data: dict[str, Any]) -> dict[str, Any]:
197 """Return the serialized dict for a single node."""
198 state: States | None = node_data.get(NodeAttributes.STATE)
199 tags: set[str] = node_data.get(NodeAttributes.TAG, set())
200 serialize_flag: bool = SystemTags.SERIALIZE in tags
202 if not serialize_flag:
203 serialized_state = States.UNINITIALIZED
204 encoded_value = None
205 has_value = False
206 else:
207 serialized_state = state
208 encoded_value, has_value = self._serialize_node_value(node_key, state, node_data)
210 raw_func = node_data.get(NodeAttributes.FUNC)
211 encoded_func = (
212 self._serialize_node_func(node_key, raw_func) if raw_func is not None and serialize_flag else None
213 )
215 user_tags = [t for t in tags if not t.startswith("__")]
217 return {
218 "key": str(node_key),
219 "state": serialized_state.name if serialized_state is not None else None,
220 "value": encoded_value,
221 "has_value": has_value,
222 "func": encoded_func,
223 "serialize": serialize_flag,
224 "tags": user_tags,
225 }
227 def _serialize_edge(self, src: Any, dst: Any, edge_data: dict[str, Any]) -> dict[str, Any]:
228 """Return the serialized dict for a single edge."""
229 param = edge_data.get(EdgeAttributes.PARAM)
230 if param is None:
231 return {"src": str(src), "dst": str(dst), "param_type": None, "param": None}
233 from loman.computeengine import _ParameterType
235 param_type, param_val = param
236 return {
237 "src": str(src),
238 "dst": str(dst),
239 "param_type": "kwd" if param_type == _ParameterType.KWD else "arg",
240 "param": param_val,
241 }
243 def _to_dict(self, comp: Any) -> dict[str, Any]:
244 """Convert a Computation to a JSON-serializable dict."""
245 nodes_out = [self._serialize_node(k, comp.dag.nodes[k]) for k in comp.dag.nodes()]
246 edges_out = [self._serialize_edge(src, dst, data) for src, dst, data in comp.dag.edges(data=True)]
247 return {"version": FORMAT_VERSION, "nodes": nodes_out, "edges": edges_out}
249 def load(self, fp: TextIO) -> Any:
250 """Deserialize a Computation from *fp* (a text-mode file-like object)."""
251 data = json.load(fp)
252 return self._from_dict(data)
254 def loads(self, s: str) -> Any:
255 """Deserialize a Computation from a JSON string."""
256 data = json.loads(s)
257 return self._from_dict(data)
259 def _from_dict(self, data: dict[str, Any]) -> Any:
260 """Reconstruct a Computation from a deserialized dict."""
261 from loman.computeengine import Computation, Error, _ParameterType
263 comp = Computation()
265 for node_info in data["nodes"]:
266 raw_key = node_info["key"]
267 node_key = parse_nodekey(raw_key)
268 state_name = node_info["state"]
269 state = States[state_name] if state_name is not None else None
270 serialize_flag: bool = node_info.get("serialize", True)
271 has_value: bool = node_info.get("has_value", False)
272 user_tags: list[str] = node_info.get("tags", [])
274 encoded_func = node_info.get("func")
275 func = self._t.from_dict(encoded_func) if encoded_func is not None else None
277 encoded_value = node_info.get("value")
278 if has_value and encoded_value is not None:
279 if isinstance(encoded_value, dict) and encoded_value.get("__loman_error__"):
280 value = Error(
281 exception=Exception(encoded_value["exception_str"]),
282 traceback=encoded_value["traceback"],
283 )
284 else:
285 value = self._t.from_dict(encoded_value)
286 else:
287 value = None
289 comp.dag.add_node(node_key)
290 node_data = comp.dag.nodes[node_key]
291 node_data[NodeAttributes.STATE] = state if state is not None else States.UNINITIALIZED
292 node_data[NodeAttributes.VALUE] = value if has_value else None
293 node_data[NodeAttributes.FUNC] = func
294 node_data[NodeAttributes.ARGS] = {}
295 node_data[NodeAttributes.KWDS] = {}
296 node_data[NodeAttributes.TAG] = set()
297 node_data[NodeAttributes.STYLE] = None
298 node_data[NodeAttributes.GROUP] = None
299 node_data[NodeAttributes.EXECUTOR] = None
300 node_data[NodeAttributes.CONVERTER] = None
302 if serialize_flag:
303 node_data[NodeAttributes.TAG].add(SystemTags.SERIALIZE)
304 for tag in user_tags:
305 node_data[NodeAttributes.TAG].add(tag)
307 for edge_info in data["edges"]:
308 src_key = parse_nodekey(edge_info["src"])
309 dst_key = parse_nodekey(edge_info["dst"])
310 param_type_str = edge_info.get("param_type")
311 param_val = edge_info.get("param")
313 if param_type_str is not None:
314 param_type = _ParameterType.KWD if param_type_str == "kwd" else _ParameterType.ARG
315 comp.dag.add_edge(src_key, dst_key, **{EdgeAttributes.PARAM: (param_type, param_val)})
316 else:
317 comp.dag.add_edge(src_key, dst_key)
319 comp._refresh_maps()
321 return comp