diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index a59b150e7ae..64c6d3e46d9 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -960,6 +960,34 @@ def select_as_symint_impl(x: torch.Tensor, dim: int, index: int): lib.impl(name, select_as_symint_impl, "Meta") select_as_symint_op = getattr(getattr(torch.ops, namespace), name) +########## +## sdpa ## +########## + + +def sdpa_impl( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + scale: Optional[float] = None, +): + if scale is None: + scale = 1.0 / (q.size(-1) ** 0.5) + attn = torch.matmul(q, k.transpose(-2, -1)) * scale + if attn_mask is not None: + attn = attn + attn_mask + attn = torch.softmax(attn, dim=-1) + return torch.matmul(attn, v) + + +name = "sdpa" +lib.define( + f"{name}(Tensor q, Tensor k, Tensor v, Tensor? attn_mask = None, float? scale = None) -> Tensor" +) +lib.impl(name, sdpa_impl, "CompositeExplicitAutograd") +sdpa_op = getattr(getattr(torch.ops, namespace), name) + ################ ## rms_norm ## ################ diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index ff056d76c3a..2e313f0f91b 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -1071,6 +1071,20 @@ def register_sdpa_cpp_ops(): ) +# ============================================================================= +# SDPA.cpp (fused SDPA entry point) +# ============================================================================= + + +@update_features("et_vk::sdpa") +def register_general_sdpa(): + return OpFeatures( + inputs_storage=utils.CONTIGUOUS_ANY, + inputs_dtypes=utils.FP_T, + supports_resize=True, + ) + + # ============================================================================= # RotaryEmbedding.cpp # ============================================================================= diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile.glslh index 72b5bdb812e..581b05072df 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile.glslh @@ -13,12 +13,19 @@ * Macro Settings: * - TILE_M * - TILE_K4 + * + * Optional: + * - LINEAR_FP_INPUT_TILE_VEC4_T — input tile vec4 type (default: VEC4_T). */ #extension GL_EXT_control_flow_attributes : require +#ifndef LINEAR_FP_INPUT_TILE_VEC4_T +#define LINEAR_FP_INPUT_TILE_VEC4_T VEC4_T +#endif + struct FPInputTile { - VEC4_T data[TILE_M][TILE_K4]; + LINEAR_FP_INPUT_TILE_VEC4_T data[TILE_M][TILE_K4]; }; #ifdef DEBUG_MODE diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile_load.glslh index 358379b3efd..84bf2e07bea 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile_load.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile_load.glslh @@ -21,11 +21,11 @@ #include "linear_fp_input_tile.glslh" -VEC4_T load_input_x4(const int k4, const int m, const int ntexels_k) { +LINEAR_FP_INPUT_TILE_VEC4_T load_input_x4(const int k4, const int m, const int ntexels_k) { #ifdef INPUT_BUFFER - return t_input[(m * ntexels_k) + k4]; + return LINEAR_FP_INPUT_TILE_VEC4_T(t_input[(m * ntexels_k) + k4]); #else - return texelFetch(t_input, ivec3(k4, m, 0), 0); + return LINEAR_FP_INPUT_TILE_VEC4_T(texelFetch(t_input, ivec3(k4, m, 0), 0)); #endif } @@ -53,7 +53,7 @@ void load_input_tile_with_checks( if (m_start + m < M && k4_start + k4 < K4) { in_tile.data[m][k4] = load_input_x4(k4_start + k4, m_start + m, K4); } else { - in_tile.data[m][k4] = VEC4_T(0.0); + in_tile.data[m][k4] = LINEAR_FP_INPUT_TILE_VEC4_T(0.0); } } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile.glslh index ca466447084..b6fc31951f5 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile.glslh @@ -10,6 +10,11 @@ * Macro Settings: * - TILE_M * - TILE_N4 + * + * Optional: + * - LINEAR_FP_OUTPUT_TILE_VEC4_T — accumulator vec4 type (default: VEC4_T). + * Set this to `vec4` to force fp32 accumulation regardless of DTYPE; used + * by fused SDPA QK to avoid fp16 overflow in Q@K^T. */ #ifndef LINEAR_FP_OUTPUT_TILE_GLSLH @@ -17,14 +22,19 @@ #extension GL_EXT_control_flow_attributes : require +#ifndef LINEAR_FP_OUTPUT_TILE_VEC4_T +#define LINEAR_FP_OUTPUT_TILE_VEC4_T VEC4_T +#define LINEAR_FP_OUTPUT_TILE_VEC4_T_IS_DEFAULT +#endif + struct FPOutTile { - VEC4_T data[TILE_M][TILE_N4]; + LINEAR_FP_OUTPUT_TILE_VEC4_T data[TILE_M][TILE_N4]; }; void initialize(out FPOutTile out_tile) { [[unroll]] for (int m = 0; m < TILE_M; ++m) { [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { - out_tile.data[m][n4] = VEC4_T(0); + out_tile.data[m][n4] = LINEAR_FP_OUTPUT_TILE_VEC4_T(0); } } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh index 60a19ca9fc9..73faf9074ac 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh @@ -21,6 +21,12 @@ #include "linear_fp_per_out_channel_params.glslh" #include "linear_fp_weight_tile.glslh" +#if defined(LINEAR_FP_WEIGHT_TILE_VEC4_T_IS_DEFAULT) == defined(LINEAR_FP_OUTPUT_TILE_VEC4_T_IS_DEFAULT) +#define MAYBE_CAST_WVEC4(x) (x) +#else +#define MAYBE_CAST_WVEC4(x) LINEAR_FP_OUTPUT_TILE_VEC4_T(x) +#endif + void fp_accumulate_with_fp_weight( inout FPOutTile accum, FPInputTile in_tile, @@ -29,23 +35,23 @@ void fp_accumulate_with_fp_weight( [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { accum.data[m][n4] = - fma(VEC4_T(in_tile.data[m][k4][0]), - w_tile.data[mul_4(k4)][n4], + fma(LINEAR_FP_OUTPUT_TILE_VEC4_T(in_tile.data[m][k4][0]), + MAYBE_CAST_WVEC4(w_tile.data[mul_4(k4)][n4]), accum.data[m][n4]); accum.data[m][n4] = - fma(VEC4_T(in_tile.data[m][k4][1]), - w_tile.data[mul_4(k4) + 1][n4], + fma(LINEAR_FP_OUTPUT_TILE_VEC4_T(in_tile.data[m][k4][1]), + MAYBE_CAST_WVEC4(w_tile.data[mul_4(k4) + 1][n4]), accum.data[m][n4]); accum.data[m][n4] = - fma(VEC4_T(in_tile.data[m][k4][2]), - w_tile.data[mul_4(k4) + 2][n4], + fma(LINEAR_FP_OUTPUT_TILE_VEC4_T(in_tile.data[m][k4][2]), + MAYBE_CAST_WVEC4(w_tile.data[mul_4(k4) + 2][n4]), accum.data[m][n4]); accum.data[m][n4] = - fma(VEC4_T(in_tile.data[m][k4][3]), - w_tile.data[mul_4(k4) + 3][n4], + fma(LINEAR_FP_OUTPUT_TILE_VEC4_T(in_tile.data[m][k4][3]), + MAYBE_CAST_WVEC4(w_tile.data[mul_4(k4) + 3][n4]), accum.data[m][n4]); } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_store.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_store.glslh index 6fb399ff99b..9ee5f004cf5 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_store.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_store.glslh @@ -25,14 +25,14 @@ #include "linear_fp_output_tile.glslh" void write_output_x4( - const VEC4_T out_texel, + const LINEAR_FP_OUTPUT_TILE_VEC4_T out_texel, const int n4, const int m, const int N4) { #ifdef OUTPUT_BUFFER - t_output[m * N4 + n4] = out_texel; + t_output[m * N4 + n4] = VEC4_T(out_texel); #else - imageStore(t_output, ivec3(n4, m, 0), out_texel); + imageStore(t_output, ivec3(n4, m, 0), VEC4_T(out_texel)); #endif } diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_packed_weight_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_packed_weight_tile_load.glslh index 36b2a7296ef..5592042f6f7 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_packed_weight_tile_load.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_packed_weight_tile_load.glslh @@ -23,12 +23,12 @@ #include "linear_fp_weight_tile.glslh" -VEC4_T load_packed_weight_x4( +LINEAR_FP_WEIGHT_TILE_VEC4_T load_packed_weight_x4( const int n4, const int dk, const int k4, const int b, const int K4, const int N4) { #ifdef WEIGHT_BUFFER - return t_weight_packed[((b * K4 + k4) * N4 + n4) * 4 + dk]; + return LINEAR_FP_WEIGHT_TILE_VEC4_T(t_weight_packed[((b * K4 + k4) * N4 + n4) * 4 + dk]); #else - return VEC4_T(texelFetch(t_weight_packed, ivec2(n4 * 4 + dk, b * K4 + k4), 0)); + return LINEAR_FP_WEIGHT_TILE_VEC4_T(texelFetch(t_weight_packed, ivec2(n4 * 4 + dk, b * K4 + k4), 0)); #endif } @@ -65,7 +65,7 @@ void load_packed_weight_tile_with_checks( if (k4 < K4 && n4_start + n4 < N4) { tile.data[k][n4] = load_packed_weight_x4(n4_start + n4, dk, k4, b, K4, N4); } else { - tile.data[k][n4] = VEC4_T(0); + tile.data[k][n4] = LINEAR_FP_WEIGHT_TILE_VEC4_T(0); } } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_tile.glslh index 5e010442540..c57c5e72f0d 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_tile.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_tile.glslh @@ -10,6 +10,9 @@ * Macro Settings: * - TILE_K * - TILE_N4 + * + * Optional: + * - LINEAR_FP_WEIGHT_TILE_VEC4_T — weight tile vec4 type (default: VEC4_T). */ #ifndef LINEAR_FP_WEIGHT_TILE_GLSLH @@ -19,8 +22,13 @@ #include "common.glslh" +#ifndef LINEAR_FP_WEIGHT_TILE_VEC4_T +#define LINEAR_FP_WEIGHT_TILE_VEC4_T VEC4_T +#define LINEAR_FP_WEIGHT_TILE_VEC4_T_IS_DEFAULT +#endif + struct FPWeightTile { - VEC4_T data[TILE_K][TILE_N4]; + LINEAR_FP_WEIGHT_TILE_VEC4_T data[TILE_K][TILE_N4]; }; #ifdef DEBUG_MODE diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl index e6c118b6ab2..6c095e66255 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl @@ -9,14 +9,22 @@ #version 450 core #define PRECISION ${PRECISION} -#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} -#define T ${texel_load_component_type(DTYPE, STORAGE)} #define NUM_WORKERS_PER_WG 64 +$if MODE == "llm": + #define HAS_INPUT_POS + +#define IN_DTYPE ${IN_DTYPE} +#define OUT_DTYPE ${OUT_DTYPE} +#define SOFTMAX_IN_VEC4_T ${texel_load_type(IN_DTYPE, STORAGE)} +#define SOFTMAX_ACC_T ${texel_load_component_type(IN_DTYPE, STORAGE)} +#define VEC4_T ${texel_load_type(OUT_DTYPE, STORAGE)} +#define T ${texel_load_component_type(OUT_DTYPE, STORAGE)} + ${define_active_storage_type(STORAGE)} -${define_required_extensions(STORAGE, DTYPE)} +${define_required_extensions(STORAGE, [IN_DTYPE, OUT_DTYPE])} #extension GL_EXT_control_flow_attributes : require @@ -24,19 +32,22 @@ layout(std430) buffer; #include "common.glslh" -${layout_declare_tensor(B, "w", "t_attn_weights_softmax", DTYPE, STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_attn_weights", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "w", "t_attn_weights_softmax", OUT_DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_attn_weights", IN_DTYPE, STORAGE, is_scalar_array=False)} -${layout_declare_ubo(B, "ivec4", "q_projected_sizes")} -${layout_declare_ubo(B, "int", "input_pos")} +${layout_declare_ubo(B, "ivec4", "q_sizes")} +${layout_declare_ubo(B, "ivec4", "k_sizes")} +$if MODE == "llm": + ${layout_declare_ubo(B, "int", "input_pos")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; -// Shared memory for cooperative max finding and exp sum reduction -shared T shared_max[NUM_WORKERS_PER_WG]; -shared T shared_exp_sum[NUM_WORKERS_PER_WG]; +// Shared memory for cooperative max finding and exp sum reduction. +// For fused SDPA, reductions happen in fp32 for numerical stability. +shared SOFTMAX_ACC_T shared_max[NUM_WORKERS_PER_WG]; +shared SOFTMAX_ACC_T shared_exp_sum[NUM_WORKERS_PER_WG]; -VEC4_T load_attn_weights_c4( +SOFTMAX_IN_VEC4_T load_attn_weights_c4( const int c4, const int s, const int q_h, @@ -46,7 +57,7 @@ VEC4_T load_attn_weights_c4( #ifdef USING_BUFFER return t_attn_weights[(q_h * S * C4) + (s * C4) + c4]; #else - return texelFetch(t_attn_weights, ivec3(c4, s, q_h), 0); + return SOFTMAX_IN_VEC4_T(texelFetch(t_attn_weights, ivec3(c4, s, q_h), 0)); #endif } @@ -65,26 +76,61 @@ void store_attn_weights_softmax_c4( #endif } +/* + * 3-pass numerically stable softmax over the context_len dimension of + * attention weights. + * + * LLM SDPA (HAS_INPUT_POS): + * reads VEC4_T (input dtype), reduces in T, writes VEC4_T. + * attn_weights S dim is padded to S_aligned. + * current context_len = input_pos + S. + * + * Fused SDPA (!HAS_INPUT_POS): + * reads vec4 (fp32 from QK), reduces in fp32, writes VEC4_T (input dtype). + * attn_weights S dim is not padded. + * context_len = k_sizes.y. + * + * Dispatch: (1, S, H * B) — for LLM (batch=1), H * B == H_q. + */ void main() { const int worker_id = int(gl_LocalInvocationID.x); // Index along attention weight's sequence_len dim const int s = int(gl_GlobalInvocationID.y); - // idx along attention weight's num_q_heads dim + // For LLM: q_head index. For fused: combined batch*H + head index. const int q_h = int(gl_GlobalInvocationID.z); - // number of Q heads - const int Q_H = q_projected_sizes.y; - // sequence length - const int S = q_projected_sizes.z; +#ifdef HAS_INPUT_POS + // LLM: q_sizes is WHCN {D, H_q, S, B} + const int Q_H = q_sizes.y; + const int S = q_sizes.z; +#else + // Fused: q_sizes is WHCN {D, S, H, B} + const int Q_H = q_sizes.z; + const int S = q_sizes.y; +#endif const int S_aligned = align_up_4(S); + +#ifdef HAS_INPUT_POS // manually determine size of the context_len dim of the attention weight. // The "actual" tensor sizes may have been aligned to a multiple of 4 to allow // memory loads to be aligned to texel boundaries. const int context_len = input_pos + S; +#else + const int context_len = k_sizes.y; +#endif const int context_texel_len = div_up_4(context_len); - if (s >= S || q_h >= Q_H) { + // LLM: attn_weights S dim is padded to S_aligned; fused: not padded. +#ifdef HAS_INPUT_POS + const int attn_S = S_aligned; +#else + const int attn_S = S; +#endif + + // bounds check — q_h bound is Q_H * batch_size; for LLM (batch=1) this + // equals Q_H, for fused this equals H * B. + if (s >= S || q_h >= Q_H * q_sizes.w) { return; } @@ -96,25 +142,25 @@ void main() { // Without this, exp(x) can overflow float32 when x > ~88.7. // ========================================================================= - T local_max = T(-1.0 / 0.0); // -infinity + SOFTMAX_ACC_T local_max = SOFTMAX_ACC_T(-1.0 / 0.0); // -infinity for (int c4 = worker_id; c4 < C4_limit; c4 += NUM_WORKERS_PER_WG) { - VEC4_T in_texel = load_attn_weights_c4( - c4, s, q_h, context_texel_len, S_aligned, Q_H); + SOFTMAX_IN_VEC4_T in_texel = load_attn_weights_c4( + c4, s, q_h, context_texel_len, attn_S, Q_H); for (int comp = 0; comp < 4; comp++) { - local_max = max(local_max, in_texel[comp]); + local_max = max(local_max, SOFTMAX_ACC_T(in_texel[comp])); } } if (worker_id == 0) { for (int c4 = C4_limit; c4 < context_texel_len; ++c4) { const int c_base = mul_4(c4); - VEC4_T in_texel = load_attn_weights_c4( - c4, s, q_h, context_texel_len, S_aligned, Q_H); + SOFTMAX_IN_VEC4_T in_texel = load_attn_weights_c4( + c4, s, q_h, context_texel_len, attn_S, Q_H); [[unroll]] for (int comp = 0; comp < 4; comp++) { if (c_base + comp < context_len) { - local_max = max(local_max, in_texel[comp]); + local_max = max(local_max, SOFTMAX_ACC_T(in_texel[comp])); } } } @@ -135,31 +181,31 @@ void main() { barrier(); } - const T global_max = shared_max[0]; + const SOFTMAX_ACC_T global_max = shared_max[0]; // ========================================================================= // Pass 2: Compute sum(exp(x - max)) using the global max for stability // ========================================================================= - T local_exp_sum = T(0); + SOFTMAX_ACC_T local_exp_sum = SOFTMAX_ACC_T(0); for (int c4 = worker_id; c4 < C4_limit; c4 += NUM_WORKERS_PER_WG) { - VEC4_T in_texel = load_attn_weights_c4( - c4, s, q_h, context_texel_len, S_aligned, Q_H); + SOFTMAX_IN_VEC4_T in_texel = load_attn_weights_c4( + c4, s, q_h, context_texel_len, attn_S, Q_H); for (int comp = 0; comp < 4; comp++) { - local_exp_sum += exp(in_texel[comp] - global_max); + local_exp_sum += exp(SOFTMAX_ACC_T(in_texel[comp]) - global_max); } } if (worker_id == 0) { for (int c4 = C4_limit; c4 < context_texel_len; ++c4) { const int c_base = mul_4(c4); - VEC4_T in_texel = load_attn_weights_c4( - c4, s, q_h, context_texel_len, S_aligned, Q_H); + SOFTMAX_IN_VEC4_T in_texel = load_attn_weights_c4( + c4, s, q_h, context_texel_len, attn_S, Q_H); [[unroll]] for (int comp = 0; comp < 4; comp++) { if (c_base + comp < context_len) { - local_exp_sum += exp(in_texel[comp] - global_max); + local_exp_sum += exp(SOFTMAX_ACC_T(in_texel[comp]) - global_max); } } } @@ -187,27 +233,32 @@ void main() { // ========================================================================= for (int c4 = worker_id; c4 < C4_limit; c4 += NUM_WORKERS_PER_WG) { - VEC4_T in_texel = load_attn_weights_c4( - c4, s, q_h, context_texel_len, S_aligned, Q_H); + SOFTMAX_IN_VEC4_T in_texel = load_attn_weights_c4( + c4, s, q_h, context_texel_len, attn_S, Q_H); - VEC4_T out_texel = exp(in_texel - global_max) / local_exp_sum; + VEC4_T out_texel; + [[unroll]] for (int comp = 0; comp < 4; comp++) { + out_texel[comp] = T( + exp(SOFTMAX_ACC_T(in_texel[comp]) - global_max) / local_exp_sum); + } store_attn_weights_softmax_c4( - out_texel, c4, s, q_h, context_texel_len, S_aligned, Q_H); + out_texel, c4, s, q_h, context_texel_len, attn_S, Q_H); } if (worker_id == 0) { for (int c4 = C4_limit; c4 < context_texel_len; ++c4) { const int c_base = mul_4(c4); - VEC4_T in_texel = load_attn_weights_c4( - c4, s, q_h, context_texel_len, S_aligned, Q_H); + SOFTMAX_IN_VEC4_T in_texel = load_attn_weights_c4( + c4, s, q_h, context_texel_len, attn_S, Q_H); VEC4_T out_texel = VEC4_T(0); [[unroll]] for (int comp = 0; comp < 4; comp++) { if (c_base + comp < context_len) { - out_texel[comp] = exp(in_texel[comp] - global_max) / local_exp_sum; + out_texel[comp] = T( + exp(SOFTMAX_ACC_T(in_texel[comp]) - global_max) / local_exp_sum); } } store_attn_weights_softmax_c4( - out_texel, c4, s, q_h, context_texel_len, S_aligned, Q_H); + out_texel, c4, s, q_h, context_texel_len, attn_S, Q_H); } } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.yaml index 66ec030680e..d46e301e203 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.yaml @@ -6,14 +6,30 @@ sdpa_attn_weights_softmax: parameter_names_with_default_values: - DTYPE: float + IN_DTYPE: float + OUT_DTYPE: float STORAGE: texture3d + MODE: llm generate_variant_forall: STORAGE: - VALUE: texture3d - VALUE: buffer - DTYPE: - - VALUE: float - - VALUE: half + combination: + parameter_names: [IN_DTYPE, OUT_DTYPE] + combos: + - parameter_values: [float, float] + suffix: float + - parameter_values: [half, half] + suffix: half shader_variants: - NAME: sdpa_attn_weights_softmax + - NAME: fused_sdpa_softmax + MODE: fused + IN_DTYPE: float + generate_variant_forall: + STORAGE: + - VALUE: texture3d + - VALUE: buffer + OUT_DTYPE: + - VALUE: float + - VALUE: half diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl index e50ca0612fd..b7f14f435fa 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl @@ -15,9 +15,13 @@ $if IO_STORAGE == "buffer": #define OUTPUT_BUFFER #define INPUT_BUFFER + #define ATTN_WEIGHTS_BUFFER $if K_CACHE_STORAGE == "buffer": #define K_CACHE_BUFFER +#define Q_LAYOUT DHSB +#define K_LAYOUT DHSB + #define TILE_K4 ${TILE_K4} #define TILE_N4 ${TILE_N4} @@ -34,11 +38,11 @@ layout(std430) buffer; #include "common.glslh" ${layout_declare_tensor(B, "w", "t_attn_weights", DTYPE, IO_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_q_projected", DTYPE, IO_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_k_cache", DTYPE, K_CACHE_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_q", DTYPE, IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_k", DTYPE, K_CACHE_STORAGE, is_scalar_array=False)} -${layout_declare_ubo(B, "ivec4", "q_projected_sizes")} -${layout_declare_ubo(B, "ivec4", "k_cache_sizes")} +${layout_declare_ubo(B, "ivec4", "q_sizes")} +${layout_declare_ubo(B, "ivec4", "k_sizes")} ${layout_declare_ubo(B, "int", "input_pos")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; @@ -75,18 +79,20 @@ void main() { // 1. const int s = 0; + // head dimension + const int D = q_sizes.x; // texel size of head_dim, over which the dot product is accumulated - const int D4 = div_up_4(q_projected_sizes.x); + const int D4 = div_up_4(D); // number of Q heads - const int Q_H = q_projected_sizes.y; + const int Q_H = q_sizes.y; // sequence length - const int S = q_projected_sizes.z; + const int S = q_sizes.z; const int S_aligned = align_up_4(S); // number of K/V heads - const int KV_H = k_cache_sizes.y; + const int KV_H = k_sizes.y; // Max context length - const int C = k_cache_sizes.z; + const int C = k_sizes.z; const int C4 = div_up_4(C); int kv_h = q_h; @@ -126,8 +132,9 @@ void main() { s, q_h, D4, - Q_H, - S); + D, + S, + Q_H); load_k_cache_tile_with_checks( w_tile, @@ -135,6 +142,7 @@ void main() { c, kv_h, D4, + D, context_len, C, KV_H); diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl index a6703437c41..662d5edde68 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl @@ -9,8 +9,14 @@ #version 450 core #define PRECISION ${PRECISION} -#define VEC4_T ${texel_load_type(DTYPE, IO_STORAGE)} -#define T ${texel_load_component_type(DTYPE, IO_STORAGE)} + +#define IN_DTYPE ${IN_DTYPE} +#define OUT_DTYPE ${OUT_DTYPE} + +#define VEC4_T ${texel_load_type(IN_DTYPE, IO_STORAGE)} +#define T ${texel_load_component_type(IN_DTYPE, IO_STORAGE)} + +#define LINEAR_FP_OUTPUT_TILE_VEC4_T ${texel_load_type(OUT_DTYPE, IO_STORAGE)} $if IO_STORAGE == "buffer": #define OUTPUT_BUFFER @@ -18,6 +24,19 @@ $if IO_STORAGE == "buffer": $if K_CACHE_STORAGE == "buffer": #define K_CACHE_BUFFER +$if MODE == "llm": + #define HAS_INPUT_POS + #define HAS_GQA + #define Q_LAYOUT DHSB + #define K_LAYOUT DHSB +$else: + #define SDPA_PAD_D + #define Q_LAYOUT DSHB + #define K_LAYOUT DSHB + +$if HAS_BIAS: + #define HAS_BIAS + #define TILE_M4 ${TILE_M4} #define TILE_K4 ${TILE_K4} #define TILE_N4 ${TILE_N4} @@ -26,19 +45,24 @@ $if K_CACHE_STORAGE == "buffer": #define TILE_K ${TILE_K4 * 4} #define TILE_N ${TILE_N4 * 4} -${define_required_extensions(IO_STORAGE, DTYPE)} +${define_required_extensions(IO_STORAGE, [IN_DTYPE, OUT_DTYPE])} layout(std430) buffer; #include "common.glslh" -${layout_declare_tensor(B, "w", "t_attn_weights", DTYPE, IO_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_q_projected", DTYPE, IO_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_k_cache", DTYPE, K_CACHE_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "w", "t_attn_weights", OUT_DTYPE, IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_q", IN_DTYPE, IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_k", IN_DTYPE, K_CACHE_STORAGE, is_scalar_array=False)} +$if HAS_BIAS: + ${layout_declare_tensor(B, "r", "t_bias", IN_DTYPE, IO_STORAGE, is_scalar_array=False)} -${layout_declare_ubo(B, "ivec4", "q_projected_sizes")} -${layout_declare_ubo(B, "ivec4", "k_cache_sizes")} -${layout_declare_ubo(B, "int", "input_pos")} +${layout_declare_ubo(B, "ivec4", "q_sizes")} +${layout_declare_ubo(B, "ivec4", "k_sizes")} +$if MODE == "llm": + ${layout_declare_ubo(B, "int", "input_pos")} +$if HAS_BIAS: + ${layout_declare_ubo(B, "ivec4", "bias_sizes")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; @@ -50,33 +74,28 @@ ${layout_declare_spec_const(C, "float", "inv_scale", "1.0")} #include "sdpa_fp_attn_weight_tile_store.glslh" /* - * Compute attention weights given the q_projected and k_cache tensors. - * q_projected has shape (batches, seq_len, num_q_heads, head_dim) - * k_cache has shape (batches, max_context_len, num_kv_heads, head_dim) - * output has shape (batches, num_q_heads, seq_len, context_len) - * - * This shader also applies scales and masking to the computed attention - * weights. + * Compute attention weights (Q @ K^T) given the Q and K tensors. * - * The scale applied is 1.0 / sqrt(head_dim_length). + * LLM SDPA (HAS_INPUT_POS, HAS_GQA): + * q: [B, S, H_q, D] (DHSB layout) + * k (k_cache): [B, C_max, H_kv, D] (DHSB layout) + * attn_weights: [B, H_q, S, context_len] in input dtype + * current context_len = input_pos + S + * Applies combined scale + causal mask. * - * The mask applied is a bit more complicated. Imagine you create a square - * matrix of size (input_pos + seq_len, input_pos + seq_len), and then set the - * lower triangular section of the matrix to -inf. Then, slice the matrix along - * the row dimension starting from input_pos to input_pos + seq_len. You end up - * with a partial mask with size (seq_len, input_pos + seq_len). This is the - * mask that is applied to the attention weight. - * - * In the shader, instead of generating the mask, the index of the elment is - * inspected to determine if it would have been masked. Given an element at - * tensor index (n, c, h, w), it would be masked if w < h + input_pos. + * Fused SDPA: + * q: [B, H, S, D] (DSHB layout) + * k: [B, H, L, D] (DSHB layout) + * attn_weights: [B, H, S, L] in fp32 to prevent fp16 overflow in Q@K^T + * Applies scalar scale, optionally adds bias. * + * Dispatch: (context_tiles, S_tiles, H * B) — for LLM (batch=1), H * B == H_q. */ void main() { const int tile_idx_x = int(gl_GlobalInvocationID.x); const int tile_idx_y = int(gl_GlobalInvocationID.y); - // idx along output num_q_heads dim + // For LLM: q_head index. For fused: combined batch*H + head index. const int q_h = int(gl_GlobalInvocationID.z); // idx along the output context_len dim @@ -85,32 +104,48 @@ void main() { // idx along the output seq_len dim const int s = tile_idx_y * TILE_M; - const int s4 = div_4(s); - - // texel size of head_dim, over which the dot product is accumulated - const int D4 = div_up_4(q_projected_sizes.x); - // number of Q heads - const int Q_H = q_projected_sizes.y; - // sequence length - const int S = q_projected_sizes.z; + +#ifdef HAS_INPUT_POS + // LLM: q_sizes is WHCN {D, H_q, S, B} + const int D = q_sizes.x; + const int Q_H = q_sizes.y; + const int S = q_sizes.z; + // k_sizes is WHCN {D, H_kv, C_max, B} + const int KV_H = k_sizes.y; + const int C = k_sizes.z; +#else + // Fused: q_sizes is WHCN {D, S, H, B} + const int D = q_sizes.x; + const int S = q_sizes.y; + const int Q_H = q_sizes.z; + // k_sizes is WHCN {D, L, H, B} + const int KV_H = k_sizes.z; + const int C = k_sizes.y; +#endif + const int D4 = div_up_4(D); const int S_aligned = align_up_4(S); - // number of K/V heads - const int KV_H = k_cache_sizes.y; - // Max context length - const int C = k_cache_sizes.z; - const int C4 = div_up_4(C); +#ifdef HAS_INPUT_POS + // current context length for LLM decode/prefill + const int context_len = input_pos + S; +#else + // fused: full key sequence length from k_sizes + const int context_len = k_sizes.y; +#endif + const int context_texel_len = div_up_4(context_len); +#ifdef HAS_GQA int kv_h = q_h; if (KV_H < Q_H) { kv_h = q_h / (Q_H / KV_H); } +#else + const int kv_h = q_h; +#endif - const int context_len = input_pos + S; - const int context_texel_len = div_up_4(context_len); - - // bounds check - if (c >= context_len || s >= S || q_h >= Q_H) { + // bounds check — q_h bound is Q_H * batch_size; for LLM (batch=1) this + // equals Q_H, for fused this equals H * B. + if (c >= context_len || s >= S || q_h >= Q_H * q_sizes.w) { return; } @@ -120,6 +155,16 @@ void main() { FPInputTile q_tile; FPWeightTile w_tile; + // The LLM attn_weights tensor is padded to S_aligned in its S dim, while + // fused attn_weights is not padded. The store/bias helpers bound-check + // against this. +#ifdef HAS_INPUT_POS + const int attn_S = S_aligned; +#else + const int attn_S = S; +#endif + +#ifdef HAS_INPUT_POS // If the tile is completely inside the mask region, then there is no need to // compute the output tile. All the elements in the output tile can be set to // negative infinity. @@ -127,9 +172,9 @@ void main() { if (tile_in_mask_region) { const VEC4_T negative_infinity_vec = VEC4_T(negative_infinity_val); set_out_tile_to_vec(out_tile, negative_infinity_vec); - } - // Otherwise, need to actually compute output tile - else { + } else +#endif + { for (int d4 = 0; d4 < D4; d4++) { load_q_projected_tile_with_checks( q_tile, @@ -137,8 +182,9 @@ void main() { s, q_h, D4, - Q_H, - S); + D, + S, + Q_H); load_k_cache_tile_with_checks( w_tile, @@ -146,15 +192,16 @@ void main() { c, kv_h, D4, + D, context_len, C, KV_H); - fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); } - // Apply scale and mask +#ifdef HAS_INPUT_POS + // LLM: combined scale + causal mask VEC4_T inv_scale_vec = VEC4_T(inv_scale); apply_scale_and_mask( out_tile, @@ -162,6 +209,13 @@ void main() { input_pos, c, s); +#else + // Fused: scalar scale, optional bias + apply_scale(out_tile, inv_scale); + #ifdef HAS_BIAS + apply_bias(out_tile, c4, s, q_h, Q_H, context_texel_len, attn_S); + #endif +#endif } store_attn_weight_tile_with_checks( @@ -170,6 +224,6 @@ void main() { s, q_h, context_texel_len, - S_aligned, + attn_S, Q_H); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.yaml index 7fc016cf3c3..24494b408fa 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.yaml @@ -6,9 +6,12 @@ sdpa_compute_attn_weights_tiled: parameter_names_with_default_values: - DTYPE: float + IN_DTYPE: float + OUT_DTYPE: float IO_STORAGE: texture3d K_CACHE_STORAGE: texture3d + MODE: llm + HAS_BIAS: false TILE_M4: 1 TILE_K4: 1 TILE_N4: 1 @@ -19,8 +22,39 @@ sdpa_compute_attn_weights_tiled: - parameter_values: [texture3d, texture3d] - parameter_values: [buffer, texture3d] - parameter_values: [buffer, buffer] - DTYPE: - - VALUE: float - - VALUE: half + combination1: + parameter_names: [IN_DTYPE, OUT_DTYPE] + combos: + - parameter_values: [float, float] + suffix: float + - parameter_values: [half, half] + suffix: half shader_variants: - NAME: sdpa_compute_attn_weights_tiled + - NAME: fused_sdpa_qk_tiled + MODE: fused + OUT_DTYPE: float + generate_variant_forall: + combination: + parameter_names: [IO_STORAGE, K_CACHE_STORAGE] + combos: + - parameter_values: [texture3d, texture3d] + - parameter_values: [buffer, texture3d] + - parameter_values: [buffer, buffer] + IN_DTYPE: + - VALUE: float + - VALUE: half + - NAME: fused_sdpa_qk_tiled_bias + MODE: fused + OUT_DTYPE: float + HAS_BIAS: true + generate_variant_forall: + combination: + parameter_names: [IO_STORAGE, K_CACHE_STORAGE] + combos: + - parameter_values: [texture3d, texture3d] + - parameter_values: [buffer, texture3d] + - parameter_values: [buffer, buffer] + IN_DTYPE: + - VALUE: float + - VALUE: half diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl index 2e5fda18e14..cd2c689ebc8 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl @@ -15,9 +15,14 @@ $if IO_STORAGE == "buffer": #define OUTPUT_BUFFER #define INPUT_BUFFER + #define ATTN_WEIGHTS_BUFFER $if V_CACHE_STORAGE == "buffer": #define V_CACHE_BUFFER +#define V_LAYOUT DHSB +#define OUT_LAYOUT DHSB +#define SDPA_V_BUF t_v_cache + #define TILE_K4 ${TILE_K4} #define TILE_N4 ${TILE_N4} @@ -37,8 +42,8 @@ ${layout_declare_tensor(B, "w", "t_output", DTYPE, IO_STORAGE, is_scalar_array=F ${layout_declare_tensor(B, "r", "t_attn_weights", DTYPE, IO_STORAGE, is_scalar_array=False)} ${layout_declare_tensor(B, "r", "t_v_cache", DTYPE, V_CACHE_STORAGE, is_scalar_array=False)} -${layout_declare_ubo(B, "ivec4", "q_projected_sizes")} -${layout_declare_ubo(B, "ivec4", "v_cache_sizes")} +${layout_declare_ubo(B, "ivec4", "q_sizes")} +${layout_declare_ubo(B, "ivec4", "v_sizes")} ${layout_declare_ubo(B, "int", "input_pos")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; @@ -76,17 +81,17 @@ void main() { const int s = 0; // texel size of head_dim - const int D4 = div_up_4(q_projected_sizes.x); + const int D4 = div_up_4(q_sizes.x); // number of Q heads - const int Q_H = q_projected_sizes.y; + const int Q_H = q_sizes.y; // sequence length - const int S = q_projected_sizes.z; + const int S = q_sizes.z; const int S_aligned = align_up_4(S); // number of K/V heads - const int KV_H = v_cache_sizes.y; + const int KV_H = v_sizes.y; // Max context length - const int C = v_cache_sizes.z; + const int C = v_sizes.z; const int C4 = div_up_4(C); int kv_h = q_h; diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl index 2027a9908a9..9f8f2dbc231 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl @@ -18,6 +18,15 @@ $if IO_STORAGE == "buffer": $if V_CACHE_STORAGE == "buffer": #define V_CACHE_BUFFER +$if MODE == "llm": + #define HAS_INPUT_POS + #define HAS_GQA + #define V_LAYOUT DHSB + #define OUT_LAYOUT DHSB +$else: + #define V_LAYOUT DSHB + #define OUT_LAYOUT DSHB + #define TILE_M4 ${TILE_M4} // Equvalent to K4 in matrix multiplication #define TILE_K4 ${TILE_K4} @@ -36,11 +45,12 @@ layout(std430) buffer; ${layout_declare_tensor(B, "w", "t_output", DTYPE, IO_STORAGE, is_scalar_array=False)} ${layout_declare_tensor(B, "r", "t_attn_weights", DTYPE, IO_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_v_cache", DTYPE, V_CACHE_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_v", DTYPE, V_CACHE_STORAGE, is_scalar_array=False)} -${layout_declare_ubo(B, "ivec4", "q_projected_sizes")} -${layout_declare_ubo(B, "ivec4", "v_cache_sizes")} -${layout_declare_ubo(B, "int", "input_pos")} +${layout_declare_ubo(B, "ivec4", "q_sizes")} +${layout_declare_ubo(B, "ivec4", "v_sizes")} +$if MODE == "llm": + ${layout_declare_ubo(B, "int", "input_pos")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; @@ -50,16 +60,27 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; #include "sdpa_fp_out_tile_store.glslh" /* - * Compute SDPA output given the attention weights and v_cache tensors. - * attention weights has shape (batches, num_q_heads, seq_len, context_len) - * v_cache has shape (batches, max_context_len, num_kv_heads, head_dim) - * output has shape (batches, seq_len, num_q_heads, head_dim) + * Compute SDPA output given the attention weights and V tensors. + * + * LLM SDPA (HAS_INPUT_POS, HAS_GQA): + * attn_weights: [B, H_q, S, context_len] + * v (v_cache): [B, C_max, H_kv, D] (DHSB layout) + * output: [B, S, H_q, D] (DHSB layout) + * current context_len = input_pos + S + * GQA: Q heads may be > KV heads; kv_h = q_h / (H_q / H_kv) + * + * Fused SDPA: + * attn_weights: [B, H, S, context_len] + * v: [B, H, context_len, D] (DSHB layout) + * output: [B, H, S, D] (DSHB layout) + * + * Dispatch: (D_tiles, S_tiles, H * B) — for LLM (batch=1), H * B == H_q. */ void main() { const int tile_idx_x = int(gl_GlobalInvocationID.x); const int tile_idx_y = int(gl_GlobalInvocationID.y); - // idx along output num_q_heads dim + // For LLM: q_head index. For fused: combined batch*H + head index. const int q_h = int(gl_GlobalInvocationID.z); // idx along the output head_dim dim @@ -69,31 +90,47 @@ void main() { // idx along the output seq_len dim const int s = tile_idx_y * TILE_M; - // texel size of head_dim - const int D4 = div_up_4(q_projected_sizes.x); - // number of Q heads - const int Q_H = q_projected_sizes.y; - // sequence length - const int S = q_projected_sizes.z; +#ifdef HAS_INPUT_POS + // LLM: q_sizes is WHCN {D, H_q, S, B} + const int D = q_sizes.x; + const int Q_H = q_sizes.y; + const int S = q_sizes.z; + // v_sizes is WHCN {D, H_kv, C_max, B} + const int KV_H = v_sizes.y; + const int C = v_sizes.z; +#else + // Fused: q_sizes is WHCN {D, S, H, B} + const int D = q_sizes.x; + const int S = q_sizes.y; + const int Q_H = q_sizes.z; + // v_sizes is WHCN {D, context_len, H, B} + const int KV_H = v_sizes.z; + const int C = v_sizes.y; +#endif + const int D4 = div_up_4(D); const int S_aligned = align_up_4(S); - // number of K/V heads - const int KV_H = v_cache_sizes.y; - // Max context length - const int C = v_cache_sizes.z; - const int C4 = div_up_4(C); +#ifdef HAS_INPUT_POS + // current context length for LLM decode/prefill + const int context_len = input_pos + S; +#else + // fused: full key sequence length from v_sizes (DSHB: {D, L, H, B}) + const int context_len = v_sizes.y; +#endif + const int context_texel_len = div_up_4(context_len); +#ifdef HAS_GQA int kv_h = q_h; if (KV_H < Q_H) { kv_h = q_h / (Q_H / KV_H); } +#else + const int kv_h = q_h; +#endif - // current context length - const int context_len = input_pos + S; - const int context_texel_len = div_up_4(context_len); - - // bounds check - if (d4 >= D4 || s >= S || q_h >= Q_H) { + // bounds check — q_h bound is Q_H * batch_size; for LLM (batch=1) this + // equals Q_H, for fused this equals H * B. + if (d4 >= D4 || s >= S || q_h >= Q_H * q_sizes.w) { return; } @@ -103,62 +140,33 @@ void main() { FPInputTile attn_weight_tile; FPWeightTile w_tile; + // For LLM, the attn_weights tensor has seq_len padded up to a multiple of 4 + // (S_aligned). The loader accesses (head * attn_S * C4 + s * C4 + c4), so + // pass S_aligned in LLM mode and S in fused mode. +#ifdef HAS_INPUT_POS + const int attn_S = S_aligned; +#else + const int attn_S = S; +#endif + + // Split loop into aligned + tail for efficiency const int context_len_aligned_down = context_len - mod_4(context_len); const int C4_limit = div_4(context_len_aligned_down); for (int c4 = 0; c4 < C4_limit; c4++) { const int c = mul_4(c4); load_attn_weight_tile_no_checks( - attn_weight_tile, - c4, - s, - q_h, - context_texel_len, - S_aligned, - Q_H); - - load_v_cache_tile_no_checks( - w_tile, - d4, - c, - kv_h, - D4, - context_len, - C, - KV_H); - + attn_weight_tile, c4, s, q_h, context_texel_len, attn_S, Q_H); + load_v_cache_tile_no_checks(w_tile, d4, c, kv_h, D4, context_len, C, KV_H); fp_accumulate_with_fp_weight(out_tile, attn_weight_tile, w_tile); } for (int c4 = C4_limit; c4 < context_texel_len; c4++) { const int c = mul_4(c4); load_attn_weight_tile_with_checks( - attn_weight_tile, - c4, - s, - q_h, - context_texel_len, - S_aligned, - Q_H); - - load_v_cache_tile_with_checks( - w_tile, - d4, - c, - kv_h, - D4, - context_len, - C, - KV_H); - + attn_weight_tile, c4, s, q_h, context_texel_len, attn_S, Q_H); + load_v_cache_tile_with_checks(w_tile, d4, c, kv_h, D4, context_len, C, KV_H); fp_accumulate_with_fp_weight(out_tile, attn_weight_tile, w_tile); } - store_sdpa_out_tile_with_checks( - out_tile, - d4, - s, - q_h, - D4, - S, - Q_H); + store_sdpa_out_tile_with_checks(out_tile, d4, s, q_h, D4, S, Q_H); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.yaml index eac2c6f37dd..ba91114ae92 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.yaml @@ -9,6 +9,7 @@ sdpa_compute_out_tiled: DTYPE: float IO_STORAGE: texture3d V_CACHE_STORAGE: texture3d + MODE: llm TILE_M4: 1 TILE_K4: 1 TILE_N4: 1 @@ -24,3 +25,15 @@ sdpa_compute_out_tiled: - VALUE: half shader_variants: - NAME: sdpa_compute_out_tiled + - NAME: fused_sdpa_av_tiled + MODE: fused + generate_variant_forall: + combination: + parameter_names: [IO_STORAGE, V_CACHE_STORAGE] + combos: + - parameter_values: [texture3d, texture3d] + - parameter_values: [buffer, texture3d] + - parameter_values: [buffer, buffer] + DTYPE: + - VALUE: float + - VALUE: half diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_attn_weight_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_attn_weight_tile_load.glslh index 12b2292fa45..829f03beb60 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_attn_weight_tile_load.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_attn_weight_tile_load.glslh @@ -7,11 +7,19 @@ */ /* - * Assume the following variables are defined in the shader layout: - * - t_attn_weights + * Shared attention weight tile load for both LLM SDPA and fused SDPA + * (used in the AV shader to read softmax output). * - * Macro Settings: - * - INPUT_BUFFER + * The attn_weights tensor layout is [head, S, L] in both cases: + * index = (head * S * L4) + (s * L4) + l4 + * + * No layout switch needed — both variants use the same index formula. + * + * Optional macros: + * INPUT_BUFFER — use buffer path; otherwise texture. Set at the shader + * level when IO_STORAGE == "buffer" and applies to all + * IO tensors uniformly (attn_weights now follows the + * output's storage type). */ #ifndef SDPA_FP_ATTN_WEIGHT_TILE_LOAD_GLSLH @@ -21,7 +29,7 @@ #include "linear_fp_input_tile.glslh" -VEC4_T load_attn_weight_c4( +LINEAR_FP_INPUT_TILE_VEC4_T load_attn_weight_c4( const int c4, const int s, const int q_h, @@ -44,7 +52,7 @@ void load_attn_weight_tile_no_checks( const int S, const int Q_H) { [[unroll]] for (int s = 0; s < TILE_M; ++s) { - [[unroll]] for (int c4 = 0; c4 < TILE_N4; ++c4) { + [[unroll]] for (int c4 = 0; c4 < TILE_K4; ++c4) { tile.data[s][c4] = load_attn_weight_c4(c4_start + c4, s_start + s, q_h, C4, S, Q_H); } @@ -60,12 +68,12 @@ void load_attn_weight_tile_with_checks( const int S, const int Q_H) { [[unroll]] for (int s = 0; s < TILE_M; ++s) { - [[unroll]] for (int c4 = 0; c4 < TILE_N4; ++c4) { + [[unroll]] for (int c4 = 0; c4 < TILE_K4; ++c4) { if (c4_start + c4 < C4 && s_start + s < S) { tile.data[s][c4] = load_attn_weight_c4(c4_start + c4, s_start + s, q_h, C4, S, Q_H); } else { - tile.data[s][c4] = VEC4_T(0.0); + tile.data[s][c4] = LINEAR_FP_INPUT_TILE_VEC4_T(0.0); } } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_attn_weight_tile_store.glslh b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_attn_weight_tile_store.glslh index c64d9af8cfb..83d81ddf812 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_attn_weight_tile_store.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_attn_weight_tile_store.glslh @@ -7,24 +7,39 @@ */ /* - * Assume the following variables are defined in the shader layout: - * - t_attn_weights + * Shared attention weight tile store for both LLM SDPA and fused SDPA. * - * Macro Settings: - * - OUTPUT_BUFFER + * The attn_weights tensor layout is [head, S, L] in both cases: + * index = (head * S * L4) + (s * L4) + l4 + * + * Tile precision is controlled by the caller via LINEAR_FP_OUTPUT_TILE_VEC4_T + * (fused SDPA sets it to vec4 for fp32 accumulation; LLM SDPA leaves it as + * VEC4_T for input-dtype accumulation). All helper functions below are + * available at all times and compile correctly in both modes. The only + * gated helper is apply_bias, which requires t_bias/bias_sizes and is + * therefore guarded by HAS_BIAS. + * + * Required macros/variables: + * t_attn_weights — output buffer/texture + * OUTPUT_BUFFER — buffer mode (otherwise texture). Set at the shader + * level when IO_STORAGE == "buffer" and applies to all + * IO tensors uniformly (attn_weights now follows the + * output's storage type). */ -#ifndef SDPA_FP_ATTN_WEIGHT_TILE_LOAD_GLSLH -#define SDPA_FP_ATTN_WEIGHT_TILE_LOAD_GLSLH +#ifndef SDPA_FP_ATTN_WEIGHT_TILE_STORE_GLSLH +#define SDPA_FP_ATTN_WEIGHT_TILE_STORE_GLSLH #extension GL_EXT_control_flow_attributes : require #include "linear_fp_output_tile.glslh" -T negative_infinity_val = T(-1.0 / 0.0); +// ============================================================ +// Shared store helpers (buffer or texture; fp32 or input-dtype) +// ============================================================ void store_attn_weight_c4( - const VEC4_T out_texel, + const LINEAR_FP_OUTPUT_TILE_VEC4_T out_texel, const int c4, const int s, const int q_h, @@ -72,7 +87,71 @@ void store_attn_weight_tile_with_checks( } } -void set_out_tile_to_vec(out FPOutTile tile, const VEC4_T vec) { +// ============================================================ +// Tile transform helpers (scale, bias, mask, set) +// ============================================================ + +T negative_infinity_val = T(-1.0 / 0.0); + +void apply_scale(inout FPOutTile tile, const float scale) { + const LINEAR_FP_OUTPUT_TILE_VEC4_T scale_vec = + LINEAR_FP_OUTPUT_TILE_VEC4_T(scale); + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + tile.data[m][n4] = tile.data[m][n4] * scale_vec; + } + } +} + +#ifdef HAS_BIAS +void apply_bias( + inout FPOutTile tile, + const int c4_start, + const int s_start, + const int bh, + const int Q_H, + const int C4, + const int S) { + const int bias_C4 = div_up_4(bias_sizes.x); + const int bias_S = bias_sizes.y; + // bias WHCN sizes map to logical shape [B, H, ?, L]: + // sizes.x = L, sizes.y = ?, sizes.z = H, sizes.w = B + // Decompose bh (= q_batch * Q_H + q_head) into Q's batch/head, then + // broadcast each dim independently. This matches standard broadcast + // semantics when bias_H or bias_B is 1. + const int bias_H = bias_sizes.z; + const int bias_B = bias_sizes.w; + const int q_batch = bh / Q_H; + const int q_head = bh - q_batch * Q_H; + const int bias_batch = (bias_B == 1) ? 0 : q_batch; + const int bias_head = (bias_H == 1) ? 0 : q_head; + + [[unroll]] for (int s = 0; s < TILE_M; ++s) { + [[unroll]] for (int c4 = 0; c4 < TILE_N4; ++c4) { + if (c4_start + c4 < C4 && s_start + s < S) { + const int bias_bh = bias_batch * bias_H + bias_head; + const int bias_s = (s_start + s) < bias_S ? (s_start + s) : 0; + const int bias_c4 = c4_start + c4; + if (bias_c4 < bias_C4) { +#ifdef INPUT_BUFFER + const LINEAR_FP_OUTPUT_TILE_VEC4_T bias_val = + LINEAR_FP_OUTPUT_TILE_VEC4_T(VEC4_T( + t_bias[(bias_bh * bias_S * bias_C4) + (bias_s * bias_C4) + + bias_c4])); +#else + const LINEAR_FP_OUTPUT_TILE_VEC4_T bias_val = + LINEAR_FP_OUTPUT_TILE_VEC4_T(VEC4_T( + texelFetch(t_bias, ivec3(bias_c4, bias_s, bias_bh), 0))); +#endif + tile.data[s][c4] = tile.data[s][c4] + bias_val; + } + } + } + } +} +#endif // HAS_BIAS + +void set_out_tile_to_vec(out FPOutTile tile, const LINEAR_FP_OUTPUT_TILE_VEC4_T vec) { [[unroll]] for (int s = 0; s < TILE_M; ++s) { [[unroll]] for (int c4 = 0; c4 < TILE_N4; ++c4) { tile.data[s][c4] = vec; } } @@ -80,7 +159,7 @@ void set_out_tile_to_vec(out FPOutTile tile, const VEC4_T vec) { void apply_scale_and_mask( inout FPOutTile tile, - const VEC4_T inv_scale_vec, + const LINEAR_FP_OUTPUT_TILE_VEC4_T inv_scale_vec, const int input_pos, const int c_idx_start, const int s_idx_start) { @@ -102,4 +181,4 @@ void apply_scale_and_mask( } } -#endif // SDPA_FP_ATTN_WEIGHT_TILE_LOAD_GLSLH +#endif // SDPA_FP_ATTN_WEIGHT_TILE_STORE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_k_cache_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_k_cache_tile_load.glslh index 1880397181d..65d08755528 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_k_cache_tile_load.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_k_cache_tile_load.glslh @@ -7,32 +7,74 @@ */ /* - * Assume the following variables are defined in the shader layout: - * - t_k_cache + * Shared K transposed tile load for both LLM SDPA and fused SDPA. + * Loads K[head, l, d4] and transposes in-place so the tile represents K^T. * - * Macro Settings: - * - K_CACHE_BUFFER + * Layout selection (caller must #define K_LAYOUT before including): + * DSHB (0) — fused SDPA: [B, H, L, D] → WHCN {D, L, H, B} + * index = (head * L + l) * D4 + d4 + * DHSB (1) — LLM SDPA: [B, L, H, D] → WHCN {D, H, L, B} + * index = (l * H + head) * D4 + d4 + * + * Optional macros: + * K_CACHE_BUFFER / K_BUFFER — use buffer path; otherwise texture + * SDPA_PAD_D — zero-pad last d4 texel when D % 4 != 0 */ #ifndef SDPA_FP_K_CACHE_TILE_LOAD_GLSLH #define SDPA_FP_K_CACHE_TILE_LOAD_GLSLH +#ifndef DSHB +#define DSHB 0 +#define DHSB 1 +#endif + #extension GL_EXT_control_flow_attributes : require #include "linear_fp_weight_tile.glslh" -VEC4_T load_k_cache_d4( +// Determine whether buffer mode is active. Both K_CACHE_BUFFER (LLM) and +// K_BUFFER (fused) activate the buffer path. +#if defined(K_CACHE_BUFFER) || defined(K_BUFFER) +#define _SDPA_K_USE_BUFFER +#endif + +LINEAR_FP_WEIGHT_TILE_VEC4_T load_k_cache_d4( const int d4, const int c, const int kv_h, const int D4, + const int D, const int C, const int KV_H) { -#ifdef K_CACHE_BUFFER - return VEC4_T(t_k_cache[(c * KV_H * D4) + (kv_h * D4) + d4]); -#else - return VEC4_T(texelFetch(t_k_cache, ivec3(d4, kv_h, c), 0)); + LINEAR_FP_WEIGHT_TILE_VEC4_T val; + +#ifdef _SDPA_K_USE_BUFFER + #if K_LAYOUT == DSHB + val = LINEAR_FP_WEIGHT_TILE_VEC4_T(t_k[(kv_h * C * D4) + (c * D4) + d4]); + #elif K_LAYOUT == DHSB + val = LINEAR_FP_WEIGHT_TILE_VEC4_T(t_k[(c * KV_H * D4) + (kv_h * D4) + d4]); + #endif +#else // texture + #if K_LAYOUT == DSHB + val = LINEAR_FP_WEIGHT_TILE_VEC4_T(texelFetch(t_k, ivec3(d4, c, kv_h), 0)); + #elif K_LAYOUT == DHSB + val = LINEAR_FP_WEIGHT_TILE_VEC4_T(texelFetch(t_k, ivec3(d4, kv_h, c), 0)); + #endif #endif + +#ifdef SDPA_PAD_D + if (d4 == D4 - 1) { + const int valid = D - mul_4(d4); + [[unroll]] for (int i = 0; i < 4; ++i) { + if (i >= valid) { + val[i] = LINEAR_FP_WEIGHT_TILE_VEC4_T(0)[i]; + } + } + } +#endif + + return val; } void load_k_cache_tile_no_checks( @@ -41,6 +83,7 @@ void load_k_cache_tile_no_checks( const int c_start, const int kv_h, const int D4, + const int D, const int context_len, const int C, const int KV_H) { @@ -48,8 +91,8 @@ void load_k_cache_tile_no_checks( const int c4 = div_4(c); const int c4i = mod_4(c); [[unroll]] for (int d4 = 0; d4 < TILE_K4; ++d4) { - VEC4_T d4_row = - load_k_cache_d4(d4_start + d4, c_start + c, kv_h, D4, C, KV_H); + LINEAR_FP_WEIGHT_TILE_VEC4_T d4_row = + load_k_cache_d4(d4_start + d4, c_start + c, kv_h, D4, D, C, KV_H); // Transpose in-place const int d_base = mul_4(d4); @@ -67,6 +110,7 @@ void load_k_cache_tile_with_checks( const int c_start, const int kv_h, const int D4, + const int D, const int context_len, const int C, const int KV_H) { @@ -74,9 +118,9 @@ void load_k_cache_tile_with_checks( const int c4 = div_4(c); const int c4i = mod_4(c); [[unroll]] for (int d4 = 0; d4 < TILE_K4; ++d4) { - VEC4_T d4_row = VEC4_T(0.0); + LINEAR_FP_WEIGHT_TILE_VEC4_T d4_row = LINEAR_FP_WEIGHT_TILE_VEC4_T(0.0); if (d4_start + d4 < D4 && c_start + c < context_len) { - d4_row = load_k_cache_d4(d4_start + d4, c_start + c, kv_h, D4, C, KV_H); + d4_row = load_k_cache_d4(d4_start + d4, c_start + c, kv_h, D4, D, C, KV_H); } // Transpose in-place diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_out_tile_store.glslh b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_out_tile_store.glslh index 17e0988a6a4..382747bf9a4 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_out_tile_store.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_out_tile_store.glslh @@ -7,22 +7,33 @@ */ /* - * Assume the following variables are defined in the shader layout: - * - t_attn_weights + * Shared output tile store for both LLM SDPA and fused SDPA. * - * Macro Settings: - * - OUTPUT_BUFFER + * Layout selection (caller must #define OUT_LAYOUT before including): + * DSHB (0) — fused SDPA: [B, H, S, D] → WHCN {D, S, H, B} + * index = (head * S + s) * D4 + d4 + * DHSB (1) — LLM SDPA: [B, S, H, D] → WHCN {D, H, S, B} + * index = (s * H + head) * D4 + d4 + * + * Required macros/variables: + * t_output — output tensor binding + * OUTPUT_BUFFER — buffer mode (otherwise texture) */ -#ifndef SDPA_FP_OUT_TILE_LOAD_GLSLH -#define SDPA_FP_OUT_TILE_LOAD_GLSLH +#ifndef SDPA_FP_OUT_TILE_STORE_GLSLH +#define SDPA_FP_OUT_TILE_STORE_GLSLH + +#ifndef DSHB +#define DSHB 0 +#define DHSB 1 +#endif #extension GL_EXT_control_flow_attributes : require #include "linear_fp_output_tile.glslh" void store_out_d4( - const VEC4_T out_texel, + const LINEAR_FP_OUTPUT_TILE_VEC4_T out_texel, const int d4, const int q_h, const int s, @@ -30,9 +41,17 @@ void store_out_d4( const int Q_H, const int S) { #ifdef OUTPUT_BUFFER - t_output[(s * Q_H * D4) + (q_h * D4) + d4] = out_texel; -#else - imageStore(t_output, ivec3(d4, q_h, s), out_texel); + #if OUT_LAYOUT == DSHB + t_output[(q_h * S * D4) + (s * D4) + d4] = VEC4_T(out_texel); + #elif OUT_LAYOUT == DHSB + t_output[(s * Q_H * D4) + (q_h * D4) + d4] = VEC4_T(out_texel); + #endif +#else // texture + #if OUT_LAYOUT == DSHB + imageStore(t_output, ivec3(d4, s, q_h), VEC4_T(out_texel)); + #elif OUT_LAYOUT == DHSB + imageStore(t_output, ivec3(d4, q_h, s), VEC4_T(out_texel)); + #endif #endif } @@ -54,4 +73,4 @@ void store_sdpa_out_tile_with_checks( } } -#endif // SDPA_FP_OUT_TILE_LOAD_GLSLH +#endif // SDPA_FP_OUT_TILE_STORE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_q_projected_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_q_projected_tile_load.glslh index a304e5019e9..752762b623d 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_q_projected_tile_load.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_q_projected_tile_load.glslh @@ -7,32 +7,65 @@ */ /* - * Assume the following variables are defined in the shader layout: - * - t_input + * Shared Q tile load for both LLM SDPA and fused SDPA. * - * Macro Settings: - * - INPUT_BUFFER + * Layout selection (caller must #define Q_LAYOUT before including): + * DSHB (0) — fused SDPA: [B, H, S, D] → WHCN {D, S, H, B} + * index = (head * S + s) * D4 + d4 + * DHSB (1) — LLM SDPA: [B, S, H, D] → WHCN {D, H, S, B} + * index = (s * H + head) * D4 + d4 + * + * Optional macros: + * INPUT_BUFFER — use buffer path; otherwise texture + * SDPA_PAD_D — zero-pad last d4 texel when D % 4 != 0 */ #ifndef SDPA_FP_Q_PROJECTED_TILE_LOAD_GLSLH #define SDPA_FP_Q_PROJECTED_TILE_LOAD_GLSLH +#define DSHB 0 +#define DHSB 1 + #extension GL_EXT_control_flow_attributes : require #include "linear_fp_input_tile.glslh" -VEC4_T load_q_projected_d4( +LINEAR_FP_INPUT_TILE_VEC4_T load_q_projected_d4( const int d4, - const int q_h, const int s, + const int q_h, const int D4, - const int Q_H, - const int S) { + const int D, + const int S, + const int Q_H) { + LINEAR_FP_INPUT_TILE_VEC4_T val; + #ifdef INPUT_BUFFER - return t_q_projected[(s * Q_H * D4) + (q_h * D4) + d4]; -#else - return texelFetch(t_q_projected, ivec3(d4, q_h, s), 0); + #if Q_LAYOUT == DSHB + val = LINEAR_FP_INPUT_TILE_VEC4_T(t_q[(q_h * S * D4) + (s * D4) + d4]); + #elif Q_LAYOUT == DHSB + val = LINEAR_FP_INPUT_TILE_VEC4_T(t_q[(s * Q_H * D4) + (q_h * D4) + d4]); + #endif +#else // texture + #if Q_LAYOUT == DSHB + val = LINEAR_FP_INPUT_TILE_VEC4_T(texelFetch(t_q, ivec3(d4, s, q_h), 0)); + #elif Q_LAYOUT == DHSB + val = LINEAR_FP_INPUT_TILE_VEC4_T(texelFetch(t_q, ivec3(d4, q_h, s), 0)); + #endif #endif + +#ifdef SDPA_PAD_D + if (d4 == D4 - 1) { + const int valid = D - mul_4(d4); + [[unroll]] for (int i = 0; i < 4; ++i) { + if (i >= valid) { + val[i] = T(0); + } + } + } +#endif + + return val; } void load_q_projected_tile_no_checks( @@ -41,12 +74,13 @@ void load_q_projected_tile_no_checks( const int s_start, const int q_h, const int D4, - const int Q_H, - const int S) { + const int D, + const int S, + const int Q_H) { [[unroll]] for (int s = 0; s < TILE_M; ++s) { [[unroll]] for (int d4 = 0; d4 < TILE_K4; ++d4) { tile.data[s][d4] = - load_q_projected_d4(d4_start + d4, q_h, s_start + s, D4, Q_H, S); + load_q_projected_d4(d4_start + d4, s_start + s, q_h, D4, D, S, Q_H); } } } @@ -57,15 +91,16 @@ void load_q_projected_tile_with_checks( const int s_start, const int q_h, const int D4, - const int Q_H, - const int S) { + const int D, + const int S, + const int Q_H) { [[unroll]] for (int s = 0; s < TILE_M; ++s) { [[unroll]] for (int d4 = 0; d4 < TILE_K4; ++d4) { if (d4_start + d4 < D4 && s_start + s < S) { tile.data[s][d4] = - load_q_projected_d4(d4_start + d4, q_h, s_start + s, D4, Q_H, S); + load_q_projected_d4(d4_start + d4, s_start + s, q_h, D4, D, S, Q_H); } else { - tile.data[s][d4] = VEC4_T(0.0); + tile.data[s][d4] = LINEAR_FP_INPUT_TILE_VEC4_T(0.0); } } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_v_cache_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_v_cache_tile_load.glslh index bf94b251c43..98516744b44 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_v_cache_tile_load.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_v_cache_tile_load.glslh @@ -7,31 +7,60 @@ */ /* - * Assume the following variables are defined in the shader layout: - * - t_v_cache + * Shared V tile load for both LLM SDPA and fused SDPA (no transpose). * - * Macro Settings: - * - V_CACHE_BUFFER + * Layout selection (caller must #define V_LAYOUT before including): + * DSHB (0) — fused SDPA: [B, H, L, D] → WHCN {D, L, H, B} + * index = (head * L + l) * D4 + d4 + * DHSB (1) — LLM SDPA: [B, L, H, D] → WHCN {D, H, L, B} + * index = (l * H + head) * D4 + d4 + * + * Optional macros: + * SDPA_V_BUF — tensor name (default: t_v) + * V_CACHE_BUFFER / V_BUFFER — use buffer path; otherwise texture */ #ifndef SDPA_FP_V_CACHE_TILE_LOAD_GLSLH #define SDPA_FP_V_CACHE_TILE_LOAD_GLSLH +#ifndef DSHB +#define DSHB 0 +#define DHSB 1 +#endif + #extension GL_EXT_control_flow_attributes : require #include "linear_fp_weight_tile.glslh" -VEC4_T load_v_cache_d4( +#ifndef SDPA_V_BUF +#define SDPA_V_BUF t_v +#endif + +// Determine whether buffer mode is active. Both V_CACHE_BUFFER (LLM) and +// V_BUFFER (fused) activate the buffer path. +#if defined(V_CACHE_BUFFER) || defined(V_BUFFER) +#define _SDPA_V_USE_BUFFER +#endif + +LINEAR_FP_WEIGHT_TILE_VEC4_T load_v_cache_d4( const int d4, const int c, const int kv_h, const int D4, const int C, const int KV_H) { -#ifdef V_CACHE_BUFFER - return VEC4_T(t_v_cache[(c * KV_H * D4) + (kv_h * D4) + d4]); -#else - return VEC4_T(texelFetch(t_v_cache, ivec3(d4, kv_h, c), 0)); +#ifdef _SDPA_V_USE_BUFFER + #if V_LAYOUT == DSHB + return LINEAR_FP_WEIGHT_TILE_VEC4_T(SDPA_V_BUF[(kv_h * C * D4) + (c * D4) + d4]); + #elif V_LAYOUT == DHSB + return LINEAR_FP_WEIGHT_TILE_VEC4_T(SDPA_V_BUF[(c * KV_H * D4) + (kv_h * D4) + d4]); + #endif +#else // texture + #if V_LAYOUT == DSHB + return LINEAR_FP_WEIGHT_TILE_VEC4_T(texelFetch(SDPA_V_BUF, ivec3(d4, c, kv_h), 0)); + #elif V_LAYOUT == DHSB + return LINEAR_FP_WEIGHT_TILE_VEC4_T(texelFetch(SDPA_V_BUF, ivec3(d4, kv_h, c), 0)); + #endif #endif } @@ -44,8 +73,8 @@ void load_v_cache_tile_no_checks( const int context_len, const int C, const int KV_H) { - [[unroll]] for (int c = 0; c < TILE_N; ++c) { - [[unroll]] for (int d4 = 0; d4 < TILE_K4; ++d4) { + [[unroll]] for (int c = 0; c < TILE_K; ++c) { + [[unroll]] for (int d4 = 0; d4 < TILE_N4; ++d4) { tile.data[c][d4] = load_v_cache_d4(d4_start + d4, c_start + c, kv_h, D4, C, KV_H); } @@ -61,13 +90,13 @@ void load_v_cache_tile_with_checks( const int context_len, const int C, const int KV_H) { - [[unroll]] for (int c = 0; c < TILE_N; ++c) { - [[unroll]] for (int d4 = 0; d4 < TILE_K4; ++d4) { + [[unroll]] for (int c = 0; c < TILE_K; ++c) { + [[unroll]] for (int d4 = 0; d4 < TILE_N4; ++d4) { if (d4_start + d4 < D4 && c_start + c < context_len) { tile.data[c][d4] = load_v_cache_d4(d4_start + d4, c_start + c, kv_h, D4, C, KV_H); } else { - tile.data[c][d4] = VEC4_T(0.0); + tile.data[c][d4] = LINEAR_FP_WEIGHT_TILE_VEC4_T(0.0); } } } diff --git a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp index 65dae5e25cb..aa0c23ba993 100644 --- a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp @@ -21,8 +21,71 @@ #include #include +#include + namespace vkcompute { +namespace { + +// +// SDPA mode: distinguishes the two dispatch families sharing this file. +// LLM — Llama-style KV-cache SDPA. Q layout [B=1, S, H, D] (DHSB). +// Separate k_cache/v_cache inputs + input_pos_symint for dynamic +// context_len. attn_weights are padded to multiples of 4 in the +// S/context_len dims and carry the input dtype. A coop (GEMV) +// shader variant is selected for single-token decode. +// FUSED — General SDPA fused op. Q layout [B, H, S, D] (DSHB). No cache, +// optional additive attn_mask, optional scale arg. attn_weights +// are unpadded and always fp32. Tiled shader variant only. +// +enum class SDPAMode { LLM, FUSED }; + +// +// Common dimension helper: folds the axis-swap for LLM vs fused Q layouts. +// `input_pos_symint` is used only for LLM (context_len = S + input_pos); +// pass kDummyValueRef for FUSED. +// +struct SDPADims { + int64_t B = 1; + int64_t H = 0; + int64_t S = 0; + int64_t D = 0; + int64_t context_len = 0; // LLM: S + input_pos_val; FUSED: size_at(-2, k) + int64_t max_context_len = 0; // LLM: size_at(-3, k); FUSED: size_at(-2, k) +}; + +} // namespace + +SDPADims compute_sdpa_dims( + ComputeGraph& graph, + const ValueRef q, + const ValueRef k, + const ValueRef input_pos_symint, + const SDPAMode mode) { + SDPADims d; + d.D = graph.size_at(-1, q); + if (mode == SDPAMode::LLM) { + // Q: [B=1, S, H, D] (DHSB), K: [B=1, C_max, H_kv, D] + // `k` may be kDummyValueRef in dispatch pickers that don't need it; + // max_context_len is only read when k is valid. + d.B = 1; + d.H = graph.size_at(-2, q); + d.S = graph.size_at(-3, q); + d.max_context_len = is_valid(k) ? graph.size_at(-3, k) : 0; + const int32_t input_pos_val = + is_valid(input_pos_symint) ? graph.read_symint(input_pos_symint) : 0; + d.context_len = d.S + input_pos_val; + } else { + // Q: [B, H, S, D] (DSHB), K: [B, H_kv, L, D] + d.B = graph.size_at(-4, q); + d.H = graph.size_at(-3, q); + d.S = graph.size_at(-2, q); + d.context_len = graph.size_at(-2, k); + d.max_context_len = d.context_len; + } + return d; +} + bool is_single_token(ComputeGraph* graph, const ValueRef& q_projected) { return graph->size_at(-3, q_projected) == 1; } @@ -31,30 +94,42 @@ bool is_single_token(ComputeGraph* graph, const ValueRef& q_projected) { // Resize functions // -void resize_compute_attn_weights_node( +// Unified attn_weights resize. In LLM mode the shape is padded to multiples of +// 4 in the S/context_len dims (to match the tiled shader's iteration space); +// in fused mode it's the unpadded [B, H, S, L]. +// resize_args layout: [q, k, input_pos_symint_or_dummy, mode_as_int] +void resize_sdpa_attn_weights_node( ComputeGraph* graph, const std::vector& args, const std::vector& resize_args) { const ValueRef attn_weights = args.at(0).refs.at(0); - const ValueRef q_projected = args.at(1).refs.at(0); - const ValueRef input_pos_symint = resize_args.at(0); - - const uint32_t num_q_heads = graph->size_at(-2, q_projected); - const uint32_t seq_len = graph->size_at(-3, q_projected); - - const int32_t input_pos_val = graph->read_symint(input_pos_symint); - - const uint32_t context_len = seq_len + input_pos_val; - - std::vector out_sizes = { - 1, // batch - num_q_heads, - utils::align_up_4(seq_len), - utils::align_up_4(context_len)}; - + const ValueRef q = resize_args.at(0); + const ValueRef k = resize_args.at(1); + const ValueRef input_pos_symint = resize_args.at(2); + const SDPAMode mode = static_cast(resize_args.at(3)); + + std::vector out_sizes; + if (mode == SDPAMode::LLM) { + const int64_t num_q_heads = graph->size_at(-2, q); + const int64_t seq_len = graph->size_at(-3, q); + const int32_t input_pos_val = graph->read_symint(input_pos_symint); + const int64_t context_len = seq_len + input_pos_val; + out_sizes = { + 1, + num_q_heads, + static_cast(utils::align_up_4(seq_len)), + static_cast(utils::align_up_4(context_len))}; + } else { + const int64_t B = graph->size_at(-4, q); + const int64_t H = graph->size_at(-3, q); + const int64_t S = graph->size_at(-2, q); + const int64_t L = graph->size_at(-2, k); + out_sizes = {B, H, S, L}; + } graph->virtual_resize(attn_weights, out_sizes); } +// Softmax preserves attn_weights shape exactly; identical across modes. void resize_sdpa_attn_weights_softmax_node( ComputeGraph* graph, const std::vector& args, @@ -65,26 +140,15 @@ void resize_sdpa_attn_weights_softmax_node( graph->virtual_resize(attn_weights_softmax, graph->sizes_of(attn_weights)); } -void resize_sdpa_compute_out_node( +// Out matches Q's shape in both modes. resize_args[0] = q. +void resize_sdpa_out_node( ComputeGraph* graph, const std::vector& args, const std::vector& resize_args) { const ValueRef out = args.at(0).refs.at(0); - const ValueRef q_projected = resize_args.at(0); + const ValueRef q = resize_args.at(0); - graph->virtual_resize(out, graph->sizes_of(q_projected)); -} - -void resize_sdpa_out( - ComputeGraph* graph, - const std::vector& args, - const std::vector& extra_args) { - (void)args; - - int arg_idx = 0; - const ValueRef q_projected = extra_args[arg_idx++]; - const ValueRef out = extra_args[arg_idx++]; - graph->virtual_resize(out, graph->sizes_of(q_projected)); + graph->virtual_resize(out, graph->sizes_of(q)); } // @@ -108,167 +172,183 @@ utils::uvec3 kv_cache_update_global_wg_size( return {utils::div_up_4(head_dim_size), seq_len, num_heads}; } -utils::uvec3 attn_weight_scale_and_mask_global_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const std::vector& args, - const std::vector& resize_args) { - (void)shader; - (void)resize_args; - - const ValueRef attn_weight = args.at(0).refs.at(0); - - if (graph->is_buffer_storage(attn_weight)) { - return { - graph->size_at(-1, attn_weight), - graph->size_at(-2, attn_weight), - graph->size_at(-3, attn_weight), - }; - } else { - return graph->logical_limits_of(attn_weight); - } +// resize_args layout for SDPA dispatch pickers mirrors the node creation +// helper: [q, k, input_pos_symint_or_dummy, mode_as_int]. +static inline SDPAMode mode_of(const std::vector& resize_args) { + return static_cast(resize_args.at(3)); } -vkapi::ShaderInfo pick_sdpa_compute_attn_weights_shader( +vkapi::ShaderInfo pick_sdpa_qk_shader( ComputeGraph* graph, const std::vector& args, const std::vector& resize_args) { - const ValueRef q_projected = args.at(1).refs.at(0); - const ValueRef k_cache = args.at(1).refs.at(1); - - const bool is_gemv = is_single_token(graph, q_projected); - - std::string shader_name = "sdpa_compute_attn_weights"; - if (is_gemv) { - shader_name += "_coop"; + const SDPAMode mode = mode_of(resize_args); + if (mode == SDPAMode::LLM) { + const ValueRef q_projected = args.at(1).refs.at(0); + const ValueRef k_cache = args.at(1).refs.at(1); + const bool is_gemv = is_single_token(graph, q_projected); + + std::string shader_name = "sdpa_compute_attn_weights"; + shader_name += is_gemv ? "_coop" : "_tiled"; + add_storage_type_suffix(shader_name, graph->storage_type_of(q_projected)); + add_storage_type_suffix(shader_name, graph->storage_type_of(k_cache)); + add_dtype_suffix(shader_name, graph->dtype_of(q_projected)); + return VK_KERNEL_FROM_STR(shader_name); } else { - shader_name += "_tiled"; + const ValueRef q = args.at(1).refs.at(0); + const ValueRef k = args.at(1).refs.at(1); + // Fused path uses bias variant iff attn_mask was provided (signalled via + // 3 inputs in the read group: q, k, attn_mask). + const bool has_bias = args.at(1).refs.size() >= 3; + std::string shader_name = + has_bias ? "fused_sdpa_qk_tiled_bias" : "fused_sdpa_qk_tiled"; + add_storage_type_suffix(shader_name, graph->storage_type_of(q)); + add_storage_type_suffix(shader_name, graph->storage_type_of(k)); + add_dtype_suffix(shader_name, graph->dtype_of(q)); + return VK_KERNEL_FROM_STR(shader_name); } - - add_storage_type_suffix(shader_name, graph->storage_type_of(q_projected)); - add_storage_type_suffix(shader_name, graph->storage_type_of(k_cache)); - add_dtype_suffix(shader_name, graph->dtype_of(q_projected)); - - return VK_KERNEL_FROM_STR(shader_name); } -utils::uvec3 pick_sdpa_compute_attn_weights_global_wg_size( +utils::uvec3 pick_sdpa_qk_global_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, const std::vector& args, const std::vector& resize_args) { - const ValueRef q_projected = args.at(1).refs.at(0); - const ValueRef input_pos_symint = resize_args.at(0); - - const uint32_t num_q_heads = graph->size_at(-2, q_projected); - const uint32_t seq_len = graph->size_at(-3, q_projected); - - const int32_t input_pos_val = graph->read_symint(input_pos_symint); - - const uint32_t context_len = seq_len + input_pos_val; - - const uint32_t N4 = utils::div_up_4(context_len); - const uint32_t M4 = utils::div_up_4(seq_len); - - return {N4, M4, num_q_heads}; + (void)shader; + (void)args; + const SDPAMode mode = mode_of(resize_args); + const ValueRef q = resize_args.at(0); + const ValueRef k = resize_args.at(1); + const ValueRef input_pos_symint = resize_args.at(2); + const SDPADims d = compute_sdpa_dims(*graph, q, k, input_pos_symint, mode); + + // Dispatch grid: (context_len tiles, S tiles, H * B). + const uint32_t N4 = utils::div_up_4(static_cast(d.context_len)); + const uint32_t M4 = utils::div_up_4(static_cast(d.S)); + return {N4, M4, static_cast(d.H * d.B)}; } -utils::uvec3 pick_sdpa_compute_attn_weights_local_wg_size( +utils::uvec3 pick_sdpa_qk_local_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, const utils::uvec3& global_workgroup_size, const std::vector& args, const std::vector& resize_args) { - const bool use_coop_algorithm = - shader.kernel_name.find("_coop") != std::string::npos; - - if (use_coop_algorithm) { - return {1, 64, 1}; - } else { + const SDPAMode mode = mode_of(resize_args); + if (mode == SDPAMode::LLM) { + const bool use_coop_algorithm = + shader.kernel_name.find("_coop") != std::string::npos; + if (use_coop_algorithm) { + return {1, 64, 1}; + } return pick_hw_square_wg_size( graph, shader, global_workgroup_size, args, resize_args); } + return default_pick_local_wg_size( + graph, shader, global_workgroup_size, args, resize_args); } -utils::uvec3 pick_sdpa_attn_weights_softmax_global_wg_size( +utils::uvec3 pick_sdpa_softmax_global_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, const std::vector& args, const std::vector& resize_args) { - const ValueRef q_projected = resize_args.at(0); - - const uint32_t num_q_heads = graph->size_at(-2, q_projected); - const uint32_t seq_len = graph->size_at(-3, q_projected); - - return {1, seq_len, num_q_heads}; + (void)shader; + const SDPAMode mode = mode_of(resize_args); + const ValueRef q = resize_args.at(0); + // LLM reads H from axis -2, fused from axis -3 (handled by + // compute_sdpa_dims). + const int64_t num_q_heads = (mode == SDPAMode::LLM) + ? graph->size_at(-2, q) + : graph->size_at(-3, q); + const int64_t seq_len = (mode == SDPAMode::LLM) + ? graph->size_at(-3, q) + : graph->size_at(-2, q); + const int64_t B = + (mode == SDPAMode::LLM) ? 1 : graph->size_at(-4, q); + return { + 1, + static_cast(seq_len), + static_cast(num_q_heads * B)}; } -utils::uvec3 pick_sdpa_attn_weights_softmax_local_wg_size( +utils::uvec3 pick_sdpa_softmax_local_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, const utils::uvec3& global_workgroup_size, const std::vector& args, const std::vector& resize_args) { + (void)graph; + (void)shader; + (void)global_workgroup_size; + (void)args; + (void)resize_args; return {64, 1, 1}; } -vkapi::ShaderInfo pick_sdpa_compute_out_shader( +vkapi::ShaderInfo pick_sdpa_av_shader( ComputeGraph* graph, const std::vector& args, const std::vector& resize_args) { - const ValueRef out = args.at(0).refs.at(0); - const ValueRef v_cache = args.at(1).refs.at(1); - - const ValueRef q_projected = resize_args.at(0); - - const bool is_gemv = is_single_token(graph, q_projected); - - std::string shader_name = "sdpa_compute_out"; - if (is_gemv) { - shader_name += "_coop"; + const SDPAMode mode = mode_of(resize_args); + if (mode == SDPAMode::LLM) { + const ValueRef out = args.at(0).refs.at(0); + const ValueRef v_cache = args.at(1).refs.at(1); + const ValueRef q_projected = resize_args.at(0); + const bool is_gemv = is_single_token(graph, q_projected); + + std::string shader_name = "sdpa_compute_out"; + shader_name += is_gemv ? "_coop" : "_tiled"; + add_storage_type_suffix(shader_name, graph->storage_type_of(out)); + add_storage_type_suffix(shader_name, graph->storage_type_of(v_cache)); + add_dtype_suffix(shader_name, graph->dtype_of(out)); + return VK_KERNEL_FROM_STR(shader_name); } else { - shader_name += "_tiled"; + const ValueRef out = args.at(0).refs.at(0); + const ValueRef v = args.at(1).refs.at(1); + std::string shader_name = "fused_sdpa_av_tiled"; + add_storage_type_suffix(shader_name, graph->storage_type_of(out)); + add_storage_type_suffix(shader_name, graph->storage_type_of(v)); + add_dtype_suffix(shader_name, graph->dtype_of(out)); + return VK_KERNEL_FROM_STR(shader_name); } - - add_storage_type_suffix(shader_name, graph->storage_type_of(out)); - add_storage_type_suffix(shader_name, graph->storage_type_of(v_cache)); - add_dtype_suffix(shader_name, graph->dtype_of(out)); - - return VK_KERNEL_FROM_STR(shader_name); } -utils::uvec3 pick_sdpa_compute_out_global_wg_size( +utils::uvec3 pick_sdpa_av_global_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, const std::vector& args, const std::vector& resize_args) { - const ValueRef q_projected = resize_args.at(0); - - const uint32_t head_dim = graph->size_at(-1, q_projected); - const uint32_t num_q_heads = graph->size_at(-2, q_projected); - const uint32_t seq_len = graph->size_at(-3, q_projected); - - const uint32_t N4 = utils::div_up_4(head_dim); - const uint32_t M4 = utils::div_up_4(seq_len); - - return {N4, M4, num_q_heads}; + (void)shader; + const SDPAMode mode = mode_of(resize_args); + const ValueRef q = resize_args.at(0); + const ValueRef k = resize_args.at(1); + const ValueRef input_pos_symint = resize_args.at(2); + const SDPADims d = compute_sdpa_dims(*graph, q, k, input_pos_symint, mode); + + const uint32_t N4 = utils::div_up_4(static_cast(d.D)); + const uint32_t M4 = utils::div_up_4(static_cast(d.S)); + return {N4, M4, static_cast(d.H * d.B)}; } -utils::uvec3 pick_sdpa_compute_out_local_wg_size( +utils::uvec3 pick_sdpa_av_local_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, const utils::uvec3& global_workgroup_size, const std::vector& args, const std::vector& resize_args) { - const bool use_coop_algorithm = - shader.kernel_name.find("_coop") != std::string::npos; - - if (use_coop_algorithm) { - return {1, 64, 1}; - } else { + const SDPAMode mode = mode_of(resize_args); + if (mode == SDPAMode::LLM) { + const bool use_coop_algorithm = + shader.kernel_name.find("_coop") != std::string::npos; + if (use_coop_algorithm) { + return {1, 64, 1}; + } return pick_hw_square_wg_size( graph, shader, global_workgroup_size, args, resize_args); } + return default_pick_local_wg_size( + graph, shader, global_workgroup_size, args, resize_args); } // @@ -309,59 +389,93 @@ void add_sdpa_kv_cache_update_node( nullptr)); } +// Unified QK node (attn_weights = scale * Q @ K^T [+ bias]). +// LLM: pass input_pos_symint (real symint), attn_mask = kDummyValueRef. +// FUSED: pass input_pos_symint = kDummyValueRef, attn_mask = valid ref or +// kDummyValueRef to indicate no bias. scale_val is always passed as +// a spec const; the LLM path computes it per head_dim and FUSED may +// inherit from the caller-supplied scale. void add_sdpa_compute_attn_weights_node( ComputeGraph& graph, - const ValueRef q_projected, - const ValueRef k_cache, + const ValueRef q, + const ValueRef k, const ValueRef input_pos_symint, - const ValueRef attn_weights) { - const int32_t head_dim_size = graph.size_at(-1, q_projected); - const float scale_val = 1.0f / std::sqrt(static_cast(head_dim_size)); - + const ValueRef attn_mask, + const float scale_val, + const ValueRef attn_weights, + const SDPAMode mode) { vkapi::ParamsBindList param_ubos = { - graph.sizes_ubo(q_projected), - graph.sizes_ubo(k_cache), - graph.get_or_create_int_param_buffer(input_pos_symint)}; + graph.sizes_ubo(q), + graph.sizes_ubo(k), + }; + std::vector read_inputs = {q, k}; + + if (mode == SDPAMode::LLM) { + param_ubos.append(graph.get_or_create_int_param_buffer(input_pos_symint)); + } else if (is_valid(attn_mask)) { + param_ubos.append(graph.sizes_ubo(attn_mask)); + read_inputs.push_back(attn_mask); + } + + const ValueRef mode_ref = static_cast(mode); graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, - pick_sdpa_compute_attn_weights_shader, - pick_sdpa_compute_attn_weights_global_wg_size, - pick_sdpa_compute_attn_weights_local_wg_size, + pick_sdpa_qk_shader, + pick_sdpa_qk_global_wg_size, + pick_sdpa_qk_local_wg_size, // Inputs and Outputs - {{attn_weights, vkapi::kWrite}, {{q_projected, k_cache}, vkapi::kRead}}, + {{attn_weights, vkapi::kWrite}, {read_inputs, vkapi::kRead}}, // Shader param buffers param_ubos, // Push Constants {}, // Specialization Constants {scale_val}, - // Resize Args - {input_pos_symint}, + // Resize Args: [q, k, input_pos_symint_or_dummy, mode] + {q, k, input_pos_symint, mode_ref}, // Resizing Logic - resize_compute_attn_weights_node)); + resize_sdpa_attn_weights_node)); } void add_sdpa_attn_weights_softmax_node( ComputeGraph& graph, const ValueRef attn_weights, - const ValueRef q_projected, + const ValueRef q, + const ValueRef k, const ValueRef input_pos_symint, - const ValueRef attn_weights_softmax) { - std::string shader_name = "sdpa_attn_weights_softmax"; - add_storage_type_suffix( - shader_name, graph.storage_type_of(attn_weights_softmax)); - add_dtype_suffix(shader_name, graph.dtype_of(attn_weights_softmax)); + const ValueRef attn_weights_softmax, + const SDPAMode mode) { + std::string shader_name; + if (mode == SDPAMode::LLM) { + shader_name = "sdpa_attn_weights_softmax"; + add_storage_type_suffix( + shader_name, graph.storage_type_of(attn_weights_softmax)); + add_dtype_suffix(shader_name, graph.dtype_of(attn_weights_softmax)); + } else { + shader_name = "fused_sdpa_softmax"; + add_storage_type_suffix( + shader_name, graph.storage_type_of(attn_weights_softmax)); + add_dtype_suffix(shader_name, graph.dtype_of(attn_weights_softmax)); + } - vkapi::ParamsBindList param_ubos = { - graph.sizes_ubo(q_projected), - graph.get_or_create_int_param_buffer(input_pos_symint)}; + vkapi::ParamsBindList param_ubos; + if (mode == SDPAMode::LLM) { + param_ubos = { + graph.sizes_ubo(q), + graph.sizes_ubo(k), + graph.get_or_create_int_param_buffer(input_pos_symint)}; + } else { + param_ubos = {graph.sizes_ubo(q), graph.sizes_ubo(k)}; + } + + const ValueRef mode_ref = static_cast(mode); graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(shader_name), - pick_sdpa_attn_weights_softmax_global_wg_size, - pick_sdpa_attn_weights_softmax_local_wg_size, + pick_sdpa_softmax_global_wg_size, + pick_sdpa_softmax_local_wg_size, // Inputs and Outputs {{attn_weights_softmax, vkapi::kWrite}, {attn_weights, vkapi::kRead}}, // Shader param buffers @@ -370,8 +484,8 @@ void add_sdpa_attn_weights_softmax_node( {}, // Specialization Constants {}, - // Resize Args - {q_projected, input_pos_symint}, + // Resize Args: [q, k, input_pos_symint_or_dummy, mode] + {q, k, input_pos_symint, mode_ref}, // Resizing Logic resize_sdpa_attn_weights_softmax_node)); } @@ -379,32 +493,41 @@ void add_sdpa_attn_weights_softmax_node( void add_sdpa_compute_out_node( ComputeGraph& graph, const ValueRef attn_weights_softmax, - const ValueRef v_cache, - const ValueRef q_projected, + const ValueRef v, + const ValueRef q, + const ValueRef k, const ValueRef input_pos_symint, - const ValueRef out) { - vkapi::ParamsBindList param_ubos = { - graph.sizes_ubo(q_projected), - graph.sizes_ubo(v_cache), - graph.get_or_create_int_param_buffer(input_pos_symint)}; + const ValueRef out, + const SDPAMode mode) { + vkapi::ParamsBindList param_ubos; + if (mode == SDPAMode::LLM) { + param_ubos = { + graph.sizes_ubo(q), + graph.sizes_ubo(v), + graph.get_or_create_int_param_buffer(input_pos_symint)}; + } else { + param_ubos = {graph.sizes_ubo(q), graph.sizes_ubo(v)}; + } + + const ValueRef mode_ref = static_cast(mode); graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, - pick_sdpa_compute_out_shader, - pick_sdpa_compute_out_global_wg_size, - pick_sdpa_compute_out_local_wg_size, + pick_sdpa_av_shader, + pick_sdpa_av_global_wg_size, + pick_sdpa_av_local_wg_size, // Inputs and Outputs - {{out, vkapi::kWrite}, {{attn_weights_softmax, v_cache}, vkapi::kRead}}, + {{out, vkapi::kWrite}, {{attn_weights_softmax, v}, vkapi::kRead}}, // Shader param buffers param_ubos, // Push Constants {}, // Specialization Constants {}, - // Resize Args - {q_projected, input_pos_symint}, + // Resize Args: [q, k, input_pos_symint_or_dummy, mode] + {q, k, input_pos_symint, mode_ref}, // Resizing Logic - resize_sdpa_compute_out_node)); + resize_sdpa_out_node)); } // @@ -515,14 +638,37 @@ void sdpa_impl(ComputeGraph& graph, const std::vector& args) { attn_weights_storage, utils::kWidthPacked); + const int32_t head_dim_size = graph.size_at(-1, q_projected); + const float scale_val = 1.0f / std::sqrt(static_cast(head_dim_size)); + add_sdpa_compute_attn_weights_node( - graph, q_projected, k_cache, input_pos_symint, attn_weights); + graph, + q_projected, + k_cache, + input_pos_symint, + /*attn_mask=*/kDummyValueRef, + scale_val, + attn_weights, + SDPAMode::LLM); add_sdpa_attn_weights_softmax_node( - graph, attn_weights, q_projected, input_pos_symint, attn_weights_softmax); + graph, + attn_weights, + q_projected, + k_cache, + input_pos_symint, + attn_weights_softmax, + SDPAMode::LLM); add_sdpa_compute_out_node( - graph, attn_weights_softmax, v_cache, q_projected, input_pos_symint, out); + graph, + attn_weights_softmax, + v_cache, + q_projected, + /*k=*/kDummyValueRef, + input_pos_symint, + out, + SDPAMode::LLM); } void sdpa_with_kv_cache_impl( @@ -542,7 +688,7 @@ void sdpa_with_kv_cache_impl( const ValueRef scale = args[arg_idx++]; // Output tensors - const ValueRef out = args[arg_idx++]; + const ValueRef out = args[arg_idx]; (void)sequence_len; @@ -602,8 +748,121 @@ void compute_attn_weight_with_kv_cache_impl( update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1}); update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1}); + const int32_t head_dim_size = graph.size_at(-1, q_projected); + const float scale_val = 1.0f / std::sqrt(static_cast(head_dim_size)); + add_sdpa_compute_attn_weights_node( - graph, q_projected, k_cache, input_pos_symint, out); + graph, + q_projected, + k_cache, + input_pos_symint, + /*attn_mask=*/kDummyValueRef, + scale_val, + out, + SDPAMode::LLM); +} + +// +// Fused SDPA entry point (et_vk.sdpa.default). +// +// Accepts pre-reshaped [B, H, S, D] tensors (DSHB) plus optional additive +// attn_mask and optional scale scalar. No KV cache; this is the general SDPA +// fused op used by non-LLM models. +// +void fused_sdpa_impl(ComputeGraph& graph, const std::vector& args) { + int arg_idx = 0; + const ValueRef q = args[arg_idx++]; + const ValueRef k = args[arg_idx++]; + const ValueRef v = args[arg_idx++]; + const ValueRef attn_mask = args[arg_idx++]; + const ValueRef scale_ref = args[arg_idx++]; + const ValueRef out = args[arg_idx]; + + // Validate inputs + VK_CHECK_COND(graph.dim_of(q) == 4); + VK_CHECK_COND(graph.dim_of(k) == 4); + VK_CHECK_COND(graph.dim_of(v) == 4); + // Head dim must match between Q and K + VK_CHECK_COND(graph.size_at(-1, q) == graph.size_at(-1, k)); + // K and V must have same sequence length + VK_CHECK_COND(graph.size_at(-2, k) == graph.size_at(-2, v)); + // All tensors must be width-packed + VK_CHECK_COND(graph.packed_dim_of(q) == WHCN::kWidthDim); + VK_CHECK_COND(graph.packed_dim_of(k) == WHCN::kWidthDim); + VK_CHECK_COND(graph.packed_dim_of(v) == WHCN::kWidthDim); + + // Compute scale + const int32_t head_dim = graph.size_at(-1, q); + float scale_val; + if (graph.val_is_none(scale_ref)) { + scale_val = 1.0f / std::sqrt(static_cast(head_dim)); + } else { + scale_val = graph.extract_scalar(scale_ref); + } + + // Resolve attn_mask: a None value is normalized to kDummyValueRef so the + // unified helpers can branch with a single `is_valid()` check. + const ValueRef attn_mask_ref = + graph.val_is_none(attn_mask) ? kDummyValueRef : attn_mask; + + // Get dimensions for intermediate allocation + const int64_t B = graph.size_at(-4, q); + const int64_t H = graph.size_at(-3, q); + const int64_t S = graph.size_at(-2, q); + const int64_t L = graph.size_at(-2, k); + + std::vector attn_weight_sizes = {B, H, S, L}; + + // attn_weights and attn_weights_softmax follow the output's storage so the + // entire fused SDPA pipeline uses a uniform storage type. attn_weights stays + // in fp32 for numerical stability of the Q@K^T accumulation. + const utils::StorageType attn_storage = graph.storage_type_of(out); + + TmpTensor attn_weights( + &graph, + attn_weight_sizes, + vkapi::ScalarType::Float, + attn_storage, + utils::kWidthPacked); + + TmpTensor attn_weights_softmax( + &graph, + attn_weight_sizes, + graph.dtype_of(q), + attn_storage, + utils::kWidthPacked); + + // Phase 1: Q @ K^T with fp32 accumulation, apply scale and optional bias + add_sdpa_compute_attn_weights_node( + graph, + q, + k, + /*input_pos_symint=*/kDummyValueRef, + attn_mask_ref, + scale_val, + attn_weights, + SDPAMode::FUSED); + + // Phase 2: Softmax in fp32, output in input dtype + add_sdpa_attn_weights_softmax_node( + graph, + attn_weights, + q, + k, + /*input_pos_symint=*/kDummyValueRef, + attn_weights_softmax, + SDPAMode::FUSED); + + // Phase 3: attn_weights_softmax @ V + add_sdpa_compute_out_node( + graph, + attn_weights_softmax, + v, + q, + k, + /*input_pos_symint=*/kDummyValueRef, + out, + SDPAMode::FUSED); } REGISTER_OPERATORS { @@ -613,6 +872,7 @@ REGISTER_OPERATORS { VK_REGISTER_OP( testing.compute_attn_weight_with_kv_cache.default, compute_attn_weight_with_kv_cache_impl); + VK_REGISTER_OP(et_vk.sdpa.default, fused_sdpa_impl); } } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/sdpa_test.cpp b/backends/vulkan/test/op_tests/sdpa_test.cpp index d881ce7a7f4..fd2afc7408e 100644 --- a/backends/vulkan/test/op_tests/sdpa_test.cpp +++ b/backends/vulkan/test/op_tests/sdpa_test.cpp @@ -616,6 +616,242 @@ void test_vulkan_sdpa( } } +// +// General-purpose fused SDPA tests (et_vk.sdpa) +// + +/* + * Reference implementation of general SDPA: softmax(Q @ K^T * scale + bias) @ V + * Q: [B, H, S, D], K: [B, H, L, D], V: [B, H, L, D] + * Returns: [B, H, S, D] + */ +at::Tensor general_sdpa_reference_impl( + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + const std::optional& attn_mask = std::nullopt, + const std::optional scale = std::nullopt) { + float scale_val = + scale.has_value() ? scale.value() : (1.0 / sqrt(q.size(-1))); + at::Tensor attn = at::matmul(q, k.transpose(-2, -1)) * scale_val; + if (attn_mask.has_value()) { + attn = attn + attn_mask.value(); + } + attn = at::softmax(attn, -1); + return at::matmul(attn, v); +} + +void test_vulkan_general_sdpa( + const int batch_size, + const int num_heads, + const int q_seq_len, + const int kv_seq_len, + const int head_dim, + const bool has_bias, + at::ScalarType dtype = at::kFloat) { + torch::manual_seed(42); + + // Generate random inputs in [B, H, S, D] layout + at::Tensor q = at::rand( + {batch_size, num_heads, q_seq_len, head_dim}, + at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor k = at::rand( + {batch_size, num_heads, kv_seq_len, head_dim}, + at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor v = at::rand( + {batch_size, num_heads, kv_seq_len, head_dim}, + at::device(at::kCPU).dtype(at::kFloat)); + + std::optional bias = std::nullopt; + if (has_bias) { + // Broadcastable bias: [B, 1, 1, kv_seq_len] + bias = at::rand( + {batch_size, 1, 1, kv_seq_len}, + at::device(at::kCPU).dtype(at::kFloat)) * + 2.0 - + 1.0; + } + + // Compute reference output in fp32 + at::Tensor reference_out = general_sdpa_reference_impl(q, k, v, bias); + + // Cast to test dtype for Vulkan + q = q.to(dtype); + k = k.to(dtype); + v = v.to(dtype); + if (bias.has_value()) { + bias = bias.value().to(dtype); + } + + // Build Vulkan compute graph + using namespace vkcompute; + + GraphConfig config; + ComputeGraph graph(config); + + IOValueRef r_q = graph.add_input_tensor( + q.sizes().vec(), from_at_scalartype(dtype), utils::kBuffer); + IOValueRef r_k = graph.add_input_tensor( + k.sizes().vec(), from_at_scalartype(dtype), utils::kBuffer); + IOValueRef r_v = graph.add_input_tensor( + v.sizes().vec(), from_at_scalartype(dtype), utils::kBuffer); + + ValueRef r_bias = kDummyValueRef; + IOValueRef r_bias_io = {}; + if (has_bias) { + r_bias_io = graph.add_input_tensor( + bias.value().sizes().vec(), from_at_scalartype(dtype), utils::kBuffer); + r_bias = r_bias_io.value; + } + + const ValueRef r_out = graph.add_tensor( + {batch_size, num_heads, q_seq_len, head_dim}, + from_at_scalartype(dtype), + utils::kBuffer); + + VK_GET_OP_FN("et_vk.sdpa.default") + (graph, + { + r_q.value, + r_k.value, + r_v.value, + r_bias, + kDummyValueRef, // scale (None -> 1/sqrt(head_dim)) + r_out, + }); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.prepack(); + + // Copy inputs + graph.maybe_cast_and_copy_into_staging( + r_q.staging, q.const_data_ptr(), q.numel(), from_at_scalartype(dtype)); + graph.maybe_cast_and_copy_into_staging( + r_k.staging, k.const_data_ptr(), k.numel(), from_at_scalartype(dtype)); + graph.maybe_cast_and_copy_into_staging( + r_v.staging, v.const_data_ptr(), v.numel(), from_at_scalartype(dtype)); + if (has_bias) { + graph.maybe_cast_and_copy_into_staging( + r_bias_io.staging, + bias.value().const_data_ptr(), + bias.value().numel(), + from_at_scalartype(dtype)); + } + + graph.execute(); + + // Extract output + at::Tensor vk_out = at::zeros( + {batch_size, num_heads, q_seq_len, head_dim}, + at::device(at::kCPU).dtype(dtype)) + .contiguous(); + graph.maybe_cast_and_copy_from_staging( + staging_out, + vk_out.mutable_data_ptr(), + vk_out.numel(), + from_at_scalartype(dtype)); + + // Compare in fp32 + vk_out = vk_out.to(at::kFloat); + + // Use appropriate tolerance based on dtype + double atol = dtype == at::kHalf ? 1e-2 : 1e-4; + double rtol = dtype == at::kHalf ? 1e-2 : 1e-5; + + const bool output_correct = at::allclose(reference_out, vk_out, rtol, atol); + if (!output_correct) { + at::Tensor diffs = at::abs(reference_out - vk_out); + std::cout << "General SDPA test failed:" << " B=" << batch_size + << " H=" << num_heads << " S=" << q_seq_len << " L=" << kv_seq_len + << " D=" << head_dim << " bias=" << has_bias << " dtype=" << dtype + << std::endl; + std::cout << "Max diff: " << at::max(diffs).item() << std::endl; + std::cout << "Max value: " + << at::max(at::abs(at::cat({reference_out, vk_out}, -1))).item() + << std::endl; + + // Print all elements for small tensors + if (reference_out.numel() <= 64) { + auto ref_flat = reference_out.flatten(); + auto vk_flat = vk_out.flatten(); + std::cout << "Reference vs Vulkan:" << std::endl; + for (int i = 0; i < ref_flat.numel(); ++i) { + std::cout << " [" << i << "] ref=" << ref_flat[i].item() + << " vk=" << vk_flat[i].item() << " diff=" + << std::abs( + ref_flat[i].item() - vk_flat[i].item()) + << std::endl; + } + } + } + ASSERT_TRUE(output_correct); +} + +// Basic correctness: small sizes, no bias, fp32 +TEST(VulkanGeneralSDPATest, test_general_sdpa_small_no_bias) { + test_vulkan_general_sdpa(1, 2, 4, 4, 8, false); +} + +// With additive bias mask +TEST(VulkanGeneralSDPATest, test_general_sdpa_small_with_bias) { + test_vulkan_general_sdpa(1, 2, 4, 8, 8, true); +} + +// Cross-attention: Q and K have different sequence lengths +TEST(VulkanGeneralSDPATest, test_general_sdpa_cross_attention) { + test_vulkan_general_sdpa(1, 4, 4, 16, 16, false); +} + +// Batch size > 1 +TEST(VulkanGeneralSDPATest, test_general_sdpa_batched) { + test_vulkan_general_sdpa(2, 4, 8, 8, 16, false); +} + +// Larger head_dim with bias (EdgeTAM-like) +TEST(VulkanGeneralSDPATest, test_general_sdpa_large_head_dim) { + test_vulkan_general_sdpa(1, 8, 4, 4, 32, true); +} + +// Non-aligned S (S is height dim, not width — no padding issue) +TEST(VulkanGeneralSDPATest, test_general_sdpa_non_aligned_s) { + test_vulkan_general_sdpa(1, 2, 5, 4, 32, false); +} + +// Large number of heads +TEST(VulkanGeneralSDPATest, test_general_sdpa_many_heads) { + test_vulkan_general_sdpa(1, 8, 4, 8, 32, false); +} + +// fp16 — validates fp32 internal accumulation +TEST(VulkanGeneralSDPATest, test_general_sdpa_fp16) { + test_vulkan_general_sdpa( + /*batch_size=*/1, + /*num_heads=*/4, + /*q_seq_len=*/8, + /*kv_seq_len=*/8, + /*head_dim=*/16, + /*has_bias=*/false, + /*dtype=*/at::kHalf); +} + +// fp16 with bias +TEST(VulkanGeneralSDPATest, test_general_sdpa_fp16_with_bias) { + test_vulkan_general_sdpa( + /*batch_size=*/1, + /*num_heads=*/4, + /*q_seq_len=*/8, + /*kv_seq_len=*/16, + /*head_dim=*/16, + /*has_bias=*/true, + /*dtype=*/at::kHalf); +} + +// +// Existing KV-cache SDPA tests +// + TEST(VulkanSDPATest, test_sdpa_op_small_params) { const int base_sequence_len = 3; const int num_heads = 8;