From b7d4d31109c9758823d610086973aa626704377d Mon Sep 17 00:00:00 2001 From: Julian Ng-Thow-Hing Date: Fri, 12 Jun 2026 15:35:44 -0700 Subject: [PATCH 1/2] [ExecuTorch][WebGPU] GPU timestamp query profiling (general implementation) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pull Request resolved: https://github.com/pytorch/executorch/pull/20201 Backend-agnostic GPU-timestamp infrastructure, split out so the general implementation is foundational (below SDPA) while the SDPA-specific dispatch labeling stays above the SDPA op. Composed of: `WebGPUQueryPool`, a faithful re-port of Vulkan's `vkapi::QueryPool` (`backends/vulkan/runtime/vk_api/QueryPool.{h,cpp}`) — same `ShaderDuration` data model and ticks->ns conversion; three deviations are forced by the WebGPU API (per-dispatch bracketing via a compute-pass `timestampWrites` descriptor since there is no mid-encoder `writeTimestamp`; readback via `resolveQuerySet` + buffer map rather than host-side `vkGetQueryPoolResults`; the `TimestampQuery` capability requested as an explicit device feature, fail-open if the adapter lacks it). `WebGPUDevice` gains timestamp-feature detection, and `WebGPUGraph` gains a per-dispatch `kernel_name` label plus `execute()` bracketing of each compute pass when the pool is active. Opt-in via the `WEBGPU_TIMESTAMP_QUERY` env var; off by default, so the production `execute()` path is byte-identical. The SDPA per-kernel labeling lives in the companion "for SDPA" diff above the SDPA op. Co-authored with Claude. ghstack-source-id: 392975889 @exported-using-ghexport Differential Revision: [D108188287](https://our.internmc.facebook.com/intern/diff/D108188287/) --- backends/webgpu/CMakeLists.txt | 17 ++ backends/webgpu/runtime/WebGPUDevice.cpp | 19 ++ backends/webgpu/runtime/WebGPUDevice.h | 12 ++ backends/webgpu/runtime/WebGPUGraph.cpp | 71 ++++++- backends/webgpu/runtime/WebGPUGraph.h | 1 + backends/webgpu/runtime/WebGPUQueryPool.cpp | 224 ++++++++++++++++++++ backends/webgpu/runtime/WebGPUQueryPool.h | 88 ++++++++ backends/webgpu/test/test_webgpu_native.cpp | 132 +++++++++++- 8 files changed, 562 insertions(+), 2 deletions(-) create mode 100644 backends/webgpu/runtime/WebGPUQueryPool.cpp create mode 100644 backends/webgpu/runtime/WebGPUQueryPool.h diff --git a/backends/webgpu/CMakeLists.txt b/backends/webgpu/CMakeLists.txt index 9b1476f2290..1fc0860fc4b 100644 --- a/backends/webgpu/CMakeLists.txt +++ b/backends/webgpu/CMakeLists.txt @@ -30,6 +30,7 @@ set(WEBGPU_SRCS runtime/WebGPUGraph.cpp runtime/WebGPUDelegateHeader.cpp runtime/WebGPUDevice.cpp + runtime/WebGPUQueryPool.cpp runtime/ops/OperatorRegistry.cpp runtime/ops/add/BinaryOp.cpp runtime/ops/rms_norm/RmsNorm.cpp @@ -76,6 +77,17 @@ endif() target_compile_options(webgpu_backend PRIVATE -fexceptions) +# Opt-in GPU timestamp profiling (WebGPUQueryPool); OFF so production builds +# request no TimestampQuery device feature. Mirrors Vulkan's compile-flag gate. +option(EXECUTORCH_BUILD_WEBGPU_PROFILING + "Enable WebGPU GPU timestamp-query profiling" OFF +) +if(EXECUTORCH_BUILD_WEBGPU_PROFILING) + target_compile_definitions( + webgpu_backend PRIVATE WGPU_BACKEND_ENABLE_PROFILING + ) +endif() + # Link with --whole-archive for static registration of backend + ops executorch_target_link_options_shared_lib(webgpu_backend) @@ -114,6 +126,11 @@ function(add_webgpu_native_test test_name test_src) target_link_libraries(${test_name} PRIVATE dl m pthread) endif() target_compile_options(${test_name} PRIVATE -fexceptions) + if(EXECUTORCH_BUILD_WEBGPU_PROFILING) + target_compile_definitions( + ${test_name} PRIVATE WGPU_BACKEND_ENABLE_PROFILING + ) + endif() set_property(TARGET ${test_name} PROPERTY CXX_STANDARD 17) endfunction() diff --git a/backends/webgpu/runtime/WebGPUDevice.cpp b/backends/webgpu/runtime/WebGPUDevice.cpp index 041cbe5a703..e69101851a2 100644 --- a/backends/webgpu/runtime/WebGPUDevice.cpp +++ b/backends/webgpu/runtime/WebGPUDevice.cpp @@ -13,6 +13,9 @@ #include #include #include +#ifdef WGPU_BACKEND_ENABLE_PROFILING +#include +#endif // WGPU_BACKEND_ENABLE_PROFILING namespace executorch { namespace backends { @@ -137,6 +140,18 @@ WebGPUContext create_webgpu_context() { WGPUStatus_Success) { device_desc.requiredLimits = &supported_limits; } + +#ifdef WGPU_BACKEND_ENABLE_PROFILING + // Bench: enable TimestampQuery if available; fail-open (skip timing if not). + std::vector required_features; + if (wgpuAdapterHasFeature(ctx.adapter, WGPUFeatureName_TimestampQuery)) { + required_features.push_back(WGPUFeatureName_TimestampQuery); + device_desc.requiredFeatureCount = required_features.size(); + device_desc.requiredFeatures = required_features.data(); + ctx.timestamp_supported = true; + } +#endif // WGPU_BACKEND_ENABLE_PROFILING + device_desc.uncapturedErrorCallbackInfo.callback = on_device_error; WGPUWaitStatus device_wait = webgpu_wait( @@ -192,6 +207,10 @@ WebGPUContext* get_default_webgpu_context() { } void destroy_webgpu_context(WebGPUContext& ctx) { +#ifdef WGPU_BACKEND_ENABLE_PROFILING + // Release device-child GPU resources before the device handle. + ctx.querypool.reset(); +#endif // WGPU_BACKEND_ENABLE_PROFILING if (ctx.queue) { wgpuQueueRelease(ctx.queue); ctx.queue = nullptr; diff --git a/backends/webgpu/runtime/WebGPUDevice.h b/backends/webgpu/runtime/WebGPUDevice.h index 78afd96316a..a332edef443 100644 --- a/backends/webgpu/runtime/WebGPUDevice.h +++ b/backends/webgpu/runtime/WebGPUDevice.h @@ -10,6 +10,12 @@ #include +#ifdef WGPU_BACKEND_ENABLE_PROFILING +#include + +#include +#endif // WGPU_BACKEND_ENABLE_PROFILING + namespace executorch { namespace backends { namespace webgpu { @@ -19,6 +25,12 @@ struct WebGPUContext { WGPUAdapter adapter = nullptr; WGPUDevice device = nullptr; WGPUQueue queue = nullptr; +#ifdef WGPU_BACKEND_ENABLE_PROFILING + // True if the device was created with the TimestampQuery feature (bench). + bool timestamp_supported = false; + // Bench-only: timestamp-query pool, lazily created in execute() (env-gated). + std::unique_ptr querypool; +#endif // WGPU_BACKEND_ENABLE_PROFILING }; WebGPUContext create_webgpu_context(); diff --git a/backends/webgpu/runtime/WebGPUGraph.cpp b/backends/webgpu/runtime/WebGPUGraph.cpp index b3ae5511d13..1c977d130dd 100644 --- a/backends/webgpu/runtime/WebGPUGraph.cpp +++ b/backends/webgpu/runtime/WebGPUGraph.cpp @@ -15,6 +15,7 @@ #include #include +#include #include #include @@ -496,18 +497,57 @@ void WebGPUGraph::copy_inputs( } } +namespace { +// Bench gate: compiled out unless WGPU_BACKEND_ENABLE_PROFILING; then the +// WEBGPU_TIMESTAMP_QUERY env var enables per-pass GPU timestamp queries. +bool should_timestamp_query() { +#ifdef WGPU_BACKEND_ENABLE_PROFILING + static const bool enabled = std::getenv("WEBGPU_TIMESTAMP_QUERY") != nullptr; + return enabled; +#else + return false; +#endif +} +} // namespace + void WebGPUGraph::execute() { const size_t n = dispatches_.size(); const size_t chunk = execute_config_.chunk_size; if (chunk == 0 || n <= chunk) { +#ifdef WGPU_BACKEND_ENABLE_PROFILING + // Bench: timestamp-query pool, null unless env-gated + feature present. + WebGPUQueryPool* qp = nullptr; + if (should_timestamp_query() && n > 0) { + if (auto* ctx = get_default_webgpu_context()) { + if (ctx->timestamp_supported) { + if (!ctx->querypool || ctx->querypool->capacity() < n) { + ctx->querypool = std::make_unique(); + ctx->querypool->initialize(device_, static_cast(n)); + } + qp = ctx->querypool.get(); + qp->reset(static_cast(n)); + } + } + } +#endif // WGPU_BACKEND_ENABLE_PROFILING + WGPUCommandEncoderDescriptor enc_desc = {}; WGPUCommandEncoder encoder = wgpuDeviceCreateCommandEncoder(device_, &enc_desc); // One pass per dispatch: enforces storage RAW ordering across deps. - for (const auto& dispatch : dispatches_) { + for (size_t i = 0; i < n; i++) { + const auto& dispatch = dispatches_[i]; WGPUComputePassDescriptor pass_desc = {}; +#ifdef WGPU_BACKEND_ENABLE_PROFILING + // tw must outlive BeginComputePass (the descriptor points at it). + WGPUPassTimestampWrites tw = {}; + if (qp) { + tw = qp->writes_for(static_cast(i)); + pass_desc.timestampWrites = &tw; + } +#endif // WGPU_BACKEND_ENABLE_PROFILING WGPUComputePassEncoder pass = wgpuCommandEncoderBeginComputePass(encoder, &pass_desc); wgpuComputePassEncoderSetPipeline(pass, dispatch.pipeline); @@ -517,6 +557,15 @@ void WebGPUGraph::execute() { pass, dispatch.workgroup_count_x, 1, 1); wgpuComputePassEncoderEnd(pass); wgpuComputePassEncoderRelease(pass); +#ifdef WGPU_BACKEND_ENABLE_PROFILING + if (qp) { + qp->record( + static_cast(i), + dispatch.kernel_name, + {dispatch.workgroup_count_x, 1, 1}, + {1, 1, 1}); + } +#endif // WGPU_BACKEND_ENABLE_PROFILING } for (const auto& copy : output_copies_) { @@ -524,15 +573,35 @@ void WebGPUGraph::execute() { encoder, copy.src_buffer, 0, copy.staging_buffer, 0, copy.nbytes); } +#ifdef WGPU_BACKEND_ENABLE_PROFILING + if (qp) { + qp->resolve(encoder); + } +#endif // WGPU_BACKEND_ENABLE_PROFILING + WGPUCommandBufferDescriptor cmd_desc = {}; WGPUCommandBuffer cmd = wgpuCommandEncoderFinish(encoder, &cmd_desc); wgpuQueueSubmit(queue_, 1, &cmd); wgpuCommandBufferRelease(cmd); wgpuCommandEncoderRelease(encoder); + +#ifdef WGPU_BACKEND_ENABLE_PROFILING + if (qp) { + qp->extract_results(instance_); + qp->print_results(); + } +#endif // WGPU_BACKEND_ENABLE_PROFILING return; } + // GPU timestamp queries assume one submit; chunked execute is multi-submit. + if (should_timestamp_query()) { + throw std::runtime_error( + "WebGPU: WEBGPU_TIMESTAMP_QUERY is incompatible with chunked execute " + "(multi-submit); disable chunking to use GPU timestamp queries"); + } + const size_t first_chunk = execute_config_.initial_chunk_size > 0 ? execute_config_.initial_chunk_size : chunk; diff --git a/backends/webgpu/runtime/WebGPUGraph.h b/backends/webgpu/runtime/WebGPUGraph.h index 9f656ce4d14..92aa14d59b6 100644 --- a/backends/webgpu/runtime/WebGPUGraph.h +++ b/backends/webgpu/runtime/WebGPUGraph.h @@ -31,6 +31,7 @@ struct WebGPUDispatch { WGPUComputePipeline pipeline = nullptr; WGPUBindGroup bind_group = nullptr; uint32_t workgroup_count_x = 1; + std::string kernel_name; // bench label }; struct OutputCopy { diff --git a/backends/webgpu/runtime/WebGPUQueryPool.cpp b/backends/webgpu/runtime/WebGPUQueryPool.cpp new file mode 100644 index 00000000000..89e08a2afce --- /dev/null +++ b/backends/webgpu/runtime/WebGPUQueryPool.cpp @@ -0,0 +1,224 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include +#include +#include + +namespace executorch::backends::webgpu { + +#ifdef WGPU_BACKEND_ENABLE_PROFILING + +namespace { + +struct MapCallbackData { + WGPUMapAsyncStatus status = WGPUMapAsyncStatus_Error; +}; + +void map_callback( + WGPUMapAsyncStatus status, + WGPUStringView /*message*/, + void* userdata1, + void* /*userdata2*/) { + auto* data = static_cast(userdata1); + data->status = status; +} + +constexpr uint64_t kTimestampBytes = sizeof(uint64_t); + +} // namespace + +WebGPUQueryPool::~WebGPUQueryPool() { + if (readback_buf_) { + wgpuBufferRelease(readback_buf_); + } + if (resolve_buf_) { + wgpuBufferRelease(resolve_buf_); + } + if (qset_) { + wgpuQuerySetRelease(qset_); + } +} + +void WebGPUQueryPool::initialize(WGPUDevice device, uint32_t max_pairs) { + if (max_pairs == 0) { + return; + } + // Re-init guard; mirrors Vulkan QueryPool (avoids leaking a prior QuerySet). + if (qset_ != nullptr) { + return; + } + capacity_pairs_ = max_pairs; + const uint32_t count = 2 * max_pairs; + const uint64_t bytes = static_cast(count) * kTimestampBytes; + + WGPUQuerySetDescriptor qsd = {}; + qsd.type = WGPUQueryType_Timestamp; + qsd.count = count; + qset_ = wgpuDeviceCreateQuerySet(device, &qsd); + + WGPUBufferDescriptor rbd = {}; + rbd.size = bytes; + rbd.usage = WGPUBufferUsage_QueryResolve | WGPUBufferUsage_CopySrc; + resolve_buf_ = wgpuDeviceCreateBuffer(device, &rbd); + + WGPUBufferDescriptor mbd = {}; + mbd.size = bytes; + mbd.usage = WGPUBufferUsage_MapRead | WGPUBufferUsage_CopyDst; + readback_buf_ = wgpuDeviceCreateBuffer(device, &mbd); + // WebGPU timestamps are already nanoseconds, so ns_per_tick_ stays 1.0. +} + +void WebGPUQueryPool::reset(uint32_t num_dispatches) { + // Fail loud on overrun; mirrors Vulkan QueryPool VK_CHECK_COND guard. + if (num_dispatches > capacity_pairs_) { + throw std::runtime_error( + "WebGPUQueryPool: num_dispatches " + std::to_string(num_dispatches) + + " exceeds capacity " + std::to_string(capacity_pairs_)); + } + num_pairs_ = num_dispatches; + durations_.clear(); +} + +WGPUPassTimestampWrites WebGPUQueryPool::writes_for(uint32_t i) { + WGPUPassTimestampWrites tw = {}; + tw.querySet = qset_; + tw.beginningOfPassWriteIndex = 2 * i; + tw.endOfPassWriteIndex = 2 * i + 1; + return tw; +} + +void WebGPUQueryPool::record( + uint32_t i, + const std::string& name, + std::array gwg, + std::array lwg) { + ShaderDuration d; + d.idx = i; + d.kernel_name = name; + d.global_wg = gwg; + d.local_wg = lwg; + durations_.push_back(d); +} + +void WebGPUQueryPool::resolve(WGPUCommandEncoder encoder) { + if (num_pairs_ == 0) { + return; + } + const uint32_t count = 2 * num_pairs_; + wgpuCommandEncoderResolveQuerySet(encoder, qset_, 0, count, resolve_buf_, 0); + wgpuCommandEncoderCopyBufferToBuffer( + encoder, + resolve_buf_, + 0, + readback_buf_, + 0, + static_cast(count) * kTimestampBytes); +} + +void WebGPUQueryPool::extract_results(WGPUInstance instance) { + if (num_pairs_ == 0) { + return; + } + const uint32_t count = 2 * num_pairs_; + const uint64_t bytes = static_cast(count) * kTimestampBytes; + + MapCallbackData cb; + WGPUBufferMapCallbackInfo cb_info = {}; + cb_info.mode = WGPUCallbackMode_WaitAnyOnly; + cb_info.callback = map_callback; + cb_info.userdata1 = &cb; + webgpu_wait( + instance, + wgpuBufferMapAsync(readback_buf_, WGPUMapMode_Read, 0, bytes, cb_info)); + + if (cb.status != WGPUMapAsyncStatus_Success) { + printf( + "WebGPUQueryPool: readback map failed (status %d)\n", (int)cb.status); + return; + } + const uint64_t* ticks = static_cast( + wgpuBufferGetConstMappedRange(readback_buf_, 0, bytes)); + if (ticks != nullptr) { + for (auto& d : durations_) { + const uint64_t t0 = ticks[2 * d.idx]; + const uint64_t t1 = ticks[2 * d.idx + 1]; + d.start_time_ns = static_cast(t0 * ns_per_tick_); + d.end_time_ns = static_cast(t1 * ns_per_tick_); + d.execution_duration_ns = + (t1 >= t0) ? static_cast((t1 - t0) * ns_per_tick_) : 0; + } + } + wgpuBufferUnmap(readback_buf_); +} + +void WebGPUQueryPool::print_results(bool tsv) const { + const char* sep = tsv ? "\t" : " "; + if (tsv) { + printf("idx%skernel%sgwg%sduration_us\n", sep, sep, sep); + } else { + printf("=== WebGPUQueryPool: per-dispatch GPU time ===\n"); + } + for (const auto& d : durations_) { + const double us = d.execution_duration_ns / 1000.0; + printf( + "%u%s%s%s(%u,%u,%u)%s%.3f\n", + d.idx, + sep, + d.kernel_name.empty() ? "dispatch" : d.kernel_name.c_str(), + sep, + d.global_wg[0], + d.global_wg[1], + d.global_wg[2], + sep, + us); + } + if (tsv) { + return; + } + std::map> totals; + for (const auto& d : durations_) { + auto& t = totals[d.kernel_name.empty() ? "dispatch" : d.kernel_name]; + t.first += d.execution_duration_ns; + t.second += 1; + } + printf("--- per-kernel mean / total (us) ---\n"); + for (const auto& kv : totals) { + const double mean_us = kv.second.first / kv.second.second / 1000.0; + const double total_us = kv.second.first / 1000.0; + printf( + "%s%smean %.3f%stotal %.3f (n=%u)\n", + kv.first.c_str(), + sep, + mean_us, + sep, + total_us, + kv.second.second); + } +} + +uint64_t WebGPUQueryPool::get_mean_shader_ns( + const std::string& kernel_name) const { + uint64_t sum = 0; + uint32_t n = 0; + for (const auto& d : durations_) { + if (d.kernel_name == kernel_name) { + sum += d.execution_duration_ns; + n += 1; + } + } + return n == 0 ? 0 : sum / n; +} + +#endif // WGPU_BACKEND_ENABLE_PROFILING + +} // namespace executorch::backends::webgpu diff --git a/backends/webgpu/runtime/WebGPUQueryPool.h b/backends/webgpu/runtime/WebGPUQueryPool.h new file mode 100644 index 00000000000..9e5d6cb788c --- /dev/null +++ b/backends/webgpu/runtime/WebGPUQueryPool.h @@ -0,0 +1,88 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include +#include +#include +#include + +namespace executorch::backends::webgpu { + +#ifdef WGPU_BACKEND_ENABLE_PROFILING + +// Per-dispatch GPU timing; mirrors Vulkan QueryPool ShaderDuration. +struct ShaderDuration { + uint32_t idx = 0; + std::string kernel_name; + std::array global_wg{}; + std::array local_wg{}; + uint64_t start_time_ns = 0; + uint64_t end_time_ns = 0; + uint64_t execution_duration_ns = 0; +}; + +// GPU timestamp-query pool; re-port of Vulkan vk_api/QueryPool. +class WebGPUQueryPool { + public: + WebGPUQueryPool() = default; + ~WebGPUQueryPool(); + + WebGPUQueryPool(const WebGPUQueryPool&) = delete; + WebGPUQueryPool& operator=(const WebGPUQueryPool&) = delete; + + // Create the QuerySet + readback buffers; query the ns-per-tick period. + void initialize(WGPUDevice device, uint32_t max_pairs); + bool is_initialized() const { + return qset_ != nullptr; + } + uint32_t capacity() const { + return capacity_pairs_; + } + + // Clear durations and set the dispatch count for this run. + void reset(uint32_t num_dispatches); + + // timestampWrites for pass i: begin=2i, end=2i+1. + WGPUPassTimestampWrites writes_for(uint32_t i); + + // Record pass i's label + workgroup sizes (start/end filled by extract). + void record( + uint32_t i, + const std::string& name, + std::array gwg, + std::array lwg); + + // Resolve the QuerySet into the readback buffer; call before submit. + void resolve(WGPUCommandEncoder encoder); + + // Map the readback, convert ticks->ns, fill durations; call after submit. + void extract_results(WGPUInstance instance); + + const std::vector& results() const { + return durations_; + } + void print_results(bool tsv = false) const; + uint64_t get_mean_shader_ns(const std::string& kernel_name) const; + + private: + WGPUQuerySet qset_ = nullptr; + WGPUBuffer resolve_buf_ = nullptr; // QueryResolve | CopySrc + WGPUBuffer readback_buf_ = nullptr; // MapRead | CopyDst + uint32_t capacity_pairs_ = 0; + uint32_t num_pairs_ = 0; + double ns_per_tick_ = 1.0; // WebGPU timestamps are already nanoseconds + std::vector durations_; +}; + +#endif // WGPU_BACKEND_ENABLE_PROFILING + +} // namespace executorch::backends::webgpu diff --git a/backends/webgpu/test/test_webgpu_native.cpp b/backends/webgpu/test/test_webgpu_native.cpp index 5b9d538223e..e62d6f2b53c 100644 --- a/backends/webgpu/test/test_webgpu_native.cpp +++ b/backends/webgpu/test/test_webgpu_native.cpp @@ -133,6 +133,131 @@ static bool test_chained_add(const std::string& model_path) { return true; } +#ifdef WGPU_BACKEND_ENABLE_PROFILING +// Capacity-overrun must throw; runs without a device or TimestampQuery. +static bool test_query_pool_overrun_throws() { + printf("\n--- Test: WebGPUQueryPool capacity-overrun guard ---\n"); + WebGPUQueryPool qp; + try { + qp.reset(1); + } catch (const std::exception&) { + printf("PASS: reset beyond capacity throws\n"); + return true; + } + printf("FAIL: reset beyond capacity did not throw\n"); + return false; +} + +// WebGPUQueryPool roundtrip: time a probe pass; assert non-zero GPU duration. +static bool test_query_pool_roundtrip(const WebGPUContext& ctx) { + printf("\n--- Test: WebGPUQueryPool roundtrip ---\n"); + if (!ctx.timestamp_supported) { + printf("SKIP: adapter lacks TimestampQuery feature\n"); + return true; + } + WGPUDevice device = ctx.device; + + // Probe loop iterates enough to burn a measurable, non-zero GPU duration. + const char* kProbeWGSL = + "@group(0) @binding(0) var out: array;\n" + "@compute @workgroup_size(64)\n" + "fn main(@builtin(global_invocation_id) gid: vec3) {\n" + " var acc = 0.0;\n" + " for (var i = 0u; i < 8192u; i = i + 1u) {\n" + " acc = acc + f32(i) * 1.000001;\n" + " }\n" + " out[gid.x] = acc;\n" + "}\n"; + + WGPUShaderSourceWGSL wgsl_desc = {}; + wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL; + wgsl_desc.code = {kProbeWGSL, WGPU_STRLEN}; + WGPUShaderModuleDescriptor shader_desc = {}; + shader_desc.nextInChain = &wgsl_desc.chain; + WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc); + + WGPUBindGroupLayoutEntry bgl_entry = {}; + bgl_entry.binding = 0; + bgl_entry.visibility = WGPUShaderStage_Compute; + bgl_entry.buffer.type = WGPUBufferBindingType_Storage; + WGPUBindGroupLayoutDescriptor bgl_desc = {}; + bgl_desc.entryCount = 1; + bgl_desc.entries = &bgl_entry; + WGPUBindGroupLayout bgl = wgpuDeviceCreateBindGroupLayout(device, &bgl_desc); + + WGPUPipelineLayoutDescriptor pl_desc = {}; + pl_desc.bindGroupLayoutCount = 1; + pl_desc.bindGroupLayouts = &bgl; + WGPUPipelineLayout pl = wgpuDeviceCreatePipelineLayout(device, &pl_desc); + + WGPUComputePipelineDescriptor pipe_desc = {}; + pipe_desc.layout = pl; + pipe_desc.compute.module = shader; + pipe_desc.compute.entryPoint = {"main", WGPU_STRLEN}; + WGPUComputePipeline pipe = + wgpuDeviceCreateComputePipeline(device, &pipe_desc); + + WGPUBufferDescriptor obd = {}; + obd.size = 64 * sizeof(float); + obd.usage = WGPUBufferUsage_Storage; + WGPUBuffer out_buf = wgpuDeviceCreateBuffer(device, &obd); + + WGPUBindGroupEntry bg_entry = {}; + bg_entry.binding = 0; + bg_entry.buffer = out_buf; + bg_entry.size = obd.size; + WGPUBindGroupDescriptor bg_desc = {}; + bg_desc.layout = bgl; + bg_desc.entryCount = 1; + bg_desc.entries = &bg_entry; + WGPUBindGroup bg = wgpuDeviceCreateBindGroup(device, &bg_desc); + + WebGPUQueryPool qp; + qp.initialize(device, 1); + qp.reset(1); + + WGPUCommandEncoder enc = wgpuDeviceCreateCommandEncoder(device, nullptr); + WGPUPassTimestampWrites tw = qp.writes_for(0); + WGPUComputePassDescriptor pass_desc = {}; + pass_desc.timestampWrites = &tw; + WGPUComputePassEncoder pass = + wgpuCommandEncoderBeginComputePass(enc, &pass_desc); + wgpuComputePassEncoderSetPipeline(pass, pipe); + wgpuComputePassEncoderSetBindGroup(pass, 0, bg, 0, nullptr); + wgpuComputePassEncoderDispatchWorkgroups(pass, 1, 1, 1); + wgpuComputePassEncoderEnd(pass); + wgpuComputePassEncoderRelease(pass); + qp.record(0, "probe", {1, 1, 1}, {64, 1, 1}); + qp.resolve(enc); + WGPUCommandBuffer cmd = wgpuCommandEncoderFinish(enc, nullptr); + wgpuQueueSubmit(ctx.queue, 1, &cmd); + wgpuCommandBufferRelease(cmd); + wgpuCommandEncoderRelease(enc); + + qp.extract_results(ctx.instance); + + wgpuBufferRelease(out_buf); + wgpuComputePipelineRelease(pipe); + wgpuPipelineLayoutRelease(pl); + wgpuBindGroupLayoutRelease(bgl); + wgpuBindGroupRelease(bg); + wgpuShaderModuleRelease(shader); + + if (qp.results().size() != 1) { + printf("FAIL: expected 1 duration, got %zu\n", qp.results().size()); + return false; + } + const uint64_t dur = qp.results()[0].execution_duration_ns; + printf(" probe duration: %llu ns\n", (unsigned long long)dur); + if (dur == 0) { + printf("FAIL: probe duration is zero (expected monotonic non-zero)\n"); + return false; + } + printf("PASS: WebGPUQueryPool roundtrip -- non-zero GPU kernel duration\n"); + return true; +} +#endif // WGPU_BACKEND_ENABLE_PROFILING + int main(int argc, char** argv) { std::string model_path = "webgpu_add_test.pte"; if (argc > 1) { @@ -158,7 +283,12 @@ int main(int argc, char** argv) { set_default_webgpu_context(&ctx); printf("WebGPU device acquired (native)\n"); - bool ok = test_single_add(model_path); + bool ok = true; +#ifdef WGPU_BACKEND_ENABLE_PROFILING + ok = test_query_pool_overrun_throws() && ok; + ok = test_query_pool_roundtrip(ctx) && ok; +#endif // WGPU_BACKEND_ENABLE_PROFILING + ok = test_single_add(model_path) && ok; if (!chained_model_path.empty()) { ok = test_chained_add(chained_model_path) && ok; From 49c61608e7f495903ccd5315af8f2fdb01ca3a50 Mon Sep 17 00:00:00 2001 From: Julian Ng-Thow-Hing Date: Fri, 12 Jun 2026 15:35:52 -0700 Subject: [PATCH 2/2] [ExecuTorch][WebGPU] Add fused SDPA (sdpa_with_kv_cache) with dynamic input_pos Pull Request resolved: https://github.com/pytorch/executorch/pull/20086 Adds the fused `sdpa_with_kv_cache` op (QK attention-weights, softmax, attention-output sub-kernels over the KV cache), composing the three enablers below it: the base graph's inter-dispatch buffer passing (scratch buffers + multi-pass execute), the `update_cache` op, and the SymInt live-scalar mechanism. The QK/softmax/AV kernels mirror the Vulkan reference's flat-index/GQA/causal-mask math (NCHW, buffer-only, fp32). `input_pos` is consumed dynamically via the SymInt mechanism: the op reads `symint_buffer()` as a uniform, sizes its scratch + dispatches for the max context length, and registers a resize hook so a single delegate runs an autoregressive decode loop (feed only the new token + advancing `input_pos`) instead of a fixed baked position. Mirrors the Vulkan SymInt = live uniform-buffer design. Tests live in the stacked test-suite diff above (clean op diff here). Authored with assistance from Claude. ghstack-source-id: 392609088 @exported-using-ghexport Differential Revision: [D107595125](https://our.internmc.facebook.com/intern/diff/D107595125/) --- backends/webgpu/CMakeLists.txt | 1 + backends/webgpu/runtime/WebGPUGraph.h | 3 + backends/webgpu/runtime/ops/sdpa/Sdpa.cpp | 616 ++++++++++++++++++ .../ops/sdpa/sdpa_compute_attn_weights.wgsl | 55 ++ .../ops/sdpa/sdpa_compute_attn_weights_wgsl.h | 79 +++ .../runtime/ops/sdpa/sdpa_compute_out.wgsl | 46 ++ .../runtime/ops/sdpa/sdpa_compute_out_wgsl.h | 70 ++ .../webgpu/runtime/ops/sdpa/sdpa_softmax.wgsl | 101 +++ .../runtime/ops/sdpa/sdpa_softmax_wgsl.h | 125 ++++ 9 files changed, 1096 insertions(+) create mode 100644 backends/webgpu/runtime/ops/sdpa/Sdpa.cpp create mode 100644 backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl create mode 100644 backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h create mode 100644 backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl create mode 100644 backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h create mode 100644 backends/webgpu/runtime/ops/sdpa/sdpa_softmax.wgsl create mode 100644 backends/webgpu/runtime/ops/sdpa/sdpa_softmax_wgsl.h diff --git a/backends/webgpu/CMakeLists.txt b/backends/webgpu/CMakeLists.txt index 1fc0860fc4b..3393d3e35e6 100644 --- a/backends/webgpu/CMakeLists.txt +++ b/backends/webgpu/CMakeLists.txt @@ -35,6 +35,7 @@ set(WEBGPU_SRCS runtime/ops/add/BinaryOp.cpp runtime/ops/rms_norm/RmsNorm.cpp runtime/ops/update_cache/UpdateCache.cpp + runtime/ops/sdpa/Sdpa.cpp runtime/ops/select_as_symint/SelectAsSymint.cpp ) diff --git a/backends/webgpu/runtime/WebGPUGraph.h b/backends/webgpu/runtime/WebGPUGraph.h index 92aa14d59b6..3cff09ecb6d 100644 --- a/backends/webgpu/runtime/WebGPUGraph.h +++ b/backends/webgpu/runtime/WebGPUGraph.h @@ -106,6 +106,9 @@ class WebGPUGraph { int64_t get_int(int id) const { return ints_[id]; } + bool get_bool(int id) const { + return bools_[id]; + } // Live-scalar (SymInt) API; mirrors the Vulkan SymInt/ParamsBuffer UBO. // set_symint writes the buffer + marks dirty only if the value changed. diff --git a/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp b/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp new file mode 100644 index 00000000000..dd48f6f5902 --- /dev/null +++ b/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp @@ -0,0 +1,616 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include + +namespace executorch::backends::webgpu { + +namespace { + +// Uniform param structs (all 16-byte aligned, matching the WGSL Params). +struct UpdateCacheParams { + uint32_t numel; + uint32_t dst_offset; + uint32_t cache_numel; + uint32_t _pad0; +}; +static_assert(sizeof(UpdateCacheParams) == 16, "UpdateCacheParams must be 16B"); + +struct AttnWeightsParams { + uint32_t S; + uint32_t Hq; + uint32_t Hkv; + uint32_t D; + uint32_t context_len; + uint32_t input_pos; + uint32_t g; + float scale; +}; +static_assert(sizeof(AttnWeightsParams) == 32, "AttnWeightsParams must be 32B"); + +struct SoftmaxParams { + uint32_t num_rows; + uint32_t row_width; + uint32_t _pad0; + uint32_t _pad1; +}; +static_assert(sizeof(SoftmaxParams) == 16, "SoftmaxParams must be 16B"); + +struct ComputeOutParams { + uint32_t S; + uint32_t Hq; + uint32_t Hkv; + uint32_t D; + uint32_t context_len; + uint32_t g; + uint32_t _pad0; + uint32_t _pad1; +}; +static_assert(sizeof(ComputeOutParams) == 32, "ComputeOutParams must be 32B"); + +// Param-struct builder helpers — used in both initial build and resize hook. +static UpdateCacheParams make_update_cache_params( + uint64_t kv_numel, + uint32_t dst_offset, + uint64_t cache_numel) { + UpdateCacheParams p = {}; + p.numel = static_cast(kv_numel); + p.dst_offset = dst_offset; + p.cache_numel = static_cast(cache_numel); + return p; +} + +static AttnWeightsParams make_attn_weights_params( + int64_t S, + int64_t Hq, + int64_t Hkv, + int64_t D, + int64_t ctx, + int64_t pos, + int64_t g, + float scale) { + AttnWeightsParams p = {}; + p.S = static_cast(S); + p.Hq = static_cast(Hq); + p.Hkv = static_cast(Hkv); + p.D = static_cast(D); + p.context_len = static_cast(ctx); + p.input_pos = static_cast(pos); + p.g = static_cast(g); + p.scale = scale; + return p; +} + +static SoftmaxParams make_softmax_params(int64_t Hq, int64_t S, int64_t ctx) { + SoftmaxParams p = {}; + p.num_rows = static_cast(Hq * S); + p.row_width = static_cast(ctx); + return p; +} + +static ComputeOutParams make_compute_out_params( + int64_t S, + int64_t Hq, + int64_t Hkv, + int64_t D, + int64_t ctx, + int64_t g) { + ComputeOutParams p = {}; + p.S = static_cast(S); + p.Hq = static_cast(Hq); + p.Hkv = static_cast(Hkv); + p.D = static_cast(D); + p.context_len = static_cast(ctx); + p.g = static_cast(g); + return p; +} + +// Create a uniform buffer initialized with the given bytes. +WGPUBuffer +make_uniform_buffer(WebGPUGraph& graph, const void* data, size_t size) { + WGPUDevice device = graph.device(); + WGPUBufferDescriptor desc = {}; + desc.size = size; + desc.usage = WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst; + desc.mappedAtCreation = true; + WGPUBuffer buffer = wgpuDeviceCreateBuffer(device, &desc); + void* mapped = wgpuBufferGetMappedRange(buffer, 0, size); + std::memcpy(mapped, data, size); + wgpuBufferUnmap(buffer); + graph.add_uniform_buffer_bytes(size); + return buffer; +} + +// A buffer + its byte size, for binding. +struct BufferBinding { + WGPUBuffer buffer; + uint64_t size; +}; + +// Build one dispatch (pipeline + bind group) and record it on the graph. +void build_dispatch( + WebGPUGraph& graph, + const char* wgsl_source, + const BufferBinding* storage_bindings, + uint32_t n_storage, // includes the rw output at index 0 + WGPUBuffer uniform_buffer, + uint64_t uniform_size, + uint32_t workgroup_count_x, + uint32_t wg_size, + bool retain_uniform = false) { + WGPUDevice device = graph.device(); + + WGPUShaderSourceWGSL wgsl_desc = {}; + wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL; + wgsl_desc.code = {wgsl_source, WGPU_STRLEN}; + WGPUShaderModuleDescriptor shader_desc = {}; + shader_desc.nextInChain = &wgsl_desc.chain; + WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc); + + // Bind group layout: storage entries then the uniform. + constexpr uint32_t kMaxEntries = 8; + if (n_storage + 1 > kMaxEntries) { + throw std::runtime_error("WebGPU sdpa: n_storage exceeds kMaxEntries"); + } + WGPUBindGroupLayoutEntry bgl_entries[kMaxEntries] = {}; + const uint32_t uniform_binding = n_storage; + for (uint32_t i = 0; i < n_storage; i++) { + bgl_entries[i].binding = i; + bgl_entries[i].visibility = WGPUShaderStage_Compute; + bgl_entries[i].buffer.type = (i == 0) + ? WGPUBufferBindingType_Storage + : WGPUBufferBindingType_ReadOnlyStorage; + } + bgl_entries[uniform_binding].binding = uniform_binding; + bgl_entries[uniform_binding].visibility = WGPUShaderStage_Compute; + bgl_entries[uniform_binding].buffer.type = WGPUBufferBindingType_Uniform; + + WGPUBindGroupLayoutDescriptor bgl_desc = {}; + bgl_desc.entryCount = n_storage + 1; + bgl_desc.entries = bgl_entries; + WGPUBindGroupLayout bgl = wgpuDeviceCreateBindGroupLayout(device, &bgl_desc); + + WGPUPipelineLayoutDescriptor pl_desc = {}; + pl_desc.bindGroupLayoutCount = 1; + pl_desc.bindGroupLayouts = &bgl; + WGPUPipelineLayout pipeline_layout = + wgpuDeviceCreatePipelineLayout(device, &pl_desc); + + // QK/AV/update_cache have an `override wg_size`; softmax (0) keeps a const. + WGPUConstantEntry wg_size_constant = {}; + wg_size_constant.key = {"wg_size", WGPU_STRLEN}; + wg_size_constant.value = static_cast(wg_size); + + WGPUComputePipelineDescriptor pipeline_desc = {}; + pipeline_desc.layout = pipeline_layout; + pipeline_desc.compute.module = shader; + pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN}; + if (wg_size != 0) { + pipeline_desc.compute.constantCount = 1; + pipeline_desc.compute.constants = &wg_size_constant; + } + WGPUComputePipeline pipeline = + wgpuDeviceCreateComputePipeline(device, &pipeline_desc); + + WGPUBindGroupEntry bg_entries[kMaxEntries] = {}; + for (uint32_t i = 0; i < n_storage; i++) { + bg_entries[i].binding = i; + bg_entries[i].buffer = storage_bindings[i].buffer; + bg_entries[i].size = storage_bindings[i].size; + } + bg_entries[uniform_binding].binding = uniform_binding; + bg_entries[uniform_binding].buffer = uniform_buffer; + bg_entries[uniform_binding].size = uniform_size; + + WGPUBindGroupDescriptor bg_desc = {}; + bg_desc.layout = bgl; + bg_desc.entryCount = n_storage + 1; + bg_desc.entries = bg_entries; + WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc); + + graph.add_dispatch({pipeline, bind_group, workgroup_count_x}); + + wgpuShaderModuleRelease(shader); + wgpuBindGroupLayoutRelease(bgl); + wgpuPipelineLayoutRelease(pipeline_layout); + if (retain_uniform) { + // Graph owns it so a resize hook can rewrite it; freed in the dtor. + graph.own_uniform_buffer(uniform_buffer); + } else { + // Drop our ref; the bind group keeps the uniform alive. + wgpuBufferRelease(uniform_buffer); + } +} + +// Dispatch one update_cache (K or V); returns the retained uniform buffer. +static WGPUBuffer record_update_cache_dispatch( + WebGPUGraph& graph, + WGPUDevice device, + const WebGPUTensor& cache, + const WebGPUTensor& src, + uint64_t kv_numel, + uint32_t kv_dst_offset, + uint64_t cache_numel, + uint32_t uc_wg, + bool dynamic_pos, + const char* label) { + const uint32_t wgc = utils::compute_1d_workgroup_count( + device, static_cast(kv_numel), uc_wg, label); + UpdateCacheParams uc = + make_update_cache_params(kv_numel, kv_dst_offset, cache_numel); + WGPUBuffer ubuf = make_uniform_buffer(graph, &uc, sizeof(uc)); + BufferBinding bindings[2] = { + {cache.buffer, cache.nbytes}, {src.buffer, src.nbytes}}; + build_dispatch( + graph, + kUpdateCacheWGSL, + bindings, + 2, + ubuf, + sizeof(uc), + wgc, + uc_wg, + dynamic_pos); + return ubuf; +} + +// llama.sdpa_with_kv_cache.default args mirror the Vulkan impl. +void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector& args) { + const int q_id = args.at(0); + const int k_id = args.at(1); + const int v_id = args.at(2); + const int k_cache_id = args.at(3); + const int v_cache_id = args.at(4); + const int input_pos_id = args.at(5); + // arg 6 (seq_len) is derived from q; args 7-9 validated below. + const int attn_mask_id = args.at(7); + const int drop_p_id = args.at(8); + const int is_causal_id = args.at(9); + const int scale_id = args.at(10); + const int out_id = args.at(11); + + const auto& q = graph.get_tensor(q_id); + const auto& k = graph.get_tensor(k_id); + const auto& v = graph.get_tensor(v_id); + const auto& k_cache = graph.get_tensor(k_cache_id); + const auto& v_cache = graph.get_tensor(v_cache_id); + const auto& out = graph.get_tensor(out_id); + + if (q.dims.size() < 3 || k.dims.size() < 3 || v.dims.size() < 3 || + k_cache.dims.size() < 3) { + throw std::runtime_error("WebGPU sdpa: q/k/v/k_cache must be rank >= 3"); + } + + // q [1, S, Hq, D]; k/v [1, S, Hkv, D]; caches [1, Cmax, Hkv, D]. + const size_t qn = q.dims.size(); + const int64_t S = q.dims[qn - 3]; + const int64_t Hq = q.dims[qn - 2]; + const int64_t D = q.dims[qn - 1]; + + const size_t kn = k.dims.size(); + const int64_t Hkv = k.dims[kn - 2]; + + const size_t cn = k_cache.dims.size(); + const int64_t Cmax = k_cache.dims[cn - 3]; + + // Validate B == 1 (leading dims must all be 1). + for (size_t i = 0; i + 3 < qn; i++) { + if (q.dims[i] != 1) { + throw std::runtime_error("WebGPU sdpa: only batch size 1 is supported"); + } + } + if (S <= 0 || Hq <= 0 || D <= 0 || Hkv <= 0 || Cmax <= 0) { + throw std::runtime_error("WebGPU sdpa: non-positive dimension"); + } + if (Hq % Hkv != 0) { + throw std::runtime_error("WebGPU sdpa: Hq must be a multiple of Hkv (GQA)"); + } + const int64_t g = Hq / Hkv; + + // k/v seq-len must match q's S. + if (k.dims[kn - 3] != S || v.dims[v.dims.size() - 3] != S) { + throw std::runtime_error("WebGPU sdpa: k/v seq_len must match q"); + } + + // k/v projected shapes must match q/k; mirrors Vulkan update_cache -1/-2. + if (k.dims[kn - 1] != D || v.dims[v.dims.size() - 1] != D) { + throw std::runtime_error("WebGPU sdpa: k/v head_dim must match q"); + } + if (v.dims[v.dims.size() - 2] != Hkv) { + throw std::runtime_error("WebGPU sdpa: v num_heads must match k"); + } + + // Mirrors Vulkan SDPA: q/k_cache head_dim + k_cache/v_cache shape must match. + if (D != k_cache.dims[cn - 1]) { + throw std::runtime_error("WebGPU sdpa: q and k_cache head_dim mismatch"); + } + if (k_cache.dims != v_cache.dims) { + throw std::runtime_error("WebGPU sdpa: k_cache and v_cache shape mismatch"); + } + + // fp32-only: validate byte counts against fp32 element counts. + auto numel = [](const WebGPUTensor& t) { + uint64_t n = 1; + for (int64_t d : t.dims) { + n *= static_cast(d); + } + return n; + }; + if (q.nbytes != numel(q) * sizeof(float) || + k.nbytes != numel(k) * sizeof(float) || + v.nbytes != numel(v) * sizeof(float) || + out.nbytes != numel(out) * sizeof(float)) { + throw std::runtime_error("WebGPU sdpa: fp32-only (byte-size mismatch)"); + } + + // input_pos: build-time Int (baked) OR runtime SymInt (dynamic decode). + int64_t input_pos = 0; + const auto input_pos_type = graph.get_value_type(input_pos_id); + const bool dynamic_pos = input_pos_type == WebGPUGraph::ValueType::SymInt; + if (dynamic_pos) { + input_pos = graph.read_symint(input_pos_id); // build placeholder (e.g. 0) + } else if (input_pos_type == WebGPUGraph::ValueType::Int) { + input_pos = graph.get_int(input_pos_id); + } else { + // No silent default-to-0; mirrors Vulkan get_or_create_int_param_buffer. + throw std::runtime_error("WebGPU sdpa: input_pos must be Int or SymInt"); + } + if (input_pos < 0) { + throw std::runtime_error("WebGPU sdpa: input_pos must be non-negative"); + } + const int64_t context_len = S + input_pos; + if (context_len <= 0 || context_len > Cmax) { + throw std::runtime_error("WebGPU sdpa: context_len exceeds cache capacity"); + } + + // scale arg is None (use 1/sqrt(D)) or an explicit Double; reject others. + float scale = 1.0f / std::sqrt(static_cast(D)); + const auto scale_type = graph.get_value_type(scale_id); + if (scale_type == WebGPUGraph::ValueType::Double) { + scale = static_cast(graph.get_double(scale_id)); + } else if (scale_type != WebGPUGraph::ValueType::Null) { + throw std::runtime_error("WebGPU sdpa: scale must be None or a Double"); + } + + // Unsupported attention args must be absent/default; mirrors Vulkan + // SDPA.cpp:587-593 (scale is handled above as an intentional extension). + using VT = WebGPUGraph::ValueType; + if (graph.get_value_type(attn_mask_id) != VT::Null) { + throw std::runtime_error("WebGPU sdpa: attn_mask is not supported"); + } + // dropout_p: serializer may dedup 0.0 onto input_pos's Int(0) when pos=0. + const auto drop_type = graph.get_value_type(drop_p_id); + if (!(drop_type == VT::Null || + (drop_type == VT::Double && graph.get_double(drop_p_id) == 0.0) || + (drop_type == VT::Int && graph.get_int(drop_p_id) == 0))) { + throw std::runtime_error("WebGPU sdpa: only dropout_p=0 is supported"); + } + const auto causal_type = graph.get_value_type(is_causal_id); + if (!(causal_type == VT::Null || + (causal_type == VT::Bool && graph.get_bool(is_causal_id)))) { + throw std::runtime_error("WebGPU sdpa: only is_causal=true is supported"); + } + + // KV cache written in place; only attn_weights/softmax need scratch. + const uint64_t aw_floats = static_cast(Hq) * + static_cast(S) * static_cast(context_len); + // Dynamic input_pos: size+bind scratch for Cmax (no realloc; covers any ctx). + const uint64_t aw_cap_floats = static_cast(Hq) * + static_cast(S) * + static_cast(dynamic_pos ? Cmax : context_len); + const uint64_t aw_bytes = aw_cap_floats * sizeof(float); + // Prefill scratch scales as Hq·S·Cmax; can be large for long-context prefill. + WGPUBuffer attn_weights = graph.create_scratch_buffer(aw_bytes); + WGPUBuffer attn_weights_softmax = graph.create_scratch_buffer(aw_bytes); + + // Dynamic input_pos: the resize hook rewrites these per step. + WGPUBuffer uc_k_buf = nullptr, uc_v_buf = nullptr, qk_buf = nullptr, + softmax_buf = nullptr, av_buf = nullptr; + size_t qk_idx = 0; + + const WGPUDevice device = graph.device(); + const uint32_t uc_wg = + utils::clamp_workgroup_size(device, kUpdateCacheWorkgroupSizeX); + const uint32_t qk_wg = utils::clamp_workgroup_size( + device, kSdpaComputeAttnWeightsWorkgroupSizeX); + const uint32_t av_wg = + utils::clamp_workgroup_size(device, kSdpaComputeOutWorkgroupSizeX); + + // Dispatches 1-2: write new K/V into the caches (reuses update_cache). + const uint64_t kv_numel = static_cast(S) * + static_cast(Hkv) * static_cast(D); + const uint32_t kv_dst_offset = static_cast( + static_cast(input_pos) * static_cast(Hkv) * + static_cast(D)); + uc_k_buf = record_update_cache_dispatch( + graph, + device, + k_cache, + k, + kv_numel, + kv_dst_offset, + numel(k_cache), + uc_wg, + dynamic_pos, + "update_cache(K)"); + uc_v_buf = record_update_cache_dispatch( + graph, + device, + v_cache, + v, + kv_numel, + kv_dst_offset, + numel(v_cache), + uc_wg, + dynamic_pos, + "update_cache(V)"); + + // --- Dispatch 3: QK -> attn_weights. One thread per (h,s,c) element. + { + if (aw_floats > UINT32_MAX) { + throw std::runtime_error( + "WebGPU sdpa: Hq*S*context_len exceeds uint32 max"); + } + const uint32_t wgc = utils::compute_1d_workgroup_count( + device, static_cast(aw_floats), qk_wg, "QK"); + AttnWeightsParams p = make_attn_weights_params( + S, Hq, Hkv, D, context_len, input_pos, g, scale); + WGPUBuffer ubuf = make_uniform_buffer(graph, &p, sizeof(p)); + BufferBinding bindings[3] = { + {attn_weights, aw_bytes}, + {q.buffer, q.nbytes}, + {k_cache.buffer, k_cache.nbytes}}; + build_dispatch( + graph, + kSdpaComputeAttnWeightsWGSL, + bindings, + 3, + ubuf, + sizeof(p), + wgc, + qk_wg, + dynamic_pos); + qk_buf = ubuf; + qk_idx = graph.num_dispatches() - 1; + } + + // Dispatch 4: softmax, one workgroup per (h,s) row of width context_len. + { + // One workgroup per (h,s) row; wg_size 1 keeps the device dispatch check. + const uint32_t wgc = utils::compute_1d_workgroup_count( + device, static_cast(Hq * S), 1, "softmax"); + SoftmaxParams p = make_softmax_params(Hq, S, context_len); + WGPUBuffer ubuf = make_uniform_buffer(graph, &p, sizeof(p)); + BufferBinding bindings[2] = { + {attn_weights_softmax, aw_bytes}, {attn_weights, aw_bytes}}; + build_dispatch( + graph, + kSdpaSoftmaxWGSL, + bindings, + 2, + ubuf, + sizeof(p), + wgc, + 0, + dynamic_pos); + softmax_buf = ubuf; + } + + // --- Dispatch 5: AV -> out. One thread per (s,h,d) output element. + { + const uint64_t out_floats = static_cast(S) * + static_cast(Hq) * static_cast(D); + const uint32_t wgc = utils::compute_1d_workgroup_count( + device, static_cast(out_floats), av_wg, "AV"); + ComputeOutParams p = make_compute_out_params(S, Hq, Hkv, D, context_len, g); + WGPUBuffer ubuf = make_uniform_buffer(graph, &p, sizeof(p)); + BufferBinding bindings[3] = { + {out.buffer, out.nbytes}, + {attn_weights_softmax, aw_bytes}, + {v_cache.buffer, v_cache.nbytes}}; + build_dispatch( + graph, + kSdpaComputeOutWGSL, + bindings, + 3, + ubuf, + sizeof(p), + wgc, + av_wg, + dynamic_pos); + av_buf = ubuf; + } + + // Per-step recompute hook; mirrors Vulkan DynamicDispatchNode. + if (dynamic_pos) { + graph.add_resize_hook( + input_pos_id, + [input_pos_id, + S, + Hq, + Hkv, + D, + Cmax, + g, + scale, + qk_idx, + qk_wg, + uc_k_buf, + uc_v_buf, + qk_buf, + softmax_buf, + av_buf](WebGPUGraph& gr) { + const int32_t pos = gr.read_symint(input_pos_id); + if (pos < 0) { + throw std::runtime_error( + "WebGPU sdpa: input_pos must be non-negative"); + } + const int64_t ctx = S + pos; + if (ctx <= 0 || ctx > Cmax) { + throw std::runtime_error( + "WebGPU sdpa: context_len exceeds cache capacity"); + } + const uint32_t kv_off = static_cast( + static_cast(pos) * static_cast(Hkv) * + static_cast(D)); + const uint64_t aw_floats = static_cast(Hq) * + static_cast(S) * static_cast(ctx); + if (aw_floats > UINT32_MAX) { + throw std::runtime_error( + "WebGPU sdpa: Hq*S*context_len exceeds uint32 max"); + } + const uint64_t kv_numel = static_cast(S) * + static_cast(Hkv) * static_cast(D); + const uint64_t k_cache_numel = static_cast(Cmax) * + static_cast(Hkv) * static_cast(D); + + UpdateCacheParams uc = + make_update_cache_params(kv_numel, kv_off, k_cache_numel); + wgpuQueueWriteBuffer(gr.queue(), uc_k_buf, 0, &uc, sizeof(uc)); + wgpuQueueWriteBuffer(gr.queue(), uc_v_buf, 0, &uc, sizeof(uc)); + + AttnWeightsParams qp = + make_attn_weights_params(S, Hq, Hkv, D, ctx, pos, g, scale); + wgpuQueueWriteBuffer(gr.queue(), qk_buf, 0, &qp, sizeof(qp)); + const uint32_t qk_wgc = utils::compute_1d_workgroup_count( + gr.device(), + static_cast(aw_floats), + qk_wg, + "QK(resize)"); + gr.dispatch_at(qk_idx).workgroup_count_x = qk_wgc; + + SoftmaxParams sp = make_softmax_params(Hq, S, ctx); + wgpuQueueWriteBuffer(gr.queue(), softmax_buf, 0, &sp, sizeof(sp)); + + ComputeOutParams op = make_compute_out_params(S, Hq, Hkv, D, ctx, g); + wgpuQueueWriteBuffer(gr.queue(), av_buf, 0, &op, sizeof(op)); + }); + } +} + +} // namespace + +WEBGPU_REGISTER_OPERATORS { + WEBGPU_REGISTER_OP(sdpa_with_kv_cache.default, sdpa_with_kv_cache_impl); +} + +} // namespace executorch::backends::webgpu diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl new file mode 100644 index 00000000000..b9905a59376 --- /dev/null +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl @@ -0,0 +1,55 @@ +@group(0) @binding(0) var t_attn_weights: array; +@group(0) @binding(1) var t_q: array; +@group(0) @binding(2) var t_k_cache: array; + +struct Params { + S: u32, + Hq: u32, + Hkv: u32, + D: u32, + context_len: u32, + input_pos: u32, + g: u32, + scale: f32, +} +@group(0) @binding(3) var params: Params; + +// WGSL forbids literal -inf; large finite negative is a WGSL-safe stand-in. +const NEG_INF: f32 = -1.0e30; + +override wg_size: u32 = 64; + +@compute @workgroup_size(wg_size, 1, 1) +fn main(@builtin(global_invocation_id) gid: vec3) { + let total = params.Hq * params.S * params.context_len; + let idx = gid.x; + if (idx >= total) { + return; + } + let c = idx % params.context_len; + let s = (idx / params.context_len) % params.S; + let h = idx / (params.context_len * params.S); + + let kvh = h / params.g; + + let q_base = s * params.Hq * params.D + h * params.D; + let k_base = c * params.Hkv * params.D + kvh * params.D; + + var acc: f32 = 0.0; + var d: u32 = 0u; + loop { + if (d >= params.D) { + break; + } + acc = acc + t_q[q_base + d] * t_k_cache[k_base + d]; + d = d + 1u; + } + acc = acc * params.scale; + + // Causal mask: position c may not attend beyond s + input_pos. + if (c > s + params.input_pos) { + acc = NEG_INF; + } + + t_attn_weights[idx] = acc; +} diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h new file mode 100644 index 00000000000..3f3f3d6b085 --- /dev/null +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h @@ -0,0 +1,79 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch::backends::webgpu { + +// @generated from sdpa_compute_attn_weights.wgsl - DO NOT EDIT. +// wgsl-sha256: 7410869c1c35f09777851bf49b835dc8fecaff3f327aa64a9c900ac0cc3445e1 +inline constexpr const char* kSdpaComputeAttnWeightsWGSL = R"( +@group(0) @binding(0) var t_attn_weights: array; +@group(0) @binding(1) var t_q: array; +@group(0) @binding(2) var t_k_cache: array; + +struct Params { + S: u32, + Hq: u32, + Hkv: u32, + D: u32, + context_len: u32, + input_pos: u32, + g: u32, + scale: f32, +} +@group(0) @binding(3) var params: Params; + +// WGSL forbids literal -inf; large finite negative is a WGSL-safe stand-in. +const NEG_INF: f32 = -1.0e30; + +override wg_size: u32 = 64; + +@compute @workgroup_size(wg_size, 1, 1) +fn main(@builtin(global_invocation_id) gid: vec3) { + let total = params.Hq * params.S * params.context_len; + let idx = gid.x; + if (idx >= total) { + return; + } + let c = idx % params.context_len; + let s = (idx / params.context_len) % params.S; + let h = idx / (params.context_len * params.S); + + let kvh = h / params.g; + + let q_base = s * params.Hq * params.D + h * params.D; + let k_base = c * params.Hkv * params.D + kvh * params.D; + + var acc: f32 = 0.0; + var d: u32 = 0u; + loop { + if (d >= params.D) { + break; + } + acc = acc + t_q[q_base + d] * t_k_cache[k_base + d]; + d = d + 1u; + } + acc = acc * params.scale; + + // Causal mask: position c may not attend beyond s + input_pos. + if (c > s + params.input_pos) { + acc = NEG_INF; + } + + t_attn_weights[idx] = acc; +} +)"; + +inline constexpr uint32_t kSdpaComputeAttnWeightsWorkgroupSizeX = 64; +inline constexpr uint32_t kSdpaComputeAttnWeightsWorkgroupSizeY = 1; +inline constexpr uint32_t kSdpaComputeAttnWeightsWorkgroupSizeZ = 1; + +} // namespace executorch::backends::webgpu diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl new file mode 100644 index 00000000000..97642670f60 --- /dev/null +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl @@ -0,0 +1,46 @@ +@group(0) @binding(0) var t_out: array; +@group(0) @binding(1) var t_attn_weights_softmax: array; +@group(0) @binding(2) var t_v_cache: array; + +struct Params { + S: u32, + Hq: u32, + Hkv: u32, + D: u32, + context_len: u32, + g: u32, + _pad0: u32, + _pad1: u32, +} +@group(0) @binding(3) var params: Params; + +override wg_size: u32 = 64; + +@compute @workgroup_size(wg_size, 1, 1) +fn main(@builtin(global_invocation_id) gid: vec3) { + let total = params.S * params.Hq * params.D; + let idx = gid.x; + if (idx >= total) { + return; + } + let d = idx % params.D; + let h = (idx / params.D) % params.Hq; + let s = idx / (params.D * params.Hq); + + let kvh = h / params.g; + + let aw_base = h * params.S * params.context_len + s * params.context_len; + + var acc: f32 = 0.0; + var c: u32 = 0u; + loop { + if (c >= params.context_len) { + break; + } + let v_off = c * params.Hkv * params.D + kvh * params.D + d; + acc = acc + t_attn_weights_softmax[aw_base + c] * t_v_cache[v_off]; + c = c + 1u; + } + + t_out[idx] = acc; +} diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h new file mode 100644 index 00000000000..ce25df06876 --- /dev/null +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch::backends::webgpu { + +// @generated from sdpa_compute_out.wgsl - DO NOT EDIT. +// wgsl-sha256: 67b9c64fbffdcb72264dda42e24b59e414719411c64c504f84f2ba57b5dcfc0f +inline constexpr const char* kSdpaComputeOutWGSL = R"( +@group(0) @binding(0) var t_out: array; +@group(0) @binding(1) var t_attn_weights_softmax: array; +@group(0) @binding(2) var t_v_cache: array; + +struct Params { + S: u32, + Hq: u32, + Hkv: u32, + D: u32, + context_len: u32, + g: u32, + _pad0: u32, + _pad1: u32, +} +@group(0) @binding(3) var params: Params; + +override wg_size: u32 = 64; + +@compute @workgroup_size(wg_size, 1, 1) +fn main(@builtin(global_invocation_id) gid: vec3) { + let total = params.S * params.Hq * params.D; + let idx = gid.x; + if (idx >= total) { + return; + } + let d = idx % params.D; + let h = (idx / params.D) % params.Hq; + let s = idx / (params.D * params.Hq); + + let kvh = h / params.g; + + let aw_base = h * params.S * params.context_len + s * params.context_len; + + var acc: f32 = 0.0; + var c: u32 = 0u; + loop { + if (c >= params.context_len) { + break; + } + let v_off = c * params.Hkv * params.D + kvh * params.D + d; + acc = acc + t_attn_weights_softmax[aw_base + c] * t_v_cache[v_off]; + c = c + 1u; + } + + t_out[idx] = acc; +} +)"; + +inline constexpr uint32_t kSdpaComputeOutWorkgroupSizeX = 64; +inline constexpr uint32_t kSdpaComputeOutWorkgroupSizeY = 1; +inline constexpr uint32_t kSdpaComputeOutWorkgroupSizeZ = 1; + +} // namespace executorch::backends::webgpu diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_softmax.wgsl b/backends/webgpu/runtime/ops/sdpa/sdpa_softmax.wgsl new file mode 100644 index 00000000000..6ef223c3a98 --- /dev/null +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_softmax.wgsl @@ -0,0 +1,101 @@ +@group(0) @binding(0) var t_out: array; +@group(0) @binding(1) var t_in: array; + +struct Params { + num_rows: u32, + row_width: u32, + _pad0: u32, + _pad1: u32, +} +@group(0) @binding(2) var params: Params; + +const WG_SIZE: u32 = 64u; + +// WGSL forbids literal -inf; a large finite negative inits the running max. +const NEG_INF: f32 = -1.0e30; + +var shared_max: array; +var shared_sum: array; + +@compute @workgroup_size(WG_SIZE, 1, 1) +fn main( + @builtin(workgroup_id) wid: vec3, + @builtin(local_invocation_id) lid: vec3) { + // One workgroup per (h, s) row of length context_len (= row_width). + let row_idx = wid.x; + let worker_id = lid.x; + + let base = row_idx * params.row_width; + let valid = row_idx < params.num_rows; + let width = params.row_width; + + // Pass 1: row max (stable softmax). Threads stride over the row. + var local_max: f32 = NEG_INF; + if (valid) { + var x: u32 = worker_id; + loop { + if (x >= width) { + break; + } + local_max = max(local_max, t_in[base + x]); + x = x + WG_SIZE; + } + } + shared_max[worker_id] = local_max; + + // Reduce max. workgroupBarrier() calls are in uniform control flow. + workgroupBarrier(); + var stride: u32 = WG_SIZE / 2u; + loop { + if (stride == 0u) { + break; + } + if (worker_id < stride) { + shared_max[worker_id] = max(shared_max[worker_id], shared_max[worker_id + stride]); + } + workgroupBarrier(); + stride = stride >> 1u; + } + let row_max = shared_max[0]; + + // Pass 2: sum of exp(x - max). + var local_sum: f32 = 0.0; + if (valid) { + var x: u32 = worker_id; + loop { + if (x >= width) { + break; + } + local_sum = local_sum + exp(t_in[base + x] - row_max); + x = x + WG_SIZE; + } + } + shared_sum[worker_id] = local_sum; + + workgroupBarrier(); + stride = WG_SIZE / 2u; + loop { + if (stride == 0u) { + break; + } + if (worker_id < stride) { + shared_sum[worker_id] = shared_sum[worker_id] + shared_sum[worker_id + stride]; + } + workgroupBarrier(); + stride = stride >> 1u; + } + let row_sum = shared_sum[0]; + + // Pass 3: normalize. Guard division by zero defensively. + if (valid) { + let inv = select(0.0, 1.0 / row_sum, row_sum > 0.0); + var x: u32 = worker_id; + loop { + if (x >= width) { + break; + } + t_out[base + x] = exp(t_in[base + x] - row_max) * inv; + x = x + WG_SIZE; + } + } +} diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_softmax_wgsl.h b/backends/webgpu/runtime/ops/sdpa/sdpa_softmax_wgsl.h new file mode 100644 index 00000000000..94f0ab5790a --- /dev/null +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_softmax_wgsl.h @@ -0,0 +1,125 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch::backends::webgpu { + +// @generated from sdpa_softmax.wgsl - DO NOT EDIT. +// wgsl-sha256: e2714ec4c2400b37f6fd39c410075c519effc0273354a4f906fb924334809024 +inline constexpr const char* kSdpaSoftmaxWGSL = R"( +@group(0) @binding(0) var t_out: array; +@group(0) @binding(1) var t_in: array; + +struct Params { + num_rows: u32, + row_width: u32, + _pad0: u32, + _pad1: u32, +} +@group(0) @binding(2) var params: Params; + +const WG_SIZE: u32 = 64u; + +// WGSL forbids literal -inf; a large finite negative inits the running max. +const NEG_INF: f32 = -1.0e30; + +var shared_max: array; +var shared_sum: array; + +@compute @workgroup_size(WG_SIZE, 1, 1) +fn main( + @builtin(workgroup_id) wid: vec3, + @builtin(local_invocation_id) lid: vec3) { + // One workgroup per (h, s) row of length context_len (= row_width). + let row_idx = wid.x; + let worker_id = lid.x; + + let base = row_idx * params.row_width; + let valid = row_idx < params.num_rows; + let width = params.row_width; + + // Pass 1: row max (stable softmax). Threads stride over the row. + var local_max: f32 = NEG_INF; + if (valid) { + var x: u32 = worker_id; + loop { + if (x >= width) { + break; + } + local_max = max(local_max, t_in[base + x]); + x = x + WG_SIZE; + } + } + shared_max[worker_id] = local_max; + + // Reduce max. workgroupBarrier() calls are in uniform control flow. + workgroupBarrier(); + var stride: u32 = WG_SIZE / 2u; + loop { + if (stride == 0u) { + break; + } + if (worker_id < stride) { + shared_max[worker_id] = max(shared_max[worker_id], shared_max[worker_id + stride]); + } + workgroupBarrier(); + stride = stride >> 1u; + } + let row_max = shared_max[0]; + + // Pass 2: sum of exp(x - max). + var local_sum: f32 = 0.0; + if (valid) { + var x: u32 = worker_id; + loop { + if (x >= width) { + break; + } + local_sum = local_sum + exp(t_in[base + x] - row_max); + x = x + WG_SIZE; + } + } + shared_sum[worker_id] = local_sum; + + workgroupBarrier(); + stride = WG_SIZE / 2u; + loop { + if (stride == 0u) { + break; + } + if (worker_id < stride) { + shared_sum[worker_id] = shared_sum[worker_id] + shared_sum[worker_id + stride]; + } + workgroupBarrier(); + stride = stride >> 1u; + } + let row_sum = shared_sum[0]; + + // Pass 3: normalize. Guard division by zero defensively. + if (valid) { + let inv = select(0.0, 1.0 / row_sum, row_sum > 0.0); + var x: u32 = worker_id; + loop { + if (x >= width) { + break; + } + t_out[base + x] = exp(t_in[base + x] - row_max) * inv; + x = x + WG_SIZE; + } + } +} +)"; + +inline constexpr uint32_t kSdpaSoftmaxWorkgroupSizeX = 64; +inline constexpr uint32_t kSdpaSoftmaxWorkgroupSizeY = 1; +inline constexpr uint32_t kSdpaSoftmaxWorkgroupSizeZ = 1; + +} // namespace executorch::backends::webgpu