-
Notifications
You must be signed in to change notification settings - Fork 1.1k
[New Feasture]: Support for Flow-GRPO/Dance-GRPO #1141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,5 +1,5 @@ | ||||||||||||||||||||||||||||||
| import torch, math | ||||||||||||||||||||||||||||||
| from typing_extensions import Literal | ||||||||||||||||||||||||||||||
| from typing_extensions import Literal, Optional, List | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| class FlowMatchScheduler(): | ||||||||||||||||||||||||||||||
|
|
@@ -177,3 +177,145 @@ def training_weight(self, timestep): | |||||||||||||||||||||||||||||
| timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs()) | ||||||||||||||||||||||||||||||
| weights = self.linear_timesteps_weights[timestep_id] | ||||||||||||||||||||||||||||||
| return weights | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| class FlowMatchSDEScheduler(FlowMatchScheduler): | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def __init__(self, | ||||||||||||||||||||||||||||||
| template : Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image"] = "FLUX.1", | ||||||||||||||||||||||||||||||
| noise_level : float = 0.1, | ||||||||||||||||||||||||||||||
| noise_window : Optional[List[int]] = None, | ||||||||||||||||||||||||||||||
| sde_step_num : Optional[int] = None, | ||||||||||||||||||||||||||||||
| sde_type : Literal['Flow-SDE', 'Dance-SDE', 'CPS'] = 'Flow-SDE', | ||||||||||||||||||||||||||||||
| seed: Optional[int] = None, | ||||||||||||||||||||||||||||||
| **kwargs, | ||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||
| super().__init__(template=template, **kwargs) | ||||||||||||||||||||||||||||||
| if noise_window is None: | ||||||||||||||||||||||||||||||
| self.noise_window = list(range(self.num_train_timesteps)) | ||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||
| self.noise_window = list(noise_window) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| self.noise_level = noise_level | ||||||||||||||||||||||||||||||
| self.sde_step_num = sde_step_num or len(self.noise_window) | ||||||||||||||||||||||||||||||
| self.sde_type = sde_type | ||||||||||||||||||||||||||||||
| self.seed = seed or 42 | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def set_seed(self, seed: int) -> torch.Tensor: | ||||||||||||||||||||||||||||||
| self.seed = seed | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| @property | ||||||||||||||||||||||||||||||
| def current_noise_steps(self) -> torch.Tensor: | ||||||||||||||||||||||||||||||
| if self.sde_step_num >= len(self.noise_window): | ||||||||||||||||||||||||||||||
| return self.noise_window | ||||||||||||||||||||||||||||||
| generator = torch.Generator().manual_seed(self.seed) | ||||||||||||||||||||||||||||||
| selected_indices = torch.randperm(len(self.noise_window), generator=generator)[:self.sde_step_num] | ||||||||||||||||||||||||||||||
| return self.noise_window[selected_indices] | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def get_current_noise_level(self, timestep) -> float: | ||||||||||||||||||||||||||||||
| if isinstance(timestep, torch.Tensor): | ||||||||||||||||||||||||||||||
| timestep = timestep.cpu() | ||||||||||||||||||||||||||||||
| timestep_id = torch.argmin((self.timesteps - timestep).abs()) | ||||||||||||||||||||||||||||||
| if timestep_id in self.current_noise_steps: | ||||||||||||||||||||||||||||||
| return self.noise_level | ||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||
| return 0.0 | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def step(self, | ||||||||||||||||||||||||||||||
| model_output, | ||||||||||||||||||||||||||||||
| timestep, | ||||||||||||||||||||||||||||||
| sample, | ||||||||||||||||||||||||||||||
| to_final=False, | ||||||||||||||||||||||||||||||
| prev_sample=None, | ||||||||||||||||||||||||||||||
| generator=None, | ||||||||||||||||||||||||||||||
| return_log_prob=False, | ||||||||||||||||||||||||||||||
| return_dict=False, | ||||||||||||||||||||||||||||||
| **kwargs | ||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||
| if isinstance(timestep, torch.Tensor): | ||||||||||||||||||||||||||||||
| timestep = timestep.cpu() | ||||||||||||||||||||||||||||||
| timestep_id = torch.argmin((self.timesteps - timestep).abs()) | ||||||||||||||||||||||||||||||
| sigma = self.sigmas[timestep_id] | ||||||||||||||||||||||||||||||
| if to_final or timestep_id + 1 >= len(self.timesteps): | ||||||||||||||||||||||||||||||
| sigma_ = 0 | ||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||
| sigma_ = self.sigmas[timestep_id + 1] | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # Convert to float32 for numerical stability | ||||||||||||||||||||||||||||||
| model_output = model_output.float() | ||||||||||||||||||||||||||||||
| sample = sample.float() | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| current_noise_level = self.get_current_noise_level(timestep) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| dt = sigma_ - sigma | ||||||||||||||||||||||||||||||
| sigma_max = self.sigmas[1] # Use the max sigma < 1 | ||||||||||||||||||||||||||||||
| if self.sde_type == 'Flow-SDE': | ||||||||||||||||||||||||||||||
| std_dev_t = torch.sqrt(sigma / (1 - torch.where(sigma == 1, sigma_max, sigma))) * current_noise_level | ||||||||||||||||||||||||||||||
| prev_sample_mean = sample * (1 + std_dev_t**2 / (2 * sigma) * dt) + model_output * (1 + std_dev_t**2 * (1 - sigma) / (2 * sigma)) * dt | ||||||||||||||||||||||||||||||
| if prev_sample is None: | ||||||||||||||||||||||||||||||
| variance_noise = torch.randn_like( | ||||||||||||||||||||||||||||||
| model_output.shape, | ||||||||||||||||||||||||||||||
| generator=generator, | ||||||||||||||||||||||||||||||
| ).to(device=model_output.device, dtype=model_output.dtype) | ||||||||||||||||||||||||||||||
| prev_sample = prev_sample_mean + std_dev_t * torch.sqrt(-1 * dt) * variance_noise | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| if return_log_prob: | ||||||||||||||||||||||||||||||
| log_prob = ( | ||||||||||||||||||||||||||||||
| -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * torch.sqrt(-1 * dt)) ** 2)) | ||||||||||||||||||||||||||||||
| - torch.log(std_dev_t * torch.sqrt(-1 * dt)) | ||||||||||||||||||||||||||||||
| - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi))) | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) | ||||||||||||||||||||||||||||||
|
Comment on lines
+263
to
+268
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The log probability calculation for
Suggested change
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| elif self.sde_type == 'Dance-SDE': | ||||||||||||||||||||||||||||||
| pred_original_sample = sample - sigma * model_output | ||||||||||||||||||||||||||||||
| std_dev_t = current_noise_level * torch.sqrt(-1 * dt) | ||||||||||||||||||||||||||||||
| log_term = 0.5 * current_noise_level**2 * (sample - pred_original_sample * (1 - sigma)) / sigma**2 | ||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line involves a division by
Suggested change
|
||||||||||||||||||||||||||||||
| prev_sample_mean = sample + (model_output + log_term) * dt | ||||||||||||||||||||||||||||||
| if prev_sample is None: | ||||||||||||||||||||||||||||||
| variance_noise = torch.randn( | ||||||||||||||||||||||||||||||
| model_output.shape, | ||||||||||||||||||||||||||||||
| generator=generator, | ||||||||||||||||||||||||||||||
| ).to(device=model_output.device, dtype=model_output.dtype) | ||||||||||||||||||||||||||||||
| prev_sample = prev_sample_mean + std_dev_t * variance_noise | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| if return_log_prob: | ||||||||||||||||||||||||||||||
| log_prob = ( | ||||||||||||||||||||||||||||||
| (-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2))) | ||||||||||||||||||||||||||||||
| - math.log(std_dev_t) | ||||||||||||||||||||||||||||||
| - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi))) | ||||||||||||||||||||||||||||||
|
Comment on lines
+284
to
+286
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are two issues in this block:
Suggested change
|
||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # mean along all but batch dimension | ||||||||||||||||||||||||||||||
| log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| elif self.sde_type == 'CPS': | ||||||||||||||||||||||||||||||
| # Coefficient Preserving Sampling | ||||||||||||||||||||||||||||||
| std_dev_t = sigma_ * torch.sin(current_noise_level * torch.pi / 2) | ||||||||||||||||||||||||||||||
| pred_original_sample = sample - sigma * model_output | ||||||||||||||||||||||||||||||
| noise_estimate = sample + model_output * (1 - sigma) | ||||||||||||||||||||||||||||||
| prev_sample_mean = pred_original_sample * (1 - sigma_) + noise_estimate * torch.sqrt(sigma_**2 - std_dev_t**2) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| if prev_sample is None: | ||||||||||||||||||||||||||||||
| variance_noise = torch.randn( | ||||||||||||||||||||||||||||||
| model_output.shape, | ||||||||||||||||||||||||||||||
| generator=generator, | ||||||||||||||||||||||||||||||
| ).to(device=model_output.device, dtype=model_output.dtype) | ||||||||||||||||||||||||||||||
| prev_sample = prev_sample_mean + std_dev_t * variance_noise | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| if return_log_prob: | ||||||||||||||||||||||||||||||
| log_prob = -((prev_sample.detach() - prev_sample_mean) ** 2) | ||||||||||||||||||||||||||||||
| log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) | ||||||||||||||||||||||||||||||
|
Comment on lines
+307
to
+308
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The calculation for
Suggested change
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| if not return_log_prob: | ||||||||||||||||||||||||||||||
| log_prob = torch.zeros(sample.shape[0], device=sample.device) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| if return_dict: | ||||||||||||||||||||||||||||||
| return { | ||||||||||||||||||||||||||||||
| "prev_sample": prev_sample, | ||||||||||||||||||||||||||||||
| "log_prob": log_prob, | ||||||||||||||||||||||||||||||
| "prev_sample_mean": prev_sample_mean, | ||||||||||||||||||||||||||||||
| "std_dev_t": std_dev_t, | ||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| return prev_sample, log_prob, prev_sample_mean, std_dev_t | ||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.noise_windowis initialized as a list. However, in thecurrent_noise_stepsproperty (line 208), it is indexed with atorch.Tensoron line 212 (self.noise_window[selected_indices]), which will raise aTypeError. To fix this,self.noise_windowshould be atorch.Tensor. This change will also make the return type ofcurrent_noise_stepsconsistent with its type hint (torch.Tensor).