diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu index 6fcd246f63c5..2521b2797e2c 100644 --- a/csrc/quantization/activation_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -578,11 +578,13 @@ void persistent_masked_m_silu_mul_quant( // This kernel currently only supports H % 128 == 0 and assumes a // fixed GROUP_SIZE of 128. + static constexpr int GROUP_SIZE = 128; + TORCH_CHECK(input.dtype() == torch::kBFloat16); TORCH_CHECK(y_q.dtype() == torch::kFloat8_e4m3fn || y_q.dtype() == torch::kFloat8_e4m3fnuz); TORCH_CHECK(y_s.dtype() == torch::kFloat32); - TORCH_CHECK(input.size(-1) % 256 == 0); + TORCH_CHECK(input.size(-1) % (GROUP_SIZE * 2) == 0); using Idx_t = int64_t; @@ -601,8 +603,6 @@ void persistent_masked_m_silu_mul_quant( Idx_t stride_counts_e = tokens_per_expert.stride(0); - static constexpr int GROUP_SIZE = 128; - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); #define KERNEL(BLOCK_COUNT, USE_UE8M0, THREAD_COUNT, STAGES) \ @@ -628,21 +628,26 @@ void persistent_masked_m_silu_mul_quant( static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32; + int const NUM_GROUPS = H / GROUP_SIZE; if (!use_ue8m0) { - if (H >= 4096) { + if (H >= 4096 && (NUM_GROUPS % 8 == 0)) { + /* 8 warps config */ static constexpr int NUM_STAGES = 4; static constexpr int THREAD_COUNT = 256; KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, NUM_STAGES); } else { + /* 1 warp config */ static constexpr int THREAD_COUNT = 32; KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, 2); } } else { - if (H >= 4096) { + if (H >= 4096 && (NUM_GROUPS % 8 == 0)) { + /* 8 warps config */ static constexpr int NUM_STAGES = 4; static constexpr int THREAD_COUNT = 256; KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, NUM_STAGES); } else { + /* 1 warp config */ static constexpr int THREAD_COUNT = 32; KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, 2); } diff --git a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py index 97a55c37b9a3..420dbbffaac0 100644 --- a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py +++ b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py @@ -25,6 +25,7 @@ (8, 16, 128 * 2, fp8_dtype), (8, 16, 128 * 3, fp8_dtype), (8, 64, 7168, fp8_dtype), + (8, 128, 128 * 33, fp8_dtype), (8, 128, 7168, fp8_dtype), (8, 512, 7168, fp8_dtype), (8, 1024, 7168, fp8_dtype), @@ -54,8 +55,10 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type): ) # Run the SiLU V2 kernel + # TODO (varun): use_e8m0 is set to false as the reference impl does + # not handle that case. y_q, y_s = persistent_masked_m_silu_mul_quant( - y, tokens_per_expert, group_size=group_size + y, tokens_per_expert, group_size=group_size, use_ue8m0=False ) torch.cuda.synchronize() diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 095ec966ea7e..b8a97e92ab79 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -100,6 +100,7 @@ def persistent_masked_m_silu_mul_quant( tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert num_parallel_tokens=16, group_size: int = 128, + use_ue8m0: bool | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales y has shape (E, T, 2*H). The first half of the last dimension is @@ -164,7 +165,7 @@ def persistent_masked_m_silu_mul_quant( device=y.device, ) - use_ue8m0 = is_deep_gemm_e8m0_used() + use_ue8m0 = use_ue8m0 if use_ue8m0 is not None else is_deep_gemm_e8m0_used() cuda_arch = current_platform.get_device_capability( device_id=y.device.index