Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

import torch

from executorch.backends.nxp.backend.data_format import DataFormat
from executorch.backends.nxp.backend.ir.converter.conversion import translator
from executorch.backends.nxp.backend.ir.converter.conversion.common import OpsList
from executorch.backends.nxp.backend.ir.converter.conversion.translator import (
create_channels_last_to_channels_first_permutation,
)
Expand Down Expand Up @@ -89,10 +92,15 @@ def _is_supported_in_IR(
def _to_pos_dim(d: int, rank: int):
return d + rank if d < 0 else d

@staticmethod
def _normalize_dim(dim: list[int], rank: int) -> list[int]:
# convert negative index to positive
return [MeanDimConverter._to_pos_dim(d, rank) for d in dim]

@staticmethod
def _normalize_and_to_channel_last_dim(dim: list[int], rank: int) -> list[int]:
# convert negative index to positive
dim = [MeanDimConverter._to_pos_dim(d, rank) for d in dim]
dim = MeanDimConverter._normalize_dim(dim, rank)

perm = create_channels_last_to_channels_first_permutation(rank, True)
dim = [perm[d] for d in dim]
Expand All @@ -106,6 +114,114 @@ def _get_attrs(node: Node) -> tuple[list[int], bool]:
keepdim = node.args[2] if len(node.args) >= 3 else False
return dim, keepdim

def _get_dim_and_handle_io_formats(
self, ops: OpsList, dim: list[int], keep_dim: bool
):
t_op = ops.middle_op
x = t_op.tmp_inputs[0]
y = t_op.tmp_outputs[0]

channels_last_input = x.tensor_format.is_channels_last()
channels_last_output = y.tensor_format.is_channels_last()
formatless_input = not channels_last_input
formatless_output = not channels_last_output

dim = self._normalize_dim(dim, x.rank)

if keep_dim:
# The rank is preserved and the io formats should always be equal.
assert (
x.tensor_format == y.tensor_format
), "NXP backend: There is a bug in `mean.dim` format inference."

# Just adjust the dim to match the input format.
if channels_last_input:
dim = self._normalize_and_to_channel_last_dim(dim, x.rank)

else:
# `keep_dim = False`, so the output rank != input rank, and the operator changes the tensor format.

if channels_last_input and formatless_output:
if 1 in dim:
# If we are reducing over the channels, the channels dimension gets removed and the output ends up
# exactly equal in channels last and channels first, regardless of which other dimensions are
# removed. Therefore, we can just adjust the `dim` and we don't need to insert any `Transpose` ops.
dim = self._normalize_and_to_channel_last_dim(dim, x.rank)
elif all(spatial_dim in dim for spatial_dim in range(2, x.rank)):
# All spatial dims are reduced, leaving only batch and channels (both optionally). So the result is
# equal in channels first and channels last as long as we adjust the `dim` to match a channels last
# input (similarly to the case above).
dim = self._normalize_and_to_channel_last_dim(dim, x.rank)
else:
# If the channels dimension is preserved, we must transpose the input to channels first (to match
# the edge model) and we must keep the `dim` unchanged (referencing channels first dimensions).
# Otherwise, the output would not match the input.
to_channels_first_perm = (
translator.create_channels_last_to_channels_first_permutation(
x.rank
)
)
ops.add_pre(
self.builder.create_transpose_operator_before(
t_op, 0, to_channels_first_perm
)
)
t_op.tmp_inputs[0].tensor_format = DataFormat.CHANNELS_FIRST

elif formatless_input and channels_last_output:
# We need apply the `mean` with the original `dim`, which will produce a channels first output. Then,
# we need to append a `Transpose` operator to make the output channels last.
to_channels_last_perm = (
translator.create_channels_first_to_channels_last_permutation(
y.rank, True
)
)
ops.add_post(
self.builder.create_transpose_operator_after(
t_op, 0, to_channels_last_perm
)
)
t_op.tmp_outputs[0].tensor_format = DataFormat.CHANNELS_FIRST

elif formatless_input and formatless_output:
# No action needed.
pass

else: # channels_last_input and channels_last_output
# This case cannot currently occur, as it would require the case:
# channels last 4D -> mean -> channels_last 3D
# which cannot currently happen as the 3D conv/pooling/... is supported by adding `view_copy` nodes in
# the edge dialect and converting the node to 4D, and the `view_copy` nodes prevent the propagation of
# the format to the `mean.dim` output.
# Therefore, the implementation cannot be tested. But from experience with other operators, it should
# work correctly. We just need to add 2 `Transpose` ops to make the IO channels first, and keep the
# `dim` unchanged.
to_channels_first_perm = (
translator.create_channels_last_to_channels_first_permutation(
x.rank
)
)
ops.add_pre(
self.builder.create_transpose_operator_before(
t_op, 0, to_channels_first_perm
)
)
t_op.tmp_inputs[0].tensor_format = DataFormat.CHANNELS_FIRST

to_channels_last_perm = (
translator.create_channels_first_to_channels_last_permutation(
y.rank, True
)
)
ops.add_post(
self.builder.create_transpose_operator_after(
t_op, 0, to_channels_last_perm
)
)
t_op.tmp_outputs[0].tensor_format = DataFormat.CHANNELS_FIRST

return dim

def convert(self, node: Node):
"""Convert the 'mean.dim' operator to NeutronIR 'Mean'.
The ExecuTorch schema is:
Expand All @@ -123,10 +239,9 @@ def convert(self, node: Node):

t_op = self._create_tflite_op_with_io_tensors(node)
t_op.builtin_options = mean_options.Mean(keepdim)
x = t_op.tmp_inputs[0]

if x.tensor_format.is_channels_last():
dim = self._normalize_and_to_channel_last_dim(dim, x.rank)
ops = OpsList(middle_op=t_op)
dim = self._get_dim_and_handle_io_formats(ops, dim, keepdim)

convert_axes_from_attribute(t_op, self.builder, dim)
self.builder.append_operators([t_op])
self.builder.append_operators(ops.flatten())
73 changes: 53 additions & 20 deletions backends/nxp/backend/node_format_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,27 @@
import torch

from executorch.backends.nxp.backend.data_format import DataFormat, NXP_NODE_FORMAT

from executorch.backends.nxp.backend.edge_helper import is_channels_last_dim_order
from executorch.backends.nxp.backend.edge_helper import (
is_channels_last_dim_order,
try_get_arg,
)
from executorch.backends.nxp.backend.edge_program_converter import functions_converters
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.backends.nxp.tests.ops_aliases import (
AdaptiveAvgPool2D,
AvgPool2D,
Convolution,
DequantizePerChannel,
DequantizePerTensor,
GetItem,
MaxPool2D,
MaxPool2DWithIndices,
MeanDim,
PermuteCopy,
QuantizePerTensor,
UpsampleBilinear2D,
UpsampleNearest2D,
ViewCopy,
)
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from torch.export import ExportedProgram
from torch.fx import Node
Expand All @@ -25,21 +42,22 @@ class NodeFormatInference:
# The op in the dictionary is mapped to a dictionary, which holds indices to input nodes
# that are always channels first.
ops_with_channels_first_nodes = {
exir_ops.edge.aten._adaptive_avg_pool2d.default: {"inputs": [0]},
AdaptiveAvgPool2D: {"inputs": [0]},
torch.ops.aten.adaptive_avg_pool2d.default: {"inputs": [0]},
exir_ops.edge.aten.avg_pool2d.default: {"inputs": [0]},
exir_ops.edge.aten.convolution.default: {"inputs": [0, 1]},
exir_ops.edge.aten.max_pool2d_with_indices.default: {"inputs": [0]},
exir_ops.edge.aten.max_pool2d.default: {"inputs": [0]},
exir_ops.edge.aten.upsample_bilinear2d.vec: {"inputs": [0]},
exir_ops.edge.aten.upsample_nearest2d.vec: {"inputs": [0]},
AvgPool2D: {"inputs": [0]},
Convolution: {"inputs": [0, 1]},
MaxPool2DWithIndices: {"inputs": [0]},
MaxPool2D: {"inputs": [0]},
UpsampleBilinear2D: {"inputs": [0]},
UpsampleNearest2D: {"inputs": [0]},
}

# A set of Edge Aten ops, which have the ability to change the format (for example - input nodes
# are channels first but output is formatless).
ops_that_can_change_tensor_format = {
exir_ops.edge.aten.view_copy.default,
exir_ops.edge.aten.permute_copy.default,
ViewCopy,
PermuteCopy,
MeanDim,
}

_type_changed_during_last_run: bool
Expand Down Expand Up @@ -71,10 +89,10 @@ def __init__(self, edge_program: ExportedProgram, only_for_op_support_check=Fals
self._type_changed_during_last_run = False

self._known_targets = list(functions_converters) + [
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
operator.getitem,
DequantizePerTensor,
DequantizePerChannel,
QuantizePerTensor,
GetItem,
]

def identify_node_formats(self):
Expand Down Expand Up @@ -104,10 +122,7 @@ def _infer_format_of_nodes(self, node: Node):
self._handle_node_which_uses_channels_first_format(node)

elif op_type in self.ops_that_can_change_tensor_format:
if op_type in [
exir_ops.edge.aten.view_copy.default,
exir_ops.edge.aten.permute_copy.default,
]:
if op_type in [ViewCopy, PermuteCopy]:
# Try to assign the `formatless` format to the input and output. The converter will then handle the
# transition.
# Note: If the format for the input/output has already been assigned as channels first, it will NOT be
Expand All @@ -119,10 +134,28 @@ def _infer_format_of_nodes(self, node: Node):
self._node_inputs[node][0], DataFormat.FORMATLESS
)

elif op_type == MeanDim:
# The operator schema is:
# mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
keep_dim = try_get_arg(node, 2) or False
if keep_dim:
# The operator preserves the rank, so we can handle it as an operator that can use any node format.
self._handle_node_which_can_use_any_node_format(node)
else:
# The operator removes dimensions, so the IO must be marked as `formatless` (unless overridden by
# channels first of course).
self._assign_format_to_node(
self._node_outputs[node][0], DataFormat.FORMATLESS
)
self._assign_format_to_node(
self._node_inputs[node][0], DataFormat.FORMATLESS
)

else:
logger.error(
f"Node format inference for node type: {op_type} not found!"
)

elif node.op != "call_function" or (
hasattr(node, "target") and node.target in self._known_targets
):
Expand Down
Loading
Loading