Skip to content

Commit b05d8fc

Browse files
committed
Add QAT support for NeutronQuantizer
1 parent b24c39a commit b05d8fc

File tree

2 files changed

+147
-69
lines changed

2 files changed

+147
-69
lines changed

backends/nxp/quantizer/neutron_quantizer.py

Lines changed: 106 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,13 @@
4949
)
5050
from torch import fx
5151
from torch.ao.quantization.quantizer.utils import _annotate_output_qspec
52-
from torchao.quantization.pt2e import HistogramObserver, MinMaxObserver
52+
from torchao.quantization.pt2e import (
53+
FakeQuantize,
54+
FusedMovingAvgObsFakeQuantize,
55+
HistogramObserver,
56+
MinMaxObserver,
57+
MovingAverageMinMaxObserver,
58+
)
5359
from torchao.quantization.pt2e.quantizer import (
5460
ComposableQuantizer,
5561
DerivedQuantizationSpec,
@@ -149,74 +155,116 @@ def get_supported_operators(cls) -> list[OperatorConfig]:
149155

150156

151157
# Quantization Specification used by Neutron NPU
152-
act_qspec = QuantizationSpec(
153-
dtype=torch.int8,
154-
quant_min=-128,
155-
quant_max=127,
156-
qscheme=torch.per_tensor_affine,
157-
is_dynamic=False,
158-
observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12),
159-
)
160-
161-
wgt_qspec = QuantizationSpec(
162-
dtype=torch.int8,
163-
quant_min=-127,
164-
quant_max=127,
165-
qscheme=torch.per_tensor_symmetric,
166-
is_dynamic=False,
167-
observer_or_fake_quant_ctr=MinMaxObserver,
168-
ch_axis=0,
169-
)
158+
def act_qspec(is_qat: bool):
159+
eps = 2**-12
160+
observer_or_fake_quant_ctr = (
161+
FusedMovingAvgObsFakeQuantize.with_args(
162+
observer=MovingAverageMinMaxObserver, eps=eps
163+
)
164+
if is_qat
165+
else HistogramObserver.with_args(eps=eps)
166+
)
167+
168+
return QuantizationSpec(
169+
dtype=torch.int8,
170+
quant_min=-128,
171+
quant_max=127,
172+
qscheme=torch.per_tensor_affine,
173+
is_dynamic=False,
174+
observer_or_fake_quant_ctr=observer_or_fake_quant_ctr,
175+
)
176+
177+
178+
def wgt_qspec(is_qat: bool):
179+
observer_or_fake_quant_ctr = (
180+
FakeQuantize.with_args(observer=MovingAverageMinMaxObserver)
181+
if is_qat
182+
else MinMaxObserver
183+
)
184+
185+
return QuantizationSpec(
186+
dtype=torch.int8,
187+
quant_min=-127,
188+
quant_max=127,
189+
qscheme=torch.per_tensor_symmetric,
190+
is_dynamic=False,
191+
observer_or_fake_quant_ctr=observer_or_fake_quant_ctr,
192+
ch_axis=0,
193+
)
194+
195+
196+
def wgt_fc_qspec(is_qat: bool):
197+
observer_or_fake_quant_ctr = (
198+
FakeQuantize.with_args(observer=MovingAverageMinMaxObserver)
199+
if is_qat
200+
else MinMaxObserver
201+
)
202+
203+
return QuantizationSpec(
204+
dtype=torch.int8,
205+
quant_min=-127,
206+
quant_max=127,
207+
qscheme=torch.per_tensor_symmetric,
208+
is_dynamic=False,
209+
observer_or_fake_quant_ctr=observer_or_fake_quant_ctr,
210+
)
170211

171-
wgt_fc_qspec = QuantizationSpec(
172-
dtype=torch.int8,
173-
quant_min=-127,
174-
quant_max=127,
175-
qscheme=torch.per_tensor_symmetric,
176-
is_dynamic=False,
177-
observer_or_fake_quant_ctr=MinMaxObserver,
178-
)
179212

180213
# Is set by the *PatternQuantizer directly.
181214
bias_qspec = None
182215

183216

