Skip to content

Commit a1a4fee

Browse files
abcarlislepytorchmergebot
authored andcommitted
Native channel shuffle floating point exception (pytorch#144010)
Fixes pytorch#142453 Added TORCH_CHECKS to prevent the user from using the native_channel_shuffle function incorrectly and getting a "Floating point exception (core dumped)" Pull Request resolved: pytorch#144010 Approved by: https://github.com/albanD
1 parent 8f420a5 commit a1a4fee

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

aten/src/ATen/native/ChanelShuffle.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,17 @@
2020
namespace at::native {
2121

2222
Tensor channel_shuffle_cpu(const Tensor& self, int64_t groups) {
23+
TORCH_CHECK(self.dim() > 2,
24+
"channel_shuffle expects input with > 2 dims, but got input with sizes ",
25+
self.sizes());
26+
int64_t c = self.size(1);
27+
TORCH_CHECK(groups > 0,
28+
"Number of groups to divide channels in must be positive.",
29+
" Value of groups:", groups);
30+
TORCH_CHECK((c % groups) == 0,
31+
"Number of channels must be divisible by groups. Got ",
32+
c, " channels and ", groups, " groups.");
33+
2334
Tensor output;
2435
if (self.numel() == 0) {
2536
output = self.alias();

test/test_nn.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6400,6 +6400,24 @@ def test_channel_shuffle_return_alias_of_self(self):
64006400
output = torch.nn.ChannelShuffle(groups)(input_tensor)
64016401
torch.testing.assert_close(output, input_tensor)
64026402

6403+
def test_channel_shuffle_input_checks(self):
6404+
input_tensor = torch.rand([1, 3, 2, 2])
6405+
with self.assertRaisesRegex(RuntimeError,
6406+
"Number of groups to divide channels in must be positive.*"):
6407+
groups = 0
6408+
torch.native_channel_shuffle(input_tensor, groups)
6409+
6410+
with self.assertRaisesRegex(RuntimeError,
6411+
"Number of channels must be divisible by groups.*"):
6412+
groups = 2
6413+
torch.native_channel_shuffle(input_tensor, groups)
6414+
6415+
with self.assertRaisesRegex(RuntimeError,
6416+
"channel_shuffle expects input with > 2 dims,.*"):
6417+
input_tensor = torch.rand([1, 2])
6418+
groups = 2
6419+
torch.native_channel_shuffle(input_tensor, groups)
6420+
64036421
@skipIfTorchDynamo("TorchDynamo fails here for unknown reasons")
64046422
def test_native_channel_shuffle_return_alias_of_self(self):
64056423
groups = 3

0 commit comments

Comments
 (0)