Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -895,19 +895,16 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
self.train_ids_t5 = tokenizer.convert_tokens_to_ids(self.inserting_toks)

# random initialization of new tokens
embeds = (
text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens
)
text_module = text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
embeds = text_module.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens
std_token_embedding = embeds.weight.data.std()

logger.info(f"{idx} text encoder's std_token_embedding: {std_token_embedding}")

train_ids = self.train_ids if idx == 0 else self.train_ids_t5
# if initializer_concept are not provided, token embeddings are initialized randomly
if args.initializer_concept is None:
hidden_size = (
text_encoder.text_model.config.hidden_size if idx == 0 else text_encoder.encoder.config.hidden_size
)
hidden_size = text_module.config.hidden_size if idx == 0 else text_encoder.encoder.config.hidden_size
embeds.weight.data[train_ids] = (
torch.randn(len(train_ids), hidden_size).to(device=self.device).to(dtype=self.dtype)
* std_token_embedding
Expand Down Expand Up @@ -940,7 +937,8 @@ def save_embeddings(self, file_path: str):
idx_to_text_encoder_name = {0: "clip_l", 1: "t5"}
for idx, text_encoder in enumerate(self.text_encoders):
train_ids = self.train_ids if idx == 0 else self.train_ids_t5
embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared
text_module = text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
embeds = text_module.embeddings.token_embedding if idx == 0 else text_encoder.shared
assert embeds.weight.data.shape[0] == len(self.tokenizers[idx]), "Tokenizers should be the same."
new_token_embeddings = embeds.weight.data[train_ids]

Expand All @@ -962,7 +960,8 @@ def device(self):
@torch.no_grad()
def retract_embeddings(self):
for idx, text_encoder in enumerate(self.text_encoders):
embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared
text_module = text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
embeds = text_module.embeddings.token_embedding if idx == 0 else text_encoder.shared
index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
embeds.weight.data[index_no_updates] = (
self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates]
Expand Down Expand Up @@ -2112,7 +2111,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
if args.train_text_encoder:
text_encoder_one.train()
# set top parameter requires_grad = True for gradient checkpointing works
unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
_te_one = unwrap_model(text_encoder_one)
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)
elif args.train_text_encoder_ti: # textual inversion / pivotal tuning
text_encoder_one.train()
if args.enable_t5_ti:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -763,19 +763,28 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)

# random initialization of new tokens
std_token_embedding = text_encoder.text_model.embeddings.token_embedding.weight.data.std()
std_token_embedding = (
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
).embeddings.token_embedding.weight.data.std()

print(f"{idx} text encoder's std_token_embedding: {std_token_embedding}")

text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] = (
torch.randn(len(self.train_ids), text_encoder.text_model.config.hidden_size)
(
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
).embeddings.token_embedding.weight.data[self.train_ids] = (
torch.randn(
len(self.train_ids),
(
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
).config.hidden_size,
)
.to(device=self.device)
.to(dtype=self.dtype)
* std_token_embedding
)
self.embeddings_settings[f"original_embeddings_{idx}"] = (
text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
)
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
).embeddings.token_embedding.weight.data.clone()
self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding

inu = torch.ones((len(tokenizer),), dtype=torch.bool)
Expand All @@ -794,10 +803,14 @@ def save_embeddings(self, file_path: str):
# text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14 - TODO - change for sd
idx_to_text_encoder_name = {0: "clip_l", 1: "clip_g"}
for idx, text_encoder in enumerate(self.text_encoders):
assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len(
self.tokenizers[0]
), "Tokenizers should be the same."
new_token_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids]
assert (
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
).embeddings.token_embedding.weight.data.shape[0] == len(self.tokenizers[0]), (
"Tokenizers should be the same."
)
new_token_embeddings = (
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
).embeddings.token_embedding.weight.data[self.train_ids]

# New tokens for each text encoder are saved under "clip_l" (for text_encoder 0), "clip_g" (for
# text_encoder 1) to keep compatible with the ecosystem.
Expand All @@ -819,7 +832,9 @@ def device(self):
def retract_embeddings(self):
for idx, text_encoder in enumerate(self.text_encoders):
index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
text_encoder.text_model.embeddings.token_embedding.weight.data[index_no_updates] = (
(
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
).embeddings.token_embedding.weight.data[index_no_updates] = (
self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates]
.to(device=text_encoder.device)
.to(dtype=text_encoder.dtype)
Expand All @@ -830,11 +845,15 @@ def retract_embeddings(self):
std_token_embedding = self.embeddings_settings[f"std_token_embedding_{idx}"]

index_updates = ~index_no_updates
new_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates]
new_embeddings = (
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
).embeddings.token_embedding.weight.data[index_updates]
off_ratio = std_token_embedding / new_embeddings.std()

new_embeddings = new_embeddings * (off_ratio**0.1)
text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] = new_embeddings
(
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
).embeddings.token_embedding.weight.data[index_updates] = new_embeddings


class DreamBoothDataset(Dataset):
Expand Down Expand Up @@ -1704,7 +1723,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
text_encoder_one.train()
# set top parameter requires_grad = True for gradient checkpointing works
if args.train_text_encoder:
text_encoder_one.text_model.embeddings.requires_grad_(True)
_te_one = text_encoder_one
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)

unet.train()
for step, batch in enumerate(train_dataloader):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -929,19 +929,28 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)

# random initialization of new tokens
std_token_embedding = text_encoder.text_model.embeddings.token_embedding.weight.data.std()
std_token_embedding = (
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
).embeddings.token_embedding.weight.data.std()

print(f"{idx} text encoder's std_token_embedding: {std_token_embedding}")

text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] = (
torch.randn(len(self.train_ids), text_encoder.text_model.config.hidden_size)
(
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
).embeddings.token_embedding.weight.data[self.train_ids] = (
torch.randn(
len(self.train_ids),
(
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
).config.hidden_size,
)
.to(device=self.device)
.to(dtype=self.dtype)
* std_token_embedding
)
self.embeddings_settings[f"original_embeddings_{idx}"] = (
text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
)
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
).embeddings.token_embedding.weight.data.clone()
self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding

inu = torch.ones((len(tokenizer),), dtype=torch.bool)
Expand All @@ -959,10 +968,14 @@ def save_embeddings(self, file_path: str):
# text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14
idx_to_text_encoder_name = {0: "clip_l", 1: "clip_g"}
for idx, text_encoder in enumerate(self.text_encoders):
assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len(
self.tokenizers[0]
), "Tokenizers should be the same."
new_token_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids]
assert (
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
).embeddings.token_embedding.weight.data.shape[0] == len(self.tokenizers[0]), (
"Tokenizers should be the same."
)
new_token_embeddings = (
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
).embeddings.token_embedding.weight.data[self.train_ids]

# New tokens for each text encoder are saved under "clip_l" (for text_encoder 0), "clip_g" (for
# text_encoder 1) to keep compatible with the ecosystem.
Expand All @@ -984,7 +997,9 @@ def device(self):
def retract_embeddings(self):
for idx, text_encoder in enumerate(self.text_encoders):
index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
text_encoder.text_model.embeddings.token_embedding.weight.data[index_no_updates] = (
(
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
).embeddings.token_embedding.weight.data[index_no_updates] = (
self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates]
.to(device=text_encoder.device)
.to(dtype=text_encoder.dtype)
Expand All @@ -995,11 +1010,15 @@ def retract_embeddings(self):
std_token_embedding = self.embeddings_settings[f"std_token_embedding_{idx}"]

index_updates = ~index_no_updates
new_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates]
new_embeddings = (
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
).embeddings.token_embedding.weight.data[index_updates]
off_ratio = std_token_embedding / new_embeddings.std()

