diff --git a/docs/source/Customization/Custom-dataset.md b/docs/source/Customization/Custom-dataset.md index 84fd506231..09b66666eb 100644 --- a/docs/source/Customization/Custom-dataset.md +++ b/docs/source/Customization/Custom-dataset.md @@ -29,7 +29,7 @@ query-response格式: ```jsonl {"system": "", "query": "", "response": "", "history": [["", ""]]} ``` -注意:以下字段会自动转成对应的system、query、response字段。 +注意:以下字段会自动转成对应的system、query、response字段。(solution字段会保留) - system: 'system', 'system_prompt'. - query: 'query', 'prompt', 'input', 'instruction', 'question', 'problem'. - response: 'response', 'answer', 'output', 'targets', 'target', 'answer_key', 'answers', 'solution', 'text', 'completion', 'content'. diff --git a/docs/source/Instruction/Command-line-parameters.md b/docs/source/Instruction/Command-line-parameters.md index af857745ca..791a44a4dc 100644 --- a/docs/source/Instruction/Command-line-parameters.md +++ b/docs/source/Instruction/Command-line-parameters.md @@ -397,6 +397,7 @@ Vera使用`target_modules`、`target_regex`、`modules_to_save`三个参数, - vllm_disable_custom_all_reduce: 禁用自定义的 all-reduce 内核,回退到 NCCL。为了稳定性,默认为`True`。 - vllm_enforce_eager: vllm使用pytorch eager模式还是建立cuda graph,默认为`False`。设置为True可以节约显存,但会影响效率。 - vllm_mm_processor_cache_gb: 多模态处理器缓存大小(GiB),用于缓存已处理的多模态输入(如图像、视频)避免重复处理。默认为`4`。设置为`0`可禁用缓存但会降低性能(不推荐)。仅对多模态模型生效。 +- vllm_speculative_config: 推测解码配置,传入json字符串。默认为None。 - vllm_disable_cascade_attn: 是否强制关闭V1引擎的cascade attention实现以防止潜在数值误差,默认为False,由vLLM内部逻辑决定是否使用。 - 🔥vllm_limit_mm_per_prompt: 控制vllm使用多图,默认为`None`。例如传入`--vllm_limit_mm_per_prompt '{"image": 5, "video": 2}'`。 - vllm_max_lora_rank: 默认为`16`。vllm对于lora支持的参数。 diff --git a/docs/source/Instruction/Supported-models-and-datasets.md b/docs/source/Instruction/Supported-models-and-datasets.md index 5ed4262059..dce4e6a6af 100644 --- a/docs/source/Instruction/Supported-models-and-datasets.md +++ b/docs/source/Instruction/Supported-models-and-datasets.md @@ -1137,6 +1137,7 @@ |-|default|huge dataset|-|pretrain, quality|[allenai/c4](https://huggingface.co/datasets/allenai/c4)| |[bespokelabs/Bespoke-Stratos-17k](https://modelscope.cn/datasets/bespokelabs/Bespoke-Stratos-17k)|default|16710|480.7±236.1, min=266, max=3556|chat, sft, cot, r1|[bespokelabs/Bespoke-Stratos-17k](https://huggingface.co/datasets/bespokelabs/Bespoke-Stratos-17k)| |-|default|huge dataset|-|pretrain, quality|[cerebras/SlimPajama-627B](https://huggingface.co/datasets/cerebras/SlimPajama-627B)| +|[clip-benchmark/wds_voc2007_multilabel](https://modelscope.cn/datasets/clip-benchmark/wds_voc2007_multilabel)|default|2501|112.0±0.0, min=112, max=112|multilabel, multi-modal|[clip-benchmark/wds_voc2007_multilabel](https://huggingface.co/datasets/clip-benchmark/wds_voc2007_multilabel)| |[codefuse-ai/CodeExercise-Python-27k](https://modelscope.cn/datasets/codefuse-ai/CodeExercise-Python-27k)|default|27224|337.3±154.2, min=90, max=2826|chat, coding, 🔥|-| |[codefuse-ai/Evol-instruction-66k](https://modelscope.cn/datasets/codefuse-ai/Evol-instruction-66k)|default|66862|440.1±208.4, min=46, max=2661|chat, coding, 🔥|-| |[damo/MSAgent-Bench](https://modelscope.cn/datasets/damo/MSAgent-Bench)|default
mini|638149|859.2±460.1, min=38, max=3479|chat, agent, multi-round|-| @@ -1164,6 +1165,7 @@ |[modelscope/clue](https://modelscope.cn/datasets/modelscope/clue)|cmnli|391783|81.6±16.0, min=54, max=157|text-generation, classification|[clue](https://huggingface.co/datasets/clue)| |[modelscope/coco_2014_caption](https://modelscope.cn/datasets/modelscope/coco_2014_caption)|train
validation|454617|389.6±68.4, min=70, max=587|chat, multi-modal, vision, 🔥|-| |[modelscope/gsm8k](https://modelscope.cn/datasets/modelscope/gsm8k)|main|7473|88.6±21.6, min=41, max=241|qa, math|-| +|[open-r1/DAPO-Math-17k-Processed](https://modelscope.cn/datasets/open-r1/DAPO-Math-17k-Processed)|all|17398|122.3±65.2, min=41, max=1517|math, rlvr|[open-r1/DAPO-Math-17k-Processed](https://huggingface.co/datasets/open-r1/DAPO-Math-17k-Processed)| |[open-r1/verifiable-coding-problems-python](https://modelscope.cn/datasets/open-r1/verifiable-coding-problems-python)|default|35735|559.0±255.2, min=74, max=6191|grpo, code|[open-r1/verifiable-coding-problems-python](https://huggingface.co/datasets/open-r1/verifiable-coding-problems-python)| |[open-r1/verifiable-coding-problems-python-10k](https://modelscope.cn/datasets/open-r1/verifiable-coding-problems-python-10k)|default|1800|581.6±233.4, min=136, max=2022|grpo, code|[open-r1/verifiable-coding-problems-python-10k](https://huggingface.co/datasets/open-r1/verifiable-coding-problems-python-10k)| |[open-r1/verifiable-coding-problems-python-10k_decontaminated](https://modelscope.cn/datasets/open-r1/verifiable-coding-problems-python-10k_decontaminated)|default|1574|575.7±234.3, min=136, max=2022|grpo, code|[open-r1/verifiable-coding-problems-python-10k_decontaminated](https://huggingface.co/datasets/open-r1/verifiable-coding-problems-python-10k_decontaminated)| @@ -1193,7 +1195,7 @@ |[swift/RedPajama-Data-V2](https://modelscope.cn/datasets/swift/RedPajama-Data-V2)|default|huge dataset|-|pretrain, quality|[togethercomputer/RedPajama-Data-V2](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-V2)| |[swift/ScienceQA](https://modelscope.cn/datasets/swift/ScienceQA)|default|16967|101.7±55.8, min=32, max=620|multi-modal, science, vqa, quality|[derek-thomas/ScienceQA](https://huggingface.co/datasets/derek-thomas/ScienceQA)| |[swift/SlimOrca](https://modelscope.cn/datasets/swift/SlimOrca)|default|517982|405.5±442.1, min=47, max=8312|quality, en|[Open-Orca/SlimOrca](https://huggingface.co/datasets/Open-Orca/SlimOrca)| -|[swift/TextCaps](https://modelscope.cn/datasets/swift/TextCaps)|default
emb|huge dataset|-|multi-modal, en, caption, quality|[HuggingFaceM4/TextCaps](https://huggingface.co/datasets/HuggingFaceM4/TextCaps)| +|[swift/TextCaps](https://modelscope.cn/datasets/swift/TextCaps)|default
emb
rerank|huge dataset|-|multi-modal, en, caption, quality|[HuggingFaceM4/TextCaps](https://huggingface.co/datasets/HuggingFaceM4/TextCaps)| |[swift/ToolBench](https://modelscope.cn/datasets/swift/ToolBench)|default|124345|2251.7±1039.8, min=641, max=9451|chat, agent, multi-round|-| |[swift/VQAv2](https://modelscope.cn/datasets/swift/VQAv2)|default|huge dataset|-|en, vqa, quality|[HuggingFaceM4/VQAv2](https://huggingface.co/datasets/HuggingFaceM4/VQAv2)| |[swift/VideoChatGPT](https://modelscope.cn/datasets/swift/VideoChatGPT)|Generic
Temporal
Consistency|3206|87.4±48.3, min=31, max=398|chat, multi-modal, video, 🔥|[lmms-lab/VideoChatGPT](https://huggingface.co/datasets/lmms-lab/VideoChatGPT)| diff --git a/docs/source/Megatron-SWIFT/Command-line-parameters.md b/docs/source/Megatron-SWIFT/Command-line-parameters.md index ffe22790e1..124c71cff1 100644 --- a/docs/source/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source/Megatron-SWIFT/Command-line-parameters.md @@ -218,6 +218,11 @@ - qk_head_dim: QK 投影中 head 的维度。 `q_head_dim = qk_head_dim + qk_pos_emb_head_dim`。默认为None,自动从config.json读取。 - qk_pos_emb_head_dim: QK 投影中位置嵌入的维度。默认为None,自动从config.json读取。 +**MTP参数** +- mtp_num_layers: 多token预测(MTP)层的数量。MTP将每个位置的预测范围扩展到多个未来token。此MTP实现使用D个顺序模块依次预测D个额外的token。默认为None。(需要"megatron-core>=0.14") + - 注意:mtp_num_layers的值,将不自动从config.json获取,需手动设置。你可以参考config.json中的`num_nextn_predict_layers`字段填写该值。使用mcore-bridge时,将优先从safetensors文件中加载MTP权重,若无法找到,则进行随机初始化。 +- mtp_loss_scaling_factor: 多token预测(MTP)损失的缩放因子。我们计算所有深度上MTP损失的平均值,然后乘以该缩放因子得到总体MTP损失,它将作为一个额外的训练目标。默认为0.1。 + **Tuner参数**: - train_type: 可选为'lora'和'full'。默认为'full'。 - 🔥freeze_llm: 该参数只对多模态模型生效,可用于全参数训练和LoRA训练,但会产生不同的效果。若是全参数训练,将freeze_llm设置为True会将LLM部分权重进行冻结;若是LoRA训练且`target_modules`设置为'all-linear',将freeze_llm设置为True将会取消在LLM部分添加LoRA模块。该参数默认为False。 diff --git a/docs/source_en/Customization/Custom-dataset.md b/docs/source_en/Customization/Custom-dataset.md index 69d49e4ed0..cae457d9eb 100644 --- a/docs/source_en/Customization/Custom-dataset.md +++ b/docs/source_en/Customization/Custom-dataset.md @@ -30,7 +30,7 @@ Query-Response format: ```jsonl {"system": "", "query": "", "response": "", "history": [["", ""]]} ``` -Note: The following fields will be automatically converted to the corresponding system, query, and response fields. +Note: The following fields will be automatically converted to the corresponding system, query, and response fields. (The 'solution' field will be retained) - system: 'system', 'system_prompt'. - query: 'query', 'prompt', 'input', 'instruction', 'question', 'problem'. - response: 'response', 'answer', 'output', 'targets', 'target', 'answer_key', 'answers', 'solution', 'text', 'completion', 'content'. diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index 243c4d5d0f..b7737b54bf 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -404,6 +404,7 @@ Parameter meanings can be found in the [vllm documentation](https://docs.vllm.ai - vllm_disable_custom_all_reduce: Disables the custom all-reduce kernel and falls back to NCCL. For stability, the default is `True`. - vllm_enforce_eager: Determines whether vllm uses PyTorch eager mode or constructs a CUDA graph, default is `False`. Setting it to True can save memory but may affect efficiency. - vllm_mm_processor_cache_gb: The size (in GiB) of the multimodal processor cache, used to store processed multimodal inputs (e.g., images, videos) to avoid redundant processing. Default is 4. Setting it to 0 disables the cache but may degrade performance (not recommended). This option takes effect only for multimodal models. +- vllm_speculative_config: Speculative decoding configuration, passed as a JSON string. Default: None. - vllm_disable_cascade_attn: Whether to forcibly disable the V1 engine’s cascade-attention implementation to avoid potential numerical issues. Defaults to False; vLLM’s internal heuristics determine whether cascade attention is actually used. - 🔥vllm_limit_mm_per_prompt: Controls the use of multiple media in vllm, default is `None`. For example, you can pass in `--vllm_limit_mm_per_prompt '{"image": 5, "video": 2}'`. - vllm_max_lora_rank: Default is `16`. This is the parameter supported by vllm for lora. diff --git a/docs/source_en/Instruction/Supported-models-and-datasets.md b/docs/source_en/Instruction/Supported-models-and-datasets.md index 31e61a42f1..7b66068b8d 100644 --- a/docs/source_en/Instruction/Supported-models-and-datasets.md +++ b/docs/source_en/Instruction/Supported-models-and-datasets.md @@ -1138,6 +1138,7 @@ The table below introduces information about the datasets integrated with ms-swi |-|default|huge dataset|-|pretrain, quality|[allenai/c4](https://huggingface.co/datasets/allenai/c4)| |[bespokelabs/Bespoke-Stratos-17k](https://modelscope.cn/datasets/bespokelabs/Bespoke-Stratos-17k)|default|16710|480.7±236.1, min=266, max=3556|chat, sft, cot, r1|[bespokelabs/Bespoke-Stratos-17k](https://huggingface.co/datasets/bespokelabs/Bespoke-Stratos-17k)| |-|default|huge dataset|-|pretrain, quality|[cerebras/SlimPajama-627B](https://huggingface.co/datasets/cerebras/SlimPajama-627B)| +|[clip-benchmark/wds_voc2007_multilabel](https://modelscope.cn/datasets/clip-benchmark/wds_voc2007_multilabel)|default|2501|112.0±0.0, min=112, max=112|multilabel, multi-modal|[clip-benchmark/wds_voc2007_multilabel](https://huggingface.co/datasets/clip-benchmark/wds_voc2007_multilabel)| |[codefuse-ai/CodeExercise-Python-27k](https://modelscope.cn/datasets/codefuse-ai/CodeExercise-Python-27k)|default|27224|337.3±154.2, min=90, max=2826|chat, coding, 🔥|-| |[codefuse-ai/Evol-instruction-66k](https://modelscope.cn/datasets/codefuse-ai/Evol-instruction-66k)|default|66862|440.1±208.4, min=46, max=2661|chat, coding, 🔥|-| |[damo/MSAgent-Bench](https://modelscope.cn/datasets/damo/MSAgent-Bench)|default
mini|638149|859.2±460.1, min=38, max=3479|chat, agent, multi-round|-| @@ -1165,6 +1166,7 @@ The table below introduces information about the datasets integrated with ms-swi |[modelscope/clue](https://modelscope.cn/datasets/modelscope/clue)|cmnli|391783|81.6±16.0, min=54, max=157|text-generation, classification|[clue](https://huggingface.co/datasets/clue)| |[modelscope/coco_2014_caption](https://modelscope.cn/datasets/modelscope/coco_2014_caption)|train
validation|454617|389.6±68.4, min=70, max=587|chat, multi-modal, vision, 🔥|-| |[modelscope/gsm8k](https://modelscope.cn/datasets/modelscope/gsm8k)|main|7473|88.6±21.6, min=41, max=241|qa, math|-| +|[open-r1/DAPO-Math-17k-Processed](https://modelscope.cn/datasets/open-r1/DAPO-Math-17k-Processed)|all|17398|122.3±65.2, min=41, max=1517|math, rlvr|[open-r1/DAPO-Math-17k-Processed](https://huggingface.co/datasets/open-r1/DAPO-Math-17k-Processed)| |[open-r1/verifiable-coding-problems-python](https://modelscope.cn/datasets/open-r1/verifiable-coding-problems-python)|default|35735|559.0±255.2, min=74, max=6191|grpo, code|[open-r1/verifiable-coding-problems-python](https://huggingface.co/datasets/open-r1/verifiable-coding-problems-python)| |[open-r1/verifiable-coding-problems-python-10k](https://modelscope.cn/datasets/open-r1/verifiable-coding-problems-python-10k)|default|1800|581.6±233.4, min=136, max=2022|grpo, code|[open-r1/verifiable-coding-problems-python-10k](https://huggingface.co/datasets/open-r1/verifiable-coding-problems-python-10k)| |[open-r1/verifiable-coding-problems-python-10k_decontaminated](https://modelscope.cn/datasets/open-r1/verifiable-coding-problems-python-10k_decontaminated)|default|1574|575.7±234.3, min=136, max=2022|grpo, code|[open-r1/verifiable-coding-problems-python-10k_decontaminated](https://huggingface.co/datasets/open-r1/verifiable-coding-problems-python-10k_decontaminated)| @@ -1194,7 +1196,7 @@ The table below introduces information about the datasets integrated with ms-swi |[swift/RedPajama-Data-V2](https://modelscope.cn/datasets/swift/RedPajama-Data-V2)|default|huge dataset|-|pretrain, quality|[togethercomputer/RedPajama-Data-V2](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-V2)| |[swift/ScienceQA](https://modelscope.cn/datasets/swift/ScienceQA)|default|16967|101.7±55.8, min=32, max=620|multi-modal, science, vqa, quality|[derek-thomas/ScienceQA](https://huggingface.co/datasets/derek-thomas/ScienceQA)| |[swift/SlimOrca](https://modelscope.cn/datasets/swift/SlimOrca)|default|517982|405.5±442.1, min=47, max=8312|quality, en|[Open-Orca/SlimOrca](https://huggingface.co/datasets/Open-Orca/SlimOrca)| -|[swift/TextCaps](https://modelscope.cn/datasets/swift/TextCaps)|default
emb|huge dataset|-|multi-modal, en, caption, quality|[HuggingFaceM4/TextCaps](https://huggingface.co/datasets/HuggingFaceM4/TextCaps)| +|[swift/TextCaps](https://modelscope.cn/datasets/swift/TextCaps)|default
emb
rerank|huge dataset|-|multi-modal, en, caption, quality|[HuggingFaceM4/TextCaps](https://huggingface.co/datasets/HuggingFaceM4/TextCaps)| |[swift/ToolBench](https://modelscope.cn/datasets/swift/ToolBench)|default|124345|2251.7±1039.8, min=641, max=9451|chat, agent, multi-round|-| |[swift/VQAv2](https://modelscope.cn/datasets/swift/VQAv2)|default|huge dataset|-|en, vqa, quality|[HuggingFaceM4/VQAv2](https://huggingface.co/datasets/HuggingFaceM4/VQAv2)| |[swift/VideoChatGPT](https://modelscope.cn/datasets/swift/VideoChatGPT)|Generic
Temporal
Consistency|3206|87.4±48.3, min=31, max=398|chat, multi-modal, video, 🔥|[lmms-lab/VideoChatGPT](https://huggingface.co/datasets/lmms-lab/VideoChatGPT)| diff --git a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md index 1c260f5405..c9e11cd8b9 100644 --- a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md @@ -231,6 +231,12 @@ For guidance on selecting parallelization strategies, please refer to the [Train - qk_head_dim: Dimension of the head in the QK projection. `q_head_dim = qk_head_dim + qk_pos_emb_head_dim`. Default is None and will be automatically read from config.json. - qk_pos_emb_head_dim: Dimension of the position embedding in the QK projection. Default is None and will be automatically read from config.json. + +**MTP Parameters** +- mtp_num_layers: Number of Multi-Token Prediction (MTP) layers. MTP extends the prediction scope at each position to multiple future tokens. This MTP implementation uses D sequential modules to sequentially predict D additional tokens. Default is None. (requires "megatron-core>=0.14") + - Note: The value of mtp_num_layers will not be automatically retrieved from config.json and must be set manually. You can refer to the `num_nextn_predict_layers` field in config.json to fill in this value. When using mcore-bridge, MTP weights will be loaded from safetensors files first. If not found, random initialization will be performed. +- mtp_loss_scaling_factor: Scaling factor of Multi-Token Prediction (MTP) loss. We compute the average of MTP losses across all depths, then multiply it by this scaling factor to obtain the overall MTP loss, which serves as an additional training objective. Default is 0.1. + **Tuner Parameters**: - train_type: Options are `'lora'` and `'full'`. Default is `'full'`. diff --git a/examples/infer/sglang/mtp.sh b/examples/infer/sglang/mtp.sh new file mode 100644 index 0000000000..8582f43d94 --- /dev/null +++ b/examples/infer/sglang/mtp.sh @@ -0,0 +1,13 @@ +CUDA_VISIBLE_DEVICES=0,1,2,3 \ +swift infer \ + --model ZhipuAI/GLM-4.5-Air \ + --sglang_tp_size 4 \ + --infer_backend sglang \ + --val_dataset AI-ModelScope/alpaca-gpt4-data-zh#100 \ + --sglang_context_length 8192 \ + --max_new_tokens 2048 \ + --sglang_mem_fraction_static 0.7 \ + --sglang_speculative_algorithm EAGLE \ + --sglang_speculative_eagle_topk 1 \ + --sglang_speculative_num_steps 3 \ + --sglang_speculative_num_draft_tokens 4 diff --git a/examples/infer/vllm/mtp.sh b/examples/infer/vllm/mtp.sh new file mode 100644 index 0000000000..7d3b6a58a9 --- /dev/null +++ b/examples/infer/vllm/mtp.sh @@ -0,0 +1,10 @@ +CUDA_VISIBLE_DEVICES=0,1,2,3 \ +swift infer \ + --model Qwen/Qwen3-Next-80B-A3B-Instruct \ + --vllm_tensor_parallel_size 4 \ + --infer_backend vllm \ + --vllm_max_model_len 8192 \ + --val_dataset AI-ModelScope/alpaca-gpt4-data-zh#100 \ + --vllm_speculative_config '{"method":"qwen3_next_mtp","num_speculative_tokens":2}' \ + --vllm_gpu_memory_utilization 0.9 \ + --max_new_tokens 2048 diff --git a/examples/megatron/lora/glm4_5_106b.sh b/examples/megatron/lora/glm4_5_106b.sh index 69783aed4a..b7ecb23609 100644 --- a/examples/megatron/lora/glm4_5_106b.sh +++ b/examples/megatron/lora/glm4_5_106b.sh @@ -1,10 +1,13 @@ -# thinking -> non-thinking +# demo: thinking -> non-thinking # 4 * 70GiB; 40s/it PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ NPROC_PER_NODE=4 \ CUDA_VISIBLE_DEVICES=0,1,2,3 \ megatron sft \ - --load GLM-4.5-Air-mcore \ + --model ZhipuAI/GLM-4.5-Air \ + --load_safetensors true \ + --save_safetensors true \ + --mtp_num_layers 1 \ --dataset 'swift/Chinese-Qwen3-235B-2507-Distill-data-110k-SFT' \ --load_from_cache_file true \ --train_type lora \ diff --git a/examples/megatron/lora/qwen3_235b.sh b/examples/megatron/lora/qwen3_235b.sh index fc4bada207..01a48f3fb9 100644 --- a/examples/megatron/lora/qwen3_235b.sh +++ b/examples/megatron/lora/qwen3_235b.sh @@ -5,9 +5,12 @@ PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ NPROC_PER_NODE=8 \ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ megatron sft \ - --load Qwen3-235B-A22B-Instruct-2507-mcore \ + --model Qwen/Qwen3-235B-A22B-Instruct-2507 \ --dataset 'swift/Chinese-Qwen3-235B-2507-Distill-data-110k-SFT#2000' \ 'swift/self-cognition#1000' \ + --load_safetensors true \ + --save_safetensors true \ + --merge_lora false \ --load_from_cache_file true \ --train_type lora \ --lora_rank 8 \ diff --git a/examples/models/qwen3_next/mtp.sh b/examples/models/qwen3_next/mtp.sh new file mode 100644 index 0000000000..a310f1a6d5 --- /dev/null +++ b/examples/models/qwen3_next/mtp.sh @@ -0,0 +1,57 @@ +# 8 * 60GiB, 10s/it + +PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ +NPROC_PER_NODE=8 \ +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ +megatron sft \ + --model Qwen/Qwen3-Next-80B-A3B-Instruct \ + --load_safetensors true \ + --save_safetensors true \ + --mtp_num_layers 1 \ + --dataset 'swift/Chinese-Qwen3-235B-2507-Distill-data-110k-SFT#2000' \ + 'swift/self-cognition#1000' \ + --load_from_cache_file true \ + --train_type lora \ + --lora_rank 8 \ + --lora_alpha 32 \ + --target_modules all-linear \ + --expert_model_parallel_size 4 \ + --moe_permute_fusion true \ + --moe_grouped_gemm true \ + --moe_shared_expert_overlap true \ + --moe_aux_loss_coeff 1e-6 \ + --micro_batch_size 2 \ + --global_batch_size 16 \ + --recompute_granularity full \ + --recompute_method uniform \ + --recompute_num_layers 1 \ + --max_epochs 1 \ + --finetune true \ + --cross_entropy_loss_fusion true \ + --lr 1e-4 \ + --lr_warmup_fraction 0.05 \ + --min_lr 1e-5 \ + --save megatron_output/Qwen3-Next-80B-A3B-Instruct \ + --eval_interval 200 \ + --save_interval 200 \ + --max_length 2048 \ + --num_workers 8 \ + --dataset_num_proc 8 \ + --no_save_optim true \ + --no_save_rng true \ + --sequence_parallel true \ + --attention_backend flash \ + --model_author swift \ + --model_name swift-robot + + +# CUDA_VISIBLE_DEVICES=0,1,2,3 \ +# swift infer \ +# --model megatron_output/Qwen3-Next-80B-A3B-Instruct/vx-xxx/checkpoint-xxx \ +# --vllm_tensor_parallel_size 4 \ +# --infer_backend vllm \ +# --vllm_max_model_len 8192 \ +# --val_dataset AI-ModelScope/alpaca-gpt4-data-zh#100 \ +# --vllm_gpu_memory_utilization 0.9 \ +# --vllm_speculative_config '{"method":"qwen3_next_mtp","num_speculative_tokens":2}' \ +# --max_new_tokens 2048 diff --git a/swift/llm/argument/infer_args.py b/swift/llm/argument/infer_args.py index cbf746d4dc..519c46f971 100644 --- a/swift/llm/argument/infer_args.py +++ b/swift/llm/argument/infer_args.py @@ -61,6 +61,12 @@ class SglangArguments: sglang_kv_cache_dtype: str = 'auto' sglang_enable_dp_attention: bool = False sglang_disable_custom_all_reduce: bool = True + # speculative decoding + # e.g. EAGLE, EAGLE3, NEXTN + sglang_speculative_algorithm: Optional[str] = None + sglang_speculative_num_steps: Optional[int] = None + sglang_speculative_eagle_topk: Optional[int] = None + sglang_speculative_num_draft_tokens: Optional[int] = None def get_sglang_engine_kwargs(self): kwargs = { @@ -76,6 +82,10 @@ def get_sglang_engine_kwargs(self): 'kv_cache_dtype': self.sglang_kv_cache_dtype, 'enable_dp_attention': self.sglang_enable_dp_attention, 'disable_custom_all_reduce': self.sglang_disable_custom_all_reduce, + 'speculative_algorithm': self.sglang_speculative_algorithm, + 'speculative_num_steps': self.sglang_speculative_num_steps, + 'speculative_eagle_topk': self.sglang_speculative_eagle_topk, + 'speculative_num_draft_tokens': self.sglang_speculative_num_draft_tokens, } if self.task_type == 'embedding': kwargs['task_type'] = 'embedding' @@ -92,7 +102,7 @@ class InferArguments(MergeArguments, LmdeployArguments, SglangArguments, VllmArg ckpt_dir (Optional[str]): Directory to the checkpoint. Default is None. infer_backend (Literal): Backend to use for inference. Default is 'pt'. Allowed values are 'vllm', 'pt', 'lmdeploy'. - result_path (Optional[str]): Directory to store inference results. Default is None. + result_path (Optional[str]): Path to store inference results. Default is None. max_batch_size (int): Maximum batch size for the pt engine. Default is 1. val_dataset_sample (Optional[int]): Sample size for validation dataset. Default is None. reranker_use_activation (bool): reranker use activation after calculating. Default is True. diff --git a/swift/llm/infer/infer_engine/grpo_vllm_engine.py b/swift/llm/infer/infer_engine/grpo_vllm_engine.py index 18b626a505..cb0d2d56df 100644 --- a/swift/llm/infer/infer_engine/grpo_vllm_engine.py +++ b/swift/llm/infer/infer_engine/grpo_vllm_engine.py @@ -46,6 +46,7 @@ def __init__( disable_cascade_attn: bool = False, load_format: str = 'auto', mm_processor_cache_gb: Optional[float] = None, + speculative_config: Optional[Union[str, dict]] = None, # lora enable_lora: bool = False, max_loras: int = 1, @@ -80,6 +81,7 @@ def __init__( disable_cascade_attn=disable_cascade_attn, load_format=load_format, mm_processor_cache_gb=mm_processor_cache_gb, + speculative_config=speculative_config, enable_lora=enable_lora, max_loras=max_loras, max_lora_rank=max_lora_rank, diff --git a/swift/llm/infer/infer_engine/infer_engine.py b/swift/llm/infer/infer_engine/infer_engine.py index 4e1c903d17..86c9583e40 100644 --- a/swift/llm/infer/infer_engine/infer_engine.py +++ b/swift/llm/infer/infer_engine/infer_engine.py @@ -32,6 +32,7 @@ def _post_init(self, template=None): self.max_model_len = self.model_info.max_model_len self.task_type = self.model_info.task_type self.config = self.model_info.config + self.max_tokens_offset = 0 if template is None: ckpt_dir = get_ckpt_dir(self.model_dir, getattr(self, 'adapters', None)) logger.info('Create the default_template for the infer_engine') @@ -220,7 +221,7 @@ def set_default_max_tokens(self, request_config: RequestConfig, inputs: Dict[str max_model_len = 8192 logger.warning( 'The current model is unable to retrieve `max_model_len`. It is set to the default value of 8192.') - max_max_tokens = max_model_len - num_tokens + max_max_tokens = max_model_len - num_tokens + self.max_tokens_offset if max_tokens is None: request_config.max_tokens = max_max_tokens elif max_max_tokens < request_config.max_tokens: diff --git a/swift/llm/infer/infer_engine/sglang_engine.py b/swift/llm/infer/infer_engine/sglang_engine.py index dd78d0b651..37de0f845e 100644 --- a/swift/llm/infer/infer_engine/sglang_engine.py +++ b/swift/llm/infer/infer_engine/sglang_engine.py @@ -48,6 +48,10 @@ def __init__( kv_cache_dtype: str = 'auto', enable_dp_attention: bool = False, disable_custom_all_reduce: bool = True, + speculative_algorithm: Optional[str] = None, + speculative_num_steps: Optional[int] = None, + speculative_eagle_topk: Optional[int] = None, + speculative_num_draft_tokens: Optional[int] = None, log_level='error', engine_kwargs: Optional[Dict[str, Any]] = None, template: Optional[Template] = None, @@ -88,6 +92,10 @@ def __init__( kv_cache_dtype=kv_cache_dtype, enable_dp_attention=enable_dp_attention, disable_custom_all_reduce=disable_custom_all_reduce, + speculative_algorithm=speculative_algorithm, + speculative_num_steps=speculative_num_steps, + speculative_eagle_topk=speculative_eagle_topk, + speculative_num_draft_tokens=speculative_num_draft_tokens, log_level=log_level, skip_tokenizer_init=True, trust_remote_code=True, @@ -98,6 +106,8 @@ def __init__( self.server_args.is_embedding = True self.engine = sgl.Engine(server_args=self.server_args) self._load_generation_config() + if speculative_num_draft_tokens is not None: + self.max_tokens_offset = -speculative_num_draft_tokens def _load_generation_config(self) -> None: generation_config_path = os.path.join(self.model_dir, 'generation_config.json') diff --git a/swift/llm/infer/infer_engine/vllm_engine.py b/swift/llm/infer/infer_engine/vllm_engine.py index 2dc37c0ee0..e5c927cbb4 100644 --- a/swift/llm/infer/infer_engine/vllm_engine.py +++ b/swift/llm/infer/infer_engine/vllm_engine.py @@ -70,6 +70,7 @@ def __init__( disable_cascade_attn: bool = False, load_format: str = 'auto', mm_processor_cache_gb: Optional[float] = None, + speculative_config: Optional[Union[str, dict]] = None, # lora enable_lora: bool = False, max_loras: int = 1, @@ -131,6 +132,7 @@ def __init__( task=task_type, disable_cascade_attn=disable_cascade_attn, mm_processor_cache_gb=mm_processor_cache_gb, + speculative_config=speculative_config, **engine_kwargs, ) context = nullcontext() @@ -172,6 +174,7 @@ def _prepare_engine_kwargs( disable_cascade_attn: bool = False, load_format: str = 'auto', mm_processor_cache_gb: Optional[float] = None, + speculative_config: Optional[Union[str, dict]] = None, **engine_kwargs, ) -> None: if task == 'embedding': @@ -202,7 +205,7 @@ def _prepare_engine_kwargs( 'The current version of vLLM does not support `limit_mm_per_prompt`. Please upgrade vLLM.') for key in [ 'enable_expert_parallel', 'enable_sleep_mode', 'disable_cascade_attn', 'load_format', - 'mm_processor_cache_gb' + 'mm_processor_cache_gb', 'speculative_config' ]: if key in parameters: if locals()[key] is not None: diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index a15624e01b..2d34af793b 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -496,6 +496,10 @@ class MegatronArguments(ExtraMegatronArguments): qk_head_dim: Optional[int] = None qk_pos_emb_head_dim: Optional[int] = None + # mtp + mtp_num_layers: Optional[int] = None + mtp_loss_scaling_factor: float = 0.1 + # fp8 fp8_format: Literal['e4m3', 'hybrid'] = None fp8_recipe: Literal['tensorwise', 'delayed', 'mxfp8', 'blockwise'] = 'delayed' diff --git a/swift/megatron/init.py b/swift/megatron/init.py index 30458c45a2..3a95952260 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -5,7 +5,8 @@ import subprocess import sys from contextlib import contextmanager -from copy import copy +from copy import copy, deepcopy +from functools import partial from typing import List, Optional, Tuple import peft @@ -388,6 +389,101 @@ def build_tokenizer(args): global_vars.build_tokenizer = build_tokenizer +def _patch_mtp(): + from megatron.core import InferenceParams + from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionLayer + from megatron.core.packed_seq_params import PackedSeqParams + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + context: torch.Tensor = None, + context_mask: torch.Tensor = None, + rotary_pos_emb: torch.Tensor = None, + rotary_pos_cos: torch.Tensor = None, + rotary_pos_sin: torch.Tensor = None, + attention_bias: torch.Tensor = None, + inference_params: InferenceParams = None, + packed_seq_params: PackedSeqParams = None, + sequence_len_offset: torch.Tensor = None, + embedding=None, + ): + """ + Execute the forward pass through the Multi-Token Prediction (MTP) layer. + + Args: + input_ids (Tensor): Input token IDs . + position_ids (Tensor): Positional IDs of the input tokens. + hidden_states (Tensor): Hidden states tensor of shape [s, b, h] where s is the + sequence length, b is the batch size, and h is the hidden size. + attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking + self-attention. + context (Tensor, optional): Context tensor for cross-attention, if applicable. + context_mask (Tensor, optional): Mask for cross-attention context, if applicable. + rotary_pos_emb (Tensor, optional): Rotary positional embeddings. + rotary_pos_cos (Tensor, optional): Cosine component of rotary positional embeddings. + rotary_pos_sin (Tensor, optional): Sine component of rotary positional embeddings. + sequence_len_offset (Tensor, optional): Offset for sequence length, if applicable. + embedding (Callable): The embedding module from gpt model to compute the decoder input. + + Returns: + Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape + [s, b, h], and optionally the updated context tensor if cross-attention is used. + """ + # TODO: Multimodal compatible + assert context is None, 'multi token prediction + cross attention is not yet supported.' + input_ids, position_ids, decoder_input, hidden_states = self._get_embeddings( + input_ids=input_ids, + position_ids=position_ids, + embedding=embedding, + hidden_states=hidden_states, + ) + packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' + if packed_seq: + assert not self.transformer_layer.self_attention.config.apply_rope_fusion + assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}' + rotary_pos_emb = rotary_pos_emb[position_ids[0]] + if self.config.recompute_granularity == 'full' and self.training: + hidden_states = self._checkpointed_forward( + partial( + self._proj_and_transformer_layer, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + ), + hidden_states=hidden_states, + decoder_input=decoder_input, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + attention_bias=attention_bias, + inference_params=inference_params, + ) + else: + hidden_states = self._proj_and_transformer_layer( + hidden_states=hidden_states, + decoder_input=decoder_input, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + attention_bias=attention_bias, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + ) + return hidden_states, input_ids, position_ids + + MultiTokenPredictionLayer.forward = forward + + def _patch_peft_ModulesToSaveWrapper(): if version.parse(peft.__version__) >= version.parse('0.16'): from peft.utils import other as peft_module @@ -686,6 +782,7 @@ def _patch_megatron(): _patch_build_train_valid_test_datasets() _patch_mrope() _patch_megatron_tokenizer() + _patch_mtp() logging.root.setLevel(logging_level) # revert logger level from swift.megatron import tuners # patch lora try: diff --git a/swift/megatron/model/gpt/qwen3_next.py b/swift/megatron/model/gpt/qwen3_next.py index e4603352d4..5950df2722 100644 --- a/swift/megatron/model/gpt/qwen3_next.py +++ b/swift/megatron/model/gpt/qwen3_next.py @@ -504,6 +504,7 @@ def get_qwen3_next_transformer_layer_spec(config, vp_stage=None): class Qwen3NextBridge(GPTBridge): + hf_mtp_prefix = 'mtp.layers' def _set_state_dict(self, mg_module, @@ -514,7 +515,7 @@ def _set_state_dict(self, *, offset: float = 0, is_expert: bool = False): - if 'layernorm' in mg_key or 'layer_norm_weight' in mg_key: + if 'layernorm' in mg_key or 'layer_norm_weight' in mg_key or 'enorm' in mg_key or 'hnorm' in mg_key: offset = 1 if to_mcore else -1 return super()._set_state_dict( mg_module, @@ -537,6 +538,15 @@ def _set_layer_attn(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: boo hf_state_dict = super()._set_layer_attn(mg_layer, hf_state_dict, layer_idx, to_mcore) return hf_state_dict + def _convert_mtp_extra(self, mtp_layer, hf_state_dict, to_mcore, origin_hf_state_dict): + hf_state_dict = self._remove_prefix(origin_hf_state_dict, 'mtp.') + for mg_key, key in zip(['enorm.weight', 'hnorm.weight', 'eh_proj.weight'], + ['pre_fc_norm_embedding.weight', 'pre_fc_norm_hidden.weight', 'fc.weight']): + self._set_state_dict(mtp_layer, mg_key, hf_state_dict, key, to_mcore) + self._set_state_dict(mtp_layer, 'final_layernorm.weight', hf_state_dict, 'norm.weight', to_mcore) + if not to_mcore: + origin_hf_state_dict.update(self._add_prefix(hf_state_dict, 'mtp.')) + register_megatron_model( MegatronModelMeta( diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index 451925285d..f9da632687 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -26,6 +26,7 @@ # Some ideas for LoRA conversion are referenced from: https://github.com/modelscope/ms-swift/pull/6225 class GPTBridge: hf_layers_prefix = 'model.layers' + hf_mtp_prefix = 'model.layers' hf_embed_key = 'model.embed_tokens.weight' hf_final_layernorm_key = 'model.norm.weight' hf_lm_head_key = 'lm_head.weight' @@ -100,7 +101,9 @@ def _get_tp_split_dim(self, mg_key: Optional[str]) -> Optional[int]: # mla 'linear_q_proj', 'linear_q_up_proj', - 'linear_kv_up_proj' + 'linear_kv_up_proj', + # mtp + 'eh_proj', } if self.args.task_type == 'causal_lm': dim0_keys.add('output_layer') @@ -1034,6 +1037,18 @@ def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool, tqd else: yield from list(self._add_prefix(res, hf_prefix).items()) hf_state_dict = {} + + if not to_mcore or is_pp_last_stage and self.args.mtp_num_layers: + lm_model = getattr(mg_model, 'language_model') if self.args.is_multimodal else mg_model + layer_idx = 0 + while layer_idx < self.args.mtp_num_layers: + res = self._convert_mtp_layer(lm_model, hf_state_dict, f'{self.hf_mtp_prefix}.', layer_idx, to_mcore) + layer_idx += 1 + if to_mcore: + yield + else: + yield from list(self._add_prefix(res, hf_prefix).items()) + hf_state_dict = {} if not to_mcore or is_pp_last_stage: hf_state_dict.update(self._convert_post_process(mg_model, hf_state_dict, '', to_mcore)) if to_mcore: @@ -1042,6 +1057,48 @@ def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool, tqd hf_state_dict = self._convert_hf_state_dict(hf_state_dict, to_mcore) yield from list(self._add_prefix(hf_state_dict, hf_prefix).items()) + def _convert_mtp_extra(self, mtp_layer, hf_state_dict, to_mcore, origin_hf_state_dict): + for key in ['enorm.weight', 'hnorm.weight', 'eh_proj.weight']: + self._set_state_dict(mtp_layer, key, hf_state_dict, key, to_mcore) + self._set_state_dict(mtp_layer, 'final_layernorm.weight', hf_state_dict, 'shared_head.norm.weight', to_mcore) + + def _convert_mtp_layer(self, lm_model, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): + mtp_layer = lm_model.mtp.layers[layer_idx] if hasattr(lm_model, 'mtp') else None + if self.hf_mtp_prefix == self.hf_layers_prefix: + hf_layer_idx = layer_idx + self.args.num_layers + else: + hf_layer_idx = layer_idx + hf_prefix = f'{hf_prefix}{hf_layer_idx}.' + if to_mcore: + origin_hf_state_dict = hf_state_dict + hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) + if len(hf_state_dict) == 0: + logger.info_if( + f'MTP Layer {mtp_layer.layer_number} safetensors weights not found, ' + 'this part will be randomly initialized.', + cond=is_last_rank()) + for param in mtp_layer.parameters(): + if param.ndim == 2: + mtp_layer.config.init_method(param.data) + return {} + else: + origin_hf_state_dict = {} + hf_state_dict = {} + # Weights for shared parts are not stored, refer to GLM4.6 + # self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, 'embed_tokens.weight', + # to_mcore) + # self._set_state_dict(lm_model, 'output_layer.weight', hf_state_dict, 'shared_head.head.weight', to_mcore) + self._convert_mtp_extra(mtp_layer, hf_state_dict, to_mcore, origin_hf_state_dict) + transformer_layer = None if mtp_layer is None else mtp_layer.transformer_layer + hf_state_dict.update(self._set_layer_attn(transformer_layer, hf_state_dict, -1, to_mcore)) + hf_state_dict.update(self._set_layer_mlp(transformer_layer, hf_state_dict, -1, to_mcore)) + if to_mcore: + hf_state_dict = {} + else: + hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) + hf_state_dict.update(origin_hf_state_dict) + return hf_state_dict + def load_weights(self, mg_model, hf_model_dir: str, is_peft_format: bool = False, adapter_name: str = 'default'): self._is_peft_format = is_peft_format self._adapter_name = adapter_name diff --git a/swift/megatron/model/gpt_model.py b/swift/megatron/model/gpt_model.py index b529a73337..f4c8f6ac39 100644 --- a/swift/megatron/model/gpt_model.py +++ b/swift/megatron/model/gpt_model.py @@ -1,10 +1,12 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from collections import OrderedDict from contextlib import contextmanager +from copy import deepcopy from typing import Any, Dict, Literal, Optional, Tuple import megatron.core import torch +from megatron.core import parallel_state from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.extensions.transformer_engine import TELinear @@ -13,6 +15,7 @@ from megatron.core.models.gpt import GPTModel as McoreGPTModel from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region +from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler, MTPLossLoggingHelper, roll_tensor from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import WrappedTensor, deprecate_inference_params @@ -140,8 +143,17 @@ def __init__( if (self.attention_scaling != 1 or position_embedding_type == 'mrope') and config.apply_rope_fusion: config.apply_rope_fusion = False - logger.warning('`apply_rope_fusion` does not support `attention_scaling`. ' + if self.attention_scaling != 1: + warning_string = 'attention_scaling' + else: + warning_string = 'mrope' + logger.warning(f'`apply_rope_fusion` does not support `{warning_string}`. ' f'Setting `config.apply_rope_fusion`: {config.apply_rope_fusion}') + if getattr(self, 'mtp', None) is not None: + for layer in self.mtp.layers: + attention = layer.transformer_layer.self_attention + attention.config = deepcopy(attention.config) + attention.config.apply_rope_fusion = False @contextmanager def _patch_apply_rotary_pos_emb(self): @@ -264,6 +276,8 @@ def forward( *, inference_params: Optional[BaseInferenceContext] = None, loss_mask: Optional[torch.Tensor] = None, + # Mask labels to be compatible with thd & MTP + mtp_labels: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: """Forward function of the GPT Model This function passes the input tensors @@ -304,52 +318,159 @@ def forward( args = get_args() labels = labels if args.task_type == 'causal_lm' else None - if mcore_013: - return self._postprocess( - hidden_states=hidden_states, + # MTP: https://github.com/NVIDIA/Megatron-LM/issues/1661 + return self._postprocess( + hidden_states=hidden_states, + input_ids=input_ids, + position_ids=position_ids, + labels=labels, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + mtp_in_postprocess=self.mtp_process, + loss_mask=loss_mask, + decoder_input=decoder_input, + attention_mask=attention_mask, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + runtime_gather_output=runtime_gather_output, + extra_block_kwargs=extra_block_kwargs, + inference_context=inference_context, + mtp_labels=mtp_labels, + ) + + def _postprocess( + self, + hidden_states, + input_ids, + position_ids, + labels, + rotary_pos_emb, + rotary_pos_cos, + rotary_pos_sin, + mtp_in_postprocess=None, + loss_mask=None, + decoder_input=None, + attention_mask=None, + inference_params=None, + packed_seq_params=None, + sequence_len_offset=None, + runtime_gather_output=None, + extra_block_kwargs=None, + inference_context=None, + mtp_labels=None, + ): + """Postprocesses decoder hidden states to generate logits or compute loss. + + Applies Multi-Token Prediction if enabled, generates output logits through + the output layer, and computes language model loss when labels are provided. + """ + in_inference_mode = inference_context is not None and not self.training + if in_inference_mode: + assert runtime_gather_output, 'Inference must always gather TP logits' + + # logits and loss + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() + + if mtp_in_postprocess: + hidden_states = self.mtp( input_ids=input_ids, position_ids=position_ids, - labels=labels, + hidden_states=hidden_states, + attention_mask=attention_mask, + inference_params=inference_params, rotary_pos_emb=rotary_pos_emb, rotary_pos_cos=rotary_pos_cos, rotary_pos_sin=rotary_pos_sin, - mtp_in_postprocess=self.mtp_process, - loss_mask=loss_mask, - decoder_input=decoder_input, - attention_mask=attention_mask, - inference_params=inference_params, packed_seq_params=packed_seq_params, sequence_len_offset=sequence_len_offset, - runtime_gather_output=runtime_gather_output, - extra_block_kwargs=extra_block_kwargs, - inference_context=inference_context, + embedding=self.embedding, + **(extra_block_kwargs or {}), ) - else: - if not self.post_process: - return hidden_states - - # logits and loss - output_weight = None - if self.share_embeddings_and_output_weights: - output_weight = self.shared_embedding_or_output_weight() - logits, _ = self.output_layer( - hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output) - if has_config_logger_enabled(self.config): - payload = OrderedDict({ - 'input_ids': input_ids, - 'position_ids': position_ids, - 'attention_mask': attention_mask, - 'decoder_input': decoder_input, - 'logits': logits, - }) - log_config_to_disk(self.config, payload, prefix='input_and_logits') - if labels is None: - # [s b h] => [b s h] - return logits.transpose(0, 1).contiguous() - - loss = self.compute_language_model_loss(labels, logits) - - return loss + + if not self.post_process: + return hidden_states + + if self.mtp_process: + hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) + hidden_states = hidden_states_list[0] + if loss_mask is None: + # if loss_mask is not provided, use all ones as loss_mask + loss_mask = torch.ones_like(mtp_labels) + for mtp_layer_number in range(self.config.mtp_num_layers): + # output + mtp_logits, _ = self.output_layer( + hidden_states_list[mtp_layer_number + 1], + weight=output_weight, + runtime_gather_output=runtime_gather_output, + ) + # Calc loss for the current Multi-Token Prediction (MTP) layers. + mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group) + loss_mask, num_tokens = roll_tensor(loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group) + mtp_loss = self.compute_language_model_loss(mtp_labels, mtp_logits) + mtp_loss = loss_mask * mtp_loss + if self.training: + # TODO(shifangx): remove the use of parallel_state here + # after moving loss logging to loss_func in pretrain_gpt.py + MTPLossLoggingHelper.save_loss_to_tracker( + torch.sum(mtp_loss) / num_tokens, + mtp_layer_number, + self.config.mtp_num_layers, + avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True), + ) + mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers + if self.config.calculate_per_token_loss: + hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss) + else: + hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss / num_tokens) + sequence_parallel_override = False + if in_inference_mode and inference_context.materialize_only_last_token_logits: + if inference_context.is_static_batching(): + hidden_states = hidden_states[-1:, :, :] + else: + if self.output_layer.sequence_parallel: + # Perform the sequence parallel gather here instead of after the output layer + # because we need to slice the last token logits from the full view of the + # packed logits across all requests. + # TODO(ksanthanam): Make the equivalent change in the `MambaModel` code after + # merging in !3722. + hidden_states = gather_from_sequence_parallel_region(hidden_states, group=self.pg_collection.tp) + self.output_layer.sequence_parallel = False + sequence_parallel_override = True + + # Reshape [B, 1, H] to [1, B, H] → extract each sample’s true last‐token hidden + # state ([B, H]) → unsqueeze back to [1, B, H] + # (so that the output layer, which expects S×B×H, receives only the final token) + hidden_states = inference_context.last_token_logits(hidden_states.squeeze(1).unsqueeze(0)).unsqueeze(1) + + logits, _ = self.output_layer(hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output) + + # Restore sequence parallel execution to the output layer if necessary. + if sequence_parallel_override: + assert (in_inference_mode and inference_context.is_dynamic_batching() + and inference_context.materialize_only_last_token_logits) + self.output_layer.sequence_parallel = True + + if has_config_logger_enabled(self.config): + payload = OrderedDict({ + 'input_ids': input_ids, + 'position_ids': position_ids, + 'attention_mask': attention_mask, + 'decoder_input': decoder_input, + 'logits': logits, + }) + log_config_to_disk(self.config, payload, prefix='input_and_logits') + + if labels is None: + # [s b h] => [b s h] + return logits.transpose(0, 1).contiguous() + + loss = self.compute_language_model_loss(labels, logits) + + return loss def get_input_tensor(self): return self.decoder.input_tensor diff --git a/swift/megatron/model/mm_gpt_model.py b/swift/megatron/model/mm_gpt_model.py index 8be3c36744..25a8244d01 100644 --- a/swift/megatron/model/mm_gpt_model.py +++ b/swift/megatron/model/mm_gpt_model.py @@ -1,5 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from contextlib import contextmanager +from typing import Optional import megatron.core import torch @@ -39,6 +40,8 @@ def __init__(self, args = get_args() self.megatron_model_meta = get_megatron_model_meta(args.hf_model_type) self.visual = None + if args.mtp_num_layers: + raise ValueError('MTP currently does not support multimodal models.') if pre_process and self.megatron_model_meta.visual_cls is not None: self.visual = self.megatron_model_meta.visual_cls(config) @@ -87,6 +90,8 @@ def forward( labels: torch.Tensor = None, inference_params: InferenceParams = None, packed_seq_params: PackedSeqParams = None, + *, + mtp_labels: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: if decoder_input is not None: @@ -108,6 +113,7 @@ def forward( labels=labels, inference_params=inference_params, packed_seq_params=packed_seq_params, + mtp_labels=mtp_labels, **kwargs, ) diff --git a/swift/megatron/model/model_provider.py b/swift/megatron/model/model_provider.py index 997f53a231..84798aec49 100644 --- a/swift/megatron/model/model_provider.py +++ b/swift/megatron/model/model_provider.py @@ -117,10 +117,7 @@ def oom_observer(device, alloc, device_alloc, device_free): transformer_layer_spec = megatron_model_meta.get_transformer_layer_spec(config, vp_stage=vp_stage) else: if args.num_experts: - if mcore_013: - kwargs = {'qk_l2_norm': args.qk_l2_norm, 'vp_stage': vp_stage} - else: - kwargs = {} + kwargs = {'qk_l2_norm': args.qk_l2_norm, 'vp_stage': vp_stage} if mcore_013 else {} # Define the decoder block spec transformer_layer_spec = get_gpt_decoder_block_spec( config, use_transformer_engine=use_te, normalization=args.normalization, **kwargs) @@ -137,8 +134,9 @@ def oom_observer(device, alloc, device_alloc, device_free): transformer_layer_spec_for_mtp = _get_transformer_layer_spec(use_te, config) else: transformer_layer_spec_for_mtp = transformer_layer_spec + kwargs = {'vp_stage': vp_stage} if mcore_013 else {} mtp_block_spec = get_gpt_mtp_block_spec( - config, transformer_layer_spec_for_mtp, use_transformer_engine=use_te, vp_stage=vp_stage) + config, transformer_layer_spec_for_mtp, use_transformer_engine=use_te, **kwargs) if args.use_shared_expert_gate and args.num_experts and args.moe_shared_expert_intermediate_size: # qwen2_moe diff --git a/swift/megatron/train/__init__.py b/swift/megatron/train/__init__.py index 1b091bd4a3..537a1489de 100644 --- a/swift/megatron/train/__init__.py +++ b/swift/megatron/train/__init__.py @@ -1,4 +1,24 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from .pt import megatron_pt_main -from .rlhf import megatron_rlhf_main -from .sft import megatron_sft_main +from typing import TYPE_CHECKING + +from swift.utils.import_utils import _LazyModule + +if TYPE_CHECKING: + from .pt import megatron_pt_main + from .rlhf import megatron_rlhf_main + from .sft import megatron_sft_main +else: + _import_structure = { + 'pt': ['megatron_pt_main'], + 'rlhf': ['megatron_rlhf_main'], + 'sft': ['megatron_sft_main'], + } + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/swift/megatron/trainers/__init__.py b/swift/megatron/trainers/__init__.py index 80cf16fe22..1f5ce04967 100644 --- a/swift/megatron/trainers/__init__.py +++ b/swift/megatron/trainers/__init__.py @@ -1,6 +1,28 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from .dpo_trainer import MegatronDPOTrainer -from .grpo_trainer import MegatronGRPOTrainer -from .kto_trainer import MegatronKTOTrainer -from .reward_trainer import MegatronRewardTrainer -from .trainer import MegatronTrainer +from typing import TYPE_CHECKING + +from swift.utils.import_utils import _LazyModule + +if TYPE_CHECKING: + from .dpo_trainer import MegatronDPOTrainer + from .grpo_trainer import MegatronGRPOTrainer + from .kto_trainer import MegatronKTOTrainer + from .reward_trainer import MegatronRewardTrainer + from .trainer import MegatronTrainer +else: + _import_structure = { + 'dpo_trainer': ['MegatronDPOTrainer'], + 'grpo_trainer': ['MegatronGRPOTrainer'], + 'kto_trainer': ['MegatronKTOTrainer'], + 'reward_trainer': ['MegatronRewardTrainer'], + 'trainer': ['MegatronTrainer'], + } + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index f2e55224a2..3cd1b7d6ed 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -22,7 +22,7 @@ from megatron.core.transformer.multi_token_prediction import MTPLossLoggingHelper from megatron.core.utils import StragglerDetector from megatron.training import (checkpointing, ft_integration, get_args, get_model, get_tensorboard_writer, get_timers, - get_wandb_writer, is_last_rank, one_logger_utils, pretrain, print_rank_0, + get_wandb_writer, initialize, is_last_rank, one_logger_utils, pretrain, print_rank_0, print_rank_last, training) from megatron.training.checkpointing import load_checkpoint from megatron.training.theoretical_memory_usage import report_theoretical_memory @@ -81,6 +81,7 @@ def bridge(self): @contextmanager def _get_iters(self, train_dataset, val_dataset): origin_initialize_megatron = training.initialize_megatron + origin_validate_args = initialize.validate_args def initialize_megatron(*_args, **kwargs): res = origin_initialize_megatron(*_args, **kwargs) @@ -109,11 +110,16 @@ def initialize_megatron(*_args, **kwargs): logger.info(f'Setting args.eval_iters: {args.eval_iters}') return res + def validate_args(args, *_args, **kwargs): + return origin_validate_args(args, *_args, **kwargs) + training.initialize_megatron = initialize_megatron + initialize.validate_args = validate_args try: yield finally: training.initialize_megatron = origin_initialize_megatron + initialize.validate_args = origin_validate_args def new_cyclic_iter(self, iterable): args = get_args() @@ -790,6 +796,7 @@ def training_log(self, loss_dict, total_loss_dict, learning_rate, decoupled_lear track_names.append('load_balancing_loss') if args.moe_z_loss_coeff is not None: track_names.append('z_loss') + track_moe_kwargs = {'mtp_num_layers': args.mtp_num_layers} if self.mcore_013 else {} track_moe_metrics( loss_scale=moe_loss_scale, iteration=iteration, @@ -800,7 +807,8 @@ def training_log(self, loss_dict, total_loss_dict, learning_rate, decoupled_lear force_initialize=True, track_names=track_names, num_layers=args.num_layers, - moe_layer_freq=args.moe_layer_freq) + moe_layer_freq=args.moe_layer_freq, + **track_moe_kwargs) if args.mtp_num_layers is not None: mtp_loss_scale = 1 / get_num_microbatches() MTPLossLoggingHelper.track_mtp_metrics(mtp_loss_scale, iteration, writer, wandb_writer, total_loss_dict) @@ -1027,6 +1035,13 @@ def _prepare_batch(self, data, vp_stage, num_samples=None): if args.padding_free and text_position_ids is not None: batch['packed_seq_params'] = get_packed_seq_params(text_position_ids) batch['packed_seq_params'].num_samples = num_samples + if args.mtp_num_layers and batch.get('labels') is not None: + cu_seqlens = batch['packed_seq_params'].cu_seqlens_q.clone() + mtp_labels = batch['labels'].clone() + for _ in range(args.mtp_num_layers): + mtp_labels[:, cu_seqlens[cu_seqlens < mtp_labels.shape[1]]] = -100 + cu_seqlens = cu_seqlens + 1 + batch['mtp_labels'] = mtp_labels # slice batch along sequence dimension for context parallelism batch = get_batch_on_this_cp_rank(batch) return batch diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index 9d4c4c96b3..e5e376b014 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -57,7 +57,7 @@ def get_batch_on_this_tp_rank(data, vp_stage=None): else: is_pp_first_stage = mpu.is_pipeline_first_stage() is_pp_last_stage = mpu.is_pipeline_last_stage() - if not is_pp_first_stage: + if not args.mtp_num_layers and not is_pp_first_stage: batch['input_ids'] = None if not is_pp_last_stage: batch['labels'] = None @@ -110,7 +110,7 @@ def get_batch_on_this_cp_rank(batch: Dict[str, Any]): cp_size = mpu.get_context_parallel_world_size() if cp_size > 1: args = get_args() - keys = ['labels', 'attention_mask', 'position_ids', 'loss_scale'] + keys = ['labels', 'attention_mask', 'position_ids', 'loss_scale', 'mtp_labels'] if not args.is_multimodal: # Multimodal models will handle CP in input_embeds. keys.append('input_ids') diff --git a/swift/trainers/arguments.py b/swift/trainers/arguments.py index 42f6afdcdd..a178d86841 100644 --- a/swift/trainers/arguments.py +++ b/swift/trainers/arguments.py @@ -206,12 +206,16 @@ class VllmArguments: vllm_reasoning_parser: Optional[str] = None vllm_disable_cascade_attn: bool = False vllm_mm_processor_cache_gb: Optional[float] = None + vllm_speculative_config: Optional[Union[dict, str]] = None vllm_engine_kwargs: Optional[Union[dict, str]] = None # rollout vllm_data_parallel_size: int = 1 def __post_init__(self): - self.vllm_limit_mm_per_prompt = json_parse_to_dict(self.vllm_limit_mm_per_prompt) + if self.vllm_limit_mm_per_prompt is not None: + self.vllm_limit_mm_per_prompt = json_parse_to_dict(self.vllm_limit_mm_per_prompt) + if self.vllm_speculative_config is not None: + self.vllm_speculative_config = json_parse_to_dict(self.vllm_speculative_config) self.vllm_engine_kwargs = json_parse_to_dict(self.vllm_engine_kwargs) def get_vllm_engine_kwargs(self): @@ -237,6 +241,7 @@ def get_vllm_engine_kwargs(self): 'reasoning_parser': self.vllm_reasoning_parser, 'disable_cascade_attn': self.vllm_disable_cascade_attn, 'mm_processor_cache_gb': self.vllm_mm_processor_cache_gb, + 'speculative_config': self.vllm_speculative_config, 'num_labels': self.num_labels, 'engine_kwargs': self.vllm_engine_kwargs, }