Skip to content

ONNX parser failure of TensorRT 10.16 when running Attention node on any GPU #4738

@ir2718

Description

@ir2718

Description

Creating a model with one ONNX Attention node with KV cache support creates an error when parsing with TensorRT.

The following error is produced:

[04/22/2026-19:02:08] [TRT] [E] ModelImporter.cpp:149: ERROR: onnxOpImporters.cpp:191 In function importAttention:
[6] Assertion failed: node.output().size() == 1: TensorRT only supports Attention nodes with one output.

Environment

TensorRT Version: 10.16.0.72

NVIDIA GPU: NVIDIA RTX 6000 Ada

NVIDIA Driver Version: 550.163.01

CUDA Version: 12.4

Operating System: Debian 13

Python Version: 3.13.5

Relevant Files

The script for creating an ONNX model that contains a single Attention node:

import onnx

hidden_dim = 64
num_heads = 8

q = onnx.helper.make_tensor_value_info(
    "q", onnx.TensorProto.FLOAT,
    ["batch_size", num_heads, "curr_seq_len", hidden_dim])
k = onnx.helper.make_tensor_value_info(
    "k", onnx.TensorProto.FLOAT,
    ["batch_size", num_heads, "curr_seq_len", hidden_dim])
v = onnx.helper.make_tensor_value_info(
    "v", onnx.TensorProto.FLOAT,
    ["batch_size", num_heads, "curr_seq_len", hidden_dim])
o = onnx.helper.make_tensor_value_info(
    "o", onnx.TensorProto.FLOAT,
    ["batch_size", num_heads, "curr_seq_len", hidden_dim])

past_k = onnx.helper.make_tensor_value_info(
    "past_k", onnx.TensorProto.FLOAT,
    ["batch_size", num_heads, "past_seq_len", hidden_dim])
past_v = onnx.helper.make_tensor_value_info(
    "past_v", onnx.TensorProto.FLOAT,
    ["batch_size", num_heads, "past_seq_len", hidden_dim])

present_k = onnx.helper.make_tensor_value_info(
    "present_k", onnx.TensorProto.FLOAT,
    ["batch_size", num_heads, "total_seq_len", hidden_dim])
present_v = onnx.helper.make_tensor_value_info(
    "present_v", onnx.TensorProto.FLOAT,
    ["batch_size", num_heads, "total_seq_len", hidden_dim])

inputs = ["q", "k", "v", "", "past_k", "past_v"]
outputs = ["o", "present_k", "present_v"]

inputs_graph = [q, k, v, past_k, past_v]
outputs_graph = [o, present_k, present_v]

attention_node = onnx.helper.make_node(
    "Attention",
    inputs=inputs,
    outputs=outputs,
    q_num_heads=num_heads,
    kv_num_heads=num_heads,
    is_causal=1,
    domain=""
)
graph = onnx.helper.make_graph(
    nodes=[attention_node],
    name="attention_graph",
    inputs=inputs_graph,
    outputs=outputs_graph,
)
model = onnx.helper.make_model(
    graph,
    opset_imports=[onnx.helper.make_opsetid("", 24)],
    producer_name="attention-example",
)
save_name = "model.onnx"
onnx.save_model(model, save_name)

The script generates a model under the name model.onnx:

python3 model.py

Steps To Reproduce

Here is the script used for parsing the ONNX model:

import tensorrt as trt

model_path = "model.onnx"
trt_logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(trt_logger)
network = builder.create_network(
    1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED))
parser = trt.OnnxParser(network, trt_logger)
success = parser.parse_from_file(model_path)

By running the script with the model.onnx file, I get an error regarding the number of outputs:

python3 trt_onnx_parse.py --model_path model.onnx

Here is the full output of the script, including the error:

[04/22/2026-18:37:30] [TRT] [E] ModelImporter.cpp:138: While parsing node number 0 [Attention -> "o"]:
[04/22/2026-18:37:30] [TRT] [E] ModelImporter.cpp:140: --- Begin node ---
input: "q"
input: "k"
input: "v"
input: ""
input: "past_k"
input: "past_v"
output: "o"
output: "present_k"
output: "present_v"
op_type: "Attention"
attribute {
  name: "is_causal"
  i: 1
  type: INT
}
attribute {
  name: "kv_num_heads"
  i: 8
  type: INT
}
attribute {
  name: "q_num_heads"
  i: 8
  type: INT
}
domain: ""

[04/22/2026-18:37:30] [TRT] [E] ModelImporter.cpp:141: --- End node ---
[04/22/2026-18:37:30] [TRT] [E] ModelImporter.cpp:149: ERROR: onnxOpImporters.cpp:191 In function importAttention:
[6] Assertion failed: node.output().size() == 1: TensorRT only supports Attention nodes with one output.

Have you tried the latest release?: Yes

Can this model run on other frameworks?

The model runs using ONNXRuntime:

import numpy as np
import onnxruntime as ort

batch_size = 2
past_seq_len = 7
curr_seq_len = 1
hidden_dim = 64
num_heads = 8

opts = ort.SessionOptions()
ort_model = ort.InferenceSession(
    "model.onnx", opts, providers=["CPUExecutionProvider"],
)
init_tensor_fn = lambda seq: np.random.randn(
    batch_size, num_heads, seq, hidden_dim).astype(np.float32)
q = init_tensor_fn(curr_seq_len)
k = init_tensor_fn(curr_seq_len)
v = init_tensor_fn(curr_seq_len)
past_k = init_tensor_fn(past_seq_len)
past_v = init_tensor_fn(past_seq_len)
out = ort_model.run(None, {
    "q": q,
    "k": k,
    "v": v,
    "past_k": past_k,
    "past_v": past_v,
})

Metadata

Metadata

Assignees

No one assigned

    Labels

    Module:ONNXIssues relating to ONNX usage and import

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions