diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 8c9ea9399..a12516b5d 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -227,6 +227,7 @@ def distributed_log_softmax( return logits_norm - sum_exp_logits.log() # log_softmax +@torch.compile def _reverse_kl_forward_backward( logits: torch.Tensor, target: torch.Tensor, @@ -259,24 +260,21 @@ def _reverse_kl_forward_backward( if loss_mask is not None: Assert.eq(loss_mask.shape, logits.shape[:-1]) - # Compute log probabilities teacher_log_probs = distributed_log_softmax(target.float(), group=group) - student_log_probs = distributed_log_softmax(logits, group=group) - - # Reverse KL: input=teacher_log_probs, target=student_probs - loss_terms = torch.nn.functional.kl_div( - teacher_log_probs, # input = log(p) - student_log_probs, # target = log(q) - reduction="none", - log_target=True, - ).sum(dim=-1) + log_ratio = distributed_log_softmax(logits, group=group) + + student_probs = log_ratio.exp() + log_ratio = log_ratio - teacher_log_probs # In-place: log_ratio = student_log_probs - teacher_log_probs + del teacher_log_probs + # Compute loss terms: student_probs * log_ratio, then sum over vocab + # This is equivalent to kl_div(..., log_target=True) but more memory efficient + loss_terms = (student_probs * log_ratio).sum(dim=-1) + if loss_mask is not None: # loss mask is the same on all ranks for TP over vocab. valid = loss_mask.to(loss_terms.dtype) loss_terms = loss_terms * valid - valid_tokens = valid.sum() - else: - valid_tokens = torch.prod(torch.tensor(loss_terms.shape, device=loss_terms.device, dtype=loss_terms.dtype)) + valid_tokens = torch.prod(torch.tensor(loss_terms.shape, device=loss_terms.device, dtype=loss_terms.dtype)) loss = loss_terms.sum() # sums over batch and seq. len. if group is not None: @@ -284,20 +282,20 @@ def _reverse_kl_forward_backward( loss /= valid_tokens if grad_output is not None: - # need to calculate gradient manually, backprop through all reduce can be problematic, see https://github.com/pytorch/pytorch/issues/58005 - log_ratio = student_log_probs - teacher_log_probs - expected = torch.sum(torch.exp(student_log_probs) * log_ratio, dim=-1, keepdim=True) - # expected E_q(log s - log t) -- this is actually dependent on the full vocab! + # Gradient: d/d(logits) KL(q||p) = q * (log(q/p) - E_q[log(q/p)]) + # where E_q[log(q/p)] is the expected log ratio under the student distribution + expected = torch.sum(student_probs * log_ratio, dim=-1, keepdim=True) if group is not None: all_reduce(expected, op=ReduceOp.SUM, group=group) - grad_base = torch.exp(student_log_probs) * (log_ratio - expected) + log_ratio = log_ratio - expected + log_ratio = log_ratio * student_probs + del student_probs # Free after use if loss_mask is not None: - valid = loss_mask.to(logits.dtype).unsqueeze(-1) - grad_base = grad_base * valid + log_ratio = log_ratio * loss_mask.to(logits.dtype).unsqueeze(-1) - grad = grad_base.mul(grad_output / valid_tokens) - grad = grad.to(logits.dtype) + log_ratio = log_ratio * (grad_output / valid_tokens) + grad = log_ratio.to(logits.dtype) else: grad = None diff --git a/tests/functional/test_cross_entropy.py b/tests/functional/test_cross_entropy.py index 72644d061..20d16bb96 100644 --- a/tests/functional/test_cross_entropy.py +++ b/tests/functional/test_cross_entropy.py @@ -104,7 +104,9 @@ def _reverse_kl_forward_backward_torch(logits: torch.Tensor, target: torch.Tenso reduction="none", log_target=True, ).sum(dim=-1) - output = per_sample.mean() if loss_mask is None else (per_sample * loss_mask).sum() / loss_mask.sum() + if loss_mask is not None: + per_sample = per_sample * loss_mask + output = per_sample.mean() output.backward() return output, logits.grad