diff --git a/linopy/common.py b/linopy/common.py index dce26a7a..6ca79a9e 100644 --- a/linopy/common.py +++ b/linopy/common.py @@ -282,20 +282,29 @@ def _coords_to_dict( Normalize coords to a dict mapping dim names to coordinate values. Entries must be ``pd.Index`` (named or not) or unnamed sequences - (``list`` / ``tuple`` / ``range`` / ``np.ndarray``). Other types — - notably ``xarray.DataArray`` — raise ``TypeError`` rather than - being silently dropped: callers should convert via - ``variable.indexes[]`` (or ``pd.Index(...)``) first. + (``list`` / ``tuple`` / ``range`` / ``np.ndarray``). A + ``pd.MultiIndex`` must have ``.name`` set — xarray requires a + single dimension name for the flattened index. Other types — + notably ``xarray.DataArray`` — raise ``TypeError``; callers should + convert via ``variable.indexes[]`` first. """ if isinstance(coords, Mapping): return dict(coords) result: dict[str, Any] = {} for c in coords: - if isinstance(c, pd.Index): + if isinstance(c, pd.MultiIndex): + if not c.name: + raise TypeError( + "MultiIndex coords entries must have .name set so " + "xarray can use it as the dimension name. Set it via " + "`idx.name = 'my_dim'` before passing to coords." + ) + result[c.name] = c + elif isinstance(c, pd.Index): if c.name: result[c.name] = c elif isinstance(c, list | tuple | range | np.ndarray): - pass # unnamed sequence contributes no named dim + pass else: raise TypeError( f"coords entries must be pd.Index or an unnamed sequence " @@ -310,21 +319,22 @@ def _named_pandas_to_dataarray(arr: pd.Series | pd.DataFrame) -> DataArray | Non """ Convert a pandas Series or DataFrame with fully named axes to a DataArray. - DataFrame columns (and column-MultiIndex levels) are stacked into the row - MultiIndex so each axis name becomes its own dimension. Returns ``None`` - if any axis (or MultiIndex level) is unnamed, so the caller can fall back - to ``as_dataarray``. + Returns ``None`` if any axis (or MultiIndex level) is unnamed or + non-string, so the caller can fall back to ``as_dataarray``. """ names = list(arr.index.names) if isinstance(arr, pd.DataFrame): names += list(arr.columns.names) - # pd.Index.names entries can be any hashable (tuples, ints, ...). Only - # strings map cleanly to xarray dim names; everything else falls through. if any(not isinstance(n, str) for n in names): return None if isinstance(arr, pd.DataFrame): - arr = arr.stack(list(range(arr.columns.nlevels)), future_stack=True) + if isinstance(arr.index, pd.MultiIndex) or isinstance( + arr.columns, pd.MultiIndex + ): + arr = arr.stack(list(range(arr.columns.nlevels)), future_stack=True) + return arr.to_xarray() + return DataArray(arr) return arr.to_xarray() @@ -392,7 +402,14 @@ def as_dataarray_in_coords(arr: Any, coords: Any, **kwargs: Any) -> DataArray: for dim, coord_values in expected.items(): if dim not in arr.dims: continue - if isinstance(arr.indexes.get(dim), pd.MultiIndex): + expected_is_mi = isinstance(coord_values, pd.MultiIndex) + actual_is_mi = isinstance(arr.indexes.get(dim), pd.MultiIndex) + if expected_is_mi or actual_is_mi: + if expected_is_mi and actual_is_mi: + if not arr.indexes[dim].equals(coord_values): + raise ValueError( + f"MultiIndex for dimension '{dim}' does not match coords" + ) continue expected_idx = ( coord_values @@ -401,7 +418,6 @@ def as_dataarray_in_coords(arr: Any, coords: Any, **kwargs: Any) -> DataArray: ) actual_idx = arr.coords[dim].to_index() if not actual_idx.equals(expected_idx): - # Same values, different order → reindex to match expected order if len(actual_idx) == len(expected_idx) and set(actual_idx) == set( expected_idx ): diff --git a/test/test_variable.py b/test/test_variable.py index 1a49abd6..6a123ff0 100644 --- a/test/test_variable.py +++ b/test/test_variable.py @@ -488,15 +488,6 @@ def test_unnamed_coords_short_circuit(self, model: "Model") -> None: var = model.add_variables(upper=bound, coords=[pd.Index([0, 1, 2])], name="x") assert (var.data.upper == [1, 2, 3]).all() - def test_dataarray_bound_with_multiindex_coord(self, model: "Model") -> None: - """A DataArray bound carrying a MultiIndex coord skips the value check.""" - midx = pd.MultiIndex.from_product([[0, 1], ["a", "b"]], names=("l1", "l2")) - midx.name = "multi" - bound = DataArray([1, 2, 3, 4], dims=["multi"], coords={"multi": midx}) - var = model.add_variables(upper=bound, coords=[midx], name="x") - assert var.shape == (4,) - assert (var.data.upper == [1, 2, 3, 4]).all() - # -- Broadcasting missing dims ----------------------------------------- @pytest.mark.parametrize( @@ -596,14 +587,6 @@ def test_dataarray_broadcast_missing_dim_order( # -- Special coord formats --------------------------------------------- - def test_multiindex_coords(self, model: "Model") -> None: - idx = pd.MultiIndex.from_product( - [[1, 2], ["a", "b"]], names=("level1", "level2") - ) - idx.name = "multi" - var = model.add_variables(lower=0, upper=1, coords=[idx], name="x") - assert var.shape == (4,) - def test_xarray_coordinates_object(self, model: "Model") -> None: time = pd.RangeIndex(3, name="time") base = model.add_variables(lower=0, coords=[time], name="base") @@ -811,3 +794,54 @@ def test_string_coords_mismatch(self, model: "Model") -> None: coords={"region": ["north", "south", "east"]}, name="x", ) + + +class TestAddVariablesMultiIndexCoords: + """MultiIndex-specific coord handling in add_variables.""" + + @pytest.fixture + def model(self) -> "Model": + return Model() + + @pytest.fixture + def midx(self) -> pd.MultiIndex: + mi = pd.MultiIndex.from_product([[0, 1], ["a", "b"]], names=("l1", "l2")) + mi.name = "multi" + return mi + + def test_scalar_bounds(self, model: "Model", midx: pd.MultiIndex) -> None: + var = model.add_variables(lower=0, upper=1, coords=[midx], name="x") + assert var.shape == (4,) + assert var.dims == ("multi",) + + def test_dataarray_bound(self, model: "Model", midx: pd.MultiIndex) -> None: + bound = DataArray([1, 2, 3, 4], dims=["multi"], coords={"multi": midx}) + var = model.add_variables(upper=bound, coords=[midx], name="x") + assert var.shape == (4,) + assert (var.data.upper == [1, 2, 3, 4]).all() + + def test_dataarray_bound_broadcast( + self, model: "Model", midx: pd.MultiIndex + ) -> None: + time = pd.Index([10, 20, 30], name="time") + bound = DataArray([1, 2, 3, 4], dims=["multi"], coords={"multi": midx}) + var = model.add_variables( + lower=-bound, upper=bound, coords=[midx, time], name="x" + ) + assert var.dims == ("multi", "time") + assert var.shape == (4, 3) + assert (var.data.upper.sel(time=10) == [1, 2, 3, 4]).all() + + def test_without_name_raises(self, model: "Model") -> None: + midx = pd.MultiIndex.from_product([[0, 1], ["a", "b"]], names=("l1", "l2")) + with pytest.raises(TypeError, match="MultiIndex.*must have .name set"): + model.add_variables(lower=0, upper=1, coords=[midx], name="x") + + def test_mismatched_multiindex_raises( + self, model: "Model", midx: pd.MultiIndex + ) -> None: + other = pd.MultiIndex.from_product([[0, 1], ["x", "y"]], names=("l1", "l2")) + other.name = "multi" + bound = DataArray([1, 2, 3, 4], dims=["multi"], coords={"multi": other}) + with pytest.raises(ValueError, match="MultiIndex.*does not match"): + model.add_variables(upper=bound, coords=[midx], name="x")