From 643d0fd3490cba07cea88827ab3b3275142eb905 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Thu, 21 Sep 2023 15:01:09 -0400 Subject: [PATCH] Enable custom dtype arrays as return values (#1) * Enable custom dtype arrays as return values * Version bump --- stanio/__init__.py | 2 +- stanio/reshape.py | 44 +++++++++++++++++++++--- test/test_reshape.py | 82 +++++++++++++++++++++++++++++++++++--------- 3 files changed, 106 insertions(+), 22 deletions(-) diff --git a/stanio/__init__.py b/stanio/__init__.py index de13f5b..3eec581 100644 --- a/stanio/__init__.py +++ b/stanio/__init__.py @@ -11,4 +11,4 @@ "stan_variables", ] -__version__ = "0.3.1" +__version__ = "0.4.0" diff --git a/stanio/reshape.py b/stanio/reshape.py index cc7dc63..06fecf9 100644 --- a/stanio/reshape.py +++ b/stanio/reshape.py @@ -47,6 +47,23 @@ class Variable: # list of nested parameters contents: List["Variable"] + def dtype(self, top=True): + if self.type == VariableType.TUPLE: + elts = [ + (str(i + 1), param.dtype(top=False)) + for i, param in enumerate(self.contents) + ] + dtype = np.dtype(elts) + elif self.type == VariableType.SCALAR: + dtype = np.float64 + elif self.type == VariableType.COMPLEX: + dtype = np.complex128 + + if top: + return dtype + else: + return np.dtype((dtype, self.dimensions)) + def columns(self) -> Iterable[int]: return range(self.start_idx, self.end_idx) @@ -81,7 +98,7 @@ def _extract_helper(self, src: np.ndarray, offset: int = 0): out[i, idx] = tuple(elt[i] for elt in elts) return out.reshape(-1, *self.dimensions, order="F") - def extract_reshape(self, src: np.ndarray) -> npt.NDArray[Any]: + def extract_reshape(self, src: np.ndarray, object=True) -> npt.NDArray[Any]: """ Given an array where the final dimension is the flattened output of a Stan model, (e.g. one row of a Stan CSV file), extract the variable @@ -98,6 +115,10 @@ def extract_reshape(self, src: np.ndarray) -> npt.NDArray[Any]: Indicies besides the final dimension are preserved in the output. + object : bool + If True, the output of tuple types will be an object array, + otherwise it will use custom dtypes to represent tuples. + Returns ------- npt.NDArray[Any] @@ -106,10 +127,14 @@ def extract_reshape(self, src: np.ndarray) -> npt.NDArray[Any]: otherwise it will have a dtype of either float64 or complex128. """ out = self._extract_helper(src) + if not object: + out = out.astype(self.dtype()) if src.ndim > 1: - return out.reshape(*src.shape[:-1], *self.dimensions, order="F") + out = out.reshape(*src.shape[:-1], *self.dimensions, order="F") else: - return out.squeeze(axis=0) + out = out.squeeze(axis=0) + + return out def _munge_first_tuple(tup: str) -> str: @@ -194,7 +219,10 @@ def parse_header(header: str) -> Dict[str, Variable]: def stan_variables( - parameters: Dict[str, Variable], source: npt.NDArray[np.float64] + parameters: Dict[str, Variable], + source: npt.NDArray[np.float64], + *, + object: bool = True, ) -> Dict[str, npt.NDArray[Any]]: """ Given a dictionary of :class:`Variable` objects and a source array, @@ -208,6 +236,9 @@ def stan_variables( like that returned by :func:`parse_header()`. source : npt.NDArray[np.float64] The array to extract from. + object : bool + If True, the output of tuple types will be an object array, + otherwise it will use custom dtypes to represent tuples. Returns ------- @@ -215,4 +246,7 @@ def stan_variables( A dictionary mapping the base name of each variable to the extracted and reshaped data. """ - return {param.name: param.extract_reshape(source) for param in parameters.values()} + return { + param.name: param.extract_reshape(source, object=object) + for param in parameters.values() + } diff --git a/test/test_reshape.py b/test/test_reshape.py index 8f857ad..4152fc3 100644 --- a/test/test_reshape.py +++ b/test/test_reshape.py @@ -12,12 +12,12 @@ # see file data/rectangles/output.stan -@pytest.fixture(scope="module") -def rect_data(): +@pytest.fixture(scope="module", params=[True, False], ids=["use_object", "use_dtype"]) +def rect_data(request): files = [DATA / "rectangles" / f"output_{i}.csv" for i in range(1, 5)] header, data = read_csv(files) params = parse_header(header) - yield stan_variables(params, data) + yield stan_variables(params, data, object=request.param) def test_basic_shapes(rect_data): @@ -91,43 +91,93 @@ def test_basic_values(rect_data): # see file data/tuples/output.stan -@pytest.fixture(scope="module") -def tuple_data(): +@pytest.fixture(scope="module", params=[True, False], ids=["use_object", "use_dtype"]) +def tuple_data(request): files = [DATA / "tuples" / f"output_{i}.csv" for i in range(1, 5)] header, data = read_csv(files) params = parse_header(header) - yield stan_variables(params, data) + yield stan_variables(params, data, object=request.param) def test_tuple_shapes(tuple_data): - assert isinstance(tuple_data["pair"][0, 0], tuple) assert len(tuple_data["pair"][0, 0]) == 2 - assert isinstance(tuple_data["nested"][0, 0], tuple) assert len(tuple_data["nested"][0, 0]) == 2 - assert isinstance(tuple_data["nested"][0, 0][1], tuple) assert len(tuple_data["nested"][0, 0][1]) == 2 assert tuple_data["arr_pair"].shape == (4, 1000, 2) - assert isinstance(tuple_data["arr_pair"][0, 0, 0], tuple) assert tuple_data["arr_very_nested"].shape == (4, 1000, 3) + + assert tuple_data["arr_2d_pair"].shape == (4, 1000, 3, 2) + + assert tuple_data["ultimate"].shape == (4, 1000, 2, 3) + assert tuple_data["ultimate"][0, 0, 0, 0][0].shape == (2,) + assert tuple_data["ultimate"][0, 0, 0, 0][0][0][1].shape == (2,) + assert tuple_data["ultimate"][0, 0, 0, 0][1].shape == (4, 5) + + +def check_tuple_shapes_objects(tuple_data): + assert isinstance(tuple_data["pair"][0, 0], tuple) + + assert isinstance(tuple_data["nested"][0, 0], tuple) + assert isinstance(tuple_data["nested"][0, 0][1], tuple) + + assert isinstance(tuple_data["arr_pair"][0, 0, 0], tuple) + assert isinstance(tuple_data["arr_very_nested"][0, 0, 0], tuple) assert isinstance(tuple_data["arr_very_nested"][0, 0, 0][0], tuple) assert isinstance(tuple_data["arr_very_nested"][0, 0, 0][0][1], tuple) - assert tuple_data["arr_2d_pair"].shape == (4, 1000, 3, 2) assert isinstance(tuple_data["arr_2d_pair"][0, 0, 0, 0], tuple) - assert tuple_data["ultimate"].shape == (4, 1000, 2, 3) assert isinstance(tuple_data["ultimate"][0, 0, 0, 0], tuple) - assert tuple_data["ultimate"][0, 0, 0, 0][0].shape == (2,) assert isinstance(tuple_data["ultimate"][0, 0, 0, 0][0][0], tuple) - assert tuple_data["ultimate"][0, 0, 0, 0][0][0][1].shape == (2,) - assert tuple_data["ultimate"][0, 0, 0, 0][1].shape == (4, 5) + + +def check_tuple_shapes_custom_dtypes(tuple_data): + for value in tuple_data.values(): + assert not value.dtype.hasobject + + pair_dtype = np.dtype([("1", "f8"), ("2", "f8")]) + assert tuple_data["pair"].dtype == pair_dtype + + nested_dtype = np.dtype([("1", "f8"), ("2", [("1", "f8"), ("2", "c16")])]) + assert tuple_data["nested"].dtype == nested_dtype + assert tuple_data["nested"][0, 0][1].dtype == nested_dtype[1] + + assert tuple_data["arr_pair"].dtype == pair_dtype + + very_nested_dtype = np.dtype( + [ + ("1", nested_dtype), + ("2", "f8"), + ] + ) + assert tuple_data["arr_very_nested"].dtype == very_nested_dtype + assert tuple_data["arr_very_nested"][0, 0, 0][0].dtype == nested_dtype + assert tuple_data["arr_very_nested"][0, 0, 0][0][1].dtype == nested_dtype[1] + + ultimate_dtype = np.dtype( + [ + ("1", ([("1", "f8"), ("2", "(2,)f8")], (2,))), + ("2", "(4,5)f8"), + ] + ) + assert tuple_data["ultimate"].dtype == ultimate_dtype + + +def test_tuple_dtypes(tuple_data): + if isinstance(tuple_data["pair"][0, 0], tuple): + check_tuple_shapes_objects(tuple_data) + else: + check_tuple_shapes_custom_dtypes(tuple_data) def assert_tuple_equal(t1, t2): + if hasattr(t1, "dtype") and t1.dtype.kind == "V": + t1 = t1.tolist() + assert len(t1) == len(t2) for x, y in zip(t1, t2): if isinstance(x, tuple): @@ -140,7 +190,7 @@ def check_tuples(tuple_data, chain, draw): base = tuple_data["base"][chain, draw] base_i = tuple_data["base_i"][chain, draw] pair_exp = (base, 2 * base) - np.testing.assert_almost_equal(tuple_data["pair"][chain, draw], pair_exp) + assert_tuple_equal(tuple_data["pair"][chain, draw], pair_exp) nested_exp = (base * 3, (base_i, 4j * base)) assert_tuple_equal(tuple_data["nested"][chain, draw], nested_exp)