From cd3e18384f0005246ce5f78f475ff20986b8a09d Mon Sep 17 00:00:00 2001 From: mahaocong90 Date: Thu, 19 Feb 2026 17:39:57 +0800 Subject: [PATCH 1/2] Add a FLOPs collection interface that supports real-time collection of floating-point operations (FLOPs) during the training process for WAN models. --- diffsynth/diffusion/runner.py | 60 ++++- diffsynth/models/wan_video_text_encoder.py | 7 +- diffsynth/utils/profiling/__init__.py | 6 + diffsynth/utils/profiling/flops_profiler.py | 252 ++++++++++++++++++ .../utils/xfuser/xdit_context_parallel.py | 7 +- 5 files changed, 327 insertions(+), 5 deletions(-) create mode 100644 diffsynth/utils/profiling/__init__.py create mode 100644 diffsynth/utils/profiling/flops_profiler.py diff --git a/diffsynth/diffusion/runner.py b/diffsynth/diffusion/runner.py index 6e26035e8..5fa39dc57 100644 --- a/diffsynth/diffusion/runner.py +++ b/diffsynth/diffusion/runner.py @@ -3,7 +3,13 @@ from accelerate import Accelerator from .training_module import DiffusionTrainingModule from .logger import ModelLogger - +import time +from ..utils.profiling.flops_profiler import ( + print_model_profile, + get_flops, + profile_entire_model, + unprofile_entire_model, +) def launch_training_task( accelerator: Accelerator, @@ -29,21 +35,69 @@ def launch_training_task( dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers) model.to(device=accelerator.device) model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) - + + train_step = 0 + profile_entire_model(model) + for epoch_id in range(num_epochs): - for data in tqdm(dataloader): + progress = tqdm( + dataloader, + disable=not accelerator.is_main_process, + desc=f"Epoch {epoch_id + 1}/{num_epochs}", + ) + + for data in progress: + iter_start = time.time() + timing = {} + if data is None: + continue + with accelerator.accumulate(model): optimizer.zero_grad() + if dataset.load_from_cache: loss = model({}, inputs=data) else: loss = model(data) + + t5_Tflops, wan_Tflops, vae_Tflops = get_flops(model) accelerator.backward(loss) optimizer.step() + model_logger.on_step_end(accelerator, model, save_steps, loss=loss) scheduler.step() + + torch.cuda.synchronize() + iter_end = time.time() + timing["step"] = iter_end - iter_start + train_step += 1 + + total_flops = t5_Tflops + wan_Tflops + vae_Tflops + TFLOPS = total_flops * 3 / timing["step"] + + if accelerator.is_main_process: + def format_time(key: str) -> str: + value = timing.get(key, 0.0) + return f"{value:.3f}s" + + postfix_dict = { + "Rank": f"{accelerator.process_index}", + "loss": f"{loss.item():.5f}", + "lr": f"{optimizer.param_groups[0]['lr']:.5e}", + "step/t": format_time("step"), + "[t5] Tflops": f"{t5_Tflops:.3f}", + "[dit] Tflops": f"{wan_Tflops:.3f}", + "[vae] Tflops": f"{vae_Tflops:.3f}", + "TFLOPS": f"{TFLOPS:.3f}", + } + progress.set_postfix(postfix_dict) + log_msg = f"[Step {train_step:6d}] | " + " | ".join(f"{k}: {v}" for k, v in postfix_dict.items()) + progress.write(log_msg) + if save_steps is None: model_logger.on_epoch_end(accelerator, model, epoch_id) + + unprofile_entire_model(model) model_logger.on_training_end(accelerator, model, save_steps) diff --git a/diffsynth/models/wan_video_text_encoder.py b/diffsynth/models/wan_video_text_encoder.py index 64090db8c..4f840fe78 100644 --- a/diffsynth/models/wan_video_text_encoder.py +++ b/diffsynth/models/wan_video_text_encoder.py @@ -70,6 +70,11 @@ def forward(self, x, context=None, mask=None, pos_bias=None): k = self.k(context).view(b, -1, n, c) v = self.v(context).view(b, -1, n, c) + # For caculate flops + self.q_shape = q.shape + self.k_shape = k.shape + self.v_shape = v.shape + # attention bias attn_bias = x.new_zeros(b, n, q.size(1), k.size(1)) if pos_bias is not None: @@ -327,4 +332,4 @@ def _clean(self, text): text = whitespace_clean(basic_clean(text)).lower() elif self.clean == 'canonicalize': text = canonicalize(basic_clean(text)) - return text \ No newline at end of file + return text diff --git a/diffsynth/utils/profiling/__init__.py b/diffsynth/utils/profiling/__init__.py new file mode 100644 index 000000000..afad49e1d --- /dev/null +++ b/diffsynth/utils/profiling/__init__.py @@ -0,0 +1,6 @@ +from .flops_profiler import ( + profile_entire_model, + unprofile_entire_model, + get_flops, + print_model_profile, +) diff --git a/diffsynth/utils/profiling/flops_profiler.py b/diffsynth/utils/profiling/flops_profiler.py new file mode 100644 index 000000000..002a5585f --- /dev/null +++ b/diffsynth/utils/profiling/flops_profiler.py @@ -0,0 +1,252 @@ +import torch +import torch.nn as nn +from functools import wraps +import time +from collections import defaultdict +import flash_attn +from einops import rearrange +from torch.utils.flop_counter import conv_flop_count + +def get_dit_flops(model): + def get_dit_flops(dit_block_model): + total_flops = 0 + for sub_model in dit_block_model.modules(): + total_flops += getattr(sub_model, '__flops__', 0) + return total_flops + + total_flops = 0 + total_duration = 0 + for sub_module in model.modules(): + if sub_module.__class__.__name__ == 'DiTBlock': + total_flops += get_dit_flops(sub_module) + total_duration += getattr(sub_module, '__duration__', 0) + + Tflops = total_flops / 1e12 + return Tflops + +def get_flops(model): + def get_module_flops(module): + if not hasattr(module, "__flops__"): + module.__flops__ = 0 + + flops = module.__flops__ + # iterate over immediate children modules + for child in module.children(): + flops += get_module_flops(child) + return flops + + t5_flops = 0 + wan_flops = 0 + vae_flops = 0 + for module in model.modules(): + if module.__class__.__name__ == 'WanTextEncoder': + t5_flops = get_module_flops(module) + if module.__class__.__name__ == 'WanModel': + wan_flops = get_module_flops(module) + if module.__class__.__name__ == 'WanVideoVAE38': + vae_flops = get_module_flops(module) + return t5_flops / 1e12, wan_flops / 1e12, vae_flops / 1e12 + +def print_model_profile(model): + def get_module_flops(module): + if not hasattr(module, "__flops__"): + module.__flops__ = 0 + + flops = module.__flops__ + # iterate over immediate children modules + for child in module.children(): + flops += get_module_flops(child) + return flops + + def get_module_duration(module): + if not hasattr(module, "__duration__"): + module.__duration__ = 0 + + duration = module.__duration__ + if duration == 0: # e.g. ModuleList + for m in module.children(): + duration += get_module_duration(m) + return duration + + def flops_repr(module): + flops = get_module_flops(module) + duration = get_module_duration(module) * 1000 + items = [ + "{:,} flops".format(flops), + "{:.3f} ms".format(duration), + ] + original_extra_repr = module.original_extra_repr() + if original_extra_repr: + items.append(original_extra_repr) + return ", ".join(items) + + def add_extra_repr(module): + flops_extra_repr = flops_repr.__get__(module) + if module.extra_repr != flops_extra_repr: + module.original_extra_repr = module.extra_repr + module.extra_repr = flops_extra_repr + assert module.extra_repr != module.original_extra_repr + + def del_extra_repr(module): + if hasattr(module, "original_extra_repr"): + module.extra_repr = module.original_extra_repr + del module.original_extra_repr + + model.apply(add_extra_repr) + print(model) + model.apply(del_extra_repr) + +def get_module_flops(module, *args, result=None, **kwargs): + module_type = module.__class__.__name__ + module_original_fwd = module._original_forward.__name__ + + if module_type == 'RMSNorm': + x = args[0] + return x.numel() * 4 + + elif module_type == 'RMS_norm': + x = args[0] + return x.numel() * 4 + + elif module_type == 'Dropout': + x = args[0] + return x.numel() * 2 + + elif module_type == 'LayerNorm': + x = args[0] + has_affine = module.weight is not None + return x.numel() * (5 if has_affine else 4) + + elif module_type == 'Linear': + x = args[0] + return x.numel() * module.weight.size(0) * 2 + + elif module_type == 'ReLU': + x = args[0] + return x.numel() + + elif module_type == 'GELU': + x = args[0] + return x.numel() + + elif module_type == 'SiLU': + x = args[0] + return x.numel() + + elif module_type == 'Conv3d' or module_type == 'CausalConv3d' or module_type == 'Conv2d': + x_shape = args[0].shape + weight = getattr(module, 'weight', None) + w_shape = weight.shape + out_shape = result.shape + + flops = conv_flop_count( + x_shape=x_shape, + w_shape=w_shape, + out_shape=out_shape, + transposed=False + ) + return flops + + # AttentionModule input is 3D shape, USP input is 4D shape. + # + # 3D shape: + # q [batch, target_seq_len, Dim] + # k [batch, source_seq_len, Dim] + # v [batch, source_seq_len, Dim] + # flops = (batch * target_seq_len * source_seq_len) * Dim * 2 + # + (batch * target_seq_len * Dim) * source_seq_len * 2 + # = 4 * (batch * target_seq_len * source_seq_len * Dim) + # + # 4D shape: + # q [batch, target_seq_len, head, dim] + # k [batch, source_seq_len, head, dim] + # v [batch, source_seq_len, head, dim] + # flops = 4 * (batch * target_seq_len * source_seq_len * head * dim) + # + elif module_type == 'AttentionModule': + q = args[0] + k = args[1] + v = args[2] + + b, ts, dq = q.shape + _, ss, _ = k.shape + _, _, dv = v.shape + flops = (b * ts * ss * dq) * 2 + (b * ts * ss * dv) * 2 + return flops + + elif module_original_fwd == 'usp_attn_forward' or module_type == 'T5Attention': + q_shape = module.q_shape + k_shape = module.k_shape + v_shape = module.v_shape + + b, ts, n, dq = q_shape + _, ss, _, _ = k_shape + _, _, _, dv = v_shape + flops = (b * ts * ss * n * dq) * 2 + (b * ts * ss * n * dv) * 2 + return flops + + elif module_type == 'GateModule': + x = args[0] + return x.numel() * 2 + + elif module_type == 'T5LayerNorm': + x = args[0] + return x.numel() * 4 + + elif module_type == 'T5RelativeEmbedding': + lq = args[0] + lk = args[1] + return lq * lk * 10 + + else: + return 0 + +def flops_counter(flops_func=None): + def decorator(forward_func): + @wraps(forward_func) + def wrapper(self, *args, **kwargs): + start_time = time.perf_counter() + + result = forward_func(self, *args, **kwargs) + + self.__flops__ = get_module_flops(self, *args, result=result, **kwargs) + + end_time = time.perf_counter() + self.__duration__ = (end_time - start_time) + + return result + return wrapper + return decorator + + +def wrap_existing_module(module, verbose_profiling=False): + # save original fwd + module.verbose_profiling = verbose_profiling + module._original_forward = module.forward + + @flops_counter() + def profiled_forward(self, x, *args, **kwargs): + return module._original_forward(x, *args, **kwargs) + + module.forward = profiled_forward.__get__(module, type(module)) + return module + +def profile_entire_model(model, verbose_profiling=True): + for name, module in model.named_modules(): + wrap_existing_module(module, verbose_profiling) + return model + +def unwrap_existing_module(module): + if hasattr(module, "_original_forward"): + module.forward = module._original_forward + del module._original_forward + + if hasattr(module, "verbose_profiling"): + del module.verbose_profiling + return module + +def unprofile_entire_model(model): + for name, module in model.named_modules(): + unwrap_existing_module(module) + return model + diff --git a/diffsynth/utils/xfuser/xdit_context_parallel.py b/diffsynth/utils/xfuser/xdit_context_parallel.py index 228e7b877..708b69f1a 100644 --- a/diffsynth/utils/xfuser/xdit_context_parallel.py +++ b/diffsynth/utils/xfuser/xdit_context_parallel.py @@ -128,6 +128,11 @@ def usp_attn_forward(self, x, freqs): k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads) v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads) + # For caculate flops + self.q_shape = q.shape + self.k_shape = k.shape + self.v_shape = v.shape + attn_type = AttnType.FA ring_impl_type = "basic" if IS_NPU_AVAILABLE: @@ -143,4 +148,4 @@ def usp_attn_forward(self, x, freqs): del q, k, v getattr(torch, parse_device_type(x.device)).empty_cache() - return self.o(x) \ No newline at end of file + return self.o(x) From c09f61c5b1aea2e9814357b0d78871d3b6d52851 Mon Sep 17 00:00:00 2001 From: mahaocong90 Date: Thu, 19 Feb 2026 21:17:35 +0800 Subject: [PATCH 2/2] [Bugfix] Fix the code issues reviewed by gemini-code-assist about [feature] Add a FLOPs collection interface. --- diffsynth/diffusion/runner.py | 12 ++---- diffsynth/models/wan_video_text_encoder.py | 2 +- diffsynth/utils/profiling/flops_profiler.py | 42 +++++-------------- .../utils/xfuser/xdit_context_parallel.py | 2 +- 4 files changed, 15 insertions(+), 43 deletions(-) diff --git a/diffsynth/diffusion/runner.py b/diffsynth/diffusion/runner.py index 5fa39dc57..868e36f6c 100644 --- a/diffsynth/diffusion/runner.py +++ b/diffsynth/diffusion/runner.py @@ -48,7 +48,6 @@ def launch_training_task( for data in progress: iter_start = time.time() - timing = {} if data is None: continue @@ -68,23 +67,18 @@ def launch_training_task( scheduler.step() torch.cuda.synchronize() - iter_end = time.time() - timing["step"] = iter_end - iter_start + time_step = time.time() - iter_start train_step += 1 total_flops = t5_Tflops + wan_Tflops + vae_Tflops - TFLOPS = total_flops * 3 / timing["step"] + TFLOPS = total_flops * 3 / time_step if accelerator.is_main_process: - def format_time(key: str) -> str: - value = timing.get(key, 0.0) - return f"{value:.3f}s" - postfix_dict = { "Rank": f"{accelerator.process_index}", "loss": f"{loss.item():.5f}", "lr": f"{optimizer.param_groups[0]['lr']:.5e}", - "step/t": format_time("step"), + "step/t": f"{time_step:.3f}", "[t5] Tflops": f"{t5_Tflops:.3f}", "[dit] Tflops": f"{wan_Tflops:.3f}", "[vae] Tflops": f"{vae_Tflops:.3f}", diff --git a/diffsynth/models/wan_video_text_encoder.py b/diffsynth/models/wan_video_text_encoder.py index 4f840fe78..33004e256 100644 --- a/diffsynth/models/wan_video_text_encoder.py +++ b/diffsynth/models/wan_video_text_encoder.py @@ -70,7 +70,7 @@ def forward(self, x, context=None, mask=None, pos_bias=None): k = self.k(context).view(b, -1, n, c) v = self.v(context).view(b, -1, n, c) - # For caculate flops + # For calculate flops self.q_shape = q.shape self.k_shape = k.shape self.v_shape = v.shape diff --git a/diffsynth/utils/profiling/flops_profiler.py b/diffsynth/utils/profiling/flops_profiler.py index 002a5585f..6954f730b 100644 --- a/diffsynth/utils/profiling/flops_profiler.py +++ b/diffsynth/utils/profiling/flops_profiler.py @@ -2,28 +2,8 @@ import torch.nn as nn from functools import wraps import time -from collections import defaultdict -import flash_attn -from einops import rearrange from torch.utils.flop_counter import conv_flop_count -def get_dit_flops(model): - def get_dit_flops(dit_block_model): - total_flops = 0 - for sub_model in dit_block_model.modules(): - total_flops += getattr(sub_model, '__flops__', 0) - return total_flops - - total_flops = 0 - total_duration = 0 - for sub_module in model.modules(): - if sub_module.__class__.__name__ == 'DiTBlock': - total_flops += get_dit_flops(sub_module) - total_duration += getattr(sub_module, '__duration__', 0) - - Tflops = total_flops / 1e12 - return Tflops - def get_flops(model): def get_module_flops(module): if not hasattr(module, "__flops__"): @@ -96,7 +76,7 @@ def del_extra_repr(module): print(model) model.apply(del_extra_repr) -def get_module_flops(module, *args, result=None, **kwargs): +def calculate_module_flops(module, *args, result=None, **kwargs): module_type = module.__class__.__name__ module_original_fwd = module._original_forward.__name__ @@ -193,6 +173,8 @@ def get_module_flops(module, *args, result=None, **kwargs): x = args[0] return x.numel() * 4 + # The 10x factor is an estimate of the computational coefficient for torch.log. + # The search and move operations in position encoding do not involve flop operations. elif module_type == 'T5RelativeEmbedding': lq = args[0] lk = args[1] @@ -201,7 +183,7 @@ def get_module_flops(module, *args, result=None, **kwargs): else: return 0 -def flops_counter(flops_func=None): +def flops_counter(): def decorator(forward_func): @wraps(forward_func) def wrapper(self, *args, **kwargs): @@ -209,7 +191,7 @@ def wrapper(self, *args, **kwargs): result = forward_func(self, *args, **kwargs) - self.__flops__ = get_module_flops(self, *args, result=result, **kwargs) + self.__flops__ = calculate_module_flops(self, *args, result=result, **kwargs) end_time = time.perf_counter() self.__duration__ = (end_time - start_time) @@ -218,22 +200,20 @@ def wrapper(self, *args, **kwargs): return wrapper return decorator - -def wrap_existing_module(module, verbose_profiling=False): +def wrap_existing_module(module): # save original fwd - module.verbose_profiling = verbose_profiling module._original_forward = module.forward @flops_counter() - def profiled_forward(self, x, *args, **kwargs): - return module._original_forward(x, *args, **kwargs) + def profiled_forward(self, *args, **kwargs): + return module._original_forward(*args, **kwargs) module.forward = profiled_forward.__get__(module, type(module)) return module -def profile_entire_model(model, verbose_profiling=True): +def profile_entire_model(model): for name, module in model.named_modules(): - wrap_existing_module(module, verbose_profiling) + wrap_existing_module(module) return model def unwrap_existing_module(module): @@ -241,8 +221,6 @@ def unwrap_existing_module(module): module.forward = module._original_forward del module._original_forward - if hasattr(module, "verbose_profiling"): - del module.verbose_profiling return module def unprofile_entire_model(model): diff --git a/diffsynth/utils/xfuser/xdit_context_parallel.py b/diffsynth/utils/xfuser/xdit_context_parallel.py index 708b69f1a..c6c4883f2 100644 --- a/diffsynth/utils/xfuser/xdit_context_parallel.py +++ b/diffsynth/utils/xfuser/xdit_context_parallel.py @@ -128,7 +128,7 @@ def usp_attn_forward(self, x, freqs): k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads) v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads) - # For caculate flops + # For calculate flops self.q_shape = q.shape self.k_shape = k.shape self.v_shape = v.shape