Skip to content

Add dynamic loss balancing for distillation feature and logit losses#4220

Open
ajkv-google wants to merge 1 commit into
mainfrom
ajkv/dynamic-weight-loss
Open

Add dynamic loss balancing for distillation feature and logit losses#4220
ajkv-google wants to merge 1 commit into
mainfrom
ajkv/dynamic-weight-loss

Conversation

@ajkv-google

Copy link
Copy Markdown
Collaborator

Description

Description

This PR updates the distillation loss computation in MaxText by introducing dynamic weighting for the feature loss and knowledge distillation terms.

In src/maxtext/trainers/post_train/distillation/distillation_utils.py, we implemented a dynamic scale to match the raw feature loss magnitude to the softened logit loss (soft_loss_mean), and then scale the combined kd_loss to match the target task loss (hard_loss_mean).

This dynamic loss balancing strategy is the approach used by the PIE team, and it helps stabilize multi-task objectives by making sure that different loss components maintain proportional gradient updates throughout training.

Specific Implementation Details:

  • Feature Loss Scaling: Raw feature loss is multiplied by feature_mag_scale, which is computed as the ratio of soft_loss_mean to raw_feature_loss (both with stop-gradients applied to match detach()).
  • Combined Distillation (KD) Scaling: The combined distillation loss kd_loss is scaled using kd_mag_scale, computed as the ratio of hard_loss_mean to kd_loss.
  • Total Loss: Calculated as: total_loss = hard_loss_mean + (alpha * balanced_kd_loss).

Tests

Ran training job with this new dynamic loss weighting strategy.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov

codecov Bot commented Jun 22, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant