33import builtins
44from contextlib import contextmanager
55import os
6+ import subprocess
67import sys
8+ import textwrap
79from typing import TYPE_CHECKING
810import unittest
911from unittest import mock
1012
1113import torch
14+ from torch ._environment import is_fbcode
1215
1316import helion
1417from helion import exc
@@ -64,6 +67,42 @@ def kernel(x: torch.Tensor) -> torch.Tensor:
6467
6568 return kernel
6669
70+ def _run_breakpoint_in_subprocess (
71+ self ,
72+ * ,
73+ test_name : str ,
74+ runner_method : str ,
75+ triton_interpret : int ,
76+ helion_interpret : int ,
77+ ) -> None :
78+ """Run a breakpoint test in a subprocess to isolate interpreter state."""
79+ script = textwrap .dedent (
80+ f"""
81+ from test.test_breakpoint import TestBreakpoint
82+
83+ case = TestBreakpoint({ test_name !r} )
84+ case.setUp()
85+ try:
86+ getattr(case, { runner_method !r} )(triton_interpret={ triton_interpret } , helion_interpret={ helion_interpret } )
87+ finally:
88+ case.tearDown()
89+ """
90+ )
91+
92+ env = os .environ .copy ()
93+ result = subprocess .run (
94+ [sys .executable , "-c" , script ],
95+ env = env ,
96+ capture_output = True ,
97+ )
98+ if result .returncode != 0 :
99+ raise AssertionError (
100+ f"{ test_name } subprocess failed" ,
101+ result .returncode ,
102+ result .stdout .decode (),
103+ result .stderr .decode (),
104+ )
105+
67106 def _run_device_breakpoint_test (
68107 self , triton_interpret : int , helion_interpret : int
69108 ) -> None :
@@ -90,14 +129,32 @@ def _run_device_breakpoint_test(
90129 out = bound (x )
91130 torch .testing .assert_close (out , x )
92131
132+ @unittest .skipIf (is_fbcode (), "subprocess test doesn't work in internal CI" )
93133 def test_device_breakpoint_no_interpret (self ) -> None :
94- self ._run_device_breakpoint_test (triton_interpret = 0 , helion_interpret = 0 )
95-
134+ self ._run_breakpoint_in_subprocess (
135+ test_name = self ._testMethodName ,
136+ runner_method = "_run_device_breakpoint_test" ,
137+ triton_interpret = 0 ,
138+ helion_interpret = 0 ,
139+ )
140+
141+ @unittest .skipIf (is_fbcode (), "subprocess test doesn't work in internal CI" )
96142 def test_device_breakpoint_triton_interpret (self ) -> None :
97- self ._run_device_breakpoint_test (triton_interpret = 1 , helion_interpret = 0 )
98-
143+ self ._run_breakpoint_in_subprocess (
144+ test_name = self ._testMethodName ,
145+ runner_method = "_run_device_breakpoint_test" ,
146+ triton_interpret = 1 ,
147+ helion_interpret = 0 ,
148+ )
149+
150+ @unittest .skipIf (is_fbcode (), "subprocess test doesn't work in internal CI" )
99151 def test_device_breakpoint_helion_interpret (self ) -> None :
100- self ._run_device_breakpoint_test (triton_interpret = 0 , helion_interpret = 1 )
152+ self ._run_breakpoint_in_subprocess (
153+ test_name = self ._testMethodName ,
154+ runner_method = "_run_device_breakpoint_test" ,
155+ triton_interpret = 0 ,
156+ helion_interpret = 1 ,
157+ )
101158
102159 def _run_host_breakpoint_test (
103160 self , triton_interpret : int , helion_interpret : int
@@ -116,14 +173,32 @@ def _run_host_breakpoint_test(
116173 out = bound (x )
117174 torch .testing .assert_close (out , x )
118175
176+ @unittest .skipIf (is_fbcode (), "subprocess test doesn't work in internal CI" )
119177 def test_host_breakpoint_no_interpret (self ) -> None :
120- self ._run_host_breakpoint_test (triton_interpret = 0 , helion_interpret = 0 )
121-
178+ self ._run_breakpoint_in_subprocess (
179+ test_name = self ._testMethodName ,
180+ runner_method = "_run_host_breakpoint_test" ,
181+ triton_interpret = 0 ,
182+ helion_interpret = 0 ,
183+ )
184+
185+ @unittest .skipIf (is_fbcode (), "subprocess test doesn't work in internal CI" )
122186 def test_host_breakpoint_triton_interpret (self ) -> None :
123- self ._run_host_breakpoint_test (triton_interpret = 1 , helion_interpret = 0 )
124-
187+ self ._run_breakpoint_in_subprocess (
188+ test_name = self ._testMethodName ,
189+ runner_method = "_run_host_breakpoint_test" ,
190+ triton_interpret = 1 ,
191+ helion_interpret = 0 ,
192+ )
193+
194+ @unittest .skipIf (is_fbcode (), "subprocess test doesn't work in internal CI" )
125195 def test_host_breakpoint_helion_interpret (self ) -> None :
126- self ._run_host_breakpoint_test (triton_interpret = 0 , helion_interpret = 1 )
196+ self ._run_breakpoint_in_subprocess (
197+ test_name = self ._testMethodName ,
198+ runner_method = "_run_host_breakpoint_test" ,
199+ triton_interpret = 0 ,
200+ helion_interpret = 1 ,
201+ )
127202
128203
129204if __name__ == "__main__" :
0 commit comments