Coverage for src / loman / util.py: 61%
80 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-02 23:34 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-02 23:34 +0000
1"""Utility functions and classes for loman computation graphs."""
3import itertools
4import types
6import numpy as np
7import pandas as pd
10def apply1(f, xs, *args, **kwds):
11 """Apply function f to xs, handling generators, lists, and single values."""
12 if isinstance(xs, types.GeneratorType):
13 return (f(x, *args, **kwds) for x in xs)
14 if isinstance(xs, list):
15 return [f(x, *args, **kwds) for x in xs]
16 return f(xs, *args, **kwds)
19def as_iterable(xs):
20 """Convert input to iterable form if not already iterable."""
21 if isinstance(xs, (types.GeneratorType, list, set)):
22 return xs
23 return (xs,)
26def apply_n(f, *xs, **kwds):
27 """Apply function f to the cartesian product of iterables xs."""
28 for p in itertools.product(*[as_iterable(x) for x in xs]):
29 f(*p, **kwds)
32class AttributeView:
33 """Provides attribute-style access to dynamic collections."""
35 def __init__(self, get_attribute_list, get_attribute, get_item=None):
36 """Initialize with functions to get attribute list and individual attributes.
38 Args:
39 get_attribute_list: Function that returns list of available attributes
40 get_attribute: Function that takes an attribute name and returns its value
41 get_item: Optional function for item access, defaults to get_attribute
42 """
43 self.get_attribute_list = get_attribute_list
44 self.get_attribute = get_attribute
45 self.get_item = get_item
46 if self.get_item is None:
47 self.get_item = get_attribute
49 def __dir__(self):
50 """Return list of available attributes."""
51 return self.get_attribute_list()
53 def __getattr__(self, attr):
54 """Get attribute by name, raising AttributeError if not found."""
55 try:
56 return self.get_attribute(attr)
57 except KeyError:
58 raise AttributeError(attr)
60 def __getitem__(self, key):
61 """Get item by key."""
62 return self.get_item(key)
64 def __getstate__(self):
65 """Prepare object for serialization."""
66 return {
67 "get_attribute_list": self.get_attribute_list,
68 "get_attribute": self.get_attribute,
69 "get_item": self.get_item,
70 }
72 def __setstate__(self, state):
73 """Restore object from serialized state."""
74 self.get_attribute_list = state["get_attribute_list"]
75 self.get_attribute = state["get_attribute"]
76 self.get_item = state["get_item"]
77 if self.get_item is None:
78 self.get_item = self.get_attribute
80 @staticmethod
81 def from_dict(d, use_apply1=True):
82 """Create an AttributeView from a dictionary."""
83 if use_apply1:
85 def get_attribute(xs):
86 return apply1(d.get, xs)
87 else:
88 get_attribute = d.get
89 return AttributeView(d.keys, get_attribute)
92pandas_types = (pd.Series, pd.DataFrame)
95def value_eq(a, b):
96 """Compare two values for equality, handling pandas and numpy objects safely.
98 - Uses .equals for pandas Series/DataFrame
99 - For numpy arrays, returns a single boolean using np.array_equal (treats NaNs as equal)
100 - Falls back to == and coerces to bool when possible
101 """
102 if a is b:
103 return True
105 # pandas objects: use robust equality
106 if isinstance(a, pandas_types):
107 return a.equals(b)
108 if isinstance(b, pandas_types):
109 return b.equals(a)
110 if isinstance(a, np.ndarray) or isinstance(b, np.ndarray):
111 try:
112 return np.array_equal(a, b, equal_nan=True)
113 except Exception:
114 return False
115 try:
116 if isinstance(a, np.ndarray) or isinstance(b, np.ndarray):
117 try:
118 return np.array_equal(a, b, equal_nan=True)
119 except TypeError:
120 # Fallback if equal_nan not available
121 a_arr = np.asarray(a)
122 b_arr = np.asarray(b)
123 if a_arr.shape != b_arr.shape:
124 return False
125 eq = a_arr == b_arr
126 # align NaN handling
127 with np.errstate(invalid="ignore"):
128 both_nan = np.isnan(a_arr) & np.isnan(b_arr)
129 return bool(np.all(eq | both_nan))
131 # Default comparison; ensure a single boolean
132 result = a == b
133 # If result is an array-like truth value, reduce safely
134 if isinstance(result, (np.ndarray,)):
135 return bool(np.all(result))
136 return bool(result)
137 except Exception:
138 return False