Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions csrc/quantization/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -578,11 +578,13 @@ void persistent_masked_m_silu_mul_quant(

// This kernel currently only supports H % 128 == 0 and assumes a
// fixed GROUP_SIZE of 128.
static constexpr int GROUP_SIZE = 128;

TORCH_CHECK(input.dtype() == torch::kBFloat16);
TORCH_CHECK(y_q.dtype() == torch::kFloat8_e4m3fn ||
y_q.dtype() == torch::kFloat8_e4m3fnuz);
TORCH_CHECK(y_s.dtype() == torch::kFloat32);
TORCH_CHECK(input.size(-1) % 256 == 0);
TORCH_CHECK(input.size(-1) % (GROUP_SIZE * 2) == 0);

using Idx_t = int64_t;

Expand All @@ -601,8 +603,6 @@ void persistent_masked_m_silu_mul_quant(

Idx_t stride_counts_e = tokens_per_expert.stride(0);

static constexpr int GROUP_SIZE = 128;

const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

#define KERNEL(BLOCK_COUNT, USE_UE8M0, THREAD_COUNT, STAGES) \
Expand All @@ -628,21 +628,26 @@ void persistent_masked_m_silu_mul_quant(

static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32;

int const NUM_GROUPS = H / GROUP_SIZE;
if (!use_ue8m0) {
if (H >= 4096) {
if (H >= 4096 && (NUM_GROUPS % 8 == 0)) {
/* 8 warps config */
static constexpr int NUM_STAGES = 4;
static constexpr int THREAD_COUNT = 256;
KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, NUM_STAGES);
} else {
/* 1 warp config */
static constexpr int THREAD_COUNT = 32;
KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, 2);
}
} else {
if (H >= 4096) {
if (H >= 4096 && (NUM_GROUPS % 8 == 0)) {
/* 8 warps config */
static constexpr int NUM_STAGES = 4;
static constexpr int THREAD_COUNT = 256;
KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, NUM_STAGES);
} else {
/* 1 warp config */
static constexpr int THREAD_COUNT = 32;
KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, 2);
}
Expand Down
5 changes: 4 additions & 1 deletion tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
(8, 16, 128 * 2, fp8_dtype),
(8, 16, 128 * 3, fp8_dtype),
(8, 64, 7168, fp8_dtype),
(8, 128, 128 * 33, fp8_dtype),
(8, 128, 7168, fp8_dtype),
(8, 512, 7168, fp8_dtype),
(8, 1024, 7168, fp8_dtype),
Expand Down Expand Up @@ -54,8 +55,10 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type):
)

# Run the SiLU V2 kernel
# TODO (varun): use_e8m0 is set to false as the reference impl does
# not handle that case.
y_q, y_s = persistent_masked_m_silu_mul_quant(
y, tokens_per_expert, group_size=group_size
y, tokens_per_expert, group_size=group_size, use_ue8m0=False
)

torch.cuda.synchronize()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def persistent_masked_m_silu_mul_quant(
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
num_parallel_tokens=16,
group_size: int = 128,
use_ue8m0: bool | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
y has shape (E, T, 2*H). The first half of the last dimension is
Expand Down Expand Up @@ -164,7 +165,7 @@ def persistent_masked_m_silu_mul_quant(
device=y.device,
)

use_ue8m0 = is_deep_gemm_e8m0_used()
use_ue8m0 = use_ue8m0 if use_ue8m0 is not None else is_deep_gemm_e8m0_used()

cuda_arch = current_platform.get_device_capability(
device_id=y.device.index
Expand Down