From 4a6be9893898d8264a0766c92517f050ad480aa2 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 19 Dec 2025 15:02:24 +0000 Subject: [PATCH 1/3] manual kl + memory savings --- fast_llm/functional/cross_entropy.py | 37 ++++++++++++++-------------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 8c9ea939..839b1e41 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -259,17 +259,16 @@ def _reverse_kl_forward_backward( if loss_mask is not None: Assert.eq(loss_mask.shape, logits.shape[:-1]) - # Compute log probabilities teacher_log_probs = distributed_log_softmax(target.float(), group=group) - student_log_probs = distributed_log_softmax(logits, group=group) - - # Reverse KL: input=teacher_log_probs, target=student_probs - loss_terms = torch.nn.functional.kl_div( - teacher_log_probs, # input = log(p) - student_log_probs, # target = log(q) - reduction="none", - log_target=True, - ).sum(dim=-1) + log_ratio = distributed_log_softmax(logits, group=group) + + student_probs = log_ratio.exp() + log_ratio.sub_(teacher_log_probs) # In-place: log_ratio = student_log_probs - teacher_log_probs + del teacher_log_probs + # Compute loss terms: student_probs * log_ratio, then sum over vocab + # This is equivalent to kl_div(..., log_target=True) but more memory efficient + loss_terms = (student_probs * log_ratio).sum(dim=-1) + if loss_mask is not None: # loss mask is the same on all ranks for TP over vocab. valid = loss_mask.to(loss_terms.dtype) @@ -284,20 +283,20 @@ def _reverse_kl_forward_backward( loss /= valid_tokens if grad_output is not None: - # need to calculate gradient manually, backprop through all reduce can be problematic, see https://github.com/pytorch/pytorch/issues/58005 - log_ratio = student_log_probs - teacher_log_probs - expected = torch.sum(torch.exp(student_log_probs) * log_ratio, dim=-1, keepdim=True) - # expected E_q(log s - log t) -- this is actually dependent on the full vocab! + # Gradient: d/d(logits) KL(q||p) = q * (log(q/p) - E_q[log(q/p)]) + # where E_q[log(q/p)] is the expected log ratio under the student distribution + expected = torch.sum(student_probs * log_ratio, dim=-1, keepdim=True) if group is not None: all_reduce(expected, op=ReduceOp.SUM, group=group) - grad_base = torch.exp(student_log_probs) * (log_ratio - expected) + log_ratio.sub_(expected) # In-place: log_ratio -= expected + log_ratio.mul_(student_probs) # In-place: now log_ratio is grad_base + del student_probs # Free after use if loss_mask is not None: - valid = loss_mask.to(logits.dtype).unsqueeze(-1) - grad_base = grad_base * valid + log_ratio.mul_(loss_mask.to(logits.dtype).unsqueeze(-1)) # In-place - grad = grad_base.mul(grad_output / valid_tokens) - grad = grad.to(logits.dtype) + log_ratio.mul_(grad_output / valid_tokens) # In-place + grad = log_ratio.to(logits.dtype) else: grad = None From eed426a471ddc3f2b0b28f2d5d4b6d526e9737f1 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 19 Dec 2025 19:06:10 +0000 Subject: [PATCH 2/3] average by seq. length --- fast_llm/functional/cross_entropy.py | 5 ++--- tests/functional/test_cross_entropy.py | 4 +++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 839b1e41..e25595a8 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -227,6 +227,7 @@ def distributed_log_softmax( return logits_norm - sum_exp_logits.log() # log_softmax +@torch.compile def _reverse_kl_forward_backward( logits: torch.Tensor, target: torch.Tensor, @@ -273,9 +274,7 @@ def _reverse_kl_forward_backward( # loss mask is the same on all ranks for TP over vocab. valid = loss_mask.to(loss_terms.dtype) loss_terms = loss_terms * valid - valid_tokens = valid.sum() - else: - valid_tokens = torch.prod(torch.tensor(loss_terms.shape, device=loss_terms.device, dtype=loss_terms.dtype)) + valid_tokens = torch.prod(torch.tensor(loss_terms.shape, device=loss_terms.device, dtype=loss_terms.dtype)) loss = loss_terms.sum() # sums over batch and seq. len. if group is not None: diff --git a/tests/functional/test_cross_entropy.py b/tests/functional/test_cross_entropy.py index 72644d06..20d16bb9 100644 --- a/tests/functional/test_cross_entropy.py +++ b/tests/functional/test_cross_entropy.py @@ -104,7 +104,9 @@ def _reverse_kl_forward_backward_torch(logits: torch.Tensor, target: torch.Tenso reduction="none", log_target=True, ).sum(dim=-1) - output = per_sample.mean() if loss_mask is None else (per_sample * loss_mask).sum() / loss_mask.sum() + if loss_mask is not None: + per_sample = per_sample * loss_mask + output = per_sample.mean() output.backward() return output, logits.grad From f179681da6dc7e64fe7526839215e07b96b36081 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 23 Dec 2025 02:32:47 +0000 Subject: [PATCH 3/3] removed in-place ops. --- fast_llm/functional/cross_entropy.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index e25595a8..a12516b5 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -264,7 +264,7 @@ def _reverse_kl_forward_backward( log_ratio = distributed_log_softmax(logits, group=group) student_probs = log_ratio.exp() - log_ratio.sub_(teacher_log_probs) # In-place: log_ratio = student_log_probs - teacher_log_probs + log_ratio = log_ratio - teacher_log_probs # In-place: log_ratio = student_log_probs - teacher_log_probs del teacher_log_probs # Compute loss terms: student_probs * log_ratio, then sum over vocab # This is equivalent to kl_div(..., log_target=True) but more memory efficient @@ -287,14 +287,14 @@ def _reverse_kl_forward_backward( expected = torch.sum(student_probs * log_ratio, dim=-1, keepdim=True) if group is not None: all_reduce(expected, op=ReduceOp.SUM, group=group) - log_ratio.sub_(expected) # In-place: log_ratio -= expected - log_ratio.mul_(student_probs) # In-place: now log_ratio is grad_base + log_ratio = log_ratio - expected + log_ratio = log_ratio * student_probs del student_probs # Free after use if loss_mask is not None: - log_ratio.mul_(loss_mask.to(logits.dtype).unsqueeze(-1)) # In-place + log_ratio = log_ratio * loss_mask.to(logits.dtype).unsqueeze(-1) - log_ratio.mul_(grad_output / valid_tokens) # In-place + log_ratio = log_ratio * (grad_output / valid_tokens) grad = log_ratio.to(logits.dtype) else: grad = None