diff --git a/examples/dynamo/attention_plugin_example.py b/examples/dynamo/attention_plugin_example.py new file mode 100644 index 0000000000..979efb8dcd --- /dev/null +++ b/examples/dynamo/attention_plugin_example.py @@ -0,0 +1,747 @@ +""" +.. _attention_plugin_example: + +Custom Attention Plugin with KV Cache Management +================================================= + +This example demonstrates how to use a custom TensorRT AttentionPlugin that implements +efficient multi-head attention with Rotary Position Embedding (RoPE) and KV cache management +for autoregressive generation in Large Language Models (LLMs). + +**Plugin Library:** + +This example uses a custom TensorRT plugin shared library (``libNvInfer_edgellm_plugin.so``) +that replaces standard transformer attention operations and RoPE computations with optimized +CUDA kernels. The plugin source code is available at (internal access only): + +https://gitlab-master.nvidia.com/hoonkyungc/tensorrt-edgellm/-/blob/torchtrt-plugin-build/README_TORCHTRT_PLUGIN.md + +Build instructions and implementation details can be found in the repository above. + +**Key Features:** + +- **Dual Kernel Support:** + + - **FMHA (Fused Multi-Head Attention)** for context phase when ``seq_len > 1`` (processing multiple tokens) + - **XQA (Extended Query Attention)** for decode phase when ``seq_len = 1`` (single token generation) + +- **KV Cache Management:** Efficiently manages key-value cache for autoregressive generation +- **Perfect Accuracy:** Achieves cosine similarity = 1.0 with PyTorch's ``scaled_dot_product_attention`` +- **Grouped Query Attention (GQA):** Supports efficient attention with fewer KV heads + +**What This Example Tests:** + +1. **XQA Kernel (seq_len=1):** Single token generation, with and without past context +2. **FMHA Kernel (seq_len>1):** Context processing with multiple tokens +3. **Multi-Step Generation:** Realistic LLM scenario - process prompt (FMHA), then generate tokens (XQA) +4. **Perfect Accuracy:** All tests achieve ``cosine_similarity ≥ 0.99`` with PyTorch SDPA + +**Installation Requirements:** + +.. code-block:: bash + + pip install torch torch_tensorrt tensorrt + +Build the AttentionPlugin shared library following instructions at the GitLab repository above. +The compiled library should be located at: ``/path/to/tensorrt-edgellm/build/libNvInfer_edgellm_plugin.so`` +""" + +# %% +# Imports and Setup +# ----------------- + +import ctypes +import os +from typing import Tuple + +import numpy as np +import tensorrt as trt +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_tensorrt +from torch_tensorrt.dynamo.conversion import ( + ConversionContext, + dynamo_tensorrt_converter, +) +from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor + +# %% +# Enable plugin debug logging +# ---------------------------- +os.environ["EDGELLM_DEBUG_PLUGIN"] = "1" + +# %% +# Initialize CUDA and Load Plugin +# -------------------------------- +# CUDA must be initialized before loading the TensorRT plugin library + +print("Initializing CUDA context...") +DEVICE = torch.device("cuda:0") +_ = torch.zeros(1, device=DEVICE) # Initialize CUDA +print(f"CUDA initialized on {DEVICE}\n") + +PLUGIN_PATH = ( + "/develop/TensorRT/TensorRT-Edge-LLM-release/build/libNvInfer_edgellm_plugin.so" +) +ctypes.CDLL(PLUGIN_PATH) +print(f"Loaded plugin: {PLUGIN_PATH}\n") + +# %% +# Model Configuration +# ------------------- +# These hyperparameters match typical LLM architectures with Grouped Query Attention (GQA) + +BATCH_SIZE = 1 +NUM_Q_HEADS = 4 # Number of query heads +NUM_KV_HEADS = 2 # Number of key/value heads (GQA: fewer than query heads) +HEAD_DIM = 64 # Dimension per head +KV_CACHE_CAPACITY = 128 # Maximum sequence length +HIDDEN_DIM = NUM_Q_HEADS * HEAD_DIM # 256 +NUM_KV_GROUPS = NUM_Q_HEADS // NUM_KV_HEADS # 2 + +DTYPE = torch.float16 + +# %% +# RoPE (Rotary Position Embedding) Utilities +# ------------------------------------------- +# RoPE encodes positional information through rotation in complex space + + +def precompute_rope(head_dim: int, max_seq_len: int = 128, base: float = 10000.0): + """ + Precompute RoPE cos/sin for all positions. + + Returns: + Tensor of shape [1, max_seq_len, head_dim] in FP32 + """ + inv_freq = 1.0 / ( + base ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim) + ) + t = torch.arange(max_seq_len, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + rope = torch.cat([cos, sin], dim=-1) + return rope.unsqueeze(0).to(DEVICE) + + +def apply_rope(x, rope_cache, position_ids): + """ + Apply RoPE to input tensor. + + Args: + x: [batch, num_heads, seq_len, head_dim] + rope_cache: [1, max_seq_len, head_dim] + position_ids: [seq_len] position indices + """ + seq_len = x.shape[2] + rope = rope_cache[:, position_ids, :] # [1, seq_len, head_dim] + rope = rope.unsqueeze(1) # [1, 1, seq_len, head_dim] + + half_dim = x.shape[-1] // 2 + cos = rope[..., :half_dim] + sin = rope[..., half_dim:] + + x_fp32 = x.float() + x1 = x_fp32[..., :half_dim] + x2 = x_fp32[..., half_dim:] + + rotated = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1) + + return rotated.half() + + +def repeat_kv(x, n_rep): + """Repeat KV heads for Grouped Query Attention""" + if n_rep == 1: + return x + bs, n_kv_heads, slen, head_dim = x.shape + return ( + x[:, :, None, :, :] + .expand(bs, n_kv_heads, n_rep, slen, head_dim) + .reshape(bs, n_kv_heads * n_rep, slen, head_dim) + ) + + +# %% +# PyTorch SDPA Reference Implementation +# ------------------------------------- +# This serves as the ground truth for correctness validation + + +class SDPAModel(nn.Module): + """Reference attention using PyTorch's scaled_dot_product_attention""" + + def __init__(self): + super().__init__() + self.num_q_heads = NUM_Q_HEADS + self.num_kv_heads = NUM_KV_HEADS + self.head_dim = HEAD_DIM + self.num_key_value_groups = NUM_KV_GROUPS + + self.qkv = nn.Linear( + HIDDEN_DIM, HIDDEN_DIM + 2 * NUM_KV_HEADS * HEAD_DIM, bias=True + ) + self.out = nn.Linear(HIDDEN_DIM, HIDDEN_DIM, bias=False) + + def forward(self, x, kv_cache, ctx_len_tensor, rope): + """ + Args: + x: [batch, seq_len, hidden_dim] + kv_cache: [batch, 2, num_kv_heads, capacity, head_dim] + ctx_len_tensor: [batch] - total context length including current tokens + rope: [1, max_seq_len, head_dim] + """ + batch_size, seq_len, _ = x.shape + ctx_len = ctx_len_tensor[0].item() + past_len = ctx_len - seq_len + + # QKV projection + qkv = self.qkv(x) + q_size = self.num_q_heads * self.head_dim + kv_size = self.num_kv_heads * self.head_dim + query, key, value = torch.split(qkv, [q_size, kv_size, kv_size], dim=-1) + + # Reshape to multi-head format + query = query.view( + batch_size, seq_len, self.num_q_heads, self.head_dim + ).transpose(1, 2) + key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose( + 1, 2 + ) + value = value.view( + batch_size, seq_len, self.num_kv_heads, self.head_dim + ).transpose(1, 2) + + # Apply RoPE + position_ids = torch.arange(past_len, past_len + seq_len, device=x.device) + query = apply_rope(query, rope, position_ids) + key = apply_rope(key, rope, position_ids) + + # Update KV cache + kv_cache[:, 0, :, past_len : past_len + seq_len, :] = key + kv_cache[:, 1, :, past_len : past_len + seq_len, :] = value + + # Get full K/V from cache + full_key = kv_cache[:, 0, :, :ctx_len, :] + full_value = kv_cache[:, 1, :, :ctx_len, :] + + # Expand for GQA + full_key = repeat_kv(full_key, self.num_key_value_groups) + full_value = repeat_kv(full_value, self.num_key_value_groups) + + # Scaled dot-product attention + is_causal = seq_len > 1 + attn_out = F.scaled_dot_product_attention( + query.contiguous(), + full_key.contiguous(), + full_value.contiguous(), + attn_mask=None, + dropout_p=0.0, + is_causal=is_causal, + ) + + # Output projection + attn_out = ( + attn_out.transpose(1, 2).contiguous().view(batch_size, seq_len, HIDDEN_DIM) + ) + output = self.out(attn_out) + + return output, kv_cache + + +# %% +# TensorRT Plugin Integration +# ---------------------------- +# Register custom operation and converter for TensorRT plugin + + +def register_plugin_op(): + """ + Register custom attention operation. + + Note: The release version of TensorRT-Edge-LLM requires 5 inputs: + - qkv: [B, S, (Hq+Hk+Hv)*D] fused QKV tensor + - kv: [B, 2, Hkv, Capacity, D] KV cache tensor + - ctx_len: [B] context length per batch + - rope: [S, D] rotary position encoding + - kv_cache_start_idx: [B] starting index in KV cache (required for release version) + """ + + @torch.library.custom_op("xqa::attn", mutates_args=()) + def attn( + qkv: torch.Tensor, + kv: torch.Tensor, + ctx_len: torch.Tensor, + rope: torch.Tensor, + kv_cache_start_idx: torch.Tensor, # Required 5th input for release plugin + nq: int, + nkv: int, + d: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = qkv.shape[0] + seq_len = qkv.shape[1] + attn_out = torch.zeros( + batch_size, seq_len, nq, d, dtype=qkv.dtype, device=qkv.device + ) + updated_kv = kv.clone() + return attn_out, updated_kv + + @torch.library.register_fake("xqa::attn") + def _(qkv, kv, ctx_len, rope, kv_cache_start_idx, nq, nkv, d): + batch_size = qkv.shape[0] + seq_len = qkv.shape[1] + attn_out = torch.empty( + batch_size, seq_len, nq, d, dtype=qkv.dtype, device=qkv.device + ) + updated_kv = kv.clone() + return attn_out, updated_kv + + +register_plugin_op() + + +@dynamo_tensorrt_converter(torch.ops.xqa.attn.default, supports_dynamic_shapes=True) +def convert_attn(ctx: ConversionContext, target, args, kwargs, name): + """ + Convert PyTorch custom op to TensorRT plugin. + + Release version of TensorRT-Edge-LLM requires 5 inputs: + - qkv, kv, ctx_len, rope, kv_cache_start_idx + + Plugin fields for release version: + - num_q_heads, num_kv_heads, head_size, enable_tree_attention, enable_delta_kv_output + """ + # args: qkv, kv, ctx_len, rope, kv_cache_start_idx, nq, nkv, d + qkv, kv, ctx_len, rope, kv_cache_start_idx, nq, nkv, d = args[:8] + + # Get plugin creator + creator = trt.get_plugin_registry().get_plugin_creator("AttentionPlugin", "1", "") + if creator is None: + raise RuntimeError("AttentionPlugin not found! Make sure plugin is loaded.") + + # Plugin fields for release version of TensorRT-Edge-LLM + field_list = [ + trt.PluginField( + field_name, np.array([field_val], dtype=np.int32), trt.PluginFieldType.INT32 + ) + for field_name, field_val in [ + ("num_q_heads", nq), + ("num_kv_heads", nkv), + ("head_size", d), + ("enable_tree_attention", 0), + ("enable_delta_kv_output", 1), # Enable for python runtime+torch_tensorrt + ] + ] + + fields = trt.PluginFieldCollection(field_list) + plugin = creator.create_plugin(name, fields) + + if plugin is None: + raise RuntimeError("Failed to create plugin") + + # 5 inputs for release version: qkv, kv, ctx_len, rope, kv_cache_start_idx + inputs = [ + ( + get_trt_tensor(ctx, i, f"{name}_i{idx}") + if not isinstance(i, trt.ITensor) + else i + ) + for idx, i in enumerate([qkv, kv, ctx_len, rope, kv_cache_start_idx]) + ] + + # Handle kv_cache_start_idx shape if needed (squeeze if [B, 1] -> [B]) + if len(inputs[4].shape) == 2 and inputs[4].shape[1] == 1: + shuffle_layer = ctx.net.add_shuffle(inputs[4]) + shuffle_layer.reshape_dims = (inputs[4].shape[0],) + inputs[4] = shuffle_layer.get_output(0) + + layer = ctx.net.add_plugin_v2(inputs, plugin) + + return layer.get_output(0), layer.get_output(1) + + +class PluginModel(nn.Module): + """Attention model using TensorRT plugin""" + + def __init__(self): + super().__init__() + self.qkv = nn.Linear( + HIDDEN_DIM, HIDDEN_DIM + 2 * NUM_KV_HEADS * HEAD_DIM, bias=True + ) + self.out = nn.Linear(HIDDEN_DIM, HIDDEN_DIM, bias=False) + + def forward(self, x, kv_cache, ctx_len_tensor, rope): + bsz, seq_len, _ = x.shape + qkv = self.qkv(x) + + # kv_cache_start_idx: starting position in KV cache for each batch + # For normal inference, this is 0 (start from beginning) + kv_cache_start_idx = torch.zeros(bsz, dtype=torch.int32, device=x.device) + + # Custom plugin call (5 inputs for release version) + attn_out, updated_kv = torch.ops.xqa.attn.default( + qkv, + kv_cache, + ctx_len_tensor, + rope, + kv_cache_start_idx, + NUM_Q_HEADS, + NUM_KV_HEADS, + HEAD_DIM, + ) + + # Reshape from [B, S, num_heads, head_dim] to [B, S, hidden_dim] + attn_out = attn_out.reshape(bsz, seq_len, HIDDEN_DIM) + + return self.out(attn_out), updated_kv + + +# %% +# Test Functions +# -------------- + + +def test_case( + name: str, seq_len: int, has_past_context: bool, sdpa_model, trt_model, rope +): + """ + Run a single test case and validate correctness. + + Args: + name: Test case name + seq_len: Sequence length (1 for XQA, >1 for FMHA) + has_past_context: Whether to initialize KV cache with past tokens + sdpa_model: PyTorch SDPA reference model + trt_model: Compiled TensorRT model + rope: Precomputed RoPE cache + + Note: + With enable_delta_kv_output=1, TRT plugin outputs only the delta KV: + - Context Phase: [B, 2, H, seq_len, D] (newly processed tokens) + - Generation Phase: [B, 2, H, 1, D] (single new token) + Python runtime must merge this delta into the main KV cache. + """ + print(f"\n{name}") + + # Determine context length + past_len = 10 if has_past_context else 0 + ctx_len = torch.tensor([past_len + seq_len], dtype=torch.int32, device=DEVICE) + + # Initialize KV caches + sdpa_kv = torch.zeros( + BATCH_SIZE, + 2, + NUM_KV_HEADS, + KV_CACHE_CAPACITY, + HEAD_DIM, + dtype=DTYPE, + device=DEVICE, + ) + trt_kv = torch.zeros( + BATCH_SIZE, + 2, + NUM_KV_HEADS, + KV_CACHE_CAPACITY, + HEAD_DIM, + dtype=DTYPE, + device=DEVICE, + ) + + # Add past context if needed + if has_past_context: + past_values = torch.randn( + BATCH_SIZE, 2, NUM_KV_HEADS, past_len, HEAD_DIM, dtype=DTYPE, device=DEVICE + ) + sdpa_kv[:, :, :, :past_len, :] = past_values + trt_kv[:, :, :, :past_len, :] = past_values + print(f" Input: {seq_len} new tokens + {past_len} past tokens in cache") + else: + print(f" Input: {seq_len} tokens (empty KV cache)") + + # Generate input tokens + x = torch.randn(BATCH_SIZE, seq_len, HIDDEN_DIM, dtype=DTYPE, device=DEVICE) + + # Run both models + with torch.no_grad(): + sdpa_out, sdpa_kv_new = sdpa_model(x, sdpa_kv, ctx_len, rope) + trt_out, trt_kv_delta = trt_model(x, trt_kv, ctx_len, rope) + + # TRT plugin with enable_delta_kv_output=1 returns only delta KV + # Merge delta into main KV cache at the correct position + delta_seq_len = trt_kv_delta.shape[3] # Should be seq_len + trt_kv[:, :, :, past_len : past_len + delta_seq_len, :] = trt_kv_delta + + # Compute similarities + attn_sim = F.cosine_similarity( + sdpa_out.flatten().float(), trt_out.flatten().float(), dim=0 + ).item() + + # Compare the newly updated portion of KV cache (after merge) + new_kv_sim = F.cosine_similarity( + sdpa_kv_new[:, :, :, past_len : past_len + seq_len, :].flatten().float(), + trt_kv[:, :, :, past_len : past_len + seq_len, :].flatten().float(), + dim=0, + ).item() + + # Determine which kernel was used + kernel_type = "XQA (decode)" if seq_len == 1 else "FMHA (context)" + + # Print results + print(f" Kernel Used: {kernel_type}") + print(f" Attention Output: cosine_similarity = {attn_sim:.6f}") + print(f" Updated KV Cache: cosine_similarity = {new_kv_sim:.6f}") + + # If there's past context, verify it's preserved in our main buffer + if has_past_context: + past_sim = F.cosine_similarity( + sdpa_kv_new[:, :, :, :past_len, :].flatten().float(), + trt_kv[:, :, :, :past_len, :].flatten().float(), + dim=0, + ).item() + print(f" Past KV Preserved: cosine_similarity = {past_sim:.6f}") + passed = attn_sim >= 0.99 and new_kv_sim >= 0.99 and past_sim >= 0.99 + else: + passed = attn_sim >= 0.99 and new_kv_sim >= 0.99 + + status = "PASS" if passed else "FAIL" + print(f" Result: {status}") + + return passed, attn_sim, new_kv_sim + + +# %% +# Main Execution +# -------------- + +if __name__ == "__main__": + print("\nCustom Attention Plugin - Correctness Validation") + + # Precompute RoPE + rope = precompute_rope(HEAD_DIM, KV_CACHE_CAPACITY) + + # Create models + print("\nCreating models...") + sdpa_model = SDPAModel().to(DEVICE).to(DTYPE).eval() + plugin_model = PluginModel().to(DEVICE).to(DTYPE).eval() + + # Share weights + plugin_model.qkv.weight.data.copy_(sdpa_model.qkv.weight.data) + plugin_model.qkv.bias.data.copy_(sdpa_model.qkv.bias.data) + plugin_model.out.weight.data.copy_(sdpa_model.out.weight.data) + print("Weights shared between models") + + # Compile with Torch-TensorRT (with dynamic shapes for seq_len) + print("\nCompiling with Torch-TensorRT...") + x_example = torch.randn(BATCH_SIZE, 1, HIDDEN_DIM, dtype=DTYPE, device=DEVICE) + kv_example = torch.zeros( + BATCH_SIZE, + 2, + NUM_KV_HEADS, + KV_CACHE_CAPACITY, + HEAD_DIM, + dtype=DTYPE, + device=DEVICE, + ) + ctx_example = torch.tensor([1], dtype=torch.int32, device=DEVICE) + + # Enable dynamic shapes for seq_len dimension + inputs_spec = [ + torch_tensorrt.Input( + min_shape=(BATCH_SIZE, 1, HIDDEN_DIM), + opt_shape=(BATCH_SIZE, 8, HIDDEN_DIM), + max_shape=(BATCH_SIZE, 32, HIDDEN_DIM), + dtype=DTYPE, + ), + kv_example, + ctx_example, + rope, + ] + + with torch_tensorrt.logging.errors(): + trt_model = torch_tensorrt.compile( + plugin_model, + inputs=inputs_spec, + enabled_precisions={torch.float16}, + min_block_size=1, + truncate_double=True, + device=DEVICE, + ) + print("Compilation complete") + + # %% + # Run Test Cases + # -------------- + # Test all 4 combinations: {seq_len=1, seq_len>1} × {empty cache, with past} + + print("\nRunning Test Cases") + + results = [] + + # Test 1: Single token, empty cache (XQA kernel, cold start) + results.append( + test_case( + "Test 1: Single Token Generation (XQA) - Empty Cache", + seq_len=1, + has_past_context=False, + sdpa_model=sdpa_model, + trt_model=trt_model, + rope=rope, + ) + ) + + # Test 2: Single token, with past context (XQA kernel, typical decode) + results.append( + test_case( + "Test 2: Single Token Generation (XQA) - With Past Context", + seq_len=1, + has_past_context=True, + sdpa_model=sdpa_model, + trt_model=trt_model, + rope=rope, + ) + ) + + # Test 3: Multiple tokens, empty cache (FMHA kernel, prefill phase) + results.append( + test_case( + "Test 3: Context Processing (FMHA) - Empty Cache", + seq_len=16, + has_past_context=False, + sdpa_model=sdpa_model, + trt_model=trt_model, + rope=rope, + ) + ) + + # %% + # Multi-Step Generation Test + # --------------------------- + # Realistic test: Process initial context (FMHA), then generate tokens one by one (XQA) + # Note: With enable_delta_kv_output=1, we must merge delta KV into main buffer + + print("\nTest 4: Multi-Step Generation (FMHA -> XQA x 3)") + print("Simulating real LLM generation:") + print(" 1. Process initial prompt with FMHA (seq_len=16)") + print(" 2. Generate tokens one by one with XQA (seq_len=1)") + + # Step 1: Process initial prompt (FMHA) + initial_seq_len = 16 + x_init = torch.randn( + BATCH_SIZE, initial_seq_len, HIDDEN_DIM, dtype=DTYPE, device=DEVICE + ) + ctx_len_init = torch.tensor([initial_seq_len], dtype=torch.int32, device=DEVICE) + + sdpa_kv_multi = torch.zeros( + BATCH_SIZE, + 2, + NUM_KV_HEADS, + KV_CACHE_CAPACITY, + HEAD_DIM, + dtype=DTYPE, + device=DEVICE, + ) + trt_kv_multi = torch.zeros( + BATCH_SIZE, + 2, + NUM_KV_HEADS, + KV_CACHE_CAPACITY, + HEAD_DIM, + dtype=DTYPE, + device=DEVICE, + ) + + with torch.no_grad(): + sdpa_out_init, sdpa_kv_multi = sdpa_model( + x_init, sdpa_kv_multi, ctx_len_init, rope + ) + trt_out_init, trt_kv_delta = trt_model(x_init, trt_kv_multi, ctx_len_init, rope) + + # Merge delta KV into main buffer (context phase: delta has shape [B, 2, H, seq_len, D]) + delta_len = trt_kv_delta.shape[3] + trt_kv_multi[:, :, :, :delta_len, :] = trt_kv_delta + + init_sim = F.cosine_similarity( + sdpa_out_init.flatten().float(), trt_out_init.flatten().float(), dim=0 + ).item() + + print(f"\nStep 1: Initial prompt (FMHA, seq_len={initial_seq_len})") + print(f" Similarity: {init_sim:.6f}") + + # Step 2: Generate tokens one by one (XQA) + num_gen_tokens = 3 + all_passed_multi = init_sim > 0.99 + current_pos = initial_seq_len # Track current position in KV cache + + for gen_step in range(num_gen_tokens): + current_ctx_len = initial_seq_len + gen_step + 1 + x_gen = torch.randn(BATCH_SIZE, 1, HIDDEN_DIM, dtype=DTYPE, device=DEVICE) + ctx_len_gen = torch.tensor([current_ctx_len], dtype=torch.int32, device=DEVICE) + + with torch.no_grad(): + sdpa_out_gen, sdpa_kv_multi = sdpa_model( + x_gen, sdpa_kv_multi, ctx_len_gen, rope + ) + trt_out_gen, trt_kv_delta = trt_model( + x_gen, trt_kv_multi, ctx_len_gen, rope + ) + + # Merge delta KV into main buffer (generation phase: delta has shape [B, 2, H, 1, D]) + trt_kv_multi[:, :, :, current_pos : current_pos + 1, :] = trt_kv_delta + current_pos += 1 + + gen_sim = F.cosine_similarity( + sdpa_out_gen.flatten().float(), trt_out_gen.flatten().float(), dim=0 + ).item() + + kv_sim_gen = F.cosine_similarity( + sdpa_kv_multi[:, :, :, :current_ctx_len, :].flatten().float(), + trt_kv_multi[:, :, :, :current_ctx_len, :].flatten().float(), + dim=0, + ).item() + + passed = gen_sim > 0.99 and kv_sim_gen > 0.99 + all_passed_multi = all_passed_multi and passed + + print(f"\nStep {gen_step + 2}: Generate token {gen_step + 1} (XQA, seq_len=1)") + print(f" Attn similarity: {gen_sim:.6f}") + print(f" KV similarity: {kv_sim_gen:.6f}") + + results.append( + ( + all_passed_multi, + 1.0 if all_passed_multi else 0.0, + 1.0 if all_passed_multi else 0.0, + ) + ) + + print(f"\nResult: {'PASS - All steps matched!' if all_passed_multi else 'FAIL'}") + + # %% + # Summary + # ------- + + print("\nSUMMARY") + + test_names = [ + "Test 1: XQA - Empty Cache", + "Test 2: XQA - With Past", + "Test 3: FMHA - Empty Cache", + "Test 4: Multi-Step (FMHA->XQA)", + ] + + for name, (passed, attn_sim, kv_sim) in zip(test_names, results): + status = "PASS" if passed else "FAIL" + print(f"{name}: {status}") + print(f" Attention: {attn_sim:.4f}, KV Cache: {kv_sim:.4f}") + + all_passed = all(r[0] for r in results) + + if all_passed: + print("SUCCESS: All tests passed!") + print("Both FMHA and XQA kernels work correctly") + print("KV cache management is accurate") + print("Perfect agreement with PyTorch SDPA (cosine similarity >= 0.99)") + else: + print("FAILURE: Some tests failed") diff --git a/examples/dynamo/end_to_end_llm_generation_example.py b/examples/dynamo/end_to_end_llm_generation_example.py new file mode 100644 index 0000000000..02822014ad --- /dev/null +++ b/examples/dynamo/end_to_end_llm_generation_example.py @@ -0,0 +1,400 @@ +""" +End-to-End LLM Generation Example with TensorRT Attention Plugin + +This example demonstrates how to use the TensorRT attention plugin for +efficient LLM inference with KV caching. + +The plugin utilities are shared with tools/llm/run_llm.py for consistency. +""" + +import os +import sys +import time + +import torch +import torch.nn as nn +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + +# Add tools/llm to path for shared utilities +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../tools/llm")) + +from plugin_utils import ( + LLMPluginWrapper, + PluginAttention, + benchmark_plugin_generation, + compile_plugin_model, + create_kv_caches, + generate_with_plugin, + get_plugin_config, + get_plugin_rope_cache, + load_plugin, + register_plugin_op, + replace_attention_with_plugin, + set_plugin_config_from_model, +) + +# Configuration +MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct" +MAX_SEQ_LEN = 2048 +DTYPE = torch.float16 +DEVICE = torch.device("cuda:0") + +# Load the plugin +load_plugin() +register_plugin_op() + + +# ----------------------------------------------------------------------------- +# Backward Compatibility Exports +# ----------------------------------------------------------------------------- + +# These are exported for backward compatibility with any code that imports +# from this module directly. + +# Re-export Qwen2Wrapper as an alias for LLMPluginWrapper +Qwen2Wrapper = LLMPluginWrapper + + +# Re-export replace_attention for backward compatibility +def replace_attention(model, config): + """ + Replace attention modules with plugin attention. + + This is a backward-compatible wrapper around replace_attention_with_plugin. + """ + return replace_attention_with_plugin(model, config, MAX_SEQ_LEN, DEVICE, DTYPE) + + +def compile_model(model, input_ids, position_ids, kv_caches, ctx_len): + """ + Compile a model for TensorRT inference. + + This is a backward-compatible wrapper that extracts config from the model. + """ + # Get config from the wrapped model + if hasattr(model, "model"): + inner_model = model.model + if hasattr(inner_model, "config"): + config = inner_model.config + else: + config = inner_model.model.config + else: + config = model.config + + return compile_plugin_model(model, config, MAX_SEQ_LEN, DEVICE, DTYPE) + + +# Global config for backward compatibility with converter +TARGET_CONFIG = None + + +# ----------------------------------------------------------------------------- +# Helpers +# ----------------------------------------------------------------------------- + + +def apply_repetition_penalty(logits, generated_ids, penalty): + """Apply repetition penalty to logits.""" + if penalty == 1.0: + return logits + + score = torch.gather(logits, 1, generated_ids) + score = torch.where(score < 0, score * penalty, score / penalty) + logits.scatter_(1, generated_ids, score) + return logits + + +# ----------------------------------------------------------------------------- +# Benchmarking +# ----------------------------------------------------------------------------- + + +def benchmark_generation(model_func, isl, osl, config, run_name="Model"): + """ + Benchmark generation with the plugin model. + + This wraps benchmark_plugin_generation for backward compatibility. + """ + return benchmark_plugin_generation( + model_func, config, isl, osl, MAX_SEQ_LEN, DEVICE, DTYPE, run_name + ) + + +def run_pytorch_benchmark_manual(model, config, isl, osl): + """Run PyTorch benchmark with manual loop (no KV cache).""" + input_ids = torch.randint(0, config.vocab_size, (1, isl), device=DEVICE) + + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + + with torch.no_grad(): + generated_ids = input_ids + + for _ in range(osl): + outputs = model(generated_ids, use_cache=False) + next_token = torch.argmax(outputs.logits[:, -1, :], dim=-1).unsqueeze(0) + generated_ids = torch.cat([generated_ids, next_token], dim=1) + + end_event.record() + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + print( + f"PyTorch (Manual - No Cache) | ISL: {isl}, OSL: {osl} | Total Time: {elapsed_ms:.2f} ms | Tokens/sec: {osl / (elapsed_ms / 1000.0):.2f}" + ) + return elapsed_ms + + +def run_pytorch_benchmark_generate(model, config, isl, osl): + """Run PyTorch benchmark with model.generate() API.""" + input_ids = torch.randint(0, config.vocab_size, (1, isl), device=DEVICE) + + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + + with torch.no_grad(): + _ = model.generate( + input_ids, + max_new_tokens=osl, + min_new_tokens=osl, + do_sample=False, + use_cache=True, + pad_token_id=config.eos_token_id, + ) + + end_event.record() + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + print( + f"PyTorch (Generate) | ISL: {isl}, OSL: {osl} | Total Time: {elapsed_ms:.2f} ms | Tokens/sec: {osl / (elapsed_ms / 1000.0):.2f}" + ) + return elapsed_ms + + +def generate_reference(model, tokenizer, prompt, max_new_tokens=20): + """ + Generate reference output with PyTorch (greedy, no cache). + """ + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(DEVICE) + generated_ids = input_ids + + repetition_penalty = getattr(model.generation_config, "repetition_penalty", 1.0) + print( + f"DEBUG: Using repetition_penalty={repetition_penalty} for Reference Generation" + ) + + for _ in range(max_new_tokens): + current_seq_len = generated_ids.shape[1] + position_ids = torch.arange( + current_seq_len, dtype=torch.long, device=DEVICE + ).unsqueeze(0) + + outputs = model(generated_ids, position_ids=position_ids, use_cache=False) + next_token_logits = outputs.logits[:, -1, :] + + next_token_logits = apply_repetition_penalty( + next_token_logits, generated_ids, repetition_penalty + ) + next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0) + + if next_token.item() == tokenizer.eos_token_id: + break + + generated_ids = torch.cat([generated_ids, next_token], dim=1) + + return tokenizer.decode(generated_ids[0], skip_special_tokens=True) + + +def verify_output(trt_model_func, model_pytorch, tokenizer, prompt, max_new_tokens=20): + """Verify TensorRT output matches PyTorch reference.""" + print(f"\nPrompt: '{prompt}'") + + # 1. PyTorch Reference Generation + print("\n=== PyTorch Reference Generation ===") + inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) + input_ids = inputs.input_ids + + with torch.no_grad(): + pyt_outputs = generate_reference( + model_pytorch, tokenizer, prompt, max_new_tokens=30 + ) + print(f"PyTorch Reference Text Output: {pyt_outputs}") + + with torch.no_grad(): + pyt_outputs_generate_ids = model_pytorch.generate( + input_ids, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, + pad_token_id=tokenizer.eos_token_id, + ) + pyt_outputs_generate_text = tokenizer.decode( + pyt_outputs_generate_ids[0], skip_special_tokens=True + ) + print(f"PyTorch Generate Text Output: {pyt_outputs_generate_text}") + + pyt_text = pyt_outputs + print(f"PyTorch Output: {pyt_text}") + + # 2. TensorRT Plugin Generation + print("\n=== TensorRT Plugin Generation ===") + + repetition_penalty = getattr( + model_pytorch.generation_config, "repetition_penalty", 1.0 + ) + print( + f"DEBUG: Using repetition_penalty={repetition_penalty} for TensorRT Generation" + ) + + seq_len = input_ids.shape[1] + position_ids = torch.arange(seq_len, dtype=torch.long, device=DEVICE).unsqueeze(0) + + config = model_pytorch.config + kv_caches = create_kv_caches(config, MAX_SEQ_LEN, 1, DEVICE, DTYPE) + + generated_ids = input_ids + + # Prefill + ctx_len = torch.tensor([seq_len], dtype=torch.int32, device=DEVICE) + logits, kv_caches_delta = trt_model_func( + input_ids, position_ids, kv_caches, ctx_len + ) + + for i, delta in enumerate(kv_caches_delta): + seq_len_out = delta.shape[3] + kv_caches[i][:, :, :, :seq_len_out, :] = delta + + next_token_logits = logits[:, -1, :] + next_token_logits = apply_repetition_penalty( + next_token_logits, generated_ids, repetition_penalty + ) + next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0) + + generated_ids = torch.cat([generated_ids, next_token], dim=1) + + # Decode + cur_pos = seq_len + + if next_token.item() != tokenizer.eos_token_id: + for _ in range(max_new_tokens - 1): + input_ids_step = next_token + position_ids_step = torch.tensor( + [[cur_pos]], dtype=torch.long, device=DEVICE + ) + ctx_len_step = torch.tensor([cur_pos + 1], dtype=torch.int32, device=DEVICE) + + logits, kv_caches_delta = trt_model_func( + input_ids_step, position_ids_step, kv_caches, ctx_len_step + ) + + for i, delta in enumerate(kv_caches_delta): + kv_caches[i][:, :, :, cur_pos : cur_pos + 1, :] = delta + + next_token_logits = logits[:, -1, :] + next_token_logits = apply_repetition_penalty( + next_token_logits, generated_ids, repetition_penalty + ) + next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0) + + if next_token.item() == tokenizer.eos_token_id: + break + + generated_ids = torch.cat([generated_ids, next_token], dim=1) + cur_pos += 1 + + trt_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + print(f"TensorRT Output: {trt_text}") + + # 3. Comparison + print("\n=== Comparison ===") + if pyt_text == trt_text: + print("SUCCESS: Outputs match exactly!") + else: + print("FAILURE: Outputs differ.") + print(f"PyTorch: {pyt_text}") + print(f"TensorRT: {trt_text}") + + +# ----------------------------------------------------------------------------- +# Main +# ----------------------------------------------------------------------------- + +if __name__ == "__main__": + torch.manual_seed(42) + + print(f"Loading {MODEL_NAME}...") + config = AutoConfig.from_pretrained(MODEL_NAME) + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + + # Set global config for backward compatibility + # Note: TARGET_CONFIG is defined at module level for backward compatibility + globals()["TARGET_CONFIG"] = config + + # Set plugin config + set_plugin_config_from_model(config, MAX_SEQ_LEN) + + # 1. PyTorch Model + model_pytorch = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, + torch_dtype=DTYPE, + ).to(DEVICE) + model_pytorch.eval() + + # 2. TensorRT Plugin Model + model_trt = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=DTYPE).to( + DEVICE + ) + model_trt.eval() + + model_trt = replace_attention(model_trt, config) + wrapper = LLMPluginWrapper(model_trt) + + # Compilation + print("Compiling TensorRT model...") + + dummy_input_ids = torch.tensor([[1, 2, 3]], device=DEVICE) + dummy_pos_ids = torch.tensor([[0, 1, 2]], device=DEVICE) + dummy_ctx_len = torch.tensor([3], dtype=torch.int32, device=DEVICE) + dummy_kvs = create_kv_caches(config, MAX_SEQ_LEN, 1, DEVICE, DTYPE) + + trt_model_func = compile_model( + wrapper, dummy_input_ids, dummy_pos_ids, dummy_kvs, dummy_ctx_len + ) + + # 3. Verification + print("\n=== Verifying Output Accuracy ===") + verify_output( + trt_model_func, + model_pytorch, + tokenizer, + "What is parallel programming?", + max_new_tokens=30, + ) + + # 4. Benchmarks + benchmarks = [ + (128, 128), + (256, 128), + (512, 256), + ] + + print("\n=== Starting Benchmarks ===") + print(f"Device: {torch.cuda.get_device_name(0)}") + + for isl, osl in benchmarks: + print("-" * 60) + # PyTorch Manual Loop + run_pytorch_benchmark_manual(model_pytorch, config, isl, osl) + + # PyTorch Generate API + run_pytorch_benchmark_generate(model_pytorch, config, isl, osl) + + # TensorRT + benchmark_generation(trt_model_func, isl, osl, config, run_name="TensorRT") diff --git a/tools/llm/README.md b/tools/llm/README.md index 05a1e3cc60..e5f6f47b25 100644 --- a/tools/llm/README.md +++ b/tools/llm/README.md @@ -7,6 +7,9 @@ This directory provides utilities and scripts for compiling, optimizing, and ben - **Model Support:** Works with popular LLMs such as Llama-3, Qwen2.5, etc. - **VLM Support:** Supports Visual Language Models like Qwen2.5-VL and Eagle2. - **Precision Modes:** Supports FP16, BF16, and FP32. +- **Multiple Backends:** + - **SDPA Backend** (default): Registers custom lowering pass for SDPA operations, enabling TensorRT conversion with optional static KV cache support + - **Plugin Backend**: Uses TensorRT Edge-LLM attention plugin for optimized inference with built-in KV cache management - **KV Cache:** Supports static and dynamic KV cache for efficient autoregressive decoding. - **Benchmarking:** Measures and compares throughput and latency for PyTorch and TensorRT backends. - **Custom Attention:** Registers and converts custom scaled dot-product attention (SDPA) for compatibility with TensorRT. @@ -37,41 +40,164 @@ We have officially verified support for the following models: #### Text-only LLMs: `run_llm.py` +**1. Generation with Output Verification** + +Compare PyTorch and TensorRT outputs to verify correctness: + +*SDPA Backend:* +```bash +python run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --backend sdpa \ + --prompt "What is parallel programming?" --precision FP16 --num_tokens 30 --enable_pytorch_run +``` +
+Expected Output + +``` +========= PyTorch ========= +PyTorch model generated text: What is parallel programming? Parallel programming is a technique used to improve the performance of a program by dividing the work into smaller tasks and executing them simultaneously on multiple processors or cores. +=================================== +========= TensorRT ========= +TensorRT model generated text: What is parallel programming? Parallel programming is a technique used to improve the performance of a program by dividing the work into smaller tasks and executing them simultaneously on multiple processors or cores. +=================================== +PyTorch and TensorRT outputs match: True +``` +
+ +*Plugin Backend:* +```bash +python run_llm.py --model Qwen/Qwen2.5-0.5B-Instruct --backend plugin \ + --prompt "What is parallel programming?" --precision FP16 --num_tokens 30 --enable_pytorch_run +``` +
+Expected Output + +``` +========= PyTorch ========= +PyTorch model generated text: What is parallel programming? What are the benefits of parallel programming? What are the challenges of parallel programming? What are the different types of parallel programming? What are the advantages of +=================================== +========= TensorRT ========= +TensorRT model generated text: What is parallel programming? What are the benefits of parallel programming? What are the challenges of parallel programming? What are the different types of parallel programming? What are the advantages of +=================================== +PyTorch and TensorRT outputs match: True +``` +
+ +**2. Benchmarking for Performance Comparison** + +*Plugin Backend (compares TensorRT-Plugin vs PyTorch):* ```bash -python run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --prompt "What is parallel programming?" --precision FP16 --num_tokens 128 --cache static_v2 --benchmark +python run_llm.py --model Qwen/Qwen2.5-0.5B-Instruct --backend plugin --precision FP16 \ + --benchmark --iterations 5 --isl 128 --num_tokens 20 --batch_size 1 --enable_pytorch_run ``` +*SDPA with Static Cache (compares TensorRT-SDPA-StaticCache vs PyTorch):* +```bash +python run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --backend sdpa --cache static_v2 \ + --precision FP16 --benchmark --iterations 5 --isl 128 --num_tokens 20 --batch_size 1 --enable_pytorch_run +``` + +> **Note**: In benchmark mode, `--prompt` is not used. Random input tokens are generated based on `--isl` (input sequence length). + #### Vision Language Models: `run_vlm.py` +*Generation with Output Verification:* +```bash +python run_vlm.py --model nvidia/Eagle2-2B --precision FP16 --num_tokens 64 --cache static_v1 --enable_pytorch_run +``` + +*Benchmarking:* ```bash -python run_vlm.py --model nvidia/Eagle2-2B --precision FP16 --num_tokens 128 --cache static_v1 --enable_pytorch_run --benchmark +python run_vlm.py --model nvidia/Eagle2-2B --precision FP16 --cache static_v1 --benchmark --iterations 5 --num_tokens 128 ``` #### Key Arguments +**Model Configuration:** - `--model`: Name or path of the HuggingFace LLM/VLM. - `--tokenizer`: (Optional) Tokenizer name; defaults to model. -- `--prompt`: Input prompt for generation. +- `--backend`: Backend to use (`sdpa` or `plugin`). Default is `sdpa`. Only applicable for LLM models. + +**Generation Settings:** +- `--prompt`: Input prompt for generation (generation mode only, ignored in benchmark mode). - `--image_path`: (Optional) Path to input image file for VLM models. If not provided, will use a sample image. - `--precision`: Precision mode (`FP16`, `FP32`). - `--num_tokens`: Number of output tokens to generate. -- `--cache`: KV cache type (`static_v1`, `static_v2`, or empty for no KV caching). -- `--benchmark`: Enable benchmarking mode. + +**Cache and Optimization:** +- `--cache`: KV cache type for SDPA backend (`static_v1`, `static_v2`, or empty for no KV caching). + - Note: Not applicable for plugin backend (manages cache internally). + +**Benchmarking:** +- `--benchmark`: Enable benchmarking mode (uses random inputs instead of prompt). +- `--iterations`: Number of benchmark iterations. Default is 5. +- `--isl`: Input sequence length for benchmarking. Default is 2048. +- `--batch_size`: Batch size for benchmarking. Default is 1. - `--enable_pytorch_run`: Also run and compare PyTorch baseline. ### Caching Strategies +#### SDPA Backend - **Static Cache v1/v2:** Adds static KV cache tensors as model inputs/outputs for efficient reuse. - **No Cache:** Standard autoregressive decoding. Please read our tutorial on how static cache is implemented. +#### Plugin Backend +The plugin backend uses the TensorRT Edge-LLM AttentionPlugin which manages KV cache internally. The `--cache` option is not applicable and will be ignored if specified with `--backend plugin`. + +## Plugin Backend Setup + +To use the plugin backend (`--backend plugin`), you need to build the TensorRT Edge-LLM AttentionPlugin library. + +### Building the AttentionPlugin + +Currently, the plugin support requires a custom build from a feature branch: + +```bash +# Clone the repository with the torch-tensorrt-python-runtime feature +git clone -b feature/torch-tensorrt-python-runtime https://github.com/chohk88/TensorRT-Edge-LLM.git +cd TensorRT-Edge-LLM + +# Build the plugin library +mkdir build && cd build + +# Configure with CMake (adjust paths based on your environment) +# Example for typical Ubuntu setup with CUDA 12.9 and TensorRT in /usr: +cmake .. -DTRT_PACKAGE_DIR=/usr -DCUDA_VERSION=12.9 + +# Build +make -j$(nproc) + +# The plugin library will be at: build/libNvInfer_edgellm_plugin.so +``` + +> **Note**: CMake configuration may vary depending on your system setup. Common options include: +> - `-DTRT_PACKAGE_DIR`: TensorRT installation directory (e.g., `/usr`, `/usr/local`) +> - `-DCUDA_VERSION`: CUDA version (e.g., `12.9`, `12.6`) +> +> Refer to the [TensorRT-Edge-LLM build documentation](https://github.com/chohk88/TensorRT-Edge-LLM/tree/feature/torch-tensorrt-python-runtime) for complete build instructions and dependencies. + +After building, update the plugin path in `plugin_utils.py`: +```python +DEFAULT_PLUGIN_PATH = "/path/to/your/TensorRT-Edge-LLM/build/libNvInfer_edgellm_plugin.so" +``` + +### Additional Examples + +Two comprehensive examples are provided in `examples/dynamo/` to demonstrate plugin usage: + +- **`attention_plugin_example.py`**: Standalone example showing how to use the AttentionPlugin with custom models +- **`end_to_end_llm_generation_example.py`**: End-to-end LLM generation example with plugin integration + +These examples can serve as references for integrating the plugin into your own applications. + ## Extension This codebase can be extended to - Add new models by specifying their HuggingFace name. - Implement new cache strategies by adding FX graph passes. - Customize SDPA conversion for new attention mechanisms. +- Add new backend implementations (see `plugin_utils.py` for plugin backend reference). ## Limitations - We do not currently support sliding window attention (used in Gemma3 and Qwen 3 models) yet. diff --git a/tools/llm/plugin_utils.py b/tools/llm/plugin_utils.py new file mode 100644 index 0000000000..4622a76785 --- /dev/null +++ b/tools/llm/plugin_utils.py @@ -0,0 +1,942 @@ +""" +Plugin utilities for TensorRT LLM inference with custom attention plugins. + +This module provides model-agnostic utilities for using TensorRT attention plugins +with various LLM architectures (Qwen, Llama, etc.). +""" + +import ctypes +import inspect +import os +from typing import Any, Callable, Dict, List, Optional, Tuple, Type + +import numpy as np +import tensorrt as trt +import torch +import torch.nn as nn +import torch_tensorrt +from torch_tensorrt.dynamo.conversion import ( + ConversionContext, + dynamo_tensorrt_converter, +) +from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor + +# Default plugin path - can be overridden +DEFAULT_PLUGIN_PATH = ( + "/develop/TensorRT/TensorRT-Edge-LLM-release/build/libNvInfer_edgellm_plugin.so" +) + +# Global configuration for plugin converter +_PLUGIN_CONFIG: Dict[str, Any] = {} + + +def load_plugin(plugin_path: Optional[str] = None) -> bool: + """ + Load the TensorRT attention plugin library. + + Args: + plugin_path: Path to the plugin .so file. If None, uses DEFAULT_PLUGIN_PATH. + + Returns: + True if plugin was loaded successfully, False otherwise. + + Raises: + RuntimeError: If plugin file does not exist. + """ + path = plugin_path or DEFAULT_PLUGIN_PATH + if not os.path.exists(path): + raise RuntimeError(f"Plugin not found at {path}") + ctypes.CDLL(path) + return True + + +def set_plugin_config( + num_attention_heads: int, + num_key_value_heads: int, + head_dim: int, + max_seq_len: int = 2048, + max_batch_size: int = 4, +) -> None: + """ + Set global configuration for the plugin converter. + + Args: + num_attention_heads: Number of query attention heads. + num_key_value_heads: Number of key/value attention heads (for GQA). + head_dim: Dimension of each attention head. + max_seq_len: Maximum sequence length for KV cache. + max_batch_size: Maximum batch size. + """ + global _PLUGIN_CONFIG + _PLUGIN_CONFIG = { + "num_attention_heads": num_attention_heads, + "num_key_value_heads": num_key_value_heads, + "head_dim": head_dim, + "max_seq_len": max_seq_len, + "max_batch_size": max_batch_size, + } + + +def get_plugin_config() -> Dict[str, Any]: + """Get the current plugin configuration.""" + return _PLUGIN_CONFIG.copy() + + +def set_plugin_config_from_model(model_config: Any, max_seq_len: int = 2048) -> None: + """ + Set plugin configuration from a HuggingFace model config. + + Args: + model_config: HuggingFace model configuration object. + max_seq_len: Maximum sequence length for KV cache. + """ + # Qwen3 has explicit head_dim in config that differs from hidden_size // num_attention_heads + if hasattr(model_config, "head_dim") and model_config.head_dim is not None: + head_dim = model_config.head_dim + else: + head_dim = model_config.hidden_size // model_config.num_attention_heads + + set_plugin_config( + num_attention_heads=model_config.num_attention_heads, + num_key_value_heads=model_config.num_key_value_heads, + head_dim=head_dim, + max_seq_len=max_seq_len, + ) + + +# ----------------------------------------------------------------------------- +# Plugin Op Registration +# ----------------------------------------------------------------------------- + + +def _register_plugin_op_impl() -> None: + """ + Internal implementation to register the xqa::attn custom op for PyTorch. + + Note: The release version of TensorRT-Edge-LLM requires 5 inputs: + - qkv: [B, S, (Hq+Hk+Hv)*D] fused QKV tensor + - kv: [B, 2, Hkv, Capacity, D] KV cache tensor + - ctx_len: [B] context length per batch + - rope: [S, D] rotary position encoding + - kv_cache_start_idx: [B] starting index in KV cache (required for release version) + """ + + @torch.library.custom_op("xqa::attn", mutates_args=()) + def attn( + qkv: torch.Tensor, + kv: torch.Tensor, + ctx_len: torch.Tensor, + rope: torch.Tensor, + kv_cache_start_idx: torch.Tensor, # Required 5th input for release plugin + nq: int, + nkv: int, + d: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = qkv.shape[0] + seq_len = qkv.shape[1] + attn_out = torch.zeros( + batch_size, seq_len, nq, d, dtype=qkv.dtype, device=qkv.device + ) + updated_kv = kv.clone() + return attn_out, updated_kv + + @torch.library.register_fake("xqa::attn") + def _(qkv, kv, ctx_len, rope, kv_cache_start_idx, nq, nkv, d): + batch_size = qkv.shape[0] + seq_len = qkv.shape[1] + attn_out = torch.empty( + batch_size, seq_len, nq, d, dtype=qkv.dtype, device=qkv.device + ) + updated_kv = kv.clone() + return attn_out, updated_kv + + +def register_plugin_op() -> None: + """ + Register the xqa::attn custom op for PyTorch. + + This function is idempotent - safe to call multiple times. + """ + if hasattr(torch.ops, "xqa") and hasattr(torch.ops.xqa, "attn"): + return + _register_plugin_op_impl() + + +# Register the op at module import time so the converter decorator works +# This is safe because the op registration is idempotent +if not (hasattr(torch.ops, "xqa") and hasattr(torch.ops.xqa, "attn")): + _register_plugin_op_impl() + + +# ----------------------------------------------------------------------------- +# TensorRT Converter +# ----------------------------------------------------------------------------- + + +@dynamo_tensorrt_converter(torch.ops.xqa.attn.default, supports_dynamic_shapes=True) +def convert_attn(ctx: ConversionContext, target, args, kwargs, name): + """ + Convert xqa::attn op to TensorRT AttentionPlugin. + + Release version of TensorRT-Edge-LLM requires 5 inputs: + - qkv, kv, ctx_len, rope, kv_cache_start_idx + + Plugin fields for release version: + - num_q_heads, num_kv_heads, head_size, enable_tree_attention, enable_delta_kv_output + """ + # args: qkv, kv, ctx_len, rope, kv_cache_start_idx, nq, nkv, d + qkv, kv, ctx_len, rope, kv_cache_start_idx, nq, nkv, d = args[:8] + + creator = trt.get_plugin_registry().get_plugin_creator("AttentionPlugin", "1", "") + if creator is None: + raise RuntimeError("AttentionPlugin not found in TensorRT plugin registry!") + + # Get config from global settings + config = get_plugin_config() + if config: + nq_val = config["num_attention_heads"] + nkv_val = config["num_key_value_heads"] + d_val = config["head_dim"] + else: + # Fallback to values from args (may not work correctly) + nq_val = nq if isinstance(nq, int) else 14 + nkv_val = nkv if isinstance(nkv, int) else 2 + d_val = d if isinstance(d, int) else 64 + + # Plugin fields for release version of TensorRT-Edge-LLM + field_list = [ + trt.PluginField( + "num_q_heads", np.array([nq_val], dtype=np.int32), trt.PluginFieldType.INT32 + ), + trt.PluginField( + "num_kv_heads", + np.array([nkv_val], dtype=np.int32), + trt.PluginFieldType.INT32, + ), + trt.PluginField( + "head_size", np.array([d_val], dtype=np.int32), trt.PluginFieldType.INT32 + ), + trt.PluginField( + "enable_tree_attention", + np.array([0], dtype=np.int32), + trt.PluginFieldType.INT32, + ), + trt.PluginField( + "enable_delta_kv_output", + np.array([1], dtype=np.int32), + trt.PluginFieldType.INT32, + ), + ] + + fields = trt.PluginFieldCollection(field_list) + plugin = creator.create_plugin(name, fields) + + # 5 inputs for release version: qkv, kv, ctx_len, rope, kv_cache_start_idx + inputs = [ + ( + get_trt_tensor(ctx, i, f"{name}_i{idx}") + if not isinstance(i, trt.ITensor) + else i + ) + for idx, i in enumerate([qkv, kv, ctx_len, rope, kv_cache_start_idx]) + ] + + # Handle ctx_len shape if needed (squeeze if [B, 1] -> [B]) + if len(inputs[2].shape) == 2 and inputs[2].shape[1] == 1: + shuffle_layer = ctx.net.add_shuffle(inputs[2]) + shuffle_layer.reshape_dims = (inputs[2].shape[0],) + inputs[2] = shuffle_layer.get_output(0) + + # Handle kv_cache_start_idx shape if needed (squeeze if [B, 1] -> [B]) + if len(inputs[4].shape) == 2 and inputs[4].shape[1] == 1: + shuffle_layer = ctx.net.add_shuffle(inputs[4]) + shuffle_layer.reshape_dims = (inputs[4].shape[0],) + inputs[4] = shuffle_layer.get_output(0) + + layer = ctx.net.add_plugin_v2(inputs, plugin) + return layer.get_output(0), layer.get_output(1) + + +# ----------------------------------------------------------------------------- +# RoPE Cache Generation +# ----------------------------------------------------------------------------- + + +def get_plugin_rope_cache( + rotary_emb: nn.Module, + max_seq_len: int, + head_dim: int, + device: torch.device, +) -> torch.Tensor: + """ + Generate RoPE cache tensor for the plugin from a rotary embedding module. + + Args: + rotary_emb: The rotary embedding module from the model. + max_seq_len: Maximum sequence length. + head_dim: Dimension of each attention head. + device: Device to create the cache on. + + Returns: + RoPE cache tensor of shape [1, max_seq_len, head_dim]. + """ + inv_freq = rotary_emb.inv_freq.to(device).float() + attention_scaling = getattr(rotary_emb, "attention_scaling", 1.0) + t = torch.arange(max_seq_len, device=device, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + cos_half = freqs.cos() * attention_scaling + sin_half = freqs.sin() * attention_scaling + rope = torch.cat([cos_half, sin_half], dim=-1) + return rope.unsqueeze(0) + + +# ----------------------------------------------------------------------------- +# Plugin Attention Module +# ----------------------------------------------------------------------------- + + +class PluginAttention(nn.Module): + """ + Model-agnostic Plugin Attention module that replaces standard attention. + + This module wraps the projection layers from the original attention module + and uses the xqa::attn plugin op for the attention computation. + + Supports: + - Qwen2.5, Llama: Standard attention + - Qwen3: Attention with QK Normalization (q_norm, k_norm) + """ + + def __init__( + self, + original_attn: nn.Module, + config: Any, + layer_idx: int, + rope_cache: torch.Tensor, + ): + """ + Initialize PluginAttention. + + Args: + original_attn: The original attention module to wrap. + config: Model configuration. + layer_idx: Index of this layer in the model. + rope_cache: Pre-computed RoPE cache tensor. + """ + super().__init__() + self.q_proj = original_attn.q_proj + self.k_proj = original_attn.k_proj + self.v_proj = original_attn.v_proj + self.o_proj = original_attn.o_proj + + # Qwen3 has QK Normalization + self.q_norm = getattr(original_attn, "q_norm", None) + self.k_norm = getattr(original_attn, "k_norm", None) + + self.num_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + + # Qwen3 has explicit head_dim that may differ from hidden_size // num_attention_heads + if hasattr(config, "head_dim") and config.head_dim is not None: + self.head_dim = config.head_dim + else: + self.head_dim = config.hidden_size // config.num_attention_heads + + # For Qwen3, attention output size is num_heads * head_dim, not hidden_size + self.attn_hidden_size = self.num_heads * self.head_dim + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + self.register_buffer("rope_cache", rope_cache) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_value: Optional[torch.Tensor] = None, + ctx_len: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass using the plugin attention. + + Args: + hidden_states: Input tensor of shape [batch, seq_len, hidden_size]. + attention_mask: Unused (plugin handles masking internally). + position_ids: Position IDs (unused, plugin uses RoPE cache). + past_key_value: KV cache tensor of shape [batch, 2, num_kv_heads, capacity, head_dim]. + ctx_len: Context length tensor for each batch item. + + Returns: + Tuple of (output tensor, updated KV cache). + """ + batch_size, seq_len, _ = hidden_states.shape + + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + # Qwen3: Apply QK Normalization if available + if self.q_norm is not None: + # Reshape for per-head normalization: [B, S, num_heads, head_dim] + q = q.view(batch_size, seq_len, self.num_heads, self.head_dim) + q = self.q_norm(q) + q = q.view(batch_size, seq_len, -1) + + if self.k_norm is not None: + # Reshape for per-head normalization: [B, S, num_kv_heads, head_dim] + k = k.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim) + k = self.k_norm(k) + k = k.view(batch_size, seq_len, -1) + + qkv = torch.cat([q, k, v], dim=-1) + + if ctx_len is None: + ctx_len = torch.tensor( + [seq_len], dtype=torch.int32, device=hidden_states.device + ).expand(batch_size) + + rope_fp32 = self.rope_cache.float() + + if past_key_value is None: + raise ValueError("past_key_value (KV cache tensor) must be provided") + + # kv_cache_start_idx: starting position in KV cache for each batch + # For normal inference, this is 0 (start from beginning) + kv_cache_start_idx = torch.zeros( + batch_size, dtype=torch.int32, device=hidden_states.device + ) + + attn_out, updated_kv = torch.ops.xqa.attn.default( + qkv, + past_key_value, + ctx_len, + rope_fp32, + kv_cache_start_idx, + self.num_heads, + self.num_key_value_heads, + self.head_dim, + ) + + # Use attn_hidden_size for reshape (may differ from hidden_size in Qwen3) + attn_out = attn_out.reshape(batch_size, seq_len, self.attn_hidden_size) + output = self.o_proj(attn_out) + return output, updated_kv + + +# ----------------------------------------------------------------------------- +# Model Wrappers +# ----------------------------------------------------------------------------- + + +class LLMPluginWrapper(nn.Module): + """ + Generic wrapper for LLM models with plugin attention. + + This wrapper handles the forward pass for models with replaced attention modules, + managing KV caches and context lengths appropriately. + """ + + def __init__(self, model: nn.Module, model_type: str = "auto"): + """ + Initialize the wrapper. + + Args: + model: The model with replaced attention modules. + model_type: Type of model ("qwen", "llama", or "auto" for auto-detection). + """ + super().__init__() + self.model = model + self.model_type = ( + self._detect_model_type(model) if model_type == "auto" else model_type + ) + + def _detect_model_type(self, model: nn.Module) -> str: + """Auto-detect model type from model structure.""" + model_class = model.__class__.__name__.lower() + if "qwen" in model_class: + return "qwen" + elif "llama" in model_class or "mistral" in model_class: + return "llama" + else: + # Default to generic transformer structure + return "generic" + + def _get_transformer(self) -> nn.Module: + """Get the transformer backbone based on model type.""" + if self.model_type == "qwen": + return self.model.model + elif self.model_type == "llama": + return self.model.model + else: + # Try common attribute names + for attr in ["model", "transformer", "backbone"]: + if hasattr(self.model, attr): + return getattr(self.model, attr) + raise ValueError( + f"Cannot find transformer backbone for model type: {self.model_type}" + ) + + def _get_layers(self, transformer: nn.Module) -> nn.ModuleList: + """Get the list of transformer layers.""" + for attr in ["layers", "h", "blocks"]: + if hasattr(transformer, attr): + return getattr(transformer, attr) + raise ValueError("Cannot find transformer layers") + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + ctx_len: torch.Tensor, + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Forward pass with plugin attention. + + Args: + input_ids: Input token IDs [batch, seq_len]. + position_ids: Position IDs [batch, seq_len]. + kv_caches: List of KV cache tensors, one per layer. + ctx_len: Context length tensor [batch]. + + Returns: + Tuple of (logits, list of updated KV caches). + """ + transformer = self._get_transformer() + hidden_states = transformer.embed_tokens(input_ids) + + layers = self._get_layers(transformer) + new_kv_caches = [] + + for i, layer in enumerate(layers): + past_key_value = kv_caches[i] + residual = hidden_states + + # Input layer norm + if hasattr(layer, "input_layernorm"): + hidden_states = layer.input_layernorm(hidden_states) + elif hasattr(layer, "ln_1"): + hidden_states = layer.ln_1(hidden_states) + + # Self attention + hidden_states, updated_kv = layer.self_attn( + hidden_states=hidden_states, + attention_mask=None, + position_ids=position_ids, + past_key_value=past_key_value, + ctx_len=ctx_len, + ) + hidden_states = residual + hidden_states + + # Post attention layer norm + MLP + residual = hidden_states + if hasattr(layer, "post_attention_layernorm"): + hidden_states = layer.post_attention_layernorm(hidden_states) + elif hasattr(layer, "ln_2"): + hidden_states = layer.ln_2(hidden_states) + hidden_states = layer.mlp(hidden_states) + hidden_states = residual + hidden_states + + new_kv_caches.append(updated_kv) + + # Final layer norm + if hasattr(transformer, "norm"): + hidden_states = transformer.norm(hidden_states) + elif hasattr(transformer, "ln_f"): + hidden_states = transformer.ln_f(hidden_states) + + # LM head + logits = self.model.lm_head(hidden_states) + + return logits, new_kv_caches + + +# ----------------------------------------------------------------------------- +# Model Modification Functions +# ----------------------------------------------------------------------------- + + +def replace_attention_with_plugin( + model: nn.Module, + config: Any, + max_seq_len: int, + device: torch.device, + dtype: torch.dtype = torch.float16, +) -> nn.Module: + """ + Replace all attention modules in a model with PluginAttention. + + Args: + model: The HuggingFace model to modify. + config: Model configuration. + max_seq_len: Maximum sequence length for RoPE cache. + device: Device for the model. + dtype: Data type for the model. + + Returns: + The modified model with plugin attention. + """ + # Get rotary embedding from model + transformer = model.model if hasattr(model, "model") else model + + # Try to find rotary embedding + rotary_emb = None + if hasattr(transformer, "rotary_emb"): + rotary_emb = transformer.rotary_emb + elif hasattr(transformer, "layers") and len(transformer.layers) > 0: + first_layer = transformer.layers[0] + if hasattr(first_layer, "self_attn") and hasattr( + first_layer.self_attn, "rotary_emb" + ): + rotary_emb = first_layer.self_attn.rotary_emb + + if rotary_emb is None: + raise ValueError("Cannot find rotary embedding in model") + + # Qwen3 has explicit head_dim that may differ from hidden_size // num_attention_heads + if hasattr(config, "head_dim") and config.head_dim is not None: + head_dim = config.head_dim + else: + head_dim = config.hidden_size // config.num_attention_heads + rope_cache = get_plugin_rope_cache(rotary_emb, max_seq_len, head_dim, device) + + # Get layers + if hasattr(transformer, "layers"): + layers = transformer.layers + elif hasattr(transformer, "h"): + layers = transformer.h + else: + raise ValueError("Cannot find transformer layers") + + # Replace attention modules + for i, layer in enumerate(layers): + layer.self_attn = PluginAttention(layer.self_attn, config, i, rope_cache) + + return model + + +# ----------------------------------------------------------------------------- +# Compilation +# ----------------------------------------------------------------------------- + + +def compile_plugin_model( + model: nn.Module, + config: Any, + max_seq_len: int, + device: torch.device, + dtype: torch.dtype = torch.float16, + debug: bool = False, +) -> Callable: + """ + Compile a model with plugin attention for TensorRT inference. + + Args: + model: The wrapped model (should be LLMPluginWrapper or similar). + config: Model configuration. + max_seq_len: Maximum sequence length. + device: Device for compilation. + dtype: Data type. + debug: Whether to enable debug logging. + + Returns: + Compiled TensorRT model function. + """ + # Prepare dummy inputs + num_layers = config.num_hidden_layers + num_kv_heads = config.num_key_value_heads + # Qwen3 has explicit head_dim that may differ from hidden_size // num_attention_heads + if hasattr(config, "head_dim") and config.head_dim is not None: + head_dim = config.head_dim + else: + head_dim = config.hidden_size // config.num_attention_heads + + dummy_input_ids = torch.tensor([[1, 2, 3]], device=device) + dummy_pos_ids = torch.tensor([[0, 1, 2]], device=device) + dummy_ctx_len = torch.tensor([3], dtype=torch.int32, device=device) + dummy_kvs = [ + torch.zeros( + 1, 2, num_kv_heads, max_seq_len, head_dim, dtype=dtype, device=device + ) + for _ in range(num_layers) + ] + + # Dynamic shapes + seq_len_dim = torch.export.Dim("seq_len", min=1, max=max_seq_len) + kv_cache_dynamics = [{}] * num_layers + dynamic_shapes = { + "input_ids": {1: seq_len_dim}, + "position_ids": {1: seq_len_dim}, + "kv_caches": kv_cache_dynamics, + "ctx_len": {}, + } + + # Export + ep = torch.export.export( + model, + args=(dummy_input_ids, dummy_pos_ids, dummy_kvs, dummy_ctx_len), + dynamic_shapes=dynamic_shapes, + strict=False, + ) + + # Compile + trt_model = torch_tensorrt.dynamo.compile( + ep, + inputs=[dummy_input_ids, dummy_pos_ids, dummy_kvs, dummy_ctx_len], + enabled_precisions={torch.float32}, + use_explicit_typing=True, + use_fp32_acc=True, + device=device, + disable_tf32=True, + min_block_size=1, + debug=debug, + ) + + return trt_model + + +# ----------------------------------------------------------------------------- +# KV Cache Utilities +# ----------------------------------------------------------------------------- + + +def create_kv_caches( + config: Any, + max_seq_len: int, + batch_size: int, + device: torch.device, + dtype: torch.dtype = torch.float16, +) -> List[torch.Tensor]: + """ + Create empty KV cache tensors for all layers. + + Args: + config: Model configuration. + max_seq_len: Maximum sequence length (capacity). + batch_size: Batch size. + device: Device to create tensors on. + dtype: Data type for the tensors. + + Returns: + List of KV cache tensors, one per layer. + """ + num_layers = config.num_hidden_layers + num_kv_heads = config.num_key_value_heads + # Qwen3 has explicit head_dim that may differ from hidden_size // num_attention_heads + if hasattr(config, "head_dim") and config.head_dim is not None: + head_dim = config.head_dim + else: + head_dim = config.hidden_size // config.num_attention_heads + + return [ + torch.zeros( + batch_size, + 2, + num_kv_heads, + max_seq_len, + head_dim, + dtype=dtype, + device=device, + ) + for _ in range(num_layers) + ] + + +# ----------------------------------------------------------------------------- +# Generation Utilities +# ----------------------------------------------------------------------------- + + +def generate_with_plugin( + model_func: Callable, + input_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + max_new_tokens: int, + eos_token_id: Optional[int] = None, + device: torch.device = torch.device("cuda:0"), +) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Generate tokens using the plugin model. + + Args: + model_func: The compiled model function. + input_ids: Input token IDs [batch, seq_len]. + kv_caches: List of KV cache tensors. + max_new_tokens: Maximum number of new tokens to generate. + eos_token_id: EOS token ID for early stopping (optional). + device: Device for computation. + + Returns: + Tuple of (generated token IDs, updated KV caches). + """ + generated_ids = input_ids.clone() + seq_len = input_ids.shape[1] + + # Prefill + position_ids = torch.arange(seq_len, dtype=torch.long, device=device).unsqueeze(0) + ctx_len = torch.tensor([seq_len], dtype=torch.int32, device=device) + + output = model_func(input_ids, position_ids, kv_caches, ctx_len) + + if isinstance(output, (tuple, list)): + if len(output) == 2: + logits, delta_kvs = output + else: + logits = output[0] + delta_kvs = output[1:] + else: + logits = output + delta_kvs = [] + + # Update KV caches + if len(delta_kvs) > 0: + for i, delta in enumerate(delta_kvs): + seq_len_out = delta.shape[3] + kv_caches[i][:, :, :, :seq_len_out, :] = delta + + next_token = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(0) + generated_ids = torch.cat([generated_ids, next_token], dim=1) + + # Check for EOS + if eos_token_id is not None and next_token.item() == eos_token_id: + return generated_ids, kv_caches + + # Decode + cur_pos = seq_len + + for _ in range(max_new_tokens - 1): + input_ids_step = next_token + position_ids_step = torch.tensor([[cur_pos]], dtype=torch.long, device=device) + ctx_len_step = torch.tensor([cur_pos + 1], dtype=torch.int32, device=device) + + output = model_func(input_ids_step, position_ids_step, kv_caches, ctx_len_step) + + if isinstance(output, (tuple, list)): + if len(output) == 2: + logits, delta_kvs = output + else: + logits = output[0] + delta_kvs = output[1:] + + # Update KV caches + if len(delta_kvs) > 0: + for i, delta in enumerate(delta_kvs): + kv_caches[i][:, :, :, cur_pos : cur_pos + 1, :] = delta + + next_token = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(0) + generated_ids = torch.cat([generated_ids, next_token], dim=1) + cur_pos += 1 + + # Check for EOS + if eos_token_id is not None and next_token.item() == eos_token_id: + break + + return generated_ids, kv_caches + + +def benchmark_plugin_generation( + model_func: Callable, + config: Any, + isl: int, + osl: int, + max_seq_len: int, + device: torch.device, + dtype: torch.dtype = torch.float16, + run_name: str = "Plugin", +) -> float: + """ + Benchmark plugin model generation. + + Args: + model_func: The compiled model function. + config: Model configuration. + isl: Input sequence length. + osl: Output sequence length (number of tokens to generate). + max_seq_len: Maximum sequence length for KV cache. + device: Device for computation. + dtype: Data type. + run_name: Name for logging. + + Returns: + Elapsed time in milliseconds. + """ + # Check for extra kwargs the model might need + extra_kwargs = {} + if hasattr(model_func, "forward"): + sig = inspect.signature(model_func.forward) + if "arg_start_idx" in sig.parameters: + extra_kwargs["arg_start_idx"] = 0 + if "arg_end_idx" in sig.parameters: + extra_kwargs["arg_end_idx"] = 0 + + # Prepare inputs + input_ids = torch.randint(0, config.vocab_size, (1, isl), device=device) + kv_caches = create_kv_caches(config, max_seq_len, 1, device, dtype) + + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + + # Prefill + seq_len = isl + position_ids = torch.arange(seq_len, dtype=torch.long, device=device).unsqueeze(0) + ctx_len = torch.tensor([seq_len], dtype=torch.int32, device=device) + + output = model_func(input_ids, position_ids, kv_caches, ctx_len, **extra_kwargs) + + if isinstance(output, (tuple, list)): + if len(output) == 2: + logits, delta_kvs = output + else: + logits = output[0] + delta_kvs = output[1:] + else: + logits = output + delta_kvs = [] + + # Update KV caches + if len(delta_kvs) > 0: + for i, delta in enumerate(delta_kvs): + seq_len_out = delta.shape[3] + kv_caches[i][:, :, :, :seq_len_out, :] = delta + + next_token = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(0) + + # Decode + cur_pos = seq_len + + for _ in range(osl - 1): + input_ids_step = next_token + position_ids_step = torch.tensor([[cur_pos]], dtype=torch.long, device=device) + ctx_len_step = torch.tensor([cur_pos + 1], dtype=torch.int32, device=device) + + output = model_func( + input_ids_step, position_ids_step, kv_caches, ctx_len_step, **extra_kwargs + ) + + if isinstance(output, (tuple, list)): + if len(output) == 2: + logits, delta_kvs = output + else: + logits = output[0] + delta_kvs = output[1:] + + # Update KV caches + if len(delta_kvs) > 0: + for i, delta in enumerate(delta_kvs): + kv_caches[i][:, :, :, cur_pos : cur_pos + 1, :] = delta + + next_token = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(0) + cur_pos += 1 + + end_event.record() + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + print( + f"{run_name} | ISL: {isl}, OSL: {osl} | Total Time: {elapsed_ms:.2f} ms | Tokens/sec: {osl / (elapsed_ms / 1000.0):.2f}" + ) + return elapsed_ms diff --git a/tools/llm/run_llm.py b/tools/llm/run_llm.py index 1531c30622..184baefbe6 100644 --- a/tools/llm/run_llm.py +++ b/tools/llm/run_llm.py @@ -10,6 +10,7 @@ import argparse import copy import os +import sys import timeit from contextlib import nullcontext @@ -28,6 +29,24 @@ time_generate, ) +# Import plugin utilities (optional) +try: + from plugin_utils import ( + LLMPluginWrapper, + benchmark_plugin_generation, + compile_plugin_model, + create_kv_caches, + generate_with_plugin, + load_plugin, + register_plugin_op, + replace_attention_with_plugin, + set_plugin_config_from_model, + ) + + PLUGIN_AVAILABLE = True +except ImportError as e: + PLUGIN_AVAILABLE = False + DEVICE = torch.device("cuda:0") @@ -49,17 +68,23 @@ def get_model(args): moved to CUDA device with the specified precision """ with torch.no_grad(): + # For plugin backend, we don't set attn_implementation + attn_impl_kwargs = {} + if args.backend == "sdpa": + attn_impl_kwargs["attn_implementation"] = "sdpa" + model = ( AutoModelForCausalLM.from_pretrained( args.model, use_cache=False, - attn_implementation="sdpa", + **attn_impl_kwargs, ) .eval() .cuda() ) - # register SDPA variant for the model - register_sdpa.enable_sdpa_converter(args.model, model.config) + # register SDPA variant for the model (only for sdpa backend) + if args.backend == "sdpa": + register_sdpa.enable_sdpa_converter(args.model, model.config) if args.precision == "FP16": model = model.to(torch.float16) @@ -195,6 +220,12 @@ def measure_perf(trt_model, input_signature, backend_name): default="FP16", help="Precision to use in the model. Options: FP16, BF16, FP32", ) + arg_parser.add_argument( + "--backend", + type=str, + default="sdpa", + help="Backend to use. Options: sdpa, plugin", + ) arg_parser.add_argument( "--iterations", type=int, default=5, help="no. of iterations to run" ) @@ -238,6 +269,16 @@ def measure_perf(trt_model, input_signature, backend_name): ) args = arg_parser.parse_args() + + # Validate arguments + if args.backend == "plugin" and not PLUGIN_AVAILABLE: + raise RuntimeError( + "Plugin backend requested but plugin utilities are not available." + ) + if args.cache and args.backend == "plugin": + print("Warning: --cache is only applicable with 'sdpa' backend. Ignoring.") + args.cache = "" + with torch.inference_mode(): model = get_model(args) @@ -281,54 +322,114 @@ def measure_perf(trt_model, input_signature, backend_name): compile_time_s=None, ) - if args.cache == "static_v1": - # This import is required to register static v1 KV cache transformations as lowering passes - import static_cache_v1 - if args.cache == "static_v2": - # This import is required to register static v2 KV cache transformations as lowering passes - import static_cache_v2 + # Backend selection: sdpa or plugin + if args.backend == "plugin": + # Plugin backend + if not PLUGIN_AVAILABLE: + raise RuntimeError("Plugin backend requested but not available") + + dtype = ( + torch.float16 + if args.precision == "FP16" + else (torch.bfloat16 if args.precision == "BF16" else torch.float32) + ) + config = model.config + max_seq_len = max(2048, MAX_OUTPUT_SEQ_LENGTH) + + # Load plugin and register op + load_plugin() + register_plugin_op() + set_plugin_config_from_model(config, max_seq_len) + + # Replace attention with plugin + model = replace_attention_with_plugin( + model, config, max_seq_len, DEVICE, dtype + ) + wrapper = LLMPluginWrapper(model) - # Compile the model with Torch-TensorRT - trt_model = compile_torchtrt(model, input_ids, args) + # Compile plugin model + trt_model = compile_plugin_model( + wrapper, config, max_seq_len, DEVICE, dtype, args.debug + ) - if args.cache == "static_v1" or args.cache == "static_v2": - if args.cudagraph: - # Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases. - # trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model) - torch_tensorrt.runtime.set_cudagraphs_mode(True) + # Create KV caches + kv_caches = create_kv_caches( + config, max_seq_len, args.batch_size, DEVICE, dtype + ) - trt_gen_tokens = generate_with_static_cache( + # Generate + trt_gen_tokens, _ = generate_with_plugin( trt_model, input_ids.clone(), - MAX_OUTPUT_SEQ_LENGTH, + kv_caches, + args.num_tokens, tokenizer.eos_token_id, + DEVICE, ) if args.benchmark: - trt_timings = time_generate( - generate_with_static_cache, + trt_timings = [] + for i in range(args.iterations): + elapsed_ms = benchmark_plugin_generation( + trt_model, + config, + input_ids.shape[1], + args.num_tokens, + max_seq_len, + DEVICE, + dtype, + ) + trt_timings.append(elapsed_ms / 1000.0) + else: + # SDPA backend (default) + if args.cache == "static_v1": + # This import is required to register static v1 KV cache transformations as lowering passes + import static_cache_v1 + if args.cache == "static_v2": + # This import is required to register static v2 KV cache transformations as lowering passes + import static_cache_v2 + + # Compile the model with Torch-TensorRT + trt_model = compile_torchtrt(model, input_ids, args) + + if args.cache == "static_v1" or args.cache == "static_v2": + if args.cudagraph: + # Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases. + # trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model) + torch_tensorrt.runtime.set_cudagraphs_mode(True) + + trt_gen_tokens = generate_with_static_cache( trt_model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id, - iterations=args.iterations, ) - else: - trt_gen_tokens = generate( - trt_model, - input_ids.clone(), - MAX_OUTPUT_SEQ_LENGTH, - tokenizer.eos_token_id, - ) - if args.benchmark: - trt_timings = time_generate( - generate, + + if args.benchmark: + trt_timings = time_generate( + generate_with_static_cache, + trt_model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + iterations=args.iterations, + ) + else: + trt_gen_tokens = generate( trt_model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id, - iterations=args.iterations, ) + if args.benchmark: + trt_timings = time_generate( + generate, + trt_model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + iterations=args.iterations, + ) if args.benchmark: trt_stats = record_stats(