Skip to content

[xnnpack] XNNPACK lowering fails when a permute result has multiple consumers #19152

@wuyii8941

Description

@wuyii8941

🐛 Describe the bug

Bug

ExecuTorch XNNPACK lowering fails on a small graph where the result ofpermute is consumed by two downstream ops (reshape and sum). The same graph works through the portable ExecuTorch path.

The error says pass-through arguments are not supported by the XNNPACK delegate. If this graph topology is unsupported, the partitioner should probably avoid creating an invalid XNNPACK delegate partition and leave the graph to portable ops, or report the unsupported topology earlier with a clearer diagnostic.

Reproducer

import importlib.metadata

import torch
from torch.export import export

from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.exir import to_edge, to_edge_transform_and_lower
from executorch.extension.pybindings.portable_lib import _load_for_executorch_from_buffer


class Model(torch.nn.Module):
    def forward(self, x):
        y = x.permute(0, 2, 1)
        return y.reshape(x.shape[0], -1), y.sum(dim=1)


print("torch:", torch.__version__)
print("executorch:", importlib.metadata.version("executorch"))

model = Model().eval()
x = torch.randn(2, 3, 5)

# Portable path works.
ref = model(x)
ep = export(model, (x,), strict=True)
program = to_edge(ep).to_executorch()
module = _load_for_executorch_from_buffer(program.buffer)
out = module.run_method("forward", (x,))
print("portable max_diff:", max((a - b).abs().max().item() for a, b in zip(out, ref)))

# XNNPACK path fails during lowering/preprocess.
ep = export(model, (x,), strict=True)
to_edge_transform_and_lower(
    ep,
    partitioner=[XnnpackPartitioner()],
).to_executorch()

Actual Behavior

The portable path lowers and runs correctly:

portable max_diff: ~4.8e-07

The XNNPACK path fails during lowering:

RuntimeError: Output node 'aten_permute_copy_default' is already in the inputs.
This is likely due to pass through arguments, which are not supported in XNNPACK Delegate.

The error originates from:

executorch/backends/xnnpack/xnnpack_preprocess.py
generate_node_to_external_map(...)

Expected Behavior

One of the following would be expected:

  1. XNNPACK lowering succeeds and the program matches PyTorch eager
  2. the partitioner avoids creating a delegate partition that requires unsupported pass-through arguments and leaves those ops on the portable path
  3. lowering rejects the graph before delegate preprocess with a clearer unsupported-topology diagnostic.

Notes

This seems graph-topology specific:

  • y = x.permute(...); return y.reshape(...) lowers successfully.
  • y = x.permute(...); return y.sum(...) lowers successfully.
  • y = x.permute(...); return y, y.sum(...) lowers successfully.
  • y = x.permute(...); return y.reshape(...), y.sum(...) fails.

So the failure appears when the same permute result feeds two nontrivial downstream consumers and the XNNPACK partitioner/preprocess produces a delegate with a pass-through argument.

Versions

Environment

Reproduced locally on:

  • torch: 2.11.0+cu130
  • executorch: 1.2.0
  • Python: 3.11
  • Platform: Linux x86_64

Also reproduced during lowering on:

  • torch: 2.13.0.dev20260425+cu126
  • executorch: 1.2.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions