Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/webgpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ set(WEBGPU_SRCS
runtime/ops/update_cache/UpdateCache.cpp
runtime/ops/sdpa/Sdpa.cpp
runtime/ops/select_as_symint/SelectAsSymint.cpp
runtime/ops/quantized_linear/QuantizedLinear.cpp
)

add_library(webgpu_backend ${WEBGPU_SRCS})
Expand Down
230 changes: 230 additions & 0 deletions backends/webgpu/runtime/ops/quantized_linear/QuantizedLinear.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/webgpu/runtime/WebGPUGraph.h>
#include <executorch/backends/webgpu/runtime/WebGPUUtils.h>
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>
#include <executorch/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_wgsl.h>

#include <webgpu/webgpu.h>

#include <cstdint>
#include <cstring>
#include <stdexcept>

namespace executorch::backends::webgpu {

namespace {

// Uniform layout matching the WGSL Params struct (16-byte aligned, 32 bytes).
struct Q4gswParams {
uint32_t M;
uint32_t N;
uint32_t K;
uint32_t K_packed;
uint32_t group_size;
uint32_t padded_N;
uint32_t has_bias;
uint32_t _pad;
};
static_assert(sizeof(Q4gswParams) == 32, "Q4gswParams must be 32 bytes");

// et_vk.linear_q4gsw args: [in, weight, scales, group_size, bias, out].
void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
const int in_id = args.at(0);
const int weight_id = args.at(1);
const int scales_id = args.at(2);
const int group_size_id = args.at(3);
const int bias_id = args.at(4);
const int out_id = args.at(5);

WGPUDevice device = graph.device();

const auto& in = graph.get_tensor(in_id);
const auto& weight = graph.get_tensor(weight_id);
const auto& scales = graph.get_tensor(scales_id);
const auto& out = graph.get_tensor(out_id);

if (in.dims.empty() || weight.dims.size() < 2 || scales.dims.size() < 2) {
throw std::runtime_error("WebGPU linear_q4gsw: malformed input dims");
}

// Shapes from the tensors' own dims (no dtype field at runtime).
const uint32_t K = static_cast<uint32_t>(in.dims.back());
if (K == 0) {
throw std::runtime_error("WebGPU linear_q4gsw: K == 0");
}
uint64_t in_numel = 1;
for (int64_t d : in.dims) {
in_numel *= static_cast<uint64_t>(d);
}
const uint32_t M = static_cast<uint32_t>(in_numel / K);
const uint32_t N = static_cast<uint32_t>(weight.dims[0]);
const uint32_t K_packed = static_cast<uint32_t>(weight.dims[1]);
const uint32_t num_groups = static_cast<uint32_t>(scales.dims[0]);
const uint32_t padded_N = static_cast<uint32_t>(scales.dims[1]);
if (M == 0 || N == 0) {
throw std::runtime_error("WebGPU linear_q4gsw: M or N == 0");
}
// int4 packing is 2 nibbles/byte, so K_packed must be ceil(K/2) (guards OOB).
if (K_packed != (K + 1) / 2) {
throw std::runtime_error("WebGPU linear_q4gsw: K_packed must be ceil(K/2)");
}

// One workgroup per output row (M); validate dispatch before any alloc.
const uint32_t workgroup_count =
utils::compute_1d_workgroup_count(device, M, 1, "linear_q4gsw");

// fp32-only byte-size guards (no runtime dtype); fp16 scales -> bail.
const uint64_t scales_numel =
static_cast<uint64_t>(num_groups) * static_cast<uint64_t>(padded_N);
const uint64_t weight_numel =
static_cast<uint64_t>(N) * static_cast<uint64_t>(K_packed);
if (in.nbytes != in_numel * sizeof(float) ||
out.nbytes != static_cast<uint64_t>(M) * N * sizeof(float) ||
scales.nbytes != scales_numel * sizeof(float) ||
weight.nbytes != weight_numel) {
throw std::runtime_error(
"WebGPU linear_q4gsw: fp32-only (byte-size mismatch)");
}

int64_t group_size = 0;
if (graph.get_value_type(group_size_id) == WebGPUGraph::ValueType::Int) {
group_size = graph.get_int(group_size_id);
}
if (group_size <= 0) {
throw std::runtime_error("WebGPU linear_q4gsw: group_size <= 0");
}

// Optional bias: real buffer if present, else a dummy for the fixed layout.
uint32_t has_bias = 0;
WGPUBuffer bias_buffer = nullptr;
uint64_t bias_size = 4;
if (graph.get_value_type(bias_id) == WebGPUGraph::ValueType::Tensor) {
const auto& bias = graph.get_tensor(bias_id);
if (bias.buffer == nullptr || bias.nbytes < N * sizeof(float)) {
throw std::runtime_error(
"WebGPU linear_q4gsw: bias present but null/undersized");
}
has_bias = 1;
bias_buffer = bias.buffer;
bias_size = bias.nbytes;
}
if (bias_buffer == nullptr) {
bias_buffer = graph.create_scratch_buffer(4);
bias_size = 4;
}

Q4gswParams params = {};
params.M = M;
params.N = N;
params.K = K;
params.K_packed = K_packed;
params.group_size = static_cast<uint32_t>(group_size);
params.padded_N = padded_N;
params.has_bias = has_bias;

WGPUBufferDescriptor uniform_desc = {};
uniform_desc.size = sizeof(Q4gswParams);
uniform_desc.usage = WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst;
uniform_desc.mappedAtCreation = true;
WGPUBuffer uniform_buffer = wgpuDeviceCreateBuffer(device, &uniform_desc);
void* mapped =
wgpuBufferGetMappedRange(uniform_buffer, 0, sizeof(Q4gswParams));
std::memcpy(mapped, &params, sizeof(Q4gswParams));
wgpuBufferUnmap(uniform_buffer);
graph.add_uniform_buffer_bytes(sizeof(Q4gswParams));

WGPUShaderSourceWGSL wgsl_desc = {};
wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL;
wgsl_desc.code = {kQ4gswLinearWGSL, WGPU_STRLEN};
WGPUShaderModuleDescriptor shader_desc = {};
shader_desc.nextInChain = &wgsl_desc.chain;
WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc);

// Bind group layout: out (rw) + in/weight/scales/bias (ro storage) + uniform.
WGPUBindGroupLayoutEntry entries[6] = {};
entries[0].binding = 0;
entries[0].visibility = WGPUShaderStage_Compute;
entries[0].buffer.type = WGPUBufferBindingType_Storage;
for (uint32_t i = 1; i <= 4; i++) {
entries[i].binding = i;
entries[i].visibility = WGPUShaderStage_Compute;
entries[i].buffer.type = WGPUBufferBindingType_ReadOnlyStorage;
}
entries[5].binding = 5;
entries[5].visibility = WGPUShaderStage_Compute;
entries[5].buffer.type = WGPUBufferBindingType_Uniform;

WGPUBindGroupLayoutDescriptor bgl_desc = {};
bgl_desc.entryCount = 6;
bgl_desc.entries = entries;
WGPUBindGroupLayout bgl = wgpuDeviceCreateBindGroupLayout(device, &bgl_desc);

WGPUPipelineLayoutDescriptor pl_desc = {};
pl_desc.bindGroupLayoutCount = 1;
pl_desc.bindGroupLayouts = &bgl;
WGPUPipelineLayout pipeline_layout =
wgpuDeviceCreatePipelineLayout(device, &pl_desc);

const uint32_t wg_size =
utils::clamp_workgroup_size(device, kQ4gswLinearWorkgroupSizeX);
WGPUConstantEntry wg_size_constant = {};
wg_size_constant.key = {"wg_size", WGPU_STRLEN};
wg_size_constant.value = static_cast<double>(wg_size);

WGPUComputePipelineDescriptor pipeline_desc = {};
pipeline_desc.layout = pipeline_layout;
pipeline_desc.compute.module = shader;
pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN};
pipeline_desc.compute.constantCount = 1;
pipeline_desc.compute.constants = &wg_size_constant;
WGPUComputePipeline pipeline =
wgpuDeviceCreateComputePipeline(device, &pipeline_desc);