new_embeddings = new_embeddings * (off_ratio**0.1)
text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] = new_embeddings
(
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
).embeddings.token_embedding.weight.data[index_updates] = new_embeddings


class DreamBoothDataset(Dataset):
Expand Down Expand Up @@ -2083,8 +2102,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
text_encoder_two.train()
# set top parameter requires_grad = True for gradient checkpointing works
if args.train_text_encoder:
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)
_te_one = accelerator.unwrap_model(text_encoder_one)
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)
_te_two = accelerator.unwrap_model(text_encoder_two)
(_te_two.text_model if hasattr(_te_two, "text_model") else _te_two).embeddings.requires_grad_(True)

for step, batch in enumerate(train_dataloader):
if pivoted:
Expand Down
7 changes: 4 additions & 3 deletions examples/custom_diffusion/train_custom_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,10 +874,11 @@ def main(args):
token_embeds[x] = token_embeds[y]

# Freeze all parameters except for the token embeddings in text encoder
text_module = text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
params_to_freeze = itertools.chain(
text_encoder.text_model.encoder.parameters(),
text_encoder.text_model.final_layer_norm.parameters(),
text_encoder.text_model.embeddings.position_embedding.parameters(),
text_module.encoder.parameters(),
text_module.final_layer_norm.parameters(),
text_module.embeddings.position_embedding.parameters(),
)
freeze_params(params_to_freeze)
########################################################
Expand Down
3 changes: 2 additions & 1 deletion examples/dreambooth/train_dreambooth_lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -1691,7 +1691,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
if args.train_text_encoder:
text_encoder_one.train()
# set top parameter requires_grad = True for gradient checkpointing works
unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
_te_one = unwrap_model(text_encoder_one)
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)

for step, batch in enumerate(train_dataloader):
models_to_accumulate = [transformer]
Expand Down
3 changes: 2 additions & 1 deletion examples/dreambooth/train_dreambooth_lora_flux_kontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -1896,7 +1896,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
if args.train_text_encoder:
text_encoder_one.train()
# set top parameter requires_grad = True for gradient checkpointing works
unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
_te_one = unwrap_model(text_encoder_one)
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)

for step, batch in enumerate(train_dataloader):
models_to_accumulate = [transformer]
Expand Down
6 changes: 4 additions & 2 deletions examples/dreambooth/train_dreambooth_lora_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1719,8 +1719,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
text_encoder_two.train()

# set top parameter requires_grad = True for gradient checkpointing works
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)
_te_one = accelerator.unwrap_model(text_encoder_one)
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)
_te_two = accelerator.unwrap_model(text_encoder_two)
(_te_two.text_model if hasattr(_te_two, "text_model") else _te_two).embeddings.requires_grad_(True)

for step, batch in enumerate(train_dataloader):
models_to_accumulate = [transformer]
Expand Down
6 changes: 4 additions & 2 deletions examples/dreambooth/train_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1661,8 +1661,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
text_encoder_two.train()

# set top parameter requires_grad = True for gradient checkpointing works
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)
_te_one = accelerator.unwrap_model(text_encoder_one)
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)
_te_two = accelerator.unwrap_model(text_encoder_two)
(_te_two.text_model if hasattr(_te_two, "text_model") else _te_two).embeddings.requires_grad_(True)

for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
Expand Down
7 changes: 4 additions & 3 deletions examples/textual_inversion/textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,9 +702,10 @@ def main():
vae.requires_grad_(False)
unet.requires_grad_(False)
# Freeze all parameters except for the token embeddings in text encoder
text_encoder.text_model.encoder.requires_grad_(False)
text_encoder.text_model.final_layer_norm.requires_grad_(False)
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
text_module = text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
text_module.encoder.requires_grad_(False)
text_module.final_layer_norm.requires_grad_(False)
text_module.embeddings.position_embedding.requires_grad_(False)

if args.gradient_checkpointing:
# Keep unet in train mode if we are using gradient checkpointing to save memory.
Expand Down
Loading
Loading