@@ -279,17 +279,17 @@ __device__ __forceinline__ void token_bounds(int32_t n_tokens,
279279}
280280
281281template <int BLOCK_COUNT, int SMEM_SIZE_BYTES_Y, typename fp8_type,
282- int THREADS, typename Idx_t, bool USE_UE8M0, int GROUP_SIZE = 128 ,
283- int NUM_STAGES = 3 >
282+ typename scale_t , int THREADS, typename Idx_t, bool CEIL_UE8M0 ,
283+ int GROUP_SIZE = 128 , int NUM_STAGES = 3 >
284284__global__ void silu_mul_fp8_quant_deep_gemm_kernel (
285285 const __nv_bfloat16* __restrict__ _input, fp8_type* __restrict__ _y_q,
286- float * __restrict__ _y_s, const int32_t * __restrict__ tokens_per_expert,
286+ scale_t * __restrict__ _y_s, const int32_t * __restrict__ tokens_per_expert,
287287 // sizes
288288 Idx_t E, Idx_t T, Idx_t H,
289289 // strides (in elements)
290290 Idx_t stride_i_e, Idx_t stride_i_t , Idx_t stride_i_h, Idx_t stride_yq_e,
291291 Idx_t stride_yq_t , Idx_t stride_yq_h, Idx_t stride_ys_e, Idx_t stride_ys_t ,
292- Idx_t stride_ys_g, Idx_t stride_counts_e) {
292+ Idx_t stride_ys_g, Idx_t stride_ys_p, Idx_t stride_counts_e) {
293293#ifndef USE_ROCM
294294 static constexpr int NUM_WARPS = THREADS / WARP_SIZE;
295295
@@ -466,9 +466,22 @@ __global__ void silu_mul_fp8_quant_deep_gemm_kernel(
466466
467467 __nv_fp8x4_e4m3* y_q_base_ptr =
468468 reinterpret_cast <__nv_fp8x4_e4m3*>(_y_q) + lane_id;
469- auto y_scale_base_ptr = _y_s + warp_position_scales * stride_ys_g;
469+
470+ Idx_t scale_group_offset = 0 ;
471+ if constexpr (std::is_same<scale_t , uint8_t >::value) {
472+ // packed int32_t format
473+ int pack_id = warp_position_scales / 4 ;
474+ int scale_in_pack = warp_position_scales % 4 ;
475+ scale_group_offset = pack_id * stride_ys_p + scale_in_pack * stride_ys_g;
476+ } else {
477+ scale_group_offset = warp_position_scales * stride_ys_g;
478+ }
479+
480+ scale_t * const y_scale_base_ptr = _y_s + scale_group_offset;
470481
471482 for (auto j = tokens_lower; j < tokens_upper; j++) {
483+ int current_group_id = warp_position_scales; // Running count of which
484+ // group is being processed
472485 const Idx_t base_ys = expert_id * stride_ys_e;
473486 auto y_s_ptr = y_scale_base_ptr + base_ys + token_offset * stride_ys_t ;
474487 __nv_fp8x4_e4m3* y_q_ptr =
@@ -509,7 +522,7 @@ __global__ void silu_mul_fp8_quant_deep_gemm_kernel(
509522
510523 __nv_bfloat16 y_s = __hmul (warp_max (_y_max2.x ), fp8_inv);
511524
512- if constexpr (USE_UE8M0 ) {
525+ if constexpr (CEIL_UE8M0 ) {
513526 y_s = hexp2 (hceil (hlog2 (y_s)));
514527 }
515528
@@ -527,8 +540,24 @@ __global__ void silu_mul_fp8_quant_deep_gemm_kernel(
527540 y_q_ptr += WARP_SIZE * stride_yq_h;
528541
529542 if (!lane_id) {
530- *y_s_ptr = y_s;
531- y_s_ptr += stride_ys_g;
543+ // Store scales.
544+ if constexpr (std::is_same<scale_t , uint8_t >::value) {
545+ // Packed UE8MO format. Remove Mantissa.
546+ *y_s_ptr = reinterpret_cast <int16_t &>(y_s) >> 7 ;
547+
548+ bool const jump_pack = (current_group_id + 1 ) % 4 == 0 ;
549+ // Minus 3 because we need to get to the first group in the
550+ // next pack.
551+ y_s_ptr += jump_pack ? (stride_ys_p - 3 ) : stride_ys_g;
552+
553+ } else {
554+ // float32 format
555+ static_assert (std::is_same<scale_t , float >::value);
556+ *y_s_ptr = y_s;
557+ y_s_ptr += stride_ys_g;
558+ }
559+
560+ current_group_id += 1 ;
532561 }
533562 }
534563 }
@@ -573,7 +602,7 @@ void persistent_masked_m_silu_mul_quant(
573602 const at::Tensor& tokens_per_expert, // (E)
574603 at::Tensor& y_q, // (E, T, H) [OUT]
575604 at::Tensor& y_s, // (E, T, H//group_size) [OUT]
576- bool use_ue8m0 ) {
605+ bool cast_scale_ue8m0 ) {
577606#ifndef USE_ROCM
578607
579608 // This kernel currently only supports H % 128 == 0 and assumes a
@@ -583,9 +612,12 @@ void persistent_masked_m_silu_mul_quant(
583612 TORCH_CHECK (input.dtype () == torch::kBFloat16 );
584613 TORCH_CHECK (y_q.dtype () == torch::kFloat8_e4m3fn ||
585614 y_q.dtype () == torch::kFloat8_e4m3fnuz );
586- TORCH_CHECK (y_s.dtype () == torch::kFloat32 );
587615 TORCH_CHECK (input.size (-1 ) % (GROUP_SIZE * 2 ) == 0 );
588616
617+ bool const is_packed_ue8m0 =
618+ (y_s.dtype () == torch::kInt32 && cast_scale_ue8m0);
619+ TORCH_CHECK (y_s.dtype () == torch::kFloat32 || is_packed_ue8m0);
620+
589621 using Idx_t = int64_t ;
590622
591623 Idx_t E = input.size (0 );
@@ -597,15 +629,18 @@ void persistent_masked_m_silu_mul_quant(
597629 Idx_t stride_yq_e = y_q.stride (0 );
598630 Idx_t stride_yq_t = y_q.stride (1 );
599631 Idx_t stride_yq_h = y_q.stride (2 );
600- Idx_t stride_ys_e = y_s.stride (0 );
601- Idx_t stride_ys_t = y_s.stride (1 );
602- Idx_t stride_ys_g = y_s.stride (2 );
603632
604633 Idx_t stride_counts_e = tokens_per_expert.stride (0 );
605634
635+ int const NUM_GROUPS = H / GROUP_SIZE;
636+
606637 const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
607638
608- #define KERNEL (BLOCK_COUNT, USE_UE8M0, THREAD_COUNT, STAGES ) \
639+ // TODO: Get this from cuda_arch ?
640+ static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32 ;
641+
642+ #define KERNEL (BLOCK_COUNT, scale_t , STRIDE_YS_E, STRIDE_YS_T, STRIDE_YS_G, \
643+ STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, STAGES) \
609644 static constexpr int NUM_WARPS = THREAD_COUNT / WARP_SIZE; \
610645 int sms = SILU_V2_BLOCK_COUNT; \
611646 static constexpr int max_shared_mem_bytes = \
@@ -615,43 +650,86 @@ void persistent_masked_m_silu_mul_quant(
615650 VLLM_DISPATCH_FP8_TYPES ( \
616651 y_q.scalar_type (), " silu_mul_fp8_quant_deep_gemm_kernel" , [&] { \
617652 vllm::silu_mul_fp8_quant_deep_gemm_kernel< \
618- BLOCK_COUNT, max_shared_mem_bytes, fp8_t , THREAD_COUNT, Idx_t, \
619- USE_UE8M0, GROUP_SIZE, STAGES> \
653+ BLOCK_COUNT, max_shared_mem_bytes, fp8_t , scale_t , THREAD_COUNT, \
654+ Idx_t, CEIL_UE8M0, GROUP_SIZE, STAGES> \
620655 <<<grid, block, max_shared_mem_bytes + (E + 1 ) * 16 , stream>>> ( \
621656 reinterpret_cast <__nv_bfloat16*>(input.data_ptr ()), \
622- (fp8_t *)y_q.data_ptr (), y_s.data_ptr <float >(), \
657+ (fp8_t *)y_q.data_ptr (), \
658+ reinterpret_cast <scale_t *>(y_s.data_ptr ()), \
623659 reinterpret_cast <int32_t *>(tokens_per_expert.data_ptr ()), E, \
624660 T, H, stride_i_e, stride_i_t , stride_i_h, stride_yq_e, \
625- stride_yq_t , stride_yq_h, stride_ys_e, stride_ys_t , \
626- stride_ys_g, stride_counts_e); \
661+ stride_yq_t , stride_yq_h, STRIDE_YS_E, STRIDE_YS_T , \
662+ STRIDE_YS_G, STRIDE_YS_P, stride_counts_e); \
627663 });
628664
629- static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32 ;
630-
631- int const NUM_GROUPS = H / GROUP_SIZE;
632- if (!use_ue8m0) {
633- if (H >= 4096 && (NUM_GROUPS % 8 == 0 )) {
634- /* 8 warps config */
635- static constexpr int NUM_STAGES = 4 ;
636- static constexpr int THREAD_COUNT = 256 ;
637- KERNEL (SILU_V2_BLOCK_COUNT, false , THREAD_COUNT, NUM_STAGES);
638- } else {
639- /* 1 warp config */
640- static constexpr int THREAD_COUNT = 32 ;
641- KERNEL (SILU_V2_BLOCK_COUNT, false , THREAD_COUNT, 2 );
642- }
643- } else {
644- if (H >= 4096 && (NUM_GROUPS % 8 == 0 )) {
645- /* 8 warps config */
646- static constexpr int NUM_STAGES = 4 ;
647- static constexpr int THREAD_COUNT = 256 ;
648- KERNEL (SILU_V2_BLOCK_COUNT, true , THREAD_COUNT, NUM_STAGES);
649- } else {
650- /* 1 warp config */
651- static constexpr int THREAD_COUNT = 32 ;
652- KERNEL (SILU_V2_BLOCK_COUNT, true , THREAD_COUNT, 2 );
665+ #define LAUNCH_ON_H (scale_t , STRIDE_YS_E, STRIDE_YS_T, STRIDE_YS_G, \
666+ STRIDE_YS_P, CEIL_UE8M0) \
667+ if (H >= 4096 && (NUM_GROUPS % 8 ) == 0 ) { \
668+ /* 8 warp config */ \
669+ static constexpr int NUM_STAGES = 4 ; \
670+ static constexpr int THREAD_COUNT = 256 ; \
671+ KERNEL (SILU_V2_BLOCK_COUNT, scale_t , STRIDE_YS_E, STRIDE_YS_T, \
672+ STRIDE_YS_G, STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, NUM_STAGES); \
673+ } else { \
674+ /* 1 warp config */ \
675+ static constexpr int THREAD_COUNT = 32 ; \
676+ KERNEL (SILU_V2_BLOCK_COUNT, scale_t , STRIDE_YS_E, STRIDE_YS_T, \
677+ STRIDE_YS_G, STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, 2 ); \
653678 }
679+
680+ Idx_t stride_ys_e = y_s.stride (0 );
681+ Idx_t stride_ys_t = y_s.stride (1 );
682+ Idx_t stride_ys_g = y_s.stride (2 );
683+ Idx_t stride_ys_p = 0 ;
684+ if (!cast_scale_ue8m0) {
685+ TORCH_CHECK (!is_packed_ue8m0);
686+ LAUNCH_ON_H (float , stride_ys_e, stride_ys_t , stride_ys_g, stride_ys_p,
687+ false );
688+ return ;
689+ }
690+
691+ if (!is_packed_ue8m0) {
692+ // UE8M0 but not packed
693+ LAUNCH_ON_H (float , stride_ys_e, stride_ys_t , stride_ys_g, stride_ys_p,
694+ true );
695+ return ;
654696 }
655697
698+ TORCH_CHECK (cast_scale_ue8m0 && is_packed_ue8m0);
699+ TORCH_CHECK (y_s.dtype () == torch::kInt32 );
700+
701+ // Int32 packed ue8m0 scales tensor.
702+ // Let E, T, G be the number to experts, number of tokens and number of groups
703+ // respectively. Let, E = 2, T = 4, G = 6, in this case the int32 scales
704+ // tensor are of shape [1, 4, 2] and stride [8, 1, 4]. The scales are expected
705+ // to be arranged as follows,
706+ // [[T0G0-T0G1-T0G2-T0G3, T0G4-T0G5-X-X,],
707+ // [T1G0-T1G1-T1G2-T1G3, T1G4-T1G5-X-X,]
708+ // [T2G0-T2G1-T2G2-T2G3, T2G4-T2G5-X-X,]
709+ // [T3G0-T3G1-T3G2-T3G3, T3G4-T3G5-X-X,]]
710+ // where, TxGy is the scale ue8m0 scale value of Token x, Group y.
711+ //
712+ // In memory (in bytes) the scale values are arranged as,
713+ // [T0G0, T0G1, T0G2, T0G3, T1G0, T1G2, T1G3, T1G4, T2G0, T2G1, T2G3, T2G4,
714+ // T3G0, T3G1, T3G2, T3G3, T0G4, T0G5, X, X, T1G4, T1G5, X, X, T2G4, T2G5,
715+ // X, X, T3G4, T3G5, X, X]
716+ //
717+ // An Int32 tensor of size [1, 4, 2] and stride [8, 1, 4] can be represented
718+ // as an uint8 tensor of shape [1, 2, 4, 4] and stride [32, 16, 4, 1]. In
719+ // english, ignoring the Experts dimension, the original int32 tensor is
720+ // simply treated as two packed [4, 4] uint8 tensor (or two [4, 1] int32
721+ // tensor). The following strides setting reflects this change. Caveat: This
722+ // means that the G dimension is no longer contiguous. i.e. Note that to move
723+ // from G3 to G4, we need to jump along the packing dimension. The kernel
724+ // handles this case.
725+
726+ stride_ys_e *= sizeof (int32_t );
727+ stride_ys_p = T * sizeof (int32_t ); // Packing dimension
728+ stride_ys_t = sizeof (int32_t );
729+ stride_ys_g = 1 ;
730+
731+ LAUNCH_ON_H (uint8_t , stride_ys_e, stride_ys_t , stride_ys_g, stride_ys_p,
732+ true );
733+
656734#endif
657735}
0 commit comments