diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 608ab3ef3135..8c83bb5466b6 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -895,9 +895,8 @@ def initialize_new_tokens(self, inserting_toks: List[str]): self.train_ids_t5 = tokenizer.convert_tokens_to_ids(self.inserting_toks) # random initialization of new tokens - embeds = ( - text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens - ) + text_module = text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder + embeds = text_module.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens std_token_embedding = embeds.weight.data.std() logger.info(f"{idx} text encoder's std_token_embedding: {std_token_embedding}") @@ -905,9 +904,7 @@ def initialize_new_tokens(self, inserting_toks: List[str]): train_ids = self.train_ids if idx == 0 else self.train_ids_t5 # if initializer_concept are not provided, token embeddings are initialized randomly if args.initializer_concept is None: - hidden_size = ( - text_encoder.text_model.config.hidden_size if idx == 0 else text_encoder.encoder.config.hidden_size - ) + hidden_size = text_module.config.hidden_size if idx == 0 else text_encoder.encoder.config.hidden_size embeds.weight.data[train_ids] = ( torch.randn(len(train_ids), hidden_size).to(device=self.device).to(dtype=self.dtype) * std_token_embedding @@ -940,7 +937,8 @@ def save_embeddings(self, file_path: str): idx_to_text_encoder_name = {0: "clip_l", 1: "t5"} for idx, text_encoder in enumerate(self.text_encoders): train_ids = self.train_ids if idx == 0 else self.train_ids_t5 - embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared + text_module = text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder + embeds = text_module.embeddings.token_embedding if idx == 0 else text_encoder.shared assert embeds.weight.data.shape[0] == len(self.tokenizers[idx]), "Tokenizers should be the same." new_token_embeddings = embeds.weight.data[train_ids] @@ -962,7 +960,8 @@ def device(self): @torch.no_grad() def retract_embeddings(self): for idx, text_encoder in enumerate(self.text_encoders): - embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared + text_module = text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder + embeds = text_module.embeddings.token_embedding if idx == 0 else text_encoder.shared index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"] embeds.weight.data[index_no_updates] = ( self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates] @@ -2112,7 +2111,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.train_text_encoder: text_encoder_one.train() # set top parameter requires_grad = True for gradient checkpointing works - unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) + _te_one = unwrap_model(text_encoder_one) + (_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True) elif args.train_text_encoder_ti: # textual inversion / pivotal tuning text_encoder_one.train() if args.enable_t5_ti: diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py index a47e4dd96dcb..ae438f720aa2 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py @@ -763,19 +763,28 @@ def initialize_new_tokens(self, inserting_toks: List[str]): self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks) # random initialization of new tokens - std_token_embedding = text_encoder.text_model.embeddings.token_embedding.weight.data.std() + std_token_embedding = ( + text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder + ).embeddings.token_embedding.weight.data.std() print(f"{idx} text encoder's std_token_embedding: {std_token_embedding}") - text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] = ( - torch.randn(len(self.train_ids), text_encoder.text_model.config.hidden_size) + ( + text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder + ).embeddings.token_embedding.weight.data[self.train_ids] = ( + torch.randn( + len(self.train_ids), + ( + text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder + ).config.hidden_size, + ) .to(device=self.device) .to(dtype=self.dtype) * std_token_embedding ) self.embeddings_settings[f"original_embeddings_{idx}"] = ( - text_encoder.text_model.embeddings.token_embedding.weight.data.clone() - ) + text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder + ).embeddings.token_embedding.weight.data.clone() self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding inu = torch.ones((len(tokenizer),), dtype=torch.bool) @@ -794,10 +803,14 @@ def save_embeddings(self, file_path: str): # text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14 - TODO - change for sd idx_to_text_encoder_name = {0: "clip_l", 1: "clip_g"} for idx, text_encoder in enumerate(self.text_encoders): - assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len( - self.tokenizers[0] - ), "Tokenizers should be the same." - new_token_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] + assert ( + text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder + ).embeddings.token_embedding.weight.data.shape[0] == len(self.tokenizers[0]), ( + "Tokenizers should be the same." + ) + new_token_embeddings = ( + text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder + ).embeddings.token_embedding.weight.data[self.train_ids] # New tokens for each text encoder are saved under "clip_l" (for text_encoder 0), "clip_g" (for # text_encoder 1) to keep compatible with the ecosystem. @@ -819,7 +832,9 @@ def device(self): def retract_embeddings(self): for idx, text_encoder in enumerate(self.text_encoders): index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"] - text_encoder.text_model.embeddings.token_embedding.weight.data[index_no_updates] = ( + ( + text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder + ).embeddings.token_embedding.weight.data[index_no_updates] = ( self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates] .to(device=text_encoder.device) .to(dtype=text_encoder.dtype) @@ -830,11 +845,15 @@ def retract_embeddings(self): std_token_embedding = self.embeddings_settings[f"std_token_embedding_{idx}"] index_updates = ~index_no_updates - new_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] + new_embeddings = ( + text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder + ).embeddings.token_embedding.weight.data[index_updates] off_ratio = std_token_embedding / new_embeddings.std() new_embeddings = new_embeddings * (off_ratio**0.1) - text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] = new_embeddings + ( + text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder + ).embeddings.token_embedding.weight.data[index_updates] = new_embeddings class DreamBoothDataset(Dataset): @@ -1704,7 +1723,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): text_encoder_one.train() # set top parameter requires_grad = True for gradient checkpointing works if args.train_text_encoder: - text_encoder_one.text_model.embeddings.requires_grad_(True) + _te_one = text_encoder_one + (_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True) unet.train() for step, batch in enumerate(train_dataloader): diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index dcaa5a38fc37..8d6e04a35bbb 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -929,19 +929,28 @@ def initialize_new_tokens(self, inserting_toks: List[str]): self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks) # random initialization of new tokens - std_token_embedding = text_encoder.text_model.embeddings.token_embedding.weight.data.std() + std_token_embedding = ( + text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder + ).embeddings.token_embedding.weight.data.std() print(f"{idx} text encoder's std_token_embedding: {std_token_embedding}") - text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] = ( - torch.randn(len(self.train_ids), text_encoder.text_model.config.hidden_size) + ( + text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder + ).embeddings.token_embedding.weight.data[self.train_ids] = ( + torch.randn( + len(self.train_ids), + ( + text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder + ).config.hidden_size, + ) .to(device=self.device) .to(dtype=self.dtype) * std_token_embedding ) self.embeddings_settings[f"original_embeddings_{idx}"] = ( - text_encoder.text_model.embeddings.token_embedding.weight.data.clone() - ) + text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder + ).embeddings.token_embedding.weight.data.clone() self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding inu = torch.ones((len(tokenizer),), dtype=torch.bool) @@ -959,10 +968,14 @@ def save_embeddings(self, file_path: str): # text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14 idx_to_text_encoder_name = {0: "clip_l", 1: "clip_g"} for idx, text_encoder in enumerate(self.text_encoders): - assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len( - self.tokenizers[0] - ), "Tokenizers should be the same." - new_token_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] + assert ( + text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder + ).embeddings.token_embedding.weight.data.shape[0] == len(self.tokenizers[0]), ( + "Tokenizers should be the same." + ) + new_token_embeddings = ( + text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder + ).embeddings.token_embedding.weight.data[self.train_ids] # New tokens for each text encoder are saved under "clip_l" (for text_encoder 0), "clip_g" (for # text_encoder 1) to keep compatible with the ecosystem. @@ -984,7 +997,9 @@ def device(self): def retract_embeddings(self): for idx, text_encoder in enumerate(self.text_encoders): index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"] - text_encoder.text_model.embeddings.token_embedding.weight.data[index_no_updates] = ( + ( + text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder + ).embeddings.token_embedding.weight.data[index_no_updates] = ( self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates] .to(device=text_encoder.device) .to(dtype=text_encoder.dtype) @@ -995,11 +1010,15 @@ def retract_embeddings(self): std_token_embedding = self.embeddings_settings[f"std_token_embedding_{idx}"] index_updates = ~index_no_updates - new_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] + new_embeddings = ( + text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder + ).embeddings.token_embedding.weight.data[index_updates] off_ratio = std_token_embedding / new_embeddings.std() new_embeddings = new_embeddings * (off_ratio**0.1) - text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] = new_embeddings + ( + text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder + ).embeddings.token_embedding.weight.data[index_updates] = new_embeddings class DreamBoothDataset(Dataset): @@ -2083,8 +2102,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): text_encoder_two.train() # set top parameter requires_grad = True for gradient checkpointing works if args.train_text_encoder: - accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) - accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True) + _te_one = accelerator.unwrap_model(text_encoder_one) + (_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True) + _te_two = accelerator.unwrap_model(text_encoder_two) + (_te_two.text_model if hasattr(_te_two, "text_model") else _te_two).embeddings.requires_grad_(True) for step, batch in enumerate(train_dataloader): if pivoted: diff --git a/examples/custom_diffusion/train_custom_diffusion.py b/examples/custom_diffusion/train_custom_diffusion.py index 2ce451917709..e7647917d10c 100644 --- a/examples/custom_diffusion/train_custom_diffusion.py +++ b/examples/custom_diffusion/train_custom_diffusion.py @@ -874,10 +874,11 @@ def main(args): token_embeds[x] = token_embeds[y] # Freeze all parameters except for the token embeddings in text encoder + text_module = text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder params_to_freeze = itertools.chain( - text_encoder.text_model.encoder.parameters(), - text_encoder.text_model.final_layer_norm.parameters(), - text_encoder.text_model.embeddings.position_embedding.parameters(), + text_module.encoder.parameters(), + text_module.final_layer_norm.parameters(), + text_module.embeddings.position_embedding.parameters(), ) freeze_params(params_to_freeze) ######################################################## diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index e0e7d2e40e56..6514962b4a58 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1691,7 +1691,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.train_text_encoder: text_encoder_one.train() # set top parameter requires_grad = True for gradient checkpointing works - unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) + _te_one = unwrap_model(text_encoder_one) + (_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True) for step, batch in enumerate(train_dataloader): models_to_accumulate = [transformer] diff --git a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py index dee65761e92b..e8fb88ce6c10 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py +++ b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py @@ -1896,7 +1896,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.train_text_encoder: text_encoder_one.train() # set top parameter requires_grad = True for gradient checkpointing works - unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) + _te_one = unwrap_model(text_encoder_one) + (_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True) for step, batch in enumerate(train_dataloader): models_to_accumulate = [transformer] diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 4f49ef4bd801..41b98f6d8e7a 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -1719,8 +1719,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): text_encoder_two.train() # set top parameter requires_grad = True for gradient checkpointing works - accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) - accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True) + _te_one = accelerator.unwrap_model(text_encoder_one) + (_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True) + _te_two = accelerator.unwrap_model(text_encoder_two) + (_te_two.text_model if hasattr(_te_two, "text_model") else _te_two).embeddings.requires_grad_(True) for step, batch in enumerate(train_dataloader): models_to_accumulate = [transformer] diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 502ce1a3f1ec..cfd144bd566d 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1661,8 +1661,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): text_encoder_two.train() # set top parameter requires_grad = True for gradient checkpointing works - accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) - accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True) + _te_one = accelerator.unwrap_model(text_encoder_one) + (_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True) + _te_two = accelerator.unwrap_model(text_encoder_two) + (_te_two.text_model if hasattr(_te_two, "text_model") else _te_two).embeddings.requires_grad_(True) for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 1aaa701d8ceb..46efa0d00559 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -702,9 +702,10 @@ def main(): vae.requires_grad_(False) unet.requires_grad_(False) # Freeze all parameters except for the token embeddings in text encoder - text_encoder.text_model.encoder.requires_grad_(False) - text_encoder.text_model.final_layer_norm.requires_grad_(False) - text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) + text_module = text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder + text_module.encoder.requires_grad_(False) + text_module.final_layer_norm.requires_grad_(False) + text_module.embeddings.position_embedding.requires_grad_(False) if args.gradient_checkpointing: # Keep unet in train mode if we are using gradient checkpointing to save memory. diff --git a/examples/textual_inversion/textual_inversion_sdxl.py b/examples/textual_inversion/textual_inversion_sdxl.py index 3e9151034eaa..8fde356d445b 100644 --- a/examples/textual_inversion/textual_inversion_sdxl.py +++ b/examples/textual_inversion/textual_inversion_sdxl.py @@ -717,12 +717,14 @@ def main(): unet.requires_grad_(False) # Freeze all parameters except for the token embeddings in text encoder - text_encoder_1.text_model.encoder.requires_grad_(False) - text_encoder_1.text_model.final_layer_norm.requires_grad_(False) - text_encoder_1.text_model.embeddings.position_embedding.requires_grad_(False) - text_encoder_2.text_model.encoder.requires_grad_(False) - text_encoder_2.text_model.final_layer_norm.requires_grad_(False) - text_encoder_2.text_model.embeddings.position_embedding.requires_grad_(False) + text_module_1 = text_encoder_1.text_model if hasattr(text_encoder_1, "text_model") else text_encoder_1 + text_module_1.encoder.requires_grad_(False) + text_module_1.final_layer_norm.requires_grad_(False) + text_module_1.embeddings.position_embedding.requires_grad_(False) + text_module_2 = text_encoder_2.text_model if hasattr(text_encoder_2, "text_model") else text_encoder_2 + text_module_2.encoder.requires_grad_(False) + text_module_2.final_layer_norm.requires_grad_(False) + text_module_2.embeddings.position_embedding.requires_grad_(False) if args.gradient_checkpointing: text_encoder_1.gradient_checkpointing_enable() @@ -767,8 +769,12 @@ def main(): optimizer = optimizer_class( # only optimize the embeddings [ - text_encoder_1.text_model.embeddings.token_embedding.weight, - text_encoder_2.text_model.embeddings.token_embedding.weight, + ( + text_encoder_1.text_model if hasattr(text_encoder_1, "text_model") else text_encoder_1 + ).embeddings.token_embedding.weight, + ( + text_encoder_2.text_model if hasattr(text_encoder_2, "text_model") else text_encoder_2 + ).embeddings.token_embedding.weight, ], lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2),