|
1 | | - #!/usr/bin/env python3 |
| 1 | +#!/usr/bin/env python3 |
2 | 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | 3 | # All rights reserved. |
4 | 4 | # |
@@ -239,34 +239,6 @@ def compare_tensor_groups( |
239 | 239 | {"rtol": 1e-02, "atol": 1e-02} if dtype == torch.half else {}, |
240 | 240 | ) |
241 | 241 |
|
242 | | - @unittest.skipIf(not gpu_available, "CUDA not available") |
243 | | - def test_group_index_select_dim0_duplicate_gradients(self) -> None: |
244 | | - device = torch.device("cuda") |
245 | | - dtype = torch.float |
246 | | - |
247 | | - num_rows = 4 |
248 | | - num_cols = 9 |
249 | | - indices = torch.tensor([0, 1, 2, 1, 0, 2], dtype=torch.long, device=device) |
250 | | - |
251 | | - input_tensor = torch.randn( |
252 | | - (num_rows, num_cols), dtype=dtype, device=device |
253 | | - ).requires_grad_(True) |
254 | | - |
255 | | - output_group = torch.ops.fbgemm.group_index_select_dim0( |
256 | | - [input_tensor], [indices] |
257 | | - ) |
258 | | - output = output_group[0] |
259 | | - |
260 | | - grad = torch.arange( |
261 | | - output.numel(), dtype=dtype, device=device |
262 | | - ).view_as(output) |
263 | | - output.backward(grad) |
264 | | - |
265 | | - ref_grad = torch.zeros_like(input_tensor) |
266 | | - ref_grad.index_add_(0, indices, grad) |
267 | | - |
268 | | - torch.testing.assert_close(input_tensor.grad, ref_grad) |
269 | | - |
270 | 242 | @given( |
271 | 243 | num_inputs=st.integers(0, 100), |
272 | 244 | max_input_rows=st.integers(2, 32), |
|
0 commit comments