Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
356 changes: 351 additions & 5 deletions src/MaxText/utils/ckpt_conversion/utils/param_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,12 +959,138 @@ def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_MAPPING(config, scan_layers=False):
)

# Add "thinker." prefix to text mapping values
def add_prefix_recursive(value):
"""Recursively add 'thinker.' prefix to strings, handling nested lists."""
if isinstance(value, list):
return [add_prefix_recursive(v) for v in value]
else:
return f"thinker.{value}"

for key, value in text_mapping.items():
text_mapping[key] = [f"thinker.{v}" for v in value] if isinstance(value, list) else f"thinker.{value}"
text_mapping[key] = add_prefix_recursive(value)
mapping.update(text_mapping)

# TODO(hengtaoguo): Add vision, audio, and other modality mappings here similarly
# mapping.update(vision_mapping), mapping.update(audio_mapping), etc.
# Vision mapping
vision_config = config["thinker_config"]["vision_config"]
n_vision_layers = vision_config["depth"]

# Vision patch embedding
mapping["params-vision_encoder-Qwen3OmniMoeVisionEncoder_0-patch_embed-proj-kernel"] = (
"thinker.visual.patch_embed.proj.weight"
)
mapping["params-vision_encoder-Qwen3OmniMoeVisionEncoder_0-patch_embed-proj-bias"] = (
"thinker.visual.patch_embed.proj.bias"
)

# Vision positional embedding
mapping["params-vision_encoder-Qwen3OmniMoeVisionEncoder_0-pos_embed_interpolate-pos_embed"] = (
"thinker.visual.pos_embed.weight"
)

# Vision blocks (27 layers)
for i in range(n_vision_layers):
prefix = f"params-vision_encoder-Qwen3OmniMoeVisionEncoder_0-blocks_{i}"
hf_prefix = f"thinker.visual.blocks.{i}"

# Layer norms
mapping[f"{prefix}-ln1-scale"] = f"{hf_prefix}.norm1.weight"
mapping[f"{prefix}-ln1-bias"] = f"{hf_prefix}.norm1.bias"
mapping[f"{prefix}-ln2-scale"] = f"{hf_prefix}.norm2.weight"
mapping[f"{prefix}-ln2-bias"] = f"{hf_prefix}.norm2.bias"

# Attention (HF has fused QKV, MaxText has separate Q/K/V)
# We'll handle the split/fusion in the hook functions
mapping[f"{prefix}-attn-attn-query-kernel"] = f"{hf_prefix}.attn.qkv.weight"
mapping[f"{prefix}-attn-attn-query-bias"] = f"{hf_prefix}.attn.qkv.bias"
mapping[f"{prefix}-attn-attn-key-kernel"] = f"{hf_prefix}.attn.qkv.weight"
mapping[f"{prefix}-attn-attn-key-bias"] = f"{hf_prefix}.attn.qkv.bias"
mapping[f"{prefix}-attn-attn-value-kernel"] = f"{hf_prefix}.attn.qkv.weight"
mapping[f"{prefix}-attn-attn-value-bias"] = f"{hf_prefix}.attn.qkv.bias"
mapping[f"{prefix}-attn-attn-out-kernel"] = f"{hf_prefix}.attn.proj.weight"
mapping[f"{prefix}-attn-attn-out-bias"] = f"{hf_prefix}.attn.proj.bias"

# MLP
mapping[f"{prefix}-mlp-kernel"] = f"{hf_prefix}.mlp.linear_fc1.weight"
mapping[f"{prefix}-mlp-bias"] = f"{hf_prefix}.mlp.linear_fc1.bias"
mapping[f"{prefix}-mlp_out-kernel"] = f"{hf_prefix}.mlp.linear_fc2.weight"
mapping[f"{prefix}-mlp_out-bias"] = f"{hf_prefix}.mlp.linear_fc2.bias"

# Vision merger_list (deep mergers at layers 8, 16, 24)
deepstack_indexes = vision_config.get("deepstack_visual_indexes", [8, 16, 24])
for merger_idx, _ in enumerate(deepstack_indexes):
prefix = f"params-vision_encoder-Qwen3OmniMoeVisionEncoder_0-merger_{merger_idx}"
hf_prefix = f"thinker.visual.merger_list.{merger_idx}"

