Skip to content

Commit 2df8931

Browse files
bzgooglebzgoogle
authored andcommitted
update sharding to support pure 2d TP
1 parent 2f69cae commit 2df8931

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

tpu_inference/models/jax/deepseek_v3.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def __init__(self,
120120
hidden_size=hidden_size,
121121
dtype=dtype,
122122
rngs=self.rng,
123-
vd_sharding=(('data', 'expert', 'model'),
123+
vd_sharding=(('data', 'model', 'expert'),
124124
None),
125125
random_init=self.random_init)
126126

@@ -218,8 +218,8 @@ def _create_mla() -> MLA:
218218
random_init=self.random_init,
219219
activation_ffw_td=('data', 'model'),
220220
activation_ffw_ted=('data', None, 'model'),
221-
def_sharding=('expert', 'model', None),
222-
fed_sharding=('expert', None, 'model'),
221+
def_sharding=(None , 'model', 'expert'),
222+
fed_sharding=(None , 'expert', 'model'),
223223
router=router) if is_moe_layer else DenseFFW(
224224
dtype=dtype,
225225
hidden_act=hidden_act,
@@ -302,8 +302,8 @@ def _create_mla() -> MLA:
302302
hidden_size=hidden_size,
303303
dtype=dtype,
304304
rngs=self.rng,
305-
vd_sharding=(('data', 'expert', 'model'), None),
306-
dv_sharding=(None, ('data', 'expert', 'model')),
305+
vd_sharding=(('data', 'model', 'expert'), None),
306+
dv_sharding=(None, ('data', 'model', 'expert')),
307307
random_init=self.random_init)
308308

309309
# For compatibility with flax.

0 commit comments

Comments
 (0)