diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 03387051b3b..8abe0088cd1 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -122,7 +122,9 @@ class FillValueCoder: """ @classmethod - def encode(cls, value: int | float | str | bytes, dtype: np.dtype[Any]) -> Any: + def encode( + cls, value: int | float | complex | str | bytes, dtype: np.dtype[Any] + ) -> Any: if dtype.kind in "S": # byte string, this implies that 'value' must also be `bytes` dtype. assert isinstance(value, bytes) @@ -132,16 +134,33 @@ def encode(cls, value: int | float | str | bytes, dtype: np.dtype[Any]) -> Any: return bool(value) elif dtype.kind in "iu": # todo: do we want to check for decimals? + assert isinstance(value, int | float) return int(value) elif dtype.kind in "f": + assert isinstance(value, int | float) return base64.standard_b64encode(struct.pack(" None: assert actual3 == expected3 +@requires_zarr +@pytest.mark.parametrize("dtype", [complex, np.complex64, np.complex128]) +def test_fill_value_coder_complex(dtype) -> None: + """Test that FillValueCoder round-trips complex fill values.""" + from xarray.backends.zarr import FillValueCoder + + for value in [dtype(1 + 2j), dtype(-3.5 + 4.5j), dtype(complex("nan+nanj"))]: + encoded = FillValueCoder.encode(value, np.dtype(dtype)) + decoded = FillValueCoder.decode(encoded, np.dtype(dtype)) + np.testing.assert_equal(np.array(decoded, dtype=dtype), np.array(value)) + + @requires_zarr def test_extract_zarr_variable_encoding() -> None: var = xr.Variable("x", [1, 2])