@@ -92,9 +92,9 @@ def cumsum_helper(x: torch.Tensor) -> torch.Tensor:
9292 return hl .associative_scan (add_combine_fn , x , dim = 0 )
9393
9494
95- @helion .jit
95+ @helion .kernel
9696def jit_add_combine_fn (x , y ):
97- """Addition combine function with @helion.jit decorator (should be ignored)."""
97+ """Addition combine function with @helion.kernel decorator (should be ignored)."""
9898 return x + y
9999
100100
@@ -496,10 +496,10 @@ def test_codegen_kernel(x: torch.Tensor) -> torch.Tensor:
496496 self .assertNotIn ("placeholder" , code )
497497
498498 @skipIfRefEager (
499- "torch._higher_order_ops.associative_scan with nested @helion.jit is not supported by ref eager mode yet"
499+ "torch._higher_order_ops.associative_scan with nested @helion.kernel is not supported by ref eager mode yet"
500500 )
501501 def test_associative_scan_jit_decorator_ignored (self ):
502- """Test that @helion.jit decorator on combine functions is ignored."""
502+ """Test that @helion.kernel decorator on combine functions is ignored."""
503503
504504 @helion .kernel (autotune_effort = "none" )
505505 def test_jit_kernel (x : torch .Tensor ) -> torch .Tensor :
@@ -521,8 +521,8 @@ def test_jit_kernel(x: torch.Tensor) -> torch.Tensor:
521521 self .assertIn ("def jit_add_combine_fn_" , code )
522522 self .assertIn ("tl.associative_scan" , code )
523523 self .assertIn ("param_0 + param_1" , code )
524- # Verify @helion.jit decorator doesn't appear in generated code
525- self .assertNotIn ("@helion.jit " , code )
524+ # Verify @helion.kernel decorator doesn't appear in generated code
525+ self .assertNotIn ("@helion.kernel " , code )
526526
527527 @skipIfRefEager (
528528 "torch._higher_order_ops.associative_scan with tuple arg is not supported by ref eager mode yet"
0 commit comments