Skip to content

Skip softmax calibration with list of thresholds#987

Open
rohansjoshi wants to merge 1 commit intomainfrom
rohjoshi/sa-calib
Open

Skip softmax calibration with list of thresholds#987
rohansjoshi wants to merge 1 commit intomainfrom
rohjoshi/sa-calib

Conversation

@rohansjoshi
Copy link
Contributor

@rohansjoshi rohansjoshi commented Mar 6, 2026

Modify skip softmax calibration to use a list of thresholds instead of a single threshold. Sparsity during inference is unchanged, but during calibration we can use the list to gather statistics about many thresholds in a single forward pass. Makes calibration 20x faster

Summary by CodeRabbit

  • New Features

    • Multi-threshold sparsity configuration for fine-grained control over attention sparsity levels.
  • Improvements

    • Calibration efficiency: Single forward pass for collecting all threshold data instead of iterating per threshold.
    • Configuration format updated to support threshold lists for prefill and decode phases.
  • Breaking Changes

    • Configuration API: threshold field renamed to thresholds and now expects lists of values instead of scalars.
    • Sparsity statistics output updated to return per-threshold values.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 6, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 6, 2026

📝 Walkthrough

Walkthrough

The pull request transitions attention sparsity functionality from single-threshold to multi-threshold support across the ModelOpt library. The calibrator now collects all threshold data in one forward pass, configurations use threshold lists per phase, and implementation methods handle multiple thresholds. All tests updated to align with the new schema.

Changes

Cohort / File(s) Summary
Configuration Schema
modelopt/torch/sparsity/attention_sparsity/config.py
Updated SparseAttentionAttributeConfig.threshold to thresholds (dict of lists); updated validators for multi-threshold validation ensuring equal lengths across phases; updated FlashSkipSoftmaxConfig and global constants to use new structure.
Calibration Logic
modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py
Refactored threshold collection from per-threshold loops to single-pass forward; renamed _set_threshold() to _set_thresholds() accepting list of thresholds; updated internal data aggregation to handle per-sample multi-threshold sparsity lists.
Sparse Method Implementation
modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py
Renamed threshold/threshold_config to thresholds/thresholds_config; replaced _update_threshold() with _update_thresholds(); added conditional multi-threshold path in threshold computation; updated sparsity/masking logic to support multiple thresholds per phase.
Statistics Aggregation
modelopt/torch/sparsity/attention_sparsity/stats_manager.py
Enhanced sparse_blocks aggregation to handle both scalar and list forms with conditional branch logic; updated get_summary() to compute per-element or scalar averages depending on block type.
Documentation Updates
modelopt/torch/sparsity/attention_sparsity/model_sparsify.py
Updated configuration key examples from "threshold" to "thresholds" in docstrings and example configs.
Test Configuration Updates
tests/_test_utils/torch/sparsity/sparse_attention_common.py, tests/gpu/torch/sparsity/attention_sparsity/test_*.py, tests/unit/torch/sparsity/attention_sparsity/test_*.py
Updated all test configurations to use "thresholds" key with list values (e.g., {"prefill": [1e-3], "decode": [1e-4]}); adjusted assertions and sparsity expectations to reflect list-based outputs.

Sequence Diagram(s)

sequenceDiagram
    participant Calibrator
    participant Module
    participant SparseMethod as Sparse Method
    participant Aggregator

    rect rgba(100, 150, 200, 0.5)
    Note over Calibrator,Aggregator: Old Approach: Per-Threshold Loops
    loop For each threshold
        Calibrator->>Module: forward(threshold=t)
        Module->>SparseMethod: compute sparsity(threshold=t)
        SparseMethod->>Module: return sparsity_t
        Calibrator->>Aggregator: aggregate(t, sparsity_t)
    end
    end

    rect rgba(150, 200, 100, 0.5)
    Note over Calibrator,Aggregator: New Approach: Single-Pass Collection
    Calibrator->>Module: forward(no threshold param)
    Module->>SparseMethod: compute sparsity for ALL thresholds
    SparseMethod->>Module: return sparsity_list = [s1, s2, ..., sN]
    Module->>Calibrator: return per-sample sparsity_list
    loop For each threshold in list
        Calibrator->>Aggregator: unpack and aggregate (scale_factor=t×length, sparsity)
    end
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately and concisely describes the main change: refactoring the skip softmax calibration to accept multiple thresholds instead of a single threshold, which is the primary focus across all modified files.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Security Anti-Patterns ✅ Passed Review of all modified files reveals no critical security anti-patterns defined in SECURITY.md; changes consist of refactoring threshold configuration handling.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch rohjoshi/sa-calib