mapping[f"{prefix}-ln_q-scale"] = f"{hf_prefix}.ln_q.weight"
mapping[f"{prefix}-ln_q-bias"] = f"{hf_prefix}.ln_q.bias"
mapping[f"{prefix}-mlp_0-kernel"] = f"{hf_prefix}.mlp.0.weight"
mapping[f"{prefix}-mlp_0-bias"] = f"{hf_prefix}.mlp.0.bias"
mapping[f"{prefix}-mlp_2-kernel"] = f"{hf_prefix}.mlp.2.weight"
mapping[f"{prefix}-mlp_2-bias"] = f"{hf_prefix}.mlp.2.bias"

# Vision projector (final merger)
mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-ln_q-scale"] = "thinker.visual.merger.ln_q.weight"
mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-ln_q-bias"] = "thinker.visual.merger.ln_q.bias"
mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-mlp_0-kernel"] = (
"thinker.visual.merger.mlp.0.weight"
)
mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-mlp_0-bias"] = "thinker.visual.merger.mlp.0.bias"
mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-mlp_2-kernel"] = (
"thinker.visual.merger.mlp.2.weight"
)
mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-mlp_2-bias"] = "thinker.visual.merger.mlp.2.bias"

# Audio mapping
audio_config = config["thinker_config"]["audio_config"]
n_audio_layers = audio_config["encoder_layers"]

# Audio conv layers (3 Conv2D layers for downsampling)
mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv2d1-kernel"] = "thinker.audio_tower.conv2d1.weight"
mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv2d1-bias"] = "thinker.audio_tower.conv2d1.bias"
mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv2d2-kernel"] = "thinker.audio_tower.conv2d2.weight"
mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv2d2-bias"] = "thinker.audio_tower.conv2d2.bias"
mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv2d3-kernel"] = "thinker.audio_tower.conv2d3.weight"
mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv2d3-bias"] = "thinker.audio_tower.conv2d3.bias"

# Audio conv output projection
mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv_out-kernel"] = "thinker.audio_tower.conv_out.weight"

# Audio encoder layers (32 layers)
for i in range(n_audio_layers):
prefix = f"params-audio_encoder-Qwen3OmniAudioEncoder_0-layers_{i}"
hf_prefix = f"thinker.audio_tower.layers.{i}"

# Layer norms
mapping[f"{prefix}-input_layer_norm-scale"] = f"{hf_prefix}.self_attn_layer_norm.weight"
mapping[f"{prefix}-input_layer_norm-bias"] = f"{hf_prefix}.self_attn_layer_norm.bias"
mapping[f"{prefix}-post_attention_layer_norm-scale"] = f"{hf_prefix}.final_layer_norm.weight"
mapping[f"{prefix}-post_attention_layer_norm-bias"] = f"{hf_prefix}.final_layer_norm.bias"

# Attention (separate Q/K/V)
mapping[f"{prefix}-self_attention_audio-query-kernel"] = f"{hf_prefix}.self_attn.q_proj.weight"
mapping[f"{prefix}-self_attention_audio-query-bias"] = f"{hf_prefix}.self_attn.q_proj.bias"
mapping[f"{prefix}-self_attention_audio-key-kernel"] = f"{hf_prefix}.self_attn.k_proj.weight"
mapping[f"{prefix}-self_attention_audio-key-bias"] = f"{hf_prefix}.self_attn.k_proj.bias"
mapping[f"{prefix}-self_attention_audio-value-kernel"] = f"{hf_prefix}.self_attn.v_proj.weight"
mapping[f"{prefix}-self_attention_audio-value-bias"] = f"{hf_prefix}.self_attn.v_proj.bias"
mapping[f"{prefix}-self_attention_audio-out-kernel"] = f"{hf_prefix}.self_attn.out_proj.weight"
mapping[f"{prefix}-self_attention_audio-out-bias"] = f"{hf_prefix}.self_attn.out_proj.bias"

