From 564b282d177887886ac220b92571a4725cbb422e Mon Sep 17 00:00:00 2001 From: Your Name Date: Sun, 3 May 2026 03:23:04 -0400 Subject: [PATCH] fix(mm): support ComfyUI bundled checkpoint format for Anima model identification Anima finetunes packaged in ComfyUI format use `model.diffusion_model.*` prefixed keys instead of bare or `net.*` prefixed keys. Update the probe and loader to recognize and handle this format. Co-Authored-By: Claude Opus 4.6 --- .../backend/model_manager/configs/main.py | 23 ++- .../model_manager/load/model_loaders/anima.py | 11 +- .../test_anima_model_identification.py | 162 ++++++++++++++++++ 3 files changed, 185 insertions(+), 11 deletions(-) create mode 100644 tests/backend/model_manager/configs/test_anima_model_identification.py diff --git a/invokeai/backend/model_manager/configs/main.py b/invokeai/backend/model_manager/configs/main.py index da5bc5eed36..2d002d68dc5 100644 --- a/invokeai/backend/model_manager/configs/main.py +++ b/invokeai/backend/model_manager/configs/main.py @@ -1120,12 +1120,20 @@ def _has_anima_keys(state_dict: dict[str | int, Any]) -> bool: (unique to Anima - the LLM Adapter that bridges Qwen3 text encoder to the Cosmos DiT) alongside Cosmos Predict2 DiT keys (blocks, t_embedder, x_embedder, final_layer). - The checkpoint keys may have a `net.` prefix (e.g. `net.llm_adapter.`, `net.blocks.`). + The checkpoint keys may have a `net.` prefix (e.g. `net.llm_adapter.`, `net.blocks.`) + or a `model.diffusion_model.` prefix (ComfyUI bundled checkpoint format). """ has_llm_adapter = False has_cosmos_dit = False - # Cosmos DiT key prefixes — support both with and without `net.` prefix + # LLM adapter key prefixes — support bare, `net.`, and `model.diffusion_model.` prefixes + llm_adapter_prefixes = ( + "llm_adapter.", + "net.llm_adapter.", + "model.diffusion_model.llm_adapter.", + ) + + # Cosmos DiT key prefixes — support bare, `net.`, and `model.diffusion_model.` prefixes cosmos_prefixes = ( "blocks.", "t_embedder.", @@ -1135,16 +1143,19 @@ def _has_anima_keys(state_dict: dict[str | int, Any]) -> bool: "net.t_embedder.", "net.x_embedder.", "net.final_layer.", + "model.diffusion_model.blocks.", + "model.diffusion_model.t_embedder.", + "model.diffusion_model.x_embedder.", + "model.diffusion_model.final_layer.", ) for key in state_dict.keys(): if isinstance(key, int): continue - if key.startswith("llm_adapter.") or key.startswith("net.llm_adapter."): + if any(key.startswith(p) for p in llm_adapter_prefixes): has_llm_adapter = True - for prefix in cosmos_prefixes: - if key.startswith(prefix): - has_cosmos_dit = True + if any(key.startswith(p) for p in cosmos_prefixes): + has_cosmos_dit = True if has_llm_adapter and has_cosmos_dit: return True diff --git a/invokeai/backend/model_manager/load/model_loaders/anima.py b/invokeai/backend/model_manager/load/model_loaders/anima.py index 6549c220a86..05da7c26038 100644 --- a/invokeai/backend/model_manager/load/model_loaders/anima.py +++ b/invokeai/backend/model_manager/load/model_loaders/anima.py @@ -67,10 +67,12 @@ def _load_from_singlefile( # Load the state dict from safetensors sd = load_file(model_path) - # Strip the `net.` prefix that all Anima checkpoint keys have - # e.g., "net.blocks.0.self_attn.q_proj.weight" -> "blocks.0.self_attn.q_proj.weight" + # Handle different checkpoint packaging formats: + # - Official format: keys prefixed with `net.` (e.g. `net.blocks.0...`) + # - ComfyUI bundled format: transformer keys prefixed with `model.diffusion_model.` + # alongside `first_stage_model.*` (VAE) and `cond_stage_model.*` (text encoder) prefix_to_strip = None - for prefix in ["net."]: + for prefix in ["model.diffusion_model.", "net."]: if any(k.startswith(prefix) for k in sd.keys() if isinstance(k, str)): prefix_to_strip = prefix break @@ -80,8 +82,7 @@ def _load_from_singlefile( for key, value in sd.items(): if isinstance(key, str) and key.startswith(prefix_to_strip): stripped_sd[key[len(prefix_to_strip) :]] = value - else: - stripped_sd[key] = value + # Skip non-transformer keys from bundled checkpoints (VAE, text encoder) sd = stripped_sd # Create an empty AnimaTransformer with Anima's default architecture parameters diff --git a/tests/backend/model_manager/configs/test_anima_model_identification.py b/tests/backend/model_manager/configs/test_anima_model_identification.py new file mode 100644 index 00000000000..04bfea2dca9 --- /dev/null +++ b/tests/backend/model_manager/configs/test_anima_model_identification.py @@ -0,0 +1,162 @@ +import pytest + +from invokeai.backend.model_manager.configs.main import _has_anima_keys + + +def _make_state_dict(prefixes: list[str], keys: list[str]) -> dict[str, object]: + """Build a minimal fake state dict with the given prefixes applied to the given keys.""" + return {f"{prefix}{key}": None for prefix in prefixes for key in keys} + + +# Minimal keys that satisfy both llm_adapter and cosmos DiT requirements +ANIMA_LLM_ADAPTER_KEYS = ["llm_adapter.blocks.0.cross_attn.k_norm.weight"] +ANIMA_COSMOS_DIT_KEYS = [ + "blocks.0.adaln_modulation_cross_attn.1.weight", + "t_embedder.1.linear_1.weight", + "x_embedder.proj.1.weight", + "final_layer.adaln_modulation.1.weight", +] + + +class TestHasAnimaKeys: + """Tests for _has_anima_keys heuristic used during model identification.""" + + def test_bare_keys(self): + """Bare keys (no prefix) should be recognized.""" + sd = _make_state_dict([""], ANIMA_LLM_ADAPTER_KEYS + ANIMA_COSMOS_DIT_KEYS) + assert _has_anima_keys(sd) is True + + def test_net_prefix(self): + """Official format with `net.` prefix should be recognized.""" + sd = _make_state_dict(["net."], ANIMA_LLM_ADAPTER_KEYS + ANIMA_COSMOS_DIT_KEYS) + assert _has_anima_keys(sd) is True + + def test_comfyui_bundled_prefix(self): + """ComfyUI bundled format with `model.diffusion_model.` prefix should be recognized.""" + sd = _make_state_dict(["model.diffusion_model."], ANIMA_LLM_ADAPTER_KEYS + ANIMA_COSMOS_DIT_KEYS) + assert _has_anima_keys(sd) is True + + def test_comfyui_bundled_with_extra_keys(self): + """Bundled checkpoint with VAE and text encoder keys should still be recognized.""" + sd = _make_state_dict(["model.diffusion_model."], ANIMA_LLM_ADAPTER_KEYS + ANIMA_COSMOS_DIT_KEYS) + # Add bundled VAE and text encoder keys (should not interfere) + sd["first_stage_model.conv1.weight"] = None + sd["first_stage_model.encoder.downsamples.0.weight"] = None + sd["cond_stage_model.qwen3_06b.transformer.model.embed_tokens.weight"] = None + assert _has_anima_keys(sd) is True + + def test_missing_llm_adapter_keys(self): + """Should not match if llm_adapter keys are absent.""" + sd = _make_state_dict([""], ANIMA_COSMOS_DIT_KEYS) + assert _has_anima_keys(sd) is False + + def test_missing_cosmos_dit_keys(self): + """Should not match if Cosmos DiT keys are absent.""" + sd = _make_state_dict([""], ANIMA_LLM_ADAPTER_KEYS) + assert _has_anima_keys(sd) is False + + def test_empty_state_dict(self): + """Empty state dict should not match.""" + assert _has_anima_keys({}) is False + + def test_unrelated_keys(self): + """State dict with unrelated keys should not match.""" + sd = { + "model.diffusion_model.input_blocks.0.0.weight": None, + "model.diffusion_model.output_blocks.0.0.weight": None, + "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": None, + } + assert _has_anima_keys(sd) is False + + @pytest.mark.parametrize( + "prefix", + ["", "net.", "model.diffusion_model."], + ) + def test_all_prefixes_parametrized(self, prefix: str): + """All supported prefix formats should be recognized.""" + sd = _make_state_dict([prefix], ANIMA_LLM_ADAPTER_KEYS + ANIMA_COSMOS_DIT_KEYS) + assert _has_anima_keys(sd) is True + + +class TestAnimaDoesNotConflictWithOtherModels: + """Verify that _has_anima_keys does not false-positive on similar model architectures.""" + + def test_flux_bundled_checkpoint(self): + """FLUX bundled checkpoints use double_blocks/single_blocks, not blocks — should not match.""" + sd = { + "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale": None, + "model.diffusion_model.double_blocks.0.img_attn.proj.weight": None, + "model.diffusion_model.single_blocks.0.linear1.weight": None, + "model.diffusion_model.context_embedder.weight": None, + "model.diffusion_model.img_in.weight": None, + } + assert _has_anima_keys(sd) is False + + def test_sd1_bundled_checkpoint(self): + """SD1/SD2/SDXL bundled checkpoints use input_blocks/output_blocks — should not match.""" + sd = { + "model.diffusion_model.input_blocks.0.0.weight": None, + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": None, + "model.diffusion_model.output_blocks.0.0.weight": None, + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight": None, + "first_stage_model.encoder.down.0.block.0.conv1.weight": None, + "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": None, + } + assert _has_anima_keys(sd) is False + + def test_raw_cosmos_dit_without_llm_adapter(self): + """A raw Cosmos Predict2 DiT (without Anima's LLM adapter) should not match.""" + sd = { + "blocks.0.adaln_modulation_cross_attn.1.weight": None, + "blocks.0.self_attn.q_proj.weight": None, + "t_embedder.1.linear_1.weight": None, + "x_embedder.proj.1.weight": None, + "final_layer.adaln_modulation.1.weight": None, + } + assert _has_anima_keys(sd) is False + + def test_z_image_checkpoint(self): + """Z-Image uses blocks.* but with cap_embedder/context_refiner — should not match.""" + sd = { + "model.diffusion_model.blocks.0.attn.to_q.weight": None, + "model.diffusion_model.blocks.0.attn.to_k.weight": None, + "model.diffusion_model.cap_embedder.0.weight": None, + "model.diffusion_model.context_refiner.blocks.0.weight": None, + "model.diffusion_model.t_embedder.mlp.0.weight": None, + "model.diffusion_model.x_embedder.proj.weight": None, + } + # Z-Image has blocks/t_embedder/x_embedder but NOT llm_adapter + assert _has_anima_keys(sd) is False + + def test_qwen_image_checkpoint(self): + """QwenImage uses txt_in/txt_norm/img_in — should not match.""" + sd = { + "txt_in.weight": None, + "txt_norm.weight": None, + "img_in.weight": None, + "double_blocks.0.img_attn.proj.weight": None, + "single_blocks.0.linear1.weight": None, + } + assert _has_anima_keys(sd) is False + + def test_flux_lora_does_not_match(self): + """FLUX LoRA weights should not match as Anima.""" + sd = { + "double_blocks.0.img_attn.proj.lora_down.weight": None, + "double_blocks.0.img_attn.proj.lora_up.weight": None, + "single_blocks.0.linear1.lora_down.weight": None, + } + assert _has_anima_keys(sd) is False + + def test_cosmos_dit_bundled_without_llm_adapter(self): + """Bundled Cosmos DiT (model.diffusion_model. prefix) but no llm_adapter — should not match.""" + sd = { + "model.diffusion_model.blocks.0.self_attn.q_proj.weight": None, + "model.diffusion_model.t_embedder.1.linear_1.weight": None, + "model.diffusion_model.x_embedder.proj.1.weight": None, + "model.diffusion_model.final_layer.adaln_modulation.1.weight": None, + "first_stage_model.encoder.downsamples.0.weight": None, + "cond_stage_model.transformer.model.embed_tokens.weight": None, + } + # Has all the Cosmos DiT keys but missing llm_adapter — not Anima + assert _has_anima_keys(sd) is False