Skip to content
Open
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
767441d
copy code from old PR
Jiaqi-Lv Nov 6, 2025
d2a9702
preliminiary testing
Jiaqi-Lv Nov 7, 2025
44c4994
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 7, 2025
0f8d4fe
initial prototype
Jiaqi-Lv Nov 10, 2025
d42b78a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 10, 2025
b7f829c
clean up
Jiaqi-Lv Nov 10, 2025
cba5fd5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 10, 2025
f2cdcc4
update
Jiaqi-Lv Nov 11, 2025
dd99d97
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 11, 2025
c468a8b
update pipeline
Jiaqi-Lv Nov 12, 2025
14f870a
update pipeline
Jiaqi-Lv Nov 12, 2025
6e65fba
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 12, 2025
7eb916e
refactor code
Jiaqi-Lv Nov 12, 2025
8442ac2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 12, 2025
de83074
clean up
Jiaqi-Lv Nov 12, 2025
17e5422
Merge branch 'dev-define-engines-abc' into dev-define-nucleus-detecti…
shaneahmed Nov 17, 2025
f5b1885
update patch mode processing
Jiaqi-Lv Nov 22, 2025
551e43c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 22, 2025
12f985a
tidy up code
Jiaqi-Lv Nov 22, 2025
05b2c7d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 22, 2025
2afbf8c
fix precommit
Jiaqi-Lv Nov 23, 2025
367295d
update test
Jiaqi-Lv Nov 23, 2025
7912abe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 23, 2025
0a72e8b
improve tests
Jiaqi-Lv Nov 24, 2025
f8b4189
improve tests
Jiaqi-Lv Nov 24, 2025
228731c
precommit
Jiaqi-Lv Nov 24, 2025
6c26a0f
fix deepsource
Jiaqi-Lv Nov 25, 2025
7ffea5b
fix deepsource
Jiaqi-Lv Nov 27, 2025
79fc088
Merge branch 'dev-define-engines-abc' into dev-define-nucleus-detecti…
shaneahmed Dec 1, 2025
198982f
Merge branch 'dev-define-engines-abc' into dev-define-nucleus-detecti…
shaneahmed Dec 5, 2025
884bdf0
Merge branch 'dev-define-engines-abc' into dev-define-nucleus-detecti…
shaneahmed Dec 5, 2025
d63a7cc
refactor code and improve typing
Jiaqi-Lv Dec 10, 2025
a90f748
add test for map_overlap postprocessing
Jiaqi-Lv Dec 10, 2025
2e004f1
refactor postprocessing and saving
Jiaqi-Lv Dec 11, 2025
3638985
update save as zarr
Jiaqi-Lv Dec 11, 2025
af6d2c1
improve test coverage
Jiaqi-Lv Dec 11, 2025
1974b2f
improve tests and address comments
Jiaqi-Lv Dec 12, 2025
775e6a1
fix tests
Jiaqi-Lv Dec 12, 2025
b5f9c7d
:wrench: Add `run` function
shaneahmed Dec 12, 2025
86e809d
use smaller wsi for testing
Jiaqi-Lv Dec 12, 2025
abf02f6
reduce code complexity
Jiaqi-Lv Dec 12, 2025
5428618
add run params
Jiaqi-Lv Dec 12, 2025
dd74781
rename function
Jiaqi-Lv Dec 15, 2025
89505ae
:construction: Review and minor changes.
shaneahmed Dec 15, 2025
f14e4cb
fix tests
Jiaqi-Lv Dec 15, 2025
54a6447
fix tests
Jiaqi-Lv Dec 15, 2025
656dbb9
fix tests
Jiaqi-Lv Dec 15, 2025
10d09bd
update post processing and saving
Jiaqi-Lv Dec 15, 2025
b58dd88
fix tests
Jiaqi-Lv Dec 15, 2025
d2eae54
update model postproc function
Jiaqi-Lv Dec 15, 2025
d44d943
fix deepsource
Jiaqi-Lv Dec 15, 2025
b875e97
update RunParams
Jiaqi-Lv Dec 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/pretrained.rst
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ The input output configuration is as follows:
ioconfig = IOPatchPredictorConfig(
patch_input_shape=(31, 31),
stride_shape=(8, 8),
input_resolutions=[{"resolution": 0.25, "units": "mpp"}]
input_resolutions=[{"resolution": 0.5, "units": "mpp"}]
)


