From 73b1bb4417342a1512f50a0c833c7506de8ef5b5 Mon Sep 17 00:00:00 2001 From: ssjia Date: Tue, 10 Mar 2026 10:01:33 -0700 Subject: [PATCH] [ET-VK][qlinear] Add bias support to q4gsw and dq8ca_q4gsw quantized linear ops Wire bias through the q4gsw and dq8ca_q4gsw quantized linear operators. Add add_bias_to_out_tile() helper in the output tile computation header and call it from all three shader variants (tiled, coop, dq8ca_tiled). Remove the bias guard in the pattern matcher to allow biased linear layers. Differential Revision: [D95970172](https://our.internmc.facebook.com/intern/diff/D95970172/) [ghstack-poisoned] --- backends/vulkan/custom_ops_lib.py | 4 ++-- backends/vulkan/patterns/quantized_linear.py | 8 +++----- .../graph/ops/glsl/linear_dq8ca_q4gsw_tiled.glsl | 6 ++++++ .../glsl/linear_fp_output_tile_fp_compute.glslh | 10 ++++++++++ .../graph/ops/glsl/linear_q4gsw_coop.glsl | 5 +++++ .../graph/ops/glsl/linear_q4gsw_tiled.glsl | 6 ++++++ backends/vulkan/test/custom_ops/q4gsw_linear.cpp | 16 +++++----------- 7 files changed, 37 insertions(+), 18 deletions(-) diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index 87506f0b773..5c60592dca8 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -259,7 +259,7 @@ def linear_q4gsw( weights, [1, group_size], weight_scales, weight_zeros, torch.int8, -8, 7 ) - out = torch.nn.functional.linear(x, weights) + out = torch.nn.functional.linear(x, weights, bias) return out @@ -273,7 +273,7 @@ def linear_dq8ca_q4gsw( group_size: int, bias: Optional[torch.Tensor] = None, ): - return linear_q4gsw(x, weights, weight_scales, group_size) + return linear_q4gsw(x, weights, weight_scales, group_size, bias) name = "linear_q4gsw" diff --git a/backends/vulkan/patterns/quantized_linear.py b/backends/vulkan/patterns/quantized_linear.py index df80749e72f..14684a2bff1 100644 --- a/backends/vulkan/patterns/quantized_linear.py +++ b/backends/vulkan/patterns/quantized_linear.py @@ -392,6 +392,7 @@ def make_linear_q4gsw_op( match.weight_node, match.weight_scales_node, group_size, + match.bias_node, ), ) @@ -459,6 +460,7 @@ def make_linear_dq8ca_q4gsw_op( weight_sums_node, match.weight_scales_node, group_size, + match.bias_node, ), ) @@ -523,6 +525,7 @@ def make_linear_q8ta_q8csw_custom_op( match.weight_node, weight_sums_node, match.weight_scales_node, + match.bias_node, ), ) @@ -622,7 +625,6 @@ def replace_quantized_linear_patterns( assert weight_zeros_tensor is not None # Route to appropriate custom op. - # q8ta_linear supports bias, so check it first before the bias guard. if ( match.is_input_static_per_tensor_quantized() and match.is_weight_perchannel_quantized() @@ -631,10 +633,6 @@ def replace_quantized_linear_patterns( make_q8ta_linear_custom_op(ep, graph_module, match, weight_tensor) return - # Remaining ops do not support bias - if match.bias_node is not None: - return - if ( match.is_weight_only_quantized() and match.is_weight_pergroup_quantized() diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.glsl index b6c32863eb9..fa0129b65a5 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.glsl @@ -144,5 +144,11 @@ void main() { group_size); } + if (apply_bias > 0) { + FPPerOutChannelParams bias_tile; + load_bias_tile(bias_tile, n4); + add_bias_to_out_tile(out_tile, bias_tile); + } + write_output_tile_with_checks(out_tile, n4, m, N4, M); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh index 01b3c762e39..60a19ca9fc9 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh @@ -73,6 +73,16 @@ void apply_weight_scales_and_biases( } } +void add_bias_to_out_tile( + inout FPOutTile tile, + const FPPerOutChannelParams bias) { + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + tile.data[m][n4] = tile.data[m][n4] + bias.data[n4]; + } + } +} + void accumulate_out_tile_with_out_tile( inout FPOutTile accum, const FPOutTile other) { diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.glsl index 02bfe3fff0f..053f27d6c9b 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.glsl @@ -142,6 +142,11 @@ void main() { // Only the first thread will write out result if (lid == 0) { out_tile = partial_sums[0]; + if (apply_bias > 0) { + FPPerOutChannelParams bias_tile; + load_bias_tile(bias_tile, n4); + add_bias_to_out_tile(out_tile, bias_tile); + } write_output_tile_with_checks(out_tile, n4, 0, N4, 1); } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_tiled.glsl index 9a42a7fa67f..70a637ed0f8 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_tiled.glsl @@ -110,5 +110,11 @@ void main() { } } + if (apply_bias > 0) { + FPPerOutChannelParams bias_tile; + load_bias_tile(bias_tile, n4); + add_bias_to_out_tile(out_tile, bias_tile); + } + write_output_tile_with_checks(out_tile, n4, m, N4, M); } diff --git a/backends/vulkan/test/custom_ops/q4gsw_linear.cpp b/backends/vulkan/test/custom_ops/q4gsw_linear.cpp index 2af1488541d..ef6369c6b1f 100644 --- a/backends/vulkan/test/custom_ops/q4gsw_linear.cpp +++ b/backends/vulkan/test/custom_ops/q4gsw_linear.cpp @@ -148,7 +148,7 @@ TestCase create_test_case_from_config( input_dtype, storage_type, utils::kWidthPacked, - DataGenType::ZEROS); + config.has_bias ? DataGenType::RANDOM : DataGenType::ZEROS); bias.set_constant(true); if (!config.has_bias) { bias.set_none(true); @@ -237,9 +237,10 @@ std::vector generate_quantized_linear_test_cases() { {32, 64, 32, 16}, {32, 128, 64, 32}, {32, 256, 128, 64}, - // No bias tests - {32, 128, 64, 32, false}, - {32, 256, 128, 64, false}, + // With bias + {4, 64, 32, 16, true}, + {4, 128, 64, 32, true}, + {32, 128, 64, 32, true}, // Performance test cases {1, 2048, 2048, 128}, {128, 2048, 2048, 128}, @@ -499,13 +500,6 @@ void reference_impl(TestCase& test_case) { } int64_t quantized_linear_flop_calculator(const TestCase& test_case) { - int input_idx = 0; - int weight_idx = 1; - if (test_case.operator_name().find("dq8ca") != std::string::npos) { - input_idx = 0; - weight_idx = 3; // Weight comes after input, input_scale, input_zero_point - } - // Get input and weight dimensions const auto& input_sizes = test_case.inputs()[0].get_tensor_sizes(); const auto& output_sizes = test_case.outputs()[0].get_tensor_sizes();