Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 62 additions & 40 deletions backends/cuda/triton/kernels/sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,21 +422,22 @@ def _sdpa_fwd_kernel_body(

offs_n_init = tl.arange(0, BLOCK_N)

# Window-aware early-exit. A KV block that is fully masked (sliding-window
# or causal) contributes nothing to the online softmax — every entry is
# -inf, so p=0 and m_i/l_i/acc are left unchanged. We detect such blocks up
# front and skip their K/V loads and both matmuls. This is exact: it only
# skips work the mask would have zeroed out anyway. At seq=2048 the 50
# sliding-window(1024) layers and the 10 causal layers each leave roughly
# half (or more) of their KV blocks fully masked, so this is a large cut to
# the dominant prefill cost. The skip condition is a CTA-wide reduction, so
# the branch is uniform and turns into a real skip (not predication).
if IS_CAUSAL:
max_seq_pos = tl.max(seq_pos)

for start_n in tl.range(0, Lk, BLOCK_N):
offs_n = start_n + offs_n_init

# K load: uniform (single KV head, shared across all Q heads in tile)
k_ptrs = K_ptr + (
b * stride_kb
+ h_kv * stride_kh
+ (offs_n[:, None] * stride_kn)
+ (offs_d[None, :] * stride_kd)
)
k_mask = (offs_n[:, None] < Lk) & (offs_d[None, :] < HEAD_DIM)
k = tl.load(k_ptrs, mask=k_mask, other=0.0).to(tl.bfloat16)

qk = (tl.dot(q, tl.trans(k)).to(tl.float32) * sm_scale).to(tl.float32)

# Decide whether any row in this tile actually attends to this KV block.
if HAS_MASK:
mask_ptrs = Mask_ptr + (
b * stride_mb
Expand All @@ -445,39 +446,60 @@ def _sdpa_fwd_kernel_body(
)
mn_mask = row_valid[:, None] & (offs_n[None, :] < Lk)
mask_block = tl.load(mask_ptrs, mask=mn_mask, other=False)
qk = tl.where(
mask_block, qk, tl.full(qk.shape, -float("inf"), dtype=tl.float32)
block_active = tl.sum(mask_block.to(tl.int32)) > 0
elif IS_CAUSAL:
# Block is entirely in the future for every row -> skip.
block_active = start_n <= max_seq_pos
else:
block_active = True

if block_active:
# K load: uniform (single KV head, shared across Q heads in tile)
k_ptrs = K_ptr + (
b * stride_kb
+ h_kv * stride_kh
+ (offs_n[:, None] * stride_kn)
+ (offs_d[None, :] * stride_kd)
)
k_mask = (offs_n[:, None] < Lk) & (offs_d[None, :] < HEAD_DIM)
k = tl.load(k_ptrs, mask=k_mask, other=0.0).to(tl.bfloat16)

if IS_CAUSAL:
causal = offs_n[None, :] > seq_pos[:, None]
qk = tl.where(
causal, tl.full(qk.shape, -float("inf"), dtype=tl.float32), qk
)
qk = (tl.dot(q, tl.trans(k)).to(tl.float32) * sm_scale).to(tl.float32)

m_ij = tl.maximum(m_i, tl.max(qk, axis=1).to(tl.float32))
safe_diff = tl.where(
m_ij[:, None] > -float("inf"), qk - m_ij[:, None], -float("inf")
)
p_f32 = tl.exp(safe_diff).to(tl.float32)
l_ij = tl.sum(p_f32, axis=1).to(tl.float32)
safe_alpha_diff = tl.where(m_ij > -float("inf"), m_i - m_ij, 0.0)
alpha = tl.exp(safe_alpha_diff).to(tl.float32)
if HAS_MASK:
qk = tl.where(
mask_block, qk, tl.full(qk.shape, -float("inf"), dtype=tl.float32)
)

# V load: uniform (single KV head)
v_ptrs = V_ptr + (
b * stride_vb
+ h_kv * stride_vh
+ (offs_n[:, None] * stride_vn)
+ (offs_d[None, :] * stride_vd)
)
v_mask = (offs_n[:, None] < Lk) & (offs_d[None, :] < HEAD_DIM)
v = tl.load(v_ptrs, mask=v_mask, other=0.0).to(tl.bfloat16)
if IS_CAUSAL:
causal = offs_n[None, :] > seq_pos[:, None]
qk = tl.where(
causal, tl.full(qk.shape, -float("inf"), dtype=tl.float32), qk
)

p_bf16 = p_f32.to(tl.bfloat16)
acc = (acc * alpha[:, None] + tl.dot(p_bf16, v)).to(tl.float32)
l_i = (l_i * alpha + l_ij).to(tl.float32)
m_i = m_ij
m_ij = tl.maximum(m_i, tl.max(qk, axis=1).to(tl.float32))
safe_diff = tl.where(
m_ij[:, None] > -float("inf"), qk - m_ij[:, None], -float("inf")
)
p_f32 = tl.exp(safe_diff).to(tl.float32)
l_ij = tl.sum(p_f32, axis=1).to(tl.float32)
safe_alpha_diff = tl.where(m_ij > -float("inf"), m_i - m_ij, 0.0)
alpha = tl.exp(safe_alpha_diff).to(tl.float32)

# V load: uniform (single KV head)
v_ptrs = V_ptr + (
b * stride_vb
+ h_kv * stride_vh
+ (offs_n[:, None] * stride_vn)
+ (offs_d[None, :] * stride_vd)
)
v_mask = (offs_n[:, None] < Lk) & (offs_d[None, :] < HEAD_DIM)
v = tl.load(v_ptrs, mask=v_mask, other=0.0).to(tl.bfloat16)

p_bf16 = p_f32.to(tl.bfloat16)
acc = (acc * alpha[:, None] + tl.dot(p_bf16, v)).to(tl.float32)
l_i = (l_i * alpha + l_ij).to(tl.float32)
m_i = m_ij

inv_l_i = tl.where(l_i > 0, 1.0 / l_i, 0.0)
acc = acc * inv_l_i[:, None]
Expand Down
Loading