Coverage for src / loman / serialization / computation.py: 93%

140 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-07 21:24 +0000

1"""Serialization for Computation graphs to/from JSON.""" 

2 

3from __future__ import annotations 

4 

5import json 

6from typing import TYPE_CHECKING, Any, ClassVar, TextIO 

7 

8from loman.consts import EdgeAttributes, NodeAttributes, States, SystemTags 

9from loman.exception import SerializationError 

10from loman.nodekey import parse_nodekey 

11 

12from .transformer import ( 

13 DataFrameTransformer, 

14 DillFunctionTransformer, 

15 EnumTransformer, 

16 FunctionRefTransformer, 

17 NdArrayTransformer, 

18 NodeKeyTransformer, 

19 SeriesTransformer, 

20 Transformer, 

21 UntransformableTypeError, 

22) 

23 

24if TYPE_CHECKING: 

25 pass 

26 

27# Serialization format version — bump when the schema changes. 

28FORMAT_VERSION = 1 

29 

30 

31def default_computation_transformer() -> Transformer: 

32 """Create a Transformer pre-registered with all types needed for Computation serialization.""" 

33 t = Transformer() 

34 

35 # Numeric arrays 

36 t.register(NdArrayTransformer()) 

37 

38 # Enums: register the States enum so node states roundtrip correctly. 

39 enum_t = EnumTransformer() 

40 enum_t.register_enum(States) 

41 t.register(enum_t) 

42 

43 # Importable callables (module-level functions). Lambdas / closures raise. 

44 t.register(FunctionRefTransformer()) 

45 

46 # Pandas 

47 t.register(DataFrameTransformer()) 

48 t.register(SeriesTransformer()) 

49 

50 # NodeKey (hierarchical node names) 

51 t.register(NodeKeyTransformer()) 

52 

53 return t 

54 

55 

56def dill_computation_transformer() -> Transformer: 

57 """Create a Transformer that serializes all callables — including lambdas and closures — via dill. 

58 

59 Identical to :func:`default_computation_transformer` except that 

60 :class:`~loman.serialization.transformer.DillFunctionTransformer` replaces 

61 :class:`~loman.serialization.transformer.FunctionRefTransformer`, so lambdas 

62 and locally-defined closures are serialized as base64-encoded dill blobs 

63 rather than raising :class:`~loman.exception.SerializationError`. 

64 """ 

65 t = Transformer() 

66 

67 t.register(NdArrayTransformer()) 

68 

69 enum_t = EnumTransformer() 

70 enum_t.register_enum(States) 

71 t.register(enum_t) 

72 

73 # Dill-based callable serializer — handles lambdas and closures. 

74 t.register(DillFunctionTransformer()) 

75 

76 t.register(DataFrameTransformer()) 

77 t.register(SeriesTransformer()) 

78 t.register(NodeKeyTransformer()) 

79 

80 return t 

81 

82 

83class ComputationSerializer: 

84 """Serialize and deserialize a :class:`~loman.computeengine.Computation` graph to JSON. 

85 

86 The serialized format is a JSON object with the following top-level keys: 

87 

88 - ``version``: integer format version 

89 - ``nodes``: list of node objects 

90 - ``edges``: list of edge objects 

91 

92 Each **node** object has: 

93 

94 - ``key``: string representation of the NodeKey 

95 - ``state``: name of the :class:`~loman.consts.States` enum member (or ``null``) 

96 - ``value``: transformer-encoded value (or ``null`` when absent / not serialized) 

97 - ``has_value``: bool — false when the node has no meaningful value to restore 

98 - ``func``: transformer-encoded callable (or ``null``) 

99 - ``serialize``: bool — whether the node has the ``__serialize__`` tag 

100 - ``tags``: list of non-system tags 

101 

102 Each **edge** object has: 

103 

104 - ``src``: string key of the source node 

105 - ``dst``: string key of the destination node 

106 - ``param_type``: ``"arg"`` or ``"kwd"`` 

107 - ``param``: positional index (int) for args, parameter name (str) for kwds 

108 

109 Parameters 

110 ---------- 

111 transformer: 

112 Custom :class:`~loman.serialization.transformer.Transformer` instance. 

113 If ``None``, a default transformer is built based on *use_dill_for_functions*. 

114 use_dill_for_functions: 

115 When ``True``, lambdas and closures are serialized as base64-encoded dill 

116 blobs rather than raising :class:`~loman.exception.SerializationError`. 

117 Has no effect when a custom *transformer* is supplied. Defaults to ``False``. 

118 """ 

