Coverage for src / loman / serialization / transformer.py: 100%

234 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-22 21:30 +0000

1"""Object serialization and transformation framework.""" 

2 

3import dataclasses 

4import graphlib 

5from abc import ABC, abstractmethod 

6from collections.abc import Iterable 

7from typing import Any 

8 

9import numpy as np 

10 

11try: 

12 import attrs 

13 

14 HAS_ATTRS = True 

15except ImportError: # pragma: no cover 

16 HAS_ATTRS = False 

17 

18KEY_TYPE = "type" 

19KEY_CLASS = "class" 

20KEY_VALUES = "values" 

21KEY_DATA = "data" 

22 

23TYPENAME_DICT = "dict" 

24TYPENAME_TUPLE = "tuple" 

25TYPENAME_TRANSFORMABLE = "transformable" 

26TYPENAME_ATTRS = "attrs" 

27TYPENAME_DATACLASS = "dataclass" 

28 

29 

30class UntransformableTypeError(Exception): 

31 """Exception raised when a type cannot be transformed for serialization.""" 

32 

33 pass 

34 

35 

36class UnrecognizedTypeError(Exception): 

37 """Exception raised when a type is not recognized during transformation.""" 

38 

39 pass 

40 

41 

42class MissingObject: 

43 """Sentinel object representing missing or unset values.""" 

44 

45 def __repr__(self) -> str: 

46 """Return string representation of missing object.""" 

47 return "Missing" 

48 

49 

50def order_classes(classes: Iterable[type]) -> list[type]: 

51 """Order classes by inheritance hierarchy using topological sort.""" 

52 graph: dict[type, set[type]] = {x: set() for x in classes} 

53 for x in classes: 

54 for y in classes: 

55 if issubclass(x, y) and x != y: 

56 graph[y].add(x) 

57 ts = graphlib.TopologicalSorter(graph) 

58 return list(ts.static_order()) 

59 

60 

61class CustomTransformer(ABC): 

62 """Abstract base class for custom object transformers.""" 

63 

64 @property 

65 @abstractmethod 

66 def name(self) -> str: 

67 """Return unique name identifier for this transformer.""" 

68 pass # pragma: no cover 

69 

70 @abstractmethod 

71 def to_dict(self, transformer: "Transformer", o: object) -> dict[str, Any]: 

72 """Convert object to dictionary representation.""" 

73 pass # pragma: no cover 

74 

75 @abstractmethod 

76 def from_dict(self, transformer: "Transformer", d: dict[str, Any]) -> object: 

77 """Reconstruct object from dictionary representation.""" 

78 pass # pragma: no cover 

79 

80 @property 

81 def supported_direct_types(self) -> Iterable[type]: 

82 """Return types that this transformer handles directly.""" 

83 return [] 

84 

85 @property 

86 def supported_subtypes(self) -> Iterable[type]: 

87 """Return base types whose subtypes this transformer can handle.""" 

88 return [] 

89 

90 

91class Transformable(ABC): 

92 """Abstract base class for objects that can transform themselves.""" 

93 

94 @abstractmethod 

95 def to_dict(self, transformer: "Transformer") -> dict[str, Any]: 

96 """Convert this object to dictionary representation.""" 

97 pass # pragma: no cover 

98 

99 @classmethod 

100 @abstractmethod 

101 def from_dict(cls, transformer: "Transformer", d: dict[str, Any]) -> object: 

102 """Reconstruct object from dictionary representation.""" 

103 pass # pragma: no cover 

104 

105 

106class Transformer: 

107 """Main transformer class for object serialization and deserialization.""" 

108 

109 def __init__(self, *, strict: bool = True) -> None: 

110 """Initialize transformer with strict mode setting.""" 

111 self.strict = strict 

112 

113 self._direct_type_map: dict[type, CustomTransformer] = {} 

114 self._subtype_order: list[type] = [] 

115 self._subtype_map: dict[type, CustomTransformer] = {} 

116 self._transformers: dict[str, CustomTransformer] = {} 

117 self._transformable_types: dict[str, type[Transformable]] = {} 

118 self._attrs_types: dict[str, type] = {} 

119 self._dataclass_types: dict[str, type] = {} 

120 

121 def register(self, t: CustomTransformer | type[Transformable] | type) -> None: 

122 """Register a transformer, transformable type, or regular type.""" 

123 if isinstance(t, CustomTransformer): 

124 self.register_transformer(t) 

125 elif isinstance(t, type) and issubclass(t, Transformable): 

126 self.register_transformable(t) 

127 elif HAS_ATTRS and isinstance(t, type) and attrs.has(t): 

128 self.register_attrs(t) 

129 elif isinstance(t, type) and dataclasses.is_dataclass(t): 

130 self.register_dataclass(t) 

131 else: 

132 msg = f"Unable to register {t}" 

133 raise ValueError(msg) 

134 

135 def register_transformer(self, transformer: CustomTransformer) -> None: 

136 """Register a custom transformer for specific types.""" 

137 assert transformer.name not in self._transformers # noqa: S101 

138 for type_ in transformer.supported_direct_types: 

