Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions backends/vulkan/custom_ops_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,55 @@ def apply_rotary_emb_hf_impl(
lib.impl(name, apply_rotary_emb_hf_impl, "CompositeExplicitAutograd")
apply_rotary_emb_hf_op = getattr(getattr(torch.ops, namespace), name)

##################################
## apply_rotary_emb_interleaved ##
##################################


def apply_rotary_emb_interleaved_impl(
x: torch.Tensor, freqs_cis: torch.Tensor
) -> torch.Tensor:
# EdgeTAM's pair-interleaved complex-number RoPE.
# x: [B, N, C] with (real, imag) pairs interleaved along C
# freqs_cis: any rank whose flattened layout is [N, C]. Commonly 2D
# [N, C] or 4D [1, N, C/2, 2] from
# `torch.view_as_real(...).unsqueeze(0)`. The (cos, sin)
# pairs are interleaved along the innermost axis in the
# flattened view.
# Semantically equivalent to:
# freqs_cis.reshape(N, C // 2, 2) -> (cos, sin)
# out[2k] = x[2k] * cos[k] - x[2k+1] * sin[k]
# out[2k+1] = x[2k] * sin[k] + x[2k+1] * cos[k]
B, N, C = x.shape
a_real, a_imag = x.view(B, N, C // 2, 2).unbind(-1)
# Use reshape so callers may pass freqs_cis at any rank.
cs = freqs_cis.reshape(N, C // 2, 2)
b_real, b_imag = cs[..., 0], cs[..., 1]
out = torch.stack(
(a_real * b_real - a_imag * b_imag, a_real * b_imag + a_imag * b_real),
dim=-1,
)
return out.view(B, N, C)


def apply_rotary_emb_interleaved_meta(
x: torch.Tensor, freqs_cis: torch.Tensor
) -> torch.Tensor:
# Meta kernel: shape-only. Keeps the op opaque during torch.export (no
# inlining of view/reshape calls into the exported graph) and does not
# constrain the rank of freqs_cis — any shape with N * C elements is
# accepted by the Vulkan dispatcher.
return torch.empty_like(x)


name = "apply_rotary_emb_interleaved"
lib.define(f"{name}(Tensor x, Tensor freqs_cis) -> Tensor")
# CPU kernel preserves eager-mode reference semantics.
lib.impl(name, apply_rotary_emb_interleaved_impl, "CPU")
# Meta kernel keeps the op opaque in the exported graph.
lib.impl(name, apply_rotary_emb_interleaved_meta, "Meta")
apply_rotary_emb_interleaved_op = getattr(getattr(torch.ops, namespace), name)

########################
## q8ta_add ##
########################
Expand Down
16 changes: 16 additions & 0 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,6 +1110,22 @@ def register_apply_rotary_emb_hf():
)


@update_features(exir_ops.edge.et_vk.apply_rotary_emb_interleaved.default)
def register_apply_rotary_emb_interleaved():
return OpFeatures(
# freqs_cis is pinned to buffer storage so the shader can compute a
# flat [N, C] linear address regardless of the tensor's declared rank
# (callers commonly pass 4D [1, N, C/2, 2] without a preceding view).
inputs_storage=[
utils.CONTIGUOUS_ANY, # x
utils.CONTIGUOUS_BUFFER, # freqs_cis
],
inputs_dtypes=utils.FP_T,
supports_resize=True,
supports_highdim=True,
)


# =============================================================================
# Permute.cpp
# =============================================================================
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

${define_required_extensions(STORAGE, DTYPE)}
${define_required_extensions("buffer", DTYPE)}

#define PRECISION ${PRECISION}

#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
#define BUFFER_VEC4_T ${texel_load_type(DTYPE, "buffer")}

${define_active_storage_type(STORAGE)}

layout(std430) buffer;

#include "indexing.glslh"

${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE, is_scalar_array=False)}
// freqs_cis is always bound as a buffer so the shader can flat-index it
// regardless of the caller's declared rank (2D [N, C] or 4D [1, N, C/2, 2]).
${layout_declare_tensor(B, "r", "t_freqs", DTYPE, "buffer", is_scalar_array=False)}

$if STORAGE == "buffer":
${layout_declare_ubo(B, "BufferMetadata", "outp")}
$else:
${layout_declare_ubo(B, "TextureMetadata", "outp")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

${layout_declare_spec_const(C, "int", "outp_layout", "CONTIG_LAYOUT_INT")}

/*
* Applies rotary positional embeddings to a tensor whose last dimension
* contains pair-interleaved (real, imag) components. This matches EdgeTAM's
* `apply_rotary_enc_without_complex` semantics, where the fused cos/sin
* freqs tensor has a flattened [N, C] layout (cos, sin pairs interleaved).
*
* Inputs:
* t_in: [B, N, C] (last dim packed as [r0, i0, r1, i1, ...])
* t_freqs: contiguous memory with N * C elements. May arrive at any rank
* (e.g. 2D [N, C] or 4D [1, N, C/2, 2]). Physically the values
* are [cos0, sin0, cos1, sin1, ...] along the final axis.
*
* Output:
* t_out: same shape as t_in
*
* Math per k in [0, C/2):
* out[2k] = x[2k] * cos[k] - x[2k+1] * sin[k]
* out[2k+1] = x[2k] * sin[k] + x[2k+1] * cos[k]
*
* Each thread processes one width-packed texel (4 elements = 2 (r, i) pairs).
* All participating tensors are assumed to be width-packed with standard axis
* maps.
*
* The freqs tensor is indexed using a flat (n_idx * C + c_offset) address to
* remain correct regardless of input rank — the shape of t_freqs does not
* need to match the logical [N, C] layout, only the underlying memory does.
*/
void main() {
// Each thread computes one output texel of 4 elements along the last dim.
TensorIndex4D out_tidx = zero_tensor4d_idx();
out_tidx.data.x = int(gl_GlobalInvocationID.x) * 4;
out_tidx.data.y = int(gl_GlobalInvocationID.y);
out_tidx.data.z = int(gl_GlobalInvocationID.z);

if (out_of_bounds(out_tidx, outp)) {
return;
}

// Freqs tensor is always a contiguous buffer of N * C elements. Compute
// a flat texel index directly from logical (n_idx, c_elem_idx / 4). The
// logical width C comes from the output tensor metadata — both buffer
// and texture metadata store this at index 0 (sizes[0][0] / sizes.x).
#ifdef USING_BUFFER
const uint C_width = outp.sizes[0][0];
#else
const uint C_width = uint(outp.sizes.x);
#endif
const uint freqs_texel_bufi =
uint(out_tidx.data.y) * div_4(C_width)
+ uint(gl_GlobalInvocationID.x);
BUFFER_VEC4_T f_tex = t_freqs[freqs_texel_bufi];

#ifdef USING_BUFFER
const uint out_texel_bufi =
div_4(tensor4d_idx_to_linear_idx(outp, out_tidx));
VEC4_T x_tex = t_in[out_texel_bufi];
#else // USING_TEXTURE
const ivec3 out_pos =
tensor4d_idx_to_texel_pos_simple(outp, out_tidx, outp_layout);
VEC4_T x_tex = texelFetch(t_in, out_pos, 0);
#endif

// x_tex = (r0, i0, r1, i1), f_tex = (c0, s0, c1, s1)
VEC4_T out_tex;
out_tex.x = x_tex.x * VEC4_T(f_tex).x - x_tex.y * VEC4_T(f_tex).y;
out_tex.y = x_tex.x * VEC4_T(f_tex).y + x_tex.y * VEC4_T(f_tex).x;
out_tex.z = x_tex.z * VEC4_T(f_tex).z - x_tex.w * VEC4_T(f_tex).w;
out_tex.w = x_tex.z * VEC4_T(f_tex).w + x_tex.w * VEC4_T(f_tex).z;

#ifdef USING_BUFFER
t_out[out_texel_bufi] = out_tex;
#else
imageStore(t_out, out_pos, out_tex);
#endif
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
apply_rotary_emb_interleaved:
parameter_names_with_default_values:
DTYPE: float
STORAGE: texture3d
generate_variant_forall:
STORAGE:
- VALUE: texture3d
- VALUE: buffer
DTYPE:
- VALUE: half
- VALUE: float
shader_variants:
- NAME: apply_rotary_emb_interleaved
108 changes: 108 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,117 @@ void apply_rotary_emb_hf(
graph, args[0], args[1], args[2], args[3], args[4], xq_out, xk_out);
}

//
// EdgeTAM-style RoPE variant with fused [cos, sin] freqs tensor
//
// Operates on a single tensor (Q or K) of shape [B, N, C] with pair-interleaved
// (real, imag) components along the last dim, and a freqs tensor with a total
// element count of N * C that packs (cos, sin) pairs in the same interleaved
// order as the x tensor. The freqs tensor may be passed in at any rank whose
// flattened layout is [N, C] — e.g. 2D `[N, C]` or 4D `[1, N, C/2, 2]`. This
// avoids callers having to emit a `view` dispatch (view_copy) purely to
// normalize rank.
//

void resize_rotary_embedding_interleaved_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& resize_args) {
(void)resize_args;

const ValueRef out = args.at(0).refs.at(0);
const ValueRef in = args.at(1).refs.at(0);

graph->virtual_resize(out, graph->sizes_of(in));
}

utils::uvec3 rotary_embedding_interleaved_global_wg_size(
ComputeGraph* graph,
const vkapi::ShaderInfo& shader,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& resize_args) {
(void)shader;
(void)resize_args;

const ValueRef out = args.at(0).refs.at(0);

const std::vector<int64_t> out_sizes = graph->sizes_of(out);
VK_CHECK_COND(out_sizes.size() == 3);

const uint32_t B = static_cast<uint32_t>(out_sizes.at(0));
const uint32_t N = static_cast<uint32_t>(out_sizes.at(1));
const uint32_t C = static_cast<uint32_t>(out_sizes.at(2));

// One thread per output texel of 4 elements along C.
return {utils::div_up_4(C), N, B};
}

void add_rotary_embedding_interleaved_node(
ComputeGraph& graph,
const ValueRef x,
const ValueRef freqs_cis,
const ValueRef out) {
const std::vector<int64_t> x_sizes = graph.sizes_of(x);
const std::vector<int64_t> freqs_sizes = graph.sizes_of(freqs_cis);

VK_CHECK_COND(x_sizes.size() == 3);
VK_CHECK_COND(x_sizes.at(2) % 4 == 0);

// freqs_cis may arrive at any rank (commonly 2D [N, C] or 4D [1, N, C/2, 2]
// from `torch.view_as_real(...).unsqueeze(0)`). Validate via numel rather
// than per-dim equality so callers do not need to emit a view_copy purely
// to flatten the shape.
int64_t freqs_numel = 1;
for (const int64_t s : freqs_sizes) {
freqs_numel *= s;
}
const int64_t expected_numel = x_sizes.at(1) * x_sizes.at(2);
VK_CHECK_COND(freqs_numel == expected_numel);

VK_CHECK_COND(graph.packed_dim_of(x) == WHCN::kWidthDim);
VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kWidthDim);
VK_CHECK_COND(graph.has_standard_axis_map(x));
VK_CHECK_COND(graph.has_standard_axis_map(out));
// freqs_cis is pinned to buffer storage via op_registry so the shader can
// use flat (row, col) indexing regardless of its declared rank.
VK_CHECK_COND(graph.is_buffer_storage(freqs_cis));

std::string kernel_name = "apply_rotary_emb_interleaved";
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
add_dtype_suffix(kernel_name, graph.dtype_of(out));

vkapi::ParamsBindList param_ubos = {graph.meta_ubo(out)};

graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
rotary_embedding_interleaved_global_wg_size,
default_pick_local_wg_size,
// Inputs and Outputs
{{out, vkapi::kWrite}, {{x, freqs_cis}, vkapi::kRead}},
// Parameter buffers
param_ubos,
// Push Constants
{},
// Specialization Constants
{graph.hashed_layout_of(out)},
// Resize Args
{},
// Resizing Logic
resize_rotary_embedding_interleaved_node));
}

void apply_rotary_emb_interleaved(
ComputeGraph& graph,
const std::vector<ValueRef>& args) {
add_rotary_embedding_interleaved_node(graph, args[0], args[1], args[2]);
}

REGISTER_OPERATORS {
VK_REGISTER_OP(et_vk.apply_rotary_emb.default, apply_rotary_emb);
VK_REGISTER_OP(et_vk.apply_rotary_emb_hf.default, apply_rotary_emb_hf);
VK_REGISTER_OP(
et_vk.apply_rotary_emb_interleaved.default, apply_rotary_emb_interleaved);
}

} // namespace vkcompute
Loading
Loading