Skip to content

Commit 39ba1d5

Browse files
bzgooglebzgoogle
authored andcommitted
add unit test; add flag to support switching between dense/sparse matmul
Signed-off-by: bzgoogle <beinuoz@google.com> Signed-off-by: bzgoogle <beinuoz_google_com@t1v-n-3f245f84-w-0.us-central1-c.c.tpu-prod-env-one-vm.internal>
1 parent 68a4e2e commit 39ba1d5

File tree

1 file changed

+224
-0
lines changed

1 file changed

+224
-0
lines changed
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
import os
2+
import unittest
3+
4+
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
5+
6+
import jax
7+
import jax.numpy as jnp
8+
import numpy as np
9+
from flax import nnx
10+
from jax.sharding import Mesh, PartitionSpec
11+
12+
from tpu_commons.models.jax.common.moe.deepseek_moe import (DeepSeekV3Router,
13+
SparseMoE)
14+
15+
16+
class TestDeepSeekV3Router(unittest.TestCase):
17+
18+
def setUp(self):
19+
self.cpu_mesh = Mesh(jax.devices('cpu'), axis_names=('data', ))
20+
21+
def test_get_topk_indices_single_group(self):
22+
"""Test get_topk_indices with single expert group."""
23+
with jax.set_mesh(self.cpu_mesh):
24+
router = DeepSeekV3Router(random_init=True,
25+
hidden_size=512,
26+
num_experts=4,
27+
num_experts_per_tok=2,
28+
n_groups=1,
29+
topk_groups=1,
30+
norm_topk_prob=True,
31+
routed_scaling_factor=1.0,
32+
dtype=jnp.bfloat16,
33+
rngs=nnx.Rngs(42))
34+
router.bias_E = jnp.zeros((4, ))
35+
36+
scores = jnp.array([[0.1, 0.3, 0.2, 0.4]]) # shape: (1, 4)
37+
indices = router.get_topk_indices(scores)
38+
39+
# Should return indices of top 2 experts
40+
expected_indices = jnp.array([[3,
41+
1]]) # experts with scores 0.4, 0.3
42+
self.assertTrue(jnp.array_equal(indices, expected_indices))
43+
44+
def test_get_topk_indices_2_groups(self):
45+
"""Test get_topk_indices with 2 expert groups."""
46+
with jax.set_mesh(self.cpu_mesh):
47+
router = DeepSeekV3Router(random_init=True,
48+
hidden_size=512,
49+
num_experts=4,
50+
num_experts_per_tok=2,
51+
n_groups=2,
52+
topk_groups=1,
53+
norm_topk_prob=True,
54+
routed_scaling_factor=1.0,
55+
dtype=jnp.bfloat16,
56+
rngs=nnx.Rngs(42))
57+
router.bias_E = jnp.zeros((4, ))
58+
59+
# 4 experts, 2 groups, 2 experts per group
60+
scores = jnp.array([[[0.1, 0.3, 0.2, 0.4]]]) # shape: (1, 1, 4)
61+
indices = router.get_topk_indices(scores)
62+
63+
# Should return indices of top 2 experts
64+
expected_indices = jnp.array([[[3, 2]]])
65+
self.assertTrue(jnp.array_equal(indices, expected_indices))
66+
67+
def test_router_e2e(self):
68+
with jax.set_mesh(self.cpu_mesh):
69+
router = DeepSeekV3Router(random_init=True,
70+
hidden_size=512,
71+
num_experts=8,
72+
num_experts_per_tok=2,
73+
n_groups=2,
74+
topk_groups=1,
75+
norm_topk_prob=True,
76+
routed_scaling_factor=1.0,
77+
dtype=jnp.bfloat16,
78+
rngs=nnx.Rngs(42))
79+
x = jnp.ones((2, 512))
80+
weights, indices = router(x)
81+
self.assertEqual(weights.shape, (2, 2))
82+
self.assertEqual(indices.shape, (2, 2))
83+
84+
85+
class TestSparseMoE(unittest.TestCase):
86+
87+
def setUp(self):
88+
"""Set up a multi-device mesh and a sample MoE layer for testing."""
89+
devices = jax.devices()
90+
self.device_count = len(devices)
91+
if self.device_count < 8:
92+
self.skipTest("This test requires at least 8 simulated devices.")
93+
94+
# This mesh will have a 'model' axis for expert parallelism
95+
mesh_shape = (self.device_count, 1)
96+
device_mesh_array = np.array(devices).reshape(mesh_shape)
97+
98+
# Define the axis names
99+
axis_names = ('model', 'data')
100+
101+
# Create the 2D mesh
102+
self.mesh = Mesh(device_mesh_array, axis_names=axis_names)
103+
104+
# --- Model Configuration ---
105+
self.B, self.S, self.D = 2, 4, 16 # Batch, Sequence, Hidden Dim
106+
self.E, self.K = 16, 8 # Num Experts, Experts per Token
107+
self.moe_intermediate_size = 32 # FFN Dim
108+
self.num_expert_parallelism = 8 # Shard experts across 8 devices
109+
110+
self.key = jax.random.PRNGKey(42)
111+
self.x = jax.random.normal(self.key, (self.B * self.S, self.D),
112+
dtype=jnp.bfloat16)
113+
114+
# --- Instantiate MoE Layer ---
115+
# We need to do this inside the mesh context
116+
with self.mesh:
117+
router = DeepSeekV3Router(hidden_size=self.D,
118+
num_experts=self.E,
119+
num_experts_per_tok=self.K,
120+
n_groups=1,
121+
topk_groups=1,
122+
norm_topk_prob=False,
123+
routed_scaling_factor=1.0,
124+
dtype=jnp.bfloat16,
125+
rngs=nnx.Rngs(self.key),
126+
ed_sharding=PartitionSpec(),
127+
e_sharding=PartitionSpec(),
128+
activation_ffw_td=PartitionSpec(
129+
'data', None))
130+
# Instantiation updated to match user's code snippet
131+
self.moe = SparseMoE(
132+
hidden_size=self.D,
133+
intermediate_size_moe=self.moe_intermediate_size,
134+
num_local_experts=self.E,
135+
hidden_act="silu",
136+
num_experts_per_tok=self.K,
137+
router=router,
138+
dtype=jnp.bfloat16,
139+
rngs=nnx.Rngs(self.key),
140+
mesh=self.mesh,
141+
apply_expert_weight_before_computation=False,
142+
143+
# Sharding specs updated based on user's snippet
144+
edf_sharding=PartitionSpec('model', None, None),
145+
efd_sharding=PartitionSpec('model', None, None),
146+
activation_ffw_ted=PartitionSpec('data', None),
147+
activation_ffw_td=PartitionSpec(
148+
'data', None) # Activations are replicated
149+
)
150+
151+
def test_token_replicated_expert_parallel_fwd(self):
152+
"""
153+
Validates the MoE forward pass against a simple, dense equivalent.
154+
This specifically tests the is_batch_sharded_by_expert=False path.
155+
"""
156+
# --- 1. Get the ACTUAL output from the complex distributed MoE layer ---
157+
# The __call__ method will trigger the shard_map, which requires the mesh context.
158+
with self.mesh:
159+
actual_output = self.moe(self.x)
160+
161+
# --- 2. Calculate the EXPECTED output using a simple, sequential process ---
162+
# This serves as the "ground truth".
163+
164+
# Get router decisions (router params are replicated, so this is fine)
165+
router_weights, selected_experts = self.moe.router(self.x)
166+
167+
# Gather the full, unsharded weights from all devices ---
168+
# .value on a sharded param gives the *local* shard.
169+
# jax.device_get() retrieves the *full* GlobalDeviceArray to the host.
170+
gating_kernel_full = jax.device_get(self.moe.kernel_gating_EDF.value)
171+
up_proj_kernel_full = jax.device_get(self.moe.kernel_up_proj_EDF.value)
172+
down_proj_kernel_full = jax.device_get(
173+
self.moe.kernel_down_proj_EFD.value)
174+
175+
# Check that we really got the full weights
176+
self.assertEqual(gating_kernel_full.shape,
177+
(self.E, self.D, self.moe_intermediate_size))
178+
179+
# Flatten inputs for easier iteration
180+
flat_x = self.x.reshape(self.B * self.S, self.D)
181+
flat_weights = router_weights.reshape(self.B * self.S, self.K)
182+
flat_experts = selected_experts.reshape(self.B * self.S, self.K)
183+
184+
expected_output = jnp.zeros_like(flat_x)
185+
186+
# Manually apply each expert to each token sequentially
187+
for i in range(self.B * self.S): # For each token
188+
token_input = flat_x[i]
189+
combined_expert_output = jnp.zeros(self.D, dtype=jnp.bfloat16)
190+
191+
for k in range(self.K): # For each chosen expert for that token
192+
expert_idx = flat_experts[i, k]
193+
weight = flat_weights[i, k]
194+
195+
# Get kernels from the *full* gathered arrays ---
196+
gating_kernel = gating_kernel_full[expert_idx]
197+
up_proj_kernel = up_proj_kernel_full[expert_idx]
198+
down_proj_kernel = down_proj_kernel_full[expert_idx]
199+
200+
# Perform the expert computation (dense matmuls)
201+
gating_proj = jnp.dot(token_input, gating_kernel)
202+
up_proj = jnp.dot(token_input, up_proj_kernel)
203+
204+
# Note: Assuming 'silu' activation as specified in MoE init
205+
fused = nnx.silu(gating_proj) * up_proj
206+
207+
expert_output = jnp.dot(fused, down_proj_kernel)
208+
209+
# Apply router weight after computation (matches implementation)
210+
combined_expert_output += weight * expert_output
211+
212+
expected_output = expected_output.at[i].set(combined_expert_output)
213+
214+
expected_output = expected_output.reshape(self.B * self.S, self.D)
215+
216+
# --- 3. Compare the results ---
217+
self.assertTrue(
218+
jnp.allclose(actual_output, expected_output, atol=1e-2, rtol=1e-2),
219+
f"The output of the distributed MoE does not match the dense equivalent.\n"
220+
f"Actual:\n{actual_output}\n"
221+
f"Expected:\n{expected_output}")
222+
print(
223+
"\n✅ Test Passed: Distributed MoE output matches the dense ground truth."
224+
)

0 commit comments

Comments
 (0)