Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 106 additions & 58 deletions backends/nxp/quantizer/neutron_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,13 @@
)
from torch import fx
from torch.ao.quantization.quantizer.utils import _annotate_output_qspec
from torchao.quantization.pt2e import HistogramObserver, MinMaxObserver
from torchao.quantization.pt2e import (
FakeQuantize,
FusedMovingAvgObsFakeQuantize,
HistogramObserver,
MinMaxObserver,
MovingAverageMinMaxObserver,
)
from torchao.quantization.pt2e.quantizer import (
ComposableQuantizer,
DerivedQuantizationSpec,
Expand Down Expand Up @@ -149,74 +155,116 @@ def get_supported_operators(cls) -> list[OperatorConfig]:


# Quantization Specification used by Neutron NPU
act_qspec = QuantizationSpec(
dtype=torch.int8,
quant_min=-128,
quant_max=127,
qscheme=torch.per_tensor_affine,
is_dynamic=False,
observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12),
)

wgt_qspec = QuantizationSpec(
dtype=torch.int8,
quant_min=-127,
quant_max=127,
qscheme=torch.per_tensor_symmetric,
is_dynamic=False,
observer_or_fake_quant_ctr=MinMaxObserver,
ch_axis=0,
)
def act_qspec(is_qat: bool):
eps = 2**-12
observer_or_fake_quant_ctr = (
FusedMovingAvgObsFakeQuantize.with_args(
observer=MovingAverageMinMaxObserver, eps=eps
)
if is_qat
else HistogramObserver.with_args(eps=eps)
)

return QuantizationSpec(
dtype=torch.int8,
quant_min=-128,
quant_max=127,
qscheme=torch.per_tensor_affine,
is_dynamic=False,
observer_or_fake_quant_ctr=observer_or_fake_quant_ctr,
)


def wgt_qspec(is_qat: bool):
observer_or_fake_quant_ctr = (
FakeQuantize.with_args(observer=MovingAverageMinMaxObserver)
if is_qat
else MinMaxObserver
)

return QuantizationSpec(
dtype=torch.int8,
quant_min=-127,
quant_max=127,
qscheme=torch.per_tensor_symmetric,
is_dynamic=False,
observer_or_fake_quant_ctr=observer_or_fake_quant_ctr,
ch_axis=0,
)


def wgt_fc_qspec(is_qat: bool):
observer_or_fake_quant_ctr = (
FakeQuantize.with_args(observer=MovingAverageMinMaxObserver)
if is_qat
else MinMaxObserver
)

return QuantizationSpec(
dtype=torch.int8,
quant_min=-127,
quant_max=127,
qscheme=torch.per_tensor_symmetric,
is_dynamic=False,
observer_or_fake_quant_ctr=observer_or_fake_quant_ctr,
)

wgt_fc_qspec = QuantizationSpec(
dtype=torch.int8,
quant_min=-127,
quant_max=127,
qscheme=torch.per_tensor_symmetric,
is_dynamic=False,
observer_or_fake_quant_ctr=MinMaxObserver,
)

# Is set by the *PatternQuantizer directly.
bias_qspec = None


class NeutronQuantizer(ComposableQuantizer):
def __init__(self, neutron_target_spec: NeutronTargetSpec):
def __init__(self, neutron_target_spec: NeutronTargetSpec, is_qat: bool = False):
self.neutron_target_spec = neutron_target_spec
static_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_qspec, None)
static_fc_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_fc_qspec, None)
self.is_qat = is_qat

static_qconfig = QuantizationConfig(
act_qspec(is_qat=is_qat),
act_qspec(is_qat=is_qat),
wgt_qspec(is_qat=is_qat),
None,
)
static_fc_qconfig = QuantizationConfig(
act_qspec(is_qat=is_qat),
act_qspec(is_qat=is_qat),
wgt_fc_qspec(is_qat=is_qat),
None,
)

OpQuantizer = NeutronAtenQuantizer
super().__init__(
[
NeutronAtenQuantizer(AbsPattern(), static_qconfig),
NeutronAtenQuantizer(AdaptiveAvgPoolPattern(), static_qconfig),
NeutronAtenQuantizer(AddTensorPattern(), static_qconfig),
NeutronAtenQuantizer(AddmmPattern(self), static_fc_qconfig),
NeutronAtenQuantizer(AvgPoolPattern(), static_qconfig),
NeutronAtenQuantizer(CatPattern(), static_qconfig),
NeutronAtenQuantizer(Conv1dPattern(), static_qconfig),
NeutronAtenQuantizer(Conv2dPattern(self), static_qconfig),
NeutronAtenQuantizer(DropoutPattern(), static_qconfig),
NeutronAtenQuantizer(FlattenPattern(), static_qconfig),
NeutronAtenQuantizer(HardTanhPattern(), static_qconfig),
NeutronAtenQuantizer(HardTanhInPlacePattern(), static_qconfig),
NeutronAtenQuantizer(LinearPattern(self), static_fc_qconfig),
NeutronAtenQuantizer(MaxPoolPattern(), static_qconfig),
NeutronAtenQuantizer(MeanDimPattern(), static_qconfig),
NeutronAtenQuantizer(MmPattern(self), static_qconfig),
NeutronAtenQuantizer(PadPattern(), static_qconfig),
NeutronAtenQuantizer(PermutePattern(), static_qconfig),
NeutronAtenQuantizer(ReluPattern(), static_qconfig),
NeutronAtenQuantizer(ReluInPlacePattern(), static_qconfig),
NeutronAtenQuantizer(ReshapePattern(), static_qconfig),
NeutronAtenQuantizer(SigmoidPattern(), static_qconfig),
NeutronAtenQuantizer(SoftMaxPattern(), static_qconfig),
NeutronAtenQuantizer(SubTensorPattern(), static_qconfig),
NeutronAtenQuantizer(TanhPattern(), static_qconfig),
NeutronAtenQuantizer(TanhInPlacePattern(), static_qconfig),
NeutronAtenQuantizer(ViewPattern(), static_qconfig),
OpQuantizer(AbsPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(AdaptiveAvgPoolPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(AddTensorPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(AddmmPattern(self, is_qat=is_qat), static_fc_qconfig),
OpQuantizer(AvgPoolPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(CatPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(Conv1dPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(Conv2dPattern(self, is_qat=is_qat), static_qconfig),
OpQuantizer(DropoutPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(FlattenPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(HardTanhPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(HardTanhInPlacePattern(is_qat=is_qat), static_qconfig),
OpQuantizer(LinearPattern(self, is_qat=is_qat), static_fc_qconfig),
OpQuantizer(MaxPoolPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(MeanDimPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(MmPattern(self, is_qat=is_qat), static_qconfig),
OpQuantizer(PadPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(PermutePattern(is_qat=is_qat), static_qconfig),
OpQuantizer(ReluPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(ReluInPlacePattern(is_qat=is_qat), static_qconfig),
OpQuantizer(ReshapePattern(is_qat=is_qat), static_qconfig),
OpQuantizer(SigmoidPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(SoftMaxPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(SubTensorPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(TanhPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(TanhInPlacePattern(is_qat=is_qat), static_qconfig),
OpQuantizer(ViewPattern(is_qat=is_qat), static_qconfig),
]
)

# Mapping ops defined in quantizer partition types to its quantizer
self.op_to_quantizer = {
pt: q for q in self.quantizers for pt in q.pattern.partition_types()
Expand Down Expand Up @@ -272,7 +320,7 @@ def _annotate_inputs(self, model: fx.GraphModule):
continue

if node.op == "placeholder" and len(node.users) > 0:
_annotate_output_qspec(node, act_qspec)
_annotate_output_qspec(node, act_qspec(self.is_qat))
self._mark_input_node_as_annotated(node)

def validate(self, model: torch.fx.GraphModule) -> None:
Expand Down
52 changes: 41 additions & 11 deletions backends/nxp/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
from executorch.backends.nxp.quantizer.utils import get_bias_qparams
from torch import fx
from torch._ops import OpOverload
from torchao.quantization.pt2e import PerChannelMinMaxObserver
from torchao.quantization.pt2e import (
FakeQuantize,
FixedQParamsFakeQuantize,
MovingAveragePerChannelMinMaxObserver,
PerChannelMinMaxObserver,
)
from torchao.quantization.pt2e.quantizer import (
DerivedQuantizationSpec,
FixedQParamsQuantizationSpec,
Expand Down Expand Up @@ -57,7 +62,8 @@ class PartitionAnchors:
| tuple[fx.Node, NodeArgsIdx, SharedQuantizationSpec],
] = field(default_factory=list)
weights: list[
tuple[fx.Node, NodeArgsIdx] | tuple[fx.Node, NodeArgsIdx, QuantizationSpec],
tuple[fx.Node, NodeArgsIdx]
| tuple[fx.Node, NodeArgsIdx, QuantizationSpec | FakeQuantize],
] = field(default_factory=list)
biases: list[
tuple[fx.Node, NodeArgsIdx]
Expand All @@ -67,12 +73,20 @@ class PartitionAnchors:
literals: list[tuple[fx.Node, NodeArgsIdx]] = field(default_factory=list)
output: list[
tuple[fx.Node]
| tuple[fx.Node, FixedQParamsQuantizationSpec | SharedQuantizationSpec],
| tuple[
fx.Node,
FixedQParamsQuantizationSpec
| FixedQParamsFakeQuantize
| SharedQuantizationSpec,
],
] = field(default_factory=list)
empty: bool = False


class QuantizationPattern(ABC):
def __init__(self, is_qat: bool):
self.is_qat = is_qat

@abstractmethod
def partition_types(self) -> list[OpOverload]:
"""
Expand Down Expand Up @@ -145,11 +159,15 @@ def get_anchors_for_fixed_quant_specs(
zero_point: int,
quant_min: int = -128,
quant_max: int = 127,
is_qat: bool = False,
) -> PartitionAnchors:
node = fused_partition[0].nodes[-1]
assert len(fused_partition[0].input_nodes) == 1

qspec = FixedQParamsQuantizationSpec(
QSpecOrFakeQuantize = (
FixedQParamsFakeQuantize if is_qat else FixedQParamsQuantizationSpec
)
qspec_or_fake_quantize = QSpecOrFakeQuantize(
dtype=torch.int8,
scale=scale,
zero_point=zero_point,
Expand All @@ -163,7 +181,7 @@ def get_anchors_for_fixed_quant_specs(
weights=[],
biases=[],
output=[
(node, qspec),
(node, qspec_or_fake_quantize),
],
)

Expand All @@ -187,11 +205,12 @@ def partition_types(self):


class AddmmPattern(QuantizationPattern):
def __init__(self, neutron_quantizer):
def __init__(self, neutron_quantizer, is_qat: bool):
self.neutron_quantizer = neutron_quantizer
self.neutron_target_info = (
self.neutron_quantizer.neutron_target_spec.neutron_target_info
)
self.is_qat = is_qat

def partition_types(self) -> list[OpOverload]:
return [torch.ops.aten.addmm.default]
Expand Down Expand Up @@ -363,7 +382,11 @@ def get_anchors(
ch_axis=0,
)

weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver
weight_observer_or_fake_quant_ctr = (
FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver)
if self.is_qat
else PerChannelMinMaxObserver
)
weight_quantization_spec = QuantizationSpec(
dtype=torch.int8,
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr,
Expand Down Expand Up @@ -392,11 +415,12 @@ def partition_types(self) -> list[OpOverload]:


class Conv2dPattern(ConvPattern):
def __init__(self, neutron_quantizer):
def __init__(self, neutron_quantizer, is_qat: bool):
self.neutron_quantizer = neutron_quantizer
self.neutron_target_info = (
self.neutron_quantizer.neutron_target_spec.neutron_target_info
)
self.is_qat = is_qat

def partition_types(self) -> list[OpOverload]:
return [torch.ops.aten.conv2d.default]
Expand All @@ -419,7 +443,11 @@ def get_anchors(
ch_axis=0,
)

weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver
weight_observer_or_fake_quant_ctr = (
FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver)
if self.is_qat
else PerChannelMinMaxObserver
)
weight_quantization_spec = QuantizationSpec(
dtype=torch.int8,
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr,
Expand Down Expand Up @@ -511,11 +539,12 @@ def replacement_op(self):


class LinearPattern(QuantizationPattern):
def __init__(self, neutron_quantizer):
def __init__(self, neutron_quantizer, is_qat: bool):
self.neutron_quantizer = neutron_quantizer
self.neutron_target_info = (
self.neutron_quantizer.neutron_target_spec.neutron_target_info
)
self.is_qat = is_qat

def partition_types(self) -> list[OpOverload]:
return [torch.ops.aten.linear.default]
Expand Down Expand Up @@ -585,11 +614,12 @@ def partition_types(self):


class MmPattern(QuantizationPattern):
def __init__(self, neutron_quantizer):
def __init__(self, neutron_quantizer, is_qat: bool):
self.neutron_quantizer = neutron_quantizer
self.neutron_target_info = (
self.neutron_quantizer.neutron_target_spec.neutron_target_info
)
self.is_qat = is_qat

def partition_types(self) -> list[OpOverload]:
return [torch.ops.aten.mm.default]
Expand Down
15 changes: 15 additions & 0 deletions backends/nxp/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,3 +571,18 @@ def __init__(self, activation: str, inplace: bool, in_channels: int):
def forward(self, x):
x = self.conv(x)
return self.activation(x)


class MLP(torch.nn.Module):
def __init__(self):
super().__init__()
self.sequential = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1),
)

def forward(self, x):
return self.sequential(x)
Loading