Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions modelopt/onnx/quantization/autotune/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
)
from .region_pattern import RegionPattern
from .region_search import CombinedRegionSearch
from .torch_region_builder import TorchRegionBuilder

__all__ = [
"AutotunerError",
Expand All @@ -61,5 +62,6 @@
"RegionType",
"ResolvedInsertionPoint",
"TensorRTPyBenchmark",
"TorchRegionBuilder",
"TrtExecBenchmark",
]
22 changes: 15 additions & 7 deletions modelopt/onnx/quantization/autotune/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
from modelopt.onnx.quantization.autotune.autotuner_base import QDQAutotunerBase
from modelopt.onnx.quantization.autotune.common import Config, PatternCache, Region, RegionType
from modelopt.onnx.quantization.autotune.region_search import CombinedRegionSearch
from modelopt.onnx.quantization.autotune.torch_region_builder import (
TorchRegionBuilder,
check_torch_naming_convention,
)


class QDQAutotuner(QDQAutotunerBase):
Expand Down Expand Up @@ -94,13 +98,17 @@ def _search_regions(self) -> None:
- Phase 2: Top-down refinement creating hierarchical structure
"""
logger.info("Discovering optimization regions")
search = CombinedRegionSearch(
self.graph,
maximum_sequence_region_size=self.config.maximum_sequence_region_size,
minimum_topdown_search_size=self.config.minimum_topdown_search_size,
)
self.regions = search.search_regions()
self._reassign_region_ids(self.regions)
if check_torch_naming_convention(self.graph):
torch_search = TorchRegionBuilder(self.graph)
self.regions = torch_search.build_regions(linearize=True, only_quantizable=True)
else:
Comment on lines +101 to +104
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Torch branch double-flattens regions and can produce duplicates.

Line 103 requests linearize=True, but Lines 114–120 still recursively flatten every entry. If the returned list already includes descendants, this duplicates regions and skews later profiling/ordering.

Suggested fix
         if check_torch_naming_convention(self.graph):
             torch_search = TorchRegionBuilder(self.graph)
-            self.regions = torch_search.build_regions(linearize=True, only_quantizable=True)
+            self.regions = torch_search.build_regions(linearize=False, only_quantizable=True)
+            self._reassign_region_ids(self.regions)
         else:
             default_search = CombinedRegionSearch(
                 self.graph,
                 maximum_sequence_region_size=self.config.maximum_sequence_region_size,
                 minimum_topdown_search_size=self.config.minimum_topdown_search_size,

Also applies to: 114-120

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/onnx/quantization/autotune/autotuner.py` around lines 101 - 104, The
Torch branch double-flattens regions: when
check_torch_naming_convention(self.graph) is true you call
TorchRegionBuilder(self.graph).build_regions(linearize=True, ...) which already
returns linearized descendants, but the subsequent recursive flatten loop (the
code that iterates over self.regions and extends a flat list) flattens again and
creates duplicates; fix by detecting the linearize=True case and skipping the
extra recursive flatten (i.e., assign self.regions directly to the build_regions
result when linearize is requested) or, alternatively, guard the recursive
flatten with a check for already-linearized entries or deduplicate by region
identity; reference check_torch_naming_convention, TorchRegionBuilder,
build_regions and self.regions when applying the change.

default_search = CombinedRegionSearch(
self.graph,
maximum_sequence_region_size=self.config.maximum_sequence_region_size,
minimum_topdown_search_size=self.config.minimum_topdown_search_size,
)
self.regions = default_search.search_regions()
self._reassign_region_ids(self.regions)
logger.debug(f"Found {len(self.regions)} top-level regions")

# Flatten the hierarchy into a list of all regions
Expand Down
Loading