Comment @coderabbitai help to get the list of available commands and usage tips.

…single pass

Signed-off-by: Rohan Joshi <rohjoshi@nvidia.com>
@rohansjoshi rohansjoshi marked this pull request as ready for review March 6, 2026 19:14
@rohansjoshi rohansjoshi requested a review from a team as a code owner March 6, 2026 19:14
@kevalmorabia97 kevalmorabia97 requested review from kaix-nv and removed request for kevalmorabia97 March 6, 2026 19:15
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py (1)

137-140: Guard against silent data loss when pairing thresholds and sparsities.

At Line 140, zip(self.threshold_trials, sparsity_list) silently truncates on length mismatch, which can hide calibration stat drift.

💡 Suggested fix
         for sample_stat in per_sample_stats:
             length = sample_stat["sample_length"]
             sparsity_list = sample_stat["sparsity"]
+            if len(sparsity_list) != len(self.threshold_trials):
+                raise ValueError(
+                    f"Expected {len(self.threshold_trials)} sparsity values, got {len(sparsity_list)}"
+                )
             for threshold, sparsity in zip(self.threshold_trials, sparsity_list):
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py` around
lines 137 - 140, The loop silently truncates mismatched pairs by using
zip(self.threshold_trials, sparsity_list); before iterating (in the calibrator
that processes per_sample_stats), validate that len(sparsity_list) ==
len(self.threshold_trials) and if not, raise a clear exception or log an error
and skip the sample to avoid silent data loss—use the
sample_stat/"sample_length" and sparsity_list context to include identifying
info in the message; do not rely on zip_longest to silently fill values,
explicitly enforce or handle length mismatches.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py`:
- Around line 130-136: Wrap the calibration sequence in a try/finally so
calibration mode is always cleaned up on error: after calling
self._set_thresholds(...) and self._enable_calibration_mode(...), run
forward_loop(model) and self._extract_calibration_stats(...) inside a try block
and call self._disable_calibration_mode(...) (and reset any trial thresholds if
applicable) in the finally block; reference the methods _set_thresholds,
_enable_calibration_mode, forward_loop, _extract_calibration_stats, and
_disable_calibration_mode so you locate and wrap that exact sequence to ensure
modules are disabled and thresholds cleared even when exceptions occur.

In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`:
- Around line 179-183: The code computes total_blocks/total_valid_blocks using
num_causal_blocks when self.is_causal but later counts dense block positions
across all blocks (producing negative sparsity); update the dense-block counting
to mask out non-causal positions or else compute both numerator and denominator
from the same masked positions: use the same causal mask used to derive
num_causal_blocks when counting dense blocks (and when computing
total_blocks/total_valid_blocks) so numerator and denominator align (apply this
fix in the block that sets total_blocks/total_valid_blocks and also in the later
dense-counting section referenced around the second occurrence at Lines
~194-197); refer to self.is_causal, num_causal_blocks, total_valid_blocks,
total_blocks and the dense-block counting logic to locate and change the code.
- Around line 60-61: The code currently falls back to the runtime value
self.thresholds when a phase key is missing, making behavior depend on the order
phases run; instead, when resolving per-phase thresholds use only configuration
defaults (e.g. phase_val = self.thresholds_config.get(phase,
self.thresholds_config.get("prefill", [1e-3]))), so replace any use of
self.thresholds as the fallback with a config-only chain (phase -> "prefill" ->
literal default) in the code that looks up phase thresholds (references:
self.thresholds_config and self.thresholds).

---

Nitpick comments:
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py`:
- Around line 137-140: The loop silently truncates mismatched pairs by using
zip(self.threshold_trials, sparsity_list); before iterating (in the calibrator
that processes per_sample_stats), validate that len(sparsity_list) ==
len(self.threshold_trials) and if not, raise a clear exception or log an error
and skip the sample to avoid silent data loss—use the
sample_stat/"sample_length" and sparsity_list context to include identifying
info in the message; do not rely on zip_longest to silently fill values,
explicitly enforce or handle length mismatches.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: aced1124-5d44-4cab-b27a-89ab4c75bffa

📥 Commits

Reviewing files that changed from the base of the PR and between 1ccd945 and 8f455c1.

📒 Files selected for processing (12)
  • modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py
  • modelopt/torch/sparsity/attention_sparsity/model_sparsify.py
  • modelopt/torch/sparsity/attention_sparsity/stats_manager.py
  • tests/_test_utils/torch/sparsity/sparse_attention_common.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py
  • tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py
  • tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py
  • tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py
  • tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py

Comment on lines +130 to +136
self._set_thresholds(attention_modules, self.threshold_trials)
self._enable_calibration_mode(attention_modules)
with torch.no_grad():
forward_loop(model)
per_sample_stats = self._extract_calibration_stats(attention_modules, phase=phase)
self._disable_calibration_mode(attention_modules)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Ensure calibration mode is always disabled on failure.

At Line 130–Line 136, an exception in forward_loop(model) or _extract_calibration_stats(...) leaves modules in calibration mode with trial thresholds still set.

💡 Suggested fix
         self._set_thresholds(attention_modules, self.threshold_trials)
         self._enable_calibration_mode(attention_modules)
-        with torch.no_grad():
-            forward_loop(model)
-        per_sample_stats = self._extract_calibration_stats(attention_modules, phase=phase)
-        self._disable_calibration_mode(attention_modules)
+        try:
+            with torch.no_grad():
+                forward_loop(model)
+            per_sample_stats = self._extract_calibration_stats(attention_modules, phase=phase)
+        finally:
+            self._disable_calibration_mode(attention_modules)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py` around
lines 130 - 136, Wrap the calibration sequence in a try/finally so calibration
mode is always cleaned up on error: after calling self._set_thresholds(...) and
self._enable_calibration_mode(...), run forward_loop(model) and
self._extract_calibration_stats(...) inside a try block and call
self._disable_calibration_mode(...) (and reset any trial thresholds if
applicable) in the finally block; reference the methods _set_thresholds,
_enable_calibration_mode, forward_loop, _extract_calibration_stats, and
_disable_calibration_mode so you locate and wrap that exact sequence to ensure
modules are disabled and thresholds cleared even when exceptions occur.

Comment on lines +60 to 61
self.thresholds = self.thresholds_config.get("prefill", [1e-3])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Avoid phase-order-dependent threshold fallback.

At Line 71, missing phase keys fall back to the previous runtime value (self.thresholds), so behavior depends on which phase ran first instead of config-only defaults.

💡 Suggested fix
-        # Initialize thresholds from dict config (prefill phase as default)
-        self.thresholds = self.thresholds_config.get("prefill", [1e-3])
+        # Deterministic fallback for configs that define only one phase
+        self._fallback_thresholds = (
+            self.thresholds_config.get("prefill")
+            or self.thresholds_config.get("decode")
+            or [1e-3]
+        )
+        self.thresholds = list(self._fallback_thresholds)

     def _update_thresholds(self, phase: str):
         """Update thresholds list based on phase."""
-        self.thresholds = self.thresholds_config.get(phase, self.thresholds)
+        self.thresholds = list(self.thresholds_config.get(phase, self._fallback_thresholds))

Also applies to: 69-72

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`
around lines 60 - 61, The code currently falls back to the runtime value
self.thresholds when a phase key is missing, making behavior depend on the order
phases run; instead, when resolving per-phase thresholds use only configuration
defaults (e.g. phase_val = self.thresholds_config.get(phase,
self.thresholds_config.get("prefill", [1e-3]))), so replace any use of
self.thresholds as the fallback with a config-only chain (phase -> "prefill" ->
literal default) in the code that looks up phase thresholds (references:
self.thresholds_config and self.thresholds).

Comment on lines 179 to 183
if self.is_causal:
# For causal attention, only count lower triangle blocks (including diagonal)
num_causal_blocks = num_block_rows * (2 * num_block_cols - num_block_rows + 1) // 2
total_valid_blocks = batch_size * num_heads * num_causal_blocks
dense_blocks = block_mask.sum()
total_blocks = num_causal_blocks
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Mask non-causal blocks before counting dense blocks.

At Line 179–Line 183, denominator is causal-only, but at Line 194–Line 197 numerator counts all block positions. This can produce invalid sparsity values (< 0). It also explains the weakened prefill test bounds.

💡 Suggested fix
-            if self.is_causal:
-                num_causal_blocks = num_block_rows * (2 * num_block_cols - num_block_rows + 1) // 2
+            causal_block_layout = None
+            if self.is_causal:
+                causal_block_layout = torch.tril(
+                    torch.ones(
+                        (num_block_rows, num_block_cols),
+                        dtype=torch.bool,
+                        device=attn_weights.device,
+                    )
+                )
+                num_causal_blocks = int(causal_block_layout.sum().item())
                 total_valid_blocks = batch_size * num_heads * num_causal_blocks
                 total_blocks = num_causal_blocks
             else:
                 total_valid_blocks = batch_size * num_heads * num_block_rows * num_block_cols
                 total_blocks = num_block_rows * num_block_cols
@@
             for i, log_threshold in enumerate(log_thresholds):
                 block_mask = (block_max - block_max_cummax > log_threshold).any(dim=-2)
+                if causal_block_layout is not None:
+                    block_mask = block_mask & causal_block_layout

                 dense_blocks_list.append(block_mask.sum().item())

Also applies to: 194-197

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`
around lines 179 - 183, The code computes total_blocks/total_valid_blocks using
num_causal_blocks when self.is_causal but later counts dense block positions
across all blocks (producing negative sparsity); update the dense-block counting
to mask out non-causal positions or else compute both numerator and denominator
from the same masked positions: use the same causal mask used to derive
num_causal_blocks when counting dense blocks (and when computing
total_blocks/total_valid_blocks) so numerator and denominator align (apply this
fix in the block that sets total_blocks/total_valid_blocks and also in the later
dense-counting section referenced around the second occurrence at Lines
~194-197); refer to self.is_causal, num_causal_blocks, total_valid_blocks,
total_blocks and the dense-block counting logic to locate and change the code.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Updates FlashSkipSoftmax “skip softmax” calibration to support evaluating multiple sparsity thresholds in a single forward pass, improving calibration throughput while keeping inference sparsity behavior unchanged.

Changes:

  • Rename sparse attention config from threshold (scalar per phase) to thresholds (list per phase) and propagate through configs/tests.
  • Update FlashSkipSoftmax to compute per-threshold sparsity stats in one pass (and use the first threshold for the applied mask).
  • Extend stats aggregation to handle sparse_blocks as either a scalar or a list.

Reviewed changes

Copilot reviewed 12 out of 12 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py Updates expected threshold info to thresholds dict-of-lists.
tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py Updates sparse attention conversion tests to use thresholds.
tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py Updates calibration tests/configs to use thresholds.
tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py Updates FlashSkipSoftmax unit tests for list-based sparsity outputs.
tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py Updates GPU integration configs to thresholds.
tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py Updates GPU calibration configs to thresholds.
tests/_test_utils/torch/sparsity/sparse_attention_common.py Updates shared test config fixtures to thresholds.
modelopt/torch/sparsity/attention_sparsity/stats_manager.py Adds support for aggregating list-valued sparse_blocks and list average sparsity.
modelopt/torch/sparsity/attention_sparsity/model_sparsify.py Updates public-facing doc/example to thresholds.
modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py Implements multi-threshold stats collection and threshold list handling.
modelopt/torch/sparsity/attention_sparsity/config.py Renames/validates thresholds as dict-of-float-lists (with length checks).
modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py Switches calibration data collection to single-pass multi-threshold stats extraction.
Comments suppressed due to low confidence (1)

modelopt/torch/sparsity/attention_sparsity/config.py:132

  • validate_thresholds still raises an error that says "Threshold must be..." even though the field is now thresholds and expects lists. Updating this message will make validation failures much clearer to users.
    def validate_thresholds(cls, v):
        """Validate thresholds is a dict of lists with valid phases and values in range (0, 1)."""
        if not isinstance(v, dict):
            raise ValueError(
                f"Threshold must be a dict with 'prefill' and/or 'decode' keys, got {type(v).__name__}"
            )

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 119 to 146
def test_correction_factor_calculation_prefill(self):
"""Test correction factor for prefill phase."""
method = FlashSkipSoftmax(
{
"threshold": {"prefill": 1e-3, "decode": 1e-4},
"thresholds": {"prefill": [1e-3], "decode": [1e-4]},
"br": 128,
"bc": 128,
"backend": "pytorch",
"is_causal": True,
}
)

# Create simple attention pattern
attn = torch.randn(1, 1, 128, 256)

mask, stats = method.calc_correction_factor_and_p(attn, "prefill")

# Verify stats structure
assert "correction_factor" in stats
assert "sparsity" in stats
assert "phase" in stats
assert "total_blocks" in stats
assert stats["phase"] == "prefill"
assert 0 <= stats["correction_factor"] <= 1
# Sparsity can be negative if threshold is too low (more blocks kept than expected)
assert -1 <= stats["sparsity"] <= 1
# sparsity is now a list (one entry per threshold)
assert isinstance(stats["sparsity"], list)
assert all(-1 <= s <= 1 for s in stats["sparsity"])

Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new multi-threshold behavior isn't covered by unit tests yet: the updated tests still only use single-entry lists. Adding a test with thresholds length > 1 would help verify that stats['sparsity'] length matches the threshold list and that the applied mask corresponds to the first threshold (as documented).

Copilot uses AI. Check for mistakes.
Comment on lines +49 to 58
thresholds: dict[str, list[float]] = ModeloptField(
default={"prefill": [1e-3], "decode": [1e-4]},
title="Sparsity thresholds.",
description=(
"Threshold for determining which attention values to skip. "
"Must be a dict with 'prefill' and 'decode' keys."
"Thresholds for determining which attention values to skip. "
"Must be a dict with 'prefill' and/or 'decode' keys, each mapping to a list of floats. "
"Prefill and decode lists must have the same length. "
"Sparsity is computed per threshold; the first threshold's mask is applied."
),
)
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The config rename to thresholds is a breaking change but the repo still has callers using the old threshold field (e.g. tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py). This will cause validation errors and failing tests unless you either update the remaining call sites/tests or provide a backwards-compatible alias (e.g., accept threshold as a deprecated alias that maps to a single-entry thresholds list).

Copilot uses AI. Check for mistakes.
Comment on lines +205 to +219
# Step 6-7: Expand block_mask_0 to element level (skip in calibration mode)
if not self._calibration_mode and block_mask_0 is not None:
element_mask = (
block_mask_0.unsqueeze(-2)
.unsqueeze(-1)
.expand(batch_size, num_heads, num_block_rows, self.br, num_block_cols, self.bc)
)
del block_mask_0
element_mask = element_mask.reshape(
batch_size, num_heads, padded_seq_q, padded_seq_k
)
element_mask = element_mask[:, :, :seq_q, :seq_k]
else:
element_mask = None

Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In calibration mode calc_correction_factor_and_p returns element_mask = None, but calculate_sparsity() (and the base SparseAttentionMethod API) is typed/used as returning a tensor mask. Even if the current softmax patch ignores the mask during calibration, returning None risks surprising failures for any other caller of calculate_sparsity() while calibration_mode is enabled. Consider returning a lightweight sentinel tensor (or updating the method contract/call sites to explicitly allow None).

Copilot uses AI. Check for mistakes.
if current == 0:
# Lazy init: first time seeing a list
self.aggregated_stats["sparse_blocks"] = list(incoming)
elif isinstance(current, list):
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SparseAttentionStatsManager.collect() accumulates sparse_blocks lists by index but doesn't validate that incoming and current have the same length. If the configured threshold list changes between calls, this can raise an IndexError (or silently keep stale tail values if the new list is shorter). Consider explicitly checking list lengths and either resetting safely or raising a clear error.

Suggested change
elif isinstance(current, list):
elif isinstance(current, list):
if len(current) != len(incoming):
raise ValueError(
f"SparseAttentionStatsManager.collect: length mismatch for "
f"'sparse_blocks' in module '{self.module_name}': "
f"existing length {len(current)}, incoming length {len(incoming)}."
)

Copilot uses AI. Check for mistakes.

for sample_stat in per_sample_stats:
length = sample_stat["sample_length"]
sparsity_list = sample_stat["sparsity"]
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

During calibration, zip(self.threshold_trials, sparsity_list) will silently truncate if a module reports a sparsity list of the wrong length (e.g. if a module is still using the calibrated single-threshold path and returns a scalar/length-1 list). It would be safer to validate len(sparsity_list) == len(self.threshold_trials) and fail fast (or temporarily disable calibration_params on modules during data collection) to avoid producing incorrect calibration data.

Suggested change
sparsity_list = sample_stat["sparsity"]
sparsity_list = sample_stat["sparsity"]
# Validate that each sample reports one sparsity value per threshold.
# Using zip() without this check would silently truncate if the lengths differ.
if len(sparsity_list) != len(self.threshold_trials):
raise ValueError(
"Mismatch between number of thresholds and reported sparsity values during "
f"{phase} calibration: expected {len(self.threshold_trials)} sparsity values "
f"per sample, but got {len(sparsity_list)}. This may indicate that a module "
"is still using a single-threshold path or is misconfigured for calibration."
)

Copilot uses AI. Check for mistakes.
Comment on lines +320 to +322
assert len(set(lengths)) == 1, (
f"All modules must have the same number of thresholds, got {lengths}"
)
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_extract_calibration_stats() uses a bare assert to enforce that all modules report the same number of thresholds. Asserts can be stripped with python -O, which would turn this into silent miscalibration. Prefer raising a ValueError (or similar) with the same message so the check always runs.

Suggested change
assert len(set(lengths)) == 1, (
f"All modules must have the same number of thresholds, got {lengths}"
)
if len(set(lengths)) != 1:
raise ValueError(
f"All modules must have the same number of thresholds, got {lengths}"
)

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants