Coverage for src / loman / util.py: 100%

68 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-07 21:24 +0000

1"""Utility functions and classes for loman computation graphs.""" 

2 

3import itertools 

4import types 

5from collections.abc import Callable, Generator, Iterable 

6from typing import Any, TypeVar, overload 

7 

8import numpy as np 

9import pandas as pd 

10 

11T = TypeVar("T") 

12R = TypeVar("R") 

13 

14 

15@overload 

16def apply1(f: Callable[..., R], xs: list[T], *args: Any, **kwds: Any) -> list[R]: ... 

17 

18 

19@overload 

20def apply1(f: Callable[..., R], xs: T, *args: Any, **kwds: Any) -> R: ... 

21 

22 

23@overload 

24def apply1(f: Callable[..., R], xs: Generator[T, None, None], *args: Any, **kwds: Any) -> Generator[R, None, None]: ... 

25 

26 

27def apply1( 

28 f: Callable[..., R], xs: T | list[T] | Generator[T, None, None], *args: Any, **kwds: Any 

29) -> R | list[R] | Generator[R, None, None]: 

30 """Apply function f to xs, handling generators, lists, and single values.""" 

31 if isinstance(xs, types.GeneratorType): 

32 return (f(x, *args, **kwds) for x in xs) 

33 if isinstance(xs, list): 

34 return [f(x, *args, **kwds) for x in xs] 

35 return f(xs, *args, **kwds) 

36 

37 

38def as_iterable(xs: T | Iterable[T]) -> Iterable[T]: 

39 """Convert input to iterable form if not already iterable.""" 

40 if isinstance(xs, (types.GeneratorType, list, set)): 

41 return xs # type: ignore[return-value] 

42 return (xs,) # type: ignore[return-value] 

43 

44 

45def apply_n(f: Callable[..., Any], *xs: Any, **kwds: Any) -> None: 

46 """Apply function f to the cartesian product of iterables xs.""" 

47 for p in itertools.product(*[as_iterable(x) for x in xs]): 

48 f(*p, **kwds) 

49 

50 

51class AttributeView: 

52 """Provides attribute-style access to dynamic collections.""" 

53 

54 def __init__( 

55 self, 

56 get_attribute_list: Callable[[], Iterable[str]], 

57 get_attribute: Callable[[str], Any], 

58 get_item: Callable[[Any], Any] | None = None, 

59 ) -> None: 

60 """Initialize with functions to get attribute list and individual attributes. 

61 

62 Args: 

63 get_attribute_list: Function that returns list of available attributes 

64 get_attribute: Function that takes an attribute name and returns its value 

65 get_item: Optional function for item access, defaults to get_attribute 

66 """ 

67 self.get_attribute_list = get_attribute_list 

68 self.get_attribute = get_attribute 

69 self.get_item: Callable[[Any], Any] = get_item if get_item is not None else get_attribute 

70 

71 def __dir__(self) -> list[str]: 

72 """Return list of available attributes.""" 

73 return list(self.get_attribute_list()) 

74 

75 def __getattr__(self, attr: str) -> Any: 

76 """Get attribute by name, raising AttributeError if not found.""" 

77 try: 

78 return self.get_attribute(attr) 

79 except KeyError as e: 

80 raise AttributeError(attr) from e 

81 

82 def __getitem__(self, key: Any) -> Any: 

83 """Get item by key.""" 

84 return self.get_item(key) 

85 

86 def __getstate__(self) -> dict[str, Any]: 

87 """Prepare object for serialization.""" 

88 return { 

89 "get_attribute_list": self.get_attribute_list, 

90 "get_attribute": self.get_attribute, 

91 "get_item": self.get_item, 

92 } 

93 

94 def __setstate__(self, state: dict[str, Any]) -> None: 

95 """Restore object from serialized state.""" 

96 self.get_attribute_list = state["get_attribute_list"] 

97 self.get_attribute = state["get_attribute"] 

98 self.get_item = state["get_item"] 

99 if self.get_item is None: 

100 self.get_item = self.get_attribute 

101 

102 @staticmethod 

103 def from_dict(d: dict[Any, Any], use_apply1: bool = True) -> "AttributeView": 

104 """Create an AttributeView from a dictionary.""" 

105 if use_apply1: 

106 

107 def get_attribute(xs: Any) -> Any: 

108 """Get attribute value from dictionary with apply1 support.""" 

109 return apply1(d.get, xs) 

110 else: 

111 get_attribute = d.get 

112 return AttributeView(d.keys, get_attribute) 

113 

114 

115pandas_types = (pd.Series, pd.DataFrame) 

116 

117 

118def value_eq(a: Any, b: Any) -> bool: 

119 """Compare two values for equality, handling pandas and numpy objects safely. 

120 

121 - Uses .equals for pandas Series/DataFrame 

122 - For numpy arrays, returns a single boolean using np.array_equal (treats NaNs as equal) 

123 - Falls back to == and coerces to bool when possible 

124 """ 

125 if a is b: 

126 return True 

127 

128 # pandas objects: use robust equality 

129 if isinstance(a, pandas_types): 

130 return bool(a.equals(b)) 

131 if isinstance(b, pandas_types): # pragma: no cover 

132 return bool(b.equals(a)) 

133 if isinstance(a, np.ndarray) or isinstance(b, np.ndarray): 

134 try: 

135 return bool(np.array_equal(a, b, equal_nan=True)) 

136 except Exception: 

137 return False 

138 

139 # Default comparison; ensure a single boolean 

140 try: 

141 result = a == b 

142 # If result is an array-like truth value, reduce safely 

143 if isinstance(result, (np.ndarray,)): 

144 return bool(np.all(result)) 

145 return bool(result) 

146 except Exception: 

147 return False