Coverage for src / loman / util.py: 61%

80 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-02 23:34 +0000

1"""Utility functions and classes for loman computation graphs.""" 

2 

3import itertools 

4import types 

5 

6import numpy as np 

7import pandas as pd 

8 

9 

10def apply1(f, xs, *args, **kwds): 

11 """Apply function f to xs, handling generators, lists, and single values.""" 

12 if isinstance(xs, types.GeneratorType): 

13 return (f(x, *args, **kwds) for x in xs) 

14 if isinstance(xs, list): 

15 return [f(x, *args, **kwds) for x in xs] 

16 return f(xs, *args, **kwds) 

17 

18 

19def as_iterable(xs): 

20 """Convert input to iterable form if not already iterable.""" 

21 if isinstance(xs, (types.GeneratorType, list, set)): 

22 return xs 

23 return (xs,) 

24 

25 

26def apply_n(f, *xs, **kwds): 

27 """Apply function f to the cartesian product of iterables xs.""" 

28 for p in itertools.product(*[as_iterable(x) for x in xs]): 

29 f(*p, **kwds) 

30 

31 

32class AttributeView: 

33 """Provides attribute-style access to dynamic collections.""" 

34 

35 def __init__(self, get_attribute_list, get_attribute, get_item=None): 

36 """Initialize with functions to get attribute list and individual attributes. 

37 

38 Args: 

39 get_attribute_list: Function that returns list of available attributes 

40 get_attribute: Function that takes an attribute name and returns its value 

41 get_item: Optional function for item access, defaults to get_attribute 

42 """ 

43 self.get_attribute_list = get_attribute_list 

44 self.get_attribute = get_attribute 

45 self.get_item = get_item 

46 if self.get_item is None: 

47 self.get_item = get_attribute 

48 

49 def __dir__(self): 

50 """Return list of available attributes.""" 

51 return self.get_attribute_list() 

52 

53 def __getattr__(self, attr): 

54 """Get attribute by name, raising AttributeError if not found.""" 

55 try: 

56 return self.get_attribute(attr) 

57 except KeyError: 

58 raise AttributeError(attr) 

59 

60 def __getitem__(self, key): 

61 """Get item by key.""" 

62 return self.get_item(key) 

63 

64 def __getstate__(self): 

65 """Prepare object for serialization.""" 

66 return { 

67 "get_attribute_list": self.get_attribute_list, 

68 "get_attribute": self.get_attribute, 

69 "get_item": self.get_item, 

70 } 

71 

72 def __setstate__(self, state): 

73 """Restore object from serialized state.""" 

74 self.get_attribute_list = state["get_attribute_list"] 

75 self.get_attribute = state["get_attribute"] 

76 self.get_item = state["get_item"] 

77 if self.get_item is None: 

78 self.get_item = self.get_attribute 

79 

80 @staticmethod 

81 def from_dict(d, use_apply1=True): 

82 """Create an AttributeView from a dictionary.""" 

83 if use_apply1: 

84 

85 def get_attribute(xs): 

86 return apply1(d.get, xs) 

87 else: 

88 get_attribute = d.get 

89 return AttributeView(d.keys, get_attribute) 

90 

91 

92pandas_types = (pd.Series, pd.DataFrame) 

93 

94 

95def value_eq(a, b): 

96 """Compare two values for equality, handling pandas and numpy objects safely. 

97 

98 - Uses .equals for pandas Series/DataFrame 

99 - For numpy arrays, returns a single boolean using np.array_equal (treats NaNs as equal) 

100 - Falls back to == and coerces to bool when possible 

101 """ 

102 if a is b: 

103 return True 

104 

105 # pandas objects: use robust equality 

106 if isinstance(a, pandas_types): 

107 return a.equals(b) 

108 if isinstance(b, pandas_types): 

109 return b.equals(a) 

110 if isinstance(a, np.ndarray) or isinstance(b, np.ndarray): 

111 try: 

112 return np.array_equal(a, b, equal_nan=True) 

113 except Exception: 

114 return False 

115 try: 

116 if isinstance(a, np.ndarray) or isinstance(b, np.ndarray): 

117 try: 

118 return np.array_equal(a, b, equal_nan=True) 

119 except TypeError: 

120 # Fallback if equal_nan not available 

121 a_arr = np.asarray(a) 

122 b_arr = np.asarray(b) 

123 if a_arr.shape != b_arr.shape: 

124 return False 

125 eq = a_arr == b_arr 

126 # align NaN handling 

127 with np.errstate(invalid="ignore"): 

128 both_nan = np.isnan(a_arr) & np.isnan(b_arr) 

129 return bool(np.all(eq | both_nan)) 

130 

131 # Default comparison; ensure a single boolean 

132 result = a == b 

133 # If result is an array-like truth value, reduce safely 

134 if isinstance(result, (np.ndarray,)): 

135 return bool(np.all(result)) 

136 return bool(result) 

137 except Exception: 

138 return False