Skip to content
This repository was archived by the owner on Jul 21, 2021. It is now read-only.

Commit e7a664f

Browse files
committed
manually merge pull request #3
1 parent b2076dc commit e7a664f

File tree

2 files changed

+32
-17
lines changed

2 files changed

+32
-17
lines changed

torch_deform_conv/deform_conv.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,13 @@ def th_map_coordinates(input, coords, order=1):
6666

6767
def sp_batch_map_coordinates(inputs, coords):
6868
"""Reference implementation for batch_map_coordinates"""
69-
coords = coords.clip(0, inputs.shape[1] - 1)
69+
# coords = coords.clip(0, inputs.shape[1] - 1)
70+
71+
assert (coords.shape[2] == 2)
72+
height = coords[:,:,0].clip(0, inputs.shape[1] - 1)
73+
width = coords[:,:,1].clip(0, inputs.shape[2] - 1)
74+
np.concatenate((np.expand_dims(height, axis=2), np.expand_dims(width, axis=2)), 2)
75+
7076
mapped_vals = np.array([
7177
sp_map_coordinates(input, coord.T, mode='nearest', order=1)
7278
for input, coord in zip(inputs, coords)
@@ -87,15 +93,22 @@ def th_batch_map_coordinates(input, coords, order=1):
8793
"""
8894

8995
batch_size = input.size(0)
90-
input_size = input.size(1)
96+
input_height = input.size(1)
97+
input_width = input.size(2)
98+
9199
n_coords = coords.size(1)
92100

93-
coords = torch.clamp(coords, 0, input_size - 1)
101+
# coords = torch.clamp(coords, 0, input_size - 1)
102+
103+
coords = torch.cat((torch.clamp(coords.narrow(2, 0, 1), 0, input_height - 1), torch.clamp(coords.narrow(2, 1, 1), 0, input_width - 1)), 2)
104+
105+
assert (coords.size(1) == n_coords)
106+
94107
coords_lt = coords.floor().long()
95108
coords_rb = coords.ceil().long()
96109
coords_lb = torch.stack([coords_lt[..., 0], coords_rb[..., 1]], 2)
97110
coords_rt = torch.stack([coords_rb[..., 0], coords_lt[..., 1]], 2)
98-
idx = th_repeat(torch.range(0, batch_size-1), n_coords).long()
111+
idx = th_repeat(torch.arange(0, batch_size), n_coords).long()
99112
idx = Variable(idx, requires_grad=False)
100113
if input.is_cuda:
101114
idx = idx.cuda()
@@ -108,7 +121,7 @@ def _get_vals_by_coords(input, coords):
108121
vals = th_flatten(input).index_select(0, inds)
109122
vals = vals.view(batch_size, n_coords)
110123
return vals
111-
124+
112125
vals_lt = _get_vals_by_coords(input, coords_lt.detach())
113126
vals_rb = _get_vals_by_coords(input, coords_rb.detach())
114127
vals_lb = _get_vals_by_coords(input, coords_lb.detach())
@@ -125,21 +138,22 @@ def sp_batch_map_offsets(input, offsets):
125138
"""Reference implementation for tf_batch_map_offsets"""
126139

127140
batch_size = input.shape[0]
128-
input_size = input.shape[1]
141+
input_height = input.shape[1]
142+
input_width = input.shape[2]
129143

130144
offsets = offsets.reshape(batch_size, -1, 2)
131-
grid = np.stack(np.mgrid[:input_size, :input_size], -1).reshape(-1, 2)
145+
grid = np.stack(np.mgrid[:input_height, :input_width], -1).reshape(-1, 2)
132146
grid = np.repeat([grid], batch_size, axis=0)
133147
coords = offsets + grid
134-
coords = coords.clip(0, input_size - 1)
148+
# coords = coords.clip(0, input_size - 1)
135149

136150
mapped_vals = sp_batch_map_coordinates(input, coords)
137151
return mapped_vals
138152

139153

140-
def th_generate_grid(batch_size, input_size, dtype, cuda):
154+
def th_generate_grid(batch_size, input_height, input_width, dtype, cuda):
141155
grid = np.meshgrid(
142-
range(input_size), range(input_size), indexing='ij'
156+
range(input_height), range(input_width), indexing='ij'
143157
)
144158
grid = np.stack(grid, axis=-1)
145159
grid = grid.reshape(-1, 2)
@@ -162,11 +176,12 @@ def th_batch_map_offsets(input, offsets, grid=None, order=1):
162176
torch.Tensor. shape = (b, s, s)
163177
"""
164178
batch_size = input.size(0)
165-
input_size = input.size(1)
179+
input_height = input.size(1)
180+
input_width = input.size(2)
166181

167182
offsets = offsets.view(batch_size, -1, 2)
168183
if grid is None:
169-
grid = th_generate_grid(batch_size, input_size, offsets.data.type(), offsets.data.is_cuda)
184+
grid = th_generate_grid(batch_size, input_height, input_width, offsets.data.type(), offsets.data.is_cuda)
170185

171186
coords = offsets + grid
172187

torch_deform_conv/layers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,14 @@ def forward(self, x):
5454

5555
@staticmethod
5656
def _get_grid(self, x):
57-
batch_size, input_size= x.size(0), x.size(1)
57+
batch_size, input_height, input_width = x.size(0), x.size(1), x.size(2)
5858
dtype, cuda = x.data.type(), x.data.is_cuda
59-
if self._grid_param == (batch_size, input_size, dtype, cuda):
59+
if self._grid_param == (batch_size, input_height, input_width, dtype, cuda):
6060
return self._grid
61-
self._grid_param = (batch_size, input_size, dtype, cuda)
62-
self._grid = th_generate_grid(batch_size, input_size, dtype, cuda)
61+
self._grid_param = (batch_size, input_height, input_width, dtype, cuda)
62+
self._grid = th_generate_grid(batch_size, input_height, input_width, dtype, cuda)
6363
return self._grid
64-
64+
6565
@staticmethod
6666
def _init_weights(weights, std):
6767
fan_out = weights.size(0)

0 commit comments

Comments
 (0)