Expand All @@ -369,7 +369,7 @@ The input output configuration is as follows:
ioconfig = IOPatchPredictorConfig(
patch_input_shape=(252, 252),
stride_shape=(150, 150),
input_resolutions=[{"resolution": 0.25, "units": "mpp"}]
input_resolutions=[{"resolution": 0.5, "units": "mpp"}]
)


Expand All @@ -393,7 +393,7 @@ The input output configuration is as follows:
ioconfig = IOPatchPredictorConfig(
patch_input_shape=(31, 31),
stride_shape=(8, 8),
input_resolutions=[{"resolution": 0.25, "units": "mpp"}]
input_resolutions=[{"resolution": 0.5, "units": "mpp"}]
)


Expand All @@ -409,7 +409,7 @@ The input output configuration is as follows:
ioconfig = IOPatchPredictorConfig(
patch_input_shape=(252, 252),
stride_shape=(150, 150),
input_resolutions=[{"resolution": 0.25, "units": "mpp"}]
input_resolutions=[{"resolution": 0.5, "units": "mpp"}]
)


Expand Down
225 changes: 225 additions & 0 deletions tests/engines/test_nucleus_detection_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
"""Tests for NucleusDetector."""

import pathlib
import shutil
from collections.abc import Callable

import dask.array as da
import numpy as np
import pandas as pd
import pytest

from tiatoolbox.annotation.storage import SQLiteStore
from tiatoolbox.models.engine.nucleus_detector import NucleusDetector
from tiatoolbox.utils import env_detection as toolbox_env
from tiatoolbox.utils.misc import imwrite
from tiatoolbox.wsicore.wsireader import WSIReader

device = "cuda" if toolbox_env.has_gpu() else "cpu"


def _rm_dir(path: pathlib.Path) -> None:
"""Helper func to remove directory."""
if pathlib.Path(path).exists():
shutil.rmtree(path, ignore_errors=True)


def check_output(path: pathlib.Path) -> None:
"""Check NucleusDetector output."""


def test_nucleus_detection_nms_empty_dataframe() -> None:
"""nucleus_detection_nms should return a copy for empty inputs."""
df = pd.DataFrame(columns=["x", "y", "type", "prob"])

result = NucleusDetector.nucleus_detection_nms(df, radius=3)

assert result.empty
assert result is not df
assert list(result.columns) == ["x", "y", "type", "prob"]


def test_nucleus_detection_nms_invalid_radius() -> None:
"""Radius must be strictly positive."""
df = pd.DataFrame({"x": [0], "y": [0], "type": [1], "prob": [0.9]})

with pytest.raises(ValueError, match="radius must be > 0"):
NucleusDetector.nucleus_detection_nms(df, radius=0)


def test_nucleus_detection_nms_invalid_overlap_threshold() -> None:
"""overlap_threshold must lie in (0, 1]."""
df = pd.DataFrame({"x": [0], "y": [0], "type": [1], "prob": [0.9]})

message = r"overlap_threshold must be in \(0\.0, 1\.0\], got 0"
with pytest.raises(ValueError, match=message):
NucleusDetector.nucleus_detection_nms(df, radius=1, overlap_threshold=0)


def test_nucleus_detection_nms_suppresses_overlapping_detections() -> None:
"""Lower-probability overlapping detections are removed."""
df = pd.DataFrame(
{
"x": [2, 0, 20],
"y": [1, 0, 20],
"type": [1, 1, 2],
"prob": [0.6, 0.9, 0.7],
}
)

result = NucleusDetector.nucleus_detection_nms(df, radius=5)

expected = pd.DataFrame(
{"x": [0, 20], "y": [0, 20], "type": [1, 2], "prob": [0.9, 0.7]}
)
pd.testing.assert_frame_equal(result.reset_index(drop=True), expected)


def test_nucleus_detection_nms_suppresses_across_types() -> None:
"""Overlapping detections of different types are also suppressed."""
df = pd.DataFrame(
{
"x": [0, 0, 20],
"y": [0, 0, 0],
"type": [1, 2, 1],
"prob": [0.6, 0.95, 0.4],
}
)

result = NucleusDetector.nucleus_detection_nms(df, radius=5)

expected = pd.DataFrame(
{"x": [0, 20], "y": [0, 0], "type": [2, 1], "prob": [0.95, 0.4]}
)
pd.testing.assert_frame_equal(result.reset_index(drop=True), expected)


def test_nucleus_detection_nms_retains_non_overlapping_candidates() -> None:
"""Detections with IoU below the threshold are preserved."""
df = pd.DataFrame(
{
"x": [0, 10],
"y": [0, 0],
"type": [1, 1],
"prob": [0.8, 0.5],
}
)

result = NucleusDetector.nucleus_detection_nms(df, radius=5, overlap_threshold=0.5)

expected = pd.DataFrame(
{"x": [0, 10], "y": [0, 0], "type": [1, 1], "prob": [0.8, 0.5]}
)
pd.testing.assert_frame_equal(result.reset_index(drop=True), expected)


def test_nucleus_detector_wsi(remote_sample: Callable, tmp_path: pathlib.Path) -> None:
"""Test for nucleus detection engine."""
mini_wsi_svs = pathlib.Path(remote_sample("wsi4_512_512_svs"))

pretrained_model = "mapde-conic"

save_dir = tmp_path

nucleus_detector = NucleusDetector(model=pretrained_model)
_ = nucleus_detector.run(
patch_mode=False,
device=device,
output_type="annotationstore",
memory_threshold=50,
images=[mini_wsi_svs],
save_dir=save_dir,
overwrite=True,
)

store = SQLiteStore.open(save_dir / "wsi4_512_512.db")
assert len(store.values()) == 281
store.close()

_rm_dir(save_dir)


def test_nucleus_detector_patch(
remote_sample: Callable, tmp_path: pathlib.Path
) -> None:
"""Test for nucleus detection engine in patch mode."""
mini_wsi_svs = pathlib.Path(remote_sample("wsi4_512_512_svs"))

wsi_reader = WSIReader.open(mini_wsi_svs)
patch_1 = wsi_reader.read_rect((0, 0), (252, 252), resolution=0.5, units="mpp")
patch_2 = wsi_reader.read_rect((252, 252), (252, 252), resolution=0.5, units="mpp")

pretrained_model = "mapde-conic"

save_dir = tmp_path

nucleus_detector = NucleusDetector(model=pretrained_model)
_ = nucleus_detector.run(
patch_mode=True,
device=device,
output_type="annotationstore",
memory_threshold=50,
images=[patch_1, patch_2],
save_dir=save_dir,
overwrite=True,
class_dict=None,
)

store_1 = SQLiteStore.open(save_dir / "0.db")
assert len(store_1.values()) == 270
store_1.close()

store_2 = SQLiteStore.open(save_dir / "1.db")
assert len(store_2.values()) == 52
store_2.close()

imwrite(save_dir / "patch_0.png", patch_1)
imwrite(save_dir / "patch_1.png", patch_2)
_ = nucleus_detector.run(
patch_mode=True,
device=device,
output_type="zarr",
memory_threshold=50,
images=[save_dir / "patch_0.png", save_dir / "patch_1.png"],
save_dir=save_dir,
overwrite=True,
)

store_1 = SQLiteStore.open(save_dir / "patch_0.db")
assert len(store_1.values()) == 270
store_1.close()

store_2 = SQLiteStore.open(save_dir / "patch_1.db")
assert len(store_2.values()) == 52
store_2.close()

_rm_dir(save_dir)


def test_nucleus_detector_write_centroid_maps(tmp_path: pathlib.Path) -> None:
"""Test for _write_centroid_maps function."""
detection_maps = np.zeros((20, 20, 1), dtype=np.uint8)
detection_maps = da.from_array(detection_maps, chunks=(20, 20, 1))

store = NucleusDetector.write_centroid_maps_to_store(
detection_maps=detection_maps, class_dict=None
)
assert len(store.values()) == 0
store.close()