# MLP (AudioMLP has 2 linear layers: fc1 and fc2)
mapping[f"{prefix}-AudioMLP-wi-kernel"] = f"{hf_prefix}.fc1.weight"
mapping[f"{prefix}-AudioMLP-wi-bias"] = f"{hf_prefix}.fc1.bias"
mapping[f"{prefix}-AudioMLP-wo-kernel"] = f"{hf_prefix}.fc2.weight"
mapping[f"{prefix}-AudioMLP-wo-bias"] = f"{hf_prefix}.fc2.bias"

# Audio post layer norm
mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-layernorm_post-scale"] = "thinker.audio_tower.ln_post.weight"
mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-layernorm_post-bias"] = "thinker.audio_tower.ln_post.bias"

# Audio projector (2 linear layers)
mapping["params-audio_encoder-Qwen3OmniAudioProjector_0-proj1-kernel"] = "thinker.audio_tower.proj1.weight"
mapping["params-audio_encoder-Qwen3OmniAudioProjector_0-proj1-bias"] = "thinker.audio_tower.proj1.bias"
mapping["params-audio_encoder-Qwen3OmniAudioProjector_0-proj2-kernel"] = "thinker.audio_tower.proj2.weight"
mapping["params-audio_encoder-Qwen3OmniAudioProjector_0-proj2-bias"] = "thinker.audio_tower.proj2.bias"

return mapping

Expand Down Expand Up @@ -1001,8 +1127,228 @@ def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, scan_layers=False, saving
)
mapping.update(text_hooks)

# TODO(hengtaoguo): Add vision, audio, and other modality mappings here similarly
# mapping.update(vision_hooks), mapping.update(audio_hooks), etc.
# Vision hooks
vision_config = config["thinker_config"]["vision_config"]
n_vision_layers = vision_config["depth"]
hidden_size = vision_config["hidden_size"]

def reshape_kernel_vision(input_tensor, target_shape):
"""Reshape kernel for vision layers."""
if saving_to_hf:
flipped_target_shape = np.flip(np.array(target_shape))
return input_tensor.reshape(flipped_target_shape).T
else:
return input_tensor.T.reshape(target_shape)

def reshape_conv3d_patch_embed(input_tensor, target_shape):
"""Reshape 3D conv patch embedding weight.
HF: (out_channels, in_channels, temporal, height, width)
MaxText: (temporal, height, width, in_channels, out_channels)
"""
if saving_to_hf:
# MaxText -> HF: (T, H, W, C_in, C_out) -> (C_out, C_in, T, H, W)
return input_tensor.transpose(4, 3, 0, 1, 2)
else:
# HF -> MaxText: (C_out, C_in, T, H, W) -> (T, H, W, C_in, C_out)
return input_tensor.transpose(2, 3, 4, 1, 0)

def split_qkv_query(input_tensor, target_shape):
"""Extract Q from fused QKV for HF->MaxText conversion.
HF has fused QKV: (3*hidden_size, hidden_size)
MaxText Q: (hidden_size, num_heads, head_dim)
"""
if saving_to_hf:
# MaxText -> HF: will be handled by fusion hook
raise NotImplementedError("Use fusion hook for MaxText->HF")
else:
# HF -> MaxText: Extract Q from fused QKV
# input_tensor shape: (3*hidden_size, hidden_size)
q_weight = input_tensor[:hidden_size, :] # (hidden_size, hidden_size)
return q_weight.T.reshape(target_shape) # (hidden_size, num_heads, head_dim)

def split_qkv_key(input_tensor, target_shape):
"""Extract K from fused QKV for HF->MaxText conversion."""
if saving_to_hf:
raise NotImplementedError("Use fusion hook for MaxText->HF")
else:
# Extract K from fused QKV
k_weight = input_tensor[hidden_size : 2 * hidden_size, :]
return k_weight.T.reshape(target_shape)

def split_qkv_value(input_tensor, target_shape):
"""Extract V from fused QKV for HF->MaxText conversion."""
if saving_to_hf:
raise NotImplementedError("Use fusion hook for MaxText->HF")
else:
# Extract V from fused QKV
v_weight = input_tensor[2 * hidden_size :, :]
return v_weight.T.reshape(target_shape)

