diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index de534f46..4da2cfa6 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -1286,9 +1286,20 @@ def _render_images( ), ) - # True if user gave n cmaps for n channels - got_multiple_cmaps = isinstance(render_params.cmap_params, list) - if got_multiple_cmaps: + # A list of cmap_params can be either user-supplied (one cmap per channel) or + # synthesized upstream to carry per-channel norms when the user only set `norm` + # (or `palette + norm=list`). The synthesized form must not trigger the + # blending warning or conflict with `palette`. + if isinstance(render_params.cmap_params, list): + got_multiple_cmaps = True + user_supplied_multi_cmaps = any(not cp.cmap_is_default for cp in render_params.cmap_params) + if len(render_params.cmap_params) != n_channels: + raise ValueError("If 'cmap' is provided, its length must match the number of channels.") + else: + got_multiple_cmaps = False + user_supplied_multi_cmaps = False + + if user_supplied_multi_cmaps: logger.warning( "You're blending multiple cmaps. " "If the plot doesn't look like you expect, it might be because your " @@ -1297,10 +1308,6 @@ def _render_images( "Consider using 'palette' instead." ) - # not using got_multiple_cmaps here because of ruff :( - if isinstance(render_params.cmap_params, list) and len(render_params.cmap_params) != n_channels: - raise ValueError("If 'cmap' is provided, its length must match the number of channels.") - # Detect RGB(A) images by channel names — skip when user overrides with palette/cmap is_rgb, has_alpha = _is_rgb_image(channels) has_explicit_cmap = ( @@ -1527,8 +1534,9 @@ def _render_images( zorder=render_params.zorder, ) - # 2C) Image has n channels and palette info - elif palette is not None and not got_multiple_cmaps: + # 2C) palette set; also covers `palette + norm=list` since synthesized + # default cmaps don't conflict and per-channel norms are already in `layers`. + elif palette is not None and not user_supplied_multi_cmaps: if len(palette) != n_channels: raise ValueError("If 'palette' is provided, its length must match the number of channels.") @@ -1567,10 +1575,6 @@ def _render_images( zorder=render_params.zorder, ) - # 2D) Image has n channels, no palette but cmap info - elif palette is not None and got_multiple_cmaps: - raise ValueError("If 'palette' is provided, 'cmap' must be None.") - # Collect channel legend entries (single point for all multi-channel paths) if render_params.channels_as_legend and channel_legend_entries is not None: if legend_colors is not None: diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 63efffd2..6971f31a 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -2989,15 +2989,22 @@ def _validate_image_render_params( ) element_params[el]["palette"] = palette + expected_len = len(channel) if channel is not None else len(spatial_element_ch) + cmap = param_dict["cmap"] if cmap is not None: - expected_len = len(channel) if channel is not None else len(spatial_element_ch) if len(cmap) == 1: cmap = cmap * expected_len if len(cmap) != expected_len: - cmap = None + raise ValueError( + f"Length of 'cmap' list ({len(cmap)}) must match the number of channels ({expected_len})." + ) element_params[el]["cmap"] = cmap - element_params[el]["norm"] = param_dict["norm"] + + norm = param_dict["norm"] + if isinstance(norm, list) and len(norm) > 1 and len(norm) != expected_len: + raise ValueError(f"Length of 'norm' list ({len(norm)}) must match the number of channels ({expected_len}).") + element_params[el]["norm"] = norm scale = param_dict["scale"] if scale and isinstance(param_dict["sdata"][el], DataTree): if scale not in list(param_dict["sdata"][el].keys()) and scale != "full": diff --git a/tests/pl/test_render_images.py b/tests/pl/test_render_images.py index 4a6bbe9c..abac517a 100644 --- a/tests/pl/test_render_images.py +++ b/tests/pl/test_render_images.py @@ -496,6 +496,32 @@ def test_norm_list_without_explicit_cmap(): plt.close(fig) +# Regression tests for #622: misleading 'cmap' errors when norm/palette interact. +def test_norm_list_wrong_length_raises_with_norm_message(): + # Without an explicit cmap the user only set norm; the error must mention norm, + # not cmap, and report both lengths. + sdata = _make_multichannel_sdata() + with pytest.raises(ValueError, match=r"'norm' list \(2\).*channels \(3\)"): + sdata.pl.render_images("img", norm=[Normalize(0, 1), Normalize(0, 2)]).pl.show() + + +def test_cmap_wrong_length_with_norm_list_no_longer_silent(): + # Previously the wrong-length cmap was silently nulled when norm was a list of + # the correct length, hiding the bug. It must now raise just like the no-norm path. + sdata = _make_multichannel_sdata() + with pytest.raises(ValueError, match=r"'cmap' list \(2\).*channels \(3\)"): + sdata.pl.render_images("img", cmap=["Reds", "Greens"], norm=[Normalize(0, 1)] * 3).pl.show() + + +def test_palette_with_norm_list_renders(): + # palette + per-channel norms used to fail with "If 'palette' is provided, 'cmap' + # must be None." even though the user never passed cmap. Should now render. + sdata = _make_multichannel_sdata() + fig, ax = plt.subplots() + sdata.pl.render_images("img", palette=["red", "green", "blue"], norm=[Normalize(0, 1)] * 3).pl.show(ax=ax) + plt.close(fig) + + def test_cmap_matches_selected_channels_not_full_image(sdata_blobs: SpatialData): """Cmap length should be validated against selected channels, not the full image channel count.""" # blobs_image has 3 channels; select 1 with a matching length-1 cmap