diff --git a/monai/apps/detection/utils/anchor_utils.py b/monai/apps/detection/utils/anchor_utils.py index cbde3ebae9..c3cbe5f14c 100644 --- a/monai/apps/detection/utils/anchor_utils.py +++ b/monai/apps/detection/utils/anchor_utils.py @@ -253,7 +253,7 @@ def grid_anchors(self, grid_sizes: list[list[int]], strides: list[list[Tensor]]) # compute anchor centers regarding to the image. # shifts_centers is [x_center, y_center] or [x_center, y_center, z_center] shifts_centers = [ - torch.arange(0, size[axis], dtype=torch.int32, device=device) * stride[axis] + torch.arange(0, size[axis], dtype=torch.int32, device=device) * stride[axis] + stride[axis] // 2 for axis in range(self.spatial_dims) ] diff --git a/tests/apps/detection/utils/test_anchor_box.py b/tests/apps/detection/utils/test_anchor_box.py index 7543c84ed9..22bf350f0b 100644 --- a/tests/apps/detection/utils/test_anchor_box.py +++ b/tests/apps/detection/utils/test_anchor_box.py @@ -45,9 +45,9 @@ class TestAnchorGenerator(unittest.TestCase): @parameterized.expand(TEST_CASES_2D) def test_anchor_2d(self, input_param, image_shape, feature_maps_shapes): torch_anchor_utils, _ = optional_import("torchvision.models.detection.anchor_utils") - image_list, _ = optional_import("torchvision.models.detection.image_list") - # test it behaves the same with torchvision for 2d + # test it behaves for new functionality of centered anchors + # pytorch does not follow this functionality anchor = AnchorGenerator(**input_param, indexing="xy") anchor_ref = torch_anchor_utils.AnchorGenerator(**input_param) for a, a_f in zip(anchor.cell_anchors, anchor_ref.cell_anchors): @@ -57,15 +57,18 @@ def test_anchor_2d(self, input_param, image_shape, feature_maps_shapes): grid_sizes = [[2, 2], [1, 1]] strides = [[torch.tensor(1), torch.tensor(2)], [torch.tensor(2), torch.tensor(4)]] - for a, a_f in zip(anchor.grid_anchors(grid_sizes, strides), anchor_ref.grid_anchors(grid_sizes, strides)): - assert_allclose(a, a_f, type_test=True, device_test=False, atol=1e-3) - images = torch.rand(image_shape) - feature_maps = tuple(torch.rand(fs) for fs in feature_maps_shapes) - result = anchor(images, feature_maps) - result_ref = anchor_ref(image_list.ImageList(images, ([123, 122],)), feature_maps) - for a, a_f in zip(result, result_ref): - assert_allclose(a, a_f, type_test=True, device_test=False, atol=0.1) + monai_anchors = anchor.grid_anchors(grid_sizes, strides) + torchvision_anchors = anchor_ref.grid_anchors(grid_sizes, strides) + + for a, a_f, s in zip(monai_anchors, torchvision_anchors, strides): + stride_y, stride_x = s + + offset_x = stride_x // 2 + offset_y = stride_y // 2 + offset = torch.tensor([offset_x, offset_y, offset_x, offset_y], dtype=a_f.dtype, device=a_f.device) + + assert_allclose(a, a_f + offset, type_test=True, device_test=False, atol=1e-3) @parameterized.expand(TEST_CASES_2D) def test_script_2d(self, input_param, image_shape, feature_maps_shapes):