Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions src/maxdiffusion/configs/ltx2_3_video.yml
Original file line number Diff line number Diff line change
Expand Up @@ -142,18 +142,20 @@ enable_lora: False

# Distilled LoRA
# lora_config: {
# lora_model_name_or_path: ["Lightricks/LTX-2"],
# weight_name: ["ltx-2-19b-distilled-lora-384.safetensors"],
# lora_model_name_or_path: ["Lightricks/LTX-2.3"],
# weight_name: ["ltx-2.3-22b-distilled-lora-384.safetensors"],
# adapter_name: ["distilled-lora-384"],
# # placeholder - the real value is mixed per-layer ranks: 32/128/256/384
# # and the loader reads each layer's REAL rank from the LoRA tensor shapes
# rank: [384]
# }

# Standard LoRA
lora_config: {
lora_model_name_or_path: ["Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In"],
weight_name: ["ltx-2-19b-lora-camera-control-dolly-in.safetensors"],
adapter_name: ["camera-control-dolly-in"],
rank: [32]
lora_model_name_or_path: ["Lightricks/LTX-2.3-22b-IC-LoRA-Colorization"],
weight_name: ["ltx-2.3-22b-ic-lora-colorization-0.9.safetensors"],
adapter_name: ["colorization"],
rank: [128]
}


Expand Down
13 changes: 13 additions & 0 deletions src/maxdiffusion/loaders/lora_conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,31 +716,37 @@ def translate_ltx2_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False):
"attn1.to_k": "attn1.to_k",
"attn1.to_v": "attn1.to_v",
"attn1.to_out": "attn1.to_out.0",
"attn1.to_gate_logits": "attn1.to_gate_logits",
# Audio Self Attention (audio_attn1)
"audio_attn1.to_q": "audio_attn1.to_q",
"audio_attn1.to_k": "audio_attn1.to_k",
"audio_attn1.to_v": "audio_attn1.to_v",
"audio_attn1.to_out": "audio_attn1.to_out.0",
"audio_attn1.to_gate_logits": "audio_attn1.to_gate_logits",
# Audio Cross Attention (audio_attn2)
"audio_attn2.to_q": "audio_attn2.to_q",
"audio_attn2.to_k": "audio_attn2.to_k",
"audio_attn2.to_v": "audio_attn2.to_v",
"audio_attn2.to_out": "audio_attn2.to_out.0",
"audio_attn2.to_gate_logits": "audio_attn2.to_gate_logits",
# Cross Attention (attn2)
"attn2.to_q": "attn2.to_q",
"attn2.to_k": "attn2.to_k",
"attn2.to_v": "attn2.to_v",
"attn2.to_out": "attn2.to_out.0",
"attn2.to_gate_logits": "attn2.to_gate_logits",
# Audio to Video Cross Attention
"audio_to_video_attn.to_q": "audio_to_video_attn.to_q",
"audio_to_video_attn.to_k": "audio_to_video_attn.to_k",
"audio_to_video_attn.to_v": "audio_to_video_attn.to_v",
"audio_to_video_attn.to_out": "audio_to_video_attn.to_out.0",
"audio_to_video_attn.to_gate_logits": "audio_to_video_attn.to_gate_logits",
# Video to Audio Cross Attention
"video_to_audio_attn.to_q": "video_to_audio_attn.to_q",
"video_to_audio_attn.to_k": "video_to_audio_attn.to_k",
"video_to_audio_attn.to_v": "video_to_audio_attn.to_v",
"video_to_audio_attn.to_out": "video_to_audio_attn.to_out.0",
"video_to_audio_attn.to_gate_logits": "video_to_audio_attn.to_gate_logits",
# Feed Forward
"ff.net_0": "ff.net.0.proj",
"ff.net_2": "ff.net.2",
Expand Down Expand Up @@ -778,6 +784,13 @@ def translate_ltx2_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False):
"caption_projection.linear_2": "diffusion_model.caption_projection.linear_2",
"audio_caption_projection.linear_1": "diffusion_model.audio_caption_projection.linear_1",
"audio_caption_projection.linear_2": "diffusion_model.audio_caption_projection.linear_2",
# Prompt-conditioned AdaLN
"prompt_adaln.linear": "diffusion_model.prompt_adaln_single.linear",
"prompt_adaln.emb.timestep_embedder.linear_1": "diffusion_model.prompt_adaln_single.emb.timestep_embedder.linear_1",
"prompt_adaln.emb.timestep_embedder.linear_2": "diffusion_model.prompt_adaln_single.emb.timestep_embedder.linear_2",
"audio_prompt_adaln.linear": "diffusion_model.audio_prompt_adaln_single.linear",
"audio_prompt_adaln.emb.timestep_embedder.linear_1": "diffusion_model.audio_prompt_adaln_single.emb.timestep_embedder.linear_1",
"audio_prompt_adaln.emb.timestep_embedder.linear_2": "diffusion_model.audio_prompt_adaln_single.emb.timestep_embedder.linear_2",
# Connectors
"feature_extractor.linear": "text_embedding_projection.aggregate_embed",
}
Expand Down
34 changes: 21 additions & 13 deletions src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,26 +50,34 @@ def load_lora_weights(
def translate_fn(nnx_path_str):
return lora_conversion_utils.translate_ltx2_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=scan_layers)

h_state_dict = None
if hasattr(pipeline, "transformer") and transformer_weight_name:
if not transformer_weight_name:
max_logging.log("No LoRA weight name provided; skipping LoRA load.")
return pipeline

h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs)
transformer_state_dict = {}
connector_state_dict = {}
if hasattr(pipeline, "transformer"):
max_logging.log(f"Merging LoRA into transformer with rank={rank}")
h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs)
# Filter state dict for transformer keys to avoid confusing warnings
transformer_state_dict = {k: v for k, v in h_state_dict.items() if k.startswith("diffusion_model")}
transformer_state_dict = {k: v for k, v in h_state_dict.items() if k.startswith("diffusion_model.")}
merge_fn(pipeline.transformer, transformer_state_dict, rank, scale, translate_fn, dtype=dtype)
else:
max_logging.log("transformer not found or no weight name provided for LoRA.")
max_logging.log("transformer not found.")

if hasattr(pipeline, "connectors"):
max_logging.log(f"Merging LoRA into connectors with rank={rank}")
if h_state_dict is None and transformer_weight_name:
h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs)
connector_state_dict = {k: v for k, v in h_state_dict.items() if k.startswith("text_embedding_projection.")}
merge_fn(pipeline.connectors, connector_state_dict, rank, scale, translate_fn, dtype=dtype)
else:
max_logging.log("connectors not found.")

if h_state_dict is not None:
# Filter state dict for connector keys to avoid confusing warnings
connector_state_dict = {k: v for k, v in h_state_dict.items() if k.startswith("text_embedding_projection")}
merge_fn(pipeline.connectors, connector_state_dict, rank, scale, translate_fn, dtype=dtype)
else:
max_logging.log("Could not load LoRA state dict for connectors.")
# Warn if there are keys routed to no target.
# the merge_fn warns about unmatched keys in each dict, so we only warn about any leftovers
unmatched_keys = set(h_state_dict) - set(transformer_state_dict) - set(connector_state_dict)
if unmatched_keys:
max_logging.log(
f"{len(unmatched_keys)} key(s) in LoRA dictionary routed to no merge target: {unmatched_keys}"
)

return pipeline
Loading