Skip to content

Commit 1c03e12

Browse files
committed
Add QAT support for NeutronQuantizer
1 parent b24c39a commit 1c03e12

File tree

2 files changed

+138
-69
lines changed

2 files changed

+138
-69
lines changed

backends/nxp/quantizer/neutron_quantizer.py

Lines changed: 97 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,11 @@
4949
)
5050
from torch import fx
5151
from torch.ao.quantization.quantizer.utils import _annotate_output_qspec
52-
from torchao.quantization.pt2e import HistogramObserver, MinMaxObserver
52+
from torchao.quantization.pt2e import (
53+
FusedMovingAvgObsFakeQuantize,
54+
HistogramObserver,
55+
MinMaxObserver,
56+
)
5357
from torchao.quantization.pt2e.quantizer import (
5458
ComposableQuantizer,
5559
DerivedQuantizationSpec,
@@ -149,74 +153,109 @@ def get_supported_operators(cls) -> list[OperatorConfig]:
149153

150154

151155
# Quantization Specification used by Neutron NPU
152-
act_qspec = QuantizationSpec(
153-
dtype=torch.int8,
154-
quant_min=-128,
155-
quant_max=127,
156-
qscheme=torch.per_tensor_affine,
157-
is_dynamic=False,
158-
observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12),
159-
)
156+
def act_qspec(is_qat: bool):
157+
observer_or_fake_quant_ctr = (
158+
FusedMovingAvgObsFakeQuantize
159+
if is_qat
160+
else HistogramObserver.with_args(eps=2**-12)
161+
)
162+
163+
return QuantizationSpec(
164+
dtype=torch.int8,
165+
quant_min=-128,
166+
quant_max=127,
167+
qscheme=torch.per_tensor_affine,
168+
is_dynamic=False,
169+
observer_or_fake_quant_ctr=observer_or_fake_quant_ctr,
170+
)
171+
172+
173+
def wgt_qspec(is_qat: bool):
174+
observer_or_fake_quant_ctr = (
175+
FusedMovingAvgObsFakeQuantize if is_qat else MinMaxObserver
176+
)
177+
178+
return QuantizationSpec(
179+
dtype=torch.int8,
180+
quant_min=-127,
181+
quant_max=127,
182+
qscheme=torch.per_tensor_symmetric,
183+
is_dynamic=False,
184+
observer_or_fake_quant_ctr=observer_or_fake_quant_ctr,
185+
ch_axis=0,
186+
)
187+
188+
189+
def wgt_fc_qspec(is_qat: bool):
190+
observer_or_fake_quant_ctr = (
191+
FusedMovingAvgObsFakeQuantize if is_qat else MinMaxObserver
192+
)
193+
194+
return QuantizationSpec(
195+
dtype=torch.int8,
196+
quant_min=-127,
197+
quant_max=127,
198+
qscheme=torch.per_tensor_symmetric,
199+
is_dynamic=False,
200+
observer_or_fake_quant_ctr=observer_or_fake_quant_ctr,
201+
)
160202

161-
wgt_qspec = QuantizationSpec(
162-
dtype=torch.int8,
163-
quant_min=-127,
164-
quant_max=127,
165-
qscheme=torch.per_tensor_symmetric,
166-
is_dynamic=False,
167-
observer_or_fake_quant_ctr=MinMaxObserver,
168-
ch_axis=0,
169-
)
170-
171-
wgt_fc_qspec = QuantizationSpec(
172-
dtype=torch.int8,
173-
quant_min=-127,
174-
quant_max=127,
175-
qscheme=torch.per_tensor_symmetric,
176-
is_dynamic=False,
177-
observer_or_fake_quant_ctr=MinMaxObserver,
178-
)
179203

180204
# Is set by the *PatternQuantizer directly.
181205
bias_qspec = None
182206

183207

