Skip to content

Allow HF trainer to mask sequences prior to reduction#1009

Open
AAnoosheh wants to merge 5 commits intomainfrom
aanoosheh/hf-trainer-masking
Open

Allow HF trainer to mask sequences prior to reduction#1009
AAnoosheh wants to merge 5 commits intomainfrom
aanoosheh/hf-trainer-masking

Conversation

@AAnoosheh
Copy link
Contributor

@AAnoosheh AAnoosheh commented Mar 9, 2026

What does this PR do?

Type of change: ?

Previously HF trainer did not account for loss masking

Usage

# Add a code snippet demonstrating how to use this

Testing

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ / ❌ / N/A
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

  • Improvements
    • Knowledge-distillation loss now properly ignores padding/special tokens and supports masked per-token averaging.
    • Default loss reduction behavior adjusted for finer-grained training control and clearer per-token outputs.
    • More robust logit handling with consistent numeric casting for improved stability and accuracy, including mixed-precision scenarios.

Signed-off-by: Asha Anoosheh <aanoosheh@nvidia.com>
@AAnoosheh AAnoosheh self-assigned this Mar 9, 2026
@AAnoosheh AAnoosheh requested a review from a team as a code owner March 9, 2026 19:38
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 9, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds IGNORE_TOKEN_ID and a trainer-level _compute_kd_loss wrapper that builds a per-sample reduction function applying label masking. Updates LMLogitsLoss to support per-token losses (casts logits to float) and changes LogitsDistillationLoss default reduction from "batchmean" to "mean".

Changes

Cohort / File(s) Summary
HuggingFace KD plugin
modelopt/torch/distill/plugins/huggingface.py
Added Tensor and LabelSmoother imports and IGNORE_TOKEN_ID constant. Introduced _compute_kd_loss that constructs loss_reduction_fn to mask tokens with IGNORE_TOKEN_ID and either average or return per-token losses; assigns self.compute_loss_func. Updated LMLogitsLoss.__init__ signature to (temperature: float = 1.0, reduction: str = "none") and modified forward to cast student/teacher logits to float and return per-token losses when reduction == "none".
KD loss utilities
modelopt/torch/distill/losses.py
Changed LogitsDistillationLoss.__init__ default reduction from "batchmean" to "mean". Adjusted forward to handle reduction == "none" by summing the last (vocab) dimension when returning per-token losses.

Sequence Diagram(s)

sequenceDiagram
participant Trainer as KDTrainer
participant Model as Student/Teacher Model
participant Loss as LMLogitsLoss
participant Labels as Labels (may contain -100)

Trainer->>Model: forward(inputs) → logits_student, logits_teacher
Trainer->>Labels: provide labels (optional)
Trainer->>Trainer: build loss_reduction_fn(labels, IGNORE_TOKEN_ID)
Trainer->>Model: call model.compute_kd_loss(loss_reduction_fn=...)
Model->>Loss: compute_kd_loss invokes LMLogitsLoss
Loss->>Loss: cast logits to float, compute per-token KD loss
Loss->>Labels: apply mask where labels == IGNORE_TOKEN_ID (if labels present)
Loss-->>Model: return per-sample/per-token loss
Model-->>Trainer: aggregated loss (averaged or per-sample)
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 71.43% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Allow HF trainer to mask sequences prior to reduction' directly matches the PR's main objective to enable sequence masking before loss reduction in the Hugging Face trainer integration.
Security Anti-Patterns ✅ Passed The pull request contains no security anti-patterns including unsafe torch.load/numpy.load calls, hardcoded trust_remote_code=True, eval/exec on external input, or # nosec bypasses.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch aanoosheh/hf-trainer-masking

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
modelopt/torch/distill/plugins/huggingface.py (1)

104-115: Clarify batch-mean reduction logic for the labels is None case.

On line 108, loss.sum() / len(loss) divides by the first dimension's size. If loss is already flattened or has shape (batch * seq,), this gives correct "mean" behavior. However, if loss is multi-dimensional (e.g., (batch, seq, vocab) before the sum(dim=-1) on line 110), len(loss) returns only the batch size, not the total element count.

Consider whether this branch should also handle the loss.ndim >= 2 case consistently with the masked branch, or use loss.numel() / loss.mean() for clarity:

 def loss_reduction_fn(loss: Tensor):
     if labels is None:
-        return loss.sum() / len(loss)  # batchmean reduction
+        per_token_loss = loss.sum(dim=-1) if loss.ndim >= 2 else loss
+        return per_token_loss.mean()  # batchmean reduction
     loss_mask = (labels.view(-1) != IGNORE_INDEX).to(loss.dtype)

