Skip to content

Commit 605e152

Browse files
authored
Fix CI to surface errors correctly, fix all existing errors (#1138)
1 parent a30ce01 commit 605e152

File tree

4 files changed

+89
-10
lines changed

4 files changed

+89
-10
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ jobs:
151151
152152
- name: Run Tests
153153
run: |
154+
set -o pipefail
154155
source .venv/bin/activate
155156
# Conditionally enable ref-eager and golden-accept/dtype-assert test modes
156157
if [[ "${{ matrix.dtype-asserts }}" == "true" ]]; then export HELION_DEBUG_DTYPE_ASSERTS=1; fi

test/test_breakpoint.py

Lines changed: 85 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33
import builtins
44
from contextlib import contextmanager
55
import os
6+
import subprocess
67
import sys
8+
import textwrap
79
from typing import TYPE_CHECKING
810
import unittest
911
from unittest import mock
1012

1113
import torch
14+
from torch._environment import is_fbcode
1215

1316
import helion
1417
from helion import exc
@@ -64,6 +67,42 @@ def kernel(x: torch.Tensor) -> torch.Tensor:
6467

6568
return kernel
6669

70+
def _run_breakpoint_in_subprocess(
71+
self,
72+
*,
73+
test_name: str,
74+
runner_method: str,
75+
triton_interpret: int,
76+
helion_interpret: int,
77+
) -> None:
78+
"""Run a breakpoint test in a subprocess to isolate interpreter state."""
79+
script = textwrap.dedent(
80+
f"""
81+
from test.test_breakpoint import TestBreakpoint
82+
83+
case = TestBreakpoint({test_name!r})
84+
case.setUp()
85+
try:
86+
getattr(case, {runner_method!r})(triton_interpret={triton_interpret}, helion_interpret={helion_interpret})
87+
finally:
88+
case.tearDown()
89+
"""
90+
)
91+
92+
env = os.environ.copy()
93+
result = subprocess.run(
94+
[sys.executable, "-c", script],
95+
env=env,
96+
capture_output=True,
97+
)
98+
if result.returncode != 0:
99+
raise AssertionError(
100+
f"{test_name} subprocess failed",
101+
result.returncode,
102+
result.stdout.decode(),
103+
result.stderr.decode(),
104+
)
105+
67106
def _run_device_breakpoint_test(
68107
self, triton_interpret: int, helion_interpret: int
69108
) -> None:
@@ -90,14 +129,32 @@ def _run_device_breakpoint_test(
90129
out = bound(x)
91130
torch.testing.assert_close(out, x)
92131

132+
@unittest.skipIf(is_fbcode(), "subprocess test doesn't work in internal CI")
93133
def test_device_breakpoint_no_interpret(self) -> None:
94-
self._run_device_breakpoint_test(triton_interpret=0, helion_interpret=0)
95-
134+
self._run_breakpoint_in_subprocess(
135+
test_name=self._testMethodName,
136+
runner_method="_run_device_breakpoint_test",
137+
triton_interpret=0,
138+
helion_interpret=0,
139+
)
140+
141+
@unittest.skipIf(is_fbcode(), "subprocess test doesn't work in internal CI")
96142
def test_device_breakpoint_triton_interpret(self) -> None:
97-
self._run_device_breakpoint_test(triton_interpret=1, helion_interpret=0)
98-
143+
self._run_breakpoint_in_subprocess(
144+
test_name=self._testMethodName,
145+
runner_method="_run_device_breakpoint_test",
146+
triton_interpret=1,
147+
helion_interpret=0,
148+
)
149+
150+
@unittest.skipIf(is_fbcode(), "subprocess test doesn't work in internal CI")
99151
def test_device_breakpoint_helion_interpret(self) -> None:
100-
self._run_device_breakpoint_test(triton_interpret=0, helion_interpret=1)
152+
self._run_breakpoint_in_subprocess(
153+
test_name=self._testMethodName,
154+
runner_method="_run_device_breakpoint_test",
155+
triton_interpret=0,
156+
helion_interpret=1,
157+
)
101158

102159
def _run_host_breakpoint_test(
103160
self, triton_interpret: int, helion_interpret: int
@@ -116,14 +173,32 @@ def _run_host_breakpoint_test(
116173
out = bound(x)
117174
torch.testing.assert_close(out, x)
118175

176+
@unittest.skipIf(is_fbcode(), "subprocess test doesn't work in internal CI")
119177
def test_host_breakpoint_no_interpret(self) -> None:
120-
self._run_host_breakpoint_test(triton_interpret=0, helion_interpret=0)
121-
178+
self._run_breakpoint_in_subprocess(
179+
test_name=self._testMethodName,
180+
runner_method="_run_host_breakpoint_test",
181+
triton_interpret=0,
182+
helion_interpret=0,
183+
)
184+
185+
@unittest.skipIf(is_fbcode(), "subprocess test doesn't work in internal CI")
122186
def test_host_breakpoint_triton_interpret(self) -> None:
123-
self._run_host_breakpoint_test(triton_interpret=1, helion_interpret=0)
124-
187+
self._run_breakpoint_in_subprocess(
188+
test_name=self._testMethodName,
189+
runner_method="_run_host_breakpoint_test",
190+
triton_interpret=1,
191+
helion_interpret=0,
192+
)
193+
194+
@unittest.skipIf(is_fbcode(), "subprocess test doesn't work in internal CI")
125195
def test_host_breakpoint_helion_interpret(self) -> None:
126-
self._run_host_breakpoint_test(triton_interpret=0, helion_interpret=1)
196+
self._run_breakpoint_in_subprocess(
197+
test_name=self._testMethodName,
198+
runner_method="_run_host_breakpoint_test",
199+
triton_interpret=0,
200+
helion_interpret=1,
201+
)
127202

128203

129204
if __name__ == "__main__":

test/test_indexing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,7 @@ def run_case(
505505
expect_error=None,
506506
)
507507

508+
@skipIfRefEager("specialization_key is not used in ref eager mode")
508509
def test_dynamic_shape_specialization_key_tracks_large_tensors(self) -> None:
509510
@helion.kernel(static_shapes=False)
510511
def passthrough(x: torch.Tensor) -> torch.Tensor:

test/test_unroll_tuples.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from helion._testing import RefEagerTestBase
1111
from helion._testing import TestCase
1212
from helion._testing import code_and_output
13+
from helion._testing import skipIfRefEager
1314
import helion.language as hl
1415

1516

@@ -520,6 +521,7 @@ def kernel_static_range_tuple_indexing(
520521
expected = sum(tensors)
521522
torch.testing.assert_close(result, expected)
522523

524+
@skipIfRefEager("Type inference errors are not raised in ref eager mode")
523525
def test_static_range_tuple_indexing_requires_uniform_types(self):
524526
@helion.kernel(autotune_effort="none")
525527
def kernel_static_range_tuple_mismatch(x: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)