[NVBug 6102977] Add _disable_use_cache context manager to fix PTQ AttributeError on custom configs#1324
[NVBug 6102977] Add _disable_use_cache context manager to fix PTQ AttributeError on custom configs#1324
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Enterprise Run ID: 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthroughAdds a private Changes
Sequence Diagram(s)sequenceDiagram
participant Caller
participant Disable as "disable_use_cache\n(context)"
participant NoGrad as "torch.no_grad\n(context)"
participant DataLoader
participant Model
Caller->>Disable: enter — save state / set config.use_cache=False
Disable->>NoGrad: enter
NoGrad->>DataLoader: iterate batches
DataLoader->>Model: forward(batch)
Model-->>DataLoader: outputs
DataLoader-->>NoGrad: next batch / done
NoGrad->>Disable: exit
Disable->>Caller: restore original use_cache (or delete)
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 5 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
|
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1324 +/- ##
==========================================
- Coverage 74.46% 73.42% -1.05%
==========================================
Files 464 485 +21
Lines 50089 53684 +3595
==========================================
+ Hits 37300 39418 +2118
- Misses 12789 14266 +1477
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
cjluo-nv
left a comment
There was a problem hiding this comment.
The context manager has a state-leak bug: when config didn't originally have use_cache, the attribute is set to False but never cleaned up on exit. After the context manager completes, the config permanently has use_cache = False that wasn't there before. The finally block needs to delattr in the not had_attr case.
Also, no unit tests are included. A simple test for _disable_use_cache covering the three cases (no config, config with existing use_cache, config without use_cache) would be straightforward and valuable.
| try: | ||
| yield | ||
| finally: | ||
| if had_attr: |
There was a problem hiding this comment.
Bug: When had_attr is False (the config didn't originally have use_cache), the context manager sets config.use_cache = False on entry but never removes it on exit. This leaks a new attribute onto the config object.
The finally block should clean up:
finally:
if had_attr:
config.use_cache = prev
else:
delattr(config, "use_cache")Without this, after calling get_max_batch_size or _forward_loop, a config that never had use_cache will now permanently have use_cache = False, which could change model behavior for subsequent inference.
| @@ -437,6 +438,33 @@ def get_supported_datasets() -> list[str]: | |||
| return list(SUPPORTED_DATASET_CONFIG.keys()) | |||
There was a problem hiding this comment.
Missing tests: Please add unit tests for _disable_use_cache covering:
- Model with no
configattribute (no-op) - Model whose config already has
use_cache(restored on exit) - Model whose config lacks
use_cache(attribute should not persist after exit)
These are simple to write with a mock nn.Module and would directly validate the bug fix.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/utils/dataset_utils.py`:
- Around line 524-540: The retry loop halves target_data_batch on OOM but never
rebuilds target_input, so infer_method keeps receiving the original (too-large)
tensor; modify the loop to recreate target_input from sample_input_single_batch
after each reduction of target_data_batch (i.e., move or add the target_input =
sample_input_single_batch.expand([...]) construction inside the while loop using
the updated target_data_batch and same shape logic with enumerate), then call
infer_method with the rebuilt tensor so each smaller batch size is actually
tested.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 25913188-255c-4bd2-88c5-91fbddba1bb2
📒 Files selected for processing (2)
modelopt/torch/utils/dataset_utils.pytests/unit/torch/utils/test_dataset_utils.py
| target_input = sample_input_single_batch.expand( | ||
| [ | ||
| target_data_batch if index == 0 else dim | ||
| for index, dim in enumerate(sample_input_single_batch.shape) | ||
| ] | ||
| ) | ||
| target_data_batch = 1 | ||
| else: | ||
| target_data_batch = max(int(free_mem_before / mem_diff_per_data_batch), 1) | ||
| target_input = sample_input_single_batch.expand( | ||
| [ | ||
| target_data_batch if index == 0 else dim | ||
| for index, dim in enumerate(sample_input_single_batch.shape) | ||
| ] | ||
| ) | ||
|
|
||
| # For some models on multi GPU, we observe the memory per batch is not a constant. | ||
| # So we just test the target batch size and make sure we do not go OOM. | ||
| while target_data_batch > 1: | ||
| with torch.set_grad_enabled(enable_grad): | ||
| try: | ||
| infer_method(target_input) | ||
| break | ||
| except torch.cuda.OutOfMemoryError: | ||
| target_data_batch = target_data_batch // 2 | ||
| # For some models on multi GPU, we observe the memory per batch is not a constant. | ||
| # So we just test the target batch size and make sure we do not go OOM. | ||
| while target_data_batch > 1: | ||
| with torch.set_grad_enabled(enable_grad): | ||
| try: | ||
| infer_method(target_input) | ||
| break | ||
| except torch.cuda.OutOfMemoryError: | ||
| target_data_batch = target_data_batch // 2 | ||
|
|
There was a problem hiding this comment.
Retry loop halves target_data_batch but reuses stale target_input.
After an OOM, the loop updates target_data_batch only; it keeps retrying with the original expanded tensor shape, so reduced batch sizes are never actually tested.
💡 Proposed fix
- target_input = sample_input_single_batch.expand(
- [
- target_data_batch if index == 0 else dim
- for index, dim in enumerate(sample_input_single_batch.shape)
- ]
- )
-
# For some models on multi GPU, we observe the memory per batch is not a constant.
# So we just test the target batch size and make sure we do not go OOM.
while target_data_batch > 1:
+ target_input = sample_input_single_batch.expand(
+ [
+ target_data_batch if index == 0 else dim
+ for index, dim in enumerate(sample_input_single_batch.shape)
+ ]
+ )
with torch.set_grad_enabled(enable_grad):
try:
infer_method(target_input)
break
except torch.cuda.OutOfMemoryError:
target_data_batch = target_data_batch // 2🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/utils/dataset_utils.py` around lines 524 - 540, The retry loop
halves target_data_batch on OOM but never rebuilds target_input, so infer_method
keeps receiving the original (too-large) tensor; modify the loop to recreate
target_input from sample_input_single_batch after each reduction of
target_data_batch (i.e., move or add the target_input =
sample_input_single_batch.expand([...]) construction inside the while loop using
the updated target_data_batch and same shape logic with enumerate), then call
infer_method with the rebuilt tensor so each smaller batch size is actually
tested.
There was a problem hiding this comment.
Good catch, but this is pre-existing — target_input is built outside the OOM retry loop on the parent commit too (verified via git show 204daaf^:modelopt/torch/utils/dataset_utils.py; the loop body is unchanged here, the lines only show in the diff because we re-indented to wrap them in with _disable_use_cache(model):). Will file a separate fix to move target_input = ...expand(...) inside the while loop. Out of scope for this PR.
There was a problem hiding this comment.
@meenchen, thanks for verifying — that's a fair point. Since the bug predates this PR and you're planning a follow-up fix, would you like me to open a GitHub issue to track moving target_input = sample_input_single_batch.expand(...) inside the while loop? That way it won't get lost.
🧠 Learnings used
Learnt from: kevalmorabia97
Repo: NVIDIA/Model-Optimizer PR: 974
File: modelopt/torch/puzzletron/pruning/pruning_utils.py:195-201
Timestamp: 2026-03-06T14:26:21.439Z
Learning: In NVIDIA/Model-Optimizer, for PyTorch >= 2.6, torch.load() calls without an explicit weights_only argument are safe. Do not flag bare torch.load(...) as a security issue in files under the modelopt package (e.g., modelopt/torch/puzzletron/pruning/pruning_utils.py) as long as the PyTorch version constraint is maintained. If supporting PyTorch < 2.6, require an explicit weights_only argument to torch.load() to avoid potential issues.
Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
eaf6bc8 to
60b97ae
Compare
cjluo-nv
left a comment
There was a problem hiding this comment.
Bot review — DM the bot to share feedback.
All critical issues from the previous review have been addressed:
-
State-leak bug (critical) — Fixed: The
finallyblock now properly callsdelattr(config, "use_cache")(wrapped insuppress(AttributeError)) when the attribute didn't exist before entry. This prevents leakinguse_cache = Falseonto configs that never had it. -
Missing unit tests (critical) — Fixed: Four well-structured tests added covering all cases: no config attribute (no-op), existing
use_cacherestored on exit (parametrized True/False), missinguse_cachecleaned up on exit (no leak), and restoration on exception. -
OOM retry loop stale
target_input(pre-existing): Author correctly identified this as a pre-existing bug unrelated to this PR — the re-indentation just made it visible in the diff. Appropriate to fix separately.
The context manager is clean, well-documented, and correctly applied to both get_max_batch_size and _forward_loop. The refactoring from inline try/finally to a shared @contextmanager is a good simplification.
Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
What does this PR do?
Type of change: Bug fix
Usage
# Add a code snippet demonstrating how to use thisTesting
Step-3.5-Flash PTQ now passes get_max_batch_size
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: ✅ / ❌ / N/AAdditional Information
Summary by CodeRabbit
Refactor
Tests