Skip to content

Commit

Permalink
Enable custom dtype arrays as return values (#1)
Browse files Browse the repository at this point in the history
* Enable custom dtype arrays as return values

* Version bump
  • Loading branch information
WardBrian authored Sep 21, 2023
1 parent eae5228 commit 643d0fd
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 22 deletions.
2 changes: 1 addition & 1 deletion stanio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
"stan_variables",
]

__version__ = "0.3.1"
__version__ = "0.4.0"
44 changes: 39 additions & 5 deletions stanio/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,23 @@ class Variable:
# list of nested parameters
contents: List["Variable"]

def dtype(self, top=True):
if self.type == VariableType.TUPLE:
elts = [
(str(i + 1), param.dtype(top=False))
for i, param in enumerate(self.contents)
]
dtype = np.dtype(elts)
elif self.type == VariableType.SCALAR:
dtype = np.float64
elif self.type == VariableType.COMPLEX:
dtype = np.complex128

if top:
return dtype
else:
return np.dtype((dtype, self.dimensions))

def columns(self) -> Iterable[int]:
return range(self.start_idx, self.end_idx)

Expand Down Expand Up @@ -81,7 +98,7 @@ def _extract_helper(self, src: np.ndarray, offset: int = 0):
out[i, idx] = tuple(elt[i] for elt in elts)
return out.reshape(-1, *self.dimensions, order="F")

def extract_reshape(self, src: np.ndarray) -> npt.NDArray[Any]:
def extract_reshape(self, src: np.ndarray, object=True) -> npt.NDArray[Any]:
"""
Given an array where the final dimension is the flattened output of a
Stan model, (e.g. one row of a Stan CSV file), extract the variable
Expand All @@ -98,6 +115,10 @@ def extract_reshape(self, src: np.ndarray) -> npt.NDArray[Any]:
Indicies besides the final dimension are preserved
in the output.
object : bool
If True, the output of tuple types will be an object array,
otherwise it will use custom dtypes to represent tuples.
Returns
-------
npt.NDArray[Any]
Expand All @@ -106,10 +127,14 @@ def extract_reshape(self, src: np.ndarray) -> npt.NDArray[Any]:
otherwise it will have a dtype of either float64 or complex128.
"""
out = self._extract_helper(src)
if not object:
out = out.astype(self.dtype())
if src.ndim > 1:
return out.reshape(*src.shape[:-1], *self.dimensions, order="F")
out = out.reshape(*src.shape[:-1], *self.dimensions, order="F")
else:
return out.squeeze(axis=0)
out = out.squeeze(axis=0)

return out


def _munge_first_tuple(tup: str) -> str:
Expand Down Expand Up @@ -194,7 +219,10 @@ def parse_header(header: str) -> Dict[str, Variable]:


def stan_variables(
parameters: Dict[str, Variable], source: npt.NDArray[np.float64]
parameters: Dict[str, Variable],
source: npt.NDArray[np.float64],
*,
object: bool = True,
) -> Dict[str, npt.NDArray[Any]]:
"""
Given a dictionary of :class:`Variable` objects and a source array,
Expand All @@ -208,11 +236,17 @@ def stan_variables(
like that returned by :func:`parse_header()`.
source : npt.NDArray[np.float64]
The array to extract from.
object : bool
If True, the output of tuple types will be an object array,
otherwise it will use custom dtypes to represent tuples.
Returns
-------
Dict[str, npt.NDArray[Any]]
A dictionary mapping the base name of each variable to the extracted
and reshaped data.
"""
return {param.name: param.extract_reshape(source) for param in parameters.values()}
return {
param.name: param.extract_reshape(source, object=object)
for param in parameters.values()
}
82 changes: 66 additions & 16 deletions test/test_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@


# see file data/rectangles/output.stan
@pytest.fixture(scope="module")
def rect_data():
@pytest.fixture(scope="module", params=[True, False], ids=["use_object", "use_dtype"])
def rect_data(request):
files = [DATA / "rectangles" / f"output_{i}.csv" for i in range(1, 5)]
header, data = read_csv(files)
params = parse_header(header)
yield stan_variables(params, data)
yield stan_variables(params, data, object=request.param)


def test_basic_shapes(rect_data):
Expand Down Expand Up @@ -91,43 +91,93 @@ def test_basic_values(rect_data):


# see file data/tuples/output.stan
@pytest.fixture(scope="module")
def tuple_data():
@pytest.fixture(scope="module", params=[True, False], ids=["use_object", "use_dtype"])
def tuple_data(request):
files = [DATA / "tuples" / f"output_{i}.csv" for i in range(1, 5)]
header, data = read_csv(files)
params = parse_header(header)
yield stan_variables(params, data)
yield stan_variables(params, data, object=request.param)