Also, the outputs parameter is declared but unused—if this is intentional (KD loss is computed from captured intermediate outputs), consider prefixing with _ to indicate it's intentionally unused.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/distill/plugins/huggingface.py` around lines 104 - 115, The
batch-mean branch in _compute_kd_loss currently uses loss.sum() / len(loss)
which only divides by the first-dimension size and is incorrect for multi-dim
losses; modify loss_reduction_fn to mirror the masked branch by computing
per_token_loss = loss.sum(dim=-1) if loss.ndim >= 2 else loss and then return
per_token_loss.mean() (or per_token_loss.sum() / per_token_loss.numel()) for the
labels is None case so reduction is correct for flattened and multi-dim tensors,
and call self.model.compute_kd_loss(loss_reduction_fn=loss_reduction_fn) as
before; also rename the unused outputs parameter in _compute_kd_loss to _outputs
(or prefix with _) to indicate it's intentionally unused.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@modelopt/torch/distill/plugins/huggingface.py`:
- Around line 104-115: The batch-mean branch in _compute_kd_loss currently uses
loss.sum() / len(loss) which only divides by the first-dimension size and is
incorrect for multi-dim losses; modify loss_reduction_fn to mirror the masked
branch by computing per_token_loss = loss.sum(dim=-1) if loss.ndim >= 2 else
loss and then return per_token_loss.mean() (or per_token_loss.sum() /
per_token_loss.numel()) for the labels is None case so reduction is correct for
flattened and multi-dim tensors, and call
self.model.compute_kd_loss(loss_reduction_fn=loss_reduction_fn) as before; also
rename the unused outputs parameter in _compute_kd_loss to _outputs (or prefix
with _) to indicate it's intentionally unused.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 921c1563-4097-472a-996b-cdf258d8f308

📥 Commits

Reviewing files that changed from the base of the PR and between 2bb404e and 8004d0c.

📒 Files selected for processing (1)
  • modelopt/torch/distill/plugins/huggingface.py

@codecov
Copy link

codecov bot commented Mar 9, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 71.73%. Comparing base (5d0e012) to head (04b5654).
⚠️ Report is 3 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #1009   +/-   ##
=======================================
  Coverage   71.73%   71.73%           
=======================================
  Files         211      211           
  Lines       23948    23949    +1     
=======================================
+ Hits        17180    17181    +1     
  Misses       6768     6768           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: Asha Anoosheh <aanoosheh@nvidia.com>
Signed-off-by: Asha Anoosheh <aanoosheh@nvidia.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
modelopt/torch/distill/plugins/huggingface.py (1)

126-134: Breaking change: Default reduction changed to "none".

The constructor now defaults to reduction="none" instead of inheriting the parent's default. This is intentional for the new masking feature but may affect existing users who instantiate LMLogitsLoss() without explicit arguments.

Consider noting this in the changelog or migration notes as mentioned in the PR checklist.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/distill/plugins/huggingface.py` around lines 126 - 134, The
default for reduction in LMLogitsLoss.__init__ was changed to "none" (was
previously inherited) which is a breaking change; update the project changelog
and migration notes to explicitly document this behavioral change (include the
class name LMLogitsLoss and the parameter reduction default change to "none"),
and add a short note in the constructor docstring of huggingface.py indicating
the new default and recommended migration steps for users who relied on the old
behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@modelopt/torch/distill/plugins/huggingface.py`:
- Around line 126-134: The default for reduction in LMLogitsLoss.__init__ was
changed to "none" (was previously inherited) which is a breaking change; update
the project changelog and migration notes to explicitly document this behavioral
change (include the class name LMLogitsLoss and the parameter reduction default
change to "none"), and add a short note in the constructor docstring of
huggingface.py indicating the new default and recommended migration steps for
users who relied on the old behavior.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 99809bd0-fa7c-4c07-9bd0-2fc5bea07050

📥 Commits

Reviewing files that changed from the base of the PR and between 3f77bf4 and 04b5654.

📒 Files selected for processing (1)
  • modelopt/torch/distill/plugins/huggingface.py

