@@ -100,9 +100,56 @@ def create_scheduler(self):
100100 noise_scheduler_state = noise_scheduler .set_timesteps (noise_scheduler_state , num_inference_steps = 1000 , training = True )
101101 return noise_scheduler , noise_scheduler_state
102102
103- def calculate_tflops (self , pipeline ):
104- max_logging .log ("WARNING : Calculting tflops is not implemented in Wan 2.1. Returning 0..." )
105- return 0
103+ @staticmethod
104+ def calculate_tflops (pipeline ):
105+
106+ maxdiffusion_config = pipeline .config
107+ # Model configuration
108+ height = pipeline .config .height
109+ width = pipeline .config .width
110+ num_frames = pipeline .config .num_frames
111+
112+ # Transformer dimensions
113+ transformer_config = pipeline .transformer .config
114+ num_layers = transformer_config .num_layers
115+ heads = pipeline .transformer .config .num_attention_heads
116+ head_dim = pipeline .transformer .config .attention_head_dim
117+ ffn_dim = transformer_config .ffn_dim
118+ seq_len = int (((height / 8 ) * (width / 8 ) * ((num_frames - 1 ) // pipeline .vae_scale_factor_temporal + 1 )) / 4 )
119+ text_encoder_dim = 512
120+ # Attention FLOPS
121+ # Self
122+ self_attn_qkv_proj_flops = 3 * (2 * seq_len * (heads * head_dim ) ** 2 )
123+ self_attn_qk_v_flops = 2 * (2 * seq_len ** 2 * (heads * head_dim ))
124+ # Cross
125+ cross_attn_kv_proj_flops = 3 * (2 * text_encoder_dim * (heads * head_dim ) ** 2 )
126+ cross_attn_q_proj_flops = 1 * (2 * seq_len * (heads * head_dim ) ** 2 )
127+ cross_attention_qk_v_flops = 2 * (2 * seq_len * text_encoder_dim * (heads * head_dim ))
128+
129+ # Output_projection from attention
130+ attn_output_proj_flops = 2 * (2 * seq_len * (heads * head_dim ) ** 2 )
131+
132+ total_attn_flops = (
133+ self_attn_qkv_proj_flops
134+ + self_attn_qk_v_flops
135+ + cross_attn_kv_proj_flops
136+ + cross_attn_q_proj_flops
137+ + cross_attention_qk_v_flops
138+ + attn_output_proj_flops
139+ )
140+
141+ # FFN
142+ ffn_flops = 2 * (2 * seq_len * (heads * head_dim ) * ffn_dim )
143+
144+ flops_per_block = total_attn_flops + ffn_flops
145+
146+ total_transformer_flops = flops_per_block * num_layers
147+
148+ tflops = maxdiffusion_config .per_device_batch_size * total_transformer_flops / 1e12
149+ train_tflops = 3 * tflops
150+
151+ max_logging .log (f"Calculated TFLOPs per pass: { train_tflops :.4f} " )
152+ return train_tflops , total_attn_flops , seq_len
106153
107154 def get_data_shardings (self , mesh ):
108155 data_sharding = jax .sharding .NamedSharding (mesh , P (* self .config .data_sharding ))
@@ -225,7 +272,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
225272 )
226273 # TODO - 0 needs to be changed to last step if continuing from an orbax checkpoint.
227274 start_step = 0
228- per_device_tflops = self .calculate_tflops (pipeline )
275+ per_device_tflops , _ , _ = WanTrainer .calculate_tflops (pipeline )
229276 scheduler_state = pipeline .scheduler_state
230277 example_batch = load_next_batch (train_data_iterator , None , self .config )
231278
0 commit comments