diff --git a/src/maxdiffusion/configs/ltx2_3_video.yml b/src/maxdiffusion/configs/ltx2_3_video.yml index a57106231..a6275cfae 100644 --- a/src/maxdiffusion/configs/ltx2_3_video.yml +++ b/src/maxdiffusion/configs/ltx2_3_video.yml @@ -142,18 +142,20 @@ enable_lora: False # Distilled LoRA # lora_config: { -# lora_model_name_or_path: ["Lightricks/LTX-2"], -# weight_name: ["ltx-2-19b-distilled-lora-384.safetensors"], +# lora_model_name_or_path: ["Lightricks/LTX-2.3"], +# weight_name: ["ltx-2.3-22b-distilled-lora-384.safetensors"], # adapter_name: ["distilled-lora-384"], +# # placeholder - the real value is mixed per-layer ranks: 32/128/256/384 +# # and the loader reads each layer's REAL rank from the LoRA tensor shapes # rank: [384] # } # Standard LoRA lora_config: { - lora_model_name_or_path: ["Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In"], - weight_name: ["ltx-2-19b-lora-camera-control-dolly-in.safetensors"], - adapter_name: ["camera-control-dolly-in"], - rank: [32] + lora_model_name_or_path: ["Lightricks/LTX-2.3-22b-IC-LoRA-Colorization"], + weight_name: ["ltx-2.3-22b-ic-lora-colorization-0.9.safetensors"], + adapter_name: ["colorization"], + rank: [128] } diff --git a/src/maxdiffusion/loaders/lora_conversion_utils.py b/src/maxdiffusion/loaders/lora_conversion_utils.py index ca0371b76..97fc63dcd 100644 --- a/src/maxdiffusion/loaders/lora_conversion_utils.py +++ b/src/maxdiffusion/loaders/lora_conversion_utils.py @@ -716,31 +716,37 @@ def translate_ltx2_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False): "attn1.to_k": "attn1.to_k", "attn1.to_v": "attn1.to_v", "attn1.to_out": "attn1.to_out.0", + "attn1.to_gate_logits": "attn1.to_gate_logits", # Audio Self Attention (audio_attn1) "audio_attn1.to_q": "audio_attn1.to_q", "audio_attn1.to_k": "audio_attn1.to_k", "audio_attn1.to_v": "audio_attn1.to_v", "audio_attn1.to_out": "audio_attn1.to_out.0", + "audio_attn1.to_gate_logits": "audio_attn1.to_gate_logits", # Audio Cross Attention (audio_attn2) "audio_attn2.to_q": "audio_attn2.to_q", "audio_attn2.to_k": "audio_attn2.to_k", "audio_attn2.to_v": "audio_attn2.to_v", "audio_attn2.to_out": "audio_attn2.to_out.0", + "audio_attn2.to_gate_logits": "audio_attn2.to_gate_logits", # Cross Attention (attn2) "attn2.to_q": "attn2.to_q", "attn2.to_k": "attn2.to_k", "attn2.to_v": "attn2.to_v", "attn2.to_out": "attn2.to_out.0", + "attn2.to_gate_logits": "attn2.to_gate_logits", # Audio to Video Cross Attention "audio_to_video_attn.to_q": "audio_to_video_attn.to_q", "audio_to_video_attn.to_k": "audio_to_video_attn.to_k", "audio_to_video_attn.to_v": "audio_to_video_attn.to_v", "audio_to_video_attn.to_out": "audio_to_video_attn.to_out.0", + "audio_to_video_attn.to_gate_logits": "audio_to_video_attn.to_gate_logits", # Video to Audio Cross Attention "video_to_audio_attn.to_q": "video_to_audio_attn.to_q", "video_to_audio_attn.to_k": "video_to_audio_attn.to_k", "video_to_audio_attn.to_v": "video_to_audio_attn.to_v", "video_to_audio_attn.to_out": "video_to_audio_attn.to_out.0", + "video_to_audio_attn.to_gate_logits": "video_to_audio_attn.to_gate_logits", # Feed Forward "ff.net_0": "ff.net.0.proj", "ff.net_2": "ff.net.2", @@ -778,6 +784,13 @@ def translate_ltx2_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False): "caption_projection.linear_2": "diffusion_model.caption_projection.linear_2", "audio_caption_projection.linear_1": "diffusion_model.audio_caption_projection.linear_1", "audio_caption_projection.linear_2": "diffusion_model.audio_caption_projection.linear_2", + # Prompt-conditioned AdaLN + "prompt_adaln.linear": "diffusion_model.prompt_adaln_single.linear", + "prompt_adaln.emb.timestep_embedder.linear_1": "diffusion_model.prompt_adaln_single.emb.timestep_embedder.linear_1", + "prompt_adaln.emb.timestep_embedder.linear_2": "diffusion_model.prompt_adaln_single.emb.timestep_embedder.linear_2", + "audio_prompt_adaln.linear": "diffusion_model.audio_prompt_adaln_single.linear", + "audio_prompt_adaln.emb.timestep_embedder.linear_1": "diffusion_model.audio_prompt_adaln_single.emb.timestep_embedder.linear_1", + "audio_prompt_adaln.emb.timestep_embedder.linear_2": "diffusion_model.audio_prompt_adaln_single.emb.timestep_embedder.linear_2", # Connectors "feature_extractor.linear": "text_embedding_projection.aggregate_embed", } diff --git a/src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py b/src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py index 247b3ba2e..1fad541c6 100644 --- a/src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py +++ b/src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py @@ -50,26 +50,34 @@ def load_lora_weights( def translate_fn(nnx_path_str): return lora_conversion_utils.translate_ltx2_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=scan_layers) - h_state_dict = None - if hasattr(pipeline, "transformer") and transformer_weight_name: + if not transformer_weight_name: + max_logging.log("No LoRA weight name provided; skipping LoRA load.") + return pipeline + + h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs) + transformer_state_dict = {} + connector_state_dict = {} + if hasattr(pipeline, "transformer"): max_logging.log(f"Merging LoRA into transformer with rank={rank}") - h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs) # Filter state dict for transformer keys to avoid confusing warnings - transformer_state_dict = {k: v for k, v in h_state_dict.items() if k.startswith("diffusion_model")} + transformer_state_dict = {k: v for k, v in h_state_dict.items() if k.startswith("diffusion_model.")} merge_fn(pipeline.transformer, transformer_state_dict, rank, scale, translate_fn, dtype=dtype) else: - max_logging.log("transformer not found or no weight name provided for LoRA.") + max_logging.log("transformer not found.") if hasattr(pipeline, "connectors"): max_logging.log(f"Merging LoRA into connectors with rank={rank}") - if h_state_dict is None and transformer_weight_name: - h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs) + connector_state_dict = {k: v for k, v in h_state_dict.items() if k.startswith("text_embedding_projection.")} + merge_fn(pipeline.connectors, connector_state_dict, rank, scale, translate_fn, dtype=dtype) + else: + max_logging.log("connectors not found.") - if h_state_dict is not None: - # Filter state dict for connector keys to avoid confusing warnings - connector_state_dict = {k: v for k, v in h_state_dict.items() if k.startswith("text_embedding_projection")} - merge_fn(pipeline.connectors, connector_state_dict, rank, scale, translate_fn, dtype=dtype) - else: - max_logging.log("Could not load LoRA state dict for connectors.") + # Warn if there are keys routed to no target. + # the merge_fn warns about unmatched keys in each dict, so we only warn about any leftovers + unmatched_keys = set(h_state_dict) - set(transformer_state_dict) - set(connector_state_dict) + if unmatched_keys: + max_logging.log( + f"{len(unmatched_keys)} key(s) in LoRA dictionary routed to no merge target: {unmatched_keys}" + ) return pipeline