Signed-off-by: Asha Anoosheh <aanoosheh@nvidia.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/distill/losses.py`:
- Line 34: The default reduction for KL in LogitsDistillationLoss should remain
"batchmean" to match PyTorch's mathematical KL divergence and preserve backward
compatibility; change the default value in LogitsDistillationLoss.__init__ from
"mean" to "batchmean" (and update any related docstring or tests that assume the
previous default) so that torch.nn.functional.kl_div uses reduction="batchmean"
unless explicitly overridden.

In `@modelopt/torch/distill/plugins/huggingface.py`:
- Around line 144-148: The per-token loss returned by
LogitsDistillationLoss.forward() is flattened by the parent class and becomes a
1D tensor; after computing loss = super().forward(student_logits,
teacher_logits) (and handling self._reduction == "none"), reshape the per-token
loss back to the original token shape student_logits.shape[:-1] (i.e., (batch,
seq)) before returning so subsequent masking (e.g., multiplying by labels !=
IGNORE_TOKEN_ID) matches dimensions; update the return path in huggingface.py to
reshape loss to student_logits.shape[:-1] when reduction == "none".

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: bbc4d4a7-08bd-41e3-ad0d-ac5a1010c546

📥 Commits

Reviewing files that changed from the base of the PR and between 04b5654 and 732dc47.

📒 Files selected for processing (2)
  • modelopt/torch/distill/losses.py
  • modelopt/torch/distill/plugins/huggingface.py

Signed-off-by: Asha Anoosheh <aanoosheh@nvidia.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
modelopt/torch/distill/losses.py (1)

34-34: ⚠️ Potential issue | 🟠 Major

Keep "batchmean" as the default to preserve backward compatibility and mathematical correctness.

The change from "batchmean" to "mean" scales every KD loss down by 1 / num_classes relative to previous behavior and deviates from the mathematical definition of KL divergence. This affects all existing callers of LogitsDistillationLoss.

Suggested fix
-    def __init__(self, temperature: float = 1.0, reduction: str = "mean"):
+    def __init__(self, temperature: float = 1.0, reduction: str = "batchmean"):
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/distill/losses.py` at line 34, The constructor for
LogitsDistillationLoss changed the default reduction from "batchmean" to "mean",
which alters KD loss scaling; revert the default in
LogitsDistillationLoss.__init__ to reduction: str = "batchmean" (and ensure any
docstring or tests referencing default behavior are updated if present) so the
KL divergence uses batchmean as before and preserves backward compatibility and
mathematical correctness.
🧹 Nitpick comments (1)
modelopt/torch/distill/losses.py (1)

62-64: Document the semantic change to reduction="none" behavior.

This implementation is mathematically correct—KL divergence requires summing over the class dimension to produce per-token losses. However, this changes the output shape of reduction="none" from (batch, seq, vocab) to (batch, seq), which is a breaking change for any existing callers that expected element-wise losses.

Consider adding a note in the docstring (line 39-40) clarifying that reduction="none" returns per-token losses (summed over vocab), not element-wise losses.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/distill/losses.py` around lines 62 - 64, Update the docstring
for the loss implementation that performs KL-divergence (the method containing
the self._reduction check) to explicitly document that reduction="none" returns
per-token losses summed over the vocab dimension (shape becomes (batch, seq)),
not element-wise per-class losses; mention that the implementation sums kd_loss
over dim=-1 to produce this per-token result so callers expecting (batch, seq,
vocab) must sum or change their code accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@modelopt/torch/distill/losses.py`:
- Line 34: The constructor for LogitsDistillationLoss changed the default
reduction from "batchmean" to "mean", which alters KD loss scaling; revert the
default in LogitsDistillationLoss.__init__ to reduction: str = "batchmean" (and
ensure any docstring or tests referencing default behavior are updated if
present) so the KL divergence uses batchmean as before and preserves backward
compatibility and mathematical correctness.

---

Nitpick comments:
In `@modelopt/torch/distill/losses.py`:
- Around line 62-64: Update the docstring for the loss implementation that
performs KL-divergence (the method containing the self._reduction check) to
explicitly document that reduction="none" returns per-token losses summed over
the vocab dimension (shape becomes (batch, seq)), not element-wise per-class
losses; mention that the implementation sums kd_loss over dim=-1 to produce this
per-token result so callers expecting (batch, seq, vocab) must sum or change
their code accordingly.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: ee265919-44e3-44f1-bdc3-28d263f81b4d

📥 Commits

Reviewing files that changed from the base of the PR and between 732dc47 and 820bad7.

📒 Files selected for processing (2)
  • modelopt/torch/distill/losses.py
  • modelopt/torch/distill/plugins/huggingface.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/distill/plugins/huggingface.py

Copy link
Contributor

@realAsma realAsma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants