Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions invokeai/backend/model_manager/configs/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
Main_Checkpoint_Anima_Config,
Main_Checkpoint_Flux2_Config,
Main_Checkpoint_FLUX_Config,
Main_Checkpoint_QwenImage_Config,
Main_Checkpoint_SD1_Config,
Main_Checkpoint_SD2_Config,
Main_Checkpoint_SDXL_Config,
Expand Down Expand Up @@ -183,6 +184,7 @@
Annotated[Main_Checkpoint_SDXLRefiner_Config, Main_Checkpoint_SDXLRefiner_Config.get_tag()],
Annotated[Main_Checkpoint_Flux2_Config, Main_Checkpoint_Flux2_Config.get_tag()],
Annotated[Main_Checkpoint_FLUX_Config, Main_Checkpoint_FLUX_Config.get_tag()],
Annotated[Main_Checkpoint_QwenImage_Config, Main_Checkpoint_QwenImage_Config.get_tag()],
Annotated[Main_Checkpoint_ZImage_Config, Main_Checkpoint_ZImage_Config.get_tag()],
Annotated[Main_Checkpoint_Anima_Config, Main_Checkpoint_Anima_Config.get_tag()],
# Main (Pipeline) - quantized formats
Expand Down
96 changes: 76 additions & 20 deletions invokeai/backend/model_manager/configs/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import re
from abc import ABC
from pathlib import Path
from typing import Any, Literal, Self

from pydantic import BaseModel, ConfigDict, Field
Expand Down Expand Up @@ -1329,20 +1331,88 @@ def _get_qwen_image_variant(cls, mod: ModelOnDisk) -> QwenImageVariantType:
return QwenImageVariantType.Generate


# ComfyUI single-file checkpoints prefix every transformer key with one of these.
# The loaders strip them before instantiating the model (see `_strip_comfyui_prefix`
# in the qwen_image loader); detection must strip them too so the two paths agree.
_COMFYUI_KEY_PREFIXES = ("model.diffusion_model.", "diffusion_model.")


def _strip_comfyui_key_prefix(key: str) -> str:
"""Strip a leading ComfyUI `model.diffusion_model.` / `diffusion_model.` prefix from a key."""
for prefix in _COMFYUI_KEY_PREFIXES:
if key.startswith(prefix):
return key[len(prefix) :]
return key


def _has_qwen_image_keys(state_dict: dict[str | int, Any]) -> bool:
"""Check if state dict contains Qwen Image Edit transformer keys.

Qwen Image Edit uses 'txt_in' and 'txt_norm' instead of 'context_embedder' (FLUX).
This distinguishes it from FLUX and other architectures.
This distinguishes it from FLUX and other architectures. ComfyUI-style prefixes are
stripped first so prefixed checkpoints are detected and reach the loader.
"""
has_txt_in = any(isinstance(k, str) and k.startswith("txt_in.") for k in state_dict.keys())
has_txt_norm = any(isinstance(k, str) and k.startswith("txt_norm.") for k in state_dict.keys())
has_img_in = any(isinstance(k, str) and k.startswith("img_in.") for k in state_dict.keys())
keys = [_strip_comfyui_key_prefix(k) for k in state_dict.keys() if isinstance(k, str)]
has_txt_in = any(k.startswith("txt_in.") for k in keys)
has_txt_norm = any(k.startswith("txt_norm.") for k in keys)
has_img_in = any(k.startswith("img_in.") for k in keys)
# Must NOT have context_embedder (which would indicate FLUX)
has_context_embedder = any(isinstance(k, str) and "context_embedder" in k for k in state_dict.keys())
has_context_embedder = any("context_embedder" in k for k in keys)
return has_txt_in and has_txt_norm and has_img_in and not has_context_embedder


# Matches "edit" as a standalone token (delimited by start/end or any non-alphanumeric
# separator), so `qwen_image_edit_2509` matches but `credited` / `edited` / `unedited` do not.
_EDIT_TOKEN_RE = re.compile(r"(?:^|[^a-z0-9])edit(?:[^a-z0-9]|$)")


def _infer_qwen_image_variant(sd: dict[str | int, Any], path: Path) -> QwenImageVariantType:
"""Infer Qwen Image variant from state dict marker or filename heuristic.

Edit-variant models include an `__index_timestep_zero__` tensor used by the
`zero_cond_t` dual-modulation path. Falls back to a filename "edit" token check
for converters that don't emit the marker.
"""
marker = "__index_timestep_zero__"
if marker in sd or any(isinstance(k, str) and _strip_comfyui_key_prefix(k) == marker for k in sd):
return QwenImageVariantType.Edit
if _EDIT_TOKEN_RE.search(path.stem.lower()):
return QwenImageVariantType.Edit
return QwenImageVariantType.Generate


class Main_Checkpoint_QwenImage_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
"""Model config for Qwen Image single-file checkpoint models (safetensors, etc).

Covers both raw bf16/fp16 checkpoints and ComfyUI-style fp8_scaled checkpoints.
The loader dequantizes fp8 weights back to bf16 at load time; the
`default_settings.fp8_storage` toggle can then optionally re-cast to fp8 for
VRAM savings.
"""