def test_tuple_shapes(tuple_data):
assert isinstance(tuple_data["pair"][0, 0], tuple)
assert len(tuple_data["pair"][0, 0]) == 2

assert isinstance(tuple_data["nested"][0, 0], tuple)
assert len(tuple_data["nested"][0, 0]) == 2
assert isinstance(tuple_data["nested"][0, 0][1], tuple)
assert len(tuple_data["nested"][0, 0][1]) == 2

assert tuple_data["arr_pair"].shape == (4, 1000, 2)
assert isinstance(tuple_data["arr_pair"][0, 0, 0], tuple)

assert tuple_data["arr_very_nested"].shape == (4, 1000, 3)

assert tuple_data["arr_2d_pair"].shape == (4, 1000, 3, 2)

assert tuple_data["ultimate"].shape == (4, 1000, 2, 3)
assert tuple_data["ultimate"][0, 0, 0, 0][0].shape == (2,)
assert tuple_data["ultimate"][0, 0, 0, 0][0][0][1].shape == (2,)
assert tuple_data["ultimate"][0, 0, 0, 0][1].shape == (4, 5)


def check_tuple_shapes_objects(tuple_data):
assert isinstance(tuple_data["pair"][0, 0], tuple)

assert isinstance(tuple_data["nested"][0, 0], tuple)
assert isinstance(tuple_data["nested"][0, 0][1], tuple)

assert isinstance(tuple_data["arr_pair"][0, 0, 0], tuple)

assert isinstance(tuple_data["arr_very_nested"][0, 0, 0], tuple)
assert isinstance(tuple_data["arr_very_nested"][0, 0, 0][0], tuple)
assert isinstance(tuple_data["arr_very_nested"][0, 0, 0][0][1], tuple)

assert tuple_data["arr_2d_pair"].shape == (4, 1000, 3, 2)
assert isinstance(tuple_data["arr_2d_pair"][0, 0, 0, 0], tuple)

assert tuple_data["ultimate"].shape == (4, 1000, 2, 3)
assert isinstance(tuple_data["ultimate"][0, 0, 0, 0], tuple)
assert tuple_data["ultimate"][0, 0, 0, 0][0].shape == (2,)
assert isinstance(tuple_data["ultimate"][0, 0, 0, 0][0][0], tuple)
assert tuple_data["ultimate"][0, 0, 0, 0][0][0][1].shape == (2,)
assert tuple_data["ultimate"][0, 0, 0, 0][1].shape == (4, 5)


def check_tuple_shapes_custom_dtypes(tuple_data):
for value in tuple_data.values():
assert not value.dtype.hasobject

pair_dtype = np.dtype([("1", "f8"), ("2", "f8")])
assert tuple_data["pair"].dtype == pair_dtype

nested_dtype = np.dtype([("1", "f8"), ("2", [("1", "f8"), ("2", "c16")])])
assert tuple_data["nested"].dtype == nested_dtype
assert tuple_data["nested"][0, 0][1].dtype == nested_dtype[1]

assert tuple_data["arr_pair"].dtype == pair_dtype

very_nested_dtype = np.dtype(
[
("1", nested_dtype),
("2", "f8"),
]
)
assert tuple_data["arr_very_nested"].dtype == very_nested_dtype
assert tuple_data["arr_very_nested"][0, 0, 0][0].dtype == nested_dtype
assert tuple_data["arr_very_nested"][0, 0, 0][0][1].dtype == nested_dtype[1]

ultimate_dtype = np.dtype(
[
("1", ([("1", "f8"), ("2", "(2,)f8")], (2,))),
("2", "(4,5)f8"),
]
)
assert tuple_data["ultimate"].dtype == ultimate_dtype


def test_tuple_dtypes(tuple_data):
if isinstance(tuple_data["pair"][0, 0], tuple):
check_tuple_shapes_objects(tuple_data)
else:
check_tuple_shapes_custom_dtypes(tuple_data)


def assert_tuple_equal(t1, t2):
if hasattr(t1, "dtype") and t1.dtype.kind == "V":
t1 = t1.tolist()

assert len(t1) == len(t2)
for x, y in zip(t1, t2):
if isinstance(x, tuple):
Expand All @@ -140,7 +190,7 @@ def check_tuples(tuple_data, chain, draw):
base = tuple_data["base"][chain, draw]
base_i = tuple_data["base_i"][chain, draw]
pair_exp = (base, 2 * base)
np.testing.assert_almost_equal(tuple_data["pair"][chain, draw], pair_exp)
assert_tuple_equal(tuple_data["pair"][chain, draw], pair_exp)
nested_exp = (base * 3, (base_i, 4j * base))
assert_tuple_equal(tuple_data["nested"][chain, draw], nested_exp)

Expand Down

0 comments on commit 643d0fd

Please sign in to comment.