From c09f4d324078c264d9acb63f06c912c7f8268c3a Mon Sep 17 00:00:00 2001 From: bzgoogle Date: Mon, 15 Sep 2025 15:37:50 +0000 Subject: [PATCH 01/11] Implement the SparseMatmul for DeepSeek Signed-off-by: bzgoogle Signed-off-by: bzgoogle --- .../models/jax/common/moe/deepseek_moe.py | 594 ++++++++++++++++++ 1 file changed, 594 insertions(+) create mode 100644 tpu_commons/models/jax/common/moe/deepseek_moe.py diff --git a/tpu_commons/models/jax/common/moe/deepseek_moe.py b/tpu_commons/models/jax/common/moe/deepseek_moe.py new file mode 100644 index 000000000..5ccfbd838 --- /dev/null +++ b/tpu_commons/models/jax/common/moe/deepseek_moe.py @@ -0,0 +1,594 @@ +import enum +from dataclasses import InitVar, dataclass +from functools import partial +from typing import Tuple + +import jax +import jax.numpy as jnp +from flax import nnx +from flax.typing import Sharding +from jax.sharding import PartitionSpec +from jaxtyping import Float + +from tpu_commons.models.jax.common.base import create_param +from tpu_commons.models.jax.common.layers import FlaxUtils +from tpu_commons.models.jax.common.moe.moe import MoE + +modeling_flax_utils = FlaxUtils() + + +@dataclass +class DeepSeekV3Router(nnx.Module): + """Router module for Mixture-of-Experts (MoE) layers. + + This module determines which experts each token should be routed to based on the input. + + """ + + hidden_size: int + num_experts: int + num_experts_per_tok: int + n_groups: int + topk_groups: int + norm_topk_prob: bool + routed_scaling_factor: float + dtype: jnp.dtype + rngs: InitVar[nnx.Rngs] + + # Sharding Attributes + activation_ffw_td: Sharding = () + ed_sharding: Sharding = () + e_sharding: Sharding = () + + random_init: bool = False + + router_bias_dtype: jnp.dtype = jnp.float32 + + def get_topk_indices(self, scores_TE: Float) -> Float: + """Get the topk indices of the scores. + + Args: + scores_TE: The scores to get the topk indices of. Shape (sequence, num_experts). + + Returns: + The topk indices of the scores. Shape (sequence, num_experts_per_tok). + """ + + scores_TE = scores_TE + self.bias_E + if self.n_groups > 1: + experts_per_group = self.num_experts // self.n_groups + group_scores_TGM = jnp.reshape( + scores_TE, (-1, self.n_groups, experts_per_group)) + group_scores_TG2 = jax.lax.top_k(group_scores_TGM, k=2)[0] + group_scores_TG = jnp.sum(group_scores_TG2, axis=-1) + indices = jax.lax.top_k(group_scores_TG, k=self.topk_groups)[1] + + mask_TG = jnp.any(jnp.arange( + self.n_groups)[:, None] == indices[..., None, :], + axis=-1) + mask_TE = jnp.repeat(mask_TG, + scores_TE.shape[-1] // mask_TG.shape[-1], -1) + scores_TE = jnp.where(mask_TE, scores_TE, 0.0) + + indices_TX = jax.lax.top_k(scores_TE, k=self.num_experts_per_tok)[1] + + return indices_TX + + def __call__(self, x_TD: Float) -> Tuple[Float, Float]: + """Routes tokens to top k experts. + + Args: + x_TD: Input array of shape (sequence, d_model). + + Returns: + A tuple containing: + - weights: Normalized weights for selected experts, shape (sequence, num_experts_per_tok). + - indices: Indices of selected experts, shape (sequence, num_experts_per_tok). + """ + x_TD = jnp.asarray(x_TD, self.dtype) + x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td) + + scores_TE = jnp.einsum("TD,DE -> TE", x_TD, self.kernel_DE.value) + scores_TE = nnx.sigmoid(scores_TE) + + original_scores_TE = scores_TE + topk_indices_TX = self.get_topk_indices(scores_TE) + weights_TX = jnp.take_along_axis(original_scores_TE, + topk_indices_TX, + axis=-1) + + if self.norm_topk_prob: + weights_TX /= jnp.sum(weights_TX, axis=-1)[..., None] + 1e-20 + + weights_TX *= self.routed_scaling_factor + + return weights_TX, topk_indices_TX + + def __post_init__(self, rngs: nnx.Rngs): + """Generates the router kernel (weights and bias) for routing.""" + D = self.hidden_size + E = self.num_experts + self.kernel_DE = create_param(rngs, + shape=(D, E), + dtype=self.dtype, + sharding=self.ed_sharding, + random_init=self.random_init) + self.bias_E = create_param(rngs, + shape=(E, ), + dtype=self.router_bias_dtype, + sharding=self.e_sharding, + random_init=self.random_init) + + +@dataclass(kw_only=True) +class SparseMoE(MoE): + """Mixture-of-Experts (MoE) Routed MLP Layer. + + This module implements a Sparse MoE layer with a router and multiple expert MLPs. + + Attributes: + num_experts_per_tok: The number of experts each token is routed to. + tile_size: A tuple (batch, activation_dim, weight_dim) for GMM tiling. + use_megablox: If True, uses the MegaBlox GMM kernel. + mesh: The device mesh. + # TODO: need to redesign this I/O for parallelism + num_expert_parallelism: The size of the 'expert' mesh dimension. + # TODO: determine if we get it from external or extrat it in MoE class + is_batch_sharded_by_expert: True if batch is sharded over 'expert' dim. + """ + num_experts_per_tok: int + #TODO: tile size is (tile_batch_seq, tile_activation_dim, tile_weight_dim,) from MaxText + tile_size: tuple[int, int, int] = (128, 64, 128) + use_megablox: bool = False + mesh: jax.sharding.Mesh + + def __post_init__(self, rngs: nnx.Rngs): + super().__post_init__(rngs) + + # Derive the expert sharding + self.expert_axis_name = self.edf_sharding[0] + if self.expert_axis_name is None: + self.num_expert_parallelism = 1 + else: + self.num_expert_parallelism = self.mesh.shape[ + self.expert_axis_name] + + # Derive if data is sharded by expert + self.data_axis_name = self.activation_ffw_td[0] + self.is_batch_sharded_by_expert = ( + self.expert_axis_name is not None) and (self.expert_axis_name + == self.data_axis_name) + + def _sort_activations(self, inputs: jax.Array, + sort_indices: jax.Array) -> jax.Array: + """Sorts activations(inputs) by `sort_indices` for the forward pass.""" + return inputs[sort_indices, ...] + + @staticmethod + def get_all_to_all_params( + all_shards_group_sizes, + shard_id, + num_expert_parallelism, + is_batch_sharded=True, + ): + """Generates params for ragged_all_to_all communication.""" + + class TransformStrategy(enum.Enum): + INPUT_OFFSET = enum.auto() + SEND_SIZE = enum.auto() + OUTPUT_OFFSET = enum.auto() + RECV_SIZE = enum.auto() + + def transform_array(input_array, shard_id, strategy, is_batch_sharded): + if is_batch_sharded: + if strategy == TransformStrategy.INPUT_OFFSET: + local_array = input_array[shard_id] + return jnp.concatenate( + (jnp.array([0]), jnp.cumsum(local_array)[:-1])) + elif strategy == TransformStrategy.SEND_SIZE: + return input_array[shard_id] + elif strategy == TransformStrategy.OUTPUT_OFFSET: + zero_row = jnp.zeros((1, ) + input_array.shape[1:], + dtype=input_array.dtype) + array_with_zeros = jnp.concatenate((zero_row, input_array), + axis=0) + cumulated_array = jnp.cumsum(array_with_zeros, + axis=0, + dtype=input_array.dtype) + return cumulated_array[shard_id] + elif strategy == TransformStrategy.RECV_SIZE: + return input_array[:, shard_id] + else: + raise ValueError( + f"Unknown transform array strategy: {strategy}") + else: + if strategy == TransformStrategy.INPUT_OFFSET: + return jnp.zeros(num_expert_parallelism, + dtype=input_array.dtype) + elif strategy == TransformStrategy.SEND_SIZE: + return jnp.repeat(input_array[shard_id], + num_expert_parallelism) + elif strategy == TransformStrategy.OUTPUT_OFFSET: + output_offset = jnp.concatenate( + (jnp.array([0]), + jnp.cumsum(input_array[:-1])))[shard_id] + return jnp.repeat(output_offset, num_expert_parallelism) + elif strategy == TransformStrategy.RECV_SIZE: + return input_array + else: + raise ValueError( + f"Unknown transform array strategy: {strategy}") + + input_offsets = transform_array(all_shards_group_sizes, shard_id, + TransformStrategy.INPUT_OFFSET, + is_batch_sharded) + send_sizes = transform_array(all_shards_group_sizes, shard_id, + TransformStrategy.SEND_SIZE, + is_batch_sharded) + output_offsets = transform_array(all_shards_group_sizes, shard_id, + TransformStrategy.OUTPUT_OFFSET, + is_batch_sharded) + recv_sizes = transform_array(all_shards_group_sizes, shard_id, + TransformStrategy.RECV_SIZE, + is_batch_sharded) + return input_offsets, send_sizes, output_offsets, recv_sizes + + def _local_permute( + self, + inputs, + global_group_sizes, + local_expert_size, + shard_index, + is_offset=False, + global_sorted_experts=None, + ): + """Permutes tokens locally within an expert shard.""" + # global_group_sizes: (tokens parallelism, num_total_experts) + # all_shard_local_sizes: (tokens parallelism, num local experts in the shard) + all_shard_local_sizes = jax.lax.dynamic_slice_in_dim( + global_group_sizes, + shard_index * local_expert_size, + local_expert_size, + axis=1, + ) + local_sizes = all_shard_local_sizes.reshape(-1) + + # local_group_size: (tokens parallelism, ) + local_group_size = jnp.sum(all_shard_local_sizes, axis=0) + + # When token replicated in devices + if is_offset: + global_sorted_shard_assignments = jnp.floor_divide( + global_sorted_experts, local_expert_size) + expert_indices = jnp.where( + global_sorted_shard_assignments == shard_index, + jnp.mod(global_sorted_experts, local_expert_size), + local_expert_size, + ) + + # When token sharded in devices + else: + base_indices = jnp.mod(jnp.arange(local_sizes.shape[0]), + local_expert_size) + expert_indices = jnp.repeat(base_indices, + local_sizes, + total_repeat_length=inputs.shape[0]) + + sorted_indices = jnp.argsort(expert_indices) + # sort the inputs based on the local expert_indices + sorted_inputs = self._sort_activations(inputs, sorted_indices) + # sortted local expert id from 0 to local expert size + sorted_experts_ids = expert_indices[sorted_indices] + return ( + sorted_inputs, + sorted_indices, + local_group_size, + sorted_experts_ids, + ) + + def _permute(self, inputs_TD: Float, selected_experts_TX: jax.Array): + """Global permute: Sorts tokens by assigned expert.""" + # suffix t = T * X = total_assignments for the local tokens(T) on this device. + total_tokens = inputs_TD.shape[0] + flat_expert_indices = selected_experts_TX.flatten() + sort_indices_t = jnp.argsort(flat_expert_indices) + + replicated_inputs_tD = jnp.repeat(inputs_TD, + self.num_experts_per_tok, + axis=0) + sorted_inputs_tD = self._sort_activations(replicated_inputs_tD, + sort_indices_t) + + # number of tokens assigned to each expert + group_sizes_E = jnp.bincount(flat_expert_indices, + length=self.num_local_experts) + + expert_ids = jnp.arange(self.num_local_experts) + total_assignments = total_tokens * self.num_experts_per_tok + sorted_expert_assignments_t = jnp.repeat( + expert_ids, + repeats=group_sizes_E, + total_repeat_length=total_assignments) + + return ( + sorted_inputs_tD, + sort_indices_t, + group_sizes_E, + sorted_expert_assignments_t, + ) + + def _unpermute(self, processed_tokens: jax.Array, sort_indices: jax.Array, + router_weights_TX: jax.Array): + """Unsorts tokens to their original order and combines expert outputs with router's weight.""" + with jax.named_scope("unpermute"): + unsorted_tokens_tD = self._sort_activations( + processed_tokens, jnp.argsort(sort_indices)) + reshaped_tokens_TXD = unsorted_tokens_tD.reshape( + -1, self.num_experts_per_tok, self.hidden_size) + # jax.debug.print( + # "✅ reshaped_tokens_TXD on device: reshaped_tokens_TXD[5]={t}", + # t=reshaped_tokens_TXD[5, 0,:5] + # ) + # jax.debug.print( + # "✅ router_weights_TX on device: router_weights_TX={t}", + # t=router_weights_TX[5, :] + # ) + with jax.named_scope("combine_weights"): + output_TD = jnp.einsum( + "TXD,TX -> TD", + reshaped_tokens_TXD.astype(jnp.float32), + router_weights_TX.astype(jnp.float32), + precision='float32', + ) + + return output_TD.astype(self.dtype) + + def _gmm(self, inputs, kernel, group_sizes): + """Performs Grouped Matrix Multiply.""" + num_rows = inputs.shape[0] + pad_amount = (self.tile_size[0] - + num_rows % self.tile_size[0]) % self.tile_size[0] + if pad_amount > 0: + inputs = jnp.pad(inputs, ((0, pad_amount), (0, 0))) + + if self.use_megablox: + #TODO: megablox is used in MaxText, keep a placeholder here for future implement + raise NotImplementedError( + "MegaBlox kernel call is not implemented.") + else: + output = jax.lax.ragged_dot( + lhs=inputs, + rhs=kernel, + group_sizes=group_sizes, + preferred_element_type=self.dtype, + ) + + if pad_amount > 0: + output = output[:num_rows, :] + return output + + @staticmethod + def _distributed_sparse_moe_fwd( + self, + x_TD: jax.Array, + router_weights_TX: jax.Array, + selected_experts_TX: jax.Array, + kernel_gating: jax.Array, + kernel_up_proj: jax.Array, + kernel_down_proj: jax.Array, + ): + """ + The sparse MoE forward pass with fully distributed logic. + This assumes it is running within a distributed TPU. + """ + + # 1. Global Permute, perpute all tokens across shards + ( + sorted_inputs, + global_sort_indices, + global_group_sizes, + global_sorted_experts, + ) = self._permute(x_TD, selected_experts_TX) + + # TODO: update to 'expert' after we enable expert parallelism, currently experts are sharded along model axis + # or we sould derive it from the model init + expert_shard_id = jax.lax.axis_index(self.expert_axis_name) + local_expert_size = self.num_local_experts // self.num_expert_parallelism + + if self.num_expert_parallelism > 1: + if self.is_batch_sharded_by_expert: + # When token sharded in devices + # In this path, we assume the data(tokens) are fully sharded on expert, namely data_axis_name == expert_axis_name + + # 2a. Send Tokens To Experts (All-to-All) + # Gather group sizes from all data shards + # all_shards_group_sizes: (data parallelism = expert parallelism, number of total experts ) + all_shards_group_sizes = jax.lax.all_gather( + global_group_sizes, axis_name=self.data_axis_name) + + # all_shards_group_sizes_per_expert_shard[i][j] = # tokens on shard[i] to be sent to expert shard[j] + all_shards_group_sizes_per_expert_shard = jnp.sum( + all_shards_group_sizes.reshape( + self.num_expert_parallelism, # data parallelism + self.num_expert_parallelism, # expert parallelism + local_expert_size # Experts per shard + ), + axis=2) + input_offsets, send_sizes, output_offsets, recv_sizes = self.get_all_to_all_params( + all_shards_group_sizes_per_expert_shard, expert_shard_id, + self.num_expert_parallelism) + # Estimate buffer size + local_total_assignments = x_TD.shape[ + 0] * self.num_experts_per_tok + global_total_assignments = local_total_assignments * self.num_expert_parallelism + output_shape_est = jnp.zeros( + (global_total_assignments, self.hidden_size), + dtype=sorted_inputs.dtype) + + inputs_after_all2all = jax.lax.ragged_all_to_all( + sorted_inputs, + output_shape_est, + input_offsets, + send_sizes, + output_offsets, + recv_sizes, + axis_name=self.expert_axis_name) + + # 3a. Local Permute + # Get full group sizes from all shards + full_global_group_sizes = jax.lax.all_gather( + global_group_sizes, axis_name=self.expert_axis_name) + ( + compute_inputs, + local_sorted_indices, + compute_group_sizes, + compute_expert_ids, + ) = self._local_permute( + inputs_after_all2all, + full_global_group_sizes, + local_expert_size, + shard_index=expert_shard_id, + is_offset=False, + ) + + else: + # When token replicated in devices + + # 2. No send all-to-all needed, as the tokens are sorted and replicated on all devices + # 3b. Local "Permute" + ( + compute_inputs, + local_sorted_indices, + compute_group_sizes, + compute_expert_ids, + ) = self._local_permute( + sorted_inputs, + global_group_sizes[None, :], + local_expert_size, + shard_index=expert_shard_id, + is_offset=True, + global_sorted_experts=global_sorted_experts, + ) + + # Calculate group sizes for return all-to-all + reshaped_group_sizes = jnp.sum(global_group_sizes.reshape( + -1, local_expert_size), + axis=1) + mask = compute_expert_ids < local_expert_size + compute_inputs = compute_inputs * mask[..., None] + + else: + # --- NO EXPERT PARALLELISM --- + compute_inputs = sorted_inputs + compute_group_sizes = global_group_sizes + compute_expert_ids = global_sorted_experts + local_sorted_indices = jnp.arange(sorted_inputs.shape[0]) + + #debug_position_in_sorted = jnp.argsort(global_sort_indices)[40:48] + #debug_position_compute_inputs = jnp.argsort(local_sorted_indices)[debug_position_in_sorted] + + # 4. Compute: Apply experts using Grouped Matrix Multiply + with jax.named_scope("gating"): + # compute_inputs: (local total assignments, D) + gating_TEF = self._gmm(compute_inputs, kernel_gating, + compute_group_sizes) + activated_gating_TEF = modeling_flax_utils.ACT2FN[self.hidden_act]( + gating_TEF) + + with jax.named_scope("up_projection"): + up_proj_TEF = self._gmm(compute_inputs, kernel_up_proj, + compute_group_sizes) + + fuse_TEF = activated_gating_TEF * up_proj_TEF + + with jax.named_scope("down_projection"): + # intermediate_output: (local total assignments, D) + intermediate_output = self._gmm(fuse_TEF, kernel_down_proj, + compute_group_sizes) + + # 5. Return Results (All-to-All) + if self.num_expert_parallelism > 1: + local_total_assignments = x_TD.shape[0] * self.num_experts_per_tok + output_shape = jnp.zeros( + (local_total_assignments, self.hidden_size), + dtype=intermediate_output.dtype) + + if self.is_batch_sharded_by_expert: + # When token sharded in devices + # Unsort locally before sending back + local_output = self._sort_activations( + intermediate_output, jnp.argsort(local_sorted_indices)) + + input_offsets, send_sizes, output_offsets, recv_sizes = self.get_all_to_all_params( + jnp.transpose(all_shards_group_sizes), + expert_shard_id, + self.num_expert_parallelism, + ) + final_intermediate_output = jax.lax.ragged_all_to_all( + local_output, + output_shape, + input_offsets, + send_sizes, + output_offsets, + recv_sizes, + axis_name=self.expert_axis_name) + else: + # When token replicated in devices + input_offsets, send_sizes, output_offsets, recv_sizes = self.get_all_to_all_params( + reshaped_group_sizes, + expert_shard_id, + self.num_expert_parallelism, + is_batch_sharded=False, + ) + final_intermediate_output = jax.lax.ragged_all_to_all( + intermediate_output, + output_shape, + input_offsets, + send_sizes, + output_offsets, + recv_sizes, + axis_name=self.expert_axis_name) + else: + final_intermediate_output = intermediate_output + + # 6. Global Unpermute (on the data shard) + with jax.named_scope("unpermute"): + output_TD = self._unpermute(final_intermediate_output, + global_sort_indices, router_weights_TX) + + return output_TD + + def __call__(self, x_TD: Float): + """Performs the forward pass of the Sparse MoE layer.""" + x_TD = jnp.asarray(x_TD, self.dtype) + x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td) + router_weights_TX, selected_experts_TX = self.router(x_TD) + + in_specs = ( + PartitionSpec(), # Replicated `self` + PartitionSpec(*self.activation_ffw_td), # Sharded x_TD + PartitionSpec(), # Replicated router_weights_TX + PartitionSpec(), # Replicated selected_experts_TX + PartitionSpec(*self.edf_sharding), # Sharded gating kernel + PartitionSpec(*self.edf_sharding), # Sharded up-projection kernel + PartitionSpec( + *self.efd_sharding), # Sharded down-projection kernel + ) + out_specs = PartitionSpec(*self.activation_ffw_td) + + mapped_moe_fwd = partial(jax.experimental.shard_map.shard_map, + mesh=self.mesh, + in_specs=in_specs, + out_specs=out_specs, + check_rep=False)( + SparseMoE._distributed_sparse_moe_fwd) + + return mapped_moe_fwd( + self, + x_TD, + router_weights_TX, + selected_experts_TX, + self.kernel_gating_EDF.value, + self.kernel_up_proj_EDF.value, + self.kernel_down_proj_EFD.value, + ) From 601708cbede79f4a2e225e3519f7268690964928 Mon Sep 17 00:00:00 2001 From: bzgoogle Date: Mon, 15 Sep 2025 23:07:38 +0000 Subject: [PATCH 02/11] add unit test; add flag to support switching between dense/sparse matmul Signed-off-by: bzgoogle Signed-off-by: bzgoogle --- .../jax/common/moe/test_deepseek_moe.py | 224 ++++++++++++++++++ 1 file changed, 224 insertions(+) create mode 100644 tests/models/jax/common/moe/test_deepseek_moe.py diff --git a/tests/models/jax/common/moe/test_deepseek_moe.py b/tests/models/jax/common/moe/test_deepseek_moe.py new file mode 100644 index 000000000..4c3942013 --- /dev/null +++ b/tests/models/jax/common/moe/test_deepseek_moe.py @@ -0,0 +1,224 @@ +import os +import unittest + +os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8' + +import jax +import jax.numpy as jnp +import numpy as np +from flax import nnx +from jax.sharding import Mesh, PartitionSpec + +from tpu_commons.models.jax.common.moe.deepseek_moe import (DeepSeekV3Router, + SparseMoE) + + +class TestDeepSeekV3Router(unittest.TestCase): + + def setUp(self): + self.cpu_mesh = Mesh(jax.devices('cpu'), axis_names=('data', )) + + def test_get_topk_indices_single_group(self): + """Test get_topk_indices with single expert group.""" + with jax.set_mesh(self.cpu_mesh): + router = DeepSeekV3Router(random_init=True, + hidden_size=512, + num_experts=4, + num_experts_per_tok=2, + n_groups=1, + topk_groups=1, + norm_topk_prob=True, + routed_scaling_factor=1.0, + dtype=jnp.bfloat16, + rngs=nnx.Rngs(42)) + router.bias_E = jnp.zeros((4, )) + + scores = jnp.array([[0.1, 0.3, 0.2, 0.4]]) # shape: (1, 4) + indices = router.get_topk_indices(scores) + + # Should return indices of top 2 experts + expected_indices = jnp.array([[3, + 1]]) # experts with scores 0.4, 0.3 + self.assertTrue(jnp.array_equal(indices, expected_indices)) + + def test_get_topk_indices_2_groups(self): + """Test get_topk_indices with 2 expert groups.""" + with jax.set_mesh(self.cpu_mesh): + router = DeepSeekV3Router(random_init=True, + hidden_size=512, + num_experts=4, + num_experts_per_tok=2, + n_groups=2, + topk_groups=1, + norm_topk_prob=True, + routed_scaling_factor=1.0, + dtype=jnp.bfloat16, + rngs=nnx.Rngs(42)) + router.bias_E = jnp.zeros((4, )) + + # 4 experts, 2 groups, 2 experts per group + scores = jnp.array([[[0.1, 0.3, 0.2, 0.4]]]) # shape: (1, 1, 4) + indices = router.get_topk_indices(scores) + + # Should return indices of top 2 experts + expected_indices = jnp.array([[[3, 2]]]) + self.assertTrue(jnp.array_equal(indices, expected_indices)) + + def test_router_e2e(self): + with jax.set_mesh(self.cpu_mesh): + router = DeepSeekV3Router(random_init=True, + hidden_size=512, + num_experts=8, + num_experts_per_tok=2, + n_groups=2, + topk_groups=1, + norm_topk_prob=True, + routed_scaling_factor=1.0, + dtype=jnp.bfloat16, + rngs=nnx.Rngs(42)) + x = jnp.ones((2, 512)) + weights, indices = router(x) + self.assertEqual(weights.shape, (2, 2)) + self.assertEqual(indices.shape, (2, 2)) + + +class TestSparseMoE(unittest.TestCase): + + def setUp(self): + """Set up a multi-device mesh and a sample MoE layer for testing.""" + devices = jax.devices() + self.device_count = len(devices) + if self.device_count < 8: + self.skipTest("This test requires at least 8 simulated devices.") + + # This mesh will have a 'model' axis for expert parallelism + mesh_shape = (self.device_count, 1) + device_mesh_array = np.array(devices).reshape(mesh_shape) + + # Define the axis names + axis_names = ('model', 'data') + + # Create the 2D mesh + self.mesh = Mesh(device_mesh_array, axis_names=axis_names) + + # --- Model Configuration --- + self.B, self.S, self.D = 2, 4, 16 # Batch, Sequence, Hidden Dim + self.E, self.K = 16, 8 # Num Experts, Experts per Token + self.moe_intermediate_size = 32 # FFN Dim + self.num_expert_parallelism = 8 # Shard experts across 8 devices + + self.key = jax.random.PRNGKey(42) + self.x = jax.random.normal(self.key, (self.B * self.S, self.D), + dtype=jnp.bfloat16) + + # --- Instantiate MoE Layer --- + # We need to do this inside the mesh context + with self.mesh: + router = DeepSeekV3Router(hidden_size=self.D, + num_experts=self.E, + num_experts_per_tok=self.K, + n_groups=1, + topk_groups=1, + norm_topk_prob=False, + routed_scaling_factor=1.0, + dtype=jnp.bfloat16, + rngs=nnx.Rngs(self.key), + ed_sharding=PartitionSpec(), + e_sharding=PartitionSpec(), + activation_ffw_td=PartitionSpec( + 'data', None)) + # Instantiation updated to match user's code snippet + self.moe = SparseMoE( + hidden_size=self.D, + intermediate_size_moe=self.moe_intermediate_size, + num_local_experts=self.E, + hidden_act="silu", + num_experts_per_tok=self.K, + router=router, + dtype=jnp.bfloat16, + rngs=nnx.Rngs(self.key), + mesh=self.mesh, + apply_expert_weight_before_computation=False, + + # Sharding specs updated based on user's snippet + edf_sharding=PartitionSpec('model', None, None), + efd_sharding=PartitionSpec('model', None, None), + activation_ffw_ted=PartitionSpec('data', None), + activation_ffw_td=PartitionSpec( + 'data', None) # Activations are replicated + ) + + def test_token_replicated_expert_parallel_fwd(self): + """ + Validates the MoE forward pass against a simple, dense equivalent. + This specifically tests the is_batch_sharded_by_expert=False path. + """ + # --- 1. Get the ACTUAL output from the complex distributed MoE layer --- + # The __call__ method will trigger the shard_map, which requires the mesh context. + with self.mesh: + actual_output = self.moe(self.x) + + # --- 2. Calculate the EXPECTED output using a simple, sequential process --- + # This serves as the "ground truth". + + # Get router decisions (router params are replicated, so this is fine) + router_weights, selected_experts = self.moe.router(self.x) + + # Gather the full, unsharded weights from all devices --- + # .value on a sharded param gives the *local* shard. + # jax.device_get() retrieves the *full* GlobalDeviceArray to the host. + gating_kernel_full = jax.device_get(self.moe.kernel_gating_EDF.value) + up_proj_kernel_full = jax.device_get(self.moe.kernel_up_proj_EDF.value) + down_proj_kernel_full = jax.device_get( + self.moe.kernel_down_proj_EFD.value) + + # Check that we really got the full weights + self.assertEqual(gating_kernel_full.shape, + (self.E, self.D, self.moe_intermediate_size)) + + # Flatten inputs for easier iteration + flat_x = self.x.reshape(self.B * self.S, self.D) + flat_weights = router_weights.reshape(self.B * self.S, self.K) + flat_experts = selected_experts.reshape(self.B * self.S, self.K) + + expected_output = jnp.zeros_like(flat_x) + + # Manually apply each expert to each token sequentially + for i in range(self.B * self.S): # For each token + token_input = flat_x[i] + combined_expert_output = jnp.zeros(self.D, dtype=jnp.bfloat16) + + for k in range(self.K): # For each chosen expert for that token + expert_idx = flat_experts[i, k] + weight = flat_weights[i, k] + + # Get kernels from the *full* gathered arrays --- + gating_kernel = gating_kernel_full[expert_idx] + up_proj_kernel = up_proj_kernel_full[expert_idx] + down_proj_kernel = down_proj_kernel_full[expert_idx] + + # Perform the expert computation (dense matmuls) + gating_proj = jnp.dot(token_input, gating_kernel) + up_proj = jnp.dot(token_input, up_proj_kernel) + + # Note: Assuming 'silu' activation as specified in MoE init + fused = nnx.silu(gating_proj) * up_proj + + expert_output = jnp.dot(fused, down_proj_kernel) + + # Apply router weight after computation (matches implementation) + combined_expert_output += weight * expert_output + + expected_output = expected_output.at[i].set(combined_expert_output) + + expected_output = expected_output.reshape(self.B * self.S, self.D) + + # --- 3. Compare the results --- + self.assertTrue( + jnp.allclose(actual_output, expected_output, atol=1e-2, rtol=1e-2), + f"The output of the distributed MoE does not match the dense equivalent.\n" + f"Actual:\n{actual_output}\n" + f"Expected:\n{expected_output}") + print( + "\n✅ Test Passed: Distributed MoE output matches the dense ground truth." + ) From 787ba7d5ccd9960ca12d68be234604067e78e794 Mon Sep 17 00:00:00 2001 From: bzgoogle Date: Tue, 30 Sep 2025 22:00:12 +0000 Subject: [PATCH 03/11] address some comments Signed-off-by: bzgoogle --- tests/models/jax/common/moe/test_deepseek_moe.py | 2 -- tpu_commons/models/jax/common/moe/deepseek_moe.py | 11 ----------- 2 files changed, 13 deletions(-) diff --git a/tests/models/jax/common/moe/test_deepseek_moe.py b/tests/models/jax/common/moe/test_deepseek_moe.py index 4c3942013..d0b0d77bb 100644 --- a/tests/models/jax/common/moe/test_deepseek_moe.py +++ b/tests/models/jax/common/moe/test_deepseek_moe.py @@ -1,8 +1,6 @@ import os import unittest -os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8' - import jax import jax.numpy as jnp import numpy as np diff --git a/tpu_commons/models/jax/common/moe/deepseek_moe.py b/tpu_commons/models/jax/common/moe/deepseek_moe.py index 5ccfbd838..03d64f5f2 100644 --- a/tpu_commons/models/jax/common/moe/deepseek_moe.py +++ b/tpu_commons/models/jax/common/moe/deepseek_moe.py @@ -325,14 +325,6 @@ def _unpermute(self, processed_tokens: jax.Array, sort_indices: jax.Array, processed_tokens, jnp.argsort(sort_indices)) reshaped_tokens_TXD = unsorted_tokens_tD.reshape( -1, self.num_experts_per_tok, self.hidden_size) - # jax.debug.print( - # "✅ reshaped_tokens_TXD on device: reshaped_tokens_TXD[5]={t}", - # t=reshaped_tokens_TXD[5, 0,:5] - # ) - # jax.debug.print( - # "✅ router_weights_TX on device: router_weights_TX={t}", - # t=router_weights_TX[5, :] - # ) with jax.named_scope("combine_weights"): output_TD = jnp.einsum( "TXD,TX -> TD", @@ -484,9 +476,6 @@ def _distributed_sparse_moe_fwd( compute_expert_ids = global_sorted_experts local_sorted_indices = jnp.arange(sorted_inputs.shape[0]) - #debug_position_in_sorted = jnp.argsort(global_sort_indices)[40:48] - #debug_position_compute_inputs = jnp.argsort(local_sorted_indices)[debug_position_in_sorted] - # 4. Compute: Apply experts using Grouped Matrix Multiply with jax.named_scope("gating"): # compute_inputs: (local total assignments, D) From 96b6f36ebafac43eb513cc21d6242532b5b34e41 Mon Sep 17 00:00:00 2001 From: Jacob Platin <31421084+jrplatin@users.noreply.github.com> Date: Tue, 30 Sep 2025 16:19:15 -0700 Subject: [PATCH 04/11] [JAX][Quantization] Add Qwix support for SparseMatul (#740) Signed-off-by: Jacob Platin --- .../jax/common/moe/test_deepseek_moe.py | 1 - .../models/jax/common/moe/deepseek_moe.py | 47 ++++++++++++++----- 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/tests/models/jax/common/moe/test_deepseek_moe.py b/tests/models/jax/common/moe/test_deepseek_moe.py index d0b0d77bb..6983649b4 100644 --- a/tests/models/jax/common/moe/test_deepseek_moe.py +++ b/tests/models/jax/common/moe/test_deepseek_moe.py @@ -1,4 +1,3 @@ -import os import unittest import jax diff --git a/tpu_commons/models/jax/common/moe/deepseek_moe.py b/tpu_commons/models/jax/common/moe/deepseek_moe.py index 03d64f5f2..018480437 100644 --- a/tpu_commons/models/jax/common/moe/deepseek_moe.py +++ b/tpu_commons/models/jax/common/moe/deepseek_moe.py @@ -1,7 +1,7 @@ import enum from dataclasses import InitVar, dataclass from functools import partial -from typing import Tuple +from typing import Optional, Tuple import jax import jax.numpy as jnp @@ -9,10 +9,14 @@ from flax.typing import Sharding from jax.sharding import PartitionSpec from jaxtyping import Float +from qwix._src.core.ragged_dot import ragged_dot as qwix_ragged_dot +from qwix._src.providers import ptq from tpu_commons.models.jax.common.base import create_param from tpu_commons.models.jax.common.layers import FlaxUtils from tpu_commons.models.jax.common.moe.moe import MoE +from tpu_commons.models.jax.utils.quantization.quantization_utils import ( + manually_quantize_qwix_activation, manually_quantize_qwix_weight) modeling_flax_utils = FlaxUtils() @@ -141,6 +145,8 @@ class SparseMoE(MoE): tile_size: tuple[int, int, int] = (128, 64, 128) use_megablox: bool = False mesh: jax.sharding.Mesh + # This should be set if and only if you have quantized your model (via Qwix) + quantized_dtype: Optional[jnp.dtype] = None def __post_init__(self, rngs: nnx.Rngs): super().__post_init__(rngs) @@ -348,7 +354,11 @@ def _gmm(self, inputs, kernel, group_sizes): raise NotImplementedError( "MegaBlox kernel call is not implemented.") else: - output = jax.lax.ragged_dot( + inputs = manually_quantize_qwix_activation( + inputs, "ragged_dot", jnp.float8_e4m3fn, [0], {}, + "absmax") if self.quantized_dtype else inputs + ragged_dot_func = qwix_ragged_dot if self.quantized_dtype else jax.lax.ragged_dot + output = ragged_dot_func( lhs=inputs, rhs=kernel, group_sizes=group_sizes, @@ -572,12 +582,27 @@ def __call__(self, x_TD: Float): check_rep=False)( SparseMoE._distributed_sparse_moe_fwd) - return mapped_moe_fwd( - self, - x_TD, - router_weights_TX, - selected_experts_TX, - self.kernel_gating_EDF.value, - self.kernel_up_proj_EDF.value, - self.kernel_down_proj_EFD.value, - ) + kernel_gating_EDF = self.kernel_gating_EDF.value + kernel_up_proj_EDF = self.kernel_up_proj_EDF.value + kernel_down_proj_EFD = self.kernel_down_proj_EFD.value + + if self.quantized_dtype: + if not isinstance(kernel_gating_EDF, ptq.WithAux): + kernel_gating_EDF = manually_quantize_qwix_weight( + kernel_gating_EDF, self.quantized_dtype, [0, 2], {}, + "absmax") + if not isinstance(kernel_up_proj_EDF, ptq.WithAux): + kernel_up_proj_EDF = manually_quantize_qwix_weight( + kernel_up_proj_EDF, self.quantized_dtype, [0, 2], {}, + "absmax") + if not isinstance(kernel_down_proj_EFD, ptq.WithAux): + kernel_down_proj_EFD = manually_quantize_qwix_weight( + kernel_down_proj_EFD, self.quantized_dtype, [0, 1], {}, + "absmax") + kernel_gating_EDF = kernel_gating_EDF.array + kernel_up_proj_EDF = kernel_up_proj_EDF.array + kernel_down_proj_EFD = kernel_down_proj_EFD.array + + return mapped_moe_fwd(self, x_TD, router_weights_TX, + selected_experts_TX, kernel_gating_EDF, + kernel_up_proj_EDF, kernel_down_proj_EFD) From 5147850a3dab16e52f8357cd100851b9e4c573dd Mon Sep 17 00:00:00 2001 From: bzgoogle Date: Fri, 3 Oct 2025 21:54:38 +0000 Subject: [PATCH 05/11] local change to support 2d TP for DeepSeek --- .../models/jax/common/moe/deepseek_moe.py | 110 ++++++++++-------- .../jax/attention/deepseek_v3_attention.py | 4 +- tpu_inference/layers/jax/moe/moe.py | 4 +- tpu_inference/models/common/model_loader.py | 4 +- tpu_inference/models/jax/deepseek_v3.py | 38 +++--- tpu_inference/runner/compilation_manager.py | 8 +- tpu_inference/runner/kv_cache.py | 4 +- tpu_inference/runner/kv_cache_manager.py | 10 +- 8 files changed, 99 insertions(+), 83 deletions(-) diff --git a/tpu_commons/models/jax/common/moe/deepseek_moe.py b/tpu_commons/models/jax/common/moe/deepseek_moe.py index 018480437..a079e737e 100644 --- a/tpu_commons/models/jax/common/moe/deepseek_moe.py +++ b/tpu_commons/models/jax/common/moe/deepseek_moe.py @@ -1,7 +1,7 @@ import enum from dataclasses import InitVar, dataclass from functools import partial -from typing import Optional, Tuple +from typing import Tuple import jax import jax.numpy as jnp @@ -9,14 +9,10 @@ from flax.typing import Sharding from jax.sharding import PartitionSpec from jaxtyping import Float -from qwix._src.core.ragged_dot import ragged_dot as qwix_ragged_dot -from qwix._src.providers import ptq from tpu_commons.models.jax.common.base import create_param from tpu_commons.models.jax.common.layers import FlaxUtils from tpu_commons.models.jax.common.moe.moe import MoE -from tpu_commons.models.jax.utils.quantization.quantization_utils import ( - manually_quantize_qwix_activation, manually_quantize_qwix_weight) modeling_flax_utils = FlaxUtils() @@ -140,19 +136,43 @@ class SparseMoE(MoE): # TODO: determine if we get it from external or extrat it in MoE class is_batch_sharded_by_expert: True if batch is sharded over 'expert' dim. """ + def_sharding: Sharding + fed_sharding: Sharding num_experts_per_tok: int #TODO: tile size is (tile_batch_seq, tile_activation_dim, tile_weight_dim,) from MaxText tile_size: tuple[int, int, int] = (128, 64, 128) use_megablox: bool = False mesh: jax.sharding.Mesh - # This should be set if and only if you have quantized your model (via Qwix) - quantized_dtype: Optional[jnp.dtype] = None def __post_init__(self, rngs: nnx.Rngs): - super().__post_init__(rngs) + + D = self.hidden_size + F = self.intermediate_size_moe + # shape_gating = (D, self.num_local_experts, F) + # shape_up = (D, self.num_local_experts, F) + # shape_down = (F, self.num_local_experts,D) + shape_gating = (self.num_local_experts, D, F) + shape_up = (self.num_local_experts, D, F) + shape_down = (self.num_local_experts, F, D) + + self.kernel_gating_DEF = create_param(rngs, + shape=shape_gating, + dtype=self.dtype, + sharding=self.def_sharding, + random_init=self.random_init) + self.kernel_up_proj_DEF = create_param(rngs, + shape=shape_up, + dtype=self.dtype, + sharding=self.def_sharding, + random_init=self.random_init) + self.kernel_down_proj_FED = create_param(rngs, + shape=shape_down, + dtype=self.dtype, + sharding=self.fed_sharding, + random_init=self.random_init) # Derive the expert sharding - self.expert_axis_name = self.edf_sharding[0] + self.expert_axis_name = self.def_sharding[0] if self.expert_axis_name is None: self.num_expert_parallelism = 1 else: @@ -329,20 +349,29 @@ def _unpermute(self, processed_tokens: jax.Array, sort_indices: jax.Array, with jax.named_scope("unpermute"): unsorted_tokens_tD = self._sort_activations( processed_tokens, jnp.argsort(sort_indices)) + D = unsorted_tokens_tD.shape[1] reshaped_tokens_TXD = unsorted_tokens_tD.reshape( - -1, self.num_experts_per_tok, self.hidden_size) + -1, self.num_experts_per_tok, D) + # jax.debug.print( + # "✅ reshaped_tokens_TXD on device: reshaped_tokens_TXD[5]={t}", + # t=reshaped_tokens_TXD[5, 0,:5] + # ) + # jax.debug.print( + # "✅ router_weights_TX on device: router_weights_TX={t}", + # t=router_weights_TX[5, :] + # ) with jax.named_scope("combine_weights"): output_TD = jnp.einsum( "TXD,TX -> TD", - reshaped_tokens_TXD.astype(jnp.float32), - router_weights_TX.astype(jnp.float32), - precision='float32', + reshaped_tokens_TXD.astype(self.dtype), + router_weights_TX.astype(self.dtype), ) return output_TD.astype(self.dtype) def _gmm(self, inputs, kernel, group_sizes): """Performs Grouped Matrix Multiply.""" + jax.config.update("jax_ragged_dot_use_ragged_dot_instruction", True) num_rows = inputs.shape[0] pad_amount = (self.tile_size[0] - num_rows % self.tile_size[0]) % self.tile_size[0] @@ -354,11 +383,8 @@ def _gmm(self, inputs, kernel, group_sizes): raise NotImplementedError( "MegaBlox kernel call is not implemented.") else: - inputs = manually_quantize_qwix_activation( - inputs, "ragged_dot", jnp.float8_e4m3fn, [0], {}, - "absmax") if self.quantized_dtype else inputs - ragged_dot_func = qwix_ragged_dot if self.quantized_dtype else jax.lax.ragged_dot - output = ragged_dot_func( + + output = jax.lax.ragged_dot( lhs=inputs, rhs=kernel, group_sizes=group_sizes, @@ -394,10 +420,12 @@ def _distributed_sparse_moe_fwd( # TODO: update to 'expert' after we enable expert parallelism, currently experts are sharded along model axis # or we sould derive it from the model init - expert_shard_id = jax.lax.axis_index(self.expert_axis_name) + local_expert_size = self.num_local_experts // self.num_expert_parallelism - if self.num_expert_parallelism > 1: + #if self.num_expert_parallelism > 1: + if self.expert_axis_name: + expert_shard_id = jax.lax.axis_index(self.expert_axis_name) if self.is_batch_sharded_by_expert: # When token sharded in devices # In this path, we assume the data(tokens) are fully sharded on expert, namely data_axis_name == expert_axis_name @@ -508,8 +536,9 @@ def _distributed_sparse_moe_fwd( # 5. Return Results (All-to-All) if self.num_expert_parallelism > 1: local_total_assignments = x_TD.shape[0] * self.num_experts_per_tok + D = x_TD.shape[1] output_shape = jnp.zeros( - (local_total_assignments, self.hidden_size), + (local_total_assignments, D), dtype=intermediate_output.dtype) if self.is_batch_sharded_by_expert: @@ -568,10 +597,10 @@ def __call__(self, x_TD: Float): PartitionSpec(*self.activation_ffw_td), # Sharded x_TD PartitionSpec(), # Replicated router_weights_TX PartitionSpec(), # Replicated selected_experts_TX - PartitionSpec(*self.edf_sharding), # Sharded gating kernel - PartitionSpec(*self.edf_sharding), # Sharded up-projection kernel + PartitionSpec(*self.def_sharding), # Sharded gating kernel + PartitionSpec(*self.def_sharding), # Sharded up-projection kernel PartitionSpec( - *self.efd_sharding), # Sharded down-projection kernel + *self.fed_sharding), # Sharded down-projection kernel ) out_specs = PartitionSpec(*self.activation_ffw_td) @@ -582,27 +611,12 @@ def __call__(self, x_TD: Float): check_rep=False)( SparseMoE._distributed_sparse_moe_fwd) - kernel_gating_EDF = self.kernel_gating_EDF.value - kernel_up_proj_EDF = self.kernel_up_proj_EDF.value - kernel_down_proj_EFD = self.kernel_down_proj_EFD.value - - if self.quantized_dtype: - if not isinstance(kernel_gating_EDF, ptq.WithAux): - kernel_gating_EDF = manually_quantize_qwix_weight( - kernel_gating_EDF, self.quantized_dtype, [0, 2], {}, - "absmax") - if not isinstance(kernel_up_proj_EDF, ptq.WithAux): - kernel_up_proj_EDF = manually_quantize_qwix_weight( - kernel_up_proj_EDF, self.quantized_dtype, [0, 2], {}, - "absmax") - if not isinstance(kernel_down_proj_EFD, ptq.WithAux): - kernel_down_proj_EFD = manually_quantize_qwix_weight( - kernel_down_proj_EFD, self.quantized_dtype, [0, 1], {}, - "absmax") - kernel_gating_EDF = kernel_gating_EDF.array - kernel_up_proj_EDF = kernel_up_proj_EDF.array - kernel_down_proj_EFD = kernel_down_proj_EFD.array - - return mapped_moe_fwd(self, x_TD, router_weights_TX, - selected_experts_TX, kernel_gating_EDF, - kernel_up_proj_EDF, kernel_down_proj_EFD) + return mapped_moe_fwd( + self, + x_TD, + router_weights_TX, + selected_experts_TX, + self.kernel_gating_DEF.value, + self.kernel_up_proj_DEF.value, + self.kernel_down_proj_FED.value, + ) diff --git a/tpu_inference/layers/jax/attention/deepseek_v3_attention.py b/tpu_inference/layers/jax/attention/deepseek_v3_attention.py index a1634b923..3f086ee1a 100644 --- a/tpu_inference/layers/jax/attention/deepseek_v3_attention.py +++ b/tpu_inference/layers/jax/attention/deepseek_v3_attention.py @@ -317,13 +317,13 @@ def attention( self.query_tnh, # q self.keyvalue_skh, # k self.keyvalue_skh, # v - P(None, None, "model"), # kv_cache + P(None, None, ('model', 'expert')), # kv_cache P(), # md.seq_lens: Replicated P(), # page_indices_flat: Replicated P(), # query_start_loc: Replicated P(), # distribution: Replicated ) - out_specs = (self.attn_o_tnh, P(None, None, "model")) + out_specs = (self.attn_o_tnh, P(None, None, ('model', 'expert'))) def _ragged_paged_attention(*args): return ragged_paged_attention( diff --git a/tpu_inference/layers/jax/moe/moe.py b/tpu_inference/layers/jax/moe/moe.py index c8e08ea40..a32ccb96f 100644 --- a/tpu_inference/layers/jax/moe/moe.py +++ b/tpu_inference/layers/jax/moe/moe.py @@ -84,8 +84,8 @@ class MoE(nnx.Module): router: nnx.Module activation_ffw_td: Sharding activation_ffw_ted: Sharding - edf_sharding: Sharding - efd_sharding: Sharding + edf_sharding: Sharding = () + efd_sharding: Sharding = () random_init: bool = False def __call__(self, x_TD: Float): diff --git a/tpu_inference/models/common/model_loader.py b/tpu_inference/models/common/model_loader.py index eb142111f..d07f430fa 100644 --- a/tpu_inference/models/common/model_loader.py +++ b/tpu_inference/models/common/model_loader.py @@ -201,7 +201,7 @@ def get_flax_model( vllm_config.model_config.hf_config) jit_model = _get_nnx_model(model_class, vllm_config, rng, mesh) kv_cache_sharding = NamedSharding( - mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, None, "model")) + mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, None, ("model", "expert"))) hidden_states_sharding = NamedSharding(mesh, PartitionSpec( ShardingAxisName.ATTN_DATA, @@ -226,7 +226,7 @@ def run_model(graphdef, state, *args): return model(*args) logits_sharding = NamedSharding( - mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, "model")) + mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, ("model", "expert"))) @functools.partial( jax.jit, diff --git a/tpu_inference/models/jax/deepseek_v3.py b/tpu_inference/models/jax/deepseek_v3.py index 15686e881..583fa48c6 100644 --- a/tpu_inference/models/jax/deepseek_v3.py +++ b/tpu_inference/models/jax/deepseek_v3.py @@ -148,14 +148,14 @@ def _create_mla() -> MLA: rngs=self.rng, activation_attention_td=(None, None), activation_q_td=(None, None), - query_tnh=P(None, 'model', None), - keyvalue_skh=P(None, 'model', None), + query_tnh=P(None, ('model', 'expert'), None), + keyvalue_skh=P(None, ('model', 'expert'), None), activation_attention_out_td=(None, None), - attn_o_tnh=P(None, 'model', None), - q_da_sharding=(None, 'model'), - anh_sharding=(None, 'model', None), - kv_da_sharding=(None, 'model'), - nhd_sharding=('model', None, None)) + attn_o_tnh=P(None, ('model', 'expert'), None), + q_da_sharding=(None, ('model', 'expert')), + anh_sharding=(None, ('model', 'expert'), None), + kv_da_sharding=(None, ('model', 'expert')), + nhd_sharding=(('model', 'expert'), None, None)) for i in range(first_k_dense_replace): block = TransformerBlock( @@ -201,8 +201,8 @@ def _create_mla() -> MLA: routed_scaling_factor=2.5, dtype=dtype, activation_ffw_td=('data', None), - ed_sharding=('model', None), - e_sharding=('model', )) + ed_sharding=(None, None), + e_sharding=(None, )) if self.sparse_matmul: # TODO: orginize the SparseMoE and DenseMoE better given they share most interfaces custom_module = SparseMoE( @@ -216,12 +216,10 @@ def _create_mla() -> MLA: hidden_act=hidden_act, rngs=self.rng, random_init=self.random_init, - activation_ffw_td=('data', None), - activation_ffw_ted=('data', None, None), - edf_sharding=('model', None, None), - efd_sharding=('model', None, None), - quantized_dtype=self.weight_loader.quant_dtype - if self.weight_loader.is_model_quantized else None, + activation_ffw_td=('data', 'model'), + activation_ffw_ted=('data', None, 'model'), + def_sharding=('expert', 'model', None), + fed_sharding=('expert', None, 'model'), router=router) if is_moe_layer else DenseFFW( dtype=dtype, hidden_act=hidden_act, @@ -241,10 +239,10 @@ def _create_mla() -> MLA: hidden_act=hidden_act, rngs=self.rng, random_init=self.random_init, - activation_ffw_td=('data', None), + activation_ffw_td=('data', 'model'), activation_ffw_ted=('data', None, None), - edf_sharding=('model', None, None), - efd_sharding=('model', None, None), + edf_sharding=('expert', 'model', None), + efd_sharding=('expert', None, 'model'), router=router) if is_moe_layer else DenseFFW( dtype=dtype, hidden_act=hidden_act, @@ -865,4 +863,8 @@ def weights_dequant_cpu(x: torch.Tensor, scale = s[M // block_size, j // block_size] y[M_main:M, j:j + block_size] = block * scale +<<<<<<< HEAD:tpu_inference/models/jax/deepseek_v3.py return y.to(j2t_dtype(jnp.dtype(output_dtype))) +======= + return y.to(torch.get_default_dtype()) +>>>>>>> 307bbd62 (local change to support 2d TP for DeepSeek):tpu_commons/models/jax/deepseek_v3.py diff --git a/tpu_inference/runner/compilation_manager.py b/tpu_inference/runner/compilation_manager.py index 42b9b199d..cb9c94df7 100644 --- a/tpu_inference/runner/compilation_manager.py +++ b/tpu_inference/runner/compilation_manager.py @@ -352,7 +352,7 @@ def _precompile_select_from_array(self) -> None: indices_paddings=self.runner.num_reqs_paddings, hidden_dim=vocab_size, input_sharding=NamedSharding(self.runner.mesh, - PartitionSpec(None, "model")), + PartitionSpec(None, ('model', 'expert')), ) self._precompile_select_from_array_helper( name="select target tokens for spec decoding", @@ -360,7 +360,7 @@ def _precompile_select_from_array(self) -> None: indices_paddings=self.runner.num_logits_paddings, hidden_dim=vocab_size, input_sharding=NamedSharding(self.runner.mesh, - PartitionSpec(None, "model")), + PartitionSpec(None, ('model', 'expert')), only_equal_paddings=True, ) @@ -392,7 +392,7 @@ def _precompile_sampling(self) -> None: for num_reqs in self.runner.num_reqs_paddings: logits_sharding = NamedSharding( self.runner.mesh, - PartitionSpec(ShardingAxisName.ATTN_DATA, "model")) + PartitionSpec(ShardingAxisName.ATTN_DATA, ('model', 'expert')) dp_size = self.runner.vllm_config.sharding_config.total_dp_size sampling_metadata_sharding = NamedSharding( self.runner.mesh, PartitionSpec( @@ -482,7 +482,7 @@ def _precompile_rejection_sampler(self) -> None: for num_logits in self.runner.num_logits_paddings: for num_reqs in self.runner.num_reqs_paddings: sharding = NamedSharding(self.runner.mesh, - PartitionSpec(None, "model")) + PartitionSpec(None, ('model', 'expert'))) target_probs = self._create_dummy_tensor( (num_logits, vocab_size), jnp.bfloat16, sharding) draft_token_ids = self._create_dummy_tensor((num_logits, ), diff --git a/tpu_inference/runner/kv_cache.py b/tpu_inference/runner/kv_cache.py index 612ac3bea..f902d0399 100644 --- a/tpu_inference/runner/kv_cache.py +++ b/tpu_inference/runner/kv_cache.py @@ -22,7 +22,7 @@ def get_kv_cache_shape_with_mesh(mesh: Mesh, total_num_pages: int, actual_head_dim: int, kv_dtype: any): """Gets the KV cache shape based on the mesh configuration.""" - model_cnt = mesh.shape["model"] + model_cnt = mesh.shape["model"] * mesh.shape["expert"] assert actual_num_kv_heads % model_cnt == 0 # NOTE(chengjiyao): Currently, the attention kernel is tailored to the # specific model, rather than being determined by the head_dim. If new @@ -79,7 +79,7 @@ def create_kv_caches( sharding = NamedSharding( mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, None, - ShardingAxisName.ATTN_HEAD)) + ('model', 'expert')) def _allocate() -> jax.Array: return jnp.empty( diff --git a/tpu_inference/runner/kv_cache_manager.py b/tpu_inference/runner/kv_cache_manager.py index dcb94a966..ae34a9d48 100644 --- a/tpu_inference/runner/kv_cache_manager.py +++ b/tpu_inference/runner/kv_cache_manager.py @@ -55,7 +55,7 @@ def get_kv_cache_spec(self): # Pad num_kv_heads to multiple of TP size. num_kv_heads = common_utils.get_padded_num_heads( model_config.get_total_num_kv_heads(), - self.runner.mesh.shape["model"]) + self.runner.mesh.shape["model"] * self.runner.mesh.shape["expert"]) head_size = common_utils.get_padded_head_dim( model_config.get_head_size()) for i in range(model_config.get_num_layers(parallel_config)): @@ -78,7 +78,7 @@ def get_kv_cache_spec(self): hf_config = draft_model_config.hf_config num_kv_heads = common_utils.get_padded_num_heads( hf_config.num_key_value_heads, - self.runner.mesh.shape["model"]) + self.runner.mesh.shape["model"] * self.runner.mesh.shape["expert"]) head_size = common_utils.get_padded_head_dim( hf_config.hidden_size // hf_config.num_attention_heads) @@ -120,7 +120,7 @@ def get_kv_cache_spec(self): block_size=block_size, num_kv_heads=common_utils.get_padded_num_heads( attn_module.num_kv_heads, - self.runner.mesh.shape["model"]), + self.runner.mesh.shape["model"] * self.runner.mesh.shape["expert"]), head_size=common_utils.get_padded_head_dim( attn_module.head_size), dtype=self.runner.kv_cache_dtype, @@ -138,7 +138,7 @@ def get_kv_cache_spec(self): block_size=block_size, num_kv_heads=common_utils.get_padded_num_heads( attn_module.num_kv_heads, - self.runner.mesh.shape["model"]), + self.runner.mesh.shape["model"] * self.runner.mesh.shape["expert"]), head_size=common_utils.get_padded_head_dim( attn_module.head_size), dtype=self.runner.kv_cache_dtype) @@ -378,7 +378,7 @@ def transfer_kv_cache(self, f"Transferring kv cache shape {len(kv_cache_slices)} * {kv_cache_slices[0].shape} sharding {kv_cache_slices[0].sharding} size {kv_cache_slices[0].nbytes * len(kv_cache_slices)/1024/1024} Mbytes" ) sharding = NamedSharding(self.runner.mesh, - PartitionSpec(None, "model")) + PartitionSpec(None, ("model", "expert"))) if envs.VLLM_TPU_USING_PATHWAYS: from pathwaysutils.experimental import \ reshard as experimental_reshard From c8abe25593dbf91e4b52e59595a91c98056b80bb Mon Sep 17 00:00:00 2001 From: bzgoogle Date: Fri, 3 Oct 2025 22:24:22 +0000 Subject: [PATCH 06/11] update layer to full --- tpu_inference/models/jax/deepseek_v3.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tpu_inference/models/jax/deepseek_v3.py b/tpu_inference/models/jax/deepseek_v3.py index 583fa48c6..3538978df 100644 --- a/tpu_inference/models/jax/deepseek_v3.py +++ b/tpu_inference/models/jax/deepseek_v3.py @@ -863,8 +863,4 @@ def weights_dequant_cpu(x: torch.Tensor, scale = s[M // block_size, j // block_size] y[M_main:M, j:j + block_size] = block * scale -<<<<<<< HEAD:tpu_inference/models/jax/deepseek_v3.py return y.to(j2t_dtype(jnp.dtype(output_dtype))) -======= - return y.to(torch.get_default_dtype()) ->>>>>>> 307bbd62 (local change to support 2d TP for DeepSeek):tpu_commons/models/jax/deepseek_v3.py From 139494e38178a645d2aa26953e618754bae38a75 Mon Sep 17 00:00:00 2001 From: bzgoogle Date: Sat, 4 Oct 2025 18:55:15 +0000 Subject: [PATCH 07/11] update sharding to support pure 2d TP --- tpu_inference/models/jax/deepseek_v3.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tpu_inference/models/jax/deepseek_v3.py b/tpu_inference/models/jax/deepseek_v3.py index 3538978df..1842aeaad 100644 --- a/tpu_inference/models/jax/deepseek_v3.py +++ b/tpu_inference/models/jax/deepseek_v3.py @@ -120,7 +120,7 @@ def __init__(self, hidden_size=hidden_size, dtype=dtype, rngs=self.rng, - vd_sharding=(('data', 'expert', 'model'), + vd_sharding=(('data', 'model', 'expert'), None), random_init=self.random_init) @@ -218,8 +218,8 @@ def _create_mla() -> MLA: random_init=self.random_init, activation_ffw_td=('data', 'model'), activation_ffw_ted=('data', None, 'model'), - def_sharding=('expert', 'model', None), - fed_sharding=('expert', None, 'model'), + def_sharding=(None , 'model', 'expert'), + fed_sharding=(None , 'expert', 'model'), router=router) if is_moe_layer else DenseFFW( dtype=dtype, hidden_act=hidden_act, @@ -302,8 +302,8 @@ def _create_mla() -> MLA: hidden_size=hidden_size, dtype=dtype, rngs=self.rng, - vd_sharding=(('data', 'expert', 'model'), None), - dv_sharding=(None, ('data', 'expert', 'model')), + vd_sharding=(('data', 'model', 'expert'), None), + dv_sharding=(None, ('data', 'model', 'expert')), random_init=self.random_init) # For compatibility with flax. From f67e580d0fe2777e3cb155e67a52c54acc4539d0 Mon Sep 17 00:00:00 2001 From: bzgoogle Date: Mon, 6 Oct 2025 23:01:48 +0000 Subject: [PATCH 08/11] Fix dtype bug of weight_loading. It not only occurs for sparsematmul, but densematmul --- .../models/jax/common/moe/deepseek_moe.py | 30 +++++++++---------- tpu_inference/models/jax/deepseek_v3.py | 12 +++++--- 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/tpu_commons/models/jax/common/moe/deepseek_moe.py b/tpu_commons/models/jax/common/moe/deepseek_moe.py index a079e737e..5da2a2b61 100644 --- a/tpu_commons/models/jax/common/moe/deepseek_moe.py +++ b/tpu_commons/models/jax/common/moe/deepseek_moe.py @@ -136,8 +136,8 @@ class SparseMoE(MoE): # TODO: determine if we get it from external or extrat it in MoE class is_batch_sharded_by_expert: True if batch is sharded over 'expert' dim. """ - def_sharding: Sharding - fed_sharding: Sharding + edf_sharding: Sharding + efd_sharding: Sharding num_experts_per_tok: int #TODO: tile size is (tile_batch_seq, tile_activation_dim, tile_weight_dim,) from MaxText tile_size: tuple[int, int, int] = (128, 64, 128) @@ -155,24 +155,24 @@ def __post_init__(self, rngs: nnx.Rngs): shape_up = (self.num_local_experts, D, F) shape_down = (self.num_local_experts, F, D) - self.kernel_gating_DEF = create_param(rngs, + self.kernel_gating_EDF = create_param(rngs, shape=shape_gating, dtype=self.dtype, - sharding=self.def_sharding, + sharding=self.edf_sharding, random_init=self.random_init) - self.kernel_up_proj_DEF = create_param(rngs, + self.kernel_up_proj_EDF = create_param(rngs, shape=shape_up, dtype=self.dtype, - sharding=self.def_sharding, + sharding=self.edf_sharding, random_init=self.random_init) - self.kernel_down_proj_FED = create_param(rngs, + self.kernel_down_proj_EFD = create_param(rngs, shape=shape_down, dtype=self.dtype, - sharding=self.fed_sharding, + sharding=self.efd_sharding, random_init=self.random_init) # Derive the expert sharding - self.expert_axis_name = self.def_sharding[0] + self.expert_axis_name = self.edf_sharding[0] if self.expert_axis_name is None: self.num_expert_parallelism = 1 else: @@ -597,10 +597,10 @@ def __call__(self, x_TD: Float): PartitionSpec(*self.activation_ffw_td), # Sharded x_TD PartitionSpec(), # Replicated router_weights_TX PartitionSpec(), # Replicated selected_experts_TX - PartitionSpec(*self.def_sharding), # Sharded gating kernel - PartitionSpec(*self.def_sharding), # Sharded up-projection kernel + PartitionSpec(*self.edf_sharding), # Sharded gating kernel + PartitionSpec(*self.edf_sharding), # Sharded up-projection kernel PartitionSpec( - *self.fed_sharding), # Sharded down-projection kernel + *self.efd_sharding), # Sharded down-projection kernel ) out_specs = PartitionSpec(*self.activation_ffw_td) @@ -616,7 +616,7 @@ def __call__(self, x_TD: Float): x_TD, router_weights_TX, selected_experts_TX, - self.kernel_gating_DEF.value, - self.kernel_up_proj_DEF.value, - self.kernel_down_proj_FED.value, + self.kernel_gating_EDF.value, + self.kernel_up_proj_EDF.value, + self.kernel_down_proj_EFD.value, ) diff --git a/tpu_inference/models/jax/deepseek_v3.py b/tpu_inference/models/jax/deepseek_v3.py index 1842aeaad..6f4ec317e 100644 --- a/tpu_inference/models/jax/deepseek_v3.py +++ b/tpu_inference/models/jax/deepseek_v3.py @@ -218,8 +218,8 @@ def _create_mla() -> MLA: random_init=self.random_init, activation_ffw_td=('data', 'model'), activation_ffw_ted=('data', None, 'model'), - def_sharding=(None , 'model', 'expert'), - fed_sharding=(None , 'expert', 'model'), + edf_sharding=(None , 'model', 'expert'), + efd_sharding=(None , 'expert', 'model'), router=router) if is_moe_layer else DenseFFW( dtype=dtype, hidden_act=hidden_act, @@ -363,7 +363,10 @@ def __init__(self, vllm_config: VllmConfig, num_layers, hidden_size, "is_verbose", None) is not None self.num_routed_experts = num_local_experts self.model_dtype = model_dtype +<<<<<<< HEAD:tpu_inference/models/jax/deepseek_v3.py +======= +>>>>>>> 641cb6d4 (Fix dtype bug of weight_loading. It not only occurs for sparsematmul, but densematmul):tpu_commons/models/jax/deepseek_v3.py self._transpose_map = { # dense mlp r"mlp\.down_proj": (1, 0), @@ -827,9 +830,10 @@ def load_weights(self, model_for_loading: nnx.Module): def weights_dequant_cpu(x: torch.Tensor, s: torch.Tensor, - output_dtype: jnp.dtype, + output_dtype: torch.dtype, block_size: int = 128) -> torch.Tensor: assert x.dim() == 2 and s.dim() == 2, "Both x and s must be 2D tensors" + torch_output_type = DTYPE_VIEW_MAP.get(jnp.dtype(output_dtype)) M, N = x.shape x = x.to(torch.float32) @@ -863,4 +867,4 @@ def weights_dequant_cpu(x: torch.Tensor, scale = s[M // block_size, j // block_size] y[M_main:M, j:j + block_size] = block * scale - return y.to(j2t_dtype(jnp.dtype(output_dtype))) + return y.to(torch_output_type) From e5b5e1118e1fde23aa4a20ac8c71b7fb15065a0d Mon Sep 17 00:00:00 2001 From: bzgoogle Date: Thu, 30 Oct 2025 18:22:00 +0000 Subject: [PATCH 09/11] bug fix after rebase --- tpu_inference/layers/jax/moe/deepseek_v3_moe.py | 9 +++++---- tpu_inference/models/jax/deepseek_v3.py | 4 ---- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/tpu_inference/layers/jax/moe/deepseek_v3_moe.py b/tpu_inference/layers/jax/moe/deepseek_v3_moe.py index 4aff8b9e8..8471f0a9a 100644 --- a/tpu_inference/layers/jax/moe/deepseek_v3_moe.py +++ b/tpu_inference/layers/jax/moe/deepseek_v3_moe.py @@ -19,7 +19,7 @@ manually_quantize_qwix_activation, manually_quantize_qwix_weight) modeling_flax_utils = FlaxUtils() - +jax.config.update("jax_ragged_dot_use_ragged_dot_instruction", True), @dataclass class DeepSeekV3Router(nnx.Module): @@ -329,8 +329,9 @@ def _unpermute(self, processed_tokens: jax.Array, sort_indices: jax.Array, with jax.named_scope("unpermute"): unsorted_tokens_tD = self._sort_activations( processed_tokens, jnp.argsort(sort_indices)) + D = unsorted_tokens_tD.shape[-1] reshaped_tokens_TXD = unsorted_tokens_tD.reshape( - -1, self.num_experts_per_tok, self.hidden_size) + -1, self.num_experts_per_tok, D) with jax.named_scope("combine_weights"): output_TD = jnp.einsum( "TXD,TX -> TD", @@ -394,10 +395,10 @@ def _distributed_sparse_moe_fwd( # TODO: update to 'expert' after we enable expert parallelism, currently experts are sharded along model axis # or we sould derive it from the model init - expert_shard_id = jax.lax.axis_index(self.expert_axis_name) - local_expert_size = self.num_local_experts // self.num_expert_parallelism if self.num_expert_parallelism > 1: + expert_shard_id = jax.lax.axis_index(self.expert_axis_name) + local_expert_size = self.num_local_experts // self.num_expert_parallelism if self.is_batch_sharded_by_expert: # When token sharded in devices # In this path, we assume the data(tokens) are fully sharded on expert, namely data_axis_name == expert_axis_name diff --git a/tpu_inference/models/jax/deepseek_v3.py b/tpu_inference/models/jax/deepseek_v3.py index 6f4ec317e..30f10aa0e 100644 --- a/tpu_inference/models/jax/deepseek_v3.py +++ b/tpu_inference/models/jax/deepseek_v3.py @@ -363,10 +363,6 @@ def __init__(self, vllm_config: VllmConfig, num_layers, hidden_size, "is_verbose", None) is not None self.num_routed_experts = num_local_experts self.model_dtype = model_dtype -<<<<<<< HEAD:tpu_inference/models/jax/deepseek_v3.py - -======= ->>>>>>> 641cb6d4 (Fix dtype bug of weight_loading. It not only occurs for sparsematmul, but densematmul):tpu_commons/models/jax/deepseek_v3.py self._transpose_map = { # dense mlp r"mlp\.down_proj": (1, 0), From b8c59b4b69eaa59dad97e70e02bb039698898418 Mon Sep 17 00:00:00 2001 From: bzgoogle Date: Tue, 11 Nov 2025 18:14:39 +0000 Subject: [PATCH 10/11] fix bug after rebase --- tpu_inference/runner/compilation_manager.py | 33 +++++----- tpu_inference/runner/kv_cache.py | 9 ++- tpu_inference/runner/tpu_runner.py | 70 +++++++++++++++------ 3 files changed, 73 insertions(+), 39 deletions(-) diff --git a/tpu_inference/runner/compilation_manager.py b/tpu_inference/runner/compilation_manager.py index cb9c94df7..a0a5ec99c 100644 --- a/tpu_inference/runner/compilation_manager.py +++ b/tpu_inference/runner/compilation_manager.py @@ -1,13 +1,14 @@ import os import time -from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple +from collections.abc import Callable +from typing import TYPE_CHECKING, Any import jax import jax.numpy as jnp import numpy as np -import vllm.envs as envs from jax.sharding import NamedSharding, PartitionSpec +import vllm.envs as envs from tpu_inference.core.disagg_utils import is_disagg_enabled from tpu_inference.layers.common.attention_metadata import AttentionMetadata from tpu_inference.layers.common.sharding import ShardingAxisName @@ -36,9 +37,9 @@ def __init__(self, runner: "TPUModelRunner"): envs.VLLM_XLA_CACHE_PATH) def _create_dummy_tensor(self, - shape: Tuple[int, ...], + shape: tuple[int, ...], dtype: Any, - sharding: Optional[NamedSharding] = None) -> Any: + sharding: NamedSharding | None = None) -> Any: """Helper to create dummy tensors for precompilation.""" tensor = jnp.ones(shape, dtype=dtype) if sharding: @@ -273,11 +274,11 @@ def _precompile_backbone_with_inputs_embeds(self) -> None: def _precompile_select_from_array_helper( self, name: str, - source_paddings: List[int], - indices_paddings: List[int], + source_paddings: list[int], + indices_paddings: list[int], hidden_dim: int, - input_sharding: Optional[NamedSharding] = None, - indices_sharding: Optional[NamedSharding] = None, + input_sharding: NamedSharding | None = None, + indices_sharding: NamedSharding | None = None, only_equal_paddings: bool = False, check_should_skip_padding: bool = True, ) -> None: @@ -351,16 +352,18 @@ def _precompile_select_from_array(self) -> None: source_paddings=self.runner.num_logits_paddings, indices_paddings=self.runner.num_reqs_paddings, hidden_dim=vocab_size, - input_sharding=NamedSharding(self.runner.mesh, - PartitionSpec(None, ('model', 'expert')), + input_sharding=NamedSharding( + self.runner.mesh, PartitionSpec(None, + ('model', 'expert'))), ) self._precompile_select_from_array_helper( name="select target tokens for spec decoding", source_paddings=self.runner.num_logits_paddings, indices_paddings=self.runner.num_logits_paddings, hidden_dim=vocab_size, - input_sharding=NamedSharding(self.runner.mesh, - PartitionSpec(None, ('model', 'expert')), + input_sharding=NamedSharding( + self.runner.mesh, PartitionSpec(None, + ('model', 'expert'))), only_equal_paddings=True, ) @@ -392,7 +395,7 @@ def _precompile_sampling(self) -> None: for num_reqs in self.runner.num_reqs_paddings: logits_sharding = NamedSharding( self.runner.mesh, - PartitionSpec(ShardingAxisName.ATTN_DATA, ('model', 'expert')) + PartitionSpec(ShardingAxisName.ATTN_DATA, ('model', 'expert'))) dp_size = self.runner.vllm_config.sharding_config.total_dp_size sampling_metadata_sharding = NamedSharding( self.runner.mesh, PartitionSpec( @@ -481,8 +484,8 @@ def _precompile_rejection_sampler(self) -> None: vocab_size = self.runner.model_config.get_vocab_size() for num_logits in self.runner.num_logits_paddings: for num_reqs in self.runner.num_reqs_paddings: - sharding = NamedSharding(self.runner.mesh, - PartitionSpec(None, ('model', 'expert'))) + sharding = NamedSharding( + self.runner.mesh, PartitionSpec(None, ('model', 'expert'))) target_probs = self._create_dummy_tensor( (num_logits, vocab_size), jnp.bfloat16, sharding) draft_token_ids = self._create_dummy_tensor((num_logits, ), diff --git a/tpu_inference/runner/kv_cache.py b/tpu_inference/runner/kv_cache.py index f902d0399..f54ec1861 100644 --- a/tpu_inference/runner/kv_cache.py +++ b/tpu_inference/runner/kv_cache.py @@ -1,4 +1,4 @@ -from typing import Any, List +from typing import Any import jax import jax.numpy as jnp @@ -46,9 +46,9 @@ def create_kv_caches( num_kv_heads: int, head_size: int, mesh: Mesh, - layer_names: List[str], + layer_names: list[str], cache_dtype: jnp.dtype = DEFAULT_KV_CACHE_DTYPE, -) -> List[jax.Array]: +) -> list[jax.Array]: """ Creates a list of KV cache where each array mapps to single attention layer. @@ -78,8 +78,7 @@ def create_kv_caches( sharding = NamedSharding( mesh, - PartitionSpec(ShardingAxisName.ATTN_DATA, None, - ('model', 'expert')) + PartitionSpec(ShardingAxisName.ATTN_DATA, None, ('model', 'expert'))) def _allocate() -> jax.Array: return jnp.empty( diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py index e76b9056b..aa9e254f4 100644 --- a/tpu_inference/runner/tpu_runner.py +++ b/tpu_inference/runner/tpu_runner.py @@ -2,23 +2,56 @@ import functools import os import random +from collections.abc import Callable from contextlib import nullcontext from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Tuple, cast +from typing import Any, cast import jax import jax.numpy as jnp import jaxtyping import numpy as np import torch -import vllm.envs as envs from flax import nnx from jax.experimental import mesh_utils from jax.sharding import NamedSharding, PartitionSpec from torchax.ops.mappings import j2t_dtype + +import vllm.envs as envs +from tpu_inference import utils as common_utils +from tpu_inference.layers.common.attention_metadata import AttentionMetadata +from tpu_inference.layers.jax.sample.rejection_sampler import RejectionSampler +from tpu_inference.layers.jax.sample.sampling import ( + compute_logprobs, + gather_logprobs, + sample, +) +from tpu_inference.layers.jax.sample.sampling_metadata import ( + TPUSupportedSamplingMetadata, +) +from tpu_inference.layers.jax.sharding import ShardingAxisName, ShardingConfigManager +from tpu_inference.logger import init_logger +from tpu_inference.models.common.model_loader import get_model +from tpu_inference.models.jax.utils.weight_utils import ( + shard_put, + transfer_state_with_mappings, +) +from tpu_inference.runner import utils as runner_utils +from tpu_inference.runner.compilation_manager import CompilationManager +from tpu_inference.runner.input_batch_jax import CachedRequestState, InputBatch +from tpu_inference.runner.kv_cache_manager import KVCacheManager +from tpu_inference.runner.lora_utils import LoraUtils +from tpu_inference.runner.multimodal_manager import MultiModalManager +from tpu_inference.runner.persistent_batch_manager import PersistentBatchManager +from tpu_inference.runner.speculative_decoding_manager import ( + SpecDecodeMetadata, + SpeculativeDecodingManager, +) +from tpu_inference.runner.structured_decoding_manager import StructuredDecodingManager +from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer +from tpu_inference.utils import device_array, make_optimized_mesh, time_function from vllm.config import VllmConfig -from vllm.distributed.kv_transfer import (get_kv_transfer_group, - has_kv_transfer_group) +from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group from vllm.forward_context import set_forward_context from vllm.sequence import IntermediateTensors from vllm.tasks import SupportedTask @@ -31,8 +64,7 @@ LogprobsTensors, ModelRunnerOutput) from vllm.v1.request import Request from vllm.v1.spec_decode.ngram_proposer import NgramProposer -from vllm.v1.worker.kv_connector_model_runner_mixin import \ - KVConnectorModelRunnerMixin +from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from tpu_inference import utils as common_utils @@ -108,7 +140,7 @@ def __init__( next_tokens: jax.Array, num_reqs: int, discard_sampled_tokens_req_indices: list[int], - logits_indices_selector: Optional[List[int]] = None, + logits_indices_selector: list[int] | None = None, ): self._model_runner_output = model_runner_output self._next_tokens = next_tokens @@ -137,7 +169,7 @@ class AsyncPreResults: request_seq_lens: list[tuple[int, CachedRequestState, int]] discard_sampled_tokens_req_indices: list[int] placeholder_req_id_to_index: dict[str, int] - logits_indices_selector: Optional[List[int]] = None + logits_indices_selector: list[int] | None = None @dataclass @@ -147,7 +179,7 @@ class ExecuteModelState: scheduler_output: "VllmSchedulerOutput" attn_metadata: AttentionMetadata - input_ids: Optional[jax.Array] + input_ids: jax.Array | None hidden_states: jax.Array logits: jax.Array aux_hidden_states: Optional[jax.Array] @@ -552,7 +584,7 @@ def capture_model(self) -> None: def execute_model( self, scheduler_output: "VllmSchedulerOutput", - intermediate_tensors: Optional[IntermediateTensors] = None, + intermediate_tensors: IntermediateTensors | None = None, ) -> ModelRunnerOutput | None: if self.execute_model_state is not None: raise RuntimeError("State error: sample_tokens() must be called " @@ -797,7 +829,7 @@ def _sample_from_logits( self, scheduler_output: "VllmSchedulerOutput", attn_metadata: AttentionMetadata, - input_ids: Optional[jax.Array], + input_ids: jax.Array | None, hidden_states: jax.Array, logits: jax.Array, aux_hidden_states: Optional[jax.Array], @@ -1617,26 +1649,26 @@ def _get_input_ids_embeds(self, input_ids: jax.Array, else: return input_ids, None - def take_draft_token_ids(self) -> Optional[DraftTokenIds]: + def take_draft_token_ids(self) -> DraftTokenIds | None: return self.speculative_decoding_manager.take_draft_token_ids() ###### Local disagg utilities ###### def get_kv_cache_for_block_ids( self, - block_ids: List[int], - ) -> List[jax.Array]: + block_ids: list[int], + ) -> list[jax.Array]: return self.kv_cache_manager.get_kv_cache_for_block_ids(block_ids) def transfer_kv_cache(self, - kv_cache_slices: List[jax.Array]) -> List[jax.Array]: + kv_cache_slices: list[jax.Array]) -> list[jax.Array]: return self.kv_cache_manager.transfer_kv_cache(kv_cache_slices) def insert_request_with_kv_cache( self, request: "Request", - kv_cache_slices: List[jax.Array], - block_ids: List[List[int]], + kv_cache_slices: list[jax.Array], + block_ids: list[list[int]], ): return self.kv_cache_manager.insert_request_with_kv_cache( request, kv_cache_slices, block_ids) @@ -1646,8 +1678,8 @@ def insert_request_with_kv_cache( def _sync_weights( self, updated_weights: jaxtyping.PyTree, - mappings: Dict[str, Tuple[str, Tuple[str]]], - transpose_keys: Dict[str, Tuple[int]], + mappings: dict[str, tuple[str, tuple[str]]], + transpose_keys: dict[str, tuple[int]], reshard_fn: Callable[[jaxtyping.PyTree, jaxtyping.PyTree], jaxtyping.PyTree] = None ) -> None: From 1a71211a77a006aa6ed165160ba593ca0070331d Mon Sep 17 00:00:00 2001 From: bzgoogle Date: Thu, 27 Nov 2025 00:11:44 +0000 Subject: [PATCH 11/11] Attn optimizaiton for latent vectors --- tpu_inference/models/jax/deepseek_v3.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tpu_inference/models/jax/deepseek_v3.py b/tpu_inference/models/jax/deepseek_v3.py index 30f10aa0e..bc0303de8 100644 --- a/tpu_inference/models/jax/deepseek_v3.py +++ b/tpu_inference/models/jax/deepseek_v3.py @@ -52,7 +52,7 @@ def __init__(self, self.rng = nnx.Rngs(rng) # NOTE: the default is 61 - num_layers: int = vllm_config.model_config.hf_config.num_hidden_layers + num_layers: int = 20 num_local_experts: int = 256 vocab_size: int = 129280 @@ -152,9 +152,9 @@ def _create_mla() -> MLA: keyvalue_skh=P(None, ('model', 'expert'), None), activation_attention_out_td=(None, None), attn_o_tnh=P(None, ('model', 'expert'), None), - q_da_sharding=(None, ('model', 'expert')), + q_da_sharding=('model', None), anh_sharding=(None, ('model', 'expert'), None), - kv_da_sharding=(None, ('model', 'expert')), + kv_da_sharding=('model', None), nhd_sharding=(('model', 'expert'), None, None)) for i in range(first_k_dense_replace): @@ -220,6 +220,8 @@ def _create_mla() -> MLA: activation_ffw_ted=('data', None, 'model'), edf_sharding=(None , 'model', 'expert'), efd_sharding=(None , 'expert', 'model'), + quantized_dtype=self.weight_loader.quant_dtype + if self.weight_loader.is_model_quantized else None, router=router) if is_moe_layer else DenseFFW( dtype=dtype, hidden_act=hidden_act,