diff --git a/src/spatialdata_plot/pl/_datashader.py b/src/spatialdata_plot/pl/_datashader.py index 2bcaed79..b5cf0532 100644 --- a/src/spatialdata_plot/pl/_datashader.py +++ b/src/spatialdata_plot/pl/_datashader.py @@ -43,6 +43,21 @@ # --------------------------------------------------------------------------- +def _apply_user_alpha(result: ds.tf.Image | np.ndarray, alpha: float) -> ds.tf.Image | np.ndarray: + """Scale the alpha channel of a datashader shade result by ``alpha``. + + ``ds.tf.shade(min_alpha=...)`` is a floor, not a scale, so user alpha + must be applied post-hoc. See #617. + """ + if alpha >= 1.0 or result is None: + return result + arr = result if isinstance(result, np.ndarray) else result.to_numpy().base + if arr is None or arr.ndim != 3 or arr.shape[-1] != 4: + return result + arr[..., 3] = (arr[..., 3].astype(np.float32) * alpha).astype(np.uint8) + return result + + def _coerce_categorical_source(series: pd.Series | dd.Series) -> pd.Categorical: """Return a ``pd.Categorical`` from a pandas or dask Series.""" if isinstance(series, dd.Series): @@ -241,6 +256,7 @@ def _ds_shade_continuous( span=color_span, clip=norm.clip, ) + shaded = _apply_user_alpha(shaded, alpha) nan_shaded = None if nan_agg is not None: @@ -251,6 +267,7 @@ def _ds_shade_continuous( # only shapes (no spread) pass min_alpha for NaN shading shade_kwargs["min_alpha"] = _convert_alpha_to_datashader_range(alpha) nan_shaded = ds.tf.shade(nan_agg, **shade_kwargs) + nan_shaded = _apply_user_alpha(nan_shaded, alpha) return shaded, nan_shaded, reduction_bounds @@ -270,12 +287,13 @@ def _ds_shade_categorical( ds_cmap = _hex_no_alpha(ds_cmap) agg_to_shade = ds.tf.spread(agg, px=spread_px) if spread_px is not None else agg - return _datashader_map_aggregate_to_color( + shaded = _datashader_map_aggregate_to_color( agg_to_shade, cmap=ds_cmap, color_key=color_key, min_alpha=_convert_alpha_to_datashader_range(alpha), ) + return _apply_user_alpha(shaded, alpha) # --------------------------------------------------------------------------- @@ -338,6 +356,7 @@ def _render_ds_outlines( min_alpha=_convert_alpha_to_datashader_range(alpha), how="linear", ) + shaded = _apply_user_alpha(shaded, alpha) rgba, trans = _create_image_from_datashader_result(shaded, factor, ax) _ax_show_and_transform(rgba, trans, ax, zorder=render_params.zorder, extent=extent) diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index e3131d98..04bcda88 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -681,7 +681,7 @@ def _set_outline( """ # A) User doesn't want to see outlines if ( - (outline_alpha and outline_alpha == 0.0) + outline_alpha == 0.0 or (isinstance(outline_alpha, tuple) and np.all(np.array(outline_alpha) == 0.0)) or not (outline_alpha or outline_width or outline_color) ): diff --git a/tests/_images/Points_datashader_continuous_color.png b/tests/_images/Points_datashader_continuous_color.png index 068fa6ee..be92016a 100644 Binary files a/tests/_images/Points_datashader_continuous_color.png and b/tests/_images/Points_datashader_continuous_color.png differ diff --git a/tests/_images/Points_mpl_and_datashader_point_sizes_agree_after_altered_dpi.png b/tests/_images/Points_mpl_and_datashader_point_sizes_agree_after_altered_dpi.png index f5e2f77a..3cfff9c7 100644 Binary files a/tests/_images/Points_mpl_and_datashader_point_sizes_agree_after_altered_dpi.png and b/tests/_images/Points_mpl_and_datashader_point_sizes_agree_after_altered_dpi.png differ diff --git a/tests/pl/test_render_shapes.py b/tests/pl/test_render_shapes.py index db3df973..815071d9 100644 --- a/tests/pl/test_render_shapes.py +++ b/tests/pl/test_render_shapes.py @@ -1228,6 +1228,68 @@ def test_datashader_alpha_not_applied_twice(sdata_blobs: SpatialData): plt.close(fig) +@pytest.mark.parametrize( + ("fill_alpha", "expected_max"), + [ + (0.0, 0), + (0.3, 76), + (0.5, 127), + (1.0, 255), + ], +) +def test_datashader_respects_fill_alpha(sdata_blobs: SpatialData, fill_alpha: float, expected_max: int): + """fill_alpha must scale the rendered alpha channel linearly on the datashader path (#617).""" + fig, ax = plt.subplots() + sdata_blobs.pl.render_shapes( + element="blobs_polygons", + method="datashader", + fill_alpha=fill_alpha, + ).pl.show(ax=ax) + fig.canvas.draw() + + axes_images = [c for c in ax.get_children() if isinstance(c, matplotlib.image.AxesImage)] + assert axes_images + rgba = axes_images[0].get_array() + assert rgba.ndim == 3 and rgba.shape[-1] == 4 + assert int(rgba[..., 3].max()) == expected_max + plt.close(fig) + + +@pytest.mark.parametrize( + ("outline_alpha", "expected_max"), + [ + (0.0, None), + (0.3, 76), + (0.5, 127), + (1.0, 255), + ], +) +def test_datashader_respects_outline_alpha(sdata_blobs: SpatialData, outline_alpha: float, expected_max: int | None): + """outline_alpha must scale the outline image's alpha; alpha=0 must skip rendering entirely (#617).""" + fig, ax = plt.subplots() + sdata_blobs.pl.render_shapes( + element="blobs_polygons", + method="datashader", + fill_alpha=1.0, + outline_alpha=outline_alpha, + outline_color="red", + ).pl.show(ax=ax) + fig.canvas.draw() + + axes_images = [c for c in ax.get_children() if isinstance(c, matplotlib.image.AxesImage)] + outline_imgs = [ + img + for img in axes_images + if (arr := img.get_array()).ndim == 3 and arr.shape[-1] == 4 and arr[..., 0].max() > arr[..., 1].max() + ] + if expected_max is None: + assert not outline_imgs + else: + assert outline_imgs + assert int(outline_imgs[0].get_array()[..., 3].max()) == expected_max + plt.close(fig) + + def test_render_shapes_color_with_conflicting_index_name(): """render_shapes(color=...) must not crash when obs.index.name matches an existing column. diff --git a/tests/pl/test_utils.py b/tests/pl/test_utils.py index a456d765..cac36885 100644 --- a/tests/pl/test_utils.py +++ b/tests/pl/test_utils.py @@ -8,10 +8,12 @@ from spatialdata import SpatialData import spatialdata_plot +from spatialdata_plot.pl.render_params import Color from spatialdata_plot.pl.utils import ( _apply_cmap_alpha_to_datashader_result, _datashader_map_aggregate_to_color, _get_subplots, + _set_outline, set_zero_in_cmap_to_transparent, ) from tests.conftest import DPI, PlotTester, PlotTesterMeta @@ -164,6 +166,22 @@ def test_is_color_like(color_result: tuple[ColorLike, bool]): assert spatialdata_plot.pl.utils._is_color_like(color) == result +@pytest.mark.parametrize( + ("outline_alpha", "outline_color", "expected"), + [ + (0.0, Color("#ff0000"), (0.0, 0.0)), + (0, Color("#ff0000"), (0.0, 0.0)), + ((0.0, 0.0), Color("#ff0000"), (0.0, 0.0)), + (0.5, Color("#ff0000"), (0.5, 0.0)), + (1.0, Color("#ff0000"), (1.0, 0.0)), + ], +) +def test_set_outline_respects_zero_alpha(outline_alpha, outline_color, expected): + """outline_alpha=0 must yield (0.0, 0.0) even when outline_color is set (#617 follow-up).""" + alpha, _ = _set_outline(outline_alpha=outline_alpha, outline_width=None, outline_color=outline_color) + assert alpha == expected + + class TestCmapAlphaDatashader: """Regression tests for #376: set_zero_in_cmap_to_transparent with datashader."""