Skip to content

Conversation

@Jintao-Huang
Copy link
Collaborator

No description provided.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Jintao-Huang, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the Megatron framework by introducing support for Multi-Token Prediction (MTP) within the GPT model. The changes involve adding new configurable parameters for MTP layers and loss scaling, alongside a substantial refactoring of the GPT model's forward pass. This refactoring separates input preprocessing and output postprocessing into distinct methods, creating a more modular and maintainable architecture that seamlessly incorporates MTP as an additional training objective.

Highlights

  • Multi-Token Prediction (MTP) Support: Introduced new command-line parameters (mtp_num_layers and mtp_loss_scaling_factor) to enable and configure Multi-Token Prediction, which extends the prediction scope to multiple future tokens.
  • GPTModel Forward Pass Refactoring: The GPTModel's forward method has been refactored into _preprocess and _postprocess helper methods. This modularization improves code organization and facilitates the integration of new features like MTP.
  • MTP Integration in GPTModel: MTP logic is now integrated into the _postprocess step of the GPTModel, allowing it to function as an additional training objective by computing and scaling MTP losses.
  • Documentation Updates: Updated both Chinese and English documentation for command-line parameters to include the newly added MTP configuration options.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Multi-Token Prediction (MTP) in Megatron, including new command-line arguments and documentation. The core change is a significant refactoring of the GPTModel.forward method to accommodate the MTP logic. My review focuses on the correctness of this refactoring and the consistency of the documentation. I've identified a critical breaking API change in the forward method signature that needs to be addressed. Additionally, there are some language inconsistencies in the documentation files that should be corrected for clarity.

self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: torch.Tensor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The signature of the forward method has been changed, making attention_mask a required positional argument. This is a breaking API change that will cause errors in other parts of the codebase that call this method without providing attention_mask or passing it as a keyword argument. For example, MultimodalGPTModel.forward in swift/megatron/model/mm_gpt_model.py will break. To maintain backward compatibility, attention_mask should remain an optional keyword argument.

Suggested change
attention_mask: torch.Tensor,
attention_mask: torch.Tensor = None,

Comment on lines 218 to 219
- mtp_num_layers: Number of Multi-Token Prediction (MTP) layers. MTP extends the prediction scope at each position to multiple future tokens. This MTP implementation uses D sequential modules to sequentially predict D additional tokens. Default is None.
- mtp_loss_scaling_factor: Scaling factor of Multi-Token Prediction (MTP) loss. We compute the average of MTP losses across all depths, then multiply it by this scaling factor to obtain the overall MTP loss, which serves as an additional training objective. Default is 0.1.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The descriptions for the new MTP parameters are in English, but this documentation file is in Chinese. For consistency, please translate these descriptions into Chinese.

Suggested change
- mtp_num_layers: Number of Multi-Token Prediction (MTP) layers. MTP extends the prediction scope at each position to multiple future tokens. This MTP implementation uses D sequential modules to sequentially predict D additional tokens. Default is None.
- mtp_loss_scaling_factor: Scaling factor of Multi-Token Prediction (MTP) loss. We compute the average of MTP losses across all depths, then multiply it by this scaling factor to obtain the overall MTP loss, which serves as an additional training objective. Default is 0.1.
- mtp_num_layers: 多token预测(MTP)层的数量。MTP将每个位置的预测范围扩展到多个未来token。此MTP实现使用D个顺序模块依次预测D个额外的token。默认为None。
- mtp_loss_scaling_factor: 多token预测(MTP)损失的缩放因子。我们计算所有深度上MTP损失的平均值,然后乘以该缩放因子得到总体MTP损失,它将作为一个额外的训练目标。默认为0.1。

Comment on lines 232 to 233
- mtp_num_layers: 多token预测(MTP)层的数量。MTP将每个位置的预测范围扩展到多个未来token。此MTP实现使用D个顺序模块依次预测D个额外的token。默认为None。
- mtp_loss_scaling_factor: 多token预测(MTP)损失的缩放因子。我们计算所有深度上MTP损失的平均值,然后乘以该缩放因子得到总体MTP损失,它将作为一个额外的训练目标。默认为0.1。
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The descriptions for the new MTP parameters are in Chinese, but this documentation file is in English. For consistency, please translate these descriptions into English.

Suggested change
- mtp_num_layers: 多token预测(MTP)层的数量。MTP将每个位置的预测范围扩展到多个未来token。此MTP实现使用D个顺序模块依次预测D个额外的token。默认为None。
- mtp_loss_scaling_factor: 多token预测(MTP)损失的缩放因子。我们计算所有深度上MTP损失的平均值,然后乘以该缩放因子得到总体MTP损失,它将作为一个额外的训练目标。默认为0.1。
- mtp_num_layers: Number of Multi-Token Prediction (MTP) layers. MTP extends the prediction scope at each position to multiple future tokens. This MTP implementation uses D sequential modules to sequentially predict D additional tokens. Default is None.
- mtp_loss_scaling_factor: Scaling factor of Multi-Token Prediction (MTP) loss. We compute the average of MTP losses across all depths, then multiply it by this scaling factor to obtain the overall MTP loss, which serves as an additional training objective. Default is 0.1.

@hjh0119 hjh0119 mentioned this pull request Nov 12, 2025
@Jintao-Huang
Copy link
Collaborator Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Megatron's Multi-Token Prediction (MTP), which is a form of speculative decoding. The changes include adding new command-line arguments, documentation, and example scripts for MTP. Core logic changes involve modifying the model bridge for weight conversion of MTP layers, particularly for qwen3_next, and patching Megatron's MTP layer to handle packed sequences. The implementation also adds speculative decoding parameters for the sglang inference backend.

My review found a critical issue in the logic for handling packed sequences within the patched MTP layer, which could lead to incorrect behavior. I've provided a code suggestion to fix this. The rest of the changes, including documentation, examples, and argument handling, appear to be correct and well-implemented.

if packed_seq:
packed_seq_params = deepcopy(packed_seq_params)
tensor = packed_seq_params.cu_seqlens_q
cu_seqlens = torch.concat([tensor.new_zeros(1, ), tensor[1:] - 1, tensor[-1:]])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The logic for adjusting cu_seqlens for packed sequences in the MTP layer appears to be incorrect. The current implementation does not correctly shorten each sequence in the batch by one token for the next-token prediction task. Instead, it seems to shorten only the first sequence and adds an erroneous sequence of length 1 at the end, which will likely lead to errors or incorrect behavior when using MTP with packed sequences.

A correct implementation should adjust the cumulative sequence lengths to account for each sequence being shortened by one. This can be achieved by subtracting an increasing sequence [0, 1, 2, ...] from the original cu_seqlens tensor.

Suggested change
cu_seqlens = torch.concat([tensor.new_zeros(1, ), tensor[1:] - 1, tensor[-1:]])
cu_seqlens = tensor - torch.arange(len(tensor), device=tensor.device, dtype=tensor.dtype)

@Jintao-Huang
Copy link
Collaborator Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Megatron Multi-Token Prediction (MTP) and speculative decoding for vLLM/sglang backends. The changes are comprehensive, spanning documentation, example scripts, and core logic for both training and inference. The implementation appears solid, including new arguments, necessary monkey-patching for Megatron integration, and weight conversion logic. I have a couple of suggestions to enhance code robustness and consistency.

Comment on lines 447 to 483
if packed_seq and not self.config.apply_rope_fusion:
assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}'
rotary_pos_emb = rotary_pos_emb[position_ids[0]]
if self.config.recompute_granularity == 'full' and self.training:
hidden_states = self._checkpointed_forward(
partial(
self._proj_and_transformer_layer,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
),
hidden_states=hidden_states,
decoder_input=decoder_input,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
inference_params=inference_params,
)
else:
hidden_states = self._proj_and_transformer_layer(
hidden_states=hidden_states,
decoder_input=decoder_input,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
)
self.config.apply_rope_fusion = apply_rope_fusion
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Modifying self.config.apply_rope_fusion during the forward pass can be risky. If an exception occurs between setting it to False and restoring its original value, the config state will be left inconsistent, which could affect subsequent operations. It's safer to use a try...finally block to ensure the original value is always restored.

        try:
            if packed_seq and not self.config.apply_rope_fusion:
                assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}'
                rotary_pos_emb = rotary_pos_emb[position_ids[0]]
            if self.config.recompute_granularity == 'full' and self.training:
                hidden_states = self._checkpointed_forward(
                    partial(
                        self._proj_and_transformer_layer,
                        packed_seq_params=packed_seq_params,
                        sequence_len_offset=sequence_len_offset,
                    ),
                    hidden_states=hidden_states,
                    decoder_input=decoder_input,
                    attention_mask=attention_mask,
                    context=context,
                    context_mask=context_mask,
                    rotary_pos_emb=rotary_pos_emb,
                    rotary_pos_cos=rotary_pos_cos,
                    rotary_pos_sin=rotary_pos_sin,
                    attention_bias=attention_bias,
                    inference_params=inference_params,
                )
            else:
                hidden_states = self._proj_and_transformer_layer(
                    hidden_states=hidden_states,
                    decoder_input=decoder_input,
                    attention_mask=attention_mask,
                    context=context,
                    context_mask=context_mask,
                    rotary_pos_emb=rotary_pos_emb,
                    rotary_pos_cos=rotary_pos_cos,
                    rotary_pos_sin=rotary_pos_sin,
                    attention_bias=attention_bias,
                    inference_params=inference_params,
                    packed_seq_params=packed_seq_params,
                    sequence_len_offset=sequence_len_offset,
                )
        finally:
            self.config.apply_rope_fusion = apply_rope_fusion

self.vllm_limit_mm_per_prompt = json_parse_to_dict(self.vllm_limit_mm_per_prompt)
if self.vllm_speculative_config is not None:
self.vllm_speculative_config = json_parse_to_dict(self.vllm_speculative_config)
self.vllm_engine_kwargs = json_parse_to_dict(self.vllm_engine_kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with how vllm_limit_mm_per_prompt and vllm_speculative_config are handled, it's good practice to also check if vllm_engine_kwargs is not None before parsing it. While json_parse_to_dict might handle None gracefully, this change improves code clarity and robustness.

Suggested change
self.vllm_engine_kwargs = json_parse_to_dict(self.vllm_engine_kwargs)
if self.vllm_engine_kwargs is not None:
self.vllm_engine_kwargs = json_parse_to_dict(self.vllm_engine_kwargs)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants