From 3939aea20bf038b015a69ae0203d62eb71a1672c Mon Sep 17 00:00:00 2001 From: Menglu Yu Date: Mon, 20 Oct 2025 10:50:30 -0700 Subject: [PATCH] Add simplified se_block kernel (#989) Summary: We add a helion kernel to compute 2 * x * sigmoid(x @ w) Differential Revision: D84968671 --- examples/se_block.py | 213 ++++++++++++++++++++++++++++++++++++ test/test_examples.expected | 191 ++++++++++++++++++++++++++++++++ test/test_examples.py | 97 ++++++++++++++++ 3 files changed, 501 insertions(+) create mode 100644 examples/se_block.py diff --git a/examples/se_block.py b/examples/se_block.py new file mode 100644 index 000000000..a3046482f --- /dev/null +++ b/examples/se_block.py @@ -0,0 +1,213 @@ +""" +Helion SE Block Example +============================ +This example demonstrates a Helion kernel implementation of SE Block. +""" + +# %% +from __future__ import annotations + +import torch +from torch import Tensor + +import helion +from helion._testing import DEVICE +from helion._testing import run_example +import helion.language as hl + + +# %% +@helion.kernel( + # static_shapes=True gives a performance boost for matmuls + static_shapes=True, +) +def se_block_fwd(x: Tensor, w: Tensor) -> tuple[Tensor, Tensor]: + """ + Performs 2 * x * sigmoid(x @ w) + Args: + x: 2D tensor of shape [m, n]. + w: 2D tensor of shape [n, n]. + Returns: + out: Resulting matrix of shape [m, n]. + s: sigmoid(x @ w) of shape [m, n]. + """ + m, n = x.size() + + out = torch.empty([m, n], dtype=x.dtype, device=x.device) + s = torch.empty([m, n], dtype=x.dtype, device=x.device) + + for tile_m in hl.tile(m): + for tile_n in hl.tile(n): + # Compute sigmoid in float32 + sigmoid_result = torch.sigmoid(x[tile_m, :] @ w[:, tile_n]) + s[tile_m, tile_n] = sigmoid_result + # Compute output: 2 * x * sigmoid, cast to input dtype + acc = 2.0 * x[tile_m, tile_n].to(torch.float32) * sigmoid_result + out[tile_m, tile_n] = acc.to(x.dtype) + + return out, s + + +# %% +@helion.kernel(static_shapes=True) +def se_block_bwd_dx(grad_out: Tensor, x: Tensor, w: Tensor, s: Tensor) -> Tensor: + """ + Compute gradient for x. + grad_x = 2 * grad_out * s + (2 * grad_out * x * s * (1 - s)) @ w.T + + Args: + grad_out: Gradient w.r.t output [m, n] + x: Input tensor [m, n] + w: Weight matrix [n, n] + s: sigmoid(x @ w) from forward pass [m, n] + + Returns: + grad_x: Gradient w.r.t x [m, n] + """ + m, n = x.size() + + grad_x = torch.empty([m, n], dtype=torch.float32, device=x.device) + + for tile_m, tile_n in hl.tile([m, n]): + # 2 * grad_out * s + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + acc += 2.0 * grad_out[tile_m, tile_n] * s[tile_m, tile_n] + + for tile_k in hl.tile(n): + # 2 * grad_out * x * s * (1-s) for tile_k + grad_to_w = ( + 2.0 + * grad_out[tile_m, tile_k].to(torch.float32) + * x[tile_m, tile_k].to(torch.float32) + * s[tile_m, tile_k].to(torch.float32) + * (1.0 - s[tile_m, tile_k].to(torch.float32)) + ) + # grad_to_w @ w.T[tile_k, tile_n] = grad_to_w @ w[tile_n, tile_k].T + acc += grad_to_w @ w[tile_n, tile_k].to(torch.float32).T + + grad_x[tile_m, tile_n] = acc.to(x.dtype) + + return grad_x + + +# %% +@helion.kernel(static_shapes=True) +def se_block_bwd_dw(grad_out: Tensor, x: Tensor, s: Tensor) -> Tensor: + """ + Compute gradient for w. + grad_w = x.T @ (2 * grad_out * x * s * (1 - s)) + + Args: + grad_out: Gradient w.r.t output [m, n] + x: Input tensor [m, n] + s: sigmoid(x @ w) from forward pass [m, n] + + Returns: + grad_w: Gradient w.r.t w [n, n] + """ + m, n = x.size() + + grad_w = torch.zeros([n, n], dtype=torch.float32, device=x.device) + + for tile_n1, tile_n2 in hl.tile([n, n]): + acc_w = hl.zeros([tile_n1, tile_n2], dtype=torch.float32) + for tile_m in hl.tile(m): + # 2 * grad_out * x * s * (1-s) + grad_to_w = ( + 2.0 + * grad_out[tile_m, tile_n2].to(torch.float32) + * x[tile_m, tile_n2].to(torch.float32) + * s[tile_m, tile_n2].to(torch.float32) + * (1.0 - s[tile_m, tile_n2].to(torch.float32)) + ) + # x[tile_m, tile_n1].T @ grad_to_w[tile_m, tile_n2] + acc_w += x[tile_m, tile_n1].to(torch.float32).T @ grad_to_w + + grad_w[tile_n1, tile_n2] = acc_w.to(x.dtype) + + return grad_w + + +# %% +# Reference Implementation +# -------------------- +def se_block_pytorch(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: + """ + PyTorch reference implementation se_block. + + Args: + x, w: Input tensors + + Returns: + tensor of 2 * x * sigmoid(x @ w) + """ + return 2 * x * torch.sigmoid(x @ w) + + +# %% +# Autograd Function +# ------------------ +class SEBlockFunction(torch.autograd.Function): + @staticmethod + def forward( # type: ignore[override] + ctx: object, + x: torch.Tensor, + w: torch.Tensor, + ) -> torch.Tensor: + """Forward pass for se block.""" + out, s = se_block_fwd(x, w) + ctx.save_for_backward(x, w, s) # type: ignore[attr-defined] + return out + + @staticmethod + def backward( # type: ignore[override] + ctx: object, + grad_out: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Backward pass for se block.""" + x, w, s = ctx.saved_tensors # type: ignore[attr-defined] + + grad_x = se_block_bwd_dx(grad_out, x, w, s) + grad_w = se_block_bwd_dw(grad_out, x, s) + + return grad_x, grad_w + + +def se_block(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: + """ + SE Block with autograd support. + + Args: + x: Input tensor [m, n] + w: Weight matrix [n, n] + + Returns: + Output tensor [m, n] + """ + return SEBlockFunction.apply(x, w) # type: ignore[no-any-return] + + +def check(m: int, n: int) -> None: + """ + Checks the correctness against PyTorch. + Args: + m (int): Number of rows in matrix x. + n (int): Number of columns in matrix x. + """ + x = torch.randn([m, n], device=DEVICE, dtype=torch.float16, requires_grad=True) + w = torch.randn([n, n], device=DEVICE, dtype=torch.float16, requires_grad=True) + for bwd in [True, False]: + run_example(se_block, se_block_pytorch, (x, w), bwd=bwd) + + +# %% +def main() -> None: + """ + Main function to run correctness checks. + """ + check(1024, 1024) + + +# %% +if __name__ == "__main__": + main() diff --git a/test/test_examples.expected b/test/test_examples.expected index 6fd318400..b0237f052 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -4559,3 +4559,194 @@ def squeeze_and_excitation_net_bwd_db(grad_out: Tensor, x: Tensor, d: Tensor, c: _BLOCK_SIZE_2 = 16 _launcher(_helion_squeeze_and_excitation_net_bwd_db, (triton.cdiv(256, _BLOCK_SIZE_0) * triton.cdiv(256, _BLOCK_SIZE_1),), grad_out, x, d, c, grad_b, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) return grad_b + +--- assertExpectedJournal(TestExamples.test_se_block_fwd) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_se_block_fwd(x, w, s, out, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + indices_2 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32) + for offset_1 in tl.range(0, 128, _BLOCK_SIZE_1): + indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) + load = tl.load(x + (indices_0[:, None] * 128 + indices_2[None, :] * 1), None) + load_1 = tl.load(w + (indices_2[:, None] * 128 + indices_1[None, :] * 1), None) + mm = tl.dot(tl.cast(load, tl.bfloat16), tl.cast(load_1, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32) + v_0 = tl.sigmoid(tl.cast(mm, tl.float32)) + tl.store(s + (indices_0[:, None] * 128 + indices_1[None, :] * 1), v_0, None) + load_2 = tl.load(x + (indices_0[:, None] * 128 + indices_1[None, :] * 1), None) + v_1 = tl.cast(load_2, tl.float32) + v_2 = 2.0 + v_3 = v_1 * v_2 + v_4 = tl.cast(v_0, tl.float32) + v_5 = v_3 * v_4 + v_6 = tl.cast(v_5, tl.bfloat16) + tl.store(out + (indices_0[:, None] * 128 + indices_1[None, :] * 1), v_6, None) + +def se_block_fwd(x: Tensor, w: Tensor, *, _launcher=_default_launcher): + """ + Performs 2 * x * sigmoid(x @ w) + Args: + x: 2D tensor of shape [m, n]. + w: 2D tensor of shape [n, n]. + Returns: + out: Resulting matrix of shape [m, n]. + s: sigmoid(x @ w) of shape [m, n]. + """ + m, n = x.size() + out = torch.empty([m, n], dtype=x.dtype, device=x.device) + s = torch.empty([m, n], dtype=x.dtype, device=x.device) + _BLOCK_SIZE_0 = 32 + _RDIM_SIZE_2 = 128 + _BLOCK_SIZE_1 = 32 + _launcher(_helion_se_block_fwd, (triton.cdiv(128, _BLOCK_SIZE_0),), x, w, s, out, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + return (out, s) + +--- assertExpectedJournal(TestExamples.test_se_block_bwd_dx) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_se_block_bwd_dx(grad_out, s, x, w, grad_x, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(128, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + load = tl.load(grad_out + (indices_0[:, None] * 128 + indices_1[None, :] * 1), None) + v_0 = 2.0 + v_1 = load * v_0 + load_1 = tl.load(s + (indices_0[:, None] * 128 + indices_1[None, :] * 1), None) + v_2 = v_1 * load_1 + v_3 = tl.cast(v_2, tl.float32) + v_4 = acc + v_3 + for offset_2 in tl.range(0, 128, _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + v_4_copy = v_4 + v_4_copy_0 = v_4_copy + load_2 = tl.load(grad_out + (indices_0[:, None] * 128 + indices_2[None, :] * 1), None) + v_5 = tl.cast(load_2, tl.float32) + v_6 = 2.0 + v_7 = v_5 * v_6 + load_3 = tl.load(x + (indices_0[:, None] * 128 + indices_2[None, :] * 1), None) + v_8 = tl.cast(load_3, tl.float32) + v_9 = v_7 * v_8 + load_4 = tl.load(s + (indices_0[:, None] * 128 + indices_2[None, :] * 1), None) + v_10 = tl.cast(load_4, tl.float32) + v_11 = v_9 * v_10 + load_5 = tl.load(s + (indices_0[:, None] * 128 + indices_2[None, :] * 1), None) + v_12 = tl.cast(load_5, tl.float32) + v_13 = 1.0 + v_14 = v_13 - v_12 + v_15 = v_11 * v_14 + load_6 = tl.load(w + (indices_1[:, None] * 128 + indices_2[None, :] * 1), None) + v_16 = tl.cast(load_6, tl.float32) + permute = tl.permute(v_16, [1, 0]) + mm = tl.dot(tl.cast(v_15, tl.float32), tl.cast(permute, tl.float32), input_precision='tf32', out_dtype=tl.float32) + v_4 = v_4_copy_0 + mm + v_18 = tl.cast(v_4, tl.float16) + v_19 = tl.cast(v_18, tl.float32) + tl.store(grad_x + (indices_0[:, None] * 128 + indices_1[None, :] * 1), v_19, None) + +def se_block_bwd_dx(grad_out: Tensor, x: Tensor, w: Tensor, s: Tensor, *, _launcher=_default_launcher): + """ + Compute gradient for x. + grad_x = 2 * grad_out * s + (2 * grad_out * x * s * (1 - s)) @ w.T + + Args: + grad_out: Gradient w.r.t output [m, n] + x: Input tensor [m, n] + w: Weight matrix [n, n] + s: sigmoid(x @ w) from forward pass [m, n] + + Returns: + grad_x: Gradient w.r.t x [m, n] + """ + m, n = x.size() + grad_x = torch.empty([m, n], dtype=torch.float32, device=x.device) + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + _BLOCK_SIZE_2 = 32 + _launcher(_helion_se_block_bwd_dx, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(128, _BLOCK_SIZE_1),), grad_out, s, x, w, grad_x, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) + return grad_x + +--- assertExpectedJournal(TestExamples.test_se_block_bwd_dw) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_se_block_bwd_dw(grad_out, x, s, grad_w, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(128, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + acc_w = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + for offset_2 in tl.range(0, 128, _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + acc_w_copy = acc_w + acc_w_copy_0 = acc_w_copy + load = tl.load(grad_out + (indices_2[:, None] * 128 + indices_1[None, :] * 1), None) + v_0 = tl.cast(load, tl.float32) + v_1 = 2.0 + v_2 = v_0 * v_1 + load_1 = tl.load(x + (indices_2[:, None] * 128 + indices_1[None, :] * 1), None) + v_3 = tl.cast(load_1, tl.float32) + v_4 = v_2 * v_3 + load_2 = tl.load(s + (indices_2[:, None] * 128 + indices_1[None, :] * 1), None) + v_5 = tl.cast(load_2, tl.float32) + v_6 = v_4 * v_5 + load_3 = tl.load(s + (indices_2[:, None] * 128 + indices_1[None, :] * 1), None) + v_7 = tl.cast(load_3, tl.float32) + v_8 = 1.0 + v_9 = v_8 - v_7 + v_10 = v_6 * v_9 + load_4 = tl.load(x + (indices_2[:, None] * 128 + indices_0[None, :] * 1), None) + v_11 = tl.cast(load_4, tl.float32) + permute = tl.permute(v_11, [1, 0]) + mm = tl.dot(tl.cast(permute, tl.float32), tl.cast(v_10, tl.float32), input_precision='tf32', out_dtype=tl.float32) + acc_w = acc_w_copy_0 + mm + v_13 = tl.cast(acc_w, tl.float16) + v_14 = tl.cast(v_13, tl.float32) + tl.store(grad_w + (indices_0[:, None] * 128 + indices_1[None, :] * 1), v_14, None) + +def se_block_bwd_dw(grad_out: Tensor, x: Tensor, s: Tensor, *, _launcher=_default_launcher): + """ + Compute gradient for w. + grad_w = x.T @ (2 * grad_out * x * s * (1 - s)) + + Args: + grad_out: Gradient w.r.t output [m, n] + x: Input tensor [m, n] + s: sigmoid(x @ w) from forward pass [m, n] + + Returns: + grad_w: Gradient w.r.t w [n, n] + """ + m, n = x.size() + grad_w = torch.zeros([n, n], dtype=torch.float32, device=x.device) + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + _BLOCK_SIZE_2 = 32 + _launcher(_helion_se_block_bwd_dw, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(128, _BLOCK_SIZE_1),), grad_out, x, s, grad_w, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) + return grad_w diff --git a/test/test_examples.py b/test/test_examples.py index 0c6a6b311..8fa936810 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -1567,6 +1567,103 @@ def test_squeeze_and_excitation_net_bwd_db(self): ) ) + @skipIfRocm("failure on rocm") + @skipIfA10G("failure on a10g") + def test_se_block_fwd(self): + m, n = 128, 128 + x = torch.randn([m, n], device=DEVICE, dtype=torch.bfloat16) + w = torch.randn([n, n], device=DEVICE, dtype=torch.bfloat16) + + # Compute expected output with PyTorch + expected = 2 * x * torch.sigmoid(x @ w) + + args = (x, w) + + self.assertExpectedJournal( + check_example( + "se_block", + args, + (expected, None), # (output, sigmoid) + fn_name="se_block_fwd", + block_sizes=[32, 32], + num_warps=4, + num_stages=3, + ) + ) + + @skipIfRocm("failure on rocm") + @skipIfA10G("failure on a10g") + def test_se_block_bwd_dx(self): + m, n = 128, 128 + x = torch.randn([m, n], device=DEVICE, dtype=torch.float16, requires_grad=True) + w = torch.randn([n, n], device=DEVICE, dtype=torch.float16, requires_grad=True) + grad_out = torch.randn([m, n], device=DEVICE, dtype=torch.float16) + + # Compute expected gradients with PyTorch + x_torch = x.detach().clone().requires_grad_(True) + w_torch = w.detach().clone().requires_grad_(True) + out_torch = 2 * x_torch * torch.sigmoid(x_torch @ w_torch) + out_torch.backward(grad_out) + + # Get sigmoid values from forward pass for the backward kernel + # Configure forward kernel to avoid autotuning during backward test + from examples.se_block import se_block_fwd + + config = helion.Config(block_size=[32, 32], num_warps=4, num_stages=3) + configured_kernel = helion.kernel(se_block_fwd.fn, config=config) + _, s = configured_kernel(x, w) + + args = (grad_out, x, w, s) + + self.assertExpectedJournal( + check_example( + "se_block", + args, + x_torch.grad, + fn_name="se_block_bwd_dx", + block_sizes=[32, 32, 32], + num_warps=4, + num_stages=3, + ) + ) + + @skipIfRocm("failure on rocm") + @skipIfA10G("failure on a10g") + def test_se_block_bwd_dw(self): + m, n = 128, 128 + x = torch.randn([m, n], device=DEVICE, dtype=torch.float16, requires_grad=True) + w = torch.randn([n, n], device=DEVICE, dtype=torch.float16, requires_grad=True) + grad_out = torch.randn([m, n], device=DEVICE, dtype=torch.float16) + + # Compute expected gradients with PyTorch + x_torch = x.detach().clone().requires_grad_(True) + w_torch = w.detach().clone().requires_grad_(True) + out_torch = 2 * x_torch * torch.sigmoid(x_torch @ w_torch) + out_torch.backward(grad_out) + + # Get sigmoid values from forward pass for the backward kernel + # Configure forward kernel to avoid autotuning during backward test + from examples.se_block import se_block_fwd + + config = helion.Config(block_size=[32, 32], num_warps=4, num_stages=3) + configured_kernel = helion.kernel(se_block_fwd.fn, config=config) + _, s = configured_kernel(x, w) + + args = (grad_out, x, s) + + self.assertExpectedJournal( + check_example( + "se_block", + args, + w_torch.grad, + fn_name="se_block_bwd_dw", + block_sizes=[32, 32, 32], + num_warps=4, + num_stages=3, + rtol=1e-2, + ) + ) + if __name__ == "__main__": unittest.main()