Coverage for src/loman/serialization/transformer.py: 90%

236 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-21 05:36 +0000

1"""Object serialization and transformation framework.""" 

2 

3import graphlib 

4from abc import ABC, abstractmethod 

5from collections.abc import Iterable 

6 

7import numpy as np 

8 

9try: 

10 import attrs 

11 

12 HAS_ATTRS = True 

13except ImportError: 

14 HAS_ATTRS = False 

15 

16import dataclasses 

17 

18KEY_TYPE = "type" 

19KEY_CLASS = "class" 

20KEY_VALUES = "values" 

21KEY_DATA = "data" 

22 

23TYPENAME_DICT = "dict" 

24TYPENAME_TUPLE = "tuple" 

25TYPENAME_TRANSFORMABLE = "transformable" 

26TYPENAME_ATTRS = "attrs" 

27TYPENAME_DATACLASS = "dataclass" 

28 

29 

30class UntransformableTypeError(Exception): 

31 """Exception raised when a type cannot be transformed for serialization.""" 

32 

33 pass 

34 

35 

36class UnrecognizedTypeError(Exception): 

37 """Exception raised when a type is not recognized during transformation.""" 

38 

39 pass 

40 

41 

42class MissingObject: 

43 """Sentinel object representing missing or unset values.""" 

44 

45 def __repr__(self): 

46 """Return string representation of missing object.""" 

47 return "Missing" 

48 

49 

50def order_classes(classes): 

51 """Order classes by inheritance hierarchy using topological sort.""" 

52 graph = {x: set() for x in classes} 

53 for x in classes: 

54 for y in classes: 

55 if issubclass(x, y) and x != y: 

56 graph[y].add(x) 

57 ts = graphlib.TopologicalSorter(graph) 

58 return list(ts.static_order()) 

59 

60 

61class CustomTransformer(ABC): 

62 """Abstract base class for custom object transformers.""" 

63 

64 @property 

65 @abstractmethod 

66 def name(self) -> str: 

67 """Return unique name identifier for this transformer.""" 

68 pass 

69 

70 @abstractmethod 

71 def to_dict(self, transformer: "Transformer", o: object) -> dict: 

72 """Convert object to dictionary representation.""" 

73 pass 

74 

75 @abstractmethod 

76 def from_dict(self, transformer: "Transformer", d: dict) -> object: 

77 """Reconstruct object from dictionary representation.""" 

78 pass 

79 

80 @property 

81 def supported_direct_types(self) -> Iterable[type]: 

82 """Return types that this transformer handles directly.""" 

83 return [] 

84 

85 @property 

86 def supported_subtypes(self) -> Iterable[type]: 

87 """Return base types whose subtypes this transformer can handle.""" 

88 return [] 

89 

90 

91class Transformable(ABC): 

92 """Abstract base class for objects that can transform themselves.""" 

93 

94 @abstractmethod 

95 def to_dict(self, transformer: "Transformer") -> dict: 

96 """Convert this object to dictionary representation.""" 

97 pass 

98 

99 @classmethod 

100 @abstractmethod 

101 def from_dict(cls, transformer: "Transformer", d: dict) -> object: 

102 """Reconstruct object from dictionary representation.""" 

103 pass 

104 

105 

106class Transformer: 

107 """Main transformer class for object serialization and deserialization.""" 

108 

109 def __init__(self, *, strict: bool = True): 

110 """Initialize transformer with strict mode setting.""" 

111 self.strict = strict 

112 

113 self._direct_type_map = {} 

114 self._subtype_order = [] 

115 self._subtype_map = {} 

116 self._transformers = {} 

117 self._transformable_types = {} 

118 self._attrs_types = {} 

119 self._dataclass_types = {} 

120 

121 def register(self, t: CustomTransformer | type[Transformable] | type): 

122 """Register a transformer, transformable type, or regular type.""" 

123 if isinstance(t, CustomTransformer): 

124 self.register_transformer(t) 

125 elif issubclass(t, Transformable): 

126 self.register_transformable(t) 

127 elif HAS_ATTRS and attrs.has(t): 

128 self.register_attrs(t) 

129 elif dataclasses.is_dataclass(t): 

130 self.register_dataclass(t) 

131 else: 

132 raise ValueError(f"Unable to register {t}") 

133 

134 def register_transformer(self, transformer: CustomTransformer): 

135 """Register a custom transformer for specific types.""" 

136 assert transformer.name not in self._transformers 

137 for type_ in transformer.supported_direct_types: 

138 assert type_ not in self._direct_type_map 

139 for type_ in transformer.supported_subtypes: 

140 assert type_ not in self._subtype_map 

141 

142 self._transformers[transformer.name] = transformer 

143 

144 for type_ in transformer.supported_direct_types: 

145 self._direct_type_map[type_] = transformer 

146 

147 contains_supported_subtypes = False 

148 for type_ in transformer.supported_subtypes: 

149 contains_supported_subtypes = True 

150 self._subtype_map[type_] = transformer 

151 if contains_supported_subtypes: 

152 self._subtype_order = order_classes(self._subtype_map.keys()) 

153 

154 def register_transformable(self, transformable_type: type[Transformable]): 

155 """Register a transformable type that can serialize itself.""" 

156 name = transformable_type.__name__ 

157 assert name not in self._transformable_types 

158 self._transformable_types[name] = transformable_type 

159 

160 def register_attrs(self, attrs_type: type): 

161 """Register an attrs-decorated class for serialization.""" 

162 name = attrs_type.__name__ 

163 assert name not in self._attrs_types 

164 self._attrs_types[name] = attrs_type 

165 

166 def register_dataclass(self, dataclass_type: type): 

167 """Register a dataclass for serialization.""" 

168 name = dataclass_type.__name__ 

169 assert name not in self._dataclass_types 

170 self._dataclass_types[name] = dataclass_type 

171 

172 def get_transformer_for_obj(self, obj) -> CustomTransformer | None: 

173 """Get the appropriate transformer for a given object.""" 

174 transformer = self._direct_type_map.get(type(obj)) 

175 if transformer is not None: 

176 return transformer 

177 for tp in self._subtype_order: 

178 if isinstance(obj, tp): 

179 return self._subtype_map[tp] 

180 

181 def get_transformer_for_name(self, name) -> CustomTransformer | None: 

182 """Get a transformer by its registered name.""" 

183 transformer = self._transformers.get(name) 

184 return transformer 

185 

186 def to_dict(self, o): 

187 """Convert an object to a serializable dictionary representation.""" 

188 if isinstance(o, str) or o is None or o is True or o is False or isinstance(o, (int, float)): 

189 return o 

190 elif isinstance(o, tuple): 

191 return {KEY_TYPE: TYPENAME_TUPLE, KEY_VALUES: [self.to_dict(x) for x in o]} 

192 elif isinstance(o, list): 

193 return [self.to_dict(x) for x in o] 

194 elif isinstance(o, dict): 

195 return self._dict_to_dict(o) 

196 elif isinstance(o, Transformable): 

197 return {KEY_TYPE: TYPENAME_TRANSFORMABLE, KEY_CLASS: type(o).__name__, KEY_DATA: o.to_dict(self)} 

198 elif HAS_ATTRS and attrs.has(o): 

199 return self._attrs_to_dict(o) 

200 elif dataclasses.is_dataclass(o): 

201 return self._dataclass_to_dict(o) 

202 else: 

203 return self._to_dict_transformer(o) 

204 

205 def _dict_to_dict(self, o): 

206 d = {k: self.to_dict(v) for k, v in o.items()} 

207 if KEY_TYPE in o: 

208 return {KEY_TYPE: TYPENAME_DICT, KEY_DATA: d} 

209 else: 

210 return d 

211 

212 def _attrs_to_dict(self, o): 

213 data = {} 

214 for a in o.__attrs_attrs__: 

215 data[a.name] = self.to_dict(o.__getattribute__(a.name)) 

216 res = {KEY_TYPE: TYPENAME_ATTRS, KEY_CLASS: type(o).__name__} 

217 if len(data) > 0: 

218 res[KEY_DATA] = data 

219 return res 

220 

221 def _dataclass_to_dict(self, o): 

222 data = {} 

223 for f in dataclasses.fields(o): 

224 data[f.name] = self.to_dict(getattr(o, f.name)) 

225 res = {KEY_TYPE: TYPENAME_DATACLASS, KEY_CLASS: type(o).__name__} 

226 if len(data) > 0: 

227 res[KEY_DATA] = data 

228 return res 

229 

230 def _to_dict_transformer(self, o): 

231 transformer = self.get_transformer_for_obj(o) 

232 if transformer is None: 

233 if self.strict: 

