[ET-VK] Add apply_rotary_emb_interleaved fused operator#19115
[ET-VK] Add apply_rotary_emb_interleaved fused operator#19115SS-JIA wants to merge 3 commits intogh/SS-JIA/524/basefrom
Conversation
Introduces `et_vk.apply_rotary_emb_interleaved`, a fused Vulkan custom operator for the "complex-number" RoPE variant used by SAM2/EdgeTAM's memory attention. This replaces a 12+-op layout-shuffle chain (`view/unbind/stack/view` -> lowers to `slice_copy + squeeze_copy + unsqueeze_copy + cat + view_copy`) with a single GPU dispatch. **Math**: On pair-interleaved inputs where element `2k` is real and `2k+1` is imag, for each `k in [0, C/2)`: out[2k] = x[2k] * cos[k] - x[2k+1] * sin[k] out[2k+1] = x[2k] * sin[k] + x[2k+1] * cos[k] **Why a new op instead of reusing `et_vk.apply_rotary_emb`**: The existing LLM-oriented operator takes `(xq, xk)` pairs with separate `freqs_cos` / `freqs_sin` tensors and 4D `(B, S, H, D)` shapes optimized for LLM prefill two-texel-per-thread reuse. SAM2's memory attention passes a single 3D `(B, N, C)` tensor through RoPE (no heads dim) with a fused `[N, C/2, 2]` freqs tensor. Reusing the existing op would force runtime splits of the fused freqs and double-dispatch Q/K separately, defeating the fuse. A sibling shader is tighter for both workloads. **Op contract**: `apply_rotary_emb_interleaved(x, freqs_cis) -> Tensor` where `x` is `[B, N, C]` and `freqs_cis` is any rank with `N*C` elements and the `cos`/`sin` values interleaved on the innermost dim. In EdgeTAM's memory attention the native shape is `[1, N, C/2, 2]`; passing it through without a reshape keeps the exported graph clean of bracketing view_copy dispatches. **Shader**: Single-dispatch kernel, one texel out per thread. Each thread reads one `x` texel (2 real/imag pairs) and the corresponding `freqs_cis` entries (2 cos/sin pairs) flat-indexed from buffer storage, writes one output texel. `x` and output support buffer + texture3d; `freqs_cis` is always buffer-storage (small tensor, flat indexing is simplest). Supports fp16 and fp32 via the `FP_T` dtype iterator in the YAML. **Op registration**: `Meta` kernel returns `torch.empty_like(x)` to keep the op opaque during `torch.export`. `CPU` kernel holds the reference math so non-Vulkan backends keep working. `op_registry.py` pins `freqs_cis` storage to `CONTIGUOUS_BUFFER` while leaving `x` at `CONTIGUOUS_ANY`. Differential Revision: [D102360202](https://our.internmc.facebook.com/intern/diff/D102360202/) [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19115
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 1 New Failure, 3 Cancelled Jobs, 2 Unrelated FailuresAs of commit 9ad2459 with merge base eef7921 ( NEW FAILURE - The following job has failed:
CANCELLED JOBS - The following jobs were cancelled. Please retry:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
Introduces `et_vk.apply_rotary_emb_interleaved`, a fused Vulkan custom operator for the "complex-number" RoPE variant used by SAM2/EdgeTAM's memory attention. This replaces a 12+-op layout-shuffle chain (`view/unbind/stack/view` -> lowers to `slice_copy + squeeze_copy + unsqueeze_copy + cat + view_copy`) with a single GPU dispatch. **Math**: On pair-interleaved inputs where element `2k` is real and `2k+1` is imag, for each `k in [0, C/2)`: out[2k] = x[2k] * cos[k] - x[2k+1] * sin[k] out[2k+1] = x[2k] * sin[k] + x[2k+1] * cos[k] **Why a new op instead of reusing `et_vk.apply_rotary_emb`**: The existing LLM-oriented operator takes `(xq, xk)` pairs with separate `freqs_cos` / `freqs_sin` tensors and 4D `(B, S, H, D)` shapes optimized for LLM prefill two-texel-per-thread reuse. SAM2's memory attention passes a single 3D `(B, N, C)` tensor through RoPE (no heads dim) with a fused `[N, C/2, 2]` freqs tensor. Reusing the existing op would force runtime splits of the fused freqs and double-dispatch Q/K separately, defeating the fuse. A sibling shader is tighter for both workloads. **Op contract**: `apply_rotary_emb_interleaved(x, freqs_cis) -> Tensor` where `x` is `[B, N, C]` and `freqs_cis` is any rank with `N*C` elements and the `cos`/`sin` values interleaved on the innermost dim. In EdgeTAM's memory attention the native shape is `[1, N, C/2, 2]`; passing it through without a reshape keeps the exported graph clean of bracketing view_copy dispatches. **Shader**: Single-dispatch kernel, one texel out per thread. Each thread reads one `x` texel (2 real/imag pairs) and the corresponding `freqs_cis` entries (2 cos/sin pairs) flat-indexed from buffer storage, writes one output texel. `x` and output support buffer + texture3d; `freqs_cis` is always buffer-storage (small tensor, flat indexing is simplest). Supports fp16 and fp32 via the `FP_T` dtype iterator in the YAML. **Op registration**: `Meta` kernel returns `torch.empty_like(x)` to keep the op opaque during `torch.export`. `CPU` kernel holds the reference math so non-Vulkan backends keep working. `op_registry.py` pins `freqs_cis` storage to `CONTIGUOUS_BUFFER` while leaving `x` at `CONTIGUOUS_ANY`. Differential Revision: [D102360202](https://our.internmc.facebook.com/intern/diff/D102360202/) [ghstack-poisoned]
Pull Request resolved: #19115 Introduces `et_vk.apply_rotary_emb_interleaved`, a fused Vulkan custom operator for the "complex-number" RoPE variant used by SAM2/EdgeTAM's memory attention. This replaces a 12+-op layout-shuffle chain (`view/unbind/stack/view` -> lowers to `slice_copy + squeeze_copy + unsqueeze_copy + cat + view_copy`) with a single GPU dispatch. **Math**: On pair-interleaved inputs where element `2k` is real and `2k+1` is imag, for each `k in [0, C/2)`: out[2k] = x[2k] * cos[k] - x[2k+1] * sin[k] out[2k+1] = x[2k] * sin[k] + x[2k+1] * cos[k] **Why a new op instead of reusing `et_vk.apply_rotary_emb`**: The existing LLM-oriented operator takes `(xq, xk)` pairs with separate `freqs_cos` / `freqs_sin` tensors and 4D `(B, S, H, D)` shapes optimized for LLM prefill two-texel-per-thread reuse. SAM2's memory attention passes a single 3D `(B, N, C)` tensor through RoPE (no heads dim) with a fused `[N, C/2, 2]` freqs tensor. Reusing the existing op would force runtime splits of the fused freqs and double-dispatch Q/K separately, defeating the fuse. A sibling shader is tighter for both workloads. **Op contract**: `apply_rotary_emb_interleaved(x, freqs_cis) -> Tensor` where `x` is `[B, N, C]` and `freqs_cis` is any rank with `N*C` elements and the `cos`/`sin` values interleaved on the innermost dim. In EdgeTAM's memory attention the native shape is `[1, N, C/2, 2]`; passing it through without a reshape keeps the exported graph clean of bracketing view_copy dispatches. **Shader**: Single-dispatch kernel, one texel out per thread. Each thread reads one `x` texel (2 real/imag pairs) and the corresponding `freqs_cis` entries (2 cos/sin pairs) flat-indexed from buffer storage, writes one output texel. `x` and output support buffer + texture3d; `freqs_cis` is always buffer-storage (small tensor, flat indexing is simplest). Supports fp16 and fp32 via the `FP_T` dtype iterator in the YAML. **Op registration**: `Meta` kernel returns `torch.empty_like(x)` to keep the op opaque during `torch.export`. `CPU` kernel holds the reference math so non-Vulkan backends keep working. `op_registry.py` pins `freqs_cis` storage to `CONTIGUOUS_BUFFER` while leaving `x` at `CONTIGUOUS_ANY`. ghstack-source-id: 372548626 @exported-using-ghexport Differential Revision: [D102360202](https://our.internmc.facebook.com/intern/diff/D102360202/)
Introduces `et_vk.apply_rotary_emb_interleaved`, a fused Vulkan custom operator for the "complex-number" RoPE variant used by SAM2/EdgeTAM's memory attention. This replaces a 12+-op layout-shuffle chain (`view/unbind/stack/view` -> lowers to `slice_copy + squeeze_copy + unsqueeze_copy + cat + view_copy`) with a single GPU dispatch. **Math**: On pair-interleaved inputs where element `2k` is real and `2k+1` is imag, for each `k in [0, C/2)`: out[2k] = x[2k] * cos[k] - x[2k+1] * sin[k] out[2k+1] = x[2k] * sin[k] + x[2k+1] * cos[k] **Why a new op instead of reusing `et_vk.apply_rotary_emb`**: The existing LLM-oriented operator takes `(xq, xk)` pairs with separate `freqs_cos` / `freqs_sin` tensors and 4D `(B, S, H, D)` shapes optimized for LLM prefill two-texel-per-thread reuse. SAM2's memory attention passes a single 3D `(B, N, C)` tensor through RoPE (no heads dim) with a fused `[N, C/2, 2]` freqs tensor. Reusing the existing op would force runtime splits of the fused freqs and double-dispatch Q/K separately, defeating the fuse. A sibling shader is tighter for both workloads. **Op contract**: `apply_rotary_emb_interleaved(x, freqs_cis) -> Tensor` where `x` is `[B, N, C]` and `freqs_cis` is any rank with `N*C` elements and the `cos`/`sin` values interleaved on the innermost dim. In EdgeTAM's memory attention the native shape is `[1, N, C/2, 2]`; passing it through without a reshape keeps the exported graph clean of bracketing view_copy dispatches. **Shader**: Single-dispatch kernel, one texel out per thread. Each thread reads one `x` texel (2 real/imag pairs) and the corresponding `freqs_cis` entries (2 cos/sin pairs) flat-indexed from buffer storage, writes one output texel. `x` and output support buffer + texture3d; `freqs_cis` is always buffer-storage (small tensor, flat indexing is simplest). Supports fp16 and fp32 via the `FP_T` dtype iterator in the YAML. **Op registration**: `Meta` kernel returns `torch.empty_like(x)` to keep the op opaque during `torch.export`. `CPU` kernel holds the reference math so non-Vulkan backends keep working. `op_registry.py` pins `freqs_cis` storage to `CONTIGUOUS_BUFFER` while leaving `x` at `CONTIGUOUS_ANY`. Differential Revision: [D102360202](https://our.internmc.facebook.com/intern/diff/D102360202/) [ghstack-poisoned]
Pull Request resolved: #19115 Introduces `et_vk.apply_rotary_emb_interleaved`, a fused Vulkan custom operator for the "complex-number" RoPE variant used by SAM2/EdgeTAM's memory attention. This replaces a 12+-op layout-shuffle chain (`view/unbind/stack/view` -> lowers to `slice_copy + squeeze_copy + unsqueeze_copy + cat + view_copy`) with a single GPU dispatch. **Math**: On pair-interleaved inputs where element `2k` is real and `2k+1` is imag, for each `k in [0, C/2)`: out[2k] = x[2k] * cos[k] - x[2k+1] * sin[k] out[2k+1] = x[2k] * sin[k] + x[2k+1] * cos[k] **Why a new op instead of reusing `et_vk.apply_rotary_emb`**: The existing LLM-oriented operator takes `(xq, xk)` pairs with separate `freqs_cos` / `freqs_sin` tensors and 4D `(B, S, H, D)` shapes optimized for LLM prefill two-texel-per-thread reuse. SAM2's memory attention passes a single 3D `(B, N, C)` tensor through RoPE (no heads dim) with a fused `[N, C/2, 2]` freqs tensor. Reusing the existing op would force runtime splits of the fused freqs and double-dispatch Q/K separately, defeating the fuse. A sibling shader is tighter for both workloads. **Op contract**: `apply_rotary_emb_interleaved(x, freqs_cis) -> Tensor` where `x` is `[B, N, C]` and `freqs_cis` is any rank with `N*C` elements and the `cos`/`sin` values interleaved on the innermost dim. In EdgeTAM's memory attention the native shape is `[1, N, C/2, 2]`; passing it through without a reshape keeps the exported graph clean of bracketing view_copy dispatches. **Shader**: Single-dispatch kernel, one texel out per thread. Each thread reads one `x` texel (2 real/imag pairs) and the corresponding `freqs_cis` entries (2 cos/sin pairs) flat-indexed from buffer storage, writes one output texel. `x` and output support buffer + texture3d; `freqs_cis` is always buffer-storage (small tensor, flat indexing is simplest). Supports fp16 and fp32 via the `FP_T` dtype iterator in the YAML. **Op registration**: `Meta` kernel returns `torch.empty_like(x)` to keep the op opaque during `torch.export`. `CPU` kernel holds the reference math so non-Vulkan backends keep working. `op_registry.py` pins `freqs_cis` storage to `CONTIGUOUS_BUFFER` while leaving `x` at `CONTIGUOUS_ANY`. ghstack-source-id: 372777610 @exported-using-ghexport Differential Revision: [D102360202](https://our.internmc.facebook.com/intern/diff/D102360202/)
Stack from ghstack (oldest at bottom):
Introduces
et_vk.apply_rotary_emb_interleaved, a fused Vulkan custom operator for the "complex-number" RoPE variant used by SAM2/EdgeTAM's memory attention. This replaces a 12+-op layout-shuffle chain (view/unbind/stack/view-> lowers toslice_copy + squeeze_copy + unsqueeze_copy + cat + view_copy) with a single GPU dispatch.Math: On pair-interleaved inputs where element
2kis real and2k+1is imag, for eachk in [0, C/2):out[2k] = x[2k] * cos[k] - x[2k+1] * sin[k]
out[2k+1] = x[2k] * sin[k] + x[2k+1] * cos[k]
Why a new op instead of reusing
et_vk.apply_rotary_emb: The existing LLM-oriented operator takes(xq, xk)pairs with separatefreqs_cos/freqs_sintensors and 4D(B, S, H, D)shapes optimized for LLM prefill two-texel-per-thread reuse. SAM2's memory attention passes a single 3D(B, N, C)tensor through RoPE (no heads dim) with a fused[N, C/2, 2]freqs tensor. Reusing the existing op would force runtime splits of the fused freqs and double-dispatch Q/K separately, defeating the fuse. A sibling shader is tighter for both workloads.Op contract:
apply_rotary_emb_interleaved(x, freqs_cis) -> Tensorwherexis[B, N, C]andfreqs_cisis any rank withN*Celements and thecos/sinvalues interleaved on the innermost dim. In EdgeTAM's memory attention the native shape is[1, N, C/2, 2]; passing it through without a reshape keeps the exported graph clean of bracketing view_copy dispatches.Shader: Single-dispatch kernel, one texel out per thread. Each thread reads one
xtexel (2 real/imag pairs) and the correspondingfreqs_cisentries (2 cos/sin pairs) flat-indexed from buffer storage, writes one output texel.xand output support buffer + texture3d;freqs_cisis always buffer-storage (small tensor, flat indexing is simplest). Supports fp16 and fp32 via theFP_Tdtype iterator in the YAML.Op registration:
Metakernel returnstorch.empty_like(x)to keep the op opaque duringtorch.export.CPUkernel holds the reference math so non-Vulkan backends keep working.op_registry.pypinsfreqs_cisstorage toCONTIGUOUS_BUFFERwhile leavingxatCONTIGUOUS_ANY.Differential Revision: D102360202