Skip to content

Conversation

@oleksost
Copy link
Contributor

@oleksost oleksost commented Dec 16, 2025

✨ Description

Fixes several issues related to loss masking, added loss matching checks in the model tests.

1. Cross entropy loss seem to be incorrectly averaged when loss masks are used?

Symptom:

  • no symptom really, since the reference torch implementation was averaging incorrectly here
  • was discovered because the df4 tests with loss masking were not failing for CE loss (they did for rev. KL loss)

Proposed solution:

  • IIUC when computing loss and grad mean we need to divide by the number of valid tokens (loss_mask.sum()) when mask is used and not by all tokens (similar as in rev. KL case)

2.Incorrect grad & loss averaging with loss masking & grad. accumulation

When training with micro-batches grad. accumulation & loss masking, loss and grads are incorrectly averaged across micro-batches because of different number of valid tokens in each micro-batch (i.e. loss_mask.sum() are different across micro-batches).

Symptom:

  • test tests/models/test_model.py::test_and_compare_model[mistral_reverse_kl-df4]@dependency_group_15 is failing

Proposed solution:

  • calculate global number of valid tokens (requires an additional iterations over micro-batches in the runner) --> scale grad_output by loss_mask.sum() / total_valid_tokens to make sure the average over micro-batches is correct.
  • added loss comparison to the model tests: bot pytest tests/models/test_model.py --models mistral_reverse_kl and pytest tests/models/test_model.py --models mistral_distill_logits pass now, the losses and grads seem to match the simple baseline.

🔍 Type of change

Select all that apply:

  • 🐛 Bug fix (non-breaking change that addresses a specific issue)
  • 🚀 New feature (non-breaking change that adds functionality)
  • ⚠️ Breaking change (a change that could affect existing functionality)
  • 📈 Performance improvement/optimization (improves speed, memory usage, or efficiency)
  • 🛠️ Code refactor (non-functional changes that improve code readability, structure, etc.)
  • 📦 Dependency bump (updates dependencies, including Dockerfile or package changes)
  • 📝 Documentation change (updates documentation, including new content or typo fixes)
  • 🔧 Infrastructure/Build change (affects build process, CI/CD, or dependencies)

📝 Changes

✅ Checklist

Make sure the following tasks are completed before submitting the PR:

General

  • 📜 I have read and followed the contributing guidelines.
  • 🏷️ I am using a clear and descriptive PR title that summarizes the key change or feature introduced.
  • 🎉 The functionality is complete, and I have tested the changes.
  • 📝 I have updated the documentation if needed.
  • ⚠️ The change does not introduce any new issues (e.g., runtime warnings, type checker errors, linting problems, unhandled edge cases).
  • 🧩 I have commented my code, especially in hard-to-understand areas.

Dependencies and Configuration

  • 🐋 I have updated the Docker configuration or dependencies, if applicable.
  • 🔄 I have ensured compatibility with the existing setup after dependency changes.

Testing

  • 🧪 I have added or updated tests to cover my changes.
  • ✔️ New and existing tests pass locally with my changes.
  • 🚦 I have tested these changes on GPUs and verified training stability.
  • 🏋️ I have tested the changes on realistic training workloads, if applicable.

Performance Impact

  • 📊 I have run benchmarks where applicable to evaluate the performance impact.
  • ✅ The benchmarks show no performance regression.
  • 🚀 The benchmarks indicate a potential performance improvement.
  • ⚠️ The benchmarks indicate a potential performance degradation.
  • 📈 I have provided benchmark results and detailed any performance impact below, if applicable.

📊 Performance Impact Details

If there is any impact on performance, describe it and provide benchmark results, if applicable:


🗒️ Additional Notes

Include any additional context, information, or considerations here, such as known issues, follow-up tasks, or backward compatibility concerns.

@oleksost oleksost changed the title Fixes: CE loss with loss masking & loss masking with micro-batches Loss masking fixes: cross entropy averaging & grad. accumulation Dec 16, 2025
@oleksost oleksost marked this pull request as draft December 16, 2025 22:45
@jlamypoirier
Copy link
Collaborator

jlamypoirier commented Dec 17, 2025

This was already discussed in #344, #345, and we chose to average over all tokens and not just the valid tokens. This is arguably preferable because every token end up having the same contribution to the gradients independently of masking. I'd recommend replacing reverse KL with a simple mean instead.

@oleksost
Copy link
Contributor Author

oleksost commented Dec 18, 2025

Ach, I see. thanks for pointing to that discussion, I should have checked those past conversation before spanding time on this.

So IIUC, if we do say SFT/SFT distillation and have a batch of 10 prompt tokens (masked with 0) and 10 output tokens), the gradients would have 2x smaller magnitude in case of normalising over all tokens as opposed to normalising over valid tokens only.

So while it is true that each token contributes the same relative amount to the grad, if we have samples with different amounts of masking, samples with less masking will contribute more? Which can lead to instabilities especially if we want to mix CPT and SFT data for distillation, where CPT data would have no masking and so would dsminate the grads? @RaymondLi0

Also, wouldn't this bias the model towards longer generations? (i.e. samples with more unmasked tokens contribute more)

@RaymondLi0
Copy link
Contributor

I think the main potential problem with averaging over valid tokens only is that it would artificially bump the gradient contributions of very short samples, maybe leading to a higher variance in gradients.

As Joel said, in the current version where we average over all tokens, all tokens have the same contribution to the gradients. So whether we bias the model towards longer generations only depends on our dataset mix. If we include enough short samples, there is no reason that we bias the model towards long generations. Remind also that documents are packed together, so a long sample could contain several short documents.
If we adequately mix CPT vs SFT and long-samples vs short-samples, I think the current implementation should work just fine.

@oleksost
Copy link
Contributor Author

oleksost commented Dec 18, 2025

Thanks for the clarification. After the discussion in the office, the trade-offs of the two approaches seem to be clear:

  1. Average losses & grads over constant L for each batch: the magnitude of batch's contribution to the gradients/loss depend on the number of valid tokens in the batch: more valid tokens => larger contribution of the batch. If a batch has disproportionally more samples with long generations, it will contribute more. Here we have equal contribution per token.
    This could potentially lead to training instabilities?

  2. Average loss&grad over valid token count: here the loss and grad magnitudes per batch are independent of the number of masked tokens in each batch. So the training is always more stable. However, we don't have equal contribution per token across the batches.

Not sure how much it matters in practice (depends on the data distribution), it might slow down learning.

It seem to me that having guarantees that the loss/grad magnitude will not randomly explode is important but it comes potentially at some efficiency+complexity cost in Fast-LLM.

@oleksost oleksost changed the title Loss masking fixes: cross entropy averaging & grad. accumulation [Prototype] Normalising by valid tokens Dec 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants