|
49 | 49 | ) |
50 | 50 | from torch import fx |
51 | 51 | from torch.ao.quantization.quantizer.utils import _annotate_output_qspec |
52 | | -from torchao.quantization.pt2e import HistogramObserver, MinMaxObserver |
| 52 | +from torchao.quantization.pt2e import ( |
| 53 | + FusedMovingAvgObsFakeQuantize, |
| 54 | + HistogramObserver, |
| 55 | + MinMaxObserver, |
| 56 | +) |
53 | 57 | from torchao.quantization.pt2e.quantizer import ( |
54 | 58 | ComposableQuantizer, |
55 | 59 | DerivedQuantizationSpec, |
@@ -149,74 +153,109 @@ def get_supported_operators(cls) -> list[OperatorConfig]: |
149 | 153 |
|
150 | 154 |
|
151 | 155 | # Quantization Specification used by Neutron NPU |
152 | | -act_qspec = QuantizationSpec( |
153 | | - dtype=torch.int8, |
154 | | - quant_min=-128, |
155 | | - quant_max=127, |
156 | | - qscheme=torch.per_tensor_affine, |
157 | | - is_dynamic=False, |
158 | | - observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12), |
159 | | -) |
| 156 | +def act_qspec(is_qat: bool): |
| 157 | + observer_or_fake_quant_ctr = ( |
| 158 | + FusedMovingAvgObsFakeQuantize |
| 159 | + if is_qat |
| 160 | + else HistogramObserver.with_args(eps=2**-12) |
| 161 | + ) |
| 162 | + |
| 163 | + return QuantizationSpec( |
| 164 | + dtype=torch.int8, |
| 165 | + quant_min=-128, |
| 166 | + quant_max=127, |
| 167 | + qscheme=torch.per_tensor_affine, |
| 168 | + is_dynamic=False, |
| 169 | + observer_or_fake_quant_ctr=observer_or_fake_quant_ctr, |
| 170 | + ) |
| 171 | + |
| 172 | + |
| 173 | +def wgt_qspec(is_qat: bool): |
| 174 | + observer_or_fake_quant_ctr = ( |
| 175 | + FusedMovingAvgObsFakeQuantize if is_qat else MinMaxObserver |
| 176 | + ) |
| 177 | + |
| 178 | + return QuantizationSpec( |
| 179 | + dtype=torch.int8, |
| 180 | + quant_min=-127, |
| 181 | + quant_max=127, |
| 182 | + qscheme=torch.per_tensor_symmetric, |
| 183 | + is_dynamic=False, |
| 184 | + observer_or_fake_quant_ctr=observer_or_fake_quant_ctr, |
| 185 | + ch_axis=0, |
| 186 | + ) |
| 187 | + |
| 188 | + |
| 189 | +def wgt_fc_qspec(is_qat: bool): |
| 190 | + observer_or_fake_quant_ctr = ( |
| 191 | + FusedMovingAvgObsFakeQuantize if is_qat else MinMaxObserver |
| 192 | + ) |
| 193 | + |
| 194 | + return QuantizationSpec( |
| 195 | + dtype=torch.int8, |
| 196 | + quant_min=-127, |
| 197 | + quant_max=127, |
| 198 | + qscheme=torch.per_tensor_symmetric, |
| 199 | + is_dynamic=False, |
| 200 | + observer_or_fake_quant_ctr=observer_or_fake_quant_ctr, |
| 201 | + ) |
160 | 202 |
|
161 | | -wgt_qspec = QuantizationSpec( |
162 | | - dtype=torch.int8, |
163 | | - quant_min=-127, |
164 | | - quant_max=127, |
165 | | - qscheme=torch.per_tensor_symmetric, |
166 | | - is_dynamic=False, |
167 | | - observer_or_fake_quant_ctr=MinMaxObserver, |
168 | | - ch_axis=0, |
169 | | -) |
170 | | - |
171 | | -wgt_fc_qspec = QuantizationSpec( |
172 | | - dtype=torch.int8, |
173 | | - quant_min=-127, |
174 | | - quant_max=127, |
175 | | - qscheme=torch.per_tensor_symmetric, |
176 | | - is_dynamic=False, |
177 | | - observer_or_fake_quant_ctr=MinMaxObserver, |
178 | | -) |
179 | 203 |
|
180 | 204 | # Is set by the *PatternQuantizer directly. |
181 | 205 | bias_qspec = None |
182 | 206 |
|
183 | 207 |
|
184 | 208 | class NeutronQuantizer(ComposableQuantizer): |
185 | | - def __init__(self, neutron_target_spec: NeutronTargetSpec): |
| 209 | + def __init__(self, neutron_target_spec: NeutronTargetSpec, is_qat: bool = False): |
186 | 210 | self.neutron_target_spec = neutron_target_spec |
187 | | - static_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_qspec, None) |
188 | | - static_fc_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_fc_qspec, None) |
| 211 | + self.is_qat = is_qat |
| 212 | + |
| 213 | + static_qconfig = QuantizationConfig( |
| 214 | + act_qspec(is_qat=is_qat), |
| 215 | + act_qspec(is_qat=is_qat), |
| 216 | + wgt_qspec(is_qat=is_qat), |
| 217 | + None, |
| 218 | + ) |
| 219 | + static_fc_qconfig = QuantizationConfig( |
| 220 | + act_qspec(is_qat=is_qat), |
| 221 | + act_qspec(is_qat=is_qat), |
| 222 | + wgt_fc_qspec(is_qat=is_qat), |
| 223 | + None, |
| 224 | + ) |
| 225 | + |
| 226 | + OpQuantizer = NeutronAtenQuantizer |
189 | 227 | super().__init__( |
190 | 228 | [ |
191 | | - NeutronAtenQuantizer(AbsPattern(), static_qconfig), |
192 | | - NeutronAtenQuantizer(AdaptiveAvgPoolPattern(), static_qconfig), |
193 | | - NeutronAtenQuantizer(AddTensorPattern(), static_qconfig), |
194 | | - NeutronAtenQuantizer(AddmmPattern(self), static_fc_qconfig), |
195 | | - NeutronAtenQuantizer(AvgPoolPattern(), static_qconfig), |
196 | | - NeutronAtenQuantizer(CatPattern(), static_qconfig), |
197 | | - NeutronAtenQuantizer(Conv1dPattern(), static_qconfig), |
198 | | - NeutronAtenQuantizer(Conv2dPattern(self), static_qconfig), |
199 | | - NeutronAtenQuantizer(DropoutPattern(), static_qconfig), |
200 | | - NeutronAtenQuantizer(FlattenPattern(), static_qconfig), |
201 | | - NeutronAtenQuantizer(HardTanhPattern(), static_qconfig), |
202 | | - NeutronAtenQuantizer(HardTanhInPlacePattern(), static_qconfig), |
203 | | - NeutronAtenQuantizer(LinearPattern(self), static_fc_qconfig), |
204 | | - NeutronAtenQuantizer(MaxPoolPattern(), static_qconfig), |
205 | | - NeutronAtenQuantizer(MeanDimPattern(), static_qconfig), |
206 | | - NeutronAtenQuantizer(MmPattern(self), static_qconfig), |
207 | | - NeutronAtenQuantizer(PadPattern(), static_qconfig), |
208 | | - NeutronAtenQuantizer(PermutePattern(), static_qconfig), |
209 | | - NeutronAtenQuantizer(ReluPattern(), static_qconfig), |
210 | | - NeutronAtenQuantizer(ReluInPlacePattern(), static_qconfig), |
211 | | - NeutronAtenQuantizer(ReshapePattern(), static_qconfig), |
212 | | - NeutronAtenQuantizer(SigmoidPattern(), static_qconfig), |
213 | | - NeutronAtenQuantizer(SoftMaxPattern(), static_qconfig), |
214 | | - NeutronAtenQuantizer(SubTensorPattern(), static_qconfig), |
215 | | - NeutronAtenQuantizer(TanhPattern(), static_qconfig), |
216 | | - NeutronAtenQuantizer(TanhInPlacePattern(), static_qconfig), |
217 | | - NeutronAtenQuantizer(ViewPattern(), static_qconfig), |
| 229 | + OpQuantizer(AbsPattern(is_qat=is_qat), static_qconfig), |
| 230 | + OpQuantizer(AdaptiveAvgPoolPattern(is_qat=is_qat), static_qconfig), |
| 231 | + OpQuantizer(AddTensorPattern(is_qat=is_qat), static_qconfig), |
| 232 | + OpQuantizer(AddmmPattern(self, is_qat=is_qat), static_fc_qconfig), |
| 233 | + OpQuantizer(AvgPoolPattern(is_qat=is_qat), static_qconfig), |
| 234 | + OpQuantizer(CatPattern(is_qat=is_qat), static_qconfig), |
| 235 | + OpQuantizer(Conv1dPattern(is_qat=is_qat), static_qconfig), |
| 236 | + OpQuantizer(Conv2dPattern(self, is_qat=is_qat), static_qconfig), |
| 237 | + OpQuantizer(DropoutPattern(is_qat=is_qat), static_qconfig), |
| 238 | + OpQuantizer(FlattenPattern(is_qat=is_qat), static_qconfig), |
| 239 | + OpQuantizer(HardTanhPattern(is_qat=is_qat), static_qconfig), |
| 240 | + OpQuantizer(HardTanhInPlacePattern(is_qat=is_qat), static_qconfig), |
| 241 | + OpQuantizer(LinearPattern(self, is_qat=is_qat), static_fc_qconfig), |
| 242 | + OpQuantizer(MaxPoolPattern(is_qat=is_qat), static_qconfig), |
| 243 | + OpQuantizer(MeanDimPattern(is_qat=is_qat), static_qconfig), |
| 244 | + OpQuantizer(MmPattern(self, is_qat=is_qat), static_qconfig), |
| 245 | + OpQuantizer(PadPattern(is_qat=is_qat), static_qconfig), |
| 246 | + OpQuantizer(PermutePattern(is_qat=is_qat), static_qconfig), |
| 247 | + OpQuantizer(ReluPattern(is_qat=is_qat), static_qconfig), |
| 248 | + OpQuantizer(ReluInPlacePattern(is_qat=is_qat), static_qconfig), |
| 249 | + OpQuantizer(ReshapePattern(is_qat=is_qat), static_qconfig), |
| 250 | + OpQuantizer(SigmoidPattern(is_qat=is_qat), static_qconfig), |
| 251 | + OpQuantizer(SoftMaxPattern(is_qat=is_qat), static_qconfig), |
| 252 | + OpQuantizer(SubTensorPattern(is_qat=is_qat), static_qconfig), |
| 253 | + OpQuantizer(TanhPattern(is_qat=is_qat), static_qconfig), |
| 254 | + OpQuantizer(TanhInPlacePattern(is_qat=is_qat), static_qconfig), |
| 255 | + OpQuantizer(ViewPattern(is_qat=is_qat), static_qconfig), |
218 | 256 | ] |
219 | 257 | ) |
| 258 | + |
220 | 259 | # Mapping ops defined in quantizer partition types to its quantizer |
221 | 260 | self.op_to_quantizer = { |
222 | 261 | pt: q for q in self.quantizers for pt in q.pattern.partition_types() |
@@ -272,7 +311,7 @@ def _annotate_inputs(self, model: fx.GraphModule): |
272 | 311 | continue |
273 | 312 |
|
274 | 313 | if node.op == "placeholder" and len(node.users) > 0: |
275 | | - _annotate_output_qspec(node, act_qspec) |
| 314 | + _annotate_output_qspec(node, act_qspec(self.is_qat)) |
276 | 315 | self._mark_input_node_as_annotated(node) |
277 | 316 |
|
278 | 317 | def validate(self, model: torch.fx.GraphModule) -> None: |
|
0 commit comments