Skip to content

Commit a171d3e

Browse files
bzgooglebzgoogle
authored andcommitted
address some comments
Signed-off-by: bzgoogle <beinuoz_google_com@t1v-n-3f245f84-w-0.us-central1-c.c.tpu-prod-env-one-vm.internal>
1 parent 39ba1d5 commit a171d3e

File tree

2 files changed

+0
-13
lines changed

2 files changed

+0
-13
lines changed

tests/models/jax/common/moe/test_deepseek_moe.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import os
22
import unittest
33

4-
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
5-
64
import jax
75
import jax.numpy as jnp
86
import numpy as np

tpu_commons/models/jax/common/moe/deepseek_moe.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -325,14 +325,6 @@ def _unpermute(self, processed_tokens: jax.Array, sort_indices: jax.Array,
325325
processed_tokens, jnp.argsort(sort_indices))
326326
reshaped_tokens_TXD = unsorted_tokens_tD.reshape(
327327
-1, self.num_experts_per_tok, self.hidden_size)
328-
# jax.debug.print(
329-
# "✅ reshaped_tokens_TXD on device: reshaped_tokens_TXD[5]={t}",
330-
# t=reshaped_tokens_TXD[5, 0,:5]
331-
# )
332-
# jax.debug.print(
333-
# "✅ router_weights_TX on device: router_weights_TX={t}",
334-
# t=router_weights_TX[5, :]
335-
# )
336328
with jax.named_scope("combine_weights"):
337329
output_TD = jnp.einsum(
338330
"TXD,TX -> TD",
@@ -484,9 +476,6 @@ def _distributed_sparse_moe_fwd(
484476
compute_expert_ids = global_sorted_experts
485477
local_sorted_indices = jnp.arange(sorted_inputs.shape[0])
486478

487-
#debug_position_in_sorted = jnp.argsort(global_sort_indices)[40:48]
488-
#debug_position_compute_inputs = jnp.argsort(local_sorted_indices)[debug_position_in_sorted]
489-
490479
# 4. Compute: Apply experts using Grouped Matrix Multiply
491480
with jax.named_scope("gating"):
492481
# compute_inputs: (local total assignments, D)

0 commit comments

Comments
 (0)