"""
Minimal repro for a TensorRT 10.16 build failure when an opset-23 ONNX
`Attention` op (auto-fused from F.scaled_dot_product_attention by
torch.onnx with dynamo=True) lives inside a torch.cond -> ONNX `If` subgraph.
Three variants:
A. SDPA inside torch.cond -> BUILD FAILS (this report)
B. SDPA outside torch.cond -> builds OK
C. Manual matmul/softmax inside torch.cond -> builds OK
"""
import math, os, torch, torch.nn as nn, torch.nn.functional as F, tensorrt as trt
SEQ_LEN, EMBED_DIM, DEVICE, DTYPE = 256, 64, "cuda", torch.bfloat16
class AttnSDPA(nn.Module):
def __init__(self):
super().__init__()
self.to_qkv = nn.Conv1d(EMBED_DIM, EMBED_DIM * 3, 1)
self.proj = nn.Conv1d(EMBED_DIM, EMBED_DIM, 1)
def forward(self, x):
b, c, s = x.shape
qkv = self.to_qkv(x).reshape(b, 1, c * 3, s).permute(0, 1, 3, 2).contiguous()
q, k, v = qkv.chunk(3, dim=-1)
x = F.scaled_dot_product_attention(q, k, v).squeeze(1).permute(0, 2, 1).contiguous()
return self.proj(x)
class AttnManual(nn.Module):
def __init__(self):
super().__init__()
self.to_qkv = nn.Conv1d(EMBED_DIM, EMBED_DIM * 3, 1)
self.proj = nn.Conv1d(EMBED_DIM, EMBED_DIM, 1)
def forward(self, x):
b, c, s = x.shape
qkv = self.to_qkv(x).reshape(b, 1, c * 3, s).permute(0, 1, 3, 2).contiguous()
q, k, v = qkv.chunk(3, dim=-1)
scale = 1.0 / math.sqrt(q.shape[-1])
attn = (torch.matmul(q, k.transpose(-1, -2)) * scale).softmax(dim=-1)
x = torch.matmul(attn, v).squeeze(1).permute(0, 2, 1).contiguous()
return self.proj(x)
class CondWrapper(nn.Module):
def __init__(self, body):
super().__init__()
self.body = body
def _branch(self, x):
return (self.body(x).contiguous(),)
def forward(self, x, first_chunk):
return torch.cond(first_chunk, self._branch, self._branch, (x,))
class FlatWrapper(nn.Module):
def __init__(self, body):
super().__init__()
self.body = body
def forward(self, x):
return self.body(x).contiguous()
def export_onnx(wrapper, args, in_names, out_names, onnx_path):
from torch.export import _trace as _et
cfg = _et.DEFAULT_EXPORT_DYNAMO_CONFIG
saved = cfg.assume_static_by_default
cfg.assume_static_by_default = True # required so the inner cond compile doesn't symbolize input dims
try:
ep = torch.export.export(wrapper, args, strict=False)
finally:
cfg.assume_static_by_default = saved
p = torch.onnx.export(
ep, args, None,
input_names=in_names, output_names=out_names,
opset_version=23, dynamo=True, optimize=False,
)
p.optimize()
from torch.onnx._internal._lazy_import import onnxscript_apis
onnxscript_apis.save_model_with_external_data(p.model, onnx_path, verbose=False)
def count_ops(onnx_path):
import onnx
m = onnx.load(onnx_path, load_external_data=False)
counts = {}
def walk(g):
for n in g.node:
counts[n.op_type] = counts.get(n.op_type, 0) + 1
for a in n.attribute:
if a.type == onnx.AttributeProto.GRAPH:
walk(a.g)
walk(m.graph)
return counts
class _Logger(trt.ILogger):
def __init__(self):
super().__init__()
def log(self, sev, msg):
if sev <= trt.ILogger.Severity.WARNING:
print(f"[TRT {sev.name}] {msg}")
def build_engine(onnx_path, engine_path):
logger = _Logger()
builder = trt.Builder(logger)
cfg = builder.create_builder_config()
cfg.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 4 << 30)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED))
parser = trt.OnnxParser(network, logger)
if not parser.parse_from_file(onnx_path):
for i in range(parser.num_errors):
print(f"[TRT PARSE] {parser.get_error(i)}")
return False
s = builder.build_serialized_network(network, cfg)
if s is None:
return False
with open(engine_path, "wb") as f:
f.write(bytes(s))
return True
def run_variant(name):
print(f"\n=== VARIANT {name} ===")
torch.manual_seed(0)
if name == "A":
w = CondWrapper(AttnSDPA()).to(DEVICE, DTYPE).eval()
args = (torch.randn(1, EMBED_DIM, SEQ_LEN, device=DEVICE, dtype=DTYPE),
torch.tensor(False, device=DEVICE))
in_names = ["x", "first_chunk"]
elif name == "B":
w = FlatWrapper(AttnSDPA()).to(DEVICE, DTYPE).eval()
args = (torch.randn(1, EMBED_DIM, SEQ_LEN, device=DEVICE, dtype=DTYPE),)
in_names = ["x"]
elif name == "C":
w = CondWrapper(AttnManual()).to(DEVICE, DTYPE).eval()
args = (torch.randn(1, EMBED_DIM, SEQ_LEN, device=DEVICE, dtype=DTYPE),
torch.tensor(False, device=DEVICE))
in_names = ["x", "first_chunk"]
onnx_path, engine_path = f"/tmp/repro_{name}.onnx", f"/tmp/repro_{name}.engine"
for p in (onnx_path, engine_path, onnx_path + ".data"):
if os.path.exists(p):
os.remove(p)
export_onnx(w, args, in_names, ["y"], onnx_path)
print(f" ONNX ops: {count_ops(onnx_path)}")
print(f" building TRT...")
return name, build_engine(onnx_path, engine_path)
if __name__ == "__main__":
print(f"PyTorch={torch.__version__} TRT={trt.__version__}")
for n, ok in [run_variant(v) for v in ("A", "B", "C")]:
print(f" {n}: {'OK' if ok else 'FAIL'}")
TensorRT 10.16: opset-23
Attentionop fails inside ONNXIfsubgraph (myelin "Unnamed Layer* N [ElementWise]_output" error)Description
When the standard ONNX opset-23
Attentionop (auto-fused bytorch.onnxwithdynamo=TruefromF.scaled_dot_product_attention) lives inside an ONNXIfsubgraph (lowered fromtorch.cond), TRT 10.16 fails to build the engine with:The error is reproducible with a 200-line standalone script: a single
Attentionop +Conv1Dprojections wrapped intorch.cond. The build succeeds when either:Attentionop is moved outside theIf(variant B), orAttentionis decomposed into explicitMatMul/Softmax/MatMuland left inside theIf(variant C).So the failure is specific to the combination
{opset-23 Attention} ∩ {If subgraph}. We hit this on a real workload (a video VAE that usestorch.condto unify two control-flow paths in one engine) and traced it back to this minimal case.There also seems to be a related minor symptom on the parser side:
[TRT WARNING] ImporterContext.hpp:378: A node named node_Split_1 already existsis emitted for variant A — torch.onnx's QKVSplitends up in both branches of theIfwith the same auto-generated name, and the parser can't query the second instance's outputs. This is a warning rather than a build failure, but might be a related symptom if the unnamed scaling layers TRT creates are similarly affected by If-branch scoping.We searched NVIDIA/TensorRT issues, release notes (10.16 / 10.17 / 10.18 / 11.0), and the developer forum; the closest related report we found is #4705 (also opset-23 Attention, also scoped-ops machinery, but a different failure mode — single-Attention-layer parse-time crash on RTX4080, not the myelin/If interaction shown here). That one is open with no fix or NVIDIA response since 2026-02-26.
Environment
Repro
The script below is fully standalone (no third-party deps beyond torch / tensorrt / onnx / onnxscript). Variant A reproduces the failure; B and C are controls.
trt_bug_repro.py(click to expand)Output
Expected behavior
Variant A should build successfully. The
Attentionop should compose withIfthe same wayMatMul/Softmax/MatMuldo.Notes / hypothesis (from the user side)
When TRT's ONNX importer parses the opset-23
Attentionop it appears to create a few internal helper layers (e.g. an unnamedElementWisefor the Q*scale broadcast, and the helper layers we see in verbose mode namedONNXTRT_ShapeTensorFromDims_*,ONNXTRT_castHelper_*,ONNXTRT_unsqueezeTensor_*). Inside anIfForeignNode, those unnamed layers are referenced by myelin viasetInput(...)but the lookup fails, suggesting an If-subgraph scoping issue in the importer's name table or in myelin's IR builder — not a problem with the op semantics themselves (variant B builds fine; variant C with explicit MatMul/Softmax also builds fine inside the sameIf).Workaround: decompose SDPA into explicit
MatMul/Softmax/MatMulbefore callingtorch.onnx.export— i.e. don't rely on the opset-23 Attention auto-fusion when the call site is reachable from inside atorch.cond. We're using this in production but it's not desirable long-term — we'd like to use the nativeAttentionop for performance.Happy to provide more diagnostics (verbose build log, ONNX file) on request.