From d3608dd5cd934e66ace36497a36e8ecbe22b3047 Mon Sep 17 00:00:00 2001 From: Ajay Krishna Vajjala Date: Mon, 22 Jun 2026 18:19:28 +0000 Subject: [PATCH] Updated loss to contain dynamic weighting for feature loss --- .../distillation/distillation_utils.py | 23 +++++++++++++++---- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/src/maxtext/trainers/post_train/distillation/distillation_utils.py b/src/maxtext/trainers/post_train/distillation/distillation_utils.py index f063cdb23a..e7d87de7f5 100644 --- a/src/maxtext/trainers/post_train/distillation/distillation_utils.py +++ b/src/maxtext/trainers/post_train/distillation/distillation_utils.py @@ -543,11 +543,10 @@ def compute_loss( ce_teacher_sum = jnp.sum(ce_teacher_per_pos * mask) - base_logit_loss = (alpha * soft_loss_mean) + ((1.0 - alpha) * hard_loss_mean) - + # Feature Loss Scaling (match Logit Loss magnitude) feature_loss = jnp.array(0.0, dtype=jnp.float32) + if self.beta_feature > 0.0: - assert s_features is not None and t_features is not None if self.layer_indices is not None: s_features_sliced = jnp.take(s_features, self.layer_indices, axis=0) t_features_sliced = jnp.take(t_features, self.layer_indices, axis=0) @@ -558,9 +557,23 @@ def compute_loss( s_features_sliced = s_features_sliced.astype(jnp.float32) t_features_sliced = t_features_sliced.astype(jnp.float32) - feature_loss = beta_feature * self.feature_loss_fn(s_features_sliced, t_features_sliced, mask) + # Calculate raw feature loss + raw_feature_loss = self.feature_loss_fn(s_features_sliced, t_features_sliced, mask) + + # Match magnitude: Scale raw feature loss to match soft loss mean + feature_mag_scale = jax.lax.stop_gradient(soft_loss_mean) / jnp.maximum( + jax.lax.stop_gradient(raw_feature_loss), 1e-8 + ) + feature_loss = raw_feature_loss * feature_mag_scale + + # Combined Distillation Loss (Logits + Balanced Features) + kd_loss = soft_loss_mean + feature_loss + + # KD Scaling (match Hard/Task Loss magnitude) + kd_mag_scale = jax.lax.stop_gradient(hard_loss_mean) / jnp.maximum(jax.lax.stop_gradient(kd_loss), 1e-8) + balanced_kd_loss = kd_loss * kd_mag_scale - total_loss = base_logit_loss + feature_loss + total_loss = hard_loss_mean + (alpha * balanced_kd_loss) moe_lb_loss = jnp.array(0.0) if student_output.moe_lb_loss is not None: