-
Notifications
You must be signed in to change notification settings - Fork 39
[Prototype] Normalising by valid tokens #426
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
This was already discussed in #344, #345, and we chose to average over all tokens and not just the valid tokens. This is arguably preferable because every token end up having the same contribution to the gradients independently of masking. I'd recommend replacing reverse KL with a simple mean instead. |
|
Ach, I see. thanks for pointing to that discussion, I should have checked those past conversation before spanding time on this. So IIUC, if we do say SFT/SFT distillation and have a batch of 10 prompt tokens (masked with 0) and 10 output tokens), the gradients would have 2x smaller magnitude in case of normalising over all tokens as opposed to normalising over valid tokens only. So while it is true that each token contributes the same relative amount to the grad, if we have samples with different amounts of masking, samples with less masking will contribute more? Which can lead to instabilities especially if we want to mix CPT and SFT data for distillation, where CPT data would have no masking and so would dsminate the grads? @RaymondLi0 Also, wouldn't this bias the model towards longer generations? (i.e. samples with more unmasked tokens contribute more) |
|
I think the main potential problem with averaging over valid tokens only is that it would artificially bump the gradient contributions of very short samples, maybe leading to a higher variance in gradients. As Joel said, in the current version where we average over all tokens, all tokens have the same contribution to the gradients. So whether we bias the model towards longer generations only depends on our dataset mix. If we include enough short samples, there is no reason that we bias the model towards long generations. Remind also that documents are packed together, so a long sample could contain several short documents. |
|
Thanks for the clarification. After the discussion in the office, the trade-offs of the two approaches seem to be clear:
Not sure how much it matters in practice (depends on the data distribution), it might slow down learning. It seem to me that having guarantees that the loss/grad magnitude will not randomly explode is important but it comes potentially at some efficiency+complexity cost in Fast-LLM. |
✨ Description
Fixes several issues related to loss masking, added loss matching checks in the model tests.
1. Cross entropy loss seem to be incorrectly averaged when loss masks are used?
Symptom:
df4tests with loss masking were not failing for CE loss (they did for rev. KL loss)Proposed solution:
loss_mask.sum()) when mask is used and not by all tokens (similar as in rev. KL case)2.Incorrect grad & loss averaging with loss masking & grad. accumulation
When training with micro-batches grad. accumulation & loss masking, loss and grads are incorrectly averaged across micro-batches because of different number of valid tokens in each micro-batch (i.e. loss_mask.sum() are different across micro-batches).
Symptom:
tests/models/test_model.py::test_and_compare_model[mistral_reverse_kl-df4]@dependency_group_15is failingProposed solution:
grad_outputbyloss_mask.sum() / total_valid_tokensto make sure the average over micro-batches is correct.pytest tests/models/test_model.py --models mistral_reverse_klandpytest tests/models/test_model.py --models mistral_distill_logitspass now, the losses and grads seem to match thesimplebaseline.🔍 Type of change
Select all that apply:
📝 Changes
✅ Checklist
Make sure the following tasks are completed before submitting the PR:
General
Dependencies and Configuration
Testing
Performance Impact
📊 Performance Impact Details
If there is any impact on performance, describe it and provide benchmark results, if applicable:
🗒️ Additional Notes
Include any additional context, information, or considerations here, such as known issues, follow-up tasks, or backward compatibility concerns.