diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index 477697fadb64..ce75565b5e8a 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -1717,12 +1717,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): cond_model_input = (cond_model_input - latents_bn_mean) / latents_bn_std model_input_ids = Flux2Pipeline._prepare_latent_ids(model_input).to(device=model_input.device) - cond_model_input_list = [cond_model_input[i].unsqueeze(0) for i in range(cond_model_input.shape[0])] - cond_model_input_ids = Flux2Pipeline._prepare_image_ids(cond_model_input_list).to( + cond_model_input_ids = Flux2Pipeline._prepare_image_ids([cond_model_input[0:1]]).to( device=cond_model_input.device ) - cond_model_input_ids = cond_model_input_ids.view( - cond_model_input.shape[0], -1, model_input_ids.shape[-1] + cond_model_input_ids = cond_model_input_ids.expand( + cond_model_input.shape[0], -1, -1 ) # Sample noise that we'll add to the latents