diff --git a/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py b/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py index 03be2560d..8662436e9 100644 --- a/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py +++ b/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py @@ -959,12 +959,138 @@ def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_MAPPING(config, scan_layers=False): ) # Add "thinker." prefix to text mapping values + def add_prefix_recursive(value): + """Recursively add 'thinker.' prefix to strings, handling nested lists.""" + if isinstance(value, list): + return [add_prefix_recursive(v) for v in value] + else: + return f"thinker.{value}" + for key, value in text_mapping.items(): - text_mapping[key] = [f"thinker.{v}" for v in value] if isinstance(value, list) else f"thinker.{value}" + text_mapping[key] = add_prefix_recursive(value) mapping.update(text_mapping) - # TODO(hengtaoguo): Add vision, audio, and other modality mappings here similarly - # mapping.update(vision_mapping), mapping.update(audio_mapping), etc. + # Vision mapping + vision_config = config["thinker_config"]["vision_config"] + n_vision_layers = vision_config["depth"] + + # Vision patch embedding + mapping["params-vision_encoder-Qwen3OmniMoeVisionEncoder_0-patch_embed-proj-kernel"] = ( + "thinker.visual.patch_embed.proj.weight" + ) + mapping["params-vision_encoder-Qwen3OmniMoeVisionEncoder_0-patch_embed-proj-bias"] = ( + "thinker.visual.patch_embed.proj.bias" + ) + + # Vision positional embedding + mapping["params-vision_encoder-Qwen3OmniMoeVisionEncoder_0-pos_embed_interpolate-pos_embed"] = ( + "thinker.visual.pos_embed.weight" + ) + + # Vision blocks (27 layers) + for i in range(n_vision_layers): + prefix = f"params-vision_encoder-Qwen3OmniMoeVisionEncoder_0-blocks_{i}" + hf_prefix = f"thinker.visual.blocks.{i}" + + # Layer norms + mapping[f"{prefix}-ln1-scale"] = f"{hf_prefix}.norm1.weight" + mapping[f"{prefix}-ln1-bias"] = f"{hf_prefix}.norm1.bias" + mapping[f"{prefix}-ln2-scale"] = f"{hf_prefix}.norm2.weight" + mapping[f"{prefix}-ln2-bias"] = f"{hf_prefix}.norm2.bias" + + # Attention (HF has fused QKV, MaxText has separate Q/K/V) + # We'll handle the split/fusion in the hook functions + mapping[f"{prefix}-attn-attn-query-kernel"] = f"{hf_prefix}.attn.qkv.weight" + mapping[f"{prefix}-attn-attn-query-bias"] = f"{hf_prefix}.attn.qkv.bias" + mapping[f"{prefix}-attn-attn-key-kernel"] = f"{hf_prefix}.attn.qkv.weight" + mapping[f"{prefix}-attn-attn-key-bias"] = f"{hf_prefix}.attn.qkv.bias" + mapping[f"{prefix}-attn-attn-value-kernel"] = f"{hf_prefix}.attn.qkv.weight" + mapping[f"{prefix}-attn-attn-value-bias"] = f"{hf_prefix}.attn.qkv.bias" + mapping[f"{prefix}-attn-attn-out-kernel"] = f"{hf_prefix}.attn.proj.weight" + mapping[f"{prefix}-attn-attn-out-bias"] = f"{hf_prefix}.attn.proj.bias" + + # MLP + mapping[f"{prefix}-mlp-kernel"] = f"{hf_prefix}.mlp.linear_fc1.weight" + mapping[f"{prefix}-mlp-bias"] = f"{hf_prefix}.mlp.linear_fc1.bias" + mapping[f"{prefix}-mlp_out-kernel"] = f"{hf_prefix}.mlp.linear_fc2.weight" + mapping[f"{prefix}-mlp_out-bias"] = f"{hf_prefix}.mlp.linear_fc2.bias" + + # Vision merger_list (deep mergers at layers 8, 16, 24) + deepstack_indexes = vision_config.get("deepstack_visual_indexes", [8, 16, 24]) + for merger_idx, _ in enumerate(deepstack_indexes): + prefix = f"params-vision_encoder-Qwen3OmniMoeVisionEncoder_0-merger_{merger_idx}" + hf_prefix = f"thinker.visual.merger_list.{merger_idx}" + + mapping[f"{prefix}-ln_q-scale"] = f"{hf_prefix}.ln_q.weight" + mapping[f"{prefix}-ln_q-bias"] = f"{hf_prefix}.ln_q.bias" + mapping[f"{prefix}-mlp_0-kernel"] = f"{hf_prefix}.mlp.0.weight" + mapping[f"{prefix}-mlp_0-bias"] = f"{hf_prefix}.mlp.0.bias" + mapping[f"{prefix}-mlp_2-kernel"] = f"{hf_prefix}.mlp.2.weight" + mapping[f"{prefix}-mlp_2-bias"] = f"{hf_prefix}.mlp.2.bias" + + # Vision projector (final merger) + mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-ln_q-scale"] = "thinker.visual.merger.ln_q.weight" + mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-ln_q-bias"] = "thinker.visual.merger.ln_q.bias" + mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-mlp_0-kernel"] = ( + "thinker.visual.merger.mlp.0.weight" + ) + mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-mlp_0-bias"] = "thinker.visual.merger.mlp.0.bias" + mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-mlp_2-kernel"] = ( + "thinker.visual.merger.mlp.2.weight" + ) + mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-mlp_2-bias"] = "thinker.visual.merger.mlp.2.bias" + + # Audio mapping + audio_config = config["thinker_config"]["audio_config"] + n_audio_layers = audio_config["encoder_layers"] + + # Audio conv layers (3 Conv2D layers for downsampling) + mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv2d1-kernel"] = "thinker.audio_tower.conv2d1.weight" + mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv2d1-bias"] = "thinker.audio_tower.conv2d1.bias" + mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv2d2-kernel"] = "thinker.audio_tower.conv2d2.weight" + mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv2d2-bias"] = "thinker.audio_tower.conv2d2.bias" + mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv2d3-kernel"] = "thinker.audio_tower.conv2d3.weight" + mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv2d3-bias"] = "thinker.audio_tower.conv2d3.bias" + + # Audio conv output projection + mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv_out-kernel"] = "thinker.audio_tower.conv_out.weight" + + # Audio encoder layers (32 layers) + for i in range(n_audio_layers): + prefix = f"params-audio_encoder-Qwen3OmniAudioEncoder_0-layers_{i}" + hf_prefix = f"thinker.audio_tower.layers.{i}" + + # Layer norms + mapping[f"{prefix}-input_layer_norm-scale"] = f"{hf_prefix}.self_attn_layer_norm.weight" + mapping[f"{prefix}-input_layer_norm-bias"] = f"{hf_prefix}.self_attn_layer_norm.bias" + mapping[f"{prefix}-post_attention_layer_norm-scale"] = f"{hf_prefix}.final_layer_norm.weight" + mapping[f"{prefix}-post_attention_layer_norm-bias"] = f"{hf_prefix}.final_layer_norm.bias" + + # Attention (separate Q/K/V) + mapping[f"{prefix}-self_attention_audio-query-kernel"] = f"{hf_prefix}.self_attn.q_proj.weight" + mapping[f"{prefix}-self_attention_audio-query-bias"] = f"{hf_prefix}.self_attn.q_proj.bias" + mapping[f"{prefix}-self_attention_audio-key-kernel"] = f"{hf_prefix}.self_attn.k_proj.weight" + mapping[f"{prefix}-self_attention_audio-key-bias"] = f"{hf_prefix}.self_attn.k_proj.bias" + mapping[f"{prefix}-self_attention_audio-value-kernel"] = f"{hf_prefix}.self_attn.v_proj.weight" + mapping[f"{prefix}-self_attention_audio-value-bias"] = f"{hf_prefix}.self_attn.v_proj.bias" + mapping[f"{prefix}-self_attention_audio-out-kernel"] = f"{hf_prefix}.self_attn.out_proj.weight" + mapping[f"{prefix}-self_attention_audio-out-bias"] = f"{hf_prefix}.self_attn.out_proj.bias" + + # MLP (AudioMLP has 2 linear layers: fc1 and fc2) + mapping[f"{prefix}-AudioMLP-wi-kernel"] = f"{hf_prefix}.fc1.weight" + mapping[f"{prefix}-AudioMLP-wi-bias"] = f"{hf_prefix}.fc1.bias" + mapping[f"{prefix}-AudioMLP-wo-kernel"] = f"{hf_prefix}.fc2.weight" + mapping[f"{prefix}-AudioMLP-wo-bias"] = f"{hf_prefix}.fc2.bias" + + # Audio post layer norm + mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-layernorm_post-scale"] = "thinker.audio_tower.ln_post.weight" + mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-layernorm_post-bias"] = "thinker.audio_tower.ln_post.bias" + + # Audio projector (2 linear layers) + mapping["params-audio_encoder-Qwen3OmniAudioProjector_0-proj1-kernel"] = "thinker.audio_tower.proj1.weight" + mapping["params-audio_encoder-Qwen3OmniAudioProjector_0-proj1-bias"] = "thinker.audio_tower.proj1.bias" + mapping["params-audio_encoder-Qwen3OmniAudioProjector_0-proj2-kernel"] = "thinker.audio_tower.proj2.weight" + mapping["params-audio_encoder-Qwen3OmniAudioProjector_0-proj2-bias"] = "thinker.audio_tower.proj2.bias" return mapping @@ -1001,8 +1127,228 @@ def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, scan_layers=False, saving ) mapping.update(text_hooks) - # TODO(hengtaoguo): Add vision, audio, and other modality mappings here similarly - # mapping.update(vision_hooks), mapping.update(audio_hooks), etc. + # Vision hooks + vision_config = config["thinker_config"]["vision_config"] + n_vision_layers = vision_config["depth"] + hidden_size = vision_config["hidden_size"] + + def reshape_kernel_vision(input_tensor, target_shape): + """Reshape kernel for vision layers.""" + if saving_to_hf: + flipped_target_shape = np.flip(np.array(target_shape)) + return input_tensor.reshape(flipped_target_shape).T + else: + return input_tensor.T.reshape(target_shape) + + def reshape_conv3d_patch_embed(input_tensor, target_shape): + """Reshape 3D conv patch embedding weight. + HF: (out_channels, in_channels, temporal, height, width) + MaxText: (temporal, height, width, in_channels, out_channels) + """ + if saving_to_hf: + # MaxText -> HF: (T, H, W, C_in, C_out) -> (C_out, C_in, T, H, W) + return input_tensor.transpose(4, 3, 0, 1, 2) + else: + # HF -> MaxText: (C_out, C_in, T, H, W) -> (T, H, W, C_in, C_out) + return input_tensor.transpose(2, 3, 4, 1, 0) + + def split_qkv_query(input_tensor, target_shape): + """Extract Q from fused QKV for HF->MaxText conversion. + HF has fused QKV: (3*hidden_size, hidden_size) + MaxText Q: (hidden_size, num_heads, head_dim) + """ + if saving_to_hf: + # MaxText -> HF: will be handled by fusion hook + raise NotImplementedError("Use fusion hook for MaxText->HF") + else: + # HF -> MaxText: Extract Q from fused QKV + # input_tensor shape: (3*hidden_size, hidden_size) + q_weight = input_tensor[:hidden_size, :] # (hidden_size, hidden_size) + return q_weight.T.reshape(target_shape) # (hidden_size, num_heads, head_dim) + + def split_qkv_key(input_tensor, target_shape): + """Extract K from fused QKV for HF->MaxText conversion.""" + if saving_to_hf: + raise NotImplementedError("Use fusion hook for MaxText->HF") + else: + # Extract K from fused QKV + k_weight = input_tensor[hidden_size : 2 * hidden_size, :] + return k_weight.T.reshape(target_shape) + + def split_qkv_value(input_tensor, target_shape): + """Extract V from fused QKV for HF->MaxText conversion.""" + if saving_to_hf: + raise NotImplementedError("Use fusion hook for MaxText->HF") + else: + # Extract V from fused QKV + v_weight = input_tensor[2 * hidden_size :, :] + return v_weight.T.reshape(target_shape) + + def split_qkv_bias_query(input_tensor, target_shape): + """Extract Q bias from fused QKV bias.""" + if saving_to_hf: + raise NotImplementedError("Use fusion hook for MaxText->HF") + else: + q_bias = input_tensor[:hidden_size] + return q_bias.reshape(target_shape) # (num_heads, head_dim) + + def split_qkv_bias_key(input_tensor, target_shape): + """Extract K bias from fused QKV bias.""" + if saving_to_hf: + raise NotImplementedError("Use fusion hook for MaxText->HF") + else: + k_bias = input_tensor[hidden_size : 2 * hidden_size] + return k_bias.reshape(target_shape) + + def split_qkv_bias_value(input_tensor, target_shape): + """Extract V bias from fused QKV bias.""" + if saving_to_hf: + raise NotImplementedError("Use fusion hook for MaxText->HF") + else: + v_bias = input_tensor[2 * hidden_size :] + return v_bias.reshape(target_shape) + + def reshape_vision_attn_out(input_tensor, target_shape): + """Reshape vision attention output projection. + HF: (hidden_size, hidden_size) + MaxText: (num_heads, head_dim, hidden_size) + """ + if saving_to_hf: + # MaxText -> HF: (num_heads, head_dim, hidden_size) -> (hidden_size, hidden_size) + return input_tensor.reshape(hidden_size, hidden_size).T + else: + # HF -> MaxText: (hidden_size, hidden_size) -> (num_heads, head_dim, hidden_size) + return input_tensor.T.reshape(target_shape) + + # Vision patch embedding + mapping["params-vision_encoder-Qwen3OmniMoeVisionEncoder_0-patch_embed-proj-kernel"] = reshape_conv3d_patch_embed + + # Vision blocks + for i in range(n_vision_layers): + prefix = f"params-vision_encoder-Qwen3OmniMoeVisionEncoder_0-blocks_{i}" + + # Attention Q/K/V - split from fused QKV + mapping[f"{prefix}-attn-attn-query-kernel"] = split_qkv_query + mapping[f"{prefix}-attn-attn-query-bias"] = split_qkv_bias_query + mapping[f"{prefix}-attn-attn-key-kernel"] = split_qkv_key + mapping[f"{prefix}-attn-attn-key-bias"] = split_qkv_bias_key + mapping[f"{prefix}-attn-attn-value-kernel"] = split_qkv_value + mapping[f"{prefix}-attn-attn-value-bias"] = split_qkv_bias_value + + # Attention output + mapping[f"{prefix}-attn-attn-out-kernel"] = reshape_vision_attn_out + # attn-attn-out-bias doesn't need a hook (no reshape needed) + + # MLP + mapping[f"{prefix}-mlp-kernel"] = reshape_kernel_vision + mapping[f"{prefix}-mlp_out-kernel"] = reshape_kernel_vision + + # Vision merger_list and projector MLPs + deepstack_indexes = vision_config.get("deepstack_visual_indexes", [8, 16, 24]) + for merger_idx, _ in enumerate(deepstack_indexes): + prefix = f"params-vision_encoder-Qwen3OmniMoeVisionEncoder_0-merger_{merger_idx}" + mapping[f"{prefix}-mlp_0-kernel"] = reshape_kernel_vision + mapping[f"{prefix}-mlp_2-kernel"] = reshape_kernel_vision + + # Vision projector (final merger) + mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-mlp_0-kernel"] = reshape_kernel_vision + mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-mlp_2-kernel"] = reshape_kernel_vision + + # Audio hooks + audio_config = config["thinker_config"]["audio_config"] + n_audio_layers = audio_config["encoder_layers"] + hidden_size_audio = audio_config["d_model"] + + def reshape_kernel_audio(input_tensor, target_shape): + """Reshape kernel for audio layers.""" + if saving_to_hf: + flipped_target_shape = np.flip(np.array(target_shape)) + return input_tensor.reshape(flipped_target_shape).T + else: + return input_tensor.T.reshape(target_shape) + + def reshape_conv2d_audio(input_tensor, target_shape): + """Reshape Conv2D weight for audio. + + HF: (out_channels, in_channels, height, width) + MaxText: (height, width, in_channels, out_channels) + """ + if saving_to_hf: + # MaxText -> HF: (H, W, C_in, C_out) -> (C_out, C_in, H, W) + return input_tensor.transpose(3, 2, 0, 1) + else: + # HF -> MaxText: (C_out, C_in, H, W) -> (H, W, C_in, C_out) + return input_tensor.transpose(2, 3, 1, 0) + + def reshape_audio_attn_qkv(input_tensor, target_shape): + """Reshape audio attention Q/K/V projection. + + HF: (hidden_size, hidden_size) + MaxText: (hidden_size, num_heads, head_dim) + """ + if saving_to_hf: + # MaxText -> HF: (hidden_size, num_heads, head_dim) -> (hidden_size, hidden_size) + return input_tensor.reshape(hidden_size_audio, hidden_size_audio).T + else: + # HF -> MaxText: (hidden_size, hidden_size) -> (hidden_size, num_heads, head_dim) + return input_tensor.T.reshape(target_shape) + + def reshape_audio_attn_out(input_tensor, target_shape): + """Reshape audio attention output projection. + F + HF: (hidden_size, hidden_size) + MaxText: (num_heads, head_dim, hidden_size) + """ + if saving_to_hf: + # MaxText -> HF: (num_heads, head_dim, hidden_size) -> (hidden_size, hidden_size) + return input_tensor.reshape(hidden_size_audio, hidden_size_audio).T + else: + # HF -> MaxText: (hidden_size, hidden_size) -> (num_heads, head_dim, hidden_size) + return input_tensor.T.reshape(target_shape) + + def reshape_audio_attn_qkv_bias(input_tensor, target_shape): + """Reshape audio attention Q/K/V bias. + + HF: (hidden_size,) + MaxText: (num_heads, head_dim) + """ + if saving_to_hf: + # MaxText -> HF: (num_heads, head_dim) -> (hidden_size,) + return input_tensor.reshape(hidden_size_audio) + else: + # HF -> MaxText: (hidden_size,) -> (num_heads, head_dim) + return input_tensor.reshape(target_shape) + + # Audio conv layers + mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv2d1-kernel"] = reshape_conv2d_audio + mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv2d2-kernel"] = reshape_conv2d_audio + mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv2d3-kernel"] = reshape_conv2d_audio + + # Audio conv output projection + mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv_out-kernel"] = reshape_kernel_audio + + # Audio encoder layers + for i in range(n_audio_layers): + prefix = f"params-audio_encoder-Qwen3OmniAudioEncoder_0-layers_{i}" + + # Attention Q/K/V + mapping[f"{prefix}-self_attention_audio-query-kernel"] = reshape_audio_attn_qkv + mapping[f"{prefix}-self_attention_audio-query-bias"] = reshape_audio_attn_qkv_bias + mapping[f"{prefix}-self_attention_audio-key-kernel"] = reshape_audio_attn_qkv + mapping[f"{prefix}-self_attention_audio-key-bias"] = reshape_audio_attn_qkv_bias + mapping[f"{prefix}-self_attention_audio-value-kernel"] = reshape_audio_attn_qkv + mapping[f"{prefix}-self_attention_audio-value-bias"] = reshape_audio_attn_qkv_bias + + # Attention output + mapping[f"{prefix}-self_attention_audio-out-kernel"] = reshape_audio_attn_out + + # MLP + mapping[f"{prefix}-AudioMLP-wi-kernel"] = reshape_kernel_audio + mapping[f"{prefix}-AudioMLP-wo-kernel"] = reshape_kernel_audio + + # Audio projector + mapping["params-audio_encoder-Qwen3OmniAudioProjector_0-proj1-kernel"] = reshape_kernel_audio + mapping["params-audio_encoder-Qwen3OmniAudioProjector_0-proj2-kernel"] = reshape_kernel_audio return mapping