Skip to content

Commit

Permalink
Fix JSON encoding of complex fill values (zarr-developers#2432)
Browse files Browse the repository at this point in the history
* Fix JSON encoding of complex fill values

We were not replacing NaNs and Infs with the string versions.

* Fix decoding of complex fill values

* try excluding `math.inf`

* Check complex numbers explicitly

* Update src/zarr/core/metadata/v3.py
  • Loading branch information
dcherian authored Oct 23, 2024
1 parent 6ce0526 commit bc588a7
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 4 deletions.
26 changes: 22 additions & 4 deletions src/zarr/core/metadata/v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@

DEFAULT_DTYPE = "float64"

# Keep in sync with _replace_special_floats
SPECIAL_FLOATS_ENCODED = {
"Infinity": np.inf,
"-Infinity": -np.inf,
"NaN": np.nan,
}


def parse_zarr_format(data: object) -> Literal[3]:
if data == 3:
Expand Down Expand Up @@ -149,7 +156,7 @@ def default(self, o: object) -> Any:
if isinstance(out, complex):
# python complex types are not JSON serializable, so we use the
# serialization defined in the zarr v3 spec
return [out.real, out.imag]
return _replace_special_floats([out.real, out.imag])
elif np.isnan(out):
return "NaN"
elif np.isinf(out):
Expand Down Expand Up @@ -447,8 +454,11 @@ def parse_fill_value(
if isinstance(fill_value, Sequence) and not isinstance(fill_value, str):
if data_type in (DataType.complex64, DataType.complex128):
if len(fill_value) == 2:
decoded_fill_value = tuple(
SPECIAL_FLOATS_ENCODED.get(value, value) for value in fill_value
)
# complex datatypes serialize to JSON arrays with two elements
return np_dtype.type(complex(*fill_value))
return np_dtype.type(complex(*decoded_fill_value))
else:
msg = (
f"Got an invalid fill value for complex data type {data_type.value}."
Expand All @@ -475,12 +485,20 @@ def parse_fill_value(
pass
elif fill_value in ["Infinity", "-Infinity"] and not np.isfinite(casted_value):
pass
elif np_dtype.kind in "cf":
elif np_dtype.kind == "f":
# float comparison is not exact, especially when dtype <float64
# so we us np.isclose for this comparison.
# so we use np.isclose for this comparison.
# this also allows us to compare nan fill_values
if not np.isclose(fill_value, casted_value, equal_nan=True):
raise ValueError(f"fill value {fill_value!r} is not valid for dtype {data_type}")
elif np_dtype.kind == "c":
# confusingly np.isclose(np.inf, np.inf + 0j) is False on numpy<2, so compare real and imag parts
# explicitly.
if not (
np.isclose(np.real(fill_value), np.real(casted_value), equal_nan=True)
and np.isclose(np.imag(fill_value), np.imag(casted_value), equal_nan=True)
):
raise ValueError(f"fill value {fill_value!r} is not valid for dtype {data_type}")
else:
if fill_value != casted_value:
raise ValueError(f"fill value {fill_value!r} is not valid for dtype {data_type}")
Expand Down
23 changes: 23 additions & 0 deletions tests/test_array.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json
import math
import pickle
from itertools import accumulate
from typing import Any, Literal
Expand All @@ -9,6 +11,7 @@
from zarr import Array, AsyncArray, Group
from zarr.codecs import BytesCodec, VLenBytesCodec
from zarr.core.array import chunks_initialized
from zarr.core.buffer import default_buffer_prototype
from zarr.core.buffer.cpu import NDBuffer
from zarr.core.common import JSON, MemoryOrder, ZarrFormat
from zarr.core.group import AsyncGroup
Expand Down Expand Up @@ -624,3 +627,23 @@ def test_array_create_order(
assert vals.flags.f_contiguous
else:
raise AssertionError


@pytest.mark.parametrize(
("fill_value", "expected"),
[
(np.nan * 1j, ["NaN", "NaN"]),
(np.nan, ["NaN", 0.0]),
(np.inf, ["Infinity", 0.0]),
(np.inf * 1j, ["NaN", "Infinity"]),
(-np.inf, ["-Infinity", 0.0]),
(math.inf, ["Infinity", 0.0]),
],
)
async def test_special_complex_fill_values_roundtrip(fill_value: Any, expected: list[Any]) -> None:
store = MemoryStore({}, mode="w")
Array.create(store=store, shape=(1,), dtype=np.complex64, fill_value=fill_value)
content = await store.get("zarr.json", prototype=default_buffer_prototype())
assert content is not None
actual = json.loads(content.to_bytes())
assert actual["fill_value"] == expected

0 comments on commit bc588a7

Please sign in to comment.