From bc588a760a804f783c4242d4435863a43a5f3f9f Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 23 Oct 2024 13:30:49 -0600 Subject: [PATCH] Fix JSON encoding of complex fill values (#2432) * Fix JSON encoding of complex fill values We were not replacing NaNs and Infs with the string versions. * Fix decoding of complex fill values * try excluding `math.inf` * Check complex numbers explicitly * Update src/zarr/core/metadata/v3.py --- src/zarr/core/metadata/v3.py | 26 ++++++++++++++++++++++---- tests/test_array.py | 23 +++++++++++++++++++++++ 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index 6b6f28dd96..7a38e9fd70 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -43,6 +43,13 @@ DEFAULT_DTYPE = "float64" +# Keep in sync with _replace_special_floats +SPECIAL_FLOATS_ENCODED = { + "Infinity": np.inf, + "-Infinity": -np.inf, + "NaN": np.nan, +} + def parse_zarr_format(data: object) -> Literal[3]: if data == 3: @@ -149,7 +156,7 @@ def default(self, o: object) -> Any: if isinstance(out, complex): # python complex types are not JSON serializable, so we use the # serialization defined in the zarr v3 spec - return [out.real, out.imag] + return _replace_special_floats([out.real, out.imag]) elif np.isnan(out): return "NaN" elif np.isinf(out): @@ -447,8 +454,11 @@ def parse_fill_value( if isinstance(fill_value, Sequence) and not isinstance(fill_value, str): if data_type in (DataType.complex64, DataType.complex128): if len(fill_value) == 2: + decoded_fill_value = tuple( + SPECIAL_FLOATS_ENCODED.get(value, value) for value in fill_value + ) # complex datatypes serialize to JSON arrays with two elements - return np_dtype.type(complex(*fill_value)) + return np_dtype.type(complex(*decoded_fill_value)) else: msg = ( f"Got an invalid fill value for complex data type {data_type.value}." @@ -475,12 +485,20 @@ def parse_fill_value( pass elif fill_value in ["Infinity", "-Infinity"] and not np.isfinite(casted_value): pass - elif np_dtype.kind in "cf": + elif np_dtype.kind == "f": # float comparison is not exact, especially when dtype None: + store = MemoryStore({}, mode="w") + Array.create(store=store, shape=(1,), dtype=np.complex64, fill_value=fill_value) + content = await store.get("zarr.json", prototype=default_buffer_prototype()) + assert content is not None + actual = json.loads(content.to_bytes()) + assert actual["fill_value"] == expected