Skip to content

Commit 910f6e3

Browse files
bzgooglebzgoogle
authored andcommitted
local change to support 2d TP for DeepSeek
1 parent aed9b57 commit 910f6e3

File tree

8 files changed

+99
-83
lines changed

8 files changed

+99
-83
lines changed

tpu_commons/models/jax/common/moe/deepseek_moe.py

Lines changed: 62 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,18 @@
11
import enum
22
from dataclasses import InitVar, dataclass
33
from functools import partial
4-
from typing import Optional, Tuple
4+
from typing import Tuple
55

66
import jax
77
import jax.numpy as jnp
88
from flax import nnx
99
from flax.typing import Sharding
1010
from jax.sharding import PartitionSpec
1111
from jaxtyping import Float
12-
from qwix._src.core.ragged_dot import ragged_dot as qwix_ragged_dot
13-
from qwix._src.providers import ptq
1412

1513
from tpu_commons.models.jax.common.base import create_param
1614
from tpu_commons.models.jax.common.layers import FlaxUtils
1715
from tpu_commons.models.jax.common.moe.moe import MoE
18-
from tpu_commons.models.jax.utils.quantization.quantization_utils import (
19-
manually_quantize_qwix_activation, manually_quantize_qwix_weight)
2016

2117
modeling_flax_utils = FlaxUtils()
2218

