From 8004d0cfd720b24fbc4e028ea8374e71094490bd Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Mon, 9 Mar 2026 12:37:18 -0700 Subject: [PATCH 1/5] Allow HF trainer to mask padding etc Signed-off-by: Asha Anoosheh --- modelopt/torch/distill/plugins/huggingface.py | 36 +++++++++++++++++-- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/modelopt/torch/distill/plugins/huggingface.py b/modelopt/torch/distill/plugins/huggingface.py index ff5cdc426..b84beed0a 100644 --- a/modelopt/torch/distill/plugins/huggingface.py +++ b/modelopt/torch/distill/plugins/huggingface.py @@ -15,12 +15,15 @@ """ModelOpt plugin to train HuggingFace models with knowledge distillation.""" +from torch import Tensor from transformers.modeling_outputs import CausalLMOutputWithPast import modelopt.torch.distill as mtd from modelopt.torch.opt.plugins import ModelOptHFTrainer from modelopt.torch.utils import print_rank_0 +IGNORE_INDEX = -100 + class KDTrainer(ModelOptHFTrainer): """Distillation trainer for HuggingFace models.""" @@ -98,12 +101,38 @@ def save_model( def train(self, *args, **kwargs): """Train the model.""" - self.compute_loss_func = lambda *args, **kwargs: self.model.compute_kd_loss() + + def _compute_kd_loss(outputs: Tensor, labels: Tensor | None, **kwargs): + def loss_reduction_fn(loss: Tensor): + if labels is None: + return loss.sum() / len(loss) # batchmean reduction + loss_mask = (labels.view(-1) != IGNORE_INDEX).to(loss.dtype) + per_token_loss = loss.sum(dim=-1) if loss.ndim >= 2 else loss + return (per_token_loss * loss_mask).sum() / loss_mask.sum().clamp(min=1) + + return self.model.compute_kd_loss(loss_reduction_fn=loss_reduction_fn) + + self.compute_loss_func = _compute_kd_loss return super().train(*args, **kwargs) class LMLogitsLoss(mtd.LogitsDistillationLoss): - """Logits loss for knowledge distillation.""" + """Logits loss for language-model knowledge distillation. + + Defaults to ``reduction="none"`` to support per-token loss masking via ``loss_reduction_fn`` + in :meth:`DistillationModel.compute_kd_loss`. This allows masking out padding and non-assistant + tokens (marked with ``IGNORE_INDEX = -100``) before reducing the loss. + """ + + def __init__(self, temperature: float = 1.0, reduction: str = "none"): + """Constructor. + + Args: + temperature: A value used to soften the logits before computing loss. + reduction: How to reduce the final pointwise loss. Defaults to ``"none"`` to + allow loss-masking via ``loss_reduction_fn`` in ``compute_kd_loss``. + """ + super().__init__(temperature=temperature, reduction=reduction) def forward(self, out_student: CausalLMOutputWithPast, out_teacher: CausalLMOutputWithPast): """Forward pass for logits distillation loss. @@ -112,4 +141,5 @@ def forward(self, out_student: CausalLMOutputWithPast, out_teacher: CausalLMOutp out_student: The student model output. out_teacher: The teacher model output. """ - return super().forward(out_student.logits, out_teacher.logits) + student_logits, teacher_logits = out_student.logits.float(), out_teacher.logits.float() + return super().forward(student_logits, teacher_logits) From 3f77bf4b25166c818b6b0326d5920863e265f821 Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Mon, 9 Mar 2026 13:51:39 -0700 Subject: [PATCH 2/5] Suggestion Signed-off-by: Asha Anoosheh --- modelopt/torch/distill/plugins/huggingface.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modelopt/torch/distill/plugins/huggingface.py b/modelopt/torch/distill/plugins/huggingface.py index b84beed0a..a7cd50b56 100644 --- a/modelopt/torch/distill/plugins/huggingface.py +++ b/modelopt/torch/distill/plugins/huggingface.py @@ -15,6 +15,7 @@ """ModelOpt plugin to train HuggingFace models with knowledge distillation.""" +import torch.nn as nn from torch import Tensor from transformers.modeling_outputs import CausalLMOutputWithPast @@ -22,8 +23,6 @@ from modelopt.torch.opt.plugins import ModelOptHFTrainer from modelopt.torch.utils import print_rank_0 -IGNORE_INDEX = -100 - class KDTrainer(ModelOptHFTrainer): """Distillation trainer for HuggingFace models.""" @@ -101,12 +100,13 @@ def save_model( def train(self, *args, **kwargs): """Train the model.""" + ignore_index = nn.CrossEntropyLoss().ignore_index # equals -100 def _compute_kd_loss(outputs: Tensor, labels: Tensor | None, **kwargs): def loss_reduction_fn(loss: Tensor): if labels is None: return loss.sum() / len(loss) # batchmean reduction - loss_mask = (labels.view(-1) != IGNORE_INDEX).to(loss.dtype) + loss_mask = (labels.view(-1) != ignore_index).to(loss.dtype) per_token_loss = loss.sum(dim=-1) if loss.ndim >= 2 else loss return (per_token_loss * loss_mask).sum() / loss_mask.sum().clamp(min=1) From 04b5654690e137d8cecb3165b05a1478ee6636d2 Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Mon, 9 Mar 2026 15:54:47 -0700 Subject: [PATCH 3/5] Small fix Signed-off-by: Asha Anoosheh --- modelopt/torch/distill/plugins/huggingface.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/modelopt/torch/distill/plugins/huggingface.py b/modelopt/torch/distill/plugins/huggingface.py index a7cd50b56..0133ad682 100644 --- a/modelopt/torch/distill/plugins/huggingface.py +++ b/modelopt/torch/distill/plugins/huggingface.py @@ -106,9 +106,8 @@ def _compute_kd_loss(outputs: Tensor, labels: Tensor | None, **kwargs): def loss_reduction_fn(loss: Tensor): if labels is None: return loss.sum() / len(loss) # batchmean reduction - loss_mask = (labels.view(-1) != ignore_index).to(loss.dtype) - per_token_loss = loss.sum(dim=-1) if loss.ndim >= 2 else loss - return (per_token_loss * loss_mask).sum() / loss_mask.sum().clamp(min=1) + loss_mask = (labels != ignore_index).to(loss.dtype) + return (loss * loss_mask).sum() / loss_mask.sum().clamp(min=1) return self.model.compute_kd_loss(loss_reduction_fn=loss_reduction_fn) @@ -142,4 +141,7 @@ def forward(self, out_student: CausalLMOutputWithPast, out_teacher: CausalLMOutp out_teacher: The teacher model output. """ student_logits, teacher_logits = out_student.logits.float(), out_teacher.logits.float() - return super().forward(student_logits, teacher_logits) + loss = super().forward(student_logits, teacher_logits) + if self._reduction == "none": + loss = loss.sum(dim=-1) + return loss From 732dc47267df9f9c7d128f1099ad77fd3a842e41 Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Tue, 10 Mar 2026 07:05:52 -0700 Subject: [PATCH 4/5] Change batchmean to mean and refine Signed-off-by: Asha Anoosheh --- modelopt/torch/distill/losses.py | 2 +- modelopt/torch/distill/plugins/huggingface.py | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/modelopt/torch/distill/losses.py b/modelopt/torch/distill/losses.py index 258824bf0..993228199 100644 --- a/modelopt/torch/distill/losses.py +++ b/modelopt/torch/distill/losses.py @@ -31,7 +31,7 @@ class LogitsDistillationLoss(Loss): This function implements the distillation loss found in the paper: https://arxiv.org/abs/1503.02531. """ - def __init__(self, temperature: float = 1.0, reduction: str = "batchmean"): + def __init__(self, temperature: float = 1.0, reduction: str = "mean"): """Constructor. Args: diff --git a/modelopt/torch/distill/plugins/huggingface.py b/modelopt/torch/distill/plugins/huggingface.py index 0133ad682..33ccd7d4a 100644 --- a/modelopt/torch/distill/plugins/huggingface.py +++ b/modelopt/torch/distill/plugins/huggingface.py @@ -15,14 +15,16 @@ """ModelOpt plugin to train HuggingFace models with knowledge distillation.""" -import torch.nn as nn from torch import Tensor from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.trainer_pt_utils import LabelSmoother import modelopt.torch.distill as mtd from modelopt.torch.opt.plugins import ModelOptHFTrainer from modelopt.torch.utils import print_rank_0 +IGNORE_TOKEN_ID = LabelSmoother.ignore_index # equals -100 + class KDTrainer(ModelOptHFTrainer): """Distillation trainer for HuggingFace models.""" @@ -100,13 +102,12 @@ def save_model( def train(self, *args, **kwargs): """Train the model.""" - ignore_index = nn.CrossEntropyLoss().ignore_index # equals -100 def _compute_kd_loss(outputs: Tensor, labels: Tensor | None, **kwargs): def loss_reduction_fn(loss: Tensor): if labels is None: - return loss.sum() / len(loss) # batchmean reduction - loss_mask = (labels != ignore_index).to(loss.dtype) + return loss.mean() + loss_mask = labels != IGNORE_TOKEN_ID return (loss * loss_mask).sum() / loss_mask.sum().clamp(min=1) return self.model.compute_kd_loss(loss_reduction_fn=loss_reduction_fn) @@ -120,7 +121,7 @@ class LMLogitsLoss(mtd.LogitsDistillationLoss): Defaults to ``reduction="none"`` to support per-token loss masking via ``loss_reduction_fn`` in :meth:`DistillationModel.compute_kd_loss`. This allows masking out padding and non-assistant - tokens (marked with ``IGNORE_INDEX = -100``) before reducing the loss. + tokens before reducing the loss. """ def __init__(self, temperature: float = 1.0, reduction: str = "none"): From 820bad71a11d812d26b2801fcc5b5dd922ec9cfa Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Tue, 10 Mar 2026 07:27:41 -0700 Subject: [PATCH 5/5] Fix bug Signed-off-by: Asha Anoosheh --- modelopt/torch/distill/losses.py | 7 ++++--- modelopt/torch/distill/plugins/huggingface.py | 5 +---- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/modelopt/torch/distill/losses.py b/modelopt/torch/distill/losses.py index 993228199..832437b7e 100644 --- a/modelopt/torch/distill/losses.py +++ b/modelopt/torch/distill/losses.py @@ -57,11 +57,12 @@ def forward(self, logits_s: torch.Tensor, logits_t: torch.Tensor) -> torch.Tenso soft_log_probs = F.log_softmax(logits_s / self._temperature, dim=-1) soft_targets = F.softmax(logits_t / self._temperature, dim=-1) - soft_log_probs = soft_log_probs.view(-1, soft_log_probs.size(-1)) - soft_targets = soft_targets.view(-1, soft_targets.size(-1)) - kd_loss = F.kl_div(soft_log_probs, soft_targets.detach(), reduction=self._reduction) + if self._reduction == "none": + # Remove vocab dimension + kd_loss = kd_loss.sum(dim=-1) + # Since the magnitudes of the gradients produced by the soft logits scale as 1/(T^2), # multiplying them by T^2 ensures that the relative contributions of the logits # remain roughly unchanged while experimenting with meta-parameters. diff --git a/modelopt/torch/distill/plugins/huggingface.py b/modelopt/torch/distill/plugins/huggingface.py index 33ccd7d4a..c865d8857 100644 --- a/modelopt/torch/distill/plugins/huggingface.py +++ b/modelopt/torch/distill/plugins/huggingface.py @@ -142,7 +142,4 @@ def forward(self, out_student: CausalLMOutputWithPast, out_teacher: CausalLMOutp out_teacher: The teacher model output. """ student_logits, teacher_logits = out_student.logits.float(), out_teacher.logits.float() - loss = super().forward(student_logits, teacher_logits) - if self._reduction == "none": - loss = loss.sum(dim=-1) - return loss + return super().forward(student_logits, teacher_logits)