def split_qkv_bias_query(input_tensor, target_shape):
"""Extract Q bias from fused QKV bias."""
if saving_to_hf:
raise NotImplementedError("Use fusion hook for MaxText->HF")
else:
q_bias = input_tensor[:hidden_size]
return q_bias.reshape(target_shape) # (num_heads, head_dim)

def split_qkv_bias_key(input_tensor, target_shape):
"""Extract K bias from fused QKV bias."""
if saving_to_hf:
raise NotImplementedError("Use fusion hook for MaxText->HF")
else:
k_bias = input_tensor[hidden_size : 2 * hidden_size]
return k_bias.reshape(target_shape)

def split_qkv_bias_value(input_tensor, target_shape):
"""Extract V bias from fused QKV bias."""
if saving_to_hf:
raise NotImplementedError("Use fusion hook for MaxText->HF")
else:
v_bias = input_tensor[2 * hidden_size :]
return v_bias.reshape(target_shape)

def reshape_vision_attn_out(input_tensor, target_shape):
"""Reshape vision attention output projection.
HF: (hidden_size, hidden_size)
MaxText: (num_heads, head_dim, hidden_size)
"""
if saving_to_hf:
# MaxText -> HF: (num_heads, head_dim, hidden_size) -> (hidden_size, hidden_size)
return input_tensor.reshape(hidden_size, hidden_size).T
else:
# HF -> MaxText: (hidden_size, hidden_size) -> (num_heads, head_dim, hidden_size)
return input_tensor.T.reshape(target_shape)

# Vision patch embedding
mapping["params-vision_encoder-Qwen3OmniMoeVisionEncoder_0-patch_embed-proj-kernel"] = reshape_conv3d_patch_embed

# Vision blocks
for i in range(n_vision_layers):
prefix = f"params-vision_encoder-Qwen3OmniMoeVisionEncoder_0-blocks_{i}"

# Attention Q/K/V - split from fused QKV
mapping[f"{prefix}-attn-attn-query-kernel"] = split_qkv_query
mapping[f"{prefix}-attn-attn-query-bias"] = split_qkv_bias_query
mapping[f"{prefix}-attn-attn-key-kernel"] = split_qkv_key
mapping[f"{prefix}-attn-attn-key-bias"] = split_qkv_bias_key
mapping[f"{prefix}-attn-attn-value-kernel"] = split_qkv_value
mapping[f"{prefix}-attn-attn-value-bias"] = split_qkv_bias_value

# Attention output
mapping[f"{prefix}-attn-attn-out-kernel"] = reshape_vision_attn_out
# attn-attn-out-bias doesn't need a hook (no reshape needed)

# MLP
mapping[f"{prefix}-mlp-kernel"] = reshape_kernel_vision
mapping[f"{prefix}-mlp_out-kernel"] = reshape_kernel_vision

# Vision merger_list and projector MLPs
deepstack_indexes = vision_config.get("deepstack_visual_indexes", [8, 16, 24])
for merger_idx, _ in enumerate(deepstack_indexes):
prefix = f"params-vision_encoder-Qwen3OmniMoeVisionEncoder_0-merger_{merger_idx}"
mapping[f"{prefix}-mlp_0-kernel"] = reshape_kernel_vision
mapping[f"{prefix}-mlp_2-kernel"] = reshape_kernel_vision

# Vision projector (final merger)
mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-mlp_0-kernel"] = reshape_kernel_vision
mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-mlp_2-kernel"] = reshape_kernel_vision

# Audio hooks
audio_config = config["thinker_config"]["audio_config"]
n_audio_layers = audio_config["encoder_layers"]
hidden_size_audio = audio_config["d_model"]

def reshape_kernel_audio(input_tensor, target_shape):
"""Reshape kernel for audio layers."""
if saving_to_hf:
flipped_target_shape = np.flip(np.array(target_shape))
return input_tensor.reshape(flipped_target_shape).T
else:
return input_tensor.T.reshape(target_shape)

def reshape_conv2d_audio(input_tensor, target_shape):
"""Reshape Conv2D weight for audio.

HF: (out_channels, in_channels, height, width)
MaxText: (height, width, in_channels, out_channels)
"""
if saving_to_hf:
# MaxText -> HF: (H, W, C_in, C_out) -> (C_out, C_in, H, W)
return input_tensor.transpose(3, 2, 0, 1)
else:
# HF -> MaxText: (C_out, C_in, H, W) -> (H, W, C_in, C_out)
return input_tensor.transpose(2, 3, 1, 0)

def reshape_audio_attn_qkv(input_tensor, target_shape):
"""Reshape audio attention Q/K/V projection.

HF: (hidden_size, hidden_size)
MaxText: (hidden_size, num_heads, head_dim)
"""
if saving_to_hf:
# MaxText -> HF: (hidden_size, num_heads, head_dim) -> (hidden_size, hidden_size)
return input_tensor.reshape(hidden_size_audio, hidden_size_audio).T
else:
# HF -> MaxText: (hidden_size, hidden_size) -> (hidden_size, num_heads, head_dim)
return input_tensor.T.reshape(target_shape)

def reshape_audio_attn_out(input_tensor, target_shape):
"""Reshape audio attention output projection.
F
HF: (hidden_size, hidden_size)
MaxText: (num_heads, head_dim, hidden_size)
"""
if saving_to_hf:
# MaxText -> HF: (num_heads, head_dim, hidden_size) -> (hidden_size, hidden_size)
return input_tensor.reshape(hidden_size_audio, hidden_size_audio).T
else:
# HF -> MaxText: (hidden_size, hidden_size) -> (num_heads, head_dim, hidden_size)
return input_tensor.T.reshape(target_shape)

def reshape_audio_attn_qkv_bias(input_tensor, target_shape):
"""Reshape audio attention Q/K/V bias.

HF: (hidden_size,)
MaxText: (num_heads, head_dim)
"""
if saving_to_hf:
# MaxText -> HF: (num_heads, head_dim) -> (hidden_size,)
return input_tensor.reshape(hidden_size_audio)
else:
# HF -> MaxText: (hidden_size,) -> (num_heads, head_dim)
return input_tensor.reshape(target_shape)

# Audio conv layers
mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv2d1-kernel"] = reshape_conv2d_audio
mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv2d2-kernel"] = reshape_conv2d_audio
mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv2d3-kernel"] = reshape_conv2d_audio

# Audio conv output projection
mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv_out-kernel"] = reshape_kernel_audio

# Audio encoder layers
for i in range(n_audio_layers):
prefix = f"params-audio_encoder-Qwen3OmniAudioEncoder_0-layers_{i}"

# Attention Q/K/V
mapping[f"{prefix}-self_attention_audio-query-kernel"] = reshape_audio_attn_qkv
mapping[f"{prefix}-self_attention_audio-query-bias"] = reshape_audio_attn_qkv_bias
mapping[f"{prefix}-self_attention_audio-key-kernel"] = reshape_audio_attn_qkv
mapping[f"{prefix}-self_attention_audio-key-bias"] = reshape_audio_attn_qkv_bias
mapping[f"{prefix}-self_attention_audio-value-kernel"] = reshape_audio_attn_qkv
mapping[f"{prefix}-self_attention_audio-value-bias"] = reshape_audio_attn_qkv_bias

# Attention output
mapping[f"{prefix}-self_attention_audio-out-kernel"] = reshape_audio_attn_out

# MLP
mapping[f"{prefix}-AudioMLP-wi-kernel"] = reshape_kernel_audio
mapping[f"{prefix}-AudioMLP-wo-kernel"] = reshape_kernel_audio

# Audio projector
mapping["params-audio_encoder-Qwen3OmniAudioProjector_0-proj1-kernel"] = reshape_kernel_audio
mapping["params-audio_encoder-Qwen3OmniAudioProjector_0-proj2-kernel"] = reshape_kernel_audio

return mapping

Expand Down
Loading