@@ -1730,3 +1730,158 @@ def matmul(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
17301730 _launcher(_helion_matmul, (_NUM_SM,), x, y, out, _NUM_SM, _BLOCK_SIZE_1, _BLOCK_SIZE_0, _BLOCK_SIZE_2, num_warps=4, num_stages=1)
17311731 # src[test_loops.py:N]: return out
17321732 return out
1733+
1734+ --- assertExpectedJournal(TestLoops.test_while_accumulates_tensor)
1735+ from __future__ import annotations
1736+
1737+ import torch
1738+ import triton
1739+ import triton.language as tl
1740+ from helion.runtime import default_launcher as _default_launcher
1741+
1742+ @triton.jit
1743+ def _helion_kernel(out, _BLOCK_SIZE_0: tl.constexpr):
1744+ # src[test_loops.py:N]: for tile in hl.tile(x.shape):
1745+ pid_0 = tl.program_id(0)
1746+ offset_0 = pid_0 * _BLOCK_SIZE_0
1747+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1748+ # src[test_loops.py:N]: acc = torch.zeros_like(x[tile])
1749+ acc = tl.full([_BLOCK_SIZE_0], 0, tl.float32)
1750+ # src[test_loops.py:N]: steps = torch.zeros([], device=x.device, dtype=torch.int32)
1751+ steps = tl.full([], 0, tl.int32)
1752+ # src[test_loops.py:N]: while steps < 4:
1753+ # src[test_loops.py:N]: acc = acc + 1
1754+ # src[test_loops.py:N]: steps = steps + 1
1755+ steps_copy = steps
1756+ steps_copy_0 = steps_copy
1757+ # src[test_loops.py:N]: while steps < 4:
1758+ v_0 = tl.full([], 4, tl.int32)
1759+ v_1 = steps_copy_0 < v_0
1760+ # src[test_loops.py:N]: while steps < 4:
1761+ # src[test_loops.py:N]: acc = acc + 1
1762+ # src[test_loops.py:N]: steps = steps + 1
1763+ while_cond = v_1
1764+ while while_cond:
1765+ steps_copy_1 = steps
1766+ acc_copy = acc
1767+ steps_copy_1_0 = steps_copy_1
1768+ acc_copy_0 = acc_copy
1769+ # src[test_loops.py:N]: acc = acc + 1
1770+ v_2 = 1.0
1771+ acc = acc_copy_0 + v_2
1772+ # src[test_loops.py:N]: steps = steps + 1
1773+ v_4 = tl.full([], 1, tl.int32)
1774+ steps = steps_copy_1_0 + v_4
1775+ # src[test_loops.py:N]: while steps < 4:
1776+ # src[test_loops.py:N]: acc = acc + 1
1777+ # src[test_loops.py:N]: steps = steps + 1
1778+ steps_copy_2 = steps
1779+ steps_copy_2_0 = steps_copy_2
1780+ # src[test_loops.py:N]: while steps < 4:
1781+ v_6 = tl.full([], 4, tl.int32)
1782+ v_7 = steps_copy_2_0 < v_6
1783+ # src[test_loops.py:N]: while steps < 4:
1784+ # src[test_loops.py:N]: acc = acc + 1
1785+ # src[test_loops.py:N]: steps = steps + 1
1786+ while_cond = v_7
1787+ # src[test_loops.py:N]: out[tile] = acc
1788+ tl.store(out + indices_0 * 1, acc, None)
1789+
1790+ def kernel(x: torch.Tensor, *, _launcher=_default_launcher):
1791+ # src[test_loops.py:N]: out = torch.empty_like(x)
1792+ out = torch.empty_like(x)
1793+ # src[test_loops.py:N]: for tile in hl.tile(x.shape):
1794+ _BLOCK_SIZE_0 = 16
1795+ # src[test_loops.py:N]: for tile in hl.tile(x.shape):
1796+ # src[test_loops.py:N]: acc = torch.zeros_like(x[tile])
1797+ # src[test_loops.py:N]: steps = torch.zeros([], device=x.device, dtype=torch.int32)
1798+ # src[test_loops.py:N-N]: ...
1799+ _launcher(_helion_kernel, (triton.cdiv(16, _BLOCK_SIZE_0),), out, _BLOCK_SIZE_0, num_warps=4, num_stages=1)
1800+ # src[test_loops.py:N]: return out
1801+ return out
1802+
1803+ --- assertExpectedJournal(TestLoops.test_while_atomic_add_accumulates)
1804+ from __future__ import annotations
1805+
1806+ import torch
1807+ import triton
1808+ import triton.language as tl
1809+ from helion.runtime import default_launcher as _default_launcher
1810+
1811+ @triton.jit
1812+ def _helion_kernel(counters, values, totals):
1813+ # src[test_loops.py:N]: for idx in hl.tile(values.size(0)):
1814+ pid_0 = tl.program_id(0)
1815+ offset_0 = pid_0
1816+ indices_0 = offset_0 + tl.zeros([1], tl.int32)
1817+ # src[test_loops.py:N]: while hl.atomic_add(counters, [idx], 1).sum() < 1:
1818+ atomic_add = tl.atomic_add(counters + indices_0 * 1, 1, mask=None, sem='relaxed')
1819+ sum_1 = tl.cast(tl.sum(atomic_add, 0), tl.float32)
1820+ v_0 = 1.0
1821+ v_1 = sum_1 < v_0
1822+ # src[test_loops.py:N]: while hl.atomic_add(counters, [idx], 1).sum() < 1:
1823+ # src[test_loops.py:N]: hl.atomic_add(totals, [idx], values[idx])
1824+ while_cond = v_1
1825+ while while_cond:
1826+ # src[test_loops.py:N]: hl.atomic_add(totals, [idx], values[idx])
1827+ load = tl.load(values + indices_0 * 1, None)
1828+ tl.atomic_add(totals + indices_0 * 1, load, mask=None, sem='relaxed')
1829+ # src[test_loops.py:N]: while hl.atomic_add(counters, [idx], 1).sum() < 1:
1830+ atomic_add_1 = tl.atomic_add(counters + indices_0 * 1, 1, mask=None, sem='relaxed')
1831+ sum_2 = tl.cast(tl.sum(atomic_add_1, 0), tl.float32)
1832+ v_2 = 1.0
1833+ v_3 = sum_2 < v_2
1834+ # src[test_loops.py:N]: while hl.atomic_add(counters, [idx], 1).sum() < 1:
1835+ # src[test_loops.py:N]: hl.atomic_add(totals, [idx], values[idx])
1836+ while_cond = v_3
1837+
1838+ def kernel(values: torch.Tensor, totals: torch.Tensor, counters: torch.Tensor, *, _launcher=_default_launcher):
1839+ # src[test_loops.py:N]: for idx in hl.tile(values.size(0)):
1840+ # src[test_loops.py:N]: while hl.atomic_add(counters, [idx], 1).sum() < 1:
1841+ # src[test_loops.py:N]: hl.atomic_add(totals, [idx], values[idx])
1842+ _launcher(_helion_kernel, (8,), counters, values, totals, num_warps=4, num_stages=1)
1843+ # src[test_loops.py:N]: return totals
1844+ return totals
1845+
1846+ --- assertExpectedJournal(TestLoops.test_while_atomic_cas_pass)
1847+ from __future__ import annotations
1848+
1849+ import torch
1850+ import triton
1851+ import triton.language as tl
1852+ from helion.runtime import default_launcher as _default_launcher
1853+
1854+ @triton.jit
1855+ def _helion_kernel(grad_x_lock, _BLOCK_SIZE_0: tl.constexpr):
1856+ # src[test_loops.py:N]: for idx in hl.tile(grad_x_lock.size(0)):
1857+ pid_0 = tl.program_id(0)
1858+ offset_0 = pid_0 * _BLOCK_SIZE_0
1859+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1860+ # src[test_loops.py:N]: while hl.atomic_cas(grad_x_lock, [idx], 0, 1) == 1:
1861+ atomic_cas = tl.atomic_cas(grad_x_lock + indices_0 * 1, 0, 1, sem='relaxed')
1862+ v_0 = tl.full([], 1, tl.int32)
1863+ v_1 = atomic_cas == v_0
1864+ # src[test_loops.py:N]: while hl.atomic_cas(grad_x_lock, [idx], 0, 1) == 1:
1865+ # src[test_loops.py:N]: pass
1866+ while_cond = v_1
1867+ while while_cond:
1868+ # src[test_loops.py:N]: while hl.atomic_cas(grad_x_lock, [idx], 0, 1) == 1:
1869+ atomic_cas_1 = tl.atomic_cas(grad_x_lock + indices_0 * 1, 0, 1, sem='relaxed')
1870+ v_2 = tl.full([], 1, tl.int32)
1871+ v_3 = atomic_cas_1 == v_2
1872+ # src[test_loops.py:N]: while hl.atomic_cas(grad_x_lock, [idx], 0, 1) == 1:
1873+ # src[test_loops.py:N]: pass
1874+ while_cond = v_3
1875+ # src[test_loops.py:N]: hl.atomic_cas(grad_x_lock, [idx], 1, 0)
1876+ tl.atomic_cas(grad_x_lock + indices_0 * 1, 1, 0, sem='relaxed')
1877+
1878+ def kernel(grad_x_lock: torch.Tensor, *, _launcher=_default_launcher):
1879+ # src[test_loops.py:N]: for idx in hl.tile(grad_x_lock.size(0)):
1880+ _BLOCK_SIZE_0 = 16
1881+ # src[test_loops.py:N]: for idx in hl.tile(grad_x_lock.size(0)):
1882+ # src[test_loops.py:N]: while hl.atomic_cas(grad_x_lock, [idx], 0, 1) == 1:
1883+ # src[test_loops.py:N]: pass
1884+ # src[test_loops.py:N-N]: ...
1885+ _launcher(_helion_kernel, (triton.cdiv(16, _BLOCK_SIZE_0),), grad_x_lock, _BLOCK_SIZE_0, num_warps=4, num_stages=1)
1886+ # src[test_loops.py:N]: return grad_x_lock
1887+ return grad_x_lock
0 commit comments