119 

120 # States whose nodes carry a meaningful value that should be preserved. 

121 _VALUE_STATES: ClassVar[set[States]] = {States.UPTODATE, States.PINNED, States.ERROR} 

122 

123 def __init__( 

124 self, 

125 transformer: Transformer | None = None, 

126 *, 

127 use_dill_for_functions: bool = False, 

128 ) -> None: 

129 """Initialise with an optional custom transformer.""" 

130 if transformer is None: 

131 transformer = ( 

132 dill_computation_transformer() if use_dill_for_functions else default_computation_transformer() 

133 ) 

134 self._t = transformer 

135 self._use_dill_for_functions = use_dill_for_functions 

136 

137 def dump(self, comp: Any, fp: TextIO) -> None: 

138 """Serialize *comp* to *fp* (a text-mode file-like object).""" 

139 data = self._to_dict(comp) 

140 json.dump(data, fp) 

141 

142 def dumps(self, comp: Any) -> str: 

143 """Serialize *comp* and return a JSON string.""" 

144 return json.dumps(self._to_dict(comp)) 

145 

146 def _serialize_node_value(self, node_key: Any, state: States | None, node_data: dict[str, Any]) -> tuple[Any, bool]: 

147 """Return ``(encoded_value, has_value)`` for a node that should be serialized. 

148 

149 Raises :class:`~loman.exception.SerializationError` if the value cannot 

150 be encoded. 

151 """ 

152 from loman.computeengine import Error 

153 

154 if state not in self._VALUE_STATES: 

155 return None, False 

156 

157 raw_value = node_data.get(NodeAttributes.VALUE) 

158 if state == States.ERROR and isinstance(raw_value, Error): 

159 return ( 

160 { 

161 "__loman_error__": True, 

162 "exception_type": type(raw_value.exception).__name__, 

163 "exception_str": str(raw_value.exception), 

164 "traceback": raw_value.traceback, 

165 }, 

166 True, 

167 ) 

168 

169 try: 

170 return self._t.to_dict(raw_value), True 

171 except (UntransformableTypeError, ValueError) as exc: 

172 msg = f"Cannot serialize value of node {node_key!r}: {exc}" 

173 raise SerializationError(msg) from exc 

174 

175 def _serialize_node_func(self, node_key: Any, raw_func: Any) -> Any: 

176 """Return the encoded function for a node, or ``None`` if it cannot be serialized. 

177 

178 Lambdas raise :class:`~loman.exception.SerializationError` unless 

179 ``use_dill_for_functions`` is enabled. Other non-importable callables 

180 (e.g. framework closures from ``add_block``) are silently stored as ``null``. 

181 """ 

182 qualname = getattr(raw_func, "__qualname__", "") or "" 

183 if not self._use_dill_for_functions and "<lambda>" in qualname: 

184 msg = ( 

185 f"Cannot serialize lambda function on node {node_key!r}. " 

186 "Use a module-level importable function, serialize=False, " 

187 "or ComputationSerializer(use_dill_for_functions=True)." 

188 ) 

189 raise SerializationError(msg) 

190 try: 

191 return self._t.to_dict(raw_func) 

192 except (UntransformableTypeError, ValueError, TypeError): 

193 # Non-importable callable (e.g. framework closure) — store null. 

194 return None 

195 

196 def _serialize_node(self, node_key: Any, node_data: dict[str, Any]) -> dict[str, Any]: 

197 """Return the serialized dict for a single node.""" 

198 state: States | None = node_data.get(NodeAttributes.STATE) 

199 tags: set[str] = node_data.get(NodeAttributes.TAG, set()) 

200 serialize_flag: bool = SystemTags.SERIALIZE in tags 

201 

202 if not serialize_flag: 

203 serialized_state = States.UNINITIALIZED 

204 encoded_value = None 

205 has_value = False 

206 else: 

207 serialized_state = state 

208 encoded_value, has_value = self._serialize_node_value(node_key, state, node_data) 

209 

210 raw_func = node_data.get(NodeAttributes.FUNC) 

211 encoded_func = ( 

212 self._serialize_node_func(node_key, raw_func) if raw_func is not None and serialize_flag else None 

213 ) 

214 

215 user_tags = [t for t in tags if not t.startswith("__")] 

216 

217 return { 

218 "key": str(node_key), 

219 "state": serialized_state.name if serialized_state is not None else None, 

220 "value": encoded_value, 

221 "has_value": has_value, 

222 "func": encoded_func, 

223 "serialize": serialize_flag, 

224 "tags": user_tags, 

225 } 

226 

227 def _serialize_edge(self, src: Any, dst: Any, edge_data: dict[str, Any]) -> dict[str, Any]: 

228 """Return the serialized dict for a single edge.""" 

229 param = edge_data.get(EdgeAttributes.PARAM) 

230 if param is None: 

231 return {"src": str(src), "dst": str(dst), "param_type": None, "param": None} 

232 

233 from loman.computeengine import _ParameterType 

234 

235 param_type, param_val = param 

236 return { 

237 "src": str(src), 

238 "dst": str(dst), 

239 "param_type": "kwd" if param_type == _ParameterType.KWD else "arg", 

240 "param": param_val, 

241 } 

242 

243 def _to_dict(self, comp: Any) -> dict[str, Any]: 

244 """Convert a Computation to a JSON-serializable dict.""" 

245 nodes_out = [self._serialize_node(k, comp.dag.nodes[k]) for k in comp.dag.nodes()] 

246 edges_out = [self._serialize_edge(src, dst, data) for src, dst, data in comp.dag.edges(data=True)] 

247 return {"version": FORMAT_VERSION, "nodes": nodes_out, "edges": edges_out} 

248 

249 def load(self, fp: TextIO) -> Any: 

250 """Deserialize a Computation from *fp* (a text-mode file-like object).""" 

251 data = json.load(fp) 

252 return self._from_dict(data) 

253 

254 def loads(self, s: str) -> Any: 

255 """Deserialize a Computation from a JSON string.""" 

256 data = json.loads(s) 

257 return self._from_dict(data) 

258 

259 def _from_dict(self, data: dict[str, Any]) -> Any: 

260 """Reconstruct a Computation from a deserialized dict.""" 

261 from loman.computeengine import Computation, Error, _ParameterType 

262 

263 comp = Computation() 

264 

265 for node_info in data["nodes"]: 

266 raw_key = node_info["key"] 

267 node_key = parse_nodekey(raw_key) 

268 state_name = node_info["state"] 

269 state = States[state_name] if state_name is not None else None 

270 serialize_flag: bool = node_info.get("serialize", True) 

271 has_value: bool = node_info.get("has_value", False) 

272 user_tags: list[str] = node_info.get("tags", []) 

273 

274 encoded_func = node_info.get("func") 

275 func = self._t.from_dict(encoded_func) if encoded_func is not None else None 

276 

277 encoded_value = node_info.get("value") 

278 if has_value and encoded_value is not None: 

279 if isinstance(encoded_value, dict) and encoded_value.get("__loman_error__"): 

280 value = Error( 

281 exception=Exception(encoded_value["exception_str"]), 

282 traceback=encoded_value["traceback"], 

283 ) 

284 else: 

285 value = self._t.from_dict(encoded_value) 

286 else: 

287 value = None 

288 

289 comp.dag.add_node(node_key) 

290 node_data = comp.dag.nodes[node_key] 

291 node_data[NodeAttributes.STATE] = state if state is not None else States.UNINITIALIZED 

292 node_data[NodeAttributes.VALUE] = value if has_value else None 

293 node_data[NodeAttributes.FUNC] = func 

294 node_data[NodeAttributes.ARGS] = {} 

295 node_data[NodeAttributes.KWDS] = {} 

296 node_data[NodeAttributes.TAG] = set() 

297 node_data[NodeAttributes.STYLE] = None 

298 node_data[NodeAttributes.GROUP] = None 

299 node_data[NodeAttributes.EXECUTOR] = None 

300 node_data[NodeAttributes.CONVERTER] = None 

301 

302 if serialize_flag: 

303 node_data[NodeAttributes.TAG].add(SystemTags.SERIALIZE) 

304 for tag in user_tags: 

305 node_data[NodeAttributes.TAG].add(tag) 

306 

307 for edge_info in data["edges"]: 

308 src_key = parse_nodekey(edge_info["src"]) 

309 dst_key = parse_nodekey(edge_info["dst"]) 

310 param_type_str = edge_info.get("param_type") 

311 param_val = edge_info.get("param") 

312 

313 if param_type_str is not None: 

314 param_type = _ParameterType.KWD if param_type_str == "kwd" else _ParameterType.ARG 

315 comp.dag.add_edge(src_key, dst_key, **{EdgeAttributes.PARAM: (param_type, param_val)}) 

316 else: 

317 comp.dag.add_edge(src_key, dst_key) 

318 

319 comp._refresh_maps() 

320 

321 return comp