@@ -140,19 +136,43 @@ class SparseMoE(MoE):
140136
# TODO: determine if we get it from external or extrat it in MoE class
141137
is_batch_sharded_by_expert: True if batch is sharded over 'expert' dim.
142138
"""
139+
def_sharding: Sharding
140+
fed_sharding: Sharding
143141
num_experts_per_tok: int
144142
#TODO: tile size is (tile_batch_seq, tile_activation_dim, tile_weight_dim,) from MaxText
145143
tile_size: tuple[int, int, int] = (128, 64, 128)
146144
use_megablox: bool = False
147145
mesh: jax.sharding.Mesh
148-
# This should be set if and only if you have quantized your model (via Qwix)
149-
quantized_dtype: Optional[jnp.dtype] = None
150146

151147
def __post_init__(self, rngs: nnx.Rngs):
152-
super().__post_init__(rngs)
148+
149+
D = self.hidden_size
150+
F = self.intermediate_size_moe
151+
# shape_gating = (D, self.num_local_experts, F)
152+
# shape_up = (D, self.num_local_experts, F)
153+
# shape_down = (F, self.num_local_experts,D)
154+
shape_gating = (self.num_local_experts, D, F)
155+
shape_up = (self.num_local_experts, D, F)
156+
shape_down = (self.num_local_experts, F, D)
157+
158+
self.kernel_gating_DEF = create_param(rngs,
159+
shape=shape_gating,
160+
dtype=self.dtype,
161+
sharding=self.def_sharding,
162+
random_init=self.random_init)
163+
self.kernel_up_proj_DEF = create_param(rngs,
164+
shape=shape_up,
165+
dtype=self.dtype,
166+
sharding=self.def_sharding,
167+
random_init=self.random_init)
168+
self.kernel_down_proj_FED = create_param(rngs,
169+
shape=shape_down,
170+
dtype=self.dtype,
171+
sharding=self.fed_sharding,
172+
random_init=self.random_init)
153173

154174
# Derive the expert sharding
155-
self.expert_axis_name = self.edf_sharding[0]
175+
self.expert_axis_name = self.def_sharding[0]
156176
if self.expert_axis_name is None:
157177
self.num_expert_parallelism = 1
158178
else:
@@ -329,20 +349,29 @@ def _unpermute(self, processed_tokens: jax.Array, sort_indices: jax.Array,
329349
with jax.named_scope("unpermute"):
330350
unsorted_tokens_tD = self._sort_activations(
331351
processed_tokens, jnp.argsort(sort_indices))
352+
D = unsorted_tokens_tD.shape[1]
332353
reshaped_tokens_TXD = unsorted_tokens_tD.reshape(
333-
-1, self.num_experts_per_tok, self.hidden_size)
354+
-1, self.num_experts_per_tok, D)
355+
# jax.debug.print(
356+
# "✅ reshaped_tokens_TXD on device: reshaped_tokens_TXD[5]={t}",
357+
# t=reshaped_tokens_TXD[5, 0,:5]
358+
# )
359+
# jax.debug.print(
360+
# "✅ router_weights_TX on device: router_weights_TX={t}",
361+
# t=router_weights_TX[5, :]
362+
# )
334363
with jax.named_scope("combine_weights"):
335364
output_TD = jnp.einsum(
336365
"TXD,TX -> TD",
337-
reshaped_tokens_TXD.astype(jnp.float32),
338-
router_weights_TX.astype(jnp.float32),
339-
precision='float32',
366+
reshaped_tokens_TXD.astype(self.dtype),
367+
router_weights_TX.astype(self.dtype),
340368
)
341369

342370
return output_TD.astype(self.dtype)
343371

344372
def _gmm(self, inputs, kernel, group_sizes):
345373
"""Performs Grouped Matrix Multiply."""
374+
jax.config.update("jax_ragged_dot_use_ragged_dot_instruction", True)
346375
num_rows = inputs.shape[0]
347376
pad_amount = (self.tile_size[0] -
348377
num_rows % self.tile_size[0]) % self.tile_size[0]
@@ -354,11 +383,8 @@ def _gmm(self, inputs, kernel, group_sizes):
354383
raise NotImplementedError(
355384
"MegaBlox kernel call is not implemented.")
356385
else:
357-
inputs = manually_quantize_qwix_activation(
358-
inputs, "ragged_dot", jnp.float8_e4m3fn, [0], {},
359-
"absmax") if self.quantized_dtype else inputs
360-
ragged_dot_func = qwix_ragged_dot if self.quantized_dtype else jax.lax.ragged_dot
361-
output = ragged_dot_func(
386+
387+
output = jax.lax.ragged_dot(
362388
lhs=inputs,
363389
rhs=kernel,
364390
group_sizes=group_sizes,
@@ -394,10 +420,12 @@ def _distributed_sparse_moe_fwd(
394420

395421
# TODO: update to 'expert' after we enable expert parallelism, currently experts are sharded along model axis
396422
# or we sould derive it from the model init
397-
expert_shard_id = jax.lax.axis_index(self.expert_axis_name)
423+
398424
local_expert_size = self.num_local_experts // self.num_expert_parallelism
399425

400-
if self.num_expert_parallelism > 1:
426+
#if self.num_expert_parallelism > 1:
427+
if self.expert_axis_name:
428+
expert_shard_id = jax.lax.axis_index(self.expert_axis_name)
401429
if self.is_batch_sharded_by_expert:
402430
# When token sharded in devices
403431
# In this path, we assume the data(tokens) are fully sharded on expert, namely data_axis_name == expert_axis_name
@@ -508,8 +536,9 @@ def _distributed_sparse_moe_fwd(
508536
# 5. Return Results (All-to-All)
509537
if self.num_expert_parallelism > 1:
510538
local_total_assignments = x_TD.shape[0] * self.num_experts_per_tok
539+
D = x_TD.shape[1]
511540
output_shape = jnp.zeros(
512-
(local_total_assignments, self.hidden_size),
541+
(local_total_assignments, D),
513542
dtype=intermediate_output.dtype)
514543

515544
if self.is_batch_sharded_by_expert:
@@ -568,10 +597,10 @@ def __call__(self, x_TD: Float):
568597
PartitionSpec(*self.activation_ffw_td), # Sharded x_TD
569598
PartitionSpec(), # Replicated router_weights_TX
570599
PartitionSpec(), # Replicated selected_experts_TX
571-
PartitionSpec(*self.edf_sharding), # Sharded gating kernel
572-
PartitionSpec(*self.edf_sharding), # Sharded up-projection kernel
600+
PartitionSpec(*self.def_sharding), # Sharded gating kernel
601+
PartitionSpec(*self.def_sharding), # Sharded up-projection kernel
573602
PartitionSpec(
574-
*self.efd_sharding), # Sharded down-projection kernel
603+
*self.fed_sharding), # Sharded down-projection kernel
575604
)
576605
out_specs = PartitionSpec(*self.activation_ffw_td)
577606

@@ -582,27 +611,12 @@ def __call__(self, x_TD: Float):
582611
check_rep=False)(
583612
SparseMoE._distributed_sparse_moe_fwd)
584613

585-
kernel_gating_EDF = self.kernel_gating_EDF.value
586-
kernel_up_proj_EDF = self.kernel_up_proj_EDF.value
587-
kernel_down_proj_EFD = self.kernel_down_proj_EFD.value
588-
589-
if self.quantized_dtype:
590-
if not isinstance(kernel_gating_EDF, ptq.WithAux):
591-
kernel_gating_EDF = manually_quantize_qwix_weight(
592-
kernel_gating_EDF, self.quantized_dtype, [0, 2], {},
593-
"absmax")
594-
if not isinstance(kernel_up_proj_EDF, ptq.WithAux):
595-
kernel_up_proj_EDF = manually_quantize_qwix_weight(
596-
kernel_up_proj_EDF, self.quantized_dtype, [0, 2], {},
597-
"absmax")
598-
if not isinstance(kernel_down_proj_EFD, ptq.WithAux):
599-
kernel_down_proj_EFD = manually_quantize_qwix_weight(
600-
kernel_down_proj_EFD, self.quantized_dtype, [0, 1], {},
601-
"absmax")
602-
kernel_gating_EDF = kernel_gating_EDF.array
603-
kernel_up_proj_EDF = kernel_up_proj_EDF.array
604-
kernel_down_proj_EFD = kernel_down_proj_EFD.array
605-
606-
return mapped_moe_fwd(self, x_TD, router_weights_TX,
607-
selected_experts_TX, kernel_gating_EDF,
608-
kernel_up_proj_EDF, kernel_down_proj_EFD)
614+
return mapped_moe_fwd(
615+
self,
616+
x_TD,
617+
router_weights_TX,
618+
selected_experts_TX,
619+
self.kernel_gating_DEF.value,
620+
self.kernel_up_proj_DEF.value,
621+
self.kernel_down_proj_FED.value,
622+
)

tpu_inference/layers/jax/attention/deepseek_v3_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,13 +317,13 @@ def attention(
317317
self.query_tnh, # q
318318
self.keyvalue_skh, # k
319319
self.keyvalue_skh, # v
320-
P(None, None, "model"), # kv_cache
320+
P(None, None, ('model', 'expert')), # kv_cache
321321
P(), # md.seq_lens: Replicated
322322
P(), # page_indices_flat: Replicated
323323
P(), # query_start_loc: Replicated
324324
P(), # distribution: Replicated
325325
)
326-
out_specs = (self.attn_o_tnh, P(None, None, "model"))
326+
out_specs = (self.attn_o_tnh, P(None, None, ('model', 'expert')))
327327

328328
def _ragged_paged_attention(*args):
329329
return ragged_paged_attention(

tpu_inference/layers/jax/moe/moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ class MoE(nnx.Module):
8484
router: nnx.Module
8585
activation_ffw_td: Sharding
8686
activation_ffw_ted: Sharding
87-
edf_sharding: Sharding
88-
efd_sharding: Sharding
87+
edf_sharding: Sharding = ()
88+
efd_sharding: Sharding = ()
8989
random_init: bool = False
9090

9191
def __call__(self, x_TD: Float):

tpu_inference/models/common/model_loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def get_flax_model(
199199
vllm_config.model_config.hf_config)
200200
jit_model = _get_nnx_model(model_class, vllm_config, rng, mesh)
201201
kv_cache_sharding = NamedSharding(
202-
mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, None, "model"))
202+
mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, None, ("model", "expert")))
203203
hidden_states_sharding = NamedSharding(mesh,
204204
PartitionSpec(
205205
ShardingAxisName.ATTN_DATA,
@@ -224,7 +224,7 @@ def run_model(graphdef, state, *args):
224224
return model(*args)
225225

226226
logits_sharding = NamedSharding(
227-
mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, "model"))
227+
mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, ("model", "expert")))
228228

229229
@functools.partial(
230230
jax.jit,

tpu_inference/models/jax/deepseek_v3.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -148,14 +148,14 @@ def _create_mla() -> MLA:
148148
rngs=self.rng,
149149
activation_attention_td=(None, None),
150150
activation_q_td=(None, None),
151-
query_tnh=P(None, 'model', None),
152-
keyvalue_skh=P(None, 'model', None),
151+
query_tnh=P(None, ('model', 'expert'), None),
152+
keyvalue_skh=P(None, ('model', 'expert'), None),
153153
activation_attention_out_td=(None, None),
154-
attn_o_tnh=P(None, 'model', None),
155-
q_da_sharding=(None, 'model'),
156-
anh_sharding=(None, 'model', None),
157-
kv_da_sharding=(None, 'model'),
158-
nhd_sharding=('model', None, None))
154+
attn_o_tnh=P(None, ('model', 'expert'), None),
155+
q_da_sharding=(None, ('model', 'expert')),
156+
anh_sharding=(None, ('model', 'expert'), None),
157+
kv_da_sharding=(None, ('model', 'expert')),
158+
nhd_sharding=(('model', 'expert'), None, None))
159159

160160
for i in range(first_k_dense_replace):
161161
block = TransformerBlock(
@@ -201,8 +201,8 @@ def _create_mla() -> MLA:
201201
routed_scaling_factor=2.5,
202202
dtype=dtype,
203203
activation_ffw_td=('data', None),
204-
ed_sharding=('model', None),
205-
e_sharding=('model', ))
204+
ed_sharding=(None, None),
205+
e_sharding=(None, ))
206206
if self.sparse_matmul:
207207
# TODO: orginize the SparseMoE and DenseMoE better given they share most interfaces
208208
custom_module = SparseMoE(
@@ -216,12 +216,10 @@ def _create_mla() -> MLA:
216216
hidden_act=hidden_act,
217217
rngs=self.rng,
218218
random_init=self.random_init,
219-
activation_ffw_td=('data', None),
220-
activation_ffw_ted=('data', None, None),
221-
edf_sharding=('model', None, None),
222-
efd_sharding=('model', None, None),
223-
quantized_dtype=self.weight_loader.quant_dtype
224-
if self.weight_loader.is_model_quantized else None,
219+
activation_ffw_td=('data', 'model'),
220+
activation_ffw_ted=('data', None, 'model'),
221+
def_sharding=('expert', 'model', None),
222+
fed_sharding=('expert', None, 'model'),
225223
router=router) if is_moe_layer else DenseFFW(
226224
dtype=dtype,
227225
hidden_act=hidden_act,
@@ -241,10 +239,10 @@ def _create_mla() -> MLA:
241239
hidden_act=hidden_act,
242240
rngs=self.rng,
243241
random_init=self.random_init,
244-
activation_ffw_td=('data', None),
242+
activation_ffw_td=('data', 'model'),
245243
activation_ffw_ted=('data', None, None),
246-
edf_sharding=('model', None, None),
247-
efd_sharding=('model', None, None),
244+
edf_sharding=('expert', 'model', None),
245+
efd_sharding=('expert', None, 'model'),
248246
router=router) if is_moe_layer else DenseFFW(
249247
dtype=dtype,
250248
hidden_act=hidden_act,
@@ -865,4 +863,8 @@ def weights_dequant_cpu(x: torch.Tensor,
865863
scale = s[M // block_size, j // block_size]
866864
y[M_main:M, j:j + block_size] = block * scale
867865

866+
<<<<<<< HEAD:tpu_inference/models/jax/deepseek_v3.py
868867
return y.to(j2t_dtype(jnp.dtype(output_dtype)))
868+
=======
869+
return y.to(torch.get_default_dtype())
870+
>>>>>>> 307bbd62 (local change to support 2d TP for DeepSeek):tpu_commons/models/jax/deepseek_v3.py

tpu_inference/runner/compilation_manager.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -350,15 +350,15 @@ def _precompile_select_from_array(self) -> None:
350350
indices_paddings=self.runner.num_reqs_paddings,
351351
hidden_dim=vocab_size,
352352
input_sharding=NamedSharding(self.runner.mesh,
353-
PartitionSpec(None, "model")),
353+
PartitionSpec(None, ('model', 'expert')),
354354
)
355355
self._precompile_select_from_array_helper(
356356
name="select target tokens for spec decoding",
357357
source_paddings=self.runner.num_logits_paddings,
358358
indices_paddings=self.runner.num_logits_paddings,
359359
hidden_dim=vocab_size,
360360
input_sharding=NamedSharding(self.runner.mesh,
361-
PartitionSpec(None, "model")),
361+
PartitionSpec(None, ('model', 'expert')),
362362
only_equal_paddings=True,
363363
)
364364

@@ -390,7 +390,7 @@ def _precompile_sampling(self) -> None:
390390
for num_reqs in self.runner.num_reqs_paddings:
391391
logits_sharding = NamedSharding(
392392
self.runner.mesh,
393-
PartitionSpec(ShardingAxisName.ATTN_DATA, "model"))
393+
PartitionSpec(ShardingAxisName.ATTN_DATA, ('model', 'expert'))
394394
dp_size = self.runner.vllm_config.sharding_config.total_dp_size
395395
sampling_metadata_sharding = NamedSharding(
396396
self.runner.mesh, PartitionSpec(
@@ -480,7 +480,7 @@ def _precompile_rejection_sampler(self) -> None:
480480
for num_logits in self.runner.num_logits_paddings:
481481
for num_reqs in self.runner.num_reqs_paddings:
482482
sharding = NamedSharding(self.runner.mesh,
483-
PartitionSpec(None, "model"))
483+
PartitionSpec(None, ('model', 'expert')))
484484
target_probs = self._create_dummy_tensor(
485485
(num_logits, vocab_size), jnp.bfloat16, sharding)
486486
draft_token_ids = self._create_dummy_tensor((num_logits, ),

tpu_inference/runner/kv_cache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def get_kv_cache_shape_with_mesh(mesh: Mesh, total_num_pages: int,
2222
actual_head_dim: int, kv_dtype: any):
2323
"""Gets the KV cache shape based on the mesh configuration."""
2424

25-
model_cnt = mesh.shape["model"]
25+
model_cnt = mesh.shape["model"] * mesh.shape["expert"]
2626
assert actual_num_kv_heads % model_cnt == 0
2727
# NOTE(chengjiyao): Currently, the attention kernel is tailored to the
2828
# specific model, rather than being determined by the head_dim. If new
@@ -79,7 +79,7 @@ def create_kv_caches(
7979
sharding = NamedSharding(
8080
mesh,
8181
PartitionSpec(ShardingAxisName.ATTN_DATA, None,
82-
ShardingAxisName.ATTN_HEAD))
82+
('model', 'expert'))
8383

8484
def _allocate() -> jax.Array:
8585
return jnp.empty(

0 commit comments

Comments
 (0)