From 02e449adbfb0c25b6410abab963d16307b904b45 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Tue, 14 Oct 2025 13:07:33 -0700 Subject: [PATCH] Fix bug with unit sized dims and block_sizes stack-info: PR: https://github.com/pytorch/helion/pull/932, branch: jansel/stack/191 --- helion/_compiler/compile_environment.py | 8 ++++--- test/test_matmul.py | 28 +++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/helion/_compiler/compile_environment.py b/helion/_compiler/compile_environment.py index d620a9890..27c39ca34 100644 --- a/helion/_compiler/compile_environment.py +++ b/helion/_compiler/compile_environment.py @@ -582,9 +582,11 @@ def from_config( @dataclasses.dataclass class LoopSpecBlockSizeSource(BlockSizeSource): def from_config(self, config: Config, block_size_info: BlockSizeInfo) -> int: - index = CompileEnvironment.current().config_spec.block_sizes.block_id_to_index( - block_size_info.block_id - ) + env = CompileEnvironment.current() + size = block_size_info.size + if isinstance(size, (int, torch.SymInt)) and env.known_equal(size, 1): + return 1 + index = env.config_spec.block_sizes.block_id_to_index(block_size_info.block_id) return config.block_sizes[index] diff --git a/test/test_matmul.py b/test/test_matmul.py index 4c797bb01..77e1c66ac 100644 --- a/test/test_matmul.py +++ b/test/test_matmul.py @@ -14,6 +14,7 @@ from helion._testing import code_and_output from helion._testing import import_path from helion._testing import skipIfRefEager +from helion._testing import skipIfRocm import helion.language as hl torch.backends.cuda.matmul.fp32_precision = "tf32" @@ -272,6 +273,33 @@ def matmul_split_k(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: torch.testing.assert_close(result, expected, atol=1e-1, rtol=1e-2) self.assertExpectedJournal(code) + @skipIfRocm("ROCm triton error in TritonAMDGPUBlockPingpong") + @skipIfRefEager("config_spec is not supported in ref eager mode") + def test_matmul_config_reuse_with_unit_dim(self): + torch.manual_seed(0) + big_args = ( + torch.randn([64, 64], device=DEVICE, dtype=torch.float32), + torch.randn([64, 64], device=DEVICE, dtype=torch.float32), + ) + big_bound = matmul_with_addmm.bind(big_args) + big_spec = big_bound.config_spec + self.assertEqual(len(big_spec.block_sizes), 3) + big_config = big_spec.default_config() + + small_args = ( + torch.randn([1, 64], device=DEVICE, dtype=torch.float32), + torch.randn([64, 64], device=DEVICE, dtype=torch.float32), + ) + small_bound = matmul_with_addmm.bind(small_args) + small_spec = small_bound.config_spec + self.assertEqual(len(small_spec.block_sizes), 3) + + # Previously raised when reusing configs tuned on larger shapes. + small_bound.set_config(big_config) + result = small_bound(*small_args) + expected = small_args[0] @ small_args[1] + torch.testing.assert_close(result, expected, atol=1e-1, rtol=1e-2) + def test_matmul_packed_rhs(self): @helion.kernel(static_shapes=False) def matmul_with_packed_b(