Allow HF trainer to mask sequences prior to reduction#1009
Allow HF trainer to mask sequences prior to reduction#1009
Conversation
Signed-off-by: Asha Anoosheh <aanoosheh@nvidia.com>
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds IGNORE_TOKEN_ID and a trainer-level _compute_kd_loss wrapper that builds a per-sample reduction function applying label masking. Updates LMLogitsLoss to support per-token losses (casts logits to float) and changes LogitsDistillationLoss default reduction from "batchmean" to "mean". Changes
Sequence Diagram(s)sequenceDiagram
participant Trainer as KDTrainer
participant Model as Student/Teacher Model
participant Loss as LMLogitsLoss
participant Labels as Labels (may contain -100)
Trainer->>Model: forward(inputs) → logits_student, logits_teacher
Trainer->>Labels: provide labels (optional)
Trainer->>Trainer: build loss_reduction_fn(labels, IGNORE_TOKEN_ID)
Trainer->>Model: call model.compute_kd_loss(loss_reduction_fn=...)
Model->>Loss: compute_kd_loss invokes LMLogitsLoss
Loss->>Loss: cast logits to float, compute per-token KD loss
Loss->>Labels: apply mask where labels == IGNORE_TOKEN_ID (if labels present)
Loss-->>Model: return per-sample/per-token loss
Model-->>Trainer: aggregated loss (averaged or per-sample)
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
modelopt/torch/distill/plugins/huggingface.py (1)
104-115: Clarify batch-mean reduction logic for thelabels is Nonecase.On line 108,
loss.sum() / len(loss)divides by the first dimension's size. Iflossis already flattened or has shape(batch * seq,), this gives correct "mean" behavior. However, iflossis multi-dimensional (e.g.,(batch, seq, vocab)before thesum(dim=-1)on line 110),len(loss)returns only the batch size, not the total element count.Consider whether this branch should also handle the
loss.ndim >= 2case consistently with the masked branch, or useloss.numel()/loss.mean()for clarity:def loss_reduction_fn(loss: Tensor): if labels is None: - return loss.sum() / len(loss) # batchmean reduction + per_token_loss = loss.sum(dim=-1) if loss.ndim >= 2 else loss + return per_token_loss.mean() # batchmean reduction loss_mask = (labels.view(-1) != IGNORE_INDEX).to(loss.dtype)Also, the
outputsparameter is declared but unused—if this is intentional (KD loss is computed from captured intermediate outputs), consider prefixing with_to indicate it's intentionally unused.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/distill/plugins/huggingface.py` around lines 104 - 115, The batch-mean branch in _compute_kd_loss currently uses loss.sum() / len(loss) which only divides by the first-dimension size and is incorrect for multi-dim losses; modify loss_reduction_fn to mirror the masked branch by computing per_token_loss = loss.sum(dim=-1) if loss.ndim >= 2 else loss and then return per_token_loss.mean() (or per_token_loss.sum() / per_token_loss.numel()) for the labels is None case so reduction is correct for flattened and multi-dim tensors, and call self.model.compute_kd_loss(loss_reduction_fn=loss_reduction_fn) as before; also rename the unused outputs parameter in _compute_kd_loss to _outputs (or prefix with _) to indicate it's intentionally unused.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@modelopt/torch/distill/plugins/huggingface.py`:
- Around line 104-115: The batch-mean branch in _compute_kd_loss currently uses
loss.sum() / len(loss) which only divides by the first-dimension size and is
incorrect for multi-dim losses; modify loss_reduction_fn to mirror the masked
branch by computing per_token_loss = loss.sum(dim=-1) if loss.ndim >= 2 else
loss and then return per_token_loss.mean() (or per_token_loss.sum() /
per_token_loss.numel()) for the labels is None case so reduction is correct for
flattened and multi-dim tensors, and call
self.model.compute_kd_loss(loss_reduction_fn=loss_reduction_fn) as before; also
rename the unused outputs parameter in _compute_kd_loss to _outputs (or prefix
with _) to indicate it's intentionally unused.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 921c1563-4097-472a-996b-cdf258d8f308
📒 Files selected for processing (1)
modelopt/torch/distill/plugins/huggingface.py
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #1009 +/- ##
=======================================
Coverage 71.73% 71.73%
=======================================
Files 211 211
Lines 23948 23949 +1
=======================================
+ Hits 17180 17181 +1
Misses 6768 6768 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: Asha Anoosheh <aanoosheh@nvidia.com>
There was a problem hiding this comment.
🧹 Nitpick comments (1)
modelopt/torch/distill/plugins/huggingface.py (1)
126-134: Breaking change: Defaultreductionchanged to"none".The constructor now defaults to
reduction="none"instead of inheriting the parent's default. This is intentional for the new masking feature but may affect existing users who instantiateLMLogitsLoss()without explicit arguments.Consider noting this in the changelog or migration notes as mentioned in the PR checklist.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/distill/plugins/huggingface.py` around lines 126 - 134, The default for reduction in LMLogitsLoss.__init__ was changed to "none" (was previously inherited) which is a breaking change; update the project changelog and migration notes to explicitly document this behavioral change (include the class name LMLogitsLoss and the parameter reduction default change to "none"), and add a short note in the constructor docstring of huggingface.py indicating the new default and recommended migration steps for users who relied on the old behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@modelopt/torch/distill/plugins/huggingface.py`:
- Around line 126-134: The default for reduction in LMLogitsLoss.__init__ was
changed to "none" (was previously inherited) which is a breaking change; update
the project changelog and migration notes to explicitly document this behavioral
change (include the class name LMLogitsLoss and the parameter reduction default
change to "none"), and add a short note in the constructor docstring of
huggingface.py indicating the new default and recommended migration steps for
users who relied on the old behavior.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 99809bd0-fa7c-4c07-9bd0-2fc5bea07050
📒 Files selected for processing (1)
modelopt/torch/distill/plugins/huggingface.py
Signed-off-by: Asha Anoosheh <aanoosheh@nvidia.com>
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/distill/losses.py`:
- Line 34: The default reduction for KL in LogitsDistillationLoss should remain
"batchmean" to match PyTorch's mathematical KL divergence and preserve backward
compatibility; change the default value in LogitsDistillationLoss.__init__ from
"mean" to "batchmean" (and update any related docstring or tests that assume the
previous default) so that torch.nn.functional.kl_div uses reduction="batchmean"
unless explicitly overridden.
In `@modelopt/torch/distill/plugins/huggingface.py`:
- Around line 144-148: The per-token loss returned by
LogitsDistillationLoss.forward() is flattened by the parent class and becomes a
1D tensor; after computing loss = super().forward(student_logits,
teacher_logits) (and handling self._reduction == "none"), reshape the per-token
loss back to the original token shape student_logits.shape[:-1] (i.e., (batch,
seq)) before returning so subsequent masking (e.g., multiplying by labels !=
IGNORE_TOKEN_ID) matches dimensions; update the return path in huggingface.py to
reshape loss to student_logits.shape[:-1] when reduction == "none".
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: bbc4d4a7-08bd-41e3-ad0d-ac5a1010c546
📒 Files selected for processing (2)
modelopt/torch/distill/losses.pymodelopt/torch/distill/plugins/huggingface.py
There was a problem hiding this comment.
♻️ Duplicate comments (1)
modelopt/torch/distill/losses.py (1)
34-34:⚠️ Potential issue | 🟠 MajorKeep
"batchmean"as the default to preserve backward compatibility and mathematical correctness.The change from
"batchmean"to"mean"scales every KD loss down by1 / num_classesrelative to previous behavior and deviates from the mathematical definition of KL divergence. This affects all existing callers ofLogitsDistillationLoss.Suggested fix
- def __init__(self, temperature: float = 1.0, reduction: str = "mean"): + def __init__(self, temperature: float = 1.0, reduction: str = "batchmean"):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/distill/losses.py` at line 34, The constructor for LogitsDistillationLoss changed the default reduction from "batchmean" to "mean", which alters KD loss scaling; revert the default in LogitsDistillationLoss.__init__ to reduction: str = "batchmean" (and ensure any docstring or tests referencing default behavior are updated if present) so the KL divergence uses batchmean as before and preserves backward compatibility and mathematical correctness.
🧹 Nitpick comments (1)
modelopt/torch/distill/losses.py (1)
62-64: Document the semantic change toreduction="none"behavior.This implementation is mathematically correct—KL divergence requires summing over the class dimension to produce per-token losses. However, this changes the output shape of
reduction="none"from(batch, seq, vocab)to(batch, seq), which is a breaking change for any existing callers that expected element-wise losses.Consider adding a note in the docstring (line 39-40) clarifying that
reduction="none"returns per-token losses (summed over vocab), not element-wise losses.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/distill/losses.py` around lines 62 - 64, Update the docstring for the loss implementation that performs KL-divergence (the method containing the self._reduction check) to explicitly document that reduction="none" returns per-token losses summed over the vocab dimension (shape becomes (batch, seq)), not element-wise per-class losses; mention that the implementation sums kd_loss over dim=-1 to produce this per-token result so callers expecting (batch, seq, vocab) must sum or change their code accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@modelopt/torch/distill/losses.py`:
- Line 34: The constructor for LogitsDistillationLoss changed the default
reduction from "batchmean" to "mean", which alters KD loss scaling; revert the
default in LogitsDistillationLoss.__init__ to reduction: str = "batchmean" (and
ensure any docstring or tests referencing default behavior are updated if
present) so the KL divergence uses batchmean as before and preserves backward
compatibility and mathematical correctness.
---
Nitpick comments:
In `@modelopt/torch/distill/losses.py`:
- Around line 62-64: Update the docstring for the loss implementation that
performs KL-divergence (the method containing the self._reduction check) to
explicitly document that reduction="none" returns per-token losses summed over
the vocab dimension (shape becomes (batch, seq)), not element-wise per-class
losses; mention that the implementation sums kd_loss over dim=-1 to produce this
per-token result so callers expecting (batch, seq, vocab) must sum or change
their code accordingly.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: ee265919-44e3-44f1-bdc3-28d263f81b4d
📒 Files selected for processing (2)
modelopt/torch/distill/losses.pymodelopt/torch/distill/plugins/huggingface.py
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/torch/distill/plugins/huggingface.py
What does this PR do?
Type of change: ?
Previously HF trainer did not account for loss masking
Usage
# Add a code snippet demonstrating how to use thisTesting
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: ✅ / ❌ / N/AAdditional Information
Summary by CodeRabbit