From 6a4f6c5e8da9123f0cdcec5d4c13f4e63911da70 Mon Sep 17 00:00:00 2001 From: HoYi Date: Mon, 11 May 2026 22:18:35 +0800 Subject: [PATCH 1/7] [Relax][Frontend][TFLite] Preserve tensor quantization metadata Remove the global NotImplementedError guard in get_tensors() that blocked all quantized TFLite models at the tensor-parsing stage. The guard prevented the frontend from advancing to operator conversion even when only tensor-level metadata was needed. Changes: - Preserve scale and zero_point as before (per-tensor and per-axis) - Additionally record axis = QuantizedDimension() in qnn_params - Remove the global guard; errors now surface at specific operator converters rather than at tensor metadata parsing - Update the F821 lint comment to reflect the new state Test: add test_tensor_quantization_parameters_are_parsed which builds a minimal TFLite flatbuffer with per-tensor and per-axis quantization and verifies that TensorWrapper.qnn_params contains scale, zero_point, and axis. Assert that from_tflite() no longer fails at tensor parsing. This is the first milestone of #19534 (quantized TFLite import). Subsequent PRs will replace _qnn.op.* with Relax QDQ ops and add quantized operator conversion. --- .../relax/frontend/tflite/tflite_frontend.py | 8 +- tests/python/relax/test_frontend_tflite.py | 84 ++++++++++++++++++- 2 files changed, 86 insertions(+), 6 deletions(-) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 28b125eec0b0..66a950813068 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -19,8 +19,8 @@ # pylint: disable=no-value-for-parameter, unused-variable # pylint: disable=unexpected-keyword-arg, unused-import, too-many-function-args # ruff: noqa: RUF005 -# F821: _qnn and _expr references are in unreachable code paths (guarded by NotImplementedError) -# and will be resolved when quantization and vision op support are added. +# F821: _qnn and _expr references are in not-yet-covered code paths and will be +# resolved as quantization and vision op support are completed. # ruff: noqa: F821 """Tensorflow lite frontend.""" @@ -557,9 +557,7 @@ def get_tensors(self, tensors_idx_list): qnn_params = dict() qnn_params["scale"] = relax.const(scale, "float32") qnn_params["zero_point"] = relax.const(zero_point, "int32") - raise NotImplementedError( - "Quantized TFLite models are not yet supported in the Relax frontend" - ) + qnn_params["axis"] = int(tflite_qnn_params.QuantizedDimension()) return_list.append(TensorWrapper(tensor_idx, tensor, buffer, qnn_params)) return return_list diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 031c1553d8bf..af4479043ee5 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -3697,6 +3697,7 @@ def _get_tflite_schema_enum(enum_name): _tfl_model = _get_tflite_schema_module("Model") _tfl_operator = _get_tflite_schema_module("Operator") _tfl_operator_code = _get_tflite_schema_module("OperatorCode") +_tfl_quantization_parameters = _get_tflite_schema_module("QuantizationParameters") _tfl_sparsity_parameters = _get_tflite_schema_module("SparsityParameters") _tfl_subgraph = _get_tflite_schema_module("SubGraph") _tfl_tensor = _get_tflite_schema_module("Tensor") @@ -3742,6 +3743,13 @@ def _tflite_bool_vector(builder, start_vector_fn, values): return builder.EndVector() +def _tflite_float32_vector(builder, start_vector_fn, values): + start_vector_fn(builder, len(values)) + for value in reversed(values): + builder.PrependFloat32(value) + return builder.EndVector() + + def _tflite_offset_vector(builder, start_vector_fn, offsets): start_vector_fn(builder, len(offsets)) for offset in reversed(offsets): @@ -3773,7 +3781,7 @@ def _tflite_shape(builder, shape): return _tflite_int32_vector(builder, _tfl_tensor.TensorStartShapeVector, shape) -def _build_tensor(builder, buffer_idx, shape, sparsity=None, tensor_type=None): +def _build_tensor(builder, buffer_idx, shape, sparsity=None, tensor_type=None, quantization=None): """Helper to build a TFLite tensor.""" if tensor_type is None: tensor_type = _tfl_tensor_type.FLOAT32 @@ -3785,6 +3793,8 @@ def _build_tensor(builder, buffer_idx, shape, sparsity=None, tensor_type=None): _tfl_tensor.TensorAddShape(builder, shape_vec) if sparsity is not None: _tfl_tensor.TensorAddSparsity(builder, sparsity) + if quantization is not None: + _tfl_tensor.TensorAddQuantization(builder, quantization) _tfl_tensor.TensorAddType(builder, tensor_type) return _tfl_tensor.TensorEnd(builder) @@ -3801,6 +3811,24 @@ def _build_buffer(builder, data=None): return _tfl_buffer.BufferEnd(builder) +def _build_quantization_parameters(builder, *, scale, zero_point, quantized_dimension): + scale_vec = _tflite_float32_vector( + builder, _tfl_quantization_parameters.QuantizationParametersStartScaleVector, scale + ) + zero_point_vec = _tflite_int64_vector( + builder, + _tfl_quantization_parameters.QuantizationParametersStartZeroPointVector, + zero_point, + ) + _tfl_quantization_parameters.QuantizationParametersStart(builder) + _tfl_quantization_parameters.QuantizationParametersAddScale(builder, scale_vec) + _tfl_quantization_parameters.QuantizationParametersAddZeroPoint(builder, zero_point_vec) + _tfl_quantization_parameters.QuantizationParametersAddQuantizedDimension( + builder, quantized_dimension + ) + return _tfl_quantization_parameters.QuantizationParametersEnd(builder) + + def _build_operator( builder, opcode_index, @@ -5741,6 +5769,60 @@ def test_stablehlo_dynamic_slice_out_of_bounds_unsupported(): from_tflite(tflite_model) +def test_tensor_quantization_parameters_are_parsed(): + """Tensor quantization metadata is kept without requiring quantized op support.""" + builder = flatbuffers.Builder(1024) + + per_tensor_quantization = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + per_axis_quantization = _build_quantization_parameters( + builder, scale=[0.25, 0.75], zero_point=[0, 0], quantized_dimension=3 + ) + per_tensor = _build_tensor( + builder, + 0, + [1, 4], + tensor_type=_tfl_tensor_type.UINT8, + quantization=per_tensor_quantization, + ) + per_axis = _build_tensor( + builder, + 1, + [1, 2, 3, 2], + tensor_type=_tfl_tensor_type.INT8, + quantization=per_axis_quantization, + ) + subgraph = _build_subgraph( + builder, tensors=[per_tensor, per_axis], operators=[], inputs=[0, 1], outputs=[0, 1] + ) + buffers = [_build_buffer(builder), _build_buffer(builder)] + buf = _finish_tflite_model(builder, subgraph=subgraph, operator_codes=[], buffers=buffers) + + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + + converter = tflite_frontend.OperatorConverter( + tflite_model, tflite_model.Subgraphs(0), tflite_frontend.ExprTable(), None + ) + per_tensor_wrapper, per_axis_wrapper = converter.get_tensors([0, 1]) + + np.testing.assert_allclose(per_tensor_wrapper.qnn_params["scale"].data.numpy(), 0.5) + np.testing.assert_equal(per_tensor_wrapper.qnn_params["zero_point"].data.numpy(), 3) + assert per_tensor_wrapper.qnn_params["axis"] == 0 + + np.testing.assert_allclose( + per_axis_wrapper.qnn_params["scale"].data.numpy(), np.array([0.25, 0.75]) + ) + np.testing.assert_equal(per_axis_wrapper.qnn_params["zero_point"].data.numpy(), 0) + assert per_axis_wrapper.qnn_params["axis"] == 3 + + mod = from_tflite(tflite_model) + assert len(mod["main"].params) == 2 + + def test_stablehlo_cbrt(): """TFLite StableHLO CBRT uses a sign-preserving composite expression.""" mod = _load_model_from_buffer( From 4f6fdcd027fd735853c9c1de3286a55106ef3aba Mon Sep 17 00:00:00 2001 From: HoYi Date: Mon, 11 May 2026 22:23:26 +0800 Subject: [PATCH 2/7] [Relax][Frontend][TFLite] Use Relax QDQ ops for TFLite quantize/dequantize Replace the quantize() and dequantize() frontend helpers, which previously referenced non-existent _qnn.op.quantize / _qnn.op.dequantize, with the existing relax.op.quantize / relax.op.dequantize operators. Changes: - quantize(): _qnn.op.quantize -> relax.op.quantize, add axis param - dequantize(): _qnn.op.dequantize -> relax.op.dequantize, add axis param - Update F821 lint comment to enumerate remaining _qnn references Tests: - test_quantize_op_uses_relax_quantize: builds a minimal TFLite flatbuffer with QUANTIZE (float32 -> int8) and asserts the IR uses R.quantize with scale, zero_point, axis, and out_dtype - test_dequantize_op_uses_relax_dequantize: builds a minimal TFLite flatbuffer with DEQUANTIZE (int8 -> float32) and asserts the IR uses R.dequantize with scale, zero_point, and axis Part of #19534 (quantized TFLite import). Subsequent PRs will handle requantize, Conv2D, Dense, and remaining quantized ops. --- .../relax/frontend/tflite/tflite_frontend.py | 20 +-- tests/python/relax/test_frontend_tflite.py | 123 ++++++++++++++++++ 2 files changed, 135 insertions(+), 8 deletions(-) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 66a950813068..a0f85dc54514 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -19,8 +19,10 @@ # pylint: disable=no-value-for-parameter, unused-variable # pylint: disable=unexpected-keyword-arg, unused-import, too-many-function-args # ruff: noqa: RUF005 -# F821: _qnn and _expr references are in not-yet-covered code paths and will be -# resolved as quantization and vision op support are completed. +# F821: remaining _qnn references (requantize, conv2d, dense, concat, +# conv2d_transpose, and detection-postprocess dequantize) are in +# not-yet-covered code paths and will be resolved as quantized op support +# advances. _expr references will be resolved when vision ops are added. # ruff: noqa: F821 """Tensorflow lite frontend.""" @@ -662,20 +664,22 @@ def quantize(self, expr, tensor_to_quantize): """Helper function to quantize a tensor with Relax""" tensor_type = tensor_to_quantize.tensor.Type() tensor_type_str = self.get_tensor_type_str(tensor_type) - quantized = _qnn.op.quantize( + quantized = relax.op.quantize( data=expr, - output_scale=tensor_to_quantize.qnn_params["scale"], - output_zero_point=tensor_to_quantize.qnn_params["zero_point"], + scale=tensor_to_quantize.qnn_params["scale"], + zero_point=tensor_to_quantize.qnn_params["zero_point"], + axis=tensor_to_quantize.qnn_params["axis"], out_dtype=tensor_type_str, ) return quantized def dequantize(self, expr, tensor): """Helper function to dequantize a tensor with Relax""" - dequantized = _qnn.op.dequantize( + dequantized = relax.op.dequantize( data=expr, - input_scale=tensor.qnn_params["scale"], - input_zero_point=tensor.qnn_params["zero_point"], + scale=tensor.qnn_params["scale"], + zero_point=tensor.qnn_params["zero_point"], + axis=tensor.qnn_params["axis"], ) return dequantized diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index af4479043ee5..f7ce628a710f 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -5823,6 +5823,129 @@ def test_tensor_quantization_parameters_are_parsed(): assert len(mod["main"].params) == 2 +def test_quantize_op_uses_relax_quantize(): + """TFLite QUANTIZE float32 -> int8 uses R.quantize.""" + builder = flatbuffers.Builder(1024) + + input_data = np.array([1.0, 2.0], dtype=np.float32) + output_qparams = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + + input_tensor = _build_tensor(builder, 0, [2], tensor_type=_tfl_tensor_type.FLOAT32) + output_tensor = _build_tensor( + builder, + 1, + [2], + tensor_type=_tfl_tensor_type.INT8, + quantization=output_qparams, + ) + + quantize_op = _build_operator(builder, 0, [0], [1]) + subgraph = _build_subgraph( + builder, + tensors=[input_tensor, output_tensor], + operators=[quantize_op], + inputs=[0], + outputs=[1], + ) + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.QUANTIZE)] + input_buffer = _build_buffer(builder, input_data.tobytes()) + output_buffer = _build_buffer(builder) + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[input_buffer, output_buffer], + ) + + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + mod = from_tflite(tflite_model) + mod["main"] = mod["main"].without_attr("params") + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2,), dtype="float32")) -> R.Tensor((2,), dtype="int8"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor((2,), dtype="int8") = R.quantize( + x, + R.const(0.5, "float32"), + R.const(3, "int32"), + axis=0, + out_dtype="int8", + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_dequantize_op_uses_relax_dequantize(): + """TFLite DEQUANTIZE int8 -> float32 uses R.dequantize.""" + builder = flatbuffers.Builder(1024) + + input_data = np.array([10, 20], dtype=np.int8) + input_qparams = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + + input_tensor = _build_tensor( + builder, + 0, + [2], + tensor_type=_tfl_tensor_type.INT8, + quantization=input_qparams, + ) + output_tensor = _build_tensor(builder, 1, [2], tensor_type=_tfl_tensor_type.FLOAT32) + + dequantize_op = _build_operator(builder, 0, [0], [1]) + subgraph = _build_subgraph( + builder, + tensors=[input_tensor, output_tensor], + operators=[dequantize_op], + inputs=[0], + outputs=[1], + ) + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.DEQUANTIZE)] + input_buffer = _build_buffer(builder, input_data.tobytes()) + output_buffer = _build_buffer(builder) + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[input_buffer, output_buffer], + ) + + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + mod = from_tflite(tflite_model) + mod["main"] = mod["main"].without_attr("params") + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2,), dtype="int8")) -> R.Tensor((2,), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor((2,), dtype="float32") = R.dequantize( + x, + R.const(0.5, "float32"), + R.const(3, "int32"), + axis=0, + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + def test_stablehlo_cbrt(): """TFLite StableHLO CBRT uses a sign-preserving composite expression.""" mod = _load_model_from_buffer( From 84a2a6f0309ba3a65d251fe61a582d1688953186 Mon Sep 17 00:00:00 2001 From: HoYi Date: Mon, 11 May 2026 23:11:18 +0800 Subject: [PATCH 3/7] [Relax][Frontend][TFLite] Support quantized Conv2D via QDQ decomposition MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the _qnn.op.conv2d and _qnn.op.requantize calls in convert_conv() with a DQ → float conv2d → Q flow using the existing relax.op.dequantize and relax.op.quantize operators that were wired in PR #2. Changes to convert_conv(): - Dequantize input activation and weight before the float conv2d. - Remap per-channel weight QuantizedDimension() from the original TFLite layout (OC=0) to the HWIO layout (OC=3) for the dequantize axis. - Dequantize INT32/INT64 bias before adding to the float conv output. - Replace the fused _qnn.op.requantize + activation call with self.quantize() + convert_qnn_fused_activation_function(). Test: test_quantized_conv2d_per_tensor_uses_qdq builds a minimal TFLite flatbuffer with a per-tensor quantized Conv2D and asserts the IR uses dequantize → permute_dims → dequantize → conv2d → quantize. Known limitations (will be addressed in follow-ups): - DepthwiseConv2D axis remap not yet handled. - Per-channel weight test not yet added. - INT32 bias dequantization uses input_scale only (not input_scale × per-channel weight_scale). Part of #19534 (quantized TFLite import). --- .../relax/frontend/tflite/tflite_frontend.py | 78 +++--- tests/python/relax/test_frontend_tflite.py | 250 ++++++++++++++++++ 2 files changed, 286 insertions(+), 42 deletions(-) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index a0f85dc54514..20b25609f0f3 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -3446,15 +3446,23 @@ def convert_conv(self, op, conv_type): ) if input_tensor.qnn_params: - qnn_conv2d_params = dict(params) - qnn_conv2d_params["input_zero_point"] = input_tensor.qnn_params["zero_point"] - qnn_conv2d_params["kernel_zero_point"] = weight_tensor.qnn_params["zero_point"] - qnn_conv2d_params["out_dtype"] = ( - "int64" if output_tensor_type_str == "int16" else "int32" - ) - qnn_conv2d_params["input_scale"] = input_tensor.qnn_params["scale"] - qnn_conv2d_params["kernel_scale"] = weight_tensor.qnn_params["scale"] - out = _qnn.op.conv2d(in_expr, weight_expr, **qnn_conv2d_params) + # Dequantize input activation + in_f32 = self.dequantize(in_expr, input_tensor) + # Dequantize weight with per-channel axis remap. + # TFLite weight original layout: [OC, KH, KW, IC] + # After transpose to HWIO: [KH, KW, IC, OC] + # QuantizedDimension() == 0 (OC in original) → axis 3 in HWIO. + weight_axis = weight_tensor.qnn_params["axis"] + if not is_depthwise_conv: + weight_axis = 3 + w_f32 = relax.op.dequantize( + weight_expr, + scale=weight_tensor.qnn_params["scale"], + zero_point=weight_tensor.qnn_params["zero_point"], + axis=weight_axis, + ) + # Float convolution + out = relax.op.nn.conv2d(in_f32, w_f32, **params) else: out = relax.op.nn.conv2d(in_expr, weight_expr, **params) @@ -3477,37 +3485,28 @@ def convert_conv(self, op, conv_type): dtype=bias_tensor_type_str, source_name=bias_tensor.tensor.Name(), ) + # For quantized conv, INT32/INT64 bias must be dequantized + # to float32 before adding to the float conv output. + if bias_tensor.qnn_params: + bias_expr = self.dequantize(bias_expr, bias_tensor) + elif input_tensor.qnn_params and bias_tensor_type in ( + TensorType.INT32, + TensorType.INT64, + ): + bias_expr = relax.op.dequantize( + bias_expr, + scale=input_tensor.qnn_params["scale"], + zero_point=relax.const(0, "int32"), + axis=0, + ) out = relax.op.add(out, bias_expr) # Handle fused activation. if output_tensor.qnn_params: - # Calculate the intermediate scale and zero point of the int32 output. - data_scale = input_tensor.qnn_params["scale"] - data_scale_val = get_scalar_from_constant(data_scale) - - weight_scale = weight_tensor.qnn_params["scale"] - # If weight scale is scalar, it is per-tensor quantization - if isinstance(weight_scale, float): - weight_scale_val = get_scalar_from_constant(weight_scale) - else: - weight_scale_val = get_tensor_from_constant(weight_scale) - - new_input_scale_val = data_scale_val * weight_scale_val - new_input_scale = relax.const(new_input_scale_val, "float32") - new_input_zero_point = relax.const(0, "int32") - - # Finally requantize - out = _qnn.op.requantize( - out, - input_scale=new_input_scale, - input_zero_point=new_input_zero_point, - output_scale=output_tensor.qnn_params["scale"], - output_zero_point=output_tensor.qnn_params["zero_point"], - out_dtype=output_tensor_type_str, - axis=3, - ) + # Quantize the float output using the output tensor's qnn params. + out = self.quantize(out, output_tensor) - # Call activation function + # Call quantized activation function output_scale_val = get_scalar_from_constant(output_tensor.qnn_params["scale"]) output_zero_point_val = get_scalar_from_constant(output_tensor.qnn_params["zero_point"]) out = self.convert_qnn_fused_activation_function( @@ -5088,13 +5087,8 @@ def convert_quantize(self, op): if input_tensor_type_str == "float32": out = self.quantize(in_expr, output_tensor) else: - out = _qnn.op.requantize( - in_expr, - input_scale=input_tensor.qnn_params["scale"], - input_zero_point=input_tensor.qnn_params["zero_point"], - output_scale=output_tensor.qnn_params["scale"], - output_zero_point=output_tensor.qnn_params["zero_point"], - out_dtype=output_tensor_type_str, + raise tvm.error.OpNotImplemented( + "TFLite QUANTIZE acting as requantize is not supported yet" ) return out diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index f7ce628a710f..faef3a482431 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -5946,6 +5946,256 @@ def main(x: R.Tensor((2,), dtype="int8")) -> R.Tensor((2,), dtype="float32"): tvm.ir.assert_structural_equal(mod, Expected) +def test_quantized_conv2d_per_tensor_uses_qdq(): + """Quantized Conv2D with per-tensor quantization uses DQ -> conv2d -> Q.""" + builder = flatbuffers.Builder(2048) + + in_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + wt_q = _build_quantization_parameters( + builder, scale=[0.25], zero_point=[0], quantized_dimension=0 + ) + out_q = _build_quantization_parameters( + builder, scale=[1.0], zero_point=[0], quantized_dimension=0 + ) + + input_tensor = _build_tensor( + builder, + 0, + [1, 4, 4, 1], + tensor_type=_tfl_tensor_type.INT8, + quantization=in_q, + ) + weight_tensor = _build_tensor( + builder, + 1, + [2, 3, 3, 1], + tensor_type=_tfl_tensor_type.INT8, + quantization=wt_q, + ) + output_tensor = _build_tensor( + builder, + 2, + [1, 2, 2, 2], + tensor_type=_tfl_tensor_type.INT8, + quantization=out_q, + ) + + _tfl_conv2d_options.Conv2DOptionsStart(builder) + _tfl_conv2d_options.Conv2DOptionsAddStrideH(builder, 1) + _tfl_conv2d_options.Conv2DOptionsAddStrideW(builder, 1) + _tfl_conv2d_options.Conv2DOptionsAddPadding(builder, _tfl_padding.VALID) + _tfl_conv2d_options.Conv2DOptionsAddFusedActivationFunction(builder, 0) + conv_opts = _tfl_conv2d_options.Conv2DOptionsEnd(builder) + + conv_op = _build_operator( + builder, + 0, + [0, 1], + [2], + builtin_options_type=_tfl_builtin_options.Conv2DOptions, + builtin_options=conv_opts, + ) + subgraph = _build_subgraph( + builder, + tensors=[input_tensor, weight_tensor, output_tensor], + operators=[conv_op], + inputs=[0, 1], + outputs=[2], + ) + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.CONV_2D)] + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[_build_buffer(builder), _build_buffer(builder), _build_buffer(builder)], + ) + + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + mod = from_tflite(tflite_model) + mod["main"] = mod["main"].without_attr("params") + + @I.ir_module + class Expected: + @R.function + def main( + tvmgen_tensor_0: R.Tensor((1, 4, 4, 1), dtype="int8"), + tvmgen_tensor_1: R.Tensor((2, 3, 3, 1), dtype="int8"), + ) -> R.Tensor((1, 2, 2, 2), dtype="int8"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + lv: R.Tensor((1, 4, 4, 1), dtype="float32") = R.dequantize( + tvmgen_tensor_0, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="float32", + axis=0, + ) + lv1: R.Tensor((3, 3, 1, 2), dtype="int8") = R.permute_dims( + tvmgen_tensor_1, + axes=[1, 2, 3, 0], + ) + lv2: R.Tensor((3, 3, 1, 2), dtype="float32") = R.dequantize( + lv1, + R.const(0.25, "float32"), + R.const(0, "int32"), + out_dtype="float32", + axis=3, + ) + lv3: R.Tensor((1, 2, 2, 2), dtype="float32") = R.nn.conv2d( + lv, + lv2, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="HWIO", + out_layout="NHWC", + out_dtype="void", + ) + gv: R.Tensor((1, 2, 2, 2), dtype="int8") = R.quantize( + lv3, + R.const(1.0, "float32"), + R.const(0, "int32"), + out_dtype="int8", + axis=0, + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_quantized_conv2d_per_channel_weight_uses_remapped_axis(): + """Quantized Conv2D remaps per-channel weight axis after OHWI -> HWIO.""" + builder = flatbuffers.Builder(2048) + + in_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + wt_q = _build_quantization_parameters( + builder, scale=[0.25, 0.75], zero_point=[0, 0], quantized_dimension=0 + ) + out_q = _build_quantization_parameters( + builder, scale=[1.0], zero_point=[0], quantized_dimension=0 + ) + + input_tensor = _build_tensor( + builder, + 0, + [1, 4, 4, 1], + tensor_type=_tfl_tensor_type.INT8, + quantization=in_q, + ) + weight_tensor = _build_tensor( + builder, + 1, + [2, 3, 3, 1], + tensor_type=_tfl_tensor_type.INT8, + quantization=wt_q, + ) + output_tensor = _build_tensor( + builder, + 2, + [1, 2, 2, 2], + tensor_type=_tfl_tensor_type.INT8, + quantization=out_q, + ) + + _tfl_conv2d_options.Conv2DOptionsStart(builder) + _tfl_conv2d_options.Conv2DOptionsAddStrideH(builder, 1) + _tfl_conv2d_options.Conv2DOptionsAddStrideW(builder, 1) + _tfl_conv2d_options.Conv2DOptionsAddPadding(builder, _tfl_padding.VALID) + _tfl_conv2d_options.Conv2DOptionsAddFusedActivationFunction(builder, 0) + conv_opts = _tfl_conv2d_options.Conv2DOptionsEnd(builder) + + conv_op = _build_operator( + builder, + 0, + [0, 1], + [2], + builtin_options_type=_tfl_builtin_options.Conv2DOptions, + builtin_options=conv_opts, + ) + subgraph = _build_subgraph( + builder, + tensors=[input_tensor, weight_tensor, output_tensor], + operators=[conv_op], + inputs=[0, 1], + outputs=[2], + ) + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.CONV_2D)] + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[_build_buffer(builder), _build_buffer(builder), _build_buffer(builder)], + ) + + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + mod = from_tflite(tflite_model) + mod["main"] = mod["main"].without_attr("params") + + @I.ir_module + class Expected: + @R.function + def main( + tvmgen_tensor_0: R.Tensor((1, 4, 4, 1), dtype="int8"), + tvmgen_tensor_1: R.Tensor((2, 3, 3, 1), dtype="int8"), + ) -> R.Tensor((1, 2, 2, 2), dtype="int8"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + lv: R.Tensor((1, 4, 4, 1), dtype="float32") = R.dequantize( + tvmgen_tensor_0, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="float32", + axis=0, + ) + lv1: R.Tensor((3, 3, 1, 2), dtype="int8") = R.permute_dims( + tvmgen_tensor_1, + axes=[1, 2, 3, 0], + ) + lv2: R.Tensor((3, 3, 1, 2), dtype="float32") = R.dequantize( + lv1, + R.const([0.25, 0.75], "float32"), + R.const(0, "int32"), + out_dtype="float32", + axis=3, + ) + lv3: R.Tensor((1, 2, 2, 2), dtype="float32") = R.nn.conv2d( + lv, + lv2, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="HWIO", + out_layout="NHWC", + out_dtype="void", + ) + gv: R.Tensor((1, 2, 2, 2), dtype="int8") = R.quantize( + lv3, + R.const(1.0, "float32"), + R.const(0, "int32"), + out_dtype="int8", + axis=0, + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + def test_stablehlo_cbrt(): """TFLite StableHLO CBRT uses a sign-preserving composite expression.""" mod = _load_model_from_buffer( From 5e45aad7c8e335ace7445e141fbd99397eee5e9b Mon Sep 17 00:00:00 2001 From: HoYi Date: Mon, 11 May 2026 23:45:39 +0800 Subject: [PATCH 4/7] [Relax][Frontend][TFLite] Replace _qnn.op.requantize with QDQ in simple ops MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the remaining _qnn.op.requantize calls in elementwise and reshape/reduce converters with the DQ → float op → Q pattern using the existing relax.op.dequantize / relax.op.quantize operators. Converters updated: - convert_relu: QNN fused RELU + requantize → DQ → relu → Q - convert_relu6: QNN fused RELU6 + requantize → DQ → clip → Q - convert_relu_n1_to_1: quantized clip + requantize → DQ → clip → Q - convert_reshape: uint8 requantize → self.quantize - _convert_reduce: int32 cast + requantize → DQ → op → Q (covers multinomial and all reduce-like ops) Part of #19534 (quantized TFLite import). --- .../relax/frontend/tflite/tflite_frontend.py | 110 +++--------------- 1 file changed, 13 insertions(+), 97 deletions(-) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 20b25609f0f3..8bacb87b4782 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -794,15 +794,7 @@ def convert_reshape(self, op): if input_tensor.qnn_params and input_tensor_type_str == "uint8": output_tensor = output_tensors[0] if not self.has_same_qnn_params(input_tensor, output_tensor): - output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) - out = _qnn.op.requantize( - out, - input_scale=input_tensor.qnn_params["scale"], - input_zero_point=input_tensor.qnn_params["zero_point"], - output_scale=output_tensor.qnn_params["scale"], - output_zero_point=output_tensor.qnn_params["zero_point"], - out_dtype=output_tensor_type_str, - ) + out = self.quantize(out, output_tensor) return out @@ -1113,8 +1105,6 @@ def convert_shape(self, op): def convert_relu(self, op): """Convert TFLite ReLU""" - from tflite.ActivationFunctionType import ActivationFunctionType - input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 1, "input tensors length should be 1" input_tensor = input_tensors[0] @@ -1125,32 +1115,12 @@ def convert_relu(self, op): output_tensor = output_tensors[0] if input_tensor.qnn_params: - # Quantize a float value to an quantized integer value - scale_val = get_scalar_from_constant(input_tensor.qnn_params["scale"]) - zero_point_val = get_scalar_from_constant(input_tensor.qnn_params["zero_point"]) - - output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) - out = self.convert_qnn_fused_activation_function( - expr=in_expr, - fused_activation_fn=ActivationFunctionType.RELU, - scale=scale_val, - zero_point=zero_point_val, - dtype=output_tensor_type_str, - ) + in_f32 = self.dequantize(in_expr, input_tensor) + out = relax.op.nn.relu(in_f32) + out = self.quantize(out, output_tensor) else: out = relax.op.nn.relu(in_expr) - if output_tensor.qnn_params: - output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) - out = _qnn.op.requantize( - out, - input_scale=input_tensor.qnn_params["scale"], - input_zero_point=input_tensor.qnn_params["zero_point"], - output_scale=output_tensor.qnn_params["scale"], - output_zero_point=output_tensor.qnn_params["zero_point"], - out_dtype=output_tensor_type_str, - ) - return out def convert_hard_swish(self, op): @@ -1186,8 +1156,6 @@ def _hard_swish(data): def convert_relu6(self, op): """Convert TFLite ReLU6""" - from tflite.ActivationFunctionType import ActivationFunctionType - input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 1, "input tensors length should be 1" input_tensor = input_tensors[0] @@ -1198,32 +1166,12 @@ def convert_relu6(self, op): output_tensor = output_tensors[0] if input_tensor.qnn_params: - # Quantize a float value to an quantized integer value - scale_val = get_scalar_from_constant(input_tensor.qnn_params["scale"]) - zero_point_val = get_scalar_from_constant(input_tensor.qnn_params["zero_point"]) - - output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) - out = self.convert_qnn_fused_activation_function( - expr=in_expr, - fused_activation_fn=ActivationFunctionType.RELU6, - scale=scale_val, - zero_point=zero_point_val, - dtype=output_tensor_type_str, - ) + in_f32 = self.dequantize(in_expr, input_tensor) + out = relax.op.clip(in_f32, min=0, max=6) + out = self.quantize(out, output_tensor) else: out = relax.op.clip(in_expr, min=0, max=6) - if output_tensor.qnn_params: - output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) - out = _qnn.op.requantize( - out, - input_scale=input_tensor.qnn_params["scale"], - input_zero_point=input_tensor.qnn_params["zero_point"], - output_scale=output_tensor.qnn_params["scale"], - output_zero_point=output_tensor.qnn_params["zero_point"], - out_dtype=output_tensor_type_str, - ) - return out def convert_leaky_relu(self, op): @@ -1267,36 +1215,12 @@ def convert_relu_n1_to_1(self, op): output_tensor = output_tensors[0] if input_tensor.qnn_params: - # Quantize a float value to an quantized integer value - scale_val = get_scalar_from_constant(input_tensor.qnn_params["scale"]) - zero_point_val = get_scalar_from_constant(input_tensor.qnn_params["zero_point"]) - - def quantize(x): - return float(round(x / scale_val) + zero_point_val) - - # Get min/max of the input dtype. This will be used to ensure that - # clip a_min/a_max are not beyond the dtype range. - input_tensor_type_str = self.get_tensor_type_str(input_tensor.tensor.Type()) - qmin = float(tvm.tirx.min_value(input_tensor_type_str).value) - qmax = float(tvm.tirx.max_value(input_tensor_type_str).value) - - out = relax.op.clip( - in_expr, min=max(qmin, quantize(-1.0)), max=min(qmax, quantize(1.0)) - ) + in_f32 = self.dequantize(in_expr, input_tensor) + out = relax.op.clip(in_f32, min=-1, max=1) + out = self.quantize(out, output_tensor) else: out = relax.op.clip(in_expr, min=-1, max=1) - if output_tensor.qnn_params: - output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) - out = _qnn.op.requantize( - out, - input_scale=input_tensor.qnn_params["scale"], - input_zero_point=input_tensor.qnn_params["zero_point"], - output_scale=output_tensor.qnn_params["scale"], - output_zero_point=output_tensor.qnn_params["zero_point"], - out_dtype=output_tensor_type_str, - ) - return out def convert_log_softmax(self, op): @@ -3043,24 +2967,16 @@ def _convert_reduce(self, relax_op, op): keep_dims = False if input_tensor.qnn_params: - in_expr = relax.op.cast(in_expr, "int32") + in_expr = self.dequantize(in_expr, input_tensor) out = relax_op(in_expr, axis, keep_dims) - # Finally if the reduce is quantized. Add a requantize at the end. + # Finally if the reduce is quantized. Quantize the output. output_tensors = self.get_output_tensors(op) assert len(output_tensors) == 1, "output tensors length should be 1" output_tensor = output_tensors[0] - output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) if output_tensor.qnn_params: - out = _qnn.op.requantize( - out, - input_scale=input_tensor.qnn_params["scale"], - input_zero_point=input_tensor.qnn_params["zero_point"], - output_scale=output_tensor.qnn_params["scale"], - output_zero_point=output_tensor.qnn_params["zero_point"], - out_dtype=output_tensor_type_str, - ) + out = self.quantize(out, output_tensor) return out From 4364f6a9397e9c21a5597ffdfe1abf3f41012337 Mon Sep 17 00:00:00 2001 From: HoYi Date: Tue, 12 May 2026 00:59:35 +0800 Subject: [PATCH 5/7] [Relax][Frontend][TFLite] Complete QDQ conversion for remaining quantized ops MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the last _qnn.op.* references in the TFLite frontend with the DQ → float op → Q pattern, eliminating all references to the non-existent _qnn module. convert_fully_connected: - _qnn.op.dense → DQ input + DQ weight (axis remap OC 0→1) + matmul - _qnn.op.requantize + activation → self.quantize + activation - INT32/INT64 bias dequantized with input_scale × weight_scale convert_concatenation: - _qnn.op.concat → DQ each input → float concat → quantize → activation convert_transpose_conv: - _qnn.op.conv2d_transpose → DQ input + DQ weight (axis remap OHWI→IOHW, OC axis 0→1) + float conv2d_transpose - _qnn.op.requantize → self.quantize - INT32/INT64 bias dequantized (previously missing — added in review fix) convert_detection_postprocess: - 3× _qnn.op.dequantize → self.dequantize convert_reshape (uint8 path): - Requantize on integer tensor → DQ → reshape → Q Depthwise Conv2D: - Explicit OpNotImplemented for per-channel depthwise (axis semantics change after [1,KH,KW,C*M] → [KH,KW,C,M] reshape) Cleanup: - Removed now-unnecessary F821 noqa comment (zero _qnn / _expr refs) - Removed unused locals (weight_shape, output_tensor_type_str) All _qnn.op.* references eliminated. 386 tests pass, ruff clean. Closes #19534. --- .../relax/frontend/tflite/tflite_frontend.py | 258 ++-- tests/python/relax/test_frontend_tflite.py | 1137 +++++++++++++++++ 2 files changed, 1249 insertions(+), 146 deletions(-) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 8bacb87b4782..ddc97b77521c 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -19,11 +19,6 @@ # pylint: disable=no-value-for-parameter, unused-variable # pylint: disable=unexpected-keyword-arg, unused-import, too-many-function-args # ruff: noqa: RUF005 -# F821: remaining _qnn references (requantize, conv2d, dense, concat, -# conv2d_transpose, and detection-postprocess dequantize) are in -# not-yet-covered code paths and will be resolved as quantized op support -# advances. _expr references will be resolved when vision ops are added. -# ruff: noqa: F821 """Tensorflow lite frontend.""" import functools @@ -715,7 +710,7 @@ def quantize(x): if fused_activation_fn == ActivationFunctionType.RELU_N1_TO_1: return relax.op.clip(expr, min=max(qmin, quantize(-1.0)), max=min(qmax, quantize(1.0))) if fused_activation_fn == ActivationFunctionType.RELU: - return relax.op.clip(expr, min=max(qmin, quantize(0.0)), a_max=qmax) + return relax.op.clip(expr, min=max(qmin, quantize(0.0)), max=qmax) fused_activation_fn_str = self.activation_fn_type[fused_activation_fn] raise tvm.error.OpNotImplemented( @@ -790,12 +785,15 @@ def convert_reshape(self, op): "TFLite reshape requires input and output scale and zero points to be equal" ) - out = relax.op.reshape(in_expr, shape=relax.ShapeExpr(target_shape)) if input_tensor.qnn_params and input_tensor_type_str == "uint8": output_tensor = output_tensors[0] if not self.has_same_qnn_params(input_tensor, output_tensor): + in_f32 = self.dequantize(in_expr, input_tensor) + out = relax.op.reshape(in_f32, shape=relax.ShapeExpr(target_shape)) out = self.quantize(out, output_tensor) + return out + out = relax.op.reshape(in_expr, shape=relax.ShapeExpr(target_shape)) return out def _convert_resize(self, method, op): @@ -1266,18 +1264,11 @@ def convert_concatenation(self, op): if not input_tensors[0].qnn_params: out = relax.op.concat(in_exprs, axis=concatenation_axis) else: - input_scales = [input_tensor.qnn_params["scale"] for input_tensor in input_tensors] - input_zero_points = [ - input_tensor.qnn_params["zero_point"] for input_tensor in input_tensors + in_f32s = [ + self.dequantize(expr, tensor) for expr, tensor in zip(in_exprs, input_tensors) ] - out = _qnn.op.concat( - in_exprs, - input_scales=input_scales, - input_zero_points=input_zero_points, - output_scale=output_tensor.qnn_params["scale"], - output_zero_point=output_tensor.qnn_params["zero_point"], - axis=concatenation_axis, - ) + out = relax.op.concat(in_f32s, axis=concatenation_axis) + out = self.quantize(out, output_tensor) # Handle fused activations if output_tensor.qnn_params: @@ -2367,7 +2358,7 @@ def convert_square(self, op): return out - def _convert_elemwise(self, op, relax_op, relax_qnn_op=None, comparison_op=False): + def _convert_elemwise(self, op, relax_op, comparison_op=False): """Generic method to Convert TFLite elemwise""" from tflite.AddOptions import AddOptions @@ -2376,7 +2367,6 @@ def _convert_elemwise(self, op, relax_op, relax_qnn_op=None, comparison_op=False from tflite.MulOptions import MulOptions from tflite.SubOptions import SubOptions - ignore_qnn_params = self.is_quantized(op) input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 2, "input tensors length should be 2" @@ -2384,36 +2374,19 @@ def _convert_elemwise(self, op, relax_op, relax_qnn_op=None, comparison_op=False rhs_tensor = input_tensors[1] lhs_expr = self.get_tensor_expr(lhs_tensor) rhs_expr = self.get_tensor_expr(rhs_tensor) + input_is_quantized = lhs_tensor.qnn_params is not None or rhs_tensor.qnn_params is not None output_tensors = self.get_output_tensors(op) assert len(output_tensors) == 1, "output tensors length should be 1" output_tensor = output_tensors[0] - # TFLite format demands equal scale and zero_point tuple parameters for some operations - # to allow us to use non-quantized operation instead of quantized if ignore_qnn_params=True - if ignore_qnn_params and not comparison_op: - assert ( - lhs_tensor.qnn_params - and self.has_same_qnn_params(lhs_tensor, output_tensor) - and self.has_same_qnn_params(rhs_tensor, output_tensor) - ), "All tensors should be quantized with the same (scale,zero-point) tuple parameters" - - # If quantized, extracts qnn params and call QNN add operator. - if not ignore_qnn_params and lhs_tensor.qnn_params: - assert rhs_tensor.qnn_params, "Both tensors should be quantized." - assert output_tensor.qnn_params, "Output tensor should be quantized." - out = relax_op( - lhs=lhs_expr, - rhs=rhs_expr, - lhs_scale=lhs_tensor.qnn_params["scale"], - lhs_zero_point=lhs_tensor.qnn_params["zero_point"], - rhs_scale=rhs_tensor.qnn_params["scale"], - rhs_zero_point=rhs_tensor.qnn_params["zero_point"], - output_scale=output_tensor.qnn_params["scale"], - output_zero_point=output_tensor.qnn_params["zero_point"], - ) - else: - out = relax_op(lhs_expr, rhs_expr) + if input_is_quantized: + if lhs_tensor.qnn_params: + lhs_expr = self.dequantize(lhs_expr, lhs_tensor) + if rhs_tensor.qnn_params: + rhs_expr = self.dequantize(rhs_expr, rhs_tensor) + + out = relax_op(lhs_expr, rhs_expr) # Options (fused_activation_function) options = None @@ -2431,20 +2404,14 @@ def _convert_elemwise(self, op, relax_op, relax_qnn_op=None, comparison_op=False options.Init(op_options.Bytes, op_options.Pos) fused_activation_fn = options.FusedActivationFunction() - # Handle fused activations - if not ignore_qnn_params and output_tensor.qnn_params: - scale_val = get_scalar_from_constant(output_tensor.qnn_params["scale"]) - zero_point_val = get_scalar_from_constant(output_tensor.qnn_params["zero_point"]) - output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) - out = self.convert_qnn_fused_activation_function( - expr=out, - fused_activation_fn=fused_activation_fn, - scale=scale_val, - zero_point=zero_point_val, - dtype=output_tensor_type_str, + out = self.convert_fused_activation_function(out, fused_activation_fn) + + if input_is_quantized and not comparison_op: + if not output_tensor.qnn_params: + raise tvm.error.OpAttributeInvalid( + "Quantized TFLite elemwise operator output must have quantization parameters" ) - else: - out = self.convert_fused_activation_function(out, fused_activation_fn) + out = self.quantize(out, output_tensor) return out def convert_add_n(self, op): @@ -3093,20 +3060,24 @@ def convert_fully_connected(self, op): ) weight_expr = self.get_tensor_expr(weight_tensor) - weight_shape = weight_expr.struct_info.shape weight_expr = relax.op.permute_dims(weight_expr, [1, 0]) if input_tensor.qnn_params: - out = _qnn.op.dense( - in_expr, + # Dequantize input and weight (OC remapped from axis 0 to 1) + in_f32 = self.dequantize(in_expr, input_tensor) + weight_axis = weight_tensor.qnn_params["axis"] + if weight_axis != 0: + raise tvm.error.OpAttributeInvalid( + f"FC weight QuantizedDimension() must be 0 (output-channel " + f"axis in [OC,IC] layout), got {weight_axis}" + ) + w_f32 = relax.op.dequantize( weight_expr, - input_zero_point=input_tensor.qnn_params["zero_point"], - kernel_zero_point=weight_tensor.qnn_params["zero_point"], - input_scale=input_tensor.qnn_params["scale"], - kernel_scale=weight_tensor.qnn_params["scale"], - units=weight_shape[0], - out_dtype="int64" if output_tensor_type_str == "int16" else "int32", + scale=weight_tensor.qnn_params["scale"], + zero_point=weight_tensor.qnn_params["zero_point"], + axis=1, ) + out = relax.op.matmul(in_f32, w_f32) else: out = relax.op.matmul(in_expr, weight_expr) @@ -3130,27 +3101,27 @@ def convert_fully_connected(self, op): dtype=bias_tensor_type_str, source_name=bias_tensor.tensor.Name(), ) + if bias_tensor.qnn_params: + bias_expr = self.dequantize(bias_expr, bias_tensor) + elif input_tensor.qnn_params and bias_tensor_type in ( + TensorType.INT32, + TensorType.INT64, + ): + bias_scale = relax.op.multiply( + input_tensor.qnn_params["scale"], + weight_tensor.qnn_params["scale"], + ) + bias_expr = relax.op.dequantize( + bias_expr, + scale=bias_scale, + zero_point=relax.const(0, "int32"), + axis=0, + ) out = relax.op.add(out, bias_expr) - # Finally if the dense is quantized. Add a requantize at the end. + # Finally if the dense is quantized. Quantize the output. if output_tensor.qnn_params: - data_scale = input_tensor.qnn_params["scale"] - weight_scale = weight_tensor.qnn_params["scale"] - data_scale_val = get_scalar_from_constant(data_scale) - weight_scale_val = get_scalar_from_constant(weight_scale) - new_input_scale_val = data_scale_val * weight_scale_val - new_input_scale = relax.const(new_input_scale_val, "float32") - new_input_zero_point = relax.const(0, "int32") - - # Requantize - out = _qnn.op.requantize( - out, - input_scale=new_input_scale, - input_zero_point=new_input_zero_point, - output_scale=output_tensor.qnn_params["scale"], - output_zero_point=output_tensor.qnn_params["zero_point"], - out_dtype=output_tensor_type_str, - ) + out = self.quantize(out, output_tensor) # Call activation function output_scale_val = get_scalar_from_constant(output_tensor.qnn_params["scale"]) @@ -3369,7 +3340,19 @@ def convert_conv(self, op, conv_type): # After transpose to HWIO: [KH, KW, IC, OC] # QuantizedDimension() == 0 (OC in original) → axis 3 in HWIO. weight_axis = weight_tensor.qnn_params["axis"] - if not is_depthwise_conv: + if is_depthwise_conv: + if weight_axis != 0: + raise tvm.error.OpNotImplemented( + "Per-channel quantized depthwise convolution is not supported " + "because the channel axis changes semantics after the " + "[1,KH,KW,C*M] → [KH,KW,C,M] reshape." + ) + else: + if weight_axis != 0: + raise tvm.error.OpAttributeInvalid( + f"Conv2D weight QuantizedDimension() must be 0 (output-channel " + f"axis in [OC,KH,KW,IC] layout), got {weight_axis}" + ) weight_axis = 3 w_f32 = relax.op.dequantize( weight_expr, @@ -3411,7 +3394,10 @@ def convert_conv(self, op, conv_type): ): bias_expr = relax.op.dequantize( bias_expr, - scale=input_tensor.qnn_params["scale"], + scale=relax.op.multiply( + input_tensor.qnn_params["scale"], + weight_tensor.qnn_params["scale"], + ), zero_point=relax.const(0, "int32"), axis=0, ) @@ -4902,25 +4888,27 @@ def convert_transpose_conv(self, op): padding = (0, 0, 0, 0) if input_tensor.qnn_params: - input_zero_point = input_tensor.qnn_params["zero_point"] - kernel_zero_point = weights_tensor.qnn_params["zero_point"] - input_scale = input_tensor.qnn_params["scale"] - kernel_scale = weights_tensor.qnn_params["scale"] - out_dtype = "int64" if output_tensor_type_str == "int16" else "int32" - out = _qnn.op.conv2d_transpose( - in_expr, + in_f32 = self.dequantize(in_expr, input_tensor) + weight_axis = weights_tensor.qnn_params["axis"] + if weight_axis != 0: + raise tvm.error.OpAttributeInvalid( + f"TransposeConv weight QuantizedDimension() must be 0 " + f"(output-channel axis in OHWI layout), got {weight_axis}" + ) + w_f32 = relax.op.dequantize( weight_expr_iohw, - input_zero_point, - kernel_zero_point, - input_scale, - kernel_scale, + scale=weights_tensor.qnn_params["scale"], + zero_point=weights_tensor.qnn_params["zero_point"], + axis=1, + ) + out = relax.op.nn.conv2d_transpose( + in_f32, + w_f32, strides=(stride_h, stride_w), padding=padding, - channels=int(out_channels), - kernel_size=(int(kernel_h), int(kernel_w)), data_layout="NHWC", kernel_layout="IOHW", - out_dtype=out_dtype, + out_dtype="float32", ) else: out = relax.op.nn.conv2d_transpose( @@ -4952,34 +4940,26 @@ def convert_transpose_conv(self, op): dtype=bias_tensor_type_str, source_name=bias_tensor.tensor.Name(), ) - channel_axis = 3 - out = relax.op.nn.bias_add(out, bias_expr, axis=channel_axis) + if bias_tensor.qnn_params: + bias_expr = self.dequantize(bias_expr, bias_tensor) + elif input_tensor.qnn_params and bias_tensor_type in ( + TensorType.INT32, + TensorType.INT64, + ): + bias_scale = relax.op.multiply( + input_tensor.qnn_params["scale"], + weights_tensor.qnn_params["scale"], + ) + bias_expr = relax.op.dequantize( + bias_expr, + scale=bias_scale, + zero_point=relax.const(0, "int32"), + axis=0, + ) + out = relax.op.add(out, bias_expr) if output_tensor.qnn_params: - # Calculate the intermediate scale and zero point of the int32 output. - data_scale = input_tensor.qnn_params["scale"] - data_scale_val = get_scalar_from_constant(data_scale) - - weight_scale = weights_tensor.qnn_params["scale"] - # If weight scale is scalar, it is per-tensor quantization - if isinstance(weight_scale, float): - weight_scale_val = get_scalar_from_constant(weight_scale) - else: - weight_scale_val = get_tensor_from_constant(weight_scale) - - new_input_scale_val = data_scale_val * weight_scale_val - new_input_scale = relax.const(new_input_scale_val, "float32") - new_input_zero_point = relax.const(0, "int32") - - out = _qnn.op.requantize( - out, - input_scale=new_input_scale, - input_zero_point=new_input_zero_point, - output_scale=output_tensor.qnn_params["scale"], - output_zero_point=output_tensor.qnn_params["zero_point"], - out_dtype=output_tensor_type_str, - axis=3, - ) + out = self.quantize(out, output_tensor) return out def convert_quantize(self, op): @@ -4994,7 +4974,6 @@ def convert_quantize(self, op): output_tensors = self.get_output_tensors(op) assert len(output_tensors) == 1, "output tensors length should be 1" output_tensor = output_tensors[0] - output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) # The output must be quantized assert output_tensor.qnn_params @@ -5003,9 +4982,8 @@ def convert_quantize(self, op): if input_tensor_type_str == "float32": out = self.quantize(in_expr, output_tensor) else: - raise tvm.error.OpNotImplemented( - "TFLite QUANTIZE acting as requantize is not supported yet" - ) + in_f32 = self.dequantize(in_expr, input_tensor) + out = self.quantize(in_f32, output_tensor) return out def convert_dequantize(self, op): @@ -5154,23 +5132,11 @@ def convert_detection_postprocess(self, op): ) if inputs[0].qnn_params: - loc_prob = _qnn.op.dequantize( - data=loc_prob, - input_scale=inputs[0].qnn_params["scale"], - input_zero_point=inputs[0].qnn_params["zero_point"], - ) + loc_prob = self.dequantize(loc_prob, inputs[0]) if inputs[1].qnn_params: - cls_pred = _qnn.op.dequantize( - data=cls_pred, - input_scale=inputs[1].qnn_params["scale"], - input_zero_point=inputs[1].qnn_params["zero_point"], - ) + cls_pred = self.dequantize(cls_pred, inputs[1]) if inputs[2].qnn_params: - anchor_expr = _qnn.op.dequantize( - data=anchor_expr, - input_scale=inputs[2].qnn_params["scale"], - input_zero_point=inputs[2].qnn_params["zero_point"], - ) + anchor_expr = self.dequantize(anchor_expr, inputs[2]) # loc_prob coords are in yxhw format # need to convert to xywh diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index faef3a482431..cc4af9c19341 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -3705,6 +3705,7 @@ def _get_tflite_schema_enum(enum_name): _tfl_builtin_operator = _get_tflite_schema_enum("BuiltinOperator") _tfl_builtin_options = _get_tflite_schema_enum("BuiltinOptions") _tfl_builtin_options2 = _get_tflite_schema_enum("BuiltinOptions2") +_tfl_activation_fn = _get_tflite_schema_enum("ActivationFunctionType") _tfl_dimension_type = _get_tflite_schema_enum("DimensionType") _tfl_fc_weights_format = _get_tflite_schema_enum("FullyConnectedOptionsWeightsFormat") _tfl_padding = _get_tflite_schema_enum("Padding") @@ -5885,6 +5886,93 @@ def main(x: R.Tensor((2,), dtype="float32")) -> R.Tensor((2,), dtype="int8"): tvm.ir.assert_structural_equal(mod, Expected) +def test_quantize_op_requantize_uses_dq_q(): + """TFLite QUANTIZE with quantized input uses DQ→Q (requantize).""" + builder = flatbuffers.Builder(1024) + + input_data = np.array([10, 20], dtype=np.int8) + input_qparams = _build_quantization_parameters( + builder, scale=[0.25], zero_point=[1], quantized_dimension=0 + ) + output_qparams = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + + input_tensor = _build_tensor( + builder, + 0, + [2], + tensor_type=_tfl_tensor_type.INT8, + quantization=input_qparams, + ) + output_tensor = _build_tensor( + builder, + 1, + [2], + tensor_type=_tfl_tensor_type.INT8, + quantization=output_qparams, + ) + + quantize_op = _build_operator( + builder, + 0, + [0], + [1], + ) + subgraph = _build_subgraph( + builder, + tensors=[input_tensor, output_tensor], + operators=[quantize_op], + inputs=[0], + outputs=[1], + ) + operator_codes = [ + _build_operator_code(builder, _tfl_builtin_operator.QUANTIZE), + ] + input_buffer = _build_buffer(builder, input_data.tobytes()) + output_buffer = _build_buffer(builder) + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[input_buffer, output_buffer], + ) + + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + mod = from_tflite(tflite_model) + mod["main"] = mod["main"].without_attr("params") + + @I.ir_module + class Expected: + @R.function + def main( + tvmgen_tensor_0: R.Tensor((2,), dtype="int8"), + ) -> R.Tensor((2,), dtype="int8"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + lv: R.Tensor((2,), dtype="float32") = R.dequantize( + tvmgen_tensor_0, + R.const(0.25, "float32"), + R.const(1, "int32"), + out_dtype="float32", + axis=0, + ) + gv: R.Tensor((2,), dtype="int8") = R.quantize( + lv, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="int8", + axis=0, + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + def test_dequantize_op_uses_relax_dequantize(): """TFLite DEQUANTIZE int8 -> float32 uses R.dequantize.""" builder = flatbuffers.Builder(1024) @@ -6594,6 +6682,1055 @@ def test_stablehlo_convolution_dimension_numbers_unsupported(): from_tflite(tflite_model) +def test_quantized_concat_uses_qdq(): + """Quantized CONCATENATION uses DQ each input → concat → Q.""" + import flatbuffers + import tflite.Model + + builder = flatbuffers.Builder(1024) + + in_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + out_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + + t0 = _build_tensor(builder, 0, [1, 2], tensor_type=_tfl_tensor_type.INT8, quantization=in_q) + t1 = _build_tensor(builder, 1, [1, 2], tensor_type=_tfl_tensor_type.INT8, quantization=in_q) + t2 = _build_tensor(builder, 2, [1, 4], tensor_type=_tfl_tensor_type.INT8, quantization=out_q) + + tflite.ConcatenationOptionsStart(builder) + tflite.ConcatenationOptionsAddAxis(builder, 1) + tflite.ConcatenationOptionsAddFusedActivationFunction(builder, 0) + concat_opts = tflite.ConcatenationOptionsEnd(builder) + + concat_op = _build_operator( + builder, + 0, + [0, 1], + [2], + builtin_options_type=_tfl_builtin_options.ConcatenationOptions, + builtin_options=concat_opts, + ) + subgraph = _build_subgraph( + builder, + tensors=[t0, t1, t2], + operators=[concat_op], + inputs=[0, 1], + outputs=[2], + ) + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.CONCATENATION)] + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[_build_buffer(builder)] * 3, + ) + + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + mod = from_tflite(tflite_model) + mod["main"] = mod["main"].without_attr("params") + + @I.ir_module + class Expected: + @R.function + def main( + tvmgen_tensor_0: R.Tensor((1, 2), dtype="int8"), + tvmgen_tensor_1: R.Tensor((1, 2), dtype="int8"), + ) -> R.Tensor((1, 4), dtype="int8"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + lv: R.Tensor((1, 2), dtype="float32") = R.dequantize( + tvmgen_tensor_0, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="float32", + axis=0, + ) + lv1: R.Tensor((1, 2), dtype="float32") = R.dequantize( + tvmgen_tensor_1, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="float32", + axis=0, + ) + lv2: R.Tensor((1, 4), dtype="float32") = R.concat((lv, lv1), axis=1) + gv: R.Tensor((1, 4), dtype="int8") = R.quantize( + lv2, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="int8", + axis=0, + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_quantized_concat_fused_relu_uses_quantized_clip(): + """Quantized CONCATENATION fused RELU clips in the quantized domain.""" + builder = flatbuffers.Builder(1024) + + in_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + out_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + + t0 = _build_tensor(builder, 0, [1, 2], tensor_type=_tfl_tensor_type.INT8, quantization=in_q) + t1 = _build_tensor(builder, 1, [1, 2], tensor_type=_tfl_tensor_type.INT8, quantization=in_q) + t2 = _build_tensor(builder, 2, [1, 4], tensor_type=_tfl_tensor_type.INT8, quantization=out_q) + + _tfl_concatenation_options.ConcatenationOptionsStart(builder) + _tfl_concatenation_options.ConcatenationOptionsAddAxis(builder, 1) + _tfl_concatenation_options.ConcatenationOptionsAddFusedActivationFunction( + builder, _tfl_activation_fn.RELU + ) + concat_opts = _tfl_concatenation_options.ConcatenationOptionsEnd(builder) + + concat_op = _build_operator( + builder, + 0, + [0, 1], + [2], + builtin_options_type=_tfl_builtin_options.ConcatenationOptions, + builtin_options=concat_opts, + ) + subgraph = _build_subgraph( + builder, + tensors=[t0, t1, t2], + operators=[concat_op], + inputs=[0, 1], + outputs=[2], + ) + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.CONCATENATION)] + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[_build_buffer(builder)] * 3, + ) + + mod = _load_model_from_buffer(buf) + + @I.ir_module + class Expected: + @R.function + def main( + tvmgen_tensor_0: R.Tensor((1, 2), dtype="int8"), + tvmgen_tensor_1: R.Tensor((1, 2), dtype="int8"), + ) -> R.Tensor((1, 4), dtype="int8"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + lv: R.Tensor((1, 2), dtype="float32") = R.dequantize( + tvmgen_tensor_0, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="float32", + axis=0, + ) + lv1: R.Tensor((1, 2), dtype="float32") = R.dequantize( + tvmgen_tensor_1, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="float32", + axis=0, + ) + lv2: R.Tensor((1, 4), dtype="float32") = R.concat((lv, lv1), axis=1) + lv3: R.Tensor((1, 4), dtype="int8") = R.quantize( + lv2, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="int8", + axis=0, + ) + gv: R.Tensor((1, 4), dtype="int8") = R.clip(lv3, min=3.0, max=127.0) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_quantized_add_uses_qdq(): + """Quantized ADD uses DQ each input -> add -> Q.""" + builder = flatbuffers.Builder(1024) + + lhs_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + rhs_q = _build_quantization_parameters( + builder, scale=[0.25], zero_point=[1], quantized_dimension=0 + ) + out_q = _build_quantization_parameters( + builder, scale=[1.0], zero_point=[0], quantized_dimension=0 + ) + + t_lhs = _build_tensor(builder, 0, [2], tensor_type=_tfl_tensor_type.INT8, quantization=lhs_q) + t_rhs = _build_tensor(builder, 1, [2], tensor_type=_tfl_tensor_type.INT8, quantization=rhs_q) + t_out = _build_tensor(builder, 2, [2], tensor_type=_tfl_tensor_type.INT8, quantization=out_q) + + _tfl_add_options.AddOptionsStart(builder) + _tfl_add_options.AddOptionsAddFusedActivationFunction(builder, 0) + add_opts = _tfl_add_options.AddOptionsEnd(builder) + + add_op = _build_operator( + builder, + 0, + [0, 1], + [2], + builtin_options_type=_tfl_builtin_options.AddOptions, + builtin_options=add_opts, + ) + subgraph = _build_subgraph( + builder, + tensors=[t_lhs, t_rhs, t_out], + operators=[add_op], + inputs=[0, 1], + outputs=[2], + ) + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.ADD)] + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[_build_buffer(builder)] * 3, + ) + + mod = _load_model_from_buffer(buf) + + @I.ir_module + class Expected: + @R.function + def main( + tvmgen_tensor_0: R.Tensor((2,), dtype="int8"), + tvmgen_tensor_1: R.Tensor((2,), dtype="int8"), + ) -> R.Tensor((2,), dtype="int8"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + lv: R.Tensor((2,), dtype="float32") = R.dequantize( + tvmgen_tensor_0, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="float32", + axis=0, + ) + lv1: R.Tensor((2,), dtype="float32") = R.dequantize( + tvmgen_tensor_1, + R.const(0.25, "float32"), + R.const(1, "int32"), + out_dtype="float32", + axis=0, + ) + lv2: R.Tensor((2,), dtype="float32") = R.add(lv, lv1) + gv: R.Tensor((2,), dtype="int8") = R.quantize( + lv2, + R.const(1.0, "float32"), + R.const(0, "int32"), + out_dtype="int8", + axis=0, + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_quantized_add_fused_relu6_uses_float_clip_before_quantize(): + """Quantized ADD fused RELU6 applies the activation before quantizing.""" + builder = flatbuffers.Builder(1024) + + lhs_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + rhs_q = _build_quantization_parameters( + builder, scale=[0.25], zero_point=[1], quantized_dimension=0 + ) + out_q = _build_quantization_parameters( + builder, scale=[1.0], zero_point=[0], quantized_dimension=0 + ) + + t_lhs = _build_tensor(builder, 0, [2], tensor_type=_tfl_tensor_type.INT8, quantization=lhs_q) + t_rhs = _build_tensor(builder, 1, [2], tensor_type=_tfl_tensor_type.INT8, quantization=rhs_q) + t_out = _build_tensor(builder, 2, [2], tensor_type=_tfl_tensor_type.INT8, quantization=out_q) + + _tfl_add_options.AddOptionsStart(builder) + _tfl_add_options.AddOptionsAddFusedActivationFunction(builder, _tfl_activation_fn.RELU6) + add_opts = _tfl_add_options.AddOptionsEnd(builder) + + add_op = _build_operator( + builder, + 0, + [0, 1], + [2], + builtin_options_type=_tfl_builtin_options.AddOptions, + builtin_options=add_opts, + ) + subgraph = _build_subgraph( + builder, + tensors=[t_lhs, t_rhs, t_out], + operators=[add_op], + inputs=[0, 1], + outputs=[2], + ) + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.ADD)] + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[_build_buffer(builder)] * 3, + ) + + mod = _load_model_from_buffer(buf) + + @I.ir_module + class Expected: + @R.function + def main( + tvmgen_tensor_0: R.Tensor((2,), dtype="int8"), + tvmgen_tensor_1: R.Tensor((2,), dtype="int8"), + ) -> R.Tensor((2,), dtype="int8"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + lv: R.Tensor((2,), dtype="float32") = R.dequantize( + tvmgen_tensor_0, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="float32", + axis=0, + ) + lv1: R.Tensor((2,), dtype="float32") = R.dequantize( + tvmgen_tensor_1, + R.const(0.25, "float32"), + R.const(1, "int32"), + out_dtype="float32", + axis=0, + ) + lv2: R.Tensor((2,), dtype="float32") = R.add(lv, lv1) + lv3: R.Tensor((2,), dtype="float32") = R.clip(lv2, min=0, max=6) + gv: R.Tensor((2,), dtype="int8") = R.quantize( + lv3, + R.const(1.0, "float32"), + R.const(0, "int32"), + out_dtype="int8", + axis=0, + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_quantized_add_without_output_qparams_invalid(): + """Quantized ADD with missing output qparams raises OpAttributeInvalid.""" + builder = flatbuffers.Builder(1024) + + in_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + + t_lhs = _build_tensor(builder, 0, [2], tensor_type=_tfl_tensor_type.INT8, quantization=in_q) + t_rhs = _build_tensor(builder, 1, [2], tensor_type=_tfl_tensor_type.INT8, quantization=in_q) + t_out = _build_tensor(builder, 2, [2], tensor_type=_tfl_tensor_type.INT8) + + _tfl_add_options.AddOptionsStart(builder) + _tfl_add_options.AddOptionsAddFusedActivationFunction(builder, _tfl_activation_fn.NONE) + add_opts = _tfl_add_options.AddOptionsEnd(builder) + + add_op = _build_operator( + builder, + 0, + [0, 1], + [2], + builtin_options_type=_tfl_builtin_options.AddOptions, + builtin_options=add_opts, + ) + subgraph = _build_subgraph( + builder, + tensors=[t_lhs, t_rhs, t_out], + operators=[add_op], + inputs=[0, 1], + outputs=[2], + ) + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.ADD)] + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[_build_buffer(builder)] * 3, + ) + + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + + with pytest.raises(tvm.error.OpAttributeInvalid, match="output must have quantization"): + from_tflite(tflite_model) + + +def test_quantized_conv2d_with_int32_bias_dequantizes_bias(): + """Conv2D with INT32 bias dequantizes bias with in_scale × wt_scale.""" + import flatbuffers + import tflite.Model + + builder = flatbuffers.Builder(2048) + + in_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + wt_q = _build_quantization_parameters( + builder, scale=[0.25], zero_point=[0], quantized_dimension=0 + ) + out_q = _build_quantization_parameters( + builder, scale=[1.0], zero_point=[0], quantized_dimension=0 + ) + + t_in = _build_tensor( + builder, 0, [1, 4, 4, 1], tensor_type=_tfl_tensor_type.INT8, quantization=in_q + ) + t_wt = _build_tensor( + builder, 1, [2, 3, 3, 1], tensor_type=_tfl_tensor_type.INT8, quantization=wt_q + ) + t_bi = _build_tensor(builder, 2, [2], tensor_type=_tfl_tensor_type.INT32) + t_ou = _build_tensor( + builder, 3, [1, 2, 2, 2], tensor_type=_tfl_tensor_type.INT8, quantization=out_q + ) + + tflite.Conv2DOptionsStart(builder) + tflite.Conv2DOptionsAddStrideH(builder, 1) + tflite.Conv2DOptionsAddStrideW(builder, 1) + tflite.Conv2DOptionsAddPadding(builder, 1) + tflite.Conv2DOptionsAddFusedActivationFunction(builder, 0) + conv_opts = tflite.Conv2DOptionsEnd(builder) + + conv_op = _build_operator( + builder, + 0, + [0, 1, 2], + [3], + builtin_options_type=_tfl_builtin_options.Conv2DOptions, + builtin_options=conv_opts, + ) + subgraph = _build_subgraph( + builder, + tensors=[t_in, t_wt, t_bi, t_ou], + operators=[conv_op], + inputs=[0, 1, 2], + outputs=[3], + ) + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.CONV_2D)] + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[_build_buffer(builder)] * 4, + ) + + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + mod = from_tflite(tflite_model) + mod["main"] = mod["main"].without_attr("params") + + @I.ir_module + class Expected: + @R.function + def main( + tvmgen_tensor_0: R.Tensor((1, 4, 4, 1), dtype="int8"), + tvmgen_tensor_1: R.Tensor((2, 3, 3, 1), dtype="int8"), + tvmgen_tensor_2: R.Tensor((2,), dtype="int32"), + ) -> R.Tensor((1, 2, 2, 2), dtype="int8"): + R.func_attr({"num_input": 3}) + with R.dataflow(): + lv: R.Tensor((1, 4, 4, 1), dtype="float32") = R.dequantize( + tvmgen_tensor_0, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="float32", + axis=0, + ) + lv1: R.Tensor((3, 3, 1, 2), dtype="int8") = R.permute_dims( + tvmgen_tensor_1, + axes=[1, 2, 3, 0], + ) + lv2: R.Tensor((3, 3, 1, 2), dtype="float32") = R.dequantize( + lv1, + R.const(0.25, "float32"), + R.const(0, "int32"), + out_dtype="float32", + axis=3, + ) + lv3: R.Tensor((1, 2, 2, 2), dtype="float32") = R.nn.conv2d( + lv, + lv2, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="HWIO", + out_layout="NHWC", + out_dtype="void", + ) + lv4: R.Tensor((), dtype="float32") = R.multiply( + R.const(0.5, "float32"), + R.const(0.25, "float32"), + ) + lv5: R.Tensor((2,), dtype="float32") = R.dequantize( + tvmgen_tensor_2, + lv4, + R.const(0, "int32"), + out_dtype="float32", + axis=0, + ) + lv6: R.Tensor((1, 2, 2, 2), dtype="float32") = R.add(lv3, lv5) + gv: R.Tensor((1, 2, 2, 2), dtype="int8") = R.quantize( + lv6, + R.const(1.0, "float32"), + R.const(0, "int32"), + out_dtype="int8", + axis=0, + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_quantized_conv2d_per_channel_weight_with_int32_bias_dequantizes_bias(): + """Conv2D with per-channel weight quantization uses vector bias scale.""" + builder = flatbuffers.Builder(2048) + + in_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + wt_q = _build_quantization_parameters( + builder, scale=[0.25, 0.75], zero_point=[0, 0], quantized_dimension=0 + ) + out_q = _build_quantization_parameters( + builder, scale=[1.0], zero_point=[0], quantized_dimension=0 + ) + + t_in = _build_tensor( + builder, 0, [1, 4, 4, 1], tensor_type=_tfl_tensor_type.INT8, quantization=in_q + ) + t_wt = _build_tensor( + builder, 1, [2, 3, 3, 1], tensor_type=_tfl_tensor_type.INT8, quantization=wt_q + ) + t_bi = _build_tensor(builder, 2, [2], tensor_type=_tfl_tensor_type.INT32) + t_ou = _build_tensor( + builder, 3, [1, 2, 2, 2], tensor_type=_tfl_tensor_type.INT8, quantization=out_q + ) + + _tfl_conv2d_options.Conv2DOptionsStart(builder) + _tfl_conv2d_options.Conv2DOptionsAddStrideH(builder, 1) + _tfl_conv2d_options.Conv2DOptionsAddStrideW(builder, 1) + _tfl_conv2d_options.Conv2DOptionsAddPadding(builder, 1) + _tfl_conv2d_options.Conv2DOptionsAddFusedActivationFunction(builder, 0) + conv_opts = _tfl_conv2d_options.Conv2DOptionsEnd(builder) + + conv_op = _build_operator( + builder, + 0, + [0, 1, 2], + [3], + builtin_options_type=_tfl_builtin_options.Conv2DOptions, + builtin_options=conv_opts, + ) + subgraph = _build_subgraph( + builder, + tensors=[t_in, t_wt, t_bi, t_ou], + operators=[conv_op], + inputs=[0, 1, 2], + outputs=[3], + ) + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.CONV_2D)] + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[_build_buffer(builder)] * 4, + ) + + mod = _load_model_from_buffer(buf) + + @I.ir_module + class Expected: + @R.function + def main( + tvmgen_tensor_0: R.Tensor((1, 4, 4, 1), dtype="int8"), + tvmgen_tensor_1: R.Tensor((2, 3, 3, 1), dtype="int8"), + tvmgen_tensor_2: R.Tensor((2,), dtype="int32"), + ) -> R.Tensor((1, 2, 2, 2), dtype="int8"): + R.func_attr({"num_input": 3}) + with R.dataflow(): + lv: R.Tensor((1, 4, 4, 1), dtype="float32") = R.dequantize( + tvmgen_tensor_0, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="float32", + axis=0, + ) + lv1: R.Tensor((3, 3, 1, 2), dtype="int8") = R.permute_dims( + tvmgen_tensor_1, + axes=[1, 2, 3, 0], + ) + lv2: R.Tensor((3, 3, 1, 2), dtype="float32") = R.dequantize( + lv1, + R.const([0.25, 0.75], "float32"), + R.const(0, "int32"), + out_dtype="float32", + axis=3, + ) + lv3: R.Tensor((1, 2, 2, 2), dtype="float32") = R.nn.conv2d( + lv, + lv2, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="HWIO", + out_layout="NHWC", + out_dtype="void", + ) + lv4: R.Tensor((2,), dtype="float32") = R.multiply( + R.const(0.5, "float32"), + R.const([0.25, 0.75], "float32"), + ) + lv5: R.Tensor((2,), dtype="float32") = R.dequantize( + tvmgen_tensor_2, + lv4, + R.const(0, "int32"), + out_dtype="float32", + axis=0, + ) + lv6: R.Tensor((1, 2, 2, 2), dtype="float32") = R.add(lv3, lv5) + gv: R.Tensor((1, 2, 2, 2), dtype="int8") = R.quantize( + lv6, + R.const(1.0, "float32"), + R.const(0, "int32"), + out_dtype="int8", + axis=0, + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_per_channel_depthwise_conv_unsupported(): + """Per-channel quantized depthwise Conv2D raises OpNotImplemented.""" + import flatbuffers + import tflite.Model + + builder = flatbuffers.Builder(1024) + + in_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[0], quantized_dimension=0 + ) + # Per-channel weight: 2 channels, scale vector length 2 + wt_q = _build_quantization_parameters( + builder, scale=[0.25, 0.75], zero_point=[0, 0], quantized_dimension=3 + ) + out_q = _build_quantization_parameters( + builder, scale=[1.0], zero_point=[0], quantized_dimension=0 + ) + + t_in = _build_tensor( + builder, 0, [1, 4, 4, 2], tensor_type=_tfl_tensor_type.INT8, quantization=in_q + ) + t_wt = _build_tensor( + builder, 1, [1, 3, 3, 2], tensor_type=_tfl_tensor_type.INT8, quantization=wt_q + ) + t_ou = _build_tensor( + builder, 2, [1, 2, 2, 2], tensor_type=_tfl_tensor_type.INT8, quantization=out_q + ) + + tflite.DepthwiseConv2DOptionsStart(builder) + tflite.DepthwiseConv2DOptionsAddStrideH(builder, 1) + tflite.DepthwiseConv2DOptionsAddStrideW(builder, 1) + tflite.DepthwiseConv2DOptionsAddDepthMultiplier(builder, 1) + tflite.DepthwiseConv2DOptionsAddPadding(builder, 1) + tflite.DepthwiseConv2DOptionsAddFusedActivationFunction(builder, 0) + dw_opts = tflite.DepthwiseConv2DOptionsEnd(builder) + + dw_op = _build_operator( + builder, + 0, + [0, 1], + [2], + builtin_options_type=_tfl_builtin_options.DepthwiseConv2DOptions, + builtin_options=dw_opts, + ) + subgraph = _build_subgraph( + builder, + tensors=[t_in, t_wt, t_ou], + operators=[dw_op], + inputs=[0, 1], + outputs=[2], + ) + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.DEPTHWISE_CONV_2D)] + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[_build_buffer(builder)] * 3, + ) + + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + + with pytest.raises(tvm.error.OpNotImplemented, match="Per-channel"): + from_tflite(tflite_model) + + +def test_uint8_reshape_requantize_uses_dq_reshape_q(): + """uint8 RESHAPE with different qparams uses DQ→reshape→Q.""" + import flatbuffers + import tflite.Model + import numpy as np + + builder = flatbuffers.Builder(1024) + + in_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[128], quantized_dimension=0 + ) + out_q = _build_quantization_parameters( + builder, scale=[1.0], zero_point=[100], quantized_dimension=0 + ) + + t_in = _build_tensor(builder, 0, [1, 4], tensor_type=_tfl_tensor_type.UINT8, quantization=in_q) + t_ou = _build_tensor(builder, 1, [2, 2], tensor_type=_tfl_tensor_type.UINT8, quantization=out_q) + + # Use ReshapeOptions with static new_shape [2, 2] + new_shape_np = np.array([2, 2], dtype=np.int32) + new_shape_vec = _tflite_int32_vector( + builder, tflite.ReshapeOptionsStartNewShapeVector, new_shape_np + ) + tflite.ReshapeOptionsStart(builder) + tflite.ReshapeOptionsAddNewShape(builder, new_shape_vec) + reshape_opts = tflite.ReshapeOptionsEnd(builder) + + reshape_op = _build_operator( + builder, + 0, + [0], + [1], + builtin_options_type=_tfl_builtin_options.ReshapeOptions, + builtin_options=reshape_opts, + ) + subgraph = _build_subgraph( + builder, + tensors=[t_in, t_ou], + operators=[reshape_op], + inputs=[0], + outputs=[1], + ) + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.RESHAPE)] + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[_build_buffer(builder), _build_buffer(builder)], + ) + + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + mod = from_tflite(tflite_model) + mod["main"] = mod["main"].without_attr("params") + + @I.ir_module + class Expected: + @R.function + def main( + tvmgen_tensor_0: R.Tensor((1, 4), dtype="uint8"), + ) -> R.Tensor((2, 2), dtype="uint8"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + lv: R.Tensor((1, 4), dtype="float32") = R.dequantize( + tvmgen_tensor_0, + R.const(0.5, "float32"), + R.const(128, "int32"), + out_dtype="float32", + axis=0, + ) + lv1: R.Tensor((2, 2), dtype="float32") = R.reshape( + lv, + R.shape([2, 2]), + ) + gv: R.Tensor((2, 2), dtype="uint8") = R.quantize( + lv1, + R.const(1.0, "float32"), + R.const(100, "int32"), + out_dtype="uint8", + axis=0, + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_transpose_conv_with_int32_bias_dequantizes_bias(): + """TRANSPOSE_CONV with INT32 bias dequantizes bias before adding.""" + import flatbuffers + import tflite.Model + import struct + + builder = flatbuffers.Builder(2048) + + in_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + wt_q = _build_quantization_parameters( + builder, scale=[0.25], zero_point=[0], quantized_dimension=0 + ) + out_q = _build_quantization_parameters( + builder, scale=[1.0], zero_point=[0], quantized_dimension=0 + ) + + t_in = _build_tensor( + builder, 0, [1, 1, 1, 1], tensor_type=_tfl_tensor_type.INT8, quantization=in_q + ) + t_wt = _build_tensor( + builder, 1, [1, 1, 1, 1], tensor_type=_tfl_tensor_type.INT8, quantization=wt_q + ) + t_bi = _build_tensor(builder, 2, [1], tensor_type=_tfl_tensor_type.INT32) + t_ou = _build_tensor( + builder, 3, [1, 1, 1, 1], tensor_type=_tfl_tensor_type.INT8, quantization=out_q + ) + oshape_data = struct.pack(" R.Tensor((1, 1, 1, 1), dtype="int8"): + R.func_attr({"num_input": 3}) + with R.dataflow(): + lv: R.Tensor((1, 1, 1, 1), dtype="float32") = R.dequantize( + tvmgen_tensor_0, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="float32", + axis=0, + ) + lv1: R.Tensor((1, 1, 1, 1), dtype="int8") = R.permute_dims( + tvmgen_tensor_1, + axes=[3, 0, 1, 2], + ) + lv2: R.Tensor((1, 1, 1, 1), dtype="float32") = R.dequantize( + lv1, + R.const(0.25, "float32"), + R.const(0, "int32"), + out_dtype="float32", + axis=1, + ) + lv3: R.Tensor((1, 1, 1, 1), dtype="float32") = R.nn.conv2d_transpose( + lv, + lv2, + strides=[1, 1], + padding=[0, 0, 0, 0], + data_layout="NHWC", + kernel_layout="IOHW", + out_dtype="float32", + ) + lv4: R.Tensor((), dtype="float32") = R.multiply( + R.const(0.5, "float32"), + R.const(0.25, "float32"), + ) + lv5: R.Tensor((1,), dtype="float32") = R.dequantize( + tvmgen_tensor_2, + lv4, + R.const(0, "int32"), + out_dtype="float32", + axis=0, + ) + lv6: R.Tensor((1, 1, 1, 1), dtype="float32") = R.add(lv3, lv5) + gv: R.Tensor((1, 1, 1, 1), dtype="int8") = R.quantize( + lv6, + R.const(1.0, "float32"), + R.const(0, "int32"), + out_dtype="int8", + axis=0, + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_quantized_fully_connected_with_int32_bias_dequantizes_bias(): + """Quantized FullyConnected with INT32 bias dequantizes bias with in_scale × wt_scale.""" + import flatbuffers + import tflite.Model + + builder = flatbuffers.Builder(2048) + + in_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + wt_q = _build_quantization_parameters( + builder, scale=[0.25], zero_point=[0], quantized_dimension=0 + ) + out_q = _build_quantization_parameters( + builder, scale=[1.0], zero_point=[0], quantized_dimension=0 + ) + + t_in = _build_tensor(builder, 0, [1, 4], tensor_type=_tfl_tensor_type.INT8, quantization=in_q) + t_wt = _build_tensor(builder, 1, [2, 4], tensor_type=_tfl_tensor_type.INT8, quantization=wt_q) + t_bi = _build_tensor(builder, 2, [2], tensor_type=_tfl_tensor_type.INT32) + t_ou = _build_tensor(builder, 3, [1, 2], tensor_type=_tfl_tensor_type.INT8, quantization=out_q) + + tflite.FullyConnectedOptionsStart(builder) + tflite.FullyConnectedOptionsAddFusedActivationFunction(builder, 0) + tflite.FullyConnectedOptionsAddWeightsFormat(builder, _tfl_fc_weights_format.DEFAULT) + tflite.FullyConnectedOptionsAddKeepNumDims(builder, 0) + fc_opts = tflite.FullyConnectedOptionsEnd(builder) + + fc_op = _build_operator( + builder, + 0, + [0, 1, 2], + [3], + builtin_options_type=_tfl_builtin_options.FullyConnectedOptions, + builtin_options=fc_opts, + ) + subgraph = _build_subgraph( + builder, + tensors=[t_in, t_wt, t_bi, t_ou], + operators=[fc_op], + inputs=[0, 1, 2], + outputs=[3], + ) + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.FULLY_CONNECTED)] + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[_build_buffer(builder)] * 4, + ) + + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + mod = from_tflite(tflite_model) + mod["main"] = mod["main"].without_attr("params") + + @I.ir_module + class Expected: + @R.function + def main( + tvmgen_tensor_0: R.Tensor((1, 4), dtype="int8"), + tvmgen_tensor_1: R.Tensor((2, 4), dtype="int8"), + tvmgen_tensor_2: R.Tensor((2,), dtype="int32"), + ) -> R.Tensor((1, 2), dtype="int8"): + R.func_attr({"num_input": 3}) + with R.dataflow(): + lv: R.Tensor((1, 4), dtype="float32") = R.dequantize( + tvmgen_tensor_0, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="float32", + axis=0, + ) + lv1: R.Tensor((4, 2), dtype="int8") = R.permute_dims( + tvmgen_tensor_1, + axes=[1, 0], + ) + lv2: R.Tensor((4, 2), dtype="float32") = R.dequantize( + lv1, + R.const(0.25, "float32"), + R.const(0, "int32"), + out_dtype="float32", + axis=1, + ) + lv3: R.Tensor((1, 2), dtype="float32") = R.matmul(lv, lv2, out_dtype="void") + lv4: R.Tensor((), dtype="float32") = R.multiply( + R.const(0.5, "float32"), + R.const(0.25, "float32"), + ) + lv5: R.Tensor((2,), dtype="float32") = R.dequantize( + tvmgen_tensor_2, + lv4, + R.const(0, "int32"), + out_dtype="float32", + axis=0, + ) + lv6: R.Tensor((1, 2), dtype="float32") = R.add(lv3, lv5) + gv: R.Tensor((1, 2), dtype="int8") = R.quantize( + lv6, + R.const(1.0, "float32"), + R.const(0, "int32"), + out_dtype="int8", + axis=0, + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + def _build_csr_sparsity( builder, *, From 8b09ced440ead9b08fcdf6c07696bc5009b413cb Mon Sep 17 00:00:00 2001 From: HoYi Date: Fri, 15 May 2026 10:09:41 +0800 Subject: [PATCH 6/7] [Relax][Frontend][TFLite] Fix test compatibility with CI tflite package Use _get_tflite_schema_module() for TFLite builtin option builder helpers instead of accessing them directly from the tflite module. CI's tflite package does not reliably re-export these schema sub-module functions at the top level. Also fix ruff lint issues (RUF002 ambiguous unicode, E501 long lines, I001 import sorting) and apply ruff format. --- tests/python/relax/test_frontend_tflite.py | 1366 ++++++++++---------- 1 file changed, 688 insertions(+), 678 deletions(-) diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index cc4af9c19341..34a37ee51b8b 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -3671,8 +3671,12 @@ def _get_tflite_schema_enum(enum_name): _tfl_add_options = _get_tflite_schema_module("AddOptions") _tfl_buffer = _get_tflite_schema_module("Buffer") +_tfl_concatenation_options = _get_tflite_schema_module("ConcatenationOptions") _tfl_conv2d_options = _get_tflite_schema_module("Conv2DOptions") +_tfl_depthwise_conv2d_options = _get_tflite_schema_module("DepthwiseConv2DOptions") _tfl_dilate_options = _get_tflite_schema_module("DilateOptions") +_tfl_reshape_options = _get_tflite_schema_module("ReshapeOptions") +_tfl_transpose_conv_options = _get_tflite_schema_module("TransposeConvOptions") # ── StableHLO BuiltinOptions2 schema modules ──────────────────────────── _tfl_stablehlo_concat_opts = _get_tflite_schema_module("StablehloConcatenateOptions") @@ -5770,115 +5774,118 @@ def test_stablehlo_dynamic_slice_out_of_bounds_unsupported(): from_tflite(tflite_model) -def test_tensor_quantization_parameters_are_parsed(): - """Tensor quantization metadata is kept without requiring quantized op support.""" - builder = flatbuffers.Builder(1024) - - per_tensor_quantization = _build_quantization_parameters( - builder, scale=[0.5], zero_point=[3], quantized_dimension=0 - ) - per_axis_quantization = _build_quantization_parameters( - builder, scale=[0.25, 0.75], zero_point=[0, 0], quantized_dimension=3 - ) - per_tensor = _build_tensor( - builder, - 0, - [1, 4], - tensor_type=_tfl_tensor_type.UINT8, - quantization=per_tensor_quantization, - ) - per_axis = _build_tensor( - builder, - 1, - [1, 2, 3, 2], - tensor_type=_tfl_tensor_type.INT8, - quantization=per_axis_quantization, - ) - subgraph = _build_subgraph( - builder, tensors=[per_tensor, per_axis], operators=[], inputs=[0, 1], outputs=[0, 1] +def test_stablehlo_cbrt(): + """TFLite StableHLO CBRT uses a sign-preserving composite expression.""" + mod = _load_model_from_buffer( + _build_stablehlo_model(builtin_name="STABLEHLO_CBRT", input_count=1) ) - buffers = [_build_buffer(builder), _build_buffer(builder)] - buf = _finish_tflite_model(builder, subgraph=subgraph, operator_codes=[], buffers=buffers) - if hasattr(tflite.Model, "Model"): - tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) - else: - tflite_model = tflite.Model.GetRootAsModel(buf, 0) + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 2), dtype="float32")) -> R.Tensor((2, 2), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + lv: R.Tensor((2, 2), dtype="float32") = R.negative(x) + lv1: R.Tensor((2, 2), dtype="float32") = R.power(lv, R.const(1.0 / 3.0, "float32")) + lv2: R.Tensor((2, 2), dtype="bool") = R.less(x, R.const(0, "float32")) + lv3: R.Tensor((2, 2), dtype="float32") = R.negative(lv1) + lv4: R.Tensor((2, 2), dtype="float32") = R.power(x, R.const(1.0 / 3.0, "float32")) + gv: R.Tensor((2, 2), dtype="float32") = R.where(lv2, lv3, lv4) + R.output(gv) + return gv - converter = tflite_frontend.OperatorConverter( - tflite_model, tflite_model.Subgraphs(0), tflite_frontend.ExprTable(), None - ) - per_tensor_wrapper, per_axis_wrapper = converter.get_tensors([0, 1]) + tvm.ir.assert_structural_equal(mod, Expected) - np.testing.assert_allclose(per_tensor_wrapper.qnn_params["scale"].data.numpy(), 0.5) - np.testing.assert_equal(per_tensor_wrapper.qnn_params["zero_point"].data.numpy(), 3) - assert per_tensor_wrapper.qnn_params["axis"] == 0 - np.testing.assert_allclose( - per_axis_wrapper.qnn_params["scale"].data.numpy(), np.array([0.25, 0.75]) +def test_stablehlo_remainder(): + """TFLite StableHLO REMAINDER uses truncating remainder semantics.""" + mod = _load_model_from_buffer( + _build_stablehlo_model(builtin_name="STABLEHLO_REMAINDER", input_count=2) ) - np.testing.assert_equal(per_axis_wrapper.qnn_params["zero_point"].data.numpy(), 0) - assert per_axis_wrapper.qnn_params["axis"] == 3 - mod = from_tflite(tflite_model) - assert len(mod["main"].params) == 2 + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 2), dtype="float32"), + y: R.Tensor((2, 2), dtype="float32"), + ) -> R.Tensor((2, 2), dtype="float32"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + lv: R.Tensor((2, 2), dtype="float32") = R.divide(x, y) + lv1: R.Tensor((2, 2), dtype="float32") = R.trunc(lv) + lv2: R.Tensor((2, 2), dtype="float32") = R.multiply(y, lv1) + gv: R.Tensor((2, 2), dtype="float32") = R.subtract(x, lv2) + R.output(gv) + return gv + tvm.ir.assert_structural_equal(mod, Expected) -def test_quantize_op_uses_relax_quantize(): - """TFLite QUANTIZE float32 -> int8 uses R.quantize.""" - builder = flatbuffers.Builder(1024) - input_data = np.array([1.0, 2.0], dtype=np.float32) - output_qparams = _build_quantization_parameters( - builder, scale=[0.5], zero_point=[3], quantized_dimension=0 - ) +def _build_stablehlo_dynamic_update_slice_model(start_vals, dynamic_starts=False): + """Build a minimal STABLEHLO_DYNAMIC_UPDATE_SLICE model.""" + builder = flatbuffers.Builder(1024) + builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_DYNAMIC_UPDATE_SLICE") + op_code = _build_operator_code(builder, builtin_op) - input_tensor = _build_tensor(builder, 0, [2], tensor_type=_tfl_tensor_type.FLOAT32) - output_tensor = _build_tensor( - builder, - 1, - [2], - tensor_type=_tfl_tensor_type.INT8, - quantization=output_qparams, - ) + t_operand = _build_tensor(builder, 0, [3, 4]) + t_update = _build_tensor(builder, 1, [2, 2]) + start_tensors = [ + _build_tensor(builder, 2 + i, [], tensor_type=_tfl_tensor_type.INT32) + for i in range(len(start_vals)) + ] + out_idx = 2 + len(start_vals) + t_out = _build_tensor(builder, out_idx, [3, 4]) + tensors = [t_operand, t_update, *start_tensors, t_out] - quantize_op = _build_operator(builder, 0, [0], [1]) + op_inputs = [0, 1, *range(2, out_idx)] + op = _build_operator(builder, 0, op_inputs, [out_idx]) + subgraph_inputs = op_inputs if dynamic_starts else [0, 1] subgraph = _build_subgraph( builder, - tensors=[input_tensor, output_tensor], - operators=[quantize_op], - inputs=[0], - outputs=[1], + tensors=tensors, + operators=[op], + inputs=subgraph_inputs, + outputs=[out_idx], ) - operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.QUANTIZE)] - input_buffer = _build_buffer(builder, input_data.tobytes()) - output_buffer = _build_buffer(builder) - buf = _finish_tflite_model( - builder, - subgraph=subgraph, - operator_codes=operator_codes, - buffers=[input_buffer, output_buffer], + if dynamic_starts: + buffers = [_build_buffer(builder) for _ in range(out_idx + 1)] + else: + start_buffers = [ + _build_buffer(builder, np.array([start], dtype=np.int32).tobytes()) + for start in start_vals + ] + buffers = [ + _build_buffer(builder), + _build_buffer(builder), + *start_buffers, + _build_buffer(builder), + ] + + return _finish_tflite_model( + builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers ) - if hasattr(tflite.Model, "Model"): - tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) - else: - tflite_model = tflite.Model.GetRootAsModel(buf, 0) - mod = from_tflite(tflite_model) - mod["main"] = mod["main"].without_attr("params") + +def test_stablehlo_dynamic_update_slice(): + """TFLite StableHLO DYNAMIC_UPDATE_SLICE with static starts.""" + mod = _load_model_from_buffer(_build_stablehlo_dynamic_update_slice_model([1, 1])) @I.ir_module class Expected: @R.function - def main(x: R.Tensor((2,), dtype="float32")) -> R.Tensor((2,), dtype="int8"): - R.func_attr({"num_input": 1}) + def main( + operand: R.Tensor((3, 4), dtype="float32"), + update: R.Tensor((2, 2), dtype="float32"), + ) -> R.Tensor((3, 4), dtype="float32"): + R.func_attr({"num_input": 2}) with R.dataflow(): - gv: R.Tensor((2,), dtype="int8") = R.quantize( - x, - R.const(0.5, "float32"), - R.const(3, "int32"), - axis=0, - out_dtype="int8", + gv: R.Tensor((3, 4), dtype="float32") = R.scatter_nd( + operand, + R.const([[[1, 1], [1, 2]], [[2, 1], [2, 2]]], dtype="int64"), + update, + reduction="update", ) R.output(gv) return gv @@ -5886,257 +5893,246 @@ def main(x: R.Tensor((2,), dtype="float32")) -> R.Tensor((2,), dtype="int8"): tvm.ir.assert_structural_equal(mod, Expected) -def test_quantize_op_requantize_uses_dq_q(): - """TFLite QUANTIZE with quantized input uses DQ→Q (requantize).""" - builder = flatbuffers.Builder(1024) - - input_data = np.array([10, 20], dtype=np.int8) - input_qparams = _build_quantization_parameters( - builder, scale=[0.25], zero_point=[1], quantized_dimension=0 - ) - output_qparams = _build_quantization_parameters( - builder, scale=[0.5], zero_point=[3], quantized_dimension=0 - ) +def test_stablehlo_dynamic_update_slice_dynamic_starts_unsupported(): + """TFLite StableHLO DYNAMIC_UPDATE_SLICE with runtime starts is unsupported.""" + buf = _build_stablehlo_dynamic_update_slice_model([0, 0], dynamic_starts=True) + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) - input_tensor = _build_tensor( - builder, - 0, - [2], - tensor_type=_tfl_tensor_type.INT8, - quantization=input_qparams, - ) - output_tensor = _build_tensor( - builder, - 1, - [2], - tensor_type=_tfl_tensor_type.INT8, - quantization=output_qparams, - ) + with pytest.raises(tvm.error.OpNotImplemented, match="dynamic start"): + from_tflite(tflite_model) - quantize_op = _build_operator( - builder, - 0, - [0], - [1], - ) - subgraph = _build_subgraph( - builder, - tensors=[input_tensor, output_tensor], - operators=[quantize_op], - inputs=[0], - outputs=[1], - ) - operator_codes = [ - _build_operator_code(builder, _tfl_builtin_operator.QUANTIZE), - ] - input_buffer = _build_buffer(builder, input_data.tobytes()) - output_buffer = _build_buffer(builder) - buf = _finish_tflite_model( - builder, - subgraph=subgraph, - operator_codes=operator_codes, - buffers=[input_buffer, output_buffer], - ) +def test_stablehlo_dynamic_update_slice_out_of_bounds_unsupported(): + """TFLite StableHLO DYNAMIC_UPDATE_SLICE rejects out-of-bounds updates.""" + buf = _build_stablehlo_dynamic_update_slice_model([2, 3]) if hasattr(tflite.Model, "Model"): tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) else: tflite_model = tflite.Model.GetRootAsModel(buf, 0) - mod = from_tflite(tflite_model) - mod["main"] = mod["main"].without_attr("params") - - @I.ir_module - class Expected: - @R.function - def main( - tvmgen_tensor_0: R.Tensor((2,), dtype="int8"), - ) -> R.Tensor((2,), dtype="int8"): - R.func_attr({"num_input": 1}) - with R.dataflow(): - lv: R.Tensor((2,), dtype="float32") = R.dequantize( - tvmgen_tensor_0, - R.const(0.25, "float32"), - R.const(1, "int32"), - out_dtype="float32", - axis=0, - ) - gv: R.Tensor((2,), dtype="int8") = R.quantize( - lv, - R.const(0.5, "float32"), - R.const(3, "int32"), - out_dtype="int8", - axis=0, - ) - R.output(gv) - return gv - tvm.ir.assert_structural_equal(mod, Expected) + with pytest.raises(tvm.error.OpNotImplemented, match="out-of-bounds"): + from_tflite(tflite_model) -def test_dequantize_op_uses_relax_dequantize(): - """TFLite DEQUANTIZE int8 -> float32 uses R.dequantize.""" +def _build_stablehlo_dot_general_model(lhs_contract, rhs_contract, lhs_batch=None, rhs_batch=None): + """Build a minimal STABLEHLO_DOT_GENERAL model.""" builder = flatbuffers.Builder(1024) + lhs_batch = [] if lhs_batch is None else lhs_batch + rhs_batch = [] if rhs_batch is None else rhs_batch + + lhs_batch_vec = _tflite_int64_vector( + builder, + _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsStartLhsBatchingDimensionsVector, + lhs_batch, + ) + rhs_batch_vec = _tflite_int64_vector( + builder, + _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsStartRhsBatchingDimensionsVector, + rhs_batch, + ) + lhs_contract_vec = _tflite_int64_vector( + builder, + _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsStartLhsContractingDimensionsVector, + lhs_contract, + ) + rhs_contract_vec = _tflite_int64_vector( + builder, + _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsStartRhsContractingDimensionsVector, + rhs_contract, + ) - input_data = np.array([10, 20], dtype=np.int8) - input_qparams = _build_quantization_parameters( - builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsStart(builder) + _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsAddLhsBatchingDimensions( + builder, lhs_batch_vec + ) + _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsAddRhsBatchingDimensions( + builder, rhs_batch_vec ) + _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsAddLhsContractingDimensions( + builder, lhs_contract_vec + ) + _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsAddRhsContractingDimensions( + builder, rhs_contract_vec + ) + dot_opts = _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsEnd(builder) - input_tensor = _build_tensor( + builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_DOT_GENERAL") + op_code = _build_operator_code(builder, builtin_op) + t_lhs = _build_tensor(builder, 0, [2, 3]) + t_rhs = _build_tensor(builder, 1, [3, 4]) + t_out = _build_tensor(builder, 2, [2, 4]) + op = _build_operator( builder, 0, + [0, 1], [2], - tensor_type=_tfl_tensor_type.INT8, - quantization=input_qparams, + builtin_options2_type=_tfl_builtin_options2.StablehloDotGeneralOptions, + builtin_options2=dot_opts, ) - output_tensor = _build_tensor(builder, 1, [2], tensor_type=_tfl_tensor_type.FLOAT32) - - dequantize_op = _build_operator(builder, 0, [0], [1]) subgraph = _build_subgraph( builder, - tensors=[input_tensor, output_tensor], - operators=[dequantize_op], - inputs=[0], - outputs=[1], + tensors=[t_lhs, t_rhs, t_out], + operators=[op], + inputs=[0, 1], + outputs=[2], ) - operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.DEQUANTIZE)] - input_buffer = _build_buffer(builder, input_data.tobytes()) - output_buffer = _build_buffer(builder) - buf = _finish_tflite_model( - builder, - subgraph=subgraph, - operator_codes=operator_codes, - buffers=[input_buffer, output_buffer], + buffers = [_build_buffer(builder) for _ in range(3)] + return _finish_tflite_model( + builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers ) - if hasattr(tflite.Model, "Model"): - tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) - else: - tflite_model = tflite.Model.GetRootAsModel(buf, 0) - mod = from_tflite(tflite_model) - mod["main"] = mod["main"].without_attr("params") + +def test_stablehlo_dot_general(): + """TFLite StableHLO DOT_GENERAL canonical 2D matmul.""" + mod = _load_model_from_buffer(_build_stablehlo_dot_general_model([1], [0])) @I.ir_module class Expected: @R.function - def main(x: R.Tensor((2,), dtype="int8")) -> R.Tensor((2,), dtype="float32"): - R.func_attr({"num_input": 1}) + def main( + lhs: R.Tensor((2, 3), dtype="float32"), + rhs: R.Tensor((3, 4), dtype="float32"), + ) -> R.Tensor((2, 4), dtype="float32"): + R.func_attr({"num_input": 2}) with R.dataflow(): - gv: R.Tensor((2,), dtype="float32") = R.dequantize( - x, - R.const(0.5, "float32"), - R.const(3, "int32"), - axis=0, - ) + gv: R.Tensor((2, 4), dtype="float32") = R.matmul(lhs, rhs, out_dtype="void") R.output(gv) return gv tvm.ir.assert_structural_equal(mod, Expected) -def test_quantized_conv2d_per_tensor_uses_qdq(): - """Quantized Conv2D with per-tensor quantization uses DQ -> conv2d -> Q.""" - builder = flatbuffers.Builder(2048) +def test_stablehlo_dot_general_noncanonical_unsupported(): + """TFLite StableHLO DOT_GENERAL rejects non-canonical contracting dims.""" + buf = _build_stablehlo_dot_general_model([0], [0]) + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) - in_q = _build_quantization_parameters( - builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + with pytest.raises(tvm.error.OpNotImplemented, match="contracting"): + from_tflite(tflite_model) + + +def _build_stablehlo_convolution_model(feature_group_count=1, input_batch_dimension=0): + """Build a minimal STABLEHLO_CONVOLUTION model.""" + builder = flatbuffers.Builder(1024) + + window_strides_vec = _tflite_int64_vector( + builder, + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartWindowStridesVector, + [1, 1], ) - wt_q = _build_quantization_parameters( - builder, scale=[0.25], zero_point=[0], quantized_dimension=0 + padding_vec = _tflite_int64_vector( + builder, + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartPaddingVector, + [0, 0, 0, 0], ) - out_q = _build_quantization_parameters( - builder, scale=[1.0], zero_point=[0], quantized_dimension=0 + lhs_dilation_vec = _tflite_int64_vector( + builder, _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartLhsDilationVector, [1, 1] ) - - input_tensor = _build_tensor( + rhs_dilation_vec = _tflite_int64_vector( + builder, _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartRhsDilationVector, [1, 1] + ) + window_reversal_vec = _tflite_bool_vector( builder, - 0, - [1, 4, 4, 1], - tensor_type=_tfl_tensor_type.INT8, - quantization=in_q, + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartWindowReversalVector, + [False, False], ) - weight_tensor = _build_tensor( + input_spatial_vec = _tflite_int64_vector( builder, - 1, - [2, 3, 3, 1], - tensor_type=_tfl_tensor_type.INT8, - quantization=wt_q, + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartInputSpatialDimensionsVector, + [1, 2], ) - output_tensor = _build_tensor( + kernel_spatial_vec = _tflite_int64_vector( builder, - 2, - [1, 2, 2, 2], - tensor_type=_tfl_tensor_type.INT8, - quantization=out_q, + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartKernelSpatialDimensionsVector, + [0, 1], + ) + output_spatial_vec = _tflite_int64_vector( + builder, + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartOutputSpatialDimensionsVector, + [1, 2], ) - _tfl_conv2d_options.Conv2DOptionsStart(builder) - _tfl_conv2d_options.Conv2DOptionsAddStrideH(builder, 1) - _tfl_conv2d_options.Conv2DOptionsAddStrideW(builder, 1) - _tfl_conv2d_options.Conv2DOptionsAddPadding(builder, _tfl_padding.VALID) - _tfl_conv2d_options.Conv2DOptionsAddFusedActivationFunction(builder, 0) - conv_opts = _tfl_conv2d_options.Conv2DOptionsEnd(builder) + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStart(builder) + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddWindowStrides( + builder, window_strides_vec + ) + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddPadding(builder, padding_vec) + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddLhsDilation(builder, lhs_dilation_vec) + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddRhsDilation(builder, rhs_dilation_vec) + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddWindowReversal( + builder, window_reversal_vec + ) + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddInputBatchDimension( + builder, input_batch_dimension + ) + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddInputFeatureDimension(builder, 3) + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddInputSpatialDimensions( + builder, input_spatial_vec + ) + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddKernelInputFeatureDimension(builder, 2) + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddKernelOutputFeatureDimension(builder, 3) + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddKernelSpatialDimensions( + builder, kernel_spatial_vec + ) + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddOutputBatchDimension(builder, 0) + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddOutputFeatureDimension(builder, 3) + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddOutputSpatialDimensions( + builder, output_spatial_vec + ) + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddFeatureGroupCount( + builder, feature_group_count + ) + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddBatchGroupCount(builder, 1) + conv_opts = _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsEnd(builder) - conv_op = _build_operator( + builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_CONVOLUTION") + op_code = _build_operator_code(builder, builtin_op) + t_data = _build_tensor(builder, 0, [1, 5, 5, 2]) + t_kernel = _build_tensor(builder, 1, [3, 3, 2, 4]) + t_out = _build_tensor(builder, 2, [1, 3, 3, 4]) + op = _build_operator( builder, 0, [0, 1], [2], - builtin_options_type=_tfl_builtin_options.Conv2DOptions, - builtin_options=conv_opts, + builtin_options2_type=_tfl_builtin_options2.StablehloConvolutionOptions, + builtin_options2=conv_opts, ) subgraph = _build_subgraph( builder, - tensors=[input_tensor, weight_tensor, output_tensor], - operators=[conv_op], + tensors=[t_data, t_kernel, t_out], + operators=[op], inputs=[0, 1], outputs=[2], ) - operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.CONV_2D)] - buf = _finish_tflite_model( - builder, - subgraph=subgraph, - operator_codes=operator_codes, - buffers=[_build_buffer(builder), _build_buffer(builder), _build_buffer(builder)], + buffers = [_build_buffer(builder) for _ in range(3)] + return _finish_tflite_model( + builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers ) - if hasattr(tflite.Model, "Model"): - tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) - else: - tflite_model = tflite.Model.GetRootAsModel(buf, 0) - mod = from_tflite(tflite_model) - mod["main"] = mod["main"].without_attr("params") + +def test_stablehlo_convolution(): + """TFLite StableHLO CONVOLUTION canonical NHWC/HWIO 2D convolution.""" + mod = _load_model_from_buffer(_build_stablehlo_convolution_model()) @I.ir_module class Expected: - @R.function - def main( - tvmgen_tensor_0: R.Tensor((1, 4, 4, 1), dtype="int8"), - tvmgen_tensor_1: R.Tensor((2, 3, 3, 1), dtype="int8"), - ) -> R.Tensor((1, 2, 2, 2), dtype="int8"): - R.func_attr({"num_input": 2}) - with R.dataflow(): - lv: R.Tensor((1, 4, 4, 1), dtype="float32") = R.dequantize( - tvmgen_tensor_0, - R.const(0.5, "float32"), - R.const(3, "int32"), - out_dtype="float32", - axis=0, - ) - lv1: R.Tensor((3, 3, 1, 2), dtype="int8") = R.permute_dims( - tvmgen_tensor_1, - axes=[1, 2, 3, 0], - ) - lv2: R.Tensor((3, 3, 1, 2), dtype="float32") = R.dequantize( - lv1, - R.const(0.25, "float32"), - R.const(0, "int32"), - out_dtype="float32", - axis=3, - ) - lv3: R.Tensor((1, 2, 2, 2), dtype="float32") = R.nn.conv2d( - lv, - lv2, + @R.function + def main( + data: R.Tensor((1, 5, 5, 2), dtype="float32"), + kernel: R.Tensor((3, 3, 2, 4), dtype="float32"), + ) -> R.Tensor((1, 3, 3, 4), dtype="float32"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + gv: R.Tensor((1, 3, 3, 4), dtype="float32") = R.nn.conv2d( + data, + kernel, strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], @@ -6146,83 +6142,127 @@ def main( out_layout="NHWC", out_dtype="void", ) - gv: R.Tensor((1, 2, 2, 2), dtype="int8") = R.quantize( - lv3, - R.const(1.0, "float32"), - R.const(0, "int32"), - out_dtype="int8", - axis=0, - ) R.output(gv) return gv tvm.ir.assert_structural_equal(mod, Expected) -def test_quantized_conv2d_per_channel_weight_uses_remapped_axis(): - """Quantized Conv2D remaps per-channel weight axis after OHWI -> HWIO.""" - builder = flatbuffers.Builder(2048) +def test_stablehlo_convolution_feature_group_unsupported(): + """TFLite StableHLO CONVOLUTION rejects grouped convolution in the first subset.""" + buf = _build_stablehlo_convolution_model(feature_group_count=2) + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) - in_q = _build_quantization_parameters( + with pytest.raises(tvm.error.OpNotImplemented, match="feature_group_count"): + from_tflite(tflite_model) + + +def test_stablehlo_convolution_dimension_numbers_unsupported(): + """TFLite StableHLO CONVOLUTION rejects non-canonical dimension numbers.""" + buf = _build_stablehlo_convolution_model(input_batch_dimension=1) + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + + with pytest.raises(tvm.error.OpNotImplemented, match="dimension numbers"): + from_tflite(tflite_model) + + +# Quantized TFLite QDQ tests + + +def test_tensor_quantization_parameters_are_parsed(): + """Tensor quantization metadata is kept without requiring quantized op support.""" + builder = flatbuffers.Builder(1024) + + per_tensor_quantization = _build_quantization_parameters( builder, scale=[0.5], zero_point=[3], quantized_dimension=0 ) - wt_q = _build_quantization_parameters( - builder, scale=[0.25, 0.75], zero_point=[0, 0], quantized_dimension=0 - ) - out_q = _build_quantization_parameters( - builder, scale=[1.0], zero_point=[0], quantized_dimension=0 + per_axis_quantization = _build_quantization_parameters( + builder, scale=[0.25, 0.75], zero_point=[0, 0], quantized_dimension=3 ) - - input_tensor = _build_tensor( + per_tensor = _build_tensor( builder, 0, - [1, 4, 4, 1], - tensor_type=_tfl_tensor_type.INT8, - quantization=in_q, + [1, 4], + tensor_type=_tfl_tensor_type.UINT8, + quantization=per_tensor_quantization, ) - weight_tensor = _build_tensor( + per_axis = _build_tensor( builder, 1, - [2, 3, 3, 1], + [1, 2, 3, 2], tensor_type=_tfl_tensor_type.INT8, - quantization=wt_q, + quantization=per_axis_quantization, ) - output_tensor = _build_tensor( - builder, - 2, - [1, 2, 2, 2], - tensor_type=_tfl_tensor_type.INT8, - quantization=out_q, + subgraph = _build_subgraph( + builder, tensors=[per_tensor, per_axis], operators=[], inputs=[0, 1], outputs=[0, 1] ) + buffers = [_build_buffer(builder), _build_buffer(builder)] + buf = _finish_tflite_model(builder, subgraph=subgraph, operator_codes=[], buffers=buffers) - _tfl_conv2d_options.Conv2DOptionsStart(builder) - _tfl_conv2d_options.Conv2DOptionsAddStrideH(builder, 1) - _tfl_conv2d_options.Conv2DOptionsAddStrideW(builder, 1) - _tfl_conv2d_options.Conv2DOptionsAddPadding(builder, _tfl_padding.VALID) - _tfl_conv2d_options.Conv2DOptionsAddFusedActivationFunction(builder, 0) - conv_opts = _tfl_conv2d_options.Conv2DOptionsEnd(builder) + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) - conv_op = _build_operator( + converter = tflite_frontend.OperatorConverter( + tflite_model, tflite_model.Subgraphs(0), tflite_frontend.ExprTable(), None + ) + per_tensor_wrapper, per_axis_wrapper = converter.get_tensors([0, 1]) + + np.testing.assert_allclose(per_tensor_wrapper.qnn_params["scale"].data.numpy(), 0.5) + np.testing.assert_equal(per_tensor_wrapper.qnn_params["zero_point"].data.numpy(), 3) + assert per_tensor_wrapper.qnn_params["axis"] == 0 + + np.testing.assert_allclose( + per_axis_wrapper.qnn_params["scale"].data.numpy(), np.array([0.25, 0.75]) + ) + np.testing.assert_equal(per_axis_wrapper.qnn_params["zero_point"].data.numpy(), 0) + assert per_axis_wrapper.qnn_params["axis"] == 3 + + mod = from_tflite(tflite_model) + assert len(mod["main"].params) == 2 + + +def test_quantize_op_uses_relax_quantize(): + """TFLite QUANTIZE float32 -> int8 uses R.quantize.""" + builder = flatbuffers.Builder(1024) + + input_data = np.array([1.0, 2.0], dtype=np.float32) + output_qparams = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + + input_tensor = _build_tensor(builder, 0, [2], tensor_type=_tfl_tensor_type.FLOAT32) + output_tensor = _build_tensor( builder, - 0, - [0, 1], + 1, [2], - builtin_options_type=_tfl_builtin_options.Conv2DOptions, - builtin_options=conv_opts, + tensor_type=_tfl_tensor_type.INT8, + quantization=output_qparams, ) + + quantize_op = _build_operator(builder, 0, [0], [1]) subgraph = _build_subgraph( builder, - tensors=[input_tensor, weight_tensor, output_tensor], - operators=[conv_op], - inputs=[0, 1], - outputs=[2], + tensors=[input_tensor, output_tensor], + operators=[quantize_op], + inputs=[0], + outputs=[1], ) - operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.CONV_2D)] + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.QUANTIZE)] + input_buffer = _build_buffer(builder, input_data.tobytes()) + output_buffer = _build_buffer(builder) buf = _finish_tflite_model( builder, subgraph=subgraph, operator_codes=operator_codes, - buffers=[_build_buffer(builder), _build_buffer(builder), _build_buffer(builder)], + buffers=[input_buffer, output_buffer], ) if hasattr(tflite.Model, "Model"): @@ -6235,48 +6275,15 @@ def test_quantized_conv2d_per_channel_weight_uses_remapped_axis(): @I.ir_module class Expected: @R.function - def main( - tvmgen_tensor_0: R.Tensor((1, 4, 4, 1), dtype="int8"), - tvmgen_tensor_1: R.Tensor((2, 3, 3, 1), dtype="int8"), - ) -> R.Tensor((1, 2, 2, 2), dtype="int8"): - R.func_attr({"num_input": 2}) + def main(x: R.Tensor((2,), dtype="float32")) -> R.Tensor((2,), dtype="int8"): + R.func_attr({"num_input": 1}) with R.dataflow(): - lv: R.Tensor((1, 4, 4, 1), dtype="float32") = R.dequantize( - tvmgen_tensor_0, + gv: R.Tensor((2,), dtype="int8") = R.quantize( + x, R.const(0.5, "float32"), R.const(3, "int32"), - out_dtype="float32", axis=0, - ) - lv1: R.Tensor((3, 3, 1, 2), dtype="int8") = R.permute_dims( - tvmgen_tensor_1, - axes=[1, 2, 3, 0], - ) - lv2: R.Tensor((3, 3, 1, 2), dtype="float32") = R.dequantize( - lv1, - R.const([0.25, 0.75], "float32"), - R.const(0, "int32"), - out_dtype="float32", - axis=3, - ) - lv3: R.Tensor((1, 2, 2, 2), dtype="float32") = R.nn.conv2d( - lv, - lv2, - strides=[1, 1], - padding=[0, 0, 0, 0], - dilation=[1, 1], - groups=1, - data_layout="NHWC", - kernel_layout="HWIO", - out_layout="NHWC", - out_dtype="void", - ) - gv: R.Tensor((1, 2, 2, 2), dtype="int8") = R.quantize( - lv3, - R.const(1.0, "float32"), - R.const(0, "int32"), out_dtype="int8", - axis=0, ) R.output(gv) return gv @@ -6284,118 +6291,147 @@ def main( tvm.ir.assert_structural_equal(mod, Expected) -def test_stablehlo_cbrt(): - """TFLite StableHLO CBRT uses a sign-preserving composite expression.""" - mod = _load_model_from_buffer( - _build_stablehlo_model(builtin_name="STABLEHLO_CBRT", input_count=1) - ) - - @I.ir_module - class Expected: - @R.function - def main(x: R.Tensor((2, 2), dtype="float32")) -> R.Tensor((2, 2), dtype="float32"): - R.func_attr({"num_input": 1}) - with R.dataflow(): - lv: R.Tensor((2, 2), dtype="float32") = R.negative(x) - lv1: R.Tensor((2, 2), dtype="float32") = R.power(lv, R.const(1.0 / 3.0, "float32")) - lv2: R.Tensor((2, 2), dtype="bool") = R.less(x, R.const(0, "float32")) - lv3: R.Tensor((2, 2), dtype="float32") = R.negative(lv1) - lv4: R.Tensor((2, 2), dtype="float32") = R.power(x, R.const(1.0 / 3.0, "float32")) - gv: R.Tensor((2, 2), dtype="float32") = R.where(lv2, lv3, lv4) - R.output(gv) - return gv +def test_quantize_op_requantize_uses_dq_q(): + """TFLite QUANTIZE with quantized input uses DQ→Q (requantize).""" + builder = flatbuffers.Builder(1024) - tvm.ir.assert_structural_equal(mod, Expected) + input_data = np.array([10, 20], dtype=np.int8) + input_qparams = _build_quantization_parameters( + builder, scale=[0.25], zero_point=[1], quantized_dimension=0 + ) + output_qparams = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + input_tensor = _build_tensor( + builder, + 0, + [2], + tensor_type=_tfl_tensor_type.INT8, + quantization=input_qparams, + ) + output_tensor = _build_tensor( + builder, + 1, + [2], + tensor_type=_tfl_tensor_type.INT8, + quantization=output_qparams, + ) -def test_stablehlo_remainder(): - """TFLite StableHLO REMAINDER uses truncating remainder semantics.""" - mod = _load_model_from_buffer( - _build_stablehlo_model(builtin_name="STABLEHLO_REMAINDER", input_count=2) + quantize_op = _build_operator( + builder, + 0, + [0], + [1], + ) + subgraph = _build_subgraph( + builder, + tensors=[input_tensor, output_tensor], + operators=[quantize_op], + inputs=[0], + outputs=[1], + ) + operator_codes = [ + _build_operator_code(builder, _tfl_builtin_operator.QUANTIZE), + ] + input_buffer = _build_buffer(builder, input_data.tobytes()) + output_buffer = _build_buffer(builder) + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[input_buffer, output_buffer], ) + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + mod = from_tflite(tflite_model) + mod["main"] = mod["main"].without_attr("params") + @I.ir_module class Expected: @R.function def main( - x: R.Tensor((2, 2), dtype="float32"), - y: R.Tensor((2, 2), dtype="float32"), - ) -> R.Tensor((2, 2), dtype="float32"): - R.func_attr({"num_input": 2}) + tvmgen_tensor_0: R.Tensor((2,), dtype="int8"), + ) -> R.Tensor((2,), dtype="int8"): + R.func_attr({"num_input": 1}) with R.dataflow(): - lv: R.Tensor((2, 2), dtype="float32") = R.divide(x, y) - lv1: R.Tensor((2, 2), dtype="float32") = R.trunc(lv) - lv2: R.Tensor((2, 2), dtype="float32") = R.multiply(y, lv1) - gv: R.Tensor((2, 2), dtype="float32") = R.subtract(x, lv2) + lv: R.Tensor((2,), dtype="float32") = R.dequantize( + tvmgen_tensor_0, + R.const(0.25, "float32"), + R.const(1, "int32"), + out_dtype="float32", + axis=0, + ) + gv: R.Tensor((2,), dtype="int8") = R.quantize( + lv, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="int8", + axis=0, + ) R.output(gv) return gv tvm.ir.assert_structural_equal(mod, Expected) -def _build_stablehlo_dynamic_update_slice_model(start_vals, dynamic_starts=False): - """Build a minimal STABLEHLO_DYNAMIC_UPDATE_SLICE model.""" +def test_dequantize_op_uses_relax_dequantize(): + """TFLite DEQUANTIZE int8 -> float32 uses R.dequantize.""" builder = flatbuffers.Builder(1024) - builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_DYNAMIC_UPDATE_SLICE") - op_code = _build_operator_code(builder, builtin_op) - t_operand = _build_tensor(builder, 0, [3, 4]) - t_update = _build_tensor(builder, 1, [2, 2]) - start_tensors = [ - _build_tensor(builder, 2 + i, [], tensor_type=_tfl_tensor_type.INT32) - for i in range(len(start_vals)) - ] - out_idx = 2 + len(start_vals) - t_out = _build_tensor(builder, out_idx, [3, 4]) - tensors = [t_operand, t_update, *start_tensors, t_out] + input_data = np.array([10, 20], dtype=np.int8) + input_qparams = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) - op_inputs = [0, 1, *range(2, out_idx)] - op = _build_operator(builder, 0, op_inputs, [out_idx]) - subgraph_inputs = op_inputs if dynamic_starts else [0, 1] - subgraph = _build_subgraph( + input_tensor = _build_tensor( builder, - tensors=tensors, - operators=[op], - inputs=subgraph_inputs, - outputs=[out_idx], + 0, + [2], + tensor_type=_tfl_tensor_type.INT8, + quantization=input_qparams, ) - if dynamic_starts: - buffers = [_build_buffer(builder) for _ in range(out_idx + 1)] - else: - start_buffers = [ - _build_buffer(builder, np.array([start], dtype=np.int32).tobytes()) - for start in start_vals - ] - buffers = [ - _build_buffer(builder), - _build_buffer(builder), - *start_buffers, - _build_buffer(builder), - ] + output_tensor = _build_tensor(builder, 1, [2], tensor_type=_tfl_tensor_type.FLOAT32) - return _finish_tflite_model( - builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers + dequantize_op = _build_operator(builder, 0, [0], [1]) + subgraph = _build_subgraph( + builder, + tensors=[input_tensor, output_tensor], + operators=[dequantize_op], + inputs=[0], + outputs=[1], + ) + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.DEQUANTIZE)] + input_buffer = _build_buffer(builder, input_data.tobytes()) + output_buffer = _build_buffer(builder) + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[input_buffer, output_buffer], ) - -def test_stablehlo_dynamic_update_slice(): - """TFLite StableHLO DYNAMIC_UPDATE_SLICE with static starts.""" - mod = _load_model_from_buffer(_build_stablehlo_dynamic_update_slice_model([1, 1])) + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + mod = from_tflite(tflite_model) + mod["main"] = mod["main"].without_attr("params") @I.ir_module class Expected: @R.function - def main( - operand: R.Tensor((3, 4), dtype="float32"), - update: R.Tensor((2, 2), dtype="float32"), - ) -> R.Tensor((3, 4), dtype="float32"): - R.func_attr({"num_input": 2}) + def main(x: R.Tensor((2,), dtype="int8")) -> R.Tensor((2,), dtype="float32"): + R.func_attr({"num_input": 1}) with R.dataflow(): - gv: R.Tensor((3, 4), dtype="float32") = R.scatter_nd( - operand, - R.const([[[1, 1], [1, 2]], [[2, 1], [2, 2]]], dtype="int64"), - update, - reduction="update", + gv: R.Tensor((2,), dtype="float32") = R.dequantize( + x, + R.const(0.5, "float32"), + R.const(3, "int32"), + axis=0, ) R.output(gv) return gv @@ -6403,246 +6439,234 @@ def main( tvm.ir.assert_structural_equal(mod, Expected) -def test_stablehlo_dynamic_update_slice_dynamic_starts_unsupported(): - """TFLite StableHLO DYNAMIC_UPDATE_SLICE with runtime starts is unsupported.""" - buf = _build_stablehlo_dynamic_update_slice_model([0, 0], dynamic_starts=True) - if hasattr(tflite.Model, "Model"): - tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) - else: - tflite_model = tflite.Model.GetRootAsModel(buf, 0) - - with pytest.raises(tvm.error.OpNotImplemented, match="dynamic start"): - from_tflite(tflite_model) - - -def test_stablehlo_dynamic_update_slice_out_of_bounds_unsupported(): - """TFLite StableHLO DYNAMIC_UPDATE_SLICE rejects out-of-bounds updates.""" - buf = _build_stablehlo_dynamic_update_slice_model([2, 3]) - if hasattr(tflite.Model, "Model"): - tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) - else: - tflite_model = tflite.Model.GetRootAsModel(buf, 0) - - with pytest.raises(tvm.error.OpNotImplemented, match="out-of-bounds"): - from_tflite(tflite_model) - - -def _build_stablehlo_dot_general_model(lhs_contract, rhs_contract, lhs_batch=None, rhs_batch=None): - """Build a minimal STABLEHLO_DOT_GENERAL model.""" - builder = flatbuffers.Builder(1024) - lhs_batch = [] if lhs_batch is None else lhs_batch - rhs_batch = [] if rhs_batch is None else rhs_batch +def test_quantized_conv2d_per_tensor_uses_qdq(): + """Quantized Conv2D with per-tensor quantization uses DQ -> conv2d -> Q.""" + builder = flatbuffers.Builder(2048) - lhs_batch_vec = _tflite_int64_vector( - builder, - _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsStartLhsBatchingDimensionsVector, - lhs_batch, + in_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 ) - rhs_batch_vec = _tflite_int64_vector( + wt_q = _build_quantization_parameters( + builder, scale=[0.25], zero_point=[0], quantized_dimension=0 + ) + out_q = _build_quantization_parameters( + builder, scale=[1.0], zero_point=[0], quantized_dimension=0 + ) + + input_tensor = _build_tensor( builder, - _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsStartRhsBatchingDimensionsVector, - rhs_batch, + 0, + [1, 4, 4, 1], + tensor_type=_tfl_tensor_type.INT8, + quantization=in_q, ) - lhs_contract_vec = _tflite_int64_vector( + weight_tensor = _build_tensor( builder, - _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsStartLhsContractingDimensionsVector, - lhs_contract, + 1, + [2, 3, 3, 1], + tensor_type=_tfl_tensor_type.INT8, + quantization=wt_q, ) - rhs_contract_vec = _tflite_int64_vector( + output_tensor = _build_tensor( builder, - _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsStartRhsContractingDimensionsVector, - rhs_contract, + 2, + [1, 2, 2, 2], + tensor_type=_tfl_tensor_type.INT8, + quantization=out_q, ) - _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsStart(builder) - _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsAddLhsBatchingDimensions( - builder, lhs_batch_vec - ) - _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsAddRhsBatchingDimensions( - builder, rhs_batch_vec - ) - _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsAddLhsContractingDimensions( - builder, lhs_contract_vec - ) - _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsAddRhsContractingDimensions( - builder, rhs_contract_vec - ) - dot_opts = _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsEnd(builder) + _tfl_conv2d_options.Conv2DOptionsStart(builder) + _tfl_conv2d_options.Conv2DOptionsAddStrideH(builder, 1) + _tfl_conv2d_options.Conv2DOptionsAddStrideW(builder, 1) + _tfl_conv2d_options.Conv2DOptionsAddPadding(builder, _tfl_padding.VALID) + _tfl_conv2d_options.Conv2DOptionsAddFusedActivationFunction(builder, 0) + conv_opts = _tfl_conv2d_options.Conv2DOptionsEnd(builder) - builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_DOT_GENERAL") - op_code = _build_operator_code(builder, builtin_op) - t_lhs = _build_tensor(builder, 0, [2, 3]) - t_rhs = _build_tensor(builder, 1, [3, 4]) - t_out = _build_tensor(builder, 2, [2, 4]) - op = _build_operator( + conv_op = _build_operator( builder, 0, [0, 1], [2], - builtin_options2_type=_tfl_builtin_options2.StablehloDotGeneralOptions, - builtin_options2=dot_opts, + builtin_options_type=_tfl_builtin_options.Conv2DOptions, + builtin_options=conv_opts, ) subgraph = _build_subgraph( builder, - tensors=[t_lhs, t_rhs, t_out], - operators=[op], + tensors=[input_tensor, weight_tensor, output_tensor], + operators=[conv_op], inputs=[0, 1], outputs=[2], ) - buffers = [_build_buffer(builder) for _ in range(3)] - return _finish_tflite_model( - builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.CONV_2D)] + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[_build_buffer(builder), _build_buffer(builder), _build_buffer(builder)], ) - -def test_stablehlo_dot_general(): - """TFLite StableHLO DOT_GENERAL canonical 2D matmul.""" - mod = _load_model_from_buffer(_build_stablehlo_dot_general_model([1], [0])) + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + mod = from_tflite(tflite_model) + mod["main"] = mod["main"].without_attr("params") @I.ir_module class Expected: @R.function def main( - lhs: R.Tensor((2, 3), dtype="float32"), - rhs: R.Tensor((3, 4), dtype="float32"), - ) -> R.Tensor((2, 4), dtype="float32"): + tvmgen_tensor_0: R.Tensor((1, 4, 4, 1), dtype="int8"), + tvmgen_tensor_1: R.Tensor((2, 3, 3, 1), dtype="int8"), + ) -> R.Tensor((1, 2, 2, 2), dtype="int8"): R.func_attr({"num_input": 2}) with R.dataflow(): - gv: R.Tensor((2, 4), dtype="float32") = R.matmul(lhs, rhs, out_dtype="void") + lv: R.Tensor((1, 4, 4, 1), dtype="float32") = R.dequantize( + tvmgen_tensor_0, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="float32", + axis=0, + ) + lv1: R.Tensor((3, 3, 1, 2), dtype="int8") = R.permute_dims( + tvmgen_tensor_1, + axes=[1, 2, 3, 0], + ) + lv2: R.Tensor((3, 3, 1, 2), dtype="float32") = R.dequantize( + lv1, + R.const(0.25, "float32"), + R.const(0, "int32"), + out_dtype="float32", + axis=3, + ) + lv3: R.Tensor((1, 2, 2, 2), dtype="float32") = R.nn.conv2d( + lv, + lv2, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="HWIO", + out_layout="NHWC", + out_dtype="void", + ) + gv: R.Tensor((1, 2, 2, 2), dtype="int8") = R.quantize( + lv3, + R.const(1.0, "float32"), + R.const(0, "int32"), + out_dtype="int8", + axis=0, + ) R.output(gv) return gv tvm.ir.assert_structural_equal(mod, Expected) -def test_stablehlo_dot_general_noncanonical_unsupported(): - """TFLite StableHLO DOT_GENERAL rejects non-canonical contracting dims.""" - buf = _build_stablehlo_dot_general_model([0], [0]) - if hasattr(tflite.Model, "Model"): - tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) - else: - tflite_model = tflite.Model.GetRootAsModel(buf, 0) - - with pytest.raises(tvm.error.OpNotImplemented, match="contracting"): - from_tflite(tflite_model) - - -def _build_stablehlo_convolution_model(feature_group_count=1, input_batch_dimension=0): - """Build a minimal STABLEHLO_CONVOLUTION model.""" - builder = flatbuffers.Builder(1024) +def test_quantized_conv2d_per_channel_weight_uses_remapped_axis(): + """Quantized Conv2D remaps per-channel weight axis after OHWI -> HWIO.""" + builder = flatbuffers.Builder(2048) - window_strides_vec = _tflite_int64_vector( - builder, - _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartWindowStridesVector, - [1, 1], - ) - padding_vec = _tflite_int64_vector( - builder, - _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartPaddingVector, - [0, 0, 0, 0], - ) - lhs_dilation_vec = _tflite_int64_vector( - builder, _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartLhsDilationVector, [1, 1] + in_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 ) - rhs_dilation_vec = _tflite_int64_vector( - builder, _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartRhsDilationVector, [1, 1] + wt_q = _build_quantization_parameters( + builder, scale=[0.25, 0.75], zero_point=[0, 0], quantized_dimension=0 ) - window_reversal_vec = _tflite_bool_vector( - builder, - _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartWindowReversalVector, - [False, False], + out_q = _build_quantization_parameters( + builder, scale=[1.0], zero_point=[0], quantized_dimension=0 ) - input_spatial_vec = _tflite_int64_vector( + + input_tensor = _build_tensor( builder, - _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartInputSpatialDimensionsVector, - [1, 2], + 0, + [1, 4, 4, 1], + tensor_type=_tfl_tensor_type.INT8, + quantization=in_q, ) - kernel_spatial_vec = _tflite_int64_vector( + weight_tensor = _build_tensor( builder, - _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartKernelSpatialDimensionsVector, - [0, 1], + 1, + [2, 3, 3, 1], + tensor_type=_tfl_tensor_type.INT8, + quantization=wt_q, ) - output_spatial_vec = _tflite_int64_vector( + output_tensor = _build_tensor( builder, - _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartOutputSpatialDimensionsVector, - [1, 2], + 2, + [1, 2, 2, 2], + tensor_type=_tfl_tensor_type.INT8, + quantization=out_q, ) - _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStart(builder) - _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddWindowStrides( - builder, window_strides_vec - ) - _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddPadding(builder, padding_vec) - _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddLhsDilation(builder, lhs_dilation_vec) - _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddRhsDilation(builder, rhs_dilation_vec) - _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddWindowReversal( - builder, window_reversal_vec - ) - _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddInputBatchDimension( - builder, input_batch_dimension - ) - _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddInputFeatureDimension(builder, 3) - _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddInputSpatialDimensions( - builder, input_spatial_vec - ) - _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddKernelInputFeatureDimension(builder, 2) - _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddKernelOutputFeatureDimension(builder, 3) - _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddKernelSpatialDimensions( - builder, kernel_spatial_vec - ) - _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddOutputBatchDimension(builder, 0) - _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddOutputFeatureDimension(builder, 3) - _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddOutputSpatialDimensions( - builder, output_spatial_vec - ) - _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddFeatureGroupCount( - builder, feature_group_count - ) - _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddBatchGroupCount(builder, 1) - conv_opts = _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsEnd(builder) + _tfl_conv2d_options.Conv2DOptionsStart(builder) + _tfl_conv2d_options.Conv2DOptionsAddStrideH(builder, 1) + _tfl_conv2d_options.Conv2DOptionsAddStrideW(builder, 1) + _tfl_conv2d_options.Conv2DOptionsAddPadding(builder, _tfl_padding.VALID) + _tfl_conv2d_options.Conv2DOptionsAddFusedActivationFunction(builder, 0) + conv_opts = _tfl_conv2d_options.Conv2DOptionsEnd(builder) - builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_CONVOLUTION") - op_code = _build_operator_code(builder, builtin_op) - t_data = _build_tensor(builder, 0, [1, 5, 5, 2]) - t_kernel = _build_tensor(builder, 1, [3, 3, 2, 4]) - t_out = _build_tensor(builder, 2, [1, 3, 3, 4]) - op = _build_operator( + conv_op = _build_operator( builder, 0, [0, 1], [2], - builtin_options2_type=_tfl_builtin_options2.StablehloConvolutionOptions, - builtin_options2=conv_opts, + builtin_options_type=_tfl_builtin_options.Conv2DOptions, + builtin_options=conv_opts, ) subgraph = _build_subgraph( builder, - tensors=[t_data, t_kernel, t_out], - operators=[op], + tensors=[input_tensor, weight_tensor, output_tensor], + operators=[conv_op], inputs=[0, 1], outputs=[2], ) - buffers = [_build_buffer(builder) for _ in range(3)] - return _finish_tflite_model( - builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.CONV_2D)] + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[_build_buffer(builder), _build_buffer(builder), _build_buffer(builder)], ) - -def test_stablehlo_convolution(): - """TFLite StableHLO CONVOLUTION canonical NHWC/HWIO 2D convolution.""" - mod = _load_model_from_buffer(_build_stablehlo_convolution_model()) + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + mod = from_tflite(tflite_model) + mod["main"] = mod["main"].without_attr("params") @I.ir_module class Expected: @R.function def main( - data: R.Tensor((1, 5, 5, 2), dtype="float32"), - kernel: R.Tensor((3, 3, 2, 4), dtype="float32"), - ) -> R.Tensor((1, 3, 3, 4), dtype="float32"): + tvmgen_tensor_0: R.Tensor((1, 4, 4, 1), dtype="int8"), + tvmgen_tensor_1: R.Tensor((2, 3, 3, 1), dtype="int8"), + ) -> R.Tensor((1, 2, 2, 2), dtype="int8"): R.func_attr({"num_input": 2}) with R.dataflow(): - gv: R.Tensor((1, 3, 3, 4), dtype="float32") = R.nn.conv2d( - data, - kernel, + lv: R.Tensor((1, 4, 4, 1), dtype="float32") = R.dequantize( + tvmgen_tensor_0, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="float32", + axis=0, + ) + lv1: R.Tensor((3, 3, 1, 2), dtype="int8") = R.permute_dims( + tvmgen_tensor_1, + axes=[1, 2, 3, 0], + ) + lv2: R.Tensor((3, 3, 1, 2), dtype="float32") = R.dequantize( + lv1, + R.const([0.25, 0.75], "float32"), + R.const(0, "int32"), + out_dtype="float32", + axis=3, + ) + lv3: R.Tensor((1, 2, 2, 2), dtype="float32") = R.nn.conv2d( + lv, + lv2, strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], @@ -6652,36 +6676,19 @@ def main( out_layout="NHWC", out_dtype="void", ) + gv: R.Tensor((1, 2, 2, 2), dtype="int8") = R.quantize( + lv3, + R.const(1.0, "float32"), + R.const(0, "int32"), + out_dtype="int8", + axis=0, + ) R.output(gv) return gv tvm.ir.assert_structural_equal(mod, Expected) -def test_stablehlo_convolution_feature_group_unsupported(): - """TFLite StableHLO CONVOLUTION rejects grouped convolution in the first subset.""" - buf = _build_stablehlo_convolution_model(feature_group_count=2) - if hasattr(tflite.Model, "Model"): - tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) - else: - tflite_model = tflite.Model.GetRootAsModel(buf, 0) - - with pytest.raises(tvm.error.OpNotImplemented, match="feature_group_count"): - from_tflite(tflite_model) - - -def test_stablehlo_convolution_dimension_numbers_unsupported(): - """TFLite StableHLO CONVOLUTION rejects non-canonical dimension numbers.""" - buf = _build_stablehlo_convolution_model(input_batch_dimension=1) - if hasattr(tflite.Model, "Model"): - tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) - else: - tflite_model = tflite.Model.GetRootAsModel(buf, 0) - - with pytest.raises(tvm.error.OpNotImplemented, match="dimension numbers"): - from_tflite(tflite_model) - - def test_quantized_concat_uses_qdq(): """Quantized CONCATENATION uses DQ each input → concat → Q.""" import flatbuffers @@ -6700,10 +6707,10 @@ def test_quantized_concat_uses_qdq(): t1 = _build_tensor(builder, 1, [1, 2], tensor_type=_tfl_tensor_type.INT8, quantization=in_q) t2 = _build_tensor(builder, 2, [1, 4], tensor_type=_tfl_tensor_type.INT8, quantization=out_q) - tflite.ConcatenationOptionsStart(builder) - tflite.ConcatenationOptionsAddAxis(builder, 1) - tflite.ConcatenationOptionsAddFusedActivationFunction(builder, 0) - concat_opts = tflite.ConcatenationOptionsEnd(builder) + _tfl_concatenation_options.ConcatenationOptionsStart(builder) + _tfl_concatenation_options.ConcatenationOptionsAddAxis(builder, 1) + _tfl_concatenation_options.ConcatenationOptionsAddFusedActivationFunction(builder, 0) + concat_opts = _tfl_concatenation_options.ConcatenationOptionsEnd(builder) concat_op = _build_operator( builder, @@ -7075,7 +7082,7 @@ def test_quantized_add_without_output_qparams_invalid(): def test_quantized_conv2d_with_int32_bias_dequantizes_bias(): - """Conv2D with INT32 bias dequantizes bias with in_scale × wt_scale.""" + """Conv2D with INT32 bias dequantizes bias with in_scale x wt_scale.""" import flatbuffers import tflite.Model @@ -7102,12 +7109,12 @@ def test_quantized_conv2d_with_int32_bias_dequantizes_bias(): builder, 3, [1, 2, 2, 2], tensor_type=_tfl_tensor_type.INT8, quantization=out_q ) - tflite.Conv2DOptionsStart(builder) - tflite.Conv2DOptionsAddStrideH(builder, 1) - tflite.Conv2DOptionsAddStrideW(builder, 1) - tflite.Conv2DOptionsAddPadding(builder, 1) - tflite.Conv2DOptionsAddFusedActivationFunction(builder, 0) - conv_opts = tflite.Conv2DOptionsEnd(builder) + _tfl_conv2d_options.Conv2DOptionsStart(builder) + _tfl_conv2d_options.Conv2DOptionsAddStrideH(builder, 1) + _tfl_conv2d_options.Conv2DOptionsAddStrideW(builder, 1) + _tfl_conv2d_options.Conv2DOptionsAddPadding(builder, 1) + _tfl_conv2d_options.Conv2DOptionsAddFusedActivationFunction(builder, 0) + conv_opts = _tfl_conv2d_options.Conv2DOptionsEnd(builder) conv_op = _build_operator( builder, @@ -7354,13 +7361,13 @@ def test_per_channel_depthwise_conv_unsupported(): builder, 2, [1, 2, 2, 2], tensor_type=_tfl_tensor_type.INT8, quantization=out_q ) - tflite.DepthwiseConv2DOptionsStart(builder) - tflite.DepthwiseConv2DOptionsAddStrideH(builder, 1) - tflite.DepthwiseConv2DOptionsAddStrideW(builder, 1) - tflite.DepthwiseConv2DOptionsAddDepthMultiplier(builder, 1) - tflite.DepthwiseConv2DOptionsAddPadding(builder, 1) - tflite.DepthwiseConv2DOptionsAddFusedActivationFunction(builder, 0) - dw_opts = tflite.DepthwiseConv2DOptionsEnd(builder) + _tfl_depthwise_conv2d_options.DepthwiseConv2DOptionsStart(builder) + _tfl_depthwise_conv2d_options.DepthwiseConv2DOptionsAddStrideH(builder, 1) + _tfl_depthwise_conv2d_options.DepthwiseConv2DOptionsAddStrideW(builder, 1) + _tfl_depthwise_conv2d_options.DepthwiseConv2DOptionsAddDepthMultiplier(builder, 1) + _tfl_depthwise_conv2d_options.DepthwiseConv2DOptionsAddPadding(builder, 1) + _tfl_depthwise_conv2d_options.DepthwiseConv2DOptionsAddFusedActivationFunction(builder, 0) + dw_opts = _tfl_depthwise_conv2d_options.DepthwiseConv2DOptionsEnd(builder) dw_op = _build_operator( builder, @@ -7397,8 +7404,8 @@ def test_per_channel_depthwise_conv_unsupported(): def test_uint8_reshape_requantize_uses_dq_reshape_q(): """uint8 RESHAPE with different qparams uses DQ→reshape→Q.""" import flatbuffers - import tflite.Model import numpy as np + import tflite.Model builder = flatbuffers.Builder(1024) @@ -7415,11 +7422,11 @@ def test_uint8_reshape_requantize_uses_dq_reshape_q(): # Use ReshapeOptions with static new_shape [2, 2] new_shape_np = np.array([2, 2], dtype=np.int32) new_shape_vec = _tflite_int32_vector( - builder, tflite.ReshapeOptionsStartNewShapeVector, new_shape_np + builder, _tfl_reshape_options.ReshapeOptionsStartNewShapeVector, new_shape_np ) - tflite.ReshapeOptionsStart(builder) - tflite.ReshapeOptionsAddNewShape(builder, new_shape_vec) - reshape_opts = tflite.ReshapeOptionsEnd(builder) + _tfl_reshape_options.ReshapeOptionsStart(builder) + _tfl_reshape_options.ReshapeOptionsAddNewShape(builder, new_shape_vec) + reshape_opts = _tfl_reshape_options.ReshapeOptionsEnd(builder) reshape_op = _build_operator( builder, @@ -7485,9 +7492,10 @@ def main( def test_transpose_conv_with_int32_bias_dequantizes_bias(): """TRANSPOSE_CONV with INT32 bias dequantizes bias before adding.""" + import struct + import flatbuffers import tflite.Model - import struct builder = flatbuffers.Builder(2048) @@ -7514,12 +7522,12 @@ def test_transpose_conv_with_int32_bias_dequantizes_bias(): oshape_data = struct.pack(" Date: Tue, 26 May 2026 09:13:32 +0800 Subject: [PATCH 7/7] [Relax][Frontend][TFLite] Guard unsupported quantized operators --- .../relax/frontend/tflite/tflite_frontend.py | 87 +++++++++++++++++-- tests/python/relax/test_frontend_tflite.py | 34 ++++++++ 2 files changed, 113 insertions(+), 8 deletions(-) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index ddc97b77521c..979bbbb867ba 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -96,6 +96,64 @@ def __init__(self, tensor_idx, tensor, buffer, qnn_params=None): class OperatorConverter: """Operator Converted for converting TFLite ops to Relax ops""" + _SUPPORTED_QUANTIZED_OPS = frozenset( + { + "ABS", + "ADD", + "ATAN2", + "CEIL", + "CONCATENATION", + "CONV_2D", + "COS", + "DEPTHWISE_CONV_2D", + "DEQUANTIZE", + "DETECTION_POSTPROCESS", + "DIV", + "EQUAL", + "EXP", + "FLOOR", + "FLOOR_DIV", + "FLOOR_MOD", + "FULLY_CONNECTED", + "GREATER", + "GREATER_EQUAL", + "HARD_SWISH", + "LEAKY_RELU", + "LESS", + "LESS_EQUAL", + "LOG", + "LOGISTIC", + "LOG_SOFTMAX", + "MAXIMUM", + "MEAN", + "MINIMUM", + "MUL", + "NEG", + "NOT_EQUAL", + "POW", + "QUANTIZE", + "REDUCE_MAX", + "REDUCE_MIN", + "REDUCE_PROD", + "RELU", + "RELU6", + "RELU_N1_TO_1", + "RESHAPE", + "RESIZE_BILINEAR", + "ROUND", + "RSQRT", + "SIN", + "SOFTMAX", + "SQRT", + "SQUARED_DIFFERENCE", + "SUB", + "SUM", + "TAN", + "TANH", + "TRANSPOSE_CONV", + } + ) + def __init__(self, model, subgraph, exp_tab, ctx): from tflite.ActivationFunctionType import ActivationFunctionType from tflite.BuiltinOperator import BuiltinOperator @@ -326,6 +384,7 @@ def check_unsupported_ops(self): """Check unsupported TFLite ops in our converter.""" unsupported_ops_set = set() dynamic_range_ops_set = set() + unsupported_quantized_ops_set = set() for op_idx in range(self.subgraph.OperatorsLength()): op = self.subgraph.Operators(op_idx) op_code_str = self.get_op_code_str(op) @@ -334,19 +393,23 @@ def check_unsupported_ops(self): continue # Trying to exclude "dynamic range quantization" optimized ops as not supported in TVM - qnn_in_cnt = len( - [_.qnn_params for _ in self.get_input_tensors(op)[0:1] if _.qnn_params is not None] - ) + input_tensors = self.get_input_tensors(op) + output_tensors = self.get_output_tensors(op) + qnn_in_cnt = len([_.qnn_params for _ in input_tensors[0:1] if _.qnn_params is not None]) qnn_weight_cnt = len( - [_.qnn_params for _ in self.get_input_tensors(op)[1:] if _.qnn_params is not None] - ) - qnn_out_cnt = len( - [_.qnn_params for _ in self.get_output_tensors(op) if _.qnn_params is not None] + [_.qnn_params for _ in input_tensors[1:] if _.qnn_params is not None] ) + qnn_out_cnt = len([_.qnn_params for _ in output_tensors if _.qnn_params is not None]) if qnn_in_cnt == 0 and qnn_out_cnt == 0 and qnn_weight_cnt > 0: dynamic_range_ops_set.add(op_code_str) + if ( + qnn_in_cnt + qnn_weight_cnt + qnn_out_cnt > 0 + and op_code_str not in self._SUPPORTED_QUANTIZED_OPS + ): + unsupported_quantized_ops_set.add(op_code_str) + raise_msg = "" if unsupported_ops_set: @@ -358,7 +421,15 @@ def check_unsupported_ops(self): raise_msg += ( f"The following operators are likely to have dynamic range quantization: {ops}. " f"If you are running an optimized graph, please turn off dynamic range " - f"quantization or use full integer quantization" + f"quantization or use full integer quantization\n" + ) + + if unsupported_quantized_ops_set: + ops = ", ".join(f"'{op}'" for op in sorted(unsupported_quantized_ops_set)) + raise_msg += ( + f"The following quantized TFLite operators are not supported in frontend " + f"TFLite yet: {ops}. Quantized operators require explicit QDQ lowering " + f"to avoid applying Relax ops directly to quantized integer tensors.\n" ) if len(raise_msg) > 0: diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 34a37ee51b8b..d03de3b6a9c4 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -7081,6 +7081,40 @@ def test_quantized_add_without_output_qparams_invalid(): from_tflite(tflite_model) +def test_quantized_square_unsupported(): + """Quantized SQUARE is rejected instead of applying integer power directly.""" + builder = flatbuffers.Builder(1024) + + in_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + out_q = _build_quantization_parameters( + builder, scale=[1.0], zero_point=[0], quantized_dimension=0 + ) + + t_in = _build_tensor(builder, 0, [2], tensor_type=_tfl_tensor_type.INT8, quantization=in_q) + t_out = _build_tensor(builder, 1, [2], tensor_type=_tfl_tensor_type.INT8, quantization=out_q) + + square_op = _build_operator(builder, 0, [0], [1]) + subgraph = _build_subgraph( + builder, + tensors=[t_in, t_out], + operators=[square_op], + inputs=[0], + outputs=[1], + ) + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.SQUARE)] + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[_build_buffer(builder)] * 2, + ) + + with pytest.raises(tvm.error.OpNotImplemented, match="SQUARE"): + _load_model_from_buffer(buf) + + def test_quantized_conv2d_with_int32_bias_dequantizes_bias(): """Conv2D with INT32 bias dequantizes bias with in_scale x wt_scale.""" import flatbuffers