From b9fe0c193e4f5bc5866570a7eafbdc9dc1858478 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Thu, 13 Nov 2025 19:11:06 +0000 Subject: [PATCH 1/4] squashed and cleaned the commits --- py/torch_tensorrt/dynamo/_compiler.py | 28 +- py/torch_tensorrt/dynamo/_defaults.py | 2 + py/torch_tensorrt/dynamo/_settings.py | 2 + .../partitioning/_resource_partitioner.py | 562 ++++++++++++++++++ .../dynamo/partitioning/fusion_patterns.py | 185 ++++++ tools/llm/torchtrt_ext/register_sdpa.py | 7 +- 6 files changed, 783 insertions(+), 3 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py create mode 100644 py/torch_tensorrt/dynamo/partitioning/fusion_patterns.py diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index c8ad938032..42cf580c6a 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -104,6 +104,7 @@ def cross_compile_for_windows( l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, + cpu_memory_budget: int = _defaults.CPU_MEMORY_BUDGET, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows @@ -178,6 +179,7 @@ def cross_compile_for_windows( tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model + cpu_memory_budget (int): The maximum amount of CPU memory to use for the compilation. If the compilation requires more memory than this budget, the compilation will fail. If set to -1, the compilation will use all available CPU memory. **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -333,6 +335,7 @@ def cross_compile_for_windows( "tiling_optimization_level": tiling_optimization_level, "l2_limit_for_tiling": l2_limit_for_tiling, "use_distributed_mode_trace": use_distributed_mode_trace, + "cpu_memory_budget": cpu_memory_budget, } # disable the following settings is not supported for cross compilation for windows feature @@ -434,6 +437,7 @@ def compile( l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, + cpu_memory_budget: int = _defaults.CPU_MEMORY_BUDGET, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -680,8 +684,8 @@ def compile( "l2_limit_for_tiling": l2_limit_for_tiling, "offload_module_to_cpu": offload_module_to_cpu, "use_distributed_mode_trace": use_distributed_mode_trace, + "cpu_memory_budget": cpu_memory_budget, } - settings = CompilationSettings(**compilation_options) logger.info("Compilation Settings: %s\n", settings) exported_program = pre_export_lowering(exported_program, settings) @@ -850,6 +854,16 @@ def preserve_module_specs( require_full_compilation=settings.require_full_compilation, ) + from torch_tensorrt.dynamo.partitioning._resource_partitioner import ( + resource_partition, + ) + + partitioned_module = resource_partition( + gm, + partitioned_module, + cpu_memory_budget=settings.cpu_memory_budget, + ) + dryrun_tracker.unsupported_ops = supported_ops.unsupported_operators # The global partitioner leaves non-TRT nodes as-is @@ -868,6 +882,16 @@ def preserve_module_specs( # Iterate over all components that can be accelerated # Generate the corresponding TRT Module for those + # Here we delete the frozen parameters from the graph module. Note this does not affect the submodules. We are going to delete the frozen parameters from the submodules in the convert_module function. + # This is done to release CPU memory. + for attr in dir(gm): + if attr.startswith("_frozen_param"): + delattr(gm, attr) + + from torch_tensorrt.dynamo.conversion._ConverterRegistry import DYNAMO_CONVERTERS + + DYNAMO_CONVERTERS.disallowed_targets = set() + for name, _ in partitioned_module.named_children(): submodule = getattr(partitioned_module, name) # filter on the GraphModule @@ -1243,7 +1267,7 @@ def convert_exported_program_to_serialized_trt_engine( # Prepare torch_trt inputs trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs) - trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs) + trt_kwarg_inputs: Optional[dict[str, Any]] = prepare_inputs(kwarg_inputs) device = to_torch_tensorrt_device(device) enabled_precisions = {dtype._from(p) for p in enabled_precisions} diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index de970ecd81..0b4c0a2b54 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -2,6 +2,7 @@ import platform import tempfile +import psutil import torch from torch_tensorrt._Device import Device from torch_tensorrt._enums import EngineCapability, dtype @@ -57,6 +58,7 @@ L2_LIMIT_FOR_TILING = -1 USE_DISTRIBUTED_MODE_TRACE = False OFFLOAD_MODULE_TO_CPU = False +CPU_MEMORY_BUDGET = psutil.virtual_memory().available if platform.system() == "Linux": import pwd diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index d8f6809eae..52ac86012c 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -7,6 +7,7 @@ from torch_tensorrt.dynamo._defaults import ( ASSUME_DYNAMIC_SHAPE_SUPPORT, CACHE_BUILT_ENGINES, + CPU_MEMORY_BUDGET, DISABLE_TF32, DLA_GLOBAL_DRAM_SIZE, DLA_LOCAL_DRAM_SIZE, @@ -140,6 +141,7 @@ class CompilationSettings: l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU + cpu_memory_budget: int = CPU_MEMORY_BUDGET def __getstate__(self) -> dict[str, Any]: from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( diff --git a/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py new file mode 100644 index 0000000000..967711ba02 --- /dev/null +++ b/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py @@ -0,0 +1,562 @@ +"""Resource-aware graph partitioner for TensorRT compilation. + +This module refines an existing capability-based partitioning (accelerated vs +non-accelerated subgraphs) by further splitting accelerated subgraphs to meet +host CPU memory constraints during TensorRT engine building. + +High-level algorithm +-------------------- +Given an original `torch.fx.GraphModule` and a capability-partitioned +`GraphModule` (produced earlier in the pipeline), we: + +1) Reconstruct subgraphs on the original graph + - Iterate over the capability-partitioned module to determine which original + nodes belong to which subgraph (accelerated or not). + - Preserve fusion groups discovered in each subgraph so that all nodes in a fusion + group remain in the same subgraph and not be split across subgraphs. + - Verify subgraphs respect topological order. This is to ensure the validity of the subgraphs. + - Reconstruting subgraphs from partitioned module is easier than building nasted partitioned graph modules and flattening them later. + +2) Estimate memory cost of each possible subgraphs + - Compute a per-subgraph "size" by traversing the graph to find weights + (get_attr) reachable from its nodes and summing tensor bytes. + - Use a set to record the visited nodes and avoid double counting shared parameters across subgraphs. + + +4) Split large accelerated subgraphs + - While a subgraph exceeds the per-engine budget, split it into two or more subgraphs. + - Move nodes incrementally from the front of the original subgraph into a + new left subgraph, repeatedly validating/correcting topological, partitioning, and + dependency constraints. + - Ensure we never split across a fusion group; when a split would break a + fusion, we backtrack dependencies and move the entire fusion and related nodes into the left + side. + - Continue until the left subgraph fits the budget + - Repeat the process for the right subgraph until all subgraphs fit the budget. + +5) Finalize + - After splitting, assert all fusion groups reside in a single subgraph. + - Tag nodes and produce a `GraphModule` where each subgraph becomes either a + TRT engine (accelerated) or runs in Torch (non-accelerated). + +Notes +----- +- The budget is a heuristic bound. If the total model size exceeds 40x the + per-engine budget, we fail early with a clear error suggesting remedies. +""" + +import logging +from typing import Dict, List, Tuple + +import psutil +import torch +from torch.fx.passes.splitter_base import Subgraph, _SplitterBase +from torch.fx.passes.tools_common import CALLABLE_NODE_OPS +from torch_tensorrt.dynamo.partitioning.fusion_patterns import ( + get_node_in_fusion_pattern, +) + +logger = logging.getLogger(__name__) + + +class ResourcePartitioner(_SplitterBase): # type: ignore + """Refine capability-based subgraphs to meet host CPU memory constraints. + + This partitioner takes: + - an original `torch.fx.GraphModule` (`module`) + - a capability-partitioned `GraphModule` (`partitioned_module`) containing + submodules that delineate accelerated vs non-accelerated regions + - a CPU memory budget in bytes (`cpu_memory_budget`) + + It maps nodes from `module` into subgraphs according to `partitioned_module` + and then splits oversized accelerated subgraphs so that each resulting TRT + engine's estimated size fits within a conservative budget derived from + available CPU memory or predefined CPU budget. + """ + + def __init__( + self, + module: torch.fx.GraphModule, + partitioned_module: torch.fx.GraphModule, + cpu_memory_budget: int, + ): + + assert isinstance(module, torch.fx.GraphModule) + assert isinstance(partitioned_module, torch.fx.GraphModule) + + self.module = module + self.partitioned_module = partitioned_module + self.cpu_memory_budget = cpu_memory_budget + + self.deps = self.find_deps() + + self.non_acc_submodule_name = "_run_on_gpu_" + self._node_submodule_map: Dict[str, str] = {} + self._return_tuple = False + self.fusion_patterns: Dict[torch.fx.Node, List[torch.fx.Node]] = {} + + def partition_graph(self) -> torch.fx.GraphModule: + """Build the final partitioned `GraphModule` honoring memory constraints. + + Steps: + - Build subgraph assignments from the capability-partitioned module + - Split oversized accelerated subgraphs based on memory budget + - Tag nodes and construct the final split graph + + Returns: + torch.fx.GraphModule: A graph split into subgraphs based on capability partitioning and memory constraints. + """ + # Delegate nodes based on operator coverage + subgraphs = self.put_nodes_into_subgraphs() + + subgraphs = self.break_subgraphs( + subgraphs, subgraph_size_budget=self.calculate_size_budget() + ) + + # Set the number of TRT engines to be generated + self.num_trt_accelerated_subgraphs = len([s for s in subgraphs if s.is_acc]) + + # Tag the accelerated nodes and split the graph accordingly + self.tag(subgraphs) + + gm = self.split() + + return gm + + def put_nodes_into_subgraphs(self) -> list[Subgraph]: + """Map original graph nodes into capability-based subgraphs. + + - Iterates `partitioned_module` submodules to establish which node names + belong to which subgraph (accelerated or not). + - Builds a fusion pattern map for each subgraph so that known fusion groups remain intact. + Note that since fusion map is built for each subgraph, the capability partitioning can still break the fusion groups. + - Put the nodes into the subgraphs based on the capability partitioning. + - Verifies the resulting list of subgraphs is topologically ordered. + + Returns: + list[Subgraph]: Ordered subgraphs consisting of nodes in `module` based on capability partitioning. + """ + subgraphs_map = {} + subgraphs = [] + name_to_node_map = ( + {} + ) # We use this map to help map the nodes in partitioned module to the nodes in original module. + for name, _ in self.partitioned_module.named_children(): + # We first iterate over the partitioned module to find the subgraphs based on capability partitioning. + submodule = getattr(self.partitioned_module, name) + if not isinstance(submodule, torch.fx.graph_module.GraphModule): + continue + subgraph = Subgraph(is_acc="acc" in name, nodes=[]) + subgraphs.append(subgraph) + self.fusion_patterns.update(get_node_in_fusion_pattern(submodule.graph)) + + for node in submodule.graph.nodes: + # Erase the tag from previous partitioner if it exists + if hasattr(node, "tag"): + delattr(node, "tag") + + if node.op in CALLABLE_NODE_OPS: + # Store which subgraph the node should be put in + subgraphs_map[node.name] = subgraph + + # We then iterate over the original module to put the nodes into the subgraphs. + for node in self.module.graph.nodes: + if hasattr(node, "tag"): + # Erase the tag from previous partitioner + delattr(node, "tag") + if node.op in CALLABLE_NODE_OPS: + name_to_node_map[node.name] = node + subgraphs_map[node.name].nodes.append(node) + + assert self.check_topological_order( + subgraphs + ), "The subgraphs are not topologically ordered" + self.fusion_patterns = { + name_to_node_map[node.name]: [ + name_to_node_map[n.name] for n in fusion_nodes + ] + for node, fusion_nodes in self.fusion_patterns.items() + } + + return subgraphs + + def check_topological_order(self, subgraphs: List[Subgraph]) -> bool: + """Return True if subgraphs are in a valid topological order. + + Each node's dependencies must appear in earlier subgraphs or earlier + positions within the same subgraph. Subgraphs should be topologically ordered to ensure the validity of the subgraphs. + """ + visited_nodes: set[torch.fx.Node] = set() + for subgraph in subgraphs: + for node in subgraph.nodes: + if self.deps[node] > visited_nodes: + return False + visited_nodes.add(node) + return True + + def calculate_size_budget( + self, engine_compilation_memory_usage_multiplier: int = 4 + ) -> int: + """Compute the per-engine size budget in bytes. + + Uses explicit `cpu_memory_budget` minus used RSS + divided by a safety multiplier. + + Args: + engine_compilation_memory_usage_multiplier: Safety divisor applied to + available memory to approximate a per-engine budget. By default we assume TensorRT + compilation requires up to 4x the model's size. + + Returns: + int: Budget in bytes for a single accelerated subgraph. + """ + + used_rss: int = psutil.virtual_memory().used + available_rss = self.cpu_memory_budget - used_rss + return available_rss // engine_compilation_memory_usage_multiplier + + def break_subgraphs( + self, subgraphs: List[Subgraph], subgraph_size_budget: int + ) -> List[Subgraph]: + """Split oversized accelerated subgraphs until they fit within budget. + + - Compute sizes for each subgraph (in bytes of parameters reachable from + that subgraph). + - If the sum of all sizes is catastrophically larger than budget + (threshold 40x), raise a ValueError with guidance. + - For any subgraph whose size exceeds `subgraph_size_budget`, iteratively + split it using `break_subgraph_by_size` and append resulting segments. + - Validate that fusion groups remain intact post splitting. + + Args: + subgraphs: Ordered list of subgraphs from capability partitioning. + subgraph_size_budget: Target maximum size per accelerated subgraph. + + Returns: + List[Subgraph]: New list of subgraphs after resource-aware splitting. + """ + + new_subgraphs = [] + # We throw an error if the remaining memory is almost empty compared to the model size. + # i.e. if the remaining memory is 4G (budget is 1G) the model size is greater than 40G, we stop the compilation. + sizes = self.size_of_subgraphs(subgraphs) + if sum(sizes) > subgraph_size_budget * 40: + raise ValueError( + f"CPU memory budget or available memory is too small to compile the model. CPU memory budget: {self.cpu_memory_budget // (1024 * 1024) if self.cpu_memory_budget != -1 else "All available memory"} MB, Model size: {sum(sizes) // (1024 * 1024)} MB. " + + "Consider setting cpu_memory_budget to a larger value or disable offload_module_to_cpu to save more CPU memory." + ) + for subgraph, size in zip(subgraphs, sizes): + + while size > subgraph_size_budget: + broken_subgraphs, size_0, size_1 = self.break_subgraph_by_size( + subgraph, subgraph_size_budget + ) + size = size_1 + new_subgraphs.append(broken_subgraphs[0]) + subgraph = broken_subgraphs[1] + new_subgraphs.append(subgraph) + + self._varify_all_fusion_nodes_in_same_subgraph(new_subgraphs) + + return new_subgraphs + + def _varify_all_fusion_nodes_in_same_subgraph( + self, subgraphs: List[Subgraph] + ) -> None: + """Assert that every fusion group is contained in exactly one subgraph.""" + node_to_subgraph = {} + for i, s in enumerate(subgraphs): + for n in s.nodes: + node_to_subgraph[n] = i + + fusion_nodes_map_list = [ + len({node_to_subgraph[n] for n in ns}) == 1 + for ns in self.fusion_patterns.values() + ] # fusion nodes must be in the same subgraph + + assert all( + fusion_nodes_map_list + ), "All fusion nodes must be in the same subgraph" + logger.info("All fusion nodes are in the same subgraph.") + + def break_subgraph_by_size( + self, subgraph: Subgraph, size_to_break: int + ) -> Tuple[List[Subgraph], int, int]: + """Split a single oversized subgraph into two valid subgraphs. + + Moves nodes from the head of `subgraph` into a new left segment until + the left segment's estimated size exceeds `size_to_break`. During the + process we: + - Repeatedly validate/correct topological placement + - Detect and avoid splitting fusion groups by moving all fused nodes + (and their producer chain) into the left segment + + Returns: + (segments, size_left, size_right): + segments[0] is the new left subgraph, segments[1] is the residual + right subgraph. Sizes are estimated parameter bytes of each. + """ + all_nodes = subgraph.nodes + device_ordinal = subgraph.device_ordinal + new_subgraphs = [ + Subgraph( + is_acc=True, + nodes=[], + device_ordinal=device_ordinal, + ), + Subgraph( + is_acc=True, + nodes=all_nodes, + device_ordinal=device_ordinal, + ), + ] + + # We break the subgraph until the left subgraph fits the budget. + while True: + # Set a step size proportional to the size of the subgraph to make the algorithm more efficient. + # This reduce the time complexity from O(N**2) to O(N). The max number of steps is 50. + # Note: we want the first step size to be 1. + step_size = ( + 1 if not new_subgraphs[0].nodes else max(1, len(all_nodes) // 50) + ) + new_subgraphs = self.step_and_validate(new_subgraphs, step_size) + size_0, size_1 = self.size_of_subgraphs(new_subgraphs) + if size_0 > size_to_break: + break + + if len(new_subgraphs[1].nodes) == 0: + new_subgraphs.pop(1) + return new_subgraphs, size_0, size_1 + + def step_and_validate( + self, new_subgraphs: List[Subgraph], step_size: int = 1 + ) -> List[Subgraph]: + """Advance the split by `step_size` nodes, then add more nodes to the left subgraph if rules are broken. + There are two rules to check: + 1. The subgraphs should be ordered in a way that is safely to partition. + This is checked by validate_and_correct_subgraphs. Check that function for more details. + 2. The subgraphs should not break any fusion groups. + - Move `step_size` nodes from the right to the left subgraph. + - Run validation/correction to ensure a legal partitioning placement. + - Get all leaf nodes in the left subgraph and check whether any of them are in a fusion group. + - If the move splits a fusion group, migrate the entire fusion into the left subgraph. + + Returns: + List[Subgraph]: Updated pair of subgraphs after stabilization. + """ + + for _ in range(step_size): + new_subgraphs[0].nodes.append(new_subgraphs[1].nodes.pop(0)) + + while True: + new_subgraphs = self.validate_and_correct_subgraphs(new_subgraphs) + nodes_in_first_subgraph = set(new_subgraphs[0].nodes) + nodes_in_second_subgraph = set(new_subgraphs[1].nodes) + leaf_node = self.get_leaf_node(nodes_in_first_subgraph) + broken_fusion = self.step_if_break_fusion( + new_subgraphs, + leaf_node, + nodes_in_first_subgraph, + nodes_in_second_subgraph, + ) + if not broken_fusion or len(new_subgraphs[1].nodes) == 0: + break + + return new_subgraphs + + def step_if_break_fusion( + self, + subgraphs: List[Subgraph], + leaf_nodes: set[torch.fx.Node], + nodes_in_first_subgraph: set[torch.fx.Node], + nodes_in_second_subgraph: set[torch.fx.Node], + ) -> bool: + """Detect a fusion split and migrate fused nodes to the left subgraph. + + Given the current split boundary (captured by `leaf_nodes` of the left + subgraph), check all recorded fusion groups. If any fused node remains + on the right while its peer is on the left, pull the node and all of its + producer chain into the left subgraph to keep fusions intact. + + Returns: + bool: True if any fusion was migrated (i.e., a split would have + broken a fusion), otherwise False. + """ + + def add_nodes(node: torch.fx.Node) -> None: + """ + This function adds a node and all its previous nodes to the first subgraph and removes it from the second subgraph in post order. + """ + if ( + node.op in CALLABLE_NODE_OPS + and node not in nodes_in_first_subgraph + and node in nodes_in_second_subgraph + ): + # Exclude all nodes already in the first subgraph + nodes_in_first_subgraph.add(node) + nodes_in_second_subgraph.remove(node) + for input_node in node._input_nodes: + add_nodes(input_node) + subgraphs[0].nodes.append(node) + subgraphs[1].nodes.remove(node) + + fusion_broken = False + for leaf in leaf_nodes: + for node in self.fusion_patterns.get(leaf, []): + if ( + node not in nodes_in_first_subgraph + and node in nodes_in_second_subgraph + ): + fusion_broken = True + add_nodes(node) + + return fusion_broken + + def get_leaf_node( + self, nodes_in_first_subgraph: set[torch.fx.Node] + ) -> set[torch.fx.Node]: + """Return nodes in the left subgraph that feed any node on the right. + + A node is considered a leaf if at least one of its users is not in the + left subgraph. + """ + leaf_node = set() + + for node in nodes_in_first_subgraph: + for user in node.users: + if user not in nodes_in_first_subgraph: + leaf_node.add(node) + break + return leaf_node + + def size_of_subgraphs(self, subgraphs: List[Subgraph]) -> List[int]: + """Estimate parameter footprint (bytes) for each subgraph. + + Traverses each subgraph's nodes and their producer chains to find + parameters referenced via `get_attr`, summing tensor bytes. Shared + parameters are counted only once globally. + + Returns: + List[int]: Size per subgraph in bytes. + """ + state_dict = self.module.state_dict(keep_vars=True) + sizes = [] + weight_visited_nodes = set() + for subgraph in subgraphs: + nodes_in_subgraph = set(subgraph.nodes) + stack = subgraph.nodes.copy() + size = 0 + while stack: + node = stack.pop() + if node in weight_visited_nodes: + continue + weight_visited_nodes.add(node) + if node.op == "get_attr": + weight = state_dict.get(node.target, None) + if weight is None: + logger.warning(f"Weight {node.target} not found in state_dict") + continue + size += weight.numel() * weight.element_size() + continue + if node not in nodes_in_subgraph: + # Trace to other subgraphs + continue + for input_node in node._input_nodes: + if input_node not in weight_visited_nodes: + stack.append(input_node) + sizes.append(size) + return sizes + + def validate_and_correct_subgraphs( + self, subgraphs: List[Subgraph] + ) -> List[Subgraph]: + """This is very important for the correctness of the partitioning. Torch gives undefined behavior if the subgraphs are not ordered correctly. + + The principle is: nodes that have all dependencies resolved in previous subgraphs should also be moved to the previous subgraph. + For example, given a breakpoint node n resulting in two subgraphs S1 [..., n] and S2 [n+1, ...], all nodes in S2 that is not directly or indirectly depend on n should be moved to S1. + + We use a map to record the index of the subgraph that a node's users should belong to. If the node N is in subgraph S1 and is not the breakpoint node (subgraph.nodes[-1]), + then the users that only depend on N should also be moved to S1. However, N is a breakpoint node, then the users that only depend on N should also be moved to S2. + + With the map, we can determine with subgraph a later node should be moved to according to all its inputs. We take max indices of all inputs nodes to determine the subgraph index. + + Returns: + List[Subgraph]: Corrected subgraphs. + """ + # a map from a node to the index of the subgraph it's user should belong to + visited_nodes = {} + + for i, subgraph in enumerate(subgraphs): + if i == 0: + for node in subgraph.nodes: + visited_nodes[node] = i + # breakpoint node's users should belong to the next subgraph + visited_nodes[subgraph.nodes[-1]] = i + 1 + continue + + elif not subgraph.is_acc: + # non-accelerated subgraphs should be put in the next subgraph + for node in subgraph.nodes: + visited_nodes[subgraph.nodes[-1]] = i + 1 + continue + + else: + to_remove_nodes = [] + for j, node in enumerate(subgraph.nodes): + if j == len(subgraph.nodes) - 1: + # breakpoint node's users should belong to the next subgraph + visited_nodes[node] = i + 1 + continue + subgraph_idx = 0 + for dep in self.deps[node]: + if dep in visited_nodes: + # We take max indices of all inputs nodes to determine the subgraph index. + subgraph_idx = max(subgraph_idx, visited_nodes[dep]) + + if subgraph_idx != i: + # If the node should be moved to a different subgraph, we move it and remove it from the current subgraph. + subgraphs[subgraph_idx].nodes.append(node) + to_remove_nodes.append(node) + # Record the the subgraph that the users of this node should belong to + visited_nodes[node] = subgraph_idx + + # Remove the nodes that are moved to other subgraphs + for node in to_remove_nodes: + subgraph.nodes.remove(node) + + return subgraphs + + +def resource_partition( + gm: torch.fx.GraphModule, + partitioned_module: torch.fx.GraphModule, + cpu_memory_budget: int, +) -> torch.fx.GraphModule: + """Resource-aware partitioning entry point. + + Takes an original FX graph (`gm`) and a capability-partitioned module + (`partitioned_module`) and returns a final graph where accelerated segments + are split further, if necessary, to satisfy CPU memory limits for TRT + engine compilation. + + Args: + gm: Original FX `GraphModule`. + partitioned_module: Capability-partitioned `GraphModule` indicating + accelerated vs non-accelerated regions. + cpu_memory_budget: CPU memory budget in bytes for engine compilation. + Use -1 to base the budget on currently available system memory. + + Returns: + torch.fx.GraphModule: Final graph with resource-constrained subgraphs. + """ + + # Construct + partitioner = ResourcePartitioner( + gm, + partitioned_module, + cpu_memory_budget=cpu_memory_budget, + ) + + partitioned_graph = partitioner.partition_graph() + + return partitioned_graph diff --git a/py/torch_tensorrt/dynamo/partitioning/fusion_patterns.py b/py/torch_tensorrt/dynamo/partitioning/fusion_patterns.py new file mode 100644 index 0000000000..a5b3e74ee5 --- /dev/null +++ b/py/torch_tensorrt/dynamo/partitioning/fusion_patterns.py @@ -0,0 +1,185 @@ +from functools import lru_cache +from typing import Dict, List, Set + +import torch +from torch.fx.passes.utils.matcher_utils import SubgraphMatcher +from torch.ops import aten + + +class ConvBNReLU(torch.nn.Module): # type: ignore[misc] + def __init__(self) -> None: + super().__init__() + + def forward( + self, + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: List[int], + padding: List[int], + dilation: List[int], + transposed: bool, + output_padding: List[int], + groups: int, + bn_weight: torch.Tensor, + bn_bias: torch.Tensor, + running_mean: torch.Tensor, + running_var: torch.Tensor, + momentum: float, + eps: float, + ) -> torch.Tensor: + x = aten.convolution.default( + x, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ) + x = aten._native_batch_norm_legit_no_training.default( + x, bn_weight, bn_bias, running_mean, running_var, momentum, eps + )[0] + x = aten.relu.default(x) + return x + + +class ConvReLU(torch.nn.Module): # type: ignore[misc] + def __init__(self) -> None: + super().__init__() + + def forward( + self, + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: List[int], + padding: List[int], + dilation: List[int], + transposed: bool, + output_padding: List[int], + groups: int, + ) -> torch.Tensor: + x = aten.convolution.default( + x, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ) + x = aten.relu.default(x) + return x + + +class ConvGelu(torch.nn.Module): # type: ignore[misc] + def __init__(self) -> None: + super().__init__() + + def forward( + self, + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: List[int], + padding: List[int], + dilation: List[int], + transposed: bool, + output_padding: List[int], + groups: int, + ) -> torch.Tensor: + x = aten.convolution.default( + x, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ) + x = aten.gelu.default(x) + return x + + +class ConvSilu(torch.nn.Module): # type: ignore[misc] + def __init__(self) -> None: + super().__init__() + + def forward( + self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor + ) -> torch.Tensor: + x = aten.convolution.default( + x, weight, bias, [1, 1], [1, 1], [1, 1], False, [0, 0], 1 + ) + x = aten.silu.default(x) + return x + + +class MulAdd(torch.nn.Module): # type: ignore[misc] + def __init__(self) -> None: + super().__init__() + + def forward( + self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor + ) -> torch.Tensor: + x = aten.mul.Tensor(x, weight) + x = aten.add.Tensor(x, bias) + return x + + +class MulMul(torch.nn.Module): # type: ignore[misc] + def __init__(self) -> None: + super().__init__() + + def forward( + self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor + ) -> torch.Tensor: + x = aten.mul.Tensor(x, y) + x = aten.mul.Tensor(x, z) + return x + + +All_FUSION_PATTERNS = [ + ConvBNReLU, + ConvReLU, + ConvGelu, + ConvSilu, + MulAdd, + MulMul, +] + + +@lru_cache(maxsize=None) +def get_node_in_fusion_pattern( + graph: torch.fx.Graph, +) -> Dict[torch.fx.Node, Set[torch.fx.Node]]: + """ + This function gets the nodes map of the fusion pattern from the graph. + Key: node that appears in the fusion pattern + Value: the list of nodes that should be fused together + """ + fusion_nodes = {} + for pattern in All_FUSION_PATTERNS: + pattern_graph = torch.fx.symbolic_trace(pattern()) + subgraph_matcher = SubgraphMatcher(pattern_graph.graph) + match_result = subgraph_matcher.match(graph) + for match in match_result: + fusion_group = { + node + for node in match.nodes_map.values() + if node + and type(node) == torch.fx.Node + and node.op == "call_function" + and node not in match.placeholder_nodes + } + for node in fusion_group: + fusion_nodes[node] = fusion_group + + return fusion_nodes diff --git a/tools/llm/torchtrt_ext/register_sdpa.py b/tools/llm/torchtrt_ext/register_sdpa.py index a82384fda9..c86ee6f3a4 100644 --- a/tools/llm/torchtrt_ext/register_sdpa.py +++ b/tools/llm/torchtrt_ext/register_sdpa.py @@ -23,6 +23,7 @@ torch.ops.aten.scaled_dot_product_attention.default, torch.ops.aten._scaled_dot_product_efficient_attention.default, torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten._scaled_dot_product_cudnn_attention.default, ) @@ -43,6 +44,7 @@ def _remove_decompositions(): REPLACEABLE_ATEN_OPS = { torch.ops.aten._scaled_dot_product_efficient_attention.default, torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten._scaled_dot_product_cudnn_attention.default, } from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( @@ -79,7 +81,10 @@ def _process_sdpa_node( ValueError: If the SDPA node has an unexpected number of arguments """ - if node.target == torch.ops.aten._scaled_dot_product_efficient_attention.default: + if node.target in [ + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_cudnn_attention.default, + ]: if len(node.args) == 7: ( query, From 99581e2b7891cde8ec1e9dedf5057509b29f5560 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Fri, 14 Nov 2025 00:46:50 +0000 Subject: [PATCH 2/4] Added decorator and tests --- py/torch_tensorrt/dynamo/_compiler.py | 4 + ...usion_patterns.py => _atomic_subgraphs.py} | 34 ++++--- .../partitioning/_resource_partitioner.py | 4 +- .../test_resource_partitioning.py | 93 +++++++++++++++++++ 4 files changed, 121 insertions(+), 14 deletions(-) rename py/torch_tensorrt/dynamo/partitioning/{fusion_patterns.py => _atomic_subgraphs.py} (86%) create mode 100644 tests/py/dynamo/partitioning/test_resource_partitioning.py diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 42cf580c6a..5a84218015 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -618,6 +618,10 @@ def compile( "'arg_inputs' and 'inputs' should not be used at the same time." ) + assert ( + cpu_memory_budget >= 2 * 1024 * 1024 * 1024 + ), "CPU memory budget must be greater than 10GB" + arg_inputs = inputs or arg_inputs if kwarg_inputs is None: diff --git a/py/torch_tensorrt/dynamo/partitioning/fusion_patterns.py b/py/torch_tensorrt/dynamo/partitioning/_atomic_subgraphs.py similarity index 86% rename from py/torch_tensorrt/dynamo/partitioning/fusion_patterns.py rename to py/torch_tensorrt/dynamo/partitioning/_atomic_subgraphs.py index a5b3e74ee5..dbda162dcb 100644 --- a/py/torch_tensorrt/dynamo/partitioning/fusion_patterns.py +++ b/py/torch_tensorrt/dynamo/partitioning/_atomic_subgraphs.py @@ -1,11 +1,25 @@ from functools import lru_cache -from typing import Dict, List, Set +from typing import Callable, Dict, List, Set import torch from torch.fx.passes.utils.matcher_utils import SubgraphMatcher from torch.ops import aten +ATOMIC_SUBGRAPHS = [] + +def register_atomic_subgraph( + is_aten: bool = False, +) -> Callable[[torch.nn.Module], torch.nn.Module]: + + def decorator(subgraph: torch.nn.Module) -> torch.nn.Module: + ATOMIC_SUBGRAPHS.append((subgraph, is_aten)) + return subgraph + + return decorator + + +@register_atomic_subgraph(is_aten=True) class ConvBNReLU(torch.nn.Module): # type: ignore[misc] def __init__(self) -> None: super().__init__() @@ -46,6 +60,7 @@ def forward( return x +@register_atomic_subgraph(is_aten=True) class ConvReLU(torch.nn.Module): # type: ignore[misc] def __init__(self) -> None: super().__init__() @@ -77,6 +92,7 @@ def forward( return x +@register_atomic_subgraph(is_aten=True) class ConvGelu(torch.nn.Module): # type: ignore[misc] def __init__(self) -> None: super().__init__() @@ -108,6 +124,7 @@ def forward( return x +@register_atomic_subgraph(is_aten=True) class ConvSilu(torch.nn.Module): # type: ignore[misc] def __init__(self) -> None: super().__init__() @@ -122,6 +139,7 @@ def forward( return x +@register_atomic_subgraph(is_aten=True) class MulAdd(torch.nn.Module): # type: ignore[misc] def __init__(self) -> None: super().__init__() @@ -134,6 +152,7 @@ def forward( return x +@register_atomic_subgraph(is_aten=True) class MulMul(torch.nn.Module): # type: ignore[misc] def __init__(self) -> None: super().__init__() @@ -146,16 +165,6 @@ def forward( return x -All_FUSION_PATTERNS = [ - ConvBNReLU, - ConvReLU, - ConvGelu, - ConvSilu, - MulAdd, - MulMul, -] - - @lru_cache(maxsize=None) def get_node_in_fusion_pattern( graph: torch.fx.Graph, @@ -166,8 +175,9 @@ def get_node_in_fusion_pattern( Value: the list of nodes that should be fused together """ fusion_nodes = {} - for pattern in All_FUSION_PATTERNS: + for pattern, is_aten in ATOMIC_SUBGRAPHS: pattern_graph = torch.fx.symbolic_trace(pattern()) + # TODO: Add decomposition and lowering if is_aten is False subgraph_matcher = SubgraphMatcher(pattern_graph.graph) match_result = subgraph_matcher.match(graph) for match in match_result: diff --git a/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py index 967711ba02..0ef8d76e3a 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py @@ -52,7 +52,7 @@ import torch from torch.fx.passes.splitter_base import Subgraph, _SplitterBase from torch.fx.passes.tools_common import CALLABLE_NODE_OPS -from torch_tensorrt.dynamo.partitioning.fusion_patterns import ( +from torch_tensorrt.dynamo.partitioning._atomic_subgraphs import ( get_node_in_fusion_pattern, ) @@ -211,7 +211,7 @@ def calculate_size_budget( int: Budget in bytes for a single accelerated subgraph. """ - used_rss: int = psutil.virtual_memory().used + used_rss: int = psutil.Process().memory_info().rss available_rss = self.cpu_memory_budget - used_rss return available_rss // engine_compilation_memory_usage_multiplier diff --git a/tests/py/dynamo/partitioning/test_resource_partitioning.py b/tests/py/dynamo/partitioning/test_resource_partitioning.py new file mode 100644 index 0000000000..f059b0b166 --- /dev/null +++ b/tests/py/dynamo/partitioning/test_resource_partitioning.py @@ -0,0 +1,93 @@ +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_tensorrt as torchtrt +from torch.testing._internal.common_utils import TestCase, run_tests +from torch_tensorrt.dynamo import partitioning +from torch_tensorrt.dynamo.conversion import CompilationSettings +from torch_tensorrt.dynamo.lowering import ( + get_decompositions, + post_lowering, + pre_export_lowering, +) +from torch_tensorrt.dynamo.lowering.passes import post_lowering, pre_export_lowering +from torch_tensorrt.dynamo.partitioning._resource_partitioner import resource_partition + + +class TestResourcePartitioning(TestCase): + def test_resource_partitioning(self): + class net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1024, 4096, 3, padding=1) + self.bn1 = nn.BatchNorm2d(4096) + self.conv2 = nn.Conv2d(4096, 1024, 3, padding=1) + self.bn2 = nn.BatchNorm2d(1024) + self.fc1 = nn.Linear(1024 * 56 * 56, 10) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = F.relu(x) + x = F.max_pool2d(x, (2, 2)) + x = self.conv2(x) + x = self.bn2(x) + x = F.relu(x) + x = F.max_pool2d(x, (2, 2)) + x = torch.flatten(x, 1) + return self.fc1(x) + + model = net().eval() + model.to("cuda") + inputs = [torch.randn((1, 1024, 224, 224)).to("cuda")] + + enabled_precisions = {torch.float} + use_python_runtime = False + + exp_program = torch.export.export(model, tuple(inputs)) + + compilation_options = { + "use_python_runtime": use_python_runtime, + "enabled_precisions": enabled_precisions, + "min_block_size": 1, + "immutable_weights": True, + "reuse_cached_engines": False, + } + settings = CompilationSettings(**compilation_options) + with torchtrt.dynamo.Debugger( + log_level="debug", + logging_dir="/home/profile/logging/moe", + engine_builder_monitor=False, + ): + + exported_program = pre_export_lowering(exp_program, settings) + exported_program = exported_program.run_decompositions( + get_decompositions(False) + ) + + gm = exported_program.module() + gm = post_lowering(gm, settings) + + partitioned_module, supported_ops = partitioning.fast_partition( + gm, + min_block_size=settings.min_block_size, + torch_executed_ops=settings.torch_executed_ops, + require_full_compilation=settings.require_full_compilation, + skip_fusion=True, + ) + + partitioned_module = resource_partition( + gm, partitioned_module, cpu_memory_budget=2 * 1024 * 1024 * 1024 # 2GB, + ) + + self.assertEqual( + len(list[Any](partitioned_module.named_children())), + 2, + "The graph should have 2 subgraphs", + ) + + +if __name__ == "__main__": + run_tests() From 290cc3985f33d5308a9a90c1f78f9927f1274908 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Fri, 14 Nov 2025 20:17:10 +0000 Subject: [PATCH 3/4] Added example and fixed lru problem --- examples/dynamo/low_cpu_memory_compilation.py | 84 +++++++++++++++++++ py/torch_tensorrt/dynamo/_compiler.py | 7 +- .../dynamo/partitioning/_atomic_subgraphs.py | 25 ++++-- 3 files changed, 107 insertions(+), 9 deletions(-) create mode 100644 examples/dynamo/low_cpu_memory_compilation.py diff --git a/examples/dynamo/low_cpu_memory_compilation.py b/examples/dynamo/low_cpu_memory_compilation.py new file mode 100644 index 0000000000..2ff5356490 --- /dev/null +++ b/examples/dynamo/low_cpu_memory_compilation.py @@ -0,0 +1,84 @@ +""" + +.. _low_cpu_memory_compilation: + +Low CPU Memory Compilation Example +================================== + +This example demonstrates compiling a model with a bounded CPU (host) memory +budget using Torch-TensorRT Dynamo. Limiting host RAM use is helpful on +memory-constrained machines or when compiling very large models. + +Key notes: +- The toy model below has roughly 430 MB of parameters. We set the CPU + memory budget to 2 GiB. At compile time, only about 900 MB of host RAM + may remain available. We expect at most 403 * 4 = 1612 MB of memory to be used by the model. + So the model is partitioned into two subgraphs to fit the memory budget. + +- Performance impact varies by model. When the number of TensorRT engines + created is small, the impact is typically minimal. + +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_tensorrt as torchtrt +from torch_tensorrt.dynamo.conversion import CompilationSettings + + +class net(nn.Module): + def __init__(self): + super().__init__() + # Intentionally large layers to stress host memory during compilation. + self.conv1 = nn.Conv2d(1024, 4096, 3, padding=1) + self.bn1 = nn.BatchNorm2d(4096) + self.conv2 = nn.Conv2d(4096, 1024, 3, padding=1) + self.bn2 = nn.BatchNorm2d(1024) + self.fc1 = nn.Linear(1024 * 56 * 56, 10) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = F.relu(x) + x = F.max_pool2d(x, (2, 2)) + x = self.conv2(x) + x = self.bn2(x) + x = F.relu(x) + x = F.max_pool2d(x, (2, 2)) + x = torch.flatten(x, 1) + return self.fc1(x) + + +model = net().eval() +model.to("cuda") +inputs = [torch.randn((1, 1024, 224, 224)).to("cuda")] + +enabled_precisions = {torch.float} +use_python_runtime = False + +compilation_options = { + "use_python_runtime": use_python_runtime, + "enabled_precisions": enabled_precisions, + "min_block_size": 1, + "immutable_weights": True, + "reuse_cached_engines": False, + "cpu_memory_budget": 2 * 1024 * 1024 * 1024, # 2 GiB in bytes +} + +settings = CompilationSettings(**compilation_options) +with torchtrt.dynamo.Debugger( + log_level="debug", + logging_dir="/home/profile/logging/moe", + engine_builder_monitor=False, +): + + exp_program = torch.export.export(model, tuple(inputs)) + trt_gm = torchtrt.dynamo.compile( + exp_program, + inputs=inputs, + **compilation_options, + ) + + # Expect two back-to-back TensorRT engines due to partitioning under the memory budget. + print(trt_gm) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 5a84218015..a0ec195a42 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -40,6 +40,9 @@ post_lowering, pre_export_lowering, ) +from torch_tensorrt.dynamo.partitioning._resource_partitioner import ( + resource_partition, +) from torch_tensorrt.dynamo.utils import ( deallocate_module, get_flat_args_with_check, @@ -858,10 +861,6 @@ def preserve_module_specs( require_full_compilation=settings.require_full_compilation, ) - from torch_tensorrt.dynamo.partitioning._resource_partitioner import ( - resource_partition, - ) - partitioned_module = resource_partition( gm, partitioned_module, diff --git a/py/torch_tensorrt/dynamo/partitioning/_atomic_subgraphs.py b/py/torch_tensorrt/dynamo/partitioning/_atomic_subgraphs.py index dbda162dcb..e9c0420add 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_atomic_subgraphs.py +++ b/py/torch_tensorrt/dynamo/partitioning/_atomic_subgraphs.py @@ -165,7 +165,6 @@ def forward( return x -@lru_cache(maxsize=None) def get_node_in_fusion_pattern( graph: torch.fx.Graph, ) -> Dict[torch.fx.Node, Set[torch.fx.Node]]: @@ -175,10 +174,8 @@ def get_node_in_fusion_pattern( Value: the list of nodes that should be fused together """ fusion_nodes = {} - for pattern, is_aten in ATOMIC_SUBGRAPHS: - pattern_graph = torch.fx.symbolic_trace(pattern()) - # TODO: Add decomposition and lowering if is_aten is False - subgraph_matcher = SubgraphMatcher(pattern_graph.graph) + for compiled_pattern_graph in get_compiled_atomic_subgraphs(): + subgraph_matcher = SubgraphMatcher(compiled_pattern_graph.graph) match_result = subgraph_matcher.match(graph) for match in match_result: fusion_group = { @@ -193,3 +190,21 @@ def get_node_in_fusion_pattern( fusion_nodes[node] = fusion_group return fusion_nodes + + +@lru_cache(maxsize=None) +def get_compiled_atomic_subgraphs() -> List[torch.fx.GraphModule]: + """ + This function gets the compiled atomic subgraphs from the graph. + LRU cache the result to avoid recompiling the same pattern multiple times. + """ + compiled_atomic_subgraphs = [] + for pattern, is_aten in ATOMIC_SUBGRAPHS: + pattern_graph = torch.fx.symbolic_trace(pattern()) + if not is_aten: + # TODO: Add decomposition and lowering if is_aten is False + raise NotImplementedError( + "Atomic subgraphs are not supported for non-aten subgraphs yet." + ) + compiled_atomic_subgraphs.append(pattern_graph) + return compiled_atomic_subgraphs From aba69ef1bcdb74a8cb1ffcca3254283a8ba4d586 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Fri, 14 Nov 2025 23:01:30 +0000 Subject: [PATCH 4/4] deleted changes that don't belong to this PR --- py/torch_tensorrt/dynamo/_compiler.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index a0ec195a42..3ee0da4322 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -885,16 +885,6 @@ def preserve_module_specs( # Iterate over all components that can be accelerated # Generate the corresponding TRT Module for those - # Here we delete the frozen parameters from the graph module. Note this does not affect the submodules. We are going to delete the frozen parameters from the submodules in the convert_module function. - # This is done to release CPU memory. - for attr in dir(gm): - if attr.startswith("_frozen_param"): - delattr(gm, attr) - - from torch_tensorrt.dynamo.conversion._ConverterRegistry import DYNAMO_CONVERTERS - - DYNAMO_CONVERTERS.disallowed_targets = set() - for name, _ in partitioned_module.named_children(): submodule = getattr(partitioned_module, name) # filter on the GraphModule