139 assert type_ not in self._direct_type_map # noqa: S101 

140 for type_ in transformer.supported_subtypes: 

141 assert type_ not in self._subtype_map # noqa: S101 

142 

143 self._transformers[transformer.name] = transformer 

144 

145 for type_ in transformer.supported_direct_types: 

146 self._direct_type_map[type_] = transformer 

147 

148 contains_supported_subtypes = False 

149 for type_ in transformer.supported_subtypes: 

150 contains_supported_subtypes = True 

151 self._subtype_map[type_] = transformer 

152 if contains_supported_subtypes: 

153 self._subtype_order = order_classes(self._subtype_map.keys()) 

154 

155 def register_transformable(self, transformable_type: type[Transformable]) -> None: 

156 """Register a transformable type that can serialize itself.""" 

157 name = transformable_type.__name__ 

158 assert name not in self._transformable_types # noqa: S101 

159 self._transformable_types[name] = transformable_type 

160 

161 def register_attrs(self, attrs_type: type) -> None: 

162 """Register an attrs-decorated class for serialization.""" 

163 name = attrs_type.__name__ 

164 assert name not in self._attrs_types # noqa: S101 

165 self._attrs_types[name] = attrs_type 

166 

167 def register_dataclass(self, dataclass_type: type) -> None: 

168 """Register a dataclass for serialization.""" 

169 name = dataclass_type.__name__ 

170 assert name not in self._dataclass_types # noqa: S101 

171 self._dataclass_types[name] = dataclass_type 

172 

173 def get_transformer_for_obj(self, obj: object) -> CustomTransformer | None: 

174 """Get the appropriate transformer for a given object.""" 

175 transformer = self._direct_type_map.get(type(obj)) 

176 if transformer is not None: 

177 return transformer 

178 for tp in self._subtype_order: 

179 if isinstance(obj, tp): 

180 return self._subtype_map[tp] 

181 return None 

182 

183 def get_transformer_for_name(self, name: str) -> CustomTransformer | None: 

184 """Get a transformer by its registered name.""" 

185 transformer = self._transformers.get(name) 

186 return transformer 

187 

188 def to_dict(self, o: object) -> Any: 

189 """Convert an object to a serializable dictionary representation.""" 

190 if isinstance(o, str) or o is None or o is True or o is False or isinstance(o, (int, float)): 

191 return o 

192 elif isinstance(o, tuple): 

193 return {KEY_TYPE: TYPENAME_TUPLE, KEY_VALUES: [self.to_dict(x) for x in o]} 

194 elif isinstance(o, list): 

195 return [self.to_dict(x) for x in o] 

196 elif isinstance(o, dict): 

197 return self._dict_to_dict(o) 

198 elif isinstance(o, Transformable): 

199 return {KEY_TYPE: TYPENAME_TRANSFORMABLE, KEY_CLASS: type(o).__name__, KEY_DATA: o.to_dict(self)} 

200 elif HAS_ATTRS and attrs.has(type(o)): 

201 return self._attrs_to_dict(o) 

202 elif dataclasses.is_dataclass(o) and not isinstance(o, type): 

203 return self._dataclass_to_dict(o) 

204 else: 

205 return self._to_dict_transformer(o) 

206 

207 def _dict_to_dict(self, o: dict[Any, Any]) -> dict[str, Any]: 

208 """Convert a dictionary to serializable form.""" 

209 d = {k: self.to_dict(v) for k, v in o.items()} 

210 if KEY_TYPE in o: 

211 return {KEY_TYPE: TYPENAME_DICT, KEY_DATA: d} 

212 else: 

213 return d 

214 

215 def _attrs_to_dict(self, o: object) -> dict[str, Any]: 

216 """Convert an attrs object to serializable dictionary form.""" 

217 data: dict[str, Any] = {} 

218 for a in o.__attrs_attrs__: # type: ignore[attr-defined] 

219 data[a.name] = self.to_dict(o.__getattribute__(a.name)) 

220 res: dict[str, Any] = {KEY_TYPE: TYPENAME_ATTRS, KEY_CLASS: type(o).__name__} 

221 if len(data) > 0: 

222 res[KEY_DATA] = data 

223 return res 

224 

225 def _dataclass_to_dict(self, o: object) -> dict[str, Any]: 

226 """Convert a dataclass object to serializable dictionary form.""" 

227 data: dict[str, Any] = {} 

228 for f in dataclasses.fields(o): # type: ignore[arg-type] 

229 data[f.name] = self.to_dict(getattr(o, f.name)) 

230 res: dict[str, Any] = {KEY_TYPE: TYPENAME_DATACLASS, KEY_CLASS: type(o).__name__} 

231 if len(data) > 0: 

232 res[KEY_DATA] = data 

233 return res 

234 

235 def _to_dict_transformer(self, o: object) -> dict[str, Any] | None: 

236 """Convert an object using a registered custom transformer.""" 

237 transformer = self.get_transformer_for_obj(o) 

238 if transformer is None: 

239 if self.strict: 

240 msg = f"Could not transform object of type {type(o).__name__}" 

241 raise UntransformableTypeError(msg) 

242 else: 

243 return None 

244 d = transformer.to_dict(self, o) 

245 d[KEY_TYPE] = transformer.name 

246 return d 

247 

248 def from_dict(self, d: Any) -> Any: 

249 """Convert a dictionary representation back to the original object.""" 

250 if isinstance(d, str) or d is None or d is True or d is False or isinstance(d, (int, float)): 

251 return d 

252 elif isinstance(d, list): 

253 return [self.from_dict(x) for x in d] 

254 elif isinstance(d, dict): 

255 type_ = d.get(KEY_TYPE) 

256 if type_ is None: 

257 return {k: self.from_dict(v) for k, v in d.items()} 

258 elif type_ == TYPENAME_TUPLE: 

259 return tuple(self.from_dict(x) for x in d[KEY_VALUES]) 

260 elif type_ == TYPENAME_DICT: 

261 return {k: self.from_dict(v) for k, v in d[KEY_DATA].items()} 

262 elif type_ == TYPENAME_TRANSFORMABLE: 

263 return self._from_dict_transformable(d) 

264 elif type_ == TYPENAME_ATTRS: 

265 return self._from_attrs(d) 

266 elif type_ == TYPENAME_DATACLASS: 

267 return self._from_dataclass(d) 

268 else: 

269 return self._from_dict_transformer(type_, d) 

270 else: 

271 msg = "Unable to determine object type from dictionary" 

272 raise ValueError(msg) 

273 

274 def _from_dict_transformable(self, d: dict[str, Any]) -> object: 

275 """Reconstruct a Transformable object from dictionary form.""" 

276 classname = d[KEY_CLASS] 

277 cls = self._transformable_types.get(classname) 

278 if cls is None: 

279 if self.strict: 

280 msg = f"Unable to transform Transformable object of class {classname}" 

281 raise UnrecognizedTypeError(msg) 

282 else: 

283 return MissingObject() 

284 else: 

285 return cls.from_dict(self, d[KEY_DATA]) 

286 

287 def _from_attrs(self, d: dict[str, Any]) -> object: 

288 """Reconstruct an attrs object from dictionary form.""" 

289 if not HAS_ATTRS: # pragma: no cover 

290 if self.strict: 

291 msg = "attrs package not installed" 

292 raise UnrecognizedTypeError(msg) 

293 return MissingObject() 

294 cls = self._attrs_types.get(d[KEY_CLASS]) 

295 if cls is None: 

296 if self.strict: 

297 msg = f"Unable to create attrs object of type {cls}" 

298 raise UnrecognizedTypeError(msg) 

299 else: 

300 return MissingObject() 

301 else: 

302 kwargs: dict[str, Any] = {} 

303 if KEY_DATA in d: 

304 for key, value in d[KEY_DATA].items(): 

305 kwargs[key] = self.from_dict(value) 

306 return cls(**kwargs) 

307 

308 def _from_dataclass(self, d: dict[str, Any]) -> object: 

309 """Reconstruct a dataclass object from dictionary form.""" 

310 cls = self._dataclass_types.get(d[KEY_CLASS]) 

311 if cls is None: 

312 if self.strict: 

313 msg = f"Unable to create dataclass object of type {cls}" 

314 raise UnrecognizedTypeError(msg) 

315 else: 

316 return MissingObject() 

317 else: 

318 kwargs: dict[str, Any] = {} 

319 if KEY_DATA in d: 

320 for key, value in d[KEY_DATA].items(): 

321 kwargs[key] = self.from_dict(value) 

322 return cls(**kwargs) 

323 

324 def _from_dict_transformer(self, type_: str, d: dict[str, Any]) -> object: 

325 """Reconstruct an object using a registered custom transformer.""" 

326 transformer = self.get_transformer_for_name(type_) 

327 if transformer is None: 

328 if self.strict: 

329 msg = f"Unable to transform object of type {type_}" 

330 raise UnrecognizedTypeError(msg) 

331 else: 

332 return MissingObject() 

333 return transformer.from_dict(self, d) 

334 

335 

336class NdArrayTransformer(CustomTransformer): 

337 """Transformer for NumPy ndarray objects.""" 

338 

339 @property 

340 def name(self) -> str: 

341 """Return transformer name.""" 

342 return "ndarray" 

343 

344 def to_dict(self, transformer: "Transformer", o: object) -> dict[str, Any]: 

345 """Convert numpy array to dictionary with shape, dtype, and data.""" 

346 assert isinstance(o, np.ndarray) # noqa: S101 

347 return {"shape": list(o.shape), "dtype": o.dtype.str, "data": transformer.to_dict(o.ravel().tolist())} # type: ignore[arg-type] 

348 

349 def from_dict(self, transformer: "Transformer", d: dict[str, Any]) -> object: 

350 """Reconstruct numpy array from dictionary.""" 

351 return np.array(transformer.from_dict(d["data"]), d["dtype"]).reshape(d["shape"]) 

352 

353 @property 

354 def supported_direct_types(self) -> Iterable[type]: 

355 """Return supported numpy array types.""" 

356 return [np.ndarray]