Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions invokeai/backend/model_manager/configs/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,12 +1120,20 @@ def _has_anima_keys(state_dict: dict[str | int, Any]) -> bool:
(unique to Anima - the LLM Adapter that bridges Qwen3 text encoder to the Cosmos DiT)
alongside Cosmos Predict2 DiT keys (blocks, t_embedder, x_embedder, final_layer).

The checkpoint keys may have a `net.` prefix (e.g. `net.llm_adapter.`, `net.blocks.`).
The checkpoint keys may have a `net.` prefix (e.g. `net.llm_adapter.`, `net.blocks.`)
or a `model.diffusion_model.` prefix (ComfyUI bundled checkpoint format).
"""
has_llm_adapter = False
has_cosmos_dit = False

# Cosmos DiT key prefixes — support both with and without `net.` prefix
# LLM adapter key prefixes — support bare, `net.`, and `model.diffusion_model.` prefixes
llm_adapter_prefixes = (
"llm_adapter.",
"net.llm_adapter.",
"model.diffusion_model.llm_adapter.",
)

# Cosmos DiT key prefixes — support bare, `net.`, and `model.diffusion_model.` prefixes
cosmos_prefixes = (
"blocks.",
"t_embedder.",
Expand All @@ -1135,16 +1143,19 @@ def _has_anima_keys(state_dict: dict[str | int, Any]) -> bool:
"net.t_embedder.",
"net.x_embedder.",
"net.final_layer.",
"model.diffusion_model.blocks.",
"model.diffusion_model.t_embedder.",
"model.diffusion_model.x_embedder.",
"model.diffusion_model.final_layer.",
)

for key in state_dict.keys():
if isinstance(key, int):
continue
if key.startswith("llm_adapter.") or key.startswith("net.llm_adapter."):
if any(key.startswith(p) for p in llm_adapter_prefixes):
has_llm_adapter = True
for prefix in cosmos_prefixes:
if key.startswith(prefix):
has_cosmos_dit = True
if any(key.startswith(p) for p in cosmos_prefixes):
has_cosmos_dit = True
if has_llm_adapter and has_cosmos_dit:
return True

Expand Down
11 changes: 6 additions & 5 deletions invokeai/backend/model_manager/load/model_loaders/anima.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,12 @@ def _load_from_singlefile(
# Load the state dict from safetensors
sd = load_file(model_path)

# Strip the `net.` prefix that all Anima checkpoint keys have
# e.g., "net.blocks.0.self_attn.q_proj.weight" -> "blocks.0.self_attn.q_proj.weight"
# Handle different checkpoint packaging formats:
# - Official format: keys prefixed with `net.` (e.g. `net.blocks.0...`)
# - ComfyUI bundled format: transformer keys prefixed with `model.diffusion_model.`
# alongside `first_stage_model.*` (VAE) and `cond_stage_model.*` (text encoder)
prefix_to_strip = None
for prefix in ["net."]:
for prefix in ["model.diffusion_model.", "net."]:
if any(k.startswith(prefix) for k in sd.keys() if isinstance(k, str)):
prefix_to_strip = prefix
break
Expand All @@ -80,8 +82,7 @@ def _load_from_singlefile(
for key, value in sd.items():
if isinstance(key, str) and key.startswith(prefix_to_strip):
stripped_sd[key[len(prefix_to_strip) :]] = value
else:
stripped_sd[key] = value
# Skip non-transformer keys from bundled checkpoints (VAE, text encoder)
sd = stripped_sd

# Create an empty AnimaTransformer with Anima's default architecture parameters
Expand Down
162 changes: 162 additions & 0 deletions tests/backend/model_manager/configs/test_anima_model_identification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import pytest

from invokeai.backend.model_manager.configs.main import _has_anima_keys


def _make_state_dict(prefixes: list[str], keys: list[str]) -> dict[str, object]:
"""Build a minimal fake state dict with the given prefixes applied to the given keys."""
return {f"{prefix}{key}": None for prefix in prefixes for key in keys}


# Minimal keys that satisfy both llm_adapter and cosmos DiT requirements
ANIMA_LLM_ADAPTER_KEYS = ["llm_adapter.blocks.0.cross_attn.k_norm.weight"]
ANIMA_COSMOS_DIT_KEYS = [
"blocks.0.adaln_modulation_cross_attn.1.weight",
"t_embedder.1.linear_1.weight",
"x_embedder.proj.1.weight",
"final_layer.adaln_modulation.1.weight",
]


class TestHasAnimaKeys:
"""Tests for _has_anima_keys heuristic used during model identification."""

def test_bare_keys(self):
"""Bare keys (no prefix) should be recognized."""
sd = _make_state_dict([""], ANIMA_LLM_ADAPTER_KEYS + ANIMA_COSMOS_DIT_KEYS)
assert _has_anima_keys(sd) is True

def test_net_prefix(self):
"""Official format with `net.` prefix should be recognized."""
sd = _make_state_dict(["net."], ANIMA_LLM_ADAPTER_KEYS + ANIMA_COSMOS_DIT_KEYS)
assert _has_anima_keys(sd) is True

def test_comfyui_bundled_prefix(self):
"""ComfyUI bundled format with `model.diffusion_model.` prefix should be recognized."""
sd = _make_state_dict(["model.diffusion_model."], ANIMA_LLM_ADAPTER_KEYS + ANIMA_COSMOS_DIT_KEYS)
assert _has_anima_keys(sd) is True

def test_comfyui_bundled_with_extra_keys(self):
"""Bundled checkpoint with VAE and text encoder keys should still be recognized."""
sd = _make_state_dict(["model.diffusion_model."], ANIMA_LLM_ADAPTER_KEYS + ANIMA_COSMOS_DIT_KEYS)
# Add bundled VAE and text encoder keys (should not interfere)
sd["first_stage_model.conv1.weight"] = None
sd["first_stage_model.encoder.downsamples.0.weight"] = None
sd["cond_stage_model.qwen3_06b.transformer.model.embed_tokens.weight"] = None
assert _has_anima_keys(sd) is True

def test_missing_llm_adapter_keys(self):
"""Should not match if llm_adapter keys are absent."""
sd = _make_state_dict([""], ANIMA_COSMOS_DIT_KEYS)
assert _has_anima_keys(sd) is False

def test_missing_cosmos_dit_keys(self):
"""Should not match if Cosmos DiT keys are absent."""
sd = _make_state_dict([""], ANIMA_LLM_ADAPTER_KEYS)
assert _has_anima_keys(sd) is False

def test_empty_state_dict(self):
"""Empty state dict should not match."""
assert _has_anima_keys({}) is False

def test_unrelated_keys(self):
"""State dict with unrelated keys should not match."""
sd = {
"model.diffusion_model.input_blocks.0.0.weight": None,
"model.diffusion_model.output_blocks.0.0.weight": None,
"cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": None,
}
assert _has_anima_keys(sd) is False

@pytest.mark.parametrize(
"prefix",
["", "net.", "model.diffusion_model."],
)
def test_all_prefixes_parametrized(self, prefix: str):
"""All supported prefix formats should be recognized."""
sd = _make_state_dict([prefix], ANIMA_LLM_ADAPTER_KEYS + ANIMA_COSMOS_DIT_KEYS)
assert _has_anima_keys(sd) is True


class TestAnimaDoesNotConflictWithOtherModels:
"""Verify that _has_anima_keys does not false-positive on similar model architectures."""

def test_flux_bundled_checkpoint(self):
"""FLUX bundled checkpoints use double_blocks/single_blocks, not blocks — should not match."""
sd = {
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale": None,
"model.diffusion_model.double_blocks.0.img_attn.proj.weight": None,
"model.diffusion_model.single_blocks.0.linear1.weight": None,
"model.diffusion_model.context_embedder.weight": None,
"model.diffusion_model.img_in.weight": None,
}
assert _has_anima_keys(sd) is False

def test_sd1_bundled_checkpoint(self):
"""SD1/SD2/SDXL bundled checkpoints use input_blocks/output_blocks — should not match."""
sd = {
"model.diffusion_model.input_blocks.0.0.weight": None,
"model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": None,
"model.diffusion_model.output_blocks.0.0.weight": None,
"model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight": None,
"first_stage_model.encoder.down.0.block.0.conv1.weight": None,
"cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": None,
}
assert _has_anima_keys(sd) is False

def test_raw_cosmos_dit_without_llm_adapter(self):
"""A raw Cosmos Predict2 DiT (without Anima's LLM adapter) should not match."""
sd = {
"blocks.0.adaln_modulation_cross_attn.1.weight": None,
"blocks.0.self_attn.q_proj.weight": None,
"t_embedder.1.linear_1.weight": None,
"x_embedder.proj.1.weight": None,
"final_layer.adaln_modulation.1.weight": None,
}
assert _has_anima_keys(sd) is False

def test_z_image_checkpoint(self):
"""Z-Image uses blocks.* but with cap_embedder/context_refiner — should not match."""
sd = {
"model.diffusion_model.blocks.0.attn.to_q.weight": None,
"model.diffusion_model.blocks.0.attn.to_k.weight": None,
"model.diffusion_model.cap_embedder.0.weight": None,
"model.diffusion_model.context_refiner.blocks.0.weight": None,
"model.diffusion_model.t_embedder.mlp.0.weight": None,
"model.diffusion_model.x_embedder.proj.weight": None,
}
# Z-Image has blocks/t_embedder/x_embedder but NOT llm_adapter
assert _has_anima_keys(sd) is False

def test_qwen_image_checkpoint(self):
"""QwenImage uses txt_in/txt_norm/img_in — should not match."""
sd = {
"txt_in.weight": None,
"txt_norm.weight": None,
"img_in.weight": None,
"double_blocks.0.img_attn.proj.weight": None,
"single_blocks.0.linear1.weight": None,
}
assert _has_anima_keys(sd) is False

def test_flux_lora_does_not_match(self):
"""FLUX LoRA weights should not match as Anima."""
sd = {
"double_blocks.0.img_attn.proj.lora_down.weight": None,
"double_blocks.0.img_attn.proj.lora_up.weight": None,
"single_blocks.0.linear1.lora_down.weight": None,
}
assert _has_anima_keys(sd) is False

def test_cosmos_dit_bundled_without_llm_adapter(self):
"""Bundled Cosmos DiT (model.diffusion_model. prefix) but no llm_adapter — should not match."""
sd = {
"model.diffusion_model.blocks.0.self_attn.q_proj.weight": None,
"model.diffusion_model.t_embedder.1.linear_1.weight": None,
"model.diffusion_model.x_embedder.proj.1.weight": None,
"model.diffusion_model.final_layer.adaln_modulation.1.weight": None,
"first_stage_model.encoder.downsamples.0.weight": None,
"cond_stage_model.transformer.model.embed_tokens.weight": None,
}
# Has all the Cosmos DiT keys but missing llm_adapter — not Anima
assert _has_anima_keys(sd) is False
Loading