Fix: Perform sigmoid calculation in fp32 for aux loss stability #2765
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR updates
compute_routing_scores_for_aux_lossto performsigmoidrouting score calculations infloat32, matching the existing behavior of thesoftmaxpath.Fixes #2741
Currently, the
softmaxrouting path explicitly casts tofloat32to avoid underflow/overflow in BF16/FP16. However, thesigmoidpath performs operations in the input dtype.While
sigmoidis bounded [0, 1], the subsequent normalization (scores / sum) involves accumulation that can suffer from precision loss in BF16, especially with a large number of experts. This change aligns both methods to use high-precision accumulation for auxiliary loss stability.