From 740071a28e07d6ac5a46f79ceb6f1303155ebb0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C5=A9=20Ho=C3=A0ng=20Nh=E1=BA=ADt=20Tr=C6=B0=E1=BB=9Dng?= Date: Wed, 15 Apr 2026 11:51:11 +0000 Subject: [PATCH 1/2] Fix: Handle TorchDynamo incompatible with torch._scaled_mm --- colossalai/quantization/fp8.py | 47 +++++++++++++++------------------- 1 file changed, 21 insertions(+), 26 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index e23da5cccd4d..a7793ee7d8f8 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -797,37 +797,32 @@ def forward( ctx.w_fp8_t = w_fp8.t() ctx.inv_scale_x = inv_scale_x ctx.inv_scale_w = inv_scale_w - out = torch._scaled_mm( - x_fp8, - ctx.w_fp8_t, - bias=bias, - out_dtype=ctx.out_dtype, - scale_a=inv_scale_x, - scale_b=inv_scale_w, - use_fast_accum=True, - )[0] + + # Dequantize and compute matrix multiplication (compatible with TorchDynamo) + x_deq = x_fp8.to(ctx.out_dtype) * inv_scale_x + w_t_deq = ctx.w_fp8_t.to(ctx.out_dtype) * inv_scale_w + + out = x_deq @ w_t_deq + if bias is not None: + out = out + bias.to(ctx.out_dtype) + + out = out.to(ctx.out_dtype) return out.reshape(*ctx.x_shape[:-1], w.shape[0]) @staticmethod def backward(ctx: Any, out_grad) -> Any: out_grad = out_grad.reshape(-1, out_grad.shape[-1]) out_grad_fp8, out_grad_scale = cast_to_fp8(out_grad, fp8_format="e5m2") - x_grad = torch._scaled_mm( - out_grad_fp8, - ctx.w_fp8_t.contiguous().t(), - out_dtype=ctx.out_dtype, - scale_a=out_grad_scale, - scale_b=ctx.inv_scale_w, - use_fast_accum=True, - )[0] - w_grad = torch._scaled_mm( - out_grad_fp8.t().contiguous(), - ctx.x_fp8.t().contiguous().t(), - out_dtype=ctx.out_dtype, - scale_a=out_grad_scale, - scale_b=ctx.inv_scale_x, - use_fast_accum=True, - )[0] + + # Dequantize (force contiguous after cast) + out_grad_deq = (out_grad_fp8.to(ctx.out_dtype) * out_grad_scale).contiguous() + w_t_deq = (ctx.w_fp8_t.to(ctx.out_dtype) * ctx.inv_scale_w).contiguous() + x_deq = (ctx.x_fp8.to(ctx.out_dtype) * ctx.inv_scale_x).contiguous() + + # Compute gradients + x_grad = out_grad_deq @ w_t_deq.t() + w_grad = out_grad_deq.t() @ x_deq + bias_grad = None if ctx.has_bias: bias_grad = out_grad.sum(0) @@ -843,4 +838,4 @@ def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.T if input.shape[-1] % 16 != 0 or np.prod(input.shape[:-1]) % 16 != 0: return F.linear(input, weight, bias) out = _linear_fp8(input, weight, bias) - return out + return out \ No newline at end of file From c3571c7c7007c051f55c2dc72284117de2b08905 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 15 Apr 2026 12:02:15 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/quantization/fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index a7793ee7d8f8..0fbc5a850144 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -838,4 +838,4 @@ def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.T if input.shape[-1] % 16 != 0 or np.prod(input.shape[:-1]) % 16 != 0: return F.linear(input, weight, bias) out = _linear_fp8(input, weight, bias) - return out \ No newline at end of file + return out