Skip to content

Commit c3bf323

Browse files
authored
Wan flops calc (#243)
* wan 2.1 flops estimation * linting. * add test. Modify padding to use kv instead of kv_compute
1 parent 8fdf3c2 commit c3bf323

File tree

3 files changed

+89
-6
lines changed

3 files changed

+89
-6
lines changed

src/maxdiffusion/models/attention_flax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,8 @@ def _tpu_flash_attention(
214214
def wrap_flash_attention(query, key, value):
215215

216216
query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_sizes.block_q)
217-
key, _, key_seq_len = _pad_data_for_flash(key, heads, block_sizes.block_kv_compute)
218-
value, _, _ = _pad_data_for_flash(value, heads, block_sizes.block_kv_compute)
217+
key, _, key_seq_len = _pad_data_for_flash(key, heads, block_sizes.block_kv)
218+
value, _, _ = _pad_data_for_flash(value, heads, block_sizes.block_kv)
219219

220220
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
221221
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])

src/maxdiffusion/tests/flop_calculations_test.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import os
22
import unittest
3+
from unittest.mock import Mock
34
import jax
45
from jax.sharding import Mesh
56
import flax.linen as nn
67
from absl.testing import absltest
78
from maxdiffusion.max_utils import calculate_model_tflops
89
from maxdiffusion.models.attention_flax import FlaxAttention
10+
from maxdiffusion.models.wan.transformers.transformer_wan import WanModel
911
from .. import pyconfig, max_utils
12+
from maxdiffusion.trainers.wan_trainer import WanTrainer
1013

1114
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
1215

@@ -20,6 +23,39 @@ def setUp(self):
2023
devices_array = max_utils.create_device_mesh(self.config)
2124
self.mesh = Mesh(devices_array, self.config.mesh_axes)
2225

26+
def assertFlopsAlmostEqual(self, flops1, flops2, rel_tol=5e-2):
27+
"""Assert that two FLOPs values are almost equal, within 5% relative tolerance."""
28+
self.assertTrue(
29+
abs(flops1 - flops2) / max(abs(flops1), abs(flops2)) <= rel_tol,
30+
f"FLOPs values are not equal: {flops1} != {flops2} (rel_tol={rel_tol:.2e})",
31+
)
32+
33+
def test_wan_21_flops(self):
34+
pyconfig.initialize(
35+
[
36+
None,
37+
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
38+
"width=1280",
39+
"height=720",
40+
"num_frames=81",
41+
"per_device_batch_size=1",
42+
],
43+
unittest=True,
44+
)
45+
config = pyconfig.config
46+
wan_config = WanModel.load_config(config.pretrained_model_name_or_path, subfolder="transformer")
47+
pipeline = Mock()
48+
pipeline.config = config
49+
pipeline.vae_scale_factor_temporal = 4
50+
transformer = Mock()
51+
transformer.config = Mock()
52+
transformer.config.configure_mock(**wan_config)
53+
pipeline.transformer = transformer
54+
55+
calculated_tflops, attention_flops, seq_len = WanTrainer.calculate_tflops(pipeline)
56+
golden_tflops = 19_573
57+
self.assertFlopsAlmostEqual(calculated_tflops, golden_tflops)
58+
2359
def test_dense_layer_model_flops(self):
2460
class SimpleLinearModel(nn.Module):
2561

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,56 @@ def create_scheduler(self):
100100
noise_scheduler_state = noise_scheduler.set_timesteps(noise_scheduler_state, num_inference_steps=1000, training=True)
101101
return noise_scheduler, noise_scheduler_state
102102

103-
def calculate_tflops(self, pipeline):
104-
max_logging.log("WARNING : Calculting tflops is not implemented in Wan 2.1. Returning 0...")
105-
return 0
103+
@staticmethod
104+
def calculate_tflops(pipeline):
105+
106+
maxdiffusion_config = pipeline.config
107+
# Model configuration
108+
height = pipeline.config.height
109+
width = pipeline.config.width
110+
num_frames = pipeline.config.num_frames
111+
112+
# Transformer dimensions
113+
transformer_config = pipeline.transformer.config
114+
num_layers = transformer_config.num_layers
115+
heads = pipeline.transformer.config.num_attention_heads
116+
head_dim = pipeline.transformer.config.attention_head_dim
117+
ffn_dim = transformer_config.ffn_dim
118+
seq_len = int(((height / 8) * (width / 8) * ((num_frames - 1) // pipeline.vae_scale_factor_temporal + 1)) / 4)
119+
text_encoder_dim = 512
120+
# Attention FLOPS
121+
# Self
122+
self_attn_qkv_proj_flops = 3 * (2 * seq_len * (heads * head_dim) ** 2)
123+
self_attn_qk_v_flops = 2 * (2 * seq_len**2 * (heads * head_dim))
124+
# Cross
125+
cross_attn_kv_proj_flops = 3 * (2 * text_encoder_dim * (heads * head_dim) ** 2)
126+
cross_attn_q_proj_flops = 1 * (2 * seq_len * (heads * head_dim) ** 2)
127+
cross_attention_qk_v_flops = 2 * (2 * seq_len * text_encoder_dim * (heads * head_dim))
128+
129+
# Output_projection from attention
130+
attn_output_proj_flops = 2 * (2 * seq_len * (heads * head_dim) ** 2)
131+
132+
total_attn_flops = (
133+
self_attn_qkv_proj_flops
134+
+ self_attn_qk_v_flops
135+
+ cross_attn_kv_proj_flops
136+
+ cross_attn_q_proj_flops
137+
+ cross_attention_qk_v_flops
138+
+ attn_output_proj_flops
139+
)
140+
141+
# FFN
142+
ffn_flops = 2 * (2 * seq_len * (heads * head_dim) * ffn_dim)
143+
144+
flops_per_block = total_attn_flops + ffn_flops
145+
146+
total_transformer_flops = flops_per_block * num_layers
147+
148+
tflops = maxdiffusion_config.per_device_batch_size * total_transformer_flops / 1e12
149+
train_tflops = 3 * tflops
150+
151+
max_logging.log(f"Calculated TFLOPs per pass: {train_tflops:.4f}")
152+
return train_tflops, total_attn_flops, seq_len
106153

107154
def get_data_shardings(self, mesh):
108155
data_sharding = jax.sharding.NamedSharding(mesh, P(*self.config.data_sharding))
@@ -225,7 +272,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
225272
)
226273
# TODO - 0 needs to be changed to last step if continuing from an orbax checkpoint.
227274
start_step = 0
228-
per_device_tflops = self.calculate_tflops(pipeline)
275+
per_device_tflops, _, _ = WanTrainer.calculate_tflops(pipeline)
229276
scheduler_state = pipeline.scheduler_state
230277
example_batch = load_next_batch(train_data_iterator, None, self.config)
231278

0 commit comments

Comments
 (0)