WGPUBindGroupEntry bg_entries[6] = {};
bg_entries[0].binding = 0;
bg_entries[0].buffer = out.buffer;
bg_entries[0].size = out.nbytes;
bg_entries[1].binding = 1;
bg_entries[1].buffer = in.buffer;
bg_entries[1].size = in.nbytes;
bg_entries[2].binding = 2;
bg_entries[2].buffer = weight.buffer;
bg_entries[2].size = weight.nbytes;
bg_entries[3].binding = 3;
bg_entries[3].buffer = scales.buffer;
bg_entries[3].size = scales.nbytes;
bg_entries[4].binding = 4;
bg_entries[4].buffer = bias_buffer;
bg_entries[4].size = bias_size;
bg_entries[5].binding = 5;
bg_entries[5].buffer = uniform_buffer;
bg_entries[5].size = sizeof(Q4gswParams);

WGPUBindGroupDescriptor bg_desc = {};
bg_desc.layout = bgl;
bg_desc.entryCount = 6;
bg_desc.entries = bg_entries;
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);

graph.add_dispatch({pipeline, bind_group, workgroup_count, "linear_q4gsw"});

wgpuShaderModuleRelease(shader);
wgpuBindGroupLayoutRelease(bgl);
wgpuPipelineLayoutRelease(pipeline_layout);
wgpuBufferRelease(uniform_buffer);
}

} // namespace

