From dec033ee1311cb0c295cb9182a49772dd94f6657 Mon Sep 17 00:00:00 2001 From: ssjia Date: Fri, 24 Apr 2026 09:46:38 -0700 Subject: [PATCH] [ET-VK] Add apply_rotary_emb_interleaved fused operator Introduces `et_vk.apply_rotary_emb_interleaved`, a fused Vulkan custom operator for the "complex-number" RoPE variant used by SAM2/EdgeTAM's memory attention. This replaces a 12+-op layout-shuffle chain (`view/unbind/stack/view` -> lowers to `slice_copy + squeeze_copy + unsqueeze_copy + cat + view_copy`) with a single GPU dispatch. **Math**: On pair-interleaved inputs where element `2k` is real and `2k+1` is imag, for each `k in [0, C/2)`: out[2k] = x[2k] * cos[k] - x[2k+1] * sin[k] out[2k+1] = x[2k] * sin[k] + x[2k+1] * cos[k] **Why a new op instead of reusing `et_vk.apply_rotary_emb`**: The existing LLM-oriented operator takes `(xq, xk)` pairs with separate `freqs_cos` / `freqs_sin` tensors and 4D `(B, S, H, D)` shapes optimized for LLM prefill two-texel-per-thread reuse. SAM2's memory attention passes a single 3D `(B, N, C)` tensor through RoPE (no heads dim) with a fused `[N, C/2, 2]` freqs tensor. Reusing the existing op would force runtime splits of the fused freqs and double-dispatch Q/K separately, defeating the fuse. A sibling shader is tighter for both workloads. **Op contract**: `apply_rotary_emb_interleaved(x, freqs_cis) -> Tensor` where `x` is `[B, N, C]` and `freqs_cis` is any rank with `N*C` elements and the `cos`/`sin` values interleaved on the innermost dim. In EdgeTAM's memory attention the native shape is `[1, N, C/2, 2]`; passing it through without a reshape keeps the exported graph clean of bracketing view_copy dispatches. **Shader**: Single-dispatch kernel, one texel out per thread. Each thread reads one `x` texel (2 real/imag pairs) and the corresponding `freqs_cis` entries (2 cos/sin pairs) flat-indexed from buffer storage, writes one output texel. `x` and output support buffer + texture3d; `freqs_cis` is always buffer-storage (small tensor, flat indexing is simplest). Supports fp16 and fp32 via the `FP_T` dtype iterator in the YAML. **Op registration**: `Meta` kernel returns `torch.empty_like(x)` to keep the op opaque during `torch.export`. `CPU` kernel holds the reference math so non-Vulkan backends keep working. `op_registry.py` pins `freqs_cis` storage to `CONTIGUOUS_BUFFER` while leaving `x` at `CONTIGUOUS_ANY`. Differential Revision: [D102360202](https://our.internmc.facebook.com/intern/diff/D102360202/) [ghstack-poisoned] --- backends/vulkan/custom_ops_lib.py | 49 ++++ backends/vulkan/op_registry.py | 16 ++ .../glsl/apply_rotary_emb_interleaved.glsl | 114 +++++++++ .../glsl/apply_rotary_emb_interleaved.yaml | 13 + .../graph/ops/impl/RotaryEmbedding.cpp | 108 +++++++++ .../test/op_tests/rotary_embedding_test.cpp | 224 ++++++++++++++++++ 6 files changed, 524 insertions(+) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/apply_rotary_emb_interleaved.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/apply_rotary_emb_interleaved.yaml diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index 64c6d3e46d9..4b1b02466ee 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -828,6 +828,55 @@ def apply_rotary_emb_hf_impl( lib.impl(name, apply_rotary_emb_hf_impl, "CompositeExplicitAutograd") apply_rotary_emb_hf_op = getattr(getattr(torch.ops, namespace), name) +################################## +## apply_rotary_emb_interleaved ## +################################## + + +def apply_rotary_emb_interleaved_impl( + x: torch.Tensor, freqs_cis: torch.Tensor +) -> torch.Tensor: + # EdgeTAM's pair-interleaved complex-number RoPE. + # x: [B, N, C] with (real, imag) pairs interleaved along C + # freqs_cis: any rank whose flattened layout is [N, C]. Commonly 2D + # [N, C] or 4D [1, N, C/2, 2] from + # `torch.view_as_real(...).unsqueeze(0)`. The (cos, sin) + # pairs are interleaved along the innermost axis in the + # flattened view. + # Semantically equivalent to: + # freqs_cis.reshape(N, C // 2, 2) -> (cos, sin) + # out[2k] = x[2k] * cos[k] - x[2k+1] * sin[k] + # out[2k+1] = x[2k] * sin[k] + x[2k+1] * cos[k] + B, N, C = x.shape + a_real, a_imag = x.view(B, N, C // 2, 2).unbind(-1) + # Use reshape so callers may pass freqs_cis at any rank. + cs = freqs_cis.reshape(N, C // 2, 2) + b_real, b_imag = cs[..., 0], cs[..., 1] + out = torch.stack( + (a_real * b_real - a_imag * b_imag, a_real * b_imag + a_imag * b_real), + dim=-1, + ) + return out.view(B, N, C) + + +def apply_rotary_emb_interleaved_meta( + x: torch.Tensor, freqs_cis: torch.Tensor +) -> torch.Tensor: + # Meta kernel: shape-only. Keeps the op opaque during torch.export (no + # inlining of view/reshape calls into the exported graph) and does not + # constrain the rank of freqs_cis — any shape with N * C elements is + # accepted by the Vulkan dispatcher. + return torch.empty_like(x) + + +name = "apply_rotary_emb_interleaved" +lib.define(f"{name}(Tensor x, Tensor freqs_cis) -> Tensor") +# CPU kernel preserves eager-mode reference semantics. +lib.impl(name, apply_rotary_emb_interleaved_impl, "CPU") +# Meta kernel keeps the op opaque in the exported graph. +lib.impl(name, apply_rotary_emb_interleaved_meta, "Meta") +apply_rotary_emb_interleaved_op = getattr(getattr(torch.ops, namespace), name) + ######################## ## q8ta_add ## ######################## diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 2e313f0f91b..9345f0a9090 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -1110,6 +1110,22 @@ def register_apply_rotary_emb_hf(): ) +@update_features(exir_ops.edge.et_vk.apply_rotary_emb_interleaved.default) +def register_apply_rotary_emb_interleaved(): + return OpFeatures( + # freqs_cis is pinned to buffer storage so the shader can compute a + # flat [N, C] linear address regardless of the tensor's declared rank + # (callers commonly pass 4D [1, N, C/2, 2] without a preceding view). + inputs_storage=[ + utils.CONTIGUOUS_ANY, # x + utils.CONTIGUOUS_BUFFER, # freqs_cis + ], + inputs_dtypes=utils.FP_T, + supports_resize=True, + supports_highdim=True, + ) + + # ============================================================================= # Permute.cpp # ============================================================================= diff --git a/backends/vulkan/runtime/graph/ops/glsl/apply_rotary_emb_interleaved.glsl b/backends/vulkan/runtime/graph/ops/glsl/apply_rotary_emb_interleaved.glsl new file mode 100644 index 00000000000..121b1207d31 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/apply_rotary_emb_interleaved.glsl @@ -0,0 +1,114 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +${define_required_extensions(STORAGE, DTYPE)} +${define_required_extensions("buffer", DTYPE)} + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} +#define BUFFER_VEC4_T ${texel_load_type(DTYPE, "buffer")} + +${define_active_storage_type(STORAGE)} + +layout(std430) buffer; + +#include "indexing.glslh" + +${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE, is_scalar_array=False)} +// freqs_cis is always bound as a buffer so the shader can flat-index it +// regardless of the caller's declared rank (2D [N, C] or 4D [1, N, C/2, 2]). +${layout_declare_tensor(B, "r", "t_freqs", DTYPE, "buffer", is_scalar_array=False)} + +$if STORAGE == "buffer": + ${layout_declare_ubo(B, "BufferMetadata", "outp")} +$else: + ${layout_declare_ubo(B, "TextureMetadata", "outp")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "outp_layout", "CONTIG_LAYOUT_INT")} + +/* + * Applies rotary positional embeddings to a tensor whose last dimension + * contains pair-interleaved (real, imag) components. This matches EdgeTAM's + * `apply_rotary_enc_without_complex` semantics, where the fused cos/sin + * freqs tensor has a flattened [N, C] layout (cos, sin pairs interleaved). + * + * Inputs: + * t_in: [B, N, C] (last dim packed as [r0, i0, r1, i1, ...]) + * t_freqs: contiguous memory with N * C elements. May arrive at any rank + * (e.g. 2D [N, C] or 4D [1, N, C/2, 2]). Physically the values + * are [cos0, sin0, cos1, sin1, ...] along the final axis. + * + * Output: + * t_out: same shape as t_in + * + * Math per k in [0, C/2): + * out[2k] = x[2k] * cos[k] - x[2k+1] * sin[k] + * out[2k+1] = x[2k] * sin[k] + x[2k+1] * cos[k] + * + * Each thread processes one width-packed texel (4 elements = 2 (r, i) pairs). + * All participating tensors are assumed to be width-packed with standard axis + * maps. + * + * The freqs tensor is indexed using a flat (n_idx * C + c_offset) address to + * remain correct regardless of input rank — the shape of t_freqs does not + * need to match the logical [N, C] layout, only the underlying memory does. + */ +void main() { + // Each thread computes one output texel of 4 elements along the last dim. + TensorIndex4D out_tidx = zero_tensor4d_idx(); + out_tidx.data.x = int(gl_GlobalInvocationID.x) * 4; + out_tidx.data.y = int(gl_GlobalInvocationID.y); + out_tidx.data.z = int(gl_GlobalInvocationID.z); + + if (out_of_bounds(out_tidx, outp)) { + return; + } + + // Freqs tensor is always a contiguous buffer of N * C elements. Compute + // a flat texel index directly from logical (n_idx, c_elem_idx / 4). The + // logical width C comes from the output tensor metadata — both buffer + // and texture metadata store this at index 0 (sizes[0][0] / sizes.x). +#ifdef USING_BUFFER + const uint C_width = outp.sizes[0][0]; +#else + const uint C_width = uint(outp.sizes.x); +#endif + const uint freqs_texel_bufi = + uint(out_tidx.data.y) * div_4(C_width) + + uint(gl_GlobalInvocationID.x); + BUFFER_VEC4_T f_tex = t_freqs[freqs_texel_bufi]; + +#ifdef USING_BUFFER + const uint out_texel_bufi = + div_4(tensor4d_idx_to_linear_idx(outp, out_tidx)); + VEC4_T x_tex = t_in[out_texel_bufi]; +#else // USING_TEXTURE + const ivec3 out_pos = + tensor4d_idx_to_texel_pos_simple(outp, out_tidx, outp_layout); + VEC4_T x_tex = texelFetch(t_in, out_pos, 0); +#endif + + // x_tex = (r0, i0, r1, i1), f_tex = (c0, s0, c1, s1) + VEC4_T out_tex; + out_tex.x = x_tex.x * VEC4_T(f_tex).x - x_tex.y * VEC4_T(f_tex).y; + out_tex.y = x_tex.x * VEC4_T(f_tex).y + x_tex.y * VEC4_T(f_tex).x; + out_tex.z = x_tex.z * VEC4_T(f_tex).z - x_tex.w * VEC4_T(f_tex).w; + out_tex.w = x_tex.z * VEC4_T(f_tex).w + x_tex.w * VEC4_T(f_tex).z; + +#ifdef USING_BUFFER + t_out[out_texel_bufi] = out_tex; +#else + imageStore(t_out, out_pos, out_tex); +#endif +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/apply_rotary_emb_interleaved.yaml b/backends/vulkan/runtime/graph/ops/glsl/apply_rotary_emb_interleaved.yaml new file mode 100644 index 00000000000..3d118c25832 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/apply_rotary_emb_interleaved.yaml @@ -0,0 +1,13 @@ +apply_rotary_emb_interleaved: + parameter_names_with_default_values: + DTYPE: float + STORAGE: texture3d + generate_variant_forall: + STORAGE: + - VALUE: texture3d + - VALUE: buffer + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: apply_rotary_emb_interleaved diff --git a/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp b/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp index 7f90e2557cb..d1e70cc8c41 100644 --- a/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp @@ -219,9 +219,117 @@ void apply_rotary_emb_hf( graph, args[0], args[1], args[2], args[3], args[4], xq_out, xk_out); } +// +// EdgeTAM-style RoPE variant with fused [cos, sin] freqs tensor +// +// Operates on a single tensor (Q or K) of shape [B, N, C] with pair-interleaved +// (real, imag) components along the last dim, and a freqs tensor with a total +// element count of N * C that packs (cos, sin) pairs in the same interleaved +// order as the x tensor. The freqs tensor may be passed in at any rank whose +// flattened layout is [N, C] — e.g. 2D `[N, C]` or 4D `[1, N, C/2, 2]`. This +// avoids callers having to emit a `view` dispatch (view_copy) purely to +// normalize rank. +// + +void resize_rotary_embedding_interleaved_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + (void)resize_args; + + const ValueRef out = args.at(0).refs.at(0); + const ValueRef in = args.at(1).refs.at(0); + + graph->virtual_resize(out, graph->sizes_of(in)); +} + +utils::uvec3 rotary_embedding_interleaved_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + const ValueRef out = args.at(0).refs.at(0); + + const std::vector out_sizes = graph->sizes_of(out); + VK_CHECK_COND(out_sizes.size() == 3); + + const uint32_t B = static_cast(out_sizes.at(0)); + const uint32_t N = static_cast(out_sizes.at(1)); + const uint32_t C = static_cast(out_sizes.at(2)); + + // One thread per output texel of 4 elements along C. + return {utils::div_up_4(C), N, B}; +} + +void add_rotary_embedding_interleaved_node( + ComputeGraph& graph, + const ValueRef x, + const ValueRef freqs_cis, + const ValueRef out) { + const std::vector x_sizes = graph.sizes_of(x); + const std::vector freqs_sizes = graph.sizes_of(freqs_cis); + + VK_CHECK_COND(x_sizes.size() == 3); + VK_CHECK_COND(x_sizes.at(2) % 4 == 0); + + // freqs_cis may arrive at any rank (commonly 2D [N, C] or 4D [1, N, C/2, 2] + // from `torch.view_as_real(...).unsqueeze(0)`). Validate via numel rather + // than per-dim equality so callers do not need to emit a view_copy purely + // to flatten the shape. + int64_t freqs_numel = 1; + for (const int64_t s : freqs_sizes) { + freqs_numel *= s; + } + const int64_t expected_numel = x_sizes.at(1) * x_sizes.at(2); + VK_CHECK_COND(freqs_numel == expected_numel); + + VK_CHECK_COND(graph.packed_dim_of(x) == WHCN::kWidthDim); + VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kWidthDim); + VK_CHECK_COND(graph.has_standard_axis_map(x)); + VK_CHECK_COND(graph.has_standard_axis_map(out)); + // freqs_cis is pinned to buffer storage via op_registry so the shader can + // use flat (row, col) indexing regardless of its declared rank. + VK_CHECK_COND(graph.is_buffer_storage(freqs_cis)); + + std::string kernel_name = "apply_rotary_emb_interleaved"; + add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + vkapi::ParamsBindList param_ubos = {graph.meta_ubo(out)}; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + rotary_embedding_interleaved_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {{x, freqs_cis}, vkapi::kRead}}, + // Parameter buffers + param_ubos, + // Push Constants + {}, + // Specialization Constants + {graph.hashed_layout_of(out)}, + // Resize Args + {}, + // Resizing Logic + resize_rotary_embedding_interleaved_node)); +} + +void apply_rotary_emb_interleaved( + ComputeGraph& graph, + const std::vector& args) { + add_rotary_embedding_interleaved_node(graph, args[0], args[1], args[2]); +} + REGISTER_OPERATORS { VK_REGISTER_OP(et_vk.apply_rotary_emb.default, apply_rotary_emb); VK_REGISTER_OP(et_vk.apply_rotary_emb_hf.default, apply_rotary_emb_hf); + VK_REGISTER_OP( + et_vk.apply_rotary_emb_interleaved.default, apply_rotary_emb_interleaved); } } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/rotary_embedding_test.cpp b/backends/vulkan/test/op_tests/rotary_embedding_test.cpp index e2be1526a4a..a2be5affb65 100644 --- a/backends/vulkan/test/op_tests/rotary_embedding_test.cpp +++ b/backends/vulkan/test/op_tests/rotary_embedding_test.cpp @@ -820,3 +820,227 @@ TEST( /*start_pos=*/10, /*max_seq_len=*/64); } + +// +// Interleaved (EdgeTAM) RoPE reference and tests +// +// x: [B, N, C] with (real, imag) pairs interleaved along C +// freqs_cis: any rank with N * C elements. Commonly 2D [N, C] or 4D +// [1, N, C/2, 2] from torch.view_as_real(...).unsqueeze(0). The +// (cos, sin) pairs are interleaved along the innermost axis in +// the flattened view. +// + +at::Tensor rotary_embedding_interleaved_impl( + const at::Tensor& x, + const at::Tensor& freqs_cis) { + const int64_t B = x.size(0); + const int64_t N = x.size(1); + const int64_t C = x.size(2); + + std::vector x_pairs = at::unbind(x.reshape({B, N, C / 2, 2}), -1); + at::Tensor& a_real = x_pairs[0]; + at::Tensor& a_imag = x_pairs[1]; + + at::Tensor freqs_3d = freqs_cis.reshape({N, C / 2, 2}); + std::vector freq_pairs = at::unbind(freqs_3d, -1); + at::Tensor& b_real = freq_pairs[0]; + at::Tensor& b_imag = freq_pairs[1]; + + at::Tensor out_real = a_real * b_real - a_imag * b_imag; + at::Tensor out_imag = a_real * b_imag + a_imag * b_real; + + return at::stack({out_real, out_imag}, -1).reshape({B, N, C}); +} + +void test_reference_interleaved( + const int B = 1, + const int N = 256, + const int C = 256, + const at::ScalarType dtype = at::kFloat, + const std::vector& freqs_shape = {}) { + at::Tensor x = at::rand({B, N, C}, at::device(at::kCPU).dtype(dtype)); + std::vector fshape = + freqs_shape.empty() ? std::vector{N, C} : freqs_shape; + at::Tensor freqs_cis = at::rand(fshape, at::device(at::kCPU).dtype(dtype)); + + at::Tensor ref = rotary_embedding_interleaved_impl(x, freqs_cis); + + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(utils::kTexture3D); + ComputeGraph graph(config); + + IOValueRef r_x = graph.add_input_tensor( + x.sizes().vec(), from_at_scalartype(x.scalar_type())); + // freqs_cis is always buffer-backed: the shader flat-indexes it so it is + // insensitive to the tensor's declared rank (matches op_registry pinning). + IOValueRef r_freqs_cis = graph.add_input_tensor( + freqs_cis.sizes().vec(), + from_at_scalartype(freqs_cis.scalar_type()), + utils::kBuffer); + + const ValueRef r_out = graph.add_tensor( + ref.sizes().vec(), from_at_scalartype(ref.scalar_type())); + + VK_GET_OP_FN("et_vk.apply_rotary_emb_interleaved.default") + (graph, {r_x.value, r_freqs_cis.value, r_out}); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.prepack(); + + graph.propagate_resize(); + graph.maybe_cast_and_copy_into_staging( + r_x.staging, + x.const_data_ptr(), + x.numel(), + from_at_scalartype(x.scalar_type())); + graph.maybe_cast_and_copy_into_staging( + r_freqs_cis.staging, + freqs_cis.const_data_ptr(), + freqs_cis.numel(), + from_at_scalartype(freqs_cis.scalar_type())); + + graph.execute(); + + at::Tensor vk_out = at::empty_like(ref); + graph.maybe_cast_and_copy_from_staging( + staging_out, + vk_out.mutable_data_ptr(), + vk_out.numel(), + from_at_scalartype(vk_out.scalar_type())); + + const double tol = (dtype == at::kHalf) ? 5e-3 : 1e-4; + EXPECT_TRUE(at::allclose(ref, vk_out, tol, tol)); +} + +// EdgeTAM self-attention shape +TEST(VulkanRotaryEmbeddingInterleavedTest, edgetam_self_attn_fp32) { + test_reference_interleaved(/*B=*/1, /*N=*/256, /*C=*/256, at::kFloat); +} + +TEST(VulkanRotaryEmbeddingInterleavedTest, edgetam_self_attn_fp16) { + test_reference_interleaved(/*B=*/1, /*N=*/256, /*C=*/256, at::kHalf); +} + +// EdgeTAM cross-attention memory-bank shape +TEST(VulkanRotaryEmbeddingInterleavedTest, edgetam_cross_attn_fp32) { + test_reference_interleaved(/*B=*/1, /*N=*/1792, /*C=*/256, at::kFloat); +} + +TEST(VulkanRotaryEmbeddingInterleavedTest, edgetam_cross_attn_fp16) { + test_reference_interleaved(/*B=*/1, /*N=*/1792, /*C=*/256, at::kHalf); +} + +// Buffer storage path +void test_reference_interleaved_buffer( + const int B = 1, + const int N = 256, + const int C = 256, + const at::ScalarType dtype = at::kFloat, + const std::vector& freqs_shape = {}) { + at::Tensor x = at::rand({B, N, C}, at::device(at::kCPU).dtype(dtype)); + std::vector fshape = + freqs_shape.empty() ? std::vector{N, C} : freqs_shape; + at::Tensor freqs_cis = at::rand(fshape, at::device(at::kCPU).dtype(dtype)); + + at::Tensor ref = rotary_embedding_interleaved_impl(x, freqs_cis); + + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(utils::kBuffer); + ComputeGraph graph(config); + + IOValueRef r_x = graph.add_input_tensor( + x.sizes().vec(), from_at_scalartype(x.scalar_type())); + IOValueRef r_freqs_cis = graph.add_input_tensor( + freqs_cis.sizes().vec(), from_at_scalartype(freqs_cis.scalar_type())); + + const ValueRef r_out = graph.add_tensor( + ref.sizes().vec(), from_at_scalartype(ref.scalar_type())); + + VK_GET_OP_FN("et_vk.apply_rotary_emb_interleaved.default") + (graph, {r_x.value, r_freqs_cis.value, r_out}); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.prepack(); + + graph.propagate_resize(); + graph.maybe_cast_and_copy_into_staging( + r_x.staging, + x.const_data_ptr(), + x.numel(), + from_at_scalartype(x.scalar_type())); + graph.maybe_cast_and_copy_into_staging( + r_freqs_cis.staging, + freqs_cis.const_data_ptr(), + freqs_cis.numel(), + from_at_scalartype(freqs_cis.scalar_type())); + + graph.execute(); + + at::Tensor vk_out = at::empty_like(ref); + graph.maybe_cast_and_copy_from_staging( + staging_out, + vk_out.mutable_data_ptr(), + vk_out.numel(), + from_at_scalartype(vk_out.scalar_type())); + + const double tol = (dtype == at::kHalf) ? 5e-3 : 1e-4; + EXPECT_TRUE(at::allclose(ref, vk_out, tol, tol)); +} + +TEST(VulkanRotaryEmbeddingInterleavedTest, edgetam_self_attn_buffer_fp32) { + test_reference_interleaved_buffer( + /*B=*/1, /*N=*/256, /*C=*/256, at::kFloat); +} + +TEST(VulkanRotaryEmbeddingInterleavedTest, edgetam_cross_attn_buffer_fp32) { + test_reference_interleaved_buffer( + /*B=*/1, /*N=*/1792, /*C=*/256, at::kFloat); +} + +// 4D freqs_cis [1, N, C/2, 2] — EdgeTAM exporter emits this shape directly +// from torch.view_as_real(...).unsqueeze(0), so this case must work without +// the caller inserting a view_copy. +TEST(VulkanRotaryEmbeddingInterleavedTest, edgetam_self_attn_buffer_4d_freqs) { + test_reference_interleaved_buffer( + /*B=*/1, + /*N=*/256, + /*C=*/256, + at::kFloat, + /*freqs_shape=*/{1, 256, 128, 2}); +} + +TEST(VulkanRotaryEmbeddingInterleavedTest, edgetam_cross_attn_buffer_4d_freqs) { + test_reference_interleaved_buffer( + /*B=*/1, + /*N=*/1792, + /*C=*/256, + at::kFloat, + /*freqs_shape=*/{1, 1792, 128, 2}); +} + +TEST(VulkanRotaryEmbeddingInterleavedTest, edgetam_self_attn_4d_freqs_fp16) { + test_reference_interleaved( + /*B=*/1, + /*N=*/256, + /*C=*/256, + at::kHalf, + /*freqs_shape=*/{1, 256, 128, 2}); +} + +TEST(VulkanRotaryEmbeddingInterleavedTest, edgetam_cross_attn_4d_freqs_fp16) { + test_reference_interleaved( + /*B=*/1, + /*N=*/1792, + /*C=*/256, + at::kHalf, + /*freqs_shape=*/{1, 1792, 128, 2}); +}