diff --git a/modelopt/torch/distill/losses.py b/modelopt/torch/distill/losses.py index 258824bf0..832437b7e 100644 --- a/modelopt/torch/distill/losses.py +++ b/modelopt/torch/distill/losses.py @@ -31,7 +31,7 @@ class LogitsDistillationLoss(Loss): This function implements the distillation loss found in the paper: https://arxiv.org/abs/1503.02531. """ - def __init__(self, temperature: float = 1.0, reduction: str = "batchmean"): + def __init__(self, temperature: float = 1.0, reduction: str = "mean"): """Constructor. Args: @@ -57,11 +57,12 @@ def forward(self, logits_s: torch.Tensor, logits_t: torch.Tensor) -> torch.Tenso soft_log_probs = F.log_softmax(logits_s / self._temperature, dim=-1) soft_targets = F.softmax(logits_t / self._temperature, dim=-1) - soft_log_probs = soft_log_probs.view(-1, soft_log_probs.size(-1)) - soft_targets = soft_targets.view(-1, soft_targets.size(-1)) - kd_loss = F.kl_div(soft_log_probs, soft_targets.detach(), reduction=self._reduction) + if self._reduction == "none": + # Remove vocab dimension + kd_loss = kd_loss.sum(dim=-1) + # Since the magnitudes of the gradients produced by the soft logits scale as 1/(T^2), # multiplying them by T^2 ensures that the relative contributions of the logits # remain roughly unchanged while experimenting with meta-parameters. diff --git a/modelopt/torch/distill/plugins/huggingface.py b/modelopt/torch/distill/plugins/huggingface.py index ff5cdc426..c865d8857 100644 --- a/modelopt/torch/distill/plugins/huggingface.py +++ b/modelopt/torch/distill/plugins/huggingface.py @@ -15,12 +15,16 @@ """ModelOpt plugin to train HuggingFace models with knowledge distillation.""" +from torch import Tensor from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.trainer_pt_utils import LabelSmoother import modelopt.torch.distill as mtd from modelopt.torch.opt.plugins import ModelOptHFTrainer from modelopt.torch.utils import print_rank_0 +IGNORE_TOKEN_ID = LabelSmoother.ignore_index # equals -100 + class KDTrainer(ModelOptHFTrainer): """Distillation trainer for HuggingFace models.""" @@ -98,12 +102,37 @@ def save_model( def train(self, *args, **kwargs): """Train the model.""" - self.compute_loss_func = lambda *args, **kwargs: self.model.compute_kd_loss() + + def _compute_kd_loss(outputs: Tensor, labels: Tensor | None, **kwargs): + def loss_reduction_fn(loss: Tensor): + if labels is None: + return loss.mean() + loss_mask = labels != IGNORE_TOKEN_ID + return (loss * loss_mask).sum() / loss_mask.sum().clamp(min=1) + + return self.model.compute_kd_loss(loss_reduction_fn=loss_reduction_fn) + + self.compute_loss_func = _compute_kd_loss return super().train(*args, **kwargs) class LMLogitsLoss(mtd.LogitsDistillationLoss): - """Logits loss for knowledge distillation.""" + """Logits loss for language-model knowledge distillation. + + Defaults to ``reduction="none"`` to support per-token loss masking via ``loss_reduction_fn`` + in :meth:`DistillationModel.compute_kd_loss`. This allows masking out padding and non-assistant + tokens before reducing the loss. + """ + + def __init__(self, temperature: float = 1.0, reduction: str = "none"): + """Constructor. + + Args: + temperature: A value used to soften the logits before computing loss. + reduction: How to reduce the final pointwise loss. Defaults to ``"none"`` to + allow loss-masking via ``loss_reduction_fn`` in ``compute_kd_loss``. + """ + super().__init__(temperature=temperature, reduction=reduction) def forward(self, out_student: CausalLMOutputWithPast, out_teacher: CausalLMOutputWithPast): """Forward pass for logits distillation loss. @@ -112,4 +141,5 @@ def forward(self, out_student: CausalLMOutputWithPast, out_teacher: CausalLMOutp out_student: The student model output. out_teacher: The teacher model output. """ - return super().forward(out_student.logits, out_teacher.logits) + student_logits, teacher_logits = out_student.logits.float(), out_teacher.logits.float() + return super().forward(student_logits, teacher_logits)