From 23a07f5a2d9e31fb48893f9138bf5b67a86d4142 Mon Sep 17 00:00:00 2001 From: rouniuyizu Date: Fri, 28 Apr 2017 15:18:58 +0800 Subject: [PATCH] support different H*W --- tests/test_deform_conv.py | 19 ++++++++++------- torch_deform_conv/deform_conv.py | 35 +++++++++++++++++++++++--------- torch_deform_conv/layers.py | 8 ++++---- 3 files changed, 41 insertions(+), 21 deletions(-) diff --git a/tests/test_deform_conv.py b/tests/test_deform_conv.py index d157636..0576999 100644 --- a/tests/test_deform_conv.py +++ b/tests/test_deform_conv.py @@ -24,8 +24,8 @@ def test_th_map_coordinates(): def test_th_batch_map_coordinates(): np.random.seed(42) - input = np.random.random((4, 100, 100)) - coords = (np.random.random((4, 200, 2)) * 99) + input = np.random.random((4, 100, 156)) + coords = np.random.random((4, 100*156, 2)) * 99 sp_mapped_vals = sp_batch_map_coordinates(input, coords) th_mapped_vals = th_batch_map_coordinates( @@ -36,8 +36,8 @@ def test_th_batch_map_coordinates(): def test_th_batch_map_offsets(): np.random.seed(42) - input = np.random.random((4, 100, 100)) - offsets = (np.random.random((4, 100, 100, 2)) * 2) + input = np.random.random((4, 100, 156)) + offsets = (np.random.random((4, 100, 156, 2)) * 2) sp_mapped_vals = sp_batch_map_offsets(input, offsets) th_mapped_vals = th_batch_map_offsets( @@ -48,14 +48,19 @@ def test_th_batch_map_offsets(): def test_th_batch_map_offsets_grad(): np.random.seed(42) - input = np.random.random((4, 100, 100)) - offsets = (np.random.random((4, 100, 100, 2)) * 2) + input = np.random.random((4, 100, 156)) + offsets = (np.random.random((4, 100, 156, 2)) * 2) input = Variable(torch.from_numpy(input), requires_grad=True) offsets = Variable(torch.from_numpy(offsets), requires_grad=True) th_mapped_vals = th_batch_map_offsets(input, offsets) - e = torch.from_numpy(np.random.random((4, 100, 100))) + e = torch.from_numpy(np.random.random((4, 100, 156))) th_mapped_vals.backward(e) assert not np.allclose(input.grad.data.numpy(), 0) assert not np.allclose(offsets.grad.data.numpy(), 0) + +if __name__ == '__main__': + test_th_batch_map_coordinates() + test_th_batch_map_offsets() + test_th_batch_map_offsets_grad() \ No newline at end of file diff --git a/torch_deform_conv/deform_conv.py b/torch_deform_conv/deform_conv.py index 21ed8b5..bcc3746 100644 --- a/torch_deform_conv/deform_conv.py +++ b/torch_deform_conv/deform_conv.py @@ -66,7 +66,13 @@ def th_map_coordinates(input, coords, order=1): def sp_batch_map_coordinates(inputs, coords): """Reference implementation for batch_map_coordinates""" - coords = coords.clip(0, inputs.shape[1] - 1) + # coords = coords.clip(0, inputs.shape[1] - 1) + + assert (coords.shape[2] == 2) + height = coords[:,:,0].clip(0, inputs.shape[1] - 1) + weight = coords[:,:,1].clip(0, inputs.shape[2] - 1) + np.concatenate((np.expand_dims(height, axis=2), np.expand_dims(weight, axis=2)), 2) + mapped_vals = np.array([ sp_map_coordinates(input, coord.T, mode='nearest', order=1) for input, coord in zip(inputs, coords) @@ -87,10 +93,17 @@ def th_batch_map_coordinates(input, coords, order=1): """ batch_size = input.size(0) - input_size = input.size(1) + input_height = input.size(1) + input_weight = input.size(2) + n_coords = coords.size(1) - coords = torch.clamp(coords, 0, input_size - 1) + # coords = torch.clamp(coords, 0, input_size - 1) + + coords = torch.cat((torch.clamp(coords.narrow(2, 0, 1), 0, input_height - 1), torch.clamp(coords.narrow(2, 1, 1), 0, input_weight - 1)), 2) + + assert (coords.size(1) == n_coords) + coords_lt = coords.floor().long() coords_rb = coords.ceil().long() coords_lb = torch.stack([coords_lt[..., 0], coords_rb[..., 1]], 2) @@ -125,21 +138,22 @@ def sp_batch_map_offsets(input, offsets): """Reference implementation for tf_batch_map_offsets""" batch_size = input.shape[0] - input_size = input.shape[1] + input_height = input.shape[1] + input_weight = input.shape[2] offsets = offsets.reshape(batch_size, -1, 2) - grid = np.stack(np.mgrid[:input_size, :input_size], -1).reshape(-1, 2) + grid = np.stack(np.mgrid[:input_height, :input_weight], -1).reshape(-1, 2) grid = np.repeat([grid], batch_size, axis=0) coords = offsets + grid - coords = coords.clip(0, input_size - 1) + # coords = coords.clip(0, input_size - 1) mapped_vals = sp_batch_map_coordinates(input, coords) return mapped_vals -def th_generate_grid(batch_size, input_size, dtype, cuda): +def th_generate_grid(batch_size, input_height, input_weight, dtype, cuda): grid = np.meshgrid( - range(input_size), range(input_size), indexing='ij' + range(input_height), range(input_weight), indexing='ij' ) grid = np.stack(grid, axis=-1) grid = grid.reshape(-1, 2) @@ -162,11 +176,12 @@ def th_batch_map_offsets(input, offsets, grid=None, order=1): torch.Tensor. shape = (b, s, s) """ batch_size = input.size(0) - input_size = input.size(1) + input_height = input.size(1) + input_weight = input.size(2) offsets = offsets.view(batch_size, -1, 2) if grid is None: - grid = th_generate_grid(batch_size, input_size, offsets.data.type(), offsets.data.is_cuda) + grid = th_generate_grid(batch_size, input_height, input_weight, offsets.data.type(), offsets.data.is_cuda) coords = offsets + grid diff --git a/torch_deform_conv/layers.py b/torch_deform_conv/layers.py index 721d64d..c1a1bbc 100644 --- a/torch_deform_conv/layers.py +++ b/torch_deform_conv/layers.py @@ -54,12 +54,12 @@ def forward(self, x): @staticmethod def _get_grid(self, x): - batch_size, input_size= x.size(0), x.size(1) + batch_size, input_height, input_weight = x.size(0), x.size(1), x.size(2) dtype, cuda = x.data.type(), x.data.is_cuda - if self._grid_param == (batch_size, input_size, dtype, cuda): + if self._grid_param == (batch_size, input_height, input_weight, dtype, cuda): return self._grid - self._grid_param = (batch_size, input_size, dtype, cuda) - self._grid = th_generate_grid(batch_size, input_size, dtype, cuda) + self._grid_param = (batch_size, input_height, input_weight, dtype, cuda) + self._grid = th_generate_grid(batch_size, input_height, input_weight, dtype, cuda) return self._grid @staticmethod