Skip to content

Conversation

@zewenli98
Copy link
Collaborator

@zewenli98 zewenli98 commented Jan 8, 2026

Description

Added a lowering pass to replace SymInt with aten.sym_size. Users now can explicitly use dynamic dims in the models. e.g.:

import logging

import torch
import torch.nn as nn
import torch_tensorrt

logging.basicConfig(level=logging.DEBUG)

torch.manual_seed(0)


class ExpandReshapeModel(nn.Module):
    def __init__(self, embed_dim: int):
        super().__init__()
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.embed_dim = embed_dim
        self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3)

    def forward(self, x: torch.Tensor):
        batch_size = x.shape[0]  # dynamic dim
        cls_token = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat([cls_token, x], dim=1)
        x = self.qkv_proj(x)
        reshaped_qkv = x.reshape(batch_size, x.size(1), 3, 12, -1)
        return reshaped_qkv


model = ExpandReshapeModel(embed_dim=768).cuda().eval()
x = torch.randn(4, 196, 768).cuda()
torch._dynamo.mark_dynamic(x, index=0, min=2, max=32)
trt_module = torch.compile(model, backend="tensorrt")
out = trt_module(x)
print(out)

Fixes #3981

Type of change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@zewenli98 zewenli98 requested a review from narendasan January 8, 2026 18:57
@zewenli98 zewenli98 self-assigned this Jan 8, 2026
@meta-cla meta-cla bot added the cla signed label Jan 8, 2026
@github-actions github-actions bot added component: lowering Issues re: The lowering / preprocessing passes component: core Issues re: The core compiler component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Jan 8, 2026
@github-actions github-actions bot requested a review from cehongwang January 8, 2026 18:58
Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a test case

and issubclass(node.type, torch.Tensor)
):
for symint_node, (arg_name, idx) in symint_node_arg_dict.items():
if node.target == "L_" + arg_name + "_":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this name template be consistent?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same concern. Is there any other ways to identify it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's torch generated internal name so it should be consistent if torch doesn't change the design. Anyways I changed to a more robust way to do it.

@github-actions github-actions bot added the component: tests Issues re: Tests label Jan 17, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: api [Python] Issues re: Python API component: core Issues re: The core compiler component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes component: tests Issues re: Tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

🐛 [Bug] SymInt placeholder nodes not fully removed in remove_sym_nodes pass, causing issues with TensorRT lowering

4 participants