Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions modelopt/torch/distill/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class LogitsDistillationLoss(Loss):
This function implements the distillation loss found in the paper: https://arxiv.org/abs/1503.02531.
"""

def __init__(self, temperature: float = 1.0, reduction: str = "batchmean"):
def __init__(self, temperature: float = 1.0, reduction: str = "mean"):
"""Constructor.

Args:
Expand All @@ -57,11 +57,12 @@ def forward(self, logits_s: torch.Tensor, logits_t: torch.Tensor) -> torch.Tenso
soft_log_probs = F.log_softmax(logits_s / self._temperature, dim=-1)
soft_targets = F.softmax(logits_t / self._temperature, dim=-1)

soft_log_probs = soft_log_probs.view(-1, soft_log_probs.size(-1))
soft_targets = soft_targets.view(-1, soft_targets.size(-1))

kd_loss = F.kl_div(soft_log_probs, soft_targets.detach(), reduction=self._reduction)

if self._reduction == "none":
# Remove vocab dimension
kd_loss = kd_loss.sum(dim=-1)

# Since the magnitudes of the gradients produced by the soft logits scale as 1/(T^2),
# multiplying them by T^2 ensures that the relative contributions of the logits
# remain roughly unchanged while experimenting with meta-parameters.
Expand Down
36 changes: 33 additions & 3 deletions modelopt/torch/distill/plugins/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@

"""ModelOpt plugin to train HuggingFace models with knowledge distillation."""

from torch import Tensor
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.trainer_pt_utils import LabelSmoother

import modelopt.torch.distill as mtd
from modelopt.torch.opt.plugins import ModelOptHFTrainer
from modelopt.torch.utils import print_rank_0

IGNORE_TOKEN_ID = LabelSmoother.ignore_index # equals -100


class KDTrainer(ModelOptHFTrainer):
"""Distillation trainer for HuggingFace models."""
Expand Down Expand Up @@ -98,12 +102,37 @@ def save_model(

def train(self, *args, **kwargs):
"""Train the model."""
self.compute_loss_func = lambda *args, **kwargs: self.model.compute_kd_loss()

def _compute_kd_loss(outputs: Tensor, labels: Tensor | None, **kwargs):
def loss_reduction_fn(loss: Tensor):
if labels is None:
return loss.mean()
loss_mask = labels != IGNORE_TOKEN_ID
return (loss * loss_mask).sum() / loss_mask.sum().clamp(min=1)

return self.model.compute_kd_loss(loss_reduction_fn=loss_reduction_fn)

self.compute_loss_func = _compute_kd_loss
return super().train(*args, **kwargs)


class LMLogitsLoss(mtd.LogitsDistillationLoss):
"""Logits loss for knowledge distillation."""
"""Logits loss for language-model knowledge distillation.

Defaults to ``reduction="none"`` to support per-token loss masking via ``loss_reduction_fn``
in :meth:`DistillationModel.compute_kd_loss`. This allows masking out padding and non-assistant
tokens before reducing the loss.
"""

def __init__(self, temperature: float = 1.0, reduction: str = "none"):
"""Constructor.

Args:
temperature: A value used to soften the logits before computing loss.
reduction: How to reduce the final pointwise loss. Defaults to ``"none"`` to
allow loss-masking via ``loss_reduction_fn`` in ``compute_kd_loss``.
"""
super().__init__(temperature=temperature, reduction=reduction)

def forward(self, out_student: CausalLMOutputWithPast, out_teacher: CausalLMOutputWithPast):
"""Forward pass for logits distillation loss.
Expand All @@ -112,4 +141,5 @@ def forward(self, out_student: CausalLMOutputWithPast, out_teacher: CausalLMOutp
out_student: The student model output.
out_teacher: The teacher model output.
"""
return super().forward(out_student.logits, out_teacher.logits)
student_logits, teacher_logits = out_student.logits.float(), out_teacher.logits.float()
return super().forward(student_logits, teacher_logits)
Loading