234 raise UntransformableTypeError(f"Could not transform object of type {type(o).__name__}") 

235 else: 

236 return None 

237 d = transformer.to_dict(self, o) 

238 d[KEY_TYPE] = transformer.name 

239 return d 

240 

241 def from_dict(self, d): 

242 """Convert a dictionary representation back to the original object.""" 

243 if isinstance(d, str) or d is None or d is True or d is False or isinstance(d, (int, float)): 

244 return d 

245 elif isinstance(d, list): 

246 return [self.from_dict(x) for x in d] 

247 elif isinstance(d, dict): 

248 type_ = d.get(KEY_TYPE) 

249 if type_ is None: 

250 return {k: self.from_dict(v) for k, v in d.items()} 

251 elif type_ == TYPENAME_TUPLE: 

252 return tuple(self.from_dict(x) for x in d[KEY_VALUES]) 

253 elif type_ == TYPENAME_DICT: 

254 return {k: self.from_dict(v) for k, v in d[KEY_DATA].items()} 

255 elif type_ == TYPENAME_TRANSFORMABLE: 

256 return self._from_dict_transformable(d) 

257 elif type_ == TYPENAME_ATTRS: 

258 return self._from_attrs(d) 

259 elif type_ == TYPENAME_DATACLASS: 

260 return self._from_dataclass(d) 

261 else: 

262 return self._from_dict_transformer(type_, d) 

263 else: 

264 raise Exception() 

265 

266 def _from_dict_transformable(self, d): 

267 classname = d[KEY_CLASS] 

268 cls = self._transformable_types.get(classname) 

269 if cls is None: 

270 if self.strict: 

271 raise UnrecognizedTypeError(f"Unable to transform Transformable object of class {classname}") 

272 else: 

273 return MissingObject() 

274 else: 

275 return cls.from_dict(self, d[KEY_DATA]) 

276 

277 def _from_attrs(self, d): 

278 if not HAS_ATTRS: 

279 if self.strict: 

280 raise UnrecognizedTypeError("attrs package not installed") 

281 return MissingObject() 

282 cls = self._attrs_types.get(d[KEY_CLASS]) 

283 if cls is None: 

284 if self.strict: 

285 raise UnrecognizedTypeError(f"Unable to create attrs object of type {cls}") 

286 else: 

287 return MissingObject() 

288 else: 

289 kwargs = {} 

290 if KEY_DATA in d: 

291 for key, value in d[KEY_DATA].items(): 

292 kwargs[key] = self.from_dict(value) 

293 return cls(**kwargs) 

294 

295 def _from_dataclass(self, d): 

296 cls = self._dataclass_types.get(d[KEY_CLASS]) 

297 if cls is None: 

298 if self.strict: 

299 raise UnrecognizedTypeError(f"Unable to create dataclass object of type {cls}") 

300 else: 

301 return MissingObject() 

302 else: 

303 kwargs = {} 

304 if KEY_DATA in d: 

305 for key, value in d[KEY_DATA].items(): 

306 kwargs[key] = self.from_dict(value) 

307 return cls(**kwargs) 

308 

309 def _from_dict_transformer(self, type_, d): 

310 transformer = self.get_transformer_for_name(type_) 

311 if transformer is None: 

312 if self.strict: 

313 raise UnrecognizedTypeError(f"Unable to transform object of type {type_}") 

314 else: 

315 return MissingObject() 

316 return transformer.from_dict(self, d) 

317 

318 

319class NdArrayTransformer(CustomTransformer): 

320 """Transformer for NumPy ndarray objects.""" 

321 

322 @property 

323 def name(self): 

324 """Return transformer name.""" 

325 return "ndarray" 

326 

327 def to_dict(self, transformer: "Transformer", o: object) -> dict: 

328 """Convert numpy array to dictionary with shape, dtype, and data.""" 

329 assert isinstance(o, np.ndarray) 

330 return {"shape": list(o.shape), "dtype": o.dtype.str, "data": transformer.to_dict(o.ravel().tolist())} 

331 

332 def from_dict(self, transformer: "Transformer", d: dict) -> object: 

333 """Reconstruct numpy array from dictionary.""" 

334 return np.array(transformer.from_dict(d["data"]), d["dtype"]).reshape(d["shape"]) 

335 

336 @property 

337 def supported_direct_types(self): 

338 """Return supported numpy array types.""" 

339 return [np.ndarray]