@@ -66,7 +66,13 @@ def th_map_coordinates(input, coords, order=1):
6666
6767def sp_batch_map_coordinates (inputs , coords ):
6868 """Reference implementation for batch_map_coordinates"""
69- coords = coords .clip (0 , inputs .shape [1 ] - 1 )
69+ # coords = coords.clip(0, inputs.shape[1] - 1)
70+
71+ assert (coords .shape [2 ] == 2 )
72+ height = coords [:,:,0 ].clip (0 , inputs .shape [1 ] - 1 )
73+ width = coords [:,:,1 ].clip (0 , inputs .shape [2 ] - 1 )
74+ np .concatenate ((np .expand_dims (height , axis = 2 ), np .expand_dims (width , axis = 2 )), 2 )
75+
7076 mapped_vals = np .array ([
7177 sp_map_coordinates (input , coord .T , mode = 'nearest' , order = 1 )
7278 for input , coord in zip (inputs , coords )
@@ -87,15 +93,22 @@ def th_batch_map_coordinates(input, coords, order=1):
8793 """
8894
8995 batch_size = input .size (0 )
90- input_size = input .size (1 )
96+ input_height = input .size (1 )
97+ input_width = input .size (2 )
98+
9199 n_coords = coords .size (1 )
92100
93- coords = torch .clamp (coords , 0 , input_size - 1 )
101+ # coords = torch.clamp(coords, 0, input_size - 1)
102+
103+ coords = torch .cat ((torch .clamp (coords .narrow (2 , 0 , 1 ), 0 , input_height - 1 ), torch .clamp (coords .narrow (2 , 1 , 1 ), 0 , input_width - 1 )), 2 )
104+
105+ assert (coords .size (1 ) == n_coords )
106+
94107 coords_lt = coords .floor ().long ()
95108 coords_rb = coords .ceil ().long ()
96109 coords_lb = torch .stack ([coords_lt [..., 0 ], coords_rb [..., 1 ]], 2 )
97110 coords_rt = torch .stack ([coords_rb [..., 0 ], coords_lt [..., 1 ]], 2 )
98- idx = th_repeat (torch .range (0 , batch_size - 1 ), n_coords ).long ()
111+ idx = th_repeat (torch .arange (0 , batch_size ), n_coords ).long ()
99112 idx = Variable (idx , requires_grad = False )
100113 if input .is_cuda :
101114 idx = idx .cuda ()
@@ -108,7 +121,7 @@ def _get_vals_by_coords(input, coords):
108121 vals = th_flatten (input ).index_select (0 , inds )
109122 vals = vals .view (batch_size , n_coords )
110123 return vals
111-
124+
112125 vals_lt = _get_vals_by_coords (input , coords_lt .detach ())
113126 vals_rb = _get_vals_by_coords (input , coords_rb .detach ())
114127 vals_lb = _get_vals_by_coords (input , coords_lb .detach ())
@@ -125,21 +138,22 @@ def sp_batch_map_offsets(input, offsets):
125138 """Reference implementation for tf_batch_map_offsets"""
126139
127140 batch_size = input .shape [0 ]
128- input_size = input .shape [1 ]
141+ input_height = input .shape [1 ]
142+ input_width = input .shape [2 ]
129143
130144 offsets = offsets .reshape (batch_size , - 1 , 2 )
131- grid = np .stack (np .mgrid [:input_size , :input_size ], - 1 ).reshape (- 1 , 2 )
145+ grid = np .stack (np .mgrid [:input_height , :input_width ], - 1 ).reshape (- 1 , 2 )
132146 grid = np .repeat ([grid ], batch_size , axis = 0 )
133147 coords = offsets + grid
134- coords = coords .clip (0 , input_size - 1 )
148+ # coords = coords.clip(0, input_size - 1)
135149
136150 mapped_vals = sp_batch_map_coordinates (input , coords )
137151 return mapped_vals
138152
139153
140- def th_generate_grid (batch_size , input_size , dtype , cuda ):
154+ def th_generate_grid (batch_size , input_height , input_width , dtype , cuda ):
141155 grid = np .meshgrid (
142- range (input_size ), range (input_size ), indexing = 'ij'
156+ range (input_height ), range (input_width ), indexing = 'ij'
143157 )
144158 grid = np .stack (grid , axis = - 1 )
145159 grid = grid .reshape (- 1 , 2 )
@@ -162,11 +176,12 @@ def th_batch_map_offsets(input, offsets, grid=None, order=1):
162176 torch.Tensor. shape = (b, s, s)
163177 """
164178 batch_size = input .size (0 )
165- input_size = input .size (1 )
179+ input_height = input .size (1 )
180+ input_width = input .size (2 )
166181
167182 offsets = offsets .view (batch_size , - 1 , 2 )
168183 if grid is None :
169- grid = th_generate_grid (batch_size , input_size , offsets .data .type (), offsets .data .is_cuda )
184+ grid = th_generate_grid (batch_size , input_height , input_width , offsets .data .type (), offsets .data .is_cuda )
170185
171186 coords = offsets + grid
172187
0 commit comments