Skip softmax calibration with list of thresholds#987
Skip softmax calibration with list of thresholds#987rohansjoshi wants to merge 1 commit intomainfrom
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
📝 WalkthroughWalkthroughThe pull request transitions attention sparsity functionality from single-threshold to multi-threshold support across the ModelOpt library. The calibrator now collects all threshold data in one forward pass, configurations use threshold lists per phase, and implementation methods handle multiple thresholds. All tests updated to align with the new schema. Changes
Sequence Diagram(s)sequenceDiagram
participant Calibrator
participant Module
participant SparseMethod as Sparse Method
participant Aggregator
rect rgba(100, 150, 200, 0.5)
Note over Calibrator,Aggregator: Old Approach: Per-Threshold Loops
loop For each threshold
Calibrator->>Module: forward(threshold=t)
Module->>SparseMethod: compute sparsity(threshold=t)
SparseMethod->>Module: return sparsity_t
Calibrator->>Aggregator: aggregate(t, sparsity_t)
end
end
rect rgba(150, 200, 100, 0.5)
Note over Calibrator,Aggregator: New Approach: Single-Pass Collection
Calibrator->>Module: forward(no threshold param)
Module->>SparseMethod: compute sparsity for ALL thresholds
SparseMethod->>Module: return sparsity_list = [s1, s2, ..., sN]
Module->>Calibrator: return per-sample sparsity_list
loop For each threshold in list
Calibrator->>Aggregator: unpack and aggregate (scale_factor=t×length, sparsity)
end
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Comment |
…single pass Signed-off-by: Rohan Joshi <rohjoshi@nvidia.com>
7cb2377 to
8f455c1
Compare
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (1)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py (1)
137-140: Guard against silent data loss when pairing thresholds and sparsities.At Line 140,
zip(self.threshold_trials, sparsity_list)silently truncates on length mismatch, which can hide calibration stat drift.💡 Suggested fix
for sample_stat in per_sample_stats: length = sample_stat["sample_length"] sparsity_list = sample_stat["sparsity"] + if len(sparsity_list) != len(self.threshold_trials): + raise ValueError( + f"Expected {len(self.threshold_trials)} sparsity values, got {len(sparsity_list)}" + ) for threshold, sparsity in zip(self.threshold_trials, sparsity_list):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py` around lines 137 - 140, The loop silently truncates mismatched pairs by using zip(self.threshold_trials, sparsity_list); before iterating (in the calibrator that processes per_sample_stats), validate that len(sparsity_list) == len(self.threshold_trials) and if not, raise a clear exception or log an error and skip the sample to avoid silent data loss—use the sample_stat/"sample_length" and sparsity_list context to include identifying info in the message; do not rely on zip_longest to silently fill values, explicitly enforce or handle length mismatches.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py`:
- Around line 130-136: Wrap the calibration sequence in a try/finally so
calibration mode is always cleaned up on error: after calling
self._set_thresholds(...) and self._enable_calibration_mode(...), run
forward_loop(model) and self._extract_calibration_stats(...) inside a try block
and call self._disable_calibration_mode(...) (and reset any trial thresholds if
applicable) in the finally block; reference the methods _set_thresholds,
_enable_calibration_mode, forward_loop, _extract_calibration_stats, and
_disable_calibration_mode so you locate and wrap that exact sequence to ensure
modules are disabled and thresholds cleared even when exceptions occur.
In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`:
- Around line 179-183: The code computes total_blocks/total_valid_blocks using
num_causal_blocks when self.is_causal but later counts dense block positions
across all blocks (producing negative sparsity); update the dense-block counting
to mask out non-causal positions or else compute both numerator and denominator
from the same masked positions: use the same causal mask used to derive
num_causal_blocks when counting dense blocks (and when computing
total_blocks/total_valid_blocks) so numerator and denominator align (apply this
fix in the block that sets total_blocks/total_valid_blocks and also in the later
dense-counting section referenced around the second occurrence at Lines
~194-197); refer to self.is_causal, num_causal_blocks, total_valid_blocks,
total_blocks and the dense-block counting logic to locate and change the code.
- Around line 60-61: The code currently falls back to the runtime value
self.thresholds when a phase key is missing, making behavior depend on the order
phases run; instead, when resolving per-phase thresholds use only configuration
defaults (e.g. phase_val = self.thresholds_config.get(phase,
self.thresholds_config.get("prefill", [1e-3]))), so replace any use of
self.thresholds as the fallback with a config-only chain (phase -> "prefill" ->
literal default) in the code that looks up phase thresholds (references:
self.thresholds_config and self.thresholds).
---
Nitpick comments:
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py`:
- Around line 137-140: The loop silently truncates mismatched pairs by using
zip(self.threshold_trials, sparsity_list); before iterating (in the calibrator
that processes per_sample_stats), validate that len(sparsity_list) ==
len(self.threshold_trials) and if not, raise a clear exception or log an error
and skip the sample to avoid silent data loss—use the
sample_stat/"sample_length" and sparsity_list context to include identifying
info in the message; do not rely on zip_longest to silently fill values,
explicitly enforce or handle length mismatches.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: aced1124-5d44-4cab-b27a-89ab4c75bffa
📒 Files selected for processing (12)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.pymodelopt/torch/sparsity/attention_sparsity/config.pymodelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.pymodelopt/torch/sparsity/attention_sparsity/model_sparsify.pymodelopt/torch/sparsity/attention_sparsity/stats_manager.pytests/_test_utils/torch/sparsity/sparse_attention_common.pytests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.pytests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.pytests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.pytests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.pytests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.pytests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py
| self._set_thresholds(attention_modules, self.threshold_trials) | ||
| self._enable_calibration_mode(attention_modules) | ||
| with torch.no_grad(): | ||
| forward_loop(model) | ||
| per_sample_stats = self._extract_calibration_stats(attention_modules, phase=phase) | ||
| self._disable_calibration_mode(attention_modules) | ||
|
|
There was a problem hiding this comment.
Ensure calibration mode is always disabled on failure.
At Line 130–Line 136, an exception in forward_loop(model) or _extract_calibration_stats(...) leaves modules in calibration mode with trial thresholds still set.
💡 Suggested fix
self._set_thresholds(attention_modules, self.threshold_trials)
self._enable_calibration_mode(attention_modules)
- with torch.no_grad():
- forward_loop(model)
- per_sample_stats = self._extract_calibration_stats(attention_modules, phase=phase)
- self._disable_calibration_mode(attention_modules)
+ try:
+ with torch.no_grad():
+ forward_loop(model)
+ per_sample_stats = self._extract_calibration_stats(attention_modules, phase=phase)
+ finally:
+ self._disable_calibration_mode(attention_modules)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py` around
lines 130 - 136, Wrap the calibration sequence in a try/finally so calibration
mode is always cleaned up on error: after calling self._set_thresholds(...) and
self._enable_calibration_mode(...), run forward_loop(model) and
self._extract_calibration_stats(...) inside a try block and call
self._disable_calibration_mode(...) (and reset any trial thresholds if
applicable) in the finally block; reference the methods _set_thresholds,
_enable_calibration_mode, forward_loop, _extract_calibration_stats, and
_disable_calibration_mode so you locate and wrap that exact sequence to ensure
modules are disabled and thresholds cleared even when exceptions occur.
| self.thresholds = self.thresholds_config.get("prefill", [1e-3]) | ||
|
|
There was a problem hiding this comment.
Avoid phase-order-dependent threshold fallback.
At Line 71, missing phase keys fall back to the previous runtime value (self.thresholds), so behavior depends on which phase ran first instead of config-only defaults.
💡 Suggested fix
- # Initialize thresholds from dict config (prefill phase as default)
- self.thresholds = self.thresholds_config.get("prefill", [1e-3])
+ # Deterministic fallback for configs that define only one phase
+ self._fallback_thresholds = (
+ self.thresholds_config.get("prefill")
+ or self.thresholds_config.get("decode")
+ or [1e-3]
+ )
+ self.thresholds = list(self._fallback_thresholds)
def _update_thresholds(self, phase: str):
"""Update thresholds list based on phase."""
- self.thresholds = self.thresholds_config.get(phase, self.thresholds)
+ self.thresholds = list(self.thresholds_config.get(phase, self._fallback_thresholds))Also applies to: 69-72
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`
around lines 60 - 61, The code currently falls back to the runtime value
self.thresholds when a phase key is missing, making behavior depend on the order
phases run; instead, when resolving per-phase thresholds use only configuration
defaults (e.g. phase_val = self.thresholds_config.get(phase,
self.thresholds_config.get("prefill", [1e-3]))), so replace any use of
self.thresholds as the fallback with a config-only chain (phase -> "prefill" ->
literal default) in the code that looks up phase thresholds (references:
self.thresholds_config and self.thresholds).
| if self.is_causal: | ||
| # For causal attention, only count lower triangle blocks (including diagonal) | ||
| num_causal_blocks = num_block_rows * (2 * num_block_cols - num_block_rows + 1) // 2 | ||
| total_valid_blocks = batch_size * num_heads * num_causal_blocks | ||
| dense_blocks = block_mask.sum() | ||
| total_blocks = num_causal_blocks | ||
| else: |
There was a problem hiding this comment.
Mask non-causal blocks before counting dense blocks.
At Line 179–Line 183, denominator is causal-only, but at Line 194–Line 197 numerator counts all block positions. This can produce invalid sparsity values (< 0). It also explains the weakened prefill test bounds.
💡 Suggested fix
- if self.is_causal:
- num_causal_blocks = num_block_rows * (2 * num_block_cols - num_block_rows + 1) // 2
+ causal_block_layout = None
+ if self.is_causal:
+ causal_block_layout = torch.tril(
+ torch.ones(
+ (num_block_rows, num_block_cols),
+ dtype=torch.bool,
+ device=attn_weights.device,
+ )
+ )
+ num_causal_blocks = int(causal_block_layout.sum().item())
total_valid_blocks = batch_size * num_heads * num_causal_blocks
total_blocks = num_causal_blocks
else:
total_valid_blocks = batch_size * num_heads * num_block_rows * num_block_cols
total_blocks = num_block_rows * num_block_cols
@@
for i, log_threshold in enumerate(log_thresholds):
block_mask = (block_max - block_max_cummax > log_threshold).any(dim=-2)
+ if causal_block_layout is not None:
+ block_mask = block_mask & causal_block_layout
dense_blocks_list.append(block_mask.sum().item())Also applies to: 194-197
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`
around lines 179 - 183, The code computes total_blocks/total_valid_blocks using
num_causal_blocks when self.is_causal but later counts dense block positions
across all blocks (producing negative sparsity); update the dense-block counting
to mask out non-causal positions or else compute both numerator and denominator
from the same masked positions: use the same causal mask used to derive
num_causal_blocks when counting dense blocks (and when computing
total_blocks/total_valid_blocks) so numerator and denominator align (apply this
fix in the block that sets total_blocks/total_valid_blocks and also in the later
dense-counting section referenced around the second occurrence at Lines
~194-197); refer to self.is_causal, num_causal_blocks, total_valid_blocks,
total_blocks and the dense-block counting logic to locate and change the code.
There was a problem hiding this comment.
Pull request overview
Updates FlashSkipSoftmax “skip softmax” calibration to support evaluating multiple sparsity thresholds in a single forward pass, improving calibration throughput while keeping inference sparsity behavior unchanged.
Changes:
- Rename sparse attention config from
threshold(scalar per phase) tothresholds(list per phase) and propagate through configs/tests. - Update FlashSkipSoftmax to compute per-threshold sparsity stats in one pass (and use the first threshold for the applied mask).
- Extend stats aggregation to handle
sparse_blocksas either a scalar or a list.
Reviewed changes
Copilot reviewed 12 out of 12 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py | Updates expected threshold info to thresholds dict-of-lists. |
| tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py | Updates sparse attention conversion tests to use thresholds. |
| tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py | Updates calibration tests/configs to use thresholds. |
| tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py | Updates FlashSkipSoftmax unit tests for list-based sparsity outputs. |
| tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py | Updates GPU integration configs to thresholds. |
| tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py | Updates GPU calibration configs to thresholds. |
| tests/_test_utils/torch/sparsity/sparse_attention_common.py | Updates shared test config fixtures to thresholds. |
| modelopt/torch/sparsity/attention_sparsity/stats_manager.py | Adds support for aggregating list-valued sparse_blocks and list average sparsity. |
| modelopt/torch/sparsity/attention_sparsity/model_sparsify.py | Updates public-facing doc/example to thresholds. |
| modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py | Implements multi-threshold stats collection and threshold list handling. |
| modelopt/torch/sparsity/attention_sparsity/config.py | Renames/validates thresholds as dict-of-float-lists (with length checks). |
| modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py | Switches calibration data collection to single-pass multi-threshold stats extraction. |
Comments suppressed due to low confidence (1)
modelopt/torch/sparsity/attention_sparsity/config.py:132
validate_thresholdsstill raises an error that says "Threshold must be..." even though the field is nowthresholdsand expects lists. Updating this message will make validation failures much clearer to users.
def validate_thresholds(cls, v):
"""Validate thresholds is a dict of lists with valid phases and values in range (0, 1)."""
if not isinstance(v, dict):
raise ValueError(
f"Threshold must be a dict with 'prefill' and/or 'decode' keys, got {type(v).__name__}"
)
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def test_correction_factor_calculation_prefill(self): | ||
| """Test correction factor for prefill phase.""" | ||
| method = FlashSkipSoftmax( | ||
| { | ||
| "threshold": {"prefill": 1e-3, "decode": 1e-4}, | ||
| "thresholds": {"prefill": [1e-3], "decode": [1e-4]}, | ||
| "br": 128, | ||
| "bc": 128, | ||
| "backend": "pytorch", | ||
| "is_causal": True, | ||
| } | ||
| ) | ||
|
|
||
| # Create simple attention pattern | ||
| attn = torch.randn(1, 1, 128, 256) | ||
|
|
||
| mask, stats = method.calc_correction_factor_and_p(attn, "prefill") | ||
|
|
||
| # Verify stats structure | ||
| assert "correction_factor" in stats | ||
| assert "sparsity" in stats | ||
| assert "phase" in stats | ||
| assert "total_blocks" in stats | ||
| assert stats["phase"] == "prefill" | ||
| assert 0 <= stats["correction_factor"] <= 1 | ||
| # Sparsity can be negative if threshold is too low (more blocks kept than expected) | ||
| assert -1 <= stats["sparsity"] <= 1 | ||
| # sparsity is now a list (one entry per threshold) | ||
| assert isinstance(stats["sparsity"], list) | ||
| assert all(-1 <= s <= 1 for s in stats["sparsity"]) | ||
|
|
There was a problem hiding this comment.
The new multi-threshold behavior isn't covered by unit tests yet: the updated tests still only use single-entry lists. Adding a test with thresholds length > 1 would help verify that stats['sparsity'] length matches the threshold list and that the applied mask corresponds to the first threshold (as documented).
| thresholds: dict[str, list[float]] = ModeloptField( | ||
| default={"prefill": [1e-3], "decode": [1e-4]}, | ||
| title="Sparsity thresholds.", | ||
| description=( | ||
| "Threshold for determining which attention values to skip. " | ||
| "Must be a dict with 'prefill' and 'decode' keys." | ||
| "Thresholds for determining which attention values to skip. " | ||
| "Must be a dict with 'prefill' and/or 'decode' keys, each mapping to a list of floats. " | ||
| "Prefill and decode lists must have the same length. " | ||
| "Sparsity is computed per threshold; the first threshold's mask is applied." | ||
| ), | ||
| ) |
There was a problem hiding this comment.
The config rename to thresholds is a breaking change but the repo still has callers using the old threshold field (e.g. tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py). This will cause validation errors and failing tests unless you either update the remaining call sites/tests or provide a backwards-compatible alias (e.g., accept threshold as a deprecated alias that maps to a single-entry thresholds list).
| # Step 6-7: Expand block_mask_0 to element level (skip in calibration mode) | ||
| if not self._calibration_mode and block_mask_0 is not None: | ||
| element_mask = ( | ||
| block_mask_0.unsqueeze(-2) | ||
| .unsqueeze(-1) | ||
| .expand(batch_size, num_heads, num_block_rows, self.br, num_block_cols, self.bc) | ||
| ) | ||
| del block_mask_0 | ||
| element_mask = element_mask.reshape( | ||
| batch_size, num_heads, padded_seq_q, padded_seq_k | ||
| ) | ||
| element_mask = element_mask[:, :, :seq_q, :seq_k] | ||
| else: | ||
| element_mask = None | ||
|
|
There was a problem hiding this comment.
In calibration mode calc_correction_factor_and_p returns element_mask = None, but calculate_sparsity() (and the base SparseAttentionMethod API) is typed/used as returning a tensor mask. Even if the current softmax patch ignores the mask during calibration, returning None risks surprising failures for any other caller of calculate_sparsity() while calibration_mode is enabled. Consider returning a lightweight sentinel tensor (or updating the method contract/call sites to explicitly allow None).
| if current == 0: | ||
| # Lazy init: first time seeing a list | ||
| self.aggregated_stats["sparse_blocks"] = list(incoming) | ||
| elif isinstance(current, list): |
There was a problem hiding this comment.
SparseAttentionStatsManager.collect() accumulates sparse_blocks lists by index but doesn't validate that incoming and current have the same length. If the configured threshold list changes between calls, this can raise an IndexError (or silently keep stale tail values if the new list is shorter). Consider explicitly checking list lengths and either resetting safely or raising a clear error.
| elif isinstance(current, list): | |
| elif isinstance(current, list): | |
| if len(current) != len(incoming): | |
| raise ValueError( | |
| f"SparseAttentionStatsManager.collect: length mismatch for " | |
| f"'sparse_blocks' in module '{self.module_name}': " | |
| f"existing length {len(current)}, incoming length {len(incoming)}." | |
| ) |
|
|
||
| for sample_stat in per_sample_stats: | ||
| length = sample_stat["sample_length"] | ||
| sparsity_list = sample_stat["sparsity"] |
There was a problem hiding this comment.
During calibration, zip(self.threshold_trials, sparsity_list) will silently truncate if a module reports a sparsity list of the wrong length (e.g. if a module is still using the calibrated single-threshold path and returns a scalar/length-1 list). It would be safer to validate len(sparsity_list) == len(self.threshold_trials) and fail fast (or temporarily disable calibration_params on modules during data collection) to avoid producing incorrect calibration data.
| sparsity_list = sample_stat["sparsity"] | |
| sparsity_list = sample_stat["sparsity"] | |
| # Validate that each sample reports one sparsity value per threshold. | |
| # Using zip() without this check would silently truncate if the lengths differ. | |
| if len(sparsity_list) != len(self.threshold_trials): | |
| raise ValueError( | |
| "Mismatch between number of thresholds and reported sparsity values during " | |
| f"{phase} calibration: expected {len(self.threshold_trials)} sparsity values " | |
| f"per sample, but got {len(sparsity_list)}. This may indicate that a module " | |
| "is still using a single-threshold path or is misconfigured for calibration." | |
| ) |
| assert len(set(lengths)) == 1, ( | ||
| f"All modules must have the same number of thresholds, got {lengths}" | ||
| ) |
There was a problem hiding this comment.
_extract_calibration_stats() uses a bare assert to enforce that all modules report the same number of thresholds. Asserts can be stripped with python -O, which would turn this into silent miscalibration. Prefer raising a ValueError (or similar) with the same message so the check always runs.
| assert len(set(lengths)) == 1, ( | |
| f"All modules must have the same number of thresholds, got {lengths}" | |
| ) | |
| if len(set(lengths)) != 1: | |
| raise ValueError( | |
| f"All modules must have the same number of thresholds, got {lengths}" | |
| ) |
Modify skip softmax calibration to use a list of thresholds instead of a single threshold. Sparsity during inference is unchanged, but during calibration we can use the list to gather statistics about many thresholds in a single forward pass. Makes calibration 20x faster
Summary by CodeRabbit
New Features
Improvements
Breaking Changes
thresholdfield renamed tothresholdsand now expects lists of values instead of scalars.