base: Literal[BaseModelType.QwenImage] = Field(default=BaseModelType.QwenImage)
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
variant: QwenImageVariantType | None = Field(default=None)

@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_file(mod)

raise_for_override_fields(cls, override_fields)

sd = mod.load_state_dict()

if not _has_qwen_image_keys(sd):
raise NotAMatchError("state dict does not look like a Qwen Image model")

if _has_ggml_tensors(sd):
raise NotAMatchError("state dict looks like GGUF quantized")

explicit_variant = override_fields.pop("variant", None) or _infer_qwen_image_variant(sd, mod.path)

return cls(**override_fields, variant=explicit_variant)


class Main_GGUF_QwenImage_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
"""Model config for GGUF-quantized Qwen Image transformer models."""

Expand All @@ -1364,21 +1434,7 @@ def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -
if not _has_ggml_tensors(sd):
raise NotAMatchError("state dict does not look like GGUF quantized")

# Infer variant from the state dict if not explicitly provided.
# The Edit variant includes an extra tensor `__index_timestep_zero__` (used by the
# `zero_cond_t` dual-modulation path in diffusers' QwenImageTransformer2DModel).
# If the marker tensor is missing, fall back to the filename heuristic since older
# or alternate GGUF converters may not emit it.
explicit_variant = override_fields.pop("variant", None)
if explicit_variant is None:
if "__index_timestep_zero__" in sd:
explicit_variant = QwenImageVariantType.Edit
else:
filename = mod.path.stem.lower()
if "edit" in filename:
explicit_variant = QwenImageVariantType.Edit
else:
explicit_variant = QwenImageVariantType.Generate
explicit_variant = override_fields.pop("variant", None) or _infer_qwen_image_variant(sd, mod.path)

return cls(**override_fields, variant=explicit_variant)

Expand Down
36 changes: 32 additions & 4 deletions invokeai/backend/model_manager/configs/qwen3_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,22 @@ def _has_ggml_tensors(state_dict: dict[str | int, Any]) -> bool:
return any(isinstance(v, GGMLTensor) for v in state_dict.values())


def _has_qwen_vl_visual_tower(state_dict: dict[str | int, Any]) -> bool:
"""Check if state dict bundles a Qwen2.5-VL / Qwen2-VL vision tower.

Qwen-VL encoders ship the visual tower (`visual.blocks.*`, `visual.patch_embed.*`)
alongside the language model, whereas a text-only Qwen3 encoder never does. A Qwen-VL
file otherwise satisfies the Qwen3 key heuristic (it has `model.layers.*` /
`model.embed_tokens.weight` too), so without this check it matches *both* the Qwen3 and
the QwenVLEncoder configs and the tiebreak can misroute it to Qwen3. We use it to keep
the two mutually exclusive.
"""
for key in state_dict.keys():
if isinstance(key, str) and (key.startswith("visual.blocks.") or key.startswith("visual.patch_embed.")):
return True
return False


def _get_qwen3_variant_from_state_dict(state_dict: dict[str | int, Any]) -> Optional[Qwen3VariantType]:
"""Determine Qwen3 variant (0.6B, 4B, or 8B) from state dict based on hidden_size.

Expand Down Expand Up @@ -137,9 +153,15 @@ def _get_variant_or_default(cls, mod: ModelOnDisk) -> Qwen3VariantType:

@classmethod
def _validate_looks_like_qwen3_model(cls, mod: ModelOnDisk) -> None:
has_qwen3_keys = _has_qwen3_keys(mod.load_state_dict())
if not has_qwen3_keys:
state_dict = mod.load_state_dict()
if not _has_qwen3_keys(state_dict):
raise NotAMatchError("state dict does not look like a Qwen3 model")
# Reject Qwen2.5-VL / Qwen2-VL encoders: they carry a visual tower and must be
# classified as QwenVLEncoder (text-only Qwen3 encoders never have one).
if _has_qwen_vl_visual_tower(state_dict):
raise NotAMatchError(
"state dict bundles a Qwen-VL visual tower; this is a Qwen-VL encoder, not a text-only Qwen3 encoder"
)

@classmethod
def _validate_does_not_look_like_gguf_quantized(cls, mod: ModelOnDisk) -> None:
Expand Down Expand Up @@ -272,9 +294,15 @@ def _get_variant_or_default(cls, mod: ModelOnDisk) -> Qwen3VariantType:

@classmethod
def _validate_looks_like_qwen3_model(cls, mod: ModelOnDisk) -> None:
has_qwen3_keys = _has_qwen3_keys(mod.load_state_dict())
if not has_qwen3_keys:
state_dict = mod.load_state_dict()
if not _has_qwen3_keys(state_dict):
raise NotAMatchError("state dict does not look like a Qwen3 model")
# Reject Qwen2.5-VL / Qwen2-VL encoders: they carry a visual tower and must be
# classified as QwenVLEncoder (text-only Qwen3 encoders never have one).
if _has_qwen_vl_visual_tower(state_dict):
raise NotAMatchError(
"state dict bundles a Qwen-VL visual tower; this is a Qwen-VL encoder, not a text-only Qwen3 encoder"
)

@classmethod
def _validate_looks_like_gguf_quantized(cls, mod: ModelOnDisk) -> None:
Expand Down
Loading
Loading