@@ -50,6 +50,10 @@ def swiglu(x, y=None):
5050]
5151
5252
53+ def get_sm_num ():
54+ return 112
55+
56+
5357def set_parameter_color (
5458 parameters , color , group = None , offline_quant_expert_weight = True , clear_origin_weight_when_offline_quant = True
5559):
@@ -159,7 +163,7 @@ def padding_and_quant_input(tensor):
159163 tensor_t_fp8 , tensor_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
160164 tensor ,
161165 output_scale_transpose = True ,
162- tquant_method = "1x128" ,
166+ quant_method = "1x128" ,
163167 input_transpose = True ,
164168 return_transpose_only = True ,
165169 )
@@ -178,7 +182,7 @@ def kitchen_gemm(
178182 if out is None :
179183 out = paddle .zeros ([x_fp8 .shape [0 ], w_fp8 .shape [0 ]], rtn_dtype )
180184 if numpy .prod (x_fp8 .shape ) != 0 and numpy .prod (w_fp8 .shape ) != 0 :
181- deep_gemm .wgrad_gemm_fp8_fp8_fp32_nt ((x_fp8 , x_scale ), (w_fp8 , w_scale ), out , num_sms = 118 )
185+ deep_gemm .wgrad_gemm_fp8_fp8_fp32_nt ((x_fp8 , x_scale ), (w_fp8 , w_scale ), out , num_sms = get_sm_num () )
182186 return out
183187
184188 if out is not None :
@@ -261,7 +265,9 @@ def compute_fp8_linear(
261265 if out is None :
262266 out = paddle .empty ([input_fp8 .shape [0 ], weight_fp8 .shape [0 ]], dtype = weight .dtype )
263267
264- deep_gemm .gemm_fp8_fp8_bf16_nt ((input_fp8 , input_scale .T ), (weight_fp8 , weight_scale ), out , num_sms = 118 )
268+ deep_gemm .gemm_fp8_fp8_bf16_nt (
269+ (input_fp8 , input_scale .T ), (weight_fp8 , weight_scale ), out , num_sms = get_sm_num ()
270+ )
265271
266272 # Return outputs
267273 if return_mode == "output_only" :
@@ -351,7 +357,7 @@ def common_fp8_mlp_bwd(
351357 # Recompute o1 using deep_gemm(x_fp8, w1_t_fp8)
352358 w1_fp8 , w1_scale = weight_quant (w1 , True )
353359 o1 = paddle .empty ([x_fp8 .shape [0 ], w1_fp8 .shape [0 ]], dtype = do3 .dtype )
354- deep_gemm .gemm_fp8_fp8_bf16_nt ((x_fp8 , x_scale .T ), (w1_fp8 , w1_scale ), o1 , num_sms = 118 )
360+ deep_gemm .gemm_fp8_fp8_bf16_nt ((x_fp8 , x_scale .T ), (w1_fp8 , w1_scale ), o1 , num_sms = get_sm_num () )
355361
356362 # ===== [recompute] o2 = swiglu(o1) =====
357363 o2 = swiglu (o1 )
@@ -838,7 +844,7 @@ def split_group_gemm(x_fp8, x_scale, w_fp8, w_scale, tokens_per_expert, gemm_out
838844 (x_fp8 [start_idx :end_idx ], x_scale_tma_align ),
839845 (w_fp8 [i ], w_scale [i ]),
840846 gemm_out [start_idx :end_idx ],
841- num_sms = 118 ,
847+ num_sms = get_sm_num () ,
842848 )
843849
844850 start_idx = end_idx
@@ -927,7 +933,7 @@ def fwd_gate_up(self, x, expert_w1, num_expert, tokens_per_expert, m_indices=Non
927933 (w1_t_quant , w1_t_scale ),
928934 o1 ,
929935 m_indices = self .m_indices if m_indices is None else m_indices ,
930- num_sms = 118 ,
936+ num_sms = get_sm_num () ,
931937 )
932938
933939 if m_indices is None :
@@ -981,7 +987,7 @@ def fwd_down(
981987 (w2_quant , w2_scale ),
982988 o3 ,
983989 m_indices = m_indices if self .fwd_subbatch else self .m_indices ,
984- num_sms = 118 ,
990+ num_sms = get_sm_num () ,
985991 )
986992
987993 return o3
@@ -1022,7 +1028,7 @@ def bwd_dowm_input(self, expert_w2, unzipped_grad, o1, tokens_per_expert, m_indi
10221028 (bw_w2_quant , bw_w2_scale ),
10231029 do2_s ,
10241030 m_indices = m_indices if self .bwd_subbatch else self .m_indices ,
1025- num_sms = 118 ,
1031+ num_sms = get_sm_num () ,
10261032 )
10271033
10281034 with paddle .amp .auto_cast (False ):
@@ -1068,7 +1074,7 @@ def bwd_gate_up_input(self, do1, expert_w1, tokens_per_expert, m_indices=None, d
10681074 (bw_w1_quant , bw_w1_scale ),
10691075 dx ,
10701076 m_indices = m_indices if self .bwd_subbatch else self .m_indices ,
1071- num_sms = 118 ,
1077+ num_sms = get_sm_num () ,
10721078 )
10731079
10741080 return dx
0 commit comments