|
49 | 49 | ) |
50 | 50 | from torch import fx |
51 | 51 | from torch.ao.quantization.quantizer.utils import _annotate_output_qspec |
52 | | -from torchao.quantization.pt2e import HistogramObserver, MinMaxObserver |
| 52 | +from torchao.quantization.pt2e import ( |
| 53 | + FakeQuantize, |
| 54 | + FusedMovingAvgObsFakeQuantize, |
| 55 | + HistogramObserver, |
| 56 | + MinMaxObserver, |
| 57 | + MovingAverageMinMaxObserver, |
| 58 | +) |
53 | 59 | from torchao.quantization.pt2e.quantizer import ( |
54 | 60 | ComposableQuantizer, |
55 | 61 | DerivedQuantizationSpec, |
@@ -149,74 +155,116 @@ def get_supported_operators(cls) -> list[OperatorConfig]: |
149 | 155 |
|
150 | 156 |
|
151 | 157 | # Quantization Specification used by Neutron NPU |
152 | | -act_qspec = QuantizationSpec( |
153 | | - dtype=torch.int8, |
154 | | - quant_min=-128, |
155 | | - quant_max=127, |
156 | | - qscheme=torch.per_tensor_affine, |
157 | | - is_dynamic=False, |
158 | | - observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12), |
159 | | -) |
160 | | - |
161 | | -wgt_qspec = QuantizationSpec( |
162 | | - dtype=torch.int8, |
163 | | - quant_min=-127, |
164 | | - quant_max=127, |
165 | | - qscheme=torch.per_tensor_symmetric, |
166 | | - is_dynamic=False, |
167 | | - observer_or_fake_quant_ctr=MinMaxObserver, |
168 | | - ch_axis=0, |
169 | | -) |
| 158 | +def act_qspec(is_qat: bool): |
| 159 | + eps = 2**-12 |
| 160 | + observer_or_fake_quant_ctr = ( |
| 161 | + FusedMovingAvgObsFakeQuantize.with_args( |
| 162 | + observer=MovingAverageMinMaxObserver, eps=eps |
| 163 | + ) |
| 164 | + if is_qat |
| 165 | + else HistogramObserver.with_args(eps=eps) |
| 166 | + ) |
| 167 | + |
| 168 | + return QuantizationSpec( |
| 169 | + dtype=torch.int8, |
| 170 | + quant_min=-128, |
| 171 | + quant_max=127, |
| 172 | + qscheme=torch.per_tensor_affine, |
| 173 | + is_dynamic=False, |
| 174 | + observer_or_fake_quant_ctr=observer_or_fake_quant_ctr, |
| 175 | + ) |
| 176 | + |
| 177 | + |
| 178 | +def wgt_qspec(is_qat: bool): |
| 179 | + observer_or_fake_quant_ctr = ( |
| 180 | + FakeQuantize.with_args(observer=MovingAverageMinMaxObserver) |
| 181 | + if is_qat |
| 182 | + else MinMaxObserver |
| 183 | + ) |
| 184 | + |
| 185 | + return QuantizationSpec( |
| 186 | + dtype=torch.int8, |
| 187 | + quant_min=-127, |
| 188 | + quant_max=127, |
| 189 | + qscheme=torch.per_tensor_symmetric, |
| 190 | + is_dynamic=False, |
| 191 | + observer_or_fake_quant_ctr=observer_or_fake_quant_ctr, |
| 192 | + ch_axis=0, |
| 193 | + ) |
| 194 | + |
| 195 | + |
| 196 | +def wgt_fc_qspec(is_qat: bool): |
| 197 | + observer_or_fake_quant_ctr = ( |
| 198 | + FakeQuantize.with_args(observer=MovingAverageMinMaxObserver) |
| 199 | + if is_qat |
| 200 | + else MinMaxObserver |
| 201 | + ) |
| 202 | + |
| 203 | + return QuantizationSpec( |
| 204 | + dtype=torch.int8, |
| 205 | + quant_min=-127, |
| 206 | + quant_max=127, |
| 207 | + qscheme=torch.per_tensor_symmetric, |
| 208 | + is_dynamic=False, |
| 209 | + observer_or_fake_quant_ctr=observer_or_fake_quant_ctr, |
| 210 | + ) |
170 | 211 |
|
171 | | -wgt_fc_qspec = QuantizationSpec( |
172 | | - dtype=torch.int8, |
173 | | - quant_min=-127, |
174 | | - quant_max=127, |
175 | | - qscheme=torch.per_tensor_symmetric, |
176 | | - is_dynamic=False, |
177 | | - observer_or_fake_quant_ctr=MinMaxObserver, |
178 | | -) |
179 | 212 |
|
180 | 213 | # Is set by the *PatternQuantizer directly. |
181 | 214 | bias_qspec = None |
182 | 215 |
|
183 | 216 |
|
184 | 217 | class NeutronQuantizer(ComposableQuantizer): |
185 | | - def __init__(self, neutron_target_spec: NeutronTargetSpec): |
| 218 | + def __init__(self, neutron_target_spec: NeutronTargetSpec, is_qat: bool = False): |
186 | 219 | self.neutron_target_spec = neutron_target_spec |
187 | | - static_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_qspec, None) |
188 | | - static_fc_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_fc_qspec, None) |
| 220 | + self.is_qat = is_qat |
| 221 | + |
| 222 | + static_qconfig = QuantizationConfig( |
| 223 | + act_qspec(is_qat=is_qat), |
| 224 | + act_qspec(is_qat=is_qat), |
| 225 | + wgt_qspec(is_qat=is_qat), |
| 226 | + None, |
| 227 | + ) |
| 228 | + static_fc_qconfig = QuantizationConfig( |
| 229 | + act_qspec(is_qat=is_qat), |
| 230 | + act_qspec(is_qat=is_qat), |
| 231 | + wgt_fc_qspec(is_qat=is_qat), |
| 232 | + None, |
| 233 | + ) |
| 234 | + |
| 235 | + OpQuantizer = NeutronAtenQuantizer |
189 | 236 | super().__init__( |
190 | 237 | [ |
191 | | - NeutronAtenQuantizer(AbsPattern(), static_qconfig), |
192 | | - NeutronAtenQuantizer(AdaptiveAvgPoolPattern(), static_qconfig), |
193 | | - NeutronAtenQuantizer(AddTensorPattern(), static_qconfig), |
194 | | - NeutronAtenQuantizer(AddmmPattern(self), static_fc_qconfig), |
195 | | - NeutronAtenQuantizer(AvgPoolPattern(), static_qconfig), |
196 | | - NeutronAtenQuantizer(CatPattern(), static_qconfig), |
197 | | - NeutronAtenQuantizer(Conv1dPattern(), static_qconfig), |
198 | | - NeutronAtenQuantizer(Conv2dPattern(self), static_qconfig), |
199 | | - NeutronAtenQuantizer(DropoutPattern(), static_qconfig), |
200 | | - NeutronAtenQuantizer(FlattenPattern(), static_qconfig), |
201 | | - NeutronAtenQuantizer(HardTanhPattern(), static_qconfig), |
202 | | - NeutronAtenQuantizer(HardTanhInPlacePattern(), static_qconfig), |
203 | | - NeutronAtenQuantizer(LinearPattern(self), static_fc_qconfig), |
204 | | - NeutronAtenQuantizer(MaxPoolPattern(), static_qconfig), |
205 | | - NeutronAtenQuantizer(MeanDimPattern(), static_qconfig), |
206 | | - NeutronAtenQuantizer(MmPattern(self), static_qconfig), |
207 | | - NeutronAtenQuantizer(PadPattern(), static_qconfig), |
208 | | - NeutronAtenQuantizer(PermutePattern(), static_qconfig), |
209 | | - NeutronAtenQuantizer(ReluPattern(), static_qconfig), |
210 | | - NeutronAtenQuantizer(ReluInPlacePattern(), static_qconfig), |
211 | | - NeutronAtenQuantizer(ReshapePattern(), static_qconfig), |
212 | | - NeutronAtenQuantizer(SigmoidPattern(), static_qconfig), |
213 | | - NeutronAtenQuantizer(SoftMaxPattern(), static_qconfig), |
214 | | - NeutronAtenQuantizer(SubTensorPattern(), static_qconfig), |
215 | | - NeutronAtenQuantizer(TanhPattern(), static_qconfig), |
216 | | - NeutronAtenQuantizer(TanhInPlacePattern(), static_qconfig), |
217 | | - NeutronAtenQuantizer(ViewPattern(), static_qconfig), |
| 238 | + OpQuantizer(AbsPattern(is_qat=is_qat), static_qconfig), |
| 239 | + OpQuantizer(AdaptiveAvgPoolPattern(is_qat=is_qat), static_qconfig), |
| 240 | + OpQuantizer(AddTensorPattern(is_qat=is_qat), static_qconfig), |
| 241 | + OpQuantizer(AddmmPattern(self, is_qat=is_qat), static_fc_qconfig), |
| 242 | + OpQuantizer(AvgPoolPattern(is_qat=is_qat), static_qconfig), |
| 243 | + OpQuantizer(CatPattern(is_qat=is_qat), static_qconfig), |
| 244 | + OpQuantizer(Conv1dPattern(is_qat=is_qat), static_qconfig), |
| 245 | + OpQuantizer(Conv2dPattern(self, is_qat=is_qat), static_qconfig), |
| 246 | + OpQuantizer(DropoutPattern(is_qat=is_qat), static_qconfig), |
| 247 | + OpQuantizer(FlattenPattern(is_qat=is_qat), static_qconfig), |
| 248 | + OpQuantizer(HardTanhPattern(is_qat=is_qat), static_qconfig), |
| 249 | + OpQuantizer(HardTanhInPlacePattern(is_qat=is_qat), static_qconfig), |
| 250 | + OpQuantizer(LinearPattern(self, is_qat=is_qat), static_fc_qconfig), |
| 251 | + OpQuantizer(MaxPoolPattern(is_qat=is_qat), static_qconfig), |
| 252 | + OpQuantizer(MeanDimPattern(is_qat=is_qat), static_qconfig), |
| 253 | + OpQuantizer(MmPattern(self, is_qat=is_qat), static_qconfig), |
| 254 | + OpQuantizer(PadPattern(is_qat=is_qat), static_qconfig), |
| 255 | + OpQuantizer(PermutePattern(is_qat=is_qat), static_qconfig), |
| 256 | + OpQuantizer(ReluPattern(is_qat=is_qat), static_qconfig), |
| 257 | + OpQuantizer(ReluInPlacePattern(is_qat=is_qat), static_qconfig), |
| 258 | + OpQuantizer(ReshapePattern(is_qat=is_qat), static_qconfig), |
| 259 | + OpQuantizer(SigmoidPattern(is_qat=is_qat), static_qconfig), |
| 260 | + OpQuantizer(SoftMaxPattern(is_qat=is_qat), static_qconfig), |
| 261 | + OpQuantizer(SubTensorPattern(is_qat=is_qat), static_qconfig), |
| 262 | + OpQuantizer(TanhPattern(is_qat=is_qat), static_qconfig), |
| 263 | + OpQuantizer(TanhInPlacePattern(is_qat=is_qat), static_qconfig), |
| 264 | + OpQuantizer(ViewPattern(is_qat=is_qat), static_qconfig), |
218 | 265 | ] |
219 | 266 | ) |
| 267 | + |
220 | 268 | # Mapping ops defined in quantizer partition types to its quantizer |
221 | 269 | self.op_to_quantizer = { |
222 | 270 | pt: q for q in self.quantizers for pt in q.pattern.partition_types() |
@@ -272,7 +320,7 @@ def _annotate_inputs(self, model: fx.GraphModule): |
272 | 320 | continue |
273 | 321 |
|
274 | 322 | if node.op == "placeholder" and len(node.users) > 0: |
275 | | - _annotate_output_qspec(node, act_qspec) |
| 323 | + _annotate_output_qspec(node, act_qspec(self.is_qat)) |
276 | 324 | self._mark_input_node_as_annotated(node) |
277 | 325 |
|
278 | 326 | def validate(self, model: torch.fx.GraphModule) -> None: |
|
0 commit comments