Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ def _tosa_pipeline(

# Node transformation passes (post q/dq folding)

self.add_pass(DecomposeExpm1Pass())
self.add_pass(DecomposeLogitPass())
self.add_pass(DecomposeMaskedFill())
self.add_pass(DecomposeRoundPass())
Expand All @@ -209,7 +208,6 @@ def _tosa_pipeline(
self.add_pass(DecomposeSinhPass())
self.add_pass(DecomposeSignPass())
self.add_pass(DecomposeFloorDividePass())
self.add_pass(DecomposeDivTensorModePass())
self.add_pass(DecomposeGeluPass())
self.add_pass(DecomposeAddSubAlphaPass())
self.add_pass(DecomposeGroupedConv())
Expand Down
9 changes: 9 additions & 0 deletions backends/arm/_passes/arm_pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def create_node(
quantize: bool = False,
q_params: Optional[tuple] = None,
from_node: Optional[torch.fx.Node] = None,
inherit_qparams: bool = False,
):
"""
Adds a node to 'graph'. graph.inserting_before/after() should be used before the call to decide where to insert the node.
Expand All @@ -132,6 +133,14 @@ def create_node(
keys = from_node.meta.keys()
for key in keys:
new_meta[key] = from_node.meta[key]
if not inherit_qparams:
if "input_qparams" in new_meta:
new_meta["input_qparams"] = {}
if "output_qparams" in new_meta:
new_meta["output_qparams"] = {}
elif inherit_qparams:
raise ValueError("inherit_qparams is only valid when from_node is given")

old_stack_trace = new_meta.get("stack_trace", "")
new_meta["stack_trace"] = f"{old_stack_trace}\n{traceback.format_stack()[-2]}"
node.meta = new_meta
Expand Down
1 change: 1 addition & 0 deletions backends/arm/_passes/broadcast_args_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
args=(arg, multiples),
kwargs={},
from_node=node,
inherit_qparams=False,
)
node.replace_input_with(arg, repeat)

Expand Down
11 changes: 9 additions & 2 deletions backends/arm/_passes/conv1d_unsqueeze_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,17 @@ def call_operator(self, op, args, kwargs, meta):
if len(stride) != 1:
return super().call_operator(op, args, kwargs, meta)

x_meta = meta.copy()
x_meta.data["input_qparams"] = {}
x_meta.data["output_qparams"] = {}

x = args[0]
x_unsqueezed_shape = list(x.data.shape) + [1]
x = super().call_operator(
exir_ops.edge.aten.view_copy.default,
(x, x_unsqueezed_shape),
{},
meta,
x_meta,
updated=True,
)

Expand Down Expand Up @@ -79,12 +83,15 @@ def call_operator(self, op, args, kwargs, meta):
exir_ops.edge.aten.convolution.default, new_args, kwargs, meta, updated=True
)

x_squeezed_meta = meta.copy()
x_squeezed_meta.data["input_qparams"] = {}
x_squeezed_meta.data["output_qparams"] = {}
x_squeezed_shape = list(x.data.shape)[:-1]
x = super().call_operator(
exir_ops.edge.aten.view_copy.default,
(x, x_squeezed_shape),
{},
meta,
x_squeezed_meta,
updated=True,
)

Expand Down
23 changes: 18 additions & 5 deletions backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from executorch.backends.arm._passes.decompose_avg_pool2d import DecomposeAvgPool2d

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass
from executorch.exir.pass_base import ExportPass, NodeMetadata

edge_ops = (exir_ops.edge.aten._adaptive_avg_pool2d.default,)
aten_ops = (torch.ops.aten.adaptive_avg_pool2d.default,)
Expand Down Expand Up @@ -60,6 +60,11 @@ def call_operator(self, op, args, kwargs, meta, updated=False):
# Vela currently only allows a stride in the interval of [1,3] for AvgPool2d.
# To accommodate this, the AvgPool2d op is applied to pooling regions and the results are concatenated.

# Slices and concats does not require quantization parameters
metadata_dict = dict(meta.data)
metadata_dict["input_qparams"] = {}
metadata_dict["output_qparams"] = {}
meta_with_no_qparams = NodeMetadata(metadata_dict)
res = []
for out_i in range(output_size_h):
row = []
Expand All @@ -72,11 +77,15 @@ def call_operator(self, op, args, kwargs, meta, updated=False):

# Slice along H
x_h = super().call_operator(
slice_op, (x, 2, start_h, end_h), kwargs, meta, True
slice_op, (x, 2, start_h, end_h), kwargs, meta_with_no_qparams, True
)
# Slice along W
x_hw = super().call_operator(
slice_op, (x_h, 3, start_w, end_w), kwargs, meta, True
slice_op,
(x_h, 3, start_w, end_w),
kwargs,
meta_with_no_qparams,
True,
)

# Apply avg pooling with kernel size equal to the pooling region
Expand All @@ -89,9 +98,13 @@ def call_operator(self, op, args, kwargs, meta, updated=False):
row.append(pooled)

# Concatenate row results along width (dim=3)
row_tensor = super().call_operator(cat_op, (row, 3), kwargs, meta, True)
row_tensor = super().call_operator(
cat_op, (row, 3), kwargs, meta_with_no_qparams, True
)
res.append(row_tensor)

# Concatenate all rows along height (dim=2)
out = super().call_operator(cat_op, (res, 2), kwargs, meta, True)
out = super().call_operator(
cat_op, (res, 2), kwargs, meta_with_no_qparams, True
)
return out
24 changes: 20 additions & 4 deletions backends/arm/_passes/decompose_cumsum_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,13 @@ def call(self, graph_module):
with graph.inserting_before(node):
# Reshape to 4D with
view_args = (input_node, conv_shape)
view_node = create_node(graph, view_op, args=view_args, from_node=node)
view_node = create_node(
graph,
view_op,
args=view_args,
from_node=node,
inherit_qparams=False,
)

conv_args = (
view_node,
Expand All @@ -114,7 +120,9 @@ def call(self, graph_module):
[0],
1,
)
conv_node = create_node(graph, conv_op, args=conv_args, from_node=node)
conv_node = create_node(
graph, conv_op, args=conv_args, from_node=node, inherit_qparams=True
)

# The convolution is inserted after quantization, so we need to set our
# own quantization parameters for the weights here. However since the
Expand All @@ -129,12 +137,20 @@ def call(self, graph_module):

slice_args = (conv_node, 2, 0, original_shape[dim])
slice_node = create_node(
graph, slice_op, args=slice_args, from_node=node
graph,
slice_op,
args=slice_args,
from_node=node,
inherit_qparams=False,
)

view_original_args = (slice_node, original_shape)
view_original_node = create_node(
graph, view_op, args=view_original_args, from_node=node
graph,
view_op,
args=view_original_args,
from_node=node,
inherit_qparams=False,
)

# Replace and remove original
Expand Down
13 changes: 6 additions & 7 deletions backends/arm/_passes/decompose_linear_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def call(self, graph_module):
op_target=exir_ops.edge.aten.view_copy.default,
args=(input, input_reshaped_shape),
kwargs={},
from_node=node,
inherit_qparams=False,
)

# Reshape weights to 4D with shape (Co, Ci, 1, 1)
Expand All @@ -63,6 +65,8 @@ def call(self, graph_module):
op_target=exir_ops.edge.aten.view_copy.default,
args=(weights, weights_reshaped_shape),
kwargs={},
from_node=node,
inherit_qparams=False,
)

conv = create_node(
Expand All @@ -81,6 +85,7 @@ def call(self, graph_module):
),
kwargs={},
from_node=node,
inherit_qparams=True,
)

with graph_module.graph.inserting_after(conv):
Expand All @@ -93,14 +98,8 @@ def call(self, graph_module):
args=(conv, list(output_shape)),
kwargs={},
from_node=node,
inherit_qparams=False,
)
# Quantization parameters are inherited from original linear node, but
# output reshape should use the linear node's output qparams for both input
# and output.
if "input_qparams" in output.meta:
output.meta["input_qparams"] = output.meta.get(
"output_qparams", None
)

node.replace_all_uses_with(output)
graph_module.graph.erase_node(node)
Expand Down
36 changes: 21 additions & 15 deletions backends/arm/_passes/decompose_maxpool2d_with_dilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,34 +70,40 @@ def call_operator(self, op, args, kwargs, meta):
ph2 += extra_h * d_h
pw2 += extra_w * d_w

meta_with_no_qparams = meta.copy()
meta_with_no_qparams.data["output_qparams"] = {}
meta_with_no_qparams.data["input_qparams"] = {}
meta_with_no_output_qparams = meta.copy()
meta_with_no_output_qparams.data["output_qparams"] = {}

# 1) Pad via EXIR edge pad (preserves dtype)
pad_edge = exir_ops.edge.aten.constant_pad_nd.default
pads = [pw, pw2, ph, ph2, 0, 0, 0, 0]
x_pad = super().call_operator(
pad_edge,
(x, pads, 0),
{},
meta,
meta_with_no_output_qparams,
)

# 2) Space-to-batch: reshape and permute
x2 = super().call_operator(
exir_ops.edge.aten.view_copy.default,
(x_pad, [N, C, H_pack, d_h, W_pack, d_w]),
{},
meta,
meta_with_no_qparams,
)
x2 = super().call_operator(
exir_ops.edge.aten.permute_copy.default,
(x2, [3, 5, 0, 1, 2, 4]),
{},
meta,
meta_with_no_qparams,
)
x2 = super().call_operator(
exir_ops.edge.aten.view_copy.default,
(x2, [N * d_h * d_w, C, H_pack, W_pack]),
{},
meta,
meta_with_no_qparams,
)

# 3) Core pooling on packed tensor
Expand All @@ -120,13 +126,13 @@ def call_operator(self, op, args, kwargs, meta):
operator.getitem,
(pool_out, 0),
{},
meta,
meta_with_no_qparams,
)
indices_proxy = super().call_operator(
operator.getitem,
(pool_out, 1),
{},
meta,
meta_with_no_qparams,
)
pooled_fake, _ = pool_out.data
else:
Expand All @@ -141,20 +147,20 @@ def call_operator(self, op, args, kwargs, meta):
exir_ops.edge.aten.view_copy.default,
(pooled_proxy, [d_h, d_w, N, C_out, H_out, W_out]),
{},
meta,
meta_with_no_qparams,
)
out = super().call_operator(
exir_ops.edge.aten.permute_copy.default,
(out, [2, 3, 4, 0, 5, 1]),
{},
meta,
meta_with_no_qparams,
)
# now flatten back into (N, C, H_out*d_h, W_out*d_w)
out = super().call_operator(
exir_ops.edge.aten.view_copy.default,
(out, [N, C_out, H_out * d_h, W_out * d_w]),
{},
meta,
meta_with_no_qparams,
)

# 5) Final crop
Expand All @@ -166,13 +172,13 @@ def call_operator(self, op, args, kwargs, meta):
exir_ops.edge.aten.slice_copy.Tensor,
(out, 2, S_top, S_top + H),
{},
meta,
meta_with_no_qparams,
)
out = super().call_operator(
exir_ops.edge.aten.slice_copy.Tensor,
(out, 3, S_left, S_left + W),
{},
meta,
meta_with_no_qparams,
)

if is_with_indices:
Expand All @@ -181,7 +187,7 @@ def call_operator(self, op, args, kwargs, meta):
exir_ops.edge.aten.view_copy.default,
(indices_proxy, [d_h, d_w, N, C_out, H_out, W_out]),
{},
meta,
meta_with_no_qparams,
)
idx = super().call_operator(
exir_ops.edge.aten.permute_copy.default,
Expand All @@ -193,19 +199,19 @@ def call_operator(self, op, args, kwargs, meta):
exir_ops.edge.aten.view_copy.default,
(idx, [N, C_out, H_out * d_h, W_out * d_w]),
{},
meta,
meta_with_no_qparams,
)
idx = super().call_operator(
exir_ops.edge.aten.slice_copy.Tensor,
(idx, 2, S_top, S_top + H),
{},
meta,
meta_with_no_qparams,
)
idx = super().call_operator(
exir_ops.edge.aten.slice_copy.Tensor,
(idx, 3, S_left, S_left + W),
{},
meta,
meta_with_no_qparams,
)
return out, idx

Expand Down
12 changes: 10 additions & 2 deletions backends/arm/_passes/decompose_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,18 @@ def call(self, graph_module: torch.fx.GraphModule):

with graph_module.graph.inserting_before(node):
slice_node = create_node(
graph_module.graph, slice_op, (input_node, dim, index, index + 1)
graph_module.graph,
slice_op,
(input_node, dim, index, index + 1),
from_node=node,
inherit_qparams=False,
)
squeeze_node = create_node(
graph_module.graph, squeeze_op, (slice_node, [dim]), from_node=node
graph_module.graph,
squeeze_op,
(slice_node, [dim]),
from_node=node,
inherit_qparams=True,
)

node.replace_all_uses_with(squeeze_node)
Expand Down
6 changes: 5 additions & 1 deletion backends/arm/_passes/decompose_sum_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,11 @@ def call_operator(self, op, args, kwargs, meta):

for dim in dims:
input_node = super().call_operator(
sum_op, (input_node, dim, True), kwargs, meta, updated=True
sum_op,
(input_node, dim, True),
kwargs,
meta,
updated=True,
)

if not keepdims:
Expand Down
Loading
Loading