Skip to content

Commit 6624d6d

Browse files
authored
Faster int4 gemm (#751)
1 parent d588560 commit 6624d6d

File tree

2 files changed

+51
-50
lines changed

2 files changed

+51
-50
lines changed

examples/int4_gemm.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,14 @@ def matmul_bf16_int4(A: Tensor, B: Tensor) -> Tensor:
5656
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
5757

5858
for tile_k_packed in hl.tile(K // 2, block_size=block_size_k_packed):
59+
# Load corresponding tiles from A (need to load twice the packed tile size)
60+
# We need to map tile_k_packed to the corresponding range in A
61+
a_tile_begin = tile_k_packed.begin * 2
62+
a_tile_len = block_size_k_packed * 2
63+
a_tile = A[tile_m, a_tile_begin : (a_tile_begin + a_tile_len)].to(
64+
torch.float32
65+
) # [BLOCK_SIZE_M, BLOCK_SIZE_K]
66+
5967
# Load packed int8 data from B
6068
b_tile = B[tile_k_packed, tile_n] # [BLOCK_SIZE_K//2, BLOCK_SIZE_N]
6169

@@ -64,29 +72,19 @@ def matmul_bf16_int4(A: Tensor, B: Tensor) -> Tensor:
6472
b_lo = ((b_tile << 4) >> 4).to(torch.int8) # Sign-extend low 4 bits
6573
b_hi = (b_tile >> 4).to(torch.int8) # Sign-extend high 4 bits
6674

67-
# Convert to bfloat16
68-
b_lo_bf16 = b_lo.to(torch.bfloat16) # [BLOCK_SIZE_K//2, BLOCK_SIZE_N]
69-
b_hi_bf16 = b_hi.to(torch.bfloat16) # [BLOCK_SIZE_K//2, BLOCK_SIZE_N]
70-
7175
# Stack and reshape to interleave low and high bits
7276
# Stack along a new dimension to get [BLOCK_SIZE_K//2, 2, BLOCK_SIZE_N]
73-
b_stacked = torch.stack([b_lo_bf16, b_hi_bf16], dim=1)
77+
b_stacked = torch.stack([b_lo, b_hi], dim=1)
7478

7579
# Reshape to interleave: [BLOCK_SIZE_K//2, 2, BLOCK_SIZE_N] -> [BLOCK_SIZE_K, BLOCK_SIZE_N]
7680
# This will place elements in the order: b_lo[0], b_hi[0], b_lo[1], b_hi[1], ...
7781
b_unpacked = b_stacked.reshape(
7882
tile_k_packed.block_size * 2, tile_n.block_size
79-
)
80-
81-
# Load corresponding tiles from A (need to load twice the packed tile size)
82-
# We need to map tile_k_packed to the corresponding range in A
83-
a_tile_begin = tile_k_packed.begin * 2
84-
a_tile_len = tile_k_packed.block_size * 2
85-
a_tile = A[
86-
tile_m, a_tile_begin : (a_tile_begin + a_tile_len)
87-
] # [BLOCK_SIZE_M, BLOCK_SIZE_K]
83+
).to(torch.float32)
8884

89-
acc = acc + hl.dot(a_tile, b_unpacked) # [BLOCK_SIZE_M, BLOCK_SIZE_N]
85+
a_tile = a_tile.unsqueeze(2) # [BLOCK_SIZE_M, BLOCK_SIZE_K, 1]
86+
b_unpacked = b_unpacked.unsqueeze(0)
87+
acc = acc + (a_tile * b_unpacked).sum(dim=1) # [BLOCK_SIZE_M, BLOCK_SIZE_N]
9088

9189
C[tile_m, tile_n] = acc.to(torch.bfloat16)
9290

@@ -113,14 +111,13 @@ def int4_gemm_tritonbench(tb_op: object, x: torch.Tensor, w: torch.Tensor) -> Ca
113111
Callable: A function that performs the int4 gemm.
114112
"""
115113

116-
def run_kernel() -> torch.Tensor:
117-
x_2d = x.reshape(-1, x.size(-1))
118-
119-
# Pack w to int4 format (two 4-bit values per int8 byte)
120-
w_int8 = w.to(torch.int8)
121-
w_reshaped = w_int8.reshape(w.shape[0] // 2, 2, w.shape[1]).permute(1, 0, 2)
122-
w_packed = ((w_reshaped[0] & 0xF) | (w_reshaped[1] << 4)).to(torch.int8)
114+
# Pack w to int4 format (two 4-bit values per int8 byte)
115+
x_2d = x.reshape(-1, x.size(-1))
116+
w_int8 = w.to(torch.int8)
117+
w_reshaped = w_int8.reshape(w.shape[0] // 2, 2, w.shape[1]).permute(1, 0, 2)
118+
w_packed = ((w_reshaped[0] & 0xF) | (w_reshaped[1] << 4)).to(torch.int8)
123119

120+
def run_kernel() -> torch.Tensor:
124121
return matmul_bf16_int4(x_2d, w_packed)
125122

126123
return run_kernel

test/test_examples.expected

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1429,7 +1429,7 @@ from torch._inductor.runtime import triton_helpers
14291429
from helion.runtime import default_launcher as _default_launcher
14301430

14311431
@triton.jit
1432-
def _helion_matmul_bf16_int4(B, A, C, A_stride_0, A_stride_1, B_stride_0, B_stride_1, C_stride_0, C_stride_1, M, N, K, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, mul: tl.constexpr):
1432+
def _helion_matmul_bf16_int4(A, B, C, A_stride_0, A_stride_1, B_stride_0, B_stride_1, C_stride_0, C_stride_1, M, N, K, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, mul_1: tl.constexpr):
14331433
num_blocks_0 = tl.cdiv(M, _BLOCK_SIZE_1)
14341434
pid_0 = tl.program_id(0) % num_blocks_0
14351435
pid_1 = tl.program_id(0) // num_blocks_0
@@ -1441,37 +1441,40 @@ def _helion_matmul_bf16_int4(B, A, C, A_stride_0, A_stride_1, B_stride_0, B_stri
14411441
mask_2 = indices_2 < N
14421442
acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32)
14431443
floordiv = triton_helpers.div_floor_integer(K, 2)
1444-
for offset_0 in tl.range(0, floordiv.to(tl.int32), _BLOCK_SIZE_0):
1445-
indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
1446-
mask_0 = indices_0 < floordiv
1444+
for offset_3 in tl.range(0, floordiv.to(tl.int32), _BLOCK_SIZE_0):
1445+
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
1446+
mask_0 = indices_3 < floordiv
14471447
acc_copy = acc
14481448
acc_copy_0 = acc_copy
1449-
b_tile = tl.load(B + (indices_0[:, None] * B_stride_0 + indices_2[None, :] * B_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
1450-
v_0 = tl.full([], 4, tl.int8)
1451-
v_1 = b_tile << v_0
1452-
v_2 = tl.full([], 4, tl.int8)
1453-
v_3 = v_1 >> v_2
1454-
v_4 = tl.full([], 4, tl.int8)
1455-
v_5 = b_tile >> v_4
1456-
v_6 = tl.cast(v_3, tl.bfloat16)
1457-
v_7 = tl.cast(v_5, tl.bfloat16)
1449+
mul = 2 * offset_3
1450+
iota = mul + tl.arange(0, mul_1)
1451+
load = tl.load(A + (indices_1[:, None] * A_stride_0 + iota[None, :] * A_stride_1), mask_1[:, None], other=0)
1452+
v_0 = tl.cast(load, tl.float32)
1453+
b_tile = tl.load(B + (indices_3[:, None] * B_stride_0 + indices_2[None, :] * B_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
1454+
v_1 = tl.full([], 4, tl.int8)
1455+
v_2 = b_tile << v_1
1456+
v_3 = tl.full([], 4, tl.int8)
1457+
v_4 = v_2 >> v_3
1458+
v_5 = tl.full([], 4, tl.int8)
1459+
v_6 = b_tile >> v_5
14581460
stack_idx = tl.arange(0, 2)
14591461
broadcast_idx = stack_idx[None, :, None]
1460-
expanded_0 = tl.expand_dims(v_6, 1)
1461-
expanded_1 = tl.expand_dims(v_7, 1)
1462+
expanded_0 = tl.expand_dims(v_4, 1)
1463+
expanded_1 = tl.expand_dims(v_6, 1)
14621464
stacked_result = tl.zeros_like(expanded_0)
1463-
mask_3 = broadcast_idx == 0
1464-
stacked_result = tl.where(mask_3, expanded_0, stacked_result)
1465-
mask_4 = broadcast_idx == 1
1466-
stacked_result = tl.where(mask_4, expanded_1, stacked_result)
1467-
b_unpacked = tl.reshape(stacked_result, [2 * _BLOCK_SIZE_0, _BLOCK_SIZE_2])
1468-
mul_5 = 2 * offset_0
1469-
iota = mul_5 + tl.arange(0, mul)
1470-
a_tile = tl.load(A + (indices_1[:, None] * A_stride_0 + iota[None, :] * A_stride_1), mask_1[:, None], other=0)
1471-
dot = tl.dot(tl.cast(a_tile, tl.bfloat16), tl.cast(b_unpacked, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32)
1472-
acc = acc_copy_0 + dot
1473-
v_9 = tl.cast(acc, tl.bfloat16)
1474-
tl.store(C + (indices_1[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), v_9, mask_1[:, None] & mask_2[None, :])
1465+
mask_4 = broadcast_idx == 0
1466+
stacked_result = tl.where(mask_4, expanded_0, stacked_result)
1467+
mask_5 = broadcast_idx == 1
1468+
stacked_result = tl.where(mask_5, expanded_1, stacked_result)
1469+
view = tl.reshape(stacked_result, [2 * _BLOCK_SIZE_0, _BLOCK_SIZE_2])
1470+
v_7 = tl.cast(view, tl.float32)
1471+
a_tile_1 = v_0[:, :, None]
1472+
b_unpacked_1 = v_7[None, :, :]
1473+
v_8 = a_tile_1 * b_unpacked_1
1474+
sum_1 = tl.cast(tl.sum(v_8, 1), tl.float32)
1475+
acc = acc_copy_0 + sum_1
1476+
v_10 = tl.cast(acc, tl.bfloat16)
1477+
tl.store(C + (indices_1[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), v_10, mask_1[:, None] & mask_2[None, :])
14751478

14761479
def matmul_bf16_int4(A: Tensor, B: Tensor, *, _launcher=_default_launcher):
14771480
"""
@@ -1495,7 +1498,8 @@ def matmul_bf16_int4(A: Tensor, B: Tensor, *, _launcher=_default_launcher):
14951498
_BLOCK_SIZE_1 = 64
14961499
_BLOCK_SIZE_2 = 32
14971500
_BLOCK_SIZE_0 = 64
1498-
_launcher(_helion_matmul_bf16_int4, (triton.cdiv(M, _BLOCK_SIZE_1) * triton.cdiv(N, _BLOCK_SIZE_2),), B, A, C, A.stride(0), A.stride(1), B.stride(0), B.stride(1), C.stride(0), C.stride(1), M, N, K, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_0, 2 * _BLOCK_SIZE_0, num_warps=4, num_stages=3)
1501+
_RDIM_SIZE_3 = triton.next_power_of_2(2 * _BLOCK_SIZE_0)
1502+
_launcher(_helion_matmul_bf16_int4, (triton.cdiv(M, _BLOCK_SIZE_1) * triton.cdiv(N, _BLOCK_SIZE_2),), A, B, C, A.stride(0), A.stride(1), B.stride(0), B.stride(1), C.stride(0), C.stride(1), M, N, K, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_0, 2 * _BLOCK_SIZE_0, num_warps=4, num_stages=3)
14991503
return C
15001504

15011505
--- assertExpectedJournal(TestExamples.test_jagged_dense_add)

0 commit comments

Comments
 (0)