Skip to content

Commit e682fd8

Browse files
authored
Fix O(n^2) pd.concat in _get_cs_contents (#642)
1 parent 49acff0 commit e682fd8

3 files changed

Lines changed: 44 additions & 39 deletions

File tree

src/spatialdata_plot/pl/basic.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,9 +1054,10 @@ def show(
10541054

10551055
# Check if user specified only certain elements to be plotted
10561056
cs_contents = _get_cs_contents(sdata)
1057+
cs_index = cs_contents.set_index("cs")
10571058
pending_colorbars: list[tuple[Axes, list[ColorbarSpec]]] = []
10581059

1059-
elements_to_be_rendered = _get_elements_to_be_rendered(render_cmds, cs_contents, cs)
1060+
elements_to_be_rendered = _get_elements_to_be_rendered(render_cmds, cs_index, cs)
10601061

10611062
# filter out cs without relevant elements
10621063
cmds = [cmd for cmd, _ in render_cmds]
@@ -1079,7 +1080,7 @@ def show(
10791080
strict_cs = [
10801081
cs_name
10811082
for cs_name in coordinate_systems
1082-
if all(cs_contents.query(f"cs == '{cs_name}'").iloc[0][flag] for flag in required_flags)
1083+
if cs_name in cs_index.index and all(cs_index.loc[cs_name][flag] for flag in required_flags)
10831084
]
10841085
if strict_cs:
10851086
coordinate_systems = strict_cs
@@ -1197,15 +1198,15 @@ def _draw_colorbar(
11971198
elif location == "top":
11981199
trackers_axes["top"] = pad_axes + bbox_axes.height
11991200

1200-
cs_contents = _get_cs_contents(sdata)
1201-
12021201
# go through tree
12031202

12041203
for i, cs in enumerate(coordinate_systems):
12051204
sdata = self._copy()
1206-
_, has_images, has_labels, has_points, has_shapes = (
1207-
cs_contents.query(f"cs == '{cs}'").iloc[0, :].values.tolist()
1208-
)
1205+
cs_row = cs_index.loc[cs]
1206+
has_images = cs_row["has_images"]
1207+
has_labels = cs_row["has_labels"]
1208+
has_points = cs_row["has_points"]
1209+
has_shapes = cs_row["has_shapes"]
12091210
ax = fig_params.ax if fig_params.axs is None else fig_params.axs[i]
12101211
assert isinstance(ax, Axes)
12111212
axis_colorbar_requests: list[ColorbarSpec] | None = [] if legend_params.colorbar else None

src/spatialdata_plot/pl/utils.py

Lines changed: 18 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -334,35 +334,21 @@ def _get_cs_contents(sdata: sd.SpatialData) -> pd.DataFrame:
334334
"""Check which coordinate systems contain which elements and return that info."""
335335
cs_mapping = _get_coordinate_system_mapping(sdata)
336336
content_flags = ["has_images", "has_labels", "has_points", "has_shapes"]
337-
cs_contents = pd.DataFrame(columns=["cs"] + content_flags)
338337

338+
rows = []
339339
for cs_name, element_ids in cs_mapping.items():
340-
# determine if coordinate system has the respective elements
341-
cs_has_images = any(e in sdata.images for e in element_ids)
342-
cs_has_labels = any(e in sdata.labels for e in element_ids)
343-
cs_has_points = any(e in sdata.points for e in element_ids)
344-
cs_has_shapes = any(e in sdata.shapes for e in element_ids)
345-
346-
cs_contents = pd.concat(
347-
[
348-
cs_contents,
349-
pd.DataFrame(
350-
{
351-
"cs": cs_name,
352-
"has_images": [cs_has_images],
353-
"has_labels": [cs_has_labels],
354-
"has_points": [cs_has_points],
355-
"has_shapes": [cs_has_shapes],
356-
}
357-
),
358-
]
340+
rows.append(
341+
{
342+
"cs": cs_name,
343+
"has_images": any(e in sdata.images for e in element_ids),
344+
"has_labels": any(e in sdata.labels for e in element_ids),
345+
"has_points": any(e in sdata.points for e in element_ids),
346+
"has_shapes": any(e in sdata.shapes for e in element_ids),
347+
}
359348
)
360349

361-
cs_contents["has_images"] = cs_contents["has_images"].astype("bool")
362-
cs_contents["has_labels"] = cs_contents["has_labels"].astype("bool")
363-
cs_contents["has_points"] = cs_contents["has_points"].astype("bool")
364-
cs_contents["has_shapes"] = cs_contents["has_shapes"].astype("bool")
365-
350+
cs_contents = pd.DataFrame(rows, columns=["cs"] + content_flags)
351+
cs_contents[content_flags] = cs_contents[content_flags].astype("bool")
366352
return cs_contents
367353

368354

@@ -2102,7 +2088,7 @@ def _get_elements_to_be_rendered(
21022088
ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams,
21032089
]
21042090
],
2105-
cs_contents: pd.DataFrame,
2091+
cs_index: pd.DataFrame,
21062092
cs: str,
21072093
) -> list[str]:
21082094
"""
@@ -2112,23 +2098,23 @@ def _get_elements_to_be_rendered(
21122098
----------
21132099
render_cmds
21142100
List of tuples containing the commands and their respective parameters.
2115-
cs_contents
2116-
The dataframe indicating for each coordinate system which SpatialElements it contains.
2101+
cs_index
2102+
The cs_contents dataframe indexed by the "cs" column.
21172103
cs
2118-
The name of the coordinate system to query cs_contents for.
2104+
The name of the coordinate system to query cs_index for.
21192105
21202106
Returns
21212107
-------
21222108
List of names of the SpatialElements to be rendered in the plot.
21232109
"""
21242110
elements_to_be_rendered: list[str] = []
21252111

2126-
cs_query = cs_contents.query(f"cs == '{cs}'")
2112+
cs_row = cs_index.loc[cs] if cs in cs_index.index else None
21272113

21282114
for cmd, params in render_cmds:
21292115
key = _RENDER_CMD_TO_CS_FLAG.get(cmd)
2130-
if key and cs_query[key][0]:
2131-
elements_to_be_rendered += [params.element]
2116+
if key and cs_row is not None and cs_row[key]:
2117+
elements_to_be_rendered.append(params.element)
21322118

21332119
return elements_to_be_rendered
21342120

tests/pl/test_render.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
import matplotlib.pyplot as plt
2+
import numpy as np
23
import pytest
4+
from spatialdata import SpatialData
5+
from spatialdata.models import Image2DModel
6+
from spatialdata.transformations import Identity, set_transformation
7+
8+
import spatialdata_plot # noqa: F401
39

410

511
def test_render_images_can_plot_one_cyx_image(request):
@@ -97,3 +103,15 @@ def test_single_ax_auto_cs_unresolvable_raises(sdata_multi_cs):
97103
with pytest.raises(ValueError, match="coordinate_systems="):
98104
# Only render shapes (present in both CS), so strict filter can't narrow down
99105
sdata_multi_cs.pl.render_shapes("shp").pl.show(ax=ax)
106+
107+
108+
def test_cs_name_with_apostrophe_does_not_crash():
109+
# Regression test for #602: .query(f"cs == '{cs}'") raised TokenError for cs names
110+
# containing single quotes.
111+
data = np.zeros((1, 10, 10), dtype=np.float64)
112+
img = Image2DModel.parse(data, dims=("c", "y", "x"))
113+
sdata = SpatialData(images={"img": img})
114+
set_transformation(sdata["img"], Identity(), to_coordinate_system="patient's_cs")
115+
_, ax = plt.subplots()
116+
sdata.pl.render_images("img").pl.show(ax=ax, coordinate_systems="patient's_cs")
117+
plt.close("all")

0 commit comments

Comments
 (0)