Skip to content

Commit fe1cd77

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranath
andauthored
[Performance][B200] silu_mul_quant: pack scales in int32 (#28358)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
1 parent fdfd507 commit fe1cd77

File tree

7 files changed

+461
-146
lines changed

7 files changed

+461
-146
lines changed

csrc/quantization/activation_kernels.cu

Lines changed: 121 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -279,17 +279,17 @@ __device__ __forceinline__ void token_bounds(int32_t n_tokens,
279279
}
280280

281281
template <int BLOCK_COUNT, int SMEM_SIZE_BYTES_Y, typename fp8_type,
282-
int THREADS, typename Idx_t, bool USE_UE8M0, int GROUP_SIZE = 128,
283-
int NUM_STAGES = 3>
282+
typename scale_t, int THREADS, typename Idx_t, bool CEIL_UE8M0,
283+
int GROUP_SIZE = 128, int NUM_STAGES = 3>
284284
__global__ void silu_mul_fp8_quant_deep_gemm_kernel(
285285
const __nv_bfloat16* __restrict__ _input, fp8_type* __restrict__ _y_q,
286-
float* __restrict__ _y_s, const int32_t* __restrict__ tokens_per_expert,
286+
scale_t* __restrict__ _y_s, const int32_t* __restrict__ tokens_per_expert,
287287
// sizes
288288
Idx_t E, Idx_t T, Idx_t H,
289289
// strides (in elements)
290290
Idx_t stride_i_e, Idx_t stride_i_t, Idx_t stride_i_h, Idx_t stride_yq_e,
291291
Idx_t stride_yq_t, Idx_t stride_yq_h, Idx_t stride_ys_e, Idx_t stride_ys_t,
292-
Idx_t stride_ys_g, Idx_t stride_counts_e) {
292+
Idx_t stride_ys_g, Idx_t stride_ys_p, Idx_t stride_counts_e) {
293293
#ifndef USE_ROCM
294294
static constexpr int NUM_WARPS = THREADS / WARP_SIZE;
295295

@@ -466,9 +466,22 @@ __global__ void silu_mul_fp8_quant_deep_gemm_kernel(
466466

467467
__nv_fp8x4_e4m3* y_q_base_ptr =
468468
reinterpret_cast<__nv_fp8x4_e4m3*>(_y_q) + lane_id;
469-
auto y_scale_base_ptr = _y_s + warp_position_scales * stride_ys_g;
469+
470+
Idx_t scale_group_offset = 0;
471+
if constexpr (std::is_same<scale_t, uint8_t>::value) {
472+
// packed int32_t format
473+
int pack_id = warp_position_scales / 4;
474+
int scale_in_pack = warp_position_scales % 4;
475+
scale_group_offset = pack_id * stride_ys_p + scale_in_pack * stride_ys_g;
476+
} else {
477+
scale_group_offset = warp_position_scales * stride_ys_g;
478+
}
479+
480+
scale_t* const y_scale_base_ptr = _y_s + scale_group_offset;
470481

471482
for (auto j = tokens_lower; j < tokens_upper; j++) {
483+
int current_group_id = warp_position_scales; // Running count of which
484+
// group is being processed
472485
const Idx_t base_ys = expert_id * stride_ys_e;
473486
auto y_s_ptr = y_scale_base_ptr + base_ys + token_offset * stride_ys_t;
474487
__nv_fp8x4_e4m3* y_q_ptr =
@@ -509,7 +522,7 @@ __global__ void silu_mul_fp8_quant_deep_gemm_kernel(
509522

510523
__nv_bfloat16 y_s = __hmul(warp_max(_y_max2.x), fp8_inv);
511524

512-
if constexpr (USE_UE8M0) {
525+
if constexpr (CEIL_UE8M0) {
513526
y_s = hexp2(hceil(hlog2(y_s)));
514527
}
515528

@@ -527,8 +540,24 @@ __global__ void silu_mul_fp8_quant_deep_gemm_kernel(
527540
y_q_ptr += WARP_SIZE * stride_yq_h;
528541

529542
if (!lane_id) {
530-
*y_s_ptr = y_s;
531-
y_s_ptr += stride_ys_g;
543+
// Store scales.
544+
if constexpr (std::is_same<scale_t, uint8_t>::value) {
545+
// Packed UE8MO format. Remove Mantissa.
546+
*y_s_ptr = reinterpret_cast<int16_t&>(y_s) >> 7;
547+
548+
bool const jump_pack = (current_group_id + 1) % 4 == 0;
549+
// Minus 3 because we need to get to the first group in the
550+
// next pack.
551+
y_s_ptr += jump_pack ? (stride_ys_p - 3) : stride_ys_g;
552+
553+
} else {
554+
// float32 format
555+
static_assert(std::is_same<scale_t, float>::value);
556+
*y_s_ptr = y_s;
557+
y_s_ptr += stride_ys_g;
558+
}
559+
560+
current_group_id += 1;
532561
}
533562
}
534563
}
@@ -573,7 +602,7 @@ void persistent_masked_m_silu_mul_quant(
573602
const at::Tensor& tokens_per_expert, // (E)
574603
at::Tensor& y_q, // (E, T, H) [OUT]
575604
at::Tensor& y_s, // (E, T, H//group_size) [OUT]
576-
bool use_ue8m0) {
605+
bool cast_scale_ue8m0) {
577606
#ifndef USE_ROCM
578607

579608
// This kernel currently only supports H % 128 == 0 and assumes a
@@ -583,9 +612,12 @@ void persistent_masked_m_silu_mul_quant(
583612
TORCH_CHECK(input.dtype() == torch::kBFloat16);
584613
TORCH_CHECK(y_q.dtype() == torch::kFloat8_e4m3fn ||
585614
y_q.dtype() == torch::kFloat8_e4m3fnuz);
586-
TORCH_CHECK(y_s.dtype() == torch::kFloat32);
587615
TORCH_CHECK(input.size(-1) % (GROUP_SIZE * 2) == 0);
588616

617+
bool const is_packed_ue8m0 =
618+
(y_s.dtype() == torch::kInt32 && cast_scale_ue8m0);
619+
TORCH_CHECK(y_s.dtype() == torch::kFloat32 || is_packed_ue8m0);
620+
589621
using Idx_t = int64_t;
590622

591623
Idx_t E = input.size(0);
@@ -597,15 +629,18 @@ void persistent_masked_m_silu_mul_quant(
597629
Idx_t stride_yq_e = y_q.stride(0);
598630
Idx_t stride_yq_t = y_q.stride(1);
599631
Idx_t stride_yq_h = y_q.stride(2);
600-
Idx_t stride_ys_e = y_s.stride(0);
601-
Idx_t stride_ys_t = y_s.stride(1);
602-
Idx_t stride_ys_g = y_s.stride(2);
603632

604633
Idx_t stride_counts_e = tokens_per_expert.stride(0);
605634

635+
int const NUM_GROUPS = H / GROUP_SIZE;
636+
606637
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
607638

608-
#define KERNEL(BLOCK_COUNT, USE_UE8M0, THREAD_COUNT, STAGES) \
639+
// TODO: Get this from cuda_arch ?
640+
static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32;
641+
642+
#define KERNEL(BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, STRIDE_YS_G, \
643+
STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, STAGES) \
609644
static constexpr int NUM_WARPS = THREAD_COUNT / WARP_SIZE; \
610645
int sms = SILU_V2_BLOCK_COUNT; \
611646
static constexpr int max_shared_mem_bytes = \
@@ -615,43 +650,86 @@ void persistent_masked_m_silu_mul_quant(
615650
VLLM_DISPATCH_FP8_TYPES( \
616651
y_q.scalar_type(), "silu_mul_fp8_quant_deep_gemm_kernel", [&] { \
617652
vllm::silu_mul_fp8_quant_deep_gemm_kernel< \
618-
BLOCK_COUNT, max_shared_mem_bytes, fp8_t, THREAD_COUNT, Idx_t, \
619-
USE_UE8M0, GROUP_SIZE, STAGES> \
653+
BLOCK_COUNT, max_shared_mem_bytes, fp8_t, scale_t, THREAD_COUNT, \
654+
Idx_t, CEIL_UE8M0, GROUP_SIZE, STAGES> \
620655
<<<grid, block, max_shared_mem_bytes + (E + 1) * 16, stream>>>( \
621656
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \
622-
(fp8_t*)y_q.data_ptr(), y_s.data_ptr<float>(), \
657+
(fp8_t*)y_q.data_ptr(), \
658+
reinterpret_cast<scale_t*>(y_s.data_ptr()), \
623659
reinterpret_cast<int32_t*>(tokens_per_expert.data_ptr()), E, \
624660
T, H, stride_i_e, stride_i_t, stride_i_h, stride_yq_e, \
625-
stride_yq_t, stride_yq_h, stride_ys_e, stride_ys_t, \
626-
stride_ys_g, stride_counts_e); \
661+
stride_yq_t, stride_yq_h, STRIDE_YS_E, STRIDE_YS_T, \
662+
STRIDE_YS_G, STRIDE_YS_P, stride_counts_e); \
627663
});
628664

629-
static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32;
630-
631-
int const NUM_GROUPS = H / GROUP_SIZE;
632-
if (!use_ue8m0) {
633-
if (H >= 4096 && (NUM_GROUPS % 8 == 0)) {
634-
/* 8 warps config */
635-
static constexpr int NUM_STAGES = 4;
636-
static constexpr int THREAD_COUNT = 256;
637-
KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, NUM_STAGES);
638-
} else {
639-
/* 1 warp config */
640-
static constexpr int THREAD_COUNT = 32;
641-
KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, 2);
642-
}
643-
} else {
644-
if (H >= 4096 && (NUM_GROUPS % 8 == 0)) {
645-
/* 8 warps config */
646-
static constexpr int NUM_STAGES = 4;
647-
static constexpr int THREAD_COUNT = 256;
648-
KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, NUM_STAGES);
649-
} else {
650-
/* 1 warp config */
651-
static constexpr int THREAD_COUNT = 32;
652-
KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, 2);
665+
#define LAUNCH_ON_H(scale_t, STRIDE_YS_E, STRIDE_YS_T, STRIDE_YS_G, \
666+
STRIDE_YS_P, CEIL_UE8M0) \
667+
if (H >= 4096 && (NUM_GROUPS % 8) == 0) { \
668+
/* 8 warp config */ \
669+
static constexpr int NUM_STAGES = 4; \
670+
static constexpr int THREAD_COUNT = 256; \
671+
KERNEL(SILU_V2_BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, \
672+
STRIDE_YS_G, STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, NUM_STAGES); \
673+
} else { \
674+
/* 1 warp config */ \
675+
static constexpr int THREAD_COUNT = 32; \
676+
KERNEL(SILU_V2_BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, \
677+
STRIDE_YS_G, STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, 2); \
653678
}
679+
680+
Idx_t stride_ys_e = y_s.stride(0);
681+
Idx_t stride_ys_t = y_s.stride(1);
682+
Idx_t stride_ys_g = y_s.stride(2);
683+
Idx_t stride_ys_p = 0;
684+
if (!cast_scale_ue8m0) {
685+
TORCH_CHECK(!is_packed_ue8m0);
686+
LAUNCH_ON_H(float, stride_ys_e, stride_ys_t, stride_ys_g, stride_ys_p,
687+
false);
688+
return;
689+
}
690+
691+
if (!is_packed_ue8m0) {
692+
// UE8M0 but not packed
693+
LAUNCH_ON_H(float, stride_ys_e, stride_ys_t, stride_ys_g, stride_ys_p,
694+
true);
695+
return;
654696
}
655697

698+
TORCH_CHECK(cast_scale_ue8m0 && is_packed_ue8m0);
699+
TORCH_CHECK(y_s.dtype() == torch::kInt32);
700+
701+
// Int32 packed ue8m0 scales tensor.
702+
// Let E, T, G be the number to experts, number of tokens and number of groups
703+
// respectively. Let, E = 2, T = 4, G = 6, in this case the int32 scales
704+
// tensor are of shape [1, 4, 2] and stride [8, 1, 4]. The scales are expected
705+
// to be arranged as follows,
706+
// [[T0G0-T0G1-T0G2-T0G3, T0G4-T0G5-X-X,],
707+
// [T1G0-T1G1-T1G2-T1G3, T1G4-T1G5-X-X,]
708+
// [T2G0-T2G1-T2G2-T2G3, T2G4-T2G5-X-X,]
709+
// [T3G0-T3G1-T3G2-T3G3, T3G4-T3G5-X-X,]]
710+
// where, TxGy is the scale ue8m0 scale value of Token x, Group y.
711+
//
712+
// In memory (in bytes) the scale values are arranged as,
713+
// [T0G0, T0G1, T0G2, T0G3, T1G0, T1G2, T1G3, T1G4, T2G0, T2G1, T2G3, T2G4,
714+
// T3G0, T3G1, T3G2, T3G3, T0G4, T0G5, X, X, T1G4, T1G5, X, X, T2G4, T2G5,
715+
// X, X, T3G4, T3G5, X, X]
716+
//
717+
// An Int32 tensor of size [1, 4, 2] and stride [8, 1, 4] can be represented
718+
// as an uint8 tensor of shape [1, 2, 4, 4] and stride [32, 16, 4, 1]. In
719+
// english, ignoring the Experts dimension, the original int32 tensor is
720+
// simply treated as two packed [4, 4] uint8 tensor (or two [4, 1] int32
721+
// tensor). The following strides setting reflects this change. Caveat: This
722+
// means that the G dimension is no longer contiguous. i.e. Note that to move
723+
// from G3 to G4, we need to jump along the packing dimension. The kernel
724+
// handles this case.
725+
726+
stride_ys_e *= sizeof(int32_t);
727+
stride_ys_p = T * sizeof(int32_t); // Packing dimension
728+
stride_ys_t = sizeof(int32_t);
729+
stride_ys_g = 1;
730+
731+
LAUNCH_ON_H(uint8_t, stride_ys_e, stride_ys_t, stride_ys_g, stride_ys_p,
732+
true);
733+
656734
#endif
657735
}

tests/conftest.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,3 +1384,16 @@ def image_urls(request, local_asset_server) -> list[str]:
13841384
"""Indirect fixture: takes a list of names, returns list of full URLs."""
13851385
names: list[str] = request.param
13861386
return [local_asset_server.url_for(name) for name in names]
1387+
1388+
1389+
@pytest.fixture
1390+
def disable_deepgemm_ue8m0(monkeypatch):
1391+
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
1392+
1393+
with monkeypatch.context() as monkeypatch_ctx:
1394+
monkeypatch_ctx.setenv("VLLM_USE_DEEP_GEMM_E8M0", "0")
1395+
is_deep_gemm_e8m0_used.cache_clear()
1396+
yield
1397+
# Clear cache so the next time it is used it is processed with the
1398+
# default VLLM_USE_DEEP_GEMM_E8M0 setting.
1399+
is_deep_gemm_e8m0_used.cache_clear()

tests/kernels/moe/test_deepep_deepgemm_moe.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@
2121
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
2222
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
2323
from vllm.platforms import current_platform
24-
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported
24+
from vllm.utils.deep_gemm import (
25+
get_mk_alignment_for_contiguous_layout,
26+
is_deep_gemm_e8m0_used,
27+
is_deep_gemm_supported,
28+
)
2529
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm
2630

2731
from ...utils import multi_gpu_test
@@ -413,27 +417,24 @@ def _test_deepep_deepgemm_moe(
413417
@multi_gpu_test(num_gpus=2)
414418
@requires_deep_ep
415419
@requires_deep_gemm
416-
@pytest.mark.skipif(
417-
is_deep_gemm_e8m0_used(), reason="Skipping test for Blackwell DeepGEMM"
418-
)
419420
def test_ht_deepep_deepgemm_moe(
420421
mnk: tuple[int, int, int],
421422
num_experts: int,
422423
topk: int,
423424
world_dp_size: tuple[int, int],
425+
disable_deepgemm_ue8m0,
424426
):
425427
"""
426428
Tests for High-Throughput DeepEP + DeepGemm integration.
427429
"""
428-
import deep_gemm
429430

430431
m, n, k = mnk
431432
current_platform.seed_everything(7)
432433

433434
if topk > num_experts:
434435
pytest.skip(f"Skipping test: topk={topk} > E={num_experts}")
435436

436-
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
437+
block_m = get_mk_alignment_for_contiguous_layout()[0]
437438
block_size = [block_m, block_m]
438439

439440
world_size, dp_size = world_dp_size
@@ -487,20 +488,19 @@ def test_ht_deepep_deepgemm_moe(
487488
@multi_gpu_test(num_gpus=2)
488489
@requires_deep_ep
489490
@requires_deep_gemm
490-
@pytest.mark.skipif(
491-
is_deep_gemm_e8m0_used(), reason="Skipping test for Blackwell DeepGEMM"
492-
)
493491
def test_ll_deepep_deepgemm_moe(
494492
mnk: tuple[int, int, int],
495493
num_experts: int,
496494
topk: int,
497495
use_fp8_dispatch: bool,
498496
block_size: list[int],
499497
world_dp_size: tuple[int, int],
498+
disable_deepgemm_ue8m0,
500499
):
501500
"""
502501
Tests for Low-Latency DeepEP + DeepGemm integration.
503502
"""
503+
assert not is_deep_gemm_e8m0_used()
504504

505505
m, n, k = mnk
506506
current_platform.seed_everything(7)

tests/kernels/moe/test_deepep_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def torch_moe_impl(
294294
# blockwise quant and de-quant.
295295
assert not per_act_token_quant
296296
a = test_tensors.rank_tokens
297-
aq, aq_scale = per_token_group_quant_fp8(a, 128)
297+
aq, aq_scale = per_token_group_quant_fp8(a, 128, use_ue8m0=False)
298298
a = (
299299
(aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1))
300300
.view(a.shape)

0 commit comments

Comments
 (0)