🐛 Describe the bug
Arm backend bug: TOSAPartitioner.partition() makes export order depend on Python hash randomization
Summary
Arm's TOSAPartitioner.partition() currently accumulates partition tags in a
Python set and then builds partition_tags from that set:
tags: set[str] = set()
...
partition_tags = {tag: self.delegation_spec for tag in tags}
Because Python set iteration order depends on the interpreter hash seed,
partition_tags insertion order is non-deterministic across fresh processes.
That would be harmless if later stages treated partition_tags as unordered,
but ExecuTorch backend lowering iterates partition_tags.items() in insertion
order while:
- duplicating constants
- creating delegate submodules
- lowering those submodules to backend modules
So the emitted lowered graph and downstream backend artifacts can change across
Python processes even when the input model and weights are identical.
Affected code
In the current executorch checkout:
backends/arm/tosa/partitioner.py
_tag_module() accumulates tags in tags: set[str] = set() at line 207
partition() builds partition_tags = {tag: self.delegation_spec for tag in tags} at line 314
exir/backend/backend_api.py
- iterates
partition_result.partition_tags.items() while partitioning/lowering at lines 271, 421, 455, and 740
Tiny Arm end-to-end reproducer
The repro below is intentionally tiny, but it is a true Arm lowering repro. It
uses:
TOSAPartitioner
TOSABackend
to_edge_transform_and_lower()
The model is the same small two-delegate shape already used in Arm tests:
def forward(self, x, y):
z = x + y
s = torch.max(z)
return s * z
This keeps the repro small while still exercising Arm's real partitioning and
lowering path end to end.
It demonstrates the core problem:
- same graph
- same weights
- same inputs
- only
PYTHONHASHSEED changes
- lowered Arm graph changes
The script depends only on torch, the local executorch checkout, and the
normal Arm Python dependency set needed to run TOSA lowering.
Standalone repro file in this repo:
backends/arm/test/misc/partition_tag_order_repro.py
Run:
python backends/arm/test/misc/partition_tag_order_repro.py
Reproducer script
from __future__ import annotations
import hashlib
import json
import os
import subprocess
import sys
from pathlib import Path
REPO_ROOT = Path(__file__).resolve().parents[4]
SRC_ROOT = REPO_ROOT / "src"
_RESULT_PREFIX = "JSON_RESULT="
_CHILD_ENV = "ARM_PARTITION_TAG_ORDER_REPRO_CHILD"
def _prepare_imports() -> None:
if str(SRC_ROOT) not in sys.path:
sys.path.insert(0, str(SRC_ROOT))
try:
import tosa_serializer # noqa: F401
except ImportError:
import serializer.tosa_serializer as serializer_tosa
sys.modules["tosa_serializer"] = serializer_tosa
def _run_once() -> dict[str, object]:
_prepare_imports()
import torch
from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec
from executorch.backends.arm.tosa.partitioner import TOSAPartitioner
from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower
class MultipleDelegatesModule(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
z = x + y
s = torch.max(z)
return s * z
torch.manual_seed(0)
model = MultipleDelegatesModule().eval()
inputs = (torch.randn(10, 4, 5), torch.randn(10, 4, 5))
exported = torch.export.export(model, inputs, strict=True)
partitioner = TOSAPartitioner(TosaCompileSpec("TOSA-1.0+FP"))
edge_program = to_edge_transform_and_lower(
exported,
partitioner=[partitioner],
compile_config=EdgeCompileConfig(_check_ir_validity=False),
)
exported_program = edge_program.exported_program()
graph_module = exported_program.graph_module
graph_text = str(graph_module.graph)
delegate_count = sum(
1
for node in graph_module.graph.nodes
if node.op == "call_function"
and node.target == torch.ops.higher_order.executorch_call_delegate
)
return {
"pythonhashseed": os.environ.get("PYTHONHASHSEED", ""),
"delegate_count": delegate_count,
"lowered_graph_sha256": hashlib.sha256(graph_text.encode("utf-8")).hexdigest(),
"lowered_graph": graph_text,
}
def _run_child(seed: int) -> dict[str, object]:
env = os.environ.copy()
env["PYTHONHASHSEED"] = str(seed)
env[_CHILD_ENV] = "1"
out = subprocess.check_output(
[sys.executable, __file__],
env=env,
text=True,
stderr=subprocess.STDOUT,
)
for line in out.splitlines():
if line.startswith(_RESULT_PREFIX):
return json.loads(line[len(_RESULT_PREFIX) :])
raise RuntimeError(out)
seed1 = _run_child(1)
seed2 = _run_child(2)
print("seed", seed1["pythonhashseed"], "sha256 =", seed1["lowered_graph_sha256"])
print("seed", seed2["pythonhashseed"], "sha256 =", seed2["lowered_graph_sha256"])
print("hashes_match =", seed1["lowered_graph_sha256"] == seed2["lowered_graph_sha256"])
Expected result
Changing only PYTHONHASHSEED should not change the lowered Arm graph for a
fixed model, weights, and inputs.
Actual result
It does change the lowered Arm graph.
For example, in one local run:
PYTHONHASHSEED=1 produced lowered graph hash:
c5504f6a7079731396c58bd7e56bb4420b18bb1b5d612df9781247c271aa8230
PYTHONHASHSEED=2 produced lowered graph hash:
787da2bc084cec910bee01c1480afb0e2655f0be066821757bbbdfcef26eef76
Both runs still produced two Arm delegates. The only relevant change was the
Python hash seed affecting the iteration order of the internal tag set.
Why this is specifically an Arm bug trigger
This repro is already Arm-specific, but the direct trigger is still the same:
TOSAPartitioner materializes partition_tags from a Python set.
That means the final lowered Arm/VGF graph can depend on PYTHONHASHSEED even
for the exact same model.
For example, in fresh Python processes, the same set literal can iterate in
different orders:
import os
import subprocess
import sys
for seed in ("1", "2"):
env = os.environ.copy()
env["PYTHONHASHSEED"] = seed
out = subprocess.check_output(
[sys.executable, "-c", "print(list({'tag0', 'tag1'}))"],
env=env,
text=True,
).strip()
print(seed, out)
One local run produced:
- seed
1: ['tag0', 'tag1']
- seed
2: ['tag1', 'tag0']
Proposed fix
Make Arm's partitioner return partition_tags in deterministic order.
Minimal stabilization:
partition_tags = {
tag: self.delegation_spec
for tag in sorted(tags)
}
More robust fix:
- preserve tag discovery order inside
TOSAPartitioner instead of storing tags in a set
- also make ExecuTorch lowering stop depending on incoming dict insertion order
Why this matters
Backend export should be deterministic for a fixed model, weights, and tool
versions. Requiring callers to set PYTHONHASHSEED to get stable artifacts is
not robust enough for CI, golden generation, or release bundling.
Versions
Collecting environment information...
PyTorch version: 2.10.0+cpu
Is debug build: False
CUDA used to build PyTorch: Could not collect
ROCM used to build PyTorch: N/A
OS: Microsoft Windows 11 Enterprise (10.0.26100 64-bit)
GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect
Libc version: N/A
Python version: 3.10.5 (tags/v3.10.5:f377153, Jun 6 2022, 16:14:13) [MSC v.1929 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.26100-SP0
Is CUDA available: False
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4070 Ti
Nvidia driver version: 576.88
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A
CPU:
Name: 13th Gen Intel(R) Core(TM) i9-13900KF
Manufacturer: GenuineIntel
Family: 207
Architecture: 9
ProcessorType: 3
DeviceID: CPU0
CurrentClockSpeed: 3000
MaxClockSpeed: 3000
L2CacheSize: 32768
L2CacheSpeed: None
Revision: None
Versions of relevant libraries:
[pip3] executorch==1.2.0.dev20260305+cpu
[pip3] numpy==2.1.3
[pip3] pytorch_tokenizers==1.1.0
[pip3] torch==2.10.0
[pip3] torchao==0.15.0
[pip3] torchvision==0.25.0
[conda] Could not collect
cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell
🐛 Describe the bug
Arm backend bug:
TOSAPartitioner.partition()makes export order depend on Python hash randomizationSummary
Arm's
TOSAPartitioner.partition()currently accumulates partition tags in aPython
setand then buildspartition_tagsfrom that set:Because Python
setiteration order depends on the interpreter hash seed,partition_tagsinsertion order is non-deterministic across fresh processes.That would be harmless if later stages treated
partition_tagsas unordered,but ExecuTorch backend lowering iterates
partition_tags.items()in insertionorder while:
So the emitted lowered graph and downstream backend artifacts can change across
Python processes even when the input model and weights are identical.
Affected code
In the current
executorchcheckout:backends/arm/tosa/partitioner.py_tag_module()accumulates tags intags: set[str] = set()at line 207partition()buildspartition_tags = {tag: self.delegation_spec for tag in tags}at line 314exir/backend/backend_api.pypartition_result.partition_tags.items()while partitioning/lowering at lines 271, 421, 455, and 740Tiny Arm end-to-end reproducer
The repro below is intentionally tiny, but it is a true Arm lowering repro. It
uses:
TOSAPartitionerTOSABackendto_edge_transform_and_lower()The model is the same small two-delegate shape already used in Arm tests:
This keeps the repro small while still exercising Arm's real partitioning and
lowering path end to end.
It demonstrates the core problem:
PYTHONHASHSEEDchangesThe script depends only on
torch, the localexecutorchcheckout, and thenormal Arm Python dependency set needed to run TOSA lowering.
Standalone repro file in this repo:
backends/arm/test/misc/partition_tag_order_repro.pyRun:
Reproducer script
Expected result
Changing only
PYTHONHASHSEEDshould not change the lowered Arm graph for afixed model, weights, and inputs.
Actual result
It does change the lowered Arm graph.
For example, in one local run:
PYTHONHASHSEED=1produced lowered graph hash:c5504f6a7079731396c58bd7e56bb4420b18bb1b5d612df9781247c271aa8230PYTHONHASHSEED=2produced lowered graph hash:787da2bc084cec910bee01c1480afb0e2655f0be066821757bbbdfcef26eef76Both runs still produced two Arm delegates. The only relevant change was the
Python hash seed affecting the iteration order of the internal tag set.
Why this is specifically an Arm bug trigger
This repro is already Arm-specific, but the direct trigger is still the same:
TOSAPartitionermaterializespartition_tagsfrom a Pythonset.That means the final lowered Arm/VGF graph can depend on
PYTHONHASHSEEDevenfor the exact same model.
For example, in fresh Python processes, the same set literal can iterate in
different orders:
One local run produced:
1:['tag0', 'tag1']2:['tag1', 'tag0']Proposed fix
Make Arm's partitioner return
partition_tagsin deterministic order.Minimal stabilization:
More robust fix:
TOSAPartitionerinstead of storing tags in asetWhy this matters
Backend export should be deterministic for a fixed model, weights, and tool
versions. Requiring callers to set
PYTHONHASHSEEDto get stable artifacts isnot robust enough for CI, golden generation, or release bundling.
Versions
Collecting environment information...
PyTorch version: 2.10.0+cpu
Is debug build: False
CUDA used to build PyTorch: Could not collect
ROCM used to build PyTorch: N/A
OS: Microsoft Windows 11 Enterprise (10.0.26100 64-bit)
GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect
Libc version: N/A
Python version: 3.10.5 (tags/v3.10.5:f377153, Jun 6 2022, 16:14:13) [MSC v.1929 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.26100-SP0
Is CUDA available: False
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4070 Ti
Nvidia driver version: 576.88
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A
CPU:
Name: 13th Gen Intel(R) Core(TM) i9-13900KF
Manufacturer: GenuineIntel
Family: 207
Architecture: 9
ProcessorType: 3
DeviceID: CPU0
CurrentClockSpeed: 3000
MaxClockSpeed: 3000
L2CacheSize: 32768
L2CacheSpeed: None
Revision: None
Versions of relevant libraries:
[pip3] executorch==1.2.0.dev20260305+cpu
[pip3] numpy==2.1.3
[pip3] pytorch_tokenizers==1.1.0
[pip3] torch==2.10.0
[pip3] torchao==0.15.0
[pip3] torchvision==0.25.0
[conda] Could not collect
cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell