11import enum
22from dataclasses import InitVar , dataclass
33from functools import partial
4- from typing import Optional , Tuple
4+ from typing import Tuple
55
66import jax
77import jax .numpy as jnp
88from flax import nnx
99from flax .typing import Sharding
1010from jax .sharding import PartitionSpec
1111from jaxtyping import Float
12- from qwix ._src .core .ragged_dot import ragged_dot as qwix_ragged_dot
13- from qwix ._src .providers import ptq
1412
1513from tpu_commons .models .jax .common .base import create_param
1614from tpu_commons .models .jax .common .layers import FlaxUtils
1715from tpu_commons .models .jax .common .moe .moe import MoE
18- from tpu_commons .models .jax .utils .quantization .quantization_utils import (
19- manually_quantize_qwix_activation , manually_quantize_qwix_weight )
2016
2117modeling_flax_utils = FlaxUtils ()
2218
@@ -140,19 +136,43 @@ class SparseMoE(MoE):
140136 # TODO: determine if we get it from external or extrat it in MoE class
141137 is_batch_sharded_by_expert: True if batch is sharded over 'expert' dim.
142138 """
139+ def_sharding : Sharding
140+ fed_sharding : Sharding
143141 num_experts_per_tok : int
144142 #TODO: tile size is (tile_batch_seq, tile_activation_dim, tile_weight_dim,) from MaxText
145143 tile_size : tuple [int , int , int ] = (128 , 64 , 128 )
146144 use_megablox : bool = False
147145 mesh : jax .sharding .Mesh
148- # This should be set if and only if you have quantized your model (via Qwix)
149- quantized_dtype : Optional [jnp .dtype ] = None
150146
151147 def __post_init__ (self , rngs : nnx .Rngs ):
152- super ().__post_init__ (rngs )
148+
149+ D = self .hidden_size
150+ F = self .intermediate_size_moe
151+ # shape_gating = (D, self.num_local_experts, F)
152+ # shape_up = (D, self.num_local_experts, F)
153+ # shape_down = (F, self.num_local_experts,D)
154+ shape_gating = (self .num_local_experts , D , F )
155+ shape_up = (self .num_local_experts , D , F )
156+ shape_down = (self .num_local_experts , F , D )
157+
158+ self .kernel_gating_DEF = create_param (rngs ,
159+ shape = shape_gating ,
160+ dtype = self .dtype ,
161+ sharding = self .def_sharding ,
162+ random_init = self .random_init )
163+ self .kernel_up_proj_DEF = create_param (rngs ,
164+ shape = shape_up ,
165+ dtype = self .dtype ,
166+ sharding = self .def_sharding ,
167+ random_init = self .random_init )
168+ self .kernel_down_proj_FED = create_param (rngs ,
169+ shape = shape_down ,
170+ dtype = self .dtype ,
171+ sharding = self .fed_sharding ,
172+ random_init = self .random_init )
153173
154174 # Derive the expert sharding
155- self .expert_axis_name = self .edf_sharding [0 ]
175+ self .expert_axis_name = self .def_sharding [0 ]
156176 if self .expert_axis_name is None :
157177 self .num_expert_parallelism = 1
158178 else :
@@ -329,20 +349,29 @@ def _unpermute(self, processed_tokens: jax.Array, sort_indices: jax.Array,
329349 with jax .named_scope ("unpermute" ):
330350 unsorted_tokens_tD = self ._sort_activations (
331351 processed_tokens , jnp .argsort (sort_indices ))
352+ D = unsorted_tokens_tD .shape [1 ]
332353 reshaped_tokens_TXD = unsorted_tokens_tD .reshape (
333- - 1 , self .num_experts_per_tok , self .hidden_size )
354+ - 1 , self .num_experts_per_tok , D )
355+ # jax.debug.print(
356+ # "✅ reshaped_tokens_TXD on device: reshaped_tokens_TXD[5]={t}",
357+ # t=reshaped_tokens_TXD[5, 0,:5]
358+ # )
359+ # jax.debug.print(
360+ # "✅ router_weights_TX on device: router_weights_TX={t}",
361+ # t=router_weights_TX[5, :]
362+ # )
334363 with jax .named_scope ("combine_weights" ):
335364 output_TD = jnp .einsum (
336365 "TXD,TX -> TD" ,
337- reshaped_tokens_TXD .astype (jnp .float32 ),
338- router_weights_TX .astype (jnp .float32 ),
339- precision = 'float32' ,
366+ reshaped_tokens_TXD .astype (self .dtype ),
367+ router_weights_TX .astype (self .dtype ),
340368 )
341369
342370 return output_TD .astype (self .dtype )
343371
344372 def _gmm (self , inputs , kernel , group_sizes ):
345373 """Performs Grouped Matrix Multiply."""
374+ jax .config .update ("jax_ragged_dot_use_ragged_dot_instruction" , True )
346375 num_rows = inputs .shape [0 ]
347376 pad_amount = (self .tile_size [0 ] -
348377 num_rows % self .tile_size [0 ]) % self .tile_size [0 ]
@@ -354,11 +383,8 @@ def _gmm(self, inputs, kernel, group_sizes):
354383 raise NotImplementedError (
355384 "MegaBlox kernel call is not implemented." )
356385 else :
357- inputs = manually_quantize_qwix_activation (
358- inputs , "ragged_dot" , jnp .float8_e4m3fn , [0 ], {},
359- "absmax" ) if self .quantized_dtype else inputs
360- ragged_dot_func = qwix_ragged_dot if self .quantized_dtype else jax .lax .ragged_dot
361- output = ragged_dot_func (
386+
387+ output = jax .lax .ragged_dot (
362388 lhs = inputs ,
363389 rhs = kernel ,
364390 group_sizes = group_sizes ,
@@ -394,10 +420,12 @@ def _distributed_sparse_moe_fwd(
394420
395421 # TODO: update to 'expert' after we enable expert parallelism, currently experts are sharded along model axis
396422 # or we sould derive it from the model init
397- expert_shard_id = jax . lax . axis_index ( self . expert_axis_name )
423+
398424 local_expert_size = self .num_local_experts // self .num_expert_parallelism
399425
400- if self .num_expert_parallelism > 1 :
426+ #if self.num_expert_parallelism > 1:
427+ if self .expert_axis_name :
428+ expert_shard_id = jax .lax .axis_index (self .expert_axis_name )
401429 if self .is_batch_sharded_by_expert :
402430 # When token sharded in devices
403431 # In this path, we assume the data(tokens) are fully sharded on expert, namely data_axis_name == expert_axis_name
@@ -508,8 +536,9 @@ def _distributed_sparse_moe_fwd(
508536 # 5. Return Results (All-to-All)
509537 if self .num_expert_parallelism > 1 :
510538 local_total_assignments = x_TD .shape [0 ] * self .num_experts_per_tok
539+ D = x_TD .shape [1 ]
511540 output_shape = jnp .zeros (
512- (local_total_assignments , self . hidden_size ),
541+ (local_total_assignments , D ),
513542 dtype = intermediate_output .dtype )
514543
515544 if self .is_batch_sharded_by_expert :
@@ -568,10 +597,10 @@ def __call__(self, x_TD: Float):
568597 PartitionSpec (* self .activation_ffw_td ), # Sharded x_TD
569598 PartitionSpec (), # Replicated router_weights_TX
570599 PartitionSpec (), # Replicated selected_experts_TX
571- PartitionSpec (* self .edf_sharding ), # Sharded gating kernel
572- PartitionSpec (* self .edf_sharding ), # Sharded up-projection kernel
600+ PartitionSpec (* self .def_sharding ), # Sharded gating kernel
601+ PartitionSpec (* self .def_sharding ), # Sharded up-projection kernel
573602 PartitionSpec (
574- * self .efd_sharding ), # Sharded down-projection kernel
603+ * self .fed_sharding ), # Sharded down-projection kernel
575604 )
576605 out_specs = PartitionSpec (* self .activation_ffw_td )
577606
@@ -582,27 +611,12 @@ def __call__(self, x_TD: Float):
582611 check_rep = False )(
583612 SparseMoE ._distributed_sparse_moe_fwd )
584613
585- kernel_gating_EDF = self .kernel_gating_EDF .value
586- kernel_up_proj_EDF = self .kernel_up_proj_EDF .value
587- kernel_down_proj_EFD = self .kernel_down_proj_EFD .value
588-
589- if self .quantized_dtype :
590- if not isinstance (kernel_gating_EDF , ptq .WithAux ):
591- kernel_gating_EDF = manually_quantize_qwix_weight (
592- kernel_gating_EDF , self .quantized_dtype , [0 , 2 ], {},
593- "absmax" )
594- if not isinstance (kernel_up_proj_EDF , ptq .WithAux ):
595- kernel_up_proj_EDF = manually_quantize_qwix_weight (
596- kernel_up_proj_EDF , self .quantized_dtype , [0 , 2 ], {},
597- "absmax" )
598- if not isinstance (kernel_down_proj_EFD , ptq .WithAux ):
599- kernel_down_proj_EFD = manually_quantize_qwix_weight (
600- kernel_down_proj_EFD , self .quantized_dtype , [0 , 1 ], {},
601- "absmax" )
602- kernel_gating_EDF = kernel_gating_EDF .array
603- kernel_up_proj_EDF = kernel_up_proj_EDF .array
604- kernel_down_proj_EFD = kernel_down_proj_EFD .array
605-
606- return mapped_moe_fwd (self , x_TD , router_weights_TX ,
607- selected_experts_TX , kernel_gating_EDF ,
608- kernel_up_proj_EDF , kernel_down_proj_EFD )
614+ return mapped_moe_fwd (
615+ self ,
616+ x_TD ,
617+ router_weights_TX ,
618+ selected_experts_TX ,
619+ self .kernel_gating_DEF .value ,
620+ self .kernel_up_proj_DEF .value ,
621+ self .kernel_down_proj_FED .value ,
622+ )
0 commit comments