From ddef0c306fcad6f01c42ae278a729bf8d792892a Mon Sep 17 00:00:00 2001 From: Vijay Vignesh Prasad Rao Date: Thu, 7 May 2026 09:59:27 -0400 Subject: [PATCH 1/7] Adding per_component functionality to Hausdorff Distance metric Signed-off-by: Vijay Vignesh Prasad Rao --- monai/metrics/hausdorff_distance.py | 126 ++++++++++++++++++++--- tests/metrics/test_hausdorff_distance.py | 53 ++++++++++ 2 files changed, 166 insertions(+), 13 deletions(-) diff --git a/monai/metrics/hausdorff_distance.py b/monai/metrics/hausdorff_distance.py index 1b83c93e5b..2302ba3038 100644 --- a/monai/metrics/hausdorff_distance.py +++ b/monai/metrics/hausdorff_distance.py @@ -17,7 +17,13 @@ import numpy as np import torch -from monai.metrics.utils import do_metric_reduction, get_edge_surface_distance, ignore_background, prepare_spacing +from monai.metrics.utils import ( + compute_voronoi_regions_fast, + do_metric_reduction, + get_edge_surface_distance, + ignore_background, + prepare_spacing, +) from monai.utils import MetricReduction, convert_data_type from .metric import CumulativeIterationMetric @@ -37,6 +43,18 @@ class HausdorffDistanceMetric(CumulativeIterationMetric): Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`. + The ``per_component=True`` approach computes the Hausdorff distance on a per-connected component basis in the ground + truth segmentation. This ensures that each component contributes equally to the final metric, regardless of its size. + Traditional Hausdorff distance can be dominated by large structures, but the per-component method gives a more + balanced evaluation, particularly for small or fragmented objects. This provides a granular assessment of segmentation + quality, which is especially important in cases with multiple disconnected foreground components. + Note: + - The input prediction (`y_pred`) and ground truth (`y`) must both have 2 channels (foreground/background), + with binary segmentation (0 for background, 1 for foreground). That is, this assumes the shape of both prediction + and ground truth is B2HW[D]. + - This method cannot be used with multiclass segmentation. + For more information, refer to the original paper: https://arxiv.org/abs/2410.18684 + Args: include_background: whether to include distance computation on the first channel of the predicted output. Defaults to ``False``. @@ -51,6 +69,7 @@ class HausdorffDistanceMetric(CumulativeIterationMetric): ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans). Here `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric. + per_component: whether to compute the Hausdorff distance on a per-connected component basis. Defaults to ``False``. """ @@ -62,6 +81,7 @@ def __init__( directed: bool = False, reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False, + per_component: bool = False, ) -> None: super().__init__() self.include_background = include_background @@ -70,6 +90,7 @@ def __init__( self.directed = directed self.reduction = reduction self.get_not_nans = get_not_nans + self.per_component = per_component def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) -> torch.Tensor: # type: ignore[override] """ @@ -96,7 +117,17 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) dims = y_pred.ndimension() if dims < 3: raise ValueError("y_pred should have at least three dimensions.") - + if self.per_component: + if y_pred.ndim not in (4, 5) or y.ndim not in (4, 5) or y_pred.shape[1] != 2 or y.shape[1] != 2: + same_rank = y_pred.ndim == y.ndim and y_pred.ndim in (4, 5) + binary_channels = y_pred.shape[1] == 2 and y.shape[1] == 2 + same_shape = y_pred.shape == y.shape + if not (same_rank and binary_channels and same_shape): + raise ValueError( + "per_component requires matching 4D/5D binary tensors " + "(B, 2, H, W) or (B, 2, D, H, W). " + f"Got y_pred={tuple(y_pred.shape)}, y={tuple(y.shape)}." + ) # compute (BxC) for each channel for each batch return compute_hausdorff_distance( y_pred=y_pred, @@ -106,6 +137,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) percentile=self.percentile, directed=self.directed, spacing=kwargs.get("spacing"), + per_component=self.per_component, ) def aggregate( @@ -137,6 +169,7 @@ def compute_hausdorff_distance( percentile: float | None = None, directed: bool = False, spacing: int | float | np.ndarray | Sequence[int | float | np.ndarray | Sequence[int | float]] | None = None, + per_component: bool = False, ) -> torch.Tensor: """ Compute the Hausdorff distance. @@ -162,6 +195,7 @@ def compute_hausdorff_distance( If inner sequence has length 1, isotropic spacing with that value is used for all images in the batch, else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used for all images in batch. Defaults to ``None``. + per_component: whether to compute the Hausdorff distance on a per-connected component basis. Defaults to ``False``. """ if not include_background: @@ -179,17 +213,83 @@ def compute_hausdorff_distance( spacing_list = prepare_spacing(spacing=spacing, batch_size=batch_size, img_dim=img_dim) for b, c in np.ndindex(batch_size, n_class): - _, distances, _ = get_edge_surface_distance( - y_pred[b, c], - y[b, c], - distance_metric=distance_metric, - spacing=spacing_list[b], - symmetric=not directed, - class_index=c, - ) - percentile_distances = [_compute_percentile_hausdorff_distance(d, percentile) for d in distances] - max_distance = torch.max(torch.stack(percentile_distances)) - hd[b, c] = max_distance + if per_component: + if y[b, c].sum() == 0 or y_pred[b, c].sum() == 0: + if y_pred[b, c].sum() == 0 and y[b, c].sum() == 0: + hd[b, c] = 0.0 + else: + hd[b, c] = 1.0 + continue + cc_assignment = compute_voronoi_regions_fast(y_pred[b, c].cpu().numpy()) + if cc_assignment.device != y_pred[b, c].device: + cc_assignment = cc_assignment.to(y_pred[b, c].device) + max_list = [] + for cc_id in torch.unique(cc_assignment.view(-1)): + cc_mask = cc_assignment == cc_id + + coords = torch.nonzero(cc_mask, as_tuple=False) + min_corner_idx = coords.min(dim=0).values + max_corner_idx = coords.max(dim=0).values + + crop_pred = ( + y_pred[b, c][ + min_corner_idx[0] : max_corner_idx[0] + 1, + min_corner_idx[1] : max_corner_idx[1] + 1, + min_corner_idx[2] : max_corner_idx[2] + 1, + ] + if y_pred.ndim == 5 + else y_pred[b, c][ + min_corner_idx[0] : max_corner_idx[0] + 1, min_corner_idx[1] : max_corner_idx[1] + 1 + ] + ) + + crop_label = ( + y[b, c][ + min_corner_idx[0] : max_corner_idx[0] + 1, + min_corner_idx[1] : max_corner_idx[1] + 1, + min_corner_idx[2] : max_corner_idx[2] + 1, + ] + if y.ndim == 5 + else y[b, c][min_corner_idx[0] : max_corner_idx[0] + 1, min_corner_idx[1] : max_corner_idx[1] + 1] + ) + + cc_crop_mask = ( + cc_mask[ + min_corner_idx[0] : max_corner_idx[0] + 1, + min_corner_idx[1] : max_corner_idx[1] + 1, + min_corner_idx[2] : max_corner_idx[2] + 1, + ] + if y_pred.ndim == 5 + else cc_mask[min_corner_idx[0] : max_corner_idx[0] + 1, min_corner_idx[1] : max_corner_idx[1] + 1] + ) + + pred_masked = crop_pred * cc_crop_mask + label_masked = crop_label * cc_crop_mask + + _, distances, _ = get_edge_surface_distance( + pred_masked, + label_masked, + distance_metric=distance_metric, + spacing=spacing_list[b], + symmetric=not directed, + class_index=c, + ) + percentile_distances = [_compute_percentile_hausdorff_distance(d, percentile) for d in distances] + max_list.append(torch.stack(percentile_distances)) + + hd[b, c] = torch.nanmean(torch.stack(max_list)) if max_list else 0.0 + else: + _, distances, _ = get_edge_surface_distance( + y_pred[b, c], + y[b, c], + distance_metric=distance_metric, + spacing=spacing_list[b], + symmetric=not directed, + class_index=c, + ) + percentile_distances = [_compute_percentile_hausdorff_distance(d, percentile) for d in distances] + max_distance = torch.max(torch.stack(percentile_distances)) + hd[b, c] = max_distance return hd diff --git a/tests/metrics/test_hausdorff_distance.py b/tests/metrics/test_hausdorff_distance.py index 20276a1832..9b137e6d41 100644 --- a/tests/metrics/test_hausdorff_distance.py +++ b/tests/metrics/test_hausdorff_distance.py @@ -161,6 +161,45 @@ def create_spherical_seg_3d( for i, (metric, directed) in enumerate(product(["euclidean", "chessboard", "taxicab"], [True, False])): TEST_CASES_EXPANDED.append((_device, metric, directed, test_input, test_output[i])) +TEST_CASES_CC_METRICS = [] +y = torch.zeros((2, 2, 32, 32, 32), device=_device) +y_hat = torch.zeros((2, 2, 32, 32, 32), device=_device) +TEST_CASES_CC_METRICS.append([[y, y_hat], [[0.0], [0.0]]]) + +y = torch.zeros((2, 2, 32, 32, 32), device=_device) +y_hat = torch.zeros((2, 2, 32, 32, 32), device=_device) +y_hat[0, 1, 5:10, 5:10, 5:10] = 1 +y_hat[0, 0] = 1 - y_hat[0, 1] +TEST_CASES_CC_METRICS.append([[y, y_hat], [[1.0], [0.0]]]) + +y = torch.zeros((2, 2, 32, 32, 32), device=_device) +y_hat = torch.zeros((2, 2, 32, 32, 32), device=_device) +y[0, 1, 10:15, 10:15, 10:15] = 1 +y[0, 0] = 1 - y[0, 1] +y_hat[0, 1, 10:15, 10:15, 10:15] = 1 +y_hat[0, 0] = 1 - y_hat[0, 1] +TEST_CASES_CC_METRICS.append([[y, y_hat], [[0.0], [0.0]]]) + +y = torch.zeros((2, 2, 32, 32, 32), device=_device) +y_hat = torch.zeros((2, 2, 32, 32, 32), device=_device) +y[0, 1, 10:15, 10:15, 10:15] = 1 +y[0, 1, 20:25, 20:25, 20:25] = 1 +y[0, 0] = 1 - y[0, 1] +y_hat[0, 1, 11:16, 10:15, 10:15] = 1 +y_hat[0, 1, 21:26, 19:24, 20:25] = 1 +y_hat[0, 0] = 1 - y_hat[0, 1] +TEST_CASES_CC_METRICS.append([[y, y_hat], [[1.2071], [0.0]]]) + +y = torch.zeros((2, 2, 32, 32), device=_device) +y_hat = torch.zeros((2, 2, 32, 32), device=_device) +y[0, 1, 10:15, 10:15] = 1 +y[0, 1, 20:25, 20:25] = 1 +y[0, 0] = 1 - y[0, 1] +y_hat[0, 1, 10:15, 10:15] = 1 +y_hat[0, 1, 21:26, 19:24] = 1 +y_hat[0, 0] = 1 - y_hat[0, 1] +TEST_CASES_CC_METRICS.append([[y, y_hat], [[0.7071], [0.0]]]) + def _describe_test_case(test_func, test_number, params): _device, metric, directed, test_input, test_output = params.args @@ -204,6 +243,20 @@ def test_nans(self, input_data): np.testing.assert_allclose(0, result, rtol=1e-7) np.testing.assert_allclose(0, not_nans, rtol=1e-7) + @parameterized.expand(TEST_CASES_CC_METRICS) + def test_cc_metrics(self, input_data, expected_value): + [seg_1, seg_2] = input_data + seg_1 = torch.tensor(seg_1) + seg_2 = torch.tensor(seg_2) + hd_metric = HausdorffDistanceMetric(per_component=True) + hd_metric(seg_1, seg_2) + result = hd_metric.aggregate(reduction="none") + np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) + + def test_channel_dimensions(self): + with self.assertRaises(ValueError): + HausdorffDistanceMetric(per_component=True)(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 144, 144])) + if __name__ == "__main__": unittest.main() From 5b21617eb8ce82d59af865ffb1210a0f60c6b767 Mon Sep 17 00:00:00 2001 From: Vijay Vignesh Prasad Rao Date: Thu, 7 May 2026 10:17:07 -0400 Subject: [PATCH 2/7] Resolving coderabbitai comments --- monai/metrics/hausdorff_distance.py | 6 +++--- tests/metrics/test_hausdorff_distance.py | 20 ++++++++++---------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/monai/metrics/hausdorff_distance.py b/monai/metrics/hausdorff_distance.py index 2302ba3038..039603754c 100644 --- a/monai/metrics/hausdorff_distance.py +++ b/monai/metrics/hausdorff_distance.py @@ -223,7 +223,7 @@ def compute_hausdorff_distance( cc_assignment = compute_voronoi_regions_fast(y_pred[b, c].cpu().numpy()) if cc_assignment.device != y_pred[b, c].device: cc_assignment = cc_assignment.to(y_pred[b, c].device) - max_list = [] + component_scores = [] for cc_id in torch.unique(cc_assignment.view(-1)): cc_mask = cc_assignment == cc_id @@ -275,9 +275,9 @@ def compute_hausdorff_distance( class_index=c, ) percentile_distances = [_compute_percentile_hausdorff_distance(d, percentile) for d in distances] - max_list.append(torch.stack(percentile_distances)) + component_scores.append(torch.max(torch.stack(percentile_distances))) - hd[b, c] = torch.nanmean(torch.stack(max_list)) if max_list else 0.0 + hd[b, c] = torch.nanmean(torch.stack(component_scores)) if component_scores else 0.0 else: _, distances, _ = get_edge_surface_distance( y_pred[b, c], diff --git a/tests/metrics/test_hausdorff_distance.py b/tests/metrics/test_hausdorff_distance.py index 9b137e6d41..d31d52148d 100644 --- a/tests/metrics/test_hausdorff_distance.py +++ b/tests/metrics/test_hausdorff_distance.py @@ -162,26 +162,26 @@ def create_spherical_seg_3d( TEST_CASES_EXPANDED.append((_device, metric, directed, test_input, test_output[i])) TEST_CASES_CC_METRICS = [] -y = torch.zeros((2, 2, 32, 32, 32), device=_device) -y_hat = torch.zeros((2, 2, 32, 32, 32), device=_device) +y = torch.zeros((2, 2, 32, 32, 32), device=_devices[-1]) +y_hat = torch.zeros((2, 2, 32, 32, 32), device=_devices[-1]) TEST_CASES_CC_METRICS.append([[y, y_hat], [[0.0], [0.0]]]) -y = torch.zeros((2, 2, 32, 32, 32), device=_device) -y_hat = torch.zeros((2, 2, 32, 32, 32), device=_device) +y = torch.zeros((2, 2, 32, 32, 32), device=_devices[-1]) +y_hat = torch.zeros((2, 2, 32, 32, 32), device=_devices[-1]) y_hat[0, 1, 5:10, 5:10, 5:10] = 1 y_hat[0, 0] = 1 - y_hat[0, 1] TEST_CASES_CC_METRICS.append([[y, y_hat], [[1.0], [0.0]]]) -y = torch.zeros((2, 2, 32, 32, 32), device=_device) -y_hat = torch.zeros((2, 2, 32, 32, 32), device=_device) +y = torch.zeros((2, 2, 32, 32, 32), device=_devices[-1]) +y_hat = torch.zeros((2, 2, 32, 32, 32), device=_devices[-1]) y[0, 1, 10:15, 10:15, 10:15] = 1 y[0, 0] = 1 - y[0, 1] y_hat[0, 1, 10:15, 10:15, 10:15] = 1 y_hat[0, 0] = 1 - y_hat[0, 1] TEST_CASES_CC_METRICS.append([[y, y_hat], [[0.0], [0.0]]]) -y = torch.zeros((2, 2, 32, 32, 32), device=_device) -y_hat = torch.zeros((2, 2, 32, 32, 32), device=_device) +y = torch.zeros((2, 2, 32, 32, 32), device=_devices[-1]) +y_hat = torch.zeros((2, 2, 32, 32, 32), device=_devices[-1]) y[0, 1, 10:15, 10:15, 10:15] = 1 y[0, 1, 20:25, 20:25, 20:25] = 1 y[0, 0] = 1 - y[0, 1] @@ -190,8 +190,8 @@ def create_spherical_seg_3d( y_hat[0, 0] = 1 - y_hat[0, 1] TEST_CASES_CC_METRICS.append([[y, y_hat], [[1.2071], [0.0]]]) -y = torch.zeros((2, 2, 32, 32), device=_device) -y_hat = torch.zeros((2, 2, 32, 32), device=_device) +y = torch.zeros((2, 2, 32, 32), device=_devices[-1]) +y_hat = torch.zeros((2, 2, 32, 32), device=_devices[-1]) y[0, 1, 10:15, 10:15] = 1 y[0, 1, 20:25, 20:25] = 1 y[0, 0] = 1 - y[0, 1] From feb29384478ff26e36e083196c4d60170eae04c8 Mon Sep 17 00:00:00 2001 From: Vijay Vignesh Prasad Rao Date: Thu, 7 May 2026 10:19:55 -0400 Subject: [PATCH 3/7] Resolving coderabbitai comments Signed-off-by: Vijay Vignesh Prasad Rao From e3ff3c2d850d85bab3b35462ab16e503efafa339 Mon Sep 17 00:00:00 2001 From: Vijay Vignesh Prasad Rao Date: Thu, 7 May 2026 10:23:58 -0400 Subject: [PATCH 4/7] Signing Off Signed-off-by: Vijay Vignesh Prasad Rao --- monai/metrics/hausdorff_distance.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/metrics/hausdorff_distance.py b/monai/metrics/hausdorff_distance.py index 039603754c..17aa2a9660 100644 --- a/monai/metrics/hausdorff_distance.py +++ b/monai/metrics/hausdorff_distance.py @@ -70,7 +70,6 @@ class HausdorffDistanceMetric(CumulativeIterationMetric): get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans). Here `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric. per_component: whether to compute the Hausdorff distance on a per-connected component basis. Defaults to ``False``. - """ def __init__( From af863a94e81a62d05a64b47426c9faec304cd3ba Mon Sep 17 00:00:00 2001 From: Vijay Vignesh Prasad Rao Date: Thu, 7 May 2026 10:25:03 -0400 Subject: [PATCH 5/7] DCO Remediation Commit for Vijay Vignesh Prasad Rao I, Vijay Vignesh Prasad Rao , hereby add my Signed-off-by to this commit: 5b21617eb8ce82d59af865ffb1210a0f60c6b767 Signed-off-by: Vijay Vignesh Prasad Rao From 09a467ca3f1247648cbe3b8d8d6479ea18a66366 Mon Sep 17 00:00:00 2001 From: Vijay Vignesh Prasad Rao Date: Thu, 7 May 2026 10:43:22 -0400 Subject: [PATCH 6/7] Resolving bugs Signed-off-by: Vijay Vignesh Prasad Rao --- monai/metrics/hausdorff_distance.py | 11 +++++------ tests/metrics/test_hausdorff_distance.py | 2 +- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/monai/metrics/hausdorff_distance.py b/monai/metrics/hausdorff_distance.py index 17aa2a9660..82a15761cb 100644 --- a/monai/metrics/hausdorff_distance.py +++ b/monai/metrics/hausdorff_distance.py @@ -213,13 +213,12 @@ def compute_hausdorff_distance( for b, c in np.ndindex(batch_size, n_class): if per_component: - if y[b, c].sum() == 0 or y_pred[b, c].sum() == 0: - if y_pred[b, c].sum() == 0 and y[b, c].sum() == 0: - hd[b, c] = 0.0 - else: - hd[b, c] = 1.0 + pred_empty = y_pred[b, c].sum() == 0 + label_empty = y[b, c].sum() == 0 + if pred_empty and label_empty: + hd[b, c] = 0.0 if (pred_empty and label_empty) else float("nan") continue - cc_assignment = compute_voronoi_regions_fast(y_pred[b, c].cpu().numpy()) + cc_assignment = compute_voronoi_regions_fast(y[b, c].cpu().numpy()) if cc_assignment.device != y_pred[b, c].device: cc_assignment = cc_assignment.to(y_pred[b, c].device) component_scores = [] diff --git a/tests/metrics/test_hausdorff_distance.py b/tests/metrics/test_hausdorff_distance.py index d31d52148d..e264549fae 100644 --- a/tests/metrics/test_hausdorff_distance.py +++ b/tests/metrics/test_hausdorff_distance.py @@ -170,7 +170,7 @@ def create_spherical_seg_3d( y_hat = torch.zeros((2, 2, 32, 32, 32), device=_devices[-1]) y_hat[0, 1, 5:10, 5:10, 5:10] = 1 y_hat[0, 0] = 1 - y_hat[0, 1] -TEST_CASES_CC_METRICS.append([[y, y_hat], [[1.0], [0.0]]]) +TEST_CASES_CC_METRICS.append([[y, y_hat], [[float('inf')], [0.0]]]) y = torch.zeros((2, 2, 32, 32, 32), device=_devices[-1]) y_hat = torch.zeros((2, 2, 32, 32, 32), device=_devices[-1]) From 4611fb1a5421dd0e8b9ffc9a6f51740499443b7c Mon Sep 17 00:00:00 2001 From: Vijay Vignesh Prasad Rao Date: Thu, 7 May 2026 10:57:05 -0400 Subject: [PATCH 7/7] Resolving linting bugs Signed-off-by: Vijay Vignesh Prasad Rao --- tests/metrics/test_hausdorff_distance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/metrics/test_hausdorff_distance.py b/tests/metrics/test_hausdorff_distance.py index e264549fae..4b1b3fa654 100644 --- a/tests/metrics/test_hausdorff_distance.py +++ b/tests/metrics/test_hausdorff_distance.py @@ -170,7 +170,7 @@ def create_spherical_seg_3d( y_hat = torch.zeros((2, 2, 32, 32, 32), device=_devices[-1]) y_hat[0, 1, 5:10, 5:10, 5:10] = 1 y_hat[0, 0] = 1 - y_hat[0, 1] -TEST_CASES_CC_METRICS.append([[y, y_hat], [[float('inf')], [0.0]]]) +TEST_CASES_CC_METRICS.append([[y, y_hat], [[float("inf")], [0.0]]]) y = torch.zeros((2, 2, 32, 32, 32), device=_devices[-1]) y_hat = torch.zeros((2, 2, 32, 32, 32), device=_devices[-1])