Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import torch
import torch.nn as nn
import types
Expand Down Expand Up @@ -1296,6 +1297,9 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
self.register_scale("descale_amax", mod_extra_config.scale.inputs[3].type(torch.float32), self.scale_format)
self.register_scale("scale_output", 1 / mod_extra_config.scale.outputs[0].type(torch.float32), self.scale_format)
self.register_scale("scale_amax", 1 / self.descale_amax, self.scale_format)
self.qkv_slice_thld = int(os.getenv("PT_HPU_QKV_SLICE_SEQ_LEN_THLD", 4096))
if self.qkv_slice_thld > 0:
self.qkv_chunk_size = int(os.getenv("VLLM_FUSEDSDPA_QKV_SLICE_CHUNK_SIZE", self.qkv_slice_thld))

def forward_qdq(
self,
Expand Down Expand Up @@ -1330,6 +1334,41 @@ def forward_qdq(
seq_padding_type,
)
return results

def fp8_fsdpa_fwd(self,
q,
k,
v,
attn_mask,
dropout_p,
scale,
is_causal,
softmax_mode,
):
results = torch.ops.hpu.fp8_sdpa_recomp_fwd(
q,
k,
v,
attn_mask,
dropout_p,
scale,
is_causal,
True, # requires_backward
softmax_mode, # softmax_mode
self.scale_q, # d_scale_q
self.scale_k, # d_scale_k
self.scale_v, # d_scale_v
self.scale_amax, # q_scale_s
self.scale_output, # q_scale_o
self.descale_amax, # d_scale_s
False, # is_amax_s
False, # is_amax_o
None, # valid_seq_len
"right", # seq_padding_type
(-1, -1), # window_size
None, # sink
)
return results

def forward_quant(
self,
Expand All @@ -1345,32 +1384,94 @@ def forward_quant(
valid_seq_len=None,
seq_padding_type="None",
):
sm_mode = softmax_mode if softmax_mode == "fp32" else "None"
sm_mode = softmax_mode if softmax_mode == "fp32" else "fast"
qinput = self.quant_q(q).detach()
kinput = self.quant_k(k).detach()
vinput = self.quant_v(v).detach()
results = self.fp8_fused_sdpa(
qinput,
kinput,
vinput,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
softmax_mode=sm_mode,
d_scale_q=self.scale_q,
d_scale_k=self.scale_k,
d_scale_v=self.scale_v,
q_scale_s=self.scale_amax,
q_scale_o=self.scale_output,
d_scale_s=self.descale_amax,
is_amax_s=False,
valid_seq_len=valid_seq_len,
seq_padding_type=seq_padding_type,
)
output = results[0]
d_out = self.dequant_output(output)
return d_out
q_len = q.shape[-2]
kv_len = kinput.size(-2)

# for prefill with prefix caching
if self.qkv_slice_thld > 0 and q_len != 1 and q_len != kv_len and kv_len >= self.qkv_slice_thld:
assert attn_mask is not None, "Attention mask is required for FSDPA with prefix caching."
prefix_len = kv_len - q_len
from habana_frameworks.torch.hpex.kernels.Fp8FusedSDPA import is_gqa, gqa_input_reshape_fwd, gqa_output_reshape
gqa = is_gqa(qinput, kinput)
if gqa:
qinput, kinput, vinput, attn_mask = gqa_input_reshape_fwd(qinput, kinput, vinput, attn_mask)

# calculate the prefix SDPA w/o mask
prefix_kinput = kinput[..., :prefix_len, :]
prefix_vinput = vinput[..., :prefix_len, :]
prefix_results = self.fp8_fsdpa_fwd(qinput, prefix_kinput, prefix_vinput, None, dropout_p, scale, False, sm_mode)
prefix_out, prefix_m, prefix_linv = (gqa_output_reshape(x) for x in (prefix_results[:3])) if gqa else prefix_results[:3]
prefix_m = prefix_m.to(torch.float32)
prefix_linv = prefix_linv.to(torch.float32) * 128.0 if softmax_mode != "fp32" else prefix_linv.to(torch.float32)
prefix_out = self.dequant_output(prefix_out).to(torch.float32)

# calculate the causal part in chunks
chunk_outputs = []
num_chunks = (q_len + self.qkv_chunk_size - 1) // self.qkv_chunk_size
for q_chunk_idx in range(num_chunks):
q_start = q_len - (q_chunk_idx + 1) * self.qkv_chunk_size
q_start = max(q_start, 0)
q_end = q_len - q_chunk_idx * self.qkv_chunk_size
q_chunk = qinput[..., q_start:q_end, :]

last_out = prefix_out[..., q_start:q_end, :]
last_m = prefix_m[..., q_start:q_end, :]
last_linv = prefix_linv[..., q_start:q_end, :]

for kv_chunk_idx in range(0, num_chunks - q_chunk_idx):
kv_start = prefix_len + q_end - (kv_chunk_idx + 1) * self.qkv_chunk_size
kv_start = max(kv_start, prefix_len)
kv_end = prefix_len + q_end - kv_chunk_idx * self.qkv_chunk_size
k_chunk = kinput[..., kv_start:kv_end, :]
v_chunk = vinput[..., kv_start:kv_end, :]

is_causal_chunk = kv_chunk_idx == 0 and q_chunk_idx !=0
mask_chunk = attn_mask[..., q_start:q_end, kv_start:kv_end] if kv_chunk_idx == 0 and not is_causal_chunk else None
chunk_res = self.fp8_fsdpa_fwd(q_chunk, k_chunk, v_chunk, mask_chunk, dropout_p, scale, is_causal_chunk, sm_mode)

chunk_out, chunk_m, chunk_linv = (gqa_output_reshape(x) for x in (chunk_res[:3])) if gqa else chunk_res[:3]
chunk_m = chunk_m.to(torch.float32)
chunk_linv = chunk_linv.to(torch.float32) * 128.0 if softmax_mode != "fp32" else chunk_linv.to(torch.float32)
chunk_out = self.dequant_output(chunk_out).to(torch.float32)

new_m = torch.maximum(last_m, chunk_m)
last_linv_rescaled = (1.0 / last_linv) * torch.exp(last_m - new_m)
chunk_linv_rescaled = (1.0 / chunk_linv) * torch.exp(chunk_m - new_m)
last_linv = 1.0 / (last_linv_rescaled + chunk_linv_rescaled)
last_out = (last_linv_rescaled * last_linv) * last_out + (
chunk_linv_rescaled * last_linv) * chunk_out
last_m = new_m
chunk_outputs.append(last_out)
chunk_outputs = list(reversed(chunk_outputs))
output = torch.cat(chunk_outputs, dim=-2)
return output.to(q.dtype)
else:
results = self.fp8_fused_sdpa(
qinput,
kinput,
vinput,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
softmax_mode=sm_mode,
d_scale_q=self.scale_q,
d_scale_k=self.scale_k,
d_scale_v=self.scale_v,
q_scale_s=self.scale_amax,
q_scale_o=self.scale_output,
d_scale_s=self.descale_amax,
is_amax_s=False,
valid_seq_len=valid_seq_len,
seq_padding_type=seq_padding_type,
)
output = results[0]
d_out = self.dequant_output(output)
return d_out

def forward_measure(
self,
Expand Down