[New Feasture]: Add a FLOPs collection interface#1302
[New Feasture]: Add a FLOPs collection interface#1302mahaocong90 wants to merge 2 commits intomodelscope:mainfrom
Conversation
…f floating-point operations (FLOPs) during the training process for WAN models.
Summary of ChangesHello @mahaocong90, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a robust system for real-time profiling of Floating Point Operations (FLOPs) during the training of deep learning models, specifically targeting WAN models. By wrapping neural network modules and employing decorators, it accurately estimates computational costs for various layers, providing valuable insights into model efficiency. The integration into the training loop allows for immediate feedback on TFLOPS per step, aiding in performance monitoring and optimization efforts. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a valuable FLOPs collection interface for real-time performance monitoring during model training. The implementation correctly uses decorators and module wrapping to estimate computation. My review includes suggestions to improve the robustness and maintainability of the new profiling code, addressing a potential bug in the module wrapping logic, removing unused code, and enhancing clarity in a few areas.
| def profiled_forward(self, x, *args, **kwargs): | ||
| return module._original_forward(x, *args, **kwargs) |
There was a problem hiding this comment.
The signature of profiled_forward is (self, x, *args, **kwargs), which assumes that every wrapped module's forward method has a first positional argument x. This is not always true and can lead to a TypeError for modules with different forward signatures (e.g., no arguments, or keyword-only arguments). The signature should be (self, *args, **kwargs) to be generic and robust.
| def profiled_forward(self, x, *args, **kwargs): | |
| return module._original_forward(x, *args, **kwargs) | |
| def profiled_forward(self, *args, **kwargs): | |
| return module._original_forward(*args, **kwargs) |
diffsynth/diffusion/runner.py
Outdated
| def format_time(key: str) -> str: | ||
| value = timing.get(key, 0.0) | ||
| return f"{value:.3f}s" |
There was a problem hiding this comment.
This function is redefined on every iteration of the training loop, which is inefficient. It's better to define it once outside the loop. To do so, you'll need to pass the timing dictionary as an argument.
For example, you could define it before the loop:
def format_time(timing_dict: dict, key: str) -> str:
value = timing_dict.get(key, 0.0)
return f"{value:.3f}s"And then call it inside the loop as format_time(timing, "step").
Since I cannot suggest changes outside of the current diff hunk, I'm leaving this as a comment for you to refactor.
| k = self.k(context).view(b, -1, n, c) | ||
| v = self.v(context).view(b, -1, n, c) | ||
|
|
||
| # For caculate flops |
| from collections import defaultdict | ||
| import flash_attn | ||
| from einops import rearrange |
| def get_dit_flops(model): | ||
| def get_dit_flops(dit_block_model): | ||
| total_flops = 0 | ||
| for sub_model in dit_block_model.modules(): | ||
| total_flops += getattr(sub_model, '__flops__', 0) | ||
| return total_flops | ||
|
|
||
| total_flops = 0 | ||
| total_duration = 0 | ||
| for sub_module in model.modules(): | ||
| if sub_module.__class__.__name__ == 'DiTBlock': | ||
| total_flops += get_dit_flops(sub_module) | ||
| total_duration += getattr(sub_module, '__duration__', 0) | ||
|
|
||
| Tflops = total_flops / 1e12 | ||
| return Tflops |
There was a problem hiding this comment.
This function has a couple of issues affecting readability and maintainability:
- The nested function at line 11 has the same name as the outer function, which can be confusing. It's better to give it a more descriptive name, like
_get_dit_flops_recursive. - The
total_durationvariable is calculated but never used. It should be removed.
| def get_dit_flops(model): | |
| def get_dit_flops(dit_block_model): | |
| total_flops = 0 | |
| for sub_model in dit_block_model.modules(): | |
| total_flops += getattr(sub_model, '__flops__', 0) | |
| return total_flops | |
| total_flops = 0 | |
| total_duration = 0 | |
| for sub_module in model.modules(): | |
| if sub_module.__class__.__name__ == 'DiTBlock': | |
| total_flops += get_dit_flops(sub_module) | |
| total_duration += getattr(sub_module, '__duration__', 0) | |
| Tflops = total_flops / 1e12 | |
| return Tflops | |
| def get_dit_flops(model): | |
| def _get_dit_flops_recursive(dit_block_model): | |
| total_flops = 0 | |
| for sub_model in dit_block_model.modules(): | |
| total_flops += getattr(sub_model, '__flops__', 0) | |
| return total_flops | |
| total_flops = 0 | |
| for sub_module in model.modules(): | |
| if sub_module.__class__.__name__ == 'DiTBlock': | |
| total_flops += _get_dit_flops_recursive(sub_module) | |
| Tflops = total_flops / 1e12 | |
| return Tflops |
…ature] Add a FLOPs collection interface.
This PR add a FLOPs collection interface that supports real-time collection of floating-point operations (FLOPs) during the training process for WAN models.
Description
Wrap the nn.modules module and use Python decorators to estimate the forward computational flops based on module inputs, outputs, and model parameters.
Environment version
os:ubuntu24.04
cuda driver:550.163.01 + 12.9
Python:3.12.3
torch:2.8.0
xfuser:0.4.5
transformers:4.55.2
gpu: a800 x 8, one node
Print result
Once we have estimated the FLOPS for the forward process, we can roughly estimate the TFLOPS for a training step by the formula: (fwd + bwd) / time = 3 * fwd / time.