Skip to content

[NVBug 6102977] Add _disable_use_cache context manager to fix PTQ AttributeError on custom configs#1324

Open
meenchen wants to merge 4 commits intomainfrom
weimingc/nvbug_6102977
Open

[NVBug 6102977] Add _disable_use_cache context manager to fix PTQ AttributeError on custom configs#1324
meenchen wants to merge 4 commits intomainfrom
weimingc/nvbug_6102977

Conversation

@meenchen
Copy link
Copy Markdown
Contributor

@meenchen meenchen commented Apr 22, 2026

What does this PR do?

Type of change: Bug fix

  • Summary: Running hf_ptq.py on stepfun-ai/Step-3.5-Flash (and any model whose custom HF config doesn't assign use_cache) crashed in get_max_batch_size() with AttributeError: 'Step3p5Config' object has no attribute 'use_cache' before calibration could start.
  • Extract the existing "disable KV cache during calibration" logic into a _disable_use_cache(model) context manager, apply it to both get_max_batch_size and _forward_loop. The CM sets config.use_cache = False unconditionally (not only when the attribute exists) and restores the prior value on exit if one was set.
  • Behavior unchanged for normal configs; the NemotronH hybrid-cache correctness guarantee from Add layerwise calibration for large models #1251 is preserved.

Usage

# Add a code snippet demonstrating how to use this

Testing

Step-3.5-Flash PTQ now passes get_max_batch_size

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ / ❌ / N/A
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

  • Refactor

    • Improved memory handling during model evaluation and calibration by consistently disabling KV cache for both single-batch probes and full dataloader runs, simplifying and stabilizing inference flow and ensuring cache state is managed reliably.
  • Tests

    • Added unit tests verifying cache-state handling across models with and without cache settings, including correct restoration behavior even when errors occur.

@meenchen meenchen requested a review from a team as a code owner April 22, 2026 19:16
@meenchen meenchen requested a review from AAnoosheh April 22, 2026 19:16
@meenchen meenchen self-assigned this Apr 22, 2026
@meenchen meenchen added bug Something isn't working cherry-pick-0.44.0 After code freeze, cherry-pick to release branch for next rc (bulk update). Only for bug fixes / doc labels Apr 22, 2026
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 22, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 50f1c80e-6711-485a-82e9-e7d8f6024e55

📥 Commits

Reviewing files that changed from the base of the PR and between 60b97ae and bde46ae.

📒 Files selected for processing (1)
  • tests/unit/torch/utils/test_dataset_utils.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/unit/torch/utils/test_dataset_utils.py

📝 Walkthrough

Walkthrough

Adds a private _disable_use_cache(model) context manager and updates dataset utilities to run dummy memory probes and dataloader forwards with caching disabled by using _disable_use_cache(model) together with torch.no_grad(). Tests added to exercise context manager behaviors and _forward_loop interactions.

Changes

Cohort / File(s) Summary
Dataset utils (cache behavior)
modelopt/torch/utils/dataset_utils.py
Adds _disable_use_cache(model) context manager that sets model.config.use_cache = False inside the block and restores or removes the attribute on exit. Wraps the initial dummy forward probe and the candidate batch OOM-testing loop in get_max_batch_size() with this context. Refactors _forward_loop() to use the shared context plus torch.no_grad() instead of bespoke save/restore logic.
Tests (context manager & forward loop)
tests/unit/torch/utils/test_dataset_utils.py
Adds unit tests that import and validate _disable_use_cache behaviors for models without config, with existing config.use_cache, and with config but no use_cache; verifies restoration/removal on normal exit and on exception. Tests _forward_loop to ensure config.use_cache is False during iteration and restored afterward. Introduces a minimal _Config helper for tests.

Sequence Diagram(s)

sequenceDiagram
  participant Caller
  participant Disable as "disable_use_cache\n(context)"
  participant NoGrad as "torch.no_grad\n(context)"
  participant DataLoader
  participant Model

  Caller->>Disable: enter — save state / set config.use_cache=False
  Disable->>NoGrad: enter
  NoGrad->>DataLoader: iterate batches
  DataLoader->>Model: forward(batch)
  Model-->>DataLoader: outputs
  DataLoader-->>NoGrad: next batch / done
  NoGrad->>Disable: exit
  Disable->>Caller: restore original use_cache (or delete)
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 68.75% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title directly and clearly summarizes the main change: adding a _disable_use_cache context manager to fix an AttributeError on custom configs during PTQ, which matches the primary intent and changeset.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed PR introduces internal utility code for managing model caching behavior during quantization without introducing security anti-patterns such as torch.load(), numpy.load(), eval(), exec(), or trust_remote_code issues.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch weimingc/nvbug_6102977

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 22, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1324/

Built to branch gh-pages at 2026-04-27 20:12 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 22, 2026

Codecov Report

❌ Patch coverage is 90.00000% with 4 lines in your changes missing coverage. Please review.
✅ Project coverage is 73.42%. Comparing base (0678136) to head (bde46ae).
⚠️ Report is 21 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/utils/dataset_utils.py 90.00% 4 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1324      +/-   ##
==========================================
- Coverage   74.46%   73.42%   -1.05%     
==========================================
  Files         464      485      +21     
  Lines       50089    53684    +3595     
==========================================
+ Hits        37300    39418    +2118     
- Misses      12789    14266    +1477     
Flag Coverage Δ
examples 41.50% <80.00%> (+5.45%) ⬆️
gpu 58.57% <80.00%> (-0.52%) ⬇️
regression 14.85% <10.00%> (+0.06%) ⬆️
unit 52.77% <57.50%> (+0.33%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@meenchen meenchen changed the title [NVBug 6102977] Fix step3.5 PTQ [NVBug 6102977] Add _disable_use_cache context manager to fix PTQ AttributeError on custom configs Apr 22, 2026
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The context manager has a state-leak bug: when config didn't originally have use_cache, the attribute is set to False but never cleaned up on exit. After the context manager completes, the config permanently has use_cache = False that wasn't there before. The finally block needs to delattr in the not had_attr case.

Also, no unit tests are included. A simple test for _disable_use_cache covering the three cases (no config, config with existing use_cache, config without use_cache) would be straightforward and valuable.

try:
yield
finally:
if had_attr:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: When had_attr is False (the config didn't originally have use_cache), the context manager sets config.use_cache = False on entry but never removes it on exit. This leaks a new attribute onto the config object.

The finally block should clean up:

finally:
    if had_attr:
        config.use_cache = prev
    else:
        delattr(config, "use_cache")

Without this, after calling get_max_batch_size or _forward_loop, a config that never had use_cache will now permanently have use_cache = False, which could change model behavior for subsequent inference.

@@ -437,6 +438,33 @@ def get_supported_datasets() -> list[str]:
return list(SUPPORTED_DATASET_CONFIG.keys())
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing tests: Please add unit tests for _disable_use_cache covering:

  1. Model with no config attribute (no-op)
  2. Model whose config already has use_cache (restored on exit)
  3. Model whose config lacks use_cache (attribute should not persist after exit)

These are simple to write with a mock nn.Module and would directly validate the bug fix.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/utils/dataset_utils.py`:
- Around line 524-540: The retry loop halves target_data_batch on OOM but never
rebuilds target_input, so infer_method keeps receiving the original (too-large)
tensor; modify the loop to recreate target_input from sample_input_single_batch
after each reduction of target_data_batch (i.e., move or add the target_input =
sample_input_single_batch.expand([...]) construction inside the while loop using
the updated target_data_batch and same shape logic with enumerate), then call
infer_method with the rebuilt tensor so each smaller batch size is actually
tested.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 25913188-255c-4bd2-88c5-91fbddba1bb2

📥 Commits

Reviewing files that changed from the base of the PR and between 204daaf and 66b7ec4.

📒 Files selected for processing (2)
  • modelopt/torch/utils/dataset_utils.py
  • tests/unit/torch/utils/test_dataset_utils.py

Comment on lines +524 to 540
target_input = sample_input_single_batch.expand(
[
target_data_batch if index == 0 else dim
for index, dim in enumerate(sample_input_single_batch.shape)
]
)
target_data_batch = 1
else:
target_data_batch = max(int(free_mem_before / mem_diff_per_data_batch), 1)
target_input = sample_input_single_batch.expand(
[
target_data_batch if index == 0 else dim
for index, dim in enumerate(sample_input_single_batch.shape)
]
)

# For some models on multi GPU, we observe the memory per batch is not a constant.
# So we just test the target batch size and make sure we do not go OOM.
while target_data_batch > 1:
with torch.set_grad_enabled(enable_grad):
try:
infer_method(target_input)
break
except torch.cuda.OutOfMemoryError:
target_data_batch = target_data_batch // 2
# For some models on multi GPU, we observe the memory per batch is not a constant.
# So we just test the target batch size and make sure we do not go OOM.
while target_data_batch > 1:
with torch.set_grad_enabled(enable_grad):
try:
infer_method(target_input)
break
except torch.cuda.OutOfMemoryError:
target_data_batch = target_data_batch // 2

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Retry loop halves target_data_batch but reuses stale target_input.

After an OOM, the loop updates target_data_batch only; it keeps retrying with the original expanded tensor shape, so reduced batch sizes are never actually tested.

💡 Proposed fix
-        target_input = sample_input_single_batch.expand(
-            [
-                target_data_batch if index == 0 else dim
-                for index, dim in enumerate(sample_input_single_batch.shape)
-            ]
-        )
-
         # For some models on multi GPU, we observe the memory per batch is not a constant.
         # So we just test the target batch size and make sure we do not go OOM.
         while target_data_batch > 1:
+            target_input = sample_input_single_batch.expand(
+                [
+                    target_data_batch if index == 0 else dim
+                    for index, dim in enumerate(sample_input_single_batch.shape)
+                ]
+            )
             with torch.set_grad_enabled(enable_grad):
                 try:
                     infer_method(target_input)
                     break
                 except torch.cuda.OutOfMemoryError:
                     target_data_batch = target_data_batch // 2
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/utils/dataset_utils.py` around lines 524 - 540, The retry loop
halves target_data_batch on OOM but never rebuilds target_input, so infer_method
keeps receiving the original (too-large) tensor; modify the loop to recreate
target_input from sample_input_single_batch after each reduction of
target_data_batch (i.e., move or add the target_input =
sample_input_single_batch.expand([...]) construction inside the while loop using
the updated target_data_batch and same shape logic with enumerate), then call
infer_method with the rebuilt tensor so each smaller batch size is actually
tested.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, but this is pre-existing — target_input is built outside the OOM retry loop on the parent commit too (verified via git show 204daaf^:modelopt/torch/utils/dataset_utils.py; the loop body is unchanged here, the lines only show in the diff because we re-indented to wrap them in with _disable_use_cache(model):). Will file a separate fix to move target_input = ...expand(...) inside the while loop. Out of scope for this PR.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@meenchen, thanks for verifying — that's a fair point. Since the bug predates this PR and you're planning a follow-up fix, would you like me to open a GitHub issue to track moving target_input = sample_input_single_batch.expand(...) inside the while loop? That way it won't get lost.


🧠 Learnings used
Learnt from: kevalmorabia97
Repo: NVIDIA/Model-Optimizer PR: 974
File: modelopt/torch/puzzletron/pruning/pruning_utils.py:195-201
Timestamp: 2026-03-06T14:26:21.439Z
Learning: In NVIDIA/Model-Optimizer, for PyTorch >= 2.6, torch.load() calls without an explicit weights_only argument are safe. Do not flag bare torch.load(...) as a security issue in files under the modelopt package (e.g., modelopt/torch/puzzletron/pruning/pruning_utils.py) as long as the PyTorch version constraint is maintained. If supporting PyTorch < 2.6, require an explicit weights_only argument to torch.load() to avoid potential issues.

Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
@meenchen meenchen force-pushed the weimingc/nvbug_6102977 branch from eaf6bc8 to 60b97ae Compare April 27, 2026 19:25
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

All critical issues from the previous review have been addressed:

  1. State-leak bug (critical) — Fixed: The finally block now properly calls delattr(config, "use_cache") (wrapped in suppress(AttributeError)) when the attribute didn't exist before entry. This prevents leaking use_cache = False onto configs that never had it.

  2. Missing unit tests (critical) — Fixed: Four well-structured tests added covering all cases: no config attribute (no-op), existing use_cache restored on exit (parametrized True/False), missing use_cache cleaned up on exit (no leak), and restoration on exception.

  3. OOM retry loop stale target_input (pre-existing): Author correctly identified this as a pre-existing bug unrelated to this PR — the re-indentation just made it visible in the diff. Appropriate to fix separately.

The context manager is clean, well-documented, and correctly applied to both get_max_batch_size and _forward_loop. The refactoring from inline try/finally to a shared @contextmanager is a good simplification.

Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working cherry-pick-0.44.0 After code freeze, cherry-pick to release branch for next rc (bulk update). Only for bug fixes / doc

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants