diff --git a/src/common/transformations/include/transformations/op_conversions/convert_weight_compressed_conv1x1_to_matmul.hpp b/src/common/transformations/include/transformations/op_conversions/convert_weight_compressed_conv1x1_to_matmul.hpp new file mode 100644 index 00000000000000..853a025e11da7a --- /dev/null +++ b/src/common/transformations/include/transformations/op_conversions/convert_weight_compressed_conv1x1_to_matmul.hpp @@ -0,0 +1,120 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/matcher_pass.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace pass { + +class TRANSFORMATIONS_API ConvertWeightCompressedConv1x1ToMatmul; + +} // namespace pass +} // namespace ov + +/** + * @ingroup ov_transformation_common_api + * @brief ConvertWeightCompressedConv1x1ToMatmul transformation matches a weight-compressed + * Convolution with a 1x1 kernel and replaces it with a MatMul operation. + * + * The transformation identifies the following pattern: + * + * +---------+ +-----------+ +------+ + * | Weights | | ZeroPoint | |Scale | + * +---------+ +-----------+ +------+ + * | | | + * v v | + * +-------+ +-------+ | + * |Convert| |Convert| | + * +-------+ +-------+ | + * | | | + * +-----+ +----+ | + * | | | + * +------------+ v v | + * | Activation | +--------+ | + * +------------+ |Subtract| (optional) | + * | +--------+ | + * v | | + * +-------------+ v | + * | Transpose/ | +----------+ | + * | Reshape | | Multiply |<--------------+ + * +-------------+ +----------+ + * | | + * | v + * | +-----------+ + * +----------------->|Convolution| + * | (1x1) | + * +-----------+ + * | + * v + * +----------+ + * |Add (Bias)| (optional) + * +----------+ + * | + * v + * +-----------+ + * | Convert | (optional) + * +-----------+ + * | + * v + * +------------+ + * | Transpose/ | + * | Reshape | + * +------------+ + * + * and replaces it with: + * + * +------------+ + * | Activation | + * +------------+ + * | + * | +---------+ +-----------+ +------+ + * | | Weights | | ZeroPoint | |Scale | + * | +---------+ +-----------+ +------+ + * | | | | + * | v v | + * | +-------+ +-------+ | + * | |Convert| |Convert| | + * | +-------+ +-------+ | + * | | | | + * | +-----+ +----+ | + * | | | | + * | v v | + * | +--------+ | + * | |Subtract| (optional) | + * | +--------+ | + * | | | + * | v | + * | +----------+ | + * | | Multiply |<--------------+ + * | +----------+ + * | | + * | v + * | +--------+ + * +-------------> | MatMul | + * +--------+ + * | + * v + * +----------+ + * |Add (Bias)| (optional) + * +----------+ + * | + * v + * +-----------+ + * | Convert | (optional) + * +-----------+ + * | + * v + * +------------+ + * | Reshape | (optional) + * +------------+ + */ + +class ov::pass::ConvertWeightCompressedConv1x1ToMatmul : public ov::pass::MatcherPass { +public: + OPENVINO_MATCHER_PASS_RTTI("ConvertWeightCompressedConv1x1ToMatmul"); + ConvertWeightCompressedConv1x1ToMatmul(); +}; diff --git a/src/plugins/intel_gpu/src/plugin/transformations/convert_weight_compressed_conv1x1_to_matmul.cpp b/src/common/transformations/src/transformations/op_conversions/convert_weight_compressed_conv1x1_to_matmul.cpp similarity index 81% rename from src/plugins/intel_gpu/src/plugin/transformations/convert_weight_compressed_conv1x1_to_matmul.cpp rename to src/common/transformations/src/transformations/op_conversions/convert_weight_compressed_conv1x1_to_matmul.cpp index 99630f6730a14f..afb13587ec3708 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/convert_weight_compressed_conv1x1_to_matmul.cpp +++ b/src/common/transformations/src/transformations/op_conversions/convert_weight_compressed_conv1x1_to_matmul.cpp @@ -2,23 +2,23 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "convert_weight_compressed_conv1x1_to_matmul.hpp" +#include "transformations/op_conversions/convert_weight_compressed_conv1x1_to_matmul.hpp" #include #include #include -#include "graph/include/gemm_inst.h" -#include "intel_gpu/runtime/utils.hpp" +#include "itt.hpp" #include "openvino/core/graph_util.hpp" #include "openvino/core/node_vector.hpp" #include "openvino/core/partial_shape.hpp" #include "openvino/core/rt_info.hpp" +#include "openvino/core/rt_info/weightless_caching_attributes.hpp" #include "openvino/core/type/element_type.hpp" #include "openvino/op/add.hpp" -#include "openvino/op/convolution.hpp" #include "openvino/op/constant.hpp" #include "openvino/op/convert.hpp" +#include "openvino/op/convolution.hpp" #include "openvino/op/matmul.hpp" #include "openvino/op/multiply.hpp" #include "openvino/op/reshape.hpp" @@ -34,28 +34,26 @@ using namespace ov::pass::pattern; using ov::pass::pattern::op::Or; -namespace ov::intel_gpu { - -ConvertWeightCompressedConv1x1ToMatmul::ConvertWeightCompressedConv1x1ToMatmul(bool supports_immad) { - add_matcher(supports_immad); -} - -ConvertWeightCompressedConv1x1ToMatmulMatcher::ConvertWeightCompressedConv1x1ToMatmulMatcher(bool supports_immad) { +ov::pass::ConvertWeightCompressedConv1x1ToMatmul::ConvertWeightCompressedConv1x1ToMatmul() { + MATCHER_SCOPE(ConvertWeightCompressedConv1x1ToMatmul); auto filter1x1_path = [](const ov::Output& output) { const auto& pshape = output.get_partial_shape(); - return ov::op::util::is_on_path(output) && pshape.is_static() && pshape[-1] == 1 && pshape[-2] == 1; + return ov::op::util::is_on_path(output) && pshape.is_static() && + pshape[-1] == 1 && pshape[-2] == 1; }; auto bias_path = [](const ov::Output& output) { const auto& pshape = output.get_partial_shape(); - return ov::op::util::is_on_path(output) && pshape.is_static() && pshape[0] == 1 && pshape[2] == 1 && pshape[3] == 1; + return ov::op::util::is_on_path(output) && pshape.is_static() && pshape[0] == 1 && + pshape[2] == 1 && pshape[3] == 1; }; auto first_input_m = ov::pass::pattern::any_input(); auto a_order_m = ov::pass::pattern::wrap_type(); auto transpose_activations_m = ov::pass::pattern::wrap_type({first_input_m, a_order_m}); auto reshape_activations_m = ov::pass::pattern::wrap_type({first_input_m, a_order_m}); - auto a_m = std::make_shared(OutputVector{transpose_activations_m, reshape_activations_m}); + auto a_m = + std::make_shared(OutputVector{transpose_activations_m, reshape_activations_m}); auto weights_const_m = wrap_type(rank_equals(4) && has_static_rank() && filter1x1_path); auto weights_param_m = wrap_type(rank_equals(4) && has_static_rank() && filter1x1_path); @@ -64,9 +62,11 @@ ConvertWeightCompressedConv1x1ToMatmulMatcher::ConvertWeightCompressedConv1x1ToM auto weights_scales_m = ov::pass::pattern::any_input(); auto weights_zp_m = ov::pass::pattern::any_input(); auto weights_zp_convert_m = ov::pass::pattern::wrap_type({weights_zp_m}); - auto weight_subtract_m = ov::pass::pattern::wrap_type({weight_convert_m, weights_zp_convert_m}); + auto weight_subtract_m = + ov::pass::pattern::wrap_type({weight_convert_m, weights_zp_convert_m}); // Make zp subtraction optional to account for symmetrical quantization cases - auto weight_dequantized_m = std::make_shared(OutputVector{weight_convert_m, weight_subtract_m}); + auto weight_dequantized_m = + std::make_shared(OutputVector{weight_convert_m, weight_subtract_m}); auto weight_mult_m = ov::pass::pattern::wrap_type({weight_dequantized_m, weights_scales_m}); auto conv1x1_m = ov::pass::pattern::wrap_type({a_m, weight_mult_m}); @@ -87,14 +87,21 @@ ConvertWeightCompressedConv1x1ToMatmulMatcher::ConvertWeightCompressedConv1x1ToM const auto& pattern_map = m.get_pattern_value_map(); auto conv1x1 = ov::as_type_ptr(pattern_map.at(conv1x1_m).get_node_shared_ptr()); - auto weight_convert = ov::as_type_ptr(pattern_map.at(weight_convert_m).get_node_shared_ptr()); - auto weight_sub = (pattern_map.count(weight_subtract_m) > 0) ? pattern_map.at(weight_subtract_m).get_node_shared_ptr() : nullptr; + auto weight_convert = + ov::as_type_ptr(pattern_map.at(weight_convert_m).get_node_shared_ptr()); + auto weight_sub = (pattern_map.count(weight_subtract_m) > 0) + ? pattern_map.at(weight_subtract_m).get_node_shared_ptr() + : nullptr; auto weight_mult = ov::as_type_ptr(pattern_map.at(weight_mult_m).get_node_shared_ptr()); auto bias_out = (pattern_map.count(bias_m) > 0) ? pattern_map.at(bias_m).get_node_shared_ptr() : nullptr; - auto bias_const = (pattern_map.count(bias_const_m) > 0) ? pattern_map.at(bias_const_m).get_node_shared_ptr() : nullptr; - auto convert_out = (pattern_map.count(convert_m) > 0) ? pattern_map.at(convert_m).get_node_shared_ptr() : nullptr; + auto bias_const = + (pattern_map.count(bias_const_m) > 0) ? pattern_map.at(bias_const_m).get_node_shared_ptr() : nullptr; + auto convert_out = + (pattern_map.count(convert_m) > 0) ? pattern_map.at(convert_m).get_node_shared_ptr() : nullptr; auto out_order = (pattern_map.count(c_order_m) > 0) ? pattern_map.at(c_order_m).get_node_shared_ptr() : nullptr; - auto reshape_out = (pattern_map.count(reshape_output_m) > 0) ? pattern_map.at(reshape_output_m).get_node_shared_ptr() : nullptr; + auto reshape_out = (pattern_map.count(reshape_output_m) > 0) + ? pattern_map.at(reshape_output_m).get_node_shared_ptr() + : nullptr; if (!conv1x1 || transformation_callback(conv1x1)) { return false; } @@ -137,7 +144,8 @@ ConvertWeightCompressedConv1x1ToMatmulMatcher::ConvertWeightCompressedConv1x1ToM auto Reshape_weight = reshape_const_to_2d(weight); MatcherPass::register_new_node(Reshape_weight); Reshape_weight->set_friendly_name(weight->get_friendly_name() + "_Reshape_weight"); - weight_squeezed_convert = ov::as_type_ptr(weight_convert->clone_with_new_inputs({Reshape_weight})); + weight_squeezed_convert = + ov::as_type_ptr(weight_convert->clone_with_new_inputs({Reshape_weight})); ov::copy_runtime_info(weight_convert, weight_squeezed_convert); } else { auto param = ov::as_type_ptr(weight); @@ -146,15 +154,17 @@ ConvertWeightCompressedConv1x1ToMatmulMatcher::ConvertWeightCompressedConv1x1ToM auto shape_b = param->get_output_partial_shape(0); for (size_t i = 0; i < shape_b.size(); i++) if (shape_b.to_shape()[i] != 1) { - values_reshape_b.push_back(shape_b.to_shape()[i]); + values_reshape_b.push_back(static_cast(shape_b.to_shape()[i])); } auto reshape_weight_const = ov::op::v0::Constant::create(element::i32, Shape{2}, values_reshape_b); auto Reshape_weight = std::make_shared(param, reshape_weight_const, false); MatcherPass::register_new_node(Reshape_weight); Reshape_weight->set_friendly_name(param->get_friendly_name() + "_Reshape_weight"); - weight_squeezed_convert = ov::as_type_ptr(weight_convert->clone_with_new_inputs({Reshape_weight})); + weight_squeezed_convert = + ov::as_type_ptr(weight_convert->clone_with_new_inputs({Reshape_weight})); ov::copy_runtime_info(weight_convert, weight_squeezed_convert); + ov::copy_runtime_info(weight_convert, Reshape_weight); } ov::disable_constant_folding(weight_squeezed_convert); @@ -170,11 +180,13 @@ ConvertWeightCompressedConv1x1ToMatmulMatcher::ConvertWeightCompressedConv1x1ToM auto Reshape_zp = reshape_const_to_2d(zp); MatcherPass::register_new_node(Reshape_zp); Reshape_zp->set_friendly_name(zp->get_friendly_name() + "_Reshape_zp"); - auto weights_zp_convert = ov::as_type_ptr(pattern_map.at(weights_zp_convert_m).get_node_shared_ptr()); + auto weights_zp_convert = + ov::as_type_ptr(pattern_map.at(weights_zp_convert_m).get_node_shared_ptr()); auto zp_squeezed_convert = weights_zp_convert->clone_with_new_inputs({Reshape_zp}); ov::copy_runtime_info(weights_zp_convert, zp_squeezed_convert); ov::disable_constant_folding(zp_squeezed_convert); - auto zero_adjusted_weight = weight_sub->clone_with_new_inputs({weight_squeezed_convert, zp_squeezed_convert}); + auto zero_adjusted_weight = + weight_sub->clone_with_new_inputs({weight_squeezed_convert, zp_squeezed_convert}); ov::copy_runtime_info(weight_sub, zero_adjusted_weight); scaled_weight = weight_mult->clone_with_new_inputs({zero_adjusted_weight, Reshape_scale}); } @@ -182,6 +194,7 @@ ConvertWeightCompressedConv1x1ToMatmulMatcher::ConvertWeightCompressedConv1x1ToM ov::disable_constant_folding(scaled_weight); auto matmul = std::make_shared(activation, scaled_weight, false, true); + ov::copy_runtime_info(conv1x1, matmul); std::shared_ptr matmul_out; if (bias_out) { auto bias = ov::as_type_ptr(bias_const); @@ -193,12 +206,14 @@ ConvertWeightCompressedConv1x1ToMatmulMatcher::ConvertWeightCompressedConv1x1ToM auto new_bias_shape = ov::Shape{bias_shape[0], bias_shape[2], bias_shape[3], bias_shape[1]}; auto Reshape_bias = std::make_shared(*bias, new_bias_shape); + ov::copy_runtime_info(bias, Reshape_bias); ov::copy_weightless_cache_attr(bias, Reshape_bias); MatcherPass::register_new_node(Reshape_bias); Reshape_bias->set_friendly_name(bias->get_friendly_name() + "_Reshape_bias"); matmul_out = bias_out->clone_with_new_inputs({matmul, Reshape_bias}); + ov::copy_runtime_info(bias_out, matmul_out); } else { matmul_out = matmul; } @@ -208,6 +223,7 @@ ConvertWeightCompressedConv1x1ToMatmulMatcher::ConvertWeightCompressedConv1x1ToM auto convert_final = convert_out->clone_with_new_inputs({matmul_out}); auto reshape_final = reshape_out->clone_with_new_inputs({convert_final, out_order}); reshape_final->set_friendly_name(m.get_match_root()->get_friendly_name()); + ov::copy_runtime_info(convert_out, convert_final); ov::copy_runtime_info(m.get_matched_nodes(), reshape_final); ov::replace_node(m.get_match_root(), reshape_final); } else { @@ -232,8 +248,6 @@ ConvertWeightCompressedConv1x1ToMatmulMatcher::ConvertWeightCompressedConv1x1ToM return true; }; - auto m = std::make_shared(output_m, "ConvertWeightCompressedConv1x1ToMatmulMatcher"); + auto m = std::make_shared(output_m, matcher_name); this->register_matcher(m, callback); } - -} // namespace ov::intel_gpu diff --git a/src/common/transformations/tests/op_conversions/convert_weight_compressed_conv1x1_to_matmul_test.cpp b/src/common/transformations/tests/op_conversions/convert_weight_compressed_conv1x1_to_matmul_test.cpp new file mode 100644 index 00000000000000..182b363fafdc14 --- /dev/null +++ b/src/common/transformations/tests/op_conversions/convert_weight_compressed_conv1x1_to_matmul_test.cpp @@ -0,0 +1,243 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/op_conversions/convert_weight_compressed_conv1x1_to_matmul.hpp" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "common_test_utils/ov_test_utils.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/convolution.hpp" +#include "openvino/op/matmul.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/subtract.hpp" +#include "openvino/op/transpose.hpp" +#include "openvino/opsets/opset1_decl.hpp" +#include "openvino/opsets/opset3_decl.hpp" +#include "openvino/opsets/opset7_decl.hpp" +#include "transformations/rt_info/decompression.hpp" + +using namespace ov; +using namespace testing; + +namespace { +struct Conv1x1ToMatmulTestParams { + bool with_zp; + bool with_bias; + bool with_convert; + bool weights_as_param; + std::string activation_op_type; +}; + +std::shared_ptr gen_model(const Conv1x1ToMatmulTestParams& p) { + auto input = std::make_shared(ov::element::f16, ov::Shape{1, 1, 2, 10}); + std::shared_ptr act_node; + if (p.activation_op_type == "Transpose") { + auto transpose_const = ov::opset1::Constant::create(ov::element::i32, ov::Shape{4}, {0, 3, 1, 2}); + act_node = std::make_shared(input, transpose_const); + } else { + auto reshape_const = ov::opset1::Constant::create(ov::element::i32, ov::Shape{4}, {1, 10, 1, 2}); + act_node = std::make_shared(input, reshape_const, false); + } + + std::shared_ptr weights_node; + ov::ParameterVector params = {input}; + if (p.weights_as_param) { + auto weights_param = std::make_shared(ov::element::i4, ov::Shape{15, 10, 1, 1}); + weights_node = weights_param; + params.push_back(weights_param); + } else { + weights_node = ov::opset1::Constant::create(ov::element::i4, {15, 10, 1, 1}, {1}); + } + + auto weights_convert = std::make_shared(weights_node, ov::element::f16); + std::shared_ptr current_node = weights_convert; + + if (p.with_zp) { + auto zp_const = ov::opset1::Constant::create(ov::element::i4, {15, 10, 1, 1}, {1}); + auto zp_convert = std::make_shared(zp_const, ov::element::f16); + current_node = std::make_shared(weights_convert, zp_convert); + } + + auto scale_const = ov::opset1::Constant::create(ov::element::f16, {15, 10, 1, 1}, {1}); + auto mul = std::make_shared(current_node, scale_const); + + auto conv = std::make_shared(act_node, + mul, + ov::Strides{1, 1}, + ov::CoordinateDiff{0, 0}, + ov::CoordinateDiff{0, 0}, + ov::Strides{1, 1}, + ov::op::PadType::EXPLICIT); + current_node = conv; + + if (p.with_bias) { + auto bias_const = ov::opset1::Constant::create(ov::element::f16, ov::Shape{1, 15, 1, 1}, {1}); + current_node = std::make_shared(conv, bias_const); + } + if (p.with_convert) { + current_node = std::make_shared(current_node, ov::element::f32); + } + + std::shared_ptr out_node; + if (p.activation_op_type == "Transpose") { + auto transpose_const = ov::opset1::Constant::create(ov::element::i32, ov::Shape{4}, {0, 2, 3, 1}); + out_node = std::make_shared(current_node, transpose_const); + } else { + auto reshape_const = ov::opset1::Constant::create(ov::element::i32, ov::Shape{4}, {1, 1, 2, 15}); + out_node = std::make_shared(current_node, reshape_const, false); + } + + return std::make_shared(ov::OutputVector{out_node}, params); +} + +std::shared_ptr gen_model_ref(const Conv1x1ToMatmulTestParams& p) { + auto input = std::make_shared(ov::element::f16, ov::Shape{1, 1, 2, 10}); + + std::shared_ptr weights_node; + ov::ParameterVector params = {input}; + if (p.weights_as_param) { + auto weights_param = std::make_shared(ov::element::i4, ov::Shape{15, 10, 1, 1}); + auto reshape_const = ov::opset1::Constant::create(ov::element::i32, ov::Shape{2}, {15, 10}); + weights_node = std::make_shared(weights_param, reshape_const, false); + params.push_back(weights_param); + } else { + weights_node = ov::opset1::Constant::create(ov::element::i4, {15, 10}, {1}); + } + + auto weights_convert = std::make_shared(weights_node, ov::element::f16); + std::shared_ptr current_node = weights_convert; + + if (p.with_zp) { + auto zp_const = ov::opset1::Constant::create(ov::element::i4, {15, 10}, {1}); + auto zp_convert = std::make_shared(zp_const, ov::element::f16); + current_node = std::make_shared(weights_convert, zp_convert); + } + + auto scale_const = ov::opset1::Constant::create(ov::element::f16, {15, 10}, {1}); + auto mul = std::make_shared(current_node, scale_const); + + auto matmul = std::make_shared(input, mul, false, true); + current_node = matmul; + + if (p.with_bias) { + auto bias_const = ov::opset1::Constant::create(ov::element::f16, ov::Shape{1, 1, 1, 15}, {1}); + current_node = std::make_shared(matmul, bias_const); + } + if (p.with_convert) { + current_node = std::make_shared(current_node, ov::element::f32); + } + + std::shared_ptr out_node; + if (p.activation_op_type == "Reshape") { + auto reshape_const = ov::opset1::Constant::create(ov::element::i32, ov::Shape{4}, {1, 1, 2, 15}); + out_node = std::make_shared(current_node, reshape_const, false); + } else { + out_node = current_node; + } + + return std::make_shared(ov::OutputVector{out_node}, params); +} +} // namespace + +class ConvertWeightCompressedConv1x1ToMatmulTest + : public TransformationTestsF, + public WithParamInterface> { +public: + static std::string get_test_case_name( + const testing::TestParamInfo>& obj) { + bool with_zp, with_bias, with_convert, weights_as_param; + std::string activation_op_type; + std::tie(with_zp, with_bias, with_convert, weights_as_param, activation_op_type) = obj.param; + + std::ostringstream result; + result << "with_zp=" << with_zp << "_"; + result << "with_bias=" << with_bias << "_"; + result << "with_convert=" << with_convert << "_"; + result << "weights_as_param=" << weights_as_param << "_"; + result << "activation_op_type=" << activation_op_type; + return result.str(); + } + +protected: + void SetUp() override { + TransformationTestsF::SetUp(); + bool with_zp, with_bias, with_convert, weights_as_param; + std::string activation_op_type; + std::tie(with_zp, with_bias, with_convert, weights_as_param, activation_op_type) = GetParam(); + Conv1x1ToMatmulTestParams params{with_zp, with_bias, with_convert, weights_as_param, activation_op_type}; + model = gen_model(params); + model_ref = gen_model_ref(params); + manager.register_pass(); + } +}; + +TEST_P(ConvertWeightCompressedConv1x1ToMatmulTest, CompareFunctions) {} + +INSTANTIATE_TEST_SUITE_P(TransformationTests, + ConvertWeightCompressedConv1x1ToMatmulTest, + ::testing::Combine(::testing::Bool(), + ::testing::Bool(), + ::testing::Bool(), + ::testing::Bool(), + ::testing::Values("Transpose", "Reshape")), + ConvertWeightCompressedConv1x1ToMatmulTest::get_test_case_name); + +// Checked blocked cases +TEST(TransformationTests, ConvertWeightCompressedConv1x1ToMatmulExceptionTest_conv3x3) { + auto CreateConv = [&]() { + ov::Strides strides{1, 1}; + ov::Strides dilations{1, 1}; + ov::CoordinateDiff pads_begin{1, 1}; + ov::CoordinateDiff pads_end{1, 1}; + auto input1 = std::make_shared(ov::element::f16, ov::Shape{1, 1, 2, 1}); + auto transpose_constant1 = ov::opset1::Constant::create(ov::element::i32, ov::Shape{4}, {0, 3, 1, 2}); + auto transpose_constant2 = ov::opset1::Constant::create(ov::element::i32, ov::Shape{4}, {0, 2, 3, 1}); + auto input2 = ov::opset1::Constant::create(ov::element::i4, ov::Shape{1, 1, 3, 3}, {1}); + auto input2_convert = std::make_shared(input2, ov::element::f16); + auto input2_scale = ov::opset1::Constant::create(ov::element::f16, ov::Shape{1, 1, 3, 3}, {1}); + auto mul = std::make_shared(input2_convert, input2_scale); + auto transpose1 = std::make_shared(input1, transpose_constant1); + auto conv3x3 = std::make_shared(transpose1, + mul, + strides, + pads_begin, + pads_end, + dilations, + ov::op::PadType::EXPLICIT); + auto transpose2 = std::make_shared(conv3x3, transpose_constant2); + + auto model = std::make_shared(ov::OutputVector{transpose2}, ov::ParameterVector{input1}); + return model; + }; + + ov::pass::Manager manager; + manager.set_per_pass_validation(false); + manager.register_pass(); + + auto func = CreateConv(); + + manager.run_passes(func); + + bool success = false; + for (auto& ops : func->get_ops()) { + std::string type_name(ops->get_type_name()); + if (type_name.find("MatMul") != std::string::npos) { + success = true; + break; + } + } + ASSERT_TRUE(success == false); +} diff --git a/src/plugins/intel_gpu/src/plugin/transformations/convert_weight_compressed_conv1x1_to_matmul.hpp b/src/plugins/intel_gpu/src/plugin/transformations/convert_weight_compressed_conv1x1_to_matmul.hpp deleted file mode 100644 index f7d8a2d75c86a3..00000000000000 --- a/src/plugins/intel_gpu/src/plugin/transformations/convert_weight_compressed_conv1x1_to_matmul.hpp +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (C) 2025 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include "openvino/pass/graph_rewrite.hpp" - -namespace ov::intel_gpu { - -class ConvertWeightCompressedConv1x1ToMatmul : public ov::pass::GraphRewrite { -public: - OPENVINO_GRAPH_REWRITE_RTTI("ConvertWeightCompressedConv1x1ToMatmul"); - ConvertWeightCompressedConv1x1ToMatmul(bool supports_immad = false); -}; - -class ConvertWeightCompressedConv1x1ToMatmulMatcher : public ov::pass::MatcherPass { -public: - OPENVINO_MATCHER_PASS_RTTI("ConvertWeightCompressedConv1x1ToMatmulMatcher"); - ConvertWeightCompressedConv1x1ToMatmulMatcher(bool supports_immad); -}; - -} // namespace ov::intel_gpu diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index b583751399fa2e..f9dab6b5fece1c 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -78,7 +78,6 @@ #include "plugin/transformations/convert_matmul_to_fc.hpp" #include "plugin/transformations/convert_moe_to_compressed.hpp" #include "plugin/transformations/convert_stridedslices_to_variadicsplit.hpp" -#include "plugin/transformations/convert_weight_compressed_conv1x1_to_matmul.hpp" #include "plugin/transformations/decompose_reduce_scalar_output.hpp" #include "plugin/transformations/dynamic_quantize_fully_connected.hpp" #include "plugin/transformations/fc_convert_fusion.hpp" @@ -168,6 +167,7 @@ #include "transformations/op_conversions/convert_subtract.hpp" #include "transformations/op_conversions/convert_ti_to_sequences.hpp" #include "transformations/op_conversions/convert_topk11_downgrade.hpp" +#include "transformations/op_conversions/convert_weight_compressed_conv1x1_to_matmul.hpp" #include "transformations/op_conversions/eye_decomposition.hpp" #include "transformations/op_conversions/gelu7_downgrade.hpp" #include "transformations/op_conversions/group_normalization_decomposition.hpp" @@ -1335,7 +1335,7 @@ void TransformationsPipeline::apply(std::shared_ptr func) { ov::pass::Manager manager("GPU:PostLPT"); manager.set_per_pass_validation(false); - manager.register_pass(); + manager.register_pass(); manager.register_pass(); manager.register_pass(device_info.supports_immad); manager.register_pass(); diff --git a/src/plugins/intel_gpu/tests/unit/transformations/convert_weight_compressed_conv1x1_to_matmul_test.cpp b/src/plugins/intel_gpu/tests/unit/transformations/convert_weight_compressed_conv1x1_to_matmul_test.cpp deleted file mode 100644 index aedfea3e12ba92..00000000000000 --- a/src/plugins/intel_gpu/tests/unit/transformations/convert_weight_compressed_conv1x1_to_matmul_test.cpp +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright (C) 2025 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "common_test_utils/ov_test_utils.hpp" -#include "openvino/op/concat.hpp" -#include "openvino/op/convolution.hpp" -#include "openvino/op/matmul.hpp" -#include "openvino/op/multiply.hpp" -#include "openvino/op/shape_of.hpp" -#include "openvino/op/subtract.hpp" -#include "openvino/op/transpose.hpp" -#include "openvino/opsets/opset1_decl.hpp" -#include "openvino/opsets/opset3_decl.hpp" -#include "openvino/opsets/opset7_decl.hpp" -#include "transformations/rt_info/decompression.hpp" - -using namespace testing; -using namespace ov::intel_gpu; - -TEST_F(TransformationTestsF, ConvertWeightCompressedConv1x1ToMatmulTest1) { - ov::Strides strides{1, 1}; - ov::Strides dilations{1, 1}; - ov::CoordinateDiff pads_begin{0, 0}; - ov::CoordinateDiff pads_end{0, 0}; - { - auto input1 = std::make_shared(ov::element::f16, ov::Shape{1, 1, 2, 10}); - auto transpose_constant1 = ov::opset1::Constant::create(ov::element::i32, ov::Shape{4}, {0, 3, 1, 2}); - auto transpose_constant2 = ov::opset1::Constant::create(ov::element::i32, ov::Shape{4}, {0, 2, 3, 1}); - auto input2 = ov::opset1::Constant::create(ov::element::i4, ov::Shape{15, 10, 1, 1}, {1}); - auto input2_convert = std::make_shared(input2, ov::element::f16); - auto input2_scale = ov::opset1::Constant::create(ov::element::f16, ov::Shape{15, 10, 1, 1}, {1}); - auto mul = std::make_shared(input2_convert, input2_scale); - auto transpose1 = std::make_shared(input1, transpose_constant1); - auto conv1x1 = std::make_shared(transpose1, mul, strides, pads_begin, pads_end, dilations, ov::op::PadType::EXPLICIT); - auto transpose2 = std::make_shared(conv1x1, transpose_constant2); - - model = std::make_shared(ov::OutputVector{transpose2}, ov::ParameterVector{input1}); - manager.register_pass(); - } - { - auto input1 = std::make_shared(ov::element::f16, ov::Shape{1, 1, 2, 10}); - auto input2 = ov::opset1::Constant::create(ov::element::i4, ov::Shape{15, 10}, {1}); - auto input2_convert = std::make_shared(input2, ov::element::f16); - auto input2_scale = ov::opset1::Constant::create(ov::element::f16, ov::Shape{15, 10}, {1}); - auto mul = std::make_shared(input2_convert, input2_scale); - auto matmul = std::make_shared(input1, mul, false, true); - - model_ref = std::make_shared(ov::OutputVector{matmul}, ov::ParameterVector{input1}); - } -} - -// Checked blocked cases -TEST(TransformationTests, ConvertWeightCompressedConv1x1ToMatmulExceptionTest_conv3x3) { - auto CreateConv = [&]() { - ov::Strides strides{1, 1}; - ov::Strides dilations{1, 1}; - ov::CoordinateDiff pads_begin{1, 1}; - ov::CoordinateDiff pads_end{1, 1}; - auto input1 = std::make_shared(ov::element::f16, ov::Shape{1, 1, 2, 1}); - auto transpose_constant1 = ov::opset1::Constant::create(ov::element::i32, ov::Shape{4}, {0, 3, 1, 2}); - auto transpose_constant2 = ov::opset1::Constant::create(ov::element::i32, ov::Shape{4}, {0, 2, 3, 1}); - auto input2 = ov::opset1::Constant::create(ov::element::i4, ov::Shape{1, 1, 3, 3}, {1}); - auto input2_convert = std::make_shared(input2, ov::element::f16); - auto input2_scale = ov::opset1::Constant::create(ov::element::f16, ov::Shape{1, 1, 3, 3}, {1}); - auto mul = std::make_shared(input2_convert, input2_scale); - auto transpose1 = std::make_shared(input1, transpose_constant1); - auto conv3x3 = std::make_shared(transpose1, mul, strides, pads_begin, pads_end, dilations, ov::op::PadType::EXPLICIT); - auto transpose2 = std::make_shared(conv3x3, transpose_constant2); - - auto model = std::make_shared(ov::OutputVector{transpose2}, ov::ParameterVector{input1}); - return model; - }; - - ov::pass::Manager manager; - manager.set_per_pass_validation(false); - manager.register_pass(); - - auto func = CreateConv(); - - manager.run_passes(func); - - bool success = false; - for (auto& ops : func->get_ops()) { - std::string type_name(ops->get_type_name()); - if (type_name.find("MatMul") != std::string::npos) { - success = true; - break; - } - } - ASSERT_TRUE(success == false); -}