184208
class NeutronQuantizer(ComposableQuantizer):
185-
def __init__(self, neutron_target_spec: NeutronTargetSpec):
209+
def __init__(self, neutron_target_spec: NeutronTargetSpec, is_qat: bool = False):
186210
self.neutron_target_spec = neutron_target_spec
187-
static_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_qspec, None)
188-
static_fc_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_fc_qspec, None)
211+
self.is_qat = is_qat
212+
213+
static_qconfig = QuantizationConfig(
214+
act_qspec(is_qat=is_qat),
215+
act_qspec(is_qat=is_qat),
216+
wgt_qspec(is_qat=is_qat),
217+
None,
218+
)
219+
static_fc_qconfig = QuantizationConfig(
220+
act_qspec(is_qat=is_qat),
221+
act_qspec(is_qat=is_qat),
222+
wgt_fc_qspec(is_qat=is_qat),
223+
None,
224+
)
225+
226+
OpQuantizer = NeutronAtenQuantizer
189227
super().__init__(
190228
[
191-
NeutronAtenQuantizer(AbsPattern(), static_qconfig),
192-
NeutronAtenQuantizer(AdaptiveAvgPoolPattern(), static_qconfig),
193-
NeutronAtenQuantizer(AddTensorPattern(), static_qconfig),
194-
NeutronAtenQuantizer(AddmmPattern(self), static_fc_qconfig),
195-
NeutronAtenQuantizer(AvgPoolPattern(), static_qconfig),
196-
NeutronAtenQuantizer(CatPattern(), static_qconfig),
197-
NeutronAtenQuantizer(Conv1dPattern(), static_qconfig),
198-
NeutronAtenQuantizer(Conv2dPattern(self), static_qconfig),
199-
NeutronAtenQuantizer(DropoutPattern(), static_qconfig),
200-
NeutronAtenQuantizer(FlattenPattern(), static_qconfig),
201-
NeutronAtenQuantizer(HardTanhPattern(), static_qconfig),
202-
NeutronAtenQuantizer(HardTanhInPlacePattern(), static_qconfig),
203-
NeutronAtenQuantizer(LinearPattern(self), static_fc_qconfig),
204-
NeutronAtenQuantizer(MaxPoolPattern(), static_qconfig),
205-
NeutronAtenQuantizer(MeanDimPattern(), static_qconfig),
206-
NeutronAtenQuantizer(MmPattern(self), static_qconfig),
207-
NeutronAtenQuantizer(PadPattern(), static_qconfig),
208-
NeutronAtenQuantizer(PermutePattern(), static_qconfig),
209-
NeutronAtenQuantizer(ReluPattern(), static_qconfig),
210-
NeutronAtenQuantizer(ReluInPlacePattern(), static_qconfig),
211-
NeutronAtenQuantizer(ReshapePattern(), static_qconfig),
212-
NeutronAtenQuantizer(SigmoidPattern(), static_qconfig),
213-
NeutronAtenQuantizer(SoftMaxPattern(), static_qconfig),
214-
NeutronAtenQuantizer(SubTensorPattern(), static_qconfig),
215-
NeutronAtenQuantizer(TanhPattern(), static_qconfig),
216-
NeutronAtenQuantizer(TanhInPlacePattern(), static_qconfig),
217-
NeutronAtenQuantizer(ViewPattern(), static_qconfig),
229+
OpQuantizer(AbsPattern(is_qat=is_qat), static_qconfig),
230+
OpQuantizer(AdaptiveAvgPoolPattern(is_qat=is_qat), static_qconfig),
231+
OpQuantizer(AddTensorPattern(is_qat=is_qat), static_qconfig),
232+
OpQuantizer(AddmmPattern(self, is_qat=is_qat), static_fc_qconfig),
233+
OpQuantizer(AvgPoolPattern(is_qat=is_qat), static_qconfig),
234+
OpQuantizer(CatPattern(is_qat=is_qat), static_qconfig),
235+
OpQuantizer(Conv1dPattern(is_qat=is_qat), static_qconfig),
236+
OpQuantizer(Conv2dPattern(self, is_qat=is_qat), static_qconfig),
237+
OpQuantizer(DropoutPattern(is_qat=is_qat), static_qconfig),
238+
OpQuantizer(FlattenPattern(is_qat=is_qat), static_qconfig),
239+
OpQuantizer(HardTanhPattern(is_qat=is_qat), static_qconfig),
240+
OpQuantizer(HardTanhInPlacePattern(is_qat=is_qat), static_qconfig),
241+
OpQuantizer(LinearPattern(self, is_qat=is_qat), static_fc_qconfig),
242+
OpQuantizer(MaxPoolPattern(is_qat=is_qat), static_qconfig),
243+
OpQuantizer(MeanDimPattern(is_qat=is_qat), static_qconfig),
244+
OpQuantizer(MmPattern(self, is_qat=is_qat), static_qconfig),
245+
OpQuantizer(PadPattern(is_qat=is_qat), static_qconfig),
246+
OpQuantizer(PermutePattern(is_qat=is_qat), static_qconfig),
247+
OpQuantizer(ReluPattern(is_qat=is_qat), static_qconfig),
248+
OpQuantizer(ReluInPlacePattern(is_qat=is_qat), static_qconfig),
249+
OpQuantizer(ReshapePattern(is_qat=is_qat), static_qconfig),
250+
OpQuantizer(SigmoidPattern(is_qat=is_qat), static_qconfig),
251+
OpQuantizer(SoftMaxPattern(is_qat=is_qat), static_qconfig),
252+
OpQuantizer(SubTensorPattern(is_qat=is_qat), static_qconfig),
253+
OpQuantizer(TanhPattern(is_qat=is_qat), static_qconfig),
254+
OpQuantizer(TanhInPlacePattern(is_qat=is_qat), static_qconfig),
255+
OpQuantizer(ViewPattern(is_qat=is_qat), static_qconfig),
218256
]
219257
)
258+
220259
# Mapping ops defined in quantizer partition types to its quantizer
221260
self.op_to_quantizer = {
222261
pt: q for q in self.quantizers for pt in q.pattern.partition_types()
@@ -272,7 +311,7 @@ def _annotate_inputs(self, model: fx.GraphModule):
272311
continue
273312

274313
if node.op == "placeholder" and len(node.users) > 0:
275-
_annotate_output_qspec(node, act_qspec)
314+
_annotate_output_qspec(node, act_qspec(self.is_qat))
276315
self._mark_input_node_as_annotated(node)
277316

278317
def validate(self, model: torch.fx.GraphModule) -> None:

backends/nxp/quantizer/patterns.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@
1313
from executorch.backends.nxp.quantizer.utils import get_bias_qparams
1414
from torch import fx
1515
from torch._ops import OpOverload
16-
from torchao.quantization.pt2e import PerChannelMinMaxObserver
16+
from torchao.quantization.pt2e import (
17+
FakeQuantize,
18+
FixedQParamsFakeQuantize,
19+
PerChannelMinMaxObserver,
20+
)
1721
from torchao.quantization.pt2e.quantizer import (
1822
DerivedQuantizationSpec,
1923
FixedQParamsQuantizationSpec,
@@ -57,7 +61,8 @@ class PartitionAnchors:
5761
| tuple[fx.Node, NodeArgsIdx, SharedQuantizationSpec],
5862
] = field(default_factory=list)
5963
weights: list[
60-
tuple[fx.Node, NodeArgsIdx] | tuple[fx.Node, NodeArgsIdx, QuantizationSpec],
64+
tuple[fx.Node, NodeArgsIdx]
65+
| tuple[fx.Node, NodeArgsIdx, QuantizationSpec | FakeQuantize],
6166
] = field(default_factory=list)
6267
biases: list[
6368
tuple[fx.Node, NodeArgsIdx]
@@ -67,12 +72,20 @@ class PartitionAnchors:
6772
literals: list[tuple[fx.Node, NodeArgsIdx]] = field(default_factory=list)
6873
output: list[
6974
tuple[fx.Node]
70-
| tuple[fx.Node, FixedQParamsQuantizationSpec | SharedQuantizationSpec],
75+
| tuple[
76+
fx.Node,
77+
FixedQParamsQuantizationSpec
78+
| FixedQParamsFakeQuantize
79+
| SharedQuantizationSpec,
80+
],
7181
] = field(default_factory=list)
7282
empty: bool = False
7383

7484

7585
class QuantizationPattern(ABC):
86+
def __init__(self, is_qat: bool):
87+
self.is_qat = is_qat
88+
7689
@abstractmethod
7790
def partition_types(self) -> list[OpOverload]:
7891
"""
@@ -145,11 +158,15 @@ def get_anchors_for_fixed_quant_specs(
145158
zero_point: int,
146159
quant_min: int = -128,
147160
quant_max: int = 127,
161+
is_qat: bool = False,
148162
) -> PartitionAnchors:
149163
node = fused_partition[0].nodes[-1]
150164
assert len(fused_partition[0].input_nodes) == 1
151165

152-
qspec = FixedQParamsQuantizationSpec(
166+
QSpecOrFakeQuantize = (
167+
FixedQParamsFakeQuantize if is_qat else FixedQParamsQuantizationSpec
168+
)
169+
qspec_or_fake_quantize = QSpecOrFakeQuantize(
153170
dtype=torch.int8,
154171
scale=scale,
155172
zero_point=zero_point,
@@ -163,7 +180,7 @@ def get_anchors_for_fixed_quant_specs(
163180
weights=[],
164181
biases=[],
165182
output=[
166-
(node, qspec),
183+
(node, qspec_or_fake_quantize),
167184
],
168185
)
169186

@@ -187,11 +204,12 @@ def partition_types(self):
187204

188205

189206
class AddmmPattern(QuantizationPattern):
190-
def __init__(self, neutron_quantizer):
207+
def __init__(self, neutron_quantizer, is_qat: bool):
191208
self.neutron_quantizer = neutron_quantizer
192209
self.neutron_target_info = (
193210
self.neutron_quantizer.neutron_target_spec.neutron_target_info
194211
)
212+
self.is_qat = is_qat
195213

196214
def partition_types(self) -> list[OpOverload]:
197215
return [torch.ops.aten.addmm.default]
@@ -363,7 +381,12 @@ def get_anchors(
363381
ch_axis=0,
364382
)
365383

366-
weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver
384+
weight_observer_or_fake_quant_ctr = (
385+
# TODO: Check feasibility vs. HistogramObserver
386+
FakeQuantize.with_args(observer=PerChannelMinMaxObserver)
387+
if self.is_qat
388+
else PerChannelMinMaxObserver
389+
)
367390
weight_quantization_spec = QuantizationSpec(
368391
dtype=torch.int8,
369392
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr,
@@ -392,11 +415,12 @@ def partition_types(self) -> list[OpOverload]:
392415

393416

394417
class Conv2dPattern(ConvPattern):
395-
def __init__(self, neutron_quantizer):
418+
def __init__(self, neutron_quantizer, is_qat: bool):
396419
self.neutron_quantizer = neutron_quantizer
397420
self.neutron_target_info = (
398421
self.neutron_quantizer.neutron_target_spec.neutron_target_info
399422
)
423+
self.is_qat = is_qat
400424

401425
def partition_types(self) -> list[OpOverload]:
402426
return [torch.ops.aten.conv2d.default]
@@ -419,7 +443,11 @@ def get_anchors(
419443
ch_axis=0,
420444
)
421445

422-
weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver
446+
weight_observer_or_fake_quant_ctr = (
447+
FakeQuantize.with_args(observer=PerChannelMinMaxObserver)
448+
if self.is_qat
449+
else PerChannelMinMaxObserver
450+
)
423451
weight_quantization_spec = QuantizationSpec(
424452
dtype=torch.int8,
425453
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr,
@@ -511,11 +539,12 @@ def replacement_op(self):
511539

512540

513541
class LinearPattern(QuantizationPattern):
514-
def __init__(self, neutron_quantizer):
542+
def __init__(self, neutron_quantizer, is_qat: bool):
515543
self.neutron_quantizer = neutron_quantizer
516544
self.neutron_target_info = (
517545
self.neutron_quantizer.neutron_target_spec.neutron_target_info
518546
)
547+
self.is_qat = is_qat
519548

520549
def partition_types(self) -> list[OpOverload]:
521550
return [torch.ops.aten.linear.default]
@@ -585,11 +614,12 @@ def partition_types(self):
585614

586615

587616
class MmPattern(QuantizationPattern):
588-
def __init__(self, neutron_quantizer):
617+
def __init__(self, neutron_quantizer, is_qat: bool):
589618
self.neutron_quantizer = neutron_quantizer
590619
self.neutron_target_info = (
591620
self.neutron_quantizer.neutron_target_spec.neutron_target_info
592621
)
622+
self.is_qat = is_qat
593623

594624
def partition_types(self) -> list[OpOverload]:
595625
return [torch.ops.aten.mm.default]

0 commit comments

Comments
 (0)