diff --git a/examples/all_reduce.py b/examples/all_reduce.py index 57fd8ed81..e4399440c 100644 --- a/examples/all_reduce.py +++ b/examples/all_reduce.py @@ -23,6 +23,7 @@ import helion from helion._testing import DEVICE +from helion._testing import run_example import helion.language as hl # %% @@ -212,6 +213,27 @@ def helion_one_shot_all_reduce(a_shared: torch.Tensor) -> torch.Tensor: ) +def reference_one_shot_all_reduce(a_shared: torch.Tensor) -> torch.Tensor: + """ + Reference implementation using the symmetric memory one-shot primitive. + """ + dist_group = dist.group.WORLD + if dist_group is None: + raise RuntimeError("No distributed group available") + + a_shared_clone = symm_mem.empty( + a_shared.shape, + dtype=a_shared.dtype, + device=a_shared.device, + ) + symm_mem.rendezvous(a_shared_clone, dist_group.group_name) + a_shared_clone.copy_(a_shared) + + return torch.ops.symm_mem.one_shot_all_reduce( # pyright: ignore[reportCallIssue] + a_shared_clone, "sum", dist_group.group_name + ) + + # %% # Testing Function # ---------------- @@ -232,21 +254,13 @@ def test(N: int, device: torch.device, dtype: torch.dtype) -> None: world_size = dist.get_world_size() a_shared = symm_mem.empty(N // world_size, dtype=dtype, device=device).normal_() - a_shared_clone = symm_mem.empty( - a_shared.shape, - dtype=a_shared.dtype, - device=a_shared.device, + run_example( + helion_one_shot_all_reduce, + reference_one_shot_all_reduce, + (a_shared,), + rtol=1e-1, + atol=1e-1, ) - symm_mem.rendezvous(a_shared_clone, dist_group.group_name) - a_shared_clone.copy_(a_shared) - - a_out = helion_one_shot_all_reduce(a_shared) - - gloden_o = torch.ops.symm_mem.one_shot_all_reduce( - a_shared_clone, "sum", dist_group.group_name - ) - - torch.testing.assert_close(a_out, gloden_o, rtol=1e-1, atol=1e-1) def main() -> None: diff --git a/examples/bf16xint16_gemm.py b/examples/bf16xint16_gemm.py index 729b537d2..c7571aa70 100644 --- a/examples/bf16xint16_gemm.py +++ b/examples/bf16xint16_gemm.py @@ -15,6 +15,7 @@ import helion from helion._testing import DEVICE +from helion._testing import run_example import helion.language as hl @@ -140,19 +141,25 @@ def check(m: int, k: int, n: int) -> None: """ x = torch.randn([m, k], device=DEVICE, dtype=torch.bfloat16) w = torch.randint(-(2**15), 2**15 - 1, (k, n), device=DEVICE, dtype=torch.int16) - - result = bf16xint16_gemm(x, w, transpose=False) - expected = reference_bf16xint16_pytorch(x, w, transpose=False) - torch.testing.assert_close(result, expected, rtol=1e-2, atol=1e-2) + run_example( + bf16xint16_gemm, + reference_bf16xint16_pytorch, + (x, w, False), + rtol=1e-2, + atol=1e-2, + ) x_int16 = torch.randint( -(2**15), 2**15 - 1, (m, k), device=DEVICE, dtype=torch.int16 ) w_bf16 = torch.randn([k, n], device=DEVICE, dtype=torch.bfloat16) - - result = bf16xint16_gemm(x_int16, w_bf16, transpose=True) - expected = reference_bf16xint16_pytorch(x_int16, w_bf16, transpose=True) - torch.testing.assert_close(result, expected, rtol=1e-2, atol=1e-2) + run_example( + bf16xint16_gemm, + reference_bf16xint16_pytorch, + (x_int16, w_bf16, True), + rtol=1e-2, + atol=1e-2, + ) # %% diff --git a/examples/grouped_gemm.py b/examples/grouped_gemm.py index 35f60e1c3..0780c67cc 100644 --- a/examples/grouped_gemm.py +++ b/examples/grouped_gemm.py @@ -40,6 +40,7 @@ import helion from helion._testing import DEVICE +from helion._testing import run_example import helion.language as hl # %% @@ -310,6 +311,26 @@ def _reference_grouped_gemm( return torch.cat(outs, dim=0) +def grouped_gemm_jagged_example( + group_A: list[torch.Tensor], group_B: list[torch.Tensor] +) -> torch.Tensor: + """ + Wrapper to run grouped_gemm_jagged with unpacked TritonBench inputs. + """ + A_packed, B_shared, group_offsets = _pack_group_inputs(group_A, group_B) + return grouped_gemm_jagged(A_packed, B_shared, group_offsets) + + +def grouped_gemm_jagged_persistent_example( + group_A: list[torch.Tensor], group_B: list[torch.Tensor] +) -> torch.Tensor: + """ + Wrapper to run grouped_gemm_jagged_persistent with unpacked TritonBench inputs. + """ + A_packed, B_shared, group_offsets = _pack_group_inputs(group_A, group_B) + return grouped_gemm_jagged_persistent(A_packed, B_shared, group_offsets) + + # %% # Test Harness and Validation # --------------------------- @@ -330,18 +351,23 @@ def main() -> None: # Shared weight matrix B replicated for each group (as per TritonBench convention) group_B = [torch.randn(K, N, device=device, dtype=dtype).contiguous()] * G - ref = _reference_grouped_gemm(group_A, group_B) - print("Testing grouped GEMM kernels...") - - # Test basic jagged kernel correctness - out = grouped_gemm_jagged_tritonbench(None, group_A, group_B)() - torch.testing.assert_close(out.float(), ref.float(), atol=1e-2, rtol=1e-2) + run_example( + grouped_gemm_jagged_example, + _reference_grouped_gemm, + (group_A, group_B), + rtol=1e-2, + atol=1e-2, + ) print("✓ Non-persistent kernel passed") - # Test persistent kernel with dynamic tiling - out_p = grouped_gemm_jagged_persistent_tritonbench(None, group_A, group_B)() - torch.testing.assert_close(out_p.float(), ref.float(), atol=1e-2, rtol=1e-2) + run_example( + grouped_gemm_jagged_persistent_example, + _reference_grouped_gemm, + (group_A, group_B), + rtol=1e-2, + atol=1e-2, + ) print("✓ Persistent kernel passed") print("\nAll tests passed!") diff --git a/examples/int4_gemm.py b/examples/int4_gemm.py index c2d6bb4c5..6c1657b25 100644 --- a/examples/int4_gemm.py +++ b/examples/int4_gemm.py @@ -21,6 +21,7 @@ import helion from helion._testing import DEVICE +from helion._testing import run_example import helion.language as hl # %% @@ -130,37 +131,72 @@ def run_kernel() -> torch.Tensor: # %% -def check(m: int, k: int, n: int) -> None: +def _pack_int4_matrix(unpacked: torch.Tensor) -> torch.Tensor: """ - Test the INT4 GEMM implementation. + Pack int4 matrix into int8 container with two values per byte. Args: - m (int): Number of rows in the left input matrix. - k (int): Shared dimension (must be even). - n (int): Number of columns in the right input matrix. + unpacked (torch.Tensor): Tensor of shape [K, N] with values in [-8, 7]. + + Returns: + torch.Tensor: Packed tensor of shape [K//2, N] in int8 format. """ - # Create test matrices - A = torch.randn(m, k, dtype=torch.bfloat16, device=DEVICE) + k, n = unpacked.shape + assert k % 2 == 0, "K dimension must be even for int4 packing" + reshaped = unpacked.reshape(k // 2, 2, n).permute(1, 0, 2) + return ((reshaped[0] & 0xF) | (reshaped[1] << 4)).to(torch.int8) - # Create packed int4 matrix B (K//2 x N) - # Generate random int4 values in range [-8, 7] and pack them - B_unpacked = torch.randint(-8, 8, (k, n), dtype=torch.int8, device=DEVICE) - # Pack using the same format as tritonbench - B_reshaped = B_unpacked.reshape(k // 2, 2, n).permute(1, 0, 2) - B_packed = ((B_reshaped[0] & 0xF) | (B_reshaped[1] << 4)).to(torch.int8) +def _unpack_int4_matrix(packed: torch.Tensor) -> torch.Tensor: + """ + Unpack an int4 matrix stored as two 4-bit values per int8 byte. + + Args: + packed (torch.Tensor): Packed tensor of shape [K//2, N] in int8 format. + + Returns: + torch.Tensor: Unpacked tensor of shape [K, N] in int8 format. + """ + b_lo = ((packed << 4) >> 4).to(torch.int8) + b_hi = (packed >> 4).to(torch.int8) + stacked = torch.stack([b_lo, b_hi], dim=1) + return stacked.reshape(packed.shape[0] * 2, packed.shape[1]) + - # Convert unpacked values to bfloat16 for reference - B_unpacked_bf16 = B_unpacked.to(torch.bfloat16) +def reference_matmul_bf16_int4(A: Tensor, B_packed: Tensor) -> Tensor: + """ + Reference implementation that unpacks the int4 weights and performs matmul. + + Args: + A (Tensor): Input tensor in bfloat16 format. + B_packed (Tensor): Packed int4 tensor. + + Returns: + Tensor: Output tensor in bfloat16 format. + """ + B_unpacked = _unpack_int4_matrix(B_packed).to(torch.bfloat16) + return torch.matmul(A, B_unpacked) - # Compute reference result - expected = torch.matmul(A, B_unpacked_bf16) - # Run the kernel - result = matmul_bf16_int4(A, B_packed) +def check(m: int, k: int, n: int) -> None: + """ + Test the INT4 GEMM implementation using the run_example utility. - # Check accuracy with appropriate tolerance - torch.testing.assert_close(result, expected, rtol=2e-1, atol=1.0) + Args: + m (int): Number of rows in the left input matrix. + k (int): Shared dimension (must be even). + n (int): Number of columns in the right input matrix. + """ + A = torch.randn(m, k, dtype=torch.bfloat16, device=DEVICE) + B_unpacked = torch.randint(-8, 8, (k, n), dtype=torch.int8, device=DEVICE) + B_packed = _pack_int4_matrix(B_unpacked) + run_example( + matmul_bf16_int4, + reference_matmul_bf16_int4, + (A, B_packed), + rtol=2e-1, + atol=1.0, + ) print(f"Test passed for shapes: M={m}, K={k}, N={n}") diff --git a/helion/autotuner/base_search.py b/helion/autotuner/base_search.py index b81692ea3..bed4b324a 100644 --- a/helion/autotuner/base_search.py +++ b/helion/autotuner/base_search.py @@ -409,7 +409,7 @@ def parallel_benchmark( iterator = iter_with_progress( zip(configs, fns, is_workings, strict=True), total=len(configs), - description=f"{desc}: exploring neighbors", + description=f"{desc} exploring neighbors", enabled=self.settings.autotune_progress_bar, ) for config, fn, is_working in iterator: