@@ -120,7 +120,7 @@ def __init__(self,
120120 hidden_size = hidden_size ,
121121 dtype = dtype ,
122122 rngs = self .rng ,
123- vd_sharding = (('data' , 'expert ' , 'model ' ),
123+ vd_sharding = (('data' , 'model ' , 'expert ' ),
124124 None ),
125125 random_init = self .random_init )
126126
@@ -218,8 +218,8 @@ def _create_mla() -> MLA:
218218 random_init = self .random_init ,
219219 activation_ffw_td = ('data' , 'model' ),
220220 activation_ffw_ted = ('data' , None , 'model' ),
221- def_sharding = ('expert' , 'model' , None ),
222- fed_sharding = ('expert' , None , 'model' ),
221+ def_sharding = (None , 'model' , 'expert' ),
222+ fed_sharding = (None , 'expert' , 'model' ),
223223 router = router ) if is_moe_layer else DenseFFW (
224224 dtype = dtype ,
225225 hidden_act = hidden_act ,
@@ -302,8 +302,8 @@ def _create_mla() -> MLA:
302302 hidden_size = hidden_size ,
303303 dtype = dtype ,
304304 rngs = self .rng ,
305- vd_sharding = (('data' , 'expert ' , 'model ' ), None ),
306- dv_sharding = (None , ('data' , 'expert ' , 'model ' )),
305+ vd_sharding = (('data' , 'model ' , 'expert ' ), None ),
306+ dv_sharding = (None , ('data' , 'model ' , 'expert ' )),
307307 random_init = self .random_init )
308308
309309 # For compatibility with flax.
0 commit comments