Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -1286,9 +1286,20 @@ def _render_images(
),
)

# True if user gave n cmaps for n channels
got_multiple_cmaps = isinstance(render_params.cmap_params, list)
if got_multiple_cmaps:
# A list of cmap_params can be either user-supplied (one cmap per channel) or
# synthesized upstream to carry per-channel norms when the user only set `norm`
# (or `palette + norm=list`). The synthesized form must not trigger the
# blending warning or conflict with `palette`.
if isinstance(render_params.cmap_params, list):
got_multiple_cmaps = True
user_supplied_multi_cmaps = any(not cp.cmap_is_default for cp in render_params.cmap_params)
if len(render_params.cmap_params) != n_channels:
raise ValueError("If 'cmap' is provided, its length must match the number of channels.")
else:
got_multiple_cmaps = False
user_supplied_multi_cmaps = False

if user_supplied_multi_cmaps:
logger.warning(
"You're blending multiple cmaps. "
"If the plot doesn't look like you expect, it might be because your "
Expand All @@ -1297,10 +1308,6 @@ def _render_images(
"Consider using 'palette' instead."
)

# not using got_multiple_cmaps here because of ruff :(
if isinstance(render_params.cmap_params, list) and len(render_params.cmap_params) != n_channels:
raise ValueError("If 'cmap' is provided, its length must match the number of channels.")

# Detect RGB(A) images by channel names — skip when user overrides with palette/cmap
is_rgb, has_alpha = _is_rgb_image(channels)
has_explicit_cmap = (
Expand Down Expand Up @@ -1527,8 +1534,9 @@ def _render_images(
zorder=render_params.zorder,
)

# 2C) Image has n channels and palette info
elif palette is not None and not got_multiple_cmaps:
# 2C) palette set; also covers `palette + norm=list` since synthesized
# default cmaps don't conflict and per-channel norms are already in `layers`.
elif palette is not None and not user_supplied_multi_cmaps:
if len(palette) != n_channels:
raise ValueError("If 'palette' is provided, its length must match the number of channels.")

Expand Down Expand Up @@ -1567,10 +1575,6 @@ def _render_images(
zorder=render_params.zorder,
)

# 2D) Image has n channels, no palette but cmap info
elif palette is not None and got_multiple_cmaps:
raise ValueError("If 'palette' is provided, 'cmap' must be None.")

# Collect channel legend entries (single point for all multi-channel paths)
if render_params.channels_as_legend and channel_legend_entries is not None:
if legend_colors is not None:
Expand Down
13 changes: 10 additions & 3 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2989,15 +2989,22 @@ def _validate_image_render_params(
)
element_params[el]["palette"] = palette

expected_len = len(channel) if channel is not None else len(spatial_element_ch)

cmap = param_dict["cmap"]
if cmap is not None:
expected_len = len(channel) if channel is not None else len(spatial_element_ch)
if len(cmap) == 1:
cmap = cmap * expected_len
if len(cmap) != expected_len:
cmap = None
raise ValueError(
f"Length of 'cmap' list ({len(cmap)}) must match the number of channels ({expected_len})."
)
element_params[el]["cmap"] = cmap
element_params[el]["norm"] = param_dict["norm"]

norm = param_dict["norm"]
if isinstance(norm, list) and len(norm) > 1 and len(norm) != expected_len:
raise ValueError(f"Length of 'norm' list ({len(norm)}) must match the number of channels ({expected_len}).")
element_params[el]["norm"] = norm
scale = param_dict["scale"]
if scale and isinstance(param_dict["sdata"][el], DataTree):
if scale not in list(param_dict["sdata"][el].keys()) and scale != "full":
Expand Down
26 changes: 26 additions & 0 deletions tests/pl/test_render_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,32 @@ def test_norm_list_without_explicit_cmap():
plt.close(fig)


# Regression tests for #622: misleading 'cmap' errors when norm/palette interact.
def test_norm_list_wrong_length_raises_with_norm_message():
# Without an explicit cmap the user only set norm; the error must mention norm,
# not cmap, and report both lengths.
sdata = _make_multichannel_sdata()
with pytest.raises(ValueError, match=r"'norm' list \(2\).*channels \(3\)"):
sdata.pl.render_images("img", norm=[Normalize(0, 1), Normalize(0, 2)]).pl.show()


def test_cmap_wrong_length_with_norm_list_no_longer_silent():
# Previously the wrong-length cmap was silently nulled when norm was a list of
# the correct length, hiding the bug. It must now raise just like the no-norm path.
sdata = _make_multichannel_sdata()
with pytest.raises(ValueError, match=r"'cmap' list \(2\).*channels \(3\)"):
sdata.pl.render_images("img", cmap=["Reds", "Greens"], norm=[Normalize(0, 1)] * 3).pl.show()


def test_palette_with_norm_list_renders():
# palette + per-channel norms used to fail with "If 'palette' is provided, 'cmap'
# must be None." even though the user never passed cmap. Should now render.
sdata = _make_multichannel_sdata()
fig, ax = plt.subplots()
sdata.pl.render_images("img", palette=["red", "green", "blue"], norm=[Normalize(0, 1)] * 3).pl.show(ax=ax)
plt.close(fig)


def test_cmap_matches_selected_channels_not_full_image(sdata_blobs: SpatialData):
"""Cmap length should be validated against selected channels, not the full image channel count."""
# blobs_image has 3 channels; select 1 with a matching length-1 cmap
Expand Down
Loading