Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 21 additions & 32 deletions src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from xarray import DataArray, DataTree

from spatialdata_plot._accessor import register_spatial_data_accessor
from spatialdata_plot._logging import _log_context, logger
from spatialdata_plot._logging import _log_context
from spatialdata_plot.pl.render import (
_draw_channel_legend,
_render_images,
Expand Down Expand Up @@ -190,7 +190,8 @@ def render_shapes(
shape: Literal["circle", "hex", "visium_hex", "square"] | None = None,
colorbar: bool | str | None = "auto",
colorbar_params: dict[str, object] | None = None,
**kwargs: Any,
datashader_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None,
transfunc: Callable[[float], float] | None = None,
) -> sd.SpatialData:
"""
Render shapes elements in SpatialData.
Expand Down Expand Up @@ -279,15 +280,10 @@ def render_shapes(
specified, the shapes are converted to a circle/hexagon/square before rendering. If "visium_hex" is
specified, the shapes are assumed to be Visium spots and the size of the hexagons is adjusted to be adjacent
to each other.

**kwargs : Any
Additional arguments for customization. This can include:

datashader_reduction : Literal[
"sum", "mean", "any", "count", "std", "var", "max", "min"
], default: "max"
Reduction method for datashader when coloring by continuous values. Defaults to 'max'.

datashader_reduction : Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None, optional
Reduction method for datashader when coloring by continuous values. When ``None``, defaults to ``"max"``.
transfunc : Callable[[float], float] | None, optional
Optional transformation applied to the continuous color vector before normalization and colormap mapping.

Notes
-----
Expand All @@ -300,8 +296,6 @@ def render_shapes(
sd.SpatialData
A copy of the SpatialData object with the rendering parameters stored in its plotting tree.
"""
if "vmin" in kwargs or "vmax" in kwargs:
logger.warning("`vmin` and `vmax` are deprecated. Pass matplotlib `Normalize` object to norm instead.")
params_dict = _validate_shape_render_params(
self._sdata,
element=element,
Expand All @@ -320,7 +314,7 @@ def render_shapes(
table_layer=table_layer,
shape=shape,
method=method,
ds_reduction=kwargs.get("datashader_reduction"),
ds_reduction=datashader_reduction,
colorbar=colorbar,
colorbar_params=colorbar_params,
gene_symbols=gene_symbols,
Expand Down Expand Up @@ -351,7 +345,7 @@ def render_shapes(
palette=param_values["palette"],
outline_alpha=final_outline_alpha,
fill_alpha=param_values["fill_alpha"],
transfunc=kwargs.get("transfunc"),
transfunc=transfunc,
table_name=param_values["table_name"],
table_layer=param_values["table_layer"],
shape=param_values["shape"],
Expand Down Expand Up @@ -384,7 +378,8 @@ def render_points(
gene_symbols: str | None = None,
colorbar: bool | str | None = "auto",
colorbar_params: dict[str, object] | None = None,
**kwargs: Any,
datashader_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None,
transfunc: Callable[[float], float] | None = None,
) -> sd.SpatialData:
"""
Render points elements in SpatialData.
Expand Down Expand Up @@ -452,22 +447,16 @@ def render_points(
Column name in :attr:`sdata.table.var` to use for looking up ``color``. Use this when
``var_names`` are e.g. ENSEMBL IDs but you want to refer to genes by their symbols stored
in another column of ``var``. Mimics scanpy's ``gene_symbols`` parameter.

**kwargs : Any
Additional arguments for customization. This can include:

datashader_reduction : Literal[
"sum", "mean", "any", "count", "std", "var", "max", "min"
], default: "sum"
Reduction method for datashader when coloring by continuous values. Defaults to 'sum'.
datashader_reduction : Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None, optional
Reduction method for datashader when coloring by continuous values. When ``None``, defaults to ``"sum"``.
transfunc : Callable[[float], float] | None, optional
Optional transformation applied to the continuous color vector before normalization and colormap mapping.

Returns
-------
sd.SpatialData
A copy of the SpatialData object with the rendering parameters stored in its plotting tree.
"""
if "vmin" in kwargs or "vmax" in kwargs:
logger.warning("`vmin` and `vmax` are deprecated. Pass matplotlib `Normalize` object to norm instead.")
params_dict = _validate_points_render_params(
self._sdata,
element=element,
Expand All @@ -481,7 +470,7 @@ def render_points(
size=size,
table_name=table_name,
table_layer=table_layer,
ds_reduction=kwargs.get("datashader_reduction"),
ds_reduction=datashader_reduction,
colorbar=colorbar,
colorbar_params=colorbar_params,
gene_symbols=gene_symbols,
Expand Down Expand Up @@ -511,7 +500,7 @@ def render_points(
cmap_params=cmap_params,
palette=param_values["palette"],
alpha=param_values["alpha"],
transfunc=kwargs.get("transfunc"),
transfunc=transfunc,
size=param_values["size"],
table_name=param_values["table_name"],
table_layer=param_values["table_layer"],
Expand Down Expand Up @@ -730,7 +719,7 @@ def render_labels(
table_name: str | None = None,
table_layer: str | None = None,
gene_symbols: str | None = None,
**kwargs: Any,
transfunc: Callable[[float], float] | None = None,
) -> sd.SpatialData:
"""
Render labels elements in SpatialData.
Expand Down Expand Up @@ -806,14 +795,14 @@ def render_labels(
Column name in :attr:`sdata.table.var` to use for looking up ``color``. Use this when
``var_names`` are e.g. ENSEMBL IDs but you want to refer to genes by their symbols stored
in another column of ``var``. Mimics scanpy's ``gene_symbols`` parameter.
transfunc : Callable[[float], float] | None, optional
Optional transformation applied to the continuous color vector before normalization and colormap mapping.

Returns
-------
sd.SpatialData
A copy of the SpatialData object with the rendering parameters stored in its plotting tree.
"""
if "vmin" in kwargs or "vmax" in kwargs:
logger.warning("`vmin` and `vmax` are deprecated. Pass matplotlib `Normalize` object to norm instead.")
params_dict = _validate_label_render_params(
self._sdata,
element=element,
Expand Down Expand Up @@ -859,7 +848,7 @@ def render_labels(
scale=param_values["scale"],
table_name=param_values["table_name"],
table_layer=param_values["table_layer"],
transfunc=kwargs.get("transfunc"),
transfunc=transfunc,
zorder=n_steps,
colorbar=param_values["colorbar"],
colorbar_params=param_values["colorbar_params"],
Expand Down
15 changes: 15 additions & 0 deletions tests/pl/test_render_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,21 @@ def test_cmap_matches_selected_channels_not_full_image(sdata_blobs: SpatialData)
plt.close(fig)


# Regression for #612: vmin/vmax kwargs are no longer accepted on any render
# function. The check covers all four to prevent the asymmetry from re-emerging.
@pytest.mark.parametrize("kwarg", ["vmin", "vmax"])
@pytest.mark.parametrize("func", ["render_images", "render_shapes", "render_points", "render_labels"])
def test_vmin_vmax_kwargs_rejected_uniformly(sdata_blobs: SpatialData, func: str, kwarg: str) -> None:
elements = {
"render_images": "blobs_image",
"render_labels": "blobs_labels",
"render_points": "blobs_points",
"render_shapes": "blobs_circles",
}
with pytest.raises(TypeError, match=kwarg):
getattr(sdata_blobs.pl, func)(elements[func], **{kwarg: 0})


# ---------------------------------------------------------------------------
# channels_as_legend visual tests (#459)
# ---------------------------------------------------------------------------
Expand Down
14 changes: 7 additions & 7 deletions tests/pl/test_render_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,13 @@ def _make_tablemodel_with_categorical_labels(sdata_blobs, label):

_, axs = plt.subplots(nrows=1, ncols=3, layout="tight")

sdata_blobs.pl.render_labels(label, color="channel_1_sum", table="other_table", scale="scale0").pl.show(
ax=axs[0], title="ch_1_sum", colorbar=False
)
sdata_blobs.pl.render_labels(label, color="channel_2_sum", table="other_table", scale="scale0").pl.show(
ax=axs[1], title="ch_2_sum", colorbar=False
)
sdata_blobs.pl.render_labels(label, color="which_max", table="other_table", scale="scale0").pl.show(
sdata_blobs.pl.render_labels(
label, color="channel_1_sum", table_name="other_table", scale="scale0"
).pl.show(ax=axs[0], title="ch_1_sum", colorbar=False)
sdata_blobs.pl.render_labels(
label, color="channel_2_sum", table_name="other_table", scale="scale0"
).pl.show(ax=axs[1], title="ch_2_sum", colorbar=False)
sdata_blobs.pl.render_labels(label, color="which_max", table_name="other_table", scale="scale0").pl.show(
ax=axs[2], legend_fontsize=6
)

Expand Down
4 changes: 2 additions & 2 deletions tests/pl/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,14 @@ def test_plot_can_set_zero_in_cmap_to_transparent(self, sdata_blobs: SpatialData
new_cmap = set_zero_in_cmap_to_transparent(cmap="viridis")

# baseline img
sdata_blobs.pl.render_labels("blobs_labels", color="my_var", cmap="viridis", table="table").pl.show(
sdata_blobs.pl.render_labels("blobs_labels", color="my_var", cmap="viridis", table_name="table").pl.show(
ax=axs[0], colorbar=False
)

sdata_blobs.tables["table"].obs.iloc[8:12, 2] = 0

# image with 0s as transparent, so some labels are "missing"
sdata_blobs.pl.render_labels("blobs_labels", color="my_var", cmap=new_cmap, table="table").pl.show(
sdata_blobs.pl.render_labels("blobs_labels", color="my_var", cmap=new_cmap, table_name="table").pl.show(
ax=axs[1], colorbar=False
)

Expand Down
Loading