WEBGPU_REGISTER_OPERATORS {
WEBGPU_REGISTER_OP(et_vk.linear_q4gsw.default, q4gsw_linear_impl);
}

} // namespace executorch::backends::webgpu
64 changes: 64 additions & 0 deletions backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
@group(0) @binding(0) var<storage, read_write> t_out: array<f32>;
@group(0) @binding(1) var<storage, read> t_input: array<f32>;
@group(0) @binding(2) var<storage, read> t_weight: array<u32>;
@group(0) @binding(3) var<storage, read> t_scales: array<f32>;
@group(0) @binding(4) var<storage, read> t_bias: array<f32>;

struct Params {
M: u32,
N: u32,
K: u32,
K_packed: u32,
group_size: u32,
padded_N: u32,
has_bias: u32,
_pad: u32,
}
@group(0) @binding(5) var<uniform> params: Params;

override wg_size: u32 = 64u;

// One workgroup per row m, threads stride N; loop logical K only (in-bounds).
@compute @workgroup_size(wg_size, 1, 1)
fn main(
@builtin(workgroup_id) wid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {
let m = wid.x;
if (m >= params.M) {
return;
}
let in_base = m * params.K;

var n: u32 = lid.x;
loop {
if (n >= params.N) {
break;
}
var acc: f32 = 0.0;
var k: u32 = 0u;
loop {
if (k >= params.K) {
break;
}
// Packed weight byte for (n, k): row stride K_packed bytes, byte k/2.
let byte_idx = n * params.K_packed + (k >> 1u);
let word = t_weight[byte_idx >> 2u];
let b = (word >> ((byte_idx & 3u) * 8u)) & 0xFFu;
var nib: u32;
if ((k & 1u) == 0u) {
nib = b & 0x0Fu; // even k -> low nibble
} else {
nib = (b >> 4u) & 0x0Fu; // odd k -> high nibble
}
let q = f32(i32(nib) - 8); // +8-shifted on pack; recover signed [-8,7]
let scale = t_scales[(k / params.group_size) * params.padded_N + n];
acc = acc + t_input[in_base + k] * q * scale;
k = k + 1u;
}
if (params.has_bias != 0u) {
acc = acc + t_bias[n];
}
t_out[m * params.N + n] = acc;
n = n + wg_size;
}
}
88 changes: 88 additions & 0 deletions backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_wgsl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <cstdint>

namespace executorch::backends::webgpu {

// @generated from q4gsw_linear.wgsl - DO NOT EDIT.
// wgsl-sha256: 966cec5d4102eb7c8f6504d2a335a1bd2f235424933fe83b4d0f8f274d894f39
inline constexpr const char* kQ4gswLinearWGSL = R"(
@group(0) @binding(0) var<storage, read_write> t_out: array<f32>;
@group(0) @binding(1) var<storage, read> t_input: array<f32>;
@group(0) @binding(2) var<storage, read> t_weight: array<u32>;
@group(0) @binding(3) var<storage, read> t_scales: array<f32>;
@group(0) @binding(4) var<storage, read> t_bias: array<f32>;

struct Params {
M: u32,
N: u32,
K: u32,
K_packed: u32,
group_size: u32,
padded_N: u32,
has_bias: u32,
_pad: u32,
}
@group(0) @binding(5) var<uniform> params: Params;

override wg_size: u32 = 64u;

// One workgroup per row m, threads stride N; loop logical K only (in-bounds).
@compute @workgroup_size(wg_size, 1, 1)
fn main(
@builtin(workgroup_id) wid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {
let m = wid.x;
if (m >= params.M) {
return;
}
let in_base = m * params.K;

var n: u32 = lid.x;
loop {
if (n >= params.N) {
break;
}
var acc: f32 = 0.0;
var k: u32 = 0u;
loop {
if (k >= params.K) {
break;
}
// Packed weight byte for (n, k): row stride K_packed bytes, byte k/2.
let byte_idx = n * params.K_packed + (k >> 1u);
let word = t_weight[byte_idx >> 2u];
let b = (word >> ((byte_idx & 3u) * 8u)) & 0xFFu;
var nib: u32;
if ((k & 1u) == 0u) {
nib = b & 0x0Fu; // even k -> low nibble
} else {
nib = (b >> 4u) & 0x0Fu; // odd k -> high nibble
}
let q = f32(i32(nib) - 8); // +8-shifted on pack; recover signed [-8,7]
let scale = t_scales[(k / params.group_size) * params.padded_N + n];
acc = acc + t_input[in_base + k] * q * scale;
k = k + 1u;
}
if (params.has_bias != 0u) {
acc = acc + t_bias[n];
}
t_out[m * params.N + n] = acc;
n = n + wg_size;
}
}
)";

inline constexpr uint32_t kQ4gswLinearWorkgroupSizeX = 64;
inline constexpr uint32_t kQ4gswLinearWorkgroupSizeY = 1;
inline constexpr uint32_t kQ4gswLinearWorkgroupSizeZ = 1;

} // namespace executorch::backends::webgpu
Loading