diff --git a/diffsynth/diffusion/runner.py b/diffsynth/diffusion/runner.py index 6e26035e8..868e36f6c 100644 --- a/diffsynth/diffusion/runner.py +++ b/diffsynth/diffusion/runner.py @@ -3,7 +3,13 @@ from accelerate import Accelerator from .training_module import DiffusionTrainingModule from .logger import ModelLogger - +import time +from ..utils.profiling.flops_profiler import ( + print_model_profile, + get_flops, + profile_entire_model, + unprofile_entire_model, +) def launch_training_task( accelerator: Accelerator, @@ -29,21 +35,63 @@ def launch_training_task( dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers) model.to(device=accelerator.device) model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) - + + train_step = 0 + profile_entire_model(model) + for epoch_id in range(num_epochs): - for data in tqdm(dataloader): + progress = tqdm( + dataloader, + disable=not accelerator.is_main_process, + desc=f"Epoch {epoch_id + 1}/{num_epochs}", + ) + + for data in progress: + iter_start = time.time() + if data is None: + continue + with accelerator.accumulate(model): optimizer.zero_grad() + if dataset.load_from_cache: loss = model({}, inputs=data) else: loss = model(data) + + t5_Tflops, wan_Tflops, vae_Tflops = get_flops(model) accelerator.backward(loss) optimizer.step() + model_logger.on_step_end(accelerator, model, save_steps, loss=loss) scheduler.step() + + torch.cuda.synchronize() + time_step = time.time() - iter_start + train_step += 1 + + total_flops = t5_Tflops + wan_Tflops + vae_Tflops + TFLOPS = total_flops * 3 / time_step + + if accelerator.is_main_process: + postfix_dict = { + "Rank": f"{accelerator.process_index}", + "loss": f"{loss.item():.5f}", + "lr": f"{optimizer.param_groups[0]['lr']:.5e}", + "step/t": f"{time_step:.3f}", + "[t5] Tflops": f"{t5_Tflops:.3f}", + "[dit] Tflops": f"{wan_Tflops:.3f}", + "[vae] Tflops": f"{vae_Tflops:.3f}", + "TFLOPS": f"{TFLOPS:.3f}", + } + progress.set_postfix(postfix_dict) + log_msg = f"[Step {train_step:6d}] | " + " | ".join(f"{k}: {v}" for k, v in postfix_dict.items()) + progress.write(log_msg) + if save_steps is None: model_logger.on_epoch_end(accelerator, model, epoch_id) + + unprofile_entire_model(model) model_logger.on_training_end(accelerator, model, save_steps) diff --git a/diffsynth/models/wan_video_text_encoder.py b/diffsynth/models/wan_video_text_encoder.py index 64090db8c..33004e256 100644 --- a/diffsynth/models/wan_video_text_encoder.py +++ b/diffsynth/models/wan_video_text_encoder.py @@ -70,6 +70,11 @@ def forward(self, x, context=None, mask=None, pos_bias=None): k = self.k(context).view(b, -1, n, c) v = self.v(context).view(b, -1, n, c) + # For calculate flops + self.q_shape = q.shape + self.k_shape = k.shape + self.v_shape = v.shape + # attention bias attn_bias = x.new_zeros(b, n, q.size(1), k.size(1)) if pos_bias is not None: @@ -327,4 +332,4 @@ def _clean(self, text): text = whitespace_clean(basic_clean(text)).lower() elif self.clean == 'canonicalize': text = canonicalize(basic_clean(text)) - return text \ No newline at end of file + return text diff --git a/diffsynth/utils/profiling/__init__.py b/diffsynth/utils/profiling/__init__.py new file mode 100644 index 000000000..afad49e1d --- /dev/null +++ b/diffsynth/utils/profiling/__init__.py @@ -0,0 +1,6 @@ +from .flops_profiler import ( + profile_entire_model, + unprofile_entire_model, + get_flops, + print_model_profile, +) diff --git a/diffsynth/utils/profiling/flops_profiler.py b/diffsynth/utils/profiling/flops_profiler.py new file mode 100644 index 000000000..6954f730b --- /dev/null +++ b/diffsynth/utils/profiling/flops_profiler.py @@ -0,0 +1,230 @@ +import torch +import torch.nn as nn +from functools import wraps +import time +from torch.utils.flop_counter import conv_flop_count + +def get_flops(model): + def get_module_flops(module): + if not hasattr(module, "__flops__"): + module.__flops__ = 0 + + flops = module.__flops__ + # iterate over immediate children modules + for child in module.children(): + flops += get_module_flops(child) + return flops + + t5_flops = 0 + wan_flops = 0 + vae_flops = 0 + for module in model.modules(): + if module.__class__.__name__ == 'WanTextEncoder': + t5_flops = get_module_flops(module) + if module.__class__.__name__ == 'WanModel': + wan_flops = get_module_flops(module) + if module.__class__.__name__ == 'WanVideoVAE38': + vae_flops = get_module_flops(module) + return t5_flops / 1e12, wan_flops / 1e12, vae_flops / 1e12 + +def print_model_profile(model): + def get_module_flops(module): + if not hasattr(module, "__flops__"): + module.__flops__ = 0 + + flops = module.__flops__ + # iterate over immediate children modules + for child in module.children(): + flops += get_module_flops(child) + return flops + + def get_module_duration(module): + if not hasattr(module, "__duration__"): + module.__duration__ = 0 + + duration = module.__duration__ + if duration == 0: # e.g. ModuleList + for m in module.children(): + duration += get_module_duration(m) + return duration + + def flops_repr(module): + flops = get_module_flops(module) + duration = get_module_duration(module) * 1000 + items = [ + "{:,} flops".format(flops), + "{:.3f} ms".format(duration), + ] + original_extra_repr = module.original_extra_repr() + if original_extra_repr: + items.append(original_extra_repr) + return ", ".join(items) + + def add_extra_repr(module): + flops_extra_repr = flops_repr.__get__(module) + if module.extra_repr != flops_extra_repr: + module.original_extra_repr = module.extra_repr + module.extra_repr = flops_extra_repr + assert module.extra_repr != module.original_extra_repr + + def del_extra_repr(module): + if hasattr(module, "original_extra_repr"): + module.extra_repr = module.original_extra_repr + del module.original_extra_repr + + model.apply(add_extra_repr) + print(model) + model.apply(del_extra_repr) + +def calculate_module_flops(module, *args, result=None, **kwargs): + module_type = module.__class__.__name__ + module_original_fwd = module._original_forward.__name__ + + if module_type == 'RMSNorm': + x = args[0] + return x.numel() * 4 + + elif module_type == 'RMS_norm': + x = args[0] + return x.numel() * 4 + + elif module_type == 'Dropout': + x = args[0] + return x.numel() * 2 + + elif module_type == 'LayerNorm': + x = args[0] + has_affine = module.weight is not None + return x.numel() * (5 if has_affine else 4) + + elif module_type == 'Linear': + x = args[0] + return x.numel() * module.weight.size(0) * 2 + + elif module_type == 'ReLU': + x = args[0] + return x.numel() + + elif module_type == 'GELU': + x = args[0] + return x.numel() + + elif module_type == 'SiLU': + x = args[0] + return x.numel() + + elif module_type == 'Conv3d' or module_type == 'CausalConv3d' or module_type == 'Conv2d': + x_shape = args[0].shape + weight = getattr(module, 'weight', None) + w_shape = weight.shape + out_shape = result.shape + + flops = conv_flop_count( + x_shape=x_shape, + w_shape=w_shape, + out_shape=out_shape, + transposed=False + ) + return flops + + # AttentionModule input is 3D shape, USP input is 4D shape. + # + # 3D shape: + # q [batch, target_seq_len, Dim] + # k [batch, source_seq_len, Dim] + # v [batch, source_seq_len, Dim] + # flops = (batch * target_seq_len * source_seq_len) * Dim * 2 + # + (batch * target_seq_len * Dim) * source_seq_len * 2 + # = 4 * (batch * target_seq_len * source_seq_len * Dim) + # + # 4D shape: + # q [batch, target_seq_len, head, dim] + # k [batch, source_seq_len, head, dim] + # v [batch, source_seq_len, head, dim] + # flops = 4 * (batch * target_seq_len * source_seq_len * head * dim) + # + elif module_type == 'AttentionModule': + q = args[0] + k = args[1] + v = args[2] + + b, ts, dq = q.shape + _, ss, _ = k.shape + _, _, dv = v.shape + flops = (b * ts * ss * dq) * 2 + (b * ts * ss * dv) * 2 + return flops + + elif module_original_fwd == 'usp_attn_forward' or module_type == 'T5Attention': + q_shape = module.q_shape + k_shape = module.k_shape + v_shape = module.v_shape + + b, ts, n, dq = q_shape + _, ss, _, _ = k_shape + _, _, _, dv = v_shape + flops = (b * ts * ss * n * dq) * 2 + (b * ts * ss * n * dv) * 2 + return flops + + elif module_type == 'GateModule': + x = args[0] + return x.numel() * 2 + + elif module_type == 'T5LayerNorm': + x = args[0] + return x.numel() * 4 + + # The 10x factor is an estimate of the computational coefficient for torch.log. + # The search and move operations in position encoding do not involve flop operations. + elif module_type == 'T5RelativeEmbedding': + lq = args[0] + lk = args[1] + return lq * lk * 10 + + else: + return 0 + +def flops_counter(): + def decorator(forward_func): + @wraps(forward_func) + def wrapper(self, *args, **kwargs): + start_time = time.perf_counter() + + result = forward_func(self, *args, **kwargs) + + self.__flops__ = calculate_module_flops(self, *args, result=result, **kwargs) + + end_time = time.perf_counter() + self.__duration__ = (end_time - start_time) + + return result + return wrapper + return decorator + +def wrap_existing_module(module): + # save original fwd + module._original_forward = module.forward + + @flops_counter() + def profiled_forward(self, *args, **kwargs): + return module._original_forward(*args, **kwargs) + + module.forward = profiled_forward.__get__(module, type(module)) + return module + +def profile_entire_model(model): + for name, module in model.named_modules(): + wrap_existing_module(module) + return model + +def unwrap_existing_module(module): + if hasattr(module, "_original_forward"): + module.forward = module._original_forward + del module._original_forward + + return module + +def unprofile_entire_model(model): + for name, module in model.named_modules(): + unwrap_existing_module(module) + return model + diff --git a/diffsynth/utils/xfuser/xdit_context_parallel.py b/diffsynth/utils/xfuser/xdit_context_parallel.py index 228e7b877..c6c4883f2 100644 --- a/diffsynth/utils/xfuser/xdit_context_parallel.py +++ b/diffsynth/utils/xfuser/xdit_context_parallel.py @@ -128,6 +128,11 @@ def usp_attn_forward(self, x, freqs): k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads) v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads) + # For calculate flops + self.q_shape = q.shape + self.k_shape = k.shape + self.v_shape = v.shape + attn_type = AttnType.FA ring_impl_type = "basic" if IS_NPU_AVAILABLE: @@ -143,4 +148,4 @@ def usp_attn_forward(self, x, freqs): del q, k, v getattr(torch, parse_device_type(x.device)).empty_cache() - return self.o(x) \ No newline at end of file + return self.o(x)