detection_maps = np.zeros((20, 20, 1), dtype=np.uint8)
detection_maps[10, 10, 0] = 1
detection_maps = da.from_array(detection_maps, chunks=(20, 20, 1))
_ = NucleusDetector.write_centroid_maps_to_store(
detection_maps=detection_maps,
save_path=tmp_path / "test.db",
class_dict={0: "nucleus"},
)
store = SQLiteStore.open(tmp_path / "test.db")
assert len(store.values()) == 1
annotation = next(iter(store.values()))
print(annotation)
assert annotation.properties["type"] == "nucleus"
assert annotation.geometry.centroid.x == 10.0
assert annotation.geometry.centroid.y == 10.0
store.close()
31 changes: 30 additions & 1 deletion tests/models/test_arch_mapde.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from tiatoolbox.models import MapDe
from tiatoolbox.models.architecture import fetch_pretrained_weights
from tiatoolbox.models.engine.nucleus_detector import NucleusDetector
from tiatoolbox.utils import env_detection as toolbox_env
from tiatoolbox.utils.misc import select_device
from tiatoolbox.wsicore.wsireader import WSIReader
Expand Down Expand Up @@ -48,7 +49,35 @@ def test_functionality(remote_sample: Callable) -> None:
batch = torch.from_numpy(patch)[None]
output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))
output = model.postproc(output[0])
assert np.all(output[0:2] == [[19, 171], [53, 89]])
xs, ys, _, _ = NucleusDetector._centroid_maps_to_detection_records(output, None)

np.testing.assert_array_equal(xs[0:2], np.array([242, 192]))
np.testing.assert_array_equal(ys[0:2], np.array([10, 13]))

patch = reader.read_bounds(
(0, 0, 252, 252),
resolution=0.50,
units="mpp",
coord_space="resolution",
)

model, weights_path = _load_mapde(name="mapde-conic")
patch = model.preproc(patch)
batch = torch.from_numpy(patch)[None]
output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))
block_info = {
0: {
"array-location": [
[0, 1],
[0, 1],
], # dummy block to test no valid detections
}
}
output = model.postproc(output[0], block_info=block_info)
xs, ys, _, _ = NucleusDetector._centroid_maps_to_detection_records(output, None)
np.testing.assert_array_equal(xs, np.array([]))
np.testing.assert_array_equal(ys, np.array([]))

Path(weights_path).unlink()


Expand Down
36 changes: 33 additions & 3 deletions tests/models/test_arch_sccnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from tiatoolbox.models import SCCNN
from tiatoolbox.models.architecture import fetch_pretrained_weights
from tiatoolbox.models.engine.nucleus_detector import NucleusDetector
from tiatoolbox.utils import env_detection
from tiatoolbox.utils.misc import select_device
from tiatoolbox.wsicore.wsireader import WSIReader
Expand Down Expand Up @@ -48,13 +49,42 @@ def test_functionality(remote_sample: Callable) -> None:
device=select_device(on_gpu=env_detection.has_gpu()),
)
output = model.postproc(output[0])
np.testing.assert_array_equal(output, np.array([[8, 7]]))
xs, ys, _, _ = NucleusDetector._centroid_maps_to_detection_records(output, None)

np.testing.assert_array_equal(xs, np.array([8]))
np.testing.assert_array_equal(ys, np.array([7]))

model = _load_sccnn(name="sccnn-conic")
output = model.infer_batch(
model,
batch,
device=select_device(on_gpu=env_detection.has_gpu()),
)
output = model.postproc(output[0])
np.testing.assert_array_equal(output, np.array([[7, 8]]))
block_info = {
0: {
"array-location": [[0, 31], [0, 31]],
}
}
output = model.postproc(output[0], block_info=block_info)
xs, ys, _, _ = NucleusDetector._centroid_maps_to_detection_records(output, None)
np.testing.assert_array_equal(xs, np.array([7]))
np.testing.assert_array_equal(ys, np.array([8]))

model = _load_sccnn(name="sccnn-conic")
output = model.infer_batch(
model,
batch,
device=select_device(on_gpu=env_detection.has_gpu()),
)
block_info = {
0: {
"array-location": [
[0, 1],
[0, 1],
], # dummy block to test no valid detections
}
}
output = model.postproc(output[0], block_info=block_info)
xs, ys, _, _ = NucleusDetector._centroid_maps_to_detection_records(output, None)
np.testing.assert_array_equal(xs, np.array([]))
np.testing.assert_array_equal(ys, np.array([]))
Loading
Loading