diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 28b125eec0b0..979bbbb867ba 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -19,9 +19,6 @@ # pylint: disable=no-value-for-parameter, unused-variable # pylint: disable=unexpected-keyword-arg, unused-import, too-many-function-args # ruff: noqa: RUF005 -# F821: _qnn and _expr references are in unreachable code paths (guarded by NotImplementedError) -# and will be resolved when quantization and vision op support are added. -# ruff: noqa: F821 """Tensorflow lite frontend.""" import functools @@ -99,6 +96,64 @@ def __init__(self, tensor_idx, tensor, buffer, qnn_params=None): class OperatorConverter: """Operator Converted for converting TFLite ops to Relax ops""" + _SUPPORTED_QUANTIZED_OPS = frozenset( + { + "ABS", + "ADD", + "ATAN2", + "CEIL", + "CONCATENATION", + "CONV_2D", + "COS", + "DEPTHWISE_CONV_2D", + "DEQUANTIZE", + "DETECTION_POSTPROCESS", + "DIV", + "EQUAL", + "EXP", + "FLOOR", + "FLOOR_DIV", + "FLOOR_MOD", + "FULLY_CONNECTED", + "GREATER", + "GREATER_EQUAL", + "HARD_SWISH", + "LEAKY_RELU", + "LESS", + "LESS_EQUAL", + "LOG", + "LOGISTIC", + "LOG_SOFTMAX", + "MAXIMUM", + "MEAN", + "MINIMUM", + "MUL", + "NEG", + "NOT_EQUAL", + "POW", + "QUANTIZE", + "REDUCE_MAX", + "REDUCE_MIN", + "REDUCE_PROD", + "RELU", + "RELU6", + "RELU_N1_TO_1", + "RESHAPE", + "RESIZE_BILINEAR", + "ROUND", + "RSQRT", + "SIN", + "SOFTMAX", + "SQRT", + "SQUARED_DIFFERENCE", + "SUB", + "SUM", + "TAN", + "TANH", + "TRANSPOSE_CONV", + } + ) + def __init__(self, model, subgraph, exp_tab, ctx): from tflite.ActivationFunctionType import ActivationFunctionType from tflite.BuiltinOperator import BuiltinOperator @@ -329,6 +384,7 @@ def check_unsupported_ops(self): """Check unsupported TFLite ops in our converter.""" unsupported_ops_set = set() dynamic_range_ops_set = set() + unsupported_quantized_ops_set = set() for op_idx in range(self.subgraph.OperatorsLength()): op = self.subgraph.Operators(op_idx) op_code_str = self.get_op_code_str(op) @@ -337,19 +393,23 @@ def check_unsupported_ops(self): continue # Trying to exclude "dynamic range quantization" optimized ops as not supported in TVM - qnn_in_cnt = len( - [_.qnn_params for _ in self.get_input_tensors(op)[0:1] if _.qnn_params is not None] - ) + input_tensors = self.get_input_tensors(op) + output_tensors = self.get_output_tensors(op) + qnn_in_cnt = len([_.qnn_params for _ in input_tensors[0:1] if _.qnn_params is not None]) qnn_weight_cnt = len( - [_.qnn_params for _ in self.get_input_tensors(op)[1:] if _.qnn_params is not None] - ) - qnn_out_cnt = len( - [_.qnn_params for _ in self.get_output_tensors(op) if _.qnn_params is not None] + [_.qnn_params for _ in input_tensors[1:] if _.qnn_params is not None] ) + qnn_out_cnt = len([_.qnn_params for _ in output_tensors if _.qnn_params is not None]) if qnn_in_cnt == 0 and qnn_out_cnt == 0 and qnn_weight_cnt > 0: dynamic_range_ops_set.add(op_code_str) + if ( + qnn_in_cnt + qnn_weight_cnt + qnn_out_cnt > 0 + and op_code_str not in self._SUPPORTED_QUANTIZED_OPS + ): + unsupported_quantized_ops_set.add(op_code_str) + raise_msg = "" if unsupported_ops_set: @@ -361,7 +421,15 @@ def check_unsupported_ops(self): raise_msg += ( f"The following operators are likely to have dynamic range quantization: {ops}. " f"If you are running an optimized graph, please turn off dynamic range " - f"quantization or use full integer quantization" + f"quantization or use full integer quantization\n" + ) + + if unsupported_quantized_ops_set: + ops = ", ".join(f"'{op}'" for op in sorted(unsupported_quantized_ops_set)) + raise_msg += ( + f"The following quantized TFLite operators are not supported in frontend " + f"TFLite yet: {ops}. Quantized operators require explicit QDQ lowering " + f"to avoid applying Relax ops directly to quantized integer tensors.\n" ) if len(raise_msg) > 0: @@ -557,9 +625,7 @@ def get_tensors(self, tensors_idx_list): qnn_params = dict() qnn_params["scale"] = relax.const(scale, "float32") qnn_params["zero_point"] = relax.const(zero_point, "int32") - raise NotImplementedError( - "Quantized TFLite models are not yet supported in the Relax frontend" - ) + qnn_params["axis"] = int(tflite_qnn_params.QuantizedDimension()) return_list.append(TensorWrapper(tensor_idx, tensor, buffer, qnn_params)) return return_list @@ -664,20 +730,22 @@ def quantize(self, expr, tensor_to_quantize): """Helper function to quantize a tensor with Relax""" tensor_type = tensor_to_quantize.tensor.Type() tensor_type_str = self.get_tensor_type_str(tensor_type) - quantized = _qnn.op.quantize( + quantized = relax.op.quantize( data=expr, - output_scale=tensor_to_quantize.qnn_params["scale"], - output_zero_point=tensor_to_quantize.qnn_params["zero_point"], + scale=tensor_to_quantize.qnn_params["scale"], + zero_point=tensor_to_quantize.qnn_params["zero_point"], + axis=tensor_to_quantize.qnn_params["axis"], out_dtype=tensor_type_str, ) return quantized def dequantize(self, expr, tensor): """Helper function to dequantize a tensor with Relax""" - dequantized = _qnn.op.dequantize( + dequantized = relax.op.dequantize( data=expr, - input_scale=tensor.qnn_params["scale"], - input_zero_point=tensor.qnn_params["zero_point"], + scale=tensor.qnn_params["scale"], + zero_point=tensor.qnn_params["zero_point"], + axis=tensor.qnn_params["axis"], ) return dequantized @@ -713,7 +781,7 @@ def quantize(x): if fused_activation_fn == ActivationFunctionType.RELU_N1_TO_1: return relax.op.clip(expr, min=max(qmin, quantize(-1.0)), max=min(qmax, quantize(1.0))) if fused_activation_fn == ActivationFunctionType.RELU: - return relax.op.clip(expr, min=max(qmin, quantize(0.0)), a_max=qmax) + return relax.op.clip(expr, min=max(qmin, quantize(0.0)), max=qmax) fused_activation_fn_str = self.activation_fn_type[fused_activation_fn] raise tvm.error.OpNotImplemented( @@ -788,20 +856,15 @@ def convert_reshape(self, op): "TFLite reshape requires input and output scale and zero points to be equal" ) - out = relax.op.reshape(in_expr, shape=relax.ShapeExpr(target_shape)) if input_tensor.qnn_params and input_tensor_type_str == "uint8": output_tensor = output_tensors[0] if not self.has_same_qnn_params(input_tensor, output_tensor): - output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) - out = _qnn.op.requantize( - out, - input_scale=input_tensor.qnn_params["scale"], - input_zero_point=input_tensor.qnn_params["zero_point"], - output_scale=output_tensor.qnn_params["scale"], - output_zero_point=output_tensor.qnn_params["zero_point"], - out_dtype=output_tensor_type_str, - ) + in_f32 = self.dequantize(in_expr, input_tensor) + out = relax.op.reshape(in_f32, shape=relax.ShapeExpr(target_shape)) + out = self.quantize(out, output_tensor) + return out + out = relax.op.reshape(in_expr, shape=relax.ShapeExpr(target_shape)) return out def _convert_resize(self, method, op): @@ -1111,8 +1174,6 @@ def convert_shape(self, op): def convert_relu(self, op): """Convert TFLite ReLU""" - from tflite.ActivationFunctionType import ActivationFunctionType - input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 1, "input tensors length should be 1" input_tensor = input_tensors[0] @@ -1123,32 +1184,12 @@ def convert_relu(self, op): output_tensor = output_tensors[0] if input_tensor.qnn_params: - # Quantize a float value to an quantized integer value - scale_val = get_scalar_from_constant(input_tensor.qnn_params["scale"]) - zero_point_val = get_scalar_from_constant(input_tensor.qnn_params["zero_point"]) - - output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) - out = self.convert_qnn_fused_activation_function( - expr=in_expr, - fused_activation_fn=ActivationFunctionType.RELU, - scale=scale_val, - zero_point=zero_point_val, - dtype=output_tensor_type_str, - ) + in_f32 = self.dequantize(in_expr, input_tensor) + out = relax.op.nn.relu(in_f32) + out = self.quantize(out, output_tensor) else: out = relax.op.nn.relu(in_expr) - if output_tensor.qnn_params: - output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) - out = _qnn.op.requantize( - out, - input_scale=input_tensor.qnn_params["scale"], - input_zero_point=input_tensor.qnn_params["zero_point"], - output_scale=output_tensor.qnn_params["scale"], - output_zero_point=output_tensor.qnn_params["zero_point"], - out_dtype=output_tensor_type_str, - ) - return out def convert_hard_swish(self, op): @@ -1184,8 +1225,6 @@ def _hard_swish(data): def convert_relu6(self, op): """Convert TFLite ReLU6""" - from tflite.ActivationFunctionType import ActivationFunctionType - input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 1, "input tensors length should be 1" input_tensor = input_tensors[0] @@ -1196,32 +1235,12 @@ def convert_relu6(self, op): output_tensor = output_tensors[0] if input_tensor.qnn_params: - # Quantize a float value to an quantized integer value - scale_val = get_scalar_from_constant(input_tensor.qnn_params["scale"]) - zero_point_val = get_scalar_from_constant(input_tensor.qnn_params["zero_point"]) - - output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) - out = self.convert_qnn_fused_activation_function( - expr=in_expr, - fused_activation_fn=ActivationFunctionType.RELU6, - scale=scale_val, - zero_point=zero_point_val, - dtype=output_tensor_type_str, - ) + in_f32 = self.dequantize(in_expr, input_tensor) + out = relax.op.clip(in_f32, min=0, max=6) + out = self.quantize(out, output_tensor) else: out = relax.op.clip(in_expr, min=0, max=6) - if output_tensor.qnn_params: - output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) - out = _qnn.op.requantize( - out, - input_scale=input_tensor.qnn_params["scale"], - input_zero_point=input_tensor.qnn_params["zero_point"], - output_scale=output_tensor.qnn_params["scale"], - output_zero_point=output_tensor.qnn_params["zero_point"], - out_dtype=output_tensor_type_str, - ) - return out def convert_leaky_relu(self, op): @@ -1265,36 +1284,12 @@ def convert_relu_n1_to_1(self, op): output_tensor = output_tensors[0] if input_tensor.qnn_params: - # Quantize a float value to an quantized integer value - scale_val = get_scalar_from_constant(input_tensor.qnn_params["scale"]) - zero_point_val = get_scalar_from_constant(input_tensor.qnn_params["zero_point"]) - - def quantize(x): - return float(round(x / scale_val) + zero_point_val) - - # Get min/max of the input dtype. This will be used to ensure that - # clip a_min/a_max are not beyond the dtype range. - input_tensor_type_str = self.get_tensor_type_str(input_tensor.tensor.Type()) - qmin = float(tvm.tirx.min_value(input_tensor_type_str).value) - qmax = float(tvm.tirx.max_value(input_tensor_type_str).value) - - out = relax.op.clip( - in_expr, min=max(qmin, quantize(-1.0)), max=min(qmax, quantize(1.0)) - ) + in_f32 = self.dequantize(in_expr, input_tensor) + out = relax.op.clip(in_f32, min=-1, max=1) + out = self.quantize(out, output_tensor) else: out = relax.op.clip(in_expr, min=-1, max=1) - if output_tensor.qnn_params: - output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) - out = _qnn.op.requantize( - out, - input_scale=input_tensor.qnn_params["scale"], - input_zero_point=input_tensor.qnn_params["zero_point"], - output_scale=output_tensor.qnn_params["scale"], - output_zero_point=output_tensor.qnn_params["zero_point"], - out_dtype=output_tensor_type_str, - ) - return out def convert_log_softmax(self, op): @@ -1340,18 +1335,11 @@ def convert_concatenation(self, op): if not input_tensors[0].qnn_params: out = relax.op.concat(in_exprs, axis=concatenation_axis) else: - input_scales = [input_tensor.qnn_params["scale"] for input_tensor in input_tensors] - input_zero_points = [ - input_tensor.qnn_params["zero_point"] for input_tensor in input_tensors + in_f32s = [ + self.dequantize(expr, tensor) for expr, tensor in zip(in_exprs, input_tensors) ] - out = _qnn.op.concat( - in_exprs, - input_scales=input_scales, - input_zero_points=input_zero_points, - output_scale=output_tensor.qnn_params["scale"], - output_zero_point=output_tensor.qnn_params["zero_point"], - axis=concatenation_axis, - ) + out = relax.op.concat(in_f32s, axis=concatenation_axis) + out = self.quantize(out, output_tensor) # Handle fused activations if output_tensor.qnn_params: @@ -2441,7 +2429,7 @@ def convert_square(self, op): return out - def _convert_elemwise(self, op, relax_op, relax_qnn_op=None, comparison_op=False): + def _convert_elemwise(self, op, relax_op, comparison_op=False): """Generic method to Convert TFLite elemwise""" from tflite.AddOptions import AddOptions @@ -2450,7 +2438,6 @@ def _convert_elemwise(self, op, relax_op, relax_qnn_op=None, comparison_op=False from tflite.MulOptions import MulOptions from tflite.SubOptions import SubOptions - ignore_qnn_params = self.is_quantized(op) input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 2, "input tensors length should be 2" @@ -2458,36 +2445,19 @@ def _convert_elemwise(self, op, relax_op, relax_qnn_op=None, comparison_op=False rhs_tensor = input_tensors[1] lhs_expr = self.get_tensor_expr(lhs_tensor) rhs_expr = self.get_tensor_expr(rhs_tensor) + input_is_quantized = lhs_tensor.qnn_params is not None or rhs_tensor.qnn_params is not None output_tensors = self.get_output_tensors(op) assert len(output_tensors) == 1, "output tensors length should be 1" output_tensor = output_tensors[0] - # TFLite format demands equal scale and zero_point tuple parameters for some operations - # to allow us to use non-quantized operation instead of quantized if ignore_qnn_params=True - if ignore_qnn_params and not comparison_op: - assert ( - lhs_tensor.qnn_params - and self.has_same_qnn_params(lhs_tensor, output_tensor) - and self.has_same_qnn_params(rhs_tensor, output_tensor) - ), "All tensors should be quantized with the same (scale,zero-point) tuple parameters" - - # If quantized, extracts qnn params and call QNN add operator. - if not ignore_qnn_params and lhs_tensor.qnn_params: - assert rhs_tensor.qnn_params, "Both tensors should be quantized." - assert output_tensor.qnn_params, "Output tensor should be quantized." - out = relax_op( - lhs=lhs_expr, - rhs=rhs_expr, - lhs_scale=lhs_tensor.qnn_params["scale"], - lhs_zero_point=lhs_tensor.qnn_params["zero_point"], - rhs_scale=rhs_tensor.qnn_params["scale"], - rhs_zero_point=rhs_tensor.qnn_params["zero_point"], - output_scale=output_tensor.qnn_params["scale"], - output_zero_point=output_tensor.qnn_params["zero_point"], - ) - else: - out = relax_op(lhs_expr, rhs_expr) + if input_is_quantized: + if lhs_tensor.qnn_params: + lhs_expr = self.dequantize(lhs_expr, lhs_tensor) + if rhs_tensor.qnn_params: + rhs_expr = self.dequantize(rhs_expr, rhs_tensor) + + out = relax_op(lhs_expr, rhs_expr) # Options (fused_activation_function) options = None @@ -2505,20 +2475,14 @@ def _convert_elemwise(self, op, relax_op, relax_qnn_op=None, comparison_op=False options.Init(op_options.Bytes, op_options.Pos) fused_activation_fn = options.FusedActivationFunction() - # Handle fused activations - if not ignore_qnn_params and output_tensor.qnn_params: - scale_val = get_scalar_from_constant(output_tensor.qnn_params["scale"]) - zero_point_val = get_scalar_from_constant(output_tensor.qnn_params["zero_point"]) - output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) - out = self.convert_qnn_fused_activation_function( - expr=out, - fused_activation_fn=fused_activation_fn, - scale=scale_val, - zero_point=zero_point_val, - dtype=output_tensor_type_str, + out = self.convert_fused_activation_function(out, fused_activation_fn) + + if input_is_quantized and not comparison_op: + if not output_tensor.qnn_params: + raise tvm.error.OpAttributeInvalid( + "Quantized TFLite elemwise operator output must have quantization parameters" ) - else: - out = self.convert_fused_activation_function(out, fused_activation_fn) + out = self.quantize(out, output_tensor) return out def convert_add_n(self, op): @@ -3041,24 +3005,16 @@ def _convert_reduce(self, relax_op, op): keep_dims = False if input_tensor.qnn_params: - in_expr = relax.op.cast(in_expr, "int32") + in_expr = self.dequantize(in_expr, input_tensor) out = relax_op(in_expr, axis, keep_dims) - # Finally if the reduce is quantized. Add a requantize at the end. + # Finally if the reduce is quantized. Quantize the output. output_tensors = self.get_output_tensors(op) assert len(output_tensors) == 1, "output tensors length should be 1" output_tensor = output_tensors[0] - output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) if output_tensor.qnn_params: - out = _qnn.op.requantize( - out, - input_scale=input_tensor.qnn_params["scale"], - input_zero_point=input_tensor.qnn_params["zero_point"], - output_scale=output_tensor.qnn_params["scale"], - output_zero_point=output_tensor.qnn_params["zero_point"], - out_dtype=output_tensor_type_str, - ) + out = self.quantize(out, output_tensor) return out @@ -3175,20 +3131,24 @@ def convert_fully_connected(self, op): ) weight_expr = self.get_tensor_expr(weight_tensor) - weight_shape = weight_expr.struct_info.shape weight_expr = relax.op.permute_dims(weight_expr, [1, 0]) if input_tensor.qnn_params: - out = _qnn.op.dense( - in_expr, + # Dequantize input and weight (OC remapped from axis 0 to 1) + in_f32 = self.dequantize(in_expr, input_tensor) + weight_axis = weight_tensor.qnn_params["axis"] + if weight_axis != 0: + raise tvm.error.OpAttributeInvalid( + f"FC weight QuantizedDimension() must be 0 (output-channel " + f"axis in [OC,IC] layout), got {weight_axis}" + ) + w_f32 = relax.op.dequantize( weight_expr, - input_zero_point=input_tensor.qnn_params["zero_point"], - kernel_zero_point=weight_tensor.qnn_params["zero_point"], - input_scale=input_tensor.qnn_params["scale"], - kernel_scale=weight_tensor.qnn_params["scale"], - units=weight_shape[0], - out_dtype="int64" if output_tensor_type_str == "int16" else "int32", + scale=weight_tensor.qnn_params["scale"], + zero_point=weight_tensor.qnn_params["zero_point"], + axis=1, ) + out = relax.op.matmul(in_f32, w_f32) else: out = relax.op.matmul(in_expr, weight_expr) @@ -3212,27 +3172,27 @@ def convert_fully_connected(self, op): dtype=bias_tensor_type_str, source_name=bias_tensor.tensor.Name(), ) + if bias_tensor.qnn_params: + bias_expr = self.dequantize(bias_expr, bias_tensor) + elif input_tensor.qnn_params and bias_tensor_type in ( + TensorType.INT32, + TensorType.INT64, + ): + bias_scale = relax.op.multiply( + input_tensor.qnn_params["scale"], + weight_tensor.qnn_params["scale"], + ) + bias_expr = relax.op.dequantize( + bias_expr, + scale=bias_scale, + zero_point=relax.const(0, "int32"), + axis=0, + ) out = relax.op.add(out, bias_expr) - # Finally if the dense is quantized. Add a requantize at the end. + # Finally if the dense is quantized. Quantize the output. if output_tensor.qnn_params: - data_scale = input_tensor.qnn_params["scale"] - weight_scale = weight_tensor.qnn_params["scale"] - data_scale_val = get_scalar_from_constant(data_scale) - weight_scale_val = get_scalar_from_constant(weight_scale) - new_input_scale_val = data_scale_val * weight_scale_val - new_input_scale = relax.const(new_input_scale_val, "float32") - new_input_zero_point = relax.const(0, "int32") - - # Requantize - out = _qnn.op.requantize( - out, - input_scale=new_input_scale, - input_zero_point=new_input_zero_point, - output_scale=output_tensor.qnn_params["scale"], - output_zero_point=output_tensor.qnn_params["zero_point"], - out_dtype=output_tensor_type_str, - ) + out = self.quantize(out, output_tensor) # Call activation function output_scale_val = get_scalar_from_constant(output_tensor.qnn_params["scale"]) @@ -3444,15 +3404,35 @@ def convert_conv(self, op, conv_type): ) if input_tensor.qnn_params: - qnn_conv2d_params = dict(params) - qnn_conv2d_params["input_zero_point"] = input_tensor.qnn_params["zero_point"] - qnn_conv2d_params["kernel_zero_point"] = weight_tensor.qnn_params["zero_point"] - qnn_conv2d_params["out_dtype"] = ( - "int64" if output_tensor_type_str == "int16" else "int32" - ) - qnn_conv2d_params["input_scale"] = input_tensor.qnn_params["scale"] - qnn_conv2d_params["kernel_scale"] = weight_tensor.qnn_params["scale"] - out = _qnn.op.conv2d(in_expr, weight_expr, **qnn_conv2d_params) + # Dequantize input activation + in_f32 = self.dequantize(in_expr, input_tensor) + # Dequantize weight with per-channel axis remap. + # TFLite weight original layout: [OC, KH, KW, IC] + # After transpose to HWIO: [KH, KW, IC, OC] + # QuantizedDimension() == 0 (OC in original) → axis 3 in HWIO. + weight_axis = weight_tensor.qnn_params["axis"] + if is_depthwise_conv: + if weight_axis != 0: + raise tvm.error.OpNotImplemented( + "Per-channel quantized depthwise convolution is not supported " + "because the channel axis changes semantics after the " + "[1,KH,KW,C*M] → [KH,KW,C,M] reshape." + ) + else: + if weight_axis != 0: + raise tvm.error.OpAttributeInvalid( + f"Conv2D weight QuantizedDimension() must be 0 (output-channel " + f"axis in [OC,KH,KW,IC] layout), got {weight_axis}" + ) + weight_axis = 3 + w_f32 = relax.op.dequantize( + weight_expr, + scale=weight_tensor.qnn_params["scale"], + zero_point=weight_tensor.qnn_params["zero_point"], + axis=weight_axis, + ) + # Float convolution + out = relax.op.nn.conv2d(in_f32, w_f32, **params) else: out = relax.op.nn.conv2d(in_expr, weight_expr, **params) @@ -3475,37 +3455,31 @@ def convert_conv(self, op, conv_type): dtype=bias_tensor_type_str, source_name=bias_tensor.tensor.Name(), ) + # For quantized conv, INT32/INT64 bias must be dequantized + # to float32 before adding to the float conv output. + if bias_tensor.qnn_params: + bias_expr = self.dequantize(bias_expr, bias_tensor) + elif input_tensor.qnn_params and bias_tensor_type in ( + TensorType.INT32, + TensorType.INT64, + ): + bias_expr = relax.op.dequantize( + bias_expr, + scale=relax.op.multiply( + input_tensor.qnn_params["scale"], + weight_tensor.qnn_params["scale"], + ), + zero_point=relax.const(0, "int32"), + axis=0, + ) out = relax.op.add(out, bias_expr) # Handle fused activation. if output_tensor.qnn_params: - # Calculate the intermediate scale and zero point of the int32 output. - data_scale = input_tensor.qnn_params["scale"] - data_scale_val = get_scalar_from_constant(data_scale) - - weight_scale = weight_tensor.qnn_params["scale"] - # If weight scale is scalar, it is per-tensor quantization - if isinstance(weight_scale, float): - weight_scale_val = get_scalar_from_constant(weight_scale) - else: - weight_scale_val = get_tensor_from_constant(weight_scale) - - new_input_scale_val = data_scale_val * weight_scale_val - new_input_scale = relax.const(new_input_scale_val, "float32") - new_input_zero_point = relax.const(0, "int32") - - # Finally requantize - out = _qnn.op.requantize( - out, - input_scale=new_input_scale, - input_zero_point=new_input_zero_point, - output_scale=output_tensor.qnn_params["scale"], - output_zero_point=output_tensor.qnn_params["zero_point"], - out_dtype=output_tensor_type_str, - axis=3, - ) + # Quantize the float output using the output tensor's qnn params. + out = self.quantize(out, output_tensor) - # Call activation function + # Call quantized activation function output_scale_val = get_scalar_from_constant(output_tensor.qnn_params["scale"]) output_zero_point_val = get_scalar_from_constant(output_tensor.qnn_params["zero_point"]) out = self.convert_qnn_fused_activation_function( @@ -4985,25 +4959,27 @@ def convert_transpose_conv(self, op): padding = (0, 0, 0, 0) if input_tensor.qnn_params: - input_zero_point = input_tensor.qnn_params["zero_point"] - kernel_zero_point = weights_tensor.qnn_params["zero_point"] - input_scale = input_tensor.qnn_params["scale"] - kernel_scale = weights_tensor.qnn_params["scale"] - out_dtype = "int64" if output_tensor_type_str == "int16" else "int32" - out = _qnn.op.conv2d_transpose( - in_expr, + in_f32 = self.dequantize(in_expr, input_tensor) + weight_axis = weights_tensor.qnn_params["axis"] + if weight_axis != 0: + raise tvm.error.OpAttributeInvalid( + f"TransposeConv weight QuantizedDimension() must be 0 " + f"(output-channel axis in OHWI layout), got {weight_axis}" + ) + w_f32 = relax.op.dequantize( weight_expr_iohw, - input_zero_point, - kernel_zero_point, - input_scale, - kernel_scale, + scale=weights_tensor.qnn_params["scale"], + zero_point=weights_tensor.qnn_params["zero_point"], + axis=1, + ) + out = relax.op.nn.conv2d_transpose( + in_f32, + w_f32, strides=(stride_h, stride_w), padding=padding, - channels=int(out_channels), - kernel_size=(int(kernel_h), int(kernel_w)), data_layout="NHWC", kernel_layout="IOHW", - out_dtype=out_dtype, + out_dtype="float32", ) else: out = relax.op.nn.conv2d_transpose( @@ -5035,34 +5011,26 @@ def convert_transpose_conv(self, op): dtype=bias_tensor_type_str, source_name=bias_tensor.tensor.Name(), ) - channel_axis = 3 - out = relax.op.nn.bias_add(out, bias_expr, axis=channel_axis) + if bias_tensor.qnn_params: + bias_expr = self.dequantize(bias_expr, bias_tensor) + elif input_tensor.qnn_params and bias_tensor_type in ( + TensorType.INT32, + TensorType.INT64, + ): + bias_scale = relax.op.multiply( + input_tensor.qnn_params["scale"], + weights_tensor.qnn_params["scale"], + ) + bias_expr = relax.op.dequantize( + bias_expr, + scale=bias_scale, + zero_point=relax.const(0, "int32"), + axis=0, + ) + out = relax.op.add(out, bias_expr) if output_tensor.qnn_params: - # Calculate the intermediate scale and zero point of the int32 output. - data_scale = input_tensor.qnn_params["scale"] - data_scale_val = get_scalar_from_constant(data_scale) - - weight_scale = weights_tensor.qnn_params["scale"] - # If weight scale is scalar, it is per-tensor quantization - if isinstance(weight_scale, float): - weight_scale_val = get_scalar_from_constant(weight_scale) - else: - weight_scale_val = get_tensor_from_constant(weight_scale) - - new_input_scale_val = data_scale_val * weight_scale_val - new_input_scale = relax.const(new_input_scale_val, "float32") - new_input_zero_point = relax.const(0, "int32") - - out = _qnn.op.requantize( - out, - input_scale=new_input_scale, - input_zero_point=new_input_zero_point, - output_scale=output_tensor.qnn_params["scale"], - output_zero_point=output_tensor.qnn_params["zero_point"], - out_dtype=output_tensor_type_str, - axis=3, - ) + out = self.quantize(out, output_tensor) return out def convert_quantize(self, op): @@ -5077,7 +5045,6 @@ def convert_quantize(self, op): output_tensors = self.get_output_tensors(op) assert len(output_tensors) == 1, "output tensors length should be 1" output_tensor = output_tensors[0] - output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) # The output must be quantized assert output_tensor.qnn_params @@ -5086,14 +5053,8 @@ def convert_quantize(self, op): if input_tensor_type_str == "float32": out = self.quantize(in_expr, output_tensor) else: - out = _qnn.op.requantize( - in_expr, - input_scale=input_tensor.qnn_params["scale"], - input_zero_point=input_tensor.qnn_params["zero_point"], - output_scale=output_tensor.qnn_params["scale"], - output_zero_point=output_tensor.qnn_params["zero_point"], - out_dtype=output_tensor_type_str, - ) + in_f32 = self.dequantize(in_expr, input_tensor) + out = self.quantize(in_f32, output_tensor) return out def convert_dequantize(self, op): @@ -5242,23 +5203,11 @@ def convert_detection_postprocess(self, op): ) if inputs[0].qnn_params: - loc_prob = _qnn.op.dequantize( - data=loc_prob, - input_scale=inputs[0].qnn_params["scale"], - input_zero_point=inputs[0].qnn_params["zero_point"], - ) + loc_prob = self.dequantize(loc_prob, inputs[0]) if inputs[1].qnn_params: - cls_pred = _qnn.op.dequantize( - data=cls_pred, - input_scale=inputs[1].qnn_params["scale"], - input_zero_point=inputs[1].qnn_params["zero_point"], - ) + cls_pred = self.dequantize(cls_pred, inputs[1]) if inputs[2].qnn_params: - anchor_expr = _qnn.op.dequantize( - data=anchor_expr, - input_scale=inputs[2].qnn_params["scale"], - input_zero_point=inputs[2].qnn_params["zero_point"], - ) + anchor_expr = self.dequantize(anchor_expr, inputs[2]) # loc_prob coords are in yxhw format # need to convert to xywh diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 031c1553d8bf..d03de3b6a9c4 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -3671,8 +3671,12 @@ def _get_tflite_schema_enum(enum_name): _tfl_add_options = _get_tflite_schema_module("AddOptions") _tfl_buffer = _get_tflite_schema_module("Buffer") +_tfl_concatenation_options = _get_tflite_schema_module("ConcatenationOptions") _tfl_conv2d_options = _get_tflite_schema_module("Conv2DOptions") +_tfl_depthwise_conv2d_options = _get_tflite_schema_module("DepthwiseConv2DOptions") _tfl_dilate_options = _get_tflite_schema_module("DilateOptions") +_tfl_reshape_options = _get_tflite_schema_module("ReshapeOptions") +_tfl_transpose_conv_options = _get_tflite_schema_module("TransposeConvOptions") # ── StableHLO BuiltinOptions2 schema modules ──────────────────────────── _tfl_stablehlo_concat_opts = _get_tflite_schema_module("StablehloConcatenateOptions") @@ -3697,6 +3701,7 @@ def _get_tflite_schema_enum(enum_name): _tfl_model = _get_tflite_schema_module("Model") _tfl_operator = _get_tflite_schema_module("Operator") _tfl_operator_code = _get_tflite_schema_module("OperatorCode") +_tfl_quantization_parameters = _get_tflite_schema_module("QuantizationParameters") _tfl_sparsity_parameters = _get_tflite_schema_module("SparsityParameters") _tfl_subgraph = _get_tflite_schema_module("SubGraph") _tfl_tensor = _get_tflite_schema_module("Tensor") @@ -3704,6 +3709,7 @@ def _get_tflite_schema_enum(enum_name): _tfl_builtin_operator = _get_tflite_schema_enum("BuiltinOperator") _tfl_builtin_options = _get_tflite_schema_enum("BuiltinOptions") _tfl_builtin_options2 = _get_tflite_schema_enum("BuiltinOptions2") +_tfl_activation_fn = _get_tflite_schema_enum("ActivationFunctionType") _tfl_dimension_type = _get_tflite_schema_enum("DimensionType") _tfl_fc_weights_format = _get_tflite_schema_enum("FullyConnectedOptionsWeightsFormat") _tfl_padding = _get_tflite_schema_enum("Padding") @@ -3742,6 +3748,13 @@ def _tflite_bool_vector(builder, start_vector_fn, values): return builder.EndVector() +def _tflite_float32_vector(builder, start_vector_fn, values): + start_vector_fn(builder, len(values)) + for value in reversed(values): + builder.PrependFloat32(value) + return builder.EndVector() + + def _tflite_offset_vector(builder, start_vector_fn, offsets): start_vector_fn(builder, len(offsets)) for offset in reversed(offsets): @@ -3773,7 +3786,7 @@ def _tflite_shape(builder, shape): return _tflite_int32_vector(builder, _tfl_tensor.TensorStartShapeVector, shape) -def _build_tensor(builder, buffer_idx, shape, sparsity=None, tensor_type=None): +def _build_tensor(builder, buffer_idx, shape, sparsity=None, tensor_type=None, quantization=None): """Helper to build a TFLite tensor.""" if tensor_type is None: tensor_type = _tfl_tensor_type.FLOAT32 @@ -3785,6 +3798,8 @@ def _build_tensor(builder, buffer_idx, shape, sparsity=None, tensor_type=None): _tfl_tensor.TensorAddShape(builder, shape_vec) if sparsity is not None: _tfl_tensor.TensorAddSparsity(builder, sparsity) + if quantization is not None: + _tfl_tensor.TensorAddQuantization(builder, quantization) _tfl_tensor.TensorAddType(builder, tensor_type) return _tfl_tensor.TensorEnd(builder) @@ -3801,6 +3816,24 @@ def _build_buffer(builder, data=None): return _tfl_buffer.BufferEnd(builder) +def _build_quantization_parameters(builder, *, scale, zero_point, quantized_dimension): + scale_vec = _tflite_float32_vector( + builder, _tfl_quantization_parameters.QuantizationParametersStartScaleVector, scale + ) + zero_point_vec = _tflite_int64_vector( + builder, + _tfl_quantization_parameters.QuantizationParametersStartZeroPointVector, + zero_point, + ) + _tfl_quantization_parameters.QuantizationParametersStart(builder) + _tfl_quantization_parameters.QuantizationParametersAddScale(builder, scale_vec) + _tfl_quantization_parameters.QuantizationParametersAddZeroPoint(builder, zero_point_vec) + _tfl_quantization_parameters.QuantizationParametersAddQuantizedDimension( + builder, quantized_dimension + ) + return _tfl_quantization_parameters.QuantizationParametersEnd(builder) + + def _build_operator( builder, opcode_index, @@ -6139,6 +6172,1609 @@ def test_stablehlo_convolution_dimension_numbers_unsupported(): from_tflite(tflite_model) +# Quantized TFLite QDQ tests + + +def test_tensor_quantization_parameters_are_parsed(): + """Tensor quantization metadata is kept without requiring quantized op support.""" + builder = flatbuffers.Builder(1024) + + per_tensor_quantization = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + per_axis_quantization = _build_quantization_parameters( + builder, scale=[0.25, 0.75], zero_point=[0, 0], quantized_dimension=3 + ) + per_tensor = _build_tensor( + builder, + 0, + [1, 4], + tensor_type=_tfl_tensor_type.UINT8, + quantization=per_tensor_quantization, + ) + per_axis = _build_tensor( + builder, + 1, + [1, 2, 3, 2], + tensor_type=_tfl_tensor_type.INT8, + quantization=per_axis_quantization, + ) + subgraph = _build_subgraph( + builder, tensors=[per_tensor, per_axis], operators=[], inputs=[0, 1], outputs=[0, 1] + ) + buffers = [_build_buffer(builder), _build_buffer(builder)] + buf = _finish_tflite_model(builder, subgraph=subgraph, operator_codes=[], buffers=buffers) + + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + + converter = tflite_frontend.OperatorConverter( + tflite_model, tflite_model.Subgraphs(0), tflite_frontend.ExprTable(), None + ) + per_tensor_wrapper, per_axis_wrapper = converter.get_tensors([0, 1]) + + np.testing.assert_allclose(per_tensor_wrapper.qnn_params["scale"].data.numpy(), 0.5) + np.testing.assert_equal(per_tensor_wrapper.qnn_params["zero_point"].data.numpy(), 3) + assert per_tensor_wrapper.qnn_params["axis"] == 0 + + np.testing.assert_allclose( + per_axis_wrapper.qnn_params["scale"].data.numpy(), np.array([0.25, 0.75]) + ) + np.testing.assert_equal(per_axis_wrapper.qnn_params["zero_point"].data.numpy(), 0) + assert per_axis_wrapper.qnn_params["axis"] == 3 + + mod = from_tflite(tflite_model) + assert len(mod["main"].params) == 2 + + +def test_quantize_op_uses_relax_quantize(): + """TFLite QUANTIZE float32 -> int8 uses R.quantize.""" + builder = flatbuffers.Builder(1024) + + input_data = np.array([1.0, 2.0], dtype=np.float32) + output_qparams = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + + input_tensor = _build_tensor(builder, 0, [2], tensor_type=_tfl_tensor_type.FLOAT32) + output_tensor = _build_tensor( + builder, + 1, + [2], + tensor_type=_tfl_tensor_type.INT8, + quantization=output_qparams, + ) + + quantize_op = _build_operator(builder, 0, [0], [1]) + subgraph = _build_subgraph( + builder, + tensors=[input_tensor, output_tensor], + operators=[quantize_op], + inputs=[0], + outputs=[1], + ) + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.QUANTIZE)] + input_buffer = _build_buffer(builder, input_data.tobytes()) + output_buffer = _build_buffer(builder) + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[input_buffer, output_buffer], + ) + + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + mod = from_tflite(tflite_model) + mod["main"] = mod["main"].without_attr("params") + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2,), dtype="float32")) -> R.Tensor((2,), dtype="int8"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor((2,), dtype="int8") = R.quantize( + x, + R.const(0.5, "float32"), + R.const(3, "int32"), + axis=0, + out_dtype="int8", + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_quantize_op_requantize_uses_dq_q(): + """TFLite QUANTIZE with quantized input uses DQ→Q (requantize).""" + builder = flatbuffers.Builder(1024) + + input_data = np.array([10, 20], dtype=np.int8) + input_qparams = _build_quantization_parameters( + builder, scale=[0.25], zero_point=[1], quantized_dimension=0 + ) + output_qparams = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + + input_tensor = _build_tensor( + builder, + 0, + [2], + tensor_type=_tfl_tensor_type.INT8, + quantization=input_qparams, + ) + output_tensor = _build_tensor( + builder, + 1, + [2], + tensor_type=_tfl_tensor_type.INT8, + quantization=output_qparams, + ) + + quantize_op = _build_operator( + builder, + 0, + [0], + [1], + ) + subgraph = _build_subgraph( + builder, + tensors=[input_tensor, output_tensor], + operators=[quantize_op], + inputs=[0], + outputs=[1], + ) + operator_codes = [ + _build_operator_code(builder, _tfl_builtin_operator.QUANTIZE), + ] + input_buffer = _build_buffer(builder, input_data.tobytes()) + output_buffer = _build_buffer(builder) + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[input_buffer, output_buffer], + ) + + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + mod = from_tflite(tflite_model) + mod["main"] = mod["main"].without_attr("params") + + @I.ir_module + class Expected: + @R.function + def main( + tvmgen_tensor_0: R.Tensor((2,), dtype="int8"), + ) -> R.Tensor((2,), dtype="int8"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + lv: R.Tensor((2,), dtype="float32") = R.dequantize( + tvmgen_tensor_0, + R.const(0.25, "float32"), + R.const(1, "int32"), + out_dtype="float32", + axis=0, + ) + gv: R.Tensor((2,), dtype="int8") = R.quantize( + lv, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="int8", + axis=0, + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_dequantize_op_uses_relax_dequantize(): + """TFLite DEQUANTIZE int8 -> float32 uses R.dequantize.""" + builder = flatbuffers.Builder(1024) + + input_data = np.array([10, 20], dtype=np.int8) + input_qparams = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + + input_tensor = _build_tensor( + builder, + 0, + [2], + tensor_type=_tfl_tensor_type.INT8, + quantization=input_qparams, + ) + output_tensor = _build_tensor(builder, 1, [2], tensor_type=_tfl_tensor_type.FLOAT32) + + dequantize_op = _build_operator(builder, 0, [0], [1]) + subgraph = _build_subgraph( + builder, + tensors=[input_tensor, output_tensor], + operators=[dequantize_op], + inputs=[0], + outputs=[1], + ) + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.DEQUANTIZE)] + input_buffer = _build_buffer(builder, input_data.tobytes()) + output_buffer = _build_buffer(builder) + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[input_buffer, output_buffer], + ) + + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + mod = from_tflite(tflite_model) + mod["main"] = mod["main"].without_attr("params") + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2,), dtype="int8")) -> R.Tensor((2,), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor((2,), dtype="float32") = R.dequantize( + x, + R.const(0.5, "float32"), + R.const(3, "int32"), + axis=0, + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_quantized_conv2d_per_tensor_uses_qdq(): + """Quantized Conv2D with per-tensor quantization uses DQ -> conv2d -> Q.""" + builder = flatbuffers.Builder(2048) + + in_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + wt_q = _build_quantization_parameters( + builder, scale=[0.25], zero_point=[0], quantized_dimension=0 + ) + out_q = _build_quantization_parameters( + builder, scale=[1.0], zero_point=[0], quantized_dimension=0 + ) + + input_tensor = _build_tensor( + builder, + 0, + [1, 4, 4, 1], + tensor_type=_tfl_tensor_type.INT8, + quantization=in_q, + ) + weight_tensor = _build_tensor( + builder, + 1, + [2, 3, 3, 1], + tensor_type=_tfl_tensor_type.INT8, + quantization=wt_q, + ) + output_tensor = _build_tensor( + builder, + 2, + [1, 2, 2, 2], + tensor_type=_tfl_tensor_type.INT8, + quantization=out_q, + ) + + _tfl_conv2d_options.Conv2DOptionsStart(builder) + _tfl_conv2d_options.Conv2DOptionsAddStrideH(builder, 1) + _tfl_conv2d_options.Conv2DOptionsAddStrideW(builder, 1) + _tfl_conv2d_options.Conv2DOptionsAddPadding(builder, _tfl_padding.VALID) + _tfl_conv2d_options.Conv2DOptionsAddFusedActivationFunction(builder, 0) + conv_opts = _tfl_conv2d_options.Conv2DOptionsEnd(builder) + + conv_op = _build_operator( + builder, + 0, + [0, 1], + [2], + builtin_options_type=_tfl_builtin_options.Conv2DOptions, + builtin_options=conv_opts, + ) + subgraph = _build_subgraph( + builder, + tensors=[input_tensor, weight_tensor, output_tensor], + operators=[conv_op], + inputs=[0, 1], + outputs=[2], + ) + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.CONV_2D)] + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[_build_buffer(builder), _build_buffer(builder), _build_buffer(builder)], + ) + + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + mod = from_tflite(tflite_model) + mod["main"] = mod["main"].without_attr("params") + + @I.ir_module + class Expected: + @R.function + def main( + tvmgen_tensor_0: R.Tensor((1, 4, 4, 1), dtype="int8"), + tvmgen_tensor_1: R.Tensor((2, 3, 3, 1), dtype="int8"), + ) -> R.Tensor((1, 2, 2, 2), dtype="int8"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + lv: R.Tensor((1, 4, 4, 1), dtype="float32") = R.dequantize( + tvmgen_tensor_0, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="float32", + axis=0, + ) + lv1: R.Tensor((3, 3, 1, 2), dtype="int8") = R.permute_dims( + tvmgen_tensor_1, + axes=[1, 2, 3, 0], + ) + lv2: R.Tensor((3, 3, 1, 2), dtype="float32") = R.dequantize( + lv1, + R.const(0.25, "float32"), + R.const(0, "int32"), + out_dtype="float32", + axis=3, + ) + lv3: R.Tensor((1, 2, 2, 2), dtype="float32") = R.nn.conv2d( + lv, + lv2, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="HWIO", + out_layout="NHWC", + out_dtype="void", + ) + gv: R.Tensor((1, 2, 2, 2), dtype="int8") = R.quantize( + lv3, + R.const(1.0, "float32"), + R.const(0, "int32"), + out_dtype="int8", + axis=0, + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_quantized_conv2d_per_channel_weight_uses_remapped_axis(): + """Quantized Conv2D remaps per-channel weight axis after OHWI -> HWIO.""" + builder = flatbuffers.Builder(2048) + + in_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + wt_q = _build_quantization_parameters( + builder, scale=[0.25, 0.75], zero_point=[0, 0], quantized_dimension=0 + ) + out_q = _build_quantization_parameters( + builder, scale=[1.0], zero_point=[0], quantized_dimension=0 + ) + + input_tensor = _build_tensor( + builder, + 0, + [1, 4, 4, 1], + tensor_type=_tfl_tensor_type.INT8, + quantization=in_q, + ) + weight_tensor = _build_tensor( + builder, + 1, + [2, 3, 3, 1], + tensor_type=_tfl_tensor_type.INT8, + quantization=wt_q, + ) + output_tensor = _build_tensor( + builder, + 2, + [1, 2, 2, 2], + tensor_type=_tfl_tensor_type.INT8, + quantization=out_q, + ) + + _tfl_conv2d_options.Conv2DOptionsStart(builder) + _tfl_conv2d_options.Conv2DOptionsAddStrideH(builder, 1) + _tfl_conv2d_options.Conv2DOptionsAddStrideW(builder, 1) + _tfl_conv2d_options.Conv2DOptionsAddPadding(builder, _tfl_padding.VALID) + _tfl_conv2d_options.Conv2DOptionsAddFusedActivationFunction(builder, 0) + conv_opts = _tfl_conv2d_options.Conv2DOptionsEnd(builder) + + conv_op = _build_operator( + builder, + 0, + [0, 1], + [2], + builtin_options_type=_tfl_builtin_options.Conv2DOptions, + builtin_options=conv_opts, + ) + subgraph = _build_subgraph( + builder, + tensors=[input_tensor, weight_tensor, output_tensor], + operators=[conv_op], + inputs=[0, 1], + outputs=[2], + ) + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.CONV_2D)] + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[_build_buffer(builder), _build_buffer(builder), _build_buffer(builder)], + ) + + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + mod = from_tflite(tflite_model) + mod["main"] = mod["main"].without_attr("params") + + @I.ir_module + class Expected: + @R.function + def main( + tvmgen_tensor_0: R.Tensor((1, 4, 4, 1), dtype="int8"), + tvmgen_tensor_1: R.Tensor((2, 3, 3, 1), dtype="int8"), + ) -> R.Tensor((1, 2, 2, 2), dtype="int8"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + lv: R.Tensor((1, 4, 4, 1), dtype="float32") = R.dequantize( + tvmgen_tensor_0, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="float32", + axis=0, + ) + lv1: R.Tensor((3, 3, 1, 2), dtype="int8") = R.permute_dims( + tvmgen_tensor_1, + axes=[1, 2, 3, 0], + ) + lv2: R.Tensor((3, 3, 1, 2), dtype="float32") = R.dequantize( + lv1, + R.const([0.25, 0.75], "float32"), + R.const(0, "int32"), + out_dtype="float32", + axis=3, + ) + lv3: R.Tensor((1, 2, 2, 2), dtype="float32") = R.nn.conv2d( + lv, + lv2, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="HWIO", + out_layout="NHWC", + out_dtype="void", + ) + gv: R.Tensor((1, 2, 2, 2), dtype="int8") = R.quantize( + lv3, + R.const(1.0, "float32"), + R.const(0, "int32"), + out_dtype="int8", + axis=0, + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_quantized_concat_uses_qdq(): + """Quantized CONCATENATION uses DQ each input → concat → Q.""" + import flatbuffers + import tflite.Model + + builder = flatbuffers.Builder(1024) + + in_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + out_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + + t0 = _build_tensor(builder, 0, [1, 2], tensor_type=_tfl_tensor_type.INT8, quantization=in_q) + t1 = _build_tensor(builder, 1, [1, 2], tensor_type=_tfl_tensor_type.INT8, quantization=in_q) + t2 = _build_tensor(builder, 2, [1, 4], tensor_type=_tfl_tensor_type.INT8, quantization=out_q) + + _tfl_concatenation_options.ConcatenationOptionsStart(builder) + _tfl_concatenation_options.ConcatenationOptionsAddAxis(builder, 1) + _tfl_concatenation_options.ConcatenationOptionsAddFusedActivationFunction(builder, 0) + concat_opts = _tfl_concatenation_options.ConcatenationOptionsEnd(builder) + + concat_op = _build_operator( + builder, + 0, + [0, 1], + [2], + builtin_options_type=_tfl_builtin_options.ConcatenationOptions, + builtin_options=concat_opts, + ) + subgraph = _build_subgraph( + builder, + tensors=[t0, t1, t2], + operators=[concat_op], + inputs=[0, 1], + outputs=[2], + ) + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.CONCATENATION)] + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[_build_buffer(builder)] * 3, + ) + + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + mod = from_tflite(tflite_model) + mod["main"] = mod["main"].without_attr("params") + + @I.ir_module + class Expected: + @R.function + def main( + tvmgen_tensor_0: R.Tensor((1, 2), dtype="int8"), + tvmgen_tensor_1: R.Tensor((1, 2), dtype="int8"), + ) -> R.Tensor((1, 4), dtype="int8"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + lv: R.Tensor((1, 2), dtype="float32") = R.dequantize( + tvmgen_tensor_0, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="float32", + axis=0, + ) + lv1: R.Tensor((1, 2), dtype="float32") = R.dequantize( + tvmgen_tensor_1, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="float32", + axis=0, + ) + lv2: R.Tensor((1, 4), dtype="float32") = R.concat((lv, lv1), axis=1) + gv: R.Tensor((1, 4), dtype="int8") = R.quantize( + lv2, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="int8", + axis=0, + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_quantized_concat_fused_relu_uses_quantized_clip(): + """Quantized CONCATENATION fused RELU clips in the quantized domain.""" + builder = flatbuffers.Builder(1024) + + in_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + out_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + + t0 = _build_tensor(builder, 0, [1, 2], tensor_type=_tfl_tensor_type.INT8, quantization=in_q) + t1 = _build_tensor(builder, 1, [1, 2], tensor_type=_tfl_tensor_type.INT8, quantization=in_q) + t2 = _build_tensor(builder, 2, [1, 4], tensor_type=_tfl_tensor_type.INT8, quantization=out_q) + + _tfl_concatenation_options.ConcatenationOptionsStart(builder) + _tfl_concatenation_options.ConcatenationOptionsAddAxis(builder, 1) + _tfl_concatenation_options.ConcatenationOptionsAddFusedActivationFunction( + builder, _tfl_activation_fn.RELU + ) + concat_opts = _tfl_concatenation_options.ConcatenationOptionsEnd(builder) + + concat_op = _build_operator( + builder, + 0, + [0, 1], + [2], + builtin_options_type=_tfl_builtin_options.ConcatenationOptions, + builtin_options=concat_opts, + ) + subgraph = _build_subgraph( + builder, + tensors=[t0, t1, t2], + operators=[concat_op], + inputs=[0, 1], + outputs=[2], + ) + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.CONCATENATION)] + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[_build_buffer(builder)] * 3, + ) + + mod = _load_model_from_buffer(buf) + + @I.ir_module + class Expected: + @R.function + def main( + tvmgen_tensor_0: R.Tensor((1, 2), dtype="int8"), + tvmgen_tensor_1: R.Tensor((1, 2), dtype="int8"), + ) -> R.Tensor((1, 4), dtype="int8"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + lv: R.Tensor((1, 2), dtype="float32") = R.dequantize( + tvmgen_tensor_0, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="float32", + axis=0, + ) + lv1: R.Tensor((1, 2), dtype="float32") = R.dequantize( + tvmgen_tensor_1, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="float32", + axis=0, + ) + lv2: R.Tensor((1, 4), dtype="float32") = R.concat((lv, lv1), axis=1) + lv3: R.Tensor((1, 4), dtype="int8") = R.quantize( + lv2, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="int8", + axis=0, + ) + gv: R.Tensor((1, 4), dtype="int8") = R.clip(lv3, min=3.0, max=127.0) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_quantized_add_uses_qdq(): + """Quantized ADD uses DQ each input -> add -> Q.""" + builder = flatbuffers.Builder(1024) + + lhs_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + rhs_q = _build_quantization_parameters( + builder, scale=[0.25], zero_point=[1], quantized_dimension=0 + ) + out_q = _build_quantization_parameters( + builder, scale=[1.0], zero_point=[0], quantized_dimension=0 + ) + + t_lhs = _build_tensor(builder, 0, [2], tensor_type=_tfl_tensor_type.INT8, quantization=lhs_q) + t_rhs = _build_tensor(builder, 1, [2], tensor_type=_tfl_tensor_type.INT8, quantization=rhs_q) + t_out = _build_tensor(builder, 2, [2], tensor_type=_tfl_tensor_type.INT8, quantization=out_q) + + _tfl_add_options.AddOptionsStart(builder) + _tfl_add_options.AddOptionsAddFusedActivationFunction(builder, 0) + add_opts = _tfl_add_options.AddOptionsEnd(builder) + + add_op = _build_operator( + builder, + 0, + [0, 1], + [2], + builtin_options_type=_tfl_builtin_options.AddOptions, + builtin_options=add_opts, + ) + subgraph = _build_subgraph( + builder, + tensors=[t_lhs, t_rhs, t_out], + operators=[add_op], + inputs=[0, 1], + outputs=[2], + ) + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.ADD)] + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[_build_buffer(builder)] * 3, + ) + + mod = _load_model_from_buffer(buf) + + @I.ir_module + class Expected: + @R.function + def main( + tvmgen_tensor_0: R.Tensor((2,), dtype="int8"), + tvmgen_tensor_1: R.Tensor((2,), dtype="int8"), + ) -> R.Tensor((2,), dtype="int8"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + lv: R.Tensor((2,), dtype="float32") = R.dequantize( + tvmgen_tensor_0, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="float32", + axis=0, + ) + lv1: R.Tensor((2,), dtype="float32") = R.dequantize( + tvmgen_tensor_1, + R.const(0.25, "float32"), + R.const(1, "int32"), + out_dtype="float32", + axis=0, + ) + lv2: R.Tensor((2,), dtype="float32") = R.add(lv, lv1) + gv: R.Tensor((2,), dtype="int8") = R.quantize( + lv2, + R.const(1.0, "float32"), + R.const(0, "int32"), + out_dtype="int8", + axis=0, + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_quantized_add_fused_relu6_uses_float_clip_before_quantize(): + """Quantized ADD fused RELU6 applies the activation before quantizing.""" + builder = flatbuffers.Builder(1024) + + lhs_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + rhs_q = _build_quantization_parameters( + builder, scale=[0.25], zero_point=[1], quantized_dimension=0 + ) + out_q = _build_quantization_parameters( + builder, scale=[1.0], zero_point=[0], quantized_dimension=0 + ) + + t_lhs = _build_tensor(builder, 0, [2], tensor_type=_tfl_tensor_type.INT8, quantization=lhs_q) + t_rhs = _build_tensor(builder, 1, [2], tensor_type=_tfl_tensor_type.INT8, quantization=rhs_q) + t_out = _build_tensor(builder, 2, [2], tensor_type=_tfl_tensor_type.INT8, quantization=out_q) + + _tfl_add_options.AddOptionsStart(builder) + _tfl_add_options.AddOptionsAddFusedActivationFunction(builder, _tfl_activation_fn.RELU6) + add_opts = _tfl_add_options.AddOptionsEnd(builder) + + add_op = _build_operator( + builder, + 0, + [0, 1], + [2], + builtin_options_type=_tfl_builtin_options.AddOptions, + builtin_options=add_opts, + ) + subgraph = _build_subgraph( + builder, + tensors=[t_lhs, t_rhs, t_out], + operators=[add_op], + inputs=[0, 1], + outputs=[2], + ) + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.ADD)] + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[_build_buffer(builder)] * 3, + ) + + mod = _load_model_from_buffer(buf) + + @I.ir_module + class Expected: + @R.function + def main( + tvmgen_tensor_0: R.Tensor((2,), dtype="int8"), + tvmgen_tensor_1: R.Tensor((2,), dtype="int8"), + ) -> R.Tensor((2,), dtype="int8"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + lv: R.Tensor((2,), dtype="float32") = R.dequantize( + tvmgen_tensor_0, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="float32", + axis=0, + ) + lv1: R.Tensor((2,), dtype="float32") = R.dequantize( + tvmgen_tensor_1, + R.const(0.25, "float32"), + R.const(1, "int32"), + out_dtype="float32", + axis=0, + ) + lv2: R.Tensor((2,), dtype="float32") = R.add(lv, lv1) + lv3: R.Tensor((2,), dtype="float32") = R.clip(lv2, min=0, max=6) + gv: R.Tensor((2,), dtype="int8") = R.quantize( + lv3, + R.const(1.0, "float32"), + R.const(0, "int32"), + out_dtype="int8", + axis=0, + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_quantized_add_without_output_qparams_invalid(): + """Quantized ADD with missing output qparams raises OpAttributeInvalid.""" + builder = flatbuffers.Builder(1024) + + in_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + + t_lhs = _build_tensor(builder, 0, [2], tensor_type=_tfl_tensor_type.INT8, quantization=in_q) + t_rhs = _build_tensor(builder, 1, [2], tensor_type=_tfl_tensor_type.INT8, quantization=in_q) + t_out = _build_tensor(builder, 2, [2], tensor_type=_tfl_tensor_type.INT8) + + _tfl_add_options.AddOptionsStart(builder) + _tfl_add_options.AddOptionsAddFusedActivationFunction(builder, _tfl_activation_fn.NONE) + add_opts = _tfl_add_options.AddOptionsEnd(builder) + + add_op = _build_operator( + builder, + 0, + [0, 1], + [2], + builtin_options_type=_tfl_builtin_options.AddOptions, + builtin_options=add_opts, + ) + subgraph = _build_subgraph( + builder, + tensors=[t_lhs, t_rhs, t_out], + operators=[add_op], + inputs=[0, 1], + outputs=[2], + ) + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.ADD)] + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[_build_buffer(builder)] * 3, + ) + + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + + with pytest.raises(tvm.error.OpAttributeInvalid, match="output must have quantization"): + from_tflite(tflite_model) + + +def test_quantized_square_unsupported(): + """Quantized SQUARE is rejected instead of applying integer power directly.""" + builder = flatbuffers.Builder(1024) + + in_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + out_q = _build_quantization_parameters( + builder, scale=[1.0], zero_point=[0], quantized_dimension=0 + ) + + t_in = _build_tensor(builder, 0, [2], tensor_type=_tfl_tensor_type.INT8, quantization=in_q) + t_out = _build_tensor(builder, 1, [2], tensor_type=_tfl_tensor_type.INT8, quantization=out_q) + + square_op = _build_operator(builder, 0, [0], [1]) + subgraph = _build_subgraph( + builder, + tensors=[t_in, t_out], + operators=[square_op], + inputs=[0], + outputs=[1], + ) + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.SQUARE)] + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[_build_buffer(builder)] * 2, + ) + + with pytest.raises(tvm.error.OpNotImplemented, match="SQUARE"): + _load_model_from_buffer(buf) + + +def test_quantized_conv2d_with_int32_bias_dequantizes_bias(): + """Conv2D with INT32 bias dequantizes bias with in_scale x wt_scale.""" + import flatbuffers + import tflite.Model + + builder = flatbuffers.Builder(2048) + + in_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + wt_q = _build_quantization_parameters( + builder, scale=[0.25], zero_point=[0], quantized_dimension=0 + ) + out_q = _build_quantization_parameters( + builder, scale=[1.0], zero_point=[0], quantized_dimension=0 + ) + + t_in = _build_tensor( + builder, 0, [1, 4, 4, 1], tensor_type=_tfl_tensor_type.INT8, quantization=in_q + ) + t_wt = _build_tensor( + builder, 1, [2, 3, 3, 1], tensor_type=_tfl_tensor_type.INT8, quantization=wt_q + ) + t_bi = _build_tensor(builder, 2, [2], tensor_type=_tfl_tensor_type.INT32) + t_ou = _build_tensor( + builder, 3, [1, 2, 2, 2], tensor_type=_tfl_tensor_type.INT8, quantization=out_q + ) + + _tfl_conv2d_options.Conv2DOptionsStart(builder) + _tfl_conv2d_options.Conv2DOptionsAddStrideH(builder, 1) + _tfl_conv2d_options.Conv2DOptionsAddStrideW(builder, 1) + _tfl_conv2d_options.Conv2DOptionsAddPadding(builder, 1) + _tfl_conv2d_options.Conv2DOptionsAddFusedActivationFunction(builder, 0) + conv_opts = _tfl_conv2d_options.Conv2DOptionsEnd(builder) + + conv_op = _build_operator( + builder, + 0, + [0, 1, 2], + [3], + builtin_options_type=_tfl_builtin_options.Conv2DOptions, + builtin_options=conv_opts, + ) + subgraph = _build_subgraph( + builder, + tensors=[t_in, t_wt, t_bi, t_ou], + operators=[conv_op], + inputs=[0, 1, 2], + outputs=[3], + ) + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.CONV_2D)] + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[_build_buffer(builder)] * 4, + ) + + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + mod = from_tflite(tflite_model) + mod["main"] = mod["main"].without_attr("params") + + @I.ir_module + class Expected: + @R.function + def main( + tvmgen_tensor_0: R.Tensor((1, 4, 4, 1), dtype="int8"), + tvmgen_tensor_1: R.Tensor((2, 3, 3, 1), dtype="int8"), + tvmgen_tensor_2: R.Tensor((2,), dtype="int32"), + ) -> R.Tensor((1, 2, 2, 2), dtype="int8"): + R.func_attr({"num_input": 3}) + with R.dataflow(): + lv: R.Tensor((1, 4, 4, 1), dtype="float32") = R.dequantize( + tvmgen_tensor_0, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="float32", + axis=0, + ) + lv1: R.Tensor((3, 3, 1, 2), dtype="int8") = R.permute_dims( + tvmgen_tensor_1, + axes=[1, 2, 3, 0], + ) + lv2: R.Tensor((3, 3, 1, 2), dtype="float32") = R.dequantize( + lv1, + R.const(0.25, "float32"), + R.const(0, "int32"), + out_dtype="float32", + axis=3, + ) + lv3: R.Tensor((1, 2, 2, 2), dtype="float32") = R.nn.conv2d( + lv, + lv2, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="HWIO", + out_layout="NHWC", + out_dtype="void", + ) + lv4: R.Tensor((), dtype="float32") = R.multiply( + R.const(0.5, "float32"), + R.const(0.25, "float32"), + ) + lv5: R.Tensor((2,), dtype="float32") = R.dequantize( + tvmgen_tensor_2, + lv4, + R.const(0, "int32"), + out_dtype="float32", + axis=0, + ) + lv6: R.Tensor((1, 2, 2, 2), dtype="float32") = R.add(lv3, lv5) + gv: R.Tensor((1, 2, 2, 2), dtype="int8") = R.quantize( + lv6, + R.const(1.0, "float32"), + R.const(0, "int32"), + out_dtype="int8", + axis=0, + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_quantized_conv2d_per_channel_weight_with_int32_bias_dequantizes_bias(): + """Conv2D with per-channel weight quantization uses vector bias scale.""" + builder = flatbuffers.Builder(2048) + + in_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + wt_q = _build_quantization_parameters( + builder, scale=[0.25, 0.75], zero_point=[0, 0], quantized_dimension=0 + ) + out_q = _build_quantization_parameters( + builder, scale=[1.0], zero_point=[0], quantized_dimension=0 + ) + + t_in = _build_tensor( + builder, 0, [1, 4, 4, 1], tensor_type=_tfl_tensor_type.INT8, quantization=in_q + ) + t_wt = _build_tensor( + builder, 1, [2, 3, 3, 1], tensor_type=_tfl_tensor_type.INT8, quantization=wt_q + ) + t_bi = _build_tensor(builder, 2, [2], tensor_type=_tfl_tensor_type.INT32) + t_ou = _build_tensor( + builder, 3, [1, 2, 2, 2], tensor_type=_tfl_tensor_type.INT8, quantization=out_q + ) + + _tfl_conv2d_options.Conv2DOptionsStart(builder) + _tfl_conv2d_options.Conv2DOptionsAddStrideH(builder, 1) + _tfl_conv2d_options.Conv2DOptionsAddStrideW(builder, 1) + _tfl_conv2d_options.Conv2DOptionsAddPadding(builder, 1) + _tfl_conv2d_options.Conv2DOptionsAddFusedActivationFunction(builder, 0) + conv_opts = _tfl_conv2d_options.Conv2DOptionsEnd(builder) + + conv_op = _build_operator( + builder, + 0, + [0, 1, 2], + [3], + builtin_options_type=_tfl_builtin_options.Conv2DOptions, + builtin_options=conv_opts, + ) + subgraph = _build_subgraph( + builder, + tensors=[t_in, t_wt, t_bi, t_ou], + operators=[conv_op], + inputs=[0, 1, 2], + outputs=[3], + ) + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.CONV_2D)] + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[_build_buffer(builder)] * 4, + ) + + mod = _load_model_from_buffer(buf) + + @I.ir_module + class Expected: + @R.function + def main( + tvmgen_tensor_0: R.Tensor((1, 4, 4, 1), dtype="int8"), + tvmgen_tensor_1: R.Tensor((2, 3, 3, 1), dtype="int8"), + tvmgen_tensor_2: R.Tensor((2,), dtype="int32"), + ) -> R.Tensor((1, 2, 2, 2), dtype="int8"): + R.func_attr({"num_input": 3}) + with R.dataflow(): + lv: R.Tensor((1, 4, 4, 1), dtype="float32") = R.dequantize( + tvmgen_tensor_0, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="float32", + axis=0, + ) + lv1: R.Tensor((3, 3, 1, 2), dtype="int8") = R.permute_dims( + tvmgen_tensor_1, + axes=[1, 2, 3, 0], + ) + lv2: R.Tensor((3, 3, 1, 2), dtype="float32") = R.dequantize( + lv1, + R.const([0.25, 0.75], "float32"), + R.const(0, "int32"), + out_dtype="float32", + axis=3, + ) + lv3: R.Tensor((1, 2, 2, 2), dtype="float32") = R.nn.conv2d( + lv, + lv2, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="HWIO", + out_layout="NHWC", + out_dtype="void", + ) + lv4: R.Tensor((2,), dtype="float32") = R.multiply( + R.const(0.5, "float32"), + R.const([0.25, 0.75], "float32"), + ) + lv5: R.Tensor((2,), dtype="float32") = R.dequantize( + tvmgen_tensor_2, + lv4, + R.const(0, "int32"), + out_dtype="float32", + axis=0, + ) + lv6: R.Tensor((1, 2, 2, 2), dtype="float32") = R.add(lv3, lv5) + gv: R.Tensor((1, 2, 2, 2), dtype="int8") = R.quantize( + lv6, + R.const(1.0, "float32"), + R.const(0, "int32"), + out_dtype="int8", + axis=0, + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_per_channel_depthwise_conv_unsupported(): + """Per-channel quantized depthwise Conv2D raises OpNotImplemented.""" + import flatbuffers + import tflite.Model + + builder = flatbuffers.Builder(1024) + + in_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[0], quantized_dimension=0 + ) + # Per-channel weight: 2 channels, scale vector length 2 + wt_q = _build_quantization_parameters( + builder, scale=[0.25, 0.75], zero_point=[0, 0], quantized_dimension=3 + ) + out_q = _build_quantization_parameters( + builder, scale=[1.0], zero_point=[0], quantized_dimension=0 + ) + + t_in = _build_tensor( + builder, 0, [1, 4, 4, 2], tensor_type=_tfl_tensor_type.INT8, quantization=in_q + ) + t_wt = _build_tensor( + builder, 1, [1, 3, 3, 2], tensor_type=_tfl_tensor_type.INT8, quantization=wt_q + ) + t_ou = _build_tensor( + builder, 2, [1, 2, 2, 2], tensor_type=_tfl_tensor_type.INT8, quantization=out_q + ) + + _tfl_depthwise_conv2d_options.DepthwiseConv2DOptionsStart(builder) + _tfl_depthwise_conv2d_options.DepthwiseConv2DOptionsAddStrideH(builder, 1) + _tfl_depthwise_conv2d_options.DepthwiseConv2DOptionsAddStrideW(builder, 1) + _tfl_depthwise_conv2d_options.DepthwiseConv2DOptionsAddDepthMultiplier(builder, 1) + _tfl_depthwise_conv2d_options.DepthwiseConv2DOptionsAddPadding(builder, 1) + _tfl_depthwise_conv2d_options.DepthwiseConv2DOptionsAddFusedActivationFunction(builder, 0) + dw_opts = _tfl_depthwise_conv2d_options.DepthwiseConv2DOptionsEnd(builder) + + dw_op = _build_operator( + builder, + 0, + [0, 1], + [2], + builtin_options_type=_tfl_builtin_options.DepthwiseConv2DOptions, + builtin_options=dw_opts, + ) + subgraph = _build_subgraph( + builder, + tensors=[t_in, t_wt, t_ou], + operators=[dw_op], + inputs=[0, 1], + outputs=[2], + ) + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.DEPTHWISE_CONV_2D)] + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[_build_buffer(builder)] * 3, + ) + + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + + with pytest.raises(tvm.error.OpNotImplemented, match="Per-channel"): + from_tflite(tflite_model) + + +def test_uint8_reshape_requantize_uses_dq_reshape_q(): + """uint8 RESHAPE with different qparams uses DQ→reshape→Q.""" + import flatbuffers + import numpy as np + import tflite.Model + + builder = flatbuffers.Builder(1024) + + in_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[128], quantized_dimension=0 + ) + out_q = _build_quantization_parameters( + builder, scale=[1.0], zero_point=[100], quantized_dimension=0 + ) + + t_in = _build_tensor(builder, 0, [1, 4], tensor_type=_tfl_tensor_type.UINT8, quantization=in_q) + t_ou = _build_tensor(builder, 1, [2, 2], tensor_type=_tfl_tensor_type.UINT8, quantization=out_q) + + # Use ReshapeOptions with static new_shape [2, 2] + new_shape_np = np.array([2, 2], dtype=np.int32) + new_shape_vec = _tflite_int32_vector( + builder, _tfl_reshape_options.ReshapeOptionsStartNewShapeVector, new_shape_np + ) + _tfl_reshape_options.ReshapeOptionsStart(builder) + _tfl_reshape_options.ReshapeOptionsAddNewShape(builder, new_shape_vec) + reshape_opts = _tfl_reshape_options.ReshapeOptionsEnd(builder) + + reshape_op = _build_operator( + builder, + 0, + [0], + [1], + builtin_options_type=_tfl_builtin_options.ReshapeOptions, + builtin_options=reshape_opts, + ) + subgraph = _build_subgraph( + builder, + tensors=[t_in, t_ou], + operators=[reshape_op], + inputs=[0], + outputs=[1], + ) + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.RESHAPE)] + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[_build_buffer(builder), _build_buffer(builder)], + ) + + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + mod = from_tflite(tflite_model) + mod["main"] = mod["main"].without_attr("params") + + @I.ir_module + class Expected: + @R.function + def main( + tvmgen_tensor_0: R.Tensor((1, 4), dtype="uint8"), + ) -> R.Tensor((2, 2), dtype="uint8"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + lv: R.Tensor((1, 4), dtype="float32") = R.dequantize( + tvmgen_tensor_0, + R.const(0.5, "float32"), + R.const(128, "int32"), + out_dtype="float32", + axis=0, + ) + lv1: R.Tensor((2, 2), dtype="float32") = R.reshape( + lv, + R.shape([2, 2]), + ) + gv: R.Tensor((2, 2), dtype="uint8") = R.quantize( + lv1, + R.const(1.0, "float32"), + R.const(100, "int32"), + out_dtype="uint8", + axis=0, + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_transpose_conv_with_int32_bias_dequantizes_bias(): + """TRANSPOSE_CONV with INT32 bias dequantizes bias before adding.""" + import struct + + import flatbuffers + import tflite.Model + + builder = flatbuffers.Builder(2048) + + in_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + wt_q = _build_quantization_parameters( + builder, scale=[0.25], zero_point=[0], quantized_dimension=0 + ) + out_q = _build_quantization_parameters( + builder, scale=[1.0], zero_point=[0], quantized_dimension=0 + ) + + t_in = _build_tensor( + builder, 0, [1, 1, 1, 1], tensor_type=_tfl_tensor_type.INT8, quantization=in_q + ) + t_wt = _build_tensor( + builder, 1, [1, 1, 1, 1], tensor_type=_tfl_tensor_type.INT8, quantization=wt_q + ) + t_bi = _build_tensor(builder, 2, [1], tensor_type=_tfl_tensor_type.INT32) + t_ou = _build_tensor( + builder, 3, [1, 1, 1, 1], tensor_type=_tfl_tensor_type.INT8, quantization=out_q + ) + oshape_data = struct.pack(" R.Tensor((1, 1, 1, 1), dtype="int8"): + R.func_attr({"num_input": 3}) + with R.dataflow(): + lv: R.Tensor((1, 1, 1, 1), dtype="float32") = R.dequantize( + tvmgen_tensor_0, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="float32", + axis=0, + ) + lv1: R.Tensor((1, 1, 1, 1), dtype="int8") = R.permute_dims( + tvmgen_tensor_1, + axes=[3, 0, 1, 2], + ) + lv2: R.Tensor((1, 1, 1, 1), dtype="float32") = R.dequantize( + lv1, + R.const(0.25, "float32"), + R.const(0, "int32"), + out_dtype="float32", + axis=1, + ) + lv3: R.Tensor((1, 1, 1, 1), dtype="float32") = R.nn.conv2d_transpose( + lv, + lv2, + strides=[1, 1], + padding=[0, 0, 0, 0], + data_layout="NHWC", + kernel_layout="IOHW", + out_dtype="float32", + ) + lv4: R.Tensor((), dtype="float32") = R.multiply( + R.const(0.5, "float32"), + R.const(0.25, "float32"), + ) + lv5: R.Tensor((1,), dtype="float32") = R.dequantize( + tvmgen_tensor_2, + lv4, + R.const(0, "int32"), + out_dtype="float32", + axis=0, + ) + lv6: R.Tensor((1, 1, 1, 1), dtype="float32") = R.add(lv3, lv5) + gv: R.Tensor((1, 1, 1, 1), dtype="int8") = R.quantize( + lv6, + R.const(1.0, "float32"), + R.const(0, "int32"), + out_dtype="int8", + axis=0, + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_quantized_fully_connected_with_int32_bias_dequantizes_bias(): + """Quantized FullyConnected with INT32 bias dequantizes bias with in_scale x wt_scale.""" + import flatbuffers + import tflite.Model + + builder = flatbuffers.Builder(2048) + + in_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + wt_q = _build_quantization_parameters( + builder, scale=[0.25], zero_point=[0], quantized_dimension=0 + ) + out_q = _build_quantization_parameters( + builder, scale=[1.0], zero_point=[0], quantized_dimension=0 + ) + + t_in = _build_tensor(builder, 0, [1, 4], tensor_type=_tfl_tensor_type.INT8, quantization=in_q) + t_wt = _build_tensor(builder, 1, [2, 4], tensor_type=_tfl_tensor_type.INT8, quantization=wt_q) + t_bi = _build_tensor(builder, 2, [2], tensor_type=_tfl_tensor_type.INT32) + t_ou = _build_tensor(builder, 3, [1, 2], tensor_type=_tfl_tensor_type.INT8, quantization=out_q) + + _tfl_fully_connected_options.FullyConnectedOptionsStart(builder) + _tfl_fully_connected_options.FullyConnectedOptionsAddFusedActivationFunction(builder, 0) + _tfl_fully_connected_options.FullyConnectedOptionsAddWeightsFormat( + builder, _tfl_fc_weights_format.DEFAULT + ) + _tfl_fully_connected_options.FullyConnectedOptionsAddKeepNumDims(builder, 0) + fc_opts = _tfl_fully_connected_options.FullyConnectedOptionsEnd(builder) + + fc_op = _build_operator( + builder, + 0, + [0, 1, 2], + [3], + builtin_options_type=_tfl_builtin_options.FullyConnectedOptions, + builtin_options=fc_opts, + ) + subgraph = _build_subgraph( + builder, + tensors=[t_in, t_wt, t_bi, t_ou], + operators=[fc_op], + inputs=[0, 1, 2], + outputs=[3], + ) + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.FULLY_CONNECTED)] + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[_build_buffer(builder)] * 4, + ) + + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + mod = from_tflite(tflite_model) + mod["main"] = mod["main"].without_attr("params") + + @I.ir_module + class Expected: + @R.function + def main( + tvmgen_tensor_0: R.Tensor((1, 4), dtype="int8"), + tvmgen_tensor_1: R.Tensor((2, 4), dtype="int8"), + tvmgen_tensor_2: R.Tensor((2,), dtype="int32"), + ) -> R.Tensor((1, 2), dtype="int8"): + R.func_attr({"num_input": 3}) + with R.dataflow(): + lv: R.Tensor((1, 4), dtype="float32") = R.dequantize( + tvmgen_tensor_0, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="float32", + axis=0, + ) + lv1: R.Tensor((4, 2), dtype="int8") = R.permute_dims( + tvmgen_tensor_1, + axes=[1, 0], + ) + lv2: R.Tensor((4, 2), dtype="float32") = R.dequantize( + lv1, + R.const(0.25, "float32"), + R.const(0, "int32"), + out_dtype="float32", + axis=1, + ) + lv3: R.Tensor((1, 2), dtype="float32") = R.matmul(lv, lv2, out_dtype="void") + lv4: R.Tensor((), dtype="float32") = R.multiply( + R.const(0.5, "float32"), + R.const(0.25, "float32"), + ) + lv5: R.Tensor((2,), dtype="float32") = R.dequantize( + tvmgen_tensor_2, + lv4, + R.const(0, "int32"), + out_dtype="float32", + axis=0, + ) + lv6: R.Tensor((1, 2), dtype="float32") = R.add(lv3, lv5) + gv: R.Tensor((1, 2), dtype="int8") = R.quantize( + lv6, + R.const(1.0, "float32"), + R.const(0, "int32"), + out_dtype="int8", + axis=0, + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + def _build_csr_sparsity( builder, *,