@@ -168,6 +168,96 @@ def mock_do_bench(*args, **kwargs):
168168 self .assertIn ("kernel" , captured )
169169 self .assertIn ("helion_repro_caller()" , captured )
170170
171+ def test_print_repro_on_device_ir_lowering_error (self ):
172+ """Ensure HELION_PRINT_REPRO=1 prints repro when compilation fails during device IR lowering."""
173+ with self ._with_print_repro_enabled ():
174+
175+ @helion .kernel (config = helion .Config (block_sizes = [32 ], num_warps = 4 ))
176+ def kernel_with_compile_error (x : torch .Tensor ) -> torch .Tensor :
177+ out = torch .empty_like (x )
178+ n = x .shape [0 ]
179+ for tile_n in hl .tile ([n ]):
180+ # Using torch.nonzero inside device loop causes compilation error
181+ # because it produces data-dependent output shape
182+ torch .nonzero (x [tile_n ])
183+ out [tile_n ] = x [tile_n ]
184+ return out
185+
186+ torch .manual_seed (0 )
187+ x = torch .randn ([128 ], dtype = torch .float32 , device = DEVICE )
188+
189+ with self .capture_logs () as log_capture :
190+ # This should trigger a compilation error during device IR lowering
191+ with self .assertRaises (RuntimeError ):
192+ kernel_with_compile_error (x )
193+
194+ # Extract repro script from logs
195+ repro_script = None
196+ for record in log_capture .records :
197+ if "# === HELION KERNEL REPRO ===" in record .message :
198+ repro_script = record .message
199+ break
200+
201+ # Verify that a repro script was printed when compilation failed
202+ self .assertIsNotNone (
203+ repro_script ,
204+ "Expected repro script to be printed when device IR lowering fails" ,
205+ )
206+ self .assertIn ("# === HELION KERNEL REPRO ===" , repro_script )
207+ self .assertIn ("# === END HELION KERNEL REPRO ===" , repro_script )
208+ self .assertIn ("kernel_with_compile_error" , repro_script )
209+ self .assertIn ("helion_repro_caller()" , repro_script )
210+
211+ def test_print_repro_on_triton_codegen_error (self ):
212+ """Ensure HELION_PRINT_REPRO=1 prints repro when Triton codegen fails."""
213+ with self ._with_print_repro_enabled ():
214+
215+ @helion .kernel (config = helion .Config (block_sizes = [32 ], num_warps = 4 ))
216+ def kernel_with_triton_error (x : torch .Tensor ) -> torch .Tensor :
217+ out = torch .empty_like (x )
218+ n = x .shape [0 ]
219+ for tile_n in hl .tile ([n ]):
220+ out [tile_n ] = x [tile_n ] + 1
221+ return out
222+
223+ torch .manual_seed (0 )
224+ x = torch .randn ([128 ], dtype = torch .float32 , device = DEVICE )
225+
226+ # Mock PyCodeCache.load to simulate a Triton codegen error
227+ from torch ._inductor .codecache import PyCodeCache
228+
229+ original_load = PyCodeCache .load
230+
231+ def mock_load (code , * args , ** kwargs ):
232+ if "kernel_with_triton_error" in code :
233+ raise RuntimeError ("Simulated Triton codegen error" )
234+ return original_load (code , * args , ** kwargs )
235+
236+ with (
237+ self .capture_logs () as log_capture ,
238+ mock .patch .object (PyCodeCache , "load" , mock_load ),
239+ ):
240+ # This should trigger a Triton codegen error
241+ with self .assertRaises (RuntimeError ):
242+ kernel_with_triton_error (x )
243+
244+ # Extract repro script from logs
245+ repro_script = None
246+ for record in log_capture .records :
247+ if "# === HELION KERNEL REPRO ===" in record .message :
248+ repro_script = record .message
249+ break
250+
251+ # Verify that a repro script was printed when Triton codegen failed
252+ self .assertIsNotNone (
253+ repro_script ,
254+ "Expected repro script to be printed when Triton codegen fails" ,
255+ )
256+ self .assertIn ("# === HELION KERNEL REPRO ===" , repro_script )
257+ self .assertIn ("# === END HELION KERNEL REPRO ===" , repro_script )
258+ self .assertIn ("kernel_with_triton_error" , repro_script )
259+ self .assertIn ("helion_repro_caller()" , repro_script )
260+
171261
172262if __name__ == "__main__" :
173263 unittest .main ()
0 commit comments