|
| 1 | +import os |
| 2 | +import unittest |
| 3 | + |
| 4 | +os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8' |
| 5 | + |
| 6 | +import jax |
| 7 | +import jax.numpy as jnp |
| 8 | +import numpy as np |
| 9 | +from flax import nnx |
| 10 | +from jax.sharding import Mesh, PartitionSpec |
| 11 | + |
| 12 | +from tpu_commons.models.jax.common.moe.deepseek_moe import (DeepSeekV3Router, |
| 13 | + SparseMoE) |
| 14 | + |
| 15 | + |
| 16 | +class TestDeepSeekV3Router(unittest.TestCase): |
| 17 | + |
| 18 | + def setUp(self): |
| 19 | + self.cpu_mesh = Mesh(jax.devices('cpu'), axis_names=('data', )) |
| 20 | + |
| 21 | + def test_get_topk_indices_single_group(self): |
| 22 | + """Test get_topk_indices with single expert group.""" |
| 23 | + with jax.set_mesh(self.cpu_mesh): |
| 24 | + router = DeepSeekV3Router(random_init=True, |
| 25 | + hidden_size=512, |
| 26 | + num_experts=4, |
| 27 | + num_experts_per_tok=2, |
| 28 | + n_groups=1, |
| 29 | + topk_groups=1, |
| 30 | + norm_topk_prob=True, |
| 31 | + routed_scaling_factor=1.0, |
| 32 | + dtype=jnp.bfloat16, |
| 33 | + rngs=nnx.Rngs(42)) |
| 34 | + router.bias_E = jnp.zeros((4, )) |
| 35 | + |
| 36 | + scores = jnp.array([[0.1, 0.3, 0.2, 0.4]]) # shape: (1, 4) |
| 37 | + indices = router.get_topk_indices(scores) |
| 38 | + |
| 39 | + # Should return indices of top 2 experts |
| 40 | + expected_indices = jnp.array([[3, |
| 41 | + 1]]) # experts with scores 0.4, 0.3 |
| 42 | + self.assertTrue(jnp.array_equal(indices, expected_indices)) |
| 43 | + |
| 44 | + def test_get_topk_indices_2_groups(self): |
| 45 | + """Test get_topk_indices with 2 expert groups.""" |
| 46 | + with jax.set_mesh(self.cpu_mesh): |
| 47 | + router = DeepSeekV3Router(random_init=True, |
| 48 | + hidden_size=512, |
| 49 | + num_experts=4, |
| 50 | + num_experts_per_tok=2, |
| 51 | + n_groups=2, |
| 52 | + topk_groups=1, |
| 53 | + norm_topk_prob=True, |
| 54 | + routed_scaling_factor=1.0, |
| 55 | + dtype=jnp.bfloat16, |
| 56 | + rngs=nnx.Rngs(42)) |
| 57 | + router.bias_E = jnp.zeros((4, )) |
| 58 | + |
| 59 | + # 4 experts, 2 groups, 2 experts per group |
| 60 | + scores = jnp.array([[[0.1, 0.3, 0.2, 0.4]]]) # shape: (1, 1, 4) |
| 61 | + indices = router.get_topk_indices(scores) |
| 62 | + |
| 63 | + # Should return indices of top 2 experts |
| 64 | + expected_indices = jnp.array([[[3, 2]]]) |
| 65 | + self.assertTrue(jnp.array_equal(indices, expected_indices)) |
| 66 | + |
| 67 | + def test_router_e2e(self): |
| 68 | + with jax.set_mesh(self.cpu_mesh): |
| 69 | + router = DeepSeekV3Router(random_init=True, |
| 70 | + hidden_size=512, |
| 71 | + num_experts=8, |
| 72 | + num_experts_per_tok=2, |
| 73 | + n_groups=2, |
| 74 | + topk_groups=1, |
| 75 | + norm_topk_prob=True, |
| 76 | + routed_scaling_factor=1.0, |
| 77 | + dtype=jnp.bfloat16, |
| 78 | + rngs=nnx.Rngs(42)) |
| 79 | + x = jnp.ones((2, 512)) |
| 80 | + weights, indices = router(x) |
| 81 | + self.assertEqual(weights.shape, (2, 2)) |
| 82 | + self.assertEqual(indices.shape, (2, 2)) |
| 83 | + |
| 84 | + |
| 85 | +class TestSparseMoE(unittest.TestCase): |
| 86 | + |
| 87 | + def setUp(self): |
| 88 | + """Set up a multi-device mesh and a sample MoE layer for testing.""" |
| 89 | + devices = jax.devices() |
| 90 | + self.device_count = len(devices) |
| 91 | + if self.device_count < 8: |
| 92 | + self.skipTest("This test requires at least 8 simulated devices.") |
| 93 | + |
| 94 | + # This mesh will have a 'model' axis for expert parallelism |
| 95 | + mesh_shape = (self.device_count, 1) |
| 96 | + device_mesh_array = np.array(devices).reshape(mesh_shape) |
| 97 | + |
| 98 | + # Define the axis names |
| 99 | + axis_names = ('model', 'data') |
| 100 | + |
| 101 | + # Create the 2D mesh |
| 102 | + self.mesh = Mesh(device_mesh_array, axis_names=axis_names) |
| 103 | + |
| 104 | + # --- Model Configuration --- |
| 105 | + self.B, self.S, self.D = 2, 4, 16 # Batch, Sequence, Hidden Dim |
| 106 | + self.E, self.K = 16, 8 # Num Experts, Experts per Token |
| 107 | + self.moe_intermediate_size = 32 # FFN Dim |
| 108 | + self.num_expert_parallelism = 8 # Shard experts across 8 devices |
| 109 | + |
| 110 | + self.key = jax.random.PRNGKey(42) |
| 111 | + self.x = jax.random.normal(self.key, (self.B * self.S, self.D), |
| 112 | + dtype=jnp.bfloat16) |
| 113 | + |
| 114 | + # --- Instantiate MoE Layer --- |
| 115 | + # We need to do this inside the mesh context |
| 116 | + with self.mesh: |
| 117 | + router = DeepSeekV3Router(hidden_size=self.D, |
| 118 | + num_experts=self.E, |
| 119 | + num_experts_per_tok=self.K, |
| 120 | + n_groups=1, |
| 121 | + topk_groups=1, |
| 122 | + norm_topk_prob=False, |
| 123 | + routed_scaling_factor=1.0, |
| 124 | + dtype=jnp.bfloat16, |
| 125 | + rngs=nnx.Rngs(self.key), |
| 126 | + ed_sharding=PartitionSpec(), |
| 127 | + e_sharding=PartitionSpec(), |
| 128 | + activation_ffw_td=PartitionSpec( |
| 129 | + 'data', None)) |
| 130 | + # Instantiation updated to match user's code snippet |
| 131 | + self.moe = SparseMoE( |
| 132 | + hidden_size=self.D, |
| 133 | + intermediate_size_moe=self.moe_intermediate_size, |
| 134 | + num_local_experts=self.E, |
| 135 | + hidden_act="silu", |
| 136 | + num_experts_per_tok=self.K, |
| 137 | + router=router, |
| 138 | + dtype=jnp.bfloat16, |
| 139 | + rngs=nnx.Rngs(self.key), |
| 140 | + mesh=self.mesh, |
| 141 | + apply_expert_weight_before_computation=False, |
| 142 | + |
| 143 | + # Sharding specs updated based on user's snippet |
| 144 | + edf_sharding=PartitionSpec('model', None, None), |
| 145 | + efd_sharding=PartitionSpec('model', None, None), |
| 146 | + activation_ffw_ted=PartitionSpec('data', None), |
| 147 | + activation_ffw_td=PartitionSpec( |
| 148 | + 'data', None) # Activations are replicated |
| 149 | + ) |
| 150 | + |
| 151 | + def test_token_replicated_expert_parallel_fwd(self): |
| 152 | + """ |
| 153 | + Validates the MoE forward pass against a simple, dense equivalent. |
| 154 | + This specifically tests the is_batch_sharded_by_expert=False path. |
| 155 | + """ |
| 156 | + # --- 1. Get the ACTUAL output from the complex distributed MoE layer --- |
| 157 | + # The __call__ method will trigger the shard_map, which requires the mesh context. |
| 158 | + with self.mesh: |
| 159 | + actual_output = self.moe(self.x) |
| 160 | + |
| 161 | + # --- 2. Calculate the EXPECTED output using a simple, sequential process --- |
| 162 | + # This serves as the "ground truth". |
| 163 | + |
| 164 | + # Get router decisions (router params are replicated, so this is fine) |
| 165 | + router_weights, selected_experts = self.moe.router(self.x) |
| 166 | + |
| 167 | + # Gather the full, unsharded weights from all devices --- |
| 168 | + # .value on a sharded param gives the *local* shard. |
| 169 | + # jax.device_get() retrieves the *full* GlobalDeviceArray to the host. |
| 170 | + gating_kernel_full = jax.device_get(self.moe.kernel_gating_EDF.value) |
| 171 | + up_proj_kernel_full = jax.device_get(self.moe.kernel_up_proj_EDF.value) |
| 172 | + down_proj_kernel_full = jax.device_get( |
| 173 | + self.moe.kernel_down_proj_EFD.value) |
| 174 | + |
| 175 | + # Check that we really got the full weights |
| 176 | + self.assertEqual(gating_kernel_full.shape, |
| 177 | + (self.E, self.D, self.moe_intermediate_size)) |
| 178 | + |
| 179 | + # Flatten inputs for easier iteration |
| 180 | + flat_x = self.x.reshape(self.B * self.S, self.D) |
| 181 | + flat_weights = router_weights.reshape(self.B * self.S, self.K) |
| 182 | + flat_experts = selected_experts.reshape(self.B * self.S, self.K) |
| 183 | + |
| 184 | + expected_output = jnp.zeros_like(flat_x) |
| 185 | + |
| 186 | + # Manually apply each expert to each token sequentially |
| 187 | + for i in range(self.B * self.S): # For each token |
| 188 | + token_input = flat_x[i] |
| 189 | + combined_expert_output = jnp.zeros(self.D, dtype=jnp.bfloat16) |
| 190 | + |
| 191 | + for k in range(self.K): # For each chosen expert for that token |
| 192 | + expert_idx = flat_experts[i, k] |
| 193 | + weight = flat_weights[i, k] |
| 194 | + |
| 195 | + # Get kernels from the *full* gathered arrays --- |
| 196 | + gating_kernel = gating_kernel_full[expert_idx] |
| 197 | + up_proj_kernel = up_proj_kernel_full[expert_idx] |
| 198 | + down_proj_kernel = down_proj_kernel_full[expert_idx] |
| 199 | + |
| 200 | + # Perform the expert computation (dense matmuls) |
| 201 | + gating_proj = jnp.dot(token_input, gating_kernel) |
| 202 | + up_proj = jnp.dot(token_input, up_proj_kernel) |
| 203 | + |
| 204 | + # Note: Assuming 'silu' activation as specified in MoE init |
| 205 | + fused = nnx.silu(gating_proj) * up_proj |
| 206 | + |
| 207 | + expert_output = jnp.dot(fused, down_proj_kernel) |
| 208 | + |
| 209 | + # Apply router weight after computation (matches implementation) |
| 210 | + combined_expert_output += weight * expert_output |
| 211 | + |
| 212 | + expected_output = expected_output.at[i].set(combined_expert_output) |
| 213 | + |
| 214 | + expected_output = expected_output.reshape(self.B * self.S, self.D) |
| 215 | + |
| 216 | + # --- 3. Compare the results --- |
| 217 | + self.assertTrue( |
| 218 | + jnp.allclose(actual_output, expected_output, atol=1e-2, rtol=1e-2), |
| 219 | + f"The output of the distributed MoE does not match the dense equivalent.\n" |
| 220 | + f"Actual:\n{actual_output}\n" |
| 221 | + f"Expected:\n{expected_output}") |
| 222 | + print( |
| 223 | + "\n✅ Test Passed: Distributed MoE output matches the dense ground truth." |
| 224 | + ) |
0 commit comments