Skip to content

Commit b1d6da6

Browse files
authored
Use HELION_PRINT_REPRO=1 to print repro when device IR lowering or Triton codegen error (#1078)
1 parent 5ef76af commit b1d6da6

File tree

2 files changed

+99
-3
lines changed

2 files changed

+99
-3
lines changed

helion/runtime/kernel.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -356,9 +356,14 @@ def __init__(
356356
_maybe_skip_dtype_check_in_meta_registrations(),
357357
patch_inductor_lowerings(),
358358
):
359-
self.host_function: HostFunction = HostFunction(
360-
self.kernel.fn, self.fake_args, constexpr_args
361-
)
359+
try:
360+
self.host_function: HostFunction = HostFunction(
361+
self.kernel.fn, self.fake_args, constexpr_args
362+
)
363+
except Exception:
364+
config = self.env.config_spec.default_config()
365+
self.maybe_log_repro(log.warning, args, config=config)
366+
raise
362367

363368
@property
364369
def settings(self) -> Settings:
@@ -456,6 +461,7 @@ def compile_config(
456461
self.format_kernel_decorator(config, self.settings),
457462
exc_info=True,
458463
)
464+
self.maybe_log_repro(log.warning, self.fake_args, config=config)
459465
raise
460466
if allow_print:
461467
log.info("Output code: \n%s", triton_code)

test/test_debug_utils.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,96 @@ def mock_do_bench(*args, **kwargs):
168168
self.assertIn("kernel", captured)
169169
self.assertIn("helion_repro_caller()", captured)
170170

171+
def test_print_repro_on_device_ir_lowering_error(self):
172+
"""Ensure HELION_PRINT_REPRO=1 prints repro when compilation fails during device IR lowering."""
173+
with self._with_print_repro_enabled():
174+
175+
@helion.kernel(config=helion.Config(block_sizes=[32], num_warps=4))
176+
def kernel_with_compile_error(x: torch.Tensor) -> torch.Tensor:
177+
out = torch.empty_like(x)
178+
n = x.shape[0]
179+
for tile_n in hl.tile([n]):
180+
# Using torch.nonzero inside device loop causes compilation error
181+
# because it produces data-dependent output shape
182+
torch.nonzero(x[tile_n])
183+
out[tile_n] = x[tile_n]
184+
return out
185+
186+
torch.manual_seed(0)
187+
x = torch.randn([128], dtype=torch.float32, device=DEVICE)
188+
189+
with self.capture_logs() as log_capture:
190+
# This should trigger a compilation error during device IR lowering
191+
with self.assertRaises(RuntimeError):
192+
kernel_with_compile_error(x)
193+
194+
# Extract repro script from logs
195+
repro_script = None
196+
for record in log_capture.records:
197+
if "# === HELION KERNEL REPRO ===" in record.message:
198+
repro_script = record.message
199+
break
200+
201+
# Verify that a repro script was printed when compilation failed
202+
self.assertIsNotNone(
203+
repro_script,
204+
"Expected repro script to be printed when device IR lowering fails",
205+
)
206+
self.assertIn("# === HELION KERNEL REPRO ===", repro_script)
207+
self.assertIn("# === END HELION KERNEL REPRO ===", repro_script)
208+
self.assertIn("kernel_with_compile_error", repro_script)
209+
self.assertIn("helion_repro_caller()", repro_script)
210+
211+
def test_print_repro_on_triton_codegen_error(self):
212+
"""Ensure HELION_PRINT_REPRO=1 prints repro when Triton codegen fails."""
213+
with self._with_print_repro_enabled():
214+
215+
@helion.kernel(config=helion.Config(block_sizes=[32], num_warps=4))
216+
def kernel_with_triton_error(x: torch.Tensor) -> torch.Tensor:
217+
out = torch.empty_like(x)
218+
n = x.shape[0]
219+
for tile_n in hl.tile([n]):
220+
out[tile_n] = x[tile_n] + 1
221+
return out
222+
223+
torch.manual_seed(0)
224+
x = torch.randn([128], dtype=torch.float32, device=DEVICE)
225+
226+
# Mock PyCodeCache.load to simulate a Triton codegen error
227+
from torch._inductor.codecache import PyCodeCache
228+
229+
original_load = PyCodeCache.load
230+
231+
def mock_load(code, *args, **kwargs):
232+
if "kernel_with_triton_error" in code:
233+
raise RuntimeError("Simulated Triton codegen error")
234+
return original_load(code, *args, **kwargs)
235+
236+
with (
237+
self.capture_logs() as log_capture,
238+
mock.patch.object(PyCodeCache, "load", mock_load),
239+
):
240+
# This should trigger a Triton codegen error
241+
with self.assertRaises(RuntimeError):
242+
kernel_with_triton_error(x)
243+
244+
# Extract repro script from logs
245+
repro_script = None
246+
for record in log_capture.records:
247+
if "# === HELION KERNEL REPRO ===" in record.message:
248+
repro_script = record.message
249+
break
250+
251+
# Verify that a repro script was printed when Triton codegen failed
252+
self.assertIsNotNone(
253+
repro_script,
254+
"Expected repro script to be printed when Triton codegen fails",
255+
)
256+
self.assertIn("# === HELION KERNEL REPRO ===", repro_script)
257+
self.assertIn("# === END HELION KERNEL REPRO ===", repro_script)
258+
self.assertIn("kernel_with_triton_error", repro_script)
259+
self.assertIn("helion_repro_caller()", repro_script)
260+
171261

172262
if __name__ == "__main__":
173263
unittest.main()

0 commit comments

Comments
 (0)