diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index 41965dec..1f11f20b 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -1054,9 +1054,10 @@ def show( # Check if user specified only certain elements to be plotted cs_contents = _get_cs_contents(sdata) + cs_index = cs_contents.set_index("cs") pending_colorbars: list[tuple[Axes, list[ColorbarSpec]]] = [] - elements_to_be_rendered = _get_elements_to_be_rendered(render_cmds, cs_contents, cs) + elements_to_be_rendered = _get_elements_to_be_rendered(render_cmds, cs_index, cs) # filter out cs without relevant elements cmds = [cmd for cmd, _ in render_cmds] @@ -1079,7 +1080,7 @@ def show( strict_cs = [ cs_name for cs_name in coordinate_systems - if all(cs_contents.query(f"cs == '{cs_name}'").iloc[0][flag] for flag in required_flags) + if cs_name in cs_index.index and all(cs_index.loc[cs_name][flag] for flag in required_flags) ] if strict_cs: coordinate_systems = strict_cs @@ -1197,15 +1198,15 @@ def _draw_colorbar( elif location == "top": trackers_axes["top"] = pad_axes + bbox_axes.height - cs_contents = _get_cs_contents(sdata) - # go through tree for i, cs in enumerate(coordinate_systems): sdata = self._copy() - _, has_images, has_labels, has_points, has_shapes = ( - cs_contents.query(f"cs == '{cs}'").iloc[0, :].values.tolist() - ) + cs_row = cs_index.loc[cs] + has_images = cs_row["has_images"] + has_labels = cs_row["has_labels"] + has_points = cs_row["has_points"] + has_shapes = cs_row["has_shapes"] ax = fig_params.ax if fig_params.axs is None else fig_params.axs[i] assert isinstance(ax, Axes) axis_colorbar_requests: list[ColorbarSpec] | None = [] if legend_params.colorbar else None diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 1aa283e9..d36698d6 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -334,35 +334,21 @@ def _get_cs_contents(sdata: sd.SpatialData) -> pd.DataFrame: """Check which coordinate systems contain which elements and return that info.""" cs_mapping = _get_coordinate_system_mapping(sdata) content_flags = ["has_images", "has_labels", "has_points", "has_shapes"] - cs_contents = pd.DataFrame(columns=["cs"] + content_flags) + rows = [] for cs_name, element_ids in cs_mapping.items(): - # determine if coordinate system has the respective elements - cs_has_images = any(e in sdata.images for e in element_ids) - cs_has_labels = any(e in sdata.labels for e in element_ids) - cs_has_points = any(e in sdata.points for e in element_ids) - cs_has_shapes = any(e in sdata.shapes for e in element_ids) - - cs_contents = pd.concat( - [ - cs_contents, - pd.DataFrame( - { - "cs": cs_name, - "has_images": [cs_has_images], - "has_labels": [cs_has_labels], - "has_points": [cs_has_points], - "has_shapes": [cs_has_shapes], - } - ), - ] + rows.append( + { + "cs": cs_name, + "has_images": any(e in sdata.images for e in element_ids), + "has_labels": any(e in sdata.labels for e in element_ids), + "has_points": any(e in sdata.points for e in element_ids), + "has_shapes": any(e in sdata.shapes for e in element_ids), + } ) - cs_contents["has_images"] = cs_contents["has_images"].astype("bool") - cs_contents["has_labels"] = cs_contents["has_labels"].astype("bool") - cs_contents["has_points"] = cs_contents["has_points"].astype("bool") - cs_contents["has_shapes"] = cs_contents["has_shapes"].astype("bool") - + cs_contents = pd.DataFrame(rows, columns=["cs"] + content_flags) + cs_contents[content_flags] = cs_contents[content_flags].astype("bool") return cs_contents @@ -2106,7 +2092,7 @@ def _get_elements_to_be_rendered( ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams, ] ], - cs_contents: pd.DataFrame, + cs_index: pd.DataFrame, cs: str, ) -> list[str]: """ @@ -2116,10 +2102,10 @@ def _get_elements_to_be_rendered( ---------- render_cmds List of tuples containing the commands and their respective parameters. - cs_contents - The dataframe indicating for each coordinate system which SpatialElements it contains. + cs_index + The cs_contents dataframe indexed by the "cs" column. cs - The name of the coordinate system to query cs_contents for. + The name of the coordinate system to query cs_index for. Returns ------- @@ -2127,12 +2113,12 @@ def _get_elements_to_be_rendered( """ elements_to_be_rendered: list[str] = [] - cs_query = cs_contents.query(f"cs == '{cs}'") + cs_row = cs_index.loc[cs] if cs in cs_index.index else None for cmd, params in render_cmds: key = _RENDER_CMD_TO_CS_FLAG.get(cmd) - if key and cs_query[key][0]: - elements_to_be_rendered += [params.element] + if key and cs_row is not None and cs_row[key]: + elements_to_be_rendered.append(params.element) return elements_to_be_rendered diff --git a/tests/pl/test_render.py b/tests/pl/test_render.py index 83c6ee3c..29dd36b1 100644 --- a/tests/pl/test_render.py +++ b/tests/pl/test_render.py @@ -1,5 +1,11 @@ import matplotlib.pyplot as plt +import numpy as np import pytest +from spatialdata import SpatialData +from spatialdata.models import Image2DModel +from spatialdata.transformations import Identity, set_transformation + +import spatialdata_plot # noqa: F401 def test_render_images_can_plot_one_cyx_image(request): @@ -97,3 +103,15 @@ def test_single_ax_auto_cs_unresolvable_raises(sdata_multi_cs): with pytest.raises(ValueError, match="coordinate_systems="): # Only render shapes (present in both CS), so strict filter can't narrow down sdata_multi_cs.pl.render_shapes("shp").pl.show(ax=ax) + + +def test_cs_name_with_apostrophe_does_not_crash(): + # Regression test for #602: .query(f"cs == '{cs}'") raised TokenError for cs names + # containing single quotes. + data = np.zeros((1, 10, 10), dtype=np.float64) + img = Image2DModel.parse(data, dims=("c", "y", "x")) + sdata = SpatialData(images={"img": img}) + set_transformation(sdata["img"], Identity(), to_coordinate_system="patient's_cs") + _, ax = plt.subplots() + sdata.pl.render_images("img").pl.show(ax=ax, coordinate_systems="patient's_cs") + plt.close("all")