From 412e962feb82bf054202eb36a6e5ffaab68bf0dc Mon Sep 17 00:00:00 2001 From: cloudforge1 Date: Mon, 23 Mar 2026 23:34:24 +0800 Subject: [PATCH 1/3] =?UTF-8?q?=E3=80=90Hackathon=2010th=20No.50=E3=80=91a?= =?UTF-8?q?dd=20MiniCPM4/4.1=20model=20support?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - MiniCPM4MLP: gate/up merged, silu activation, no bias - MiniCPM4Attention: GQA with QKVParallelLinear(bias=False), neox rotary - MiniCPM4DecoderLayer: μP residual scaling (scale_depth/sqrt(num_layers)) - MiniCPM4Model: μP embedding scaling (scale_emb), LongRoPE support - MiniCPM4ForCausalLM: μP lm_head scaling (hidden_size/dim_model_base) - Weight mapping: HF model. to FD minicpm4. prefix - Architecture: MiniCPMForCausalLM (GQA, not MLA) - Follows Qwen2 patterns adapted for MiniCPM4 μP parametrization --- fastdeploy/model_executor/models/minicpm4.py | 516 +++++++++++++++++++ 1 file changed, 516 insertions(+) create mode 100644 fastdeploy/model_executor/models/minicpm4.py diff --git a/fastdeploy/model_executor/models/minicpm4.py b/fastdeploy/model_executor/models/minicpm4.py new file mode 100644 index 00000000000..96a4d86b1ab --- /dev/null +++ b/fastdeploy/model_executor/models/minicpm4.py @@ -0,0 +1,516 @@ +""" +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from __future__ import annotations + +import math +import re +from functools import partial +from typing import Dict + +import paddle +from paddle import nn +from paddleformers.transformers import PretrainedModel +from paddleformers.utils.log import logger + +from fastdeploy.config import FDConfig, ModelConfig +from fastdeploy.model_executor.forward_meta import ForwardMeta +from fastdeploy.model_executor.graph_optimization.decorator import ( + support_graph_optimization, +) +from fastdeploy.model_executor.layers.activation import SiluAndMul +from fastdeploy.model_executor.layers.attention.attention import Attention +from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding +from fastdeploy.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from fastdeploy.model_executor.layers.lm_head import ParallelLMHead +from fastdeploy.model_executor.layers.normalization import RMSNorm +from fastdeploy.model_executor.models.model_base import ( + ModelCategory, + ModelForCasualLM, + ModelRegistry, +) +from fastdeploy.model_executor.utils import ( + WeightsMapper, + default_weight_loader, + process_weights_after_loading, + process_weights_before_loading, +) + + +class MiniCPM4MLP(nn.Layer): + """ """ + + def __init__( + self, + fd_config: FDConfig, + prefix: str = "", + ) -> None: + super().__init__() + self.up_gate_proj = MergedColumnParallelLinear( + fd_config=fd_config, + prefix=f"{prefix}.up_gate_proj", + input_size=fd_config.model_config.hidden_size, + output_size=fd_config.model_config.intermediate_size * 2, + with_bias=False, + activation=fd_config.model_config.hidden_act, + ) + + self.down_proj = RowParallelLinear( + fd_config=fd_config, + prefix=f"{prefix}.down_proj", + input_size=fd_config.model_config.intermediate_size, + output_size=fd_config.model_config.hidden_size, + with_bias=False, + ) + + self.act_fn = SiluAndMul( + fd_config=fd_config, + bias=getattr(self.up_gate_proj, "bias", None), + act_method=fd_config.model_config.hidden_act, + ) + + def load_state_dict(self, state_dict): + """ """ + self.up_gate_proj.load_state_dict(state_dict) + self.down_proj.load_state_dict(state_dict) + + def forward(self, x, forward_meta): + """ """ + gate_up_out = self.up_gate_proj(x) + act_out = self.act_fn(gate_up_out) + down_out = self.down_proj(act_out) + return down_out + + +class MiniCPM4Attention(nn.Layer): + """ """ + + def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None: + super().__init__() + + self.qkv_proj = QKVParallelLinear(fd_config=fd_config, prefix=f"{prefix}.qkv_proj", with_bias=False) + + self.o_proj = RowParallelLinear( + fd_config=fd_config, + prefix=f"{prefix}.o_proj", + input_size=fd_config.model_config.hidden_size, + output_size=fd_config.model_config.hidden_size, + ) + + self.attn = Attention( + fd_config=fd_config, + layer_id=layer_id, + prefix=prefix, + use_neox_rotary_style=True, + ) + + def load_state_dict(self, state_dict): + """ """ + self.qkv_proj.load_state_dict(state_dict) + self.o_proj.load_state_dict(state_dict) + self.attn.load_state_dict(state_dict) + + def forward( + self, + forward_meta: ForwardMeta, + hidden_states: paddle.Tensor, + ): + """ """ + qkv_out = self.qkv_proj(hidden_states) + + atten_out = self.attn( + qkv=qkv_out, + forward_meta=forward_meta, + ) + output = self.o_proj(atten_out) + return output + + +class MiniCPM4DecoderLayer(nn.Layer): + """MiniCPM4 decoder layer with μP residual scaling.""" + + def __init__( + self, + fd_config: FDConfig, + prefix: str = "", + ) -> None: + super().__init__() + layer_id = int(prefix.split(sep=".")[-1]) + + self.self_attn = MiniCPM4Attention( + fd_config=fd_config, + layer_id=layer_id, + prefix=f"{prefix}.self_attn", + ) + + self.mlp = MiniCPM4MLP( + fd_config=fd_config, + prefix=f"{prefix}.mlp", + ) + + self.input_layernorm = RMSNorm( + fd_config, + hidden_size=fd_config.model_config.hidden_size, + eps=fd_config.model_config.rms_norm_eps, + prefix=f"{prefix}.input_layernorm", + ) + + self.post_attention_layernorm = RMSNorm( + fd_config, + hidden_size=fd_config.model_config.hidden_size, + eps=fd_config.model_config.rms_norm_eps, + prefix=f"{prefix}.post_attention_layernorm", + layer_id=layer_id, + ) + + # μP residual scaling: scale_depth / sqrt(num_hidden_layers) + scale_depth = getattr(fd_config.model_config, "scale_depth", 1.0) + num_hidden_layers = fd_config.model_config.num_hidden_layers + self.residual_scale = scale_depth / math.sqrt(num_hidden_layers) + + def load_state_dict(self, state_dict): + """ """ + self.self_attn.load_state_dict(state_dict) + self.mlp.load_state_dict(state_dict) + self.input_layernorm.load_state_dict(state_dict) + self.post_attention_layernorm.load_state_dict(state_dict) + + def forward( + self, + forward_meta: ForwardMeta, + hidden_states: paddle.Tensor, + residual: paddle.Tensor = None, + ): + """ """ + # Self Attention + hidden_states, residual = self.input_layernorm( + hidden_states, residual_input=residual, forward_meta=forward_meta + ) + + hidden_states = self.self_attn( + hidden_states=hidden_states, + forward_meta=forward_meta, + ) + + # μP: scale attention output before residual add + hidden_states = hidden_states * self.residual_scale + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + + hidden_states = self.mlp(hidden_states, forward_meta) + + # μP: scale MLP output before residual add + hidden_states = hidden_states * self.residual_scale + + return hidden_states, residual + + +@support_graph_optimization +class MiniCPM4Model(nn.Layer): + """ """ + + def __init__( + self, + fd_config: FDConfig = None, + ): + super().__init__() + + self.num_layers = fd_config.model_config.num_hidden_layers + fd_config.model_config.pretrained_config.prefix_name = "minicpm4" + + # μP embedding scaling factor + self.scale_emb = getattr(fd_config.model_config, "scale_emb", 1) + + self.embed_tokens = VocabParallelEmbedding( + fd_config=fd_config, + num_embeddings=fd_config.model_config.vocab_size, + embedding_dim=fd_config.model_config.hidden_size, + params_dtype=paddle.get_default_dtype, + prefix=(f"{fd_config.model_config.pretrained_config.prefix_name}.embed_tokens"), + ) + + self.layers = nn.LayerList( + [ + MiniCPM4DecoderLayer( + fd_config=fd_config, + prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.layers.{i}", + ) + for i in range(self.num_layers) + ] + ) + + self.norm = RMSNorm( + fd_config, + hidden_size=fd_config.model_config.hidden_size, + eps=fd_config.model_config.rms_norm_eps, + prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.norm", + ) + + def load_state_dict(self, state_dict): + """ + Load model parameters from a given state dictionary. + + Args: + state_dict (dict[str, np.ndarray | paddle.Tensor]): + A dictionary containing model parameters, where keys are parameter names + and values are NumPy arrays or PaddlePaddle tensors. + """ + self.embed_tokens.load_state_dict(state_dict) + self.norm.load_state_dict(state_dict) + for i in range(self.num_layers): + logger.info(f"Start load layer {i}") + self.layers[i].load_state_dict(state_dict) + + def forward( + self, + ids_remove_padding: paddle.Tensor, + forward_meta: ForwardMeta, + ): + + hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding, forward_meta=forward_meta) + + # μP: scale embeddings + if self.scale_emb != 1: + hidden_states = hidden_states * self.scale_emb + + residual = None + + for i in range(self.num_layers): + hidden_states, residual = self.layers[i](forward_meta, hidden_states, residual) + + out = self.norm(hidden_states, residual)[0] + + return out + + +@ModelRegistry.register_model_class( + architecture="MiniCPMForCausalLM", + module_name="minicpm4", + category=[ModelCategory.TEXT_GENERATION], + primary_use=ModelCategory.TEXT_GENERATION, +) +class MiniCPM4ForCausalLM(ModelForCasualLM): + """ + MiniCPM4ForCausalLM — supports MiniCPM4 and MiniCPM4.1 series models. + + Key differences from Qwen2: + - μP (Maximal Update Parametrization) scaling: + * Embedding output scaled by `scale_emb` (default: 12) + * Residual connections scaled by `scale_depth / sqrt(num_hidden_layers)` (default: 1.4) + * LM head input scaled by `hidden_size / dim_model_base` (default: 4096/256 = 16) + - No QKV bias (attention_bias=false) + - LongRoPE position encoding + """ + + def __init__(self, fd_config: FDConfig): + super(MiniCPM4ForCausalLM, self).__init__(fd_config) + + self.fd_config = fd_config + self.minicpm4 = MiniCPM4Model(fd_config=fd_config) + + self.ori_vocab_size = fd_config.model_config.ori_vocab_size + self.tie_word_embeddings = fd_config.model_config.tie_word_embeddings + self.lm_head = ParallelLMHead( + fd_config=fd_config, + embedding_dim=fd_config.model_config.hidden_size, + num_embeddings=fd_config.model_config.vocab_size, + prefix="lm_head", + ) + + # μP: lm_head input scaling factor = hidden_size / dim_model_base + dim_model_base = getattr(fd_config.model_config, "dim_model_base", None) + hidden_size = fd_config.model_config.hidden_size + if dim_model_base is not None and dim_model_base > 0: + self.lm_head_scale = hidden_size / dim_model_base + else: + self.lm_head_scale = 1.0 + + self.process_weights_before_loading_fn = process_weights_before_loading( + mapper=( + WeightsMapper(orig_to_new_prefix={"model.": "minicpm4."}) + if self.fd_config.model_config.model_format == "torch" + else None + ), + ) + + @paddle.no_grad() + def load_weights(self, weights_iterator) -> None: + """ + Load model parameters from a given weights_iterator object. + + Args: + weights_iterator (Iterator): An iterator yielding (name, weight) pairs. + """ + + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("up_gate_proj", "gate_proj", "gate"), + ("up_gate_proj", "up_proj", "up"), + ("embed_tokens.embeddings", "embed_tokens", None), + ("lm_head.linear", "lm_head", None), + ] + + params_dict = dict(self.named_parameters()) + process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()), self.fd_config) + for loaded_weight_name, loaded_weight in weights_iterator: + logger.debug(f"Loading weight: {loaded_weight_name}") + loaded_weight_name = ( + self.process_weights_before_loading_fn(loaded_weight_name) + if getattr(self, "process_weights_before_loading_fn", None) + else loaded_weight_name + ) + if loaded_weight_name is None: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in loaded_weight_name: + continue + model_param_name = loaded_weight_name.replace(weight_name, param_name) + if model_param_name not in params_dict: + continue + param = params_dict[model_param_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) + weight_loader(param, loaded_weight, shard_id) + break + else: + model_param_name = loaded_weight_name + if model_param_name not in params_dict: + continue + param = params_dict[model_param_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) + weight_loader(param, loaded_weight) + model_sublayer_name = re.sub(r"\.(weight)$", "", model_param_name) + process_weights_after_loading_fn(model_sublayer_name, param) + if getattr(self, "tie_word_embeddings", False): + self.lm_head.linear.weight.set_value( + self.minicpm4.embed_tokens.embeddings.weight.transpose([1, 0]).astype(self.lm_head.linear.weight.dtype) + ) + + @classmethod + def name(self): + """ """ + return "MiniCPMForCausalLM" + + @paddle.no_grad() + def set_state_dict(self, state_dict): + """ + Load model parameters from a given state dictionary. + + Args: + state_dict (dict[str, np.ndarray | paddle.Tensor]): + A dictionary containing model parameters, where keys are parameter names + and values are NumPy arrays or PaddlePaddle tensors. + """ + self.minicpm4.load_state_dict(state_dict) + self.lm_head.load_state_dict(state_dict) + + def compute_logits(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta = None): + """ """ + # μP: scale hidden states before lm_head + if self.lm_head_scale != 1.0: + hidden_states = hidden_states / self.lm_head_scale + logits = self.lm_head(hidden_states) + logits = logits.astype(paddle.float32) + logits[:, self.ori_vocab_size :] = -float("inf") + + return logits + + def forward( + self, + inputs: Dict, + forward_meta: ForwardMeta, + ): + ids_remove_padding = inputs["ids_remove_padding"] + hidden_states = self.minicpm4(ids_remove_padding=ids_remove_padding, forward_meta=forward_meta) + + return hidden_states + + def clear_grpah_opt_backend(self): + """Clear graph optimization backend, the captured cuda graph will be cleaned""" + self.minicpm4.clear_grpah_opt_backend(fd_config=self.fd_config) + + +class MiniCPM4PretrainedModel(PretrainedModel): + """ + MiniCPM4PretrainedModel + """ + + config_class = FDConfig + + def _init_weight(self, layer): + """ + _init_weight + """ + return None + + @classmethod + def arch_name(self): + return "MiniCPMForCausalLM" + + @classmethod + def _get_tensor_parallel_mappings(cls, config: ModelConfig, is_split=True): + + from paddleformers.transformers.conversion_utils import split_or_merge_func + + fn = split_or_merge_func( + is_split=is_split, + tensor_model_parallel_size=config.tensor_model_parallel_size, + tensor_parallel_rank=config.tensor_parallel_rank, + num_attention_heads=config.num_attention_heads, + ) + + def get_tensor_parallel_split_mappings(num_layers): + final_actions = {} + + base_actions = { + "lm_head.weight": partial(fn, is_column=True), + # Row Linear + "embed_tokens.weight": partial(fn, is_column=False), + "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False), + "layers.0.mlp.down_proj.weight": partial(fn, is_column=False), + } + + # Column Linear + if config.fuse_attention_qkv: + base_actions["layers.0.self_attn.qkv_proj.weight"] = partial(fn, is_column=True) + else: + base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True) + # MiniCPM4 has no QKV bias, only need weight splits + if config.num_key_value_heads % config.tensor_model_parallel_size == 0: + base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True) + + base_actions["layers.0.mlp.gate_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.up_proj.weight"] = partial(fn, is_column=True) + + for key, action in base_actions.items(): + if "layers.0." in key: + for i in range(num_layers): + final_actions[key.replace("layers.0.", f"layers.{i}.")] = action + final_actions[key] = action + + return final_actions + + mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers) + + return mappings From 59758a588d200a8094f11a8380ee77f060be2ef2 Mon Sep 17 00:00:00 2001 From: cloudforge1 Date: Tue, 24 Mar 2026 02:31:02 +0800 Subject: [PATCH 2/3] add MiniCPM4 usage documentation and supported models entry --- docs/best_practices/MiniCPM4-8B.md | 104 +++++++++++++++++++++++++++++ docs/supported_models.md | 1 + 2 files changed, 105 insertions(+) create mode 100644 docs/best_practices/MiniCPM4-8B.md diff --git a/docs/best_practices/MiniCPM4-8B.md b/docs/best_practices/MiniCPM4-8B.md new file mode 100644 index 00000000000..6d758bb1c45 --- /dev/null +++ b/docs/best_practices/MiniCPM4-8B.md @@ -0,0 +1,104 @@ +# MiniCPM4/4.1-8B + +## I. Environment Preparation + +### 1.1 Hardware Requirements +The minimum number of GPUs required to deploy `MiniCPM4.1-8B` on the following hardware for each quantization is as follows: + +| | BF16 | WINT8 | WINT4 | FP8 | +|-----|-----|-----|-----|-----| +|H800 80GB| 1 | 1 | 1 | 1 | +|A800 80GB| 1 | 1 | 1 | / | +|H20 96GB| 1 | 1 | 1 | 1 | +|L20 48GB| 1 | 1 | 1 | 1 | +|A30 40GB| / | 1 | 1 | / | +|A10 24GB| / | 1 | 1 | / | +|V100 32GB| / | 1 | 1 | / | + +**Tips:** +1. MiniCPM4.1-8B is a dense 8B model — a single GPU is sufficient for inference at all supported quantization levels. +2. For hardware not listed in the table, you can estimate whether it can be deployed based on the GPU memory. BF16 requires ~16GB, WINT8 ~8GB, WINT4 ~4GB. + +### 1.2 Install FastDeploy +- Installation: For detail, please refer to [FastDeploy Installation](../get_started/installation/README.md). +- Model Download: For detail, please refer to [Supported Models](../supported_models.md). + +## II. How to Use + +### 2.1 Basic: Launching the Service + +**Example 1:** Deploying MiniCPM4.1-8B with WINT4 quantization + +```bash +python -m fastdeploy.entrypoints.openai.api_server \ + --model openbmb/MiniCPM4.1-8B \ + --tensor-parallel-size 1 \ + --quantization wint4 \ + --max-model-len 32768 \ + --max-num-seqs 128 +``` + +**Example 2:** Deploying MiniCPM4.1-8B with BF16 (full precision) + +```bash +python -m fastdeploy.entrypoints.openai.api_server \ + --model openbmb/MiniCPM4.1-8B \ + --tensor-parallel-size 1 \ + --max-model-len 32768 \ + --max-num-seqs 64 +``` + +- `--quantization`: Quantization strategy. Options: `wint8` / `wint4` / `block_wise_fp8` (Hopper required). Omit for BF16. +- `--max-model-len`: Maximum number of tokens for the deployed service. MiniCPM4.1 supports up to 65,536 tokens with LongRoPE, but larger values increase GPU memory usage. + +For more parameter meanings and default settings, see [FastDeploy Parameter Documentation](../parameters.md). + +### 2.2 Sending Requests + +After the service starts, send requests via the OpenAI-compatible API: + +```bash +curl http://localhost:8180/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "openbmb/MiniCPM4.1-8B", + "messages": [{"role": "user", "content": "What is the capital of France?"}], + "max_tokens": 512 + }' +``` + +### 2.3 Advanced: How to Get Better Performance + +#### 2.3.1 Correctly Set Parameters That Match the Application Scenario +Evaluate average input length, average output length, and maximum context length. +- Set `--max-model-len` according to the maximum context length. For example, if the average input length is 1000 and the output length is 4000, then it is recommended to set it to 8192. + +#### 2.3.2 Prefix Caching +**Idea:** The core idea of Prefix Caching is to avoid repeated calculations by caching the intermediate calculation results of the input sequence (KV Cache), thereby speeding up the response speed of multiple requests with the same prefix. For details, refer to [prefix-cache](../features/prefix_caching.md). + +**How to enable:** +Since version 2.2 (including the develop branch), Prefix Caching has been enabled by default. + +#### 2.3.3 Chunked Prefill +**Idea:** This strategy splits the prefill stage request into small-scale sub-chunks, and executes them in batches mixed with the decode request. For details, please refer to [Chunked Prefill](../features/chunked_prefill.md). + +**How to enable:** +Since version 2.2 (including the develop branch), Chunked Prefill has been enabled by default. + +#### 2.3.4 CudaGraph +**Idea:** CUDAGraph encapsulates GPU computing and memory operations into a re-executable graph, reducing CPU-GPU communication overhead and improving computing performance. + +**How to enable:** +CUDAGraph has been enabled by default since version 2.3. + +## Model Architecture Notes + +MiniCPM4.1-8B uses μP (Maximal Update Parametrization) for training stability: +- **Embedding scaling**: Output scaled by `scale_emb` (12×) +- **Residual scaling**: Connections scaled by `scale_depth / √num_hidden_layers` +- **LM head scaling**: Input scaled by `hidden_size / dim_model_base` + +These scaling factors are automatically read from the model's `config.json` and require no user configuration. + +## FAQ +If you encounter any problems during use, please refer to [FAQ](./FAQ.md). diff --git a/docs/supported_models.md b/docs/supported_models.md index b0684affc11..caf1cf54f82 100644 --- a/docs/supported_models.md +++ b/docs/supported_models.md @@ -40,6 +40,7 @@ These models accept text input. |⭐DEEPSEEK|BF16/WINT4|unsloth/DeepSeek-V3.1-BF16;
unsloth/DeepSeek-V3-0324-BF16;
unsloth/DeepSeek-R1-BF16, etc.| |⭐GPT-OSS|BF16/WINT8|unsloth/gpt-oss-20b-BF16, etc.| |⭐GLM-4.5/4.6|BF16/wfp8afp8|zai-org/GLM-4.5-Air;
zai-org/GLM-4.6
 [最佳实践](./best_practices/GLM-4-MoE-Text.md) etc.| +|MINICPM4|BF16/WINT8/WINT4/FP8|[openbmb/MiniCPM4.1-8B](./best_practices/MiniCPM4-8B.md);
openbmb/MiniCPM4-8B| ## Multimodal Language Models From 1cb866132a69481f73aa35eb9c308cc6889403a7 Mon Sep 17 00:00:00 2001 From: cloudforge1 Date: Tue, 24 Mar 2026 03:41:35 +0800 Subject: [PATCH 3/3] =?UTF-8?q?add=20CPU-side=20unit=20tests=20for=20MiniC?= =?UTF-8?q?PM4=20=CE=BCP=20scaling,=20weight=20mapping,=20and=20registrati?= =?UTF-8?q?on?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/model_executor/test_minicpm4.py | 320 ++++++++++++++++++++++++++ 1 file changed, 320 insertions(+) create mode 100644 tests/model_executor/test_minicpm4.py diff --git a/tests/model_executor/test_minicpm4.py b/tests/model_executor/test_minicpm4.py new file mode 100644 index 00000000000..d203bbf054c --- /dev/null +++ b/tests/model_executor/test_minicpm4.py @@ -0,0 +1,320 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import math + +import numpy as np +import paddle +import pytest + +# ── μP scaling math tests (pure computation, no FD imports needed) ────────── + + +class TestMuPScaling: + """Test μP (Maximal Update Parametrization) scaling factors. + + MiniCPM4 applies three scaling sites: + 1. Embedding: output *= scale_emb + 2. Residual: hidden_states *= scale_depth / sqrt(num_hidden_layers) + 3. LM head: hidden_states /= (hidden_size / dim_model_base) + """ + + # Reference config values from openbmb/MiniCPM4.1-8B + SCALE_EMB = 12 + SCALE_DEPTH = 1.4 + NUM_HIDDEN_LAYERS = 32 + HIDDEN_SIZE = 4096 + DIM_MODEL_BASE = 256 + + def test_embedding_scaling(self): + """Embedding output scaled by scale_emb.""" + x = paddle.ones([2, 8, self.HIDDEN_SIZE], dtype="float32") + scaled = x * self.SCALE_EMB + np.testing.assert_allclose( + scaled.numpy(), + np.full([2, 8, self.HIDDEN_SIZE], 12.0, dtype="float32"), + ) + + def test_residual_scaling_value(self): + """Residual scale = scale_depth / sqrt(num_hidden_layers).""" + expected = self.SCALE_DEPTH / math.sqrt(self.NUM_HIDDEN_LAYERS) + assert abs(expected - 0.24748737341529164) < 1e-10 + + def test_residual_scaling_applied(self): + """Hidden states scaled by residual_scale before residual add.""" + residual_scale = self.SCALE_DEPTH / math.sqrt(self.NUM_HIDDEN_LAYERS) + x = paddle.full([4, self.HIDDEN_SIZE], 2.0, dtype="float32") + scaled = x * residual_scale + np.testing.assert_allclose( + scaled.numpy(), + np.full([4, self.HIDDEN_SIZE], 2.0 * residual_scale, dtype="float32"), + rtol=1e-6, + ) + + def test_lm_head_scaling(self): + """LM head input divided by hidden_size / dim_model_base.""" + lm_head_scale = self.HIDDEN_SIZE / self.DIM_MODEL_BASE + assert lm_head_scale == 16.0 + + x = paddle.full([4, self.HIDDEN_SIZE], 32.0, dtype="float32") + scaled = x / lm_head_scale + np.testing.assert_allclose( + scaled.numpy(), + np.full([4, self.HIDDEN_SIZE], 2.0, dtype="float32"), + ) + + def test_lm_head_scale_fallback(self): + """When dim_model_base is None or 0, lm_head_scale defaults to 1.0.""" + for dim_model_base in [None, 0]: + if dim_model_base is not None and dim_model_base > 0: + scale = self.HIDDEN_SIZE / dim_model_base + else: + scale = 1.0 + assert scale == 1.0 + + def test_residual_scale_depth_default(self): + """When scale_depth not in config, defaults to 1.0 → no scaling.""" + scale_depth = 1.0 # default + residual_scale = scale_depth / math.sqrt(self.NUM_HIDDEN_LAYERS) + x = paddle.full([4, self.HIDDEN_SIZE], 1.0, dtype="float32") + scaled = x * residual_scale + expected = 1.0 / math.sqrt(32) + np.testing.assert_allclose(scaled.numpy().mean(), expected, rtol=1e-6) + + +# ── Weight mapping tests ──────────────────────────────────────────────────── + + +class TestWeightMapping: + """Test HuggingFace → FastDeploy weight name mapping.""" + + STACKED_PARAMS = [ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("up_gate_proj", "gate_proj", "gate"), + ("up_gate_proj", "up_proj", "up"), + ("embed_tokens.embeddings", "embed_tokens", None), + ("lm_head.linear", "lm_head", None), + ] + + def test_hf_prefix_rename(self): + """HF 'model.' prefix maps to FD 'minicpm4.' prefix.""" + hf_names = [ + "model.layers.0.self_attn.q_proj.weight", + "model.embed_tokens.weight", + "model.norm.weight", + "lm_head.weight", # no model. prefix + ] + for name in hf_names: + fd_name = name.replace("model.", "minicpm4.") + if name.startswith("model."): + assert fd_name.startswith("minicpm4.") + else: + assert fd_name == name # lm_head unchanged + + def test_qkv_stacking(self): + """q_proj, k_proj, v_proj map to qkv_proj with correct shard_id.""" + qkv_map = {wn: (pn, sid) for pn, wn, sid in self.STACKED_PARAMS if "proj" in wn and sid in ("q", "k", "v")} + assert qkv_map["q_proj"] == ("qkv_proj", "q") + assert qkv_map["k_proj"] == ("qkv_proj", "k") + assert qkv_map["v_proj"] == ("qkv_proj", "v") + + def test_gate_up_stacking(self): + """gate_proj, up_proj map to up_gate_proj.""" + gu_map = {wn: (pn, sid) for pn, wn, sid in self.STACKED_PARAMS if sid in ("gate", "up")} + assert gu_map["gate_proj"] == ("up_gate_proj", "gate") + assert gu_map["up_proj"] == ("up_gate_proj", "up") + + def test_embed_and_lm_head_rename(self): + """embed_tokens → embed_tokens.embeddings, lm_head → lm_head.linear.""" + rename_map = {wn: pn for pn, wn, sid in self.STACKED_PARAMS if sid is None} + assert rename_map["embed_tokens"] == "embed_tokens.embeddings" + assert rename_map["lm_head"] == "lm_head.linear" + + def test_weight_name_replacement(self): + """Full pipeline: HF name → prefix rename → stacked param rename.""" + hf_name = "model.layers.5.self_attn.q_proj.weight" + # Step 1: prefix rename + fd_name = hf_name.replace("model.", "minicpm4.") + assert fd_name == "minicpm4.layers.5.self_attn.q_proj.weight" + + # Step 2: stacked param rename + for param_name, weight_name, shard_id in self.STACKED_PARAMS: + if weight_name in fd_name: + model_param_name = fd_name.replace(weight_name, param_name) + assert model_param_name == "minicpm4.layers.5.self_attn.qkv_proj.weight" + assert shard_id == "q" + break + + +# ── Registration & config tests ───────────────────────────────────────────── + + +class TestRegistration: + """Test model architecture registration string.""" + + def test_architecture_string(self): + """MiniCPM4 registers as 'MiniCPMForCausalLM' (matching HF config).""" + # The decorator uses architecture="MiniCPMForCausalLM" + # Verify by reading the source file directly + import ast + import os + + model_file = os.path.join( + os.path.dirname(__file__), + "..", + "..", + "fastdeploy", + "model_executor", + "models", + "minicpm4.py", + ) + with open(model_file) as f: + tree = ast.parse(f.read()) + + # Find the register_model_class decorator + found_arch = None + for node in ast.walk(tree): + if isinstance(node, ast.Call): + for kw in node.keywords: + if kw.arg == "architecture" and isinstance(kw.value, ast.Constant): + found_arch = kw.value.value + break + assert found_arch == "MiniCPMForCausalLM" + + def test_module_name_is_minicpm4(self): + """The module_name in registration is 'minicpm4'.""" + import ast + import os + + model_file = os.path.join( + os.path.dirname(__file__), + "..", + "..", + "fastdeploy", + "model_executor", + "models", + "minicpm4.py", + ) + with open(model_file) as f: + tree = ast.parse(f.read()) + + found_module = None + for node in ast.walk(tree): + if isinstance(node, ast.Call): + for kw in node.keywords: + if kw.arg == "module_name" and isinstance(kw.value, ast.Constant): + found_module = kw.value.value + break + assert found_module == "minicpm4" + + def test_model_classes_exist(self): + """Source file defines all 6 expected classes.""" + import ast + import os + + model_file = os.path.join( + os.path.dirname(__file__), + "..", + "..", + "fastdeploy", + "model_executor", + "models", + "minicpm4.py", + ) + with open(model_file) as f: + tree = ast.parse(f.read()) + + class_names = [node.name for node in ast.walk(tree) if isinstance(node, ast.ClassDef)] + expected = [ + "MiniCPM4MLP", + "MiniCPM4Attention", + "MiniCPM4DecoderLayer", + "MiniCPM4Model", + "MiniCPM4ForCausalLM", + "MiniCPM4PretrainedModel", + ] + for name in expected: + assert name in class_names, f"Missing class: {name}" + + def test_no_qkv_bias(self): + """MiniCPM4Attention uses with_bias=False (unlike Qwen2).""" + import ast + import os + + model_file = os.path.join( + os.path.dirname(__file__), + "..", + "..", + "fastdeploy", + "model_executor", + "models", + "minicpm4.py", + ) + with open(model_file) as f: + source = f.read() + tree = ast.parse(source) + + # Find QKVParallelLinear call inside MiniCPM4Attention + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef) and node.name == "MiniCPM4Attention": + for child in ast.walk(node): + if isinstance(child, ast.Call): + for kw in child.keywords: + if kw.arg == "with_bias" and isinstance(kw.value, ast.Constant): + assert kw.value.value is False, "QKV should have with_bias=False" + return + pytest.fail("with_bias keyword not found in MiniCPM4Attention.QKVParallelLinear") + + +# ── compute_logits logic test ─────────────────────────────────────────────── + + +class TestComputeLogits: + """Test the compute_logits μP scaling and vocab masking logic.""" + + def test_lm_head_scaling_and_vocab_mask(self): + """compute_logits divides by lm_head_scale and masks extended vocab.""" + hidden_size = 128 + ori_vocab_size = 100 + vocab_size = 128 # extended + lm_head_scale = 16.0 + + # Simulate hidden_states + hidden_states = paddle.full([4, hidden_size], 32.0, dtype="float32") + + # Step 1: μP scaling + scaled = hidden_states / lm_head_scale + np.testing.assert_allclose(scaled.numpy().mean(), 2.0, rtol=1e-6) + + # Step 2: Simulate lm_head projection (linear: hidden→vocab) + weight = paddle.ones([vocab_size, hidden_size], dtype="float32") + logits = paddle.matmul(scaled, weight.T) + logits = logits.astype(paddle.float32) + + # Step 3: Mask extended vocab positions + logits[:, ori_vocab_size:] = -float("inf") + + assert logits.shape == [4, vocab_size] + # Valid vocab positions should be finite + assert paddle.isfinite(logits[:, :ori_vocab_size]).all() + # Extended positions should be -inf + assert (logits[:, ori_vocab_size:] == -float("inf")).all() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])