184217
class NeutronQuantizer(ComposableQuantizer):
185-
def __init__(self, neutron_target_spec: NeutronTargetSpec):
218+
def __init__(self, neutron_target_spec: NeutronTargetSpec, is_qat: bool = False):
186219
self.neutron_target_spec = neutron_target_spec
187-
static_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_qspec, None)
188-
static_fc_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_fc_qspec, None)
220+
self.is_qat = is_qat
221+
222+
static_qconfig = QuantizationConfig(
223+
act_qspec(is_qat=is_qat),
224+
act_qspec(is_qat=is_qat),
225+
wgt_qspec(is_qat=is_qat),
226+
None,
227+
)
228+
static_fc_qconfig = QuantizationConfig(
229+
act_qspec(is_qat=is_qat),
230+
act_qspec(is_qat=is_qat),
231+
wgt_fc_qspec(is_qat=is_qat),
232+
None,
233+
)
234+
235+
OpQuantizer = NeutronAtenQuantizer
189236
super().__init__(
190237
[
191-
NeutronAtenQuantizer(AbsPattern(), static_qconfig),
192-
NeutronAtenQuantizer(AdaptiveAvgPoolPattern(), static_qconfig),
193-
NeutronAtenQuantizer(AddTensorPattern(), static_qconfig),
194-
NeutronAtenQuantizer(AddmmPattern(self), static_fc_qconfig),
195-
NeutronAtenQuantizer(AvgPoolPattern(), static_qconfig),
196-
NeutronAtenQuantizer(CatPattern(), static_qconfig),
197-
NeutronAtenQuantizer(Conv1dPattern(), static_qconfig),
198-
NeutronAtenQuantizer(Conv2dPattern(self), static_qconfig),
199-
NeutronAtenQuantizer(DropoutPattern(), static_qconfig),
200-
NeutronAtenQuantizer(FlattenPattern(), static_qconfig),
201-
NeutronAtenQuantizer(HardTanhPattern(), static_qconfig),
202-
NeutronAtenQuantizer(HardTanhInPlacePattern(), static_qconfig),
203-
NeutronAtenQuantizer(LinearPattern(self), static_fc_qconfig),
204-
NeutronAtenQuantizer(MaxPoolPattern(), static_qconfig),
205-
NeutronAtenQuantizer(MeanDimPattern(), static_qconfig),
206-
NeutronAtenQuantizer(MmPattern(self), static_qconfig),
207-
NeutronAtenQuantizer(PadPattern(), static_qconfig),
208-
NeutronAtenQuantizer(PermutePattern(), static_qconfig),
209-
NeutronAtenQuantizer(ReluPattern(), static_qconfig),
210-
NeutronAtenQuantizer(ReluInPlacePattern(), static_qconfig),
211-
NeutronAtenQuantizer(ReshapePattern(), static_qconfig),
212-
NeutronAtenQuantizer(SigmoidPattern(), static_qconfig),
213-
NeutronAtenQuantizer(SoftMaxPattern(), static_qconfig),
214-
NeutronAtenQuantizer(SubTensorPattern(), static_qconfig),
215-
NeutronAtenQuantizer(TanhPattern(), static_qconfig),
216-
NeutronAtenQuantizer(TanhInPlacePattern(), static_qconfig),
217-
NeutronAtenQuantizer(ViewPattern(), static_qconfig),
238+
OpQuantizer(AbsPattern(is_qat=is_qat), static_qconfig),
239+
OpQuantizer(AdaptiveAvgPoolPattern(is_qat=is_qat), static_qconfig),
240+
OpQuantizer(AddTensorPattern(is_qat=is_qat), static_qconfig),
241+
OpQuantizer(AddmmPattern(self, is_qat=is_qat), static_fc_qconfig),
242+
OpQuantizer(AvgPoolPattern(is_qat=is_qat), static_qconfig),
243+
OpQuantizer(CatPattern(is_qat=is_qat), static_qconfig),
244+
OpQuantizer(Conv1dPattern(is_qat=is_qat), static_qconfig),
245+
OpQuantizer(Conv2dPattern(self, is_qat=is_qat), static_qconfig),
246+
OpQuantizer(DropoutPattern(is_qat=is_qat), static_qconfig),
247+
OpQuantizer(FlattenPattern(is_qat=is_qat), static_qconfig),
248+
OpQuantizer(HardTanhPattern(is_qat=is_qat), static_qconfig),
249+
OpQuantizer(HardTanhInPlacePattern(is_qat=is_qat), static_qconfig),
250+
OpQuantizer(LinearPattern(self, is_qat=is_qat), static_fc_qconfig),
251+
OpQuantizer(MaxPoolPattern(is_qat=is_qat), static_qconfig),
252+
OpQuantizer(MeanDimPattern(is_qat=is_qat), static_qconfig),
253+
OpQuantizer(MmPattern(self, is_qat=is_qat), static_qconfig),
254+
OpQuantizer(PadPattern(is_qat=is_qat), static_qconfig),
255+
OpQuantizer(PermutePattern(is_qat=is_qat), static_qconfig),
256+
OpQuantizer(ReluPattern(is_qat=is_qat), static_qconfig),
257+
OpQuantizer(ReluInPlacePattern(is_qat=is_qat), static_qconfig),
258+
OpQuantizer(ReshapePattern(is_qat=is_qat), static_qconfig),
259+
OpQuantizer(SigmoidPattern(is_qat=is_qat), static_qconfig),
260+
OpQuantizer(SoftMaxPattern(is_qat=is_qat), static_qconfig),
261+
OpQuantizer(SubTensorPattern(is_qat=is_qat), static_qconfig),
262+
OpQuantizer(TanhPattern(is_qat=is_qat), static_qconfig),
263+
OpQuantizer(TanhInPlacePattern(is_qat=is_qat), static_qconfig),
264+
OpQuantizer(ViewPattern(is_qat=is_qat), static_qconfig),
218265
]
219266
)
267+
220268
# Mapping ops defined in quantizer partition types to its quantizer
221269
self.op_to_quantizer = {
222270
pt: q for q in self.quantizers for pt in q.pattern.partition_types()
@@ -272,7 +320,7 @@ def _annotate_inputs(self, model: fx.GraphModule):
272320
continue
273321

274322
if node.op == "placeholder" and len(node.users) > 0:
275-
_annotate_output_qspec(node, act_qspec)
323+
_annotate_output_qspec(node, act_qspec(self.is_qat))
276324
self._mark_input_node_as_annotated(node)
277325

278326
def validate(self, model: torch.fx.GraphModule) -> None:

backends/nxp/quantizer/patterns.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
from executorch.backends.nxp.quantizer.utils import get_bias_qparams
1414
from torch import fx
1515
from torch._ops import OpOverload
16-
from torchao.quantization.pt2e import PerChannelMinMaxObserver
16+
from torchao.quantization.pt2e import (
17+
FakeQuantize,
18+
FixedQParamsFakeQuantize,
19+
MovingAveragePerChannelMinMaxObserver,
20+
PerChannelMinMaxObserver,
21+
)
1722
from torchao.quantization.pt2e.quantizer import (
1823
DerivedQuantizationSpec,
1924
FixedQParamsQuantizationSpec,
@@ -57,7 +62,8 @@ class PartitionAnchors:
5762
| tuple[fx.Node, NodeArgsIdx, SharedQuantizationSpec],
5863
] = field(default_factory=list)
5964
weights: list[
60-
tuple[fx.Node, NodeArgsIdx] | tuple[fx.Node, NodeArgsIdx, QuantizationSpec],
65+
tuple[fx.Node, NodeArgsIdx]
66+
| tuple[fx.Node, NodeArgsIdx, QuantizationSpec | FakeQuantize],
6167
] = field(default_factory=list)
6268
biases: list[
6369
tuple[fx.Node, NodeArgsIdx]
@@ -67,12 +73,20 @@ class PartitionAnchors:
6773
literals: list[tuple[fx.Node, NodeArgsIdx]] = field(default_factory=list)
6874
output: list[
6975
tuple[fx.Node]
70-
| tuple[fx.Node, FixedQParamsQuantizationSpec | SharedQuantizationSpec],
76+
| tuple[
77+
fx.Node,
78+
FixedQParamsQuantizationSpec
79+
| FixedQParamsFakeQuantize
80+
| SharedQuantizationSpec,
81+
],
7182
] = field(default_factory=list)
7283
empty: bool = False
7384

7485

7586
class QuantizationPattern(ABC):
87+
def __init__(self, is_qat: bool):
88+
self.is_qat = is_qat
89+
7690
@abstractmethod
7791
def partition_types(self) -> list[OpOverload]:
7892
"""
@@ -145,11 +159,15 @@ def get_anchors_for_fixed_quant_specs(
145159
zero_point: int,
146160
quant_min: int = -128,
147161
quant_max: int = 127,
162+
is_qat: bool = False,
148163
) -> PartitionAnchors:
149164
node = fused_partition[0].nodes[-1]
150165
assert len(fused_partition[0].input_nodes) == 1
151166

152-
qspec = FixedQParamsQuantizationSpec(
167+
QSpecOrFakeQuantize = (
168+
FixedQParamsFakeQuantize if is_qat else FixedQParamsQuantizationSpec
169+
)
170+
qspec_or_fake_quantize = QSpecOrFakeQuantize(
153171
dtype=torch.int8,
154172
scale=scale,
155173
zero_point=zero_point,
@@ -163,7 +181,7 @@ def get_anchors_for_fixed_quant_specs(
163181
weights=[],
164182
biases=[],
165183
output=[
166-
(node, qspec),
184+
(node, qspec_or_fake_quantize),
167185
],
168186
)
169187

@@ -187,11 +205,12 @@ def partition_types(self):
187205

188206

189207
class AddmmPattern(QuantizationPattern):
190-
def __init__(self, neutron_quantizer):
208+
def __init__(self, neutron_quantizer, is_qat: bool):
191209
self.neutron_quantizer = neutron_quantizer
192210
self.neutron_target_info = (
193211
self.neutron_quantizer.neutron_target_spec.neutron_target_info
194212
)
213+
self.is_qat = is_qat
195214

196215
def partition_types(self) -> list[OpOverload]:
197216
return [torch.ops.aten.addmm.default]
@@ -363,7 +382,11 @@ def get_anchors(
363382
ch_axis=0,
364383
)
365384

366-
weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver
385+
weight_observer_or_fake_quant_ctr = (
386+
FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver)
387+
if self.is_qat
388+
else PerChannelMinMaxObserver
389+
)
367390
weight_quantization_spec = QuantizationSpec(
368391
dtype=torch.int8,
369392
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr,
@@ -392,11 +415,12 @@ def partition_types(self) -> list[OpOverload]:
392415

393416

394417
class Conv2dPattern(ConvPattern):
395-
def __init__(self, neutron_quantizer):
418+
def __init__(self, neutron_quantizer, is_qat: bool):
396419
self.neutron_quantizer = neutron_quantizer
397420
self.neutron_target_info = (
398421
self.neutron_quantizer.neutron_target_spec.neutron_target_info
399422
)
423+
self.is_qat = is_qat
400424

401425
def partition_types(self) -> list[OpOverload]:
402426
return [torch.ops.aten.conv2d.default]
@@ -419,7 +443,11 @@ def get_anchors(
419443
ch_axis=0,
420444
)
421445

422-
weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver
446+
weight_observer_or_fake_quant_ctr = (
447+
FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver)
448+
if self.is_qat
449+
else PerChannelMinMaxObserver
450+
)
423451
weight_quantization_spec = QuantizationSpec(
424452
dtype=torch.int8,
425453
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr,
@@ -511,11 +539,12 @@ def replacement_op(self):
511539

512540

513541
class LinearPattern(QuantizationPattern):
514-
def __init__(self, neutron_quantizer):
542+
def __init__(self, neutron_quantizer, is_qat: bool):
515543
self.neutron_quantizer = neutron_quantizer
516544
self.neutron_target_info = (
517545
self.neutron_quantizer.neutron_target_spec.neutron_target_info
518546
)
547+
self.is_qat = is_qat
519548

520549
def partition_types(self) -> list[OpOverload]:
521550
return [torch.ops.aten.linear.default]
@@ -585,11 +614,12 @@ def partition_types(self):
585614

586615

587616
class MmPattern(QuantizationPattern):
588-
def __init__(self, neutron_quantizer):
617+
def __init__(self, neutron_quantizer, is_qat: bool):
589618
self.neutron_quantizer = neutron_quantizer
590619
self.neutron_target_info = (
591620
self.neutron_quantizer.neutron_target_spec.neutron_target_info
592621
)
622+
self.is_qat = is_qat
593623

594624
def partition_types(self) -> list[OpOverload]:
595625
return [torch.ops.aten.mm.default]

0 commit comments

Comments
 (0)