diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index 7f687bb10f4..b9b100797e3 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -8,7 +8,6 @@ import executorch.backends.vulkan.patterns as vk_patterns import torch.library - from torch._subclasses.fake_tensor import FakeTensor namespace = "et_vk" @@ -259,7 +258,7 @@ def linear_q4gsw( weights, [1, group_size], weight_scales, weight_zeros, torch.int8, -8, 7 ) - out = torch.nn.functional.linear(x, weights) + out = torch.nn.functional.linear(x, weights, bias) return out @@ -273,26 +272,23 @@ def linear_dq8ca_q4gsw( group_size: int, bias: Optional[torch.Tensor] = None, ): - return linear_q4gsw(x, weights, weight_scales, group_size) + return linear_q4gsw(x, weights, weight_scales, group_size, bias) name = "linear_q4gsw" -lib.define( - f""" +lib.define(f""" {name}( Tensor self, Tensor weights, Tensor weight_scales, int group_size, Tensor? bias = None) -> Tensor - """ -) + """) lib.impl(name, linear_q4gsw, "CompositeExplicitAutograd") linear_qc4w_op = getattr(getattr(torch.ops, namespace), name) name = "linear_dq8ca_q4gsw" -lib.define( - f""" +lib.define(f""" {name}( Tensor input, Tensor input_scales, @@ -302,8 +298,7 @@ def linear_dq8ca_q4gsw( Tensor weight_scales, int group_size, Tensor? bias = None) -> Tensor - """ -) + """) lib.impl(name, linear_dq8ca_q4gsw, "CompositeExplicitAutograd") linear_dq8ca_q4gsw_op = getattr(getattr(torch.ops, namespace), name) @@ -341,8 +336,7 @@ def linear_q8ta_q8csw( name = "linear_q8ta_q8csw" -lib.define( - f""" +lib.define(f""" {name}( Tensor x, float input_scale, @@ -351,8 +345,7 @@ def linear_q8ta_q8csw( Tensor weight_sums, Tensor weight_scales, Tensor? bias = None) -> Tensor - """ -) + """) lib.impl(name, linear_q8ta_q8csw, "CompositeExplicitAutograd") qa_q8csw_linear = getattr(getattr(torch.ops, namespace), name) @@ -403,8 +396,7 @@ def q8ta_linear( name = "q8ta_linear" -lib.define( - f""" +lib.define(f""" {name}( Tensor x, float input_scale, @@ -416,8 +408,7 @@ def q8ta_linear( int output_zero_point, Tensor? bias = None, str activation = "none") -> Tensor - """ -) + """) lib.impl(name, q8ta_linear, "CompositeExplicitAutograd") q8ta_linear_op = getattr(getattr(torch.ops, namespace), name) @@ -468,8 +459,7 @@ def q8ta_linear_gemv( name = "q8ta_linear_gemv" -lib.define( - f""" +lib.define(f""" {name}( Tensor x, float input_scale, @@ -481,8 +471,7 @@ def q8ta_linear_gemv( int output_zero_point, Tensor? bias = None, str activation = "none") -> Tensor - """ -) + """) lib.impl(name, q8ta_linear_gemv, "CompositeExplicitAutograd") q8ta_linear_gemv_op = getattr(getattr(torch.ops, namespace), name) @@ -560,8 +549,7 @@ def q8ta_conv2d( name = "q8ta_conv2d" -lib.define( - f""" +lib.define(f""" {name}( Tensor x, float input_scale, @@ -578,15 +566,13 @@ def q8ta_conv2d( SymInt[] dilation, SymInt groups, str activation) -> Tensor - """ -) + """) lib.impl(name, q8ta_conv2d, "CompositeExplicitAutograd") q8ta_conv2d_op = getattr(getattr(torch.ops, namespace), name) name = "q8ta_conv2d_pw" -lib.define( - f""" +lib.define(f""" {name}( Tensor x, float input_scale, @@ -603,8 +589,7 @@ def q8ta_conv2d( SymInt[] dilation, SymInt groups, str activation) -> Tensor - """ -) + """) lib.impl(name, q8ta_conv2d, "CompositeExplicitAutograd") q8ta_conv2d_pw_op = getattr(getattr(torch.ops, namespace), name) @@ -662,8 +647,7 @@ def q8ta_conv2d_dw( name = "q8ta_conv2d_dw" -lib.define( - f""" +lib.define(f""" {name}( Tensor x, float input_scale, @@ -680,8 +664,7 @@ def q8ta_conv2d_dw( SymInt[] dilation, SymInt groups, str activation) -> Tensor - """ -) + """) lib.impl(name, q8ta_conv2d_dw, "CompositeExplicitAutograd") conv2d_q8ta_q8csw_dw_op = getattr(getattr(torch.ops, namespace), name) @@ -760,8 +743,7 @@ def q8ta_conv2d_transposed( name = "q8ta_conv2d_transposed" -lib.define( - f""" +lib.define(f""" {name}( Tensor x, float input_scale, @@ -779,8 +761,7 @@ def q8ta_conv2d_transposed( SymInt[] dilation, SymInt groups, str activation) -> Tensor - """ -) + """) lib.impl(name, q8ta_conv2d_transposed, "CompositeExplicitAutograd") q8ta_conv2d_transposed_op = getattr(getattr(torch.ops, namespace), name) diff --git a/backends/vulkan/patterns/quantized_linear.py b/backends/vulkan/patterns/quantized_linear.py index b9b307e14f1..0ff9d66498f 100644 --- a/backends/vulkan/patterns/quantized_linear.py +++ b/backends/vulkan/patterns/quantized_linear.py @@ -5,28 +5,22 @@ # LICENSE file in the root directory of this source tree. import operator - from typing import Optional import executorch.backends.vulkan.utils as utils - import torch import torch.nn.functional as F - from executorch.backends.transforms.utils import ( create_constant_placeholder, get_param_tensor, ) - from executorch.backends.vulkan.patterns.pattern_registry import ( PatternMatch, register_pattern_detector, register_pattern_replacement, ) - from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops - from torch.export.graph_signature import InputKind @@ -407,6 +401,7 @@ def make_linear_q4gsw_op( match.weight_node, match.weight_scales_node, group_size, + match.bias_node, ), ) @@ -474,6 +469,7 @@ def make_linear_dq8ca_q4gsw_op( weight_sums_node, match.weight_scales_node, group_size, + match.bias_node, ), ) @@ -538,6 +534,7 @@ def make_linear_q8ta_q8csw_custom_op( match.weight_node, weight_sums_node, match.weight_scales_node, + match.bias_node, ), ) @@ -637,7 +634,6 @@ def replace_quantized_linear_patterns( assert weight_zeros_tensor is not None # Route to appropriate custom op. - # q8ta_linear supports bias, so check it first before the bias guard. if ( match.is_input_static_per_tensor_quantized() and match.is_weight_perchannel_quantized() @@ -646,10 +642,6 @@ def replace_quantized_linear_patterns( make_q8ta_linear_custom_op(ep, graph_module, match, weight_tensor) return - # Remaining ops do not support bias - if match.bias_node is not None: - return - if ( match.is_weight_only_quantized() and match.is_weight_pergroup_quantized() diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.glsl index b6c32863eb9..fa0129b65a5 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.glsl @@ -144,5 +144,11 @@ void main() { group_size); } + if (apply_bias > 0) { + FPPerOutChannelParams bias_tile; + load_bias_tile(bias_tile, n4); + add_bias_to_out_tile(out_tile, bias_tile); + } + write_output_tile_with_checks(out_tile, n4, m, N4, M); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh index 01b3c762e39..60a19ca9fc9 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh @@ -73,6 +73,16 @@ void apply_weight_scales_and_biases( } } +void add_bias_to_out_tile( + inout FPOutTile tile, + const FPPerOutChannelParams bias) { + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + tile.data[m][n4] = tile.data[m][n4] + bias.data[n4]; + } + } +} + void accumulate_out_tile_with_out_tile( inout FPOutTile accum, const FPOutTile other) { diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.glsl index 02bfe3fff0f..053f27d6c9b 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.glsl @@ -142,6 +142,11 @@ void main() { // Only the first thread will write out result if (lid == 0) { out_tile = partial_sums[0]; + if (apply_bias > 0) { + FPPerOutChannelParams bias_tile; + load_bias_tile(bias_tile, n4); + add_bias_to_out_tile(out_tile, bias_tile); + } write_output_tile_with_checks(out_tile, n4, 0, N4, 1); } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_tiled.glsl index 9a42a7fa67f..70a637ed0f8 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_tiled.glsl @@ -110,5 +110,11 @@ void main() { } } + if (apply_bias > 0) { + FPPerOutChannelParams bias_tile; + load_bias_tile(bias_tile, n4); + add_bias_to_out_tile(out_tile, bias_tile); + } + write_output_tile_with_checks(out_tile, n4, m, N4, M); } diff --git a/backends/vulkan/test/custom_ops/q4gsw_linear.cpp b/backends/vulkan/test/custom_ops/q4gsw_linear.cpp index 2af1488541d..ef6369c6b1f 100644 --- a/backends/vulkan/test/custom_ops/q4gsw_linear.cpp +++ b/backends/vulkan/test/custom_ops/q4gsw_linear.cpp @@ -148,7 +148,7 @@ TestCase create_test_case_from_config( input_dtype, storage_type, utils::kWidthPacked, - DataGenType::ZEROS); + config.has_bias ? DataGenType::RANDOM : DataGenType::ZEROS); bias.set_constant(true); if (!config.has_bias) { bias.set_none(true); @@ -237,9 +237,10 @@ std::vector generate_quantized_linear_test_cases() { {32, 64, 32, 16}, {32, 128, 64, 32}, {32, 256, 128, 64}, - // No bias tests - {32, 128, 64, 32, false}, - {32, 256, 128, 64, false}, + // With bias + {4, 64, 32, 16, true}, + {4, 128, 64, 32, true}, + {32, 128, 64, 32, true}, // Performance test cases {1, 2048, 2048, 128}, {128, 2048, 2048, 128}, @@ -499,13 +500,6 @@ void reference_impl(TestCase& test_case) { } int64_t quantized_linear_flop_calculator(const TestCase& test_case) { - int input_idx = 0; - int weight_idx = 1; - if (test_case.operator_name().find("dq8ca") != std::string::npos) { - input_idx = 0; - weight_idx = 3; // Weight comes after input, input_scale, input_zero_point - } - // Get input and weight dimensions const auto& input_sizes = test_case.inputs()[0].get_tensor_sizes(); const auto& output_sizes = test_case.outputs()[0].get_tensor_sizes();