Skip to content

Commit dc7a48e

Browse files
author
Varun Sundar Rabindranath
committed
pass in disable_ue8m0_cast
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
1 parent 581b3f2 commit dc7a48e

File tree

2 files changed

+18
-5
lines changed

2 files changed

+18
-5
lines changed

vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,12 +352,22 @@ def apply(
352352
expected_m,
353353
)
354354

355+
quant_scale_fmt = DeepGemmQuantScaleFMT.from_target_arch()
355356
a2q, a2q_scale = persistent_masked_m_silu_mul_quant(
356357
workspace1,
357358
expert_num_tokens,
358-
quant_scale_fmt=DeepGemmQuantScaleFMT.from_target_arch(),
359+
quant_scale_fmt=quant_scale_fmt,
359360
)
360361

362+
# If we have committed to the UE8M0 format. This flag must be set so
363+
# DeepGEMM does the same to the weights if they are not in UE8M0
364+
# format.
365+
enable_dg_ue8m0_cast = quant_scale_fmt == DeepGemmQuantScaleFMT.UE8M0
361366
fp8_m_grouped_gemm_nt_masked(
362-
(a2q, a2q_scale), (w2, self.w2_scale), output, expert_num_tokens, expected_m
367+
(a2q, a2q_scale),
368+
(w2, self.w2_scale),
369+
output,
370+
expert_num_tokens,
371+
expected_m,
372+
disable_ue8m0_cast=not enable_dg_ue8m0_cast,
363373
)

vllm/utils/deep_gemm.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,12 @@ def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
195195
_lazy_init()
196196
if _grouped_masked_impl is None:
197197
return _missing(*args, **kwargs)
198-
return _grouped_masked_impl(
199-
*args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs
200-
)
198+
if "disable_ue8m0_cast" in kwargs:
199+
disable_ue8m0_cast = kwargs["disable_ue8m0_cast"]
200+
del kwargs["disable_ue8m0_cast"]
201+
else:
202+
disable_ue8m0_cast = not is_deep_gemm_e8m0_used()
203+
return _grouped_masked_impl(*args, disable_ue8m0_cast=disable_ue8m0_cast, **kwargs)
201204

202205

203206
def fp8_mqa_logits(

0 commit comments

Comments
 (0)