-
Notifications
You must be signed in to change notification settings - Fork 1.1k
[New Feasture]: Support for Flow-GRPO/Dance-GRPO #1141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @Jayce-Ping, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates an online Reinforcement Learning (RL) training pipeline into the system, specifically supporting Flow-GRPO and Dance-GRPO algorithms for flow-matching models. This enhancement allows models to be fine-tuned using non-differentiable objectives like human preferences, eliminating the need for traditional supervised fine-tuning with paired image data and significantly improving data efficiency. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a FlowMatchSDEScheduler to support Reinforcement Learning based training for flow-matching models, as described in the Flow-GRPO and Dance-GRPO papers. The new scheduler implements stochastic sampling to enable exploration. My review of the new scheduler implementation has identified several critical issues related to correctness and numerical stability, including a type error that would cause a crash, potential division-by-zero errors, and an incorrect log-probability calculation. I have provided detailed comments and code suggestions to address these problems.
| if noise_window is None: | ||
| self.noise_window = list(range(self.num_train_timesteps)) | ||
| else: | ||
| self.noise_window = list(noise_window) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.noise_window is initialized as a list. However, in the current_noise_steps property (line 208), it is indexed with a torch.Tensor on line 212 (self.noise_window[selected_indices]), which will raise a TypeError. To fix this, self.noise_window should be a torch.Tensor. This change will also make the return type of current_noise_steps consistent with its type hint (torch.Tensor).
| if noise_window is None: | |
| self.noise_window = list(range(self.num_train_timesteps)) | |
| else: | |
| self.noise_window = list(noise_window) | |
| if noise_window is None: | |
| self.noise_window = torch.arange(self.num_train_timesteps) | |
| else: | |
| self.noise_window = torch.tensor(noise_window) |
| elif self.sde_type == 'Dance-SDE': | ||
| pred_original_sample = sample - sigma * model_output | ||
| std_dev_t = current_noise_level * torch.sqrt(-1 * dt) | ||
| log_term = 0.5 * current_noise_level**2 * (sample - pred_original_sample * (1 - sigma)) / sigma**2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line involves a division by sigma**2. Some scheduler configurations, like set_timesteps_wan, can produce sigma values of 0. This will lead to a division-by-zero error. Please add a small epsilon to the denominator for numerical stability.
| log_term = 0.5 * current_noise_level**2 * (sample - pred_original_sample * (1 - sigma)) / sigma**2 | |
| log_term = 0.5 * current_noise_level**2 * (sample - pred_original_sample * (1 - sigma)) / (sigma**2 + 1e-9) |
| (-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2))) | ||
| - math.log(std_dev_t) | ||
| - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are two issues in this block:
math.logis used on a tensorstd_dev_t(line 285), which will cause aTypeError. You should usetorch.loginstead.- If
dtis zero,std_dev_twill be zero, leading to division by zero in the first term andlog(0)in the second term. Adding a small epsilon will improve numerical stability.
| (-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2))) | |
| - math.log(std_dev_t) | |
| - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi))) | |
| (-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2) + 1e-9)) | |
| - torch.log(std_dev_t + 1e-9) | |
| - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi))) |
| log_prob = ( | ||
| -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * torch.sqrt(-1 * dt)) ** 2)) | ||
| - torch.log(std_dev_t * torch.sqrt(-1 * dt)) | ||
| - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi))) | ||
| ) | ||
| log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The log probability calculation for Flow-SDE can be unstable if dt is zero, which would cause division by zero or log(0). It's safer to add a small epsilon for numerical stability. Additionally, torch.sqrt(-1 * dt) is computed multiple times; pre-calculating it can improve clarity and slightly improve efficiency.
| log_prob = ( | |
| -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * torch.sqrt(-1 * dt)) ** 2)) | |
| - torch.log(std_dev_t * torch.sqrt(-1 * dt)) | |
| - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi))) | |
| ) | |
| log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) | |
| sqrt_neg_dt = torch.sqrt(-dt) | |
| variance = (std_dev_t * sqrt_neg_dt) ** 2 | |
| log_prob = ( | |
| -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * variance + 1e-9) | |
| - torch.log(std_dev_t * sqrt_neg_dt + 1e-9) | |
| - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi))) | |
| ) | |
| log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) |
| log_prob = -((prev_sample.detach() - prev_sample_mean) ** 2) | ||
| log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The calculation for log_prob for the CPS SDE type appears to be incomplete. It's currently calculating only the negative squared error. A proper log probability for a Gaussian distribution should also include terms for the variance and the normalization constant, following the formula: log P(x) = - (x - μ)² / (2σ²) - log(σ) - 0.5 * log(2π). Here, the variance is std_dev_t**2.
| log_prob = -((prev_sample.detach() - prev_sample_mean) ** 2) | |
| log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) | |
| variance = std_dev_t**2 | |
| log_prob = ( | |
| -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * variance + 1e-9) | |
| - torch.log(std_dev_t + 1e-9) | |
| - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi))) | |
| ) | |
| log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) |
Description
This PR introduces a Reinforcement Learning (RL) training pipeline for flow-matching models, implementing the algorithms proposed in Flow-GRPO (arXiv:2505.05470) and Dance-GRPO (arXiv:2505.07818).
Traditional flow-matching models (like FLUX or Stable Diffusion 3) rely on Supervised Fine-Tuning (SFT) using paired data. This PR enables Online RL directly on the flow-matching vector field, allowing the model to optimize towards non-differentiable objectives (such as human preference, aesthetic scores, or structural constraints) in a data-efficient manner without requiring ground-truth target images.
Key technical contributions include:
Features
Data-Efficient Fine-Tuning: Enables model alignment using only prompts and a reward function (no paired image data required).
Flow-GRPO Algorithm: Full implementation of the GRPO loss specifically derived for flow-matching vector fields.
Stochastic Scheduler for Exploration: Added a new scheduler that supports ODE-SDE mixed sampling to inject noise for exploration during the RL rollout phase.
RL Trainer Wrapper: A modular
FlowGRPOTrainerthat integrates seamlessly with the existingDiffSynth-Studiotraining loop.Reward Model Registry:
GRPO Sampler: A
DistributedSamplerobject designed for diverse and evenly distributed RL sampling.TODO
pipeline/wrapperobject.Flux.1andQwen-Imagebackbones.Type of Change
References
Relevant Issues
#1111 #1110