From 0392cd8824b1c01d4c43a21564f505b8a19cbb9a Mon Sep 17 00:00:00 2001 From: zhujian <2469395556@qq.com> Date: Tue, 26 May 2026 20:46:32 +0800 Subject: [PATCH] fix: resolve issue #13811 --- examples/dreambooth/train_dreambooth_lora_flux2_img2img.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index 477697fadb64..ce75565b5e8a 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -1717,12 +1717,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): cond_model_input = (cond_model_input - latents_bn_mean) / latents_bn_std model_input_ids = Flux2Pipeline._prepare_latent_ids(model_input).to(device=model_input.device) - cond_model_input_list = [cond_model_input[i].unsqueeze(0) for i in range(cond_model_input.shape[0])] - cond_model_input_ids = Flux2Pipeline._prepare_image_ids(cond_model_input_list).to( + cond_model_input_ids = Flux2Pipeline._prepare_image_ids([cond_model_input[0:1]]).to( device=cond_model_input.device ) - cond_model_input_ids = cond_model_input_ids.view( - cond_model_input.shape[0], -1, model_input_ids.shape[-1] + cond_model_input_ids = cond_model_input_ids.expand( + cond_model_input.shape[0], -1, -1 ) # Sample noise that we'll add to the latents