Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -543,11 +543,10 @@ def compute_loss(

ce_teacher_sum = jnp.sum(ce_teacher_per_pos * mask)

base_logit_loss = (alpha * soft_loss_mean) + ((1.0 - alpha) * hard_loss_mean)

# Feature Loss Scaling (match Logit Loss magnitude)
feature_loss = jnp.array(0.0, dtype=jnp.float32)

if self.beta_feature > 0.0:
assert s_features is not None and t_features is not None
if self.layer_indices is not None:
s_features_sliced = jnp.take(s_features, self.layer_indices, axis=0)
t_features_sliced = jnp.take(t_features, self.layer_indices, axis=0)
Expand All @@ -558,9 +557,23 @@ def compute_loss(
s_features_sliced = s_features_sliced.astype(jnp.float32)
t_features_sliced = t_features_sliced.astype(jnp.float32)

feature_loss = beta_feature * self.feature_loss_fn(s_features_sliced, t_features_sliced, mask)
# Calculate raw feature loss
raw_feature_loss = self.feature_loss_fn(s_features_sliced, t_features_sliced, mask)

# Match magnitude: Scale raw feature loss to match soft loss mean
feature_mag_scale = jax.lax.stop_gradient(soft_loss_mean) / jnp.maximum(
jax.lax.stop_gradient(raw_feature_loss), 1e-8
)
feature_loss = raw_feature_loss * feature_mag_scale

# Combined Distillation Loss (Logits + Balanced Features)
kd_loss = soft_loss_mean + feature_loss

# KD Scaling (match Hard/Task Loss magnitude)
kd_mag_scale = jax.lax.stop_gradient(hard_loss_mean) / jnp.maximum(jax.lax.stop_gradient(kd_loss), 1e-8)
balanced_kd_loss = kd_loss * kd_mag_scale

total_loss = base_logit_loss + feature_loss
total_loss = hard_loss_mean + (alpha * balanced_kd_loss)

moe_lb_loss = jnp.array(0.0)
if student_output.moe_lb_loss is not None:
Expand Down
Loading