From 47636150e6eeff0befc2eb0fade3643505da60e6 Mon Sep 17 00:00:00 2001 From: Simon Strycek Date: Mon, 27 Oct 2025 11:21:45 +0100 Subject: [PATCH 1/3] Add QAT support for NeutronQuantizer --- backends/nxp/quantizer/neutron_quantizer.py | 164 +++++++++++++------- backends/nxp/quantizer/patterns.py | 52 +++++-- 2 files changed, 147 insertions(+), 69 deletions(-) diff --git a/backends/nxp/quantizer/neutron_quantizer.py b/backends/nxp/quantizer/neutron_quantizer.py index f476e16628e..78c386e3436 100644 --- a/backends/nxp/quantizer/neutron_quantizer.py +++ b/backends/nxp/quantizer/neutron_quantizer.py @@ -50,7 +50,13 @@ ) from torch import fx from torch.ao.quantization.quantizer.utils import _annotate_output_qspec -from torchao.quantization.pt2e import HistogramObserver, MinMaxObserver +from torchao.quantization.pt2e import ( + FakeQuantize, + FusedMovingAvgObsFakeQuantize, + HistogramObserver, + MinMaxObserver, + MovingAverageMinMaxObserver, +) from torchao.quantization.pt2e.quantizer import ( ComposableQuantizer, DerivedQuantizationSpec, @@ -150,74 +156,116 @@ def get_supported_operators(cls) -> list[OperatorConfig]: # Quantization Specification used by Neutron NPU -act_qspec = QuantizationSpec( - dtype=torch.int8, - quant_min=-128, - quant_max=127, - qscheme=torch.per_tensor_affine, - is_dynamic=False, - observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12), -) - -wgt_qspec = QuantizationSpec( - dtype=torch.int8, - quant_min=-127, - quant_max=127, - qscheme=torch.per_tensor_symmetric, - is_dynamic=False, - observer_or_fake_quant_ctr=MinMaxObserver, - ch_axis=0, -) +def act_qspec(is_qat: bool): + eps = 2**-12 + observer_or_fake_quant_ctr = ( + FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, eps=eps + ) + if is_qat + else HistogramObserver.with_args(eps=eps) + ) + + return QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=observer_or_fake_quant_ctr, + ) + + +def wgt_qspec(is_qat: bool): + observer_or_fake_quant_ctr = ( + FakeQuantize.with_args(observer=MovingAverageMinMaxObserver) + if is_qat + else MinMaxObserver + ) + + return QuantizationSpec( + dtype=torch.int8, + quant_min=-127, + quant_max=127, + qscheme=torch.per_tensor_symmetric, + is_dynamic=False, + observer_or_fake_quant_ctr=observer_or_fake_quant_ctr, + ch_axis=0, + ) + + +def wgt_fc_qspec(is_qat: bool): + observer_or_fake_quant_ctr = ( + FakeQuantize.with_args(observer=MovingAverageMinMaxObserver) + if is_qat + else MinMaxObserver + ) + + return QuantizationSpec( + dtype=torch.int8, + quant_min=-127, + quant_max=127, + qscheme=torch.per_tensor_symmetric, + is_dynamic=False, + observer_or_fake_quant_ctr=observer_or_fake_quant_ctr, + ) -wgt_fc_qspec = QuantizationSpec( - dtype=torch.int8, - quant_min=-127, - quant_max=127, - qscheme=torch.per_tensor_symmetric, - is_dynamic=False, - observer_or_fake_quant_ctr=MinMaxObserver, -) # Is set by the *PatternQuantizer directly. bias_qspec = None class NeutronQuantizer(ComposableQuantizer): - def __init__(self, neutron_target_spec: NeutronTargetSpec): + def __init__(self, neutron_target_spec: NeutronTargetSpec, is_qat: bool = False): self.neutron_target_spec = neutron_target_spec - static_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_qspec, None) - static_fc_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_fc_qspec, None) + self.is_qat = is_qat + + static_qconfig = QuantizationConfig( + act_qspec(is_qat=is_qat), + act_qspec(is_qat=is_qat), + wgt_qspec(is_qat=is_qat), + None, + ) + static_fc_qconfig = QuantizationConfig( + act_qspec(is_qat=is_qat), + act_qspec(is_qat=is_qat), + wgt_fc_qspec(is_qat=is_qat), + None, + ) + + OpQuantizer = NeutronAtenQuantizer super().__init__( [ - NeutronAtenQuantizer(AbsPattern(), static_qconfig), - NeutronAtenQuantizer(AdaptiveAvgPoolPattern(), static_qconfig), - NeutronAtenQuantizer(AddTensorPattern(), static_qconfig), - NeutronAtenQuantizer(AddmmPattern(self), static_fc_qconfig), - NeutronAtenQuantizer(AvgPoolPattern(), static_qconfig), - NeutronAtenQuantizer(CatPattern(), static_qconfig), - NeutronAtenQuantizer(Conv1dPattern(), static_qconfig), - NeutronAtenQuantizer(Conv2dPattern(self), static_qconfig), - NeutronAtenQuantizer(DropoutPattern(), static_qconfig), - NeutronAtenQuantizer(FlattenPattern(), static_qconfig), - NeutronAtenQuantizer(HardTanhPattern(), static_qconfig), - NeutronAtenQuantizer(HardTanhInPlacePattern(), static_qconfig), - NeutronAtenQuantizer(LinearPattern(self), static_fc_qconfig), - NeutronAtenQuantizer(MaxPoolPattern(), static_qconfig), - NeutronAtenQuantizer(MeanDimPattern(), static_qconfig), - NeutronAtenQuantizer(MmPattern(self), static_qconfig), - NeutronAtenQuantizer(PadPattern(), static_qconfig), - NeutronAtenQuantizer(PermutePattern(), static_qconfig), - NeutronAtenQuantizer(ReluPattern(), static_qconfig), - NeutronAtenQuantizer(ReluInPlacePattern(), static_qconfig), - NeutronAtenQuantizer(ReshapePattern(), static_qconfig), - NeutronAtenQuantizer(SigmoidPattern(), static_qconfig), - NeutronAtenQuantizer(SoftMaxPattern(), static_qconfig), - NeutronAtenQuantizer(SubTensorPattern(), static_qconfig), - NeutronAtenQuantizer(TanhPattern(), static_qconfig), - NeutronAtenQuantizer(TanhInPlacePattern(), static_qconfig), - NeutronAtenQuantizer(ViewPattern(), static_qconfig), + OpQuantizer(AbsPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(AdaptiveAvgPoolPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(AddTensorPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(AddmmPattern(self, is_qat=is_qat), static_fc_qconfig), + OpQuantizer(AvgPoolPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(CatPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(Conv1dPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(Conv2dPattern(self, is_qat=is_qat), static_qconfig), + OpQuantizer(DropoutPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(FlattenPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(HardTanhPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(HardTanhInPlacePattern(is_qat=is_qat), static_qconfig), + OpQuantizer(LinearPattern(self, is_qat=is_qat), static_fc_qconfig), + OpQuantizer(MaxPoolPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(MeanDimPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(MmPattern(self, is_qat=is_qat), static_qconfig), + OpQuantizer(PadPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(PermutePattern(is_qat=is_qat), static_qconfig), + OpQuantizer(ReluPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(ReluInPlacePattern(is_qat=is_qat), static_qconfig), + OpQuantizer(ReshapePattern(is_qat=is_qat), static_qconfig), + OpQuantizer(SigmoidPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(SoftMaxPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(SubTensorPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(TanhPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(TanhInPlacePattern(is_qat=is_qat), static_qconfig), + OpQuantizer(ViewPattern(is_qat=is_qat), static_qconfig), ] ) + # Mapping ops defined in quantizer partition types to its quantizer self.op_to_quantizer = { pt: q for q in self.quantizers for pt in q.pattern.partition_types() @@ -280,7 +328,7 @@ def _annotate_inputs(self, model: fx.GraphModule): continue if node.op == "placeholder" and len(node.users) > 0: - _annotate_output_qspec(node, act_qspec) + _annotate_output_qspec(node, act_qspec(self.is_qat)) self._mark_input_node_as_annotated(node) def validate(self, model: torch.fx.GraphModule) -> None: diff --git a/backends/nxp/quantizer/patterns.py b/backends/nxp/quantizer/patterns.py index ee92cd42ef1..a0fb580ab33 100644 --- a/backends/nxp/quantizer/patterns.py +++ b/backends/nxp/quantizer/patterns.py @@ -14,7 +14,12 @@ from torch import fx from torch._ops import OpOverload from torch.fx import Node -from torchao.quantization.pt2e import PerChannelMinMaxObserver +from torchao.quantization.pt2e import ( + FakeQuantize, + FixedQParamsFakeQuantize, + MovingAveragePerChannelMinMaxObserver, + PerChannelMinMaxObserver, +) from torchao.quantization.pt2e.quantizer import ( DerivedQuantizationSpec, FixedQParamsQuantizationSpec, @@ -59,7 +64,8 @@ class PartitionAnchors: | tuple[fx.Node, NodeArgsIdx, SharedQuantizationSpec], ] = field(default_factory=list) weights: list[ - tuple[fx.Node, NodeArgsIdx] | tuple[fx.Node, NodeArgsIdx, QuantizationSpec], + tuple[fx.Node, NodeArgsIdx] + | tuple[fx.Node, NodeArgsIdx, QuantizationSpec | FakeQuantize], ] = field(default_factory=list) biases: list[ tuple[fx.Node, NodeArgsIdx] @@ -69,12 +75,20 @@ class PartitionAnchors: literals: list[tuple[fx.Node, NodeArgsIdx]] = field(default_factory=list) output: list[ tuple[fx.Node] - | tuple[fx.Node, FixedQParamsQuantizationSpec | SharedQuantizationSpec], + | tuple[ + fx.Node, + FixedQParamsQuantizationSpec + | FixedQParamsFakeQuantize + | SharedQuantizationSpec, + ], ] = field(default_factory=list) empty: bool = False class QuantizationPattern(ABC): + def __init__(self, is_qat: bool): + self.is_qat = is_qat + @abstractmethod def partition_types(self) -> list[OpOverload]: """ @@ -147,11 +161,15 @@ def get_anchors_for_fixed_quant_specs( zero_point: int, quant_min: int = -128, quant_max: int = 127, + is_qat: bool = False, ) -> PartitionAnchors: node = fused_partition[0].nodes[-1] assert len(fused_partition[0].input_nodes) == 1 - qspec = FixedQParamsQuantizationSpec( + QSpecOrFakeQuantize = ( + FixedQParamsFakeQuantize if is_qat else FixedQParamsQuantizationSpec + ) + qspec_or_fake_quantize = QSpecOrFakeQuantize( dtype=torch.int8, scale=scale, zero_point=zero_point, @@ -165,7 +183,7 @@ def get_anchors_for_fixed_quant_specs( weights=[], biases=[], output=[ - (node, qspec), + (node, qspec_or_fake_quantize), ], ) @@ -189,11 +207,12 @@ def partition_types(self): class AddmmPattern(QuantizationPattern): - def __init__(self, neutron_quantizer): + def __init__(self, neutron_quantizer, is_qat: bool): self.neutron_quantizer = neutron_quantizer self.neutron_target_info = ( self.neutron_quantizer.neutron_target_spec.neutron_target_info ) + self.is_qat = is_qat def partition_types(self) -> list[OpOverload]: return [torch.ops.aten.addmm.default] @@ -364,7 +383,11 @@ def get_anchors( ch_axis=0, ) - weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver + weight_observer_or_fake_quant_ctr = ( + FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver) + if self.is_qat + else PerChannelMinMaxObserver + ) weight_quantization_spec = QuantizationSpec( dtype=torch.int8, observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr, @@ -393,11 +416,12 @@ def partition_types(self) -> list[OpOverload]: class Conv2dPattern(ConvPattern): - def __init__(self, neutron_quantizer): + def __init__(self, neutron_quantizer, is_qat: bool): self.neutron_quantizer = neutron_quantizer self.neutron_target_info = ( self.neutron_quantizer.neutron_target_spec.neutron_target_info ) + self.is_qat = is_qat def partition_types(self) -> list[OpOverload]: return [torch.ops.aten.conv2d.default] @@ -420,7 +444,11 @@ def get_anchors( ch_axis=0, ) - weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver + weight_observer_or_fake_quant_ctr = ( + FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver) + if self.is_qat + else PerChannelMinMaxObserver + ) weight_quantization_spec = QuantizationSpec( dtype=torch.int8, observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr, @@ -512,11 +540,12 @@ def replacement_op(self): class LinearPattern(QuantizationPattern): - def __init__(self, neutron_quantizer): + def __init__(self, neutron_quantizer, is_qat: bool): self.neutron_quantizer = neutron_quantizer self.neutron_target_info = ( self.neutron_quantizer.neutron_target_spec.neutron_target_info ) + self.is_qat = is_qat def partition_types(self) -> list[OpOverload]: return [torch.ops.aten.linear.default] @@ -586,11 +615,12 @@ def partition_types(self): class MmPattern(QuantizationPattern): - def __init__(self, neutron_quantizer): + def __init__(self, neutron_quantizer, is_qat: bool): self.neutron_quantizer = neutron_quantizer self.neutron_target_info = ( self.neutron_quantizer.neutron_target_spec.neutron_target_info ) + self.is_qat = is_qat def partition_types(self) -> list[OpOverload]: return [torch.ops.aten.mm.default] From d7829f7eb1d4598a1651adab970b1323130d43f9 Mon Sep 17 00:00:00 2001 From: Simon Strycek Date: Mon, 10 Nov 2025 11:42:43 +0100 Subject: [PATCH 2/3] Add quantizer targeted unit tests for QAT --- backends/nxp/tests/models.py | 15 +++ backends/nxp/tests/test_quantizer.py | 138 +++++++++++++++++++++------ 2 files changed, 124 insertions(+), 29 deletions(-) diff --git a/backends/nxp/tests/models.py b/backends/nxp/tests/models.py index 2bd1f2b6d77..f52df3c2571 100644 --- a/backends/nxp/tests/models.py +++ b/backends/nxp/tests/models.py @@ -571,3 +571,18 @@ def __init__(self, activation: str, inplace: bool, in_channels: int): def forward(self, x): x = self.conv(x) return self.activation(x) + + +class MLP(torch.nn.Module): + def __init__(self): + super().__init__() + self.sequential = torch.nn.Sequential( + torch.nn.Linear(1, 10), + torch.nn.ReLU(), + torch.nn.Linear(10, 10), + torch.nn.ReLU(), + torch.nn.Linear(10, 1), + ) + + def forward(self, x): + return self.sequential(x) diff --git a/backends/nxp/tests/test_quantizer.py b/backends/nxp/tests/test_quantizer.py index d76fbaf460d..ca01a246070 100644 --- a/backends/nxp/tests/test_quantizer.py +++ b/backends/nxp/tests/test_quantizer.py @@ -29,9 +29,17 @@ ToChannelLastPreprocess, ) from executorch.exir.dialects._ops import ops as exir_ops -from torch.export import ExportedProgram +from torch.export import export, ExportedProgram from torch.fx import GraphModule -from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e import ( + move_exported_model_to_eval, + move_exported_model_to_train, +) +from torchao.quantization.pt2e.quantize_pt2e import ( + convert_pt2e, + prepare_pt2e, + prepare_qat_pt2e, +) fuse_activation_ops = [ exir_ops.edge.aten.addmm.default, @@ -44,16 +52,31 @@ ] +@pytest.fixture(autouse=True) +def reseed_model_per_test_run(): + torch.manual_seed(23) + + +def _prepare_for_quantization(exported_model, is_qat: bool = False): + if is_qat: + return prepare_qat_pt2e( + exported_model.module(), NeutronQuantizer(neutron_target_spec, is_qat=True) + ) + else: + return prepare_pt2e( + exported_model.module(), NeutronQuantizer(neutron_target_spec) + ) + + def test_quantizer_conv2d(): model = models.Conv2dModule() model.eval() example_input = (torch.ones(1, 4, 32, 32),) - quantizer = NeutronQuantizer(neutron_target_spec) - graph_module = torch.export.export(model, example_input, strict=True).module() + exported_model = torch.export.export(model, example_input, strict=True) # noinspection PyTypeChecker - m = prepare_pt2e(graph_module, quantizer) + m = _prepare_for_quantization(exported_model) m(*example_input) m = convert_pt2e(m) @@ -87,11 +110,10 @@ def test_quantizer_linear(): model.eval() example_input = (torch.ones(10, 32),) - quantizer = NeutronQuantizer(neutron_target_spec) - graph_module = torch.export.export(model, example_input, strict=True).module() + exported_model = torch.export.export(model, example_input, strict=True) # noinspection PyTypeChecker - m = prepare_pt2e(graph_module, quantizer) + m = _prepare_for_quantization(exported_model) m(*example_input) m = convert_pt2e(m) @@ -123,11 +145,10 @@ def test_quantizer_maxpool2d(): model.eval() example_input = (torch.ones(1, 8, 32, 32),) - quantizer = NeutronQuantizer(neutron_target_spec) - graph_module = torch.export.export(model, example_input, strict=True).module() + exported_model = torch.export.export(model, example_input, strict=True) # noinspection PyTypeChecker - m = prepare_pt2e(graph_module, quantizer) + m = _prepare_for_quantization(exported_model) m(*example_input) m = convert_pt2e(m) @@ -158,11 +179,10 @@ def test_quantizer_softmax(): model.eval() example_input = (torch.ones(1, 10),) - quantizer = NeutronQuantizer(neutron_target_spec) - graph_module = torch.export.export(model, example_input, strict=True).module() + exported_model = torch.export.export(model, example_input, strict=True) # noinspection PyTypeChecker - m = prepare_pt2e(graph_module, quantizer) + m = _prepare_for_quantization(exported_model) m(*example_input) m = convert_pt2e(m) @@ -192,11 +212,10 @@ def test_quantizer_single_maxpool2d(): model.eval() example_input = (torch.ones(1, 4, 32, 32),) - quantizer = NeutronQuantizer(neutron_target_spec) - graph_module = torch.export.export(model, example_input, strict=True).module() + exported_model = torch.export.export(model, example_input, strict=True) # noinspection PyTypeChecker - m = prepare_pt2e(graph_module, quantizer) + m = _prepare_for_quantization(exported_model) m(*example_input) m = convert_pt2e(m) @@ -214,11 +233,10 @@ def test_quantizer_conv2d_relu(): model.eval() example_input = (torch.ones(1, 4, 32, 32),) - quantizer = NeutronQuantizer(neutron_target_spec) - graph_module = torch.export.export(model, example_input, strict=True).module() + exported_model = torch.export.export(model, example_input, strict=True) # noinspection PyTypeChecker - m = prepare_pt2e(graph_module, quantizer) + m = _prepare_for_quantization(exported_model) m(*example_input) m = convert_pt2e(m) @@ -241,11 +259,10 @@ def test_quantizer_conv2d_avg_pool2d(): model.eval() example_input = (torch.ones(1, 4, 16, 16),) - quantizer = NeutronQuantizer(neutron_target_spec) - graph_module = torch.export.export(model, example_input, strict=True).module() + exported_model = torch.export.export(model, example_input, strict=True) # noinspection PyTypeChecker - m = prepare_pt2e(graph_module, quantizer) + m = _prepare_for_quantization(exported_model) m(*example_input) m = convert_pt2e(m) @@ -269,11 +286,10 @@ def test_quantizer_conv2d_permute(): model.eval() example_input = (torch.ones(1, 4, 16, 16),) - quantizer = NeutronQuantizer(neutron_target_spec) - graph_module = torch.export.export(model, example_input, strict=True).module() + exported_model = torch.export.export(model, example_input, strict=True) # noinspection PyTypeChecker - m = prepare_pt2e(graph_module, quantizer) + m = _prepare_for_quantization(exported_model) m(*example_input) m = convert_pt2e(m) @@ -301,11 +317,10 @@ def test_multiple_shared_spec_ops_in_row(): model.eval() example_input = (torch.ones(1, 3, 64, 64),) - quantizer = NeutronQuantizer(neutron_target_spec) - graph_module = torch.export.export(model, example_input, strict=True).module() + exported_model = torch.export.export(model, example_input, strict=True) # noinspection PyTypeChecker - m = prepare_pt2e(graph_module, quantizer) + m = _prepare_for_quantization(exported_model) m(*example_input) m = convert_pt2e(m) @@ -579,3 +594,68 @@ def test_quantizer__conv_w_activation(mocker, activation, inplace): tflite_output_preprocess=ToChannelFirstPreprocess(), atol=1.0, ) + + +def test_qat_train(loss_tolerance: float = 0.02): + def evaluate(model, inputs, gts): + with torch.no_grad(): + test_outputs = model(inputs) + loss = torch.nn.functional.mse_loss(test_outputs, gts) + return loss + + def train_step(model, optimizer): + optimizer.zero_grad() + batch = torch.randn(100, 1).clamp(-1, 1) + outputs = model(batch) + loss = torch.nn.functional.mse_loss(outputs, torch.sin(batch)) + loss.backward() + optimizer.step() + + model = models.MLP() + model.train() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + + for _ in range(100): + train_step(model, optimizer) + + test_inputs = torch.randn(20, 1).clamp(-1, 1) + test_outputs = torch.sin(test_inputs) + + model.eval() + eval_loss = evaluate(model, test_inputs, test_outputs) + + exported_model = export(model, (torch.randn(1, 1),), strict=True) + prepared_model = _prepare_for_quantization(exported_model, is_qat=True) + + prepared_model = move_exported_model_to_train(prepared_model) + for _ in range(30): + train_step(prepared_model, optimizer) + prepared_model = move_exported_model_to_eval(prepared_model) + + quantized_model = convert_pt2e(prepared_model) + + test_inputs = torch.randn(100, 1).clamp(-1, 1) + test_outputs = torch.sin(test_inputs) + + quant_eval_loss = evaluate(quantized_model, test_inputs, test_outputs) + + assert (quant_eval_loss - eval_loss) < loss_tolerance + + +def test_qat_produces_same_graph_as_ptq(): + model = models.Conv2dModule(in_channels=8, out_channels=32, kernel_size=5) + model.eval() + exported_model = export(model, ((torch.randn(1, 8, 32, 32),)), strict=True) + + qat_prepared_model = _prepare_for_quantization(exported_model, is_qat=True) + qat_quantized_model = convert_pt2e(qat_prepared_model) + + ptq_prepared_model = _prepare_for_quantization(exported_model, is_qat=False) + ptq_quantized_model = convert_pt2e(ptq_prepared_model) + + assert all( + ptqn.target == qatn.target + for qatn, ptqn in zip( + qat_quantized_model.graph.nodes, ptq_quantized_model.graph.nodes + ) + ) From 35fce40c22c911a579cc9ac493953c97deda12d9 Mon Sep 17 00:00:00 2001 From: Simon Strycek Date: Mon, 10 Nov 2025 15:08:38 +0100 Subject: [PATCH 3/3] Update per-op pytest tests to cover QAT --- backends/nxp/tests/executorch_pipeline.py | 36 ++++++++++--- .../node_converter/test_abs_converter.py | 7 ++- .../test_adaptive_avg_pool2d_converter.py | 17 +++--- .../test_add_tensor_converter.py | 17 +++--- .../test_avg_pool2d_converter.py | 11 ++-- .../node_converter/test_cat_converter.py | 52 ++++++++++-------- .../test_constant_pad_nd_converter.py | 31 +++++++---- .../node_converter/test_conv_converter.py | 29 +++++----- .../node_converter/test_hardtanh_converter.py | 17 ++++-- .../test_max_pool_2d_converter.py | 7 ++- .../node_converter/test_mean_dim_converter.py | 53 ++++++++++++++----- .../test_permute_copy_converter.py | 7 ++- .../node_converter/test_relu_converter.py | 9 ++-- .../node_converter/test_sigmoid_converter.py | 9 ++-- .../test_sub_tensor_converter.py | 19 ++++--- .../test_view_copy_converter.py | 12 +++-- backends/nxp/tests/use_qat.py | 11 ++++ 17 files changed, 233 insertions(+), 111 deletions(-) create mode 100644 backends/nxp/tests/use_qat.py diff --git a/backends/nxp/tests/executorch_pipeline.py b/backends/nxp/tests/executorch_pipeline.py index d209ce3ea01..91c56d20658 100644 --- a/backends/nxp/tests/executorch_pipeline.py +++ b/backends/nxp/tests/executorch_pipeline.py @@ -32,7 +32,15 @@ ) from torch import nn from torch.export import export -from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e import ( + move_exported_model_to_eval, + move_exported_model_to_train, +) +from torchao.quantization.pt2e.quantize_pt2e import ( + convert_pt2e, + prepare_pt2e, + prepare_qat_pt2e, +) from torchao.quantization.pt2e.quantizer import Quantizer default_neutron_converter_flavor = "SDK_25_09" @@ -48,11 +56,23 @@ class ModelInputSpec: def _quantize_model( - model, quantizer, calibration_inputs: list[tuple[torch.Tensor, ...]] + model, quantizer, calibration_inputs: list[tuple[torch.Tensor, ...]], is_qat: bool ): - m = prepare_pt2e(model, quantizer) + if is_qat: + m = prepare_qat_pt2e(model, quantizer) + m = move_exported_model_to_train(m) # TODO: Find if this call is necessary. + else: + m = prepare_pt2e(model, quantizer) + + # We omit training in case of QAT mode as it is not necessary for producing quantized model + # and would introduce slow downs. There are unit tests covering full QAT in quantizer tests. + # When in QAT mode, observers are being updated during every forward pass as it is with PTQ. for data in calibration_inputs: m(*data) + + if is_qat: + m = move_exported_model_to_eval(m) + m = convert_pt2e(m) return m @@ -67,8 +87,8 @@ def get_random_calibration_inputs( ] -def _get_default_quantizer(target_spec: NeutronTargetSpec) -> Quantizer: - return NeutronQuantizer(target_spec) +def _get_default_quantizer(target_spec: NeutronTargetSpec, use_qat: bool) -> Quantizer: + return NeutronQuantizer(target_spec, is_qat=use_qat) def to_model_input_spec( @@ -102,13 +122,16 @@ def to_quantized_edge_program( ] = get_random_calibration_inputs, target="imxrt700", neutron_converter_flavor=default_neutron_converter_flavor, + use_qat=False, remove_quant_io_ops=False, custom_delegation_options=CustomDelegationOptions(), # noqa B008 get_quantizer_fn=None, ) -> EdgeProgramManager: _neutron_target_spec = NeutronTargetSpec(target, neutron_converter_flavor) if get_quantizer_fn is None: - get_quantizer_fn = partial(_get_default_quantizer, _neutron_target_spec) + get_quantizer_fn = partial( + _get_default_quantizer, _neutron_target_spec, use_qat + ) quantizer = get_quantizer_fn() calibration_inputs = get_calibration_inputs_fn(to_model_input_spec(input_spec)) @@ -123,6 +146,7 @@ def to_quantized_edge_program( exir_program_aten.module(), quantizer, calibration_inputs, + is_qat=use_qat, ) compile_spec = generate_neutron_compile_spec( diff --git a/backends/nxp/tests/ir/converter/node_converter/test_abs_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_abs_converter.py index 315c76a7614..4b62dd64dfb 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_abs_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_abs_converter.py @@ -19,6 +19,7 @@ ) from executorch.exir.dialects._ops import ops as exir_ops from torch.export import ExportedProgram +from executorch.backends.nxp.tests.use_qat import * # noqa F403 @pytest.fixture(autouse=True) @@ -62,12 +63,14 @@ def forward(self, x): return x.abs() -def test_conv_abs(mocker, input_shape: tuple[int] = (1, 3, 112, 112)): +def test_conv_abs(mocker, use_qat, input_shape: tuple[int] = (1, 3, 112, 112)): model = ConvBlocksWithAbs(conv_in_channels=input_shape[1]) converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - quantized_program = to_quantized_edge_program(model, input_shape).exported_program() + quantized_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() tflite_flatbuffers_model, io_formats = converter_spy.spy_return exported_program: ExportedProgram = converter_spy.call_args.args[1] diff --git a/backends/nxp/tests/ir/converter/node_converter/test_adaptive_avg_pool2d_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_adaptive_avg_pool2d_converter.py index 9c8235f7eda..7546d0f3730 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_adaptive_avg_pool2d_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_adaptive_avg_pool2d_converter.py @@ -16,6 +16,7 @@ AdaptiveAvgPool2dConvModule, ) from torch.export import ExportedProgram +from executorch.backends.nxp.tests.use_qat import * # noqa F403 @pytest.fixture(autouse=True) @@ -40,14 +41,16 @@ def reseed_model_per_test_run(): ], ) def test_adaptive_avg_pool_2d_delegated_quant_conversion( - mocker, input_shape, output_size + mocker, input_shape, output_size, use_qat ): model = AdaptiveAvgPool2dConvModule(output_size) converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - edge_program = to_quantized_edge_program(model, input_shape).exported_program() + edge_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() nodes = [str(node) for node in edge_program.graph.nodes] # Input size is a multiple of output size, can be converted to AveragePool, node is delegated @@ -84,14 +87,16 @@ def test_adaptive_avg_pool_2d_delegated_quant_conversion( ], ) def test_adaptive_avg_pool_2d_non_delegated_quant_conversion( - mocker, input_shape, output_size + mocker, input_shape, output_size, use_qat ): model = AdaptiveAvgPool2dConvModule(output_size) converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - edge_program = to_quantized_edge_program(model, input_shape).exported_program() + edge_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() nodes = list(edge_program.graph.nodes) # Input size is not a multiple of output size, cannot be converted to AveragePool, node is not delegated @@ -115,14 +120,14 @@ def test_adaptive_avg_pool_2d_non_delegated_quant_conversion( ) -def test_adaptive_avg_pool_2d_mean_dim_quant_conversion(mocker): +def test_adaptive_avg_pool_2d_mean_dim_quant_conversion(mocker, use_qat): input_shape = (1, 4, 16, 16) model = AdaptiveAvgPool2dConvMeanDimModule() converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(model, input_shape) + _ = to_quantized_edge_program(model, input_shape, use_qat=use_qat) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return diff --git a/backends/nxp/tests/ir/converter/node_converter/test_add_tensor_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_add_tensor_converter.py index 2c3107eae77..57ba7628dfb 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_add_tensor_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_add_tensor_converter.py @@ -21,6 +21,7 @@ AddTensorOneInputModule, ) from torch.export import ExportedProgram +from executorch.backends.nxp.tests.use_qat import * # noqa F403 @pytest.fixture(autouse=True) @@ -38,13 +39,13 @@ def reseed_model_per_test_run(): pytest.param((1, 4, 8, 8), id="4D."), ], ) -def test_add_tensor_quant_conversion(mocker, input_shape): +def test_add_tensor_quant_conversion(mocker, input_shape, use_qat): model = AddTensorModule() converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(model, [input_shape, input_shape]) + _ = to_quantized_edge_program(model, [input_shape, input_shape], use_qat=use_qat) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return @@ -69,13 +70,13 @@ def test_add_tensor_quant_conversion(mocker, input_shape): pytest.param((1, 4, 8, 8), id="4D."), ], ) -def test_add_tensor_one_input_quant_conversion(mocker, input_shape): +def test_add_tensor_one_input_quant_conversion(mocker, input_shape, use_qat): model = AddTensorOneInputModule() converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(model, input_shape) + _ = to_quantized_edge_program(model, input_shape, use_qat=use_qat) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return @@ -97,13 +98,13 @@ def test_add_tensor_one_input_quant_conversion(mocker, input_shape): pytest.param((1, 4, 5, 5), id="4D, product of dims is not a multiple of 8."), ], ) -def test_add_tensor_w_conv_quant_conversion(mocker, input_shape): +def test_add_tensor_w_conv_quant_conversion(mocker, input_shape, use_qat): model = AddTensorConvModule() converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(model, input_shape) + _ = to_quantized_edge_program(model, input_shape, use_qat=use_qat) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return @@ -135,13 +136,13 @@ def test_add_tensor_w_conv_quant_conversion(mocker, input_shape): ], ) def test_add_tensor_broadcasting_unsupported_quant_conversion( - x_input_shape, y_input_shape + x_input_shape, y_input_shape, use_qat ): model = AddTensorModule() # Run conversion edge_program = to_quantized_edge_program( - model, [x_input_shape, y_input_shape] + model, [x_input_shape, y_input_shape], use_qat=use_qat ).exported_program() nodes = list(edge_program.graph.nodes) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py index bcdbd955c71..f701f91bf0c 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py @@ -27,6 +27,7 @@ ) from executorch.backends.nxp.tests.models import AvgPool2dConvModule, AvgPool2dModule from torch.export import ExportedProgram +from executorch.backends.nxp.tests.use_qat import * # noqa F403 @pytest.fixture(autouse=True) @@ -139,13 +140,15 @@ def test_avg_pool_2d_conversion(input_shape, padding, count_include_pad): ), ], ) -def test_avg_pool_2d_quant_conversion(mocker, input_shape, padding, count_include_pad): +def test_avg_pool_2d_quant_conversion( + mocker, input_shape, padding, count_include_pad, use_qat +): model = AvgPool2dConvModule(padding=padding, count_include_pad=count_include_pad) converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(model, input_shape) + _ = to_quantized_edge_program(model, input_shape, use_qat=use_qat) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return @@ -164,7 +167,7 @@ def test_avg_pool_2d_quant_conversion(mocker, input_shape, padding, count_includ ) -def test_avg_pool_2d_quant_conversion__padded(mocker): +def test_avg_pool_2d_quant_conversion__padded(mocker, use_qat): input_shape = (1, 8, 8, 8) model = AvgPool2dModule(True, 1) @@ -172,7 +175,7 @@ def test_avg_pool_2d_quant_conversion__padded(mocker): ops_spy = mocker.spy(ModelBuilder, "finish") # Run conversion - _ = to_quantized_edge_program(model, input_shape) + _ = to_quantized_edge_program(model, input_shape, use_qat=use_qat) # Capture the converter operators. ops = ops_spy.spy_return.sub_graphs[0].operators.vector diff --git a/backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py index 590b0be6a6b..e3ee2fff90b 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py @@ -22,6 +22,7 @@ ) from executorch.exir.dialects._ops import ops as exir_ops from torch.export import ExportedProgram +from executorch.backends.nxp.tests.use_qat import * # noqa F403 def _normalized_dim(dim, rank): @@ -84,13 +85,13 @@ def forward(self, *inputs: torch.Tensor): pytest.param(4, 5, -3, id="4D, 5 inputs, dim=-3"), ], ) -def test_cat__same_shapes(dim, num_inputs, rank, mocker): +def test_cat__same_shapes(dim, num_inputs, rank, mocker, use_qat): input_shape = tuple([8, 8, 8, 8][:rank]) converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") quantized_program = to_quantized_edge_program( - CatModule(dim), [input_shape] * num_inputs + CatModule(dim), [input_shape] * num_inputs, use_qat=use_qat ).exported_program() # Make sure the `Cat` was delegated. @@ -115,13 +116,13 @@ def test_cat__same_shapes(dim, num_inputs, rank, mocker): @pytest.mark.parametrize("dim", [3, -2, -3]) @pytest.mark.parametrize("num_inputs", [2, 5]) -def test_cat__channels_first__same_shapes(dim, num_inputs, mocker): +def test_cat__channels_first__same_shapes(dim, num_inputs, mocker, use_qat): input_shape = (2, 8, 6, 8) converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") channels = input_shape[1] if dim not in {1, -3} else input_shape[1] * num_inputs quantized_program = to_quantized_edge_program( - CatConvModule(dim, channels), [input_shape] * num_inputs + CatConvModule(dim, channels), [input_shape] * num_inputs, use_qat=use_qat ).exported_program() # Make sure the `Cat` was delegated. @@ -158,13 +159,13 @@ def test_cat__channels_first__same_shapes(dim, num_inputs, mocker): pytest.param(-2, (1, 1, 1, 8), id="axis = -2"), ], ) -def test_cat__unsupported__imxrt700(dim, input_shape): +def test_cat__unsupported__imxrt700(dim, input_shape, use_qat): """This test is conjoined with the one below (`test_cat__context_dependent__imxrt700`). In this case, the inputs of the `cat` are NOT compute ops, so the `cat` is NOT delegated. """ num_inputs = 2 quantized_program = to_quantized_edge_program( - CatModule(dim), [input_shape] * num_inputs, target="imxrt700" + CatModule(dim), [input_shape] * num_inputs, target="imxrt700", use_qat=use_qat ).exported_program() # Make sure the `Cat` was NOT delegated. @@ -188,13 +189,16 @@ def test_cat__unsupported__imxrt700(dim, input_shape): pytest.param(-2, (1, 1, 1, 8), id="axis = -2"), ], ) -def test_cat__context_dependent__imxrt700(dim, input_shape): +def test_cat__context_dependent__imxrt700(dim, input_shape, use_qat): """This test is conjoined with the one above (`test_cat__unsupported__imxrt700`). In this case, the inputs of the `cat` are compute ops, so the `cat` is delegated. """ num_inputs = 2 ep = to_quantized_edge_program( - AddCatModule(dim), [input_shape] * num_inputs, target="imxrt700" + AddCatModule(dim), + [input_shape] * num_inputs, + target="imxrt700", + use_qat=use_qat, ).exported_program() # Make sure the `Cat` was delegated. @@ -218,7 +222,7 @@ def test_cat__context_dependent__imxrt700(dim, input_shape): pytest.param(4, 5, -3, id="4D, 5 inputs, dim=-3"), ], ) -def test_cat__different_shapes(dim, num_inputs, rank, mocker): +def test_cat__different_shapes(dim, num_inputs, rank, mocker, use_qat): input_shape = tuple([2, 8, 8, 8, 8][-rank:]) # The shape of every input will be different along the concatenated dimension. @@ -231,7 +235,7 @@ def test_cat__different_shapes(dim, num_inputs, rank, mocker): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") quantized_program = to_quantized_edge_program( - CatModule(dim), input_shapes + CatModule(dim), input_shapes, use_qat=use_qat ).exported_program() # Make sure the `Cat` was delegated. @@ -258,7 +262,7 @@ def test_cat__different_shapes(dim, num_inputs, rank, mocker): @pytest.mark.parametrize( "num_inputs", [2, 5], ids=lambda num_inputs: f"num_inputs = {num_inputs}" ) -def test_cat__channels_first__different_shapes(dim, num_inputs, mocker): +def test_cat__channels_first__different_shapes(dim, num_inputs, mocker, use_qat): input_shape = (2, 8, 6, 8) # The shape of every input will be different along the concatenated dimension. @@ -276,7 +280,7 @@ def test_cat__channels_first__different_shapes(dim, num_inputs, mocker): sum(shape[1] for shape in input_shapes) if dim in [1, -3] else input_shape[1] ) quantized_program = to_quantized_edge_program( - CatConvModule(dim, channels), input_shapes + CatConvModule(dim, channels), input_shapes, use_qat=use_qat ).exported_program() # Make sure the `Cat` was delegated. @@ -301,7 +305,7 @@ def test_cat__channels_first__different_shapes(dim, num_inputs, mocker): ) -def test_cat__different_shapes__unsupported_channels__imxrt700(): +def test_cat__different_shapes__unsupported_channels__imxrt700(use_qat): input_shape = (2, 4, 6, 7) # (channels % 8) != 0 num_inputs = 2 @@ -315,7 +319,7 @@ def test_cat__different_shapes__unsupported_channels__imxrt700(): input_shapes.append(tuple(tmp_shape)) quantized_program = to_quantized_edge_program( - CatModule(dim), input_shapes, target="imxrt700" + CatModule(dim), input_shapes, target="imxrt700", use_qat=use_qat ).exported_program() # Make sure the `Cat` was NOT delegated. @@ -327,7 +331,7 @@ def test_cat__different_shapes__unsupported_channels__imxrt700(): ) -def test_cat__force_delegate(): +def test_cat__force_delegate(use_qat): target = "imxrt700" # The Partitioner doesn't know if the `8` or the `1` will become the channels in the IR. Therefore, it would @@ -339,6 +343,7 @@ def test_cat__force_delegate(): [input_shape, input_shape], target=target, custom_delegation_options=CustomDelegationOptions(force_delegate_cat=True), + use_qat=use_qat, ).exported_program() # Make sure the `Cat` was delegated. @@ -348,7 +353,7 @@ def test_cat__force_delegate(): assert any("lowered_module" in node.name for node in quantized_program.graph.nodes) -def test_cat__same_shapes_converter_padding_last_dimension(): +def test_cat__same_shapes_converter_padding_last_dimension(use_qat): target = "imxrt700" # The Converter is capable of padding the last dimension of `cat` with the same input shapes. @@ -360,6 +365,7 @@ def test_cat__same_shapes_converter_padding_last_dimension(): target=target, neutron_converter_flavor="SDK_25_09", custom_delegation_options=CustomDelegationOptions(), + use_qat=use_qat, ).exported_program() # Make sure the `Cat` was delegated. @@ -369,7 +375,7 @@ def test_cat__same_shapes_converter_padding_last_dimension(): assert any("lowered_module" in node.name for node in quantized_program.graph.nodes) -def test_cat__same_shapes__channels_first__padding_channels(): +def test_cat__same_shapes__channels_first__padding_channels(use_qat): target = "imxrt700" # The Converter is capable of padding the last dimension of `cat` with the same input shapes. @@ -381,6 +387,7 @@ def test_cat__same_shapes__channels_first__padding_channels(): target=target, neutron_converter_flavor="SDK_25_09", custom_delegation_options=CustomDelegationOptions(), + use_qat=use_qat, ).exported_program() # Make sure the `Cat` was delegated. @@ -390,7 +397,7 @@ def test_cat__same_shapes__channels_first__padding_channels(): assert any("lowered_module" in node.name for node in quantized_program.graph.nodes) -def test_cat__same_shapes_converter_padding_middle_dimension(): +def test_cat__same_shapes_converter_padding_middle_dimension(use_qat): target = "imxrt700" # The Converter is not capable of padding the middle dimensions of `cat` with the same input shapes. @@ -401,6 +408,7 @@ def test_cat__same_shapes_converter_padding_middle_dimension(): [input_shape, input_shape], target=target, custom_delegation_options=CustomDelegationOptions(), + use_qat=use_qat, ).exported_program() # Make sure the `Cat` was NOT delegated. @@ -412,7 +420,7 @@ def test_cat__same_shapes_converter_padding_middle_dimension(): ) -def test_cat__format_specific_support__formatless(mocker): +def test_cat__format_specific_support__formatless(mocker, use_qat): # The last dim will end up being the channels, as the format is `formatless`. # Only the last dim satisfies the Neutron requirements for the channels. input_shape = (3, 3, 3, 8) @@ -424,7 +432,7 @@ def test_cat__format_specific_support__formatless(mocker): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") quantized_program = to_quantized_edge_program( - CatModule(dim), input_shapes + CatModule(dim), input_shapes, use_qat=use_qat ).exported_program() # Make sure the `Cat` was delegated. @@ -447,7 +455,7 @@ def test_cat__format_specific_support__formatless(mocker): ) -def test_cat__format_specific_support__channels_first(mocker): +def test_cat__format_specific_support__channels_first(mocker, use_qat): # The second dim will end up being the channels, as the format is `formatless`. # Only the second dim satisfies the Neutron requirements for the channels. input_shape = (3, 8, 3, 3) @@ -462,7 +470,7 @@ def test_cat__format_specific_support__channels_first(mocker): sum(shape[1] for shape in input_shapes) if dim in [1, -3] else input_shape[1] ) quantized_program = to_quantized_edge_program( - CatConvModule(dim, channels), input_shapes + CatConvModule(dim, channels), input_shapes, use_qat=use_qat ).exported_program() # Make sure the `Cat` was delegated. diff --git a/backends/nxp/tests/ir/converter/node_converter/test_constant_pad_nd_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_constant_pad_nd_converter.py index 56be613a664..563b8c7393c 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_constant_pad_nd_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_constant_pad_nd_converter.py @@ -21,6 +21,7 @@ ConstantPadNDConvModule, ConstantPadNDModule, ) +from executorch.backends.nxp.tests.use_qat import * # noqa F403 from executorch.exir.dialects._ops import ops as exir_ops @@ -116,20 +117,24 @@ def test_constant_pad_nd_conversion__channels_first(input_shape, paddings): pytest.param((1, 1, 6, 8), (1, 2, 3, 4, 2, 1), id="4D, padding C, H, W"), ], ) -def test_constant_pad_nd__unsupported_paddings(input_shape, paddings): +def test_constant_pad_nd__unsupported_paddings(input_shape, paddings, use_qat): model = ConstantPadNDModule(paddings) - exec_program = to_quantized_edge_program(model, input_shape).exported_program() + exec_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() nodes = list(exec_program.graph.nodes) # There is at least one non-delegated Pad node assert any(node.name == "aten_constant_pad_nd_default" for node in nodes) -def test_constant_pad_nd__delegation__formatless__supported_padding(): +def test_constant_pad_nd__delegation__formatless__supported_padding(use_qat): input_shape = (2, 4, 6, 8) # Formatless -> the last dim (8) will be padded. paddings = [0, 0, 1, 2, 3, 4] # The last dim is padded using the first 2 paddings. model = ConstantPadNDModule(paddings) - exec_program = to_quantized_edge_program(model, input_shape).exported_program() + exec_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() # Make sure the `pad` was delegated. assert not graph_contains_any_of_ops( @@ -137,11 +142,13 @@ def test_constant_pad_nd__delegation__formatless__supported_padding(): ) -def test_constant_pad_nd__delegation__formatless__unsupported_padding(): +def test_constant_pad_nd__delegation__formatless__unsupported_padding(use_qat): input_shape = (2, 4, 6, 8) # Formatless -> the last dim (8) will be padded. paddings = [0, 1] # The last dim is padded using the first 2 paddings. model = ConstantPadNDModule(paddings) - exec_program = to_quantized_edge_program(model, input_shape).exported_program() + exec_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() # Make sure the `pad` was NOT delegated. assert graph_contains_any_of_ops( @@ -149,11 +156,13 @@ def test_constant_pad_nd__delegation__formatless__unsupported_padding(): ) -def test_constant_pad_nd__delegation__channels_first__supported_padding(): +def test_constant_pad_nd__delegation__channels_first__supported_padding(use_qat): input_shape = (2, 4, 6, 8) # Channels first -> the second dim (4) will be padded. paddings = [1, 2, 3, 4, 0, 0] # The second dim is padded using the paddings[4:6]. model = ConstantPadNDConvModule(paddings) - exec_program = to_quantized_edge_program(model, input_shape).exported_program() + exec_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() # Make sure the `pad` was delegated. assert not graph_contains_any_of_ops( @@ -161,11 +170,13 @@ def test_constant_pad_nd__delegation__channels_first__supported_padding(): ) -def test_constant_pad_nd__delegation__channels_first__unsupported_padding(): +def test_constant_pad_nd__delegation__channels_first__unsupported_padding(use_qat): input_shape = (2, 3, 6, 8) # Channels first -> the second dim (3) will be padded. paddings = [0, 0, 0, 0, 1, 0] # The second dim is padded using the paddings[4:6]. model = ConstantPadNDConvModule(paddings) - exec_program = to_quantized_edge_program(model, input_shape).exported_program() + exec_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() # Make sure the `pad` was NOT delegated. assert graph_contains_any_of_ops( diff --git a/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py index d7a59cad6d6..4a30eeea5b7 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py @@ -27,6 +27,7 @@ ) from executorch.backends.nxp.tests.models import Conv1dModule, Conv2dModule from torch.export import ExportedProgram +from executorch.backends.nxp.tests.use_qat import * # noqa F403 @pytest.fixture(autouse=True) @@ -38,14 +39,14 @@ def reseed_model_per_test_run(): @pytest.mark.parametrize("stride", [1, 2]) @pytest.mark.parametrize("dilation", [2, 1]) @pytest.mark.parametrize("kernel_size", [(1,), (3,)]) -def test_conv1d_quant_conversion(stride, dilation, kernel_size, mocker): +def test_conv1d_quant_conversion(stride, dilation, kernel_size, mocker, use_qat): input_shape = (1, 4, 16) model = Conv1dModule(stride=stride, dilation=dilation, kernel_size=kernel_size) converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") ops_spy = mocker.spy(ModelBuilder, "finish") # Run conversion - _ = to_quantized_edge_program(model, input_shape) + _ = to_quantized_edge_program(model, input_shape, use_qat=use_qat) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return @@ -90,7 +91,7 @@ def test_conv1d_quant_conversion(stride, dilation, kernel_size, mocker): ) @pytest.mark.parametrize("padding", [(1,), 2]) def test_conv1d_quant_conversion__padded( - stride, dilation, kernel_size, padding, mocker + stride, dilation, kernel_size, padding, mocker, use_qat ): input_shape = (1, 4, 16) model = Conv1dModule( @@ -100,7 +101,7 @@ def test_conv1d_quant_conversion__padded( ops_spy = mocker.spy(ModelBuilder, "finish") # Run conversion - _ = to_quantized_edge_program(model, input_shape) + _ = to_quantized_edge_program(model, input_shape, use_qat=use_qat) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return @@ -145,7 +146,9 @@ def test_conv1d_quant_conversion__padded( @pytest.mark.parametrize("stride", [1, 2]) @pytest.mark.parametrize("dilation", [2, 1]) @pytest.mark.parametrize("kernel_size", [(1,), (3,)]) -def test_conv1d_quant_conversion__depthwise(stride, dilation, kernel_size, mocker): +def test_conv1d_quant_conversion__depthwise( + stride, dilation, kernel_size, mocker, use_qat +): input_shape = (1, 4, 16) group = input_shape[1] model = Conv1dModule( @@ -160,7 +163,7 @@ def test_conv1d_quant_conversion__depthwise(stride, dilation, kernel_size, mocke ops_spy = mocker.spy(ModelBuilder, "finish") # Run conversion - _ = to_quantized_edge_program(model, input_shape) + _ = to_quantized_edge_program(model, input_shape, use_qat=use_qat) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return @@ -204,7 +207,7 @@ def test_conv1d_quant_conversion__depthwise(stride, dilation, kernel_size, mocke ) @pytest.mark.parametrize("padding", [(1,), 2]) def test_conv1d_quant_conversion__depthwise__padded( - stride, dilation, kernel_size, padding, mocker + stride, dilation, kernel_size, padding, mocker, use_qat ): input_shape = (1, 4, 16) group = input_shape[1] @@ -221,7 +224,7 @@ def test_conv1d_quant_conversion__depthwise__padded( ops_spy = mocker.spy(ModelBuilder, "finish") # Run conversion - _ = to_quantized_edge_program(model, input_shape) + _ = to_quantized_edge_program(model, input_shape, use_qat=use_qat) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return @@ -371,11 +374,11 @@ def test_conv1d_quant_conversion__depthwise__padded( ), ], ) -def test_conv2d_quant_conversion(mocker, model: torch.nn.Module, input_shape): +def test_conv2d_quant_conversion(mocker, model: torch.nn.Module, input_shape, use_qat): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(model, input_shape) + _ = to_quantized_edge_program(model, input_shape, use_qat=use_qat) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return @@ -435,7 +438,7 @@ def test_conv2d_conversion__depthwise(stride, dilation, kernel_shape, mocker): @pytest.mark.parametrize("dilation", [1, 2]) @pytest.mark.parametrize("kernel_shape", [[1, 2], [3, 3], [4, 1]]) def test_conv2d_conversion__depthwise__quantized( - stride, dilation, kernel_shape, mocker + stride, dilation, kernel_shape, mocker, use_qat ): input_shape = (1, 4, 12, 12) group = input_shape[1] @@ -451,6 +454,7 @@ def test_conv2d_conversion__depthwise__quantized( kernel_size=kernel_shape, ), tuple(input_shape), + use_qat=use_qat, ).exported_program() ops = spy.spy_return.sub_graphs[0].operators.vector @@ -495,7 +499,7 @@ def test_conv2d_conversion__depthwise__padded(padding, mocker): @pytest.mark.parametrize("padding", [1, 2]) -def test_conv2d_conversion__depthwise__padded__quantized(padding, mocker): +def test_conv2d_conversion__depthwise__padded__quantized(padding, mocker, use_qat): input_shape = (1, 4, 12, 12) group = input_shape[1] spy = mocker.spy(ModelBuilder, "finish") @@ -505,6 +509,7 @@ def test_conv2d_conversion__depthwise__padded__quantized(padding, mocker): group=group, in_channels=group, out_channels=group, padding=padding ), tuple(input_shape), + use_qat=use_qat, ).exported_program() ops = spy.spy_return.sub_graphs[0].operators.vector diff --git a/backends/nxp/tests/ir/converter/node_converter/test_hardtanh_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_hardtanh_converter.py index c4bc559817b..e753a1704f2 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_hardtanh_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_hardtanh_converter.py @@ -23,6 +23,7 @@ from executorch.backends.nxp.tests.models import Conv2dWithActivation from executorch.exir.dialects._ops import ops as exir_ops from torch.export import ExportedProgram +from executorch.backends.nxp.tests.use_qat import * # noqa F403 @pytest.fixture(autouse=True) @@ -33,7 +34,7 @@ def reseed_model_per_test_run(): @pytest.mark.parametrize("input_shape", [(1, 3, 128, 128)]) @pytest.mark.parametrize("inplace", [True, False]) -def test_relu6_quant(mocker, input_shape: tuple[int], inplace: bool): +def test_relu6_quant(mocker, input_shape: tuple[int], inplace: bool, use_qat: bool): # The torch.nn.Relu6 inherits from torch.nn.Hardtanh, and hence represented as HardTanh in ATen. # Testing the hardtanh originated from torch.nn.Relu6 op. model = Conv2dWithActivation( @@ -42,7 +43,9 @@ def test_relu6_quant(mocker, input_shape: tuple[int], inplace: bool): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - quantized_program = to_quantized_edge_program(model, input_shape).exported_program() + quantized_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() tflite_flatbuffers_model, io_formats = converter_spy.spy_return exported_program: ExportedProgram = converter_spy.call_args.args[1] @@ -67,7 +70,11 @@ def test_relu6_quant(mocker, input_shape: tuple[int], inplace: bool): ) @pytest.mark.parametrize("inplace", [True, False]) def test_custom_hardtanh_quant( - mocker, input_shape: tuple[int], activation_range: tuple[int, int], inplace: bool + mocker, + input_shape: tuple[int], + activation_range: tuple[int, int], + inplace: bool, + use_qat: bool, ): # TODO(13063): This test suffers from non-ideal testing random quantization, because we always use range <0,1>. # We should update (decrease atol) when the Conv/Linear + Activation fuse at quantization is in place. @@ -79,7 +86,9 @@ def test_custom_hardtanh_quant( converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - quantized_program = to_quantized_edge_program(model, input_shape).exported_program() + quantized_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() tflite_flatbuffers_model, io_formats = converter_spy.spy_return exported_program: ExportedProgram = converter_spy.call_args.args[1] diff --git a/backends/nxp/tests/ir/converter/node_converter/test_max_pool_2d_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_max_pool_2d_converter.py index 50bbf100980..13e891194b9 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_max_pool_2d_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_max_pool_2d_converter.py @@ -24,6 +24,7 @@ from executorch.backends.xnnpack._passes import RemoveGetItemPass from executorch.exir.verification.verifier import EXIREdgeDialectVerifier from torch.export import ExportedProgram +from executorch.backends.nxp.tests.use_qat import * # noqa F403 @pytest.fixture(autouse=True) @@ -99,11 +100,13 @@ def test_max_pool_2d_conversion(input_shape, padding): ), ], ) -def test_max_pool_2d_quant_conversion(mocker, input_shape, padding): +def test_max_pool_2d_quant_conversion(mocker, input_shape, padding, use_qat): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(MaxPool2dConvModule(padding=padding), input_shape) + _ = to_quantized_edge_program( + MaxPool2dConvModule(padding=padding), input_shape, use_qat=use_qat + ) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return diff --git a/backends/nxp/tests/ir/converter/node_converter/test_mean_dim_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_mean_dim_converter.py index 4bbd89cc01d..8e15c9fb140 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_mean_dim_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_mean_dim_converter.py @@ -18,6 +18,7 @@ ToChannelLastPreprocess, ) from executorch.backends.nxp.tests.models import MeanDimConvModule, MeanDimLinearModule +from executorch.backends.nxp.tests.use_qat import * # noqa F403 from executorch.exir.dialects._ops import ops as exir_ops from torch.export import ExportedProgram @@ -47,13 +48,17 @@ def forward(self, x): pytest.param((1, 4, 8, 8), (3, 2), id="Dim 3, 2."), ], ) -def test_mean_dim_conv_quant_conversion(mocker, input_shape, dim, keepdim=True): +def test_mean_dim_conv_quant_conversion( + mocker, input_shape, dim, use_qat, keepdim=True +): model = MeanDimConvModule(dim, keepdim) converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - ep = to_quantized_edge_program(model, input_shape).exported_program() + ep = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() # Make sure the `mean.dim` was delegated. assert not graph_contains_any_of_ops(ep.graph, [exir_ops.edge.aten.mean.dim]) @@ -92,14 +97,16 @@ def test_mean_dim_conv_quant_conversion(mocker, input_shape, dim, keepdim=True): ], ) def test_mean_dim_linear_unsupported_quant_conversion( - mocker, input_shape, dim, keepdim + mocker, input_shape, dim, use_qat, keepdim ): model = MeanDimLinearModule(dim, keepdim) converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - edge_program = to_quantized_edge_program(model, input_shape).exported_program() + edge_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() nodes = list(edge_program.graph.nodes) # Last 2 dimensions are not used or keepdim is False, cannot be converted to MeanDim, node is not delegated @@ -137,13 +144,17 @@ def test_mean_dim_linear_unsupported_quant_conversion( pytest.param(True, id="Keep dim."), ], ) -def test_mean_dim_conv_unsupported_quant_conversion(mocker, input_shape, dim, keepdim): +def test_mean_dim_conv_unsupported_quant_conversion( + mocker, input_shape, dim, use_qat, keepdim +): model = MeanDimConvModule(dim, keepdim) converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - edge_program = to_quantized_edge_program(model, input_shape).exported_program() + edge_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() nodes = list(edge_program.graph.nodes) # Last 2 dimensions are not used or keepdim is False, cannot be converted to MeanDim, node is not delegated @@ -175,12 +186,16 @@ def test_mean_dim_conv_unsupported_quant_conversion(mocker, input_shape, dim, ke pytest.param((1, 2, 3, 8), (-2, -3), id="Dim -2, -3."), ], ) -def test_mean_dim__formatless__supported(mocker, input_shape, dim, keepdim=True): +def test_mean_dim__formatless__supported( + mocker, input_shape, dim, use_qat, keepdim=True +): model = MeanDimModule(dim, keepdim) converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - ep = to_quantized_edge_program(model, input_shape).exported_program() + ep = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() # Make sure the `mean.dim` was delegated. assert not graph_contains_any_of_ops(ep.graph, [exir_ops.edge.aten.mean.dim]) @@ -208,10 +223,12 @@ def test_mean_dim__formatless__supported(mocker, input_shape, dim, keepdim=True) pytest.param((1, 2, 3, 8), (2, 3), id="Dim 2, 3."), ], ) -def test_mean_dim__formatless__unsupported(input_shape, dim, keepdim=True): +def test_mean_dim__formatless__unsupported(input_shape, dim, use_qat, keepdim=True): model = MeanDimModule(dim, keepdim) - ep = to_quantized_edge_program(model, input_shape).exported_program() + ep = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() # Make sure the `mean.dim` was NOT delegated. assert graph_contains_any_of_ops(ep.graph, [exir_ops.edge.aten.mean.dim]) @@ -226,10 +243,14 @@ def test_mean_dim__formatless__unsupported(input_shape, dim, keepdim=True): ), ], ) -def test_mean_dim__formatless__unsupported_channels(input_shape, dim, keepdim=True): +def test_mean_dim__formatless__unsupported_channels( + input_shape, dim, use_qat, keepdim=True +): model = MeanDimModule(dim, keepdim) - ep = to_quantized_edge_program(model, input_shape).exported_program() + ep = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() # Make sure the `mean.dim` was NOT delegated. assert graph_contains_any_of_ops(ep.graph, [exir_ops.edge.aten.mean.dim]) @@ -244,13 +265,17 @@ def test_mean_dim__formatless__unsupported_channels(input_shape, dim, keepdim=Tr ), ], ) -def test_mean_dim__channels_first__unsupported_channels(input_shape, dim, keepdim=True): +def test_mean_dim__channels_first__unsupported_channels( + input_shape, dim, use_qat, keepdim=True +): model = MeanDimConvModule( dim, keepdim, out_channels=5 ) # Only multiples of 8 (num_macs) are supported. # Run conversion - ep = to_quantized_edge_program(model, input_shape).exported_program() + ep = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() # Make sure the `mean.dim` was NOT delegated. assert graph_contains_any_of_ops(ep.graph, [exir_ops.edge.aten.mean.dim]) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_permute_copy_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_permute_copy_converter.py index d25e2759cc8..de44e1bf470 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_permute_copy_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_permute_copy_converter.py @@ -18,6 +18,7 @@ ) from executorch.backends.nxp.tests.models import Conv2dModule from torch.export import ExportedProgram +from executorch.backends.nxp.tests.use_qat import * # noqa F403 @pytest.fixture(autouse=True) @@ -37,14 +38,16 @@ def forward(self, x): return torch.permute(x, self.new_dims) -def test_permute_copy_quant_conversion__with_bias(mocker): +def test_permute_copy_quant_conversion__with_bias(mocker, use_qat): input_shape = (1, 4, 8, 8) new_dims = (0, 2, 3, 1) converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(Conv2dPermuteCopyModule(new_dims), input_shape) + _ = to_quantized_edge_program( + Conv2dPermuteCopyModule(new_dims), input_shape, use_qat=use_qat + ) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return diff --git a/backends/nxp/tests/ir/converter/node_converter/test_relu_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_relu_converter.py index 8d903e3e0b5..88de9a92b55 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_relu_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_relu_converter.py @@ -21,6 +21,7 @@ ) from executorch.backends.nxp.tests.models import Conv2dModule, LinearModule, ReLUModule from torch.export import ExportedProgram +from executorch.backends.nxp.tests.use_qat import * # noqa F403 @pytest.fixture(autouse=True) @@ -62,12 +63,12 @@ def test_relu_conversion(): convert_run_compare(edge_program, input_data=input_data) -def test_relu_with_conv_quant_conversion(mocker): +def test_relu_with_conv_quant_conversion(mocker, use_qat): input_shape = (1, 4, 32, 32) converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(ConvReLUModule(), input_shape) + _ = to_quantized_edge_program(ConvReLUModule(), input_shape, use_qat=use_qat) # Capture generated model tflite_flatbuffers_model, _ = converter_spy.spy_return @@ -88,12 +89,12 @@ def test_relu_with_conv_quant_conversion(mocker): ) -def test_relu_with_linear_quant_conversion(mocker): +def test_relu_with_linear_quant_conversion(mocker, use_qat): input_shape = (256, 32) converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(LinearReLUModule(), input_shape) + _ = to_quantized_edge_program(LinearReLUModule(), input_shape, use_qat=use_qat) # Capture generated model tflite_flatbuffers_model, _ = converter_spy.spy_return diff --git a/backends/nxp/tests/ir/converter/node_converter/test_sigmoid_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_sigmoid_converter.py index c5d7d4d6a38..f8d92d97b77 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_sigmoid_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_sigmoid_converter.py @@ -20,6 +20,7 @@ from executorch.backends.nxp.tests.models import ConvWithSigmoid from torch import nn from torch.export import ExportedProgram +from executorch.backends.nxp.tests.use_qat import * # noqa F403 @pytest.fixture(autouse=True) @@ -28,12 +29,12 @@ def reseed_model_per_test_run(): np.random.seed(23) -def test_conv_sigmoid(mocker, input_shape: tuple[int] = (1, 3, 112, 112)): +def test_conv_sigmoid(mocker, use_qat, input_shape: tuple[int] = (1, 3, 112, 112)): model = ConvWithSigmoid(conv_in_channels=input_shape[1]) converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - to_quantized_edge_program(model, input_shape).exported_program() + to_quantized_edge_program(model, input_shape, use_qat=use_qat).exported_program() tflite_flatbuffers_model, io_formats = converter_spy.spy_return exported_program: ExportedProgram = converter_spy.call_args.args[1] @@ -59,12 +60,12 @@ def test_conv_sigmoid(mocker, input_shape: tuple[int] = (1, 3, 112, 112)): pytest.param((10, 3, 25, 25, 25), id="4D"), ], ) -def test_sigmoid_only(mocker, input_shape): +def test_sigmoid_only(mocker, use_qat, input_shape): model = nn.Sigmoid() converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - to_quantized_edge_program(model, input_shape).exported_program() + to_quantized_edge_program(model, input_shape, use_qat=use_qat).exported_program() tflite_flatbuffers_model, io_formats = converter_spy.spy_return exported_program: ExportedProgram = converter_spy.call_args.args[1] diff --git a/backends/nxp/tests/ir/converter/node_converter/test_sub_tensor_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_sub_tensor_converter.py index 98566ff1ad6..92602972065 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_sub_tensor_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_sub_tensor_converter.py @@ -22,6 +22,7 @@ ) from executorch.exir.dialects._ops import ops as exir_ops from torch.export import ExportedProgram +from executorch.backends.nxp.tests.use_qat import * # noqa F403 @pytest.fixture(autouse=True) @@ -39,13 +40,13 @@ def reseed_model_per_test_run(): pytest.param((1, 4, 8, 8), id="4D."), ], ) -def test_sub_tensor_quant_conversion(mocker, input_shape): +def test_sub_tensor_quant_conversion(mocker, input_shape, use_qat): model = SubTensorModule() converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(model, [input_shape, input_shape]) + _ = to_quantized_edge_program(model, [input_shape, input_shape], use_qat=use_qat) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return @@ -78,13 +79,13 @@ def test_sub_tensor_quant_conversion(mocker, input_shape): pytest.param((1, 4, 8, 8), id="4D."), ], ) -def test_sub_tensor_one_input_quant_conversion(mocker, input_shape): +def test_sub_tensor_one_input_quant_conversion(mocker, input_shape, use_qat): model = SubTensorOneInputModule() converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(model, input_shape) + _ = to_quantized_edge_program(model, input_shape, use_qat=use_qat) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return @@ -109,7 +110,7 @@ def test_sub_tensor_one_input_quant_conversion(mocker, input_shape): pytest.param((1, 4, 5, 5), id="4D, product of dims is not a multiple of 8."), ], ) -def test_sub_tensor_w_conv_quant_conversion(mocker, x_input_shape): +def test_sub_tensor_w_conv_quant_conversion(mocker, x_input_shape, use_qat): model = SubTensorConvModule() converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") @@ -118,7 +119,9 @@ def test_sub_tensor_w_conv_quant_conversion(mocker, x_input_shape): y_input_shape = (n, 8, h, w) # Run conversion - _ = to_quantized_edge_program(model, [x_input_shape, y_input_shape]) + _ = to_quantized_edge_program( + model, [x_input_shape, y_input_shape], use_qat=use_qat + ) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return @@ -159,13 +162,13 @@ def test_sub_tensor_w_conv_quant_conversion(mocker, x_input_shape): ], ) def test_sub_tensor_broadcasting_unsupported_quant_conversion( - x_input_shape, y_input_shape + x_input_shape, y_input_shape, use_qat ): model = SubTensorModule() # Run conversion edge_program = to_quantized_edge_program( - model, [x_input_shape, y_input_shape] + model, [x_input_shape, y_input_shape], use_qat=use_qat ).exported_program() nodes = list(edge_program.graph.nodes) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_view_copy_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_view_copy_converter.py index 448a9753000..709679f829f 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_view_copy_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_view_copy_converter.py @@ -35,6 +35,7 @@ ) from torch import nn from torch.export import ExportedProgram +from executorch.backends.nxp.tests.use_qat import * # noqa F403 @pytest.fixture(autouse=True) @@ -209,11 +210,13 @@ def test__formatless_to_formatless(mocker): pytest.param((8, 64), (1, 16, 4, 4), id="2D"), ], ) -def test_view_copy_w_linear_quant_conversion(mocker, input_shape, new_shape): +def test_view_copy_w_linear_quant_conversion(mocker, input_shape, new_shape, use_qat): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(LinearReshapeModule(new_shape=new_shape), input_shape) + _ = to_quantized_edge_program( + LinearReshapeModule(new_shape=new_shape), input_shape, use_qat=use_qat + ) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return @@ -234,7 +237,9 @@ def test_view_copy_w_linear_quant_conversion(mocker, input_shape, new_shape): pytest.param((1, 4, 16, 16), 196, id="4D"), ], ) -def test_view_w_conv_linear_quant_conversion(mocker, input_shape, channels_view_out): +def test_view_w_conv_linear_quant_conversion( + mocker, input_shape, channels_view_out, use_qat +): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion @@ -243,6 +248,7 @@ def test_view_w_conv_linear_quant_conversion(mocker, input_shape, channels_view_ channels=input_shape[1], channels_view_out=channels_view_out ), input_shape, + use_qat=use_qat, ) # Capture generated model diff --git a/backends/nxp/tests/use_qat.py b/backends/nxp/tests/use_qat.py new file mode 100644 index 00000000000..5994d5aa193 --- /dev/null +++ b/backends/nxp/tests/use_qat.py @@ -0,0 +1,11 @@ +import pytest + + +@pytest.fixture +def use_qat(request): + return request.param + + +def pytest_generate_tests(metafunc): + if "use_qat" in metafunc.fixturenames: + metafunc.parametrize("use_qat", [True, False], indirect=True)