From 872782ecef0121c66f2e2065dcef0802cfa23254 Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Thu, 20 Nov 2025 08:58:47 +0800 Subject: [PATCH 1/6] call fp8 fsdpa for the context and query twice --- .../fp8_quant/_quant_common/helper_modules.py | 116 ++++++++++++++---- 1 file changed, 93 insertions(+), 23 deletions(-) diff --git a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py index 05770f7b171..cb4ddc02cd9 100755 --- a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import torch import torch.nn as nn import types @@ -1330,6 +1331,40 @@ def forward_qdq( seq_padding_type, ) return results + + def fp8_fsdpa_fwd(self, + q, + k, + v, + attn_mask, + dropout_p, + scale, + softmax_mode, + ): + results = torch.ops.hpu.fp8_sdpa_recomp_fwd( + q, + k, + v, + attn_mask, + dropout_p, + scale, + False, # is_causal + True, # requires_backward + softmax_mode, # softmax_mode + self.scale_q, # d_scale_q + self.scale_k, # d_scale_k + self.scale_v, # d_scale_v + self.scale_amax, # q_scale_s + self.scale_output, # q_scale_o + self.descale_amax, # d_scale_s + False, # is_amax_s + False, # is_amax_o + None, # valid_seq_len + "right", # seq_padding_type + (-1, -1), # window_size + None, # sink + ) + return results def forward_quant( self, @@ -1345,32 +1380,67 @@ def forward_quant( valid_seq_len=None, seq_padding_type="None", ): - sm_mode = softmax_mode if softmax_mode == "fp32" else "None" + sm_mode = softmax_mode if softmax_mode == "fp32" else "none" qinput = self.quant_q(q).detach() kinput = self.quant_k(k).detach() vinput = self.quant_v(v).detach() - results = self.fp8_fused_sdpa( - qinput, - kinput, - vinput, - attn_mask=attn_mask, - dropout_p=dropout_p, - is_causal=is_causal, - scale=scale, - softmax_mode=sm_mode, - d_scale_q=self.scale_q, - d_scale_k=self.scale_k, - d_scale_v=self.scale_v, - q_scale_s=self.scale_amax, - q_scale_o=self.scale_output, - d_scale_s=self.descale_amax, - is_amax_s=False, - valid_seq_len=valid_seq_len, - seq_padding_type=seq_padding_type, - ) - output = results[0] - d_out = self.dequant_output(output) - return d_out + q_len = q.shape[-2] + kv_len = kinput.size(-2) + + # for prefill with prefix caching + if q_len != 1 and q_len != kv_len: + from habana_frameworks.torch.hpex.kernels.Fp8FusedSDPA import is_gqa, gqa_input_reshape_fwd, gqa_output_reshape + gqa = is_gqa(qinput, kinput) + if gqa: + qinput, kinput, vinput, attn_mask = gqa_input_reshape_fwd(qinput, kinput, vinput, attn_mask) + + prefix_k = kinput[..., :kv_len - q_len, :] + prefix_v = vinput[..., :kv_len - q_len, :] + prefix_mask = None + prefix_res = self.fp8_fsdpa_fwd(qinput, prefix_k, prefix_v, prefix_mask, dropout_p, scale, sm_mode) + prefix_out, prefix_m, prefix_linv = (gqa_output_reshape(x) for x in (prefix_res[:3])) if gqa else prefix_res[:3] + + text_k = kinput[..., kv_len - q_len:, :] + text_v = vinput[..., kv_len - q_len:, :] + text_mask = attn_mask[..., kv_len - q_len:] + text_res = self.fp8_fsdpa_fwd(qinput, text_k, text_v, text_mask, dropout_p, scale, sm_mode) + text_out, text_m, text_linv = (gqa_output_reshape(x) for x in (text_res[:3])) if gqa else text_res[:3] + + prefix_linv = prefix_linv.to(torch.float32) * 128.0 if softmax_mode != "fp32" else prefix_linv.to(torch.float32) + text_linv = text_linv.to(torch.float32) * 128.0 if softmax_mode != "fp32" else text_linv.to(torch.float32) + prefix_out = self.dequant_output(prefix_out).to(torch.float32) + text_out = self.dequant_output(text_out).to(torch.float32) + new_m = torch.maximum(prefix_m, text_m) + l_rescaled = (1.0 / prefix_linv) * torch.exp(prefix_m - new_m) + block_l_rescaled = (1.0 / text_linv) * torch.exp(text_m - new_m) + new_linv = 1.0 / (l_rescaled + block_l_rescaled) + attn_weights = (l_rescaled * new_linv) * prefix_out + ( + block_l_rescaled * new_linv) * text_out + return attn_weights.to(q.dtype) + + else: + results = self.fp8_fused_sdpa( + qinput, + kinput, + vinput, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + softmax_mode=sm_mode, + d_scale_q=self.scale_q, + d_scale_k=self.scale_k, + d_scale_v=self.scale_v, + q_scale_s=self.scale_amax, + q_scale_o=self.scale_output, + d_scale_s=self.descale_amax, + is_amax_s=False, + valid_seq_len=valid_seq_len, + seq_padding_type=seq_padding_type, + ) + output = results[0] + d_out = self.dequant_output(output) + return d_out def forward_measure( self, From 7323ad810af725988e496a6fe224fa4a5b451158 Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Thu, 20 Nov 2025 10:40:28 +0800 Subject: [PATCH 2/6] apply slicing for the fsdpa --- .../fp8_quant/_quant_common/helper_modules.py | 77 +++++++++++++------ 1 file changed, 52 insertions(+), 25 deletions(-) diff --git a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py index cb4ddc02cd9..b99e089c30f 100755 --- a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py @@ -1297,6 +1297,10 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs): self.register_scale("descale_amax", mod_extra_config.scale.inputs[3].type(torch.float32), self.scale_format) self.register_scale("scale_output", 1 / mod_extra_config.scale.outputs[0].type(torch.float32), self.scale_format) self.register_scale("scale_amax", 1 / self.descale_amax, self.scale_format) + self.qkv_slice_thld = int(os.getenv("VLLM_FUSEDSPA_QKV_SLICE_SEQ_LEN_THLD", 8192)) + if self.qkv_slice_thld > 0: + self.q_chunk_size = int(os.getenv("VLLM_FUSEDSDPA_Q_SLICE_CHUNK_SIZE", self.qkv_slice_thld)) + self.kv_chunk_size = int(os.getenv("VLLM_FUSEDSDPA_KV_SLICE_CHUNK_SIZE", self.qkv_slice_thld)) def forward_qdq( self, @@ -1388,36 +1392,59 @@ def forward_quant( kv_len = kinput.size(-2) # for prefill with prefix caching - if q_len != 1 and q_len != kv_len: + if q_len != 1 and q_len != kv_len \ + and kv_len > self.qkv_slice_thld: + ctx_len = kv_len - q_len from habana_frameworks.torch.hpex.kernels.Fp8FusedSDPA import is_gqa, gqa_input_reshape_fwd, gqa_output_reshape gqa = is_gqa(qinput, kinput) if gqa: qinput, kinput, vinput, attn_mask = gqa_input_reshape_fwd(qinput, kinput, vinput, attn_mask) - prefix_k = kinput[..., :kv_len - q_len, :] - prefix_v = vinput[..., :kv_len - q_len, :] - prefix_mask = None - prefix_res = self.fp8_fsdpa_fwd(qinput, prefix_k, prefix_v, prefix_mask, dropout_p, scale, sm_mode) - prefix_out, prefix_m, prefix_linv = (gqa_output_reshape(x) for x in (prefix_res[:3])) if gqa else prefix_res[:3] - - text_k = kinput[..., kv_len - q_len:, :] - text_v = vinput[..., kv_len - q_len:, :] - text_mask = attn_mask[..., kv_len - q_len:] - text_res = self.fp8_fsdpa_fwd(qinput, text_k, text_v, text_mask, dropout_p, scale, sm_mode) - text_out, text_m, text_linv = (gqa_output_reshape(x) for x in (text_res[:3])) if gqa else text_res[:3] - - prefix_linv = prefix_linv.to(torch.float32) * 128.0 if softmax_mode != "fp32" else prefix_linv.to(torch.float32) - text_linv = text_linv.to(torch.float32) * 128.0 if softmax_mode != "fp32" else text_linv.to(torch.float32) - prefix_out = self.dequant_output(prefix_out).to(torch.float32) - text_out = self.dequant_output(text_out).to(torch.float32) - new_m = torch.maximum(prefix_m, text_m) - l_rescaled = (1.0 / prefix_linv) * torch.exp(prefix_m - new_m) - block_l_rescaled = (1.0 / text_linv) * torch.exp(text_m - new_m) - new_linv = 1.0 / (l_rescaled + block_l_rescaled) - attn_weights = (l_rescaled * new_linv) * prefix_out + ( - block_l_rescaled * new_linv) * text_out - return attn_weights.to(q.dtype) - + num_q_chunks = (q_len + self.q_chunk_size - 1) // self.q_chunk_size + num_kv_chunks = (kv_len + self.kv_chunk_size - 1) // self.kv_chunk_size + chunk_outputs = [] + for q_chunk_idx in range(num_q_chunks): + q_start = q_chunk_idx * self.q_chunk_size + q_end = min((q_chunk_idx + 1) * self.q_chunk_size, q_len) + q_chunk = qinput[..., q_start:q_end, :] + + last_out = None + last_m = None + last_linv = None + for kv_chunk_idx in range(num_kv_chunks): + kv_start = kv_chunk_idx * self.kv_chunk_size + kv_end = min((kv_chunk_idx + 1) * self.kv_chunk_size, kv_len) + k_chunk = kinput[..., kv_start:kv_end, :] + v_chunk = vinput[..., kv_start:kv_end, :] + attn_mask_chunk = attn_mask[..., kv_start:kv_end] if attn_mask is not None else None + attn_mask_chunk = None if kv_end < ctx_len else attn_mask_chunk + + # skip the upper triangular part for causal attention + if kv_start > ctx_len + q_end: + continue + + chunk_res = self.fp8_fsdpa_fwd(q_chunk, k_chunk, v_chunk, attn_mask_chunk, dropout_p, scale, sm_mode) + chunk_out, chunk_m, chunk_linv = (gqa_output_reshape(x) for x in (chunk_res[:3])) if gqa else chunk_res[:3] + + chunk_m = chunk_m.to(torch.float32) + chunk_linv = chunk_linv.to(torch.float32) * 128.0 if softmax_mode != "fp32" else chunk_linv.to(torch.float32) + chunk_out = self.dequant_output(chunk_out).to(torch.float32) + + if kv_chunk_idx == 0: + last_out = chunk_out + last_m = chunk_m + last_linv = chunk_linv + else: + new_m = torch.maximum(last_m, chunk_m) + last_linv_rescaled = (1.0 / last_linv) * torch.exp(last_m - new_m) + chunk_linv_rescaled = (1.0 / chunk_linv) * torch.exp(chunk_m - new_m) + last_linv = 1.0 / (last_linv_rescaled + chunk_linv_rescaled) + last_out = (last_linv_rescaled * last_linv) * last_out + ( + chunk_linv_rescaled * last_linv) * chunk_out + last_m = new_m + chunk_outputs.append(last_out) + output = torch.cat(chunk_outputs, dim=-2) + return output.to(q.dtype) else: results = self.fp8_fused_sdpa( qinput, From 688ae027ba708d4fffc4ca549476945aebb91d42 Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Fri, 21 Nov 2025 08:21:32 +0800 Subject: [PATCH 3/6] pass less attn_mask to get better perf. --- .../fp8_quant/_quant_common/helper_modules.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py index b99e089c30f..9f7bf3287a9 100755 --- a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py @@ -1343,6 +1343,7 @@ def fp8_fsdpa_fwd(self, attn_mask, dropout_p, scale, + is_causal, softmax_mode, ): results = torch.ops.hpu.fp8_sdpa_recomp_fwd( @@ -1352,7 +1353,7 @@ def fp8_fsdpa_fwd(self, attn_mask, dropout_p, scale, - False, # is_causal + is_causal, True, # requires_backward softmax_mode, # softmax_mode self.scale_q, # d_scale_q @@ -1394,6 +1395,7 @@ def forward_quant( # for prefill with prefix caching if q_len != 1 and q_len != kv_len \ and kv_len > self.qkv_slice_thld: + assert attn_mask is not None, "Attention mask is required for FSDPA with prefix caching." ctx_len = kv_len - q_len from habana_frameworks.torch.hpex.kernels.Fp8FusedSDPA import is_gqa, gqa_input_reshape_fwd, gqa_output_reshape gqa = is_gqa(qinput, kinput) @@ -1416,14 +1418,21 @@ def forward_quant( kv_end = min((kv_chunk_idx + 1) * self.kv_chunk_size, kv_len) k_chunk = kinput[..., kv_start:kv_end, :] v_chunk = vinput[..., kv_start:kv_end, :] - attn_mask_chunk = attn_mask[..., kv_start:kv_end] if attn_mask is not None else None - attn_mask_chunk = None if kv_end < ctx_len else attn_mask_chunk # skip the upper triangular part for causal attention if kv_start > ctx_len + q_end: continue + + is_causal= True if kv_start-ctx_len==0 else False + + # current chunk_size should be multiple of 1024 to get right m/linv + if kv_end-ctx_len==0 and ((q_end-q_start)%1024!=0 or (kv_end-kv_start)%1024!=0): + is_causal = False + attn_mask_chunk = attn_mask[..., kv_start:kv_end] + else: + attn_mask_chunk = None - chunk_res = self.fp8_fsdpa_fwd(q_chunk, k_chunk, v_chunk, attn_mask_chunk, dropout_p, scale, sm_mode) + chunk_res = self.fp8_fsdpa_fwd(q_chunk, k_chunk, v_chunk, attn_mask_chunk, dropout_p, scale, is_causal, sm_mode) chunk_out, chunk_m, chunk_linv = (gqa_output_reshape(x) for x in (chunk_res[:3])) if gqa else chunk_res[:3] chunk_m = chunk_m.to(torch.float32) From dbb94f5335bfb7d8538747a893b8511889053543 Mon Sep 17 00:00:00 2001 From: Kurt Chen Date: Fri, 21 Nov 2025 13:52:18 +0000 Subject: [PATCH 4/6] Split context and causal and do QKV slice for APC FusedSDPA --- .../fp8_quant/_quant_common/helper_modules.py | 88 ++++++++++++++----- 1 file changed, 66 insertions(+), 22 deletions(-) diff --git a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py index 9f7bf3287a9..88da743db7e 100755 --- a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py @@ -1297,7 +1297,7 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs): self.register_scale("descale_amax", mod_extra_config.scale.inputs[3].type(torch.float32), self.scale_format) self.register_scale("scale_output", 1 / mod_extra_config.scale.outputs[0].type(torch.float32), self.scale_format) self.register_scale("scale_amax", 1 / self.descale_amax, self.scale_format) - self.qkv_slice_thld = int(os.getenv("VLLM_FUSEDSPA_QKV_SLICE_SEQ_LEN_THLD", 8192)) + self.qkv_slice_thld = int(os.getenv("VLLM_FUSEDSDPA_QKV_SLICE_SEQ_LEN_THLD", 8192)) if self.qkv_slice_thld > 0: self.q_chunk_size = int(os.getenv("VLLM_FUSEDSDPA_Q_SLICE_CHUNK_SIZE", self.qkv_slice_thld)) self.kv_chunk_size = int(os.getenv("VLLM_FUSEDSDPA_KV_SLICE_CHUNK_SIZE", self.qkv_slice_thld)) @@ -1393,17 +1393,17 @@ def forward_quant( kv_len = kinput.size(-2) # for prefill with prefix caching - if q_len != 1 and q_len != kv_len \ - and kv_len > self.qkv_slice_thld: + if self.qkv_slice_thld > 0 and q_len != 1 and q_len != kv_len and kv_len > self.qkv_slice_thld: assert attn_mask is not None, "Attention mask is required for FSDPA with prefix caching." ctx_len = kv_len - q_len from habana_frameworks.torch.hpex.kernels.Fp8FusedSDPA import is_gqa, gqa_input_reshape_fwd, gqa_output_reshape gqa = is_gqa(qinput, kinput) if gqa: qinput, kinput, vinput, attn_mask = gqa_input_reshape_fwd(qinput, kinput, vinput, attn_mask) - + num_q_chunks = (q_len + self.q_chunk_size - 1) // self.q_chunk_size - num_kv_chunks = (kv_len + self.kv_chunk_size - 1) // self.kv_chunk_size + num_context_kv_chunks = (ctx_len + self.kv_chunk_size - 1) // self.kv_chunk_size + num_causal_kv_chunks = num_q_chunks chunk_outputs = [] for q_chunk_idx in range(num_q_chunks): q_start = q_chunk_idx * self.q_chunk_size @@ -1413,28 +1413,15 @@ def forward_quant( last_out = None last_m = None last_linv = None - for kv_chunk_idx in range(num_kv_chunks): + for kv_chunk_idx in range(num_context_kv_chunks): kv_start = kv_chunk_idx * self.kv_chunk_size - kv_end = min((kv_chunk_idx + 1) * self.kv_chunk_size, kv_len) + kv_end = min((kv_chunk_idx + 1) * self.kv_chunk_size, ctx_len) k_chunk = kinput[..., kv_start:kv_end, :] v_chunk = vinput[..., kv_start:kv_end, :] - # skip the upper triangular part for causal attention - if kv_start > ctx_len + q_end: - continue - - is_causal= True if kv_start-ctx_len==0 else False - - # current chunk_size should be multiple of 1024 to get right m/linv - if kv_end-ctx_len==0 and ((q_end-q_start)%1024!=0 or (kv_end-kv_start)%1024!=0): - is_causal = False - attn_mask_chunk = attn_mask[..., kv_start:kv_end] - else: - attn_mask_chunk = None - - chunk_res = self.fp8_fsdpa_fwd(q_chunk, k_chunk, v_chunk, attn_mask_chunk, dropout_p, scale, is_causal, sm_mode) + chunk_res = self.fp8_fsdpa_fwd(q_chunk, k_chunk, v_chunk, None, dropout_p, scale, False, sm_mode) chunk_out, chunk_m, chunk_linv = (gqa_output_reshape(x) for x in (chunk_res[:3])) if gqa else chunk_res[:3] - + chunk_m = chunk_m.to(torch.float32) chunk_linv = chunk_linv.to(torch.float32) * 128.0 if softmax_mode != "fp32" else chunk_linv.to(torch.float32) chunk_out = self.dequant_output(chunk_out).to(torch.float32) @@ -1451,6 +1438,63 @@ def forward_quant( last_out = (last_linv_rescaled * last_linv) * last_out + ( chunk_linv_rescaled * last_linv) * chunk_out last_m = new_m + + kv_causal_start = ctx_len + q_start + kv_causal_end = ctx_len + q_end + k_causal_chunk = kinput[..., kv_causal_start:kv_causal_end, :] + v_causal_chunk = vinput[..., kv_causal_start:kv_causal_end, :] + + bs = q_chunk.size(0) + q_chunk_len = q_chunk.size(-2) + if q_chunk.size(-2) < self.q_chunk_size: + mask = (1 - torch.tril( + torch.ones(bs, + 1, + 1, + q_chunk_len, + q_chunk_len, + dtype=q.dtype, + device=q.device))) * torch.finfo( + q.dtype).min + causal_chunk_res = self.fp8_fsdpa_fwd(q_chunk, k_causal_chunk, v_causal_chunk, mask, dropout_p, scale, False, sm_mode) + else: + causal_chunk_res = self.fp8_fsdpa_fwd(q_chunk, k_causal_chunk, v_causal_chunk, None, dropout_p, scale, True, sm_mode) + + causal_chunk_out, causal_chunk_m, causal_chunk_linv = (gqa_output_reshape(x) for x in (causal_chunk_res[:3])) if gqa else causal_chunk_res[:3] + causal_chunk_m = causal_chunk_m.to(torch.float32) + causal_chunk_linv = causal_chunk_linv.to(torch.float32) * 128.0 if softmax_mode != "fp32" else causal_chunk_linv.to(torch.float32) + causal_chunk_out = self.dequant_output(causal_chunk_out).to(torch.float32) + + if num_causal_kv_chunks == 1: + new_m = torch.maximum(last_m, causal_chunk_m) + last_linv_rescaled = (1.0 / last_linv) * torch.exp(last_m - new_m) + chunk_linv_rescaled = (1.0 / causal_chunk_linv) * torch.exp(causal_chunk_m - new_m) + last_linv = 1.0 / (last_linv_rescaled + chunk_linv_rescaled) + last_out = (last_linv_rescaled * last_linv) * last_out + ( + chunk_linv_rescaled * last_linv) * causal_chunk_out + last_m = new_m + else: + for kv_chunk_idx in range(0, q_chunk_idx): + kv_causal_start = ctx_len + kv_chunk_idx * self.q_chunk_size + kv_causal_end = ctx_len + (kv_chunk_idx + 1) * self.q_chunk_size + k_causal_chunk = kinput[..., kv_causal_start:kv_causal_end, :] + v_causal_chunk = vinput[..., kv_causal_start:kv_causal_end, :] + + chunk_res = self.fp8_fsdpa_fwd(q_chunk, k_causal_chunk, v_causal_chunk, None, dropout_p, scale, False, sm_mode) + + chunk_out, chunk_m, chunk_linv = (gqa_output_reshape(x) for x in (chunk_res[:3])) if gqa else chunk_res[:3] + chunk_m = chunk_m.to(torch.float32) + chunk_linv = chunk_linv.to(torch.float32) * 128.0 if softmax_mode != "fp32" else chunk_linv.to(torch.float32) + chunk_out = self.dequant_output(chunk_out).to(torch.float32) + + new_m = torch.maximum(last_m, chunk_m) + last_linv_rescaled = (1.0 / last_linv) * torch.exp(last_m - new_m) + chunk_linv_rescaled = (1.0 / chunk_linv) * torch.exp(chunk_m - new_m) + last_linv = 1.0 / (last_linv_rescaled + chunk_linv_rescaled) + last_out = (last_linv_rescaled * last_linv) * last_out + ( + chunk_linv_rescaled * last_linv) * chunk_out + last_m = new_m + chunk_outputs.append(last_out) output = torch.cat(chunk_outputs, dim=-2) return output.to(q.dtype) From 663e67dc2e0dd719dd553a9157bed176844c9eac Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Fri, 28 Nov 2025 11:30:13 +0800 Subject: [PATCH 5/6] use slice of attn_mask instead of new causal_mask --- .../fp8_quant/_quant_common/helper_modules.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py index 88da743db7e..c1cf3776bb7 100755 --- a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py @@ -1447,15 +1447,7 @@ def forward_quant( bs = q_chunk.size(0) q_chunk_len = q_chunk.size(-2) if q_chunk.size(-2) < self.q_chunk_size: - mask = (1 - torch.tril( - torch.ones(bs, - 1, - 1, - q_chunk_len, - q_chunk_len, - dtype=q.dtype, - device=q.device))) * torch.finfo( - q.dtype).min + mask = attn_mask[..., q_start:q_end, kv_causal_start:kv_causal_end] causal_chunk_res = self.fp8_fsdpa_fwd(q_chunk, k_causal_chunk, v_causal_chunk, mask, dropout_p, scale, False, sm_mode) else: causal_chunk_res = self.fp8_fsdpa_fwd(q_chunk, k_causal_chunk, v_causal_chunk, None, dropout_p, scale, True, sm_mode) From 7abff40453dba264a1b4d96b5739954ad69cd8a6 Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Mon, 1 Dec 2025 14:28:17 +0800 Subject: [PATCH 6/6] use slicing for the causal part only --- .../fp8_quant/_quant_common/helper_modules.py | 113 ++++++------------ 1 file changed, 36 insertions(+), 77 deletions(-) diff --git a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py index c1cf3776bb7..3611b2796ad 100755 --- a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py @@ -1297,10 +1297,9 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs): self.register_scale("descale_amax", mod_extra_config.scale.inputs[3].type(torch.float32), self.scale_format) self.register_scale("scale_output", 1 / mod_extra_config.scale.outputs[0].type(torch.float32), self.scale_format) self.register_scale("scale_amax", 1 / self.descale_amax, self.scale_format) - self.qkv_slice_thld = int(os.getenv("VLLM_FUSEDSDPA_QKV_SLICE_SEQ_LEN_THLD", 8192)) + self.qkv_slice_thld = int(os.getenv("PT_HPU_QKV_SLICE_SEQ_LEN_THLD", 4096)) if self.qkv_slice_thld > 0: - self.q_chunk_size = int(os.getenv("VLLM_FUSEDSDPA_Q_SLICE_CHUNK_SIZE", self.qkv_slice_thld)) - self.kv_chunk_size = int(os.getenv("VLLM_FUSEDSDPA_KV_SLICE_CHUNK_SIZE", self.qkv_slice_thld)) + self.qkv_chunk_size = int(os.getenv("VLLM_FUSEDSDPA_QKV_SLICE_CHUNK_SIZE", self.qkv_slice_thld)) def forward_qdq( self, @@ -1385,7 +1384,7 @@ def forward_quant( valid_seq_len=None, seq_padding_type="None", ): - sm_mode = softmax_mode if softmax_mode == "fp32" else "none" + sm_mode = softmax_mode if softmax_mode == "fp32" else "fast" qinput = self.quant_q(q).detach() kinput = self.quant_k(k).detach() vinput = self.quant_v(v).detach() @@ -1393,101 +1392,61 @@ def forward_quant( kv_len = kinput.size(-2) # for prefill with prefix caching - if self.qkv_slice_thld > 0 and q_len != 1 and q_len != kv_len and kv_len > self.qkv_slice_thld: + if self.qkv_slice_thld > 0 and q_len != 1 and q_len != kv_len and kv_len >= self.qkv_slice_thld: assert attn_mask is not None, "Attention mask is required for FSDPA with prefix caching." - ctx_len = kv_len - q_len + prefix_len = kv_len - q_len from habana_frameworks.torch.hpex.kernels.Fp8FusedSDPA import is_gqa, gqa_input_reshape_fwd, gqa_output_reshape gqa = is_gqa(qinput, kinput) if gqa: qinput, kinput, vinput, attn_mask = gqa_input_reshape_fwd(qinput, kinput, vinput, attn_mask) - num_q_chunks = (q_len + self.q_chunk_size - 1) // self.q_chunk_size - num_context_kv_chunks = (ctx_len + self.kv_chunk_size - 1) // self.kv_chunk_size - num_causal_kv_chunks = num_q_chunks + # calculate the prefix SDPA w/o mask + prefix_kinput = kinput[..., :prefix_len, :] + prefix_vinput = vinput[..., :prefix_len, :] + prefix_results = self.fp8_fsdpa_fwd(qinput, prefix_kinput, prefix_vinput, None, dropout_p, scale, False, sm_mode) + prefix_out, prefix_m, prefix_linv = (gqa_output_reshape(x) for x in (prefix_results[:3])) if gqa else prefix_results[:3] + prefix_m = prefix_m.to(torch.float32) + prefix_linv = prefix_linv.to(torch.float32) * 128.0 if softmax_mode != "fp32" else prefix_linv.to(torch.float32) + prefix_out = self.dequant_output(prefix_out).to(torch.float32) + + # calculate the causal part in chunks chunk_outputs = [] - for q_chunk_idx in range(num_q_chunks): - q_start = q_chunk_idx * self.q_chunk_size - q_end = min((q_chunk_idx + 1) * self.q_chunk_size, q_len) + num_chunks = (q_len + self.qkv_chunk_size - 1) // self.qkv_chunk_size + for q_chunk_idx in range(num_chunks): + q_start = q_len - (q_chunk_idx + 1) * self.qkv_chunk_size + q_start = max(q_start, 0) + q_end = q_len - q_chunk_idx * self.qkv_chunk_size q_chunk = qinput[..., q_start:q_end, :] - last_out = None - last_m = None - last_linv = None - for kv_chunk_idx in range(num_context_kv_chunks): - kv_start = kv_chunk_idx * self.kv_chunk_size - kv_end = min((kv_chunk_idx + 1) * self.kv_chunk_size, ctx_len) + last_out = prefix_out[..., q_start:q_end, :] + last_m = prefix_m[..., q_start:q_end, :] + last_linv = prefix_linv[..., q_start:q_end, :] + + for kv_chunk_idx in range(0, num_chunks - q_chunk_idx): + kv_start = prefix_len + q_end - (kv_chunk_idx + 1) * self.qkv_chunk_size + kv_start = max(kv_start, prefix_len) + kv_end = prefix_len + q_end - kv_chunk_idx * self.qkv_chunk_size k_chunk = kinput[..., kv_start:kv_end, :] v_chunk = vinput[..., kv_start:kv_end, :] - chunk_res = self.fp8_fsdpa_fwd(q_chunk, k_chunk, v_chunk, None, dropout_p, scale, False, sm_mode) - chunk_out, chunk_m, chunk_linv = (gqa_output_reshape(x) for x in (chunk_res[:3])) if gqa else chunk_res[:3] + is_causal_chunk = kv_chunk_idx == 0 and q_chunk_idx !=0 + mask_chunk = attn_mask[..., q_start:q_end, kv_start:kv_end] if kv_chunk_idx == 0 and not is_causal_chunk else None + chunk_res = self.fp8_fsdpa_fwd(q_chunk, k_chunk, v_chunk, mask_chunk, dropout_p, scale, is_causal_chunk, sm_mode) + chunk_out, chunk_m, chunk_linv = (gqa_output_reshape(x) for x in (chunk_res[:3])) if gqa else chunk_res[:3] chunk_m = chunk_m.to(torch.float32) chunk_linv = chunk_linv.to(torch.float32) * 128.0 if softmax_mode != "fp32" else chunk_linv.to(torch.float32) chunk_out = self.dequant_output(chunk_out).to(torch.float32) - if kv_chunk_idx == 0: - last_out = chunk_out - last_m = chunk_m - last_linv = chunk_linv - else: - new_m = torch.maximum(last_m, chunk_m) - last_linv_rescaled = (1.0 / last_linv) * torch.exp(last_m - new_m) - chunk_linv_rescaled = (1.0 / chunk_linv) * torch.exp(chunk_m - new_m) - last_linv = 1.0 / (last_linv_rescaled + chunk_linv_rescaled) - last_out = (last_linv_rescaled * last_linv) * last_out + ( - chunk_linv_rescaled * last_linv) * chunk_out - last_m = new_m - - kv_causal_start = ctx_len + q_start - kv_causal_end = ctx_len + q_end - k_causal_chunk = kinput[..., kv_causal_start:kv_causal_end, :] - v_causal_chunk = vinput[..., kv_causal_start:kv_causal_end, :] - - bs = q_chunk.size(0) - q_chunk_len = q_chunk.size(-2) - if q_chunk.size(-2) < self.q_chunk_size: - mask = attn_mask[..., q_start:q_end, kv_causal_start:kv_causal_end] - causal_chunk_res = self.fp8_fsdpa_fwd(q_chunk, k_causal_chunk, v_causal_chunk, mask, dropout_p, scale, False, sm_mode) - else: - causal_chunk_res = self.fp8_fsdpa_fwd(q_chunk, k_causal_chunk, v_causal_chunk, None, dropout_p, scale, True, sm_mode) - - causal_chunk_out, causal_chunk_m, causal_chunk_linv = (gqa_output_reshape(x) for x in (causal_chunk_res[:3])) if gqa else causal_chunk_res[:3] - causal_chunk_m = causal_chunk_m.to(torch.float32) - causal_chunk_linv = causal_chunk_linv.to(torch.float32) * 128.0 if softmax_mode != "fp32" else causal_chunk_linv.to(torch.float32) - causal_chunk_out = self.dequant_output(causal_chunk_out).to(torch.float32) - - if num_causal_kv_chunks == 1: - new_m = torch.maximum(last_m, causal_chunk_m) + new_m = torch.maximum(last_m, chunk_m) last_linv_rescaled = (1.0 / last_linv) * torch.exp(last_m - new_m) - chunk_linv_rescaled = (1.0 / causal_chunk_linv) * torch.exp(causal_chunk_m - new_m) + chunk_linv_rescaled = (1.0 / chunk_linv) * torch.exp(chunk_m - new_m) last_linv = 1.0 / (last_linv_rescaled + chunk_linv_rescaled) last_out = (last_linv_rescaled * last_linv) * last_out + ( - chunk_linv_rescaled * last_linv) * causal_chunk_out + chunk_linv_rescaled * last_linv) * chunk_out last_m = new_m - else: - for kv_chunk_idx in range(0, q_chunk_idx): - kv_causal_start = ctx_len + kv_chunk_idx * self.q_chunk_size - kv_causal_end = ctx_len + (kv_chunk_idx + 1) * self.q_chunk_size - k_causal_chunk = kinput[..., kv_causal_start:kv_causal_end, :] - v_causal_chunk = vinput[..., kv_causal_start:kv_causal_end, :] - - chunk_res = self.fp8_fsdpa_fwd(q_chunk, k_causal_chunk, v_causal_chunk, None, dropout_p, scale, False, sm_mode) - - chunk_out, chunk_m, chunk_linv = (gqa_output_reshape(x) for x in (chunk_res[:3])) if gqa else chunk_res[:3] - chunk_m = chunk_m.to(torch.float32) - chunk_linv = chunk_linv.to(torch.float32) * 128.0 if softmax_mode != "fp32" else chunk_linv.to(torch.float32) - chunk_out = self.dequant_output(chunk_out).to(torch.float32) - - new_m = torch.maximum(last_m, chunk_m) - last_linv_rescaled = (1.0 / last_linv) * torch.exp(last_m - new_m) - chunk_linv_rescaled = (1.0 / chunk_linv) * torch.exp(chunk_m - new_m) - last_linv = 1.0 / (last_linv_rescaled + chunk_linv_rescaled) - last_out = (last_linv_rescaled * last_linv) * last_out + ( - chunk_linv_rescaled * last_linv) * chunk_out - last_m = new_m - chunk_outputs.append(last_out) + chunk_outputs = list(reversed(chunk_outputs)) output = torch.cat(chunk_outputs, dim=-2) return output.to(q.dtype) else: