Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 20 additions & 22 deletions fast_llm/functional/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def distributed_log_softmax(
return logits_norm - sum_exp_logits.log() # log_softmax


@torch.compile
def _reverse_kl_forward_backward(
logits: torch.Tensor,
target: torch.Tensor,
Expand Down Expand Up @@ -259,45 +260,42 @@ def _reverse_kl_forward_backward(
if loss_mask is not None:
Assert.eq(loss_mask.shape, logits.shape[:-1])

# Compute log probabilities
teacher_log_probs = distributed_log_softmax(target.float(), group=group)
student_log_probs = distributed_log_softmax(logits, group=group)

# Reverse KL: input=teacher_log_probs, target=student_probs
loss_terms = torch.nn.functional.kl_div(
teacher_log_probs, # input = log(p)
student_log_probs, # target = log(q)
reduction="none",
log_target=True,
).sum(dim=-1)
log_ratio = distributed_log_softmax(logits, group=group)

student_probs = log_ratio.exp()
log_ratio = log_ratio - teacher_log_probs # In-place: log_ratio = student_log_probs - teacher_log_probs
del teacher_log_probs
# Compute loss terms: student_probs * log_ratio, then sum over vocab
# This is equivalent to kl_div(..., log_target=True) but more memory efficient
loss_terms = (student_probs * log_ratio).sum(dim=-1)

if loss_mask is not None:
# loss mask is the same on all ranks for TP over vocab.
valid = loss_mask.to(loss_terms.dtype)
loss_terms = loss_terms * valid
valid_tokens = valid.sum()
else:
valid_tokens = torch.prod(torch.tensor(loss_terms.shape, device=loss_terms.device, dtype=loss_terms.dtype))
valid_tokens = torch.prod(torch.tensor(loss_terms.shape, device=loss_terms.device, dtype=loss_terms.dtype))
loss = loss_terms.sum() # sums over batch and seq. len.

if group is not None:
all_reduce(loss, op=ReduceOp.SUM, group=group)
loss /= valid_tokens

if grad_output is not None:
# need to calculate gradient manually, backprop through all reduce can be problematic, see https://github.com/pytorch/pytorch/issues/58005
log_ratio = student_log_probs - teacher_log_probs
expected = torch.sum(torch.exp(student_log_probs) * log_ratio, dim=-1, keepdim=True)
# expected E_q(log s - log t) -- this is actually dependent on the full vocab!
# Gradient: d/d(logits) KL(q||p) = q * (log(q/p) - E_q[log(q/p)])
# where E_q[log(q/p)] is the expected log ratio under the student distribution
expected = torch.sum(student_probs * log_ratio, dim=-1, keepdim=True)
if group is not None:
all_reduce(expected, op=ReduceOp.SUM, group=group)
grad_base = torch.exp(student_log_probs) * (log_ratio - expected)
log_ratio = log_ratio - expected
log_ratio = log_ratio * student_probs
del student_probs # Free after use

if loss_mask is not None:
valid = loss_mask.to(logits.dtype).unsqueeze(-1)
grad_base = grad_base * valid
log_ratio = log_ratio * loss_mask.to(logits.dtype).unsqueeze(-1)

grad = grad_base.mul(grad_output / valid_tokens)
grad = grad.to(logits.dtype)
log_ratio = log_ratio * (grad_output / valid_tokens)
grad = log_ratio.to(logits.dtype)
else:
grad = None

Expand Down
4 changes: 3 additions & 1 deletion tests/functional/test_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ def _reverse_kl_forward_backward_torch(logits: torch.Tensor, target: torch.Tenso
reduction="none",
log_target=True,
).sum(dim=-1)
output = per_sample.mean() if loss_mask is None else (per_sample * loss_mask).sum() / loss_mask.sum()
if loss_mask is not None:
per_sample = per_sample * loss_mask
output = per_sample.mean()
output.backward()
return output, logits.grad

Expand Down