Skip to content

Conversation

@chohk88
Copy link
Collaborator

@chohk88 chohk88 commented Jan 14, 2026

Description

This PR adds TensorRT Edge-LLM AttentionPlugin backend support as an alternative to the default SDPA lowering, providing 1.7x ~ 3.3x performance improvement for LLM inference.

Supported Models: Llama 3.x (3.1 and 3.2), Qwen 2.5, Qwen 3, Qwen3.1

⚠️ Current Implementation: The plugin backend requires building the AttentionPlugin library from a forked repository branch: https://github.com/chohk88/TensorRT-Edge-LLM/tree/feature/torch-tensorrt-python-runtime

This is a temporary solution for the initial implementation. The fork contains Torch-TRT compatibility Python runtime support that is not yet available in the official NVIDIA TensorRT-Edge-LLM repository.

Type of change

Please delete options that are not relevant and/or add your own.

  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@narendasan
Copy link
Collaborator

@zewenli98 please review

This example uses a custom TensorRT plugin shared library (``libNvInfer_edgellm_plugin.so``)
that replaces standard transformer attention operations and RoPE computations with optimized
CUDA kernels. The plugin source code is available at (internal access only):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chohk88 can you change this to external links?

- kv_cache_start_idx: [B] starting index in KV cache (required for release version)
"""

@torch.library.custom_op("xqa::attn", mutates_args=())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets call the op tensorrt_edge_llm::xqa_attn

- kv_cache_start_idx: [B] starting index in KV cache (required for release version)
"""

@torch.library.custom_op("xqa::attn", mutates_args=())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same thing here: tensorrt_edge_llm::xqa_attn

nkv: int,
d: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = qkv.shape[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to provide a valid implementation here easily? could we lift the kernel from the .so?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be a P1/P2 sort of thing, but I think it would be good for the sake of completeness

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants