3333 from paddle .distributed .fleet .meta_parallel .zero_bubble_utils import EventStore
3434except ImportError :
3535 EventStore = None
36+
3637from paddle .distributed .fleet .recompute .recompute import recompute
3738from paddle .distributed .fleet .utils .sequence_parallel_utils import ScatterOp
3839
@@ -598,6 +599,7 @@ def __init__(
598599 mlp_layer ,
599600 send_mtp_embed ,
600601 using_post_norm_recompute = False ,
602+ stepped_recompute_fwd_gate_up = False ,
601603 name = "" ,
602604 ):
603605 self .attn_and_gate_node = attn_and_gate_node
@@ -606,6 +608,7 @@ def __init__(
606608 self .send_mtp_embed = send_mtp_embed
607609
608610 self .using_post_norm_recompute = using_post_norm_recompute
611+ self .stepped_recompute_fwd_gate_up = stepped_recompute_fwd_gate_up
609612 self .name = name
610613
611614 self .moe_group = mlp_layer .moe_group
@@ -1058,6 +1061,8 @@ def backward_for_fusion(self, output_grad, combine_bw_event_to_wait=None, pp_str
10581061 return output_grad , event_to_wait
10591062
10601063 def forward (self , inputs ):
1064+ if self .stepped_recompute_fwd_gate_up :
1065+ self .fp8_fusion_moe_node .mlp_node .set_recompute_fwd_gate_up (True )
10611066 inputs = self .attn_forward (inputs )
10621067 inputs = self .dispatch_forward (inputs )
10631068 inputs = self .mlp_forward (inputs )
@@ -1820,6 +1825,7 @@ def build_schedule_node(self):
18201825 mlp_layer = self .mlp ,
18211826 send_mtp_embed = self .config .send_mtp_embed ,
18221827 using_post_norm_recompute = self .config .using_post_norm_recompute ,
1828+ stepped_recompute_fwd_gate_up = self .config .stepped_recompute_fwd_gate_up ,
18231829 name = "FusionFp8DecoderLayerNode" ,
18241830 )
18251831 else :
0 commit comments