diff --git a/examples/dreambooth/train_dreambooth_lora_qwen_image.py b/examples/dreambooth/train_dreambooth_lora_qwen_image.py index 0afb608af84a..4dcd5457fb41 100644 --- a/examples/dreambooth/train_dreambooth_lora_qwen_image.py +++ b/examples/dreambooth/train_dreambooth_lora_qwen_image.py @@ -1533,9 +1533,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # from the cat above, but collate_fn also doubles the prompts list. Use half the # prompts count to avoid a 2x over-repeat that produces more embeddings than latents. num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts) - prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1) + prompt_embeds = prompt_embeds.repeat_interleave(num_repeat_elements, dim=0) if prompt_embeds_mask is not None: - prompt_embeds_mask = prompt_embeds_mask.repeat(num_repeat_elements, 1) + prompt_embeds_mask = prompt_embeds_mask.repeat_interleave(num_repeat_elements, dim=0) # Convert images to latent space if args.cache_latents: model_input = latents_cache[step].sample() @@ -1602,10 +1602,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) + weighting, weighting_prior = torch.chunk(weighting, 2, dim=0) # Compute prior loss prior_loss = torch.mean( - (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( + (weighting_prior.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( target_prior.shape[0], -1 ), 1,